deepseek_v2.py 73 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
14

Liangsheng Yin's avatar
Liangsheng Yin committed
15
16
17
# Adapted from:
# https://github.com/vllm-project/vllm/blob/fb6af8bc086328ca6659e72d11ffd4309ce4de22/vllm/model_executor/models/deepseek_v2.py
"""Inference-only DeepseekV2 model."""
18

19
import logging
20
import os
21
22
from dataclasses import dataclass
from enum import Enum, IntEnum, auto
Liangsheng Yin's avatar
Liangsheng Yin committed
23
24
25
from typing import Any, Dict, Iterable, Optional, Tuple

import torch
Ke Bao's avatar
Ke Bao committed
26
import torch.nn.functional as F
Liangsheng Yin's avatar
Liangsheng Yin committed
27
from torch import nn
28
from tqdm import tqdm
Liangsheng Yin's avatar
Liangsheng Yin committed
29
from transformers import PretrainedConfig
30
31

from sglang.srt.distributed import (
32
    get_tensor_model_parallel_rank,
Liangsheng Yin's avatar
Liangsheng Yin committed
33
    get_tensor_model_parallel_world_size,
34
    parallel_state,
Liangsheng Yin's avatar
Liangsheng Yin committed
35
36
    tensor_model_parallel_all_reduce,
)
37
from sglang.srt.layers.activation import SiluAndMul
Lianmin Zheng's avatar
Lianmin Zheng committed
38
from sglang.srt.layers.dp_attention import (
39
40
    attn_tp_all_gather,
    attn_tp_reduce_scatter,
41
    dp_gather_partial,
Lianmin Zheng's avatar
Lianmin Zheng committed
42
43
44
    dp_scatter,
    get_attention_tp_rank,
    get_attention_tp_size,
45
    get_local_attention_dp_size,
Lianmin Zheng's avatar
Lianmin Zheng committed
46
)
47
from sglang.srt.layers.layernorm import RMSNorm
48
49
50
51
52
53
from sglang.srt.layers.linear import (
    ColumnParallelLinear,
    MergedColumnParallelLinear,
    ReplicatedLinear,
    RowParallelLinear,
)
Liangsheng Yin's avatar
Liangsheng Yin committed
54
from sglang.srt.layers.logits_processor import LogitsProcessor
55
from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, EPMoE, get_moe_impl_class
Lianmin Zheng's avatar
Lianmin Zheng committed
56
from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
Ke Bao's avatar
Ke Bao committed
57
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
58
from sglang.srt.layers.moe.topk import select_experts
59
from sglang.srt.layers.quantization.base_config import QuantizationConfig
60
from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM
61
62
from sglang.srt.layers.quantization.fp8_kernel import (
    per_tensor_quant_mla_fp8,
63
    per_token_group_quant_mla_deep_gemm_masked_fp8,
64
)
HandH1998's avatar
HandH1998 committed
65
from sglang.srt.layers.quantization.fp8_utils import (
66
    block_quant_dequant,
HandH1998's avatar
HandH1998 committed
67
    block_quant_to_tensor_quant,
68
    channel_quant_to_tensor_quant,
69
    normalize_e4m3fn_to_e4m3fnuz,
HandH1998's avatar
HandH1998 committed
70
)
71
72
73
from sglang.srt.layers.quantization.int8_utils import (
    block_dequant as int8_block_dequant,
)
Liangsheng Yin's avatar
Liangsheng Yin committed
74
from sglang.srt.layers.radix_attention import RadixAttention
75
from sglang.srt.layers.rotary_embedding import get_rope, get_rope_wrapper
76
77
78
79
from sglang.srt.layers.vocab_parallel_embedding import (
    ParallelLMHead,
    VocabParallelEmbedding,
)
80
81
82
83
84
from sglang.srt.managers.expert_distribution import (
    ExpertDistributionRecorder,
    get_global_expert_distribution_recorder,
)
from sglang.srt.managers.expert_location import ModelConfigForExpertLocation
85
from sglang.srt.managers.schedule_batch import global_server_args_dict
86
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
87
from sglang.srt.model_loader.weight_utils import default_weight_loader
88
89
90
91
92
93
94
95
from sglang.srt.utils import (
    BumpAllocator,
    DeepEPMode,
    add_prefix,
    get_bool_env_var,
    get_int_env_var,
    is_cuda,
    is_hip,
96
    log_info_on_rank0,
97
)
98

99
_is_hip = is_hip()
Yineng Zhang's avatar
Yineng Zhang committed
100
_is_cuda = is_cuda()
101

Yineng Zhang's avatar
Yineng Zhang committed
102
if _is_cuda:
103
    from sgl_kernel import awq_dequantize, bmm_fp8, merge_state_v2
104
105
106
107

    from sglang.srt.layers.quantization.deep_gemm import (
        grouped_gemm_nt_f8f8bf16_masked as deep_gemm_grouped_gemm_nt_f8f8bf16_masked,
    )
Yineng Zhang's avatar
Yineng Zhang committed
108
else:
Lianmin Zheng's avatar
Lianmin Zheng committed
109
    from vllm._custom_ops import awq_dequantize
Liangsheng Yin's avatar
Liangsheng Yin committed
110

111
112
113
114
115
if _is_hip:
    from sglang.srt.layers.attention.triton_ops.rocm_mla_decode_rope import (
        decode_attention_fwd_grouped_rope,
    )

116
117
logger = logging.getLogger(__name__)

Liangsheng Yin's avatar
Liangsheng Yin committed
118

119
120
121
122
123
124
125
126
127
128
129
130
class AttnForwardMethod(IntEnum):
    # Use multi-head attention
    MHA = auto()

    # Use absorbed multi-latent attention
    MLA = auto()

    # Use multi-head attention, but with KV cache chunked.
    # This method can avoid OOM when prefix lengths are long.
    MHA_CHUNKED_KV = auto()


Liangsheng Yin's avatar
Liangsheng Yin committed
131
132
133
134
135
136
137
138
class DeepseekV2MLP(nn.Module):
    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
        quant_config: Optional[QuantizationConfig] = None,
        reduce_results: bool = True,
139
        prefix: str = "",
140
141
        tp_rank: Optional[int] = None,
        tp_size: Optional[int] = None,
Liangsheng Yin's avatar
Liangsheng Yin committed
142
143
144
    ) -> None:
        super().__init__()
        self.gate_up_proj = MergedColumnParallelLinear(
145
146
147
148
149
            hidden_size,
            [intermediate_size] * 2,
            bias=False,
            quant_config=quant_config,
            prefix=add_prefix("gate_up_proj", prefix),
150
151
            tp_rank=tp_rank,
            tp_size=tp_size,
Liangsheng Yin's avatar
Liangsheng Yin committed
152
153
154
155
156
157
158
        )
        self.down_proj = RowParallelLinear(
            intermediate_size,
            hidden_size,
            bias=False,
            quant_config=quant_config,
            reduce_results=reduce_results,
159
            prefix=add_prefix("down_proj", prefix),
160
161
            tp_rank=tp_rank,
            tp_size=tp_size,
Liangsheng Yin's avatar
Liangsheng Yin committed
162
163
164
165
166
167
168
169
        )
        if hidden_act != "silu":
            raise ValueError(
                f"Unsupported activation: {hidden_act}. "
                "Only silu is supported for now."
            )
        self.act_fn = SiluAndMul()

170
    def forward(self, x, forward_batch: Optional[ForwardBatch] = None):
Liangsheng Yin's avatar
Liangsheng Yin committed
171
172
173
174
175
176
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.down_proj(x)
        return x


Ke Bao's avatar
Ke Bao committed
177
class MoEGate(nn.Module):
178
179
180
181
182
    def __init__(
        self,
        config,
        prefix: str = "",
    ):
Ke Bao's avatar
Ke Bao committed
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
        super().__init__()
        self.weight = nn.Parameter(
            torch.empty((config.n_routed_experts, config.hidden_size))
        )
        if config.topk_method == "noaux_tc":
            self.e_score_correction_bias = nn.Parameter(
                torch.empty((config.n_routed_experts))
            )
        else:
            self.e_score_correction_bias = None

    def forward(self, hidden_states):
        logits = F.linear(hidden_states, self.weight, None)
        return logits


199
200
201
202
203
204
205
206
def is_non_idle_and_non_empty(forward_mode, hidden_states):
    return (
        (forward_mode is not None)
        and not forward_mode.is_idle()
        and hidden_states.shape[0] > 0
    )


Liangsheng Yin's avatar
Liangsheng Yin committed
207
208
209
210
211
212
class DeepseekV2MoE(nn.Module):

    def __init__(
        self,
        config: PretrainedConfig,
        quant_config: Optional[QuantizationConfig] = None,
213
        prefix: str = "",
Liangsheng Yin's avatar
Liangsheng Yin committed
214
215
216
217
218
    ):
        super().__init__()
        self.tp_size = get_tensor_model_parallel_world_size()
        self.routed_scaling_factor = config.routed_scaling_factor
        self.n_shared_experts = config.n_shared_experts
219
        self.n_share_experts_fusion = global_server_args_dict["n_share_experts_fusion"]
220

Liangsheng Yin's avatar
Liangsheng Yin committed
221
222
223
224
225
226
227
228
229
230
231
232
        if self.tp_size > config.n_routed_experts:
            raise ValueError(
                f"Tensor parallel size {self.tp_size} is greater than "
                f"the number of experts {config.n_routed_experts}."
            )

        if config.hidden_act != "silu":
            raise ValueError(
                f"Unsupported activation: {config.hidden_act}. "
                "Only silu is supported for now."
            )

233
        self.gate = MoEGate(config=config, prefix=add_prefix("gate", prefix))
Ke Bao's avatar
Ke Bao committed
234

235
        self.experts = get_moe_impl_class()(
236
237
            num_experts=config.n_routed_experts + self.n_share_experts_fusion,
            top_k=config.num_experts_per_tok + min(self.n_share_experts_fusion, 1),
238
239
240
241
242
243
244
245
            hidden_size=config.hidden_size,
            intermediate_size=config.moe_intermediate_size,
            renormalize=config.norm_topk_prob,
            quant_config=quant_config,
            use_grouped_topk=True,
            num_expert_group=config.n_group,
            topk_group=config.topk_group,
            correction_bias=self.gate.e_score_correction_bias,
246
            routed_scaling_factor=self.routed_scaling_factor,
247
248
249
250
251
252
253
            prefix=add_prefix("experts", prefix),
            **(
                dict(deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]])
                if global_server_args_dict["enable_deepep_moe"]
                else {}
            ),
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
254

255
        if config.n_shared_experts is not None and self.n_share_experts_fusion == 0:
Liangsheng Yin's avatar
Liangsheng Yin committed
256
            intermediate_size = config.moe_intermediate_size * config.n_shared_experts
257
            # disable tp for shared experts when enable deepep moe
258
259
260
261
262
263
264
265
266
267
268
269
270
            self.shared_experts = DeepseekV2MLP(
                hidden_size=config.hidden_size,
                intermediate_size=intermediate_size,
                hidden_act=config.hidden_act,
                quant_config=quant_config,
                reduce_results=False,
                prefix=add_prefix("shared_experts", prefix),
                **(
                    dict(tp_rank=0, tp_size=1)
                    if global_server_args_dict["enable_deepep_moe"]
                    else {}
                ),
            )
271

272
273
        self.top_k = config.num_experts_per_tok

274
        if global_server_args_dict["enable_deepep_moe"]:
275
276
            # TODO: we will support tp < ep in the future
            self.ep_size = get_tensor_model_parallel_world_size()
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
            self.num_experts = config.n_routed_experts
            self.renormalize = config.norm_topk_prob
            self.topk_group = config.topk_group
            self.num_expert_group = config.n_group
            self.correction_bias = (
                self.gate.e_score_correction_bias.data
                if self.gate.e_score_correction_bias is not None
                else None
            )

            self.deepep_dispatcher = DeepEPDispatcher(
                group=parallel_state.get_tp_group().device_group,
                router_topk=self.top_k,
                permute_fusion=True,
                num_experts=config.n_routed_experts,
                num_local_experts=config.n_routed_experts // self.tp_size,
Liangsheng Yin's avatar
Liangsheng Yin committed
293
                hidden_size=config.hidden_size,
294
                params_dtype=config.torch_dtype,
295
                deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]],
fzyzcjy's avatar
fzyzcjy committed
296
                async_finish=True,  # TODO
297
                return_recv_hook=True,
Liangsheng Yin's avatar
Liangsheng Yin committed
298
299
            )

300
301
302
303
    @property
    def _enable_deepep_moe(self):
        return global_server_args_dict["enable_deepep_moe"]

304
    def forward(
305
        self, hidden_states: torch.Tensor, forward_batch: Optional[ForwardBatch] = None
306
    ) -> torch.Tensor:
307
        forward_mode = forward_batch.forward_mode
308
309
        if (not self._enable_deepep_moe) or is_non_idle_and_non_empty(
            forward_mode, hidden_states
310
        ):
311
312
            # router_logits: (num_tokens, n_experts)
            router_logits = self.gate(hidden_states)
313
314
315
316
317
318
319
320
321
322
323
324
        else:
            router_logits = None

        if (self.n_share_experts_fusion == 0) and (
            (not self._enable_deepep_moe)
            or is_non_idle_and_non_empty(forward_mode, hidden_states)
        ):
            shared_output = self.shared_experts(hidden_states)
        else:
            shared_output = None

        if self._enable_deepep_moe and (router_logits is not None):
325
326
327
328
329
330
331
332
333
            topk_weights, topk_idx = select_experts(
                hidden_states=hidden_states,
                router_logits=router_logits,
                top_k=self.top_k,
                use_grouped_topk=True,
                renormalize=self.renormalize,
                topk_group=self.topk_group,
                num_expert_group=self.num_expert_group,
                correction_bias=self.correction_bias,
334
                routed_scaling_factor=self.routed_scaling_factor,
335
                num_token_non_padded=forward_batch.num_token_non_padded,
336
            )
337
338
339
340
341
342
343
        else:
            topk_idx = torch.full(
                (0, self.top_k), -1, dtype=torch.int, device=hidden_states.device
            )
            topk_weights = torch.empty(
                (0, self.top_k), dtype=torch.float32, device=hidden_states.device
            )
344
345

        if self._enable_deepep_moe and (self.ep_size > 1):
346
            # TODO(ch-wan): allow users to set num_max_dispatch_tokens_per_rank value
347
348
349
350
351
            (
                hidden_states,
                topk_idx,
                topk_weights,
                reorder_topk_ids,
352
                num_recv_tokens_per_expert,
353
354
355
356
357
358
359
360
                seg_indptr,
                masked_m,
                expected_m,
            ) = self.deepep_dispatcher.dispatch(
                hidden_states,
                topk_idx,
                topk_weights,
                forward_mode=forward_mode,
361
            )
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380

        if self._enable_deepep_moe:
            final_hidden_states = self.experts(
                hidden_states=hidden_states,
                topk_idx=topk_idx,
                topk_weights=topk_weights,
                reorder_topk_ids=reorder_topk_ids,
                seg_indptr=seg_indptr,
                masked_m=masked_m,
                expected_m=expected_m,
                num_recv_tokens_per_expert=num_recv_tokens_per_expert,
                forward_mode=forward_mode,
            )
        else:
            final_hidden_states = self.experts(
                hidden_states=hidden_states, router_logits=router_logits
            )

        if self._enable_deepep_moe and (self.ep_size > 1):
381
            final_hidden_states = self.deepep_dispatcher.combine(
382
383
384
385
                final_hidden_states,
                topk_idx,
                topk_weights,
                forward_mode,
386
            )
387

388
389
        final_hidden_states *= self.routed_scaling_factor

390
391
        if shared_output is not None:
            final_hidden_states = final_hidden_states + shared_output
Liangsheng Yin's avatar
Liangsheng Yin committed
392

393
394
        if (not self._enable_deepep_moe) and (self.tp_size > 1):
            final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
Liangsheng Yin's avatar
Liangsheng Yin committed
395

396
        return final_hidden_states
397

Liangsheng Yin's avatar
Liangsheng Yin committed
398
399
400
401
402
403
404
405
406

def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
    import math

    if scale <= 1:
        return 1.0
    return 0.1 * mscale * math.log(scale) + 1.0


407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
class DeepseekV2AttentionMLA(nn.Module):

    def __init__(
        self,
        config: PretrainedConfig,
        hidden_size: int,
        num_heads: int,
        qk_nope_head_dim: int,
        qk_rope_head_dim: int,
        v_head_dim: int,
        q_lora_rank: int,
        kv_lora_rank: int,
        rope_theta: float = 10000,
        rope_scaling: Optional[Dict[str, Any]] = None,
        max_position_embeddings: int = 8192,
        quant_config: Optional[QuantizationConfig] = None,
Lianmin Zheng's avatar
Lianmin Zheng committed
423
424
        reduce_results: bool = True,
        layer_id: int = None,
425
        prefix: str = "",
426
        alt_stream: Optional[torch.cuda.Stream] = None,
427
428
429
430
431
432
433
434
435
436
    ) -> None:
        super().__init__()
        self.layer_id = layer_id
        self.hidden_size = hidden_size
        self.qk_nope_head_dim = qk_nope_head_dim
        self.qk_rope_head_dim = qk_rope_head_dim
        self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
        self.v_head_dim = v_head_dim
        self.q_lora_rank = q_lora_rank
        self.kv_lora_rank = kv_lora_rank
Lianmin Zheng's avatar
Lianmin Zheng committed
437
438
439
        attn_tp_rank = get_attention_tp_rank()
        attn_tp_size = get_attention_tp_size()

440
        self.num_heads = num_heads
Lianmin Zheng's avatar
Lianmin Zheng committed
441
442
        assert num_heads % attn_tp_size == 0
        self.num_local_heads = num_heads // attn_tp_size
443
444
445
446
        self.scaling = self.qk_head_dim**-0.5
        self.rope_theta = rope_theta
        self.max_position_embeddings = max_position_embeddings

Lianmin Zheng's avatar
Lianmin Zheng committed
447
448
        # For tensor parallel attention
        if self.q_lora_rank is not None:
449
            self.fused_qkv_a_proj_with_mqa = ReplicatedLinear(
Ke Bao's avatar
Ke Bao committed
450
                self.hidden_size,
451
                self.q_lora_rank + self.kv_lora_rank + self.qk_rope_head_dim,
452
453
                bias=False,
                quant_config=quant_config,
454
                prefix=add_prefix("fused_qkv_a_proj_with_mqa", prefix),
455
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
456
457
458
459
            self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
            self.q_b_proj = ColumnParallelLinear(
                q_lora_rank,
                self.num_heads * self.qk_head_dim,
Ke Bao's avatar
Ke Bao committed
460
461
                bias=False,
                quant_config=quant_config,
Lianmin Zheng's avatar
Lianmin Zheng committed
462
463
464
                prefix=add_prefix("q_b_proj", prefix),
                tp_rank=attn_tp_rank,
                tp_size=attn_tp_size,
Ke Bao's avatar
Ke Bao committed
465
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
466
467
        else:
            self.q_proj = ColumnParallelLinear(
468
                self.hidden_size,
Lianmin Zheng's avatar
Lianmin Zheng committed
469
                self.num_heads * self.qk_head_dim,
470
471
                bias=False,
                quant_config=quant_config,
Lianmin Zheng's avatar
Lianmin Zheng committed
472
473
474
                prefix=add_prefix("q_proj", prefix),
                tp_rank=attn_tp_rank,
                tp_size=attn_tp_size,
475
            )
476
477
478
479
480
481
482
483
            self.kv_a_proj_with_mqa = ReplicatedLinear(
                self.hidden_size,
                self.kv_lora_rank + self.qk_rope_head_dim,
                bias=False,
                quant_config=quant_config,
                prefix=add_prefix("kv_a_proj_with_mqa", prefix),
            )

Lianmin Zheng's avatar
Lianmin Zheng committed
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
        self.kv_b_proj = ColumnParallelLinear(
            self.kv_lora_rank,
            self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
            bias=False,
            quant_config=quant_config,
            prefix=add_prefix("kv_b_proj", prefix),
            tp_rank=attn_tp_rank,
            tp_size=attn_tp_size,
        )
        # O projection.
        self.o_proj = RowParallelLinear(
            self.num_heads * self.v_head_dim,
            self.hidden_size,
            bias=False,
            quant_config=quant_config,
            reduce_results=reduce_results,
            prefix=add_prefix("o_proj", prefix),
            tp_rank=attn_tp_rank,
            tp_size=attn_tp_size,
        )
504
        self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
Ke Bao's avatar
Ke Bao committed
505
506
507
508

        if rope_scaling:
            rope_scaling["rope_type"] = "deepseek_yarn"

509
        self.rotary_emb = get_rope(
510
511
512
513
514
515
516
517
518
519
520
521
522
            qk_rope_head_dim,
            rotary_dim=qk_rope_head_dim,
            max_position=max_position_embeddings,
            base=rope_theta,
            rope_scaling=rope_scaling,
            is_neox_style=False,
        )

        if rope_scaling:
            mscale_all_dim = rope_scaling.get("mscale_all_dim", False)
            scaling_factor = rope_scaling["factor"]
            mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
            self.scaling = self.scaling * mscale * mscale
Ke Bao's avatar
Ke Bao committed
523
524
        else:
            self.rotary_emb.forward = self.rotary_emb.forward_native
525

526
        self.attn_mqa = RadixAttention(
527
528
529
530
531
532
            self.num_local_heads,
            self.kv_lora_rank + self.qk_rope_head_dim,
            self.scaling,
            num_kv_heads=1,
            layer_id=layer_id,
            v_head_dim=self.kv_lora_rank,
533
            quant_config=quant_config,
534
            prefix=add_prefix("attn_mqa", prefix),
535
536
        )

537
538
539
540
541
542
543
        self.attn_mha = RadixAttention(
            self.num_local_heads,
            self.qk_nope_head_dim + self.qk_rope_head_dim,
            self.scaling,
            num_kv_heads=self.num_local_heads,
            layer_id=layer_id,
            v_head_dim=self.v_head_dim,
544
            quant_config=quant_config,
545
            prefix=add_prefix("attn_mha", prefix),
546
547
        )

548
549
        self.alt_stream = alt_stream

Ke Bao's avatar
Ke Bao committed
550
551
        self.w_kc = None
        self.w_vc = None
552
        self.w_scale = None
553

554
555
556
557
        self.w_scale_k = None
        self.w_scale_v = None
        self.use_deep_gemm_bmm = False

Lianmin Zheng's avatar
Lianmin Zheng committed
558
559
560
        self.flashinfer_mla_disable_ragged = global_server_args_dict[
            "flashinfer_mla_disable_ragged"
        ]
561
562
563
        self.disable_chunked_prefix_cache = global_server_args_dict[
            "disable_chunked_prefix_cache"
        ]
564
        self.attention_backend = global_server_args_dict["attention_backend"]
565
566
567
        self.rocm_fused_decode_mla = get_bool_env_var(
            "SGLANG_ROCM_FUSED_DECODE_MLA", "false"
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
568

569
        # TODO: Design a finer way to determine the threshold
570
571
572
        self.chunked_prefix_cache_threshold = get_int_env_var(
            "SGL_CHUNKED_PREFIX_CACHE_THRESHOLD", 8192
        )
573
574
575
576

    def dispatch_attn_forward_method(
        self, forward_batch: ForwardBatch
    ) -> AttnForwardMethod:
577
        if self.attention_backend == "flashinfer":
Lianmin Zheng's avatar
Lianmin Zheng committed
578
            # Flashinfer MLA: Do not absorb when enabling ragged prefill
579
            if (
Lianmin Zheng's avatar
Lianmin Zheng committed
580
581
582
583
                not self.flashinfer_mla_disable_ragged
                and forward_batch.forward_mode.is_extend()
                and not forward_batch.forward_mode.is_target_verify()
                and not forward_batch.forward_mode.is_draft_extend()
584
                and sum(forward_batch.extend_prefix_lens_cpu) == 0
585
586
587
588
            ):
                return AttnForwardMethod.MHA
            else:
                return AttnForwardMethod.MLA
589
        elif self.attention_backend == "fa3":
590
            # Flash Attention: Use MHA with chunked KV cache when prefilling on long sequences.
591
592
            if forward_batch.extend_prefix_lens_cpu is not None:
                sum_extend_prefix_lens = sum(forward_batch.extend_prefix_lens_cpu)
593
594
595
596
597
            if (
                forward_batch.forward_mode.is_extend()
                and not self.disable_chunked_prefix_cache
                and not forward_batch.forward_mode.is_target_verify()
                and not forward_batch.forward_mode.is_draft_extend()
598
599
600
601
                and (
                    sum_extend_prefix_lens >= self.chunked_prefix_cache_threshold
                    or sum_extend_prefix_lens == 0
                )
602
603
604
605
            ):
                return AttnForwardMethod.MHA_CHUNKED_KV
            else:
                return AttnForwardMethod.MLA
Lianmin Zheng's avatar
Lianmin Zheng committed
606
607
        else:
            # Triton: Use normal computation for prefill and use weight absorption for extend/decode
608
            if (
Lianmin Zheng's avatar
Lianmin Zheng committed
609
610
611
                forward_batch.forward_mode.is_extend()
                and not forward_batch.forward_mode.is_target_verify()
                and not forward_batch.forward_mode.is_draft_extend()
612
                and sum(forward_batch.extend_prefix_lens_cpu) == 0
613
614
615
616
            ):
                return AttnForwardMethod.MHA
            else:
                return AttnForwardMethod.MLA
Lianmin Zheng's avatar
Lianmin Zheng committed
617

618
619
620
621
    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
622
        forward_batch: ForwardBatch,
623
        zero_allocator: BumpAllocator,
624
    ) -> torch.Tensor:
Lianmin Zheng's avatar
Lianmin Zheng committed
625
626
627
628
629
        if hidden_states.shape[0] == 0:
            assert (
                not self.o_proj.reduce_results
            ), "short-circuiting allreduce will lead to hangs"
            return hidden_states
630

631
632
633
        attn_forward_method = self.dispatch_attn_forward_method(forward_batch)

        if attn_forward_method == AttnForwardMethod.MHA:
634
            return self.forward_normal(positions, hidden_states, forward_batch)
635
636
637
638
        elif attn_forward_method == AttnForwardMethod.MHA_CHUNKED_KV:
            return self.forward_normal_chunked_kv(
                positions, hidden_states, forward_batch
            )
639
        else:
640
            if _is_hip:
641
                if (
Lianmin Zheng's avatar
Lianmin Zheng committed
642
                    self.rocm_fused_decode_mla
643
644
645
646
647
648
                    and forward_batch.forward_mode.is_decode()
                ):
                    return self.forward_absorb_fused_mla_rope(
                        positions, hidden_states, forward_batch
                    )
                else:
649
650
651
                    return self.forward_absorb(
                        positions, hidden_states, forward_batch, zero_allocator
                    )
652
            else:
653
654
655
                return self.forward_absorb(
                    positions, hidden_states, forward_batch, zero_allocator
                )
656
657
658
659
660
661
662
663

    def forward_normal(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        forward_batch: ForwardBatch,
    ) -> torch.Tensor:
        if self.q_lora_rank is not None:
664
665
666
            q, latent_cache = self.fused_qkv_a_proj_with_mqa(hidden_states)[0].split(
                [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1
            )
667
668
669
670
671
672
            q = self.q_a_layernorm(q)
            q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
        else:
            q = self.q_proj(hidden_states)[0].view(
                -1, self.num_local_heads, self.qk_head_dim
            )
673
674
            latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]

675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
        _, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
        kv_a, _ = latent_cache.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
        latent_cache = latent_cache.unsqueeze(1)
        kv_a = self.kv_a_layernorm(kv_a.contiguous())
        kv = self.kv_b_proj(kv_a)[0]
        kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim)
        k_nope = kv[..., : self.qk_nope_head_dim]
        v = kv[..., self.qk_nope_head_dim :]
        k_pe = latent_cache[:, :, self.kv_lora_rank :]
        q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
        q[..., self.qk_nope_head_dim :] = q_pe
        k = torch.empty_like(q)
        k[..., : self.qk_nope_head_dim] = k_nope
        k[..., self.qk_nope_head_dim :] = k_pe

        latent_cache[:, :, : self.kv_lora_rank] = kv_a.unsqueeze(1)
        latent_cache[:, :, self.kv_lora_rank :] = k_pe

        # Save latent cache
        forward_batch.token_to_kv_pool.set_kv_buffer(
            self.attn_mha, forward_batch.out_cache_loc, latent_cache, None
        )
        attn_output = self.attn_mha(q, k, v, forward_batch, save_kv_cache=False)
        attn_output = attn_output.reshape(-1, self.num_local_heads * self.v_head_dim)
        output, _ = self.o_proj(attn_output)
        return output

    def forward_absorb(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        forward_batch: ForwardBatch,
707
        zero_allocator: BumpAllocator,
708
709
    ) -> torch.Tensor:
        if self.q_lora_rank is not None:
710
711
712
            q, latent_cache = self.fused_qkv_a_proj_with_mqa(hidden_states)[0].split(
                [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1
            )
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
            k_nope = latent_cache[..., : self.kv_lora_rank]

            # overlap qk norm
            if self.alt_stream is not None and torch.cuda.is_current_stream_capturing():
                current_stream = torch.cuda.current_stream()
                self.alt_stream.wait_stream(current_stream)
                q = self.q_a_layernorm(q)
                with torch.cuda.stream(self.alt_stream):
                    k_nope = self.kv_a_layernorm(k_nope)
                current_stream.wait_stream(self.alt_stream)
            else:
                q = self.q_a_layernorm(q)
                k_nope = self.kv_a_layernorm(k_nope)

            k_nope = k_nope.unsqueeze(1)
728
729
730
731
732
            q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
        else:
            q = self.q_proj(hidden_states)[0].view(
                -1, self.num_local_heads, self.qk_head_dim
            )
733
            latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
734
735
736
            k_nope = latent_cache[..., : self.kv_lora_rank]
            k_nope = self.kv_a_layernorm(k_nope).unsqueeze(1)

737
        q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
738
        k_pe = latent_cache[..., self.kv_lora_rank :].unsqueeze(1)
739

740
741
        if self.use_deep_gemm_bmm:
            q_nope_val, q_nope_scale, masked_m, expected_m, aligned_m = (
742
                per_token_group_quant_mla_deep_gemm_masked_fp8(q_nope.transpose(0, 1))
743
744
745
746
            )
            q_nope_out = q_nope.new_empty(
                (self.num_local_heads, aligned_m, self.kv_lora_rank)
            )
747
            deep_gemm_grouped_gemm_nt_f8f8bf16_masked(
748
749
750
751
752
753
754
755
                (q_nope_val, q_nope_scale),
                (self.w_kc, self.w_scale_k),
                q_nope_out,
                masked_m,
                expected_m,
            )
            q_nope_out = q_nope_out[:, :expected_m, :]
        elif self.w_kc.dtype == torch.float8_e4m3fnuz:
756
757
758
759
760
            # TODO(kernel): add bmm_fp8 for torch.float8_e4m3fnuz
            q_nope_out = torch.bmm(
                q_nope.to(torch.bfloat16).transpose(0, 1),
                self.w_kc.to(torch.bfloat16) * self.w_scale,
            )
761
        elif self.w_kc.dtype == torch.float8_e4m3fn:
762
            q_nope_val, q_nope_scale = per_tensor_quant_mla_fp8(
Lianmin Zheng's avatar
Lianmin Zheng committed
763
                q_nope.transpose(0, 1),
764
                zero_allocator.allocate(1),
765
766
767
768
769
770
            )
            q_nope_out = bmm_fp8(
                q_nope_val, self.w_kc, q_nope_scale, self.w_scale, torch.bfloat16
            )
        else:
            q_nope_out = torch.bmm(q_nope.transpose(0, 1), self.w_kc)
771
772

        q_nope_out = q_nope_out.transpose(0, 1)
773
774
        q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)

xu-yfei's avatar
xu-yfei committed
775
        if self.attention_backend == "fa3" or self.attention_backend == "flashinfer":
776
            attn_output = self.attn_mqa(
Ke Bao's avatar
Ke Bao committed
777
                q_nope_out, k_nope, k_nope, forward_batch, q_rope=q_pe, k_rope=k_pe
778
779
780
            )
        else:
            q = torch.cat([q_nope_out, q_pe], dim=-1)
Ke Bao's avatar
Ke Bao committed
781
            k = torch.cat([k_nope, k_pe], dim=-1)
782
            attn_output = self.attn_mqa(q, k, k_nope, forward_batch)
783
784
        attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)

785
786
        if self.use_deep_gemm_bmm:
            attn_output_val, attn_output_scale, masked_m, expected_m, aligned_m = (
787
788
                per_token_group_quant_mla_deep_gemm_masked_fp8(
                    attn_output.transpose(0, 1)
789
790
791
792
793
                )
            )
            attn_bmm_output = attn_output.new_empty(
                (self.num_local_heads, aligned_m, self.v_head_dim)
            )
794
            deep_gemm_grouped_gemm_nt_f8f8bf16_masked(
795
796
797
798
799
800
801
802
                (attn_output_val, attn_output_scale),
                (self.w_vc, self.w_scale_v),
                attn_bmm_output,
                masked_m,
                expected_m,
            )
            attn_bmm_output = attn_bmm_output[:, :expected_m, :]
        elif self.w_vc.dtype == torch.float8_e4m3fnuz:
803
804
805
806
807
            # TODO(kernel): add bmm_fp8 for torch.float8_e4m3fnuz
            attn_bmm_output = torch.bmm(
                attn_output.to(torch.bfloat16).transpose(0, 1),
                self.w_vc.to(torch.bfloat16) * self.w_scale,
            )
808
        elif self.w_vc.dtype == torch.float8_e4m3fn:
809
            attn_output_val, attn_output_scale = per_tensor_quant_mla_fp8(
Lianmin Zheng's avatar
Lianmin Zheng committed
810
                attn_output.transpose(0, 1),
811
                zero_allocator.allocate(1),
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
            )
            attn_bmm_output = bmm_fp8(
                attn_output_val,
                self.w_vc,
                attn_output_scale,
                self.w_scale,
                torch.bfloat16,
            )
        else:
            attn_bmm_output = torch.bmm(attn_output.transpose(0, 1), self.w_vc)
        attn_output = attn_bmm_output.transpose(0, 1).flatten(1, 2)
        output, _ = self.o_proj(attn_output)

        return output

    def forward_absorb_fused_mla_rope(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        forward_batch: ForwardBatch,
832
        zero_allocator: BumpAllocator,
833
834
835
836
837
838
839
840
841
    ) -> torch.Tensor:
        enable_rope_fusion = (
            os.getenv("SGLANG_FUSED_MLA_ENABLE_ROPE_FUSION", "1") == "1"
        )
        q_len = hidden_states.shape[0]
        q_input = hidden_states.new_empty(
            q_len, self.num_local_heads, self.kv_lora_rank + self.qk_rope_head_dim
        )
        if self.q_lora_rank is not None:
842
843
844
            q, latent_cache = self.fused_qkv_a_proj_with_mqa(hidden_states)[0].split(
                [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1
            )
845
846
847
848
849
850
            q = self.q_a_layernorm(q)
            q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
        else:
            q = self.q_proj(hidden_states)[0].view(
                -1, self.num_local_heads, self.qk_head_dim
            )
851
            latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
852
853
854
855
856
857
858
859
860
        q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)

        if self.w_kc.dtype == torch.float8_e4m3fnuz:
            # TODO(kernel): add bmm_fp8 for torch.float8_e4m3fnuz
            q_nope_out = torch.bmm(
                q_nope.to(torch.bfloat16).transpose(0, 1),
                self.w_kc.to(torch.bfloat16) * self.w_scale,
            )
        elif self.w_kc.dtype == torch.float8_e4m3fn:
861
            q_nope_val, q_nope_scale = per_tensor_quant_mla_fp8(
862
863
864
                q_nope.transpose(0, 1),
                zero_allocator.allocate(1),
                dtype=torch.float8_e4m3fn,
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
            )
            q_nope_out = bmm_fp8(
                q_nope_val, self.w_kc, q_nope_scale, self.w_scale, torch.bfloat16
            )
        else:
            q_nope_out = torch.bmm(q_nope.transpose(0, 1), self.w_kc)
        q_input[..., : self.kv_lora_rank] = q_nope_out.transpose(0, 1)
        v_input = latent_cache[..., : self.kv_lora_rank]
        v_input = self.kv_a_layernorm(v_input.contiguous()).unsqueeze(1)
        k_input = latent_cache.unsqueeze(1)
        k_input[..., : self.kv_lora_rank] = v_input

        if not enable_rope_fusion:
            k_pe = k_input[..., self.kv_lora_rank :]
            q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
            q_input[..., self.kv_lora_rank :] = q_pe
            k_input[..., self.kv_lora_rank :] = k_pe
            k_pe_output = None
        else:
            k_pe_output = torch.empty_like(k_input[..., self.kv_lora_rank :])

        q_input[..., self.kv_lora_rank :] = q_pe

        # attn_output = self.attn_mqa(q_input, k_input, v_input, forward_batch)
        # Use Fused ROPE with use_rope=OFF.
        attn_output = torch.empty(
            (q_len, self.num_local_heads, self.kv_lora_rank),
            dtype=q.dtype,
            device=q.device,
        )
        attn_logits, _, kv_indptr, kv_indices, _, _, _ = (
            forward_batch.attn_backend.forward_metadata
        )
        cos_sin_cache = self.rotary_emb.cos_sin_cache
        num_kv_split = forward_batch.attn_backend.num_kv_splits
        sm_scale = self.attn_mqa.scaling
        if attn_logits is None:
            attn_logits = torch.empty(
                (
                    forward_batch.batch_size,
                    self.num_local_heads,
                    num_kv_split,
                    self.kv_lora_rank + 1,
                ),
                dtype=torch.float32,
                device=q.device,
            )

        # save current latent cache.
        forward_batch.token_to_kv_pool.set_kv_buffer(
            self.attn_mqa, forward_batch.out_cache_loc, k_input, None
        )
        key_cache_buf = forward_batch.token_to_kv_pool.get_key_buffer(
            self.attn_mqa.layer_id
        )
        val_cache_buf = key_cache_buf[..., : self.kv_lora_rank]

        decode_attention_fwd_grouped_rope(
            q_input,
            key_cache_buf,
            val_cache_buf,
            attn_output,
            kv_indptr,
            kv_indices,
            k_pe_output,
            self.kv_lora_rank,
            self.rotary_emb.rotary_dim,
            cos_sin_cache,
            positions,
            attn_logits,
            num_kv_split,
            sm_scale,
            logit_cap=self.attn_mqa.logit_cap,
            use_rope=enable_rope_fusion,
            is_neox_style=self.rotary_emb.is_neox_style,
        )

        if enable_rope_fusion:
            k_input[..., self.kv_lora_rank :] = k_pe_output
            forward_batch.token_to_kv_pool.set_kv_buffer(
                self.attn_mqa, forward_batch.out_cache_loc, k_input, None
            )

        attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)

950
951
952
953
954
955
956
        if self.w_vc.dtype == torch.float8_e4m3fnuz:
            # TODO(kernel): add bmm_fp8 for torch.float8_e4m3fnuz
            attn_bmm_output = torch.bmm(
                attn_output.to(torch.bfloat16).transpose(0, 1),
                self.w_vc.to(torch.bfloat16) * self.w_scale,
            )
        elif self.w_vc.dtype == torch.float8_e4m3fn:
957
            attn_output_val, attn_output_scale = per_tensor_quant_mla_fp8(
958
959
960
                attn_output.transpose(0, 1),
                zero_allocator.allocate(1),
                dtype=torch.float8_e4m3fn,
961
962
963
964
965
966
967
968
969
970
971
            )
            attn_bmm_output = bmm_fp8(
                attn_output_val,
                self.w_vc,
                attn_output_scale,
                self.w_scale,
                torch.bfloat16,
            )
        else:
            attn_bmm_output = torch.bmm(attn_output.transpose(0, 1), self.w_vc)
        attn_output = attn_bmm_output.transpose(0, 1).flatten(1, 2)
972
973
974
975
        output, _ = self.o_proj(attn_output)

        return output

976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
    def _chunked_prefix_attn_mha(
        self,
        q: torch.Tensor,
        accum_output: torch.Tensor,
        accum_lse: torch.Tensor,
        forward_batch: ForwardBatch,
    ) -> torch.Tensor:

        assert forward_batch.num_prefix_chunks is not None
        for i in range(forward_batch.num_prefix_chunks):
            forward_batch.set_prefix_chunk_idx(i)

            # Fetch latent cache from memory pool with precomputed chunked kv indices
            latent_cache_buf = forward_batch.token_to_kv_pool.get_key_buffer(
                self.attn_mha.layer_id
            )
            latent_cache = latent_cache_buf[
                forward_batch.prefix_chunk_kv_indices[i]
            ].contiguous()

            kv_a_normed, k_pe = latent_cache.split(
                [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
            )
            kv_a_normed = kv_a_normed.squeeze(1).contiguous()
            kv = self.kv_b_proj(kv_a_normed)[0]
            kv = kv.view(
                -1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim
            )
            v = kv[..., self.qk_nope_head_dim :]
            k_nope = kv[..., : self.qk_nope_head_dim]

            k = torch.empty(
                (
                    k_nope.shape[0],
                    self.num_local_heads,
                    self.qk_nope_head_dim + self.qk_rope_head_dim,
                ),
                dtype=v.dtype,
                device=v.device,
            )
            k[..., : self.qk_nope_head_dim] = k_nope
            k[..., self.qk_nope_head_dim :] = k_pe

            output, lse = self.attn_mha(q, k, v, forward_batch, save_kv_cache=False)
            lse = torch.transpose(lse, 0, 1).contiguous()
            tmp_output = torch.empty_like(accum_output)
            tmp_lse = torch.empty_like(accum_lse)
            merge_state_v2(output, lse, accum_output, accum_lse, tmp_output, tmp_lse)
            accum_output, accum_lse = tmp_output, tmp_lse

        return accum_output

    def forward_normal_chunked_kv(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        forward_batch: ForwardBatch,
    ) -> torch.Tensor:
        # In normal mha, the k and v tensors will become overly large when the prefix length is long.
        # To avoid this, we split the kv cache into chunks and process them one after another.
        # Since mha is compute friendly, the for loop induced here will not introduce significant overhead.
        # The top comments in https://github.com/vllm-project/vllm/blob/main/vllm/v1/attention/backends/mla/common.py
        # will be helpful for understanding the purpose of this function.

        # First do normal mha forward to get output for extended part
        if self.q_lora_rank is not None:
1042
1043
1044
            q, latent_cache = self.fused_qkv_a_proj_with_mqa(hidden_states)[0].split(
                [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1
            )
1045
1046
1047
1048
1049
1050
            q = self.q_a_layernorm(q)
            q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
        else:
            q = self.q_proj(hidden_states)[0].view(
                -1, self.num_local_heads, self.qk_head_dim
            )
1051
            latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
        _, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
        kv_a, _ = latent_cache.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
        latent_cache = latent_cache.unsqueeze(1)
        kv_a = self.kv_a_layernorm(kv_a.contiguous())
        kv = self.kv_b_proj(kv_a)[0]
        kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim)
        k_nope = kv[..., : self.qk_nope_head_dim]
        v = kv[..., self.qk_nope_head_dim :]
        k_pe = latent_cache[:, :, self.kv_lora_rank :]

        q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
        q[..., self.qk_nope_head_dim :] = q_pe
        k = torch.empty_like(q)
        k[..., : self.qk_nope_head_dim] = k_nope
        k[..., self.qk_nope_head_dim :] = k_pe

        latent_cache[:, :, : self.kv_lora_rank] = kv_a.unsqueeze(1)
        latent_cache[:, :, self.kv_lora_rank :] = k_pe

        # Save latent cache
        forward_batch.token_to_kv_pool.set_kv_buffer(
            self.attn_mha, forward_batch.out_cache_loc, latent_cache, None
        )

        # Do mha for extended part without prefix
        forward_batch.set_attn_attend_prefix_cache(False)
        attn_output, lse = self.attn_mha(q, k, v, forward_batch, save_kv_cache=False)
        lse = torch.transpose(lse, 0, 1).contiguous()

        # Do mha attention with chunked prefix cache if there are any sequence with prefix
        if any(forward_batch.extend_prefix_lens_cpu):
            # Only initialize the info once
            if forward_batch.num_prefix_chunks is None:
                forward_batch.prepare_chunked_prefix_cache_info(q.device)

            forward_batch.set_attn_attend_prefix_cache(True)
            attn_output = self._chunked_prefix_attn_mha(
                q=q,
                accum_output=attn_output,
                accum_lse=lse,
                forward_batch=forward_batch,
            )

        attn_output = attn_output.reshape(-1, self.num_local_heads * self.v_head_dim)
        output, _ = self.o_proj(attn_output)
        return output

1099

1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
class _FFNInputMode(Enum):
    # The MLP sublayer requires 1/tp_size tokens as input
    SCATTERED = auto()
    # The MLP sublayer requires all tokens as input
    FULL = auto()


@dataclass
class _DecoderLayerInfo:
    is_sparse: bool
    ffn_input_mode: _FFNInputMode


Liangsheng Yin's avatar
Liangsheng Yin committed
1113
1114
1115
1116
1117
1118
1119
class DeepseekV2DecoderLayer(nn.Module):

    def __init__(
        self,
        config: PretrainedConfig,
        layer_id: int,
        quant_config: Optional[QuantizationConfig] = None,
1120
        is_nextn: bool = False,
1121
        prefix: str = "",
1122
        alt_stream: Optional[torch.cuda.Stream] = None,
Liangsheng Yin's avatar
Liangsheng Yin committed
1123
1124
1125
1126
1127
1128
    ) -> None:
        super().__init__()
        self.hidden_size = config.hidden_size
        rope_theta = getattr(config, "rope_theta", 10000)
        rope_scaling = getattr(config, "rope_scaling", None)
        max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
Lianmin Zheng's avatar
Lianmin Zheng committed
1129
1130
        self.enable_dp_attention = global_server_args_dict["enable_dp_attention"]
        self.layer_id = layer_id
1131
        self.local_dp_size = get_local_attention_dp_size()
1132
1133
        self.attn_tp_size = get_attention_tp_size()
        self.attn_tp_rank = get_attention_tp_rank()
Baizhou Zhang's avatar
Baizhou Zhang committed
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
        self.self_attn = DeepseekV2AttentionMLA(
            config=config,
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
            qk_nope_head_dim=config.qk_nope_head_dim,
            qk_rope_head_dim=config.qk_rope_head_dim,
            v_head_dim=config.v_head_dim,
            q_lora_rank=(
                config.q_lora_rank if hasattr(config, "q_lora_rank") else None
            ),
            kv_lora_rank=config.kv_lora_rank,
            rope_theta=rope_theta,
            rope_scaling=rope_scaling,
            max_position_embeddings=max_position_embeddings,
            quant_config=quant_config,
            layer_id=layer_id,
            reduce_results=False,
            prefix=add_prefix("self_attn", prefix),
1152
            alt_stream=alt_stream,
Baizhou Zhang's avatar
Baizhou Zhang committed
1153
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
1154

1155
1156
1157
1158
1159
1160
        self.info = self._compute_info(config, layer_id=layer_id, is_nextn=is_nextn)
        previous_layer_info = self._compute_info(
            config, layer_id=layer_id - 1, is_nextn=False
        )

        if self.info.is_sparse:
1161
1162
1163
1164
1165
            self.mlp = DeepseekV2MoE(
                config=config,
                quant_config=quant_config,
                prefix=add_prefix("mlp", prefix),
            )
Liangsheng Yin's avatar
Liangsheng Yin committed
1166
        else:
1167
1168
1169
1170
            if self._enable_moe_dense_fully_dp():
                mlp_tp_rank, mlp_tp_size = 0, 1
            else:
                mlp_tp_rank, mlp_tp_size = None, None
Liangsheng Yin's avatar
Liangsheng Yin committed
1171
1172
1173
1174
1175
            self.mlp = DeepseekV2MLP(
                hidden_size=config.hidden_size,
                intermediate_size=config.intermediate_size,
                hidden_act=config.hidden_act,
                quant_config=quant_config,
1176
                prefix=add_prefix("mlp", prefix),
1177
1178
                tp_rank=mlp_tp_rank,
                tp_size=mlp_tp_size,
Liangsheng Yin's avatar
Liangsheng Yin committed
1179
            )
1180
1181

        self.input_is_scattered = (
1182
1183
            layer_id > 0
            and previous_layer_info.ffn_input_mode == _FFNInputMode.SCATTERED
1184
1185
1186
        )
        self.is_last_layer = self.layer_id == config.num_hidden_layers - 1

Liangsheng Yin's avatar
Liangsheng Yin committed
1187
1188
1189
1190
1191
        self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(
            config.hidden_size, eps=config.rms_norm_eps
        )

1192
1193
1194
1195
    @staticmethod
    def _enable_moe_dense_fully_dp():
        return global_server_args_dict["moe_dense_tp_size"] == 1

1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
    @staticmethod
    def _compute_info(config: PretrainedConfig, layer_id: int, is_nextn: bool):
        is_sparse = is_nextn or (
            config.n_routed_experts is not None
            and layer_id >= config.first_k_dense_replace
            and layer_id % config.moe_layer_freq == 0
        )
        ffn_input_mode = (
            _FFNInputMode.SCATTERED
            if (global_server_args_dict["enable_deepep_moe"] and is_sparse)
1206
            or (DeepseekV2DecoderLayer._enable_moe_dense_fully_dp() and not is_sparse)
1207
1208
1209
1210
            else _FFNInputMode.FULL
        )
        return _DecoderLayerInfo(is_sparse=is_sparse, ffn_input_mode=ffn_input_mode)

Liangsheng Yin's avatar
Liangsheng Yin committed
1211
1212
1213
1214
    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
1215
        forward_batch: ForwardBatch,
Liangsheng Yin's avatar
Liangsheng Yin committed
1216
        residual: Optional[torch.Tensor],
1217
        zero_allocator: BumpAllocator,
Liangsheng Yin's avatar
Liangsheng Yin committed
1218
    ) -> torch.Tensor:
1219
1220
        if self.info.ffn_input_mode == _FFNInputMode.SCATTERED:
            return self.forward_ffn_with_scattered_input(
1221
                positions, hidden_states, forward_batch, residual, zero_allocator
1222
            )
1223
1224
        elif self.info.ffn_input_mode == _FFNInputMode.FULL:
            return self.forward_ffn_with_full_input(
1225
                positions, hidden_states, forward_batch, residual, zero_allocator
1226
            )
1227
1228
        else:
            raise NotImplementedError
1229

1230
    def forward_ffn_with_full_input(
1231
1232
1233
1234
1235
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        forward_batch: ForwardBatch,
        residual: Optional[torch.Tensor],
1236
        zero_allocator: BumpAllocator,
1237
1238
    ) -> torch.Tensor:

1239
        if hidden_states.shape[0] == 0:
Lianmin Zheng's avatar
Lianmin Zheng committed
1240
1241
            residual = hidden_states
        else:
1242
1243
1244
1245
1246
            if residual is None:
                residual = hidden_states
                hidden_states = self.input_layernorm(hidden_states)
            else:
                hidden_states, residual = self.input_layernorm(hidden_states, residual)
Lianmin Zheng's avatar
Lianmin Zheng committed
1247

1248
1249
1250
1251
            assert not (
                self.attn_tp_size != 1 and self.input_is_scattered
            ), "moe_layer_freq > 1 is not supported when attn_tp_size > 1"

1252
1253
1254
1255
1256
            # Self Attention
            hidden_states = self.self_attn(
                positions=positions,
                hidden_states=hidden_states,
                forward_batch=forward_batch,
1257
                zero_allocator=zero_allocator,
Lianmin Zheng's avatar
Lianmin Zheng committed
1258
1259
1260
1261
1262
            )

        # Gather
        if get_tensor_model_parallel_world_size() > 1:
            # all gather and all reduce
1263
            if self.local_dp_size != 1:
1264
1265
1266
1267
1268
1269
1270
1271
1272
                if self.attn_tp_rank == 0:
                    hidden_states += residual
                hidden_states, local_hidden_states = (
                    forward_batch.gathered_buffer,
                    hidden_states,
                )
                dp_gather_partial(hidden_states, local_hidden_states, forward_batch)
                dp_scatter(residual, hidden_states, forward_batch)
                hidden_states = self.post_attention_layernorm(hidden_states)
Ke Bao's avatar
Ke Bao committed
1273
            else:
Lianmin Zheng's avatar
Lianmin Zheng committed
1274
                hidden_states = tensor_model_parallel_all_reduce(hidden_states)
1275
1276
1277
1278
1279
1280
1281
                hidden_states, residual = self.post_attention_layernorm(
                    hidden_states, residual
                )
        else:
            hidden_states, residual = self.post_attention_layernorm(
                hidden_states, residual
            )
Liangsheng Yin's avatar
Liangsheng Yin committed
1282
1283

        # Fully Connected
1284
        hidden_states = self.mlp(hidden_states, forward_batch)
1285

1286
        # TODO(ch-wan): use reduce-scatter in MLP to avoid this scatter
1287
        # Scatter
1288
        if self.local_dp_size != 1:
1289
1290
1291
1292
1293
1294
1295
1296
            # important: forward batch.gathered_buffer is used both after scatter and after gather.
            # be careful about this!
            hidden_states, global_hidden_states = (
                forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
                hidden_states,
            )
            dp_scatter(hidden_states, global_hidden_states, forward_batch)

Liangsheng Yin's avatar
Liangsheng Yin committed
1297
1298
        return hidden_states, residual

1299
    def forward_ffn_with_scattered_input(
1300
1301
1302
1303
1304
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        forward_batch: ForwardBatch,
        residual: Optional[torch.Tensor],
1305
        zero_allocator: BumpAllocator,
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
    ) -> torch.Tensor:

        if hidden_states.shape[0] == 0:
            residual = hidden_states
        else:
            if residual is None:
                residual = hidden_states
                hidden_states = self.input_layernorm(hidden_states)
            else:
                hidden_states, residual = self.input_layernorm(hidden_states, residual)

        if self.attn_tp_size != 1 and self.input_is_scattered:
            hidden_states, local_hidden_states = (
                forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
                hidden_states,
            )
1322
            attn_tp_all_gather(
1323
1324
1325
1326
1327
1328
1329
1330
                list(hidden_states.tensor_split(self.attn_tp_size)), local_hidden_states
            )

        # Self Attention
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
            forward_batch=forward_batch,
1331
            zero_allocator=zero_allocator,
1332
1333
1334
        )

        if self.attn_tp_size != 1:
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
            tensor_list = list(hidden_states.tensor_split(self.attn_tp_size))
            hidden_states = tensor_list[self.attn_tp_rank]
            attn_tp_reduce_scatter(hidden_states, tensor_list)
            if not self.input_is_scattered:
                residual = residual.tensor_split(self.attn_tp_size)[self.attn_tp_rank]

        if hidden_states.shape[0] != 0:
            hidden_states, residual = self.post_attention_layernorm(
                hidden_states, residual
            )
1345

1346
1347
1348
1349
1350
        if not (
            self._enable_moe_dense_fully_dp()
            and (not self.info.is_sparse)
            and hidden_states.shape[0] == 0
        ):
1351
            hidden_states = self.mlp(hidden_states, forward_batch)
1352
1353

        if self.is_last_layer and self.attn_tp_size != 1:
1354
1355
            hidden_states += residual
            residual = None
1356
1357
1358
1359
            hidden_states, local_hidden_states = (
                forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
                hidden_states,
            )
1360
            attn_tp_all_gather(
1361
1362
1363
1364
1365
                list(hidden_states.tensor_split(self.attn_tp_size)), local_hidden_states
            )

        return hidden_states, residual

Liangsheng Yin's avatar
Liangsheng Yin committed
1366
1367
1368
1369
1370
1371
1372
1373

class DeepseekV2Model(nn.Module):
    fall_back_to_pt_during_load = False

    def __init__(
        self,
        config: PretrainedConfig,
        quant_config: Optional[QuantizationConfig] = None,
1374
        prefix: str = "",
Liangsheng Yin's avatar
Liangsheng Yin committed
1375
1376
1377
1378
1379
1380
1381
1382
    ) -> None:
        super().__init__()
        self.padding_id = config.pad_token_id
        self.vocab_size = config.vocab_size

        self.embed_tokens = VocabParallelEmbedding(
            config.vocab_size,
            config.hidden_size,
Ke Bao's avatar
Ke Bao committed
1383
            enable_tp=not global_server_args_dict["enable_dp_attention"],
Liangsheng Yin's avatar
Liangsheng Yin committed
1384
        )
1385
        self.alt_stream = torch.cuda.Stream()
Liangsheng Yin's avatar
Liangsheng Yin committed
1386
1387
1388
1389
1390
1391
        self.layers = nn.ModuleList(
            [
                DeepseekV2DecoderLayer(
                    config,
                    layer_id,
                    quant_config=quant_config,
1392
                    prefix=add_prefix(f"layers.{layer_id}", prefix),
1393
                    alt_stream=self.alt_stream,
Liangsheng Yin's avatar
Liangsheng Yin committed
1394
1395
1396
1397
1398
1399
                )
                for layer_id in range(config.num_hidden_layers)
            ]
        )
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

1400
        self.dp_size = get_local_attention_dp_size()
Lianmin Zheng's avatar
Lianmin Zheng committed
1401

1402
1403
1404
    def get_input_embeddings(self) -> torch.Tensor:
        return self.embed_tokens

Liangsheng Yin's avatar
Liangsheng Yin committed
1405
1406
1407
1408
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
1409
        forward_batch: ForwardBatch,
1410
        input_embeds: torch.Tensor = None,
Liangsheng Yin's avatar
Liangsheng Yin committed
1411
    ) -> torch.Tensor:
1412
1413
1414
1415
        zero_allocator = BumpAllocator(
            # TODO for two-batch-overlap, we need a larger buffer size
            buffer_size=len(self.layers) * 2,
            dtype=torch.float32,
1416
1417
1418
            device=(
                input_embeds.device if input_embeds is not None else input_ids.device
            ),
1419
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
1420

1421
1422
1423
1424
1425
        if input_embeds is None:
            hidden_states = self.embed_tokens(input_ids)
        else:
            hidden_states = input_embeds

Liangsheng Yin's avatar
Liangsheng Yin committed
1426
1427
        residual = None
        for i in range(len(self.layers)):
1428
1429
1430
1431
1432
            with get_global_expert_distribution_recorder().with_current_layer(i):
                layer = self.layers[i]
                hidden_states, residual = layer(
                    positions, hidden_states, forward_batch, residual, zero_allocator
                )
Ke Bao's avatar
Ke Bao committed
1433
        if not forward_batch.forward_mode.is_idle():
1434
1435
1436
1437
            if residual is None:
                hidden_states = self.norm(hidden_states)
            else:
                hidden_states, _ = self.norm(hidden_states, residual)
Liangsheng Yin's avatar
Liangsheng Yin committed
1438
1439
1440
1441
1442
1443
1444
1445
1446
        return hidden_states


class DeepseekV2ForCausalLM(nn.Module):

    def __init__(
        self,
        config: PretrainedConfig,
        quant_config: Optional[QuantizationConfig] = None,
1447
        prefix: str = "",
Liangsheng Yin's avatar
Liangsheng Yin committed
1448
1449
1450
    ) -> None:
        super().__init__()
        self.config = config
1451
        self.tp_size = get_tensor_model_parallel_world_size()
Liangsheng Yin's avatar
Liangsheng Yin committed
1452
        self.quant_config = quant_config
1453
1454
1455
1456
1457
1458
1459
1460
1461
        self.determine_n_share_experts_fusion()
        self.model = DeepseekV2Model(
            config, quant_config, prefix=add_prefix("model", prefix)
        )
        self.lm_head = ParallelLMHead(
            config.vocab_size,
            config.hidden_size,
            quant_config=quant_config,
            prefix=add_prefix("lm_head", prefix),
1462
            use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
1463
1464
        )
        self.logits_processor = LogitsProcessor(config)
1465
        self.dp_size = get_local_attention_dp_size()
1466
1467
1468
1469

    def determine_n_share_experts_fusion(
        self, architecture: str = "DeepseekV3ForCausalLM"
    ):
1470
        self.n_share_experts_fusion = global_server_args_dict["n_share_experts_fusion"]
1471
1472
1473
        if self.n_share_experts_fusion > 0:
            # Only Deepseek V3/R1 can use shared experts fusion optimization now.
            if (
1474
1475
                not _is_cuda
                or self.config.architectures[0] != architecture
1476
1477
1478
1479
                or self.config.n_routed_experts != 256
            ):
                self.n_share_experts_fusion = 0
                global_server_args_dict["n_share_experts_fusion"] = 0
1480
1481
                log_info_on_rank0(
                    logger,
1482
                    "Only Deepseek V3/R1 on NV-platform can use shared experts fusion optimization. Shared experts fusion optimization is disabled.",
1483
1484
1485
1486
                )
            else:
                assert (
                    self.n_share_experts_fusion == self.tp_size
1487
                ), f"Shared experts fusion optimization is enabled in DeepSeek V3/R1, set it to {self.tp_size} can get best optimized performance."
1488
1489
        elif self.n_share_experts_fusion == 0:
            if (
1490
1491
                _is_cuda
                and torch.cuda.get_device_capability("cuda") >= (9, 0)
1492
                and self.config.architectures[0] == architecture
1493
1494
1495
1496
1497
                and self.config.n_routed_experts == 256
                and (not global_server_args_dict["enable_deepep_moe"])
            ):
                self.n_share_experts_fusion = self.tp_size
                global_server_args_dict["n_share_experts_fusion"] = self.tp_size
1498
1499
1500
                log_info_on_rank0(
                    logger,
                    "Deepseek V3/R1 with fp8 can use shared experts fusion optimization when SM version >=90. Shared experts fusion optimization is enabled.",
1501
                )
1502

Mick's avatar
Mick committed
1503
1504
1505
    def get_input_embeddings(self) -> nn.Embedding:
        return self.model.embed_tokens

1506
    @torch.no_grad()
Liangsheng Yin's avatar
Liangsheng Yin committed
1507
1508
1509
1510
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
1511
        forward_batch: ForwardBatch,
1512
        input_embeds: torch.Tensor = None,
Liangsheng Yin's avatar
Liangsheng Yin committed
1513
    ) -> torch.Tensor:
1514
1515

        hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
Lianmin Zheng's avatar
Lianmin Zheng committed
1516

1517
1518
1519
        return self.logits_processor(
            input_ids, hidden_states, self.lm_head, forward_batch
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
1520

1521
    def post_load_weights(self, is_nextn=False):
inkcherry's avatar
inkcherry committed
1522
1523

        # Perform post-processing after loading weights
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
        layer_ids = (
            range(self.config.num_hidden_layers)
            if not is_nextn
            else [self.config.num_hidden_layers]
        )
        for layer_id in layer_ids:
            self_attn = (
                self.model.layers[layer_id].self_attn
                if not is_nextn
                else self.model.decoder.self_attn
            )
Baizhou Zhang's avatar
Baizhou Zhang committed
1535
1536
1537
1538
1539
1540
1541
1542
            if hasattr(self_attn.kv_b_proj, "qweight"):
                # AWQ compatible
                if _is_cuda:
                    w = awq_dequantize(
                        self_attn.kv_b_proj.qweight,
                        self_attn.kv_b_proj.scales,
                        self_attn.kv_b_proj.qzeros,
                    ).T
inkcherry's avatar
inkcherry committed
1543
                else:
Baizhou Zhang's avatar
Baizhou Zhang committed
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
                    w = awq_dequantize(
                        self_attn.kv_b_proj.qweight,
                        self_attn.kv_b_proj.scales,
                        self_attn.kv_b_proj.qzeros,
                        0,
                        0,
                        0,
                    ).T
            else:
                w = self_attn.kv_b_proj.weight
            # NOTE(HandH1998): Since `bmm_fp8` only supports per-tensor scale, we have to requantize `self_attn.kv_b_proj`.
            # This may affect the accuracy of fp8 model.
1556
1557
1558
1559
            # Fix deepseek v3 blockwise bmm by using deep_gemm
            use_deep_gemm_bmm = False
            model_dtype = torch.get_default_dtype()

Baizhou Zhang's avatar
Baizhou Zhang committed
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
            if w.dtype in (
                torch.float8_e4m3fn,
                torch.float8_e4m3fnuz,
            ):
                if hasattr(self.quant_config, "weight_block_size"):
                    weight_block_size = self.quant_config.weight_block_size
                    if weight_block_size is not None:
                        assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
                        if _is_hip:
                            weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
                                weight=w,
                                weight_scale=self_attn.kv_b_proj.weight_scale_inv,
                                input_scale=None,
inkcherry's avatar
inkcherry committed
1573
                            )
Baizhou Zhang's avatar
Baizhou Zhang committed
1574
                        else:
inkcherry's avatar
inkcherry committed
1575
1576
                            weight = w
                            weight_scale = self_attn.kv_b_proj.weight_scale_inv
Baizhou Zhang's avatar
Baizhou Zhang committed
1577

1578
1579
1580
1581
1582
1583
                        if (
                            _is_cuda
                            and weight_block_size[0] == 128
                            and weight_block_size[1] == 128
                            and model_dtype == torch.bfloat16
                        ):
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
                            if _ENABLE_JIT_DEEPGEMM and get_bool_env_var(
                                "SGL_USE_DEEPGEMM_BMM", "false"
                            ):
                                block_scale = weight_scale
                                use_deep_gemm_bmm = True
                            else:
                                w = block_quant_dequant(
                                    weight,
                                    weight_scale,
                                    weight_block_size,
                                    model_dtype,
                                )
1596
1597
1598
1599
1600
                        else:
                            w, scale = block_quant_to_tensor_quant(
                                weight, weight_scale, weight_block_size
                            )
                            self_attn.w_scale = scale
Baizhou Zhang's avatar
Baizhou Zhang committed
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
                else:
                    weight = w
                    weight_scale = self_attn.kv_b_proj.weight_scale
                    w, scale = channel_quant_to_tensor_quant(weight, weight_scale)
                    self_attn.w_scale = scale

            if w.dtype == torch.int8:
                if hasattr(self.quant_config, "weight_block_size"):
                    # block-wise int8 need it
                    weight_block_size = self.quant_config.weight_block_size
                    if weight_block_size is not None:
                        assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
                        weight = w
                        weight_scale = self_attn.kv_b_proj.weight_scale_inv
                        w = int8_block_dequant(
                            weight, weight_scale, weight_block_size
                        ).to(torch.bfloat16)
                else:
                    # channel-wise int8 need it
                    w = w.to(torch.bfloat16) * self_attn.kv_b_proj.weight_scale.to(
                        torch.bfloat16
                    )
1623

Baizhou Zhang's avatar
Baizhou Zhang committed
1624
1625
1626
            w_kc, w_vc = w.unflatten(
                0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
            ).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
            if not use_deep_gemm_bmm:
                self_attn.w_kc = w_kc.transpose(1, 2).contiguous().transpose(1, 2)
                self_attn.w_vc = w_vc.contiguous().transpose(1, 2)
                if (
                    hasattr(self_attn.kv_b_proj, "weight_scale")
                    and self_attn.w_scale is None
                ):
                    self_attn.w_scale = self_attn.kv_b_proj.weight_scale
                    if _is_hip:
                        self_attn.w_scale *= 2.0
            else:
                num_tiles_k = self_attn.qk_nope_head_dim // weight_block_size[1]
                num_tiles_n = self_attn.v_head_dim // weight_block_size[0]
                ws_kc, ws_vc = block_scale.unflatten(
                    0, (-1, (num_tiles_k + num_tiles_n))
                ).split([num_tiles_k, num_tiles_n], dim=1)
                self_attn.w_scale_k = ws_kc.transpose(1, 2).contiguous()
                self_attn.w_scale_v = ws_vc.contiguous()
                self_attn.w_kc = w_kc.transpose(1, 2).contiguous()
                self_attn.w_vc = w_vc.contiguous()
                self_attn.use_deep_gemm_bmm = True
inkcherry's avatar
inkcherry committed
1648

1649
1650
1651
1652
    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], is_nextn=False):
        if is_nextn:
            if hasattr(self.config, "num_nextn_predict_layers"):
                num_nextn_layers = self.config.num_nextn_predict_layers
1653
                assert num_nextn_layers == 1, "Only 1 nextn layer is supported"
1654
1655
1656
1657
1658
1659
1660
1661
1662
                # compatible with old design
                nextn_layer_id = (
                    0
                    if self.config.num_hidden_layers == 1
                    else self.config.num_hidden_layers
                )
            else:
                raise ValueError("num_nextn_predict_layers is not in the config")

Liangsheng Yin's avatar
Liangsheng Yin committed
1663
1664
1665
1666
1667
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
        ]
1668
        if self.n_share_experts_fusion > 0:
1669
1670
            weights_list = list(weights)
            weights_dict = dict(weights_list)
1671
            if self.quant_config is None or self.quant_config.get_name() == "w8a8_int8":
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
                suffix_list = [
                    "down_proj.weight",
                    "down_proj.weight_scale",
                    "gate_proj.weight",
                    "gate_proj.weight_scale",
                    "up_proj.weight",
                    "up_proj.weight_scale",
                ]
            else:
                suffix_list = [
                    "down_proj.weight",
                    "down_proj.weight_scale_inv",
                    "gate_proj.weight",
                    "gate_proj.weight_scale_inv",
                    "up_proj.weight",
                    "up_proj.weight_scale_inv",
                ]
1689
            names_to_remove = []
1690
1691

            moe_layers = (
1692
1693
1694
1695
                range(
                    self.config.first_k_dense_replace,
                    self.config.num_hidden_layers,
                    self.config.moe_layer_freq,
1696
1697
1698
1699
1700
1701
1702
                )
                if not is_nextn
                else [nextn_layer_id]
            )

            for moe_layer in tqdm(
                moe_layers,
1703
1704
1705
                desc=f"Cloning {self.n_share_experts_fusion} "
                "replicas of the shared expert into MoE",
            ):
1706
1707
1708
1709
1710
                for suffix in suffix_list:
                    shared_expert_weight_name = (
                        f"model.layers.{moe_layer}.mlp.shared_experts.{suffix}"
                    )
                    for num_repeat in range(self.n_share_experts_fusion):
1711
1712
1713
1714
1715
1716
                        weights_list.append(
                            (
                                f"model.layers.{moe_layer}."
                                f"mlp.experts."
                                f"{self.config.n_routed_experts + num_repeat}"
                                f".{suffix}",
1717
                                weights_dict[shared_expert_weight_name],
1718
1719
                            )
                        )
1720
                    names_to_remove += [shared_expert_weight_name]
1721
            weights = [w for w in weights_list if w[0] not in names_to_remove]
Liangsheng Yin's avatar
Liangsheng Yin committed
1722
1723
1724

        # Params for weights, fp8 weight scales, fp8 activation scales
        # (param_name, weight_name, expert_id, shard_id)
1725
        expert_params_mapping = get_moe_impl_class().make_expert_params_mapping(
Liangsheng Yin's avatar
Liangsheng Yin committed
1726
1727
1728
            ckpt_gate_proj_name="gate_proj",
            ckpt_down_proj_name="down_proj",
            ckpt_up_proj_name="up_proj",
1729
            num_experts=self.config.n_routed_experts + self.n_share_experts_fusion,
Liangsheng Yin's avatar
Liangsheng Yin committed
1730
1731
        )

1732
1733
1734
1735
1736
1737
        # Fuse q_a_proj and kv_a_proj_with_mqa along output dimension when q_lora_rank is not None
        fuse_qkv_a_proj = hasattr(self.config, "q_lora_rank") and (
            self.config.q_lora_rank is not None
        )
        cached_a_proj = {} if fuse_qkv_a_proj else None

1738
1739
1740
1741
1742
1743
1744
1745
1746
        if is_nextn:
            nextn_layer_prefix = f"model.layers.{nextn_layer_id}"
            nextn_spec_weight_names = [
                "shared_head.norm",
                "eh_proj",
                "enorm",
                "hnorm",
            ]

Liangsheng Yin's avatar
Liangsheng Yin committed
1747
1748
        params_dict = dict(self.named_parameters())
        for name, loaded_weight in weights:
1749
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
            if not is_nextn:
                if hasattr(self.config, "num_nextn_predict_layers"):
                    num_nextn_layers = self.config.num_nextn_predict_layers
                    if num_nextn_layers > 0 and name.startswith("model.layers"):
                        name_list = name.split(".")
                        if (
                            len(name_list) >= 3
                            and int(name_list[2]) >= self.config.num_hidden_layers
                        ):
                            continue
            else:
                if not name.startswith(nextn_layer_prefix):
                    continue

                # Use shared head and embed weights from target model
                if "shared_head.head" in name or "embed_tokens" in name:
                    continue

                is_decoder = True
                # For nextn specific weights
                for weight_name in nextn_spec_weight_names:
                    if weight_name in name:
                        name = name.replace(nextn_layer_prefix, "model")
                        is_decoder = False
                        break
                # For decoder layer weights
                if is_decoder:
                    name = name.replace(nextn_layer_prefix, "model.decoder")

Liangsheng Yin's avatar
Liangsheng Yin committed
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803
1804
1805
1806
1807
1808
1809
1810
            if "rotary_emb.inv_freq" in name:
                continue
            for param_name, weight_name, shard_id in stacked_params_mapping:
                # Skip non-stacked layers and experts (experts handled below).
                if weight_name not in name:
                    continue
                # We have mlp.experts[0].gate_proj in the checkpoint.
                # Since we handle the experts below in expert_params_mapping,
                # we need to skip here BEFORE we update the name, otherwise
                # name will be updated to mlp.experts[0].gate_up_proj, which
                # will then be updated below in expert_params_mapping
                # for mlp.experts[0].gate_gate_up_proj, which breaks load.
                if ("mlp.experts." in name) and name not in params_dict:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                for mapping in expert_params_mapping:
                    param_name, weight_name, expert_id, shard_id = mapping
                    if weight_name not in name:
                        continue
                    name = name.replace(weight_name, param_name)
                    param = params_dict[name]
                    weight_loader = param.weight_loader
                    weight_loader(
                        param,
                        loaded_weight,
1811
                        name,
Liangsheng Yin's avatar
Liangsheng Yin committed
1812
1813
1814
1815
1816
1817
1818
1819
1820
                        shard_id=shard_id,
                        expert_id=expert_id,
                    )
                    break
                else:
                    # Skip loading extra bias for GPTQ models.
                    if name.endswith(".bias") and name not in params_dict:
                        continue

1821
1822
1823
1824
1825
1826
1827
1828
1829
1830
1831
1832
1833
1834
1835
1836
1837
1838
1839
1840
1841
1842
1843
1844
1845
1846
1847
1848
1849
1850
1851
1852
1853
1854
1855
1856
1857
1858
1859
1860
1861
1862
1863
                    if fuse_qkv_a_proj and (
                        "q_a_proj" in name or "kv_a_proj_with_mqa" in name
                    ):
                        cached_a_proj[name] = loaded_weight
                        q_a_proj_name = (
                            name
                            if "q_a_proj" in name
                            else name.replace("kv_a_proj_with_mqa", "q_a_proj")
                        )
                        kv_a_proj_name = (
                            name
                            if "kv_a_proj_with_mqa" in name
                            else name.replace("q_a_proj", "kv_a_proj_with_mqa")
                        )

                        # When both q_a_proj and kv_a_proj_with_mqa has been cached, load the fused weight to parameter
                        if (
                            q_a_proj_name in cached_a_proj
                            and kv_a_proj_name in cached_a_proj
                        ):
                            q_a_proj_weight = cached_a_proj[q_a_proj_name]
                            kv_a_proj_weight = cached_a_proj[kv_a_proj_name]
                            fused_weight = torch.cat(
                                [q_a_proj_weight, kv_a_proj_weight], dim=0
                            )

                            param_name = name.replace(
                                "q_a_proj", "fused_qkv_a_proj_with_mqa"
                            )
                            param = params_dict[param_name]

                            weight_loader = getattr(
                                param, "weight_loader", default_weight_loader
                            )
                            weight_loader(param, fused_weight)
                            cached_a_proj.pop(q_a_proj_name)
                            cached_a_proj.pop(kv_a_proj_name)
                    else:
                        param = params_dict[name]
                        weight_loader = getattr(
                            param, "weight_loader", default_weight_loader
                        )
                        weight_loader(param, loaded_weight)
Liangsheng Yin's avatar
Liangsheng Yin committed
1864

1865
        self.post_load_weights(is_nextn=is_nextn)
Ke Bao's avatar
Ke Bao committed
1866

1867
1868
1869
1870
1871
1872
1873
1874
1875
1876
1877
    def get_embed_and_head(self):
        return self.model.embed_tokens.weight, self.lm_head.weight

    def set_embed_and_head(self, embed, head):
        del self.model.embed_tokens.weight
        del self.lm_head.weight
        self.model.embed_tokens.weight = embed
        self.lm_head.weight = head
        torch.cuda.empty_cache()
        torch.cuda.synchronize()

1878
1879
1880
1881
1882
1883
1884
1885
    @classmethod
    def get_model_config_for_expert_location(cls, config):
        return ModelConfigForExpertLocation(
            num_layers=config.num_hidden_layers,
            num_logical_experts=config.n_routed_experts,
            num_groups=config.n_group,
        )

Liangsheng Yin's avatar
Liangsheng Yin committed
1886

HandH1998's avatar
HandH1998 committed
1887
1888
1889
1890
1891
class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM):
    pass


EntryClass = [DeepseekV2ForCausalLM, DeepseekV3ForCausalLM]