deepseek_v2.py 57.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

wangding zeng's avatar
wangding zeng committed
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# Copyright 2023 The vLLM team.
# Copyright 2023 DeepSeek-AI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
25
"""Inference-only DeepseekV2/DeepseekV3 model."""
26

27
28
import typing
from collections.abc import Callable, Iterable
29
from itertools import islice
30
from typing import Any
wangding zeng's avatar
wangding zeng committed
31
32
33

import torch
from torch import nn
34
from transformers import DeepseekV2Config, DeepseekV3Config
wangding zeng's avatar
wangding zeng committed
35

36
from vllm.attention import Attention
37
38
from vllm.attention.backends.abstract import AttentionBackend
from vllm.attention.ops.common import pack_seq_triton, unpack_seq_triton
39
from vllm.compilation.decorators import support_torch_compile
40
41
42
43
44
45
46
47
from vllm.config import CacheConfig, ParallelConfig, VllmConfig, get_current_vllm_config
from vllm.distributed import (
    get_ep_group,
    get_pp_group,
    get_tensor_model_parallel_rank,
    get_tensor_model_parallel_world_size,
    tensor_model_parallel_all_gather,
)
48
49
from vllm.forward_context import get_forward_context
from vllm.logger import init_logger
wangding zeng's avatar
wangding zeng committed
50
from vllm.model_executor.layers.activation import SiluAndMul
51
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
52
from vllm.model_executor.layers.fused_moe import SharedFusedMoE
53
54
55
56
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
    is_rocm_aiter_fusion_shared_expert_enabled,
    is_rocm_aiter_moe_enabled,
)
57
from vllm.model_executor.layers.layernorm import LayerNorm, RMSNorm
58
59
60
61
62
63
from vllm.model_executor.layers.linear import (
    ColumnParallelLinear,
    MergedColumnParallelLinear,
    ReplicatedLinear,
    RowParallelLinear,
)
wangding zeng's avatar
wangding zeng committed
64
from vllm.model_executor.layers.logits_processor import LogitsProcessor
65
from vllm.model_executor.layers.mla import MLAModules, MultiHeadLatentAttentionWrapper
66
from vllm.model_executor.layers.quantization import QuantizationConfig
67
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
68
69
    per_token_group_quant_fp8,
)
wangding zeng's avatar
wangding zeng committed
70
71
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
72
73
74
    ParallelLMHead,
    VocabParallelEmbedding,
)
75
from vllm.model_executor.model_loader.weight_utils import (
76
77
78
    default_weight_loader,
    maybe_remap_kv_scale_name,
)
79
from vllm.model_executor.models.utils import sequence_parallel_chunk
80
from vllm.platforms import current_platform
81
from vllm.sequence import IntermediateTensors
82
from vllm.utils.deep_gemm import fp8_mqa_logits, fp8_paged_mqa_logits
83
from vllm.utils.torch_utils import direct_register_custom_op
84
85
86
87
from vllm.v1.attention.backends.mla.indexer import (
    DeepseekV32IndexerBackend,
    DeepseekV32IndexerMetadata,
)
88
from vllm.v1.kv_cache_interface import KVCacheSpec, MLAAttentionSpec
wangding zeng's avatar
wangding zeng committed
89

90
from .interfaces import MixtureOfExperts, SupportsLoRA, SupportsPP
91
92
93
94
95
96
97
from .utils import (
    PPMissingLayer,
    is_pp_missing_parameter,
    make_empty_intermediate_tensors_factory,
    make_layers,
    maybe_prefix,
)
98

99
100
101
102
103
104
105
if current_platform.is_cuda_alike():
    from vllm import _custom_ops as ops
elif current_platform.is_xpu():
    from vllm._ipex_ops import ipex_ops as ops

logger = init_logger(__name__)

wangding zeng's avatar
wangding zeng committed
106
107
108
109
110
111
112

class DeepseekV2MLP(nn.Module):
    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
113
        quant_config: QuantizationConfig | None = None,
wangding zeng's avatar
wangding zeng committed
114
        reduce_results: bool = True,
115
        is_sequence_parallel=False,
116
        prefix: str = "",
wangding zeng's avatar
wangding zeng committed
117
118
    ) -> None:
        super().__init__()
119
120
121
122
123

        # If is_sequence_parallel, the input and output tensors are sharded
        # across the ranks within the tp_group. In this case the weights are
        # replicated and no collective ops are needed.
        # Otherwise we use standard TP with an allreduce at the end.
wangding zeng's avatar
wangding zeng committed
124
        self.gate_up_proj = MergedColumnParallelLinear(
125
126
            hidden_size,
            [intermediate_size] * 2,
wangding zeng's avatar
wangding zeng committed
127
            bias=False,
128
            quant_config=quant_config,
129
            disable_tp=is_sequence_parallel,
130
131
132
133
134
135
136
137
138
139
140
            prefix=f"{prefix}.gate_up_proj",
        )
        self.down_proj = RowParallelLinear(
            intermediate_size,
            hidden_size,
            bias=False,
            quant_config=quant_config,
            reduce_results=reduce_results,
            disable_tp=is_sequence_parallel,
            prefix=f"{prefix}.down_proj",
        )
wangding zeng's avatar
wangding zeng committed
141
        if hidden_act != "silu":
142
143
144
            raise ValueError(
                f"Unsupported activation: {hidden_act}. Only silu is supported for now."
            )
wangding zeng's avatar
wangding zeng committed
145
146
147
148
149
150
151
152
153
154
155
156
        self.act_fn = SiluAndMul()

    def forward(self, x):
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.down_proj(x)
        return x


class DeepseekV2MoE(nn.Module):
    def __init__(
        self,
157
        config: DeepseekV2Config | DeepseekV3Config,
158
        parallel_config: ParallelConfig,
159
        quant_config: QuantizationConfig | None = None,
160
        prefix: str = "",
wangding zeng's avatar
wangding zeng committed
161
162
163
    ):
        super().__init__()
        self.tp_size = get_tensor_model_parallel_world_size()
164
165
        self.tp_rank = get_tensor_model_parallel_rank()

wangding zeng's avatar
wangding zeng committed
166
        self.routed_scaling_factor = config.routed_scaling_factor
167
168
169
170
171
172

        self.ep_group = get_ep_group().device_group
        self.ep_rank = self.ep_group.rank()
        self.ep_size = self.ep_group.size()
        self.n_routed_experts: int = config.n_routed_experts
        self.n_shared_experts: int = config.n_shared_experts
173

174
        self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe
175

176
        if config.hidden_act != "silu":
177
178
179
180
181
182
183
184
185
186
187
188
            raise ValueError(
                f"Unsupported activation: {config.hidden_act}. "
                "Only silu is supported for now."
            )

        self.gate = ReplicatedLinear(
            config.hidden_size,
            config.n_routed_experts,
            bias=False,
            quant_config=None,
            prefix=f"{prefix}.gate",
        )
189
190
        if config.topk_method == "noaux_tc":
            self.gate.e_score_correction_bias = nn.Parameter(
191
192
                torch.empty(config.n_routed_experts, dtype=torch.float32)
            )
193
194
195
        else:
            self.gate.e_score_correction_bias = None

196
        # Load balancing settings.
197
198
        eplb_config = parallel_config.eplb_config
        self.enable_eplb = parallel_config.enable_eplb
199

200
        self.n_redundant_experts = eplb_config.num_redundant_experts
201
        self.n_logical_experts = self.n_routed_experts
202
        self.n_physical_experts = self.n_logical_experts + self.n_redundant_experts
203
204
        self.n_local_physical_experts = self.n_physical_experts // self.ep_size

205
206
207
208
        self.physical_expert_start = self.ep_rank * self.n_local_physical_experts
        self.physical_expert_end = (
            self.physical_expert_start + self.n_local_physical_experts
        )
209

210
211
212
213
        if (
            config.n_shared_experts is None
            or is_rocm_aiter_fusion_shared_expert_enabled()
        ):
214
215
            self.shared_experts = None
        else:
216
            intermediate_size = config.moe_intermediate_size * config.n_shared_experts
217

wangding zeng's avatar
wangding zeng committed
218
219
220
221
222
            self.shared_experts = DeepseekV2MLP(
                hidden_size=config.hidden_size,
                intermediate_size=intermediate_size,
                hidden_act=config.hidden_act,
                quant_config=quant_config,
223
                is_sequence_parallel=self.is_sequence_parallel,
224
                reduce_results=False,
225
                prefix=f"{prefix}.shared_experts",
wangding zeng's avatar
wangding zeng committed
226
227
            )

228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
        self.experts = SharedFusedMoE(
            shared_experts=self.shared_experts,
            num_experts=config.n_routed_experts,
            top_k=config.num_experts_per_tok,
            hidden_size=config.hidden_size,
            intermediate_size=config.moe_intermediate_size,
            reduce_results=False,
            renormalize=config.norm_topk_prob,
            quant_config=quant_config,
            use_grouped_topk=True,
            num_expert_group=config.n_group,
            topk_group=config.topk_group,
            prefix=f"{prefix}.experts",
            scoring_func=config.scoring_func,
            # we do scaling outside, set factor to 1.0 to avoid double mul
243
244
245
246
            # aiter applies routed_scaling_factor internally
            routed_scaling_factor=1.0
            if not is_rocm_aiter_moe_enabled()
            else self.routed_scaling_factor,
247
248
249
250
            e_score_correction_bias=self.gate.e_score_correction_bias,
            enable_eplb=self.enable_eplb,
            num_redundant_experts=self.n_redundant_experts,
            is_sequence_parallel=self.is_sequence_parallel,
251
252
253
            n_shared_experts=config.n_shared_experts
            if is_rocm_aiter_fusion_shared_expert_enabled()
            else None,
254
        )
255

wangding zeng's avatar
wangding zeng committed
256
257
258
    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        num_tokens, hidden_dim = hidden_states.shape
        hidden_states = hidden_states.view(-1, hidden_dim)
259
260
261
262
263
264

        # Chunk the hidden states so they aren't replicated across TP ranks.
        # This avoids duplicate computation in self.experts.
        # TODO: We can replace the all_reduce at the end of attn with a
        # reduce_scatter instead of chunking here.
        if self.is_sequence_parallel:
265
            hidden_states = sequence_parallel_chunk(hidden_states)
266

wangding zeng's avatar
wangding zeng committed
267
268
        # router_logits: (num_tokens, n_experts)
        router_logits, _ = self.gate(hidden_states)
269

270
271
272
        fused_moe_out = self.experts(
            hidden_states=hidden_states, router_logits=router_logits
        )
273

274
275
276
        shared_output, final_hidden_states = fused_moe_out
        if self.shared_experts is None:
            assert shared_output is None
277
278
279
280

        # Fix FP16 overflow
        # See DeepseekV2DecoderLayer for more details.
        if hidden_states.dtype != torch.float16:
281
282
            if not is_rocm_aiter_moe_enabled():
                final_hidden_states *= self.routed_scaling_factor
283
284
        elif self.shared_experts is not None:
            assert shared_output is not None
285
            shared_output *= 1.0 / self.routed_scaling_factor
286
287
288
289

        if self.shared_experts is not None:
            assert shared_output is not None
            final_hidden_states += shared_output
290

291
292
        if self.is_sequence_parallel:
            final_hidden_states = tensor_model_parallel_all_gather(
293
294
                final_hidden_states, 0
            )
295
296
            final_hidden_states = final_hidden_states[:num_tokens]
        elif self.tp_size > 1:
297
298
299
            final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel(
                final_hidden_states
            )
wangding zeng's avatar
wangding zeng committed
300
301
302
303
304
305

        return final_hidden_states.view(num_tokens, hidden_dim)


def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
    import math
306

wangding zeng's avatar
wangding zeng committed
307
308
309
310
311
312
313
314
    if scale <= 1:
        return 1.0
    return 0.1 * mscale * math.log(scale) + 1.0


class DeepseekV2Attention(nn.Module):
    def __init__(
        self,
315
        vllm_config: VllmConfig,
316
        config: DeepseekV2Config | DeepseekV3Config,
wangding zeng's avatar
wangding zeng committed
317
318
319
320
321
322
323
324
        hidden_size: int,
        num_heads: int,
        qk_nope_head_dim: int,
        qk_rope_head_dim: int,
        v_head_dim: int,
        q_lora_rank: int,
        kv_lora_rank: int,
        rope_theta: float = 10000,
325
        rope_scaling: dict[str, Any] | None = None,
wangding zeng's avatar
wangding zeng committed
326
        max_position_embeddings: int = 8192,
327
328
329
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
        topk_indices_buffer: torch.Tensor | None = None,
330
        prefix: str = "",
wangding zeng's avatar
wangding zeng committed
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
    ) -> None:
        super().__init__()
        self.hidden_size = hidden_size
        self.qk_nope_head_dim = qk_nope_head_dim
        self.qk_rope_head_dim = qk_rope_head_dim
        self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
        self.v_head_dim = v_head_dim
        self.q_lora_rank = q_lora_rank
        self.kv_lora_rank = kv_lora_rank
        self.num_heads = num_heads
        tp_size = get_tensor_model_parallel_world_size()
        assert num_heads % tp_size == 0
        self.num_local_heads = num_heads // tp_size
        self.scaling = self.qk_head_dim**-0.5
        self.rope_theta = rope_theta
        self.max_position_embeddings = max_position_embeddings
347
348
        assert topk_indices_buffer is None, (
            "topk_indices_buffer is not \
349
        supported for DeepseekV2Attention"
350
        )
wangding zeng's avatar
wangding zeng committed
351
352

        if self.q_lora_rank is not None:
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
            self.q_a_proj = ReplicatedLinear(
                self.hidden_size,
                self.q_lora_rank,
                bias=False,
                quant_config=quant_config,
                prefix=f"{prefix}.q_a_proj",
            )
            self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
            self.q_b_proj = ColumnParallelLinear(
                q_lora_rank,
                self.num_heads * self.qk_head_dim,
                bias=False,
                quant_config=quant_config,
                prefix=f"{prefix}.q_b_proj",
            )
wangding zeng's avatar
wangding zeng committed
368
        else:
369
370
371
372
373
374
375
            self.q_proj = ColumnParallelLinear(
                self.hidden_size,
                self.num_heads * self.qk_head_dim,
                bias=False,
                quant_config=quant_config,
                prefix=f"{prefix}.q_proj",
            )
wangding zeng's avatar
wangding zeng committed
376

377
378
379
380
381
        self.kv_a_proj_with_mqa = ReplicatedLinear(
            self.hidden_size,
            self.kv_lora_rank + self.qk_rope_head_dim,
            bias=False,
            quant_config=quant_config,
382
383
384
            prefix=f"{prefix}.kv_a_proj_with_mqa",
        )
        self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
wangding zeng's avatar
wangding zeng committed
385
386
387
388
        self.kv_b_proj = ColumnParallelLinear(
            self.kv_lora_rank,
            self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
            bias=False,
389
            quant_config=quant_config,
390
391
            prefix=f"{prefix}.kv_b_proj",
        )
wangding zeng's avatar
wangding zeng committed
392
        # O projection.
393
394
395
396
397
398
399
        self.o_proj = RowParallelLinear(
            self.num_heads * self.v_head_dim,
            self.hidden_size,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.o_proj",
        )
400
        if rope_scaling:
401
            rope_scaling["rope_type"] = "deepseek_yarn"
402

403
404
405
406
407
408
409
410
        self.rotary_emb = get_rope(
            qk_rope_head_dim,
            rotary_dim=qk_rope_head_dim,
            max_position=max_position_embeddings,
            base=rope_theta,
            rope_scaling=rope_scaling,
            is_neox_style=False,
        )
wangding zeng's avatar
wangding zeng committed
411
412
413
414
415
416
417

        if rope_scaling:
            mscale_all_dim = rope_scaling.get("mscale_all_dim", False)
            scaling_factor = rope_scaling["factor"]
            mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
            self.scaling = self.scaling * mscale * mscale

418
419
420
421
422
423
424
425
426
        self.attn = Attention(
            self.num_local_heads,
            self.qk_head_dim,
            self.scaling,
            num_kv_heads=self.num_local_heads,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.attn",
        )
wangding zeng's avatar
wangding zeng committed
427
428
429
430
431
432
433
434
435

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        if self.q_lora_rank is not None:
            q = self.q_a_proj(hidden_states)[0]
            q = self.q_a_layernorm(q)
436
            q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
wangding zeng's avatar
wangding zeng committed
437
        else:
438
439
440
441
            q = self.q_proj(hidden_states)[0].view(
                -1, self.num_local_heads, self.qk_head_dim
            )
        q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
wangding zeng's avatar
wangding zeng committed
442
        latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
443
        kv_a, _ = latent_cache.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
wangding zeng's avatar
wangding zeng committed
444
        latent_cache = latent_cache.unsqueeze(1)
445
        kv_a = self.kv_a_layernorm(kv_a)
wangding zeng's avatar
wangding zeng committed
446
        kv = self.kv_b_proj(kv_a)[0]
447
        kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim)
wangding zeng's avatar
wangding zeng committed
448
        k_nope, v = kv.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
449
        k_pe = latent_cache[:, :, self.kv_lora_rank :]
450

wangding zeng's avatar
wangding zeng committed
451
        q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
452

453
        q[..., self.qk_nope_head_dim :] = q_pe
wangding zeng's avatar
wangding zeng committed
454
        k = torch.empty_like(q)
455
456
        k[..., : self.qk_nope_head_dim] = k_nope
        k[..., self.qk_nope_head_dim :] = k_pe
457
458
        # padding value to qk_head_dim for alignment
        v = torch.nn.functional.pad(
459
460
            v, [0, self.qk_head_dim - self.v_head_dim], value=0
        ).view(-1, self.num_local_heads * self.qk_head_dim)
461
        attn_output = self.attn(q, k, v)
462
463
464
        attn_output = attn_output.view(-1, self.num_local_heads, self.qk_head_dim)[
            ..., : self.v_head_dim
        ].reshape(-1, self.num_local_heads * self.v_head_dim)
wangding zeng's avatar
wangding zeng committed
465
466
467
468
        output, _ = self.o_proj(attn_output)
        return output


469
class DeepseekV32IndexerCache(torch.nn.Module, AttentionLayerBase):
470
471
472
    def __init__(
        self, head_dim: int, dtype: torch.dtype, prefix: str, cache_config: CacheConfig
    ):
473
474
475
476
477
478
479
480
481
482
483
        super().__init__()
        self.kv_cache = [torch.tensor([])]
        self.head_dim = head_dim
        self.prefix = prefix
        self.cache_config = cache_config
        self.dtype = dtype
        compilation_config = get_current_vllm_config().compilation_config
        if prefix in compilation_config.static_forward_context:
            raise ValueError(f"Duplicate layer name: {prefix}")
        compilation_config.static_forward_context[prefix] = self

484
    def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
485
486
487
488
489
490
491
        return MLAAttentionSpec(  # Only has one vector instead of K + V
            block_size=self.cache_config.block_size,
            num_kv_heads=1,
            head_size=self.head_dim,
            dtype=self.dtype,
        )

492
    def forward(self): ...
493
494
495
496
497
498
499
500
501
502
503
504
505

    def get_attn_backend(self) -> AttentionBackend:
        return DeepseekV32IndexerBackend


def sparse_attn_indexer(
    hidden_states: torch.Tensor,
    k_cache_prefix: str,
    kv_cache: torch.Tensor,
    q_fp8: torch.Tensor,
    k: torch.Tensor,
    weights: torch.Tensor,
    quant_block_size: int,
506
    scale_fmt: str | None,
507
508
509
510
    topk_tokens: int,
    head_dim: int,
    max_model_len: int,
    total_seq_lens: int,
511
    topk_indices_buffer: torch.Tensor | None,
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
) -> torch.Tensor:
    # careful! this will be None in dummy run
    attn_metadata = get_forward_context().attn_metadata
    # assert isinstance(attn_metadata, dict)
    if not isinstance(attn_metadata, dict):
        return sparse_attn_indexer_fake(
            hidden_states,
            k_cache_prefix,
            kv_cache,
            q_fp8,
            k,
            weights,
            quant_block_size,
            scale_fmt,
            topk_tokens,
            head_dim,
            max_model_len,
            total_seq_lens,
            topk_indices_buffer,
        )
    attn_metadata = attn_metadata[k_cache_prefix]
    assert isinstance(attn_metadata, DeepseekV32IndexerMetadata)
    slot_mapping = attn_metadata.slot_mapping
    has_decode = attn_metadata.num_decodes > 0
    has_prefill = attn_metadata.num_prefills > 0
    num_decode_tokens = attn_metadata.num_decode_tokens

    ops.indexer_k_quant_and_cache(
        k,
        kv_cache,
        slot_mapping,
        quant_block_size,
        scale_fmt,
    )

547
    topk_indices_buffer[: hidden_states.shape[0]] = -1
548
549
    if has_prefill:
        prefill_metadata = attn_metadata.prefill
550
        for chunk in prefill_metadata.chunks:
551
552
553
554
555
556
            k_fp8 = torch.empty(
                [chunk.total_seq_lens, head_dim],
                device=k.device,
                dtype=torch.float8_e4m3fn,
            )
            k_scale = torch.empty(
557
558
559
                [chunk.total_seq_lens, 4],
                device=k.device,
                dtype=torch.uint8,
560
            )
561
            ops.cp_gather_indexer_k_quant_cache(
562
563
564
565
566
567
568
                kv_cache,
                k_fp8,
                k_scale,
                chunk.block_table,
                chunk.cu_seq_lens,
            )
            logits = fp8_mqa_logits(
569
                q_fp8[chunk.token_start : chunk.token_end],
570
                (k_fp8, k_scale.view(torch.float32)),
571
                weights[chunk.token_start : chunk.token_end],
572
573
574
                chunk.cu_seqlen_ks,
                chunk.cu_seqlen_ke,
            )
575
576
577
578
            num_rows = logits.shape[0]
            assert topk_tokens == 2048, "top_k_per_row assumes size 2048"
            topk_indices = torch.empty(
                num_rows, topk_tokens, dtype=torch.int32, device=logits.device
579
            )
580
581
582
583
584
585
586
587
            torch.ops._C.top_k_per_row(
                logits,
                chunk.cu_seqlen_ks,
                chunk.cu_seqlen_ke,
                topk_indices,
                num_rows,
                logits.stride(0),
                logits.stride(1),
588
            )
589
            topk_indices_buffer[
590
591
                chunk.token_start : chunk.token_end, : topk_indices.shape[-1]
            ] = topk_indices.to(dtype=torch.int32)
592
593
594
595
596
597
598
599
600
601
602
603
604

    if has_decode:
        decode_metadata = attn_metadata.decode
        # kv_cache size requirement [num_block, block_size, n_head, head_dim],
        # we only have [num_block, block_size, head_dim],
        kv_cache = kv_cache.unsqueeze(-2)
        decode_lens = decode_metadata.decode_lens
        if decode_metadata.requires_padding:
            # pad in edge case where we have short chunked prefill length <
            # decode_threshold since we unstrictly split
            # prefill and decode by decode_threshold
            # (currently set to 1 + speculative tokens)
            padded_q_fp8_decode_tokens = pack_seq_triton(
605
606
                q_fp8[:num_decode_tokens], decode_lens
            )
607
608
        else:
            padded_q_fp8_decode_tokens = q_fp8[:num_decode_tokens].reshape(
609
610
                decode_lens.shape[0], -1, *q_fp8.shape[1:]
            )
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
        # TODO: move and optimize below logic with triton kernels
        batch_size = padded_q_fp8_decode_tokens.shape[0]
        next_n = padded_q_fp8_decode_tokens.shape[1]
        assert batch_size == decode_metadata.seq_lens.shape[0]
        num_padded_tokens = batch_size * next_n
        logits = fp8_paged_mqa_logits(
            padded_q_fp8_decode_tokens,
            kv_cache,
            weights[:num_padded_tokens],
            decode_metadata.seq_lens,
            decode_metadata.block_table,
            decode_metadata.schedule_metadata,
            max_model_len=max_model_len,
        )
        # padded query len
        current_device = padded_q_fp8_decode_tokens.device
        padded_num_tokens = batch_size * next_n
628
629
630
631
632
633
        row_indices = torch.arange(padded_num_tokens, device=current_device) // next_n
        next_n_offset = (
            torch.arange(padded_num_tokens, device=padded_q_fp8_decode_tokens.device)
            % next_n
        )
        index_end_pos = (
634
            decode_metadata.seq_lens[row_indices] - next_n + next_n_offset + 1
635
        ).unsqueeze(1)
636
637
638
639
640
641
642
643
644
645
646
647
648
649
        num_rows = logits.shape[0]
        assert topk_tokens == 2048, "top_k_per_row assumes size 2048"
        topk_indices = torch.empty(
            num_rows, topk_tokens, dtype=torch.int32, device=logits.device
        )
        torch.ops._C.top_k_per_row(
            logits,
            torch.zeros(num_rows, dtype=torch.int32, device=logits.device),
            index_end_pos.to(dtype=torch.int32, device=logits.device),
            topk_indices,
            num_rows,
            logits.stride(0),
            logits.stride(1),
        )
650
651
652
653
654
        if decode_metadata.requires_padding:
            # if padded, we need to unpack
            # the topk indices removing padded tokens
            topk_indices = unpack_seq_triton(
                topk_indices.reshape(batch_size, -1, topk_indices.shape[-1]),
655
656
657
658
659
                decode_lens,
            )
        topk_indices_buffer[:num_decode_tokens, : topk_indices.shape[-1]] = (
            topk_indices.to(dtype=torch.int32)
        )
660
661
662
663
664
665
666
667
668
669
670
671

    return topk_indices_buffer


def sparse_attn_indexer_fake(
    hidden_states: torch.Tensor,
    k_cache_prefix: str,
    kv_cache: torch.Tensor,
    q_fp8: torch.Tensor,
    k: torch.Tensor,
    weights: torch.Tensor,
    quant_block_size: int,
672
    scale_fmt: str | None,
673
674
675
676
    topk_tokens: int,
    head_dim: int,
    max_model_len: int,
    total_seq_lens: int,
677
    topk_indices_buffer: torch.Tensor | None,
678
679
680
681
) -> torch.Tensor:
    # profile run
    # NOTE(Chen): create the max possible flattened_kv. So that
    # profile_run can get correct memory usage.
682
683
684
685
    _flattened_kv = torch.empty(
        [total_seq_lens, head_dim + 4], device=k.device, dtype=torch.uint8
    )
    _k_fp8 = _flattened_kv[..., :head_dim].view(torch.float8_e4m3fn).contiguous()
686
687
688
689
690
691
692
693
694
695
696
697
698
699
    _k_scale = _flattened_kv[..., head_dim:].view(torch.float32).contiguous()
    return topk_indices_buffer


direct_register_custom_op(
    op_name="sparse_attn_indexer",
    op_func=sparse_attn_indexer,
    mutates_args=["topk_indices_buffer"],
    fake_impl=sparse_attn_indexer_fake,
    dispatch_key=current_platform.dispatch_key,
)


class Indexer(nn.Module):
700
701
702
    def __init__(
        self,
        vllm_config: VllmConfig,
703
        config: DeepseekV2Config | DeepseekV3Config,
704
705
        hidden_size: int,
        q_lora_rank: int,
706
707
708
        quant_config: QuantizationConfig | None,
        cache_config: CacheConfig | None,
        topk_indices_buffer: torch.Tensor | None,
709
710
        prefix: str = "",
    ):
711
712
713
714
715
716
717
718
719
720
        super().__init__()
        self.vllm_config = vllm_config
        self.config = config
        # self.indexer_cfg = config.attn_module_list_cfg[0]["attn_index"]
        self.topk_tokens = config.index_topk
        self.n_head = config.index_n_heads  # 64
        self.head_dim = config.index_head_dim  # 128
        self.rope_dim = config.qk_rope_head_dim  # 64
        self.q_lora_rank = q_lora_rank  # 1536
        # no tensor parallel, just replicated
721
722
723
724
725
726
727
728
729
730
731
732
733
734
        self.wq_b = ReplicatedLinear(
            self.q_lora_rank,
            self.head_dim * self.n_head,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.wq_b",
        )
        self.wk = ReplicatedLinear(
            hidden_size,
            self.head_dim,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.wk",
        )
735
        self.k_norm = LayerNorm(self.head_dim, eps=1e-6)
736
737
738
        self.weights_proj = ReplicatedLinear(
            hidden_size, self.n_head, quant_config=None, prefix=f"{prefix}.weights_proj"
        )
739
740
741
742
743
744
745
746
747
748
        self.softmax_scale = self.head_dim**-0.5

        self.scale_fmt = "ue8m0"
        self.quant_block_size = 128  # TODO: get from config
        self.topk_indices_buffer = topk_indices_buffer

        # NOTE: (zyongye) we use fp8 naive cache,
        #       where we store value in fp8 and scale in fp32
        #       per self.quant_block_size element
        self.k_cache = DeepseekV32IndexerCache(
749
            head_dim=self.head_dim + self.head_dim // self.quant_block_size * 4,
750
751
            dtype=torch.uint8,
            prefix=f"{prefix}.k_cache",
752
753
            cache_config=cache_config,
        )
754
755
        self.max_model_len = vllm_config.model_config.max_model_len
        self.prefix = prefix
756
757
        from vllm.v1.attention.backends.mla.indexer import get_max_prefill_buffer_size

758
759
        self.max_total_seq_len = get_max_prefill_buffer_size(vllm_config)

760
761
762
    def forward(
        self, hidden_states: torch.Tensor, qr: torch.Tensor, positions, rotary_emb
    ) -> torch.Tensor:
763
764
765
        q, _ = self.wq_b(qr)
        q = q.view(-1, self.n_head, self.head_dim)
        q_pe, q_nope = torch.split(
766
767
            q, [self.rope_dim, self.head_dim - self.rope_dim], dim=-1
        )
768
769
770
771

        k, _ = self.wk(hidden_states)
        k = self.k_norm(k)
        k_pe, k_nope = torch.split(
772
773
            k, [self.rope_dim, self.head_dim - self.rope_dim], dim=-1
        )
774
775
776
777
778
779
780

        q_pe, k_pe = rotary_emb(positions, q_pe, k_pe.unsqueeze(1))
        q = torch.cat([q_pe, q_nope], dim=-1)
        k = torch.cat([k_pe.squeeze(1), k_nope], dim=-1)

        # we only quant q here since k quant is fused with cache insertion
        q = q.view(-1, self.head_dim)
781
782
783
784
785
786
        q_fp8, q_scale = per_token_group_quant_fp8(
            q,
            self.quant_block_size,
            column_major_scales=False,
            use_ue8m0=self.scale_fmt is not None,
        )
787
788
789
790
        q_fp8 = q_fp8.view(-1, self.n_head, self.head_dim)
        q_scale = q_scale.view(-1, self.n_head, 1)

        weights, _ = self.weights_proj(hidden_states)
791
792
793
        weights = (
            weights.unsqueeze(-1) * q_scale * self.softmax_scale * self.n_head**-0.5
        )
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
        weights = weights.squeeze(-1)

        return torch.ops.vllm.sparse_attn_indexer(
            hidden_states,
            self.k_cache.prefix,
            self.k_cache.kv_cache[0],
            q_fp8,
            k,
            weights,
            self.quant_block_size,
            self.scale_fmt,
            self.topk_tokens,
            self.head_dim,
            self.max_model_len,
            self.max_total_seq_len,
            self.topk_indices_buffer,
        )


813
814
815
816
class DeepseekV2MLAAttention(nn.Module):
    """
    Main reference: DeepseekV2 paper, and FlashInfer Implementation
    (https://arxiv.org/abs/2405.04434 and https://github.com/flashinfer-ai/flashinfer/pull/551).
817

818
819
        For more info see MLACommonImpl in:
        vllm/v1/attention/backends/mla/utils.py
820
821
822
823
    """

    def __init__(
        self,
824
        vllm_config: VllmConfig,
825
        config: DeepseekV2Config | DeepseekV3Config,
826
827
828
829
830
        hidden_size: int,
        num_heads: int,
        qk_nope_head_dim: int,
        qk_rope_head_dim: int,
        v_head_dim: int,
831
        q_lora_rank: int | None,
832
833
        kv_lora_rank: int,
        rope_theta: float = 10000,
834
        rope_scaling: dict[str, Any] | None = None,
835
        max_position_embeddings: int = 8192,
836
837
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
838
        prefix: str = "",
839
        topk_indices_buffer: torch.Tensor | None = None,
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
    ) -> None:
        super().__init__()
        self.hidden_size = hidden_size
        self.qk_nope_head_dim = qk_nope_head_dim
        self.qk_rope_head_dim = qk_rope_head_dim
        self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
        self.v_head_dim = v_head_dim

        self.q_lora_rank = q_lora_rank
        self.kv_lora_rank = kv_lora_rank

        self.num_heads = num_heads
        tp_size = get_tensor_model_parallel_world_size()
        assert num_heads % tp_size == 0
        self.num_local_heads = num_heads // tp_size

        self.scaling = self.qk_head_dim**-0.5
        self.rope_theta = rope_theta
        self.max_position_embeddings = max_position_embeddings

        if self.q_lora_rank is not None:
861
            self.fused_qkv_a_proj = MergedColumnParallelLinear(
862
863
864
865
                self.hidden_size,
                [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim],
                bias=False,
                quant_config=quant_config,
866
                prefix=f"{prefix}.fused_qkv_a_proj",
867
868
                disable_tp=True,
            )
869
870
871
872
873
874
        else:
            self.kv_a_proj_with_mqa = ReplicatedLinear(
                self.hidden_size,
                self.kv_lora_rank + self.qk_rope_head_dim,
                bias=False,
                quant_config=quant_config,
875
876
                prefix=f"{prefix}.kv_a_proj_with_mqa",
            )
877
878

        if self.q_lora_rank is not None:
879
880
881
882
883
884
885
886
            self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
            self.q_b_proj = ColumnParallelLinear(
                self.q_lora_rank,
                self.num_heads * self.qk_head_dim,
                bias=False,
                quant_config=quant_config,
                prefix=f"{prefix}.q_b_proj",
            )
887
        else:
888
889
890
891
892
893
894
895
            self.q_proj = ColumnParallelLinear(
                self.hidden_size,
                self.num_heads * self.qk_head_dim,
                bias=False,
                quant_config=quant_config,
                prefix=f"{prefix}.q_proj",
            )
        self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
896
897
898
899
900
        self.kv_b_proj = ColumnParallelLinear(
            self.kv_lora_rank,
            self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
            bias=False,
            quant_config=quant_config,
901
902
903
904
905
906
907
908
909
            prefix=f"{prefix}.kv_b_proj",
        )
        self.o_proj = RowParallelLinear(
            self.num_heads * self.v_head_dim,
            self.hidden_size,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.o_proj",
        )
910

911
        if rope_scaling:
912
913
914
915
916
917
918
919
920
            rope_scaling["rope_type"] = "deepseek_yarn"
        self.rotary_emb = get_rope(
            qk_rope_head_dim,
            rotary_dim=qk_rope_head_dim,
            max_position=max_position_embeddings,
            base=rope_theta,
            rope_scaling=rope_scaling,
            is_neox_style=False,
        )
921
922
923
924
925
926
        if rope_scaling:
            mscale_all_dim = rope_scaling.get("mscale_all_dim", False)
            scaling_factor = rope_scaling["factor"]
            mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
            self.scaling = self.scaling * mscale * mscale

927
928
929
        self.is_v32 = hasattr(config, "index_topk")

        if self.is_v32:
930
931
932
933
934
935
936
937
938
939
            self.indexer = Indexer(
                vllm_config,
                config,
                hidden_size,
                q_lora_rank,
                quant_config,
                cache_config,
                topk_indices_buffer,
                f"{prefix}.indexer",
            )
940
941
942
        else:
            self.indexer = None

943
944
        mla_modules = MLAModules(
            kv_a_layernorm=self.kv_a_layernorm,
945
            kv_b_proj=self.kv_b_proj,
946
947
948
            rotary_emb=self.rotary_emb,
            o_proj=self.o_proj,
            fused_qkv_a_proj=self.fused_qkv_a_proj
949
950
            if self.q_lora_rank is not None
            else None,
951
            kv_a_proj_with_mqa=self.kv_a_proj_with_mqa
952
953
954
            if self.q_lora_rank is None
            else None,
            q_a_layernorm=self.q_a_layernorm if self.q_lora_rank is not None else None,
955
956
            q_b_proj=self.q_b_proj if self.q_lora_rank is not None else None,
            q_proj=self.q_proj if self.q_lora_rank is None else None,
957
958
959
            indexer=self.indexer,
            is_sparse=self.is_v32,
            topk_indices_buffer=topk_indices_buffer,
960
        )
961

962
        self.mla_attn = MultiHeadLatentAttentionWrapper(
963
964
965
966
967
968
969
970
971
972
973
974
            self.hidden_size,
            self.num_local_heads,
            self.scaling,
            self.qk_nope_head_dim,
            self.qk_rope_head_dim,
            self.v_head_dim,
            self.q_lora_rank,
            self.kv_lora_rank,
            mla_modules,
            cache_config,
            quant_config,
            prefix,
975
976
977
978
979
980
981
        )

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
982
        return self.mla_attn(positions, hidden_states)
983
984


wangding zeng's avatar
wangding zeng committed
985
class DeepseekV2DecoderLayer(nn.Module):
986
987
988
989
    def __init__(
        self,
        vllm_config: VllmConfig,
        prefix: str,
990
991
        config: DeepseekV2Config | None = None,
        topk_indices_buffer: torch.Tensor | None = None,
992
    ) -> None:
wangding zeng's avatar
wangding zeng committed
993
        super().__init__()
994

995
996
        if config is None:
            config = vllm_config.model_config.hf_config
997
998
999
1000
1001
        model_config = vllm_config.model_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config
        parallel_config = vllm_config.parallel_config

wangding zeng's avatar
wangding zeng committed
1002
1003
1004
        self.hidden_size = config.hidden_size
        rope_theta = getattr(config, "rope_theta", 10000)
        rope_scaling = getattr(config, "rope_scaling", None)
1005
        max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
1006
1007
        # DecoderLayers are created with `make_layers` which passes the prefix
        # with the layer's index.
1008
        layer_idx = int(prefix.split(sep=".")[-1])
1009
        self.layer_idx = layer_idx
1010
1011
1012
1013
1014
        if model_config.use_mla:
            attn_cls = DeepseekV2MLAAttention
        else:
            attn_cls = DeepseekV2Attention
        self.self_attn = attn_cls(
1015
            vllm_config=vllm_config,
wangding zeng's avatar
wangding zeng committed
1016
1017
1018
1019
1020
1021
            config=config,
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
            qk_nope_head_dim=config.qk_nope_head_dim,
            qk_rope_head_dim=config.qk_rope_head_dim,
            v_head_dim=config.v_head_dim,
1022
            q_lora_rank=config.q_lora_rank if hasattr(config, "q_lora_rank") else None,
wangding zeng's avatar
wangding zeng committed
1023
1024
1025
1026
1027
1028
            kv_lora_rank=config.kv_lora_rank,
            rope_theta=rope_theta,
            rope_scaling=rope_scaling,
            max_position_embeddings=max_position_embeddings,
            cache_config=cache_config,
            quant_config=quant_config,
1029
            prefix=f"{prefix}.self_attn",
1030
            topk_indices_buffer=topk_indices_buffer,
wangding zeng's avatar
wangding zeng committed
1031
        )
1032

1033
1034
1035
1036
1037
        if (
            config.n_routed_experts is not None
            and layer_idx >= config.first_k_dense_replace
            and layer_idx % config.moe_layer_freq == 0
        ):
1038
1039
            self.mlp = DeepseekV2MoE(
                config=config,
1040
                parallel_config=parallel_config,
1041
1042
1043
                quant_config=quant_config,
                prefix=f"{prefix}.mlp",
            )
wangding zeng's avatar
wangding zeng committed
1044
1045
1046
1047
1048
1049
        else:
            self.mlp = DeepseekV2MLP(
                hidden_size=config.hidden_size,
                intermediate_size=config.intermediate_size,
                hidden_act=config.hidden_act,
                quant_config=quant_config,
1050
                prefix=f"{prefix}.mlp",
wangding zeng's avatar
wangding zeng committed
1051
            )
1052
1053
1054
1055
        self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(
            config.hidden_size, eps=config.rms_norm_eps
        )
1056
        self.routed_scaling_factor = config.routed_scaling_factor
wangding zeng's avatar
wangding zeng committed
1057
1058
1059
1060
1061

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
1062
        residual: torch.Tensor | None,
wangding zeng's avatar
wangding zeng committed
1063
1064
1065
    ) -> torch.Tensor:
        # Self Attention
        if residual is None:
1066
            residual = hidden_states.clone()
wangding zeng's avatar
wangding zeng committed
1067
1068
            hidden_states = self.input_layernorm(hidden_states)
        else:
1069
            hidden_states, residual = self.input_layernorm(hidden_states, residual)
wangding zeng's avatar
wangding zeng committed
1070
1071
1072
1073
1074
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
        )

1075
1076
1077
1078
        if hidden_states.dtype == torch.float16:
            # Fix FP16 overflow
            # We scale both hidden_states and residual before
            # rmsnorm, and rmsnorm result would not affect by scale.
1079
            hidden_states *= 1.0 / self.routed_scaling_factor
1080
1081
1082
            if self.layer_idx == 0:
                # The residual is shared by all layers, we only scale it on
                # first layer.
1083
                residual *= 1.0 / self.routed_scaling_factor
1084
1085

        # Fully Connected
1086
        hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
wangding zeng's avatar
wangding zeng committed
1087
        hidden_states = self.mlp(hidden_states)
1088

1089
        if isinstance(self.mlp, DeepseekV2MLP) and hidden_states.dtype == torch.float16:
1090
1091
1092
1093
1094
            # Fix FP16 overflow
            # Scaling the DeepseekV2MLP output, it is the input of
            # input_layernorm of next decoder layer.
            # The scaling of DeepseekV2MOE output would be done in the forward
            # of DeepseekV2MOE
1095
            hidden_states *= 1.0 / self.routed_scaling_factor
1096

wangding zeng's avatar
wangding zeng committed
1097
1098
1099
        return hidden_states, residual


1100
@support_torch_compile
wangding zeng's avatar
wangding zeng committed
1101
1102
1103
class DeepseekV2Model(nn.Module):
    fall_back_to_pt_during_load = False

1104
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
wangding zeng's avatar
wangding zeng committed
1105
        super().__init__()
1106
1107
1108

        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
1109
        self.config = config
1110
        self.device = current_platform.device_type
1111

wangding zeng's avatar
wangding zeng committed
1112
        self.vocab_size = config.vocab_size
1113
1114
1115
1116
1117
1118
1119
        self.is_v32 = hasattr(config, "index_topk")
        if self.is_v32:
            topk_tokens = config.index_topk
            topk_indices_buffer = torch.empty(
                vllm_config.scheduler_config.max_num_batched_tokens,
                topk_tokens,
                dtype=torch.int32,
1120
                device=self.device,
1121
            )
1122
1123
        else:
            topk_indices_buffer = None
wangding zeng's avatar
wangding zeng committed
1124

1125
1126
1127
1128
        if get_pp_group().is_first_rank:
            self.embed_tokens = VocabParallelEmbedding(
                config.vocab_size,
                config.hidden_size,
1129
                quant_config=quant_config,
1130
1131
                prefix=f"{prefix}.embed_tokens",
            )
1132
1133
1134
1135
1136
        else:
            self.embed_tokens = PPMissingLayer()

        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
1137
            lambda prefix: DeepseekV2DecoderLayer(
1138
                vllm_config, prefix, topk_indices_buffer=topk_indices_buffer
1139
1140
1141
            ),
            prefix=f"{prefix}.layers",
        )
1142
1143
1144
1145
1146

        if get_pp_group().is_last_rank:
            self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        else:
            self.norm = PPMissingLayer()
1147
1148
1149
        self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
            ["hidden_states", "residual"], config.hidden_size
        )
wangding zeng's avatar
wangding zeng committed
1150

1151
1152
1153
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.embed_tokens(input_ids)

wangding zeng's avatar
wangding zeng committed
1154
1155
1156
1157
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
1158
1159
1160
        intermediate_tensors: IntermediateTensors | None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor | IntermediateTensors:
1161
        if get_pp_group().is_first_rank:
1162
1163
1164
1165
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
                hidden_states = self.get_input_embeddings(input_ids)
1166
1167
1168
1169
1170
1171
            residual = None
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]

1172
        for layer in islice(self.layers, self.start_layer, self.end_layer):
1173
            hidden_states, residual = layer(positions, hidden_states, residual)
1174
1175

        if not get_pp_group().is_last_rank:
1176
1177
1178
            return IntermediateTensors(
                {"hidden_states": hidden_states, "residual": residual}
            )
1179

wangding zeng's avatar
wangding zeng committed
1180
1181
1182
        hidden_states, _ = self.norm(hidden_states, residual)
        return hidden_states

1183

1184
class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts, SupportsLoRA):
1185
1186
1187
    packed_modules_mapping = {
        "gate_up_proj": ["gate_proj", "up_proj"],
    }
1188
1189
1190
1191
1192
1193
1194

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        self.config = config
        self.quant_config = quant_config
1195
1196
1197
1198
1199

        # `packed_modules_mapping` needs to be modified before
        # initializing DeepseekV2Model, as it is passed inplace to
        # quantization config init and may be used to select the
        # quant_method for relevant layers during initialization.
1200
1201
1202
        self.fuse_qkv_a_proj = (
            hasattr(config, "q_lora_rank") and config.q_lora_rank is not None
        )
1203
1204
1205
1206
1207
1208
        if self.fuse_qkv_a_proj:
            self.packed_modules_mapping["fused_qkv_a_proj"] = [
                "q_a_proj",
                "kv_a_proj_with_mqa",
            ]

1209
1210
1211
        self.model = DeepseekV2Model(
            vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
        )
1212
        if get_pp_group().is_last_rank:
1213
1214
1215
1216
1217
1218
            self.lm_head = ParallelLMHead(
                config.vocab_size,
                config.hidden_size,
                quant_config=quant_config,
                prefix=maybe_prefix(prefix, "lm_head"),
            )
1219
1220
1221
1222
        else:
            self.lm_head = PPMissingLayer()
        self.logits_processor = LogitsProcessor(config.vocab_size)
        self.make_empty_intermediate_tensors = (
1223
1224
            self.model.make_empty_intermediate_tensors
        )
1225
1226
1227
        self.expert_weights = []

        # Set MoE hyperparameters
1228
        self.num_moe_layers = config.num_hidden_layers - config.first_k_dense_replace
1229
1230
        self.num_expert_groups = config.n_group

1231
        self.moe_layers: list[SharedFusedMoE] = []
1232
        example_moe = None
1233
        for layer in self.model.layers:
1234
1235
1236
            if isinstance(layer, PPMissingLayer):
                continue

1237
1238
            assert isinstance(layer, DeepseekV2DecoderLayer)
            if isinstance(layer.mlp, DeepseekV2MoE):
1239
1240
                # Pick last one layer since the first ones may be dense layers.
                example_moe = layer.mlp
1241
1242
                self.moe_layers.append(layer.mlp.experts)

1243
1244
1245
        if example_moe is None:
            raise RuntimeError("No DeepseekV2MoE layer found in model.layers.")

1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
        self.num_logical_experts = example_moe.n_logical_experts
        self.num_physical_experts = example_moe.n_physical_experts
        self.num_local_physical_experts = example_moe.n_local_physical_experts
        self.num_routed_experts = example_moe.n_routed_experts
        self.num_shared_experts = example_moe.n_shared_experts
        self.num_redundant_experts = example_moe.n_redundant_experts

    def set_eplb_state(
        self,
        expert_load_view: torch.Tensor,
        logical_to_physical_map: torch.Tensor,
        logical_replica_count: torch.Tensor,
    ) -> None:
        for layer_idx, layer in enumerate(self.moe_layers):
            # Register the expert weights.
            self.expert_weights.append(layer.get_expert_weights())
            layer.set_eplb_state(
                moe_layer_idx=layer_idx,
                expert_load_view=expert_load_view,
                logical_to_physical_map=logical_to_physical_map,
                logical_replica_count=logical_replica_count,
            )
1268

1269
1270
1271
1272
1273
1274
1275
1276
    def update_physical_experts_metadata(
        self,
        num_physical_experts: int,
        num_local_physical_experts: int,
    ) -> None:
        assert self.num_local_physical_experts == num_local_physical_experts
        self.num_physical_experts = num_physical_experts
        self.num_local_physical_experts = num_local_physical_experts
1277
        self.num_redundant_experts = num_physical_experts - self.num_logical_experts
1278
1279
1280
1281
1282
1283
1284
1285
        for layer in self.model.layers:
            if isinstance(layer.mlp, DeepseekV2MoE):
                moe = layer.mlp
                moe.n_local_physical_experts = num_local_physical_experts
                moe.n_physical_experts = num_physical_experts
                moe.n_redundant_experts = self.num_redundant_experts
                moe.experts.update_expert_map()

1286
1287
1288
1289
1290
1291
1292
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.get_input_embeddings(input_ids)

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
1293
1294
1295
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor | IntermediateTensors:
1296
1297
1298
        hidden_states = self.model(
            input_ids, positions, intermediate_tensors, inputs_embeds
        )
1299
1300
1301
1302
1303
        return hidden_states

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
1304
    ) -> torch.Tensor | None:
1305
        logits = self.logits_processor(self.lm_head, hidden_states)
1306
1307
        return logits

1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
    def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
        # Params for weights, fp8 weight scales, fp8 activation scales
        # (param_name, weight_name, expert_id, shard_id)
        return SharedFusedMoE.make_expert_params_mapping(
            ckpt_gate_proj_name="gate_proj",
            ckpt_down_proj_name="down_proj",
            ckpt_up_proj_name="up_proj",
            num_experts=self.config.n_routed_experts,
            num_redundant_experts=0,
        )

1319
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
wangding zeng's avatar
wangding zeng committed
1320
1321
1322
1323
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
1324
1325
            ("fused_qkv_a_proj", "q_a_proj", 0),
            ("fused_qkv_a_proj", "kv_a_proj_with_mqa", 1),
wangding zeng's avatar
wangding zeng committed
1326
1327
        ]

1328
1329
        # Params for weights, fp8 weight scales, fp8 activation scales
        # (param_name, weight_name, expert_id, shard_id)
1330
        expert_params_mapping = SharedFusedMoE.make_expert_params_mapping(
1331
1332
1333
            ckpt_gate_proj_name="gate_proj",
            ckpt_down_proj_name="down_proj",
            ckpt_up_proj_name="up_proj",
1334
1335
1336
1337
1338
1339
            num_experts=self.config.n_routed_experts
            + (
                self.config.n_shared_experts
                if is_rocm_aiter_fusion_shared_expert_enabled()
                else 0
            ),
1340
1341
            num_redundant_experts=self.num_redundant_experts,
        )
1342

wangding zeng's avatar
wangding zeng committed
1343
        params_dict = dict(self.named_parameters())
1344
        loaded_params: set[str] = set()
wangding zeng's avatar
wangding zeng committed
1345
        for name, loaded_weight in weights:
1346
1347
1348
            if "rotary_emb.inv_freq" in name:
                continue

1349
1350
1351
            spec_layer = get_spec_layer_idx_from_weight_name(self.config, name)
            if spec_layer is not None:
                continue  # skip spec decode layers for main model
1352

1353
1354
1355
1356
1357
            is_fuse_shared_experts_layer = (
                is_rocm_aiter_fusion_shared_expert_enabled()
                and ("mlp.shared_experts" in name)
            )

1358
            for param_name, weight_name, shard_id in stacked_params_mapping:
1359
                # Skip non-stacked layers and experts (experts handled below).
wangding zeng's avatar
wangding zeng committed
1360
1361
                if weight_name not in name:
                    continue
1362
1363
1364
1365
1366
1367
                # We have mlp.experts[0].gate_proj in the checkpoint.
                # Since we handle the experts below in expert_params_mapping,
                # we need to skip here BEFORE we update the name, otherwise
                # name will be updated to mlp.experts[0].gate_up_proj, which
                # will then be updated below in expert_params_mapping
                # for mlp.experts[0].gate_gate_up_proj, which breaks load.
1368
                if ("mlp.experts." in name) and name not in params_dict:
1369
                    continue
1370
1371
                if is_fuse_shared_experts_layer:
                    continue
1372
                name_mapped = name.replace(weight_name, param_name)
1373
1374
1375

                # QKV fusion is optional, fall back to normal
                # weight loading if it's not enabled
1376
                # if go with fusion option, then update name
1377
1378
1379
                if (
                    param_name == "fused_qkv_a_proj"
                ) and name_mapped not in params_dict:
1380
                    continue
1381
1382
                else:
                    name = name_mapped
wangding zeng's avatar
wangding zeng committed
1383
1384
1385
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
1386
1387
1388
1389

                if is_pp_missing_parameter(name, self):
                    continue

wangding zeng's avatar
wangding zeng committed
1390
1391
1392
1393
1394
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
1395
                is_expert_weight = False
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415

                # Special handling: when AITER fusion_shared_experts is enabled,
                # checkpoints may provide a single widened shared_experts tensor
                # without explicit expert indices
                # (e.g. ...mlp.shared_experts.gate_proj.weight).
                # For models with multiple shared experts, split that tensor
                # evenly into per-shared-expert slices and load them into
                # appended expert slots mlp.experts.{n_routed_experts + j}.*
                # accordingly.
                num_chunks = 1
                if is_fuse_shared_experts_layer:
                    num_chunks = getattr(self.config, "n_shared_experts", 1) or 1
                    # Determine split axis based on op type
                    # gate/up: ColumnParallel → split along dim 0
                    # down: RowParallel → split along dim 1
                    split_dim = 1 if "down_proj.weight" in name else 0
                    total = loaded_weight.shape[split_dim]
                    assert total % num_chunks == 0, (
                        f"Shared expert weight dim {total} "
                        f"not divisible by num_chunks {num_chunks}"
1416
                    )
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
                    chunk_size = total // num_chunks

                for j in range(num_chunks):
                    chunk_name = name
                    weight_to_load = loaded_weight

                    if is_fuse_shared_experts_layer:
                        if split_dim == 0:
                            weight_to_load = loaded_weight[
                                j * chunk_size : (j + 1) * chunk_size, :
                            ]
                        else:
                            weight_to_load = loaded_weight[
                                :, j * chunk_size : (j + 1) * chunk_size
                            ]
                        # Synthesize an expert-style name so expert mapping
                        # can route it
                        chunk_name = name.replace(
                            "mlp.shared_experts",
                            f"mlp.experts.{self.config.n_routed_experts + j}",
                        )

                    # Use expert_params_mapping to locate the destination
                    # param and delegate to its expert-aware weight_loader
                    # with expert_id.
                    for mapping in expert_params_mapping:
                        param_name, weight_name, expert_id, shard_id = mapping
                        if weight_name not in chunk_name:
                            continue

                        # Anyway, this is an expert weight and should not be
                        # attempted to load as other weights later
                        is_expert_weight = True

                        # Do not modify `name` since the loop may continue here
                        # Instead, create a new variable
                        name_mapped = chunk_name.replace(weight_name, param_name)

                        if is_pp_missing_parameter(name_mapped, self):
                            continue

                        param = params_dict[name_mapped]
                        # We should ask the weight loader to return success or
                        # not here since otherwise we may skip experts with
                        # other available replicas.
                        weight_loader = typing.cast(
                            Callable[..., bool], param.weight_loader
                        )
                        success = weight_loader(
                            param,
                            weight_to_load,
                            name_mapped,
                            shard_id=shard_id,
                            expert_id=expert_id,
                            return_success=True,
                        )
                        if success:
                            if not is_fuse_shared_experts_layer:
                                name = name_mapped
                            else:
                                loaded_params.add(name_mapped)
                            break
                    else:
                        if is_expert_weight:
                            # We've checked that this is an expert weight
                            # However it's not mapped locally to this rank
                            # So we simply skip it
                            continue

                        # Skip loading extra bias for GPTQ models.
                        if name.endswith(".bias") and name not in params_dict:
                            continue

                        # Remapping the name of FP8 kv-scale.
                        name = maybe_remap_kv_scale_name(name, params_dict)
                        if name is None:
                            continue

                        if is_pp_missing_parameter(name, self):
                            continue

                        param = params_dict[name]
                        weight_loader = getattr(
                            param, "weight_loader", default_weight_loader
                        )
                        weight_loader(param, loaded_weight)
            if not is_fuse_shared_experts_layer:
                loaded_params.add(name)
1505

1506
        return loaded_params
1507
1508
1509
1510


class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM):
    pass
1511
1512


1513
1514
# Compatibility with
# https://huggingface.co/deepseek-ai/DeepSeek-V3-Base/blob/main/configuration_deepseek.py
1515
def get_spec_layer_idx_from_weight_name(
1516
1517
    config: DeepseekV2Config | DeepseekV3Config, weight_name: str
) -> int | None:
1518
1519
1520
1521
    if (
        hasattr(config, "num_nextn_predict_layers")
        and config.num_nextn_predict_layers > 0
    ):
1522
1523
        layer_idx = config.num_hidden_layers
        for i in range(config.num_nextn_predict_layers):
1524
            if weight_name.startswith(f"model.layers.{layer_idx + i}."):
1525
1526
                return layer_idx + i
    return None