phi3v.py 26 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
9
10
11
12
13
14
15
16
17
# Copyright 2024 The vLLM team.
# Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
18
from collections.abc import Iterable, Mapping, Sequence
19
from typing import Annotated, Any, Literal, TypeAlias
20

21
import regex as re
22
23
import torch
import torch.nn as nn
24
25
26
27
28
29
from transformers import (
    BatchFeature,
    CLIPVisionConfig,
    PretrainedConfig,
    ProcessorMixin,
)
30

31
from vllm.config import VllmConfig
32
from vllm.config.multimodal import BaseDummyOptions
33
from vllm.logger import init_logger
34
from vllm.model_executor.layers.quantization import QuantizationConfig
35
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
36
from vllm.multimodal import MULTIMODAL_REGISTRY
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
from vllm.multimodal.inputs import (
    MultiModalDataDict,
    MultiModalFieldConfig,
    MultiModalKwargsItems,
)
from vllm.multimodal.parse import (
    ImageEmbeddingItems,
    ImageProcessorItems,
    ImageSize,
    MultiModalDataItems,
)
from vllm.multimodal.processing import (
    BaseMultiModalProcessor,
    BaseProcessingInfo,
    MultiModalPromptUpdates,
    PlaceholderFeaturesInfo,
    PromptReplacement,
    PromptUpdate,
    ResolvedPromptUpdate,
)
57
from vllm.multimodal.profiling import BaseDummyInputsBuilder
58
from vllm.sequence import IntermediateTensors
59
from vllm.utils import is_list_of
60
from vllm.utils.tensor_schema import TensorSchema, TensorShape
61

62
from .clip import CLIPVisionModel
63
64
65
66
67
68
69
70
71
72
73
74
75
76
from .interfaces import (
    MultiModalEmbeddings,
    SupportsMultiModal,
    SupportsPP,
    SupportsQuant,
)
from .utils import (
    AutoWeightsLoader,
    WeightsMapper,
    _merge_multimodal_embeddings,
    flatten_bn,
    init_vllm_registered_model,
    maybe_prefix,
)
77

78
79
logger = init_logger(__name__)

80
81
82
# Cannot find the following 2 numbers from hf config.
_IMAGE_TOKEN_ID = 32044

83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
CLIP_VIT_LARGE_PATCH14_336_CONFIG = CLIPVisionConfig(
    dropout=0.0,
    hidden_act="quick_gelu",
    hidden_size=1024,
    image_size=336,
    intermediate_size=4096,
    num_attention_heads=16,
    num_channels=3,
    num_hidden_layers=24,
    patch_size=14,
    projection_dim=768,
)


def _init_img_processor(
    hf_config: PretrainedConfig,
99
    quant_config: QuantizationConfig | None,
100
101
    prefix: str = "",
) -> CLIPVisionModel:
102
    clip_config = CLIP_VIT_LARGE_PATCH14_336_CONFIG
103
    layer_idx = hf_config.img_processor.get("layer_idx", -2)
104
105
106

    # Initialize the CLIP only up to the required feature layer
    if layer_idx < 0:
107
        num_hidden_layers = clip_config.num_hidden_layers + layer_idx + 1
108
109
110
111
    else:
        num_hidden_layers = layer_idx + 1

    img_processor = CLIPVisionModel(
112
113
114
        clip_config,
        quant_config,
        num_hidden_layers_override=num_hidden_layers,
115
        prefix=prefix,
116
    )
117
118
119
120

    return img_processor


121
class Phi3VImagePixelInputs(TensorSchema):
122
    """
123
124
125
126
127
128
    Dimensions:
        - b: Batch size
        - n: Number of images
        - p: Number of patches
        - h: Height of each patch
        - w: Width of each patch
129
130
    """

131
    type: Literal["pixel_values", "image_embeds"] = "pixel_values"
132

133
    # Supports either a stacked tensor or a list of (p, 3, h, w) tensors
134
    pixel_values: Annotated[
135
        torch.Tensor | list[torch.Tensor],
136
137
138
        TensorShape(
            "bn", "p", 3, "h", "w", dynamic_dims={"p"}
        ),  # 'p' may vary across items
139
    ]
140

141
    # Stacked tensor with height and width for each image
142
    image_sizes: Annotated[torch.Tensor | None, TensorShape("bn", 2)]
143
144


145
class Phi3VImageEmbeddingInputs(TensorSchema):
146
    """
147
148
149
150
151
152
    Dimensions:
        - b: Batch size
        - n: Number of images
        - f: Image feature size (e.g., number of tokens per image)
        - h: Hidden size (must match language model backbone)
    """
153

154
155
    type: Literal["image_embeds"] = "image_embeds"
    data: Annotated[
156
        torch.Tensor | list[torch.Tensor],
157
158
        TensorShape("bn", "f", "h"),
    ]
159
160


161
Phi3VImageInputs: TypeAlias = Phi3VImagePixelInputs | Phi3VImageEmbeddingInputs
162
163


164
class Phi3ImageEmbeddingBase(nn.Module):
165
    def __init__(self) -> None:
166
167
168
169
170
        super().__init__()
        self.layer_idx: int
        self.type_feature: str
        self.img_processor: CLIPVisionModel

171
    def get_img_features(self, img_embeds: torch.FloatTensor) -> torch.FloatTensor:
172
173
        TYPE_FEATURE = self.type_feature

174
175
        # NOTE: we skip the step to select the vision feature layer since
        # this is already done inside the img_processor
176
        img_feature = self.img_processor(img_embeds)
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191

        if TYPE_FEATURE == "patch":
            patch_feature = img_feature[:, 1:]
            return patch_feature

        if TYPE_FEATURE == "cls_patch":
            return img_feature

        raise NotImplementedError


# adapted from https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_embedding_phi3_v.py
class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase):
    """Phi3 Image embedding with HD transform."""

192
193
194
    def __init__(
        self,
        config: PretrainedConfig,
195
        quant_config: QuantizationConfig | None,
196
197
        prefix: str = "",
    ) -> None:
198
        super().__init__()
199
200

        # n_embed or hidden_size
201
        hidden_size = config.n_embd if hasattr(config, "n_embd") else config.hidden_size
202

203
        self.img_processor = _init_img_processor(
204
205
            config, quant_config, prefix=f"{prefix}.img_processor"
        )
206

207
208
        image_dim_out = config.img_processor["image_dim_out"]
        self.num_img_tokens = config.img_processor["num_img_tokens"]
209
210
211
212

        self.image_dim_out = image_dim_out

        # global_gn and sub_gn for hd transform, serves as line separator
213
        self.use_hd_transform = config.embd_layer.get("use_hd_transform", False)
214
        self.with_learnable_separator = config.embd_layer.get(
215
216
217
            "with_learnable_separator", False
        )
        self.hd_transform_order = config.embd_layer.get("hd_transform_order", "glb_sub")
218
219
220
221
222
        # with_hd_transform and with_learnable_separator should have same value
        assert self.use_hd_transform and self.with_learnable_separator

        # 1024 * 4, merge spatial to channel dimension
        self.glb_GN = nn.Parameter(torch.empty([1, 1, self.image_dim_out * 4]))
223
        self.sub_GN = nn.Parameter(torch.empty([1, 1, 1, self.image_dim_out * 4]))
224
225
226
227
228

        dim_projection = hidden_size
        depth = 2
        layers = [nn.Linear(image_dim_out * 4, dim_projection)]
        for _ in range(1, depth):
229
            layers.extend([nn.GELU(), nn.Linear(dim_projection, dim_projection)])
230
231
        self.img_projection = nn.Sequential(*layers)

232
        self.type_feature = config.img_processor.get("type_feature", "patch")
233

234
235
236
    def forward(
        self, pixel_values: torch.FloatTensor, image_sizes: torch.Tensor
    ) -> torch.FloatTensor:
237
238
239
240
241
242
243
244
245
        """
        process image and return vision embeddings.

        pixel_values: (num_images, num_crops, c, h, w)
        output: (num_images, num_img_tokens, hidden_size)
        """
        num_images, num_crops, c, h, w = pixel_values.shape
        pixel_values = pixel_values.flatten(0, 1)
        img_features = self.get_img_features(pixel_values)
246
247
248
249
        img_features = img_features.reshape(
            num_images, num_crops, -1, self.image_dim_out
        )
        image_features_proj = self.hd_feature_transform(img_features, image_sizes)
250
251
252
253
254
255
        return image_features_proj

    def hd_feature_transform(self, image_features, image_sizes):
        """
        image_features: (num_images, num_crops+1, 24*24, 1024)
        """
256
257
258
        assert self.hd_transform_order == "sub_glb", (
            f"hd_transform_order `{self.hd_transform_order}` not implemented"
        )
259
260
261
262
263
264
265
        if isinstance(self.img_projection, nn.Sequential):
            target_device = self.img_projection[0].bias.device
            target_dtype = self.img_projection[0].bias.dtype
        else:  # It's a single nn.Linear layer
            target_device = self.img_projection.bias.device
            target_dtype = self.img_projection.bias.dtype

266
        global_image_features = image_features[:, 0]  # (num_images, 24*24, 1024)
267
268
        # global feature can be viewed as a special HD case with num_crops 1x1
        global_image_features_hd = self.reshape_hd_patches_2x2merge(
269
270
            global_image_features, 1, 1
        )
271
        global_image_features_hd_newline = self.add_image_newline(
272
273
            global_image_features_hd
        )
274

275
        batch_image_features_proj = []
276
277
278
279
280
281
282
283
284
285
        # need a for loop to process each image because of different image sizes
        # (patch arrangement is different for each image)
        for i, img_size in enumerate(image_sizes):
            h, w = img_size
            h_crop = h // 336
            w_crop = w // 336
            num_crops = h_crop * w_crop

            # NOTE: real num_crops is padded
            # (num_crops, 24*24, 1024)
286
            sub_image_features = image_features[i, 1 : 1 + num_crops]
287
            sub_image_features_hd = self.reshape_hd_patches_2x2merge(
288
289
                sub_image_features, h_crop, w_crop
            )
290
            sub_image_features_hd_newline = self.add_image_newline(
291
292
                sub_image_features_hd
            )
293
294

            # [sub features, separator, global features]
295
296
297
298
299
300
301
302
303
            image_embeddings = torch.cat(
                [
                    sub_image_features_hd_newline.squeeze(
                        0
                    ),  # (h_crop*12*(w_crop*12+1), 4096)
                    self.glb_GN.squeeze(0),
                    global_image_features_hd_newline[i],
                ]
            )
304
            img_proj = self.img_projection(
305
306
                image_embeddings.to(target_device, target_dtype)
            )
307
308
309
            batch_image_features_proj.append(img_proj)

        return batch_image_features_proj
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325

    def reshape_hd_patches_2x2merge(self, image_features, h_crop, w_crop):
        """
        image_features: (num_images*num_crops, 24*24, 1024)
        output: (num_images, h_crop*12, w_crop*12, 4096)
        where h_crop*w_crop == num_crops
        """
        N, L, C = image_features.shape
        assert L == 576 and C == 1024 and N % (h_crop * w_crop) == 0
        num_images = N // (h_crop * w_crop)
        H = int(L**0.5)
        image_features_hd = (
            image_features.reshape(N, H, H, C)  # N, 24, 24, 1024
            .reshape(N, H // 2, 2, H // 2, 2, C)  # N, 12, 2, 12, 2, 1024
            .permute(0, 1, 3, 2, 4, 5)  # N, 12, 12, 2, 2, 1024
            .reshape(N, -1, 4 * C)  # N, 144, 4096
326
327
328
            .reshape(
                num_images, h_crop, w_crop, H // 2, H // 2, -1
            )  # n_img, h_crop, w_crop, 12, 12, 4096
329
            .permute(0, 1, 3, 2, 4, 5)  # n_img, h_crop, 12, w_crop, 12, 4096
330
331
332
            .reshape(
                num_images, h_crop * H // 2, w_crop * H // 2, 4 * C
            )  # n_img, h_crop*12, w_crop*12, 4096
333
334
335
336
337
338
339
340
341
342
        )
        return image_features_hd

    def add_image_newline(self, image_features_hd):
        """
        image_features_hd: (num_images, h_crop*12, w_crop*12, 4096)
        output: (num_images, (h_crop*12) * (w_crop*12+1), 4096)
        """
        num_images, h, w, hid_dim = image_features_hd.shape
        # add the newline token to the HD image feature patches
343
344
345
        newline_embeddings = self.sub_GN.expand(
            num_images, h, -1, -1
        )  # (n_img, h, 1, hid_dim)
346
        image_features_hd_newline = torch.cat(
347
348
            [image_features_hd, newline_embeddings], dim=2
        ).reshape(num_images, -1, hid_dim)
349
        return image_features_hd_newline
350
351


352
class Phi3VProcessingInfo(BaseProcessingInfo):
353
    def get_supported_mm_limits(self) -> Mapping[str, int | None]:
354
355
        return {"image": None}

356
357
358
359
360
    def get_num_image_tokens(
        self,
        *,
        image_width: int,
        image_height: int,
361
        processor: ProcessorMixin | None = None,
362
363
364
365
366
367
368
369
370
371
    ) -> int:
        if processor is None:
            processor = self.get_hf_processor()

        return processor.calc_num_image_tokens_from_image_size(  # type: ignore
            width=image_width,
            height=image_height,
        )

    def get_image_size_with_most_features(self) -> ImageSize:
372
373
374
        # Result in the max possible feature size (h:w = 16:1)
        return ImageSize(height=8000, width=50)

375
376

class Phi3VDummyInputsBuilder(BaseDummyInputsBuilder[Phi3VProcessingInfo]):
377
378
379
380
381
382
383
384
385
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_images = mm_counts.get("image", 0)

        hf_processor = self.info.get_hf_processor()
        image_tokens: list[str] = hf_processor.img_tokens  # type: ignore

        return "".join(image_tokens[:num_images])

    def get_dummy_mm_data(
386
        self,
387
388
        seq_len: int,
        mm_counts: Mapping[str, int],
389
        mm_options: Mapping[str, BaseDummyOptions] | None = None,
390
    ) -> MultiModalDataDict:
391
        num_images = mm_counts.get("image", 0)
392

393
        target_width, target_height = self.info.get_image_size_with_most_features()
394

395
396
        image_overrides = mm_options.get("image") if mm_options else None

397
        return {
398
399
400
401
402
403
            "image": self._get_dummy_images(
                width=target_width,
                height=target_height,
                num_images=num_images,
                overrides=image_overrides,
            )
404
405
406
        }


407
class Phi3VMultiModalProcessor(BaseMultiModalProcessor[Phi3VProcessingInfo]):
408
    def _call_hf_processor(
409
410
        self,
        prompt: str,
411
412
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
413
        tok_kwargs: Mapping[str, object],
414
    ) -> BatchFeature:
415
416
        processed_outputs = super()._call_hf_processor(
            prompt=prompt,
417
418
            mm_data=mm_data,
            mm_kwargs=mm_kwargs,
419
            tok_kwargs=tok_kwargs,
420
421
        )

422
423
424
        input_ids = processed_outputs["input_ids"]
        assert isinstance(input_ids, torch.Tensor)

425
426
427
        # Phi3v processor has inserted -1, -2 etc as placeholder in prompt_ids,
        # which will cause OverflowError when decoding the prompt_ids.
        # Therefore, we need to do an early replacement here
428
        input_ids.masked_fill_(input_ids < 0, _IMAGE_TOKEN_ID)
429

430
431
        return processed_outputs

432
433
434
435
436
437
438
439
440
441
442
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return dict(
            pixel_values=MultiModalFieldConfig.batched("image"),
            image_sizes=MultiModalFieldConfig.batched("image"),
            image_embeds=MultiModalFieldConfig.batched("image"),
        )

443
    def _get_prompt_updates(
444
445
        self,
        mm_items: MultiModalDataItems,
446
        hf_processor_mm_kwargs: Mapping[str, Any],
447
        out_mm_kwargs: MultiModalKwargsItems,
448
    ) -> Sequence[PromptUpdate]:
449
        hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
450
451
452
        image_tokens: list[str] = hf_processor.img_tokens  # type: ignore

        def get_replacement_phi3v(item_idx: int):
453
            images = mm_items.get_items(
454
455
                "image", (ImageEmbeddingItems, ImageProcessorItems)
            )
456
457
458
459
460

            if isinstance(images, ImageEmbeddingItems):
                num_image_tokens = images.get_feature_size(item_idx)
            else:
                image_size = images.get_image_size(item_idx)
461
                num_image_tokens = self.info.get_num_image_tokens(
462
463
                    image_width=image_size.width,
                    image_height=image_size.height,
464
                    processor=hf_processor,
465
466
                )

467
            return [_IMAGE_TOKEN_ID] * num_image_tokens
468
469
470
471

        return [
            PromptReplacement(
                modality="image",
472
                target=image_tokens.__getitem__,
473
                replacement=get_replacement_phi3v,
474
            )
475
476
        ]

477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
    def _recompute_cached_prompt_update(
        self,
        cached_update: ResolvedPromptUpdate,
        new_item_idx: int,
    ) -> ResolvedPromptUpdate:
        new_update = super()._recompute_cached_prompt_update(
            cached_update,
            new_item_idx,
        )

        if cached_update.modality == "image":
            hf_processor = self.info.get_hf_processor()
            image_tokens: list[str] = hf_processor.img_tokens  # type: ignore
            new_update = new_update.with_target(image_tokens[new_item_idx])

        return new_update

494
    def _apply_prompt_updates(
495
496
        self,
        token_ids: list[int],
497
        mm_prompt_updates: MultiModalPromptUpdates,
498
    ) -> tuple[list[int], Mapping[str, list[PlaceholderFeaturesInfo]]]:
pansicheng's avatar
pansicheng committed
499
        # align to hf behavior when there are images
500
        if len(mm_prompt_updates):
pansicheng's avatar
pansicheng committed
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
            tokenizer = self.info.get_tokenizer()
            # to decode token_ids to the original text, we need to
            # 1. remove the first bos token
            # 2. remove space after each special token
            #    introduced by the tokenizer
            if len(token_ids) and token_ids[0] == tokenizer.bos_token_id:
                token_ids = token_ids[1:]
            text = tokenizer.decode(token_ids)
            for special_tokens in tokenizer.special_tokens_map.values():
                if isinstance(special_tokens, str):
                    text = text.replace(f"{special_tokens} ", special_tokens)
                elif isinstance(special_tokens, list):
                    for special_token in special_tokens:
                        text = text.replace(f"{special_token} ", special_token)
            # perform hf behavior
            # https://huggingface.co/microsoft/Phi-3.5-vision-instruct/blob/64f88b6/processing_phi3_v.py#L407
            pattern = r"<\|image_\d+\|>"
            prompt_chunks = [
519
                tokenizer(chunk).input_ids for chunk in re.split(pattern, text)
pansicheng's avatar
pansicheng committed
520
521
522
523
524
525
526
527
            ]
            image_tags = [
                tokenizer(chunk, add_special_tokens=False).input_ids
                for chunk in re.findall(pattern, text)
            ]
            if len(prompt_chunks) > len(image_tags):
                image_tags.append([])
            token_ids = [
528
529
530
531
                e
                for sublist in zip(prompt_chunks, image_tags)
                for ele in sublist
                for e in ele
pansicheng's avatar
pansicheng committed
532
533
            ]

534
        token_ids, placeholders = super()._apply_prompt_updates(
535
            token_ids=token_ids,
536
            mm_prompt_updates=mm_prompt_updates,
537
538
539
        )

        # Keep the behavior in line with HF processor
540
541
542
        if len(mm_prompt_updates) and (
            token_ids[:2] == tokenizer.encode("<s> <|image|>", add_special_tokens=False)
        ):
543
            token_ids = [token_ids[0], *token_ids[2:]]
544
545
            placeholders = {
                modality: [
546
                    PlaceholderFeaturesInfo(
547
548
549
                        modality=p.modality,
                        item_idx=p.item_idx,
                        start_idx=p.start_idx - 1,
550
                        tokens=p.tokens,
551
                        is_embed=p.is_embed,
552
553
                    )
                    for p in ps
554
555
556
                ]
                for modality, ps in placeholders.items()
            }
557

558
        return token_ids, placeholders
559

560

561
562
563
564
565
566
@MULTIMODAL_REGISTRY.register_processor(
    Phi3VMultiModalProcessor,
    info=Phi3VProcessingInfo,
    dummy_inputs=Phi3VDummyInputsBuilder,
)
class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsQuant):
567
568
569
570
571
572
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={
            "model.vision_embed_tokens.wte": "embed_tokens",
            "model.vision_embed_tokens.": "vision_embed_tokens.",
            "lm_head.": "language_model.lm_head.",
            "model.": "language_model.model.",
573
574
        }
    )
575

576
    @classmethod
577
    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
578
579
580
581
582
        if modality.startswith("image"):
            return f"<|image_{i}|>"

        raise ValueError("Only image modality is supported")

583
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
584
        super().__init__()
585
586
        config = vllm_config.model_config.hf_config
        multimodal_config = vllm_config.model_config.multimodal_config
587
        self.config = config
588
        self.multimodal_config = multimodal_config
589
        self.image_token_id = _IMAGE_TOKEN_ID
590

591
592
593
594
        self.embed_tokens = VocabParallelEmbedding(
            config.vocab_size,
            config.hidden_size,
            org_num_embeddings=config.vocab_size,
595
            quant_config=self.quant_config,
596
            prefix=maybe_prefix(prefix, "model.embed_tokens"),
597
598
599
        )

        # TODO: Optionally initializes this for supporting input embeddings.
600
        self.vision_embed_tokens = Phi3HDImageEmbedding(
601
            config,
602
            self.quant_config,
603
604
            prefix=maybe_prefix(prefix, "model.vision_embed_tokens"),
        )
605

606
607
608
609
610
611
612
613
614
615
616
        self.language_model = init_vllm_registered_model(
            vllm_config=vllm_config,
            # The prefix is empty intentionally because default prefix of
            # LlamaForCausalLM is "model"
            prefix="",
            # We don't directly initialize vLLM's LlamaForCausalLM so we
            # can automatically apply embedding wrapper if this model is
            # initialized as an embedding model
            architectures=["LlamaForCausalLM"],
        )

617
        self.make_empty_intermediate_tensors = (
618
619
            self.language_model.make_empty_intermediate_tensors
        )
620

621
    def _parse_and_validate_image_input(
622
        self, **kwargs: object
623
    ) -> Phi3VImageInputs | None:
624
625
        pixel_values = kwargs.pop("pixel_values", None)
        image_sizes = kwargs.pop("image_sizes", None)
626
        image_embeds = kwargs.pop("image_embeds", None)
627

628
629
630
631
632
633
        if pixel_values is None and image_embeds is None:
            return None

        if pixel_values is not None:
            return Phi3VImagePixelInputs(
                type="pixel_values",
634
                pixel_values=flatten_bn(pixel_values),
635
636
637
                image_sizes=flatten_bn(image_sizes, concat=True),
                resolve_bindings={
                    "h": CLIP_VIT_LARGE_PATCH14_336_CONFIG.image_size,
638
639
640
                    "w": CLIP_VIT_LARGE_PATCH14_336_CONFIG.image_size,
                },
            )
641
642
643
644

        if image_embeds is not None:
            return Phi3VImageEmbeddingInputs(
                type="image_embeds",
645
                data=flatten_bn(image_embeds),
646
647
648
649
650
651
652
653
654
            )

        raise AssertionError("This line should be unreachable.")

    def _process_image_input(
        self,
        image_input: Phi3VImageInputs,
    ) -> torch.Tensor:
        if image_input["type"] == "image_embeds":
655
656
657
658
659
660
661
662
            image_data = image_input["data"]
            if is_list_of(image_data, torch.Tensor):
                # it's already a list of tensors
                return image_data
            if len(image_data.shape) == 3:
                # 3D tensor
                return list(torch.unbind(image_data, dim=0))
            raise ValueError(
663
                "We expect batched 2D tensors; "
664
665
                "this can be either a list of 2D tensors or a single 3D tensor."
            )
666

667
        assert self.vision_embed_tokens is not None
668
669
670
        image_embeds = self.vision_embed_tokens(
            image_input["pixel_values"], image_input["image_sizes"]
        )
671

672
        return image_embeds
673

674
675
676
    def get_language_model(self) -> torch.nn.Module:
        return self.language_model

677
    def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings:
678
679
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input is None:
680
            return []
681
682
683
684
685
686
        vision_embeddings = self._process_image_input(image_input)
        return vision_embeddings

    def get_input_embeddings(
        self,
        input_ids: torch.Tensor,
687
        multimodal_embeddings: MultiModalEmbeddings | None = None,
688
        *,
689
        is_multimodal: torch.Tensor | None = None,
690
        handle_oov_mm_token: bool = False,
691
    ) -> torch.Tensor:
692
693
694
695
696
697
698
699
700
701
702
703
704
705
        inputs_embeds = self._get_text_embeddings(
            input_ids,
            self.embed_tokens,
            is_multimodal=is_multimodal,
            handle_oov_mm_token=handle_oov_mm_token,
        )

        if multimodal_embeddings is None or len(multimodal_embeddings) == 0:
            return inputs_embeds

        if is_multimodal is None:
            raise ValueError(
                "`get_input_embeddings` now requires `is_multimodal` arg, "
                "please update your model runner according to "
706
707
                "https://github.com/vllm-project/vllm/pull/16229."
            )
708
709
710
711
712
713

        return _merge_multimodal_embeddings(
            inputs_embeds=inputs_embeds,
            multimodal_embeddings=multimodal_embeddings,
            is_multimodal=is_multimodal,
        )
714

715
716
717
718
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
719
720
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
721
722
        **kwargs: object,
    ):
723
        if intermediate_tensors is not None:
724
            inputs_embeds = None
725

726
727
728
        hidden_states = self.language_model.model(
            input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds
        )
729
730
731

        return hidden_states

732
733
734
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
735
    ) -> torch.Tensor | None:
736
        return self.language_model.compute_logits(hidden_states)
737

738
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
739
        loader = AutoWeightsLoader(self)
740
        autoloaded_weights = loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
741
742
743

        # The HF config doesn't specify whether these are tied,
        # so we detect it this way
744
        if "embed_tokens.weight" not in autoloaded_weights:
745
            self.embed_tokens = self.language_model.model.embed_tokens
746
747
            autoloaded_weights.add("embed_tokens.weight")
        return autoloaded_weights