siglip.py 23.2 KB
Newer Older
1
2
3
4
"""Implementation of SiglipVisionModel intended to be only used
within a vision language model."""

import math
5
from array import array
6
from typing import Iterable, Optional, Tuple
7
8
9
10

import torch
from PIL import Image
from torch import nn
11
from transformers import SiglipVisionConfig
12
13
14
15
16
17
18
19
20
21
22
23
24
25
from transformers.models.siglip.modeling_siglip import SiglipAttention
from vllm_flash_attn import flash_attn_func
from xformers.ops import memory_efficient_attention

from vllm.config import ModelConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.inputs import LLMInputs
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
                                               QKVParallelLinear,
                                               RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import (
    VocabParallelEmbedding)
26
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
27
28
from vllm.multimodal.utils import (cached_get_tokenizer,
                                   repeat_and_pad_placeholder_tokens)
29
from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData
30
31
32


def get_siglip_patch_grid_length(*, image_size: int, patch_size: int) -> int:
33
34
    # Since interpolation is applied, the image size need not be divisible
    # assert image_size % patch_size == 0
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
    return image_size // patch_size


def get_siglip_num_patches(*, image_size: int, patch_size: int) -> int:
    grid_length = get_siglip_patch_grid_length(image_size=image_size,
                                               patch_size=patch_size)
    return grid_length * grid_length


def get_siglip_image_feature_size(hf_config: SiglipVisionConfig) -> int:
    return get_siglip_num_patches(image_size=hf_config.image_size,
                                  patch_size=hf_config.patch_size)


def get_max_siglip_image_tokens(hf_config: SiglipVisionConfig) -> int:
    return get_siglip_image_feature_size(hf_config)


def dummy_seq_data_for_siglip(
    hf_config: SiglipVisionConfig,
    seq_len: int,
56
    num_images: int,
57
58
59
60
61
62
63
64
65
    *,
    image_token_id: int,
    image_feature_size_override: Optional[int] = None,
):
    if image_feature_size_override is None:
        image_feature_size = get_siglip_image_feature_size(hf_config)
    else:
        image_feature_size = image_feature_size_override

66
67
68
69
    token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE,
                      [image_token_id]) * image_feature_size
    token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE,
                       [0]) * (seq_len - image_feature_size)
70
71
72
73
74
    return SequenceData(token_ids)


def dummy_image_for_siglip(
    hf_config: SiglipVisionConfig,
75
    num_images: int,
76
77
78
79
80
81
82
83
84
85
86
    *,
    image_width_override: Optional[int] = None,
    image_height_override: Optional[int] = None,
):
    width = height = hf_config.image_size
    if image_width_override is not None:
        width = image_width_override
    if image_height_override is not None:
        height = image_height_override

    image = Image.new("RGB", (width, height), color=0)
87
    return {"image": image if num_images == 1 else [image] * num_images}
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104


def input_processor_for_siglip(
    model_config: ModelConfig,
    hf_config: SiglipVisionConfig,
    llm_inputs: LLMInputs,
    *,
    image_token_id: int,
    image_feature_size_override: Optional[int] = None,
):
    multi_modal_data = llm_inputs.get("multi_modal_data")
    if multi_modal_data is None or "image" not in multi_modal_data:
        return llm_inputs

    tokenizer = cached_get_tokenizer(model_config.tokenizer)

    if image_feature_size_override is None:
105
106
107
108
109
110
111
        image_data = multi_modal_data["image"]
        if isinstance(image_data, Image.Image):
            image_feature_size = get_siglip_image_feature_size(hf_config)
        elif isinstance(image_data, torch.Tensor):
            image_feature_size = image_data.shape[0]
        else:
            raise TypeError(f"Invalid image type: {type(image_data)}")
112
113
114
    else:
        image_feature_size = image_feature_size_override

115
    new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens(
116
117
118
        tokenizer,
        llm_inputs.get("prompt"),
        llm_inputs["prompt_token_ids"],
119
        placeholder_token_id=image_token_id,
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
        repeat_count=image_feature_size,
    )

    # NOTE: Create a defensive copy of the original inputs
    return LLMInputs(
        prompt_token_ids=new_token_ids,
        prompt=new_prompt,
        multi_modal_data=multi_modal_data,
    )


# Adapted from https://github.com/huggingface/transformers/blob/v4.43.3/src/transformers/models/siglip/modeling_siglip.py#L249 # noqa
class SiglipVisionEmbeddings(nn.Module):

    def __init__(self, config: SiglipVisionConfig):
        super().__init__()
        self.config = config
        self.embed_dim = config.hidden_size
        self.image_size = config.image_size
        self.patch_size = config.patch_size

        self.patch_embedding = nn.Conv2d(
            in_channels=config.num_channels,
            out_channels=self.embed_dim,
            kernel_size=self.patch_size,
            stride=self.patch_size,
            padding="valid",
        )

        self.num_patches = (self.image_size // self.patch_size)**2
        self.num_positions = self.num_patches
        self.position_embedding = VocabParallelEmbedding(
            self.num_positions, self.embed_dim)
        self.register_buffer(
            "position_ids",
            torch.arange(self.num_positions, dtype=torch.int64).expand(
                (1, -1)),
            persistent=False,
        )

    def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int,
                                 width: int) -> torch.Tensor:
        """
        This method is an adapted method for SigLIP (due to SigLIP not having
        class embedding unlike other ViTs) that allows the model to interpolate
        the pre-trained position encodings such that it can be usable on higher
        resolution images.

        Source:
        https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
        """
        position_embeddings = self.position_embedding.weight.unsqueeze(0)
        num_patches = embeddings.shape[1]
        num_positions = position_embeddings.shape[1]
        if num_patches == num_positions and height == width:
            return position_embeddings

        dim = embeddings.shape[-1]
        height = height // self.patch_size
        width = width // self.patch_size
        # we add a small number to avoid floating point error
        # in the interpolation
        # see discussion at https://github.com/facebookresearch/dino/issues/8
        height, width = height + 0.1, width + 0.1

        patch_pos_embed = position_embeddings.reshape(
            1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)),
            dim)
        patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
        patch_pos_embed = nn.functional.interpolate(
            patch_pos_embed,
            scale_factor=(
                height / math.sqrt(num_positions),
                width / math.sqrt(num_positions),
            ),
            mode="bicubic",
            align_corners=False,
        )
        if (int(height) != patch_pos_embed.shape[-2]
                or int(width) != patch_pos_embed.shape[-1]):
            raise ValueError("Width or height does not match with "
                             "the interpolated position embeddings")

        patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
        return patch_pos_embed

    def forward(self,
                pixel_values: torch.Tensor,
                interpolate_pos_encoding: bool = False) -> torch.Tensor:
        _, _, height, width = pixel_values.shape
        target_dtype = self.patch_embedding.weight.dtype
        patch_embeds = self.patch_embedding(pixel_values.to(
            dtype=target_dtype))  # shape = [*, width, grid, grid]
        embeddings = patch_embeds.flatten(2).transpose(1, 2)

        if interpolate_pos_encoding:
            embeddings = embeddings + self.interpolate_pos_encoding(
                embeddings, height, width)
        else:
            embeddings = embeddings + self.position_embedding(
                self.position_ids)
        return embeddings


# NOTE: Not used - kept for later when we TP the ViT
# TODO(ChristopherCho): Implement TP version of Attention
class SiglipTPAttention(nn.Module):

    def __init__(
        self,
        config,
        quant_config: Optional[QuantizationConfig] = None,
    ):
        super().__init__()
        self.config = config
        self.embed_dim = config.hidden_size

        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = config.num_attention_heads
        if self.total_num_heads % tp_size != 0:
            raise ValueError(
                f"Number of attention heads ({self.total_num_heads}) "
                "must be divisible by the tensor model parallel size"
                f" ({tp_size}).")

        self.num_heads = self.total_num_heads // tp_size
        self.head_dim = self.embed_dim // self.total_num_heads
        if self.head_dim * self.total_num_heads != self.embed_dim:
            raise ValueError(f"embed_dim must be divisible by num_heads (got "
                             "`embed_dim`: {self.embed_dim} and `num_heads`:"
                             f" {self.num_heads}).")
        self.qkv_size = self.num_heads * self.head_dim
        self.scale = self.head_dim**-0.5
        self.dropout = config.attention_dropout

        self.qkv_proj = QKVParallelLinear(
            hidden_size=self.embed_dim,
            head_size=self.head_dim,
            total_num_heads=self.total_num_heads,
            quant_config=quant_config,
        )
        self.out_proj = RowParallelLinear(
            input_size=self.embed_dim,
            output_size=self.embed_dim,
            quant_config=quant_config,
        )

        self.attn_fn = self._basic_attention_forward

    def forward(
        self,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        """Input shape: Batch x Time x Channel"""
        batch_size, q_len, _ = hidden_states.size()

        qkv_states, _ = self.qkv_proj(hidden_states)
        query_states, key_states, value_states = qkv_states.split(
            [self.qkv_size] * 3, dim=-1)

        attn_output = self.attn_fn(
            q=query_states,
            k=key_states,
            v=value_states,
            batch_size=batch_size,
            q_len=q_len,
        )

        attn_output, _ = self.out_proj(attn_output)
        return attn_output

    def _basic_attention_forward(self, q, k, v, batch_size, q_len):
        q = q.view(batch_size, q_len, self.num_heads,
                   self.head_dim).transpose(1, 2)
        k = k.view(batch_size, q_len, self.num_heads,
                   self.head_dim).transpose(1, 2)
        v = v.view(batch_size, q_len, self.num_heads,
                   self.head_dim).transpose(1, 2)

        k_v_seq_len = k.shape[-2]
        attn_weights = torch.matmul(q, k.transpose(2, 3)) * self.scale

        if attn_weights.size() != (
                batch_size,
                self.num_heads,
                q_len,
                k_v_seq_len,
        ):
            raise ValueError(
                "Attention weights should be of size "
                f"{(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is"
                f" {attn_weights.size()}")

        # upcast attention to fp32
        attn_weights = nn.functional.softmax(attn_weights,
                                             dim=-1,
                                             dtype=torch.float32).to(q.dtype)
        attn_weights = nn.functional.dropout(attn_weights,
                                             p=self.dropout,
                                             training=self.training)
        attn_output = torch.matmul(attn_weights, v)

        if attn_output.size() != (
                batch_size,
                self.num_heads,
                q_len,
                self.head_dim,
        ):
            raise ValueError(
                "`attn_output` should be of size "
                f"{(batch_size, self.num_heads, q_len, self.head_dim)}, but is"
                f" {attn_output.size()}")

        attn_output = attn_output.transpose(1, 2).contiguous()
        attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim)

        return attn_output


# NOTE: Not used - kept for later when we TP the ViT
# TODO(ChristopherCho): flash_attn_func is not working properly.
#                       It constantly throws a CUDA error.
class SiglipFlashAttention2(SiglipTPAttention):

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.attn_fn = self._flash_attention_forward

    # Ported from https://github.com/huggingface/transformers/blob/v4.43.3/src/transformers/models/siglip/modeling_siglip.py#L449
    # and https://github.com/huggingface/transformers/blob/v4.43.3/src/transformers/modeling_flash_attention_utils.py#L133
    def _flash_attention_forward(self, q, k, v, batch_size, q_len, *args,
                                 **kwargs):
        """Implements the multihead softmax attention.
        Arguments
        ---------
            q, k, v: The tensor containing the
                     query, key, and value. (B, S, H, D)
        """

        q = q.view(batch_size, q_len, self.num_heads, self.head_dim)
        k = k.view(batch_size, q_len, self.num_heads, self.head_dim)
        v = v.view(batch_size, q_len, self.num_heads, self.head_dim)

        attn_output = flash_attn_func(
            q,
            k,
            v,
            dropout_p=self.dropout,
            causal=False,
        )

        attn_output = attn_output.reshape(batch_size, q_len,
                                          self.embed_dim).contiguous()

        return attn_output


# NOTE: Not used - kept for later when we TP the ViT
class SiglipSdpaAttention(SiglipTPAttention):

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.is_causal = False
        self.attn_fn = self._sdpa_attention_forward

    def _sdpa_attention_forward(self, q, k, v, batch_size, q_len):
        q = q.view(batch_size, q_len, self.num_heads,
                   self.head_dim).transpose(1, 2)
        k = k.view(batch_size, q_len, self.num_heads,
                   self.head_dim).transpose(1, 2)
        v = v.view(batch_size, q_len, self.num_heads,
                   self.head_dim).transpose(1, 2)

        attn_output = torch.nn.functional.scaled_dot_product_attention(
            q, k, v, dropout_p=self.dropout, is_causal=False, scale=self.scale)

        attn_output = attn_output.transpose(1, 2).contiguous()
        attn_output = attn_output.view(batch_size, q_len, self.embed_dim)

        return attn_output


# NOTE: Not used - kept for later when we TP the ViT
class SiglipxFormersAttention(SiglipTPAttention):

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.attn_fn = self._xformers_attention_forward

    def _xformers_attention_forward(self, q, k, v, batch_size, q_len):
        q = q.view(batch_size, q_len, self.num_heads, self.head_dim)
        k = k.view(batch_size, q_len, self.num_heads, self.head_dim)
        v = v.view(batch_size, q_len, self.num_heads, self.head_dim)

        attn_output = memory_efficient_attention(q,
                                                 k,
                                                 v,
                                                 p=0.0,
                                                 scale=self.scale)
        attn_output = attn_output.reshape(batch_size, q_len,
                                          self.embed_dim).contiguous()

        return attn_output


# NOTE: Not used - kept for later when we TP the ViT
SIGLIP_ATTENTION_CLASSES = {
    "eager": SiglipTPAttention,
    "flash_attention_2": SiglipFlashAttention2,
    "sdpa": SiglipSdpaAttention,
    "xformers": SiglipxFormersAttention,
}


class SiglipMLP(nn.Module):

    def __init__(
        self,
        config,
        quant_config: Optional[QuantizationConfig] = None,
    ):
        super().__init__()
        self.config = config
        self.activation_fn = get_act_fn(config.hidden_act)

        # For quantization, we require the hidden size to be a multiple of 64
        quantizable = (config.hidden_size % 64 == 0
                       and config.intermediate_size % 64 == 0)
        self.fc1 = ColumnParallelLinear(
            config.hidden_size,
            config.intermediate_size,
            quant_config=quant_config if quantizable else None,
        )
        self.fc2 = RowParallelLinear(
            config.intermediate_size,
            config.hidden_size,
            quant_config=quant_config if quantizable else None,
        )

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states, _ = self.fc1(hidden_states)
        hidden_states = self.activation_fn(hidden_states)
        hidden_states, _ = self.fc2(hidden_states)
        return hidden_states


class SiglipEncoderLayer(nn.Module):

    def __init__(
        self,
470
        config: SiglipVisionConfig,
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
        quant_config: Optional[QuantizationConfig] = None,
    ):
        super().__init__()
        self.embed_dim = config.hidden_size

        # TODO(ChristopherCho): use TP'ed Attention block
        self.self_attn = SiglipAttention(config)
        self.layer_norm1 = nn.LayerNorm(self.embed_dim,
                                        eps=config.layer_norm_eps)
        self.mlp = SiglipMLP(
            config,
            quant_config=quant_config,
        )
        self.layer_norm2 = nn.LayerNorm(self.embed_dim,
                                        eps=config.layer_norm_eps)

    def forward(
        self,
        hidden_states: torch.Tensor,
490
    ) -> Tuple[torch.Tensor, None]:
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
        residual = hidden_states

        hidden_states = self.layer_norm1(hidden_states)
        hidden_states, _ = self.self_attn(hidden_states=hidden_states)
        hidden_states = residual + hidden_states

        residual = hidden_states
        hidden_states = self.layer_norm2(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states = residual + hidden_states

        return hidden_states, None


class SiglipEncoder(nn.Module):

    def __init__(
        self,
509
        config: SiglipVisionConfig,
510
        quant_config: Optional[QuantizationConfig] = None,
511
        num_hidden_layers_override: Optional[int] = None,
512
513
514
    ):
        super().__init__()
        self.config = config
515
516
517
518
519
520

        if num_hidden_layers_override is None:
            num_hidden_layers = config.num_hidden_layers
        else:
            num_hidden_layers = num_hidden_layers_override

521
        self.layers = nn.ModuleList([
522
523
            SiglipEncoderLayer(config, quant_config=quant_config)
            for _ in range(num_hidden_layers)
524
525
526
527
528
        ])

    def forward(
        self,
        inputs_embeds: torch.Tensor,
529
    ) -> torch.Tensor:
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
        hidden_states = inputs_embeds
        for encoder_layer in self.layers:
            hidden_states, _ = encoder_layer(hidden_states)

        return hidden_states


class SiglipMultiheadAttentionPoolingHead(nn.Module):
    """Multihead Attention Pooling."""

    def __init__(
        self,
        config: SiglipVisionConfig,
        quant_config: Optional[QuantizationConfig] = None,
    ):
        super().__init__()

        self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size))
        # TODO(ChristopherCho): Implement vLLM version of MultiheadAttention
        self.attention = torch.nn.MultiheadAttention(
            config.hidden_size, config.num_attention_heads, batch_first=True)
        self.layernorm = nn.LayerNorm(config.hidden_size,
                                      eps=config.layer_norm_eps)
        self.mlp = SiglipMLP(config=config, quant_config=quant_config)

    def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
        batch_size = hidden_state.shape[0]
        probe = self.probe.repeat(batch_size, 1, 1)

        hidden_state = self.attention(probe, hidden_state, hidden_state)[0]

        residual = hidden_state
        hidden_state = self.layernorm(hidden_state)
        hidden_state = residual + self.mlp(hidden_state)

        return hidden_state[:, 0]


class SiglipVisionTransformer(nn.Module):

    def __init__(
        self,
        config: SiglipVisionConfig,
        quant_config: Optional[QuantizationConfig] = None,
574
        num_hidden_layers_override: Optional[int] = None,
575
576
577
578
579
580
581
582
583
    ):
        super().__init__()
        self.config = config
        embed_dim = config.hidden_size

        self.embeddings = SiglipVisionEmbeddings(config)
        self.encoder = SiglipEncoder(
            config,
            quant_config=quant_config,
584
            num_hidden_layers_override=num_hidden_layers_override,
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
        )
        self.post_layernorm = nn.LayerNorm(embed_dim,
                                           eps=config.layer_norm_eps)
        self.use_head = (True if not hasattr(config, "vision_use_head") else
                         config.vision_use_head)
        if self.use_head:
            self.head = SiglipMultiheadAttentionPoolingHead(
                config=config, quant_config=quant_config)

    def forward(
        self,
        pixel_values: torch.Tensor,
        interpolate_pos_encoding: bool = True,
    ) -> torch.Tensor:
        hidden_states = self.embeddings(
            pixel_values,
            interpolate_pos_encoding=interpolate_pos_encoding,
        )

        encoder_outputs = self.encoder(inputs_embeds=hidden_states)

        last_hidden_state = self.post_layernorm(encoder_outputs)

        # TODO: add this back when pooled_output is used in inference
        # if self.use_head:
        # pooled_output = self.head(last_hidden_state)

        return last_hidden_state


class SiglipVisionModel(nn.Module):
    config_class = SiglipVisionConfig
    main_input_name = "pixel_values"

    def __init__(
        self,
        config: SiglipVisionConfig,
        quant_config: Optional[QuantizationConfig] = None,
623
        num_hidden_layers_override: Optional[int] = None,
624
625
626
627
628
    ):
        super().__init__()
        self.vision_model = SiglipVisionTransformer(
            config,
            quant_config,
629
            num_hidden_layers_override=num_hidden_layers_override,
630
631
632
633
634
635
636
637
638
639
640
641
642
643
        )

    def get_input_embeddings(self) -> nn.Module:
        return self.vision_model.embeddings.patch_embedding

    def forward(
        self,
        pixel_values: torch.Tensor,
        interpolate_pos_encoding: bool = False,
    ) -> torch.Tensor:
        return self.vision_model(
            pixel_values=pixel_values,
            interpolate_pos_encoding=interpolate_pos_encoding,
        )
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659

    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
        params_dict = dict(self.named_parameters())
        layer_count = len(self.vision_model.encoder.layers)

        for name, loaded_weight in weights:
            # omit layers when num_hidden_layers_override is set
            if "vision_model.encoder.layers." in name:
                layer_idx = int(name.split(".")[3])
                if layer_idx >= layer_count:
                    continue

            param = params_dict[name]
            weight_loader = getattr(param, "weight_loader",
                                    default_weight_loader)
            weight_loader(param, loaded_weight)