phi3v.py 28 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# coding=utf-8
# Copyright 2024 The vLLM team.
# Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
16
import itertools
17
import re
18
from functools import cached_property, lru_cache
19
20
from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional,
                    Tuple, TypedDict, Union)
21

22
import numpy as np
23
24
import torch
import torch.nn as nn
25
from PIL import Image
26
from transformers import CLIPVisionConfig, PretrainedConfig
27
28

from vllm.attention import AttentionMetadata
29
from vllm.config import CacheConfig, ModelConfig, MultiModalConfig
30
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
31
from vllm.logger import init_logger
32
from vllm.model_executor.layers.quantization import QuantizationConfig
33
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
34
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
35
from vllm.model_executor.models.clip import CLIPVisionModel
36
from vllm.model_executor.models.llama import LlamaForCausalLM
37
from vllm.model_executor.sampling_metadata import SamplingMetadata
38
from vllm.multimodal import MULTIMODAL_REGISTRY
39
from vllm.multimodal.utils import cached_get_tokenizer, repeat_and_pad_token
40
from vllm.sequence import IntermediateTensors
41
from vllm.utils import is_list_of
42

43
from .clip import dummy_image_for_clip, dummy_seq_data_for_clip
44
45
46
from .interfaces import SupportsMultiModal, SupportsPP
from .utils import (flatten_bn, group_weights_with_prefix,
                    merge_multimodal_embeddings)
47

48
49
logger = init_logger(__name__)

50
51
52
53
_KEYS_TO_MODIFY_MAPPING = {
    "model.vision_embed_tokens": "vision_embed_tokens",
}

54
55
56
# Cannot find the following 2 numbers from hf config.
_IMAGE_TOKEN_ID = 32044

57
58
59
60
# Result in the max possible feature size (h:w = 16:1)
MAX_IMAGE_FEATURE_SIZE_HEIGHT = 8000
MAX_IMAGE_FEATURE_SIZE_WIDTH = 50

61
62
63
64
65
66
67
68
69
70
71
72
CLIP_VIT_LARGE_PATCH14_336_CONFIG = CLIPVisionConfig(dropout=0.0,
                                                     hidden_act="quick_gelu",
                                                     hidden_size=1024,
                                                     image_size=336,
                                                     intermediate_size=4096,
                                                     num_attention_heads=16,
                                                     num_channels=3,
                                                     num_hidden_layers=24,
                                                     patch_size=14,
                                                     projection_dim=768)


73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
def _init_img_processor(hf_config: PretrainedConfig):
    clip_config = CLIP_VIT_LARGE_PATCH14_336_CONFIG
    layer_idx = hf_config.img_processor.get('layer_idx', -2)

    # Initialize the CLIP only up to the required feature layer
    if layer_idx < 0:
        num_hidden_layers = clip_config.num_hidden_layers + \
            layer_idx + 1
    else:
        num_hidden_layers = layer_idx + 1

    img_processor = CLIPVisionModel(
        clip_config, num_hidden_layers_override=num_hidden_layers)

    return img_processor


90
91
92
93
class Phi3VImagePixelInputs(TypedDict):
    type: Literal["pixel_values"]
    data: Union[torch.Tensor, List[torch.Tensor]]
    """
94
95
    Shape:
    `(batch_size * num_images, 1 + num_patches, num_channels, height, width)`
96

97
98
    Note that `num_patches` may be different per batch and image,
    in which case the data is passed as a list instead of a batched tensor.
99
100
101
102
    """

    image_sizes: torch.Tensor
    """
103
    Shape: `(batch_size * num_images, 2)`
104
105
106
107
108
109
110
111

    This should be in `(height, width)` format.
    """


class Phi3VImageEmbeddingInputs(TypedDict):
    type: Literal["image_embeds"]
    data: Union[torch.Tensor, List[torch.Tensor]]
112
    """Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
113
114
115
116
117
118
119
120

    `hidden_size` must match the hidden size of language model backbone.
    """


Phi3VImageInputs = Union[Phi3VImagePixelInputs, Phi3VImageEmbeddingInputs]


121
122
class Phi3ImageEmbeddingBase(nn.Module):

123
    def __init__(self) -> None:
124
125
126
127
128
129
130
131
132
        super().__init__()
        self.layer_idx: int
        self.type_feature: str
        self.img_processor: CLIPVisionModel

    def get_img_features(self,
                         img_embeds: torch.FloatTensor) -> torch.FloatTensor:
        TYPE_FEATURE = self.type_feature

133
134
        # NOTE: we skip the step to select the vision feature layer since
        # this is already done inside the img_processor
135
        img_feature = self.img_processor(img_embeds)
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150

        if TYPE_FEATURE == "patch":
            patch_feature = img_feature[:, 1:]
            return patch_feature

        if TYPE_FEATURE == "cls_patch":
            return img_feature

        raise NotImplementedError


# adapted from https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_embedding_phi3_v.py
class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase):
    """Phi3 Image embedding with HD transform."""

151
152
    def __init__(self, config: PretrainedConfig) -> None:
        super().__init__()
153
154
155
156
157

        # n_embed or hidden_size
        hidden_size = config.n_embd if hasattr(
            config, 'n_embd') else config.hidden_size

158
        self.img_processor = _init_img_processor(config)
159

160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
        image_dim_out = config.img_processor['image_dim_out']
        self.num_img_tokens = config.img_processor['num_img_tokens']

        self.image_dim_out = image_dim_out

        # global_gn and sub_gn for hd transform, serves as line separator
        self.use_hd_transform = config.embd_layer.get('use_hd_transform',
                                                      False)
        self.with_learnable_separator = config.embd_layer.get(
            'with_learnable_separator', False)
        self.hd_transform_order = config.embd_layer.get(
            'hd_transform_order', 'glb_sub')
        # with_hd_transform and with_learnable_separator should have same value
        assert self.use_hd_transform and self.with_learnable_separator

        # 1024 * 4, merge spatial to channel dimension
        self.glb_GN = nn.Parameter(torch.empty([1, 1, self.image_dim_out * 4]))
        self.sub_GN = nn.Parameter(
            torch.empty([1, 1, 1, self.image_dim_out * 4]))

        dim_projection = hidden_size
        depth = 2
        layers = [nn.Linear(image_dim_out * 4, dim_projection)]
        for _ in range(1, depth):
            layers.extend(
                [nn.GELU(),
                 nn.Linear(dim_projection, dim_projection)])
        self.img_projection = nn.Sequential(*layers)

        self.type_feature = config.img_processor.get('type_feature', 'patch')

191
    def forward(self, pixel_values: torch.FloatTensor,
192
                image_sizes: torch.Tensor) -> torch.FloatTensor:
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
        """
        process image and return vision embeddings.

        pixel_values: (num_images, num_crops, c, h, w)
        output: (num_images, num_img_tokens, hidden_size)
        """
        num_images, num_crops, c, h, w = pixel_values.shape
        pixel_values = pixel_values.flatten(0, 1)
        img_features = self.get_img_features(pixel_values)
        img_features = img_features.reshape(num_images, num_crops, -1,
                                            self.image_dim_out)
        image_features_proj = self.hd_feature_transform(
            img_features, image_sizes)
        return image_features_proj

    def hd_feature_transform(self, image_features, image_sizes):
        """
        image_features: (num_images, num_crops+1, 24*24, 1024)
        """
        assert (
            self.hd_transform_order == 'sub_glb'
        ), f'hd_transform_order `{self.hd_transform_order}` not implemented'
        if isinstance(self.img_projection, nn.Sequential):
            target_device = self.img_projection[0].bias.device
            target_dtype = self.img_projection[0].bias.dtype
        else:  # It's a single nn.Linear layer
            target_device = self.img_projection.bias.device
            target_dtype = self.img_projection.bias.dtype

        global_image_features = image_features[:,
                                               0]  # (num_images, 24*24, 1024)
        # global feature can be viewed as a special HD case with num_crops 1x1
        global_image_features_hd = self.reshape_hd_patches_2x2merge(
            global_image_features, 1, 1)
        global_image_features_hd_newline = self.add_image_newline(
            global_image_features_hd)

230
        batch_image_features_proj = []
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
        # need a for loop to process each image because of different image sizes
        # (patch arrangement is different for each image)
        for i, img_size in enumerate(image_sizes):
            h, w = img_size
            h_crop = h // 336
            w_crop = w // 336
            num_crops = h_crop * w_crop

            # NOTE: real num_crops is padded
            # (num_crops, 24*24, 1024)
            sub_image_features = image_features[i, 1:1 + num_crops]
            sub_image_features_hd = self.reshape_hd_patches_2x2merge(
                sub_image_features, h_crop, w_crop)
            sub_image_features_hd_newline = self.add_image_newline(
                sub_image_features_hd)

            # [sub features, separator, global features]
248
249
250
251
252
253
254
255
256
257
258
            image_embeddings = torch.cat([
                sub_image_features_hd_newline.squeeze(
                    0),  # (h_crop*12*(w_crop*12+1), 4096)
                self.glb_GN.squeeze(0),
                global_image_features_hd_newline[i],
            ])
            img_proj = self.img_projection(
                image_embeddings.to(target_device, target_dtype))
            batch_image_features_proj.append(img_proj)

        return batch_image_features_proj
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295

    def reshape_hd_patches_2x2merge(self, image_features, h_crop, w_crop):
        """
        image_features: (num_images*num_crops, 24*24, 1024)
        output: (num_images, h_crop*12, w_crop*12, 4096)
        where h_crop*w_crop == num_crops
        """
        N, L, C = image_features.shape
        assert L == 576 and C == 1024 and N % (h_crop * w_crop) == 0
        num_images = N // (h_crop * w_crop)
        H = int(L**0.5)
        image_features_hd = (
            image_features.reshape(N, H, H, C)  # N, 24, 24, 1024
            .reshape(N, H // 2, 2, H // 2, 2, C)  # N, 12, 2, 12, 2, 1024
            .permute(0, 1, 3, 2, 4, 5)  # N, 12, 12, 2, 2, 1024
            .reshape(N, -1, 4 * C)  # N, 144, 4096
            .reshape(num_images, h_crop, w_crop, H // 2, H // 2,
                     -1)  # n_img, h_crop, w_crop, 12, 12, 4096
            .permute(0, 1, 3, 2, 4, 5)  # n_img, h_crop, 12, w_crop, 12, 4096
            .reshape(num_images, h_crop * H // 2, w_crop * H // 2,
                     4 * C)  # n_img, h_crop*12, w_crop*12, 4096
        )
        return image_features_hd

    def add_image_newline(self, image_features_hd):
        """
        image_features_hd: (num_images, h_crop*12, w_crop*12, 4096)
        output: (num_images, (h_crop*12) * (w_crop*12+1), 4096)
        """
        num_images, h, w, hid_dim = image_features_hd.shape
        # add the newline token to the HD image feature patches
        newline_embeddings = self.sub_GN.expand(num_images, h, -1,
                                                -1)  # (n_img, h, 1, hid_dim)
        image_features_hd_newline = torch.cat(
            [image_features_hd, newline_embeddings],
            dim=2).reshape(num_images, -1, hid_dim)
        return image_features_hd_newline
296

297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
        # prepare weight iterators for components
        weights_group = group_weights_with_prefix(weights)

        # load vision encoder
        self.img_processor.load_weights(weights_group["img_processor"])

        # load glb_GN
        for name, loaded_weight in weights_group["glb_GN"]:
            assert name == ""
            param = self.glb_GN
            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
            weight_loader(param, loaded_weight)

        # load sub_GN
        for name, loaded_weight in weights_group["sub_GN"]:
            assert name == ""
            param = self.sub_GN
            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
            weight_loader(param, loaded_weight)

        # load mlp projector
        mlp_params_dict = dict(self.img_projection.named_parameters())
        for name, loaded_weight in weights_group["img_projection"]:
            param = mlp_params_dict[name]
            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
            weight_loader(param, loaded_weight)

328

329
# Based on https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py#L57
330
def _calc_padded_size(*, width: int, height: int, padding_unit: int = 336):
331
332
333
334
335
336
337
338
    target_height = int(np.ceil(height / padding_unit) * padding_unit)
    top_padding = int((target_height - height) / 2)
    bottom_padding = target_height - height - top_padding
    padded_width = width
    padded_height = height + top_padding + bottom_padding
    return padded_width, padded_height


339
# Based on https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py#L90
340
def _calc_hd_transform_size(*, width: int, height: int, hd_num: int):
341
342
343
344
345
346
347
348
349
350
351
352
353
354
    transposed = False
    if width < height:
        width, height = height, width
        transposed = True

    ratio = width / height
    scale = 1
    while scale * np.ceil(scale / ratio) <= hd_num:
        scale += 1
    scale -= 1

    new_width = int(scale * 336)
    new_height = int(new_width / ratio)

355
356
    padded_width, padded_height = _calc_padded_size(width=new_width,
                                                    height=new_height)
357
358
359
360
361
362
363

    if transposed:
        padded_width, padded_height = padded_height, padded_width

    return padded_width, padded_height


364
365
# Based on https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py#L181
def get_phi3v_image_feature_size(
366
    hf_config: Dict[str, Any],
367
368
369
    *,
    input_height: int,
    input_width: int,
370
    num_crops: int,
371
) -> int:
372
373
    if num_crops is None:
        num_crops = hf_config.get("num_crops", 16)
374
375
376
377
378
379
380
    new_width, new_height = _calc_hd_transform_size(width=input_width,
                                                    height=input_height,
                                                    hd_num=num_crops)

    return (new_height // 336 * new_width // 336 + 1) * 144 + 1 \
        + (new_height // 336 + 1) * 12

381

382
383
384
def get_max_phi3v_image_tokens(ctx: InputContext,
                               *,
                               num_crops: Optional[int] = None):
385
386

    return get_phi3v_image_feature_size(
387
        ctx.get_hf_image_processor_config(),
388
389
        input_height=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
        input_width=MAX_IMAGE_FEATURE_SIZE_WIDTH,
390
        num_crops=num_crops,
391
392
393
    )


394
395
396
397
398
def dummy_data_for_phi3v(ctx: InputContext,
                         seq_len: int,
                         mm_counts: Mapping[str, int],
                         *,
                         num_crops: Optional[int] = None):
399
    num_images = mm_counts["image"]
400

401
    image_feature_size = get_max_phi3v_image_tokens(ctx, num_crops=num_crops)
402

403
404
405
    seq_data = dummy_seq_data_for_clip(
        CLIP_VIT_LARGE_PATCH14_336_CONFIG,
        seq_len,
406
        num_images,
407
        image_token_id=_IMAGE_TOKEN_ID,
408
409
410
411
        image_feature_size_override=image_feature_size,
    )
    mm_data = dummy_image_for_clip(
        CLIP_VIT_LARGE_PATCH14_336_CONFIG,
412
        num_images,
413
414
        image_width_override=MAX_IMAGE_FEATURE_SIZE_WIDTH,
        image_height_override=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
415
    )
416

417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
    return seq_data, mm_data


# Reserve this function to also handle placeholders for additional images
# [ref: PR #5820]
@lru_cache
def _get_image_placeholder_token_ids(model_config: ModelConfig,
                                     idx: int) -> List[int]:
    assert idx > 0

    tokenizer = cached_get_tokenizer(model_config.tokenizer)

    # We need to get the token for "<", not "▁<"
    # https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/raw/main/tokenizer.json
    a_token_id, = tokenizer.encode("a", add_special_tokens=False)
    a_token_id_, *image_placeholder_token_ids = tokenizer.encode(
        f"a<|image_{idx}|>", add_special_tokens=False)
    assert a_token_id == a_token_id_

    return image_placeholder_token_ids


439
440
441
442
def input_processor_for_phi3v(ctx: InputContext,
                              llm_inputs: LLMInputs,
                              *,
                              num_crops: Optional[int] = None):
443
444
445
    multi_modal_data = llm_inputs.get("multi_modal_data")
    if multi_modal_data is None or "image" not in multi_modal_data:
        return llm_inputs
446

447
    model_config = ctx.model_config
448
    hf_config = ctx.get_hf_image_processor_config()
449
450
451
452

    image_data = multi_modal_data["image"]
    if isinstance(image_data, Image.Image):
        w, h = image_data.size
453
454
455
        image_feature_size = [
            get_phi3v_image_feature_size(hf_config,
                                         input_width=w,
456
457
                                         input_height=h,
                                         num_crops=num_crops)
458
459
460
461
462
463
464
465
466
        ]
        image_data = [image_data]
    elif is_list_of(image_data, Image.Image):
        image_feature_size = []
        for image in image_data:
            w, h = image.size
            image_feature_size.append(
                get_phi3v_image_feature_size(hf_config,
                                             input_width=w,
467
468
                                             input_height=h,
                                             num_crops=num_crops))
469
    elif isinstance(image_data, torch.Tensor):
470
471
472
        num_images, image_feature_size, hidden_size = image_data.shape
    elif is_list_of(image_data, torch.Tensor):
        image_feature_size = [item.shape[1] for item in image_data]
473
474
475
476
477
    else:
        raise TypeError(f"Invalid image type: {type(image_data)}")

    prompt = llm_inputs.get("prompt")
    if prompt is None:
478
479
480
        # for async server request, we assume prompt and its token_ids is always
        # in correct format. And num_image_tags == len(image_data) always True.
        image_idx = range(1, len(image_data) + 1)
481
482
        new_prompt = None
    else:
483
        image_idx = sorted(map(int, re.findall(r"<\|image_(\d+)\|>+", prompt)))
484
485
486
487
        if prompt.count("<|image|>") > 0:
            logger.warning("Please follow the prompt format that is "
                           "documented on HuggingFace which does not involve "
                           "repeating <|image|> tokens.")
488
489
490
        elif (num_image_tags := len(image_idx)) > 1:
            assert num_image_tags == len(
                image_data), "The count of image_placeholder not match image's"
491
492
        new_prompt = prompt

493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
    prompt_token_ids = llm_inputs["prompt_token_ids"].copy()

    # masked place_holder with image token id
    for idx in image_idx:
        image_token_ids = _get_image_placeholder_token_ids(model_config,
                                                           idx=idx)
        for i in range(len(prompt_token_ids) - len(image_token_ids) + 1):
            if prompt_token_ids[i:i + len(image_token_ids)] == image_token_ids:
                prompt_token_ids[i:i + len(image_token_ids)] = [
                    _IMAGE_TOKEN_ID
                ] * len(image_token_ids)
                break

    # merge consecutive tag ids
    merged_token_ids: List[int] = []
    for is_placeholder, token_ids in itertools.groupby(
            prompt_token_ids, lambda x: x == _IMAGE_TOKEN_ID):
        if is_placeholder:
            merged_token_ids.append(_IMAGE_TOKEN_ID)
        else:
            merged_token_ids.extend(list(token_ids))
514

515
    # TODO: Move this to utils or integrate with clip.
516
    new_token_ids: List[int] = []
517
518
519
520
521
522
523
524
525
526
    placeholder_idx = 0
    while merged_token_ids:
        token_id = merged_token_ids.pop(0)
        if token_id == _IMAGE_TOKEN_ID:
            new_token_ids.extend(
                repeat_and_pad_token(
                    _IMAGE_TOKEN_ID,
                    repeat_count=image_feature_size[placeholder_idx],
                ))
            placeholder_idx += 1
527
        else:
528
            new_token_ids.append(token_id)
529
530
531
532
533

    # NOTE: Create a defensive copy of the original inputs
    llm_inputs = LLMInputs(prompt_token_ids=new_token_ids,
                           prompt=new_prompt,
                           multi_modal_data=multi_modal_data)
534
    return llm_inputs
535

536
537

@MULTIMODAL_REGISTRY.register_image_input_mapper()
538
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_phi3v_image_tokens)
539
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_phi3v)
540
@INPUT_REGISTRY.register_input_processor(input_processor_for_phi3v)
541
class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
542
543
544

    def __init__(self,
                 config: PretrainedConfig,
545
                 multimodal_config: MultiModalConfig,
546
547
                 cache_config: Optional[CacheConfig] = None,
                 quant_config: Optional[QuantizationConfig] = None) -> None:
548
549
        super().__init__()

550
        self.config = config
551
        self.multimodal_config = multimodal_config
552
        self.image_token_id = _IMAGE_TOKEN_ID
553

554
        # TODO: Optionally initializes this for supporting embeddings.
555
        self.vision_embed_tokens = Phi3HDImageEmbedding(config)
556
557
558
559
560
561
562
563
564
565
566
567
568

        self.language_model = LlamaForCausalLM(config, cache_config,
                                               quant_config)

        self.make_empty_intermediate_tensors = (
            self.language_model.make_empty_intermediate_tensors)

    @cached_property
    def sampler(self):
        if hasattr(self.language_model, "sampler"):
            return self.language_model.sampler

        return Sampler()
569

570
    def _validate_image_sizes(self, data: torch.Tensor) -> torch.Tensor:
571
572
573
574
575
576
577
578
579
580
581
582
583
        expected_dims = (2, )

        def _validate_shape(d: torch.Tensor):
            actual_dims = tuple(d.shape)

            if actual_dims != expected_dims:
                expected_expr = str(expected_dims)
                raise ValueError(
                    f"The expected shape of image sizes per image per batch "
                    f"is {expected_expr}. You supplied {tuple(d.shape)}.")

        for d in data:
            _validate_shape(d)
584
585
586
587
588
589
590

        return data

    def _validate_pixel_values(
        self, data: Union[torch.Tensor, List[torch.Tensor]]
    ) -> Union[torch.Tensor, List[torch.Tensor]]:

591
592
593
594
595
596
597
598
        h = w = CLIP_VIT_LARGE_PATCH14_336_CONFIG.image_size
        expected_dims = (3, h, w)

        def _validate_shape(d: torch.Tensor):
            actual_dims = tuple(d.shape[1:])

            if actual_dims != expected_dims:
                expected_expr = ("num_patches", *map(str, expected_dims))
599
                raise ValueError(
600
                    "The expected shape of pixel values per image per batch "
601
                    f"is {expected_expr}. You supplied {tuple(d.shape)}.")
602

603
604
        for d in data:
            _validate_shape(d)
605
606
607

        return data

608
    def _parse_and_validate_image_input(
609
            self, **kwargs: object) -> Optional[Phi3VImageInputs]:
610
611
        pixel_values = kwargs.pop("pixel_values", None)
        image_sizes = kwargs.pop("image_sizes", None)
612
        image_embeds = kwargs.pop("image_embeds", None)
613

614
615
616
        if pixel_values is None:
            return None

617
618
619
620
621
622
623
624
        if pixel_values is None and image_embeds is None:
            return None

        if pixel_values is not None:
            if not isinstance(pixel_values, (torch.Tensor, list)):
                raise ValueError("Incorrect type of pixel values. "
                                 f"Got type: {type(pixel_values)}")

625
            if not isinstance(image_sizes, (torch.Tensor, list)):
626
627
628
629
630
                raise ValueError("Incorrect type of image sizes. "
                                 f"Got type: {type(image_sizes)}")

            return Phi3VImagePixelInputs(
                type="pixel_values",
631
632
633
                data=self._validate_pixel_values(flatten_bn(pixel_values)),
                image_sizes=self._validate_image_sizes(
                    flatten_bn(image_sizes, concat=True)))
634
635
636
637
638

        if image_embeds is not None:
            if not isinstance(image_embeds, torch.Tensor):
                raise ValueError("Incorrect type of image embeddings. "
                                 f"Got type: {type(image_embeds)}")
639

640
641
            return Phi3VImageEmbeddingInputs(
                type="image_embeds",
642
                data=flatten_bn(image_embeds),
643
644
645
646
647
648
649
650
651
652
653
            )

        raise AssertionError("This line should be unreachable.")

    def _process_image_input(
        self,
        image_input: Phi3VImageInputs,
    ) -> torch.Tensor:

        if image_input["type"] == "image_embeds":
            return image_input["data"]
654

655
656
657
        assert self.vision_embed_tokens is not None
        image_embeds = self.vision_embed_tokens(image_input["data"],
                                                image_input["image_sizes"])
658

659
        return image_embeds
660

661
662
663
    def forward(self,
                input_ids: torch.Tensor,
                positions: torch.Tensor,
664
                kv_caches: List[torch.Tensor],
665
666
667
                attn_metadata: AttentionMetadata,
                intermediate_tensors: Optional[IntermediateTensors] = None,
                **kwargs: object):
668
        if intermediate_tensors is not None:
669
670
            input_ids = None
            inputs_embeds = None
671
672
673
674
675
676
677
678
679
680
681
682
683
        else:
            image_input = self._parse_and_validate_image_input(**kwargs)

            if image_input is not None:
                vision_embeddings = self._process_image_input(image_input)
                inputs_embeds = self.language_model.model.get_input_embeddings(
                    input_ids)
                inputs_embeds = merge_multimodal_embeddings(
                    input_ids, inputs_embeds, vision_embeddings,
                    self.image_token_id)
                input_ids = None
            else:
                inputs_embeds = None
684

685
686
687
688
689
690
        hidden_states = self.language_model.model(input_ids,
                                                  positions,
                                                  kv_caches,
                                                  attn_metadata,
                                                  intermediate_tensors,
                                                  inputs_embeds=inputs_embeds)
691
692
693

        return hidden_states

694
695
696
697
698
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
699
700
        return self.language_model.compute_logits(hidden_states,
                                                  sampling_metadata)
701
702
703
704
705
706

    def sample(
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
707
        return self.language_model.sample(logits, sampling_metadata)
708
709

    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
710
711
712
713
714
        hf_to_vllm_mapping = {
            "model.vision_embed_tokens.": "vision_embed_tokens.",
            "lm_head.": "language_model.lm_head.",
            "model.": "language_model.model.",
        }
715

716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
        def hf_to_vllm_name(key: str) -> str:
            for hf_name, vllm_name in hf_to_vllm_mapping.items():
                if key.startswith(hf_name):
                    return key.replace(hf_name, vllm_name, 1)

            return key

        vllm_weights = {hf_to_vllm_name(k): v for k, v in weights}

        # prepare weight iterators for components
        weights_group = group_weights_with_prefix(vllm_weights.items())

        # load vision embeddings and encoder
        self.vision_embed_tokens.load_weights(
            weights_group["vision_embed_tokens"])

        # load llm backbone
        self.language_model.load_weights(weights_group["language_model"])