deepseek_v2.py 79.7 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
14

Liangsheng Yin's avatar
Liangsheng Yin committed
15
16
17
# Adapted from:
# https://github.com/vllm-project/vllm/blob/fb6af8bc086328ca6659e72d11ffd4309ce4de22/vllm/model_executor/models/deepseek_v2.py
"""Inference-only DeepseekV2 model."""
18

19
import logging
20
import os
21
from enum import IntEnum, auto
Liangsheng Yin's avatar
Liangsheng Yin committed
22
23
24
from typing import Any, Dict, Iterable, Optional, Tuple

import torch
Ke Bao's avatar
Ke Bao committed
25
import torch.nn.functional as F
Liangsheng Yin's avatar
Liangsheng Yin committed
26
from torch import nn
27
from tqdm import tqdm
Liangsheng Yin's avatar
Liangsheng Yin committed
28
from transformers import PretrainedConfig
29
30

from sglang.srt.distributed import (
Liangsheng Yin's avatar
Liangsheng Yin committed
31
    get_tensor_model_parallel_world_size,
32
    parallel_state,
Liangsheng Yin's avatar
Liangsheng Yin committed
33
34
    tensor_model_parallel_all_reduce,
)
35
from sglang.srt.layers.activation import SiluAndMul
36
37
38
39
40
from sglang.srt.layers.communicator import (
    LayerCommunicator,
    LayerScatterModes,
    enable_moe_dense_fully_dp,
)
Lianmin Zheng's avatar
Lianmin Zheng committed
41
42
43
from sglang.srt.layers.dp_attention import (
    get_attention_tp_rank,
    get_attention_tp_size,
44
    get_local_attention_dp_size,
Lianmin Zheng's avatar
Lianmin Zheng committed
45
)
46
from sglang.srt.layers.layernorm import RMSNorm
47
48
49
50
51
52
from sglang.srt.layers.linear import (
    ColumnParallelLinear,
    MergedColumnParallelLinear,
    ReplicatedLinear,
    RowParallelLinear,
)
Liangsheng Yin's avatar
Liangsheng Yin committed
53
from sglang.srt.layers.logits_processor import LogitsProcessor
fzyzcjy's avatar
fzyzcjy committed
54
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
Lianmin Zheng's avatar
Lianmin Zheng committed
55
from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
56
from sglang.srt.layers.moe.topk import select_experts
57
from sglang.srt.layers.quantization.base_config import QuantizationConfig
58
from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM
59
60
from sglang.srt.layers.quantization.fp8_kernel import (
    per_tensor_quant_mla_fp8,
61
    per_token_group_quant_mla_deep_gemm_masked_fp8,
62
)
HandH1998's avatar
HandH1998 committed
63
from sglang.srt.layers.quantization.fp8_utils import (
64
    block_quant_dequant,
HandH1998's avatar
HandH1998 committed
65
    block_quant_to_tensor_quant,
66
    channel_quant_to_tensor_quant,
67
    normalize_e4m3fn_to_e4m3fnuz,
HandH1998's avatar
HandH1998 committed
68
)
69
70
71
from sglang.srt.layers.quantization.int8_utils import (
    block_dequant as int8_block_dequant,
)
Liangsheng Yin's avatar
Liangsheng Yin committed
72
from sglang.srt.layers.radix_attention import RadixAttention
73
from sglang.srt.layers.rotary_embedding import get_rope
74
75
76
77
from sglang.srt.layers.vocab_parallel_embedding import (
    ParallelLMHead,
    VocabParallelEmbedding,
)
78
79
80
81
from sglang.srt.managers.expert_distribution import (
    get_global_expert_distribution_recorder,
)
from sglang.srt.managers.expert_location import ModelConfigForExpertLocation
82
from sglang.srt.managers.expert_location_dispatch import ExpertLocationDispatchInfo
83
from sglang.srt.managers.schedule_batch import global_server_args_dict
84
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
85
from sglang.srt.model_loader.weight_utils import default_weight_loader
86
87
88
89
from sglang.srt.two_batch_overlap import (
    MaybeTboDeepEPDispatcher,
    model_forward_maybe_tbo,
)
90
91
92
93
94
95
96
97
from sglang.srt.utils import (
    BumpAllocator,
    DeepEPMode,
    add_prefix,
    get_bool_env_var,
    get_int_env_var,
    is_cuda,
    is_hip,
98
    is_non_idle_and_non_empty,
99
    log_info_on_rank0,
100
)
101

102
_is_hip = is_hip()
Yineng Zhang's avatar
Yineng Zhang committed
103
_is_cuda = is_cuda()
104

Yineng Zhang's avatar
Yineng Zhang committed
105
if _is_cuda:
106
    from sgl_kernel import awq_dequantize, bmm_fp8, merge_state_v2
107
108
109
110

    from sglang.srt.layers.quantization.deep_gemm import (
        grouped_gemm_nt_f8f8bf16_masked as deep_gemm_grouped_gemm_nt_f8f8bf16_masked,
    )
Yineng Zhang's avatar
Yineng Zhang committed
111
else:
Lianmin Zheng's avatar
Lianmin Zheng committed
112
    from vllm._custom_ops import awq_dequantize
Liangsheng Yin's avatar
Liangsheng Yin committed
113

114
115
116
117
118
if _is_hip:
    from sglang.srt.layers.attention.triton_ops.rocm_mla_decode_rope import (
        decode_attention_fwd_grouped_rope,
    )

119
120
logger = logging.getLogger(__name__)

Liangsheng Yin's avatar
Liangsheng Yin committed
121

122
123
124
125
126
127
128
129
130
131
132
class AttnForwardMethod(IntEnum):
    # Use multi-head attention
    MHA = auto()

    # Use absorbed multi-latent attention
    MLA = auto()

    # Use multi-head attention, but with KV cache chunked.
    # This method can avoid OOM when prefix lengths are long.
    MHA_CHUNKED_KV = auto()

133
134
135
    # Use MLA but with fused RoPE
    MLA_FUSED_ROPE = auto()

136

Liangsheng Yin's avatar
Liangsheng Yin committed
137
138
139
140
141
142
143
144
class DeepseekV2MLP(nn.Module):
    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
        quant_config: Optional[QuantizationConfig] = None,
        reduce_results: bool = True,
145
        prefix: str = "",
146
147
        tp_rank: Optional[int] = None,
        tp_size: Optional[int] = None,
Liangsheng Yin's avatar
Liangsheng Yin committed
148
149
    ) -> None:
        super().__init__()
150
151
        self.tp_size = tp_size

Liangsheng Yin's avatar
Liangsheng Yin committed
152
        self.gate_up_proj = MergedColumnParallelLinear(
153
154
155
156
157
            hidden_size,
            [intermediate_size] * 2,
            bias=False,
            quant_config=quant_config,
            prefix=add_prefix("gate_up_proj", prefix),
158
159
            tp_rank=tp_rank,
            tp_size=tp_size,
Liangsheng Yin's avatar
Liangsheng Yin committed
160
161
162
163
164
165
166
        )
        self.down_proj = RowParallelLinear(
            intermediate_size,
            hidden_size,
            bias=False,
            quant_config=quant_config,
            reduce_results=reduce_results,
167
            prefix=add_prefix("down_proj", prefix),
168
169
            tp_rank=tp_rank,
            tp_size=tp_size,
Liangsheng Yin's avatar
Liangsheng Yin committed
170
171
172
173
174
175
176
177
        )
        if hidden_act != "silu":
            raise ValueError(
                f"Unsupported activation: {hidden_act}. "
                "Only silu is supported for now."
            )
        self.act_fn = SiluAndMul()

178
179
180
181
    def forward(self, x, forward_batch=None):
        if (self.tp_size == 1) and x.shape[0] == 0:
            return x

Liangsheng Yin's avatar
Liangsheng Yin committed
182
183
184
185
186
187
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.down_proj(x)
        return x


Ke Bao's avatar
Ke Bao committed
188
class MoEGate(nn.Module):
189
190
191
192
193
    def __init__(
        self,
        config,
        prefix: str = "",
    ):
Ke Bao's avatar
Ke Bao committed
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
        super().__init__()
        self.weight = nn.Parameter(
            torch.empty((config.n_routed_experts, config.hidden_size))
        )
        if config.topk_method == "noaux_tc":
            self.e_score_correction_bias = nn.Parameter(
                torch.empty((config.n_routed_experts))
            )
        else:
            self.e_score_correction_bias = None

    def forward(self, hidden_states):
        logits = F.linear(hidden_states, self.weight, None)
        return logits


Liangsheng Yin's avatar
Liangsheng Yin committed
210
211
212
213
214
class DeepseekV2MoE(nn.Module):

    def __init__(
        self,
        config: PretrainedConfig,
fzyzcjy's avatar
fzyzcjy committed
215
        layer_id: int,
Liangsheng Yin's avatar
Liangsheng Yin committed
216
        quant_config: Optional[QuantizationConfig] = None,
217
        prefix: str = "",
Liangsheng Yin's avatar
Liangsheng Yin committed
218
219
220
221
222
    ):
        super().__init__()
        self.tp_size = get_tensor_model_parallel_world_size()
        self.routed_scaling_factor = config.routed_scaling_factor
        self.n_shared_experts = config.n_shared_experts
223
        self.n_share_experts_fusion = global_server_args_dict["n_share_experts_fusion"]
224
        self.config = config
fzyzcjy's avatar
fzyzcjy committed
225
        self.layer_id = layer_id
226

Liangsheng Yin's avatar
Liangsheng Yin committed
227
228
229
230
231
232
233
234
235
236
237
238
        if self.tp_size > config.n_routed_experts:
            raise ValueError(
                f"Tensor parallel size {self.tp_size} is greater than "
                f"the number of experts {config.n_routed_experts}."
            )

        if config.hidden_act != "silu":
            raise ValueError(
                f"Unsupported activation: {config.hidden_act}. "
                "Only silu is supported for now."
            )

239
        self.gate = MoEGate(config=config, prefix=add_prefix("gate", prefix))
Ke Bao's avatar
Ke Bao committed
240

241
        self.experts = get_moe_impl_class()(
242
243
244
            num_experts=config.n_routed_experts
            + self.n_share_experts_fusion
            + global_server_args_dict["ep_num_redundant_experts"],
245
            top_k=config.num_experts_per_tok + min(self.n_share_experts_fusion, 1),
246
247
            hidden_size=config.hidden_size,
            intermediate_size=config.moe_intermediate_size,
fzyzcjy's avatar
fzyzcjy committed
248
            layer_id=self.layer_id,
249
250
251
252
253
254
            renormalize=config.norm_topk_prob,
            quant_config=quant_config,
            use_grouped_topk=True,
            num_expert_group=config.n_group,
            topk_group=config.topk_group,
            correction_bias=self.gate.e_score_correction_bias,
255
            routed_scaling_factor=self.routed_scaling_factor,
256
257
258
259
260
261
262
            prefix=add_prefix("experts", prefix),
            **(
                dict(deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]])
                if global_server_args_dict["enable_deepep_moe"]
                else {}
            ),
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
263

264
        if config.n_shared_experts is not None and self.n_share_experts_fusion == 0:
Liangsheng Yin's avatar
Liangsheng Yin committed
265
            intermediate_size = config.moe_intermediate_size * config.n_shared_experts
266
            # disable tp for shared experts when enable deepep moe
267
268
269
270
271
272
273
274
275
276
277
278
279
            self.shared_experts = DeepseekV2MLP(
                hidden_size=config.hidden_size,
                intermediate_size=intermediate_size,
                hidden_act=config.hidden_act,
                quant_config=quant_config,
                reduce_results=False,
                prefix=add_prefix("shared_experts", prefix),
                **(
                    dict(tp_rank=0, tp_size=1)
                    if global_server_args_dict["enable_deepep_moe"]
                    else {}
                ),
            )
280

281
282
        self.top_k = config.num_experts_per_tok

283
        if global_server_args_dict["enable_deepep_moe"]:
284
285
            # TODO: we will support tp < ep in the future
            self.ep_size = get_tensor_model_parallel_world_size()
286
287
288
289
            self.num_experts = (
                config.n_routed_experts
                + global_server_args_dict["ep_num_redundant_experts"]
            )
290
291
292
293
294
295
296
297
298
            self.renormalize = config.norm_topk_prob
            self.topk_group = config.topk_group
            self.num_expert_group = config.n_group
            self.correction_bias = (
                self.gate.e_score_correction_bias.data
                if self.gate.e_score_correction_bias is not None
                else None
            )

299
            self.deepep_dispatcher = MaybeTboDeepEPDispatcher(
300
301
302
                group=parallel_state.get_tp_group().device_group,
                router_topk=self.top_k,
                permute_fusion=True,
303
                num_experts=self.num_experts,
304
                num_local_experts=config.n_routed_experts // self.tp_size,
Liangsheng Yin's avatar
Liangsheng Yin committed
305
                hidden_size=config.hidden_size,
306
                params_dtype=config.torch_dtype,
307
                deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]],
308
                async_finish=True,
309
                return_recv_hook=True,
Liangsheng Yin's avatar
Liangsheng Yin committed
310
311
            )

312
        self._enable_deepep_moe = global_server_args_dict["enable_deepep_moe"]
313

314
315
316
317
318
319
320
    def get_moe_weights(self):
        return [
            x.data
            for name, x in self.experts.named_parameters()
            if name not in ["correction_bias"]
        ]

321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
    def forward(
        self, hidden_states: torch.Tensor, forward_batch: Optional[ForwardBatch] = None
    ) -> torch.Tensor:
        if not self._enable_deepep_moe:
            return self.forward_normal(hidden_states)
        else:
            return self.forward_deepep(hidden_states, forward_batch)

    def forward_normal(self, hidden_states: torch.Tensor) -> torch.Tensor:
        shared_output = self._forward_shared_experts(hidden_states)
        # router_logits: (num_tokens, n_experts)
        router_logits = self.gate(hidden_states)
        final_hidden_states = self.experts(
            hidden_states=hidden_states, router_logits=router_logits
        )
        final_hidden_states *= self.routed_scaling_factor
        if shared_output is not None:
            final_hidden_states = final_hidden_states + shared_output
        if self.tp_size > 1:
            final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
        return final_hidden_states

    def forward_deepep(
        self, hidden_states: torch.Tensor, forward_batch: ForwardBatch
    ) -> torch.Tensor:
        forward_mode = forward_batch.forward_mode
        shared_output = None
        if is_non_idle_and_non_empty(forward_mode, hidden_states):
            # router_logits: (num_tokens, n_experts)
            router_logits = self.gate(hidden_states)
            shared_output = self._forward_shared_experts(hidden_states)
            topk_weights, topk_idx = select_experts(
                hidden_states=hidden_states,
                router_logits=router_logits,
                top_k=self.top_k,
                use_grouped_topk=True,
                renormalize=self.renormalize,
                topk_group=self.topk_group,
                num_expert_group=self.num_expert_group,
                correction_bias=self.correction_bias,
                routed_scaling_factor=self.routed_scaling_factor,
                num_token_non_padded=forward_batch.num_token_non_padded,
            )
        else:
            topk_idx = torch.full(
                (0, self.top_k), -1, dtype=torch.int, device=hidden_states.device
            )
            topk_weights = torch.empty(
                (0, self.top_k), dtype=torch.float32, device=hidden_states.device
            )
        if self.ep_size > 1:
            # TODO(ch-wan): allow users to set num_max_dispatch_tokens_per_rank value
            (
                hidden_states,
                topk_idx,
                topk_weights,
                reorder_topk_ids,
                num_recv_tokens_per_expert,
                seg_indptr,
                masked_m,
                expected_m,
            ) = self.deepep_dispatcher.dispatch(
                hidden_states=hidden_states,
                topk_idx=topk_idx,
                topk_weights=topk_weights,
                forward_mode=forward_mode,
            )
        final_hidden_states = self.experts(
            hidden_states=hidden_states,
            topk_idx=topk_idx,
            topk_weights=topk_weights,
            reorder_topk_ids=reorder_topk_ids,
            seg_indptr=seg_indptr,
            masked_m=masked_m,
            expected_m=expected_m,
            num_recv_tokens_per_expert=num_recv_tokens_per_expert,
            forward_mode=forward_mode,
        )
        if self.ep_size > 1:
            final_hidden_states = self.deepep_dispatcher.combine(
                hidden_states=final_hidden_states,
                topk_idx=topk_idx,
                topk_weights=topk_weights,
                forward_mode=forward_mode,
            )
        final_hidden_states *= self.routed_scaling_factor

        if shared_output is not None:
            final_hidden_states = final_hidden_states + shared_output

        return final_hidden_states

    def _forward_shared_experts(self, hidden_states):
        if self.n_share_experts_fusion == 0:
            return self.shared_experts(hidden_states)
        else:
            return None

419
    def op_gate(self, state):
420
        if is_non_idle_and_non_empty(
421
            state.forward_batch.forward_mode, state.hidden_states_mlp_input
422
        ):
423
            # router_logits: (num_tokens, n_experts)
424
            state.router_logits = self.gate(state.hidden_states_mlp_input)
425
        else:
426
            state.router_logits = None
427

428
    def op_shared_experts(self, state):
429
430
431
        hidden_states_mlp_input = state.pop("hidden_states_mlp_input")
        if (self.n_share_experts_fusion == 0) and is_non_idle_and_non_empty(
            state.forward_batch.forward_mode, hidden_states_mlp_input
432
        ):
433
            state.shared_output = self.shared_experts(hidden_states_mlp_input)
434
        else:
435
            state.shared_output = None
436

437
    def op_select_experts(self, state):
438
        router_logits = state.pop("router_logits")
439
440
        hidden_states = state.hidden_states_mlp_input

441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
        if router_logits is not None:
            state.topk_weights_local, state.topk_idx_local = select_experts(
                hidden_states=hidden_states,
                router_logits=router_logits,
                top_k=self.top_k,
                use_grouped_topk=True,
                renormalize=self.renormalize,
                topk_group=self.topk_group,
                num_expert_group=self.num_expert_group,
                correction_bias=self.correction_bias,
                routed_scaling_factor=self.routed_scaling_factor,
                expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new(
                    layer_id=self.layer_id,
                ),
            )
        else:
            state.topk_idx_local = torch.full(
                (0, self.top_k), -1, dtype=torch.int, device=hidden_states.device
            )
            state.topk_weights_local = torch.empty(
                (0, self.top_k), dtype=torch.float32, device=hidden_states.device
            )
463

464
    def op_dispatch_a(self, state):
465
        if self.ep_size > 1:
466
            # TODO(ch-wan): allow users to set num_max_dispatch_tokens_per_rank value
467
            self.deepep_dispatcher.dispatch_a(
468
                hidden_states=state.hidden_states_mlp_input,
469
470
471
                topk_idx=state.pop("topk_idx_local"),
                topk_weights=state.pop("topk_weights_local"),
                forward_mode=state.forward_batch.forward_mode,
472
                tbo_subbatch_index=state.get("tbo_subbatch_index"),
473
            )
474

475
    def op_dispatch_b(self, state):
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
        if self.ep_size > 1:
            with get_global_expert_distribution_recorder().with_current_layer(
                self.layer_id
            ):
                (
                    state.hidden_states_experts_input,
                    state.topk_idx_dispatched,
                    state.topk_weights_dispatched,
                    state.reorder_topk_ids,
                    state.num_recv_tokens_per_expert,
                    state.seg_indptr,
                    state.masked_m,
                    state.expected_m,
                ) = self.deepep_dispatcher.dispatch_b(
                    tbo_subbatch_index=state.get("tbo_subbatch_index"),
                )
492
493

    def op_experts(self, state):
494
495
496
497
498
499
500
501
502
503
504
        state.hidden_states_experts_output = self.experts(
            hidden_states=state.pop("hidden_states_experts_input"),
            topk_idx=state.topk_idx_dispatched,
            topk_weights=state.topk_weights_dispatched,
            reorder_topk_ids=state.pop("reorder_topk_ids"),
            seg_indptr=state.pop("seg_indptr"),
            masked_m=state.pop("masked_m"),
            expected_m=state.pop("expected_m"),
            num_recv_tokens_per_expert=state.pop("num_recv_tokens_per_expert"),
            forward_mode=state.forward_batch.forward_mode,
        )
505

506
    def op_combine_a(self, state):
507
        if self.ep_size > 1:
508
            self.deepep_dispatcher.combine_a(
509
                hidden_states=state.pop("hidden_states_experts_output"),
510
511
512
                topk_idx=state.pop("topk_idx_dispatched"),
                topk_weights=state.pop("topk_weights_dispatched"),
                forward_mode=state.forward_batch.forward_mode,
513
                tbo_subbatch_index=state.get("tbo_subbatch_index"),
514
            )
515

516
    def op_combine_b(self, state):
517
518
519
520
        if self.ep_size > 1:
            state.hidden_states_after_combine = self.deepep_dispatcher.combine_b(
                tbo_subbatch_index=state.get("tbo_subbatch_index"),
            )
521
522

    def op_output(self, state):
523
        final_hidden_states = state.pop("hidden_states_after_combine")
524
        final_hidden_states *= self.routed_scaling_factor
525
526
        if (s := state.pop("shared_output")) is not None:
            final_hidden_states = final_hidden_states + s
Liangsheng Yin's avatar
Liangsheng Yin committed
527

528
        state.hidden_states_mlp_output = final_hidden_states
529

Liangsheng Yin's avatar
Liangsheng Yin committed
530
531
532
533
534
535
536
537
538

def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
    import math

    if scale <= 1:
        return 1.0
    return 0.1 * mscale * math.log(scale) + 1.0


539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
class DeepseekV2AttentionMLA(nn.Module):

    def __init__(
        self,
        config: PretrainedConfig,
        hidden_size: int,
        num_heads: int,
        qk_nope_head_dim: int,
        qk_rope_head_dim: int,
        v_head_dim: int,
        q_lora_rank: int,
        kv_lora_rank: int,
        rope_theta: float = 10000,
        rope_scaling: Optional[Dict[str, Any]] = None,
        max_position_embeddings: int = 8192,
        quant_config: Optional[QuantizationConfig] = None,
Lianmin Zheng's avatar
Lianmin Zheng committed
555
556
        reduce_results: bool = True,
        layer_id: int = None,
557
        prefix: str = "",
558
        alt_stream: Optional[torch.cuda.Stream] = None,
559
560
561
562
563
564
565
566
567
568
    ) -> None:
        super().__init__()
        self.layer_id = layer_id
        self.hidden_size = hidden_size
        self.qk_nope_head_dim = qk_nope_head_dim
        self.qk_rope_head_dim = qk_rope_head_dim
        self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
        self.v_head_dim = v_head_dim
        self.q_lora_rank = q_lora_rank
        self.kv_lora_rank = kv_lora_rank
Lianmin Zheng's avatar
Lianmin Zheng committed
569
570
571
        attn_tp_rank = get_attention_tp_rank()
        attn_tp_size = get_attention_tp_size()

572
        self.num_heads = num_heads
Lianmin Zheng's avatar
Lianmin Zheng committed
573
574
        assert num_heads % attn_tp_size == 0
        self.num_local_heads = num_heads // attn_tp_size
575
576
577
578
        self.scaling = self.qk_head_dim**-0.5
        self.rope_theta = rope_theta
        self.max_position_embeddings = max_position_embeddings

Lianmin Zheng's avatar
Lianmin Zheng committed
579
580
        # For tensor parallel attention
        if self.q_lora_rank is not None:
581
            self.fused_qkv_a_proj_with_mqa = ReplicatedLinear(
Ke Bao's avatar
Ke Bao committed
582
                self.hidden_size,
583
                self.q_lora_rank + self.kv_lora_rank + self.qk_rope_head_dim,
584
585
                bias=False,
                quant_config=quant_config,
586
                prefix=add_prefix("fused_qkv_a_proj_with_mqa", prefix),
587
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
588
589
590
591
            self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
            self.q_b_proj = ColumnParallelLinear(
                q_lora_rank,
                self.num_heads * self.qk_head_dim,
Ke Bao's avatar
Ke Bao committed
592
593
                bias=False,
                quant_config=quant_config,
Lianmin Zheng's avatar
Lianmin Zheng committed
594
595
596
                prefix=add_prefix("q_b_proj", prefix),
                tp_rank=attn_tp_rank,
                tp_size=attn_tp_size,
Ke Bao's avatar
Ke Bao committed
597
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
598
599
        else:
            self.q_proj = ColumnParallelLinear(
600
                self.hidden_size,
Lianmin Zheng's avatar
Lianmin Zheng committed
601
                self.num_heads * self.qk_head_dim,
602
603
                bias=False,
                quant_config=quant_config,
Lianmin Zheng's avatar
Lianmin Zheng committed
604
605
606
                prefix=add_prefix("q_proj", prefix),
                tp_rank=attn_tp_rank,
                tp_size=attn_tp_size,
607
            )
608
609
610
611
612
613
614
615
            self.kv_a_proj_with_mqa = ReplicatedLinear(
                self.hidden_size,
                self.kv_lora_rank + self.qk_rope_head_dim,
                bias=False,
                quant_config=quant_config,
                prefix=add_prefix("kv_a_proj_with_mqa", prefix),
            )

Lianmin Zheng's avatar
Lianmin Zheng committed
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
        self.kv_b_proj = ColumnParallelLinear(
            self.kv_lora_rank,
            self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
            bias=False,
            quant_config=quant_config,
            prefix=add_prefix("kv_b_proj", prefix),
            tp_rank=attn_tp_rank,
            tp_size=attn_tp_size,
        )
        # O projection.
        self.o_proj = RowParallelLinear(
            self.num_heads * self.v_head_dim,
            self.hidden_size,
            bias=False,
            quant_config=quant_config,
            reduce_results=reduce_results,
            prefix=add_prefix("o_proj", prefix),
            tp_rank=attn_tp_rank,
            tp_size=attn_tp_size,
        )
636
        self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
Ke Bao's avatar
Ke Bao committed
637
638
639
640

        if rope_scaling:
            rope_scaling["rope_type"] = "deepseek_yarn"

641
        self.rotary_emb = get_rope(
642
643
644
645
646
647
648
649
650
651
652
653
654
            qk_rope_head_dim,
            rotary_dim=qk_rope_head_dim,
            max_position=max_position_embeddings,
            base=rope_theta,
            rope_scaling=rope_scaling,
            is_neox_style=False,
        )

        if rope_scaling:
            mscale_all_dim = rope_scaling.get("mscale_all_dim", False)
            scaling_factor = rope_scaling["factor"]
            mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
            self.scaling = self.scaling * mscale * mscale
Ke Bao's avatar
Ke Bao committed
655
656
        else:
            self.rotary_emb.forward = self.rotary_emb.forward_native
657

658
        self.attn_mqa = RadixAttention(
659
660
661
662
663
664
            self.num_local_heads,
            self.kv_lora_rank + self.qk_rope_head_dim,
            self.scaling,
            num_kv_heads=1,
            layer_id=layer_id,
            v_head_dim=self.kv_lora_rank,
665
            quant_config=quant_config,
666
            prefix=add_prefix("attn_mqa", prefix),
667
668
        )

669
670
671
672
673
674
675
        self.attn_mha = RadixAttention(
            self.num_local_heads,
            self.qk_nope_head_dim + self.qk_rope_head_dim,
            self.scaling,
            num_kv_heads=self.num_local_heads,
            layer_id=layer_id,
            v_head_dim=self.v_head_dim,
676
            quant_config=quant_config,
677
            prefix=add_prefix("attn_mha", prefix),
678
679
        )

680
681
        self.alt_stream = alt_stream

Ke Bao's avatar
Ke Bao committed
682
683
        self.w_kc = None
        self.w_vc = None
684
        self.w_scale = None
685

686
687
688
689
        self.w_scale_k = None
        self.w_scale_v = None
        self.use_deep_gemm_bmm = False

Lianmin Zheng's avatar
Lianmin Zheng committed
690
691
692
        self.flashinfer_mla_disable_ragged = global_server_args_dict[
            "flashinfer_mla_disable_ragged"
        ]
693
694
695
        self.disable_chunked_prefix_cache = global_server_args_dict[
            "disable_chunked_prefix_cache"
        ]
696
        self.attention_backend = global_server_args_dict["attention_backend"]
697
698
699
        self.rocm_fused_decode_mla = get_bool_env_var(
            "SGLANG_ROCM_FUSED_DECODE_MLA", "false"
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
700

701
        # TODO: Design a finer way to determine the threshold
702
703
704
        self.chunked_prefix_cache_threshold = get_int_env_var(
            "SGL_CHUNKED_PREFIX_CACHE_THRESHOLD", 8192
        )
705
706
707
708

    def dispatch_attn_forward_method(
        self, forward_batch: ForwardBatch
    ) -> AttnForwardMethod:
709
710
711
712
713
714
715
716
717
718
719
720
        def _dispatch_mla_subtype():
            if _is_hip:
                if (
                    self.rocm_fused_decode_mla
                    and forward_batch.forward_mode.is_decode()
                ):
                    return AttnForwardMethod.MLA_FUSED_ROPE
                else:
                    return AttnForwardMethod.MLA
            else:
                return AttnForwardMethod.MLA

721
        if self.attention_backend == "flashinfer":
Lianmin Zheng's avatar
Lianmin Zheng committed
722
            # Flashinfer MLA: Do not absorb when enabling ragged prefill
723
            if (
Lianmin Zheng's avatar
Lianmin Zheng committed
724
725
726
727
                not self.flashinfer_mla_disable_ragged
                and forward_batch.forward_mode.is_extend()
                and not forward_batch.forward_mode.is_target_verify()
                and not forward_batch.forward_mode.is_draft_extend()
728
                and sum(forward_batch.extend_prefix_lens_cpu) == 0
729
730
731
            ):
                return AttnForwardMethod.MHA
            else:
732
                return _dispatch_mla_subtype()
733
        elif self.attention_backend == "fa3":
734
            # Flash Attention: Use MHA with chunked KV cache when prefilling on long sequences.
735
736
            if forward_batch.extend_prefix_lens_cpu is not None:
                sum_extend_prefix_lens = sum(forward_batch.extend_prefix_lens_cpu)
737
738
739
740
741
            if (
                forward_batch.forward_mode.is_extend()
                and not self.disable_chunked_prefix_cache
                and not forward_batch.forward_mode.is_target_verify()
                and not forward_batch.forward_mode.is_draft_extend()
742
743
744
745
                and (
                    sum_extend_prefix_lens >= self.chunked_prefix_cache_threshold
                    or sum_extend_prefix_lens == 0
                )
746
747
748
            ):
                return AttnForwardMethod.MHA_CHUNKED_KV
            else:
749
                return _dispatch_mla_subtype()
Lianmin Zheng's avatar
Lianmin Zheng committed
750
751
        else:
            # Triton: Use normal computation for prefill and use weight absorption for extend/decode
752
            if (
Lianmin Zheng's avatar
Lianmin Zheng committed
753
754
755
                forward_batch.forward_mode.is_extend()
                and not forward_batch.forward_mode.is_target_verify()
                and not forward_batch.forward_mode.is_draft_extend()
756
                and sum(forward_batch.extend_prefix_lens_cpu) == 0
757
758
759
            ):
                return AttnForwardMethod.MHA
            else:
760
                return _dispatch_mla_subtype()
Lianmin Zheng's avatar
Lianmin Zheng committed
761

762
763
764
765
766
767
768
769
770
771
772
773
774
    def op_prepare(self, state):
        state.attn_intermediate_state = self.forward_prepare(
            positions=state.positions,
            hidden_states=state.pop("hidden_states_after_comm_pre_attn"),
            forward_batch=state.forward_batch,
            zero_allocator=state.zero_allocator,
        )

    def op_core(self, state):
        state.hidden_states_after_attn = self.forward_core(
            state.pop("attn_intermediate_state")
        )

775
776
777
778
    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
779
        forward_batch: ForwardBatch,
780
        zero_allocator: BumpAllocator,
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
    ):
        s = self.forward_prepare(
            positions=positions,
            hidden_states=hidden_states,
            forward_batch=forward_batch,
            zero_allocator=zero_allocator,
        )
        return self.forward_core(s)

    def forward_prepare(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        forward_batch: ForwardBatch,
        zero_allocator: BumpAllocator,
    ):
Lianmin Zheng's avatar
Lianmin Zheng committed
797
798
799
800
        if hidden_states.shape[0] == 0:
            assert (
                not self.o_proj.reduce_results
            ), "short-circuiting allreduce will lead to hangs"
801
            return hidden_states, None, forward_batch, None
802

803
804
805
        attn_forward_method = self.dispatch_attn_forward_method(forward_batch)

        if attn_forward_method == AttnForwardMethod.MHA:
806
807
808
            inner_state = self.forward_normal_prepare(
                positions, hidden_states, forward_batch, zero_allocator
            )
809
        elif attn_forward_method == AttnForwardMethod.MHA_CHUNKED_KV:
810
811
            inner_state = self.forward_normal_chunked_kv_prepare(
                positions, hidden_states, forward_batch, zero_allocator
812
            )
813
        elif attn_forward_method == AttnForwardMethod.MLA:
814
            inner_state = self.forward_absorb_prepare(
815
816
817
                positions, hidden_states, forward_batch, zero_allocator
            )
        elif attn_forward_method == AttnForwardMethod.MLA_FUSED_ROPE:
818
819
            inner_state = self.forward_absorb_fused_mla_rope_prepare(
                positions, hidden_states, forward_batch, zero_allocator
820
            )
821
        else:
822
            raise NotImplementedError
823
        return None, attn_forward_method, forward_batch, inner_state
824

825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
    def forward_core(self, intermediate_state):
        hidden_states, attn_forward_method, forward_batch, inner_state = (
            intermediate_state
        )
        if inner_state is None:
            return hidden_states

        if attn_forward_method == AttnForwardMethod.MHA:
            return self.forward_normal_core(*inner_state)
        elif attn_forward_method == AttnForwardMethod.MHA_CHUNKED_KV:
            return self.forward_normal_chunked_kv_core(*inner_state)
        elif attn_forward_method == AttnForwardMethod.MLA:
            return self.forward_absorb_core(*inner_state)
        elif attn_forward_method == AttnForwardMethod.MLA_FUSED_ROPE:
            return self.forward_absorb_fused_mla_rope_core(*inner_state)
        else:
            raise NotImplementedError

    def forward_normal_prepare(
844
845
846
847
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        forward_batch: ForwardBatch,
848
849
        zero_allocator: BumpAllocator,
    ):
850
        if self.q_lora_rank is not None:
851
852
853
            q, latent_cache = self.fused_qkv_a_proj_with_mqa(hidden_states)[0].split(
                [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1
            )
854
855
856
857
858
859
            q = self.q_a_layernorm(q)
            q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
        else:
            q = self.q_proj(hidden_states)[0].view(
                -1, self.num_local_heads, self.qk_head_dim
            )
860
861
            latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]

862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
        _, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
        kv_a, _ = latent_cache.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
        latent_cache = latent_cache.unsqueeze(1)
        kv_a = self.kv_a_layernorm(kv_a.contiguous())
        kv = self.kv_b_proj(kv_a)[0]
        kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim)
        k_nope = kv[..., : self.qk_nope_head_dim]
        v = kv[..., self.qk_nope_head_dim :]
        k_pe = latent_cache[:, :, self.kv_lora_rank :]
        q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
        q[..., self.qk_nope_head_dim :] = q_pe
        k = torch.empty_like(q)
        k[..., : self.qk_nope_head_dim] = k_nope
        k[..., self.qk_nope_head_dim :] = k_pe

        latent_cache[:, :, : self.kv_lora_rank] = kv_a.unsqueeze(1)
        latent_cache[:, :, self.kv_lora_rank :] = k_pe

        # Save latent cache
        forward_batch.token_to_kv_pool.set_kv_buffer(
            self.attn_mha, forward_batch.out_cache_loc, latent_cache, None
        )
884
885
886
887

        return q, k, v, forward_batch

    def forward_normal_core(self, q, k, v, forward_batch):
888
889
890
891
892
        attn_output = self.attn_mha(q, k, v, forward_batch, save_kv_cache=False)
        attn_output = attn_output.reshape(-1, self.num_local_heads * self.v_head_dim)
        output, _ = self.o_proj(attn_output)
        return output

893
    def forward_absorb_prepare(
894
895
896
897
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        forward_batch: ForwardBatch,
898
        zero_allocator: BumpAllocator,
899
    ):
900
901
        from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode

902
        if self.q_lora_rank is not None:
903
904
905
            q, latent_cache = self.fused_qkv_a_proj_with_mqa(hidden_states)[0].split(
                [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1
            )
906
907
908
            k_nope = latent_cache[..., : self.kv_lora_rank]

            # overlap qk norm
909
            if self.alt_stream is not None and get_is_capture_mode():
910
911
912
913
914
915
916
917
918
919
920
                current_stream = torch.cuda.current_stream()
                self.alt_stream.wait_stream(current_stream)
                q = self.q_a_layernorm(q)
                with torch.cuda.stream(self.alt_stream):
                    k_nope = self.kv_a_layernorm(k_nope)
                current_stream.wait_stream(self.alt_stream)
            else:
                q = self.q_a_layernorm(q)
                k_nope = self.kv_a_layernorm(k_nope)

            k_nope = k_nope.unsqueeze(1)
921
922
923
924
925
            q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
        else:
            q = self.q_proj(hidden_states)[0].view(
                -1, self.num_local_heads, self.qk_head_dim
            )
926
            latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
927
928
929
            k_nope = latent_cache[..., : self.kv_lora_rank]
            k_nope = self.kv_a_layernorm(k_nope).unsqueeze(1)

930
        q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
931
        k_pe = latent_cache[..., self.kv_lora_rank :].unsqueeze(1)
932

933
934
        if self.use_deep_gemm_bmm:
            q_nope_val, q_nope_scale, masked_m, expected_m, aligned_m = (
935
                per_token_group_quant_mla_deep_gemm_masked_fp8(q_nope.transpose(0, 1))
936
937
938
939
            )
            q_nope_out = q_nope.new_empty(
                (self.num_local_heads, aligned_m, self.kv_lora_rank)
            )
940
            deep_gemm_grouped_gemm_nt_f8f8bf16_masked(
941
942
943
944
945
946
947
948
                (q_nope_val, q_nope_scale),
                (self.w_kc, self.w_scale_k),
                q_nope_out,
                masked_m,
                expected_m,
            )
            q_nope_out = q_nope_out[:, :expected_m, :]
        elif self.w_kc.dtype == torch.float8_e4m3fnuz:
949
950
951
952
953
            # TODO(kernel): add bmm_fp8 for torch.float8_e4m3fnuz
            q_nope_out = torch.bmm(
                q_nope.to(torch.bfloat16).transpose(0, 1),
                self.w_kc.to(torch.bfloat16) * self.w_scale,
            )
954
        elif self.w_kc.dtype == torch.float8_e4m3fn:
955
            q_nope_val, q_nope_scale = per_tensor_quant_mla_fp8(
Lianmin Zheng's avatar
Lianmin Zheng committed
956
                q_nope.transpose(0, 1),
957
                zero_allocator.allocate(1),
958
959
960
961
962
963
            )
            q_nope_out = bmm_fp8(
                q_nope_val, self.w_kc, q_nope_scale, self.w_scale, torch.bfloat16
            )
        else:
            q_nope_out = torch.bmm(q_nope.transpose(0, 1), self.w_kc)
964
965

        q_nope_out = q_nope_out.transpose(0, 1)
966
967
        q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)

968
969
970
971
972
        return q_pe, k_pe, q_nope_out, k_nope, forward_batch, zero_allocator

    def forward_absorb_core(
        self, q_pe, k_pe, q_nope_out, k_nope, forward_batch, zero_allocator
    ):
xu-yfei's avatar
xu-yfei committed
973
        if self.attention_backend == "fa3" or self.attention_backend == "flashinfer":
974
            attn_output = self.attn_mqa(
Ke Bao's avatar
Ke Bao committed
975
                q_nope_out, k_nope, k_nope, forward_batch, q_rope=q_pe, k_rope=k_pe
976
977
978
            )
        else:
            q = torch.cat([q_nope_out, q_pe], dim=-1)
Ke Bao's avatar
Ke Bao committed
979
            k = torch.cat([k_nope, k_pe], dim=-1)
980
            attn_output = self.attn_mqa(q, k, k_nope, forward_batch)
981
982
        attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)

983
984
        if self.use_deep_gemm_bmm:
            attn_output_val, attn_output_scale, masked_m, expected_m, aligned_m = (
985
986
                per_token_group_quant_mla_deep_gemm_masked_fp8(
                    attn_output.transpose(0, 1)
987
988
989
990
991
                )
            )
            attn_bmm_output = attn_output.new_empty(
                (self.num_local_heads, aligned_m, self.v_head_dim)
            )
992
            deep_gemm_grouped_gemm_nt_f8f8bf16_masked(
993
994
995
996
997
998
999
1000
                (attn_output_val, attn_output_scale),
                (self.w_vc, self.w_scale_v),
                attn_bmm_output,
                masked_m,
                expected_m,
            )
            attn_bmm_output = attn_bmm_output[:, :expected_m, :]
        elif self.w_vc.dtype == torch.float8_e4m3fnuz:
1001
1002
1003
1004
1005
            # TODO(kernel): add bmm_fp8 for torch.float8_e4m3fnuz
            attn_bmm_output = torch.bmm(
                attn_output.to(torch.bfloat16).transpose(0, 1),
                self.w_vc.to(torch.bfloat16) * self.w_scale,
            )
1006
        elif self.w_vc.dtype == torch.float8_e4m3fn:
1007
            attn_output_val, attn_output_scale = per_tensor_quant_mla_fp8(
Lianmin Zheng's avatar
Lianmin Zheng committed
1008
                attn_output.transpose(0, 1),
1009
                zero_allocator.allocate(1),
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
            )
            attn_bmm_output = bmm_fp8(
                attn_output_val,
                self.w_vc,
                attn_output_scale,
                self.w_scale,
                torch.bfloat16,
            )
        else:
            attn_bmm_output = torch.bmm(attn_output.transpose(0, 1), self.w_vc)
        attn_output = attn_bmm_output.transpose(0, 1).flatten(1, 2)
        output, _ = self.o_proj(attn_output)

        return output

1025
    def forward_absorb_fused_mla_rope_prepare(
1026
1027
1028
1029
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        forward_batch: ForwardBatch,
1030
        zero_allocator: BumpAllocator,
1031
    ):
1032
1033
1034
1035
1036
1037
1038
1039
        enable_rope_fusion = (
            os.getenv("SGLANG_FUSED_MLA_ENABLE_ROPE_FUSION", "1") == "1"
        )
        q_len = hidden_states.shape[0]
        q_input = hidden_states.new_empty(
            q_len, self.num_local_heads, self.kv_lora_rank + self.qk_rope_head_dim
        )
        if self.q_lora_rank is not None:
1040
1041
1042
            q, latent_cache = self.fused_qkv_a_proj_with_mqa(hidden_states)[0].split(
                [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1
            )
1043
1044
1045
1046
1047
1048
            q = self.q_a_layernorm(q)
            q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
        else:
            q = self.q_proj(hidden_states)[0].view(
                -1, self.num_local_heads, self.qk_head_dim
            )
1049
            latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
1050
1051
1052
1053
1054
1055
1056
1057
1058
        q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)

        if self.w_kc.dtype == torch.float8_e4m3fnuz:
            # TODO(kernel): add bmm_fp8 for torch.float8_e4m3fnuz
            q_nope_out = torch.bmm(
                q_nope.to(torch.bfloat16).transpose(0, 1),
                self.w_kc.to(torch.bfloat16) * self.w_scale,
            )
        elif self.w_kc.dtype == torch.float8_e4m3fn:
1059
            q_nope_val, q_nope_scale = per_tensor_quant_mla_fp8(
1060
1061
1062
                q_nope.transpose(0, 1),
                zero_allocator.allocate(1),
                dtype=torch.float8_e4m3fn,
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
            )
            q_nope_out = bmm_fp8(
                q_nope_val, self.w_kc, q_nope_scale, self.w_scale, torch.bfloat16
            )
        else:
            q_nope_out = torch.bmm(q_nope.transpose(0, 1), self.w_kc)
        q_input[..., : self.kv_lora_rank] = q_nope_out.transpose(0, 1)
        v_input = latent_cache[..., : self.kv_lora_rank]
        v_input = self.kv_a_layernorm(v_input.contiguous()).unsqueeze(1)
        k_input = latent_cache.unsqueeze(1)
        k_input[..., : self.kv_lora_rank] = v_input

        if not enable_rope_fusion:
            k_pe = k_input[..., self.kv_lora_rank :]
            q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
            q_input[..., self.kv_lora_rank :] = q_pe
            k_input[..., self.kv_lora_rank :] = k_pe
            k_pe_output = None
        else:
            k_pe_output = torch.empty_like(k_input[..., self.kv_lora_rank :])

        q_input[..., self.kv_lora_rank :] = q_pe

        # attn_output = self.attn_mqa(q_input, k_input, v_input, forward_batch)
        # Use Fused ROPE with use_rope=OFF.
        attn_output = torch.empty(
            (q_len, self.num_local_heads, self.kv_lora_rank),
            dtype=q.dtype,
            device=q.device,
        )
        attn_logits, _, kv_indptr, kv_indices, _, _, _ = (
            forward_batch.attn_backend.forward_metadata
        )
        cos_sin_cache = self.rotary_emb.cos_sin_cache
        num_kv_split = forward_batch.attn_backend.num_kv_splits
        sm_scale = self.attn_mqa.scaling
        if attn_logits is None:
            attn_logits = torch.empty(
                (
                    forward_batch.batch_size,
                    self.num_local_heads,
                    num_kv_split,
                    self.kv_lora_rank + 1,
                ),
                dtype=torch.float32,
                device=q.device,
            )

        # save current latent cache.
        forward_batch.token_to_kv_pool.set_kv_buffer(
            self.attn_mqa, forward_batch.out_cache_loc, k_input, None
        )
        key_cache_buf = forward_batch.token_to_kv_pool.get_key_buffer(
            self.attn_mqa.layer_id
        )
        val_cache_buf = key_cache_buf[..., : self.kv_lora_rank]

1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
        return (
            q_input,
            key_cache_buf,
            val_cache_buf,
            attn_output,
            kv_indptr,
            kv_indices,
            k_pe_output,
            cos_sin_cache,
            positions,
            attn_logits,
            num_kv_split,
            sm_scale,
            enable_rope_fusion,
            k_input,
            forward_batch,
            zero_allocator,
        )

    def forward_absorb_fused_mla_rope_core(
        self,
        q_input,
        key_cache_buf,
        val_cache_buf,
        attn_output,
        kv_indptr,
        kv_indices,
        k_pe_output,
        cos_sin_cache,
        positions,
        attn_logits,
        num_kv_split,
        sm_scale,
        enable_rope_fusion,
        k_input,
        forward_batch,
        zero_allocator,
    ):
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
        decode_attention_fwd_grouped_rope(
            q_input,
            key_cache_buf,
            val_cache_buf,
            attn_output,
            kv_indptr,
            kv_indices,
            k_pe_output,
            self.kv_lora_rank,
            self.rotary_emb.rotary_dim,
            cos_sin_cache,
            positions,
            attn_logits,
            num_kv_split,
            sm_scale,
            logit_cap=self.attn_mqa.logit_cap,
            use_rope=enable_rope_fusion,
            is_neox_style=self.rotary_emb.is_neox_style,
        )

        if enable_rope_fusion:
            k_input[..., self.kv_lora_rank :] = k_pe_output
            forward_batch.token_to_kv_pool.set_kv_buffer(
                self.attn_mqa, forward_batch.out_cache_loc, k_input, None
            )

        attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)

1186
1187
1188
1189
1190
1191
1192
        if self.w_vc.dtype == torch.float8_e4m3fnuz:
            # TODO(kernel): add bmm_fp8 for torch.float8_e4m3fnuz
            attn_bmm_output = torch.bmm(
                attn_output.to(torch.bfloat16).transpose(0, 1),
                self.w_vc.to(torch.bfloat16) * self.w_scale,
            )
        elif self.w_vc.dtype == torch.float8_e4m3fn:
1193
            attn_output_val, attn_output_scale = per_tensor_quant_mla_fp8(
1194
1195
1196
                attn_output.transpose(0, 1),
                zero_allocator.allocate(1),
                dtype=torch.float8_e4m3fn,
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
            )
            attn_bmm_output = bmm_fp8(
                attn_output_val,
                self.w_vc,
                attn_output_scale,
                self.w_scale,
                torch.bfloat16,
            )
        else:
            attn_bmm_output = torch.bmm(attn_output.transpose(0, 1), self.w_vc)
        attn_output = attn_bmm_output.transpose(0, 1).flatten(1, 2)
1208
1209
1210
1211
        output, _ = self.o_proj(attn_output)

        return output

1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
    def _chunked_prefix_attn_mha(
        self,
        q: torch.Tensor,
        accum_output: torch.Tensor,
        accum_lse: torch.Tensor,
        forward_batch: ForwardBatch,
    ) -> torch.Tensor:

        assert forward_batch.num_prefix_chunks is not None
        for i in range(forward_batch.num_prefix_chunks):
            forward_batch.set_prefix_chunk_idx(i)

            # Fetch latent cache from memory pool with precomputed chunked kv indices
            latent_cache_buf = forward_batch.token_to_kv_pool.get_key_buffer(
                self.attn_mha.layer_id
            )
            latent_cache = latent_cache_buf[
                forward_batch.prefix_chunk_kv_indices[i]
            ].contiguous()

            kv_a_normed, k_pe = latent_cache.split(
                [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
            )
            kv_a_normed = kv_a_normed.squeeze(1).contiguous()
            kv = self.kv_b_proj(kv_a_normed)[0]
            kv = kv.view(
                -1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim
            )
            v = kv[..., self.qk_nope_head_dim :]
            k_nope = kv[..., : self.qk_nope_head_dim]

            k = torch.empty(
                (
                    k_nope.shape[0],
                    self.num_local_heads,
                    self.qk_nope_head_dim + self.qk_rope_head_dim,
                ),
                dtype=v.dtype,
                device=v.device,
            )
            k[..., : self.qk_nope_head_dim] = k_nope
            k[..., self.qk_nope_head_dim :] = k_pe

            output, lse = self.attn_mha(q, k, v, forward_batch, save_kv_cache=False)
            lse = torch.transpose(lse, 0, 1).contiguous()
            tmp_output = torch.empty_like(accum_output)
            tmp_lse = torch.empty_like(accum_lse)
            merge_state_v2(output, lse, accum_output, accum_lse, tmp_output, tmp_lse)
            accum_output, accum_lse = tmp_output, tmp_lse

        return accum_output

1264
    def forward_normal_chunked_kv_prepare(
1265
1266
1267
1268
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        forward_batch: ForwardBatch,
1269
1270
        zero_allocator: BumpAllocator,
    ):
1271
1272
1273
1274
1275
1276
1277
1278
        # In normal mha, the k and v tensors will become overly large when the prefix length is long.
        # To avoid this, we split the kv cache into chunks and process them one after another.
        # Since mha is compute friendly, the for loop induced here will not introduce significant overhead.
        # The top comments in https://github.com/vllm-project/vllm/blob/main/vllm/v1/attention/backends/mla/common.py
        # will be helpful for understanding the purpose of this function.

        # First do normal mha forward to get output for extended part
        if self.q_lora_rank is not None:
1279
1280
1281
            q, latent_cache = self.fused_qkv_a_proj_with_mqa(hidden_states)[0].split(
                [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1
            )
1282
1283
1284
1285
1286
1287
            q = self.q_a_layernorm(q)
            q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
        else:
            q = self.q_proj(hidden_states)[0].view(
                -1, self.num_local_heads, self.qk_head_dim
            )
1288
            latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
        _, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
        kv_a, _ = latent_cache.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
        latent_cache = latent_cache.unsqueeze(1)
        kv_a = self.kv_a_layernorm(kv_a.contiguous())
        kv = self.kv_b_proj(kv_a)[0]
        kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim)
        k_nope = kv[..., : self.qk_nope_head_dim]
        v = kv[..., self.qk_nope_head_dim :]
        k_pe = latent_cache[:, :, self.kv_lora_rank :]

        q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
        q[..., self.qk_nope_head_dim :] = q_pe
        k = torch.empty_like(q)
        k[..., : self.qk_nope_head_dim] = k_nope
        k[..., self.qk_nope_head_dim :] = k_pe

        latent_cache[:, :, : self.kv_lora_rank] = kv_a.unsqueeze(1)
        latent_cache[:, :, self.kv_lora_rank :] = k_pe

        # Save latent cache
        forward_batch.token_to_kv_pool.set_kv_buffer(
            self.attn_mha, forward_batch.out_cache_loc, latent_cache, None
        )

1313
1314
1315
        return q, k, v, forward_batch

    def forward_normal_chunked_kv_core(self, q, k, v, forward_batch):
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
        # Do mha for extended part without prefix
        forward_batch.set_attn_attend_prefix_cache(False)
        attn_output, lse = self.attn_mha(q, k, v, forward_batch, save_kv_cache=False)
        lse = torch.transpose(lse, 0, 1).contiguous()

        # Do mha attention with chunked prefix cache if there are any sequence with prefix
        if any(forward_batch.extend_prefix_lens_cpu):
            # Only initialize the info once
            if forward_batch.num_prefix_chunks is None:
                forward_batch.prepare_chunked_prefix_cache_info(q.device)

            forward_batch.set_attn_attend_prefix_cache(True)
            attn_output = self._chunked_prefix_attn_mha(
                q=q,
                accum_output=attn_output,
                accum_lse=lse,
                forward_batch=forward_batch,
            )

        attn_output = attn_output.reshape(-1, self.num_local_heads * self.v_head_dim)
        output, _ = self.o_proj(attn_output)
        return output

1339

Liangsheng Yin's avatar
Liangsheng Yin committed
1340
1341
1342
1343
1344
1345
1346
class DeepseekV2DecoderLayer(nn.Module):

    def __init__(
        self,
        config: PretrainedConfig,
        layer_id: int,
        quant_config: Optional[QuantizationConfig] = None,
1347
        is_nextn: bool = False,
1348
        prefix: str = "",
1349
        alt_stream: Optional[torch.cuda.Stream] = None,
Liangsheng Yin's avatar
Liangsheng Yin committed
1350
1351
1352
    ) -> None:
        super().__init__()
        self.hidden_size = config.hidden_size
1353
        self.config = config
Liangsheng Yin's avatar
Liangsheng Yin committed
1354
1355
1356
        rope_theta = getattr(config, "rope_theta", 10000)
        rope_scaling = getattr(config, "rope_scaling", None)
        max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
Lianmin Zheng's avatar
Lianmin Zheng committed
1357
1358
        self.enable_dp_attention = global_server_args_dict["enable_dp_attention"]
        self.layer_id = layer_id
Baizhou Zhang's avatar
Baizhou Zhang committed
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
        self.self_attn = DeepseekV2AttentionMLA(
            config=config,
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
            qk_nope_head_dim=config.qk_nope_head_dim,
            qk_rope_head_dim=config.qk_rope_head_dim,
            v_head_dim=config.v_head_dim,
            q_lora_rank=(
                config.q_lora_rank if hasattr(config, "q_lora_rank") else None
            ),
            kv_lora_rank=config.kv_lora_rank,
            rope_theta=rope_theta,
            rope_scaling=rope_scaling,
            max_position_embeddings=max_position_embeddings,
            quant_config=quant_config,
            layer_id=layer_id,
            reduce_results=False,
            prefix=add_prefix("self_attn", prefix),
1377
            alt_stream=alt_stream,
Baizhou Zhang's avatar
Baizhou Zhang committed
1378
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
1379

1380
1381
1382
1383
1384
1385
1386
1387
        self.is_layer_sparse = self._is_layer_sparse(layer_id, is_nextn=is_nextn)
        is_previous_layer_sparse = self._is_layer_sparse(layer_id - 1, is_nextn=False)

        self.layer_scatter_modes = LayerScatterModes.init_new(
            layer_id=layer_id,
            num_layers=config.num_hidden_layers,
            is_layer_sparse=self.is_layer_sparse,
            is_previous_layer_sparse=is_previous_layer_sparse,
1388
1389
        )

1390
        if self.is_layer_sparse:
1391
1392
1393
1394
            self.mlp = DeepseekV2MoE(
                config=config,
                quant_config=quant_config,
                prefix=add_prefix("mlp", prefix),
fzyzcjy's avatar
fzyzcjy committed
1395
                layer_id=self.layer_id,
1396
            )
Liangsheng Yin's avatar
Liangsheng Yin committed
1397
        else:
1398
            if enable_moe_dense_fully_dp():
1399
1400
1401
                mlp_tp_rank, mlp_tp_size = 0, 1
            else:
                mlp_tp_rank, mlp_tp_size = None, None
Liangsheng Yin's avatar
Liangsheng Yin committed
1402
1403
1404
1405
1406
            self.mlp = DeepseekV2MLP(
                hidden_size=config.hidden_size,
                intermediate_size=config.intermediate_size,
                hidden_act=config.hidden_act,
                quant_config=quant_config,
1407
                prefix=add_prefix("mlp", prefix),
1408
1409
                tp_rank=mlp_tp_rank,
                tp_size=mlp_tp_size,
Liangsheng Yin's avatar
Liangsheng Yin committed
1410
            )
1411

Liangsheng Yin's avatar
Liangsheng Yin committed
1412
1413
1414
1415
1416
        self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(
            config.hidden_size, eps=config.rms_norm_eps
        )

1417
1418
1419
1420
        self.layer_communicator = LayerCommunicator(
            layer_scatter_modes=self.layer_scatter_modes,
            input_layernorm=self.input_layernorm,
            post_attention_layernorm=self.post_attention_layernorm,
1421
        )
1422
1423
1424
1425
1426
1427

    def _is_layer_sparse(self, layer_id: int, is_nextn: bool) -> bool:
        return is_nextn or (
            self.config.n_routed_experts is not None
            and layer_id >= self.config.first_k_dense_replace
            and layer_id % self.config.moe_layer_freq == 0
1428
1429
        )

Liangsheng Yin's avatar
Liangsheng Yin committed
1430
1431
1432
1433
    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
1434
        forward_batch: ForwardBatch,
Liangsheng Yin's avatar
Liangsheng Yin committed
1435
        residual: Optional[torch.Tensor],
1436
        zero_allocator: BumpAllocator,
Liangsheng Yin's avatar
Liangsheng Yin committed
1437
    ) -> torch.Tensor:
1438
1439
        hidden_states, residual = self.layer_communicator.prepare_attn(
            hidden_states, residual, forward_batch
1440
1441
        )

1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
            forward_batch=forward_batch,
            zero_allocator=zero_allocator,
        )

        hidden_states, residual = self.layer_communicator.prepare_mlp(
            hidden_states, residual, forward_batch
        )

        hidden_states = self.mlp(hidden_states, forward_batch)

        hidden_states, residual = self.layer_communicator.postprocess_layer(
            hidden_states, residual, forward_batch
        )

        return hidden_states, residual

1461
1462
1463
1464
1465
1466
1467
1468
    def op_comm_prepare_attn(
        self,
        state,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        forward_batch: ForwardBatch,
        residual: Optional[torch.Tensor],
        zero_allocator: BumpAllocator,
1469
        tbo_subbatch_index: Optional[int] = None,
1470
1471
    ):
        state.hidden_states_after_comm_pre_attn, state.residual_after_input_ln = (
fzyzcjy's avatar
fzyzcjy committed
1472
            self.layer_communicator.prepare_attn(hidden_states, residual, forward_batch)
1473
1474
1475
1476
1477
1478
        )
        state.update(
            dict(
                forward_batch=forward_batch,
                positions=positions,
                zero_allocator=zero_allocator,
1479
                tbo_subbatch_index=tbo_subbatch_index,
1480
            )
1481
        )
1482

1483
1484
1485
1486
1487
1488
1489
    def op_comm_prepare_mlp(self, state):
        state.hidden_states_mlp_input, state.residual_after_comm_pre_mlp = (
            self.layer_communicator.prepare_mlp(
                state.pop("hidden_states_after_attn"),
                state.pop("residual_after_input_ln"),
                state.forward_batch,
            )
1490
        )
1491

1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
    def op_mlp(self, state):
        hidden_states = state.pop("hidden_states_mlp_input")
        if not (
            enable_moe_dense_fully_dp()
            and (not self.is_layer_sparse)
            and hidden_states.shape[0] == 0
        ):
            state.hidden_states_mlp_output = self.mlp(
                hidden_states, state.forward_batch.forward_mode
            )
        else:
            state.hidden_states_mlp_output = hidden_states
1504

1505
    def op_comm_postprocess_layer(self, state):
1506
        hidden_states, residual = self.layer_communicator.postprocess_layer(
1507
1508
1509
            state.pop("hidden_states_mlp_output"),
            state.pop("residual_after_comm_pre_mlp"),
            state.forward_batch,
1510
        )
1511

1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
        output = dict(
            positions=state.positions,
            hidden_states=hidden_states,
            residual=residual,
            forward_batch=state.forward_batch,
            zero_allocator=state.zero_allocator,
            tbo_subbatch_index=state.tbo_subbatch_index,
        )

        state.clear(
            expect_keys={
                "positions",
                "forward_batch",
                "zero_allocator",
                "tbo_subbatch_index",
            }
        )
        return output
1530

Liangsheng Yin's avatar
Liangsheng Yin committed
1531
1532
1533
1534
1535
1536
1537
1538

class DeepseekV2Model(nn.Module):
    fall_back_to_pt_during_load = False

    def __init__(
        self,
        config: PretrainedConfig,
        quant_config: Optional[QuantizationConfig] = None,
1539
        prefix: str = "",
Liangsheng Yin's avatar
Liangsheng Yin committed
1540
1541
1542
1543
    ) -> None:
        super().__init__()
        self.padding_id = config.pad_token_id
        self.vocab_size = config.vocab_size
1544
        self.first_k_dense_replace = config.first_k_dense_replace
Liangsheng Yin's avatar
Liangsheng Yin committed
1545
1546
1547
1548

        self.embed_tokens = VocabParallelEmbedding(
            config.vocab_size,
            config.hidden_size,
Ke Bao's avatar
Ke Bao committed
1549
            enable_tp=not global_server_args_dict["enable_dp_attention"],
Liangsheng Yin's avatar
Liangsheng Yin committed
1550
        )
1551
        self.alt_stream = torch.cuda.Stream() if _is_cuda else None
Liangsheng Yin's avatar
Liangsheng Yin committed
1552
1553
1554
1555
1556
1557
        self.layers = nn.ModuleList(
            [
                DeepseekV2DecoderLayer(
                    config,
                    layer_id,
                    quant_config=quant_config,
1558
                    prefix=add_prefix(f"layers.{layer_id}", prefix),
1559
                    alt_stream=self.alt_stream,
Liangsheng Yin's avatar
Liangsheng Yin committed
1560
1561
1562
1563
1564
1565
                )
                for layer_id in range(config.num_hidden_layers)
            ]
        )
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

1566
        self.dp_size = get_local_attention_dp_size()
Lianmin Zheng's avatar
Lianmin Zheng committed
1567

1568
1569
1570
    def get_input_embeddings(self) -> torch.Tensor:
        return self.embed_tokens

Liangsheng Yin's avatar
Liangsheng Yin committed
1571
1572
1573
1574
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
1575
        forward_batch: ForwardBatch,
1576
        input_embeds: torch.Tensor = None,
Liangsheng Yin's avatar
Liangsheng Yin committed
1577
    ) -> torch.Tensor:
1578
1579
        total_num_layers = len(self.layers)
        device = input_embeds.device if input_embeds is not None else input_ids.device
1580
        zero_allocator = BumpAllocator(
1581
            buffer_size=total_num_layers * 2 * (2 if forward_batch.can_run_tbo else 1),
1582
            dtype=torch.float32,
1583
            device=device,
1584
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
1585

1586
1587
1588
1589
1590
        if input_embeds is None:
            hidden_states = self.embed_tokens(input_ids)
        else:
            hidden_states = input_embeds

Liangsheng Yin's avatar
Liangsheng Yin committed
1591
        residual = None
1592
1593
1594
1595
1596
1597
1598

        normal_num_layers = (
            self.first_k_dense_replace
            if forward_batch.can_run_tbo
            else total_num_layers
        )
        for i in range(normal_num_layers):
1599
1600
1601
1602
1603
            with get_global_expert_distribution_recorder().with_current_layer(i):
                layer = self.layers[i]
                hidden_states, residual = layer(
                    positions, hidden_states, forward_batch, residual, zero_allocator
                )
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615

        if normal_num_layers != total_num_layers:
            hidden_states, residual = model_forward_maybe_tbo(
                layers=self.layers[normal_num_layers:],
                enable_tbo=True,
                positions=positions,
                forward_batch=forward_batch,
                hidden_states=hidden_states,
                residual=residual,
                zero_allocator=zero_allocator,
            )

Ke Bao's avatar
Ke Bao committed
1616
        if not forward_batch.forward_mode.is_idle():
1617
1618
1619
1620
            if residual is None:
                hidden_states = self.norm(hidden_states)
            else:
                hidden_states, _ = self.norm(hidden_states, residual)
Liangsheng Yin's avatar
Liangsheng Yin committed
1621
1622
1623
1624
1625
1626
1627
1628
1629
        return hidden_states


class DeepseekV2ForCausalLM(nn.Module):

    def __init__(
        self,
        config: PretrainedConfig,
        quant_config: Optional[QuantizationConfig] = None,
1630
        prefix: str = "",
Liangsheng Yin's avatar
Liangsheng Yin committed
1631
1632
1633
    ) -> None:
        super().__init__()
        self.config = config
1634
        self.tp_size = get_tensor_model_parallel_world_size()
Liangsheng Yin's avatar
Liangsheng Yin committed
1635
        self.quant_config = quant_config
1636
1637
1638
1639
1640
1641
1642
1643
1644
        self.determine_n_share_experts_fusion()
        self.model = DeepseekV2Model(
            config, quant_config, prefix=add_prefix("model", prefix)
        )
        self.lm_head = ParallelLMHead(
            config.vocab_size,
            config.hidden_size,
            quant_config=quant_config,
            prefix=add_prefix("lm_head", prefix),
1645
            use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
1646
1647
        )
        self.logits_processor = LogitsProcessor(config)
1648
        self.dp_size = get_local_attention_dp_size()
1649
1650
1651
1652

    def determine_n_share_experts_fusion(
        self, architecture: str = "DeepseekV3ForCausalLM"
    ):
1653
        self.n_share_experts_fusion = global_server_args_dict["n_share_experts_fusion"]
1654
1655
1656
        if self.n_share_experts_fusion > 0:
            # Only Deepseek V3/R1 can use shared experts fusion optimization now.
            if (
1657
1658
                not _is_cuda
                or self.config.architectures[0] != architecture
1659
1660
1661
1662
                or self.config.n_routed_experts != 256
            ):
                self.n_share_experts_fusion = 0
                global_server_args_dict["n_share_experts_fusion"] = 0
1663
1664
                log_info_on_rank0(
                    logger,
1665
                    "Only Deepseek V3/R1 on NV-platform can use shared experts fusion optimization. Shared experts fusion optimization is disabled.",
1666
1667
1668
1669
                )
            else:
                assert (
                    self.n_share_experts_fusion == self.tp_size
1670
                ), f"Shared experts fusion optimization is enabled in DeepSeek V3/R1, set it to {self.tp_size} can get best optimized performance."
1671
1672
        elif self.n_share_experts_fusion == 0:
            if (
1673
1674
                _is_cuda
                and torch.cuda.get_device_capability("cuda") >= (9, 0)
1675
                and self.config.architectures[0] == architecture
1676
1677
1678
1679
1680
                and self.config.n_routed_experts == 256
                and (not global_server_args_dict["enable_deepep_moe"])
            ):
                self.n_share_experts_fusion = self.tp_size
                global_server_args_dict["n_share_experts_fusion"] = self.tp_size
1681
1682
1683
                log_info_on_rank0(
                    logger,
                    "Deepseek V3/R1 with fp8 can use shared experts fusion optimization when SM version >=90. Shared experts fusion optimization is enabled.",
1684
                )
1685

Mick's avatar
Mick committed
1686
1687
1688
    def get_input_embeddings(self) -> nn.Embedding:
        return self.model.embed_tokens

1689
    @torch.no_grad()
Liangsheng Yin's avatar
Liangsheng Yin committed
1690
1691
1692
1693
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
1694
        forward_batch: ForwardBatch,
1695
        input_embeds: torch.Tensor = None,
Liangsheng Yin's avatar
Liangsheng Yin committed
1696
    ) -> torch.Tensor:
1697
        hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
Lianmin Zheng's avatar
Lianmin Zheng committed
1698

1699
1700
1701
        return self.logits_processor(
            input_ids, hidden_states, self.lm_head, forward_batch
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
1702

1703
    def post_load_weights(self, is_nextn=False):
inkcherry's avatar
inkcherry committed
1704
1705

        # Perform post-processing after loading weights
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
        layer_ids = (
            range(self.config.num_hidden_layers)
            if not is_nextn
            else [self.config.num_hidden_layers]
        )
        for layer_id in layer_ids:
            self_attn = (
                self.model.layers[layer_id].self_attn
                if not is_nextn
                else self.model.decoder.self_attn
            )
Baizhou Zhang's avatar
Baizhou Zhang committed
1717
1718
1719
1720
1721
1722
1723
1724
            if hasattr(self_attn.kv_b_proj, "qweight"):
                # AWQ compatible
                if _is_cuda:
                    w = awq_dequantize(
                        self_attn.kv_b_proj.qweight,
                        self_attn.kv_b_proj.scales,
                        self_attn.kv_b_proj.qzeros,
                    ).T
inkcherry's avatar
inkcherry committed
1725
                else:
Baizhou Zhang's avatar
Baizhou Zhang committed
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
                    w = awq_dequantize(
                        self_attn.kv_b_proj.qweight,
                        self_attn.kv_b_proj.scales,
                        self_attn.kv_b_proj.qzeros,
                        0,
                        0,
                        0,
                    ).T
            else:
                w = self_attn.kv_b_proj.weight
            # NOTE(HandH1998): Since `bmm_fp8` only supports per-tensor scale, we have to requantize `self_attn.kv_b_proj`.
            # This may affect the accuracy of fp8 model.
1738
1739
1740
1741
            # Fix deepseek v3 blockwise bmm by using deep_gemm
            use_deep_gemm_bmm = False
            model_dtype = torch.get_default_dtype()

Baizhou Zhang's avatar
Baizhou Zhang committed
1742
1743
1744
1745
1746
1747
1748
1749
1750
1751
1752
1753
1754
            if w.dtype in (
                torch.float8_e4m3fn,
                torch.float8_e4m3fnuz,
            ):
                if hasattr(self.quant_config, "weight_block_size"):
                    weight_block_size = self.quant_config.weight_block_size
                    if weight_block_size is not None:
                        assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
                        if _is_hip:
                            weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
                                weight=w,
                                weight_scale=self_attn.kv_b_proj.weight_scale_inv,
                                input_scale=None,
inkcherry's avatar
inkcherry committed
1755
                            )
Baizhou Zhang's avatar
Baizhou Zhang committed
1756
                        else:
inkcherry's avatar
inkcherry committed
1757
1758
                            weight = w
                            weight_scale = self_attn.kv_b_proj.weight_scale_inv
Baizhou Zhang's avatar
Baizhou Zhang committed
1759

1760
1761
1762
1763
1764
1765
                        if (
                            _is_cuda
                            and weight_block_size[0] == 128
                            and weight_block_size[1] == 128
                            and model_dtype == torch.bfloat16
                        ):
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
                            if _ENABLE_JIT_DEEPGEMM and get_bool_env_var(
                                "SGL_USE_DEEPGEMM_BMM", "false"
                            ):
                                block_scale = weight_scale
                                use_deep_gemm_bmm = True
                            else:
                                w = block_quant_dequant(
                                    weight,
                                    weight_scale,
                                    weight_block_size,
                                    model_dtype,
                                )
1778
1779
1780
1781
1782
                        else:
                            w, scale = block_quant_to_tensor_quant(
                                weight, weight_scale, weight_block_size
                            )
                            self_attn.w_scale = scale
Baizhou Zhang's avatar
Baizhou Zhang committed
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803
1804
                else:
                    weight = w
                    weight_scale = self_attn.kv_b_proj.weight_scale
                    w, scale = channel_quant_to_tensor_quant(weight, weight_scale)
                    self_attn.w_scale = scale

            if w.dtype == torch.int8:
                if hasattr(self.quant_config, "weight_block_size"):
                    # block-wise int8 need it
                    weight_block_size = self.quant_config.weight_block_size
                    if weight_block_size is not None:
                        assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
                        weight = w
                        weight_scale = self_attn.kv_b_proj.weight_scale_inv
                        w = int8_block_dequant(
                            weight, weight_scale, weight_block_size
                        ).to(torch.bfloat16)
                else:
                    # channel-wise int8 need it
                    w = w.to(torch.bfloat16) * self_attn.kv_b_proj.weight_scale.to(
                        torch.bfloat16
                    )
1805

Baizhou Zhang's avatar
Baizhou Zhang committed
1806
1807
1808
            w_kc, w_vc = w.unflatten(
                0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
            ).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
1809
1810
1811
1812
1813
1814
1815
1816
1817
1818
1819
1820
1821
1822
1823
1824
1825
1826
1827
1828
1829
            if not use_deep_gemm_bmm:
                self_attn.w_kc = w_kc.transpose(1, 2).contiguous().transpose(1, 2)
                self_attn.w_vc = w_vc.contiguous().transpose(1, 2)
                if (
                    hasattr(self_attn.kv_b_proj, "weight_scale")
                    and self_attn.w_scale is None
                ):
                    self_attn.w_scale = self_attn.kv_b_proj.weight_scale
                    if _is_hip:
                        self_attn.w_scale *= 2.0
            else:
                num_tiles_k = self_attn.qk_nope_head_dim // weight_block_size[1]
                num_tiles_n = self_attn.v_head_dim // weight_block_size[0]
                ws_kc, ws_vc = block_scale.unflatten(
                    0, (-1, (num_tiles_k + num_tiles_n))
                ).split([num_tiles_k, num_tiles_n], dim=1)
                self_attn.w_scale_k = ws_kc.transpose(1, 2).contiguous()
                self_attn.w_scale_v = ws_vc.contiguous()
                self_attn.w_kc = w_kc.transpose(1, 2).contiguous()
                self_attn.w_vc = w_vc.contiguous()
                self_attn.use_deep_gemm_bmm = True
inkcherry's avatar
inkcherry committed
1830

1831
1832
1833
1834
1835
1836
1837
1838
        # TODO support nextn later
        if not is_nextn:
            self.routed_experts_weights_of_layer = {
                layer_id: layer.mlp.get_moe_weights()
                for layer_id, layer in enumerate(self.model.layers)
                if isinstance(layer.mlp, DeepseekV2MoE)
            }

1839
1840
1841
1842
    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], is_nextn=False):
        if is_nextn:
            if hasattr(self.config, "num_nextn_predict_layers"):
                num_nextn_layers = self.config.num_nextn_predict_layers
1843
                assert num_nextn_layers == 1, "Only 1 nextn layer is supported"
1844
1845
1846
1847
1848
1849
1850
1851
1852
                # compatible with old design
                nextn_layer_id = (
                    0
                    if self.config.num_hidden_layers == 1
                    else self.config.num_hidden_layers
                )
            else:
                raise ValueError("num_nextn_predict_layers is not in the config")

Liangsheng Yin's avatar
Liangsheng Yin committed
1853
1854
1855
1856
1857
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
        ]
1858
        if self.n_share_experts_fusion > 0:
1859
1860
            weights_list = list(weights)
            weights_dict = dict(weights_list)
1861
            if self.quant_config is None or self.quant_config.get_name() == "w8a8_int8":
1862
1863
1864
1865
1866
1867
1868
1869
1870
1871
1872
1873
1874
1875
1876
1877
1878
                suffix_list = [
                    "down_proj.weight",
                    "down_proj.weight_scale",
                    "gate_proj.weight",
                    "gate_proj.weight_scale",
                    "up_proj.weight",
                    "up_proj.weight_scale",
                ]
            else:
                suffix_list = [
                    "down_proj.weight",
                    "down_proj.weight_scale_inv",
                    "gate_proj.weight",
                    "gate_proj.weight_scale_inv",
                    "up_proj.weight",
                    "up_proj.weight_scale_inv",
                ]
1879
            names_to_remove = []
1880
1881

            moe_layers = (
1882
1883
1884
1885
                range(
                    self.config.first_k_dense_replace,
                    self.config.num_hidden_layers,
                    self.config.moe_layer_freq,
1886
1887
1888
1889
1890
1891
1892
                )
                if not is_nextn
                else [nextn_layer_id]
            )

            for moe_layer in tqdm(
                moe_layers,
1893
1894
1895
                desc=f"Cloning {self.n_share_experts_fusion} "
                "replicas of the shared expert into MoE",
            ):
1896
1897
1898
1899
1900
                for suffix in suffix_list:
                    shared_expert_weight_name = (
                        f"model.layers.{moe_layer}.mlp.shared_experts.{suffix}"
                    )
                    for num_repeat in range(self.n_share_experts_fusion):
1901
1902
1903
1904
1905
1906
                        weights_list.append(
                            (
                                f"model.layers.{moe_layer}."
                                f"mlp.experts."
                                f"{self.config.n_routed_experts + num_repeat}"
                                f".{suffix}",
1907
                                weights_dict[shared_expert_weight_name],
1908
1909
                            )
                        )
1910
                    names_to_remove += [shared_expert_weight_name]
1911
            weights = [w for w in weights_list if w[0] not in names_to_remove]
Liangsheng Yin's avatar
Liangsheng Yin committed
1912
1913
1914

        # Params for weights, fp8 weight scales, fp8 activation scales
        # (param_name, weight_name, expert_id, shard_id)
1915
        expert_params_mapping = get_moe_impl_class().make_expert_params_mapping(
Liangsheng Yin's avatar
Liangsheng Yin committed
1916
1917
1918
            ckpt_gate_proj_name="gate_proj",
            ckpt_down_proj_name="down_proj",
            ckpt_up_proj_name="up_proj",
1919
            num_experts=self.config.n_routed_experts + self.n_share_experts_fusion,
Liangsheng Yin's avatar
Liangsheng Yin committed
1920
1921
        )

1922
1923
1924
1925
1926
1927
        # Fuse q_a_proj and kv_a_proj_with_mqa along output dimension when q_lora_rank is not None
        fuse_qkv_a_proj = hasattr(self.config, "q_lora_rank") and (
            self.config.q_lora_rank is not None
        )
        cached_a_proj = {} if fuse_qkv_a_proj else None

1928
1929
1930
1931
1932
1933
1934
1935
1936
        if is_nextn:
            nextn_layer_prefix = f"model.layers.{nextn_layer_id}"
            nextn_spec_weight_names = [
                "shared_head.norm",
                "eh_proj",
                "enorm",
                "hnorm",
            ]

Liangsheng Yin's avatar
Liangsheng Yin committed
1937
1938
        params_dict = dict(self.named_parameters())
        for name, loaded_weight in weights:
1939
1940
1941
1942
1943
1944
1945
1946
1947
1948
1949
1950
1951
1952
1953
1954
1955
1956
1957
1958
1959
1960
1961
1962
1963
1964
1965
1966
1967
            if not is_nextn:
                if hasattr(self.config, "num_nextn_predict_layers"):
                    num_nextn_layers = self.config.num_nextn_predict_layers
                    if num_nextn_layers > 0 and name.startswith("model.layers"):
                        name_list = name.split(".")
                        if (
                            len(name_list) >= 3
                            and int(name_list[2]) >= self.config.num_hidden_layers
                        ):
                            continue
            else:
                if not name.startswith(nextn_layer_prefix):
                    continue

                # Use shared head and embed weights from target model
                if "shared_head.head" in name or "embed_tokens" in name:
                    continue

                is_decoder = True
                # For nextn specific weights
                for weight_name in nextn_spec_weight_names:
                    if weight_name in name:
                        name = name.replace(nextn_layer_prefix, "model")
                        is_decoder = False
                        break
                # For decoder layer weights
                if is_decoder:
                    name = name.replace(nextn_layer_prefix, "model.decoder")

Liangsheng Yin's avatar
Liangsheng Yin committed
1968
1969
1970
1971
1972
1973
1974
1975
1976
1977
1978
1979
1980
1981
1982
1983
1984
1985
1986
1987
1988
1989
1990
1991
1992
1993
1994
1995
1996
1997
1998
1999
2000
            if "rotary_emb.inv_freq" in name:
                continue
            for param_name, weight_name, shard_id in stacked_params_mapping:
                # Skip non-stacked layers and experts (experts handled below).
                if weight_name not in name:
                    continue
                # We have mlp.experts[0].gate_proj in the checkpoint.
                # Since we handle the experts below in expert_params_mapping,
                # we need to skip here BEFORE we update the name, otherwise
                # name will be updated to mlp.experts[0].gate_up_proj, which
                # will then be updated below in expert_params_mapping
                # for mlp.experts[0].gate_gate_up_proj, which breaks load.
                if ("mlp.experts." in name) and name not in params_dict:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                for mapping in expert_params_mapping:
                    param_name, weight_name, expert_id, shard_id = mapping
                    if weight_name not in name:
                        continue
                    name = name.replace(weight_name, param_name)
                    param = params_dict[name]
                    weight_loader = param.weight_loader
                    weight_loader(
                        param,
                        loaded_weight,
2001
                        name,
Liangsheng Yin's avatar
Liangsheng Yin committed
2002
2003
2004
2005
2006
2007
2008
2009
2010
                        shard_id=shard_id,
                        expert_id=expert_id,
                    )
                    break
                else:
                    # Skip loading extra bias for GPTQ models.
                    if name.endswith(".bias") and name not in params_dict:
                        continue

2011
2012
2013
2014
2015
2016
2017
2018
2019
2020
2021
2022
2023
2024
2025
2026
2027
2028
2029
2030
2031
2032
2033
2034
2035
2036
2037
2038
2039
2040
2041
2042
2043
2044
2045
2046
2047
2048
2049
2050
2051
2052
2053
                    if fuse_qkv_a_proj and (
                        "q_a_proj" in name or "kv_a_proj_with_mqa" in name
                    ):
                        cached_a_proj[name] = loaded_weight
                        q_a_proj_name = (
                            name
                            if "q_a_proj" in name
                            else name.replace("kv_a_proj_with_mqa", "q_a_proj")
                        )
                        kv_a_proj_name = (
                            name
                            if "kv_a_proj_with_mqa" in name
                            else name.replace("q_a_proj", "kv_a_proj_with_mqa")
                        )

                        # When both q_a_proj and kv_a_proj_with_mqa has been cached, load the fused weight to parameter
                        if (
                            q_a_proj_name in cached_a_proj
                            and kv_a_proj_name in cached_a_proj
                        ):
                            q_a_proj_weight = cached_a_proj[q_a_proj_name]
                            kv_a_proj_weight = cached_a_proj[kv_a_proj_name]
                            fused_weight = torch.cat(
                                [q_a_proj_weight, kv_a_proj_weight], dim=0
                            )

                            param_name = name.replace(
                                "q_a_proj", "fused_qkv_a_proj_with_mqa"
                            )
                            param = params_dict[param_name]

                            weight_loader = getattr(
                                param, "weight_loader", default_weight_loader
                            )
                            weight_loader(param, fused_weight)
                            cached_a_proj.pop(q_a_proj_name)
                            cached_a_proj.pop(kv_a_proj_name)
                    else:
                        param = params_dict[name]
                        weight_loader = getattr(
                            param, "weight_loader", default_weight_loader
                        )
                        weight_loader(param, loaded_weight)
Liangsheng Yin's avatar
Liangsheng Yin committed
2054

2055
        self.post_load_weights(is_nextn=is_nextn)
Ke Bao's avatar
Ke Bao committed
2056

2057
2058
2059
2060
2061
2062
2063
2064
2065
2066
2067
    def get_embed_and_head(self):
        return self.model.embed_tokens.weight, self.lm_head.weight

    def set_embed_and_head(self, embed, head):
        del self.model.embed_tokens.weight
        del self.lm_head.weight
        self.model.embed_tokens.weight = embed
        self.lm_head.weight = head
        torch.cuda.empty_cache()
        torch.cuda.synchronize()

2068
2069
2070
2071
2072
2073
2074
2075
    @classmethod
    def get_model_config_for_expert_location(cls, config):
        return ModelConfigForExpertLocation(
            num_layers=config.num_hidden_layers,
            num_logical_experts=config.n_routed_experts,
            num_groups=config.n_group,
        )

Liangsheng Yin's avatar
Liangsheng Yin committed
2076

HandH1998's avatar
HandH1998 committed
2077
2078
2079
2080
2081
class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM):
    pass


EntryClass = [DeepseekV2ForCausalLM, DeepseekV3ForCausalLM]