phi3v.py 24 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# Copyright 2024 The vLLM team.
# Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
15
16
17
from functools import cached_property
from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
                    TypedDict, Union)
18
19
20

import torch
import torch.nn as nn
21
22
from transformers import (BatchFeature, CLIPVisionConfig, PretrainedConfig,
                          ProcessorMixin)
23
24

from vllm.attention import AttentionMetadata
25
26
from vllm.config import VllmConfig
from vllm.inputs import InputContext
27
from vllm.logger import init_logger
28
from vllm.model_executor.layers.quantization import QuantizationConfig
Joe Runde's avatar
Joe Runde committed
29
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
30
31
from vllm.model_executor.layers.vocab_parallel_embedding import (
    VocabParallelEmbedding)
32
from vllm.model_executor.models.clip import CLIPVisionModel
33
from vllm.model_executor.sampling_metadata import SamplingMetadata
34
from vllm.multimodal import MULTIMODAL_REGISTRY
35
from vllm.multimodal.inputs import NestedTensors
36
37
from vllm.multimodal.processing import (BaseMultiModalProcessor,
                                        MultiModalDataDict,
38
                                        MultiModalDataItems, ProcessorInputs,
39
                                        PromptReplacement)
40
from vllm.sequence import IntermediateTensors
41
from vllm.utils import is_list_of
42

43
from .clip import dummy_image_for_clip
44
from .interfaces import SupportsMultiModal, SupportsPP
45
46
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
                    init_vllm_registered_model, maybe_prefix,
47
                    merge_multimodal_embeddings)
48

49
50
logger = init_logger(__name__)

51
52
53
# Cannot find the following 2 numbers from hf config.
_IMAGE_TOKEN_ID = 32044

54
55
56
57
# Result in the max possible feature size (h:w = 16:1)
MAX_IMAGE_FEATURE_SIZE_HEIGHT = 8000
MAX_IMAGE_FEATURE_SIZE_WIDTH = 50

58
59
60
61
62
63
64
65
66
67
68
69
CLIP_VIT_LARGE_PATCH14_336_CONFIG = CLIPVisionConfig(dropout=0.0,
                                                     hidden_act="quick_gelu",
                                                     hidden_size=1024,
                                                     image_size=336,
                                                     intermediate_size=4096,
                                                     num_attention_heads=16,
                                                     num_channels=3,
                                                     num_hidden_layers=24,
                                                     patch_size=14,
                                                     projection_dim=768)


70
def _init_img_processor(hf_config: PretrainedConfig,
71
72
                        quant_config: Optional[QuantizationConfig],
                        prefix: str = "") -> CLIPVisionModel:
73
74
75
76
77
78
79
80
81
82
83
    clip_config = CLIP_VIT_LARGE_PATCH14_336_CONFIG
    layer_idx = hf_config.img_processor.get('layer_idx', -2)

    # Initialize the CLIP only up to the required feature layer
    if layer_idx < 0:
        num_hidden_layers = clip_config.num_hidden_layers + \
            layer_idx + 1
    else:
        num_hidden_layers = layer_idx + 1

    img_processor = CLIPVisionModel(
84
85
86
        clip_config,
        quant_config,
        num_hidden_layers_override=num_hidden_layers,
87
        prefix=prefix,
88
    )
89
90
91
92

    return img_processor


93
94
95
96
class Phi3VImagePixelInputs(TypedDict):
    type: Literal["pixel_values"]
    data: Union[torch.Tensor, List[torch.Tensor]]
    """
97
98
    Shape:
    `(batch_size * num_images, 1 + num_patches, num_channels, height, width)`
99

100
101
    Note that `num_patches` may be different per batch and image,
    in which case the data is passed as a list instead of a batched tensor.
102
103
104
105
    """

    image_sizes: torch.Tensor
    """
106
    Shape: `(batch_size * num_images, 2)`
107
108
109
110
111
112
113
114

    This should be in `(height, width)` format.
    """


class Phi3VImageEmbeddingInputs(TypedDict):
    type: Literal["image_embeds"]
    data: Union[torch.Tensor, List[torch.Tensor]]
115
    """Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
116
117
118
119
120
121
122
123

    `hidden_size` must match the hidden size of language model backbone.
    """


Phi3VImageInputs = Union[Phi3VImagePixelInputs, Phi3VImageEmbeddingInputs]


124
125
class Phi3ImageEmbeddingBase(nn.Module):

126
    def __init__(self) -> None:
127
128
129
130
131
132
133
134
135
        super().__init__()
        self.layer_idx: int
        self.type_feature: str
        self.img_processor: CLIPVisionModel

    def get_img_features(self,
                         img_embeds: torch.FloatTensor) -> torch.FloatTensor:
        TYPE_FEATURE = self.type_feature

136
137
        # NOTE: we skip the step to select the vision feature layer since
        # this is already done inside the img_processor
138
        img_feature = self.img_processor(img_embeds)
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153

        if TYPE_FEATURE == "patch":
            patch_feature = img_feature[:, 1:]
            return patch_feature

        if TYPE_FEATURE == "cls_patch":
            return img_feature

        raise NotImplementedError


# adapted from https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_embedding_phi3_v.py
class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase):
    """Phi3 Image embedding with HD transform."""

154
155
156
157
    def __init__(self,
                 config: PretrainedConfig,
                 quant_config: Optional[QuantizationConfig],
                 prefix: str = "") -> None:
158
        super().__init__()
159
160
161
162
163

        # n_embed or hidden_size
        hidden_size = config.n_embd if hasattr(
            config, 'n_embd') else config.hidden_size

164
165
        self.img_processor = _init_img_processor(
            config, quant_config, prefix=f"{prefix}.img_processor")
166

167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
        image_dim_out = config.img_processor['image_dim_out']
        self.num_img_tokens = config.img_processor['num_img_tokens']

        self.image_dim_out = image_dim_out

        # global_gn and sub_gn for hd transform, serves as line separator
        self.use_hd_transform = config.embd_layer.get('use_hd_transform',
                                                      False)
        self.with_learnable_separator = config.embd_layer.get(
            'with_learnable_separator', False)
        self.hd_transform_order = config.embd_layer.get(
            'hd_transform_order', 'glb_sub')
        # with_hd_transform and with_learnable_separator should have same value
        assert self.use_hd_transform and self.with_learnable_separator

        # 1024 * 4, merge spatial to channel dimension
        self.glb_GN = nn.Parameter(torch.empty([1, 1, self.image_dim_out * 4]))
        self.sub_GN = nn.Parameter(
            torch.empty([1, 1, 1, self.image_dim_out * 4]))

        dim_projection = hidden_size
        depth = 2
        layers = [nn.Linear(image_dim_out * 4, dim_projection)]
        for _ in range(1, depth):
            layers.extend(
                [nn.GELU(),
                 nn.Linear(dim_projection, dim_projection)])
        self.img_projection = nn.Sequential(*layers)

        self.type_feature = config.img_processor.get('type_feature', 'patch')

198
    def forward(self, pixel_values: torch.FloatTensor,
199
                image_sizes: torch.Tensor) -> torch.FloatTensor:
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
        """
        process image and return vision embeddings.

        pixel_values: (num_images, num_crops, c, h, w)
        output: (num_images, num_img_tokens, hidden_size)
        """
        num_images, num_crops, c, h, w = pixel_values.shape
        pixel_values = pixel_values.flatten(0, 1)
        img_features = self.get_img_features(pixel_values)
        img_features = img_features.reshape(num_images, num_crops, -1,
                                            self.image_dim_out)
        image_features_proj = self.hd_feature_transform(
            img_features, image_sizes)
        return image_features_proj

    def hd_feature_transform(self, image_features, image_sizes):
        """
        image_features: (num_images, num_crops+1, 24*24, 1024)
        """
        assert (
            self.hd_transform_order == 'sub_glb'
        ), f'hd_transform_order `{self.hd_transform_order}` not implemented'
        if isinstance(self.img_projection, nn.Sequential):
            target_device = self.img_projection[0].bias.device
            target_dtype = self.img_projection[0].bias.dtype
        else:  # It's a single nn.Linear layer
            target_device = self.img_projection.bias.device
            target_dtype = self.img_projection.bias.dtype

        global_image_features = image_features[:,
                                               0]  # (num_images, 24*24, 1024)
        # global feature can be viewed as a special HD case with num_crops 1x1
        global_image_features_hd = self.reshape_hd_patches_2x2merge(
            global_image_features, 1, 1)
        global_image_features_hd_newline = self.add_image_newline(
            global_image_features_hd)

237
        batch_image_features_proj = []
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
        # need a for loop to process each image because of different image sizes
        # (patch arrangement is different for each image)
        for i, img_size in enumerate(image_sizes):
            h, w = img_size
            h_crop = h // 336
            w_crop = w // 336
            num_crops = h_crop * w_crop

            # NOTE: real num_crops is padded
            # (num_crops, 24*24, 1024)
            sub_image_features = image_features[i, 1:1 + num_crops]
            sub_image_features_hd = self.reshape_hd_patches_2x2merge(
                sub_image_features, h_crop, w_crop)
            sub_image_features_hd_newline = self.add_image_newline(
                sub_image_features_hd)

            # [sub features, separator, global features]
255
256
257
258
259
260
261
262
263
264
265
            image_embeddings = torch.cat([
                sub_image_features_hd_newline.squeeze(
                    0),  # (h_crop*12*(w_crop*12+1), 4096)
                self.glb_GN.squeeze(0),
                global_image_features_hd_newline[i],
            ])
            img_proj = self.img_projection(
                image_embeddings.to(target_device, target_dtype))
            batch_image_features_proj.append(img_proj)

        return batch_image_features_proj
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302

    def reshape_hd_patches_2x2merge(self, image_features, h_crop, w_crop):
        """
        image_features: (num_images*num_crops, 24*24, 1024)
        output: (num_images, h_crop*12, w_crop*12, 4096)
        where h_crop*w_crop == num_crops
        """
        N, L, C = image_features.shape
        assert L == 576 and C == 1024 and N % (h_crop * w_crop) == 0
        num_images = N // (h_crop * w_crop)
        H = int(L**0.5)
        image_features_hd = (
            image_features.reshape(N, H, H, C)  # N, 24, 24, 1024
            .reshape(N, H // 2, 2, H // 2, 2, C)  # N, 12, 2, 12, 2, 1024
            .permute(0, 1, 3, 2, 4, 5)  # N, 12, 12, 2, 2, 1024
            .reshape(N, -1, 4 * C)  # N, 144, 4096
            .reshape(num_images, h_crop, w_crop, H // 2, H // 2,
                     -1)  # n_img, h_crop, w_crop, 12, 12, 4096
            .permute(0, 1, 3, 2, 4, 5)  # n_img, h_crop, 12, w_crop, 12, 4096
            .reshape(num_images, h_crop * H // 2, w_crop * H // 2,
                     4 * C)  # n_img, h_crop*12, w_crop*12, 4096
        )
        return image_features_hd

    def add_image_newline(self, image_features_hd):
        """
        image_features_hd: (num_images, h_crop*12, w_crop*12, 4096)
        output: (num_images, (h_crop*12) * (w_crop*12+1), 4096)
        """
        num_images, h, w, hid_dim = image_features_hd.shape
        # add the newline token to the HD image feature patches
        newline_embeddings = self.sub_GN.expand(num_images, h, -1,
                                                -1)  # (n_img, h, 1, hid_dim)
        image_features_hd_newline = torch.cat(
            [image_features_hd, newline_embeddings],
            dim=2).reshape(num_images, -1, hid_dim)
        return image_features_hd_newline
303
304


305
306
307
def get_max_phi3v_image_tokens(ctx: InputContext) -> int:
    processor = ctx.get_hf_processor()
    image_processor = processor.image_processor  # type: ignore
308

309
    return image_processor.calc_num_image_tokens_from_image_size(
310
311
312
        width=MAX_IMAGE_FEATURE_SIZE_WIDTH,
        height=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
    )
313
314


315
class Phi3VMultiModalProcessor(BaseMultiModalProcessor):
316

317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
    def _get_hf_processor(
        self,
        *,
        num_crops: Optional[int] = None,
    ) -> ProcessorMixin:
        if num_crops is not None:
            return self.ctx.get_hf_processor(num_crops=num_crops)
        return self.ctx.get_hf_processor()

    def _apply_hf_processor(
        self,
        prompt: str,
        mm_data: MultiModalDataDict,
        mm_processor_kwargs: Mapping[str, object],
    ) -> BatchFeature:
        processed_outputs = super()._apply_hf_processor(
            prompt, mm_data, mm_processor_kwargs)
        # Phi3v processor has inserted -1, -2 etc as placeholder in prompt_ids,
        # which will cause OverflowError when decoding the prompt_ids.
        # Therefore, we need to do an early replacement here
        token_ids = processed_outputs['input_ids']
        token_ids[token_ids < 0] = _IMAGE_TOKEN_ID
        processed_outputs['input_ids'] = token_ids
        return processed_outputs

342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
    def _get_prompt_replacements(
        self,
        mm_items: MultiModalDataItems,
        hf_inputs: BatchFeature,
        mm_processor_kwargs: Mapping[str, object],
    ) -> list[PromptReplacement]:
        hf_processor = self._get_hf_processor()
        image_tokens: list[str] = hf_processor.img_tokens  # type: ignore
        image_processor = hf_processor.image_processor  # type: ignore

        mm_config = self.ctx.get_mm_config()
        max_images = mm_config.limit_per_prompt.get("image", 1)

        def get_replacement_phi3v(item_idx: int):
            image_size = mm_items.get_image_size(item_idx)
            num_tokens = image_processor.calc_num_image_tokens_from_image_size(
                width=image_size.width,
                height=image_size.height,
            )

            return [_IMAGE_TOKEN_ID] * num_tokens

        return [
            PromptReplacement(
                modality="image",
                target=image_token,
                replacement=get_replacement_phi3v,
            ) for image_token in image_tokens[:max_images]
        ]

    def _get_dummy_mm_inputs(
373
374
        self,
        mm_counts: Mapping[str, int],
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
    ) -> ProcessorInputs:
        num_images = mm_counts["image"]

        data = dummy_image_for_clip(
            CLIP_VIT_LARGE_PATCH14_336_CONFIG,
            num_images,
            image_width_override=MAX_IMAGE_FEATURE_SIZE_WIDTH,
            image_height_override=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
        )

        hf_processor = self._get_hf_processor()
        image_tokens: list[str] = hf_processor.img_tokens  # type: ignore

        return ProcessorInputs(
            prompt_text="".join(image_tokens[:num_images]),
            mm_data=data,
            mm_processor_kwargs={},
        )
393

394

395
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_phi3v_image_tokens)
396
@MULTIMODAL_REGISTRY.register_processor(Phi3VMultiModalProcessor)
397
class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
398

399
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
400
        super().__init__()
401
402
403
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config
404
        self.config = config
405
        self.multimodal_config = multimodal_config
406
        self.image_token_id = _IMAGE_TOKEN_ID
407

408
409
410
411
412
        self.embed_tokens = VocabParallelEmbedding(
            config.vocab_size,
            config.hidden_size,
            org_num_embeddings=config.vocab_size,
            quant_config=quant_config,
413
            prefix=maybe_prefix(prefix, "model.embed_tokens"),
414
415
416
        )

        # TODO: Optionally initializes this for supporting input embeddings.
417
        self.vision_embed_tokens = Phi3HDImageEmbedding(
418
419
420
            config,
            quant_config,
            prefix=maybe_prefix(prefix, "model.vision_embed_tokens"))
421

422
423
424
425
426
427
428
429
430
431
432
        self.language_model = init_vllm_registered_model(
            vllm_config=vllm_config,
            # The prefix is empty intentionally because default prefix of
            # LlamaForCausalLM is "model"
            prefix="",
            # We don't directly initialize vLLM's LlamaForCausalLM so we
            # can automatically apply embedding wrapper if this model is
            # initialized as an embedding model
            architectures=["LlamaForCausalLM"],
        )

433
434
435
436
437
438
439
440
        self.make_empty_intermediate_tensors = (
            self.language_model.make_empty_intermediate_tensors)

    @cached_property
    def sampler(self):
        if hasattr(self.language_model, "sampler"):
            return self.language_model.sampler

Joe Runde's avatar
Joe Runde committed
441
        return get_sampler()
442

443
    def _validate_image_sizes(self, data: torch.Tensor) -> torch.Tensor:
444
445
446
447
448
449
450
451
452
453
454
455
456
        expected_dims = (2, )

        def _validate_shape(d: torch.Tensor):
            actual_dims = tuple(d.shape)

            if actual_dims != expected_dims:
                expected_expr = str(expected_dims)
                raise ValueError(
                    f"The expected shape of image sizes per image per batch "
                    f"is {expected_expr}. You supplied {tuple(d.shape)}.")

        for d in data:
            _validate_shape(d)
457
458
459
460
461
462
463

        return data

    def _validate_pixel_values(
        self, data: Union[torch.Tensor, List[torch.Tensor]]
    ) -> Union[torch.Tensor, List[torch.Tensor]]:

464
465
466
467
468
469
470
471
        h = w = CLIP_VIT_LARGE_PATCH14_336_CONFIG.image_size
        expected_dims = (3, h, w)

        def _validate_shape(d: torch.Tensor):
            actual_dims = tuple(d.shape[1:])

            if actual_dims != expected_dims:
                expected_expr = ("num_patches", *map(str, expected_dims))
472
                raise ValueError(
473
                    "The expected shape of pixel values per image per batch "
474
                    f"is {expected_expr}. You supplied {tuple(d.shape)}.")
475

476
477
        for d in data:
            _validate_shape(d)
478
479
480

        return data

481
    def _parse_and_validate_image_input(
482
            self, **kwargs: object) -> Optional[Phi3VImageInputs]:
483
484
        pixel_values = kwargs.pop("pixel_values", None)
        image_sizes = kwargs.pop("image_sizes", None)
485
        image_embeds = kwargs.pop("image_embeds", None)
486

487
488
489
490
491
492
493
494
        if pixel_values is None and image_embeds is None:
            return None

        if pixel_values is not None:
            if not isinstance(pixel_values, (torch.Tensor, list)):
                raise ValueError("Incorrect type of pixel values. "
                                 f"Got type: {type(pixel_values)}")

495
            if not isinstance(image_sizes, (torch.Tensor, list)):
496
497
498
499
500
                raise ValueError("Incorrect type of image sizes. "
                                 f"Got type: {type(image_sizes)}")

            return Phi3VImagePixelInputs(
                type="pixel_values",
501
502
503
                data=self._validate_pixel_values(flatten_bn(pixel_values)),
                image_sizes=self._validate_image_sizes(
                    flatten_bn(image_sizes, concat=True)))
504
505
506
507
508

        if image_embeds is not None:
            if not isinstance(image_embeds, torch.Tensor):
                raise ValueError("Incorrect type of image embeddings. "
                                 f"Got type: {type(image_embeds)}")
509

510
511
            return Phi3VImageEmbeddingInputs(
                type="image_embeds",
512
                data=flatten_bn(image_embeds),
513
514
515
516
517
518
519
520
521
522
            )

        raise AssertionError("This line should be unreachable.")

    def _process_image_input(
        self,
        image_input: Phi3VImageInputs,
    ) -> torch.Tensor:

        if image_input["type"] == "image_embeds":
523
524
525
526
527
528
529
530
531
532
533
            image_data = image_input["data"]
            if is_list_of(image_data, torch.Tensor):
                # it's already a list of tensors
                return image_data
            if len(image_data.shape) == 3:
                # 3D tensor
                return list(torch.unbind(image_data, dim=0))
            raise ValueError(
                "We expect batched 2D tensors;"
                "this can be either a list of 2D tensors or a single 3D tensor."
            )
534

535
536
537
        assert self.vision_embed_tokens is not None
        image_embeds = self.vision_embed_tokens(image_input["data"],
                                                image_input["image_sizes"])
538

539
        return image_embeds
540

541
    def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
542
543
544
545
546
547
548
549
550
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input is None:
            return None
        vision_embeddings = self._process_image_input(image_input)
        return vision_embeddings

    def get_input_embeddings(
        self,
        input_ids: torch.Tensor,
551
        multimodal_embeddings: Optional[NestedTensors] = None,
552
553
    ) -> torch.Tensor:
        inputs_embeds = self.embed_tokens(input_ids)
554
        if multimodal_embeddings is not None:
555
            inputs_embeds = merge_multimodal_embeddings(
556
                input_ids, inputs_embeds, multimodal_embeddings,
557
558
559
                self.image_token_id)
        return inputs_embeds

560
561
562
    def forward(self,
                input_ids: torch.Tensor,
                positions: torch.Tensor,
563
                kv_caches: List[torch.Tensor],
564
565
                attn_metadata: AttentionMetadata,
                intermediate_tensors: Optional[IntermediateTensors] = None,
566
                inputs_embeds: Optional[torch.Tensor] = None,
567
                **kwargs: object):
568

569
        if intermediate_tensors is not None:
570
            inputs_embeds = None
571
572
573

        # NOTE: In v1, inputs_embeds is always generated at model runner, this
        # condition is for v0 compatibility
574
        elif inputs_embeds is None:
575
            vision_embeddings = self.get_multimodal_embeddings(**kwargs)
576
577
578
            inputs_embeds = self.get_input_embeddings(input_ids,
                                                      vision_embeddings)
            input_ids = None
579

580
581
582
583
584
585
        hidden_states = self.language_model.model(input_ids,
                                                  positions,
                                                  kv_caches,
                                                  attn_metadata,
                                                  intermediate_tensors,
                                                  inputs_embeds=inputs_embeds)
586
587
588

        return hidden_states

589
590
591
592
593
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
594
595
        return self.language_model.compute_logits(hidden_states,
                                                  sampling_metadata)
596
597
598
599
600
601

    def sample(
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
602
        return self.language_model.sample(logits, sampling_metadata)
603

604
605
    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
606
607
        hf_to_vllm_mapper = WeightsMapper(
            orig_to_new_prefix={
608
                "model.vision_embed_tokens.wte": "embed_tokens",
609
610
611
612
613
614
                "model.vision_embed_tokens.": "vision_embed_tokens.",
                "lm_head.": "language_model.lm_head.",
                "model.": "language_model.model.",
            })

        loader = AutoWeightsLoader(self)
615
616
617
618
619
        autoloaded_weights = loader.load_weights(weights,
                                                 mapper=hf_to_vllm_mapper)

        # The HF config doesn't specify whether these are tied,
        # so we detect it this way
620
        if "embed_tokens.weight" not in autoloaded_weights:
621
            self.embed_tokens = self.language_model.model.embed_tokens
622
623
            autoloaded_weights.add("embed_tokens.weight")
        return autoloaded_weights