deepseek_v2.py 64.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

wangding zeng's avatar
wangding zeng committed
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# Copyright 2023 The vLLM team.
# Copyright 2023 DeepSeek-AI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
25
"""Inference-only DeepseekV2/DeepseekV3 model."""
26

27
28
import typing
from collections.abc import Callable, Iterable
29
from itertools import islice
wangding zeng's avatar
wangding zeng committed
30
31
32

import torch
from torch import nn
33
from transformers import DeepseekV2Config, DeepseekV3Config
wangding zeng's avatar
wangding zeng committed
34

35
from vllm._aiter_ops import rocm_aiter_ops
36
from vllm.attention.layer import Attention
37
from vllm.compilation.decorators import support_torch_compile
38
39
40
41
42
43
44
45
from vllm.config import CacheConfig, ParallelConfig, VllmConfig, get_current_vllm_config
from vllm.distributed import (
    get_ep_group,
    get_pp_group,
    get_tensor_model_parallel_rank,
    get_tensor_model_parallel_world_size,
    tensor_model_parallel_all_gather,
)
46
47
from vllm.forward_context import get_forward_context
from vllm.logger import init_logger
wangding zeng's avatar
wangding zeng committed
48
from vllm.model_executor.layers.activation import SiluAndMul
49
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
50
from vllm.model_executor.layers.fused_moe import SharedFusedMoE
51
from vllm.model_executor.layers.layernorm import LayerNorm, RMSNorm
52
53
54
from vllm.model_executor.layers.linear import (
    ColumnParallelLinear,
    MergedColumnParallelLinear,
55
    QKVParallelLinear,
56
57
58
    ReplicatedLinear,
    RowParallelLinear,
)
wangding zeng's avatar
wangding zeng committed
59
from vllm.model_executor.layers.logits_processor import LogitsProcessor
60
from vllm.model_executor.layers.mla import MLAModules, MultiHeadLatentAttentionWrapper
61
from vllm.model_executor.layers.quantization import QuantizationConfig
62
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
63
64
    per_token_group_quant_fp8,
)
wangding zeng's avatar
wangding zeng committed
65
66
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
67
68
69
    ParallelLMHead,
    VocabParallelEmbedding,
)
70
from vllm.model_executor.model_loader.weight_utils import (
71
72
73
    default_weight_loader,
    maybe_remap_kv_scale_name,
)
74
from vllm.model_executor.models.utils import sequence_parallel_chunk
75
from vllm.platforms import current_platform
76
from vllm.sequence import IntermediateTensors
77
from vllm.utils.deep_gemm import fp8_mqa_logits, fp8_paged_mqa_logits
78
from vllm.utils.torch_utils import direct_register_custom_op
79
from vllm.v1.attention.backend import AttentionBackend
80
81
82
83
from vllm.v1.attention.backends.mla.indexer import (
    DeepseekV32IndexerBackend,
    DeepseekV32IndexerMetadata,
)
84
from vllm.v1.attention.ops.common import pack_seq_triton, unpack_seq_triton
85
from vllm.v1.kv_cache_interface import KVCacheSpec, MLAAttentionSpec
86
from vllm.v1.worker.workspace import current_workspace_manager
wangding zeng's avatar
wangding zeng committed
87

88
from .interfaces import MixtureOfExperts, SupportsEagle, SupportsLoRA, SupportsPP
89
90
91
92
93
94
95
from .utils import (
    PPMissingLayer,
    is_pp_missing_parameter,
    make_empty_intermediate_tensors_factory,
    make_layers,
    maybe_prefix,
)
96

97
98
99
100
101
102
103
if current_platform.is_cuda_alike():
    from vllm import _custom_ops as ops
elif current_platform.is_xpu():
    from vllm._ipex_ops import ipex_ops as ops

logger = init_logger(__name__)

wangding zeng's avatar
wangding zeng committed
104

105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
class DeepseekAttention(nn.Module):
    """Normal MHA implementation used by Deepseek v1."""

    def __init__(
        self,
        vllm_config: VllmConfig,
        config: DeepseekV2Config | DeepseekV3Config,
        hidden_size: int,
        num_heads: int,
        max_position_embeddings: int = 8192,
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
        prefix: str = "",
        **kwargs,
    ) -> None:
        super().__init__()
        self.hidden_size = hidden_size
        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = num_heads
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        self.total_num_kv_heads = config.num_key_value_heads
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
        self.head_dim = hidden_size // self.total_num_heads
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scaling = self.head_dim**-0.5
        self.max_position_embeddings = max_position_embeddings

        self.qkv_proj = QKVParallelLinear(
            hidden_size,
            self.head_dim,
            self.total_num_heads,
            self.total_num_kv_heads,
            bias=False,
            quant_config=quant_config,
        )

        self.o_proj = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            hidden_size,
            bias=False,
            quant_config=quant_config,
        )

        self.rotary_emb = get_rope(
            self.head_dim,
            max_position=max_position_embeddings,
161
            rope_parameters=config.rope_parameters,
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
        )
        self.attn = Attention(
            self.num_heads,
            self.head_dim,
            self.scaling,
            num_kv_heads=self.num_kv_heads,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.attn",
        )

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
        q, k = self.rotary_emb(positions, q, k)
        attn_output = self.attn(q, k, v)
        output, _ = self.o_proj(attn_output)
        return output


wangding zeng's avatar
wangding zeng committed
186
187
188
189
190
191
class DeepseekV2MLP(nn.Module):
    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
192
        quant_config: QuantizationConfig | None = None,
wangding zeng's avatar
wangding zeng committed
193
        reduce_results: bool = True,
194
        is_sequence_parallel=False,
195
        prefix: str = "",
wangding zeng's avatar
wangding zeng committed
196
197
    ) -> None:
        super().__init__()
198
199
200
201
202

        # If is_sequence_parallel, the input and output tensors are sharded
        # across the ranks within the tp_group. In this case the weights are
        # replicated and no collective ops are needed.
        # Otherwise we use standard TP with an allreduce at the end.
wangding zeng's avatar
wangding zeng committed
203
        self.gate_up_proj = MergedColumnParallelLinear(
204
205
            hidden_size,
            [intermediate_size] * 2,
wangding zeng's avatar
wangding zeng committed
206
            bias=False,
207
            quant_config=quant_config,
208
            disable_tp=is_sequence_parallel,
209
210
211
212
213
214
215
216
217
218
219
            prefix=f"{prefix}.gate_up_proj",
        )
        self.down_proj = RowParallelLinear(
            intermediate_size,
            hidden_size,
            bias=False,
            quant_config=quant_config,
            reduce_results=reduce_results,
            disable_tp=is_sequence_parallel,
            prefix=f"{prefix}.down_proj",
        )
wangding zeng's avatar
wangding zeng committed
220
        if hidden_act != "silu":
221
222
223
            raise ValueError(
                f"Unsupported activation: {hidden_act}. Only silu is supported for now."
            )
wangding zeng's avatar
wangding zeng committed
224
225
226
227
228
229
230
231
232
233
234
235
        self.act_fn = SiluAndMul()

    def forward(self, x):
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.down_proj(x)
        return x


class DeepseekV2MoE(nn.Module):
    def __init__(
        self,
236
        config: DeepseekV2Config | DeepseekV3Config,
237
        parallel_config: ParallelConfig,
238
        quant_config: QuantizationConfig | None = None,
239
        prefix: str = "",
wangding zeng's avatar
wangding zeng committed
240
241
242
    ):
        super().__init__()
        self.tp_size = get_tensor_model_parallel_world_size()
243
244
        self.tp_rank = get_tensor_model_parallel_rank()

245
        self.routed_scaling_factor = getattr(config, "routed_scaling_factor", 1.0)
246
247

        self.ep_group = get_ep_group().device_group
248
        self.ep_rank = get_ep_group().rank_in_group
249
250
251
        self.ep_size = self.ep_group.size()
        self.n_routed_experts: int = config.n_routed_experts
        self.n_shared_experts: int = config.n_shared_experts
252

253
        self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe
254

255
        if config.hidden_act != "silu":
256
257
258
259
260
261
262
263
264
265
266
267
            raise ValueError(
                f"Unsupported activation: {config.hidden_act}. "
                "Only silu is supported for now."
            )

        self.gate = ReplicatedLinear(
            config.hidden_size,
            config.n_routed_experts,
            bias=False,
            quant_config=None,
            prefix=f"{prefix}.gate",
        )
268
        if getattr(config, "topk_method", None) == "noaux_tc":
269
            self.gate.e_score_correction_bias = nn.Parameter(
270
271
                torch.empty(config.n_routed_experts, dtype=torch.float32)
            )
272
273
274
        else:
            self.gate.e_score_correction_bias = None

275
        # Load balancing settings.
276
277
        eplb_config = parallel_config.eplb_config
        self.enable_eplb = parallel_config.enable_eplb
278

279
        self.n_redundant_experts = eplb_config.num_redundant_experts
280
        self.n_logical_experts = self.n_routed_experts
281
        self.n_physical_experts = self.n_logical_experts + self.n_redundant_experts
282
283
        self.n_local_physical_experts = self.n_physical_experts // self.ep_size

284
285
286
287
        self.physical_expert_start = self.ep_rank * self.n_local_physical_experts
        self.physical_expert_end = (
            self.physical_expert_start + self.n_local_physical_experts
        )
288

289
        self.is_rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled()
290
291
292
293
        self.is_fusion_moe_shared_experts_enabled = (
            rocm_aiter_ops.is_fusion_moe_shared_experts_enabled()
        )
        if config.n_shared_experts is None or self.is_fusion_moe_shared_experts_enabled:
294
295
            self.shared_experts = None
        else:
296
            intermediate_size = config.moe_intermediate_size * config.n_shared_experts
297

wangding zeng's avatar
wangding zeng committed
298
299
300
301
302
            self.shared_experts = DeepseekV2MLP(
                hidden_size=config.hidden_size,
                intermediate_size=intermediate_size,
                hidden_act=config.hidden_act,
                quant_config=quant_config,
303
                is_sequence_parallel=self.is_sequence_parallel,
304
                reduce_results=False,
305
                prefix=f"{prefix}.shared_experts",
wangding zeng's avatar
wangding zeng committed
306
307
            )

308
309
        self.experts = SharedFusedMoE(
            shared_experts=self.shared_experts,
310
            gate=self.gate,
311
312
313
314
315
316
317
318
            num_experts=config.n_routed_experts,
            top_k=config.num_experts_per_tok,
            hidden_size=config.hidden_size,
            intermediate_size=config.moe_intermediate_size,
            reduce_results=False,
            renormalize=config.norm_topk_prob,
            quant_config=quant_config,
            use_grouped_topk=True,
319
320
            num_expert_group=getattr(config, "n_group", 1),
            topk_group=getattr(config, "topk_group", 1),
321
            prefix=f"{prefix}.experts",
322
            scoring_func=getattr(config, "scoring_func", "softmax"),
323
            # we do scaling outside, set factor to 1.0 to avoid double mul
324
325
            # aiter applies routed_scaling_factor internally
            routed_scaling_factor=1.0
326
            if not self.is_rocm_aiter_moe_enabled
327
            else self.routed_scaling_factor,
328
329
330
331
            e_score_correction_bias=self.gate.e_score_correction_bias,
            enable_eplb=self.enable_eplb,
            num_redundant_experts=self.n_redundant_experts,
            is_sequence_parallel=self.is_sequence_parallel,
332
            n_shared_experts=config.n_shared_experts
333
            if self.is_fusion_moe_shared_experts_enabled
334
            else None,
335
        )
336

wangding zeng's avatar
wangding zeng committed
337
338
339
    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        num_tokens, hidden_dim = hidden_states.shape
        hidden_states = hidden_states.view(-1, hidden_dim)
340
341
342
343
344
345

        # Chunk the hidden states so they aren't replicated across TP ranks.
        # This avoids duplicate computation in self.experts.
        # TODO: We can replace the all_reduce at the end of attn with a
        # reduce_scatter instead of chunking here.
        if self.is_sequence_parallel:
346
            hidden_states = sequence_parallel_chunk(hidden_states)
347

348
349
350
351
352
353
354
355
356
357
358
        if self.experts.is_internal_router:
            # In this case, the gate/router runs inside the FusedMoE class
            fused_moe_out = self.experts(
                hidden_states=hidden_states, router_logits=hidden_states
            )
        else:
            # router_logits: (num_tokens, n_experts)
            router_logits, _ = self.gate(hidden_states)
            fused_moe_out = self.experts(
                hidden_states=hidden_states, router_logits=router_logits
            )
359

360
361
362
        shared_output, final_hidden_states = fused_moe_out
        if self.shared_experts is None:
            assert shared_output is None
363
364
365
366

        # Fix FP16 overflow
        # See DeepseekV2DecoderLayer for more details.
        if hidden_states.dtype != torch.float16:
367
            if not self.is_rocm_aiter_moe_enabled:
368
                final_hidden_states *= self.routed_scaling_factor
369
370
        elif self.shared_experts is not None:
            assert shared_output is not None
371
            shared_output *= 1.0 / self.routed_scaling_factor
372
373
374
375

        if self.shared_experts is not None:
            assert shared_output is not None
            final_hidden_states += shared_output
376

377
378
        if self.is_sequence_parallel:
            final_hidden_states = tensor_model_parallel_all_gather(
379
380
                final_hidden_states, 0
            )
381
382
            final_hidden_states = final_hidden_states[:num_tokens]
        elif self.tp_size > 1:
383
384
385
            final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel(
                final_hidden_states
            )
wangding zeng's avatar
wangding zeng committed
386
387
388
389
390
391

        return final_hidden_states.view(num_tokens, hidden_dim)


def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
    import math
392

wangding zeng's avatar
wangding zeng committed
393
394
395
396
397
    if scale <= 1:
        return 1.0
    return 0.1 * mscale * math.log(scale) + 1.0


398
399
400
401
402
403
404
405
406
407
def _get_llama_4_scaling(
    original_max_position_embeddings: int, scaling_beta: float, positions: torch.Tensor
) -> torch.Tensor:
    scaling = 1 + scaling_beta * torch.log(
        1 + torch.floor(positions / original_max_position_embeddings)
    )
    # Broadcast over num_heads and head_dim
    return scaling[..., None, None]


wangding zeng's avatar
wangding zeng committed
408
409
410
class DeepseekV2Attention(nn.Module):
    def __init__(
        self,
411
        vllm_config: VllmConfig,
412
        config: DeepseekV2Config | DeepseekV3Config,
wangding zeng's avatar
wangding zeng committed
413
414
415
416
417
418
419
420
        hidden_size: int,
        num_heads: int,
        qk_nope_head_dim: int,
        qk_rope_head_dim: int,
        v_head_dim: int,
        q_lora_rank: int,
        kv_lora_rank: int,
        max_position_embeddings: int = 8192,
421
422
423
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
        topk_indices_buffer: torch.Tensor | None = None,
424
        prefix: str = "",
wangding zeng's avatar
wangding zeng committed
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
    ) -> None:
        super().__init__()
        self.hidden_size = hidden_size
        self.qk_nope_head_dim = qk_nope_head_dim
        self.qk_rope_head_dim = qk_rope_head_dim
        self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
        self.v_head_dim = v_head_dim
        self.q_lora_rank = q_lora_rank
        self.kv_lora_rank = kv_lora_rank
        self.num_heads = num_heads
        tp_size = get_tensor_model_parallel_world_size()
        assert num_heads % tp_size == 0
        self.num_local_heads = num_heads // tp_size
        self.scaling = self.qk_head_dim**-0.5
        self.max_position_embeddings = max_position_embeddings
440
441
        assert topk_indices_buffer is None, (
            "topk_indices_buffer is not \
442
        supported for DeepseekV2Attention"
443
        )
wangding zeng's avatar
wangding zeng committed
444
445

        if self.q_lora_rank is not None:
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
            self.q_a_proj = ReplicatedLinear(
                self.hidden_size,
                self.q_lora_rank,
                bias=False,
                quant_config=quant_config,
                prefix=f"{prefix}.q_a_proj",
            )
            self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
            self.q_b_proj = ColumnParallelLinear(
                q_lora_rank,
                self.num_heads * self.qk_head_dim,
                bias=False,
                quant_config=quant_config,
                prefix=f"{prefix}.q_b_proj",
            )
wangding zeng's avatar
wangding zeng committed
461
        else:
462
463
464
465
466
467
468
            self.q_proj = ColumnParallelLinear(
                self.hidden_size,
                self.num_heads * self.qk_head_dim,
                bias=False,
                quant_config=quant_config,
                prefix=f"{prefix}.q_proj",
            )
wangding zeng's avatar
wangding zeng committed
469

470
471
472
473
474
        self.kv_a_proj_with_mqa = ReplicatedLinear(
            self.hidden_size,
            self.kv_lora_rank + self.qk_rope_head_dim,
            bias=False,
            quant_config=quant_config,
475
476
477
            prefix=f"{prefix}.kv_a_proj_with_mqa",
        )
        self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
wangding zeng's avatar
wangding zeng committed
478
479
480
481
        self.kv_b_proj = ColumnParallelLinear(
            self.kv_lora_rank,
            self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
            bias=False,
482
            quant_config=quant_config,
483
484
            prefix=f"{prefix}.kv_b_proj",
        )
wangding zeng's avatar
wangding zeng committed
485
        # O projection.
486
487
488
489
490
491
492
        self.o_proj = RowParallelLinear(
            self.num_heads * self.v_head_dim,
            self.hidden_size,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.o_proj",
        )
493
        if config.rope_parameters["rope_type"] != "default":
494
495
496
497
498
            config.rope_parameters["rope_type"] = (
                "deepseek_yarn"
                if config.rope_parameters.get("apply_yarn_scaling", True)
                else "deepseek_llama_scaling"
            )
499

500
501
502
        self.rotary_emb = get_rope(
            qk_rope_head_dim,
            max_position=max_position_embeddings,
503
            rope_parameters=config.rope_parameters,
504
505
            is_neox_style=False,
        )
wangding zeng's avatar
wangding zeng committed
506

507
508
509
510
        if (
            config.rope_parameters["rope_type"] != "default"
            and config.rope_parameters["rope_type"] == "deepseek_yarn"
        ):
511
512
            mscale_all_dim = config.rope_parameters.get("mscale_all_dim", False)
            scaling_factor = config.rope_parameters["factor"]
wangding zeng's avatar
wangding zeng committed
513
514
515
            mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
            self.scaling = self.scaling * mscale * mscale

516
517
518
519
520
521
522
523
524
        self.attn = Attention(
            self.num_local_heads,
            self.qk_head_dim,
            self.scaling,
            num_kv_heads=self.num_local_heads,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.attn",
        )
wangding zeng's avatar
wangding zeng committed
525
526
527
528
529

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
530
        llama_4_scaling: torch.Tensor | None,
wangding zeng's avatar
wangding zeng committed
531
532
533
534
    ) -> torch.Tensor:
        if self.q_lora_rank is not None:
            q = self.q_a_proj(hidden_states)[0]
            q = self.q_a_layernorm(q)
535
            q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
wangding zeng's avatar
wangding zeng committed
536
        else:
537
538
539
540
            q = self.q_proj(hidden_states)[0].view(
                -1, self.num_local_heads, self.qk_head_dim
            )
        q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
wangding zeng's avatar
wangding zeng committed
541
        latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
542
        kv_a, _ = latent_cache.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
wangding zeng's avatar
wangding zeng committed
543
        latent_cache = latent_cache.unsqueeze(1)
544
        kv_a = self.kv_a_layernorm(kv_a)
wangding zeng's avatar
wangding zeng committed
545
        kv = self.kv_b_proj(kv_a)[0]
546
        kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim)
wangding zeng's avatar
wangding zeng committed
547
        k_nope, v = kv.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
548
        k_pe = latent_cache[:, :, self.kv_lora_rank :]
549

wangding zeng's avatar
wangding zeng committed
550
        q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
551

552
        q[..., self.qk_nope_head_dim :] = q_pe
wangding zeng's avatar
wangding zeng committed
553
        k = torch.empty_like(q)
554
555
        k[..., : self.qk_nope_head_dim] = k_nope
        k[..., self.qk_nope_head_dim :] = k_pe
556
557
558
559
560

        # Apply llama 4 scaling if provided
        if llama_4_scaling is not None:
            q *= llama_4_scaling

561
562
        # padding value to qk_head_dim for alignment
        v = torch.nn.functional.pad(
563
564
            v, [0, self.qk_head_dim - self.v_head_dim], value=0
        ).view(-1, self.num_local_heads * self.qk_head_dim)
565
        attn_output = self.attn(q, k, v)
566
567
568
        attn_output = attn_output.view(-1, self.num_local_heads, self.qk_head_dim)[
            ..., : self.v_head_dim
        ].reshape(-1, self.num_local_heads * self.v_head_dim)
wangding zeng's avatar
wangding zeng committed
569
570
571
572
        output, _ = self.o_proj(attn_output)
        return output


573
class DeepseekV32IndexerCache(torch.nn.Module, AttentionLayerBase):
574
575
576
    def __init__(
        self, head_dim: int, dtype: torch.dtype, prefix: str, cache_config: CacheConfig
    ):
577
578
579
580
581
582
583
584
585
586
587
        super().__init__()
        self.kv_cache = [torch.tensor([])]
        self.head_dim = head_dim
        self.prefix = prefix
        self.cache_config = cache_config
        self.dtype = dtype
        compilation_config = get_current_vllm_config().compilation_config
        if prefix in compilation_config.static_forward_context:
            raise ValueError(f"Duplicate layer name: {prefix}")
        compilation_config.static_forward_context[prefix] = self

588
    def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
589
590
591
592
593
594
595
        return MLAAttentionSpec(  # Only has one vector instead of K + V
            block_size=self.cache_config.block_size,
            num_kv_heads=1,
            head_size=self.head_dim,
            dtype=self.dtype,
        )

596
    def forward(self): ...
597
598
599
600
601
602
603
604
605
606
607
608
609

    def get_attn_backend(self) -> AttentionBackend:
        return DeepseekV32IndexerBackend


def sparse_attn_indexer(
    hidden_states: torch.Tensor,
    k_cache_prefix: str,
    kv_cache: torch.Tensor,
    q_fp8: torch.Tensor,
    k: torch.Tensor,
    weights: torch.Tensor,
    quant_block_size: int,
610
    scale_fmt: str | None,
611
612
613
614
    topk_tokens: int,
    head_dim: int,
    max_model_len: int,
    total_seq_lens: int,
615
    topk_indices_buffer: torch.Tensor | None,
616
617
618
) -> torch.Tensor:
    # careful! this will be None in dummy run
    attn_metadata = get_forward_context().attn_metadata
619
    fp8_dtype = current_platform.fp8_dtype()
620

621
622
    # assert isinstance(attn_metadata, dict)
    if not isinstance(attn_metadata, dict):
623
624
625
626
627
628
        # Reserve workspace for indexer during profiling run
        current_workspace_manager().get_simultaneous(
            ((total_seq_lens, head_dim), torch.float8_e4m3fn),
            ((total_seq_lens, 4), torch.uint8),
        )

629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
        return sparse_attn_indexer_fake(
            hidden_states,
            k_cache_prefix,
            kv_cache,
            q_fp8,
            k,
            weights,
            quant_block_size,
            scale_fmt,
            topk_tokens,
            head_dim,
            max_model_len,
            total_seq_lens,
            topk_indices_buffer,
        )
    attn_metadata = attn_metadata[k_cache_prefix]
    assert isinstance(attn_metadata, DeepseekV32IndexerMetadata)
    slot_mapping = attn_metadata.slot_mapping
    has_decode = attn_metadata.num_decodes > 0
    has_prefill = attn_metadata.num_prefills > 0
    num_decode_tokens = attn_metadata.num_decode_tokens

    ops.indexer_k_quant_and_cache(
        k,
        kv_cache,
        slot_mapping,
        quant_block_size,
        scale_fmt,
    )

659
    topk_indices_buffer[: hidden_states.shape[0]] = -1
660
661
    if has_prefill:
        prefill_metadata = attn_metadata.prefill
662
663
664
665
666
667
668
669

        # Get the full shared workspace buffers once (will allocate on first use)
        workspace_manager = current_workspace_manager()
        k_fp8_full, k_scale_full = workspace_manager.get_simultaneous(
            ((total_seq_lens, head_dim), fp8_dtype),
            ((total_seq_lens, 4), torch.uint8),
        )

670
        for chunk in prefill_metadata.chunks:
671
672
            k_fp8 = k_fp8_full[: chunk.total_seq_lens]
            k_scale = k_scale_full[: chunk.total_seq_lens]
673
            ops.cp_gather_indexer_k_quant_cache(
674
675
676
677
678
679
                kv_cache,
                k_fp8,
                k_scale,
                chunk.block_table,
                chunk.cu_seq_lens,
            )
680
681
            fp8_mqa_logits_func = fp8_mqa_logits
            if current_platform.is_rocm():
682
683
684
                from vllm.v1.attention.ops.rocm_aiter_mla_sparse import (
                    rocm_fp8_mqa_logits,
                )
685
686
687

                fp8_mqa_logits_func = rocm_fp8_mqa_logits
            logits = fp8_mqa_logits_func(
688
                q_fp8[chunk.token_start : chunk.token_end],
689
                (k_fp8, k_scale.view(torch.float32)),
690
                weights[chunk.token_start : chunk.token_end],
691
692
693
                chunk.cu_seqlen_ks,
                chunk.cu_seqlen_ke,
            )
694
            num_rows = logits.shape[0]
695
696
697
            topk_indices = topk_indices_buffer[
                chunk.token_start : chunk.token_end, :topk_tokens
            ]
698
            torch.ops._C.top_k_per_row_prefill(
699
700
701
702
703
704
705
                logits,
                chunk.cu_seqlen_ks,
                chunk.cu_seqlen_ke,
                topk_indices,
                num_rows,
                logits.stride(0),
                logits.stride(1),
706
                topk_tokens,
707
            )
708
709
710
711
712
713
714
715
716
717
718
719

    if has_decode:
        decode_metadata = attn_metadata.decode
        # kv_cache size requirement [num_block, block_size, n_head, head_dim],
        # we only have [num_block, block_size, head_dim],
        kv_cache = kv_cache.unsqueeze(-2)
        decode_lens = decode_metadata.decode_lens
        if decode_metadata.requires_padding:
            # pad in edge case where we have short chunked prefill length <
            # decode_threshold since we unstrictly split
            # prefill and decode by decode_threshold
            # (currently set to 1 + speculative tokens)
720
721

            # [num_decode_tokens, n_head, head_dim] -> [bs, 1+next_n, n_head, head_dim]
722
            padded_q_fp8_decode_tokens = pack_seq_triton(
723
724
                q_fp8[:num_decode_tokens], decode_lens
            )
725
726
727
728
            # [num_decode_tokens, n_head] -> [bs, 1+next_n, n_head]
            padded_weights = pack_seq_triton(weights[:num_decode_tokens], decode_lens)
            # [bs, 1+next_n, n_head] -> [bs * next_n, n_head]
            padded_weights = padded_weights.flatten(0, 1)
729
730
        else:
            padded_q_fp8_decode_tokens = q_fp8[:num_decode_tokens].reshape(
731
732
                decode_lens.shape[0], -1, *q_fp8.shape[1:]
            )
733
            padded_weights = weights
734
735
736
737
738
        # TODO: move and optimize below logic with triton kernels
        batch_size = padded_q_fp8_decode_tokens.shape[0]
        next_n = padded_q_fp8_decode_tokens.shape[1]
        assert batch_size == decode_metadata.seq_lens.shape[0]
        num_padded_tokens = batch_size * next_n
739
740
        fp8_paged_mqa_logits_func = fp8_paged_mqa_logits
        if current_platform.is_rocm():
741
            from vllm.v1.attention.ops.rocm_aiter_mla_sparse import (
742
743
744
745
746
                rocm_fp8_paged_mqa_logits,
            )

            fp8_paged_mqa_logits_func = rocm_fp8_paged_mqa_logits
        logits = fp8_paged_mqa_logits_func(
747
748
            padded_q_fp8_decode_tokens,
            kv_cache,
749
            padded_weights[:num_padded_tokens],
750
751
752
753
754
            decode_metadata.seq_lens,
            decode_metadata.block_table,
            decode_metadata.schedule_metadata,
            max_model_len=max_model_len,
        )
755
        num_rows = logits.shape[0]
756
        topk_indices = topk_indices_buffer[:num_padded_tokens, :topk_tokens]
757
758

        torch.ops._C.top_k_per_row_decode(
759
            logits,
760
761
            next_n,
            decode_metadata.seq_lens,
762
763
764
765
            topk_indices,
            num_rows,
            logits.stride(0),
            logits.stride(1),
766
            topk_tokens,
767
        )
768
769
770
771
772
        if decode_metadata.requires_padding:
            # if padded, we need to unpack
            # the topk indices removing padded tokens
            topk_indices = unpack_seq_triton(
                topk_indices.reshape(batch_size, -1, topk_indices.shape[-1]),
773
774
                decode_lens,
            )
775
776
777
            topk_indices_buffer[:num_decode_tokens, : topk_indices.shape[-1]] = (
                topk_indices
            )
778
779
780
781
782
783
784
785
786
787
788
789

    return topk_indices_buffer


def sparse_attn_indexer_fake(
    hidden_states: torch.Tensor,
    k_cache_prefix: str,
    kv_cache: torch.Tensor,
    q_fp8: torch.Tensor,
    k: torch.Tensor,
    weights: torch.Tensor,
    quant_block_size: int,
790
    scale_fmt: str | None,
791
792
793
794
    topk_tokens: int,
    head_dim: int,
    max_model_len: int,
    total_seq_lens: int,
795
    topk_indices_buffer: torch.Tensor | None,
796
797
798
799
800
801
802
803
804
805
806
807
808
809
) -> torch.Tensor:
    return topk_indices_buffer


direct_register_custom_op(
    op_name="sparse_attn_indexer",
    op_func=sparse_attn_indexer,
    mutates_args=["topk_indices_buffer"],
    fake_impl=sparse_attn_indexer_fake,
    dispatch_key=current_platform.dispatch_key,
)


class Indexer(nn.Module):
810
811
812
    def __init__(
        self,
        vllm_config: VllmConfig,
813
        config: DeepseekV2Config | DeepseekV3Config,
814
815
        hidden_size: int,
        q_lora_rank: int,
816
817
818
        quant_config: QuantizationConfig | None,
        cache_config: CacheConfig | None,
        topk_indices_buffer: torch.Tensor | None,
819
820
        prefix: str = "",
    ):
821
822
823
824
825
826
827
828
829
830
        super().__init__()
        self.vllm_config = vllm_config
        self.config = config
        # self.indexer_cfg = config.attn_module_list_cfg[0]["attn_index"]
        self.topk_tokens = config.index_topk
        self.n_head = config.index_n_heads  # 64
        self.head_dim = config.index_head_dim  # 128
        self.rope_dim = config.qk_rope_head_dim  # 64
        self.q_lora_rank = q_lora_rank  # 1536
        # no tensor parallel, just replicated
831
832
833
834
835
836
837
838
839
840
841
842
843
844
        self.wq_b = ReplicatedLinear(
            self.q_lora_rank,
            self.head_dim * self.n_head,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.wq_b",
        )
        self.wk = ReplicatedLinear(
            hidden_size,
            self.head_dim,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.wk",
        )
845
        self.k_norm = LayerNorm(self.head_dim, eps=1e-6)
846
        self.weights_proj = ReplicatedLinear(
847
848
849
850
851
            hidden_size,
            self.n_head,
            bias=False,
            quant_config=None,
            prefix=f"{prefix}.weights_proj",
852
        )
853
854
855
856
857
858
859
860
861
862
        self.softmax_scale = self.head_dim**-0.5

        self.scale_fmt = "ue8m0"
        self.quant_block_size = 128  # TODO: get from config
        self.topk_indices_buffer = topk_indices_buffer

        # NOTE: (zyongye) we use fp8 naive cache,
        #       where we store value in fp8 and scale in fp32
        #       per self.quant_block_size element
        self.k_cache = DeepseekV32IndexerCache(
863
            head_dim=self.head_dim + self.head_dim // self.quant_block_size * 4,
864
865
            dtype=torch.uint8,
            prefix=f"{prefix}.k_cache",
866
867
            cache_config=cache_config,
        )
868
869
        self.max_model_len = vllm_config.model_config.max_model_len
        self.prefix = prefix
870
871
        from vllm.v1.attention.backends.mla.indexer import get_max_prefill_buffer_size

872
873
        self.max_total_seq_len = get_max_prefill_buffer_size(vllm_config)

874
875
876
    def forward(
        self, hidden_states: torch.Tensor, qr: torch.Tensor, positions, rotary_emb
    ) -> torch.Tensor:
877
878
879
        q, _ = self.wq_b(qr)
        q = q.view(-1, self.n_head, self.head_dim)
        q_pe, q_nope = torch.split(
880
881
            q, [self.rope_dim, self.head_dim - self.rope_dim], dim=-1
        )
882
883
884
885

        k, _ = self.wk(hidden_states)
        k = self.k_norm(k)
        k_pe, k_nope = torch.split(
886
887
            k, [self.rope_dim, self.head_dim - self.rope_dim], dim=-1
        )
888
889

        q_pe, k_pe = rotary_emb(positions, q_pe, k_pe.unsqueeze(1))
890
891
892
893
894
        # Note: RoPE (NeoX) can introduce extra leading dimensions during compilation
        # so we need to reshape back to token-flattened shapes
        q_pe = q_pe.reshape(-1, self.n_head, self.rope_dim)
        k_pe = k_pe.reshape(-1, 1, self.rope_dim)

895
896
        q = torch.cat([q_pe, q_nope], dim=-1)
        # `k_pe` is [num_tokens, 1, rope_dim] (MQA).
897
        k = torch.cat([k_pe.squeeze(-2), k_nope], dim=-1)
898
899
900

        # we only quant q here since k quant is fused with cache insertion
        q = q.view(-1, self.head_dim)
901
902
903
904
905
906
        q_fp8, q_scale = per_token_group_quant_fp8(
            q,
            self.quant_block_size,
            column_major_scales=False,
            use_ue8m0=self.scale_fmt is not None,
        )
907
908
909
910
        q_fp8 = q_fp8.view(-1, self.n_head, self.head_dim)
        q_scale = q_scale.view(-1, self.n_head, 1)

        weights, _ = self.weights_proj(hidden_states)
911
912
913
        weights = (
            weights.unsqueeze(-1) * q_scale * self.softmax_scale * self.n_head**-0.5
        )
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
        weights = weights.squeeze(-1)

        return torch.ops.vllm.sparse_attn_indexer(
            hidden_states,
            self.k_cache.prefix,
            self.k_cache.kv_cache[0],
            q_fp8,
            k,
            weights,
            self.quant_block_size,
            self.scale_fmt,
            self.topk_tokens,
            self.head_dim,
            self.max_model_len,
            self.max_total_seq_len,
            self.topk_indices_buffer,
        )


933
934
935
936
class DeepseekV2MLAAttention(nn.Module):
    """
    Main reference: DeepseekV2 paper, and FlashInfer Implementation
    (https://arxiv.org/abs/2405.04434 and https://github.com/flashinfer-ai/flashinfer/pull/551).
937

938
939
        For more info see MLACommonImpl in:
        vllm/v1/attention/backends/mla/utils.py
940
941
942
943
    """

    def __init__(
        self,
944
        vllm_config: VllmConfig,
945
        config: DeepseekV2Config | DeepseekV3Config,
946
947
948
949
950
        hidden_size: int,
        num_heads: int,
        qk_nope_head_dim: int,
        qk_rope_head_dim: int,
        v_head_dim: int,
951
        q_lora_rank: int | None,
952
953
        kv_lora_rank: int,
        max_position_embeddings: int = 8192,
954
955
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
956
        prefix: str = "",
957
        topk_indices_buffer: torch.Tensor | None = None,
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
    ) -> None:
        super().__init__()
        self.hidden_size = hidden_size
        self.qk_nope_head_dim = qk_nope_head_dim
        self.qk_rope_head_dim = qk_rope_head_dim
        self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
        self.v_head_dim = v_head_dim

        self.q_lora_rank = q_lora_rank
        self.kv_lora_rank = kv_lora_rank

        self.num_heads = num_heads
        tp_size = get_tensor_model_parallel_world_size()
        assert num_heads % tp_size == 0
        self.num_local_heads = num_heads // tp_size

        self.scaling = self.qk_head_dim**-0.5
        self.max_position_embeddings = max_position_embeddings

        if self.q_lora_rank is not None:
978
            self.fused_qkv_a_proj = MergedColumnParallelLinear(
979
980
981
982
                self.hidden_size,
                [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim],
                bias=False,
                quant_config=quant_config,
983
                prefix=f"{prefix}.fused_qkv_a_proj",
984
985
                disable_tp=True,
            )
986
987
988
989
990
991
        else:
            self.kv_a_proj_with_mqa = ReplicatedLinear(
                self.hidden_size,
                self.kv_lora_rank + self.qk_rope_head_dim,
                bias=False,
                quant_config=quant_config,
992
993
                prefix=f"{prefix}.kv_a_proj_with_mqa",
            )
994
995

        if self.q_lora_rank is not None:
996
997
998
999
1000
1001
1002
1003
            self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
            self.q_b_proj = ColumnParallelLinear(
                self.q_lora_rank,
                self.num_heads * self.qk_head_dim,
                bias=False,
                quant_config=quant_config,
                prefix=f"{prefix}.q_b_proj",
            )
1004
        else:
1005
1006
1007
1008
1009
1010
1011
1012
            self.q_proj = ColumnParallelLinear(
                self.hidden_size,
                self.num_heads * self.qk_head_dim,
                bias=False,
                quant_config=quant_config,
                prefix=f"{prefix}.q_proj",
            )
        self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
1013
1014
1015
1016
1017
        self.kv_b_proj = ColumnParallelLinear(
            self.kv_lora_rank,
            self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
            bias=False,
            quant_config=quant_config,
1018
1019
1020
1021
1022
1023
1024
1025
1026
            prefix=f"{prefix}.kv_b_proj",
        )
        self.o_proj = RowParallelLinear(
            self.num_heads * self.v_head_dim,
            self.hidden_size,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.o_proj",
        )
1027

1028
        if config.rope_parameters["rope_type"] != "default":
1029
1030
1031
1032
1033
1034
            config.rope_parameters["rope_type"] = (
                "deepseek_yarn"
                if config.rope_parameters.get("apply_yarn_scaling", True)
                else "deepseek_llama_scaling"
            )

1035
1036
1037
        self.rotary_emb = get_rope(
            qk_rope_head_dim,
            max_position=max_position_embeddings,
1038
            rope_parameters=config.rope_parameters,
1039
1040
            is_neox_style=False,
        )
1041
1042
1043
1044
1045

        if (
            config.rope_parameters["rope_type"] != "default"
            and config.rope_parameters["rope_type"] == "deepseek_yarn"
        ):
1046
1047
            mscale_all_dim = config.rope_parameters.get("mscale_all_dim", False)
            scaling_factor = config.rope_parameters["factor"]
1048
1049
1050
            mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
            self.scaling = self.scaling * mscale * mscale

1051
1052
1053
        self.is_v32 = hasattr(config, "index_topk")

        if self.is_v32:
1054
1055
1056
            self.indexer_rope_emb = get_rope(
                qk_rope_head_dim,
                max_position=max_position_embeddings,
1057
                rope_parameters=config.rope_parameters,
1058
1059
                is_neox_style=True,
            )
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
            self.indexer = Indexer(
                vllm_config,
                config,
                hidden_size,
                q_lora_rank,
                quant_config,
                cache_config,
                topk_indices_buffer,
                f"{prefix}.indexer",
            )
1070
        else:
1071
            self.indexer_rope_emb = None
1072
1073
            self.indexer = None

1074
1075
        mla_modules = MLAModules(
            kv_a_layernorm=self.kv_a_layernorm,
1076
            kv_b_proj=self.kv_b_proj,
1077
1078
1079
            rotary_emb=self.rotary_emb,
            o_proj=self.o_proj,
            fused_qkv_a_proj=self.fused_qkv_a_proj
1080
1081
            if self.q_lora_rank is not None
            else None,
1082
            kv_a_proj_with_mqa=self.kv_a_proj_with_mqa
1083
1084
1085
            if self.q_lora_rank is None
            else None,
            q_a_layernorm=self.q_a_layernorm if self.q_lora_rank is not None else None,
1086
1087
            q_b_proj=self.q_b_proj if self.q_lora_rank is not None else None,
            q_proj=self.q_proj if self.q_lora_rank is None else None,
1088
            indexer=self.indexer,
1089
            indexer_rotary_emb=self.indexer_rope_emb,
1090
1091
            is_sparse=self.is_v32,
            topk_indices_buffer=topk_indices_buffer,
1092
        )
1093

1094
        self.mla_attn = MultiHeadLatentAttentionWrapper(
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
            self.hidden_size,
            self.num_local_heads,
            self.scaling,
            self.qk_nope_head_dim,
            self.qk_rope_head_dim,
            self.v_head_dim,
            self.q_lora_rank,
            self.kv_lora_rank,
            mla_modules,
            cache_config,
            quant_config,
            prefix,
1107
1108
1109
1110
1111
1112
        )

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
1113
        llama_4_scaling: torch.Tensor | None,
1114
    ) -> torch.Tensor:
1115
        return self.mla_attn(positions, hidden_states, llama_4_scaling)
1116
1117


wangding zeng's avatar
wangding zeng committed
1118
class DeepseekV2DecoderLayer(nn.Module):
1119
1120
1121
1122
    def __init__(
        self,
        vllm_config: VllmConfig,
        prefix: str,
1123
1124
        config: DeepseekV2Config | None = None,
        topk_indices_buffer: torch.Tensor | None = None,
1125
    ) -> None:
wangding zeng's avatar
wangding zeng committed
1126
        super().__init__()
1127

1128
1129
        if config is None:
            config = vllm_config.model_config.hf_config
1130
1131
1132
1133
1134
        model_config = vllm_config.model_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config
        parallel_config = vllm_config.parallel_config

wangding zeng's avatar
wangding zeng committed
1135
        self.hidden_size = config.hidden_size
1136
        max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
1137
        moe_layer_freq = getattr(config, "moe_layer_freq", 1)
1138
1139
        # DecoderLayers are created with `make_layers` which passes the prefix
        # with the layer's index.
1140
        layer_idx = int(prefix.split(sep=".")[-1])
1141
        self.layer_idx = layer_idx
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151

        # verify MLA attention specific fields
        qk_nope_head_dim = getattr(config, "qk_nope_head_dim", 0)
        qk_rope_head_dim = getattr(config, "qk_rope_head_dim", 0)
        v_head_dim = getattr(config, "v_head_dim", 0)
        kv_lora_rank = getattr(config, "kv_lora_rank", 0)
        use_mha = config.model_type == "deepseek" or all(
            dim == 0 for dim in (qk_nope_head_dim, qk_rope_head_dim)
        )

1152
1153
        self.use_mha = use_mha

1154
1155
1156
        if use_mha:
            attn_cls = DeepseekAttention
        elif model_config.use_mla:
1157
1158
1159
1160
            attn_cls = DeepseekV2MLAAttention
        else:
            attn_cls = DeepseekV2Attention
        self.self_attn = attn_cls(
1161
            vllm_config=vllm_config,
wangding zeng's avatar
wangding zeng committed
1162
1163
1164
            config=config,
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
1165
1166
1167
            qk_nope_head_dim=qk_nope_head_dim,
            qk_rope_head_dim=qk_rope_head_dim,
            v_head_dim=v_head_dim,
1168
            q_lora_rank=config.q_lora_rank if hasattr(config, "q_lora_rank") else None,
1169
            kv_lora_rank=kv_lora_rank,
wangding zeng's avatar
wangding zeng committed
1170
1171
1172
            max_position_embeddings=max_position_embeddings,
            cache_config=cache_config,
            quant_config=quant_config,
1173
            prefix=f"{prefix}.self_attn",
1174
            topk_indices_buffer=topk_indices_buffer,
wangding zeng's avatar
wangding zeng committed
1175
        )
1176

1177
1178
1179
        if (
            config.n_routed_experts is not None
            and layer_idx >= config.first_k_dense_replace
1180
            and layer_idx % moe_layer_freq == 0
1181
        ):
1182
1183
            self.mlp = DeepseekV2MoE(
                config=config,
1184
                parallel_config=parallel_config,
1185
1186
1187
                quant_config=quant_config,
                prefix=f"{prefix}.mlp",
            )
wangding zeng's avatar
wangding zeng committed
1188
1189
1190
1191
1192
1193
        else:
            self.mlp = DeepseekV2MLP(
                hidden_size=config.hidden_size,
                intermediate_size=config.intermediate_size,
                hidden_act=config.hidden_act,
                quant_config=quant_config,
1194
                prefix=f"{prefix}.mlp",
wangding zeng's avatar
wangding zeng committed
1195
            )
1196
1197
1198
1199
        self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(
            config.hidden_size, eps=config.rms_norm_eps
        )
1200
        self.routed_scaling_factor = getattr(config, "routed_scaling_factor", 1.0)
wangding zeng's avatar
wangding zeng committed
1201
1202
1203
1204
1205

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
1206
        residual: torch.Tensor | None,
1207
        llama_4_scaling: torch.Tensor | None = None,
wangding zeng's avatar
wangding zeng committed
1208
1209
1210
    ) -> torch.Tensor:
        # Self Attention
        if residual is None:
1211
            residual = hidden_states.clone()
wangding zeng's avatar
wangding zeng committed
1212
1213
            hidden_states = self.input_layernorm(hidden_states)
        else:
1214
            hidden_states, residual = self.input_layernorm(hidden_states, residual)
1215
1216
1217
1218
1219
1220
1221
1222

        attn_kwargs = {
            "positions": positions,
            "hidden_states": hidden_states,
        }
        if not self.use_mha:
            attn_kwargs["llama_4_scaling"] = llama_4_scaling
        hidden_states = self.self_attn(**attn_kwargs)
wangding zeng's avatar
wangding zeng committed
1223

1224
1225
1226
1227
        if (
            not isinstance(self.self_attn, DeepseekAttention)
            and hidden_states.dtype == torch.float16
        ):
1228
1229
1230
            # Fix FP16 overflow
            # We scale both hidden_states and residual before
            # rmsnorm, and rmsnorm result would not affect by scale.
1231
            hidden_states *= 1.0 / self.routed_scaling_factor
1232
1233
1234
            if self.layer_idx == 0:
                # The residual is shared by all layers, we only scale it on
                # first layer.
1235
                residual *= 1.0 / self.routed_scaling_factor
1236
1237

        # Fully Connected
1238
        hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
wangding zeng's avatar
wangding zeng committed
1239
        hidden_states = self.mlp(hidden_states)
1240

1241
        if isinstance(self.mlp, DeepseekV2MLP) and hidden_states.dtype == torch.float16:
1242
1243
1244
1245
1246
            # Fix FP16 overflow
            # Scaling the DeepseekV2MLP output, it is the input of
            # input_layernorm of next decoder layer.
            # The scaling of DeepseekV2MOE output would be done in the forward
            # of DeepseekV2MOE
1247
            hidden_states *= 1.0 / self.routed_scaling_factor
1248

wangding zeng's avatar
wangding zeng committed
1249
1250
1251
        return hidden_states, residual


1252
@support_torch_compile
wangding zeng's avatar
wangding zeng committed
1253
1254
1255
class DeepseekV2Model(nn.Module):
    fall_back_to_pt_during_load = False

1256
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
wangding zeng's avatar
wangding zeng committed
1257
        super().__init__()
1258
1259
1260

        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
1261
        self.config = config
1262
        self.device = current_platform.device_type
1263

wangding zeng's avatar
wangding zeng committed
1264
        self.vocab_size = config.vocab_size
1265
1266
1267
1268
1269
1270
1271
        self.is_v32 = hasattr(config, "index_topk")
        if self.is_v32:
            topk_tokens = config.index_topk
            topk_indices_buffer = torch.empty(
                vllm_config.scheduler_config.max_num_batched_tokens,
                topk_tokens,
                dtype=torch.int32,
1272
                device=self.device,
1273
            )
1274
1275
        else:
            topk_indices_buffer = None
wangding zeng's avatar
wangding zeng committed
1276

1277
1278
1279
1280
        if get_pp_group().is_first_rank:
            self.embed_tokens = VocabParallelEmbedding(
                config.vocab_size,
                config.hidden_size,
1281
                quant_config=quant_config,
1282
1283
                prefix=f"{prefix}.embed_tokens",
            )
1284
1285
1286
1287
        else:
            self.embed_tokens = PPMissingLayer()
        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
1288
            lambda prefix: DeepseekV2DecoderLayer(
1289
                vllm_config, prefix, topk_indices_buffer=topk_indices_buffer
1290
1291
1292
            ),
            prefix=f"{prefix}.layers",
        )
1293
1294
1295
1296
1297

        if get_pp_group().is_last_rank:
            self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        else:
            self.norm = PPMissingLayer()
1298
1299
1300
        self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
            ["hidden_states", "residual"], config.hidden_size
        )
wangding zeng's avatar
wangding zeng committed
1301

1302
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
1303
1304
        return self.embed_tokens(input_ids)

wangding zeng's avatar
wangding zeng committed
1305
1306
1307
1308
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
1309
1310
1311
        intermediate_tensors: IntermediateTensors | None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor | IntermediateTensors:
1312
        if get_pp_group().is_first_rank:
1313
1314
1315
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
1316
                hidden_states = self.embed_input_ids(input_ids)
1317
1318
1319
1320
1321
1322
            residual = None
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]

1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
        # Compute llama 4 scaling once per forward pass if enabled
        llama_4_scaling_config = getattr(self.config, "llama_4_scaling", None)
        llama_4_scaling: torch.Tensor | None
        if llama_4_scaling_config is not None:
            llama_4_scaling = _get_llama_4_scaling(
                original_max_position_embeddings=llama_4_scaling_config[
                    "original_max_position_embeddings"
                ],
                scaling_beta=llama_4_scaling_config["beta"],
                positions=positions,
            )
        else:
            llama_4_scaling = None

1337
        for layer in islice(self.layers, self.start_layer, self.end_layer):
1338
1339
1340
            hidden_states, residual = layer(
                positions, hidden_states, residual, llama_4_scaling
            )
1341
1342

        if not get_pp_group().is_last_rank:
1343
1344
1345
            return IntermediateTensors(
                {"hidden_states": hidden_states, "residual": residual}
            )
1346

wangding zeng's avatar
wangding zeng committed
1347
1348
1349
        hidden_states, _ = self.norm(hidden_states, residual)
        return hidden_states

1350

1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
class DeepseekV2MixtureOfExperts(MixtureOfExperts):
    moe_mlp_layers: list[DeepseekV2MoE]
    """
    List of MoE MLP layers in the model.
    """

    def extract_moe_parameters(self, example_moe: DeepseekV2MoE | None):
        if example_moe is None:
            self.num_moe_layers = 0
            self.num_expert_groups = 0
            self.num_logical_experts = 0
            self.num_physical_experts = 0
            self.num_local_physical_experts = 0
            self.num_routed_experts = 0
            self.num_shared_experts = 0
            self.num_redundant_experts = 0
            logger.warning("DeepSeekV2: No DeepseekV2MoE layer found in model.layers.")
        else:
            self.num_logical_experts = example_moe.n_logical_experts
            self.num_physical_experts = example_moe.n_physical_experts
            self.num_local_physical_experts = example_moe.n_local_physical_experts
            self.num_routed_experts = example_moe.n_routed_experts
            self.num_shared_experts = example_moe.n_shared_experts
            self.num_redundant_experts = example_moe.n_redundant_experts

    def update_physical_experts_metadata(
        self,
        num_physical_experts: int,
        num_local_physical_experts: int,
    ) -> None:
        assert self.num_local_physical_experts == num_local_physical_experts
        self.num_physical_experts = num_physical_experts
        self.num_local_physical_experts = num_local_physical_experts
        self.num_redundant_experts = num_physical_experts - self.num_logical_experts
        for moe in self.moe_mlp_layers:
            moe.n_local_physical_experts = num_local_physical_experts
            moe.n_physical_experts = num_physical_experts
            moe.n_redundant_experts = self.num_redundant_experts
            moe.experts.update_expert_map()


class DeepseekV2ForCausalLM(
1393
    nn.Module, SupportsPP, DeepseekV2MixtureOfExperts, SupportsLoRA, SupportsEagle
1394
):
1395
1396
1397
    packed_modules_mapping = {
        "gate_up_proj": ["gate_proj", "up_proj"],
    }
1398
    model_cls = DeepseekV2Model
1399
1400
1401
1402
1403
1404
1405

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        self.config = config
        self.quant_config = quant_config
1406

1407
1408
1409
1410
1411
1412
1413
1414
1415
        qk_nope_head_dim = getattr(config, "qk_nope_head_dim", 0)
        qk_rope_head_dim = getattr(config, "qk_rope_head_dim", 0)
        self.use_mha = config.model_type == "deepseek" or all(
            dim == 0 for dim in (qk_nope_head_dim, qk_rope_head_dim)
        )

        if self.use_mha:
            self.packed_modules_mapping["qkv_proj"] = ["q_proj", "k_proj", "v_proj"]

1416
1417
1418
1419
        # `packed_modules_mapping` needs to be modified before
        # initializing DeepseekV2Model, as it is passed inplace to
        # quantization config init and may be used to select the
        # quant_method for relevant layers during initialization.
1420
1421
1422
        self.fuse_qkv_a_proj = (
            hasattr(config, "q_lora_rank") and config.q_lora_rank is not None
        )
1423
1424
1425
1426
1427
1428
        if self.fuse_qkv_a_proj:
            self.packed_modules_mapping["fused_qkv_a_proj"] = [
                "q_a_proj",
                "kv_a_proj_with_mqa",
            ]

1429
        self.model = self.model_cls(
1430
1431
            vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
        )
1432
        if get_pp_group().is_last_rank:
1433
1434
1435
1436
1437
1438
            self.lm_head = ParallelLMHead(
                config.vocab_size,
                config.hidden_size,
                quant_config=quant_config,
                prefix=maybe_prefix(prefix, "lm_head"),
            )
1439
1440
1441
1442
        else:
            self.lm_head = PPMissingLayer()
        self.logits_processor = LogitsProcessor(config.vocab_size)
        self.make_empty_intermediate_tensors = (
1443
1444
            self.model.make_empty_intermediate_tensors
        )
1445
1446
1447
1448
1449
1450
1451
        # Set MoE hyperparameters
        self.num_moe_layers = (
            self.config.num_hidden_layers - self.config.first_k_dense_replace
        )
        self.set_moe_parameters()

    def set_moe_parameters(self):
1452
1453
        self.expert_weights = []

1454
        self.num_expert_groups = getattr(self.config, "n_group", 1)
1455

1456
1457
        self.moe_layers = []
        self.moe_mlp_layers = []
1458
        example_moe = None
1459
        for layer in self.model.layers:
1460
1461
1462
            if isinstance(layer, PPMissingLayer):
                continue

1463
1464
            assert isinstance(layer, DeepseekV2DecoderLayer)
            if isinstance(layer.mlp, DeepseekV2MoE):
1465
1466
                # Pick last one layer since the first ones may be dense layers.
                example_moe = layer.mlp
1467
                self.moe_mlp_layers.append(layer.mlp)
1468
1469
                self.moe_layers.append(layer.mlp.experts)

1470
        self.extract_moe_parameters(example_moe)
1471

1472
1473
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.embed_input_ids(input_ids)
1474
1475
1476
1477
1478

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
1479
1480
1481
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor | IntermediateTensors:
1482
1483
1484
        hidden_states = self.model(
            input_ids, positions, intermediate_tensors, inputs_embeds
        )
1485
1486
1487
1488
1489
        return hidden_states

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
1490
    ) -> torch.Tensor | None:
1491
        logits = self.logits_processor(self.lm_head, hidden_states)
1492
1493
        return logits

1494
1495
1496
1497
    def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
        # Params for weights, fp8 weight scales, fp8 activation scales
        # (param_name, weight_name, expert_id, shard_id)
        return SharedFusedMoE.make_expert_params_mapping(
1498
            self,
1499
1500
1501
1502
1503
1504
1505
            ckpt_gate_proj_name="gate_proj",
            ckpt_down_proj_name="down_proj",
            ckpt_up_proj_name="up_proj",
            num_experts=self.config.n_routed_experts,
            num_redundant_experts=0,
        )

1506
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
1507
1508
1509
        rocm_aiter_moe_shared_expert_enabled = (
            rocm_aiter_ops.is_fusion_moe_shared_experts_enabled()
        )
wangding zeng's avatar
wangding zeng committed
1510
1511
1512
1513
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
1514
1515
        ]
        mla_params_mapping = [
1516
1517
            ("fused_qkv_a_proj", "q_a_proj", 0),
            ("fused_qkv_a_proj", "kv_a_proj_with_mqa", 1),
wangding zeng's avatar
wangding zeng committed
1518
        ]
1519
1520
1521
1522
1523
1524
1525
1526
1527
        mha_params_mapping = [
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
        ]
        if self.use_mha:
            stacked_params_mapping.extend(mha_params_mapping)
        else:
            stacked_params_mapping.extend(mla_params_mapping)
wangding zeng's avatar
wangding zeng committed
1528

1529
1530
        # Params for weights, fp8 weight scales, fp8 activation scales
        # (param_name, weight_name, expert_id, shard_id)
1531
        expert_params_mapping = SharedFusedMoE.make_expert_params_mapping(
1532
            self,
1533
1534
1535
            ckpt_gate_proj_name="gate_proj",
            ckpt_down_proj_name="down_proj",
            ckpt_up_proj_name="up_proj",
1536
1537
1538
            num_experts=self.config.n_routed_experts
            + (
                self.config.n_shared_experts
1539
                if rocm_aiter_moe_shared_expert_enabled
1540
1541
                else 0
            ),
1542
1543
            num_redundant_experts=self.num_redundant_experts,
        )
1544

wangding zeng's avatar
wangding zeng committed
1545
        params_dict = dict(self.named_parameters())
1546
        loaded_params: set[str] = set()
wangding zeng's avatar
wangding zeng committed
1547
        for name, loaded_weight in weights:
1548
1549
1550
            if "rotary_emb.inv_freq" in name:
                continue

1551
1552
1553
            spec_layer = get_spec_layer_idx_from_weight_name(self.config, name)
            if spec_layer is not None:
                continue  # skip spec decode layers for main model
1554

1555
1556
            is_fusion_moe_shared_experts_layer = (
                rocm_aiter_moe_shared_expert_enabled and ("mlp.shared_experts" in name)
1557
1558
            )

1559
            for param_name, weight_name, shard_id in stacked_params_mapping:
1560
                # Skip non-stacked layers and experts (experts handled below).
wangding zeng's avatar
wangding zeng committed
1561
1562
                if weight_name not in name:
                    continue
1563
1564
1565
1566
1567
1568
                # We have mlp.experts[0].gate_proj in the checkpoint.
                # Since we handle the experts below in expert_params_mapping,
                # we need to skip here BEFORE we update the name, otherwise
                # name will be updated to mlp.experts[0].gate_up_proj, which
                # will then be updated below in expert_params_mapping
                # for mlp.experts[0].gate_gate_up_proj, which breaks load.
1569
                if ("mlp.experts." in name) and name not in params_dict:
1570
                    continue
1571
                if is_fusion_moe_shared_experts_layer:
1572
                    continue
1573
                name_mapped = name.replace(weight_name, param_name)
1574
1575
1576

                # QKV fusion is optional, fall back to normal
                # weight loading if it's not enabled
1577
                # if go with fusion option, then update name
1578
1579
1580
                if (
                    param_name == "fused_qkv_a_proj"
                ) and name_mapped not in params_dict:
1581
                    continue
1582
1583
                else:
                    name = name_mapped
wangding zeng's avatar
wangding zeng committed
1584
1585
1586
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
1587
1588
1589
1590

                if is_pp_missing_parameter(name, self):
                    continue

wangding zeng's avatar
wangding zeng committed
1591
1592
1593
1594
1595
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
1596
                is_expert_weight = False
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606

                # Special handling: when AITER fusion_shared_experts is enabled,
                # checkpoints may provide a single widened shared_experts tensor
                # without explicit expert indices
                # (e.g. ...mlp.shared_experts.gate_proj.weight).
                # For models with multiple shared experts, split that tensor
                # evenly into per-shared-expert slices and load them into
                # appended expert slots mlp.experts.{n_routed_experts + j}.*
                # accordingly.
                num_chunks = 1
1607
                if is_fusion_moe_shared_experts_layer:
1608
1609
1610
1611
                    num_chunks = getattr(self.config, "n_shared_experts", 1) or 1
                    # Determine split axis based on op type
                    # gate/up: ColumnParallel → split along dim 0
                    # down: RowParallel → split along dim 1
1612
1613
1614
1615
1616
                    split_dim = (
                        1
                        if ("down_proj.weight" in name and loaded_weight.ndim > 1)
                        else 0
                    )
1617
1618
1619
1620
                    total = loaded_weight.shape[split_dim]
                    assert total % num_chunks == 0, (
                        f"Shared expert weight dim {total} "
                        f"not divisible by num_chunks {num_chunks}"
1621
                    )
1622
1623
1624
1625
1626
1627
                    chunk_size = total // num_chunks

                for j in range(num_chunks):
                    chunk_name = name
                    weight_to_load = loaded_weight

1628
                    if is_fusion_moe_shared_experts_layer:
1629
1630
1631
1632
1633
                        chunk_slice = slice(j * chunk_size, (j + 1) * chunk_size)
                        if loaded_weight.ndim == 1:
                            weight_to_load = loaded_weight[chunk_slice]
                        elif split_dim == 0:
                            weight_to_load = loaded_weight[chunk_slice, :]
1634
                        else:
1635
                            weight_to_load = loaded_weight[:, chunk_slice]
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
                        # Synthesize an expert-style name so expert mapping
                        # can route it
                        chunk_name = name.replace(
                            "mlp.shared_experts",
                            f"mlp.experts.{self.config.n_routed_experts + j}",
                        )

                    # Use expert_params_mapping to locate the destination
                    # param and delegate to its expert-aware weight_loader
                    # with expert_id.
                    for mapping in expert_params_mapping:
                        param_name, weight_name, expert_id, shard_id = mapping
                        if weight_name not in chunk_name:
                            continue

                        # Anyway, this is an expert weight and should not be
                        # attempted to load as other weights later
                        is_expert_weight = True

                        # Do not modify `name` since the loop may continue here
                        # Instead, create a new variable
                        name_mapped = chunk_name.replace(weight_name, param_name)

                        if is_pp_missing_parameter(name_mapped, self):
                            continue

                        param = params_dict[name_mapped]
                        # We should ask the weight loader to return success or
                        # not here since otherwise we may skip experts with
                        # other available replicas.
                        weight_loader = typing.cast(
                            Callable[..., bool], param.weight_loader
                        )
                        success = weight_loader(
                            param,
                            weight_to_load,
                            name_mapped,
                            shard_id=shard_id,
                            expert_id=expert_id,
                            return_success=True,
                        )
                        if success:
1678
                            if not is_fusion_moe_shared_experts_layer:
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
                                name = name_mapped
                            else:
                                loaded_params.add(name_mapped)
                            break
                    else:
                        if is_expert_weight:
                            # We've checked that this is an expert weight
                            # However it's not mapped locally to this rank
                            # So we simply skip it
                            continue

                        # Skip loading extra bias for GPTQ models.
                        if name.endswith(".bias") and name not in params_dict:
                            continue

                        # Remapping the name of FP8 kv-scale.
                        name = maybe_remap_kv_scale_name(name, params_dict)
                        if name is None:
                            continue

                        if is_pp_missing_parameter(name, self):
                            continue

                        param = params_dict[name]
                        weight_loader = getattr(
                            param, "weight_loader", default_weight_loader
                        )
                        weight_loader(param, loaded_weight)
1707
            if not is_fusion_moe_shared_experts_layer:
1708
                loaded_params.add(name)
1709

1710
        return loaded_params
1711
1712


1713
1714
1715
1716
class DeepseekForCausalLM(DeepseekV2ForCausalLM):
    pass


1717
1718
class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM):
    pass
1719
1720


1721
1722
# Compatibility with
# https://huggingface.co/deepseek-ai/DeepSeek-V3-Base/blob/main/configuration_deepseek.py
1723
def get_spec_layer_idx_from_weight_name(
1724
1725
    config: DeepseekV2Config | DeepseekV3Config, weight_name: str
) -> int | None:
1726
1727
1728
1729
    if (
        hasattr(config, "num_nextn_predict_layers")
        and config.num_nextn_predict_layers > 0
    ):
1730
1731
        layer_idx = config.num_hidden_layers
        for i in range(config.num_nextn_predict_layers):
1732
            if weight_name.startswith(f"model.layers.{layer_idx + i}."):
1733
1734
                return layer_idx + i
    return None