phi3v.py 25 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
7
8
9
10
11
12
13
14
15
16
17
# Copyright 2024 The vLLM team.
# Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
18
from collections.abc import Iterable, Mapping, Sequence
19
from typing import Annotated, Any, Literal, TypeAlias
20

21
import regex as re
22
23
import torch
import torch.nn as nn
24
25
26
27
28
29
from transformers import (
    BatchFeature,
    CLIPVisionConfig,
    PretrainedConfig,
    ProcessorMixin,
)
30

31
from vllm.config import VllmConfig
32
from vllm.config.multimodal import BaseDummyOptions
33
from vllm.logger import init_logger
34
from vllm.model_executor.layers.quantization import QuantizationConfig
35
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
36
from vllm.multimodal import MULTIMODAL_REGISTRY
37
38
39
40
41
42
43
44
45
46
47
from vllm.multimodal.inputs import (
    MultiModalDataDict,
    MultiModalFieldConfig,
    MultiModalKwargsItems,
)
from vllm.multimodal.parse import (
    ImageEmbeddingItems,
    ImageProcessorItems,
    ImageSize,
    MultiModalDataItems,
)
48
49
from vllm.multimodal.processing import BaseDummyInputsBuilder
from vllm.multimodal.processing.processor import (
50
51
52
53
54
55
56
57
    BaseMultiModalProcessor,
    BaseProcessingInfo,
    MultiModalPromptUpdates,
    PlaceholderFeaturesInfo,
    PromptReplacement,
    PromptUpdate,
    ResolvedPromptUpdate,
)
58
from vllm.sequence import IntermediateTensors
59
from vllm.utils.tensor_schema import TensorSchema, TensorShape
60

61
from .clip import CLIPVisionModel
62
63
64
65
66
from .interfaces import (
    MultiModalEmbeddings,
    SupportsMultiModal,
    SupportsPP,
    SupportsQuant,
67
    _require_is_multimodal,
68
69
70
71
72
73
74
75
)
from .utils import (
    AutoWeightsLoader,
    WeightsMapper,
    _merge_multimodal_embeddings,
    init_vllm_registered_model,
    maybe_prefix,
)
76

77
78
logger = init_logger(__name__)

79
80
81
# Cannot find the following 2 numbers from hf config.
_IMAGE_TOKEN_ID = 32044

82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
CLIP_VIT_LARGE_PATCH14_336_CONFIG = CLIPVisionConfig(
    dropout=0.0,
    hidden_act="quick_gelu",
    hidden_size=1024,
    image_size=336,
    intermediate_size=4096,
    num_attention_heads=16,
    num_channels=3,
    num_hidden_layers=24,
    patch_size=14,
    projection_dim=768,
)


def _init_img_processor(
    hf_config: PretrainedConfig,
98
    quant_config: QuantizationConfig | None,
99
100
    prefix: str = "",
) -> CLIPVisionModel:
101
    clip_config = CLIP_VIT_LARGE_PATCH14_336_CONFIG
102
    layer_idx = hf_config.img_processor.get("layer_idx", -2)
103
104
105

    # Initialize the CLIP only up to the required feature layer
    if layer_idx < 0:
106
        num_hidden_layers = clip_config.num_hidden_layers + layer_idx + 1
107
108
109
110
    else:
        num_hidden_layers = layer_idx + 1

    img_processor = CLIPVisionModel(
111
        clip_config,
112
        quant_config=quant_config,
113
        num_hidden_layers_override=num_hidden_layers,
114
        prefix=prefix,
115
    )
116
117
118
119

    return img_processor


120
class Phi3VImagePixelInputs(TensorSchema):
121
    """
122
123
124
125
126
127
    Dimensions:
        - b: Batch size
        - n: Number of images
        - p: Number of patches
        - h: Height of each patch
        - w: Width of each patch
128
129
    """

130
    type: Literal["pixel_values", "image_embeds"] = "pixel_values"
131

132
    # Supports either a stacked tensor or a list of (p, 3, h, w) tensors
133
    pixel_values: Annotated[
134
        torch.Tensor | list[torch.Tensor],
135
136
137
        TensorShape(
            "bn", "p", 3, "h", "w", dynamic_dims={"p"}
        ),  # 'p' may vary across items
138
    ]
139

140
    # Stacked tensor with height and width for each image
141
    image_sizes: Annotated[torch.Tensor | None, TensorShape("bn", 2)]
142
143


144
class Phi3VImageEmbeddingInputs(TensorSchema):
145
    """
146
147
148
149
150
151
    Dimensions:
        - b: Batch size
        - n: Number of images
        - f: Image feature size (e.g., number of tokens per image)
        - h: Hidden size (must match language model backbone)
    """
152

153
154
    type: Literal["image_embeds"] = "image_embeds"
    data: Annotated[
155
        torch.Tensor | list[torch.Tensor],
156
157
        TensorShape("bn", "f", "h"),
    ]
158
159


160
Phi3VImageInputs: TypeAlias = Phi3VImagePixelInputs | Phi3VImageEmbeddingInputs
161
162


163
# adapted from https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_embedding_phi3_v.py
164
class Phi3HDImageEmbedding(nn.Module):
165
166
    """Phi3 Image embedding with HD transform."""

167
168
169
    def __init__(
        self,
        config: PretrainedConfig,
170
        quant_config: QuantizationConfig | None,
171
172
        prefix: str = "",
    ) -> None:
173
        super().__init__()
174
175

        # n_embed or hidden_size
176
        hidden_size = config.n_embd if hasattr(config, "n_embd") else config.hidden_size
177

178
        self.img_processor = _init_img_processor(
179
180
181
            config,
            quant_config=quant_config,
            prefix=f"{prefix}.img_processor",
182
        )
183

184
185
        image_dim_out = config.img_processor["image_dim_out"]
        self.num_img_tokens = config.img_processor["num_img_tokens"]
186
187
188
189

        self.image_dim_out = image_dim_out

        # global_gn and sub_gn for hd transform, serves as line separator
190
        self.use_hd_transform = config.embd_layer.get("use_hd_transform", False)
191
        self.with_learnable_separator = config.embd_layer.get(
192
193
194
            "with_learnable_separator", False
        )
        self.hd_transform_order = config.embd_layer.get("hd_transform_order", "glb_sub")
195
196
197
198
199
        # with_hd_transform and with_learnable_separator should have same value
        assert self.use_hd_transform and self.with_learnable_separator

        # 1024 * 4, merge spatial to channel dimension
        self.glb_GN = nn.Parameter(torch.empty([1, 1, self.image_dim_out * 4]))
200
        self.sub_GN = nn.Parameter(torch.empty([1, 1, 1, self.image_dim_out * 4]))
201
202
203

        dim_projection = hidden_size
        depth = 2
204
        layers: list[nn.Module] = [nn.Linear(image_dim_out * 4, dim_projection)]
205
        for _ in range(1, depth):
206
            layers.extend([nn.GELU(), nn.Linear(dim_projection, dim_projection)])
207
208
        self.img_projection = nn.Sequential(*layers)

209
        self.type_feature = config.img_processor.get("type_feature", "patch")
210

211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
    def get_img_features(self, img_embeds: torch.FloatTensor) -> torch.FloatTensor:
        type_feature = self.type_feature

        # NOTE: we skip the step to select the vision feature layer since
        # this is already done inside the img_processor
        img_feature = self.img_processor(img_embeds)

        if type_feature == "patch":
            patch_feature = img_feature[:, 1:]
            return patch_feature

        if type_feature == "cls_patch":
            return img_feature

        raise NotImplementedError(type_feature)

227
228
229
    def forward(
        self, pixel_values: torch.FloatTensor, image_sizes: torch.Tensor
    ) -> torch.FloatTensor:
230
231
232
233
234
235
236
237
238
        """
        process image and return vision embeddings.

        pixel_values: (num_images, num_crops, c, h, w)
        output: (num_images, num_img_tokens, hidden_size)
        """
        num_images, num_crops, c, h, w = pixel_values.shape
        pixel_values = pixel_values.flatten(0, 1)
        img_features = self.get_img_features(pixel_values)
239
240
241
242
        img_features = img_features.reshape(
            num_images, num_crops, -1, self.image_dim_out
        )
        image_features_proj = self.hd_feature_transform(img_features, image_sizes)
243
244
245
246
247
248
        return image_features_proj

    def hd_feature_transform(self, image_features, image_sizes):
        """
        image_features: (num_images, num_crops+1, 24*24, 1024)
        """
249
250
251
        assert self.hd_transform_order == "sub_glb", (
            f"hd_transform_order `{self.hd_transform_order}` not implemented"
        )
252
253
254
255
256
257
258
        if isinstance(self.img_projection, nn.Sequential):
            target_device = self.img_projection[0].bias.device
            target_dtype = self.img_projection[0].bias.dtype
        else:  # It's a single nn.Linear layer
            target_device = self.img_projection.bias.device
            target_dtype = self.img_projection.bias.dtype

259
        global_image_features = image_features[:, 0]  # (num_images, 24*24, 1024)
260
261
        # global feature can be viewed as a special HD case with num_crops 1x1
        global_image_features_hd = self.reshape_hd_patches_2x2merge(
262
263
            global_image_features, 1, 1
        )
264
        global_image_features_hd_newline = self.add_image_newline(
265
266
            global_image_features_hd
        )
267

268
        batch_image_features_proj = []
269
270
271
272
273
274
275
276
277
278
        # need a for loop to process each image because of different image sizes
        # (patch arrangement is different for each image)
        for i, img_size in enumerate(image_sizes):
            h, w = img_size
            h_crop = h // 336
            w_crop = w // 336
            num_crops = h_crop * w_crop

            # NOTE: real num_crops is padded
            # (num_crops, 24*24, 1024)
279
            sub_image_features = image_features[i, 1 : 1 + num_crops]
280
            sub_image_features_hd = self.reshape_hd_patches_2x2merge(
281
282
                sub_image_features, h_crop, w_crop
            )
283
            sub_image_features_hd_newline = self.add_image_newline(
284
285
                sub_image_features_hd
            )
286
287

            # [sub features, separator, global features]
288
289
290
291
292
293
294
295
296
            image_embeddings = torch.cat(
                [
                    sub_image_features_hd_newline.squeeze(
                        0
                    ),  # (h_crop*12*(w_crop*12+1), 4096)
                    self.glb_GN.squeeze(0),
                    global_image_features_hd_newline[i],
                ]
            )
297
            img_proj = self.img_projection(
298
299
                image_embeddings.to(target_device, target_dtype)
            )
300
301
302
            batch_image_features_proj.append(img_proj)

        return batch_image_features_proj
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318

    def reshape_hd_patches_2x2merge(self, image_features, h_crop, w_crop):
        """
        image_features: (num_images*num_crops, 24*24, 1024)
        output: (num_images, h_crop*12, w_crop*12, 4096)
        where h_crop*w_crop == num_crops
        """
        N, L, C = image_features.shape
        assert L == 576 and C == 1024 and N % (h_crop * w_crop) == 0
        num_images = N // (h_crop * w_crop)
        H = int(L**0.5)
        image_features_hd = (
            image_features.reshape(N, H, H, C)  # N, 24, 24, 1024
            .reshape(N, H // 2, 2, H // 2, 2, C)  # N, 12, 2, 12, 2, 1024
            .permute(0, 1, 3, 2, 4, 5)  # N, 12, 12, 2, 2, 1024
            .reshape(N, -1, 4 * C)  # N, 144, 4096
319
320
321
            .reshape(
                num_images, h_crop, w_crop, H // 2, H // 2, -1
            )  # n_img, h_crop, w_crop, 12, 12, 4096
322
            .permute(0, 1, 3, 2, 4, 5)  # n_img, h_crop, 12, w_crop, 12, 4096
323
324
325
            .reshape(
                num_images, h_crop * H // 2, w_crop * H // 2, 4 * C
            )  # n_img, h_crop*12, w_crop*12, 4096
326
327
328
329
330
331
332
333
334
335
        )
        return image_features_hd

    def add_image_newline(self, image_features_hd):
        """
        image_features_hd: (num_images, h_crop*12, w_crop*12, 4096)
        output: (num_images, (h_crop*12) * (w_crop*12+1), 4096)
        """
        num_images, h, w, hid_dim = image_features_hd.shape
        # add the newline token to the HD image feature patches
336
337
338
        newline_embeddings = self.sub_GN.expand(
            num_images, h, -1, -1
        )  # (n_img, h, 1, hid_dim)
339
        image_features_hd_newline = torch.cat(
340
341
            [image_features_hd, newline_embeddings], dim=2
        ).reshape(num_images, -1, hid_dim)
342
        return image_features_hd_newline
343
344


345
class Phi3VProcessingInfo(BaseProcessingInfo):
346
    def get_supported_mm_limits(self) -> Mapping[str, int | None]:
347
348
        return {"image": None}

349
350
351
352
353
    def get_num_image_tokens(
        self,
        *,
        image_width: int,
        image_height: int,
354
        processor: ProcessorMixin,
355
356
357
358
359
360
361
    ) -> int:
        return processor.calc_num_image_tokens_from_image_size(  # type: ignore
            width=image_width,
            height=image_height,
        )

    def get_image_size_with_most_features(self) -> ImageSize:
362
363
364
        # Result in the max possible feature size (h:w = 16:1)
        return ImageSize(height=8000, width=50)

365
366

class Phi3VDummyInputsBuilder(BaseDummyInputsBuilder[Phi3VProcessingInfo]):
367
368
369
370
371
372
373
374
375
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_images = mm_counts.get("image", 0)

        hf_processor = self.info.get_hf_processor()
        image_tokens: list[str] = hf_processor.img_tokens  # type: ignore

        return "".join(image_tokens[:num_images])

    def get_dummy_mm_data(
376
        self,
377
378
        seq_len: int,
        mm_counts: Mapping[str, int],
379
        mm_options: Mapping[str, BaseDummyOptions],
380
    ) -> MultiModalDataDict:
381
        num_images = mm_counts.get("image", 0)
382

383
        target_width, target_height = self.info.get_image_size_with_most_features()
384

385
        image_overrides = mm_options.get("image")
386

387
        return {
388
389
390
391
392
393
            "image": self._get_dummy_images(
                width=target_width,
                height=target_height,
                num_images=num_images,
                overrides=image_overrides,
            )
394
395
396
        }


397
class Phi3VMultiModalProcessor(BaseMultiModalProcessor[Phi3VProcessingInfo]):
398
    def _call_hf_processor(
399
400
        self,
        prompt: str,
401
402
        mm_data: Mapping[str, object],
        mm_kwargs: Mapping[str, object],
403
        tok_kwargs: Mapping[str, object],
404
    ) -> BatchFeature:
405
406
        processed_outputs = super()._call_hf_processor(
            prompt=prompt,
407
408
            mm_data=mm_data,
            mm_kwargs=mm_kwargs,
409
            tok_kwargs=tok_kwargs,
410
411
        )

412
413
414
        input_ids = processed_outputs["input_ids"]
        assert isinstance(input_ids, torch.Tensor)

415
416
417
        # Phi3v processor has inserted -1, -2 etc as placeholder in prompt_ids,
        # which will cause OverflowError when decoding the prompt_ids.
        # Therefore, we need to do an early replacement here
418
        input_ids.masked_fill_(input_ids < 0, _IMAGE_TOKEN_ID)
419

420
421
        return processed_outputs

422
423
424
425
426
427
428
429
430
431
432
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return dict(
            pixel_values=MultiModalFieldConfig.batched("image"),
            image_sizes=MultiModalFieldConfig.batched("image"),
            image_embeds=MultiModalFieldConfig.batched("image"),
        )

433
    def _get_prompt_updates(
434
435
        self,
        mm_items: MultiModalDataItems,
436
        hf_processor_mm_kwargs: Mapping[str, Any],
437
        out_mm_kwargs: MultiModalKwargsItems,
438
    ) -> Sequence[PromptUpdate]:
439
        hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
440
441
442
        image_tokens: list[str] = hf_processor.img_tokens  # type: ignore

        def get_replacement_phi3v(item_idx: int):
443
            images = mm_items.get_items(
444
445
                "image", (ImageEmbeddingItems, ImageProcessorItems)
            )
446
447
448
449
450

            if isinstance(images, ImageEmbeddingItems):
                num_image_tokens = images.get_feature_size(item_idx)
            else:
                image_size = images.get_image_size(item_idx)
451
                num_image_tokens = self.info.get_num_image_tokens(
452
453
                    image_width=image_size.width,
                    image_height=image_size.height,
454
                    processor=hf_processor,
455
456
                )

457
            return [_IMAGE_TOKEN_ID] * num_image_tokens
458
459
460
461

        return [
            PromptReplacement(
                modality="image",
462
                target=image_tokens.__getitem__,
463
                replacement=get_replacement_phi3v,
464
            )
465
466
        ]

467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
    def _recompute_cached_prompt_update(
        self,
        cached_update: ResolvedPromptUpdate,
        new_item_idx: int,
    ) -> ResolvedPromptUpdate:
        new_update = super()._recompute_cached_prompt_update(
            cached_update,
            new_item_idx,
        )

        if cached_update.modality == "image":
            hf_processor = self.info.get_hf_processor()
            image_tokens: list[str] = hf_processor.img_tokens  # type: ignore
            new_update = new_update.with_target(image_tokens[new_item_idx])

        return new_update

484
    def _apply_prompt_updates(
485
486
        self,
        token_ids: list[int],
487
        mm_prompt_updates: MultiModalPromptUpdates,
488
    ) -> tuple[list[int], Mapping[str, list[PlaceholderFeaturesInfo]]]:
pansicheng's avatar
pansicheng committed
489
        # align to hf behavior when there are images
490
        if len(mm_prompt_updates):
pansicheng's avatar
pansicheng committed
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
            tokenizer = self.info.get_tokenizer()
            # to decode token_ids to the original text, we need to
            # 1. remove the first bos token
            # 2. remove space after each special token
            #    introduced by the tokenizer
            if len(token_ids) and token_ids[0] == tokenizer.bos_token_id:
                token_ids = token_ids[1:]
            text = tokenizer.decode(token_ids)
            for special_tokens in tokenizer.special_tokens_map.values():
                if isinstance(special_tokens, str):
                    text = text.replace(f"{special_tokens} ", special_tokens)
                elif isinstance(special_tokens, list):
                    for special_token in special_tokens:
                        text = text.replace(f"{special_token} ", special_token)
            # perform hf behavior
            # https://huggingface.co/microsoft/Phi-3.5-vision-instruct/blob/64f88b6/processing_phi3_v.py#L407
            pattern = r"<\|image_\d+\|>"
            prompt_chunks = [
509
                tokenizer(chunk).input_ids for chunk in re.split(pattern, text)
pansicheng's avatar
pansicheng committed
510
511
512
513
514
515
516
517
            ]
            image_tags = [
                tokenizer(chunk, add_special_tokens=False).input_ids
                for chunk in re.findall(pattern, text)
            ]
            if len(prompt_chunks) > len(image_tags):
                image_tags.append([])
            token_ids = [
518
519
520
521
                e
                for sublist in zip(prompt_chunks, image_tags)
                for ele in sublist
                for e in ele
pansicheng's avatar
pansicheng committed
522
523
            ]

524
        token_ids, placeholders = super()._apply_prompt_updates(
525
            token_ids=token_ids,
526
            mm_prompt_updates=mm_prompt_updates,
527
528
529
        )

        # Keep the behavior in line with HF processor
530
531
532
        if len(mm_prompt_updates) and (
            token_ids[:2] == tokenizer.encode("<s> <|image|>", add_special_tokens=False)
        ):
533
            token_ids = [token_ids[0], *token_ids[2:]]
534
535
            placeholders = {
                modality: [
536
                    PlaceholderFeaturesInfo(
537
538
539
                        modality=p.modality,
                        item_idx=p.item_idx,
                        start_idx=p.start_idx - 1,
540
                        tokens=p.tokens,
541
                        is_embed=p.is_embed,
542
543
                    )
                    for p in ps
544
545
546
                ]
                for modality, ps in placeholders.items()
            }
547

548
        return token_ids, placeholders
549

550

551
552
553
554
555
556
@MULTIMODAL_REGISTRY.register_processor(
    Phi3VMultiModalProcessor,
    info=Phi3VProcessingInfo,
    dummy_inputs=Phi3VDummyInputsBuilder,
)
class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsQuant):
557
558
559
560
561
562
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={
            "model.vision_embed_tokens.wte": "embed_tokens",
            "model.vision_embed_tokens.": "vision_embed_tokens.",
            "lm_head.": "language_model.lm_head.",
            "model.": "language_model.model.",
563
564
        }
    )
565

566
    @classmethod
567
    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
568
569
570
571
572
        if modality.startswith("image"):
            return f"<|image_{i}|>"

        raise ValueError("Only image modality is supported")

573
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
574
        super().__init__()
575
        config = vllm_config.model_config.hf_config
576
        quant_config = vllm_config.quant_config
577
        multimodal_config = vllm_config.model_config.multimodal_config
578
        self.config = config
579
        self.multimodal_config = multimodal_config
580
        self.image_token_id = _IMAGE_TOKEN_ID
581

582
583
584
585
586
587
588
589
590
591
592
593
        with self._mark_tower_model(vllm_config, "image"):
            self.embed_tokens = VocabParallelEmbedding(
                config.vocab_size,
                config.hidden_size,
                quant_config=quant_config,
                prefix=maybe_prefix(prefix, "model.embed_tokens"),
            )
            self.vision_embed_tokens = Phi3HDImageEmbedding(
                config,
                quant_config=quant_config,
                prefix=maybe_prefix(prefix, "model.vision_embed_tokens"),
            )
594

595
596
597
598
599
600
601
602
603
604
605
        with self._mark_language_model(vllm_config):
            self.language_model = init_vllm_registered_model(
                vllm_config=vllm_config,
                # The prefix is empty intentionally because default prefix of
                # LlamaForCausalLM is "model"
                prefix="",
                # We don't directly initialize vLLM's LlamaForCausalLM so we
                # can automatically apply embedding wrapper if this model is
                # initialized as an embedding model
                architectures=["LlamaForCausalLM"],
            )
606

607
        self.make_empty_intermediate_tensors = (
608
609
            self.language_model.make_empty_intermediate_tensors
        )
610

611
    def _parse_and_validate_image_input(
612
        self, **kwargs: object
613
    ) -> Phi3VImageInputs | None:
614
615
        pixel_values = kwargs.pop("pixel_values", None)
        image_sizes = kwargs.pop("image_sizes", None)
616
        image_embeds = kwargs.pop("image_embeds", None)
617

618
619
620
621
622
623
        if pixel_values is None and image_embeds is None:
            return None

        if pixel_values is not None:
            return Phi3VImagePixelInputs(
                type="pixel_values",
624
625
                pixel_values=pixel_values,
                image_sizes=image_sizes,
626
627
                resolve_bindings={
                    "h": CLIP_VIT_LARGE_PATCH14_336_CONFIG.image_size,
628
629
630
                    "w": CLIP_VIT_LARGE_PATCH14_336_CONFIG.image_size,
                },
            )
631
632
633
634

        if image_embeds is not None:
            return Phi3VImageEmbeddingInputs(
                type="image_embeds",
635
                data=image_embeds,
636
637
638
639
640
641
642
643
644
            )

        raise AssertionError("This line should be unreachable.")

    def _process_image_input(
        self,
        image_input: Phi3VImageInputs,
    ) -> torch.Tensor:
        if image_input["type"] == "image_embeds":
645
            return image_input["data"]
646

647
648
649
        image_embeds = self.vision_embed_tokens(
            image_input["pixel_values"], image_input["image_sizes"]
        )
650

651
        return image_embeds
652

653
    def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
654
655
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input is None:
656
            return []
657
658
659
        vision_embeddings = self._process_image_input(image_input)
        return vision_embeddings

660
    def embed_input_ids(
661
662
        self,
        input_ids: torch.Tensor,
663
        multimodal_embeddings: MultiModalEmbeddings | None = None,
664
        *,
665
        is_multimodal: torch.Tensor | None = None,
666
        handle_oov_mm_token: bool = False,
667
    ) -> torch.Tensor:
668
        inputs_embeds = self._embed_text_input_ids(
669
670
671
672
673
674
675
676
677
678
679
680
            input_ids,
            self.embed_tokens,
            is_multimodal=is_multimodal,
            handle_oov_mm_token=handle_oov_mm_token,
        )

        if multimodal_embeddings is None or len(multimodal_embeddings) == 0:
            return inputs_embeds

        return _merge_multimodal_embeddings(
            inputs_embeds=inputs_embeds,
            multimodal_embeddings=multimodal_embeddings,
681
            is_multimodal=_require_is_multimodal(is_multimodal),
682
        )
683

684
685
    def forward(
        self,
686
        input_ids: torch.Tensor | None,
687
        positions: torch.Tensor,
688
689
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
690
691
        **kwargs: object,
    ):
692
        if intermediate_tensors is not None:
693
            inputs_embeds = None
694

695
696
697
        hidden_states = self.language_model.model(
            input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds
        )
698
699
700

        return hidden_states

701
702
703
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
704
    ) -> torch.Tensor | None:
705
        return self.language_model.compute_logits(hidden_states)
706

707
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
708
        loader = AutoWeightsLoader(self)
709
        autoloaded_weights = loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
710
711
712

        # The HF config doesn't specify whether these are tied,
        # so we detect it this way
713
        if "embed_tokens.weight" not in autoloaded_weights:
714
            self.embed_tokens = self.language_model.model.embed_tokens
715
716
            autoloaded_weights.add("embed_tokens.weight")
        return autoloaded_weights