deepseek_v2.py 58.1 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
14

Liangsheng Yin's avatar
Liangsheng Yin committed
15
16
17
# Adapted from:
# https://github.com/vllm-project/vllm/blob/fb6af8bc086328ca6659e72d11ffd4309ce4de22/vllm/model_executor/models/deepseek_v2.py
"""Inference-only DeepseekV2 model."""
18

19
import os
Liangsheng Yin's avatar
Liangsheng Yin committed
20
21
22
from typing import Any, Dict, Iterable, Optional, Tuple

import torch
Ke Bao's avatar
Ke Bao committed
23
import torch.nn.functional as F
Liangsheng Yin's avatar
Liangsheng Yin committed
24
25
from torch import nn
from transformers import PretrainedConfig
26
27

from sglang.srt.distributed import (
Liangsheng Yin's avatar
Liangsheng Yin committed
28
    get_tensor_model_parallel_world_size,
29
    parallel_state,
Liangsheng Yin's avatar
Liangsheng Yin committed
30
31
    tensor_model_parallel_all_reduce,
)
32
from sglang.srt.layers.activation import SiluAndMul
Lianmin Zheng's avatar
Lianmin Zheng committed
33
from sglang.srt.layers.dp_attention import (
34
    dp_gather_partial,
Lianmin Zheng's avatar
Lianmin Zheng committed
35
36
37
38
    dp_scatter,
    get_attention_dp_size,
    get_attention_tp_rank,
    get_attention_tp_size,
39
40
    tp_all_gather,
    tp_reduce_scatter,
Lianmin Zheng's avatar
Lianmin Zheng committed
41
)
42
from sglang.srt.layers.layernorm import RMSNorm
43
44
45
46
47
48
from sglang.srt.layers.linear import (
    ColumnParallelLinear,
    MergedColumnParallelLinear,
    ReplicatedLinear,
    RowParallelLinear,
)
Liangsheng Yin's avatar
Liangsheng Yin committed
49
from sglang.srt.layers.logits_processor import LogitsProcessor
50
51
from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, EPMoE
from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
Ke Bao's avatar
Ke Bao committed
52
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
53
from sglang.srt.layers.moe.topk import select_experts
54
from sglang.srt.layers.quantization.base_config import QuantizationConfig
HandH1998's avatar
HandH1998 committed
55
56
57
from sglang.srt.layers.quantization.fp8_utils import (
    block_quant_to_tensor_quant,
    input_to_float8,
58
    normalize_e4m3fn_to_e4m3fnuz,
HandH1998's avatar
HandH1998 committed
59
)
60
61
62
from sglang.srt.layers.quantization.int8_utils import (
    block_dequant as int8_block_dequant,
)
Liangsheng Yin's avatar
Liangsheng Yin committed
63
from sglang.srt.layers.radix_attention import RadixAttention
64
from sglang.srt.layers.rotary_embedding import get_rope, get_rope_wrapper
65
66
67
68
from sglang.srt.layers.vocab_parallel_embedding import (
    ParallelLMHead,
    VocabParallelEmbedding,
)
69
from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder
70
from sglang.srt.managers.schedule_batch import global_server_args_dict
71
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
72
from sglang.srt.model_loader.weight_utils import default_weight_loader
73
from sglang.srt.utils import DeepEPMode, add_prefix, is_cuda, is_hip
74

75
_is_hip = is_hip()
Yineng Zhang's avatar
Yineng Zhang committed
76
_is_cuda = is_cuda()
77

Yineng Zhang's avatar
Yineng Zhang committed
78
79
if _is_cuda:
    from sgl_kernel import awq_dequantize, bmm_fp8
Yineng Zhang's avatar
Yineng Zhang committed
80
81
else:
    from vllm import _custom_ops as ops
Liangsheng Yin's avatar
Liangsheng Yin committed
82

83
84
85
86
87
if _is_hip:
    from sglang.srt.layers.attention.triton_ops.rocm_mla_decode_rope import (
        decode_attention_fwd_grouped_rope,
    )

88
89
expert_distribution_recorder = ExpertDistributionRecorder()

Liangsheng Yin's avatar
Liangsheng Yin committed
90
91
92
93
94
95
96
97
98

class DeepseekV2MLP(nn.Module):
    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
        quant_config: Optional[QuantizationConfig] = None,
        reduce_results: bool = True,
99
        prefix: str = "",
100
101
        tp_rank: Optional[int] = None,
        tp_size: Optional[int] = None,
Liangsheng Yin's avatar
Liangsheng Yin committed
102
103
104
    ) -> None:
        super().__init__()
        self.gate_up_proj = MergedColumnParallelLinear(
105
106
107
108
109
            hidden_size,
            [intermediate_size] * 2,
            bias=False,
            quant_config=quant_config,
            prefix=add_prefix("gate_up_proj", prefix),
110
111
            tp_rank=tp_rank,
            tp_size=tp_size,
Liangsheng Yin's avatar
Liangsheng Yin committed
112
113
114
115
116
117
118
        )
        self.down_proj = RowParallelLinear(
            intermediate_size,
            hidden_size,
            bias=False,
            quant_config=quant_config,
            reduce_results=reduce_results,
119
            prefix=add_prefix("down_proj", prefix),
120
121
            tp_rank=tp_rank,
            tp_size=tp_size,
Liangsheng Yin's avatar
Liangsheng Yin committed
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
        )
        if hidden_act != "silu":
            raise ValueError(
                f"Unsupported activation: {hidden_act}. "
                "Only silu is supported for now."
            )
        self.act_fn = SiluAndMul()

    def forward(self, x):
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.down_proj(x)
        return x


Ke Bao's avatar
Ke Bao committed
137
class MoEGate(nn.Module):
138
139
140
141
142
    def __init__(
        self,
        config,
        prefix: str = "",
    ):
Ke Bao's avatar
Ke Bao committed
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
        super().__init__()
        self.weight = nn.Parameter(
            torch.empty((config.n_routed_experts, config.hidden_size))
        )
        if config.topk_method == "noaux_tc":
            self.e_score_correction_bias = nn.Parameter(
                torch.empty((config.n_routed_experts))
            )
        else:
            self.e_score_correction_bias = None

    def forward(self, hidden_states):
        logits = F.linear(hidden_states, self.weight, None)
        return logits


Liangsheng Yin's avatar
Liangsheng Yin committed
159
160
161
162
163
164
class DeepseekV2MoE(nn.Module):

    def __init__(
        self,
        config: PretrainedConfig,
        quant_config: Optional[QuantizationConfig] = None,
165
        prefix: str = "",
Liangsheng Yin's avatar
Liangsheng Yin committed
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
    ):
        super().__init__()
        self.tp_size = get_tensor_model_parallel_world_size()
        self.routed_scaling_factor = config.routed_scaling_factor
        self.n_shared_experts = config.n_shared_experts
        self.routed_scaling_factor = config.routed_scaling_factor
        if self.tp_size > config.n_routed_experts:
            raise ValueError(
                f"Tensor parallel size {self.tp_size} is greater than "
                f"the number of experts {config.n_routed_experts}."
            )

        if config.hidden_act != "silu":
            raise ValueError(
                f"Unsupported activation: {config.hidden_act}. "
                "Only silu is supported for now."
            )

184
        self.gate = MoEGate(config=config, prefix=add_prefix("gate", prefix))
Ke Bao's avatar
Ke Bao committed
185

186
187
188
189
190
        MoEImpl = (
            DeepEPMoE
            if global_server_args_dict["enable_deepep_moe"]
            else (EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE)
        )
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
        if not global_server_args_dict["enable_deepep_moe"]:
            self.experts = MoEImpl(
                num_experts=config.n_routed_experts,
                top_k=config.num_experts_per_tok,
                hidden_size=config.hidden_size,
                intermediate_size=config.moe_intermediate_size,
                renormalize=config.norm_topk_prob,
                quant_config=quant_config,
                use_grouped_topk=True,
                num_expert_group=config.n_group,
                topk_group=config.topk_group,
                correction_bias=self.gate.e_score_correction_bias,
                prefix=add_prefix("experts", prefix),
            )
        else:
            self.experts = MoEImpl(
                num_experts=config.n_routed_experts,
                top_k=config.num_experts_per_tok,
                hidden_size=config.hidden_size,
                intermediate_size=config.moe_intermediate_size,
                renormalize=config.norm_topk_prob,
                quant_config=quant_config,
                use_grouped_topk=True,
                num_expert_group=config.n_group,
                topk_group=config.topk_group,
                correction_bias=self.gate.e_score_correction_bias,
                prefix=add_prefix("experts", prefix),
218
                deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]],
219
            )
Liangsheng Yin's avatar
Liangsheng Yin committed
220
221
222

        if config.n_shared_experts is not None:
            intermediate_size = config.moe_intermediate_size * config.n_shared_experts
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
            # disable tp for shared experts when enable deepep moe
            if not global_server_args_dict["enable_deepep_moe"]:
                self.shared_experts = DeepseekV2MLP(
                    hidden_size=config.hidden_size,
                    intermediate_size=intermediate_size,
                    hidden_act=config.hidden_act,
                    quant_config=quant_config,
                    reduce_results=False,
                    prefix=add_prefix("shared_experts", prefix),
                )
            else:
                self.shared_experts = DeepseekV2MLP(
                    hidden_size=config.hidden_size,
                    intermediate_size=intermediate_size,
                    hidden_act=config.hidden_act,
                    quant_config=quant_config,
                    reduce_results=False,
                    prefix=add_prefix("shared_experts", prefix),
                    tp_rank=0,
                    tp_size=1,
                )

        if global_server_args_dict["enable_deepep_moe"]:
246
247
            # TODO: we will support tp < ep in the future
            self.ep_size = get_tensor_model_parallel_world_size()
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
            self.num_experts = config.n_routed_experts
            self.top_k = config.num_experts_per_tok
            self.renormalize = config.norm_topk_prob
            self.topk_group = config.topk_group
            self.num_expert_group = config.n_group
            self.correction_bias = (
                self.gate.e_score_correction_bias.data
                if self.gate.e_score_correction_bias is not None
                else None
            )

            self.deepep_dispatcher = DeepEPDispatcher(
                group=parallel_state.get_tp_group().device_group,
                router_topk=self.top_k,
                permute_fusion=True,
                num_experts=config.n_routed_experts,
                num_local_experts=config.n_routed_experts // self.tp_size,
Liangsheng Yin's avatar
Liangsheng Yin committed
265
                hidden_size=config.hidden_size,
266
                params_dtype=config.torch_dtype,
267
                deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]],
fzyzcjy's avatar
fzyzcjy committed
268
                async_finish=True,  # TODO
269
                return_recv_hook=True,
Liangsheng Yin's avatar
Liangsheng Yin committed
270
271
            )

272
273
274
275
276
277
278
279
280
    def forward(
        self, hidden_states: torch.Tensor, forward_mode: Optional[ForwardMode] = None
    ) -> torch.Tensor:
        if not global_server_args_dict["enable_deepep_moe"]:
            return self.forward_normal(hidden_states)
        else:
            return self.forward_deepep(hidden_states, forward_mode)

    def forward_normal(self, hidden_states: torch.Tensor) -> torch.Tensor:
Liangsheng Yin's avatar
Liangsheng Yin committed
281
282
283
        if self.n_shared_experts is not None:
            shared_output = self.shared_experts(hidden_states)
        # router_logits: (num_tokens, n_experts)
Ke Bao's avatar
Ke Bao committed
284
        router_logits = self.gate(hidden_states)
Liangsheng Yin's avatar
Liangsheng Yin committed
285
286
287
288
289
290
291
292
        final_hidden_states = (
            self.experts(hidden_states=hidden_states, router_logits=router_logits)
            * self.routed_scaling_factor
        )
        if shared_output is not None:
            final_hidden_states = final_hidden_states + shared_output
        if self.tp_size > 1:
            final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
fzyzcjy's avatar
fzyzcjy committed
293
        return final_hidden_states
294
295
296
297
298
299
300
301
302
303
304

    def forward_deepep(
        self, hidden_states: torch.Tensor, forward_mode: ForwardMode
    ) -> torch.Tensor:
        shared_output = None
        topk_idx = torch.full(
            (0, self.top_k), -1, dtype=torch.int, device=hidden_states.device
        )
        topk_weights = torch.empty(
            (0, self.top_k), dtype=torch.float32, device=hidden_states.device
        )
305
306
307
308
309
        if (
            forward_mode is not None
            and not forward_mode.is_idle()
            and hidden_states.shape[0] > 0
        ):
310
311
312
313
314
315
316
317
318
319
320
321
322
323
            # router_logits: (num_tokens, n_experts)
            router_logits = self.gate(hidden_states)
            if self.n_shared_experts is not None:
                shared_output = self.shared_experts(hidden_states)
            topk_weights, topk_idx = select_experts(
                hidden_states=hidden_states,
                router_logits=router_logits,
                top_k=self.top_k,
                use_grouped_topk=True,
                renormalize=self.renormalize,
                topk_group=self.topk_group,
                num_expert_group=self.num_expert_group,
                correction_bias=self.correction_bias,
            )
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
        if self.ep_size > 1:
            (
                hidden_states,
                topk_idx,
                topk_weights,
                reorder_topk_ids,
                seg_indptr,
                masked_m,
                expected_m,
            ) = self.deepep_dispatcher.dispatch(
                hidden_states,
                topk_idx,
                topk_weights,
                self.num_experts,
                forward_mode=forward_mode,
339
340
341
            )
        final_hidden_states = (
            self.experts(
342
                hidden_states=hidden_states,
343
344
                reorder_topk_ids=reorder_topk_ids,
                seg_indptr=seg_indptr,
345
346
                masked_m=masked_m,
                expected_m=expected_m,
347
348
349
350
                forward_mode=forward_mode,
            )
            * self.routed_scaling_factor
        )
351
        if self.ep_size > 1:
352
            final_hidden_states = self.deepep_dispatcher.combine(
353
354
355
356
                final_hidden_states,
                topk_idx,
                topk_weights,
                forward_mode,
357
358
359
            )
        if shared_output is not None:
            final_hidden_states = final_hidden_states + shared_output
Liangsheng Yin's avatar
Liangsheng Yin committed
360

fzyzcjy's avatar
fzyzcjy committed
361
        return final_hidden_states
Liangsheng Yin's avatar
Liangsheng Yin committed
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388


def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
    import math

    if scale <= 1:
        return 1.0
    return 0.1 * mscale * math.log(scale) + 1.0


class DeepseekV2Attention(nn.Module):

    def __init__(
        self,
        config: PretrainedConfig,
        hidden_size: int,
        num_heads: int,
        qk_nope_head_dim: int,
        qk_rope_head_dim: int,
        v_head_dim: int,
        q_lora_rank: int,
        kv_lora_rank: int,
        rope_theta: float = 10000,
        rope_scaling: Optional[Dict[str, Any]] = None,
        max_position_embeddings: int = 8192,
        quant_config: Optional[QuantizationConfig] = None,
        layer_id=None,
Lianmin Zheng's avatar
Lianmin Zheng committed
389
        reduce_results: bool = True,
390
        prefix: str = "",
Liangsheng Yin's avatar
Liangsheng Yin committed
391
392
393
394
395
396
397
398
399
400
    ) -> None:
        super().__init__()
        self.layer_id = layer_id
        self.hidden_size = hidden_size
        self.qk_nope_head_dim = qk_nope_head_dim
        self.qk_rope_head_dim = qk_rope_head_dim
        self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
        self.v_head_dim = v_head_dim
        self.q_lora_rank = q_lora_rank
        self.kv_lora_rank = kv_lora_rank
Lianmin Zheng's avatar
Lianmin Zheng committed
401
402
403
404
405

        self.dp_size = get_attention_dp_size()
        attn_tp_rank = get_attention_tp_rank()
        attn_tp_size = get_attention_tp_size()

Liangsheng Yin's avatar
Liangsheng Yin committed
406
        self.num_heads = num_heads
Lianmin Zheng's avatar
Lianmin Zheng committed
407
408
        assert num_heads % attn_tp_size == 0
        self.num_local_heads = num_heads // attn_tp_size
Liangsheng Yin's avatar
Liangsheng Yin committed
409
410
411
412
413
414
415
416
417
418
        self.scaling = self.qk_head_dim**-0.5
        self.rope_theta = rope_theta
        self.max_position_embeddings = max_position_embeddings

        if self.q_lora_rank is not None:
            self.q_a_proj = ReplicatedLinear(
                self.hidden_size,
                self.q_lora_rank,
                bias=False,
                quant_config=quant_config,
419
                prefix=add_prefix("q_a_proj", prefix),
Liangsheng Yin's avatar
Liangsheng Yin committed
420
421
422
423
424
425
426
            )
            self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
            self.q_b_proj = ColumnParallelLinear(
                q_lora_rank,
                self.num_heads * self.qk_head_dim,
                bias=False,
                quant_config=quant_config,
427
                prefix=add_prefix("q_b_proj", prefix),
Liangsheng Yin's avatar
Liangsheng Yin committed
428
429
430
431
432
433
434
            )
        else:
            self.q_proj = ColumnParallelLinear(
                self.hidden_size,
                self.num_heads * self.qk_head_dim,
                bias=False,
                quant_config=quant_config,
435
                prefix=add_prefix("q_proj", prefix),
Lianmin Zheng's avatar
Lianmin Zheng committed
436
437
                tp_rank=attn_tp_rank,
                tp_size=attn_tp_size,
Liangsheng Yin's avatar
Liangsheng Yin committed
438
439
440
441
442
443
444
            )

        self.kv_a_proj_with_mqa = ReplicatedLinear(
            self.hidden_size,
            self.kv_lora_rank + self.qk_rope_head_dim,
            bias=False,
            quant_config=quant_config,
445
            prefix=add_prefix("kv_a_proj_with_mqa", prefix),
Liangsheng Yin's avatar
Liangsheng Yin committed
446
447
448
449
450
451
452
        )
        self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
        self.kv_b_proj = ColumnParallelLinear(
            self.kv_lora_rank,
            self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
            bias=False,
            quant_config=quant_config,
453
            prefix=add_prefix("kv_b_proj", prefix),
Liangsheng Yin's avatar
Liangsheng Yin committed
454
455
456
457
458
459
460
        )
        # O projection.
        self.o_proj = RowParallelLinear(
            self.num_heads * self.v_head_dim,
            self.hidden_size,
            bias=False,
            quant_config=quant_config,
461
            prefix=add_prefix("o_proj", prefix),
Lianmin Zheng's avatar
Lianmin Zheng committed
462
463
464
            reduce_results=reduce_results,
            tp_rank=attn_tp_rank,
            tp_size=attn_tp_size,
Liangsheng Yin's avatar
Liangsheng Yin committed
465
        )
466
        rope_scaling["rope_type"] = "deepseek_yarn"
467
        self.rotary_emb = get_rope_wrapper(
Liangsheng Yin's avatar
Liangsheng Yin committed
468
469
470
471
472
473
            qk_rope_head_dim,
            rotary_dim=qk_rope_head_dim,
            max_position=max_position_embeddings,
            base=rope_theta,
            rope_scaling=rope_scaling,
            is_neox_style=False,
474
            device=global_server_args_dict["device"],
Liangsheng Yin's avatar
Liangsheng Yin committed
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
        )

        if rope_scaling:
            mscale_all_dim = rope_scaling.get("mscale_all_dim", False)
            scaling_factor = rope_scaling["factor"]
            mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
            self.scaling = self.scaling * mscale * mscale

        # TODO, support head_size 192
        self.attn = RadixAttention(
            self.num_local_heads,
            256,
            self.scaling,
            num_kv_heads=self.num_local_heads,
            layer_id=layer_id,
490
            prefix=add_prefix("attn", prefix),
Liangsheng Yin's avatar
Liangsheng Yin committed
491
492
493
494
495
496
        )

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
497
        forward_batch: ForwardBatch,
Liangsheng Yin's avatar
Liangsheng Yin committed
498
    ) -> torch.Tensor:
Lianmin Zheng's avatar
Lianmin Zheng committed
499
500
501
502
503
504
        if hidden_states.shape[0] == 0:
            assert (
                not self.o_proj.reduce_results
            ), "short-circuiting allreduce will lead to hangs"
            return hidden_states

Liangsheng Yin's avatar
Liangsheng Yin committed
505
506
507
508
509
510
511
512
        if self.q_lora_rank is not None:
            q = self.q_a_proj(hidden_states)[0]
            q = self.q_a_layernorm(q)
            q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
        else:
            q = self.q_proj(hidden_states)[0].view(
                -1, self.num_local_heads, self.qk_head_dim
            )
513
        _, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
Liangsheng Yin's avatar
Liangsheng Yin committed
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
        latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
        kv_a, _ = latent_cache.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
        latent_cache = latent_cache.unsqueeze(1)
        kv_a = self.kv_a_layernorm(kv_a.contiguous())
        kv = self.kv_b_proj(kv_a)[0]
        kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim)
        k_nope, v = kv.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
        k_pe = latent_cache[:, :, self.kv_lora_rank :]
        q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
        q[..., self.qk_nope_head_dim :] = q_pe
        k = torch.empty_like(q)
        k[..., : self.qk_nope_head_dim] = k_nope
        k[..., self.qk_nope_head_dim :] = k_pe
        q = torch.nn.functional.pad(q, [0, 256 - self.qk_head_dim], value=0).view(
            -1, self.num_local_heads * 256
        )
        k = torch.nn.functional.pad(k, [0, 256 - self.qk_head_dim], value=0).view(
            -1, self.num_local_heads * 256
        )
        v = torch.nn.functional.pad(v, [0, 256 - self.v_head_dim], value=0).view(
            -1, self.num_local_heads * 256
        )
536
        attn_output = self.attn(q, k, v, forward_batch)
Liangsheng Yin's avatar
Liangsheng Yin committed
537
538
539
540
541
542
543
        attn_output = attn_output.view(-1, self.num_local_heads, 256)[
            ..., : self.v_head_dim
        ].reshape(-1, self.num_local_heads * self.v_head_dim)
        output, _ = self.o_proj(attn_output)
        return output


544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
class DeepseekV2AttentionMLA(nn.Module):

    def __init__(
        self,
        config: PretrainedConfig,
        hidden_size: int,
        num_heads: int,
        qk_nope_head_dim: int,
        qk_rope_head_dim: int,
        v_head_dim: int,
        q_lora_rank: int,
        kv_lora_rank: int,
        rope_theta: float = 10000,
        rope_scaling: Optional[Dict[str, Any]] = None,
        max_position_embeddings: int = 8192,
        quant_config: Optional[QuantizationConfig] = None,
Lianmin Zheng's avatar
Lianmin Zheng committed
560
561
        reduce_results: bool = True,
        layer_id: int = None,
562
        prefix: str = "",
563
564
565
566
567
568
569
570
571
572
    ) -> None:
        super().__init__()
        self.layer_id = layer_id
        self.hidden_size = hidden_size
        self.qk_nope_head_dim = qk_nope_head_dim
        self.qk_rope_head_dim = qk_rope_head_dim
        self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
        self.v_head_dim = v_head_dim
        self.q_lora_rank = q_lora_rank
        self.kv_lora_rank = kv_lora_rank
Lianmin Zheng's avatar
Lianmin Zheng committed
573
574
575
576
        self.dp_size = get_attention_dp_size()
        attn_tp_rank = get_attention_tp_rank()
        attn_tp_size = get_attention_tp_size()

577
        self.num_heads = num_heads
Lianmin Zheng's avatar
Lianmin Zheng committed
578
579
        assert num_heads % attn_tp_size == 0
        self.num_local_heads = num_heads // attn_tp_size
580
581
582
583
        self.scaling = self.qk_head_dim**-0.5
        self.rope_theta = rope_theta
        self.max_position_embeddings = max_position_embeddings

Lianmin Zheng's avatar
Lianmin Zheng committed
584
585
586
        # For tensor parallel attention
        if self.q_lora_rank is not None:
            self.q_a_proj = ReplicatedLinear(
Ke Bao's avatar
Ke Bao committed
587
                self.hidden_size,
Lianmin Zheng's avatar
Lianmin Zheng committed
588
                self.q_lora_rank,
589
590
                bias=False,
                quant_config=quant_config,
Lianmin Zheng's avatar
Lianmin Zheng committed
591
                prefix=add_prefix("q_a_proj", prefix),
592
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
593
594
595
596
            self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
            self.q_b_proj = ColumnParallelLinear(
                q_lora_rank,
                self.num_heads * self.qk_head_dim,
Ke Bao's avatar
Ke Bao committed
597
598
                bias=False,
                quant_config=quant_config,
Lianmin Zheng's avatar
Lianmin Zheng committed
599
600
601
                prefix=add_prefix("q_b_proj", prefix),
                tp_rank=attn_tp_rank,
                tp_size=attn_tp_size,
Ke Bao's avatar
Ke Bao committed
602
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
603
604
        else:
            self.q_proj = ColumnParallelLinear(
605
                self.hidden_size,
Lianmin Zheng's avatar
Lianmin Zheng committed
606
                self.num_heads * self.qk_head_dim,
607
608
                bias=False,
                quant_config=quant_config,
Lianmin Zheng's avatar
Lianmin Zheng committed
609
610
611
                prefix=add_prefix("q_proj", prefix),
                tp_rank=attn_tp_rank,
                tp_size=attn_tp_size,
612
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
        self.kv_b_proj = ColumnParallelLinear(
            self.kv_lora_rank,
            self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
            bias=False,
            quant_config=quant_config,
            prefix=add_prefix("kv_b_proj", prefix),
            tp_rank=attn_tp_rank,
            tp_size=attn_tp_size,
        )
        # O projection.
        self.o_proj = RowParallelLinear(
            self.num_heads * self.v_head_dim,
            self.hidden_size,
            bias=False,
            quant_config=quant_config,
            reduce_results=reduce_results,
            prefix=add_prefix("o_proj", prefix),
            tp_rank=attn_tp_rank,
            tp_size=attn_tp_size,
        )
633
634
635
636
637
638

        self.kv_a_proj_with_mqa = ReplicatedLinear(
            self.hidden_size,
            self.kv_lora_rank + self.qk_rope_head_dim,
            bias=False,
            quant_config=quant_config,
639
            prefix=add_prefix("kv_a_proj_with_mqa", prefix),
640
641
        )
        self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
Ke Bao's avatar
Ke Bao committed
642
643
644
645

        if rope_scaling:
            rope_scaling["rope_type"] = "deepseek_yarn"

646
        self.rotary_emb = get_rope(
647
648
649
650
651
652
653
654
655
656
657
658
659
            qk_rope_head_dim,
            rotary_dim=qk_rope_head_dim,
            max_position=max_position_embeddings,
            base=rope_theta,
            rope_scaling=rope_scaling,
            is_neox_style=False,
        )

        if rope_scaling:
            mscale_all_dim = rope_scaling.get("mscale_all_dim", False)
            scaling_factor = rope_scaling["factor"]
            mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
            self.scaling = self.scaling * mscale * mscale
Ke Bao's avatar
Ke Bao committed
660
661
        else:
            self.rotary_emb.forward = self.rotary_emb.forward_native
662

663
        self.attn_mqa = RadixAttention(
664
665
666
667
668
669
            self.num_local_heads,
            self.kv_lora_rank + self.qk_rope_head_dim,
            self.scaling,
            num_kv_heads=1,
            layer_id=layer_id,
            v_head_dim=self.kv_lora_rank,
670
            prefix=add_prefix("attn_mqa", prefix),
671
672
        )

673
674
675
676
677
678
679
        self.attn_mha = RadixAttention(
            self.num_local_heads,
            self.qk_nope_head_dim + self.qk_rope_head_dim,
            self.scaling,
            num_kv_heads=self.num_local_heads,
            layer_id=layer_id,
            v_head_dim=self.v_head_dim,
680
            prefix=add_prefix("attn_mha", prefix),
681
682
        )

Ke Bao's avatar
Ke Bao committed
683
684
        self.w_kc = None
        self.w_vc = None
685
        self.w_scale = None
686

Lianmin Zheng's avatar
Lianmin Zheng committed
687
688
689
        self.flashinfer_mla_disable_ragged = global_server_args_dict[
            "flashinfer_mla_disable_ragged"
        ]
690
        self.attention_backend = global_server_args_dict["attention_backend"]
Lianmin Zheng's avatar
Lianmin Zheng committed
691
692
693
        self.rocm_fused_decode_mla = os.getenv("SGLANG_ROCM_FUSED_DECODE_MLA") == "1"

    def no_absorb(self, forward_batch: ForwardBatch) -> bool:
694
        if self.attention_backend == "flashinfer_mla":
Lianmin Zheng's avatar
Lianmin Zheng committed
695
696
697
698
699
700
            # Flashinfer MLA: Do not absorb when enabling ragged prefill
            return (
                not self.flashinfer_mla_disable_ragged
                and forward_batch.forward_mode.is_extend()
                and not forward_batch.forward_mode.is_target_verify()
                and not forward_batch.forward_mode.is_draft_extend()
701
                and sum(forward_batch.extend_prefix_lens_cpu) == 0
Lianmin Zheng's avatar
Lianmin Zheng committed
702
            )
703
704
705
        elif self.attention_backend == "fa3":
            # Flash Attention: Keep absorbing for all extend/decode
            return False
Lianmin Zheng's avatar
Lianmin Zheng committed
706
707
708
709
710
711
        else:
            # Triton: Use normal computation for prefill and use weight absorption for extend/decode
            return (
                forward_batch.forward_mode.is_extend()
                and not forward_batch.forward_mode.is_target_verify()
                and not forward_batch.forward_mode.is_draft_extend()
712
                and sum(forward_batch.extend_prefix_lens_cpu) == 0
Lianmin Zheng's avatar
Lianmin Zheng committed
713
714
            )

715
716
717
718
    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
719
        forward_batch: ForwardBatch,
720
    ) -> torch.Tensor:
Lianmin Zheng's avatar
Lianmin Zheng committed
721
722
723
724
725
        if hidden_states.shape[0] == 0:
            assert (
                not self.o_proj.reduce_results
            ), "short-circuiting allreduce will lead to hangs"
            return hidden_states
726

Lianmin Zheng's avatar
Lianmin Zheng committed
727
        if self.no_absorb(forward_batch):
728
            return self.forward_normal(positions, hidden_states, forward_batch)
729
        else:
730
            if _is_hip:
731
                if (
Lianmin Zheng's avatar
Lianmin Zheng committed
732
                    self.rocm_fused_decode_mla
733
734
735
736
737
738
739
740
741
                    and forward_batch.forward_mode.is_decode()
                ):
                    return self.forward_absorb_fused_mla_rope(
                        positions, hidden_states, forward_batch
                    )
                else:
                    return self.forward_absorb(positions, hidden_states, forward_batch)
            else:
                return self.forward_absorb(positions, hidden_states, forward_batch)
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789

    def forward_normal(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        forward_batch: ForwardBatch,
    ) -> torch.Tensor:
        if self.q_lora_rank is not None:
            q = self.q_a_proj(hidden_states)[0]
            q = self.q_a_layernorm(q)
            q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
        else:
            q = self.q_proj(hidden_states)[0].view(
                -1, self.num_local_heads, self.qk_head_dim
            )
        _, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
        latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
        kv_a, _ = latent_cache.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
        latent_cache = latent_cache.unsqueeze(1)
        kv_a = self.kv_a_layernorm(kv_a.contiguous())
        kv = self.kv_b_proj(kv_a)[0]
        kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim)
        k_nope = kv[..., : self.qk_nope_head_dim]
        v = kv[..., self.qk_nope_head_dim :]
        k_pe = latent_cache[:, :, self.kv_lora_rank :]
        q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
        q[..., self.qk_nope_head_dim :] = q_pe
        k = torch.empty_like(q)
        k[..., : self.qk_nope_head_dim] = k_nope
        k[..., self.qk_nope_head_dim :] = k_pe

        latent_cache[:, :, : self.kv_lora_rank] = kv_a.unsqueeze(1)
        latent_cache[:, :, self.kv_lora_rank :] = k_pe

        # Save latent cache
        forward_batch.token_to_kv_pool.set_kv_buffer(
            self.attn_mha, forward_batch.out_cache_loc, latent_cache, None
        )
        attn_output = self.attn_mha(q, k, v, forward_batch, save_kv_cache=False)
        attn_output = attn_output.reshape(-1, self.num_local_heads * self.v_head_dim)
        output, _ = self.o_proj(attn_output)
        return output

    def forward_absorb(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        forward_batch: ForwardBatch,
790
791
792
793
794
795
796
797
798
799
800
801
802
803
    ) -> torch.Tensor:
        q_len = hidden_states.shape[0]
        q_input = hidden_states.new_empty(
            q_len, self.num_local_heads, self.kv_lora_rank + self.qk_rope_head_dim
        )
        if self.q_lora_rank is not None:
            q = self.q_a_proj(hidden_states)[0]
            q = self.q_a_layernorm(q)
            q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
        else:
            q = self.q_proj(hidden_states)[0].view(
                -1, self.num_local_heads, self.qk_head_dim
            )
        q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
804

805
        if self.w_kc.dtype == torch.float8_e4m3fnuz:
806
807
808
809
810
            # TODO(kernel): add bmm_fp8 for torch.float8_e4m3fnuz
            q_nope_out = torch.bmm(
                q_nope.to(torch.bfloat16).transpose(0, 1),
                self.w_kc.to(torch.bfloat16) * self.w_scale,
            )
811
        elif self.w_kc.dtype == torch.float8_e4m3fn:
812
813
814
815
816
817
818
819
820
            q_nope_val, q_nope_scale = input_to_float8(
                q_nope.transpose(0, 1), torch.float8_e4m3fn
            )
            q_nope_out = bmm_fp8(
                q_nope_val, self.w_kc, q_nope_scale, self.w_scale, torch.bfloat16
            )
        else:
            q_nope_out = torch.bmm(q_nope.transpose(0, 1), self.w_kc)
        q_input[..., : self.kv_lora_rank] = q_nope_out.transpose(0, 1)
821

Ke Bao's avatar
Ke Bao committed
822
823
824
825
        latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
        v_input = latent_cache[..., : self.kv_lora_rank]
        v_input = self.kv_a_layernorm(v_input.contiguous()).unsqueeze(1)
        k_input = latent_cache.unsqueeze(1)
826
        k_input[..., : self.kv_lora_rank] = v_input
Ke Bao's avatar
Ke Bao committed
827
        k_pe = k_input[..., self.kv_lora_rank :]
828
829
830
831
832

        q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
        q_input[..., self.kv_lora_rank :] = q_pe
        k_input[..., self.kv_lora_rank :] = k_pe

833
        attn_output = self.attn_mqa(q_input, k_input, v_input, forward_batch)
834
835
        attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)

836
        if self.w_vc.dtype == torch.float8_e4m3fnuz:
837
838
839
840
841
            # TODO(kernel): add bmm_fp8 for torch.float8_e4m3fnuz
            attn_bmm_output = torch.bmm(
                attn_output.to(torch.bfloat16).transpose(0, 1),
                self.w_vc.to(torch.bfloat16) * self.w_scale,
            )
842
        elif self.w_vc.dtype == torch.float8_e4m3fn:
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
            attn_output_val, attn_output_scale = input_to_float8(
                attn_output.transpose(0, 1), torch.float8_e4m3fn
            )
            attn_bmm_output = bmm_fp8(
                attn_output_val,
                self.w_vc,
                attn_output_scale,
                self.w_scale,
                torch.bfloat16,
            )
        else:
            attn_bmm_output = torch.bmm(attn_output.transpose(0, 1), self.w_vc)
        attn_output = attn_bmm_output.transpose(0, 1).flatten(1, 2)
        output, _ = self.o_proj(attn_output)

        return output

    def forward_absorb_fused_mla_rope(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        forward_batch: ForwardBatch,
    ) -> torch.Tensor:
        enable_rope_fusion = (
            os.getenv("SGLANG_FUSED_MLA_ENABLE_ROPE_FUSION", "1") == "1"
        )
        q_len = hidden_states.shape[0]
        q_input = hidden_states.new_empty(
            q_len, self.num_local_heads, self.kv_lora_rank + self.qk_rope_head_dim
        )
        if self.q_lora_rank is not None:
            q = self.q_a_proj(hidden_states)[0]
            q = self.q_a_layernorm(q)
            q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
        else:
            q = self.q_proj(hidden_states)[0].view(
                -1, self.num_local_heads, self.qk_head_dim
            )
        q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)

        if self.w_kc.dtype == torch.float8_e4m3fnuz:
            # TODO(kernel): add bmm_fp8 for torch.float8_e4m3fnuz
            q_nope_out = torch.bmm(
                q_nope.to(torch.bfloat16).transpose(0, 1),
                self.w_kc.to(torch.bfloat16) * self.w_scale,
            )
        elif self.w_kc.dtype == torch.float8_e4m3fn:
            q_nope_val, q_nope_scale = input_to_float8(
                q_nope.transpose(0, 1), torch.float8_e4m3fn
            )
            q_nope_out = bmm_fp8(
                q_nope_val, self.w_kc, q_nope_scale, self.w_scale, torch.bfloat16
            )
        else:
            q_nope_out = torch.bmm(q_nope.transpose(0, 1), self.w_kc)
        q_input[..., : self.kv_lora_rank] = q_nope_out.transpose(0, 1)

        latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
        v_input = latent_cache[..., : self.kv_lora_rank]
        v_input = self.kv_a_layernorm(v_input.contiguous()).unsqueeze(1)
        k_input = latent_cache.unsqueeze(1)
        k_input[..., : self.kv_lora_rank] = v_input

        if not enable_rope_fusion:
            k_pe = k_input[..., self.kv_lora_rank :]
            q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
            q_input[..., self.kv_lora_rank :] = q_pe
            k_input[..., self.kv_lora_rank :] = k_pe
            k_pe_output = None
        else:
            k_pe_output = torch.empty_like(k_input[..., self.kv_lora_rank :])

        q_input[..., self.kv_lora_rank :] = q_pe

        # attn_output = self.attn_mqa(q_input, k_input, v_input, forward_batch)
        # Use Fused ROPE with use_rope=OFF.
        attn_output = torch.empty(
            (q_len, self.num_local_heads, self.kv_lora_rank),
            dtype=q.dtype,
            device=q.device,
        )
        attn_logits, _, kv_indptr, kv_indices, _, _, _ = (
            forward_batch.attn_backend.forward_metadata
        )
        cos_sin_cache = self.rotary_emb.cos_sin_cache
        num_kv_split = forward_batch.attn_backend.num_kv_splits
        sm_scale = self.attn_mqa.scaling
        if attn_logits is None:
            attn_logits = torch.empty(
                (
                    forward_batch.batch_size,
                    self.num_local_heads,
                    num_kv_split,
                    self.kv_lora_rank + 1,
                ),
                dtype=torch.float32,
                device=q.device,
            )

        # save current latent cache.
        forward_batch.token_to_kv_pool.set_kv_buffer(
            self.attn_mqa, forward_batch.out_cache_loc, k_input, None
        )
        key_cache_buf = forward_batch.token_to_kv_pool.get_key_buffer(
            self.attn_mqa.layer_id
        )
        val_cache_buf = key_cache_buf[..., : self.kv_lora_rank]

        decode_attention_fwd_grouped_rope(
            q_input,
            key_cache_buf,
            val_cache_buf,
            attn_output,
            kv_indptr,
            kv_indices,
            k_pe_output,
            self.kv_lora_rank,
            self.rotary_emb.rotary_dim,
            cos_sin_cache,
            positions,
            attn_logits,
            num_kv_split,
            sm_scale,
            logit_cap=self.attn_mqa.logit_cap,
            use_rope=enable_rope_fusion,
            is_neox_style=self.rotary_emb.is_neox_style,
        )

        if enable_rope_fusion:
            k_input[..., self.kv_lora_rank :] = k_pe_output
            forward_batch.token_to_kv_pool.set_kv_buffer(
                self.attn_mqa, forward_batch.out_cache_loc, k_input, None
            )

        attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)

979
980
981
982
983
984
985
        if self.w_vc.dtype == torch.float8_e4m3fnuz:
            # TODO(kernel): add bmm_fp8 for torch.float8_e4m3fnuz
            attn_bmm_output = torch.bmm(
                attn_output.to(torch.bfloat16).transpose(0, 1),
                self.w_vc.to(torch.bfloat16) * self.w_scale,
            )
        elif self.w_vc.dtype == torch.float8_e4m3fn:
986
987
988
989
990
991
992
993
994
995
996
997
998
            attn_output_val, attn_output_scale = input_to_float8(
                attn_output.transpose(0, 1), torch.float8_e4m3fn
            )
            attn_bmm_output = bmm_fp8(
                attn_output_val,
                self.w_vc,
                attn_output_scale,
                self.w_scale,
                torch.bfloat16,
            )
        else:
            attn_bmm_output = torch.bmm(attn_output.transpose(0, 1), self.w_vc)
        attn_output = attn_bmm_output.transpose(0, 1).flatten(1, 2)
999
1000
1001
1002
1003
        output, _ = self.o_proj(attn_output)

        return output


Liangsheng Yin's avatar
Liangsheng Yin committed
1004
1005
1006
1007
1008
1009
1010
class DeepseekV2DecoderLayer(nn.Module):

    def __init__(
        self,
        config: PretrainedConfig,
        layer_id: int,
        quant_config: Optional[QuantizationConfig] = None,
1011
        is_nextn: bool = False,
1012
        prefix: str = "",
Liangsheng Yin's avatar
Liangsheng Yin committed
1013
    ) -> None:
1014
1015
1016
1017
1018
1019
1020
1021

        def is_sparse_layer(l: int):
            return (
                config.n_routed_experts is not None
                and l >= config.first_k_dense_replace
                and l % config.moe_layer_freq == 0
            )

Liangsheng Yin's avatar
Liangsheng Yin committed
1022
1023
1024
1025
1026
        super().__init__()
        self.hidden_size = config.hidden_size
        rope_theta = getattr(config, "rope_theta", 10000)
        rope_scaling = getattr(config, "rope_scaling", None)
        max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
Lianmin Zheng's avatar
Lianmin Zheng committed
1027
1028
1029
        self.enable_dp_attention = global_server_args_dict["enable_dp_attention"]
        self.layer_id = layer_id
        self.dp_size = get_attention_dp_size()
1030
1031
        self.attn_tp_size = get_attention_tp_size()
        self.attn_tp_rank = get_attention_tp_rank()
Lianmin Zheng's avatar
Lianmin Zheng committed
1032

Ke Bao's avatar
Ke Bao committed
1033
        if not global_server_args_dict["disable_mla"]:
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
            self.self_attn = DeepseekV2AttentionMLA(
                config=config,
                hidden_size=self.hidden_size,
                num_heads=config.num_attention_heads,
                qk_nope_head_dim=config.qk_nope_head_dim,
                qk_rope_head_dim=config.qk_rope_head_dim,
                v_head_dim=config.v_head_dim,
                q_lora_rank=(
                    config.q_lora_rank if hasattr(config, "q_lora_rank") else None
                ),
                kv_lora_rank=config.kv_lora_rank,
                rope_theta=rope_theta,
                rope_scaling=rope_scaling,
                max_position_embeddings=max_position_embeddings,
                quant_config=quant_config,
                layer_id=layer_id,
Lianmin Zheng's avatar
Lianmin Zheng committed
1050
                reduce_results=False,
1051
                prefix=add_prefix("self_attn", prefix),
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
            )
        else:
            self.self_attn = DeepseekV2Attention(
                config=config,
                hidden_size=self.hidden_size,
                num_heads=config.num_attention_heads,
                qk_nope_head_dim=config.qk_nope_head_dim,
                qk_rope_head_dim=config.qk_rope_head_dim,
                v_head_dim=config.v_head_dim,
                q_lora_rank=(
                    config.q_lora_rank if hasattr(config, "q_lora_rank") else None
                ),
                kv_lora_rank=config.kv_lora_rank,
                rope_theta=rope_theta,
                rope_scaling=rope_scaling,
                max_position_embeddings=max_position_embeddings,
                quant_config=quant_config,
                layer_id=layer_id,
Lianmin Zheng's avatar
Lianmin Zheng committed
1070
                reduce_results=False,
1071
                prefix=add_prefix("self_attn", prefix),
1072
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
1073

1074
        if is_nextn or is_sparse_layer(layer_id):
1075
1076
1077
1078
1079
            self.mlp = DeepseekV2MoE(
                config=config,
                quant_config=quant_config,
                prefix=add_prefix("mlp", prefix),
            )
1080
            self.is_sparse = True
Liangsheng Yin's avatar
Liangsheng Yin committed
1081
1082
1083
1084
1085
1086
        else:
            self.mlp = DeepseekV2MLP(
                hidden_size=config.hidden_size,
                intermediate_size=config.intermediate_size,
                hidden_act=config.hidden_act,
                quant_config=quant_config,
1087
                prefix=add_prefix("mlp", prefix),
Liangsheng Yin's avatar
Liangsheng Yin committed
1088
            )
1089
1090
1091
1092
1093
1094
1095
1096
            self.is_sparse = False

        self.input_is_scattered = (
            is_sparse_layer(layer_id - 1)
            and global_server_args_dict["enable_deepep_moe"]
        )
        self.is_last_layer = self.layer_id == config.num_hidden_layers - 1

Liangsheng Yin's avatar
Liangsheng Yin committed
1097
1098
1099
1100
1101
1102
1103
1104
1105
        self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(
            config.hidden_size, eps=config.rms_norm_eps
        )

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
1106
        forward_batch: ForwardBatch,
Liangsheng Yin's avatar
Liangsheng Yin committed
1107
1108
        residual: Optional[torch.Tensor],
    ) -> torch.Tensor:
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
        if global_server_args_dict["enable_deepep_moe"] and self.is_sparse:
            return self.forward_deepep(
                positions, hidden_states, forward_batch, residual
            )
        else:
            return self.forward_normal(
                positions, hidden_states, forward_batch, residual
            )

    def forward_normal(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        forward_batch: ForwardBatch,
        residual: Optional[torch.Tensor],
    ) -> torch.Tensor:

1126
        if hidden_states.shape[0] == 0:
Lianmin Zheng's avatar
Lianmin Zheng committed
1127
1128
            residual = hidden_states
        else:
1129
1130
1131
1132
1133
            if residual is None:
                residual = hidden_states
                hidden_states = self.input_layernorm(hidden_states)
            else:
                hidden_states, residual = self.input_layernorm(hidden_states, residual)
Lianmin Zheng's avatar
Lianmin Zheng committed
1134

1135
1136
1137
1138
            assert not (
                self.attn_tp_size != 1 and self.input_is_scattered
            ), "moe_layer_freq > 1 is not supported when attn_tp_size > 1"

1139
1140
1141
1142
1143
            # Self Attention
            hidden_states = self.self_attn(
                positions=positions,
                hidden_states=hidden_states,
                forward_batch=forward_batch,
Lianmin Zheng's avatar
Lianmin Zheng committed
1144
1145
1146
1147
1148
1149
            )

        # Gather
        if get_tensor_model_parallel_world_size() > 1:
            # all gather and all reduce
            if self.dp_size != 1:
1150
1151
1152
1153
1154
1155
1156
1157
1158
                if self.attn_tp_rank == 0:
                    hidden_states += residual
                hidden_states, local_hidden_states = (
                    forward_batch.gathered_buffer,
                    hidden_states,
                )
                dp_gather_partial(hidden_states, local_hidden_states, forward_batch)
                dp_scatter(residual, hidden_states, forward_batch)
                hidden_states = self.post_attention_layernorm(hidden_states)
Ke Bao's avatar
Ke Bao committed
1159
            else:
Lianmin Zheng's avatar
Lianmin Zheng committed
1160
                hidden_states = tensor_model_parallel_all_reduce(hidden_states)
1161
1162
1163
1164
1165
1166
1167
                hidden_states, residual = self.post_attention_layernorm(
                    hidden_states, residual
                )
        else:
            hidden_states, residual = self.post_attention_layernorm(
                hidden_states, residual
            )
Liangsheng Yin's avatar
Liangsheng Yin committed
1168
1169

        # Fully Connected
Lianmin Zheng's avatar
Lianmin Zheng committed
1170
        hidden_states = self.mlp(hidden_states)
1171

1172
        # TODO(ch-wan): ues reduce-scatter in MLP to avoid this scatter
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
        # Scatter
        if self.dp_size != 1:
            # important: forward batch.gathered_buffer is used both after scatter and after gather.
            # be careful about this!
            hidden_states, global_hidden_states = (
                forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
                hidden_states,
            )
            dp_scatter(hidden_states, global_hidden_states, forward_batch)

Liangsheng Yin's avatar
Liangsheng Yin committed
1183
1184
        return hidden_states, residual

1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
    def forward_deepep(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        forward_batch: ForwardBatch,
        residual: Optional[torch.Tensor],
    ) -> torch.Tensor:

        if hidden_states.shape[0] == 0:
            residual = hidden_states
        else:
            if residual is None:
                residual = hidden_states
                hidden_states = self.input_layernorm(hidden_states)
            else:
                hidden_states, residual = self.input_layernorm(hidden_states, residual)

        if self.attn_tp_size != 1 and self.input_is_scattered:
            hidden_states, local_hidden_states = (
                forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
                hidden_states,
            )
            tp_all_gather(
                list(hidden_states.tensor_split(self.attn_tp_size)), local_hidden_states
            )

        # Self Attention
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
            forward_batch=forward_batch,
        )

        if self.attn_tp_size != 1:
            if self.input_is_scattered:
                tensor_list = list(hidden_states.tensor_split(self.attn_tp_size))
                hidden_states = tensor_list[self.attn_tp_rank]
                tp_reduce_scatter(hidden_states, tensor_list)
                if hidden_states.shape[0] != 0:
                    hidden_states, residual = self.post_attention_layernorm(
                        hidden_states, residual
                    )
            else:
                if self.attn_tp_rank == 0:
                    hidden_states += residual
                tensor_list = list(hidden_states.tensor_split(self.attn_tp_size))
                hidden_states = tensor_list[self.attn_tp_rank]
                tp_reduce_scatter(hidden_states, tensor_list)
                residual = hidden_states
                if hidden_states.shape[0] != 0:
                    hidden_states = self.post_attention_layernorm(hidden_states)
        else:
            if hidden_states.shape[0] != 0:
                hidden_states, residual = self.post_attention_layernorm(
                    hidden_states, residual
                )
        hidden_states = self.mlp(hidden_states, forward_batch.forward_mode)

        if self.is_last_layer and self.attn_tp_size != 1:
1244
1245
            hidden_states += residual
            residual = None
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
            hidden_states, local_hidden_states = (
                forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
                hidden_states,
            )
            tp_all_gather(
                list(hidden_states.tensor_split(self.attn_tp_size)), local_hidden_states
            )

        return hidden_states, residual

Liangsheng Yin's avatar
Liangsheng Yin committed
1256
1257
1258
1259
1260
1261
1262
1263

class DeepseekV2Model(nn.Module):
    fall_back_to_pt_during_load = False

    def __init__(
        self,
        config: PretrainedConfig,
        quant_config: Optional[QuantizationConfig] = None,
1264
        prefix: str = "",
Liangsheng Yin's avatar
Liangsheng Yin committed
1265
1266
1267
1268
1269
1270
1271
1272
    ) -> None:
        super().__init__()
        self.padding_id = config.pad_token_id
        self.vocab_size = config.vocab_size

        self.embed_tokens = VocabParallelEmbedding(
            config.vocab_size,
            config.hidden_size,
Ke Bao's avatar
Ke Bao committed
1273
            enable_tp=not global_server_args_dict["enable_dp_attention"],
Liangsheng Yin's avatar
Liangsheng Yin committed
1274
1275
1276
1277
1278
1279
1280
        )
        self.layers = nn.ModuleList(
            [
                DeepseekV2DecoderLayer(
                    config,
                    layer_id,
                    quant_config=quant_config,
1281
                    prefix=add_prefix(f"layers.{layer_id}", prefix),
Liangsheng Yin's avatar
Liangsheng Yin committed
1282
1283
1284
1285
1286
1287
                )
                for layer_id in range(config.num_hidden_layers)
            ]
        )
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

Lianmin Zheng's avatar
Lianmin Zheng committed
1288
1289
        self.dp_size = get_attention_dp_size()

Liangsheng Yin's avatar
Liangsheng Yin committed
1290
1291
1292
1293
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
1294
        forward_batch: ForwardBatch,
1295
        input_embeds: torch.Tensor = None,
Liangsheng Yin's avatar
Liangsheng Yin committed
1296
    ) -> torch.Tensor:
Lianmin Zheng's avatar
Lianmin Zheng committed
1297

1298
1299
1300
1301
1302
        if input_embeds is None:
            hidden_states = self.embed_tokens(input_ids)
        else:
            hidden_states = input_embeds

Liangsheng Yin's avatar
Liangsheng Yin committed
1303
1304
        residual = None
        for i in range(len(self.layers)):
1305
            expert_distribution_recorder.set_current_layer(i)
Liangsheng Yin's avatar
Liangsheng Yin committed
1306
1307
            layer = self.layers[i]
            hidden_states, residual = layer(
1308
                positions, hidden_states, forward_batch, residual
Liangsheng Yin's avatar
Liangsheng Yin committed
1309
            )
Ke Bao's avatar
Ke Bao committed
1310
        if not forward_batch.forward_mode.is_idle():
1311
1312
1313
1314
            if residual is None:
                hidden_states = self.norm(hidden_states)
            else:
                hidden_states, _ = self.norm(hidden_states, residual)
Liangsheng Yin's avatar
Liangsheng Yin committed
1315
1316
1317
1318
1319
1320
1321
1322
1323
        return hidden_states


class DeepseekV2ForCausalLM(nn.Module):

    def __init__(
        self,
        config: PretrainedConfig,
        quant_config: Optional[QuantizationConfig] = None,
1324
        prefix: str = "",
Liangsheng Yin's avatar
Liangsheng Yin committed
1325
1326
1327
1328
    ) -> None:
        super().__init__()
        self.config = config
        self.quant_config = quant_config
1329
1330
1331
        self.model = DeepseekV2Model(
            config, quant_config, prefix=add_prefix("model", prefix)
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
1332
1333
1334
1335
1336
1337
1338
1339
        self.lm_head = ParallelLMHead(
            config.vocab_size,
            config.hidden_size,
            quant_config=quant_config,
            prefix=add_prefix("lm_head", prefix),
        )
        self.logits_processor = LogitsProcessor(config)
        self.dp_size = get_attention_dp_size()
Liangsheng Yin's avatar
Liangsheng Yin committed
1340

Mick's avatar
Mick committed
1341
1342
1343
    def get_input_embeddings(self) -> nn.Embedding:
        return self.model.embed_tokens

1344
    @torch.no_grad()
Liangsheng Yin's avatar
Liangsheng Yin committed
1345
1346
1347
1348
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
1349
        forward_batch: ForwardBatch,
1350
        input_embeds: torch.Tensor = None,
Liangsheng Yin's avatar
Liangsheng Yin committed
1351
    ) -> torch.Tensor:
1352
1353

        hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
Lianmin Zheng's avatar
Lianmin Zheng committed
1354

1355
1356
1357
        return self.logits_processor(
            input_ids, hidden_states, self.lm_head, forward_batch
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367

    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
        ]

        # Params for weights, fp8 weight scales, fp8 activation scales
        # (param_name, weight_name, expert_id, shard_id)
1368
1369
1370
1371
1372
        MoEImpl = (
            DeepEPMoE
            if global_server_args_dict["enable_deepep_moe"]
            else (EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE)
        )
xiaobochen's avatar
xiaobochen committed
1373
        expert_params_mapping = MoEImpl.make_expert_params_mapping(
Liangsheng Yin's avatar
Liangsheng Yin committed
1374
1375
1376
1377
1378
1379
1380
1381
            ckpt_gate_proj_name="gate_proj",
            ckpt_down_proj_name="down_proj",
            ckpt_up_proj_name="up_proj",
            num_experts=self.config.n_routed_experts,
        )

        params_dict = dict(self.named_parameters())
        for name, loaded_weight in weights:
HandH1998's avatar
HandH1998 committed
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
            # TODO(HandH1998): Modify it when nextn is supported.
            if hasattr(self.config, "num_nextn_predict_layers"):
                num_nextn_layers = self.config.num_nextn_predict_layers
                if num_nextn_layers > 0 and name.startswith("model.layers"):
                    name_list = name.split(".")
                    if (
                        len(name_list) >= 3
                        and int(name_list[2]) >= self.config.num_hidden_layers
                    ):
                        continue
Liangsheng Yin's avatar
Liangsheng Yin committed
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
            if "rotary_emb.inv_freq" in name:
                continue
            for param_name, weight_name, shard_id in stacked_params_mapping:
                # Skip non-stacked layers and experts (experts handled below).
                if weight_name not in name:
                    continue
                # We have mlp.experts[0].gate_proj in the checkpoint.
                # Since we handle the experts below in expert_params_mapping,
                # we need to skip here BEFORE we update the name, otherwise
                # name will be updated to mlp.experts[0].gate_up_proj, which
                # will then be updated below in expert_params_mapping
                # for mlp.experts[0].gate_gate_up_proj, which breaks load.
                if ("mlp.experts." in name) and name not in params_dict:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                for mapping in expert_params_mapping:
                    param_name, weight_name, expert_id, shard_id = mapping
                    if weight_name not in name:
                        continue
                    name = name.replace(weight_name, param_name)
                    param = params_dict[name]
                    weight_loader = param.weight_loader
                    weight_loader(
                        param,
                        loaded_weight,
1425
                        name,
Liangsheng Yin's avatar
Liangsheng Yin committed
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
                        shard_id=shard_id,
                        expert_id=expert_id,
                    )
                    break
                else:
                    # Skip loading extra bias for GPTQ models.
                    if name.endswith(".bias") and name not in params_dict:
                        continue

                    param = params_dict[name]
                    weight_loader = getattr(
                        param, "weight_loader", default_weight_loader
                    )
                    weight_loader(param, loaded_weight)

Ke Bao's avatar
Ke Bao committed
1441
        if not global_server_args_dict["disable_mla"]:
Ke Bao's avatar
Ke Bao committed
1442
1443
            for layer_id in range(self.config.num_hidden_layers):
                self_attn = self.model.layers[layer_id].self_attn
Ke Bao's avatar
Ke Bao committed
1444
1445
                if hasattr(self_attn.kv_b_proj, "qweight"):
                    # AWQ compatible
Yineng Zhang's avatar
Yineng Zhang committed
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
                    if _is_cuda:
                        w = awq_dequantize(
                            self_attn.kv_b_proj.qweight,
                            self_attn.kv_b_proj.scales,
                            self_attn.kv_b_proj.qzeros,
                        ).T
                    else:
                        w = ops.awq_dequantize(
                            self_attn.kv_b_proj.qweight,
                            self_attn.kv_b_proj.scales,
                            self_attn.kv_b_proj.qzeros,
                            0,
                            0,
                            0,
                        ).T
Ke Bao's avatar
Ke Bao committed
1461
1462
                else:
                    w = self_attn.kv_b_proj.weight
HandH1998's avatar
HandH1998 committed
1463
1464
                # NOTE(HandH1998): Since `bmm_fp8` only supports per-tensor scale, we have to requantize `self_attn.kv_b_proj`.
                # This may affect the accuracy of fp8 model.
1465
1466
1467
                if hasattr(self.quant_config, "weight_block_size") and w.dtype in (
                    torch.float8_e4m3fn,
                    torch.float8_e4m3fnuz,
HandH1998's avatar
HandH1998 committed
1468
1469
1470
1471
                ):
                    weight_block_size = self.quant_config.weight_block_size
                    if weight_block_size is not None:
                        assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
1472
                        if _is_hip:
1473
1474
1475
1476
1477
1478
1479
1480
1481
                            weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
                                weight=w,
                                weight_scale=self_attn.kv_b_proj.weight_scale_inv,
                                input_scale=None,
                            )
                        else:
                            weight = w
                            weight_scale = self_attn.kv_b_proj.weight_scale_inv

HandH1998's avatar
HandH1998 committed
1482
                        w, scale = block_quant_to_tensor_quant(
1483
                            weight, weight_scale, weight_block_size
HandH1998's avatar
HandH1998 committed
1484
1485
                        )
                        self_attn.w_scale = scale
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
                if w.dtype == torch.int8:
                    if hasattr(self.quant_config, "weight_block_size"):
                        # block-wise int8 need it
                        weight_block_size = self.quant_config.weight_block_size
                        if weight_block_size is not None:
                            assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
                            weight = w
                            weight_scale = self_attn.kv_b_proj.weight_scale_inv
                            w = int8_block_dequant(
                                weight, weight_scale, weight_block_size
                            ).to(torch.bfloat16)
                    else:
                        # channel-wise int8 need it
                        w = w.to(torch.bfloat16) * self_attn.kv_b_proj.weight_scale.to(
                            torch.bfloat16
                        )
Ke Bao's avatar
Ke Bao committed
1502
                w_kc, w_vc = w.unflatten(
Ke Bao's avatar
Ke Bao committed
1503
1504
                    0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
                ).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
1505
1506
                self_attn.w_kc = w_kc.transpose(1, 2).contiguous().transpose(1, 2)
                self_attn.w_vc = w_vc.contiguous().transpose(1, 2)
HandH1998's avatar
HandH1998 committed
1507
1508
1509
1510
                if (
                    hasattr(self_attn.kv_b_proj, "weight_scale")
                    and self_attn.w_scale is None
                ):
1511
                    self_attn.w_scale = self_attn.kv_b_proj.weight_scale
1512
                    if _is_hip:
1513
                        self_attn.w_scale *= 2.0
Ke Bao's avatar
Ke Bao committed
1514

1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
    def get_embed_and_head(self):
        return self.model.embed_tokens.weight, self.lm_head.weight

    def set_embed_and_head(self, embed, head):
        del self.model.embed_tokens.weight
        del self.lm_head.weight
        self.model.embed_tokens.weight = embed
        self.lm_head.weight = head
        torch.cuda.empty_cache()
        torch.cuda.synchronize()

Liangsheng Yin's avatar
Liangsheng Yin committed
1526

HandH1998's avatar
HandH1998 committed
1527
1528
1529
1530
1531
class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM):
    pass


EntryClass = [DeepseekV2ForCausalLM, DeepseekV3ForCausalLM]