deepseek_mtp.py 24.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

zhuwenwen's avatar
zhuwenwen committed
4
5
import os
import re
zhuwenwen's avatar
zhuwenwen committed
6

7
8
import typing
from collections.abc import Callable, Iterable
9
10
11
12
13

import torch
import torch.nn as nn
from transformers import PretrainedConfig

14
from vllm.forward_context import get_forward_context
15
from vllm._aiter_ops import rocm_aiter_ops
16
from vllm.compilation.decorators import support_torch_compile
17
from vllm.config import VllmConfig
18
from vllm.logger import init_logger
19
from vllm.model_executor.layers.fused_moe import SharedFusedMoE
20
21
22
23
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import (
24
25
26
    ParallelLMHead,
    VocabParallelEmbedding,
)
27
28
29
30
from vllm.model_executor.model_loader.weight_utils import (
    default_weight_loader,
    maybe_remap_kv_scale_name,
)
31
from vllm.platforms import current_platform
32
33
from vllm.sequence import IntermediateTensors

34
35
from .deepseek_v2 import (
    DeepseekV2DecoderLayer,
36
37
    DeepseekV2MixtureOfExperts,
    DeepseekV2MoE,
38
39
    get_spec_layer_idx_from_weight_name,
)
40
41
42
from vllm.distributed import (tensor_model_parallel_all_gather, 
                              get_tensor_model_parallel_rank,
                              get_tensor_model_parallel_world_size)
43
from .utils import maybe_prefix
王敏's avatar
王敏 committed
44
from .interfaces import SupportsPP
zhuwenwen's avatar
zhuwenwen committed
45
from vllm import _custom_ops as ops
zhuwenwen's avatar
zhuwenwen committed
46
from vllm.model_executor.layers.quantization.blockwise_int8 import BlockInt8Config
47
import vllm.envs as envs
48
from vllm.utils import direct_register_custom_op
49
logger = init_logger(__name__)
50
51
52
53
54
55


class SharedHead(nn.Module):
    def __init__(
        self,
        config: PretrainedConfig,
56
        prefix: str,
57
        quant_config: QuantizationConfig | None = None,
58
59
60
    ) -> None:
        super().__init__()
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
61
62
63
64
65
66
        self.head = ParallelLMHead(
            config.vocab_size,
            config.hidden_size,
            quant_config=quant_config,
            prefix=maybe_prefix(prefix, "head"),
        )
67
68
69
70
71
72

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        return self.norm(hidden_states)


class DeepSeekMultiTokenPredictorLayer(nn.Module):
73
    def __init__(self, vllm_config: VllmConfig, prefix: str) -> None:
74
        super().__init__()
75

76
77
        config = vllm_config.speculative_config.draft_model_config.hf_config
        self.config = config
78
79
        quant_config = vllm_config.quant_config

80
81
        self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
82
        self.eh_proj = nn.Linear(config.hidden_size * 2, config.hidden_size, bias=False)
83

84
        self.device = current_platform.device_type
85

yangql's avatar
yangql committed
86
        #添加判断,默认开启DSA
87
        force_disable_dsa = envs.VLLM_DISABLE_DSA
yangql's avatar
yangql committed
88
        self.is_v32 = hasattr(config, "index_topk") and not force_disable_dsa
89
90
91
92
93
94
        if self.is_v32:
            topk_tokens = config.index_topk
            topk_indices_buffer = torch.empty(
                vllm_config.scheduler_config.max_num_batched_tokens,
                topk_tokens,
                dtype=torch.int32,
95
                device=self.device,
96
            )
97
98
        else:
            topk_indices_buffer = None
99

100
101
102
103
        self.shared_head = SharedHead(
            config=config, prefix=prefix, quant_config=quant_config
        )
        self.mtp_block = DeepseekV2DecoderLayer(
104
105
106
107
            vllm_config,
            prefix,
            config=self.config,
            topk_indices_buffer=topk_indices_buffer,
108
        )
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
    def fuse_fill_rms_x2_concat(hidden_states_fuse: torch.Tensor, positions: torch.Tensor, inputs_embeds: torch.Tensor,
                                    previous_hidden_states: torch.Tensor, weight_inputs_embeds: torch.Tensor, 
                                    weight_previous_hidden_states: torch.Tensor, epsilon: float) -> None:
        from lightop import fuse_fill_rms_x2_concat
        fuse_fill_rms_x2_concat(hidden_states_fuse, positions, inputs_embeds, previous_hidden_states, weight_inputs_embeds, weight_previous_hidden_states, epsilon)

    def fuse_fill_rms_x2_concat_fake(hidden_states_fuse: torch.Tensor, positions: torch.Tensor, inputs_embeds: torch.Tensor,
                                    previous_hidden_states: torch.Tensor, weight_inputs_embeds: torch.Tensor, 
                                    weight_previous_hidden_states: torch.Tensor, epsilon: float) -> None:
        pass

    direct_register_custom_op(
        op_name="fuse_fill_rms_x2_concat",
        op_func=fuse_fill_rms_x2_concat,
        mutates_args=["hidden_states_fuse", "inputs_embeds"], 
        fake_impl=fuse_fill_rms_x2_concat_fake,
    )

127
128
129
130
131
132

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        previous_hidden_states: torch.Tensor,
133
        inputs_embeds: torch.Tensor | None = None,
134
135
136
137
        spec_step_index: int = 0,
    ) -> torch.Tensor:
        assert inputs_embeds is not None
        # masking inputs at position 0, as not needed by MTP
138
139
140
141
142
143
144
145
        if envs.VLLM_USE_FUSED_FILL_RMS_CAT:
            hidden_states_fuse = torch.empty(inputs_embeds.shape[0], inputs_embeds.shape[1]*2, device=inputs_embeds.device, dtype=inputs_embeds.dtype)
            torch.ops.vllm.fuse_fill_rms_x2_concat(hidden_states_fuse, positions, inputs_embeds, previous_hidden_states, self.enorm.weight, self.hnorm.weight, self.enorm.variance_epsilon)
            hidden_states = self.eh_proj(hidden_states_fuse)
        else:
            inputs_embeds = torch.where(positions.unsqueeze(-1) == 0, 0, inputs_embeds)
            inputs_embeds = self.enorm(inputs_embeds)
            previous_hidden_states = self.hnorm(previous_hidden_states)
146

147
148
149
            hidden_states = self.eh_proj(
                torch.cat([inputs_embeds, previous_hidden_states], dim=-1)
            )
150

151
152
153
        hidden_states, residual = self.mtp_block(
            positions=positions, hidden_states=hidden_states, residual=None
        )
154
        hidden_states = residual + hidden_states
155
        return hidden_states
156
157
158
159
160
161
162
163
164


class DeepSeekMultiTokenPredictor(nn.Module):
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config = vllm_config.model_config.hf_config
        self.mtp_start_layer_idx = config.num_hidden_layers
        self.num_mtp_layers = config.num_nextn_predict_layers
        # to map the exact layer index from weights
165

166
167
168
169
170
171
172
173
174
175
176
        self.layers = torch.nn.ModuleDict(
            {
                str(idx): DeepSeekMultiTokenPredictorLayer(
                    vllm_config, f"{prefix}.layers.{idx}"
                )
                for idx in range(
                    self.mtp_start_layer_idx,
                    self.mtp_start_layer_idx + self.num_mtp_layers,
                )
            }
        )
177
178
179
        self.embed_tokens = VocabParallelEmbedding(
            config.vocab_size,
            config.hidden_size,
180
            prefix=maybe_prefix(prefix, "embed_tokens"),
181
        )
182
183
        self.logits_processor = LogitsProcessor(config.vocab_size)

184
185
186
        self.tp_rank = get_tensor_model_parallel_rank()
        self.tp_size = get_tensor_model_parallel_world_size()

187
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
188
189
        return self.embed_tokens(input_ids)

190
191
192
193
194
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        previous_hidden_states: torch.Tensor,
195
        inputs_embeds: torch.Tensor | None = None,
196
197
        spec_step_idx: int = 0,
    ) -> torch.Tensor:
198
199
        if inputs_embeds is None:
            inputs_embeds = self.embed_tokens(input_ids)
200
        current_step_idx = spec_step_idx % self.num_mtp_layers
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
        enable_lightly_cp = get_forward_context().enable_lightly_cp
        if enable_lightly_cp:
            scatter_indexes_tensor = get_forward_context().scatter_indexes_tensor
            if scatter_indexes_tensor is None:
                inputs_embeds_per_rank = torch.chunk(inputs_embeds, chunks=self.tp_size, dim=0)
                inputs_embeds = inputs_embeds_per_rank[self.tp_rank].contiguous()

                previous_hidden_states_per_rank = torch.chunk(previous_hidden_states, chunks=self.tp_size, dim=0)
                previous_hidden_states = previous_hidden_states_per_rank[self.tp_rank].contiguous()

                if positions is not None:
                    positions_per_rank = torch.chunk(positions, chunks=self.tp_size, dim=0)
                    positions = positions_per_rank[self.tp_rank].contiguous()
            else:
                scatter_indexes_tensor = torch.where(scatter_indexes_tensor == -1, 0, scatter_indexes_tensor)
                inputs_embeds = torch.index_select(inputs_embeds, 0, scatter_indexes_tensor)

                previous_hidden_states = torch.index_select(previous_hidden_states, 0, scatter_indexes_tensor)
                if positions is not None:
                    positions = torch.index_select(positions, 0, scatter_indexes_tensor)

        hidden_states = self.layers[str(self.mtp_start_layer_idx + current_step_idx)](
223
224
225
226
            input_ids,
            positions,
            previous_hidden_states,
            inputs_embeds,
227
            current_step_idx,
228
229
        )

230
231
232
233
234
235
236
237
        if enable_lightly_cp:
            hidden_states = tensor_model_parallel_all_gather(hidden_states.contiguous(), dim=0)
            gather_indexes_tensor = get_forward_context().gather_indexes_tensor
            if gather_indexes_tensor is not None:
                hidden_states = torch.index_select(hidden_states, 0, gather_indexes_tensor)

        return hidden_states

238
239
240
241
242
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        spec_step_idx: int = 0,
    ) -> torch.Tensor:
243
244
245
246
247
        current_step_idx = spec_step_idx % self.num_mtp_layers
        mtp_layer = self.layers[str(self.mtp_start_layer_idx + current_step_idx)]
        logits = self.logits_processor(
            mtp_layer.shared_head.head, mtp_layer.shared_head(hidden_states)
        )
248
249
250
        return logits


251
@support_torch_compile
王敏's avatar
王敏 committed
252
class DeepSeekMTP(nn.Module, DeepseekV2MixtureOfExperts, SupportsPP):
253
254
255
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        self.config = vllm_config.model_config.hf_config
256
257
258
        self.model = DeepSeekMultiTokenPredictor(
            vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
        )
259
260
        # Set MoE hyperparameters
        self.set_moe_parameters()
zhuwenwen's avatar
zhuwenwen committed
261
262
263
264
265
266
267
268
        quant_config = vllm_config.quant_config

        self.quant_method = None
        if quant_config is not None:
            self.quant_method = quant_config.get_name()
            os.environ['LLAMA_NN'] = '0'
            os.environ['LM_NN'] = '0'
        self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
269
                
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287

    def set_moe_parameters(self):
        self.expert_weights = []
        self.num_moe_layers = self.config.num_nextn_predict_layers
        self.num_expert_groups = self.config.n_group

        self.moe_layers = []
        self.moe_mlp_layers = []
        example_moe = None
        for layer in self.model.layers.values():
            assert isinstance(layer, DeepSeekMultiTokenPredictorLayer)
            layer = layer.mtp_block
            assert isinstance(layer, DeepseekV2DecoderLayer)
            if isinstance(layer.mlp, DeepseekV2MoE):
                example_moe = layer.mlp
                self.moe_mlp_layers.append(layer.mlp)
                self.moe_layers.append(layer.mlp.experts)
        self.extract_moe_parameters(example_moe)
288

289
290
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.embed_input_ids(input_ids)
291
292
293

    def forward(
        self,
294
        input_ids: torch.Tensor | None,
295
        positions: torch.Tensor,
296
        hidden_states: torch.Tensor,
297
298
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
299
300
        spec_step_idx: int = 0,
    ) -> torch.Tensor:
301
302
303
        hidden_states = self.model(
            input_ids, positions, hidden_states, inputs_embeds, spec_step_idx
        )
304
305
306
307
308
309
        return hidden_states

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
        spec_step_idx: int = 0,
310
    ) -> torch.Tensor | None:
311
        return self.model.compute_logits(hidden_states, spec_step_idx)
312

313
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
314
315
316
        rocm_aiter_moe_shared_expert_enabled = (
            rocm_aiter_ops.is_fusion_moe_shared_experts_enabled()
        )
317
318
319
        stacked_params_mapping = [
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
320
321
            ("fused_qkv_a_proj", "q_a_proj", 0),
            ("fused_qkv_a_proj", "kv_a_proj_with_mqa", 1),
322
323
        ]

324
        expert_params_mapping = SharedFusedMoE.make_expert_params_mapping(
325
            self,
326
327
328
            ckpt_gate_proj_name="gate_proj",
            ckpt_down_proj_name="down_proj",
            ckpt_up_proj_name="up_proj",
329
330
331
332
333
334
            num_experts=self.config.n_routed_experts
            + (
                self.config.n_shared_experts
                if rocm_aiter_moe_shared_expert_enabled
                else 0
            ),
335
        )
336
337

        params_dict = dict(self.named_parameters())
338
        loaded_params: set[str] = set()
yangql's avatar
yangql committed
339
340
341
342

        # 判断是否加载"indexer"权重
        model_has_indexer = any("indexer" in param_name for param_name in params_dict.keys())

343
        for name, loaded_weight in weights:
344
345
346
347
348
349
350
351
352
353
            if "embed_tokens" in name:
                for local_name in params_dict.keys():
                    if "embed_tokens" in local_name:
                        param = params_dict[local_name]
                        weight_loader = getattr(
                            param, "weight_loader", default_weight_loader
                        )
                        weight_loader(param, loaded_weight)
                        break

354
355
            if "rotary_emb.inv_freq" in name:
                continue
yangql's avatar
yangql committed
356
357
358
359
360

            #  跳过加载"indexer"权重
            if "indexer" in name and not model_has_indexer:
                logger.info(f"Skipping indexer weight (DSA disabled): {name}")
                continue
361
362
363
            spec_layer = get_spec_layer_idx_from_weight_name(self.config, name)
            if spec_layer is None:
                continue
364
365
366
            is_fusion_moe_shared_experts_layer = (
                rocm_aiter_moe_shared_expert_enabled and ("mlp.shared_experts" in name)
            )
367
            name = self._rewrite_spec_layer_name(spec_layer, name)
368
            for param_name, weight_name, shard_id in stacked_params_mapping:
369
370
371
372
373
374
375
376
377
                # Skip non-stacked layers and experts (experts handled below).
                if weight_name not in name:
                    continue
                # We have mlp.experts[0].gate_proj in the checkpoint.
                # Since we handle the experts below in expert_params_mapping,
                # we need to skip here BEFORE we update the name, otherwise
                # name will be updated to mlp.experts[0].gate_up_proj, which
                # will then be updated below in expert_params_mapping
                # for mlp.experts[0].gate_gate_up_proj, which breaks load.
378
                if ("mlp.experts." in name) and name not in params_dict:
379
                    continue
380
                if is_fusion_moe_shared_experts_layer:
381
                    continue
382
                name_mapped = name.replace(weight_name, param_name)
383
384
385

                # QKV fusion is optional, fall back to normal
                # weight loading if it's not enabled
386
387
388
                if (
                    param_name == "fused_qkv_a_proj"
                ) and name_mapped not in params_dict:
389
                    continue
390
391
                else:
                    name = name_mapped
392

393
394
395
396
397
398
399
400
401
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue

                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
402
403
404
405
406
407
408
409
410
411
412
413
414
415
                # Special handling: when AITER fusion_shared_experts is enabled,
                # checkpoints may provide a single widened shared_experts tensor
                # without explicit expert indices
                # (e.g. ...mlp.shared_experts.gate_proj.weight).
                # For models with multiple shared experts, split that tensor
                # evenly into per-shared-expert slices and load them into
                # appended expert slots mlp.experts.{n_routed_experts + j}.*
                # accordingly.
                num_chunks = 1
                if is_fusion_moe_shared_experts_layer:
                    num_chunks = getattr(self.config, "n_shared_experts", 1) or 1
                    # Determine split axis based on op type
                    # gate/up: ColumnParallel → split along dim 0
                    # down: RowParallel → split along dim 1
416
417
418
419
420
                    split_dim = (
                        1
                        if ("down_proj.weight" in name and loaded_weight.ndim > 1)
                        else 0
                    )
421
422
423
424
                    total = loaded_weight.shape[split_dim]
                    assert total % num_chunks == 0, (
                        f"Shared expert weight dim {total} "
                        f"not divisible by num_chunks {num_chunks}"
425
                    )
426
427
428
429
430
431
432
                    chunk_size = total // num_chunks

                for j in range(num_chunks):
                    chunk_name = name
                    weight_to_load = loaded_weight

                    if is_fusion_moe_shared_experts_layer:
zhuwenwen's avatar
zhuwenwen committed
433
434
435
436
                        if split_dim == 0:
                            weight_to_load = loaded_weight[
                                j * chunk_size : (j + 1) * chunk_size, :
                            ]
437
                        else:
zhuwenwen's avatar
zhuwenwen committed
438
439
440
                            weight_to_load = loaded_weight[
                                :, j * chunk_size : (j + 1) * chunk_size
                            ]
441
442
443
444
445
446
447
448
449
450
                        # Synthesize an expert-style name so expert mapping
                        # can route it
                        chunk_name = name.replace(
                            "mlp.shared_experts",
                            f"mlp.experts.{self.config.n_routed_experts + j}",
                        )

                    # Use expert_params_mapping to locate the destination
                    # param and delegate to its expert-aware weight_loader
                    # with expert_id.
451
                    is_expert_weight = False
452
453
454
455
456
                    for mapping in expert_params_mapping:
                        param_name, weight_name, expert_id, shard_id = mapping
                        if weight_name not in chunk_name:
                            continue

457
458
459
460
                        # Anyway, this is an expert weight and should not be
                        # attempted to load as other weights later
                        is_expert_weight = True

461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
                        # Do not modify `name` since the loop may continue here
                        # Instead, create a new variable
                        name_mapped = chunk_name.replace(weight_name, param_name)

                        param = params_dict[name_mapped]
                        # We should ask the weight loader to return success or
                        # not here since otherwise we may skip experts with
                        # other available replicas.
                        weight_loader = typing.cast(
                            Callable[..., bool], param.weight_loader
                        )
                        success = weight_loader(
                            param,
                            weight_to_load,
                            name_mapped,
                            shard_id=shard_id,
                            expert_id=expert_id,
                            return_success=True,
                        )
                        if success:
                            if not is_fusion_moe_shared_experts_layer:
                                name = name_mapped
                            else:
                                loaded_params.add(name_mapped)
                            break
                    else:
487
488
489
490
491
492
                        if is_expert_weight:
                            # We've checked that this is an expert weight
                            # However it's not mapped locally to this rank
                            # So we simply skip it
                            continue

493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
                        # Skip loading extra bias for GPTQ models.
                        if name.endswith(".bias") and name not in params_dict:
                            continue

                        name = maybe_remap_kv_scale_name(name, params_dict)
                        if name is None:
                            continue

                        # According to DeepSeek-V3 Technical Report, MTP modules
                        # shares embedding layer. We only load the first weights.
                        if (
                            spec_layer != self.model.mtp_start_layer_idx
                            and ".layers" not in name
                        ):
                            continue

                        param = params_dict[name]
                        weight_loader = getattr(
                            param, "weight_loader", default_weight_loader
                        )
                        weight_loader(param, loaded_weight)
            if not is_fusion_moe_shared_experts_layer:
                loaded_params.add(name)
516
        
zhuwenwen's avatar
zhuwenwen committed
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
        if self.use_llama_nn and self.quant_method is None:
            lay_key_words = [
                "self_attn.eh_proj.weight",
                "self_attn.q_proj.weight",
                "self_attn.q_a_proj.weight",
                "self_attn.q_b_proj.weight",
                "self_attn.kv_a_proj_with_mqa.weight",
                "self_attn.kv_b_proj.weight",
                "self_attn.o_proj.weight",
                "mlp.gate_up_proj.weight",
                "mlp.down_proj.weight",
                "mlp.gate.weight",
                "shared_experts.gate_up_proj.weight",
                "shared_experts.down_proj.weight",
                "shared_head.head.weight",
            ]

            combined_words = "|".join(lay_key_words)
            
            for layername in loaded_params:
                weight = params_dict[layername]
                matches = re.findall(combined_words, layername)
                if matches:
                    _weight = torch.zeros_like(weight.data)
                    ori_shape =_weight.shape
                    
                    ops.trans_w16_gemm(_weight, weight.data, _weight.shape[0], _weight.shape[1])
                    weight.data.copy_(_weight)
                    
                    weight.data=weight.data.reshape(ori_shape[1],-1)
547
548
549
550
551
552
        return loaded_params

    def _rewrite_spec_layer_name(self, spec_layer: int, name: str) -> str:
        """
        Rewrite the weight name to match the format of the original model.
        Add .mtp_block for modules in transformer layer block for spec layer
553
        and rename shared layer weights to be top level.
554
555
        """
        spec_layer_weight_names = [
556
557
558
559
560
            "embed_tokens",
            "enorm",
            "hnorm",
            "eh_proj",
            "shared_head",
561
        ]
562
        shared_weight_names = ["embed_tokens"]
563
        spec_layer_weight = False
564
        shared_weight = False
565
566
567
        for weight_name in spec_layer_weight_names:
            if weight_name in name:
                spec_layer_weight = True
568
569
                if weight_name in shared_weight_names:
                    shared_weight = True
570
571
572
                break
        if not spec_layer_weight:
            # treat rest weights as weights for transformer layer block
573
574
575
            name = name.replace(
                f"model.layers.{spec_layer}.", f"model.layers.{spec_layer}.mtp_block."
            )
576
577
578
        elif shared_weight:
            # treat shared weights as top level weights
            name = name.replace(f"model.layers.{spec_layer}.", "model.")
579
        return name