deepseek_v2.py 116 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
14

Liangsheng Yin's avatar
Liangsheng Yin committed
15
16
17
# Adapted from:
# https://github.com/vllm-project/vllm/blob/fb6af8bc086328ca6659e72d11ffd4309ce4de22/vllm/model_executor/models/deepseek_v2.py
"""Inference-only DeepseekV2 model."""
18

19
import concurrent.futures
20
import logging
21
import os
22
from enum import IntEnum, auto
23
from typing import Any, Dict, Iterable, Optional, Tuple, Union
Liangsheng Yin's avatar
Liangsheng Yin committed
24
25

import torch
Ke Bao's avatar
Ke Bao committed
26
import torch.nn.functional as F
Liangsheng Yin's avatar
Liangsheng Yin committed
27
from torch import nn
28
from tqdm import tqdm
Liangsheng Yin's avatar
Liangsheng Yin committed
29
from transformers import PretrainedConfig
30
31

from sglang.srt.distributed import (
32
    get_moe_expert_parallel_world_size,
33
    get_pp_group,
Liangsheng Yin's avatar
Liangsheng Yin committed
34
    get_tensor_model_parallel_world_size,
35
    parallel_state,
Liangsheng Yin's avatar
Liangsheng Yin committed
36
37
    tensor_model_parallel_all_reduce,
)
38
39
40
from sglang.srt.distributed.device_communicators.pynccl_allocator import (
    use_symmetric_memory,
)
fzyzcjy's avatar
fzyzcjy committed
41
42
43
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation
from sglang.srt.eplb.expert_location_dispatch import ExpertLocationDispatchInfo
44
from sglang.srt.layers.activation import SiluAndMul
45
from sglang.srt.layers.amx_utils import PackWeightMethod
46
47
48
49
50
from sglang.srt.layers.communicator import (
    LayerCommunicator,
    LayerScatterModes,
    enable_moe_dense_fully_dp,
)
Lianmin Zheng's avatar
Lianmin Zheng committed
51
52
53
from sglang.srt.layers.dp_attention import (
    get_attention_tp_rank,
    get_attention_tp_size,
54
    is_dp_attention_enabled,
Lianmin Zheng's avatar
Lianmin Zheng committed
55
)
56
from sglang.srt.layers.layernorm import RMSNorm
57
58
59
60
61
62
from sglang.srt.layers.linear import (
    ColumnParallelLinear,
    MergedColumnParallelLinear,
    ReplicatedLinear,
    RowParallelLinear,
)
Liangsheng Yin's avatar
Liangsheng Yin committed
63
from sglang.srt.layers.logits_processor import LogitsProcessor
64
65
66
67
from sglang.srt.layers.moe import (
    get_deepep_mode,
    get_moe_a2a_backend,
    should_use_flashinfer_cutlass_moe_fp4_allgather,
68
    should_use_flashinfer_trtllm_moe,
69
)
70
from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, get_moe_impl_class
71
72
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
from sglang.srt.layers.moe.topk import TopK, TopKOutputFormat
73
from sglang.srt.layers.quantization import deep_gemm_wrapper
74
from sglang.srt.layers.quantization.base_config import QuantizationConfig
75
from sglang.srt.layers.quantization.fp8_kernel import (
76
    is_fp8_fnuz,
77
    per_tensor_quant_mla_fp8,
78
    per_token_group_quant_mla_deep_gemm_masked_fp8,
79
)
HandH1998's avatar
HandH1998 committed
80
from sglang.srt.layers.quantization.fp8_utils import (
81
    block_quant_dequant,
HandH1998's avatar
HandH1998 committed
82
    block_quant_to_tensor_quant,
83
    channel_quant_to_tensor_quant,
84
    normalize_e4m3fn_to_e4m3fnuz,
85
    requant_weight_ue8m0_inplace,
HandH1998's avatar
HandH1998 committed
86
)
87
88
89
from sglang.srt.layers.quantization.int8_utils import (
    block_dequant as int8_block_dequant,
)
Liangsheng Yin's avatar
Liangsheng Yin committed
90
from sglang.srt.layers.radix_attention import RadixAttention
91
92
from sglang.srt.layers.rotary_embedding import get_rope_wrapper
from sglang.srt.layers.utils import PPMissingLayer, get_layer_id
93
94
95
96
from sglang.srt.layers.vocab_parallel_embedding import (
    ParallelLMHead,
    VocabParallelEmbedding,
)
97
from sglang.srt.managers.schedule_batch import global_server_args_dict
98
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
99
from sglang.srt.model_loader.weight_utils import default_weight_loader
100
101
102
103
from sglang.srt.two_batch_overlap import (
    MaybeTboDeepEPDispatcher,
    model_forward_maybe_tbo,
)
104
105
from sglang.srt.utils import (
    BumpAllocator,
106
    LazyValue,
107
    add_prefix,
108
    bind_or_assign,
109
    cpu_has_amx_support,
110
    get_bool_env_var,
111
    get_device_sm,
112
    get_int_env_var,
113
    is_cpu,
114
    is_cuda,
115
    is_flashinfer_available,
116
    is_gfx95_supported,
117
    is_hip,
118
    is_non_idle_and_non_empty,
119
    is_npu,
120
    is_sm100_supported,
121
    log_info_on_rank0,
122
    make_layers,
123
    use_intel_amx_backend,
124
)
125

126
_is_hip = is_hip()
Yineng Zhang's avatar
Yineng Zhang committed
127
_is_cuda = is_cuda()
128
_is_npu = is_npu()
129
_is_fp8_fnuz = is_fp8_fnuz()
130
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
131
132
_is_cpu_amx_available = cpu_has_amx_support()
_is_cpu = is_cpu()
133
_device_sm = get_device_sm()
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
_is_gfx95_supported = is_gfx95_supported()

_use_aiter_gfx95 = _use_aiter and _is_gfx95_supported

if _use_aiter_gfx95:
    from sglang.srt.layers.quantization.quark.utils import quark_post_load_weights
    from sglang.srt.layers.quantization.rocm_mxfp4_utils import (
        batched_gemm_afp4wfp4_pre_quant,
        fused_flatten_mxfp4_quant,
        fused_rms_mxfp4_quant,
    )
    from sglang.srt.layers.rocm_linear_utils import (
        aiter_dsv3_router_gemm,
        fused_qk_rope_cat,
        get_dsv3_gemm_output_zero_allocator_size,
    )
150

Yineng Zhang's avatar
Yineng Zhang committed
151
if _is_cuda:
152
153
154
    from sgl_kernel import (
        awq_dequantize,
        bmm_fp8,
155
        concat_mla_k,
156
157
158
159
        dsv3_fused_a_gemm,
        dsv3_router_gemm,
        merge_state_v2,
    )
160
161
elif _is_cpu and _is_cpu_amx_available:
    pass
162
163
164
165
elif _is_hip:
    from sglang.srt.layers.quantization.awq_triton import (
        awq_dequantize_triton as awq_dequantize,
    )
Yineng Zhang's avatar
Yineng Zhang committed
166
else:
167
    pass
Liangsheng Yin's avatar
Liangsheng Yin committed
168

169
170
171
172
173
if _is_hip:
    from sglang.srt.layers.attention.triton_ops.rocm_mla_decode_rope import (
        decode_attention_fwd_grouped_rope,
    )

174
175
176
_is_flashinfer_available = is_flashinfer_available()
_is_sm100_supported = is_cuda() and is_sm100_supported()

177

178
179
logger = logging.getLogger(__name__)

180
181
182
183
184
185
186
187
188
189
190
191
192
193
FORWARD_ABSORB_CORE_ATTENTION_BACKENDS = [
    "fa3",
    "flashinfer",
    "cutlass_mla",
    "trtllm_mla",
    "ascend",
]


def add_forward_absorb_core_attention_backend(backend_name):
    if backend_name not in FORWARD_ABSORB_CORE_ATTENTION_BACKENDS:
        FORWARD_ABSORB_CORE_ATTENTION_BACKENDS.append(backend_name)
        logger.info(f"Added {backend_name} to FORWARD_ABSORB_CORE_ATTENTION_BACKENDS.")

Liangsheng Yin's avatar
Liangsheng Yin committed
194

195
196
197
198
199
200
201
202
203
204
205
class AttnForwardMethod(IntEnum):
    # Use multi-head attention
    MHA = auto()

    # Use absorbed multi-latent attention
    MLA = auto()

    # Use multi-head attention, but with KV cache chunked.
    # This method can avoid OOM when prefix lengths are long.
    MHA_CHUNKED_KV = auto()

206
207
208
    # Use MLA but with fused RoPE
    MLA_FUSED_ROPE = auto()

209
210
211
    # Use MLA with fused RoPE kernel for CPU
    MLA_FUSED_ROPE_CPU = auto()

212

213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
def _dispatch_mla_subtype(attn, forward_batch):
    if _is_hip:
        if attn.rocm_fused_decode_mla and forward_batch.forward_mode.is_decode():
            return AttnForwardMethod.MLA_FUSED_ROPE
        else:
            return AttnForwardMethod.MLA
    else:
        if hasattr(attn, "fused_qkv_a_proj_with_mqa") and use_intel_amx_backend(attn):
            return AttnForwardMethod.MLA_FUSED_ROPE_CPU
        else:
            return AttnForwardMethod.MLA


class BackendRegistry:
    _handlers = {}

    @classmethod
    def register(cls, backend_name, handler_func):
        cls._handlers[backend_name] = handler_func

    @classmethod
    def get_handler(cls, backend_name):
        return cls._handlers.get(backend_name, cls._handlers.get("triton"))


def handle_ascend(attn, forward_batch):
    if (
        forward_batch.forward_mode.is_extend()
        and not forward_batch.forward_mode.is_target_verify()
        and not forward_batch.forward_mode.is_draft_extend()
    ):
        return AttnForwardMethod.MHA
    else:
        return AttnForwardMethod.MLA


def _get_sum_extend_prefix_lens(forward_batch):
    return (
        sum(forward_batch.extend_prefix_lens_cpu)
        if forward_batch.extend_prefix_lens_cpu is not None
        else 0
    )


def _is_extend_without_speculative(forward_batch):
    return (
        forward_batch.forward_mode.is_extend()
        and not forward_batch.forward_mode.is_target_verify()
        and not forward_batch.forward_mode.is_draft_extend()
    )


def _handle_backend(attn, forward_batch, backend_name):
    sum_extend_prefix_lens = _get_sum_extend_prefix_lens(forward_batch)
    disable_ragged = (
        backend_name in ["flashinfer", "flashmla"]
    ) and attn.flashinfer_mla_disable_ragged

    if (
        not disable_ragged
        and _is_extend_without_speculative(forward_batch)
        and (
            (
                sum_extend_prefix_lens >= attn.chunked_prefix_cache_threshold
                and not attn.disable_chunked_prefix_cache
            )
            or sum_extend_prefix_lens == 0
        )
    ):
        return AttnForwardMethod.MHA_CHUNKED_KV
    else:
        return _dispatch_mla_subtype(attn, forward_batch)


def handle_flashinfer(attn, forward_batch):
    return _handle_backend(attn, forward_batch, "flashinfer")


def handle_fa3(attn, forward_batch):
    return _handle_backend(attn, forward_batch, "fa3")


def handle_flashmla(attn, forward_batch):
    return _handle_backend(attn, forward_batch, "flashmla")


def handle_cutlass_mla(attn, forward_batch):
    return _handle_backend(attn, forward_batch, "cutlass_mla")


def handle_fa4(attn, forward_batch):
    # TODO(cicirori): use FA4 MHA for DeepSeekV3 for now
    return AttnForwardMethod.MHA_CHUNKED_KV


def handle_trtllm_mla(attn, forward_batch):
    sum_extend_prefix_lens = _get_sum_extend_prefix_lens(forward_batch)
    if _is_extend_without_speculative(forward_batch) and (
        not attn.disable_chunked_prefix_cache or sum_extend_prefix_lens == 0
    ):
        return AttnForwardMethod.MHA_CHUNKED_KV
    else:
        return _dispatch_mla_subtype(attn, forward_batch)


def handle_aiter(attn, forward_batch):
    if _is_extend_without_speculative(forward_batch):
        if is_dp_attention_enabled():
            if sum(forward_batch.extend_prefix_lens_cpu) == 0:
                return AttnForwardMethod.MHA
            else:
                return AttnForwardMethod.MLA
        else:
            return AttnForwardMethod.MHA
    else:
        return AttnForwardMethod.MLA


def handle_triton(attn, forward_batch):
    if (
        _is_extend_without_speculative(forward_batch)
        and sum(forward_batch.extend_prefix_lens_cpu) == 0
    ):
        return AttnForwardMethod.MHA
    else:
        return _dispatch_mla_subtype(attn, forward_batch)


Liangsheng Yin's avatar
Liangsheng Yin committed
341
342
343
344
345
346
347
348
class DeepseekV2MLP(nn.Module):
    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
        quant_config: Optional[QuantizationConfig] = None,
        reduce_results: bool = True,
349
        prefix: str = "",
350
351
        tp_rank: Optional[int] = None,
        tp_size: Optional[int] = None,
Liangsheng Yin's avatar
Liangsheng Yin committed
352
353
    ) -> None:
        super().__init__()
354
355
        self.tp_size = tp_size

Liangsheng Yin's avatar
Liangsheng Yin committed
356
        self.gate_up_proj = MergedColumnParallelLinear(
357
358
359
360
361
            hidden_size,
            [intermediate_size] * 2,
            bias=False,
            quant_config=quant_config,
            prefix=add_prefix("gate_up_proj", prefix),
362
363
            tp_rank=tp_rank,
            tp_size=tp_size,
Liangsheng Yin's avatar
Liangsheng Yin committed
364
365
366
367
368
369
370
        )
        self.down_proj = RowParallelLinear(
            intermediate_size,
            hidden_size,
            bias=False,
            quant_config=quant_config,
            reduce_results=reduce_results,
371
            prefix=add_prefix("down_proj", prefix),
372
373
            tp_rank=tp_rank,
            tp_size=tp_size,
Liangsheng Yin's avatar
Liangsheng Yin committed
374
375
376
377
378
379
380
381
        )
        if hidden_act != "silu":
            raise ValueError(
                f"Unsupported activation: {hidden_act}. "
                "Only silu is supported for now."
            )
        self.act_fn = SiluAndMul()

382
383
384
385
    def forward(
        self,
        x,
        forward_batch=None,
386
        should_allreduce_fusion: bool = False,
387
        use_reduce_scatter: bool = False,
388
        gemm_output_zero_allocator: BumpAllocator = None,
389
    ):
390
391
392
        if (self.tp_size == 1) and x.shape[0] == 0:
            return x

393
394
395
396
397
        if (
            gemm_output_zero_allocator is not None
            and x.shape[0] <= 256
            and self.gate_up_proj.weight.dtype == torch.uint8
        ):
398
399
400
401
402
            y = gemm_output_zero_allocator.allocate(
                x.shape[0] * self.gate_up_proj.output_size_per_partition
            ).view(x.shape[0], self.gate_up_proj.output_size_per_partition)
            x = (x, None, y)

Liangsheng Yin's avatar
Liangsheng Yin committed
403
404
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
405
        x, _ = self.down_proj(
406
            x, skip_all_reduce=should_allreduce_fusion or use_reduce_scatter
407
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
408
409
410
        return x


Ke Bao's avatar
Ke Bao committed
411
class MoEGate(nn.Module):
412
413
414
    def __init__(
        self,
        config,
415
        quant_config,
416
        prefix: str = "",
417
        is_nextn: bool = False,
418
    ):
Ke Bao's avatar
Ke Bao committed
419
        super().__init__()
420
        self.is_nextn = is_nextn
Ke Bao's avatar
Ke Bao committed
421
422
423
424
        self.weight = nn.Parameter(
            torch.empty((config.n_routed_experts, config.hidden_size))
        )
        if config.topk_method == "noaux_tc":
425
426
427
428
429
430
431
            correction_bias_dtype = (
                torch.bfloat16
                if quant_config is not None
                and quant_config.get_name() == "modelopt_fp4"
                and should_use_flashinfer_trtllm_moe()
                else torch.float32
            )
Ke Bao's avatar
Ke Bao committed
432
            self.e_score_correction_bias = nn.Parameter(
433
                torch.empty((config.n_routed_experts), dtype=correction_bias_dtype)
Ke Bao's avatar
Ke Bao committed
434
435
436
            )
        else:
            self.e_score_correction_bias = None
437
438
        if _is_cpu and _is_cpu_amx_available:
            self.quant_method = PackWeightMethod(weight_names=["weight"])
Ke Bao's avatar
Ke Bao committed
439

440
    def forward(self, hidden_states, gemm_output_zero_allocator: BumpAllocator = None):
441
        if use_intel_amx_backend(self):
442
443
444
445
446
447
448
            return torch.ops.sgl_kernel.weight_packed_linear(
                hidden_states,
                self.weight,
                None,  # bias
                True,  # is_vnni
            )

449
        # NOTE: For some unknown reason, router_gemm seems degrade accept length.
450
        if (
451
            _is_cuda
452
            and hidden_states.shape[0] <= 16
453
454
455
456
            and hidden_states.shape[1] == 7168
            and self.weight.shape[0] == 256
            and _device_sm >= 90
        ):
457
            # router gemm output float32
458
459
460
            logits = dsv3_router_gemm(
                hidden_states, self.weight, out_dtype=torch.float32
            )
461
462
463
464
        elif _use_aiter_gfx95 and hidden_states.shape[0] <= 256:
            logits = aiter_dsv3_router_gemm(
                hidden_states, self.weight, gemm_output_zero_allocator
            )
465
466
467
        else:
            logits = F.linear(hidden_states, self.weight, None)

Ke Bao's avatar
Ke Bao committed
468
469
470
        return logits


Liangsheng Yin's avatar
Liangsheng Yin committed
471
472
473
474
475
class DeepseekV2MoE(nn.Module):

    def __init__(
        self,
        config: PretrainedConfig,
fzyzcjy's avatar
fzyzcjy committed
476
        layer_id: int,
Liangsheng Yin's avatar
Liangsheng Yin committed
477
        quant_config: Optional[QuantizationConfig] = None,
478
        prefix: str = "",
479
        alt_stream: Optional[torch.cuda.Stream] = None,
480
        is_nextn: bool = False,
Liangsheng Yin's avatar
Liangsheng Yin committed
481
482
483
484
485
    ):
        super().__init__()
        self.tp_size = get_tensor_model_parallel_world_size()
        self.routed_scaling_factor = config.routed_scaling_factor
        self.n_shared_experts = config.n_shared_experts
486
487
488
489
490
        self.num_fused_shared_experts = (
            0
            if global_server_args_dict["disable_shared_experts_fusion"]
            else config.n_shared_experts
        )
491
        self.config = config
fzyzcjy's avatar
fzyzcjy committed
492
        self.layer_id = layer_id
493
        self.alt_stream = alt_stream
494

Liangsheng Yin's avatar
Liangsheng Yin committed
495
496
497
498
499
500
501
502
503
504
505
506
        if self.tp_size > config.n_routed_experts:
            raise ValueError(
                f"Tensor parallel size {self.tp_size} is greater than "
                f"the number of experts {config.n_routed_experts}."
            )

        if config.hidden_act != "silu":
            raise ValueError(
                f"Unsupported activation: {config.hidden_act}. "
                "Only silu is supported for now."
            )

507
        self.gate = MoEGate(
508
509
510
511
            config=config,
            quant_config=quant_config,
            prefix=add_prefix("gate", prefix),
            is_nextn=is_nextn,
512
        )
Ke Bao's avatar
Ke Bao committed
513

514
        self.experts = get_moe_impl_class(quant_config)(
515
            num_experts=config.n_routed_experts
516
            + self.num_fused_shared_experts
517
            + global_server_args_dict["ep_num_redundant_experts"],
Cheng Wan's avatar
Cheng Wan committed
518
            num_fused_shared_experts=self.num_fused_shared_experts,
519
            top_k=config.num_experts_per_tok + self.num_fused_shared_experts,
520
521
            hidden_size=config.hidden_size,
            intermediate_size=config.moe_intermediate_size,
fzyzcjy's avatar
fzyzcjy committed
522
            layer_id=self.layer_id,
523
            quant_config=quant_config,
524
            routed_scaling_factor=self.routed_scaling_factor,
525
526
            prefix=add_prefix("experts", prefix),
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
527

528
529
530
531
532
533
534
        self.topk = TopK(
            top_k=config.num_experts_per_tok + self.num_fused_shared_experts,
            renormalize=config.norm_topk_prob,
            use_grouped_topk=True,
            num_expert_group=config.n_group,
            num_fused_shared_experts=self.num_fused_shared_experts,
            topk_group=config.topk_group,
535
536
            correction_bias=self.gate.e_score_correction_bias,
            quant_config=quant_config,
537
538
            routed_scaling_factor=self.routed_scaling_factor,
            apply_routed_scaling_factor_on_output=self.experts.should_fuse_routed_scaling_factor_in_topk(),
539
540
541
            # Some Fp4 MoE backends require the output format to be bypassed but the MTP layers are unquantized
            # and requires the output format to be standard. We use quant_config to determine the output format.
            output_format=TopKOutputFormat.STANDARD if quant_config is None else None,
542
543
        )

544
545
546
        self.shared_experts_is_int8 = False
        self.shared_experts_is_fp8 = False
        self.shared_experts_weight_block_size = None
547
        if config.n_shared_experts is not None and self.num_fused_shared_experts == 0:
Liangsheng Yin's avatar
Liangsheng Yin committed
548
            intermediate_size = config.moe_intermediate_size * config.n_shared_experts
549
            # disable tp for shared experts when enable deepep moe, or with fp4 allgather
550
551
552
553
554
555
556
557
558
            self.shared_experts = DeepseekV2MLP(
                hidden_size=config.hidden_size,
                intermediate_size=intermediate_size,
                hidden_act=config.hidden_act,
                quant_config=quant_config,
                reduce_results=False,
                prefix=add_prefix("shared_experts", prefix),
                **(
                    dict(tp_rank=0, tp_size=1)
559
                    if get_moe_a2a_backend().is_deepep()
560
                    or should_use_flashinfer_cutlass_moe_fp4_allgather()
561
562
563
                    else {}
                ),
            )
AniZpZ's avatar
AniZpZ committed
564
565
566
567
            is_packed_weight = hasattr(
                self.shared_experts.gate_up_proj.quant_method, "quant_config"
            ) and self.shared_experts.gate_up_proj.quant_method.quant_config.get_name() in {
                "awq",
568
                "awq_marlin",
AniZpZ's avatar
AniZpZ committed
569
570
                "moe_wna16",
            }
571
            self.shared_experts_is_int8 = (
572
573
                not is_packed_weight
                and self.shared_experts.gate_up_proj.weight.dtype == torch.int8
574
575
            )
            self.shared_experts_is_fp8 = (
576
577
                not is_packed_weight
                and self.shared_experts.gate_up_proj.weight.dtype == torch.float8_e4m3fn
578
579
580
581
582
583
584
585
586
            )
            if self.shared_experts_is_fp8:
                assert (
                    self.shared_experts.gate_up_proj.quant_method.quant_config.weight_block_size
                    == self.shared_experts.down_proj.quant_method.quant_config.weight_block_size
                )
                self.shared_experts_weight_block_size = (
                    self.shared_experts.gate_up_proj.quant_method.quant_config.weight_block_size
                )
587

588
589
        self.top_k = config.num_experts_per_tok

590
        if get_moe_a2a_backend().is_deepep():
591
            # TODO: we will support tp < ep in the future
592
            self.ep_size = get_moe_expert_parallel_world_size()
593
594
595
596
            self.num_experts = (
                config.n_routed_experts
                + global_server_args_dict["ep_num_redundant_experts"]
            )
597
598
599
600
601
602
603
604
605
            self.renormalize = config.norm_topk_prob
            self.topk_group = config.topk_group
            self.num_expert_group = config.n_group
            self.correction_bias = (
                self.gate.e_score_correction_bias.data
                if self.gate.e_score_correction_bias is not None
                else None
            )

606
            self.deepep_dispatcher = MaybeTboDeepEPDispatcher(
607
608
609
                group=parallel_state.get_tp_group().device_group,
                router_topk=self.top_k,
                permute_fusion=True,
610
                num_experts=self.num_experts,
611
                num_local_experts=config.n_routed_experts // self.tp_size,
Liangsheng Yin's avatar
Liangsheng Yin committed
612
                hidden_size=config.hidden_size,
613
                params_dtype=config.torch_dtype,
614
                deepep_mode=get_deepep_mode(),
615
                async_finish=True,
616
                return_recv_hook=True,
Liangsheng Yin's avatar
Liangsheng Yin committed
617
618
            )

619
        self._enable_deepep_moe = get_moe_a2a_backend().is_deepep()
620

621
622
623
624
625
626
627
    def get_moe_weights(self):
        return [
            x.data
            for name, x in self.experts.named_parameters()
            if name not in ["correction_bias"]
        ]

628
    def forward(
629
630
631
        self,
        hidden_states: torch.Tensor,
        forward_batch: Optional[ForwardBatch] = None,
632
        should_allreduce_fusion: bool = False,
633
        use_reduce_scatter: bool = False,
634
        gemm_output_zero_allocator: BumpAllocator = None,
635
636
    ) -> torch.Tensor:
        if not self._enable_deepep_moe:
637
638
639
640
            DUAL_STREAM_TOKEN_THRESHOLD = 1024
            if (
                self.alt_stream is not None
                and self.num_fused_shared_experts == 0
641
                and hidden_states.shape[0] > 0
642
643
                and hidden_states.shape[0] <= DUAL_STREAM_TOKEN_THRESHOLD
            ):
644
                return self.forward_normal_dual_stream(
645
646
647
                    hidden_states,
                    should_allreduce_fusion,
                    use_reduce_scatter,
648
                    gemm_output_zero_allocator,
649
                )
650
            else:
651
                return self.forward_normal(
652
653
654
                    hidden_states,
                    should_allreduce_fusion,
                    use_reduce_scatter,
655
                    gemm_output_zero_allocator,
656
                )
657
658
659
        else:
            return self.forward_deepep(hidden_states, forward_batch)

660
    def forward_normal_dual_stream(
661
662
        self,
        hidden_states: torch.Tensor,
663
        should_allreduce_fusion: bool = False,
664
        use_reduce_scatter: bool = False,
665
        gemm_output_zero_allocator: BumpAllocator = None,
666
    ) -> torch.Tensor:
667

668
669
        current_stream = torch.cuda.current_stream()
        self.alt_stream.wait_stream(current_stream)
670
671
672
        shared_output = self._forward_shared_experts(
            hidden_states, gemm_output_zero_allocator
        )
673

674
        with torch.cuda.stream(self.alt_stream):
675
            # router_logits: (num_tokens, n_experts)
676
            router_logits = self.gate(hidden_states, gemm_output_zero_allocator)
Cheng Wan's avatar
Cheng Wan committed
677
678
            topk_output = self.topk(hidden_states, router_logits)
            final_hidden_states = self.experts(hidden_states, topk_output)
679
680
            if not _is_cuda:
                final_hidden_states *= self.routed_scaling_factor
Cheng Wan's avatar
Cheng Wan committed
681

682
        current_stream.wait_stream(self.alt_stream)
683
684
        with use_symmetric_memory(parallel_state.get_tp_group()) as sm:
            final_hidden_states_out = torch.empty_like(final_hidden_states)
Cheng Wan's avatar
Cheng Wan committed
685

686
687
688
        torch.add(final_hidden_states, shared_output, out=final_hidden_states_out)
        final_hidden_states = final_hidden_states_out
        sm.tag(final_hidden_states)
689
690
691
692
693
694
        if (
            self.tp_size > 1
            and not should_allreduce_fusion
            and not use_reduce_scatter
            and not should_use_flashinfer_cutlass_moe_fp4_allgather()
        ):
695
696
697
            final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
        return final_hidden_states

698
    def forward_normal(
699
700
        self,
        hidden_states: torch.Tensor,
701
        should_allreduce_fusion: bool = False,
702
        use_reduce_scatter: bool = False,
703
        gemm_output_zero_allocator: BumpAllocator = None,
704
    ) -> torch.Tensor:
705
706
        if hasattr(self, "shared_experts") and use_intel_amx_backend(
            self.shared_experts.gate_up_proj
707
        ):
708
            return self.forward_cpu(hidden_states, should_allreduce_fusion)
709

710
        if hidden_states.shape[0] > 0:
711
712
713
            shared_output = self._forward_shared_experts(
                hidden_states, gemm_output_zero_allocator
            )
714
            # router_logits: (num_tokens, n_experts)
715
            router_logits = self.gate(hidden_states, gemm_output_zero_allocator)
716
717
718
719
            topk_output = self.topk(hidden_states, router_logits)
        else:
            shared_output = None
            topk_output = self.topk.empty_topk_output(hidden_states.device)
720

Cheng Wan's avatar
Cheng Wan committed
721
        final_hidden_states = self.experts(hidden_states, topk_output)
722
723
        if not _is_cuda and not _use_aiter:
            # fused in biased_grouped_topk so we can skip here
724
            final_hidden_states *= self.routed_scaling_factor
725
        if shared_output is not None:
726
727
728
729
730
            with use_symmetric_memory(parallel_state.get_tp_group()) as sm:
                final_hidden_states_out = torch.empty_like(final_hidden_states)
            torch.add(final_hidden_states, shared_output, out=final_hidden_states_out)
            final_hidden_states = final_hidden_states_out
            sm.tag(final_hidden_states)
731
732
733
734
735
736
        if (
            self.tp_size > 1
            and not should_allreduce_fusion
            and not use_reduce_scatter
            and not should_use_flashinfer_cutlass_moe_fp4_allgather()
        ):
737
738
739
            final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
        return final_hidden_states

740
    def forward_cpu(
741
742
743
        self,
        hidden_states: torch.Tensor,
        should_allreduce_fusion: bool = False,
744
    ) -> torch.Tensor:
745
746
        # router_logits: (num_tokens, n_experts)
        router_logits = self.gate(hidden_states)
747
        topk_output = self.topk(hidden_states, router_logits)
748
        fused_experts_out = self.experts(
749
            hidden_states=hidden_states, topk_output=topk_output
750
751
        )

752
753
754
        assert use_intel_amx_backend(
            self.shared_experts.gate_up_proj
        ) == use_intel_amx_backend(self.shared_experts.down_proj)
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
        # [Note] inplace should be False in fused_experts.
        # If inplace is True in fused_experts (self.experts), hidden_states will be changed after fused_experts
        # While hidden_states is still needed in shared_expert.
        final_hidden_states = torch.ops.sgl_kernel.shared_expert_cpu(
            hidden_states,
            self.shared_experts.gate_up_proj.weight,
            self.shared_experts.down_proj.weight,
            fused_experts_out,
            self.routed_scaling_factor,
            True,  # inplace
            self.shared_experts_is_int8,  # use_int8_w8a8
            self.shared_experts_is_fp8,  # use_fp8_w8a16
            (
                self.shared_experts.gate_up_proj.weight_scale
                if self.shared_experts_is_int8
                else (
                    self.shared_experts.gate_up_proj.weight_scale_inv
                    if self.shared_experts_is_fp8
                    else None
                )
            ),  # w1_scale
            (
                self.shared_experts.down_proj.weight_scale
                if self.shared_experts_is_int8
                else (
                    self.shared_experts.down_proj.weight_scale_inv
                    if self.shared_experts_is_fp8
                    else None
                )
            ),  # w2_scale
            (
                self.shared_experts_weight_block_size
                if self.shared_experts_is_fp8
                else None
            ),  # block_size
            None,  # a1_scale
            None,  # a2_scale
            True,  # is_vnni
        )
794
        if self.tp_size > 1 and not should_allreduce_fusion:
795
796
797
            final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
        return final_hidden_states

798
799
800
801
    def forward_deepep(
        self, hidden_states: torch.Tensor, forward_batch: ForwardBatch
    ) -> torch.Tensor:
        shared_output = None
Cheng Wan's avatar
Cheng Wan committed
802
        if hidden_states.shape[0] > 0:
803
804
805
            # router_logits: (num_tokens, n_experts)
            router_logits = self.gate(hidden_states)
            shared_output = self._forward_shared_experts(hidden_states)
806
807
808
            topk_weights, topk_idx, _ = self.topk(
                hidden_states,
                router_logits,
809
                num_token_non_padded=forward_batch.num_token_non_padded,
810
811
812
                expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new(
                    layer_id=self.layer_id,
                ),
813
814
            )
        else:
815
816
            topk_weights, topk_idx, _ = self.topk.empty_topk_output(
                hidden_states.device
817
            )
818

819
820
821
822
        final_hidden_states = self.experts(
            hidden_states=hidden_states,
            topk_idx=topk_idx,
            topk_weights=topk_weights,
823
            forward_batch=forward_batch,
824
825
826
        )

        if shared_output is not None:
827
            x = shared_output
828
829
830
831
            if self.experts.should_fuse_routed_scaling_factor_in_topk():
                x.add_(final_hidden_states)
            else:
                x.add_(final_hidden_states, alpha=self.routed_scaling_factor)
832
833
            final_hidden_states = x
        else:
834
835
            if not self.experts.should_fuse_routed_scaling_factor_in_topk():
                final_hidden_states *= self.routed_scaling_factor
836
837
838

        return final_hidden_states

839
840
841
    def _forward_shared_experts(
        self, hidden_states, gemm_output_zero_allocator: BumpAllocator = None
    ):
842
        if self.num_fused_shared_experts == 0:
843
844
845
            return self.shared_experts(
                hidden_states, gemm_output_zero_allocator=gemm_output_zero_allocator
            )
846
847
848
        else:
            return None

849
    def op_gate(self, state):
850
        if is_non_idle_and_non_empty(
851
            state.forward_batch.forward_mode, state.hidden_states_mlp_input
852
        ):
853
            # router_logits: (num_tokens, n_experts)
854
            state.router_logits = self.gate(state.hidden_states_mlp_input)
855
        else:
856
            state.router_logits = None
857

858
    def op_shared_experts(self, state):
859
        hidden_states_mlp_input = state.pop("hidden_states_mlp_input")
860
        if (self.num_fused_shared_experts == 0) and is_non_idle_and_non_empty(
861
            state.forward_batch.forward_mode, hidden_states_mlp_input
862
        ):
863
            state.shared_output = self.shared_experts(hidden_states_mlp_input)
864
        else:
865
            state.shared_output = None
866

867
    def op_select_experts(self, state):
868
        router_logits = state.pop("router_logits")
869
870
        hidden_states = state.hidden_states_mlp_input

871
        if router_logits is not None:
872
873
874
            with get_global_expert_distribution_recorder().with_current_layer(
                self.layer_id
            ):
875
                state.topk_weights_local, state.topk_idx_local, _ = self.topk(
876
877
878
879
880
881
882
                    hidden_states=hidden_states,
                    router_logits=router_logits,
                    num_token_non_padded=state.forward_batch.num_token_non_padded,
                    expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new(
                        layer_id=self.layer_id,
                    ),
                )
883
884
885
886
887
888
889
        else:
            state.topk_idx_local = torch.full(
                (0, self.top_k), -1, dtype=torch.int, device=hidden_states.device
            )
            state.topk_weights_local = torch.empty(
                (0, self.top_k), dtype=torch.float32, device=hidden_states.device
            )
890

891
    def op_dispatch_a(self, state):
892
        if self.ep_size > 1:
893
            self.experts.deepep_dispatcher.dispatch_a(
894
                hidden_states=state.hidden_states_mlp_input,
895
896
                topk_idx=state.pop("topk_idx_local"),
                topk_weights=state.pop("topk_weights_local"),
897
                forward_batch=state.forward_batch,
898
                tbo_subbatch_index=state.get("tbo_subbatch_index"),
899
            )
900

901
    def op_dispatch_b(self, state):
902
903
904
905
        if self.ep_size > 1:
            with get_global_expert_distribution_recorder().with_current_layer(
                self.layer_id
            ):
906
                state.dispatch_output = self.experts.deepep_dispatcher.dispatch_b(
907
908
                    tbo_subbatch_index=state.get("tbo_subbatch_index"),
                )
909
910

    def op_experts(self, state):
911
912
        state.hidden_states_experts_output = self.experts.moe_impl(
            dispatch_output=state.dispatch_output,
913
        )
914

915
    def op_combine_a(self, state):
916
        if self.ep_size > 1:
917
            self.experts.deepep_dispatcher.combine_a(
918
                hidden_states=state.pop("hidden_states_experts_output"),
919
920
                topk_idx=state.dispatch_output.topk_idx,
                topk_weights=state.dispatch_output.topk_weights,
921
                forward_batch=state.forward_batch,
922
                tbo_subbatch_index=state.get("tbo_subbatch_index"),
923
            )
924
            state.pop("dispatch_output")
925

926
    def op_combine_b(self, state):
927
        if self.ep_size > 1:
928
929
930
931
            state.hidden_states_after_combine = (
                self.experts.deepep_dispatcher.combine_b(
                    tbo_subbatch_index=state.get("tbo_subbatch_index"),
                )
932
            )
933
934

    def op_output(self, state):
935
        final_hidden_states = state.pop("hidden_states_after_combine")
936
937
938
939
940
941
942

        if (shared_output := state.pop("shared_output")) is not None:
            x = shared_output
            x.add_(final_hidden_states, alpha=self.routed_scaling_factor)
            final_hidden_states = x
        else:
            final_hidden_states *= self.routed_scaling_factor
Liangsheng Yin's avatar
Liangsheng Yin committed
943

944
        state.hidden_states_mlp_output = final_hidden_states
945

Liangsheng Yin's avatar
Liangsheng Yin committed
946
947
948
949
950
951
952
953
954

def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
    import math

    if scale <= 1:
        return 1.0
    return 0.1 * mscale * math.log(scale) + 1.0


955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
class DeepseekV2AttentionMLA(nn.Module):

    def __init__(
        self,
        config: PretrainedConfig,
        hidden_size: int,
        num_heads: int,
        qk_nope_head_dim: int,
        qk_rope_head_dim: int,
        v_head_dim: int,
        q_lora_rank: int,
        kv_lora_rank: int,
        rope_theta: float = 10000,
        rope_scaling: Optional[Dict[str, Any]] = None,
        max_position_embeddings: int = 8192,
        quant_config: Optional[QuantizationConfig] = None,
Lianmin Zheng's avatar
Lianmin Zheng committed
971
972
        reduce_results: bool = True,
        layer_id: int = None,
973
        prefix: str = "",
974
        alt_stream: Optional[torch.cuda.Stream] = None,
975
976
977
978
979
980
981
982
983
984
    ) -> None:
        super().__init__()
        self.layer_id = layer_id
        self.hidden_size = hidden_size
        self.qk_nope_head_dim = qk_nope_head_dim
        self.qk_rope_head_dim = qk_rope_head_dim
        self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
        self.v_head_dim = v_head_dim
        self.q_lora_rank = q_lora_rank
        self.kv_lora_rank = kv_lora_rank
Lianmin Zheng's avatar
Lianmin Zheng committed
985
986
987
        attn_tp_rank = get_attention_tp_rank()
        attn_tp_size = get_attention_tp_size()

988
        self.num_heads = num_heads
Lianmin Zheng's avatar
Lianmin Zheng committed
989
990
        assert num_heads % attn_tp_size == 0
        self.num_local_heads = num_heads // attn_tp_size
991
992
993
994
        self.scaling = self.qk_head_dim**-0.5
        self.rope_theta = rope_theta
        self.max_position_embeddings = max_position_embeddings

Lianmin Zheng's avatar
Lianmin Zheng committed
995
996
        # For tensor parallel attention
        if self.q_lora_rank is not None:
997
            self.fused_qkv_a_proj_with_mqa = ReplicatedLinear(
Ke Bao's avatar
Ke Bao committed
998
                self.hidden_size,
999
                self.q_lora_rank + self.kv_lora_rank + self.qk_rope_head_dim,
1000
1001
                bias=False,
                quant_config=quant_config,
1002
                prefix=add_prefix("fused_qkv_a_proj_with_mqa", prefix),
1003
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
1004
1005
1006
1007
            self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
            self.q_b_proj = ColumnParallelLinear(
                q_lora_rank,
                self.num_heads * self.qk_head_dim,
Ke Bao's avatar
Ke Bao committed
1008
1009
                bias=False,
                quant_config=quant_config,
Lianmin Zheng's avatar
Lianmin Zheng committed
1010
1011
1012
                prefix=add_prefix("q_b_proj", prefix),
                tp_rank=attn_tp_rank,
                tp_size=attn_tp_size,
Ke Bao's avatar
Ke Bao committed
1013
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
1014
1015
        else:
            self.q_proj = ColumnParallelLinear(
1016
                self.hidden_size,
Lianmin Zheng's avatar
Lianmin Zheng committed
1017
                self.num_heads * self.qk_head_dim,
1018
1019
                bias=False,
                quant_config=quant_config,
Lianmin Zheng's avatar
Lianmin Zheng committed
1020
1021
1022
                prefix=add_prefix("q_proj", prefix),
                tp_rank=attn_tp_rank,
                tp_size=attn_tp_size,
1023
            )
1024
1025
1026
1027
1028
1029
1030
1031
            self.kv_a_proj_with_mqa = ReplicatedLinear(
                self.hidden_size,
                self.kv_lora_rank + self.qk_rope_head_dim,
                bias=False,
                quant_config=quant_config,
                prefix=add_prefix("kv_a_proj_with_mqa", prefix),
            )

Lianmin Zheng's avatar
Lianmin Zheng committed
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
        self.kv_b_proj = ColumnParallelLinear(
            self.kv_lora_rank,
            self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
            bias=False,
            quant_config=quant_config,
            prefix=add_prefix("kv_b_proj", prefix),
            tp_rank=attn_tp_rank,
            tp_size=attn_tp_size,
        )
        # O projection.
        self.o_proj = RowParallelLinear(
            self.num_heads * self.v_head_dim,
            self.hidden_size,
            bias=False,
            quant_config=quant_config,
            reduce_results=reduce_results,
            prefix=add_prefix("o_proj", prefix),
            tp_rank=attn_tp_rank,
            tp_size=attn_tp_size,
        )
1052
        self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
Ke Bao's avatar
Ke Bao committed
1053
1054
1055
1056

        if rope_scaling:
            rope_scaling["rope_type"] = "deepseek_yarn"

1057
        self.rotary_emb = get_rope_wrapper(
1058
1059
1060
1061
1062
1063
            qk_rope_head_dim,
            rotary_dim=qk_rope_head_dim,
            max_position=max_position_embeddings,
            base=rope_theta,
            rope_scaling=rope_scaling,
            is_neox_style=False,
1064
            device=global_server_args_dict["device"],
1065
1066
1067
1068
1069
1070
1071
        )

        if rope_scaling:
            mscale_all_dim = rope_scaling.get("mscale_all_dim", False)
            scaling_factor = rope_scaling["factor"]
            mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
            self.scaling = self.scaling * mscale * mscale
Ke Bao's avatar
Ke Bao committed
1072
1073
        else:
            self.rotary_emb.forward = self.rotary_emb.forward_native
1074

1075
        self.attn_mqa = RadixAttention(
1076
1077
1078
1079
1080
1081
            self.num_local_heads,
            self.kv_lora_rank + self.qk_rope_head_dim,
            self.scaling,
            num_kv_heads=1,
            layer_id=layer_id,
            v_head_dim=self.kv_lora_rank,
1082
            quant_config=quant_config,
1083
            prefix=add_prefix("attn_mqa", prefix),
1084
1085
        )

1086
1087
1088
1089
1090
1091
1092
        self.attn_mha = RadixAttention(
            self.num_local_heads,
            self.qk_nope_head_dim + self.qk_rope_head_dim,
            self.scaling,
            num_kv_heads=self.num_local_heads,
            layer_id=layer_id,
            v_head_dim=self.v_head_dim,
1093
            quant_config=quant_config,
1094
            prefix=add_prefix("attn_mha", prefix),
1095
1096
        )

1097
        self.alt_stream = alt_stream
1098
        self.attn_mha.kv_b_proj = None
1099

Ke Bao's avatar
Ke Bao committed
1100
1101
        self.w_kc = None
        self.w_vc = None
1102
        self.w_scale = 1.0
1103

1104
1105
1106
1107
        self.w_scale_k = None
        self.w_scale_v = None
        self.use_deep_gemm_bmm = False

Lianmin Zheng's avatar
Lianmin Zheng committed
1108
1109
1110
        self.flashinfer_mla_disable_ragged = global_server_args_dict[
            "flashinfer_mla_disable_ragged"
        ]
1111
1112
1113
        self.disable_chunked_prefix_cache = global_server_args_dict[
            "disable_chunked_prefix_cache"
        ]
1114
1115
1116
1117

        self.current_attention_backend = (
            None  # Attention backend used by current forward batch
        )
1118
1119
1120
        self.rocm_fused_decode_mla = get_bool_env_var(
            "SGLANG_ROCM_FUSED_DECODE_MLA", "false"
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
1121

1122
        # TODO: Design a finer way to determine the threshold
1123
1124
1125
        self.chunked_prefix_cache_threshold = get_int_env_var(
            "SGL_CHUNKED_PREFIX_CACHE_THRESHOLD", 8192
        )
1126

1127
1128
1129
        # If we have self.fused_qkv_a_proj_with_mqa and we're running on CPU, we will choose the torch.ops.sgl_kernel.qkv_proj_with_rope_fused_weight kernel
        # which requires self.w_kc and self.w_vc to be packed.
        # If not, we will use torch.bmm and weight shouldn't be packed in this case
AniZpZ's avatar
AniZpZ committed
1130
1131
        has_fused_proj = hasattr(self, "fused_qkv_a_proj_with_mqa")
        if has_fused_proj and _is_cpu and _is_cpu_amx_available:
1132
1133
1134
1135
            self.quant_method = PackWeightMethod(
                weight_names=["w_kc", "w_vc"], transpose_dims=[[1, 2], [1, 2]]
            )

1136
        is_packed_weight = (
AniZpZ's avatar
AniZpZ committed
1137
1138
1139
            has_fused_proj
            and hasattr(self.fused_qkv_a_proj_with_mqa.quant_method, "quant_config")
            and self.fused_qkv_a_proj_with_mqa.quant_method.quant_config.get_name()
1140
            in {"awq", "awq_marlin", "moe_wna16"}
1141
        )
1142
        self.use_min_latency_fused_a_gemm = (
AniZpZ's avatar
AniZpZ committed
1143
            has_fused_proj
1144
            and not is_packed_weight
1145
1146
1147
            and self.fused_qkv_a_proj_with_mqa.weight.dtype == torch.bfloat16
            and self.fused_qkv_a_proj_with_mqa.weight.shape[0] == 2112
            and self.fused_qkv_a_proj_with_mqa.weight.shape[1] == 7168
1148
            and _is_cuda
1149
            and _device_sm >= 90
1150
1151
        )

1152
        self.qkv_proj_with_rope_is_int8 = (
AniZpZ's avatar
AniZpZ committed
1153
            has_fused_proj
1154
            and not is_packed_weight
1155
1156
1157
            and self.fused_qkv_a_proj_with_mqa.weight.dtype == torch.int8
        )
        self.qkv_proj_with_rope_is_fp8 = (
AniZpZ's avatar
AniZpZ committed
1158
            has_fused_proj
1159
            and not is_packed_weight
1160
1161
1162
1163
            and self.fused_qkv_a_proj_with_mqa.weight.dtype == torch.float8_e4m3fn
        )

        self.weight_block_size = None
1164
1165
1166
1167
1168
1169
        if self.qkv_proj_with_rope_is_fp8 and _is_cpu and _is_cpu_amx_available:
            assert getattr(
                self.fused_qkv_a_proj_with_mqa.quant_method, "block_quant", False
            ) == getattr(self.q_b_proj.quant_method, "block_quant", False)
            use_block_quant = getattr(
                self.fused_qkv_a_proj_with_mqa.quant_method, "block_quant", False
1170
1171
            )

1172
1173
1174
1175
1176
1177
1178
1179
1180
            if use_block_quant:
                assert (
                    self.fused_qkv_a_proj_with_mqa.quant_method.quant_config.weight_block_size
                    == self.q_b_proj.quant_method.quant_config.weight_block_size
                )
                self.weight_block_size = (
                    self.fused_qkv_a_proj_with_mqa.quant_method.quant_config.weight_block_size
                )

1181
1182
1183
    def dispatch_attn_forward_method(
        self, forward_batch: ForwardBatch
    ) -> AttnForwardMethod:
1184
1185
1186
        # Determine attention backend used by current forward batch
        if forward_batch.forward_mode.is_decode_or_idle():
            attention_backend = global_server_args_dict["decode_attention_backend"]
1187
1188
1189
1190
1191
        elif (
            forward_batch.forward_mode.is_target_verify()
            or forward_batch.forward_mode.is_draft_extend()
        ):
            # Use the specified backend for speculative operations (both verify and draft extend)
1192
            if global_server_args_dict["speculative_attention_mode"] == "decode":
1193
1194
1195
                attention_backend = global_server_args_dict["decode_attention_backend"]
            else:  # default to prefill
                attention_backend = global_server_args_dict["prefill_attention_backend"]
1196
1197
1198
1199
        else:
            attention_backend = global_server_args_dict["prefill_attention_backend"]
        self.current_attention_backend = attention_backend

1200
1201
        handler = BackendRegistry.get_handler(attention_backend)
        return handler(self, forward_batch)
Lianmin Zheng's avatar
Lianmin Zheng committed
1202

1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
    def op_prepare(self, state):
        state.attn_intermediate_state = self.forward_prepare(
            positions=state.positions,
            hidden_states=state.pop("hidden_states_after_comm_pre_attn"),
            forward_batch=state.forward_batch,
            zero_allocator=state.zero_allocator,
        )

    def op_core(self, state):
        state.hidden_states_after_attn = self.forward_core(
            state.pop("attn_intermediate_state")
        )

1216
1217
1218
1219
    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
1220
        forward_batch: ForwardBatch,
1221
        zero_allocator: BumpAllocator,
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
    ):
        s = self.forward_prepare(
            positions=positions,
            hidden_states=hidden_states,
            forward_batch=forward_batch,
            zero_allocator=zero_allocator,
        )
        return self.forward_core(s)

    def forward_prepare(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        forward_batch: ForwardBatch,
        zero_allocator: BumpAllocator,
    ):
1238
1239
1240
        if self.attn_mha.kv_b_proj is None:
            self.attn_mha.kv_b_proj = self.kv_b_proj

1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
        # when hidden_states is a tuple of tensors, the tuple will include quantized weight and scale tensor
        if isinstance(hidden_states, tuple):
            if hidden_states[0].shape[0] == 0:
                assert (
                    not self.o_proj.reduce_results
                ), "short-circuiting allreduce will lead to hangs"
                return hidden_states[0]
        else:
            if hidden_states.shape[0] == 0:
                assert (
                    not self.o_proj.reduce_results
                ), "short-circuiting allreduce will lead to hangs"
                return hidden_states, None, forward_batch, None
1254

1255
1256
1257
        attn_forward_method = self.dispatch_attn_forward_method(forward_batch)

        if attn_forward_method == AttnForwardMethod.MHA:
1258
1259
1260
            inner_state = self.forward_normal_prepare(
                positions, hidden_states, forward_batch, zero_allocator
            )
1261
        elif attn_forward_method == AttnForwardMethod.MHA_CHUNKED_KV:
1262
1263
            inner_state = self.forward_normal_chunked_kv_prepare(
                positions, hidden_states, forward_batch, zero_allocator
1264
            )
1265
        elif attn_forward_method == AttnForwardMethod.MLA:
1266
            inner_state = self.forward_absorb_prepare(
1267
1268
1269
                positions, hidden_states, forward_batch, zero_allocator
            )
        elif attn_forward_method == AttnForwardMethod.MLA_FUSED_ROPE:
1270
1271
            inner_state = self.forward_absorb_fused_mla_rope_prepare(
                positions, hidden_states, forward_batch, zero_allocator
1272
            )
1273
1274
1275
1276
        elif attn_forward_method == AttnForwardMethod.MLA_FUSED_ROPE_CPU:
            inner_state = self.forward_absorb_fused_mla_rope_cpu_prepare(
                positions, hidden_states, forward_batch, zero_allocator
            )
1277
        else:
1278
            raise NotImplementedError
1279
        return None, attn_forward_method, forward_batch, inner_state
1280

1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
    def forward_core(self, intermediate_state):
        hidden_states, attn_forward_method, forward_batch, inner_state = (
            intermediate_state
        )
        if inner_state is None:
            return hidden_states

        if attn_forward_method == AttnForwardMethod.MHA:
            return self.forward_normal_core(*inner_state)
        elif attn_forward_method == AttnForwardMethod.MHA_CHUNKED_KV:
            return self.forward_normal_chunked_kv_core(*inner_state)
        elif attn_forward_method == AttnForwardMethod.MLA:
            return self.forward_absorb_core(*inner_state)
        elif attn_forward_method == AttnForwardMethod.MLA_FUSED_ROPE:
            return self.forward_absorb_fused_mla_rope_core(*inner_state)
1296
1297
        elif attn_forward_method == AttnForwardMethod.MLA_FUSED_ROPE_CPU:
            return self.forward_absorb_fused_mla_rope_cpu_core(*inner_state)
1298
1299
1300
1301
        else:
            raise NotImplementedError

    def forward_normal_prepare(
1302
1303
1304
1305
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        forward_batch: ForwardBatch,
1306
1307
        zero_allocator: BumpAllocator,
    ):
1308
        if self.q_lora_rank is not None:
1309
1310
1311
            q, latent_cache = self.fused_qkv_a_proj_with_mqa(hidden_states)[0].split(
                [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1
            )
1312
1313
1314
1315
1316
1317
            q = self.q_a_layernorm(q)
            q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
        else:
            q = self.q_proj(hidden_states)[0].view(
                -1, self.num_local_heads, self.qk_head_dim
            )
1318
1319
            latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]

1320
1321
1322
        _, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
        kv_a, _ = latent_cache.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
        latent_cache = latent_cache.unsqueeze(1)
1323
        kv_a = self.kv_a_layernorm(kv_a)
1324
1325
1326
1327
1328
1329
1330
1331
        kv = self.kv_b_proj(kv_a)[0]
        kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim)
        k_nope = kv[..., : self.qk_nope_head_dim]
        v = kv[..., self.qk_nope_head_dim :]
        k_pe = latent_cache[:, :, self.kv_lora_rank :]
        q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
        q[..., self.qk_nope_head_dim :] = q_pe
        k = torch.empty_like(q)
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343

        # Temporary for DeepSeek V3/R1 only, but can generalize if needed
        if (
            _is_cuda
            and (self.num_local_heads == 128)
            and (self.qk_nope_head_dim == 128)
            and (self.qk_rope_head_dim == 64)
        ):
            concat_mla_k(k=k, k_nope=k_nope, k_rope=k_pe)
        else:
            k[..., : self.qk_nope_head_dim] = k_nope
            k[..., self.qk_nope_head_dim :] = k_pe
1344

1345
1346
1347
        if not _is_npu:
            latent_cache[:, :, : self.kv_lora_rank] = kv_a.unsqueeze(1)
            latent_cache[:, :, self.kv_lora_rank :] = k_pe
1348

1349
1350
1351
1352
1353
1354
1355
1356
1357
            # Save latent cache
            forward_batch.token_to_kv_pool.set_kv_buffer(
                self.attn_mha, forward_batch.out_cache_loc, latent_cache, None
            )
        else:
            # To reduce a time-costing split operation
            forward_batch.token_to_kv_pool.set_kv_buffer(
                self.attn_mha, forward_batch.out_cache_loc, kv_a.unsqueeze(1), k_pe
            )
1358
1359
1360
1361

        return q, k, v, forward_batch

    def forward_normal_core(self, q, k, v, forward_batch):
1362
1363
1364
1365
1366
        attn_output = self.attn_mha(q, k, v, forward_batch, save_kv_cache=False)
        attn_output = attn_output.reshape(-1, self.num_local_heads * self.v_head_dim)
        output, _ = self.o_proj(attn_output)
        return output

Faraz's avatar
Faraz committed
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
    def _fuse_rope_for_trtllm_mla(self, forward_batch: ForwardBatch) -> bool:
        """
        Check if we should skip rope and do fused rope+quantize for TRTLLM MLA decode in fp8_e4m3 path.
        """
        return (
            self.current_attention_backend == "trtllm_mla"
            and forward_batch.forward_mode.is_decode_or_idle()
            and forward_batch.attn_backend.data_type == torch.float8_e4m3fn
        )

1377
    def forward_absorb_prepare(
1378
1379
1380
1381
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        forward_batch: ForwardBatch,
1382
        zero_allocator: BumpAllocator,
1383
    ):
1384
        from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
1385

1386
        if self.q_lora_rank is not None:
1387
1388
1389
1390
1391
            if (
                (not isinstance(hidden_states, tuple))
                and hidden_states.shape[0] <= 16
                and self.use_min_latency_fused_a_gemm
            ):
1392
1393
1394
1395
1396
1397
                fused_qkv_a_proj_out = dsv3_fused_a_gemm(
                    hidden_states, self.fused_qkv_a_proj_with_mqa.weight.T
                )
            else:
                fused_qkv_a_proj_out = self.fused_qkv_a_proj_with_mqa(hidden_states)[0]
            q, latent_cache = fused_qkv_a_proj_out.split(
1398
1399
                [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1
            )
1400
1401
1402
            k_nope = latent_cache[..., : self.kv_lora_rank]

            # overlap qk norm
1403
            if self.alt_stream is not None and get_is_capture_mode():
1404
1405
1406
1407
1408
1409
1410
                current_stream = torch.cuda.current_stream()
                self.alt_stream.wait_stream(current_stream)
                q = self.q_a_layernorm(q)
                with torch.cuda.stream(self.alt_stream):
                    k_nope = self.kv_a_layernorm(k_nope)
                current_stream.wait_stream(self.alt_stream)
            else:
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
                if _use_aiter_gfx95 and self.q_b_proj.weight.dtype == torch.uint8:
                    q, k_nope = fused_rms_mxfp4_quant(
                        q,
                        self.q_a_layernorm.weight,
                        self.q_a_layernorm.variance_epsilon,
                        k_nope,
                        self.kv_a_layernorm.weight,
                        self.kv_a_layernorm.variance_epsilon,
                    )
                else:
                    q = self.q_a_layernorm(q)
                    k_nope = self.kv_a_layernorm(k_nope)
1423
1424

            k_nope = k_nope.unsqueeze(1)
1425
1426
1427
1428
1429
            q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
        else:
            q = self.q_proj(hidden_states)[0].view(
                -1, self.num_local_heads, self.qk_head_dim
            )
1430
            latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
1431
1432
1433
            k_nope = latent_cache[..., : self.kv_lora_rank]
            k_nope = self.kv_a_layernorm(k_nope).unsqueeze(1)

1434
        q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
1435
        k_pe = latent_cache[..., self.kv_lora_rank :].unsqueeze(1)
1436

1437
1438
        if self.use_deep_gemm_bmm:
            q_nope_val, q_nope_scale, masked_m, expected_m, aligned_m = (
1439
                per_token_group_quant_mla_deep_gemm_masked_fp8(q_nope.transpose(0, 1))
1440
1441
1442
1443
            )
            q_nope_out = q_nope.new_empty(
                (self.num_local_heads, aligned_m, self.kv_lora_rank)
            )
1444
            deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
1445
1446
1447
1448
1449
1450
1451
                (q_nope_val, q_nope_scale),
                (self.w_kc, self.w_scale_k),
                q_nope_out,
                masked_m,
                expected_m,
            )
            q_nope_out = q_nope_out[:, :expected_m, :]
1452
1453
        elif _is_hip:
            # TODO(haishaw): add bmm_fp8 to ROCm
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
            if _use_aiter_gfx95 and self.w_kc.dtype == torch.uint8:
                x = q_nope.transpose(0, 1)
                q_nope_out = torch.empty(
                    x.shape[0],
                    x.shape[1],
                    self.w_kc.shape[2],
                    device=x.device,
                    dtype=torch.bfloat16,
                )
                batched_gemm_afp4wfp4_pre_quant(
                    x,
                    self.w_kc.transpose(-2, -1),
                    self.w_scale_k.transpose(-2, -1),
                    torch.bfloat16,
                    q_nope_out,
                )
            else:
                q_nope_out = torch.bmm(
                    q_nope.to(torch.bfloat16).transpose(0, 1),
                    self.w_kc.to(torch.bfloat16) * self.w_scale,
                )
1475
        elif self.w_kc.dtype == torch.float8_e4m3fn:
1476
            q_nope_val, q_nope_scale = per_tensor_quant_mla_fp8(
Lianmin Zheng's avatar
Lianmin Zheng committed
1477
                q_nope.transpose(0, 1),
1478
                zero_allocator.allocate(1),
1479
1480
1481
1482
1483
1484
            )
            q_nope_out = bmm_fp8(
                q_nope_val, self.w_kc, q_nope_scale, self.w_scale, torch.bfloat16
            )
        else:
            q_nope_out = torch.bmm(q_nope.transpose(0, 1), self.w_kc)
1485
1486

        q_nope_out = q_nope_out.transpose(0, 1)
Faraz's avatar
Faraz committed
1487

1488
1489
1490
        if not self._fuse_rope_for_trtllm_mla(forward_batch) and (
            not _use_aiter or not _is_gfx95_supported
        ):
Faraz's avatar
Faraz committed
1491
            q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
1492

1493
        return q_pe, k_pe, q_nope_out, k_nope, forward_batch, zero_allocator, positions
1494
1495

    def forward_absorb_core(
1496
        self, q_pe, k_pe, q_nope_out, k_nope, forward_batch, zero_allocator, positions
1497
    ):
1498
        if self.current_attention_backend in FORWARD_ABSORB_CORE_ATTENTION_BACKENDS:
Faraz's avatar
Faraz committed
1499
1500
1501
1502
1503
1504
            extra_args = {}
            if self._fuse_rope_for_trtllm_mla(forward_batch):
                extra_args = {
                    "cos_sin_cache": self.rotary_emb.cos_sin_cache,
                    "is_neox": self.rotary_emb.is_neox_style,
                }
1505
            attn_output = self.attn_mqa(
Faraz's avatar
Faraz committed
1506
1507
1508
1509
1510
1511
1512
                q_nope_out,
                k_nope,
                k_nope,
                forward_batch,
                q_rope=q_pe,
                k_rope=k_pe,
                **extra_args,
1513
1514
            )
        else:
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
            if _use_aiter_gfx95:
                cos = self.rotary_emb.cos_cache
                sin = self.rotary_emb.sin_cache
                q, k = fused_qk_rope_cat(
                    q_nope_out,
                    q_pe,
                    k_nope,
                    k_pe,
                    positions,
                    cos,
                    sin,
                    self.rotary_emb.is_neox_style,
                )
            else:
                q = torch.cat([q_nope_out, q_pe], dim=-1)
                k = torch.cat([k_nope, k_pe], dim=-1)

1532
            attn_output = self.attn_mqa(q, k, k_nope, forward_batch)
1533
1534
        attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)

1535
1536
        if self.use_deep_gemm_bmm:
            attn_output_val, attn_output_scale, masked_m, expected_m, aligned_m = (
1537
1538
                per_token_group_quant_mla_deep_gemm_masked_fp8(
                    attn_output.transpose(0, 1)
1539
1540
1541
1542
1543
                )
            )
            attn_bmm_output = attn_output.new_empty(
                (self.num_local_heads, aligned_m, self.v_head_dim)
            )
1544
            deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
1545
1546
1547
1548
1549
1550
                (attn_output_val, attn_output_scale),
                (self.w_vc, self.w_scale_v),
                attn_bmm_output,
                masked_m,
                expected_m,
            )
Ke Bao's avatar
Ke Bao committed
1551
1552
1553
            attn_bmm_output = (
                attn_bmm_output[:, :expected_m, :].transpose(0, 1).flatten(1, 2)
            )
1554
1555
        elif _is_hip:
            # TODO(haishaw): add bmm_fp8 to ROCm
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
            if _use_aiter_gfx95 and self.w_vc.dtype == torch.uint8:
                x = attn_output.transpose(0, 1)
                attn_bmm_output = torch.empty(
                    x.shape[0],
                    x.shape[1],
                    self.w_vc.shape[2],
                    device=x.device,
                    dtype=torch.bfloat16,
                )
                batched_gemm_afp4wfp4_pre_quant(
                    x,
                    self.w_vc.transpose(-2, -1),
                    self.w_scale_v.transpose(-2, -1),
                    torch.bfloat16,
                    attn_bmm_output,
                )
            else:
                attn_bmm_output = torch.bmm(
                    attn_output.to(torch.bfloat16).transpose(0, 1),
                    self.w_vc.to(torch.bfloat16) * self.w_scale,
                )

            if self.o_proj.weight.dtype == torch.uint8:
                attn_bmm_output = attn_bmm_output.transpose(0, 1)
                attn_bmm_output = fused_flatten_mxfp4_quant(attn_bmm_output)
            else:
                attn_bmm_output = attn_bmm_output.transpose(0, 1).flatten(1, 2)

1584
        elif self.w_vc.dtype == torch.float8_e4m3fn:
1585
            attn_output_val, attn_output_scale = per_tensor_quant_mla_fp8(
Lianmin Zheng's avatar
Lianmin Zheng committed
1586
                attn_output.transpose(0, 1),
1587
                zero_allocator.allocate(1),
1588
1589
1590
1591
1592
1593
1594
1595
            )
            attn_bmm_output = bmm_fp8(
                attn_output_val,
                self.w_vc,
                attn_output_scale,
                self.w_scale,
                torch.bfloat16,
            )
Ke Bao's avatar
Ke Bao committed
1596
            attn_bmm_output = attn_bmm_output.transpose(0, 1).flatten(1, 2)
1597
        else:
Ke Bao's avatar
Ke Bao committed
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
            attn_bmm_output = torch.empty(
                (attn_output.shape[0], self.num_local_heads * self.v_head_dim),
                dtype=attn_output.dtype,
                device=attn_output.device,
            )
            torch.bmm(
                attn_output.transpose(0, 1),
                self.w_vc,
                out=attn_bmm_output.view(
                    -1, self.num_local_heads, self.v_head_dim
                ).transpose(0, 1),
            )
        output, _ = self.o_proj(attn_bmm_output)
1611
1612
1613

        return output

1614
    def forward_absorb_fused_mla_rope_prepare(
1615
1616
1617
1618
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        forward_batch: ForwardBatch,
1619
        zero_allocator: BumpAllocator,
1620
    ):
1621
1622
1623
1624
1625
1626
1627
1628
        enable_rope_fusion = (
            os.getenv("SGLANG_FUSED_MLA_ENABLE_ROPE_FUSION", "1") == "1"
        )
        q_len = hidden_states.shape[0]
        q_input = hidden_states.new_empty(
            q_len, self.num_local_heads, self.kv_lora_rank + self.qk_rope_head_dim
        )
        if self.q_lora_rank is not None:
1629
1630
1631
            q, latent_cache = self.fused_qkv_a_proj_with_mqa(hidden_states)[0].split(
                [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1
            )
1632
1633
1634
1635
1636
1637
            q = self.q_a_layernorm(q)
            q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
        else:
            q = self.q_proj(hidden_states)[0].view(
                -1, self.num_local_heads, self.qk_head_dim
            )
1638
            latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
1639
1640
        q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)

1641
1642
        if _is_hip:
            # TODO(haishaw): add bmm_fp8 to ROCm
1643
1644
1645
1646
1647
            q_nope_out = torch.bmm(
                q_nope.to(torch.bfloat16).transpose(0, 1),
                self.w_kc.to(torch.bfloat16) * self.w_scale,
            )
        elif self.w_kc.dtype == torch.float8_e4m3fn:
1648
            q_nope_val, q_nope_scale = per_tensor_quant_mla_fp8(
1649
1650
1651
                q_nope.transpose(0, 1),
                zero_allocator.allocate(1),
                dtype=torch.float8_e4m3fn,
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
            )
            q_nope_out = bmm_fp8(
                q_nope_val, self.w_kc, q_nope_scale, self.w_scale, torch.bfloat16
            )
        else:
            q_nope_out = torch.bmm(q_nope.transpose(0, 1), self.w_kc)
        q_input[..., : self.kv_lora_rank] = q_nope_out.transpose(0, 1)
        v_input = latent_cache[..., : self.kv_lora_rank]
        v_input = self.kv_a_layernorm(v_input.contiguous()).unsqueeze(1)
        k_input = latent_cache.unsqueeze(1)
        k_input[..., : self.kv_lora_rank] = v_input

        if not enable_rope_fusion:
            k_pe = k_input[..., self.kv_lora_rank :]
            q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
            q_input[..., self.kv_lora_rank :] = q_pe
            k_input[..., self.kv_lora_rank :] = k_pe
            k_pe_output = None
        else:
            k_pe_output = torch.empty_like(k_input[..., self.kv_lora_rank :])

        q_input[..., self.kv_lora_rank :] = q_pe

        # attn_output = self.attn_mqa(q_input, k_input, v_input, forward_batch)
        # Use Fused ROPE with use_rope=OFF.
        attn_output = torch.empty(
            (q_len, self.num_local_heads, self.kv_lora_rank),
            dtype=q.dtype,
            device=q.device,
        )
        attn_logits, _, kv_indptr, kv_indices, _, _, _ = (
            forward_batch.attn_backend.forward_metadata
        )
        cos_sin_cache = self.rotary_emb.cos_sin_cache
        num_kv_split = forward_batch.attn_backend.num_kv_splits
        sm_scale = self.attn_mqa.scaling
        if attn_logits is None:
            attn_logits = torch.empty(
                (
                    forward_batch.batch_size,
                    self.num_local_heads,
                    num_kv_split,
                    self.kv_lora_rank + 1,
                ),
                dtype=torch.float32,
                device=q.device,
            )

        # save current latent cache.
        forward_batch.token_to_kv_pool.set_kv_buffer(
            self.attn_mqa, forward_batch.out_cache_loc, k_input, None
        )
        key_cache_buf = forward_batch.token_to_kv_pool.get_key_buffer(
            self.attn_mqa.layer_id
        )
        val_cache_buf = key_cache_buf[..., : self.kv_lora_rank]

1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
        return (
            q_input,
            key_cache_buf,
            val_cache_buf,
            attn_output,
            kv_indptr,
            kv_indices,
            k_pe_output,
            cos_sin_cache,
            positions,
            attn_logits,
            num_kv_split,
            sm_scale,
            enable_rope_fusion,
            k_input,
            forward_batch,
            zero_allocator,
        )

1728
1729
1730
1731
1732
1733
1734
    def forward_absorb_fused_mla_rope_cpu_prepare(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        forward_batch: ForwardBatch,
        zero_allocator: BumpAllocator,
    ):
1735
1736
        assert self.q_lora_rank is not None and use_intel_amx_backend(
            self
1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
1747
1748
1749
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
        ), "forward_absorb_fused_mla_rope_cpu_prepare requires q_lora_rank is not None and use_intel_amx_backend"

        q_input, k_input, v_input = (
            torch.ops.sgl_kernel.qkv_proj_with_rope_fused_weight(
                hidden_states,
                self.fused_qkv_a_proj_with_mqa.weight,
                self.q_b_proj.weight,
                self.w_kc,
                self.q_a_layernorm.weight,
                self.kv_a_layernorm.weight,
                positions,
                self.rotary_emb.cos_sin_cache,
                self.kv_a_layernorm.variance_epsilon,
                self.qkv_proj_with_rope_is_int8,
                self.qkv_proj_with_rope_is_fp8,
                (
                    self.fused_qkv_a_proj_with_mqa.weight_scale
                    if self.qkv_proj_with_rope_is_int8
                    else (
                        self.fused_qkv_a_proj_with_mqa.weight_scale_inv
                        if self.qkv_proj_with_rope_is_fp8
                        else None
                    )
                ),
                (
                    self.q_b_proj.weight_scale
                    if self.qkv_proj_with_rope_is_int8
                    else (
                        self.q_b_proj.weight_scale_inv
                        if self.qkv_proj_with_rope_is_fp8
                        else None
                    )
                ),
                True,  # is_vnni
                self.weight_block_size,
                self.q_lora_rank,
                self.kv_lora_rank,
                self.qk_rope_head_dim,
            )
        )
        return (q_input, k_input, v_input, forward_batch, zero_allocator)

1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
    def forward_absorb_fused_mla_rope_core(
        self,
        q_input,
        key_cache_buf,
        val_cache_buf,
        attn_output,
        kv_indptr,
        kv_indices,
        k_pe_output,
        cos_sin_cache,
        positions,
        attn_logits,
        num_kv_split,
        sm_scale,
        enable_rope_fusion,
        k_input,
        forward_batch,
        zero_allocator,
    ):
1798
1799
1800
1801
1802
1803
1804
1805
1806
1807
1808
1809
1810
1811
1812
1813
1814
1815
1816
1817
1818
1819
1820
1821
1822
1823
1824
1825
        decode_attention_fwd_grouped_rope(
            q_input,
            key_cache_buf,
            val_cache_buf,
            attn_output,
            kv_indptr,
            kv_indices,
            k_pe_output,
            self.kv_lora_rank,
            self.rotary_emb.rotary_dim,
            cos_sin_cache,
            positions,
            attn_logits,
            num_kv_split,
            sm_scale,
            logit_cap=self.attn_mqa.logit_cap,
            use_rope=enable_rope_fusion,
            is_neox_style=self.rotary_emb.is_neox_style,
        )

        if enable_rope_fusion:
            k_input[..., self.kv_lora_rank :] = k_pe_output
            forward_batch.token_to_kv_pool.set_kv_buffer(
                self.attn_mqa, forward_batch.out_cache_loc, k_input, None
            )

        attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)

1826
1827
        if _is_hip:
            # TODO(haishaw): add bmm_fp8 to ROCm
1828
1829
1830
1831
1832
            attn_bmm_output = torch.bmm(
                attn_output.to(torch.bfloat16).transpose(0, 1),
                self.w_vc.to(torch.bfloat16) * self.w_scale,
            )
        elif self.w_vc.dtype == torch.float8_e4m3fn:
1833
            attn_output_val, attn_output_scale = per_tensor_quant_mla_fp8(
1834
1835
1836
                attn_output.transpose(0, 1),
                zero_allocator.allocate(1),
                dtype=torch.float8_e4m3fn,
1837
1838
1839
1840
1841
1842
1843
1844
1845
1846
1847
            )
            attn_bmm_output = bmm_fp8(
                attn_output_val,
                self.w_vc,
                attn_output_scale,
                self.w_scale,
                torch.bfloat16,
            )
        else:
            attn_bmm_output = torch.bmm(attn_output.transpose(0, 1), self.w_vc)
        attn_output = attn_bmm_output.transpose(0, 1).flatten(1, 2)
1848
1849
1850
1851
        output, _ = self.o_proj(attn_output)

        return output

1852
1853
1854
    def forward_absorb_fused_mla_rope_cpu_core(
        self, q_input, k_input, v_input, forward_batch, zero_allocator
    ):
1855
1856
        assert self.q_lora_rank is not None and use_intel_amx_backend(
            self
1857
1858
1859
1860
1861
1862
1863
1864
1865
1866
1867
1868
1869
1870
1871
1872
1873
1874
1875
1876
1877
1878
1879
1880
1881
1882
1883
1884
1885
1886
1887
1888
        ), "forward_absorb_fused_mla_rope_cpu_core requires q_lora_rank is not None and use_intel_amx_backend"

        attn_output = self.attn_mqa(q_input, k_input, v_input, forward_batch)
        attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)

        # [Note] Align shapes of bmm inputs.
        # Shapes of inputs:
        #   q_nope: [M, B, K]
        #   original self.w_kc: [B, K, N]
        #   current self.w_kc (which has been converted in PackWeightMethod): [B, N, K]

        # Shapes of inputs to sgl_kernel.cpu.bmm:
        #   out: [B, M, N]
        #   mat1: [B, M, K]
        #   mat2: [B, N, K]
        B = self.w_vc.size(0)
        N = self.w_vc.size(1)
        M = attn_output.size(0)
        output = torch.empty([M, int(B * N)], dtype=attn_output.dtype)
        attn_bmm_output = output.view([M, B, N]).transpose_(0, 1)
        torch.ops.sgl_kernel.bmm_cpu(
            attn_bmm_output,
            attn_output.transpose(0, 1),
            self.w_vc,
            True,  # is_vnni
            None,  # scale
        )
        attn_output = output
        output, _ = self.o_proj(attn_output)

        return output

1889
1890
1891
1892
1893
1894
1895
1896
1897
1898
1899
1900
1901
1902
1903
1904
    def _chunked_prefix_attn_mha(
        self,
        q: torch.Tensor,
        accum_output: torch.Tensor,
        accum_lse: torch.Tensor,
        forward_batch: ForwardBatch,
    ) -> torch.Tensor:

        assert forward_batch.num_prefix_chunks is not None
        for i in range(forward_batch.num_prefix_chunks):
            forward_batch.set_prefix_chunk_idx(i)

            # Fetch latent cache from memory pool with precomputed chunked kv indices
            latent_cache_buf = forward_batch.token_to_kv_pool.get_key_buffer(
                self.attn_mha.layer_id
            )
1905
1906
1907
1908
1909
            latent_cache = (
                latent_cache_buf[forward_batch.prefix_chunk_kv_indices[i]]
                .contiguous()
                .to(q.dtype)
            )
1910
1911
1912
1913
1914
1915
1916
1917
1918
1919
1920
1921
1922
1923
1924
1925
1926
1927
1928
1929
1930
1931
1932
1933
1934
1935
1936
1937
1938
1939
1940
1941

            kv_a_normed, k_pe = latent_cache.split(
                [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
            )
            kv_a_normed = kv_a_normed.squeeze(1).contiguous()
            kv = self.kv_b_proj(kv_a_normed)[0]
            kv = kv.view(
                -1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim
            )
            v = kv[..., self.qk_nope_head_dim :]
            k_nope = kv[..., : self.qk_nope_head_dim]

            k = torch.empty(
                (
                    k_nope.shape[0],
                    self.num_local_heads,
                    self.qk_nope_head_dim + self.qk_rope_head_dim,
                ),
                dtype=v.dtype,
                device=v.device,
            )
            k[..., : self.qk_nope_head_dim] = k_nope
            k[..., self.qk_nope_head_dim :] = k_pe

            output, lse = self.attn_mha(q, k, v, forward_batch, save_kv_cache=False)
            tmp_output = torch.empty_like(accum_output)
            tmp_lse = torch.empty_like(accum_lse)
            merge_state_v2(output, lse, accum_output, accum_lse, tmp_output, tmp_lse)
            accum_output, accum_lse = tmp_output, tmp_lse

        return accum_output

1942
    def forward_normal_chunked_kv_prepare(
1943
1944
1945
1946
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        forward_batch: ForwardBatch,
1947
1948
        zero_allocator: BumpAllocator,
    ):
1949
1950
1951
1952
1953
1954
1955
        # In normal mha, the k and v tensors will become overly large when the prefix length is long.
        # To avoid this, we split the kv cache into chunks and process them one after another.
        # Since mha is compute friendly, the for loop induced here will not introduce significant overhead.
        # The top comments in https://github.com/vllm-project/vllm/blob/main/vllm/v1/attention/backends/mla/common.py
        # will be helpful for understanding the purpose of this function.

        # First do normal mha forward to get output for extended part
1956
1957
        return self.forward_normal_prepare(
            positions, hidden_states, forward_batch, zero_allocator
1958
1959
        )

1960
    def forward_normal_chunked_kv_core(self, q, k, v, forward_batch):
1961
1962
1963
1964
1965
1966
1967
1968
        has_extend_prefix = any(forward_batch.extend_prefix_lens_cpu)
        # Only initialize the info once
        if has_extend_prefix and forward_batch.num_prefix_chunks is None:
            forward_batch.prepare_chunked_prefix_cache_info(q.device)
            if hasattr(forward_batch.attn_backend, "init_mha_chunk_metadata"):
                forward_batch.attn_backend.init_mha_chunk_metadata(forward_batch)

        forward_batch.mha_return_lse = has_extend_prefix
1969
1970
        # Do mha for extended part without prefix
        forward_batch.set_attn_attend_prefix_cache(False)
1971
        attn_output = self.attn_mha(q, k, v, forward_batch, save_kv_cache=False)
1972
1973

        # Do mha attention with chunked prefix cache if there are any sequence with prefix
1974
1975
        if has_extend_prefix:
            attn_output, lse = attn_output
1976
1977
1978
1979
1980
1981
1982
1983
1984
1985
1986
1987
            forward_batch.set_attn_attend_prefix_cache(True)
            attn_output = self._chunked_prefix_attn_mha(
                q=q,
                accum_output=attn_output,
                accum_lse=lse,
                forward_batch=forward_batch,
            )

        attn_output = attn_output.reshape(-1, self.num_local_heads * self.v_head_dim)
        output, _ = self.o_proj(attn_output)
        return output

1988

Liangsheng Yin's avatar
Liangsheng Yin committed
1989
1990
1991
1992
1993
1994
1995
class DeepseekV2DecoderLayer(nn.Module):

    def __init__(
        self,
        config: PretrainedConfig,
        layer_id: int,
        quant_config: Optional[QuantizationConfig] = None,
1996
        is_nextn: bool = False,
1997
        prefix: str = "",
1998
        alt_stream: Optional[torch.cuda.Stream] = None,
Liangsheng Yin's avatar
Liangsheng Yin committed
1999
2000
2001
    ) -> None:
        super().__init__()
        self.hidden_size = config.hidden_size
2002
        self.config = config
Liangsheng Yin's avatar
Liangsheng Yin committed
2003
2004
2005
        rope_theta = getattr(config, "rope_theta", 10000)
        rope_scaling = getattr(config, "rope_scaling", None)
        max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
2006
        self.speculative_algorithm = global_server_args_dict["speculative_algorithm"]
Lianmin Zheng's avatar
Lianmin Zheng committed
2007
        self.layer_id = layer_id
2008
        self.is_nextn = is_nextn
Baizhou Zhang's avatar
Baizhou Zhang committed
2009
2010
2011
2012
2013
2014
2015
2016
2017
2018
2019
2020
2021
2022
2023
2024
2025
2026
        self.self_attn = DeepseekV2AttentionMLA(
            config=config,
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
            qk_nope_head_dim=config.qk_nope_head_dim,
            qk_rope_head_dim=config.qk_rope_head_dim,
            v_head_dim=config.v_head_dim,
            q_lora_rank=(
                config.q_lora_rank if hasattr(config, "q_lora_rank") else None
            ),
            kv_lora_rank=config.kv_lora_rank,
            rope_theta=rope_theta,
            rope_scaling=rope_scaling,
            max_position_embeddings=max_position_embeddings,
            quant_config=quant_config,
            layer_id=layer_id,
            reduce_results=False,
            prefix=add_prefix("self_attn", prefix),
2027
            alt_stream=alt_stream,
Baizhou Zhang's avatar
Baizhou Zhang committed
2028
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
2029

2030
2031
2032
2033
2034
        self.is_layer_sparse = self._is_layer_sparse(layer_id, is_nextn=is_nextn)
        is_previous_layer_sparse = self._is_layer_sparse(layer_id - 1, is_nextn=False)

        self.layer_scatter_modes = LayerScatterModes.init_new(
            layer_id=layer_id,
2035
            num_layers=1 if is_nextn else config.num_hidden_layers,
2036
2037
            is_layer_sparse=self.is_layer_sparse,
            is_previous_layer_sparse=is_previous_layer_sparse,
2038
2039
        )

2040
        if self.is_layer_sparse:
2041
2042
2043
2044
            self.mlp = DeepseekV2MoE(
                config=config,
                quant_config=quant_config,
                prefix=add_prefix("mlp", prefix),
fzyzcjy's avatar
fzyzcjy committed
2045
                layer_id=self.layer_id,
2046
                alt_stream=alt_stream,
2047
                is_nextn=is_nextn,
2048
            )
Liangsheng Yin's avatar
Liangsheng Yin committed
2049
        else:
2050
            if enable_moe_dense_fully_dp():
2051
2052
2053
                mlp_tp_rank, mlp_tp_size = 0, 1
            else:
                mlp_tp_rank, mlp_tp_size = None, None
Liangsheng Yin's avatar
Liangsheng Yin committed
2054
2055
2056
2057
2058
            self.mlp = DeepseekV2MLP(
                hidden_size=config.hidden_size,
                intermediate_size=config.intermediate_size,
                hidden_act=config.hidden_act,
                quant_config=quant_config,
2059
                prefix=add_prefix("mlp", prefix),
2060
2061
                tp_rank=mlp_tp_rank,
                tp_size=mlp_tp_size,
Liangsheng Yin's avatar
Liangsheng Yin committed
2062
            )
2063

Liangsheng Yin's avatar
Liangsheng Yin committed
2064
2065
2066
2067
2068
        self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(
            config.hidden_size, eps=config.rms_norm_eps
        )

2069
2070
2071
2072
        self.layer_communicator = LayerCommunicator(
            layer_scatter_modes=self.layer_scatter_modes,
            input_layernorm=self.input_layernorm,
            post_attention_layernorm=self.post_attention_layernorm,
2073
            allow_reduce_scatter=True,
2074
2075
2076
            is_last_layer=(
                is_nextn or (self.layer_id == self.config.num_hidden_layers - 1)
            ),
2077
        )
2078
2079
2080
2081
2082
2083

    def _is_layer_sparse(self, layer_id: int, is_nextn: bool) -> bool:
        return is_nextn or (
            self.config.n_routed_experts is not None
            and layer_id >= self.config.first_k_dense_replace
            and layer_id % self.config.moe_layer_freq == 0
2084
2085
        )

Liangsheng Yin's avatar
Liangsheng Yin committed
2086
2087
2088
2089
    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
2090
        forward_batch: ForwardBatch,
Liangsheng Yin's avatar
Liangsheng Yin committed
2091
        residual: Optional[torch.Tensor],
2092
        zero_allocator: BumpAllocator,
2093
        gemm_output_zero_allocator: BumpAllocator = None,
Liangsheng Yin's avatar
Liangsheng Yin committed
2094
    ) -> torch.Tensor:
2095

2096
2097
2098
        quant_format = (
            "mxfp4"
            if _is_gfx95_supported
2099
2100
2101
2102
            and getattr(self.self_attn, "fused_qkv_a_proj_with_mqa", None) is not None
            and getattr(self.self_attn.fused_qkv_a_proj_with_mqa, "weight", None)
            is not None
            and self.self_attn.fused_qkv_a_proj_with_mqa.weight.dtype == torch.uint8
2103
2104
2105
            else ""
        )

2106
        hidden_states, residual = self.layer_communicator.prepare_attn(
2107
2108
2109
2110
            hidden_states,
            residual,
            forward_batch,
            quant_format,
2111
2112
        )

2113
2114
2115
2116
2117
2118
2119
2120
2121
2122
2123
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
            forward_batch=forward_batch,
            zero_allocator=zero_allocator,
        )

        hidden_states, residual = self.layer_communicator.prepare_mlp(
            hidden_states, residual, forward_batch
        )

2124
        should_allreduce_fusion = (
2125
2126
            self.layer_communicator.should_fuse_mlp_allreduce_with_next_layer(
                forward_batch
2127
            )
2128
2129
        )

2130
2131
2132
2133
        # For DP with padding, reduce scatter can be used instead of all-reduce.
        use_reduce_scatter = self.layer_communicator.should_use_reduce_scatter(
            forward_batch
        )
2134
2135
2136
2137

        if isinstance(self.mlp, DeepseekV2MLP):
            gemm_output_zero_allocator = None

2138
        hidden_states = self.mlp(
2139
2140
2141
2142
2143
            hidden_states,
            forward_batch,
            should_allreduce_fusion,
            use_reduce_scatter,
            gemm_output_zero_allocator,
2144
        )
2145

2146
        if should_allreduce_fusion:
2147
2148
            hidden_states._sglang_needs_allreduce_fusion = True

2149
        if not should_allreduce_fusion:
2150
2151
2152
2153
            hidden_states, residual = self.layer_communicator.postprocess_layer(
                hidden_states, residual, forward_batch
            )

2154
2155
        return hidden_states, residual

2156
2157
2158
2159
2160
2161
2162
2163
    def op_comm_prepare_attn(
        self,
        state,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        forward_batch: ForwardBatch,
        residual: Optional[torch.Tensor],
        zero_allocator: BumpAllocator,
2164
        tbo_subbatch_index: Optional[int] = None,
2165
2166
    ):
        state.hidden_states_after_comm_pre_attn, state.residual_after_input_ln = (
fzyzcjy's avatar
fzyzcjy committed
2167
            self.layer_communicator.prepare_attn(hidden_states, residual, forward_batch)
2168
2169
2170
2171
2172
2173
        )
        state.update(
            dict(
                forward_batch=forward_batch,
                positions=positions,
                zero_allocator=zero_allocator,
2174
                tbo_subbatch_index=tbo_subbatch_index,
2175
            )
2176
        )
2177

2178
2179
2180
2181
2182
2183
2184
    def op_comm_prepare_mlp(self, state):
        state.hidden_states_mlp_input, state.residual_after_comm_pre_mlp = (
            self.layer_communicator.prepare_mlp(
                state.pop("hidden_states_after_attn"),
                state.pop("residual_after_input_ln"),
                state.forward_batch,
            )
2185
        )
2186

2187
2188
2189
2190
2191
2192
2193
2194
    def op_mlp(self, state):
        hidden_states = state.pop("hidden_states_mlp_input")
        if not (
            enable_moe_dense_fully_dp()
            and (not self.is_layer_sparse)
            and hidden_states.shape[0] == 0
        ):
            state.hidden_states_mlp_output = self.mlp(
2195
                hidden_states, state.forward_batch
2196
2197
2198
            )
        else:
            state.hidden_states_mlp_output = hidden_states
2199

2200
    def op_comm_postprocess_layer(self, state):
2201
        hidden_states, residual = self.layer_communicator.postprocess_layer(
2202
2203
2204
            state.pop("hidden_states_mlp_output"),
            state.pop("residual_after_comm_pre_mlp"),
            state.forward_batch,
2205
        )
2206

2207
2208
2209
2210
2211
2212
2213
2214
2215
2216
2217
2218
2219
2220
2221
2222
2223
2224
        output = dict(
            positions=state.positions,
            hidden_states=hidden_states,
            residual=residual,
            forward_batch=state.forward_batch,
            zero_allocator=state.zero_allocator,
            tbo_subbatch_index=state.tbo_subbatch_index,
        )

        state.clear(
            expect_keys={
                "positions",
                "forward_batch",
                "zero_allocator",
                "tbo_subbatch_index",
            }
        )
        return output
2225

Liangsheng Yin's avatar
Liangsheng Yin committed
2226
2227
2228
2229
2230
2231
2232
2233

class DeepseekV2Model(nn.Module):
    fall_back_to_pt_during_load = False

    def __init__(
        self,
        config: PretrainedConfig,
        quant_config: Optional[QuantizationConfig] = None,
2234
        prefix: str = "",
Liangsheng Yin's avatar
Liangsheng Yin committed
2235
2236
2237
2238
    ) -> None:
        super().__init__()
        self.padding_id = config.pad_token_id
        self.vocab_size = config.vocab_size
2239
        self.first_k_dense_replace = config.first_k_dense_replace
2240
2241
2242
2243
2244
2245
2246
2247
2248
2249
        self.pp_group = get_pp_group()

        if self.pp_group.is_first_rank:
            self.embed_tokens = VocabParallelEmbedding(
                config.vocab_size,
                config.hidden_size,
                enable_tp=not is_dp_attention_enabled(),
            )
        else:
            self.embed_tokens = PPMissingLayer()
Liangsheng Yin's avatar
Liangsheng Yin committed
2250

2251
        self.alt_stream = torch.cuda.Stream() if _is_cuda else None
2252
2253
2254
2255
2256
2257
2258
2259
2260
2261
2262
2263
        self.layers, self.start_layer, self.end_layer = make_layers(
            config.num_hidden_layers,
            lambda idx, prefix: DeepseekV2DecoderLayer(
                config=config,
                layer_id=idx,
                quant_config=quant_config,
                prefix=prefix,
                alt_stream=self.alt_stream,
            ),
            pp_rank=self.pp_group.rank_in_group,
            pp_size=self.pp_group.world_size,
            prefix=add_prefix("layers", prefix),
fzyzcjy's avatar
fzyzcjy committed
2264
2265
2266
2267
2268
2269
2270
2271
2272
2273
            offloader_kwargs=dict(
                submodule_accessor=lambda layer: (
                    layer.mlp.experts
                    if isinstance(layer.mlp, DeepseekV2MoE)
                    else layer.mlp
                ),
                whitelist_param_names_creator=lambda module: (
                    [
                        "w13_weight",
                        "w2_weight",
fzyzcjy's avatar
fzyzcjy committed
2274
2275
2276
2277
2278
2279
2280
2281
2282
                        # only for nvfp4
                        *(
                            [
                                "w13_blockscale_swizzled",
                                "w2_blockscale_swizzled",
                            ]
                            if hasattr(module, "w13_blockscale_swizzled")
                            else []
                        ),
fzyzcjy's avatar
fzyzcjy committed
2283
2284
2285
2286
2287
                    ]
                    if isinstance(module, FusedMoE)
                    else []
                ),
            ),
Liangsheng Yin's avatar
Liangsheng Yin committed
2288
        )
2289
2290
2291
2292
        if self.pp_group.is_last_rank:
            self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        else:
            self.norm = PPMissingLayer(return_tuple=True)
Liangsheng Yin's avatar
Liangsheng Yin committed
2293

2294
2295
2296
2297
2298
2299
2300
2301
2302
2303
2304
2305
2306
2307
2308
2309
2310
2311
2312
2313
2314
2315
2316
2317
2318
2319
2320
2321
2322
2323
2324
        self.gemm_output_zero_allocator_size = 0
        if (
            _use_aiter_gfx95
            and config.n_routed_experts == 256
            and self.embed_tokens.embedding_dim == 7168
        ):
            num_moe_layers = sum(
                [
                    1
                    for i in range(len(self.layers))
                    if isinstance(self.layers[i].mlp, DeepseekV2MoE)
                ]
            )

            allocate_size = 0
            for i in range(len(self.layers)):
                if isinstance(self.layers[i].mlp, DeepseekV2MoE):
                    allocate_size = self.layers[
                        i
                    ].mlp.shared_experts.gate_up_proj.output_size_per_partition
                    break

            self.gemm_output_zero_allocator_size = (
                get_dsv3_gemm_output_zero_allocator_size(
                    config.n_routed_experts,
                    num_moe_layers,
                    allocate_size,
                    self.embed_tokens.embedding_dim,
                )
            )

2325
2326
2327
    def get_input_embeddings(self) -> torch.Tensor:
        return self.embed_tokens

Liangsheng Yin's avatar
Liangsheng Yin committed
2328
2329
2330
2331
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
2332
        forward_batch: ForwardBatch,
2333
        input_embeds: torch.Tensor = None,
2334
2335
2336
        pp_proxy_tensors: Optional[PPProxyTensors] = None,
    ) -> Union[torch.Tensor, PPProxyTensors]:
        total_num_layers = self.end_layer - self.start_layer
2337
        device = input_embeds.device if input_embeds is not None else input_ids.device
2338
        zero_allocator = BumpAllocator(
2339
            buffer_size=total_num_layers * 2 * (2 if forward_batch.can_run_tbo else 1),
2340
            dtype=torch.float32,
2341
            device=device,
2342
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
2343

2344
2345
2346
2347
2348
2349
2350
2351
2352
2353
2354
2355
2356
2357
2358
        has_gemm_output_zero_allocator = hasattr(
            self, "gemm_output_zero_allocator_size"
        )

        gemm_output_zero_allocator = (
            BumpAllocator(
                buffer_size=self.gemm_output_zero_allocator_size,
                dtype=torch.float32,
                device=device,
            )
            if has_gemm_output_zero_allocator
            and self.gemm_output_zero_allocator_size > 0
            else None
        )

2359
2360
2361
2362
2363
2364
        if self.pp_group.is_first_rank:
            if input_embeds is None:
                hidden_states = self.embed_tokens(input_ids)
            else:
                hidden_states = input_embeds
            residual = None
2365
        else:
2366
2367
2368
            assert pp_proxy_tensors is not None
            hidden_states = pp_proxy_tensors["hidden_states"]
            residual = pp_proxy_tensors["residual"]
2369

2370
2371
2372
2373
2374
2375
2376
2377
2378
2379
        normal_start_layer = self.start_layer
        normal_end_layer = self.end_layer
        if forward_batch.can_run_tbo:
            if (
                self.first_k_dense_replace > normal_start_layer
                and self.first_k_dense_replace < normal_end_layer
            ):
                normal_end_layer = self.first_k_dense_replace
            elif self.first_k_dense_replace < normal_start_layer:
                normal_end_layer = normal_start_layer = 0
2380

2381
        for i in range(normal_start_layer, normal_end_layer):
2382
2383
2384
            with get_global_expert_distribution_recorder().with_current_layer(i):
                layer = self.layers[i]
                hidden_states, residual = layer(
2385
2386
2387
2388
2389
2390
                    positions,
                    hidden_states,
                    forward_batch,
                    residual,
                    zero_allocator,
                    gemm_output_zero_allocator,
2391
                )
2392

2393
        if normal_end_layer != self.end_layer:
2394
            hidden_states, residual = model_forward_maybe_tbo(
2395
                layers=self.layers[normal_end_layer : self.end_layer],
2396
2397
2398
2399
2400
                enable_tbo=True,
                positions=positions,
                forward_batch=forward_batch,
                hidden_states=hidden_states,
                residual=residual,
2401
                input_data_scatter_mode=self.layers[
2402
                    normal_end_layer - 1
2403
                ].layer_scatter_modes.layer_output_mode,
2404
2405
2406
                zero_allocator=zero_allocator,
            )

2407
2408
2409
2410
2411
2412
2413
2414
2415
2416
2417
2418
2419
        if not self.pp_group.is_last_rank:
            return PPProxyTensors(
                {
                    "hidden_states": hidden_states,
                    "residual": residual,
                }
            )
        else:
            if not forward_batch.forward_mode.is_idle():
                if residual is None:
                    hidden_states = self.norm(hidden_states)
                else:
                    hidden_states, _ = self.norm(hidden_states, residual)
Liangsheng Yin's avatar
Liangsheng Yin committed
2420
2421
2422
2423
        return hidden_states


class DeepseekV2ForCausalLM(nn.Module):
2424
2425
    # for quark model load
    packed_modules_mapping = {}
Liangsheng Yin's avatar
Liangsheng Yin committed
2426
2427
2428
2429
2430

    def __init__(
        self,
        config: PretrainedConfig,
        quant_config: Optional[QuantizationConfig] = None,
2431
        prefix: str = "",
Liangsheng Yin's avatar
Liangsheng Yin committed
2432
2433
    ) -> None:
        super().__init__()
2434
2435
2436
2437
2438
2439
2440
2441
2442
2443
2444
2445

        # for quark model load
        # Fuse q_a_proj and kv_a_proj_with_mqa along output dimension when q_lora_rank is not None
        self.fuse_qkv_a_proj = (
            hasattr(config, "q_lora_rank") and config.q_lora_rank is not None
        )
        if self.fuse_qkv_a_proj:
            self.packed_modules_mapping["fused_qkv_a_proj_with_mqa"] = [
                "q_a_proj",
                "kv_a_proj_with_mqa",
            ]

2446
        self.pp_group = get_pp_group()
Liangsheng Yin's avatar
Liangsheng Yin committed
2447
        self.config = config
2448
        self.tp_size = get_tensor_model_parallel_world_size()
Liangsheng Yin's avatar
Liangsheng Yin committed
2449
        self.quant_config = quant_config
2450
        self.determine_num_fused_shared_experts()
2451
2452
2453
2454
2455
2456
2457
2458
        self.model = DeepseekV2Model(
            config, quant_config, prefix=add_prefix("model", prefix)
        )
        self.lm_head = ParallelLMHead(
            config.vocab_size,
            config.hidden_size,
            quant_config=quant_config,
            prefix=add_prefix("lm_head", prefix),
2459
            use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
2460
2461
2462
        )
        self.logits_processor = LogitsProcessor(config)

2463
2464
2465
2466
2467
2468
2469
2470
2471
2472
2473
2474
        self._routed_experts_weights_of_layer = LazyValue(
            lambda: {
                layer_id: layer.mlp.get_moe_weights()
                for layer_id, layer in enumerate(self.model.layers)
                if isinstance(layer.mlp, DeepseekV2MoE)
            }
        )

    @property
    def routed_experts_weights_of_layer(self):
        return self._routed_experts_weights_of_layer.value

2475
    def determine_num_fused_shared_experts(
2476
2477
        self, architecture: str = "DeepseekV3ForCausalLM"
    ):
2478
2479
2480
2481
2482
2483
2484
2485
        self.num_fused_shared_experts = 0
        if global_server_args_dict["disable_shared_experts_fusion"]:
            return

        # Only Deepseek V3/R1 can use shared experts fusion optimization now.
        disable_reason = None
        if (
            not _is_cuda
2486
            or torch.cuda.get_device_capability("cuda") < (8, 0)
2487
2488
2489
2490
            or self.config.architectures[0] != architecture
            or self.config.n_routed_experts != 256
            or self.config.n_shared_experts != 1
        ):
2491
            disable_reason = "Only Deepseek V3/R1 on NV-platform with capability >= 80 can use shared experts fusion optimization."
2492
2493
        elif get_moe_expert_parallel_world_size() > 1:
            disable_reason = "Deepseek V3/R1 can not use shared experts fusion optimization under expert parallelism."
2494
2495
        elif self.quant_config.get_name() == "w4afp8":
            disable_reason = "Deepseek V3/R1 W4AFP8 model uses different quant method for routed experts and shared experts."
2496
2497
2498

        if disable_reason is not None:
            global_server_args_dict["disable_shared_experts_fusion"] = True
Cheng Wan's avatar
Cheng Wan committed
2499
            self.num_fused_shared_experts = 0
2500
2501
2502
2503
2504
2505
2506
            log_info_on_rank0(
                logger,
                f"{disable_reason} Shared experts fusion optimization is disabled.",
            )
            return

        self.num_fused_shared_experts = self.config.n_shared_experts
2507

Mick's avatar
Mick committed
2508
2509
2510
    def get_input_embeddings(self) -> nn.Embedding:
        return self.model.embed_tokens

2511
    @torch.no_grad()
Liangsheng Yin's avatar
Liangsheng Yin committed
2512
2513
2514
2515
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
2516
        forward_batch: ForwardBatch,
2517
        input_embeds: torch.Tensor = None,
2518
        pp_proxy_tensors: Optional[PPProxyTensors] = None,
Liangsheng Yin's avatar
Liangsheng Yin committed
2519
    ) -> torch.Tensor:
2520
2521
        hidden_states = self.model(
            input_ids, positions, forward_batch, input_embeds, pp_proxy_tensors
2522
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
2523

2524
2525
2526
2527
2528
2529
2530
2531
2532
2533
2534
2535
2536
2537
2538
        if self.pp_group.is_last_rank:
            return self.logits_processor(
                input_ids, hidden_states, self.lm_head, forward_batch
            )
        else:
            return hidden_states

    @property
    def start_layer(self):
        return self.model.start_layer

    @property
    def end_layer(self):
        return self.model.end_layer

2539
    def post_load_weights(self, is_nextn=False, weight_names=None):
inkcherry's avatar
inkcherry committed
2540
2541

        # Perform post-processing after loading weights
2542
2543
2544
2545
        if is_nextn:
            layer_ids = [self.config.num_hidden_layers]
        else:
            if weight_names is None:
2546
                layer_ids = range(self.model.start_layer, self.model.end_layer)
2547
2548
2549
2550
2551
            else:
                layer_ids = set()
                for name in weight_names:
                    if "kv_b_proj" in name:
                        layer_id = int(name.split(".")[2])
2552
                        if layer_id < self.config.num_hidden_layers:
2553
2554
                            layer_ids.add(layer_id)

2555
2556
2557
2558
2559
2560
        for layer_id in layer_ids:
            self_attn = (
                self.model.layers[layer_id].self_attn
                if not is_nextn
                else self.model.decoder.self_attn
            )
Baizhou Zhang's avatar
Baizhou Zhang committed
2561
2562
            if hasattr(self_attn.kv_b_proj, "qweight"):
                # AWQ compatible
2563
                if _is_cuda or _is_hip:
Baizhou Zhang's avatar
Baizhou Zhang committed
2564
2565
2566
2567
2568
                    w = awq_dequantize(
                        self_attn.kv_b_proj.qweight,
                        self_attn.kv_b_proj.scales,
                        self_attn.kv_b_proj.qzeros,
                    ).T
inkcherry's avatar
inkcherry committed
2569
                else:
Baizhou Zhang's avatar
Baizhou Zhang committed
2570
2571
2572
2573
2574
2575
2576
2577
2578
2579
2580
2581
                    w = awq_dequantize(
                        self_attn.kv_b_proj.qweight,
                        self_attn.kv_b_proj.scales,
                        self_attn.kv_b_proj.qzeros,
                        0,
                        0,
                        0,
                    ).T
            else:
                w = self_attn.kv_b_proj.weight
            # NOTE(HandH1998): Since `bmm_fp8` only supports per-tensor scale, we have to requantize `self_attn.kv_b_proj`.
            # This may affect the accuracy of fp8 model.
2582
2583
2584
            # Fix deepseek v3 blockwise bmm by using deep_gemm
            use_deep_gemm_bmm = False

Baizhou Zhang's avatar
Baizhou Zhang committed
2585
2586
2587
2588
            if w.dtype in (
                torch.float8_e4m3fn,
                torch.float8_e4m3fnuz,
            ):
2589
2590
2591
2592
                if (
                    hasattr(self.quant_config, "weight_block_size")
                    and self.quant_config.weight_block_size is not None
                ):
Baizhou Zhang's avatar
Baizhou Zhang committed
2593
                    weight_block_size = self.quant_config.weight_block_size
2594
2595
2596
2597
2598
2599
2600
2601
2602
2603
2604
2605
2606
2607
2608
2609
                    assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
                    if _is_fp8_fnuz:
                        weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
                            weight=w,
                            weight_scale=self_attn.kv_b_proj.weight_scale_inv,
                            input_scale=None,
                        )
                    else:
                        weight = w
                        weight_scale = self_attn.kv_b_proj.weight_scale_inv

                    if (
                        _is_cuda
                        and weight_block_size[0] == 128
                        and weight_block_size[1] == 128
                    ):
2610
2611
2612
2613
                        if (
                            deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
                            and not deep_gemm_wrapper.DEEPGEMM_BLACKWELL
                            and get_bool_env_var("SGL_USE_DEEPGEMM_BMM", "false")
2614
                        ):
2615
2616
                            block_scale = weight_scale
                            use_deep_gemm_bmm = True
2617
                        else:
2618
2619
2620
2621
                            w = block_quant_dequant(
                                weight,
                                weight_scale,
                                weight_block_size,
2622
                                torch.bfloat16,
2623
                            )
2624
2625
2626
2627
2628
                    else:
                        w, scale = block_quant_to_tensor_quant(
                            weight, weight_scale, weight_block_size
                        )
                        self_attn.w_scale = scale
Baizhou Zhang's avatar
Baizhou Zhang committed
2629
                else:
2630
2631
2632
2633
2634
2635
2636
2637
2638
2639
                    if _is_fp8_fnuz:
                        weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
                            weight=w,
                            weight_scale=self_attn.kv_b_proj.weight_scale,
                            input_scale=None,
                        )
                    else:
                        weight = w
                        weight_scale = self_attn.kv_b_proj.weight_scale

Baizhou Zhang's avatar
Baizhou Zhang committed
2640
2641
2642
2643
2644
2645
2646
2647
2648
2649
2650
2651
2652
2653
2654
2655
2656
2657
2658
                    w, scale = channel_quant_to_tensor_quant(weight, weight_scale)
                    self_attn.w_scale = scale

            if w.dtype == torch.int8:
                if hasattr(self.quant_config, "weight_block_size"):
                    # block-wise int8 need it
                    weight_block_size = self.quant_config.weight_block_size
                    if weight_block_size is not None:
                        assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
                        weight = w
                        weight_scale = self_attn.kv_b_proj.weight_scale_inv
                        w = int8_block_dequant(
                            weight, weight_scale, weight_block_size
                        ).to(torch.bfloat16)
                else:
                    # channel-wise int8 need it
                    w = w.to(torch.bfloat16) * self_attn.kv_b_proj.weight_scale.to(
                        torch.bfloat16
                    )
2659

Baizhou Zhang's avatar
Baizhou Zhang committed
2660
2661
2662
            w_kc, w_vc = w.unflatten(
                0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
            ).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
2663

2664
2665
2666
2667
2668
            if (
                _use_aiter_gfx95
                and self.quant_config is not None
                and self.quant_config.get_name() == "quark"
            ):
2669
2670
2671
2672
                w_kc, self_attn.w_scale_k, w_vc, self_attn.w_scale_v = (
                    quark_post_load_weights(self_attn, w, "mxfp4")
                )

2673
            if not use_deep_gemm_bmm:
2674
2675
2676
2677
2678
2679
                self_attn.w_kc = bind_or_assign(
                    self_attn.w_kc, w_kc.transpose(1, 2).contiguous().transpose(1, 2)
                )
                self_attn.w_vc = bind_or_assign(
                    self_attn.w_vc, w_vc.contiguous().transpose(1, 2)
                )
2680
2681
2682
2683
                if (
                    hasattr(self_attn.kv_b_proj, "weight_scale")
                    and self_attn.w_scale is None
                ):
2684
2685
2686
                    self_attn.w_scale = bind_or_assign(
                        self_attn.w_scale, self_attn.kv_b_proj.weight_scale
                    )
2687
2688
                    if _is_hip:
                        self_attn.w_scale *= 2.0
2689
2690
2691
2692
2693
2694
2695
2696
                # TODO: remove this after adding FP8 support in bmm cpu kernel
                if _is_cpu and _is_cpu_amx_available and w.dtype == torch.float8_e4m3fn:
                    self_attn.w_kc = (
                        self_attn.w_kc.to(torch.bfloat16) * self_attn.w_scale
                    )
                    self_attn.w_vc = (
                        self_attn.w_vc.to(torch.bfloat16) * self_attn.w_scale
                    )
2697
2698
2699
2700
2701
2702
            else:
                num_tiles_k = self_attn.qk_nope_head_dim // weight_block_size[1]
                num_tiles_n = self_attn.v_head_dim // weight_block_size[0]
                ws_kc, ws_vc = block_scale.unflatten(
                    0, (-1, (num_tiles_k + num_tiles_n))
                ).split([num_tiles_k, num_tiles_n], dim=1)
2703
2704
2705
2706
2707
2708
2709
2710
2711
2712
                self_attn.w_scale_k = bind_or_assign(
                    self_attn.w_scale_k, ws_kc.transpose(1, 2).contiguous()
                )
                self_attn.w_scale_v = bind_or_assign(
                    self_attn.w_scale_v, ws_vc.contiguous()
                )
                self_attn.w_kc = bind_or_assign(
                    self_attn.w_kc, w_kc.transpose(1, 2).contiguous()
                )
                self_attn.w_vc = bind_or_assign(self_attn.w_vc, w_vc.contiguous())
2713
                self_attn.use_deep_gemm_bmm = True
inkcherry's avatar
inkcherry committed
2714

2715
2716
2717
        if (
            deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
            and deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
2718
2719
            and hasattr(self.quant_config, "weight_block_size")
            and self.quant_config.weight_block_size is not None
2720
        ):
2721
            self._weight_requant_ue8m0(is_nextn)
2722

2723
    def _weight_requant_ue8m0(self, is_nextn=False):
2724
2725
2726
2727
2728
2729
2730
2731
2732
2733
        weight_block_size = self.quant_config.weight_block_size

        moe_layers = list(
            range(
                self.config.first_k_dense_replace,
                self.config.num_hidden_layers,
                self.config.moe_layer_freq,
            )
        )

2734
        num_hidden_layers = 1 if is_nextn else self.config.num_hidden_layers
2735

2736
2737
2738
2739
2740
        for layer_id in range(num_hidden_layers):
            if is_nextn:
                layer = self.model.decoder
            else:
                layer = self.model.layers[layer_id]
2741

2742
            module_list = [
2743
2744
                layer.self_attn.kv_b_proj,
                layer.self_attn.o_proj,
2745
2746
2747
2748
2749
2750
2751
2752
2753
2754
            ]

            if self.config.q_lora_rank is not None:
                module_list.append(layer.self_attn.fused_qkv_a_proj_with_mqa)
                module_list.append(layer.self_attn.q_b_proj)
            else:
                module_list.append(layer.self_attn.kv_a_proj_with_mqa)
                module_list.append(layer.self_attn.q_proj)

            for module in module_list:
2755
2756
2757
2758
                requant_weight_ue8m0_inplace(
                    module.weight, module.weight_scale_inv, weight_block_size
                )

2759
            if layer_id in moe_layers or is_nextn:
2760
2761
2762
2763
2764
2765
2766
2767
2768
                shared_experts = getattr(layer.mlp, "shared_experts", None)
                if shared_experts is not None:
                    for module in [
                        shared_experts.gate_up_proj,
                        shared_experts.down_proj,
                    ]:
                        requant_weight_ue8m0_inplace(
                            module.weight, module.weight_scale_inv, weight_block_size
                        )
2769
2770
2771
2772
2773
2774
2775
2776
2777
2778
2779
2780
2781
2782
2783
2784
2785
2786
2787

                experts = layer.mlp.experts
                if isinstance(experts, DeepEPMoE):
                    for w in [
                        experts.w13_weight_fp8,
                        experts.w2_weight_fp8,
                    ]:
                        requant_weight_ue8m0_inplace(w[0], w[1], weight_block_size)
            else:
                mlp = layer.mlp
                assert isinstance(mlp, DeepseekV2MLP)
                for module in [
                    mlp.gate_up_proj,
                    mlp.down_proj,
                ]:
                    requant_weight_ue8m0_inplace(
                        module.weight, module.weight_scale_inv, weight_block_size
                    )

2788
    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], is_nextn=False):
2789

2790
2791
2792
        if is_nextn:
            if hasattr(self.config, "num_nextn_predict_layers"):
                num_nextn_layers = self.config.num_nextn_predict_layers
2793
                assert num_nextn_layers == 1, "Only 1 nextn layer is supported"
2794
2795
2796
2797
2798
2799
2800
2801
2802
                # compatible with old design
                nextn_layer_id = (
                    0
                    if self.config.num_hidden_layers == 1
                    else self.config.num_hidden_layers
                )
            else:
                raise ValueError("num_nextn_predict_layers is not in the config")

Liangsheng Yin's avatar
Liangsheng Yin committed
2803
2804
2805
2806
2807
2808
2809
2810
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
        ]

        # Params for weights, fp8 weight scales, fp8 activation scales
        # (param_name, weight_name, expert_id, shard_id)
2811
        expert_params_mapping = FusedMoE.make_expert_params_mapping(
Liangsheng Yin's avatar
Liangsheng Yin committed
2812
2813
2814
            ckpt_gate_proj_name="gate_proj",
            ckpt_down_proj_name="down_proj",
            ckpt_up_proj_name="up_proj",
2815
            num_experts=self.config.n_routed_experts + self.num_fused_shared_experts,
Liangsheng Yin's avatar
Liangsheng Yin committed
2816
        )
2817
2818
2819
        # Params for special naming rules in mixed-precision models, for example:
        # model.layers.xx.mlp.experts.xx.w1.input_scale. For details,
        # see https://huggingface.co/Barrrrry/DeepSeek-R1-W4AFP8/blob/main.
2820
        if self.quant_config and self.quant_config.get_name() == "w4afp8":
2821
2822
            expert_params_mapping += FusedMoE.make_expert_input_scale_params_mapping(
                num_experts=self.config.n_routed_experts
2823
            )
Liangsheng Yin's avatar
Liangsheng Yin committed
2824

2825
2826
2827
2828
2829
2830
        # Fuse q_a_proj and kv_a_proj_with_mqa along output dimension when q_lora_rank is not None
        fuse_qkv_a_proj = hasattr(self.config, "q_lora_rank") and (
            self.config.q_lora_rank is not None
        )
        cached_a_proj = {} if fuse_qkv_a_proj else None

2831
2832
2833
2834
2835
2836
2837
2838
2839
        if is_nextn:
            nextn_layer_prefix = f"model.layers.{nextn_layer_id}"
            nextn_spec_weight_names = [
                "shared_head.norm",
                "eh_proj",
                "enorm",
                "hnorm",
            ]

2840
2841
        if self.num_fused_shared_experts > 0:
            assert self.num_fused_shared_experts == 1
2842
            log_info_on_rank0(logger, "Shared experts fusion optimization enabled.")
2843

2844
2845
2846
2847
2848
        with concurrent.futures.ThreadPoolExecutor() as executor:
            futures = []
            params_dict = dict(self.named_parameters())
            weight_names = []
            for name, loaded_weight in weights:
2849
2850
2851
2852
2853
2854
2855
2856
2857
2858
                layer_id = get_layer_id(name)
                if (
                    layer_id is not None
                    and hasattr(self.model, "start_layer")
                    and (
                        layer_id < self.model.start_layer
                        or layer_id >= self.model.end_layer
                    )
                ):
                    continue
2859
2860
2861
2862
2863
                if self.num_fused_shared_experts > 0 and "mlp.shared_experts" in name:
                    name = name.replace(
                        "mlp.shared_experts",
                        f"mlp.experts.{self.config.n_routed_experts}",
                    )
2864

2865
                weight_names.append(name)
2866

2867
2868
2869
2870
2871
2872
2873
2874
2875
2876
2877
2878
2879
                if not is_nextn:
                    if hasattr(self.config, "num_nextn_predict_layers"):
                        num_nextn_layers = self.config.num_nextn_predict_layers
                        if num_nextn_layers > 0 and name.startswith("model.layers"):
                            name_list = name.split(".")
                            if (
                                len(name_list) >= 3
                                and int(name_list[2]) >= self.config.num_hidden_layers
                            ):
                                continue
                else:
                    if not name.startswith(nextn_layer_prefix):
                        continue
2880

2881
2882
2883
                    # Use shared head and embed weights from target model
                    if "shared_head.head" in name or "embed_tokens" in name:
                        continue
2884

2885
2886
2887
2888
2889
2890
2891
2892
2893
2894
2895
2896
                    is_decoder = True
                    # For nextn specific weights
                    for weight_name in nextn_spec_weight_names:
                        if weight_name in name:
                            name = name.replace(nextn_layer_prefix, "model")
                            is_decoder = False
                            break
                    # For decoder layer weights
                    if is_decoder:
                        name = name.replace(nextn_layer_prefix, "model.decoder")

                if "rotary_emb.inv_freq" in name:
Liangsheng Yin's avatar
Liangsheng Yin committed
2897
                    continue
2898
2899
                for param_name, weight_name, shard_id in stacked_params_mapping:
                    # Skip non-stacked layers and experts (experts handled below).
Liangsheng Yin's avatar
Liangsheng Yin committed
2900
2901
                    if weight_name not in name:
                        continue
2902
2903
2904
2905
2906
2907
2908
2909
                    # We have mlp.experts[0].gate_proj in the checkpoint.
                    # Since we handle the experts below in expert_params_mapping,
                    # we need to skip here BEFORE we update the name, otherwise
                    # name will be updated to mlp.experts[0].gate_up_proj, which
                    # will then be updated below in expert_params_mapping
                    # for mlp.experts[0].gate_gate_up_proj, which breaks load.
                    if ("mlp.experts." in name) and name not in params_dict:
                        continue
Liangsheng Yin's avatar
Liangsheng Yin committed
2910
                    name = name.replace(weight_name, param_name)
2911
2912
2913
                    # Skip loading extra bias for GPTQ models.
                    if name.endswith(".bias") and name not in params_dict:
                        continue
Liangsheng Yin's avatar
Liangsheng Yin committed
2914
2915
                    param = params_dict[name]
                    weight_loader = param.weight_loader
2916
2917
                    futures.append(
                        executor.submit(weight_loader, param, loaded_weight, shard_id)
Liangsheng Yin's avatar
Liangsheng Yin committed
2918
2919
2920
                    )
                    break
                else:
2921
2922
2923
2924
2925
2926
2927
2928
2929
2930
2931
2932
2933
2934
2935
2936
                    for mapping in expert_params_mapping:
                        param_name, weight_name, expert_id, shard_id = mapping
                        if weight_name not in name:
                            continue
                        name = name.replace(weight_name, param_name)
                        param = params_dict[name]
                        weight_loader = param.weight_loader
                        futures.append(
                            executor.submit(
                                weight_loader,
                                param,
                                loaded_weight,
                                name,
                                shard_id=shard_id,
                                expert_id=expert_id,
                            )
2937
                        )
2938
2939
2940
2941
2942
                        break
                    else:
                        # Skip loading extra bias for GPTQ models.
                        if name.endswith(".bias") and name not in params_dict:
                            continue
2943
2944
2945
2946
2947
2948
                        # Skip loading embed_tokens if not first rank in pipeline parallelism
                        if ".embed_tokens." in name and not self.pp_group.is_first_rank:
                            continue
                        # Skip loading norm if not last rank in pipeline parallelism
                        if ".norm." in name and not self.pp_group.is_last_rank:
                            continue
2949
2950
                        if fuse_qkv_a_proj and (
                            "q_a_proj" in name or "kv_a_proj_with_mqa" in name
2951
                        ):
2952
2953
2954
                            cached_a_proj[name] = loaded_weight
                            q_a_proj_name = (
                                name
2955
                                if "q_a_proj" in name
2956
2957
2958
2959
2960
2961
                                else name.replace("kv_a_proj_with_mqa", "q_a_proj")
                            )
                            kv_a_proj_name = (
                                name
                                if "kv_a_proj_with_mqa" in name
                                else name.replace("q_a_proj", "kv_a_proj_with_mqa")
2962
2963
                            )

2964
2965
2966
2967
2968
2969
2970
2971
2972
2973
                            # When both q_a_proj and kv_a_proj_with_mqa has been cached, load the fused weight to parameter
                            if (
                                q_a_proj_name in cached_a_proj
                                and kv_a_proj_name in cached_a_proj
                            ):
                                q_a_proj_weight = cached_a_proj[q_a_proj_name]
                                kv_a_proj_weight = cached_a_proj[kv_a_proj_name]
                                cat_dim = 0
                                if self.quant_config is not None and (
                                    self.quant_config.get_name() == "awq"
2974
                                    or self.quant_config.get_name() == "awq_marlin"
2975
2976
2977
2978
2979
2980
2981
2982
2983
2984
2985
2986
2987
2988
2989
2990
2991
2992
2993
2994
2995
2996
2997
2998
2999
3000
3001
3002
3003
3004
3005
3006
3007
3008
3009
3010
3011
3012
3013
3014
3015
3016
3017
3018
                                    or self.quant_config.get_name() == "moe_wna16"
                                ):
                                    cat_dim = 1
                                fused_weight = torch.cat(
                                    [q_a_proj_weight, kv_a_proj_weight], dim=cat_dim
                                )
                                param_name = (
                                    name.replace(
                                        "q_a_proj", "fused_qkv_a_proj_with_mqa"
                                    )
                                    if "q_a_proj" in name
                                    else name.replace(
                                        "kv_a_proj_with_mqa",
                                        "fused_qkv_a_proj_with_mqa",
                                    )
                                )
                                param = params_dict[param_name]

                                weight_loader = getattr(
                                    param, "weight_loader", default_weight_loader
                                )
                                futures.append(
                                    executor.submit(weight_loader, param, fused_weight)
                                )
                                cached_a_proj.pop(q_a_proj_name)
                                cached_a_proj.pop(kv_a_proj_name)
                        else:
                            if (
                                "k_scale" in name or "v_scale" in name
                            ) and name not in params_dict:
                                # modelopt attn kv scale is named differently
                                for scale in ["k_scale", "v_scale"]:
                                    if scale in name:
                                        name = name.replace(
                                            f"{scale[0]}_proj", "attn_mqa"
                                        )
                                        break
                            if name not in params_dict:
                                # modelopt ckpt contains not needed weights for MTP module:
                                # model.decoder.self_attn.attn_mqa.v_scale and
                                # model.decoder.self_attn.attn_mqa.k_scale
                                logger.warning(f"{name} not found in params_dict.")
                                continue
                            param = params_dict[name]
3019
3020
3021
                            weight_loader = getattr(
                                param, "weight_loader", default_weight_loader
                            )
3022
3023
3024
3025
3026
3027
3028
                            futures.append(
                                executor.submit(weight_loader, param, loaded_weight)
                            )

            # Wait for all tasks to complete and raise any exceptions.
            for future in concurrent.futures.as_completed(futures):
                future.result()
Liangsheng Yin's avatar
Liangsheng Yin committed
3029

3030
        self.post_load_weights(is_nextn=is_nextn, weight_names=weight_names)
Ke Bao's avatar
Ke Bao committed
3031

3032
3033
3034
3035
3036
3037
3038
3039
3040
3041
3042
    def get_embed_and_head(self):
        return self.model.embed_tokens.weight, self.lm_head.weight

    def set_embed_and_head(self, embed, head):
        del self.model.embed_tokens.weight
        del self.lm_head.weight
        self.model.embed_tokens.weight = embed
        self.lm_head.weight = head
        torch.cuda.empty_cache()
        torch.cuda.synchronize()

3043
3044
3045
3046
3047
3048
3049
3050
    @classmethod
    def get_model_config_for_expert_location(cls, config):
        return ModelConfigForExpertLocation(
            num_layers=config.num_hidden_layers,
            num_logical_experts=config.n_routed_experts,
            num_groups=config.n_group,
        )

Liangsheng Yin's avatar
Liangsheng Yin committed
3051

3052
3053
3054
3055
3056
3057
3058
3059
3060
3061
3062
BackendRegistry.register("ascend", handle_ascend)
BackendRegistry.register("flashinfer", handle_flashinfer)
BackendRegistry.register("fa3", handle_fa3)
BackendRegistry.register("flashmla", handle_flashmla)
BackendRegistry.register("cutlass_mla", handle_cutlass_mla)
BackendRegistry.register("fa4", handle_fa4)
BackendRegistry.register("trtllm_mla", handle_trtllm_mla)
BackendRegistry.register("aiter", handle_aiter)
BackendRegistry.register("triton", handle_triton)


HandH1998's avatar
HandH1998 committed
3063
3064
3065
3066
3067
class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM):
    pass


EntryClass = [DeepseekV2ForCausalLM, DeepseekV3ForCausalLM]