deepseek_v2.py 57.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

wangding zeng's avatar
wangding zeng committed
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# Copyright 2023 The vLLM team.
# Copyright 2023 DeepSeek-AI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
25
"""Inference-only DeepseekV2/DeepseekV3 model."""
26

27
28
import typing
from collections.abc import Callable, Iterable
29
from itertools import islice
wangding zeng's avatar
wangding zeng committed
30
31
32

import torch
from torch import nn
33
from transformers import DeepseekV2Config, DeepseekV3Config
wangding zeng's avatar
wangding zeng committed
34

35
from vllm._aiter_ops import rocm_aiter_ops
36
from vllm.attention.layer import Attention
37
from vllm.compilation.decorators import support_torch_compile
38
39
40
41
42
43
44
45
from vllm.config import CacheConfig, ParallelConfig, VllmConfig, get_current_vllm_config
from vllm.distributed import (
    get_ep_group,
    get_pp_group,
    get_tensor_model_parallel_rank,
    get_tensor_model_parallel_world_size,
    tensor_model_parallel_all_gather,
)
46
from vllm.logger import init_logger
wangding zeng's avatar
wangding zeng committed
47
from vllm.model_executor.layers.activation import SiluAndMul
48
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
49
from vllm.model_executor.layers.fused_moe import SharedFusedMoE
50
from vllm.model_executor.layers.layernorm import LayerNorm, RMSNorm
51
52
53
from vllm.model_executor.layers.linear import (
    ColumnParallelLinear,
    MergedColumnParallelLinear,
54
    QKVParallelLinear,
55
56
57
    ReplicatedLinear,
    RowParallelLinear,
)
wangding zeng's avatar
wangding zeng committed
58
from vllm.model_executor.layers.logits_processor import LogitsProcessor
59
from vllm.model_executor.layers.mla import MLAModules, MultiHeadLatentAttentionWrapper
60
from vllm.model_executor.layers.quantization import QuantizationConfig
61
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
62
63
    per_token_group_quant_fp8,
)
wangding zeng's avatar
wangding zeng committed
64
from vllm.model_executor.layers.rotary_embedding import get_rope
65
from vllm.model_executor.layers.sparse_attn_indexer import SparseAttnIndexer
wangding zeng's avatar
wangding zeng committed
66
from vllm.model_executor.layers.vocab_parallel_embedding import (
67
68
69
    ParallelLMHead,
    VocabParallelEmbedding,
)
70
from vllm.model_executor.model_loader.weight_utils import (
71
72
73
    default_weight_loader,
    maybe_remap_kv_scale_name,
)
74
from vllm.model_executor.models.utils import sequence_parallel_chunk
75
from vllm.platforms import current_platform
76
from vllm.sequence import IntermediateTensors
77
from vllm.v1.attention.backend import AttentionBackend
78
79
80
from vllm.v1.attention.backends.mla.indexer import (
    DeepseekV32IndexerBackend,
)
81
from vllm.v1.kv_cache_interface import KVCacheSpec, MLAAttentionSpec
wangding zeng's avatar
wangding zeng committed
82

83
from .interfaces import MixtureOfExperts, SupportsEagle, SupportsLoRA, SupportsPP
84
85
86
87
88
89
90
from .utils import (
    PPMissingLayer,
    is_pp_missing_parameter,
    make_empty_intermediate_tensors_factory,
    make_layers,
    maybe_prefix,
)
91

92
93
logger = init_logger(__name__)

wangding zeng's avatar
wangding zeng committed
94

95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
class DeepseekAttention(nn.Module):
    """Normal MHA implementation used by Deepseek v1."""

    def __init__(
        self,
        vllm_config: VllmConfig,
        config: DeepseekV2Config | DeepseekV3Config,
        hidden_size: int,
        num_heads: int,
        max_position_embeddings: int = 8192,
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
        prefix: str = "",
        **kwargs,
    ) -> None:
        super().__init__()
        self.hidden_size = hidden_size
        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = num_heads
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        self.total_num_kv_heads = config.num_key_value_heads
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
        self.head_dim = hidden_size // self.total_num_heads
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scaling = self.head_dim**-0.5
        self.max_position_embeddings = max_position_embeddings

        self.qkv_proj = QKVParallelLinear(
            hidden_size,
            self.head_dim,
            self.total_num_heads,
            self.total_num_kv_heads,
            bias=False,
            quant_config=quant_config,
        )

        self.o_proj = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            hidden_size,
            bias=False,
            quant_config=quant_config,
        )

        self.rotary_emb = get_rope(
            self.head_dim,
            max_position=max_position_embeddings,
151
            rope_parameters=config.rope_parameters,
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
        )
        self.attn = Attention(
            self.num_heads,
            self.head_dim,
            self.scaling,
            num_kv_heads=self.num_kv_heads,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.attn",
        )

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
        q, k = self.rotary_emb(positions, q, k)
        attn_output = self.attn(q, k, v)
        output, _ = self.o_proj(attn_output)
        return output

wangding zeng's avatar
wangding zeng committed
175
176
177
178
179
180
181

class DeepseekV2MLP(nn.Module):
    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
182
        quant_config: QuantizationConfig | None = None,
wangding zeng's avatar
wangding zeng committed
183
        reduce_results: bool = True,
184
        is_sequence_parallel=False,
185
        prefix: str = "",
wangding zeng's avatar
wangding zeng committed
186
187
    ) -> None:
        super().__init__()
188
189
190
191
192

        # If is_sequence_parallel, the input and output tensors are sharded
        # across the ranks within the tp_group. In this case the weights are
        # replicated and no collective ops are needed.
        # Otherwise we use standard TP with an allreduce at the end.
wangding zeng's avatar
wangding zeng committed
193
        self.gate_up_proj = MergedColumnParallelLinear(
194
195
            hidden_size,
            [intermediate_size] * 2,
wangding zeng's avatar
wangding zeng committed
196
            bias=False,
197
            quant_config=quant_config,
198
            disable_tp=is_sequence_parallel,
199
200
201
202
203
            prefix=f"{prefix}.gate_up_proj",
        )
        self.down_proj = RowParallelLinear(
            intermediate_size,
            hidden_size,
wangding zeng's avatar
wangding zeng committed
204
            bias=False,
205
            quant_config=quant_config,
206
            reduce_results=reduce_results,
207
            disable_tp=is_sequence_parallel,
208
209
            prefix=f"{prefix}.down_proj",
        )
wangding zeng's avatar
wangding zeng committed
210
        if hidden_act != "silu":
211
212
213
            raise ValueError(
                f"Unsupported activation: {hidden_act}. Only silu is supported for now."
            )
wangding zeng's avatar
wangding zeng committed
214
215
216
217
218
219
220
221
222
223
224
225
        self.act_fn = SiluAndMul()

    def forward(self, x):
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.down_proj(x)
        return x


class DeepseekV2MoE(nn.Module):
    def __init__(
        self,
226
        config: DeepseekV2Config | DeepseekV3Config,
227
        parallel_config: ParallelConfig,
228
        quant_config: QuantizationConfig | None = None,
229
        prefix: str = "",
wangding zeng's avatar
wangding zeng committed
230
231
232
    ):
        super().__init__()
        self.tp_size = get_tensor_model_parallel_world_size()
233
234
        self.tp_rank = get_tensor_model_parallel_rank()

235
        self.routed_scaling_factor = getattr(config, "routed_scaling_factor", 1.0)
236
237

        self.ep_group = get_ep_group().device_group
238
        self.ep_rank = get_ep_group().rank_in_group
239
240
241
        self.ep_size = self.ep_group.size()
        self.n_routed_experts: int = config.n_routed_experts
        self.n_shared_experts: int = config.n_shared_experts
242

243
        self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe
244

245
        if config.hidden_act != "silu":
246
247
248
249
250
251
252
253
254
255
256
257
            raise ValueError(
                f"Unsupported activation: {config.hidden_act}. "
                "Only silu is supported for now."
            )

        self.gate = ReplicatedLinear(
            config.hidden_size,
            config.n_routed_experts,
            bias=False,
            quant_config=None,
            prefix=f"{prefix}.gate",
        )
258
        if getattr(config, "topk_method", None) == "noaux_tc":
259
            self.gate.e_score_correction_bias = nn.Parameter(
260
261
                torch.empty(config.n_routed_experts, dtype=torch.float32)
            )
262
263
264
        else:
            self.gate.e_score_correction_bias = None

265
        # Load balancing settings.
266
267
        eplb_config = parallel_config.eplb_config
        self.enable_eplb = parallel_config.enable_eplb
268

269
        self.n_redundant_experts = eplb_config.num_redundant_experts
270
        self.n_logical_experts = self.n_routed_experts
271
        self.n_physical_experts = self.n_logical_experts + self.n_redundant_experts
272
273
        self.n_local_physical_experts = self.n_physical_experts // self.ep_size

274
275
276
277
        self.physical_expert_start = self.ep_rank * self.n_local_physical_experts
        self.physical_expert_end = (
            self.physical_expert_start + self.n_local_physical_experts
        )
278

279
        self.is_rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled()
280
281
282
283
        self.is_fusion_moe_shared_experts_enabled = (
            rocm_aiter_ops.is_fusion_moe_shared_experts_enabled()
        )
        if config.n_shared_experts is None or self.is_fusion_moe_shared_experts_enabled:
284
285
            self.shared_experts = None
        else:
286
            intermediate_size = config.moe_intermediate_size * config.n_shared_experts
287

wangding zeng's avatar
wangding zeng committed
288
289
290
291
292
            self.shared_experts = DeepseekV2MLP(
                hidden_size=config.hidden_size,
                intermediate_size=intermediate_size,
                hidden_act=config.hidden_act,
                quant_config=quant_config,
293
                is_sequence_parallel=self.is_sequence_parallel,
294
                reduce_results=False,
295
                prefix=f"{prefix}.shared_experts",
wangding zeng's avatar
wangding zeng committed
296
297
            )

298
299
        self.experts = SharedFusedMoE(
            shared_experts=self.shared_experts,
300
            gate=self.gate,
301
302
303
304
305
306
307
308
            num_experts=config.n_routed_experts,
            top_k=config.num_experts_per_tok,
            hidden_size=config.hidden_size,
            intermediate_size=config.moe_intermediate_size,
            reduce_results=False,
            renormalize=config.norm_topk_prob,
            quant_config=quant_config,
            use_grouped_topk=True,
309
310
            num_expert_group=getattr(config, "n_group", 1),
            topk_group=getattr(config, "topk_group", 1),
311
            prefix=f"{prefix}.experts",
312
            scoring_func=getattr(config, "scoring_func", "softmax"),
313
            # we do scaling outside, set factor to 1.0 to avoid double mul
314
315
            # aiter applies routed_scaling_factor internally
            routed_scaling_factor=1.0
316
            if not self.is_rocm_aiter_moe_enabled
317
            else self.routed_scaling_factor,
318
319
320
321
            e_score_correction_bias=self.gate.e_score_correction_bias,
            enable_eplb=self.enable_eplb,
            num_redundant_experts=self.n_redundant_experts,
            is_sequence_parallel=self.is_sequence_parallel,
322
            n_shared_experts=config.n_shared_experts
323
            if self.is_fusion_moe_shared_experts_enabled
324
            else None,
325
        )
326

wangding zeng's avatar
wangding zeng committed
327
328
329
    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        num_tokens, hidden_dim = hidden_states.shape
        hidden_states = hidden_states.view(-1, hidden_dim)
330
331
332
333
334
335

        # Chunk the hidden states so they aren't replicated across TP ranks.
        # This avoids duplicate computation in self.experts.
        # TODO: We can replace the all_reduce at the end of attn with a
        # reduce_scatter instead of chunking here.
        if self.is_sequence_parallel:
336
            hidden_states = sequence_parallel_chunk(hidden_states)
337

338
339
340
341
342
        if self.experts.is_internal_router:
            # In this case, the gate/router runs inside the FusedMoE class
            fused_moe_out = self.experts(
                hidden_states=hidden_states, router_logits=hidden_states
            )
343
        else:
344
345
346
347
348
            # router_logits: (num_tokens, n_experts)
            router_logits, _ = self.gate(hidden_states)
            fused_moe_out = self.experts(
                hidden_states=hidden_states, router_logits=router_logits
            )
349

350
351
352
        shared_output, final_hidden_states = fused_moe_out
        if self.shared_experts is None:
            assert shared_output is None
353
354
355
356

        # Fix FP16 overflow
        # See DeepseekV2DecoderLayer for more details.
        if hidden_states.dtype != torch.float16:
357
            if not self.is_rocm_aiter_moe_enabled:
358
                final_hidden_states *= self.routed_scaling_factor
359
360
        elif self.shared_experts is not None:
            assert shared_output is not None
361
            shared_output *= 1.0 / self.routed_scaling_factor
362
363
364
365

        if self.shared_experts is not None:
            assert shared_output is not None
            final_hidden_states += shared_output
366

367
368
        if self.is_sequence_parallel:
            final_hidden_states = tensor_model_parallel_all_gather(
369
370
                final_hidden_states, 0
            )
371
372
            final_hidden_states = final_hidden_states[:num_tokens]
        elif self.tp_size > 1:
373
374
375
            final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel(
                final_hidden_states
            )
wangding zeng's avatar
wangding zeng committed
376
377
378
379
380
381

        return final_hidden_states.view(num_tokens, hidden_dim)


def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
    import math
382

wangding zeng's avatar
wangding zeng committed
383
384
385
386
387
    if scale <= 1:
        return 1.0
    return 0.1 * mscale * math.log(scale) + 1.0


388
389
390
391
392
393
394
395
396
397
def _get_llama_4_scaling(
    original_max_position_embeddings: int, scaling_beta: float, positions: torch.Tensor
) -> torch.Tensor:
    scaling = 1 + scaling_beta * torch.log(
        1 + torch.floor(positions / original_max_position_embeddings)
    )
    # Broadcast over num_heads and head_dim
    return scaling[..., None, None]


wangding zeng's avatar
wangding zeng committed
398
399
400
class DeepseekV2Attention(nn.Module):
    def __init__(
        self,
401
        vllm_config: VllmConfig,
402
        config: DeepseekV2Config | DeepseekV3Config,
wangding zeng's avatar
wangding zeng committed
403
404
405
406
407
408
409
410
        hidden_size: int,
        num_heads: int,
        qk_nope_head_dim: int,
        qk_rope_head_dim: int,
        v_head_dim: int,
        q_lora_rank: int,
        kv_lora_rank: int,
        max_position_embeddings: int = 8192,
411
412
413
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
        topk_indices_buffer: torch.Tensor | None = None,
414
        prefix: str = "",
wangding zeng's avatar
wangding zeng committed
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
    ) -> None:
        super().__init__()
        self.hidden_size = hidden_size
        self.qk_nope_head_dim = qk_nope_head_dim
        self.qk_rope_head_dim = qk_rope_head_dim
        self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
        self.v_head_dim = v_head_dim
        self.q_lora_rank = q_lora_rank
        self.kv_lora_rank = kv_lora_rank
        self.num_heads = num_heads
        tp_size = get_tensor_model_parallel_world_size()
        assert num_heads % tp_size == 0
        self.num_local_heads = num_heads // tp_size
        self.scaling = self.qk_head_dim**-0.5
        self.max_position_embeddings = max_position_embeddings
430
431
        assert topk_indices_buffer is None, (
            "topk_indices_buffer is not \
432
        supported for DeepseekV2Attention"
433
        )
wangding zeng's avatar
wangding zeng committed
434
435

        if self.q_lora_rank is not None:
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
            self.q_a_proj = ReplicatedLinear(
                self.hidden_size,
                self.q_lora_rank,
                bias=False,
                quant_config=quant_config,
                prefix=f"{prefix}.q_a_proj",
            )
            self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
            self.q_b_proj = ColumnParallelLinear(
                q_lora_rank,
                self.num_heads * self.qk_head_dim,
                bias=False,
                quant_config=quant_config,
                prefix=f"{prefix}.q_b_proj",
            )
wangding zeng's avatar
wangding zeng committed
451
        else:
452
453
454
455
456
457
458
            self.q_proj = ColumnParallelLinear(
                self.hidden_size,
                self.num_heads * self.qk_head_dim,
                bias=False,
                quant_config=quant_config,
                prefix=f"{prefix}.q_proj",
            )
wangding zeng's avatar
wangding zeng committed
459

460
461
462
463
464
        self.kv_a_proj_with_mqa = ReplicatedLinear(
            self.hidden_size,
            self.kv_lora_rank + self.qk_rope_head_dim,
            bias=False,
            quant_config=quant_config,
465
466
467
            prefix=f"{prefix}.kv_a_proj_with_mqa",
        )
        self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
wangding zeng's avatar
wangding zeng committed
468
469
470
471
        self.kv_b_proj = ColumnParallelLinear(
            self.kv_lora_rank,
            self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
            bias=False,
472
            quant_config=quant_config,
473
474
            prefix=f"{prefix}.kv_b_proj",
        )
wangding zeng's avatar
wangding zeng committed
475
        # O projection.
476
477
478
479
480
481
482
        self.o_proj = RowParallelLinear(
            self.num_heads * self.v_head_dim,
            self.hidden_size,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.o_proj",
        )
483
        if config.rope_parameters["rope_type"] != "default":
484
485
486
487
488
            config.rope_parameters["rope_type"] = (
                "deepseek_yarn"
                if config.rope_parameters.get("apply_yarn_scaling", True)
                else "deepseek_llama_scaling"
            )
489

490
491
492
        self.rotary_emb = get_rope(
            qk_rope_head_dim,
            max_position=max_position_embeddings,
493
            rope_parameters=config.rope_parameters,
494
495
            is_neox_style=False,
        )
wangding zeng's avatar
wangding zeng committed
496

497
498
499
500
        if (
            config.rope_parameters["rope_type"] != "default"
            and config.rope_parameters["rope_type"] == "deepseek_yarn"
        ):
501
502
            mscale_all_dim = config.rope_parameters.get("mscale_all_dim", False)
            scaling_factor = config.rope_parameters["factor"]
wangding zeng's avatar
wangding zeng committed
503
504
505
            mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
            self.scaling = self.scaling * mscale * mscale

506
507
508
509
510
511
512
513
514
        self.attn = Attention(
            self.num_local_heads,
            self.qk_head_dim,
            self.scaling,
            num_kv_heads=self.num_local_heads,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.attn",
        )
wangding zeng's avatar
wangding zeng committed
515
516
517
518
519

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
520
        llama_4_scaling: torch.Tensor | None,
wangding zeng's avatar
wangding zeng committed
521
522
523
524
    ) -> torch.Tensor:
        if self.q_lora_rank is not None:
            q = self.q_a_proj(hidden_states)[0]
            q = self.q_a_layernorm(q)
525
            q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
wangding zeng's avatar
wangding zeng committed
526
        else:
527
528
529
530
            q = self.q_proj(hidden_states)[0].view(
                -1, self.num_local_heads, self.qk_head_dim
            )
        q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
wangding zeng's avatar
wangding zeng committed
531
        latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
532
        kv_a, _ = latent_cache.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
wangding zeng's avatar
wangding zeng committed
533
        latent_cache = latent_cache.unsqueeze(1)
534
        kv_a = self.kv_a_layernorm(kv_a)
wangding zeng's avatar
wangding zeng committed
535
        kv = self.kv_b_proj(kv_a)[0]
536
        kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim)
wangding zeng's avatar
wangding zeng committed
537
        k_nope, v = kv.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
538
        k_pe = latent_cache[:, :, self.kv_lora_rank :]
539

wangding zeng's avatar
wangding zeng committed
540
        q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
541

542
        q[..., self.qk_nope_head_dim :] = q_pe
wangding zeng's avatar
wangding zeng committed
543
        k = torch.empty_like(q)
544
545
        k[..., : self.qk_nope_head_dim] = k_nope
        k[..., self.qk_nope_head_dim :] = k_pe
546
547
548
549
550

        # Apply llama 4 scaling if provided
        if llama_4_scaling is not None:
            q *= llama_4_scaling

551
552
        # padding value to qk_head_dim for alignment
        v = torch.nn.functional.pad(
553
554
            v, [0, self.qk_head_dim - self.v_head_dim], value=0
        ).view(-1, self.num_local_heads * self.qk_head_dim)
555
        attn_output = self.attn(q, k, v)
556
557
558
        attn_output = attn_output.view(-1, self.num_local_heads, self.qk_head_dim)[
            ..., : self.v_head_dim
        ].reshape(-1, self.num_local_heads * self.v_head_dim)
wangding zeng's avatar
wangding zeng committed
559
560
561
562
        output, _ = self.o_proj(attn_output)
        return output


563
class DeepseekV32IndexerCache(torch.nn.Module, AttentionLayerBase):
564
565
566
    def __init__(
        self, head_dim: int, dtype: torch.dtype, prefix: str, cache_config: CacheConfig
    ):
567
568
569
570
571
572
573
574
575
576
577
        super().__init__()
        self.kv_cache = [torch.tensor([])]
        self.head_dim = head_dim
        self.prefix = prefix
        self.cache_config = cache_config
        self.dtype = dtype
        compilation_config = get_current_vllm_config().compilation_config
        if prefix in compilation_config.static_forward_context:
            raise ValueError(f"Duplicate layer name: {prefix}")
        compilation_config.static_forward_context[prefix] = self

578
    def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
579
580
581
582
583
584
585
        return MLAAttentionSpec(  # Only has one vector instead of K + V
            block_size=self.cache_config.block_size,
            num_kv_heads=1,
            head_size=self.head_dim,
            dtype=self.dtype,
        )

586
    def forward(self): ...
587
588
589
590
591
592

    def get_attn_backend(self) -> AttentionBackend:
        return DeepseekV32IndexerBackend


class Indexer(nn.Module):
593
594
595
    def __init__(
        self,
        vllm_config: VllmConfig,
596
        config: DeepseekV2Config | DeepseekV3Config,
597
598
        hidden_size: int,
        q_lora_rank: int,
599
600
601
        quant_config: QuantizationConfig | None,
        cache_config: CacheConfig | None,
        topk_indices_buffer: torch.Tensor | None,
602
603
        prefix: str = "",
    ):
604
605
606
607
608
609
610
611
612
613
        super().__init__()
        self.vllm_config = vllm_config
        self.config = config
        # self.indexer_cfg = config.attn_module_list_cfg[0]["attn_index"]
        self.topk_tokens = config.index_topk
        self.n_head = config.index_n_heads  # 64
        self.head_dim = config.index_head_dim  # 128
        self.rope_dim = config.qk_rope_head_dim  # 64
        self.q_lora_rank = q_lora_rank  # 1536
        # no tensor parallel, just replicated
614
615
616
617
618
619
620
621
622
623
624
625
626
627
        self.wq_b = ReplicatedLinear(
            self.q_lora_rank,
            self.head_dim * self.n_head,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.wq_b",
        )
        self.wk = ReplicatedLinear(
            hidden_size,
            self.head_dim,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.wk",
        )
628
        self.k_norm = LayerNorm(self.head_dim, eps=1e-6)
629
        self.weights_proj = ReplicatedLinear(
630
631
632
633
634
            hidden_size,
            self.n_head,
            bias=False,
            quant_config=None,
            prefix=f"{prefix}.weights_proj",
635
        )
636
637
638
639
640
641
642
643
644
645
        self.softmax_scale = self.head_dim**-0.5

        self.scale_fmt = "ue8m0"
        self.quant_block_size = 128  # TODO: get from config
        self.topk_indices_buffer = topk_indices_buffer

        # NOTE: (zyongye) we use fp8 naive cache,
        #       where we store value in fp8 and scale in fp32
        #       per self.quant_block_size element
        self.k_cache = DeepseekV32IndexerCache(
646
            head_dim=self.head_dim + self.head_dim // self.quant_block_size * 4,
647
648
            dtype=torch.uint8,
            prefix=f"{prefix}.k_cache",
649
650
            cache_config=cache_config,
        )
651
652
        self.max_model_len = vllm_config.model_config.max_model_len
        self.prefix = prefix
653
654
        from vllm.v1.attention.backends.mla.indexer import get_max_prefill_buffer_size

655
        self.max_total_seq_len = get_max_prefill_buffer_size(vllm_config)
656
657
658
659
660
661
662
663
664
665
        self.indexer_op = SparseAttnIndexer(
            self.k_cache,
            self.quant_block_size,
            self.scale_fmt,
            self.topk_tokens,
            self.head_dim,
            self.max_model_len,
            self.max_total_seq_len,
            self.topk_indices_buffer,
        )
666

667
668
669
    def forward(
        self, hidden_states: torch.Tensor, qr: torch.Tensor, positions, rotary_emb
    ) -> torch.Tensor:
670
671
672
        q, _ = self.wq_b(qr)
        q = q.view(-1, self.n_head, self.head_dim)
        q_pe, q_nope = torch.split(
673
674
            q, [self.rope_dim, self.head_dim - self.rope_dim], dim=-1
        )
675
676
677
678

        k, _ = self.wk(hidden_states)
        k = self.k_norm(k)
        k_pe, k_nope = torch.split(
679
680
            k, [self.rope_dim, self.head_dim - self.rope_dim], dim=-1
        )
681
682

        q_pe, k_pe = rotary_emb(positions, q_pe, k_pe.unsqueeze(1))
683
684
685
686
687
        # Note: RoPE (NeoX) can introduce extra leading dimensions during compilation
        # so we need to reshape back to token-flattened shapes
        q_pe = q_pe.reshape(-1, self.n_head, self.rope_dim)
        k_pe = k_pe.reshape(-1, 1, self.rope_dim)

688
689
        # `rotary_emb` is shape-preserving; `q_pe` is already
        # [num_tokens, n_head, rope_dim].
690
691
        q = torch.cat([q_pe, q_nope], dim=-1)
        # `k_pe` is [num_tokens, 1, rope_dim] (MQA).
692
        k = torch.cat([k_pe.squeeze(-2), k_nope], dim=-1)
693
694
695

        # we only quant q here since k quant is fused with cache insertion
        q = q.view(-1, self.head_dim)
696
697
698
699
700
701
        q_fp8, q_scale = per_token_group_quant_fp8(
            q,
            self.quant_block_size,
            column_major_scales=False,
            use_ue8m0=self.scale_fmt is not None,
        )
702
703
704
705
        q_fp8 = q_fp8.view(-1, self.n_head, self.head_dim)
        q_scale = q_scale.view(-1, self.n_head, 1)

        weights, _ = self.weights_proj(hidden_states)
706
707
708
        weights = (
            weights.unsqueeze(-1) * q_scale * self.softmax_scale * self.n_head**-0.5
        )
709
710
        weights = weights.squeeze(-1)

711
        return self.indexer_op(hidden_states, q_fp8, k, weights)
712
713


714
715
716
717
class DeepseekV2MLAAttention(nn.Module):
    """
    Main reference: DeepseekV2 paper, and FlashInfer Implementation
    (https://arxiv.org/abs/2405.04434 and https://github.com/flashinfer-ai/flashinfer/pull/551).
718

719
720
        For more info see MLACommonImpl in:
        vllm/v1/attention/backends/mla/utils.py
721
722
723
724
    """

    def __init__(
        self,
725
        vllm_config: VllmConfig,
726
        config: DeepseekV2Config | DeepseekV3Config,
727
728
729
730
731
        hidden_size: int,
        num_heads: int,
        qk_nope_head_dim: int,
        qk_rope_head_dim: int,
        v_head_dim: int,
732
        q_lora_rank: int | None,
733
734
        kv_lora_rank: int,
        max_position_embeddings: int = 8192,
735
736
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
737
        prefix: str = "",
738
        topk_indices_buffer: torch.Tensor | None = None,
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
    ) -> None:
        super().__init__()
        self.hidden_size = hidden_size
        self.qk_nope_head_dim = qk_nope_head_dim
        self.qk_rope_head_dim = qk_rope_head_dim
        self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
        self.v_head_dim = v_head_dim

        self.q_lora_rank = q_lora_rank
        self.kv_lora_rank = kv_lora_rank

        self.num_heads = num_heads
        tp_size = get_tensor_model_parallel_world_size()
        assert num_heads % tp_size == 0
        self.num_local_heads = num_heads // tp_size

        self.scaling = self.qk_head_dim**-0.5
        self.max_position_embeddings = max_position_embeddings

        if self.q_lora_rank is not None:
759
            self.fused_qkv_a_proj = MergedColumnParallelLinear(
760
761
762
763
                self.hidden_size,
                [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim],
                bias=False,
                quant_config=quant_config,
764
                prefix=f"{prefix}.fused_qkv_a_proj",
765
766
                disable_tp=True,
            )
767
768
769
770
771
772
        else:
            self.kv_a_proj_with_mqa = ReplicatedLinear(
                self.hidden_size,
                self.kv_lora_rank + self.qk_rope_head_dim,
                bias=False,
                quant_config=quant_config,
773
774
                prefix=f"{prefix}.kv_a_proj_with_mqa",
            )
775
776

        if self.q_lora_rank is not None:
777
778
779
780
781
782
783
784
            self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
            self.q_b_proj = ColumnParallelLinear(
                self.q_lora_rank,
                self.num_heads * self.qk_head_dim,
                bias=False,
                quant_config=quant_config,
                prefix=f"{prefix}.q_b_proj",
            )
785
        else:
786
787
788
789
790
791
792
793
            self.q_proj = ColumnParallelLinear(
                self.hidden_size,
                self.num_heads * self.qk_head_dim,
                bias=False,
                quant_config=quant_config,
                prefix=f"{prefix}.q_proj",
            )
        self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
794
795
796
797
798
        self.kv_b_proj = ColumnParallelLinear(
            self.kv_lora_rank,
            self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
            bias=False,
            quant_config=quant_config,
799
800
801
802
803
804
805
806
807
            prefix=f"{prefix}.kv_b_proj",
        )
        self.o_proj = RowParallelLinear(
            self.num_heads * self.v_head_dim,
            self.hidden_size,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.o_proj",
        )
808

809
        if config.rope_parameters["rope_type"] != "default":
810
811
812
813
814
815
            config.rope_parameters["rope_type"] = (
                "deepseek_yarn"
                if config.rope_parameters.get("apply_yarn_scaling", True)
                else "deepseek_llama_scaling"
            )

816
817
818
        self.rotary_emb = get_rope(
            qk_rope_head_dim,
            max_position=max_position_embeddings,
819
            rope_parameters=config.rope_parameters,
820
821
            is_neox_style=False,
        )
822
823
824
825
826

        if (
            config.rope_parameters["rope_type"] != "default"
            and config.rope_parameters["rope_type"] == "deepseek_yarn"
        ):
827
828
            mscale_all_dim = config.rope_parameters.get("mscale_all_dim", False)
            scaling_factor = config.rope_parameters["factor"]
829
830
831
            mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
            self.scaling = self.scaling * mscale * mscale

832
        self.is_v32 = hasattr(config, "index_topk")
833
834

        if self.is_v32:
835
836
837
            self.indexer_rope_emb = get_rope(
                qk_rope_head_dim,
                max_position=max_position_embeddings,
838
                rope_parameters=config.rope_parameters,
839
840
                is_neox_style=True,
            )
841
842
843
844
845
846
847
848
849
850
            self.indexer = Indexer(
                vllm_config,
                config,
                hidden_size,
                q_lora_rank,
                quant_config,
                cache_config,
                topk_indices_buffer,
                f"{prefix}.indexer",
            )
851
        else:
852
            self.indexer_rope_emb = None
853
854
            self.indexer = None

855
856
        mla_modules = MLAModules(
            kv_a_layernorm=self.kv_a_layernorm,
857
            kv_b_proj=self.kv_b_proj,
858
859
860
            rotary_emb=self.rotary_emb,
            o_proj=self.o_proj,
            fused_qkv_a_proj=self.fused_qkv_a_proj
861
862
            if self.q_lora_rank is not None
            else None,
863
            kv_a_proj_with_mqa=self.kv_a_proj_with_mqa
864
865
866
            if self.q_lora_rank is None
            else None,
            q_a_layernorm=self.q_a_layernorm if self.q_lora_rank is not None else None,
867
868
            q_b_proj=self.q_b_proj if self.q_lora_rank is not None else None,
            q_proj=self.q_proj if self.q_lora_rank is None else None,
869
            indexer=self.indexer,
870
            indexer_rotary_emb=self.indexer_rope_emb,
871
872
            is_sparse=self.is_v32,
            topk_indices_buffer=topk_indices_buffer,
873
        )
874

875
        self.mla_attn = MultiHeadLatentAttentionWrapper(
876
877
878
879
880
881
882
883
884
885
886
887
            self.hidden_size,
            self.num_local_heads,
            self.scaling,
            self.qk_nope_head_dim,
            self.qk_rope_head_dim,
            self.v_head_dim,
            self.q_lora_rank,
            self.kv_lora_rank,
            mla_modules,
            cache_config,
            quant_config,
            prefix,
888
889
890
891
892
893
        )

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
894
        llama_4_scaling: torch.Tensor | None,
895
    ) -> torch.Tensor:
896
        return self.mla_attn(positions, hidden_states, llama_4_scaling)
897
898


wangding zeng's avatar
wangding zeng committed
899
class DeepseekV2DecoderLayer(nn.Module):
900
901
902
903
    def __init__(
        self,
        vllm_config: VllmConfig,
        prefix: str,
904
905
        config: DeepseekV2Config | None = None,
        topk_indices_buffer: torch.Tensor | None = None,
906
    ) -> None:
wangding zeng's avatar
wangding zeng committed
907
        super().__init__()
908

909
910
        if config is None:
            config = vllm_config.model_config.hf_config
911
912
913
914
915
        model_config = vllm_config.model_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config
        parallel_config = vllm_config.parallel_config

wangding zeng's avatar
wangding zeng committed
916
        self.hidden_size = config.hidden_size
917
        max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
918
        moe_layer_freq = getattr(config, "moe_layer_freq", 1)
919
920
        # DecoderLayers are created with `make_layers` which passes the prefix
        # with the layer's index.
921
        layer_idx = int(prefix.split(sep=".")[-1])
922
        self.layer_idx = layer_idx
923
924
925
926
927
928
929
930
931
932

        # verify MLA attention specific fields
        qk_nope_head_dim = getattr(config, "qk_nope_head_dim", 0)
        qk_rope_head_dim = getattr(config, "qk_rope_head_dim", 0)
        v_head_dim = getattr(config, "v_head_dim", 0)
        kv_lora_rank = getattr(config, "kv_lora_rank", 0)
        use_mha = config.model_type == "deepseek" or all(
            dim == 0 for dim in (qk_nope_head_dim, qk_rope_head_dim)
        )

933
934
        self.use_mha = use_mha

935
936
937
        if use_mha:
            attn_cls = DeepseekAttention
        elif model_config.use_mla:
938
939
940
941
            attn_cls = DeepseekV2MLAAttention
        else:
            attn_cls = DeepseekV2Attention
        self.self_attn = attn_cls(
942
            vllm_config=vllm_config,
wangding zeng's avatar
wangding zeng committed
943
944
945
            config=config,
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
946
947
948
            qk_nope_head_dim=qk_nope_head_dim,
            qk_rope_head_dim=qk_rope_head_dim,
            v_head_dim=v_head_dim,
949
            q_lora_rank=config.q_lora_rank if hasattr(config, "q_lora_rank") else None,
950
            kv_lora_rank=kv_lora_rank,
wangding zeng's avatar
wangding zeng committed
951
952
953
            max_position_embeddings=max_position_embeddings,
            cache_config=cache_config,
            quant_config=quant_config,
954
            prefix=f"{prefix}.self_attn",
955
            topk_indices_buffer=topk_indices_buffer,
wangding zeng's avatar
wangding zeng committed
956
        )
957

958
959
960
        if (
            config.n_routed_experts is not None
            and layer_idx >= config.first_k_dense_replace
961
            and layer_idx % moe_layer_freq == 0
962
        ):
963
964
            self.mlp = DeepseekV2MoE(
                config=config,
965
                parallel_config=parallel_config,
966
967
968
                quant_config=quant_config,
                prefix=f"{prefix}.mlp",
            )
wangding zeng's avatar
wangding zeng committed
969
970
971
972
973
974
        else:
            self.mlp = DeepseekV2MLP(
                hidden_size=config.hidden_size,
                intermediate_size=config.intermediate_size,
                hidden_act=config.hidden_act,
                quant_config=quant_config,
975
                prefix=f"{prefix}.mlp",
wangding zeng's avatar
wangding zeng committed
976
            )
977
978
979
980
        self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(
            config.hidden_size, eps=config.rms_norm_eps
        )
981
        self.routed_scaling_factor = getattr(config, "routed_scaling_factor", 1.0)
wangding zeng's avatar
wangding zeng committed
982
983
984
985
986

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
987
        residual: torch.Tensor | None,
988
        llama_4_scaling: torch.Tensor | None = None,
wangding zeng's avatar
wangding zeng committed
989
990
991
    ) -> torch.Tensor:
        # Self Attention
        if residual is None:
992
            residual = hidden_states.clone()
wangding zeng's avatar
wangding zeng committed
993
994
            hidden_states = self.input_layernorm(hidden_states)
        else:
995
            hidden_states, residual = self.input_layernorm(hidden_states, residual)
996
997
998
999
1000
1001
1002
1003

        attn_kwargs = {
            "positions": positions,
            "hidden_states": hidden_states,
        }
        if not self.use_mha:
            attn_kwargs["llama_4_scaling"] = llama_4_scaling
        hidden_states = self.self_attn(**attn_kwargs)
wangding zeng's avatar
wangding zeng committed
1004

1005
1006
1007
1008
        if (
            not isinstance(self.self_attn, DeepseekAttention)
            and hidden_states.dtype == torch.float16
        ):
1009
1010
1011
            # Fix FP16 overflow
            # We scale both hidden_states and residual before
            # rmsnorm, and rmsnorm result would not affect by scale.
1012
            hidden_states *= 1.0 / self.routed_scaling_factor
1013
1014
1015
            if self.layer_idx == 0:
                # The residual is shared by all layers, we only scale it on
                # first layer.
1016
                residual *= 1.0 / self.routed_scaling_factor
1017
1018

        # Fully Connected
1019
        hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
wangding zeng's avatar
wangding zeng committed
1020
        hidden_states = self.mlp(hidden_states)
1021

1022
        if isinstance(self.mlp, DeepseekV2MLP) and hidden_states.dtype == torch.float16:
1023
1024
1025
1026
1027
            # Fix FP16 overflow
            # Scaling the DeepseekV2MLP output, it is the input of
            # input_layernorm of next decoder layer.
            # The scaling of DeepseekV2MOE output would be done in the forward
            # of DeepseekV2MOE
1028
            hidden_states *= 1.0 / self.routed_scaling_factor
1029

wangding zeng's avatar
wangding zeng committed
1030
1031
1032
        return hidden_states, residual


1033
@support_torch_compile
wangding zeng's avatar
wangding zeng committed
1034
1035
1036
class DeepseekV2Model(nn.Module):
    fall_back_to_pt_during_load = False

1037
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
wangding zeng's avatar
wangding zeng committed
1038
        super().__init__()
1039
1040
1041

        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
1042
        self.config = config
1043
        self.device = current_platform.device_type
1044

wangding zeng's avatar
wangding zeng committed
1045
        self.vocab_size = config.vocab_size
1046
        self.is_v32 = hasattr(config, "index_topk")
1047
1048
1049
1050
1051
1052
        if self.is_v32:
            topk_tokens = config.index_topk
            topk_indices_buffer = torch.empty(
                vllm_config.scheduler_config.max_num_batched_tokens,
                topk_tokens,
                dtype=torch.int32,
1053
                device=self.device,
1054
            )
1055
1056
        else:
            topk_indices_buffer = None
wangding zeng's avatar
wangding zeng committed
1057

1058
1059
1060
1061
        if get_pp_group().is_first_rank:
            self.embed_tokens = VocabParallelEmbedding(
                config.vocab_size,
                config.hidden_size,
1062
                quant_config=quant_config,
1063
1064
                prefix=f"{prefix}.embed_tokens",
            )
1065
1066
1067
1068
        else:
            self.embed_tokens = PPMissingLayer()
        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
1069
            lambda prefix: DeepseekV2DecoderLayer(
1070
                vllm_config, prefix, topk_indices_buffer=topk_indices_buffer
1071
1072
1073
            ),
            prefix=f"{prefix}.layers",
        )
1074
1075
1076
1077
1078

        if get_pp_group().is_last_rank:
            self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        else:
            self.norm = PPMissingLayer()
1079
1080
1081
        self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
            ["hidden_states", "residual"], config.hidden_size
        )
wangding zeng's avatar
wangding zeng committed
1082

1083
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
1084
1085
        return self.embed_tokens(input_ids)

wangding zeng's avatar
wangding zeng committed
1086
1087
    def forward(
        self,
zhuwenwen's avatar
zhuwenwen committed
1088
        input_ids: torch.Tensor,
wangding zeng's avatar
wangding zeng committed
1089
        positions: torch.Tensor,
1090
1091
1092
        intermediate_tensors: IntermediateTensors | None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor | IntermediateTensors:
1093
        if get_pp_group().is_first_rank:
1094
1095
1096
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
1097
                hidden_states = self.embed_input_ids(input_ids)
1098
1099
1100
1101
1102
1103
            residual = None
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]

1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
        # Compute llama 4 scaling once per forward pass if enabled
        llama_4_scaling_config = getattr(self.config, "llama_4_scaling", None)
        llama_4_scaling: torch.Tensor | None
        if llama_4_scaling_config is not None:
            llama_4_scaling = _get_llama_4_scaling(
                original_max_position_embeddings=llama_4_scaling_config[
                    "original_max_position_embeddings"
                ],
                scaling_beta=llama_4_scaling_config["beta"],
                positions=positions,
            )
        else:
            llama_4_scaling = None

1118
        for layer in islice(self.layers, self.start_layer, self.end_layer):
1119
1120
1121
            hidden_states, residual = layer(
                positions, hidden_states, residual, llama_4_scaling
            )
1122
1123

        if not get_pp_group().is_last_rank:
1124
1125
1126
            return IntermediateTensors(
                {"hidden_states": hidden_states, "residual": residual}
            )
1127

wangding zeng's avatar
wangding zeng committed
1128
1129
1130
1131
        hidden_states, _ = self.norm(hidden_states, residual)
        return hidden_states


1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
class DeepseekV2MixtureOfExperts(MixtureOfExperts):
    moe_mlp_layers: list[DeepseekV2MoE]
    """
    List of MoE MLP layers in the model.
    """

    def extract_moe_parameters(self, example_moe: DeepseekV2MoE | None):
        if example_moe is None:
            self.num_moe_layers = 0
            self.num_expert_groups = 0
            self.num_logical_experts = 0
            self.num_physical_experts = 0
            self.num_local_physical_experts = 0
            self.num_routed_experts = 0
            self.num_shared_experts = 0
            self.num_redundant_experts = 0
            logger.warning("DeepSeekV2: No DeepseekV2MoE layer found in model.layers.")
        else:
            self.num_logical_experts = example_moe.n_logical_experts
            self.num_physical_experts = example_moe.n_physical_experts
            self.num_local_physical_experts = example_moe.n_local_physical_experts
            self.num_routed_experts = example_moe.n_routed_experts
            self.num_shared_experts = example_moe.n_shared_experts
            self.num_redundant_experts = example_moe.n_redundant_experts

    def update_physical_experts_metadata(
        self,
        num_physical_experts: int,
        num_local_physical_experts: int,
    ) -> None:
        assert self.num_local_physical_experts == num_local_physical_experts
        self.num_physical_experts = num_physical_experts
        self.num_local_physical_experts = num_local_physical_experts
        self.num_redundant_experts = num_physical_experts - self.num_logical_experts
        for moe in self.moe_mlp_layers:
            moe.n_local_physical_experts = num_local_physical_experts
            moe.n_physical_experts = num_physical_experts
            moe.n_redundant_experts = self.num_redundant_experts
            moe.experts.update_expert_map()


class DeepseekV2ForCausalLM(
1174
    nn.Module, SupportsPP, DeepseekV2MixtureOfExperts, SupportsLoRA, SupportsEagle
1175
):
1176
1177
1178
    packed_modules_mapping = {
        "gate_up_proj": ["gate_proj", "up_proj"],
    }
1179
    model_cls = DeepseekV2Model
wangding zeng's avatar
wangding zeng committed
1180

1181
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
wangding zeng's avatar
wangding zeng committed
1182
        super().__init__()
1183
1184
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
wangding zeng's avatar
wangding zeng committed
1185
1186
        self.config = config
        self.quant_config = quant_config
1187

1188
1189
1190
1191
1192
1193
1194
1195
1196
        qk_nope_head_dim = getattr(config, "qk_nope_head_dim", 0)
        qk_rope_head_dim = getattr(config, "qk_rope_head_dim", 0)
        self.use_mha = config.model_type == "deepseek" or all(
            dim == 0 for dim in (qk_nope_head_dim, qk_rope_head_dim)
        )

        if self.use_mha:
            self.packed_modules_mapping["qkv_proj"] = ["q_proj", "k_proj", "v_proj"]

1197
1198
1199
1200
        # `packed_modules_mapping` needs to be modified before
        # initializing DeepseekV2Model, as it is passed inplace to
        # quantization config init and may be used to select the
        # quant_method for relevant layers during initialization.
1201
1202
1203
        self.fuse_qkv_a_proj = (
            hasattr(config, "q_lora_rank") and config.q_lora_rank is not None
        )
1204
1205
1206
1207
1208
1209
        if self.fuse_qkv_a_proj:
            self.packed_modules_mapping["fused_qkv_a_proj"] = [
                "q_a_proj",
                "kv_a_proj_with_mqa",
            ]

1210
        self.model = self.model_cls(
1211
1212
            vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
        )
1213
        if get_pp_group().is_last_rank:
1214
1215
1216
1217
1218
1219
            self.lm_head = ParallelLMHead(
                config.vocab_size,
                config.hidden_size,
                quant_config=quant_config,
                prefix=maybe_prefix(prefix, "lm_head"),
            )
1220
1221
        else:
            self.lm_head = PPMissingLayer()
wangding zeng's avatar
wangding zeng committed
1222
        self.logits_processor = LogitsProcessor(config.vocab_size)
1223
        self.make_empty_intermediate_tensors = (
1224
1225
            self.model.make_empty_intermediate_tensors
        )
1226
1227
1228
1229
1230
1231
1232
        # Set MoE hyperparameters
        self.num_moe_layers = (
            self.config.num_hidden_layers - self.config.first_k_dense_replace
        )
        self.set_moe_parameters()

    def set_moe_parameters(self):
1233
1234
        self.expert_weights = []

1235
        self.num_expert_groups = getattr(self.config, "n_group", 1)
1236

1237
1238
        self.moe_layers = []
        self.moe_mlp_layers = []
1239
        example_moe = None
1240
        for layer in self.model.layers:
1241
1242
1243
            if isinstance(layer, PPMissingLayer):
                continue

1244
1245
            assert isinstance(layer, DeepseekV2DecoderLayer)
            if isinstance(layer.mlp, DeepseekV2MoE):
1246
1247
                # Pick last one layer since the first ones may be dense layers.
                example_moe = layer.mlp
1248
                self.moe_mlp_layers.append(layer.mlp)
1249
1250
                self.moe_layers.append(layer.mlp.experts)

1251
        self.extract_moe_parameters(example_moe)
wangding zeng's avatar
wangding zeng committed
1252

1253
1254
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.embed_input_ids(input_ids)
1255

wangding zeng's avatar
wangding zeng committed
1256
1257
    def forward(
        self,
zhuwenwen's avatar
zhuwenwen committed
1258
        input_ids: torch.Tensor,
wangding zeng's avatar
wangding zeng committed
1259
        positions: torch.Tensor,
1260
1261
1262
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor | IntermediateTensors:
1263
1264
1265
        hidden_states = self.model(
            input_ids, positions, intermediate_tensors, inputs_embeds
        )
wangding zeng's avatar
wangding zeng committed
1266
1267
        return hidden_states

1268
1269
1270
    def compute_logits(
        self,
        hidden_states: torch.Tensor,
1271
    ) -> torch.Tensor | None:
1272
        logits = self.logits_processor(self.lm_head, hidden_states)
wangding zeng's avatar
wangding zeng committed
1273
1274
        return logits

1275
1276
1277
1278
    def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
        # Params for weights, fp8 weight scales, fp8 activation scales
        # (param_name, weight_name, expert_id, shard_id)
        return SharedFusedMoE.make_expert_params_mapping(
1279
            self,
1280
1281
1282
1283
1284
1285
1286
            ckpt_gate_proj_name="gate_proj",
            ckpt_down_proj_name="down_proj",
            ckpt_up_proj_name="up_proj",
            num_experts=self.config.n_routed_experts,
            num_redundant_experts=0,
        )

1287
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
1288
1289
1290
        rocm_aiter_moe_shared_expert_enabled = (
            rocm_aiter_ops.is_fusion_moe_shared_experts_enabled()
        )
wangding zeng's avatar
wangding zeng committed
1291
1292
1293
1294
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
1295
1296
        ]
        mla_params_mapping = [
1297
1298
            ("fused_qkv_a_proj", "q_a_proj", 0),
            ("fused_qkv_a_proj", "kv_a_proj_with_mqa", 1),
wangding zeng's avatar
wangding zeng committed
1299
        ]
1300
1301
1302
1303
1304
1305
1306
1307
1308
        mha_params_mapping = [
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
        ]
        if self.use_mha:
            stacked_params_mapping.extend(mha_params_mapping)
        else:
            stacked_params_mapping.extend(mla_params_mapping)
wangding zeng's avatar
wangding zeng committed
1309

1310
1311
        # Params for weights, fp8 weight scales, fp8 activation scales
        # (param_name, weight_name, expert_id, shard_id)
1312
        expert_params_mapping = SharedFusedMoE.make_expert_params_mapping(
1313
            self,
1314
1315
1316
            ckpt_gate_proj_name="gate_proj",
            ckpt_down_proj_name="down_proj",
            ckpt_up_proj_name="up_proj",
1317
1318
1319
            num_experts=self.config.n_routed_experts
            + (
                self.config.n_shared_experts
1320
                if rocm_aiter_moe_shared_expert_enabled
1321
1322
                else 0
            ),
1323
1324
            num_redundant_experts=self.num_redundant_experts,
        )
1325

wangding zeng's avatar
wangding zeng committed
1326
        params_dict = dict(self.named_parameters())
1327
        loaded_params: set[str] = set()
wangding zeng's avatar
wangding zeng committed
1328
1329
1330
        for name, loaded_weight in weights:
            if "rotary_emb.inv_freq" in name:
                continue
1331

1332
1333
1334
            spec_layer = get_spec_layer_idx_from_weight_name(self.config, name)
            if spec_layer is not None:
                continue  # skip spec decode layers for main model
1335

1336
1337
            is_fusion_moe_shared_experts_layer = (
                rocm_aiter_moe_shared_expert_enabled and ("mlp.shared_experts" in name)
1338
1339
            )

1340
            for param_name, weight_name, shard_id in stacked_params_mapping:
1341
                # Skip non-stacked layers and experts (experts handled below).
wangding zeng's avatar
wangding zeng committed
1342
1343
                if weight_name not in name:
                    continue
1344
1345
1346
1347
1348
1349
                # We have mlp.experts[0].gate_proj in the checkpoint.
                # Since we handle the experts below in expert_params_mapping,
                # we need to skip here BEFORE we update the name, otherwise
                # name will be updated to mlp.experts[0].gate_up_proj, which
                # will then be updated below in expert_params_mapping
                # for mlp.experts[0].gate_gate_up_proj, which breaks load.
1350
                if ("mlp.experts." in name) and name not in params_dict:
1351
                    continue
1352
                if is_fusion_moe_shared_experts_layer:
1353
                    continue
1354
                name_mapped = name.replace(weight_name, param_name)
1355
1356
1357

                # QKV fusion is optional, fall back to normal
                # weight loading if it's not enabled
1358
                # if go with fusion option, then update name
1359
1360
1361
                if (
                    param_name == "fused_qkv_a_proj"
                ) and name_mapped not in params_dict:
1362
                    continue
1363
1364
                else:
                    name = name_mapped
wangding zeng's avatar
wangding zeng committed
1365
1366
1367
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
1368
1369
1370
1371

                if is_pp_missing_parameter(name, self):
                    continue

wangding zeng's avatar
wangding zeng committed
1372
1373
1374
1375
1376
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
1377
                is_expert_weight = False
zhuwenwen's avatar
zhuwenwen committed
1378

1379
1380
1381
1382
1383
1384
1385
1386
1387
                # Special handling: when AITER fusion_shared_experts is enabled,
                # checkpoints may provide a single widened shared_experts tensor
                # without explicit expert indices
                # (e.g. ...mlp.shared_experts.gate_proj.weight).
                # For models with multiple shared experts, split that tensor
                # evenly into per-shared-expert slices and load them into
                # appended expert slots mlp.experts.{n_routed_experts + j}.*
                # accordingly.
                num_chunks = 1
1388
                if is_fusion_moe_shared_experts_layer:
1389
1390
1391
1392
                    num_chunks = getattr(self.config, "n_shared_experts", 1) or 1
                    # Determine split axis based on op type
                    # gate/up: ColumnParallel → split along dim 0
                    # down: RowParallel → split along dim 1
1393
1394
1395
1396
1397
                    split_dim = (
                        1
                        if ("down_proj.weight" in name and loaded_weight.ndim > 1)
                        else 0
                    )
1398
1399
1400
1401
                    total = loaded_weight.shape[split_dim]
                    assert total % num_chunks == 0, (
                        f"Shared expert weight dim {total} "
                        f"not divisible by num_chunks {num_chunks}"
1402
                    )
1403
1404
1405
1406
1407
1408
                    chunk_size = total // num_chunks

                for j in range(num_chunks):
                    chunk_name = name
                    weight_to_load = loaded_weight

1409
                    if is_fusion_moe_shared_experts_layer:
1410
1411
1412
1413
1414
                        chunk_slice = slice(j * chunk_size, (j + 1) * chunk_size)
                        if loaded_weight.ndim == 1:
                            weight_to_load = loaded_weight[chunk_slice]
                        elif split_dim == 0:
                            weight_to_load = loaded_weight[chunk_slice, :]
1415
                        else:
1416
                            weight_to_load = loaded_weight[:, chunk_slice]
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
                        # Synthesize an expert-style name so expert mapping
                        # can route it
                        chunk_name = name.replace(
                            "mlp.shared_experts",
                            f"mlp.experts.{self.config.n_routed_experts + j}",
                        )

                    # Use expert_params_mapping to locate the destination
                    # param and delegate to its expert-aware weight_loader
                    # with expert_id.
                    for mapping in expert_params_mapping:
                        param_name, weight_name, expert_id, shard_id = mapping
                        if weight_name not in chunk_name:
                            continue

                        # Anyway, this is an expert weight and should not be
                        # attempted to load as other weights later
                        is_expert_weight = True

                        # Do not modify `name` since the loop may continue here
                        # Instead, create a new variable
                        name_mapped = chunk_name.replace(weight_name, param_name)

                        if is_pp_missing_parameter(name_mapped, self):
                            continue

                        param = params_dict[name_mapped]
                        # We should ask the weight loader to return success or
                        # not here since otherwise we may skip experts with
                        # other available replicas.
                        weight_loader = typing.cast(
                            Callable[..., bool], param.weight_loader
                        )
                        success = weight_loader(
                            param,
                            weight_to_load,
                            name_mapped,
                            shard_id=shard_id,
                            expert_id=expert_id,
                            return_success=True,
                        )
                        if success:
1459
                            if not is_fusion_moe_shared_experts_layer:
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
                                name = name_mapped
                            else:
                                loaded_params.add(name_mapped)
                            break
                    else:
                        if is_expert_weight:
                            # We've checked that this is an expert weight
                            # However it's not mapped locally to this rank
                            # So we simply skip it
                            continue

                        # Skip loading extra bias for GPTQ models.
                        if name.endswith(".bias") and name not in params_dict:
                            continue

                        # Remapping the name of FP8 kv-scale.
                        name = maybe_remap_kv_scale_name(name, params_dict)
                        if name is None:
                            continue

                        if is_pp_missing_parameter(name, self):
                            continue
1482

zhuwenwen's avatar
zhuwenwen committed
1483
                        param = params_dict[name]
1484
1485
1486
1487
                        weight_loader = getattr(
                            param, "weight_loader", default_weight_loader
                        )
                        weight_loader(param, loaded_weight)
1488
            if not is_fusion_moe_shared_experts_layer:
1489
                loaded_params.add(name)
zhuwenwen's avatar
zhuwenwen committed
1490

1491
        return loaded_params
1492
1493


1494
1495
1496
1497
class DeepseekForCausalLM(DeepseekV2ForCausalLM):
    pass


1498
1499
class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM):
    pass
1500
1501


1502
1503
# Compatibility with
# https://huggingface.co/deepseek-ai/DeepSeek-V3-Base/blob/main/configuration_deepseek.py
1504
def get_spec_layer_idx_from_weight_name(
1505
1506
    config: DeepseekV2Config | DeepseekV3Config, weight_name: str
) -> int | None:
1507
1508
1509
1510
    if (
        hasattr(config, "num_nextn_predict_layers")
        and config.num_nextn_predict_layers > 0
    ):
1511
1512
        layer_idx = config.num_hidden_layers
        for i in range(config.num_nextn_predict_layers):
1513
            if weight_name.startswith(f"model.layers.{layer_idx + i}."):
1514
                return layer_idx + i
zhuwenwen's avatar
zhuwenwen committed
1515
    return None