t5.py 26 KB
Newer Older
hepj's avatar
hepj committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
# SPDX-License-Identifier: Apache-2.0
# Adapted from transformers: https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/t5/modeling_t5.py

# Derived from T5 implementation posted on HuggingFace; license below:
#
# coding=utf-8
# Copyright 2018 Mesh TensorFlow authors, T5 Authors and HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch T5 & UMT5 model."""

import math
from dataclasses import dataclass
from typing import Iterable, Optional, Set, Tuple

import torch
import torch.nn.functional as F
from torch import nn

from fastvideo.v1.configs.models.encoders import BaseEncoderOutput, T5Config
from fastvideo.v1.configs.quantization import QuantizationConfig
from fastvideo.v1.distributed import (get_tensor_model_parallel_rank,
                                      get_tensor_model_parallel_world_size)
from fastvideo.v1.layers.activation import get_act_fn
from fastvideo.v1.layers.layernorm import RMSNorm
from fastvideo.v1.layers.linear import (MergedColumnParallelLinear,
                                        QKVParallelLinear, RowParallelLinear)
from fastvideo.v1.layers.vocab_parallel_embedding import VocabParallelEmbedding
from fastvideo.v1.models.encoders.base import TextEncoder
from fastvideo.v1.models.loader.weight_utils import default_weight_loader


class AttentionType:
    """
    Attention type.
    Use string to be compatible with `torch.compile`.
    """
    # Decoder attention between previous layer Q/K/V
    DECODER = "decoder"
    # Encoder attention between previous layer Q/K/V for encoder-decoder
    ENCODER = "encoder"
    # Encoder attention between previous layer Q/K/V
    ENCODER_ONLY = "encoder_only"
    # Attention between dec. Q and enc. K/V for encoder-decoder
    ENCODER_DECODER = "encoder_decoder"


@dataclass
class AttentionMetadata:
    attn_bias: torch.Tensor


class T5DenseActDense(nn.Module):

    def __init__(self,
                 config: T5Config,
                 quant_config: Optional[QuantizationConfig] = None):
        super().__init__()
        self.wi = MergedColumnParallelLinear(config.d_model, [config.d_ff],
                                             bias=False)
        self.wo = RowParallelLinear(config.d_ff,
                                    config.d_model,
                                    bias=False,
                                    quant_config=quant_config)
        self.act = get_act_fn(config.dense_act_fn)

    def forward(self, hidden_states) -> torch.Tensor:
        hidden_states, _ = self.wi(hidden_states)
        hidden_states = self.act(hidden_states)
        hidden_states, _ = self.wo(hidden_states)
        return hidden_states


class T5DenseGatedActDense(nn.Module):

    def __init__(self,
                 config: T5Config,
                 quant_config: Optional[QuantizationConfig] = None):
        super().__init__()
        self.wi_0 = MergedColumnParallelLinear(config.d_model, [config.d_ff],
                                               bias=False,
                                               quant_config=quant_config)
        self.wi_1 = MergedColumnParallelLinear(config.d_model, [config.d_ff],
                                               bias=False,
                                               quant_config=quant_config)
        # Should not run in fp16 unless mixed-precision is used,
        # see https://github.com/huggingface/transformers/issues/20287.
        self.wo = RowParallelLinear(config.d_ff,
                                    config.d_model,
                                    bias=False,
                                    quant_config=quant_config)
        self.act = get_act_fn(config.dense_act_fn)

    def forward(self, hidden_states) -> torch.Tensor:
        hidden_gelu = self.act(self.wi_0(hidden_states)[0])
        hidden_linear, _ = self.wi_1(hidden_states)
        hidden_states = hidden_gelu * hidden_linear
        hidden_states, _ = self.wo(hidden_states)
        return hidden_states


class T5LayerFF(nn.Module):

    def __init__(self,
                 config: T5Config,
                 quant_config: Optional[QuantizationConfig] = None):
        super().__init__()
        if config.is_gated_act:
            self.DenseReluDense = T5DenseGatedActDense(
                config, quant_config=quant_config)
        else:
            self.DenseReluDense = T5DenseActDense(config,
                                                  quant_config=quant_config)

        self.layer_norm = RMSNorm(config.d_model, eps=config.layer_norm_epsilon)

    def forward(self, hidden_states) -> torch.Tensor:
        forwarded_states = self.layer_norm.forward_native(hidden_states)
        forwarded_states = self.DenseReluDense(forwarded_states)
        hidden_states = hidden_states + forwarded_states
        return hidden_states


# T5 has attn_bias and does not use softmax scaling
class T5MultiHeadAttention(nn.Module):

    def __init__(self) -> None:
        super().__init__()

    def forward(self, q, k, v, attn_bias=None):
        b, _, n, c = q.shape
        attn = torch.einsum('binc,bjnc->bnij', q, k)
        if attn_bias is not None:
            attn += attn_bias

        attn = F.softmax(attn.float(), dim=-1).type_as(attn)
        x = torch.einsum('bnij,bjnc->binc', attn, v)
        x = x.reshape(b, -1, n * c)
        return x


class T5Attention(nn.Module):

    def __init__(self,
                 config: T5Config,
                 attn_type: str,
                 has_relative_attention_bias=False,
                 quant_config: Optional[QuantizationConfig] = None,
                 prefix: str = ""):
        super().__init__()
        self.attn_type = attn_type
        # Cross-attention has no relative pos encoding anyway
        self.is_decoder = attn_type == AttentionType.DECODER
        self.has_relative_attention_bias = has_relative_attention_bias
        self.relative_attention_num_buckets = \
            config.relative_attention_num_buckets
        self.relative_attention_max_distance = \
            config.relative_attention_max_distance
        self.d_model = config.d_model
        self.key_value_proj_dim = config.d_kv
        self.total_num_heads = self.total_num_kv_heads = config.num_heads

        # Partition heads across multiple tensor parallel GPUs.
        tp_world_size = get_tensor_model_parallel_world_size()
        assert config.num_heads % tp_world_size == 0
        self.n_heads = config.num_heads // tp_world_size

        self.inner_dim = self.n_heads * self.key_value_proj_dim
        # No GQA in t5.
        # self.n_kv_heads = self.n_heads

        self.qkv_proj = QKVParallelLinear(
            self.d_model,
            self.d_model // self.total_num_heads,
            self.total_num_heads,
            self.total_num_kv_heads,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.qkv_proj",
        )

        self.attn = T5MultiHeadAttention()

        if self.has_relative_attention_bias:
            self.relative_attention_bias = \
                VocabParallelEmbedding(self.relative_attention_num_buckets,
                                       self.total_num_heads,
                                       org_num_embeddings=self.relative_attention_num_buckets,
                                       padding_size=self.relative_attention_num_buckets,
                                       quant_config=quant_config)
        self.o = RowParallelLinear(
            self.d_model,
            self.d_model,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.o_proj",
        )

    @staticmethod
    def _relative_position_bucket(relative_position,
                                  bidirectional=True,
                                  num_buckets=32,
                                  max_distance=128) -> torch.Tensor:
        """
        Adapted from Mesh Tensorflow:
        https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593
        Translate relative position to a bucket number for relative attention. 
        The relative position is defined as memory_position - query_position, 
        i.e. the distance in tokens from the attending position to the 
        attended-to position. If bidirectional=False, then positive relative 
        positions are invalid. We use smaller buckets for small absolute 
        relative_position and larger buckets for larger absolute 
        relative_positions. All relative positions >=max_distance map to the
        same bucket. All relative positions <=-max_distance map to the same 
        bucket. This should allow for more graceful generalization to longer
        sequences than the model has been trained on
        Args:
            relative_position: an int32 Tensor
            bidirectional: a boolean - whether the attention is bidirectional
            num_buckets: an integer
            max_distance: an integer
        Returns:
            a Tensor with the same shape as relative_position, containing int32
            values in the range [0, num_buckets)
        """# noqa: E501
        relative_buckets = 0
        if bidirectional:
            num_buckets //= 2
            relative_buckets += (relative_position > 0).to(
                torch.long) * num_buckets
            relative_position = torch.abs(relative_position)
        else:
            relative_position = -torch.min(relative_position,
                                           torch.zeros_like(relative_position))
        # now relative_position is in the range [0, inf)

        # half of the buckets are for exact increments in positions
        max_exact = num_buckets // 2
        is_small = relative_position < max_exact

        # The other half of the buckets are for logarithmically bigger bins
        # in positions up to max_distance
        relative_position_if_large = max_exact + (
            torch.log(relative_position.float() / max_exact) /
            math.log(max_distance / max_exact) *
            (num_buckets - max_exact)).to(torch.long)
        relative_position_if_large = torch.min(
            relative_position_if_large,
            torch.full_like(relative_position_if_large, num_buckets - 1))

        relative_buckets += torch.where(is_small, relative_position,
                                        relative_position_if_large)
        return relative_buckets

    def compute_bias(self,
                     query_length,
                     key_length,
                     device=None) -> torch.Tensor:
        """Compute binned relative position bias"""
        if device is None:
            device = self.relative_attention_bias.weight.device
        context_position = torch.arange(query_length,
                                        dtype=torch.long,
                                        device=device)[:, None]
        memory_position = torch.arange(key_length,
                                       dtype=torch.long,
                                       device=device)[None, :]
        # max_seq_len, nh
        relative_position = memory_position - context_position
        relative_position_bucket = self._relative_position_bucket(
            relative_position,  # shape (query_length, key_length)
            bidirectional=(not self.is_decoder),
            num_buckets=self.relative_attention_num_buckets,
            max_distance=self.relative_attention_max_distance,
        )
        values = self.relative_attention_bias(
            relative_position_bucket
        )  # shape (query_length, key_length, num_heads)
        x = values.permute([2, 0, 1]).unsqueeze(
            0)  # shape (1, num_heads, query_length, key_length)
        return x

    def forward(
        self,
        hidden_states: torch.Tensor,  # (num_tokens, d_model)
        attention_mask: torch.Tensor,
        attn_metadata: Optional[AttentionMetadata] = None,
    ) -> torch.Tensor:
        bs, seq_len, _ = hidden_states.shape
        num_seqs = bs
        n, c = self.n_heads, self.d_model // self.total_num_heads
        qkv, _ = self.qkv_proj(hidden_states)
        # Projection of 'own' hidden state (self-attention). No GQA here.
        q, k, v = qkv.split(self.inner_dim, dim=-1)
        q = q.reshape(bs, seq_len, n, c)
        k = k.reshape(bs, seq_len, n, c)
        v = v.reshape(bs, seq_len, n, c)

        assert attn_metadata is not None
        attn_bias = attn_metadata.attn_bias
        # Not compatible with CP here (as all encoder-decoder models),
        # as it assumes homogeneous batch (prefills or decodes).
        if self.has_relative_attention_bias:
            # Self-attention. Compute T5 relative positional encoding.
            # The bias term is computed on longest sequence in batch. Biases
            # for shorter sequences are slices of the longest.
            assert self.attn_type == AttentionType.ENCODER
            attn_bias = self.compute_bias(seq_len,
                                          seq_len).repeat(num_seqs, 1, 1, 1)
            attn_metadata.attn_bias = attn_bias
        else:
            # Encoder/Decoder Self-Attention Layer, attn bias already cached.
            assert attn_bias is not None

        if attention_mask is not None:
            attention_mask = attention_mask.view(
                bs, 1, 1,
                -1) if attention_mask.ndim == 2 else attention_mask.unsqueeze(1)
            attn_bias.masked_fill_(attention_mask == 0,
                                   torch.finfo(q.dtype).min)

        if get_tensor_model_parallel_world_size() > 1:
            rank = get_tensor_model_parallel_rank()
            attn_bias = attn_bias[:, rank * self.n_heads:(rank + 1) *
                                  self.n_heads, :, :]
        attn_output = self.attn(q, k, v, attn_bias)
        output, _ = self.o(attn_output)
        return output


class T5LayerSelfAttention(nn.Module):

    def __init__(
        self,
        config,
        has_relative_attention_bias=False,
        quant_config: Optional[QuantizationConfig] = None,
        prefix: str = "",
    ):
        super().__init__()
        self.SelfAttention = T5Attention(
            config,
            AttentionType.DECODER
            if "decoder" in prefix else AttentionType.ENCODER,
            has_relative_attention_bias=has_relative_attention_bias,
            quant_config=quant_config,
            prefix=f"{prefix}.SelfAttention")
        self.layer_norm = RMSNorm(config.d_model, eps=config.layer_norm_epsilon)

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: torch.Tensor,
        attn_metadata: Optional[AttentionMetadata] = None,
    ) -> torch.Tensor:
        normed_hidden_states = self.layer_norm.forward_native(hidden_states)
        attention_output = self.SelfAttention(
            hidden_states=normed_hidden_states,
            attention_mask=attention_mask,
            attn_metadata=attn_metadata,
        )
        hidden_states = hidden_states + attention_output
        return hidden_states


class T5LayerCrossAttention(nn.Module):

    def __init__(self,
                 config,
                 quant_config: Optional[QuantizationConfig] = None,
                 prefix: str = ""):
        super().__init__()
        self.EncDecAttention = T5Attention(config,
                                           AttentionType.ENCODER_DECODER,
                                           has_relative_attention_bias=False,
                                           quant_config=quant_config,
                                           prefix=f"{prefix}.EncDecAttention")
        self.layer_norm = RMSNorm(config.d_model, eps=config.layer_norm_epsilon)

    def forward(
        self,
        hidden_states: torch.Tensor,
        attn_metadata: Optional[AttentionMetadata] = None,
    ) -> torch.Tensor:
        normed_hidden_states = self.layer_norm.forward_native(hidden_states)
        attention_output = self.EncDecAttention(
            hidden_states=normed_hidden_states,
            attn_metadata=attn_metadata,
        )
        hidden_states = hidden_states + attention_output
        return hidden_states


class T5Block(nn.Module):

    def __init__(self,
                 config: T5Config,
                 is_decoder: bool,
                 has_relative_attention_bias=False,
                 quant_config: Optional[QuantizationConfig] = None,
                 prefix: str = ""):
        super().__init__()
        self.is_decoder = is_decoder
        self.layer = nn.ModuleList()
        self.layer.append(
            T5LayerSelfAttention(
                config,
                has_relative_attention_bias=has_relative_attention_bias,
                quant_config=quant_config,
                prefix=f"{prefix}.self_attn"))

        if self.is_decoder:
            self.layer.append(
                T5LayerCrossAttention(config,
                                      quant_config=quant_config,
                                      prefix=f"{prefix}.cross_attn"))

        self.layer.append(T5LayerFF(config, quant_config=quant_config))

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: torch.Tensor,
        attn_metadata: Optional[AttentionMetadata] = None,
    ) -> torch.Tensor:

        hidden_states = self.layer[0](hidden_states=hidden_states,
                                      attention_mask=attention_mask,
                                      attn_metadata=attn_metadata)
        if self.is_decoder:
            hidden_states = self.layer[1](hidden_states=hidden_states,
                                          attn_metadata=attn_metadata)

            # Apply Feed Forward layer
            hidden_states = self.layer[2](hidden_states)
        else:
            hidden_states = self.layer[1](hidden_states)
        return hidden_states


class T5Stack(nn.Module):

    def __init__(self,
                 config: T5Config,
                 is_decoder: bool,
                 n_layers: int,
                 embed_tokens=None,
                 quant_config: Optional[QuantizationConfig] = None,
                 prefix: str = "",
                 is_umt5: bool = False):
        super().__init__()
        self.embed_tokens = embed_tokens
        self.is_umt5 = is_umt5
        if is_umt5:
            self.block = nn.ModuleList([
                T5Block(config,
                        is_decoder=is_decoder,
                        has_relative_attention_bias=True,
                        quant_config=quant_config,
                        prefix=f"{prefix}.blocks.{i}") for i in range(n_layers)
            ])
        else:
            # Only the first block has relative positional encoding.
            self.block = nn.ModuleList([
                T5Block(config,
                        is_decoder=is_decoder,
                        has_relative_attention_bias=i == 0,
                        quant_config=quant_config,
                        prefix=f"{prefix}.blocks.{i}") for i in range(n_layers)
            ])
        self.final_layer_norm = RMSNorm(config.d_model,
                                        eps=config.layer_norm_epsilon)

    def forward(
        self,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor,
        attn_metadata: AttentionMetadata,
    ) -> torch.Tensor:
        hidden_states = self.embed_tokens(input_ids)

        for idx, block in enumerate(self.block):
            hidden_states = block(
                hidden_states=hidden_states,
                attention_mask=attention_mask,
                attn_metadata=attn_metadata,
            )
        hidden_states = self.final_layer_norm.forward_native(hidden_states)
        return hidden_states


class T5EncoderModel(TextEncoder):

    def __init__(self, config: T5Config, prefix: str = ""):
        super().__init__(config)

        quant_config = None

        self.shared = VocabParallelEmbedding(
            config.vocab_size,
            config.d_model,
            org_num_embeddings=config.vocab_size)

        self.encoder = T5Stack(config,
                               False,
                               config.num_layers,
                               self.shared,
                               quant_config=quant_config,
                               prefix=f"{prefix}.encoder",
                               is_umt5=False)

    def get_input_embeddings(self):
        return self.shared

    def forward(
        self,
        input_ids: Optional[torch.Tensor],
        position_ids: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        output_hidden_states: Optional[bool] = None,
        **kwargs,
    ) -> BaseEncoderOutput:
        attn_metadata = AttentionMetadata(None)
        hidden_states = self.encoder(
            input_ids=input_ids,
            attention_mask=attention_mask,
            attn_metadata=attn_metadata,
        )

        return BaseEncoderOutput(last_hidden_state=hidden_states)

    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            (".qkv_proj", ".q", "q"),
            (".qkv_proj", ".k", "k"),
            (".qkv_proj", ".v", "v"),
        ]
        params_dict = dict(self.named_parameters())
        loaded_params: Set[str] = set()
        for name, loaded_weight in weights:
            loaded = False
            if "decoder" in name or "lm_head" in name:
                continue
            for param_name, weight_name, shard_id in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue

                if name not in params_dict:
                    continue

                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                loaded = True
                break
            if not loaded:
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue

                if name not in params_dict:
                    continue

                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params


class UMT5EncoderModel(TextEncoder):

    def __init__(self, config: T5Config, prefix: str = ""):
        super().__init__(config)

        quant_config = None

        self.shared = VocabParallelEmbedding(
            config.vocab_size,
            config.d_model,
            org_num_embeddings=config.vocab_size)

        self.encoder = T5Stack(config,
                               False,
                               config.num_layers,
                               self.shared,
                               quant_config=quant_config,
                               prefix=f"{prefix}.encoder",
                               is_umt5=True)

    def get_input_embeddings(self):
        return self.shared

    def forward(
        self,
        input_ids: Optional[torch.Tensor],
        position_ids: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        output_hidden_states: Optional[bool] = None,
        **kwargs,
    ) -> BaseEncoderOutput:
        attn_metadata = AttentionMetadata(None)
        hidden_states = self.encoder(
            input_ids=input_ids,
            attention_mask=attention_mask,
            attn_metadata=attn_metadata,
        )

        return BaseEncoderOutput(
            last_hidden_state=hidden_states,
            attention_mask=attention_mask,
        )

    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            (".qkv_proj", ".q", "q"),
            (".qkv_proj", ".k", "k"),
            (".qkv_proj", ".v", "v"),
        ]
        params_dict = dict(self.named_parameters())
        loaded_params: Set[str] = set()
        for name, loaded_weight in weights:
            loaded = False
            if "decoder" in name or "lm_head" in name:
                continue
            for param_name, weight_name, shard_id in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue

                if name not in params_dict:
                    continue

                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                loaded = True
                break
            if not loaded:
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue

                if name not in params_dict:
                    continue

                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader",
                                        default_weight_loader)
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params