phi3v.py 25.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
9
10
11
12
13
14
15
16
17
# Copyright 2024 The vLLM team.
# Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
18
from collections.abc import Iterable, Mapping, Sequence
19
from typing import Annotated, Any, Literal, TypeAlias
20

21
import regex as re
22
23
import torch
import torch.nn as nn
24
25
26
27
28
29
from transformers import (
    BatchFeature,
    CLIPVisionConfig,
    PretrainedConfig,
    ProcessorMixin,
)
30

31
from vllm.config import VllmConfig
32
from vllm.config.multimodal import BaseDummyOptions
33
from vllm.logger import init_logger
34
from vllm.model_executor.layers.quantization import QuantizationConfig
35
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
36
from vllm.multimodal import MULTIMODAL_REGISTRY
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
from vllm.multimodal.inputs import (
    MultiModalDataDict,
    MultiModalFieldConfig,
    MultiModalKwargsItems,
)
from vllm.multimodal.parse import (
    ImageEmbeddingItems,
    ImageProcessorItems,
    ImageSize,
    MultiModalDataItems,
)
from vllm.multimodal.processing import (
    BaseMultiModalProcessor,
    BaseProcessingInfo,
    MultiModalPromptUpdates,
    PlaceholderFeaturesInfo,
    PromptReplacement,
    PromptUpdate,
    ResolvedPromptUpdate,
)
57
from vllm.multimodal.profiling import BaseDummyInputsBuilder
58
from vllm.sequence import IntermediateTensors
59
from vllm.utils.tensor_schema import TensorSchema, TensorShape
60

61
from .clip import CLIPVisionModel
62
63
64
65
66
from .interfaces import (
    MultiModalEmbeddings,
    SupportsMultiModal,
    SupportsPP,
    SupportsQuant,
67
    _require_is_multimodal,
68
69
70
71
72
73
74
75
)
from .utils import (
    AutoWeightsLoader,
    WeightsMapper,
    _merge_multimodal_embeddings,
    init_vllm_registered_model,
    maybe_prefix,
)
76

77
78
logger = init_logger(__name__)

79
80
81
# Cannot find the following 2 numbers from hf config.
_IMAGE_TOKEN_ID = 32044

82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
CLIP_VIT_LARGE_PATCH14_336_CONFIG = CLIPVisionConfig(
    dropout=0.0,
    hidden_act="quick_gelu",
    hidden_size=1024,
    image_size=336,
    intermediate_size=4096,
    num_attention_heads=16,
    num_channels=3,
    num_hidden_layers=24,
    patch_size=14,
    projection_dim=768,
)


def _init_img_processor(
    hf_config: PretrainedConfig,
98
    quant_config: QuantizationConfig | None,
99
100
    prefix: str = "",
) -> CLIPVisionModel:
101
    clip_config = CLIP_VIT_LARGE_PATCH14_336_CONFIG
102
    layer_idx = hf_config.img_processor.get("layer_idx", -2)
103
104
105

    # Initialize the CLIP only up to the required feature layer
    if layer_idx < 0:
106
        num_hidden_layers = clip_config.num_hidden_layers + layer_idx + 1
107
108
109
110
    else:
        num_hidden_layers = layer_idx + 1

    img_processor = CLIPVisionModel(
111
112
113
        clip_config,
        quant_config,
        num_hidden_layers_override=num_hidden_layers,
114
        prefix=prefix,
115
    )
116
117
118
119

    return img_processor


120
class Phi3VImagePixelInputs(TensorSchema):
121
    """
122
123
124
125
126
127
    Dimensions:
        - b: Batch size
        - n: Number of images
        - p: Number of patches
        - h: Height of each patch
        - w: Width of each patch
128
129
    """

130
    type: Literal["pixel_values", "image_embeds"] = "pixel_values"
131

132
    # Supports either a stacked tensor or a list of (p, 3, h, w) tensors
133
    pixel_values: Annotated[
134
        torch.Tensor | list[torch.Tensor],
135
136
137
        TensorShape(
            "bn", "p", 3, "h", "w", dynamic_dims={"p"}
        ),  # 'p' may vary across items
138
    ]
139

140
    # Stacked tensor with height and width for each image
141
    image_sizes: Annotated[torch.Tensor | None, TensorShape("bn", 2)]
142
143


144
class Phi3VImageEmbeddingInputs(TensorSchema):
145
    """
146
147
148
149
150
151
    Dimensions:
        - b: Batch size
        - n: Number of images
        - f: Image feature size (e.g., number of tokens per image)
        - h: Hidden size (must match language model backbone)
    """
152

153
154
    type: Literal["image_embeds"] = "image_embeds"
    data: Annotated[
155
        torch.Tensor | list[torch.Tensor],
156
157
        TensorShape("bn", "f", "h"),
    ]
158
159


160
Phi3VImageInputs: TypeAlias = Phi3VImagePixelInputs | Phi3VImageEmbeddingInputs
161
162


163
class Phi3ImageEmbeddingBase(nn.Module):
164
    def __init__(self) -> None:
165
166
167
168
169
        super().__init__()
        self.layer_idx: int
        self.type_feature: str
        self.img_processor: CLIPVisionModel

170
    def get_img_features(self, img_embeds: torch.FloatTensor) -> torch.FloatTensor:
171
172
        TYPE_FEATURE = self.type_feature

173
174
        # NOTE: we skip the step to select the vision feature layer since
        # this is already done inside the img_processor
175
        img_feature = self.img_processor(img_embeds)
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190

        if TYPE_FEATURE == "patch":
            patch_feature = img_feature[:, 1:]
            return patch_feature

        if TYPE_FEATURE == "cls_patch":
            return img_feature

        raise NotImplementedError


# adapted from https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_embedding_phi3_v.py
class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase):
    """Phi3 Image embedding with HD transform."""

191
192
193
    def __init__(
        self,
        config: PretrainedConfig,
194
        quant_config: QuantizationConfig | None,
195
196
        prefix: str = "",
    ) -> None:
197
        super().__init__()
198
199

        # n_embed or hidden_size
200
        hidden_size = config.n_embd if hasattr(config, "n_embd") else config.hidden_size
201

202
        self.img_processor = _init_img_processor(
203
204
            config, quant_config, prefix=f"{prefix}.img_processor"
        )
205

206
207
        image_dim_out = config.img_processor["image_dim_out"]
        self.num_img_tokens = config.img_processor["num_img_tokens"]
208
209
210
211

        self.image_dim_out = image_dim_out

        # global_gn and sub_gn for hd transform, serves as line separator
212
        self.use_hd_transform = config.embd_layer.get("use_hd_transform", False)
213
        self.with_learnable_separator = config.embd_layer.get(
214
215
216
            "with_learnable_separator", False
        )
        self.hd_transform_order = config.embd_layer.get("hd_transform_order", "glb_sub")
217
218
219
220
221
        # with_hd_transform and with_learnable_separator should have same value
        assert self.use_hd_transform and self.with_learnable_separator

        # 1024 * 4, merge spatial to channel dimension
        self.glb_GN = nn.Parameter(torch.empty([1, 1, self.image_dim_out * 4]))
222
        self.sub_GN = nn.Parameter(torch.empty([1, 1, 1, self.image_dim_out * 4]))
223
224
225
226
227

        dim_projection = hidden_size
        depth = 2
        layers = [nn.Linear(image_dim_out * 4, dim_projection)]
        for _ in range(1, depth):
228
            layers.extend([nn.GELU(), nn.Linear(dim_projection, dim_projection)])
229
230
        self.img_projection = nn.Sequential(*layers)

231
        self.type_feature = config.img_processor.get("type_feature", "patch")
232

233
234
235
    def forward(
        self, pixel_values: torch.FloatTensor, image_sizes: torch.Tensor
    ) -> torch.FloatTensor:
236
237
238
239
240
241
242
243
244
        """
        process image and return vision embeddings.

        pixel_values: (num_images, num_crops, c, h, w)
        output: (num_images, num_img_tokens, hidden_size)
        """
        num_images, num_crops, c, h, w = pixel_values.shape
        pixel_values = pixel_values.flatten(0, 1)
        img_features = self.get_img_features(pixel_values)
245
246
247
248
        img_features = img_features.reshape(
            num_images, num_crops, -1, self.image_dim_out
        )
        image_features_proj = self.hd_feature_transform(img_features, image_sizes)
249
250
251
252
253
254
        return image_features_proj

    def hd_feature_transform(self, image_features, image_sizes):
        """
        image_features: (num_images, num_crops+1, 24*24, 1024)
        """
255
256
257
        assert self.hd_transform_order == "sub_glb", (
            f"hd_transform_order `{self.hd_transform_order}` not implemented"
        )
258
259
260
261
262
263
264
        if isinstance(self.img_projection, nn.Sequential):
            target_device = self.img_projection[0].bias.device
            target_dtype = self.img_projection[0].bias.dtype
        else:  # It's a single nn.Linear layer
            target_device = self.img_projection.bias.device
            target_dtype = self.img_projection.bias.dtype

265
        global_image_features = image_features[:, 0]  # (num_images, 24*24, 1024)
266
267
        # global feature can be viewed as a special HD case with num_crops 1x1
        global_image_features_hd = self.reshape_hd_patches_2x2merge(
268
269
            global_image_features, 1, 1
        )
270
        global_image_features_hd_newline = self.add_image_newline(
271
272
            global_image_features_hd
        )
273

274
        batch_image_features_proj = []
275
276
277
278
279
280
281
282
283
284
        # need a for loop to process each image because of different image sizes
        # (patch arrangement is different for each image)
        for i, img_size in enumerate(image_sizes):
            h, w = img_size
            h_crop = h // 336
            w_crop = w // 336
            num_crops = h_crop * w_crop

            # NOTE: real num_crops is padded
            # (num_crops, 24*24, 1024)
285
            sub_image_features = image_features[i, 1 : 1 + num_crops]
286
            sub_image_features_hd = self.reshape_hd_patches_2x2merge(
287
288
                sub_image_features, h_crop, w_crop
            )
289
            sub_image_features_hd_newline = self.add_image_newline(
290
291
                sub_image_features_hd
            )
292
293

            # [sub features, separator, global features]
294
295
296
297
298
299
300
301
302
            image_embeddings = torch.cat(
                [
                    sub_image_features_hd_newline.squeeze(
                        0
                    ),  # (h_crop*12*(w_crop*12+1), 4096)
                    self.glb_GN.squeeze(0),
                    global_image_features_hd_newline[i],
                ]
            )
303
            img_proj = self.img_projection(
304
305
                image_embeddings.to(target_device, target_dtype)
            )
306
307
308
            batch_image_features_proj.append(img_proj)

        return batch_image_features_proj
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324

    def reshape_hd_patches_2x2merge(self, image_features, h_crop, w_crop):
        """
        image_features: (num_images*num_crops, 24*24, 1024)
        output: (num_images, h_crop*12, w_crop*12, 4096)
        where h_crop*w_crop == num_crops
        """
        N, L, C = image_features.shape
        assert L == 576 and C == 1024 and N % (h_crop * w_crop) == 0
        num_images = N // (h_crop * w_crop)
        H = int(L**0.5)
        image_features_hd = (
            image_features.reshape(N, H, H, C)  # N, 24, 24, 1024
            .reshape(N, H // 2, 2, H // 2, 2, C)  # N, 12, 2, 12, 2, 1024
            .permute(0, 1, 3, 2, 4, 5)  # N, 12, 12, 2, 2, 1024
            .reshape(N, -1, 4 * C)  # N, 144, 4096
325
326
327
            .reshape(
                num_images, h_crop, w_crop, H // 2, H // 2, -1
            )  # n_img, h_crop, w_crop, 12, 12, 4096
328
            .permute(0, 1, 3, 2, 4, 5)  # n_img, h_crop, 12, w_crop, 12, 4096
329
330
331
            .reshape(
                num_images, h_crop * H // 2, w_crop * H // 2, 4 * C
            )  # n_img, h_crop*12, w_crop*12, 4096
332
333
334
335
336
337
338
339
340
341
        )
        return image_features_hd

    def add_image_newline(self, image_features_hd):
        """
        image_features_hd: (num_images, h_crop*12, w_crop*12, 4096)
        output: (num_images, (h_crop*12) * (w_crop*12+1), 4096)
        """
        num_images, h, w, hid_dim = image_features_hd.shape
        # add the newline token to the HD image feature patches
342
343
344
        newline_embeddings = self.sub_GN.expand(
            num_images, h, -1, -1
        )  # (n_img, h, 1, hid_dim)
345
        image_features_hd_newline = torch.cat(
346
347
            [image_features_hd, newline_embeddings], dim=2
        ).reshape(num_images, -1, hid_dim)
348
        return image_features_hd_newline
349
350


351
class Phi3VProcessingInfo(BaseProcessingInfo):
352
    def get_supported_mm_limits(self) -> Mapping[str, int | None]:
353
354
        return {"image": None}

355
356
357
358
359
    def get_num_image_tokens(
        self,
        *,
        image_width: int,
        image_height: int,
360
        processor: ProcessorMixin | None = None,
361
362
363
364
365
366
367
368
369
370
    ) -> int:
        if processor is None:
            processor = self.get_hf_processor()

        return processor.calc_num_image_tokens_from_image_size(  # type: ignore
            width=image_width,
            height=image_height,
        )

    def get_image_size_with_most_features(self) -> ImageSize:
371
372
373
        # Result in the max possible feature size (h:w = 16:1)
        return ImageSize(height=8000, width=50)

374
375

class Phi3VDummyInputsBuilder(BaseDummyInputsBuilder[Phi3VProcessingInfo]):
376
377
378
379
380
381
382
383
384
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_images = mm_counts.get("image", 0)

        hf_processor = self.info.get_hf_processor()
        image_tokens: list[str] = hf_processor.img_tokens  # type: ignore

        return "".join(image_tokens[:num_images])

    def get_dummy_mm_data(
385
        self,
386
387
        seq_len: int,
        mm_counts: Mapping[str, int],
388
        mm_options: Mapping[str, BaseDummyOptions] | None = None,
389
    ) -> MultiModalDataDict:
390
        num_images = mm_counts.get("image", 0)
391

392
        target_width, target_height = self.info.get_image_size_with_most_features()
393

394
395
        image_overrides = mm_options.get("image") if mm_options else None

396
        return {
397
398
399
400
401
402
            "image": self._get_dummy_images(
                width=target_width,
                height=target_height,
                num_images=num_images,
                overrides=image_overrides,
            )
403
404
405
        }


406
class Phi3VMultiModalProcessor(BaseMultiModalProcessor[Phi3VProcessingInfo]):
407
    def _call_hf_processor(
408
409
        self,
        prompt: str,
410
411
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
412
        tok_kwargs: Mapping[str, object],
413
    ) -> BatchFeature:
414
415
        processed_outputs = super()._call_hf_processor(
            prompt=prompt,
416
417
            mm_data=mm_data,
            mm_kwargs=mm_kwargs,
418
            tok_kwargs=tok_kwargs,
419
420
        )

421
422
423
        input_ids = processed_outputs["input_ids"]
        assert isinstance(input_ids, torch.Tensor)

424
425
426
        # Phi3v processor has inserted -1, -2 etc as placeholder in prompt_ids,
        # which will cause OverflowError when decoding the prompt_ids.
        # Therefore, we need to do an early replacement here
427
        input_ids.masked_fill_(input_ids < 0, _IMAGE_TOKEN_ID)
428

429
430
        return processed_outputs

431
432
433
434
435
436
437
438
439
440
441
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return dict(
            pixel_values=MultiModalFieldConfig.batched("image"),
            image_sizes=MultiModalFieldConfig.batched("image"),
            image_embeds=MultiModalFieldConfig.batched("image"),
        )

442
    def _get_prompt_updates(
443
444
        self,
        mm_items: MultiModalDataItems,
445
        hf_processor_mm_kwargs: Mapping[str, Any],
446
        out_mm_kwargs: MultiModalKwargsItems,
447
    ) -> Sequence[PromptUpdate]:
448
        hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
449
450
451
        image_tokens: list[str] = hf_processor.img_tokens  # type: ignore

        def get_replacement_phi3v(item_idx: int):
452
            images = mm_items.get_items(
453
454
                "image", (ImageEmbeddingItems, ImageProcessorItems)
            )
455
456
457
458
459

            if isinstance(images, ImageEmbeddingItems):
                num_image_tokens = images.get_feature_size(item_idx)
            else:
                image_size = images.get_image_size(item_idx)
460
                num_image_tokens = self.info.get_num_image_tokens(
461
462
                    image_width=image_size.width,
                    image_height=image_size.height,
463
                    processor=hf_processor,
464
465
                )

466
            return [_IMAGE_TOKEN_ID] * num_image_tokens
467
468
469
470

        return [
            PromptReplacement(
                modality="image",
471
                target=image_tokens.__getitem__,
472
                replacement=get_replacement_phi3v,
473
            )
474
475
        ]

476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
    def _recompute_cached_prompt_update(
        self,
        cached_update: ResolvedPromptUpdate,
        new_item_idx: int,
    ) -> ResolvedPromptUpdate:
        new_update = super()._recompute_cached_prompt_update(
            cached_update,
            new_item_idx,
        )

        if cached_update.modality == "image":
            hf_processor = self.info.get_hf_processor()
            image_tokens: list[str] = hf_processor.img_tokens  # type: ignore
            new_update = new_update.with_target(image_tokens[new_item_idx])

        return new_update

493
    def _apply_prompt_updates(
494
495
        self,
        token_ids: list[int],
496
        mm_prompt_updates: MultiModalPromptUpdates,
497
    ) -> tuple[list[int], Mapping[str, list[PlaceholderFeaturesInfo]]]:
pansicheng's avatar
pansicheng committed
498
        # align to hf behavior when there are images
499
        if len(mm_prompt_updates):
pansicheng's avatar
pansicheng committed
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
            tokenizer = self.info.get_tokenizer()
            # to decode token_ids to the original text, we need to
            # 1. remove the first bos token
            # 2. remove space after each special token
            #    introduced by the tokenizer
            if len(token_ids) and token_ids[0] == tokenizer.bos_token_id:
                token_ids = token_ids[1:]
            text = tokenizer.decode(token_ids)
            for special_tokens in tokenizer.special_tokens_map.values():
                if isinstance(special_tokens, str):
                    text = text.replace(f"{special_tokens} ", special_tokens)
                elif isinstance(special_tokens, list):
                    for special_token in special_tokens:
                        text = text.replace(f"{special_token} ", special_token)
            # perform hf behavior
            # https://huggingface.co/microsoft/Phi-3.5-vision-instruct/blob/64f88b6/processing_phi3_v.py#L407
            pattern = r"<\|image_\d+\|>"
            prompt_chunks = [
518
                tokenizer(chunk).input_ids for chunk in re.split(pattern, text)
pansicheng's avatar
pansicheng committed
519
520
521
522
523
524
525
526
            ]
            image_tags = [
                tokenizer(chunk, add_special_tokens=False).input_ids
                for chunk in re.findall(pattern, text)
            ]
            if len(prompt_chunks) > len(image_tags):
                image_tags.append([])
            token_ids = [
527
528
529
530
                e
                for sublist in zip(prompt_chunks, image_tags)
                for ele in sublist
                for e in ele
pansicheng's avatar
pansicheng committed
531
532
            ]

533
        token_ids, placeholders = super()._apply_prompt_updates(
534
            token_ids=token_ids,
535
            mm_prompt_updates=mm_prompt_updates,
536
537
538
        )

        # Keep the behavior in line with HF processor
539
540
541
        if len(mm_prompt_updates) and (
            token_ids[:2] == tokenizer.encode("<s> <|image|>", add_special_tokens=False)
        ):
542
            token_ids = [token_ids[0], *token_ids[2:]]
543
544
            placeholders = {
                modality: [
545
                    PlaceholderFeaturesInfo(
546
547
548
                        modality=p.modality,
                        item_idx=p.item_idx,
                        start_idx=p.start_idx - 1,
549
                        tokens=p.tokens,
550
                        is_embed=p.is_embed,
551
552
                    )
                    for p in ps
553
554
555
                ]
                for modality, ps in placeholders.items()
            }
556

557
        return token_ids, placeholders
558

559

560
561
562
563
564
565
@MULTIMODAL_REGISTRY.register_processor(
    Phi3VMultiModalProcessor,
    info=Phi3VProcessingInfo,
    dummy_inputs=Phi3VDummyInputsBuilder,
)
class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsQuant):
566
567
568
569
570
571
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={
            "model.vision_embed_tokens.wte": "embed_tokens",
            "model.vision_embed_tokens.": "vision_embed_tokens.",
            "lm_head.": "language_model.lm_head.",
            "model.": "language_model.model.",
572
573
        }
    )
574

575
    @classmethod
576
    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
577
578
579
580
581
        if modality.startswith("image"):
            return f"<|image_{i}|>"

        raise ValueError("Only image modality is supported")

582
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
583
        super().__init__()
584
585
        config = vllm_config.model_config.hf_config
        multimodal_config = vllm_config.model_config.multimodal_config
586
        self.config = config
587
        self.multimodal_config = multimodal_config
588
        self.image_token_id = _IMAGE_TOKEN_ID
589

590
591
592
        self.embed_tokens = VocabParallelEmbedding(
            config.vocab_size,
            config.hidden_size,
593
            quant_config=self.quant_config,
594
            prefix=maybe_prefix(prefix, "model.embed_tokens"),
595
596
597
        )

        # TODO: Optionally initializes this for supporting input embeddings.
598
        self.vision_embed_tokens = Phi3HDImageEmbedding(
599
            config,
600
            self.quant_config,
601
602
            prefix=maybe_prefix(prefix, "model.vision_embed_tokens"),
        )
603

604
605
606
607
608
609
610
611
612
613
614
        self.language_model = init_vllm_registered_model(
            vllm_config=vllm_config,
            # The prefix is empty intentionally because default prefix of
            # LlamaForCausalLM is "model"
            prefix="",
            # We don't directly initialize vLLM's LlamaForCausalLM so we
            # can automatically apply embedding wrapper if this model is
            # initialized as an embedding model
            architectures=["LlamaForCausalLM"],
        )

615
        self.make_empty_intermediate_tensors = (
616
617
            self.language_model.make_empty_intermediate_tensors
        )
618

619
    def _parse_and_validate_image_input(
620
        self, **kwargs: object
621
    ) -> Phi3VImageInputs | None:
622
623
        pixel_values = kwargs.pop("pixel_values", None)
        image_sizes = kwargs.pop("image_sizes", None)
624
        image_embeds = kwargs.pop("image_embeds", None)
625

626
627
628
629
630
631
        if pixel_values is None and image_embeds is None:
            return None

        if pixel_values is not None:
            return Phi3VImagePixelInputs(
                type="pixel_values",
632
633
                pixel_values=pixel_values,
                image_sizes=image_sizes,
634
635
                resolve_bindings={
                    "h": CLIP_VIT_LARGE_PATCH14_336_CONFIG.image_size,
636
637
638
                    "w": CLIP_VIT_LARGE_PATCH14_336_CONFIG.image_size,
                },
            )
639
640
641
642

        if image_embeds is not None:
            return Phi3VImageEmbeddingInputs(
                type="image_embeds",
643
                data=image_embeds,
644
645
646
647
648
649
650
651
652
            )

        raise AssertionError("This line should be unreachable.")

    def _process_image_input(
        self,
        image_input: Phi3VImageInputs,
    ) -> torch.Tensor:
        if image_input["type"] == "image_embeds":
653
            return image_input["data"]
654

655
        assert self.vision_embed_tokens is not None
656

657
658
659
        image_embeds = self.vision_embed_tokens(
            image_input["pixel_values"], image_input["image_sizes"]
        )
660

661
        return image_embeds
662

663
664
665
    def get_language_model(self) -> torch.nn.Module:
        return self.language_model

666
    def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
667
668
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input is None:
669
            return []
670
671
672
        vision_embeddings = self._process_image_input(image_input)
        return vision_embeddings

673
    def embed_input_ids(
674
675
        self,
        input_ids: torch.Tensor,
676
        multimodal_embeddings: MultiModalEmbeddings | None = None,
677
        *,
678
        is_multimodal: torch.Tensor | None = None,
679
        handle_oov_mm_token: bool = False,
680
    ) -> torch.Tensor:
681
        inputs_embeds = self._embed_text_input_ids(
682
683
684
685
686
687
688
689
690
691
692
693
            input_ids,
            self.embed_tokens,
            is_multimodal=is_multimodal,
            handle_oov_mm_token=handle_oov_mm_token,
        )

        if multimodal_embeddings is None or len(multimodal_embeddings) == 0:
            return inputs_embeds

        return _merge_multimodal_embeddings(
            inputs_embeds=inputs_embeds,
            multimodal_embeddings=multimodal_embeddings,
694
            is_multimodal=_require_is_multimodal(is_multimodal),
695
        )
696

697
698
699
700
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
701
702
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
703
704
        **kwargs: object,
    ):
705
        if intermediate_tensors is not None:
706
            inputs_embeds = None
707

708
709
710
        hidden_states = self.language_model.model(
            input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds
        )
711
712
713

        return hidden_states

714
715
716
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
717
    ) -> torch.Tensor | None:
718
        return self.language_model.compute_logits(hidden_states)
719

720
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
721
        loader = AutoWeightsLoader(self)
722
        autoloaded_weights = loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
723
724
725

        # The HF config doesn't specify whether these are tied,
        # so we detect it this way
726
        if "embed_tokens.weight" not in autoloaded_weights:
727
            self.embed_tokens = self.language_model.model.embed_tokens
728
729
            autoloaded_weights.add("embed_tokens.weight")
        return autoloaded_weights