idefics3.py 24.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# Copyright 2024 the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Idefics3 model compatible with HuggingFace weights."""

import math
20
from collections.abc import Iterable, Mapping, Sequence
21
from typing import Annotated, Literal, Optional, Union
22
23
24

import torch
from torch import nn
25
26
27
28
29
30
from transformers import (
    BatchFeature,
    Idefics3Config,
    Idefics3ImageProcessor,
    Idefics3Processor,
)
31

32
from vllm.config import VllmConfig
33
from vllm.config.multimodal import BaseDummyOptions
34
35
from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm.model_executor.layers.logits_processor import LogitsProcessor
36
from vllm.model_executor.layers.quantization import QuantizationConfig
37
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
38
from vllm.model_executor.models.module_mapping import MultiModelKeys
39
from vllm.multimodal import MULTIMODAL_REGISTRY
40
41
42
43
44
from vllm.multimodal.inputs import (
    MultiModalDataDict,
    MultiModalFieldConfig,
    MultiModalKwargsItems,
)
45
from vllm.multimodal.parse import ImageProcessorItems, ImageSize
46

47
48
# yapf conflicts with isort for this block
# yapf: disable
49
50
51
52
53
54
55
56
57
from vllm.multimodal.processing import (
    BaseMultiModalProcessor,
    BaseProcessingInfo,
    MultiModalDataItems,
    PromptReplacement,
    PromptUpdate,
    PromptUpdateDetails,
)

58
# yapf: enable
59
from vllm.multimodal.profiling import BaseDummyInputsBuilder
60
from vllm.sequence import IntermediateTensors
61
from vllm.utils.tensor_schema import TensorSchema, TensorShape
62
63
64

# yapf: disable
from .idefics2_vision_model import (
65
66
67
    Idefics2VisionTransformer as Idefics3VisionTransformer,
)

68
# yapf: enable
69
from .interfaces import MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal
70
from .llama import LlamaModel
71
from .utils import AutoWeightsLoader, maybe_prefix
72
73


74
class Idefics3ImagePixelInputs(TensorSchema):
75
    """
76
77
78
79
80
81
    Dimensions:
        - bn: Batch size * number of images
        - bnp: Batch size * number of images * number of patches
        - c: Number of channels (3)
        - h: Height
        - w: Width
82
    """
83

84
85
    type: Literal["pixel_values"]
    pixel_values: Annotated[torch.Tensor, TensorShape("bnp", 3, "h", "w")]
86
    pixel_attention_mask: Annotated[torch.Tensor, TensorShape("bnp", "h", "w")]
87
    num_patches: Annotated[torch.Tensor, TensorShape("bn")]
88

89

90
class Idefics3ImageEmbeddingInputs(TensorSchema):
91
    """
92
93
94
95
    Dimensions:
        - bn: Batch size * number of images
        - f: Image feature size
        - h: Hidden size (must match the hidden size of language model backbone)
96
    """
97

98
99
    type: Literal["image_embeds"]
    data: Annotated[torch.Tensor, TensorShape("bn", "f", "h")]
100
101
102
103
104


ImageInputs = Union[Idefics3ImagePixelInputs, Idefics3ImageEmbeddingInputs]


105
class Idefics3ProcessingInfo(BaseProcessingInfo):
106
    def get_hf_processor(self, **kwargs: object) -> Idefics3Processor:
107
        return self.ctx.get_hf_processor(Idefics3Processor, **kwargs)
108

109
110
    def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
        return {"image": None}
111

112
113
114
115
116
117
118
119
120
    def _resize_output_size(
        self,
        *,
        height: int,
        width: int,
        max_len: Optional[int] = None,
        min_len: int = 1,
        max_size: Optional[int] = None,
    ) -> tuple[int, int]:
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
        # Set default value for max_len if not provided
        max_len = max(height, width) if max_len is None else max_len
        aspect_ratio = width / height

        # Handle the maximum size constraint
        if max_size is not None:
            max_len = min(max_len, max_size)

        # Adjust dimensions according to the aspect ratio
        if width >= height:
            width = max_len
            height = int(width / aspect_ratio)
        else:
            height = max_len
            width = int(height * aspect_ratio)
136

137
138
139
        # Ensure both width and height are even (if needed)
        height += height % 2
        width += width % 2
140

141
142
143
        # Ensure dimensions are not smaller than the minimum length
        height = max(height, min_len)
        width = max(width, min_len)
144

145
        return height, width
146

147
148
149
150
151
152
153
154
155
    def _get_resize_output_image_size(
        self,
        *,
        image_width: int,
        image_height: int,
        resolution_max_side: int,
    ) -> tuple[int, int]:
        hf_processor = self.get_hf_processor()
        image_processor: Idefics3ImageProcessor = hf_processor.image_processor
156
        max_image_size = image_processor.size["longest_edge"]
157
158
        if resolution_max_side > max_image_size:
            raise ValueError(
159
160
                "`resolution_max_side` cannot be larger than `max_image_size`"
            )
161
162
163
164
165

        height, width = image_height, image_width

        # Find the output size, when rescaling the longest edge to max_len and
        # preserving the aspect ratio
166
167
168
        height, width = self._resize_output_size(
            height=height, width=width, max_len=resolution_max_side
        )
169
170
171
172
173
174
175
        return height, width

    def _get_image_feature_grid_size(
        self,
        *,
        image_width: int,
        image_height: int,
176
        processor: Optional[Idefics3Processor],
177
    ) -> tuple[int, int]:
178
179
180
181
182
        if processor is None:
            processor = self.get_hf_processor()

        image_processor: Idefics3ImageProcessor = processor.image_processor

183
184
        max_image_size = image_processor.max_image_size["longest_edge"]
        size = image_processor.size["longest_edge"]
185
186
187
        assert size % max_image_size == 0, (
            "`longest_edge` in image_processor's `size` must be divisible by "
            "`longest_edge` in `max_image_size`, this may be caused by "
188
189
            "incorrect mm_kwargs override."
        )
190
191
192
193
194
195
196
197
198
199
200
201

        resized_height, resized_width = self._get_resize_output_image_size(
            image_width=image_width,
            image_height=image_height,
            resolution_max_side=size,
        )
        if resized_height > max_image_size or resized_width > max_image_size:
            grid_h = math.ceil(resized_height / max_image_size)
            grid_w = math.ceil(resized_width / max_image_size)
        else:
            grid_h = grid_w = 0
        return grid_w, grid_h
202

203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
    def get_num_patches(
        self,
        *,
        image_width: int,
        image_height: int,
        processor: Optional[Idefics3Processor],
    ) -> int:
        grid_w, grid_h = self._get_image_feature_grid_size(
            image_width=image_width,
            image_height=image_height,
            processor=processor,
        )

        return grid_w * grid_h + 1

218
    def _get_image_token(
219
220
        self, processor: Optional[Idefics3Processor]
    ) -> tuple[str, str, str]:
221
222
        if processor is None:
            processor = self.get_hf_processor()
223

224
225
        image_token = processor.image_token
        fake_image_token = processor.fake_image_token
226
227
228
        global_image_token = processor.global_image_tag
        return image_token, fake_image_token, global_image_token

229
230
231
232
233
234
235
236
237
238
    def get_image_repl(
        self,
        *,
        image_width: int,
        image_height: int,
        processor: Optional[Idefics3Processor],
    ) -> str:
        if processor is None:
            processor = self.get_hf_processor()

239
        image_token, fake_image_token, global_img_token = self._get_image_token(
240
241
            processor
        )
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
        image_seq_len = processor.image_seq_len
        grid_placeholder = "<row_{n_h}_col_{n_w}>"

        p_img = image_token * image_seq_len
        global_img_placeholder = fake_image_token + global_img_token + p_img
        tile_img_placeholder = fake_image_token + grid_placeholder + p_img

        grid_w, grid_h = self._get_image_feature_grid_size(
            image_width=image_width,
            image_height=image_height,
            processor=processor,
        )
        if grid_w == 0 and grid_h == 0:
            return global_img_placeholder + fake_image_token

        tiles_placeholder = list[str]()
        for i in range(grid_h):
            for j in range(grid_w):
260
                placeholder_per_tile = tile_img_placeholder.format(n_h=i + 1, n_w=j + 1)
261
262
263
264
265
                tiles_placeholder.append(placeholder_per_tile)
                # Add line break if it is the last tile in the row
                if j == grid_w - 1:
                    tiles_placeholder.append("\n")

266
267
268
269
270
271
272
273
        return "".join(
            [
                *tiles_placeholder,
                "\n",
                global_img_placeholder,
                fake_image_token,
            ]
        )
274
275
276
277
278
279
280
281

    def get_num_image_tokens(
        self,
        *,
        image_width: int,
        image_height: int,
        processor: Optional[Idefics3Processor],
    ) -> int:
282
283
284
285
        if processor is None:
            processor = self.get_hf_processor()

        num_patches = self.get_num_patches(
286
287
288
289
290
            image_width=image_width,
            image_height=image_height,
            processor=processor,
        )

291
        return num_patches * processor.image_seq_len
292
293
294
295
296
297
298
299
300
301

    def get_image_size_with_most_features(self) -> ImageSize:
        processor = self.get_hf_processor()
        image_processor: Idefics3ImageProcessor = processor.image_processor

        return ImageSize(
            width=image_processor.size["longest_edge"],
            height=image_processor.size["longest_edge"],
        )

302

303
class Idefics3DummyInputsBuilder(BaseDummyInputsBuilder[Idefics3ProcessingInfo]):
304
305
306
307
308
309
310
311
312
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_images = mm_counts.get("image", 0)

        processor = self.info.get_hf_processor()
        image_token, _, _ = self.info._get_image_token(processor)

        return image_token * num_images

    def get_dummy_mm_data(
313
        self,
314
315
        seq_len: int,
        mm_counts: Mapping[str, int],
316
        mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
317
    ) -> MultiModalDataDict:
318
319
320
        num_images = mm_counts.get("image", 0)
        hf_processor = self.info.get_hf_processor()
        image_processor: Idefics3ImageProcessor = hf_processor.image_processor
321
        longest_edge = image_processor.max_image_size["longest_edge"]
322

323
324
        image_overrides = mm_options.get("image") if mm_options else None

325
        return {
326
327
328
329
330
331
            "image": self._get_dummy_images(
                width=longest_edge,
                height=longest_edge,
                num_images=num_images,
                overrides=image_overrides,
            )
332
333
        }

334

335
class Idefics3MultiModalProcessor(BaseMultiModalProcessor[Idefics3ProcessingInfo]):
336
337
338
339
340
    def _call_hf_processor(
        self,
        prompt: str,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
341
        tok_kwargs: Mapping[str, object],
342
    ) -> BatchFeature:
343
344
345
346
347
348
349
350
351
352
        # Text-only input not supported in composite processor
        if not (images := mm_data.get("images", [])):
            prompt_ids = self.info.get_tokenizer().encode(prompt)
            prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids)
            return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")

        processed_outputs = super()._call_hf_processor(
            prompt,
            mm_data,
            mm_kwargs,
353
            tok_kwargs,
354
355
        )

356
357
358
359
360
        parsed_images = (
            self._get_data_parser()
            .parse_mm_data({"image": images})
            .get_items("image", ImageProcessorItems)
        )
361
362
363
364
365
366
367
368
369
370
        image_sizes = [
            parsed_images.get_image_size(i) for i in range(len(parsed_images))
        ]
        hf_processor = self.info.get_hf_processor(**mm_kwargs)

        num_patches = [
            self.info.get_num_patches(
                image_width=size.width,
                image_height=size.height,
                processor=hf_processor,
371
372
            )
            for size in image_sizes
373
374
375
376
377
378
379
        ]
        processed_outputs["num_patches"] = torch.tensor(num_patches)

        # Remove the extra batch dimension
        processed_outputs["pixel_values"].squeeze_(0)
        processed_outputs["pixel_attention_mask"].squeeze_(0)

380
        return processed_outputs
381

382
383
384
385
386
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
387
388
        num_patches = hf_inputs.get("num_patches", torch.empty(0))

389
        return dict(
390
            pixel_values=MultiModalFieldConfig.flat_from_sizes("image", num_patches),
391
            pixel_attention_mask=MultiModalFieldConfig.flat_from_sizes(
392
393
                "image", num_patches
            ),
394
            image_embeds=MultiModalFieldConfig.batched("image"),
395
            num_patches=MultiModalFieldConfig.batched("image"),
396
        )
397

398
    def _get_prompt_updates(
399
400
401
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
402
        out_mm_kwargs: MultiModalKwargsItems,
403
    ) -> Sequence[PromptUpdate]:
404
        hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
405
        image_token, _, _ = self.info._get_image_token(hf_processor)
406

407
        def get_replacement_idefics3(item_idx: int) -> PromptUpdateDetails:
408
409
410
            images = mm_items.get_items("image", ImageProcessorItems)

            image_size = images.get_image_size(item_idx)
411

412
            image_repl = self.info.get_image_repl(
413
414
                image_width=image_size.width,
                image_height=image_size.height,
415
                processor=hf_processor,
416
417
            )

418
419
420
421
422
            return PromptUpdateDetails.select_text(
                image_repl,
                embed_text=image_token,
            )

423
424
425
426
427
428
429
        return [
            PromptReplacement(
                modality="image",
                target=image_token,
                replacement=get_replacement_idefics3,
            )
        ]
430
431
432


class Idefics3SimpleMLP(nn.Module):
433
434
435
436
437
438
    def __init__(
        self,
        config: Idefics3Config,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
    ):
439
        super().__init__()
440
        input_size = config.vision_config.hidden_size * (config.scale_factor**2)
441
        output_size = config.text_config.hidden_size
442
443
444
445
446
447
448
        self.proj = ReplicatedLinear(
            input_size,
            output_size,
            bias=False,
            quant_config=quant_config,
            prefix=maybe_prefix(prefix, "proj"),
        )
449
450
451
452
453
454
455

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        out, _ = self.proj(x)
        return out


class Idefics3Connector(nn.Module):
456
457
458
459
460
461
    def __init__(
        self,
        config: Idefics3Config,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
    ):
462
463
        super().__init__()
        self.scale_factor = config.scale_factor
464
465
466
467
468
        self.modality_projection = Idefics3SimpleMLP(
            config,
            quant_config,
            prefix=maybe_prefix(prefix, "modality_projection"),
        )
469

470
    def pixel_shuffle(self, x: torch.Tensor, scale_factor: int = 2) -> torch.Tensor:
471
472
473
        bsz, seq, embed_dim = x.size()
        height = width = int(seq**0.5)
        x = x.view(bsz, height, width, embed_dim)
474
        x = x.view(bsz, height, int(width / scale_factor), embed_dim * scale_factor)
475
476
477
478
479
480
481
482
        x = x.permute(0, 2, 1, 3)
        x = x.reshape(
            bsz,
            int(width / scale_factor),
            int(height / scale_factor),
            embed_dim * (scale_factor**2),
        )
        x = x.permute(0, 2, 1, 3)
483
        x = x.reshape(bsz, int(seq / (scale_factor**2)), embed_dim * (scale_factor**2))
484
485
486
        return x

    def forward(self, image_hidden_states: torch.Tensor) -> torch.Tensor:
487
        image_hidden_states = self.pixel_shuffle(image_hidden_states, self.scale_factor)
488
489
490
491
492
        image_hidden_states = self.modality_projection(image_hidden_states)
        return image_hidden_states


class Idefics3Model(nn.Module):
493
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
494
        super().__init__()
495

496
        config: Idefics3Config = vllm_config.model_config.hf_config
497
498
        quant_config = vllm_config.quant_config

499
500
        self.config = config
        self.vocab_size = self.config.text_config.vocab_size
501
502
503
        self.vision_model = Idefics3VisionTransformer(
            config.vision_config,
            quant_config=quant_config,
504
505
            prefix=maybe_prefix(prefix, "vision_model"),
        )
506
507
508
509
510
        self.connector = Idefics3Connector(
            config,
            quant_config,
            prefix=maybe_prefix(prefix, "connector"),
        )
511
512
513
514
        self.text_model = LlamaModel(
            vllm_config=vllm_config.with_hf_config(config.text_config),
            prefix=maybe_prefix(prefix, "text_model"),
        )
515
516

        self.image_seq_len = int(
517
518
519
            ((config.vision_config.image_size // config.vision_config.patch_size) ** 2)
            / (config.scale_factor**2)
        )
520
521
        self.image_token_id = self.config.image_token_id

522
    def image_pixels_to_features(
523
524
        self,
        pixel_values: torch.Tensor,
525
526
        pixel_attention_mask: torch.Tensor,
    ) -> torch.Tensor:
527
528
529
530
531
532
533
534
535
        # NOTE: we skip the step to select the vision feature layer since
        # this is already done inside the vision tower
        pixel_values = pixel_values.to(
            dtype=self.vision_model.embeddings.patch_embedding.weight.dtype
        )  # fp16 compatibility

        # Remove padding images - padding images are full 0.
        nb_values_per_image = pixel_values.shape[1:].numel()
        real_images_inds = (pixel_values == 0.0).sum(
536
537
            dim=(-1, -2, -3)
        ) != nb_values_per_image
538
539
540
        pixel_values = pixel_values[real_images_inds].contiguous()

        # Handle the vision attention mask
541
        # Remove padding images from the mask
542
        pixel_attention_mask = pixel_attention_mask[real_images_inds].contiguous()
543
544

        patch_size = self.config.vision_config.patch_size
545
546
547
548
549
550
        patches_subgrid = pixel_attention_mask.unfold(
            dimension=1, size=patch_size, step=patch_size
        )
        patches_subgrid = patches_subgrid.unfold(
            dimension=2, size=patch_size, step=patch_size
        )
551
552
553
554
555
556
557
558
        patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()

        # Get sequence from the vision encoder
        image_hidden_states = self.vision_model(
            pixel_values=pixel_values,
            patch_attention_mask=patch_attention_mask,
        )

559
        return image_hidden_states
560

561
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
562
563
        return self.text_model.get_input_embeddings(input_ids)

564
565
566
567
568
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
569
        inputs_embeds: Optional[torch.Tensor] = None,
570
571
572
573
574
575
576
577
578
579
    ) -> Union[torch.Tensor, IntermediateTensors]:
        hidden_states = self.text_model(
            input_ids,
            positions,
            intermediate_tensors,
            inputs_embeds=inputs_embeds,
        )
        return hidden_states


580
@MULTIMODAL_REGISTRY.register_processor(
581
    Idefics3MultiModalProcessor,
582
    info=Idefics3ProcessingInfo,
583
584
585
    dummy_inputs=Idefics3DummyInputsBuilder,
)
class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsLoRA):
586
587
    merge_by_field_config = True

588
589
590
591
592
593
594
595
596
597
598
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
        "gate_up_proj": [
            "gate_proj",
            "up_proj",
        ],
    }
599

600
601
602
603
604
605
606
    @classmethod
    def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
        if modality.startswith("image"):
            return "<image>"

        raise ValueError("Only image modality is supported")

607
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
608
609
        super().__init__()

610
611
612
613
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config

614
615
616
        self.config = config
        self.multimodal_config = multimodal_config

617
618
619
        self.model = Idefics3Model(
            vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
        )
620
621
622
623
624
625
        self.image_token_id = self.config.image_token_id

        self.lm_head = ParallelLMHead(
            config.text_config.vocab_size,
            config.text_config.hidden_size,
            quant_config=quant_config,
626
            prefix=maybe_prefix(prefix, "lm_head"),
627
628
        )
        if self.config.text_config.tie_word_embeddings:
629
            self.lm_head.weight = self.model.text_model.embed_tokens.weight
630
631
        self.logits_processor = LogitsProcessor(config.text_config.vocab_size)

632
    def _parse_and_validate_image_input(
633
634
        self, **kwargs: object
    ) -> Optional[ImageInputs]:
635
636
637
638
639
640
641
642
643
        pixel_values = kwargs.pop("pixel_values", None)
        image_embeds = kwargs.pop("image_embeds", None)

        if pixel_values is None and image_embeds is None:
            return None

        if image_embeds is not None:
            return Idefics3ImageEmbeddingInputs(
                type="image_embeds",
644
                data=image_embeds,
645
646
647
648
649
            )

        if pixel_values is not None:
            pixel_attention_mask = kwargs.pop("pixel_attention_mask")
            num_patches = kwargs.pop("num_patches")
650
            expected_h = expected_w = self.config.vision_config.image_size
651

652
653
            return Idefics3ImagePixelInputs(
                type="pixel_values",
654
655
656
                pixel_values=pixel_values,
                pixel_attention_mask=pixel_attention_mask,
                num_patches=num_patches,
657
                resolve_bindings={"h": expected_h, "w": expected_w},
658
659
660
661
            )

        raise AssertionError("This line should be unreachable.")

662
    def _process_image_pixels(self, inputs: Idefics3ImagePixelInputs) -> torch.Tensor:
663
664
665
666
667
668
669
670
        pixel_values = inputs["pixel_values"]
        pixel_attention_mask = inputs["pixel_attention_mask"]

        return self.model.image_pixels_to_features(
            pixel_values,
            pixel_attention_mask=pixel_attention_mask,
        )

671
672
673
674
    def _process_image_input(
        self,
        image_input: ImageInputs,
    ) -> Union[torch.Tensor, list[torch.Tensor]]:
675
676
677
678
679
680
681
        if image_input["type"] == "image_embeds":
            return image_input["data"]

        image_features = self._process_image_pixels(image_input)
        image_features = self.model.connector(image_features)

        num_patches = image_input["num_patches"]
682
        return [e.flatten(0, 1) for e in image_features.split(num_patches.tolist())]
683

684
685
686
    def get_language_model(self) -> torch.nn.Module:
        return self.model

687
    def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings:
688
        image_input = self._parse_and_validate_image_input(**kwargs)
689
        if image_input is None:
690
            return []
691

692
        return self._process_image_input(image_input)
693

694
695
696
697
698
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
699
        inputs_embeds: Optional[torch.Tensor] = None,
700
701
        **kwargs: object,
    ) -> Union[torch.Tensor, IntermediateTensors]:
702
703
704
        if intermediate_tensors is not None:
            inputs_embeds = None

705
706
707
        hidden_states = self.model.text_model(
            input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds
        )
708

709
710
        return hidden_states

711
712
    def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor:
        logits = self.logits_processor(self.lm_head, hidden_states)
713
714
        return logits

715
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
716
        loader = AutoWeightsLoader(self)
717
        return loader.load_weights(weights)
718
719
720
721
722
723
724
725

    def get_mm_mapping(self) -> MultiModelKeys:
        """
        Get the module prefix in multimodal models
        """
        return MultiModelKeys.from_string_field(
            language_model="model.text_model",
            connector="model.connector",
726
727
            tower_model="model.vision_model",
        )