phi3v.py 27 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
9
10
11
12
13
14
15
16
17
# Copyright 2024 The vLLM team.
# Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
18
from collections.abc import Iterable, Mapping, Sequence
19
from typing import Annotated, Any, Literal, Optional, Union
20

21
import regex as re
22
23
import torch
import torch.nn as nn
24
25
from transformers import (BatchFeature, CLIPVisionConfig, PretrainedConfig,
                          ProcessorMixin)
26

27
from vllm.config import VllmConfig
28
from vllm.logger import init_logger
29
from vllm.model_executor.layers.quantization import QuantizationConfig
30
31
from vllm.model_executor.layers.vocab_parallel_embedding import (
    VocabParallelEmbedding)
32
from vllm.multimodal import MULTIMODAL_REGISTRY
33
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
34
                                    MultiModalKwargsItems)
35
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
36
                                   ImageSize, MultiModalDataItems)
37
38
# yapf conflicts with isort for this block
# yapf: disable
39
from vllm.multimodal.processing import (BaseMultiModalProcessor,
40
41
                                        BaseProcessingInfo,
                                        MultiModalPromptUpdates,
42
                                        PlaceholderFeaturesInfo,
43
44
                                        PromptReplacement, PromptUpdate,
                                        ResolvedPromptUpdate)
45
# yapf: enable
46
from vllm.multimodal.profiling import BaseDummyInputsBuilder
47
from vllm.sequence import IntermediateTensors
48
from vllm.utils import is_list_of
49
from vllm.utils.tensor_schema import TensorSchema, TensorShape
50

51
from .clip import CLIPVisionModel
52
53
from .interfaces import (MultiModalEmbeddings, SupportsMultiModal, SupportsPP,
                         SupportsQuant)
54
55
56
from .utils import (AutoWeightsLoader, WeightsMapper,
                    _merge_multimodal_embeddings, flatten_bn,
                    init_vllm_registered_model, maybe_prefix)
57

58
59
logger = init_logger(__name__)

60
61
62
# Cannot find the following 2 numbers from hf config.
_IMAGE_TOKEN_ID = 32044

63
64
65
66
67
68
69
70
71
72
73
74
CLIP_VIT_LARGE_PATCH14_336_CONFIG = CLIPVisionConfig(dropout=0.0,
                                                     hidden_act="quick_gelu",
                                                     hidden_size=1024,
                                                     image_size=336,
                                                     intermediate_size=4096,
                                                     num_attention_heads=16,
                                                     num_channels=3,
                                                     num_hidden_layers=24,
                                                     patch_size=14,
                                                     projection_dim=768)


75
def _init_img_processor(hf_config: PretrainedConfig,
76
77
                        quant_config: Optional[QuantizationConfig],
                        prefix: str = "") -> CLIPVisionModel:
78
79
80
81
82
83
84
85
86
87
88
    clip_config = CLIP_VIT_LARGE_PATCH14_336_CONFIG
    layer_idx = hf_config.img_processor.get('layer_idx', -2)

    # Initialize the CLIP only up to the required feature layer
    if layer_idx < 0:
        num_hidden_layers = clip_config.num_hidden_layers + \
            layer_idx + 1
    else:
        num_hidden_layers = layer_idx + 1

    img_processor = CLIPVisionModel(
89
90
91
        clip_config,
        quant_config,
        num_hidden_layers_override=num_hidden_layers,
92
        prefix=prefix,
93
    )
94
95
96
97

    return img_processor


98
class Phi3VImagePixelInputs(TensorSchema):
99
    """
100
101
102
103
104
105
    Dimensions:
        - b: Batch size
        - n: Number of images
        - p: Number of patches
        - h: Height of each patch
        - w: Width of each patch
106
107
    """

108
    type: Literal["pixel_values", "image_embeds"] = "pixel_values"
109

110
111
112
113
114
115
    # Supports either a stacked tensor or a list of (p, 3, h, w) tensors
    data: Annotated[
        Union[torch.Tensor, list[torch.Tensor]],
        TensorShape("bn", "p", 3, "h", "w", dynamic_dims={"p"}
                    ),  # 'p' may vary across items
    ]
116

117
118
    # Stacked tensor with height and width for each image
    image_sizes: Annotated[Optional[torch.Tensor], TensorShape("bn", 2)]
119
120


121
class Phi3VImageEmbeddingInputs(TensorSchema):
122
    """
123
124
125
126
127
128
129
130
131
132
133
    Dimensions:
        - b: Batch size
        - n: Number of images
        - f: Image feature size (e.g., number of tokens per image)
        - h: Hidden size (must match language model backbone)
    """
    type: Literal["image_embeds"] = "image_embeds"
    data: Annotated[
        Union[torch.Tensor, list[torch.Tensor]],
        TensorShape("bn", "f", "h"),
    ]
134
135
136
137
138


Phi3VImageInputs = Union[Phi3VImagePixelInputs, Phi3VImageEmbeddingInputs]


139
140
class Phi3ImageEmbeddingBase(nn.Module):

141
    def __init__(self) -> None:
142
143
144
145
146
147
148
149
150
        super().__init__()
        self.layer_idx: int
        self.type_feature: str
        self.img_processor: CLIPVisionModel

    def get_img_features(self,
                         img_embeds: torch.FloatTensor) -> torch.FloatTensor:
        TYPE_FEATURE = self.type_feature

151
152
        # NOTE: we skip the step to select the vision feature layer since
        # this is already done inside the img_processor
153
        img_feature = self.img_processor(img_embeds)
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168

        if TYPE_FEATURE == "patch":
            patch_feature = img_feature[:, 1:]
            return patch_feature

        if TYPE_FEATURE == "cls_patch":
            return img_feature

        raise NotImplementedError


# adapted from https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_embedding_phi3_v.py
class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase):
    """Phi3 Image embedding with HD transform."""

169
170
171
172
    def __init__(self,
                 config: PretrainedConfig,
                 quant_config: Optional[QuantizationConfig],
                 prefix: str = "") -> None:
173
        super().__init__()
174
175
176
177
178

        # n_embed or hidden_size
        hidden_size = config.n_embd if hasattr(
            config, 'n_embd') else config.hidden_size

179
180
        self.img_processor = _init_img_processor(
            config, quant_config, prefix=f"{prefix}.img_processor")
181

182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
        image_dim_out = config.img_processor['image_dim_out']
        self.num_img_tokens = config.img_processor['num_img_tokens']

        self.image_dim_out = image_dim_out

        # global_gn and sub_gn for hd transform, serves as line separator
        self.use_hd_transform = config.embd_layer.get('use_hd_transform',
                                                      False)
        self.with_learnable_separator = config.embd_layer.get(
            'with_learnable_separator', False)
        self.hd_transform_order = config.embd_layer.get(
            'hd_transform_order', 'glb_sub')
        # with_hd_transform and with_learnable_separator should have same value
        assert self.use_hd_transform and self.with_learnable_separator

        # 1024 * 4, merge spatial to channel dimension
        self.glb_GN = nn.Parameter(torch.empty([1, 1, self.image_dim_out * 4]))
        self.sub_GN = nn.Parameter(
            torch.empty([1, 1, 1, self.image_dim_out * 4]))

        dim_projection = hidden_size
        depth = 2
        layers = [nn.Linear(image_dim_out * 4, dim_projection)]
        for _ in range(1, depth):
            layers.extend(
                [nn.GELU(),
                 nn.Linear(dim_projection, dim_projection)])
        self.img_projection = nn.Sequential(*layers)

        self.type_feature = config.img_processor.get('type_feature', 'patch')

213
    def forward(self, pixel_values: torch.FloatTensor,
214
                image_sizes: torch.Tensor) -> torch.FloatTensor:
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
        """
        process image and return vision embeddings.

        pixel_values: (num_images, num_crops, c, h, w)
        output: (num_images, num_img_tokens, hidden_size)
        """
        num_images, num_crops, c, h, w = pixel_values.shape
        pixel_values = pixel_values.flatten(0, 1)
        img_features = self.get_img_features(pixel_values)
        img_features = img_features.reshape(num_images, num_crops, -1,
                                            self.image_dim_out)
        image_features_proj = self.hd_feature_transform(
            img_features, image_sizes)
        return image_features_proj

    def hd_feature_transform(self, image_features, image_sizes):
        """
        image_features: (num_images, num_crops+1, 24*24, 1024)
        """
        assert (
            self.hd_transform_order == 'sub_glb'
        ), f'hd_transform_order `{self.hd_transform_order}` not implemented'
        if isinstance(self.img_projection, nn.Sequential):
            target_device = self.img_projection[0].bias.device
            target_dtype = self.img_projection[0].bias.dtype
        else:  # It's a single nn.Linear layer
            target_device = self.img_projection.bias.device
            target_dtype = self.img_projection.bias.dtype

        global_image_features = image_features[:,
                                               0]  # (num_images, 24*24, 1024)
        # global feature can be viewed as a special HD case with num_crops 1x1
        global_image_features_hd = self.reshape_hd_patches_2x2merge(
            global_image_features, 1, 1)
        global_image_features_hd_newline = self.add_image_newline(
            global_image_features_hd)

252
        batch_image_features_proj = []
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
        # need a for loop to process each image because of different image sizes
        # (patch arrangement is different for each image)
        for i, img_size in enumerate(image_sizes):
            h, w = img_size
            h_crop = h // 336
            w_crop = w // 336
            num_crops = h_crop * w_crop

            # NOTE: real num_crops is padded
            # (num_crops, 24*24, 1024)
            sub_image_features = image_features[i, 1:1 + num_crops]
            sub_image_features_hd = self.reshape_hd_patches_2x2merge(
                sub_image_features, h_crop, w_crop)
            sub_image_features_hd_newline = self.add_image_newline(
                sub_image_features_hd)

            # [sub features, separator, global features]
270
271
272
273
274
275
276
277
278
279
280
            image_embeddings = torch.cat([
                sub_image_features_hd_newline.squeeze(
                    0),  # (h_crop*12*(w_crop*12+1), 4096)
                self.glb_GN.squeeze(0),
                global_image_features_hd_newline[i],
            ])
            img_proj = self.img_projection(
                image_embeddings.to(target_device, target_dtype))
            batch_image_features_proj.append(img_proj)

        return batch_image_features_proj
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317

    def reshape_hd_patches_2x2merge(self, image_features, h_crop, w_crop):
        """
        image_features: (num_images*num_crops, 24*24, 1024)
        output: (num_images, h_crop*12, w_crop*12, 4096)
        where h_crop*w_crop == num_crops
        """
        N, L, C = image_features.shape
        assert L == 576 and C == 1024 and N % (h_crop * w_crop) == 0
        num_images = N // (h_crop * w_crop)
        H = int(L**0.5)
        image_features_hd = (
            image_features.reshape(N, H, H, C)  # N, 24, 24, 1024
            .reshape(N, H // 2, 2, H // 2, 2, C)  # N, 12, 2, 12, 2, 1024
            .permute(0, 1, 3, 2, 4, 5)  # N, 12, 12, 2, 2, 1024
            .reshape(N, -1, 4 * C)  # N, 144, 4096
            .reshape(num_images, h_crop, w_crop, H // 2, H // 2,
                     -1)  # n_img, h_crop, w_crop, 12, 12, 4096
            .permute(0, 1, 3, 2, 4, 5)  # n_img, h_crop, 12, w_crop, 12, 4096
            .reshape(num_images, h_crop * H // 2, w_crop * H // 2,
                     4 * C)  # n_img, h_crop*12, w_crop*12, 4096
        )
        return image_features_hd

    def add_image_newline(self, image_features_hd):
        """
        image_features_hd: (num_images, h_crop*12, w_crop*12, 4096)
        output: (num_images, (h_crop*12) * (w_crop*12+1), 4096)
        """
        num_images, h, w, hid_dim = image_features_hd.shape
        # add the newline token to the HD image feature patches
        newline_embeddings = self.sub_GN.expand(num_images, h, -1,
                                                -1)  # (n_img, h, 1, hid_dim)
        image_features_hd_newline = torch.cat(
            [image_features_hd, newline_embeddings],
            dim=2).reshape(num_images, -1, hid_dim)
        return image_features_hd_newline
318
319


320
class Phi3VProcessingInfo(BaseProcessingInfo):
321

322
323
324
    def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
        return {"image": None}

325
326
327
328
329
    def get_num_image_tokens(
        self,
        *,
        image_width: int,
        image_height: int,
330
        processor: Optional[ProcessorMixin] = None,
331
332
333
334
335
336
337
338
339
340
    ) -> int:
        if processor is None:
            processor = self.get_hf_processor()

        return processor.calc_num_image_tokens_from_image_size(  # type: ignore
            width=image_width,
            height=image_height,
        )

    def get_image_size_with_most_features(self) -> ImageSize:
341
342
343
        # Result in the max possible feature size (h:w = 16:1)
        return ImageSize(height=8000, width=50)

344
345
346

class Phi3VDummyInputsBuilder(BaseDummyInputsBuilder[Phi3VProcessingInfo]):

347
348
349
350
351
352
353
354
355
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_images = mm_counts.get("image", 0)

        hf_processor = self.info.get_hf_processor()
        image_tokens: list[str] = hf_processor.img_tokens  # type: ignore

        return "".join(image_tokens[:num_images])

    def get_dummy_mm_data(
356
        self,
357
358
        seq_len: int,
        mm_counts: Mapping[str, int],
359
    ) -> MultiModalDataDict:
360
        num_images = mm_counts.get("image", 0)
361

362
363
        target_width, target_height = \
            self.info.get_image_size_with_most_features()
364

365
        return {
366
367
368
369
370
371
372
            "image":
            self._get_dummy_images(width=target_width,
                                   height=target_height,
                                   num_images=num_images)
        }


373
class Phi3VMultiModalProcessor(BaseMultiModalProcessor[Phi3VProcessingInfo]):
374

375
    def _call_hf_processor(
376
377
        self,
        prompt: str,
378
379
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
380
        tok_kwargs: Mapping[str, object],
381
    ) -> BatchFeature:
382
383
        processed_outputs = super()._call_hf_processor(
            prompt=prompt,
384
385
            mm_data=mm_data,
            mm_kwargs=mm_kwargs,
386
            tok_kwargs=tok_kwargs,
387
388
        )

389
390
391
        input_ids = processed_outputs["input_ids"]
        assert isinstance(input_ids, torch.Tensor)

392
393
394
        # Phi3v processor has inserted -1, -2 etc as placeholder in prompt_ids,
        # which will cause OverflowError when decoding the prompt_ids.
        # Therefore, we need to do an early replacement here
395
        input_ids.masked_fill_(input_ids < 0, _IMAGE_TOKEN_ID)
396

397
398
        return processed_outputs

399
400
401
402
403
404
405
406
407
408
409
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return dict(
            pixel_values=MultiModalFieldConfig.batched("image"),
            image_sizes=MultiModalFieldConfig.batched("image"),
            image_embeds=MultiModalFieldConfig.batched("image"),
        )

410
    def _get_prompt_updates(
411
412
        self,
        mm_items: MultiModalDataItems,
413
        hf_processor_mm_kwargs: Mapping[str, Any],
414
        out_mm_kwargs: MultiModalKwargsItems,
415
    ) -> Sequence[PromptUpdate]:
416
        hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
417
418
419
        image_tokens: list[str] = hf_processor.img_tokens  # type: ignore

        def get_replacement_phi3v(item_idx: int):
420
421
422
423
424
425
426
            images = mm_items.get_items(
                "image", (ImageEmbeddingItems, ImageProcessorItems))

            if isinstance(images, ImageEmbeddingItems):
                num_image_tokens = images.get_feature_size(item_idx)
            else:
                image_size = images.get_image_size(item_idx)
427
                num_image_tokens = self.info.get_num_image_tokens(
428
429
                    image_width=image_size.width,
                    image_height=image_size.height,
430
                    processor=hf_processor,
431
432
                )

433
            return [_IMAGE_TOKEN_ID] * num_image_tokens
434
435
436
437

        return [
            PromptReplacement(
                modality="image",
438
                target=image_tokens.__getitem__,
439
                replacement=get_replacement_phi3v,
440
            )
441
442
        ]

443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
    def _recompute_cached_prompt_update(
        self,
        cached_update: ResolvedPromptUpdate,
        new_item_idx: int,
    ) -> ResolvedPromptUpdate:
        new_update = super()._recompute_cached_prompt_update(
            cached_update,
            new_item_idx,
        )

        if cached_update.modality == "image":
            hf_processor = self.info.get_hf_processor()
            image_tokens: list[str] = hf_processor.img_tokens  # type: ignore
            new_update = new_update.with_target(image_tokens[new_item_idx])

        return new_update

460
    def _apply_prompt_updates(
461
462
        self,
        token_ids: list[int],
463
        mm_prompt_updates: MultiModalPromptUpdates,
464
    ) -> tuple[list[int], str, Mapping[str, list[PlaceholderFeaturesInfo]]]:
pansicheng's avatar
pansicheng committed
465
        # align to hf behavior when there are images
466
        if len(mm_prompt_updates):
pansicheng's avatar
pansicheng committed
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
            tokenizer = self.info.get_tokenizer()
            # to decode token_ids to the original text, we need to
            # 1. remove the first bos token
            # 2. remove space after each special token
            #    introduced by the tokenizer
            if len(token_ids) and token_ids[0] == tokenizer.bos_token_id:
                token_ids = token_ids[1:]
            text = tokenizer.decode(token_ids)
            for special_tokens in tokenizer.special_tokens_map.values():
                if isinstance(special_tokens, str):
                    text = text.replace(f"{special_tokens} ", special_tokens)
                elif isinstance(special_tokens, list):
                    for special_token in special_tokens:
                        text = text.replace(f"{special_token} ", special_token)
            # perform hf behavior
            # https://huggingface.co/microsoft/Phi-3.5-vision-instruct/blob/64f88b6/processing_phi3_v.py#L407
            pattern = r"<\|image_\d+\|>"
            prompt_chunks = [
                tokenizer(chunk).input_ids
                for chunk in re.split(pattern, text)
            ]
            image_tags = [
                tokenizer(chunk, add_special_tokens=False).input_ids
                for chunk in re.findall(pattern, text)
            ]
            if len(prompt_chunks) > len(image_tags):
                image_tags.append([])
            token_ids = [
                e for sublist in zip(prompt_chunks, image_tags)
                for ele in sublist for e in ele
            ]

499
        token_ids, text, placeholders = super()._apply_prompt_updates(
500
            token_ids=token_ids,
501
            mm_prompt_updates=mm_prompt_updates,
502
503
504
505
506
507
        )

        # Keep the behavior in line with HF processor
        if text.startswith("<s> <|image|>"):
            text = text.replace("<s> <|image|>", "<s><|image|>", 1)
            token_ids = [token_ids[0], *token_ids[2:]]
508
509
            placeholders = {
                modality: [
510
                    PlaceholderFeaturesInfo(
511
512
513
                        modality=p.modality,
                        item_idx=p.item_idx,
                        start_idx=p.start_idx - 1,
514
                        tokens=p.tokens,
515
                        is_embed=p.is_embed,
516
517
518
519
                    ) for p in ps
                ]
                for modality, ps in placeholders.items()
            }
520
521
522

        return token_ids, text, placeholders

523

524
525
526
@MULTIMODAL_REGISTRY.register_processor(Phi3VMultiModalProcessor,
                                        info=Phi3VProcessingInfo,
                                        dummy_inputs=Phi3VDummyInputsBuilder)
527
528
class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP,
                       SupportsQuant):
529
530
531
532
533
534
535
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={
            "model.vision_embed_tokens.wte": "embed_tokens",
            "model.vision_embed_tokens.": "vision_embed_tokens.",
            "lm_head.": "language_model.lm_head.",
            "model.": "language_model.model.",
        })
536

537
538
539
540
541
542
543
    @classmethod
    def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
        if modality.startswith("image"):
            return f"<|image_{i}|>"

        raise ValueError("Only image modality is supported")

544
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
545
        super().__init__()
546
547
        config = vllm_config.model_config.hf_config
        multimodal_config = vllm_config.model_config.multimodal_config
548
        self.config = config
549
        self.multimodal_config = multimodal_config
550
        self.image_token_id = _IMAGE_TOKEN_ID
551

552
553
554
555
        self.embed_tokens = VocabParallelEmbedding(
            config.vocab_size,
            config.hidden_size,
            org_num_embeddings=config.vocab_size,
556
            quant_config=self.quant_config,
557
            prefix=maybe_prefix(prefix, "model.embed_tokens"),
558
559
560
        )

        # TODO: Optionally initializes this for supporting input embeddings.
561
        self.vision_embed_tokens = Phi3HDImageEmbedding(
562
            config,
563
            self.quant_config,
564
            prefix=maybe_prefix(prefix, "model.vision_embed_tokens"))
565

566
567
568
569
570
571
572
573
574
575
576
        self.language_model = init_vllm_registered_model(
            vllm_config=vllm_config,
            # The prefix is empty intentionally because default prefix of
            # LlamaForCausalLM is "model"
            prefix="",
            # We don't directly initialize vLLM's LlamaForCausalLM so we
            # can automatically apply embedding wrapper if this model is
            # initialized as an embedding model
            architectures=["LlamaForCausalLM"],
        )

577
578
579
        self.make_empty_intermediate_tensors = (
            self.language_model.make_empty_intermediate_tensors)

580
    def _parse_and_validate_image_input(
581
            self, **kwargs: object) -> Optional[Phi3VImageInputs]:
582
583
        pixel_values = kwargs.pop("pixel_values", None)
        image_sizes = kwargs.pop("image_sizes", None)
584
        image_embeds = kwargs.pop("image_embeds", None)
585

586
587
588
589
590
591
        if pixel_values is None and image_embeds is None:
            return None

        if pixel_values is not None:
            return Phi3VImagePixelInputs(
                type="pixel_values",
592
593
594
595
596
597
                data=flatten_bn(pixel_values),
                image_sizes=flatten_bn(image_sizes, concat=True),
                resolve_bindings={
                    "h": CLIP_VIT_LARGE_PATCH14_336_CONFIG.image_size,
                    "w": CLIP_VIT_LARGE_PATCH14_336_CONFIG.image_size
                })
598
599
600
601

        if image_embeds is not None:
            return Phi3VImageEmbeddingInputs(
                type="image_embeds",
602
                data=flatten_bn(image_embeds),
603
604
605
606
607
608
609
610
611
612
            )

        raise AssertionError("This line should be unreachable.")

    def _process_image_input(
        self,
        image_input: Phi3VImageInputs,
    ) -> torch.Tensor:

        if image_input["type"] == "image_embeds":
613
614
615
616
617
618
619
620
            image_data = image_input["data"]
            if is_list_of(image_data, torch.Tensor):
                # it's already a list of tensors
                return image_data
            if len(image_data.shape) == 3:
                # 3D tensor
                return list(torch.unbind(image_data, dim=0))
            raise ValueError(
621
                "We expect batched 2D tensors; "
622
623
                "this can be either a list of 2D tensors or a single 3D tensor."
            )
624

625
626
627
        assert self.vision_embed_tokens is not None
        image_embeds = self.vision_embed_tokens(image_input["data"],
                                                image_input["image_sizes"])
628

629
        return image_embeds
630

631
632
633
    def get_language_model(self) -> torch.nn.Module:
        return self.language_model

634
635
    def get_multimodal_embeddings(self,
                                  **kwargs: object) -> MultiModalEmbeddings:
636
637
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input is None:
638
            return []
639
640
641
642
643
644
        vision_embeddings = self._process_image_input(image_input)
        return vision_embeddings

    def get_input_embeddings(
        self,
        input_ids: torch.Tensor,
645
        multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
646
647
648
        *,
        is_multimodal: Optional[torch.Tensor] = None,
        handle_oov_mm_token: bool = False,
649
    ) -> torch.Tensor:
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
        inputs_embeds = self._get_text_embeddings(
            input_ids,
            self.embed_tokens,
            is_multimodal=is_multimodal,
            handle_oov_mm_token=handle_oov_mm_token,
        )

        if multimodal_embeddings is None or len(multimodal_embeddings) == 0:
            return inputs_embeds

        if is_multimodal is None:
            raise ValueError(
                "`get_input_embeddings` now requires `is_multimodal` arg, "
                "please update your model runner according to "
                "https://github.com/vllm-project/vllm/pull/16229.")

        return _merge_multimodal_embeddings(
            inputs_embeds=inputs_embeds,
            multimodal_embeddings=multimodal_embeddings,
            is_multimodal=is_multimodal,
        )
671

672
673
674
675
    def forward(self,
                input_ids: torch.Tensor,
                positions: torch.Tensor,
                intermediate_tensors: Optional[IntermediateTensors] = None,
676
                inputs_embeds: Optional[torch.Tensor] = None,
677
                **kwargs: object):
678

679
        if intermediate_tensors is not None:
680
            inputs_embeds = None
681

682
683
684
685
        hidden_states = self.language_model.model(input_ids,
                                                  positions,
                                                  intermediate_tensors,
                                                  inputs_embeds=inputs_embeds)
686
687
688

        return hidden_states

689
690
691
692
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
    ) -> Optional[torch.Tensor]:
693
        return self.language_model.compute_logits(hidden_states)
694

695
696
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
697
698

        loader = AutoWeightsLoader(self)
699
        autoloaded_weights = loader.load_weights(weights,
700
                                                 mapper=self.hf_to_vllm_mapper)
701
702
703

        # The HF config doesn't specify whether these are tied,
        # so we detect it this way
704
        if "embed_tokens.weight" not in autoloaded_weights:
705
            self.embed_tokens = self.language_model.model.embed_tokens
706
707
            autoloaded_weights.add("embed_tokens.weight")
        return autoloaded_weights