deepseek_mtp.py 24.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

zhuwenwen's avatar
zhuwenwen committed
4
5
import os
import re
zhuwenwen's avatar
zhuwenwen committed
6

7
8
import typing
from collections.abc import Callable, Iterable
9
10
11
12
13

import torch
import torch.nn as nn
from transformers import PretrainedConfig

14
from vllm.forward_context import get_forward_context
15
from vllm._aiter_ops import rocm_aiter_ops
16
from vllm.compilation.decorators import support_torch_compile
17
from vllm.config import VllmConfig
18
from vllm.logger import init_logger
19
from vllm.model_executor.layers.fused_moe import SharedFusedMoE
20
21
22
23
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import (
24
25
26
    ParallelLMHead,
    VocabParallelEmbedding,
)
27
28
29
30
from vllm.model_executor.model_loader.weight_utils import (
    default_weight_loader,
    maybe_remap_kv_scale_name,
)
31
from vllm.platforms import current_platform
32
33
from vllm.sequence import IntermediateTensors

34
35
from .deepseek_v2 import (
    DeepseekV2DecoderLayer,
36
37
    DeepseekV2MixtureOfExperts,
    DeepseekV2MoE,
38
39
    get_spec_layer_idx_from_weight_name,
)
40
41
42
from vllm.distributed import (tensor_model_parallel_all_gather, 
                              get_tensor_model_parallel_rank,
                              get_tensor_model_parallel_world_size)
43
from .utils import maybe_prefix
王敏's avatar
王敏 committed
44
from .interfaces import SupportsPP
zhuwenwen's avatar
zhuwenwen committed
45
from vllm import _custom_ops as ops
zhuwenwen's avatar
zhuwenwen committed
46
from vllm.model_executor.layers.quantization.blockwise_int8 import BlockInt8Config
47
import vllm.envs as envs
48
from vllm.utils import direct_register_custom_op
49
logger = init_logger(__name__)
50
51
52
53
54
55


class SharedHead(nn.Module):
    def __init__(
        self,
        config: PretrainedConfig,
56
        prefix: str,
57
        quant_config: QuantizationConfig | None = None,
58
59
60
    ) -> None:
        super().__init__()
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
61
62
63
64
65
66
        self.head = ParallelLMHead(
            config.vocab_size,
            config.hidden_size,
            quant_config=quant_config,
            prefix=maybe_prefix(prefix, "head"),
        )
67
68
69
70
71
72

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        return self.norm(hidden_states)


class DeepSeekMultiTokenPredictorLayer(nn.Module):
73
    def __init__(self, vllm_config: VllmConfig, prefix: str) -> None:
74
        super().__init__()
75

76
77
        config = vllm_config.speculative_config.draft_model_config.hf_config
        self.config = config
78
79
        quant_config = vllm_config.quant_config

80
81
        self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
82
        self.eh_proj = nn.Linear(config.hidden_size * 2, config.hidden_size, bias=False)
83

84
        self.device = current_platform.device_type
85

yangql's avatar
yangql committed
86
        #添加判断,默认开启DSA
87
        force_disable_dsa = envs.VLLM_DISABLE_DSA
yangql's avatar
yangql committed
88
        self.is_v32 = hasattr(config, "index_topk") and not force_disable_dsa
89
90
91
92
93
94
        if self.is_v32:
            topk_tokens = config.index_topk
            topk_indices_buffer = torch.empty(
                vllm_config.scheduler_config.max_num_batched_tokens,
                topk_tokens,
                dtype=torch.int32,
95
                device=self.device,
96
            )
97
98
        else:
            topk_indices_buffer = None
99

100
101
102
103
        self.shared_head = SharedHead(
            config=config, prefix=prefix, quant_config=quant_config
        )
        self.mtp_block = DeepseekV2DecoderLayer(
104
105
106
107
            vllm_config,
            prefix,
            config=self.config,
            topk_indices_buffer=topk_indices_buffer,
108
        )
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
    def fuse_fill_rms_x2_concat(hidden_states_fuse: torch.Tensor, positions: torch.Tensor, inputs_embeds: torch.Tensor,
                                    previous_hidden_states: torch.Tensor, weight_inputs_embeds: torch.Tensor, 
                                    weight_previous_hidden_states: torch.Tensor, epsilon: float) -> None:
        from lightop import fuse_fill_rms_x2_concat
        fuse_fill_rms_x2_concat(hidden_states_fuse, positions, inputs_embeds, previous_hidden_states, weight_inputs_embeds, weight_previous_hidden_states, epsilon)

    def fuse_fill_rms_x2_concat_fake(hidden_states_fuse: torch.Tensor, positions: torch.Tensor, inputs_embeds: torch.Tensor,
                                    previous_hidden_states: torch.Tensor, weight_inputs_embeds: torch.Tensor, 
                                    weight_previous_hidden_states: torch.Tensor, epsilon: float) -> None:
        pass

    direct_register_custom_op(
        op_name="fuse_fill_rms_x2_concat",
        op_func=fuse_fill_rms_x2_concat,
        mutates_args=["hidden_states_fuse", "inputs_embeds"], 
        fake_impl=fuse_fill_rms_x2_concat_fake,
    )

127
128
129
130
131
132

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        previous_hidden_states: torch.Tensor,
133
        inputs_embeds: torch.Tensor | None = None,
134
135
136
137
        spec_step_index: int = 0,
    ) -> torch.Tensor:
        assert inputs_embeds is not None
        # masking inputs at position 0, as not needed by MTP
138
139
140
141
142
143
144
145
        if envs.VLLM_USE_FUSED_FILL_RMS_CAT:
            hidden_states_fuse = torch.empty(inputs_embeds.shape[0], inputs_embeds.shape[1]*2, device=inputs_embeds.device, dtype=inputs_embeds.dtype)
            torch.ops.vllm.fuse_fill_rms_x2_concat(hidden_states_fuse, positions, inputs_embeds, previous_hidden_states, self.enorm.weight, self.hnorm.weight, self.enorm.variance_epsilon)
            hidden_states = self.eh_proj(hidden_states_fuse)
        else:
            inputs_embeds = torch.where(positions.unsqueeze(-1) == 0, 0, inputs_embeds)
            inputs_embeds = self.enorm(inputs_embeds)
            previous_hidden_states = self.hnorm(previous_hidden_states)
146

147
148
149
            hidden_states = self.eh_proj(
                torch.cat([inputs_embeds, previous_hidden_states], dim=-1)
            )
150

151
152
153
        hidden_states, residual = self.mtp_block(
            positions=positions, hidden_states=hidden_states, residual=None
        )
154
        hidden_states = residual + hidden_states
155
        return hidden_states
156
157
158
159
160
161
162
163
164


class DeepSeekMultiTokenPredictor(nn.Module):
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config = vllm_config.model_config.hf_config
        self.mtp_start_layer_idx = config.num_hidden_layers
        self.num_mtp_layers = config.num_nextn_predict_layers
        # to map the exact layer index from weights
165

166
167
168
169
170
171
172
173
174
175
176
        self.layers = torch.nn.ModuleDict(
            {
                str(idx): DeepSeekMultiTokenPredictorLayer(
                    vllm_config, f"{prefix}.layers.{idx}"
                )
                for idx in range(
                    self.mtp_start_layer_idx,
                    self.mtp_start_layer_idx + self.num_mtp_layers,
                )
            }
        )
177
178
179
        self.embed_tokens = VocabParallelEmbedding(
            config.vocab_size,
            config.hidden_size,
180
            prefix=maybe_prefix(prefix, "embed_tokens"),
181
        )
182
183
        self.logits_processor = LogitsProcessor(config.vocab_size)

184
185
186
        self.tp_rank = get_tensor_model_parallel_rank()
        self.tp_size = get_tensor_model_parallel_world_size()

187
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
188
189
        return self.embed_tokens(input_ids)

190
191
192
193
194
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        previous_hidden_states: torch.Tensor,
195
        inputs_embeds: torch.Tensor | None = None,
196
197
        spec_step_idx: int = 0,
    ) -> torch.Tensor:
198
199
        if inputs_embeds is None:
            inputs_embeds = self.embed_tokens(input_ids)
200
        current_step_idx = spec_step_idx % self.num_mtp_layers
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
        enable_lightly_cp = get_forward_context().enable_lightly_cp
        if enable_lightly_cp:
            scatter_indexes_tensor = get_forward_context().scatter_indexes_tensor
            if scatter_indexes_tensor is None:
                inputs_embeds_per_rank = torch.chunk(inputs_embeds, chunks=self.tp_size, dim=0)
                inputs_embeds = inputs_embeds_per_rank[self.tp_rank].contiguous()

                previous_hidden_states_per_rank = torch.chunk(previous_hidden_states, chunks=self.tp_size, dim=0)
                previous_hidden_states = previous_hidden_states_per_rank[self.tp_rank].contiguous()

                if positions is not None:
                    positions_per_rank = torch.chunk(positions, chunks=self.tp_size, dim=0)
                    positions = positions_per_rank[self.tp_rank].contiguous()
            else:
                scatter_indexes_tensor = torch.where(scatter_indexes_tensor == -1, 0, scatter_indexes_tensor)
                inputs_embeds = torch.index_select(inputs_embeds, 0, scatter_indexes_tensor)

                previous_hidden_states = torch.index_select(previous_hidden_states, 0, scatter_indexes_tensor)
                if positions is not None:
                    positions = torch.index_select(positions, 0, scatter_indexes_tensor)

        hidden_states = self.layers[str(self.mtp_start_layer_idx + current_step_idx)](
223
224
225
226
            input_ids,
            positions,
            previous_hidden_states,
            inputs_embeds,
227
            current_step_idx,
228
229
        )

230
231
232
233
234
235
236
237
        if enable_lightly_cp:
            hidden_states = tensor_model_parallel_all_gather(hidden_states.contiguous(), dim=0)
            gather_indexes_tensor = get_forward_context().gather_indexes_tensor
            if gather_indexes_tensor is not None:
                hidden_states = torch.index_select(hidden_states, 0, gather_indexes_tensor)

        return hidden_states

238
239
240
241
242
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        spec_step_idx: int = 0,
    ) -> torch.Tensor:
243
244
245
246
247
        current_step_idx = spec_step_idx % self.num_mtp_layers
        mtp_layer = self.layers[str(self.mtp_start_layer_idx + current_step_idx)]
        logits = self.logits_processor(
            mtp_layer.shared_head.head, mtp_layer.shared_head(hidden_states)
        )
248
249
250
        return logits


251
@support_torch_compile
王敏's avatar
王敏 committed
252
class DeepSeekMTP(nn.Module, DeepseekV2MixtureOfExperts, SupportsPP):
253
254
255
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        self.config = vllm_config.model_config.hf_config
256
257
258
        self.model = DeepSeekMultiTokenPredictor(
            vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
        )
259
260
        # Set MoE hyperparameters
        self.set_moe_parameters()
zhuwenwen's avatar
zhuwenwen committed
261
262
263
264
265
266
267
268
        quant_config = vllm_config.quant_config

        self.quant_method = None
        if quant_config is not None:
            self.quant_method = quant_config.get_name()
            os.environ['LLAMA_NN'] = '0'
            os.environ['LM_NN'] = '0'
        self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
王敏's avatar
王敏 committed
269
        self.use_pp = vllm_config.parallel_config.pipeline_parallel_size > 1
王敏's avatar
王敏 committed
270

271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288

    def set_moe_parameters(self):
        self.expert_weights = []
        self.num_moe_layers = self.config.num_nextn_predict_layers
        self.num_expert_groups = self.config.n_group

        self.moe_layers = []
        self.moe_mlp_layers = []
        example_moe = None
        for layer in self.model.layers.values():
            assert isinstance(layer, DeepSeekMultiTokenPredictorLayer)
            layer = layer.mtp_block
            assert isinstance(layer, DeepseekV2DecoderLayer)
            if isinstance(layer.mlp, DeepseekV2MoE):
                example_moe = layer.mlp
                self.moe_mlp_layers.append(layer.mlp)
                self.moe_layers.append(layer.mlp.experts)
        self.extract_moe_parameters(example_moe)
289

290
291
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.embed_input_ids(input_ids)
292
293
294

    def forward(
        self,
295
        input_ids: torch.Tensor | None,
296
        positions: torch.Tensor,
297
        hidden_states: torch.Tensor,
298
299
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
300
301
        spec_step_idx: int = 0,
    ) -> torch.Tensor:
302
303
304
        hidden_states = self.model(
            input_ids, positions, hidden_states, inputs_embeds, spec_step_idx
        )
305
306
307
308
309
310
        return hidden_states

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        spec_step_idx: int = 0,
311
    ) -> torch.Tensor | None:
312
        return self.model.compute_logits(hidden_states, spec_step_idx)
313

314
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
315
316
317
        rocm_aiter_moe_shared_expert_enabled = (
            rocm_aiter_ops.is_fusion_moe_shared_experts_enabled()
        )
318
319
320
        stacked_params_mapping = [
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
321
322
            ("fused_qkv_a_proj", "q_a_proj", 0),
            ("fused_qkv_a_proj", "kv_a_proj_with_mqa", 1),
323
324
        ]

325
        expert_params_mapping = SharedFusedMoE.make_expert_params_mapping(
326
            self,
327
328
329
            ckpt_gate_proj_name="gate_proj",
            ckpt_down_proj_name="down_proj",
            ckpt_up_proj_name="up_proj",
330
331
332
333
334
335
            num_experts=self.config.n_routed_experts
            + (
                self.config.n_shared_experts
                if rocm_aiter_moe_shared_expert_enabled
                else 0
            ),
336
        )
337
338

        params_dict = dict(self.named_parameters())
339
        loaded_params: set[str] = set()
yangql's avatar
yangql committed
340
341
342
343

        # 判断是否加载"indexer"权重
        model_has_indexer = any("indexer" in param_name for param_name in params_dict.keys())

344
345
346
        for name, loaded_weight in weights:
            if "rotary_emb.inv_freq" in name:
                continue
yangql's avatar
yangql committed
347
348
349
350
351

            #  跳过加载"indexer"权重
            if "indexer" in name and not model_has_indexer:
                logger.info(f"Skipping indexer weight (DSA disabled): {name}")
                continue
王敏's avatar
王敏 committed
352

353
354
            spec_layer = get_spec_layer_idx_from_weight_name(self.config, name)
            if spec_layer is None:
王敏's avatar
王敏 committed
355
356
                # load embed_tokens weight from target model in pp mode
                if "embed_tokens" in name and self.use_pp:
王敏's avatar
王敏 committed
357
358
359
360
361
362
363
364
                    for local_name in params_dict.keys():
                        if "embed_tokens" in local_name:
                            param = params_dict[local_name]
                            weight_loader = getattr(
                                param, "weight_loader", default_weight_loader
                            )
                            weight_loader(param, loaded_weight)
                            break
365
                continue
王敏's avatar
王敏 committed
366

367
368
369
            is_fusion_moe_shared_experts_layer = (
                rocm_aiter_moe_shared_expert_enabled and ("mlp.shared_experts" in name)
            )
王敏's avatar
王敏 committed
370

371
            name = self._rewrite_spec_layer_name(spec_layer, name)
372
            for param_name, weight_name, shard_id in stacked_params_mapping:
373
374
375
376
377
378
379
380
381
                # Skip non-stacked layers and experts (experts handled below).
                if weight_name not in name:
                    continue
                # We have mlp.experts[0].gate_proj in the checkpoint.
                # Since we handle the experts below in expert_params_mapping,
                # we need to skip here BEFORE we update the name, otherwise
                # name will be updated to mlp.experts[0].gate_up_proj, which
                # will then be updated below in expert_params_mapping
                # for mlp.experts[0].gate_gate_up_proj, which breaks load.
382
                if ("mlp.experts." in name) and name not in params_dict:
383
                    continue
384
                if is_fusion_moe_shared_experts_layer:
385
                    continue
386
                name_mapped = name.replace(weight_name, param_name)
387
388
389

                # QKV fusion is optional, fall back to normal
                # weight loading if it's not enabled
390
391
392
                if (
                    param_name == "fused_qkv_a_proj"
                ) and name_mapped not in params_dict:
393
                    continue
394
395
                else:
                    name = name_mapped
396

397
398
399
400
401
402
403
404
405
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue

                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
406
407
408
409
410
411
412
413
414
415
416
417
418
419
                # Special handling: when AITER fusion_shared_experts is enabled,
                # checkpoints may provide a single widened shared_experts tensor
                # without explicit expert indices
                # (e.g. ...mlp.shared_experts.gate_proj.weight).
                # For models with multiple shared experts, split that tensor
                # evenly into per-shared-expert slices and load them into
                # appended expert slots mlp.experts.{n_routed_experts + j}.*
                # accordingly.
                num_chunks = 1
                if is_fusion_moe_shared_experts_layer:
                    num_chunks = getattr(self.config, "n_shared_experts", 1) or 1
                    # Determine split axis based on op type
                    # gate/up: ColumnParallel → split along dim 0
                    # down: RowParallel → split along dim 1
420
421
422
423
424
                    split_dim = (
                        1
                        if ("down_proj.weight" in name and loaded_weight.ndim > 1)
                        else 0
                    )
425
426
427
428
                    total = loaded_weight.shape[split_dim]
                    assert total % num_chunks == 0, (
                        f"Shared expert weight dim {total} "
                        f"not divisible by num_chunks {num_chunks}"
429
                    )
430
431
432
433
434
435
436
                    chunk_size = total // num_chunks

                for j in range(num_chunks):
                    chunk_name = name
                    weight_to_load = loaded_weight

                    if is_fusion_moe_shared_experts_layer:
zhuwenwen's avatar
zhuwenwen committed
437
438
439
440
                        if split_dim == 0:
                            weight_to_load = loaded_weight[
                                j * chunk_size : (j + 1) * chunk_size, :
                            ]
441
                        else:
zhuwenwen's avatar
zhuwenwen committed
442
443
444
                            weight_to_load = loaded_weight[
                                :, j * chunk_size : (j + 1) * chunk_size
                            ]
445
446
447
448
449
450
451
452
453
454
                        # Synthesize an expert-style name so expert mapping
                        # can route it
                        chunk_name = name.replace(
                            "mlp.shared_experts",
                            f"mlp.experts.{self.config.n_routed_experts + j}",
                        )

                    # Use expert_params_mapping to locate the destination
                    # param and delegate to its expert-aware weight_loader
                    # with expert_id.
455
                    is_expert_weight = False
456
457
458
459
460
                    for mapping in expert_params_mapping:
                        param_name, weight_name, expert_id, shard_id = mapping
                        if weight_name not in chunk_name:
                            continue

461
462
463
464
                        # Anyway, this is an expert weight and should not be
                        # attempted to load as other weights later
                        is_expert_weight = True

465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
                        # Do not modify `name` since the loop may continue here
                        # Instead, create a new variable
                        name_mapped = chunk_name.replace(weight_name, param_name)

                        param = params_dict[name_mapped]
                        # We should ask the weight loader to return success or
                        # not here since otherwise we may skip experts with
                        # other available replicas.
                        weight_loader = typing.cast(
                            Callable[..., bool], param.weight_loader
                        )
                        success = weight_loader(
                            param,
                            weight_to_load,
                            name_mapped,
                            shard_id=shard_id,
                            expert_id=expert_id,
                            return_success=True,
                        )
                        if success:
                            if not is_fusion_moe_shared_experts_layer:
                                name = name_mapped
                            else:
                                loaded_params.add(name_mapped)
                            break
                    else:
491
492
493
494
495
496
                        if is_expert_weight:
                            # We've checked that this is an expert weight
                            # However it's not mapped locally to this rank
                            # So we simply skip it
                            continue

497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
                        # Skip loading extra bias for GPTQ models.
                        if name.endswith(".bias") and name not in params_dict:
                            continue

                        name = maybe_remap_kv_scale_name(name, params_dict)
                        if name is None:
                            continue

                        # According to DeepSeek-V3 Technical Report, MTP modules
                        # shares embedding layer. We only load the first weights.
                        if (
                            spec_layer != self.model.mtp_start_layer_idx
                            and ".layers" not in name
                        ):
                            continue

                        param = params_dict[name]
                        weight_loader = getattr(
                            param, "weight_loader", default_weight_loader
                        )
                        weight_loader(param, loaded_weight)
            if not is_fusion_moe_shared_experts_layer:
                loaded_params.add(name)
520
        
zhuwenwen's avatar
zhuwenwen committed
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
        if self.use_llama_nn and self.quant_method is None:
            lay_key_words = [
                "self_attn.eh_proj.weight",
                "self_attn.q_proj.weight",
                "self_attn.q_a_proj.weight",
                "self_attn.q_b_proj.weight",
                "self_attn.kv_a_proj_with_mqa.weight",
                "self_attn.kv_b_proj.weight",
                "self_attn.o_proj.weight",
                "mlp.gate_up_proj.weight",
                "mlp.down_proj.weight",
                "mlp.gate.weight",
                "shared_experts.gate_up_proj.weight",
                "shared_experts.down_proj.weight",
                "shared_head.head.weight",
            ]

            combined_words = "|".join(lay_key_words)
            
            for layername in loaded_params:
                weight = params_dict[layername]
                matches = re.findall(combined_words, layername)
                if matches:
                    _weight = torch.zeros_like(weight.data)
                    ori_shape =_weight.shape
                    
                    ops.trans_w16_gemm(_weight, weight.data, _weight.shape[0], _weight.shape[1])
                    weight.data.copy_(_weight)
                    
                    weight.data=weight.data.reshape(ori_shape[1],-1)
551
552
553
554
555
556
        return loaded_params

    def _rewrite_spec_layer_name(self, spec_layer: int, name: str) -> str:
        """
        Rewrite the weight name to match the format of the original model.
        Add .mtp_block for modules in transformer layer block for spec layer
557
        and rename shared layer weights to be top level.
558
559
        """
        spec_layer_weight_names = [
560
561
562
563
564
            "embed_tokens",
            "enorm",
            "hnorm",
            "eh_proj",
            "shared_head",
565
        ]
566
        shared_weight_names = ["embed_tokens"]
567
        spec_layer_weight = False
568
        shared_weight = False
569
570
571
        for weight_name in spec_layer_weight_names:
            if weight_name in name:
                spec_layer_weight = True
572
573
                if weight_name in shared_weight_names:
                    shared_weight = True
574
575
576
                break
        if not spec_layer_weight:
            # treat rest weights as weights for transformer layer block
577
578
579
            name = name.replace(
                f"model.layers.{spec_layer}.", f"model.layers.{spec_layer}.mtp_block."
            )
580
581
582
        elif shared_weight:
            # treat shared weights as top level weights
            name = name.replace(f"model.layers.{spec_layer}.", "model.")
583
        return name