aria.py 22.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from collections.abc import Iterable, Mapping, Sequence
4
from typing import Annotated, Literal
5
6
7

import torch
import torch.nn as nn
8
9
10
from transformers import AriaConfig, AriaTextConfig, BatchFeature
from transformers.models.aria.modeling_aria import AriaCrossAttention
from transformers.models.aria.processing_aria import AriaProcessor
11

12
from vllm.config import VllmConfig
13
from vllm.config.multimodal import BaseDummyOptions
14
from vllm.distributed import get_tensor_model_parallel_rank
15
from vllm.inputs import MultiModalDataDict
16
from vllm.model_executor.layers.activation import get_act_fn
17
from vllm.model_executor.layers.fused_moe import FusedMoE
18
from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear
Divakar Verma's avatar
Divakar Verma committed
19
from vllm.model_executor.layers.logits_processor import LogitsProcessor
20
from vllm.model_executor.layers.quantization import QuantizationConfig
Divakar Verma's avatar
Divakar Verma committed
21
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
22
from vllm.model_executor.model_loader.weight_utils import (
23
24
25
    default_weight_loader,
    maybe_remap_kv_scale_name,
)
26
from vllm.multimodal import MULTIMODAL_REGISTRY
27
28
29
30
from vllm.multimodal.inputs import (
    MultiModalFieldConfig,
    MultiModalKwargsItems,
)
31
from vllm.multimodal.parse import MultiModalDataItems
32
from vllm.multimodal.processing import (
33
    BaseDummyInputsBuilder,
34
35
36
37
38
    BaseMultiModalProcessor,
    BaseProcessingInfo,
    PromptReplacement,
    PromptUpdate,
)
39
from vllm.sequence import IntermediateTensors
40
from vllm.utils.tensor_schema import TensorSchema, TensorShape
41

42
from .idefics2_vision_model import Idefics2VisionConfig
43
from .idefics2_vision_model import (
44
45
    Idefics2VisionTransformer as Idefics3VisionTransformer,
)
46
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsQuant
47
from .llama import LlamaDecoderLayer, LlamaMLP, LlamaModel
48
49
50
51
52
53
from .utils import (
    AutoWeightsLoader,
    WeightsMapper,
    is_pp_missing_parameter,
    maybe_prefix,
)
54
55


56
class AriaImagePixelInputs(TensorSchema):
57
    """
58
59
60
61
62
63
    Dimensions:
        - b: Batch size
        - n: Number of images
        - c: Number of channels
        - h: Height of each image
        - w: Width of each image
64
65
    """

66
67
    type: Literal["pixel_values"]

68
69
70
71
72
73
    pixel_values: Annotated[
        torch.Tensor,
        TensorShape("bn", 3, "h", "w"),
    ]

    pixel_mask: Annotated[
74
        torch.Tensor | None,
75
76
77
        TensorShape("bn", "h", "w"),
    ]

78

79
80
class AriaVisionTransformer(Idefics3VisionTransformer, SupportsQuant):
    packed_modules_mapping = {"qkv_proj": ["q_proj", "k_proj", "v_proj"]}
81
82
83
84

    def __init__(
        self,
        config: Idefics2VisionConfig,
85
        quant_config: QuantizationConfig | None = None,
86
87
        prefix: str = "",
    ) -> None:
88
        super().__init__(config, quant_config=quant_config, prefix=prefix)
89
90
91
92
93
        # Unlike Idefics3VisionTransformer which uses LayerNorm after the
        # final layer, Aria omits this normalization, so we replace it with an
        # Identity layer
        self.post_layernorm = nn.Identity()

94
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
95
96
97
98
99
100
101
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
        ]
        params_dict = dict(self.named_parameters())
102
        loaded_params: set[str] = set()
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
        for name, loaded_weight in weights:
            # NOTE: post_layernorm is not used in Aria
            if "post_layernorm" in name:
                continue

            for param_name, weight_name, shard_id in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                param = params_dict[name]
118
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
119
120
121
122
123
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params


124
class AriaProjectorMLP(nn.Module):
125
126
    def __init__(
        self,
127
128
129
        in_features: int,
        hidden_features: int,
        output_dim: int,
130
        prefix: str = "",
131
132
133
    ) -> None:
        super().__init__()

134
135
136
137
138
139
        self.linear_in = ColumnParallelLinear(
            in_features, hidden_features, bias=False, prefix=f"{prefix}.linear_in"
        )
        self.linear_out = RowParallelLinear(
            hidden_features, output_dim, bias=False, prefix=f"{prefix}.linear_out"
        )
140
141
        self.act = get_act_fn("gelu_new")

142
    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
143
144
145
146
147
148
149
150
151
152
153
154
        hidden_states, _ = self.linear_in(hidden_states)
        hidden_states = self.act(hidden_states)
        hidden_states, _ = self.linear_out(hidden_states)
        return hidden_states


class AriaProjector(nn.Module):
    """
    A projection module with one cross attention layer and one FFN layer, which
    projects ViT's outputs into MoE's inputs.

    Args:
155
156
        config: [AriaConfig](https://huggingface.co/docs/transformers/main/model_doc/aria#transformers.AriaConfig)
            containing projector configuration parameters.
157
158
159
160
161

    Outputs:
        A tensor with the shape of (batch_size, query_number, output_dim)
    """

162
    def __init__(self, config: AriaConfig, prefix: str = "") -> None:
163
        super().__init__()
164
165
166
167
168
169
170

        self.patch_to_query_dict = config.projector_patch_to_query_dict
        self.in_features = config.vision_config.hidden_size
        self.num_heads = config.vision_config.num_attention_heads
        self.kv_dim = config.vision_config.hidden_size
        self.hidden_features = config.text_config.hidden_size
        self.output_dim = config.text_config.hidden_size
171
172

        self.query = nn.Parameter(
173
174
175
176
            torch.empty(
                config.max_value_projector_patch_to_query_dict, self.in_features
            )
        )
177

178
        self.cross_attn = AriaCrossAttention(config)
179

180
        self.layer_norm = nn.LayerNorm(self.in_features)
181
        self.feed_forward = AriaProjectorMLP(
182
183
184
185
            self.in_features,
            self.hidden_features,
            self.output_dim,
            prefix=f"{prefix}.feed_forward",
186
        )
187

188
189
190
    def forward(
        self,
        x: torch.Tensor,
191
        attn_mask: torch.Tensor | None = None,
192
    ) -> torch.Tensor:
193
        batch_size, num_patches = x.shape[0], x.shape[1]
194

195
        if num_patches not in self.patch_to_query_dict:
196
197
198
199
200
            raise KeyError(
                f"Number of patches {num_patches} not found in "
                "patch_to_query_dict amongst possible values "
                f"{self.patch_to_query_dict.keys()}."
            )
201

202
203
204
        query_num = self.patch_to_query_dict[num_patches]

        queries = self.query[:query_num].unsqueeze(0).repeat(batch_size, 1, 1)
205
206
207
208
209
210
211

        if attn_mask is not None:
            attn_mask = attn_mask.repeat_interleave(self.num_heads, 0)
            attn_mask = attn_mask.unsqueeze(1).expand(-1, queries.size(1), -1)

        attention_out = self.cross_attn(x, queries, attn_mask=attn_mask)

212
        out = self.feed_forward(self.layer_norm(attention_out))
213
214
215
216

        return out


217
class AriaFusedMoE(FusedMoE):
218
219
220
    def weight_loader(
        self, param: nn.Parameter, loaded_weight: torch.Tensor, shard_id: str
    ) -> None:
221
222
223
224
225
        # Override the weight_loader to handle the expert weights in the Aria
        # model, which are already packed with experts, and merge the gate and
        # up weights for each expert.
        # Note: Loading expert weights with quantization is not supported
        tp_rank = get_tensor_model_parallel_rank()
226
        if shard_id == "w13":
227
228
229
230
231
232
            # the shape of loaded_weight is
            # (num_experts, hidden_size, 2 * moe_intermediate_size)
            if self.tp_size > 1:
                up, gate = loaded_weight.chunk(2, dim=-1)
                up_current_rank = up.chunk(self.tp_size, dim=-1)[tp_rank]
                gate_current_rank = gate.chunk(self.tp_size, dim=-1)[tp_rank]
233
234
235
                up_and_gate = torch.cat(
                    [up_current_rank, gate_current_rank], dim=-1
                ).transpose(1, 2)
236
237
238
                param.data.copy_(up_and_gate)
            else:
                param.data.copy_(loaded_weight.transpose(1, 2))
239
        elif shard_id == "w2":
240
241
242
            # the shape of loaded_weight is
            # (num_experts, moe_intermediate_size, hidden_size)
            if self.tp_size > 1:
243
                down_current_rank = loaded_weight.chunk(self.tp_size, dim=1)[tp_rank]
244
245
246
247
248
                param.data.copy_(down_current_rank.transpose(1, 2))
            else:
                param.data.copy_(loaded_weight.transpose(1, 2))


249
class AriaTextMoELayer(nn.Module):
250
251
252
253
254
255
256
257
258
259
    """
    Mixture of Experts (MoE) Layer for the AriaMoE model.

    This layer implements the MoE mechanism, which routes input tokens to
    different experts based on a routing algorithm, processes them through the
    experts, and then combines the outputs.
    """

    def __init__(
        self,
260
        config: AriaTextConfig,
261
        quant_config: QuantizationConfig | None,
262
        prefix: str = "",
263
264
265
266
267
    ) -> None:
        super().__init__()
        self.config = config

        self.router_weight = nn.Parameter(
268
269
            torch.empty((self.config.moe_num_experts, self.config.hidden_size))
        )
270

271
272
273
274
275
276
277
278
        self.shared_experts = LlamaMLP(
            config.hidden_size,
            config.intermediate_size * config.moe_num_shared_experts,
            "silu",
            quant_config=quant_config,
            bias=config.mlp_bias,
        )

279
        self.experts = AriaFusedMoE(
280
            shared_experts=self.shared_experts,
281
282
283
            num_experts=config.moe_num_experts,
            top_k=config.moe_topk,
            hidden_size=config.hidden_size,
284
            intermediate_size=config.intermediate_size,
285
            quant_config=quant_config,
286
            prefix=f"{prefix}.experts",
287
288
289
290
291
292
293
        )

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        """
        Forward pass of the MoE Layer.

        Args:
294
295
            hidden_states: Input tensor of shape
                (batch_size, sequence_length, hidden_size).
296
297
298
299
300

        Returns:
            torch.Tensor: Output tensor after passing through the MoE layer.
        """

301
        router_output = torch.nn.functional.linear(hidden_states, self.router_weight)
302

303
        return self.experts(hidden_states, router_output)
304
305


306
class AriaTextDecoderLayer(LlamaDecoderLayer):
307
308
309
310
311
312
    """
    Custom Decoder Layer for the AriaMoE model which modifies the standard
    `LlamaDecoderLayer` by replacing the traditional MLP with a Mixture of
    Experts (MoE) Layer.
    """

313
314
315
316
317
318
    def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None:
        super().__init__(vllm_config, prefix)

        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config

319
320
321
        self.mlp = AriaTextMoELayer(
            config, quant_config=quant_config, prefix=f"{prefix}.mlp"
        )
322
323


324
class AriaTextModel(LlamaModel, SupportsQuant):
325
326
327
328
    """
    Custom LlamaModel for the AriaMoE model which modifies the standard
    LlamaModel by replacing the `LlamaDecoderLayer` with `MoEDecoderLayer`.
    """
329

330
331
332
333
334
335
    packed_modules_mapping = {
        "qkv_proj": ["q_proj", "k_proj", "v_proj"],
        "gate_up_proj": ["gate_proj", "up_proj"],
        "experts.w13_weight": ["experts.fc1.weight"],
        "experts.w2_weight": ["experts.fc2.weight"],
    }
336
337

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
338
339
340
        super().__init__(
            vllm_config=vllm_config, prefix=prefix, layer_type=AriaTextDecoderLayer
        )
341
342
343

    # Adapted from LlamaModel.load_weights with the modification of adding
    # the expert weights mapping to `stacked_params_mapping`
344
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
345
346
347
348
349
350
351
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            (".qkv_proj", ".q_proj", "q"),
            (".qkv_proj", ".k_proj", "k"),
            (".qkv_proj", ".v_proj", "v"),
            (".gate_up_proj", ".gate_proj", 0),
            (".gate_up_proj", ".up_proj", 1),
352
353
            ("experts.w13_weight", "experts.fc1.weight", "w13"),
            ("experts.w2_weight", "experts.fc2.weight", "w2"),
354
355
        ]
        params_dict = dict(self.named_parameters())
356
        loaded_params: set[str] = set()
357
358
359
        for name, loaded_weight in weights:
            if "rotary_emb.inv_freq" in name:
                continue
360
            if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
361
362
363
                # Models trained using ColossalAI may include these tensors in
                # the checkpoint. Skip them.
                continue
364
365
366
            if self.quant_config is not None and (
                scale_name := self.quant_config.get_cache_scale(name)
            ):
367
                # Loading kv cache quantization scales
368
                param = params_dict[scale_name]
369
370
371
372
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
                loaded_weight = (
                    loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
                )
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
                weight_loader(param, loaded_weight)
                loaded_params.add(scale_name)
                continue
            for param_name, weight_name, shard_id in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue

                if is_pp_missing_parameter(name, self):
                    continue

                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                # Remapping the name of FP8 kv-scale.
                name = maybe_remap_kv_scale_name(name, params_dict)
                if name is None:
                    continue

                if is_pp_missing_parameter(name, self):
                    continue

                param = params_dict[name]
404
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
405
406
407
408
409
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params


410
411
class AriaProcessingInfo(BaseProcessingInfo):
    def get_hf_config(self):
412
        return self.ctx.get_hf_config(AriaConfig)
413

414
    def get_vision_config(self):
415
        return self.get_hf_config().vision_config
416

417
418
    def get_hf_processor(self, **kwargs: object):
        return self.ctx.get_hf_processor(AriaProcessor, **kwargs)
419

420
    def get_supported_mm_limits(self) -> Mapping[str, int | None]:
421
422
        return {"image": None}

423
424
425
426
427
428
    def get_num_image_tokens(self) -> int:
        hf_config = self.get_hf_config()
        return max(hf_config.projector_patch_to_query_dict.values())


class AriaDummyInputsBuilder(BaseDummyInputsBuilder[AriaProcessingInfo]):
429
430
431
432
433
434
435
436
437
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_images = mm_counts.get("image", 0)

        processor = self.info.get_hf_processor()
        image_token: str = processor.tokenizer.image_token  # type: ignore

        return image_token * num_images

    def get_dummy_mm_data(
438
439
440
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
441
        mm_options: Mapping[str, BaseDummyOptions],
442
    ) -> MultiModalDataDict:
443
        vision_config = self.info.get_vision_config()
444
445
446
447

        max_image_size = vision_config.image_size
        num_images = mm_counts.get("image", 0)

448
        image_overrides = mm_options.get("image")
449

450
        return {
451
452
453
454
455
456
            "image": self._get_dummy_images(
                width=max_image_size,
                height=max_image_size,
                num_images=num_images,
                overrides=image_overrides,
            )
457
458
459
        }


460
class AriaMultiModalProcessor(BaseMultiModalProcessor[AriaProcessingInfo]):
461
462
463
464
465
466
467
468
469
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return dict(
            pixel_values=MultiModalFieldConfig.batched("image"),
            pixel_mask=MultiModalFieldConfig.batched("image"),
        )
470

471
    def _get_prompt_updates(
472
473
474
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
475
        out_mm_kwargs: MultiModalKwargsItems,
476
    ) -> Sequence[PromptUpdate]:
477
        hf_config = self.info.get_hf_config()
478
479
        image_token_id = hf_config.image_token_index

480
        num_image_tokens = self.info.get_num_image_tokens()
481
482
483
484
485

        return [
            PromptReplacement(
                modality="image",
                target=[image_token_id],
486
                replacement=[image_token_id] * num_image_tokens,
487
488
            )
        ]
489
490


491
492
493
494
495
@MULTIMODAL_REGISTRY.register_processor(
    AriaMultiModalProcessor,
    info=AriaProcessingInfo,
    dummy_inputs=AriaDummyInputsBuilder,
)
496
497
498
499
500
501
502
class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
    """
    Aria model for conditional generation tasks.

    This model combines a vision tower, a multi-modal projector, and a language
    model to perform tasks that involve both image and text inputs.
    """
503

504
505
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={
506
507
508
509
510
            # mapping for new names in checkpoint saved after transformers v4.52
            "model.language_model.": "language_model.model.",
            "model.vision_tower.": "vision_tower.",
            "model.multi_modal_projector.": "multi_modal_projector.",
            # mapping for original checkpoint
511
512
513
514
515
516
517
            "language_model.model": "language_model",
            "language_model.lm_head": "lm_head",
        },
        orig_to_new_suffix={
            "router.weight": "router_weight",
        },
    )
518

519
    @classmethod
520
    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
521
522
523
524
525
        if modality.startswith("image"):
            return "<|fim_prefix|><|img|><|fim_suffix|>"

        raise ValueError("Only image modality is supported")

526
527
528
529
530
531
532
533
534
535
    def __init__(
        self,
        vllm_config: VllmConfig,
        prefix: str = "",
    ):
        super().__init__()
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config

        self.config = config
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551

        with self._mark_tower_model(vllm_config, "image"):
            self.vision_tower = AriaVisionTransformer(
                config.vision_config,
                quant_config=quant_config,
                prefix=f"{prefix}.vision_tower",
            )
            self.multi_modal_projector = AriaProjector(
                config, prefix=maybe_prefix(prefix, "multi_modal_projector")
            )

        with self._mark_language_model(vllm_config):
            self.language_model = AriaTextModel(
                vllm_config=vllm_config.with_hf_config(config.text_config),
                prefix=maybe_prefix(prefix, "language_model.model"),
            )
552

Divakar Verma's avatar
Divakar Verma committed
553
554
555
556
557
558
559
560
561
562
563
564
            self.lm_head = ParallelLMHead(
                config.text_config.vocab_size,
                config.text_config.hidden_size,
                quant_config=quant_config,
                prefix=maybe_prefix(prefix, "lm_head"),
            )

            logit_scale = getattr(config, "logit_scale", 1.0)
            self.logits_processor = LogitsProcessor(
                config.text_config.vocab_size, scale=logit_scale
            )

565
    def _parse_and_validate_image_input(
566
        self, **kwargs: object
567
    ) -> AriaImagePixelInputs | None:
568
569
570
571
572
573
574
        pixel_values = kwargs.pop("pixel_values", None)
        pixel_mask = kwargs.pop("pixel_mask", None)

        if pixel_values is None:
            return None

        return AriaImagePixelInputs(
575
576
577
            type="pixel_values",
            pixel_values=pixel_values,
            pixel_mask=pixel_mask,
578
579
        )

580
    def _create_patch_attention_mask(
581
        self,
582
583
        pixel_mask: torch.Tensor | None,
    ) -> torch.Tensor | None:
584
585
586
587
588
589
590
591
592
593
594
595
596
597
        if pixel_mask is None:
            return None

        patches_subgrid = pixel_mask.unfold(
            dimension=1,
            size=self.vision_tower.config.patch_size,
            step=self.vision_tower.config.patch_size,
        ).unfold(
            dimension=2,
            size=self.vision_tower.config.patch_size,
            step=self.vision_tower.config.patch_size,
        )
        return (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()

598
599
    def _process_image_input(
        self, image_input: AriaImagePixelInputs
600
    ) -> tuple[torch.Tensor, torch.Tensor]:
601
602
        pixel_values = image_input["pixel_values"]
        pixel_mask = image_input["pixel_mask"]
603

604
605
606
607
608
609
610
611
612
613
614
615
        patch_attention_mask = self._create_patch_attention_mask(pixel_mask)

        image_outputs = self.vision_tower(
            pixel_values=pixel_values,
            patch_attention_mask=patch_attention_mask,
        )
        image_attn_mask = None
        if patch_attention_mask is not None:
            flattened_mask = patch_attention_mask.flatten(1)
            image_attn_mask = torch.logical_not(flattened_mask)

        return self.multi_modal_projector(image_outputs, image_attn_mask)
616

617
    def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
618
619
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input is None:
620
            return []
621
622
623
624
625
        multimodal_embeddings = self._process_image_input(image_input)
        return multimodal_embeddings

    def forward(
        self,
626
        input_ids: torch.Tensor | None,
627
        positions: torch.Tensor,
628
629
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
630
        **kwargs: object,
631
    ) -> torch.Tensor | IntermediateTensors:
632
633
        if intermediate_tensors is not None:
            inputs_embeds = None
634
635
636
637
638
639
640
641
642
643

        hidden_states = self.language_model(
            input_ids,
            positions,
            intermediate_tensors,
            inputs_embeds=inputs_embeds,
        )

        return hidden_states

644
645
646
647
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor | None:
Divakar Verma's avatar
Divakar Verma committed
648
649
        logits = self.logits_processor(self.lm_head, hidden_states)
        return logits
650

651
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
652
        loader = AutoWeightsLoader(self)
653
        loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)