phi3v.py 25.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
9
10
11
12
13
14
15
16
17
# Copyright 2024 The vLLM team.
# Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
18
from collections.abc import Iterable, Mapping, Sequence
19
from typing import Annotated, Any, Literal, TypeAlias
20

21
import regex as re
22
23
import torch
import torch.nn as nn
24
25
26
27
28
29
from transformers import (
    BatchFeature,
    CLIPVisionConfig,
    PretrainedConfig,
    ProcessorMixin,
)
30

31
from vllm.config import VllmConfig
32
from vllm.config.multimodal import BaseDummyOptions, MultiModalConfig
33
from vllm.logger import init_logger
34
from vllm.model_executor.layers.quantization import QuantizationConfig
35
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
36
from vllm.multimodal import MULTIMODAL_REGISTRY
37
38
39
40
41
42
43
44
45
46
47
from vllm.multimodal.inputs import (
    MultiModalDataDict,
    MultiModalFieldConfig,
    MultiModalKwargsItems,
)
from vllm.multimodal.parse import (
    ImageEmbeddingItems,
    ImageProcessorItems,
    ImageSize,
    MultiModalDataItems,
)
48
49
from vllm.multimodal.processing import BaseDummyInputsBuilder
from vllm.multimodal.processing.processor import (
50
51
52
53
54
55
56
57
    BaseMultiModalProcessor,
    BaseProcessingInfo,
    MultiModalPromptUpdates,
    PlaceholderFeaturesInfo,
    PromptReplacement,
    PromptUpdate,
    ResolvedPromptUpdate,
)
58
from vllm.sequence import IntermediateTensors
59
from vllm.utils.tensor_schema import TensorSchema, TensorShape
60

61
from .clip import CLIPVisionModel
62
63
64
65
66
from .interfaces import (
    MultiModalEmbeddings,
    SupportsMultiModal,
    SupportsPP,
    SupportsQuant,
67
    _require_is_multimodal,
68
69
70
71
72
73
74
75
)
from .utils import (
    AutoWeightsLoader,
    WeightsMapper,
    _merge_multimodal_embeddings,
    init_vllm_registered_model,
    maybe_prefix,
)
76

77
78
logger = init_logger(__name__)

79
80
81
# Cannot find the following 2 numbers from hf config.
_IMAGE_TOKEN_ID = 32044

82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
CLIP_VIT_LARGE_PATCH14_336_CONFIG = CLIPVisionConfig(
    dropout=0.0,
    hidden_act="quick_gelu",
    hidden_size=1024,
    image_size=336,
    intermediate_size=4096,
    num_attention_heads=16,
    num_channels=3,
    num_hidden_layers=24,
    patch_size=14,
    projection_dim=768,
)


def _init_img_processor(
    hf_config: PretrainedConfig,
98
    quant_config: QuantizationConfig | None,
99
    multimodal_config: MultiModalConfig | None,
100
101
    prefix: str = "",
) -> CLIPVisionModel:
102
    clip_config = CLIP_VIT_LARGE_PATCH14_336_CONFIG
103
    layer_idx = hf_config.img_processor.get("layer_idx", -2)
104
105
106

    # Initialize the CLIP only up to the required feature layer
    if layer_idx < 0:
107
        num_hidden_layers = clip_config.num_hidden_layers + layer_idx + 1
108
109
110
111
    else:
        num_hidden_layers = layer_idx + 1

    img_processor = CLIPVisionModel(
112
        clip_config,
113
114
        quant_config=quant_config,
        multimodal_config=multimodal_config,
115
        num_hidden_layers_override=num_hidden_layers,
116
        prefix=prefix,
117
    )
118
119
120
121

    return img_processor


122
class Phi3VImagePixelInputs(TensorSchema):
123
    """
124
125
126
127
128
129
    Dimensions:
        - b: Batch size
        - n: Number of images
        - p: Number of patches
        - h: Height of each patch
        - w: Width of each patch
130
131
    """

132
    type: Literal["pixel_values", "image_embeds"] = "pixel_values"
133

134
    # Supports either a stacked tensor or a list of (p, 3, h, w) tensors
135
    pixel_values: Annotated[
136
        torch.Tensor | list[torch.Tensor],
137
138
139
        TensorShape(
            "bn", "p", 3, "h", "w", dynamic_dims={"p"}
        ),  # 'p' may vary across items
140
    ]
141

142
    # Stacked tensor with height and width for each image
143
    image_sizes: Annotated[torch.Tensor | None, TensorShape("bn", 2)]
144
145


146
class Phi3VImageEmbeddingInputs(TensorSchema):
147
    """
148
149
150
151
152
153
    Dimensions:
        - b: Batch size
        - n: Number of images
        - f: Image feature size (e.g., number of tokens per image)
        - h: Hidden size (must match language model backbone)
    """
154

155
156
    type: Literal["image_embeds"] = "image_embeds"
    data: Annotated[
157
        torch.Tensor | list[torch.Tensor],
158
159
        TensorShape("bn", "f", "h"),
    ]
160
161


162
Phi3VImageInputs: TypeAlias = Phi3VImagePixelInputs | Phi3VImageEmbeddingInputs
163
164


165
# adapted from https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_embedding_phi3_v.py
166
class Phi3HDImageEmbedding(nn.Module):
167
168
    """Phi3 Image embedding with HD transform."""

169
170
171
    def __init__(
        self,
        config: PretrainedConfig,
172
        quant_config: QuantizationConfig | None,
173
        multimodal_config: MultiModalConfig | None,
174
175
        prefix: str = "",
    ) -> None:
176
        super().__init__()
177
178

        # n_embed or hidden_size
179
        hidden_size = config.n_embd if hasattr(config, "n_embd") else config.hidden_size
180

181
        self.img_processor = _init_img_processor(
182
183
184
185
            config,
            quant_config=quant_config,
            multimodal_config=multimodal_config,
            prefix=f"{prefix}.img_processor",
186
        )
187

188
189
        image_dim_out = config.img_processor["image_dim_out"]
        self.num_img_tokens = config.img_processor["num_img_tokens"]
190
191
192
193

        self.image_dim_out = image_dim_out

        # global_gn and sub_gn for hd transform, serves as line separator
194
        self.use_hd_transform = config.embd_layer.get("use_hd_transform", False)
195
        self.with_learnable_separator = config.embd_layer.get(
196
197
198
            "with_learnable_separator", False
        )
        self.hd_transform_order = config.embd_layer.get("hd_transform_order", "glb_sub")
199
200
201
202
203
        # with_hd_transform and with_learnable_separator should have same value
        assert self.use_hd_transform and self.with_learnable_separator

        # 1024 * 4, merge spatial to channel dimension
        self.glb_GN = nn.Parameter(torch.empty([1, 1, self.image_dim_out * 4]))
204
        self.sub_GN = nn.Parameter(torch.empty([1, 1, 1, self.image_dim_out * 4]))
205
206
207

        dim_projection = hidden_size
        depth = 2
208
        layers: list[nn.Module] = [nn.Linear(image_dim_out * 4, dim_projection)]
209
        for _ in range(1, depth):
210
            layers.extend([nn.GELU(), nn.Linear(dim_projection, dim_projection)])
211
212
        self.img_projection = nn.Sequential(*layers)

213
        self.type_feature = config.img_processor.get("type_feature", "patch")
214

215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
    def get_img_features(self, img_embeds: torch.FloatTensor) -> torch.FloatTensor:
        type_feature = self.type_feature

        # NOTE: we skip the step to select the vision feature layer since
        # this is already done inside the img_processor
        img_feature = self.img_processor(img_embeds)

        if type_feature == "patch":
            patch_feature = img_feature[:, 1:]
            return patch_feature

        if type_feature == "cls_patch":
            return img_feature

        raise NotImplementedError(type_feature)

231
232
233
    def forward(
        self, pixel_values: torch.FloatTensor, image_sizes: torch.Tensor
    ) -> torch.FloatTensor:
234
235
236
237
238
239
240
241
242
        """
        process image and return vision embeddings.

        pixel_values: (num_images, num_crops, c, h, w)
        output: (num_images, num_img_tokens, hidden_size)
        """
        num_images, num_crops, c, h, w = pixel_values.shape
        pixel_values = pixel_values.flatten(0, 1)
        img_features = self.get_img_features(pixel_values)
243
244
245
246
        img_features = img_features.reshape(
            num_images, num_crops, -1, self.image_dim_out
        )
        image_features_proj = self.hd_feature_transform(img_features, image_sizes)
247
248
249
250
251
252
        return image_features_proj

    def hd_feature_transform(self, image_features, image_sizes):
        """
        image_features: (num_images, num_crops+1, 24*24, 1024)
        """
253
254
255
        assert self.hd_transform_order == "sub_glb", (
            f"hd_transform_order `{self.hd_transform_order}` not implemented"
        )
256
257
258
259
260
261
262
        if isinstance(self.img_projection, nn.Sequential):
            target_device = self.img_projection[0].bias.device
            target_dtype = self.img_projection[0].bias.dtype
        else:  # It's a single nn.Linear layer
            target_device = self.img_projection.bias.device
            target_dtype = self.img_projection.bias.dtype

263
        global_image_features = image_features[:, 0]  # (num_images, 24*24, 1024)
264
265
        # global feature can be viewed as a special HD case with num_crops 1x1
        global_image_features_hd = self.reshape_hd_patches_2x2merge(
266
267
            global_image_features, 1, 1
        )
268
        global_image_features_hd_newline = self.add_image_newline(
269
270
            global_image_features_hd
        )
271

272
        batch_image_features_proj = []
273
274
275
276
277
278
279
280
281
282
        # need a for loop to process each image because of different image sizes
        # (patch arrangement is different for each image)
        for i, img_size in enumerate(image_sizes):
            h, w = img_size
            h_crop = h // 336
            w_crop = w // 336
            num_crops = h_crop * w_crop

            # NOTE: real num_crops is padded
            # (num_crops, 24*24, 1024)
283
            sub_image_features = image_features[i, 1 : 1 + num_crops]
284
            sub_image_features_hd = self.reshape_hd_patches_2x2merge(
285
286
                sub_image_features, h_crop, w_crop
            )
287
            sub_image_features_hd_newline = self.add_image_newline(
288
289
                sub_image_features_hd
            )
290
291

            # [sub features, separator, global features]
292
293
294
295
296
297
298
299
300
            image_embeddings = torch.cat(
                [
                    sub_image_features_hd_newline.squeeze(
                        0
                    ),  # (h_crop*12*(w_crop*12+1), 4096)
                    self.glb_GN.squeeze(0),
                    global_image_features_hd_newline[i],
                ]
            )
301
            img_proj = self.img_projection(
302
303
                image_embeddings.to(target_device, target_dtype)
            )
304
305
306
            batch_image_features_proj.append(img_proj)

        return batch_image_features_proj
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322

    def reshape_hd_patches_2x2merge(self, image_features, h_crop, w_crop):
        """
        image_features: (num_images*num_crops, 24*24, 1024)
        output: (num_images, h_crop*12, w_crop*12, 4096)
        where h_crop*w_crop == num_crops
        """
        N, L, C = image_features.shape
        assert L == 576 and C == 1024 and N % (h_crop * w_crop) == 0
        num_images = N // (h_crop * w_crop)
        H = int(L**0.5)
        image_features_hd = (
            image_features.reshape(N, H, H, C)  # N, 24, 24, 1024
            .reshape(N, H // 2, 2, H // 2, 2, C)  # N, 12, 2, 12, 2, 1024
            .permute(0, 1, 3, 2, 4, 5)  # N, 12, 12, 2, 2, 1024
            .reshape(N, -1, 4 * C)  # N, 144, 4096
323
324
325
            .reshape(
                num_images, h_crop, w_crop, H // 2, H // 2, -1
            )  # n_img, h_crop, w_crop, 12, 12, 4096
326
            .permute(0, 1, 3, 2, 4, 5)  # n_img, h_crop, 12, w_crop, 12, 4096
327
328
329
            .reshape(
                num_images, h_crop * H // 2, w_crop * H // 2, 4 * C
            )  # n_img, h_crop*12, w_crop*12, 4096
330
331
332
333
334
335
336
337
338
339
        )
        return image_features_hd

    def add_image_newline(self, image_features_hd):
        """
        image_features_hd: (num_images, h_crop*12, w_crop*12, 4096)
        output: (num_images, (h_crop*12) * (w_crop*12+1), 4096)
        """
        num_images, h, w, hid_dim = image_features_hd.shape
        # add the newline token to the HD image feature patches
340
341
342
        newline_embeddings = self.sub_GN.expand(
            num_images, h, -1, -1
        )  # (n_img, h, 1, hid_dim)
343
        image_features_hd_newline = torch.cat(
344
345
            [image_features_hd, newline_embeddings], dim=2
        ).reshape(num_images, -1, hid_dim)
346
        return image_features_hd_newline
347
348


349
class Phi3VProcessingInfo(BaseProcessingInfo):
350
    def get_supported_mm_limits(self) -> Mapping[str, int | None]:
351
352
        return {"image": None}

353
354
355
356
357
    def get_num_image_tokens(
        self,
        *,
        image_width: int,
        image_height: int,
358
        processor: ProcessorMixin | None = None,
359
360
361
362
363
364
365
366
367
368
    ) -> int:
        if processor is None:
            processor = self.get_hf_processor()

        return processor.calc_num_image_tokens_from_image_size(  # type: ignore
            width=image_width,
            height=image_height,
        )

    def get_image_size_with_most_features(self) -> ImageSize:
369
370
371
        # Result in the max possible feature size (h:w = 16:1)
        return ImageSize(height=8000, width=50)

372
373

class Phi3VDummyInputsBuilder(BaseDummyInputsBuilder[Phi3VProcessingInfo]):
374
375
376
377
378
379
380
381
382
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_images = mm_counts.get("image", 0)

        hf_processor = self.info.get_hf_processor()
        image_tokens: list[str] = hf_processor.img_tokens  # type: ignore

        return "".join(image_tokens[:num_images])

    def get_dummy_mm_data(
383
        self,
384
385
        seq_len: int,
        mm_counts: Mapping[str, int],
386
        mm_options: Mapping[str, BaseDummyOptions] | None = None,
387
    ) -> MultiModalDataDict:
388
        num_images = mm_counts.get("image", 0)
389

390
        target_width, target_height = self.info.get_image_size_with_most_features()
391

392
393
        image_overrides = mm_options.get("image") if mm_options else None

394
        return {
395
396
397
398
399
400
            "image": self._get_dummy_images(
                width=target_width,
                height=target_height,
                num_images=num_images,
                overrides=image_overrides,
            )
401
402
403
        }


404
class Phi3VMultiModalProcessor(BaseMultiModalProcessor[Phi3VProcessingInfo]):
405
    def _call_hf_processor(
406
407
        self,
        prompt: str,
408
409
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
410
        tok_kwargs: Mapping[str, object],
411
    ) -> BatchFeature:
412
413
        processed_outputs = super()._call_hf_processor(
            prompt=prompt,
414
415
            mm_data=mm_data,
            mm_kwargs=mm_kwargs,
416
            tok_kwargs=tok_kwargs,
417
418
        )

419
420
421
        input_ids = processed_outputs["input_ids"]
        assert isinstance(input_ids, torch.Tensor)

422
423
424
        # Phi3v processor has inserted -1, -2 etc as placeholder in prompt_ids,
        # which will cause OverflowError when decoding the prompt_ids.
        # Therefore, we need to do an early replacement here
425
        input_ids.masked_fill_(input_ids < 0, _IMAGE_TOKEN_ID)
426

427
428
        return processed_outputs

429
430
431
432
433
434
435
436
437
438
439
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return dict(
            pixel_values=MultiModalFieldConfig.batched("image"),
            image_sizes=MultiModalFieldConfig.batched("image"),
            image_embeds=MultiModalFieldConfig.batched("image"),
        )

440
    def _get_prompt_updates(
441
442
        self,
        mm_items: MultiModalDataItems,
443
        hf_processor_mm_kwargs: Mapping[str, Any],
444
        out_mm_kwargs: MultiModalKwargsItems,
445
    ) -> Sequence[PromptUpdate]:
446
        hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
447
448
449
        image_tokens: list[str] = hf_processor.img_tokens  # type: ignore

        def get_replacement_phi3v(item_idx: int):
450
            images = mm_items.get_items(
451
452
                "image", (ImageEmbeddingItems, ImageProcessorItems)
            )
453
454
455
456
457

            if isinstance(images, ImageEmbeddingItems):
                num_image_tokens = images.get_feature_size(item_idx)
            else:
                image_size = images.get_image_size(item_idx)
458
                num_image_tokens = self.info.get_num_image_tokens(
459
460
                    image_width=image_size.width,
                    image_height=image_size.height,
461
                    processor=hf_processor,
462
463
                )

464
            return [_IMAGE_TOKEN_ID] * num_image_tokens
465
466
467
468

        return [
            PromptReplacement(
                modality="image",
469
                target=image_tokens.__getitem__,
470
                replacement=get_replacement_phi3v,
471
            )
472
473
        ]

474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
    def _recompute_cached_prompt_update(
        self,
        cached_update: ResolvedPromptUpdate,
        new_item_idx: int,
    ) -> ResolvedPromptUpdate:
        new_update = super()._recompute_cached_prompt_update(
            cached_update,
            new_item_idx,
        )

        if cached_update.modality == "image":
            hf_processor = self.info.get_hf_processor()
            image_tokens: list[str] = hf_processor.img_tokens  # type: ignore
            new_update = new_update.with_target(image_tokens[new_item_idx])

        return new_update

491
    def _apply_prompt_updates(
492
493
        self,
        token_ids: list[int],
494
        mm_prompt_updates: MultiModalPromptUpdates,
495
    ) -> tuple[list[int], Mapping[str, list[PlaceholderFeaturesInfo]]]:
pansicheng's avatar
pansicheng committed
496
        # align to hf behavior when there are images
497
        if len(mm_prompt_updates):
pansicheng's avatar
pansicheng committed
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
            tokenizer = self.info.get_tokenizer()
            # to decode token_ids to the original text, we need to
            # 1. remove the first bos token
            # 2. remove space after each special token
            #    introduced by the tokenizer
            if len(token_ids) and token_ids[0] == tokenizer.bos_token_id:
                token_ids = token_ids[1:]
            text = tokenizer.decode(token_ids)
            for special_tokens in tokenizer.special_tokens_map.values():
                if isinstance(special_tokens, str):
                    text = text.replace(f"{special_tokens} ", special_tokens)
                elif isinstance(special_tokens, list):
                    for special_token in special_tokens:
                        text = text.replace(f"{special_token} ", special_token)
            # perform hf behavior
            # https://huggingface.co/microsoft/Phi-3.5-vision-instruct/blob/64f88b6/processing_phi3_v.py#L407
            pattern = r"<\|image_\d+\|>"
            prompt_chunks = [
516
                tokenizer(chunk).input_ids for chunk in re.split(pattern, text)
pansicheng's avatar
pansicheng committed
517
518
519
520
521
522
523
524
            ]
            image_tags = [
                tokenizer(chunk, add_special_tokens=False).input_ids
                for chunk in re.findall(pattern, text)
            ]
            if len(prompt_chunks) > len(image_tags):
                image_tags.append([])
            token_ids = [
525
526
527
528
                e
                for sublist in zip(prompt_chunks, image_tags)
                for ele in sublist
                for e in ele
pansicheng's avatar
pansicheng committed
529
530
            ]

531
        token_ids, placeholders = super()._apply_prompt_updates(
532
            token_ids=token_ids,
533
            mm_prompt_updates=mm_prompt_updates,
534
535
536
        )

        # Keep the behavior in line with HF processor
537
538
539
        if len(mm_prompt_updates) and (
            token_ids[:2] == tokenizer.encode("<s> <|image|>", add_special_tokens=False)
        ):
540
            token_ids = [token_ids[0], *token_ids[2:]]
541
542
            placeholders = {
                modality: [
543
                    PlaceholderFeaturesInfo(
544
545
546
                        modality=p.modality,
                        item_idx=p.item_idx,
                        start_idx=p.start_idx - 1,
547
                        tokens=p.tokens,
548
                        is_embed=p.is_embed,
549
550
                    )
                    for p in ps
551
552
553
                ]
                for modality, ps in placeholders.items()
            }
554

555
        return token_ids, placeholders
556

557

558
559
560
561
562
563
@MULTIMODAL_REGISTRY.register_processor(
    Phi3VMultiModalProcessor,
    info=Phi3VProcessingInfo,
    dummy_inputs=Phi3VDummyInputsBuilder,
)
class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsQuant):
564
565
566
567
568
569
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={
            "model.vision_embed_tokens.wte": "embed_tokens",
            "model.vision_embed_tokens.": "vision_embed_tokens.",
            "lm_head.": "language_model.lm_head.",
            "model.": "language_model.model.",
570
571
        }
    )
572

573
    @classmethod
574
    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
575
576
577
578
579
        if modality.startswith("image"):
            return f"<|image_{i}|>"

        raise ValueError("Only image modality is supported")

580
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
581
        super().__init__()
582
        config = vllm_config.model_config.hf_config
583
        quant_config = vllm_config.quant_config
584
        multimodal_config = vllm_config.model_config.multimodal_config
585
        self.config = config
586
        self.multimodal_config = multimodal_config
587
        self.image_token_id = _IMAGE_TOKEN_ID
588

589
590
591
592
593
594
595
596
597
598
599
600
601
        with self._mark_tower_model(vllm_config, "image"):
            self.embed_tokens = VocabParallelEmbedding(
                config.vocab_size,
                config.hidden_size,
                quant_config=quant_config,
                prefix=maybe_prefix(prefix, "model.embed_tokens"),
            )
            self.vision_embed_tokens = Phi3HDImageEmbedding(
                config,
                quant_config=quant_config,
                multimodal_config=multimodal_config,
                prefix=maybe_prefix(prefix, "model.vision_embed_tokens"),
            )
602

603
604
605
606
607
608
609
610
611
612
613
        with self._mark_language_model(vllm_config):
            self.language_model = init_vllm_registered_model(
                vllm_config=vllm_config,
                # The prefix is empty intentionally because default prefix of
                # LlamaForCausalLM is "model"
                prefix="",
                # We don't directly initialize vLLM's LlamaForCausalLM so we
                # can automatically apply embedding wrapper if this model is
                # initialized as an embedding model
                architectures=["LlamaForCausalLM"],
            )
614

615
        self.make_empty_intermediate_tensors = (
616
617
            self.language_model.make_empty_intermediate_tensors
        )
618

619
    def _parse_and_validate_image_input(
620
        self, **kwargs: object
621
    ) -> Phi3VImageInputs | None:
622
623
        pixel_values = kwargs.pop("pixel_values", None)
        image_sizes = kwargs.pop("image_sizes", None)
624
        image_embeds = kwargs.pop("image_embeds", None)
625

626
627
628
629
630
631
        if pixel_values is None and image_embeds is None:
            return None

        if pixel_values is not None:
            return Phi3VImagePixelInputs(
                type="pixel_values",
632
633
                pixel_values=pixel_values,
                image_sizes=image_sizes,
634
635
                resolve_bindings={
                    "h": CLIP_VIT_LARGE_PATCH14_336_CONFIG.image_size,
636
637
638
                    "w": CLIP_VIT_LARGE_PATCH14_336_CONFIG.image_size,
                },
            )
639
640
641
642

        if image_embeds is not None:
            return Phi3VImageEmbeddingInputs(
                type="image_embeds",
643
                data=image_embeds,
644
645
646
647
648
649
650
651
652
            )

        raise AssertionError("This line should be unreachable.")

    def _process_image_input(
        self,
        image_input: Phi3VImageInputs,
    ) -> torch.Tensor:
        if image_input["type"] == "image_embeds":
653
            return image_input["data"]
654

655
656
657
        image_embeds = self.vision_embed_tokens(
            image_input["pixel_values"], image_input["image_sizes"]
        )
658

659
        return image_embeds
660

661
    def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
662
663
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input is None:
664
            return []
665
666
667
        vision_embeddings = self._process_image_input(image_input)
        return vision_embeddings

668
    def embed_input_ids(
669
670
        self,
        input_ids: torch.Tensor,
671
        multimodal_embeddings: MultiModalEmbeddings | None = None,
672
        *,
673
        is_multimodal: torch.Tensor | None = None,
674
        handle_oov_mm_token: bool = False,
675
    ) -> torch.Tensor:
676
        inputs_embeds = self._embed_text_input_ids(
677
678
679
680
681
682
683
684
685
686
687
688
            input_ids,
            self.embed_tokens,
            is_multimodal=is_multimodal,
            handle_oov_mm_token=handle_oov_mm_token,
        )

        if multimodal_embeddings is None or len(multimodal_embeddings) == 0:
            return inputs_embeds

        return _merge_multimodal_embeddings(
            inputs_embeds=inputs_embeds,
            multimodal_embeddings=multimodal_embeddings,
689
            is_multimodal=_require_is_multimodal(is_multimodal),
690
        )
691

692
693
694
695
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
696
697
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
698
699
        **kwargs: object,
    ):
700
        if intermediate_tensors is not None:
701
            inputs_embeds = None
702

703
704
705
        hidden_states = self.language_model.model(
            input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds
        )
706
707
708

        return hidden_states

709
710
711
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
712
    ) -> torch.Tensor | None:
713
        return self.language_model.compute_logits(hidden_states)
714

715
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
716
        loader = AutoWeightsLoader(self)
717
        autoloaded_weights = loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
718
719
720

        # The HF config doesn't specify whether these are tied,
        # so we detect it this way
721
        if "embed_tokens.weight" not in autoloaded_weights:
722
            self.embed_tokens = self.language_model.model.embed_tokens
723
724
            autoloaded_weights.add("embed_tokens.weight")
        return autoloaded_weights