mixtral.py 20.3 KB
Newer Older
Pierre Stock's avatar
Pierre Stock committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
# coding=utf-8
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Mixtral model."""
24
from typing import Iterable, List, Optional, Tuple
Pierre Stock's avatar
Pierre Stock committed
25
26
27

import torch
from torch import nn
28
from transformers import MixtralConfig
Pierre Stock's avatar
Pierre Stock committed
29

30
from vllm import _custom_ops as ops
31
from vllm.attention import Attention, AttentionMetadata
Terry's avatar
Terry committed
32
from vllm.config import LoRAConfig
33
34
35
from vllm.distributed import (get_tensor_model_parallel_rank,
                              get_tensor_model_parallel_world_size,
                              tensor_model_parallel_all_reduce)
Philipp Moritz's avatar
Philipp Moritz committed
36
from vllm.model_executor.layers.fused_moe import fused_moe
Pierre Stock's avatar
Pierre Stock committed
37
from vllm.model_executor.layers.layernorm import RMSNorm
38
from vllm.model_executor.layers.linear import (QKVParallelLinear,
Philipp Moritz's avatar
Philipp Moritz committed
39
                                               ReplicatedLinear,
Pierre Stock's avatar
Pierre Stock committed
40
                                               RowParallelLinear)
41
from vllm.model_executor.layers.logits_processor import LogitsProcessor
42
43
44
from vllm.model_executor.layers.quantization.base_config import (
    QuantizationConfig)
from vllm.model_executor.layers.quantization.fp8 import Fp8Config
45
from vllm.model_executor.layers.rotary_embedding import get_rope
Pierre Stock's avatar
Pierre Stock committed
46
47
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
48
    DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
49
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
Pierre Stock's avatar
Pierre Stock committed
50
from vllm.model_executor.sampling_metadata import SamplingMetadata
Philipp Moritz's avatar
Philipp Moritz committed
51
from vllm.model_executor.utils import set_weight_attrs
Pierre Stock's avatar
Pierre Stock committed
52
from vllm.sequence import SamplerOutput
53
from vllm.utils import print_warning_once
Pierre Stock's avatar
Pierre Stock committed
54
55


Philipp Moritz's avatar
Philipp Moritz committed
56
57
58
59
60
61
62
63
class MixtralMoE(nn.Module):
    """A tensor-parallel MoE implementation for Mixtral that shards each expert
    across all ranks.

    Each expert's weights are sharded across all ranks and a fused MoE
    kernel is used for the forward pass, and finally we reduce the outputs
    across ranks.
    """
64
65
66
67

    def __init__(
        self,
        num_experts: int,
Philipp Moritz's avatar
Philipp Moritz committed
68
        top_k: int,
69
70
        hidden_size: int,
        intermediate_size: int,
Philipp Moritz's avatar
Philipp Moritz committed
71
        params_dtype: Optional[torch.dtype] = None,
72
        tp_size: Optional[int] = None,
73
        quant_config: Optional[QuantizationConfig] = None,
Philipp Moritz's avatar
Philipp Moritz committed
74
    ):
75
        super().__init__()
76
        self.tp_size = tp_size or get_tensor_model_parallel_world_size()
Philipp Moritz's avatar
Philipp Moritz committed
77
78
79
        self.num_total_experts = num_experts
        self.top_k = top_k
        self.hidden_size = hidden_size
80
        self.intermediate_size = intermediate_size // self.tp_size
81
82
        # FIXME(pcmoritz): Make this more general to support different
        # quantization schemes
83
        self.use_fp8 = isinstance(quant_config, Fp8Config)
84

Philipp Moritz's avatar
Philipp Moritz committed
85
86
87
        if params_dtype is None:
            params_dtype = torch.get_default_dtype()
        self.params_dtype = params_dtype
88

Philipp Moritz's avatar
Philipp Moritz committed
89
        self.gate = ReplicatedLinear(self.hidden_size,
90
91
                                     self.num_total_experts,
                                     bias=False,
Philipp Moritz's avatar
Philipp Moritz committed
92
                                     params_dtype=self.params_dtype,
93
                                     quant_config=None)
94

Philipp Moritz's avatar
Philipp Moritz committed
95
96
97
98
99
100
101
102
103
104
105
        self.ws = nn.Parameter(
            torch.empty(self.num_total_experts,
                        2 * self.intermediate_size,
                        self.hidden_size,
                        dtype=self.params_dtype))
        self.w2s = nn.Parameter(
            torch.empty(self.num_total_experts,
                        self.hidden_size,
                        self.intermediate_size,
                        dtype=self.params_dtype))

106
107
108
109
110
111
112
        set_weight_attrs(self.ws, {
            "weight_loader": self.weight_loader,
        })
        set_weight_attrs(self.w2s, {
            "weight_loader": self.weight_loader,
        })

113
114
        # Scaling factors for FP8 weights
        self.ws_scale = nn.Parameter(
115
            torch.ones(self.num_total_experts, dtype=torch.float32),
116
117
            requires_grad=False) if self.use_fp8 else None
        self.w2s_scale = nn.Parameter(
118
            torch.ones(self.num_total_experts, dtype=torch.float32),
119
120
            requires_grad=False) if self.use_fp8 else None

121
122
123
124
        # Scaling factors for FP8 activations
        need_act_scales = (self.use_fp8
                           and quant_config.activation_scheme == "static")
        self.as_scale = nn.Parameter(
125
            torch.zeros(1, dtype=torch.float32),
126
127
            requires_grad=False) if need_act_scales else None
        self.a2s_scale = nn.Parameter(
128
            torch.zeros(1, dtype=torch.float32),
129
130
131
132
133
134
135
136
137
            requires_grad=False) if need_act_scales else None

        if need_act_scales:
            set_weight_attrs(self.as_scale, {
                "weight_loader": self.weight_loader,
            })
            set_weight_attrs(self.a2s_scale, {
                "weight_loader": self.weight_loader,
            })
Philipp Moritz's avatar
Philipp Moritz committed
138
139
140
141
142
143
144
145
146
147
148
149
150
151

    def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor,
                      weight_name: str, expert_id: int):
        tp_rank = get_tensor_model_parallel_rank()
        param_data = param.data
        shard_size = self.intermediate_size
        shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
        if weight_name.endswith("w1.weight"):
            param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :]
        if weight_name.endswith("w3.weight"):
            param_data[expert_id,
                       shard_size:2 * shard_size, :] = loaded_weight[shard, :]
        if weight_name.endswith("w2.weight"):
            param_data[expert_id, :, :] = loaded_weight[:, shard]
152
153
        if "act_scale" in weight_name:
            param_data[:] = param_data[:].max(loaded_weight)
Philipp Moritz's avatar
Philipp Moritz committed
154

155
156
157
158
159
    def process_weights_after_loading(self):
        if self.use_fp8:
            ws = torch.empty_like(self.ws.data, dtype=torch.float8_e4m3fn)
            w2s = torch.empty_like(self.w2s.data, dtype=torch.float8_e4m3fn)
            for expert in range(self.num_total_experts):
160
                ws[expert, :, :], self.ws_scale[expert] = ops.scaled_fp8_quant(
161
162
                    self.ws.data[expert, :, :])
                w2s[expert, :, :], self.w2s_scale[
163
                    expert] = ops.scaled_fp8_quant(self.w2s.data[expert, :, :])
164
165
166
            self.ws = nn.Parameter(ws, requires_grad=False)
            self.w2s = nn.Parameter(w2s, requires_grad=False)

167
    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
168
        num_tokens, hidden_size = hidden_states.shape
Philipp Moritz's avatar
Philipp Moritz committed
169
        hidden_states = hidden_states.view(-1, self.hidden_size)
170
        # router_logits: (num_tokens, n_experts)
171
        router_logits, _ = self.gate(hidden_states)
Philipp Moritz's avatar
Philipp Moritz committed
172
173
174
        final_hidden_states = fused_moe(hidden_states,
                                        self.ws,
                                        self.w2s,
175
176
177
                                        router_logits,
                                        self.top_k,
                                        renormalize=True,
178
179
180
                                        inplace=True,
                                        use_fp8=self.use_fp8,
                                        w1_scale=self.ws_scale,
181
182
183
                                        w2_scale=self.w2s_scale,
                                        a1_scale=self.as_scale,
                                        a2_scale=self.a2s_scale)
Philipp Moritz's avatar
Philipp Moritz committed
184

185
186
187
        if self.tp_size > 1:
            final_hidden_states = tensor_model_parallel_all_reduce(
                final_hidden_states)
188

189
        return final_hidden_states.view(num_tokens, hidden_size)
Pierre Stock's avatar
Pierre Stock committed
190
191
192
193
194
195
196
197
198
199


class MixtralAttention(nn.Module):

    def __init__(self,
                 hidden_size: int,
                 num_heads: int,
                 num_kv_heads: int,
                 max_position: int = 4096 * 32,
                 rope_theta: float = 10000,
200
                 quant_config: Optional[QuantizationConfig] = None,
Pierre Stock's avatar
Pierre Stock committed
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
                 sliding_window: Optional[int] = None) -> None:
        super().__init__()
        self.hidden_size = hidden_size
        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = num_heads
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        self.total_num_kv_heads = num_kv_heads
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
        self.head_dim = hidden_size // self.total_num_heads
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scaling = self.head_dim**-0.5
        self.rope_theta = rope_theta
        self.sliding_window = sliding_window

225
        if isinstance(quant_config, Fp8Config):
226
227
228
229
            print_warning_once(
                "For Mixtral FP8 quantization, we currently do not quantize "
                "the attention layers until their FP8 performance is improved."
            )
230
            quant_config = None
231

232
        self.qkv_proj = QKVParallelLinear(
Pierre Stock's avatar
Pierre Stock committed
233
234
235
236
237
            hidden_size,
            self.head_dim,
            self.total_num_heads,
            self.total_num_kv_heads,
            bias=False,
238
            quant_config=quant_config,
Pierre Stock's avatar
Pierre Stock committed
239
        )
240
        self.o_proj = RowParallelLinear(
Pierre Stock's avatar
Pierre Stock committed
241
242
243
            self.total_num_heads * self.head_dim,
            hidden_size,
            bias=False,
244
            quant_config=quant_config,
Pierre Stock's avatar
Pierre Stock committed
245
246
247
248
249
250
        )
        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=max_position,
            base=int(self.rope_theta),
251
            is_neox_style=True,
Pierre Stock's avatar
Pierre Stock committed
252
        )
253
        self.attn = Attention(
Pierre Stock's avatar
Pierre Stock committed
254
255
256
257
258
259
260
261
262
263
264
            self.num_heads,
            self.head_dim,
            self.scaling,
            num_kv_heads=self.num_kv_heads,
            sliding_window=self.sliding_window,
        )

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
265
266
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
Pierre Stock's avatar
Pierre Stock committed
267
    ) -> torch.Tensor:
268
        qkv, _ = self.qkv_proj(hidden_states)
Pierre Stock's avatar
Pierre Stock committed
269
270
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
        q, k = self.rotary_emb(positions, q, k)
271
        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
272
        output, _ = self.o_proj(attn_output)
Pierre Stock's avatar
Pierre Stock committed
273
274
275
276
277
278
279
        return output


class MixtralDecoderLayer(nn.Module):

    def __init__(
        self,
280
        config: MixtralConfig,
281
        quant_config: Optional[QuantizationConfig] = None,
Pierre Stock's avatar
Pierre Stock committed
282
283
284
285
286
    ) -> None:
        super().__init__()
        self.hidden_size = config.hidden_size
        # Requires transformers > 4.32.0
        rope_theta = getattr(config, "rope_theta", 10000)
287
        self.self_attn = MixtralAttention(
Pierre Stock's avatar
Pierre Stock committed
288
289
290
291
292
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
            max_position=config.max_position_embeddings,
            num_kv_heads=config.num_key_value_heads,
            rope_theta=rope_theta,
293
            sliding_window=config.sliding_window,
294
            quant_config=quant_config)
Philipp Moritz's avatar
Philipp Moritz committed
295
296
297
298
        self.block_sparse_moe = MixtralMoE(
            num_experts=config.num_local_experts,
            top_k=config.num_experts_per_tok,
            hidden_size=config.hidden_size,
299
            intermediate_size=config.intermediate_size,
300
            quant_config=quant_config)
301
302
303
304
        self.input_layernorm = RMSNorm(config.hidden_size,
                                       eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(config.hidden_size,
                                                eps=config.rms_norm_eps)
Pierre Stock's avatar
Pierre Stock committed
305
306
307
308

    def forward(
        self,
        positions: torch.Tensor,
309
        hidden_states: torch.Tensor,
310
311
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
312
        residual: Optional[torch.Tensor],
Pierre Stock's avatar
Pierre Stock committed
313
    ) -> torch.Tensor:
314
315
316
317
318
319
320
321
        # Self Attention
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
            hidden_states, residual = self.input_layernorm(
                hidden_states, residual)
        hidden_states = self.self_attn(
Pierre Stock's avatar
Pierre Stock committed
322
            positions=positions,
323
            hidden_states=hidden_states,
Pierre Stock's avatar
Pierre Stock committed
324
            kv_cache=kv_cache,
325
            attn_metadata=attn_metadata,
Pierre Stock's avatar
Pierre Stock committed
326
327
        )

328
329
330
331
332
        # Fully Connected
        hidden_states, residual = self.post_attention_layernorm(
            hidden_states, residual)
        hidden_states = self.block_sparse_moe(hidden_states)
        return hidden_states, residual
Pierre Stock's avatar
Pierre Stock committed
333

334
335

class MixtralModel(nn.Module):
Pierre Stock's avatar
Pierre Stock committed
336
337
338

    def __init__(
        self,
339
        config: MixtralConfig,
340
        quant_config: Optional[QuantizationConfig] = None,
341
        lora_config: Optional[LoRAConfig] = None,
Pierre Stock's avatar
Pierre Stock committed
342
343
344
    ) -> None:
        super().__init__()
        self.padding_idx = config.pad_token_id
345
346
347
348
        lora_vocab = (lora_config.lora_extra_vocab_size *
                      (lora_config.max_loras or 1)) if lora_config else 0
        self.vocab_size = config.vocab_size + lora_vocab
        self.org_vocab_size = config.vocab_size
349
350

        self.embed_tokens = VocabParallelEmbedding(
351
            self.vocab_size,
Pierre Stock's avatar
Pierre Stock committed
352
            config.hidden_size,
353
            org_num_embeddings=config.vocab_size,
Pierre Stock's avatar
Pierre Stock committed
354
355
        )
        self.layers = nn.ModuleList([
356
            MixtralDecoderLayer(config, quant_config=quant_config)
Pierre Stock's avatar
Pierre Stock committed
357
358
            for _ in range(config.num_hidden_layers)
        ])
359
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
Pierre Stock's avatar
Pierre Stock committed
360
361
362
363
364

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
365
366
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
367
    ) -> torch.Tensor:
368
369
        hidden_states = self.embed_tokens(input_ids)
        residual = None
Pierre Stock's avatar
Pierre Stock committed
370
371
        for i in range(len(self.layers)):
            layer = self.layers[i]
372
            hidden_states, residual = layer(positions, hidden_states,
373
                                            kv_caches[i], attn_metadata,
374
                                            residual)
375
376
377
378
379
        hidden_states, _ = self.norm(hidden_states, residual)
        return hidden_states


class MixtralForCausalLM(nn.Module):
380
381
    fall_back_to_pt_during_load = False

Terry's avatar
Terry committed
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
    }

    # LoRA specific attributes
    supported_lora_modules = [
        "qkv_proj",
        "o_proj",
        "embed_tokens",
        "lm_head",
    ]
    embedding_modules = {
        "embed_tokens": "input_embeddings",
        "lm_head": "output_embeddings",
    }
    embedding_padding_modules = ["lm_head"]
402
403
404
405

    def __init__(
        self,
        config: MixtralConfig,
406
        quant_config: Optional[QuantizationConfig] = None,
Terry's avatar
Terry committed
407
        lora_config: Optional[LoRAConfig] = None,
408
409
410
    ) -> None:
        super().__init__()
        self.config = config
411
        self.model = MixtralModel(config,
412
                                  quant_config,
413
                                  lora_config=lora_config)
Terry's avatar
Terry committed
414
415
416
417
418
419
420
421
422
423
424
425
        self.unpadded_vocab_size = config.vocab_size
        if lora_config:
            self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
        self.lm_head = ParallelLMHead(
            self.unpadded_vocab_size,
            config.hidden_size,
            org_num_embeddings=config.vocab_size,
            padding_size=DEFAULT_VOCAB_PADDING_SIZE
            # We need bigger padding if using lora for kernel
            # compatibility
            if not lora_config else lora_config.lora_vocab_padding_size,
        )
426
427
428
        self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
                                                config.vocab_size)
        self.sampler = Sampler()
429
430
431
432
433

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
434
435
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
436
437
    ) -> torch.Tensor:
        hidden_states = self.model(input_ids, positions, kv_caches,
438
                                   attn_metadata)
Pierre Stock's avatar
Pierre Stock committed
439
440
        return hidden_states

441
442
443
444
445
446
    def compute_logits(self, hidden_states: torch.Tensor,
                       sampling_metadata: SamplingMetadata) -> torch.Tensor:
        logits = self.logits_processor(self.lm_head.weight, hidden_states,
                                       sampling_metadata)
        return logits

Pierre Stock's avatar
Pierre Stock committed
447
448
    def sample(
        self,
449
        logits: Optional[torch.Tensor],
Pierre Stock's avatar
Pierre Stock committed
450
        sampling_metadata: SamplingMetadata,
451
    ) -> Optional[SamplerOutput]:
452
        next_tokens = self.sampler(logits, sampling_metadata)
Pierre Stock's avatar
Pierre Stock committed
453
454
        return next_tokens

455
    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
Pierre Stock's avatar
Pierre Stock committed
456
457
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
458
459
460
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
Pierre Stock's avatar
Pierre Stock committed
461
        ]
462

Philipp Moritz's avatar
Philipp Moritz committed
463
        expert_params_mapping = [
464
            # These are the weights for the experts
Philipp Moritz's avatar
Philipp Moritz committed
465
466
467
468
469
            # (param_name, weight_name, expert_id)
            ("ws" if weight_name in ["w1", "w3"] else "w2s",
             f"experts.{expert_id}.{weight_name}.weight", expert_id)
            for expert_id in range(self.config.num_local_experts)
            for weight_name in ["w1", "w2", "w3"]
470
471
472
473
474
475
476
        ] + [
            # These are the activation scales for the experts
            # (param_name, weight_name, expert_id)
            ("as_scale" if weight_name in ["w1", "w3"] else "a2s_scale",
             f"experts.{expert_id}.{weight_name}.act_scale", expert_id)
            for expert_id in range(self.config.num_local_experts)
            for weight_name in ["w1", "w2", "w3"]
Philipp Moritz's avatar
Philipp Moritz committed
477
478
        ]

Pierre Stock's avatar
Pierre Stock committed
479
        params_dict = dict(self.named_parameters())
480
        for name, loaded_weight in weights:
Pierre Stock's avatar
Pierre Stock committed
481
482
            if "rotary_emb.inv_freq" in name:
                continue
Philipp Moritz's avatar
Philipp Moritz committed
483

Pierre Stock's avatar
Pierre Stock committed
484
485
486
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
                if weight_name not in name:
                    continue
CHU Tianxiang's avatar
CHU Tianxiang committed
487
488
489
490
491
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                param = params_dict[name]
Pierre Stock's avatar
Pierre Stock committed
492
493
494
495
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
Philipp Moritz's avatar
Philipp Moritz committed
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
                for param_name, weight_name, expert_id in expert_params_mapping:
                    if weight_name not in name:
                        continue
                    name = name.replace(weight_name, param_name)
                    param = params_dict[name]
                    weight_loader = param.weight_loader
                    weight_loader(param,
                                  loaded_weight,
                                  weight_name,
                                  expert_id=expert_id)
                    break
                else:
                    # Skip loading extra bias for GPTQ models.
                    if name.endswith(".bias") and name not in params_dict:
                        continue
                    param = params_dict[name]
                    weight_loader = getattr(param, "weight_loader",
                                            default_weight_loader)
                    weight_loader(param, loaded_weight)