qwen2_moe.py 24.8 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py
# Copyright 2024 The Qwen team.
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Qwen2MoE model compatible with HuggingFace weights."""
24
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
25
26
27
28
29
30
31

import torch
import torch.nn.functional as F
from torch import nn
from transformers import PretrainedConfig

from vllm.attention import Attention, AttentionMetadata
32
from vllm.compilation.decorators import support_torch_compile
33
from vllm.config import CacheConfig, VllmConfig
34
35
from vllm.distributed import (get_pp_group,
                              get_tensor_model_parallel_world_size,
36
                              tensor_model_parallel_all_reduce)
37
from vllm.model_executor.layers.activation import SiluAndMul
38
from vllm.model_executor.layers.fused_moe import FusedMoE
39
from vllm.model_executor.layers.layernorm import RMSNorm
40
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
41
42
43
44
                                               QKVParallelLinear,
                                               ReplicatedLinear,
                                               RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
45
from vllm.model_executor.layers.quantization import QuantizationConfig
46
from vllm.model_executor.layers.rotary_embedding import get_rope
Joe Runde's avatar
Joe Runde committed
47
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
48
49
from vllm.model_executor.layers.vocab_parallel_embedding import (
    ParallelLMHead, VocabParallelEmbedding)
50
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
51
from vllm.model_executor.sampling_metadata import SamplingMetadata
52
from vllm.sequence import IntermediateTensors
53
from vllm.utils import print_warning_once
54

55
from .interfaces import SupportsPP
56
from .utils import (extract_layer_index, is_pp_missing_parameter,
57
58
                    make_empty_intermediate_tensors_factory, make_layers,
                    maybe_prefix)
59

zhuwenwen's avatar
zhuwenwen committed
60
61
62
63
64
import os
import re
from vllm import _custom_ops as ops
from vllm.model_executor.utils import pad_weight, gemm_bank_conf

65
66
67
68
69
70
71
72

class Qwen2MoeMLP(nn.Module):

    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
73
        quant_config: Optional[QuantizationConfig] = None,
74
75
76
77
78
79
        reduce_results: bool = True,
    ) -> None:
        super().__init__()
        self.gate_up_proj = MergedColumnParallelLinear(
            hidden_size, [intermediate_size] * 2,
            bias=False,
80
            quant_config=quant_config)
81
82
83
        self.down_proj = RowParallelLinear(intermediate_size,
                                           hidden_size,
                                           bias=False,
84
                                           quant_config=quant_config,
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
                                           reduce_results=reduce_results)
        if hidden_act != "silu":
            raise ValueError(f"Unsupported activation: {hidden_act}. "
                             "Only silu is supported for now.")
        self.act_fn = SiluAndMul()

    def forward(self, x):
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.down_proj(x)
        return x


class Qwen2MoeSparseMoeBlock(nn.Module):

    def __init__(
        self,
        config: PretrainedConfig,
103
        quant_config: Optional[QuantizationConfig] = None,
104
105
106
    ):
        super().__init__()
        self.tp_size = get_tensor_model_parallel_world_size()
107
108

        if self.tp_size > config.num_experts:
109
110
            raise ValueError(
                f"Tensor parallel size {self.tp_size} is greater than "
111
112
113
114
115
116
117
118
119
                f"the number of experts {config.num_experts}.")

        self.experts = FusedMoE(num_experts=config.num_experts,
                                top_k=config.num_experts_per_tok,
                                hidden_size=config.hidden_size,
                                intermediate_size=config.moe_intermediate_size,
                                reduce_results=False,
                                renormalize=config.norm_topk_prob,
                                quant_config=quant_config)
120
121

        self.gate = ReplicatedLinear(config.hidden_size,
122
                                     config.num_experts,
123
                                     bias=False,
124
                                     quant_config=None)
125
126
127
128
129
        if config.shared_expert_intermediate_size > 0:
            self.shared_expert = Qwen2MoeMLP(
                hidden_size=config.hidden_size,
                intermediate_size=config.shared_expert_intermediate_size,
                hidden_act=config.hidden_act,
130
                quant_config=quant_config,
131
132
133
134
135
136
137
138
139
                reduce_results=False,
            )
        else:
            self.shared_expert = None
        self.shared_expert_gate = torch.nn.Linear(config.hidden_size,
                                                  1,
                                                  bias=False)

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
140
141
142
        # NOTE: hidden_states can have either 1D or 2D shape.
        orig_shape = hidden_states.shape
        hidden_dim = hidden_states.shape[-1]
143
144
145
146
147
148
149
150
151
152
        hidden_states = hidden_states.view(-1, hidden_dim)
        shared_output = None
        if self.shared_expert is not None:
            shared_output = self.shared_expert(hidden_states)
            if self.shared_expert_gate is not None:
                shared_output = F.sigmoid(
                    self.shared_expert_gate(hidden_states)) * shared_output

        # router_logits: (num_tokens, n_experts)
        router_logits, _ = self.gate(hidden_states)
153
154
        final_hidden_states = self.experts(hidden_states=hidden_states,
                                           router_logits=router_logits)
155
156
        if shared_output is not None:
            final_hidden_states = final_hidden_states + shared_output
157
158
159
        if self.tp_size > 1:
            final_hidden_states = tensor_model_parallel_all_reduce(
                final_hidden_states)
160

161
        return final_hidden_states.view(orig_shape)
162
163
164
165
166
167
168
169
170
171
172
173


class Qwen2MoeAttention(nn.Module):

    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
        num_kv_heads: int,
        rope_theta: float = 10000,
        rope_scaling: Optional[Dict[str, Any]] = None,
        max_position_embeddings: int = 8192,
174
        cache_config: Optional[CacheConfig] = None,
175
        quant_config: Optional[QuantizationConfig] = None,
176
        prefix: str = "",
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
    ) -> None:
        super().__init__()
        self.hidden_size = hidden_size
        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = num_heads
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        self.total_num_kv_heads = num_kv_heads
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
        self.head_dim = hidden_size // self.total_num_heads
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scaling = self.head_dim**-0.5
        self.rope_theta = rope_theta
        self.max_position_embeddings = max_position_embeddings

        self.qkv_proj = QKVParallelLinear(
            hidden_size,
            self.head_dim,
            self.total_num_heads,
            self.total_num_kv_heads,
            bias=True,
207
            quant_config=quant_config,
208
209
210
211
212
213
        )

        self.o_proj = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            hidden_size,
            bias=False,
214
            quant_config=quant_config,
215
216
217
218
219
220
221
222
223
224
225
226
        )

        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=max_position_embeddings,
            base=rope_theta,
            rope_scaling=rope_scaling,
        )
        self.attn = Attention(self.num_heads,
                              self.head_dim,
                              self.scaling,
227
                              num_kv_heads=self.num_kv_heads,
228
                              cache_config=cache_config,
229
230
                              quant_config=quant_config,
                              prefix=f"{prefix}.attn")
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
        q, k = self.rotary_emb(positions, q, k)
        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
        output, _ = self.o_proj(attn_output)
        return output


class Qwen2MoeDecoderLayer(nn.Module):

    def __init__(
        self,
        config: PretrainedConfig,
252
        cache_config: Optional[CacheConfig] = None,
253
        quant_config: Optional[QuantizationConfig] = None,
254
        prefix: str = "",
255
256
257
258
259
260
261
262
263
264
265
266
267
268
    ) -> None:
        super().__init__()
        self.hidden_size = config.hidden_size
        rope_theta = getattr(config, "rope_theta", 10000)
        rope_scaling = getattr(config, "rope_scaling", None)
        max_position_embeddings = getattr(config, "max_position_embeddings",
                                          8192)
        self.self_attn = Qwen2MoeAttention(
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
            num_kv_heads=config.num_key_value_heads,
            rope_theta=rope_theta,
            rope_scaling=rope_scaling,
            max_position_embeddings=max_position_embeddings,
269
            cache_config=cache_config,
270
            quant_config=quant_config,
271
            prefix=f"{prefix}.self_attn",
272
        )
273
274
275

        # Note: Qwen/Qwen2-57B-A14B-Instruct does not have
        # `mlp_only_layers` in the config.
276
        layer_idx = extract_layer_index(prefix)
277
278
279
        mlp_only_layers = ([] if not hasattr(config, "mlp_only_layers") else
                           config.mlp_only_layers)
        if (layer_idx not in mlp_only_layers) and (
280
281
                config.num_experts > 0 and
            (layer_idx + 1) % config.decoder_sparse_step == 0):
282
            self.mlp = Qwen2MoeSparseMoeBlock(config=config,
283
                                              quant_config=quant_config)
284
285
286
287
288
        else:
            self.mlp = Qwen2MoeMLP(
                hidden_size=config.hidden_size,
                intermediate_size=config.intermediate_size,
                hidden_act=config.hidden_act,
289
                quant_config=quant_config,
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
            )
        self.input_layernorm = RMSNorm(config.hidden_size,
                                       eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(config.hidden_size,
                                                eps=config.rms_norm_eps)

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        kv_cache: torch.Tensor,
        attn_metadata: AttentionMetadata,
        residual: Optional[torch.Tensor],
    ) -> torch.Tensor:
        # Self Attention
        if residual is None:
            residual = hidden_states
            hidden_states = self.input_layernorm(hidden_states)
        else:
            hidden_states, residual = self.input_layernorm(
                hidden_states, residual)
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
            kv_cache=kv_cache,
            attn_metadata=attn_metadata,
        )

        # Fully Connected
        hidden_states, residual = self.post_attention_layernorm(
            hidden_states, residual)
        hidden_states = self.mlp(hidden_states)
        return hidden_states, residual


325
@support_torch_compile
326
327
class Qwen2MoeModel(nn.Module):

328
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
329
        super().__init__()
330
331
332
333
334

        config = vllm_config.model_config.hf_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config

335
336
337
338
339
340
341
        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size

        self.embed_tokens = VocabParallelEmbedding(
            config.vocab_size,
            config.hidden_size,
        )
342
343
344
345
        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
            lambda prefix: Qwen2MoeDecoderLayer(config=config,
                                                cache_config=cache_config,
346
347
                                                quant_config=quant_config,
                                                prefix=prefix),
348
349
            prefix=f"{prefix}.layers",
        )
350
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
351
352
353
        self.make_empty_intermediate_tensors = (
            make_empty_intermediate_tensors_factory(
                ["hidden_states", "residual"], config.hidden_size))
354

355
356
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.embed_tokens(input_ids)
357
358
359
360
361
362
363

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
364
        intermediate_tensors: Optional[IntermediateTensors] = None,
365
        inputs_embeds: Optional[torch.Tensor] = None,
366
    ) -> Union[torch.Tensor, IntermediateTensors]:
367
        if get_pp_group().is_first_rank:
368
369
370
371
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
                hidden_states = self.get_input_embeddings(input_ids)
372
373
374
375
376
377
            residual = None
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]
        for i in range(self.start_layer, self.end_layer):
378
379
            layer = self.layers[i]
            hidden_states, residual = layer(positions, hidden_states,
380
381
382
383
384
385
386
                                            kv_caches[i - self.start_layer],
                                            attn_metadata, residual)
        if not get_pp_group().is_last_rank:
            return IntermediateTensors({
                "hidden_states": hidden_states,
                "residual": residual
            })
387
388
389
390
        hidden_states, _ = self.norm(hidden_states, residual)
        return hidden_states


391
class Qwen2MoeForCausalLM(nn.Module, SupportsPP):
392

393
394
    fall_back_to_pt_during_load = False

395
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
396
        super().__init__()
397
398
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
399
        self.config = config
400
        self.quant_config = quant_config
401
402
        self.model = Qwen2MoeModel(vllm_config=vllm_config,
                                   prefix=maybe_prefix(prefix, "model"))
403
404
405
        self.lm_head = ParallelLMHead(config.vocab_size,
                                      config.hidden_size,
                                      quant_config=quant_config)
406
407
        if self.config.tie_word_embeddings:
            self.lm_head.weight = self.model.embed_tokens.weight
408
        self.logits_processor = LogitsProcessor(config.vocab_size)
zhuwenwen's avatar
zhuwenwen committed
409

Joe Runde's avatar
Joe Runde committed
410
        self.sampler = get_sampler()
411
412
        self.make_empty_intermediate_tensors = (
            self.model.make_empty_intermediate_tensors)
zhuwenwen's avatar
zhuwenwen committed
413
414
415
416
417
418
419
420
421
422
        
        self.quant_method = None
        if quant_config is not None:
            self.quant_method=quant_config.get_name()
            self.quant_config=quant_config
               
        self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
        self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
        self.use_fa_pad = os.environ.get('FA_PAD') == '1'
        self.use_awq_pad = os.environ.get('AWQ_PAD') == '1'
423

424
425
426
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.get_input_embeddings(input_ids)

427
428
429
430
431
432
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        kv_caches: List[torch.Tensor],
        attn_metadata: AttentionMetadata,
433
        intermediate_tensors: Optional[IntermediateTensors] = None,
434
        inputs_embeds: Optional[torch.Tensor] = None,
435
    ) -> Union[torch.Tensor, IntermediateTensors]:
436
        hidden_states = self.model(input_ids, positions, kv_caches,
437
438
                                   attn_metadata, intermediate_tensors,
                                   inputs_embeds)
439
440
        return hidden_states

441
442
443
444
445
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> Optional[torch.Tensor]:
446
        logits = self.logits_processor(self.lm_head, hidden_states,
447
448
449
450
451
452
453
454
455
456
457
                                       sampling_metadata)
        return logits

    def sample(
        self,
        logits: Optional[torch.Tensor],
        sampling_metadata: SamplingMetadata,
    ) -> Optional[SamplerOutput]:
        next_tokens = self.sampler(logits, sampling_metadata)
        return next_tokens

458
459
    def load_weights(self, weights: Iterable[Tuple[str,
                                                   torch.Tensor]]) -> Set[str]:
460
461
462
463
464
465
466
467
468
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
        ]

469
470
471
472
473
474
475
        # Params for weights, fp8 weight scales, fp8 activation scales
        # (param_name, weight_name, expert_id, shard_id)
        expert_params_mapping = FusedMoE.make_expert_params_mapping(
            ckpt_gate_proj_name="gate_proj",
            ckpt_down_proj_name="down_proj",
            ckpt_up_proj_name="up_proj",
            num_experts=self.config.num_experts)
476

477
        params_dict = dict(self.named_parameters())
478
        loaded_params: Set[str] = set()
479
        for name, loaded_weight in weights:
480
481
482
            if "rotary_emb.inv_freq" in name:
                continue
            for (param_name, weight_name, shard_id) in stacked_params_mapping:
483
                # Skip non-stacked layers and experts (experts handled below).
484
485
                if weight_name not in name:
                    continue
486
487
488
489
490
491
492
493
                # We have mlp.experts[0].gate_proj in the checkpoint.
                # Since we handle the experts below in expert_params_mapping,
                # we need to skip here BEFORE we update the name, otherwise
                # name will be updated to mlp.experts[0].gate_up_proj, which
                # will then be updated below in expert_params_mapping
                # for mlp.experts[0].gate_gate_up_proj, which breaks load.
                if "mlp.experts" in name:
                    continue
494
495
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
496
497
                if ((name.endswith(".bias") or name.endswith("_bias"))
                        and name not in params_dict):
498
                    continue
499
500
501
                # Skip layers on other devices.
                if is_pp_missing_parameter(name, self):
                    continue
502
503
504
                if name not in params_dict:
                    continue

505
506
507
508
509
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
510
511
512
513
514
                for mapping in expert_params_mapping:
                    param_name, weight_name, expert_id, shard_id = mapping
                    if weight_name not in name:
                        continue
                    name = name.replace(weight_name, param_name)
515
516
517
                    # Skip layers on other devices.
                    if is_pp_missing_parameter(name, self):
                        continue
518
519
520
521
                    # Skip loading extra bias for GPTQ models.
                    if ((name.endswith(".bias") or name.endswith("_bias"))
                            and name not in params_dict):
                        continue
522
523
524
525
                    param = params_dict[name]
                    weight_loader = param.weight_loader
                    weight_loader(param,
                                  loaded_weight,
526
                                  name,
527
528
529
530
531
                                  shard_id=shard_id,
                                  expert_id=expert_id)
                    break
                else:
                    # Skip loading extra bias for GPTQ models.
532
533
                    if ((name.endswith(".bias") or name.endswith("_bias"))
                            and name not in params_dict):
534
                        continue
535
536
537
                    # Skip layers on other devices.
                    if is_pp_missing_parameter(name, self):
                        continue
538
539
540
541
542
543
544
545
546
547
548
549
550
551
                    # Remapping the name of FP8 kv-scale.
                    if name.endswith("kv_scale"):
                        remapped_kv_scale_name = name.replace(
                            ".kv_scale", ".attn.kv_scale")
                        if remapped_kv_scale_name not in params_dict:
                            print_warning_once(
                                "Found kv scale in the checkpoint "
                                f"(e.g. {name}), but not found the expected "
                                f"name in the model "
                                f"(e.g. {remapped_kv_scale_name}). "
                                "kv-scale is not loaded.")
                            continue
                        else:
                            name = remapped_kv_scale_name
552
553
554
555
                    param = params_dict[name]
                    weight_loader = getattr(param, "weight_loader",
                                            default_weight_loader)
                    weight_loader(param, loaded_weight)
zhuwenwen's avatar
zhuwenwen committed
556
                    loaded_params.add(name)
zhuwenwen's avatar
zhuwenwen committed
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593

        if self.use_llama_nn and self.quant_method is None:
            lay_key_words = [
                "gate_up_proj.weight",
                "down_proj.weight",
                "mlp.gate.weight",
                "self_attn.qkv_proj.weight",
                "self_attn.o_proj.weight",
                "lm_head.weight",
            ]
            combined_words = "|".join(lay_key_words)
            
            lay_qkv_words = ["self_attn.qkv_proj.weight"]   
            qkv_words = "|".join(lay_qkv_words)  
            
            lay_qkv_bias_words = ["self_attn.qkv_proj.bias"]   
            qkv_bias_words = "|".join(lay_qkv_bias_words) 
            
            for layername, weight in params_dict.items():
                # if self.use_fa_pad and (re.findall(qkv_bias_words, layername)):
                #     weight.data = pad_weight(weight.data, 32)
                    
                matches = re.findall(combined_words, layername)
                if matches:   
                    if self.use_gemm_pad and gemm_bank_conf(weight.data.shape[0]):
                        weight.data = pad_weight(weight.data, 32)  
                    
                    # if self.use_fa_pad and (re.findall(qkv_words, layername)):
                    #     if not gemm_bank_conf(weight.data.shape[0]):
                    #         weight.data = pad_weight(weight.data, 32)
                        
                    _weight = torch.zeros_like(weight.data)
                    ori_shape =_weight.shape
                    
                    ops.trans_w16_gemm(_weight, weight.data, _weight.shape[0], _weight.shape[1])
                    weight.data.copy_(_weight)
                    
zhuwenwen's avatar
zhuwenwen committed
594
595
                    weight.data=weight.data.reshape(ori_shape[1],-1)
            
596
        return loaded_params