phi3v.py 26.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
9
10
11
12
13
14
15
16
17
# Copyright 2024 The vLLM team.
# Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
18
from collections.abc import Iterable, Mapping, Sequence
19
from typing import Annotated, Any, Literal, Optional, Union
20

21
import regex as re
22
23
import torch
import torch.nn as nn
24
25
26
27
28
29
from transformers import (
    BatchFeature,
    CLIPVisionConfig,
    PretrainedConfig,
    ProcessorMixin,
)
30

31
from vllm.config import VllmConfig
32
from vllm.config.multimodal import BaseDummyOptions
33
from vllm.logger import init_logger
34
from vllm.model_executor.layers.quantization import QuantizationConfig
35
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
36
from vllm.multimodal import MULTIMODAL_REGISTRY
37
38
39
40
41
42
43
44
45
46
47
48
from vllm.multimodal.inputs import (
    MultiModalDataDict,
    MultiModalFieldConfig,
    MultiModalKwargsItems,
)
from vllm.multimodal.parse import (
    ImageEmbeddingItems,
    ImageProcessorItems,
    ImageSize,
    MultiModalDataItems,
)

49
50
# yapf conflicts with isort for this block
# yapf: disable
51
52
53
54
55
56
57
58
59
60
from vllm.multimodal.processing import (
    BaseMultiModalProcessor,
    BaseProcessingInfo,
    MultiModalPromptUpdates,
    PlaceholderFeaturesInfo,
    PromptReplacement,
    PromptUpdate,
    ResolvedPromptUpdate,
)

61
# yapf: enable
62
from vllm.multimodal.profiling import BaseDummyInputsBuilder
63
from vllm.sequence import IntermediateTensors
64
from vllm.utils import is_list_of
65
from vllm.utils.tensor_schema import TensorSchema, TensorShape
66

67
from .clip import CLIPVisionModel
68
69
70
71
72
73
74
75
76
77
78
79
80
81
from .interfaces import (
    MultiModalEmbeddings,
    SupportsMultiModal,
    SupportsPP,
    SupportsQuant,
)
from .utils import (
    AutoWeightsLoader,
    WeightsMapper,
    _merge_multimodal_embeddings,
    flatten_bn,
    init_vllm_registered_model,
    maybe_prefix,
)
82

83
84
logger = init_logger(__name__)

85
86
87
# Cannot find the following 2 numbers from hf config.
_IMAGE_TOKEN_ID = 32044

88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
CLIP_VIT_LARGE_PATCH14_336_CONFIG = CLIPVisionConfig(
    dropout=0.0,
    hidden_act="quick_gelu",
    hidden_size=1024,
    image_size=336,
    intermediate_size=4096,
    num_attention_heads=16,
    num_channels=3,
    num_hidden_layers=24,
    patch_size=14,
    projection_dim=768,
)


def _init_img_processor(
    hf_config: PretrainedConfig,
    quant_config: Optional[QuantizationConfig],
    prefix: str = "",
) -> CLIPVisionModel:
107
    clip_config = CLIP_VIT_LARGE_PATCH14_336_CONFIG
108
    layer_idx = hf_config.img_processor.get("layer_idx", -2)
109
110
111

    # Initialize the CLIP only up to the required feature layer
    if layer_idx < 0:
112
        num_hidden_layers = clip_config.num_hidden_layers + layer_idx + 1
113
114
115
116
    else:
        num_hidden_layers = layer_idx + 1

    img_processor = CLIPVisionModel(
117
118
119
        clip_config,
        quant_config,
        num_hidden_layers_override=num_hidden_layers,
120
        prefix=prefix,
121
    )
122
123
124
125

    return img_processor


126
class Phi3VImagePixelInputs(TensorSchema):
127
    """
128
129
130
131
132
133
    Dimensions:
        - b: Batch size
        - n: Number of images
        - p: Number of patches
        - h: Height of each patch
        - w: Width of each patch
134
135
    """

136
    type: Literal["pixel_values", "image_embeds"] = "pixel_values"
137

138
    # Supports either a stacked tensor or a list of (p, 3, h, w) tensors
139
    pixel_values: Annotated[
140
        Union[torch.Tensor, list[torch.Tensor]],
141
142
143
        TensorShape(
            "bn", "p", 3, "h", "w", dynamic_dims={"p"}
        ),  # 'p' may vary across items
144
    ]
145

146
147
    # Stacked tensor with height and width for each image
    image_sizes: Annotated[Optional[torch.Tensor], TensorShape("bn", 2)]
148
149


150
class Phi3VImageEmbeddingInputs(TensorSchema):
151
    """
152
153
154
155
156
157
    Dimensions:
        - b: Batch size
        - n: Number of images
        - f: Image feature size (e.g., number of tokens per image)
        - h: Hidden size (must match language model backbone)
    """
158

159
160
161
162
163
    type: Literal["image_embeds"] = "image_embeds"
    data: Annotated[
        Union[torch.Tensor, list[torch.Tensor]],
        TensorShape("bn", "f", "h"),
    ]
164
165
166
167
168


Phi3VImageInputs = Union[Phi3VImagePixelInputs, Phi3VImageEmbeddingInputs]


169
class Phi3ImageEmbeddingBase(nn.Module):
170
    def __init__(self) -> None:
171
172
173
174
175
        super().__init__()
        self.layer_idx: int
        self.type_feature: str
        self.img_processor: CLIPVisionModel

176
    def get_img_features(self, img_embeds: torch.FloatTensor) -> torch.FloatTensor:
177
178
        TYPE_FEATURE = self.type_feature

179
180
        # NOTE: we skip the step to select the vision feature layer since
        # this is already done inside the img_processor
181
        img_feature = self.img_processor(img_embeds)
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196

        if TYPE_FEATURE == "patch":
            patch_feature = img_feature[:, 1:]
            return patch_feature

        if TYPE_FEATURE == "cls_patch":
            return img_feature

        raise NotImplementedError


# adapted from https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_embedding_phi3_v.py
class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase):
    """Phi3 Image embedding with HD transform."""

197
198
199
200
201
202
    def __init__(
        self,
        config: PretrainedConfig,
        quant_config: Optional[QuantizationConfig],
        prefix: str = "",
    ) -> None:
203
        super().__init__()
204
205

        # n_embed or hidden_size
206
        hidden_size = config.n_embd if hasattr(config, "n_embd") else config.hidden_size
207

208
        self.img_processor = _init_img_processor(
209
210
            config, quant_config, prefix=f"{prefix}.img_processor"
        )
211

212
213
        image_dim_out = config.img_processor["image_dim_out"]
        self.num_img_tokens = config.img_processor["num_img_tokens"]
214
215
216
217

        self.image_dim_out = image_dim_out

        # global_gn and sub_gn for hd transform, serves as line separator
218
        self.use_hd_transform = config.embd_layer.get("use_hd_transform", False)
219
        self.with_learnable_separator = config.embd_layer.get(
220
221
222
            "with_learnable_separator", False
        )
        self.hd_transform_order = config.embd_layer.get("hd_transform_order", "glb_sub")
223
224
225
226
227
        # with_hd_transform and with_learnable_separator should have same value
        assert self.use_hd_transform and self.with_learnable_separator

        # 1024 * 4, merge spatial to channel dimension
        self.glb_GN = nn.Parameter(torch.empty([1, 1, self.image_dim_out * 4]))
228
        self.sub_GN = nn.Parameter(torch.empty([1, 1, 1, self.image_dim_out * 4]))
229
230
231
232
233

        dim_projection = hidden_size
        depth = 2
        layers = [nn.Linear(image_dim_out * 4, dim_projection)]
        for _ in range(1, depth):
234
            layers.extend([nn.GELU(), nn.Linear(dim_projection, dim_projection)])
235
236
        self.img_projection = nn.Sequential(*layers)

237
        self.type_feature = config.img_processor.get("type_feature", "patch")
238

239
240
241
    def forward(
        self, pixel_values: torch.FloatTensor, image_sizes: torch.Tensor
    ) -> torch.FloatTensor:
242
243
244
245
246
247
248
249
250
        """
        process image and return vision embeddings.

        pixel_values: (num_images, num_crops, c, h, w)
        output: (num_images, num_img_tokens, hidden_size)
        """
        num_images, num_crops, c, h, w = pixel_values.shape
        pixel_values = pixel_values.flatten(0, 1)
        img_features = self.get_img_features(pixel_values)
251
252
253
254
        img_features = img_features.reshape(
            num_images, num_crops, -1, self.image_dim_out
        )
        image_features_proj = self.hd_feature_transform(img_features, image_sizes)
255
256
257
258
259
260
        return image_features_proj

    def hd_feature_transform(self, image_features, image_sizes):
        """
        image_features: (num_images, num_crops+1, 24*24, 1024)
        """
261
262
263
        assert self.hd_transform_order == "sub_glb", (
            f"hd_transform_order `{self.hd_transform_order}` not implemented"
        )
264
265
266
267
268
269
270
        if isinstance(self.img_projection, nn.Sequential):
            target_device = self.img_projection[0].bias.device
            target_dtype = self.img_projection[0].bias.dtype
        else:  # It's a single nn.Linear layer
            target_device = self.img_projection.bias.device
            target_dtype = self.img_projection.bias.dtype

271
        global_image_features = image_features[:, 0]  # (num_images, 24*24, 1024)
272
273
        # global feature can be viewed as a special HD case with num_crops 1x1
        global_image_features_hd = self.reshape_hd_patches_2x2merge(
274
275
            global_image_features, 1, 1
        )
276
        global_image_features_hd_newline = self.add_image_newline(
277
278
            global_image_features_hd
        )
279

280
        batch_image_features_proj = []
281
282
283
284
285
286
287
288
289
290
        # need a for loop to process each image because of different image sizes
        # (patch arrangement is different for each image)
        for i, img_size in enumerate(image_sizes):
            h, w = img_size
            h_crop = h // 336
            w_crop = w // 336
            num_crops = h_crop * w_crop

            # NOTE: real num_crops is padded
            # (num_crops, 24*24, 1024)
291
            sub_image_features = image_features[i, 1 : 1 + num_crops]
292
            sub_image_features_hd = self.reshape_hd_patches_2x2merge(
293
294
                sub_image_features, h_crop, w_crop
            )
295
            sub_image_features_hd_newline = self.add_image_newline(
296
297
                sub_image_features_hd
            )
298
299

            # [sub features, separator, global features]
300
301
302
303
304
305
306
307
308
            image_embeddings = torch.cat(
                [
                    sub_image_features_hd_newline.squeeze(
                        0
                    ),  # (h_crop*12*(w_crop*12+1), 4096)
                    self.glb_GN.squeeze(0),
                    global_image_features_hd_newline[i],
                ]
            )
309
            img_proj = self.img_projection(
310
311
                image_embeddings.to(target_device, target_dtype)
            )
312
313
314
            batch_image_features_proj.append(img_proj)

        return batch_image_features_proj
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330

    def reshape_hd_patches_2x2merge(self, image_features, h_crop, w_crop):
        """
        image_features: (num_images*num_crops, 24*24, 1024)
        output: (num_images, h_crop*12, w_crop*12, 4096)
        where h_crop*w_crop == num_crops
        """
        N, L, C = image_features.shape
        assert L == 576 and C == 1024 and N % (h_crop * w_crop) == 0
        num_images = N // (h_crop * w_crop)
        H = int(L**0.5)
        image_features_hd = (
            image_features.reshape(N, H, H, C)  # N, 24, 24, 1024
            .reshape(N, H // 2, 2, H // 2, 2, C)  # N, 12, 2, 12, 2, 1024
            .permute(0, 1, 3, 2, 4, 5)  # N, 12, 12, 2, 2, 1024
            .reshape(N, -1, 4 * C)  # N, 144, 4096
331
332
333
            .reshape(
                num_images, h_crop, w_crop, H // 2, H // 2, -1
            )  # n_img, h_crop, w_crop, 12, 12, 4096
334
            .permute(0, 1, 3, 2, 4, 5)  # n_img, h_crop, 12, w_crop, 12, 4096
335
336
337
            .reshape(
                num_images, h_crop * H // 2, w_crop * H // 2, 4 * C
            )  # n_img, h_crop*12, w_crop*12, 4096
338
339
340
341
342
343
344
345
346
347
        )
        return image_features_hd

    def add_image_newline(self, image_features_hd):
        """
        image_features_hd: (num_images, h_crop*12, w_crop*12, 4096)
        output: (num_images, (h_crop*12) * (w_crop*12+1), 4096)
        """
        num_images, h, w, hid_dim = image_features_hd.shape
        # add the newline token to the HD image feature patches
348
349
350
        newline_embeddings = self.sub_GN.expand(
            num_images, h, -1, -1
        )  # (n_img, h, 1, hid_dim)
351
        image_features_hd_newline = torch.cat(
352
353
            [image_features_hd, newline_embeddings], dim=2
        ).reshape(num_images, -1, hid_dim)
354
        return image_features_hd_newline
355
356


357
class Phi3VProcessingInfo(BaseProcessingInfo):
358
359
360
    def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
        return {"image": None}

361
362
363
364
365
    def get_num_image_tokens(
        self,
        *,
        image_width: int,
        image_height: int,
366
        processor: Optional[ProcessorMixin] = None,
367
368
369
370
371
372
373
374
375
376
    ) -> int:
        if processor is None:
            processor = self.get_hf_processor()

        return processor.calc_num_image_tokens_from_image_size(  # type: ignore
            width=image_width,
            height=image_height,
        )

    def get_image_size_with_most_features(self) -> ImageSize:
377
378
379
        # Result in the max possible feature size (h:w = 16:1)
        return ImageSize(height=8000, width=50)

380
381

class Phi3VDummyInputsBuilder(BaseDummyInputsBuilder[Phi3VProcessingInfo]):
382
383
384
385
386
387
388
389
390
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_images = mm_counts.get("image", 0)

        hf_processor = self.info.get_hf_processor()
        image_tokens: list[str] = hf_processor.img_tokens  # type: ignore

        return "".join(image_tokens[:num_images])

    def get_dummy_mm_data(
391
        self,
392
393
        seq_len: int,
        mm_counts: Mapping[str, int],
394
        mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
395
    ) -> MultiModalDataDict:
396
        num_images = mm_counts.get("image", 0)
397

398
        target_width, target_height = self.info.get_image_size_with_most_features()
399

400
401
        image_overrides = mm_options.get("image") if mm_options else None

402
        return {
403
404
405
406
407
408
            "image": self._get_dummy_images(
                width=target_width,
                height=target_height,
                num_images=num_images,
                overrides=image_overrides,
            )
409
410
411
        }


412
class Phi3VMultiModalProcessor(BaseMultiModalProcessor[Phi3VProcessingInfo]):
413
    def _call_hf_processor(
414
415
        self,
        prompt: str,
416
417
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
418
        tok_kwargs: Mapping[str, object],
419
    ) -> BatchFeature:
420
421
        processed_outputs = super()._call_hf_processor(
            prompt=prompt,
422
423
            mm_data=mm_data,
            mm_kwargs=mm_kwargs,
424
            tok_kwargs=tok_kwargs,
425
426
        )

427
428
429
        input_ids = processed_outputs["input_ids"]
        assert isinstance(input_ids, torch.Tensor)

430
431
432
        # Phi3v processor has inserted -1, -2 etc as placeholder in prompt_ids,
        # which will cause OverflowError when decoding the prompt_ids.
        # Therefore, we need to do an early replacement here
433
        input_ids.masked_fill_(input_ids < 0, _IMAGE_TOKEN_ID)
434

435
436
        return processed_outputs

437
438
439
440
441
442
443
444
445
446
447
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return dict(
            pixel_values=MultiModalFieldConfig.batched("image"),
            image_sizes=MultiModalFieldConfig.batched("image"),
            image_embeds=MultiModalFieldConfig.batched("image"),
        )

448
    def _get_prompt_updates(
449
450
        self,
        mm_items: MultiModalDataItems,
451
        hf_processor_mm_kwargs: Mapping[str, Any],
452
        out_mm_kwargs: MultiModalKwargsItems,
453
    ) -> Sequence[PromptUpdate]:
454
        hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
455
456
457
        image_tokens: list[str] = hf_processor.img_tokens  # type: ignore

        def get_replacement_phi3v(item_idx: int):
458
            images = mm_items.get_items(
459
460
                "image", (ImageEmbeddingItems, ImageProcessorItems)
            )
461
462
463
464
465

            if isinstance(images, ImageEmbeddingItems):
                num_image_tokens = images.get_feature_size(item_idx)
            else:
                image_size = images.get_image_size(item_idx)
466
                num_image_tokens = self.info.get_num_image_tokens(
467
468
                    image_width=image_size.width,
                    image_height=image_size.height,
469
                    processor=hf_processor,
470
471
                )

472
            return [_IMAGE_TOKEN_ID] * num_image_tokens
473
474
475
476

        return [
            PromptReplacement(
                modality="image",
477
                target=image_tokens.__getitem__,
478
                replacement=get_replacement_phi3v,
479
            )
480
481
        ]

482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
    def _recompute_cached_prompt_update(
        self,
        cached_update: ResolvedPromptUpdate,
        new_item_idx: int,
    ) -> ResolvedPromptUpdate:
        new_update = super()._recompute_cached_prompt_update(
            cached_update,
            new_item_idx,
        )

        if cached_update.modality == "image":
            hf_processor = self.info.get_hf_processor()
            image_tokens: list[str] = hf_processor.img_tokens  # type: ignore
            new_update = new_update.with_target(image_tokens[new_item_idx])

        return new_update

499
    def _apply_prompt_updates(
500
501
        self,
        token_ids: list[int],
502
        mm_prompt_updates: MultiModalPromptUpdates,
503
    ) -> tuple[list[int], Mapping[str, list[PlaceholderFeaturesInfo]]]:
pansicheng's avatar
pansicheng committed
504
        # align to hf behavior when there are images
505
        if len(mm_prompt_updates):
pansicheng's avatar
pansicheng committed
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
            tokenizer = self.info.get_tokenizer()
            # to decode token_ids to the original text, we need to
            # 1. remove the first bos token
            # 2. remove space after each special token
            #    introduced by the tokenizer
            if len(token_ids) and token_ids[0] == tokenizer.bos_token_id:
                token_ids = token_ids[1:]
            text = tokenizer.decode(token_ids)
            for special_tokens in tokenizer.special_tokens_map.values():
                if isinstance(special_tokens, str):
                    text = text.replace(f"{special_tokens} ", special_tokens)
                elif isinstance(special_tokens, list):
                    for special_token in special_tokens:
                        text = text.replace(f"{special_token} ", special_token)
            # perform hf behavior
            # https://huggingface.co/microsoft/Phi-3.5-vision-instruct/blob/64f88b6/processing_phi3_v.py#L407
            pattern = r"<\|image_\d+\|>"
            prompt_chunks = [
524
                tokenizer(chunk).input_ids for chunk in re.split(pattern, text)
pansicheng's avatar
pansicheng committed
525
526
527
528
529
530
531
532
            ]
            image_tags = [
                tokenizer(chunk, add_special_tokens=False).input_ids
                for chunk in re.findall(pattern, text)
            ]
            if len(prompt_chunks) > len(image_tags):
                image_tags.append([])
            token_ids = [
533
534
535
536
                e
                for sublist in zip(prompt_chunks, image_tags)
                for ele in sublist
                for e in ele
pansicheng's avatar
pansicheng committed
537
538
            ]

539
        token_ids, placeholders = super()._apply_prompt_updates(
540
            token_ids=token_ids,
541
            mm_prompt_updates=mm_prompt_updates,
542
543
544
        )

        # Keep the behavior in line with HF processor
545
546
547
        if len(mm_prompt_updates) and (
            token_ids[:2] == tokenizer.encode("<s> <|image|>", add_special_tokens=False)
        ):
548
            token_ids = [token_ids[0], *token_ids[2:]]
549
550
            placeholders = {
                modality: [
551
                    PlaceholderFeaturesInfo(
552
553
554
                        modality=p.modality,
                        item_idx=p.item_idx,
                        start_idx=p.start_idx - 1,
555
                        tokens=p.tokens,
556
                        is_embed=p.is_embed,
557
558
                    )
                    for p in ps
559
560
561
                ]
                for modality, ps in placeholders.items()
            }
562

563
        return token_ids, placeholders
564

565

566
567
568
569
570
571
@MULTIMODAL_REGISTRY.register_processor(
    Phi3VMultiModalProcessor,
    info=Phi3VProcessingInfo,
    dummy_inputs=Phi3VDummyInputsBuilder,
)
class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsQuant):
572
573
574
575
576
577
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={
            "model.vision_embed_tokens.wte": "embed_tokens",
            "model.vision_embed_tokens.": "vision_embed_tokens.",
            "lm_head.": "language_model.lm_head.",
            "model.": "language_model.model.",
578
579
        }
    )
580

581
582
583
584
585
586
587
    @classmethod
    def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
        if modality.startswith("image"):
            return f"<|image_{i}|>"

        raise ValueError("Only image modality is supported")

588
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
589
        super().__init__()
590
591
        config = vllm_config.model_config.hf_config
        multimodal_config = vllm_config.model_config.multimodal_config
592
        self.config = config
593
        self.multimodal_config = multimodal_config
594
        self.image_token_id = _IMAGE_TOKEN_ID
595

596
597
598
599
        self.embed_tokens = VocabParallelEmbedding(
            config.vocab_size,
            config.hidden_size,
            org_num_embeddings=config.vocab_size,
600
            quant_config=self.quant_config,
601
            prefix=maybe_prefix(prefix, "model.embed_tokens"),
602
603
604
        )

        # TODO: Optionally initializes this for supporting input embeddings.
605
        self.vision_embed_tokens = Phi3HDImageEmbedding(
606
            config,
607
            self.quant_config,
608
609
            prefix=maybe_prefix(prefix, "model.vision_embed_tokens"),
        )
610

611
612
613
614
615
616
617
618
619
620
621
        self.language_model = init_vllm_registered_model(
            vllm_config=vllm_config,
            # The prefix is empty intentionally because default prefix of
            # LlamaForCausalLM is "model"
            prefix="",
            # We don't directly initialize vLLM's LlamaForCausalLM so we
            # can automatically apply embedding wrapper if this model is
            # initialized as an embedding model
            architectures=["LlamaForCausalLM"],
        )

622
        self.make_empty_intermediate_tensors = (
623
624
            self.language_model.make_empty_intermediate_tensors
        )
625

626
    def _parse_and_validate_image_input(
627
628
        self, **kwargs: object
    ) -> Optional[Phi3VImageInputs]:
629
630
        pixel_values = kwargs.pop("pixel_values", None)
        image_sizes = kwargs.pop("image_sizes", None)
631
        image_embeds = kwargs.pop("image_embeds", None)
632

633
634
635
636
637
638
        if pixel_values is None and image_embeds is None:
            return None

        if pixel_values is not None:
            return Phi3VImagePixelInputs(
                type="pixel_values",
639
                pixel_values=flatten_bn(pixel_values),
640
641
642
                image_sizes=flatten_bn(image_sizes, concat=True),
                resolve_bindings={
                    "h": CLIP_VIT_LARGE_PATCH14_336_CONFIG.image_size,
643
644
645
                    "w": CLIP_VIT_LARGE_PATCH14_336_CONFIG.image_size,
                },
            )
646
647
648
649

        if image_embeds is not None:
            return Phi3VImageEmbeddingInputs(
                type="image_embeds",
650
                data=flatten_bn(image_embeds),
651
652
653
654
655
656
657
658
659
            )

        raise AssertionError("This line should be unreachable.")

    def _process_image_input(
        self,
        image_input: Phi3VImageInputs,
    ) -> torch.Tensor:
        if image_input["type"] == "image_embeds":
660
661
662
663
664
665
666
667
            image_data = image_input["data"]
            if is_list_of(image_data, torch.Tensor):
                # it's already a list of tensors
                return image_data
            if len(image_data.shape) == 3:
                # 3D tensor
                return list(torch.unbind(image_data, dim=0))
            raise ValueError(
668
                "We expect batched 2D tensors; "
669
670
                "this can be either a list of 2D tensors or a single 3D tensor."
            )
671

672
        assert self.vision_embed_tokens is not None
673
674
675
        image_embeds = self.vision_embed_tokens(
            image_input["pixel_values"], image_input["image_sizes"]
        )
676

677
        return image_embeds
678

679
680
681
    def get_language_model(self) -> torch.nn.Module:
        return self.language_model

682
    def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings:
683
684
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input is None:
685
            return []
686
687
688
689
690
691
        vision_embeddings = self._process_image_input(image_input)
        return vision_embeddings

    def get_input_embeddings(
        self,
        input_ids: torch.Tensor,
692
        multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
693
694
695
        *,
        is_multimodal: Optional[torch.Tensor] = None,
        handle_oov_mm_token: bool = False,
696
    ) -> torch.Tensor:
697
698
699
700
701
702
703
704
705
706
707
708
709
710
        inputs_embeds = self._get_text_embeddings(
            input_ids,
            self.embed_tokens,
            is_multimodal=is_multimodal,
            handle_oov_mm_token=handle_oov_mm_token,
        )

        if multimodal_embeddings is None or len(multimodal_embeddings) == 0:
            return inputs_embeds

        if is_multimodal is None:
            raise ValueError(
                "`get_input_embeddings` now requires `is_multimodal` arg, "
                "please update your model runner according to "
711
712
                "https://github.com/vllm-project/vllm/pull/16229."
            )
713
714
715
716
717
718

        return _merge_multimodal_embeddings(
            inputs_embeds=inputs_embeds,
            multimodal_embeddings=multimodal_embeddings,
            is_multimodal=is_multimodal,
        )
719

720
721
722
723
724
725
726
727
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        **kwargs: object,
    ):
728
        if intermediate_tensors is not None:
729
            inputs_embeds = None
730

731
732
733
        hidden_states = self.language_model.model(
            input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds
        )
734
735
736

        return hidden_states

737
738
739
740
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
    ) -> Optional[torch.Tensor]:
741
        return self.language_model.compute_logits(hidden_states)
742

743
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
744
        loader = AutoWeightsLoader(self)
745
        autoloaded_weights = loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
746
747
748

        # The HF config doesn't specify whether these are tied,
        # so we detect it this way
749
        if "embed_tokens.weight" not in autoloaded_weights:
750
            self.embed_tokens = self.language_model.model.embed_tokens
751
752
            autoloaded_weights.add("embed_tokens.weight")
        return autoloaded_weights