deepseek_v2.py 114 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
14

Liangsheng Yin's avatar
Liangsheng Yin committed
15
16
17
# Adapted from:
# https://github.com/vllm-project/vllm/blob/fb6af8bc086328ca6659e72d11ffd4309ce4de22/vllm/model_executor/models/deepseek_v2.py
"""Inference-only DeepseekV2 model."""
18

19
import concurrent.futures
20
import logging
21
import os
22
from enum import IntEnum, auto
23
from typing import Any, Dict, Iterable, Optional, Tuple, Union
Liangsheng Yin's avatar
Liangsheng Yin committed
24
25

import torch
Ke Bao's avatar
Ke Bao committed
26
import torch.nn.functional as F
Liangsheng Yin's avatar
Liangsheng Yin committed
27
from torch import nn
28
from tqdm import tqdm
Liangsheng Yin's avatar
Liangsheng Yin committed
29
from transformers import PretrainedConfig
30
31

from sglang.srt.distributed import (
32
    get_moe_expert_parallel_world_size,
33
    get_pp_group,
Liangsheng Yin's avatar
Liangsheng Yin committed
34
    get_tensor_model_parallel_world_size,
35
    parallel_state,
Liangsheng Yin's avatar
Liangsheng Yin committed
36
37
    tensor_model_parallel_all_reduce,
)
38
39
40
from sglang.srt.distributed.device_communicators.pynccl_allocator import (
    use_symmetric_memory,
)
fzyzcjy's avatar
fzyzcjy committed
41
42
43
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation
from sglang.srt.eplb.expert_location_dispatch import ExpertLocationDispatchInfo
44
from sglang.srt.layers.activation import SiluAndMul
45
from sglang.srt.layers.amx_utils import PackWeightMethod
46
47
48
49
50
from sglang.srt.layers.communicator import (
    LayerCommunicator,
    LayerScatterModes,
    enable_moe_dense_fully_dp,
)
Lianmin Zheng's avatar
Lianmin Zheng committed
51
52
53
from sglang.srt.layers.dp_attention import (
    get_attention_tp_rank,
    get_attention_tp_size,
54
    is_dp_attention_enabled,
Lianmin Zheng's avatar
Lianmin Zheng committed
55
)
56
from sglang.srt.layers.layernorm import RMSNorm
57
58
59
60
61
62
from sglang.srt.layers.linear import (
    ColumnParallelLinear,
    MergedColumnParallelLinear,
    ReplicatedLinear,
    RowParallelLinear,
)
Liangsheng Yin's avatar
Liangsheng Yin committed
63
from sglang.srt.layers.logits_processor import LogitsProcessor
64
65
66
67
68
from sglang.srt.layers.moe import (
    get_deepep_mode,
    get_moe_a2a_backend,
    should_use_flashinfer_cutlass_moe_fp4_allgather,
)
69
from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, get_moe_impl_class
70
71
72
73
from sglang.srt.layers.moe.fused_moe_triton.layer import (
    FusedMoE,
    _is_fp4_quantization_enabled,
)
74
from sglang.srt.layers.moe.topk import TopK
75
from sglang.srt.layers.quantization import deep_gemm_wrapper
76
from sglang.srt.layers.quantization.base_config import QuantizationConfig
77
from sglang.srt.layers.quantization.fp8_kernel import (
78
    is_fp8_fnuz,
79
    per_tensor_quant_mla_fp8,
80
    per_token_group_quant_mla_deep_gemm_masked_fp8,
81
)
HandH1998's avatar
HandH1998 committed
82
from sglang.srt.layers.quantization.fp8_utils import (
83
    block_quant_dequant,
HandH1998's avatar
HandH1998 committed
84
    block_quant_to_tensor_quant,
85
    channel_quant_to_tensor_quant,
86
    normalize_e4m3fn_to_e4m3fnuz,
87
    requant_weight_ue8m0_inplace,
HandH1998's avatar
HandH1998 committed
88
)
89
90
91
from sglang.srt.layers.quantization.int8_utils import (
    block_dequant as int8_block_dequant,
)
Liangsheng Yin's avatar
Liangsheng Yin committed
92
from sglang.srt.layers.radix_attention import RadixAttention
93
94
from sglang.srt.layers.rotary_embedding import get_rope_wrapper
from sglang.srt.layers.utils import PPMissingLayer, get_layer_id
95
96
97
98
from sglang.srt.layers.vocab_parallel_embedding import (
    ParallelLMHead,
    VocabParallelEmbedding,
)
99
from sglang.srt.managers.schedule_batch import global_server_args_dict
100
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
101
from sglang.srt.model_loader.weight_utils import default_weight_loader
102
103
104
105
from sglang.srt.two_batch_overlap import (
    MaybeTboDeepEPDispatcher,
    model_forward_maybe_tbo,
)
106
107
from sglang.srt.utils import (
    BumpAllocator,
108
    LazyValue,
109
    add_prefix,
110
    bind_or_assign,
111
    cpu_has_amx_support,
112
    get_bool_env_var,
113
    get_device_sm,
114
    get_int_env_var,
115
    is_cpu,
116
    is_cuda,
117
    is_flashinfer_available,
118
    is_gfx95_supported,
119
    is_hip,
120
    is_non_idle_and_non_empty,
121
    is_npu,
122
    is_sm100_supported,
123
    log_info_on_rank0,
124
    make_layers,
125
    use_intel_amx_backend,
126
)
127

128
_is_hip = is_hip()
Yineng Zhang's avatar
Yineng Zhang committed
129
_is_cuda = is_cuda()
130
_is_npu = is_npu()
131
_is_fp8_fnuz = is_fp8_fnuz()
132
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
133
134
_is_cpu_amx_available = cpu_has_amx_support()
_is_cpu = is_cpu()
135
_device_sm = get_device_sm()
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
_is_gfx95_supported = is_gfx95_supported()

_use_aiter_gfx95 = _use_aiter and _is_gfx95_supported

if _use_aiter_gfx95:
    from sglang.srt.layers.quantization.quark.utils import quark_post_load_weights
    from sglang.srt.layers.quantization.rocm_mxfp4_utils import (
        batched_gemm_afp4wfp4_pre_quant,
        fused_flatten_mxfp4_quant,
        fused_rms_mxfp4_quant,
    )
    from sglang.srt.layers.rocm_linear_utils import (
        aiter_dsv3_router_gemm,
        fused_qk_rope_cat,
        get_dsv3_gemm_output_zero_allocator_size,
    )
152

Yineng Zhang's avatar
Yineng Zhang committed
153
if _is_cuda:
154
155
156
157
158
159
160
    from sgl_kernel import (
        awq_dequantize,
        bmm_fp8,
        dsv3_fused_a_gemm,
        dsv3_router_gemm,
        merge_state_v2,
    )
161
162
elif _is_cpu and _is_cpu_amx_available:
    pass
163
164
165
166
elif _is_hip:
    from sglang.srt.layers.quantization.awq_triton import (
        awq_dequantize_triton as awq_dequantize,
    )
Yineng Zhang's avatar
Yineng Zhang committed
167
else:
Lianmin Zheng's avatar
Lianmin Zheng committed
168
    from vllm._custom_ops import awq_dequantize
Liangsheng Yin's avatar
Liangsheng Yin committed
169

170
171
172
173
174
if _is_hip:
    from sglang.srt.layers.attention.triton_ops.rocm_mla_decode_rope import (
        decode_attention_fwd_grouped_rope,
    )

175
176
177
_is_flashinfer_available = is_flashinfer_available()
_is_sm100_supported = is_cuda() and is_sm100_supported()

178

179
180
logger = logging.getLogger(__name__)

Liangsheng Yin's avatar
Liangsheng Yin committed
181

182
183
184
185
186
187
188
189
190
191
192
class AttnForwardMethod(IntEnum):
    # Use multi-head attention
    MHA = auto()

    # Use absorbed multi-latent attention
    MLA = auto()

    # Use multi-head attention, but with KV cache chunked.
    # This method can avoid OOM when prefix lengths are long.
    MHA_CHUNKED_KV = auto()

193
194
195
    # Use MLA but with fused RoPE
    MLA_FUSED_ROPE = auto()

196
197
198
    # Use MLA with fused RoPE kernel for CPU
    MLA_FUSED_ROPE_CPU = auto()

199

Liangsheng Yin's avatar
Liangsheng Yin committed
200
201
202
203
204
205
206
207
class DeepseekV2MLP(nn.Module):
    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
        quant_config: Optional[QuantizationConfig] = None,
        reduce_results: bool = True,
208
        prefix: str = "",
209
210
        tp_rank: Optional[int] = None,
        tp_size: Optional[int] = None,
Liangsheng Yin's avatar
Liangsheng Yin committed
211
212
    ) -> None:
        super().__init__()
213
214
        self.tp_size = tp_size

Liangsheng Yin's avatar
Liangsheng Yin committed
215
        self.gate_up_proj = MergedColumnParallelLinear(
216
217
218
219
220
            hidden_size,
            [intermediate_size] * 2,
            bias=False,
            quant_config=quant_config,
            prefix=add_prefix("gate_up_proj", prefix),
221
222
            tp_rank=tp_rank,
            tp_size=tp_size,
Liangsheng Yin's avatar
Liangsheng Yin committed
223
224
225
226
227
228
229
        )
        self.down_proj = RowParallelLinear(
            intermediate_size,
            hidden_size,
            bias=False,
            quant_config=quant_config,
            reduce_results=reduce_results,
230
            prefix=add_prefix("down_proj", prefix),
231
232
            tp_rank=tp_rank,
            tp_size=tp_size,
Liangsheng Yin's avatar
Liangsheng Yin committed
233
234
235
236
237
238
239
240
        )
        if hidden_act != "silu":
            raise ValueError(
                f"Unsupported activation: {hidden_act}. "
                "Only silu is supported for now."
            )
        self.act_fn = SiluAndMul()

241
242
243
244
    def forward(
        self,
        x,
        forward_batch=None,
245
        should_allreduce_fusion: bool = False,
246
        use_reduce_scatter: bool = False,
247
        gemm_output_zero_allocator: BumpAllocator = None,
248
    ):
249
250
251
        if (self.tp_size == 1) and x.shape[0] == 0:
            return x

252
253
254
255
256
        if (
            gemm_output_zero_allocator is not None
            and x.shape[0] <= 256
            and self.gate_up_proj.weight.dtype == torch.uint8
        ):
257
258
259
260
261
            y = gemm_output_zero_allocator.allocate(
                x.shape[0] * self.gate_up_proj.output_size_per_partition
            ).view(x.shape[0], self.gate_up_proj.output_size_per_partition)
            x = (x, None, y)

Liangsheng Yin's avatar
Liangsheng Yin committed
262
263
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
264
        x, _ = self.down_proj(
265
            x, skip_all_reduce=should_allreduce_fusion or use_reduce_scatter
266
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
267
268
269
        return x


Ke Bao's avatar
Ke Bao committed
270
class MoEGate(nn.Module):
271
272
273
274
    def __init__(
        self,
        config,
        prefix: str = "",
275
        is_nextn: bool = False,
276
    ):
Ke Bao's avatar
Ke Bao committed
277
        super().__init__()
278
        self.is_nextn = is_nextn
Ke Bao's avatar
Ke Bao committed
279
280
281
282
283
        self.weight = nn.Parameter(
            torch.empty((config.n_routed_experts, config.hidden_size))
        )
        if config.topk_method == "noaux_tc":
            self.e_score_correction_bias = nn.Parameter(
284
                torch.empty((config.n_routed_experts), dtype=torch.float32)
Ke Bao's avatar
Ke Bao committed
285
286
287
            )
        else:
            self.e_score_correction_bias = None
288
289
        if _is_cpu and _is_cpu_amx_available:
            self.quant_method = PackWeightMethod(weight_names=["weight"])
Ke Bao's avatar
Ke Bao committed
290

291
    def forward(self, hidden_states, gemm_output_zero_allocator: BumpAllocator = None):
292
        if use_intel_amx_backend(self):
293
294
295
296
297
298
299
            return torch.ops.sgl_kernel.weight_packed_linear(
                hidden_states,
                self.weight,
                None,  # bias
                True,  # is_vnni
            )

300
        # NOTE: For some unknown reason, router_gemm seems degrade accept length.
301
        if (
302
            _is_cuda
303
            and hidden_states.shape[0] <= 16
304
305
306
307
            and hidden_states.shape[1] == 7168
            and self.weight.shape[0] == 256
            and _device_sm >= 90
        ):
308
            # router gemm output float32
309
310
311
            logits = dsv3_router_gemm(
                hidden_states, self.weight, out_dtype=torch.float32
            )
312
313
314
315
        elif _use_aiter_gfx95 and hidden_states.shape[0] <= 256:
            logits = aiter_dsv3_router_gemm(
                hidden_states, self.weight, gemm_output_zero_allocator
            )
316
317
318
        else:
            logits = F.linear(hidden_states, self.weight, None)

Ke Bao's avatar
Ke Bao committed
319
320
321
        return logits


Liangsheng Yin's avatar
Liangsheng Yin committed
322
323
324
325
326
class DeepseekV2MoE(nn.Module):

    def __init__(
        self,
        config: PretrainedConfig,
fzyzcjy's avatar
fzyzcjy committed
327
        layer_id: int,
Liangsheng Yin's avatar
Liangsheng Yin committed
328
        quant_config: Optional[QuantizationConfig] = None,
329
        prefix: str = "",
330
        alt_stream: Optional[torch.cuda.Stream] = None,
331
        is_nextn: bool = False,
Liangsheng Yin's avatar
Liangsheng Yin committed
332
333
334
335
336
    ):
        super().__init__()
        self.tp_size = get_tensor_model_parallel_world_size()
        self.routed_scaling_factor = config.routed_scaling_factor
        self.n_shared_experts = config.n_shared_experts
337
338
339
340
341
        self.num_fused_shared_experts = (
            0
            if global_server_args_dict["disable_shared_experts_fusion"]
            else config.n_shared_experts
        )
342
        self.config = config
fzyzcjy's avatar
fzyzcjy committed
343
        self.layer_id = layer_id
344
        self.alt_stream = alt_stream
345

Liangsheng Yin's avatar
Liangsheng Yin committed
346
347
348
349
350
351
352
353
354
355
356
357
        if self.tp_size > config.n_routed_experts:
            raise ValueError(
                f"Tensor parallel size {self.tp_size} is greater than "
                f"the number of experts {config.n_routed_experts}."
            )

        if config.hidden_act != "silu":
            raise ValueError(
                f"Unsupported activation: {config.hidden_act}. "
                "Only silu is supported for now."
            )

358
359
360
        self.gate = MoEGate(
            config=config, prefix=add_prefix("gate", prefix), is_nextn=is_nextn
        )
Ke Bao's avatar
Ke Bao committed
361

362
        self.experts = get_moe_impl_class(quant_config)(
363
            num_experts=config.n_routed_experts
364
            + self.num_fused_shared_experts
365
            + global_server_args_dict["ep_num_redundant_experts"],
Cheng Wan's avatar
Cheng Wan committed
366
            num_fused_shared_experts=self.num_fused_shared_experts,
367
            top_k=config.num_experts_per_tok + self.num_fused_shared_experts,
368
369
            hidden_size=config.hidden_size,
            intermediate_size=config.moe_intermediate_size,
fzyzcjy's avatar
fzyzcjy committed
370
            layer_id=self.layer_id,
371
            quant_config=quant_config,
372
            routed_scaling_factor=self.routed_scaling_factor,
373
374
            prefix=add_prefix("experts", prefix),
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
375

376
377
378
        correction_bias = self.gate.e_score_correction_bias
        if _is_fp4_quantization_enabled():
            correction_bias = correction_bias.to(torch.bfloat16)
379
380
381
382
383
384
385
        self.topk = TopK(
            top_k=config.num_experts_per_tok + self.num_fused_shared_experts,
            renormalize=config.norm_topk_prob,
            use_grouped_topk=True,
            num_expert_group=config.n_group,
            num_fused_shared_experts=self.num_fused_shared_experts,
            topk_group=config.topk_group,
386
            correction_bias=correction_bias,
387
388
            routed_scaling_factor=self.routed_scaling_factor,
            apply_routed_scaling_factor_on_output=self.experts.should_fuse_routed_scaling_factor_in_topk(),
389
            force_topk=quant_config is None,
390
391
        )

392
393
394
        self.shared_experts_is_int8 = False
        self.shared_experts_is_fp8 = False
        self.shared_experts_weight_block_size = None
395
        if config.n_shared_experts is not None and self.num_fused_shared_experts == 0:
Liangsheng Yin's avatar
Liangsheng Yin committed
396
            intermediate_size = config.moe_intermediate_size * config.n_shared_experts
397
            # disable tp for shared experts when enable deepep moe, or with fp4 allgather
398
399
400
401
402
403
404
405
406
            self.shared_experts = DeepseekV2MLP(
                hidden_size=config.hidden_size,
                intermediate_size=intermediate_size,
                hidden_act=config.hidden_act,
                quant_config=quant_config,
                reduce_results=False,
                prefix=add_prefix("shared_experts", prefix),
                **(
                    dict(tp_rank=0, tp_size=1)
407
                    if get_moe_a2a_backend().is_deepep()
408
                    or should_use_flashinfer_cutlass_moe_fp4_allgather()
409
410
411
                    else {}
                ),
            )
AniZpZ's avatar
AniZpZ committed
412
413
414
415
            is_packed_weight = hasattr(
                self.shared_experts.gate_up_proj.quant_method, "quant_config"
            ) and self.shared_experts.gate_up_proj.quant_method.quant_config.get_name() in {
                "awq",
416
                "awq_marlin",
AniZpZ's avatar
AniZpZ committed
417
418
                "moe_wna16",
            }
419
            self.shared_experts_is_int8 = (
420
421
                not is_packed_weight
                and self.shared_experts.gate_up_proj.weight.dtype == torch.int8
422
423
            )
            self.shared_experts_is_fp8 = (
424
425
                not is_packed_weight
                and self.shared_experts.gate_up_proj.weight.dtype == torch.float8_e4m3fn
426
427
428
429
430
431
432
433
434
            )
            if self.shared_experts_is_fp8:
                assert (
                    self.shared_experts.gate_up_proj.quant_method.quant_config.weight_block_size
                    == self.shared_experts.down_proj.quant_method.quant_config.weight_block_size
                )
                self.shared_experts_weight_block_size = (
                    self.shared_experts.gate_up_proj.quant_method.quant_config.weight_block_size
                )
435

436
437
        self.top_k = config.num_experts_per_tok

438
        if get_moe_a2a_backend().is_deepep():
439
            # TODO: we will support tp < ep in the future
440
            self.ep_size = get_moe_expert_parallel_world_size()
441
442
443
444
            self.num_experts = (
                config.n_routed_experts
                + global_server_args_dict["ep_num_redundant_experts"]
            )
445
446
447
448
449
450
451
452
453
            self.renormalize = config.norm_topk_prob
            self.topk_group = config.topk_group
            self.num_expert_group = config.n_group
            self.correction_bias = (
                self.gate.e_score_correction_bias.data
                if self.gate.e_score_correction_bias is not None
                else None
            )

454
            self.deepep_dispatcher = MaybeTboDeepEPDispatcher(
455
456
457
                group=parallel_state.get_tp_group().device_group,
                router_topk=self.top_k,
                permute_fusion=True,
458
                num_experts=self.num_experts,
459
                num_local_experts=config.n_routed_experts // self.tp_size,
Liangsheng Yin's avatar
Liangsheng Yin committed
460
                hidden_size=config.hidden_size,
461
                params_dtype=config.torch_dtype,
462
                deepep_mode=get_deepep_mode(),
463
                async_finish=True,
464
                return_recv_hook=True,
Liangsheng Yin's avatar
Liangsheng Yin committed
465
466
            )

467
        self._enable_deepep_moe = get_moe_a2a_backend().is_deepep()
468

469
470
471
472
473
474
475
    def get_moe_weights(self):
        return [
            x.data
            for name, x in self.experts.named_parameters()
            if name not in ["correction_bias"]
        ]

476
    def forward(
477
478
479
        self,
        hidden_states: torch.Tensor,
        forward_batch: Optional[ForwardBatch] = None,
480
        should_allreduce_fusion: bool = False,
481
        use_reduce_scatter: bool = False,
482
        gemm_output_zero_allocator: BumpAllocator = None,
483
484
    ) -> torch.Tensor:
        if not self._enable_deepep_moe:
485
486
487
488
            DUAL_STREAM_TOKEN_THRESHOLD = 1024
            if (
                self.alt_stream is not None
                and self.num_fused_shared_experts == 0
489
                and hidden_states.shape[0] > 0
490
491
                and hidden_states.shape[0] <= DUAL_STREAM_TOKEN_THRESHOLD
            ):
492
                return self.forward_normal_dual_stream(
493
494
495
                    hidden_states,
                    should_allreduce_fusion,
                    use_reduce_scatter,
496
                    gemm_output_zero_allocator,
497
                )
498
            else:
499
                return self.forward_normal(
500
501
502
                    hidden_states,
                    should_allreduce_fusion,
                    use_reduce_scatter,
503
                    gemm_output_zero_allocator,
504
                )
505
506
507
        else:
            return self.forward_deepep(hidden_states, forward_batch)

508
    def forward_normal_dual_stream(
509
510
        self,
        hidden_states: torch.Tensor,
511
        should_allreduce_fusion: bool = False,
512
        use_reduce_scatter: bool = False,
513
        gemm_output_zero_allocator: BumpAllocator = None,
514
    ) -> torch.Tensor:
515

516
517
        current_stream = torch.cuda.current_stream()
        self.alt_stream.wait_stream(current_stream)
518
519
520
        shared_output = self._forward_shared_experts(
            hidden_states, gemm_output_zero_allocator
        )
521

522
        with torch.cuda.stream(self.alt_stream):
523
            # router_logits: (num_tokens, n_experts)
524
            router_logits = self.gate(hidden_states, gemm_output_zero_allocator)
Cheng Wan's avatar
Cheng Wan committed
525
526
            topk_output = self.topk(hidden_states, router_logits)
            final_hidden_states = self.experts(hidden_states, topk_output)
527
528
            if not _is_cuda:
                final_hidden_states *= self.routed_scaling_factor
Cheng Wan's avatar
Cheng Wan committed
529

530
        current_stream.wait_stream(self.alt_stream)
531
532
        with use_symmetric_memory(parallel_state.get_tp_group()) as sm:
            final_hidden_states_out = torch.empty_like(final_hidden_states)
Cheng Wan's avatar
Cheng Wan committed
533

534
535
536
        torch.add(final_hidden_states, shared_output, out=final_hidden_states_out)
        final_hidden_states = final_hidden_states_out
        sm.tag(final_hidden_states)
537
538
539
540
541
542
        if (
            self.tp_size > 1
            and not should_allreduce_fusion
            and not use_reduce_scatter
            and not should_use_flashinfer_cutlass_moe_fp4_allgather()
        ):
543
544
545
            final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
        return final_hidden_states

546
    def forward_normal(
547
548
        self,
        hidden_states: torch.Tensor,
549
        should_allreduce_fusion: bool = False,
550
        use_reduce_scatter: bool = False,
551
        gemm_output_zero_allocator: BumpAllocator = None,
552
    ) -> torch.Tensor:
553
554
        if hasattr(self, "shared_experts") and use_intel_amx_backend(
            self.shared_experts.gate_up_proj
555
        ):
556
            return self.forward_cpu(hidden_states, should_allreduce_fusion)
557

558
        if hidden_states.shape[0] > 0:
559
560
561
            shared_output = self._forward_shared_experts(
                hidden_states, gemm_output_zero_allocator
            )
562
            # router_logits: (num_tokens, n_experts)
563
            router_logits = self.gate(hidden_states, gemm_output_zero_allocator)
564
565
566
567
            topk_output = self.topk(hidden_states, router_logits)
        else:
            shared_output = None
            topk_output = self.topk.empty_topk_output(hidden_states.device)
568

Cheng Wan's avatar
Cheng Wan committed
569
        final_hidden_states = self.experts(hidden_states, topk_output)
570
571
        if not _is_cuda and not _use_aiter:
            # fused in biased_grouped_topk so we can skip here
572
            final_hidden_states *= self.routed_scaling_factor
573
        if shared_output is not None:
574
575
576
577
578
            with use_symmetric_memory(parallel_state.get_tp_group()) as sm:
                final_hidden_states_out = torch.empty_like(final_hidden_states)
            torch.add(final_hidden_states, shared_output, out=final_hidden_states_out)
            final_hidden_states = final_hidden_states_out
            sm.tag(final_hidden_states)
579
580
581
582
583
584
        if (
            self.tp_size > 1
            and not should_allreduce_fusion
            and not use_reduce_scatter
            and not should_use_flashinfer_cutlass_moe_fp4_allgather()
        ):
585
586
587
            final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
        return final_hidden_states

588
    def forward_cpu(
589
590
591
        self,
        hidden_states: torch.Tensor,
        should_allreduce_fusion: bool = False,
592
    ) -> torch.Tensor:
593
594
        # router_logits: (num_tokens, n_experts)
        router_logits = self.gate(hidden_states)
595
        topk_output = self.topk(hidden_states, router_logits)
596
        fused_experts_out = self.experts(
597
            hidden_states=hidden_states, topk_output=topk_output
598
599
        )

600
601
602
        assert use_intel_amx_backend(
            self.shared_experts.gate_up_proj
        ) == use_intel_amx_backend(self.shared_experts.down_proj)
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
        # [Note] inplace should be False in fused_experts.
        # If inplace is True in fused_experts (self.experts), hidden_states will be changed after fused_experts
        # While hidden_states is still needed in shared_expert.
        final_hidden_states = torch.ops.sgl_kernel.shared_expert_cpu(
            hidden_states,
            self.shared_experts.gate_up_proj.weight,
            self.shared_experts.down_proj.weight,
            fused_experts_out,
            self.routed_scaling_factor,
            True,  # inplace
            self.shared_experts_is_int8,  # use_int8_w8a8
            self.shared_experts_is_fp8,  # use_fp8_w8a16
            (
                self.shared_experts.gate_up_proj.weight_scale
                if self.shared_experts_is_int8
                else (
                    self.shared_experts.gate_up_proj.weight_scale_inv
                    if self.shared_experts_is_fp8
                    else None
                )
            ),  # w1_scale
            (
                self.shared_experts.down_proj.weight_scale
                if self.shared_experts_is_int8
                else (
                    self.shared_experts.down_proj.weight_scale_inv
                    if self.shared_experts_is_fp8
                    else None
                )
            ),  # w2_scale
            (
                self.shared_experts_weight_block_size
                if self.shared_experts_is_fp8
                else None
            ),  # block_size
            None,  # a1_scale
            None,  # a2_scale
            True,  # is_vnni
        )
642
        if self.tp_size > 1 and not should_allreduce_fusion:
643
644
645
            final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
        return final_hidden_states

646
647
648
649
    def forward_deepep(
        self, hidden_states: torch.Tensor, forward_batch: ForwardBatch
    ) -> torch.Tensor:
        shared_output = None
Cheng Wan's avatar
Cheng Wan committed
650
        if hidden_states.shape[0] > 0:
651
652
653
            # router_logits: (num_tokens, n_experts)
            router_logits = self.gate(hidden_states)
            shared_output = self._forward_shared_experts(hidden_states)
654
655
656
            topk_weights, topk_idx, _ = self.topk(
                hidden_states,
                router_logits,
657
                num_token_non_padded=forward_batch.num_token_non_padded,
658
659
660
                expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new(
                    layer_id=self.layer_id,
                ),
661
662
            )
        else:
663
664
            topk_weights, topk_idx, _ = self.topk.empty_topk_output(
                hidden_states.device
665
            )
666

667
668
669
670
        final_hidden_states = self.experts(
            hidden_states=hidden_states,
            topk_idx=topk_idx,
            topk_weights=topk_weights,
671
            forward_batch=forward_batch,
672
673
674
        )

        if shared_output is not None:
675
676
677
678
679
            x = shared_output
            x.add_(final_hidden_states, alpha=self.routed_scaling_factor)
            final_hidden_states = x
        else:
            final_hidden_states *= self.routed_scaling_factor
680
681
682

        return final_hidden_states

683
684
685
    def _forward_shared_experts(
        self, hidden_states, gemm_output_zero_allocator: BumpAllocator = None
    ):
686
        if self.num_fused_shared_experts == 0:
687
688
689
            return self.shared_experts(
                hidden_states, gemm_output_zero_allocator=gemm_output_zero_allocator
            )
690
691
692
        else:
            return None

693
    def op_gate(self, state):
694
        if is_non_idle_and_non_empty(
695
            state.forward_batch.forward_mode, state.hidden_states_mlp_input
696
        ):
697
            # router_logits: (num_tokens, n_experts)
698
            state.router_logits = self.gate(state.hidden_states_mlp_input)
699
        else:
700
            state.router_logits = None
701

702
    def op_shared_experts(self, state):
703
        hidden_states_mlp_input = state.pop("hidden_states_mlp_input")
704
        if (self.num_fused_shared_experts == 0) and is_non_idle_and_non_empty(
705
            state.forward_batch.forward_mode, hidden_states_mlp_input
706
        ):
707
            state.shared_output = self.shared_experts(hidden_states_mlp_input)
708
        else:
709
            state.shared_output = None
710

711
    def op_select_experts(self, state):
712
        router_logits = state.pop("router_logits")
713
714
        hidden_states = state.hidden_states_mlp_input

715
        if router_logits is not None:
716
717
718
            with get_global_expert_distribution_recorder().with_current_layer(
                self.layer_id
            ):
719
                state.topk_weights_local, state.topk_idx_local, _ = self.topk(
720
721
722
723
724
725
726
                    hidden_states=hidden_states,
                    router_logits=router_logits,
                    num_token_non_padded=state.forward_batch.num_token_non_padded,
                    expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new(
                        layer_id=self.layer_id,
                    ),
                )
727
728
729
730
731
732
733
        else:
            state.topk_idx_local = torch.full(
                (0, self.top_k), -1, dtype=torch.int, device=hidden_states.device
            )
            state.topk_weights_local = torch.empty(
                (0, self.top_k), dtype=torch.float32, device=hidden_states.device
            )
734

735
    def op_dispatch_a(self, state):
736
        if self.ep_size > 1:
737
            self.experts.deepep_dispatcher.dispatch_a(
738
                hidden_states=state.hidden_states_mlp_input,
739
740
                topk_idx=state.pop("topk_idx_local"),
                topk_weights=state.pop("topk_weights_local"),
741
                forward_batch=state.forward_batch,
742
                tbo_subbatch_index=state.get("tbo_subbatch_index"),
743
            )
744

745
    def op_dispatch_b(self, state):
746
747
748
749
        if self.ep_size > 1:
            with get_global_expert_distribution_recorder().with_current_layer(
                self.layer_id
            ):
750
                state.dispatch_output = self.experts.deepep_dispatcher.dispatch_b(
751
752
                    tbo_subbatch_index=state.get("tbo_subbatch_index"),
                )
753
754

    def op_experts(self, state):
755
756
        state.hidden_states_experts_output = self.experts.moe_impl(
            dispatch_output=state.dispatch_output,
757
        )
758

759
    def op_combine_a(self, state):
760
        if self.ep_size > 1:
761
            self.experts.deepep_dispatcher.combine_a(
762
                hidden_states=state.pop("hidden_states_experts_output"),
763
764
                topk_idx=state.dispatch_output.topk_idx,
                topk_weights=state.dispatch_output.topk_weights,
765
                forward_batch=state.forward_batch,
766
                tbo_subbatch_index=state.get("tbo_subbatch_index"),
767
            )
768
            state.pop("dispatch_output")
769

770
    def op_combine_b(self, state):
771
        if self.ep_size > 1:
772
773
774
775
            state.hidden_states_after_combine = (
                self.experts.deepep_dispatcher.combine_b(
                    tbo_subbatch_index=state.get("tbo_subbatch_index"),
                )
776
            )
777
778

    def op_output(self, state):
779
        final_hidden_states = state.pop("hidden_states_after_combine")
780
781
782
783
784
785
786

        if (shared_output := state.pop("shared_output")) is not None:
            x = shared_output
            x.add_(final_hidden_states, alpha=self.routed_scaling_factor)
            final_hidden_states = x
        else:
            final_hidden_states *= self.routed_scaling_factor
Liangsheng Yin's avatar
Liangsheng Yin committed
787

788
        state.hidden_states_mlp_output = final_hidden_states
789

Liangsheng Yin's avatar
Liangsheng Yin committed
790
791
792
793
794
795
796
797
798

def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
    import math

    if scale <= 1:
        return 1.0
    return 0.1 * mscale * math.log(scale) + 1.0


799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
class DeepseekV2AttentionMLA(nn.Module):

    def __init__(
        self,
        config: PretrainedConfig,
        hidden_size: int,
        num_heads: int,
        qk_nope_head_dim: int,
        qk_rope_head_dim: int,
        v_head_dim: int,
        q_lora_rank: int,
        kv_lora_rank: int,
        rope_theta: float = 10000,
        rope_scaling: Optional[Dict[str, Any]] = None,
        max_position_embeddings: int = 8192,
        quant_config: Optional[QuantizationConfig] = None,
Lianmin Zheng's avatar
Lianmin Zheng committed
815
816
        reduce_results: bool = True,
        layer_id: int = None,
817
        prefix: str = "",
818
        alt_stream: Optional[torch.cuda.Stream] = None,
819
820
821
822
823
824
825
826
827
828
    ) -> None:
        super().__init__()
        self.layer_id = layer_id
        self.hidden_size = hidden_size
        self.qk_nope_head_dim = qk_nope_head_dim
        self.qk_rope_head_dim = qk_rope_head_dim
        self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
        self.v_head_dim = v_head_dim
        self.q_lora_rank = q_lora_rank
        self.kv_lora_rank = kv_lora_rank
Lianmin Zheng's avatar
Lianmin Zheng committed
829
830
831
        attn_tp_rank = get_attention_tp_rank()
        attn_tp_size = get_attention_tp_size()

832
        self.num_heads = num_heads
Lianmin Zheng's avatar
Lianmin Zheng committed
833
834
        assert num_heads % attn_tp_size == 0
        self.num_local_heads = num_heads // attn_tp_size
835
836
837
838
        self.scaling = self.qk_head_dim**-0.5
        self.rope_theta = rope_theta
        self.max_position_embeddings = max_position_embeddings

Lianmin Zheng's avatar
Lianmin Zheng committed
839
840
        # For tensor parallel attention
        if self.q_lora_rank is not None:
841
            self.fused_qkv_a_proj_with_mqa = ReplicatedLinear(
Ke Bao's avatar
Ke Bao committed
842
                self.hidden_size,
843
                self.q_lora_rank + self.kv_lora_rank + self.qk_rope_head_dim,
844
845
                bias=False,
                quant_config=quant_config,
846
                prefix=add_prefix("fused_qkv_a_proj_with_mqa", prefix),
847
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
848
849
850
851
            self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
            self.q_b_proj = ColumnParallelLinear(
                q_lora_rank,
                self.num_heads * self.qk_head_dim,
Ke Bao's avatar
Ke Bao committed
852
853
                bias=False,
                quant_config=quant_config,
Lianmin Zheng's avatar
Lianmin Zheng committed
854
855
856
                prefix=add_prefix("q_b_proj", prefix),
                tp_rank=attn_tp_rank,
                tp_size=attn_tp_size,
Ke Bao's avatar
Ke Bao committed
857
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
858
859
        else:
            self.q_proj = ColumnParallelLinear(
860
                self.hidden_size,
Lianmin Zheng's avatar
Lianmin Zheng committed
861
                self.num_heads * self.qk_head_dim,
862
863
                bias=False,
                quant_config=quant_config,
Lianmin Zheng's avatar
Lianmin Zheng committed
864
865
866
                prefix=add_prefix("q_proj", prefix),
                tp_rank=attn_tp_rank,
                tp_size=attn_tp_size,
867
            )
868
869
870
871
872
873
874
875
            self.kv_a_proj_with_mqa = ReplicatedLinear(
                self.hidden_size,
                self.kv_lora_rank + self.qk_rope_head_dim,
                bias=False,
                quant_config=quant_config,
                prefix=add_prefix("kv_a_proj_with_mqa", prefix),
            )

Lianmin Zheng's avatar
Lianmin Zheng committed
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
        self.kv_b_proj = ColumnParallelLinear(
            self.kv_lora_rank,
            self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
            bias=False,
            quant_config=quant_config,
            prefix=add_prefix("kv_b_proj", prefix),
            tp_rank=attn_tp_rank,
            tp_size=attn_tp_size,
        )
        # O projection.
        self.o_proj = RowParallelLinear(
            self.num_heads * self.v_head_dim,
            self.hidden_size,
            bias=False,
            quant_config=quant_config,
            reduce_results=reduce_results,
            prefix=add_prefix("o_proj", prefix),
            tp_rank=attn_tp_rank,
            tp_size=attn_tp_size,
        )
896
        self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
Ke Bao's avatar
Ke Bao committed
897
898
899
900

        if rope_scaling:
            rope_scaling["rope_type"] = "deepseek_yarn"

901
        self.rotary_emb = get_rope_wrapper(
902
903
904
905
906
907
            qk_rope_head_dim,
            rotary_dim=qk_rope_head_dim,
            max_position=max_position_embeddings,
            base=rope_theta,
            rope_scaling=rope_scaling,
            is_neox_style=False,
908
            device=global_server_args_dict["device"],
909
910
911
912
913
914
915
        )

        if rope_scaling:
            mscale_all_dim = rope_scaling.get("mscale_all_dim", False)
            scaling_factor = rope_scaling["factor"]
            mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
            self.scaling = self.scaling * mscale * mscale
Ke Bao's avatar
Ke Bao committed
916
917
        else:
            self.rotary_emb.forward = self.rotary_emb.forward_native
918

919
        self.attn_mqa = RadixAttention(
920
921
922
923
924
925
            self.num_local_heads,
            self.kv_lora_rank + self.qk_rope_head_dim,
            self.scaling,
            num_kv_heads=1,
            layer_id=layer_id,
            v_head_dim=self.kv_lora_rank,
926
            quant_config=quant_config,
927
            prefix=add_prefix("attn_mqa", prefix),
928
929
        )

930
931
932
933
934
935
936
        self.attn_mha = RadixAttention(
            self.num_local_heads,
            self.qk_nope_head_dim + self.qk_rope_head_dim,
            self.scaling,
            num_kv_heads=self.num_local_heads,
            layer_id=layer_id,
            v_head_dim=self.v_head_dim,
937
            quant_config=quant_config,
938
            prefix=add_prefix("attn_mha", prefix),
939
940
        )

941
        self.alt_stream = alt_stream
942
        self.attn_mha.kv_b_proj = None
943

Ke Bao's avatar
Ke Bao committed
944
945
        self.w_kc = None
        self.w_vc = None
946
        self.w_scale = 1.0
947

948
949
950
951
        self.w_scale_k = None
        self.w_scale_v = None
        self.use_deep_gemm_bmm = False

Lianmin Zheng's avatar
Lianmin Zheng committed
952
953
954
        self.flashinfer_mla_disable_ragged = global_server_args_dict[
            "flashinfer_mla_disable_ragged"
        ]
955
956
957
        self.disable_chunked_prefix_cache = global_server_args_dict[
            "disable_chunked_prefix_cache"
        ]
958
959
960
961

        self.current_attention_backend = (
            None  # Attention backend used by current forward batch
        )
962
963
964
        self.rocm_fused_decode_mla = get_bool_env_var(
            "SGLANG_ROCM_FUSED_DECODE_MLA", "false"
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
965

966
        # TODO: Design a finer way to determine the threshold
967
968
969
        self.chunked_prefix_cache_threshold = get_int_env_var(
            "SGL_CHUNKED_PREFIX_CACHE_THRESHOLD", 8192
        )
970

971
972
973
        # If we have self.fused_qkv_a_proj_with_mqa and we're running on CPU, we will choose the torch.ops.sgl_kernel.qkv_proj_with_rope_fused_weight kernel
        # which requires self.w_kc and self.w_vc to be packed.
        # If not, we will use torch.bmm and weight shouldn't be packed in this case
AniZpZ's avatar
AniZpZ committed
974
975
        has_fused_proj = hasattr(self, "fused_qkv_a_proj_with_mqa")
        if has_fused_proj and _is_cpu and _is_cpu_amx_available:
976
977
978
979
            self.quant_method = PackWeightMethod(
                weight_names=["w_kc", "w_vc"], transpose_dims=[[1, 2], [1, 2]]
            )

980
        is_packed_weight = (
AniZpZ's avatar
AniZpZ committed
981
982
983
            has_fused_proj
            and hasattr(self.fused_qkv_a_proj_with_mqa.quant_method, "quant_config")
            and self.fused_qkv_a_proj_with_mqa.quant_method.quant_config.get_name()
984
            in {"awq", "awq_marlin", "moe_wna16"}
985
        )
986
        self.use_min_latency_fused_a_gemm = (
AniZpZ's avatar
AniZpZ committed
987
            has_fused_proj
988
            and not is_packed_weight
989
990
991
            and self.fused_qkv_a_proj_with_mqa.weight.dtype == torch.bfloat16
            and self.fused_qkv_a_proj_with_mqa.weight.shape[0] == 2112
            and self.fused_qkv_a_proj_with_mqa.weight.shape[1] == 7168
992
            and _is_cuda
993
            and _device_sm >= 90
994
995
        )

996
        self.qkv_proj_with_rope_is_int8 = (
AniZpZ's avatar
AniZpZ committed
997
            has_fused_proj
998
            and not is_packed_weight
999
1000
1001
            and self.fused_qkv_a_proj_with_mqa.weight.dtype == torch.int8
        )
        self.qkv_proj_with_rope_is_fp8 = (
AniZpZ's avatar
AniZpZ committed
1002
            has_fused_proj
1003
            and not is_packed_weight
1004
1005
1006
1007
            and self.fused_qkv_a_proj_with_mqa.weight.dtype == torch.float8_e4m3fn
        )

        self.weight_block_size = None
1008
1009
1010
1011
1012
1013
        if self.qkv_proj_with_rope_is_fp8 and _is_cpu and _is_cpu_amx_available:
            assert getattr(
                self.fused_qkv_a_proj_with_mqa.quant_method, "block_quant", False
            ) == getattr(self.q_b_proj.quant_method, "block_quant", False)
            use_block_quant = getattr(
                self.fused_qkv_a_proj_with_mqa.quant_method, "block_quant", False
1014
1015
            )

1016
1017
1018
1019
1020
1021
1022
1023
1024
            if use_block_quant:
                assert (
                    self.fused_qkv_a_proj_with_mqa.quant_method.quant_config.weight_block_size
                    == self.q_b_proj.quant_method.quant_config.weight_block_size
                )
                self.weight_block_size = (
                    self.fused_qkv_a_proj_with_mqa.quant_method.quant_config.weight_block_size
                )

1025
1026
1027
    def dispatch_attn_forward_method(
        self, forward_batch: ForwardBatch
    ) -> AttnForwardMethod:
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
        def _dispatch_mla_subtype():
            if _is_hip:
                if (
                    self.rocm_fused_decode_mla
                    and forward_batch.forward_mode.is_decode()
                ):
                    return AttnForwardMethod.MLA_FUSED_ROPE
                else:
                    return AttnForwardMethod.MLA
            else:
1038
1039
                if hasattr(self, "fused_qkv_a_proj_with_mqa") and use_intel_amx_backend(
                    self
1040
1041
1042
1043
                ):
                    return AttnForwardMethod.MLA_FUSED_ROPE_CPU
                else:
                    return AttnForwardMethod.MLA
1044

1045
1046
1047
        # Determine attention backend used by current forward batch
        if forward_batch.forward_mode.is_decode_or_idle():
            attention_backend = global_server_args_dict["decode_attention_backend"]
1048
1049
1050
1051
1052
1053
1054
1055
1056
        elif (
            forward_batch.forward_mode.is_target_verify()
            or forward_batch.forward_mode.is_draft_extend()
        ):
            # Use the specified backend for speculative operations (both verify and draft extend)
            if global_server_args_dict["speculative_attention_backend"] == "decode":
                attention_backend = global_server_args_dict["decode_attention_backend"]
            else:  # default to prefill
                attention_backend = global_server_args_dict["prefill_attention_backend"]
1057
1058
1059
1060
1061
        else:
            attention_backend = global_server_args_dict["prefill_attention_backend"]
        self.current_attention_backend = attention_backend

        if attention_backend == "ascend":
1062
1063
1064
1065
1066
1067
1068
1069
            if (
                forward_batch.forward_mode.is_extend()
                and not forward_batch.forward_mode.is_target_verify()
                and not forward_batch.forward_mode.is_draft_extend()
            ):
                return AttnForwardMethod.MHA
            else:
                return AttnForwardMethod.MLA
1070
1071
1072
1073
        elif (
            attention_backend == "flashinfer"
            or attention_backend == "fa3"
            or attention_backend == "flashmla"
1074
            or attention_backend == "cutlass_mla"
1075
1076
1077
1078
1079
1080
1081
        ):
            # Use MHA with chunked KV cache when prefilling on long sequences.
            sum_extend_prefix_lens = (
                sum(forward_batch.extend_prefix_lens_cpu)
                if forward_batch.extend_prefix_lens_cpu is not None
                else 0
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
1082
            # Flashinfer MLA: Do not absorb when enabling ragged prefill
1083
1084
1085
            disable_ragged = (
                attention_backend == "flashinfer" or attention_backend == "flashmla"
            ) and self.flashinfer_mla_disable_ragged
1086
            if (
1087
                not disable_ragged
Lianmin Zheng's avatar
Lianmin Zheng committed
1088
1089
1090
                and forward_batch.forward_mode.is_extend()
                and not forward_batch.forward_mode.is_target_verify()
                and not forward_batch.forward_mode.is_draft_extend()
1091
                and (
1092
1093
1094
1095
                    (
                        sum_extend_prefix_lens >= self.chunked_prefix_cache_threshold
                        and not self.disable_chunked_prefix_cache
                    )
1096
1097
                    or sum_extend_prefix_lens == 0
                )
1098
1099
1100
            ):
                return AttnForwardMethod.MHA_CHUNKED_KV
            else:
1101
                return _dispatch_mla_subtype()
1102
1103
1104
1105
1106
1107
1108
1109
1110
        elif attention_backend == "trtllm_mla":
            if (
                forward_batch.forward_mode.is_extend()
                and not forward_batch.forward_mode.is_target_verify()
                and not forward_batch.forward_mode.is_draft_extend()
            ):
                return AttnForwardMethod.MHA_CHUNKED_KV
            else:
                return _dispatch_mla_subtype()
1111
        elif attention_backend == "aiter":
1112
1113
1114
1115
1116
            if (
                forward_batch.forward_mode.is_extend()
                and not forward_batch.forward_mode.is_target_verify()
                and not forward_batch.forward_mode.is_draft_extend()
            ):
1117
1118
1119
1120
1121
1122
1123
                if is_dp_attention_enabled():
                    if sum(forward_batch.extend_prefix_lens_cpu) == 0:
                        return AttnForwardMethod.MHA
                    else:
                        return AttnForwardMethod.MLA
                else:
                    return AttnForwardMethod.MHA
1124
1125
            else:
                return AttnForwardMethod.MLA
Lianmin Zheng's avatar
Lianmin Zheng committed
1126
1127
        else:
            # Triton: Use normal computation for prefill and use weight absorption for extend/decode
1128
            if (
Lianmin Zheng's avatar
Lianmin Zheng committed
1129
1130
1131
                forward_batch.forward_mode.is_extend()
                and not forward_batch.forward_mode.is_target_verify()
                and not forward_batch.forward_mode.is_draft_extend()
1132
                and sum(forward_batch.extend_prefix_lens_cpu) == 0
1133
1134
1135
            ):
                return AttnForwardMethod.MHA
            else:
1136
                return _dispatch_mla_subtype()
Lianmin Zheng's avatar
Lianmin Zheng committed
1137

1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
    def op_prepare(self, state):
        state.attn_intermediate_state = self.forward_prepare(
            positions=state.positions,
            hidden_states=state.pop("hidden_states_after_comm_pre_attn"),
            forward_batch=state.forward_batch,
            zero_allocator=state.zero_allocator,
        )

    def op_core(self, state):
        state.hidden_states_after_attn = self.forward_core(
            state.pop("attn_intermediate_state")
        )

1151
1152
1153
1154
    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
1155
        forward_batch: ForwardBatch,
1156
        zero_allocator: BumpAllocator,
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
    ):
        s = self.forward_prepare(
            positions=positions,
            hidden_states=hidden_states,
            forward_batch=forward_batch,
            zero_allocator=zero_allocator,
        )
        return self.forward_core(s)

    def forward_prepare(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        forward_batch: ForwardBatch,
        zero_allocator: BumpAllocator,
    ):
1173
1174
1175
        if self.attn_mha.kv_b_proj is None:
            self.attn_mha.kv_b_proj = self.kv_b_proj

1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
        # when hidden_states is a tuple of tensors, the tuple will include quantized weight and scale tensor
        if isinstance(hidden_states, tuple):
            if hidden_states[0].shape[0] == 0:
                assert (
                    not self.o_proj.reduce_results
                ), "short-circuiting allreduce will lead to hangs"
                return hidden_states[0]
        else:
            if hidden_states.shape[0] == 0:
                assert (
                    not self.o_proj.reduce_results
                ), "short-circuiting allreduce will lead to hangs"
                return hidden_states, None, forward_batch, None
1189

1190
1191
1192
        attn_forward_method = self.dispatch_attn_forward_method(forward_batch)

        if attn_forward_method == AttnForwardMethod.MHA:
1193
1194
1195
            inner_state = self.forward_normal_prepare(
                positions, hidden_states, forward_batch, zero_allocator
            )
1196
        elif attn_forward_method == AttnForwardMethod.MHA_CHUNKED_KV:
1197
1198
            inner_state = self.forward_normal_chunked_kv_prepare(
                positions, hidden_states, forward_batch, zero_allocator
1199
            )
1200
        elif attn_forward_method == AttnForwardMethod.MLA:
1201
            inner_state = self.forward_absorb_prepare(
1202
1203
1204
                positions, hidden_states, forward_batch, zero_allocator
            )
        elif attn_forward_method == AttnForwardMethod.MLA_FUSED_ROPE:
1205
1206
            inner_state = self.forward_absorb_fused_mla_rope_prepare(
                positions, hidden_states, forward_batch, zero_allocator
1207
            )
1208
1209
1210
1211
        elif attn_forward_method == AttnForwardMethod.MLA_FUSED_ROPE_CPU:
            inner_state = self.forward_absorb_fused_mla_rope_cpu_prepare(
                positions, hidden_states, forward_batch, zero_allocator
            )
1212
        else:
1213
            raise NotImplementedError
1214
        return None, attn_forward_method, forward_batch, inner_state
1215

1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
    def forward_core(self, intermediate_state):
        hidden_states, attn_forward_method, forward_batch, inner_state = (
            intermediate_state
        )
        if inner_state is None:
            return hidden_states

        if attn_forward_method == AttnForwardMethod.MHA:
            return self.forward_normal_core(*inner_state)
        elif attn_forward_method == AttnForwardMethod.MHA_CHUNKED_KV:
            return self.forward_normal_chunked_kv_core(*inner_state)
        elif attn_forward_method == AttnForwardMethod.MLA:
            return self.forward_absorb_core(*inner_state)
        elif attn_forward_method == AttnForwardMethod.MLA_FUSED_ROPE:
            return self.forward_absorb_fused_mla_rope_core(*inner_state)
1231
1232
        elif attn_forward_method == AttnForwardMethod.MLA_FUSED_ROPE_CPU:
            return self.forward_absorb_fused_mla_rope_cpu_core(*inner_state)
1233
1234
1235
1236
        else:
            raise NotImplementedError

    def forward_normal_prepare(
1237
1238
1239
1240
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        forward_batch: ForwardBatch,
1241
1242
        zero_allocator: BumpAllocator,
    ):
1243
        if self.q_lora_rank is not None:
1244
1245
1246
            q, latent_cache = self.fused_qkv_a_proj_with_mqa(hidden_states)[0].split(
                [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1
            )
1247
1248
1249
1250
1251
1252
            q = self.q_a_layernorm(q)
            q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
        else:
            q = self.q_proj(hidden_states)[0].view(
                -1, self.num_local_heads, self.qk_head_dim
            )
1253
1254
            latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]

1255
1256
1257
        _, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
        kv_a, _ = latent_cache.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
        latent_cache = latent_cache.unsqueeze(1)
1258
        kv_a = self.kv_a_layernorm(kv_a)
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
        kv = self.kv_b_proj(kv_a)[0]
        kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim)
        k_nope = kv[..., : self.qk_nope_head_dim]
        v = kv[..., self.qk_nope_head_dim :]
        k_pe = latent_cache[:, :, self.kv_lora_rank :]
        q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
        q[..., self.qk_nope_head_dim :] = q_pe
        k = torch.empty_like(q)
        k[..., : self.qk_nope_head_dim] = k_nope
        k[..., self.qk_nope_head_dim :] = k_pe

1270
1271
1272
        if not _is_npu:
            latent_cache[:, :, : self.kv_lora_rank] = kv_a.unsqueeze(1)
            latent_cache[:, :, self.kv_lora_rank :] = k_pe
1273

1274
1275
1276
1277
1278
1279
1280
1281
1282
            # Save latent cache
            forward_batch.token_to_kv_pool.set_kv_buffer(
                self.attn_mha, forward_batch.out_cache_loc, latent_cache, None
            )
        else:
            # To reduce a time-costing split operation
            forward_batch.token_to_kv_pool.set_kv_buffer(
                self.attn_mha, forward_batch.out_cache_loc, kv_a.unsqueeze(1), k_pe
            )
1283
1284
1285
1286

        return q, k, v, forward_batch

    def forward_normal_core(self, q, k, v, forward_batch):
1287
1288
1289
1290
1291
        attn_output = self.attn_mha(q, k, v, forward_batch, save_kv_cache=False)
        attn_output = attn_output.reshape(-1, self.num_local_heads * self.v_head_dim)
        output, _ = self.o_proj(attn_output)
        return output

Faraz's avatar
Faraz committed
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
    def _fuse_rope_for_trtllm_mla(self, forward_batch: ForwardBatch) -> bool:
        """
        Check if we should skip rope and do fused rope+quantize for TRTLLM MLA decode in fp8_e4m3 path.
        """
        return (
            self.current_attention_backend == "trtllm_mla"
            and forward_batch.forward_mode.is_decode_or_idle()
            and forward_batch.attn_backend.data_type == torch.float8_e4m3fn
        )

1302
    def forward_absorb_prepare(
1303
1304
1305
1306
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        forward_batch: ForwardBatch,
1307
        zero_allocator: BumpAllocator,
1308
    ):
1309
        from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
1310

1311
        if self.q_lora_rank is not None:
1312
1313
1314
1315
1316
            if (
                (not isinstance(hidden_states, tuple))
                and hidden_states.shape[0] <= 16
                and self.use_min_latency_fused_a_gemm
            ):
1317
1318
1319
1320
1321
1322
                fused_qkv_a_proj_out = dsv3_fused_a_gemm(
                    hidden_states, self.fused_qkv_a_proj_with_mqa.weight.T
                )
            else:
                fused_qkv_a_proj_out = self.fused_qkv_a_proj_with_mqa(hidden_states)[0]
            q, latent_cache = fused_qkv_a_proj_out.split(
1323
1324
                [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1
            )
1325
1326
1327
            k_nope = latent_cache[..., : self.kv_lora_rank]

            # overlap qk norm
1328
            if self.alt_stream is not None and get_is_capture_mode():
1329
1330
1331
1332
1333
1334
1335
                current_stream = torch.cuda.current_stream()
                self.alt_stream.wait_stream(current_stream)
                q = self.q_a_layernorm(q)
                with torch.cuda.stream(self.alt_stream):
                    k_nope = self.kv_a_layernorm(k_nope)
                current_stream.wait_stream(self.alt_stream)
            else:
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
                if _use_aiter_gfx95 and self.q_b_proj.weight.dtype == torch.uint8:
                    q, k_nope = fused_rms_mxfp4_quant(
                        q,
                        self.q_a_layernorm.weight,
                        self.q_a_layernorm.variance_epsilon,
                        k_nope,
                        self.kv_a_layernorm.weight,
                        self.kv_a_layernorm.variance_epsilon,
                    )
                else:
                    q = self.q_a_layernorm(q)
                    k_nope = self.kv_a_layernorm(k_nope)
1348
1349

            k_nope = k_nope.unsqueeze(1)
1350
1351
1352
1353
1354
            q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
        else:
            q = self.q_proj(hidden_states)[0].view(
                -1, self.num_local_heads, self.qk_head_dim
            )
1355
            latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
1356
1357
1358
            k_nope = latent_cache[..., : self.kv_lora_rank]
            k_nope = self.kv_a_layernorm(k_nope).unsqueeze(1)

1359
        q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
1360
        k_pe = latent_cache[..., self.kv_lora_rank :].unsqueeze(1)
1361

1362
1363
        if self.use_deep_gemm_bmm:
            q_nope_val, q_nope_scale, masked_m, expected_m, aligned_m = (
1364
                per_token_group_quant_mla_deep_gemm_masked_fp8(q_nope.transpose(0, 1))
1365
1366
1367
1368
            )
            q_nope_out = q_nope.new_empty(
                (self.num_local_heads, aligned_m, self.kv_lora_rank)
            )
1369
            deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
1370
1371
1372
1373
1374
1375
1376
                (q_nope_val, q_nope_scale),
                (self.w_kc, self.w_scale_k),
                q_nope_out,
                masked_m,
                expected_m,
            )
            q_nope_out = q_nope_out[:, :expected_m, :]
1377
1378
        elif _is_hip:
            # TODO(haishaw): add bmm_fp8 to ROCm
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
            if _use_aiter_gfx95 and self.w_kc.dtype == torch.uint8:
                x = q_nope.transpose(0, 1)
                q_nope_out = torch.empty(
                    x.shape[0],
                    x.shape[1],
                    self.w_kc.shape[2],
                    device=x.device,
                    dtype=torch.bfloat16,
                )
                batched_gemm_afp4wfp4_pre_quant(
                    x,
                    self.w_kc.transpose(-2, -1),
                    self.w_scale_k.transpose(-2, -1),
                    torch.bfloat16,
                    q_nope_out,
                )
            else:
                q_nope_out = torch.bmm(
                    q_nope.to(torch.bfloat16).transpose(0, 1),
                    self.w_kc.to(torch.bfloat16) * self.w_scale,
                )
1400
        elif self.w_kc.dtype == torch.float8_e4m3fn:
1401
            q_nope_val, q_nope_scale = per_tensor_quant_mla_fp8(
Lianmin Zheng's avatar
Lianmin Zheng committed
1402
                q_nope.transpose(0, 1),
1403
                zero_allocator.allocate(1),
1404
1405
1406
1407
1408
1409
            )
            q_nope_out = bmm_fp8(
                q_nope_val, self.w_kc, q_nope_scale, self.w_scale, torch.bfloat16
            )
        else:
            q_nope_out = torch.bmm(q_nope.transpose(0, 1), self.w_kc)
1410
1411

        q_nope_out = q_nope_out.transpose(0, 1)
Faraz's avatar
Faraz committed
1412

1413
1414
1415
        if not self._fuse_rope_for_trtllm_mla(forward_batch) and (
            not _use_aiter or not _is_gfx95_supported
        ):
Faraz's avatar
Faraz committed
1416
            q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
1417

1418
        return q_pe, k_pe, q_nope_out, k_nope, forward_batch, zero_allocator, positions
1419
1420

    def forward_absorb_core(
1421
        self, q_pe, k_pe, q_nope_out, k_nope, forward_batch, zero_allocator, positions
1422
    ):
1423
        if (
1424
1425
1426
            self.current_attention_backend == "fa3"
            or self.current_attention_backend == "flashinfer"
            or self.current_attention_backend == "cutlass_mla"
1427
            or self.current_attention_backend == "trtllm_mla"
1428
            or self.current_attention_backend == "ascend"
1429
        ):
Faraz's avatar
Faraz committed
1430
1431
1432
1433
1434
1435
            extra_args = {}
            if self._fuse_rope_for_trtllm_mla(forward_batch):
                extra_args = {
                    "cos_sin_cache": self.rotary_emb.cos_sin_cache,
                    "is_neox": self.rotary_emb.is_neox_style,
                }
1436
            attn_output = self.attn_mqa(
Faraz's avatar
Faraz committed
1437
1438
1439
1440
1441
1442
1443
                q_nope_out,
                k_nope,
                k_nope,
                forward_batch,
                q_rope=q_pe,
                k_rope=k_pe,
                **extra_args,
1444
1445
            )
        else:
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
            if _use_aiter_gfx95:
                cos = self.rotary_emb.cos_cache
                sin = self.rotary_emb.sin_cache
                q, k = fused_qk_rope_cat(
                    q_nope_out,
                    q_pe,
                    k_nope,
                    k_pe,
                    positions,
                    cos,
                    sin,
                    self.rotary_emb.is_neox_style,
                )
            else:
                q = torch.cat([q_nope_out, q_pe], dim=-1)
                k = torch.cat([k_nope, k_pe], dim=-1)

1463
            attn_output = self.attn_mqa(q, k, k_nope, forward_batch)
1464
1465
        attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)

1466
1467
        if self.use_deep_gemm_bmm:
            attn_output_val, attn_output_scale, masked_m, expected_m, aligned_m = (
1468
1469
                per_token_group_quant_mla_deep_gemm_masked_fp8(
                    attn_output.transpose(0, 1)
1470
1471
1472
1473
1474
                )
            )
            attn_bmm_output = attn_output.new_empty(
                (self.num_local_heads, aligned_m, self.v_head_dim)
            )
1475
            deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
1476
1477
1478
1479
1480
1481
                (attn_output_val, attn_output_scale),
                (self.w_vc, self.w_scale_v),
                attn_bmm_output,
                masked_m,
                expected_m,
            )
Ke Bao's avatar
Ke Bao committed
1482
1483
1484
            attn_bmm_output = (
                attn_bmm_output[:, :expected_m, :].transpose(0, 1).flatten(1, 2)
            )
1485
1486
        elif _is_hip:
            # TODO(haishaw): add bmm_fp8 to ROCm
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
            if _use_aiter_gfx95 and self.w_vc.dtype == torch.uint8:
                x = attn_output.transpose(0, 1)
                attn_bmm_output = torch.empty(
                    x.shape[0],
                    x.shape[1],
                    self.w_vc.shape[2],
                    device=x.device,
                    dtype=torch.bfloat16,
                )
                batched_gemm_afp4wfp4_pre_quant(
                    x,
                    self.w_vc.transpose(-2, -1),
                    self.w_scale_v.transpose(-2, -1),
                    torch.bfloat16,
                    attn_bmm_output,
                )
            else:
                attn_bmm_output = torch.bmm(
                    attn_output.to(torch.bfloat16).transpose(0, 1),
                    self.w_vc.to(torch.bfloat16) * self.w_scale,
                )

            if self.o_proj.weight.dtype == torch.uint8:
                attn_bmm_output = attn_bmm_output.transpose(0, 1)
                attn_bmm_output = fused_flatten_mxfp4_quant(attn_bmm_output)
            else:
                attn_bmm_output = attn_bmm_output.transpose(0, 1).flatten(1, 2)

1515
        elif self.w_vc.dtype == torch.float8_e4m3fn:
1516
            attn_output_val, attn_output_scale = per_tensor_quant_mla_fp8(
Lianmin Zheng's avatar
Lianmin Zheng committed
1517
                attn_output.transpose(0, 1),
1518
                zero_allocator.allocate(1),
1519
1520
1521
1522
1523
1524
1525
1526
            )
            attn_bmm_output = bmm_fp8(
                attn_output_val,
                self.w_vc,
                attn_output_scale,
                self.w_scale,
                torch.bfloat16,
            )
Ke Bao's avatar
Ke Bao committed
1527
            attn_bmm_output = attn_bmm_output.transpose(0, 1).flatten(1, 2)
1528
        else:
Ke Bao's avatar
Ke Bao committed
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
            attn_bmm_output = torch.empty(
                (attn_output.shape[0], self.num_local_heads * self.v_head_dim),
                dtype=attn_output.dtype,
                device=attn_output.device,
            )
            torch.bmm(
                attn_output.transpose(0, 1),
                self.w_vc,
                out=attn_bmm_output.view(
                    -1, self.num_local_heads, self.v_head_dim
                ).transpose(0, 1),
            )
        output, _ = self.o_proj(attn_bmm_output)
1542
1543
1544

        return output

1545
    def forward_absorb_fused_mla_rope_prepare(
1546
1547
1548
1549
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        forward_batch: ForwardBatch,
1550
        zero_allocator: BumpAllocator,
1551
    ):
1552
1553
1554
1555
1556
1557
1558
1559
        enable_rope_fusion = (
            os.getenv("SGLANG_FUSED_MLA_ENABLE_ROPE_FUSION", "1") == "1"
        )
        q_len = hidden_states.shape[0]
        q_input = hidden_states.new_empty(
            q_len, self.num_local_heads, self.kv_lora_rank + self.qk_rope_head_dim
        )
        if self.q_lora_rank is not None:
1560
1561
1562
            q, latent_cache = self.fused_qkv_a_proj_with_mqa(hidden_states)[0].split(
                [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1
            )
1563
1564
1565
1566
1567
1568
            q = self.q_a_layernorm(q)
            q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
        else:
            q = self.q_proj(hidden_states)[0].view(
                -1, self.num_local_heads, self.qk_head_dim
            )
1569
            latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
1570
1571
        q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)

1572
1573
        if _is_hip:
            # TODO(haishaw): add bmm_fp8 to ROCm
1574
1575
1576
1577
1578
            q_nope_out = torch.bmm(
                q_nope.to(torch.bfloat16).transpose(0, 1),
                self.w_kc.to(torch.bfloat16) * self.w_scale,
            )
        elif self.w_kc.dtype == torch.float8_e4m3fn:
1579
            q_nope_val, q_nope_scale = per_tensor_quant_mla_fp8(
1580
1581
1582
                q_nope.transpose(0, 1),
                zero_allocator.allocate(1),
                dtype=torch.float8_e4m3fn,
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
            )
            q_nope_out = bmm_fp8(
                q_nope_val, self.w_kc, q_nope_scale, self.w_scale, torch.bfloat16
            )
        else:
            q_nope_out = torch.bmm(q_nope.transpose(0, 1), self.w_kc)
        q_input[..., : self.kv_lora_rank] = q_nope_out.transpose(0, 1)
        v_input = latent_cache[..., : self.kv_lora_rank]
        v_input = self.kv_a_layernorm(v_input.contiguous()).unsqueeze(1)
        k_input = latent_cache.unsqueeze(1)
        k_input[..., : self.kv_lora_rank] = v_input

        if not enable_rope_fusion:
            k_pe = k_input[..., self.kv_lora_rank :]
            q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
            q_input[..., self.kv_lora_rank :] = q_pe
            k_input[..., self.kv_lora_rank :] = k_pe
            k_pe_output = None
        else:
            k_pe_output = torch.empty_like(k_input[..., self.kv_lora_rank :])

        q_input[..., self.kv_lora_rank :] = q_pe

        # attn_output = self.attn_mqa(q_input, k_input, v_input, forward_batch)
        # Use Fused ROPE with use_rope=OFF.
        attn_output = torch.empty(
            (q_len, self.num_local_heads, self.kv_lora_rank),
            dtype=q.dtype,
            device=q.device,
        )
        attn_logits, _, kv_indptr, kv_indices, _, _, _ = (
            forward_batch.attn_backend.forward_metadata
        )
        cos_sin_cache = self.rotary_emb.cos_sin_cache
        num_kv_split = forward_batch.attn_backend.num_kv_splits
        sm_scale = self.attn_mqa.scaling
        if attn_logits is None:
            attn_logits = torch.empty(
                (
                    forward_batch.batch_size,
                    self.num_local_heads,
                    num_kv_split,
                    self.kv_lora_rank + 1,
                ),
                dtype=torch.float32,
                device=q.device,
            )

        # save current latent cache.
        forward_batch.token_to_kv_pool.set_kv_buffer(
            self.attn_mqa, forward_batch.out_cache_loc, k_input, None
        )
        key_cache_buf = forward_batch.token_to_kv_pool.get_key_buffer(
            self.attn_mqa.layer_id
        )
        val_cache_buf = key_cache_buf[..., : self.kv_lora_rank]

1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
        return (
            q_input,
            key_cache_buf,
            val_cache_buf,
            attn_output,
            kv_indptr,
            kv_indices,
            k_pe_output,
            cos_sin_cache,
            positions,
            attn_logits,
            num_kv_split,
            sm_scale,
            enable_rope_fusion,
            k_input,
            forward_batch,
            zero_allocator,
        )

1659
1660
1661
1662
1663
1664
1665
    def forward_absorb_fused_mla_rope_cpu_prepare(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        forward_batch: ForwardBatch,
        zero_allocator: BumpAllocator,
    ):
1666
1667
        assert self.q_lora_rank is not None and use_intel_amx_backend(
            self
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
        ), "forward_absorb_fused_mla_rope_cpu_prepare requires q_lora_rank is not None and use_intel_amx_backend"

        q_input, k_input, v_input = (
            torch.ops.sgl_kernel.qkv_proj_with_rope_fused_weight(
                hidden_states,
                self.fused_qkv_a_proj_with_mqa.weight,
                self.q_b_proj.weight,
                self.w_kc,
                self.q_a_layernorm.weight,
                self.kv_a_layernorm.weight,
                positions,
                self.rotary_emb.cos_sin_cache,
                self.kv_a_layernorm.variance_epsilon,
                self.qkv_proj_with_rope_is_int8,
                self.qkv_proj_with_rope_is_fp8,
                (
                    self.fused_qkv_a_proj_with_mqa.weight_scale
                    if self.qkv_proj_with_rope_is_int8
                    else (
                        self.fused_qkv_a_proj_with_mqa.weight_scale_inv
                        if self.qkv_proj_with_rope_is_fp8
                        else None
                    )
                ),
                (
                    self.q_b_proj.weight_scale
                    if self.qkv_proj_with_rope_is_int8
                    else (
                        self.q_b_proj.weight_scale_inv
                        if self.qkv_proj_with_rope_is_fp8
                        else None
                    )
                ),
                True,  # is_vnni
                self.weight_block_size,
                self.q_lora_rank,
                self.kv_lora_rank,
                self.qk_rope_head_dim,
            )
        )
        return (q_input, k_input, v_input, forward_batch, zero_allocator)

1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
    def forward_absorb_fused_mla_rope_core(
        self,
        q_input,
        key_cache_buf,
        val_cache_buf,
        attn_output,
        kv_indptr,
        kv_indices,
        k_pe_output,
        cos_sin_cache,
        positions,
        attn_logits,
        num_kv_split,
        sm_scale,
        enable_rope_fusion,
        k_input,
        forward_batch,
        zero_allocator,
    ):
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
1747
1748
1749
1750
1751
1752
1753
1754
1755
1756
        decode_attention_fwd_grouped_rope(
            q_input,
            key_cache_buf,
            val_cache_buf,
            attn_output,
            kv_indptr,
            kv_indices,
            k_pe_output,
            self.kv_lora_rank,
            self.rotary_emb.rotary_dim,
            cos_sin_cache,
            positions,
            attn_logits,
            num_kv_split,
            sm_scale,
            logit_cap=self.attn_mqa.logit_cap,
            use_rope=enable_rope_fusion,
            is_neox_style=self.rotary_emb.is_neox_style,
        )

        if enable_rope_fusion:
            k_input[..., self.kv_lora_rank :] = k_pe_output
            forward_batch.token_to_kv_pool.set_kv_buffer(
                self.attn_mqa, forward_batch.out_cache_loc, k_input, None
            )

        attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)

1757
1758
        if _is_hip:
            # TODO(haishaw): add bmm_fp8 to ROCm
1759
1760
1761
1762
1763
            attn_bmm_output = torch.bmm(
                attn_output.to(torch.bfloat16).transpose(0, 1),
                self.w_vc.to(torch.bfloat16) * self.w_scale,
            )
        elif self.w_vc.dtype == torch.float8_e4m3fn:
1764
            attn_output_val, attn_output_scale = per_tensor_quant_mla_fp8(
1765
1766
1767
                attn_output.transpose(0, 1),
                zero_allocator.allocate(1),
                dtype=torch.float8_e4m3fn,
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
            )
            attn_bmm_output = bmm_fp8(
                attn_output_val,
                self.w_vc,
                attn_output_scale,
                self.w_scale,
                torch.bfloat16,
            )
        else:
            attn_bmm_output = torch.bmm(attn_output.transpose(0, 1), self.w_vc)
        attn_output = attn_bmm_output.transpose(0, 1).flatten(1, 2)
1779
1780
1781
1782
        output, _ = self.o_proj(attn_output)

        return output

1783
1784
1785
    def forward_absorb_fused_mla_rope_cpu_core(
        self, q_input, k_input, v_input, forward_batch, zero_allocator
    ):
1786
1787
        assert self.q_lora_rank is not None and use_intel_amx_backend(
            self
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803
1804
1805
1806
1807
1808
1809
1810
1811
1812
1813
1814
1815
1816
1817
1818
1819
        ), "forward_absorb_fused_mla_rope_cpu_core requires q_lora_rank is not None and use_intel_amx_backend"

        attn_output = self.attn_mqa(q_input, k_input, v_input, forward_batch)
        attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)

        # [Note] Align shapes of bmm inputs.
        # Shapes of inputs:
        #   q_nope: [M, B, K]
        #   original self.w_kc: [B, K, N]
        #   current self.w_kc (which has been converted in PackWeightMethod): [B, N, K]

        # Shapes of inputs to sgl_kernel.cpu.bmm:
        #   out: [B, M, N]
        #   mat1: [B, M, K]
        #   mat2: [B, N, K]
        B = self.w_vc.size(0)
        N = self.w_vc.size(1)
        M = attn_output.size(0)
        output = torch.empty([M, int(B * N)], dtype=attn_output.dtype)
        attn_bmm_output = output.view([M, B, N]).transpose_(0, 1)
        torch.ops.sgl_kernel.bmm_cpu(
            attn_bmm_output,
            attn_output.transpose(0, 1),
            self.w_vc,
            True,  # is_vnni
            None,  # scale
        )
        attn_output = output
        output, _ = self.o_proj(attn_output)

        return output

1820
1821
1822
1823
1824
1825
1826
1827
1828
1829
1830
1831
1832
1833
1834
1835
    def _chunked_prefix_attn_mha(
        self,
        q: torch.Tensor,
        accum_output: torch.Tensor,
        accum_lse: torch.Tensor,
        forward_batch: ForwardBatch,
    ) -> torch.Tensor:

        assert forward_batch.num_prefix_chunks is not None
        for i in range(forward_batch.num_prefix_chunks):
            forward_batch.set_prefix_chunk_idx(i)

            # Fetch latent cache from memory pool with precomputed chunked kv indices
            latent_cache_buf = forward_batch.token_to_kv_pool.get_key_buffer(
                self.attn_mha.layer_id
            )
1836
1837
1838
1839
1840
            latent_cache = (
                latent_cache_buf[forward_batch.prefix_chunk_kv_indices[i]]
                .contiguous()
                .to(q.dtype)
            )
1841
1842
1843
1844
1845
1846
1847
1848
1849
1850
1851
1852
1853
1854
1855
1856
1857
1858
1859
1860
1861
1862
1863
1864
1865
1866
1867
1868
1869
1870
1871
1872

            kv_a_normed, k_pe = latent_cache.split(
                [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
            )
            kv_a_normed = kv_a_normed.squeeze(1).contiguous()
            kv = self.kv_b_proj(kv_a_normed)[0]
            kv = kv.view(
                -1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim
            )
            v = kv[..., self.qk_nope_head_dim :]
            k_nope = kv[..., : self.qk_nope_head_dim]

            k = torch.empty(
                (
                    k_nope.shape[0],
                    self.num_local_heads,
                    self.qk_nope_head_dim + self.qk_rope_head_dim,
                ),
                dtype=v.dtype,
                device=v.device,
            )
            k[..., : self.qk_nope_head_dim] = k_nope
            k[..., self.qk_nope_head_dim :] = k_pe

            output, lse = self.attn_mha(q, k, v, forward_batch, save_kv_cache=False)
            tmp_output = torch.empty_like(accum_output)
            tmp_lse = torch.empty_like(accum_lse)
            merge_state_v2(output, lse, accum_output, accum_lse, tmp_output, tmp_lse)
            accum_output, accum_lse = tmp_output, tmp_lse

        return accum_output

1873
    def forward_normal_chunked_kv_prepare(
1874
1875
1876
1877
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        forward_batch: ForwardBatch,
1878
1879
        zero_allocator: BumpAllocator,
    ):
1880
1881
1882
1883
1884
1885
1886
        # In normal mha, the k and v tensors will become overly large when the prefix length is long.
        # To avoid this, we split the kv cache into chunks and process them one after another.
        # Since mha is compute friendly, the for loop induced here will not introduce significant overhead.
        # The top comments in https://github.com/vllm-project/vllm/blob/main/vllm/v1/attention/backends/mla/common.py
        # will be helpful for understanding the purpose of this function.

        # First do normal mha forward to get output for extended part
1887
1888
        return self.forward_normal_prepare(
            positions, hidden_states, forward_batch, zero_allocator
1889
1890
        )

1891
    def forward_normal_chunked_kv_core(self, q, k, v, forward_batch):
1892
1893
1894
1895
1896
1897
1898
1899
        has_extend_prefix = any(forward_batch.extend_prefix_lens_cpu)
        # Only initialize the info once
        if has_extend_prefix and forward_batch.num_prefix_chunks is None:
            forward_batch.prepare_chunked_prefix_cache_info(q.device)
            if hasattr(forward_batch.attn_backend, "init_mha_chunk_metadata"):
                forward_batch.attn_backend.init_mha_chunk_metadata(forward_batch)

        forward_batch.mha_return_lse = has_extend_prefix
1900
1901
        # Do mha for extended part without prefix
        forward_batch.set_attn_attend_prefix_cache(False)
1902
        attn_output = self.attn_mha(q, k, v, forward_batch, save_kv_cache=False)
1903
1904

        # Do mha attention with chunked prefix cache if there are any sequence with prefix
1905
1906
        if has_extend_prefix:
            attn_output, lse = attn_output
1907
1908
1909
1910
1911
1912
1913
1914
1915
1916
1917
1918
            forward_batch.set_attn_attend_prefix_cache(True)
            attn_output = self._chunked_prefix_attn_mha(
                q=q,
                accum_output=attn_output,
                accum_lse=lse,
                forward_batch=forward_batch,
            )

        attn_output = attn_output.reshape(-1, self.num_local_heads * self.v_head_dim)
        output, _ = self.o_proj(attn_output)
        return output

1919

Liangsheng Yin's avatar
Liangsheng Yin committed
1920
1921
1922
1923
1924
1925
1926
class DeepseekV2DecoderLayer(nn.Module):

    def __init__(
        self,
        config: PretrainedConfig,
        layer_id: int,
        quant_config: Optional[QuantizationConfig] = None,
1927
        is_nextn: bool = False,
1928
        prefix: str = "",
1929
        alt_stream: Optional[torch.cuda.Stream] = None,
Liangsheng Yin's avatar
Liangsheng Yin committed
1930
1931
1932
    ) -> None:
        super().__init__()
        self.hidden_size = config.hidden_size
1933
        self.config = config
Liangsheng Yin's avatar
Liangsheng Yin committed
1934
1935
1936
        rope_theta = getattr(config, "rope_theta", 10000)
        rope_scaling = getattr(config, "rope_scaling", None)
        max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
1937
        self.speculative_algorithm = global_server_args_dict["speculative_algorithm"]
Lianmin Zheng's avatar
Lianmin Zheng committed
1938
        self.layer_id = layer_id
1939
        self.is_nextn = is_nextn
Baizhou Zhang's avatar
Baizhou Zhang committed
1940
1941
1942
1943
1944
1945
1946
1947
1948
1949
1950
1951
1952
1953
1954
1955
1956
1957
        self.self_attn = DeepseekV2AttentionMLA(
            config=config,
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
            qk_nope_head_dim=config.qk_nope_head_dim,
            qk_rope_head_dim=config.qk_rope_head_dim,
            v_head_dim=config.v_head_dim,
            q_lora_rank=(
                config.q_lora_rank if hasattr(config, "q_lora_rank") else None
            ),
            kv_lora_rank=config.kv_lora_rank,
            rope_theta=rope_theta,
            rope_scaling=rope_scaling,
            max_position_embeddings=max_position_embeddings,
            quant_config=quant_config,
            layer_id=layer_id,
            reduce_results=False,
            prefix=add_prefix("self_attn", prefix),
1958
            alt_stream=alt_stream,
Baizhou Zhang's avatar
Baizhou Zhang committed
1959
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
1960

1961
1962
1963
1964
1965
        self.is_layer_sparse = self._is_layer_sparse(layer_id, is_nextn=is_nextn)
        is_previous_layer_sparse = self._is_layer_sparse(layer_id - 1, is_nextn=False)

        self.layer_scatter_modes = LayerScatterModes.init_new(
            layer_id=layer_id,
1966
            num_layers=1 if is_nextn else config.num_hidden_layers,
1967
1968
            is_layer_sparse=self.is_layer_sparse,
            is_previous_layer_sparse=is_previous_layer_sparse,
1969
1970
        )

1971
        if self.is_layer_sparse:
1972
1973
1974
1975
            self.mlp = DeepseekV2MoE(
                config=config,
                quant_config=quant_config,
                prefix=add_prefix("mlp", prefix),
fzyzcjy's avatar
fzyzcjy committed
1976
                layer_id=self.layer_id,
1977
                alt_stream=alt_stream,
1978
                is_nextn=is_nextn,
1979
            )
Liangsheng Yin's avatar
Liangsheng Yin committed
1980
        else:
1981
            if enable_moe_dense_fully_dp():
1982
1983
1984
                mlp_tp_rank, mlp_tp_size = 0, 1
            else:
                mlp_tp_rank, mlp_tp_size = None, None
Liangsheng Yin's avatar
Liangsheng Yin committed
1985
1986
1987
1988
1989
            self.mlp = DeepseekV2MLP(
                hidden_size=config.hidden_size,
                intermediate_size=config.intermediate_size,
                hidden_act=config.hidden_act,
                quant_config=quant_config,
1990
                prefix=add_prefix("mlp", prefix),
1991
1992
                tp_rank=mlp_tp_rank,
                tp_size=mlp_tp_size,
Liangsheng Yin's avatar
Liangsheng Yin committed
1993
            )
1994

Liangsheng Yin's avatar
Liangsheng Yin committed
1995
1996
1997
1998
1999
        self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(
            config.hidden_size, eps=config.rms_norm_eps
        )

2000
2001
2002
2003
        self.layer_communicator = LayerCommunicator(
            layer_scatter_modes=self.layer_scatter_modes,
            input_layernorm=self.input_layernorm,
            post_attention_layernorm=self.post_attention_layernorm,
2004
            allow_reduce_scatter=True,
2005
2006
2007
            is_last_layer=(
                is_nextn or (self.layer_id == self.config.num_hidden_layers - 1)
            ),
2008
        )
2009
2010
2011
2012
2013
2014

    def _is_layer_sparse(self, layer_id: int, is_nextn: bool) -> bool:
        return is_nextn or (
            self.config.n_routed_experts is not None
            and layer_id >= self.config.first_k_dense_replace
            and layer_id % self.config.moe_layer_freq == 0
2015
2016
        )

Liangsheng Yin's avatar
Liangsheng Yin committed
2017
2018
2019
2020
    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
2021
        forward_batch: ForwardBatch,
Liangsheng Yin's avatar
Liangsheng Yin committed
2022
        residual: Optional[torch.Tensor],
2023
        zero_allocator: BumpAllocator,
2024
        gemm_output_zero_allocator: BumpAllocator = None,
Liangsheng Yin's avatar
Liangsheng Yin committed
2025
    ) -> torch.Tensor:
2026

2027
2028
2029
2030
2031
2032
2033
        quant_format = (
            "mxfp4"
            if _is_gfx95_supported
            and self.self_attn.fused_qkv_a_proj_with_mqa.weight == torch.uint8
            else ""
        )

2034
        hidden_states, residual = self.layer_communicator.prepare_attn(
2035
2036
2037
2038
            hidden_states,
            residual,
            forward_batch,
            quant_format,
2039
2040
        )

2041
2042
2043
2044
2045
2046
2047
2048
2049
2050
2051
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
            forward_batch=forward_batch,
            zero_allocator=zero_allocator,
        )

        hidden_states, residual = self.layer_communicator.prepare_mlp(
            hidden_states, residual, forward_batch
        )

2052
        should_allreduce_fusion = (
2053
2054
            self.layer_communicator.should_fuse_mlp_allreduce_with_next_layer(
                forward_batch
2055
            )
2056
2057
        )

2058
2059
2060
2061
        # For DP with padding, reduce scatter can be used instead of all-reduce.
        use_reduce_scatter = self.layer_communicator.should_use_reduce_scatter(
            forward_batch
        )
2062
2063
2064
2065

        if isinstance(self.mlp, DeepseekV2MLP):
            gemm_output_zero_allocator = None

2066
        hidden_states = self.mlp(
2067
2068
2069
2070
2071
            hidden_states,
            forward_batch,
            should_allreduce_fusion,
            use_reduce_scatter,
            gemm_output_zero_allocator,
2072
        )
2073

2074
        if should_allreduce_fusion:
2075
2076
            hidden_states._sglang_needs_allreduce_fusion = True

2077
        if not should_allreduce_fusion:
2078
2079
2080
2081
            hidden_states, residual = self.layer_communicator.postprocess_layer(
                hidden_states, residual, forward_batch
            )

2082
2083
        return hidden_states, residual

2084
2085
2086
2087
2088
2089
2090
2091
    def op_comm_prepare_attn(
        self,
        state,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        forward_batch: ForwardBatch,
        residual: Optional[torch.Tensor],
        zero_allocator: BumpAllocator,
2092
        tbo_subbatch_index: Optional[int] = None,
2093
2094
    ):
        state.hidden_states_after_comm_pre_attn, state.residual_after_input_ln = (
fzyzcjy's avatar
fzyzcjy committed
2095
            self.layer_communicator.prepare_attn(hidden_states, residual, forward_batch)
2096
2097
2098
2099
2100
2101
        )
        state.update(
            dict(
                forward_batch=forward_batch,
                positions=positions,
                zero_allocator=zero_allocator,
2102
                tbo_subbatch_index=tbo_subbatch_index,
2103
            )
2104
        )
2105

2106
2107
2108
2109
2110
2111
2112
    def op_comm_prepare_mlp(self, state):
        state.hidden_states_mlp_input, state.residual_after_comm_pre_mlp = (
            self.layer_communicator.prepare_mlp(
                state.pop("hidden_states_after_attn"),
                state.pop("residual_after_input_ln"),
                state.forward_batch,
            )
2113
        )
2114

2115
2116
2117
2118
2119
2120
2121
2122
    def op_mlp(self, state):
        hidden_states = state.pop("hidden_states_mlp_input")
        if not (
            enable_moe_dense_fully_dp()
            and (not self.is_layer_sparse)
            and hidden_states.shape[0] == 0
        ):
            state.hidden_states_mlp_output = self.mlp(
2123
                hidden_states, state.forward_batch
2124
2125
2126
            )
        else:
            state.hidden_states_mlp_output = hidden_states
2127

2128
    def op_comm_postprocess_layer(self, state):
2129
        hidden_states, residual = self.layer_communicator.postprocess_layer(
2130
2131
2132
            state.pop("hidden_states_mlp_output"),
            state.pop("residual_after_comm_pre_mlp"),
            state.forward_batch,
2133
        )
2134

2135
2136
2137
2138
2139
2140
2141
2142
2143
2144
2145
2146
2147
2148
2149
2150
2151
2152
        output = dict(
            positions=state.positions,
            hidden_states=hidden_states,
            residual=residual,
            forward_batch=state.forward_batch,
            zero_allocator=state.zero_allocator,
            tbo_subbatch_index=state.tbo_subbatch_index,
        )

        state.clear(
            expect_keys={
                "positions",
                "forward_batch",
                "zero_allocator",
                "tbo_subbatch_index",
            }
        )
        return output
2153

Liangsheng Yin's avatar
Liangsheng Yin committed
2154
2155
2156
2157
2158
2159
2160
2161

class DeepseekV2Model(nn.Module):
    fall_back_to_pt_during_load = False

    def __init__(
        self,
        config: PretrainedConfig,
        quant_config: Optional[QuantizationConfig] = None,
2162
        prefix: str = "",
Liangsheng Yin's avatar
Liangsheng Yin committed
2163
2164
2165
2166
    ) -> None:
        super().__init__()
        self.padding_id = config.pad_token_id
        self.vocab_size = config.vocab_size
2167
        self.first_k_dense_replace = config.first_k_dense_replace
2168
2169
2170
2171
2172
2173
2174
2175
2176
2177
        self.pp_group = get_pp_group()

        if self.pp_group.is_first_rank:
            self.embed_tokens = VocabParallelEmbedding(
                config.vocab_size,
                config.hidden_size,
                enable_tp=not is_dp_attention_enabled(),
            )
        else:
            self.embed_tokens = PPMissingLayer()
Liangsheng Yin's avatar
Liangsheng Yin committed
2178

2179
        self.alt_stream = torch.cuda.Stream() if _is_cuda else None
2180
2181
2182
2183
2184
2185
2186
2187
2188
2189
2190
2191
        self.layers, self.start_layer, self.end_layer = make_layers(
            config.num_hidden_layers,
            lambda idx, prefix: DeepseekV2DecoderLayer(
                config=config,
                layer_id=idx,
                quant_config=quant_config,
                prefix=prefix,
                alt_stream=self.alt_stream,
            ),
            pp_rank=self.pp_group.rank_in_group,
            pp_size=self.pp_group.world_size,
            prefix=add_prefix("layers", prefix),
fzyzcjy's avatar
fzyzcjy committed
2192
2193
2194
2195
2196
2197
2198
2199
2200
2201
2202
2203
2204
2205
2206
2207
2208
            offloader_kwargs=dict(
                submodule_accessor=lambda layer: (
                    layer.mlp.experts
                    if isinstance(layer.mlp, DeepseekV2MoE)
                    else layer.mlp
                ),
                whitelist_param_names_creator=lambda module: (
                    [
                        "w13_weight",
                        "w2_weight",
                        "w13_blockscale_swizzled",
                        "w2_blockscale_swizzled",
                    ]
                    if isinstance(module, FusedMoE)
                    else []
                ),
            ),
Liangsheng Yin's avatar
Liangsheng Yin committed
2209
        )
2210
2211
2212
2213
        if self.pp_group.is_last_rank:
            self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        else:
            self.norm = PPMissingLayer(return_tuple=True)
Liangsheng Yin's avatar
Liangsheng Yin committed
2214

2215
2216
2217
2218
2219
2220
2221
2222
2223
2224
2225
2226
2227
2228
2229
2230
2231
2232
2233
2234
2235
2236
2237
2238
2239
2240
2241
2242
2243
2244
2245
        self.gemm_output_zero_allocator_size = 0
        if (
            _use_aiter_gfx95
            and config.n_routed_experts == 256
            and self.embed_tokens.embedding_dim == 7168
        ):
            num_moe_layers = sum(
                [
                    1
                    for i in range(len(self.layers))
                    if isinstance(self.layers[i].mlp, DeepseekV2MoE)
                ]
            )

            allocate_size = 0
            for i in range(len(self.layers)):
                if isinstance(self.layers[i].mlp, DeepseekV2MoE):
                    allocate_size = self.layers[
                        i
                    ].mlp.shared_experts.gate_up_proj.output_size_per_partition
                    break

            self.gemm_output_zero_allocator_size = (
                get_dsv3_gemm_output_zero_allocator_size(
                    config.n_routed_experts,
                    num_moe_layers,
                    allocate_size,
                    self.embed_tokens.embedding_dim,
                )
            )

2246
2247
2248
    def get_input_embeddings(self) -> torch.Tensor:
        return self.embed_tokens

Liangsheng Yin's avatar
Liangsheng Yin committed
2249
2250
2251
2252
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
2253
        forward_batch: ForwardBatch,
2254
        input_embeds: torch.Tensor = None,
2255
2256
2257
        pp_proxy_tensors: Optional[PPProxyTensors] = None,
    ) -> Union[torch.Tensor, PPProxyTensors]:
        total_num_layers = self.end_layer - self.start_layer
2258
        device = input_embeds.device if input_embeds is not None else input_ids.device
2259
        zero_allocator = BumpAllocator(
2260
            buffer_size=total_num_layers * 2 * (2 if forward_batch.can_run_tbo else 1),
2261
            dtype=torch.float32,
2262
            device=device,
2263
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
2264

2265
2266
2267
2268
2269
2270
2271
2272
2273
2274
2275
2276
2277
2278
2279
        has_gemm_output_zero_allocator = hasattr(
            self, "gemm_output_zero_allocator_size"
        )

        gemm_output_zero_allocator = (
            BumpAllocator(
                buffer_size=self.gemm_output_zero_allocator_size,
                dtype=torch.float32,
                device=device,
            )
            if has_gemm_output_zero_allocator
            and self.gemm_output_zero_allocator_size > 0
            else None
        )

2280
2281
2282
2283
2284
2285
        if self.pp_group.is_first_rank:
            if input_embeds is None:
                hidden_states = self.embed_tokens(input_ids)
            else:
                hidden_states = input_embeds
            residual = None
2286
        else:
2287
2288
2289
            assert pp_proxy_tensors is not None
            hidden_states = pp_proxy_tensors["hidden_states"]
            residual = pp_proxy_tensors["residual"]
2290

2291
2292
2293
2294
2295
2296
2297
2298
2299
2300
        normal_start_layer = self.start_layer
        normal_end_layer = self.end_layer
        if forward_batch.can_run_tbo:
            if (
                self.first_k_dense_replace > normal_start_layer
                and self.first_k_dense_replace < normal_end_layer
            ):
                normal_end_layer = self.first_k_dense_replace
            elif self.first_k_dense_replace < normal_start_layer:
                normal_end_layer = normal_start_layer = 0
2301

2302
        for i in range(normal_start_layer, normal_end_layer):
2303
2304
2305
            with get_global_expert_distribution_recorder().with_current_layer(i):
                layer = self.layers[i]
                hidden_states, residual = layer(
2306
2307
2308
2309
2310
2311
                    positions,
                    hidden_states,
                    forward_batch,
                    residual,
                    zero_allocator,
                    gemm_output_zero_allocator,
2312
                )
2313

2314
        if normal_end_layer != self.end_layer:
2315
            hidden_states, residual = model_forward_maybe_tbo(
2316
                layers=self.layers[normal_end_layer : self.end_layer],
2317
2318
2319
2320
2321
                enable_tbo=True,
                positions=positions,
                forward_batch=forward_batch,
                hidden_states=hidden_states,
                residual=residual,
2322
                input_data_scatter_mode=self.layers[
2323
                    normal_end_layer - 1
2324
                ].layer_scatter_modes.layer_output_mode,
2325
2326
2327
                zero_allocator=zero_allocator,
            )

2328
2329
2330
2331
2332
2333
2334
2335
2336
2337
2338
2339
2340
        if not self.pp_group.is_last_rank:
            return PPProxyTensors(
                {
                    "hidden_states": hidden_states,
                    "residual": residual,
                }
            )
        else:
            if not forward_batch.forward_mode.is_idle():
                if residual is None:
                    hidden_states = self.norm(hidden_states)
                else:
                    hidden_states, _ = self.norm(hidden_states, residual)
Liangsheng Yin's avatar
Liangsheng Yin committed
2341
2342
2343
2344
        return hidden_states


class DeepseekV2ForCausalLM(nn.Module):
2345
2346
    # for quark model load
    packed_modules_mapping = {}
Liangsheng Yin's avatar
Liangsheng Yin committed
2347
2348
2349
2350
2351

    def __init__(
        self,
        config: PretrainedConfig,
        quant_config: Optional[QuantizationConfig] = None,
2352
        prefix: str = "",
Liangsheng Yin's avatar
Liangsheng Yin committed
2353
2354
    ) -> None:
        super().__init__()
2355
2356
2357
2358
2359
2360
2361
2362
2363
2364
2365
2366

        # for quark model load
        # Fuse q_a_proj and kv_a_proj_with_mqa along output dimension when q_lora_rank is not None
        self.fuse_qkv_a_proj = (
            hasattr(config, "q_lora_rank") and config.q_lora_rank is not None
        )
        if self.fuse_qkv_a_proj:
            self.packed_modules_mapping["fused_qkv_a_proj_with_mqa"] = [
                "q_a_proj",
                "kv_a_proj_with_mqa",
            ]

2367
        self.pp_group = get_pp_group()
Liangsheng Yin's avatar
Liangsheng Yin committed
2368
        self.config = config
2369
        self.tp_size = get_tensor_model_parallel_world_size()
Liangsheng Yin's avatar
Liangsheng Yin committed
2370
        self.quant_config = quant_config
2371
        self.determine_num_fused_shared_experts()
2372
2373
2374
2375
2376
2377
2378
2379
        self.model = DeepseekV2Model(
            config, quant_config, prefix=add_prefix("model", prefix)
        )
        self.lm_head = ParallelLMHead(
            config.vocab_size,
            config.hidden_size,
            quant_config=quant_config,
            prefix=add_prefix("lm_head", prefix),
2380
            use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
2381
2382
2383
        )
        self.logits_processor = LogitsProcessor(config)

2384
2385
2386
2387
2388
2389
2390
2391
2392
2393
2394
2395
        self._routed_experts_weights_of_layer = LazyValue(
            lambda: {
                layer_id: layer.mlp.get_moe_weights()
                for layer_id, layer in enumerate(self.model.layers)
                if isinstance(layer.mlp, DeepseekV2MoE)
            }
        )

    @property
    def routed_experts_weights_of_layer(self):
        return self._routed_experts_weights_of_layer.value

2396
    def determine_num_fused_shared_experts(
2397
2398
        self, architecture: str = "DeepseekV3ForCausalLM"
    ):
2399
2400
2401
2402
2403
2404
2405
2406
        self.num_fused_shared_experts = 0
        if global_server_args_dict["disable_shared_experts_fusion"]:
            return

        # Only Deepseek V3/R1 can use shared experts fusion optimization now.
        disable_reason = None
        if (
            not _is_cuda
2407
            or torch.cuda.get_device_capability("cuda") < (8, 0)
2408
2409
2410
2411
            or self.config.architectures[0] != architecture
            or self.config.n_routed_experts != 256
            or self.config.n_shared_experts != 1
        ):
2412
            disable_reason = "Only Deepseek V3/R1 on NV-platform with capability >= 80 can use shared experts fusion optimization."
2413
2414
        elif get_moe_expert_parallel_world_size() > 1:
            disable_reason = "Deepseek V3/R1 can not use shared experts fusion optimization under expert parallelism."
2415
2416
        elif self.quant_config.get_name() == "w4afp8":
            disable_reason = "Deepseek V3/R1 W4AFP8 model uses different quant method for routed experts and shared experts."
2417
2418
2419

        if disable_reason is not None:
            global_server_args_dict["disable_shared_experts_fusion"] = True
Cheng Wan's avatar
Cheng Wan committed
2420
            self.num_fused_shared_experts = 0
2421
2422
2423
2424
2425
2426
2427
            log_info_on_rank0(
                logger,
                f"{disable_reason} Shared experts fusion optimization is disabled.",
            )
            return

        self.num_fused_shared_experts = self.config.n_shared_experts
2428

Mick's avatar
Mick committed
2429
2430
2431
    def get_input_embeddings(self) -> nn.Embedding:
        return self.model.embed_tokens

2432
    @torch.no_grad()
Liangsheng Yin's avatar
Liangsheng Yin committed
2433
2434
2435
2436
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
2437
        forward_batch: ForwardBatch,
2438
        input_embeds: torch.Tensor = None,
2439
        pp_proxy_tensors: Optional[PPProxyTensors] = None,
Liangsheng Yin's avatar
Liangsheng Yin committed
2440
    ) -> torch.Tensor:
2441
2442
        hidden_states = self.model(
            input_ids, positions, forward_batch, input_embeds, pp_proxy_tensors
2443
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
2444

2445
2446
2447
2448
2449
2450
2451
2452
2453
2454
2455
2456
2457
2458
2459
        if self.pp_group.is_last_rank:
            return self.logits_processor(
                input_ids, hidden_states, self.lm_head, forward_batch
            )
        else:
            return hidden_states

    @property
    def start_layer(self):
        return self.model.start_layer

    @property
    def end_layer(self):
        return self.model.end_layer

2460
    def post_load_weights(self, is_nextn=False, weight_names=None):
inkcherry's avatar
inkcherry committed
2461
2462

        # Perform post-processing after loading weights
2463
2464
2465
2466
        if is_nextn:
            layer_ids = [self.config.num_hidden_layers]
        else:
            if weight_names is None:
2467
                layer_ids = range(self.model.start_layer, self.model.end_layer)
2468
2469
2470
2471
2472
            else:
                layer_ids = set()
                for name in weight_names:
                    if "kv_b_proj" in name:
                        layer_id = int(name.split(".")[2])
2473
                        if layer_id < self.config.num_hidden_layers:
2474
2475
                            layer_ids.add(layer_id)

2476
2477
2478
2479
2480
2481
        for layer_id in layer_ids:
            self_attn = (
                self.model.layers[layer_id].self_attn
                if not is_nextn
                else self.model.decoder.self_attn
            )
Baizhou Zhang's avatar
Baizhou Zhang committed
2482
2483
            if hasattr(self_attn.kv_b_proj, "qweight"):
                # AWQ compatible
2484
                if _is_cuda or _is_hip:
Baizhou Zhang's avatar
Baizhou Zhang committed
2485
2486
2487
2488
2489
                    w = awq_dequantize(
                        self_attn.kv_b_proj.qweight,
                        self_attn.kv_b_proj.scales,
                        self_attn.kv_b_proj.qzeros,
                    ).T
inkcherry's avatar
inkcherry committed
2490
                else:
Baizhou Zhang's avatar
Baizhou Zhang committed
2491
2492
2493
2494
2495
2496
2497
2498
2499
2500
2501
2502
                    w = awq_dequantize(
                        self_attn.kv_b_proj.qweight,
                        self_attn.kv_b_proj.scales,
                        self_attn.kv_b_proj.qzeros,
                        0,
                        0,
                        0,
                    ).T
            else:
                w = self_attn.kv_b_proj.weight
            # NOTE(HandH1998): Since `bmm_fp8` only supports per-tensor scale, we have to requantize `self_attn.kv_b_proj`.
            # This may affect the accuracy of fp8 model.
2503
2504
2505
            # Fix deepseek v3 blockwise bmm by using deep_gemm
            use_deep_gemm_bmm = False

Baizhou Zhang's avatar
Baizhou Zhang committed
2506
2507
2508
2509
            if w.dtype in (
                torch.float8_e4m3fn,
                torch.float8_e4m3fnuz,
            ):
2510
2511
2512
2513
                if (
                    hasattr(self.quant_config, "weight_block_size")
                    and self.quant_config.weight_block_size is not None
                ):
Baizhou Zhang's avatar
Baizhou Zhang committed
2514
                    weight_block_size = self.quant_config.weight_block_size
2515
2516
2517
2518
2519
2520
2521
2522
2523
2524
2525
2526
2527
2528
2529
2530
                    assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
                    if _is_fp8_fnuz:
                        weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
                            weight=w,
                            weight_scale=self_attn.kv_b_proj.weight_scale_inv,
                            input_scale=None,
                        )
                    else:
                        weight = w
                        weight_scale = self_attn.kv_b_proj.weight_scale_inv

                    if (
                        _is_cuda
                        and weight_block_size[0] == 128
                        and weight_block_size[1] == 128
                    ):
2531
2532
2533
2534
                        if (
                            deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
                            and not deep_gemm_wrapper.DEEPGEMM_BLACKWELL
                            and get_bool_env_var("SGL_USE_DEEPGEMM_BMM", "false")
2535
                        ):
2536
2537
                            block_scale = weight_scale
                            use_deep_gemm_bmm = True
2538
                        else:
2539
2540
2541
2542
                            w = block_quant_dequant(
                                weight,
                                weight_scale,
                                weight_block_size,
2543
                                torch.bfloat16,
2544
                            )
2545
2546
2547
2548
2549
                    else:
                        w, scale = block_quant_to_tensor_quant(
                            weight, weight_scale, weight_block_size
                        )
                        self_attn.w_scale = scale
Baizhou Zhang's avatar
Baizhou Zhang committed
2550
                else:
2551
2552
2553
2554
2555
2556
2557
2558
2559
2560
                    if _is_fp8_fnuz:
                        weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
                            weight=w,
                            weight_scale=self_attn.kv_b_proj.weight_scale,
                            input_scale=None,
                        )
                    else:
                        weight = w
                        weight_scale = self_attn.kv_b_proj.weight_scale

Baizhou Zhang's avatar
Baizhou Zhang committed
2561
2562
2563
2564
2565
2566
2567
2568
2569
2570
2571
2572
2573
2574
2575
2576
2577
2578
2579
                    w, scale = channel_quant_to_tensor_quant(weight, weight_scale)
                    self_attn.w_scale = scale

            if w.dtype == torch.int8:
                if hasattr(self.quant_config, "weight_block_size"):
                    # block-wise int8 need it
                    weight_block_size = self.quant_config.weight_block_size
                    if weight_block_size is not None:
                        assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
                        weight = w
                        weight_scale = self_attn.kv_b_proj.weight_scale_inv
                        w = int8_block_dequant(
                            weight, weight_scale, weight_block_size
                        ).to(torch.bfloat16)
                else:
                    # channel-wise int8 need it
                    w = w.to(torch.bfloat16) * self_attn.kv_b_proj.weight_scale.to(
                        torch.bfloat16
                    )
2580

Baizhou Zhang's avatar
Baizhou Zhang committed
2581
2582
2583
            w_kc, w_vc = w.unflatten(
                0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
            ).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
2584
2585
2586
2587
2588
2589

            if _use_aiter_gfx95 and self.quant_config.get_name() == "quark":
                w_kc, self_attn.w_scale_k, w_vc, self_attn.w_scale_v = (
                    quark_post_load_weights(self_attn, w, "mxfp4")
                )

2590
            if not use_deep_gemm_bmm:
2591
2592
2593
2594
2595
2596
                self_attn.w_kc = bind_or_assign(
                    self_attn.w_kc, w_kc.transpose(1, 2).contiguous().transpose(1, 2)
                )
                self_attn.w_vc = bind_or_assign(
                    self_attn.w_vc, w_vc.contiguous().transpose(1, 2)
                )
2597
2598
2599
2600
                if (
                    hasattr(self_attn.kv_b_proj, "weight_scale")
                    and self_attn.w_scale is None
                ):
2601
2602
2603
                    self_attn.w_scale = bind_or_assign(
                        self_attn.w_scale, self_attn.kv_b_proj.weight_scale
                    )
2604
2605
                    if _is_hip:
                        self_attn.w_scale *= 2.0
2606
2607
2608
2609
2610
2611
2612
2613
                # TODO: remove this after adding FP8 support in bmm cpu kernel
                if _is_cpu and _is_cpu_amx_available and w.dtype == torch.float8_e4m3fn:
                    self_attn.w_kc = (
                        self_attn.w_kc.to(torch.bfloat16) * self_attn.w_scale
                    )
                    self_attn.w_vc = (
                        self_attn.w_vc.to(torch.bfloat16) * self_attn.w_scale
                    )
2614
2615
2616
2617
2618
2619
            else:
                num_tiles_k = self_attn.qk_nope_head_dim // weight_block_size[1]
                num_tiles_n = self_attn.v_head_dim // weight_block_size[0]
                ws_kc, ws_vc = block_scale.unflatten(
                    0, (-1, (num_tiles_k + num_tiles_n))
                ).split([num_tiles_k, num_tiles_n], dim=1)
2620
2621
2622
2623
2624
2625
2626
2627
2628
2629
                self_attn.w_scale_k = bind_or_assign(
                    self_attn.w_scale_k, ws_kc.transpose(1, 2).contiguous()
                )
                self_attn.w_scale_v = bind_or_assign(
                    self_attn.w_scale_v, ws_vc.contiguous()
                )
                self_attn.w_kc = bind_or_assign(
                    self_attn.w_kc, w_kc.transpose(1, 2).contiguous()
                )
                self_attn.w_vc = bind_or_assign(self_attn.w_vc, w_vc.contiguous())
2630
                self_attn.use_deep_gemm_bmm = True
inkcherry's avatar
inkcherry committed
2631

2632
2633
2634
        if (
            deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
            and deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
2635
2636
            and hasattr(self.quant_config, "weight_block_size")
            and self.quant_config.weight_block_size is not None
2637
        ):
2638
            self._weight_requant_ue8m0(is_nextn)
2639

2640
    def _weight_requant_ue8m0(self, is_nextn=False):
2641
2642
2643
2644
2645
2646
2647
2648
2649
2650
        weight_block_size = self.quant_config.weight_block_size

        moe_layers = list(
            range(
                self.config.first_k_dense_replace,
                self.config.num_hidden_layers,
                self.config.moe_layer_freq,
            )
        )

2651
        num_hidden_layers = 1 if is_nextn else self.config.num_hidden_layers
2652

2653
2654
2655
2656
2657
        for layer_id in range(num_hidden_layers):
            if is_nextn:
                layer = self.model.decoder
            else:
                layer = self.model.layers[layer_id]
2658

2659
            module_list = [
2660
2661
                layer.self_attn.kv_b_proj,
                layer.self_attn.o_proj,
2662
2663
2664
2665
2666
2667
2668
2669
2670
2671
            ]

            if self.config.q_lora_rank is not None:
                module_list.append(layer.self_attn.fused_qkv_a_proj_with_mqa)
                module_list.append(layer.self_attn.q_b_proj)
            else:
                module_list.append(layer.self_attn.kv_a_proj_with_mqa)
                module_list.append(layer.self_attn.q_proj)

            for module in module_list:
2672
2673
2674
2675
                requant_weight_ue8m0_inplace(
                    module.weight, module.weight_scale_inv, weight_block_size
                )

2676
            if layer_id in moe_layers or is_nextn:
2677
2678
2679
2680
2681
2682
2683
2684
2685
                shared_experts = getattr(layer.mlp, "shared_experts", None)
                if shared_experts is not None:
                    for module in [
                        shared_experts.gate_up_proj,
                        shared_experts.down_proj,
                    ]:
                        requant_weight_ue8m0_inplace(
                            module.weight, module.weight_scale_inv, weight_block_size
                        )
2686
2687
2688
2689
2690
2691
2692
2693
2694
2695
2696
2697
2698
2699
2700
2701
2702
2703
2704

                experts = layer.mlp.experts
                if isinstance(experts, DeepEPMoE):
                    for w in [
                        experts.w13_weight_fp8,
                        experts.w2_weight_fp8,
                    ]:
                        requant_weight_ue8m0_inplace(w[0], w[1], weight_block_size)
            else:
                mlp = layer.mlp
                assert isinstance(mlp, DeepseekV2MLP)
                for module in [
                    mlp.gate_up_proj,
                    mlp.down_proj,
                ]:
                    requant_weight_ue8m0_inplace(
                        module.weight, module.weight_scale_inv, weight_block_size
                    )

2705
    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], is_nextn=False):
2706

2707
2708
2709
        if is_nextn:
            if hasattr(self.config, "num_nextn_predict_layers"):
                num_nextn_layers = self.config.num_nextn_predict_layers
2710
                assert num_nextn_layers == 1, "Only 1 nextn layer is supported"
2711
2712
2713
2714
2715
2716
2717
2718
2719
                # compatible with old design
                nextn_layer_id = (
                    0
                    if self.config.num_hidden_layers == 1
                    else self.config.num_hidden_layers
                )
            else:
                raise ValueError("num_nextn_predict_layers is not in the config")

Liangsheng Yin's avatar
Liangsheng Yin committed
2720
2721
2722
2723
2724
2725
2726
2727
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
        ]

        # Params for weights, fp8 weight scales, fp8 activation scales
        # (param_name, weight_name, expert_id, shard_id)
2728
        expert_params_mapping = FusedMoE.make_expert_params_mapping(
Liangsheng Yin's avatar
Liangsheng Yin committed
2729
2730
2731
            ckpt_gate_proj_name="gate_proj",
            ckpt_down_proj_name="down_proj",
            ckpt_up_proj_name="up_proj",
2732
            num_experts=self.config.n_routed_experts + self.num_fused_shared_experts,
Liangsheng Yin's avatar
Liangsheng Yin committed
2733
        )
2734
2735
2736
        # Params for special naming rules in mixed-precision models, for example:
        # model.layers.xx.mlp.experts.xx.w1.input_scale. For details,
        # see https://huggingface.co/Barrrrry/DeepSeek-R1-W4AFP8/blob/main.
2737
        if self.quant_config and self.quant_config.get_name() == "w4afp8":
2738
2739
            expert_params_mapping += FusedMoE.make_expert_input_scale_params_mapping(
                num_experts=self.config.n_routed_experts
2740
            )
Liangsheng Yin's avatar
Liangsheng Yin committed
2741

2742
2743
2744
2745
2746
2747
        # Fuse q_a_proj and kv_a_proj_with_mqa along output dimension when q_lora_rank is not None
        fuse_qkv_a_proj = hasattr(self.config, "q_lora_rank") and (
            self.config.q_lora_rank is not None
        )
        cached_a_proj = {} if fuse_qkv_a_proj else None

2748
2749
2750
2751
2752
2753
2754
2755
2756
        if is_nextn:
            nextn_layer_prefix = f"model.layers.{nextn_layer_id}"
            nextn_spec_weight_names = [
                "shared_head.norm",
                "eh_proj",
                "enorm",
                "hnorm",
            ]

2757
2758
        if self.num_fused_shared_experts > 0:
            assert self.num_fused_shared_experts == 1
2759
            log_info_on_rank0(logger, "Shared experts fusion optimization enabled.")
2760

2761
2762
2763
2764
2765
        with concurrent.futures.ThreadPoolExecutor() as executor:
            futures = []
            params_dict = dict(self.named_parameters())
            weight_names = []
            for name, loaded_weight in weights:
2766
2767
2768
2769
2770
2771
2772
2773
2774
2775
                layer_id = get_layer_id(name)
                if (
                    layer_id is not None
                    and hasattr(self.model, "start_layer")
                    and (
                        layer_id < self.model.start_layer
                        or layer_id >= self.model.end_layer
                    )
                ):
                    continue
2776
2777
2778
2779
2780
                if self.num_fused_shared_experts > 0 and "mlp.shared_experts" in name:
                    name = name.replace(
                        "mlp.shared_experts",
                        f"mlp.experts.{self.config.n_routed_experts}",
                    )
2781

2782
                weight_names.append(name)
2783

2784
2785
2786
2787
2788
2789
2790
2791
2792
2793
2794
2795
2796
                if not is_nextn:
                    if hasattr(self.config, "num_nextn_predict_layers"):
                        num_nextn_layers = self.config.num_nextn_predict_layers
                        if num_nextn_layers > 0 and name.startswith("model.layers"):
                            name_list = name.split(".")
                            if (
                                len(name_list) >= 3
                                and int(name_list[2]) >= self.config.num_hidden_layers
                            ):
                                continue
                else:
                    if not name.startswith(nextn_layer_prefix):
                        continue
2797

2798
2799
2800
                    # Use shared head and embed weights from target model
                    if "shared_head.head" in name or "embed_tokens" in name:
                        continue
2801

2802
2803
2804
2805
2806
2807
2808
2809
2810
2811
2812
2813
                    is_decoder = True
                    # For nextn specific weights
                    for weight_name in nextn_spec_weight_names:
                        if weight_name in name:
                            name = name.replace(nextn_layer_prefix, "model")
                            is_decoder = False
                            break
                    # For decoder layer weights
                    if is_decoder:
                        name = name.replace(nextn_layer_prefix, "model.decoder")

                if "rotary_emb.inv_freq" in name:
Liangsheng Yin's avatar
Liangsheng Yin committed
2814
                    continue
2815
2816
                for param_name, weight_name, shard_id in stacked_params_mapping:
                    # Skip non-stacked layers and experts (experts handled below).
Liangsheng Yin's avatar
Liangsheng Yin committed
2817
2818
                    if weight_name not in name:
                        continue
2819
2820
2821
2822
2823
2824
2825
2826
                    # We have mlp.experts[0].gate_proj in the checkpoint.
                    # Since we handle the experts below in expert_params_mapping,
                    # we need to skip here BEFORE we update the name, otherwise
                    # name will be updated to mlp.experts[0].gate_up_proj, which
                    # will then be updated below in expert_params_mapping
                    # for mlp.experts[0].gate_gate_up_proj, which breaks load.
                    if ("mlp.experts." in name) and name not in params_dict:
                        continue
Liangsheng Yin's avatar
Liangsheng Yin committed
2827
                    name = name.replace(weight_name, param_name)
2828
2829
2830
                    # Skip loading extra bias for GPTQ models.
                    if name.endswith(".bias") and name not in params_dict:
                        continue
Liangsheng Yin's avatar
Liangsheng Yin committed
2831
2832
                    param = params_dict[name]
                    weight_loader = param.weight_loader
2833
2834
                    futures.append(
                        executor.submit(weight_loader, param, loaded_weight, shard_id)
Liangsheng Yin's avatar
Liangsheng Yin committed
2835
2836
2837
                    )
                    break
                else:
2838
2839
2840
2841
2842
2843
2844
2845
2846
2847
2848
2849
2850
2851
2852
2853
                    for mapping in expert_params_mapping:
                        param_name, weight_name, expert_id, shard_id = mapping
                        if weight_name not in name:
                            continue
                        name = name.replace(weight_name, param_name)
                        param = params_dict[name]
                        weight_loader = param.weight_loader
                        futures.append(
                            executor.submit(
                                weight_loader,
                                param,
                                loaded_weight,
                                name,
                                shard_id=shard_id,
                                expert_id=expert_id,
                            )
2854
                        )
2855
2856
2857
2858
2859
                        break
                    else:
                        # Skip loading extra bias for GPTQ models.
                        if name.endswith(".bias") and name not in params_dict:
                            continue
2860
2861
2862
2863
2864
2865
                        # Skip loading embed_tokens if not first rank in pipeline parallelism
                        if ".embed_tokens." in name and not self.pp_group.is_first_rank:
                            continue
                        # Skip loading norm if not last rank in pipeline parallelism
                        if ".norm." in name and not self.pp_group.is_last_rank:
                            continue
2866
2867
                        if fuse_qkv_a_proj and (
                            "q_a_proj" in name or "kv_a_proj_with_mqa" in name
2868
                        ):
2869
2870
2871
                            cached_a_proj[name] = loaded_weight
                            q_a_proj_name = (
                                name
2872
                                if "q_a_proj" in name
2873
2874
2875
2876
2877
2878
                                else name.replace("kv_a_proj_with_mqa", "q_a_proj")
                            )
                            kv_a_proj_name = (
                                name
                                if "kv_a_proj_with_mqa" in name
                                else name.replace("q_a_proj", "kv_a_proj_with_mqa")
2879
2880
                            )

2881
2882
2883
2884
2885
2886
2887
2888
2889
2890
                            # When both q_a_proj and kv_a_proj_with_mqa has been cached, load the fused weight to parameter
                            if (
                                q_a_proj_name in cached_a_proj
                                and kv_a_proj_name in cached_a_proj
                            ):
                                q_a_proj_weight = cached_a_proj[q_a_proj_name]
                                kv_a_proj_weight = cached_a_proj[kv_a_proj_name]
                                cat_dim = 0
                                if self.quant_config is not None and (
                                    self.quant_config.get_name() == "awq"
2891
                                    or self.quant_config.get_name() == "awq_marlin"
2892
2893
2894
2895
2896
2897
2898
2899
2900
2901
2902
2903
2904
2905
2906
2907
2908
2909
2910
2911
2912
2913
2914
2915
2916
2917
2918
2919
2920
2921
2922
2923
2924
2925
2926
2927
2928
2929
2930
2931
2932
2933
2934
2935
                                    or self.quant_config.get_name() == "moe_wna16"
                                ):
                                    cat_dim = 1
                                fused_weight = torch.cat(
                                    [q_a_proj_weight, kv_a_proj_weight], dim=cat_dim
                                )
                                param_name = (
                                    name.replace(
                                        "q_a_proj", "fused_qkv_a_proj_with_mqa"
                                    )
                                    if "q_a_proj" in name
                                    else name.replace(
                                        "kv_a_proj_with_mqa",
                                        "fused_qkv_a_proj_with_mqa",
                                    )
                                )
                                param = params_dict[param_name]

                                weight_loader = getattr(
                                    param, "weight_loader", default_weight_loader
                                )
                                futures.append(
                                    executor.submit(weight_loader, param, fused_weight)
                                )
                                cached_a_proj.pop(q_a_proj_name)
                                cached_a_proj.pop(kv_a_proj_name)
                        else:
                            if (
                                "k_scale" in name or "v_scale" in name
                            ) and name not in params_dict:
                                # modelopt attn kv scale is named differently
                                for scale in ["k_scale", "v_scale"]:
                                    if scale in name:
                                        name = name.replace(
                                            f"{scale[0]}_proj", "attn_mqa"
                                        )
                                        break
                            if name not in params_dict:
                                # modelopt ckpt contains not needed weights for MTP module:
                                # model.decoder.self_attn.attn_mqa.v_scale and
                                # model.decoder.self_attn.attn_mqa.k_scale
                                logger.warning(f"{name} not found in params_dict.")
                                continue
                            param = params_dict[name]
2936
2937
2938
                            weight_loader = getattr(
                                param, "weight_loader", default_weight_loader
                            )
2939
2940
2941
2942
2943
2944
2945
                            futures.append(
                                executor.submit(weight_loader, param, loaded_weight)
                            )

            # Wait for all tasks to complete and raise any exceptions.
            for future in concurrent.futures.as_completed(futures):
                future.result()
Liangsheng Yin's avatar
Liangsheng Yin committed
2946

2947
        self.post_load_weights(is_nextn=is_nextn, weight_names=weight_names)
Ke Bao's avatar
Ke Bao committed
2948

2949
2950
2951
2952
2953
2954
2955
2956
2957
2958
2959
    def get_embed_and_head(self):
        return self.model.embed_tokens.weight, self.lm_head.weight

    def set_embed_and_head(self, embed, head):
        del self.model.embed_tokens.weight
        del self.lm_head.weight
        self.model.embed_tokens.weight = embed
        self.lm_head.weight = head
        torch.cuda.empty_cache()
        torch.cuda.synchronize()

2960
2961
2962
2963
2964
2965
2966
2967
    @classmethod
    def get_model_config_for_expert_location(cls, config):
        return ModelConfigForExpertLocation(
            num_layers=config.num_hidden_layers,
            num_logical_experts=config.n_routed_experts,
            num_groups=config.n_group,
        )

Liangsheng Yin's avatar
Liangsheng Yin committed
2968

HandH1998's avatar
HandH1998 committed
2969
2970
2971
2972
2973
class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM):
    pass


EntryClass = [DeepseekV2ForCausalLM, DeepseekV3ForCausalLM]