deepseek_v2.py 53.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

wangding zeng's avatar
wangding zeng committed
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# Copyright 2023 The vLLM team.
# Copyright 2023 DeepSeek-AI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
25
"""Inference-only DeepseekV2/DeepseekV3 model."""
26

27
28
import typing
from collections.abc import Callable, Iterable
29
from itertools import islice
30
from typing import Any
wangding zeng's avatar
wangding zeng committed
31
32
33

import torch
from torch import nn
34
from transformers import DeepseekV2Config, DeepseekV3Config
wangding zeng's avatar
wangding zeng committed
35

36
from vllm.attention import Attention
37
38
from vllm.attention.backends.abstract import AttentionBackend
from vllm.attention.ops.common import pack_seq_triton, unpack_seq_triton
39
from vllm.compilation.decorators import support_torch_compile
40
41
42
43
44
45
46
47
from vllm.config import CacheConfig, ParallelConfig, VllmConfig, get_current_vllm_config
from vllm.distributed import (
    get_ep_group,
    get_pp_group,
    get_tensor_model_parallel_rank,
    get_tensor_model_parallel_world_size,
    tensor_model_parallel_all_gather,
)
48
49
from vllm.forward_context import get_forward_context
from vllm.logger import init_logger
wangding zeng's avatar
wangding zeng committed
50
from vllm.model_executor.layers.activation import SiluAndMul
51
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
52
from vllm.model_executor.layers.fused_moe import SharedFusedMoE
53
from vllm.model_executor.layers.layernorm import LayerNorm, RMSNorm
54
55
56
57
58
59
from vllm.model_executor.layers.linear import (
    ColumnParallelLinear,
    MergedColumnParallelLinear,
    ReplicatedLinear,
    RowParallelLinear,
)
wangding zeng's avatar
wangding zeng committed
60
from vllm.model_executor.layers.logits_processor import LogitsProcessor
61
from vllm.model_executor.layers.mla import MLAModules, MultiHeadLatentAttentionWrapper
62
from vllm.model_executor.layers.quantization import QuantizationConfig
63
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
64
65
    per_token_group_quant_fp8,
)
wangding zeng's avatar
wangding zeng committed
66
67
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
68
69
70
    ParallelLMHead,
    VocabParallelEmbedding,
)
71
from vllm.model_executor.model_loader.weight_utils import (
72
73
74
    default_weight_loader,
    maybe_remap_kv_scale_name,
)
75
from vllm.model_executor.models.utils import sequence_parallel_chunk
76
from vllm.platforms import current_platform
77
from vllm.sequence import IntermediateTensors
78
from vllm.utils import direct_register_custom_op
79
from vllm.utils.deep_gemm import fp8_mqa_logits, fp8_paged_mqa_logits
80
81
82
83
from vllm.v1.attention.backends.mla.indexer import (
    DeepseekV32IndexerBackend,
    DeepseekV32IndexerMetadata,
)
84
from vllm.v1.kv_cache_interface import KVCacheSpec, MLAAttentionSpec
wangding zeng's avatar
wangding zeng committed
85

86
from .interfaces import MixtureOfExperts, SupportsLoRA, SupportsPP
87
88
89
90
91
92
93
from .utils import (
    PPMissingLayer,
    is_pp_missing_parameter,
    make_empty_intermediate_tensors_factory,
    make_layers,
    maybe_prefix,
)
94

95
96
97
98
99
100
101
if current_platform.is_cuda_alike():
    from vllm import _custom_ops as ops
elif current_platform.is_xpu():
    from vllm._ipex_ops import ipex_ops as ops

logger = init_logger(__name__)

wangding zeng's avatar
wangding zeng committed
102
103
104
105
106
107
108

class DeepseekV2MLP(nn.Module):
    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
109
        quant_config: QuantizationConfig | None = None,
wangding zeng's avatar
wangding zeng committed
110
        reduce_results: bool = True,
111
        is_sequence_parallel=False,
112
        prefix: str = "",
wangding zeng's avatar
wangding zeng committed
113
114
    ) -> None:
        super().__init__()
115
116
117
118
119

        # If is_sequence_parallel, the input and output tensors are sharded
        # across the ranks within the tp_group. In this case the weights are
        # replicated and no collective ops are needed.
        # Otherwise we use standard TP with an allreduce at the end.
wangding zeng's avatar
wangding zeng committed
120
        self.gate_up_proj = MergedColumnParallelLinear(
121
122
            hidden_size,
            [intermediate_size] * 2,
wangding zeng's avatar
wangding zeng committed
123
            bias=False,
124
            quant_config=quant_config,
125
            disable_tp=is_sequence_parallel,
126
127
128
129
130
131
132
133
134
135
136
            prefix=f"{prefix}.gate_up_proj",
        )
        self.down_proj = RowParallelLinear(
            intermediate_size,
            hidden_size,
            bias=False,
            quant_config=quant_config,
            reduce_results=reduce_results,
            disable_tp=is_sequence_parallel,
            prefix=f"{prefix}.down_proj",
        )
wangding zeng's avatar
wangding zeng committed
137
        if hidden_act != "silu":
138
139
140
            raise ValueError(
                f"Unsupported activation: {hidden_act}. Only silu is supported for now."
            )
wangding zeng's avatar
wangding zeng committed
141
142
143
144
145
146
147
148
149
150
151
152
        self.act_fn = SiluAndMul()

    def forward(self, x):
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.down_proj(x)
        return x


class DeepseekV2MoE(nn.Module):
    def __init__(
        self,
153
        config: DeepseekV2Config | DeepseekV3Config,
154
        parallel_config: ParallelConfig,
155
        quant_config: QuantizationConfig | None = None,
156
        prefix: str = "",
wangding zeng's avatar
wangding zeng committed
157
158
159
    ):
        super().__init__()
        self.tp_size = get_tensor_model_parallel_world_size()
160
161
        self.tp_rank = get_tensor_model_parallel_rank()

wangding zeng's avatar
wangding zeng committed
162
        self.routed_scaling_factor = config.routed_scaling_factor
163
164
165
166
167
168

        self.ep_group = get_ep_group().device_group
        self.ep_rank = self.ep_group.rank()
        self.ep_size = self.ep_group.size()
        self.n_routed_experts: int = config.n_routed_experts
        self.n_shared_experts: int = config.n_shared_experts
169

170
        self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe
171

172
        if config.hidden_act != "silu":
173
174
175
176
177
178
179
180
181
182
183
184
            raise ValueError(
                f"Unsupported activation: {config.hidden_act}. "
                "Only silu is supported for now."
            )

        self.gate = ReplicatedLinear(
            config.hidden_size,
            config.n_routed_experts,
            bias=False,
            quant_config=None,
            prefix=f"{prefix}.gate",
        )
185
186
        if config.topk_method == "noaux_tc":
            self.gate.e_score_correction_bias = nn.Parameter(
187
188
                torch.empty(config.n_routed_experts, dtype=torch.float32)
            )
189
190
191
        else:
            self.gate.e_score_correction_bias = None

192
        # Load balancing settings.
193
194
        eplb_config = parallel_config.eplb_config
        self.enable_eplb = parallel_config.enable_eplb
195

196
        self.n_redundant_experts = eplb_config.num_redundant_experts
197
        self.n_logical_experts = self.n_routed_experts
198
        self.n_physical_experts = self.n_logical_experts + self.n_redundant_experts
199
200
        self.n_local_physical_experts = self.n_physical_experts // self.ep_size

201
202
203
204
        self.physical_expert_start = self.ep_rank * self.n_local_physical_experts
        self.physical_expert_end = (
            self.physical_expert_start + self.n_local_physical_experts
        )
205

206
207
208
        if config.n_shared_experts is None:
            self.shared_experts = None
        else:
209
            intermediate_size = config.moe_intermediate_size * config.n_shared_experts
210

wangding zeng's avatar
wangding zeng committed
211
212
213
214
215
            self.shared_experts = DeepseekV2MLP(
                hidden_size=config.hidden_size,
                intermediate_size=intermediate_size,
                hidden_act=config.hidden_act,
                quant_config=quant_config,
216
                is_sequence_parallel=self.is_sequence_parallel,
217
                reduce_results=False,
218
                prefix=f"{prefix}.shared_experts",
wangding zeng's avatar
wangding zeng committed
219
220
            )

221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
        self.experts = SharedFusedMoE(
            shared_experts=self.shared_experts,
            num_experts=config.n_routed_experts,
            top_k=config.num_experts_per_tok,
            hidden_size=config.hidden_size,
            intermediate_size=config.moe_intermediate_size,
            reduce_results=False,
            renormalize=config.norm_topk_prob,
            quant_config=quant_config,
            use_grouped_topk=True,
            num_expert_group=config.n_group,
            topk_group=config.topk_group,
            prefix=f"{prefix}.experts",
            scoring_func=config.scoring_func,
            # we do scaling outside, set factor to 1.0 to avoid double mul
            routed_scaling_factor=1.0,
            e_score_correction_bias=self.gate.e_score_correction_bias,
            enable_eplb=self.enable_eplb,
            num_redundant_experts=self.n_redundant_experts,
            is_sequence_parallel=self.is_sequence_parallel,
        )
242

wangding zeng's avatar
wangding zeng committed
243
244
245
    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        num_tokens, hidden_dim = hidden_states.shape
        hidden_states = hidden_states.view(-1, hidden_dim)
246
247
248
249
250
251

        # Chunk the hidden states so they aren't replicated across TP ranks.
        # This avoids duplicate computation in self.experts.
        # TODO: We can replace the all_reduce at the end of attn with a
        # reduce_scatter instead of chunking here.
        if self.is_sequence_parallel:
252
            hidden_states = sequence_parallel_chunk(hidden_states)
253

wangding zeng's avatar
wangding zeng committed
254
255
        # router_logits: (num_tokens, n_experts)
        router_logits, _ = self.gate(hidden_states)
256

257
258
259
        fused_moe_out = self.experts(
            hidden_states=hidden_states, router_logits=router_logits
        )
260
261
262

        if self.shared_experts is not None:
            shared_output, final_hidden_states = fused_moe_out
263
        else:
264
265
266
267
268
269
270
271
272
            shared_output = None
            final_hidden_states = fused_moe_out

        # Fix FP16 overflow
        # See DeepseekV2DecoderLayer for more details.
        if hidden_states.dtype != torch.float16:
            final_hidden_states *= self.routed_scaling_factor
        elif self.shared_experts is not None:
            assert shared_output is not None
273
            shared_output *= 1.0 / self.routed_scaling_factor
274
275
276
277

        if self.shared_experts is not None:
            assert shared_output is not None
            final_hidden_states += shared_output
278

279
280
        if self.is_sequence_parallel:
            final_hidden_states = tensor_model_parallel_all_gather(
281
282
                final_hidden_states, 0
            )
283
284
            final_hidden_states = final_hidden_states[:num_tokens]
        elif self.tp_size > 1:
285
286
287
            final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel(
                final_hidden_states
            )
wangding zeng's avatar
wangding zeng committed
288
289
290
291
292
293

        return final_hidden_states.view(num_tokens, hidden_dim)


def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
    import math
294

wangding zeng's avatar
wangding zeng committed
295
296
297
298
299
300
301
302
    if scale <= 1:
        return 1.0
    return 0.1 * mscale * math.log(scale) + 1.0


class DeepseekV2Attention(nn.Module):
    def __init__(
        self,
303
        vllm_config: VllmConfig,
304
        config: DeepseekV2Config | DeepseekV3Config,
wangding zeng's avatar
wangding zeng committed
305
306
307
308
309
310
311
312
        hidden_size: int,
        num_heads: int,
        qk_nope_head_dim: int,
        qk_rope_head_dim: int,
        v_head_dim: int,
        q_lora_rank: int,
        kv_lora_rank: int,
        rope_theta: float = 10000,
313
        rope_scaling: dict[str, Any] | None = None,
wangding zeng's avatar
wangding zeng committed
314
        max_position_embeddings: int = 8192,
315
316
317
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
        topk_indices_buffer: torch.Tensor | None = None,
318
        prefix: str = "",
wangding zeng's avatar
wangding zeng committed
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
    ) -> None:
        super().__init__()
        self.hidden_size = hidden_size
        self.qk_nope_head_dim = qk_nope_head_dim
        self.qk_rope_head_dim = qk_rope_head_dim
        self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
        self.v_head_dim = v_head_dim
        self.q_lora_rank = q_lora_rank
        self.kv_lora_rank = kv_lora_rank
        self.num_heads = num_heads
        tp_size = get_tensor_model_parallel_world_size()
        assert num_heads % tp_size == 0
        self.num_local_heads = num_heads // tp_size
        self.scaling = self.qk_head_dim**-0.5
        self.rope_theta = rope_theta
        self.max_position_embeddings = max_position_embeddings
335
336
        assert topk_indices_buffer is None, (
            "topk_indices_buffer is not \
337
        supported for DeepseekV2Attention"
338
        )
wangding zeng's avatar
wangding zeng committed
339
340

        if self.q_lora_rank is not None:
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
            self.q_a_proj = ReplicatedLinear(
                self.hidden_size,
                self.q_lora_rank,
                bias=False,
                quant_config=quant_config,
                prefix=f"{prefix}.q_a_proj",
            )
            self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
            self.q_b_proj = ColumnParallelLinear(
                q_lora_rank,
                self.num_heads * self.qk_head_dim,
                bias=False,
                quant_config=quant_config,
                prefix=f"{prefix}.q_b_proj",
            )
wangding zeng's avatar
wangding zeng committed
356
        else:
357
358
359
360
361
362
363
            self.q_proj = ColumnParallelLinear(
                self.hidden_size,
                self.num_heads * self.qk_head_dim,
                bias=False,
                quant_config=quant_config,
                prefix=f"{prefix}.q_proj",
            )
wangding zeng's avatar
wangding zeng committed
364

365
366
367
368
369
        self.kv_a_proj_with_mqa = ReplicatedLinear(
            self.hidden_size,
            self.kv_lora_rank + self.qk_rope_head_dim,
            bias=False,
            quant_config=quant_config,
370
371
372
            prefix=f"{prefix}.kv_a_proj_with_mqa",
        )
        self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
wangding zeng's avatar
wangding zeng committed
373
374
375
376
        self.kv_b_proj = ColumnParallelLinear(
            self.kv_lora_rank,
            self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
            bias=False,
377
            quant_config=quant_config,
378
379
            prefix=f"{prefix}.kv_b_proj",
        )
wangding zeng's avatar
wangding zeng committed
380
        # O projection.
381
382
383
384
385
386
387
        self.o_proj = RowParallelLinear(
            self.num_heads * self.v_head_dim,
            self.hidden_size,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.o_proj",
        )
388
        if rope_scaling:
389
            rope_scaling["rope_type"] = "deepseek_yarn"
390

391
392
393
394
395
396
397
398
        self.rotary_emb = get_rope(
            qk_rope_head_dim,
            rotary_dim=qk_rope_head_dim,
            max_position=max_position_embeddings,
            base=rope_theta,
            rope_scaling=rope_scaling,
            is_neox_style=False,
        )
wangding zeng's avatar
wangding zeng committed
399
400
401
402
403
404
405

        if rope_scaling:
            mscale_all_dim = rope_scaling.get("mscale_all_dim", False)
            scaling_factor = rope_scaling["factor"]
            mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
            self.scaling = self.scaling * mscale * mscale

406
407
408
409
410
411
412
413
414
        self.attn = Attention(
            self.num_local_heads,
            self.qk_head_dim,
            self.scaling,
            num_kv_heads=self.num_local_heads,
            cache_config=cache_config,
            quant_config=quant_config,
            prefix=f"{prefix}.attn",
        )
wangding zeng's avatar
wangding zeng committed
415
416
417
418
419
420
421
422
423

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        if self.q_lora_rank is not None:
            q = self.q_a_proj(hidden_states)[0]
            q = self.q_a_layernorm(q)
424
            q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
wangding zeng's avatar
wangding zeng committed
425
        else:
426
427
428
429
            q = self.q_proj(hidden_states)[0].view(
                -1, self.num_local_heads, self.qk_head_dim
            )
        q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
wangding zeng's avatar
wangding zeng committed
430
        latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
431
        kv_a, _ = latent_cache.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
wangding zeng's avatar
wangding zeng committed
432
        latent_cache = latent_cache.unsqueeze(1)
433
        kv_a = self.kv_a_layernorm(kv_a)
wangding zeng's avatar
wangding zeng committed
434
        kv = self.kv_b_proj(kv_a)[0]
435
        kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim)
wangding zeng's avatar
wangding zeng committed
436
        k_nope, v = kv.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
437
        k_pe = latent_cache[:, :, self.kv_lora_rank :]
438

wangding zeng's avatar
wangding zeng committed
439
        q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
440

441
        q[..., self.qk_nope_head_dim :] = q_pe
wangding zeng's avatar
wangding zeng committed
442
        k = torch.empty_like(q)
443
444
        k[..., : self.qk_nope_head_dim] = k_nope
        k[..., self.qk_nope_head_dim :] = k_pe
445
446
        # padding value to qk_head_dim for alignment
        v = torch.nn.functional.pad(
447
448
            v, [0, self.qk_head_dim - self.v_head_dim], value=0
        ).view(-1, self.num_local_heads * self.qk_head_dim)
449
        attn_output = self.attn(q, k, v)
450
451
452
        attn_output = attn_output.view(-1, self.num_local_heads, self.qk_head_dim)[
            ..., : self.v_head_dim
        ].reshape(-1, self.num_local_heads * self.v_head_dim)
wangding zeng's avatar
wangding zeng committed
453
454
455
456
        output, _ = self.o_proj(attn_output)
        return output


457
class DeepseekV32IndexerCache(torch.nn.Module, AttentionLayerBase):
458
459
460
    def __init__(
        self, head_dim: int, dtype: torch.dtype, prefix: str, cache_config: CacheConfig
    ):
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
        super().__init__()
        self.kv_cache = [torch.tensor([])]
        self.head_dim = head_dim
        self.prefix = prefix
        self.cache_config = cache_config
        self.dtype = dtype
        compilation_config = get_current_vllm_config().compilation_config
        if prefix in compilation_config.static_forward_context:
            raise ValueError(f"Duplicate layer name: {prefix}")
        compilation_config.static_forward_context[prefix] = self

    def get_kv_cache_spec(self) -> KVCacheSpec:
        return MLAAttentionSpec(  # Only has one vector instead of K + V
            block_size=self.cache_config.block_size,
            num_kv_heads=1,
            head_size=self.head_dim,
            dtype=self.dtype,
        )

480
    def forward(self): ...
481
482
483
484
485
486
487
488
489
490
491
492
493

    def get_attn_backend(self) -> AttentionBackend:
        return DeepseekV32IndexerBackend


def sparse_attn_indexer(
    hidden_states: torch.Tensor,
    k_cache_prefix: str,
    kv_cache: torch.Tensor,
    q_fp8: torch.Tensor,
    k: torch.Tensor,
    weights: torch.Tensor,
    quant_block_size: int,
494
    scale_fmt: str | None,
495
496
497
498
    topk_tokens: int,
    head_dim: int,
    max_model_len: int,
    total_seq_lens: int,
499
    topk_indices_buffer: torch.Tensor | None,
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
) -> torch.Tensor:
    # careful! this will be None in dummy run
    attn_metadata = get_forward_context().attn_metadata
    # assert isinstance(attn_metadata, dict)
    if not isinstance(attn_metadata, dict):
        return sparse_attn_indexer_fake(
            hidden_states,
            k_cache_prefix,
            kv_cache,
            q_fp8,
            k,
            weights,
            quant_block_size,
            scale_fmt,
            topk_tokens,
            head_dim,
            max_model_len,
            total_seq_lens,
            topk_indices_buffer,
        )
    attn_metadata = attn_metadata[k_cache_prefix]
    assert isinstance(attn_metadata, DeepseekV32IndexerMetadata)
    slot_mapping = attn_metadata.slot_mapping
    has_decode = attn_metadata.num_decodes > 0
    has_prefill = attn_metadata.num_prefills > 0
    num_decode_tokens = attn_metadata.num_decode_tokens

    ops.indexer_k_quant_and_cache(
        k,
        kv_cache,
        slot_mapping,
        quant_block_size,
        scale_fmt,
    )

535
    topk_indices_buffer[: hidden_states.shape[0]] = -1
536
537
    if has_prefill:
        prefill_metadata = attn_metadata.prefill
538
        for chunk in prefill_metadata.chunks:
539
540
541
542
543
544
            k_fp8 = torch.empty(
                [chunk.total_seq_lens, head_dim],
                device=k.device,
                dtype=torch.float8_e4m3fn,
            )
            k_scale = torch.empty(
545
546
547
                [chunk.total_seq_lens, 4],
                device=k.device,
                dtype=torch.uint8,
548
            )
549
            ops.cp_gather_indexer_k_quant_cache(
550
551
552
553
554
555
556
                kv_cache,
                k_fp8,
                k_scale,
                chunk.block_table,
                chunk.cu_seq_lens,
            )
            logits = fp8_mqa_logits(
557
                q_fp8[chunk.token_start : chunk.token_end],
558
                (k_fp8, k_scale.view(torch.float32)),
559
                weights[chunk.token_start : chunk.token_end],
560
561
562
                chunk.cu_seqlen_ks,
                chunk.cu_seqlen_ke,
            )
563
564
565
566
            num_rows = logits.shape[0]
            assert topk_tokens == 2048, "top_k_per_row assumes size 2048"
            topk_indices = torch.empty(
                num_rows, topk_tokens, dtype=torch.int32, device=logits.device
567
            )
568
569
570
571
572
573
574
575
576
577
578
579
            topk_values = torch.empty(
                num_rows, topk_tokens, dtype=logits.dtype, device=logits.device
            )
            torch.ops._C.top_k_per_row(
                logits,
                chunk.cu_seqlen_ks,
                chunk.cu_seqlen_ke,
                topk_indices,
                topk_values,
                num_rows,
                logits.stride(0),
                logits.stride(1),
580
            )
581
            topk_indices_buffer[
582
583
                chunk.token_start : chunk.token_end, : topk_indices.shape[-1]
            ] = topk_indices.to(dtype=torch.int32)
584
585
586
587
588
589
590
591
592
593
594
595
596

    if has_decode:
        decode_metadata = attn_metadata.decode
        # kv_cache size requirement [num_block, block_size, n_head, head_dim],
        # we only have [num_block, block_size, head_dim],
        kv_cache = kv_cache.unsqueeze(-2)
        decode_lens = decode_metadata.decode_lens
        if decode_metadata.requires_padding:
            # pad in edge case where we have short chunked prefill length <
            # decode_threshold since we unstrictly split
            # prefill and decode by decode_threshold
            # (currently set to 1 + speculative tokens)
            padded_q_fp8_decode_tokens = pack_seq_triton(
597
598
                q_fp8[:num_decode_tokens], decode_lens
            )
599
600
        else:
            padded_q_fp8_decode_tokens = q_fp8[:num_decode_tokens].reshape(
601
602
                decode_lens.shape[0], -1, *q_fp8.shape[1:]
            )
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
        # TODO: move and optimize below logic with triton kernels
        batch_size = padded_q_fp8_decode_tokens.shape[0]
        next_n = padded_q_fp8_decode_tokens.shape[1]
        assert batch_size == decode_metadata.seq_lens.shape[0]
        num_padded_tokens = batch_size * next_n
        logits = fp8_paged_mqa_logits(
            padded_q_fp8_decode_tokens,
            kv_cache,
            weights[:num_padded_tokens],
            decode_metadata.seq_lens,
            decode_metadata.block_table,
            decode_metadata.schedule_metadata,
            max_model_len=max_model_len,
        )
        # padded query len
        current_device = padded_q_fp8_decode_tokens.device
        padded_num_tokens = batch_size * next_n
620
621
622
623
624
625
        row_indices = torch.arange(padded_num_tokens, device=current_device) // next_n
        next_n_offset = (
            torch.arange(padded_num_tokens, device=padded_q_fp8_decode_tokens.device)
            % next_n
        )
        index_end_pos = (
626
            decode_metadata.seq_lens[row_indices] - next_n + next_n_offset + 1
627
        ).unsqueeze(1)
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
        num_rows = logits.shape[0]
        assert topk_tokens == 2048, "top_k_per_row assumes size 2048"
        topk_indices = torch.empty(
            num_rows, topk_tokens, dtype=torch.int32, device=logits.device
        )
        topk_values = torch.empty(
            num_rows, topk_tokens, dtype=logits.dtype, device=logits.device
        )
        torch.ops._C.top_k_per_row(
            logits,
            torch.zeros(num_rows, dtype=torch.int32, device=logits.device),
            index_end_pos.to(dtype=torch.int32, device=logits.device),
            topk_indices,
            topk_values,
            num_rows,
            logits.stride(0),
            logits.stride(1),
        )
646
647
648
649
650
        if decode_metadata.requires_padding:
            # if padded, we need to unpack
            # the topk indices removing padded tokens
            topk_indices = unpack_seq_triton(
                topk_indices.reshape(batch_size, -1, topk_indices.shape[-1]),
651
652
653
654
655
                decode_lens,
            )
        topk_indices_buffer[:num_decode_tokens, : topk_indices.shape[-1]] = (
            topk_indices.to(dtype=torch.int32)
        )
656
657
658
659
660
661
662
663
664
665
666
667

    return topk_indices_buffer


def sparse_attn_indexer_fake(
    hidden_states: torch.Tensor,
    k_cache_prefix: str,
    kv_cache: torch.Tensor,
    q_fp8: torch.Tensor,
    k: torch.Tensor,
    weights: torch.Tensor,
    quant_block_size: int,
668
    scale_fmt: str | None,
669
670
671
672
    topk_tokens: int,
    head_dim: int,
    max_model_len: int,
    total_seq_lens: int,
673
    topk_indices_buffer: torch.Tensor | None,
674
675
676
677
) -> torch.Tensor:
    # profile run
    # NOTE(Chen): create the max possible flattened_kv. So that
    # profile_run can get correct memory usage.
678
679
680
681
    _flattened_kv = torch.empty(
        [total_seq_lens, head_dim + 4], device=k.device, dtype=torch.uint8
    )
    _k_fp8 = _flattened_kv[..., :head_dim].view(torch.float8_e4m3fn).contiguous()
682
683
684
685
686
687
688
689
690
691
692
693
694
695
    _k_scale = _flattened_kv[..., head_dim:].view(torch.float32).contiguous()
    return topk_indices_buffer


direct_register_custom_op(
    op_name="sparse_attn_indexer",
    op_func=sparse_attn_indexer,
    mutates_args=["topk_indices_buffer"],
    fake_impl=sparse_attn_indexer_fake,
    dispatch_key=current_platform.dispatch_key,
)


class Indexer(nn.Module):
696
697
698
    def __init__(
        self,
        vllm_config: VllmConfig,
699
        config: DeepseekV2Config | DeepseekV3Config,
700
701
        hidden_size: int,
        q_lora_rank: int,
702
703
704
        quant_config: QuantizationConfig | None,
        cache_config: CacheConfig | None,
        topk_indices_buffer: torch.Tensor | None,
705
706
        prefix: str = "",
    ):
707
708
709
710
711
712
713
714
715
716
        super().__init__()
        self.vllm_config = vllm_config
        self.config = config
        # self.indexer_cfg = config.attn_module_list_cfg[0]["attn_index"]
        self.topk_tokens = config.index_topk
        self.n_head = config.index_n_heads  # 64
        self.head_dim = config.index_head_dim  # 128
        self.rope_dim = config.qk_rope_head_dim  # 64
        self.q_lora_rank = q_lora_rank  # 1536
        # no tensor parallel, just replicated
717
718
719
720
721
722
723
724
725
726
727
728
729
730
        self.wq_b = ReplicatedLinear(
            self.q_lora_rank,
            self.head_dim * self.n_head,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.wq_b",
        )
        self.wk = ReplicatedLinear(
            hidden_size,
            self.head_dim,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.wk",
        )
731
        self.k_norm = LayerNorm(self.head_dim, eps=1e-6)
732
733
734
        self.weights_proj = ReplicatedLinear(
            hidden_size, self.n_head, quant_config=None, prefix=f"{prefix}.weights_proj"
        )
735
736
737
738
739
740
741
742
743
744
        self.softmax_scale = self.head_dim**-0.5

        self.scale_fmt = "ue8m0"
        self.quant_block_size = 128  # TODO: get from config
        self.topk_indices_buffer = topk_indices_buffer

        # NOTE: (zyongye) we use fp8 naive cache,
        #       where we store value in fp8 and scale in fp32
        #       per self.quant_block_size element
        self.k_cache = DeepseekV32IndexerCache(
745
            head_dim=self.head_dim + self.head_dim // self.quant_block_size * 4,
746
747
            dtype=torch.uint8,
            prefix=f"{prefix}.k_cache",
748
749
            cache_config=cache_config,
        )
750
751
        self.max_model_len = vllm_config.model_config.max_model_len
        self.prefix = prefix
752
753
        from vllm.v1.attention.backends.mla.indexer import get_max_prefill_buffer_size

754
755
        self.max_total_seq_len = get_max_prefill_buffer_size(vllm_config)

756
757
758
    def forward(
        self, hidden_states: torch.Tensor, qr: torch.Tensor, positions, rotary_emb
    ) -> torch.Tensor:
759
760
761
        q, _ = self.wq_b(qr)
        q = q.view(-1, self.n_head, self.head_dim)
        q_pe, q_nope = torch.split(
762
763
            q, [self.rope_dim, self.head_dim - self.rope_dim], dim=-1
        )
764
765
766
767

        k, _ = self.wk(hidden_states)
        k = self.k_norm(k)
        k_pe, k_nope = torch.split(
768
769
            k, [self.rope_dim, self.head_dim - self.rope_dim], dim=-1
        )
770
771
772
773
774
775
776

        q_pe, k_pe = rotary_emb(positions, q_pe, k_pe.unsqueeze(1))
        q = torch.cat([q_pe, q_nope], dim=-1)
        k = torch.cat([k_pe.squeeze(1), k_nope], dim=-1)

        # we only quant q here since k quant is fused with cache insertion
        q = q.view(-1, self.head_dim)
777
778
779
780
781
782
        q_fp8, q_scale = per_token_group_quant_fp8(
            q,
            self.quant_block_size,
            column_major_scales=False,
            use_ue8m0=self.scale_fmt is not None,
        )
783
784
785
786
        q_fp8 = q_fp8.view(-1, self.n_head, self.head_dim)
        q_scale = q_scale.view(-1, self.n_head, 1)

        weights, _ = self.weights_proj(hidden_states)
787
788
789
        weights = (
            weights.unsqueeze(-1) * q_scale * self.softmax_scale * self.n_head**-0.5
        )
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
        weights = weights.squeeze(-1)

        return torch.ops.vllm.sparse_attn_indexer(
            hidden_states,
            self.k_cache.prefix,
            self.k_cache.kv_cache[0],
            q_fp8,
            k,
            weights,
            self.quant_block_size,
            self.scale_fmt,
            self.topk_tokens,
            self.head_dim,
            self.max_model_len,
            self.max_total_seq_len,
            self.topk_indices_buffer,
        )


809
810
811
812
class DeepseekV2MLAAttention(nn.Module):
    """
    Main reference: DeepseekV2 paper, and FlashInfer Implementation
    (https://arxiv.org/abs/2405.04434 and https://github.com/flashinfer-ai/flashinfer/pull/551).
813

814
815
        For more info see MLACommonImpl in:
        vllm/v1/attention/backends/mla/utils.py
816
817
818
819
    """

    def __init__(
        self,
820
        vllm_config: VllmConfig,
821
        config: DeepseekV2Config | DeepseekV3Config,
822
823
824
825
826
        hidden_size: int,
        num_heads: int,
        qk_nope_head_dim: int,
        qk_rope_head_dim: int,
        v_head_dim: int,
827
        q_lora_rank: int | None,
828
829
        kv_lora_rank: int,
        rope_theta: float = 10000,
830
        rope_scaling: dict[str, Any] | None = None,
831
        max_position_embeddings: int = 8192,
832
833
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
834
        prefix: str = "",
835
        topk_indices_buffer: torch.Tensor | None = None,
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
    ) -> None:
        super().__init__()
        self.hidden_size = hidden_size
        self.qk_nope_head_dim = qk_nope_head_dim
        self.qk_rope_head_dim = qk_rope_head_dim
        self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
        self.v_head_dim = v_head_dim

        self.q_lora_rank = q_lora_rank
        self.kv_lora_rank = kv_lora_rank

        self.num_heads = num_heads
        tp_size = get_tensor_model_parallel_world_size()
        assert num_heads % tp_size == 0
        self.num_local_heads = num_heads // tp_size

        self.scaling = self.qk_head_dim**-0.5
        self.rope_theta = rope_theta
        self.max_position_embeddings = max_position_embeddings

        if self.q_lora_rank is not None:
857
            self.fused_qkv_a_proj = MergedColumnParallelLinear(
858
859
860
861
                self.hidden_size,
                [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim],
                bias=False,
                quant_config=quant_config,
862
                prefix=f"{prefix}.fused_qkv_a_proj",
863
864
                disable_tp=True,
            )
865
866
867
868
869
870
        else:
            self.kv_a_proj_with_mqa = ReplicatedLinear(
                self.hidden_size,
                self.kv_lora_rank + self.qk_rope_head_dim,
                bias=False,
                quant_config=quant_config,
871
872
                prefix=f"{prefix}.kv_a_proj_with_mqa",
            )
873
874

        if self.q_lora_rank is not None:
875
876
877
878
879
880
881
882
            self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
            self.q_b_proj = ColumnParallelLinear(
                self.q_lora_rank,
                self.num_heads * self.qk_head_dim,
                bias=False,
                quant_config=quant_config,
                prefix=f"{prefix}.q_b_proj",
            )
883
        else:
884
885
886
887
888
889
890
891
            self.q_proj = ColumnParallelLinear(
                self.hidden_size,
                self.num_heads * self.qk_head_dim,
                bias=False,
                quant_config=quant_config,
                prefix=f"{prefix}.q_proj",
            )
        self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
892
893
894
895
896
        self.kv_b_proj = ColumnParallelLinear(
            self.kv_lora_rank,
            self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
            bias=False,
            quant_config=quant_config,
897
898
899
900
901
902
903
904
905
            prefix=f"{prefix}.kv_b_proj",
        )
        self.o_proj = RowParallelLinear(
            self.num_heads * self.v_head_dim,
            self.hidden_size,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.o_proj",
        )
906

907
        if rope_scaling:
908
909
910
911
912
913
914
915
916
            rope_scaling["rope_type"] = "deepseek_yarn"
        self.rotary_emb = get_rope(
            qk_rope_head_dim,
            rotary_dim=qk_rope_head_dim,
            max_position=max_position_embeddings,
            base=rope_theta,
            rope_scaling=rope_scaling,
            is_neox_style=False,
        )
917
918
919
920
921
922
        if rope_scaling:
            mscale_all_dim = rope_scaling.get("mscale_all_dim", False)
            scaling_factor = rope_scaling["factor"]
            mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
            self.scaling = self.scaling * mscale * mscale

923
924
925
        self.is_v32 = hasattr(config, "index_topk")

        if self.is_v32:
926
927
928
929
930
931
932
933
934
935
            self.indexer = Indexer(
                vllm_config,
                config,
                hidden_size,
                q_lora_rank,
                quant_config,
                cache_config,
                topk_indices_buffer,
                f"{prefix}.indexer",
            )
936
937
938
        else:
            self.indexer = None

939
940
        mla_modules = MLAModules(
            kv_a_layernorm=self.kv_a_layernorm,
941
            kv_b_proj=self.kv_b_proj,
942
943
944
            rotary_emb=self.rotary_emb,
            o_proj=self.o_proj,
            fused_qkv_a_proj=self.fused_qkv_a_proj
945
946
            if self.q_lora_rank is not None
            else None,
947
            kv_a_proj_with_mqa=self.kv_a_proj_with_mqa
948
949
950
            if self.q_lora_rank is None
            else None,
            q_a_layernorm=self.q_a_layernorm if self.q_lora_rank is not None else None,
951
952
            q_b_proj=self.q_b_proj if self.q_lora_rank is not None else None,
            q_proj=self.q_proj if self.q_lora_rank is None else None,
953
954
955
            indexer=self.indexer,
            is_sparse=self.is_v32,
            topk_indices_buffer=topk_indices_buffer,
956
        )
957

958
        self.mla_attn = MultiHeadLatentAttentionWrapper(
959
960
961
962
963
964
965
966
967
968
969
970
            self.hidden_size,
            self.num_local_heads,
            self.scaling,
            self.qk_nope_head_dim,
            self.qk_rope_head_dim,
            self.v_head_dim,
            self.q_lora_rank,
            self.kv_lora_rank,
            mla_modules,
            cache_config,
            quant_config,
            prefix,
971
972
973
974
975
976
977
        )

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
978
        return self.mla_attn(positions, hidden_states)
979
980


wangding zeng's avatar
wangding zeng committed
981
class DeepseekV2DecoderLayer(nn.Module):
982
983
984
985
    def __init__(
        self,
        vllm_config: VllmConfig,
        prefix: str,
986
987
        config: DeepseekV2Config | None = None,
        topk_indices_buffer: torch.Tensor | None = None,
988
    ) -> None:
wangding zeng's avatar
wangding zeng committed
989
        super().__init__()
990

991
992
        if config is None:
            config = vllm_config.model_config.hf_config
993
994
995
996
997
        model_config = vllm_config.model_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config
        parallel_config = vllm_config.parallel_config

wangding zeng's avatar
wangding zeng committed
998
999
1000
        self.hidden_size = config.hidden_size
        rope_theta = getattr(config, "rope_theta", 10000)
        rope_scaling = getattr(config, "rope_scaling", None)
1001
        max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
1002
1003
        # DecoderLayers are created with `make_layers` which passes the prefix
        # with the layer's index.
1004
        layer_idx = int(prefix.split(sep=".")[-1])
1005
        self.layer_idx = layer_idx
1006
1007
1008
1009
1010
        if model_config.use_mla:
            attn_cls = DeepseekV2MLAAttention
        else:
            attn_cls = DeepseekV2Attention
        self.self_attn = attn_cls(
1011
            vllm_config=vllm_config,
wangding zeng's avatar
wangding zeng committed
1012
1013
1014
1015
1016
1017
            config=config,
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
            qk_nope_head_dim=config.qk_nope_head_dim,
            qk_rope_head_dim=config.qk_rope_head_dim,
            v_head_dim=config.v_head_dim,
1018
            q_lora_rank=config.q_lora_rank if hasattr(config, "q_lora_rank") else None,
wangding zeng's avatar
wangding zeng committed
1019
1020
1021
1022
1023
1024
            kv_lora_rank=config.kv_lora_rank,
            rope_theta=rope_theta,
            rope_scaling=rope_scaling,
            max_position_embeddings=max_position_embeddings,
            cache_config=cache_config,
            quant_config=quant_config,
1025
            prefix=f"{prefix}.self_attn",
1026
            topk_indices_buffer=topk_indices_buffer,
wangding zeng's avatar
wangding zeng committed
1027
        )
1028

1029
1030
1031
1032
1033
        if (
            config.n_routed_experts is not None
            and layer_idx >= config.first_k_dense_replace
            and layer_idx % config.moe_layer_freq == 0
        ):
1034
1035
            self.mlp = DeepseekV2MoE(
                config=config,
1036
                parallel_config=parallel_config,
1037
1038
1039
                quant_config=quant_config,
                prefix=f"{prefix}.mlp",
            )
wangding zeng's avatar
wangding zeng committed
1040
1041
1042
1043
1044
1045
        else:
            self.mlp = DeepseekV2MLP(
                hidden_size=config.hidden_size,
                intermediate_size=config.intermediate_size,
                hidden_act=config.hidden_act,
                quant_config=quant_config,
1046
                prefix=f"{prefix}.mlp",
wangding zeng's avatar
wangding zeng committed
1047
            )
1048
1049
1050
1051
        self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(
            config.hidden_size, eps=config.rms_norm_eps
        )
1052
        self.routed_scaling_factor = config.routed_scaling_factor
wangding zeng's avatar
wangding zeng committed
1053
1054
1055
1056
1057

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
1058
        residual: torch.Tensor | None,
wangding zeng's avatar
wangding zeng committed
1059
1060
1061
    ) -> torch.Tensor:
        # Self Attention
        if residual is None:
1062
            residual = hidden_states.clone()
wangding zeng's avatar
wangding zeng committed
1063
1064
            hidden_states = self.input_layernorm(hidden_states)
        else:
1065
            hidden_states, residual = self.input_layernorm(hidden_states, residual)
wangding zeng's avatar
wangding zeng committed
1066
1067
1068
1069
1070
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
        )

1071
1072
1073
1074
        if hidden_states.dtype == torch.float16:
            # Fix FP16 overflow
            # We scale both hidden_states and residual before
            # rmsnorm, and rmsnorm result would not affect by scale.
1075
            hidden_states *= 1.0 / self.routed_scaling_factor
1076
1077
1078
            if self.layer_idx == 0:
                # The residual is shared by all layers, we only scale it on
                # first layer.
1079
                residual *= 1.0 / self.routed_scaling_factor
1080
1081

        # Fully Connected
1082
        hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
wangding zeng's avatar
wangding zeng committed
1083
        hidden_states = self.mlp(hidden_states)
1084

1085
        if isinstance(self.mlp, DeepseekV2MLP) and hidden_states.dtype == torch.float16:
1086
1087
1088
1089
1090
            # Fix FP16 overflow
            # Scaling the DeepseekV2MLP output, it is the input of
            # input_layernorm of next decoder layer.
            # The scaling of DeepseekV2MOE output would be done in the forward
            # of DeepseekV2MOE
1091
            hidden_states *= 1.0 / self.routed_scaling_factor
1092

wangding zeng's avatar
wangding zeng committed
1093
1094
1095
        return hidden_states, residual


1096
@support_torch_compile
wangding zeng's avatar
wangding zeng committed
1097
1098
1099
class DeepseekV2Model(nn.Module):
    fall_back_to_pt_during_load = False

1100
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
wangding zeng's avatar
wangding zeng committed
1101
        super().__init__()
1102
1103
1104

        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
1105
        self.config = config
1106
        self.device = current_platform.device_type
1107

wangding zeng's avatar
wangding zeng committed
1108
        self.vocab_size = config.vocab_size
1109
1110
1111
1112
1113
1114
1115
        self.is_v32 = hasattr(config, "index_topk")
        if self.is_v32:
            topk_tokens = config.index_topk
            topk_indices_buffer = torch.empty(
                vllm_config.scheduler_config.max_num_batched_tokens,
                topk_tokens,
                dtype=torch.int32,
1116
                device=self.device,
1117
            )
1118
1119
        else:
            topk_indices_buffer = None
wangding zeng's avatar
wangding zeng committed
1120

1121
1122
1123
1124
        if get_pp_group().is_first_rank:
            self.embed_tokens = VocabParallelEmbedding(
                config.vocab_size,
                config.hidden_size,
1125
                quant_config=quant_config,
1126
1127
                prefix=f"{prefix}.embed_tokens",
            )
1128
1129
1130
1131
1132
        else:
            self.embed_tokens = PPMissingLayer()

        self.start_layer, self.end_layer, self.layers = make_layers(
            config.num_hidden_layers,
1133
            lambda prefix: DeepseekV2DecoderLayer(
1134
                vllm_config, prefix, topk_indices_buffer=topk_indices_buffer
1135
1136
1137
            ),
            prefix=f"{prefix}.layers",
        )
1138
1139
1140
1141
1142

        if get_pp_group().is_last_rank:
            self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        else:
            self.norm = PPMissingLayer()
1143
1144
1145
        self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
            ["hidden_states", "residual"], config.hidden_size
        )
wangding zeng's avatar
wangding zeng committed
1146

1147
1148
1149
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.embed_tokens(input_ids)

wangding zeng's avatar
wangding zeng committed
1150
1151
1152
1153
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
1154
1155
1156
        intermediate_tensors: IntermediateTensors | None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor | IntermediateTensors:
1157
        if get_pp_group().is_first_rank:
1158
1159
1160
1161
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
                hidden_states = self.get_input_embeddings(input_ids)
1162
1163
1164
1165
1166
1167
            residual = None
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]

1168
        for layer in islice(self.layers, self.start_layer, self.end_layer):
1169
            hidden_states, residual = layer(positions, hidden_states, residual)
1170
1171

        if not get_pp_group().is_last_rank:
1172
1173
1174
            return IntermediateTensors(
                {"hidden_states": hidden_states, "residual": residual}
            )
1175

wangding zeng's avatar
wangding zeng committed
1176
1177
1178
        hidden_states, _ = self.norm(hidden_states, residual)
        return hidden_states

1179

1180
class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts, SupportsLoRA):
1181
1182
1183
    packed_modules_mapping = {
        "gate_up_proj": ["gate_proj", "up_proj"],
    }
1184
1185
1186
1187
1188
1189
1190

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        self.config = config
        self.quant_config = quant_config
1191
1192
1193
1194
1195

        # `packed_modules_mapping` needs to be modified before
        # initializing DeepseekV2Model, as it is passed inplace to
        # quantization config init and may be used to select the
        # quant_method for relevant layers during initialization.
1196
1197
1198
        self.fuse_qkv_a_proj = (
            hasattr(config, "q_lora_rank") and config.q_lora_rank is not None
        )
1199
1200
1201
1202
1203
1204
        if self.fuse_qkv_a_proj:
            self.packed_modules_mapping["fused_qkv_a_proj"] = [
                "q_a_proj",
                "kv_a_proj_with_mqa",
            ]

1205
1206
1207
        self.model = DeepseekV2Model(
            vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
        )
1208
        if get_pp_group().is_last_rank:
1209
1210
1211
1212
1213
1214
            self.lm_head = ParallelLMHead(
                config.vocab_size,
                config.hidden_size,
                quant_config=quant_config,
                prefix=maybe_prefix(prefix, "lm_head"),
            )
1215
1216
1217
1218
        else:
            self.lm_head = PPMissingLayer()
        self.logits_processor = LogitsProcessor(config.vocab_size)
        self.make_empty_intermediate_tensors = (
1219
1220
            self.model.make_empty_intermediate_tensors
        )
1221
1222
1223
        self.expert_weights = []

        # Set MoE hyperparameters
1224
        self.num_moe_layers = config.num_hidden_layers - config.first_k_dense_replace
1225
1226
        self.num_expert_groups = config.n_group

1227
        self.moe_layers: list[SharedFusedMoE] = []
1228
        example_moe = None
1229
        for layer in self.model.layers:
1230
1231
1232
            if isinstance(layer, PPMissingLayer):
                continue

1233
1234
            assert isinstance(layer, DeepseekV2DecoderLayer)
            if isinstance(layer.mlp, DeepseekV2MoE):
1235
1236
                # Pick last one layer since the first ones may be dense layers.
                example_moe = layer.mlp
1237
1238
                self.moe_layers.append(layer.mlp.experts)

1239
1240
1241
        if example_moe is None:
            raise RuntimeError("No DeepseekV2MoE layer found in model.layers.")

1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
        self.num_logical_experts = example_moe.n_logical_experts
        self.num_physical_experts = example_moe.n_physical_experts
        self.num_local_physical_experts = example_moe.n_local_physical_experts
        self.num_routed_experts = example_moe.n_routed_experts
        self.num_shared_experts = example_moe.n_shared_experts
        self.num_redundant_experts = example_moe.n_redundant_experts

    def set_eplb_state(
        self,
        expert_load_view: torch.Tensor,
        logical_to_physical_map: torch.Tensor,
        logical_replica_count: torch.Tensor,
    ) -> None:
        for layer_idx, layer in enumerate(self.moe_layers):
            # Register the expert weights.
            self.expert_weights.append(layer.get_expert_weights())
            layer.set_eplb_state(
                moe_layer_idx=layer_idx,
                expert_load_view=expert_load_view,
                logical_to_physical_map=logical_to_physical_map,
                logical_replica_count=logical_replica_count,
            )
1264

1265
1266
1267
1268
1269
1270
1271
1272
    def update_physical_experts_metadata(
        self,
        num_physical_experts: int,
        num_local_physical_experts: int,
    ) -> None:
        assert self.num_local_physical_experts == num_local_physical_experts
        self.num_physical_experts = num_physical_experts
        self.num_local_physical_experts = num_local_physical_experts
1273
        self.num_redundant_experts = num_physical_experts - self.num_logical_experts
1274
1275
1276
1277
1278
1279
1280
1281
        for layer in self.model.layers:
            if isinstance(layer.mlp, DeepseekV2MoE):
                moe = layer.mlp
                moe.n_local_physical_experts = num_local_physical_experts
                moe.n_physical_experts = num_physical_experts
                moe.n_redundant_experts = self.num_redundant_experts
                moe.experts.update_expert_map()

1282
1283
1284
1285
1286
1287
1288
    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.get_input_embeddings(input_ids)

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
1289
1290
1291
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor | IntermediateTensors:
1292
1293
1294
        hidden_states = self.model(
            input_ids, positions, intermediate_tensors, inputs_embeds
        )
1295
1296
1297
1298
1299
        return hidden_states

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
1300
    ) -> torch.Tensor | None:
1301
        logits = self.logits_processor(self.lm_head, hidden_states)
1302
1303
        return logits

1304
    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
wangding zeng's avatar
wangding zeng committed
1305
1306
1307
1308
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
1309
1310
            ("fused_qkv_a_proj", "q_a_proj", 0),
            ("fused_qkv_a_proj", "kv_a_proj_with_mqa", 1),
wangding zeng's avatar
wangding zeng committed
1311
1312
        ]

1313
1314
        # Params for weights, fp8 weight scales, fp8 activation scales
        # (param_name, weight_name, expert_id, shard_id)
1315
        expert_params_mapping = SharedFusedMoE.make_expert_params_mapping(
1316
1317
1318
            ckpt_gate_proj_name="gate_proj",
            ckpt_down_proj_name="down_proj",
            ckpt_up_proj_name="up_proj",
1319
            num_experts=self.config.n_routed_experts,
1320
1321
            num_redundant_experts=self.num_redundant_experts,
        )
1322

wangding zeng's avatar
wangding zeng committed
1323
        params_dict = dict(self.named_parameters())
1324
        loaded_params: set[str] = set()
wangding zeng's avatar
wangding zeng committed
1325
        for name, loaded_weight in weights:
1326
1327
1328
            if "rotary_emb.inv_freq" in name:
                continue

1329
1330
1331
            spec_layer = get_spec_layer_idx_from_weight_name(self.config, name)
            if spec_layer is not None:
                continue  # skip spec decode layers for main model
1332

1333
            for param_name, weight_name, shard_id in stacked_params_mapping:
1334
                # Skip non-stacked layers and experts (experts handled below).
wangding zeng's avatar
wangding zeng committed
1335
1336
                if weight_name not in name:
                    continue
1337
1338
1339
1340
1341
1342
                # We have mlp.experts[0].gate_proj in the checkpoint.
                # Since we handle the experts below in expert_params_mapping,
                # we need to skip here BEFORE we update the name, otherwise
                # name will be updated to mlp.experts[0].gate_up_proj, which
                # will then be updated below in expert_params_mapping
                # for mlp.experts[0].gate_gate_up_proj, which breaks load.
1343
                if ("mlp.experts." in name) and name not in params_dict:
1344
                    continue
1345
                name_mapped = name.replace(weight_name, param_name)
1346
1347
1348

                # QKV fusion is optional, fall back to normal
                # weight loading if it's not enabled
1349
                # if go with fusion option, then update name
1350
1351
1352
                if (
                    param_name == "fused_qkv_a_proj"
                ) and name_mapped not in params_dict:
1353
                    continue
1354
1355
                else:
                    name = name_mapped
wangding zeng's avatar
wangding zeng committed
1356
1357
1358
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
1359
1360
1361
1362

                if is_pp_missing_parameter(name, self):
                    continue

wangding zeng's avatar
wangding zeng committed
1363
1364
1365
1366
1367
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
1368
                is_expert_weight = False
1369
1370
1371
1372
                for mapping in expert_params_mapping:
                    param_name, weight_name, expert_id, shard_id = mapping
                    if weight_name not in name:
                        continue
1373

1374
1375
1376
1377
1378
1379
1380
1381
1382
                    # Anyway, this is an expert weight and should not be
                    # attempted to load as other weights later
                    is_expert_weight = True

                    # Do not modify `name` since the loop may continue here
                    # Instead, create a new variable
                    name_mapped = name.replace(weight_name, param_name)

                    if is_pp_missing_parameter(name_mapped, self):
1383
1384
                        continue

1385
1386
1387
1388
                    param = params_dict[name_mapped]
                    # We should ask the weight loader to return success or not
                    # here since otherwise we may skip experts with other
                    # available replicas.
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
                    weight_loader = typing.cast(
                        Callable[..., bool], param.weight_loader
                    )
                    success = weight_loader(
                        param,
                        loaded_weight,
                        name_mapped,
                        shard_id=shard_id,
                        expert_id=expert_id,
                        return_success=True,
                    )
1400
                    if success:
1401
                        name = name_mapped
1402
                        break
1403
                else:
1404
1405
1406
1407
1408
1409
                    if is_expert_weight:
                        # We've checked that this is an expert weight
                        # However it's not mapped locally to this rank
                        # So we simply skip it
                        continue

1410
1411
1412
1413
                    # Skip loading extra bias for GPTQ models.
                    if name.endswith(".bias") and name not in params_dict:
                        continue

1414
1415
1416
1417
1418
                    # Remapping the name of FP8 kv-scale.
                    name = maybe_remap_kv_scale_name(name, params_dict)
                    if name is None:
                        continue

1419
1420
1421
                    if is_pp_missing_parameter(name, self):
                        continue

1422
                    param = params_dict[name]
1423
1424
1425
                    weight_loader = getattr(
                        param, "weight_loader", default_weight_loader
                    )
1426
                    weight_loader(param, loaded_weight)
1427
            loaded_params.add(name)
1428

1429
        return loaded_params
1430
1431
1432
1433


class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM):
    pass
1434
1435


1436
1437
# Compatibility with
# https://huggingface.co/deepseek-ai/DeepSeek-V3-Base/blob/main/configuration_deepseek.py
1438
def get_spec_layer_idx_from_weight_name(
1439
1440
    config: DeepseekV2Config | DeepseekV3Config, weight_name: str
) -> int | None:
1441
1442
1443
1444
    if (
        hasattr(config, "num_nextn_predict_layers")
        and config.num_nextn_predict_layers > 0
    ):
1445
1446
        layer_idx = config.num_hidden_layers
        for i in range(config.num_nextn_predict_layers):
1447
            if weight_name.startswith(f"model.layers.{layer_idx + i}."):
1448
1449
                return layer_idx + i
    return None