deepseek_v2.py 64.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

wangding zeng's avatar
wangding zeng committed
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# Copyright 2023 The vLLM team.
# Copyright 2023 DeepSeek-AI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
25
"""Inference-only DeepseekV2/DeepseekV3 model."""
26

27
28
import typing
from collections.abc import Callable, Iterable
29
from itertools import islice
wangding zeng's avatar
wangding zeng committed
30
31
32

import torch
from torch import nn
33
from transformers import DeepseekV2Config, DeepseekV3Config
wangding zeng's avatar
wangding zeng committed
34

35
import vllm._custom_ops as ops
36
from vllm._aiter_ops import rocm_aiter_ops
37
from vllm.compilation.decorators import support_torch_compile
38
39
40
41
42
43
44
45
from vllm.config import CacheConfig, ParallelConfig, VllmConfig, get_current_vllm_config
from vllm.distributed import (
    get_ep_group,
    get_pp_group,
    get_tensor_model_parallel_rank,
    get_tensor_model_parallel_world_size,
    tensor_model_parallel_all_gather,
)
46
from vllm.logger import init_logger
wangding zeng's avatar
wangding zeng committed
47
from vllm.model_executor.layers.activation import SiluAndMul
48
from vllm.model_executor.layers.attention import Attention
49
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
50
from vllm.model_executor.layers.fused_moe import (
51
    FusedMoE,
52
53
    GateLinear,
    RoutingMethodType,
54
    fused_moe_make_expert_params_mapping,
55
)
56
from vllm.model_executor.layers.layernorm import LayerNorm, RMSNorm
57
58
59
from vllm.model_executor.layers.linear import (
    ColumnParallelLinear,
    MergedColumnParallelLinear,
60
    QKVParallelLinear,
61
62
63
    ReplicatedLinear,
    RowParallelLinear,
)
wangding zeng's avatar
wangding zeng committed
64
from vllm.model_executor.layers.logits_processor import LogitsProcessor
65
from vllm.model_executor.layers.mla import MLAModules, MultiHeadLatentAttentionWrapper
66
from vllm.model_executor.layers.quantization import QuantizationConfig
67
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
68
69
    per_token_group_quant_fp8,
)
70
71
72
73
from vllm.model_executor.layers.quantization.utils.quant_utils import (
    GroupShape,
    scaled_dequantize,
)
wangding zeng's avatar
wangding zeng committed
74
from vllm.model_executor.layers.rotary_embedding import get_rope
75
76
77
from vllm.model_executor.layers.sparse_attn_indexer import (
    SparseAttnIndexer,
)
wangding zeng's avatar
wangding zeng committed
78
from vllm.model_executor.layers.vocab_parallel_embedding import (
79
80
81
    ParallelLMHead,
    VocabParallelEmbedding,
)
82
from vllm.model_executor.model_loader.weight_utils import (
83
84
85
    default_weight_loader,
    maybe_remap_kv_scale_name,
)
86
from vllm.model_executor.models.utils import sequence_parallel_chunk
87
from vllm.platforms import current_platform
88
from vllm.sequence import IntermediateTensors
89
from vllm.utils.torch_utils import direct_register_custom_op
90
from vllm.v1.attention.backend import AttentionBackend
91
92
93
from vllm.v1.attention.backends.mla.indexer import (
    DeepseekV32IndexerBackend,
)
94
from vllm.v1.kv_cache_interface import KVCacheSpec, MLAAttentionSpec
wangding zeng's avatar
wangding zeng committed
95

96
97
98
99
100
101
102
from .interfaces import (
    MixtureOfExperts,
    SupportsEagle,
    SupportsEagle3,
    SupportsLoRA,
    SupportsPP,
)
103
104
105
106
107
108
109
from .utils import (
    PPMissingLayer,
    is_pp_missing_parameter,
    make_empty_intermediate_tensors_factory,
    make_layers,
    maybe_prefix,
)
110

111
112
logger = init_logger(__name__)

wangding zeng's avatar
wangding zeng committed
113

114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
class DeepseekAttention(nn.Module):
    """Normal MHA implementation used by Deepseek v1."""

    def __init__(
        self,
        vllm_config: VllmConfig,
        config: DeepseekV2Config | DeepseekV3Config,
        hidden_size: int,
        num_heads: int,
        max_position_embeddings: int = 8192,
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
        prefix: str = "",
        **kwargs,
    ) -> None:
        super().__init__()
        self.hidden_size = hidden_size
        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = num_heads
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        self.total_num_kv_heads = config.num_key_value_heads
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
        self.head_dim = hidden_size // self.total_num_heads
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scaling = self.head_dim**-0.5
        self.max_position_embeddings = max_position_embeddings

        self.qkv_proj = QKVParallelLinear(
            hidden_size,
            self.head_dim,
            self.total_num_heads,
            self.total_num_kv_heads,
            bias=False,
            quant_config=quant_config,
        )

        self.o_proj = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            hidden_size,
            bias=False,
            quant_config=quant_config,
        )

        self.rotary_emb = get_rope(
            self.head_dim,
            max_position=max_position_embeddings,
170
            rope_parameters=config.rope_parameters,
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
        )
        self.attn = Attention(
            self.num_heads,
            self.head_dim,
            self.scaling,
            num_kv_heads=self.num_kv_heads,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.attn",
        )

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
        q, k = self.rotary_emb(positions, q, k)
        attn_output = self.attn(q, k, v)
        output, _ = self.o_proj(attn_output)
        return output


wangding zeng's avatar
wangding zeng committed
195
196
197
198
199
200
class DeepseekV2MLP(nn.Module):
    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
201
        quant_config: QuantizationConfig | None = None,
wangding zeng's avatar
wangding zeng committed
202
        reduce_results: bool = True,
203
        is_sequence_parallel=False,
204
        prefix: str = "",
wangding zeng's avatar
wangding zeng committed
205
206
    ) -> None:
        super().__init__()
207
208
209
210
211

        # If is_sequence_parallel, the input and output tensors are sharded
        # across the ranks within the tp_group. In this case the weights are
        # replicated and no collective ops are needed.
        # Otherwise we use standard TP with an allreduce at the end.
wangding zeng's avatar
wangding zeng committed
212
        self.gate_up_proj = MergedColumnParallelLinear(
213
214
            hidden_size,
            [intermediate_size] * 2,
wangding zeng's avatar
wangding zeng committed
215
            bias=False,
216
            quant_config=quant_config,
217
            disable_tp=is_sequence_parallel,
218
219
220
221
222
223
224
225
226
227
228
            prefix=f"{prefix}.gate_up_proj",
        )
        self.down_proj = RowParallelLinear(
            intermediate_size,
            hidden_size,
            bias=False,
            quant_config=quant_config,
            reduce_results=reduce_results,
            disable_tp=is_sequence_parallel,
            prefix=f"{prefix}.down_proj",
        )
wangding zeng's avatar
wangding zeng committed
229
        if hidden_act != "silu":
230
231
232
            raise ValueError(
                f"Unsupported activation: {hidden_act}. Only silu is supported for now."
            )
wangding zeng's avatar
wangding zeng committed
233
234
235
236
237
238
239
240
241
242
243
244
        self.act_fn = SiluAndMul()

    def forward(self, x):
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.down_proj(x)
        return x


class DeepseekV2MoE(nn.Module):
    def __init__(
        self,
245
        config: DeepseekV2Config | DeepseekV3Config,
246
        parallel_config: ParallelConfig,
247
        quant_config: QuantizationConfig | None = None,
248
        prefix: str = "",
wangding zeng's avatar
wangding zeng committed
249
250
251
    ):
        super().__init__()
        self.tp_size = get_tensor_model_parallel_world_size()
252
253
        self.tp_rank = get_tensor_model_parallel_rank()

254
        self.routed_scaling_factor = getattr(config, "routed_scaling_factor", 1.0)
255
256

        self.ep_group = get_ep_group().device_group
257
        self.ep_rank = get_ep_group().rank_in_group
258
259
260
        self.ep_size = self.ep_group.size()
        self.n_routed_experts: int = config.n_routed_experts
        self.n_shared_experts: int = config.n_shared_experts
261

262
        self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe
263

264
        if config.hidden_act != "silu":
265
266
267
268
269
            raise ValueError(
                f"Unsupported activation: {config.hidden_act}. "
                "Only silu is supported for now."
            )

270
        self.gate = GateLinear(
271
272
273
274
            config.hidden_size,
            config.n_routed_experts,
            prefix=f"{prefix}.gate",
        )
275
        if getattr(config, "topk_method", None) == "noaux_tc":
276
            self.gate.e_score_correction_bias = nn.Parameter(
277
278
                torch.empty(config.n_routed_experts, dtype=torch.float32)
            )
279
280
281
        else:
            self.gate.e_score_correction_bias = None

282
        # Load balancing settings.
283
284
        eplb_config = parallel_config.eplb_config
        self.enable_eplb = parallel_config.enable_eplb
285

286
        self.n_redundant_experts = eplb_config.num_redundant_experts
287
        self.n_logical_experts = self.n_routed_experts
288
        self.n_physical_experts = self.n_logical_experts + self.n_redundant_experts
289
290
        self.n_local_physical_experts = self.n_physical_experts // self.ep_size

291
292
293
294
        self.physical_expert_start = self.ep_rank * self.n_local_physical_experts
        self.physical_expert_end = (
            self.physical_expert_start + self.n_local_physical_experts
        )
295

296
        self.is_rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled()
297
298
299
300
        self.is_fusion_moe_shared_experts_enabled = (
            rocm_aiter_ops.is_fusion_moe_shared_experts_enabled()
        )
        if config.n_shared_experts is None or self.is_fusion_moe_shared_experts_enabled:
301
302
            self.shared_experts = None
        else:
303
            intermediate_size = config.moe_intermediate_size * config.n_shared_experts
304

wangding zeng's avatar
wangding zeng committed
305
306
307
308
309
            self.shared_experts = DeepseekV2MLP(
                hidden_size=config.hidden_size,
                intermediate_size=intermediate_size,
                hidden_act=config.hidden_act,
                quant_config=quant_config,
310
                is_sequence_parallel=self.is_sequence_parallel,
311
                reduce_results=False,
312
                prefix=f"{prefix}.shared_experts",
wangding zeng's avatar
wangding zeng committed
313
314
            )

315
        self.experts = FusedMoE(
316
            shared_experts=self.shared_experts,
317
            gate=self.gate,
318
319
320
321
322
323
            num_experts=config.n_routed_experts,
            top_k=config.num_experts_per_tok,
            hidden_size=config.hidden_size,
            intermediate_size=config.moe_intermediate_size,
            renormalize=config.norm_topk_prob,
            quant_config=quant_config,
324
325
326
            use_grouped_topk=True,
            num_expert_group=getattr(config, "n_group", 1),
            topk_group=getattr(config, "topk_group", 1),
327
            prefix=f"{prefix}.experts",
328
            scoring_func=getattr(config, "scoring_func", "softmax"),
329
            # aiter applies routed_scaling_factor internally
330
331
            routed_scaling_factor=self.routed_scaling_factor,
            apply_routed_scale_to_output=not self.is_rocm_aiter_moe_enabled,
332
333
334
335
            e_score_correction_bias=self.gate.e_score_correction_bias,
            enable_eplb=self.enable_eplb,
            num_redundant_experts=self.n_redundant_experts,
            is_sequence_parallel=self.is_sequence_parallel,
336
            n_shared_experts=config.n_shared_experts
337
            if self.is_fusion_moe_shared_experts_enabled
338
            else None,
339
        )
340

341
342
343
        # NOTE(rob): this is a hack until we finish off the PR for
        # merging TRTLLM kernels into the MK framework. Then we can
        # query the MonolithicMK for the expected router logits.
344
        # NOTE(dbari): Use BF16 if routing is not Deepseek, e.g. Mistral Large 3
345
        self.gate.set_out_dtype(
346
347
348
349
            torch.float32
            if self.experts.quant_method.is_monolithic
            and self.experts.routing_method_type == RoutingMethodType.DeepSeekV3
            else torch.bfloat16
350
351
        )

352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
        # Pre-cast the bias to match the gate output dtype so the
        # conversion is not repeated on every forward pass.  All
        # downstream references (FusedMoE, router) share the same
        # nn.Parameter object, so mutating .data propagates everywhere.
        # Weight loading uses copy_(), which handles the dtype conversion.
        # Only needed on ROCm where the aiter biased_grouped_topk kernel
        # requires the bias dtype to match the gating output dtype.
        if (
            self.is_rocm_aiter_moe_enabled
            and self.gate.e_score_correction_bias is not None
        ):
            self.gate.e_score_correction_bias.data = (
                self.gate.e_score_correction_bias.data.to(self.gate.out_dtype)
            )

wangding zeng's avatar
wangding zeng committed
367
368
369
    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        num_tokens, hidden_dim = hidden_states.shape
        hidden_states = hidden_states.view(-1, hidden_dim)
370
371
372
373
374
375

        # Chunk the hidden states so they aren't replicated across TP ranks.
        # This avoids duplicate computation in self.experts.
        # TODO: We can replace the all_reduce at the end of attn with a
        # reduce_scatter instead of chunking here.
        if self.is_sequence_parallel:
376
            hidden_states = sequence_parallel_chunk(hidden_states)
377

378
        if self.experts.is_internal_router:
379
            final_hidden_states = self.experts(
380
381
382
383
                hidden_states=hidden_states, router_logits=hidden_states
            )
        else:
            router_logits, _ = self.gate(hidden_states)
384
            final_hidden_states = self.experts(
385
386
                hidden_states=hidden_states, router_logits=router_logits
            )
387

388
389
        if self.is_sequence_parallel:
            final_hidden_states = tensor_model_parallel_all_gather(
390
391
                final_hidden_states, 0
            )
392
            final_hidden_states = final_hidden_states[:num_tokens]
wangding zeng's avatar
wangding zeng committed
393
394
395
396
397
398

        return final_hidden_states.view(num_tokens, hidden_dim)


def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
    import math
399

wangding zeng's avatar
wangding zeng committed
400
401
402
403
404
    if scale <= 1:
        return 1.0
    return 0.1 * mscale * math.log(scale) + 1.0


405
406
407
408
409
410
411
412
413
414
def _get_llama_4_scaling(
    original_max_position_embeddings: int, scaling_beta: float, positions: torch.Tensor
) -> torch.Tensor:
    scaling = 1 + scaling_beta * torch.log(
        1 + torch.floor(positions / original_max_position_embeddings)
    )
    # Broadcast over num_heads and head_dim
    return scaling[..., None, None]


wangding zeng's avatar
wangding zeng committed
415
416
417
class DeepseekV2Attention(nn.Module):
    def __init__(
        self,
418
        vllm_config: VllmConfig,
419
        config: DeepseekV2Config | DeepseekV3Config,
wangding zeng's avatar
wangding zeng committed
420
421
422
423
424
425
426
427
        hidden_size: int,
        num_heads: int,
        qk_nope_head_dim: int,
        qk_rope_head_dim: int,
        v_head_dim: int,
        q_lora_rank: int,
        kv_lora_rank: int,
        max_position_embeddings: int = 8192,
428
429
430
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
        topk_indices_buffer: torch.Tensor | None = None,
431
        prefix: str = "",
wangding zeng's avatar
wangding zeng committed
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
    ) -> None:
        super().__init__()
        self.hidden_size = hidden_size
        self.qk_nope_head_dim = qk_nope_head_dim
        self.qk_rope_head_dim = qk_rope_head_dim
        self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
        self.v_head_dim = v_head_dim
        self.q_lora_rank = q_lora_rank
        self.kv_lora_rank = kv_lora_rank
        self.num_heads = num_heads
        tp_size = get_tensor_model_parallel_world_size()
        assert num_heads % tp_size == 0
        self.num_local_heads = num_heads // tp_size
        self.scaling = self.qk_head_dim**-0.5
        self.max_position_embeddings = max_position_embeddings
447
448
        assert topk_indices_buffer is None, (
            "topk_indices_buffer is not \
449
        supported for DeepseekV2Attention"
450
        )
wangding zeng's avatar
wangding zeng committed
451
452

        if self.q_lora_rank is not None:
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
            self.q_a_proj = ReplicatedLinear(
                self.hidden_size,
                self.q_lora_rank,
                bias=False,
                quant_config=quant_config,
                prefix=f"{prefix}.q_a_proj",
            )
            self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
            self.q_b_proj = ColumnParallelLinear(
                q_lora_rank,
                self.num_heads * self.qk_head_dim,
                bias=False,
                quant_config=quant_config,
                prefix=f"{prefix}.q_b_proj",
            )
wangding zeng's avatar
wangding zeng committed
468
        else:
469
470
471
472
473
474
475
            self.q_proj = ColumnParallelLinear(
                self.hidden_size,
                self.num_heads * self.qk_head_dim,
                bias=False,
                quant_config=quant_config,
                prefix=f"{prefix}.q_proj",
            )
wangding zeng's avatar
wangding zeng committed
476

477
478
479
480
481
        self.kv_a_proj_with_mqa = ReplicatedLinear(
            self.hidden_size,
            self.kv_lora_rank + self.qk_rope_head_dim,
            bias=False,
            quant_config=quant_config,
482
483
484
            prefix=f"{prefix}.kv_a_proj_with_mqa",
        )
        self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
wangding zeng's avatar
wangding zeng committed
485
486
487
488
        self.kv_b_proj = ColumnParallelLinear(
            self.kv_lora_rank,
            self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
            bias=False,
489
            quant_config=quant_config,
490
491
            prefix=f"{prefix}.kv_b_proj",
        )
wangding zeng's avatar
wangding zeng committed
492
        # O projection.
493
494
495
496
497
498
499
        self.o_proj = RowParallelLinear(
            self.num_heads * self.v_head_dim,
            self.hidden_size,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.o_proj",
        )
500
        if config.rope_parameters["rope_type"] != "default":
501
502
503
504
505
            config.rope_parameters["rope_type"] = (
                "deepseek_yarn"
                if config.rope_parameters.get("apply_yarn_scaling", True)
                else "deepseek_llama_scaling"
            )
506

507
508
509
        self.rotary_emb = get_rope(
            qk_rope_head_dim,
            max_position=max_position_embeddings,
510
            rope_parameters=config.rope_parameters,
511
512
            is_neox_style=False,
        )
wangding zeng's avatar
wangding zeng committed
513

514
515
516
517
        if (
            config.rope_parameters["rope_type"] != "default"
            and config.rope_parameters["rope_type"] == "deepseek_yarn"
        ):
518
519
            mscale_all_dim = config.rope_parameters.get("mscale_all_dim", False)
            scaling_factor = config.rope_parameters["factor"]
wangding zeng's avatar
wangding zeng committed
520
521
522
            mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
            self.scaling = self.scaling * mscale * mscale

523
524
525
526
527
528
529
530
531
        self.attn = Attention(
            self.num_local_heads,
            self.qk_head_dim,
            self.scaling,
            num_kv_heads=self.num_local_heads,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.attn",
        )
wangding zeng's avatar
wangding zeng committed
532
533
534
535
536

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
537
        llama_4_scaling: torch.Tensor | None,
wangding zeng's avatar
wangding zeng committed
538
539
540
541
    ) -> torch.Tensor:
        if self.q_lora_rank is not None:
            q = self.q_a_proj(hidden_states)[0]
            q = self.q_a_layernorm(q)
542
            q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
wangding zeng's avatar
wangding zeng committed
543
        else:
544
545
546
547
            q = self.q_proj(hidden_states)[0].view(
                -1, self.num_local_heads, self.qk_head_dim
            )
        q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
wangding zeng's avatar
wangding zeng committed
548
        latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
549
        kv_a, _ = latent_cache.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
wangding zeng's avatar
wangding zeng committed
550
        latent_cache = latent_cache.unsqueeze(1)
551
        kv_a = self.kv_a_layernorm(kv_a)
wangding zeng's avatar
wangding zeng committed
552
        kv = self.kv_b_proj(kv_a)[0]
553
        kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim)
wangding zeng's avatar
wangding zeng committed
554
        k_nope, v = kv.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
555
        k_pe = latent_cache[:, :, self.kv_lora_rank :]
556

wangding zeng's avatar
wangding zeng committed
557
        q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
558

559
        q[..., self.qk_nope_head_dim :] = q_pe
wangding zeng's avatar
wangding zeng committed
560
        k = torch.empty_like(q)
561
562
        k[..., : self.qk_nope_head_dim] = k_nope
        k[..., self.qk_nope_head_dim :] = k_pe
563
564
565
566
567

        # Apply llama 4 scaling if provided
        if llama_4_scaling is not None:
            q *= llama_4_scaling

568
569
        # padding value to qk_head_dim for alignment
        v = torch.nn.functional.pad(
570
571
            v, [0, self.qk_head_dim - self.v_head_dim], value=0
        ).view(-1, self.num_local_heads * self.qk_head_dim)
572
        attn_output = self.attn(q, k, v)
573
574
575
        attn_output = attn_output.view(-1, self.num_local_heads, self.qk_head_dim)[
            ..., : self.v_head_dim
        ].reshape(-1, self.num_local_heads * self.v_head_dim)
wangding zeng's avatar
wangding zeng committed
576
577
578
579
        output, _ = self.o_proj(attn_output)
        return output


580
class DeepseekV32IndexerCache(torch.nn.Module, AttentionLayerBase):
581
582
583
    def __init__(
        self, head_dim: int, dtype: torch.dtype, prefix: str, cache_config: CacheConfig
    ):
584
        super().__init__()
585
        self.kv_cache = torch.tensor([])
586
587
588
589
590
591
592
593
594
        self.head_dim = head_dim
        self.prefix = prefix
        self.cache_config = cache_config
        self.dtype = dtype
        compilation_config = get_current_vllm_config().compilation_config
        if prefix in compilation_config.static_forward_context:
            raise ValueError(f"Duplicate layer name: {prefix}")
        compilation_config.static_forward_context[prefix] = self

595
    def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
596
597
598
599
600
601
602
        return MLAAttentionSpec(  # Only has one vector instead of K + V
            block_size=self.cache_config.block_size,
            num_kv_heads=1,
            head_size=self.head_dim,
            dtype=self.dtype,
        )

603
    def forward(self): ...
604
605
606
607
608
609

    def get_attn_backend(self) -> AttentionBackend:
        return DeepseekV32IndexerBackend


class Indexer(nn.Module):
610
611
612
    def __init__(
        self,
        vllm_config: VllmConfig,
613
        config: DeepseekV2Config | DeepseekV3Config,
614
615
        hidden_size: int,
        q_lora_rank: int,
616
617
618
        quant_config: QuantizationConfig | None,
        cache_config: CacheConfig | None,
        topk_indices_buffer: torch.Tensor | None,
619
620
        prefix: str = "",
    ):
621
622
623
        super().__init__()
        self.vllm_config = vllm_config
        self.config = config
624
        self.quant_config = quant_config
625
626
627
628
629
630
631
        # self.indexer_cfg = config.attn_module_list_cfg[0]["attn_index"]
        self.topk_tokens = config.index_topk
        self.n_head = config.index_n_heads  # 64
        self.head_dim = config.index_head_dim  # 128
        self.rope_dim = config.qk_rope_head_dim  # 64
        self.q_lora_rank = q_lora_rank  # 1536
        # no tensor parallel, just replicated
632
633
634
635
636
637
638
        self.wq_b = ReplicatedLinear(
            self.q_lora_rank,
            self.head_dim * self.n_head,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.wq_b",
        )
639
640
641
642
643
644
645
646
647
648
        # Fused wk + weights_proj: single GEMM producing [head_dim + n_head].
        # FP8 wk weights are upcasted to BF16 during loading to maintain fusion.
        self.wk_weights_proj = MergedColumnParallelLinear(
            hidden_size,
            [self.head_dim, self.n_head],
            bias=False,
            quant_config=None,
            disable_tp=True,
            prefix=f"{prefix}.wk_weights_proj",
        )
649
        self.k_norm = LayerNorm(self.head_dim, eps=1e-6)
650
651
652
653
654
655
656
657
658
659
        self.softmax_scale = self.head_dim**-0.5

        self.scale_fmt = "ue8m0"
        self.quant_block_size = 128  # TODO: get from config
        self.topk_indices_buffer = topk_indices_buffer

        # NOTE: (zyongye) we use fp8 naive cache,
        #       where we store value in fp8 and scale in fp32
        #       per self.quant_block_size element
        self.k_cache = DeepseekV32IndexerCache(
660
            head_dim=self.head_dim + self.head_dim // self.quant_block_size * 4,
661
662
            dtype=torch.uint8,
            prefix=f"{prefix}.k_cache",
663
664
            cache_config=cache_config,
        )
665
666
        self.max_model_len = vllm_config.model_config.max_model_len
        self.prefix = prefix
667
668
        from vllm.v1.attention.backends.mla.indexer import get_max_prefill_buffer_size

669
        self.max_total_seq_len = get_max_prefill_buffer_size(vllm_config)
670
671
672
673
674
675
676
677
678
679
        self.indexer_op = SparseAttnIndexer(
            self.k_cache,
            self.quant_block_size,
            self.scale_fmt,
            self.topk_tokens,
            self.head_dim,
            self.max_model_len,
            self.max_total_seq_len,
            self.topk_indices_buffer,
        )
680

681
682
683
    def forward(
        self, hidden_states: torch.Tensor, qr: torch.Tensor, positions, rotary_emb
    ) -> torch.Tensor:
684
685
686
        q, _ = self.wq_b(qr)
        q = q.view(-1, self.n_head, self.head_dim)
        q_pe, q_nope = torch.split(
687
688
            q, [self.rope_dim, self.head_dim - self.rope_dim], dim=-1
        )
689
690
691
692
        # Fused wk + weights_proj: one GEMM, then split
        kw, _ = self.wk_weights_proj(hidden_states)
        k = kw[:, : self.head_dim]
        weights = kw[:, self.head_dim :]
693

694
695
        k = self.k_norm(k)
        k_pe, k_nope = torch.split(
696
697
            k, [self.rope_dim, self.head_dim - self.rope_dim], dim=-1
        )
698
699

        q_pe, k_pe = rotary_emb(positions, q_pe, k_pe.unsqueeze(1))
700
701
702
703
704
        # Note: RoPE (NeoX) can introduce extra leading dimensions during compilation
        # so we need to reshape back to token-flattened shapes
        q_pe = q_pe.reshape(-1, self.n_head, self.rope_dim)
        k_pe = k_pe.reshape(-1, 1, self.rope_dim)

705
706
        # `rotary_emb` is shape-preserving; `q_pe` is already
        # [num_tokens, n_head, rope_dim].
707
708
        q = torch.cat([q_pe, q_nope], dim=-1)
        # `k_pe` is [num_tokens, 1, rope_dim] (MQA).
709
        k = torch.cat([k_pe.squeeze(-2), k_nope], dim=-1)
710
711
712

        # we only quant q here since k quant is fused with cache insertion
        q = q.view(-1, self.head_dim)
713
714
715
716
717
718
        q_fp8, q_scale = per_token_group_quant_fp8(
            q,
            self.quant_block_size,
            column_major_scales=False,
            use_ue8m0=self.scale_fmt is not None,
        )
719
720
721
        q_fp8 = q_fp8.view(-1, self.n_head, self.head_dim)
        q_scale = q_scale.view(-1, self.n_head, 1)

722
        weights = (
723
            weights.unsqueeze(-1) * q_scale * self.softmax_scale * self.n_head**-0.5
724
        )
725
726
        weights = weights.squeeze(-1)

727
        return self.indexer_op(hidden_states, q_fp8, k, weights)
728
729


730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
def _try_load_fp8_indexer_wk(name, tensor, buf, params_dict, loaded_params):
    """
    We fuse the WK and weights_proj projections, but in some checkpoints WK is stored
    in FP8 with a separate weight_scale_inv, while weights_proj is stored in BF16.
    Upcasting to BF16 during loading enables the fusion. This function loads the FP8 WK
    weights and scale, and when both are available, dequantizes to BF16 and stores into
    the fused wk_weights_proj.weight parameter.
    """
    if "indexer.wk." not in name or "wk_weights" in name:
        return False  # Weight is not an isolated WK weight for the indexer, ignore.
    is_weight = name.endswith(".weight") and tensor.dtype == torch.float8_e4m3fn
    is_scale = "weight_scale_inv" in name
    if not is_weight and not is_scale:
        return False  # WK is not in FP8 format, ignore.
    # Buffer this tensor (weight or scale) until both have arrived.
    layer_prefix = name.rsplit(".wk.", 1)[0]  # e.g. "model.layers.0.self_attn.indexer"
    entry = buf.setdefault(layer_prefix, {})
    entry["weight" if is_weight else "scale"] = tensor
    if "weight" not in entry or "scale" not in entry:
        return True  # still waiting for the other param

    # We have both weight and scale: dequantize FP8 to BF16.
    weight_fp8, scale_inv = entry["weight"], entry["scale"]
    del buf[layer_prefix]
    block_size = weight_fp8.shape[1] // scale_inv.shape[1]
    weight_bf16 = scaled_dequantize(
        weight_fp8,
        scale_inv,
        group_shape=GroupShape(block_size, block_size),
        out_dtype=torch.bfloat16,
    )

    # Load the dequantized weight into shard 0 of the fused buffer.
    fused_name = f"{layer_prefix}.wk_weights_proj.weight"
    param = params_dict[fused_name]
    param.weight_loader(param, weight_bf16, 0)
    loaded_params.add(fused_name)
    return True


770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
def _min_latency_fused_qkv_a_proj_impl(
    input_: torch.Tensor,
    weight: torch.Tensor,
) -> torch.Tensor:
    """
    Dynamically run min-latency gemm if num_tokens <= 16.
    This must be wrapped in a custom op because our torch.compile integration
    does not support runtime dispatching on num_tokens.
    """
    num_tokens = input_.shape[0]
    if 0 < num_tokens <= 16:
        output = torch.empty(
            num_tokens,
            weight.shape[0],
            dtype=torch.bfloat16,
            device=input_.device,
        )
        ops.dsv3_fused_a_gemm(output, input_, weight.T)
        return output
    else:
        return torch.nn.functional.linear(input_, weight)


def _min_latency_fused_qkv_a_proj_fake(
    input_: torch.Tensor,
    weight: torch.Tensor,
) -> torch.Tensor:
    return input_.new_empty(input_.shape[0], weight.shape[0])


direct_register_custom_op(
    op_name="min_latency_fused_qkv_a_proj",
    op_func=_min_latency_fused_qkv_a_proj_impl,
    mutates_args=[],
    fake_impl=_min_latency_fused_qkv_a_proj_fake,
)


808
class DeepSeekV2FusedQkvAProjLinear(MergedColumnParallelLinear):
809
810
811
    def __init__(
        self,
        input_size: int,
812
        output_size: list[int],
813
814
815
816
817
818
819
820
821
        quant_config: QuantizationConfig | None = None,
        prefix: str = "",
    ):
        super().__init__(
            input_size,
            output_size,
            bias=False,
            quant_config=quant_config,
            disable_tp=True,
822
            prefix=prefix,
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
        )

        # Check if the DeepSeek V3 fused A GEMM kernel can be used.
        # This kernel supports PDL and is optimized for low batch size.
        self._use_min_latency_gemm = (
            hasattr(self, "weight")
            and self.weight.dtype == torch.bfloat16
            and self.weight.shape[0] == 2112
            and self.weight.shape[1] == 7168
            and current_platform.is_cuda()
            and (
                current_platform.is_device_capability(90)
                or current_platform.is_device_capability_family(100)
            )
        )

    def forward(
        self,
        input_,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.nn.Parameter | None]:
843
844
        if self._use_min_latency_gemm:
            output = torch.ops.vllm.min_latency_fused_qkv_a_proj(input_, self.weight)
845
846
847
848
849
850
851
852
853
854
            if not self.return_bias:
                return output
            output_bias = self.bias if self.skip_bias_add else None
            return output, output_bias
        else:
            # Fallback to the standard forward method when
            # the fused A GEMM kernel cannot be used.
            return super().forward(input_)


855
856
857
858
class DeepseekV2MLAAttention(nn.Module):
    """
    Main reference: DeepseekV2 paper, and FlashInfer Implementation
    (https://arxiv.org/abs/2405.04434 and https://github.com/flashinfer-ai/flashinfer/pull/551).
859

860
861
        For more info see MLACommonImpl in:
        vllm/v1/attention/backends/mla/utils.py
862
863
864
865
    """

    def __init__(
        self,
866
        vllm_config: VllmConfig,
867
        config: DeepseekV2Config | DeepseekV3Config,
868
869
870
871
872
        hidden_size: int,
        num_heads: int,
        qk_nope_head_dim: int,
        qk_rope_head_dim: int,
        v_head_dim: int,
873
        q_lora_rank: int | None,
874
875
        kv_lora_rank: int,
        max_position_embeddings: int = 8192,
876
877
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
878
        prefix: str = "",
879
        topk_indices_buffer: torch.Tensor | None = None,
880
        input_size: int | None = None,
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
    ) -> None:
        super().__init__()
        self.hidden_size = hidden_size
        self.qk_nope_head_dim = qk_nope_head_dim
        self.qk_rope_head_dim = qk_rope_head_dim
        self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
        self.v_head_dim = v_head_dim

        self.q_lora_rank = q_lora_rank
        self.kv_lora_rank = kv_lora_rank

        self.num_heads = num_heads
        tp_size = get_tensor_model_parallel_world_size()
        assert num_heads % tp_size == 0
        self.num_local_heads = num_heads // tp_size

        self.scaling = self.qk_head_dim**-0.5
        self.max_position_embeddings = max_position_embeddings

900
901
902
903
        # Use input_size for projection input dimensions if provided,
        # otherwise default to hidden_size (used in Eagle3 Deepseek with MLA)
        proj_input_size = input_size if input_size is not None else self.hidden_size

904
        if self.q_lora_rank is not None:
905
            self.fused_qkv_a_proj = DeepSeekV2FusedQkvAProjLinear(
906
                proj_input_size,
907
908
                [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim],
                quant_config=quant_config,
909
                prefix=f"{prefix}.fused_qkv_a_proj",
910
            )
911
912
        else:
            self.kv_a_proj_with_mqa = ReplicatedLinear(
913
                proj_input_size,
914
915
916
                self.kv_lora_rank + self.qk_rope_head_dim,
                bias=False,
                quant_config=quant_config,
917
918
                prefix=f"{prefix}.kv_a_proj_with_mqa",
            )
919
920

        if self.q_lora_rank is not None:
921
922
923
924
925
926
927
928
            self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
            self.q_b_proj = ColumnParallelLinear(
                self.q_lora_rank,
                self.num_heads * self.qk_head_dim,
                bias=False,
                quant_config=quant_config,
                prefix=f"{prefix}.q_b_proj",
            )
929
        else:
930
            self.q_proj = ColumnParallelLinear(
931
                proj_input_size,
932
933
934
935
936
937
                self.num_heads * self.qk_head_dim,
                bias=False,
                quant_config=quant_config,
                prefix=f"{prefix}.q_proj",
            )
        self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
938
939
940
941
942
        self.kv_b_proj = ColumnParallelLinear(
            self.kv_lora_rank,
            self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
            bias=False,
            quant_config=quant_config,
943
944
945
946
947
948
949
950
951
            prefix=f"{prefix}.kv_b_proj",
        )
        self.o_proj = RowParallelLinear(
            self.num_heads * self.v_head_dim,
            self.hidden_size,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.o_proj",
        )
952

953
        if config.rope_parameters["rope_type"] != "default":
954
955
956
957
958
959
            config.rope_parameters["rope_type"] = (
                "deepseek_yarn"
                if config.rope_parameters.get("apply_yarn_scaling", True)
                else "deepseek_llama_scaling"
            )

960
961
962
        self.rotary_emb = get_rope(
            qk_rope_head_dim,
            max_position=max_position_embeddings,
963
            rope_parameters=config.rope_parameters,
964
965
            is_neox_style=False,
        )
966
967
968
969
970

        if (
            config.rope_parameters["rope_type"] != "default"
            and config.rope_parameters["rope_type"] == "deepseek_yarn"
        ):
971
972
            mscale_all_dim = config.rope_parameters.get("mscale_all_dim", False)
            scaling_factor = config.rope_parameters["factor"]
973
974
975
            mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
            self.scaling = self.scaling * mscale * mscale

976
977
978
        self.is_v32 = hasattr(config, "index_topk")

        if self.is_v32:
979
980
981
            self.indexer_rope_emb = get_rope(
                qk_rope_head_dim,
                max_position=max_position_embeddings,
982
                rope_parameters=config.rope_parameters,
983
                is_neox_style=not getattr(config, "indexer_rope_interleave", False),
984
            )
985
986
987
988
989
990
991
992
993
994
            self.indexer = Indexer(
                vllm_config,
                config,
                hidden_size,
                q_lora_rank,
                quant_config,
                cache_config,
                topk_indices_buffer,
                f"{prefix}.indexer",
            )
995
        else:
996
            self.indexer_rope_emb = None
997
998
            self.indexer = None

999
1000
        mla_modules = MLAModules(
            kv_a_layernorm=self.kv_a_layernorm,
1001
            kv_b_proj=self.kv_b_proj,
1002
1003
1004
            rotary_emb=self.rotary_emb,
            o_proj=self.o_proj,
            fused_qkv_a_proj=self.fused_qkv_a_proj
1005
1006
            if self.q_lora_rank is not None
            else None,
1007
            kv_a_proj_with_mqa=self.kv_a_proj_with_mqa
1008
1009
1010
            if self.q_lora_rank is None
            else None,
            q_a_layernorm=self.q_a_layernorm if self.q_lora_rank is not None else None,
1011
1012
            q_b_proj=self.q_b_proj if self.q_lora_rank is not None else None,
            q_proj=self.q_proj if self.q_lora_rank is None else None,
1013
            indexer=self.indexer,
1014
            indexer_rotary_emb=self.indexer_rope_emb,
1015
1016
            is_sparse=self.is_v32,
            topk_indices_buffer=topk_indices_buffer,
1017
        )
1018

1019
        self.mla_attn = MultiHeadLatentAttentionWrapper(
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
            self.hidden_size,
            self.num_local_heads,
            self.scaling,
            self.qk_nope_head_dim,
            self.qk_rope_head_dim,
            self.v_head_dim,
            self.q_lora_rank,
            self.kv_lora_rank,
            mla_modules,
            cache_config,
            quant_config,
            prefix,
1032
1033
1034
1035
1036
1037
        )

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
1038
        llama_4_scaling: torch.Tensor | None,
1039
    ) -> torch.Tensor:
1040
        return self.mla_attn(positions, hidden_states, llama_4_scaling)
1041
1042


wangding zeng's avatar
wangding zeng committed
1043
class DeepseekV2DecoderLayer(nn.Module):
1044
1045
1046
1047
    def __init__(
        self,
        vllm_config: VllmConfig,
        prefix: str,
1048
1049
        config: DeepseekV2Config | None = None,
        topk_indices_buffer: torch.Tensor | None = None,
1050
    ) -> None:
wangding zeng's avatar
wangding zeng committed
1051
        super().__init__()
1052

1053
1054
        if config is None:
            config = vllm_config.model_config.hf_config
1055
1056
1057
1058
1059
        model_config = vllm_config.model_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config
        parallel_config = vllm_config.parallel_config

wangding zeng's avatar
wangding zeng committed
1060
        self.hidden_size = config.hidden_size
1061
        max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
1062
        moe_layer_freq = getattr(config, "moe_layer_freq", 1)
1063
1064
        # DecoderLayers are created with `make_layers` which passes the prefix
        # with the layer's index.
1065
        layer_idx = int(prefix.split(sep=".")[-1])
1066
        self.layer_idx = layer_idx
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076

        # verify MLA attention specific fields
        qk_nope_head_dim = getattr(config, "qk_nope_head_dim", 0)
        qk_rope_head_dim = getattr(config, "qk_rope_head_dim", 0)
        v_head_dim = getattr(config, "v_head_dim", 0)
        kv_lora_rank = getattr(config, "kv_lora_rank", 0)
        use_mha = config.model_type == "deepseek" or all(
            dim == 0 for dim in (qk_nope_head_dim, qk_rope_head_dim)
        )

1077
1078
        self.use_mha = use_mha

1079
1080
1081
        if use_mha:
            attn_cls = DeepseekAttention
        elif model_config.use_mla:
1082
1083
1084
1085
            attn_cls = DeepseekV2MLAAttention
        else:
            attn_cls = DeepseekV2Attention
        self.self_attn = attn_cls(
1086
            vllm_config=vllm_config,
wangding zeng's avatar
wangding zeng committed
1087
1088
1089
            config=config,
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
1090
1091
1092
            qk_nope_head_dim=qk_nope_head_dim,
            qk_rope_head_dim=qk_rope_head_dim,
            v_head_dim=v_head_dim,
1093
            q_lora_rank=config.q_lora_rank if hasattr(config, "q_lora_rank") else None,
1094
            kv_lora_rank=kv_lora_rank,
wangding zeng's avatar
wangding zeng committed
1095
1096
1097
            max_position_embeddings=max_position_embeddings,
            cache_config=cache_config,
            quant_config=quant_config,
1098
            prefix=f"{prefix}.self_attn",
1099
            topk_indices_buffer=topk_indices_buffer,
wangding zeng's avatar
wangding zeng committed
1100
        )
1101

1102
1103
1104
        if (
            config.n_routed_experts is not None
            and layer_idx >= config.first_k_dense_replace
1105
            and layer_idx % moe_layer_freq == 0
1106
        ):
1107
1108
            self.mlp = DeepseekV2MoE(
                config=config,
1109
                parallel_config=parallel_config,
1110
1111
1112
                quant_config=quant_config,
                prefix=f"{prefix}.mlp",
            )
wangding zeng's avatar
wangding zeng committed
1113
1114
1115
1116
1117
1118
        else:
            self.mlp = DeepseekV2MLP(
                hidden_size=config.hidden_size,
                intermediate_size=config.intermediate_size,
                hidden_act=config.hidden_act,
                quant_config=quant_config,
1119
                prefix=f"{prefix}.mlp",
wangding zeng's avatar
wangding zeng committed
1120
            )
1121
1122
1123
1124
        self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(
            config.hidden_size, eps=config.rms_norm_eps
        )
1125
        self.routed_scaling_factor = getattr(config, "routed_scaling_factor", 1.0)
wangding zeng's avatar
wangding zeng committed
1126
1127
1128
1129
1130

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
1131
        residual: torch.Tensor | None,
1132
        llama_4_scaling: torch.Tensor | None = None,
wangding zeng's avatar
wangding zeng committed
1133
1134
1135
    ) -> torch.Tensor:
        # Self Attention
        if residual is None:
1136
            residual = hidden_states.clone()
wangding zeng's avatar
wangding zeng committed
1137
1138
            hidden_states = self.input_layernorm(hidden_states)
        else:
1139
            hidden_states, residual = self.input_layernorm(hidden_states, residual)
1140
1141
1142
1143
1144
1145
1146
1147

        attn_kwargs = {
            "positions": positions,
            "hidden_states": hidden_states,
        }
        if not self.use_mha:
            attn_kwargs["llama_4_scaling"] = llama_4_scaling
        hidden_states = self.self_attn(**attn_kwargs)
wangding zeng's avatar
wangding zeng committed
1148

1149
1150
1151
1152
        if (
            not isinstance(self.self_attn, DeepseekAttention)
            and hidden_states.dtype == torch.float16
        ):
1153
1154
1155
            # Fix FP16 overflow
            # We scale both hidden_states and residual before
            # rmsnorm, and rmsnorm result would not affect by scale.
1156
            hidden_states *= 1.0 / self.routed_scaling_factor
1157
1158
1159
            if self.layer_idx == 0:
                # The residual is shared by all layers, we only scale it on
                # first layer.
1160
                residual *= 1.0 / self.routed_scaling_factor
1161
1162

        # Fully Connected
1163
        hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
wangding zeng's avatar
wangding zeng committed
1164
        hidden_states = self.mlp(hidden_states)
1165

1166
        if isinstance(self.mlp, DeepseekV2MLP) and hidden_states.dtype == torch.float16:
1167
1168
1169
1170
1171
            # Fix FP16 overflow
            # Scaling the DeepseekV2MLP output, it is the input of
            # input_layernorm of next decoder layer.
            # The scaling of DeepseekV2MOE output would be done in the forward
            # of DeepseekV2MOE
1172
            hidden_states *= 1.0 / self.routed_scaling_factor
1173

wangding zeng's avatar
wangding zeng committed
1174
1175
1176
        return hidden_states, residual


1177
@support_torch_compile
wangding zeng's avatar
wangding zeng committed
1178
1179
1180
class DeepseekV2Model(nn.Module):
    fall_back_to_pt_during_load = False

1181
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
wangding zeng's avatar
wangding zeng committed
1182
        super().__init__()
1183
1184
1185

        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
1186
        self.config = config
1187
        self.device = current_platform.device_type
1188

wangding zeng's avatar
wangding zeng committed
1189
        self.vocab_size = config.vocab_size
1190
1191
1192
1193
1194
1195
1196
        self.is_v32 = hasattr(config, "index_topk")
        if self.is_v32:
            topk_tokens = config.index_topk
            topk_indices_buffer = torch.empty(
                vllm_config.scheduler_config.max_num_batched_tokens,
                topk_tokens,
                dtype=torch.int32,
1197
                device=self.device,
1198
            )
1199
1200
        else:
            topk_indices_buffer = None
wangding zeng's avatar
wangding zeng committed
1201

1202
1203
1204
1205
        if get_pp_group().is_first_rank:
            self.embed_tokens = VocabParallelEmbedding(
                config.vocab_size,
                config.hidden_size,
1206
                quant_config=quant_config,
1207
1208
                prefix=f"{prefix}.embed_tokens",
            )
1209
1210
1211
1212
        else:
            self.embed_tokens = PPMissingLayer()
        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
1213
            lambda prefix: DeepseekV2DecoderLayer(
1214
1215
1216
                vllm_config,
                prefix,
                topk_indices_buffer=topk_indices_buffer,
1217
1218
1219
            ),
            prefix=f"{prefix}.layers",
        )
1220
1221
1222
1223
1224

        if get_pp_group().is_last_rank:
            self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        else:
            self.norm = PPMissingLayer()
1225
1226
1227
        self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
            ["hidden_states", "residual"], config.hidden_size
        )
wangding zeng's avatar
wangding zeng committed
1228

1229
1230
        self.aux_hidden_state_layers = tuple[int, ...]()

1231
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
1232
1233
        return self.embed_tokens(input_ids)

wangding zeng's avatar
wangding zeng committed
1234
1235
    def forward(
        self,
1236
        input_ids: torch.Tensor | None,
wangding zeng's avatar
wangding zeng committed
1237
        positions: torch.Tensor,
1238
1239
1240
        intermediate_tensors: IntermediateTensors | None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor | IntermediateTensors:
1241
        if get_pp_group().is_first_rank:
1242
1243
1244
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
1245
1246
1247
1248
1249
                if input_ids is None:
                    raise ValueError(
                        "Either input_ids or inputs_embeds must be provided "
                        "to DeepseekV2Model.forward"
                    )
1250
                hidden_states = self.embed_input_ids(input_ids)
1251
1252
1253
1254
1255
1256
            residual = None
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]

1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
        # Compute llama 4 scaling once per forward pass if enabled
        llama_4_scaling_config = getattr(self.config, "llama_4_scaling", None)
        llama_4_scaling: torch.Tensor | None
        if llama_4_scaling_config is not None:
            llama_4_scaling = _get_llama_4_scaling(
                original_max_position_embeddings=llama_4_scaling_config[
                    "original_max_position_embeddings"
                ],
                scaling_beta=llama_4_scaling_config["beta"],
                positions=positions,
            )
        else:
            llama_4_scaling = None

1271
1272
1273
1274
1275
1276
1277
        aux_hidden_states = []
        for idx, layer in enumerate(
            islice(self.layers, self.start_layer, self.end_layer),
            start=self.start_layer,
        ):
            if idx in self.aux_hidden_state_layers:
                aux_hidden_states.append(hidden_states + residual)
1278
1279
1280
            hidden_states, residual = layer(
                positions, hidden_states, residual, llama_4_scaling
            )
1281
1282

        if not get_pp_group().is_last_rank:
1283
1284
1285
            return IntermediateTensors(
                {"hidden_states": hidden_states, "residual": residual}
            )
1286

wangding zeng's avatar
wangding zeng committed
1287
        hidden_states, _ = self.norm(hidden_states, residual)
1288
1289
        if len(aux_hidden_states) > 0:
            return hidden_states, aux_hidden_states
wangding zeng's avatar
wangding zeng committed
1290
1291
        return hidden_states

1292

1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
class DeepseekV2MixtureOfExperts(MixtureOfExperts):
    moe_mlp_layers: list[DeepseekV2MoE]
    """
    List of MoE MLP layers in the model.
    """

    def extract_moe_parameters(self, example_moe: DeepseekV2MoE | None):
        if example_moe is None:
            self.num_moe_layers = 0
            self.num_expert_groups = 0
            self.num_logical_experts = 0
            self.num_physical_experts = 0
            self.num_local_physical_experts = 0
            self.num_routed_experts = 0
            self.num_shared_experts = 0
            self.num_redundant_experts = 0
            logger.warning("DeepSeekV2: No DeepseekV2MoE layer found in model.layers.")
        else:
            self.num_logical_experts = example_moe.n_logical_experts
            self.num_physical_experts = example_moe.n_physical_experts
            self.num_local_physical_experts = example_moe.n_local_physical_experts
            self.num_routed_experts = example_moe.n_routed_experts
            self.num_shared_experts = example_moe.n_shared_experts
            self.num_redundant_experts = example_moe.n_redundant_experts

    def update_physical_experts_metadata(
        self,
        num_physical_experts: int,
        num_local_physical_experts: int,
    ) -> None:
        assert self.num_local_physical_experts == num_local_physical_experts
        self.num_physical_experts = num_physical_experts
        self.num_local_physical_experts = num_local_physical_experts
        self.num_redundant_experts = num_physical_experts - self.num_logical_experts
        for moe in self.moe_mlp_layers:
            moe.n_local_physical_experts = num_local_physical_experts
            moe.n_physical_experts = num_physical_experts
            moe.n_redundant_experts = self.num_redundant_experts
            moe.experts.update_expert_map()


class DeepseekV2ForCausalLM(
1335
1336
1337
1338
1339
1340
    nn.Module,
    SupportsPP,
    DeepseekV2MixtureOfExperts,
    SupportsLoRA,
    SupportsEagle,
    SupportsEagle3,
1341
):
1342
1343
1344
    packed_modules_mapping = {
        "gate_up_proj": ["gate_proj", "up_proj"],
    }
1345
    model_cls = DeepseekV2Model
1346
1347
1348
1349
1350
1351
1352

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        self.config = config
        self.quant_config = quant_config
1353

1354
1355
1356
1357
1358
1359
1360
1361
1362
        qk_nope_head_dim = getattr(config, "qk_nope_head_dim", 0)
        qk_rope_head_dim = getattr(config, "qk_rope_head_dim", 0)
        self.use_mha = config.model_type == "deepseek" or all(
            dim == 0 for dim in (qk_nope_head_dim, qk_rope_head_dim)
        )

        if self.use_mha:
            self.packed_modules_mapping["qkv_proj"] = ["q_proj", "k_proj", "v_proj"]

1363
1364
1365
1366
        # `packed_modules_mapping` needs to be modified before
        # initializing DeepseekV2Model, as it is passed inplace to
        # quantization config init and may be used to select the
        # quant_method for relevant layers during initialization.
1367
1368
1369
        self.fuse_qkv_a_proj = (
            hasattr(config, "q_lora_rank") and config.q_lora_rank is not None
        )
1370
1371
1372
1373
1374
1375
        if self.fuse_qkv_a_proj:
            self.packed_modules_mapping["fused_qkv_a_proj"] = [
                "q_a_proj",
                "kv_a_proj_with_mqa",
            ]

1376
        self.model = self.model_cls(
1377
1378
            vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
        )
1379
        if get_pp_group().is_last_rank:
1380
1381
1382
1383
1384
1385
            self.lm_head = ParallelLMHead(
                config.vocab_size,
                config.hidden_size,
                quant_config=quant_config,
                prefix=maybe_prefix(prefix, "lm_head"),
            )
1386
1387
1388
1389
        else:
            self.lm_head = PPMissingLayer()
        self.logits_processor = LogitsProcessor(config.vocab_size)
        self.make_empty_intermediate_tensors = (
1390
1391
            self.model.make_empty_intermediate_tensors
        )
1392
1393
1394
1395
1396
1397
1398
        # Set MoE hyperparameters
        self.num_moe_layers = (
            self.config.num_hidden_layers - self.config.first_k_dense_replace
        )
        self.set_moe_parameters()

    def set_moe_parameters(self):
1399
1400
        self.expert_weights = []

1401
        self.num_expert_groups = getattr(self.config, "n_group", 1)
1402

1403
1404
        self.moe_layers = []
        self.moe_mlp_layers = []
1405
        example_moe = None
1406
        for layer in self.model.layers:
1407
1408
1409
            if isinstance(layer, PPMissingLayer):
                continue

1410
1411
            assert isinstance(layer, DeepseekV2DecoderLayer)
            if isinstance(layer.mlp, DeepseekV2MoE):
1412
1413
                # Pick last one layer since the first ones may be dense layers.
                example_moe = layer.mlp
1414
                self.moe_mlp_layers.append(layer.mlp)
1415
1416
                self.moe_layers.append(layer.mlp.experts)

1417
        self.extract_moe_parameters(example_moe)
1418

1419
1420
1421
1422
1423
1424
1425
    def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
        self.model.aux_hidden_state_layers = layers

    def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
        num_layers = len(self.model.layers)
        return (2, num_layers // 2, num_layers - 3)

1426
1427
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.embed_input_ids(input_ids)
1428
1429
1430

    def forward(
        self,
1431
        input_ids: torch.Tensor | None,
1432
        positions: torch.Tensor,
1433
1434
1435
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor | IntermediateTensors:
1436
1437
1438
        hidden_states = self.model(
            input_ids, positions, intermediate_tensors, inputs_embeds
        )
1439
1440
1441
1442
1443
        return hidden_states

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
1444
    ) -> torch.Tensor | None:
1445
        logits = self.logits_processor(self.lm_head, hidden_states)
1446
1447
        return logits

1448
1449
1450
    def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
        # Params for weights, fp8 weight scales, fp8 activation scales
        # (param_name, weight_name, expert_id, shard_id)
1451
        return fused_moe_make_expert_params_mapping(
1452
            self,
1453
1454
1455
1456
1457
1458
1459
            ckpt_gate_proj_name="gate_proj",
            ckpt_down_proj_name="down_proj",
            ckpt_up_proj_name="up_proj",
            num_experts=self.config.n_routed_experts,
            num_redundant_experts=0,
        )

1460
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
1461
1462
1463
        rocm_aiter_moe_shared_expert_enabled = (
            rocm_aiter_ops.is_fusion_moe_shared_experts_enabled()
        )
wangding zeng's avatar
wangding zeng committed
1464
1465
1466
1467
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
1468
1469
        ]
        mla_params_mapping = [
1470
1471
            ("fused_qkv_a_proj", "q_a_proj", 0),
            ("fused_qkv_a_proj", "kv_a_proj_with_mqa", 1),
wangding zeng's avatar
wangding zeng committed
1472
        ]
1473
1474
1475
1476
1477
        mha_params_mapping = [
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
        ]
1478
1479
1480
1481
1482
1483
1484
        # Fused indexer wk + weights_proj (shard 0 = wk, shard 1 = weights_proj)
        _pending_wk_fp8: dict = {}  # When WK is in FP8, we dequant to BF16 for fusion
        indexer_fused_mapping = [
            ("wk_weights_proj", "wk", 0),
            ("wk_weights_proj", "weights_proj", 1),
        ]
        stacked_params_mapping.extend(indexer_fused_mapping)
1485

1486
1487
1488
1489
        if self.use_mha:
            stacked_params_mapping.extend(mha_params_mapping)
        else:
            stacked_params_mapping.extend(mla_params_mapping)
wangding zeng's avatar
wangding zeng committed
1490

1491
1492
        # Params for weights, fp8 weight scales, fp8 activation scales
        # (param_name, weight_name, expert_id, shard_id)
1493
        expert_params_mapping = fused_moe_make_expert_params_mapping(
1494
            self,
1495
1496
1497
            ckpt_gate_proj_name="gate_proj",
            ckpt_down_proj_name="down_proj",
            ckpt_up_proj_name="up_proj",
1498
1499
1500
            num_experts=self.config.n_routed_experts
            + (
                self.config.n_shared_experts
1501
                if rocm_aiter_moe_shared_expert_enabled
1502
1503
                else 0
            ),
1504
1505
            num_redundant_experts=self.num_redundant_experts,
        )
1506

wangding zeng's avatar
wangding zeng committed
1507
        params_dict = dict(self.named_parameters())
1508
        loaded_params: set[str] = set()
wangding zeng's avatar
wangding zeng committed
1509
        for name, loaded_weight in weights:
1510
1511
1512
            if "rotary_emb.inv_freq" in name:
                continue

1513
1514
1515
            spec_layer = get_spec_layer_idx_from_weight_name(self.config, name)
            if spec_layer is not None:
                continue  # skip spec decode layers for main model
1516

1517
1518
            is_fusion_moe_shared_experts_layer = (
                rocm_aiter_moe_shared_expert_enabled and ("mlp.shared_experts" in name)
1519
1520
            )

1521
1522
1523
1524
1525
            if _try_load_fp8_indexer_wk(
                name, loaded_weight, _pending_wk_fp8, params_dict, loaded_params
            ):
                continue

1526
            for param_name, weight_name, shard_id in stacked_params_mapping:
1527
                # Skip non-stacked layers and experts (experts handled below).
wangding zeng's avatar
wangding zeng committed
1528
1529
                if weight_name not in name:
                    continue
1530
1531
1532
1533
1534
1535
                # We have mlp.experts[0].gate_proj in the checkpoint.
                # Since we handle the experts below in expert_params_mapping,
                # we need to skip here BEFORE we update the name, otherwise
                # name will be updated to mlp.experts[0].gate_up_proj, which
                # will then be updated below in expert_params_mapping
                # for mlp.experts[0].gate_gate_up_proj, which breaks load.
1536
                if ("mlp.experts." in name) and name not in params_dict:
1537
                    continue
1538
                if is_fusion_moe_shared_experts_layer:
1539
                    continue
1540
                name_mapped = name.replace(weight_name, param_name)
1541
1542
1543

                # QKV fusion is optional, fall back to normal
                # weight loading if it's not enabled
1544
                # if go with fusion option, then update name
1545
1546
1547
                if (
                    param_name == "fused_qkv_a_proj"
                ) and name_mapped not in params_dict:
1548
                    continue
1549
1550
                else:
                    name = name_mapped
wangding zeng's avatar
wangding zeng committed
1551
1552
1553
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
1554
1555
1556
1557

                if is_pp_missing_parameter(name, self):
                    continue

wangding zeng's avatar
wangding zeng committed
1558
1559
1560
1561
1562
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
1563
                is_expert_weight = False
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573

                # Special handling: when AITER fusion_shared_experts is enabled,
                # checkpoints may provide a single widened shared_experts tensor
                # without explicit expert indices
                # (e.g. ...mlp.shared_experts.gate_proj.weight).
                # For models with multiple shared experts, split that tensor
                # evenly into per-shared-expert slices and load them into
                # appended expert slots mlp.experts.{n_routed_experts + j}.*
                # accordingly.
                num_chunks = 1
1574
                if is_fusion_moe_shared_experts_layer:
1575
1576
1577
1578
                    num_chunks = getattr(self.config, "n_shared_experts", 1) or 1
                    # Determine split axis based on op type
                    # gate/up: ColumnParallel → split along dim 0
                    # down: RowParallel → split along dim 1
1579
1580
1581
1582
1583
                    split_dim = (
                        1
                        if ("down_proj.weight" in name and loaded_weight.ndim > 1)
                        else 0
                    )
1584
1585
1586
1587
                    total = loaded_weight.shape[split_dim]
                    assert total % num_chunks == 0, (
                        f"Shared expert weight dim {total} "
                        f"not divisible by num_chunks {num_chunks}"
1588
                    )
1589
1590
1591
1592
1593
1594
                    chunk_size = total // num_chunks

                for j in range(num_chunks):
                    chunk_name = name
                    weight_to_load = loaded_weight

1595
                    if is_fusion_moe_shared_experts_layer:
1596
1597
1598
1599
1600
                        chunk_slice = slice(j * chunk_size, (j + 1) * chunk_size)
                        if loaded_weight.ndim == 1:
                            weight_to_load = loaded_weight[chunk_slice]
                        elif split_dim == 0:
                            weight_to_load = loaded_weight[chunk_slice, :]
1601
                        else:
1602
                            weight_to_load = loaded_weight[:, chunk_slice]
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
                        # Synthesize an expert-style name so expert mapping
                        # can route it
                        chunk_name = name.replace(
                            "mlp.shared_experts",
                            f"mlp.experts.{self.config.n_routed_experts + j}",
                        )

                    # Use expert_params_mapping to locate the destination
                    # param and delegate to its expert-aware weight_loader
                    # with expert_id.
                    for mapping in expert_params_mapping:
                        param_name, weight_name, expert_id, shard_id = mapping
                        if weight_name not in chunk_name:
                            continue

                        # Anyway, this is an expert weight and should not be
                        # attempted to load as other weights later
                        is_expert_weight = True

                        # Do not modify `name` since the loop may continue here
                        # Instead, create a new variable
                        name_mapped = chunk_name.replace(weight_name, param_name)

                        if is_pp_missing_parameter(name_mapped, self):
                            continue

                        param = params_dict[name_mapped]
                        # We should ask the weight loader to return success or
                        # not here since otherwise we may skip experts with
                        # other available replicas.
                        weight_loader = typing.cast(
                            Callable[..., bool], param.weight_loader
                        )
                        success = weight_loader(
                            param,
                            weight_to_load,
                            name_mapped,
                            shard_id=shard_id,
                            expert_id=expert_id,
                            return_success=True,
                        )
                        if success:
1645
                            if not is_fusion_moe_shared_experts_layer:
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
                                name = name_mapped
                            else:
                                loaded_params.add(name_mapped)
                            break
                    else:
                        if is_expert_weight:
                            # We've checked that this is an expert weight
                            # However it's not mapped locally to this rank
                            # So we simply skip it
                            continue

                        # Skip loading extra bias for GPTQ models.
                        if name.endswith(".bias") and name not in params_dict:
                            continue

                        # Remapping the name of FP8 kv-scale.
                        name = maybe_remap_kv_scale_name(name, params_dict)
                        if name is None:
                            continue

                        if is_pp_missing_parameter(name, self):
                            continue

                        param = params_dict[name]
                        weight_loader = getattr(
                            param, "weight_loader", default_weight_loader
                        )
                        weight_loader(param, loaded_weight)
1674
            if name is not None and not is_fusion_moe_shared_experts_layer:
1675
                loaded_params.add(name)
1676

1677
        return loaded_params
1678
1679


1680
1681
1682
1683
class DeepseekForCausalLM(DeepseekV2ForCausalLM):
    pass


1684
1685
class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM):
    pass
1686
1687


Jee Jee Li's avatar
Jee Jee Li committed
1688
1689
1690
1691
class GlmMoeDsaForCausalLM(DeepseekV2ForCausalLM):
    pass


1692
1693
# Compatibility with
# https://huggingface.co/deepseek-ai/DeepSeek-V3-Base/blob/main/configuration_deepseek.py
1694
def get_spec_layer_idx_from_weight_name(
1695
1696
    config: DeepseekV2Config | DeepseekV3Config, weight_name: str
) -> int | None:
1697
1698
1699
1700
    if (
        hasattr(config, "num_nextn_predict_layers")
        and config.num_nextn_predict_layers > 0
    ):
1701
1702
        layer_idx = config.num_hidden_layers
        for i in range(config.num_nextn_predict_layers):
1703
            if weight_name.startswith(f"model.layers.{layer_idx + i}."):
1704
1705
                return layer_idx + i
    return None