llava_next.py 24 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from abc import abstractmethod
5
6
7
from collections.abc import Iterable, Mapping
from typing import (Final, Literal, Optional, Protocol, TypedDict, TypeVar,
                    Union)
8
9
10

import torch
import torch.nn as nn
11
from transformers import BatchFeature, LlavaNextConfig, LlavaNextProcessor
12
13
14
15
from transformers.models.llava_next.modeling_llava_next import (
    get_anyres_image_grid_shape, unpad_image)
from typing_extensions import NotRequired

16
from vllm.config import VllmConfig
17
from vllm.model_executor.sampling_metadata import SamplingMetadata
18
from vllm.multimodal import MULTIMODAL_REGISTRY
19
from vllm.multimodal.inputs import MultiModalFieldConfig
20
from vllm.multimodal.parse import ImageSize
21
from vllm.sequence import IntermediateTensors
22

23
from .clip import CLIPVisionModel
24
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
25
26
from .llava import (BaseLlavaMultiModalProcessor, BaseLlavaProcessingInfo,
                    LlavaDummyInputsBuilder, LlavaLikeConfig,
27
                    LlavaMultiModalProjector, init_vision_tower_for_llava)
28
from .siglip import SiglipVisionModel
29
30
from .utils import (AutoWeightsLoader, WeightsMapper, embed_multimodal,
                    flatten_bn, init_vllm_registered_model, maybe_prefix)
31
32
33
34


class LlavaNextImagePixelInputs(TypedDict):
    type: Literal["pixel_values"]
35
    pixel_values: Union[torch.Tensor, list[torch.Tensor]]
36
    """
37
38
    Shape:
    `(batch_size * num_images, 1 + num_patches, num_channels, height, width)`
39

40
41
    Note that `num_patches` may be different per batch and image,
    in which case the data is passed as a list instead of a batched tensor.
42
    """
43
44

    image_sizes: NotRequired[torch.Tensor]
45
    """
46
    Shape: `(batch_size * num_images, 2)`
47
48
49

    This should be in `(height, width)` format.
    """
50
51


52
53
54
class LlavaNextImageEmbeddingInputs(TypedDict):
    type: Literal["image_embeds"]
    data: torch.Tensor
55
    """Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
56
57
58
59
60
61
62

    `hidden_size` must match the hidden size of language model backbone.
    """


LlavaNextImageInputs = Union[LlavaNextImagePixelInputs,
                             LlavaNextImageEmbeddingInputs]
63
64


65
66
class LlavaNextLikeConfig(LlavaLikeConfig, Protocol):
    image_grid_pinpoints: Final[list[list[int]]]
67

68

69
class LlavaNextProcessingInfo(BaseLlavaProcessingInfo):
70

71
    def get_hf_config(self) -> LlavaNextLikeConfig:
72
        return self.ctx.get_hf_config(LlavaNextConfig)
73

74
75
    def get_hf_processor(self, **kwargs: object):
        hf_processor = self.ctx.get_hf_processor(LlavaNextProcessor, **kwargs)
76
77
78
79
80
81
82
83

        # In case patch_size is omitted from `processor_config.json`
        # e.g. for E5-V: https://huggingface.co/royokong/e5-v
        if hf_processor.patch_size is None:
            patch_size = self.get_vision_encoder_info().get_patch_size()
            hf_processor.patch_size = patch_size

        return hf_processor
84

85
    # Based on: https://github.com/huggingface/text-generation-inference/blob/v3.0.1/server/text_generation_server/models/vlm_causal_lm.py#L113
86
    def get_num_image_tokens(
87
88
89
90
91
        self,
        *,
        image_width: int,
        image_height: int,
    ) -> int:
92
93
        hf_config = self.get_hf_config()
        vision_encoder_info = self.get_vision_encoder_info()
94
95
96

        base_feature_size = self._apply_feature_select_strategy(
            hf_config.vision_feature_select_strategy,
97
            vision_encoder_info.get_num_image_tokens(
98
99
100
                image_width=image_width,
                image_height=image_height,
            ),
101
        )
102
103
104
105

        num_patch_height, num_patch_width = get_anyres_image_grid_shape(
            image_size=(image_height, image_width),
            grid_pinpoints=hf_config.image_grid_pinpoints,
106
            patch_size=vision_encoder_info.get_image_size(),
107
108
        )

109
110
111
112
113
114
        (
            unpadded_feature_size,
            newline_feature_size,
        ) = self._get_num_unpadded_features(
            original_height=image_height,
            original_width=image_width,
115
            npatches=vision_encoder_info.get_patch_grid_length(),
116
117
118
            num_patch_height=num_patch_height,
            num_patch_width=num_patch_width,
        )
119

120
        return unpadded_feature_size + newline_feature_size + base_feature_size
121

122
    # Based on: https://github.com/huggingface/text-generation-inference/blob/v3.0.1/server/text_generation_server/models/vlm_causal_lm.py#L86
123
124
125
126
127
128
129
130
131
    def _get_num_unpadded_features(
        self,
        *,
        original_height: int,
        original_width: int,
        npatches: int,
        num_patch_height: int,
        num_patch_width: int,
    ) -> tuple[int, int]:
132
133
        current_height = npatches * num_patch_height
        current_width = npatches * num_patch_width
134

135
136
        aspect_ratio = original_width / original_height
        current_aspect_ratio = current_width / current_height
137

138
        if aspect_ratio > current_aspect_ratio:
139
140
            new_height = int(
                round(original_height * (current_width / original_width), 7))
141
142
            padding = (current_height - new_height) // 2
            current_height = current_height - (2 * padding)
143
        else:
144
145
            new_width = int(
                round(original_width * (current_height / original_height), 7))
146
147
            padding = (current_width - new_width) // 2
            current_width = current_width - (2 * padding)
148

149
150
        unpadded_features = current_height * current_width
        newline_features = current_height
151

152
153
        return (unpadded_features, newline_features)

154
155
    def get_image_size_with_most_features(self) -> ImageSize:
        hf_config = self.get_hf_config()
156
157
158

        largest_feature_size, largest_feature_pinpoint = 0, None
        for (height, width) in hf_config.image_grid_pinpoints:
159
160
            feat_size = self.get_num_image_tokens(image_width=width,
                                                  image_height=height)
161
162
163
164
165
166
167
168
            if feat_size > largest_feature_size:
                largest_feature_size = feat_size
                largest_feature_pinpoint = ImageSize(width=width,
                                                     height=height)

        if largest_feature_size == 0 or largest_feature_pinpoint is None:
            raise ValueError("Cannot have a largest feature size of 0!")

169
170
171
        return largest_feature_pinpoint


172
173
174
175
176
177
178
179
180
181
182
183
184
185
_I = TypeVar("_I", bound=LlavaNextProcessingInfo)


class BaseLlavaNextMultiModalProcessor(BaseLlavaMultiModalProcessor[_I]):

    # Copied from BaseMultiModalProcessor
    @abstractmethod
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        raise NotImplementedError

186

187
188
class LlavaNextMultiModalProcessor(
        BaseLlavaNextMultiModalProcessor[LlavaNextProcessingInfo]):
189
190
191
192
193
194
195
196
197
198
199

    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return dict(
            pixel_values=MultiModalFieldConfig.batched("image"),
            image_sizes=MultiModalFieldConfig.batched("image"),
            image_embeds=MultiModalFieldConfig.batched("image"),
        )
200
201


202
203
204
@MULTIMODAL_REGISTRY.register_processor(LlavaNextMultiModalProcessor,
                                        info=LlavaNextProcessingInfo,
                                        dummy_inputs=LlavaDummyInputsBuilder)
205
206
class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
                                        SupportsPP):
207

208
209
210
211
212
213
214
215
216
217
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={
            # mapping for new names in checkpoint saved after transformers v4.52
            "model.language_model.": "language_model.model.",
            "model.vision_tower.": "vision_tower.",
            "model.multi_modal_projector.": "multi_modal_projector.",
            "model.image_newline": "image_newline",
            "lm_head.": "language_model.lm_head.",
        })

218
219
220
221
222
223
224
    @classmethod
    def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
        if modality.startswith("image"):
            return "<image>"

        raise ValueError("Only image modality is supported")

225
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
226
        super().__init__()
227
228
229
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config
230

231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
        vision_feature_layer = config.vision_feature_layer
        # Determine the layer up to which we will initialize the vision tower
        if isinstance(vision_feature_layer, int):
            vision_hidden_size = config.vision_config.hidden_size
            self.feature_sample_layers = None
        # Used for multimodal granite models to control encoder outputs
        elif isinstance(vision_feature_layer, (list, tuple)):
            vision_hidden_size = config.vision_config.hidden_size * len(
                vision_feature_layer)
            self.feature_sample_layers = vision_feature_layer
        else:
            raise TypeError(
                f"vision_layer_feature type: {type(vision_feature_layer)}"
                " is not supported")

246
        self.config = config
247
        self.multimodal_config = multimodal_config
248

249
        # TODO: Optionally initializes this for supporting embeddings.
250
        self.vision_tower = init_vision_tower_for_llava(
251
252
253
            config,
            quant_config,
            require_post_norm=False,
254
            prefix=maybe_prefix(prefix, "vision_tower"))
255
256
        self.image_newline = nn.Parameter(
            torch.empty(config.text_config.hidden_size))
257
        self.multi_modal_projector = LlavaMultiModalProjector(
258
            vision_hidden_size=vision_hidden_size,
259
            text_hidden_size=config.text_config.hidden_size,
260
261
            projector_hidden_act=config.projector_hidden_act,
            multimodal_projector_bias=config.multimodal_projector_bias)
262

263
        self.language_model = init_vllm_registered_model(
264
            vllm_config=vllm_config,
265
266
267
268
            hf_config=config.text_config,
            prefix=maybe_prefix(prefix, "language_model"),
        )

269
270
271
        self.make_empty_intermediate_tensors = (
            self.language_model.make_empty_intermediate_tensors)

272
    def _validate_image_sizes(self, data: torch.Tensor) -> torch.Tensor:
273
274
275
276
277
278
279
280
281
282
283
284
285
        expected_dims = (2, )

        def _validate_shape(d: torch.Tensor):
            actual_dims = tuple(d.shape)

            if actual_dims != expected_dims:
                expected_expr = str(expected_dims)
                raise ValueError(
                    f"The expected shape of image sizes per image per batch "
                    f"is {expected_expr}. You supplied {tuple(d.shape)}.")

        for d in data:
            _validate_shape(d)
286
287
288

        return data

289
    def _validate_pixel_values(
290
291
        self, data: Union[torch.Tensor, list[torch.Tensor]]
    ) -> Union[torch.Tensor, list[torch.Tensor]]:
292

293
294
295
296
297
298
299
300
        h = w = self.config.vision_config.image_size
        expected_dims = (3, h, w)

        def _validate_shape(d: torch.Tensor):
            actual_dims = tuple(d.shape[1:])

            if actual_dims != expected_dims:
                expected_expr = ("num_patches", *map(str, expected_dims))
301
                raise ValueError(
302
                    "The expected shape of pixel values per image per batch "
303
                    f"is {expected_expr}. You supplied {tuple(d.shape)}.")
304

305
306
        for d in data:
            _validate_shape(d)
307
308
309

        return data

310
    def _parse_and_validate_image_input(
311
            self, **kwargs: object) -> Optional[LlavaNextImageInputs]:
312
313
        pixel_values = kwargs.pop("pixel_values", None)
        image_sizes = kwargs.pop("image_sizes", None)
314
        image_embeds = kwargs.pop("image_embeds", None)
315

316
        if pixel_values is None and image_embeds is None:
317
            return None
318

319
320
321
322
        if pixel_values is not None:
            if not isinstance(pixel_values, (torch.Tensor, list)):
                raise ValueError("Incorrect type of pixel values. "
                                 f"Got type: {type(pixel_values)}")
323

324
            if not isinstance(image_sizes, (torch.Tensor, list)):
325
326
                raise ValueError("Incorrect type of image sizes. "
                                 f"Got type: {type(image_sizes)}")
327

328
329
            return LlavaNextImagePixelInputs(
                type="pixel_values",
330
331
                pixel_values=self._validate_pixel_values(
                    flatten_bn(pixel_values)),
332
333
                image_sizes=self._validate_image_sizes(
                    flatten_bn(image_sizes, concat=True)),
334
335
336
337
338
339
340
341
342
            )

        if image_embeds is not None:
            if not isinstance(image_embeds, torch.Tensor):
                raise ValueError("Incorrect type of image embeds. "
                                 f"Got type: {type(image_embeds)}")

            return LlavaNextImageEmbeddingInputs(
                type="image_embeds",
343
                data=flatten_bn(image_embeds),
344
345
346
            )

        raise AssertionError("This line should be unreachable.")
347

Cyrus Leung's avatar
Cyrus Leung committed
348
349
350
351
352
353
354
355
356
357
    def _select_image_features(self, image_features: torch.Tensor, *,
                               strategy: str) -> torch.Tensor:
        # Copied from https://github.com/huggingface/transformers/blob/39c3c0a72af6fbda5614dde02ff236069bb79827/src/transformers/models/llava/modeling_llava.py#L421  # noqa
        if strategy == "default":
            return image_features[:, 1:]
        elif strategy == "full":
            return image_features

        raise ValueError(f"Unexpected select feature strategy: {strategy}")

358
359
360
361
362
    def _image_pixels_to_features(
        self,
        vision_tower: Union[CLIPVisionModel, SiglipVisionModel],
        pixel_values: torch.Tensor,
    ) -> torch.Tensor:
Cyrus Leung's avatar
Cyrus Leung committed
363

364
365
        # NOTE: we skip the step to select the vision feature layer since
        # this is already done inside the vision tower
366
367
        image_features = vision_tower(
            pixel_values, feature_sample_layers=self.feature_sample_layers)
Cyrus Leung's avatar
Cyrus Leung committed
368
369
370
371
372
373

        return self._select_image_features(
            image_features,
            strategy=self.config.vision_feature_select_strategy,
        )

374
    # Based on: https://github.com/haotian-liu/LLaVA/blob/main/llava/model/llava_arch.py
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
    def _merge_image_patch_embeddings(self, image_size: torch.Tensor,
                                      patch_embeddings: torch.Tensor, *,
                                      strategy: str) -> torch.Tensor:
        if strategy == "flat":
            return patch_embeddings.flatten(0, 1)

        if strategy.startswith("spatial"):
            height = width = self.config.vision_config.image_size \
                // self.config.vision_config.patch_size

            base_patch_embeds = patch_embeddings[0]
            if height * width != base_patch_embeds.shape[0]:
                raise ValueError(
                    "The number of patches is not consistent with the "
                    "image size.")

            if patch_embeddings.shape[0] > 1:
                other_patch_embeds = patch_embeddings[1:]

394
395
396
                # Move to CPU to avoid floating-point errors
                orig_height, orig_width = image_size.tolist()

397
                # image_aspect_ratio == "anyres"
398
399
                num_patch_height, num_patch_width = get_anyres_image_grid_shape(
                    (orig_height, orig_width),
400
401
402
                    self.config.image_grid_pinpoints,
                    self.config.vision_config.image_size,
                )
403
404
405
406
                num_patches = num_patch_height * num_patch_width

                # Image patches might be padded for batch processing
                other_patch_embeds = other_patch_embeds[:num_patches] \
407
                    .view(num_patch_height, num_patch_width, height, width, -1)
408
409
410
411
412
413

                if "unpad" in strategy:
                    other_patch_embeds = other_patch_embeds \
                        .permute(4, 0, 2, 1, 3).contiguous() \
                        .flatten(1, 2).flatten(2, 3)
                    other_patch_embeds = unpad_image(other_patch_embeds,
414
                                                     (orig_height, orig_width))
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
                    other_patch_embeds = torch.cat((
                        other_patch_embeds,
                        self.image_newline[:, None, None] \
                            .expand(*other_patch_embeds.shape[:-1], 1) \
                            .to(other_patch_embeds.device),
                    ), dim=-1)
                    other_patch_embeds = other_patch_embeds \
                        .flatten(1, 2).transpose(0, 1)
                else:
                    other_patch_embeds = other_patch_embeds \
                        .permute(0, 2, 1, 3, 4).contiguous() \
                        .flatten(0, 3)

                merged_patch_embeddings = torch.cat(
                    (base_patch_embeds, other_patch_embeds), dim=0)
            else:
                if "unpad" in strategy:
                    merged_patch_embeddings = torch.cat(
                        (base_patch_embeds,
                         self.image_newline[None] \
                            .to(base_patch_embeds.device)
                    ), dim=0)
                else:
                    merged_patch_embeddings = base_patch_embeds

            return merged_patch_embeddings

        raise ValueError(f"Unexpected patch merge strategy: {strategy}")

    def _process_image_pixels(
445
446
        self,
        inputs: LlavaNextImagePixelInputs,
447
    ) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
448
449
        assert self.vision_tower is not None

450
        pixel_values = inputs["pixel_values"]
451

452
453
454
455
456
457
458
        if isinstance(pixel_values, torch.Tensor):
            b, num_patches, c, h, w = pixel_values.shape
            stacked_pixel_values = pixel_values.view(b * num_patches, c, h, w)
            stacked_image_features = self._image_pixels_to_features(
                self.vision_tower, stacked_pixel_values)
            stacked_patch_embeddings = self.multi_modal_projector(
                stacked_image_features)
459

460
461
462
463
464
            return stacked_patch_embeddings.view(
                b, num_patches, *stacked_patch_embeddings.shape[1:])

        num_patches_per_batch = [v.shape[0] for v in pixel_values]
        stacked_pixel_values = torch.cat(pixel_values)
465
466
467
        stacked_image_features = self._image_pixels_to_features(
            self.vision_tower, stacked_pixel_values)

468
469
        return torch.split(self.multi_modal_projector(stacked_image_features),
                           num_patches_per_batch)
470
471

    def _process_image_input(
472
473
        self,
        image_input: LlavaNextImageInputs,
474
    ) -> Union[torch.Tensor, list[torch.Tensor]]:
475
476
477
        if image_input["type"] == "image_embeds":
            return [image_input["data"]]

478
        patch_embeddings = self._process_image_pixels(image_input)
479
480
481

        image_sizes = image_input.get("image_sizes")
        if image_sizes is None:
482
            batch_size = len(image_input["data"])
483
            vision_config = self.config.vision_config
484
485
            default_height = default_width = vision_config.image_size
            image_sizes = torch.as_tensor([[default_height, default_width]
486
487
                                           for _ in range(batch_size)])

488
        return [
489
            self._merge_image_patch_embeddings(image_sizes[i],
490
                                               patch_features_batch,
491
                                               strategy="spatial_unpad")
492
            for i, patch_features_batch in enumerate(patch_embeddings)
493
494
        ]

495
496
497
    def get_language_model(self) -> torch.nn.Module:
        return self.language_model

498
499
    def get_multimodal_embeddings(self,
                                  **kwargs: object) -> MultiModalEmbeddings:
500
501
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input is None:
502
            return []
503
504
505
506
507
508
        vision_embeddings = self._process_image_input(image_input)
        return vision_embeddings

    def get_input_embeddings(
        self,
        input_ids: torch.Tensor,
509
        multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
510
511
    ) -> torch.Tensor:

512
513
        if multimodal_embeddings is None \
            or len(multimodal_embeddings) == 0:
514
515
516
517
518
519
520
521
522
523
            return self.language_model.get_input_embeddings(input_ids)

        inputs_embeds = embed_multimodal(
            input_ids,
            self.config.image_token_index,
            self.language_model.model.get_input_embeddings,
            multimodal_embeddings,
        )
        return inputs_embeds

524
525
526
527
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
528
        intermediate_tensors: Optional[IntermediateTensors] = None,
529
        inputs_embeds: Optional[torch.Tensor] = None,
530
        **kwargs: object,
531
    ) -> Union[torch.Tensor, IntermediateTensors]:
Cyrus Leung's avatar
Cyrus Leung committed
532
        """Run forward pass for LlaVA-NeXT.
533
534
535

        One key thing to understand is the `input_ids` already accounts for the
        positions of the to-be-inserted image embeddings.
536

537
        Concretely, consider a text prompt:
538
539
540
541
542
        `"A chat between a curious human and an artificial intelligence
        assistant. The assistant gives helpful, detailed, and polite answers to
        the human's questions.
        USER: <image>\\nWhat is shown in this image? ASSISTANT:"`.

543
        Tokenizer outputs:
544
545
546
547
548
549
550
        `[1, 319, 13563, 1546, 263, 12758, 5199, 322, 385, 23116, 21082, 20255,
        29889, 450, 20255, 4076, 8444, 29892, 13173, 29892, 322, 1248, 568,
        6089, 304, 278, 5199, 29915, 29879, 5155, 29889, 3148, 1001, 29901,
        29871, 32000, 13, 5618, 338, 4318, 297, 445, 1967, 29973, 319, 1799,
        9047, 13566, 29901]`.

        To reserve space in KV cache, we have to insert placeholder tokens
551
        before they are inputted to the model, so the input processor prepends
552
553
554
555
556
557
558
559
560
561
        additional image tokens (denoted as `32000`), resulting in:
        `[1, 319, 13563, 1546, 263, 12758, 5199, 322, 385, 23116, 21082, 20255,
        29889, 450, 20255, 4076, 8444, 29892, 13173, 29892, 322, 1248, 568,
        6089, 304, 278, 5199, 29915, 29879, 5155, 29889, 3148, 1001, 29901,
        29871, 32000, ..., 32000, 13, 5618, 338, 4318, 297, 445, 1967, 29973,
        319, 1799, 9047, 13566, 29901]`.

        Unlike in LLaVA-1.5, the number of image tokens inputted to the language
        model depends on the original size of the input image. Including the
        original image token in the input, the required number of image tokens
562
        is given by [get_llava_next_image_feature_size][].
563
564
565
566
567
568
569

        This way, the `positions` and `attn_metadata` are consistent
        with the `input_ids`.

        Args:
            input_ids: Flattened (concatenated) input_ids corresponding to a
                batch.
Cyrus Leung's avatar
Cyrus Leung committed
570
            pixel_values: The pixels in each grid patch for each input image.
571
            image_sizes: The original `(height, width)` for each input image.
572

573
574
        Info:
            [LlavaNextImageInputs][]
575
        """
576
577
        if intermediate_tensors is not None:
            inputs_embeds = None
578

579
580
581
582
583
584
585
        # NOTE: In v1, inputs_embeds is always generated at model runner, this
        # condition is for v0 compatibility.
        elif inputs_embeds is None:
            vision_embeddings = self.get_multimodal_embeddings(**kwargs)
            inputs_embeds = self.get_input_embeddings(input_ids,
                                                      vision_embeddings)
            input_ids = None
586

587
588
        hidden_states = self.language_model.model(input_ids,
                                                  positions,
589
                                                  intermediate_tensors,
590
                                                  inputs_embeds=inputs_embeds)
591
592
        return hidden_states

593
594
595
596
597
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
598
599
        return self.language_model.compute_logits(hidden_states,
                                                  sampling_metadata)
600

601
602
    def load_weights(self, weights: Iterable[tuple[str,
                                                   torch.Tensor]]) -> set[str]:
603
        loader = AutoWeightsLoader(self)
604
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)