"vllm/model_executor/offloader/prefetch.py" did not exist on "ed42507f6d6e326663997da5cca6991da5d8a23f"
phi3v.py 26.8 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# coding=utf-8
# Copyright 2024 The vLLM team.
# Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
16
import itertools
17
18
import re
from functools import lru_cache
19
20
from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional,
                    Tuple, TypedDict, Union)
21

22
import numpy as np
23
24
import torch
import torch.nn as nn
25
from PIL import Image
26
from transformers import CLIPVisionConfig, PretrainedConfig
27
28

from vllm.attention import AttentionMetadata
29
from vllm.config import CacheConfig, ModelConfig, MultiModalConfig
30
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
31
from vllm.logger import init_logger
32
from vllm.model_executor.layers.logits_processor import LogitsProcessor
33
from vllm.model_executor.layers.quantization import QuantizationConfig
34
35
36
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
37
from vllm.model_executor.models.clip import CLIPVisionModel
38
39
from vllm.model_executor.models.llama import LlamaModel
from vllm.model_executor.sampling_metadata import SamplingMetadata
40
from vllm.multimodal import MULTIMODAL_REGISTRY
41
from vllm.multimodal.utils import cached_get_tokenizer, repeat_and_pad_token
42
from vllm.sequence import IntermediateTensors, SamplerOutput
43
from vllm.utils import is_list_of
44

45
from .clip import dummy_image_for_clip, dummy_seq_data_for_clip
46
from .interfaces import SupportsMultiModal
47
from .utils import flatten_bn, merge_multimodal_embeddings
48

49
50
logger = init_logger(__name__)

51
52
53
54
_KEYS_TO_MODIFY_MAPPING = {
    "model.vision_embed_tokens": "vision_embed_tokens",
}

55
56
57
# Cannot find the following 2 numbers from hf config.
_IMAGE_TOKEN_ID = 32044

58
59
60
61
# Result in the max possible feature size (h:w = 16:1)
MAX_IMAGE_FEATURE_SIZE_HEIGHT = 8000
MAX_IMAGE_FEATURE_SIZE_WIDTH = 50

62
63
64
65
66
67
68
69
70
71
72
73
CLIP_VIT_LARGE_PATCH14_336_CONFIG = CLIPVisionConfig(dropout=0.0,
                                                     hidden_act="quick_gelu",
                                                     hidden_size=1024,
                                                     image_size=336,
                                                     intermediate_size=4096,
                                                     num_attention_heads=16,
                                                     num_channels=3,
                                                     num_hidden_layers=24,
                                                     patch_size=14,
                                                     projection_dim=768)


74
75
76
77
class Phi3VImagePixelInputs(TypedDict):
    type: Literal["pixel_values"]
    data: Union[torch.Tensor, List[torch.Tensor]]
    """
78
79
    Shape:
    `(batch_size * num_images, 1 + num_patches, num_channels, height, width)`
80

81
82
    Note that `num_patches` may be different per batch and image,
    in which case the data is passed as a list instead of a batched tensor.
83
84
85
86
    """

    image_sizes: torch.Tensor
    """
87
    Shape: `(batch_size * num_images, 2)`
88
89
90
91
92
93
94
95

    This should be in `(height, width)` format.
    """


class Phi3VImageEmbeddingInputs(TypedDict):
    type: Literal["image_embeds"]
    data: Union[torch.Tensor, List[torch.Tensor]]
96
    """Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
97
98
99
100
101
102
103
104

    `hidden_size` must match the hidden size of language model backbone.
    """


Phi3VImageInputs = Union[Phi3VImagePixelInputs, Phi3VImageEmbeddingInputs]


105
106
class Phi3ImageEmbeddingBase(nn.Module):

107
    def __init__(self) -> None:
108
109
110
111
112
113
114
115
116
        super().__init__()
        self.layer_idx: int
        self.type_feature: str
        self.img_processor: CLIPVisionModel

    def get_img_features(self,
                         img_embeds: torch.FloatTensor) -> torch.FloatTensor:
        TYPE_FEATURE = self.type_feature

117
118
        # NOTE: we skip the step to select the vision feature layer since
        # this is already done inside the img_processor
119
        img_feature = self.img_processor(img_embeds)
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134

        if TYPE_FEATURE == "patch":
            patch_feature = img_feature[:, 1:]
            return patch_feature

        if TYPE_FEATURE == "cls_patch":
            return img_feature

        raise NotImplementedError


# adapted from https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_embedding_phi3_v.py
class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase):
    """Phi3 Image embedding with HD transform."""

135
136
    def __init__(self, config: PretrainedConfig) -> None:
        super().__init__()
137
138
139
140
141
142

        # n_embed or hidden_size
        hidden_size = config.n_embd if hasattr(
            config, 'n_embd') else config.hidden_size

        clip_config = CLIP_VIT_LARGE_PATCH14_336_CONFIG
143
144
145
146
147
148
149
150
151
152
153
        self.layer_idx = config.img_processor.get('layer_idx', -2)

        # Initialize the CLIP only up to the required feature layer
        if self.layer_idx < 0:
            num_hidden_layers = clip_config.num_hidden_layers + \
                self.layer_idx + 1
        else:
            num_hidden_layers = self.layer_idx + 1

        self.img_processor = CLIPVisionModel(
            clip_config, num_hidden_layers_override=num_hidden_layers)
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
        image_dim_out = config.img_processor['image_dim_out']
        self.num_img_tokens = config.img_processor['num_img_tokens']

        self.image_dim_out = image_dim_out

        # global_gn and sub_gn for hd transform, serves as line separator
        self.use_hd_transform = config.embd_layer.get('use_hd_transform',
                                                      False)
        self.with_learnable_separator = config.embd_layer.get(
            'with_learnable_separator', False)
        self.hd_transform_order = config.embd_layer.get(
            'hd_transform_order', 'glb_sub')
        # with_hd_transform and with_learnable_separator should have same value
        assert self.use_hd_transform and self.with_learnable_separator

        # 1024 * 4, merge spatial to channel dimension
        self.glb_GN = nn.Parameter(torch.empty([1, 1, self.image_dim_out * 4]))
        self.sub_GN = nn.Parameter(
            torch.empty([1, 1, 1, self.image_dim_out * 4]))

        dim_projection = hidden_size
        depth = 2
        layers = [nn.Linear(image_dim_out * 4, dim_projection)]
        for _ in range(1, depth):
            layers.extend(
                [nn.GELU(),
                 nn.Linear(dim_projection, dim_projection)])
        self.img_projection = nn.Sequential(*layers)

        self.type_feature = config.img_processor.get('type_feature', 'patch')

185
    def forward(self, pixel_values: torch.FloatTensor,
186
                image_sizes: torch.Tensor) -> torch.FloatTensor:
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
        """
        process image and return vision embeddings.

        pixel_values: (num_images, num_crops, c, h, w)
        output: (num_images, num_img_tokens, hidden_size)
        """
        num_images, num_crops, c, h, w = pixel_values.shape
        pixel_values = pixel_values.flatten(0, 1)
        img_features = self.get_img_features(pixel_values)
        img_features = img_features.reshape(num_images, num_crops, -1,
                                            self.image_dim_out)
        image_features_proj = self.hd_feature_transform(
            img_features, image_sizes)
        return image_features_proj

    def hd_feature_transform(self, image_features, image_sizes):
        """
        image_features: (num_images, num_crops+1, 24*24, 1024)
        """
        assert (
            self.hd_transform_order == 'sub_glb'
        ), f'hd_transform_order `{self.hd_transform_order}` not implemented'
        if isinstance(self.img_projection, nn.Sequential):
            target_device = self.img_projection[0].bias.device
            target_dtype = self.img_projection[0].bias.dtype
        else:  # It's a single nn.Linear layer
            target_device = self.img_projection.bias.device
            target_dtype = self.img_projection.bias.dtype

        global_image_features = image_features[:,
                                               0]  # (num_images, 24*24, 1024)
        # global feature can be viewed as a special HD case with num_crops 1x1
        global_image_features_hd = self.reshape_hd_patches_2x2merge(
            global_image_features, 1, 1)
        global_image_features_hd_newline = self.add_image_newline(
            global_image_features_hd)

224
        batch_image_features_proj = []
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
        # need a for loop to process each image because of different image sizes
        # (patch arrangement is different for each image)
        for i, img_size in enumerate(image_sizes):
            h, w = img_size
            h_crop = h // 336
            w_crop = w // 336
            num_crops = h_crop * w_crop

            # NOTE: real num_crops is padded
            # (num_crops, 24*24, 1024)
            sub_image_features = image_features[i, 1:1 + num_crops]
            sub_image_features_hd = self.reshape_hd_patches_2x2merge(
                sub_image_features, h_crop, w_crop)
            sub_image_features_hd_newline = self.add_image_newline(
                sub_image_features_hd)

            # [sub features, separator, global features]
242
243
244
245
246
247
248
249
250
251
252
            image_embeddings = torch.cat([
                sub_image_features_hd_newline.squeeze(
                    0),  # (h_crop*12*(w_crop*12+1), 4096)
                self.glb_GN.squeeze(0),
                global_image_features_hd_newline[i],
            ])
            img_proj = self.img_projection(
                image_embeddings.to(target_device, target_dtype))
            batch_image_features_proj.append(img_proj)

        return batch_image_features_proj
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289

    def reshape_hd_patches_2x2merge(self, image_features, h_crop, w_crop):
        """
        image_features: (num_images*num_crops, 24*24, 1024)
        output: (num_images, h_crop*12, w_crop*12, 4096)
        where h_crop*w_crop == num_crops
        """
        N, L, C = image_features.shape
        assert L == 576 and C == 1024 and N % (h_crop * w_crop) == 0
        num_images = N // (h_crop * w_crop)
        H = int(L**0.5)
        image_features_hd = (
            image_features.reshape(N, H, H, C)  # N, 24, 24, 1024
            .reshape(N, H // 2, 2, H // 2, 2, C)  # N, 12, 2, 12, 2, 1024
            .permute(0, 1, 3, 2, 4, 5)  # N, 12, 12, 2, 2, 1024
            .reshape(N, -1, 4 * C)  # N, 144, 4096
            .reshape(num_images, h_crop, w_crop, H // 2, H // 2,
                     -1)  # n_img, h_crop, w_crop, 12, 12, 4096
            .permute(0, 1, 3, 2, 4, 5)  # n_img, h_crop, 12, w_crop, 12, 4096
            .reshape(num_images, h_crop * H // 2, w_crop * H // 2,
                     4 * C)  # n_img, h_crop*12, w_crop*12, 4096
        )
        return image_features_hd

    def add_image_newline(self, image_features_hd):
        """
        image_features_hd: (num_images, h_crop*12, w_crop*12, 4096)
        output: (num_images, (h_crop*12) * (w_crop*12+1), 4096)
        """
        num_images, h, w, hid_dim = image_features_hd.shape
        # add the newline token to the HD image feature patches
        newline_embeddings = self.sub_GN.expand(num_images, h, -1,
                                                -1)  # (n_img, h, 1, hid_dim)
        image_features_hd_newline = torch.cat(
            [image_features_hd, newline_embeddings],
            dim=2).reshape(num_images, -1, hid_dim)
        return image_features_hd_newline
290
291


292
# Based on https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py#L57
293
def _calc_padded_size(*, width: int, height: int, padding_unit: int = 336):
294
295
296
297
298
299
300
301
    target_height = int(np.ceil(height / padding_unit) * padding_unit)
    top_padding = int((target_height - height) / 2)
    bottom_padding = target_height - height - top_padding
    padded_width = width
    padded_height = height + top_padding + bottom_padding
    return padded_width, padded_height


302
# Based on https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py#L90
303
def _calc_hd_transform_size(*, width: int, height: int, hd_num: int = 16):
304
305
306
307
308
309
310
311
312
313
314
315
316
317
    transposed = False
    if width < height:
        width, height = height, width
        transposed = True

    ratio = width / height
    scale = 1
    while scale * np.ceil(scale / ratio) <= hd_num:
        scale += 1
    scale -= 1

    new_width = int(scale * 336)
    new_height = int(new_width / ratio)

318
319
    padded_width, padded_height = _calc_padded_size(width=new_width,
                                                    height=new_height)
320
321
322
323
324
325
326

    if transposed:
        padded_width, padded_height = padded_height, padded_width

    return padded_width, padded_height


327
328
# Based on https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py#L181
def get_phi3v_image_feature_size(
329
    hf_config: Dict[str, Any],
330
331
332
333
    *,
    input_height: int,
    input_width: int,
) -> int:
334
    num_crops = hf_config.get("num_crops", 16)
335
336
337
338
339
340
341
    new_width, new_height = _calc_hd_transform_size(width=input_width,
                                                    height=input_height,
                                                    hd_num=num_crops)

    return (new_height // 336 * new_width // 336 + 1) * 144 + 1 \
        + (new_height // 336 + 1) * 12

342

343
344
345
def get_max_phi3v_image_tokens(ctx: InputContext):

    return get_phi3v_image_feature_size(
346
        ctx.get_hf_image_processor_config(),
347
348
        input_height=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
        input_width=MAX_IMAGE_FEATURE_SIZE_WIDTH,
349
350
351
    )


352
353
354
def dummy_data_for_phi3v(ctx: InputContext, seq_len: int,
                         mm_counts: Mapping[str, int]):
    num_images = mm_counts["image"]
355
356

    image_feature_size = get_max_phi3v_image_tokens(ctx)
357

358
359
360
    seq_data = dummy_seq_data_for_clip(
        CLIP_VIT_LARGE_PATCH14_336_CONFIG,
        seq_len,
361
        num_images,
362
        image_token_id=_IMAGE_TOKEN_ID,
363
364
365
366
        image_feature_size_override=image_feature_size,
    )
    mm_data = dummy_image_for_clip(
        CLIP_VIT_LARGE_PATCH14_336_CONFIG,
367
        num_images,
368
369
        image_width_override=MAX_IMAGE_FEATURE_SIZE_WIDTH,
        image_height_override=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
370
    )
371

372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
    return seq_data, mm_data


# Reserve this function to also handle placeholders for additional images
# [ref: PR #5820]
@lru_cache
def _get_image_placeholder_token_ids(model_config: ModelConfig,
                                     idx: int) -> List[int]:
    assert idx > 0

    tokenizer = cached_get_tokenizer(model_config.tokenizer)

    # We need to get the token for "<", not "▁<"
    # https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/raw/main/tokenizer.json
    a_token_id, = tokenizer.encode("a", add_special_tokens=False)
    a_token_id_, *image_placeholder_token_ids = tokenizer.encode(
        f"a<|image_{idx}|>", add_special_tokens=False)
    assert a_token_id == a_token_id_

    return image_placeholder_token_ids


def input_processor_for_phi3v(ctx: InputContext, llm_inputs: LLMInputs):
    multi_modal_data = llm_inputs.get("multi_modal_data")
    if multi_modal_data is None or "image" not in multi_modal_data:
        return llm_inputs
398

399
    model_config = ctx.model_config
400
    hf_config = ctx.get_hf_image_processor_config()
401
402
403
404

    image_data = multi_modal_data["image"]
    if isinstance(image_data, Image.Image):
        w, h = image_data.size
405
406
407
408
409
410
411
412
413
414
415
416
417
418
        image_feature_size = [
            get_phi3v_image_feature_size(hf_config,
                                         input_width=w,
                                         input_height=h)
        ]
        image_data = [image_data]
    elif is_list_of(image_data, Image.Image):
        image_feature_size = []
        for image in image_data:
            w, h = image.size
            image_feature_size.append(
                get_phi3v_image_feature_size(hf_config,
                                             input_width=w,
                                             input_height=h))
419
    elif isinstance(image_data, torch.Tensor):
420
        image_feature_size = image_data.shape[0]
421
422
423
424
425
    else:
        raise TypeError(f"Invalid image type: {type(image_data)}")

    prompt = llm_inputs.get("prompt")
    if prompt is None:
426
427
428
        # for async server request, we assume prompt and its token_ids is always
        # in correct format. And num_image_tags == len(image_data) always True.
        image_idx = range(1, len(image_data) + 1)
429
430
        new_prompt = None
    else:
431
        image_idx = sorted(map(int, re.findall(r"<\|image_(\d+)\|>+", prompt)))
432
433
434
435
        if prompt.count("<|image|>") > 0:
            logger.warning("Please follow the prompt format that is "
                           "documented on HuggingFace which does not involve "
                           "repeating <|image|> tokens.")
436
437
438
        elif (num_image_tags := len(image_idx)) > 1:
            assert num_image_tags == len(
                image_data), "The count of image_placeholder not match image's"
439
440
        new_prompt = prompt

441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
    prompt_token_ids = llm_inputs["prompt_token_ids"].copy()

    # masked place_holder with image token id
    for idx in image_idx:
        image_token_ids = _get_image_placeholder_token_ids(model_config,
                                                           idx=idx)
        for i in range(len(prompt_token_ids) - len(image_token_ids) + 1):
            if prompt_token_ids[i:i + len(image_token_ids)] == image_token_ids:
                prompt_token_ids[i:i + len(image_token_ids)] = [
                    _IMAGE_TOKEN_ID
                ] * len(image_token_ids)
                break

    # merge consecutive tag ids
    merged_token_ids: List[int] = []
    for is_placeholder, token_ids in itertools.groupby(
            prompt_token_ids, lambda x: x == _IMAGE_TOKEN_ID):
        if is_placeholder:
            merged_token_ids.append(_IMAGE_TOKEN_ID)
        else:
            merged_token_ids.extend(list(token_ids))
462

463
    # TODO: Move this to utils or integrate with clip.
464
    new_token_ids: List[int] = []
465
466
467
468
469
470
471
472
473
474
    placeholder_idx = 0
    while merged_token_ids:
        token_id = merged_token_ids.pop(0)
        if token_id == _IMAGE_TOKEN_ID:
            new_token_ids.extend(
                repeat_and_pad_token(
                    _IMAGE_TOKEN_ID,
                    repeat_count=image_feature_size[placeholder_idx],
                ))
            placeholder_idx += 1
475
        else:
476
            new_token_ids.append(token_id)
477
478
479
480
481

    # NOTE: Create a defensive copy of the original inputs
    llm_inputs = LLMInputs(prompt_token_ids=new_token_ids,
                           prompt=new_prompt,
                           multi_modal_data=multi_modal_data)
482
    return llm_inputs
483

484
485

@MULTIMODAL_REGISTRY.register_image_input_mapper()
486
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_phi3v_image_tokens)
487
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_phi3v)
488
@INPUT_REGISTRY.register_input_processor(input_processor_for_phi3v)
489
class Phi3VForCausalLM(nn.Module, SupportsMultiModal):
490
491
492

    def __init__(self,
                 config: PretrainedConfig,
493
                 multimodal_config: MultiModalConfig,
494
495
                 cache_config: Optional[CacheConfig] = None,
                 quant_config: Optional[QuantizationConfig] = None) -> None:
496
497
        super().__init__()

498
        self.config = config
499
        self.multimodal_config = multimodal_config
500
        self.image_token_id = _IMAGE_TOKEN_ID
501

502
        self.model = LlamaModel(config, cache_config, quant_config)
503
504

        # TODO: Optionally initializes this for supporting embeddings.
505
        self.vision_embed_tokens = Phi3HDImageEmbedding(config)
506
507
508
        self.lm_head = ParallelLMHead(config.vocab_size,
                                      config.hidden_size,
                                      quant_config=quant_config)
509
510
        if self.config.tie_word_embeddings:
            self.lm_head.weight = self.model.embed_tokens.weight
511
512
513
        self.logits_processor = LogitsProcessor(config.vocab_size)
        self.sampler = Sampler()

514
    def _validate_image_sizes(self, data: torch.Tensor) -> torch.Tensor:
515
516
517
518
519
520
521
522
523
524
525
526
527
        expected_dims = (2, )

        def _validate_shape(d: torch.Tensor):
            actual_dims = tuple(d.shape)

            if actual_dims != expected_dims:
                expected_expr = str(expected_dims)
                raise ValueError(
                    f"The expected shape of image sizes per image per batch "
                    f"is {expected_expr}. You supplied {tuple(d.shape)}.")

        for d in data:
            _validate_shape(d)
528
529
530
531
532
533
534

        return data

    def _validate_pixel_values(
        self, data: Union[torch.Tensor, List[torch.Tensor]]
    ) -> Union[torch.Tensor, List[torch.Tensor]]:

535
536
537
538
539
540
541
542
        h = w = CLIP_VIT_LARGE_PATCH14_336_CONFIG.image_size
        expected_dims = (3, h, w)

        def _validate_shape(d: torch.Tensor):
            actual_dims = tuple(d.shape[1:])

            if actual_dims != expected_dims:
                expected_expr = ("num_patches", *map(str, expected_dims))
543
                raise ValueError(
544
                    "The expected shape of pixel values per image per batch "
545
                    f"is {expected_expr}. You supplied {tuple(d.shape)}.")
546

547
548
        for d in data:
            _validate_shape(d)
549
550
551

        return data

552
    def _parse_and_validate_image_input(
553
            self, **kwargs: object) -> Optional[Phi3VImageInputs]:
554
555
        pixel_values = kwargs.pop("pixel_values", None)
        image_sizes = kwargs.pop("image_sizes", None)
556
        image_embeds = kwargs.pop("image_embeds", None)
557

558
559
560
        if pixel_values is None:
            return None

561
562
563
564
565
566
567
568
        if pixel_values is None and image_embeds is None:
            return None

        if pixel_values is not None:
            if not isinstance(pixel_values, (torch.Tensor, list)):
                raise ValueError("Incorrect type of pixel values. "
                                 f"Got type: {type(pixel_values)}")

569
            if not isinstance(image_sizes, (torch.Tensor, list)):
570
571
572
573
574
                raise ValueError("Incorrect type of image sizes. "
                                 f"Got type: {type(image_sizes)}")

            return Phi3VImagePixelInputs(
                type="pixel_values",
575
576
577
                data=self._validate_pixel_values(flatten_bn(pixel_values)),
                image_sizes=self._validate_image_sizes(
                    flatten_bn(image_sizes, concat=True)))
578
579
580
581
582

        if image_embeds is not None:
            if not isinstance(image_embeds, torch.Tensor):
                raise ValueError("Incorrect type of image embeddings. "
                                 f"Got type: {type(image_embeds)}")
583

584
585
            return Phi3VImageEmbeddingInputs(
                type="image_embeds",
586
                data=flatten_bn(image_embeds),
587
588
589
590
591
592
593
594
595
596
597
            )

        raise AssertionError("This line should be unreachable.")

    def _process_image_input(
        self,
        image_input: Phi3VImageInputs,
    ) -> torch.Tensor:

        if image_input["type"] == "image_embeds":
            return image_input["data"]
598

599
600
601
        assert self.vision_embed_tokens is not None
        image_embeds = self.vision_embed_tokens(image_input["data"],
                                                image_input["image_sizes"])
602

603
        return image_embeds
604

605
606
607
    def forward(self,
                input_ids: torch.Tensor,
                positions: torch.Tensor,
608
                kv_caches: List[torch.Tensor],
609
610
611
                attn_metadata: AttentionMetadata,
                intermediate_tensors: Optional[IntermediateTensors] = None,
                **kwargs: object):
612
613
614
        image_input = self._parse_and_validate_image_input(**kwargs)

        if image_input is not None:
615
            vision_embeddings = self._process_image_input(image_input)
616
            inputs_embeds = self.model.get_input_embeddings(input_ids)
617
618
619
            inputs_embeds = merge_multimodal_embeddings(
                input_ids, inputs_embeds, vision_embeddings,
                self.image_token_id)
620
621
622
623
624
625
626
627
            input_ids = None
        else:
            inputs_embeds = None

        hidden_states = self.model(input_ids,
                                   positions,
                                   kv_caches,
                                   attn_metadata,
628
                                   intermediate_tensors,
629
630
631
632
                                   inputs_embeds=inputs_embeds)

        return hidden_states

633
634
635
636
637
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
638
        logits = self.logits_processor(self.lm_head, hidden_states,
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
                                       sampling_metadata)
        return logits

    def sample(
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
        next_tokens = self.sampler(logits, sampling_metadata)
        return next_tokens

    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            (".qkv_proj", ".q_proj", "q"),
            (".qkv_proj", ".k_proj", "k"),
            (".qkv_proj", ".v_proj", "v"),
            (".gate_up_proj", ".gate_proj", 0),
            (".gate_up_proj", ".up_proj", 1),
        ]
        params_dict = dict(self.named_parameters())
        for name, loaded_weight in weights:
            if "rotary_emb.inv_freq" in name:
                continue
663
664
665
            # post_layernorm is not needed in CLIPVisionModel
            if "vision_model.post_layernorm" in name:
                continue
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
            for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items():
                if key_to_modify in name:
                    name = name.replace(key_to_modify, new_key)
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                # We only do sharding for language model
                # and not vision model for now.
                if "vision_embed_tokens" in name and self.vision_embed_tokens:
                    continue
                if weight_name not in name:
                    continue
                param = params_dict[name.replace(weight_name, param_name)]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
684
685
686
687
688
                if name in params_dict:
                    param = params_dict[name]
                    weight_loader = getattr(param, "weight_loader",
                                            default_weight_loader)
                    weight_loader(param, loaded_weight)