deepseek_v2.py 71.4 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
14

Liangsheng Yin's avatar
Liangsheng Yin committed
15
16
17
# Adapted from:
# https://github.com/vllm-project/vllm/blob/fb6af8bc086328ca6659e72d11ffd4309ce4de22/vllm/model_executor/models/deepseek_v2.py
"""Inference-only DeepseekV2 model."""
18

19
import logging
20
import os
21
from enum import IntEnum, auto
Liangsheng Yin's avatar
Liangsheng Yin committed
22
23
24
from typing import Any, Dict, Iterable, Optional, Tuple

import torch
Ke Bao's avatar
Ke Bao committed
25
import torch.nn.functional as F
Liangsheng Yin's avatar
Liangsheng Yin committed
26
from torch import nn
27
from tqdm import tqdm
Liangsheng Yin's avatar
Liangsheng Yin committed
28
from transformers import PretrainedConfig
29
30

from sglang.srt.distributed import (
Liangsheng Yin's avatar
Liangsheng Yin committed
31
    get_tensor_model_parallel_world_size,
32
    parallel_state,
Liangsheng Yin's avatar
Liangsheng Yin committed
33
34
    tensor_model_parallel_all_reduce,
)
35
from sglang.srt.layers.activation import SiluAndMul
36
37
38
39
40
from sglang.srt.layers.communicator import (
    LayerCommunicator,
    LayerScatterModes,
    enable_moe_dense_fully_dp,
)
Lianmin Zheng's avatar
Lianmin Zheng committed
41
42
43
from sglang.srt.layers.dp_attention import (
    get_attention_tp_rank,
    get_attention_tp_size,
44
    get_local_attention_dp_size,
Lianmin Zheng's avatar
Lianmin Zheng committed
45
)
46
from sglang.srt.layers.layernorm import RMSNorm
47
48
49
50
51
52
from sglang.srt.layers.linear import (
    ColumnParallelLinear,
    MergedColumnParallelLinear,
    ReplicatedLinear,
    RowParallelLinear,
)
Liangsheng Yin's avatar
Liangsheng Yin committed
53
from sglang.srt.layers.logits_processor import LogitsProcessor
fzyzcjy's avatar
fzyzcjy committed
54
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
Lianmin Zheng's avatar
Lianmin Zheng committed
55
from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
56
from sglang.srt.layers.moe.topk import select_experts
57
from sglang.srt.layers.quantization.base_config import QuantizationConfig
58
from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM
59
60
from sglang.srt.layers.quantization.fp8_kernel import (
    per_tensor_quant_mla_fp8,
61
    per_token_group_quant_mla_deep_gemm_masked_fp8,
62
)
HandH1998's avatar
HandH1998 committed
63
from sglang.srt.layers.quantization.fp8_utils import (
64
    block_quant_dequant,
HandH1998's avatar
HandH1998 committed
65
    block_quant_to_tensor_quant,
66
    channel_quant_to_tensor_quant,
67
    normalize_e4m3fn_to_e4m3fnuz,
HandH1998's avatar
HandH1998 committed
68
)
69
70
71
from sglang.srt.layers.quantization.int8_utils import (
    block_dequant as int8_block_dequant,
)
Liangsheng Yin's avatar
Liangsheng Yin committed
72
from sglang.srt.layers.radix_attention import RadixAttention
73
from sglang.srt.layers.rotary_embedding import get_rope
74
75
76
77
from sglang.srt.layers.vocab_parallel_embedding import (
    ParallelLMHead,
    VocabParallelEmbedding,
)
78
79
80
81
from sglang.srt.managers.expert_distribution import (
    get_global_expert_distribution_recorder,
)
from sglang.srt.managers.expert_location import ModelConfigForExpertLocation
82
from sglang.srt.managers.expert_location_dispatch import ExpertLocationDispatchInfo
83
from sglang.srt.managers.schedule_batch import global_server_args_dict
84
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
85
from sglang.srt.model_loader.weight_utils import default_weight_loader
86
87
from sglang.srt.operations import execute_operations
from sglang.srt.operations_strategy import compute_layer_operations
88
89
90
91
92
93
94
95
from sglang.srt.utils import (
    BumpAllocator,
    DeepEPMode,
    add_prefix,
    get_bool_env_var,
    get_int_env_var,
    is_cuda,
    is_hip,
96
    log_info_on_rank0,
97
)
98

99
_is_hip = is_hip()
Yineng Zhang's avatar
Yineng Zhang committed
100
_is_cuda = is_cuda()
101

Yineng Zhang's avatar
Yineng Zhang committed
102
if _is_cuda:
103
    from sgl_kernel import awq_dequantize, bmm_fp8, merge_state_v2
104
105
106
107

    from sglang.srt.layers.quantization.deep_gemm import (
        grouped_gemm_nt_f8f8bf16_masked as deep_gemm_grouped_gemm_nt_f8f8bf16_masked,
    )
Yineng Zhang's avatar
Yineng Zhang committed
108
else:
Lianmin Zheng's avatar
Lianmin Zheng committed
109
    from vllm._custom_ops import awq_dequantize
Liangsheng Yin's avatar
Liangsheng Yin committed
110

111
112
113
114
115
if _is_hip:
    from sglang.srt.layers.attention.triton_ops.rocm_mla_decode_rope import (
        decode_attention_fwd_grouped_rope,
    )

116
117
logger = logging.getLogger(__name__)

Liangsheng Yin's avatar
Liangsheng Yin committed
118

119
120
121
122
123
124
125
126
127
128
129
class AttnForwardMethod(IntEnum):
    # Use multi-head attention
    MHA = auto()

    # Use absorbed multi-latent attention
    MLA = auto()

    # Use multi-head attention, but with KV cache chunked.
    # This method can avoid OOM when prefix lengths are long.
    MHA_CHUNKED_KV = auto()

130
131
132
    # Use MLA but with fused RoPE
    MLA_FUSED_ROPE = auto()

133

Liangsheng Yin's avatar
Liangsheng Yin committed
134
135
136
137
138
139
140
141
class DeepseekV2MLP(nn.Module):
    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
        quant_config: Optional[QuantizationConfig] = None,
        reduce_results: bool = True,
142
        prefix: str = "",
143
144
        tp_rank: Optional[int] = None,
        tp_size: Optional[int] = None,
Liangsheng Yin's avatar
Liangsheng Yin committed
145
146
    ) -> None:
        super().__init__()
147
148
        self.tp_size = tp_size

Liangsheng Yin's avatar
Liangsheng Yin committed
149
        self.gate_up_proj = MergedColumnParallelLinear(
150
151
152
153
154
            hidden_size,
            [intermediate_size] * 2,
            bias=False,
            quant_config=quant_config,
            prefix=add_prefix("gate_up_proj", prefix),
155
156
            tp_rank=tp_rank,
            tp_size=tp_size,
Liangsheng Yin's avatar
Liangsheng Yin committed
157
158
159
160
161
162
163
        )
        self.down_proj = RowParallelLinear(
            intermediate_size,
            hidden_size,
            bias=False,
            quant_config=quant_config,
            reduce_results=reduce_results,
164
            prefix=add_prefix("down_proj", prefix),
165
166
            tp_rank=tp_rank,
            tp_size=tp_size,
Liangsheng Yin's avatar
Liangsheng Yin committed
167
168
169
170
171
172
173
174
        )
        if hidden_act != "silu":
            raise ValueError(
                f"Unsupported activation: {hidden_act}. "
                "Only silu is supported for now."
            )
        self.act_fn = SiluAndMul()

175
176
177
178
    def forward(self, x, forward_batch=None):
        if (self.tp_size == 1) and x.shape[0] == 0:
            return x

Liangsheng Yin's avatar
Liangsheng Yin committed
179
180
181
182
183
184
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
        x, _ = self.down_proj(x)
        return x


Ke Bao's avatar
Ke Bao committed
185
class MoEGate(nn.Module):
186
187
188
189
190
    def __init__(
        self,
        config,
        prefix: str = "",
    ):
Ke Bao's avatar
Ke Bao committed
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
        super().__init__()
        self.weight = nn.Parameter(
            torch.empty((config.n_routed_experts, config.hidden_size))
        )
        if config.topk_method == "noaux_tc":
            self.e_score_correction_bias = nn.Parameter(
                torch.empty((config.n_routed_experts))
            )
        else:
            self.e_score_correction_bias = None

    def forward(self, hidden_states):
        logits = F.linear(hidden_states, self.weight, None)
        return logits


207
208
209
210
211
212
213
214
def is_non_idle_and_non_empty(forward_mode, hidden_states):
    return (
        (forward_mode is not None)
        and not forward_mode.is_idle()
        and hidden_states.shape[0] > 0
    )


Liangsheng Yin's avatar
Liangsheng Yin committed
215
216
217
218
219
class DeepseekV2MoE(nn.Module):

    def __init__(
        self,
        config: PretrainedConfig,
fzyzcjy's avatar
fzyzcjy committed
220
        layer_id: int,
Liangsheng Yin's avatar
Liangsheng Yin committed
221
        quant_config: Optional[QuantizationConfig] = None,
222
        prefix: str = "",
Liangsheng Yin's avatar
Liangsheng Yin committed
223
224
225
226
227
    ):
        super().__init__()
        self.tp_size = get_tensor_model_parallel_world_size()
        self.routed_scaling_factor = config.routed_scaling_factor
        self.n_shared_experts = config.n_shared_experts
228
        self.n_share_experts_fusion = global_server_args_dict["n_share_experts_fusion"]
fzyzcjy's avatar
fzyzcjy committed
229
        self.layer_id = layer_id
230

Liangsheng Yin's avatar
Liangsheng Yin committed
231
232
233
234
235
236
237
238
239
240
241
242
        if self.tp_size > config.n_routed_experts:
            raise ValueError(
                f"Tensor parallel size {self.tp_size} is greater than "
                f"the number of experts {config.n_routed_experts}."
            )

        if config.hidden_act != "silu":
            raise ValueError(
                f"Unsupported activation: {config.hidden_act}. "
                "Only silu is supported for now."
            )

243
        self.gate = MoEGate(config=config, prefix=add_prefix("gate", prefix))
Ke Bao's avatar
Ke Bao committed
244

245
        self.experts = get_moe_impl_class()(
246
247
            num_experts=config.n_routed_experts + self.n_share_experts_fusion,
            top_k=config.num_experts_per_tok + min(self.n_share_experts_fusion, 1),
248
249
            hidden_size=config.hidden_size,
            intermediate_size=config.moe_intermediate_size,
fzyzcjy's avatar
fzyzcjy committed
250
            layer_id=self.layer_id,
251
252
253
254
255
256
            renormalize=config.norm_topk_prob,
            quant_config=quant_config,
            use_grouped_topk=True,
            num_expert_group=config.n_group,
            topk_group=config.topk_group,
            correction_bias=self.gate.e_score_correction_bias,
257
            routed_scaling_factor=self.routed_scaling_factor,
258
259
260
261
262
263
264
            prefix=add_prefix("experts", prefix),
            **(
                dict(deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]])
                if global_server_args_dict["enable_deepep_moe"]
                else {}
            ),
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
265

266
        if config.n_shared_experts is not None and self.n_share_experts_fusion == 0:
Liangsheng Yin's avatar
Liangsheng Yin committed
267
            intermediate_size = config.moe_intermediate_size * config.n_shared_experts
268
            # disable tp for shared experts when enable deepep moe
269
270
271
272
273
274
275
276
277
278
279
280
281
            self.shared_experts = DeepseekV2MLP(
                hidden_size=config.hidden_size,
                intermediate_size=intermediate_size,
                hidden_act=config.hidden_act,
                quant_config=quant_config,
                reduce_results=False,
                prefix=add_prefix("shared_experts", prefix),
                **(
                    dict(tp_rank=0, tp_size=1)
                    if global_server_args_dict["enable_deepep_moe"]
                    else {}
                ),
            )
282

283
284
        self.top_k = config.num_experts_per_tok

285
        if global_server_args_dict["enable_deepep_moe"]:
286
287
            # TODO: we will support tp < ep in the future
            self.ep_size = get_tensor_model_parallel_world_size()
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
            self.num_experts = config.n_routed_experts
            self.renormalize = config.norm_topk_prob
            self.topk_group = config.topk_group
            self.num_expert_group = config.n_group
            self.correction_bias = (
                self.gate.e_score_correction_bias.data
                if self.gate.e_score_correction_bias is not None
                else None
            )

            self.deepep_dispatcher = DeepEPDispatcher(
                group=parallel_state.get_tp_group().device_group,
                router_topk=self.top_k,
                permute_fusion=True,
                num_experts=config.n_routed_experts,
                num_local_experts=config.n_routed_experts // self.tp_size,
Liangsheng Yin's avatar
Liangsheng Yin committed
304
                hidden_size=config.hidden_size,
305
                params_dtype=config.torch_dtype,
306
                deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]],
fzyzcjy's avatar
fzyzcjy committed
307
                async_finish=True,  # TODO
308
                return_recv_hook=True,
Liangsheng Yin's avatar
Liangsheng Yin committed
309
310
            )

311
312
313
314
    @property
    def _enable_deepep_moe(self):
        return global_server_args_dict["enable_deepep_moe"]

315
    def op_gate(self, state):
316
        if (not self._enable_deepep_moe) or is_non_idle_and_non_empty(
317
            state.forward_batch.forward_mode, state.hidden_states_mlp_input
318
        ):
319
            # router_logits: (num_tokens, n_experts)
320
            state.router_logits = self.gate(state.hidden_states_mlp_input)
321
        else:
322
            state.router_logits = None
323

324
    def op_shared_experts(self, state):
325
326
        if (self.n_share_experts_fusion == 0) and (
            (not self._enable_deepep_moe)
327
328
329
            or is_non_idle_and_non_empty(
                state.forward_batch.forward_mode, state.hidden_states_mlp_input
            )
330
        ):
331
            state.shared_output = self.shared_experts(state.hidden_states_mlp_input)
332
        else:
333
            state.shared_output = None
334

335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
    def op_select_experts(self, state):
        router_logits = state.router_logits
        hidden_states = state.hidden_states_mlp_input

        if self._enable_deepep_moe:
            if router_logits is not None:
                state.topk_weights_local, state.topk_idx_local = select_experts(
                    hidden_states=hidden_states,
                    router_logits=router_logits,
                    top_k=self.top_k,
                    use_grouped_topk=True,
                    renormalize=self.renormalize,
                    topk_group=self.topk_group,
                    num_expert_group=self.num_expert_group,
                    correction_bias=self.correction_bias,
                    routed_scaling_factor=self.routed_scaling_factor,
fzyzcjy's avatar
fzyzcjy committed
351
352
353
                    expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new(
                        layer_id=self.layer_id,
                    ),
354
355
356
357
358
359
360
361
                )
            else:
                state.topk_idx_local = torch.full(
                    (0, self.top_k), -1, dtype=torch.int, device=hidden_states.device
                )
                state.topk_weights_local = torch.empty(
                    (0, self.top_k), dtype=torch.float32, device=hidden_states.device
                )
362

363
    def op_dispatch_a(self, state):
364
        if self._enable_deepep_moe and (self.ep_size > 1):
365
            # TODO(ch-wan): allow users to set num_max_dispatch_tokens_per_rank value
366
367
368
369
370
            self.deepep_dispatcher.dispatch_a(
                hidden_states=state.pop("hidden_states_mlp_input"),
                topk_idx=state.pop("topk_idx_local"),
                topk_weights=state.pop("topk_weights_local"),
                forward_mode=state.forward_batch.forward_mode,
371
            )
372

373
374
375
376
377
378
379
380
381
382
383
384
385
386
    def op_dispatch_b(self, state):
        if self._enable_deepep_moe and (self.ep_size > 1):
            (
                state.hidden_states_experts_input,
                state.topk_idx_dispatched,
                state.topk_weights_dispatched,
                state.reorder_topk_ids,
                state.num_recv_tokens_per_expert,
                state.seg_indptr,
                state.masked_m,
                state.expected_m,
            ) = self.deepep_dispatcher.dispatch_b()

    def op_experts(self, state):
387
        if self._enable_deepep_moe:
388
389
390
391
392
393
394
395
396
397
398
            state.pop("router_logits")
            state.hidden_states_experts_output = self.experts(
                hidden_states=state.pop("hidden_states_experts_input"),
                topk_idx=state.topk_idx_dispatched,
                topk_weights=state.topk_weights_dispatched,
                reorder_topk_ids=state.pop("reorder_topk_ids"),
                seg_indptr=state.pop("seg_indptr"),
                masked_m=state.pop("masked_m"),
                expected_m=state.pop("expected_m"),
                num_recv_tokens_per_expert=state.pop("num_recv_tokens_per_expert"),
                forward_mode=state.forward_batch.forward_mode,
399
400
            )
        else:
401
402
403
            state.hidden_states_experts_output = self.experts(
                hidden_states=state.pop("hidden_states_mlp_input"),
                router_logits=state.pop("router_logits"),
404
405
            )

406
    def op_combine_a(self, state):
407
        if self._enable_deepep_moe and (self.ep_size > 1):
408
409
410
411
412
            self.deepep_dispatcher.combine_a(
                state.pop("hidden_states_experts_output"),
                topk_idx=state.pop("topk_idx_dispatched"),
                topk_weights=state.pop("topk_weights_dispatched"),
                forward_mode=state.forward_batch.forward_mode,
413
            )
414

415
416
417
418
419
420
421
422
423
424
425
    def op_combine_b(self, state):
        if self._enable_deepep_moe and (self.ep_size > 1):
            state.hidden_states_after_combine = self.deepep_dispatcher.combine_b()

    def op_output(self, state):
        final_hidden_states = (
            state.pop("hidden_states_after_combine")
            if self._enable_deepep_moe
            else state.pop("hidden_states_experts_output")
        )

426
427
        final_hidden_states *= self.routed_scaling_factor

428
429
        if (s := state.pop("shared_output")) is not None:
            final_hidden_states = final_hidden_states + s
Liangsheng Yin's avatar
Liangsheng Yin committed
430

431
432
        if (not self._enable_deepep_moe) and (self.tp_size > 1):
            final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
Liangsheng Yin's avatar
Liangsheng Yin committed
433

434
        state.hidden_states_mlp_output = final_hidden_states
435

Liangsheng Yin's avatar
Liangsheng Yin committed
436
437
438
439
440
441
442
443
444

def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
    import math

    if scale <= 1:
        return 1.0
    return 0.1 * mscale * math.log(scale) + 1.0


445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
class DeepseekV2AttentionMLA(nn.Module):

    def __init__(
        self,
        config: PretrainedConfig,
        hidden_size: int,
        num_heads: int,
        qk_nope_head_dim: int,
        qk_rope_head_dim: int,
        v_head_dim: int,
        q_lora_rank: int,
        kv_lora_rank: int,
        rope_theta: float = 10000,
        rope_scaling: Optional[Dict[str, Any]] = None,
        max_position_embeddings: int = 8192,
        quant_config: Optional[QuantizationConfig] = None,
Lianmin Zheng's avatar
Lianmin Zheng committed
461
462
        reduce_results: bool = True,
        layer_id: int = None,
463
        prefix: str = "",
464
        alt_stream: Optional[torch.cuda.Stream] = None,
465
466
467
468
469
470
471
472
473
474
    ) -> None:
        super().__init__()
        self.layer_id = layer_id
        self.hidden_size = hidden_size
        self.qk_nope_head_dim = qk_nope_head_dim
        self.qk_rope_head_dim = qk_rope_head_dim
        self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
        self.v_head_dim = v_head_dim
        self.q_lora_rank = q_lora_rank
        self.kv_lora_rank = kv_lora_rank
Lianmin Zheng's avatar
Lianmin Zheng committed
475
476
477
        attn_tp_rank = get_attention_tp_rank()
        attn_tp_size = get_attention_tp_size()

478
        self.num_heads = num_heads
Lianmin Zheng's avatar
Lianmin Zheng committed
479
480
        assert num_heads % attn_tp_size == 0
        self.num_local_heads = num_heads // attn_tp_size
481
482
483
484
        self.scaling = self.qk_head_dim**-0.5
        self.rope_theta = rope_theta
        self.max_position_embeddings = max_position_embeddings

Lianmin Zheng's avatar
Lianmin Zheng committed
485
486
        # For tensor parallel attention
        if self.q_lora_rank is not None:
487
            self.fused_qkv_a_proj_with_mqa = ReplicatedLinear(
Ke Bao's avatar
Ke Bao committed
488
                self.hidden_size,
489
                self.q_lora_rank + self.kv_lora_rank + self.qk_rope_head_dim,
490
491
                bias=False,
                quant_config=quant_config,
492
                prefix=add_prefix("fused_qkv_a_proj_with_mqa", prefix),
493
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
494
495
496
497
            self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
            self.q_b_proj = ColumnParallelLinear(
                q_lora_rank,
                self.num_heads * self.qk_head_dim,
Ke Bao's avatar
Ke Bao committed
498
499
                bias=False,
                quant_config=quant_config,
Lianmin Zheng's avatar
Lianmin Zheng committed
500
501
502
                prefix=add_prefix("q_b_proj", prefix),
                tp_rank=attn_tp_rank,
                tp_size=attn_tp_size,
Ke Bao's avatar
Ke Bao committed
503
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
504
505
        else:
            self.q_proj = ColumnParallelLinear(
506
                self.hidden_size,
Lianmin Zheng's avatar
Lianmin Zheng committed
507
                self.num_heads * self.qk_head_dim,
508
509
                bias=False,
                quant_config=quant_config,
Lianmin Zheng's avatar
Lianmin Zheng committed
510
511
512
                prefix=add_prefix("q_proj", prefix),
                tp_rank=attn_tp_rank,
                tp_size=attn_tp_size,
513
            )
514
515
516
517
518
519
520
521
            self.kv_a_proj_with_mqa = ReplicatedLinear(
                self.hidden_size,
                self.kv_lora_rank + self.qk_rope_head_dim,
                bias=False,
                quant_config=quant_config,
                prefix=add_prefix("kv_a_proj_with_mqa", prefix),
            )

Lianmin Zheng's avatar
Lianmin Zheng committed
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
        self.kv_b_proj = ColumnParallelLinear(
            self.kv_lora_rank,
            self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
            bias=False,
            quant_config=quant_config,
            prefix=add_prefix("kv_b_proj", prefix),
            tp_rank=attn_tp_rank,
            tp_size=attn_tp_size,
        )
        # O projection.
        self.o_proj = RowParallelLinear(
            self.num_heads * self.v_head_dim,
            self.hidden_size,
            bias=False,
            quant_config=quant_config,
            reduce_results=reduce_results,
            prefix=add_prefix("o_proj", prefix),
            tp_rank=attn_tp_rank,
            tp_size=attn_tp_size,
        )
542
        self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
Ke Bao's avatar
Ke Bao committed
543
544
545
546

        if rope_scaling:
            rope_scaling["rope_type"] = "deepseek_yarn"

547
        self.rotary_emb = get_rope(
548
549
550
551
552
553
554
555
556
557
558
559
560
            qk_rope_head_dim,
            rotary_dim=qk_rope_head_dim,
            max_position=max_position_embeddings,
            base=rope_theta,
            rope_scaling=rope_scaling,
            is_neox_style=False,
        )

        if rope_scaling:
            mscale_all_dim = rope_scaling.get("mscale_all_dim", False)
            scaling_factor = rope_scaling["factor"]
            mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
            self.scaling = self.scaling * mscale * mscale
Ke Bao's avatar
Ke Bao committed
561
562
        else:
            self.rotary_emb.forward = self.rotary_emb.forward_native
563

564
        self.attn_mqa = RadixAttention(
565
566
567
568
569
570
            self.num_local_heads,
            self.kv_lora_rank + self.qk_rope_head_dim,
            self.scaling,
            num_kv_heads=1,
            layer_id=layer_id,
            v_head_dim=self.kv_lora_rank,
571
            quant_config=quant_config,
572
            prefix=add_prefix("attn_mqa", prefix),
573
574
        )

575
576
577
578
579
580
581
        self.attn_mha = RadixAttention(
            self.num_local_heads,
            self.qk_nope_head_dim + self.qk_rope_head_dim,
            self.scaling,
            num_kv_heads=self.num_local_heads,
            layer_id=layer_id,
            v_head_dim=self.v_head_dim,
582
            quant_config=quant_config,
583
            prefix=add_prefix("attn_mha", prefix),
584
585
        )

586
587
        self.alt_stream = alt_stream

Ke Bao's avatar
Ke Bao committed
588
589
        self.w_kc = None
        self.w_vc = None
590
        self.w_scale = None
591

592
593
594
595
        self.w_scale_k = None
        self.w_scale_v = None
        self.use_deep_gemm_bmm = False

Lianmin Zheng's avatar
Lianmin Zheng committed
596
597
598
        self.flashinfer_mla_disable_ragged = global_server_args_dict[
            "flashinfer_mla_disable_ragged"
        ]
599
600
601
        self.disable_chunked_prefix_cache = global_server_args_dict[
            "disable_chunked_prefix_cache"
        ]
602
        self.attention_backend = global_server_args_dict["attention_backend"]
603
604
605
        self.rocm_fused_decode_mla = get_bool_env_var(
            "SGLANG_ROCM_FUSED_DECODE_MLA", "false"
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
606

607
        # TODO: Design a finer way to determine the threshold
608
609
610
        self.chunked_prefix_cache_threshold = get_int_env_var(
            "SGL_CHUNKED_PREFIX_CACHE_THRESHOLD", 8192
        )
611
612
613
614

    def dispatch_attn_forward_method(
        self, forward_batch: ForwardBatch
    ) -> AttnForwardMethod:
615
616
617
618
619
620
621
622
623
624
625
626
        def _dispatch_mla_subtype():
            if _is_hip:
                if (
                    self.rocm_fused_decode_mla
                    and forward_batch.forward_mode.is_decode()
                ):
                    return AttnForwardMethod.MLA_FUSED_ROPE
                else:
                    return AttnForwardMethod.MLA
            else:
                return AttnForwardMethod.MLA

627
        if self.attention_backend == "flashinfer":
Lianmin Zheng's avatar
Lianmin Zheng committed
628
            # Flashinfer MLA: Do not absorb when enabling ragged prefill
629
            if (
Lianmin Zheng's avatar
Lianmin Zheng committed
630
631
632
633
                not self.flashinfer_mla_disable_ragged
                and forward_batch.forward_mode.is_extend()
                and not forward_batch.forward_mode.is_target_verify()
                and not forward_batch.forward_mode.is_draft_extend()
634
                and sum(forward_batch.extend_prefix_lens_cpu) == 0
635
636
637
            ):
                return AttnForwardMethod.MHA
            else:
638
                return _dispatch_mla_subtype()
639
        elif self.attention_backend == "fa3":
640
            # Flash Attention: Use MHA with chunked KV cache when prefilling on long sequences.
641
642
            if forward_batch.extend_prefix_lens_cpu is not None:
                sum_extend_prefix_lens = sum(forward_batch.extend_prefix_lens_cpu)
643
644
645
646
647
            if (
                forward_batch.forward_mode.is_extend()
                and not self.disable_chunked_prefix_cache
                and not forward_batch.forward_mode.is_target_verify()
                and not forward_batch.forward_mode.is_draft_extend()
648
649
650
651
                and (
                    sum_extend_prefix_lens >= self.chunked_prefix_cache_threshold
                    or sum_extend_prefix_lens == 0
                )
652
653
654
            ):
                return AttnForwardMethod.MHA_CHUNKED_KV
            else:
655
                return _dispatch_mla_subtype()
Lianmin Zheng's avatar
Lianmin Zheng committed
656
657
        else:
            # Triton: Use normal computation for prefill and use weight absorption for extend/decode
658
            if (
Lianmin Zheng's avatar
Lianmin Zheng committed
659
660
661
                forward_batch.forward_mode.is_extend()
                and not forward_batch.forward_mode.is_target_verify()
                and not forward_batch.forward_mode.is_draft_extend()
662
                and sum(forward_batch.extend_prefix_lens_cpu) == 0
663
664
665
            ):
                return AttnForwardMethod.MHA
            else:
666
                return _dispatch_mla_subtype()
Lianmin Zheng's avatar
Lianmin Zheng committed
667

668
669
670
671
    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
672
        forward_batch: ForwardBatch,
673
        zero_allocator: BumpAllocator,
674
    ) -> torch.Tensor:
Lianmin Zheng's avatar
Lianmin Zheng committed
675
676
677
678
679
        if hidden_states.shape[0] == 0:
            assert (
                not self.o_proj.reduce_results
            ), "short-circuiting allreduce will lead to hangs"
            return hidden_states
680

681
682
683
        attn_forward_method = self.dispatch_attn_forward_method(forward_batch)

        if attn_forward_method == AttnForwardMethod.MHA:
684
            return self.forward_normal(positions, hidden_states, forward_batch)
685
686
687
688
        elif attn_forward_method == AttnForwardMethod.MHA_CHUNKED_KV:
            return self.forward_normal_chunked_kv(
                positions, hidden_states, forward_batch
            )
689
690
691
692
693
694
695
696
        elif attn_forward_method == AttnForwardMethod.MLA:
            return self.forward_absorb(
                positions, hidden_states, forward_batch, zero_allocator
            )
        elif attn_forward_method == AttnForwardMethod.MLA_FUSED_ROPE:
            return self.forward_absorb_fused_mla_rope(
                positions, hidden_states, forward_batch
            )
697
        else:
698
            raise NotImplementedError
699
700
701
702
703
704
705
706

    def forward_normal(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        forward_batch: ForwardBatch,
    ) -> torch.Tensor:
        if self.q_lora_rank is not None:
707
708
709
            q, latent_cache = self.fused_qkv_a_proj_with_mqa(hidden_states)[0].split(
                [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1
            )
710
711
712
713
714
715
            q = self.q_a_layernorm(q)
            q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
        else:
            q = self.q_proj(hidden_states)[0].view(
                -1, self.num_local_heads, self.qk_head_dim
            )
716
717
            latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]

718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
        _, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
        kv_a, _ = latent_cache.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
        latent_cache = latent_cache.unsqueeze(1)
        kv_a = self.kv_a_layernorm(kv_a.contiguous())
        kv = self.kv_b_proj(kv_a)[0]
        kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim)
        k_nope = kv[..., : self.qk_nope_head_dim]
        v = kv[..., self.qk_nope_head_dim :]
        k_pe = latent_cache[:, :, self.kv_lora_rank :]
        q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
        q[..., self.qk_nope_head_dim :] = q_pe
        k = torch.empty_like(q)
        k[..., : self.qk_nope_head_dim] = k_nope
        k[..., self.qk_nope_head_dim :] = k_pe

        latent_cache[:, :, : self.kv_lora_rank] = kv_a.unsqueeze(1)
        latent_cache[:, :, self.kv_lora_rank :] = k_pe

        # Save latent cache
        forward_batch.token_to_kv_pool.set_kv_buffer(
            self.attn_mha, forward_batch.out_cache_loc, latent_cache, None
        )
        attn_output = self.attn_mha(q, k, v, forward_batch, save_kv_cache=False)
        attn_output = attn_output.reshape(-1, self.num_local_heads * self.v_head_dim)
        output, _ = self.o_proj(attn_output)
        return output

    def forward_absorb(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        forward_batch: ForwardBatch,
750
        zero_allocator: BumpAllocator,
751
752
    ) -> torch.Tensor:
        if self.q_lora_rank is not None:
753
754
755
            q, latent_cache = self.fused_qkv_a_proj_with_mqa(hidden_states)[0].split(
                [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1
            )
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
            k_nope = latent_cache[..., : self.kv_lora_rank]

            # overlap qk norm
            if self.alt_stream is not None and torch.cuda.is_current_stream_capturing():
                current_stream = torch.cuda.current_stream()
                self.alt_stream.wait_stream(current_stream)
                q = self.q_a_layernorm(q)
                with torch.cuda.stream(self.alt_stream):
                    k_nope = self.kv_a_layernorm(k_nope)
                current_stream.wait_stream(self.alt_stream)
            else:
                q = self.q_a_layernorm(q)
                k_nope = self.kv_a_layernorm(k_nope)

            k_nope = k_nope.unsqueeze(1)
771
772
773
774
775
            q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
        else:
            q = self.q_proj(hidden_states)[0].view(
                -1, self.num_local_heads, self.qk_head_dim
            )
776
            latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
777
778
779
            k_nope = latent_cache[..., : self.kv_lora_rank]
            k_nope = self.kv_a_layernorm(k_nope).unsqueeze(1)

780
        q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
781
        k_pe = latent_cache[..., self.kv_lora_rank :].unsqueeze(1)
782

783
784
        if self.use_deep_gemm_bmm:
            q_nope_val, q_nope_scale, masked_m, expected_m, aligned_m = (
785
                per_token_group_quant_mla_deep_gemm_masked_fp8(q_nope.transpose(0, 1))
786
787
788
789
            )
            q_nope_out = q_nope.new_empty(
                (self.num_local_heads, aligned_m, self.kv_lora_rank)
            )
790
            deep_gemm_grouped_gemm_nt_f8f8bf16_masked(
791
792
793
794
795
796
797
798
                (q_nope_val, q_nope_scale),
                (self.w_kc, self.w_scale_k),
                q_nope_out,
                masked_m,
                expected_m,
            )
            q_nope_out = q_nope_out[:, :expected_m, :]
        elif self.w_kc.dtype == torch.float8_e4m3fnuz:
799
800
801
802
803
            # TODO(kernel): add bmm_fp8 for torch.float8_e4m3fnuz
            q_nope_out = torch.bmm(
                q_nope.to(torch.bfloat16).transpose(0, 1),
                self.w_kc.to(torch.bfloat16) * self.w_scale,
            )
804
        elif self.w_kc.dtype == torch.float8_e4m3fn:
805
            q_nope_val, q_nope_scale = per_tensor_quant_mla_fp8(
Lianmin Zheng's avatar
Lianmin Zheng committed
806
                q_nope.transpose(0, 1),
807
                zero_allocator.allocate(1),
808
809
810
811
812
813
            )
            q_nope_out = bmm_fp8(
                q_nope_val, self.w_kc, q_nope_scale, self.w_scale, torch.bfloat16
            )
        else:
            q_nope_out = torch.bmm(q_nope.transpose(0, 1), self.w_kc)
814
815

        q_nope_out = q_nope_out.transpose(0, 1)
816
817
        q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)

xu-yfei's avatar
xu-yfei committed
818
        if self.attention_backend == "fa3" or self.attention_backend == "flashinfer":
819
            attn_output = self.attn_mqa(
Ke Bao's avatar
Ke Bao committed
820
                q_nope_out, k_nope, k_nope, forward_batch, q_rope=q_pe, k_rope=k_pe
821
822
823
            )
        else:
            q = torch.cat([q_nope_out, q_pe], dim=-1)
Ke Bao's avatar
Ke Bao committed
824
            k = torch.cat([k_nope, k_pe], dim=-1)
825
            attn_output = self.attn_mqa(q, k, k_nope, forward_batch)
826
827
        attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)

828
829
        if self.use_deep_gemm_bmm:
            attn_output_val, attn_output_scale, masked_m, expected_m, aligned_m = (
830
831
                per_token_group_quant_mla_deep_gemm_masked_fp8(
                    attn_output.transpose(0, 1)
832
833
834
835
836
                )
            )
            attn_bmm_output = attn_output.new_empty(
                (self.num_local_heads, aligned_m, self.v_head_dim)
            )
837
            deep_gemm_grouped_gemm_nt_f8f8bf16_masked(
838
839
840
841
842
843
844
845
                (attn_output_val, attn_output_scale),
                (self.w_vc, self.w_scale_v),
                attn_bmm_output,
                masked_m,
                expected_m,
            )
            attn_bmm_output = attn_bmm_output[:, :expected_m, :]
        elif self.w_vc.dtype == torch.float8_e4m3fnuz:
846
847
848
849
850
            # TODO(kernel): add bmm_fp8 for torch.float8_e4m3fnuz
            attn_bmm_output = torch.bmm(
                attn_output.to(torch.bfloat16).transpose(0, 1),
                self.w_vc.to(torch.bfloat16) * self.w_scale,
            )
851
        elif self.w_vc.dtype == torch.float8_e4m3fn:
852
            attn_output_val, attn_output_scale = per_tensor_quant_mla_fp8(
Lianmin Zheng's avatar
Lianmin Zheng committed
853
                attn_output.transpose(0, 1),
854
                zero_allocator.allocate(1),
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
            )
            attn_bmm_output = bmm_fp8(
                attn_output_val,
                self.w_vc,
                attn_output_scale,
                self.w_scale,
                torch.bfloat16,
            )
        else:
            attn_bmm_output = torch.bmm(attn_output.transpose(0, 1), self.w_vc)
        attn_output = attn_bmm_output.transpose(0, 1).flatten(1, 2)
        output, _ = self.o_proj(attn_output)

        return output

    def forward_absorb_fused_mla_rope(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        forward_batch: ForwardBatch,
875
        zero_allocator: BumpAllocator,
876
877
878
879
880
881
882
883
884
    ) -> torch.Tensor:
        enable_rope_fusion = (
            os.getenv("SGLANG_FUSED_MLA_ENABLE_ROPE_FUSION", "1") == "1"
        )
        q_len = hidden_states.shape[0]
        q_input = hidden_states.new_empty(
            q_len, self.num_local_heads, self.kv_lora_rank + self.qk_rope_head_dim
        )
        if self.q_lora_rank is not None:
885
886
887
            q, latent_cache = self.fused_qkv_a_proj_with_mqa(hidden_states)[0].split(
                [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1
            )
888
889
890
891
892
893
            q = self.q_a_layernorm(q)
            q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
        else:
            q = self.q_proj(hidden_states)[0].view(
                -1, self.num_local_heads, self.qk_head_dim
            )
894
            latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
895
896
897
898
899
900
901
902
903
        q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)

        if self.w_kc.dtype == torch.float8_e4m3fnuz:
            # TODO(kernel): add bmm_fp8 for torch.float8_e4m3fnuz
            q_nope_out = torch.bmm(
                q_nope.to(torch.bfloat16).transpose(0, 1),
                self.w_kc.to(torch.bfloat16) * self.w_scale,
            )
        elif self.w_kc.dtype == torch.float8_e4m3fn:
904
            q_nope_val, q_nope_scale = per_tensor_quant_mla_fp8(
905
906
907
                q_nope.transpose(0, 1),
                zero_allocator.allocate(1),
                dtype=torch.float8_e4m3fn,
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
            )
            q_nope_out = bmm_fp8(
                q_nope_val, self.w_kc, q_nope_scale, self.w_scale, torch.bfloat16
            )
        else:
            q_nope_out = torch.bmm(q_nope.transpose(0, 1), self.w_kc)
        q_input[..., : self.kv_lora_rank] = q_nope_out.transpose(0, 1)
        v_input = latent_cache[..., : self.kv_lora_rank]
        v_input = self.kv_a_layernorm(v_input.contiguous()).unsqueeze(1)
        k_input = latent_cache.unsqueeze(1)
        k_input[..., : self.kv_lora_rank] = v_input

        if not enable_rope_fusion:
            k_pe = k_input[..., self.kv_lora_rank :]
            q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
            q_input[..., self.kv_lora_rank :] = q_pe
            k_input[..., self.kv_lora_rank :] = k_pe
            k_pe_output = None
        else:
            k_pe_output = torch.empty_like(k_input[..., self.kv_lora_rank :])

        q_input[..., self.kv_lora_rank :] = q_pe

        # attn_output = self.attn_mqa(q_input, k_input, v_input, forward_batch)
        # Use Fused ROPE with use_rope=OFF.
        attn_output = torch.empty(
            (q_len, self.num_local_heads, self.kv_lora_rank),
            dtype=q.dtype,
            device=q.device,
        )
        attn_logits, _, kv_indptr, kv_indices, _, _, _ = (
            forward_batch.attn_backend.forward_metadata
        )
        cos_sin_cache = self.rotary_emb.cos_sin_cache
        num_kv_split = forward_batch.attn_backend.num_kv_splits
        sm_scale = self.attn_mqa.scaling
        if attn_logits is None:
            attn_logits = torch.empty(
                (
                    forward_batch.batch_size,
                    self.num_local_heads,
                    num_kv_split,
                    self.kv_lora_rank + 1,
                ),
                dtype=torch.float32,
                device=q.device,
            )

        # save current latent cache.
        forward_batch.token_to_kv_pool.set_kv_buffer(
            self.attn_mqa, forward_batch.out_cache_loc, k_input, None
        )
        key_cache_buf = forward_batch.token_to_kv_pool.get_key_buffer(
            self.attn_mqa.layer_id
        )
        val_cache_buf = key_cache_buf[..., : self.kv_lora_rank]

        decode_attention_fwd_grouped_rope(
            q_input,
            key_cache_buf,
            val_cache_buf,
            attn_output,
            kv_indptr,
            kv_indices,
            k_pe_output,
            self.kv_lora_rank,
            self.rotary_emb.rotary_dim,
            cos_sin_cache,
            positions,
            attn_logits,
            num_kv_split,
            sm_scale,
            logit_cap=self.attn_mqa.logit_cap,
            use_rope=enable_rope_fusion,
            is_neox_style=self.rotary_emb.is_neox_style,
        )

        if enable_rope_fusion:
            k_input[..., self.kv_lora_rank :] = k_pe_output
            forward_batch.token_to_kv_pool.set_kv_buffer(
                self.attn_mqa, forward_batch.out_cache_loc, k_input, None
            )

        attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)

993
994
995
996
997
998
999
        if self.w_vc.dtype == torch.float8_e4m3fnuz:
            # TODO(kernel): add bmm_fp8 for torch.float8_e4m3fnuz
            attn_bmm_output = torch.bmm(
                attn_output.to(torch.bfloat16).transpose(0, 1),
                self.w_vc.to(torch.bfloat16) * self.w_scale,
            )
        elif self.w_vc.dtype == torch.float8_e4m3fn:
1000
            attn_output_val, attn_output_scale = per_tensor_quant_mla_fp8(
1001
1002
1003
                attn_output.transpose(0, 1),
                zero_allocator.allocate(1),
                dtype=torch.float8_e4m3fn,
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
            )
            attn_bmm_output = bmm_fp8(
                attn_output_val,
                self.w_vc,
                attn_output_scale,
                self.w_scale,
                torch.bfloat16,
            )
        else:
            attn_bmm_output = torch.bmm(attn_output.transpose(0, 1), self.w_vc)
        attn_output = attn_bmm_output.transpose(0, 1).flatten(1, 2)
1015
1016
1017
1018
        output, _ = self.o_proj(attn_output)

        return output

1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
    def _chunked_prefix_attn_mha(
        self,
        q: torch.Tensor,
        accum_output: torch.Tensor,
        accum_lse: torch.Tensor,
        forward_batch: ForwardBatch,
    ) -> torch.Tensor:

        assert forward_batch.num_prefix_chunks is not None
        for i in range(forward_batch.num_prefix_chunks):
            forward_batch.set_prefix_chunk_idx(i)

            # Fetch latent cache from memory pool with precomputed chunked kv indices
            latent_cache_buf = forward_batch.token_to_kv_pool.get_key_buffer(
                self.attn_mha.layer_id
            )
            latent_cache = latent_cache_buf[
                forward_batch.prefix_chunk_kv_indices[i]
            ].contiguous()

            kv_a_normed, k_pe = latent_cache.split(
                [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
            )
            kv_a_normed = kv_a_normed.squeeze(1).contiguous()
            kv = self.kv_b_proj(kv_a_normed)[0]
            kv = kv.view(
                -1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim
            )
            v = kv[..., self.qk_nope_head_dim :]
            k_nope = kv[..., : self.qk_nope_head_dim]

            k = torch.empty(
                (
                    k_nope.shape[0],
                    self.num_local_heads,
                    self.qk_nope_head_dim + self.qk_rope_head_dim,
                ),
                dtype=v.dtype,
                device=v.device,
            )
            k[..., : self.qk_nope_head_dim] = k_nope
            k[..., self.qk_nope_head_dim :] = k_pe

            output, lse = self.attn_mha(q, k, v, forward_batch, save_kv_cache=False)
            lse = torch.transpose(lse, 0, 1).contiguous()
            tmp_output = torch.empty_like(accum_output)
            tmp_lse = torch.empty_like(accum_lse)
            merge_state_v2(output, lse, accum_output, accum_lse, tmp_output, tmp_lse)
            accum_output, accum_lse = tmp_output, tmp_lse

        return accum_output

    def forward_normal_chunked_kv(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        forward_batch: ForwardBatch,
    ) -> torch.Tensor:
        # In normal mha, the k and v tensors will become overly large when the prefix length is long.
        # To avoid this, we split the kv cache into chunks and process them one after another.
        # Since mha is compute friendly, the for loop induced here will not introduce significant overhead.
        # The top comments in https://github.com/vllm-project/vllm/blob/main/vllm/v1/attention/backends/mla/common.py
        # will be helpful for understanding the purpose of this function.

        # First do normal mha forward to get output for extended part
        if self.q_lora_rank is not None:
1085
1086
1087
            q, latent_cache = self.fused_qkv_a_proj_with_mqa(hidden_states)[0].split(
                [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1
            )
1088
1089
1090
1091
1092
1093
            q = self.q_a_layernorm(q)
            q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
        else:
            q = self.q_proj(hidden_states)[0].view(
                -1, self.num_local_heads, self.qk_head_dim
            )
1094
            latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
        _, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
        kv_a, _ = latent_cache.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
        latent_cache = latent_cache.unsqueeze(1)
        kv_a = self.kv_a_layernorm(kv_a.contiguous())
        kv = self.kv_b_proj(kv_a)[0]
        kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim)
        k_nope = kv[..., : self.qk_nope_head_dim]
        v = kv[..., self.qk_nope_head_dim :]
        k_pe = latent_cache[:, :, self.kv_lora_rank :]

        q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
        q[..., self.qk_nope_head_dim :] = q_pe
        k = torch.empty_like(q)
        k[..., : self.qk_nope_head_dim] = k_nope
        k[..., self.qk_nope_head_dim :] = k_pe

        latent_cache[:, :, : self.kv_lora_rank] = kv_a.unsqueeze(1)
        latent_cache[:, :, self.kv_lora_rank :] = k_pe

        # Save latent cache
        forward_batch.token_to_kv_pool.set_kv_buffer(
            self.attn_mha, forward_batch.out_cache_loc, latent_cache, None
        )

        # Do mha for extended part without prefix
        forward_batch.set_attn_attend_prefix_cache(False)
        attn_output, lse = self.attn_mha(q, k, v, forward_batch, save_kv_cache=False)
        lse = torch.transpose(lse, 0, 1).contiguous()

        # Do mha attention with chunked prefix cache if there are any sequence with prefix
        if any(forward_batch.extend_prefix_lens_cpu):
            # Only initialize the info once
            if forward_batch.num_prefix_chunks is None:
                forward_batch.prepare_chunked_prefix_cache_info(q.device)

            forward_batch.set_attn_attend_prefix_cache(True)
            attn_output = self._chunked_prefix_attn_mha(
                q=q,
                accum_output=attn_output,
                accum_lse=lse,
                forward_batch=forward_batch,
            )

        attn_output = attn_output.reshape(-1, self.num_local_heads * self.v_head_dim)
        output, _ = self.o_proj(attn_output)
        return output

1142

Liangsheng Yin's avatar
Liangsheng Yin committed
1143
1144
1145
1146
1147
1148
1149
class DeepseekV2DecoderLayer(nn.Module):

    def __init__(
        self,
        config: PretrainedConfig,
        layer_id: int,
        quant_config: Optional[QuantizationConfig] = None,
1150
        is_nextn: bool = False,
1151
        prefix: str = "",
1152
        alt_stream: Optional[torch.cuda.Stream] = None,
Liangsheng Yin's avatar
Liangsheng Yin committed
1153
1154
1155
    ) -> None:
        super().__init__()
        self.hidden_size = config.hidden_size
1156
        self.config = config
Liangsheng Yin's avatar
Liangsheng Yin committed
1157
1158
1159
        rope_theta = getattr(config, "rope_theta", 10000)
        rope_scaling = getattr(config, "rope_scaling", None)
        max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
Lianmin Zheng's avatar
Lianmin Zheng committed
1160
1161
        self.enable_dp_attention = global_server_args_dict["enable_dp_attention"]
        self.layer_id = layer_id
Baizhou Zhang's avatar
Baizhou Zhang committed
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
        self.self_attn = DeepseekV2AttentionMLA(
            config=config,
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
            qk_nope_head_dim=config.qk_nope_head_dim,
            qk_rope_head_dim=config.qk_rope_head_dim,
            v_head_dim=config.v_head_dim,
            q_lora_rank=(
                config.q_lora_rank if hasattr(config, "q_lora_rank") else None
            ),
            kv_lora_rank=config.kv_lora_rank,
            rope_theta=rope_theta,
            rope_scaling=rope_scaling,
            max_position_embeddings=max_position_embeddings,
            quant_config=quant_config,
            layer_id=layer_id,
            reduce_results=False,
            prefix=add_prefix("self_attn", prefix),
1180
            alt_stream=alt_stream,
Baizhou Zhang's avatar
Baizhou Zhang committed
1181
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
1182

1183
1184
1185
1186
1187
1188
1189
1190
        self.is_layer_sparse = self._is_layer_sparse(layer_id, is_nextn=is_nextn)
        is_previous_layer_sparse = self._is_layer_sparse(layer_id - 1, is_nextn=False)

        self.layer_scatter_modes = LayerScatterModes.init_new(
            layer_id=layer_id,
            num_layers=config.num_hidden_layers,
            is_layer_sparse=self.is_layer_sparse,
            is_previous_layer_sparse=is_previous_layer_sparse,
1191
1192
        )

1193
        if self.is_layer_sparse:
1194
1195
1196
1197
            self.mlp = DeepseekV2MoE(
                config=config,
                quant_config=quant_config,
                prefix=add_prefix("mlp", prefix),
fzyzcjy's avatar
fzyzcjy committed
1198
                layer_id=self.layer_id,
1199
            )
Liangsheng Yin's avatar
Liangsheng Yin committed
1200
        else:
1201
            if enable_moe_dense_fully_dp():
1202
1203
1204
                mlp_tp_rank, mlp_tp_size = 0, 1
            else:
                mlp_tp_rank, mlp_tp_size = None, None
Liangsheng Yin's avatar
Liangsheng Yin committed
1205
1206
1207
1208
1209
            self.mlp = DeepseekV2MLP(
                hidden_size=config.hidden_size,
                intermediate_size=config.intermediate_size,
                hidden_act=config.hidden_act,
                quant_config=quant_config,
1210
                prefix=add_prefix("mlp", prefix),
1211
1212
                tp_rank=mlp_tp_rank,
                tp_size=mlp_tp_size,
Liangsheng Yin's avatar
Liangsheng Yin committed
1213
            )
1214

Liangsheng Yin's avatar
Liangsheng Yin committed
1215
1216
1217
1218
1219
        self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(
            config.hidden_size, eps=config.rms_norm_eps
        )

1220
1221
1222
1223
        self.layer_communicator = LayerCommunicator(
            layer_scatter_modes=self.layer_scatter_modes,
            input_layernorm=self.input_layernorm,
            post_attention_layernorm=self.post_attention_layernorm,
1224
        )
1225
1226
1227
1228
1229
1230

    def _is_layer_sparse(self, layer_id: int, is_nextn: bool) -> bool:
        return is_nextn or (
            self.config.n_routed_experts is not None
            and layer_id >= self.config.first_k_dense_replace
            and layer_id % self.config.moe_layer_freq == 0
1231
1232
        )

Liangsheng Yin's avatar
Liangsheng Yin committed
1233
1234
1235
1236
    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
1237
        forward_batch: ForwardBatch,
Liangsheng Yin's avatar
Liangsheng Yin committed
1238
        residual: Optional[torch.Tensor],
1239
        zero_allocator: BumpAllocator,
Liangsheng Yin's avatar
Liangsheng Yin committed
1240
    ) -> torch.Tensor:
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
        return execute_operations(
            inputs=dict(
                positions=positions,
                hidden_states=hidden_states,
                forward_batch=forward_batch,
                residual=residual,
                zero_allocator=zero_allocator,
            ),
            operations=compute_layer_operations(self),
        )

    def op_comm_prepare_attn(
        self,
        state,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        forward_batch: ForwardBatch,
        residual: Optional[torch.Tensor],
        zero_allocator: BumpAllocator,
    ):
        state.hidden_states_after_comm_pre_attn, state.residual_after_input_ln = (
fzyzcjy's avatar
fzyzcjy committed
1262
            self.layer_communicator.prepare_attn(hidden_states, residual, forward_batch)
1263
1264
1265
1266
1267
1268
1269
        )
        state.update(
            dict(
                forward_batch=forward_batch,
                positions=positions,
                zero_allocator=zero_allocator,
            )
1270
        )
1271

1272
1273
1274
1275
1276
1277
    def op_attn(self, state):
        state.hidden_states_after_attn = self.self_attn(
            positions=state.positions,
            hidden_states=state.pop("hidden_states_after_comm_pre_attn"),
            forward_batch=state.forward_batch,
            zero_allocator=state.zero_allocator,
1278
1279
        )

1280
1281
1282
1283
1284
1285
1286
    def op_comm_prepare_mlp(self, state):
        state.hidden_states_mlp_input, state.residual_after_comm_pre_mlp = (
            self.layer_communicator.prepare_mlp(
                state.pop("hidden_states_after_attn"),
                state.pop("residual_after_input_ln"),
                state.forward_batch,
            )
1287
        )
1288

1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
    def op_mlp(self, state):
        hidden_states = state.pop("hidden_states_mlp_input")
        if not (
            enable_moe_dense_fully_dp()
            and (not self.is_layer_sparse)
            and hidden_states.shape[0] == 0
        ):
            state.hidden_states_mlp_output = self.mlp(
                hidden_states, state.forward_batch.forward_mode
            )
        else:
            state.hidden_states_mlp_output = hidden_states
1301

1302
    def op_comm_postprocess_layer(self, state):
1303
        hidden_states, residual = self.layer_communicator.postprocess_layer(
1304
1305
1306
            state.pop("hidden_states_mlp_output"),
            state.pop("residual_after_comm_pre_mlp"),
            state.forward_batch,
1307
        )
1308

1309
        state.clear(expect_keys={"positions", "forward_batch", "zero_allocator"})
1310
1311
        return hidden_states, residual

Liangsheng Yin's avatar
Liangsheng Yin committed
1312
1313
1314
1315
1316
1317
1318
1319

class DeepseekV2Model(nn.Module):
    fall_back_to_pt_during_load = False

    def __init__(
        self,
        config: PretrainedConfig,
        quant_config: Optional[QuantizationConfig] = None,
1320
        prefix: str = "",
Liangsheng Yin's avatar
Liangsheng Yin committed
1321
1322
1323
1324
1325
1326
1327
1328
    ) -> None:
        super().__init__()
        self.padding_id = config.pad_token_id
        self.vocab_size = config.vocab_size

        self.embed_tokens = VocabParallelEmbedding(
            config.vocab_size,
            config.hidden_size,
Ke Bao's avatar
Ke Bao committed
1329
            enable_tp=not global_server_args_dict["enable_dp_attention"],
Liangsheng Yin's avatar
Liangsheng Yin committed
1330
        )
1331
        self.alt_stream = torch.cuda.Stream() if _is_cuda else None
Liangsheng Yin's avatar
Liangsheng Yin committed
1332
1333
1334
1335
1336
1337
        self.layers = nn.ModuleList(
            [
                DeepseekV2DecoderLayer(
                    config,
                    layer_id,
                    quant_config=quant_config,
1338
                    prefix=add_prefix(f"layers.{layer_id}", prefix),
1339
                    alt_stream=self.alt_stream,
Liangsheng Yin's avatar
Liangsheng Yin committed
1340
1341
1342
1343
1344
1345
                )
                for layer_id in range(config.num_hidden_layers)
            ]
        )
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

1346
        self.dp_size = get_local_attention_dp_size()
Lianmin Zheng's avatar
Lianmin Zheng committed
1347

1348
1349
1350
    def get_input_embeddings(self) -> torch.Tensor:
        return self.embed_tokens

Liangsheng Yin's avatar
Liangsheng Yin committed
1351
1352
1353
1354
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
1355
        forward_batch: ForwardBatch,
1356
        input_embeds: torch.Tensor = None,
Liangsheng Yin's avatar
Liangsheng Yin committed
1357
    ) -> torch.Tensor:
1358
1359
1360
1361
        zero_allocator = BumpAllocator(
            # TODO for two-batch-overlap, we need a larger buffer size
            buffer_size=len(self.layers) * 2,
            dtype=torch.float32,
1362
1363
1364
            device=(
                input_embeds.device if input_embeds is not None else input_ids.device
            ),
1365
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
1366

1367
1368
1369
1370
1371
        if input_embeds is None:
            hidden_states = self.embed_tokens(input_ids)
        else:
            hidden_states = input_embeds

Liangsheng Yin's avatar
Liangsheng Yin committed
1372
1373
        residual = None
        for i in range(len(self.layers)):
1374
1375
1376
1377
1378
            with get_global_expert_distribution_recorder().with_current_layer(i):
                layer = self.layers[i]
                hidden_states, residual = layer(
                    positions, hidden_states, forward_batch, residual, zero_allocator
                )
Ke Bao's avatar
Ke Bao committed
1379
        if not forward_batch.forward_mode.is_idle():
1380
1381
1382
1383
            if residual is None:
                hidden_states = self.norm(hidden_states)
            else:
                hidden_states, _ = self.norm(hidden_states, residual)
Liangsheng Yin's avatar
Liangsheng Yin committed
1384
1385
1386
1387
1388
1389
1390
1391
1392
        return hidden_states


class DeepseekV2ForCausalLM(nn.Module):

    def __init__(
        self,
        config: PretrainedConfig,
        quant_config: Optional[QuantizationConfig] = None,
1393
        prefix: str = "",
Liangsheng Yin's avatar
Liangsheng Yin committed
1394
1395
1396
    ) -> None:
        super().__init__()
        self.config = config
1397
        self.tp_size = get_tensor_model_parallel_world_size()
Liangsheng Yin's avatar
Liangsheng Yin committed
1398
        self.quant_config = quant_config
1399
1400
1401
1402
1403
1404
1405
1406
1407
        self.determine_n_share_experts_fusion()
        self.model = DeepseekV2Model(
            config, quant_config, prefix=add_prefix("model", prefix)
        )
        self.lm_head = ParallelLMHead(
            config.vocab_size,
            config.hidden_size,
            quant_config=quant_config,
            prefix=add_prefix("lm_head", prefix),
1408
            use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
1409
1410
        )
        self.logits_processor = LogitsProcessor(config)
1411
        self.dp_size = get_local_attention_dp_size()
1412
1413
1414
1415

    def determine_n_share_experts_fusion(
        self, architecture: str = "DeepseekV3ForCausalLM"
    ):
1416
        self.n_share_experts_fusion = global_server_args_dict["n_share_experts_fusion"]
1417
1418
1419
        if self.n_share_experts_fusion > 0:
            # Only Deepseek V3/R1 can use shared experts fusion optimization now.
            if (
1420
1421
                not _is_cuda
                or self.config.architectures[0] != architecture
1422
1423
1424
1425
                or self.config.n_routed_experts != 256
            ):
                self.n_share_experts_fusion = 0
                global_server_args_dict["n_share_experts_fusion"] = 0
1426
1427
                log_info_on_rank0(
                    logger,
1428
                    "Only Deepseek V3/R1 on NV-platform can use shared experts fusion optimization. Shared experts fusion optimization is disabled.",
1429
1430
1431
1432
                )
            else:
                assert (
                    self.n_share_experts_fusion == self.tp_size
1433
                ), f"Shared experts fusion optimization is enabled in DeepSeek V3/R1, set it to {self.tp_size} can get best optimized performance."
1434
1435
        elif self.n_share_experts_fusion == 0:
            if (
1436
1437
                _is_cuda
                and torch.cuda.get_device_capability("cuda") >= (9, 0)
1438
                and self.config.architectures[0] == architecture
1439
1440
1441
1442
1443
                and self.config.n_routed_experts == 256
                and (not global_server_args_dict["enable_deepep_moe"])
            ):
                self.n_share_experts_fusion = self.tp_size
                global_server_args_dict["n_share_experts_fusion"] = self.tp_size
1444
1445
1446
                log_info_on_rank0(
                    logger,
                    "Deepseek V3/R1 with fp8 can use shared experts fusion optimization when SM version >=90. Shared experts fusion optimization is enabled.",
1447
                )
1448

Mick's avatar
Mick committed
1449
1450
1451
    def get_input_embeddings(self) -> nn.Embedding:
        return self.model.embed_tokens

1452
    @torch.no_grad()
Liangsheng Yin's avatar
Liangsheng Yin committed
1453
1454
1455
1456
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
1457
        forward_batch: ForwardBatch,
1458
        input_embeds: torch.Tensor = None,
Liangsheng Yin's avatar
Liangsheng Yin committed
1459
    ) -> torch.Tensor:
1460
1461

        hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
Lianmin Zheng's avatar
Lianmin Zheng committed
1462

1463
1464
1465
        return self.logits_processor(
            input_ids, hidden_states, self.lm_head, forward_batch
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
1466

1467
    def post_load_weights(self, is_nextn=False):
inkcherry's avatar
inkcherry committed
1468
1469

        # Perform post-processing after loading weights
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
        layer_ids = (
            range(self.config.num_hidden_layers)
            if not is_nextn
            else [self.config.num_hidden_layers]
        )
        for layer_id in layer_ids:
            self_attn = (
                self.model.layers[layer_id].self_attn
                if not is_nextn
                else self.model.decoder.self_attn
            )
Baizhou Zhang's avatar
Baizhou Zhang committed
1481
1482
1483
1484
1485
1486
1487
1488
            if hasattr(self_attn.kv_b_proj, "qweight"):
                # AWQ compatible
                if _is_cuda:
                    w = awq_dequantize(
                        self_attn.kv_b_proj.qweight,
                        self_attn.kv_b_proj.scales,
                        self_attn.kv_b_proj.qzeros,
                    ).T
inkcherry's avatar
inkcherry committed
1489
                else:
Baizhou Zhang's avatar
Baizhou Zhang committed
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
                    w = awq_dequantize(
                        self_attn.kv_b_proj.qweight,
                        self_attn.kv_b_proj.scales,
                        self_attn.kv_b_proj.qzeros,
                        0,
                        0,
                        0,
                    ).T
            else:
                w = self_attn.kv_b_proj.weight
            # NOTE(HandH1998): Since `bmm_fp8` only supports per-tensor scale, we have to requantize `self_attn.kv_b_proj`.
            # This may affect the accuracy of fp8 model.
1502
1503
1504
1505
            # Fix deepseek v3 blockwise bmm by using deep_gemm
            use_deep_gemm_bmm = False
            model_dtype = torch.get_default_dtype()

Baizhou Zhang's avatar
Baizhou Zhang committed
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
            if w.dtype in (
                torch.float8_e4m3fn,
                torch.float8_e4m3fnuz,
            ):
                if hasattr(self.quant_config, "weight_block_size"):
                    weight_block_size = self.quant_config.weight_block_size
                    if weight_block_size is not None:
                        assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
                        if _is_hip:
                            weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
                                weight=w,
                                weight_scale=self_attn.kv_b_proj.weight_scale_inv,
                                input_scale=None,
inkcherry's avatar
inkcherry committed
1519
                            )
Baizhou Zhang's avatar
Baizhou Zhang committed
1520
                        else:
inkcherry's avatar
inkcherry committed
1521
1522
                            weight = w
                            weight_scale = self_attn.kv_b_proj.weight_scale_inv
Baizhou Zhang's avatar
Baizhou Zhang committed
1523

1524
1525
1526
1527
1528
1529
                        if (
                            _is_cuda
                            and weight_block_size[0] == 128
                            and weight_block_size[1] == 128
                            and model_dtype == torch.bfloat16
                        ):
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
                            if _ENABLE_JIT_DEEPGEMM and get_bool_env_var(
                                "SGL_USE_DEEPGEMM_BMM", "false"
                            ):
                                block_scale = weight_scale
                                use_deep_gemm_bmm = True
                            else:
                                w = block_quant_dequant(
                                    weight,
                                    weight_scale,
                                    weight_block_size,
                                    model_dtype,
                                )
1542
1543
1544
1545
1546
                        else:
                            w, scale = block_quant_to_tensor_quant(
                                weight, weight_scale, weight_block_size
                            )
                            self_attn.w_scale = scale
Baizhou Zhang's avatar
Baizhou Zhang committed
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
                else:
                    weight = w
                    weight_scale = self_attn.kv_b_proj.weight_scale
                    w, scale = channel_quant_to_tensor_quant(weight, weight_scale)
                    self_attn.w_scale = scale

            if w.dtype == torch.int8:
                if hasattr(self.quant_config, "weight_block_size"):
                    # block-wise int8 need it
                    weight_block_size = self.quant_config.weight_block_size
                    if weight_block_size is not None:
                        assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
                        weight = w
                        weight_scale = self_attn.kv_b_proj.weight_scale_inv
                        w = int8_block_dequant(
                            weight, weight_scale, weight_block_size
                        ).to(torch.bfloat16)
                else:
                    # channel-wise int8 need it
                    w = w.to(torch.bfloat16) * self_attn.kv_b_proj.weight_scale.to(
                        torch.bfloat16
                    )
1569

Baizhou Zhang's avatar
Baizhou Zhang committed
1570
1571
1572
            w_kc, w_vc = w.unflatten(
                0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
            ).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
            if not use_deep_gemm_bmm:
                self_attn.w_kc = w_kc.transpose(1, 2).contiguous().transpose(1, 2)
                self_attn.w_vc = w_vc.contiguous().transpose(1, 2)
                if (
                    hasattr(self_attn.kv_b_proj, "weight_scale")
                    and self_attn.w_scale is None
                ):
                    self_attn.w_scale = self_attn.kv_b_proj.weight_scale
                    if _is_hip:
                        self_attn.w_scale *= 2.0
            else:
                num_tiles_k = self_attn.qk_nope_head_dim // weight_block_size[1]
                num_tiles_n = self_attn.v_head_dim // weight_block_size[0]
                ws_kc, ws_vc = block_scale.unflatten(
                    0, (-1, (num_tiles_k + num_tiles_n))
                ).split([num_tiles_k, num_tiles_n], dim=1)
                self_attn.w_scale_k = ws_kc.transpose(1, 2).contiguous()
                self_attn.w_scale_v = ws_vc.contiguous()
                self_attn.w_kc = w_kc.transpose(1, 2).contiguous()
                self_attn.w_vc = w_vc.contiguous()
                self_attn.use_deep_gemm_bmm = True
inkcherry's avatar
inkcherry committed
1594

1595
1596
1597
1598
    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], is_nextn=False):
        if is_nextn:
            if hasattr(self.config, "num_nextn_predict_layers"):
                num_nextn_layers = self.config.num_nextn_predict_layers
1599
                assert num_nextn_layers == 1, "Only 1 nextn layer is supported"
1600
1601
1602
1603
1604
1605
1606
1607
1608
                # compatible with old design
                nextn_layer_id = (
                    0
                    if self.config.num_hidden_layers == 1
                    else self.config.num_hidden_layers
                )
            else:
                raise ValueError("num_nextn_predict_layers is not in the config")

Liangsheng Yin's avatar
Liangsheng Yin committed
1609
1610
1611
1612
1613
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
        ]
1614
        if self.n_share_experts_fusion > 0:
1615
1616
            weights_list = list(weights)
            weights_dict = dict(weights_list)
1617
            if self.quant_config is None or self.quant_config.get_name() == "w8a8_int8":
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
                suffix_list = [
                    "down_proj.weight",
                    "down_proj.weight_scale",
                    "gate_proj.weight",
                    "gate_proj.weight_scale",
                    "up_proj.weight",
                    "up_proj.weight_scale",
                ]
            else:
                suffix_list = [
                    "down_proj.weight",
                    "down_proj.weight_scale_inv",
                    "gate_proj.weight",
                    "gate_proj.weight_scale_inv",
                    "up_proj.weight",
                    "up_proj.weight_scale_inv",
                ]
1635
            names_to_remove = []
1636
1637

            moe_layers = (
1638
1639
1640
1641
                range(
                    self.config.first_k_dense_replace,
                    self.config.num_hidden_layers,
                    self.config.moe_layer_freq,
1642
1643
1644
1645
1646
1647
1648
                )
                if not is_nextn
                else [nextn_layer_id]
            )

            for moe_layer in tqdm(
                moe_layers,
1649
1650
1651
                desc=f"Cloning {self.n_share_experts_fusion} "
                "replicas of the shared expert into MoE",
            ):
1652
1653
1654
1655
1656
                for suffix in suffix_list:
                    shared_expert_weight_name = (
                        f"model.layers.{moe_layer}.mlp.shared_experts.{suffix}"
                    )
                    for num_repeat in range(self.n_share_experts_fusion):
1657
1658
1659
1660
1661
1662
                        weights_list.append(
                            (
                                f"model.layers.{moe_layer}."
                                f"mlp.experts."
                                f"{self.config.n_routed_experts + num_repeat}"
                                f".{suffix}",
1663
                                weights_dict[shared_expert_weight_name],
1664
1665
                            )
                        )
1666
                    names_to_remove += [shared_expert_weight_name]
1667
            weights = [w for w in weights_list if w[0] not in names_to_remove]
Liangsheng Yin's avatar
Liangsheng Yin committed
1668
1669
1670

        # Params for weights, fp8 weight scales, fp8 activation scales
        # (param_name, weight_name, expert_id, shard_id)
1671
        expert_params_mapping = get_moe_impl_class().make_expert_params_mapping(
Liangsheng Yin's avatar
Liangsheng Yin committed
1672
1673
1674
            ckpt_gate_proj_name="gate_proj",
            ckpt_down_proj_name="down_proj",
            ckpt_up_proj_name="up_proj",
1675
            num_experts=self.config.n_routed_experts + self.n_share_experts_fusion,
Liangsheng Yin's avatar
Liangsheng Yin committed
1676
1677
        )

1678
1679
1680
1681
1682
1683
        # Fuse q_a_proj and kv_a_proj_with_mqa along output dimension when q_lora_rank is not None
        fuse_qkv_a_proj = hasattr(self.config, "q_lora_rank") and (
            self.config.q_lora_rank is not None
        )
        cached_a_proj = {} if fuse_qkv_a_proj else None

1684
1685
1686
1687
1688
1689
1690
1691
1692
        if is_nextn:
            nextn_layer_prefix = f"model.layers.{nextn_layer_id}"
            nextn_spec_weight_names = [
                "shared_head.norm",
                "eh_proj",
                "enorm",
                "hnorm",
            ]

Liangsheng Yin's avatar
Liangsheng Yin committed
1693
1694
        params_dict = dict(self.named_parameters())
        for name, loaded_weight in weights:
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
            if not is_nextn:
                if hasattr(self.config, "num_nextn_predict_layers"):
                    num_nextn_layers = self.config.num_nextn_predict_layers
                    if num_nextn_layers > 0 and name.startswith("model.layers"):
                        name_list = name.split(".")
                        if (
                            len(name_list) >= 3
                            and int(name_list[2]) >= self.config.num_hidden_layers
                        ):
                            continue
            else:
                if not name.startswith(nextn_layer_prefix):
                    continue

                # Use shared head and embed weights from target model
                if "shared_head.head" in name or "embed_tokens" in name:
                    continue

                is_decoder = True
                # For nextn specific weights
                for weight_name in nextn_spec_weight_names:
                    if weight_name in name:
                        name = name.replace(nextn_layer_prefix, "model")
                        is_decoder = False
                        break
                # For decoder layer weights
                if is_decoder:
                    name = name.replace(nextn_layer_prefix, "model.decoder")

Liangsheng Yin's avatar
Liangsheng Yin committed
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
1747
1748
1749
1750
1751
1752
1753
1754
1755
1756
            if "rotary_emb.inv_freq" in name:
                continue
            for param_name, weight_name, shard_id in stacked_params_mapping:
                # Skip non-stacked layers and experts (experts handled below).
                if weight_name not in name:
                    continue
                # We have mlp.experts[0].gate_proj in the checkpoint.
                # Since we handle the experts below in expert_params_mapping,
                # we need to skip here BEFORE we update the name, otherwise
                # name will be updated to mlp.experts[0].gate_up_proj, which
                # will then be updated below in expert_params_mapping
                # for mlp.experts[0].gate_gate_up_proj, which breaks load.
                if ("mlp.experts." in name) and name not in params_dict:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                for mapping in expert_params_mapping:
                    param_name, weight_name, expert_id, shard_id = mapping
                    if weight_name not in name:
                        continue
                    name = name.replace(weight_name, param_name)
                    param = params_dict[name]
                    weight_loader = param.weight_loader
                    weight_loader(
                        param,
                        loaded_weight,
1757
                        name,
Liangsheng Yin's avatar
Liangsheng Yin committed
1758
1759
1760
1761
1762
1763
1764
1765
1766
                        shard_id=shard_id,
                        expert_id=expert_id,
                    )
                    break
                else:
                    # Skip loading extra bias for GPTQ models.
                    if name.endswith(".bias") and name not in params_dict:
                        continue

1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803
1804
1805
1806
1807
1808
1809
                    if fuse_qkv_a_proj and (
                        "q_a_proj" in name or "kv_a_proj_with_mqa" in name
                    ):
                        cached_a_proj[name] = loaded_weight
                        q_a_proj_name = (
                            name
                            if "q_a_proj" in name
                            else name.replace("kv_a_proj_with_mqa", "q_a_proj")
                        )
                        kv_a_proj_name = (
                            name
                            if "kv_a_proj_with_mqa" in name
                            else name.replace("q_a_proj", "kv_a_proj_with_mqa")
                        )

                        # When both q_a_proj and kv_a_proj_with_mqa has been cached, load the fused weight to parameter
                        if (
                            q_a_proj_name in cached_a_proj
                            and kv_a_proj_name in cached_a_proj
                        ):
                            q_a_proj_weight = cached_a_proj[q_a_proj_name]
                            kv_a_proj_weight = cached_a_proj[kv_a_proj_name]
                            fused_weight = torch.cat(
                                [q_a_proj_weight, kv_a_proj_weight], dim=0
                            )

                            param_name = name.replace(
                                "q_a_proj", "fused_qkv_a_proj_with_mqa"
                            )
                            param = params_dict[param_name]

                            weight_loader = getattr(
                                param, "weight_loader", default_weight_loader
                            )
                            weight_loader(param, fused_weight)
                            cached_a_proj.pop(q_a_proj_name)
                            cached_a_proj.pop(kv_a_proj_name)
                    else:
                        param = params_dict[name]
                        weight_loader = getattr(
                            param, "weight_loader", default_weight_loader
                        )
                        weight_loader(param, loaded_weight)
Liangsheng Yin's avatar
Liangsheng Yin committed
1810

1811
        self.post_load_weights(is_nextn=is_nextn)
Ke Bao's avatar
Ke Bao committed
1812

1813
1814
1815
1816
1817
1818
1819
1820
1821
1822
1823
    def get_embed_and_head(self):
        return self.model.embed_tokens.weight, self.lm_head.weight

    def set_embed_and_head(self, embed, head):
        del self.model.embed_tokens.weight
        del self.lm_head.weight
        self.model.embed_tokens.weight = embed
        self.lm_head.weight = head
        torch.cuda.empty_cache()
        torch.cuda.synchronize()

1824
1825
1826
1827
1828
1829
1830
1831
    @classmethod
    def get_model_config_for_expert_location(cls, config):
        return ModelConfigForExpertLocation(
            num_layers=config.num_hidden_layers,
            num_logical_experts=config.n_routed_experts,
            num_groups=config.n_group,
        )

Liangsheng Yin's avatar
Liangsheng Yin committed
1832

HandH1998's avatar
HandH1998 committed
1833
1834
1835
1836
1837
class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM):
    pass


EntryClass = [DeepseekV2ForCausalLM, DeepseekV3ForCausalLM]