deepseek_v2.py 61.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

wangding zeng's avatar
wangding zeng committed
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# Copyright 2023 The vLLM team.
# Copyright 2023 DeepSeek-AI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
25
"""Inference-only DeepseekV2/DeepseekV3 model."""
26

27
28
import typing
from collections.abc import Callable, Iterable
29
from itertools import islice
30
from typing import Any
wangding zeng's avatar
wangding zeng committed
31
32
33

import torch
from torch import nn
34
from transformers import DeepseekV2Config, DeepseekV3Config
wangding zeng's avatar
wangding zeng committed
35

36
from vllm.attention import Attention
37
38
from vllm.attention.backends.abstract import AttentionBackend
from vllm.attention.ops.common import pack_seq_triton, unpack_seq_triton
39
from vllm.compilation.decorators import support_torch_compile
40
41
42
43
44
45
46
47
from vllm.config import CacheConfig, ParallelConfig, VllmConfig, get_current_vllm_config
from vllm.distributed import (
    get_ep_group,
    get_pp_group,
    get_tensor_model_parallel_rank,
    get_tensor_model_parallel_world_size,
    tensor_model_parallel_all_gather,
)
48
49
from vllm.forward_context import get_forward_context
from vllm.logger import init_logger
wangding zeng's avatar
wangding zeng committed
50
from vllm.model_executor.layers.activation import SiluAndMul
51
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
52
from vllm.model_executor.layers.fused_moe import SharedFusedMoE
53
54
55
56
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
    is_rocm_aiter_fusion_shared_expert_enabled,
    is_rocm_aiter_moe_enabled,
)
57
from vllm.model_executor.layers.layernorm import LayerNorm, RMSNorm
58
59
60
from vllm.model_executor.layers.linear import (
    ColumnParallelLinear,
    MergedColumnParallelLinear,
61
    QKVParallelLinear,
62
63
64
    ReplicatedLinear,
    RowParallelLinear,
)
wangding zeng's avatar
wangding zeng committed
65
from vllm.model_executor.layers.logits_processor import LogitsProcessor
66
from vllm.model_executor.layers.mla import MLAModules, MultiHeadLatentAttentionWrapper
67
from vllm.model_executor.layers.quantization import QuantizationConfig
68
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
69
70
    per_token_group_quant_fp8,
)
wangding zeng's avatar
wangding zeng committed
71
72
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
73
74
75
    ParallelLMHead,
    VocabParallelEmbedding,
)
76
from vllm.model_executor.model_loader.weight_utils import (
77
78
79
    default_weight_loader,
    maybe_remap_kv_scale_name,
)
80
from vllm.model_executor.models.utils import sequence_parallel_chunk
81
from vllm.platforms import current_platform
82
from vllm.sequence import IntermediateTensors
83
from vllm.utils.deep_gemm import fp8_mqa_logits, fp8_paged_mqa_logits
84
from vllm.utils.torch_utils import direct_register_custom_op
85
86
87
88
from vllm.v1.attention.backends.mla.indexer import (
    DeepseekV32IndexerBackend,
    DeepseekV32IndexerMetadata,
)
89
from vllm.v1.kv_cache_interface import KVCacheSpec, MLAAttentionSpec
wangding zeng's avatar
wangding zeng committed
90

91
from .interfaces import MixtureOfExperts, SupportsLoRA, SupportsPP
92
93
94
95
96
97
98
from .utils import (
    PPMissingLayer,
    is_pp_missing_parameter,
    make_empty_intermediate_tensors_factory,
    make_layers,
    maybe_prefix,
)
99

100
101
102
103
104
105
106
if current_platform.is_cuda_alike():
    from vllm import _custom_ops as ops
elif current_platform.is_xpu():
    from vllm._ipex_ops import ipex_ops as ops

logger = init_logger(__name__)

wangding zeng's avatar
wangding zeng committed
107

108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
class DeepseekAttention(nn.Module):
    """Normal MHA implementation used by Deepseek v1."""

    def __init__(
        self,
        vllm_config: VllmConfig,
        config: DeepseekV2Config | DeepseekV3Config,
        hidden_size: int,
        num_heads: int,
        rope_theta: float = 10000,
        rope_scaling: dict[str, Any] | None = None,
        max_position_embeddings: int = 8192,
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
        prefix: str = "",
        **kwargs,
    ) -> None:
        super().__init__()
        self.hidden_size = hidden_size
        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = num_heads
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        self.total_num_kv_heads = config.num_key_value_heads
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
        self.head_dim = hidden_size // self.total_num_heads
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scaling = self.head_dim**-0.5
        self.rope_theta = rope_theta
        self.max_position_embeddings = max_position_embeddings

        self.qkv_proj = QKVParallelLinear(
            hidden_size,
            self.head_dim,
            self.total_num_heads,
            self.total_num_kv_heads,
            bias=False,
            quant_config=quant_config,
        )

        self.o_proj = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            hidden_size,
            bias=False,
            quant_config=quant_config,
        )

        self.rotary_emb = get_rope(
            self.head_dim,
            rotary_dim=self.head_dim,
            max_position=max_position_embeddings,
            base=rope_theta,
            rope_scaling=rope_scaling,
        )
        self.attn = Attention(
            self.num_heads,
            self.head_dim,
            self.scaling,
            num_kv_heads=self.num_kv_heads,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.attn",
        )

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
        q, k = self.rotary_emb(positions, q, k)
        attn_output = self.attn(q, k, v)
        output, _ = self.o_proj(attn_output)
        return output


wangding zeng's avatar
wangding zeng committed
194
195
196
197
198
199
class DeepseekV2MLP(nn.Module):
    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
200
        quant_config: QuantizationConfig | None = None,
wangding zeng's avatar
wangding zeng committed
201
        reduce_results: bool = True,
202
        is_sequence_parallel=False,
203
        prefix: str = "",
wangding zeng's avatar
wangding zeng committed
204
205
    ) -> None:
        super().__init__()
206
207
208
209
210

        # If is_sequence_parallel, the input and output tensors are sharded
        # across the ranks within the tp_group. In this case the weights are
        # replicated and no collective ops are needed.
        # Otherwise we use standard TP with an allreduce at the end.
wangding zeng's avatar
wangding zeng committed
211
        self.gate_up_proj = MergedColumnParallelLinear(
212
213
            hidden_size,
            [intermediate_size] * 2,
wangding zeng's avatar
wangding zeng committed
214
            bias=False,
215
            quant_config=quant_config,
216
            disable_tp=is_sequence_parallel,
217
218
219
220
221
222
223
224
225
226
227
            prefix=f"{prefix}.gate_up_proj",
        )
        self.down_proj = RowParallelLinear(
            intermediate_size,
            hidden_size,
            bias=False,
            quant_config=quant_config,
            reduce_results=reduce_results,
            disable_tp=is_sequence_parallel,
            prefix=f"{prefix}.down_proj",
        )
wangding zeng's avatar
wangding zeng committed
228
        if hidden_act != "silu":
229
230
231
            raise ValueError(
                f"Unsupported activation: {hidden_act}. Only silu is supported for now."
            )
wangding zeng's avatar
wangding zeng committed
232
233
234
235
236
237
238
239
240
241
242
243
        self.act_fn = SiluAndMul()

    def forward(self, x):
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.down_proj(x)
        return x


class DeepseekV2MoE(nn.Module):
    def __init__(
        self,
244
        config: DeepseekV2Config | DeepseekV3Config,
245
        parallel_config: ParallelConfig,
246
        quant_config: QuantizationConfig | None = None,
247
        prefix: str = "",
wangding zeng's avatar
wangding zeng committed
248
249
250
    ):
        super().__init__()
        self.tp_size = get_tensor_model_parallel_world_size()
251
252
        self.tp_rank = get_tensor_model_parallel_rank()

253
        self.routed_scaling_factor = getattr(config, "routed_scaling_factor", 1.0)
254
255

        self.ep_group = get_ep_group().device_group
256
        self.ep_rank = get_ep_group().rank_in_group
257
258
259
        self.ep_size = self.ep_group.size()
        self.n_routed_experts: int = config.n_routed_experts
        self.n_shared_experts: int = config.n_shared_experts
260

261
        self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe
262

263
        if config.hidden_act != "silu":
264
265
266
267
268
269
270
271
272
273
274
275
            raise ValueError(
                f"Unsupported activation: {config.hidden_act}. "
                "Only silu is supported for now."
            )

        self.gate = ReplicatedLinear(
            config.hidden_size,
            config.n_routed_experts,
            bias=False,
            quant_config=None,
            prefix=f"{prefix}.gate",
        )
276
        if getattr(config, "topk_method", None) == "noaux_tc":
277
            self.gate.e_score_correction_bias = nn.Parameter(
278
279
                torch.empty(config.n_routed_experts, dtype=torch.float32)
            )
280
281
282
        else:
            self.gate.e_score_correction_bias = None

283
        # Load balancing settings.
284
285
        eplb_config = parallel_config.eplb_config
        self.enable_eplb = parallel_config.enable_eplb
286

287
        self.n_redundant_experts = eplb_config.num_redundant_experts
288
        self.n_logical_experts = self.n_routed_experts
289
        self.n_physical_experts = self.n_logical_experts + self.n_redundant_experts
290
291
        self.n_local_physical_experts = self.n_physical_experts // self.ep_size

292
293
294
295
        self.physical_expert_start = self.ep_rank * self.n_local_physical_experts
        self.physical_expert_end = (
            self.physical_expert_start + self.n_local_physical_experts
        )
296

297
298
299
300
        if (
            config.n_shared_experts is None
            or is_rocm_aiter_fusion_shared_expert_enabled()
        ):
301
302
            self.shared_experts = None
        else:
303
            intermediate_size = config.moe_intermediate_size * config.n_shared_experts
304

wangding zeng's avatar
wangding zeng committed
305
306
307
308
309
            self.shared_experts = DeepseekV2MLP(
                hidden_size=config.hidden_size,
                intermediate_size=intermediate_size,
                hidden_act=config.hidden_act,
                quant_config=quant_config,
310
                is_sequence_parallel=self.is_sequence_parallel,
311
                reduce_results=False,
312
                prefix=f"{prefix}.shared_experts",
wangding zeng's avatar
wangding zeng committed
313
314
            )

315
316
        self.experts = SharedFusedMoE(
            shared_experts=self.shared_experts,
317
            gate=self.gate,
318
319
320
321
322
323
324
325
            num_experts=config.n_routed_experts,
            top_k=config.num_experts_per_tok,
            hidden_size=config.hidden_size,
            intermediate_size=config.moe_intermediate_size,
            reduce_results=False,
            renormalize=config.norm_topk_prob,
            quant_config=quant_config,
            use_grouped_topk=True,
326
327
            num_expert_group=getattr(config, "n_group", 1),
            topk_group=getattr(config, "topk_group", 1),
328
            prefix=f"{prefix}.experts",
329
            scoring_func=getattr(config, "scoring_func", "softmax"),
330
            # we do scaling outside, set factor to 1.0 to avoid double mul
331
332
333
334
            # aiter applies routed_scaling_factor internally
            routed_scaling_factor=1.0
            if not is_rocm_aiter_moe_enabled()
            else self.routed_scaling_factor,
335
336
337
338
            e_score_correction_bias=self.gate.e_score_correction_bias,
            enable_eplb=self.enable_eplb,
            num_redundant_experts=self.n_redundant_experts,
            is_sequence_parallel=self.is_sequence_parallel,
339
340
341
            n_shared_experts=config.n_shared_experts
            if is_rocm_aiter_fusion_shared_expert_enabled()
            else None,
342
        )
343

wangding zeng's avatar
wangding zeng committed
344
345
346
    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        num_tokens, hidden_dim = hidden_states.shape
        hidden_states = hidden_states.view(-1, hidden_dim)
347
348
349
350
351
352

        # Chunk the hidden states so they aren't replicated across TP ranks.
        # This avoids duplicate computation in self.experts.
        # TODO: We can replace the all_reduce at the end of attn with a
        # reduce_scatter instead of chunking here.
        if self.is_sequence_parallel:
353
            hidden_states = sequence_parallel_chunk(hidden_states)
354

355
356
357
358
359
360
361
362
363
364
365
        if self.experts.is_internal_router:
            # In this case, the gate/router runs inside the FusedMoE class
            fused_moe_out = self.experts(
                hidden_states=hidden_states, router_logits=hidden_states
            )
        else:
            # router_logits: (num_tokens, n_experts)
            router_logits, _ = self.gate(hidden_states)
            fused_moe_out = self.experts(
                hidden_states=hidden_states, router_logits=router_logits
            )
366

367
368
369
        shared_output, final_hidden_states = fused_moe_out
        if self.shared_experts is None:
            assert shared_output is None
370
371
372
373

        # Fix FP16 overflow
        # See DeepseekV2DecoderLayer for more details.
        if hidden_states.dtype != torch.float16:
374
375
            if not is_rocm_aiter_moe_enabled():
                final_hidden_states *= self.routed_scaling_factor
376
377
        elif self.shared_experts is not None:
            assert shared_output is not None
378
            shared_output *= 1.0 / self.routed_scaling_factor
379
380
381
382

        if self.shared_experts is not None:
            assert shared_output is not None
            final_hidden_states += shared_output
383

384
385
        if self.is_sequence_parallel:
            final_hidden_states = tensor_model_parallel_all_gather(
386
387
                final_hidden_states, 0
            )
388
389
            final_hidden_states = final_hidden_states[:num_tokens]
        elif self.tp_size > 1:
390
391
392
            final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel(
                final_hidden_states
            )
wangding zeng's avatar
wangding zeng committed
393
394
395
396
397
398

        return final_hidden_states.view(num_tokens, hidden_dim)


def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
    import math
399

wangding zeng's avatar
wangding zeng committed
400
401
402
403
404
405
406
407
    if scale <= 1:
        return 1.0
    return 0.1 * mscale * math.log(scale) + 1.0


class DeepseekV2Attention(nn.Module):
    def __init__(
        self,
408
        vllm_config: VllmConfig,
409
        config: DeepseekV2Config | DeepseekV3Config,
wangding zeng's avatar
wangding zeng committed
410
411
412
413
414
415
416
417
        hidden_size: int,
        num_heads: int,
        qk_nope_head_dim: int,
        qk_rope_head_dim: int,
        v_head_dim: int,
        q_lora_rank: int,
        kv_lora_rank: int,
        rope_theta: float = 10000,
418
        rope_scaling: dict[str, Any] | None = None,
wangding zeng's avatar
wangding zeng committed
419
        max_position_embeddings: int = 8192,
420
421
422
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
        topk_indices_buffer: torch.Tensor | None = None,
423
        prefix: str = "",
wangding zeng's avatar
wangding zeng committed
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
    ) -> None:
        super().__init__()
        self.hidden_size = hidden_size
        self.qk_nope_head_dim = qk_nope_head_dim
        self.qk_rope_head_dim = qk_rope_head_dim
        self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
        self.v_head_dim = v_head_dim
        self.q_lora_rank = q_lora_rank
        self.kv_lora_rank = kv_lora_rank
        self.num_heads = num_heads
        tp_size = get_tensor_model_parallel_world_size()
        assert num_heads % tp_size == 0
        self.num_local_heads = num_heads // tp_size
        self.scaling = self.qk_head_dim**-0.5
        self.rope_theta = rope_theta
        self.max_position_embeddings = max_position_embeddings
440
441
        assert topk_indices_buffer is None, (
            "topk_indices_buffer is not \
442
        supported for DeepseekV2Attention"
443
        )
wangding zeng's avatar
wangding zeng committed
444
445

        if self.q_lora_rank is not None:
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
            self.q_a_proj = ReplicatedLinear(
                self.hidden_size,
                self.q_lora_rank,
                bias=False,
                quant_config=quant_config,
                prefix=f"{prefix}.q_a_proj",
            )
            self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
            self.q_b_proj = ColumnParallelLinear(
                q_lora_rank,
                self.num_heads * self.qk_head_dim,
                bias=False,
                quant_config=quant_config,
                prefix=f"{prefix}.q_b_proj",
            )
wangding zeng's avatar
wangding zeng committed
461
        else:
462
463
464
465
466
467
468
            self.q_proj = ColumnParallelLinear(
                self.hidden_size,
                self.num_heads * self.qk_head_dim,
                bias=False,
                quant_config=quant_config,
                prefix=f"{prefix}.q_proj",
            )
wangding zeng's avatar
wangding zeng committed
469

470
471
472
473
474
        self.kv_a_proj_with_mqa = ReplicatedLinear(
            self.hidden_size,
            self.kv_lora_rank + self.qk_rope_head_dim,
            bias=False,
            quant_config=quant_config,
475
476
477
            prefix=f"{prefix}.kv_a_proj_with_mqa",
        )
        self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
wangding zeng's avatar
wangding zeng committed
478
479
480
481
        self.kv_b_proj = ColumnParallelLinear(
            self.kv_lora_rank,
            self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
            bias=False,
482
            quant_config=quant_config,
483
484
            prefix=f"{prefix}.kv_b_proj",
        )
wangding zeng's avatar
wangding zeng committed
485
        # O projection.
486
487
488
489
490
491
492
        self.o_proj = RowParallelLinear(
            self.num_heads * self.v_head_dim,
            self.hidden_size,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.o_proj",
        )
493
        if rope_scaling:
494
            rope_scaling["rope_type"] = "deepseek_yarn"
495

496
497
498
499
500
501
502
503
        self.rotary_emb = get_rope(
            qk_rope_head_dim,
            rotary_dim=qk_rope_head_dim,
            max_position=max_position_embeddings,
            base=rope_theta,
            rope_scaling=rope_scaling,
            is_neox_style=False,
        )
wangding zeng's avatar
wangding zeng committed
504
505
506
507
508
509
510

        if rope_scaling:
            mscale_all_dim = rope_scaling.get("mscale_all_dim", False)
            scaling_factor = rope_scaling["factor"]
            mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
            self.scaling = self.scaling * mscale * mscale

511
512
513
514
515
516
517
518
519
        self.attn = Attention(
            self.num_local_heads,
            self.qk_head_dim,
            self.scaling,
            num_kv_heads=self.num_local_heads,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.attn",
        )
wangding zeng's avatar
wangding zeng committed
520
521
522
523
524
525
526
527
528

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        if self.q_lora_rank is not None:
            q = self.q_a_proj(hidden_states)[0]
            q = self.q_a_layernorm(q)
529
            q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
wangding zeng's avatar
wangding zeng committed
530
        else:
531
532
533
534
            q = self.q_proj(hidden_states)[0].view(
                -1, self.num_local_heads, self.qk_head_dim
            )
        q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
wangding zeng's avatar
wangding zeng committed
535
        latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
536
        kv_a, _ = latent_cache.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
wangding zeng's avatar
wangding zeng committed
537
        latent_cache = latent_cache.unsqueeze(1)
538
        kv_a = self.kv_a_layernorm(kv_a)
wangding zeng's avatar
wangding zeng committed
539
        kv = self.kv_b_proj(kv_a)[0]
540
        kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim)
wangding zeng's avatar
wangding zeng committed
541
        k_nope, v = kv.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
542
        k_pe = latent_cache[:, :, self.kv_lora_rank :]
543

wangding zeng's avatar
wangding zeng committed
544
        q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
545

546
        q[..., self.qk_nope_head_dim :] = q_pe
wangding zeng's avatar
wangding zeng committed
547
        k = torch.empty_like(q)
548
549
        k[..., : self.qk_nope_head_dim] = k_nope
        k[..., self.qk_nope_head_dim :] = k_pe
550
551
        # padding value to qk_head_dim for alignment
        v = torch.nn.functional.pad(
552
553
            v, [0, self.qk_head_dim - self.v_head_dim], value=0
        ).view(-1, self.num_local_heads * self.qk_head_dim)
554
        attn_output = self.attn(q, k, v)
555
556
557
        attn_output = attn_output.view(-1, self.num_local_heads, self.qk_head_dim)[
            ..., : self.v_head_dim
        ].reshape(-1, self.num_local_heads * self.v_head_dim)
wangding zeng's avatar
wangding zeng committed
558
559
560
561
        output, _ = self.o_proj(attn_output)
        return output


562
class DeepseekV32IndexerCache(torch.nn.Module, AttentionLayerBase):
563
564
565
    def __init__(
        self, head_dim: int, dtype: torch.dtype, prefix: str, cache_config: CacheConfig
    ):
566
567
568
569
570
571
572
573
574
575
576
        super().__init__()
        self.kv_cache = [torch.tensor([])]
        self.head_dim = head_dim
        self.prefix = prefix
        self.cache_config = cache_config
        self.dtype = dtype
        compilation_config = get_current_vllm_config().compilation_config
        if prefix in compilation_config.static_forward_context:
            raise ValueError(f"Duplicate layer name: {prefix}")
        compilation_config.static_forward_context[prefix] = self

577
    def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
578
579
580
581
582
583
584
        return MLAAttentionSpec(  # Only has one vector instead of K + V
            block_size=self.cache_config.block_size,
            num_kv_heads=1,
            head_size=self.head_dim,
            dtype=self.dtype,
        )

585
    def forward(self): ...
586
587
588
589
590
591
592
593
594
595
596
597
598

    def get_attn_backend(self) -> AttentionBackend:
        return DeepseekV32IndexerBackend


def sparse_attn_indexer(
    hidden_states: torch.Tensor,
    k_cache_prefix: str,
    kv_cache: torch.Tensor,
    q_fp8: torch.Tensor,
    k: torch.Tensor,
    weights: torch.Tensor,
    quant_block_size: int,
599
    scale_fmt: str | None,
600
601
602
603
    topk_tokens: int,
    head_dim: int,
    max_model_len: int,
    total_seq_lens: int,
604
    topk_indices_buffer: torch.Tensor | None,
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
) -> torch.Tensor:
    # careful! this will be None in dummy run
    attn_metadata = get_forward_context().attn_metadata
    # assert isinstance(attn_metadata, dict)
    if not isinstance(attn_metadata, dict):
        return sparse_attn_indexer_fake(
            hidden_states,
            k_cache_prefix,
            kv_cache,
            q_fp8,
            k,
            weights,
            quant_block_size,
            scale_fmt,
            topk_tokens,
            head_dim,
            max_model_len,
            total_seq_lens,
            topk_indices_buffer,
        )
    attn_metadata = attn_metadata[k_cache_prefix]
    assert isinstance(attn_metadata, DeepseekV32IndexerMetadata)
    slot_mapping = attn_metadata.slot_mapping
    has_decode = attn_metadata.num_decodes > 0
    has_prefill = attn_metadata.num_prefills > 0
    num_decode_tokens = attn_metadata.num_decode_tokens

    ops.indexer_k_quant_and_cache(
        k,
        kv_cache,
        slot_mapping,
        quant_block_size,
        scale_fmt,
    )

640
    topk_indices_buffer[: hidden_states.shape[0]] = -1
641
642
    if has_prefill:
        prefill_metadata = attn_metadata.prefill
643
        for chunk in prefill_metadata.chunks:
644
645
646
647
648
649
            k_fp8 = torch.empty(
                [chunk.total_seq_lens, head_dim],
                device=k.device,
                dtype=torch.float8_e4m3fn,
            )
            k_scale = torch.empty(
650
651
652
                [chunk.total_seq_lens, 4],
                device=k.device,
                dtype=torch.uint8,
653
            )
654
            ops.cp_gather_indexer_k_quant_cache(
655
656
657
658
659
660
661
                kv_cache,
                k_fp8,
                k_scale,
                chunk.block_table,
                chunk.cu_seq_lens,
            )
            logits = fp8_mqa_logits(
662
                q_fp8[chunk.token_start : chunk.token_end],
663
                (k_fp8, k_scale.view(torch.float32)),
664
                weights[chunk.token_start : chunk.token_end],
665
666
667
                chunk.cu_seqlen_ks,
                chunk.cu_seqlen_ke,
            )
668
669
            num_rows = logits.shape[0]
            assert topk_tokens == 2048, "top_k_per_row assumes size 2048"
670
671
672
            topk_indices = topk_indices_buffer[
                chunk.token_start : chunk.token_end, :topk_tokens
            ]
673
674
675
676
677
678
679
680
            torch.ops._C.top_k_per_row(
                logits,
                chunk.cu_seqlen_ks,
                chunk.cu_seqlen_ke,
                topk_indices,
                num_rows,
                logits.stride(0),
                logits.stride(1),
681
            )
682
683
684
685
686
687
688
689
690
691
692
693
694

    if has_decode:
        decode_metadata = attn_metadata.decode
        # kv_cache size requirement [num_block, block_size, n_head, head_dim],
        # we only have [num_block, block_size, head_dim],
        kv_cache = kv_cache.unsqueeze(-2)
        decode_lens = decode_metadata.decode_lens
        if decode_metadata.requires_padding:
            # pad in edge case where we have short chunked prefill length <
            # decode_threshold since we unstrictly split
            # prefill and decode by decode_threshold
            # (currently set to 1 + speculative tokens)
            padded_q_fp8_decode_tokens = pack_seq_triton(
695
696
                q_fp8[:num_decode_tokens], decode_lens
            )
697
698
        else:
            padded_q_fp8_decode_tokens = q_fp8[:num_decode_tokens].reshape(
699
700
                decode_lens.shape[0], -1, *q_fp8.shape[1:]
            )
701
702
703
704
705
706
707
708
709
710
711
712
713
714
        # TODO: move and optimize below logic with triton kernels
        batch_size = padded_q_fp8_decode_tokens.shape[0]
        next_n = padded_q_fp8_decode_tokens.shape[1]
        assert batch_size == decode_metadata.seq_lens.shape[0]
        num_padded_tokens = batch_size * next_n
        logits = fp8_paged_mqa_logits(
            padded_q_fp8_decode_tokens,
            kv_cache,
            weights[:num_padded_tokens],
            decode_metadata.seq_lens,
            decode_metadata.block_table,
            decode_metadata.schedule_metadata,
            max_model_len=max_model_len,
        )
715
716
        num_rows = logits.shape[0]
        assert topk_tokens == 2048, "top_k_per_row assumes size 2048"
717
718
719
        topk_indices = topk_indices_buffer[:num_decode_tokens, :topk_tokens]

        torch.ops._C.top_k_per_row_decode(
720
            logits,
721
722
            next_n,
            decode_metadata.seq_lens,
723
724
725
726
727
            topk_indices,
            num_rows,
            logits.stride(0),
            logits.stride(1),
        )
728
729
730
731
732
        if decode_metadata.requires_padding:
            # if padded, we need to unpack
            # the topk indices removing padded tokens
            topk_indices = unpack_seq_triton(
                topk_indices.reshape(batch_size, -1, topk_indices.shape[-1]),
733
734
                decode_lens,
            )
735
736
737
            topk_indices_buffer[:num_decode_tokens, : topk_indices.shape[-1]] = (
                topk_indices
            )
738
739
740
741
742
743
744
745
746
747
748
749

    return topk_indices_buffer


def sparse_attn_indexer_fake(
    hidden_states: torch.Tensor,
    k_cache_prefix: str,
    kv_cache: torch.Tensor,
    q_fp8: torch.Tensor,
    k: torch.Tensor,
    weights: torch.Tensor,
    quant_block_size: int,
750
    scale_fmt: str | None,
751
752
753
754
    topk_tokens: int,
    head_dim: int,
    max_model_len: int,
    total_seq_lens: int,
755
    topk_indices_buffer: torch.Tensor | None,
756
757
758
759
) -> torch.Tensor:
    # profile run
    # NOTE(Chen): create the max possible flattened_kv. So that
    # profile_run can get correct memory usage.
760
761
762
763
    _flattened_kv = torch.empty(
        [total_seq_lens, head_dim + 4], device=k.device, dtype=torch.uint8
    )
    _k_fp8 = _flattened_kv[..., :head_dim].view(torch.float8_e4m3fn).contiguous()
764
765
766
767
768
769
770
771
772
773
774
775
776
777
    _k_scale = _flattened_kv[..., head_dim:].view(torch.float32).contiguous()
    return topk_indices_buffer


direct_register_custom_op(
    op_name="sparse_attn_indexer",
    op_func=sparse_attn_indexer,
    mutates_args=["topk_indices_buffer"],
    fake_impl=sparse_attn_indexer_fake,
    dispatch_key=current_platform.dispatch_key,
)


class Indexer(nn.Module):
778
779
780
    def __init__(
        self,
        vllm_config: VllmConfig,
781
        config: DeepseekV2Config | DeepseekV3Config,
782
783
        hidden_size: int,
        q_lora_rank: int,
784
785
786
        quant_config: QuantizationConfig | None,
        cache_config: CacheConfig | None,
        topk_indices_buffer: torch.Tensor | None,
787
788
        prefix: str = "",
    ):
789
790
791
792
793
794
795
796
797
798
        super().__init__()
        self.vllm_config = vllm_config
        self.config = config
        # self.indexer_cfg = config.attn_module_list_cfg[0]["attn_index"]
        self.topk_tokens = config.index_topk
        self.n_head = config.index_n_heads  # 64
        self.head_dim = config.index_head_dim  # 128
        self.rope_dim = config.qk_rope_head_dim  # 64
        self.q_lora_rank = q_lora_rank  # 1536
        # no tensor parallel, just replicated
799
800
801
802
803
804
805
806
807
808
809
810
811
812
        self.wq_b = ReplicatedLinear(
            self.q_lora_rank,
            self.head_dim * self.n_head,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.wq_b",
        )
        self.wk = ReplicatedLinear(
            hidden_size,
            self.head_dim,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.wk",
        )
813
        self.k_norm = LayerNorm(self.head_dim, eps=1e-6)
814
815
816
        self.weights_proj = ReplicatedLinear(
            hidden_size, self.n_head, quant_config=None, prefix=f"{prefix}.weights_proj"
        )
817
818
819
820
821
822
823
824
825
826
        self.softmax_scale = self.head_dim**-0.5

        self.scale_fmt = "ue8m0"
        self.quant_block_size = 128  # TODO: get from config
        self.topk_indices_buffer = topk_indices_buffer

        # NOTE: (zyongye) we use fp8 naive cache,
        #       where we store value in fp8 and scale in fp32
        #       per self.quant_block_size element
        self.k_cache = DeepseekV32IndexerCache(
827
            head_dim=self.head_dim + self.head_dim // self.quant_block_size * 4,
828
829
            dtype=torch.uint8,
            prefix=f"{prefix}.k_cache",
830
831
            cache_config=cache_config,
        )
832
833
        self.max_model_len = vllm_config.model_config.max_model_len
        self.prefix = prefix
834
835
        from vllm.v1.attention.backends.mla.indexer import get_max_prefill_buffer_size

836
837
        self.max_total_seq_len = get_max_prefill_buffer_size(vllm_config)

838
839
840
    def forward(
        self, hidden_states: torch.Tensor, qr: torch.Tensor, positions, rotary_emb
    ) -> torch.Tensor:
841
842
843
        q, _ = self.wq_b(qr)
        q = q.view(-1, self.n_head, self.head_dim)
        q_pe, q_nope = torch.split(
844
845
            q, [self.rope_dim, self.head_dim - self.rope_dim], dim=-1
        )
846
847
848
849

        k, _ = self.wk(hidden_states)
        k = self.k_norm(k)
        k_pe, k_nope = torch.split(
850
851
            k, [self.rope_dim, self.head_dim - self.rope_dim], dim=-1
        )
852
853
854
855
856
857
858

        q_pe, k_pe = rotary_emb(positions, q_pe, k_pe.unsqueeze(1))
        q = torch.cat([q_pe, q_nope], dim=-1)
        k = torch.cat([k_pe.squeeze(1), k_nope], dim=-1)

        # we only quant q here since k quant is fused with cache insertion
        q = q.view(-1, self.head_dim)
859
860
861
862
863
864
        q_fp8, q_scale = per_token_group_quant_fp8(
            q,
            self.quant_block_size,
            column_major_scales=False,
            use_ue8m0=self.scale_fmt is not None,
        )
865
866
867
868
        q_fp8 = q_fp8.view(-1, self.n_head, self.head_dim)
        q_scale = q_scale.view(-1, self.n_head, 1)

        weights, _ = self.weights_proj(hidden_states)
869
870
871
        weights = (
            weights.unsqueeze(-1) * q_scale * self.softmax_scale * self.n_head**-0.5
        )
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
        weights = weights.squeeze(-1)

        return torch.ops.vllm.sparse_attn_indexer(
            hidden_states,
            self.k_cache.prefix,
            self.k_cache.kv_cache[0],
            q_fp8,
            k,
            weights,
            self.quant_block_size,
            self.scale_fmt,
            self.topk_tokens,
            self.head_dim,
            self.max_model_len,
            self.max_total_seq_len,
            self.topk_indices_buffer,
        )


891
892
893
894
class DeepseekV2MLAAttention(nn.Module):
    """
    Main reference: DeepseekV2 paper, and FlashInfer Implementation
    (https://arxiv.org/abs/2405.04434 and https://github.com/flashinfer-ai/flashinfer/pull/551).
895

896
897
        For more info see MLACommonImpl in:
        vllm/v1/attention/backends/mla/utils.py
898
899
900
901
    """

    def __init__(
        self,
902
        vllm_config: VllmConfig,
903
        config: DeepseekV2Config | DeepseekV3Config,
904
905
906
907
908
        hidden_size: int,
        num_heads: int,
        qk_nope_head_dim: int,
        qk_rope_head_dim: int,
        v_head_dim: int,
909
        q_lora_rank: int | None,
910
911
        kv_lora_rank: int,
        rope_theta: float = 10000,
912
        rope_scaling: dict[str, Any] | None = None,
913
        max_position_embeddings: int = 8192,
914
915
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
916
        prefix: str = "",
917
        topk_indices_buffer: torch.Tensor | None = None,
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
    ) -> None:
        super().__init__()
        self.hidden_size = hidden_size
        self.qk_nope_head_dim = qk_nope_head_dim
        self.qk_rope_head_dim = qk_rope_head_dim
        self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
        self.v_head_dim = v_head_dim

        self.q_lora_rank = q_lora_rank
        self.kv_lora_rank = kv_lora_rank

        self.num_heads = num_heads
        tp_size = get_tensor_model_parallel_world_size()
        assert num_heads % tp_size == 0
        self.num_local_heads = num_heads // tp_size

        self.scaling = self.qk_head_dim**-0.5
        self.rope_theta = rope_theta
        self.max_position_embeddings = max_position_embeddings

        if self.q_lora_rank is not None:
939
            self.fused_qkv_a_proj = MergedColumnParallelLinear(
940
941
942
943
                self.hidden_size,
                [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim],
                bias=False,
                quant_config=quant_config,
944
                prefix=f"{prefix}.fused_qkv_a_proj",
945
946
                disable_tp=True,
            )
947
948
949
950
951
952
        else:
            self.kv_a_proj_with_mqa = ReplicatedLinear(
                self.hidden_size,
                self.kv_lora_rank + self.qk_rope_head_dim,
                bias=False,
                quant_config=quant_config,
953
954
                prefix=f"{prefix}.kv_a_proj_with_mqa",
            )
955
956

        if self.q_lora_rank is not None:
957
958
959
960
961
962
963
964
            self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
            self.q_b_proj = ColumnParallelLinear(
                self.q_lora_rank,
                self.num_heads * self.qk_head_dim,
                bias=False,
                quant_config=quant_config,
                prefix=f"{prefix}.q_b_proj",
            )
965
        else:
966
967
968
969
970
971
972
973
            self.q_proj = ColumnParallelLinear(
                self.hidden_size,
                self.num_heads * self.qk_head_dim,
                bias=False,
                quant_config=quant_config,
                prefix=f"{prefix}.q_proj",
            )
        self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
974
975
976
977
978
        self.kv_b_proj = ColumnParallelLinear(
            self.kv_lora_rank,
            self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
            bias=False,
            quant_config=quant_config,
979
980
981
982
983
984
985
986
987
            prefix=f"{prefix}.kv_b_proj",
        )
        self.o_proj = RowParallelLinear(
            self.num_heads * self.v_head_dim,
            self.hidden_size,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.o_proj",
        )
988

989
        if rope_scaling:
990
991
992
993
994
995
996
997
998
            rope_scaling["rope_type"] = "deepseek_yarn"
        self.rotary_emb = get_rope(
            qk_rope_head_dim,
            rotary_dim=qk_rope_head_dim,
            max_position=max_position_embeddings,
            base=rope_theta,
            rope_scaling=rope_scaling,
            is_neox_style=False,
        )
999
1000
1001
1002
1003
1004
        if rope_scaling:
            mscale_all_dim = rope_scaling.get("mscale_all_dim", False)
            scaling_factor = rope_scaling["factor"]
            mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
            self.scaling = self.scaling * mscale * mscale

1005
1006
1007
        self.is_v32 = hasattr(config, "index_topk")

        if self.is_v32:
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
            self.indexer = Indexer(
                vllm_config,
                config,
                hidden_size,
                q_lora_rank,
                quant_config,
                cache_config,
                topk_indices_buffer,
                f"{prefix}.indexer",
            )
1018
1019
1020
        else:
            self.indexer = None

1021
1022
        mla_modules = MLAModules(
            kv_a_layernorm=self.kv_a_layernorm,
1023
            kv_b_proj=self.kv_b_proj,
1024
1025
1026
            rotary_emb=self.rotary_emb,
            o_proj=self.o_proj,
            fused_qkv_a_proj=self.fused_qkv_a_proj
1027
1028
            if self.q_lora_rank is not None
            else None,
1029
            kv_a_proj_with_mqa=self.kv_a_proj_with_mqa
1030
1031
1032
            if self.q_lora_rank is None
            else None,
            q_a_layernorm=self.q_a_layernorm if self.q_lora_rank is not None else None,
1033
1034
            q_b_proj=self.q_b_proj if self.q_lora_rank is not None else None,
            q_proj=self.q_proj if self.q_lora_rank is None else None,
1035
1036
1037
            indexer=self.indexer,
            is_sparse=self.is_v32,
            topk_indices_buffer=topk_indices_buffer,
1038
        )
1039

1040
        self.mla_attn = MultiHeadLatentAttentionWrapper(
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
            self.hidden_size,
            self.num_local_heads,
            self.scaling,
            self.qk_nope_head_dim,
            self.qk_rope_head_dim,
            self.v_head_dim,
            self.q_lora_rank,
            self.kv_lora_rank,
            mla_modules,
            cache_config,
            quant_config,
            prefix,
1053
1054
1055
1056
1057
1058
1059
        )

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
1060
        return self.mla_attn(positions, hidden_states)
1061
1062


wangding zeng's avatar
wangding zeng committed
1063
class DeepseekV2DecoderLayer(nn.Module):
1064
1065
1066
1067
    def __init__(
        self,
        vllm_config: VllmConfig,
        prefix: str,
1068
1069
        config: DeepseekV2Config | None = None,
        topk_indices_buffer: torch.Tensor | None = None,
1070
    ) -> None:
wangding zeng's avatar
wangding zeng committed
1071
        super().__init__()
1072

1073
1074
        if config is None:
            config = vllm_config.model_config.hf_config
1075
1076
1077
1078
1079
        model_config = vllm_config.model_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config
        parallel_config = vllm_config.parallel_config

wangding zeng's avatar
wangding zeng committed
1080
1081
1082
        self.hidden_size = config.hidden_size
        rope_theta = getattr(config, "rope_theta", 10000)
        rope_scaling = getattr(config, "rope_scaling", None)
1083
        max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
1084
        moe_layer_freq = getattr(config, "moe_layer_freq", 1)
1085
1086
        # DecoderLayers are created with `make_layers` which passes the prefix
        # with the layer's index.
1087
        layer_idx = int(prefix.split(sep=".")[-1])
1088
        self.layer_idx = layer_idx
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101

        # verify MLA attention specific fields
        qk_nope_head_dim = getattr(config, "qk_nope_head_dim", 0)
        qk_rope_head_dim = getattr(config, "qk_rope_head_dim", 0)
        v_head_dim = getattr(config, "v_head_dim", 0)
        kv_lora_rank = getattr(config, "kv_lora_rank", 0)
        use_mha = config.model_type == "deepseek" or all(
            dim == 0 for dim in (qk_nope_head_dim, qk_rope_head_dim)
        )

        if use_mha:
            attn_cls = DeepseekAttention
        elif model_config.use_mla:
1102
1103
1104
1105
            attn_cls = DeepseekV2MLAAttention
        else:
            attn_cls = DeepseekV2Attention
        self.self_attn = attn_cls(
1106
            vllm_config=vllm_config,
wangding zeng's avatar
wangding zeng committed
1107
1108
1109
            config=config,
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
1110
1111
1112
            qk_nope_head_dim=qk_nope_head_dim,
            qk_rope_head_dim=qk_rope_head_dim,
            v_head_dim=v_head_dim,
1113
            q_lora_rank=config.q_lora_rank if hasattr(config, "q_lora_rank") else None,
1114
            kv_lora_rank=kv_lora_rank,
wangding zeng's avatar
wangding zeng committed
1115
1116
1117
1118
1119
            rope_theta=rope_theta,
            rope_scaling=rope_scaling,
            max_position_embeddings=max_position_embeddings,
            cache_config=cache_config,
            quant_config=quant_config,
1120
            prefix=f"{prefix}.self_attn",
1121
            topk_indices_buffer=topk_indices_buffer,
wangding zeng's avatar
wangding zeng committed
1122
        )
1123

1124
1125
1126
        if (
            config.n_routed_experts is not None
            and layer_idx >= config.first_k_dense_replace
1127
            and layer_idx % moe_layer_freq == 0
1128
        ):
1129
1130
            self.mlp = DeepseekV2MoE(
                config=config,
1131
                parallel_config=parallel_config,
1132
1133
1134
                quant_config=quant_config,
                prefix=f"{prefix}.mlp",
            )
wangding zeng's avatar
wangding zeng committed
1135
1136
1137
1138
1139
1140
        else:
            self.mlp = DeepseekV2MLP(
                hidden_size=config.hidden_size,
                intermediate_size=config.intermediate_size,
                hidden_act=config.hidden_act,
                quant_config=quant_config,
1141
                prefix=f"{prefix}.mlp",
wangding zeng's avatar
wangding zeng committed
1142
            )
1143
1144
1145
1146
        self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(
            config.hidden_size, eps=config.rms_norm_eps
        )
1147
        self.routed_scaling_factor = getattr(config, "routed_scaling_factor", 1.0)
wangding zeng's avatar
wangding zeng committed
1148
1149
1150
1151
1152

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
1153
        residual: torch.Tensor | None,
wangding zeng's avatar
wangding zeng committed
1154
1155
1156
    ) -> torch.Tensor:
        # Self Attention
        if residual is None:
1157
            residual = hidden_states.clone()
wangding zeng's avatar
wangding zeng committed
1158
1159
            hidden_states = self.input_layernorm(hidden_states)
        else:
1160
            hidden_states, residual = self.input_layernorm(hidden_states, residual)
wangding zeng's avatar
wangding zeng committed
1161
1162
1163
1164
1165
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
        )

1166
1167
1168
1169
        if (
            not isinstance(self.self_attn, DeepseekAttention)
            and hidden_states.dtype == torch.float16
        ):
1170
1171
1172
            # Fix FP16 overflow
            # We scale both hidden_states and residual before
            # rmsnorm, and rmsnorm result would not affect by scale.
1173
            hidden_states *= 1.0 / self.routed_scaling_factor
1174
1175
1176
            if self.layer_idx == 0:
                # The residual is shared by all layers, we only scale it on
                # first layer.
1177
                residual *= 1.0 / self.routed_scaling_factor
1178
1179

        # Fully Connected
1180
        hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
wangding zeng's avatar
wangding zeng committed
1181
        hidden_states = self.mlp(hidden_states)
1182

1183
        if isinstance(self.mlp, DeepseekV2MLP) and hidden_states.dtype == torch.float16:
1184
1185
1186
1187
1188
            # Fix FP16 overflow
            # Scaling the DeepseekV2MLP output, it is the input of
            # input_layernorm of next decoder layer.
            # The scaling of DeepseekV2MOE output would be done in the forward
            # of DeepseekV2MOE
1189
            hidden_states *= 1.0 / self.routed_scaling_factor
1190

wangding zeng's avatar
wangding zeng committed
1191
1192
1193
        return hidden_states, residual


1194
@support_torch_compile
wangding zeng's avatar
wangding zeng committed
1195
1196
1197
class DeepseekV2Model(nn.Module):
    fall_back_to_pt_during_load = False

1198
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
wangding zeng's avatar
wangding zeng committed
1199
        super().__init__()
1200
1201
1202

        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
1203
        self.config = config
1204
        self.device = current_platform.device_type
1205

wangding zeng's avatar
wangding zeng committed
1206
        self.vocab_size = config.vocab_size
1207
1208
1209
1210
1211
1212
1213
        self.is_v32 = hasattr(config, "index_topk")
        if self.is_v32:
            topk_tokens = config.index_topk
            topk_indices_buffer = torch.empty(
                vllm_config.scheduler_config.max_num_batched_tokens,
                topk_tokens,
                dtype=torch.int32,
1214
                device=self.device,
1215
            )
1216
1217
        else:
            topk_indices_buffer = None
wangding zeng's avatar
wangding zeng committed
1218

1219
1220
1221
1222
        if get_pp_group().is_first_rank:
            self.embed_tokens = VocabParallelEmbedding(
                config.vocab_size,
                config.hidden_size,
1223
                quant_config=quant_config,
1224
1225
                prefix=f"{prefix}.embed_tokens",
            )
1226
1227
1228
1229
        else:
            self.embed_tokens = PPMissingLayer()
        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
1230
            lambda prefix: DeepseekV2DecoderLayer(
1231
                vllm_config, prefix, topk_indices_buffer=topk_indices_buffer
1232
1233
1234
            ),
            prefix=f"{prefix}.layers",
        )
1235
1236
1237
1238
1239

        if get_pp_group().is_last_rank:
            self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        else:
            self.norm = PPMissingLayer()
1240
1241
1242
        self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
            ["hidden_states", "residual"], config.hidden_size
        )
wangding zeng's avatar
wangding zeng committed
1243

1244
1245
1246
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.embed_tokens(input_ids)

wangding zeng's avatar
wangding zeng committed
1247
1248
1249
1250
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
1251
1252
1253
        intermediate_tensors: IntermediateTensors | None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor | IntermediateTensors:
1254
        if get_pp_group().is_first_rank:
1255
1256
1257
1258
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
                hidden_states = self.get_input_embeddings(input_ids)
1259
1260
1261
1262
1263
1264
            residual = None
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]

1265
        for layer in islice(self.layers, self.start_layer, self.end_layer):
1266
            hidden_states, residual = layer(positions, hidden_states, residual)
1267
1268

        if not get_pp_group().is_last_rank:
1269
1270
1271
            return IntermediateTensors(
                {"hidden_states": hidden_states, "residual": residual}
            )
1272

wangding zeng's avatar
wangding zeng committed
1273
1274
1275
        hidden_states, _ = self.norm(hidden_states, residual)
        return hidden_states

1276

1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
class DeepseekV2MixtureOfExperts(MixtureOfExperts):
    moe_mlp_layers: list[DeepseekV2MoE]
    """
    List of MoE MLP layers in the model.
    """

    def extract_moe_parameters(self, example_moe: DeepseekV2MoE | None):
        if example_moe is None:
            self.num_moe_layers = 0
            self.num_expert_groups = 0
            self.num_logical_experts = 0
            self.num_physical_experts = 0
            self.num_local_physical_experts = 0
            self.num_routed_experts = 0
            self.num_shared_experts = 0
            self.num_redundant_experts = 0
            logger.warning("DeepSeekV2: No DeepseekV2MoE layer found in model.layers.")
        else:
            self.num_logical_experts = example_moe.n_logical_experts
            self.num_physical_experts = example_moe.n_physical_experts
            self.num_local_physical_experts = example_moe.n_local_physical_experts
            self.num_routed_experts = example_moe.n_routed_experts
            self.num_shared_experts = example_moe.n_shared_experts
            self.num_redundant_experts = example_moe.n_redundant_experts

    def update_physical_experts_metadata(
        self,
        num_physical_experts: int,
        num_local_physical_experts: int,
    ) -> None:
        assert self.num_local_physical_experts == num_local_physical_experts
        self.num_physical_experts = num_physical_experts
        self.num_local_physical_experts = num_local_physical_experts
        self.num_redundant_experts = num_physical_experts - self.num_logical_experts
        for moe in self.moe_mlp_layers:
            moe.n_local_physical_experts = num_local_physical_experts
            moe.n_physical_experts = num_physical_experts
            moe.n_redundant_experts = self.num_redundant_experts
            moe.experts.update_expert_map()


class DeepseekV2ForCausalLM(
    nn.Module, SupportsPP, DeepseekV2MixtureOfExperts, SupportsLoRA
):
1321
1322
1323
    packed_modules_mapping = {
        "gate_up_proj": ["gate_proj", "up_proj"],
    }
1324
1325
1326
1327
1328
1329
1330

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        self.config = config
        self.quant_config = quant_config
1331

1332
1333
1334
1335
1336
1337
1338
1339
1340
        qk_nope_head_dim = getattr(config, "qk_nope_head_dim", 0)
        qk_rope_head_dim = getattr(config, "qk_rope_head_dim", 0)
        self.use_mha = config.model_type == "deepseek" or all(
            dim == 0 for dim in (qk_nope_head_dim, qk_rope_head_dim)
        )

        if self.use_mha:
            self.packed_modules_mapping["qkv_proj"] = ["q_proj", "k_proj", "v_proj"]

1341
1342
1343
1344
        # `packed_modules_mapping` needs to be modified before
        # initializing DeepseekV2Model, as it is passed inplace to
        # quantization config init and may be used to select the
        # quant_method for relevant layers during initialization.
1345
1346
1347
        self.fuse_qkv_a_proj = (
            hasattr(config, "q_lora_rank") and config.q_lora_rank is not None
        )
1348
1349
1350
1351
1352
1353
        if self.fuse_qkv_a_proj:
            self.packed_modules_mapping["fused_qkv_a_proj"] = [
                "q_a_proj",
                "kv_a_proj_with_mqa",
            ]

1354
1355
1356
        self.model = DeepseekV2Model(
            vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
        )
1357
        if get_pp_group().is_last_rank:
1358
1359
1360
1361
1362
1363
            self.lm_head = ParallelLMHead(
                config.vocab_size,
                config.hidden_size,
                quant_config=quant_config,
                prefix=maybe_prefix(prefix, "lm_head"),
            )
1364
1365
1366
1367
        else:
            self.lm_head = PPMissingLayer()
        self.logits_processor = LogitsProcessor(config.vocab_size)
        self.make_empty_intermediate_tensors = (
1368
1369
            self.model.make_empty_intermediate_tensors
        )
1370
1371
1372
1373
1374
1375
1376
        # Set MoE hyperparameters
        self.num_moe_layers = (
            self.config.num_hidden_layers - self.config.first_k_dense_replace
        )
        self.set_moe_parameters()

    def set_moe_parameters(self):
1377
1378
        self.expert_weights = []

1379
        self.num_expert_groups = getattr(self.config, "n_group", 1)
1380

1381
1382
        self.moe_layers = []
        self.moe_mlp_layers = []
1383
        example_moe = None
1384
        for layer in self.model.layers:
1385
1386
1387
            if isinstance(layer, PPMissingLayer):
                continue

1388
1389
            assert isinstance(layer, DeepseekV2DecoderLayer)
            if isinstance(layer.mlp, DeepseekV2MoE):
1390
1391
                # Pick last one layer since the first ones may be dense layers.
                example_moe = layer.mlp
1392
                self.moe_mlp_layers.append(layer.mlp)
1393
1394
                self.moe_layers.append(layer.mlp.experts)

1395
        self.extract_moe_parameters(example_moe)
1396

1397
1398
1399
1400
1401
1402
1403
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.get_input_embeddings(input_ids)

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
1404
1405
1406
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor | IntermediateTensors:
1407
1408
1409
        hidden_states = self.model(
            input_ids, positions, intermediate_tensors, inputs_embeds
        )
1410
1411
1412
1413
1414
        return hidden_states

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
1415
    ) -> torch.Tensor | None:
1416
        logits = self.logits_processor(self.lm_head, hidden_states)
1417
1418
        return logits

1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
    def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
        # Params for weights, fp8 weight scales, fp8 activation scales
        # (param_name, weight_name, expert_id, shard_id)
        return SharedFusedMoE.make_expert_params_mapping(
            ckpt_gate_proj_name="gate_proj",
            ckpt_down_proj_name="down_proj",
            ckpt_up_proj_name="up_proj",
            num_experts=self.config.n_routed_experts,
            num_redundant_experts=0,
        )

1430
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
wangding zeng's avatar
wangding zeng committed
1431
1432
1433
1434
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
1435
1436
        ]
        mla_params_mapping = [
1437
1438
            ("fused_qkv_a_proj", "q_a_proj", 0),
            ("fused_qkv_a_proj", "kv_a_proj_with_mqa", 1),
wangding zeng's avatar
wangding zeng committed
1439
        ]
1440
1441
1442
1443
1444
1445
1446
1447
1448
        mha_params_mapping = [
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
        ]
        if self.use_mha:
            stacked_params_mapping.extend(mha_params_mapping)
        else:
            stacked_params_mapping.extend(mla_params_mapping)
wangding zeng's avatar
wangding zeng committed
1449

1450
1451
        # Params for weights, fp8 weight scales, fp8 activation scales
        # (param_name, weight_name, expert_id, shard_id)
1452
        expert_params_mapping = SharedFusedMoE.make_expert_params_mapping(
1453
1454
1455
            ckpt_gate_proj_name="gate_proj",
            ckpt_down_proj_name="down_proj",
            ckpt_up_proj_name="up_proj",
1456
1457
1458
1459
1460
1461
            num_experts=self.config.n_routed_experts
            + (
                self.config.n_shared_experts
                if is_rocm_aiter_fusion_shared_expert_enabled()
                else 0
            ),
1462
1463
            num_redundant_experts=self.num_redundant_experts,
        )
1464

wangding zeng's avatar
wangding zeng committed
1465
        params_dict = dict(self.named_parameters())
1466
        loaded_params: set[str] = set()
wangding zeng's avatar
wangding zeng committed
1467
        for name, loaded_weight in weights:
1468
1469
1470
            if "rotary_emb.inv_freq" in name:
                continue

1471
1472
1473
            spec_layer = get_spec_layer_idx_from_weight_name(self.config, name)
            if spec_layer is not None:
                continue  # skip spec decode layers for main model
1474

1475
1476
1477
1478
1479
            is_fuse_shared_experts_layer = (
                is_rocm_aiter_fusion_shared_expert_enabled()
                and ("mlp.shared_experts" in name)
            )

1480
            for param_name, weight_name, shard_id in stacked_params_mapping:
1481
                # Skip non-stacked layers and experts (experts handled below).
wangding zeng's avatar
wangding zeng committed
1482
1483
                if weight_name not in name:
                    continue
1484
1485
1486
1487
1488
1489
                # We have mlp.experts[0].gate_proj in the checkpoint.
                # Since we handle the experts below in expert_params_mapping,
                # we need to skip here BEFORE we update the name, otherwise
                # name will be updated to mlp.experts[0].gate_up_proj, which
                # will then be updated below in expert_params_mapping
                # for mlp.experts[0].gate_gate_up_proj, which breaks load.
1490
                if ("mlp.experts." in name) and name not in params_dict:
1491
                    continue
1492
1493
                if is_fuse_shared_experts_layer:
                    continue
1494
                name_mapped = name.replace(weight_name, param_name)
1495
1496
1497

                # QKV fusion is optional, fall back to normal
                # weight loading if it's not enabled
1498
                # if go with fusion option, then update name
1499
1500
1501
                if (
                    param_name == "fused_qkv_a_proj"
                ) and name_mapped not in params_dict:
1502
                    continue
1503
1504
                else:
                    name = name_mapped
wangding zeng's avatar
wangding zeng committed
1505
1506
1507
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
1508
1509
1510
1511

                if is_pp_missing_parameter(name, self):
                    continue

wangding zeng's avatar
wangding zeng committed
1512
1513
1514
1515
1516
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
1517
                is_expert_weight = False
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537

                # Special handling: when AITER fusion_shared_experts is enabled,
                # checkpoints may provide a single widened shared_experts tensor
                # without explicit expert indices
                # (e.g. ...mlp.shared_experts.gate_proj.weight).
                # For models with multiple shared experts, split that tensor
                # evenly into per-shared-expert slices and load them into
                # appended expert slots mlp.experts.{n_routed_experts + j}.*
                # accordingly.
                num_chunks = 1
                if is_fuse_shared_experts_layer:
                    num_chunks = getattr(self.config, "n_shared_experts", 1) or 1
                    # Determine split axis based on op type
                    # gate/up: ColumnParallel → split along dim 0
                    # down: RowParallel → split along dim 1
                    split_dim = 1 if "down_proj.weight" in name else 0
                    total = loaded_weight.shape[split_dim]
                    assert total % num_chunks == 0, (
                        f"Shared expert weight dim {total} "
                        f"not divisible by num_chunks {num_chunks}"
1538
                    )
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
                    chunk_size = total // num_chunks

                for j in range(num_chunks):
                    chunk_name = name
                    weight_to_load = loaded_weight

                    if is_fuse_shared_experts_layer:
                        if split_dim == 0:
                            weight_to_load = loaded_weight[
                                j * chunk_size : (j + 1) * chunk_size, :
                            ]
                        else:
                            weight_to_load = loaded_weight[
                                :, j * chunk_size : (j + 1) * chunk_size
                            ]
                        # Synthesize an expert-style name so expert mapping
                        # can route it
                        chunk_name = name.replace(
                            "mlp.shared_experts",
                            f"mlp.experts.{self.config.n_routed_experts + j}",
                        )

                    # Use expert_params_mapping to locate the destination
                    # param and delegate to its expert-aware weight_loader
                    # with expert_id.
                    for mapping in expert_params_mapping:
                        param_name, weight_name, expert_id, shard_id = mapping
                        if weight_name not in chunk_name:
                            continue

                        # Anyway, this is an expert weight and should not be
                        # attempted to load as other weights later
                        is_expert_weight = True

                        # Do not modify `name` since the loop may continue here
                        # Instead, create a new variable
                        name_mapped = chunk_name.replace(weight_name, param_name)

                        if is_pp_missing_parameter(name_mapped, self):
                            continue

                        param = params_dict[name_mapped]
                        # We should ask the weight loader to return success or
                        # not here since otherwise we may skip experts with
                        # other available replicas.
                        weight_loader = typing.cast(
                            Callable[..., bool], param.weight_loader
                        )
                        success = weight_loader(
                            param,
                            weight_to_load,
                            name_mapped,
                            shard_id=shard_id,
                            expert_id=expert_id,
                            return_success=True,
                        )
                        if success:
                            if not is_fuse_shared_experts_layer:
                                name = name_mapped
                            else:
                                loaded_params.add(name_mapped)
                            break
                    else:
                        if is_expert_weight:
                            # We've checked that this is an expert weight
                            # However it's not mapped locally to this rank
                            # So we simply skip it
                            continue

                        # Skip loading extra bias for GPTQ models.
                        if name.endswith(".bias") and name not in params_dict:
                            continue

                        # Remapping the name of FP8 kv-scale.
                        name = maybe_remap_kv_scale_name(name, params_dict)
                        if name is None:
                            continue

                        if is_pp_missing_parameter(name, self):
                            continue

                        param = params_dict[name]
                        weight_loader = getattr(
                            param, "weight_loader", default_weight_loader
                        )
                        weight_loader(param, loaded_weight)
            if not is_fuse_shared_experts_layer:
                loaded_params.add(name)
1627

1628
        return loaded_params
1629
1630


1631
1632
1633
1634
class DeepseekForCausalLM(DeepseekV2ForCausalLM):
    pass


1635
1636
class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM):
    pass
1637
1638


1639
1640
# Compatibility with
# https://huggingface.co/deepseek-ai/DeepSeek-V3-Base/blob/main/configuration_deepseek.py
1641
def get_spec_layer_idx_from_weight_name(
1642
1643
    config: DeepseekV2Config | DeepseekV3Config, weight_name: str
) -> int | None:
1644
1645
1646
1647
    if (
        hasattr(config, "num_nextn_predict_layers")
        and config.num_nextn_predict_layers > 0
    ):
1648
1649
        layer_idx = config.num_hidden_layers
        for i in range(config.num_nextn_predict_layers):
1650
            if weight_name.startswith(f"model.layers.{layer_idx + i}."):
1651
1652
                return layer_idx + i
    return None