deepseek_v2.py 103 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
14

Liangsheng Yin's avatar
Liangsheng Yin committed
15
16
17
# Adapted from:
# https://github.com/vllm-project/vllm/blob/fb6af8bc086328ca6659e72d11ffd4309ce4de22/vllm/model_executor/models/deepseek_v2.py
"""Inference-only DeepseekV2 model."""
18

19
import concurrent.futures
20
import logging
21
import os
22
from enum import IntEnum, auto
Liangsheng Yin's avatar
Liangsheng Yin committed
23
24
25
from typing import Any, Dict, Iterable, Optional, Tuple

import torch
Ke Bao's avatar
Ke Bao committed
26
import torch.nn.functional as F
Liangsheng Yin's avatar
Liangsheng Yin committed
27
from torch import nn
28
from tqdm import tqdm
Liangsheng Yin's avatar
Liangsheng Yin committed
29
from transformers import PretrainedConfig
30
31

from sglang.srt.distributed import (
32
    get_moe_expert_parallel_world_size,
Liangsheng Yin's avatar
Liangsheng Yin committed
33
    get_tensor_model_parallel_world_size,
34
    parallel_state,
Liangsheng Yin's avatar
Liangsheng Yin committed
35
36
    tensor_model_parallel_all_reduce,
)
37
38
39
from sglang.srt.distributed.device_communicators.pynccl_allocator import (
    use_symmetric_memory,
)
fzyzcjy's avatar
fzyzcjy committed
40
41
42
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation
from sglang.srt.eplb.expert_location_dispatch import ExpertLocationDispatchInfo
43
from sglang.srt.layers.activation import SiluAndMul
44
from sglang.srt.layers.amx_utils import PackWeightMethod
45
46
47
48
49
from sglang.srt.layers.communicator import (
    LayerCommunicator,
    LayerScatterModes,
    enable_moe_dense_fully_dp,
)
Lianmin Zheng's avatar
Lianmin Zheng committed
50
51
52
from sglang.srt.layers.dp_attention import (
    get_attention_tp_rank,
    get_attention_tp_size,
53
    is_dp_attention_enabled,
Lianmin Zheng's avatar
Lianmin Zheng committed
54
)
55
from sglang.srt.layers.layernorm import RMSNorm
56
57
58
59
60
61
from sglang.srt.layers.linear import (
    ColumnParallelLinear,
    MergedColumnParallelLinear,
    ReplicatedLinear,
    RowParallelLinear,
)
Liangsheng Yin's avatar
Liangsheng Yin committed
62
from sglang.srt.layers.logits_processor import LogitsProcessor
63
64
65
66
67
from sglang.srt.layers.moe import (
    get_deepep_mode,
    get_moe_a2a_backend,
    should_use_flashinfer_cutlass_moe_fp4_allgather,
)
68
from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, get_moe_impl_class
69
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
70
from sglang.srt.layers.moe.topk import TopK
71
from sglang.srt.layers.quantization import deep_gemm_wrapper
72
from sglang.srt.layers.quantization.base_config import QuantizationConfig
73
from sglang.srt.layers.quantization.fp8_kernel import (
74
    is_fp8_fnuz,
75
    per_tensor_quant_mla_fp8,
76
    per_token_group_quant_mla_deep_gemm_masked_fp8,
77
)
HandH1998's avatar
HandH1998 committed
78
from sglang.srt.layers.quantization.fp8_utils import (
79
    block_quant_dequant,
HandH1998's avatar
HandH1998 committed
80
    block_quant_to_tensor_quant,
81
    channel_quant_to_tensor_quant,
82
    normalize_e4m3fn_to_e4m3fnuz,
83
    requant_weight_ue8m0_inplace,
HandH1998's avatar
HandH1998 committed
84
)
85
86
87
from sglang.srt.layers.quantization.int8_utils import (
    block_dequant as int8_block_dequant,
)
Liangsheng Yin's avatar
Liangsheng Yin committed
88
from sglang.srt.layers.radix_attention import RadixAttention
89
from sglang.srt.layers.rotary_embedding import get_rope, get_rope_wrapper
90
from sglang.srt.layers.utils import is_sm100_supported
91
92
93
94
from sglang.srt.layers.vocab_parallel_embedding import (
    ParallelLMHead,
    VocabParallelEmbedding,
)
95
from sglang.srt.managers.schedule_batch import global_server_args_dict
96
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
97
from sglang.srt.model_loader.weight_utils import default_weight_loader
98
99
100
101
from sglang.srt.two_batch_overlap import (
    MaybeTboDeepEPDispatcher,
    model_forward_maybe_tbo,
)
102
103
from sglang.srt.utils import (
    BumpAllocator,
104
    LazyValue,
105
    add_prefix,
106
    bind_or_assign,
107
    cpu_has_amx_support,
108
    get_bool_env_var,
109
    get_device_sm,
110
    get_int_env_var,
111
    is_cpu,
112
    is_cuda,
113
    is_flashinfer_available,
114
    is_hip,
115
    is_non_idle_and_non_empty,
116
    log_info_on_rank0,
117
    use_intel_amx_backend,
118
)
119

120
_is_hip = is_hip()
Yineng Zhang's avatar
Yineng Zhang committed
121
_is_cuda = is_cuda()
122
_is_fp8_fnuz = is_fp8_fnuz()
123
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
124
125
_is_cpu_amx_available = cpu_has_amx_support()
_is_cpu = is_cpu()
126
_device_sm = get_device_sm()
127

Yineng Zhang's avatar
Yineng Zhang committed
128
if _is_cuda:
129
130
131
132
133
134
135
    from sgl_kernel import (
        awq_dequantize,
        bmm_fp8,
        dsv3_fused_a_gemm,
        dsv3_router_gemm,
        merge_state_v2,
    )
136
137
elif _is_cpu and _is_cpu_amx_available:
    pass
138
139
140
141
elif _is_hip:
    from sglang.srt.layers.quantization.awq_triton import (
        awq_dequantize_triton as awq_dequantize,
    )
Yineng Zhang's avatar
Yineng Zhang committed
142
else:
Lianmin Zheng's avatar
Lianmin Zheng committed
143
    from vllm._custom_ops import awq_dequantize
Liangsheng Yin's avatar
Liangsheng Yin committed
144

145
146
147
148
149
if _is_hip:
    from sglang.srt.layers.attention.triton_ops.rocm_mla_decode_rope import (
        decode_attention_fwd_grouped_rope,
    )

150
151
152
_is_flashinfer_available = is_flashinfer_available()
_is_sm100_supported = is_cuda() and is_sm100_supported()

153

154
155
logger = logging.getLogger(__name__)

Liangsheng Yin's avatar
Liangsheng Yin committed
156

157
158
159
160
161
162
163
164
165
166
167
class AttnForwardMethod(IntEnum):
    # Use multi-head attention
    MHA = auto()

    # Use absorbed multi-latent attention
    MLA = auto()

    # Use multi-head attention, but with KV cache chunked.
    # This method can avoid OOM when prefix lengths are long.
    MHA_CHUNKED_KV = auto()

168
169
170
    # Use MLA but with fused RoPE
    MLA_FUSED_ROPE = auto()

171
172
173
    # Use MLA with fused RoPE kernel for CPU
    MLA_FUSED_ROPE_CPU = auto()

174

Liangsheng Yin's avatar
Liangsheng Yin committed
175
176
177
178
179
180
181
182
class DeepseekV2MLP(nn.Module):
    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
        quant_config: Optional[QuantizationConfig] = None,
        reduce_results: bool = True,
183
        prefix: str = "",
184
185
        tp_rank: Optional[int] = None,
        tp_size: Optional[int] = None,
Liangsheng Yin's avatar
Liangsheng Yin committed
186
187
    ) -> None:
        super().__init__()
188
189
        self.tp_size = tp_size

Liangsheng Yin's avatar
Liangsheng Yin committed
190
        self.gate_up_proj = MergedColumnParallelLinear(
191
192
193
194
195
            hidden_size,
            [intermediate_size] * 2,
            bias=False,
            quant_config=quant_config,
            prefix=add_prefix("gate_up_proj", prefix),
196
197
            tp_rank=tp_rank,
            tp_size=tp_size,
Liangsheng Yin's avatar
Liangsheng Yin committed
198
199
200
201
202
203
204
        )
        self.down_proj = RowParallelLinear(
            intermediate_size,
            hidden_size,
            bias=False,
            quant_config=quant_config,
            reduce_results=reduce_results,
205
            prefix=add_prefix("down_proj", prefix),
206
207
            tp_rank=tp_rank,
            tp_size=tp_size,
Liangsheng Yin's avatar
Liangsheng Yin committed
208
209
210
211
212
213
214
215
        )
        if hidden_act != "silu":
            raise ValueError(
                f"Unsupported activation: {hidden_act}. "
                "Only silu is supported for now."
            )
        self.act_fn = SiluAndMul()

216
217
218
219
    def forward(
        self,
        x,
        forward_batch=None,
220
        should_allreduce_fusion: bool = False,
221
222
        use_reduce_scatter: bool = False,
    ):
223
224
225
        if (self.tp_size == 1) and x.shape[0] == 0:
            return x

Liangsheng Yin's avatar
Liangsheng Yin committed
226
227
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
228
        x, _ = self.down_proj(
229
            x, skip_all_reduce=should_allreduce_fusion or use_reduce_scatter
230
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
231
232
233
        return x


Ke Bao's avatar
Ke Bao committed
234
class MoEGate(nn.Module):
235
236
237
238
    def __init__(
        self,
        config,
        prefix: str = "",
239
        is_nextn: bool = False,
240
    ):
Ke Bao's avatar
Ke Bao committed
241
        super().__init__()
242
        self.is_nextn = is_nextn
Ke Bao's avatar
Ke Bao committed
243
244
245
246
247
        self.weight = nn.Parameter(
            torch.empty((config.n_routed_experts, config.hidden_size))
        )
        if config.topk_method == "noaux_tc":
            self.e_score_correction_bias = nn.Parameter(
248
                torch.empty((config.n_routed_experts), dtype=torch.float32)
Ke Bao's avatar
Ke Bao committed
249
250
251
            )
        else:
            self.e_score_correction_bias = None
252
253
        if _is_cpu and _is_cpu_amx_available:
            self.quant_method = PackWeightMethod(weight_names=["weight"])
Ke Bao's avatar
Ke Bao committed
254
255

    def forward(self, hidden_states):
256
        if use_intel_amx_backend(self):
257
258
259
260
261
262
263
            return torch.ops.sgl_kernel.weight_packed_linear(
                hidden_states,
                self.weight,
                None,  # bias
                True,  # is_vnni
            )

264
        # NOTE: For some unknown reason, router_gemm seems degrade accept length.
265
        if (
266
            _is_cuda
267
            and hidden_states.shape[0] <= 16
268
269
270
271
            and hidden_states.shape[1] == 7168
            and self.weight.shape[0] == 256
            and _device_sm >= 90
        ):
272
273
            # router gemm output float32
            logits = dsv3_router_gemm(hidden_states, self.weight)
274
275
276
        else:
            logits = F.linear(hidden_states, self.weight, None)

Ke Bao's avatar
Ke Bao committed
277
278
279
        return logits


Liangsheng Yin's avatar
Liangsheng Yin committed
280
281
282
283
284
class DeepseekV2MoE(nn.Module):

    def __init__(
        self,
        config: PretrainedConfig,
fzyzcjy's avatar
fzyzcjy committed
285
        layer_id: int,
Liangsheng Yin's avatar
Liangsheng Yin committed
286
        quant_config: Optional[QuantizationConfig] = None,
287
        prefix: str = "",
288
        alt_stream: Optional[torch.cuda.Stream] = None,
289
        is_nextn: bool = False,
Liangsheng Yin's avatar
Liangsheng Yin committed
290
291
292
293
294
    ):
        super().__init__()
        self.tp_size = get_tensor_model_parallel_world_size()
        self.routed_scaling_factor = config.routed_scaling_factor
        self.n_shared_experts = config.n_shared_experts
295
296
297
298
299
        self.num_fused_shared_experts = (
            0
            if global_server_args_dict["disable_shared_experts_fusion"]
            else config.n_shared_experts
        )
300
        self.config = config
fzyzcjy's avatar
fzyzcjy committed
301
        self.layer_id = layer_id
302
        self.alt_stream = alt_stream
303

Liangsheng Yin's avatar
Liangsheng Yin committed
304
305
306
307
308
309
310
311
312
313
314
315
        if self.tp_size > config.n_routed_experts:
            raise ValueError(
                f"Tensor parallel size {self.tp_size} is greater than "
                f"the number of experts {config.n_routed_experts}."
            )

        if config.hidden_act != "silu":
            raise ValueError(
                f"Unsupported activation: {config.hidden_act}. "
                "Only silu is supported for now."
            )

316
317
318
        self.gate = MoEGate(
            config=config, prefix=add_prefix("gate", prefix), is_nextn=is_nextn
        )
Ke Bao's avatar
Ke Bao committed
319

320
321
322
323
324
325
326
327
328
        self.topk = TopK(
            top_k=config.num_experts_per_tok + self.num_fused_shared_experts,
            renormalize=config.norm_topk_prob,
            use_grouped_topk=True,
            num_expert_group=config.n_group,
            num_fused_shared_experts=self.num_fused_shared_experts,
            topk_group=config.topk_group,
            correction_bias=self.gate.e_score_correction_bias,
            routed_scaling_factor=self.routed_scaling_factor,
329
330
        )

331
        self.experts = get_moe_impl_class()(
332
            num_experts=config.n_routed_experts
333
            + self.num_fused_shared_experts
334
            + global_server_args_dict["ep_num_redundant_experts"],
Cheng Wan's avatar
Cheng Wan committed
335
            num_fused_shared_experts=self.num_fused_shared_experts,
336
            top_k=config.num_experts_per_tok + self.num_fused_shared_experts,
337
338
            hidden_size=config.hidden_size,
            intermediate_size=config.moe_intermediate_size,
fzyzcjy's avatar
fzyzcjy committed
339
            layer_id=self.layer_id,
340
            quant_config=quant_config,
341
            routed_scaling_factor=self.routed_scaling_factor,
342
343
            prefix=add_prefix("experts", prefix),
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
344

345
346
347
        self.shared_experts_is_int8 = False
        self.shared_experts_is_fp8 = False
        self.shared_experts_weight_block_size = None
348
        if config.n_shared_experts is not None and self.num_fused_shared_experts == 0:
Liangsheng Yin's avatar
Liangsheng Yin committed
349
            intermediate_size = config.moe_intermediate_size * config.n_shared_experts
350
            # disable tp for shared experts when enable deepep moe, or with fp4 allgather
351
352
353
354
355
356
357
358
359
            self.shared_experts = DeepseekV2MLP(
                hidden_size=config.hidden_size,
                intermediate_size=intermediate_size,
                hidden_act=config.hidden_act,
                quant_config=quant_config,
                reduce_results=False,
                prefix=add_prefix("shared_experts", prefix),
                **(
                    dict(tp_rank=0, tp_size=1)
360
                    if get_moe_a2a_backend().is_deepep()
361
                    or should_use_flashinfer_cutlass_moe_fp4_allgather()
362
363
364
                    else {}
                ),
            )
AniZpZ's avatar
AniZpZ committed
365
366
367
368
            is_packed_weight = hasattr(
                self.shared_experts.gate_up_proj.quant_method, "quant_config"
            ) and self.shared_experts.gate_up_proj.quant_method.quant_config.get_name() in {
                "awq",
369
                "awq_marlin",
AniZpZ's avatar
AniZpZ committed
370
371
                "moe_wna16",
            }
372
            self.shared_experts_is_int8 = (
373
374
                not is_packed_weight
                and self.shared_experts.gate_up_proj.weight.dtype == torch.int8
375
376
            )
            self.shared_experts_is_fp8 = (
377
378
                not is_packed_weight
                and self.shared_experts.gate_up_proj.weight.dtype == torch.float8_e4m3fn
379
380
381
382
383
384
385
386
387
            )
            if self.shared_experts_is_fp8:
                assert (
                    self.shared_experts.gate_up_proj.quant_method.quant_config.weight_block_size
                    == self.shared_experts.down_proj.quant_method.quant_config.weight_block_size
                )
                self.shared_experts_weight_block_size = (
                    self.shared_experts.gate_up_proj.quant_method.quant_config.weight_block_size
                )
388

389
390
        self.top_k = config.num_experts_per_tok

391
        if get_moe_a2a_backend().is_deepep():
392
            # TODO: we will support tp < ep in the future
393
            self.ep_size = get_moe_expert_parallel_world_size()
394
395
396
397
            self.num_experts = (
                config.n_routed_experts
                + global_server_args_dict["ep_num_redundant_experts"]
            )
398
399
400
401
402
403
404
405
406
            self.renormalize = config.norm_topk_prob
            self.topk_group = config.topk_group
            self.num_expert_group = config.n_group
            self.correction_bias = (
                self.gate.e_score_correction_bias.data
                if self.gate.e_score_correction_bias is not None
                else None
            )

407
            self.deepep_dispatcher = MaybeTboDeepEPDispatcher(
408
409
410
                group=parallel_state.get_tp_group().device_group,
                router_topk=self.top_k,
                permute_fusion=True,
411
                num_experts=self.num_experts,
412
                num_local_experts=config.n_routed_experts // self.tp_size,
Liangsheng Yin's avatar
Liangsheng Yin committed
413
                hidden_size=config.hidden_size,
414
                params_dtype=config.torch_dtype,
415
                deepep_mode=get_deepep_mode(),
416
                async_finish=True,
417
                return_recv_hook=True,
Liangsheng Yin's avatar
Liangsheng Yin committed
418
419
            )

420
        self._enable_deepep_moe = get_moe_a2a_backend().is_deepep()
421

422
423
424
425
426
427
428
    def get_moe_weights(self):
        return [
            x.data
            for name, x in self.experts.named_parameters()
            if name not in ["correction_bias"]
        ]

429
    def forward(
430
431
432
        self,
        hidden_states: torch.Tensor,
        forward_batch: Optional[ForwardBatch] = None,
433
        should_allreduce_fusion: bool = False,
434
        use_reduce_scatter: bool = False,
435
436
    ) -> torch.Tensor:
        if not self._enable_deepep_moe:
437
438
439
440
            DUAL_STREAM_TOKEN_THRESHOLD = 1024
            if (
                self.alt_stream is not None
                and self.num_fused_shared_experts == 0
441
                and hidden_states.shape[0] > 0
442
443
                and hidden_states.shape[0] <= DUAL_STREAM_TOKEN_THRESHOLD
            ):
444
                return self.forward_normal_dual_stream(
445
446
447
                    hidden_states,
                    should_allreduce_fusion,
                    use_reduce_scatter,
448
                )
449
            else:
450
                return self.forward_normal(
451
452
453
                    hidden_states,
                    should_allreduce_fusion,
                    use_reduce_scatter,
454
                )
455
456
457
        else:
            return self.forward_deepep(hidden_states, forward_batch)

458
    def forward_normal_dual_stream(
459
460
        self,
        hidden_states: torch.Tensor,
461
        should_allreduce_fusion: bool = False,
462
        use_reduce_scatter: bool = False,
463
    ) -> torch.Tensor:
464

465
466
467
        current_stream = torch.cuda.current_stream()
        self.alt_stream.wait_stream(current_stream)
        shared_output = self._forward_shared_experts(hidden_states)
468

469
        with torch.cuda.stream(self.alt_stream):
470
471
            # router_logits: (num_tokens, n_experts)
            router_logits = self.gate(hidden_states)
Cheng Wan's avatar
Cheng Wan committed
472
473
            topk_output = self.topk(hidden_states, router_logits)
            final_hidden_states = self.experts(hidden_states, topk_output)
474
475
            if not _is_cuda:
                final_hidden_states *= self.routed_scaling_factor
Cheng Wan's avatar
Cheng Wan committed
476

477
        current_stream.wait_stream(self.alt_stream)
478
479
        with use_symmetric_memory(parallel_state.get_tp_group()) as sm:
            final_hidden_states_out = torch.empty_like(final_hidden_states)
Cheng Wan's avatar
Cheng Wan committed
480

481
482
483
        torch.add(final_hidden_states, shared_output, out=final_hidden_states_out)
        final_hidden_states = final_hidden_states_out
        sm.tag(final_hidden_states)
484
485
486
487
488
489
        if (
            self.tp_size > 1
            and not should_allreduce_fusion
            and not use_reduce_scatter
            and not should_use_flashinfer_cutlass_moe_fp4_allgather()
        ):
490
491
492
            final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
        return final_hidden_states

493
    def forward_normal(
494
495
        self,
        hidden_states: torch.Tensor,
496
        should_allreduce_fusion: bool = False,
497
        use_reduce_scatter: bool = False,
498
    ) -> torch.Tensor:
499
500
        if hasattr(self, "shared_experts") and use_intel_amx_backend(
            self.shared_experts.gate_up_proj
501
        ):
502
            return self.forward_cpu(hidden_states, should_allreduce_fusion)
503

504
505
506
507
508
509
510
511
        if hidden_states.shape[0] > 0:
            shared_output = self._forward_shared_experts(hidden_states)
            # router_logits: (num_tokens, n_experts)
            router_logits = self.gate(hidden_states)
            topk_output = self.topk(hidden_states, router_logits)
        else:
            shared_output = None
            topk_output = self.topk.empty_topk_output(hidden_states.device)
512

Cheng Wan's avatar
Cheng Wan committed
513
        final_hidden_states = self.experts(hidden_states, topk_output)
514
515
        if not _is_cuda and not _use_aiter:
            # fused in biased_grouped_topk so we can skip here
516
            final_hidden_states *= self.routed_scaling_factor
517
        if shared_output is not None:
518
519
520
521
522
            with use_symmetric_memory(parallel_state.get_tp_group()) as sm:
                final_hidden_states_out = torch.empty_like(final_hidden_states)
            torch.add(final_hidden_states, shared_output, out=final_hidden_states_out)
            final_hidden_states = final_hidden_states_out
            sm.tag(final_hidden_states)
523
524
525
526
527
528
        if (
            self.tp_size > 1
            and not should_allreduce_fusion
            and not use_reduce_scatter
            and not should_use_flashinfer_cutlass_moe_fp4_allgather()
        ):
529
530
531
            final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
        return final_hidden_states

532
    def forward_cpu(
533
534
535
        self,
        hidden_states: torch.Tensor,
        should_allreduce_fusion: bool = False,
536
    ) -> torch.Tensor:
537
538
        # router_logits: (num_tokens, n_experts)
        router_logits = self.gate(hidden_states)
539
        topk_output = self.topk(hidden_states, router_logits)
540
        fused_experts_out = self.experts(
541
            hidden_states=hidden_states, topk_output=topk_output
542
543
        )

544
545
546
        assert use_intel_amx_backend(
            self.shared_experts.gate_up_proj
        ) == use_intel_amx_backend(self.shared_experts.down_proj)
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
        # [Note] inplace should be False in fused_experts.
        # If inplace is True in fused_experts (self.experts), hidden_states will be changed after fused_experts
        # While hidden_states is still needed in shared_expert.
        final_hidden_states = torch.ops.sgl_kernel.shared_expert_cpu(
            hidden_states,
            self.shared_experts.gate_up_proj.weight,
            self.shared_experts.down_proj.weight,
            fused_experts_out,
            self.routed_scaling_factor,
            True,  # inplace
            self.shared_experts_is_int8,  # use_int8_w8a8
            self.shared_experts_is_fp8,  # use_fp8_w8a16
            (
                self.shared_experts.gate_up_proj.weight_scale
                if self.shared_experts_is_int8
                else (
                    self.shared_experts.gate_up_proj.weight_scale_inv
                    if self.shared_experts_is_fp8
                    else None
                )
            ),  # w1_scale
            (
                self.shared_experts.down_proj.weight_scale
                if self.shared_experts_is_int8
                else (
                    self.shared_experts.down_proj.weight_scale_inv
                    if self.shared_experts_is_fp8
                    else None
                )
            ),  # w2_scale
            (
                self.shared_experts_weight_block_size
                if self.shared_experts_is_fp8
                else None
            ),  # block_size
            None,  # a1_scale
            None,  # a2_scale
            True,  # is_vnni
        )
586
        if self.tp_size > 1 and not should_allreduce_fusion:
587
588
589
            final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
        return final_hidden_states

590
591
592
593
    def forward_deepep(
        self, hidden_states: torch.Tensor, forward_batch: ForwardBatch
    ) -> torch.Tensor:
        shared_output = None
Cheng Wan's avatar
Cheng Wan committed
594
        if hidden_states.shape[0] > 0:
595
596
597
            # router_logits: (num_tokens, n_experts)
            router_logits = self.gate(hidden_states)
            shared_output = self._forward_shared_experts(hidden_states)
598
599
600
            topk_weights, topk_idx, _ = self.topk(
                hidden_states,
                router_logits,
601
                num_token_non_padded=forward_batch.num_token_non_padded,
602
603
604
                expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new(
                    layer_id=self.layer_id,
                ),
605
606
            )
        else:
607
608
            topk_weights, topk_idx, _ = self.topk.empty_topk_output(
                hidden_states.device
609
            )
610

611
612
613
614
        final_hidden_states = self.experts(
            hidden_states=hidden_states,
            topk_idx=topk_idx,
            topk_weights=topk_weights,
615
            forward_batch=forward_batch,
616
617
618
        )

        if shared_output is not None:
619
620
621
622
623
            x = shared_output
            x.add_(final_hidden_states, alpha=self.routed_scaling_factor)
            final_hidden_states = x
        else:
            final_hidden_states *= self.routed_scaling_factor
624
625
626
627

        return final_hidden_states

    def _forward_shared_experts(self, hidden_states):
628
        if self.num_fused_shared_experts == 0:
629
630
631
632
            return self.shared_experts(hidden_states)
        else:
            return None

633
    def op_gate(self, state):
634
        if is_non_idle_and_non_empty(
635
            state.forward_batch.forward_mode, state.hidden_states_mlp_input
636
        ):
637
            # router_logits: (num_tokens, n_experts)
638
            state.router_logits = self.gate(state.hidden_states_mlp_input)
639
        else:
640
            state.router_logits = None
641

642
    def op_shared_experts(self, state):
643
        hidden_states_mlp_input = state.pop("hidden_states_mlp_input")
644
        if (self.num_fused_shared_experts == 0) and is_non_idle_and_non_empty(
645
            state.forward_batch.forward_mode, hidden_states_mlp_input
646
        ):
647
            state.shared_output = self.shared_experts(hidden_states_mlp_input)
648
        else:
649
            state.shared_output = None
650

651
    def op_select_experts(self, state):
652
        router_logits = state.pop("router_logits")
653
654
        hidden_states = state.hidden_states_mlp_input

655
        if router_logits is not None:
656
657
658
            with get_global_expert_distribution_recorder().with_current_layer(
                self.layer_id
            ):
659
                state.topk_weights_local, state.topk_idx_local, _ = self.topk(
660
661
662
663
664
665
666
                    hidden_states=hidden_states,
                    router_logits=router_logits,
                    num_token_non_padded=state.forward_batch.num_token_non_padded,
                    expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new(
                        layer_id=self.layer_id,
                    ),
                )
667
668
669
670
671
672
673
        else:
            state.topk_idx_local = torch.full(
                (0, self.top_k), -1, dtype=torch.int, device=hidden_states.device
            )
            state.topk_weights_local = torch.empty(
                (0, self.top_k), dtype=torch.float32, device=hidden_states.device
            )
674

675
    def op_dispatch_a(self, state):
676
        if self.ep_size > 1:
677
            self.experts.deepep_dispatcher.dispatch_a(
678
                hidden_states=state.hidden_states_mlp_input,
679
680
                topk_idx=state.pop("topk_idx_local"),
                topk_weights=state.pop("topk_weights_local"),
681
                forward_batch=state.forward_batch,
682
                tbo_subbatch_index=state.get("tbo_subbatch_index"),
683
            )
684

685
    def op_dispatch_b(self, state):
686
687
688
689
        if self.ep_size > 1:
            with get_global_expert_distribution_recorder().with_current_layer(
                self.layer_id
            ):
690
                state.dispatch_output = self.experts.deepep_dispatcher.dispatch_b(
691
692
                    tbo_subbatch_index=state.get("tbo_subbatch_index"),
                )
693
694

    def op_experts(self, state):
695
696
        state.hidden_states_experts_output = self.experts.moe_impl(
            dispatch_output=state.dispatch_output,
697
        )
698

699
    def op_combine_a(self, state):
700
        if self.ep_size > 1:
701
            self.experts.deepep_dispatcher.combine_a(
702
                hidden_states=state.pop("hidden_states_experts_output"),
703
704
                topk_idx=state.dispatch_output.topk_idx,
                topk_weights=state.dispatch_output.topk_weights,
705
                forward_batch=state.forward_batch,
706
                tbo_subbatch_index=state.get("tbo_subbatch_index"),
707
            )
708
            state.pop("dispatch_output")
709

710
    def op_combine_b(self, state):
711
        if self.ep_size > 1:
712
713
714
715
            state.hidden_states_after_combine = (
                self.experts.deepep_dispatcher.combine_b(
                    tbo_subbatch_index=state.get("tbo_subbatch_index"),
                )
716
            )
717
718

    def op_output(self, state):
719
        final_hidden_states = state.pop("hidden_states_after_combine")
720
721
722
723
724
725
726

        if (shared_output := state.pop("shared_output")) is not None:
            x = shared_output
            x.add_(final_hidden_states, alpha=self.routed_scaling_factor)
            final_hidden_states = x
        else:
            final_hidden_states *= self.routed_scaling_factor
Liangsheng Yin's avatar
Liangsheng Yin committed
727

728
        state.hidden_states_mlp_output = final_hidden_states
729

Liangsheng Yin's avatar
Liangsheng Yin committed
730
731
732
733
734
735
736
737
738

def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
    import math

    if scale <= 1:
        return 1.0
    return 0.1 * mscale * math.log(scale) + 1.0


739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
class DeepseekV2AttentionMLA(nn.Module):

    def __init__(
        self,
        config: PretrainedConfig,
        hidden_size: int,
        num_heads: int,
        qk_nope_head_dim: int,
        qk_rope_head_dim: int,
        v_head_dim: int,
        q_lora_rank: int,
        kv_lora_rank: int,
        rope_theta: float = 10000,
        rope_scaling: Optional[Dict[str, Any]] = None,
        max_position_embeddings: int = 8192,
        quant_config: Optional[QuantizationConfig] = None,
Lianmin Zheng's avatar
Lianmin Zheng committed
755
756
        reduce_results: bool = True,
        layer_id: int = None,
757
        prefix: str = "",
758
        alt_stream: Optional[torch.cuda.Stream] = None,
759
760
761
762
763
764
765
766
767
768
    ) -> None:
        super().__init__()
        self.layer_id = layer_id
        self.hidden_size = hidden_size
        self.qk_nope_head_dim = qk_nope_head_dim
        self.qk_rope_head_dim = qk_rope_head_dim
        self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
        self.v_head_dim = v_head_dim
        self.q_lora_rank = q_lora_rank
        self.kv_lora_rank = kv_lora_rank
Lianmin Zheng's avatar
Lianmin Zheng committed
769
770
771
        attn_tp_rank = get_attention_tp_rank()
        attn_tp_size = get_attention_tp_size()

772
        self.num_heads = num_heads
Lianmin Zheng's avatar
Lianmin Zheng committed
773
774
        assert num_heads % attn_tp_size == 0
        self.num_local_heads = num_heads // attn_tp_size
775
776
777
778
        self.scaling = self.qk_head_dim**-0.5
        self.rope_theta = rope_theta
        self.max_position_embeddings = max_position_embeddings

Lianmin Zheng's avatar
Lianmin Zheng committed
779
780
        # For tensor parallel attention
        if self.q_lora_rank is not None:
781
            self.fused_qkv_a_proj_with_mqa = ReplicatedLinear(
Ke Bao's avatar
Ke Bao committed
782
                self.hidden_size,
783
                self.q_lora_rank + self.kv_lora_rank + self.qk_rope_head_dim,
784
785
                bias=False,
                quant_config=quant_config,
786
                prefix=add_prefix("fused_qkv_a_proj_with_mqa", prefix),
787
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
788
789
790
791
            self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
            self.q_b_proj = ColumnParallelLinear(
                q_lora_rank,
                self.num_heads * self.qk_head_dim,
Ke Bao's avatar
Ke Bao committed
792
793
                bias=False,
                quant_config=quant_config,
Lianmin Zheng's avatar
Lianmin Zheng committed
794
795
796
                prefix=add_prefix("q_b_proj", prefix),
                tp_rank=attn_tp_rank,
                tp_size=attn_tp_size,
Ke Bao's avatar
Ke Bao committed
797
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
798
799
        else:
            self.q_proj = ColumnParallelLinear(
800
                self.hidden_size,
Lianmin Zheng's avatar
Lianmin Zheng committed
801
                self.num_heads * self.qk_head_dim,
802
803
                bias=False,
                quant_config=quant_config,
Lianmin Zheng's avatar
Lianmin Zheng committed
804
805
806
                prefix=add_prefix("q_proj", prefix),
                tp_rank=attn_tp_rank,
                tp_size=attn_tp_size,
807
            )
808
809
810
811
812
813
814
815
            self.kv_a_proj_with_mqa = ReplicatedLinear(
                self.hidden_size,
                self.kv_lora_rank + self.qk_rope_head_dim,
                bias=False,
                quant_config=quant_config,
                prefix=add_prefix("kv_a_proj_with_mqa", prefix),
            )

Lianmin Zheng's avatar
Lianmin Zheng committed
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
        self.kv_b_proj = ColumnParallelLinear(
            self.kv_lora_rank,
            self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
            bias=False,
            quant_config=quant_config,
            prefix=add_prefix("kv_b_proj", prefix),
            tp_rank=attn_tp_rank,
            tp_size=attn_tp_size,
        )
        # O projection.
        self.o_proj = RowParallelLinear(
            self.num_heads * self.v_head_dim,
            self.hidden_size,
            bias=False,
            quant_config=quant_config,
            reduce_results=reduce_results,
            prefix=add_prefix("o_proj", prefix),
            tp_rank=attn_tp_rank,
            tp_size=attn_tp_size,
        )
836
        self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
Ke Bao's avatar
Ke Bao committed
837
838
839
840

        if rope_scaling:
            rope_scaling["rope_type"] = "deepseek_yarn"

841
        self.rotary_emb = get_rope_wrapper(
842
843
844
845
846
847
            qk_rope_head_dim,
            rotary_dim=qk_rope_head_dim,
            max_position=max_position_embeddings,
            base=rope_theta,
            rope_scaling=rope_scaling,
            is_neox_style=False,
848
            device=global_server_args_dict["device"],
849
850
851
852
853
854
855
        )

        if rope_scaling:
            mscale_all_dim = rope_scaling.get("mscale_all_dim", False)
            scaling_factor = rope_scaling["factor"]
            mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
            self.scaling = self.scaling * mscale * mscale
Ke Bao's avatar
Ke Bao committed
856
857
        else:
            self.rotary_emb.forward = self.rotary_emb.forward_native
858

859
        self.attn_mqa = RadixAttention(
860
861
862
863
864
865
            self.num_local_heads,
            self.kv_lora_rank + self.qk_rope_head_dim,
            self.scaling,
            num_kv_heads=1,
            layer_id=layer_id,
            v_head_dim=self.kv_lora_rank,
866
            quant_config=quant_config,
867
            prefix=add_prefix("attn_mqa", prefix),
868
869
        )

870
871
872
873
874
875
876
        self.attn_mha = RadixAttention(
            self.num_local_heads,
            self.qk_nope_head_dim + self.qk_rope_head_dim,
            self.scaling,
            num_kv_heads=self.num_local_heads,
            layer_id=layer_id,
            v_head_dim=self.v_head_dim,
877
            quant_config=quant_config,
878
            prefix=add_prefix("attn_mha", prefix),
879
880
        )

881
        self.alt_stream = alt_stream
882
        self.attn_mha.kv_b_proj = None
883

Ke Bao's avatar
Ke Bao committed
884
885
        self.w_kc = None
        self.w_vc = None
886
        self.w_scale = 1.0
887

888
889
890
891
        self.w_scale_k = None
        self.w_scale_v = None
        self.use_deep_gemm_bmm = False

Lianmin Zheng's avatar
Lianmin Zheng committed
892
893
894
        self.flashinfer_mla_disable_ragged = global_server_args_dict[
            "flashinfer_mla_disable_ragged"
        ]
895
896
897
        self.disable_chunked_prefix_cache = global_server_args_dict[
            "disable_chunked_prefix_cache"
        ]
898
899
900
901

        self.current_attention_backend = (
            None  # Attention backend used by current forward batch
        )
902
903
904
        self.rocm_fused_decode_mla = get_bool_env_var(
            "SGLANG_ROCM_FUSED_DECODE_MLA", "false"
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
905

906
        # TODO: Design a finer way to determine the threshold
907
908
909
        self.chunked_prefix_cache_threshold = get_int_env_var(
            "SGL_CHUNKED_PREFIX_CACHE_THRESHOLD", 8192
        )
910

911
912
913
        # If we have self.fused_qkv_a_proj_with_mqa and we're running on CPU, we will choose the torch.ops.sgl_kernel.qkv_proj_with_rope_fused_weight kernel
        # which requires self.w_kc and self.w_vc to be packed.
        # If not, we will use torch.bmm and weight shouldn't be packed in this case
AniZpZ's avatar
AniZpZ committed
914
915
        has_fused_proj = hasattr(self, "fused_qkv_a_proj_with_mqa")
        if has_fused_proj and _is_cpu and _is_cpu_amx_available:
916
917
918
919
            self.quant_method = PackWeightMethod(
                weight_names=["w_kc", "w_vc"], transpose_dims=[[1, 2], [1, 2]]
            )

920
        is_packed_weight = (
AniZpZ's avatar
AniZpZ committed
921
922
923
            has_fused_proj
            and hasattr(self.fused_qkv_a_proj_with_mqa.quant_method, "quant_config")
            and self.fused_qkv_a_proj_with_mqa.quant_method.quant_config.get_name()
924
            in {"awq", "awq_marlin", "moe_wna16"}
925
        )
926
        self.use_min_latency_fused_a_gemm = (
AniZpZ's avatar
AniZpZ committed
927
            has_fused_proj
928
            and not is_packed_weight
929
930
931
            and self.fused_qkv_a_proj_with_mqa.weight.dtype == torch.bfloat16
            and self.fused_qkv_a_proj_with_mqa.weight.shape[0] == 2112
            and self.fused_qkv_a_proj_with_mqa.weight.shape[1] == 7168
932
            and _is_cuda
933
            and _device_sm >= 90
934
935
        )

936
        self.qkv_proj_with_rope_is_int8 = (
AniZpZ's avatar
AniZpZ committed
937
            has_fused_proj
938
            and not is_packed_weight
939
940
941
            and self.fused_qkv_a_proj_with_mqa.weight.dtype == torch.int8
        )
        self.qkv_proj_with_rope_is_fp8 = (
AniZpZ's avatar
AniZpZ committed
942
            has_fused_proj
943
            and not is_packed_weight
944
945
946
947
            and self.fused_qkv_a_proj_with_mqa.weight.dtype == torch.float8_e4m3fn
        )

        self.weight_block_size = None
948
949
950
951
952
953
        if self.qkv_proj_with_rope_is_fp8 and _is_cpu and _is_cpu_amx_available:
            assert getattr(
                self.fused_qkv_a_proj_with_mqa.quant_method, "block_quant", False
            ) == getattr(self.q_b_proj.quant_method, "block_quant", False)
            use_block_quant = getattr(
                self.fused_qkv_a_proj_with_mqa.quant_method, "block_quant", False
954
955
            )

956
957
958
959
960
961
962
963
964
            if use_block_quant:
                assert (
                    self.fused_qkv_a_proj_with_mqa.quant_method.quant_config.weight_block_size
                    == self.q_b_proj.quant_method.quant_config.weight_block_size
                )
                self.weight_block_size = (
                    self.fused_qkv_a_proj_with_mqa.quant_method.quant_config.weight_block_size
                )

965
966
967
    def dispatch_attn_forward_method(
        self, forward_batch: ForwardBatch
    ) -> AttnForwardMethod:
968
969
970
971
972
973
974
975
976
977
        def _dispatch_mla_subtype():
            if _is_hip:
                if (
                    self.rocm_fused_decode_mla
                    and forward_batch.forward_mode.is_decode()
                ):
                    return AttnForwardMethod.MLA_FUSED_ROPE
                else:
                    return AttnForwardMethod.MLA
            else:
978
979
                if hasattr(self, "fused_qkv_a_proj_with_mqa") and use_intel_amx_backend(
                    self
980
981
982
983
                ):
                    return AttnForwardMethod.MLA_FUSED_ROPE_CPU
                else:
                    return AttnForwardMethod.MLA
984

985
986
987
988
989
990
991
992
        # Determine attention backend used by current forward batch
        if forward_batch.forward_mode.is_decode_or_idle():
            attention_backend = global_server_args_dict["decode_attention_backend"]
        else:
            attention_backend = global_server_args_dict["prefill_attention_backend"]
        self.current_attention_backend = attention_backend

        if attention_backend == "ascend":
993
            return AttnForwardMethod.MLA
994
        elif attention_backend == "flashinfer":
Lianmin Zheng's avatar
Lianmin Zheng committed
995
            # Flashinfer MLA: Do not absorb when enabling ragged prefill
996
            if (
Lianmin Zheng's avatar
Lianmin Zheng committed
997
998
999
1000
                not self.flashinfer_mla_disable_ragged
                and forward_batch.forward_mode.is_extend()
                and not forward_batch.forward_mode.is_target_verify()
                and not forward_batch.forward_mode.is_draft_extend()
1001
                and sum(forward_batch.extend_prefix_lens_cpu) == 0
1002
1003
1004
            ):
                return AttnForwardMethod.MHA
            else:
1005
                return _dispatch_mla_subtype()
1006
        elif attention_backend == "fa3":
1007
            # Flash Attention: Use MHA with chunked KV cache when prefilling on long sequences.
1008
1009
            if forward_batch.extend_prefix_lens_cpu is not None:
                sum_extend_prefix_lens = sum(forward_batch.extend_prefix_lens_cpu)
1010
1011
1012
1013
1014
            if (
                forward_batch.forward_mode.is_extend()
                and not self.disable_chunked_prefix_cache
                and not forward_batch.forward_mode.is_target_verify()
                and not forward_batch.forward_mode.is_draft_extend()
1015
1016
1017
1018
                and (
                    sum_extend_prefix_lens >= self.chunked_prefix_cache_threshold
                    or sum_extend_prefix_lens == 0
                )
1019
1020
1021
            ):
                return AttnForwardMethod.MHA_CHUNKED_KV
            else:
1022
                return _dispatch_mla_subtype()
1023
        elif attention_backend == "aiter":
1024
1025
1026
1027
1028
1029
1030
1031
            if (
                forward_batch.forward_mode.is_extend()
                and not forward_batch.forward_mode.is_target_verify()
                and not forward_batch.forward_mode.is_draft_extend()
            ):
                return AttnForwardMethod.MHA
            else:
                return AttnForwardMethod.MLA
Lianmin Zheng's avatar
Lianmin Zheng committed
1032
1033
        else:
            # Triton: Use normal computation for prefill and use weight absorption for extend/decode
1034
            if (
Lianmin Zheng's avatar
Lianmin Zheng committed
1035
1036
1037
                forward_batch.forward_mode.is_extend()
                and not forward_batch.forward_mode.is_target_verify()
                and not forward_batch.forward_mode.is_draft_extend()
1038
                and sum(forward_batch.extend_prefix_lens_cpu) == 0
1039
1040
1041
            ):
                return AttnForwardMethod.MHA
            else:
1042
                return _dispatch_mla_subtype()
Lianmin Zheng's avatar
Lianmin Zheng committed
1043

1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
    def op_prepare(self, state):
        state.attn_intermediate_state = self.forward_prepare(
            positions=state.positions,
            hidden_states=state.pop("hidden_states_after_comm_pre_attn"),
            forward_batch=state.forward_batch,
            zero_allocator=state.zero_allocator,
        )

    def op_core(self, state):
        state.hidden_states_after_attn = self.forward_core(
            state.pop("attn_intermediate_state")
        )

1057
1058
1059
1060
    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
1061
        forward_batch: ForwardBatch,
1062
        zero_allocator: BumpAllocator,
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
    ):
        s = self.forward_prepare(
            positions=positions,
            hidden_states=hidden_states,
            forward_batch=forward_batch,
            zero_allocator=zero_allocator,
        )
        return self.forward_core(s)

    def forward_prepare(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        forward_batch: ForwardBatch,
        zero_allocator: BumpAllocator,
    ):
1079
1080
1081
        if self.attn_mha.kv_b_proj is None:
            self.attn_mha.kv_b_proj = self.kv_b_proj

Lianmin Zheng's avatar
Lianmin Zheng committed
1082
1083
1084
1085
        if hidden_states.shape[0] == 0:
            assert (
                not self.o_proj.reduce_results
            ), "short-circuiting allreduce will lead to hangs"
1086
            return hidden_states, None, forward_batch, None
1087

1088
1089
1090
        attn_forward_method = self.dispatch_attn_forward_method(forward_batch)

        if attn_forward_method == AttnForwardMethod.MHA:
1091
1092
1093
            inner_state = self.forward_normal_prepare(
                positions, hidden_states, forward_batch, zero_allocator
            )
1094
        elif attn_forward_method == AttnForwardMethod.MHA_CHUNKED_KV:
1095
1096
            inner_state = self.forward_normal_chunked_kv_prepare(
                positions, hidden_states, forward_batch, zero_allocator
1097
            )
1098
        elif attn_forward_method == AttnForwardMethod.MLA:
1099
            inner_state = self.forward_absorb_prepare(
1100
1101
1102
                positions, hidden_states, forward_batch, zero_allocator
            )
        elif attn_forward_method == AttnForwardMethod.MLA_FUSED_ROPE:
1103
1104
            inner_state = self.forward_absorb_fused_mla_rope_prepare(
                positions, hidden_states, forward_batch, zero_allocator
1105
            )
1106
1107
1108
1109
        elif attn_forward_method == AttnForwardMethod.MLA_FUSED_ROPE_CPU:
            inner_state = self.forward_absorb_fused_mla_rope_cpu_prepare(
                positions, hidden_states, forward_batch, zero_allocator
            )
1110
        else:
1111
            raise NotImplementedError
1112
        return None, attn_forward_method, forward_batch, inner_state
1113

1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
    def forward_core(self, intermediate_state):
        hidden_states, attn_forward_method, forward_batch, inner_state = (
            intermediate_state
        )
        if inner_state is None:
            return hidden_states

        if attn_forward_method == AttnForwardMethod.MHA:
            return self.forward_normal_core(*inner_state)
        elif attn_forward_method == AttnForwardMethod.MHA_CHUNKED_KV:
            return self.forward_normal_chunked_kv_core(*inner_state)
        elif attn_forward_method == AttnForwardMethod.MLA:
            return self.forward_absorb_core(*inner_state)
        elif attn_forward_method == AttnForwardMethod.MLA_FUSED_ROPE:
            return self.forward_absorb_fused_mla_rope_core(*inner_state)
1129
1130
        elif attn_forward_method == AttnForwardMethod.MLA_FUSED_ROPE_CPU:
            return self.forward_absorb_fused_mla_rope_cpu_core(*inner_state)
1131
1132
1133
1134
        else:
            raise NotImplementedError

    def forward_normal_prepare(
1135
1136
1137
1138
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        forward_batch: ForwardBatch,
1139
1140
        zero_allocator: BumpAllocator,
    ):
1141
        if self.q_lora_rank is not None:
1142
1143
1144
            q, latent_cache = self.fused_qkv_a_proj_with_mqa(hidden_states)[0].split(
                [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1
            )
1145
1146
1147
1148
1149
1150
            q = self.q_a_layernorm(q)
            q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
        else:
            q = self.q_proj(hidden_states)[0].view(
                -1, self.num_local_heads, self.qk_head_dim
            )
1151
1152
            latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]

1153
1154
1155
        _, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
        kv_a, _ = latent_cache.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
        latent_cache = latent_cache.unsqueeze(1)
1156
        kv_a = self.kv_a_layernorm(kv_a)
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
        kv = self.kv_b_proj(kv_a)[0]
        kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim)
        k_nope = kv[..., : self.qk_nope_head_dim]
        v = kv[..., self.qk_nope_head_dim :]
        k_pe = latent_cache[:, :, self.kv_lora_rank :]
        q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
        q[..., self.qk_nope_head_dim :] = q_pe
        k = torch.empty_like(q)
        k[..., : self.qk_nope_head_dim] = k_nope
        k[..., self.qk_nope_head_dim :] = k_pe

        latent_cache[:, :, : self.kv_lora_rank] = kv_a.unsqueeze(1)
        latent_cache[:, :, self.kv_lora_rank :] = k_pe

        # Save latent cache
        forward_batch.token_to_kv_pool.set_kv_buffer(
            self.attn_mha, forward_batch.out_cache_loc, latent_cache, None
        )
1175
1176
1177
1178

        return q, k, v, forward_batch

    def forward_normal_core(self, q, k, v, forward_batch):
1179
1180
1181
1182
1183
        attn_output = self.attn_mha(q, k, v, forward_batch, save_kv_cache=False)
        attn_output = attn_output.reshape(-1, self.num_local_heads * self.v_head_dim)
        output, _ = self.o_proj(attn_output)
        return output

Faraz's avatar
Faraz committed
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
    def _fuse_rope_for_trtllm_mla(self, forward_batch: ForwardBatch) -> bool:
        """
        Check if we should skip rope and do fused rope+quantize for TRTLLM MLA decode in fp8_e4m3 path.
        """
        return (
            self.current_attention_backend == "trtllm_mla"
            and forward_batch.forward_mode.is_decode_or_idle()
            and forward_batch.attn_backend.data_type == torch.float8_e4m3fn
        )

1194
    def forward_absorb_prepare(
1195
1196
1197
1198
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        forward_batch: ForwardBatch,
1199
        zero_allocator: BumpAllocator,
1200
    ):
1201
        from sglang.srt.model_executor.graph_runner import get_is_capture_mode
1202

1203
        if self.q_lora_rank is not None:
1204
1205
1206
1207
1208
1209
1210
            if hidden_states.shape[0] <= 16 and self.use_min_latency_fused_a_gemm:
                fused_qkv_a_proj_out = dsv3_fused_a_gemm(
                    hidden_states, self.fused_qkv_a_proj_with_mqa.weight.T
                )
            else:
                fused_qkv_a_proj_out = self.fused_qkv_a_proj_with_mqa(hidden_states)[0]
            q, latent_cache = fused_qkv_a_proj_out.split(
1211
1212
                [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1
            )
1213
1214
1215
            k_nope = latent_cache[..., : self.kv_lora_rank]

            # overlap qk norm
1216
            if self.alt_stream is not None and get_is_capture_mode():
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
                current_stream = torch.cuda.current_stream()
                self.alt_stream.wait_stream(current_stream)
                q = self.q_a_layernorm(q)
                with torch.cuda.stream(self.alt_stream):
                    k_nope = self.kv_a_layernorm(k_nope)
                current_stream.wait_stream(self.alt_stream)
            else:
                q = self.q_a_layernorm(q)
                k_nope = self.kv_a_layernorm(k_nope)

            k_nope = k_nope.unsqueeze(1)
1228
1229
1230
1231
1232
            q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
        else:
            q = self.q_proj(hidden_states)[0].view(
                -1, self.num_local_heads, self.qk_head_dim
            )
1233
            latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
1234
1235
1236
            k_nope = latent_cache[..., : self.kv_lora_rank]
            k_nope = self.kv_a_layernorm(k_nope).unsqueeze(1)

1237
        q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
1238
        k_pe = latent_cache[..., self.kv_lora_rank :].unsqueeze(1)
1239

1240
1241
        if self.use_deep_gemm_bmm:
            q_nope_val, q_nope_scale, masked_m, expected_m, aligned_m = (
1242
                per_token_group_quant_mla_deep_gemm_masked_fp8(q_nope.transpose(0, 1))
1243
1244
1245
1246
            )
            q_nope_out = q_nope.new_empty(
                (self.num_local_heads, aligned_m, self.kv_lora_rank)
            )
1247
            deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
1248
1249
1250
1251
1252
1253
1254
                (q_nope_val, q_nope_scale),
                (self.w_kc, self.w_scale_k),
                q_nope_out,
                masked_m,
                expected_m,
            )
            q_nope_out = q_nope_out[:, :expected_m, :]
1255
1256
        elif _is_hip:
            # TODO(haishaw): add bmm_fp8 to ROCm
1257
1258
1259
1260
            q_nope_out = torch.bmm(
                q_nope.to(torch.bfloat16).transpose(0, 1),
                self.w_kc.to(torch.bfloat16) * self.w_scale,
            )
1261
        elif self.w_kc.dtype == torch.float8_e4m3fn:
1262
            q_nope_val, q_nope_scale = per_tensor_quant_mla_fp8(
Lianmin Zheng's avatar
Lianmin Zheng committed
1263
                q_nope.transpose(0, 1),
1264
                zero_allocator.allocate(1),
1265
1266
1267
1268
1269
1270
            )
            q_nope_out = bmm_fp8(
                q_nope_val, self.w_kc, q_nope_scale, self.w_scale, torch.bfloat16
            )
        else:
            q_nope_out = torch.bmm(q_nope.transpose(0, 1), self.w_kc)
1271
1272

        q_nope_out = q_nope_out.transpose(0, 1)
Faraz's avatar
Faraz committed
1273
1274
1275

        if not self._fuse_rope_for_trtllm_mla(forward_batch):
            q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
1276

1277
1278
1279
1280
1281
        return q_pe, k_pe, q_nope_out, k_nope, forward_batch, zero_allocator

    def forward_absorb_core(
        self, q_pe, k_pe, q_nope_out, k_nope, forward_batch, zero_allocator
    ):
1282
        if (
1283
1284
1285
            self.current_attention_backend == "fa3"
            or self.current_attention_backend == "flashinfer"
            or self.current_attention_backend == "cutlass_mla"
1286
            or self.current_attention_backend == "trtllm_mla"
1287
        ):
Faraz's avatar
Faraz committed
1288
1289
1290
1291
1292
1293
            extra_args = {}
            if self._fuse_rope_for_trtllm_mla(forward_batch):
                extra_args = {
                    "cos_sin_cache": self.rotary_emb.cos_sin_cache,
                    "is_neox": self.rotary_emb.is_neox_style,
                }
1294
            attn_output = self.attn_mqa(
Faraz's avatar
Faraz committed
1295
1296
1297
1298
1299
1300
1301
                q_nope_out,
                k_nope,
                k_nope,
                forward_batch,
                q_rope=q_pe,
                k_rope=k_pe,
                **extra_args,
1302
1303
1304
            )
        else:
            q = torch.cat([q_nope_out, q_pe], dim=-1)
Ke Bao's avatar
Ke Bao committed
1305
            k = torch.cat([k_nope, k_pe], dim=-1)
1306
            attn_output = self.attn_mqa(q, k, k_nope, forward_batch)
1307
1308
        attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)

1309
1310
        if self.use_deep_gemm_bmm:
            attn_output_val, attn_output_scale, masked_m, expected_m, aligned_m = (
1311
1312
                per_token_group_quant_mla_deep_gemm_masked_fp8(
                    attn_output.transpose(0, 1)
1313
1314
1315
1316
1317
                )
            )
            attn_bmm_output = attn_output.new_empty(
                (self.num_local_heads, aligned_m, self.v_head_dim)
            )
1318
            deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
1319
1320
1321
1322
1323
1324
                (attn_output_val, attn_output_scale),
                (self.w_vc, self.w_scale_v),
                attn_bmm_output,
                masked_m,
                expected_m,
            )
Ke Bao's avatar
Ke Bao committed
1325
1326
1327
            attn_bmm_output = (
                attn_bmm_output[:, :expected_m, :].transpose(0, 1).flatten(1, 2)
            )
1328
1329
        elif _is_hip:
            # TODO(haishaw): add bmm_fp8 to ROCm
1330
1331
1332
1333
            attn_bmm_output = torch.bmm(
                attn_output.to(torch.bfloat16).transpose(0, 1),
                self.w_vc.to(torch.bfloat16) * self.w_scale,
            )
Ke Bao's avatar
Ke Bao committed
1334
            attn_bmm_output = attn_bmm_output.transpose(0, 1).flatten(1, 2)
1335
        elif self.w_vc.dtype == torch.float8_e4m3fn:
1336
            attn_output_val, attn_output_scale = per_tensor_quant_mla_fp8(
Lianmin Zheng's avatar
Lianmin Zheng committed
1337
                attn_output.transpose(0, 1),
1338
                zero_allocator.allocate(1),
1339
1340
1341
1342
1343
1344
1345
1346
            )
            attn_bmm_output = bmm_fp8(
                attn_output_val,
                self.w_vc,
                attn_output_scale,
                self.w_scale,
                torch.bfloat16,
            )
Ke Bao's avatar
Ke Bao committed
1347
            attn_bmm_output = attn_bmm_output.transpose(0, 1).flatten(1, 2)
1348
        else:
Ke Bao's avatar
Ke Bao committed
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
            attn_bmm_output = torch.empty(
                (attn_output.shape[0], self.num_local_heads * self.v_head_dim),
                dtype=attn_output.dtype,
                device=attn_output.device,
            )
            torch.bmm(
                attn_output.transpose(0, 1),
                self.w_vc,
                out=attn_bmm_output.view(
                    -1, self.num_local_heads, self.v_head_dim
                ).transpose(0, 1),
            )
        output, _ = self.o_proj(attn_bmm_output)
1362
1363
1364

        return output

1365
    def forward_absorb_fused_mla_rope_prepare(
1366
1367
1368
1369
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        forward_batch: ForwardBatch,
1370
        zero_allocator: BumpAllocator,
1371
    ):
1372
1373
1374
1375
1376
1377
1378
1379
        enable_rope_fusion = (
            os.getenv("SGLANG_FUSED_MLA_ENABLE_ROPE_FUSION", "1") == "1"
        )
        q_len = hidden_states.shape[0]
        q_input = hidden_states.new_empty(
            q_len, self.num_local_heads, self.kv_lora_rank + self.qk_rope_head_dim
        )
        if self.q_lora_rank is not None:
1380
1381
1382
            q, latent_cache = self.fused_qkv_a_proj_with_mqa(hidden_states)[0].split(
                [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1
            )
1383
1384
1385
1386
1387
1388
            q = self.q_a_layernorm(q)
            q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
        else:
            q = self.q_proj(hidden_states)[0].view(
                -1, self.num_local_heads, self.qk_head_dim
            )
1389
            latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
1390
1391
        q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)

1392
1393
        if _is_hip:
            # TODO(haishaw): add bmm_fp8 to ROCm
1394
1395
1396
1397
1398
            q_nope_out = torch.bmm(
                q_nope.to(torch.bfloat16).transpose(0, 1),
                self.w_kc.to(torch.bfloat16) * self.w_scale,
            )
        elif self.w_kc.dtype == torch.float8_e4m3fn:
1399
            q_nope_val, q_nope_scale = per_tensor_quant_mla_fp8(
1400
1401
1402
                q_nope.transpose(0, 1),
                zero_allocator.allocate(1),
                dtype=torch.float8_e4m3fn,
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
            )
            q_nope_out = bmm_fp8(
                q_nope_val, self.w_kc, q_nope_scale, self.w_scale, torch.bfloat16
            )
        else:
            q_nope_out = torch.bmm(q_nope.transpose(0, 1), self.w_kc)
        q_input[..., : self.kv_lora_rank] = q_nope_out.transpose(0, 1)
        v_input = latent_cache[..., : self.kv_lora_rank]
        v_input = self.kv_a_layernorm(v_input.contiguous()).unsqueeze(1)
        k_input = latent_cache.unsqueeze(1)
        k_input[..., : self.kv_lora_rank] = v_input

        if not enable_rope_fusion:
            k_pe = k_input[..., self.kv_lora_rank :]
            q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
            q_input[..., self.kv_lora_rank :] = q_pe
            k_input[..., self.kv_lora_rank :] = k_pe
            k_pe_output = None
        else:
            k_pe_output = torch.empty_like(k_input[..., self.kv_lora_rank :])

        q_input[..., self.kv_lora_rank :] = q_pe

        # attn_output = self.attn_mqa(q_input, k_input, v_input, forward_batch)
        # Use Fused ROPE with use_rope=OFF.
        attn_output = torch.empty(
            (q_len, self.num_local_heads, self.kv_lora_rank),
            dtype=q.dtype,
            device=q.device,
        )
        attn_logits, _, kv_indptr, kv_indices, _, _, _ = (
            forward_batch.attn_backend.forward_metadata
        )
        cos_sin_cache = self.rotary_emb.cos_sin_cache
        num_kv_split = forward_batch.attn_backend.num_kv_splits
        sm_scale = self.attn_mqa.scaling
        if attn_logits is None:
            attn_logits = torch.empty(
                (
                    forward_batch.batch_size,
                    self.num_local_heads,
                    num_kv_split,
                    self.kv_lora_rank + 1,
                ),
                dtype=torch.float32,
                device=q.device,
            )

        # save current latent cache.
        forward_batch.token_to_kv_pool.set_kv_buffer(
            self.attn_mqa, forward_batch.out_cache_loc, k_input, None
        )
        key_cache_buf = forward_batch.token_to_kv_pool.get_key_buffer(
            self.attn_mqa.layer_id
        )
        val_cache_buf = key_cache_buf[..., : self.kv_lora_rank]

1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
        return (
            q_input,
            key_cache_buf,
            val_cache_buf,
            attn_output,
            kv_indptr,
            kv_indices,
            k_pe_output,
            cos_sin_cache,
            positions,
            attn_logits,
            num_kv_split,
            sm_scale,
            enable_rope_fusion,
            k_input,
            forward_batch,
            zero_allocator,
        )

1479
1480
1481
1482
1483
1484
1485
    def forward_absorb_fused_mla_rope_cpu_prepare(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        forward_batch: ForwardBatch,
        zero_allocator: BumpAllocator,
    ):
1486
1487
        assert self.q_lora_rank is not None and use_intel_amx_backend(
            self
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
        ), "forward_absorb_fused_mla_rope_cpu_prepare requires q_lora_rank is not None and use_intel_amx_backend"

        q_input, k_input, v_input = (
            torch.ops.sgl_kernel.qkv_proj_with_rope_fused_weight(
                hidden_states,
                self.fused_qkv_a_proj_with_mqa.weight,
                self.q_b_proj.weight,
                self.w_kc,
                self.q_a_layernorm.weight,
                self.kv_a_layernorm.weight,
                positions,
                self.rotary_emb.cos_sin_cache,
                self.kv_a_layernorm.variance_epsilon,
                self.qkv_proj_with_rope_is_int8,
                self.qkv_proj_with_rope_is_fp8,
                (
                    self.fused_qkv_a_proj_with_mqa.weight_scale
                    if self.qkv_proj_with_rope_is_int8
                    else (
                        self.fused_qkv_a_proj_with_mqa.weight_scale_inv
                        if self.qkv_proj_with_rope_is_fp8
                        else None
                    )
                ),
                (
                    self.q_b_proj.weight_scale
                    if self.qkv_proj_with_rope_is_int8
                    else (
                        self.q_b_proj.weight_scale_inv
                        if self.qkv_proj_with_rope_is_fp8
                        else None
                    )
                ),
                True,  # is_vnni
                self.weight_block_size,
                self.q_lora_rank,
                self.kv_lora_rank,
                self.qk_rope_head_dim,
            )
        )
        return (q_input, k_input, v_input, forward_batch, zero_allocator)

1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
    def forward_absorb_fused_mla_rope_core(
        self,
        q_input,
        key_cache_buf,
        val_cache_buf,
        attn_output,
        kv_indptr,
        kv_indices,
        k_pe_output,
        cos_sin_cache,
        positions,
        attn_logits,
        num_kv_split,
        sm_scale,
        enable_rope_fusion,
        k_input,
        forward_batch,
        zero_allocator,
    ):
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
        decode_attention_fwd_grouped_rope(
            q_input,
            key_cache_buf,
            val_cache_buf,
            attn_output,
            kv_indptr,
            kv_indices,
            k_pe_output,
            self.kv_lora_rank,
            self.rotary_emb.rotary_dim,
            cos_sin_cache,
            positions,
            attn_logits,
            num_kv_split,
            sm_scale,
            logit_cap=self.attn_mqa.logit_cap,
            use_rope=enable_rope_fusion,
            is_neox_style=self.rotary_emb.is_neox_style,
        )

        if enable_rope_fusion:
            k_input[..., self.kv_lora_rank :] = k_pe_output
            forward_batch.token_to_kv_pool.set_kv_buffer(
                self.attn_mqa, forward_batch.out_cache_loc, k_input, None
            )

        attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)

1577
1578
        if _is_hip:
            # TODO(haishaw): add bmm_fp8 to ROCm
1579
1580
1581
1582
1583
            attn_bmm_output = torch.bmm(
                attn_output.to(torch.bfloat16).transpose(0, 1),
                self.w_vc.to(torch.bfloat16) * self.w_scale,
            )
        elif self.w_vc.dtype == torch.float8_e4m3fn:
1584
            attn_output_val, attn_output_scale = per_tensor_quant_mla_fp8(
1585
1586
1587
                attn_output.transpose(0, 1),
                zero_allocator.allocate(1),
                dtype=torch.float8_e4m3fn,
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
            )
            attn_bmm_output = bmm_fp8(
                attn_output_val,
                self.w_vc,
                attn_output_scale,
                self.w_scale,
                torch.bfloat16,
            )
        else:
            attn_bmm_output = torch.bmm(attn_output.transpose(0, 1), self.w_vc)
        attn_output = attn_bmm_output.transpose(0, 1).flatten(1, 2)
1599
1600
1601
1602
        output, _ = self.o_proj(attn_output)

        return output

1603
1604
1605
    def forward_absorb_fused_mla_rope_cpu_core(
        self, q_input, k_input, v_input, forward_batch, zero_allocator
    ):
1606
1607
        assert self.q_lora_rank is not None and use_intel_amx_backend(
            self
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
        ), "forward_absorb_fused_mla_rope_cpu_core requires q_lora_rank is not None and use_intel_amx_backend"

        attn_output = self.attn_mqa(q_input, k_input, v_input, forward_batch)
        attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)

        # [Note] Align shapes of bmm inputs.
        # Shapes of inputs:
        #   q_nope: [M, B, K]
        #   original self.w_kc: [B, K, N]
        #   current self.w_kc (which has been converted in PackWeightMethod): [B, N, K]

        # Shapes of inputs to sgl_kernel.cpu.bmm:
        #   out: [B, M, N]
        #   mat1: [B, M, K]
        #   mat2: [B, N, K]
        B = self.w_vc.size(0)
        N = self.w_vc.size(1)
        M = attn_output.size(0)
        output = torch.empty([M, int(B * N)], dtype=attn_output.dtype)
        attn_bmm_output = output.view([M, B, N]).transpose_(0, 1)
        torch.ops.sgl_kernel.bmm_cpu(
            attn_bmm_output,
            attn_output.transpose(0, 1),
            self.w_vc,
            True,  # is_vnni
            None,  # scale
        )
        attn_output = output
        output, _ = self.o_proj(attn_output)

        return output

1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
    def _chunked_prefix_attn_mha(
        self,
        q: torch.Tensor,
        accum_output: torch.Tensor,
        accum_lse: torch.Tensor,
        forward_batch: ForwardBatch,
    ) -> torch.Tensor:

        assert forward_batch.num_prefix_chunks is not None
        for i in range(forward_batch.num_prefix_chunks):
            forward_batch.set_prefix_chunk_idx(i)

            # Fetch latent cache from memory pool with precomputed chunked kv indices
            latent_cache_buf = forward_batch.token_to_kv_pool.get_key_buffer(
                self.attn_mha.layer_id
            )
            latent_cache = latent_cache_buf[
                forward_batch.prefix_chunk_kv_indices[i]
            ].contiguous()

            kv_a_normed, k_pe = latent_cache.split(
                [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
            )
            kv_a_normed = kv_a_normed.squeeze(1).contiguous()
            kv = self.kv_b_proj(kv_a_normed)[0]
            kv = kv.view(
                -1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim
            )
            v = kv[..., self.qk_nope_head_dim :]
            k_nope = kv[..., : self.qk_nope_head_dim]

            k = torch.empty(
                (
                    k_nope.shape[0],
                    self.num_local_heads,
                    self.qk_nope_head_dim + self.qk_rope_head_dim,
                ),
                dtype=v.dtype,
                device=v.device,
            )
            k[..., : self.qk_nope_head_dim] = k_nope
            k[..., self.qk_nope_head_dim :] = k_pe

            output, lse = self.attn_mha(q, k, v, forward_batch, save_kv_cache=False)
            lse = torch.transpose(lse, 0, 1).contiguous()
            tmp_output = torch.empty_like(accum_output)
            tmp_lse = torch.empty_like(accum_lse)
            merge_state_v2(output, lse, accum_output, accum_lse, tmp_output, tmp_lse)
            accum_output, accum_lse = tmp_output, tmp_lse

        return accum_output

1692
    def forward_normal_chunked_kv_prepare(
1693
1694
1695
1696
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        forward_batch: ForwardBatch,
1697
1698
        zero_allocator: BumpAllocator,
    ):
1699
1700
1701
1702
1703
1704
1705
1706
        # In normal mha, the k and v tensors will become overly large when the prefix length is long.
        # To avoid this, we split the kv cache into chunks and process them one after another.
        # Since mha is compute friendly, the for loop induced here will not introduce significant overhead.
        # The top comments in https://github.com/vllm-project/vllm/blob/main/vllm/v1/attention/backends/mla/common.py
        # will be helpful for understanding the purpose of this function.

        # First do normal mha forward to get output for extended part
        if self.q_lora_rank is not None:
1707
1708
1709
            q, latent_cache = self.fused_qkv_a_proj_with_mqa(hidden_states)[0].split(
                [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1
            )
1710
1711
1712
1713
1714
1715
            q = self.q_a_layernorm(q)
            q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
        else:
            q = self.q_proj(hidden_states)[0].view(
                -1, self.num_local_heads, self.qk_head_dim
            )
1716
            latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
1717
1718
1719
        _, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
        kv_a, _ = latent_cache.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
        latent_cache = latent_cache.unsqueeze(1)
1720
        kv_a = self.kv_a_layernorm(kv_a)
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
        kv = self.kv_b_proj(kv_a)[0]
        kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim)
        k_nope = kv[..., : self.qk_nope_head_dim]
        v = kv[..., self.qk_nope_head_dim :]
        k_pe = latent_cache[:, :, self.kv_lora_rank :]

        q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
        q[..., self.qk_nope_head_dim :] = q_pe
        k = torch.empty_like(q)
        k[..., : self.qk_nope_head_dim] = k_nope
        k[..., self.qk_nope_head_dim :] = k_pe

        latent_cache[:, :, : self.kv_lora_rank] = kv_a.unsqueeze(1)
        latent_cache[:, :, self.kv_lora_rank :] = k_pe

        # Save latent cache
        forward_batch.token_to_kv_pool.set_kv_buffer(
            self.attn_mha, forward_batch.out_cache_loc, latent_cache, None
        )

1741
1742
1743
        return q, k, v, forward_batch

    def forward_normal_chunked_kv_core(self, q, k, v, forward_batch):
1744
1745
1746
1747
1748
1749
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
        # Do mha for extended part without prefix
        forward_batch.set_attn_attend_prefix_cache(False)
        attn_output, lse = self.attn_mha(q, k, v, forward_batch, save_kv_cache=False)
        lse = torch.transpose(lse, 0, 1).contiguous()

        # Do mha attention with chunked prefix cache if there are any sequence with prefix
        if any(forward_batch.extend_prefix_lens_cpu):
            # Only initialize the info once
            if forward_batch.num_prefix_chunks is None:
                forward_batch.prepare_chunked_prefix_cache_info(q.device)

            forward_batch.set_attn_attend_prefix_cache(True)
            attn_output = self._chunked_prefix_attn_mha(
                q=q,
                accum_output=attn_output,
                accum_lse=lse,
                forward_batch=forward_batch,
            )

        attn_output = attn_output.reshape(-1, self.num_local_heads * self.v_head_dim)
        output, _ = self.o_proj(attn_output)
        return output

1767

Liangsheng Yin's avatar
Liangsheng Yin committed
1768
1769
1770
1771
1772
1773
1774
class DeepseekV2DecoderLayer(nn.Module):

    def __init__(
        self,
        config: PretrainedConfig,
        layer_id: int,
        quant_config: Optional[QuantizationConfig] = None,
1775
        is_nextn: bool = False,
1776
        prefix: str = "",
1777
        alt_stream: Optional[torch.cuda.Stream] = None,
Liangsheng Yin's avatar
Liangsheng Yin committed
1778
1779
1780
    ) -> None:
        super().__init__()
        self.hidden_size = config.hidden_size
1781
        self.config = config
Liangsheng Yin's avatar
Liangsheng Yin committed
1782
1783
1784
        rope_theta = getattr(config, "rope_theta", 10000)
        rope_scaling = getattr(config, "rope_scaling", None)
        max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
1785
        self.speculative_algorithm = global_server_args_dict["speculative_algorithm"]
Lianmin Zheng's avatar
Lianmin Zheng committed
1786
        self.layer_id = layer_id
1787
        self.is_nextn = is_nextn
Baizhou Zhang's avatar
Baizhou Zhang committed
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803
1804
1805
        self.self_attn = DeepseekV2AttentionMLA(
            config=config,
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
            qk_nope_head_dim=config.qk_nope_head_dim,
            qk_rope_head_dim=config.qk_rope_head_dim,
            v_head_dim=config.v_head_dim,
            q_lora_rank=(
                config.q_lora_rank if hasattr(config, "q_lora_rank") else None
            ),
            kv_lora_rank=config.kv_lora_rank,
            rope_theta=rope_theta,
            rope_scaling=rope_scaling,
            max_position_embeddings=max_position_embeddings,
            quant_config=quant_config,
            layer_id=layer_id,
            reduce_results=False,
            prefix=add_prefix("self_attn", prefix),
1806
            alt_stream=alt_stream,
Baizhou Zhang's avatar
Baizhou Zhang committed
1807
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
1808

1809
1810
1811
1812
1813
        self.is_layer_sparse = self._is_layer_sparse(layer_id, is_nextn=is_nextn)
        is_previous_layer_sparse = self._is_layer_sparse(layer_id - 1, is_nextn=False)

        self.layer_scatter_modes = LayerScatterModes.init_new(
            layer_id=layer_id,
1814
            num_layers=1 if is_nextn else config.num_hidden_layers,
1815
1816
            is_layer_sparse=self.is_layer_sparse,
            is_previous_layer_sparse=is_previous_layer_sparse,
1817
1818
        )

1819
        if self.is_layer_sparse:
1820
1821
1822
1823
            self.mlp = DeepseekV2MoE(
                config=config,
                quant_config=quant_config,
                prefix=add_prefix("mlp", prefix),
fzyzcjy's avatar
fzyzcjy committed
1824
                layer_id=self.layer_id,
1825
                alt_stream=alt_stream,
1826
                is_nextn=is_nextn,
1827
            )
Liangsheng Yin's avatar
Liangsheng Yin committed
1828
        else:
1829
            if enable_moe_dense_fully_dp():
1830
1831
1832
                mlp_tp_rank, mlp_tp_size = 0, 1
            else:
                mlp_tp_rank, mlp_tp_size = None, None
Liangsheng Yin's avatar
Liangsheng Yin committed
1833
1834
1835
1836
1837
            self.mlp = DeepseekV2MLP(
                hidden_size=config.hidden_size,
                intermediate_size=config.intermediate_size,
                hidden_act=config.hidden_act,
                quant_config=quant_config,
1838
                prefix=add_prefix("mlp", prefix),
1839
1840
                tp_rank=mlp_tp_rank,
                tp_size=mlp_tp_size,
Liangsheng Yin's avatar
Liangsheng Yin committed
1841
            )
1842

Liangsheng Yin's avatar
Liangsheng Yin committed
1843
1844
1845
1846
1847
        self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(
            config.hidden_size, eps=config.rms_norm_eps
        )

1848
1849
1850
1851
        self.layer_communicator = LayerCommunicator(
            layer_scatter_modes=self.layer_scatter_modes,
            input_layernorm=self.input_layernorm,
            post_attention_layernorm=self.post_attention_layernorm,
1852
            allow_reduce_scatter=True,
1853
        )
1854

1855
1856
        self._fuse_allreduce_lookup_table = self._build_fuse_allreduce_lookup_table()

1857
1858
1859
1860
1861
    def _is_layer_sparse(self, layer_id: int, is_nextn: bool) -> bool:
        return is_nextn or (
            self.config.n_routed_experts is not None
            and layer_id >= self.config.first_k_dense_replace
            and layer_id % self.config.moe_layer_freq == 0
1862
1863
        )

1864
    def _should_fuse_mlp_allreduce_with_next_layer(self, forward_batch) -> bool:
1865
        """Check if MLP allreduce can be fused with next layer's residual_rmsnorm"""
1866

1867
1868
1869
1870
1871
        batch_size = (
            forward_batch.input_ids.shape[0]
            if hasattr(forward_batch, "input_ids")
            else 0
        )
1872

1873
        if batch_size > 128:
1874
1875
            return False

1876
        return self._fuse_allreduce_lookup_table.get(batch_size, False)
1877

Liangsheng Yin's avatar
Liangsheng Yin committed
1878
1879
1880
1881
    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
1882
        forward_batch: ForwardBatch,
Liangsheng Yin's avatar
Liangsheng Yin committed
1883
        residual: Optional[torch.Tensor],
1884
        zero_allocator: BumpAllocator,
Liangsheng Yin's avatar
Liangsheng Yin committed
1885
    ) -> torch.Tensor:
1886

1887
1888
        hidden_states, residual = self.layer_communicator.prepare_attn(
            hidden_states, residual, forward_batch
1889
1890
        )

1891
1892
1893
1894
1895
1896
1897
1898
1899
1900
1901
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
            forward_batch=forward_batch,
            zero_allocator=zero_allocator,
        )

        hidden_states, residual = self.layer_communicator.prepare_mlp(
            hidden_states, residual, forward_batch
        )

1902
        should_allreduce_fusion = (
1903
            self._should_fuse_mlp_allreduce_with_next_layer(forward_batch)
1904
1905
1906
            and not (
                is_dp_attention_enabled() and self.speculative_algorithm.is_eagle()
            )
1907
            and not self.is_nextn
1908
1909
        )

1910
1911
1912
1913
1914
        # For DP with padding, reduce scatter can be used instead of all-reduce.
        use_reduce_scatter = self.layer_communicator.should_use_reduce_scatter(
            forward_batch
        )
        hidden_states = self.mlp(
1915
            hidden_states, forward_batch, should_allreduce_fusion, use_reduce_scatter
1916
        )
1917

1918
        if should_allreduce_fusion:
1919
1920
            hidden_states._sglang_needs_allreduce_fusion = True

1921
        if not should_allreduce_fusion:
1922
1923
1924
1925
            hidden_states, residual = self.layer_communicator.postprocess_layer(
                hidden_states, residual, forward_batch
            )

1926
1927
        return hidden_states, residual

1928
1929
1930
1931
1932
1933
1934
1935
    def op_comm_prepare_attn(
        self,
        state,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        forward_batch: ForwardBatch,
        residual: Optional[torch.Tensor],
        zero_allocator: BumpAllocator,
1936
        tbo_subbatch_index: Optional[int] = None,
1937
1938
    ):
        state.hidden_states_after_comm_pre_attn, state.residual_after_input_ln = (
fzyzcjy's avatar
fzyzcjy committed
1939
            self.layer_communicator.prepare_attn(hidden_states, residual, forward_batch)
1940
1941
1942
1943
1944
1945
        )
        state.update(
            dict(
                forward_batch=forward_batch,
                positions=positions,
                zero_allocator=zero_allocator,
1946
                tbo_subbatch_index=tbo_subbatch_index,
1947
            )
1948
        )
1949

1950
1951
1952
1953
1954
1955
1956
    def op_comm_prepare_mlp(self, state):
        state.hidden_states_mlp_input, state.residual_after_comm_pre_mlp = (
            self.layer_communicator.prepare_mlp(
                state.pop("hidden_states_after_attn"),
                state.pop("residual_after_input_ln"),
                state.forward_batch,
            )
1957
        )
1958

1959
1960
1961
1962
1963
1964
1965
1966
    def op_mlp(self, state):
        hidden_states = state.pop("hidden_states_mlp_input")
        if not (
            enable_moe_dense_fully_dp()
            and (not self.is_layer_sparse)
            and hidden_states.shape[0] == 0
        ):
            state.hidden_states_mlp_output = self.mlp(
1967
                hidden_states, state.forward_batch
1968
1969
1970
            )
        else:
            state.hidden_states_mlp_output = hidden_states
1971

1972
    def op_comm_postprocess_layer(self, state):
1973
        hidden_states, residual = self.layer_communicator.postprocess_layer(
1974
1975
1976
            state.pop("hidden_states_mlp_output"),
            state.pop("residual_after_comm_pre_mlp"),
            state.forward_batch,
1977
        )
1978

1979
1980
1981
1982
1983
1984
1985
1986
1987
1988
1989
1990
1991
1992
1993
1994
1995
1996
        output = dict(
            positions=state.positions,
            hidden_states=hidden_states,
            residual=residual,
            forward_batch=state.forward_batch,
            zero_allocator=state.zero_allocator,
            tbo_subbatch_index=state.tbo_subbatch_index,
        )

        state.clear(
            expect_keys={
                "positions",
                "forward_batch",
                "zero_allocator",
                "tbo_subbatch_index",
            }
        )
        return output
1997

1998
1999
2000
2001
2002
2003
2004
2005
2006
2007
2008
2009
2010
2011
2012
2013
2014
2015
2016
2017
    def _build_fuse_allreduce_lookup_table(self):
        static_conditions_met = (
            self.layer_id != self.config.num_hidden_layers - 1
            and get_tensor_model_parallel_world_size() > 1
            and global_server_args_dict.get("enable_flashinfer_allreduce_fusion", False)
            and _is_sm100_supported
            and _is_flashinfer_available
        )

        if not static_conditions_met:
            return {}

        lookup_table = {}
        for batch_size in range(129):  # 0 to 128
            is_last_layer = self.layer_id == self.config.num_hidden_layers - 1
            should_fuse = batch_size > 0 and batch_size <= 128 and not is_last_layer
            lookup_table[batch_size] = should_fuse

        return lookup_table

Liangsheng Yin's avatar
Liangsheng Yin committed
2018
2019
2020
2021
2022
2023
2024
2025

class DeepseekV2Model(nn.Module):
    fall_back_to_pt_during_load = False

    def __init__(
        self,
        config: PretrainedConfig,
        quant_config: Optional[QuantizationConfig] = None,
2026
        prefix: str = "",
Liangsheng Yin's avatar
Liangsheng Yin committed
2027
2028
2029
2030
    ) -> None:
        super().__init__()
        self.padding_id = config.pad_token_id
        self.vocab_size = config.vocab_size
2031
        self.first_k_dense_replace = config.first_k_dense_replace
Liangsheng Yin's avatar
Liangsheng Yin committed
2032
2033
2034
2035

        self.embed_tokens = VocabParallelEmbedding(
            config.vocab_size,
            config.hidden_size,
2036
            enable_tp=not is_dp_attention_enabled(),
Liangsheng Yin's avatar
Liangsheng Yin committed
2037
        )
2038
        self.alt_stream = torch.cuda.Stream() if _is_cuda else None
Liangsheng Yin's avatar
Liangsheng Yin committed
2039
2040
2041
2042
2043
2044
        self.layers = nn.ModuleList(
            [
                DeepseekV2DecoderLayer(
                    config,
                    layer_id,
                    quant_config=quant_config,
2045
                    prefix=add_prefix(f"layers.{layer_id}", prefix),
2046
                    alt_stream=self.alt_stream,
Liangsheng Yin's avatar
Liangsheng Yin committed
2047
2048
2049
2050
2051
2052
                )
                for layer_id in range(config.num_hidden_layers)
            ]
        )
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

2053
2054
2055
    def get_input_embeddings(self) -> torch.Tensor:
        return self.embed_tokens

Liangsheng Yin's avatar
Liangsheng Yin committed
2056
2057
2058
2059
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
2060
        forward_batch: ForwardBatch,
2061
        input_embeds: torch.Tensor = None,
Liangsheng Yin's avatar
Liangsheng Yin committed
2062
    ) -> torch.Tensor:
2063
2064
        total_num_layers = len(self.layers)
        device = input_embeds.device if input_embeds is not None else input_ids.device
2065
        zero_allocator = BumpAllocator(
2066
            buffer_size=total_num_layers * 2 * (2 if forward_batch.can_run_tbo else 1),
2067
            dtype=torch.float32,
2068
            device=device,
2069
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
2070

2071
2072
2073
2074
2075
        if input_embeds is None:
            hidden_states = self.embed_tokens(input_ids)
        else:
            hidden_states = input_embeds

Liangsheng Yin's avatar
Liangsheng Yin committed
2076
        residual = None
2077
2078
2079
2080
2081
2082
2083

        normal_num_layers = (
            self.first_k_dense_replace
            if forward_batch.can_run_tbo
            else total_num_layers
        )
        for i in range(normal_num_layers):
2084
2085
2086
2087
2088
            with get_global_expert_distribution_recorder().with_current_layer(i):
                layer = self.layers[i]
                hidden_states, residual = layer(
                    positions, hidden_states, forward_batch, residual, zero_allocator
                )
2089
2090
2091
2092
2093
2094
2095
2096
2097

        if normal_num_layers != total_num_layers:
            hidden_states, residual = model_forward_maybe_tbo(
                layers=self.layers[normal_num_layers:],
                enable_tbo=True,
                positions=positions,
                forward_batch=forward_batch,
                hidden_states=hidden_states,
                residual=residual,
2098
2099
2100
                input_data_scatter_mode=self.layers[
                    normal_num_layers - 1
                ].layer_scatter_modes.layer_output_mode,
2101
2102
2103
                zero_allocator=zero_allocator,
            )

Ke Bao's avatar
Ke Bao committed
2104
        if not forward_batch.forward_mode.is_idle():
2105
2106
2107
2108
            if residual is None:
                hidden_states = self.norm(hidden_states)
            else:
                hidden_states, _ = self.norm(hidden_states, residual)
Liangsheng Yin's avatar
Liangsheng Yin committed
2109
2110
2111
2112
        return hidden_states


class DeepseekV2ForCausalLM(nn.Module):
2113
2114
    # for quark model load
    packed_modules_mapping = {}
Liangsheng Yin's avatar
Liangsheng Yin committed
2115
2116
2117
2118
2119

    def __init__(
        self,
        config: PretrainedConfig,
        quant_config: Optional[QuantizationConfig] = None,
2120
        prefix: str = "",
Liangsheng Yin's avatar
Liangsheng Yin committed
2121
2122
    ) -> None:
        super().__init__()
2123
2124
2125
2126
2127
2128
2129
2130
2131
2132
2133
2134

        # for quark model load
        # Fuse q_a_proj and kv_a_proj_with_mqa along output dimension when q_lora_rank is not None
        self.fuse_qkv_a_proj = (
            hasattr(config, "q_lora_rank") and config.q_lora_rank is not None
        )
        if self.fuse_qkv_a_proj:
            self.packed_modules_mapping["fused_qkv_a_proj_with_mqa"] = [
                "q_a_proj",
                "kv_a_proj_with_mqa",
            ]

Liangsheng Yin's avatar
Liangsheng Yin committed
2135
        self.config = config
2136
        self.tp_size = get_tensor_model_parallel_world_size()
Liangsheng Yin's avatar
Liangsheng Yin committed
2137
        self.quant_config = quant_config
2138
        self.determine_num_fused_shared_experts()
2139
2140
2141
2142
2143
2144
2145
2146
        self.model = DeepseekV2Model(
            config, quant_config, prefix=add_prefix("model", prefix)
        )
        self.lm_head = ParallelLMHead(
            config.vocab_size,
            config.hidden_size,
            quant_config=quant_config,
            prefix=add_prefix("lm_head", prefix),
2147
            use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
2148
2149
2150
        )
        self.logits_processor = LogitsProcessor(config)

2151
2152
2153
2154
2155
2156
2157
2158
2159
2160
2161
2162
        self._routed_experts_weights_of_layer = LazyValue(
            lambda: {
                layer_id: layer.mlp.get_moe_weights()
                for layer_id, layer in enumerate(self.model.layers)
                if isinstance(layer.mlp, DeepseekV2MoE)
            }
        )

    @property
    def routed_experts_weights_of_layer(self):
        return self._routed_experts_weights_of_layer.value

2163
    def determine_num_fused_shared_experts(
2164
2165
        self, architecture: str = "DeepseekV3ForCausalLM"
    ):
2166
2167
2168
2169
2170
2171
2172
2173
        self.num_fused_shared_experts = 0
        if global_server_args_dict["disable_shared_experts_fusion"]:
            return

        # Only Deepseek V3/R1 can use shared experts fusion optimization now.
        disable_reason = None
        if (
            not _is_cuda
2174
            or torch.cuda.get_device_capability("cuda") < (8, 0)
2175
2176
2177
2178
            or self.config.architectures[0] != architecture
            or self.config.n_routed_experts != 256
            or self.config.n_shared_experts != 1
        ):
2179
            disable_reason = "Only Deepseek V3/R1 on NV-platform with capability >= 80 can use shared experts fusion optimization."
2180
2181
        elif get_moe_expert_parallel_world_size() > 1:
            disable_reason = "Deepseek V3/R1 can not use shared experts fusion optimization under expert parallelism."
2182
2183
2184

        if disable_reason is not None:
            global_server_args_dict["disable_shared_experts_fusion"] = True
Cheng Wan's avatar
Cheng Wan committed
2185
            self.num_fused_shared_experts = 0
2186
2187
2188
2189
2190
2191
2192
            log_info_on_rank0(
                logger,
                f"{disable_reason} Shared experts fusion optimization is disabled.",
            )
            return

        self.num_fused_shared_experts = self.config.n_shared_experts
2193

Mick's avatar
Mick committed
2194
2195
2196
    def get_input_embeddings(self) -> nn.Embedding:
        return self.model.embed_tokens

2197
    @torch.no_grad()
Liangsheng Yin's avatar
Liangsheng Yin committed
2198
2199
2200
2201
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
2202
        forward_batch: ForwardBatch,
2203
        input_embeds: torch.Tensor = None,
Liangsheng Yin's avatar
Liangsheng Yin committed
2204
    ) -> torch.Tensor:
2205
        hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
Lianmin Zheng's avatar
Lianmin Zheng committed
2206

2207
2208
2209
        return self.logits_processor(
            input_ids, hidden_states, self.lm_head, forward_batch
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
2210

2211
    def post_load_weights(self, is_nextn=False, weight_names=None):
inkcherry's avatar
inkcherry committed
2212
2213

        # Perform post-processing after loading weights
2214
2215
2216
2217
2218
2219
2220
2221
2222
2223
        if is_nextn:
            layer_ids = [self.config.num_hidden_layers]
        else:
            if weight_names is None:
                layer_ids = range(self.config.num_hidden_layers)
            else:
                layer_ids = set()
                for name in weight_names:
                    if "kv_b_proj" in name:
                        layer_id = int(name.split(".")[2])
2224
                        if layer_id < self.config.num_hidden_layers:
2225
2226
                            layer_ids.add(layer_id)

2227
2228
2229
2230
2231
2232
        for layer_id in layer_ids:
            self_attn = (
                self.model.layers[layer_id].self_attn
                if not is_nextn
                else self.model.decoder.self_attn
            )
Baizhou Zhang's avatar
Baizhou Zhang committed
2233
2234
            if hasattr(self_attn.kv_b_proj, "qweight"):
                # AWQ compatible
2235
                if _is_cuda or _is_hip:
Baizhou Zhang's avatar
Baizhou Zhang committed
2236
2237
2238
2239
2240
                    w = awq_dequantize(
                        self_attn.kv_b_proj.qweight,
                        self_attn.kv_b_proj.scales,
                        self_attn.kv_b_proj.qzeros,
                    ).T
inkcherry's avatar
inkcherry committed
2241
                else:
Baizhou Zhang's avatar
Baizhou Zhang committed
2242
2243
2244
2245
2246
2247
2248
2249
2250
2251
2252
2253
                    w = awq_dequantize(
                        self_attn.kv_b_proj.qweight,
                        self_attn.kv_b_proj.scales,
                        self_attn.kv_b_proj.qzeros,
                        0,
                        0,
                        0,
                    ).T
            else:
                w = self_attn.kv_b_proj.weight
            # NOTE(HandH1998): Since `bmm_fp8` only supports per-tensor scale, we have to requantize `self_attn.kv_b_proj`.
            # This may affect the accuracy of fp8 model.
2254
2255
2256
            # Fix deepseek v3 blockwise bmm by using deep_gemm
            use_deep_gemm_bmm = False

Baizhou Zhang's avatar
Baizhou Zhang committed
2257
2258
2259
2260
            if w.dtype in (
                torch.float8_e4m3fn,
                torch.float8_e4m3fnuz,
            ):
2261
2262
2263
2264
                if (
                    hasattr(self.quant_config, "weight_block_size")
                    and self.quant_config.weight_block_size is not None
                ):
Baizhou Zhang's avatar
Baizhou Zhang committed
2265
                    weight_block_size = self.quant_config.weight_block_size
2266
2267
2268
2269
2270
2271
2272
2273
2274
2275
2276
2277
2278
2279
2280
2281
                    assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
                    if _is_fp8_fnuz:
                        weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
                            weight=w,
                            weight_scale=self_attn.kv_b_proj.weight_scale_inv,
                            input_scale=None,
                        )
                    else:
                        weight = w
                        weight_scale = self_attn.kv_b_proj.weight_scale_inv

                    if (
                        _is_cuda
                        and weight_block_size[0] == 128
                        and weight_block_size[1] == 128
                    ):
2282
2283
2284
2285
                        if (
                            deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
                            and not deep_gemm_wrapper.DEEPGEMM_BLACKWELL
                            and get_bool_env_var("SGL_USE_DEEPGEMM_BMM", "false")
2286
                        ):
2287
2288
                            block_scale = weight_scale
                            use_deep_gemm_bmm = True
2289
                        else:
2290
2291
2292
2293
                            w = block_quant_dequant(
                                weight,
                                weight_scale,
                                weight_block_size,
2294
                                torch.bfloat16,
2295
                            )
2296
2297
2298
2299
2300
                    else:
                        w, scale = block_quant_to_tensor_quant(
                            weight, weight_scale, weight_block_size
                        )
                        self_attn.w_scale = scale
Baizhou Zhang's avatar
Baizhou Zhang committed
2301
                else:
2302
2303
2304
2305
2306
2307
2308
2309
2310
2311
                    if _is_fp8_fnuz:
                        weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
                            weight=w,
                            weight_scale=self_attn.kv_b_proj.weight_scale,
                            input_scale=None,
                        )
                    else:
                        weight = w
                        weight_scale = self_attn.kv_b_proj.weight_scale

Baizhou Zhang's avatar
Baizhou Zhang committed
2312
2313
2314
2315
2316
2317
2318
2319
2320
2321
2322
2323
2324
2325
2326
2327
2328
2329
2330
                    w, scale = channel_quant_to_tensor_quant(weight, weight_scale)
                    self_attn.w_scale = scale

            if w.dtype == torch.int8:
                if hasattr(self.quant_config, "weight_block_size"):
                    # block-wise int8 need it
                    weight_block_size = self.quant_config.weight_block_size
                    if weight_block_size is not None:
                        assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
                        weight = w
                        weight_scale = self_attn.kv_b_proj.weight_scale_inv
                        w = int8_block_dequant(
                            weight, weight_scale, weight_block_size
                        ).to(torch.bfloat16)
                else:
                    # channel-wise int8 need it
                    w = w.to(torch.bfloat16) * self_attn.kv_b_proj.weight_scale.to(
                        torch.bfloat16
                    )
2331

Baizhou Zhang's avatar
Baizhou Zhang committed
2332
2333
2334
            w_kc, w_vc = w.unflatten(
                0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
            ).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
2335
            if not use_deep_gemm_bmm:
2336
2337
2338
2339
2340
2341
                self_attn.w_kc = bind_or_assign(
                    self_attn.w_kc, w_kc.transpose(1, 2).contiguous().transpose(1, 2)
                )
                self_attn.w_vc = bind_or_assign(
                    self_attn.w_vc, w_vc.contiguous().transpose(1, 2)
                )
2342
2343
2344
2345
                if (
                    hasattr(self_attn.kv_b_proj, "weight_scale")
                    and self_attn.w_scale is None
                ):
2346
2347
2348
                    self_attn.w_scale = bind_or_assign(
                        self_attn.w_scale, self_attn.kv_b_proj.weight_scale
                    )
2349
2350
                    if _is_hip:
                        self_attn.w_scale *= 2.0
2351
2352
2353
2354
2355
2356
2357
2358
                # TODO: remove this after adding FP8 support in bmm cpu kernel
                if _is_cpu and _is_cpu_amx_available and w.dtype == torch.float8_e4m3fn:
                    self_attn.w_kc = (
                        self_attn.w_kc.to(torch.bfloat16) * self_attn.w_scale
                    )
                    self_attn.w_vc = (
                        self_attn.w_vc.to(torch.bfloat16) * self_attn.w_scale
                    )
2359
2360
2361
2362
2363
2364
            else:
                num_tiles_k = self_attn.qk_nope_head_dim // weight_block_size[1]
                num_tiles_n = self_attn.v_head_dim // weight_block_size[0]
                ws_kc, ws_vc = block_scale.unflatten(
                    0, (-1, (num_tiles_k + num_tiles_n))
                ).split([num_tiles_k, num_tiles_n], dim=1)
2365
2366
2367
2368
2369
2370
2371
2372
2373
2374
                self_attn.w_scale_k = bind_or_assign(
                    self_attn.w_scale_k, ws_kc.transpose(1, 2).contiguous()
                )
                self_attn.w_scale_v = bind_or_assign(
                    self_attn.w_scale_v, ws_vc.contiguous()
                )
                self_attn.w_kc = bind_or_assign(
                    self_attn.w_kc, w_kc.transpose(1, 2).contiguous()
                )
                self_attn.w_vc = bind_or_assign(self_attn.w_vc, w_vc.contiguous())
2375
                self_attn.use_deep_gemm_bmm = True
inkcherry's avatar
inkcherry committed
2376

2377
2378
2379
        if (
            deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
            and deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
2380
2381
            and hasattr(self.quant_config, "weight_block_size")
            and self.quant_config.weight_block_size is not None
2382
        ):
2383
            self._weight_requant_ue8m0(is_nextn)
2384

2385
    def _weight_requant_ue8m0(self, is_nextn=False):
2386
2387
2388
2389
2390
2391
2392
2393
2394
2395
        weight_block_size = self.quant_config.weight_block_size

        moe_layers = list(
            range(
                self.config.first_k_dense_replace,
                self.config.num_hidden_layers,
                self.config.moe_layer_freq,
            )
        )

2396
2397
2398
2399
2400
2401
        num_hidden_layers = 1 if is_nextn else self.config.num_hidden_layers
        for layer_id in range(num_hidden_layers):
            if is_nextn:
                layer = self.model.decoder
            else:
                layer = self.model.layers[layer_id]
2402
2403
2404
2405
2406
2407
2408
2409
2410
2411
2412

            for module in [
                layer.self_attn.fused_qkv_a_proj_with_mqa,
                layer.self_attn.q_b_proj,
                layer.self_attn.kv_b_proj,
                layer.self_attn.o_proj,
            ]:
                requant_weight_ue8m0_inplace(
                    module.weight, module.weight_scale_inv, weight_block_size
                )

2413
            if layer_id in moe_layers or is_nextn:
2414
2415
2416
2417
2418
2419
2420
2421
2422
                shared_experts = getattr(layer.mlp, "shared_experts", None)
                if shared_experts is not None:
                    for module in [
                        shared_experts.gate_up_proj,
                        shared_experts.down_proj,
                    ]:
                        requant_weight_ue8m0_inplace(
                            module.weight, module.weight_scale_inv, weight_block_size
                        )
2423
2424
2425
2426
2427
2428
2429
2430
2431
2432
2433
2434
2435
2436
2437
2438
2439
2440
2441

                experts = layer.mlp.experts
                if isinstance(experts, DeepEPMoE):
                    for w in [
                        experts.w13_weight_fp8,
                        experts.w2_weight_fp8,
                    ]:
                        requant_weight_ue8m0_inplace(w[0], w[1], weight_block_size)
            else:
                mlp = layer.mlp
                assert isinstance(mlp, DeepseekV2MLP)
                for module in [
                    mlp.gate_up_proj,
                    mlp.down_proj,
                ]:
                    requant_weight_ue8m0_inplace(
                        module.weight, module.weight_scale_inv, weight_block_size
                    )

2442
    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], is_nextn=False):
2443

2444
2445
2446
        if is_nextn:
            if hasattr(self.config, "num_nextn_predict_layers"):
                num_nextn_layers = self.config.num_nextn_predict_layers
2447
                assert num_nextn_layers == 1, "Only 1 nextn layer is supported"
2448
2449
2450
2451
2452
2453
2454
2455
2456
                # compatible with old design
                nextn_layer_id = (
                    0
                    if self.config.num_hidden_layers == 1
                    else self.config.num_hidden_layers
                )
            else:
                raise ValueError("num_nextn_predict_layers is not in the config")

Liangsheng Yin's avatar
Liangsheng Yin committed
2457
2458
2459
2460
2461
2462
2463
2464
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
        ]

        # Params for weights, fp8 weight scales, fp8 activation scales
        # (param_name, weight_name, expert_id, shard_id)
2465
        expert_params_mapping = FusedMoE.make_expert_params_mapping(
Liangsheng Yin's avatar
Liangsheng Yin committed
2466
2467
2468
            ckpt_gate_proj_name="gate_proj",
            ckpt_down_proj_name="down_proj",
            ckpt_up_proj_name="up_proj",
2469
            num_experts=self.config.n_routed_experts + self.num_fused_shared_experts,
Liangsheng Yin's avatar
Liangsheng Yin committed
2470
        )
2471
        if self.quant_config and self.quant_config.get_name() == "w4afp8":
2472
2473
            expert_params_mapping += FusedMoE.make_expert_input_scale_params_mapping(
                num_experts=self.config.n_routed_experts
2474
            )
Liangsheng Yin's avatar
Liangsheng Yin committed
2475

2476
2477
2478
2479
2480
2481
        # Fuse q_a_proj and kv_a_proj_with_mqa along output dimension when q_lora_rank is not None
        fuse_qkv_a_proj = hasattr(self.config, "q_lora_rank") and (
            self.config.q_lora_rank is not None
        )
        cached_a_proj = {} if fuse_qkv_a_proj else None

2482
2483
2484
2485
2486
2487
2488
2489
2490
        if is_nextn:
            nextn_layer_prefix = f"model.layers.{nextn_layer_id}"
            nextn_spec_weight_names = [
                "shared_head.norm",
                "eh_proj",
                "enorm",
                "hnorm",
            ]

2491
2492
        if self.num_fused_shared_experts > 0:
            assert self.num_fused_shared_experts == 1
2493
            log_info_on_rank0(logger, "Shared experts fusion optimization enabled.")
2494

2495
2496
2497
2498
2499
2500
2501
2502
2503
2504
        with concurrent.futures.ThreadPoolExecutor() as executor:
            futures = []
            params_dict = dict(self.named_parameters())
            weight_names = []
            for name, loaded_weight in weights:
                if self.num_fused_shared_experts > 0 and "mlp.shared_experts" in name:
                    name = name.replace(
                        "mlp.shared_experts",
                        f"mlp.experts.{self.config.n_routed_experts}",
                    )
2505

2506
                weight_names.append(name)
2507

2508
2509
2510
2511
2512
2513
2514
2515
2516
2517
2518
2519
2520
                if not is_nextn:
                    if hasattr(self.config, "num_nextn_predict_layers"):
                        num_nextn_layers = self.config.num_nextn_predict_layers
                        if num_nextn_layers > 0 and name.startswith("model.layers"):
                            name_list = name.split(".")
                            if (
                                len(name_list) >= 3
                                and int(name_list[2]) >= self.config.num_hidden_layers
                            ):
                                continue
                else:
                    if not name.startswith(nextn_layer_prefix):
                        continue
2521

2522
2523
2524
                    # Use shared head and embed weights from target model
                    if "shared_head.head" in name or "embed_tokens" in name:
                        continue
2525

2526
2527
2528
2529
2530
2531
2532
2533
2534
2535
2536
2537
                    is_decoder = True
                    # For nextn specific weights
                    for weight_name in nextn_spec_weight_names:
                        if weight_name in name:
                            name = name.replace(nextn_layer_prefix, "model")
                            is_decoder = False
                            break
                    # For decoder layer weights
                    if is_decoder:
                        name = name.replace(nextn_layer_prefix, "model.decoder")

                if "rotary_emb.inv_freq" in name:
Liangsheng Yin's avatar
Liangsheng Yin committed
2538
                    continue
2539
2540
                for param_name, weight_name, shard_id in stacked_params_mapping:
                    # Skip non-stacked layers and experts (experts handled below).
Liangsheng Yin's avatar
Liangsheng Yin committed
2541
2542
                    if weight_name not in name:
                        continue
2543
2544
2545
2546
2547
2548
2549
2550
                    # We have mlp.experts[0].gate_proj in the checkpoint.
                    # Since we handle the experts below in expert_params_mapping,
                    # we need to skip here BEFORE we update the name, otherwise
                    # name will be updated to mlp.experts[0].gate_up_proj, which
                    # will then be updated below in expert_params_mapping
                    # for mlp.experts[0].gate_gate_up_proj, which breaks load.
                    if ("mlp.experts." in name) and name not in params_dict:
                        continue
Liangsheng Yin's avatar
Liangsheng Yin committed
2551
                    name = name.replace(weight_name, param_name)
2552
2553
2554
                    # Skip loading extra bias for GPTQ models.
                    if name.endswith(".bias") and name not in params_dict:
                        continue
Liangsheng Yin's avatar
Liangsheng Yin committed
2555
2556
                    param = params_dict[name]
                    weight_loader = param.weight_loader
2557
2558
                    futures.append(
                        executor.submit(weight_loader, param, loaded_weight, shard_id)
Liangsheng Yin's avatar
Liangsheng Yin committed
2559
2560
2561
                    )
                    break
                else:
2562
2563
2564
2565
2566
2567
2568
2569
2570
2571
2572
2573
2574
2575
2576
2577
                    for mapping in expert_params_mapping:
                        param_name, weight_name, expert_id, shard_id = mapping
                        if weight_name not in name:
                            continue
                        name = name.replace(weight_name, param_name)
                        param = params_dict[name]
                        weight_loader = param.weight_loader
                        futures.append(
                            executor.submit(
                                weight_loader,
                                param,
                                loaded_weight,
                                name,
                                shard_id=shard_id,
                                expert_id=expert_id,
                            )
2578
                        )
2579
2580
2581
2582
2583
2584
2585
                        break
                    else:
                        # Skip loading extra bias for GPTQ models.
                        if name.endswith(".bias") and name not in params_dict:
                            continue
                        if fuse_qkv_a_proj and (
                            "q_a_proj" in name or "kv_a_proj_with_mqa" in name
2586
                        ):
2587
2588
2589
                            cached_a_proj[name] = loaded_weight
                            q_a_proj_name = (
                                name
2590
                                if "q_a_proj" in name
2591
2592
2593
2594
2595
2596
                                else name.replace("kv_a_proj_with_mqa", "q_a_proj")
                            )
                            kv_a_proj_name = (
                                name
                                if "kv_a_proj_with_mqa" in name
                                else name.replace("q_a_proj", "kv_a_proj_with_mqa")
2597
2598
                            )

2599
2600
2601
2602
2603
2604
2605
2606
2607
2608
                            # When both q_a_proj and kv_a_proj_with_mqa has been cached, load the fused weight to parameter
                            if (
                                q_a_proj_name in cached_a_proj
                                and kv_a_proj_name in cached_a_proj
                            ):
                                q_a_proj_weight = cached_a_proj[q_a_proj_name]
                                kv_a_proj_weight = cached_a_proj[kv_a_proj_name]
                                cat_dim = 0
                                if self.quant_config is not None and (
                                    self.quant_config.get_name() == "awq"
2609
                                    or self.quant_config.get_name() == "awq_marlin"
2610
2611
2612
2613
2614
2615
2616
2617
2618
2619
2620
2621
2622
2623
2624
2625
2626
2627
2628
2629
2630
2631
2632
2633
2634
2635
2636
2637
2638
2639
2640
2641
2642
2643
2644
2645
2646
2647
2648
2649
2650
2651
2652
2653
                                    or self.quant_config.get_name() == "moe_wna16"
                                ):
                                    cat_dim = 1
                                fused_weight = torch.cat(
                                    [q_a_proj_weight, kv_a_proj_weight], dim=cat_dim
                                )
                                param_name = (
                                    name.replace(
                                        "q_a_proj", "fused_qkv_a_proj_with_mqa"
                                    )
                                    if "q_a_proj" in name
                                    else name.replace(
                                        "kv_a_proj_with_mqa",
                                        "fused_qkv_a_proj_with_mqa",
                                    )
                                )
                                param = params_dict[param_name]

                                weight_loader = getattr(
                                    param, "weight_loader", default_weight_loader
                                )
                                futures.append(
                                    executor.submit(weight_loader, param, fused_weight)
                                )
                                cached_a_proj.pop(q_a_proj_name)
                                cached_a_proj.pop(kv_a_proj_name)
                        else:
                            if (
                                "k_scale" in name or "v_scale" in name
                            ) and name not in params_dict:
                                # modelopt attn kv scale is named differently
                                for scale in ["k_scale", "v_scale"]:
                                    if scale in name:
                                        name = name.replace(
                                            f"{scale[0]}_proj", "attn_mqa"
                                        )
                                        break
                            if name not in params_dict:
                                # modelopt ckpt contains not needed weights for MTP module:
                                # model.decoder.self_attn.attn_mqa.v_scale and
                                # model.decoder.self_attn.attn_mqa.k_scale
                                logger.warning(f"{name} not found in params_dict.")
                                continue
                            param = params_dict[name]
2654
2655
2656
                            weight_loader = getattr(
                                param, "weight_loader", default_weight_loader
                            )
2657
2658
2659
2660
2661
2662
2663
                            futures.append(
                                executor.submit(weight_loader, param, loaded_weight)
                            )

            # Wait for all tasks to complete and raise any exceptions.
            for future in concurrent.futures.as_completed(futures):
                future.result()
Liangsheng Yin's avatar
Liangsheng Yin committed
2664

2665
        self.post_load_weights(is_nextn=is_nextn, weight_names=weight_names)
Ke Bao's avatar
Ke Bao committed
2666

2667
2668
2669
2670
2671
2672
2673
2674
2675
2676
2677
    def get_embed_and_head(self):
        return self.model.embed_tokens.weight, self.lm_head.weight

    def set_embed_and_head(self, embed, head):
        del self.model.embed_tokens.weight
        del self.lm_head.weight
        self.model.embed_tokens.weight = embed
        self.lm_head.weight = head
        torch.cuda.empty_cache()
        torch.cuda.synchronize()

2678
2679
2680
2681
2682
2683
2684
2685
    @classmethod
    def get_model_config_for_expert_location(cls, config):
        return ModelConfigForExpertLocation(
            num_layers=config.num_hidden_layers,
            num_logical_experts=config.n_routed_experts,
            num_groups=config.n_group,
        )

Liangsheng Yin's avatar
Liangsheng Yin committed
2686

HandH1998's avatar
HandH1998 committed
2687
2688
2689
2690
2691
class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM):
    pass


EntryClass = [DeepseekV2ForCausalLM, DeepseekV3ForCausalLM]