"tests/models/quantization/test_gguf.py" did not exist on "9597a095f2c02670b44f5973635ce4b9852e8eab"
deepseek_v2.py 63.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

wangding zeng's avatar
wangding zeng committed
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# Copyright 2023 The vLLM team.
# Copyright 2023 DeepSeek-AI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
25
"""Inference-only DeepseekV2/DeepseekV3 model."""
26

27
28
import typing
from collections.abc import Callable, Iterable
29
from itertools import islice
wangding zeng's avatar
wangding zeng committed
30
31
32

import torch
from torch import nn
33
from transformers import DeepseekV2Config, DeepseekV3Config
wangding zeng's avatar
wangding zeng committed
34

35
import vllm._custom_ops as ops
36
from vllm._aiter_ops import rocm_aiter_ops
37
from vllm.compilation.decorators import support_torch_compile
38
39
40
41
42
43
44
45
from vllm.config import CacheConfig, ParallelConfig, VllmConfig, get_current_vllm_config
from vllm.distributed import (
    get_ep_group,
    get_pp_group,
    get_tensor_model_parallel_rank,
    get_tensor_model_parallel_world_size,
    tensor_model_parallel_all_gather,
)
46
from vllm.logger import init_logger
wangding zeng's avatar
wangding zeng committed
47
from vllm.model_executor.layers.activation import SiluAndMul
48
from vllm.model_executor.layers.attention import Attention
49
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
50
from vllm.model_executor.layers.fused_moe import (
51
    FusedMoE,
52
53
    GateLinear,
    RoutingMethodType,
54
    fused_moe_make_expert_params_mapping,
55
)
56
from vllm.model_executor.layers.layernorm import LayerNorm, RMSNorm
57
58
59
from vllm.model_executor.layers.linear import (
    ColumnParallelLinear,
    MergedColumnParallelLinear,
60
    QKVParallelLinear,
61
62
63
    ReplicatedLinear,
    RowParallelLinear,
)
wangding zeng's avatar
wangding zeng committed
64
from vllm.model_executor.layers.logits_processor import LogitsProcessor
65
from vllm.model_executor.layers.mla import MLAModules, MultiHeadLatentAttentionWrapper
66
from vllm.model_executor.layers.quantization import QuantizationConfig
67
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
68
69
    per_token_group_quant_fp8,
)
70
71
72
73
from vllm.model_executor.layers.quantization.utils.quant_utils import (
    GroupShape,
    scaled_dequantize,
)
wangding zeng's avatar
wangding zeng committed
74
from vllm.model_executor.layers.rotary_embedding import get_rope
75
76
77
from vllm.model_executor.layers.sparse_attn_indexer import (
    SparseAttnIndexer,
)
wangding zeng's avatar
wangding zeng committed
78
from vllm.model_executor.layers.vocab_parallel_embedding import (
79
80
81
    ParallelLMHead,
    VocabParallelEmbedding,
)
82
from vllm.model_executor.model_loader.weight_utils import (
83
84
85
    default_weight_loader,
    maybe_remap_kv_scale_name,
)
86
from vllm.model_executor.models.utils import sequence_parallel_chunk
87
from vllm.platforms import current_platform
88
from vllm.sequence import IntermediateTensors
89
from vllm.utils.torch_utils import direct_register_custom_op
90
from vllm.v1.attention.backend import AttentionBackend
91
92
93
from vllm.v1.attention.backends.mla.indexer import (
    DeepseekV32IndexerBackend,
)
94
from vllm.v1.kv_cache_interface import KVCacheSpec, MLAAttentionSpec
wangding zeng's avatar
wangding zeng committed
95

96
97
98
99
100
101
102
from .interfaces import (
    MixtureOfExperts,
    SupportsEagle,
    SupportsEagle3,
    SupportsLoRA,
    SupportsPP,
)
103
104
105
106
107
108
109
from .utils import (
    PPMissingLayer,
    is_pp_missing_parameter,
    make_empty_intermediate_tensors_factory,
    make_layers,
    maybe_prefix,
)
110

111
112
logger = init_logger(__name__)

wangding zeng's avatar
wangding zeng committed
113

114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
class DeepseekAttention(nn.Module):
    """Normal MHA implementation used by Deepseek v1."""

    def __init__(
        self,
        vllm_config: VllmConfig,
        config: DeepseekV2Config | DeepseekV3Config,
        hidden_size: int,
        num_heads: int,
        max_position_embeddings: int = 8192,
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
        prefix: str = "",
        **kwargs,
    ) -> None:
        super().__init__()
        self.hidden_size = hidden_size
        tp_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = num_heads
        assert self.total_num_heads % tp_size == 0
        self.num_heads = self.total_num_heads // tp_size
        self.total_num_kv_heads = config.num_key_value_heads
        if self.total_num_kv_heads >= tp_size:
            # Number of KV heads is greater than TP size, so we partition
            # the KV heads across multiple tensor parallel GPUs.
            assert self.total_num_kv_heads % tp_size == 0
        else:
            # Number of KV heads is less than TP size, so we replicate
            # the KV heads across multiple tensor parallel GPUs.
            assert tp_size % self.total_num_kv_heads == 0
        self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
        self.head_dim = hidden_size // self.total_num_heads
        self.q_size = self.num_heads * self.head_dim
        self.kv_size = self.num_kv_heads * self.head_dim
        self.scaling = self.head_dim**-0.5
        self.max_position_embeddings = max_position_embeddings

        self.qkv_proj = QKVParallelLinear(
            hidden_size,
            self.head_dim,
            self.total_num_heads,
            self.total_num_kv_heads,
            bias=False,
            quant_config=quant_config,
        )

        self.o_proj = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            hidden_size,
            bias=False,
            quant_config=quant_config,
        )

        self.rotary_emb = get_rope(
            self.head_dim,
            max_position=max_position_embeddings,
170
            rope_parameters=config.rope_parameters,
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
        )
        self.attn = Attention(
            self.num_heads,
            self.head_dim,
            self.scaling,
            num_kv_heads=self.num_kv_heads,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.attn",
        )

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        qkv, _ = self.qkv_proj(hidden_states)
        q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
        q, k = self.rotary_emb(positions, q, k)
        attn_output = self.attn(q, k, v)
        output, _ = self.o_proj(attn_output)
        return output


wangding zeng's avatar
wangding zeng committed
195
196
197
198
199
200
class DeepseekV2MLP(nn.Module):
    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
201
        quant_config: QuantizationConfig | None = None,
wangding zeng's avatar
wangding zeng committed
202
        reduce_results: bool = True,
203
        is_sequence_parallel=False,
204
        prefix: str = "",
wangding zeng's avatar
wangding zeng committed
205
206
    ) -> None:
        super().__init__()
207
208
209
210
211

        # If is_sequence_parallel, the input and output tensors are sharded
        # across the ranks within the tp_group. In this case the weights are
        # replicated and no collective ops are needed.
        # Otherwise we use standard TP with an allreduce at the end.
wangding zeng's avatar
wangding zeng committed
212
        self.gate_up_proj = MergedColumnParallelLinear(
213
214
            hidden_size,
            [intermediate_size] * 2,
wangding zeng's avatar
wangding zeng committed
215
            bias=False,
216
            quant_config=quant_config,
217
            disable_tp=is_sequence_parallel,
218
219
220
221
222
223
224
225
226
227
228
            prefix=f"{prefix}.gate_up_proj",
        )
        self.down_proj = RowParallelLinear(
            intermediate_size,
            hidden_size,
            bias=False,
            quant_config=quant_config,
            reduce_results=reduce_results,
            disable_tp=is_sequence_parallel,
            prefix=f"{prefix}.down_proj",
        )
wangding zeng's avatar
wangding zeng committed
229
        if hidden_act != "silu":
230
231
232
            raise ValueError(
                f"Unsupported activation: {hidden_act}. Only silu is supported for now."
            )
wangding zeng's avatar
wangding zeng committed
233
234
235
236
237
238
239
240
241
242
243
244
        self.act_fn = SiluAndMul()

    def forward(self, x):
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.down_proj(x)
        return x


class DeepseekV2MoE(nn.Module):
    def __init__(
        self,
245
        config: DeepseekV2Config | DeepseekV3Config,
246
        parallel_config: ParallelConfig,
247
        quant_config: QuantizationConfig | None = None,
248
        prefix: str = "",
wangding zeng's avatar
wangding zeng committed
249
250
251
    ):
        super().__init__()
        self.tp_size = get_tensor_model_parallel_world_size()
252
253
        self.tp_rank = get_tensor_model_parallel_rank()

254
        self.routed_scaling_factor = getattr(config, "routed_scaling_factor", 1.0)
255
256

        self.ep_group = get_ep_group().device_group
257
        self.ep_rank = get_ep_group().rank_in_group
258
259
260
        self.ep_size = self.ep_group.size()
        self.n_routed_experts: int = config.n_routed_experts
        self.n_shared_experts: int = config.n_shared_experts
261

262
        self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe
263

264
        if config.hidden_act != "silu":
265
266
267
268
269
            raise ValueError(
                f"Unsupported activation: {config.hidden_act}. "
                "Only silu is supported for now."
            )

270
        self.gate = GateLinear(
271
272
273
274
            config.hidden_size,
            config.n_routed_experts,
            prefix=f"{prefix}.gate",
        )
275
        if getattr(config, "topk_method", None) == "noaux_tc":
276
            self.gate.e_score_correction_bias = nn.Parameter(
277
278
                torch.empty(config.n_routed_experts, dtype=torch.float32)
            )
279
280
281
        else:
            self.gate.e_score_correction_bias = None

282
        # Load balancing settings.
283
284
        eplb_config = parallel_config.eplb_config
        self.enable_eplb = parallel_config.enable_eplb
285

286
        self.n_redundant_experts = eplb_config.num_redundant_experts
287
        self.n_logical_experts = self.n_routed_experts
288
        self.n_physical_experts = self.n_logical_experts + self.n_redundant_experts
289
290
        self.n_local_physical_experts = self.n_physical_experts // self.ep_size

291
292
293
294
        self.physical_expert_start = self.ep_rank * self.n_local_physical_experts
        self.physical_expert_end = (
            self.physical_expert_start + self.n_local_physical_experts
        )
295

296
        self.is_rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled()
297
298
299
300
        self.is_fusion_moe_shared_experts_enabled = (
            rocm_aiter_ops.is_fusion_moe_shared_experts_enabled()
        )
        if config.n_shared_experts is None or self.is_fusion_moe_shared_experts_enabled:
301
302
            self.shared_experts = None
        else:
303
            intermediate_size = config.moe_intermediate_size * config.n_shared_experts
304

wangding zeng's avatar
wangding zeng committed
305
306
307
308
309
            self.shared_experts = DeepseekV2MLP(
                hidden_size=config.hidden_size,
                intermediate_size=intermediate_size,
                hidden_act=config.hidden_act,
                quant_config=quant_config,
310
                is_sequence_parallel=self.is_sequence_parallel,
311
                reduce_results=False,
312
                prefix=f"{prefix}.shared_experts",
wangding zeng's avatar
wangding zeng committed
313
314
            )

315
        self.experts = FusedMoE(
316
            shared_experts=self.shared_experts,
317
            gate=self.gate,
318
319
320
321
322
323
            num_experts=config.n_routed_experts,
            top_k=config.num_experts_per_tok,
            hidden_size=config.hidden_size,
            intermediate_size=config.moe_intermediate_size,
            renormalize=config.norm_topk_prob,
            quant_config=quant_config,
324
325
326
            use_grouped_topk=True,
            num_expert_group=getattr(config, "n_group", 1),
            topk_group=getattr(config, "topk_group", 1),
327
            prefix=f"{prefix}.experts",
328
            scoring_func=getattr(config, "scoring_func", "softmax"),
329
            # aiter applies routed_scaling_factor internally
330
331
            routed_scaling_factor=self.routed_scaling_factor,
            apply_routed_scale_to_output=not self.is_rocm_aiter_moe_enabled,
332
333
334
335
            e_score_correction_bias=self.gate.e_score_correction_bias,
            enable_eplb=self.enable_eplb,
            num_redundant_experts=self.n_redundant_experts,
            is_sequence_parallel=self.is_sequence_parallel,
336
            n_shared_experts=config.n_shared_experts
337
            if self.is_fusion_moe_shared_experts_enabled
338
            else None,
339
        )
340

341
342
343
        # NOTE(rob): this is a hack until we finish off the PR for
        # merging TRTLLM kernels into the MK framework. Then we can
        # query the MonolithicMK for the expected router logits.
344
        # NOTE(dbari): Use BF16 if routing is not Deepseek, e.g. Mistral Large 3
345
        self.gate.set_out_dtype(
346
347
348
349
            torch.float32
            if self.experts.quant_method.is_monolithic
            and self.experts.routing_method_type == RoutingMethodType.DeepSeekV3
            else torch.bfloat16
350
351
        )

wangding zeng's avatar
wangding zeng committed
352
353
354
    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        num_tokens, hidden_dim = hidden_states.shape
        hidden_states = hidden_states.view(-1, hidden_dim)
355
356
357
358
359
360

        # Chunk the hidden states so they aren't replicated across TP ranks.
        # This avoids duplicate computation in self.experts.
        # TODO: We can replace the all_reduce at the end of attn with a
        # reduce_scatter instead of chunking here.
        if self.is_sequence_parallel:
361
            hidden_states = sequence_parallel_chunk(hidden_states)
362

363
        if self.experts.is_internal_router:
364
            final_hidden_states = self.experts(
365
366
367
368
                hidden_states=hidden_states, router_logits=hidden_states
            )
        else:
            router_logits, _ = self.gate(hidden_states)
369
            final_hidden_states = self.experts(
370
371
                hidden_states=hidden_states, router_logits=router_logits
            )
372

373
374
        if self.is_sequence_parallel:
            final_hidden_states = tensor_model_parallel_all_gather(
375
376
                final_hidden_states, 0
            )
377
            final_hidden_states = final_hidden_states[:num_tokens]
wangding zeng's avatar
wangding zeng committed
378
379
380
381
382
383

        return final_hidden_states.view(num_tokens, hidden_dim)


def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
    import math
384

wangding zeng's avatar
wangding zeng committed
385
386
387
388
389
    if scale <= 1:
        return 1.0
    return 0.1 * mscale * math.log(scale) + 1.0


390
391
392
393
394
395
396
397
398
399
def _get_llama_4_scaling(
    original_max_position_embeddings: int, scaling_beta: float, positions: torch.Tensor
) -> torch.Tensor:
    scaling = 1 + scaling_beta * torch.log(
        1 + torch.floor(positions / original_max_position_embeddings)
    )
    # Broadcast over num_heads and head_dim
    return scaling[..., None, None]


wangding zeng's avatar
wangding zeng committed
400
401
402
class DeepseekV2Attention(nn.Module):
    def __init__(
        self,
403
        vllm_config: VllmConfig,
404
        config: DeepseekV2Config | DeepseekV3Config,
wangding zeng's avatar
wangding zeng committed
405
406
407
408
409
410
411
412
        hidden_size: int,
        num_heads: int,
        qk_nope_head_dim: int,
        qk_rope_head_dim: int,
        v_head_dim: int,
        q_lora_rank: int,
        kv_lora_rank: int,
        max_position_embeddings: int = 8192,
413
414
415
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
        topk_indices_buffer: torch.Tensor | None = None,
416
        prefix: str = "",
wangding zeng's avatar
wangding zeng committed
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
    ) -> None:
        super().__init__()
        self.hidden_size = hidden_size
        self.qk_nope_head_dim = qk_nope_head_dim
        self.qk_rope_head_dim = qk_rope_head_dim
        self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
        self.v_head_dim = v_head_dim
        self.q_lora_rank = q_lora_rank
        self.kv_lora_rank = kv_lora_rank
        self.num_heads = num_heads
        tp_size = get_tensor_model_parallel_world_size()
        assert num_heads % tp_size == 0
        self.num_local_heads = num_heads // tp_size
        self.scaling = self.qk_head_dim**-0.5
        self.max_position_embeddings = max_position_embeddings
432
433
        assert topk_indices_buffer is None, (
            "topk_indices_buffer is not \
434
        supported for DeepseekV2Attention"
435
        )
wangding zeng's avatar
wangding zeng committed
436
437

        if self.q_lora_rank is not None:
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
            self.q_a_proj = ReplicatedLinear(
                self.hidden_size,
                self.q_lora_rank,
                bias=False,
                quant_config=quant_config,
                prefix=f"{prefix}.q_a_proj",
            )
            self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
            self.q_b_proj = ColumnParallelLinear(
                q_lora_rank,
                self.num_heads * self.qk_head_dim,
                bias=False,
                quant_config=quant_config,
                prefix=f"{prefix}.q_b_proj",
            )
wangding zeng's avatar
wangding zeng committed
453
        else:
454
455
456
457
458
459
460
            self.q_proj = ColumnParallelLinear(
                self.hidden_size,
                self.num_heads * self.qk_head_dim,
                bias=False,
                quant_config=quant_config,
                prefix=f"{prefix}.q_proj",
            )
wangding zeng's avatar
wangding zeng committed
461

462
463
464
465
466
        self.kv_a_proj_with_mqa = ReplicatedLinear(
            self.hidden_size,
            self.kv_lora_rank + self.qk_rope_head_dim,
            bias=False,
            quant_config=quant_config,
467
468
469
            prefix=f"{prefix}.kv_a_proj_with_mqa",
        )
        self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
wangding zeng's avatar
wangding zeng committed
470
471
472
473
        self.kv_b_proj = ColumnParallelLinear(
            self.kv_lora_rank,
            self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
            bias=False,
474
            quant_config=quant_config,
475
476
            prefix=f"{prefix}.kv_b_proj",
        )
wangding zeng's avatar
wangding zeng committed
477
        # O projection.
478
479
480
481
482
483
484
        self.o_proj = RowParallelLinear(
            self.num_heads * self.v_head_dim,
            self.hidden_size,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.o_proj",
        )
485
        if config.rope_parameters["rope_type"] != "default":
486
487
488
489
490
            config.rope_parameters["rope_type"] = (
                "deepseek_yarn"
                if config.rope_parameters.get("apply_yarn_scaling", True)
                else "deepseek_llama_scaling"
            )
491

492
493
494
        self.rotary_emb = get_rope(
            qk_rope_head_dim,
            max_position=max_position_embeddings,
495
            rope_parameters=config.rope_parameters,
496
497
            is_neox_style=False,
        )
wangding zeng's avatar
wangding zeng committed
498

499
500
501
502
        if (
            config.rope_parameters["rope_type"] != "default"
            and config.rope_parameters["rope_type"] == "deepseek_yarn"
        ):
503
504
            mscale_all_dim = config.rope_parameters.get("mscale_all_dim", False)
            scaling_factor = config.rope_parameters["factor"]
wangding zeng's avatar
wangding zeng committed
505
506
507
            mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
            self.scaling = self.scaling * mscale * mscale

508
509
510
511
512
513
514
515
516
        self.attn = Attention(
            self.num_local_heads,
            self.qk_head_dim,
            self.scaling,
            num_kv_heads=self.num_local_heads,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.attn",
        )
wangding zeng's avatar
wangding zeng committed
517
518
519
520
521

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
522
        llama_4_scaling: torch.Tensor | None,
wangding zeng's avatar
wangding zeng committed
523
524
525
526
    ) -> torch.Tensor:
        if self.q_lora_rank is not None:
            q = self.q_a_proj(hidden_states)[0]
            q = self.q_a_layernorm(q)
527
            q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
wangding zeng's avatar
wangding zeng committed
528
        else:
529
530
531
532
            q = self.q_proj(hidden_states)[0].view(
                -1, self.num_local_heads, self.qk_head_dim
            )
        q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
wangding zeng's avatar
wangding zeng committed
533
        latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
534
        kv_a, _ = latent_cache.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
wangding zeng's avatar
wangding zeng committed
535
        latent_cache = latent_cache.unsqueeze(1)
536
        kv_a = self.kv_a_layernorm(kv_a)
wangding zeng's avatar
wangding zeng committed
537
        kv = self.kv_b_proj(kv_a)[0]
538
        kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim)
wangding zeng's avatar
wangding zeng committed
539
        k_nope, v = kv.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
540
        k_pe = latent_cache[:, :, self.kv_lora_rank :]
541

wangding zeng's avatar
wangding zeng committed
542
        q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
543

544
        q[..., self.qk_nope_head_dim :] = q_pe
wangding zeng's avatar
wangding zeng committed
545
        k = torch.empty_like(q)
546
547
        k[..., : self.qk_nope_head_dim] = k_nope
        k[..., self.qk_nope_head_dim :] = k_pe
548
549
550
551
552

        # Apply llama 4 scaling if provided
        if llama_4_scaling is not None:
            q *= llama_4_scaling

553
554
        # padding value to qk_head_dim for alignment
        v = torch.nn.functional.pad(
555
556
            v, [0, self.qk_head_dim - self.v_head_dim], value=0
        ).view(-1, self.num_local_heads * self.qk_head_dim)
557
        attn_output = self.attn(q, k, v)
558
559
560
        attn_output = attn_output.view(-1, self.num_local_heads, self.qk_head_dim)[
            ..., : self.v_head_dim
        ].reshape(-1, self.num_local_heads * self.v_head_dim)
wangding zeng's avatar
wangding zeng committed
561
562
563
564
        output, _ = self.o_proj(attn_output)
        return output


565
class DeepseekV32IndexerCache(torch.nn.Module, AttentionLayerBase):
566
567
568
    def __init__(
        self, head_dim: int, dtype: torch.dtype, prefix: str, cache_config: CacheConfig
    ):
569
        super().__init__()
570
        self.kv_cache = torch.tensor([])
571
572
573
574
575
576
577
578
579
        self.head_dim = head_dim
        self.prefix = prefix
        self.cache_config = cache_config
        self.dtype = dtype
        compilation_config = get_current_vllm_config().compilation_config
        if prefix in compilation_config.static_forward_context:
            raise ValueError(f"Duplicate layer name: {prefix}")
        compilation_config.static_forward_context[prefix] = self

580
    def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
581
582
583
584
585
586
587
        return MLAAttentionSpec(  # Only has one vector instead of K + V
            block_size=self.cache_config.block_size,
            num_kv_heads=1,
            head_size=self.head_dim,
            dtype=self.dtype,
        )

588
    def forward(self): ...
589
590
591
592
593
594

    def get_attn_backend(self) -> AttentionBackend:
        return DeepseekV32IndexerBackend


class Indexer(nn.Module):
595
596
597
    def __init__(
        self,
        vllm_config: VllmConfig,
598
        config: DeepseekV2Config | DeepseekV3Config,
599
600
        hidden_size: int,
        q_lora_rank: int,
601
602
603
        quant_config: QuantizationConfig | None,
        cache_config: CacheConfig | None,
        topk_indices_buffer: torch.Tensor | None,
604
605
        prefix: str = "",
    ):
606
607
608
        super().__init__()
        self.vllm_config = vllm_config
        self.config = config
609
        self.quant_config = quant_config
610
611
612
613
614
615
616
        # self.indexer_cfg = config.attn_module_list_cfg[0]["attn_index"]
        self.topk_tokens = config.index_topk
        self.n_head = config.index_n_heads  # 64
        self.head_dim = config.index_head_dim  # 128
        self.rope_dim = config.qk_rope_head_dim  # 64
        self.q_lora_rank = q_lora_rank  # 1536
        # no tensor parallel, just replicated
617
618
619
620
621
622
623
        self.wq_b = ReplicatedLinear(
            self.q_lora_rank,
            self.head_dim * self.n_head,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.wq_b",
        )
624
625
626
627
628
629
630
631
632
633
        # Fused wk + weights_proj: single GEMM producing [head_dim + n_head].
        # FP8 wk weights are upcasted to BF16 during loading to maintain fusion.
        self.wk_weights_proj = MergedColumnParallelLinear(
            hidden_size,
            [self.head_dim, self.n_head],
            bias=False,
            quant_config=None,
            disable_tp=True,
            prefix=f"{prefix}.wk_weights_proj",
        )
634
        self.k_norm = LayerNorm(self.head_dim, eps=1e-6)
635
636
637
638
639
640
641
642
643
644
        self.softmax_scale = self.head_dim**-0.5

        self.scale_fmt = "ue8m0"
        self.quant_block_size = 128  # TODO: get from config
        self.topk_indices_buffer = topk_indices_buffer

        # NOTE: (zyongye) we use fp8 naive cache,
        #       where we store value in fp8 and scale in fp32
        #       per self.quant_block_size element
        self.k_cache = DeepseekV32IndexerCache(
645
            head_dim=self.head_dim + self.head_dim // self.quant_block_size * 4,
646
647
            dtype=torch.uint8,
            prefix=f"{prefix}.k_cache",
648
649
            cache_config=cache_config,
        )
650
651
        self.max_model_len = vllm_config.model_config.max_model_len
        self.prefix = prefix
652
653
        from vllm.v1.attention.backends.mla.indexer import get_max_prefill_buffer_size

654
        self.max_total_seq_len = get_max_prefill_buffer_size(vllm_config)
655
656
657
658
659
660
661
662
663
664
        self.indexer_op = SparseAttnIndexer(
            self.k_cache,
            self.quant_block_size,
            self.scale_fmt,
            self.topk_tokens,
            self.head_dim,
            self.max_model_len,
            self.max_total_seq_len,
            self.topk_indices_buffer,
        )
665

666
667
668
    def forward(
        self, hidden_states: torch.Tensor, qr: torch.Tensor, positions, rotary_emb
    ) -> torch.Tensor:
669
670
671
        q, _ = self.wq_b(qr)
        q = q.view(-1, self.n_head, self.head_dim)
        q_pe, q_nope = torch.split(
672
673
            q, [self.rope_dim, self.head_dim - self.rope_dim], dim=-1
        )
674
675
676
677
        # Fused wk + weights_proj: one GEMM, then split
        kw, _ = self.wk_weights_proj(hidden_states)
        k = kw[:, : self.head_dim]
        weights = kw[:, self.head_dim :]
678

679
680
        k = self.k_norm(k)
        k_pe, k_nope = torch.split(
681
682
            k, [self.rope_dim, self.head_dim - self.rope_dim], dim=-1
        )
683
684

        q_pe, k_pe = rotary_emb(positions, q_pe, k_pe.unsqueeze(1))
685
686
687
688
689
        # Note: RoPE (NeoX) can introduce extra leading dimensions during compilation
        # so we need to reshape back to token-flattened shapes
        q_pe = q_pe.reshape(-1, self.n_head, self.rope_dim)
        k_pe = k_pe.reshape(-1, 1, self.rope_dim)

690
691
        # `rotary_emb` is shape-preserving; `q_pe` is already
        # [num_tokens, n_head, rope_dim].
692
693
        q = torch.cat([q_pe, q_nope], dim=-1)
        # `k_pe` is [num_tokens, 1, rope_dim] (MQA).
694
        k = torch.cat([k_pe.squeeze(-2), k_nope], dim=-1)
695
696
697

        # we only quant q here since k quant is fused with cache insertion
        q = q.view(-1, self.head_dim)
698
699
700
701
702
703
        q_fp8, q_scale = per_token_group_quant_fp8(
            q,
            self.quant_block_size,
            column_major_scales=False,
            use_ue8m0=self.scale_fmt is not None,
        )
704
705
706
        q_fp8 = q_fp8.view(-1, self.n_head, self.head_dim)
        q_scale = q_scale.view(-1, self.n_head, 1)

707
        weights = (
708
            weights.unsqueeze(-1) * q_scale * self.softmax_scale * self.n_head**-0.5
709
        )
710
711
        weights = weights.squeeze(-1)

712
        return self.indexer_op(hidden_states, q_fp8, k, weights)
713
714


715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
def _try_load_fp8_indexer_wk(name, tensor, buf, params_dict, loaded_params):
    """
    We fuse the WK and weights_proj projections, but in some checkpoints WK is stored
    in FP8 with a separate weight_scale_inv, while weights_proj is stored in BF16.
    Upcasting to BF16 during loading enables the fusion. This function loads the FP8 WK
    weights and scale, and when both are available, dequantizes to BF16 and stores into
    the fused wk_weights_proj.weight parameter.
    """
    if "indexer.wk." not in name or "wk_weights" in name:
        return False  # Weight is not an isolated WK weight for the indexer, ignore.
    is_weight = name.endswith(".weight") and tensor.dtype == torch.float8_e4m3fn
    is_scale = "weight_scale_inv" in name
    if not is_weight and not is_scale:
        return False  # WK is not in FP8 format, ignore.
    # Buffer this tensor (weight or scale) until both have arrived.
    layer_prefix = name.rsplit(".wk.", 1)[0]  # e.g. "model.layers.0.self_attn.indexer"
    entry = buf.setdefault(layer_prefix, {})
    entry["weight" if is_weight else "scale"] = tensor
    if "weight" not in entry or "scale" not in entry:
        return True  # still waiting for the other param

    # We have both weight and scale: dequantize FP8 to BF16.
    weight_fp8, scale_inv = entry["weight"], entry["scale"]
    del buf[layer_prefix]
    block_size = weight_fp8.shape[1] // scale_inv.shape[1]
    weight_bf16 = scaled_dequantize(
        weight_fp8,
        scale_inv,
        group_shape=GroupShape(block_size, block_size),
        out_dtype=torch.bfloat16,
    )

    # Load the dequantized weight into shard 0 of the fused buffer.
    fused_name = f"{layer_prefix}.wk_weights_proj.weight"
    param = params_dict[fused_name]
    param.weight_loader(param, weight_bf16, 0)
    loaded_params.add(fused_name)
    return True


755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
def _min_latency_fused_qkv_a_proj_impl(
    input_: torch.Tensor,
    weight: torch.Tensor,
) -> torch.Tensor:
    """
    Dynamically run min-latency gemm if num_tokens <= 16.
    This must be wrapped in a custom op because our torch.compile integration
    does not support runtime dispatching on num_tokens.
    """
    num_tokens = input_.shape[0]
    if 0 < num_tokens <= 16:
        output = torch.empty(
            num_tokens,
            weight.shape[0],
            dtype=torch.bfloat16,
            device=input_.device,
        )
        ops.dsv3_fused_a_gemm(output, input_, weight.T)
        return output
    else:
        return torch.nn.functional.linear(input_, weight)


def _min_latency_fused_qkv_a_proj_fake(
    input_: torch.Tensor,
    weight: torch.Tensor,
) -> torch.Tensor:
    return input_.new_empty(input_.shape[0], weight.shape[0])


direct_register_custom_op(
    op_name="min_latency_fused_qkv_a_proj",
    op_func=_min_latency_fused_qkv_a_proj_impl,
    mutates_args=[],
    fake_impl=_min_latency_fused_qkv_a_proj_fake,
)


793
class DeepSeekV2FusedQkvAProjLinear(MergedColumnParallelLinear):
794
795
796
    def __init__(
        self,
        input_size: int,
797
        output_size: list[int],
798
799
800
801
802
803
804
805
806
        quant_config: QuantizationConfig | None = None,
        prefix: str = "",
    ):
        super().__init__(
            input_size,
            output_size,
            bias=False,
            quant_config=quant_config,
            disable_tp=True,
807
            prefix=prefix,
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
        )

        # Check if the DeepSeek V3 fused A GEMM kernel can be used.
        # This kernel supports PDL and is optimized for low batch size.
        self._use_min_latency_gemm = (
            hasattr(self, "weight")
            and self.weight.dtype == torch.bfloat16
            and self.weight.shape[0] == 2112
            and self.weight.shape[1] == 7168
            and current_platform.is_cuda()
            and (
                current_platform.is_device_capability(90)
                or current_platform.is_device_capability_family(100)
            )
        )

    def forward(
        self,
        input_,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.nn.Parameter | None]:
828
829
        if self._use_min_latency_gemm:
            output = torch.ops.vllm.min_latency_fused_qkv_a_proj(input_, self.weight)
830
831
832
833
834
835
836
837
838
839
            if not self.return_bias:
                return output
            output_bias = self.bias if self.skip_bias_add else None
            return output, output_bias
        else:
            # Fallback to the standard forward method when
            # the fused A GEMM kernel cannot be used.
            return super().forward(input_)


840
841
842
843
class DeepseekV2MLAAttention(nn.Module):
    """
    Main reference: DeepseekV2 paper, and FlashInfer Implementation
    (https://arxiv.org/abs/2405.04434 and https://github.com/flashinfer-ai/flashinfer/pull/551).
844

845
846
        For more info see MLACommonImpl in:
        vllm/v1/attention/backends/mla/utils.py
847
848
849
850
    """

    def __init__(
        self,
851
        vllm_config: VllmConfig,
852
        config: DeepseekV2Config | DeepseekV3Config,
853
854
855
856
857
        hidden_size: int,
        num_heads: int,
        qk_nope_head_dim: int,
        qk_rope_head_dim: int,
        v_head_dim: int,
858
        q_lora_rank: int | None,
859
860
        kv_lora_rank: int,
        max_position_embeddings: int = 8192,
861
862
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
863
        prefix: str = "",
864
        topk_indices_buffer: torch.Tensor | None = None,
865
        input_size: int | None = None,
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
    ) -> None:
        super().__init__()
        self.hidden_size = hidden_size
        self.qk_nope_head_dim = qk_nope_head_dim
        self.qk_rope_head_dim = qk_rope_head_dim
        self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
        self.v_head_dim = v_head_dim

        self.q_lora_rank = q_lora_rank
        self.kv_lora_rank = kv_lora_rank

        self.num_heads = num_heads
        tp_size = get_tensor_model_parallel_world_size()
        assert num_heads % tp_size == 0
        self.num_local_heads = num_heads // tp_size

        self.scaling = self.qk_head_dim**-0.5
        self.max_position_embeddings = max_position_embeddings

885
886
887
888
        # Use input_size for projection input dimensions if provided,
        # otherwise default to hidden_size (used in Eagle3 Deepseek with MLA)
        proj_input_size = input_size if input_size is not None else self.hidden_size

889
        if self.q_lora_rank is not None:
890
            self.fused_qkv_a_proj = DeepSeekV2FusedQkvAProjLinear(
891
                proj_input_size,
892
893
                [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim],
                quant_config=quant_config,
894
                prefix=f"{prefix}.fused_qkv_a_proj",
895
            )
896
897
        else:
            self.kv_a_proj_with_mqa = ReplicatedLinear(
898
                proj_input_size,
899
900
901
                self.kv_lora_rank + self.qk_rope_head_dim,
                bias=False,
                quant_config=quant_config,
902
903
                prefix=f"{prefix}.kv_a_proj_with_mqa",
            )
904
905

        if self.q_lora_rank is not None:
906
907
908
909
910
911
912
913
            self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
            self.q_b_proj = ColumnParallelLinear(
                self.q_lora_rank,
                self.num_heads * self.qk_head_dim,
                bias=False,
                quant_config=quant_config,
                prefix=f"{prefix}.q_b_proj",
            )
914
        else:
915
            self.q_proj = ColumnParallelLinear(
916
                proj_input_size,
917
918
919
920
921
922
                self.num_heads * self.qk_head_dim,
                bias=False,
                quant_config=quant_config,
                prefix=f"{prefix}.q_proj",
            )
        self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
923
924
925
926
927
        self.kv_b_proj = ColumnParallelLinear(
            self.kv_lora_rank,
            self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
            bias=False,
            quant_config=quant_config,
928
929
930
931
932
933
934
935
936
            prefix=f"{prefix}.kv_b_proj",
        )
        self.o_proj = RowParallelLinear(
            self.num_heads * self.v_head_dim,
            self.hidden_size,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.o_proj",
        )
937

938
        if config.rope_parameters["rope_type"] != "default":
939
940
941
942
943
944
            config.rope_parameters["rope_type"] = (
                "deepseek_yarn"
                if config.rope_parameters.get("apply_yarn_scaling", True)
                else "deepseek_llama_scaling"
            )

945
946
947
        self.rotary_emb = get_rope(
            qk_rope_head_dim,
            max_position=max_position_embeddings,
948
            rope_parameters=config.rope_parameters,
949
950
            is_neox_style=False,
        )
951
952
953
954
955

        if (
            config.rope_parameters["rope_type"] != "default"
            and config.rope_parameters["rope_type"] == "deepseek_yarn"
        ):
956
957
            mscale_all_dim = config.rope_parameters.get("mscale_all_dim", False)
            scaling_factor = config.rope_parameters["factor"]
958
959
960
            mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
            self.scaling = self.scaling * mscale * mscale

961
962
963
        self.is_v32 = hasattr(config, "index_topk")

        if self.is_v32:
964
965
966
            self.indexer_rope_emb = get_rope(
                qk_rope_head_dim,
                max_position=max_position_embeddings,
967
                rope_parameters=config.rope_parameters,
968
                is_neox_style=not getattr(config, "indexer_rope_interleave", False),
969
            )
970
971
972
973
974
975
976
977
978
979
            self.indexer = Indexer(
                vllm_config,
                config,
                hidden_size,
                q_lora_rank,
                quant_config,
                cache_config,
                topk_indices_buffer,
                f"{prefix}.indexer",
            )
980
        else:
981
            self.indexer_rope_emb = None
982
983
            self.indexer = None

984
985
        mla_modules = MLAModules(
            kv_a_layernorm=self.kv_a_layernorm,
986
            kv_b_proj=self.kv_b_proj,
987
988
989
            rotary_emb=self.rotary_emb,
            o_proj=self.o_proj,
            fused_qkv_a_proj=self.fused_qkv_a_proj
990
991
            if self.q_lora_rank is not None
            else None,
992
            kv_a_proj_with_mqa=self.kv_a_proj_with_mqa
993
994
995
            if self.q_lora_rank is None
            else None,
            q_a_layernorm=self.q_a_layernorm if self.q_lora_rank is not None else None,
996
997
            q_b_proj=self.q_b_proj if self.q_lora_rank is not None else None,
            q_proj=self.q_proj if self.q_lora_rank is None else None,
998
            indexer=self.indexer,
999
            indexer_rotary_emb=self.indexer_rope_emb,
1000
1001
            is_sparse=self.is_v32,
            topk_indices_buffer=topk_indices_buffer,
1002
        )
1003

1004
        self.mla_attn = MultiHeadLatentAttentionWrapper(
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
            self.hidden_size,
            self.num_local_heads,
            self.scaling,
            self.qk_nope_head_dim,
            self.qk_rope_head_dim,
            self.v_head_dim,
            self.q_lora_rank,
            self.kv_lora_rank,
            mla_modules,
            cache_config,
            quant_config,
            prefix,
1017
1018
1019
1020
1021
1022
        )

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
1023
        llama_4_scaling: torch.Tensor | None,
1024
    ) -> torch.Tensor:
1025
        return self.mla_attn(positions, hidden_states, llama_4_scaling)
1026
1027


wangding zeng's avatar
wangding zeng committed
1028
class DeepseekV2DecoderLayer(nn.Module):
1029
1030
1031
1032
    def __init__(
        self,
        vllm_config: VllmConfig,
        prefix: str,
1033
1034
        config: DeepseekV2Config | None = None,
        topk_indices_buffer: torch.Tensor | None = None,
1035
    ) -> None:
wangding zeng's avatar
wangding zeng committed
1036
        super().__init__()
1037

1038
1039
        if config is None:
            config = vllm_config.model_config.hf_config
1040
1041
1042
1043
1044
        model_config = vllm_config.model_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config
        parallel_config = vllm_config.parallel_config

wangding zeng's avatar
wangding zeng committed
1045
        self.hidden_size = config.hidden_size
1046
        max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
1047
        moe_layer_freq = getattr(config, "moe_layer_freq", 1)
1048
1049
        # DecoderLayers are created with `make_layers` which passes the prefix
        # with the layer's index.
1050
        layer_idx = int(prefix.split(sep=".")[-1])
1051
        self.layer_idx = layer_idx
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061

        # verify MLA attention specific fields
        qk_nope_head_dim = getattr(config, "qk_nope_head_dim", 0)
        qk_rope_head_dim = getattr(config, "qk_rope_head_dim", 0)
        v_head_dim = getattr(config, "v_head_dim", 0)
        kv_lora_rank = getattr(config, "kv_lora_rank", 0)
        use_mha = config.model_type == "deepseek" or all(
            dim == 0 for dim in (qk_nope_head_dim, qk_rope_head_dim)
        )

1062
1063
        self.use_mha = use_mha

1064
1065
1066
        if use_mha:
            attn_cls = DeepseekAttention
        elif model_config.use_mla:
1067
1068
1069
1070
            attn_cls = DeepseekV2MLAAttention
        else:
            attn_cls = DeepseekV2Attention
        self.self_attn = attn_cls(
1071
            vllm_config=vllm_config,
wangding zeng's avatar
wangding zeng committed
1072
1073
1074
            config=config,
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
1075
1076
1077
            qk_nope_head_dim=qk_nope_head_dim,
            qk_rope_head_dim=qk_rope_head_dim,
            v_head_dim=v_head_dim,
1078
            q_lora_rank=config.q_lora_rank if hasattr(config, "q_lora_rank") else None,
1079
            kv_lora_rank=kv_lora_rank,
wangding zeng's avatar
wangding zeng committed
1080
1081
1082
            max_position_embeddings=max_position_embeddings,
            cache_config=cache_config,
            quant_config=quant_config,
1083
            prefix=f"{prefix}.self_attn",
1084
            topk_indices_buffer=topk_indices_buffer,
wangding zeng's avatar
wangding zeng committed
1085
        )
1086

1087
1088
1089
        if (
            config.n_routed_experts is not None
            and layer_idx >= config.first_k_dense_replace
1090
            and layer_idx % moe_layer_freq == 0
1091
        ):
1092
1093
            self.mlp = DeepseekV2MoE(
                config=config,
1094
                parallel_config=parallel_config,
1095
1096
1097
                quant_config=quant_config,
                prefix=f"{prefix}.mlp",
            )
wangding zeng's avatar
wangding zeng committed
1098
1099
1100
1101
1102
1103
        else:
            self.mlp = DeepseekV2MLP(
                hidden_size=config.hidden_size,
                intermediate_size=config.intermediate_size,
                hidden_act=config.hidden_act,
                quant_config=quant_config,
1104
                prefix=f"{prefix}.mlp",
wangding zeng's avatar
wangding zeng committed
1105
            )
1106
1107
1108
1109
        self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(
            config.hidden_size, eps=config.rms_norm_eps
        )
1110
        self.routed_scaling_factor = getattr(config, "routed_scaling_factor", 1.0)
wangding zeng's avatar
wangding zeng committed
1111
1112
1113
1114
1115

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
1116
        residual: torch.Tensor | None,
1117
        llama_4_scaling: torch.Tensor | None = None,
wangding zeng's avatar
wangding zeng committed
1118
1119
1120
    ) -> torch.Tensor:
        # Self Attention
        if residual is None:
1121
            residual = hidden_states.clone()
wangding zeng's avatar
wangding zeng committed
1122
1123
            hidden_states = self.input_layernorm(hidden_states)
        else:
1124
            hidden_states, residual = self.input_layernorm(hidden_states, residual)
1125
1126
1127
1128
1129
1130
1131
1132

        attn_kwargs = {
            "positions": positions,
            "hidden_states": hidden_states,
        }
        if not self.use_mha:
            attn_kwargs["llama_4_scaling"] = llama_4_scaling
        hidden_states = self.self_attn(**attn_kwargs)
wangding zeng's avatar
wangding zeng committed
1133

1134
1135
1136
1137
        if (
            not isinstance(self.self_attn, DeepseekAttention)
            and hidden_states.dtype == torch.float16
        ):
1138
1139
1140
            # Fix FP16 overflow
            # We scale both hidden_states and residual before
            # rmsnorm, and rmsnorm result would not affect by scale.
1141
            hidden_states *= 1.0 / self.routed_scaling_factor
1142
1143
1144
            if self.layer_idx == 0:
                # The residual is shared by all layers, we only scale it on
                # first layer.
1145
                residual *= 1.0 / self.routed_scaling_factor
1146
1147

        # Fully Connected
1148
        hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
wangding zeng's avatar
wangding zeng committed
1149
        hidden_states = self.mlp(hidden_states)
1150

1151
        if isinstance(self.mlp, DeepseekV2MLP) and hidden_states.dtype == torch.float16:
1152
1153
1154
1155
1156
            # Fix FP16 overflow
            # Scaling the DeepseekV2MLP output, it is the input of
            # input_layernorm of next decoder layer.
            # The scaling of DeepseekV2MOE output would be done in the forward
            # of DeepseekV2MOE
1157
            hidden_states *= 1.0 / self.routed_scaling_factor
1158

wangding zeng's avatar
wangding zeng committed
1159
1160
1161
        return hidden_states, residual


1162
@support_torch_compile
wangding zeng's avatar
wangding zeng committed
1163
1164
1165
class DeepseekV2Model(nn.Module):
    fall_back_to_pt_during_load = False

1166
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
wangding zeng's avatar
wangding zeng committed
1167
        super().__init__()
1168
1169
1170

        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
1171
        self.config = config
1172
        self.device = current_platform.device_type
1173

wangding zeng's avatar
wangding zeng committed
1174
        self.vocab_size = config.vocab_size
1175
1176
1177
1178
1179
1180
1181
        self.is_v32 = hasattr(config, "index_topk")
        if self.is_v32:
            topk_tokens = config.index_topk
            topk_indices_buffer = torch.empty(
                vllm_config.scheduler_config.max_num_batched_tokens,
                topk_tokens,
                dtype=torch.int32,
1182
                device=self.device,
1183
            )
1184
1185
        else:
            topk_indices_buffer = None
wangding zeng's avatar
wangding zeng committed
1186

1187
1188
1189
1190
        if get_pp_group().is_first_rank:
            self.embed_tokens = VocabParallelEmbedding(
                config.vocab_size,
                config.hidden_size,
1191
                quant_config=quant_config,
1192
1193
                prefix=f"{prefix}.embed_tokens",
            )
1194
1195
1196
1197
        else:
            self.embed_tokens = PPMissingLayer()
        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
1198
            lambda prefix: DeepseekV2DecoderLayer(
1199
1200
1201
                vllm_config,
                prefix,
                topk_indices_buffer=topk_indices_buffer,
1202
1203
1204
            ),
            prefix=f"{prefix}.layers",
        )
1205
1206
1207
1208
1209

        if get_pp_group().is_last_rank:
            self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        else:
            self.norm = PPMissingLayer()
1210
1211
1212
        self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
            ["hidden_states", "residual"], config.hidden_size
        )
wangding zeng's avatar
wangding zeng committed
1213

1214
1215
        self.aux_hidden_state_layers = tuple[int, ...]()

1216
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
1217
1218
        return self.embed_tokens(input_ids)

wangding zeng's avatar
wangding zeng committed
1219
1220
    def forward(
        self,
1221
        input_ids: torch.Tensor | None,
wangding zeng's avatar
wangding zeng committed
1222
        positions: torch.Tensor,
1223
1224
1225
        intermediate_tensors: IntermediateTensors | None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor | IntermediateTensors:
1226
        if get_pp_group().is_first_rank:
1227
1228
1229
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
1230
1231
1232
1233
1234
                if input_ids is None:
                    raise ValueError(
                        "Either input_ids or inputs_embeds must be provided "
                        "to DeepseekV2Model.forward"
                    )
1235
                hidden_states = self.embed_input_ids(input_ids)
1236
1237
1238
1239
1240
1241
            residual = None
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]

1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
        # Compute llama 4 scaling once per forward pass if enabled
        llama_4_scaling_config = getattr(self.config, "llama_4_scaling", None)
        llama_4_scaling: torch.Tensor | None
        if llama_4_scaling_config is not None:
            llama_4_scaling = _get_llama_4_scaling(
                original_max_position_embeddings=llama_4_scaling_config[
                    "original_max_position_embeddings"
                ],
                scaling_beta=llama_4_scaling_config["beta"],
                positions=positions,
            )
        else:
            llama_4_scaling = None

1256
1257
1258
1259
1260
1261
1262
        aux_hidden_states = []
        for idx, layer in enumerate(
            islice(self.layers, self.start_layer, self.end_layer),
            start=self.start_layer,
        ):
            if idx in self.aux_hidden_state_layers:
                aux_hidden_states.append(hidden_states + residual)
1263
1264
1265
            hidden_states, residual = layer(
                positions, hidden_states, residual, llama_4_scaling
            )
1266
1267

        if not get_pp_group().is_last_rank:
1268
1269
1270
            return IntermediateTensors(
                {"hidden_states": hidden_states, "residual": residual}
            )
1271

wangding zeng's avatar
wangding zeng committed
1272
        hidden_states, _ = self.norm(hidden_states, residual)
1273
1274
        if len(aux_hidden_states) > 0:
            return hidden_states, aux_hidden_states
wangding zeng's avatar
wangding zeng committed
1275
1276
        return hidden_states

1277

1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
class DeepseekV2MixtureOfExperts(MixtureOfExperts):
    moe_mlp_layers: list[DeepseekV2MoE]
    """
    List of MoE MLP layers in the model.
    """

    def extract_moe_parameters(self, example_moe: DeepseekV2MoE | None):
        if example_moe is None:
            self.num_moe_layers = 0
            self.num_expert_groups = 0
            self.num_logical_experts = 0
            self.num_physical_experts = 0
            self.num_local_physical_experts = 0
            self.num_routed_experts = 0
            self.num_shared_experts = 0
            self.num_redundant_experts = 0
            logger.warning("DeepSeekV2: No DeepseekV2MoE layer found in model.layers.")
        else:
            self.num_logical_experts = example_moe.n_logical_experts
            self.num_physical_experts = example_moe.n_physical_experts
            self.num_local_physical_experts = example_moe.n_local_physical_experts
            self.num_routed_experts = example_moe.n_routed_experts
            self.num_shared_experts = example_moe.n_shared_experts
            self.num_redundant_experts = example_moe.n_redundant_experts

    def update_physical_experts_metadata(
        self,
        num_physical_experts: int,
        num_local_physical_experts: int,
    ) -> None:
        assert self.num_local_physical_experts == num_local_physical_experts
        self.num_physical_experts = num_physical_experts
        self.num_local_physical_experts = num_local_physical_experts
        self.num_redundant_experts = num_physical_experts - self.num_logical_experts
        for moe in self.moe_mlp_layers:
            moe.n_local_physical_experts = num_local_physical_experts
            moe.n_physical_experts = num_physical_experts
            moe.n_redundant_experts = self.num_redundant_experts
            moe.experts.update_expert_map()


class DeepseekV2ForCausalLM(
1320
1321
1322
1323
1324
1325
    nn.Module,
    SupportsPP,
    DeepseekV2MixtureOfExperts,
    SupportsLoRA,
    SupportsEagle,
    SupportsEagle3,
1326
):
1327
1328
1329
    packed_modules_mapping = {
        "gate_up_proj": ["gate_proj", "up_proj"],
    }
1330
    model_cls = DeepseekV2Model
1331
1332
1333
1334
1335
1336
1337

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        self.config = config
        self.quant_config = quant_config
1338

1339
1340
1341
1342
1343
1344
1345
1346
1347
        qk_nope_head_dim = getattr(config, "qk_nope_head_dim", 0)
        qk_rope_head_dim = getattr(config, "qk_rope_head_dim", 0)
        self.use_mha = config.model_type == "deepseek" or all(
            dim == 0 for dim in (qk_nope_head_dim, qk_rope_head_dim)
        )

        if self.use_mha:
            self.packed_modules_mapping["qkv_proj"] = ["q_proj", "k_proj", "v_proj"]

1348
1349
1350
1351
        # `packed_modules_mapping` needs to be modified before
        # initializing DeepseekV2Model, as it is passed inplace to
        # quantization config init and may be used to select the
        # quant_method for relevant layers during initialization.
1352
1353
1354
        self.fuse_qkv_a_proj = (
            hasattr(config, "q_lora_rank") and config.q_lora_rank is not None
        )
1355
1356
1357
1358
1359
1360
        if self.fuse_qkv_a_proj:
            self.packed_modules_mapping["fused_qkv_a_proj"] = [
                "q_a_proj",
                "kv_a_proj_with_mqa",
            ]

1361
        self.model = self.model_cls(
1362
1363
            vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
        )
1364
        if get_pp_group().is_last_rank:
1365
1366
1367
1368
1369
1370
            self.lm_head = ParallelLMHead(
                config.vocab_size,
                config.hidden_size,
                quant_config=quant_config,
                prefix=maybe_prefix(prefix, "lm_head"),
            )
1371
1372
1373
1374
        else:
            self.lm_head = PPMissingLayer()
        self.logits_processor = LogitsProcessor(config.vocab_size)
        self.make_empty_intermediate_tensors = (
1375
1376
            self.model.make_empty_intermediate_tensors
        )
1377
1378
1379
1380
1381
1382
1383
        # Set MoE hyperparameters
        self.num_moe_layers = (
            self.config.num_hidden_layers - self.config.first_k_dense_replace
        )
        self.set_moe_parameters()

    def set_moe_parameters(self):
1384
1385
        self.expert_weights = []

1386
        self.num_expert_groups = getattr(self.config, "n_group", 1)
1387

1388
1389
        self.moe_layers = []
        self.moe_mlp_layers = []
1390
        example_moe = None
1391
        for layer in self.model.layers:
1392
1393
1394
            if isinstance(layer, PPMissingLayer):
                continue

1395
1396
            assert isinstance(layer, DeepseekV2DecoderLayer)
            if isinstance(layer.mlp, DeepseekV2MoE):
1397
1398
                # Pick last one layer since the first ones may be dense layers.
                example_moe = layer.mlp
1399
                self.moe_mlp_layers.append(layer.mlp)
1400
1401
                self.moe_layers.append(layer.mlp.experts)

1402
        self.extract_moe_parameters(example_moe)
1403

1404
1405
1406
1407
1408
1409
1410
    def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
        self.model.aux_hidden_state_layers = layers

    def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
        num_layers = len(self.model.layers)
        return (2, num_layers // 2, num_layers - 3)

1411
1412
    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.embed_input_ids(input_ids)
1413
1414
1415

    def forward(
        self,
1416
        input_ids: torch.Tensor | None,
1417
        positions: torch.Tensor,
1418
1419
1420
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor | IntermediateTensors:
1421
1422
1423
        hidden_states = self.model(
            input_ids, positions, intermediate_tensors, inputs_embeds
        )
1424
1425
1426
1427
1428
        return hidden_states

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
1429
    ) -> torch.Tensor | None:
1430
        logits = self.logits_processor(self.lm_head, hidden_states)
1431
1432
        return logits

1433
1434
1435
    def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
        # Params for weights, fp8 weight scales, fp8 activation scales
        # (param_name, weight_name, expert_id, shard_id)
1436
        return fused_moe_make_expert_params_mapping(
1437
            self,
1438
1439
1440
1441
1442
1443
1444
            ckpt_gate_proj_name="gate_proj",
            ckpt_down_proj_name="down_proj",
            ckpt_up_proj_name="up_proj",
            num_experts=self.config.n_routed_experts,
            num_redundant_experts=0,
        )

1445
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
1446
1447
1448
        rocm_aiter_moe_shared_expert_enabled = (
            rocm_aiter_ops.is_fusion_moe_shared_experts_enabled()
        )
wangding zeng's avatar
wangding zeng committed
1449
1450
1451
1452
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
1453
1454
        ]
        mla_params_mapping = [
1455
1456
            ("fused_qkv_a_proj", "q_a_proj", 0),
            ("fused_qkv_a_proj", "kv_a_proj_with_mqa", 1),
wangding zeng's avatar
wangding zeng committed
1457
        ]
1458
1459
1460
1461
1462
        mha_params_mapping = [
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
        ]
1463
1464
1465
1466
1467
1468
1469
        # Fused indexer wk + weights_proj (shard 0 = wk, shard 1 = weights_proj)
        _pending_wk_fp8: dict = {}  # When WK is in FP8, we dequant to BF16 for fusion
        indexer_fused_mapping = [
            ("wk_weights_proj", "wk", 0),
            ("wk_weights_proj", "weights_proj", 1),
        ]
        stacked_params_mapping.extend(indexer_fused_mapping)
1470

1471
1472
1473
1474
        if self.use_mha:
            stacked_params_mapping.extend(mha_params_mapping)
        else:
            stacked_params_mapping.extend(mla_params_mapping)
wangding zeng's avatar
wangding zeng committed
1475

1476
1477
        # Params for weights, fp8 weight scales, fp8 activation scales
        # (param_name, weight_name, expert_id, shard_id)
1478
        expert_params_mapping = fused_moe_make_expert_params_mapping(
1479
            self,
1480
1481
1482
            ckpt_gate_proj_name="gate_proj",
            ckpt_down_proj_name="down_proj",
            ckpt_up_proj_name="up_proj",
1483
1484
1485
            num_experts=self.config.n_routed_experts
            + (
                self.config.n_shared_experts
1486
                if rocm_aiter_moe_shared_expert_enabled
1487
1488
                else 0
            ),
1489
1490
            num_redundant_experts=self.num_redundant_experts,
        )
1491

wangding zeng's avatar
wangding zeng committed
1492
        params_dict = dict(self.named_parameters())
1493
        loaded_params: set[str] = set()
wangding zeng's avatar
wangding zeng committed
1494
        for name, loaded_weight in weights:
1495
1496
1497
            if "rotary_emb.inv_freq" in name:
                continue

1498
1499
1500
            spec_layer = get_spec_layer_idx_from_weight_name(self.config, name)
            if spec_layer is not None:
                continue  # skip spec decode layers for main model
1501

1502
1503
            is_fusion_moe_shared_experts_layer = (
                rocm_aiter_moe_shared_expert_enabled and ("mlp.shared_experts" in name)
1504
1505
            )

1506
1507
1508
1509
1510
            if _try_load_fp8_indexer_wk(
                name, loaded_weight, _pending_wk_fp8, params_dict, loaded_params
            ):
                continue

1511
            for param_name, weight_name, shard_id in stacked_params_mapping:
1512
                # Skip non-stacked layers and experts (experts handled below).
wangding zeng's avatar
wangding zeng committed
1513
1514
                if weight_name not in name:
                    continue
1515
1516
1517
1518
1519
1520
                # We have mlp.experts[0].gate_proj in the checkpoint.
                # Since we handle the experts below in expert_params_mapping,
                # we need to skip here BEFORE we update the name, otherwise
                # name will be updated to mlp.experts[0].gate_up_proj, which
                # will then be updated below in expert_params_mapping
                # for mlp.experts[0].gate_gate_up_proj, which breaks load.
1521
                if ("mlp.experts." in name) and name not in params_dict:
1522
                    continue
1523
                if is_fusion_moe_shared_experts_layer:
1524
                    continue
1525
                name_mapped = name.replace(weight_name, param_name)
1526
1527
1528

                # QKV fusion is optional, fall back to normal
                # weight loading if it's not enabled
1529
                # if go with fusion option, then update name
1530
1531
1532
                if (
                    param_name == "fused_qkv_a_proj"
                ) and name_mapped not in params_dict:
1533
                    continue
1534
1535
                else:
                    name = name_mapped
wangding zeng's avatar
wangding zeng committed
1536
1537
1538
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
1539
1540
1541
1542

                if is_pp_missing_parameter(name, self):
                    continue

wangding zeng's avatar
wangding zeng committed
1543
1544
1545
1546
1547
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
1548
                is_expert_weight = False
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558

                # Special handling: when AITER fusion_shared_experts is enabled,
                # checkpoints may provide a single widened shared_experts tensor
                # without explicit expert indices
                # (e.g. ...mlp.shared_experts.gate_proj.weight).
                # For models with multiple shared experts, split that tensor
                # evenly into per-shared-expert slices and load them into
                # appended expert slots mlp.experts.{n_routed_experts + j}.*
                # accordingly.
                num_chunks = 1
1559
                if is_fusion_moe_shared_experts_layer:
1560
1561
1562
1563
                    num_chunks = getattr(self.config, "n_shared_experts", 1) or 1
                    # Determine split axis based on op type
                    # gate/up: ColumnParallel → split along dim 0
                    # down: RowParallel → split along dim 1
1564
1565
1566
1567
1568
                    split_dim = (
                        1
                        if ("down_proj.weight" in name and loaded_weight.ndim > 1)
                        else 0
                    )
1569
1570
1571
1572
                    total = loaded_weight.shape[split_dim]
                    assert total % num_chunks == 0, (
                        f"Shared expert weight dim {total} "
                        f"not divisible by num_chunks {num_chunks}"
1573
                    )
1574
1575
1576
1577
1578
1579
                    chunk_size = total // num_chunks

                for j in range(num_chunks):
                    chunk_name = name
                    weight_to_load = loaded_weight

1580
                    if is_fusion_moe_shared_experts_layer:
1581
1582
1583
1584
1585
                        chunk_slice = slice(j * chunk_size, (j + 1) * chunk_size)
                        if loaded_weight.ndim == 1:
                            weight_to_load = loaded_weight[chunk_slice]
                        elif split_dim == 0:
                            weight_to_load = loaded_weight[chunk_slice, :]
1586
                        else:
1587
                            weight_to_load = loaded_weight[:, chunk_slice]
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
                        # Synthesize an expert-style name so expert mapping
                        # can route it
                        chunk_name = name.replace(
                            "mlp.shared_experts",
                            f"mlp.experts.{self.config.n_routed_experts + j}",
                        )

                    # Use expert_params_mapping to locate the destination
                    # param and delegate to its expert-aware weight_loader
                    # with expert_id.
                    for mapping in expert_params_mapping:
                        param_name, weight_name, expert_id, shard_id = mapping
                        if weight_name not in chunk_name:
                            continue

                        # Anyway, this is an expert weight and should not be
                        # attempted to load as other weights later
                        is_expert_weight = True

                        # Do not modify `name` since the loop may continue here
                        # Instead, create a new variable
                        name_mapped = chunk_name.replace(weight_name, param_name)

                        if is_pp_missing_parameter(name_mapped, self):
                            continue

                        param = params_dict[name_mapped]
                        # We should ask the weight loader to return success or
                        # not here since otherwise we may skip experts with
                        # other available replicas.
                        weight_loader = typing.cast(
                            Callable[..., bool], param.weight_loader
                        )
                        success = weight_loader(
                            param,
                            weight_to_load,
                            name_mapped,
                            shard_id=shard_id,
                            expert_id=expert_id,
                            return_success=True,
                        )
                        if success:
1630
                            if not is_fusion_moe_shared_experts_layer:
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
                                name = name_mapped
                            else:
                                loaded_params.add(name_mapped)
                            break
                    else:
                        if is_expert_weight:
                            # We've checked that this is an expert weight
                            # However it's not mapped locally to this rank
                            # So we simply skip it
                            continue

                        # Skip loading extra bias for GPTQ models.
                        if name.endswith(".bias") and name not in params_dict:
                            continue

                        # Remapping the name of FP8 kv-scale.
                        name = maybe_remap_kv_scale_name(name, params_dict)
                        if name is None:
                            continue

                        if is_pp_missing_parameter(name, self):
                            continue

                        param = params_dict[name]
                        weight_loader = getattr(
                            param, "weight_loader", default_weight_loader
                        )
                        weight_loader(param, loaded_weight)
1659
            if name is not None and not is_fusion_moe_shared_experts_layer:
1660
                loaded_params.add(name)
1661

1662
        return loaded_params
1663
1664


1665
1666
1667
1668
class DeepseekForCausalLM(DeepseekV2ForCausalLM):
    pass


1669
1670
class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM):
    pass
1671
1672


Jee Jee Li's avatar
Jee Jee Li committed
1673
1674
1675
1676
class GlmMoeDsaForCausalLM(DeepseekV2ForCausalLM):
    pass


1677
1678
# Compatibility with
# https://huggingface.co/deepseek-ai/DeepSeek-V3-Base/blob/main/configuration_deepseek.py
1679
def get_spec_layer_idx_from_weight_name(
1680
1681
    config: DeepseekV2Config | DeepseekV3Config, weight_name: str
) -> int | None:
1682
1683
1684
1685
    if (
        hasattr(config, "num_nextn_predict_layers")
        and config.num_nextn_predict_layers > 0
    ):
1686
1687
        layer_idx = config.num_hidden_layers
        for i in range(config.num_nextn_predict_layers):
1688
            if weight_name.startswith(f"model.layers.{layer_idx + i}."):
1689
1690
                return layer_idx + i
    return None