".buildkite/vscode:/vscode.git/clone" did not exist on "3132290a14a66dc73c9f15ec9cd9f8909c978e11"
qwen2_moe.py 24.4 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py
# Copyright 2024 The Qwen team.
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Qwen2MoE model compatible with HuggingFace weights."""
26
from typing import Any, Dict, Iterable, Optional, Set, Tuple, Union
27
28
29
30
31
32

import torch
import torch.nn.functional as F
from torch import nn
from transformers import PretrainedConfig

33
from vllm.attention import Attention
34
from vllm.compilation.decorators import support_torch_compile
35
from vllm.config import CacheConfig, VllmConfig
36
37
from vllm.distributed import (get_pp_group,
                              get_tensor_model_parallel_world_size,
38
                              tensor_model_parallel_all_reduce)
39
from vllm.logger import init_logger
40
from vllm.model_executor.layers.activation import SiluAndMul
41
from vllm.model_executor.layers.fused_moe import FusedMoE
42
from vllm.model_executor.layers.layernorm import RMSNorm
43
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
44
45
46
47
                                               QKVParallelLinear,
                                               ReplicatedLinear,
                                               RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
48
from vllm.model_executor.layers.quantization import QuantizationConfig
49
50
51
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
    ParallelLMHead, VocabParallelEmbedding)
52
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
53
from vllm.model_executor.sampling_metadata import SamplingMetadata
54
from vllm.sequence import IntermediateTensors
55

56
from .interfaces import SupportsPP
57
58
from .utils import (AutoWeightsLoader, extract_layer_index,
                    is_pp_missing_parameter,
59
60
                    make_empty_intermediate_tensors_factory, make_layers,
                    maybe_prefix)
zhuwenwen's avatar
zhuwenwen committed
61
62
63
64
65
import os
import re
from vllm import _custom_ops as ops
from vllm.model_executor.utils import pad_weight, gemm_bank_conf

66
67
logger = init_logger(__name__)

68
69
70
71
72
73
74
75

class Qwen2MoeMLP(nn.Module):

    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
76
        quant_config: Optional[QuantizationConfig] = None,
77
78
79
80
81
82
        reduce_results: bool = True,
    ) -> None:
        super().__init__()
        self.gate_up_proj = MergedColumnParallelLinear(
            hidden_size, [intermediate_size] * 2,
            bias=False,
83
            quant_config=quant_config)
84
85
86
        self.down_proj = RowParallelLinear(intermediate_size,
                                           hidden_size,
                                           bias=False,
87
                                           quant_config=quant_config,
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
                                           reduce_results=reduce_results)
        if hidden_act != "silu":
            raise ValueError(f"Unsupported activation: {hidden_act}. "
                             "Only silu is supported for now.")
        self.act_fn = SiluAndMul()

    def forward(self, x):
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.down_proj(x)
        return x


class Qwen2MoeSparseMoeBlock(nn.Module):

    def __init__(
        self,
        config: PretrainedConfig,
106
        quant_config: Optional[QuantizationConfig] = None,
107
        prefix: str = "",
108
109
110
    ):
        super().__init__()
        self.tp_size = get_tensor_model_parallel_world_size()
111
112

        if self.tp_size > config.num_experts:
113
114
            raise ValueError(
                f"Tensor parallel size {self.tp_size} is greater than "
115
116
117
118
119
120
121
122
                f"the number of experts {config.num_experts}.")

        self.experts = FusedMoE(num_experts=config.num_experts,
                                top_k=config.num_experts_per_tok,
                                hidden_size=config.hidden_size,
                                intermediate_size=config.moe_intermediate_size,
                                reduce_results=False,
                                renormalize=config.norm_topk_prob,
123
124
                                quant_config=quant_config,
                                prefix=f"{prefix}.experts")
125
126

        self.gate = ReplicatedLinear(config.hidden_size,
127
                                     config.num_experts,
128
                                     bias=False,
129
                                     quant_config=None)
130
131
132
133
134
        if config.shared_expert_intermediate_size > 0:
            self.shared_expert = Qwen2MoeMLP(
                hidden_size=config.hidden_size,
                intermediate_size=config.shared_expert_intermediate_size,
                hidden_act=config.hidden_act,
135
                quant_config=quant_config,
136
137
138
139
140
141
142
143
144
                reduce_results=False,
            )
        else:
            self.shared_expert = None
        self.shared_expert_gate = torch.nn.Linear(config.hidden_size,
                                                  1,
                                                  bias=False)

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
145
146
147
        # NOTE: hidden_states can have either 1D or 2D shape.
        orig_shape = hidden_states.shape
        hidden_dim = hidden_states.shape[-1]
148
149
150
151
152
153
154
155
156
157
        hidden_states = hidden_states.view(-1, hidden_dim)
        shared_output = None
        if self.shared_expert is not None:
            shared_output = self.shared_expert(hidden_states)
            if self.shared_expert_gate is not None:
                shared_output = F.sigmoid(
                    self.shared_expert_gate(hidden_states)) * shared_output

        # router_logits: (num_tokens, n_experts)
        router_logits, _ = self.gate(hidden_states)
158
159
        final_hidden_states = self.experts(hidden_states=hidden_states,
                                           router_logits=router_logits)
160
161
        if shared_output is not None:
            final_hidden_states = final_hidden_states + shared_output
162
163
164
        if self.tp_size > 1:
            final_hidden_states = tensor_model_parallel_all_reduce(
                final_hidden_states)
165

166
        return final_hidden_states.view(orig_shape)
167
168
169
170
171
172
173
174
175
176
177
178


class Qwen2MoeAttention(nn.Module):

    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
        num_kv_heads: int,
        rope_theta: float = 10000,
        rope_scaling: Optional[Dict[str, Any]] = None,
        max_position_embeddings: int = 8192,
179
        cache_config: Optional[CacheConfig] = None,
180
        quant_config: Optional[QuantizationConfig] = None,
181
        prefix: str = "",
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
    ) -> None:
        super().__init__()
        self.hidden_size = hidden_size
        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = num_heads
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        self.total_num_kv_heads = num_kv_heads
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
        self.head_dim = hidden_size // self.total_num_heads
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scaling = self.head_dim**-0.5
        self.rope_theta = rope_theta
        self.max_position_embeddings = max_position_embeddings

        self.qkv_proj = QKVParallelLinear(
            hidden_size,
            self.head_dim,
            self.total_num_heads,
            self.total_num_kv_heads,
            bias=True,
212
            quant_config=quant_config,
213
214
215
216
217
218
        )

        self.o_proj = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            hidden_size,
            bias=False,
219
            quant_config=quant_config,
220
221
222
223
224
225
226
227
228
229
230
231
        )

        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=max_position_embeddings,
            base=rope_theta,
            rope_scaling=rope_scaling,
        )
        self.attn = Attention(self.num_heads,
                              self.head_dim,
                              self.scaling,
232
                              num_kv_heads=self.num_kv_heads,
233
                              cache_config=cache_config,
234
235
                              quant_config=quant_config,
                              prefix=f"{prefix}.attn")
236
237
238
239
240
241
242
243
244

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
        q, k = self.rotary_emb(positions, q, k)
245
        attn_output = self.attn(q, k, v)
246
247
248
249
250
251
252
253
254
        output, _ = self.o_proj(attn_output)
        return output


class Qwen2MoeDecoderLayer(nn.Module):

    def __init__(
        self,
        config: PretrainedConfig,
255
        cache_config: Optional[CacheConfig] = None,
256
        quant_config: Optional[QuantizationConfig] = None,
257
        prefix: str = "",
258
259
260
261
262
263
264
265
266
267
268
269
270
271
    ) -> None:
        super().__init__()
        self.hidden_size = config.hidden_size
        rope_theta = getattr(config, "rope_theta", 10000)
        rope_scaling = getattr(config, "rope_scaling", None)
        max_position_embeddings = getattr(config, "max_position_embeddings",
                                          8192)
        self.self_attn = Qwen2MoeAttention(
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
            num_kv_heads=config.num_key_value_heads,
            rope_theta=rope_theta,
            rope_scaling=rope_scaling,
            max_position_embeddings=max_position_embeddings,
272
            cache_config=cache_config,
273
            quant_config=quant_config,
274
            prefix=f"{prefix}.self_attn",
275
        )
276
277
278

        # Note: Qwen/Qwen2-57B-A14B-Instruct does not have
        # `mlp_only_layers` in the config.
279
        layer_idx = extract_layer_index(prefix)
280
281
282
        mlp_only_layers = ([] if not hasattr(config, "mlp_only_layers") else
                           config.mlp_only_layers)
        if (layer_idx not in mlp_only_layers) and (
283
284
                config.num_experts > 0 and
            (layer_idx + 1) % config.decoder_sparse_step == 0):
285
            self.mlp = Qwen2MoeSparseMoeBlock(config=config,
286
287
                                              quant_config=quant_config,
                                              prefix=f"{prefix}.mlp")
288
289
290
291
292
        else:
            self.mlp = Qwen2MoeMLP(
                hidden_size=config.hidden_size,
                intermediate_size=config.intermediate_size,
                hidden_act=config.hidden_act,
293
                quant_config=quant_config,
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
            )
        self.input_layernorm = RMSNorm(config.hidden_size,
                                       eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(config.hidden_size,
                                                eps=config.rms_norm_eps)

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        residual: Optional[torch.Tensor],
    ) -> torch.Tensor:
        # Self Attention
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
            hidden_states, residual = self.input_layernorm(
                hidden_states, residual)
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
        )

        # Fully Connected
        hidden_states, residual = self.post_attention_layernorm(
            hidden_states, residual)
        hidden_states = self.mlp(hidden_states)
        return hidden_states, residual


325
@support_torch_compile
326
327
class Qwen2MoeModel(nn.Module):

328
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
329
        super().__init__()
330
331
332
333
334

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

335
        self.vocab_size = config.vocab_size
336
        self.config = config
337
338
339
340
341

        self.embed_tokens = VocabParallelEmbedding(
            config.vocab_size,
            config.hidden_size,
        )
342
343
344
345
        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
            lambda prefix: Qwen2MoeDecoderLayer(config=config,
                                                cache_config=cache_config,
346
347
                                                quant_config=quant_config,
                                                prefix=prefix),
348
349
            prefix=f"{prefix}.layers",
        )
350
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
351
352
353
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(
                ["hidden_states", "residual"], config.hidden_size))
zhuwenwen's avatar
zhuwenwen committed
354
355
356
357
358
359
360
361
362
363
        
        self.quant_method = None
        if quant_config is not None:
            self.quant_method=quant_config.get_name()
            self.quant_config=quant_config
               
        self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
        self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
        self.use_fa_pad = os.environ.get('FA_PAD') == '1'
        self.use_awq_pad = os.environ.get('AWQ_PAD') == '1'
364

365
366
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.embed_tokens(input_ids)
367
368
369
370
371

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
372
        intermediate_tensors: Optional[IntermediateTensors] = None,
373
        inputs_embeds: Optional[torch.Tensor] = None,
374
    ) -> Union[torch.Tensor, IntermediateTensors]:
375
        if get_pp_group().is_first_rank:
376
377
378
379
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
                hidden_states = self.get_input_embeddings(input_ids)
380
381
382
383
384
            residual = None
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]
385
386
        for layer in self.layers[self.start_layer:self.end_layer]:
            hidden_states, residual = layer(positions, hidden_states, residual)
387
388
389
390
391
        if not get_pp_group().is_last_rank:
            return IntermediateTensors({
                "hidden_states": hidden_states,
                "residual": residual
            })
392
393
394
395
        hidden_states, _ = self.norm(hidden_states, residual)
        return hidden_states


396
397
    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
398
399
400
401
402
403
404
405
406
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
        ]

407
408
409
410
411
412
413
        # Params for weights, fp8 weight scales, fp8 activation scales
        # (param_name, weight_name, expert_id, shard_id)
        expert_params_mapping = FusedMoE.make_expert_params_mapping(
            ckpt_gate_proj_name="gate_proj",
            ckpt_down_proj_name="down_proj",
            ckpt_up_proj_name="up_proj",
            num_experts=self.config.num_experts)
414

415
        params_dict = dict(self.named_parameters())
416
        loaded_params: Set[str] = set()
417
        for name, loaded_weight in weights:
418
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
419
                # Skip non-stacked layers and experts (experts handled below).
420
421
                if weight_name not in name:
                    continue
422
423
424
425
426
427
428
429
                # We have mlp.experts[0].gate_proj in the checkpoint.
                # Since we handle the experts below in expert_params_mapping,
                # we need to skip here BEFORE we update the name, otherwise
                # name will be updated to mlp.experts[0].gate_up_proj, which
                # will then be updated below in expert_params_mapping
                # for mlp.experts[0].gate_gate_up_proj, which breaks load.
                if "mlp.experts" in name:
                    continue
430
431
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
432
433
                if ((name.endswith(".bias") or name.endswith("_bias"))
                        and name not in params_dict):
434
                    continue
435
436
437
                # Skip layers on other devices.
                if is_pp_missing_parameter(name, self):
                    continue
438
439
440
                if name not in params_dict:
                    continue

441
442
443
444
445
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
446
447
448
449
450
                for mapping in expert_params_mapping:
                    param_name, weight_name, expert_id, shard_id = mapping
                    if weight_name not in name:
                        continue
                    name = name.replace(weight_name, param_name)
451
452
453
                    # Skip layers on other devices.
                    if is_pp_missing_parameter(name, self):
                        continue
454
455
456
457
                    # Skip loading extra bias for GPTQ models.
                    if ((name.endswith(".bias") or name.endswith("_bias"))
                            and name not in params_dict):
                        continue
458
459
460
461
                    param = params_dict[name]
                    weight_loader = param.weight_loader
                    weight_loader(param,
                                  loaded_weight,
462
                                  name,
463
464
465
466
467
                                  shard_id=shard_id,
                                  expert_id=expert_id)
                    break
                else:
                    # Skip loading extra bias for GPTQ models.
468
469
                    if ((name.endswith(".bias") or name.endswith("_bias"))
                            and name not in params_dict):
470
                        continue
471
472
473
                    # Skip layers on other devices.
                    if is_pp_missing_parameter(name, self):
                        continue
474
475
476
477
478
                    # Remapping the name of FP8 kv-scale.
                    if name.endswith("kv_scale"):
                        remapped_kv_scale_name = name.replace(
                            ".kv_scale", ".attn.kv_scale")
                        if remapped_kv_scale_name not in params_dict:
479
                            logger.warning_once(
480
481
482
483
484
485
486
487
                                "Found kv scale in the checkpoint "
                                f"(e.g. {name}), but not found the expected "
                                f"name in the model "
                                f"(e.g. {remapped_kv_scale_name}). "
                                "kv-scale is not loaded.")
                            continue
                        else:
                            name = remapped_kv_scale_name
488
489
490
491
                    param = params_dict[name]
                    weight_loader = getattr(param, "weight_loader",
                                            default_weight_loader)
                    weight_loader(param, loaded_weight)
492
            loaded_params.add(name)
zhuwenwen's avatar
zhuwenwen committed
493
494
495
496
497
498
499
500
501
502
503
504

        if self.use_llama_nn and self.quant_method is None:
            lay_key_words = [
                "gate_up_proj.weight",
                "down_proj.weight",
                "mlp.gate.weight",
                "self_attn.qkv_proj.weight",
                "self_attn.o_proj.weight",
                "lm_head.weight",
            ]
            combined_words = "|".join(lay_key_words)
            
zhuwenwen's avatar
zhuwenwen committed
505
506
            # lay_qkv_words = ["self_attn.qkv_proj.weight"]   
            # qkv_words = "|".join(lay_qkv_words)  
zhuwenwen's avatar
zhuwenwen committed
507
            
zhuwenwen's avatar
zhuwenwen committed
508
509
            # lay_qkv_bias_words = ["self_attn.qkv_proj.bias"]   
            # qkv_bias_words = "|".join(lay_qkv_bias_words) 
zhuwenwen's avatar
zhuwenwen committed
510
            
zhuwenwen's avatar
zhuwenwen committed
511
512
            for layername in loaded_params:
                weight = params_dict[layername]
zhuwenwen's avatar
zhuwenwen committed
513
                os.environ['LM_NN'] = '0' 
zhuwenwen's avatar
zhuwenwen committed
514
515
516
517
518
                # if self.use_fa_pad and (re.findall(qkv_bias_words, layername)):
                #     weight.data = pad_weight(weight.data, 32)
                    
                matches = re.findall(combined_words, layername)
                if matches:   
zhuwenwen's avatar
zhuwenwen committed
519
520
                    # if self.use_gemm_pad and gemm_bank_conf(weight.data.shape[0]):
                    #     weight.data = pad_weight(weight.data, 32)  
zhuwenwen's avatar
zhuwenwen committed
521
522
523
524
525
526
527
528
529
530
531
                    
                    # if self.use_fa_pad and (re.findall(qkv_words, layername)):
                    #     if not gemm_bank_conf(weight.data.shape[0]):
                    #         weight.data = pad_weight(weight.data, 32)
                        
                    _weight = torch.zeros_like(weight.data)
                    ori_shape =_weight.shape
                    
                    ops.trans_w16_gemm(_weight, weight.data, _weight.shape[0], _weight.shape[1])
                    weight.data.copy_(_weight)
                    
zhuwenwen's avatar
zhuwenwen committed
532
533
                    weight.data=weight.data.reshape(ori_shape[1],-1)
            
534
        return loaded_params
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587


class Qwen2MoeForCausalLM(nn.Module, SupportsPP):

    fall_back_to_pt_during_load = False

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        self.config = config
        self.quant_config = quant_config
        self.model = Qwen2MoeModel(vllm_config=vllm_config,
                                   prefix=maybe_prefix(prefix, "model"))
        self.lm_head = ParallelLMHead(config.vocab_size,
                                      config.hidden_size,
                                      quant_config=quant_config)
        if self.config.tie_word_embeddings:
            self.lm_head.weight = self.model.embed_tokens.weight
        self.logits_processor = LogitsProcessor(config.vocab_size)
        self.make_empty_intermediate_tensors = (
            self.model.make_empty_intermediate_tensors)

    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.get_input_embeddings(input_ids)

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        intermediate_tensors: Optional[IntermediateTensors] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
    ) -> Union[torch.Tensor, IntermediateTensors]:
        hidden_states = self.model(input_ids, positions, intermediate_tensors,
                                   inputs_embeds)
        return hidden_states

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
        logits = self.logits_processor(self.lm_head, hidden_states,
                                       sampling_metadata)
        return logits

    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
        loader = AutoWeightsLoader(
            self,
            skip_prefixes=(["rotary_emb.inv_freq"]),
        )
        return loader.load_weights(weights)