idefics3.py 27.8 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# Copyright 2024 the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Idefics3 model compatible with HuggingFace weights."""

import math
19
from collections.abc import Iterable, Mapping, Sequence
20
from typing import Dict, Literal, Optional, Set, Tuple, TypedDict, Union
21
22
23

import torch
from torch import nn
24
25
from transformers import (BatchFeature, Idefics3Config, Idefics3ImageProcessor,
                          Idefics3Processor)
26

27
from vllm.config import VllmConfig
28
29
from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm.model_executor.layers.logits_processor import LogitsProcessor
30
from vllm.model_executor.layers.quantization import QuantizationConfig
31
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
32
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
33
from vllm.model_executor.models.module_mapping import MultiModelKeys
34
from vllm.model_executor.sampling_metadata import SamplingMetadata
35
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
36
37
38
from vllm.multimodal.parse import ImageProcessorItems, ImageSize
# yapf conflicts with isort for this block
# yapf: disable
39
40
41
42
from vllm.multimodal.processing import (BaseMultiModalProcessor,
                                        BaseProcessingInfo,
                                        MultiModalDataItems,
                                        MultiModalFieldConfig,
43
                                        PromptReplacement, PromptUpdate,
44
                                        PromptUpdateDetails)
45
# yapf: enable
46
47
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
48
49
50
51
52

# yapf: disable
from .idefics2_vision_model import (
    Idefics2VisionTransformer as Idefics3VisionTransformer)
# yapf: enable
53
from .interfaces import MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal
54
from .llama import LlamaModel
55
56
from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix,
                    merge_multimodal_embeddings)
57
58
59
60


class Idefics3ImagePixelInputs(TypedDict):
    type: Literal["pixel_values"]
61
    pixel_values: torch.Tensor
62
    """
63
64
    Shape: `(batch_size * num_images * num_patches, 
             num_channels, height, width)`
65
    """
66
67
68
69
70
    pixel_attention_mask: torch.Tensor

    num_patches: torch.Tensor
    """Shape: `(batch_size * num_images)`"""

71
72
73
74
75
76
77
78
79
80
81
82
83

class Idefics3ImageEmbeddingInputs(TypedDict):
    type: Literal["image_embeds"]
    data: torch.Tensor
    """
    Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
    `hidden_size` must match the hidden size of language model backbone.
    """


ImageInputs = Union[Idefics3ImagePixelInputs, Idefics3ImageEmbeddingInputs]


84
class Idefics3ProcessingInfo(BaseProcessingInfo):
85

86
    def get_hf_processor(
87
88
89
90
91
        self,
        *,
        size: Optional[Dict[str, int]] = None,
        **kwargs: object,
    ) -> Idefics3Processor:
92
        if size is not None:
93
            kwargs["size"] = size
94

95
        return self.ctx.get_hf_processor(Idefics3Processor, **kwargs)
96

97
98
    def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
        return {"image": None}
99

100
101
102
103
104
    def get_mm_max_tokens_per_item(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> Mapping[str, int]:
105
        return {"image": self.get_max_image_tokens()}
106
107
108
109
110
111

    def _resize_output_size(self,
                            *,
                            height: int,
                            width: int,
                            max_len: Optional[int] = None,
112
                            min_len: int = 1,
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
                            max_size: Optional[int] = None) -> tuple[int, int]:
        # Set default value for max_len if not provided
        max_len = max(height, width) if max_len is None else max_len
        aspect_ratio = width / height

        # Handle the maximum size constraint
        if max_size is not None:
            max_len = min(max_len, max_size)

        # Adjust dimensions according to the aspect ratio
        if width >= height:
            width = max_len
            height = int(width / aspect_ratio)
        else:
            height = max_len
            width = int(height * aspect_ratio)
129

130
131
132
        # Ensure both width and height are even (if needed)
        height += height % 2
        width += width % 2
133

134
135
136
        # Ensure dimensions are not smaller than the minimum length
        height = max(height, min_len)
        width = max(width, min_len)
137

138
        return height, width
139

140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
    def _get_resize_output_image_size(
        self,
        *,
        image_width: int,
        image_height: int,
        resolution_max_side: int,
    ) -> tuple[int, int]:
        hf_processor = self.get_hf_processor()
        image_processor: Idefics3ImageProcessor = hf_processor.image_processor
        max_image_size = image_processor.size['longest_edge']
        if resolution_max_side > max_image_size:
            raise ValueError(
                "`resolution_max_side` cannot be larger than `max_image_size`")

        height, width = image_height, image_width

        # Find the output size, when rescaling the longest edge to max_len and
        # preserving the aspect ratio
        height, width = self._resize_output_size(height=height,
                                                 width=width,
                                                 max_len=resolution_max_side)
        return height, width

    def _get_image_feature_grid_size(
        self,
        *,
        image_width: int,
        image_height: int,
168
        processor: Optional[Idefics3Processor],
169
    ) -> tuple[int, int]:
170
171
172
173
174
        if processor is None:
            processor = self.get_hf_processor()

        image_processor: Idefics3ImageProcessor = processor.image_processor

175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
        max_image_size = image_processor.max_image_size['longest_edge']
        size = image_processor.size['longest_edge']
        assert size % max_image_size == 0, (
            "`longest_edge` in image_processor's `size` must be divisible by "
            "`longest_edge` in `max_image_size`, this may be caused by "
            "incorrect mm_kwargs override.")

        resized_height, resized_width = self._get_resize_output_image_size(
            image_width=image_width,
            image_height=image_height,
            resolution_max_side=size,
        )
        if resized_height > max_image_size or resized_width > max_image_size:
            grid_h = math.ceil(resized_height / max_image_size)
            grid_w = math.ceil(resized_width / max_image_size)
        else:
            grid_h = grid_w = 0
        return grid_w, grid_h
193

194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
    def get_num_patches(
        self,
        *,
        image_width: int,
        image_height: int,
        processor: Optional[Idefics3Processor],
    ) -> int:
        grid_w, grid_h = self._get_image_feature_grid_size(
            image_width=image_width,
            image_height=image_height,
            processor=processor,
        )

        return grid_w * grid_h + 1

209
210
211
212
213
214
215
216
217
218
    def _get_image_token(
            self,
            processor: Optional[Idefics3Processor]) -> tuple[str, str, str]:
        if processor is None:
            processor = self.get_hf_processor()
        image_token = processor.image_token.content
        fake_image_token = processor.fake_image_token.content
        global_image_token = processor.global_image_tag
        return image_token, fake_image_token, global_image_token

219
220
221
222
223
224
225
226
227
228
    def get_image_repl(
        self,
        *,
        image_width: int,
        image_height: int,
        processor: Optional[Idefics3Processor],
    ) -> str:
        if processor is None:
            processor = self.get_hf_processor()

229
230
        image_token, fake_image_token, global_img_token = self._get_image_token(
            processor)
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
        image_seq_len = processor.image_seq_len
        grid_placeholder = "<row_{n_h}_col_{n_w}>"

        p_img = image_token * image_seq_len
        global_img_placeholder = fake_image_token + global_img_token + p_img
        tile_img_placeholder = fake_image_token + grid_placeholder + p_img

        grid_w, grid_h = self._get_image_feature_grid_size(
            image_width=image_width,
            image_height=image_height,
            processor=processor,
        )
        if grid_w == 0 and grid_h == 0:
            return global_img_placeholder + fake_image_token

        tiles_placeholder = list[str]()
        for i in range(grid_h):
            for j in range(grid_w):
                placeholder_per_tile = tile_img_placeholder.format(n_h=i + 1,
                                                                   n_w=j + 1)
                tiles_placeholder.append(placeholder_per_tile)
                # Add line break if it is the last tile in the row
                if j == grid_w - 1:
                    tiles_placeholder.append("\n")

        return "".join([
            *tiles_placeholder,
            "\n",
            global_img_placeholder,
            fake_image_token,
        ])

    def get_num_image_tokens(
        self,
        *,
        image_width: int,
        image_height: int,
        processor: Optional[Idefics3Processor],
    ) -> int:
270
271
272
273
        if processor is None:
            processor = self.get_hf_processor()

        num_patches = self.get_num_patches(
274
275
276
277
278
            image_width=image_width,
            image_height=image_height,
            processor=processor,
        )

279
        return num_patches * processor.image_seq_len
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298

    def get_image_size_with_most_features(self) -> ImageSize:
        processor = self.get_hf_processor()
        image_processor: Idefics3ImageProcessor = processor.image_processor

        return ImageSize(
            width=image_processor.size["longest_edge"],
            height=image_processor.size["longest_edge"],
        )

    def get_max_image_tokens(self) -> int:
        target_width, target_height = self.get_image_size_with_most_features()

        return self.get_num_image_tokens(
            image_width=target_width,
            image_height=target_height,
            processor=None,
        )

299

300
301
class Idefics3DummyInputsBuilder(BaseDummyInputsBuilder[Idefics3ProcessingInfo]
                                 ):
302

303
304
    def get_dummy_processor_inputs(
        self,
305
306
        seq_len: int,
        mm_counts: Mapping[str, int],
307
308
309
310
311
    ) -> ProcessorInputs:
        num_images = mm_counts.get("image", 0)
        hf_processor = self.info.get_hf_processor()
        image_processor: Idefics3ImageProcessor = hf_processor.image_processor
        longest_edge = image_processor.max_image_size['longest_edge']
312
        image_token, _, _ = self.info._get_image_token(hf_processor)
313
314
315
316
317
318
319
320
321
322
323
324

        mm_data = {
            "image":
            self._get_dummy_images(width=longest_edge,
                                   height=longest_edge,
                                   num_images=num_images)
        }

        return ProcessorInputs(
            prompt_text=image_token * num_images,
            mm_data=mm_data,
        )
325
326


327
class Idefics3MultiModalProcessor(
328
        BaseMultiModalProcessor[Idefics3ProcessingInfo]):
329

330
331
332
333
334
335
    def _call_hf_processor(
        self,
        prompt: str,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
    ) -> BatchFeature:
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
        # Text-only input not supported in composite processor
        if not (images := mm_data.get("images", [])):
            prompt_ids = self.info.get_tokenizer().encode(prompt)
            prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids)
            return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")

        processed_outputs = super()._call_hf_processor(
            prompt,
            mm_data,
            mm_kwargs,
        )

        parsed_images = (self._get_data_parser().parse_mm_data({
            "image": images
        }).get_items("image", ImageProcessorItems))
        image_sizes = [
            parsed_images.get_image_size(i) for i in range(len(parsed_images))
        ]
        hf_processor = self.info.get_hf_processor(**mm_kwargs)

        num_patches = [
            self.info.get_num_patches(
                image_width=size.width,
                image_height=size.height,
                processor=hf_processor,
            ) for size in image_sizes
        ]
        processed_outputs["num_patches"] = torch.tensor(num_patches)

        # Remove the extra batch dimension
        processed_outputs["pixel_values"].squeeze_(0)
        processed_outputs["pixel_attention_mask"].squeeze_(0)

369
        return processed_outputs
370

371
372
373
374
375
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
376
377
        num_patches = hf_inputs.get("num_patches", torch.empty(0))

378
        return dict(
379
380
381
382
            pixel_values=MultiModalFieldConfig.flat_from_sizes(
                "image", num_patches),
            pixel_attention_mask=MultiModalFieldConfig.flat_from_sizes(
                "image", num_patches),
383
            image_embeds=MultiModalFieldConfig.batched("image"),
384
            num_patches=MultiModalFieldConfig.batched("image"),
385
        )
386

387
    def _get_prompt_updates(
388
389
390
391
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
        out_mm_kwargs: MultiModalKwargs,
392
    ) -> Sequence[PromptUpdate]:
393
        hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
394
        image_token, _, _ = self.info._get_image_token(hf_processor)
395

396
        def get_replacement_idefics3(item_idx: int) -> PromptUpdateDetails:
397
398
399
            images = mm_items.get_items("image", ImageProcessorItems)

            image_size = images.get_image_size(item_idx)
400

401
            image_repl = self.info.get_image_repl(
402
403
                image_width=image_size.width,
                image_height=image_size.height,
404
                processor=hf_processor,
405
406
            )

407
408
409
410
411
            return PromptUpdateDetails.select_text(
                image_repl,
                embed_text=image_token,
            )

412
413
414
415
416
417
418
        return [
            PromptReplacement(
                modality="image",
                target=image_token,
                replacement=get_replacement_idefics3,
            )
        ]
419
420
421
422


class Idefics3SimpleMLP(nn.Module):

423
424
425
426
427
428
    def __init__(
        self,
        config: Idefics3Config,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
    ):
429
430
431
432
        super().__init__()
        input_size = config.vision_config.hidden_size * (config.scale_factor**
                                                         2)
        output_size = config.text_config.hidden_size
433
434
435
436
437
438
439
        self.proj = ReplicatedLinear(
            input_size,
            output_size,
            bias=False,
            quant_config=quant_config,
            prefix=maybe_prefix(prefix, "proj"),
        )
440
441
442
443
444
445
446
447

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        out, _ = self.proj(x)
        return out


class Idefics3Connector(nn.Module):

448
449
450
451
452
453
    def __init__(
        self,
        config: Idefics3Config,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
    ):
454
455
        super().__init__()
        self.scale_factor = config.scale_factor
456
457
458
459
460
        self.modality_projection = Idefics3SimpleMLP(
            config,
            quant_config,
            prefix=maybe_prefix(prefix, "modality_projection"),
        )
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490

    def pixel_shuffle(self,
                      x: torch.Tensor,
                      scale_factor: int = 2) -> torch.Tensor:
        bsz, seq, embed_dim = x.size()
        height = width = int(seq**0.5)
        x = x.view(bsz, height, width, embed_dim)
        x = x.view(bsz, height, int(width / scale_factor),
                   embed_dim * scale_factor)
        x = x.permute(0, 2, 1, 3)
        x = x.reshape(
            bsz,
            int(width / scale_factor),
            int(height / scale_factor),
            embed_dim * (scale_factor**2),
        )
        x = x.permute(0, 2, 1, 3)
        x = x.reshape(bsz, int(seq / (scale_factor**2)),
                      embed_dim * (scale_factor**2))
        return x

    def forward(self, image_hidden_states: torch.Tensor) -> torch.Tensor:
        image_hidden_states = self.pixel_shuffle(image_hidden_states,
                                                 self.scale_factor)
        image_hidden_states = self.modality_projection(image_hidden_states)
        return image_hidden_states


class Idefics3Model(nn.Module):

491
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
492
        super().__init__()
493

494
        config: Idefics3Config = vllm_config.model_config.hf_config
495
496
        quant_config = vllm_config.quant_config

497
498
        self.config = config
        self.vocab_size = self.config.text_config.vocab_size
499
500
501
502
503
504
505
506
507
        self.vision_model = Idefics3VisionTransformer(
            config.vision_config,
            quant_config=quant_config,
            prefix=maybe_prefix(prefix, "vision_model"))
        self.connector = Idefics3Connector(
            config,
            quant_config,
            prefix=maybe_prefix(prefix, "connector"),
        )
508
509
510
511
        self.text_model = LlamaModel(
            vllm_config=vllm_config.with_hf_config(config.text_config),
            prefix=maybe_prefix(prefix, "text_model"),
        )
512
513
514
515
516
517

        self.image_seq_len = int(
            ((config.vision_config.image_size //
              config.vision_config.patch_size)**2) / (config.scale_factor**2))
        self.image_token_id = self.config.image_token_id

518
    def image_pixels_to_features(
519
520
        self,
        pixel_values: torch.Tensor,
521
522
        pixel_attention_mask: torch.Tensor,
    ) -> torch.Tensor:
523
524
525
526
527
528
529
530
531
532
533
534
535
        # NOTE: we skip the step to select the vision feature layer since
        # this is already done inside the vision tower
        pixel_values = pixel_values.to(
            dtype=self.vision_model.embeddings.patch_embedding.weight.dtype
        )  # fp16 compatibility

        # Remove padding images - padding images are full 0.
        nb_values_per_image = pixel_values.shape[1:].numel()
        real_images_inds = (pixel_values == 0.0).sum(
            dim=(-1, -2, -3)) != nb_values_per_image
        pixel_values = pixel_values[real_images_inds].contiguous()

        # Handle the vision attention mask
536
537
538
        # Remove padding images from the mask
        pixel_attention_mask = pixel_attention_mask[
            real_images_inds].contiguous()
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554

        patch_size = self.config.vision_config.patch_size
        patches_subgrid = pixel_attention_mask.unfold(dimension=1,
                                                      size=patch_size,
                                                      step=patch_size)
        patches_subgrid = patches_subgrid.unfold(dimension=2,
                                                 size=patch_size,
                                                 step=patch_size)
        patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()

        # Get sequence from the vision encoder
        image_hidden_states = self.vision_model(
            pixel_values=pixel_values,
            patch_attention_mask=patch_attention_mask,
        )

555
        return image_hidden_states
556

557
558
559
560
561
562
    def get_input_embeddings(
        self,
        input_ids: torch.Tensor,
    ) -> torch.Tensor:
        return self.text_model.get_input_embeddings(input_ids)

563
564
565
566
567
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
568
        inputs_embeds: Optional[torch.Tensor] = None,
569
570
571
572
573
574
575
576
577
578
579
    ) -> Union[torch.Tensor, IntermediateTensors]:

        hidden_states = self.text_model(
            input_ids,
            positions,
            intermediate_tensors,
            inputs_embeds=inputs_embeds,
        )
        return hidden_states


580
@MULTIMODAL_REGISTRY.register_processor(
581
    Idefics3MultiModalProcessor,
582
583
    info=Idefics3ProcessingInfo,
    dummy_inputs=Idefics3DummyInputsBuilder)
584
585
586
587
588
589
590
591
592
593
594
595
596
class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
                                       SupportsLoRA):
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
        "gate_up_proj": [
            "gate_proj",
            "up_proj",
        ],
    }
597

598
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
599
600
        super().__init__()

601
602
603
604
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config

605
606
607
        self.config = config
        self.multimodal_config = multimodal_config

608
609
        self.model = Idefics3Model(vllm_config=vllm_config,
                                   prefix=maybe_prefix(prefix, "model"))
610
611
612
613
614
615
616
617
618
619
        self.image_token_id = self.config.image_token_id

        self.lm_head = ParallelLMHead(
            config.text_config.vocab_size,
            config.text_config.hidden_size,
            quant_config=quant_config,
        )
        if self.config.text_config.tie_word_embeddings:
            self.lm_head.weight = self.model.text_model.wte.weight
        self.logits_processor = LogitsProcessor(config.text_config.vocab_size)
620
        self.sampler = get_sampler()
621

622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
    def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
        h = w = self.config.vision_config.image_size
        expected_dims = (3, h, w)

        def _validate_shape(d: torch.Tensor):
            actual_dims = tuple(d.shape)

            if actual_dims != expected_dims:
                expected_expr = str(expected_dims)
                raise ValueError(
                    "The expected shape of pixel values per image per batch "
                    f" per patch is {expected_expr}. "
                    f"You supplied {tuple(d.shape)}.")

        for d in data:
            _validate_shape(d)

        return data

    def _parse_and_validate_image_input(
            self, **kwargs: object) -> Optional[ImageInputs]:
        pixel_values = kwargs.pop("pixel_values", None)
        image_embeds = kwargs.pop("image_embeds", None)

        if pixel_values is None and image_embeds is None:
            return None

        if image_embeds is not None:
            if not isinstance(image_embeds, (torch.Tensor, list)):
                raise ValueError("Incorrect type of image embeddings. "
                                 f"Got type: {type(image_embeds)}")

            return Idefics3ImageEmbeddingInputs(
                type="image_embeds",
                data=flatten_bn(image_embeds, concat=True),
            )

        if pixel_values is not None:
            if not isinstance(pixel_values, (torch.Tensor, list)):
                raise ValueError("Incorrect type of pixel values. "
                                 f"Got type: {type(pixel_values)}")

            pixel_attention_mask = kwargs.pop("pixel_attention_mask")
            if not isinstance(pixel_attention_mask, (torch.Tensor, list)):
                raise ValueError("Incorrect type of pixel_attention_mask. "
                                 f"Got type: {type(pixel_attention_mask)}")

            num_patches = kwargs.pop("num_patches")
            if not isinstance(num_patches, (torch.Tensor, list)):
                raise ValueError("Incorrect type of num_patches. "
                                 f"Got type: {type(num_patches)}")

            pixel_values = flatten_bn(pixel_values, concat=True)
            pixel_attention_mask = flatten_bn(pixel_attention_mask,
                                              concat=True)
            num_patches = flatten_bn(num_patches, concat=True)

            return Idefics3ImagePixelInputs(
                type="pixel_values",
                pixel_values=self._validate_pixel_values(pixel_values),
                pixel_attention_mask=pixel_attention_mask,
                num_patches=num_patches,
            )

        raise AssertionError("This line should be unreachable.")

    def _process_image_pixels(
            self, inputs: Idefics3ImagePixelInputs) -> torch.Tensor:
        pixel_values = inputs["pixel_values"]
        pixel_attention_mask = inputs["pixel_attention_mask"]

        return self.model.image_pixels_to_features(
            pixel_values,
            pixel_attention_mask=pixel_attention_mask,
        )

698
699
700
701
    def _process_image_input(
        self,
        image_input: ImageInputs,
    ) -> Union[torch.Tensor, list[torch.Tensor]]:
702
703
704
705
706
707
708
        if image_input["type"] == "image_embeds":
            return image_input["data"]

        image_features = self._process_image_pixels(image_input)
        image_features = self.model.connector(image_features)

        num_patches = image_input["num_patches"]
709
710
711
        return [
            e.flatten(0, 1) for e in image_features.split(num_patches.tolist())
        ]
712

713
    def get_multimodal_embeddings(
714
            self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
715
        image_input = self._parse_and_validate_image_input(**kwargs)
716
717
        if image_input is None:
            return None
718

719
        return self._process_image_input(image_input)
720
721
722
723

    def get_input_embeddings(
        self,
        input_ids: torch.Tensor,
724
        multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
725
726
727
728
    ) -> torch.Tensor:
        inputs_embeds = self.model.get_input_embeddings(input_ids)
        if multimodal_embeddings is not None:
            inputs_embeds = merge_multimodal_embeddings(
729
730
                input_ids,
                inputs_embeds,
731
                multimodal_embeddings,
732
733
                self.config.image_token_id,
            )
734
735
        return inputs_embeds

736
737
738
739
740
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
741
        inputs_embeds: Optional[torch.Tensor] = None,
742
743
        **kwargs: object,
    ) -> Union[torch.Tensor, IntermediateTensors]:
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
        if intermediate_tensors is not None:
            inputs_embeds = None

        # NOTE: In v1, inputs_embeds is always generated at model runner, this
        # condition is for v0 compatibility.
        elif inputs_embeds is None:
            vision_embeddings = self.get_multimodal_embeddings(**kwargs)
            inputs_embeds = self.get_input_embeddings(input_ids,
                                                      vision_embeddings)
            input_ids = None

        hidden_states = self.model.text_model(input_ids,
                                              positions,
                                              intermediate_tensors,
                                              inputs_embeds=inputs_embeds)

760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
        return hidden_states

    def compute_logits(self, hidden_states: torch.Tensor,
                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
        logits = self.logits_processor(self.lm_head, hidden_states,
                                       sampling_metadata)
        return logits

    def sample(
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
        next_tokens = self.sampler(logits, sampling_metadata)
        return next_tokens

776
777
    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
778
        loader = AutoWeightsLoader(self)
779
        return loader.load_weights(weights)
780
781
782
783
784
785
786
787
788

    def get_mm_mapping(self) -> MultiModelKeys:
        """
        Get the module prefix in multimodal models
        """
        return MultiModelKeys.from_string_field(
            language_model="model.text_model",
            connector="model.connector",
            tower_model="model.vision_model")