deepseek_v2.py 105 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
14

Liangsheng Yin's avatar
Liangsheng Yin committed
15
16
17
# Adapted from:
# https://github.com/vllm-project/vllm/blob/fb6af8bc086328ca6659e72d11ffd4309ce4de22/vllm/model_executor/models/deepseek_v2.py
"""Inference-only DeepseekV2 model."""
18

19
import concurrent.futures
20
import logging
21
import os
22
from enum import IntEnum, auto
23
from typing import Any, Dict, Iterable, Optional, Tuple, Union
Liangsheng Yin's avatar
Liangsheng Yin committed
24
25

import torch
Ke Bao's avatar
Ke Bao committed
26
import torch.nn.functional as F
Liangsheng Yin's avatar
Liangsheng Yin committed
27
from torch import nn
28
from tqdm import tqdm
Liangsheng Yin's avatar
Liangsheng Yin committed
29
from transformers import PretrainedConfig
30
31

from sglang.srt.distributed import (
32
    get_moe_expert_parallel_world_size,
33
    get_pp_group,
Liangsheng Yin's avatar
Liangsheng Yin committed
34
    get_tensor_model_parallel_world_size,
35
    parallel_state,
Liangsheng Yin's avatar
Liangsheng Yin committed
36
37
    tensor_model_parallel_all_reduce,
)
38
39
40
from sglang.srt.distributed.device_communicators.pynccl_allocator import (
    use_symmetric_memory,
)
fzyzcjy's avatar
fzyzcjy committed
41
42
43
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation
from sglang.srt.eplb.expert_location_dispatch import ExpertLocationDispatchInfo
44
from sglang.srt.layers.activation import SiluAndMul
45
from sglang.srt.layers.amx_utils import PackWeightMethod
46
47
48
49
50
from sglang.srt.layers.communicator import (
    LayerCommunicator,
    LayerScatterModes,
    enable_moe_dense_fully_dp,
)
Lianmin Zheng's avatar
Lianmin Zheng committed
51
52
53
from sglang.srt.layers.dp_attention import (
    get_attention_tp_rank,
    get_attention_tp_size,
54
    is_dp_attention_enabled,
Lianmin Zheng's avatar
Lianmin Zheng committed
55
)
56
from sglang.srt.layers.layernorm import RMSNorm
57
58
59
60
61
62
from sglang.srt.layers.linear import (
    ColumnParallelLinear,
    MergedColumnParallelLinear,
    ReplicatedLinear,
    RowParallelLinear,
)
Liangsheng Yin's avatar
Liangsheng Yin committed
63
from sglang.srt.layers.logits_processor import LogitsProcessor
64
65
66
67
68
from sglang.srt.layers.moe import (
    get_deepep_mode,
    get_moe_a2a_backend,
    should_use_flashinfer_cutlass_moe_fp4_allgather,
)
69
from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, get_moe_impl_class
70
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
71
from sglang.srt.layers.moe.topk import TopK
72
from sglang.srt.layers.quantization import deep_gemm_wrapper
73
from sglang.srt.layers.quantization.base_config import QuantizationConfig
74
from sglang.srt.layers.quantization.fp8_kernel import (
75
    is_fp8_fnuz,
76
    per_tensor_quant_mla_fp8,
77
    per_token_group_quant_mla_deep_gemm_masked_fp8,
78
)
HandH1998's avatar
HandH1998 committed
79
from sglang.srt.layers.quantization.fp8_utils import (
80
    block_quant_dequant,
HandH1998's avatar
HandH1998 committed
81
    block_quant_to_tensor_quant,
82
    channel_quant_to_tensor_quant,
83
    normalize_e4m3fn_to_e4m3fnuz,
84
    requant_weight_ue8m0_inplace,
HandH1998's avatar
HandH1998 committed
85
)
86
87
88
from sglang.srt.layers.quantization.int8_utils import (
    block_dequant as int8_block_dequant,
)
Liangsheng Yin's avatar
Liangsheng Yin committed
89
from sglang.srt.layers.radix_attention import RadixAttention
90
91
from sglang.srt.layers.rotary_embedding import get_rope_wrapper
from sglang.srt.layers.utils import PPMissingLayer, get_layer_id
92
93
94
95
from sglang.srt.layers.vocab_parallel_embedding import (
    ParallelLMHead,
    VocabParallelEmbedding,
)
96
from sglang.srt.managers.schedule_batch import global_server_args_dict
97
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
98
from sglang.srt.model_loader.weight_utils import default_weight_loader
99
100
101
102
from sglang.srt.two_batch_overlap import (
    MaybeTboDeepEPDispatcher,
    model_forward_maybe_tbo,
)
103
104
from sglang.srt.utils import (
    BumpAllocator,
105
    LazyValue,
106
    add_prefix,
107
    bind_or_assign,
108
    cpu_has_amx_support,
109
    get_bool_env_var,
110
    get_device_sm,
111
    get_int_env_var,
112
    is_cpu,
113
    is_cuda,
114
    is_flashinfer_available,
115
    is_hip,
116
    is_non_idle_and_non_empty,
117
    is_npu,
118
    is_sm100_supported,
119
    log_info_on_rank0,
120
    make_layers,
121
    use_intel_amx_backend,
122
)
123

124
_is_hip = is_hip()
Yineng Zhang's avatar
Yineng Zhang committed
125
_is_cuda = is_cuda()
126
_is_npu = is_npu()
127
_is_fp8_fnuz = is_fp8_fnuz()
128
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
129
130
_is_cpu_amx_available = cpu_has_amx_support()
_is_cpu = is_cpu()
131
_device_sm = get_device_sm()
132

Yineng Zhang's avatar
Yineng Zhang committed
133
if _is_cuda:
134
135
136
137
138
139
140
    from sgl_kernel import (
        awq_dequantize,
        bmm_fp8,
        dsv3_fused_a_gemm,
        dsv3_router_gemm,
        merge_state_v2,
    )
141
142
elif _is_cpu and _is_cpu_amx_available:
    pass
143
144
145
146
elif _is_hip:
    from sglang.srt.layers.quantization.awq_triton import (
        awq_dequantize_triton as awq_dequantize,
    )
Yineng Zhang's avatar
Yineng Zhang committed
147
else:
Lianmin Zheng's avatar
Lianmin Zheng committed
148
    from vllm._custom_ops import awq_dequantize
Liangsheng Yin's avatar
Liangsheng Yin committed
149

150
151
152
153
154
if _is_hip:
    from sglang.srt.layers.attention.triton_ops.rocm_mla_decode_rope import (
        decode_attention_fwd_grouped_rope,
    )

155
156
157
_is_flashinfer_available = is_flashinfer_available()
_is_sm100_supported = is_cuda() and is_sm100_supported()

158

159
160
logger = logging.getLogger(__name__)

Liangsheng Yin's avatar
Liangsheng Yin committed
161

162
163
164
165
166
167
168
169
170
171
172
class AttnForwardMethod(IntEnum):
    # Use multi-head attention
    MHA = auto()

    # Use absorbed multi-latent attention
    MLA = auto()

    # Use multi-head attention, but with KV cache chunked.
    # This method can avoid OOM when prefix lengths are long.
    MHA_CHUNKED_KV = auto()

173
174
175
    # Use MLA but with fused RoPE
    MLA_FUSED_ROPE = auto()

176
177
178
    # Use MLA with fused RoPE kernel for CPU
    MLA_FUSED_ROPE_CPU = auto()

179

Liangsheng Yin's avatar
Liangsheng Yin committed
180
181
182
183
184
185
186
187
class DeepseekV2MLP(nn.Module):
    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
        quant_config: Optional[QuantizationConfig] = None,
        reduce_results: bool = True,
188
        prefix: str = "",
189
190
        tp_rank: Optional[int] = None,
        tp_size: Optional[int] = None,
Liangsheng Yin's avatar
Liangsheng Yin committed
191
192
    ) -> None:
        super().__init__()
193
194
        self.tp_size = tp_size

Liangsheng Yin's avatar
Liangsheng Yin committed
195
        self.gate_up_proj = MergedColumnParallelLinear(
196
197
198
199
200
            hidden_size,
            [intermediate_size] * 2,
            bias=False,
            quant_config=quant_config,
            prefix=add_prefix("gate_up_proj", prefix),
201
202
            tp_rank=tp_rank,
            tp_size=tp_size,
Liangsheng Yin's avatar
Liangsheng Yin committed
203
204
205
206
207
208
209
        )
        self.down_proj = RowParallelLinear(
            intermediate_size,
            hidden_size,
            bias=False,
            quant_config=quant_config,
            reduce_results=reduce_results,
210
            prefix=add_prefix("down_proj", prefix),
211
212
            tp_rank=tp_rank,
            tp_size=tp_size,
Liangsheng Yin's avatar
Liangsheng Yin committed
213
214
215
216
217
218
219
220
        )
        if hidden_act != "silu":
            raise ValueError(
                f"Unsupported activation: {hidden_act}. "
                "Only silu is supported for now."
            )
        self.act_fn = SiluAndMul()

221
222
223
224
    def forward(
        self,
        x,
        forward_batch=None,
225
        should_allreduce_fusion: bool = False,
226
227
        use_reduce_scatter: bool = False,
    ):
228
229
230
        if (self.tp_size == 1) and x.shape[0] == 0:
            return x

Liangsheng Yin's avatar
Liangsheng Yin committed
231
232
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
233
        x, _ = self.down_proj(
234
            x, skip_all_reduce=should_allreduce_fusion or use_reduce_scatter
235
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
236
237
238
        return x


Ke Bao's avatar
Ke Bao committed
239
class MoEGate(nn.Module):
240
241
242
243
    def __init__(
        self,
        config,
        prefix: str = "",
244
        is_nextn: bool = False,
245
    ):
Ke Bao's avatar
Ke Bao committed
246
        super().__init__()
247
        self.is_nextn = is_nextn
Ke Bao's avatar
Ke Bao committed
248
249
250
251
252
        self.weight = nn.Parameter(
            torch.empty((config.n_routed_experts, config.hidden_size))
        )
        if config.topk_method == "noaux_tc":
            self.e_score_correction_bias = nn.Parameter(
253
                torch.empty((config.n_routed_experts), dtype=torch.float32)
Ke Bao's avatar
Ke Bao committed
254
255
256
            )
        else:
            self.e_score_correction_bias = None
257
258
        if _is_cpu and _is_cpu_amx_available:
            self.quant_method = PackWeightMethod(weight_names=["weight"])
Ke Bao's avatar
Ke Bao committed
259
260

    def forward(self, hidden_states):
261
        if use_intel_amx_backend(self):
262
263
264
265
266
267
268
            return torch.ops.sgl_kernel.weight_packed_linear(
                hidden_states,
                self.weight,
                None,  # bias
                True,  # is_vnni
            )

269
        # NOTE: For some unknown reason, router_gemm seems degrade accept length.
270
        if (
271
            _is_cuda
272
            and hidden_states.shape[0] <= 16
273
274
275
276
            and hidden_states.shape[1] == 7168
            and self.weight.shape[0] == 256
            and _device_sm >= 90
        ):
277
278
            # router gemm output float32
            logits = dsv3_router_gemm(hidden_states, self.weight)
279
280
281
        else:
            logits = F.linear(hidden_states, self.weight, None)

Ke Bao's avatar
Ke Bao committed
282
283
284
        return logits


Liangsheng Yin's avatar
Liangsheng Yin committed
285
286
287
288
289
class DeepseekV2MoE(nn.Module):

    def __init__(
        self,
        config: PretrainedConfig,
fzyzcjy's avatar
fzyzcjy committed
290
        layer_id: int,
Liangsheng Yin's avatar
Liangsheng Yin committed
291
        quant_config: Optional[QuantizationConfig] = None,
292
        prefix: str = "",
293
        alt_stream: Optional[torch.cuda.Stream] = None,
294
        is_nextn: bool = False,
Liangsheng Yin's avatar
Liangsheng Yin committed
295
296
297
298
299
    ):
        super().__init__()
        self.tp_size = get_tensor_model_parallel_world_size()
        self.routed_scaling_factor = config.routed_scaling_factor
        self.n_shared_experts = config.n_shared_experts
300
301
302
303
304
        self.num_fused_shared_experts = (
            0
            if global_server_args_dict["disable_shared_experts_fusion"]
            else config.n_shared_experts
        )
305
        self.config = config
fzyzcjy's avatar
fzyzcjy committed
306
        self.layer_id = layer_id
307
        self.alt_stream = alt_stream
308

Liangsheng Yin's avatar
Liangsheng Yin committed
309
310
311
312
313
314
315
316
317
318
319
320
        if self.tp_size > config.n_routed_experts:
            raise ValueError(
                f"Tensor parallel size {self.tp_size} is greater than "
                f"the number of experts {config.n_routed_experts}."
            )

        if config.hidden_act != "silu":
            raise ValueError(
                f"Unsupported activation: {config.hidden_act}. "
                "Only silu is supported for now."
            )

321
322
323
        self.gate = MoEGate(
            config=config, prefix=add_prefix("gate", prefix), is_nextn=is_nextn
        )
Ke Bao's avatar
Ke Bao committed
324

325
        self.experts = get_moe_impl_class(quant_config)(
326
            num_experts=config.n_routed_experts
327
            + self.num_fused_shared_experts
328
            + global_server_args_dict["ep_num_redundant_experts"],
Cheng Wan's avatar
Cheng Wan committed
329
            num_fused_shared_experts=self.num_fused_shared_experts,
330
            top_k=config.num_experts_per_tok + self.num_fused_shared_experts,
331
332
            hidden_size=config.hidden_size,
            intermediate_size=config.moe_intermediate_size,
fzyzcjy's avatar
fzyzcjy committed
333
            layer_id=self.layer_id,
334
            quant_config=quant_config,
335
            routed_scaling_factor=self.routed_scaling_factor,
336
337
            prefix=add_prefix("experts", prefix),
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
338

339
340
341
342
343
344
345
346
347
348
        self.topk = TopK(
            top_k=config.num_experts_per_tok + self.num_fused_shared_experts,
            renormalize=config.norm_topk_prob,
            use_grouped_topk=True,
            num_expert_group=config.n_group,
            num_fused_shared_experts=self.num_fused_shared_experts,
            topk_group=config.topk_group,
            correction_bias=self.gate.e_score_correction_bias,
            routed_scaling_factor=self.routed_scaling_factor,
            apply_routed_scaling_factor_on_output=self.experts.should_fuse_routed_scaling_factor_in_topk(),
349
            force_topk=quant_config is None,
350
351
        )

352
353
354
        self.shared_experts_is_int8 = False
        self.shared_experts_is_fp8 = False
        self.shared_experts_weight_block_size = None
355
        if config.n_shared_experts is not None and self.num_fused_shared_experts == 0:
Liangsheng Yin's avatar
Liangsheng Yin committed
356
            intermediate_size = config.moe_intermediate_size * config.n_shared_experts
357
            # disable tp for shared experts when enable deepep moe, or with fp4 allgather
358
359
360
361
362
363
364
365
366
            self.shared_experts = DeepseekV2MLP(
                hidden_size=config.hidden_size,
                intermediate_size=intermediate_size,
                hidden_act=config.hidden_act,
                quant_config=quant_config,
                reduce_results=False,
                prefix=add_prefix("shared_experts", prefix),
                **(
                    dict(tp_rank=0, tp_size=1)
367
                    if get_moe_a2a_backend().is_deepep()
368
                    or should_use_flashinfer_cutlass_moe_fp4_allgather()
369
370
371
                    else {}
                ),
            )
AniZpZ's avatar
AniZpZ committed
372
373
374
375
            is_packed_weight = hasattr(
                self.shared_experts.gate_up_proj.quant_method, "quant_config"
            ) and self.shared_experts.gate_up_proj.quant_method.quant_config.get_name() in {
                "awq",
376
                "awq_marlin",
AniZpZ's avatar
AniZpZ committed
377
378
                "moe_wna16",
            }
379
            self.shared_experts_is_int8 = (
380
381
                not is_packed_weight
                and self.shared_experts.gate_up_proj.weight.dtype == torch.int8
382
383
            )
            self.shared_experts_is_fp8 = (
384
385
                not is_packed_weight
                and self.shared_experts.gate_up_proj.weight.dtype == torch.float8_e4m3fn
386
387
388
389
390
391
392
393
394
            )
            if self.shared_experts_is_fp8:
                assert (
                    self.shared_experts.gate_up_proj.quant_method.quant_config.weight_block_size
                    == self.shared_experts.down_proj.quant_method.quant_config.weight_block_size
                )
                self.shared_experts_weight_block_size = (
                    self.shared_experts.gate_up_proj.quant_method.quant_config.weight_block_size
                )
395

396
397
        self.top_k = config.num_experts_per_tok

398
        if get_moe_a2a_backend().is_deepep():
399
            # TODO: we will support tp < ep in the future
400
            self.ep_size = get_moe_expert_parallel_world_size()
401
402
403
404
            self.num_experts = (
                config.n_routed_experts
                + global_server_args_dict["ep_num_redundant_experts"]
            )
405
406
407
408
409
410
411
412
413
            self.renormalize = config.norm_topk_prob
            self.topk_group = config.topk_group
            self.num_expert_group = config.n_group
            self.correction_bias = (
                self.gate.e_score_correction_bias.data
                if self.gate.e_score_correction_bias is not None
                else None
            )

414
            self.deepep_dispatcher = MaybeTboDeepEPDispatcher(
415
416
417
                group=parallel_state.get_tp_group().device_group,
                router_topk=self.top_k,
                permute_fusion=True,
418
                num_experts=self.num_experts,
419
                num_local_experts=config.n_routed_experts // self.tp_size,
Liangsheng Yin's avatar
Liangsheng Yin committed
420
                hidden_size=config.hidden_size,
421
                params_dtype=config.torch_dtype,
422
                deepep_mode=get_deepep_mode(),
423
                async_finish=True,
424
                return_recv_hook=True,
Liangsheng Yin's avatar
Liangsheng Yin committed
425
426
            )

427
        self._enable_deepep_moe = get_moe_a2a_backend().is_deepep()
428

429
430
431
432
433
434
435
    def get_moe_weights(self):
        return [
            x.data
            for name, x in self.experts.named_parameters()
            if name not in ["correction_bias"]
        ]

436
    def forward(
437
438
439
        self,
        hidden_states: torch.Tensor,
        forward_batch: Optional[ForwardBatch] = None,
440
        should_allreduce_fusion: bool = False,
441
        use_reduce_scatter: bool = False,
442
443
    ) -> torch.Tensor:
        if not self._enable_deepep_moe:
444
445
446
447
            DUAL_STREAM_TOKEN_THRESHOLD = 1024
            if (
                self.alt_stream is not None
                and self.num_fused_shared_experts == 0
448
                and hidden_states.shape[0] > 0
449
450
                and hidden_states.shape[0] <= DUAL_STREAM_TOKEN_THRESHOLD
            ):
451
                return self.forward_normal_dual_stream(
452
453
454
                    hidden_states,
                    should_allreduce_fusion,
                    use_reduce_scatter,
455
                )
456
            else:
457
                return self.forward_normal(
458
459
460
                    hidden_states,
                    should_allreduce_fusion,
                    use_reduce_scatter,
461
                )
462
463
464
        else:
            return self.forward_deepep(hidden_states, forward_batch)

465
    def forward_normal_dual_stream(
466
467
        self,
        hidden_states: torch.Tensor,
468
        should_allreduce_fusion: bool = False,
469
        use_reduce_scatter: bool = False,
470
    ) -> torch.Tensor:
471

472
473
474
        current_stream = torch.cuda.current_stream()
        self.alt_stream.wait_stream(current_stream)
        shared_output = self._forward_shared_experts(hidden_states)
475

476
        with torch.cuda.stream(self.alt_stream):
477
478
            # router_logits: (num_tokens, n_experts)
            router_logits = self.gate(hidden_states)
Cheng Wan's avatar
Cheng Wan committed
479
480
            topk_output = self.topk(hidden_states, router_logits)
            final_hidden_states = self.experts(hidden_states, topk_output)
481
482
            if not _is_cuda:
                final_hidden_states *= self.routed_scaling_factor
Cheng Wan's avatar
Cheng Wan committed
483

484
        current_stream.wait_stream(self.alt_stream)
485
486
        with use_symmetric_memory(parallel_state.get_tp_group()) as sm:
            final_hidden_states_out = torch.empty_like(final_hidden_states)
Cheng Wan's avatar
Cheng Wan committed
487

488
489
490
        torch.add(final_hidden_states, shared_output, out=final_hidden_states_out)
        final_hidden_states = final_hidden_states_out
        sm.tag(final_hidden_states)
491
492
493
494
495
496
        if (
            self.tp_size > 1
            and not should_allreduce_fusion
            and not use_reduce_scatter
            and not should_use_flashinfer_cutlass_moe_fp4_allgather()
        ):
497
498
499
            final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
        return final_hidden_states

500
    def forward_normal(
501
502
        self,
        hidden_states: torch.Tensor,
503
        should_allreduce_fusion: bool = False,
504
        use_reduce_scatter: bool = False,
505
    ) -> torch.Tensor:
506
507
        if hasattr(self, "shared_experts") and use_intel_amx_backend(
            self.shared_experts.gate_up_proj
508
        ):
509
            return self.forward_cpu(hidden_states, should_allreduce_fusion)
510

511
512
513
514
515
516
517
518
        if hidden_states.shape[0] > 0:
            shared_output = self._forward_shared_experts(hidden_states)
            # router_logits: (num_tokens, n_experts)
            router_logits = self.gate(hidden_states)
            topk_output = self.topk(hidden_states, router_logits)
        else:
            shared_output = None
            topk_output = self.topk.empty_topk_output(hidden_states.device)
519

Cheng Wan's avatar
Cheng Wan committed
520
        final_hidden_states = self.experts(hidden_states, topk_output)
521
522
        if not _is_cuda and not _use_aiter:
            # fused in biased_grouped_topk so we can skip here
523
            final_hidden_states *= self.routed_scaling_factor
524
        if shared_output is not None:
525
526
527
528
529
            with use_symmetric_memory(parallel_state.get_tp_group()) as sm:
                final_hidden_states_out = torch.empty_like(final_hidden_states)
            torch.add(final_hidden_states, shared_output, out=final_hidden_states_out)
            final_hidden_states = final_hidden_states_out
            sm.tag(final_hidden_states)
530
531
532
533
534
535
        if (
            self.tp_size > 1
            and not should_allreduce_fusion
            and not use_reduce_scatter
            and not should_use_flashinfer_cutlass_moe_fp4_allgather()
        ):
536
537
538
            final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
        return final_hidden_states

539
    def forward_cpu(
540
541
542
        self,
        hidden_states: torch.Tensor,
        should_allreduce_fusion: bool = False,
543
    ) -> torch.Tensor:
544
545
        # router_logits: (num_tokens, n_experts)
        router_logits = self.gate(hidden_states)
546
        topk_output = self.topk(hidden_states, router_logits)
547
        fused_experts_out = self.experts(
548
            hidden_states=hidden_states, topk_output=topk_output
549
550
        )

551
552
553
        assert use_intel_amx_backend(
            self.shared_experts.gate_up_proj
        ) == use_intel_amx_backend(self.shared_experts.down_proj)
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
        # [Note] inplace should be False in fused_experts.
        # If inplace is True in fused_experts (self.experts), hidden_states will be changed after fused_experts
        # While hidden_states is still needed in shared_expert.
        final_hidden_states = torch.ops.sgl_kernel.shared_expert_cpu(
            hidden_states,
            self.shared_experts.gate_up_proj.weight,
            self.shared_experts.down_proj.weight,
            fused_experts_out,
            self.routed_scaling_factor,
            True,  # inplace
            self.shared_experts_is_int8,  # use_int8_w8a8
            self.shared_experts_is_fp8,  # use_fp8_w8a16
            (
                self.shared_experts.gate_up_proj.weight_scale
                if self.shared_experts_is_int8
                else (
                    self.shared_experts.gate_up_proj.weight_scale_inv
                    if self.shared_experts_is_fp8
                    else None
                )
            ),  # w1_scale
            (
                self.shared_experts.down_proj.weight_scale
                if self.shared_experts_is_int8
                else (
                    self.shared_experts.down_proj.weight_scale_inv
                    if self.shared_experts_is_fp8
                    else None
                )
            ),  # w2_scale
            (
                self.shared_experts_weight_block_size
                if self.shared_experts_is_fp8
                else None
            ),  # block_size
            None,  # a1_scale
            None,  # a2_scale
            True,  # is_vnni
        )
593
        if self.tp_size > 1 and not should_allreduce_fusion:
594
595
596
            final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
        return final_hidden_states

597
598
599
600
    def forward_deepep(
        self, hidden_states: torch.Tensor, forward_batch: ForwardBatch
    ) -> torch.Tensor:
        shared_output = None
Cheng Wan's avatar
Cheng Wan committed
601
        if hidden_states.shape[0] > 0:
602
603
604
            # router_logits: (num_tokens, n_experts)
            router_logits = self.gate(hidden_states)
            shared_output = self._forward_shared_experts(hidden_states)
605
606
607
            topk_weights, topk_idx, _ = self.topk(
                hidden_states,
                router_logits,
608
                num_token_non_padded=forward_batch.num_token_non_padded,
609
610
611
                expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new(
                    layer_id=self.layer_id,
                ),
612
613
            )
        else:
614
615
            topk_weights, topk_idx, _ = self.topk.empty_topk_output(
                hidden_states.device
616
            )
617

618
619
620
621
        final_hidden_states = self.experts(
            hidden_states=hidden_states,
            topk_idx=topk_idx,
            topk_weights=topk_weights,
622
            forward_batch=forward_batch,
623
624
625
        )

        if shared_output is not None:
626
627
628
629
630
            x = shared_output
            x.add_(final_hidden_states, alpha=self.routed_scaling_factor)
            final_hidden_states = x
        else:
            final_hidden_states *= self.routed_scaling_factor
631
632
633
634

        return final_hidden_states

    def _forward_shared_experts(self, hidden_states):
635
        if self.num_fused_shared_experts == 0:
636
637
638
639
            return self.shared_experts(hidden_states)
        else:
            return None

640
    def op_gate(self, state):
641
        if is_non_idle_and_non_empty(
642
            state.forward_batch.forward_mode, state.hidden_states_mlp_input
643
        ):
644
            # router_logits: (num_tokens, n_experts)
645
            state.router_logits = self.gate(state.hidden_states_mlp_input)
646
        else:
647
            state.router_logits = None
648

649
    def op_shared_experts(self, state):
650
        hidden_states_mlp_input = state.pop("hidden_states_mlp_input")
651
        if (self.num_fused_shared_experts == 0) and is_non_idle_and_non_empty(
652
            state.forward_batch.forward_mode, hidden_states_mlp_input
653
        ):
654
            state.shared_output = self.shared_experts(hidden_states_mlp_input)
655
        else:
656
            state.shared_output = None
657

658
    def op_select_experts(self, state):
659
        router_logits = state.pop("router_logits")
660
661
        hidden_states = state.hidden_states_mlp_input

662
        if router_logits is not None:
663
664
665
            with get_global_expert_distribution_recorder().with_current_layer(
                self.layer_id
            ):
666
                state.topk_weights_local, state.topk_idx_local, _ = self.topk(
667
668
669
670
671
672
673
                    hidden_states=hidden_states,
                    router_logits=router_logits,
                    num_token_non_padded=state.forward_batch.num_token_non_padded,
                    expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new(
                        layer_id=self.layer_id,
                    ),
                )
674
675
676
677
678
679
680
        else:
            state.topk_idx_local = torch.full(
                (0, self.top_k), -1, dtype=torch.int, device=hidden_states.device
            )
            state.topk_weights_local = torch.empty(
                (0, self.top_k), dtype=torch.float32, device=hidden_states.device
            )
681

682
    def op_dispatch_a(self, state):
683
        if self.ep_size > 1:
684
            self.experts.deepep_dispatcher.dispatch_a(
685
                hidden_states=state.hidden_states_mlp_input,
686
687
                topk_idx=state.pop("topk_idx_local"),
                topk_weights=state.pop("topk_weights_local"),
688
                forward_batch=state.forward_batch,
689
                tbo_subbatch_index=state.get("tbo_subbatch_index"),
690
            )
691

692
    def op_dispatch_b(self, state):
693
694
695
696
        if self.ep_size > 1:
            with get_global_expert_distribution_recorder().with_current_layer(
                self.layer_id
            ):
697
                state.dispatch_output = self.experts.deepep_dispatcher.dispatch_b(
698
699
                    tbo_subbatch_index=state.get("tbo_subbatch_index"),
                )
700
701

    def op_experts(self, state):
702
703
        state.hidden_states_experts_output = self.experts.moe_impl(
            dispatch_output=state.dispatch_output,
704
        )
705

706
    def op_combine_a(self, state):
707
        if self.ep_size > 1:
708
            self.experts.deepep_dispatcher.combine_a(
709
                hidden_states=state.pop("hidden_states_experts_output"),
710
711
                topk_idx=state.dispatch_output.topk_idx,
                topk_weights=state.dispatch_output.topk_weights,
712
                forward_batch=state.forward_batch,
713
                tbo_subbatch_index=state.get("tbo_subbatch_index"),
714
            )
715
            state.pop("dispatch_output")
716

717
    def op_combine_b(self, state):
718
        if self.ep_size > 1:
719
720
721
722
            state.hidden_states_after_combine = (
                self.experts.deepep_dispatcher.combine_b(
                    tbo_subbatch_index=state.get("tbo_subbatch_index"),
                )
723
            )
724
725

    def op_output(self, state):
726
        final_hidden_states = state.pop("hidden_states_after_combine")
727
728
729
730
731
732
733

        if (shared_output := state.pop("shared_output")) is not None:
            x = shared_output
            x.add_(final_hidden_states, alpha=self.routed_scaling_factor)
            final_hidden_states = x
        else:
            final_hidden_states *= self.routed_scaling_factor
Liangsheng Yin's avatar
Liangsheng Yin committed
734

735
        state.hidden_states_mlp_output = final_hidden_states
736

Liangsheng Yin's avatar
Liangsheng Yin committed
737
738
739
740
741
742
743
744
745

def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
    import math

    if scale <= 1:
        return 1.0
    return 0.1 * mscale * math.log(scale) + 1.0


746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
class DeepseekV2AttentionMLA(nn.Module):

    def __init__(
        self,
        config: PretrainedConfig,
        hidden_size: int,
        num_heads: int,
        qk_nope_head_dim: int,
        qk_rope_head_dim: int,
        v_head_dim: int,
        q_lora_rank: int,
        kv_lora_rank: int,
        rope_theta: float = 10000,
        rope_scaling: Optional[Dict[str, Any]] = None,
        max_position_embeddings: int = 8192,
        quant_config: Optional[QuantizationConfig] = None,
Lianmin Zheng's avatar
Lianmin Zheng committed
762
763
        reduce_results: bool = True,
        layer_id: int = None,
764
        prefix: str = "",
765
        alt_stream: Optional[torch.cuda.Stream] = None,
766
767
768
769
770
771
772
773
774
775
    ) -> None:
        super().__init__()
        self.layer_id = layer_id
        self.hidden_size = hidden_size
        self.qk_nope_head_dim = qk_nope_head_dim
        self.qk_rope_head_dim = qk_rope_head_dim
        self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
        self.v_head_dim = v_head_dim
        self.q_lora_rank = q_lora_rank
        self.kv_lora_rank = kv_lora_rank
Lianmin Zheng's avatar
Lianmin Zheng committed
776
777
778
        attn_tp_rank = get_attention_tp_rank()
        attn_tp_size = get_attention_tp_size()

779
        self.num_heads = num_heads
Lianmin Zheng's avatar
Lianmin Zheng committed
780
781
        assert num_heads % attn_tp_size == 0
        self.num_local_heads = num_heads // attn_tp_size
782
783
784
785
        self.scaling = self.qk_head_dim**-0.5
        self.rope_theta = rope_theta
        self.max_position_embeddings = max_position_embeddings

Lianmin Zheng's avatar
Lianmin Zheng committed
786
787
        # For tensor parallel attention
        if self.q_lora_rank is not None:
788
            self.fused_qkv_a_proj_with_mqa = ReplicatedLinear(
Ke Bao's avatar
Ke Bao committed
789
                self.hidden_size,
790
                self.q_lora_rank + self.kv_lora_rank + self.qk_rope_head_dim,
791
792
                bias=False,
                quant_config=quant_config,
793
                prefix=add_prefix("fused_qkv_a_proj_with_mqa", prefix),
794
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
795
796
797
798
            self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
            self.q_b_proj = ColumnParallelLinear(
                q_lora_rank,
                self.num_heads * self.qk_head_dim,
Ke Bao's avatar
Ke Bao committed
799
800
                bias=False,
                quant_config=quant_config,
Lianmin Zheng's avatar
Lianmin Zheng committed
801
802
803
                prefix=add_prefix("q_b_proj", prefix),
                tp_rank=attn_tp_rank,
                tp_size=attn_tp_size,
Ke Bao's avatar
Ke Bao committed
804
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
805
806
        else:
            self.q_proj = ColumnParallelLinear(
807
                self.hidden_size,
Lianmin Zheng's avatar
Lianmin Zheng committed
808
                self.num_heads * self.qk_head_dim,
809
810
                bias=False,
                quant_config=quant_config,
Lianmin Zheng's avatar
Lianmin Zheng committed
811
812
813
                prefix=add_prefix("q_proj", prefix),
                tp_rank=attn_tp_rank,
                tp_size=attn_tp_size,
814
            )
815
816
817
818
819
820
821
822
            self.kv_a_proj_with_mqa = ReplicatedLinear(
                self.hidden_size,
                self.kv_lora_rank + self.qk_rope_head_dim,
                bias=False,
                quant_config=quant_config,
                prefix=add_prefix("kv_a_proj_with_mqa", prefix),
            )

Lianmin Zheng's avatar
Lianmin Zheng committed
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
        self.kv_b_proj = ColumnParallelLinear(
            self.kv_lora_rank,
            self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
            bias=False,
            quant_config=quant_config,
            prefix=add_prefix("kv_b_proj", prefix),
            tp_rank=attn_tp_rank,
            tp_size=attn_tp_size,
        )
        # O projection.
        self.o_proj = RowParallelLinear(
            self.num_heads * self.v_head_dim,
            self.hidden_size,
            bias=False,
            quant_config=quant_config,
            reduce_results=reduce_results,
            prefix=add_prefix("o_proj", prefix),
            tp_rank=attn_tp_rank,
            tp_size=attn_tp_size,
        )
843
        self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
Ke Bao's avatar
Ke Bao committed
844
845
846
847

        if rope_scaling:
            rope_scaling["rope_type"] = "deepseek_yarn"

848
        self.rotary_emb = get_rope_wrapper(
849
850
851
852
853
854
            qk_rope_head_dim,
            rotary_dim=qk_rope_head_dim,
            max_position=max_position_embeddings,
            base=rope_theta,
            rope_scaling=rope_scaling,
            is_neox_style=False,
855
            device=global_server_args_dict["device"],
856
857
858
859
860
861
862
        )

        if rope_scaling:
            mscale_all_dim = rope_scaling.get("mscale_all_dim", False)
            scaling_factor = rope_scaling["factor"]
            mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
            self.scaling = self.scaling * mscale * mscale
Ke Bao's avatar
Ke Bao committed
863
864
        else:
            self.rotary_emb.forward = self.rotary_emb.forward_native
865

866
        self.attn_mqa = RadixAttention(
867
868
869
870
871
872
            self.num_local_heads,
            self.kv_lora_rank + self.qk_rope_head_dim,
            self.scaling,
            num_kv_heads=1,
            layer_id=layer_id,
            v_head_dim=self.kv_lora_rank,
873
            quant_config=quant_config,
874
            prefix=add_prefix("attn_mqa", prefix),
875
876
        )

877
878
879
880
881
882
883
        self.attn_mha = RadixAttention(
            self.num_local_heads,
            self.qk_nope_head_dim + self.qk_rope_head_dim,
            self.scaling,
            num_kv_heads=self.num_local_heads,
            layer_id=layer_id,
            v_head_dim=self.v_head_dim,
884
            quant_config=quant_config,
885
            prefix=add_prefix("attn_mha", prefix),
886
887
        )

888
        self.alt_stream = alt_stream
889
        self.attn_mha.kv_b_proj = None
890

Ke Bao's avatar
Ke Bao committed
891
892
        self.w_kc = None
        self.w_vc = None
893
        self.w_scale = 1.0
894

895
896
897
898
        self.w_scale_k = None
        self.w_scale_v = None
        self.use_deep_gemm_bmm = False

Lianmin Zheng's avatar
Lianmin Zheng committed
899
900
901
        self.flashinfer_mla_disable_ragged = global_server_args_dict[
            "flashinfer_mla_disable_ragged"
        ]
902
903
904
        self.disable_chunked_prefix_cache = global_server_args_dict[
            "disable_chunked_prefix_cache"
        ]
905
906
907
908

        self.current_attention_backend = (
            None  # Attention backend used by current forward batch
        )
909
910
911
        self.rocm_fused_decode_mla = get_bool_env_var(
            "SGLANG_ROCM_FUSED_DECODE_MLA", "false"
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
912

913
        # TODO: Design a finer way to determine the threshold
914
915
916
        self.chunked_prefix_cache_threshold = get_int_env_var(
            "SGL_CHUNKED_PREFIX_CACHE_THRESHOLD", 8192
        )
917

918
919
920
        # If we have self.fused_qkv_a_proj_with_mqa and we're running on CPU, we will choose the torch.ops.sgl_kernel.qkv_proj_with_rope_fused_weight kernel
        # which requires self.w_kc and self.w_vc to be packed.
        # If not, we will use torch.bmm and weight shouldn't be packed in this case
AniZpZ's avatar
AniZpZ committed
921
922
        has_fused_proj = hasattr(self, "fused_qkv_a_proj_with_mqa")
        if has_fused_proj and _is_cpu and _is_cpu_amx_available:
923
924
925
926
            self.quant_method = PackWeightMethod(
                weight_names=["w_kc", "w_vc"], transpose_dims=[[1, 2], [1, 2]]
            )

927
        is_packed_weight = (
AniZpZ's avatar
AniZpZ committed
928
929
930
            has_fused_proj
            and hasattr(self.fused_qkv_a_proj_with_mqa.quant_method, "quant_config")
            and self.fused_qkv_a_proj_with_mqa.quant_method.quant_config.get_name()
931
            in {"awq", "awq_marlin", "moe_wna16"}
932
        )
933
        self.use_min_latency_fused_a_gemm = (
AniZpZ's avatar
AniZpZ committed
934
            has_fused_proj
935
            and not is_packed_weight
936
937
938
            and self.fused_qkv_a_proj_with_mqa.weight.dtype == torch.bfloat16
            and self.fused_qkv_a_proj_with_mqa.weight.shape[0] == 2112
            and self.fused_qkv_a_proj_with_mqa.weight.shape[1] == 7168
939
            and _is_cuda
940
            and _device_sm >= 90
941
942
        )

943
        self.qkv_proj_with_rope_is_int8 = (
AniZpZ's avatar
AniZpZ committed
944
            has_fused_proj
945
            and not is_packed_weight
946
947
948
            and self.fused_qkv_a_proj_with_mqa.weight.dtype == torch.int8
        )
        self.qkv_proj_with_rope_is_fp8 = (
AniZpZ's avatar
AniZpZ committed
949
            has_fused_proj
950
            and not is_packed_weight
951
952
953
954
            and self.fused_qkv_a_proj_with_mqa.weight.dtype == torch.float8_e4m3fn
        )

        self.weight_block_size = None
955
956
957
958
959
960
        if self.qkv_proj_with_rope_is_fp8 and _is_cpu and _is_cpu_amx_available:
            assert getattr(
                self.fused_qkv_a_proj_with_mqa.quant_method, "block_quant", False
            ) == getattr(self.q_b_proj.quant_method, "block_quant", False)
            use_block_quant = getattr(
                self.fused_qkv_a_proj_with_mqa.quant_method, "block_quant", False
961
962
            )

963
964
965
966
967
968
969
970
971
            if use_block_quant:
                assert (
                    self.fused_qkv_a_proj_with_mqa.quant_method.quant_config.weight_block_size
                    == self.q_b_proj.quant_method.quant_config.weight_block_size
                )
                self.weight_block_size = (
                    self.fused_qkv_a_proj_with_mqa.quant_method.quant_config.weight_block_size
                )

972
973
974
    def dispatch_attn_forward_method(
        self, forward_batch: ForwardBatch
    ) -> AttnForwardMethod:
975
976
977
978
979
980
981
982
983
984
        def _dispatch_mla_subtype():
            if _is_hip:
                if (
                    self.rocm_fused_decode_mla
                    and forward_batch.forward_mode.is_decode()
                ):
                    return AttnForwardMethod.MLA_FUSED_ROPE
                else:
                    return AttnForwardMethod.MLA
            else:
985
986
                if hasattr(self, "fused_qkv_a_proj_with_mqa") and use_intel_amx_backend(
                    self
987
988
989
990
                ):
                    return AttnForwardMethod.MLA_FUSED_ROPE_CPU
                else:
                    return AttnForwardMethod.MLA
991

992
993
994
995
996
997
998
999
        # Determine attention backend used by current forward batch
        if forward_batch.forward_mode.is_decode_or_idle():
            attention_backend = global_server_args_dict["decode_attention_backend"]
        else:
            attention_backend = global_server_args_dict["prefill_attention_backend"]
        self.current_attention_backend = attention_backend

        if attention_backend == "ascend":
1000
1001
1002
1003
1004
1005
1006
1007
            if (
                forward_batch.forward_mode.is_extend()
                and not forward_batch.forward_mode.is_target_verify()
                and not forward_batch.forward_mode.is_draft_extend()
            ):
                return AttnForwardMethod.MHA
            else:
                return AttnForwardMethod.MLA
1008
1009
1010
1011
        elif (
            attention_backend == "flashinfer"
            or attention_backend == "fa3"
            or attention_backend == "flashmla"
1012
1013
            or attention_backend == "trtllm_mla"
            or attention_backend == "cutlass_mla"
1014
1015
1016
1017
1018
1019
1020
        ):
            # Use MHA with chunked KV cache when prefilling on long sequences.
            sum_extend_prefix_lens = (
                sum(forward_batch.extend_prefix_lens_cpu)
                if forward_batch.extend_prefix_lens_cpu is not None
                else 0
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
1021
            # Flashinfer MLA: Do not absorb when enabling ragged prefill
1022
1023
1024
            disable_ragged = (
                attention_backend == "flashinfer" or attention_backend == "flashmla"
            ) and self.flashinfer_mla_disable_ragged
1025
            if (
1026
                not disable_ragged
Lianmin Zheng's avatar
Lianmin Zheng committed
1027
1028
1029
                and forward_batch.forward_mode.is_extend()
                and not forward_batch.forward_mode.is_target_verify()
                and not forward_batch.forward_mode.is_draft_extend()
1030
                and (
1031
1032
1033
1034
                    (
                        sum_extend_prefix_lens >= self.chunked_prefix_cache_threshold
                        and not self.disable_chunked_prefix_cache
                    )
1035
1036
                    or sum_extend_prefix_lens == 0
                )
1037
1038
1039
            ):
                return AttnForwardMethod.MHA_CHUNKED_KV
            else:
1040
                return _dispatch_mla_subtype()
1041
        elif attention_backend == "aiter":
1042
1043
1044
1045
1046
1047
1048
1049
            if (
                forward_batch.forward_mode.is_extend()
                and not forward_batch.forward_mode.is_target_verify()
                and not forward_batch.forward_mode.is_draft_extend()
            ):
                return AttnForwardMethod.MHA
            else:
                return AttnForwardMethod.MLA
Lianmin Zheng's avatar
Lianmin Zheng committed
1050
1051
        else:
            # Triton: Use normal computation for prefill and use weight absorption for extend/decode
1052
            if (
Lianmin Zheng's avatar
Lianmin Zheng committed
1053
1054
1055
                forward_batch.forward_mode.is_extend()
                and not forward_batch.forward_mode.is_target_verify()
                and not forward_batch.forward_mode.is_draft_extend()
1056
                and sum(forward_batch.extend_prefix_lens_cpu) == 0
1057
1058
1059
            ):
                return AttnForwardMethod.MHA
            else:
1060
                return _dispatch_mla_subtype()
Lianmin Zheng's avatar
Lianmin Zheng committed
1061

1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
    def op_prepare(self, state):
        state.attn_intermediate_state = self.forward_prepare(
            positions=state.positions,
            hidden_states=state.pop("hidden_states_after_comm_pre_attn"),
            forward_batch=state.forward_batch,
            zero_allocator=state.zero_allocator,
        )

    def op_core(self, state):
        state.hidden_states_after_attn = self.forward_core(
            state.pop("attn_intermediate_state")
        )

1075
1076
1077
1078
    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
1079
        forward_batch: ForwardBatch,
1080
        zero_allocator: BumpAllocator,
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
    ):
        s = self.forward_prepare(
            positions=positions,
            hidden_states=hidden_states,
            forward_batch=forward_batch,
            zero_allocator=zero_allocator,
        )
        return self.forward_core(s)

    def forward_prepare(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        forward_batch: ForwardBatch,
        zero_allocator: BumpAllocator,
    ):
1097
1098
1099
        if self.attn_mha.kv_b_proj is None:
            self.attn_mha.kv_b_proj = self.kv_b_proj

Lianmin Zheng's avatar
Lianmin Zheng committed
1100
1101
1102
1103
        if hidden_states.shape[0] == 0:
            assert (
                not self.o_proj.reduce_results
            ), "short-circuiting allreduce will lead to hangs"
1104
            return hidden_states, None, forward_batch, None
1105

1106
1107
1108
        attn_forward_method = self.dispatch_attn_forward_method(forward_batch)

        if attn_forward_method == AttnForwardMethod.MHA:
1109
1110
1111
            inner_state = self.forward_normal_prepare(
                positions, hidden_states, forward_batch, zero_allocator
            )
1112
        elif attn_forward_method == AttnForwardMethod.MHA_CHUNKED_KV:
1113
1114
            inner_state = self.forward_normal_chunked_kv_prepare(
                positions, hidden_states, forward_batch, zero_allocator
1115
            )
1116
        elif attn_forward_method == AttnForwardMethod.MLA:
1117
            inner_state = self.forward_absorb_prepare(
1118
1119
1120
                positions, hidden_states, forward_batch, zero_allocator
            )
        elif attn_forward_method == AttnForwardMethod.MLA_FUSED_ROPE:
1121
1122
            inner_state = self.forward_absorb_fused_mla_rope_prepare(
                positions, hidden_states, forward_batch, zero_allocator
1123
            )
1124
1125
1126
1127
        elif attn_forward_method == AttnForwardMethod.MLA_FUSED_ROPE_CPU:
            inner_state = self.forward_absorb_fused_mla_rope_cpu_prepare(
                positions, hidden_states, forward_batch, zero_allocator
            )
1128
        else:
1129
            raise NotImplementedError
1130
        return None, attn_forward_method, forward_batch, inner_state
1131

1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
    def forward_core(self, intermediate_state):
        hidden_states, attn_forward_method, forward_batch, inner_state = (
            intermediate_state
        )
        if inner_state is None:
            return hidden_states

        if attn_forward_method == AttnForwardMethod.MHA:
            return self.forward_normal_core(*inner_state)
        elif attn_forward_method == AttnForwardMethod.MHA_CHUNKED_KV:
            return self.forward_normal_chunked_kv_core(*inner_state)
        elif attn_forward_method == AttnForwardMethod.MLA:
            return self.forward_absorb_core(*inner_state)
        elif attn_forward_method == AttnForwardMethod.MLA_FUSED_ROPE:
            return self.forward_absorb_fused_mla_rope_core(*inner_state)
1147
1148
        elif attn_forward_method == AttnForwardMethod.MLA_FUSED_ROPE_CPU:
            return self.forward_absorb_fused_mla_rope_cpu_core(*inner_state)
1149
1150
1151
1152
        else:
            raise NotImplementedError

    def forward_normal_prepare(
1153
1154
1155
1156
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        forward_batch: ForwardBatch,
1157
1158
        zero_allocator: BumpAllocator,
    ):
1159
        if self.q_lora_rank is not None:
1160
1161
1162
            q, latent_cache = self.fused_qkv_a_proj_with_mqa(hidden_states)[0].split(
                [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1
            )
1163
1164
1165
1166
1167
1168
            q = self.q_a_layernorm(q)
            q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
        else:
            q = self.q_proj(hidden_states)[0].view(
                -1, self.num_local_heads, self.qk_head_dim
            )
1169
1170
            latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]

1171
1172
1173
        _, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
        kv_a, _ = latent_cache.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
        latent_cache = latent_cache.unsqueeze(1)
1174
        kv_a = self.kv_a_layernorm(kv_a)
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
        kv = self.kv_b_proj(kv_a)[0]
        kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim)
        k_nope = kv[..., : self.qk_nope_head_dim]
        v = kv[..., self.qk_nope_head_dim :]
        k_pe = latent_cache[:, :, self.kv_lora_rank :]
        q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
        q[..., self.qk_nope_head_dim :] = q_pe
        k = torch.empty_like(q)
        k[..., : self.qk_nope_head_dim] = k_nope
        k[..., self.qk_nope_head_dim :] = k_pe

1186
1187
1188
        if not _is_npu:
            latent_cache[:, :, : self.kv_lora_rank] = kv_a.unsqueeze(1)
            latent_cache[:, :, self.kv_lora_rank :] = k_pe
1189

1190
1191
1192
1193
1194
1195
1196
1197
1198
            # Save latent cache
            forward_batch.token_to_kv_pool.set_kv_buffer(
                self.attn_mha, forward_batch.out_cache_loc, latent_cache, None
            )
        else:
            # To reduce a time-costing split operation
            forward_batch.token_to_kv_pool.set_kv_buffer(
                self.attn_mha, forward_batch.out_cache_loc, kv_a.unsqueeze(1), k_pe
            )
1199
1200
1201
1202

        return q, k, v, forward_batch

    def forward_normal_core(self, q, k, v, forward_batch):
1203
1204
1205
1206
1207
        attn_output = self.attn_mha(q, k, v, forward_batch, save_kv_cache=False)
        attn_output = attn_output.reshape(-1, self.num_local_heads * self.v_head_dim)
        output, _ = self.o_proj(attn_output)
        return output

Faraz's avatar
Faraz committed
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
    def _fuse_rope_for_trtllm_mla(self, forward_batch: ForwardBatch) -> bool:
        """
        Check if we should skip rope and do fused rope+quantize for TRTLLM MLA decode in fp8_e4m3 path.
        """
        return (
            self.current_attention_backend == "trtllm_mla"
            and forward_batch.forward_mode.is_decode_or_idle()
            and forward_batch.attn_backend.data_type == torch.float8_e4m3fn
        )

1218
    def forward_absorb_prepare(
1219
1220
1221
1222
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        forward_batch: ForwardBatch,
1223
        zero_allocator: BumpAllocator,
1224
    ):
1225
        from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
1226

1227
        if self.q_lora_rank is not None:
1228
1229
1230
1231
1232
1233
1234
            if hidden_states.shape[0] <= 16 and self.use_min_latency_fused_a_gemm:
                fused_qkv_a_proj_out = dsv3_fused_a_gemm(
                    hidden_states, self.fused_qkv_a_proj_with_mqa.weight.T
                )
            else:
                fused_qkv_a_proj_out = self.fused_qkv_a_proj_with_mqa(hidden_states)[0]
            q, latent_cache = fused_qkv_a_proj_out.split(
1235
1236
                [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1
            )
1237
1238
1239
            k_nope = latent_cache[..., : self.kv_lora_rank]

            # overlap qk norm
1240
            if self.alt_stream is not None and get_is_capture_mode():
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
                current_stream = torch.cuda.current_stream()
                self.alt_stream.wait_stream(current_stream)
                q = self.q_a_layernorm(q)
                with torch.cuda.stream(self.alt_stream):
                    k_nope = self.kv_a_layernorm(k_nope)
                current_stream.wait_stream(self.alt_stream)
            else:
                q = self.q_a_layernorm(q)
                k_nope = self.kv_a_layernorm(k_nope)

            k_nope = k_nope.unsqueeze(1)
1252
1253
1254
1255
1256
            q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
        else:
            q = self.q_proj(hidden_states)[0].view(
                -1, self.num_local_heads, self.qk_head_dim
            )
1257
            latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
1258
1259
1260
            k_nope = latent_cache[..., : self.kv_lora_rank]
            k_nope = self.kv_a_layernorm(k_nope).unsqueeze(1)

1261
        q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
1262
        k_pe = latent_cache[..., self.kv_lora_rank :].unsqueeze(1)
1263

1264
1265
        if self.use_deep_gemm_bmm:
            q_nope_val, q_nope_scale, masked_m, expected_m, aligned_m = (
1266
                per_token_group_quant_mla_deep_gemm_masked_fp8(q_nope.transpose(0, 1))
1267
1268
1269
1270
            )
            q_nope_out = q_nope.new_empty(
                (self.num_local_heads, aligned_m, self.kv_lora_rank)
            )
1271
            deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
1272
1273
1274
1275
1276
1277
1278
                (q_nope_val, q_nope_scale),
                (self.w_kc, self.w_scale_k),
                q_nope_out,
                masked_m,
                expected_m,
            )
            q_nope_out = q_nope_out[:, :expected_m, :]
1279
1280
        elif _is_hip:
            # TODO(haishaw): add bmm_fp8 to ROCm
1281
1282
1283
1284
            q_nope_out = torch.bmm(
                q_nope.to(torch.bfloat16).transpose(0, 1),
                self.w_kc.to(torch.bfloat16) * self.w_scale,
            )
1285
        elif self.w_kc.dtype == torch.float8_e4m3fn:
1286
            q_nope_val, q_nope_scale = per_tensor_quant_mla_fp8(
Lianmin Zheng's avatar
Lianmin Zheng committed
1287
                q_nope.transpose(0, 1),
1288
                zero_allocator.allocate(1),
1289
1290
1291
1292
1293
1294
            )
            q_nope_out = bmm_fp8(
                q_nope_val, self.w_kc, q_nope_scale, self.w_scale, torch.bfloat16
            )
        else:
            q_nope_out = torch.bmm(q_nope.transpose(0, 1), self.w_kc)
1295
1296

        q_nope_out = q_nope_out.transpose(0, 1)
Faraz's avatar
Faraz committed
1297
1298
1299

        if not self._fuse_rope_for_trtllm_mla(forward_batch):
            q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
1300

1301
1302
1303
1304
1305
        return q_pe, k_pe, q_nope_out, k_nope, forward_batch, zero_allocator

    def forward_absorb_core(
        self, q_pe, k_pe, q_nope_out, k_nope, forward_batch, zero_allocator
    ):
1306
        if (
1307
1308
1309
            self.current_attention_backend == "fa3"
            or self.current_attention_backend == "flashinfer"
            or self.current_attention_backend == "cutlass_mla"
1310
            or self.current_attention_backend == "trtllm_mla"
1311
            or self.current_attention_backend == "ascend"
1312
        ):
Faraz's avatar
Faraz committed
1313
1314
1315
1316
1317
1318
            extra_args = {}
            if self._fuse_rope_for_trtllm_mla(forward_batch):
                extra_args = {
                    "cos_sin_cache": self.rotary_emb.cos_sin_cache,
                    "is_neox": self.rotary_emb.is_neox_style,
                }
1319
            attn_output = self.attn_mqa(
Faraz's avatar
Faraz committed
1320
1321
1322
1323
1324
1325
1326
                q_nope_out,
                k_nope,
                k_nope,
                forward_batch,
                q_rope=q_pe,
                k_rope=k_pe,
                **extra_args,
1327
1328
1329
            )
        else:
            q = torch.cat([q_nope_out, q_pe], dim=-1)
Ke Bao's avatar
Ke Bao committed
1330
            k = torch.cat([k_nope, k_pe], dim=-1)
1331
            attn_output = self.attn_mqa(q, k, k_nope, forward_batch)
1332
1333
        attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)

1334
1335
        if self.use_deep_gemm_bmm:
            attn_output_val, attn_output_scale, masked_m, expected_m, aligned_m = (
1336
1337
                per_token_group_quant_mla_deep_gemm_masked_fp8(
                    attn_output.transpose(0, 1)
1338
1339
1340
1341
1342
                )
            )
            attn_bmm_output = attn_output.new_empty(
                (self.num_local_heads, aligned_m, self.v_head_dim)
            )
1343
            deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
1344
1345
1346
1347
1348
1349
                (attn_output_val, attn_output_scale),
                (self.w_vc, self.w_scale_v),
                attn_bmm_output,
                masked_m,
                expected_m,
            )
Ke Bao's avatar
Ke Bao committed
1350
1351
1352
            attn_bmm_output = (
                attn_bmm_output[:, :expected_m, :].transpose(0, 1).flatten(1, 2)
            )
1353
1354
        elif _is_hip:
            # TODO(haishaw): add bmm_fp8 to ROCm
1355
1356
1357
1358
            attn_bmm_output = torch.bmm(
                attn_output.to(torch.bfloat16).transpose(0, 1),
                self.w_vc.to(torch.bfloat16) * self.w_scale,
            )
Ke Bao's avatar
Ke Bao committed
1359
            attn_bmm_output = attn_bmm_output.transpose(0, 1).flatten(1, 2)
1360
        elif self.w_vc.dtype == torch.float8_e4m3fn:
1361
            attn_output_val, attn_output_scale = per_tensor_quant_mla_fp8(
Lianmin Zheng's avatar
Lianmin Zheng committed
1362
                attn_output.transpose(0, 1),
1363
                zero_allocator.allocate(1),
1364
1365
1366
1367
1368
1369
1370
1371
            )
            attn_bmm_output = bmm_fp8(
                attn_output_val,
                self.w_vc,
                attn_output_scale,
                self.w_scale,
                torch.bfloat16,
            )
Ke Bao's avatar
Ke Bao committed
1372
            attn_bmm_output = attn_bmm_output.transpose(0, 1).flatten(1, 2)
1373
        else:
Ke Bao's avatar
Ke Bao committed
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
            attn_bmm_output = torch.empty(
                (attn_output.shape[0], self.num_local_heads * self.v_head_dim),
                dtype=attn_output.dtype,
                device=attn_output.device,
            )
            torch.bmm(
                attn_output.transpose(0, 1),
                self.w_vc,
                out=attn_bmm_output.view(
                    -1, self.num_local_heads, self.v_head_dim
                ).transpose(0, 1),
            )
        output, _ = self.o_proj(attn_bmm_output)
1387
1388
1389

        return output

1390
    def forward_absorb_fused_mla_rope_prepare(
1391
1392
1393
1394
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        forward_batch: ForwardBatch,
1395
        zero_allocator: BumpAllocator,
1396
    ):
1397
1398
1399
1400
1401
1402
1403
1404
        enable_rope_fusion = (
            os.getenv("SGLANG_FUSED_MLA_ENABLE_ROPE_FUSION", "1") == "1"
        )
        q_len = hidden_states.shape[0]
        q_input = hidden_states.new_empty(
            q_len, self.num_local_heads, self.kv_lora_rank + self.qk_rope_head_dim
        )
        if self.q_lora_rank is not None:
1405
1406
1407
            q, latent_cache = self.fused_qkv_a_proj_with_mqa(hidden_states)[0].split(
                [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1
            )
1408
1409
1410
1411
1412
1413
            q = self.q_a_layernorm(q)
            q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
        else:
            q = self.q_proj(hidden_states)[0].view(
                -1, self.num_local_heads, self.qk_head_dim
            )
1414
            latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
1415
1416
        q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)

1417
1418
        if _is_hip:
            # TODO(haishaw): add bmm_fp8 to ROCm
1419
1420
1421
1422
1423
            q_nope_out = torch.bmm(
                q_nope.to(torch.bfloat16).transpose(0, 1),
                self.w_kc.to(torch.bfloat16) * self.w_scale,
            )
        elif self.w_kc.dtype == torch.float8_e4m3fn:
1424
            q_nope_val, q_nope_scale = per_tensor_quant_mla_fp8(
1425
1426
1427
                q_nope.transpose(0, 1),
                zero_allocator.allocate(1),
                dtype=torch.float8_e4m3fn,
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
            )
            q_nope_out = bmm_fp8(
                q_nope_val, self.w_kc, q_nope_scale, self.w_scale, torch.bfloat16
            )
        else:
            q_nope_out = torch.bmm(q_nope.transpose(0, 1), self.w_kc)
        q_input[..., : self.kv_lora_rank] = q_nope_out.transpose(0, 1)
        v_input = latent_cache[..., : self.kv_lora_rank]
        v_input = self.kv_a_layernorm(v_input.contiguous()).unsqueeze(1)
        k_input = latent_cache.unsqueeze(1)
        k_input[..., : self.kv_lora_rank] = v_input

        if not enable_rope_fusion:
            k_pe = k_input[..., self.kv_lora_rank :]
            q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
            q_input[..., self.kv_lora_rank :] = q_pe
            k_input[..., self.kv_lora_rank :] = k_pe
            k_pe_output = None
        else:
            k_pe_output = torch.empty_like(k_input[..., self.kv_lora_rank :])

        q_input[..., self.kv_lora_rank :] = q_pe

        # attn_output = self.attn_mqa(q_input, k_input, v_input, forward_batch)
        # Use Fused ROPE with use_rope=OFF.
        attn_output = torch.empty(
            (q_len, self.num_local_heads, self.kv_lora_rank),
            dtype=q.dtype,
            device=q.device,
        )
        attn_logits, _, kv_indptr, kv_indices, _, _, _ = (
            forward_batch.attn_backend.forward_metadata
        )
        cos_sin_cache = self.rotary_emb.cos_sin_cache
        num_kv_split = forward_batch.attn_backend.num_kv_splits
        sm_scale = self.attn_mqa.scaling
        if attn_logits is None:
            attn_logits = torch.empty(
                (
                    forward_batch.batch_size,
                    self.num_local_heads,
                    num_kv_split,
                    self.kv_lora_rank + 1,
                ),
                dtype=torch.float32,
                device=q.device,
            )

        # save current latent cache.
        forward_batch.token_to_kv_pool.set_kv_buffer(
            self.attn_mqa, forward_batch.out_cache_loc, k_input, None
        )
        key_cache_buf = forward_batch.token_to_kv_pool.get_key_buffer(
            self.attn_mqa.layer_id
        )
        val_cache_buf = key_cache_buf[..., : self.kv_lora_rank]

1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
        return (
            q_input,
            key_cache_buf,
            val_cache_buf,
            attn_output,
            kv_indptr,
            kv_indices,
            k_pe_output,
            cos_sin_cache,
            positions,
            attn_logits,
            num_kv_split,
            sm_scale,
            enable_rope_fusion,
            k_input,
            forward_batch,
            zero_allocator,
        )

1504
1505
1506
1507
1508
1509
1510
    def forward_absorb_fused_mla_rope_cpu_prepare(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        forward_batch: ForwardBatch,
        zero_allocator: BumpAllocator,
    ):
1511
1512
        assert self.q_lora_rank is not None and use_intel_amx_backend(
            self
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
        ), "forward_absorb_fused_mla_rope_cpu_prepare requires q_lora_rank is not None and use_intel_amx_backend"

        q_input, k_input, v_input = (
            torch.ops.sgl_kernel.qkv_proj_with_rope_fused_weight(
                hidden_states,
                self.fused_qkv_a_proj_with_mqa.weight,
                self.q_b_proj.weight,
                self.w_kc,
                self.q_a_layernorm.weight,
                self.kv_a_layernorm.weight,
                positions,
                self.rotary_emb.cos_sin_cache,
                self.kv_a_layernorm.variance_epsilon,
                self.qkv_proj_with_rope_is_int8,
                self.qkv_proj_with_rope_is_fp8,
                (
                    self.fused_qkv_a_proj_with_mqa.weight_scale
                    if self.qkv_proj_with_rope_is_int8
                    else (
                        self.fused_qkv_a_proj_with_mqa.weight_scale_inv
                        if self.qkv_proj_with_rope_is_fp8
                        else None
                    )
                ),
                (
                    self.q_b_proj.weight_scale
                    if self.qkv_proj_with_rope_is_int8
                    else (
                        self.q_b_proj.weight_scale_inv
                        if self.qkv_proj_with_rope_is_fp8
                        else None
                    )
                ),
                True,  # is_vnni
                self.weight_block_size,
                self.q_lora_rank,
                self.kv_lora_rank,
                self.qk_rope_head_dim,
            )
        )
        return (q_input, k_input, v_input, forward_batch, zero_allocator)

1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
    def forward_absorb_fused_mla_rope_core(
        self,
        q_input,
        key_cache_buf,
        val_cache_buf,
        attn_output,
        kv_indptr,
        kv_indices,
        k_pe_output,
        cos_sin_cache,
        positions,
        attn_logits,
        num_kv_split,
        sm_scale,
        enable_rope_fusion,
        k_input,
        forward_batch,
        zero_allocator,
    ):
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
        decode_attention_fwd_grouped_rope(
            q_input,
            key_cache_buf,
            val_cache_buf,
            attn_output,
            kv_indptr,
            kv_indices,
            k_pe_output,
            self.kv_lora_rank,
            self.rotary_emb.rotary_dim,
            cos_sin_cache,
            positions,
            attn_logits,
            num_kv_split,
            sm_scale,
            logit_cap=self.attn_mqa.logit_cap,
            use_rope=enable_rope_fusion,
            is_neox_style=self.rotary_emb.is_neox_style,
        )

        if enable_rope_fusion:
            k_input[..., self.kv_lora_rank :] = k_pe_output
            forward_batch.token_to_kv_pool.set_kv_buffer(
                self.attn_mqa, forward_batch.out_cache_loc, k_input, None
            )

        attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)

1602
1603
        if _is_hip:
            # TODO(haishaw): add bmm_fp8 to ROCm
1604
1605
1606
1607
1608
            attn_bmm_output = torch.bmm(
                attn_output.to(torch.bfloat16).transpose(0, 1),
                self.w_vc.to(torch.bfloat16) * self.w_scale,
            )
        elif self.w_vc.dtype == torch.float8_e4m3fn:
1609
            attn_output_val, attn_output_scale = per_tensor_quant_mla_fp8(
1610
1611
1612
                attn_output.transpose(0, 1),
                zero_allocator.allocate(1),
                dtype=torch.float8_e4m3fn,
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
            )
            attn_bmm_output = bmm_fp8(
                attn_output_val,
                self.w_vc,
                attn_output_scale,
                self.w_scale,
                torch.bfloat16,
            )
        else:
            attn_bmm_output = torch.bmm(attn_output.transpose(0, 1), self.w_vc)
        attn_output = attn_bmm_output.transpose(0, 1).flatten(1, 2)
1624
1625
1626
1627
        output, _ = self.o_proj(attn_output)

        return output

1628
1629
1630
    def forward_absorb_fused_mla_rope_cpu_core(
        self, q_input, k_input, v_input, forward_batch, zero_allocator
    ):
1631
1632
        assert self.q_lora_rank is not None and use_intel_amx_backend(
            self
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
        ), "forward_absorb_fused_mla_rope_cpu_core requires q_lora_rank is not None and use_intel_amx_backend"

        attn_output = self.attn_mqa(q_input, k_input, v_input, forward_batch)
        attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)

        # [Note] Align shapes of bmm inputs.
        # Shapes of inputs:
        #   q_nope: [M, B, K]
        #   original self.w_kc: [B, K, N]
        #   current self.w_kc (which has been converted in PackWeightMethod): [B, N, K]

        # Shapes of inputs to sgl_kernel.cpu.bmm:
        #   out: [B, M, N]
        #   mat1: [B, M, K]
        #   mat2: [B, N, K]
        B = self.w_vc.size(0)
        N = self.w_vc.size(1)
        M = attn_output.size(0)
        output = torch.empty([M, int(B * N)], dtype=attn_output.dtype)
        attn_bmm_output = output.view([M, B, N]).transpose_(0, 1)
        torch.ops.sgl_kernel.bmm_cpu(
            attn_bmm_output,
            attn_output.transpose(0, 1),
            self.w_vc,
            True,  # is_vnni
            None,  # scale
        )
        attn_output = output
        output, _ = self.o_proj(attn_output)

        return output

1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
    def _chunked_prefix_attn_mha(
        self,
        q: torch.Tensor,
        accum_output: torch.Tensor,
        accum_lse: torch.Tensor,
        forward_batch: ForwardBatch,
    ) -> torch.Tensor:

        assert forward_batch.num_prefix_chunks is not None
        for i in range(forward_batch.num_prefix_chunks):
            forward_batch.set_prefix_chunk_idx(i)

            # Fetch latent cache from memory pool with precomputed chunked kv indices
            latent_cache_buf = forward_batch.token_to_kv_pool.get_key_buffer(
                self.attn_mha.layer_id
            )
            latent_cache = latent_cache_buf[
                forward_batch.prefix_chunk_kv_indices[i]
            ].contiguous()

            kv_a_normed, k_pe = latent_cache.split(
                [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
            )
            kv_a_normed = kv_a_normed.squeeze(1).contiguous()
            kv = self.kv_b_proj(kv_a_normed)[0]
            kv = kv.view(
                -1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim
            )
            v = kv[..., self.qk_nope_head_dim :]
            k_nope = kv[..., : self.qk_nope_head_dim]

            k = torch.empty(
                (
                    k_nope.shape[0],
                    self.num_local_heads,
                    self.qk_nope_head_dim + self.qk_rope_head_dim,
                ),
                dtype=v.dtype,
                device=v.device,
            )
            k[..., : self.qk_nope_head_dim] = k_nope
            k[..., self.qk_nope_head_dim :] = k_pe

            output, lse = self.attn_mha(q, k, v, forward_batch, save_kv_cache=False)
            tmp_output = torch.empty_like(accum_output)
            tmp_lse = torch.empty_like(accum_lse)
            merge_state_v2(output, lse, accum_output, accum_lse, tmp_output, tmp_lse)
            accum_output, accum_lse = tmp_output, tmp_lse

        return accum_output

1716
    def forward_normal_chunked_kv_prepare(
1717
1718
1719
1720
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        forward_batch: ForwardBatch,
1721
1722
        zero_allocator: BumpAllocator,
    ):
1723
1724
1725
1726
1727
1728
1729
        # In normal mha, the k and v tensors will become overly large when the prefix length is long.
        # To avoid this, we split the kv cache into chunks and process them one after another.
        # Since mha is compute friendly, the for loop induced here will not introduce significant overhead.
        # The top comments in https://github.com/vllm-project/vllm/blob/main/vllm/v1/attention/backends/mla/common.py
        # will be helpful for understanding the purpose of this function.

        # First do normal mha forward to get output for extended part
1730
1731
        return self.forward_normal_prepare(
            positions, hidden_states, forward_batch, zero_allocator
1732
1733
        )

1734
    def forward_normal_chunked_kv_core(self, q, k, v, forward_batch):
1735
1736
1737
1738
1739
1740
1741
1742
        has_extend_prefix = any(forward_batch.extend_prefix_lens_cpu)
        # Only initialize the info once
        if has_extend_prefix and forward_batch.num_prefix_chunks is None:
            forward_batch.prepare_chunked_prefix_cache_info(q.device)
            if hasattr(forward_batch.attn_backend, "init_mha_chunk_metadata"):
                forward_batch.attn_backend.init_mha_chunk_metadata(forward_batch)

        forward_batch.mha_return_lse = has_extend_prefix
1743
1744
        # Do mha for extended part without prefix
        forward_batch.set_attn_attend_prefix_cache(False)
1745
        attn_output = self.attn_mha(q, k, v, forward_batch, save_kv_cache=False)
1746
1747

        # Do mha attention with chunked prefix cache if there are any sequence with prefix
1748
1749
        if has_extend_prefix:
            attn_output, lse = attn_output
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
            forward_batch.set_attn_attend_prefix_cache(True)
            attn_output = self._chunked_prefix_attn_mha(
                q=q,
                accum_output=attn_output,
                accum_lse=lse,
                forward_batch=forward_batch,
            )

        attn_output = attn_output.reshape(-1, self.num_local_heads * self.v_head_dim)
        output, _ = self.o_proj(attn_output)
        return output

1762

Liangsheng Yin's avatar
Liangsheng Yin committed
1763
1764
1765
1766
1767
1768
1769
class DeepseekV2DecoderLayer(nn.Module):

    def __init__(
        self,
        config: PretrainedConfig,
        layer_id: int,
        quant_config: Optional[QuantizationConfig] = None,
1770
        is_nextn: bool = False,
1771
        prefix: str = "",
1772
        alt_stream: Optional[torch.cuda.Stream] = None,
Liangsheng Yin's avatar
Liangsheng Yin committed
1773
1774
1775
    ) -> None:
        super().__init__()
        self.hidden_size = config.hidden_size
1776
        self.config = config
Liangsheng Yin's avatar
Liangsheng Yin committed
1777
1778
1779
        rope_theta = getattr(config, "rope_theta", 10000)
        rope_scaling = getattr(config, "rope_scaling", None)
        max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
1780
        self.speculative_algorithm = global_server_args_dict["speculative_algorithm"]
Lianmin Zheng's avatar
Lianmin Zheng committed
1781
        self.layer_id = layer_id
1782
        self.is_nextn = is_nextn
Baizhou Zhang's avatar
Baizhou Zhang committed
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
1799
1800
        self.self_attn = DeepseekV2AttentionMLA(
            config=config,
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
            qk_nope_head_dim=config.qk_nope_head_dim,
            qk_rope_head_dim=config.qk_rope_head_dim,
            v_head_dim=config.v_head_dim,
            q_lora_rank=(
                config.q_lora_rank if hasattr(config, "q_lora_rank") else None
            ),
            kv_lora_rank=config.kv_lora_rank,
            rope_theta=rope_theta,
            rope_scaling=rope_scaling,
            max_position_embeddings=max_position_embeddings,
            quant_config=quant_config,
            layer_id=layer_id,
            reduce_results=False,
            prefix=add_prefix("self_attn", prefix),
1801
            alt_stream=alt_stream,
Baizhou Zhang's avatar
Baizhou Zhang committed
1802
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
1803

1804
1805
1806
1807
1808
        self.is_layer_sparse = self._is_layer_sparse(layer_id, is_nextn=is_nextn)
        is_previous_layer_sparse = self._is_layer_sparse(layer_id - 1, is_nextn=False)

        self.layer_scatter_modes = LayerScatterModes.init_new(
            layer_id=layer_id,
1809
            num_layers=1 if is_nextn else config.num_hidden_layers,
1810
1811
            is_layer_sparse=self.is_layer_sparse,
            is_previous_layer_sparse=is_previous_layer_sparse,
1812
1813
        )

1814
        if self.is_layer_sparse:
1815
1816
1817
1818
            self.mlp = DeepseekV2MoE(
                config=config,
                quant_config=quant_config,
                prefix=add_prefix("mlp", prefix),
fzyzcjy's avatar
fzyzcjy committed
1819
                layer_id=self.layer_id,
1820
                alt_stream=alt_stream,
1821
                is_nextn=is_nextn,
1822
            )
Liangsheng Yin's avatar
Liangsheng Yin committed
1823
        else:
1824
            if enable_moe_dense_fully_dp():
1825
1826
1827
                mlp_tp_rank, mlp_tp_size = 0, 1
            else:
                mlp_tp_rank, mlp_tp_size = None, None
Liangsheng Yin's avatar
Liangsheng Yin committed
1828
1829
1830
1831
1832
            self.mlp = DeepseekV2MLP(
                hidden_size=config.hidden_size,
                intermediate_size=config.intermediate_size,
                hidden_act=config.hidden_act,
                quant_config=quant_config,
1833
                prefix=add_prefix("mlp", prefix),
1834
1835
                tp_rank=mlp_tp_rank,
                tp_size=mlp_tp_size,
Liangsheng Yin's avatar
Liangsheng Yin committed
1836
            )
1837

Liangsheng Yin's avatar
Liangsheng Yin committed
1838
1839
1840
1841
1842
        self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(
            config.hidden_size, eps=config.rms_norm_eps
        )

1843
1844
1845
1846
        self.layer_communicator = LayerCommunicator(
            layer_scatter_modes=self.layer_scatter_modes,
            input_layernorm=self.input_layernorm,
            post_attention_layernorm=self.post_attention_layernorm,
1847
            allow_reduce_scatter=True,
1848
1849
1850
            is_last_layer=(
                is_nextn or (self.layer_id == self.config.num_hidden_layers - 1)
            ),
1851
        )
1852
1853
1854
1855
1856
1857

    def _is_layer_sparse(self, layer_id: int, is_nextn: bool) -> bool:
        return is_nextn or (
            self.config.n_routed_experts is not None
            and layer_id >= self.config.first_k_dense_replace
            and layer_id % self.config.moe_layer_freq == 0
1858
1859
        )

Liangsheng Yin's avatar
Liangsheng Yin committed
1860
1861
1862
1863
    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
1864
        forward_batch: ForwardBatch,
Liangsheng Yin's avatar
Liangsheng Yin committed
1865
        residual: Optional[torch.Tensor],
1866
        zero_allocator: BumpAllocator,
Liangsheng Yin's avatar
Liangsheng Yin committed
1867
    ) -> torch.Tensor:
1868

1869
1870
        hidden_states, residual = self.layer_communicator.prepare_attn(
            hidden_states, residual, forward_batch
1871
1872
        )

1873
1874
1875
1876
1877
1878
1879
1880
1881
1882
1883
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
            forward_batch=forward_batch,
            zero_allocator=zero_allocator,
        )

        hidden_states, residual = self.layer_communicator.prepare_mlp(
            hidden_states, residual, forward_batch
        )

1884
        should_allreduce_fusion = (
1885
1886
            self.layer_communicator.should_fuse_mlp_allreduce_with_next_layer(
                forward_batch
1887
            )
1888
1889
        )

1890
1891
1892
1893
1894
        # For DP with padding, reduce scatter can be used instead of all-reduce.
        use_reduce_scatter = self.layer_communicator.should_use_reduce_scatter(
            forward_batch
        )
        hidden_states = self.mlp(
1895
            hidden_states, forward_batch, should_allreduce_fusion, use_reduce_scatter
1896
        )
1897

1898
        if should_allreduce_fusion:
1899
1900
            hidden_states._sglang_needs_allreduce_fusion = True

1901
        if not should_allreduce_fusion:
1902
1903
1904
1905
            hidden_states, residual = self.layer_communicator.postprocess_layer(
                hidden_states, residual, forward_batch
            )

1906
1907
        return hidden_states, residual

1908
1909
1910
1911
1912
1913
1914
1915
    def op_comm_prepare_attn(
        self,
        state,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        forward_batch: ForwardBatch,
        residual: Optional[torch.Tensor],
        zero_allocator: BumpAllocator,
1916
        tbo_subbatch_index: Optional[int] = None,
1917
1918
    ):
        state.hidden_states_after_comm_pre_attn, state.residual_after_input_ln = (
fzyzcjy's avatar
fzyzcjy committed
1919
            self.layer_communicator.prepare_attn(hidden_states, residual, forward_batch)
1920
1921
1922
1923
1924
1925
        )
        state.update(
            dict(
                forward_batch=forward_batch,
                positions=positions,
                zero_allocator=zero_allocator,
1926
                tbo_subbatch_index=tbo_subbatch_index,
1927
            )
1928
        )
1929

1930
1931
1932
1933
1934
1935
1936
    def op_comm_prepare_mlp(self, state):
        state.hidden_states_mlp_input, state.residual_after_comm_pre_mlp = (
            self.layer_communicator.prepare_mlp(
                state.pop("hidden_states_after_attn"),
                state.pop("residual_after_input_ln"),
                state.forward_batch,
            )
1937
        )
1938

1939
1940
1941
1942
1943
1944
1945
1946
    def op_mlp(self, state):
        hidden_states = state.pop("hidden_states_mlp_input")
        if not (
            enable_moe_dense_fully_dp()
            and (not self.is_layer_sparse)
            and hidden_states.shape[0] == 0
        ):
            state.hidden_states_mlp_output = self.mlp(
1947
                hidden_states, state.forward_batch
1948
1949
1950
            )
        else:
            state.hidden_states_mlp_output = hidden_states
1951

1952
    def op_comm_postprocess_layer(self, state):
1953
        hidden_states, residual = self.layer_communicator.postprocess_layer(
1954
1955
1956
            state.pop("hidden_states_mlp_output"),
            state.pop("residual_after_comm_pre_mlp"),
            state.forward_batch,
1957
        )
1958

1959
1960
1961
1962
1963
1964
1965
1966
1967
1968
1969
1970
1971
1972
1973
1974
1975
1976
        output = dict(
            positions=state.positions,
            hidden_states=hidden_states,
            residual=residual,
            forward_batch=state.forward_batch,
            zero_allocator=state.zero_allocator,
            tbo_subbatch_index=state.tbo_subbatch_index,
        )

        state.clear(
            expect_keys={
                "positions",
                "forward_batch",
                "zero_allocator",
                "tbo_subbatch_index",
            }
        )
        return output
1977

Liangsheng Yin's avatar
Liangsheng Yin committed
1978
1979
1980
1981
1982
1983
1984
1985

class DeepseekV2Model(nn.Module):
    fall_back_to_pt_during_load = False

    def __init__(
        self,
        config: PretrainedConfig,
        quant_config: Optional[QuantizationConfig] = None,
1986
        prefix: str = "",
Liangsheng Yin's avatar
Liangsheng Yin committed
1987
1988
1989
1990
    ) -> None:
        super().__init__()
        self.padding_id = config.pad_token_id
        self.vocab_size = config.vocab_size
1991
        self.first_k_dense_replace = config.first_k_dense_replace
1992
1993
1994
1995
1996
1997
1998
1999
2000
2001
        self.pp_group = get_pp_group()

        if self.pp_group.is_first_rank:
            self.embed_tokens = VocabParallelEmbedding(
                config.vocab_size,
                config.hidden_size,
                enable_tp=not is_dp_attention_enabled(),
            )
        else:
            self.embed_tokens = PPMissingLayer()
Liangsheng Yin's avatar
Liangsheng Yin committed
2002

2003
        self.alt_stream = torch.cuda.Stream() if _is_cuda else None
2004
2005
2006
2007
2008
2009
2010
2011
2012
2013
2014
2015
        self.layers, self.start_layer, self.end_layer = make_layers(
            config.num_hidden_layers,
            lambda idx, prefix: DeepseekV2DecoderLayer(
                config=config,
                layer_id=idx,
                quant_config=quant_config,
                prefix=prefix,
                alt_stream=self.alt_stream,
            ),
            pp_rank=self.pp_group.rank_in_group,
            pp_size=self.pp_group.world_size,
            prefix=add_prefix("layers", prefix),
fzyzcjy's avatar
fzyzcjy committed
2016
2017
2018
2019
2020
2021
2022
2023
2024
2025
2026
2027
2028
2029
2030
2031
2032
            offloader_kwargs=dict(
                submodule_accessor=lambda layer: (
                    layer.mlp.experts
                    if isinstance(layer.mlp, DeepseekV2MoE)
                    else layer.mlp
                ),
                whitelist_param_names_creator=lambda module: (
                    [
                        "w13_weight",
                        "w2_weight",
                        "w13_blockscale_swizzled",
                        "w2_blockscale_swizzled",
                    ]
                    if isinstance(module, FusedMoE)
                    else []
                ),
            ),
Liangsheng Yin's avatar
Liangsheng Yin committed
2033
        )
2034
2035
2036
2037
        if self.pp_group.is_last_rank:
            self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        else:
            self.norm = PPMissingLayer(return_tuple=True)
Liangsheng Yin's avatar
Liangsheng Yin committed
2038

2039
2040
2041
    def get_input_embeddings(self) -> torch.Tensor:
        return self.embed_tokens

Liangsheng Yin's avatar
Liangsheng Yin committed
2042
2043
2044
2045
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
2046
        forward_batch: ForwardBatch,
2047
        input_embeds: torch.Tensor = None,
2048
2049
2050
        pp_proxy_tensors: Optional[PPProxyTensors] = None,
    ) -> Union[torch.Tensor, PPProxyTensors]:
        total_num_layers = self.end_layer - self.start_layer
2051
        device = input_embeds.device if input_embeds is not None else input_ids.device
2052
        zero_allocator = BumpAllocator(
2053
            buffer_size=total_num_layers * 2 * (2 if forward_batch.can_run_tbo else 1),
2054
            dtype=torch.float32,
2055
            device=device,
2056
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
2057

2058
2059
2060
2061
2062
2063
        if self.pp_group.is_first_rank:
            if input_embeds is None:
                hidden_states = self.embed_tokens(input_ids)
            else:
                hidden_states = input_embeds
            residual = None
2064
        else:
2065
2066
2067
            assert pp_proxy_tensors is not None
            hidden_states = pp_proxy_tensors["hidden_states"]
            residual = pp_proxy_tensors["residual"]
2068

2069
2070
2071
2072
2073
2074
2075
2076
2077
2078
        normal_start_layer = self.start_layer
        normal_end_layer = self.end_layer
        if forward_batch.can_run_tbo:
            if (
                self.first_k_dense_replace > normal_start_layer
                and self.first_k_dense_replace < normal_end_layer
            ):
                normal_end_layer = self.first_k_dense_replace
            elif self.first_k_dense_replace < normal_start_layer:
                normal_end_layer = normal_start_layer = 0
2079

2080
        for i in range(normal_start_layer, normal_end_layer):
2081
2082
2083
2084
2085
            with get_global_expert_distribution_recorder().with_current_layer(i):
                layer = self.layers[i]
                hidden_states, residual = layer(
                    positions, hidden_states, forward_batch, residual, zero_allocator
                )
2086

2087
        if normal_end_layer != self.end_layer:
2088
            hidden_states, residual = model_forward_maybe_tbo(
2089
                layers=self.layers[normal_end_layer : self.end_layer],
2090
2091
2092
2093
2094
                enable_tbo=True,
                positions=positions,
                forward_batch=forward_batch,
                hidden_states=hidden_states,
                residual=residual,
2095
                input_data_scatter_mode=self.layers[
2096
                    normal_end_layer - 1
2097
                ].layer_scatter_modes.layer_output_mode,
2098
2099
2100
                zero_allocator=zero_allocator,
            )

2101
2102
2103
2104
2105
2106
2107
2108
2109
2110
2111
2112
2113
        if not self.pp_group.is_last_rank:
            return PPProxyTensors(
                {
                    "hidden_states": hidden_states,
                    "residual": residual,
                }
            )
        else:
            if not forward_batch.forward_mode.is_idle():
                if residual is None:
                    hidden_states = self.norm(hidden_states)
                else:
                    hidden_states, _ = self.norm(hidden_states, residual)
Liangsheng Yin's avatar
Liangsheng Yin committed
2114
2115
2116
2117
        return hidden_states


class DeepseekV2ForCausalLM(nn.Module):
2118
2119
    # for quark model load
    packed_modules_mapping = {}
Liangsheng Yin's avatar
Liangsheng Yin committed
2120
2121
2122
2123
2124

    def __init__(
        self,
        config: PretrainedConfig,
        quant_config: Optional[QuantizationConfig] = None,
2125
        prefix: str = "",
Liangsheng Yin's avatar
Liangsheng Yin committed
2126
2127
    ) -> None:
        super().__init__()
2128
2129
2130
2131
2132
2133
2134
2135
2136
2137
2138
2139

        # for quark model load
        # Fuse q_a_proj and kv_a_proj_with_mqa along output dimension when q_lora_rank is not None
        self.fuse_qkv_a_proj = (
            hasattr(config, "q_lora_rank") and config.q_lora_rank is not None
        )
        if self.fuse_qkv_a_proj:
            self.packed_modules_mapping["fused_qkv_a_proj_with_mqa"] = [
                "q_a_proj",
                "kv_a_proj_with_mqa",
            ]

2140
        self.pp_group = get_pp_group()
Liangsheng Yin's avatar
Liangsheng Yin committed
2141
        self.config = config
2142
        self.tp_size = get_tensor_model_parallel_world_size()
Liangsheng Yin's avatar
Liangsheng Yin committed
2143
        self.quant_config = quant_config
2144
        self.determine_num_fused_shared_experts()
2145
2146
2147
2148
2149
2150
2151
2152
        self.model = DeepseekV2Model(
            config, quant_config, prefix=add_prefix("model", prefix)
        )
        self.lm_head = ParallelLMHead(
            config.vocab_size,
            config.hidden_size,
            quant_config=quant_config,
            prefix=add_prefix("lm_head", prefix),
2153
            use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
2154
2155
2156
        )
        self.logits_processor = LogitsProcessor(config)

2157
2158
2159
2160
2161
2162
2163
2164
2165
2166
2167
2168
        self._routed_experts_weights_of_layer = LazyValue(
            lambda: {
                layer_id: layer.mlp.get_moe_weights()
                for layer_id, layer in enumerate(self.model.layers)
                if isinstance(layer.mlp, DeepseekV2MoE)
            }
        )

    @property
    def routed_experts_weights_of_layer(self):
        return self._routed_experts_weights_of_layer.value

2169
    def determine_num_fused_shared_experts(
2170
2171
        self, architecture: str = "DeepseekV3ForCausalLM"
    ):
2172
2173
2174
2175
2176
2177
2178
2179
        self.num_fused_shared_experts = 0
        if global_server_args_dict["disable_shared_experts_fusion"]:
            return

        # Only Deepseek V3/R1 can use shared experts fusion optimization now.
        disable_reason = None
        if (
            not _is_cuda
2180
            or torch.cuda.get_device_capability("cuda") < (8, 0)
2181
2182
2183
2184
            or self.config.architectures[0] != architecture
            or self.config.n_routed_experts != 256
            or self.config.n_shared_experts != 1
        ):
2185
            disable_reason = "Only Deepseek V3/R1 on NV-platform with capability >= 80 can use shared experts fusion optimization."
2186
2187
        elif get_moe_expert_parallel_world_size() > 1:
            disable_reason = "Deepseek V3/R1 can not use shared experts fusion optimization under expert parallelism."
2188
2189
2190

        if disable_reason is not None:
            global_server_args_dict["disable_shared_experts_fusion"] = True
Cheng Wan's avatar
Cheng Wan committed
2191
            self.num_fused_shared_experts = 0
2192
2193
2194
2195
2196
2197
2198
            log_info_on_rank0(
                logger,
                f"{disable_reason} Shared experts fusion optimization is disabled.",
            )
            return

        self.num_fused_shared_experts = self.config.n_shared_experts
2199

Mick's avatar
Mick committed
2200
2201
2202
    def get_input_embeddings(self) -> nn.Embedding:
        return self.model.embed_tokens

2203
    @torch.no_grad()
Liangsheng Yin's avatar
Liangsheng Yin committed
2204
2205
2206
2207
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
2208
        forward_batch: ForwardBatch,
2209
        input_embeds: torch.Tensor = None,
2210
        pp_proxy_tensors: Optional[PPProxyTensors] = None,
Liangsheng Yin's avatar
Liangsheng Yin committed
2211
    ) -> torch.Tensor:
2212
2213
        hidden_states = self.model(
            input_ids, positions, forward_batch, input_embeds, pp_proxy_tensors
2214
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
2215

2216
2217
2218
2219
2220
2221
2222
2223
2224
2225
2226
2227
2228
2229
2230
        if self.pp_group.is_last_rank:
            return self.logits_processor(
                input_ids, hidden_states, self.lm_head, forward_batch
            )
        else:
            return hidden_states

    @property
    def start_layer(self):
        return self.model.start_layer

    @property
    def end_layer(self):
        return self.model.end_layer

2231
    def post_load_weights(self, is_nextn=False, weight_names=None):
inkcherry's avatar
inkcherry committed
2232
2233

        # Perform post-processing after loading weights
2234
2235
2236
2237
        if is_nextn:
            layer_ids = [self.config.num_hidden_layers]
        else:
            if weight_names is None:
2238
                layer_ids = range(self.model.start_layer, self.model.end_layer)
2239
2240
2241
2242
2243
            else:
                layer_ids = set()
                for name in weight_names:
                    if "kv_b_proj" in name:
                        layer_id = int(name.split(".")[2])
2244
                        if layer_id < self.config.num_hidden_layers:
2245
2246
                            layer_ids.add(layer_id)

2247
2248
2249
2250
2251
2252
        for layer_id in layer_ids:
            self_attn = (
                self.model.layers[layer_id].self_attn
                if not is_nextn
                else self.model.decoder.self_attn
            )
Baizhou Zhang's avatar
Baizhou Zhang committed
2253
2254
            if hasattr(self_attn.kv_b_proj, "qweight"):
                # AWQ compatible
2255
                if _is_cuda or _is_hip:
Baizhou Zhang's avatar
Baizhou Zhang committed
2256
2257
2258
2259
2260
                    w = awq_dequantize(
                        self_attn.kv_b_proj.qweight,
                        self_attn.kv_b_proj.scales,
                        self_attn.kv_b_proj.qzeros,
                    ).T
inkcherry's avatar
inkcherry committed
2261
                else:
Baizhou Zhang's avatar
Baizhou Zhang committed
2262
2263
2264
2265
2266
2267
2268
2269
2270
2271
2272
2273
                    w = awq_dequantize(
                        self_attn.kv_b_proj.qweight,
                        self_attn.kv_b_proj.scales,
                        self_attn.kv_b_proj.qzeros,
                        0,
                        0,
                        0,
                    ).T
            else:
                w = self_attn.kv_b_proj.weight
            # NOTE(HandH1998): Since `bmm_fp8` only supports per-tensor scale, we have to requantize `self_attn.kv_b_proj`.
            # This may affect the accuracy of fp8 model.
2274
2275
2276
            # Fix deepseek v3 blockwise bmm by using deep_gemm
            use_deep_gemm_bmm = False

Baizhou Zhang's avatar
Baizhou Zhang committed
2277
2278
2279
2280
            if w.dtype in (
                torch.float8_e4m3fn,
                torch.float8_e4m3fnuz,
            ):
2281
2282
2283
2284
                if (
                    hasattr(self.quant_config, "weight_block_size")
                    and self.quant_config.weight_block_size is not None
                ):
Baizhou Zhang's avatar
Baizhou Zhang committed
2285
                    weight_block_size = self.quant_config.weight_block_size
2286
2287
2288
2289
2290
2291
2292
2293
2294
2295
2296
2297
2298
2299
2300
2301
                    assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
                    if _is_fp8_fnuz:
                        weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
                            weight=w,
                            weight_scale=self_attn.kv_b_proj.weight_scale_inv,
                            input_scale=None,
                        )
                    else:
                        weight = w
                        weight_scale = self_attn.kv_b_proj.weight_scale_inv

                    if (
                        _is_cuda
                        and weight_block_size[0] == 128
                        and weight_block_size[1] == 128
                    ):
2302
2303
2304
2305
                        if (
                            deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
                            and not deep_gemm_wrapper.DEEPGEMM_BLACKWELL
                            and get_bool_env_var("SGL_USE_DEEPGEMM_BMM", "false")
2306
                        ):
2307
2308
                            block_scale = weight_scale
                            use_deep_gemm_bmm = True
2309
                        else:
2310
2311
2312
2313
                            w = block_quant_dequant(
                                weight,
                                weight_scale,
                                weight_block_size,
2314
                                torch.bfloat16,
2315
                            )
2316
2317
2318
2319
2320
                    else:
                        w, scale = block_quant_to_tensor_quant(
                            weight, weight_scale, weight_block_size
                        )
                        self_attn.w_scale = scale
Baizhou Zhang's avatar
Baizhou Zhang committed
2321
                else:
2322
2323
2324
2325
2326
2327
2328
2329
2330
2331
                    if _is_fp8_fnuz:
                        weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
                            weight=w,
                            weight_scale=self_attn.kv_b_proj.weight_scale,
                            input_scale=None,
                        )
                    else:
                        weight = w
                        weight_scale = self_attn.kv_b_proj.weight_scale

Baizhou Zhang's avatar
Baizhou Zhang committed
2332
2333
2334
2335
2336
2337
2338
2339
2340
2341
2342
2343
2344
2345
2346
2347
2348
2349
2350
                    w, scale = channel_quant_to_tensor_quant(weight, weight_scale)
                    self_attn.w_scale = scale

            if w.dtype == torch.int8:
                if hasattr(self.quant_config, "weight_block_size"):
                    # block-wise int8 need it
                    weight_block_size = self.quant_config.weight_block_size
                    if weight_block_size is not None:
                        assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
                        weight = w
                        weight_scale = self_attn.kv_b_proj.weight_scale_inv
                        w = int8_block_dequant(
                            weight, weight_scale, weight_block_size
                        ).to(torch.bfloat16)
                else:
                    # channel-wise int8 need it
                    w = w.to(torch.bfloat16) * self_attn.kv_b_proj.weight_scale.to(
                        torch.bfloat16
                    )
2351

Baizhou Zhang's avatar
Baizhou Zhang committed
2352
2353
2354
            w_kc, w_vc = w.unflatten(
                0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
            ).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
2355
            if not use_deep_gemm_bmm:
2356
2357
2358
2359
2360
2361
                self_attn.w_kc = bind_or_assign(
                    self_attn.w_kc, w_kc.transpose(1, 2).contiguous().transpose(1, 2)
                )
                self_attn.w_vc = bind_or_assign(
                    self_attn.w_vc, w_vc.contiguous().transpose(1, 2)
                )
2362
2363
2364
2365
                if (
                    hasattr(self_attn.kv_b_proj, "weight_scale")
                    and self_attn.w_scale is None
                ):
2366
2367
2368
                    self_attn.w_scale = bind_or_assign(
                        self_attn.w_scale, self_attn.kv_b_proj.weight_scale
                    )
2369
2370
                    if _is_hip:
                        self_attn.w_scale *= 2.0
2371
2372
2373
2374
2375
2376
2377
2378
                # TODO: remove this after adding FP8 support in bmm cpu kernel
                if _is_cpu and _is_cpu_amx_available and w.dtype == torch.float8_e4m3fn:
                    self_attn.w_kc = (
                        self_attn.w_kc.to(torch.bfloat16) * self_attn.w_scale
                    )
                    self_attn.w_vc = (
                        self_attn.w_vc.to(torch.bfloat16) * self_attn.w_scale
                    )
2379
2380
2381
2382
2383
2384
            else:
                num_tiles_k = self_attn.qk_nope_head_dim // weight_block_size[1]
                num_tiles_n = self_attn.v_head_dim // weight_block_size[0]
                ws_kc, ws_vc = block_scale.unflatten(
                    0, (-1, (num_tiles_k + num_tiles_n))
                ).split([num_tiles_k, num_tiles_n], dim=1)
2385
2386
2387
2388
2389
2390
2391
2392
2393
2394
                self_attn.w_scale_k = bind_or_assign(
                    self_attn.w_scale_k, ws_kc.transpose(1, 2).contiguous()
                )
                self_attn.w_scale_v = bind_or_assign(
                    self_attn.w_scale_v, ws_vc.contiguous()
                )
                self_attn.w_kc = bind_or_assign(
                    self_attn.w_kc, w_kc.transpose(1, 2).contiguous()
                )
                self_attn.w_vc = bind_or_assign(self_attn.w_vc, w_vc.contiguous())
2395
                self_attn.use_deep_gemm_bmm = True
inkcherry's avatar
inkcherry committed
2396

2397
2398
2399
        if (
            deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
            and deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
2400
2401
            and hasattr(self.quant_config, "weight_block_size")
            and self.quant_config.weight_block_size is not None
2402
        ):
2403
            self._weight_requant_ue8m0(is_nextn)
2404

2405
    def _weight_requant_ue8m0(self, is_nextn=False):
2406
2407
2408
2409
2410
2411
2412
2413
2414
2415
        weight_block_size = self.quant_config.weight_block_size

        moe_layers = list(
            range(
                self.config.first_k_dense_replace,
                self.config.num_hidden_layers,
                self.config.moe_layer_freq,
            )
        )

2416
        num_hidden_layers = 1 if is_nextn else self.config.num_hidden_layers
2417

2418
2419
2420
2421
2422
        for layer_id in range(num_hidden_layers):
            if is_nextn:
                layer = self.model.decoder
            else:
                layer = self.model.layers[layer_id]
2423

2424
            module_list = [
2425
2426
                layer.self_attn.kv_b_proj,
                layer.self_attn.o_proj,
2427
2428
2429
2430
2431
2432
2433
2434
2435
2436
            ]

            if self.config.q_lora_rank is not None:
                module_list.append(layer.self_attn.fused_qkv_a_proj_with_mqa)
                module_list.append(layer.self_attn.q_b_proj)
            else:
                module_list.append(layer.self_attn.kv_a_proj_with_mqa)
                module_list.append(layer.self_attn.q_proj)

            for module in module_list:
2437
2438
2439
2440
                requant_weight_ue8m0_inplace(
                    module.weight, module.weight_scale_inv, weight_block_size
                )

2441
            if layer_id in moe_layers or is_nextn:
2442
2443
2444
2445
2446
2447
2448
2449
2450
                shared_experts = getattr(layer.mlp, "shared_experts", None)
                if shared_experts is not None:
                    for module in [
                        shared_experts.gate_up_proj,
                        shared_experts.down_proj,
                    ]:
                        requant_weight_ue8m0_inplace(
                            module.weight, module.weight_scale_inv, weight_block_size
                        )
2451
2452
2453
2454
2455
2456
2457
2458
2459
2460
2461
2462
2463
2464
2465
2466
2467
2468
2469

                experts = layer.mlp.experts
                if isinstance(experts, DeepEPMoE):
                    for w in [
                        experts.w13_weight_fp8,
                        experts.w2_weight_fp8,
                    ]:
                        requant_weight_ue8m0_inplace(w[0], w[1], weight_block_size)
            else:
                mlp = layer.mlp
                assert isinstance(mlp, DeepseekV2MLP)
                for module in [
                    mlp.gate_up_proj,
                    mlp.down_proj,
                ]:
                    requant_weight_ue8m0_inplace(
                        module.weight, module.weight_scale_inv, weight_block_size
                    )

2470
    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], is_nextn=False):
2471

2472
2473
2474
        if is_nextn:
            if hasattr(self.config, "num_nextn_predict_layers"):
                num_nextn_layers = self.config.num_nextn_predict_layers
2475
                assert num_nextn_layers == 1, "Only 1 nextn layer is supported"
2476
2477
2478
2479
2480
2481
2482
2483
2484
                # compatible with old design
                nextn_layer_id = (
                    0
                    if self.config.num_hidden_layers == 1
                    else self.config.num_hidden_layers
                )
            else:
                raise ValueError("num_nextn_predict_layers is not in the config")

Liangsheng Yin's avatar
Liangsheng Yin committed
2485
2486
2487
2488
2489
2490
2491
2492
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
        ]

        # Params for weights, fp8 weight scales, fp8 activation scales
        # (param_name, weight_name, expert_id, shard_id)
2493
        expert_params_mapping = FusedMoE.make_expert_params_mapping(
Liangsheng Yin's avatar
Liangsheng Yin committed
2494
2495
2496
            ckpt_gate_proj_name="gate_proj",
            ckpt_down_proj_name="down_proj",
            ckpt_up_proj_name="up_proj",
2497
            num_experts=self.config.n_routed_experts + self.num_fused_shared_experts,
Liangsheng Yin's avatar
Liangsheng Yin committed
2498
        )
2499
        if self.quant_config and self.quant_config.get_name() == "w4afp8":
2500
2501
            expert_params_mapping += FusedMoE.make_expert_input_scale_params_mapping(
                num_experts=self.config.n_routed_experts
2502
            )
Liangsheng Yin's avatar
Liangsheng Yin committed
2503

2504
2505
2506
2507
2508
2509
        # Fuse q_a_proj and kv_a_proj_with_mqa along output dimension when q_lora_rank is not None
        fuse_qkv_a_proj = hasattr(self.config, "q_lora_rank") and (
            self.config.q_lora_rank is not None
        )
        cached_a_proj = {} if fuse_qkv_a_proj else None

2510
2511
2512
2513
2514
2515
2516
2517
2518
        if is_nextn:
            nextn_layer_prefix = f"model.layers.{nextn_layer_id}"
            nextn_spec_weight_names = [
                "shared_head.norm",
                "eh_proj",
                "enorm",
                "hnorm",
            ]

2519
2520
        if self.num_fused_shared_experts > 0:
            assert self.num_fused_shared_experts == 1
2521
            log_info_on_rank0(logger, "Shared experts fusion optimization enabled.")
2522

2523
2524
2525
2526
2527
        with concurrent.futures.ThreadPoolExecutor() as executor:
            futures = []
            params_dict = dict(self.named_parameters())
            weight_names = []
            for name, loaded_weight in weights:
2528
2529
2530
2531
2532
2533
2534
2535
2536
2537
                layer_id = get_layer_id(name)
                if (
                    layer_id is not None
                    and hasattr(self.model, "start_layer")
                    and (
                        layer_id < self.model.start_layer
                        or layer_id >= self.model.end_layer
                    )
                ):
                    continue
2538
2539
2540
2541
2542
                if self.num_fused_shared_experts > 0 and "mlp.shared_experts" in name:
                    name = name.replace(
                        "mlp.shared_experts",
                        f"mlp.experts.{self.config.n_routed_experts}",
                    )
2543

2544
                weight_names.append(name)
2545

2546
2547
2548
2549
2550
2551
2552
2553
2554
2555
2556
2557
2558
                if not is_nextn:
                    if hasattr(self.config, "num_nextn_predict_layers"):
                        num_nextn_layers = self.config.num_nextn_predict_layers
                        if num_nextn_layers > 0 and name.startswith("model.layers"):
                            name_list = name.split(".")
                            if (
                                len(name_list) >= 3
                                and int(name_list[2]) >= self.config.num_hidden_layers
                            ):
                                continue
                else:
                    if not name.startswith(nextn_layer_prefix):
                        continue
2559

2560
2561
2562
                    # Use shared head and embed weights from target model
                    if "shared_head.head" in name or "embed_tokens" in name:
                        continue
2563

2564
2565
2566
2567
2568
2569
2570
2571
2572
2573
2574
2575
                    is_decoder = True
                    # For nextn specific weights
                    for weight_name in nextn_spec_weight_names:
                        if weight_name in name:
                            name = name.replace(nextn_layer_prefix, "model")
                            is_decoder = False
                            break
                    # For decoder layer weights
                    if is_decoder:
                        name = name.replace(nextn_layer_prefix, "model.decoder")

                if "rotary_emb.inv_freq" in name:
Liangsheng Yin's avatar
Liangsheng Yin committed
2576
                    continue
2577
2578
                for param_name, weight_name, shard_id in stacked_params_mapping:
                    # Skip non-stacked layers and experts (experts handled below).
Liangsheng Yin's avatar
Liangsheng Yin committed
2579
2580
                    if weight_name not in name:
                        continue
2581
2582
2583
2584
2585
2586
2587
2588
                    # We have mlp.experts[0].gate_proj in the checkpoint.
                    # Since we handle the experts below in expert_params_mapping,
                    # we need to skip here BEFORE we update the name, otherwise
                    # name will be updated to mlp.experts[0].gate_up_proj, which
                    # will then be updated below in expert_params_mapping
                    # for mlp.experts[0].gate_gate_up_proj, which breaks load.
                    if ("mlp.experts." in name) and name not in params_dict:
                        continue
Liangsheng Yin's avatar
Liangsheng Yin committed
2589
                    name = name.replace(weight_name, param_name)
2590
2591
2592
                    # Skip loading extra bias for GPTQ models.
                    if name.endswith(".bias") and name not in params_dict:
                        continue
Liangsheng Yin's avatar
Liangsheng Yin committed
2593
2594
                    param = params_dict[name]
                    weight_loader = param.weight_loader
2595
2596
                    futures.append(
                        executor.submit(weight_loader, param, loaded_weight, shard_id)
Liangsheng Yin's avatar
Liangsheng Yin committed
2597
2598
2599
                    )
                    break
                else:
2600
2601
2602
2603
2604
2605
2606
2607
2608
2609
2610
2611
2612
2613
2614
2615
                    for mapping in expert_params_mapping:
                        param_name, weight_name, expert_id, shard_id = mapping
                        if weight_name not in name:
                            continue
                        name = name.replace(weight_name, param_name)
                        param = params_dict[name]
                        weight_loader = param.weight_loader
                        futures.append(
                            executor.submit(
                                weight_loader,
                                param,
                                loaded_weight,
                                name,
                                shard_id=shard_id,
                                expert_id=expert_id,
                            )
2616
                        )
2617
2618
2619
2620
2621
                        break
                    else:
                        # Skip loading extra bias for GPTQ models.
                        if name.endswith(".bias") and name not in params_dict:
                            continue
2622
2623
2624
2625
2626
2627
                        # Skip loading embed_tokens if not first rank in pipeline parallelism
                        if ".embed_tokens." in name and not self.pp_group.is_first_rank:
                            continue
                        # Skip loading norm if not last rank in pipeline parallelism
                        if ".norm." in name and not self.pp_group.is_last_rank:
                            continue
2628
2629
                        if fuse_qkv_a_proj and (
                            "q_a_proj" in name or "kv_a_proj_with_mqa" in name
2630
                        ):
2631
2632
2633
                            cached_a_proj[name] = loaded_weight
                            q_a_proj_name = (
                                name
2634
                                if "q_a_proj" in name
2635
2636
2637
2638
2639
2640
                                else name.replace("kv_a_proj_with_mqa", "q_a_proj")
                            )
                            kv_a_proj_name = (
                                name
                                if "kv_a_proj_with_mqa" in name
                                else name.replace("q_a_proj", "kv_a_proj_with_mqa")
2641
2642
                            )

2643
2644
2645
2646
2647
2648
2649
2650
2651
2652
                            # When both q_a_proj and kv_a_proj_with_mqa has been cached, load the fused weight to parameter
                            if (
                                q_a_proj_name in cached_a_proj
                                and kv_a_proj_name in cached_a_proj
                            ):
                                q_a_proj_weight = cached_a_proj[q_a_proj_name]
                                kv_a_proj_weight = cached_a_proj[kv_a_proj_name]
                                cat_dim = 0
                                if self.quant_config is not None and (
                                    self.quant_config.get_name() == "awq"
2653
                                    or self.quant_config.get_name() == "awq_marlin"
2654
2655
2656
2657
2658
2659
2660
2661
2662
2663
2664
2665
2666
2667
2668
2669
2670
2671
2672
2673
2674
2675
2676
2677
2678
2679
2680
2681
2682
2683
2684
2685
2686
2687
2688
2689
2690
2691
2692
2693
2694
2695
2696
2697
                                    or self.quant_config.get_name() == "moe_wna16"
                                ):
                                    cat_dim = 1
                                fused_weight = torch.cat(
                                    [q_a_proj_weight, kv_a_proj_weight], dim=cat_dim
                                )
                                param_name = (
                                    name.replace(
                                        "q_a_proj", "fused_qkv_a_proj_with_mqa"
                                    )
                                    if "q_a_proj" in name
                                    else name.replace(
                                        "kv_a_proj_with_mqa",
                                        "fused_qkv_a_proj_with_mqa",
                                    )
                                )
                                param = params_dict[param_name]

                                weight_loader = getattr(
                                    param, "weight_loader", default_weight_loader
                                )
                                futures.append(
                                    executor.submit(weight_loader, param, fused_weight)
                                )
                                cached_a_proj.pop(q_a_proj_name)
                                cached_a_proj.pop(kv_a_proj_name)
                        else:
                            if (
                                "k_scale" in name or "v_scale" in name
                            ) and name not in params_dict:
                                # modelopt attn kv scale is named differently
                                for scale in ["k_scale", "v_scale"]:
                                    if scale in name:
                                        name = name.replace(
                                            f"{scale[0]}_proj", "attn_mqa"
                                        )
                                        break
                            if name not in params_dict:
                                # modelopt ckpt contains not needed weights for MTP module:
                                # model.decoder.self_attn.attn_mqa.v_scale and
                                # model.decoder.self_attn.attn_mqa.k_scale
                                logger.warning(f"{name} not found in params_dict.")
                                continue
                            param = params_dict[name]
2698
2699
2700
                            weight_loader = getattr(
                                param, "weight_loader", default_weight_loader
                            )
2701
2702
2703
2704
2705
2706
2707
                            futures.append(
                                executor.submit(weight_loader, param, loaded_weight)
                            )

            # Wait for all tasks to complete and raise any exceptions.
            for future in concurrent.futures.as_completed(futures):
                future.result()
Liangsheng Yin's avatar
Liangsheng Yin committed
2708

2709
        self.post_load_weights(is_nextn=is_nextn, weight_names=weight_names)
Ke Bao's avatar
Ke Bao committed
2710

2711
2712
2713
2714
2715
2716
2717
2718
2719
2720
2721
    def get_embed_and_head(self):
        return self.model.embed_tokens.weight, self.lm_head.weight

    def set_embed_and_head(self, embed, head):
        del self.model.embed_tokens.weight
        del self.lm_head.weight
        self.model.embed_tokens.weight = embed
        self.lm_head.weight = head
        torch.cuda.empty_cache()
        torch.cuda.synchronize()

2722
2723
2724
2725
2726
2727
2728
2729
    @classmethod
    def get_model_config_for_expert_location(cls, config):
        return ModelConfigForExpertLocation(
            num_layers=config.num_hidden_layers,
            num_logical_experts=config.n_routed_experts,
            num_groups=config.n_group,
        )

Liangsheng Yin's avatar
Liangsheng Yin committed
2730

HandH1998's avatar
HandH1998 committed
2731
2732
2733
2734
2735
class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM):
    pass


EntryClass = [DeepseekV2ForCausalLM, DeepseekV3ForCausalLM]