deepseek_v2.py 61.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

wangding zeng's avatar
wangding zeng committed
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# Copyright 2023 The vLLM team.
# Copyright 2023 DeepSeek-AI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
25
"""Inference-only DeepseekV2/DeepseekV3 model."""
26

27
28
import typing
from collections.abc import Callable, Iterable
29
from itertools import islice
wangding zeng's avatar
wangding zeng committed
30
31
32

import torch
from torch import nn
33
from transformers import DeepseekV2Config, DeepseekV3Config
wangding zeng's avatar
wangding zeng committed
34

35
import vllm._custom_ops as ops
36
from vllm._aiter_ops import rocm_aiter_ops
37
from vllm.compilation.decorators import support_torch_compile
38
39
40
41
42
43
44
45
from vllm.config import CacheConfig, ParallelConfig, VllmConfig, get_current_vllm_config
from vllm.distributed import (
    get_ep_group,
    get_pp_group,
    get_tensor_model_parallel_rank,
    get_tensor_model_parallel_world_size,
    tensor_model_parallel_all_gather,
)
46
from vllm.logger import init_logger
wangding zeng's avatar
wangding zeng committed
47
from vllm.model_executor.layers.activation import SiluAndMul
48
from vllm.model_executor.layers.attention import Attention
49
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
50
from vllm.model_executor.layers.fused_moe import SharedFusedMoE
51
from vllm.model_executor.layers.layernorm import LayerNorm, RMSNorm
52
53
54
from vllm.model_executor.layers.linear import (
    ColumnParallelLinear,
    MergedColumnParallelLinear,
55
    QKVParallelLinear,
56
57
58
    ReplicatedLinear,
    RowParallelLinear,
)
wangding zeng's avatar
wangding zeng committed
59
from vllm.model_executor.layers.logits_processor import LogitsProcessor
60
from vllm.model_executor.layers.mla import MLAModules, MultiHeadLatentAttentionWrapper
61
from vllm.model_executor.layers.quantization import QuantizationConfig
62
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
63
64
    per_token_group_quant_fp8,
)
wangding zeng's avatar
wangding zeng committed
65
from vllm.model_executor.layers.rotary_embedding import get_rope
66
from vllm.model_executor.layers.sparse_attn_indexer import SparseAttnIndexer
wangding zeng's avatar
wangding zeng committed
67
from vllm.model_executor.layers.vocab_parallel_embedding import (
68
69
70
    ParallelLMHead,
    VocabParallelEmbedding,
)
71
from vllm.model_executor.model_loader.weight_utils import (
72
73
74
    default_weight_loader,
    maybe_remap_kv_scale_name,
)
75
from vllm.model_executor.models.utils import sequence_parallel_chunk
76
from vllm.platforms import current_platform
77
from vllm.sequence import IntermediateTensors
78
from vllm.v1.attention.backend import AttentionBackend
79
80
81
from vllm.v1.attention.backends.mla.indexer import (
    DeepseekV32IndexerBackend,
)
82
from vllm.v1.kv_cache_interface import KVCacheSpec, MLAAttentionSpec
wangding zeng's avatar
wangding zeng committed
83

84
from .interfaces import MixtureOfExperts, SupportsEagle, SupportsLoRA, SupportsPP
85
86
87
88
89
90
91
from .utils import (
    PPMissingLayer,
    is_pp_missing_parameter,
    make_empty_intermediate_tensors_factory,
    make_layers,
    maybe_prefix,
)
92

93
94
logger = init_logger(__name__)

wangding zeng's avatar
wangding zeng committed
95

96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
class DeepseekAttention(nn.Module):
    """Normal MHA implementation used by Deepseek v1."""

    def __init__(
        self,
        vllm_config: VllmConfig,
        config: DeepseekV2Config | DeepseekV3Config,
        hidden_size: int,
        num_heads: int,
        max_position_embeddings: int = 8192,
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
        prefix: str = "",
        **kwargs,
    ) -> None:
        super().__init__()
        self.hidden_size = hidden_size
        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = num_heads
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        self.total_num_kv_heads = config.num_key_value_heads
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
        self.head_dim = hidden_size // self.total_num_heads
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scaling = self.head_dim**-0.5
        self.max_position_embeddings = max_position_embeddings

        self.qkv_proj = QKVParallelLinear(
            hidden_size,
            self.head_dim,
            self.total_num_heads,
            self.total_num_kv_heads,
            bias=False,
            quant_config=quant_config,
        )

        self.o_proj = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            hidden_size,
            bias=False,
            quant_config=quant_config,
        )

        self.rotary_emb = get_rope(
            self.head_dim,
            max_position=max_position_embeddings,
152
            rope_parameters=config.rope_parameters,
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
        )
        self.attn = Attention(
            self.num_heads,
            self.head_dim,
            self.scaling,
            num_kv_heads=self.num_kv_heads,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.attn",
        )

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
        q, k = self.rotary_emb(positions, q, k)
        attn_output = self.attn(q, k, v)
        output, _ = self.o_proj(attn_output)
        return output


wangding zeng's avatar
wangding zeng committed
177
178
179
180
181
182
class DeepseekV2MLP(nn.Module):
    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
183
        quant_config: QuantizationConfig | None = None,
wangding zeng's avatar
wangding zeng committed
184
        reduce_results: bool = True,
185
        is_sequence_parallel=False,
186
        prefix: str = "",
wangding zeng's avatar
wangding zeng committed
187
188
    ) -> None:
        super().__init__()
189
190
191
192
193

        # If is_sequence_parallel, the input and output tensors are sharded
        # across the ranks within the tp_group. In this case the weights are
        # replicated and no collective ops are needed.
        # Otherwise we use standard TP with an allreduce at the end.
wangding zeng's avatar
wangding zeng committed
194
        self.gate_up_proj = MergedColumnParallelLinear(
195
196
            hidden_size,
            [intermediate_size] * 2,
wangding zeng's avatar
wangding zeng committed
197
            bias=False,
198
            quant_config=quant_config,
199
            disable_tp=is_sequence_parallel,
200
201
202
203
204
205
206
207
208
209
210
            prefix=f"{prefix}.gate_up_proj",
        )
        self.down_proj = RowParallelLinear(
            intermediate_size,
            hidden_size,
            bias=False,
            quant_config=quant_config,
            reduce_results=reduce_results,
            disable_tp=is_sequence_parallel,
            prefix=f"{prefix}.down_proj",
        )
wangding zeng's avatar
wangding zeng committed
211
        if hidden_act != "silu":
212
213
214
            raise ValueError(
                f"Unsupported activation: {hidden_act}. Only silu is supported for now."
            )
wangding zeng's avatar
wangding zeng committed
215
216
217
218
219
220
221
222
223
        self.act_fn = SiluAndMul()

    def forward(self, x):
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.down_proj(x)
        return x


224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
class DeepSeekV2Gate(ReplicatedLinear):
    def __init__(
        self,
        hidden_size: int,
        n_experts: int,
        quant_config: QuantizationConfig | None = None,
        prefix: str = "",
    ):
        assert quant_config is None
        super().__init__(
            hidden_size,
            n_experts,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.gate",
        )

        # Unquantized only, will be called "weight".
        assert hasattr(self, "weight")
        is_hopper_or_blackwell = current_platform.is_device_capability(
            (9, 0)
        ) or current_platform.is_device_capability_family(100)
        SUPPORTED_NUM_EXPERTS = [256, 384]
        SUPPORTED_HIDDEN_SIZES = [7168]

        self.allow_dsv3_router_gemm = (
            current_platform.is_cuda()
            and is_hopper_or_blackwell
            and n_experts in SUPPORTED_NUM_EXPERTS
            and hidden_size in SUPPORTED_HIDDEN_SIZES
        )

        self._out_dtype: torch.dtype | None = None

    def set_out_dtype(self, out_dtype: torch.dtype) -> None:
        """
        Set out dtype for the router logits. This is needed after
        __init__, b/c we need to check if the trtllm kernel is
        selected before we decide between bf16 and fp32.
        """

        if self._out_dtype is not None:
            raise ValueError("out_dtype has already been set")
        else:
            self._out_dtype = out_dtype

    @property
    def out_dtype(self) -> torch.dtype:
        if self._out_dtype is None:
            raise ValueError("out_dtype has not been set yet")
        return self._out_dtype

    def forward(
        self,
        x: torch.Tensor,
    ) -> tuple[torch.Tensor, None]:
        """
        Use specialized GEMM for low batch size for DSV3 and KIMI.
        """
        if self.allow_dsv3_router_gemm and x.shape[0] <= 16:
            return ops.dsv3_router_gemm(
                hidden_states=x, router_weight=self.weight, output_dtype=self.out_dtype
            ), None
        else:
            return super().forward(x)


wangding zeng's avatar
wangding zeng committed
291
292
293
class DeepseekV2MoE(nn.Module):
    def __init__(
        self,
294
        config: DeepseekV2Config | DeepseekV3Config,
295
        parallel_config: ParallelConfig,
296
        quant_config: QuantizationConfig | None = None,
297
        prefix: str = "",
wangding zeng's avatar
wangding zeng committed
298
299
300
    ):
        super().__init__()
        self.tp_size = get_tensor_model_parallel_world_size()
301
302
        self.tp_rank = get_tensor_model_parallel_rank()

303
        self.routed_scaling_factor = getattr(config, "routed_scaling_factor", 1.0)
304
305

        self.ep_group = get_ep_group().device_group
306
        self.ep_rank = get_ep_group().rank_in_group
307
308
309
        self.ep_size = self.ep_group.size()
        self.n_routed_experts: int = config.n_routed_experts
        self.n_shared_experts: int = config.n_shared_experts
310

311
        self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe
312

313
        if config.hidden_act != "silu":
314
315
316
317
318
            raise ValueError(
                f"Unsupported activation: {config.hidden_act}. "
                "Only silu is supported for now."
            )

319
        self.gate = DeepSeekV2Gate(
320
321
322
323
324
            config.hidden_size,
            config.n_routed_experts,
            quant_config=None,
            prefix=f"{prefix}.gate",
        )
325
        if getattr(config, "topk_method", None) == "noaux_tc":
326
            self.gate.e_score_correction_bias = nn.Parameter(
327
328
                torch.empty(config.n_routed_experts, dtype=torch.float32)
            )
329
330
331
        else:
            self.gate.e_score_correction_bias = None

332
        # Load balancing settings.
333
334
        eplb_config = parallel_config.eplb_config
        self.enable_eplb = parallel_config.enable_eplb
335

336
        self.n_redundant_experts = eplb_config.num_redundant_experts
337
        self.n_logical_experts = self.n_routed_experts
338
        self.n_physical_experts = self.n_logical_experts + self.n_redundant_experts
339
340
        self.n_local_physical_experts = self.n_physical_experts // self.ep_size

341
342
343
344
        self.physical_expert_start = self.ep_rank * self.n_local_physical_experts
        self.physical_expert_end = (
            self.physical_expert_start + self.n_local_physical_experts
        )
345

346
        self.is_rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled()
347
348
349
350
        self.is_fusion_moe_shared_experts_enabled = (
            rocm_aiter_ops.is_fusion_moe_shared_experts_enabled()
        )
        if config.n_shared_experts is None or self.is_fusion_moe_shared_experts_enabled:
351
352
            self.shared_experts = None
        else:
353
            intermediate_size = config.moe_intermediate_size * config.n_shared_experts
354

wangding zeng's avatar
wangding zeng committed
355
356
357
358
359
            self.shared_experts = DeepseekV2MLP(
                hidden_size=config.hidden_size,
                intermediate_size=intermediate_size,
                hidden_act=config.hidden_act,
                quant_config=quant_config,
360
                is_sequence_parallel=self.is_sequence_parallel,
361
                reduce_results=False,
362
                prefix=f"{prefix}.shared_experts",
wangding zeng's avatar
wangding zeng committed
363
364
            )

365
366
        self.experts = SharedFusedMoE(
            shared_experts=self.shared_experts,
367
            gate=self.gate,
368
369
370
371
372
373
374
            num_experts=config.n_routed_experts,
            top_k=config.num_experts_per_tok,
            hidden_size=config.hidden_size,
            intermediate_size=config.moe_intermediate_size,
            reduce_results=False,
            renormalize=config.norm_topk_prob,
            quant_config=quant_config,
375
376
377
            use_grouped_topk=True,
            num_expert_group=getattr(config, "n_group", 1),
            topk_group=getattr(config, "topk_group", 1),
378
            prefix=f"{prefix}.experts",
379
            scoring_func=getattr(config, "scoring_func", "softmax"),
380
            # we do scaling outside, set factor to 1.0 to avoid double mul
381
382
            # aiter applies routed_scaling_factor internally
            routed_scaling_factor=1.0
383
            if not self.is_rocm_aiter_moe_enabled
384
            else self.routed_scaling_factor,
385
386
387
388
            e_score_correction_bias=self.gate.e_score_correction_bias,
            enable_eplb=self.enable_eplb,
            num_redundant_experts=self.n_redundant_experts,
            is_sequence_parallel=self.is_sequence_parallel,
389
            n_shared_experts=config.n_shared_experts
390
            if self.is_fusion_moe_shared_experts_enabled
391
            else None,
392
        )
393

394
395
396
397
398
399
400
        # NOTE(rob): this is a hack until we finish off the PR for
        # merging TRTLLM kernels into the MK framework. Then we can
        # query the MonolithicMK for the expected router logits.
        self.gate.set_out_dtype(
            torch.float32 if self.experts.quant_method.is_monolithic else torch.bfloat16
        )

wangding zeng's avatar
wangding zeng committed
401
402
403
    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        num_tokens, hidden_dim = hidden_states.shape
        hidden_states = hidden_states.view(-1, hidden_dim)
404
405
406
407
408
409

        # Chunk the hidden states so they aren't replicated across TP ranks.
        # This avoids duplicate computation in self.experts.
        # TODO: We can replace the all_reduce at the end of attn with a
        # reduce_scatter instead of chunking here.
        if self.is_sequence_parallel:
410
            hidden_states = sequence_parallel_chunk(hidden_states)
411

412
413
414
415
416
417
418
419
420
421
422
        if self.experts.is_internal_router:
            # In this case, the gate/router runs inside the FusedMoE class
            fused_moe_out = self.experts(
                hidden_states=hidden_states, router_logits=hidden_states
            )
        else:
            # router_logits: (num_tokens, n_experts)
            router_logits, _ = self.gate(hidden_states)
            fused_moe_out = self.experts(
                hidden_states=hidden_states, router_logits=router_logits
            )
423

424
425
426
        shared_output, final_hidden_states = fused_moe_out
        if self.shared_experts is None:
            assert shared_output is None
427
428
429
430

        # Fix FP16 overflow
        # See DeepseekV2DecoderLayer for more details.
        if hidden_states.dtype != torch.float16:
431
            if not self.is_rocm_aiter_moe_enabled:
432
                final_hidden_states *= self.routed_scaling_factor
433
434
        elif self.shared_experts is not None:
            assert shared_output is not None
435
            shared_output *= 1.0 / self.routed_scaling_factor
436
437
438
439

        if self.shared_experts is not None:
            assert shared_output is not None
            final_hidden_states += shared_output
440

441
442
        if self.is_sequence_parallel:
            final_hidden_states = tensor_model_parallel_all_gather(
443
444
                final_hidden_states, 0
            )
445
446
            final_hidden_states = final_hidden_states[:num_tokens]
        elif self.tp_size > 1:
447
448
449
            final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel(
                final_hidden_states
            )
wangding zeng's avatar
wangding zeng committed
450
451
452
453
454
455

        return final_hidden_states.view(num_tokens, hidden_dim)


def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
    import math
456

wangding zeng's avatar
wangding zeng committed
457
458
459
460
461
    if scale <= 1:
        return 1.0
    return 0.1 * mscale * math.log(scale) + 1.0


462
463
464
465
466
467
468
469
470
471
def _get_llama_4_scaling(
    original_max_position_embeddings: int, scaling_beta: float, positions: torch.Tensor
) -> torch.Tensor:
    scaling = 1 + scaling_beta * torch.log(
        1 + torch.floor(positions / original_max_position_embeddings)
    )
    # Broadcast over num_heads and head_dim
    return scaling[..., None, None]


wangding zeng's avatar
wangding zeng committed
472
473
474
class DeepseekV2Attention(nn.Module):
    def __init__(
        self,
475
        vllm_config: VllmConfig,
476
        config: DeepseekV2Config | DeepseekV3Config,
wangding zeng's avatar
wangding zeng committed
477
478
479
480
481
482
483
484
        hidden_size: int,
        num_heads: int,
        qk_nope_head_dim: int,
        qk_rope_head_dim: int,
        v_head_dim: int,
        q_lora_rank: int,
        kv_lora_rank: int,
        max_position_embeddings: int = 8192,
485
486
487
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
        topk_indices_buffer: torch.Tensor | None = None,
488
        prefix: str = "",
wangding zeng's avatar
wangding zeng committed
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
    ) -> None:
        super().__init__()
        self.hidden_size = hidden_size
        self.qk_nope_head_dim = qk_nope_head_dim
        self.qk_rope_head_dim = qk_rope_head_dim
        self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
        self.v_head_dim = v_head_dim
        self.q_lora_rank = q_lora_rank
        self.kv_lora_rank = kv_lora_rank
        self.num_heads = num_heads
        tp_size = get_tensor_model_parallel_world_size()
        assert num_heads % tp_size == 0
        self.num_local_heads = num_heads // tp_size
        self.scaling = self.qk_head_dim**-0.5
        self.max_position_embeddings = max_position_embeddings
504
505
        assert topk_indices_buffer is None, (
            "topk_indices_buffer is not \
506
        supported for DeepseekV2Attention"
507
        )
wangding zeng's avatar
wangding zeng committed
508
509

        if self.q_lora_rank is not None:
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
            self.q_a_proj = ReplicatedLinear(
                self.hidden_size,
                self.q_lora_rank,
                bias=False,
                quant_config=quant_config,
                prefix=f"{prefix}.q_a_proj",
            )
            self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
            self.q_b_proj = ColumnParallelLinear(
                q_lora_rank,
                self.num_heads * self.qk_head_dim,
                bias=False,
                quant_config=quant_config,
                prefix=f"{prefix}.q_b_proj",
            )
wangding zeng's avatar
wangding zeng committed
525
        else:
526
527
528
529
530
531
532
            self.q_proj = ColumnParallelLinear(
                self.hidden_size,
                self.num_heads * self.qk_head_dim,
                bias=False,
                quant_config=quant_config,
                prefix=f"{prefix}.q_proj",
            )
wangding zeng's avatar
wangding zeng committed
533

534
535
536
537
538
        self.kv_a_proj_with_mqa = ReplicatedLinear(
            self.hidden_size,
            self.kv_lora_rank + self.qk_rope_head_dim,
            bias=False,
            quant_config=quant_config,
539
540
541
            prefix=f"{prefix}.kv_a_proj_with_mqa",
        )
        self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
wangding zeng's avatar
wangding zeng committed
542
543
544
545
        self.kv_b_proj = ColumnParallelLinear(
            self.kv_lora_rank,
            self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
            bias=False,
546
            quant_config=quant_config,
547
548
            prefix=f"{prefix}.kv_b_proj",
        )
wangding zeng's avatar
wangding zeng committed
549
        # O projection.
550
551
552
553
554
555
556
        self.o_proj = RowParallelLinear(
            self.num_heads * self.v_head_dim,
            self.hidden_size,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.o_proj",
        )
557
        if config.rope_parameters["rope_type"] != "default":
558
559
560
561
562
            config.rope_parameters["rope_type"] = (
                "deepseek_yarn"
                if config.rope_parameters.get("apply_yarn_scaling", True)
                else "deepseek_llama_scaling"
            )
563

564
565
566
        self.rotary_emb = get_rope(
            qk_rope_head_dim,
            max_position=max_position_embeddings,
567
            rope_parameters=config.rope_parameters,
568
569
            is_neox_style=False,
        )
wangding zeng's avatar
wangding zeng committed
570

571
572
573
574
        if (
            config.rope_parameters["rope_type"] != "default"
            and config.rope_parameters["rope_type"] == "deepseek_yarn"
        ):
575
576
            mscale_all_dim = config.rope_parameters.get("mscale_all_dim", False)
            scaling_factor = config.rope_parameters["factor"]
wangding zeng's avatar
wangding zeng committed
577
578
579
            mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
            self.scaling = self.scaling * mscale * mscale

580
581
582
583
584
585
586
587
588
        self.attn = Attention(
            self.num_local_heads,
            self.qk_head_dim,
            self.scaling,
            num_kv_heads=self.num_local_heads,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.attn",
        )
wangding zeng's avatar
wangding zeng committed
589
590
591
592
593

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
594
        llama_4_scaling: torch.Tensor | None,
wangding zeng's avatar
wangding zeng committed
595
596
597
598
    ) -> torch.Tensor:
        if self.q_lora_rank is not None:
            q = self.q_a_proj(hidden_states)[0]
            q = self.q_a_layernorm(q)
599
            q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
wangding zeng's avatar
wangding zeng committed
600
        else:
601
602
603
604
            q = self.q_proj(hidden_states)[0].view(
                -1, self.num_local_heads, self.qk_head_dim
            )
        q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
wangding zeng's avatar
wangding zeng committed
605
        latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
606
        kv_a, _ = latent_cache.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
wangding zeng's avatar
wangding zeng committed
607
        latent_cache = latent_cache.unsqueeze(1)
608
        kv_a = self.kv_a_layernorm(kv_a)
wangding zeng's avatar
wangding zeng committed
609
        kv = self.kv_b_proj(kv_a)[0]
610
        kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim)
wangding zeng's avatar
wangding zeng committed
611
        k_nope, v = kv.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
612
        k_pe = latent_cache[:, :, self.kv_lora_rank :]
613

wangding zeng's avatar
wangding zeng committed
614
        q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
615

616
        q[..., self.qk_nope_head_dim :] = q_pe
wangding zeng's avatar
wangding zeng committed
617
        k = torch.empty_like(q)
618
619
        k[..., : self.qk_nope_head_dim] = k_nope
        k[..., self.qk_nope_head_dim :] = k_pe
620
621
622
623
624

        # Apply llama 4 scaling if provided
        if llama_4_scaling is not None:
            q *= llama_4_scaling

625
626
        # padding value to qk_head_dim for alignment
        v = torch.nn.functional.pad(
627
628
            v, [0, self.qk_head_dim - self.v_head_dim], value=0
        ).view(-1, self.num_local_heads * self.qk_head_dim)
629
        attn_output = self.attn(q, k, v)
630
631
632
        attn_output = attn_output.view(-1, self.num_local_heads, self.qk_head_dim)[
            ..., : self.v_head_dim
        ].reshape(-1, self.num_local_heads * self.v_head_dim)
wangding zeng's avatar
wangding zeng committed
633
634
635
636
        output, _ = self.o_proj(attn_output)
        return output


637
class DeepseekV32IndexerCache(torch.nn.Module, AttentionLayerBase):
638
639
640
    def __init__(
        self, head_dim: int, dtype: torch.dtype, prefix: str, cache_config: CacheConfig
    ):
641
642
643
644
645
646
647
648
649
650
651
        super().__init__()
        self.kv_cache = [torch.tensor([])]
        self.head_dim = head_dim
        self.prefix = prefix
        self.cache_config = cache_config
        self.dtype = dtype
        compilation_config = get_current_vllm_config().compilation_config
        if prefix in compilation_config.static_forward_context:
            raise ValueError(f"Duplicate layer name: {prefix}")
        compilation_config.static_forward_context[prefix] = self

652
    def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
653
654
655
656
657
658
659
        return MLAAttentionSpec(  # Only has one vector instead of K + V
            block_size=self.cache_config.block_size,
            num_kv_heads=1,
            head_size=self.head_dim,
            dtype=self.dtype,
        )

660
    def forward(self): ...
661
662
663
664
665
666

    def get_attn_backend(self) -> AttentionBackend:
        return DeepseekV32IndexerBackend


class Indexer(nn.Module):
667
668
669
    def __init__(
        self,
        vllm_config: VllmConfig,
670
        config: DeepseekV2Config | DeepseekV3Config,
671
672
        hidden_size: int,
        q_lora_rank: int,
673
674
675
        quant_config: QuantizationConfig | None,
        cache_config: CacheConfig | None,
        topk_indices_buffer: torch.Tensor | None,
676
677
        prefix: str = "",
    ):
678
679
680
681
682
683
684
685
686
687
        super().__init__()
        self.vllm_config = vllm_config
        self.config = config
        # self.indexer_cfg = config.attn_module_list_cfg[0]["attn_index"]
        self.topk_tokens = config.index_topk
        self.n_head = config.index_n_heads  # 64
        self.head_dim = config.index_head_dim  # 128
        self.rope_dim = config.qk_rope_head_dim  # 64
        self.q_lora_rank = q_lora_rank  # 1536
        # no tensor parallel, just replicated
688
689
690
691
692
693
694
695
696
697
698
699
700
701
        self.wq_b = ReplicatedLinear(
            self.q_lora_rank,
            self.head_dim * self.n_head,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.wq_b",
        )
        self.wk = ReplicatedLinear(
            hidden_size,
            self.head_dim,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.wk",
        )
702
        self.k_norm = LayerNorm(self.head_dim, eps=1e-6)
703
        self.weights_proj = ReplicatedLinear(
704
705
706
707
708
            hidden_size,
            self.n_head,
            bias=False,
            quant_config=None,
            prefix=f"{prefix}.weights_proj",
709
        )
710
711
712
713
714
715
716
717
718
719
        self.softmax_scale = self.head_dim**-0.5

        self.scale_fmt = "ue8m0"
        self.quant_block_size = 128  # TODO: get from config
        self.topk_indices_buffer = topk_indices_buffer

        # NOTE: (zyongye) we use fp8 naive cache,
        #       where we store value in fp8 and scale in fp32
        #       per self.quant_block_size element
        self.k_cache = DeepseekV32IndexerCache(
720
            head_dim=self.head_dim + self.head_dim // self.quant_block_size * 4,
721
722
            dtype=torch.uint8,
            prefix=f"{prefix}.k_cache",
723
724
            cache_config=cache_config,
        )
725
726
        self.max_model_len = vllm_config.model_config.max_model_len
        self.prefix = prefix
727
728
        from vllm.v1.attention.backends.mla.indexer import get_max_prefill_buffer_size

729
        self.max_total_seq_len = get_max_prefill_buffer_size(vllm_config)
730
731
732
733
734
735
736
737
738
739
        self.indexer_op = SparseAttnIndexer(
            self.k_cache,
            self.quant_block_size,
            self.scale_fmt,
            self.topk_tokens,
            self.head_dim,
            self.max_model_len,
            self.max_total_seq_len,
            self.topk_indices_buffer,
        )
740

741
742
743
    def forward(
        self, hidden_states: torch.Tensor, qr: torch.Tensor, positions, rotary_emb
    ) -> torch.Tensor:
744
745
746
        q, _ = self.wq_b(qr)
        q = q.view(-1, self.n_head, self.head_dim)
        q_pe, q_nope = torch.split(
747
748
            q, [self.rope_dim, self.head_dim - self.rope_dim], dim=-1
        )
749
750
751
752

        k, _ = self.wk(hidden_states)
        k = self.k_norm(k)
        k_pe, k_nope = torch.split(
753
754
            k, [self.rope_dim, self.head_dim - self.rope_dim], dim=-1
        )
755
756

        q_pe, k_pe = rotary_emb(positions, q_pe, k_pe.unsqueeze(1))
757
758
759
760
761
        # Note: RoPE (NeoX) can introduce extra leading dimensions during compilation
        # so we need to reshape back to token-flattened shapes
        q_pe = q_pe.reshape(-1, self.n_head, self.rope_dim)
        k_pe = k_pe.reshape(-1, 1, self.rope_dim)

762
763
        # `rotary_emb` is shape-preserving; `q_pe` is already
        # [num_tokens, n_head, rope_dim].
764
765
        q = torch.cat([q_pe, q_nope], dim=-1)
        # `k_pe` is [num_tokens, 1, rope_dim] (MQA).
766
        k = torch.cat([k_pe.squeeze(-2), k_nope], dim=-1)
767
768
769

        # we only quant q here since k quant is fused with cache insertion
        q = q.view(-1, self.head_dim)
770
771
772
773
774
775
        q_fp8, q_scale = per_token_group_quant_fp8(
            q,
            self.quant_block_size,
            column_major_scales=False,
            use_ue8m0=self.scale_fmt is not None,
        )
776
777
778
779
        q_fp8 = q_fp8.view(-1, self.n_head, self.head_dim)
        q_scale = q_scale.view(-1, self.n_head, 1)

        weights, _ = self.weights_proj(hidden_states)
780
781
782
        weights = (
            weights.unsqueeze(-1) * q_scale * self.softmax_scale * self.n_head**-0.5
        )
783
784
        weights = weights.squeeze(-1)

785
        return self.indexer_op(hidden_states, q_fp8, k, weights)
786
787


788
789
790
791
class DeepSeekV2FusedQkvAProj(MergedColumnParallelLinear):
    def __init__(
        self,
        input_size: int,
792
        output_size: list[int],
793
794
795
796
797
798
799
800
801
        quant_config: QuantizationConfig | None = None,
        prefix: str = "",
    ):
        super().__init__(
            input_size,
            output_size,
            bias=False,
            quant_config=quant_config,
            disable_tp=True,
802
            prefix=prefix,
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
        )

        # Check if the DeepSeek V3 fused A GEMM kernel can be used.
        # This kernel supports PDL and is optimized for low batch size.
        self._use_min_latency_gemm = (
            hasattr(self, "weight")
            and self.weight.dtype == torch.bfloat16
            and self.weight.shape[0] == 2112
            and self.weight.shape[1] == 7168
            and current_platform.is_cuda()
            and (
                current_platform.is_device_capability(90)
                or current_platform.is_device_capability_family(100)
            )
        )

    def forward(
        self,
        input_,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.nn.Parameter | None]:
        num_tokens = input_.shape[0]
        if self._use_min_latency_gemm and (0 < num_tokens <= 16):
            output = torch.empty(
                num_tokens,
                2112,
                dtype=torch.bfloat16,
                device=input_.device,
            )
            ops.dsv3_fused_a_gemm(
                output,
                input_,
                self.weight.T,
            )
            if not self.return_bias:
                return output
            output_bias = self.bias if self.skip_bias_add else None
            return output, output_bias
        else:
            # Fallback to the standard forward method when
            # the fused A GEMM kernel cannot be used.
            return super().forward(input_)


846
847
848
849
class DeepseekV2MLAAttention(nn.Module):
    """
    Main reference: DeepseekV2 paper, and FlashInfer Implementation
    (https://arxiv.org/abs/2405.04434 and https://github.com/flashinfer-ai/flashinfer/pull/551).
850

851
852
        For more info see MLACommonImpl in:
        vllm/v1/attention/backends/mla/utils.py
853
854
855
856
    """

    def __init__(
        self,
857
        vllm_config: VllmConfig,
858
        config: DeepseekV2Config | DeepseekV3Config,
859
860
861
862
863
        hidden_size: int,
        num_heads: int,
        qk_nope_head_dim: int,
        qk_rope_head_dim: int,
        v_head_dim: int,
864
        q_lora_rank: int | None,
865
866
        kv_lora_rank: int,
        max_position_embeddings: int = 8192,
867
868
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
869
        prefix: str = "",
870
        topk_indices_buffer: torch.Tensor | None = None,
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
    ) -> None:
        super().__init__()
        self.hidden_size = hidden_size
        self.qk_nope_head_dim = qk_nope_head_dim
        self.qk_rope_head_dim = qk_rope_head_dim
        self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
        self.v_head_dim = v_head_dim

        self.q_lora_rank = q_lora_rank
        self.kv_lora_rank = kv_lora_rank

        self.num_heads = num_heads
        tp_size = get_tensor_model_parallel_world_size()
        assert num_heads % tp_size == 0
        self.num_local_heads = num_heads // tp_size

        self.scaling = self.qk_head_dim**-0.5
        self.max_position_embeddings = max_position_embeddings

        if self.q_lora_rank is not None:
891
            self.fused_qkv_a_proj = DeepSeekV2FusedQkvAProj(
892
893
894
                self.hidden_size,
                [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim],
                quant_config=quant_config,
895
                prefix=f"{prefix}.fused_qkv_a_proj",
896
            )
897
898
899
900
901
902
        else:
            self.kv_a_proj_with_mqa = ReplicatedLinear(
                self.hidden_size,
                self.kv_lora_rank + self.qk_rope_head_dim,
                bias=False,
                quant_config=quant_config,
903
904
                prefix=f"{prefix}.kv_a_proj_with_mqa",
            )
905
906

        if self.q_lora_rank is not None:
907
908
909
910
911
912
913
914
            self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
            self.q_b_proj = ColumnParallelLinear(
                self.q_lora_rank,
                self.num_heads * self.qk_head_dim,
                bias=False,
                quant_config=quant_config,
                prefix=f"{prefix}.q_b_proj",
            )
915
        else:
916
917
918
919
920
921
922
923
            self.q_proj = ColumnParallelLinear(
                self.hidden_size,
                self.num_heads * self.qk_head_dim,
                bias=False,
                quant_config=quant_config,
                prefix=f"{prefix}.q_proj",
            )
        self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
924
925
926
927
928
        self.kv_b_proj = ColumnParallelLinear(
            self.kv_lora_rank,
            self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
            bias=False,
            quant_config=quant_config,
929
930
931
932
933
934
935
936
937
            prefix=f"{prefix}.kv_b_proj",
        )
        self.o_proj = RowParallelLinear(
            self.num_heads * self.v_head_dim,
            self.hidden_size,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.o_proj",
        )
938

939
        if config.rope_parameters["rope_type"] != "default":
940
941
942
943
944
945
            config.rope_parameters["rope_type"] = (
                "deepseek_yarn"
                if config.rope_parameters.get("apply_yarn_scaling", True)
                else "deepseek_llama_scaling"
            )

946
947
948
        self.rotary_emb = get_rope(
            qk_rope_head_dim,
            max_position=max_position_embeddings,
949
            rope_parameters=config.rope_parameters,
950
951
            is_neox_style=False,
        )
952
953
954
955
956

        if (
            config.rope_parameters["rope_type"] != "default"
            and config.rope_parameters["rope_type"] == "deepseek_yarn"
        ):
957
958
            mscale_all_dim = config.rope_parameters.get("mscale_all_dim", False)
            scaling_factor = config.rope_parameters["factor"]
959
960
961
            mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
            self.scaling = self.scaling * mscale * mscale

962
963
964
        self.is_v32 = hasattr(config, "index_topk")

        if self.is_v32:
965
966
967
            self.indexer_rope_emb = get_rope(
                qk_rope_head_dim,
                max_position=max_position_embeddings,
968
                rope_parameters=config.rope_parameters,
969
                is_neox_style=not getattr(config, "indexer_rope_interleave", False),
970
            )
971
972
973
974
975
976
977
978
979
980
            self.indexer = Indexer(
                vllm_config,
                config,
                hidden_size,
                q_lora_rank,
                quant_config,
                cache_config,
                topk_indices_buffer,
                f"{prefix}.indexer",
            )
981
        else:
982
            self.indexer_rope_emb = None
983
984
            self.indexer = None

985
986
        mla_modules = MLAModules(
            kv_a_layernorm=self.kv_a_layernorm,
987
            kv_b_proj=self.kv_b_proj,
988
989
990
            rotary_emb=self.rotary_emb,
            o_proj=self.o_proj,
            fused_qkv_a_proj=self.fused_qkv_a_proj
991
992
            if self.q_lora_rank is not None
            else None,
993
            kv_a_proj_with_mqa=self.kv_a_proj_with_mqa
994
995
996
            if self.q_lora_rank is None
            else None,
            q_a_layernorm=self.q_a_layernorm if self.q_lora_rank is not None else None,
997
998
            q_b_proj=self.q_b_proj if self.q_lora_rank is not None else None,
            q_proj=self.q_proj if self.q_lora_rank is None else None,
999
            indexer=self.indexer,
1000
            indexer_rotary_emb=self.indexer_rope_emb,
1001
1002
            is_sparse=self.is_v32,
            topk_indices_buffer=topk_indices_buffer,
1003
        )
1004

1005
        self.mla_attn = MultiHeadLatentAttentionWrapper(
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
            self.hidden_size,
            self.num_local_heads,
            self.scaling,
            self.qk_nope_head_dim,
            self.qk_rope_head_dim,
            self.v_head_dim,
            self.q_lora_rank,
            self.kv_lora_rank,
            mla_modules,
            cache_config,
            quant_config,
            prefix,
1018
1019
1020
1021
1022
1023
        )

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
1024
        llama_4_scaling: torch.Tensor | None,
1025
    ) -> torch.Tensor:
1026
        return self.mla_attn(positions, hidden_states, llama_4_scaling)
1027
1028


wangding zeng's avatar
wangding zeng committed
1029
class DeepseekV2DecoderLayer(nn.Module):
1030
1031
1032
1033
    def __init__(
        self,
        vllm_config: VllmConfig,
        prefix: str,
1034
1035
        config: DeepseekV2Config | None = None,
        topk_indices_buffer: torch.Tensor | None = None,
1036
    ) -> None:
wangding zeng's avatar
wangding zeng committed
1037
        super().__init__()
1038

1039
1040
        if config is None:
            config = vllm_config.model_config.hf_config
1041
1042
1043
1044
1045
        model_config = vllm_config.model_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config
        parallel_config = vllm_config.parallel_config

wangding zeng's avatar
wangding zeng committed
1046
        self.hidden_size = config.hidden_size
1047
        max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
1048
        moe_layer_freq = getattr(config, "moe_layer_freq", 1)
1049
1050
        # DecoderLayers are created with `make_layers` which passes the prefix
        # with the layer's index.
1051
        layer_idx = int(prefix.split(sep=".")[-1])
1052
        self.layer_idx = layer_idx
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062

        # verify MLA attention specific fields
        qk_nope_head_dim = getattr(config, "qk_nope_head_dim", 0)
        qk_rope_head_dim = getattr(config, "qk_rope_head_dim", 0)
        v_head_dim = getattr(config, "v_head_dim", 0)
        kv_lora_rank = getattr(config, "kv_lora_rank", 0)
        use_mha = config.model_type == "deepseek" or all(
            dim == 0 for dim in (qk_nope_head_dim, qk_rope_head_dim)
        )

1063
1064
        self.use_mha = use_mha

1065
1066
1067
        if use_mha:
            attn_cls = DeepseekAttention
        elif model_config.use_mla:
1068
1069
1070
1071
            attn_cls = DeepseekV2MLAAttention
        else:
            attn_cls = DeepseekV2Attention
        self.self_attn = attn_cls(
1072
            vllm_config=vllm_config,
wangding zeng's avatar
wangding zeng committed
1073
1074
1075
            config=config,
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
1076
1077
1078
            qk_nope_head_dim=qk_nope_head_dim,
            qk_rope_head_dim=qk_rope_head_dim,
            v_head_dim=v_head_dim,
1079
            q_lora_rank=config.q_lora_rank if hasattr(config, "q_lora_rank") else None,
1080
            kv_lora_rank=kv_lora_rank,
wangding zeng's avatar
wangding zeng committed
1081
1082
1083
            max_position_embeddings=max_position_embeddings,
            cache_config=cache_config,
            quant_config=quant_config,
1084
            prefix=f"{prefix}.self_attn",
1085
            topk_indices_buffer=topk_indices_buffer,
wangding zeng's avatar
wangding zeng committed
1086
        )
1087

1088
1089
1090
        if (
            config.n_routed_experts is not None
            and layer_idx >= config.first_k_dense_replace
1091
            and layer_idx % moe_layer_freq == 0
1092
        ):
1093
1094
            self.mlp = DeepseekV2MoE(
                config=config,
1095
                parallel_config=parallel_config,
1096
1097
1098
                quant_config=quant_config,
                prefix=f"{prefix}.mlp",
            )
wangding zeng's avatar
wangding zeng committed
1099
1100
1101
1102
1103
1104
        else:
            self.mlp = DeepseekV2MLP(
                hidden_size=config.hidden_size,
                intermediate_size=config.intermediate_size,
                hidden_act=config.hidden_act,
                quant_config=quant_config,
1105
                prefix=f"{prefix}.mlp",
wangding zeng's avatar
wangding zeng committed
1106
            )
1107
1108
1109
1110
        self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(
            config.hidden_size, eps=config.rms_norm_eps
        )
1111
        self.routed_scaling_factor = getattr(config, "routed_scaling_factor", 1.0)
wangding zeng's avatar
wangding zeng committed
1112
1113
1114
1115
1116

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
1117
        residual: torch.Tensor | None,
1118
        llama_4_scaling: torch.Tensor | None = None,
wangding zeng's avatar
wangding zeng committed
1119
1120
1121
    ) -> torch.Tensor:
        # Self Attention
        if residual is None:
1122
            residual = hidden_states.clone()
wangding zeng's avatar
wangding zeng committed
1123
1124
            hidden_states = self.input_layernorm(hidden_states)
        else:
1125
            hidden_states, residual = self.input_layernorm(hidden_states, residual)
1126
1127
1128
1129
1130
1131
1132
1133

        attn_kwargs = {
            "positions": positions,
            "hidden_states": hidden_states,
        }
        if not self.use_mha:
            attn_kwargs["llama_4_scaling"] = llama_4_scaling
        hidden_states = self.self_attn(**attn_kwargs)
wangding zeng's avatar
wangding zeng committed
1134

1135
1136
1137
1138
        if (
            not isinstance(self.self_attn, DeepseekAttention)
            and hidden_states.dtype == torch.float16
        ):
1139
1140
1141
            # Fix FP16 overflow
            # We scale both hidden_states and residual before
            # rmsnorm, and rmsnorm result would not affect by scale.
1142
            hidden_states *= 1.0 / self.routed_scaling_factor
1143
1144
1145
            if self.layer_idx == 0:
                # The residual is shared by all layers, we only scale it on
                # first layer.
1146
                residual *= 1.0 / self.routed_scaling_factor
1147
1148

        # Fully Connected
1149
        hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
wangding zeng's avatar
wangding zeng committed
1150
        hidden_states = self.mlp(hidden_states)
1151

1152
        if isinstance(self.mlp, DeepseekV2MLP) and hidden_states.dtype == torch.float16:
1153
1154
1155
1156
1157
            # Fix FP16 overflow
            # Scaling the DeepseekV2MLP output, it is the input of
            # input_layernorm of next decoder layer.
            # The scaling of DeepseekV2MOE output would be done in the forward
            # of DeepseekV2MOE
1158
            hidden_states *= 1.0 / self.routed_scaling_factor
1159

wangding zeng's avatar
wangding zeng committed
1160
1161
1162
        return hidden_states, residual


1163
@support_torch_compile
wangding zeng's avatar
wangding zeng committed
1164
1165
1166
class DeepseekV2Model(nn.Module):
    fall_back_to_pt_during_load = False

1167
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
wangding zeng's avatar
wangding zeng committed
1168
        super().__init__()
1169
1170
1171

        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
1172
        self.config = config
1173
        self.device = current_platform.device_type
1174

wangding zeng's avatar
wangding zeng committed
1175
        self.vocab_size = config.vocab_size
1176
1177
1178
1179
1180
1181
1182
        self.is_v32 = hasattr(config, "index_topk")
        if self.is_v32:
            topk_tokens = config.index_topk
            topk_indices_buffer = torch.empty(
                vllm_config.scheduler_config.max_num_batched_tokens,
                topk_tokens,
                dtype=torch.int32,
1183
                device=self.device,
1184
            )
1185
1186
        else:
            topk_indices_buffer = None
wangding zeng's avatar
wangding zeng committed
1187

1188
1189
1190
1191
        if get_pp_group().is_first_rank:
            self.embed_tokens = VocabParallelEmbedding(
                config.vocab_size,
                config.hidden_size,
1192
                quant_config=quant_config,
1193
1194
                prefix=f"{prefix}.embed_tokens",
            )
1195
1196
1197
1198
        else:
            self.embed_tokens = PPMissingLayer()
        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
1199
            lambda prefix: DeepseekV2DecoderLayer(
1200
                vllm_config, prefix, topk_indices_buffer=topk_indices_buffer
1201
1202
1203
            ),
            prefix=f"{prefix}.layers",
        )
1204
1205
1206
1207
1208

        if get_pp_group().is_last_rank:
            self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        else:
            self.norm = PPMissingLayer()
1209
1210
1211
        self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
            ["hidden_states", "residual"], config.hidden_size
        )
wangding zeng's avatar
wangding zeng committed
1212

1213
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
1214
1215
        return self.embed_tokens(input_ids)

wangding zeng's avatar
wangding zeng committed
1216
1217
    def forward(
        self,
1218
        input_ids: torch.Tensor | None,
wangding zeng's avatar
wangding zeng committed
1219
        positions: torch.Tensor,
1220
1221
1222
        intermediate_tensors: IntermediateTensors | None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor | IntermediateTensors:
1223
        if get_pp_group().is_first_rank:
1224
1225
1226
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
1227
                hidden_states = self.embed_input_ids(input_ids)
1228
1229
1230
1231
1232
1233
            residual = None
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]

1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
        # Compute llama 4 scaling once per forward pass if enabled
        llama_4_scaling_config = getattr(self.config, "llama_4_scaling", None)
        llama_4_scaling: torch.Tensor | None
        if llama_4_scaling_config is not None:
            llama_4_scaling = _get_llama_4_scaling(
                original_max_position_embeddings=llama_4_scaling_config[
                    "original_max_position_embeddings"
                ],
                scaling_beta=llama_4_scaling_config["beta"],
                positions=positions,
            )
        else:
            llama_4_scaling = None

1248
        for layer in islice(self.layers, self.start_layer, self.end_layer):
1249
1250
1251
            hidden_states, residual = layer(
                positions, hidden_states, residual, llama_4_scaling
            )
1252
1253

        if not get_pp_group().is_last_rank:
1254
1255
1256
            return IntermediateTensors(
                {"hidden_states": hidden_states, "residual": residual}
            )
1257

wangding zeng's avatar
wangding zeng committed
1258
1259
1260
        hidden_states, _ = self.norm(hidden_states, residual)
        return hidden_states

1261

1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
class DeepseekV2MixtureOfExperts(MixtureOfExperts):
    moe_mlp_layers: list[DeepseekV2MoE]
    """
    List of MoE MLP layers in the model.
    """

    def extract_moe_parameters(self, example_moe: DeepseekV2MoE | None):
        if example_moe is None:
            self.num_moe_layers = 0
            self.num_expert_groups = 0
            self.num_logical_experts = 0
            self.num_physical_experts = 0
            self.num_local_physical_experts = 0
            self.num_routed_experts = 0
            self.num_shared_experts = 0
            self.num_redundant_experts = 0
            logger.warning("DeepSeekV2: No DeepseekV2MoE layer found in model.layers.")
        else:
            self.num_logical_experts = example_moe.n_logical_experts
            self.num_physical_experts = example_moe.n_physical_experts
            self.num_local_physical_experts = example_moe.n_local_physical_experts
            self.num_routed_experts = example_moe.n_routed_experts
            self.num_shared_experts = example_moe.n_shared_experts
            self.num_redundant_experts = example_moe.n_redundant_experts

    def update_physical_experts_metadata(
        self,
        num_physical_experts: int,
        num_local_physical_experts: int,
    ) -> None:
        assert self.num_local_physical_experts == num_local_physical_experts
        self.num_physical_experts = num_physical_experts
        self.num_local_physical_experts = num_local_physical_experts
        self.num_redundant_experts = num_physical_experts - self.num_logical_experts
        for moe in self.moe_mlp_layers:
            moe.n_local_physical_experts = num_local_physical_experts
            moe.n_physical_experts = num_physical_experts
            moe.n_redundant_experts = self.num_redundant_experts
            moe.experts.update_expert_map()


class DeepseekV2ForCausalLM(
1304
    nn.Module, SupportsPP, DeepseekV2MixtureOfExperts, SupportsLoRA, SupportsEagle
1305
):
1306
1307
1308
    packed_modules_mapping = {
        "gate_up_proj": ["gate_proj", "up_proj"],
    }
1309
    model_cls = DeepseekV2Model
1310
1311
1312
1313
1314
1315
1316

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        self.config = config
        self.quant_config = quant_config
1317

1318
1319
1320
1321
1322
1323
1324
1325
1326
        qk_nope_head_dim = getattr(config, "qk_nope_head_dim", 0)
        qk_rope_head_dim = getattr(config, "qk_rope_head_dim", 0)
        self.use_mha = config.model_type == "deepseek" or all(
            dim == 0 for dim in (qk_nope_head_dim, qk_rope_head_dim)
        )

        if self.use_mha:
            self.packed_modules_mapping["qkv_proj"] = ["q_proj", "k_proj", "v_proj"]

1327
1328
1329
1330
        # `packed_modules_mapping` needs to be modified before
        # initializing DeepseekV2Model, as it is passed inplace to
        # quantization config init and may be used to select the
        # quant_method for relevant layers during initialization.
1331
1332
1333
        self.fuse_qkv_a_proj = (
            hasattr(config, "q_lora_rank") and config.q_lora_rank is not None
        )
1334
1335
1336
1337
1338
1339
        if self.fuse_qkv_a_proj:
            self.packed_modules_mapping["fused_qkv_a_proj"] = [
                "q_a_proj",
                "kv_a_proj_with_mqa",
            ]

1340
        self.model = self.model_cls(
1341
1342
            vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
        )
1343
        if get_pp_group().is_last_rank:
1344
1345
1346
1347
1348
1349
            self.lm_head = ParallelLMHead(
                config.vocab_size,
                config.hidden_size,
                quant_config=quant_config,
                prefix=maybe_prefix(prefix, "lm_head"),
            )
1350
1351
1352
1353
        else:
            self.lm_head = PPMissingLayer()
        self.logits_processor = LogitsProcessor(config.vocab_size)
        self.make_empty_intermediate_tensors = (
1354
1355
            self.model.make_empty_intermediate_tensors
        )
1356
1357
1358
1359
1360
1361
1362
        # Set MoE hyperparameters
        self.num_moe_layers = (
            self.config.num_hidden_layers - self.config.first_k_dense_replace
        )
        self.set_moe_parameters()

    def set_moe_parameters(self):
1363
1364
        self.expert_weights = []

1365
        self.num_expert_groups = getattr(self.config, "n_group", 1)
1366

1367
1368
        self.moe_layers = []
        self.moe_mlp_layers = []
1369
        example_moe = None
1370
        for layer in self.model.layers:
1371
1372
1373
            if isinstance(layer, PPMissingLayer):
                continue

1374
1375
            assert isinstance(layer, DeepseekV2DecoderLayer)
            if isinstance(layer.mlp, DeepseekV2MoE):
1376
1377
                # Pick last one layer since the first ones may be dense layers.
                example_moe = layer.mlp
1378
                self.moe_mlp_layers.append(layer.mlp)
1379
1380
                self.moe_layers.append(layer.mlp.experts)

1381
        self.extract_moe_parameters(example_moe)
1382

1383
1384
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.embed_input_ids(input_ids)
1385
1386
1387

    def forward(
        self,
1388
        input_ids: torch.Tensor | None,
1389
        positions: torch.Tensor,
1390
1391
1392
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor | IntermediateTensors:
1393
1394
1395
        hidden_states = self.model(
            input_ids, positions, intermediate_tensors, inputs_embeds
        )
1396
1397
1398
1399
1400
        return hidden_states

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
1401
    ) -> torch.Tensor | None:
1402
        logits = self.logits_processor(self.lm_head, hidden_states)
1403
1404
        return logits

1405
1406
1407
1408
    def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
        # Params for weights, fp8 weight scales, fp8 activation scales
        # (param_name, weight_name, expert_id, shard_id)
        return SharedFusedMoE.make_expert_params_mapping(
1409
            self,
1410
1411
1412
1413
1414
1415
1416
            ckpt_gate_proj_name="gate_proj",
            ckpt_down_proj_name="down_proj",
            ckpt_up_proj_name="up_proj",
            num_experts=self.config.n_routed_experts,
            num_redundant_experts=0,
        )

1417
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
1418
1419
1420
        rocm_aiter_moe_shared_expert_enabled = (
            rocm_aiter_ops.is_fusion_moe_shared_experts_enabled()
        )
wangding zeng's avatar
wangding zeng committed
1421
1422
1423
1424
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
1425
1426
        ]
        mla_params_mapping = [
1427
1428
            ("fused_qkv_a_proj", "q_a_proj", 0),
            ("fused_qkv_a_proj", "kv_a_proj_with_mqa", 1),
wangding zeng's avatar
wangding zeng committed
1429
        ]
1430
1431
1432
1433
1434
1435
1436
1437
1438
        mha_params_mapping = [
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
        ]
        if self.use_mha:
            stacked_params_mapping.extend(mha_params_mapping)
        else:
            stacked_params_mapping.extend(mla_params_mapping)
wangding zeng's avatar
wangding zeng committed
1439

1440
1441
        # Params for weights, fp8 weight scales, fp8 activation scales
        # (param_name, weight_name, expert_id, shard_id)
1442
        expert_params_mapping = SharedFusedMoE.make_expert_params_mapping(
1443
            self,
1444
1445
1446
            ckpt_gate_proj_name="gate_proj",
            ckpt_down_proj_name="down_proj",
            ckpt_up_proj_name="up_proj",
1447
1448
1449
            num_experts=self.config.n_routed_experts
            + (
                self.config.n_shared_experts
1450
                if rocm_aiter_moe_shared_expert_enabled
1451
1452
                else 0
            ),
1453
1454
            num_redundant_experts=self.num_redundant_experts,
        )
1455

wangding zeng's avatar
wangding zeng committed
1456
        params_dict = dict(self.named_parameters())
1457
        loaded_params: set[str] = set()
wangding zeng's avatar
wangding zeng committed
1458
        for name, loaded_weight in weights:
1459
1460
1461
            if "rotary_emb.inv_freq" in name:
                continue

1462
1463
1464
            spec_layer = get_spec_layer_idx_from_weight_name(self.config, name)
            if spec_layer is not None:
                continue  # skip spec decode layers for main model
1465

1466
1467
            is_fusion_moe_shared_experts_layer = (
                rocm_aiter_moe_shared_expert_enabled and ("mlp.shared_experts" in name)
1468
1469
            )

1470
            for param_name, weight_name, shard_id in stacked_params_mapping:
1471
                # Skip non-stacked layers and experts (experts handled below).
wangding zeng's avatar
wangding zeng committed
1472
1473
                if weight_name not in name:
                    continue
1474
1475
1476
1477
1478
1479
                # We have mlp.experts[0].gate_proj in the checkpoint.
                # Since we handle the experts below in expert_params_mapping,
                # we need to skip here BEFORE we update the name, otherwise
                # name will be updated to mlp.experts[0].gate_up_proj, which
                # will then be updated below in expert_params_mapping
                # for mlp.experts[0].gate_gate_up_proj, which breaks load.
1480
                if ("mlp.experts." in name) and name not in params_dict:
1481
                    continue
1482
                if is_fusion_moe_shared_experts_layer:
1483
                    continue
1484
                name_mapped = name.replace(weight_name, param_name)
1485
1486
1487

                # QKV fusion is optional, fall back to normal
                # weight loading if it's not enabled
1488
                # if go with fusion option, then update name
1489
1490
1491
                if (
                    param_name == "fused_qkv_a_proj"
                ) and name_mapped not in params_dict:
1492
                    continue
1493
1494
                else:
                    name = name_mapped
wangding zeng's avatar
wangding zeng committed
1495
1496
1497
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
1498
1499
1500
1501

                if is_pp_missing_parameter(name, self):
                    continue

wangding zeng's avatar
wangding zeng committed
1502
1503
1504
1505
1506
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
1507
                is_expert_weight = False
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517

                # Special handling: when AITER fusion_shared_experts is enabled,
                # checkpoints may provide a single widened shared_experts tensor
                # without explicit expert indices
                # (e.g. ...mlp.shared_experts.gate_proj.weight).
                # For models with multiple shared experts, split that tensor
                # evenly into per-shared-expert slices and load them into
                # appended expert slots mlp.experts.{n_routed_experts + j}.*
                # accordingly.
                num_chunks = 1
1518
                if is_fusion_moe_shared_experts_layer:
1519
1520
1521
1522
                    num_chunks = getattr(self.config, "n_shared_experts", 1) or 1
                    # Determine split axis based on op type
                    # gate/up: ColumnParallel → split along dim 0
                    # down: RowParallel → split along dim 1
1523
1524
1525
1526
1527
                    split_dim = (
                        1
                        if ("down_proj.weight" in name and loaded_weight.ndim > 1)
                        else 0
                    )
1528
1529
1530
1531
                    total = loaded_weight.shape[split_dim]
                    assert total % num_chunks == 0, (
                        f"Shared expert weight dim {total} "
                        f"not divisible by num_chunks {num_chunks}"
1532
                    )
1533
1534
1535
1536
1537
1538
                    chunk_size = total // num_chunks

                for j in range(num_chunks):
                    chunk_name = name
                    weight_to_load = loaded_weight

1539
                    if is_fusion_moe_shared_experts_layer:
1540
1541
1542
1543
1544
                        chunk_slice = slice(j * chunk_size, (j + 1) * chunk_size)
                        if loaded_weight.ndim == 1:
                            weight_to_load = loaded_weight[chunk_slice]
                        elif split_dim == 0:
                            weight_to_load = loaded_weight[chunk_slice, :]
1545
                        else:
1546
                            weight_to_load = loaded_weight[:, chunk_slice]
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
                        # Synthesize an expert-style name so expert mapping
                        # can route it
                        chunk_name = name.replace(
                            "mlp.shared_experts",
                            f"mlp.experts.{self.config.n_routed_experts + j}",
                        )

                    # Use expert_params_mapping to locate the destination
                    # param and delegate to its expert-aware weight_loader
                    # with expert_id.
                    for mapping in expert_params_mapping:
                        param_name, weight_name, expert_id, shard_id = mapping
                        if weight_name not in chunk_name:
                            continue

                        # Anyway, this is an expert weight and should not be
                        # attempted to load as other weights later
                        is_expert_weight = True

                        # Do not modify `name` since the loop may continue here
                        # Instead, create a new variable
                        name_mapped = chunk_name.replace(weight_name, param_name)

                        if is_pp_missing_parameter(name_mapped, self):
                            continue

                        param = params_dict[name_mapped]
                        # We should ask the weight loader to return success or
                        # not here since otherwise we may skip experts with
                        # other available replicas.
                        weight_loader = typing.cast(
                            Callable[..., bool], param.weight_loader
                        )
                        success = weight_loader(
                            param,
                            weight_to_load,
                            name_mapped,
                            shard_id=shard_id,
                            expert_id=expert_id,
                            return_success=True,
                        )
                        if success:
1589
                            if not is_fusion_moe_shared_experts_layer:
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
                                name = name_mapped
                            else:
                                loaded_params.add(name_mapped)
                            break
                    else:
                        if is_expert_weight:
                            # We've checked that this is an expert weight
                            # However it's not mapped locally to this rank
                            # So we simply skip it
                            continue

                        # Skip loading extra bias for GPTQ models.
                        if name.endswith(".bias") and name not in params_dict:
                            continue

                        # Remapping the name of FP8 kv-scale.
                        name = maybe_remap_kv_scale_name(name, params_dict)
                        if name is None:
                            continue

                        if is_pp_missing_parameter(name, self):
                            continue

                        param = params_dict[name]
                        weight_loader = getattr(
                            param, "weight_loader", default_weight_loader
                        )
                        weight_loader(param, loaded_weight)
1618
            if name is not None and not is_fusion_moe_shared_experts_layer:
1619
                loaded_params.add(name)
1620

1621
        return loaded_params
1622
1623


1624
1625
1626
1627
class DeepseekForCausalLM(DeepseekV2ForCausalLM):
    pass


1628
1629
class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM):
    pass
1630
1631


Jee Jee Li's avatar
Jee Jee Li committed
1632
1633
1634
1635
class GlmMoeDsaForCausalLM(DeepseekV2ForCausalLM):
    pass


1636
1637
# Compatibility with
# https://huggingface.co/deepseek-ai/DeepSeek-V3-Base/blob/main/configuration_deepseek.py
1638
def get_spec_layer_idx_from_weight_name(
1639
1640
    config: DeepseekV2Config | DeepseekV3Config, weight_name: str
) -> int | None:
1641
1642
1643
1644
    if (
        hasattr(config, "num_nextn_predict_layers")
        and config.num_nextn_predict_layers > 0
    ):
1645
1646
        layer_idx = config.num_hidden_layers
        for i in range(config.num_nextn_predict_layers):
1647
            if weight_name.startswith(f"model.layers.{layer_idx + i}."):
1648
1649
                return layer_idx + i
    return None