phi3v.py 27.4 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# Copyright 2024 The vLLM team.
# Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
15
from collections.abc import Iterable, Mapping, Sequence
16
from functools import cached_property
17
from typing import Any, List, Literal, Optional, Set, Tuple, TypedDict, Union
18
19
20

import torch
import torch.nn as nn
21
22
from transformers import (BatchFeature, CLIPVisionConfig, PretrainedConfig,
                          ProcessorMixin)
23
24

from vllm.attention import AttentionMetadata
25
from vllm.config import VllmConfig
26
from vllm.logger import init_logger
27
from vllm.model_executor.layers.quantization import QuantizationConfig
Joe Runde's avatar
Joe Runde committed
28
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
29
30
from vllm.model_executor.layers.vocab_parallel_embedding import (
    VocabParallelEmbedding)
31
from vllm.model_executor.sampling_metadata import SamplingMetadata
32
from vllm.multimodal import MULTIMODAL_REGISTRY
33
34
from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
                                    NestedTensors)
35
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
36
                                   ImageSize, MultiModalDataItems)
37
38
# yapf conflicts with isort for this block
# yapf: disable
39
from vllm.multimodal.processing import (BaseMultiModalProcessor,
40
41
                                        BaseProcessingInfo,
                                        BoundPromptReplacement,
42
43
44
45
                                        PlaceholderFeaturesInfo,
                                        PromptReplacement,
                                        PromptReplacementDetails)
# yapf: enable
46
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
47
from vllm.sequence import IntermediateTensors
48
from vllm.utils import is_list_of
49

50
from .clip import CLIPVisionModel
51
from .interfaces import SupportsMultiModal, SupportsPP
52
53
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
                    init_vllm_registered_model, maybe_prefix,
54
                    merge_multimodal_embeddings)
55

56
57
logger = init_logger(__name__)

58
59
60
# Cannot find the following 2 numbers from hf config.
_IMAGE_TOKEN_ID = 32044

61
62
63
64
65
66
67
68
69
70
71
72
CLIP_VIT_LARGE_PATCH14_336_CONFIG = CLIPVisionConfig(dropout=0.0,
                                                     hidden_act="quick_gelu",
                                                     hidden_size=1024,
                                                     image_size=336,
                                                     intermediate_size=4096,
                                                     num_attention_heads=16,
                                                     num_channels=3,
                                                     num_hidden_layers=24,
                                                     patch_size=14,
                                                     projection_dim=768)


73
def _init_img_processor(hf_config: PretrainedConfig,
74
75
                        quant_config: Optional[QuantizationConfig],
                        prefix: str = "") -> CLIPVisionModel:
76
77
78
79
80
81
82
83
84
85
86
    clip_config = CLIP_VIT_LARGE_PATCH14_336_CONFIG
    layer_idx = hf_config.img_processor.get('layer_idx', -2)

    # Initialize the CLIP only up to the required feature layer
    if layer_idx < 0:
        num_hidden_layers = clip_config.num_hidden_layers + \
            layer_idx + 1
    else:
        num_hidden_layers = layer_idx + 1

    img_processor = CLIPVisionModel(
87
88
89
        clip_config,
        quant_config,
        num_hidden_layers_override=num_hidden_layers,
90
        prefix=prefix,
91
    )
92
93
94
95

    return img_processor


96
97
98
99
class Phi3VImagePixelInputs(TypedDict):
    type: Literal["pixel_values"]
    data: Union[torch.Tensor, List[torch.Tensor]]
    """
100
101
    Shape:
    `(batch_size * num_images, 1 + num_patches, num_channels, height, width)`
102

103
104
    Note that `num_patches` may be different per batch and image,
    in which case the data is passed as a list instead of a batched tensor.
105
106
107
108
    """

    image_sizes: torch.Tensor
    """
109
    Shape: `(batch_size * num_images, 2)`
110
111
112
113
114
115
116
117

    This should be in `(height, width)` format.
    """


class Phi3VImageEmbeddingInputs(TypedDict):
    type: Literal["image_embeds"]
    data: Union[torch.Tensor, List[torch.Tensor]]
118
    """Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
119
120
121
122
123
124
125
126

    `hidden_size` must match the hidden size of language model backbone.
    """


Phi3VImageInputs = Union[Phi3VImagePixelInputs, Phi3VImageEmbeddingInputs]


127
128
class Phi3ImageEmbeddingBase(nn.Module):

129
    def __init__(self) -> None:
130
131
132
133
134
135
136
137
138
        super().__init__()
        self.layer_idx: int
        self.type_feature: str
        self.img_processor: CLIPVisionModel

    def get_img_features(self,
                         img_embeds: torch.FloatTensor) -> torch.FloatTensor:
        TYPE_FEATURE = self.type_feature

139
140
        # NOTE: we skip the step to select the vision feature layer since
        # this is already done inside the img_processor
141
        img_feature = self.img_processor(img_embeds)
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156

        if TYPE_FEATURE == "patch":
            patch_feature = img_feature[:, 1:]
            return patch_feature

        if TYPE_FEATURE == "cls_patch":
            return img_feature

        raise NotImplementedError


# adapted from https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_embedding_phi3_v.py
class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase):
    """Phi3 Image embedding with HD transform."""

157
158
159
160
    def __init__(self,
                 config: PretrainedConfig,
                 quant_config: Optional[QuantizationConfig],
                 prefix: str = "") -> None:
161
        super().__init__()
162
163
164
165
166

        # n_embed or hidden_size
        hidden_size = config.n_embd if hasattr(
            config, 'n_embd') else config.hidden_size

167
168
        self.img_processor = _init_img_processor(
            config, quant_config, prefix=f"{prefix}.img_processor")
169

170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
        image_dim_out = config.img_processor['image_dim_out']
        self.num_img_tokens = config.img_processor['num_img_tokens']

        self.image_dim_out = image_dim_out

        # global_gn and sub_gn for hd transform, serves as line separator
        self.use_hd_transform = config.embd_layer.get('use_hd_transform',
                                                      False)
        self.with_learnable_separator = config.embd_layer.get(
            'with_learnable_separator', False)
        self.hd_transform_order = config.embd_layer.get(
            'hd_transform_order', 'glb_sub')
        # with_hd_transform and with_learnable_separator should have same value
        assert self.use_hd_transform and self.with_learnable_separator

        # 1024 * 4, merge spatial to channel dimension
        self.glb_GN = nn.Parameter(torch.empty([1, 1, self.image_dim_out * 4]))
        self.sub_GN = nn.Parameter(
            torch.empty([1, 1, 1, self.image_dim_out * 4]))

        dim_projection = hidden_size
        depth = 2
        layers = [nn.Linear(image_dim_out * 4, dim_projection)]
        for _ in range(1, depth):
            layers.extend(
                [nn.GELU(),
                 nn.Linear(dim_projection, dim_projection)])
        self.img_projection = nn.Sequential(*layers)

        self.type_feature = config.img_processor.get('type_feature', 'patch')

201
    def forward(self, pixel_values: torch.FloatTensor,
202
                image_sizes: torch.Tensor) -> torch.FloatTensor:
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
        """
        process image and return vision embeddings.

        pixel_values: (num_images, num_crops, c, h, w)
        output: (num_images, num_img_tokens, hidden_size)
        """
        num_images, num_crops, c, h, w = pixel_values.shape
        pixel_values = pixel_values.flatten(0, 1)
        img_features = self.get_img_features(pixel_values)
        img_features = img_features.reshape(num_images, num_crops, -1,
                                            self.image_dim_out)
        image_features_proj = self.hd_feature_transform(
            img_features, image_sizes)
        return image_features_proj

    def hd_feature_transform(self, image_features, image_sizes):
        """
        image_features: (num_images, num_crops+1, 24*24, 1024)
        """
        assert (
            self.hd_transform_order == 'sub_glb'
        ), f'hd_transform_order `{self.hd_transform_order}` not implemented'
        if isinstance(self.img_projection, nn.Sequential):
            target_device = self.img_projection[0].bias.device
            target_dtype = self.img_projection[0].bias.dtype
        else:  # It's a single nn.Linear layer
            target_device = self.img_projection.bias.device
            target_dtype = self.img_projection.bias.dtype

        global_image_features = image_features[:,
                                               0]  # (num_images, 24*24, 1024)
        # global feature can be viewed as a special HD case with num_crops 1x1
        global_image_features_hd = self.reshape_hd_patches_2x2merge(
            global_image_features, 1, 1)
        global_image_features_hd_newline = self.add_image_newline(
            global_image_features_hd)

240
        batch_image_features_proj = []
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
        # need a for loop to process each image because of different image sizes
        # (patch arrangement is different for each image)
        for i, img_size in enumerate(image_sizes):
            h, w = img_size
            h_crop = h // 336
            w_crop = w // 336
            num_crops = h_crop * w_crop

            # NOTE: real num_crops is padded
            # (num_crops, 24*24, 1024)
            sub_image_features = image_features[i, 1:1 + num_crops]
            sub_image_features_hd = self.reshape_hd_patches_2x2merge(
                sub_image_features, h_crop, w_crop)
            sub_image_features_hd_newline = self.add_image_newline(
                sub_image_features_hd)

            # [sub features, separator, global features]
258
259
260
261
262
263
264
265
266
267
268
            image_embeddings = torch.cat([
                sub_image_features_hd_newline.squeeze(
                    0),  # (h_crop*12*(w_crop*12+1), 4096)
                self.glb_GN.squeeze(0),
                global_image_features_hd_newline[i],
            ])
            img_proj = self.img_projection(
                image_embeddings.to(target_device, target_dtype))
            batch_image_features_proj.append(img_proj)

        return batch_image_features_proj
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305

    def reshape_hd_patches_2x2merge(self, image_features, h_crop, w_crop):
        """
        image_features: (num_images*num_crops, 24*24, 1024)
        output: (num_images, h_crop*12, w_crop*12, 4096)
        where h_crop*w_crop == num_crops
        """
        N, L, C = image_features.shape
        assert L == 576 and C == 1024 and N % (h_crop * w_crop) == 0
        num_images = N // (h_crop * w_crop)
        H = int(L**0.5)
        image_features_hd = (
            image_features.reshape(N, H, H, C)  # N, 24, 24, 1024
            .reshape(N, H // 2, 2, H // 2, 2, C)  # N, 12, 2, 12, 2, 1024
            .permute(0, 1, 3, 2, 4, 5)  # N, 12, 12, 2, 2, 1024
            .reshape(N, -1, 4 * C)  # N, 144, 4096
            .reshape(num_images, h_crop, w_crop, H // 2, H // 2,
                     -1)  # n_img, h_crop, w_crop, 12, 12, 4096
            .permute(0, 1, 3, 2, 4, 5)  # n_img, h_crop, 12, w_crop, 12, 4096
            .reshape(num_images, h_crop * H // 2, w_crop * H // 2,
                     4 * C)  # n_img, h_crop*12, w_crop*12, 4096
        )
        return image_features_hd

    def add_image_newline(self, image_features_hd):
        """
        image_features_hd: (num_images, h_crop*12, w_crop*12, 4096)
        output: (num_images, (h_crop*12) * (w_crop*12+1), 4096)
        """
        num_images, h, w, hid_dim = image_features_hd.shape
        # add the newline token to the HD image feature patches
        newline_embeddings = self.sub_GN.expand(num_images, h, -1,
                                                -1)  # (n_img, h, 1, hid_dim)
        image_features_hd_newline = torch.cat(
            [image_features_hd, newline_embeddings],
            dim=2).reshape(num_images, -1, hid_dim)
        return image_features_hd_newline
306
307


308
class Phi3VProcessingInfo(BaseProcessingInfo):
309

310
    def get_hf_processor(
311
312
313
314
315
316
317
318
        self,
        *,
        num_crops: Optional[int] = None,
    ) -> ProcessorMixin:
        if num_crops is not None:
            return self.ctx.get_hf_processor(num_crops=num_crops)

        return self.ctx.get_hf_processor()
319

320
321
322
    def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
        return {"image": None}

323
    def get_mm_max_tokens_per_item(self, seq_len: int) -> Mapping[str, int]:
324
        target_width, target_height = self.get_image_size_with_most_features()
325

326
        max_image_tokens = self.get_num_image_tokens(
327
328
            image_width=target_width,
            image_height=target_height,
329
            processor=None,
330
        )
331

332
        return {"image": max_image_tokens}
333

334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
    def get_num_image_tokens(
        self,
        *,
        image_width: int,
        image_height: int,
        processor: Optional[ProcessorMixin],
    ) -> int:
        if processor is None:
            processor = self.get_hf_processor()

        return processor.calc_num_image_tokens_from_image_size(  # type: ignore
            width=image_width,
            height=image_height,
        )

    def get_image_size_with_most_features(self) -> ImageSize:
350
351
352
        # Result in the max possible feature size (h:w = 16:1)
        return ImageSize(height=8000, width=50)

353
354
355

class Phi3VDummyInputsBuilder(BaseDummyInputsBuilder[Phi3VProcessingInfo]):

356
    def get_dummy_processor_inputs(
357
        self,
358
359
360
361
        seq_len: int,
        mm_counts: Mapping[str, int],
    ) -> ProcessorInputs:
        num_images = mm_counts.get("image", 0)
362

363
364
        target_width, target_height = \
            self.info.get_image_size_with_most_features()
365
366
367
368
369
370
371
372

        mm_data = {
            "image":
            self._get_dummy_images(width=target_width,
                                   height=target_height,
                                   num_images=num_images)
        }

373
        hf_processor = self.info.get_hf_processor()
374
375
376
377
378
379
380
381
        image_tokens: list[str] = hf_processor.img_tokens  # type: ignore

        return ProcessorInputs(
            prompt_text="".join(image_tokens[:num_images]),
            mm_data=mm_data,
        )


382
class Phi3VMultiModalProcessor(BaseMultiModalProcessor[Phi3VProcessingInfo]):
383

384
    def _call_hf_processor(
385
386
        self,
        prompt: str,
387
388
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
389
    ) -> BatchFeature:
390
391
        processed_outputs = super()._call_hf_processor(
            prompt=prompt,
392
393
            mm_data=mm_data,
            mm_kwargs=mm_kwargs,
394
395
        )

396
397
398
        input_ids = processed_outputs["input_ids"]
        assert isinstance(input_ids, torch.Tensor)

399
400
401
        # Phi3v processor has inserted -1, -2 etc as placeholder in prompt_ids,
        # which will cause OverflowError when decoding the prompt_ids.
        # Therefore, we need to do an early replacement here
402
        input_ids.masked_fill_(input_ids < 0, _IMAGE_TOKEN_ID)
403

404
405
        return processed_outputs

406
407
408
409
410
411
412
413
414
415
416
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return dict(
            pixel_values=MultiModalFieldConfig.batched("image"),
            image_sizes=MultiModalFieldConfig.batched("image"),
            image_embeds=MultiModalFieldConfig.batched("image"),
        )

417
418
419
    def _get_prompt_replacements(
        self,
        mm_items: MultiModalDataItems,
420
        hf_processor_mm_kwargs: Mapping[str, Any],
421
        out_mm_kwargs: MultiModalKwargs,
422
    ) -> list[PromptReplacement]:
423
        hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
424
425
        image_tokens: list[str] = hf_processor.img_tokens  # type: ignore

426
        tokenizer = self.info.get_tokenizer()
427
428
        bos_token_id = tokenizer.bos_token_id
        assert isinstance(bos_token_id, int)
429
430

        def get_replacement_phi3v(item_idx: int):
431
432
433
434
435
436
437
            images = mm_items.get_items(
                "image", (ImageEmbeddingItems, ImageProcessorItems))

            if isinstance(images, ImageEmbeddingItems):
                num_image_tokens = images.get_feature_size(item_idx)
            else:
                image_size = images.get_image_size(item_idx)
438
                num_image_tokens = self.info.get_num_image_tokens(
439
440
                    image_width=image_size.width,
                    image_height=image_size.height,
441
                    processor=hf_processor,
442
443
                )

444
445
446
447
448
449
            image_tokens = [_IMAGE_TOKEN_ID] * num_image_tokens

            return PromptReplacementDetails(
                full=image_tokens + [bos_token_id],
                features=image_tokens,
            )
450

451
452
        num_images = mm_items.get_count("image", strict=False)

453
454
455
456
457
        return [
            PromptReplacement(
                modality="image",
                target=image_token,
                replacement=get_replacement_phi3v,
458
            ) for image_token in image_tokens[:num_images]
459
460
        ]

461
462
463
    def _apply_prompt_replacements(
        self,
        token_ids: list[int],
464
        mm_prompt_repls: Mapping[str, Sequence[BoundPromptReplacement]],
465
        mm_item_counts: Mapping[str, int],
466
    ) -> tuple[list[int], str, Mapping[str, list[PlaceholderFeaturesInfo]]]:
467
468
        token_ids, text, placeholders = super()._apply_prompt_replacements(
            token_ids=token_ids,
469
            mm_prompt_repls=mm_prompt_repls,
470
471
472
473
474
475
476
            mm_item_counts=mm_item_counts,
        )

        # Keep the behavior in line with HF processor
        if text.startswith("<s> <|image|>"):
            text = text.replace("<s> <|image|>", "<s><|image|>", 1)
            token_ids = [token_ids[0], *token_ids[2:]]
477
478
            placeholders = {
                modality: [
479
                    PlaceholderFeaturesInfo(
480
481
482
                        modality=p.modality,
                        item_idx=p.item_idx,
                        start_idx=p.start_idx - 1,
483
                        tokens=p.tokens,
484
485
486
487
                    ) for p in ps
                ]
                for modality, ps in placeholders.items()
            }
488
489
490

        return token_ids, text, placeholders

491

492
493
494
@MULTIMODAL_REGISTRY.register_processor(Phi3VMultiModalProcessor,
                                        info=Phi3VProcessingInfo,
                                        dummy_inputs=Phi3VDummyInputsBuilder)
495
class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
496
497
498
499
500
501
502
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={
            "model.vision_embed_tokens.wte": "embed_tokens",
            "model.vision_embed_tokens.": "vision_embed_tokens.",
            "lm_head.": "language_model.lm_head.",
            "model.": "language_model.model.",
        })
503

504
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
505
        super().__init__()
506
507
508
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config
509
        self.config = config
510
        self.multimodal_config = multimodal_config
511
        self.image_token_id = _IMAGE_TOKEN_ID
512

513
514
515
516
517
        self.embed_tokens = VocabParallelEmbedding(
            config.vocab_size,
            config.hidden_size,
            org_num_embeddings=config.vocab_size,
            quant_config=quant_config,
518
            prefix=maybe_prefix(prefix, "model.embed_tokens"),
519
520
521
        )

        # TODO: Optionally initializes this for supporting input embeddings.
522
        self.vision_embed_tokens = Phi3HDImageEmbedding(
523
524
525
            config,
            quant_config,
            prefix=maybe_prefix(prefix, "model.vision_embed_tokens"))
526

527
528
529
530
531
532
533
534
535
536
537
        self.language_model = init_vllm_registered_model(
            vllm_config=vllm_config,
            # The prefix is empty intentionally because default prefix of
            # LlamaForCausalLM is "model"
            prefix="",
            # We don't directly initialize vLLM's LlamaForCausalLM so we
            # can automatically apply embedding wrapper if this model is
            # initialized as an embedding model
            architectures=["LlamaForCausalLM"],
        )

538
539
540
541
542
543
544
545
        self.make_empty_intermediate_tensors = (
            self.language_model.make_empty_intermediate_tensors)

    @cached_property
    def sampler(self):
        if hasattr(self.language_model, "sampler"):
            return self.language_model.sampler

Joe Runde's avatar
Joe Runde committed
546
        return get_sampler()
547

548
    def _validate_image_sizes(self, data: torch.Tensor) -> torch.Tensor:
549
550
551
552
553
554
555
556
557
558
559
560
561
        expected_dims = (2, )

        def _validate_shape(d: torch.Tensor):
            actual_dims = tuple(d.shape)

            if actual_dims != expected_dims:
                expected_expr = str(expected_dims)
                raise ValueError(
                    f"The expected shape of image sizes per image per batch "
                    f"is {expected_expr}. You supplied {tuple(d.shape)}.")

        for d in data:
            _validate_shape(d)
562
563
564
565
566
567
568

        return data

    def _validate_pixel_values(
        self, data: Union[torch.Tensor, List[torch.Tensor]]
    ) -> Union[torch.Tensor, List[torch.Tensor]]:

569
570
571
572
573
574
575
576
        h = w = CLIP_VIT_LARGE_PATCH14_336_CONFIG.image_size
        expected_dims = (3, h, w)

        def _validate_shape(d: torch.Tensor):
            actual_dims = tuple(d.shape[1:])

            if actual_dims != expected_dims:
                expected_expr = ("num_patches", *map(str, expected_dims))
577
                raise ValueError(
578
                    "The expected shape of pixel values per image per batch "
579
                    f"is {expected_expr}. You supplied {tuple(d.shape)}.")
580

581
582
        for d in data:
            _validate_shape(d)
583
584
585

        return data

586
    def _parse_and_validate_image_input(
587
            self, **kwargs: object) -> Optional[Phi3VImageInputs]:
588
589
        pixel_values = kwargs.pop("pixel_values", None)
        image_sizes = kwargs.pop("image_sizes", None)
590
        image_embeds = kwargs.pop("image_embeds", None)
591

592
593
594
595
596
597
598
599
        if pixel_values is None and image_embeds is None:
            return None

        if pixel_values is not None:
            if not isinstance(pixel_values, (torch.Tensor, list)):
                raise ValueError("Incorrect type of pixel values. "
                                 f"Got type: {type(pixel_values)}")

600
            if not isinstance(image_sizes, (torch.Tensor, list)):
601
602
603
604
605
                raise ValueError("Incorrect type of image sizes. "
                                 f"Got type: {type(image_sizes)}")

            return Phi3VImagePixelInputs(
                type="pixel_values",
606
607
608
                data=self._validate_pixel_values(flatten_bn(pixel_values)),
                image_sizes=self._validate_image_sizes(
                    flatten_bn(image_sizes, concat=True)))
609
610
611
612
613

        if image_embeds is not None:
            if not isinstance(image_embeds, torch.Tensor):
                raise ValueError("Incorrect type of image embeddings. "
                                 f"Got type: {type(image_embeds)}")
614

615
616
            return Phi3VImageEmbeddingInputs(
                type="image_embeds",
617
                data=flatten_bn(image_embeds),
618
619
620
621
622
623
624
625
626
627
            )

        raise AssertionError("This line should be unreachable.")

    def _process_image_input(
        self,
        image_input: Phi3VImageInputs,
    ) -> torch.Tensor:

        if image_input["type"] == "image_embeds":
628
629
630
631
632
633
634
635
636
637
638
            image_data = image_input["data"]
            if is_list_of(image_data, torch.Tensor):
                # it's already a list of tensors
                return image_data
            if len(image_data.shape) == 3:
                # 3D tensor
                return list(torch.unbind(image_data, dim=0))
            raise ValueError(
                "We expect batched 2D tensors;"
                "this can be either a list of 2D tensors or a single 3D tensor."
            )
639

640
641
642
        assert self.vision_embed_tokens is not None
        image_embeds = self.vision_embed_tokens(image_input["data"],
                                                image_input["image_sizes"])
643

644
        return image_embeds
645

646
    def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
647
648
649
650
651
652
653
654
655
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input is None:
            return None
        vision_embeddings = self._process_image_input(image_input)
        return vision_embeddings

    def get_input_embeddings(
        self,
        input_ids: torch.Tensor,
656
        multimodal_embeddings: Optional[NestedTensors] = None,
657
658
    ) -> torch.Tensor:
        inputs_embeds = self.embed_tokens(input_ids)
659
        if multimodal_embeddings is not None:
660
            inputs_embeds = merge_multimodal_embeddings(
661
                input_ids, inputs_embeds, multimodal_embeddings,
662
663
664
                self.image_token_id)
        return inputs_embeds

665
666
667
    def forward(self,
                input_ids: torch.Tensor,
                positions: torch.Tensor,
668
                kv_caches: List[torch.Tensor],
669
670
                attn_metadata: AttentionMetadata,
                intermediate_tensors: Optional[IntermediateTensors] = None,
671
                inputs_embeds: Optional[torch.Tensor] = None,
672
                **kwargs: object):
673

674
        if intermediate_tensors is not None:
675
            inputs_embeds = None
676
677
678

        # NOTE: In v1, inputs_embeds is always generated at model runner, this
        # condition is for v0 compatibility
679
        elif inputs_embeds is None:
680
            vision_embeddings = self.get_multimodal_embeddings(**kwargs)
681
682
683
            inputs_embeds = self.get_input_embeddings(input_ids,
                                                      vision_embeddings)
            input_ids = None
684

685
686
687
688
689
690
        hidden_states = self.language_model.model(input_ids,
                                                  positions,
                                                  kv_caches,
                                                  attn_metadata,
                                                  intermediate_tensors,
                                                  inputs_embeds=inputs_embeds)
691
692
693

        return hidden_states

694
695
696
697
698
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
699
700
        return self.language_model.compute_logits(hidden_states,
                                                  sampling_metadata)
701
702
703
704
705
706

    def sample(
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
707
        return self.language_model.sample(logits, sampling_metadata)
708

709
710
    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
711
712

        loader = AutoWeightsLoader(self)
713
        autoloaded_weights = loader.load_weights(weights,
714
                                                 mapper=self.hf_to_vllm_mapper)
715
716
717

        # The HF config doesn't specify whether these are tied,
        # so we detect it this way
718
        if "embed_tokens.weight" not in autoloaded_weights:
719
            self.embed_tokens = self.language_model.model.embed_tokens
720
721
            autoloaded_weights.add("embed_tokens.weight")
        return autoloaded_weights