phi3v.py 25.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
9
10
11
12
13
14
15
16
17
# Copyright 2024 The vLLM team.
# Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
18
from collections.abc import Iterable, Mapping, Sequence
19
from typing import Annotated, Any, Literal, TypeAlias
20

21
import regex as re
22
23
import torch
import torch.nn as nn
24
25
26
27
28
29
from transformers import (
    BatchFeature,
    CLIPVisionConfig,
    PretrainedConfig,
    ProcessorMixin,
)
30

31
from vllm.config import VllmConfig
32
from vllm.config.multimodal import BaseDummyOptions
33
from vllm.logger import init_logger
34
from vllm.model_executor.layers.quantization import QuantizationConfig
35
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
36
from vllm.multimodal import MULTIMODAL_REGISTRY
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
from vllm.multimodal.inputs import (
    MultiModalDataDict,
    MultiModalFieldConfig,
    MultiModalKwargsItems,
)
from vllm.multimodal.parse import (
    ImageEmbeddingItems,
    ImageProcessorItems,
    ImageSize,
    MultiModalDataItems,
)
from vllm.multimodal.processing import (
    BaseMultiModalProcessor,
    BaseProcessingInfo,
    MultiModalPromptUpdates,
    PlaceholderFeaturesInfo,
    PromptReplacement,
    PromptUpdate,
    ResolvedPromptUpdate,
)
57
from vllm.multimodal.profiling import BaseDummyInputsBuilder
58
from vllm.sequence import IntermediateTensors
59
from vllm.utils.tensor_schema import TensorSchema, TensorShape
60

61
from .clip import CLIPVisionModel
62
63
64
65
66
67
68
69
70
71
72
73
74
from .interfaces import (
    MultiModalEmbeddings,
    SupportsMultiModal,
    SupportsPP,
    SupportsQuant,
)
from .utils import (
    AutoWeightsLoader,
    WeightsMapper,
    _merge_multimodal_embeddings,
    init_vllm_registered_model,
    maybe_prefix,
)
75

76
77
logger = init_logger(__name__)

78
79
80
# Cannot find the following 2 numbers from hf config.
_IMAGE_TOKEN_ID = 32044

81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
CLIP_VIT_LARGE_PATCH14_336_CONFIG = CLIPVisionConfig(
    dropout=0.0,
    hidden_act="quick_gelu",
    hidden_size=1024,
    image_size=336,
    intermediate_size=4096,
    num_attention_heads=16,
    num_channels=3,
    num_hidden_layers=24,
    patch_size=14,
    projection_dim=768,
)


def _init_img_processor(
    hf_config: PretrainedConfig,
97
    quant_config: QuantizationConfig | None,
98
99
    prefix: str = "",
) -> CLIPVisionModel:
100
    clip_config = CLIP_VIT_LARGE_PATCH14_336_CONFIG
101
    layer_idx = hf_config.img_processor.get("layer_idx", -2)
102
103
104

    # Initialize the CLIP only up to the required feature layer
    if layer_idx < 0:
105
        num_hidden_layers = clip_config.num_hidden_layers + layer_idx + 1
106
107
108
109
    else:
        num_hidden_layers = layer_idx + 1

    img_processor = CLIPVisionModel(
110
111
112
        clip_config,
        quant_config,
        num_hidden_layers_override=num_hidden_layers,
113
        prefix=prefix,
114
    )
115
116
117
118

    return img_processor


119
class Phi3VImagePixelInputs(TensorSchema):
120
    """
121
122
123
124
125
126
    Dimensions:
        - b: Batch size
        - n: Number of images
        - p: Number of patches
        - h: Height of each patch
        - w: Width of each patch
127
128
    """

129
    type: Literal["pixel_values", "image_embeds"] = "pixel_values"
130

131
    # Supports either a stacked tensor or a list of (p, 3, h, w) tensors
132
    pixel_values: Annotated[
133
        torch.Tensor | list[torch.Tensor],
134
135
136
        TensorShape(
            "bn", "p", 3, "h", "w", dynamic_dims={"p"}
        ),  # 'p' may vary across items
137
    ]
138

139
    # Stacked tensor with height and width for each image
140
    image_sizes: Annotated[torch.Tensor | None, TensorShape("bn", 2)]
141
142


143
class Phi3VImageEmbeddingInputs(TensorSchema):
144
    """
145
146
147
148
149
150
    Dimensions:
        - b: Batch size
        - n: Number of images
        - f: Image feature size (e.g., number of tokens per image)
        - h: Hidden size (must match language model backbone)
    """
151

152
153
    type: Literal["image_embeds"] = "image_embeds"
    data: Annotated[
154
        torch.Tensor | list[torch.Tensor],
155
156
        TensorShape("bn", "f", "h"),
    ]
157
158


159
Phi3VImageInputs: TypeAlias = Phi3VImagePixelInputs | Phi3VImageEmbeddingInputs
160
161


162
class Phi3ImageEmbeddingBase(nn.Module):
163
    def __init__(self) -> None:
164
165
166
167
168
        super().__init__()
        self.layer_idx: int
        self.type_feature: str
        self.img_processor: CLIPVisionModel

169
    def get_img_features(self, img_embeds: torch.FloatTensor) -> torch.FloatTensor:
170
171
        TYPE_FEATURE = self.type_feature

172
173
        # NOTE: we skip the step to select the vision feature layer since
        # this is already done inside the img_processor
174
        img_feature = self.img_processor(img_embeds)
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189

        if TYPE_FEATURE == "patch":
            patch_feature = img_feature[:, 1:]
            return patch_feature

        if TYPE_FEATURE == "cls_patch":
            return img_feature

        raise NotImplementedError


# adapted from https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_embedding_phi3_v.py
class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase):
    """Phi3 Image embedding with HD transform."""

190
191
192
    def __init__(
        self,
        config: PretrainedConfig,
193
        quant_config: QuantizationConfig | None,
194
195
        prefix: str = "",
    ) -> None:
196
        super().__init__()
197
198

        # n_embed or hidden_size
199
        hidden_size = config.n_embd if hasattr(config, "n_embd") else config.hidden_size
200

201
        self.img_processor = _init_img_processor(
202
203
            config, quant_config, prefix=f"{prefix}.img_processor"
        )
204

205
206
        image_dim_out = config.img_processor["image_dim_out"]
        self.num_img_tokens = config.img_processor["num_img_tokens"]
207
208
209
210

        self.image_dim_out = image_dim_out

        # global_gn and sub_gn for hd transform, serves as line separator
211
        self.use_hd_transform = config.embd_layer.get("use_hd_transform", False)
212
        self.with_learnable_separator = config.embd_layer.get(
213
214
215
            "with_learnable_separator", False
        )
        self.hd_transform_order = config.embd_layer.get("hd_transform_order", "glb_sub")
216
217
218
219
220
        # with_hd_transform and with_learnable_separator should have same value
        assert self.use_hd_transform and self.with_learnable_separator

        # 1024 * 4, merge spatial to channel dimension
        self.glb_GN = nn.Parameter(torch.empty([1, 1, self.image_dim_out * 4]))
221
        self.sub_GN = nn.Parameter(torch.empty([1, 1, 1, self.image_dim_out * 4]))
222
223
224
225
226

        dim_projection = hidden_size
        depth = 2
        layers = [nn.Linear(image_dim_out * 4, dim_projection)]
        for _ in range(1, depth):
227
            layers.extend([nn.GELU(), nn.Linear(dim_projection, dim_projection)])
228
229
        self.img_projection = nn.Sequential(*layers)

230
        self.type_feature = config.img_processor.get("type_feature", "patch")
231

232
233
234
    def forward(
        self, pixel_values: torch.FloatTensor, image_sizes: torch.Tensor
    ) -> torch.FloatTensor:
235
236
237
238
239
240
241
242
243
        """
        process image and return vision embeddings.

        pixel_values: (num_images, num_crops, c, h, w)
        output: (num_images, num_img_tokens, hidden_size)
        """
        num_images, num_crops, c, h, w = pixel_values.shape
        pixel_values = pixel_values.flatten(0, 1)
        img_features = self.get_img_features(pixel_values)
244
245
246
247
        img_features = img_features.reshape(
            num_images, num_crops, -1, self.image_dim_out
        )
        image_features_proj = self.hd_feature_transform(img_features, image_sizes)
248
249
250
251
252
253
        return image_features_proj

    def hd_feature_transform(self, image_features, image_sizes):
        """
        image_features: (num_images, num_crops+1, 24*24, 1024)
        """
254
255
256
        assert self.hd_transform_order == "sub_glb", (
            f"hd_transform_order `{self.hd_transform_order}` not implemented"
        )
257
258
259
260
261
262
263
        if isinstance(self.img_projection, nn.Sequential):
            target_device = self.img_projection[0].bias.device
            target_dtype = self.img_projection[0].bias.dtype
        else:  # It's a single nn.Linear layer
            target_device = self.img_projection.bias.device
            target_dtype = self.img_projection.bias.dtype

264
        global_image_features = image_features[:, 0]  # (num_images, 24*24, 1024)
265
266
        # global feature can be viewed as a special HD case with num_crops 1x1
        global_image_features_hd = self.reshape_hd_patches_2x2merge(
267
268
            global_image_features, 1, 1
        )
269
        global_image_features_hd_newline = self.add_image_newline(
270
271
            global_image_features_hd
        )
272

273
        batch_image_features_proj = []
274
275
276
277
278
279
280
281
282
283
        # need a for loop to process each image because of different image sizes
        # (patch arrangement is different for each image)
        for i, img_size in enumerate(image_sizes):
            h, w = img_size
            h_crop = h // 336
            w_crop = w // 336
            num_crops = h_crop * w_crop

            # NOTE: real num_crops is padded
            # (num_crops, 24*24, 1024)
284
            sub_image_features = image_features[i, 1 : 1 + num_crops]
285
            sub_image_features_hd = self.reshape_hd_patches_2x2merge(
286
287
                sub_image_features, h_crop, w_crop
            )
288
            sub_image_features_hd_newline = self.add_image_newline(
289
290
                sub_image_features_hd
            )
291
292

            # [sub features, separator, global features]
293
294
295
296
297
298
299
300
301
            image_embeddings = torch.cat(
                [
                    sub_image_features_hd_newline.squeeze(
                        0
                    ),  # (h_crop*12*(w_crop*12+1), 4096)
                    self.glb_GN.squeeze(0),
                    global_image_features_hd_newline[i],
                ]
            )
302
            img_proj = self.img_projection(
303
304
                image_embeddings.to(target_device, target_dtype)
            )
305
306
307
            batch_image_features_proj.append(img_proj)

        return batch_image_features_proj
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323

    def reshape_hd_patches_2x2merge(self, image_features, h_crop, w_crop):
        """
        image_features: (num_images*num_crops, 24*24, 1024)
        output: (num_images, h_crop*12, w_crop*12, 4096)
        where h_crop*w_crop == num_crops
        """
        N, L, C = image_features.shape
        assert L == 576 and C == 1024 and N % (h_crop * w_crop) == 0
        num_images = N // (h_crop * w_crop)
        H = int(L**0.5)
        image_features_hd = (
            image_features.reshape(N, H, H, C)  # N, 24, 24, 1024
            .reshape(N, H // 2, 2, H // 2, 2, C)  # N, 12, 2, 12, 2, 1024
            .permute(0, 1, 3, 2, 4, 5)  # N, 12, 12, 2, 2, 1024
            .reshape(N, -1, 4 * C)  # N, 144, 4096
324
325
326
            .reshape(
                num_images, h_crop, w_crop, H // 2, H // 2, -1
            )  # n_img, h_crop, w_crop, 12, 12, 4096
327
            .permute(0, 1, 3, 2, 4, 5)  # n_img, h_crop, 12, w_crop, 12, 4096
328
329
330
            .reshape(
                num_images, h_crop * H // 2, w_crop * H // 2, 4 * C
            )  # n_img, h_crop*12, w_crop*12, 4096
331
332
333
334
335
336
337
338
339
340
        )
        return image_features_hd

    def add_image_newline(self, image_features_hd):
        """
        image_features_hd: (num_images, h_crop*12, w_crop*12, 4096)
        output: (num_images, (h_crop*12) * (w_crop*12+1), 4096)
        """
        num_images, h, w, hid_dim = image_features_hd.shape
        # add the newline token to the HD image feature patches
341
342
343
        newline_embeddings = self.sub_GN.expand(
            num_images, h, -1, -1
        )  # (n_img, h, 1, hid_dim)
344
        image_features_hd_newline = torch.cat(
345
346
            [image_features_hd, newline_embeddings], dim=2
        ).reshape(num_images, -1, hid_dim)
347
        return image_features_hd_newline
348
349


350
class Phi3VProcessingInfo(BaseProcessingInfo):
351
    def get_supported_mm_limits(self) -> Mapping[str, int | None]:
352
353
        return {"image": None}

354
355
356
357
358
    def get_num_image_tokens(
        self,
        *,
        image_width: int,
        image_height: int,
359
        processor: ProcessorMixin | None = None,
360
361
362
363
364
365
366
367
368
369
    ) -> int:
        if processor is None:
            processor = self.get_hf_processor()

        return processor.calc_num_image_tokens_from_image_size(  # type: ignore
            width=image_width,
            height=image_height,
        )

    def get_image_size_with_most_features(self) -> ImageSize:
370
371
372
        # Result in the max possible feature size (h:w = 16:1)
        return ImageSize(height=8000, width=50)

373
374

class Phi3VDummyInputsBuilder(BaseDummyInputsBuilder[Phi3VProcessingInfo]):
375
376
377
378
379
380
381
382
383
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_images = mm_counts.get("image", 0)

        hf_processor = self.info.get_hf_processor()
        image_tokens: list[str] = hf_processor.img_tokens  # type: ignore

        return "".join(image_tokens[:num_images])

    def get_dummy_mm_data(
384
        self,
385
386
        seq_len: int,
        mm_counts: Mapping[str, int],
387
        mm_options: Mapping[str, BaseDummyOptions] | None = None,
388
    ) -> MultiModalDataDict:
389
        num_images = mm_counts.get("image", 0)
390

391
        target_width, target_height = self.info.get_image_size_with_most_features()
392

393
394
        image_overrides = mm_options.get("image") if mm_options else None

395
        return {
396
397
398
399
400
401
            "image": self._get_dummy_images(
                width=target_width,
                height=target_height,
                num_images=num_images,
                overrides=image_overrides,
            )
402
403
404
        }


405
class Phi3VMultiModalProcessor(BaseMultiModalProcessor[Phi3VProcessingInfo]):
406
    def _call_hf_processor(
407
408
        self,
        prompt: str,
409
410
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
411
        tok_kwargs: Mapping[str, object],
412
    ) -> BatchFeature:
413
414
        processed_outputs = super()._call_hf_processor(
            prompt=prompt,
415
416
            mm_data=mm_data,
            mm_kwargs=mm_kwargs,
417
            tok_kwargs=tok_kwargs,
418
419
        )

420
421
422
        input_ids = processed_outputs["input_ids"]
        assert isinstance(input_ids, torch.Tensor)

423
424
425
        # Phi3v processor has inserted -1, -2 etc as placeholder in prompt_ids,
        # which will cause OverflowError when decoding the prompt_ids.
        # Therefore, we need to do an early replacement here
426
        input_ids.masked_fill_(input_ids < 0, _IMAGE_TOKEN_ID)
427

428
429
        return processed_outputs

430
431
432
433
434
435
436
437
438
439
440
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return dict(
            pixel_values=MultiModalFieldConfig.batched("image"),
            image_sizes=MultiModalFieldConfig.batched("image"),
            image_embeds=MultiModalFieldConfig.batched("image"),
        )

441
    def _get_prompt_updates(
442
443
        self,
        mm_items: MultiModalDataItems,
444
        hf_processor_mm_kwargs: Mapping[str, Any],
445
        out_mm_kwargs: MultiModalKwargsItems,
446
    ) -> Sequence[PromptUpdate]:
447
        hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
448
449
450
        image_tokens: list[str] = hf_processor.img_tokens  # type: ignore

        def get_replacement_phi3v(item_idx: int):
451
            images = mm_items.get_items(
452
453
                "image", (ImageEmbeddingItems, ImageProcessorItems)
            )
454
455
456
457
458

            if isinstance(images, ImageEmbeddingItems):
                num_image_tokens = images.get_feature_size(item_idx)
            else:
                image_size = images.get_image_size(item_idx)
459
                num_image_tokens = self.info.get_num_image_tokens(
460
461
                    image_width=image_size.width,
                    image_height=image_size.height,
462
                    processor=hf_processor,
463
464
                )

465
            return [_IMAGE_TOKEN_ID] * num_image_tokens
466
467
468
469

        return [
            PromptReplacement(
                modality="image",
470
                target=image_tokens.__getitem__,
471
                replacement=get_replacement_phi3v,
472
            )
473
474
        ]

475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
    def _recompute_cached_prompt_update(
        self,
        cached_update: ResolvedPromptUpdate,
        new_item_idx: int,
    ) -> ResolvedPromptUpdate:
        new_update = super()._recompute_cached_prompt_update(
            cached_update,
            new_item_idx,
        )

        if cached_update.modality == "image":
            hf_processor = self.info.get_hf_processor()
            image_tokens: list[str] = hf_processor.img_tokens  # type: ignore
            new_update = new_update.with_target(image_tokens[new_item_idx])

        return new_update

492
    def _apply_prompt_updates(
493
494
        self,
        token_ids: list[int],
495
        mm_prompt_updates: MultiModalPromptUpdates,
496
    ) -> tuple[list[int], Mapping[str, list[PlaceholderFeaturesInfo]]]:
pansicheng's avatar
pansicheng committed
497
        # align to hf behavior when there are images
498
        if len(mm_prompt_updates):
pansicheng's avatar
pansicheng committed
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
            tokenizer = self.info.get_tokenizer()
            # to decode token_ids to the original text, we need to
            # 1. remove the first bos token
            # 2. remove space after each special token
            #    introduced by the tokenizer
            if len(token_ids) and token_ids[0] == tokenizer.bos_token_id:
                token_ids = token_ids[1:]
            text = tokenizer.decode(token_ids)
            for special_tokens in tokenizer.special_tokens_map.values():
                if isinstance(special_tokens, str):
                    text = text.replace(f"{special_tokens} ", special_tokens)
                elif isinstance(special_tokens, list):
                    for special_token in special_tokens:
                        text = text.replace(f"{special_token} ", special_token)
            # perform hf behavior
            # https://huggingface.co/microsoft/Phi-3.5-vision-instruct/blob/64f88b6/processing_phi3_v.py#L407
            pattern = r"<\|image_\d+\|>"
            prompt_chunks = [
517
                tokenizer(chunk).input_ids for chunk in re.split(pattern, text)
pansicheng's avatar
pansicheng committed
518
519
520
521
522
523
524
525
            ]
            image_tags = [
                tokenizer(chunk, add_special_tokens=False).input_ids
                for chunk in re.findall(pattern, text)
            ]
            if len(prompt_chunks) > len(image_tags):
                image_tags.append([])
            token_ids = [
526
527
528
529
                e
                for sublist in zip(prompt_chunks, image_tags)
                for ele in sublist
                for e in ele
pansicheng's avatar
pansicheng committed
530
531
            ]

532
        token_ids, placeholders = super()._apply_prompt_updates(
533
            token_ids=token_ids,
534
            mm_prompt_updates=mm_prompt_updates,
535
536
537
        )

        # Keep the behavior in line with HF processor
538
539
540
        if len(mm_prompt_updates) and (
            token_ids[:2] == tokenizer.encode("<s> <|image|>", add_special_tokens=False)
        ):
541
            token_ids = [token_ids[0], *token_ids[2:]]
542
543
            placeholders = {
                modality: [
544
                    PlaceholderFeaturesInfo(
545
546
547
                        modality=p.modality,
                        item_idx=p.item_idx,
                        start_idx=p.start_idx - 1,
548
                        tokens=p.tokens,
549
                        is_embed=p.is_embed,
550
551
                    )
                    for p in ps
552
553
554
                ]
                for modality, ps in placeholders.items()
            }
555

556
        return token_ids, placeholders
557

558

559
560
561
562
563
564
@MULTIMODAL_REGISTRY.register_processor(
    Phi3VMultiModalProcessor,
    info=Phi3VProcessingInfo,
    dummy_inputs=Phi3VDummyInputsBuilder,
)
class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsQuant):
565
566
    merge_by_field_config = True

567
568
569
570
571
572
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={
            "model.vision_embed_tokens.wte": "embed_tokens",
            "model.vision_embed_tokens.": "vision_embed_tokens.",
            "lm_head.": "language_model.lm_head.",
            "model.": "language_model.model.",
573
574
        }
    )
575

576
    @classmethod
577
    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
578
579
580
581
582
        if modality.startswith("image"):
            return f"<|image_{i}|>"

        raise ValueError("Only image modality is supported")

583
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
584
        super().__init__()
585
586
        config = vllm_config.model_config.hf_config
        multimodal_config = vllm_config.model_config.multimodal_config
587
        self.config = config
588
        self.multimodal_config = multimodal_config
589
        self.image_token_id = _IMAGE_TOKEN_ID
590

591
592
593
594
        self.embed_tokens = VocabParallelEmbedding(
            config.vocab_size,
            config.hidden_size,
            org_num_embeddings=config.vocab_size,
595
            quant_config=self.quant_config,
596
            prefix=maybe_prefix(prefix, "model.embed_tokens"),
597
598
599
        )

        # TODO: Optionally initializes this for supporting input embeddings.
600
        self.vision_embed_tokens = Phi3HDImageEmbedding(
601
            config,
602
            self.quant_config,
603
604
            prefix=maybe_prefix(prefix, "model.vision_embed_tokens"),
        )
605

606
607
608
609
610
611
612
613
614
615
616
        self.language_model = init_vllm_registered_model(
            vllm_config=vllm_config,
            # The prefix is empty intentionally because default prefix of
            # LlamaForCausalLM is "model"
            prefix="",
            # We don't directly initialize vLLM's LlamaForCausalLM so we
            # can automatically apply embedding wrapper if this model is
            # initialized as an embedding model
            architectures=["LlamaForCausalLM"],
        )

617
        self.make_empty_intermediate_tensors = (
618
619
            self.language_model.make_empty_intermediate_tensors
        )
620

621
    def _parse_and_validate_image_input(
622
        self, **kwargs: object
623
    ) -> Phi3VImageInputs | None:
624
625
        pixel_values = kwargs.pop("pixel_values", None)
        image_sizes = kwargs.pop("image_sizes", None)
626
        image_embeds = kwargs.pop("image_embeds", None)
627

628
629
630
631
632
633
        if pixel_values is None and image_embeds is None:
            return None

        if pixel_values is not None:
            return Phi3VImagePixelInputs(
                type="pixel_values",
634
635
                pixel_values=pixel_values,
                image_sizes=image_sizes,
636
637
                resolve_bindings={
                    "h": CLIP_VIT_LARGE_PATCH14_336_CONFIG.image_size,
638
639
640
                    "w": CLIP_VIT_LARGE_PATCH14_336_CONFIG.image_size,
                },
            )
641
642
643
644

        if image_embeds is not None:
            return Phi3VImageEmbeddingInputs(
                type="image_embeds",
645
                data=image_embeds,
646
647
648
649
650
651
652
653
654
            )

        raise AssertionError("This line should be unreachable.")

    def _process_image_input(
        self,
        image_input: Phi3VImageInputs,
    ) -> torch.Tensor:
        if image_input["type"] == "image_embeds":
655
            return image_input["data"]
656

657
        assert self.vision_embed_tokens is not None
658

659
660
661
        image_embeds = self.vision_embed_tokens(
            image_input["pixel_values"], image_input["image_sizes"]
        )
662

663
        return image_embeds
664

665
666
667
    def get_language_model(self) -> torch.nn.Module:
        return self.language_model

668
    def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings:
669
670
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input is None:
671
            return []
672
673
674
675
676
677
        vision_embeddings = self._process_image_input(image_input)
        return vision_embeddings

    def get_input_embeddings(
        self,
        input_ids: torch.Tensor,
678
        multimodal_embeddings: MultiModalEmbeddings | None = None,
679
        *,
680
        is_multimodal: torch.Tensor | None = None,
681
        handle_oov_mm_token: bool = False,
682
    ) -> torch.Tensor:
683
684
685
686
687
688
689
690
691
692
693
694
695
696
        inputs_embeds = self._get_text_embeddings(
            input_ids,
            self.embed_tokens,
            is_multimodal=is_multimodal,
            handle_oov_mm_token=handle_oov_mm_token,
        )

        if multimodal_embeddings is None or len(multimodal_embeddings) == 0:
            return inputs_embeds

        if is_multimodal is None:
            raise ValueError(
                "`get_input_embeddings` now requires `is_multimodal` arg, "
                "please update your model runner according to "
697
698
                "https://github.com/vllm-project/vllm/pull/16229."
            )
699
700
701
702
703
704

        return _merge_multimodal_embeddings(
            inputs_embeds=inputs_embeds,
            multimodal_embeddings=multimodal_embeddings,
            is_multimodal=is_multimodal,
        )
705

706
707
708
709
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
710
711
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
712
713
        **kwargs: object,
    ):
714
        if intermediate_tensors is not None:
715
            inputs_embeds = None
716

717
718
719
        hidden_states = self.language_model.model(
            input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds
        )
720
721
722

        return hidden_states

723
724
725
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
726
    ) -> torch.Tensor | None:
727
        return self.language_model.compute_logits(hidden_states)
728

729
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
730
        loader = AutoWeightsLoader(self)
731
        autoloaded_weights = loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
732
733
734

        # The HF config doesn't specify whether these are tied,
        # so we detect it this way
735
        if "embed_tokens.weight" not in autoloaded_weights:
736
            self.embed_tokens = self.language_model.model.embed_tokens
737
738
            autoloaded_weights.add("embed_tokens.weight")
        return autoloaded_weights