deepseek_v2.py 73.8 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
14

Liangsheng Yin's avatar
Liangsheng Yin committed
15
16
17
# Adapted from:
# https://github.com/vllm-project/vllm/blob/fb6af8bc086328ca6659e72d11ffd4309ce4de22/vllm/model_executor/models/deepseek_v2.py
"""Inference-only DeepseekV2 model."""
18

19
import logging
20
import os
21
22
from dataclasses import dataclass
from enum import Enum, IntEnum, auto
Liangsheng Yin's avatar
Liangsheng Yin committed
23
24
25
from typing import Any, Dict, Iterable, Optional, Tuple

import torch
Ke Bao's avatar
Ke Bao committed
26
import torch.nn.functional as F
Liangsheng Yin's avatar
Liangsheng Yin committed
27
from torch import nn
28
from tqdm import tqdm
Liangsheng Yin's avatar
Liangsheng Yin committed
29
from transformers import PretrainedConfig
30
31

from sglang.srt.distributed import (
32
    get_tensor_model_parallel_rank,
Liangsheng Yin's avatar
Liangsheng Yin committed
33
    get_tensor_model_parallel_world_size,
34
    parallel_state,
Liangsheng Yin's avatar
Liangsheng Yin committed
35
36
    tensor_model_parallel_all_reduce,
)
37
from sglang.srt.layers.activation import SiluAndMul
Lianmin Zheng's avatar
Lianmin Zheng committed
38
from sglang.srt.layers.dp_attention import (
39
    dp_gather_partial,
Lianmin Zheng's avatar
Lianmin Zheng committed
40
41
42
43
    dp_scatter,
    get_attention_dp_size,
    get_attention_tp_rank,
    get_attention_tp_size,
44
45
    tp_all_gather,
    tp_reduce_scatter,
Lianmin Zheng's avatar
Lianmin Zheng committed
46
)
47
from sglang.srt.layers.layernorm import RMSNorm
48
49
50
51
52
53
from sglang.srt.layers.linear import (
    ColumnParallelLinear,
    MergedColumnParallelLinear,
    ReplicatedLinear,
    RowParallelLinear,
)
Liangsheng Yin's avatar
Liangsheng Yin committed
54
from sglang.srt.layers.logits_processor import LogitsProcessor
55
from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, EPMoE
Lianmin Zheng's avatar
Lianmin Zheng committed
56
from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
Ke Bao's avatar
Ke Bao committed
57
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
58
from sglang.srt.layers.moe.topk import select_experts
59
from sglang.srt.layers.quantization.base_config import QuantizationConfig
60
from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM
61
62
from sglang.srt.layers.quantization.fp8_kernel import (
    per_tensor_quant_mla_fp8,
63
    per_token_group_quant_mla_deep_gemm_masked_fp8,
64
)
HandH1998's avatar
HandH1998 committed
65
from sglang.srt.layers.quantization.fp8_utils import (
66
    block_quant_dequant,
HandH1998's avatar
HandH1998 committed
67
    block_quant_to_tensor_quant,
68
    channel_quant_to_tensor_quant,
69
    normalize_e4m3fn_to_e4m3fnuz,
HandH1998's avatar
HandH1998 committed
70
)
71
72
73
from sglang.srt.layers.quantization.int8_utils import (
    block_dequant as int8_block_dequant,
)
Liangsheng Yin's avatar
Liangsheng Yin committed
74
from sglang.srt.layers.radix_attention import RadixAttention
75
from sglang.srt.layers.rotary_embedding import get_rope, get_rope_wrapper
76
77
78
79
from sglang.srt.layers.vocab_parallel_embedding import (
    ParallelLMHead,
    VocabParallelEmbedding,
)
80
from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder
81
from sglang.srt.managers.schedule_batch import global_server_args_dict
82
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
83
from sglang.srt.model_loader.weight_utils import default_weight_loader
84
85
86
87
88
89
90
91
from sglang.srt.utils import (
    BumpAllocator,
    DeepEPMode,
    add_prefix,
    get_bool_env_var,
    get_int_env_var,
    is_cuda,
    is_hip,
92
    log_info_on_rank0,
93
)
94

95
_is_hip = is_hip()
Yineng Zhang's avatar
Yineng Zhang committed
96
_is_cuda = is_cuda()
97

Yineng Zhang's avatar
Yineng Zhang committed
98
if _is_cuda:
99
    from sgl_kernel import awq_dequantize, bmm_fp8, merge_state_v2
100
101
102
103

    from sglang.srt.layers.quantization.deep_gemm import (
        grouped_gemm_nt_f8f8bf16_masked as deep_gemm_grouped_gemm_nt_f8f8bf16_masked,
    )
Yineng Zhang's avatar
Yineng Zhang committed
104
else:
Lianmin Zheng's avatar
Lianmin Zheng committed
105
    from vllm._custom_ops import awq_dequantize
Liangsheng Yin's avatar
Liangsheng Yin committed
106

107
108
109
110
111
if _is_hip:
    from sglang.srt.layers.attention.triton_ops.rocm_mla_decode_rope import (
        decode_attention_fwd_grouped_rope,
    )

112
113
expert_distribution_recorder = ExpertDistributionRecorder()

114
115
logger = logging.getLogger(__name__)

Liangsheng Yin's avatar
Liangsheng Yin committed
116

117
118
119
120
121
122
123
124
125
126
127
128
class AttnForwardMethod(IntEnum):
    # Use multi-head attention
    MHA = auto()

    # Use absorbed multi-latent attention
    MLA = auto()

    # Use multi-head attention, but with KV cache chunked.
    # This method can avoid OOM when prefix lengths are long.
    MHA_CHUNKED_KV = auto()


Liangsheng Yin's avatar
Liangsheng Yin committed
129
130
131
132
133
134
135
136
class DeepseekV2MLP(nn.Module):
    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
        quant_config: Optional[QuantizationConfig] = None,
        reduce_results: bool = True,
137
        prefix: str = "",
138
139
        tp_rank: Optional[int] = None,
        tp_size: Optional[int] = None,
Liangsheng Yin's avatar
Liangsheng Yin committed
140
141
142
    ) -> None:
        super().__init__()
        self.gate_up_proj = MergedColumnParallelLinear(
143
144
145
146
147
            hidden_size,
            [intermediate_size] * 2,
            bias=False,
            quant_config=quant_config,
            prefix=add_prefix("gate_up_proj", prefix),
148
149
            tp_rank=tp_rank,
            tp_size=tp_size,
Liangsheng Yin's avatar
Liangsheng Yin committed
150
151
152
153
154
155
156
        )
        self.down_proj = RowParallelLinear(
            intermediate_size,
            hidden_size,
            bias=False,
            quant_config=quant_config,
            reduce_results=reduce_results,
157
            prefix=add_prefix("down_proj", prefix),
158
159
            tp_rank=tp_rank,
            tp_size=tp_size,
Liangsheng Yin's avatar
Liangsheng Yin committed
160
161
162
163
164
165
166
167
        )
        if hidden_act != "silu":
            raise ValueError(
                f"Unsupported activation: {hidden_act}. "
                "Only silu is supported for now."
            )
        self.act_fn = SiluAndMul()

168
    def forward(self, x, forward_mode: Optional[ForwardMode] = None):
Liangsheng Yin's avatar
Liangsheng Yin committed
169
170
171
172
173
174
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.down_proj(x)
        return x


Ke Bao's avatar
Ke Bao committed
175
class MoEGate(nn.Module):
176
177
178
179
180
    def __init__(
        self,
        config,
        prefix: str = "",
    ):
Ke Bao's avatar
Ke Bao committed
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
        super().__init__()
        self.weight = nn.Parameter(
            torch.empty((config.n_routed_experts, config.hidden_size))
        )
        if config.topk_method == "noaux_tc":
            self.e_score_correction_bias = nn.Parameter(
                torch.empty((config.n_routed_experts))
            )
        else:
            self.e_score_correction_bias = None

    def forward(self, hidden_states):
        logits = F.linear(hidden_states, self.weight, None)
        return logits


Liangsheng Yin's avatar
Liangsheng Yin committed
197
198
199
200
201
202
class DeepseekV2MoE(nn.Module):

    def __init__(
        self,
        config: PretrainedConfig,
        quant_config: Optional[QuantizationConfig] = None,
203
        prefix: str = "",
Liangsheng Yin's avatar
Liangsheng Yin committed
204
205
206
207
208
    ):
        super().__init__()
        self.tp_size = get_tensor_model_parallel_world_size()
        self.routed_scaling_factor = config.routed_scaling_factor
        self.n_shared_experts = config.n_shared_experts
209
        self.n_share_experts_fusion = global_server_args_dict["n_share_experts_fusion"]
210

Liangsheng Yin's avatar
Liangsheng Yin committed
211
212
213
214
215
216
217
218
219
220
221
222
        if self.tp_size > config.n_routed_experts:
            raise ValueError(
                f"Tensor parallel size {self.tp_size} is greater than "
                f"the number of experts {config.n_routed_experts}."
            )

        if config.hidden_act != "silu":
            raise ValueError(
                f"Unsupported activation: {config.hidden_act}. "
                "Only silu is supported for now."
            )

223
        self.gate = MoEGate(config=config, prefix=add_prefix("gate", prefix))
Ke Bao's avatar
Ke Bao committed
224

225
226
227
228
229
        MoEImpl = (
            DeepEPMoE
            if global_server_args_dict["enable_deepep_moe"]
            else (EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE)
        )
230

231
        self.experts = MoEImpl(
232
233
            num_experts=config.n_routed_experts + self.n_share_experts_fusion,
            top_k=config.num_experts_per_tok + min(self.n_share_experts_fusion, 1),
234
235
236
237
238
239
240
241
            hidden_size=config.hidden_size,
            intermediate_size=config.moe_intermediate_size,
            renormalize=config.norm_topk_prob,
            quant_config=quant_config,
            use_grouped_topk=True,
            num_expert_group=config.n_group,
            topk_group=config.topk_group,
            correction_bias=self.gate.e_score_correction_bias,
242
            routed_scaling_factor=self.routed_scaling_factor,
243
244
245
246
247
248
249
            prefix=add_prefix("experts", prefix),
            **(
                dict(deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]])
                if global_server_args_dict["enable_deepep_moe"]
                else {}
            ),
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
250

251
        if config.n_shared_experts is not None and self.n_share_experts_fusion == 0:
Liangsheng Yin's avatar
Liangsheng Yin committed
252
            intermediate_size = config.moe_intermediate_size * config.n_shared_experts
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
            # disable tp for shared experts when enable deepep moe
            if not global_server_args_dict["enable_deepep_moe"]:
                self.shared_experts = DeepseekV2MLP(
                    hidden_size=config.hidden_size,
                    intermediate_size=intermediate_size,
                    hidden_act=config.hidden_act,
                    quant_config=quant_config,
                    reduce_results=False,
                    prefix=add_prefix("shared_experts", prefix),
                )
            else:
                self.shared_experts = DeepseekV2MLP(
                    hidden_size=config.hidden_size,
                    intermediate_size=intermediate_size,
                    hidden_act=config.hidden_act,
                    quant_config=quant_config,
                    reduce_results=False,
                    prefix=add_prefix("shared_experts", prefix),
                    tp_rank=0,
                    tp_size=1,
                )

        if global_server_args_dict["enable_deepep_moe"]:
276
277
            # TODO: we will support tp < ep in the future
            self.ep_size = get_tensor_model_parallel_world_size()
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
            self.num_experts = config.n_routed_experts
            self.top_k = config.num_experts_per_tok
            self.renormalize = config.norm_topk_prob
            self.topk_group = config.topk_group
            self.num_expert_group = config.n_group
            self.correction_bias = (
                self.gate.e_score_correction_bias.data
                if self.gate.e_score_correction_bias is not None
                else None
            )

            self.deepep_dispatcher = DeepEPDispatcher(
                group=parallel_state.get_tp_group().device_group,
                router_topk=self.top_k,
                permute_fusion=True,
                num_experts=config.n_routed_experts,
                num_local_experts=config.n_routed_experts // self.tp_size,
Liangsheng Yin's avatar
Liangsheng Yin committed
295
                hidden_size=config.hidden_size,
296
                params_dtype=config.torch_dtype,
297
                deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]],
fzyzcjy's avatar
fzyzcjy committed
298
                async_finish=True,  # TODO
299
                return_recv_hook=True,
Liangsheng Yin's avatar
Liangsheng Yin committed
300
301
            )

302
303
304
305
306
307
308
309
310
    def forward(
        self, hidden_states: torch.Tensor, forward_mode: Optional[ForwardMode] = None
    ) -> torch.Tensor:
        if not global_server_args_dict["enable_deepep_moe"]:
            return self.forward_normal(hidden_states)
        else:
            return self.forward_deepep(hidden_states, forward_mode)

    def forward_normal(self, hidden_states: torch.Tensor) -> torch.Tensor:
311
        shared_output = self._forward_shared_experts(hidden_states)
Liangsheng Yin's avatar
Liangsheng Yin committed
312
        # router_logits: (num_tokens, n_experts)
Ke Bao's avatar
Ke Bao committed
313
        router_logits = self.gate(hidden_states)
Liangsheng Yin's avatar
Liangsheng Yin committed
314
315
316
317
318
319
320
321
        final_hidden_states = (
            self.experts(hidden_states=hidden_states, router_logits=router_logits)
            * self.routed_scaling_factor
        )
        if shared_output is not None:
            final_hidden_states = final_hidden_states + shared_output
        if self.tp_size > 1:
            final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
fzyzcjy's avatar
fzyzcjy committed
322
        return final_hidden_states
323
324
325
326
327

    def forward_deepep(
        self, hidden_states: torch.Tensor, forward_mode: ForwardMode
    ) -> torch.Tensor:
        shared_output = None
328
329
330
331
332
        if (
            forward_mode is not None
            and not forward_mode.is_idle()
            and hidden_states.shape[0] > 0
        ):
333
334
            # router_logits: (num_tokens, n_experts)
            router_logits = self.gate(hidden_states)
335
            shared_output = self._forward_shared_experts(hidden_states)
336
337
338
339
340
341
342
343
344
            topk_weights, topk_idx = select_experts(
                hidden_states=hidden_states,
                router_logits=router_logits,
                top_k=self.top_k,
                use_grouped_topk=True,
                renormalize=self.renormalize,
                topk_group=self.topk_group,
                num_expert_group=self.num_expert_group,
                correction_bias=self.correction_bias,
345
                routed_scaling_factor=self.routed_scaling_factor,
346
            )
347
348
349
350
351
352
353
        else:
            topk_idx = torch.full(
                (0, self.top_k), -1, dtype=torch.int, device=hidden_states.device
            )
            topk_weights = torch.empty(
                (0, self.top_k), dtype=torch.float32, device=hidden_states.device
            )
354
        if self.ep_size > 1:
355
            # TODO(ch-wan): allow users to set num_max_dispatch_tokens_per_rank value
356
357
358
359
360
            (
                hidden_states,
                topk_idx,
                topk_weights,
                reorder_topk_ids,
361
                num_recv_tokens_per_expert,
362
363
364
365
366
367
368
369
                seg_indptr,
                masked_m,
                expected_m,
            ) = self.deepep_dispatcher.dispatch(
                hidden_states,
                topk_idx,
                topk_weights,
                forward_mode=forward_mode,
370
            )
371
372
        final_hidden_states = self.experts(
            hidden_states=hidden_states,
373
374
            topk_idx=topk_idx,
            topk_weights=topk_weights,
375
376
377
378
            reorder_topk_ids=reorder_topk_ids,
            seg_indptr=seg_indptr,
            masked_m=masked_m,
            expected_m=expected_m,
379
            num_recv_tokens_per_expert=num_recv_tokens_per_expert,
380
            forward_mode=forward_mode,
381
        )
382
        if self.ep_size > 1:
383
            final_hidden_states = self.deepep_dispatcher.combine(
384
385
386
387
                final_hidden_states,
                topk_idx,
                topk_weights,
                forward_mode,
388
            )
389
390
        final_hidden_states *= self.routed_scaling_factor

391
392
        if shared_output is not None:
            final_hidden_states = final_hidden_states + shared_output
Liangsheng Yin's avatar
Liangsheng Yin committed
393

fzyzcjy's avatar
fzyzcjy committed
394
        return final_hidden_states
Liangsheng Yin's avatar
Liangsheng Yin committed
395

396
    def _forward_shared_experts(self, hidden_states):
397
        if self.n_share_experts_fusion == 0:
398
399
400
401
            return self.shared_experts(hidden_states)
        else:
            return None

Liangsheng Yin's avatar
Liangsheng Yin committed
402
403
404
405
406
407
408
409
410

def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
    import math

    if scale <= 1:
        return 1.0
    return 0.1 * mscale * math.log(scale) + 1.0


411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
class DeepseekV2AttentionMLA(nn.Module):

    def __init__(
        self,
        config: PretrainedConfig,
        hidden_size: int,
        num_heads: int,
        qk_nope_head_dim: int,
        qk_rope_head_dim: int,
        v_head_dim: int,
        q_lora_rank: int,
        kv_lora_rank: int,
        rope_theta: float = 10000,
        rope_scaling: Optional[Dict[str, Any]] = None,
        max_position_embeddings: int = 8192,
        quant_config: Optional[QuantizationConfig] = None,
Lianmin Zheng's avatar
Lianmin Zheng committed
427
428
        reduce_results: bool = True,
        layer_id: int = None,
429
        prefix: str = "",
430
        alt_stream: Optional[torch.cuda.Stream] = None,
431
432
433
434
435
436
437
438
439
440
    ) -> None:
        super().__init__()
        self.layer_id = layer_id
        self.hidden_size = hidden_size
        self.qk_nope_head_dim = qk_nope_head_dim
        self.qk_rope_head_dim = qk_rope_head_dim
        self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
        self.v_head_dim = v_head_dim
        self.q_lora_rank = q_lora_rank
        self.kv_lora_rank = kv_lora_rank
Lianmin Zheng's avatar
Lianmin Zheng committed
441
442
443
444
        self.dp_size = get_attention_dp_size()
        attn_tp_rank = get_attention_tp_rank()
        attn_tp_size = get_attention_tp_size()

445
        self.num_heads = num_heads
Lianmin Zheng's avatar
Lianmin Zheng committed
446
447
        assert num_heads % attn_tp_size == 0
        self.num_local_heads = num_heads // attn_tp_size
448
449
450
451
        self.scaling = self.qk_head_dim**-0.5
        self.rope_theta = rope_theta
        self.max_position_embeddings = max_position_embeddings

Lianmin Zheng's avatar
Lianmin Zheng committed
452
453
        # For tensor parallel attention
        if self.q_lora_rank is not None:
454
            self.fused_qkv_a_proj_with_mqa = ReplicatedLinear(
Ke Bao's avatar
Ke Bao committed
455
                self.hidden_size,
456
                self.q_lora_rank + self.kv_lora_rank + self.qk_rope_head_dim,
457
458
                bias=False,
                quant_config=quant_config,
459
                prefix=add_prefix("fused_qkv_a_proj_with_mqa", prefix),
460
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
461
462
463
464
            self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
            self.q_b_proj = ColumnParallelLinear(
                q_lora_rank,
                self.num_heads * self.qk_head_dim,
Ke Bao's avatar
Ke Bao committed
465
466
                bias=False,
                quant_config=quant_config,
Lianmin Zheng's avatar
Lianmin Zheng committed
467
468
469
                prefix=add_prefix("q_b_proj", prefix),
                tp_rank=attn_tp_rank,
                tp_size=attn_tp_size,
Ke Bao's avatar
Ke Bao committed
470
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
471
472
        else:
            self.q_proj = ColumnParallelLinear(
473
                self.hidden_size,
Lianmin Zheng's avatar
Lianmin Zheng committed
474
                self.num_heads * self.qk_head_dim,
475
476
                bias=False,
                quant_config=quant_config,
Lianmin Zheng's avatar
Lianmin Zheng committed
477
478
479
                prefix=add_prefix("q_proj", prefix),
                tp_rank=attn_tp_rank,
                tp_size=attn_tp_size,
480
            )
481
482
483
484
485
486
487
488
            self.kv_a_proj_with_mqa = ReplicatedLinear(
                self.hidden_size,
                self.kv_lora_rank + self.qk_rope_head_dim,
                bias=False,
                quant_config=quant_config,
                prefix=add_prefix("kv_a_proj_with_mqa", prefix),
            )

Lianmin Zheng's avatar
Lianmin Zheng committed
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
        self.kv_b_proj = ColumnParallelLinear(
            self.kv_lora_rank,
            self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
            bias=False,
            quant_config=quant_config,
            prefix=add_prefix("kv_b_proj", prefix),
            tp_rank=attn_tp_rank,
            tp_size=attn_tp_size,
        )
        # O projection.
        self.o_proj = RowParallelLinear(
            self.num_heads * self.v_head_dim,
            self.hidden_size,
            bias=False,
            quant_config=quant_config,
            reduce_results=reduce_results,
            prefix=add_prefix("o_proj", prefix),
            tp_rank=attn_tp_rank,
            tp_size=attn_tp_size,
        )
509
        self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
Ke Bao's avatar
Ke Bao committed
510
511
512
513

        if rope_scaling:
            rope_scaling["rope_type"] = "deepseek_yarn"

514
        self.rotary_emb = get_rope(
515
516
517
518
519
520
521
522
523
524
525
526
527
            qk_rope_head_dim,
            rotary_dim=qk_rope_head_dim,
            max_position=max_position_embeddings,
            base=rope_theta,
            rope_scaling=rope_scaling,
            is_neox_style=False,
        )

        if rope_scaling:
            mscale_all_dim = rope_scaling.get("mscale_all_dim", False)
            scaling_factor = rope_scaling["factor"]
            mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
            self.scaling = self.scaling * mscale * mscale
Ke Bao's avatar
Ke Bao committed
528
529
        else:
            self.rotary_emb.forward = self.rotary_emb.forward_native
530

531
        self.attn_mqa = RadixAttention(
532
533
534
535
536
537
            self.num_local_heads,
            self.kv_lora_rank + self.qk_rope_head_dim,
            self.scaling,
            num_kv_heads=1,
            layer_id=layer_id,
            v_head_dim=self.kv_lora_rank,
538
            quant_config=quant_config,
539
            prefix=add_prefix("attn_mqa", prefix),
540
541
        )

542
543
544
545
546
547
548
        self.attn_mha = RadixAttention(
            self.num_local_heads,
            self.qk_nope_head_dim + self.qk_rope_head_dim,
            self.scaling,
            num_kv_heads=self.num_local_heads,
            layer_id=layer_id,
            v_head_dim=self.v_head_dim,
549
            quant_config=quant_config,
550
            prefix=add_prefix("attn_mha", prefix),
551
552
        )

553
554
        self.alt_stream = alt_stream

Ke Bao's avatar
Ke Bao committed
555
556
        self.w_kc = None
        self.w_vc = None
557
        self.w_scale = None
558

559
560
561
562
        self.w_scale_k = None
        self.w_scale_v = None
        self.use_deep_gemm_bmm = False

Lianmin Zheng's avatar
Lianmin Zheng committed
563
564
565
        self.flashinfer_mla_disable_ragged = global_server_args_dict[
            "flashinfer_mla_disable_ragged"
        ]
566
567
568
        self.disable_chunked_prefix_cache = global_server_args_dict[
            "disable_chunked_prefix_cache"
        ]
569
        self.attention_backend = global_server_args_dict["attention_backend"]
570
571
572
        self.rocm_fused_decode_mla = get_bool_env_var(
            "SGLANG_ROCM_FUSED_DECODE_MLA", "false"
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
573

574
        # TODO: Design a finer way to determine the threshold
575
576
577
        self.chunked_prefix_cache_threshold = get_int_env_var(
            "SGL_CHUNKED_PREFIX_CACHE_THRESHOLD", 8192
        )
578
579
580
581

    def dispatch_attn_forward_method(
        self, forward_batch: ForwardBatch
    ) -> AttnForwardMethod:
582
        if self.attention_backend == "flashinfer":
Lianmin Zheng's avatar
Lianmin Zheng committed
583
            # Flashinfer MLA: Do not absorb when enabling ragged prefill
584
            if (
Lianmin Zheng's avatar
Lianmin Zheng committed
585
586
587
588
                not self.flashinfer_mla_disable_ragged
                and forward_batch.forward_mode.is_extend()
                and not forward_batch.forward_mode.is_target_verify()
                and not forward_batch.forward_mode.is_draft_extend()
589
                and sum(forward_batch.extend_prefix_lens_cpu) == 0
590
591
592
593
            ):
                return AttnForwardMethod.MHA
            else:
                return AttnForwardMethod.MLA
594
        elif self.attention_backend == "fa3":
595
            # Flash Attention: Use MHA with chunked KV cache when prefilling on long sequences.
596
597
            if forward_batch.extend_prefix_lens_cpu is not None:
                sum_extend_prefix_lens = sum(forward_batch.extend_prefix_lens_cpu)
598
599
600
601
602
            if (
                forward_batch.forward_mode.is_extend()
                and not self.disable_chunked_prefix_cache
                and not forward_batch.forward_mode.is_target_verify()
                and not forward_batch.forward_mode.is_draft_extend()
603
604
605
606
                and (
                    sum_extend_prefix_lens >= self.chunked_prefix_cache_threshold
                    or sum_extend_prefix_lens == 0
                )
607
608
609
610
            ):
                return AttnForwardMethod.MHA_CHUNKED_KV
            else:
                return AttnForwardMethod.MLA
Lianmin Zheng's avatar
Lianmin Zheng committed
611
612
        else:
            # Triton: Use normal computation for prefill and use weight absorption for extend/decode
613
            if (
Lianmin Zheng's avatar
Lianmin Zheng committed
614
615
616
                forward_batch.forward_mode.is_extend()
                and not forward_batch.forward_mode.is_target_verify()
                and not forward_batch.forward_mode.is_draft_extend()
617
                and sum(forward_batch.extend_prefix_lens_cpu) == 0
618
619
620
621
            ):
                return AttnForwardMethod.MHA
            else:
                return AttnForwardMethod.MLA
Lianmin Zheng's avatar
Lianmin Zheng committed
622

623
624
625
626
    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
627
        forward_batch: ForwardBatch,
628
        zero_allocator: BumpAllocator,
629
    ) -> torch.Tensor:
Lianmin Zheng's avatar
Lianmin Zheng committed
630
631
632
633
634
        if hidden_states.shape[0] == 0:
            assert (
                not self.o_proj.reduce_results
            ), "short-circuiting allreduce will lead to hangs"
            return hidden_states
635

636
637
638
        attn_forward_method = self.dispatch_attn_forward_method(forward_batch)

        if attn_forward_method == AttnForwardMethod.MHA:
639
            return self.forward_normal(positions, hidden_states, forward_batch)
640
641
642
643
        elif attn_forward_method == AttnForwardMethod.MHA_CHUNKED_KV:
            return self.forward_normal_chunked_kv(
                positions, hidden_states, forward_batch
            )
644
        else:
645
            if _is_hip:
646
                if (
Lianmin Zheng's avatar
Lianmin Zheng committed
647
                    self.rocm_fused_decode_mla
648
649
650
651
652
653
                    and forward_batch.forward_mode.is_decode()
                ):
                    return self.forward_absorb_fused_mla_rope(
                        positions, hidden_states, forward_batch
                    )
                else:
654
655
656
                    return self.forward_absorb(
                        positions, hidden_states, forward_batch, zero_allocator
                    )
657
            else:
658
659
660
                return self.forward_absorb(
                    positions, hidden_states, forward_batch, zero_allocator
                )
661
662
663
664
665
666
667
668

    def forward_normal(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        forward_batch: ForwardBatch,
    ) -> torch.Tensor:
        if self.q_lora_rank is not None:
669
670
671
            q, latent_cache = self.fused_qkv_a_proj_with_mqa(hidden_states)[0].split(
                [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1
            )
672
673
674
675
676
677
            q = self.q_a_layernorm(q)
            q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
        else:
            q = self.q_proj(hidden_states)[0].view(
                -1, self.num_local_heads, self.qk_head_dim
            )
678
679
            latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]

680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
        _, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
        kv_a, _ = latent_cache.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
        latent_cache = latent_cache.unsqueeze(1)
        kv_a = self.kv_a_layernorm(kv_a.contiguous())
        kv = self.kv_b_proj(kv_a)[0]
        kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim)
        k_nope = kv[..., : self.qk_nope_head_dim]
        v = kv[..., self.qk_nope_head_dim :]
        k_pe = latent_cache[:, :, self.kv_lora_rank :]
        q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
        q[..., self.qk_nope_head_dim :] = q_pe
        k = torch.empty_like(q)
        k[..., : self.qk_nope_head_dim] = k_nope
        k[..., self.qk_nope_head_dim :] = k_pe

        latent_cache[:, :, : self.kv_lora_rank] = kv_a.unsqueeze(1)
        latent_cache[:, :, self.kv_lora_rank :] = k_pe

        # Save latent cache
        forward_batch.token_to_kv_pool.set_kv_buffer(
            self.attn_mha, forward_batch.out_cache_loc, latent_cache, None
        )
        attn_output = self.attn_mha(q, k, v, forward_batch, save_kv_cache=False)
        attn_output = attn_output.reshape(-1, self.num_local_heads * self.v_head_dim)
        output, _ = self.o_proj(attn_output)
        return output

    def forward_absorb(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        forward_batch: ForwardBatch,
712
        zero_allocator: BumpAllocator,
713
714
    ) -> torch.Tensor:
        if self.q_lora_rank is not None:
715
716
717
            q, latent_cache = self.fused_qkv_a_proj_with_mqa(hidden_states)[0].split(
                [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1
            )
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
            k_nope = latent_cache[..., : self.kv_lora_rank]

            # overlap qk norm
            if self.alt_stream is not None and torch.cuda.is_current_stream_capturing():
                current_stream = torch.cuda.current_stream()
                self.alt_stream.wait_stream(current_stream)
                q = self.q_a_layernorm(q)
                with torch.cuda.stream(self.alt_stream):
                    k_nope = self.kv_a_layernorm(k_nope)
                current_stream.wait_stream(self.alt_stream)
            else:
                q = self.q_a_layernorm(q)
                k_nope = self.kv_a_layernorm(k_nope)

            k_nope = k_nope.unsqueeze(1)
733
734
735
736
737
            q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
        else:
            q = self.q_proj(hidden_states)[0].view(
                -1, self.num_local_heads, self.qk_head_dim
            )
738
            latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
739
740
741
            k_nope = latent_cache[..., : self.kv_lora_rank]
            k_nope = self.kv_a_layernorm(k_nope).unsqueeze(1)

742
        q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
743
        k_pe = latent_cache[..., self.kv_lora_rank :].unsqueeze(1)
744

745
746
        if self.use_deep_gemm_bmm:
            q_nope_val, q_nope_scale, masked_m, expected_m, aligned_m = (
747
                per_token_group_quant_mla_deep_gemm_masked_fp8(q_nope.transpose(0, 1))
748
749
750
751
            )
            q_nope_out = q_nope.new_empty(
                (self.num_local_heads, aligned_m, self.kv_lora_rank)
            )
752
            deep_gemm_grouped_gemm_nt_f8f8bf16_masked(
753
754
755
756
757
758
759
760
                (q_nope_val, q_nope_scale),
                (self.w_kc, self.w_scale_k),
                q_nope_out,
                masked_m,
                expected_m,
            )
            q_nope_out = q_nope_out[:, :expected_m, :]
        elif self.w_kc.dtype == torch.float8_e4m3fnuz:
761
762
763
764
765
            # TODO(kernel): add bmm_fp8 for torch.float8_e4m3fnuz
            q_nope_out = torch.bmm(
                q_nope.to(torch.bfloat16).transpose(0, 1),
                self.w_kc.to(torch.bfloat16) * self.w_scale,
            )
766
        elif self.w_kc.dtype == torch.float8_e4m3fn:
767
            q_nope_val, q_nope_scale = per_tensor_quant_mla_fp8(
Lianmin Zheng's avatar
Lianmin Zheng committed
768
                q_nope.transpose(0, 1),
769
                zero_allocator.allocate(1),
770
771
772
773
774
775
            )
            q_nope_out = bmm_fp8(
                q_nope_val, self.w_kc, q_nope_scale, self.w_scale, torch.bfloat16
            )
        else:
            q_nope_out = torch.bmm(q_nope.transpose(0, 1), self.w_kc)
776
777

        q_nope_out = q_nope_out.transpose(0, 1)
778
779
        q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)

xu-yfei's avatar
xu-yfei committed
780
        if self.attention_backend == "fa3" or self.attention_backend == "flashinfer":
781
            attn_output = self.attn_mqa(
Ke Bao's avatar
Ke Bao committed
782
                q_nope_out, k_nope, k_nope, forward_batch, q_rope=q_pe, k_rope=k_pe
783
784
785
            )
        else:
            q = torch.cat([q_nope_out, q_pe], dim=-1)
Ke Bao's avatar
Ke Bao committed
786
            k = torch.cat([k_nope, k_pe], dim=-1)
787
            attn_output = self.attn_mqa(q, k, k_nope, forward_batch)
788
789
        attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)

790
791
        if self.use_deep_gemm_bmm:
            attn_output_val, attn_output_scale, masked_m, expected_m, aligned_m = (
792
793
                per_token_group_quant_mla_deep_gemm_masked_fp8(
                    attn_output.transpose(0, 1)
794
795
796
797
798
                )
            )
            attn_bmm_output = attn_output.new_empty(
                (self.num_local_heads, aligned_m, self.v_head_dim)
            )
799
            deep_gemm_grouped_gemm_nt_f8f8bf16_masked(
800
801
802
803
804
805
806
807
                (attn_output_val, attn_output_scale),
                (self.w_vc, self.w_scale_v),
                attn_bmm_output,
                masked_m,
                expected_m,
            )
            attn_bmm_output = attn_bmm_output[:, :expected_m, :]
        elif self.w_vc.dtype == torch.float8_e4m3fnuz:
808
809
810
811
812
            # TODO(kernel): add bmm_fp8 for torch.float8_e4m3fnuz
            attn_bmm_output = torch.bmm(
                attn_output.to(torch.bfloat16).transpose(0, 1),
                self.w_vc.to(torch.bfloat16) * self.w_scale,
            )
813
        elif self.w_vc.dtype == torch.float8_e4m3fn:
814
            attn_output_val, attn_output_scale = per_tensor_quant_mla_fp8(
Lianmin Zheng's avatar
Lianmin Zheng committed
815
                attn_output.transpose(0, 1),
816
                zero_allocator.allocate(1),
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
            )
            attn_bmm_output = bmm_fp8(
                attn_output_val,
                self.w_vc,
                attn_output_scale,
                self.w_scale,
                torch.bfloat16,
            )
        else:
            attn_bmm_output = torch.bmm(attn_output.transpose(0, 1), self.w_vc)
        attn_output = attn_bmm_output.transpose(0, 1).flatten(1, 2)
        output, _ = self.o_proj(attn_output)

        return output

    def forward_absorb_fused_mla_rope(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        forward_batch: ForwardBatch,
837
        zero_allocator: BumpAllocator,
838
839
840
841
842
843
844
845
846
    ) -> torch.Tensor:
        enable_rope_fusion = (
            os.getenv("SGLANG_FUSED_MLA_ENABLE_ROPE_FUSION", "1") == "1"
        )
        q_len = hidden_states.shape[0]
        q_input = hidden_states.new_empty(
            q_len, self.num_local_heads, self.kv_lora_rank + self.qk_rope_head_dim
        )
        if self.q_lora_rank is not None:
847
848
849
            q, latent_cache = self.fused_qkv_a_proj_with_mqa(hidden_states)[0].split(
                [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1
            )
850
851
852
853
854
855
            q = self.q_a_layernorm(q)
            q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
        else:
            q = self.q_proj(hidden_states)[0].view(
                -1, self.num_local_heads, self.qk_head_dim
            )
856
            latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
857
858
859
860
861
862
863
864
865
        q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)

        if self.w_kc.dtype == torch.float8_e4m3fnuz:
            # TODO(kernel): add bmm_fp8 for torch.float8_e4m3fnuz
            q_nope_out = torch.bmm(
                q_nope.to(torch.bfloat16).transpose(0, 1),
                self.w_kc.to(torch.bfloat16) * self.w_scale,
            )
        elif self.w_kc.dtype == torch.float8_e4m3fn:
866
            q_nope_val, q_nope_scale = per_tensor_quant_mla_fp8(
867
868
869
                q_nope.transpose(0, 1),
                zero_allocator.allocate(1),
                dtype=torch.float8_e4m3fn,
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
            )
            q_nope_out = bmm_fp8(
                q_nope_val, self.w_kc, q_nope_scale, self.w_scale, torch.bfloat16
            )
        else:
            q_nope_out = torch.bmm(q_nope.transpose(0, 1), self.w_kc)
        q_input[..., : self.kv_lora_rank] = q_nope_out.transpose(0, 1)
        v_input = latent_cache[..., : self.kv_lora_rank]
        v_input = self.kv_a_layernorm(v_input.contiguous()).unsqueeze(1)
        k_input = latent_cache.unsqueeze(1)
        k_input[..., : self.kv_lora_rank] = v_input

        if not enable_rope_fusion:
            k_pe = k_input[..., self.kv_lora_rank :]
            q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
            q_input[..., self.kv_lora_rank :] = q_pe
            k_input[..., self.kv_lora_rank :] = k_pe
            k_pe_output = None
        else:
            k_pe_output = torch.empty_like(k_input[..., self.kv_lora_rank :])

        q_input[..., self.kv_lora_rank :] = q_pe

        # attn_output = self.attn_mqa(q_input, k_input, v_input, forward_batch)
        # Use Fused ROPE with use_rope=OFF.
        attn_output = torch.empty(
            (q_len, self.num_local_heads, self.kv_lora_rank),
            dtype=q.dtype,
            device=q.device,
        )
        attn_logits, _, kv_indptr, kv_indices, _, _, _ = (
            forward_batch.attn_backend.forward_metadata
        )
        cos_sin_cache = self.rotary_emb.cos_sin_cache
        num_kv_split = forward_batch.attn_backend.num_kv_splits
        sm_scale = self.attn_mqa.scaling
        if attn_logits is None:
            attn_logits = torch.empty(
                (
                    forward_batch.batch_size,
                    self.num_local_heads,
                    num_kv_split,
                    self.kv_lora_rank + 1,
                ),
                dtype=torch.float32,
                device=q.device,
            )

        # save current latent cache.
        forward_batch.token_to_kv_pool.set_kv_buffer(
            self.attn_mqa, forward_batch.out_cache_loc, k_input, None
        )
        key_cache_buf = forward_batch.token_to_kv_pool.get_key_buffer(
            self.attn_mqa.layer_id
        )
        val_cache_buf = key_cache_buf[..., : self.kv_lora_rank]

        decode_attention_fwd_grouped_rope(
            q_input,
            key_cache_buf,
            val_cache_buf,
            attn_output,
            kv_indptr,
            kv_indices,
            k_pe_output,
            self.kv_lora_rank,
            self.rotary_emb.rotary_dim,
            cos_sin_cache,
            positions,
            attn_logits,
            num_kv_split,
            sm_scale,
            logit_cap=self.attn_mqa.logit_cap,
            use_rope=enable_rope_fusion,
            is_neox_style=self.rotary_emb.is_neox_style,
        )

        if enable_rope_fusion:
            k_input[..., self.kv_lora_rank :] = k_pe_output
            forward_batch.token_to_kv_pool.set_kv_buffer(
                self.attn_mqa, forward_batch.out_cache_loc, k_input, None
            )

        attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)

955
956
957
958
959
960
961
        if self.w_vc.dtype == torch.float8_e4m3fnuz:
            # TODO(kernel): add bmm_fp8 for torch.float8_e4m3fnuz
            attn_bmm_output = torch.bmm(
                attn_output.to(torch.bfloat16).transpose(0, 1),
                self.w_vc.to(torch.bfloat16) * self.w_scale,
            )
        elif self.w_vc.dtype == torch.float8_e4m3fn:
962
            attn_output_val, attn_output_scale = per_tensor_quant_mla_fp8(
963
964
965
                attn_output.transpose(0, 1),
                zero_allocator.allocate(1),
                dtype=torch.float8_e4m3fn,
966
967
968
969
970
971
972
973
974
975
976
            )
            attn_bmm_output = bmm_fp8(
                attn_output_val,
                self.w_vc,
                attn_output_scale,
                self.w_scale,
                torch.bfloat16,
            )
        else:
            attn_bmm_output = torch.bmm(attn_output.transpose(0, 1), self.w_vc)
        attn_output = attn_bmm_output.transpose(0, 1).flatten(1, 2)
977
978
979
980
        output, _ = self.o_proj(attn_output)

        return output

981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
    def _chunked_prefix_attn_mha(
        self,
        q: torch.Tensor,
        accum_output: torch.Tensor,
        accum_lse: torch.Tensor,
        forward_batch: ForwardBatch,
    ) -> torch.Tensor:

        assert forward_batch.num_prefix_chunks is not None
        for i in range(forward_batch.num_prefix_chunks):
            forward_batch.set_prefix_chunk_idx(i)

            # Fetch latent cache from memory pool with precomputed chunked kv indices
            latent_cache_buf = forward_batch.token_to_kv_pool.get_key_buffer(
                self.attn_mha.layer_id
            )
            latent_cache = latent_cache_buf[
                forward_batch.prefix_chunk_kv_indices[i]
            ].contiguous()

            kv_a_normed, k_pe = latent_cache.split(
                [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
            )
            kv_a_normed = kv_a_normed.squeeze(1).contiguous()
            kv = self.kv_b_proj(kv_a_normed)[0]
            kv = kv.view(
                -1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim
            )
            v = kv[..., self.qk_nope_head_dim :]
            k_nope = kv[..., : self.qk_nope_head_dim]

            k = torch.empty(
                (
                    k_nope.shape[0],
                    self.num_local_heads,
                    self.qk_nope_head_dim + self.qk_rope_head_dim,
                ),
                dtype=v.dtype,
                device=v.device,
            )
            k[..., : self.qk_nope_head_dim] = k_nope
            k[..., self.qk_nope_head_dim :] = k_pe

            output, lse = self.attn_mha(q, k, v, forward_batch, save_kv_cache=False)
            lse = torch.transpose(lse, 0, 1).contiguous()
            tmp_output = torch.empty_like(accum_output)
            tmp_lse = torch.empty_like(accum_lse)
            merge_state_v2(output, lse, accum_output, accum_lse, tmp_output, tmp_lse)
            accum_output, accum_lse = tmp_output, tmp_lse

        return accum_output

    def forward_normal_chunked_kv(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        forward_batch: ForwardBatch,
    ) -> torch.Tensor:
        # In normal mha, the k and v tensors will become overly large when the prefix length is long.
        # To avoid this, we split the kv cache into chunks and process them one after another.
        # Since mha is compute friendly, the for loop induced here will not introduce significant overhead.
        # The top comments in https://github.com/vllm-project/vllm/blob/main/vllm/v1/attention/backends/mla/common.py
        # will be helpful for understanding the purpose of this function.

        # First do normal mha forward to get output for extended part
        if self.q_lora_rank is not None:
1047
1048
1049
            q, latent_cache = self.fused_qkv_a_proj_with_mqa(hidden_states)[0].split(
                [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1
            )
1050
1051
1052
1053
1054
1055
            q = self.q_a_layernorm(q)
            q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
        else:
            q = self.q_proj(hidden_states)[0].view(
                -1, self.num_local_heads, self.qk_head_dim
            )
1056
            latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
        _, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
        kv_a, _ = latent_cache.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
        latent_cache = latent_cache.unsqueeze(1)
        kv_a = self.kv_a_layernorm(kv_a.contiguous())
        kv = self.kv_b_proj(kv_a)[0]
        kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim)
        k_nope = kv[..., : self.qk_nope_head_dim]
        v = kv[..., self.qk_nope_head_dim :]
        k_pe = latent_cache[:, :, self.kv_lora_rank :]

        q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
        q[..., self.qk_nope_head_dim :] = q_pe
        k = torch.empty_like(q)
        k[..., : self.qk_nope_head_dim] = k_nope
        k[..., self.qk_nope_head_dim :] = k_pe

        latent_cache[:, :, : self.kv_lora_rank] = kv_a.unsqueeze(1)
        latent_cache[:, :, self.kv_lora_rank :] = k_pe

        # Save latent cache
        forward_batch.token_to_kv_pool.set_kv_buffer(
            self.attn_mha, forward_batch.out_cache_loc, latent_cache, None
        )

        # Do mha for extended part without prefix
        forward_batch.set_attn_attend_prefix_cache(False)
        attn_output, lse = self.attn_mha(q, k, v, forward_batch, save_kv_cache=False)
        lse = torch.transpose(lse, 0, 1).contiguous()

        # Do mha attention with chunked prefix cache if there are any sequence with prefix
        if any(forward_batch.extend_prefix_lens_cpu):
            # Only initialize the info once
            if forward_batch.num_prefix_chunks is None:
                forward_batch.prepare_chunked_prefix_cache_info(q.device)

            forward_batch.set_attn_attend_prefix_cache(True)
            attn_output = self._chunked_prefix_attn_mha(
                q=q,
                accum_output=attn_output,
                accum_lse=lse,
                forward_batch=forward_batch,
            )

        attn_output = attn_output.reshape(-1, self.num_local_heads * self.v_head_dim)
        output, _ = self.o_proj(attn_output)
        return output

1104

1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
class _FFNInputMode(Enum):
    # The MLP sublayer requires 1/tp_size tokens as input
    SCATTERED = auto()
    # The MLP sublayer requires all tokens as input
    FULL = auto()


@dataclass
class _DecoderLayerInfo:
    is_sparse: bool
    ffn_input_mode: _FFNInputMode


Liangsheng Yin's avatar
Liangsheng Yin committed
1118
1119
1120
1121
1122
1123
1124
class DeepseekV2DecoderLayer(nn.Module):

    def __init__(
        self,
        config: PretrainedConfig,
        layer_id: int,
        quant_config: Optional[QuantizationConfig] = None,
1125
        is_nextn: bool = False,
1126
        prefix: str = "",
1127
        alt_stream: Optional[torch.cuda.Stream] = None,
Liangsheng Yin's avatar
Liangsheng Yin committed
1128
1129
1130
1131
1132
1133
    ) -> None:
        super().__init__()
        self.hidden_size = config.hidden_size
        rope_theta = getattr(config, "rope_theta", 10000)
        rope_scaling = getattr(config, "rope_scaling", None)
        max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
Lianmin Zheng's avatar
Lianmin Zheng committed
1134
1135
1136
        self.enable_dp_attention = global_server_args_dict["enable_dp_attention"]
        self.layer_id = layer_id
        self.dp_size = get_attention_dp_size()
1137
1138
        self.attn_tp_size = get_attention_tp_size()
        self.attn_tp_rank = get_attention_tp_rank()
Baizhou Zhang's avatar
Baizhou Zhang committed
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
        self.self_attn = DeepseekV2AttentionMLA(
            config=config,
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
            qk_nope_head_dim=config.qk_nope_head_dim,
            qk_rope_head_dim=config.qk_rope_head_dim,
            v_head_dim=config.v_head_dim,
            q_lora_rank=(
                config.q_lora_rank if hasattr(config, "q_lora_rank") else None
            ),
            kv_lora_rank=config.kv_lora_rank,
            rope_theta=rope_theta,
            rope_scaling=rope_scaling,
            max_position_embeddings=max_position_embeddings,
            quant_config=quant_config,
            layer_id=layer_id,
            reduce_results=False,
            prefix=add_prefix("self_attn", prefix),
1157
            alt_stream=alt_stream,
Baizhou Zhang's avatar
Baizhou Zhang committed
1158
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
1159

1160
1161
1162
1163
1164
1165
        self.info = self._compute_info(config, layer_id=layer_id, is_nextn=is_nextn)
        previous_layer_info = self._compute_info(
            config, layer_id=layer_id - 1, is_nextn=False
        )

        if self.info.is_sparse:
1166
1167
1168
1169
1170
            self.mlp = DeepseekV2MoE(
                config=config,
                quant_config=quant_config,
                prefix=add_prefix("mlp", prefix),
            )
Liangsheng Yin's avatar
Liangsheng Yin committed
1171
        else:
1172
1173
1174
1175
            if self._enable_moe_dense_fully_dp():
                mlp_tp_rank, mlp_tp_size = 0, 1
            else:
                mlp_tp_rank, mlp_tp_size = None, None
Liangsheng Yin's avatar
Liangsheng Yin committed
1176
1177
1178
1179
1180
            self.mlp = DeepseekV2MLP(
                hidden_size=config.hidden_size,
                intermediate_size=config.intermediate_size,
                hidden_act=config.hidden_act,
                quant_config=quant_config,
1181
                prefix=add_prefix("mlp", prefix),
1182
1183
                tp_rank=mlp_tp_rank,
                tp_size=mlp_tp_size,
Liangsheng Yin's avatar
Liangsheng Yin committed
1184
            )
1185
1186

        self.input_is_scattered = (
1187
            previous_layer_info.ffn_input_mode == _FFNInputMode.SCATTERED
1188
1189
1190
        )
        self.is_last_layer = self.layer_id == config.num_hidden_layers - 1

Liangsheng Yin's avatar
Liangsheng Yin committed
1191
1192
1193
1194
1195
        self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(
            config.hidden_size, eps=config.rms_norm_eps
        )

1196
1197
1198
1199
    @staticmethod
    def _enable_moe_dense_fully_dp():
        return global_server_args_dict["moe_dense_tp_size"] == 1

1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
    @staticmethod
    def _compute_info(config: PretrainedConfig, layer_id: int, is_nextn: bool):
        is_sparse = is_nextn or (
            config.n_routed_experts is not None
            and layer_id >= config.first_k_dense_replace
            and layer_id % config.moe_layer_freq == 0
        )
        ffn_input_mode = (
            _FFNInputMode.SCATTERED
            if (global_server_args_dict["enable_deepep_moe"] and is_sparse)
1210
            or (DeepseekV2DecoderLayer._enable_moe_dense_fully_dp() and not is_sparse)
1211
1212
1213
1214
            else _FFNInputMode.FULL
        )
        return _DecoderLayerInfo(is_sparse=is_sparse, ffn_input_mode=ffn_input_mode)

Liangsheng Yin's avatar
Liangsheng Yin committed
1215
1216
1217
1218
    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
1219
        forward_batch: ForwardBatch,
Liangsheng Yin's avatar
Liangsheng Yin committed
1220
        residual: Optional[torch.Tensor],
1221
        zero_allocator: BumpAllocator,
Liangsheng Yin's avatar
Liangsheng Yin committed
1222
    ) -> torch.Tensor:
1223
1224
        if self.info.ffn_input_mode == _FFNInputMode.SCATTERED:
            return self.forward_ffn_with_scattered_input(
1225
                positions, hidden_states, forward_batch, residual, zero_allocator
1226
            )
1227
1228
        elif self.info.ffn_input_mode == _FFNInputMode.FULL:
            return self.forward_ffn_with_full_input(
1229
                positions, hidden_states, forward_batch, residual, zero_allocator
1230
            )
1231
1232
        else:
            raise NotImplementedError
1233

1234
    def forward_ffn_with_full_input(
1235
1236
1237
1238
1239
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        forward_batch: ForwardBatch,
        residual: Optional[torch.Tensor],
1240
        zero_allocator: BumpAllocator,
1241
1242
    ) -> torch.Tensor:

1243
        if hidden_states.shape[0] == 0:
Lianmin Zheng's avatar
Lianmin Zheng committed
1244
1245
            residual = hidden_states
        else:
1246
1247
1248
1249
1250
            if residual is None:
                residual = hidden_states
                hidden_states = self.input_layernorm(hidden_states)
            else:
                hidden_states, residual = self.input_layernorm(hidden_states, residual)
Lianmin Zheng's avatar
Lianmin Zheng committed
1251

1252
1253
1254
1255
            assert not (
                self.attn_tp_size != 1 and self.input_is_scattered
            ), "moe_layer_freq > 1 is not supported when attn_tp_size > 1"

1256
1257
1258
1259
1260
            # Self Attention
            hidden_states = self.self_attn(
                positions=positions,
                hidden_states=hidden_states,
                forward_batch=forward_batch,
1261
                zero_allocator=zero_allocator,
Lianmin Zheng's avatar
Lianmin Zheng committed
1262
1263
1264
1265
1266
1267
            )

        # Gather
        if get_tensor_model_parallel_world_size() > 1:
            # all gather and all reduce
            if self.dp_size != 1:
1268
1269
1270
1271
1272
1273
1274
1275
1276
                if self.attn_tp_rank == 0:
                    hidden_states += residual
                hidden_states, local_hidden_states = (
                    forward_batch.gathered_buffer,
                    hidden_states,
                )
                dp_gather_partial(hidden_states, local_hidden_states, forward_batch)
                dp_scatter(residual, hidden_states, forward_batch)
                hidden_states = self.post_attention_layernorm(hidden_states)
Ke Bao's avatar
Ke Bao committed
1277
            else:
Lianmin Zheng's avatar
Lianmin Zheng committed
1278
                hidden_states = tensor_model_parallel_all_reduce(hidden_states)
1279
1280
1281
1282
1283
1284
1285
                hidden_states, residual = self.post_attention_layernorm(
                    hidden_states, residual
                )
        else:
            hidden_states, residual = self.post_attention_layernorm(
                hidden_states, residual
            )
Liangsheng Yin's avatar
Liangsheng Yin committed
1286
1287

        # Fully Connected
Lianmin Zheng's avatar
Lianmin Zheng committed
1288
        hidden_states = self.mlp(hidden_states)
1289

1290
        # TODO(ch-wan): ues reduce-scatter in MLP to avoid this scatter
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
        # Scatter
        if self.dp_size != 1:
            # important: forward batch.gathered_buffer is used both after scatter and after gather.
            # be careful about this!
            hidden_states, global_hidden_states = (
                forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
                hidden_states,
            )
            dp_scatter(hidden_states, global_hidden_states, forward_batch)

Liangsheng Yin's avatar
Liangsheng Yin committed
1301
1302
        return hidden_states, residual

1303
    def forward_ffn_with_scattered_input(
1304
1305
1306
1307
1308
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        forward_batch: ForwardBatch,
        residual: Optional[torch.Tensor],
1309
        zero_allocator: BumpAllocator,
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
    ) -> torch.Tensor:

        if hidden_states.shape[0] == 0:
            residual = hidden_states
        else:
            if residual is None:
                residual = hidden_states
                hidden_states = self.input_layernorm(hidden_states)
            else:
                hidden_states, residual = self.input_layernorm(hidden_states, residual)

        if self.attn_tp_size != 1 and self.input_is_scattered:
            hidden_states, local_hidden_states = (
                forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
                hidden_states,
            )
            tp_all_gather(
                list(hidden_states.tensor_split(self.attn_tp_size)), local_hidden_states
            )

        # Self Attention
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
            forward_batch=forward_batch,
1335
            zero_allocator=zero_allocator,
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
        )

        if self.attn_tp_size != 1:
            if self.input_is_scattered:
                tensor_list = list(hidden_states.tensor_split(self.attn_tp_size))
                hidden_states = tensor_list[self.attn_tp_rank]
                tp_reduce_scatter(hidden_states, tensor_list)
                if hidden_states.shape[0] != 0:
                    hidden_states, residual = self.post_attention_layernorm(
                        hidden_states, residual
                    )
            else:
                if self.attn_tp_rank == 0:
                    hidden_states += residual
                tensor_list = list(hidden_states.tensor_split(self.attn_tp_size))
                hidden_states = tensor_list[self.attn_tp_rank]
                tp_reduce_scatter(hidden_states, tensor_list)
                residual = hidden_states
                if hidden_states.shape[0] != 0:
                    hidden_states = self.post_attention_layernorm(hidden_states)
        else:
            if hidden_states.shape[0] != 0:
                hidden_states, residual = self.post_attention_layernorm(
                    hidden_states, residual
                )
1361

1362
1363
1364
1365
1366
1367
        if not (
            self._enable_moe_dense_fully_dp()
            and (not self.info.is_sparse)
            and hidden_states.shape[0] == 0
        ):
            hidden_states = self.mlp(hidden_states, forward_batch.forward_mode)
1368
1369

        if self.is_last_layer and self.attn_tp_size != 1:
1370
1371
            hidden_states += residual
            residual = None
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
            hidden_states, local_hidden_states = (
                forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
                hidden_states,
            )
            tp_all_gather(
                list(hidden_states.tensor_split(self.attn_tp_size)), local_hidden_states
            )

        return hidden_states, residual

Liangsheng Yin's avatar
Liangsheng Yin committed
1382
1383
1384
1385
1386
1387
1388
1389

class DeepseekV2Model(nn.Module):
    fall_back_to_pt_during_load = False

    def __init__(
        self,
        config: PretrainedConfig,
        quant_config: Optional[QuantizationConfig] = None,
1390
        prefix: str = "",
Liangsheng Yin's avatar
Liangsheng Yin committed
1391
1392
1393
1394
1395
1396
1397
1398
    ) -> None:
        super().__init__()
        self.padding_id = config.pad_token_id
        self.vocab_size = config.vocab_size

        self.embed_tokens = VocabParallelEmbedding(
            config.vocab_size,
            config.hidden_size,
Ke Bao's avatar
Ke Bao committed
1399
            enable_tp=not global_server_args_dict["enable_dp_attention"],
Liangsheng Yin's avatar
Liangsheng Yin committed
1400
        )
1401
        self.alt_stream = torch.cuda.Stream()
Liangsheng Yin's avatar
Liangsheng Yin committed
1402
1403
1404
1405
1406
1407
        self.layers = nn.ModuleList(
            [
                DeepseekV2DecoderLayer(
                    config,
                    layer_id,
                    quant_config=quant_config,
1408
                    prefix=add_prefix(f"layers.{layer_id}", prefix),
1409
                    alt_stream=self.alt_stream,
Liangsheng Yin's avatar
Liangsheng Yin committed
1410
1411
1412
1413
1414
1415
                )
                for layer_id in range(config.num_hidden_layers)
            ]
        )
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

Lianmin Zheng's avatar
Lianmin Zheng committed
1416
1417
        self.dp_size = get_attention_dp_size()

1418
1419
1420
    def get_input_embeddings(self) -> torch.Tensor:
        return self.embed_tokens

Liangsheng Yin's avatar
Liangsheng Yin committed
1421
1422
1423
1424
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
1425
        forward_batch: ForwardBatch,
1426
        input_embeds: torch.Tensor = None,
Liangsheng Yin's avatar
Liangsheng Yin committed
1427
    ) -> torch.Tensor:
1428
1429
1430
1431
        zero_allocator = BumpAllocator(
            # TODO for two-batch-overlap, we need a larger buffer size
            buffer_size=len(self.layers) * 2,
            dtype=torch.float32,
1432
1433
1434
            device=(
                input_embeds.device if input_embeds is not None else input_ids.device
            ),
1435
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
1436

1437
1438
1439
1440
1441
        if input_embeds is None:
            hidden_states = self.embed_tokens(input_ids)
        else:
            hidden_states = input_embeds

Liangsheng Yin's avatar
Liangsheng Yin committed
1442
1443
        residual = None
        for i in range(len(self.layers)):
1444
            expert_distribution_recorder.set_current_layer(i)
Liangsheng Yin's avatar
Liangsheng Yin committed
1445
1446
            layer = self.layers[i]
            hidden_states, residual = layer(
1447
                positions, hidden_states, forward_batch, residual, zero_allocator
Liangsheng Yin's avatar
Liangsheng Yin committed
1448
            )
Ke Bao's avatar
Ke Bao committed
1449
        if not forward_batch.forward_mode.is_idle():
1450
1451
1452
1453
            if residual is None:
                hidden_states = self.norm(hidden_states)
            else:
                hidden_states, _ = self.norm(hidden_states, residual)
Liangsheng Yin's avatar
Liangsheng Yin committed
1454
1455
1456
1457
1458
1459
1460
1461
1462
        return hidden_states


class DeepseekV2ForCausalLM(nn.Module):

    def __init__(
        self,
        config: PretrainedConfig,
        quant_config: Optional[QuantizationConfig] = None,
1463
        prefix: str = "",
Liangsheng Yin's avatar
Liangsheng Yin committed
1464
1465
1466
    ) -> None:
        super().__init__()
        self.config = config
1467
        self.tp_size = get_tensor_model_parallel_world_size()
Liangsheng Yin's avatar
Liangsheng Yin committed
1468
        self.quant_config = quant_config
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
        self.determine_n_share_experts_fusion()
        self.model = DeepseekV2Model(
            config, quant_config, prefix=add_prefix("model", prefix)
        )
        self.lm_head = ParallelLMHead(
            config.vocab_size,
            config.hidden_size,
            quant_config=quant_config,
            prefix=add_prefix("lm_head", prefix),
        )
        self.logits_processor = LogitsProcessor(config)
        self.dp_size = get_attention_dp_size()

    def determine_n_share_experts_fusion(
        self, architecture: str = "DeepseekV3ForCausalLM"
    ):
1485
        self.n_share_experts_fusion = global_server_args_dict["n_share_experts_fusion"]
1486
1487
1488
        if self.n_share_experts_fusion > 0:
            # Only Deepseek V3/R1 can use shared experts fusion optimization now.
            if (
1489
                self.config.architectures[0] != architecture
1490
1491
1492
1493
                or self.config.n_routed_experts != 256
            ):
                self.n_share_experts_fusion = 0
                global_server_args_dict["n_share_experts_fusion"] = 0
1494
1495
1496
                log_info_on_rank0(
                    logger,
                    "Only Deepseek V3/R1 can use shared experts fusion optimization. Shared experts fusion optimization is disabled.",
1497
1498
1499
1500
1501
                )
            else:
                assert (
                    self.n_share_experts_fusion == self.tp_size
                ), f"Shared experts fusion optimization is enabled in DeepSeek V3/R1, set it to {self.tp_size} can get best optimized performace."
1502
1503
1504
        elif self.n_share_experts_fusion == 0:
            if (
                torch.cuda.get_device_capability("cuda") >= (9, 0)
1505
                and self.config.architectures[0] == architecture
1506
1507
1508
1509
1510
                and self.config.n_routed_experts == 256
                and (not global_server_args_dict["enable_deepep_moe"])
            ):
                self.n_share_experts_fusion = self.tp_size
                global_server_args_dict["n_share_experts_fusion"] = self.tp_size
1511
1512
1513
                log_info_on_rank0(
                    logger,
                    "Deepseek V3/R1 with fp8 can use shared experts fusion optimization when SM version >=90. Shared experts fusion optimization is enabled.",
1514
                )
1515

Mick's avatar
Mick committed
1516
1517
1518
    def get_input_embeddings(self) -> nn.Embedding:
        return self.model.embed_tokens

1519
    @torch.no_grad()
Liangsheng Yin's avatar
Liangsheng Yin committed
1520
1521
1522
1523
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
1524
        forward_batch: ForwardBatch,
1525
        input_embeds: torch.Tensor = None,
Liangsheng Yin's avatar
Liangsheng Yin committed
1526
    ) -> torch.Tensor:
1527
1528

        hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
Lianmin Zheng's avatar
Lianmin Zheng committed
1529

1530
1531
1532
        return self.logits_processor(
            input_ids, hidden_states, self.lm_head, forward_batch
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
1533

1534
    def post_load_weights(self, is_nextn=False):
inkcherry's avatar
inkcherry committed
1535
1536

        # Perform post-processing after loading weights
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
        layer_ids = (
            range(self.config.num_hidden_layers)
            if not is_nextn
            else [self.config.num_hidden_layers]
        )
        for layer_id in layer_ids:
            self_attn = (
                self.model.layers[layer_id].self_attn
                if not is_nextn
                else self.model.decoder.self_attn
            )
Baizhou Zhang's avatar
Baizhou Zhang committed
1548
1549
1550
1551
1552
1553
1554
1555
            if hasattr(self_attn.kv_b_proj, "qweight"):
                # AWQ compatible
                if _is_cuda:
                    w = awq_dequantize(
                        self_attn.kv_b_proj.qweight,
                        self_attn.kv_b_proj.scales,
                        self_attn.kv_b_proj.qzeros,
                    ).T
inkcherry's avatar
inkcherry committed
1556
                else:
Baizhou Zhang's avatar
Baizhou Zhang committed
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
                    w = awq_dequantize(
                        self_attn.kv_b_proj.qweight,
                        self_attn.kv_b_proj.scales,
                        self_attn.kv_b_proj.qzeros,
                        0,
                        0,
                        0,
                    ).T
            else:
                w = self_attn.kv_b_proj.weight
            # NOTE(HandH1998): Since `bmm_fp8` only supports per-tensor scale, we have to requantize `self_attn.kv_b_proj`.
            # This may affect the accuracy of fp8 model.
1569
1570
1571
1572
            # Fix deepseek v3 blockwise bmm by using deep_gemm
            use_deep_gemm_bmm = False
            model_dtype = torch.get_default_dtype()

Baizhou Zhang's avatar
Baizhou Zhang committed
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
            if w.dtype in (
                torch.float8_e4m3fn,
                torch.float8_e4m3fnuz,
            ):
                if hasattr(self.quant_config, "weight_block_size"):
                    weight_block_size = self.quant_config.weight_block_size
                    if weight_block_size is not None:
                        assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
                        if _is_hip:
                            weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
                                weight=w,
                                weight_scale=self_attn.kv_b_proj.weight_scale_inv,
                                input_scale=None,
inkcherry's avatar
inkcherry committed
1586
                            )
Baizhou Zhang's avatar
Baizhou Zhang committed
1587
                        else:
inkcherry's avatar
inkcherry committed
1588
1589
                            weight = w
                            weight_scale = self_attn.kv_b_proj.weight_scale_inv
Baizhou Zhang's avatar
Baizhou Zhang committed
1590

1591
1592
1593
1594
1595
1596
                        if (
                            _is_cuda
                            and weight_block_size[0] == 128
                            and weight_block_size[1] == 128
                            and model_dtype == torch.bfloat16
                        ):
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
                            if _ENABLE_JIT_DEEPGEMM and get_bool_env_var(
                                "SGL_USE_DEEPGEMM_BMM", "false"
                            ):
                                block_scale = weight_scale
                                use_deep_gemm_bmm = True
                            else:
                                w = block_quant_dequant(
                                    weight,
                                    weight_scale,
                                    weight_block_size,
                                    model_dtype,
                                )
1609
1610
1611
1612
1613
                        else:
                            w, scale = block_quant_to_tensor_quant(
                                weight, weight_scale, weight_block_size
                            )
                            self_attn.w_scale = scale
Baizhou Zhang's avatar
Baizhou Zhang committed
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
                else:
                    weight = w
                    weight_scale = self_attn.kv_b_proj.weight_scale
                    w, scale = channel_quant_to_tensor_quant(weight, weight_scale)
                    self_attn.w_scale = scale

            if w.dtype == torch.int8:
                if hasattr(self.quant_config, "weight_block_size"):
                    # block-wise int8 need it
                    weight_block_size = self.quant_config.weight_block_size
                    if weight_block_size is not None:
                        assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
                        weight = w
                        weight_scale = self_attn.kv_b_proj.weight_scale_inv
                        w = int8_block_dequant(
                            weight, weight_scale, weight_block_size
                        ).to(torch.bfloat16)
                else:
                    # channel-wise int8 need it
                    w = w.to(torch.bfloat16) * self_attn.kv_b_proj.weight_scale.to(
                        torch.bfloat16
                    )
1636

Baizhou Zhang's avatar
Baizhou Zhang committed
1637
1638
1639
            w_kc, w_vc = w.unflatten(
                0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
            ).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
            if not use_deep_gemm_bmm:
                self_attn.w_kc = w_kc.transpose(1, 2).contiguous().transpose(1, 2)
                self_attn.w_vc = w_vc.contiguous().transpose(1, 2)
                if (
                    hasattr(self_attn.kv_b_proj, "weight_scale")
                    and self_attn.w_scale is None
                ):
                    self_attn.w_scale = self_attn.kv_b_proj.weight_scale
                    if _is_hip:
                        self_attn.w_scale *= 2.0
            else:
                num_tiles_k = self_attn.qk_nope_head_dim // weight_block_size[1]
                num_tiles_n = self_attn.v_head_dim // weight_block_size[0]
                ws_kc, ws_vc = block_scale.unflatten(
                    0, (-1, (num_tiles_k + num_tiles_n))
                ).split([num_tiles_k, num_tiles_n], dim=1)
                self_attn.w_scale_k = ws_kc.transpose(1, 2).contiguous()
                self_attn.w_scale_v = ws_vc.contiguous()
                self_attn.w_kc = w_kc.transpose(1, 2).contiguous()
                self_attn.w_vc = w_vc.contiguous()
                self_attn.use_deep_gemm_bmm = True
inkcherry's avatar
inkcherry committed
1661

1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], is_nextn=False):
        if is_nextn:
            if hasattr(self.config, "num_nextn_predict_layers"):
                num_nextn_layers = self.config.num_nextn_predict_layers
                assert num_nextn_layers == 1, "Only 1 nextn layer is supportted"
                # compatible with old design
                nextn_layer_id = (
                    0
                    if self.config.num_hidden_layers == 1
                    else self.config.num_hidden_layers
                )
            else:
                raise ValueError("num_nextn_predict_layers is not in the config")

Liangsheng Yin's avatar
Liangsheng Yin committed
1676
1677
1678
1679
1680
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
        ]
1681
        if self.n_share_experts_fusion > 0:
1682
1683
            weights_list = list(weights)
            weights_dict = dict(weights_list)
1684
            if self.quant_config is None or self.quant_config.get_name() == "w8a8_int8":
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
                suffix_list = [
                    "down_proj.weight",
                    "down_proj.weight_scale",
                    "gate_proj.weight",
                    "gate_proj.weight_scale",
                    "up_proj.weight",
                    "up_proj.weight_scale",
                ]
            else:
                suffix_list = [
                    "down_proj.weight",
                    "down_proj.weight_scale_inv",
                    "gate_proj.weight",
                    "gate_proj.weight_scale_inv",
                    "up_proj.weight",
                    "up_proj.weight_scale_inv",
                ]
1702
            names_to_remove = []
1703
1704

            moe_layers = (
1705
1706
1707
1708
                range(
                    self.config.first_k_dense_replace,
                    self.config.num_hidden_layers,
                    self.config.moe_layer_freq,
1709
1710
1711
1712
1713
1714
1715
                )
                if not is_nextn
                else [nextn_layer_id]
            )

            for moe_layer in tqdm(
                moe_layers,
1716
1717
1718
                desc=f"Cloning {self.n_share_experts_fusion} "
                "replicas of the shared expert into MoE",
            ):
1719
1720
1721
1722
1723
                for suffix in suffix_list:
                    shared_expert_weight_name = (
                        f"model.layers.{moe_layer}.mlp.shared_experts.{suffix}"
                    )
                    for num_repeat in range(self.n_share_experts_fusion):
1724
1725
1726
1727
1728
1729
                        weights_list.append(
                            (
                                f"model.layers.{moe_layer}."
                                f"mlp.experts."
                                f"{self.config.n_routed_experts + num_repeat}"
                                f".{suffix}",
1730
                                weights_dict[shared_expert_weight_name],
1731
1732
                            )
                        )
1733
                    names_to_remove += [shared_expert_weight_name]
1734
            weights = [w for w in weights_list if w[0] not in names_to_remove]
Liangsheng Yin's avatar
Liangsheng Yin committed
1735
1736
1737

        # Params for weights, fp8 weight scales, fp8 activation scales
        # (param_name, weight_name, expert_id, shard_id)
1738
1739
1740
1741
1742
        MoEImpl = (
            DeepEPMoE
            if global_server_args_dict["enable_deepep_moe"]
            else (EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE)
        )
xiaobochen's avatar
xiaobochen committed
1743
        expert_params_mapping = MoEImpl.make_expert_params_mapping(
Liangsheng Yin's avatar
Liangsheng Yin committed
1744
1745
1746
            ckpt_gate_proj_name="gate_proj",
            ckpt_down_proj_name="down_proj",
            ckpt_up_proj_name="up_proj",
1747
            num_experts=self.config.n_routed_experts + self.n_share_experts_fusion,
Liangsheng Yin's avatar
Liangsheng Yin committed
1748
1749
        )

1750
1751
1752
1753
1754
1755
        # Fuse q_a_proj and kv_a_proj_with_mqa along output dimension when q_lora_rank is not None
        fuse_qkv_a_proj = hasattr(self.config, "q_lora_rank") and (
            self.config.q_lora_rank is not None
        )
        cached_a_proj = {} if fuse_qkv_a_proj else None

1756
1757
1758
1759
1760
1761
1762
1763
1764
        if is_nextn:
            nextn_layer_prefix = f"model.layers.{nextn_layer_id}"
            nextn_spec_weight_names = [
                "shared_head.norm",
                "eh_proj",
                "enorm",
                "hnorm",
            ]

Liangsheng Yin's avatar
Liangsheng Yin committed
1765
1766
        params_dict = dict(self.named_parameters())
        for name, loaded_weight in weights:
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
            if not is_nextn:
                if hasattr(self.config, "num_nextn_predict_layers"):
                    num_nextn_layers = self.config.num_nextn_predict_layers
                    if num_nextn_layers > 0 and name.startswith("model.layers"):
                        name_list = name.split(".")
                        if (
                            len(name_list) >= 3
                            and int(name_list[2]) >= self.config.num_hidden_layers
                        ):
                            continue
            else:
                if not name.startswith(nextn_layer_prefix):
                    continue

                # Use shared head and embed weights from target model
                if "shared_head.head" in name or "embed_tokens" in name:
                    continue

                is_decoder = True
                # For nextn specific weights
                for weight_name in nextn_spec_weight_names:
                    if weight_name in name:
                        name = name.replace(nextn_layer_prefix, "model")
                        is_decoder = False
                        break
                # For decoder layer weights
                if is_decoder:
                    name = name.replace(nextn_layer_prefix, "model.decoder")

Liangsheng Yin's avatar
Liangsheng Yin committed
1796
1797
1798
1799
1800
1801
1802
1803
1804
1805
1806
1807
1808
1809
1810
1811
1812
1813
1814
1815
1816
1817
1818
1819
1820
1821
1822
1823
1824
1825
1826
1827
1828
            if "rotary_emb.inv_freq" in name:
                continue
            for param_name, weight_name, shard_id in stacked_params_mapping:
                # Skip non-stacked layers and experts (experts handled below).
                if weight_name not in name:
                    continue
                # We have mlp.experts[0].gate_proj in the checkpoint.
                # Since we handle the experts below in expert_params_mapping,
                # we need to skip here BEFORE we update the name, otherwise
                # name will be updated to mlp.experts[0].gate_up_proj, which
                # will then be updated below in expert_params_mapping
                # for mlp.experts[0].gate_gate_up_proj, which breaks load.
                if ("mlp.experts." in name) and name not in params_dict:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                for mapping in expert_params_mapping:
                    param_name, weight_name, expert_id, shard_id = mapping
                    if weight_name not in name:
                        continue
                    name = name.replace(weight_name, param_name)
                    param = params_dict[name]
                    weight_loader = param.weight_loader
                    weight_loader(
                        param,
                        loaded_weight,
1829
                        name,
Liangsheng Yin's avatar
Liangsheng Yin committed
1830
1831
1832
1833
1834
1835
1836
1837
1838
                        shard_id=shard_id,
                        expert_id=expert_id,
                    )
                    break
                else:
                    # Skip loading extra bias for GPTQ models.
                    if name.endswith(".bias") and name not in params_dict:
                        continue

1839
1840
1841
1842
1843
1844
1845
1846
1847
1848
1849
1850
1851
1852
1853
1854
1855
1856
1857
1858
1859
1860
1861
1862
1863
1864
1865
1866
1867
1868
1869
1870
1871
1872
1873
1874
1875
1876
1877
1878
1879
1880
1881
1882
                    if fuse_qkv_a_proj and (
                        "q_a_proj" in name or "kv_a_proj_with_mqa" in name
                    ):
                        cached_a_proj[name] = loaded_weight
                        q_a_proj_name = (
                            name
                            if "q_a_proj" in name
                            else name.replace("kv_a_proj_with_mqa", "q_a_proj")
                        )
                        kv_a_proj_name = (
                            name
                            if "kv_a_proj_with_mqa" in name
                            else name.replace("q_a_proj", "kv_a_proj_with_mqa")
                        )

                        # When both q_a_proj and kv_a_proj_with_mqa has been cached, load the fused weight to parameter
                        if (
                            q_a_proj_name in cached_a_proj
                            and kv_a_proj_name in cached_a_proj
                        ):

                            q_a_proj_weight = cached_a_proj[q_a_proj_name]
                            kv_a_proj_weight = cached_a_proj[kv_a_proj_name]
                            fused_weight = torch.cat(
                                [q_a_proj_weight, kv_a_proj_weight], dim=0
                            )

                            param_name = name.replace(
                                "q_a_proj", "fused_qkv_a_proj_with_mqa"
                            )
                            param = params_dict[param_name]

                            weight_loader = getattr(
                                param, "weight_loader", default_weight_loader
                            )
                            weight_loader(param, fused_weight)
                            cached_a_proj.pop(q_a_proj_name)
                            cached_a_proj.pop(kv_a_proj_name)
                    else:
                        param = params_dict[name]
                        weight_loader = getattr(
                            param, "weight_loader", default_weight_loader
                        )
                        weight_loader(param, loaded_weight)
Liangsheng Yin's avatar
Liangsheng Yin committed
1883

1884
        self.post_load_weights(is_nextn=is_nextn)
Ke Bao's avatar
Ke Bao committed
1885

1886
1887
1888
1889
1890
1891
1892
1893
1894
1895
1896
    def get_embed_and_head(self):
        return self.model.embed_tokens.weight, self.lm_head.weight

    def set_embed_and_head(self, embed, head):
        del self.model.embed_tokens.weight
        del self.lm_head.weight
        self.model.embed_tokens.weight = embed
        self.lm_head.weight = head
        torch.cuda.empty_cache()
        torch.cuda.synchronize()

Liangsheng Yin's avatar
Liangsheng Yin committed
1897

HandH1998's avatar
HandH1998 committed
1898
1899
1900
1901
1902
class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM):
    pass


EntryClass = [DeepseekV2ForCausalLM, DeepseekV3ForCausalLM]