aria.py 22.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from collections.abc import Iterable, Mapping, Sequence
4
from typing import Annotated, Literal
5
6
7

import torch
import torch.nn as nn
8
9
10
from transformers import AriaConfig, AriaTextConfig, BatchFeature
from transformers.models.aria.modeling_aria import AriaCrossAttention
from transformers.models.aria.processing_aria import AriaProcessor
11

12
from vllm.config import VllmConfig
13
from vllm.config.multimodal import BaseDummyOptions
14
from vllm.distributed import get_tensor_model_parallel_rank
15
from vllm.inputs import MultiModalDataDict
16
from vllm.model_executor.layers.activation import get_act_fn
17
18
19
from vllm.model_executor.layers.fused_moe import (
    FusedMoE,
)
20
from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear
Divakar Verma's avatar
Divakar Verma committed
21
from vllm.model_executor.layers.logits_processor import LogitsProcessor
22
from vllm.model_executor.layers.quantization import QuantizationConfig
Divakar Verma's avatar
Divakar Verma committed
23
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
24
from vllm.model_executor.model_loader.weight_utils import (
25
26
27
    default_weight_loader,
    maybe_remap_kv_scale_name,
)
28
from vllm.multimodal import MULTIMODAL_REGISTRY
29
30
31
32
from vllm.multimodal.inputs import (
    MultiModalFieldConfig,
    MultiModalKwargsItems,
)
33
from vllm.multimodal.parse import MultiModalDataItems
34
from vllm.multimodal.processing import (
35
    BaseDummyInputsBuilder,
36
37
38
39
40
    BaseMultiModalProcessor,
    BaseProcessingInfo,
    PromptReplacement,
    PromptUpdate,
)
41
from vllm.sequence import IntermediateTensors
42
from vllm.utils.tensor_schema import TensorSchema, TensorShape
43

44
from .idefics2_vision_model import Idefics2VisionConfig
45
from .idefics2_vision_model import (
46
47
    Idefics2VisionTransformer as Idefics3VisionTransformer,
)
48
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsQuant
49
from .llama import LlamaDecoderLayer, LlamaMLP, LlamaModel
50
51
52
53
54
55
from .utils import (
    AutoWeightsLoader,
    WeightsMapper,
    is_pp_missing_parameter,
    maybe_prefix,
)
56
57


58
class AriaImagePixelInputs(TensorSchema):
59
    """
60
61
62
63
64
65
    Dimensions:
        - b: Batch size
        - n: Number of images
        - c: Number of channels
        - h: Height of each image
        - w: Width of each image
66
67
    """

68
69
    type: Literal["pixel_values"]

70
71
72
73
74
75
    pixel_values: Annotated[
        torch.Tensor,
        TensorShape("bn", 3, "h", "w"),
    ]

    pixel_mask: Annotated[
76
        torch.Tensor | None,
77
78
79
        TensorShape("bn", "h", "w"),
    ]

80

81
82
class AriaVisionTransformer(Idefics3VisionTransformer, SupportsQuant):
    packed_modules_mapping = {"qkv_proj": ["q_proj", "k_proj", "v_proj"]}
83
84
85
86

    def __init__(
        self,
        config: Idefics2VisionConfig,
87
        quant_config: QuantizationConfig | None = None,
88
89
        prefix: str = "",
    ) -> None:
90
        super().__init__(config, quant_config=quant_config, prefix=prefix)
91
92
93
94
95
        # Unlike Idefics3VisionTransformer which uses LayerNorm after the
        # final layer, Aria omits this normalization, so we replace it with an
        # Identity layer
        self.post_layernorm = nn.Identity()

96
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
97
98
99
100
101
102
103
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
        ]
        params_dict = dict(self.named_parameters())
104
        loaded_params: set[str] = set()
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
        for name, loaded_weight in weights:
            # NOTE: post_layernorm is not used in Aria
            if "post_layernorm" in name:
                continue

            for param_name, weight_name, shard_id in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                param = params_dict[name]
120
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
121
122
123
124
125
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params


126
class AriaProjectorMLP(nn.Module):
127
128
    def __init__(
        self,
129
130
131
        in_features: int,
        hidden_features: int,
        output_dim: int,
132
        prefix: str = "",
133
134
135
    ) -> None:
        super().__init__()

136
137
138
139
140
141
        self.linear_in = ColumnParallelLinear(
            in_features, hidden_features, bias=False, prefix=f"{prefix}.linear_in"
        )
        self.linear_out = RowParallelLinear(
            hidden_features, output_dim, bias=False, prefix=f"{prefix}.linear_out"
        )
142
143
        self.act = get_act_fn("gelu_new")

144
    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
145
146
147
148
149
150
151
152
153
154
155
156
        hidden_states, _ = self.linear_in(hidden_states)
        hidden_states = self.act(hidden_states)
        hidden_states, _ = self.linear_out(hidden_states)
        return hidden_states


class AriaProjector(nn.Module):
    """
    A projection module with one cross attention layer and one FFN layer, which
    projects ViT's outputs into MoE's inputs.

    Args:
157
158
        config: [AriaConfig](https://huggingface.co/docs/transformers/main/model_doc/aria#transformers.AriaConfig)
            containing projector configuration parameters.
159
160
161
162
163

    Outputs:
        A tensor with the shape of (batch_size, query_number, output_dim)
    """

164
    def __init__(self, config: AriaConfig, prefix: str = "") -> None:
165
        super().__init__()
166
167
168
169
170
171
172

        self.patch_to_query_dict = config.projector_patch_to_query_dict
        self.in_features = config.vision_config.hidden_size
        self.num_heads = config.vision_config.num_attention_heads
        self.kv_dim = config.vision_config.hidden_size
        self.hidden_features = config.text_config.hidden_size
        self.output_dim = config.text_config.hidden_size
173
174

        self.query = nn.Parameter(
175
176
177
178
            torch.empty(
                config.max_value_projector_patch_to_query_dict, self.in_features
            )
        )
179

180
        self.cross_attn = AriaCrossAttention(config)
181

182
        self.layer_norm = nn.LayerNorm(self.in_features)
183
        self.feed_forward = AriaProjectorMLP(
184
185
186
187
            self.in_features,
            self.hidden_features,
            self.output_dim,
            prefix=f"{prefix}.feed_forward",
188
        )
189

190
191
192
    def forward(
        self,
        x: torch.Tensor,
193
        attn_mask: torch.Tensor | None = None,
194
    ) -> torch.Tensor:
195
        batch_size, num_patches = x.shape[0], x.shape[1]
196

197
        if num_patches not in self.patch_to_query_dict:
198
199
200
201
202
            raise KeyError(
                f"Number of patches {num_patches} not found in "
                "patch_to_query_dict amongst possible values "
                f"{self.patch_to_query_dict.keys()}."
            )
203

204
205
206
        query_num = self.patch_to_query_dict[num_patches]

        queries = self.query[:query_num].unsqueeze(0).repeat(batch_size, 1, 1)
207
208
209
210
211
212
213

        if attn_mask is not None:
            attn_mask = attn_mask.repeat_interleave(self.num_heads, 0)
            attn_mask = attn_mask.unsqueeze(1).expand(-1, queries.size(1), -1)

        attention_out = self.cross_attn(x, queries, attn_mask=attn_mask)

214
        out = self.feed_forward(self.layer_norm(attention_out))
215
216
217
218

        return out


219
class AriaFusedMoE(FusedMoE):
220
221
222
    def weight_loader(
        self, param: nn.Parameter, loaded_weight: torch.Tensor, shard_id: str
    ) -> None:
223
224
225
226
227
        # Override the weight_loader to handle the expert weights in the Aria
        # model, which are already packed with experts, and merge the gate and
        # up weights for each expert.
        # Note: Loading expert weights with quantization is not supported
        tp_rank = get_tensor_model_parallel_rank()
228
        if shard_id == "w13":
229
230
231
232
233
234
            # the shape of loaded_weight is
            # (num_experts, hidden_size, 2 * moe_intermediate_size)
            if self.tp_size > 1:
                up, gate = loaded_weight.chunk(2, dim=-1)
                up_current_rank = up.chunk(self.tp_size, dim=-1)[tp_rank]
                gate_current_rank = gate.chunk(self.tp_size, dim=-1)[tp_rank]
235
236
237
                up_and_gate = torch.cat(
                    [up_current_rank, gate_current_rank], dim=-1
                ).transpose(1, 2)
238
239
240
                param.data.copy_(up_and_gate)
            else:
                param.data.copy_(loaded_weight.transpose(1, 2))
241
        elif shard_id == "w2":
242
243
244
            # the shape of loaded_weight is
            # (num_experts, moe_intermediate_size, hidden_size)
            if self.tp_size > 1:
245
                down_current_rank = loaded_weight.chunk(self.tp_size, dim=1)[tp_rank]
246
247
248
249
250
                param.data.copy_(down_current_rank.transpose(1, 2))
            else:
                param.data.copy_(loaded_weight.transpose(1, 2))


251
class AriaTextMoELayer(nn.Module):
252
253
254
255
256
257
258
259
260
261
    """
    Mixture of Experts (MoE) Layer for the AriaMoE model.

    This layer implements the MoE mechanism, which routes input tokens to
    different experts based on a routing algorithm, processes them through the
    experts, and then combines the outputs.
    """

    def __init__(
        self,
262
        config: AriaTextConfig,
263
        quant_config: QuantizationConfig | None,
264
        prefix: str = "",
265
266
267
268
269
    ) -> None:
        super().__init__()
        self.config = config

        self.router_weight = nn.Parameter(
270
271
            torch.empty((self.config.moe_num_experts, self.config.hidden_size))
        )
272

273
274
275
276
277
278
279
280
        self.shared_experts = LlamaMLP(
            config.hidden_size,
            config.intermediate_size * config.moe_num_shared_experts,
            "silu",
            quant_config=quant_config,
            bias=config.mlp_bias,
        )

281
        self.experts = AriaFusedMoE(
282
            shared_experts=self.shared_experts,
283
284
285
            num_experts=config.moe_num_experts,
            top_k=config.moe_topk,
            hidden_size=config.hidden_size,
286
            intermediate_size=config.intermediate_size,
287
            quant_config=quant_config,
288
            prefix=f"{prefix}.experts",
289
290
291
292
293
294
295
        )

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        """
        Forward pass of the MoE Layer.

        Args:
296
297
            hidden_states: Input tensor of shape
                (batch_size, sequence_length, hidden_size).
298
299
300
301
302

        Returns:
            torch.Tensor: Output tensor after passing through the MoE layer.
        """

303
        router_output = torch.nn.functional.linear(hidden_states, self.router_weight)
304

305
        return self.experts(hidden_states, router_output)
306
307


308
class AriaTextDecoderLayer(LlamaDecoderLayer):
309
310
311
312
313
314
    """
    Custom Decoder Layer for the AriaMoE model which modifies the standard
    `LlamaDecoderLayer` by replacing the traditional MLP with a Mixture of
    Experts (MoE) Layer.
    """

315
316
317
318
319
320
    def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None:
        super().__init__(vllm_config, prefix)

        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config

321
322
323
        self.mlp = AriaTextMoELayer(
            config, quant_config=quant_config, prefix=f"{prefix}.mlp"
        )
324
325


326
class AriaTextModel(LlamaModel, SupportsQuant):
327
328
329
330
    """
    Custom LlamaModel for the AriaMoE model which modifies the standard
    LlamaModel by replacing the `LlamaDecoderLayer` with `MoEDecoderLayer`.
    """
331

332
333
334
335
336
337
    packed_modules_mapping = {
        "qkv_proj": ["q_proj", "k_proj", "v_proj"],
        "gate_up_proj": ["gate_proj", "up_proj"],
        "experts.w13_weight": ["experts.fc1.weight"],
        "experts.w2_weight": ["experts.fc2.weight"],
    }
338
339

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
340
341
342
        super().__init__(
            vllm_config=vllm_config, prefix=prefix, layer_type=AriaTextDecoderLayer
        )
343
344
345

    # Adapted from LlamaModel.load_weights with the modification of adding
    # the expert weights mapping to `stacked_params_mapping`
346
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
347
348
349
350
351
352
353
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            (".qkv_proj", ".q_proj", "q"),
            (".qkv_proj", ".k_proj", "k"),
            (".qkv_proj", ".v_proj", "v"),
            (".gate_up_proj", ".gate_proj", 0),
            (".gate_up_proj", ".up_proj", 1),
354
355
            ("experts.w13_weight", "experts.fc1.weight", "w13"),
            ("experts.w2_weight", "experts.fc2.weight", "w2"),
356
357
        ]
        params_dict = dict(self.named_parameters())
358
        loaded_params: set[str] = set()
359
360
361
        for name, loaded_weight in weights:
            if "rotary_emb.inv_freq" in name:
                continue
362
            if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
363
364
365
                # Models trained using ColossalAI may include these tensors in
                # the checkpoint. Skip them.
                continue
366
367
368
            if self.quant_config is not None and (
                scale_name := self.quant_config.get_cache_scale(name)
            ):
369
                # Loading kv cache quantization scales
370
                param = params_dict[scale_name]
371
372
373
374
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
                loaded_weight = (
                    loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
                )
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
                weight_loader(param, loaded_weight)
                loaded_params.add(scale_name)
                continue
            for param_name, weight_name, shard_id in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue

                if is_pp_missing_parameter(name, self):
                    continue

                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                # Remapping the name of FP8 kv-scale.
                name = maybe_remap_kv_scale_name(name, params_dict)
                if name is None:
                    continue

                if is_pp_missing_parameter(name, self):
                    continue

                param = params_dict[name]
406
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
407
408
409
410
411
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params


412
413
class AriaProcessingInfo(BaseProcessingInfo):
    def get_hf_config(self):
414
        return self.ctx.get_hf_config(AriaConfig)
415

416
    def get_vision_config(self):
417
        return self.get_hf_config().vision_config
418

419
420
    def get_hf_processor(self, **kwargs: object):
        return self.ctx.get_hf_processor(AriaProcessor, **kwargs)
421

422
    def get_supported_mm_limits(self) -> Mapping[str, int | None]:
423
424
        return {"image": None}

425
426
427
428
429
430
    def get_num_image_tokens(self) -> int:
        hf_config = self.get_hf_config()
        return max(hf_config.projector_patch_to_query_dict.values())


class AriaDummyInputsBuilder(BaseDummyInputsBuilder[AriaProcessingInfo]):
431
432
433
434
435
436
437
438
439
    def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
        num_images = mm_counts.get("image", 0)

        processor = self.info.get_hf_processor()
        image_token: str = processor.tokenizer.image_token  # type: ignore

        return image_token * num_images

    def get_dummy_mm_data(
440
441
442
        self,
        seq_len: int,
        mm_counts: Mapping[str, int],
443
        mm_options: Mapping[str, BaseDummyOptions],
444
    ) -> MultiModalDataDict:
445
        vision_config = self.info.get_vision_config()
446
447
448
449

        max_image_size = vision_config.image_size
        num_images = mm_counts.get("image", 0)

450
        image_overrides = mm_options.get("image")
451

452
        return {
453
454
455
456
457
458
            "image": self._get_dummy_images(
                width=max_image_size,
                height=max_image_size,
                num_images=num_images,
                overrides=image_overrides,
            )
459
460
461
        }


462
class AriaMultiModalProcessor(BaseMultiModalProcessor[AriaProcessingInfo]):
463
464
465
466
467
468
469
470
471
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return dict(
            pixel_values=MultiModalFieldConfig.batched("image"),
            pixel_mask=MultiModalFieldConfig.batched("image"),
        )
472

473
    def _get_prompt_updates(
474
475
476
        self,
        mm_items: MultiModalDataItems,
        hf_processor_mm_kwargs: Mapping[str, object],
477
        out_mm_kwargs: MultiModalKwargsItems,
478
    ) -> Sequence[PromptUpdate]:
479
        hf_config = self.info.get_hf_config()
480
481
        image_token_id = hf_config.image_token_index

482
        num_image_tokens = self.info.get_num_image_tokens()
483
484
485
486
487

        return [
            PromptReplacement(
                modality="image",
                target=[image_token_id],
488
                replacement=[image_token_id] * num_image_tokens,
489
490
            )
        ]
491
492


493
494
495
496
497
@MULTIMODAL_REGISTRY.register_processor(
    AriaMultiModalProcessor,
    info=AriaProcessingInfo,
    dummy_inputs=AriaDummyInputsBuilder,
)
498
499
500
501
502
503
504
class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
    """
    Aria model for conditional generation tasks.

    This model combines a vision tower, a multi-modal projector, and a language
    model to perform tasks that involve both image and text inputs.
    """
505

506
507
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={
508
509
510
511
512
            # mapping for new names in checkpoint saved after transformers v4.52
            "model.language_model.": "language_model.model.",
            "model.vision_tower.": "vision_tower.",
            "model.multi_modal_projector.": "multi_modal_projector.",
            # mapping for original checkpoint
513
514
515
516
517
518
519
            "language_model.model": "language_model",
            "language_model.lm_head": "lm_head",
        },
        orig_to_new_suffix={
            "router.weight": "router_weight",
        },
    )
520

521
    @classmethod
522
    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
523
524
525
526
527
        if modality.startswith("image"):
            return "<|fim_prefix|><|img|><|fim_suffix|>"

        raise ValueError("Only image modality is supported")

528
529
530
531
532
533
534
535
536
537
    def __init__(
        self,
        vllm_config: VllmConfig,
        prefix: str = "",
    ):
        super().__init__()
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config

        self.config = config
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553

        with self._mark_tower_model(vllm_config, "image"):
            self.vision_tower = AriaVisionTransformer(
                config.vision_config,
                quant_config=quant_config,
                prefix=f"{prefix}.vision_tower",
            )
            self.multi_modal_projector = AriaProjector(
                config, prefix=maybe_prefix(prefix, "multi_modal_projector")
            )

        with self._mark_language_model(vllm_config):
            self.language_model = AriaTextModel(
                vllm_config=vllm_config.with_hf_config(config.text_config),
                prefix=maybe_prefix(prefix, "language_model.model"),
            )
554

Divakar Verma's avatar
Divakar Verma committed
555
556
557
558
559
560
561
562
563
564
565
566
            self.lm_head = ParallelLMHead(
                config.text_config.vocab_size,
                config.text_config.hidden_size,
                quant_config=quant_config,
                prefix=maybe_prefix(prefix, "lm_head"),
            )

            logit_scale = getattr(config, "logit_scale", 1.0)
            self.logits_processor = LogitsProcessor(
                config.text_config.vocab_size, scale=logit_scale
            )

567
    def _parse_and_validate_image_input(
568
        self, **kwargs: object
569
    ) -> AriaImagePixelInputs | None:
570
571
572
573
574
575
576
        pixel_values = kwargs.pop("pixel_values", None)
        pixel_mask = kwargs.pop("pixel_mask", None)

        if pixel_values is None:
            return None

        return AriaImagePixelInputs(
577
578
579
            type="pixel_values",
            pixel_values=pixel_values,
            pixel_mask=pixel_mask,
580
581
        )

582
    def _create_patch_attention_mask(
583
        self,
584
585
        pixel_mask: torch.Tensor | None,
    ) -> torch.Tensor | None:
586
587
588
589
590
591
592
593
594
595
596
597
598
599
        if pixel_mask is None:
            return None

        patches_subgrid = pixel_mask.unfold(
            dimension=1,
            size=self.vision_tower.config.patch_size,
            step=self.vision_tower.config.patch_size,
        ).unfold(
            dimension=2,
            size=self.vision_tower.config.patch_size,
            step=self.vision_tower.config.patch_size,
        )
        return (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()

600
601
    def _process_image_input(
        self, image_input: AriaImagePixelInputs
602
    ) -> tuple[torch.Tensor, torch.Tensor]:
603
604
        pixel_values = image_input["pixel_values"]
        pixel_mask = image_input["pixel_mask"]
605

606
607
608
609
610
611
612
613
614
615
616
617
        patch_attention_mask = self._create_patch_attention_mask(pixel_mask)

        image_outputs = self.vision_tower(
            pixel_values=pixel_values,
            patch_attention_mask=patch_attention_mask,
        )
        image_attn_mask = None
        if patch_attention_mask is not None:
            flattened_mask = patch_attention_mask.flatten(1)
            image_attn_mask = torch.logical_not(flattened_mask)

        return self.multi_modal_projector(image_outputs, image_attn_mask)
618

619
    def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
620
621
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input is None:
622
            return []
623
624
625
626
627
        multimodal_embeddings = self._process_image_input(image_input)
        return multimodal_embeddings

    def forward(
        self,
628
        input_ids: torch.Tensor | None,
629
        positions: torch.Tensor,
630
631
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
632
        **kwargs: object,
633
    ) -> torch.Tensor | IntermediateTensors:
634
635
        if intermediate_tensors is not None:
            inputs_embeds = None
636
637
638
639
640
641
642
643
644
645

        hidden_states = self.language_model(
            input_ids,
            positions,
            intermediate_tensors,
            inputs_embeds=inputs_embeds,
        )

        return hidden_states

646
647
648
649
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor | None:
Divakar Verma's avatar
Divakar Verma committed
650
651
        logits = self.logits_processor(self.lm_head, hidden_states)
        return logits
652

653
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
654
        loader = AutoWeightsLoader(self)
655
        loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)