llava_next.py 23.7 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
from abc import abstractmethod
4
from functools import cached_property
5
from typing import (Final, Iterable, List, Literal, Mapping, Optional,
6
                    Protocol, Set, Tuple, TypedDict, TypeVar, Union)
7
8
9

import torch
import torch.nn as nn
10
from transformers import BatchFeature, LlavaNextConfig, LlavaNextProcessor
11
12
13
14
15
from transformers.models.llava_next.modeling_llava_next import (
    get_anyres_image_grid_shape, unpad_image)
from typing_extensions import NotRequired

from vllm.attention import AttentionMetadata
16
from vllm.config import VllmConfig
Joe Runde's avatar
Joe Runde committed
17
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
18
from vllm.model_executor.sampling_metadata import SamplingMetadata
19
from vllm.multimodal import MULTIMODAL_REGISTRY
20
21
from vllm.multimodal.inputs import MultiModalFieldConfig, NestedTensors
from vllm.multimodal.parse import ImageSize
22
from vllm.sequence import IntermediateTensors
23

24
from .clip import CLIPVisionModel
25
from .interfaces import SupportsMultiModal, SupportsPP
26
27
from .llava import (BaseLlavaMultiModalProcessor, BaseLlavaProcessingInfo,
                    LlavaDummyInputsBuilder, LlavaLikeConfig,
28
                    LlavaMultiModalProjector, init_vision_tower_for_llava)
29
from .siglip import SiglipVisionModel
Cyrus Leung's avatar
Cyrus Leung committed
30
from .utils import (AutoWeightsLoader, embed_multimodal, flatten_bn,
31
                    init_vllm_registered_model, maybe_prefix)
32
33
34
35


class LlavaNextImagePixelInputs(TypedDict):
    type: Literal["pixel_values"]
36
    data: Union[torch.Tensor, List[torch.Tensor]]
37
    """
38
39
    Shape:
    `(batch_size * num_images, 1 + num_patches, num_channels, height, width)`
40

41
42
    Note that `num_patches` may be different per batch and image,
    in which case the data is passed as a list instead of a batched tensor.
43
    """
44
45

    image_sizes: NotRequired[torch.Tensor]
46
    """
47
    Shape: `(batch_size * num_images, 2)`
48
49
50

    This should be in `(height, width)` format.
    """
51
52


53
54
55
class LlavaNextImageEmbeddingInputs(TypedDict):
    type: Literal["image_embeds"]
    data: torch.Tensor
56
    """Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
57
58
59
60
61
62
63

    `hidden_size` must match the hidden size of language model backbone.
    """


LlavaNextImageInputs = Union[LlavaNextImagePixelInputs,
                             LlavaNextImageEmbeddingInputs]
64
65


66
67
class LlavaNextLikeConfig(LlavaLikeConfig, Protocol):
    image_grid_pinpoints: Final[list[list[int]]]
68

69

70
class LlavaNextProcessingInfo(BaseLlavaProcessingInfo):
71

72
    def get_hf_config(self) -> LlavaNextLikeConfig:
73
        return self.ctx.get_hf_config(LlavaNextConfig)
74

75
    def get_hf_processor(self):
76
77
78
79
80
81
82
83
84
        hf_processor = self.ctx.get_hf_processor(LlavaNextProcessor)

        # In case patch_size is omitted from `processor_config.json`
        # e.g. for E5-V: https://huggingface.co/royokong/e5-v
        if hf_processor.patch_size is None:
            patch_size = self.get_vision_encoder_info().get_patch_size()
            hf_processor.patch_size = patch_size

        return hf_processor
85

86
    # Based on: https://github.com/huggingface/text-generation-inference/blob/v3.0.1/server/text_generation_server/models/vlm_causal_lm.py#L113
87
    def get_num_image_tokens(
88
89
90
91
92
        self,
        *,
        image_width: int,
        image_height: int,
    ) -> int:
93
94
        hf_config = self.get_hf_config()
        vision_encoder_info = self.get_vision_encoder_info()
95
96
97

        base_feature_size = self._apply_feature_select_strategy(
            hf_config.vision_feature_select_strategy,
98
            vision_encoder_info.get_num_image_tokens(
99
100
101
                image_width=image_width,
                image_height=image_height,
            ),
102
        )
103
104
105
106

        num_patch_height, num_patch_width = get_anyres_image_grid_shape(
            image_size=(image_height, image_width),
            grid_pinpoints=hf_config.image_grid_pinpoints,
107
            patch_size=vision_encoder_info.get_image_size(),
108
109
        )

110
111
112
113
114
115
        (
            unpadded_feature_size,
            newline_feature_size,
        ) = self._get_num_unpadded_features(
            original_height=image_height,
            original_width=image_width,
116
            npatches=vision_encoder_info.get_patch_grid_length(),
117
118
119
            num_patch_height=num_patch_height,
            num_patch_width=num_patch_width,
        )
120

121
        return unpadded_feature_size + newline_feature_size + base_feature_size
122

123
    # Based on: https://github.com/huggingface/text-generation-inference/blob/v3.0.1/server/text_generation_server/models/vlm_causal_lm.py#L86
124
125
126
127
128
129
130
131
132
    def _get_num_unpadded_features(
        self,
        *,
        original_height: int,
        original_width: int,
        npatches: int,
        num_patch_height: int,
        num_patch_width: int,
    ) -> tuple[int, int]:
133
134
        current_height = npatches * num_patch_height
        current_width = npatches * num_patch_width
135

136
137
        aspect_ratio = original_width / original_height
        current_aspect_ratio = current_width / current_height
138

139
140
141
142
        if aspect_ratio > current_aspect_ratio:
            new_height = (original_height * current_width) // original_width
            padding = (current_height - new_height) // 2
            current_height = current_height - (2 * padding)
143
        else:
144
145
146
            new_width = (original_width * current_height) // original_height
            padding = (current_width - new_width) // 2
            current_width = current_width - (2 * padding)
147

148
149
        unpadded_features = current_height * current_width
        newline_features = current_height
150

151
152
        return (unpadded_features, newline_features)

153
154
    def get_image_size_with_most_features(self) -> ImageSize:
        hf_config = self.get_hf_config()
155
156
157

        largest_feature_size, largest_feature_pinpoint = 0, None
        for (height, width) in hf_config.image_grid_pinpoints:
158
159
            feat_size = self.get_num_image_tokens(image_width=width,
                                                  image_height=height)
160
161
162
163
164
165
166
167
            if feat_size > largest_feature_size:
                largest_feature_size = feat_size
                largest_feature_pinpoint = ImageSize(width=width,
                                                     height=height)

        if largest_feature_size == 0 or largest_feature_pinpoint is None:
            raise ValueError("Cannot have a largest feature size of 0!")

168
169
170
        return largest_feature_pinpoint


171
172
173
174
175
176
177
178
179
180
181
182
183
184
_I = TypeVar("_I", bound=LlavaNextProcessingInfo)


class BaseLlavaNextMultiModalProcessor(BaseLlavaMultiModalProcessor[_I]):

    # Copied from BaseMultiModalProcessor
    @abstractmethod
    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        raise NotImplementedError

185

186
187
class LlavaNextMultiModalProcessor(
        BaseLlavaNextMultiModalProcessor[LlavaNextProcessingInfo]):
188
189
190
191
192
193
194
195
196
197
198

    def _get_mm_fields_config(
        self,
        hf_inputs: BatchFeature,
        hf_processor_mm_kwargs: Mapping[str, object],
    ) -> Mapping[str, MultiModalFieldConfig]:
        return dict(
            pixel_values=MultiModalFieldConfig.batched("image"),
            image_sizes=MultiModalFieldConfig.batched("image"),
            image_embeds=MultiModalFieldConfig.batched("image"),
        )
199
200


201
202
203
@MULTIMODAL_REGISTRY.register_processor(LlavaNextMultiModalProcessor,
                                        info=LlavaNextProcessingInfo,
                                        dummy_inputs=LlavaDummyInputsBuilder)
204
205
class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal,
                                        SupportsPP):
206

207
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
208
        super().__init__()
209
210
211
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config
212

213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
        vision_feature_layer = config.vision_feature_layer
        # Determine the layer up to which we will initialize the vision tower
        if isinstance(vision_feature_layer, int):
            vision_hidden_size = config.vision_config.hidden_size
            self.feature_sample_layers = None
        # Used for multimodal granite models to control encoder outputs
        elif isinstance(vision_feature_layer, (list, tuple)):
            vision_hidden_size = config.vision_config.hidden_size * len(
                vision_feature_layer)
            self.feature_sample_layers = vision_feature_layer
        else:
            raise TypeError(
                f"vision_layer_feature type: {type(vision_feature_layer)}"
                " is not supported")

228
        self.config = config
229
        self.multimodal_config = multimodal_config
230

231
        # TODO: Optionally initializes this for supporting embeddings.
232
        self.vision_tower = init_vision_tower_for_llava(
233
234
235
            config,
            quant_config,
            require_post_norm=False,
236
            prefix=maybe_prefix(prefix, "vision_tower"))
237
238
        self.image_newline = nn.Parameter(
            torch.empty(config.text_config.hidden_size))
239
        self.multi_modal_projector = LlavaMultiModalProjector(
240
            vision_hidden_size=vision_hidden_size,
241
            text_hidden_size=config.text_config.hidden_size,
242
243
            projector_hidden_act=config.projector_hidden_act,
            multimodal_projector_bias=config.multimodal_projector_bias)
244

245
        self.language_model = init_vllm_registered_model(
246
            vllm_config=vllm_config,
247
248
249
250
            hf_config=config.text_config,
            prefix=maybe_prefix(prefix, "language_model"),
        )

251
252
253
254
255
256
257
258
        self.make_empty_intermediate_tensors = (
            self.language_model.make_empty_intermediate_tensors)

    @cached_property
    def sampler(self):
        if hasattr(self.language_model, "sampler"):
            return self.language_model.sampler

Joe Runde's avatar
Joe Runde committed
259
        return get_sampler()
260
261

    def _validate_image_sizes(self, data: torch.Tensor) -> torch.Tensor:
262
263
264
265
266
267
268
269
270
271
272
273
274
        expected_dims = (2, )

        def _validate_shape(d: torch.Tensor):
            actual_dims = tuple(d.shape)

            if actual_dims != expected_dims:
                expected_expr = str(expected_dims)
                raise ValueError(
                    f"The expected shape of image sizes per image per batch "
                    f"is {expected_expr}. You supplied {tuple(d.shape)}.")

        for d in data:
            _validate_shape(d)
275
276
277

        return data

278
279
280
281
    def _validate_pixel_values(
        self, data: Union[torch.Tensor, List[torch.Tensor]]
    ) -> Union[torch.Tensor, List[torch.Tensor]]:

282
283
284
285
286
287
288
289
        h = w = self.config.vision_config.image_size
        expected_dims = (3, h, w)

        def _validate_shape(d: torch.Tensor):
            actual_dims = tuple(d.shape[1:])

            if actual_dims != expected_dims:
                expected_expr = ("num_patches", *map(str, expected_dims))
290
                raise ValueError(
291
                    "The expected shape of pixel values per image per batch "
292
                    f"is {expected_expr}. You supplied {tuple(d.shape)}.")
293

294
295
        for d in data:
            _validate_shape(d)
296
297
298

        return data

299
    def _parse_and_validate_image_input(
300
            self, **kwargs: object) -> Optional[LlavaNextImageInputs]:
301
302
        pixel_values = kwargs.pop("pixel_values", None)
        image_sizes = kwargs.pop("image_sizes", None)
303
        image_embeds = kwargs.pop("image_embeds", None)
304

305
        if pixel_values is None and image_embeds is None:
306
            return None
307

308
309
310
311
        if pixel_values is not None:
            if not isinstance(pixel_values, (torch.Tensor, list)):
                raise ValueError("Incorrect type of pixel values. "
                                 f"Got type: {type(pixel_values)}")
312

313
            if not isinstance(image_sizes, (torch.Tensor, list)):
314
315
                raise ValueError("Incorrect type of image sizes. "
                                 f"Got type: {type(image_sizes)}")
316

317
318
            return LlavaNextImagePixelInputs(
                type="pixel_values",
319
320
321
                data=self._validate_pixel_values(flatten_bn(pixel_values)),
                image_sizes=self._validate_image_sizes(
                    flatten_bn(image_sizes, concat=True)),
322
323
324
325
326
327
328
329
330
            )

        if image_embeds is not None:
            if not isinstance(image_embeds, torch.Tensor):
                raise ValueError("Incorrect type of image embeds. "
                                 f"Got type: {type(image_embeds)}")

            return LlavaNextImageEmbeddingInputs(
                type="image_embeds",
331
                data=flatten_bn(image_embeds),
332
333
334
            )

        raise AssertionError("This line should be unreachable.")
335

Cyrus Leung's avatar
Cyrus Leung committed
336
337
338
339
340
341
342
343
344
345
    def _select_image_features(self, image_features: torch.Tensor, *,
                               strategy: str) -> torch.Tensor:
        # Copied from https://github.com/huggingface/transformers/blob/39c3c0a72af6fbda5614dde02ff236069bb79827/src/transformers/models/llava/modeling_llava.py#L421  # noqa
        if strategy == "default":
            return image_features[:, 1:]
        elif strategy == "full":
            return image_features

        raise ValueError(f"Unexpected select feature strategy: {strategy}")

346
347
348
349
350
    def _image_pixels_to_features(
        self,
        vision_tower: Union[CLIPVisionModel, SiglipVisionModel],
        pixel_values: torch.Tensor,
    ) -> torch.Tensor:
Cyrus Leung's avatar
Cyrus Leung committed
351

352
353
        # NOTE: we skip the step to select the vision feature layer since
        # this is already done inside the vision tower
354
355
        image_features = vision_tower(
            pixel_values, feature_sample_layers=self.feature_sample_layers)
Cyrus Leung's avatar
Cyrus Leung committed
356
357
358
359
360
361

        return self._select_image_features(
            image_features,
            strategy=self.config.vision_feature_select_strategy,
        )

362
    # Based on: https://github.com/haotian-liu/LLaVA/blob/main/llava/model/llava_arch.py
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
    def _merge_image_patch_embeddings(self, image_size: torch.Tensor,
                                      patch_embeddings: torch.Tensor, *,
                                      strategy: str) -> torch.Tensor:
        if strategy == "flat":
            return patch_embeddings.flatten(0, 1)

        if strategy.startswith("spatial"):
            height = width = self.config.vision_config.image_size \
                // self.config.vision_config.patch_size

            base_patch_embeds = patch_embeddings[0]
            if height * width != base_patch_embeds.shape[0]:
                raise ValueError(
                    "The number of patches is not consistent with the "
                    "image size.")

            if patch_embeddings.shape[0] > 1:
                other_patch_embeds = patch_embeddings[1:]

382
383
384
                # Move to CPU to avoid floating-point errors
                orig_height, orig_width = image_size.tolist()

385
                # image_aspect_ratio == "anyres"
386
387
                num_patch_height, num_patch_width = get_anyres_image_grid_shape(
                    (orig_height, orig_width),
388
389
390
                    self.config.image_grid_pinpoints,
                    self.config.vision_config.image_size,
                )
391
392
393
394
                num_patches = num_patch_height * num_patch_width

                # Image patches might be padded for batch processing
                other_patch_embeds = other_patch_embeds[:num_patches] \
395
                    .view(num_patch_height, num_patch_width, height, width, -1)
396
397
398
399
400
401

                if "unpad" in strategy:
                    other_patch_embeds = other_patch_embeds \
                        .permute(4, 0, 2, 1, 3).contiguous() \
                        .flatten(1, 2).flatten(2, 3)
                    other_patch_embeds = unpad_image(other_patch_embeds,
402
                                                     (orig_height, orig_width))
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
                    other_patch_embeds = torch.cat((
                        other_patch_embeds,
                        self.image_newline[:, None, None] \
                            .expand(*other_patch_embeds.shape[:-1], 1) \
                            .to(other_patch_embeds.device),
                    ), dim=-1)
                    other_patch_embeds = other_patch_embeds \
                        .flatten(1, 2).transpose(0, 1)
                else:
                    other_patch_embeds = other_patch_embeds \
                        .permute(0, 2, 1, 3, 4).contiguous() \
                        .flatten(0, 3)

                merged_patch_embeddings = torch.cat(
                    (base_patch_embeds, other_patch_embeds), dim=0)
            else:
                if "unpad" in strategy:
                    merged_patch_embeddings = torch.cat(
                        (base_patch_embeds,
                         self.image_newline[None] \
                            .to(base_patch_embeds.device)
                    ), dim=0)
                else:
                    merged_patch_embeddings = base_patch_embeds

            return merged_patch_embeddings

        raise ValueError(f"Unexpected patch merge strategy: {strategy}")

    def _process_image_pixels(
433
434
        self,
        inputs: LlavaNextImagePixelInputs,
435
    ) -> Union[torch.Tensor, tuple[torch.Tensor, ...]]:
436
437
438
439
        assert self.vision_tower is not None

        pixel_values = inputs["data"]

440
441
442
443
444
445
446
        if isinstance(pixel_values, torch.Tensor):
            b, num_patches, c, h, w = pixel_values.shape
            stacked_pixel_values = pixel_values.view(b * num_patches, c, h, w)
            stacked_image_features = self._image_pixels_to_features(
                self.vision_tower, stacked_pixel_values)
            stacked_patch_embeddings = self.multi_modal_projector(
                stacked_image_features)
447

448
449
450
451
452
            return stacked_patch_embeddings.view(
                b, num_patches, *stacked_patch_embeddings.shape[1:])

        num_patches_per_batch = [v.shape[0] for v in pixel_values]
        stacked_pixel_values = torch.cat(pixel_values)
453
454
455
        stacked_image_features = self._image_pixels_to_features(
            self.vision_tower, stacked_pixel_values)

456
457
        return torch.split(self.multi_modal_projector(stacked_image_features),
                           num_patches_per_batch)
458
459

    def _process_image_input(
460
461
462
        self,
        image_input: LlavaNextImageInputs,
    ) -> Union[torch.Tensor, List[torch.Tensor]]:
463
464
465
        if image_input["type"] == "image_embeds":
            return [image_input["data"]]

466
        patch_embeddings = self._process_image_pixels(image_input)
467
468
469

        image_sizes = image_input.get("image_sizes")
        if image_sizes is None:
470
            batch_size = len(image_input["data"])
471
            vision_config = self.config.vision_config
472
473
            default_height = default_width = vision_config.image_size
            image_sizes = torch.as_tensor([[default_height, default_width]
474
475
                                           for _ in range(batch_size)])

476
        return [
477
            self._merge_image_patch_embeddings(image_sizes[i],
478
                                               patch_features_batch,
479
                                               strategy="spatial_unpad")
480
            for i, patch_features_batch in enumerate(patch_embeddings)
481
482
        ]

483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
    def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input is None:
            return None
        vision_embeddings = self._process_image_input(image_input)
        return vision_embeddings

    def get_input_embeddings(
        self,
        input_ids: torch.Tensor,
        multimodal_embeddings: Optional[NestedTensors] = None,
    ) -> torch.Tensor:

        if multimodal_embeddings is None:
            return self.language_model.get_input_embeddings(input_ids)

        inputs_embeds = embed_multimodal(
            input_ids,
            self.config.image_token_index,
            self.language_model.model.get_input_embeddings,
            multimodal_embeddings,
        )
        return inputs_embeds

507
508
509
510
511
512
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
513
        intermediate_tensors: Optional[IntermediateTensors] = None,
514
        inputs_embeds: Optional[torch.Tensor] = None,
515
        **kwargs: object,
516
    ) -> Union[torch.Tensor, IntermediateTensors]:
Cyrus Leung's avatar
Cyrus Leung committed
517
        """Run forward pass for LlaVA-NeXT.
518
519
520

        One key thing to understand is the `input_ids` already accounts for the
        positions of the to-be-inserted image embeddings.
521

522
        Concretely, consider a text prompt:
523
524
525
526
527
        `"A chat between a curious human and an artificial intelligence
        assistant. The assistant gives helpful, detailed, and polite answers to
        the human's questions.
        USER: <image>\\nWhat is shown in this image? ASSISTANT:"`.

528
        Tokenizer outputs:
529
530
531
532
533
534
535
        `[1, 319, 13563, 1546, 263, 12758, 5199, 322, 385, 23116, 21082, 20255,
        29889, 450, 20255, 4076, 8444, 29892, 13173, 29892, 322, 1248, 568,
        6089, 304, 278, 5199, 29915, 29879, 5155, 29889, 3148, 1001, 29901,
        29871, 32000, 13, 5618, 338, 4318, 297, 445, 1967, 29973, 319, 1799,
        9047, 13566, 29901]`.

        To reserve space in KV cache, we have to insert placeholder tokens
536
        before they are inputted to the model, so the input processor prepends
537
538
539
540
541
542
543
544
545
546
547
        additional image tokens (denoted as `32000`), resulting in:
        `[1, 319, 13563, 1546, 263, 12758, 5199, 322, 385, 23116, 21082, 20255,
        29889, 450, 20255, 4076, 8444, 29892, 13173, 29892, 322, 1248, 568,
        6089, 304, 278, 5199, 29915, 29879, 5155, 29889, 3148, 1001, 29901,
        29871, 32000, ..., 32000, 13, 5618, 338, 4318, 297, 445, 1967, 29973,
        319, 1799, 9047, 13566, 29901]`.

        Unlike in LLaVA-1.5, the number of image tokens inputted to the language
        model depends on the original size of the input image. Including the
        original image token in the input, the required number of image tokens
        is given by :func:`get_llava_next_image_feature_size`.
548
549
550
551
552
553
554

        This way, the `positions` and `attn_metadata` are consistent
        with the `input_ids`.

        Args:
            input_ids: Flattened (concatenated) input_ids corresponding to a
                batch.
Cyrus Leung's avatar
Cyrus Leung committed
555
            pixel_values: The pixels in each grid patch for each input image.
556
            image_sizes: The original `(height, width)` for each input image.
557

Cyrus Leung's avatar
Cyrus Leung committed
558
        See also:
559
            :class:`LlavaNextImageInputs`
560
        """
561
562
        if intermediate_tensors is not None:
            inputs_embeds = None
563

564
565
566
567
568
569
570
        # NOTE: In v1, inputs_embeds is always generated at model runner, this
        # condition is for v0 compatibility.
        elif inputs_embeds is None:
            vision_embeddings = self.get_multimodal_embeddings(**kwargs)
            inputs_embeds = self.get_input_embeddings(input_ids,
                                                      vision_embeddings)
            input_ids = None
571

572
573
574
575
        hidden_states = self.language_model.model(input_ids,
                                                  positions,
                                                  kv_caches,
                                                  attn_metadata,
576
                                                  intermediate_tensors,
577
                                                  inputs_embeds=inputs_embeds)
578
579
        return hidden_states

580
581
582
583
584
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
585
586
        return self.language_model.compute_logits(hidden_states,
                                                  sampling_metadata)
587
588
589
590
591
592

    def sample(
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
593
        return self.language_model.sample(logits, sampling_metadata)
594

595
596
    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
597
        loader = AutoWeightsLoader(self)
598
        return loader.load_weights(weights)