idefics3.py 26.6 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# Copyright 2024 the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Idefics3 model compatible with HuggingFace weights."""

import math
19
20
from typing import (Dict, Iterable, List, Literal, Mapping, Optional, Set,
                    Tuple, TypedDict, Union)
21
22
23
24

import torch
import torch.utils.checkpoint
from torch import nn
25
26
from transformers import (BatchFeature, Idefics3Config, Idefics3ImageProcessor,
                          Idefics3Processor)
27
28

from vllm.attention import AttentionMetadata
29
from vllm.config import VllmConfig
30
31
32
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm.model_executor.layers.logits_processor import LogitsProcessor
33
from vllm.model_executor.layers.quantization import QuantizationConfig
34
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
35
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
36
from vllm.model_executor.models.module_mapping import MultiModelKeys
37
from vllm.model_executor.sampling_metadata import SamplingMetadata
38
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
39
from vllm.multimodal.inputs import NestedTensors
40
41
42
43
44
45
46
47
from vllm.multimodal.parse import ImageProcessorItems
from vllm.multimodal.processing import (BaseMultiModalProcessor,
                                        BaseProcessingInfo,
                                        MultiModalDataItems,
                                        MultiModalFieldConfig,
                                        PromptReplacement)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
48
49
50
51
52

# yapf: disable
from .idefics2_vision_model import (
    Idefics2VisionTransformer as Idefics3VisionTransformer)
# yapf: enable
53
from .interfaces import SupportsLoRA, SupportsMultiModal
54
from .llama import LlamaModel
55
56
from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix,
                    merge_multimodal_embeddings)
57
58
59
60
61
62
63
64

logger = init_logger(__name__)


class Idefics3ImagePixelInputs(TypedDict):
    type: Literal["pixel_values"]
    data: torch.Tensor
    """
65
66
    Shape: `(batch_size * num_images * num_patches, 
             num_channels, height, width)`
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
    """
    pixel_attention_mask: Optional[torch.BoolTensor]


class Idefics3ImageEmbeddingInputs(TypedDict):
    type: Literal["image_embeds"]
    data: torch.Tensor
    """
    Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
    `hidden_size` must match the hidden size of language model backbone.
    """


ImageInputs = Union[Idefics3ImagePixelInputs, Idefics3ImageEmbeddingInputs]


83
class Idefics3ProcessingInfo(BaseProcessingInfo):
84

85
    def get_hf_processor(
86
87
88
89
90
        self,
        *,
        size: Optional[Dict[str, int]] = None,
        **kwargs: object,
    ) -> Idefics3Processor:
91
        if size is not None:
92
            kwargs["size"] = size
93

94
        return self.ctx.get_hf_processor(Idefics3Processor, **kwargs)
95

96
97
    def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
        return {"image": None}
98

99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
    def get_mm_max_tokens_per_item(
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> Mapping[str, int]:
        hf_processor = self.get_hf_processor()
        image_processor: Idefics3ImageProcessor = hf_processor.image_processor
        grid_w, grid_h = self._get_image_feature_grid_size(
            image_width=image_processor.size['longest_edge'],
            image_height=image_processor.size['longest_edge'],
        )
        num_image_token = (grid_w * grid_h + 1) * hf_processor.image_seq_len
        # Calculate Non-image-token length
        # NOTE: <row_1_col_1> and <global-img> are special token for SmolVLM
        # but not for Idefic3, so we need to tokenize them to get actual length.
        tokenizer = self.get_tokenizer()
        tile_token_len = len(tokenizer.tokenize("<row_1_col_1>"))
        glob_token_len = len(tokenizer.tokenize(hf_processor.global_image_tag))
        # linebreak and <fake_token_around_image> always cost 1 token
        fake_token_len = lb_len = 1
        non_image_token = (grid_w * grid_h) * (
            tile_token_len + fake_token_len) + glob_token_len + (
                grid_h + 1) * lb_len + fake_token_len
        return {"image": num_image_token + non_image_token}

    def _resize_output_size(self,
                            *,
                            height: int,
                            width: int,
                            max_len: Optional[int] = None,
                            min_len: Optional[int] = 1,
                            max_size: Optional[int] = None) -> tuple[int, int]:
        # Set default value for max_len if not provided
        max_len = max(height, width) if max_len is None else max_len
        aspect_ratio = width / height

        # Handle the maximum size constraint
        if max_size is not None:
            max_len = min(max_len, max_size)

        # Adjust dimensions according to the aspect ratio
        if width >= height:
            width = max_len
            height = int(width / aspect_ratio)
        else:
            height = max_len
            width = int(height * aspect_ratio)
146

147
148
149
        # Ensure both width and height are even (if needed)
        height += height % 2
        width += width % 2
150

151
152
153
        # Ensure dimensions are not smaller than the minimum length
        height = max(height, min_len)
        width = max(width, min_len)
154

155
        return height, width
156

157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
    def _get_resize_output_image_size(
        self,
        *,
        image_width: int,
        image_height: int,
        resolution_max_side: int,
    ) -> tuple[int, int]:
        hf_processor = self.get_hf_processor()
        image_processor: Idefics3ImageProcessor = hf_processor.image_processor
        max_image_size = image_processor.size['longest_edge']
        if resolution_max_side > max_image_size:
            raise ValueError(
                "`resolution_max_side` cannot be larger than `max_image_size`")

        height, width = image_height, image_width

        # Find the output size, when rescaling the longest edge to max_len and
        # preserving the aspect ratio
        height, width = self._resize_output_size(height=height,
                                                 width=width,
                                                 max_len=resolution_max_side)
        return height, width

    def _get_image_feature_grid_size(
        self,
        *,
        image_width: int,
        image_height: int,
        size: Optional[dict[str, object]] = None,
    ) -> tuple[int, int]:
        hf_processor = self.get_hf_processor(size=size)
        image_processor: Idefics3ImageProcessor = hf_processor.image_processor
        max_image_size = image_processor.max_image_size['longest_edge']
        size = image_processor.size['longest_edge']
        assert size % max_image_size == 0, (
            "`longest_edge` in image_processor's `size` must be divisible by "
            "`longest_edge` in `max_image_size`, this may be caused by "
            "incorrect mm_kwargs override.")

        resized_height, resized_width = self._get_resize_output_image_size(
            image_width=image_width,
            image_height=image_height,
            resolution_max_side=size,
        )
        if resized_height > max_image_size or resized_width > max_image_size:
            grid_h = math.ceil(resized_height / max_image_size)
            grid_w = math.ceil(resized_width / max_image_size)
        else:
            grid_h = grid_w = 0
        return grid_w, grid_h
207
208


209
210
class Idefics3DummyInputsBuilder(BaseDummyInputsBuilder[Idefics3ProcessingInfo]
                                 ):
211

212
213
    def get_dummy_processor_inputs(
        self,
214
215
        seq_len: int,
        mm_counts: Mapping[str, int],
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
    ) -> ProcessorInputs:
        num_images = mm_counts.get("image", 0)
        hf_processor = self.info.get_hf_processor()
        image_processor: Idefics3ImageProcessor = hf_processor.image_processor
        longest_edge = image_processor.max_image_size['longest_edge']
        image_token: str = hf_processor.image_token.content

        mm_data = {
            "image":
            self._get_dummy_images(width=longest_edge,
                                   height=longest_edge,
                                   num_images=num_images)
        }

        return ProcessorInputs(
            prompt_text=image_token * num_images,
            mm_data=mm_data,
        )
234
235


236
237
class Idefics3MultimodalProcessor(
        BaseMultiModalProcessor[Idefics3ProcessingInfo]):
238

239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
    def _call_hf_processor(
        self,
        prompt: str,
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
    ) -> BatchFeature:
        if mm_data:
            processed_outputs = super()._call_hf_processor(
                prompt, mm_data, mm_kwargs)
            image_grids = [
                self.info._get_image_feature_grid_size(
                    image_width=img.width,
                    image_height=img.height,
                    **mm_kwargs,
                ) for img in mm_data["images"]
            ]
            image_patches = list(map(lambda x: math.prod(x) + 1, image_grids))
            for key in ("pixel_values", "pixel_attention_mask"):
                data = processed_outputs.pop(key)
                data = data.flatten(0, 1).split(image_patches)
                processed_outputs[key] = data
        else:
            tokenizer = self.info.get_tokenizer()
            processed_outputs = tokenizer(prompt,
                                          add_special_tokens=True,
                                          return_tensors="pt")
        return processed_outputs
266

267
268
269
270
271
272
273
274
275
276
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return dict(
            pixel_values=MultiModalFieldConfig.batched("image"),
            pixel_attention_mask=MultiModalFieldConfig.batched("image"),
            image_embeds=MultiModalFieldConfig.batched("image"),
        )
277

278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
    def _get_prompt_replacements(
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
        out_mm_kwargs: MultiModalKwargs,
    ) -> list[PromptReplacement]:
        hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)

        image_token = hf_processor.image_token.content
        fake_image_token = hf_processor.fake_image_token.content
        global_img_token = hf_processor.global_image_tag
        image_seq_len = hf_processor.image_seq_len
        grid_placeholder = "<row_{n_h}_col_{n_w}>"

        p_img = image_token * image_seq_len
        global_img_placeholder = fake_image_token + global_img_token + p_img
        tile_img_placeholder = fake_image_token + grid_placeholder + p_img

        def get_replacement_idefics3(item_idx: int) -> str:
            images = mm_items.get_items("image", ImageProcessorItems)

            image_size = images.get_image_size(item_idx)
            grid_w, grid_h = self.info._get_image_feature_grid_size(
                image_width=image_size.width,
                image_height=image_size.height,
                **hf_processor_mm_kwargs,
            )
            if grid_w == 0 and grid_h == 0:
                image_placeholder = global_img_placeholder
            else:
                tiles_placeholder = list[str]()
                for i in range(grid_h):
                    for j in range(grid_w):
                        placeholder_per_tile = tile_img_placeholder.format(
                            n_h=i + 1, n_w=j + 1)
                        tiles_placeholder.append(placeholder_per_tile)
                        # Add line break if it is the last tile in the row
                        if j == grid_w - 1:
                            tiles_placeholder.append("\n")

                image_placeholder = "".join(
                    [*tiles_placeholder, "\n", global_img_placeholder])
            return image_placeholder + fake_image_token

        return [
            PromptReplacement(
                modality="image",
                target=image_token,
                replacement=get_replacement_idefics3,
            )
        ]
329
330
331
332


class Idefics3SimpleMLP(nn.Module):

333
334
335
336
337
338
    def __init__(
        self,
        config: Idefics3Config,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
    ):
339
340
341
342
        super().__init__()
        input_size = config.vision_config.hidden_size * (config.scale_factor**
                                                         2)
        output_size = config.text_config.hidden_size
343
344
345
346
347
348
349
        self.proj = ReplicatedLinear(
            input_size,
            output_size,
            bias=False,
            quant_config=quant_config,
            prefix=maybe_prefix(prefix, "proj"),
        )
350
351
352
353
354
355
356
357

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        out, _ = self.proj(x)
        return out


class Idefics3Connector(nn.Module):

358
359
360
361
362
363
    def __init__(
        self,
        config: Idefics3Config,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
    ):
364
365
        super().__init__()
        self.scale_factor = config.scale_factor
366
367
368
369
370
        self.modality_projection = Idefics3SimpleMLP(
            config,
            quant_config,
            prefix=maybe_prefix(prefix, "modality_projection"),
        )
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400

    def pixel_shuffle(self,
                      x: torch.Tensor,
                      scale_factor: int = 2) -> torch.Tensor:
        bsz, seq, embed_dim = x.size()
        height = width = int(seq**0.5)
        x = x.view(bsz, height, width, embed_dim)
        x = x.view(bsz, height, int(width / scale_factor),
                   embed_dim * scale_factor)
        x = x.permute(0, 2, 1, 3)
        x = x.reshape(
            bsz,
            int(width / scale_factor),
            int(height / scale_factor),
            embed_dim * (scale_factor**2),
        )
        x = x.permute(0, 2, 1, 3)
        x = x.reshape(bsz, int(seq / (scale_factor**2)),
                      embed_dim * (scale_factor**2))
        return x

    def forward(self, image_hidden_states: torch.Tensor) -> torch.Tensor:
        image_hidden_states = self.pixel_shuffle(image_hidden_states,
                                                 self.scale_factor)
        image_hidden_states = self.modality_projection(image_hidden_states)
        return image_hidden_states


class Idefics3Model(nn.Module):

401
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
402
        super().__init__()
403

404
        config: Idefics3Config = vllm_config.model_config.hf_config
405
406
        quant_config = vllm_config.quant_config

407
408
409
        self.config = config
        self.padding_idx = self.config.text_config.pad_token_id
        self.vocab_size = self.config.text_config.vocab_size
410
411
412
413
414
415
416
417
418
        self.vision_model = Idefics3VisionTransformer(
            config.vision_config,
            quant_config=quant_config,
            prefix=maybe_prefix(prefix, "vision_model"))
        self.connector = Idefics3Connector(
            config,
            quant_config,
            prefix=maybe_prefix(prefix, "connector"),
        )
419
420
421
422
        self.text_model = LlamaModel(
            vllm_config=vllm_config.with_hf_config(config.text_config),
            prefix=maybe_prefix(prefix, "text_model"),
        )
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473

        self.image_seq_len = int(
            ((config.vision_config.image_size //
              config.vision_config.patch_size)**2) / (config.scale_factor**2))
        self.image_token_id = self.config.image_token_id

    def _validate_pixel_values(
        self, data: Union[torch.Tensor, List[torch.Tensor]]
    ) -> Union[torch.Tensor, List[torch.Tensor]]:

        h = w = self.config.vision_config.image_size
        expected_dims = (3, h, w)

        def _validate_shape(d: torch.Tensor):
            actual_dims = tuple(d.shape[1:])

            if actual_dims != expected_dims:
                expected_expr = ("num_patches", *map(str, expected_dims))
                raise ValueError(
                    "The expected shape of pixel values per image per batch "
                    f"is {expected_expr}. You supplied {tuple(d.shape)}.")

        for d in data:
            _validate_shape(d)

        return data

    def _parse_and_validate_image_input(
            self, **kwargs: object) -> Optional[ImageInputs]:
        pixel_values = kwargs.pop("pixel_values", None)
        image_embeds = kwargs.pop("image_embeds", None)
        pixel_attention_mask = kwargs.pop("pixel_attention_mask", None)

        if pixel_values is None and image_embeds is None:
            return None

        if image_embeds is not None:
            if not isinstance(image_embeds, (torch.Tensor, list)):
                raise ValueError("Incorrect type of image embeddings. "
                                 f"Got type: {type(image_embeds)}")

            return Idefics3ImageEmbeddingInputs(
                type="image_embeds",
                data=flatten_bn(image_embeds, concat=True),
            )

        if pixel_values is not None:
            if not isinstance(pixel_values, (torch.Tensor, list)):
                raise ValueError("Incorrect type of pixel values. "
                                 f"Got type: {type(pixel_values)}")

474
475
476
477
478
479
480
481
482
483
484
            if isinstance(pixel_values, list):
                pixel_values = torch.cat(pixel_values, dim=1)
                pixel_attention_mask = torch.cat(pixel_attention_mask, dim=1)
            else:
                pixel_values = flatten_bn(pixel_values)
                pixel_attention_mask = flatten_bn(pixel_attention_mask)

            return Idefics3ImagePixelInputs(
                type="pixel_values",
                data=self._validate_pixel_values(pixel_values),
                pixel_attention_mask=pixel_attention_mask)
485
486
487
488
489
490
491

        raise AssertionError("This line should be unreachable.")

    def _image_pixels_to_features(
        self,
        pixel_values: torch.Tensor,
        pixel_attention_mask: Optional[torch.BoolTensor] = None,
492
    ) -> NestedTensors:
493
494
        # NOTE: we skip the step to select the vision feature layer since
        # this is already done inside the vision tower
495
        num_patches = [x.size(0) for x in pixel_values]
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
        pixel_values = pixel_values.to(
            dtype=self.vision_model.embeddings.patch_embedding.weight.dtype
        )  # fp16 compatibility

        # Remove padding images - padding images are full 0.
        nb_values_per_image = pixel_values.shape[1:].numel()
        real_images_inds = (pixel_values == 0.0).sum(
            dim=(-1, -2, -3)) != nb_values_per_image
        pixel_values = pixel_values[real_images_inds].contiguous()

        # Handle the vision attention mask
        if pixel_attention_mask is None:
            pixel_attention_mask = torch.ones(
                size=(pixel_values.size(0), pixel_values.size(2),
                      pixel_values.size(3)),
                dtype=torch.bool,
                device=pixel_values.device,
            )
        else:
            # Remove padding images from the mask
            pixel_attention_mask = pixel_attention_mask[
                real_images_inds].contiguous()

        patch_size = self.config.vision_config.patch_size
        patches_subgrid = pixel_attention_mask.unfold(dimension=1,
                                                      size=patch_size,
                                                      step=patch_size)
        patches_subgrid = patches_subgrid.unfold(dimension=2,
                                                 size=patch_size,
                                                 step=patch_size)
        patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()

        # Get sequence from the vision encoder
        image_hidden_states = self.vision_model(
            pixel_values=pixel_values,
            patch_attention_mask=patch_attention_mask,
        )

534
        return image_hidden_states.split(num_patches)
535
536

    def _process_image_pixels(
537
            self, inputs: Idefics3ImagePixelInputs) -> NestedTensors:
538
539
540
541
542
543
544
545
546
547
548
549
550
551
        assert self.vision_model is not None

        pixel_values = inputs["data"]
        pixel_attention_mask = inputs["pixel_attention_mask"]

        return self._image_pixels_to_features(pixel_values,
                                              pixel_attention_mask)

    def _process_image_input(self, image_input: ImageInputs) -> torch.Tensor:
        if image_input["type"] == "image_embeds":
            return image_input["data"]

        assert self.vision_model is not None
        image_features = self._process_image_pixels(image_input)
552
553
554
        num_patches = [x.size(0) for x in image_features]
        image_features = torch.cat(image_features)
        return self.connector(image_features).split(num_patches)
555

556
557
558
559
560
561
    def get_input_embeddings(
        self,
        input_ids: torch.Tensor,
    ) -> torch.Tensor:
        return self.text_model.get_input_embeddings(input_ids)

562
563
564
565
566
567
568
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
        intermediate_tensors: Optional[IntermediateTensors] = None,
569
        inputs_embeds: Optional[torch.Tensor] = None,
570
571
572
573
574
575
576
577
578
579
580
581
582
    ) -> Union[torch.Tensor, IntermediateTensors]:

        hidden_states = self.text_model(
            input_ids,
            positions,
            kv_caches,
            attn_metadata,
            intermediate_tensors,
            inputs_embeds=inputs_embeds,
        )
        return hidden_states


583
584
585
586
@MULTIMODAL_REGISTRY.register_processor(
    Idefics3MultimodalProcessor,
    info=Idefics3ProcessingInfo,
    dummy_inputs=Idefics3DummyInputsBuilder)
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
class Idefics3ForConditionalGeneration(nn.Module, SupportsMultiModal,
                                       SupportsLoRA):
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
        "gate_up_proj": [
            "gate_proj",
            "up_proj",
        ],
    }
    # LoRA specific attributes
    supported_lora_modules = [
        # vision_model
        "fc1",
        "fc2",
        "out_proj",
        # text_model
        "qkv_proj",  # same name with vision encoder
        "o_proj",
        "gate_up_proj",
        "down_proj",
    ]
612

613
614
    embedding_modules = {}
    embedding_padding_modules = []
615

616
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
617
618
        super().__init__()

619
620
621
622
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config

623
624
625
        self.config = config
        self.multimodal_config = multimodal_config

626
627
        self.model = Idefics3Model(vllm_config=vllm_config,
                                   prefix=maybe_prefix(prefix, "model"))
628
629
630
631
632
633
634
635
636
637
        self.image_token_id = self.config.image_token_id

        self.lm_head = ParallelLMHead(
            config.text_config.vocab_size,
            config.text_config.hidden_size,
            quant_config=quant_config,
        )
        if self.config.text_config.tie_word_embeddings:
            self.lm_head.weight = self.model.text_model.wte.weight
        self.logits_processor = LogitsProcessor(config.text_config.vocab_size)
638
        self.sampler = get_sampler()
639

640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
    def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
        image_input = self.model._parse_and_validate_image_input(**kwargs)
        if image_input is None:
            return None
        vision_embeddings = self.model._process_image_input(image_input)
        return vision_embeddings

    def get_input_embeddings(
        self,
        input_ids: torch.Tensor,
        multimodal_embeddings: Optional[NestedTensors] = None,
    ) -> torch.Tensor:
        inputs_embeds = self.model.get_input_embeddings(input_ids)
        if multimodal_embeddings is not None:
            inputs_embeds = merge_multimodal_embeddings(
                input_ids, inputs_embeds, multimodal_embeddings,
                self.config.image_token_id)
        return inputs_embeds

659
660
661
662
663
664
665
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
        intermediate_tensors: Optional[IntermediateTensors] = None,
666
        inputs_embeds: Optional[torch.Tensor] = None,
667
668
        **kwargs: object,
    ) -> Union[torch.Tensor, IntermediateTensors]:
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
        if intermediate_tensors is not None:
            inputs_embeds = None

        # NOTE: In v1, inputs_embeds is always generated at model runner, this
        # condition is for v0 compatibility.
        elif inputs_embeds is None:
            vision_embeddings = self.get_multimodal_embeddings(**kwargs)
            inputs_embeds = self.get_input_embeddings(input_ids,
                                                      vision_embeddings)
            input_ids = None

        hidden_states = self.model.text_model(input_ids,
                                              positions,
                                              kv_caches,
                                              attn_metadata,
                                              intermediate_tensors,
                                              inputs_embeds=inputs_embeds)

687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
        return hidden_states

    def compute_logits(self, hidden_states: torch.Tensor,
                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
        logits = self.logits_processor(self.lm_head, hidden_states,
                                       sampling_metadata)
        return logits

    def sample(
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
        next_tokens = self.sampler(logits, sampling_metadata)
        return next_tokens

703
704
    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
705
        loader = AutoWeightsLoader(self)
706
        return loader.load_weights(weights)
707
708
709
710
711
712
713
714
715

    def get_mm_mapping(self) -> MultiModelKeys:
        """
        Get the module prefix in multimodal models
        """
        return MultiModelKeys.from_string_field(
            language_model="model.text_model",
            connector="model.connector",
            tower_model="model.vision_model")