llava.py 23.9 KB
Newer Older
1
from functools import cached_property
2
from typing import (Iterable, List, Literal, Mapping, Optional, Protocol, Set,
3
                    Tuple, TypedDict, Union)
4
5

import torch
6
import torch.nn as nn
7
from PIL import Image
8
from transformers import (CLIPVisionConfig, LlavaConfig, PixtralVisionConfig,
9
                          PretrainedConfig, SiglipVisionConfig)
10
11

from vllm.attention import AttentionMetadata
12
from vllm.config import VllmConfig
13
14
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
                         InputContext)
15
from vllm.model_executor.layers.activation import get_act_fn
16
17
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               RowParallelLinear)
18
from vllm.model_executor.layers.quantization import QuantizationConfig
Joe Runde's avatar
Joe Runde committed
19
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
20
from vllm.model_executor.sampling_metadata import SamplingMetadata
21
from vllm.multimodal import MULTIMODAL_REGISTRY
22
from vllm.multimodal.inputs import NestedTensors
23
from vllm.sequence import IntermediateTensors
24
from vllm.utils import is_list_of
25

26
27
28
from .clip import (CLIPVisionModel, dummy_image_for_clip,
                   dummy_seq_data_for_clip, get_max_clip_image_tokens,
                   input_processor_for_clip)
29
from .interfaces import SupportsMultiModal, SupportsPP
30
31
32
33
from .pixtral import (PixtralHFVisionModel, dummy_image_for_pixtral_hf,
                      dummy_seq_data_for_pixtral_hf,
                      get_max_pixtral_hf_image_tokens,
                      input_processor_for_pixtral_hf)
34
35
36
from .siglip import (SiglipVisionModel, dummy_image_for_siglip,
                     dummy_seq_data_for_siglip, get_max_siglip_image_tokens,
                     input_processor_for_siglip)
37
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
38
                    maybe_prefix, merge_multimodal_embeddings)
39
40


41
42
class LlavaImagePixelInputs(TypedDict):
    type: Literal["pixel_values"]
43
44
45
46
47
48
49
    data: Union[torch.Tensor, List[torch.Tensor]]
    """
    Shape: `(batch_size * num_images, num_channels, height, width)`

    Note that `height` or `width` may be different per batch and image,
    in which case the data is passed as a list instead of a batched tensor.
    """
50
51
52
53
54


class LlavaImageEmbeddingInputs(TypedDict):
    type: Literal["image_embeds"]
    data: torch.Tensor
55
    """Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
56
57
58
59
60
61
62
63

    `hidden_size` must match the hidden size of language model backbone.
    """


LlavaImageInputs = Union[LlavaImagePixelInputs, LlavaImageEmbeddingInputs]


64
65
class LlavaMultiModalProjector(nn.Module):

66
67
68
69
70
71
    def __init__(self,
                 vision_hidden_size: int,
                 text_hidden_size: int,
                 projector_hidden_act: str,
                 quant_config: Optional[QuantizationConfig] = None,
                 prefix: str = ""):
72
73
        super().__init__()

74
75
76
77
78
        self.linear_1 = ColumnParallelLinear(vision_hidden_size,
                                             text_hidden_size,
                                             bias=True,
                                             quant_config=quant_config,
                                             prefix=f"{prefix}.linear_1")
79
        self.act = get_act_fn(projector_hidden_act)
80
81
82
83
84
        self.linear_2 = RowParallelLinear(text_hidden_size,
                                          text_hidden_size,
                                          bias=True,
                                          quant_config=quant_config,
                                          prefix=f"{prefix}.linear_2")
85

86
    def forward(self, image_features: torch.Tensor) -> torch.Tensor:
87
        hidden_states, _ = self.linear_1(image_features)
88
        hidden_states = self.act(hidden_states)
89
        hidden_states, _ = self.linear_2(hidden_states)
90
91
92
        return hidden_states


93
94
95
96
97
def get_max_llava_image_tokens(ctx: InputContext):
    hf_config = ctx.get_hf_config(LlavaConfig)
    vision_config = hf_config.vision_config

    if isinstance(vision_config, CLIPVisionConfig):
98
99
100
        num_image_tokens = get_max_clip_image_tokens(vision_config)
    elif isinstance(vision_config, SiglipVisionConfig):
        num_image_tokens = get_max_siglip_image_tokens(vision_config)
101
102
    elif isinstance(vision_config, PixtralVisionConfig):
        num_image_tokens = get_max_pixtral_hf_image_tokens(vision_config)
103
104
105
106
107
108
109
110
111
112
113
    else:
        msg = f"Unsupported vision config: {type(vision_config)}"
        raise NotImplementedError(msg)

    strategy = hf_config.vision_feature_select_strategy
    if strategy == "default":
        return num_image_tokens - 1
    elif strategy == "full":
        return num_image_tokens
    else:
        raise ValueError(f"Unexpected select feature strategy: {strategy}")
114
115


116
117
def dummy_data_for_llava(ctx: InputContext, seq_len: int,
                         mm_counts: Mapping[str, int]):
118
119
    hf_config = ctx.get_hf_config(LlavaConfig)
    vision_config = hf_config.vision_config
120
    num_images = mm_counts["image"]
121

122
123
    image_feature_size = get_max_llava_image_tokens(ctx)

124
    if isinstance(vision_config, CLIPVisionConfig):
125
        seq_data, ranges = dummy_seq_data_for_clip(
126
127
            vision_config,
            seq_len,
128
            num_images,
129
            image_token_id=hf_config.image_token_index,
130
            image_feature_size_override=image_feature_size,
131
132
        )

133
        mm_data = dummy_image_for_clip(vision_config, num_images)
134
        return DummyData(seq_data, mm_data, ranges)
135
    elif isinstance(vision_config, SiglipVisionConfig):
136
        seq_data, ranges = dummy_seq_data_for_siglip(
137
138
            vision_config,
            seq_len,
139
            num_images,
140
141
142
143
            image_token_id=hf_config.image_token_index,
            image_feature_size_override=image_feature_size,
        )

144
        mm_data = dummy_image_for_siglip(vision_config, num_images)
145
        return DummyData(seq_data, mm_data, ranges)
146
    elif isinstance(vision_config, PixtralVisionConfig):
147
        seq_data, ranges = dummy_seq_data_for_pixtral_hf(
148
149
150
151
152
153
154
155
            vision_config,
            seq_len,
            num_images,
            image_token_id=hf_config.image_token_index,
            image_feature_size_override=image_feature_size,
        )

        mm_data = dummy_image_for_pixtral_hf(vision_config, num_images)
156
        return DummyData(seq_data, mm_data, ranges)
157
158
159
160
161

    msg = f"Unsupported vision config: {type(vision_config)}"
    raise NotImplementedError(msg)


162
163
def input_processor_for_llava(ctx: InputContext, inputs: DecoderOnlyInputs):
    multi_modal_data = inputs.get("multi_modal_data")
164
    if multi_modal_data is None or "image" not in multi_modal_data:
165
        return inputs
166
167
168
169
170

    model_config = ctx.model_config
    hf_config = ctx.get_hf_config(LlavaConfig)
    vision_config = hf_config.vision_config

171
172
173
174
175
176
177
178
179
180
181
182
    image_data = multi_modal_data["image"]
    if isinstance(image_data, Image.Image):
        image_feature_size = get_max_llava_image_tokens(ctx)
    elif is_list_of(image_data, Image.Image):
        image_feature_size = [get_max_llava_image_tokens(ctx)
                              ] * len(image_data)
    elif isinstance(image_data, torch.Tensor):
        num_images, image_feature_size, hidden_size = image_data.shape
    elif is_list_of(image_data, torch.Tensor):
        image_feature_size = [item.shape[1] for item in image_data]
    else:
        raise TypeError(f"Invalid image type: {type(image_data)}")
183

184
185
186
187
    if isinstance(vision_config, CLIPVisionConfig):
        return input_processor_for_clip(
            model_config,
            vision_config,
188
            inputs,
189
            image_token_id=hf_config.image_token_index,
190
191
192
193
194
195
            image_feature_size_override=image_feature_size,
        )
    elif isinstance(vision_config, SiglipVisionConfig):
        return input_processor_for_siglip(
            model_config,
            vision_config,
196
            inputs,
197
198
199
            image_token_id=hf_config.image_token_index,
            image_feature_size_override=image_feature_size,
        )
200
201
202
203
204
205
206
207
208
    elif isinstance(vision_config, PixtralVisionConfig):
        # We ignore image_feature_size_override since we have non-uniform
        # image sizes for Pixtral
        return input_processor_for_pixtral_hf(
            model_config,
            vision_config,
            inputs,
            image_token_id=hf_config.image_token_index,
        )
209
210
211
212
213

    msg = f"Unsupported vision config: {type(vision_config)}"
    raise NotImplementedError(msg)


214
215
class LlavaLikeConfig(Protocol):
    vision_config: PretrainedConfig
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
    vision_feature_layer: Union[int, List[int]]


def _get_num_hidden_layers(hf_config: LlavaLikeConfig) -> int:
    """Determine the number of hidden layers to initialize up to in the
    visual encoder.
    
    Args:
        hf_config: Model config with vision feature layer(s).
    """
    feature_layers = hf_config.vision_feature_layer
    num_hidden_layers = hf_config.vision_config.num_hidden_layers
    # If we have one feature layer, initialize up to that layer
    if isinstance(feature_layers, int):
        return _get_layer_index(feature_layers, num_hidden_layers)
    # If we have multiple feature layers, initialize up to the deepest one
    elif isinstance(feature_layers, (list, tuple)):
        return max(
            _get_layer_index(idx, num_hidden_layers) for idx in feature_layers)
    raise TypeError(f"vision_layer_feature type: {type(feature_layers)}"
                    " is not supported")


def _get_layer_index(feature_layer_index: int, num_hidden_layers: int) -> int:
    """Given an signed vision feature layer, get the number of hidden layers
    needed to leverage it.

    Args:
        feature_layer_index: Index of a required layer in the visual encoder.
        num_hidden_layers: The total number of hidden layers in the visual
            encoder.
    """
    if feature_layer_index < 0:
        return num_hidden_layers + feature_layer_index + 1
    return feature_layer_index + 1
251
252
253
254
255
256
257


def init_vision_tower_for_llava(
    hf_config: LlavaLikeConfig,
    quant_config: Optional[QuantizationConfig],
    *,
    require_post_norm: Optional[bool] = None,
258
    prefix: str = "",
259
):
260
261
    vision_config = hf_config.vision_config

262
263
    # Initialize the vision tower only up to the deepest required feature layer
    num_hidden_layers = _get_num_hidden_layers(hf_config)
264
265
266
267

    if isinstance(vision_config, CLIPVisionConfig):
        return CLIPVisionModel(
            vision_config,
268
            quant_config=quant_config,
269
            num_hidden_layers_override=num_hidden_layers,
270
            require_post_norm=require_post_norm,
271
            prefix=prefix,
272
273
274
275
        )
    elif isinstance(vision_config, SiglipVisionConfig):
        return SiglipVisionModel(
            vision_config,
276
            quant_config=quant_config,
277
            num_hidden_layers_override=num_hidden_layers,
278
            require_post_norm=require_post_norm,
279
            prefix=prefix,
280
        )
281
    elif isinstance(vision_config, PixtralVisionConfig):
282
283
        return PixtralHFVisionModel(
            vision_config,
284
            quant_config=quant_config,
285
286
            num_hidden_layers_override=num_hidden_layers,
            require_post_norm=require_post_norm,
287
            prefix=prefix,
288
        )
289
290
291
292
293

    msg = f"Unsupported vision config: {type(vision_config)}"
    raise NotImplementedError(msg)


294
@MULTIMODAL_REGISTRY.register_image_input_mapper()
295
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_image_tokens)
296
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava)
297
@INPUT_REGISTRY.register_input_processor(input_processor_for_llava)
298
class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
299
300
301
302
303
304
305
306
307
    # BitandBytes specific attributes
    bitsandbytes_stacked_params_mapping = {
        # shard_name, weight_name, index
        "q_proj": ("qkv_proj", 0),
        "k_proj": ("qkv_proj", 1),
        "v_proj": ("qkv_proj", 2),
        "gate_proj": ("gate_up_proj", 0),
        "up_proj": ("gate_up_proj", 1),
    }
308

309
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
310
        super().__init__()
311

312
313
314
315
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config

316
        self.config = config
317
        self.multimodal_config = multimodal_config
318

319
320
321
322
323
324
325
326
327
        # NOTE: These are special cases for Pixtral-12B in the HF-format
        # https://huggingface.co/mistral-community/pixtral-12b/blob/main/config.json  # noqa
        if (config.text_config.architectures is None
                and config.text_config.model_type == "mistral"):
            config.text_config.architectures = ["MistralForCausalLM"]
        if (config.projector_hidden_act is None
                and config.vision_config.hidden_act == "gelu"):
            config.projector_hidden_act = "gelu"

328
        # TODO: Optionally initializes this for supporting embeddings.
329
        self.vision_tower = init_vision_tower_for_llava(
330
331
332
            config,
            quant_config,
            require_post_norm=False,
333
            prefix=maybe_prefix(prefix, "vision_tower"))
334
335
336
        self.multi_modal_projector = LlavaMultiModalProjector(
            vision_hidden_size=config.vision_config.hidden_size,
            text_hidden_size=config.text_config.hidden_size,
337
338
339
            projector_hidden_act=config.projector_hidden_act,
            quant_config=quant_config,
            prefix=maybe_prefix(prefix, "multi_modal_projector"))
340

341
        self.language_model = init_vllm_registered_model(
342
            vllm_config=vllm_config,
343
344
345
            hf_config=config.text_config,
            prefix=maybe_prefix(prefix, "language_model"),
        )
346

347
348
349
350
351
352
353
354
        self.make_empty_intermediate_tensors = (
            self.language_model.make_empty_intermediate_tensors)

    @cached_property
    def sampler(self):
        if hasattr(self.language_model, "sampler"):
            return self.language_model.sampler

Joe Runde's avatar
Joe Runde committed
355
        return get_sampler()
356

357
358
359
360
361
362
363
    def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
        h = w = self.config.vision_config.image_size
        expected_dims = (3, h, w)
        actual_dims = tuple(data.shape[1:])

        if actual_dims != expected_dims:
            expected_expr = ("batch_size", *map(str, expected_dims))
364
            raise ValueError(
365
366
                f"The expected shape of pixel values is {expected_expr}. "
                f"You supplied {tuple(data.shape)}.")
367
368
369

        return data

370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
    def _validate_image_sizes(self, images: List[torch.Tensor],
                              sizes: List[torch.Tensor]) -> List[torch.Tensor]:
        if not isinstance(sizes, list):
            sizes = [sizes]

        total_images = sum(size.numel() // 2 for size in sizes)
        if total_images != len(images):
            raise ValueError("Mismatch in number of images. "
                             f"Expected {total_images}, got {len(images)}")
        img_idx = 0
        for size in sizes:
            # Flatten the size tensor to a list of (height, width) pairs
            size = size.view(-1, 2).tolist()
            for expected_h, expected_w in size:
                if img_idx >= len(images):
                    raise ValueError("Ran out of images before sizes. "
                                     f"{img_idx} >= {len(images)}")
                img = images[img_idx]
                if img.shape[-2:] != (expected_h, expected_w):
                    raise ValueError(
                        "Image size mismatch. Expected "
                        f"{(expected_h, expected_w)}, got {img.shape[-2:]}")
                if img.shape[-3] != 3:
                    raise ValueError("Image channel mismatch. Expected 3, "
                                     f"got {img.shape[-3]}")
                img_idx += 1
        return images

398
    def _parse_and_validate_image_input(
399
400
            self, **kwargs: object) -> Optional[LlavaImageInputs]:
        pixel_values = kwargs.pop("pixel_values", None)
401
        image_sizes = kwargs.pop("image_sizes", None)
402
        image_embeds = kwargs.pop("image_embeds", None)
403

404
        if pixel_values is None and image_embeds is None:
405
            return None
406

407
        if pixel_values is not None:
408
            if not isinstance(pixel_values, (torch.Tensor, list)):
409
410
                raise ValueError("Incorrect type of pixel values. "
                                 f"Got type: {type(pixel_values)}")
411

412
413
414
415
            # Case for models like PixtralHF that have dynamic image sizes
            # so we need to produce a list of tensors
            if image_sizes is not None:
                images = pixel_values
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434

                def flatten_to_3d_tensors(item):
                    if isinstance(item, torch.Tensor):
                        if item.dim() >= 3:
                            return [t for t in item.view(-1, *item.shape[-3:])]
                        else:
                            raise ValueError(
                                f"Unexpected tensor dimension: {item.dim()}")
                    elif isinstance(item, list):
                        return [
                            t for subitem in item
                            for t in flatten_to_3d_tensors(subitem)
                        ]
                    else:
                        raise ValueError(f"Unexpected type: {type(item)}")

                # Restructure the batched images into a list of lists of images
                images = flatten_to_3d_tensors(pixel_values)

435
436
                return LlavaImagePixelInputs(
                    type="pixel_values",
437
                    data=self._validate_image_sizes(images, image_sizes),
438
439
                )

440
441
            return LlavaImagePixelInputs(
                type="pixel_values",
442
443
                data=self._validate_pixel_values(
                    flatten_bn(pixel_values, concat=True)),
444
445
446
            )

        if image_embeds is not None:
447
            if not isinstance(image_embeds, (torch.Tensor, list)):
448
449
                raise ValueError("Incorrect type of image embeddings. "
                                 f"Got type: {type(image_embeds)}")
450

451
452
            return LlavaImageEmbeddingInputs(
                type="image_embeds",
453
                data=flatten_bn(image_embeds, concat=True),
454
455
456
            )

        raise AssertionError("This line should be unreachable.")
457
458
459
460
461
462
463
464
465
466
467

    def _select_image_features(self, image_features: torch.Tensor, *,
                               strategy: str) -> torch.Tensor:
        # Copied from https://github.com/huggingface/transformers/blob/39c3c0a72af6fbda5614dde02ff236069bb79827/src/transformers/models/llava/modeling_llava.py#L421  # noqa
        if strategy == "default":
            return image_features[:, 1:]
        elif strategy == "full":
            return image_features

        raise ValueError(f"Unexpected select feature strategy: {strategy}")

468
469
    def _image_pixels_to_features(
        self,
470
471
        vision_tower: Union[CLIPVisionModel, SiglipVisionModel,
                            PixtralHFVisionModel],
472
473
        pixel_values: torch.Tensor,
    ) -> torch.Tensor:
474

475
476
        # NOTE: we skip the step to select the vision feature layer since
        # this is already done inside the vision tower
477
        image_features = vision_tower(pixel_values)
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493

        return self._select_image_features(
            image_features,
            strategy=self.config.vision_feature_select_strategy,
        )

    def _process_image_pixels(self,
                              inputs: LlavaImagePixelInputs) -> torch.Tensor:
        assert self.vision_tower is not None

        pixel_values = inputs["data"]

        return self._image_pixels_to_features(self.vision_tower, pixel_values)

    def _process_image_input(self,
                             image_input: LlavaImageInputs) -> torch.Tensor:
494
495
496
497

        if image_input["type"] == "image_embeds":
            return image_input["data"]

498
499
        assert self.vision_tower is not None
        image_features = self._process_image_pixels(image_input)
500
501
        return self.multi_modal_projector(image_features)

502
    def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
503
504
505
506
507
508
509
510
511
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input is None:
            return None
        vision_embeddings = self._process_image_input(image_input)
        return vision_embeddings

    def get_input_embeddings(
        self,
        input_ids: torch.Tensor,
512
        multimodal_embeddings: Optional[NestedTensors] = None,
513
514
    ) -> torch.Tensor:
        inputs_embeds = self.language_model.get_input_embeddings(input_ids)
515
        if multimodal_embeddings is not None:
516
            inputs_embeds = merge_multimodal_embeddings(
517
                input_ids, inputs_embeds, multimodal_embeddings,
518
519
520
                self.config.image_token_index)
        return inputs_embeds

521
522
523
524
525
526
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
527
        intermediate_tensors: Optional[IntermediateTensors] = None,
528
        inputs_embeds: Optional[torch.Tensor] = None,
529
        **kwargs: object,
530
    ) -> Union[torch.Tensor, IntermediateTensors]:
Cyrus Leung's avatar
Cyrus Leung committed
531
        """Run forward pass for LLaVA-1.5.
532
533
534

        One key thing to understand is the `input_ids` already accounts for the
        positions of the to-be-inserted image embeddings.
535

536
        Concretely, consider a text prompt:
537
538
        `"USER: <image>\\nWhat's the content of the image?\\nASSISTANT:"`.

539
        Tokenizer outputs:
540
541
542
543
        `[1, 3148, 1001, 29901, 29871, 32000, 29871, 13, 5618, 29915, 29879,
        278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566, 29901]`.

        To reserve space in KV cache, we have to insert placeholder tokens
544
        before they are inputted to the model, so the input processor prepends
545
546
547
548
549
550
551
552
553
        additional image tokens (denoted as `32000`), resulting in:
        `[1, 3148, 1001, 29901, 29871, 32000, ..., 32000, 29871, 13, 5618,
        29915, 29879, 278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566,
        29901]`.

        We insert 575 tokens so that including the original image token in the
        input, there are a total of 576 (24 * 24) image tokens, which
        corresponds to the number of image tokens inputted to the language
        model, i.e. the number of image tokens outputted by the visual encoder.
554
555
556
557
558
559
560

        This way, the `positions` and `attn_metadata` are consistent
        with the `input_ids`.

        Args:
            input_ids: Flattened (concatenated) input_ids corresponding to a
                batch.
Cyrus Leung's avatar
Cyrus Leung committed
561
            pixel_values: The pixels in each input image.
562

563
564
        See also:
            :class:`LlavaImageInputs`
565
        """
566
567
        if intermediate_tensors is not None:
            inputs_embeds = None
568
569
570

        # NOTE: In v1, inputs_embeds is always generated at model runner, this
        # condition is for v0 compatibility.
571
        elif inputs_embeds is None:
572
            vision_embeddings = self.get_multimodal_embeddings(**kwargs)
573
574
575
            inputs_embeds = self.get_input_embeddings(input_ids,
                                                      vision_embeddings)
            input_ids = None
576

577
578
579
580
        hidden_states = self.language_model.model(input_ids,
                                                  positions,
                                                  kv_caches,
                                                  attn_metadata,
581
                                                  intermediate_tensors,
582
                                                  inputs_embeds=inputs_embeds)
583
584
585

        return hidden_states

586
587
588
589
590
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
591
592
        return self.language_model.compute_logits(hidden_states,
                                                  sampling_metadata)
593
594
595
596
597
598

    def sample(
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
599
        return self.language_model.sample(logits, sampling_metadata)
600

601
602
    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
603
        loader = AutoWeightsLoader(self)
604
        return loader.load_weights(weights)