deepseek_v2.py 105 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
14

Liangsheng Yin's avatar
Liangsheng Yin committed
15
16
17
# Adapted from:
# https://github.com/vllm-project/vllm/blob/fb6af8bc086328ca6659e72d11ffd4309ce4de22/vllm/model_executor/models/deepseek_v2.py
"""Inference-only DeepseekV2 model."""
18

19
import concurrent.futures
20
import logging
21
import os
22
from enum import IntEnum, auto
23
from typing import Any, Dict, Iterable, Optional, Tuple, Union
Liangsheng Yin's avatar
Liangsheng Yin committed
24
25

import torch
Ke Bao's avatar
Ke Bao committed
26
import torch.nn.functional as F
Liangsheng Yin's avatar
Liangsheng Yin committed
27
from torch import nn
28
from tqdm import tqdm
Liangsheng Yin's avatar
Liangsheng Yin committed
29
from transformers import PretrainedConfig
30
31

from sglang.srt.distributed import (
32
    get_moe_expert_parallel_world_size,
33
    get_pp_group,
Liangsheng Yin's avatar
Liangsheng Yin committed
34
    get_tensor_model_parallel_world_size,
35
    parallel_state,
Liangsheng Yin's avatar
Liangsheng Yin committed
36
37
    tensor_model_parallel_all_reduce,
)
38
39
40
from sglang.srt.distributed.device_communicators.pynccl_allocator import (
    use_symmetric_memory,
)
fzyzcjy's avatar
fzyzcjy committed
41
42
43
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation
from sglang.srt.eplb.expert_location_dispatch import ExpertLocationDispatchInfo
44
from sglang.srt.layers.activation import SiluAndMul
45
from sglang.srt.layers.amx_utils import PackWeightMethod
46
47
48
49
50
from sglang.srt.layers.communicator import (
    LayerCommunicator,
    LayerScatterModes,
    enable_moe_dense_fully_dp,
)
Lianmin Zheng's avatar
Lianmin Zheng committed
51
52
53
from sglang.srt.layers.dp_attention import (
    get_attention_tp_rank,
    get_attention_tp_size,
54
    is_dp_attention_enabled,
Lianmin Zheng's avatar
Lianmin Zheng committed
55
)
56
from sglang.srt.layers.layernorm import RMSNorm
57
58
59
60
61
62
from sglang.srt.layers.linear import (
    ColumnParallelLinear,
    MergedColumnParallelLinear,
    ReplicatedLinear,
    RowParallelLinear,
)
Liangsheng Yin's avatar
Liangsheng Yin committed
63
from sglang.srt.layers.logits_processor import LogitsProcessor
64
65
66
67
68
from sglang.srt.layers.moe import (
    get_deepep_mode,
    get_moe_a2a_backend,
    should_use_flashinfer_cutlass_moe_fp4_allgather,
)
69
from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, get_moe_impl_class
70
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
71
from sglang.srt.layers.moe.topk import TopK
72
from sglang.srt.layers.quantization import deep_gemm_wrapper
73
from sglang.srt.layers.quantization.base_config import QuantizationConfig
74
from sglang.srt.layers.quantization.fp8_kernel import (
75
    is_fp8_fnuz,
76
    per_tensor_quant_mla_fp8,
77
    per_token_group_quant_mla_deep_gemm_masked_fp8,
78
)
HandH1998's avatar
HandH1998 committed
79
from sglang.srt.layers.quantization.fp8_utils import (
80
    block_quant_dequant,
HandH1998's avatar
HandH1998 committed
81
    block_quant_to_tensor_quant,
82
    channel_quant_to_tensor_quant,
83
    normalize_e4m3fn_to_e4m3fnuz,
84
    requant_weight_ue8m0_inplace,
HandH1998's avatar
HandH1998 committed
85
)
86
87
88
from sglang.srt.layers.quantization.int8_utils import (
    block_dequant as int8_block_dequant,
)
Liangsheng Yin's avatar
Liangsheng Yin committed
89
from sglang.srt.layers.radix_attention import RadixAttention
90
from sglang.srt.layers.rotary_embedding import get_rope, get_rope_wrapper
91
from sglang.srt.layers.utils import PPMissingLayer, get_layer_id, is_sm100_supported
92
93
94
95
from sglang.srt.layers.vocab_parallel_embedding import (
    ParallelLMHead,
    VocabParallelEmbedding,
)
96
from sglang.srt.managers.schedule_batch import global_server_args_dict
97
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
98
from sglang.srt.model_loader.weight_utils import default_weight_loader
99
100
101
102
from sglang.srt.two_batch_overlap import (
    MaybeTboDeepEPDispatcher,
    model_forward_maybe_tbo,
)
103
104
from sglang.srt.utils import (
    BumpAllocator,
105
    LazyValue,
106
    add_prefix,
107
    bind_or_assign,
108
    cpu_has_amx_support,
109
    get_bool_env_var,
110
    get_device_sm,
111
    get_int_env_var,
112
    is_cpu,
113
    is_cuda,
114
    is_flashinfer_available,
115
    is_hip,
116
    is_non_idle_and_non_empty,
117
    log_info_on_rank0,
118
    make_layers,
119
    use_intel_amx_backend,
120
)
121

122
_is_hip = is_hip()
Yineng Zhang's avatar
Yineng Zhang committed
123
_is_cuda = is_cuda()
124
_is_fp8_fnuz = is_fp8_fnuz()
125
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
126
127
_is_cpu_amx_available = cpu_has_amx_support()
_is_cpu = is_cpu()
128
_device_sm = get_device_sm()
129

Yineng Zhang's avatar
Yineng Zhang committed
130
if _is_cuda:
131
132
133
134
135
136
137
    from sgl_kernel import (
        awq_dequantize,
        bmm_fp8,
        dsv3_fused_a_gemm,
        dsv3_router_gemm,
        merge_state_v2,
    )
138
139
elif _is_cpu and _is_cpu_amx_available:
    pass
140
141
142
143
elif _is_hip:
    from sglang.srt.layers.quantization.awq_triton import (
        awq_dequantize_triton as awq_dequantize,
    )
Yineng Zhang's avatar
Yineng Zhang committed
144
else:
Lianmin Zheng's avatar
Lianmin Zheng committed
145
    from vllm._custom_ops import awq_dequantize
Liangsheng Yin's avatar
Liangsheng Yin committed
146

147
148
149
150
151
if _is_hip:
    from sglang.srt.layers.attention.triton_ops.rocm_mla_decode_rope import (
        decode_attention_fwd_grouped_rope,
    )

152
153
154
_is_flashinfer_available = is_flashinfer_available()
_is_sm100_supported = is_cuda() and is_sm100_supported()

155

156
157
logger = logging.getLogger(__name__)

Liangsheng Yin's avatar
Liangsheng Yin committed
158

159
160
161
162
163
164
165
166
167
168
169
class AttnForwardMethod(IntEnum):
    # Use multi-head attention
    MHA = auto()

    # Use absorbed multi-latent attention
    MLA = auto()

    # Use multi-head attention, but with KV cache chunked.
    # This method can avoid OOM when prefix lengths are long.
    MHA_CHUNKED_KV = auto()

170
171
172
    # Use MLA but with fused RoPE
    MLA_FUSED_ROPE = auto()

173
174
175
    # Use MLA with fused RoPE kernel for CPU
    MLA_FUSED_ROPE_CPU = auto()

176

Liangsheng Yin's avatar
Liangsheng Yin committed
177
178
179
180
181
182
183
184
class DeepseekV2MLP(nn.Module):
    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        hidden_act: str,
        quant_config: Optional[QuantizationConfig] = None,
        reduce_results: bool = True,
185
        prefix: str = "",
186
187
        tp_rank: Optional[int] = None,
        tp_size: Optional[int] = None,
Liangsheng Yin's avatar
Liangsheng Yin committed
188
189
    ) -> None:
        super().__init__()
190
191
        self.tp_size = tp_size

Liangsheng Yin's avatar
Liangsheng Yin committed
192
        self.gate_up_proj = MergedColumnParallelLinear(
193
194
195
196
197
            hidden_size,
            [intermediate_size] * 2,
            bias=False,
            quant_config=quant_config,
            prefix=add_prefix("gate_up_proj", prefix),
198
199
            tp_rank=tp_rank,
            tp_size=tp_size,
Liangsheng Yin's avatar
Liangsheng Yin committed
200
201
202
203
204
205
206
        )
        self.down_proj = RowParallelLinear(
            intermediate_size,
            hidden_size,
            bias=False,
            quant_config=quant_config,
            reduce_results=reduce_results,
207
            prefix=add_prefix("down_proj", prefix),
208
209
            tp_rank=tp_rank,
            tp_size=tp_size,
Liangsheng Yin's avatar
Liangsheng Yin committed
210
211
212
213
214
215
216
217
        )
        if hidden_act != "silu":
            raise ValueError(
                f"Unsupported activation: {hidden_act}. "
                "Only silu is supported for now."
            )
        self.act_fn = SiluAndMul()

218
219
220
221
    def forward(
        self,
        x,
        forward_batch=None,
222
        should_allreduce_fusion: bool = False,
223
224
        use_reduce_scatter: bool = False,
    ):
225
226
227
        if (self.tp_size == 1) and x.shape[0] == 0:
            return x

Liangsheng Yin's avatar
Liangsheng Yin committed
228
229
        gate_up, _ = self.gate_up_proj(x)
        x = self.act_fn(gate_up)
230
        x, _ = self.down_proj(
231
            x, skip_all_reduce=should_allreduce_fusion or use_reduce_scatter
232
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
233
234
235
        return x


Ke Bao's avatar
Ke Bao committed
236
class MoEGate(nn.Module):
237
238
239
240
    def __init__(
        self,
        config,
        prefix: str = "",
241
        is_nextn: bool = False,
242
    ):
Ke Bao's avatar
Ke Bao committed
243
        super().__init__()
244
        self.is_nextn = is_nextn
Ke Bao's avatar
Ke Bao committed
245
246
247
248
249
        self.weight = nn.Parameter(
            torch.empty((config.n_routed_experts, config.hidden_size))
        )
        if config.topk_method == "noaux_tc":
            self.e_score_correction_bias = nn.Parameter(
250
                torch.empty((config.n_routed_experts), dtype=torch.float32)
Ke Bao's avatar
Ke Bao committed
251
252
253
            )
        else:
            self.e_score_correction_bias = None
254
255
        if _is_cpu and _is_cpu_amx_available:
            self.quant_method = PackWeightMethod(weight_names=["weight"])
Ke Bao's avatar
Ke Bao committed
256
257

    def forward(self, hidden_states):
258
        if use_intel_amx_backend(self):
259
260
261
262
263
264
265
            return torch.ops.sgl_kernel.weight_packed_linear(
                hidden_states,
                self.weight,
                None,  # bias
                True,  # is_vnni
            )

266
        # NOTE: For some unknown reason, router_gemm seems degrade accept length.
267
        if (
268
            _is_cuda
269
            and hidden_states.shape[0] <= 16
270
271
272
273
            and hidden_states.shape[1] == 7168
            and self.weight.shape[0] == 256
            and _device_sm >= 90
        ):
274
275
            # router gemm output float32
            logits = dsv3_router_gemm(hidden_states, self.weight)
276
277
278
        else:
            logits = F.linear(hidden_states, self.weight, None)

Ke Bao's avatar
Ke Bao committed
279
280
281
        return logits


Liangsheng Yin's avatar
Liangsheng Yin committed
282
283
284
285
286
class DeepseekV2MoE(nn.Module):

    def __init__(
        self,
        config: PretrainedConfig,
fzyzcjy's avatar
fzyzcjy committed
287
        layer_id: int,
Liangsheng Yin's avatar
Liangsheng Yin committed
288
        quant_config: Optional[QuantizationConfig] = None,
289
        prefix: str = "",
290
        alt_stream: Optional[torch.cuda.Stream] = None,
291
        is_nextn: bool = False,
Liangsheng Yin's avatar
Liangsheng Yin committed
292
293
294
295
296
    ):
        super().__init__()
        self.tp_size = get_tensor_model_parallel_world_size()
        self.routed_scaling_factor = config.routed_scaling_factor
        self.n_shared_experts = config.n_shared_experts
297
298
299
300
301
        self.num_fused_shared_experts = (
            0
            if global_server_args_dict["disable_shared_experts_fusion"]
            else config.n_shared_experts
        )
302
        self.config = config
fzyzcjy's avatar
fzyzcjy committed
303
        self.layer_id = layer_id
304
        self.alt_stream = alt_stream
305

Liangsheng Yin's avatar
Liangsheng Yin committed
306
307
308
309
310
311
312
313
314
315
316
317
        if self.tp_size > config.n_routed_experts:
            raise ValueError(
                f"Tensor parallel size {self.tp_size} is greater than "
                f"the number of experts {config.n_routed_experts}."
            )

        if config.hidden_act != "silu":
            raise ValueError(
                f"Unsupported activation: {config.hidden_act}. "
                "Only silu is supported for now."
            )

318
319
320
        self.gate = MoEGate(
            config=config, prefix=add_prefix("gate", prefix), is_nextn=is_nextn
        )
Ke Bao's avatar
Ke Bao committed
321

322
        self.experts = get_moe_impl_class(quant_config)(
323
            num_experts=config.n_routed_experts
324
            + self.num_fused_shared_experts
325
            + global_server_args_dict["ep_num_redundant_experts"],
Cheng Wan's avatar
Cheng Wan committed
326
            num_fused_shared_experts=self.num_fused_shared_experts,
327
            top_k=config.num_experts_per_tok + self.num_fused_shared_experts,
328
329
            hidden_size=config.hidden_size,
            intermediate_size=config.moe_intermediate_size,
fzyzcjy's avatar
fzyzcjy committed
330
            layer_id=self.layer_id,
331
            quant_config=quant_config,
332
            routed_scaling_factor=self.routed_scaling_factor,
333
334
            prefix=add_prefix("experts", prefix),
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
335

336
337
338
339
340
341
342
343
344
345
        self.topk = TopK(
            top_k=config.num_experts_per_tok + self.num_fused_shared_experts,
            renormalize=config.norm_topk_prob,
            use_grouped_topk=True,
            num_expert_group=config.n_group,
            num_fused_shared_experts=self.num_fused_shared_experts,
            topk_group=config.topk_group,
            correction_bias=self.gate.e_score_correction_bias,
            routed_scaling_factor=self.routed_scaling_factor,
            apply_routed_scaling_factor_on_output=self.experts.should_fuse_routed_scaling_factor_in_topk(),
346
            force_topk=quant_config is None,
347
348
        )

349
350
351
        self.shared_experts_is_int8 = False
        self.shared_experts_is_fp8 = False
        self.shared_experts_weight_block_size = None
352
        if config.n_shared_experts is not None and self.num_fused_shared_experts == 0:
Liangsheng Yin's avatar
Liangsheng Yin committed
353
            intermediate_size = config.moe_intermediate_size * config.n_shared_experts
354
            # disable tp for shared experts when enable deepep moe, or with fp4 allgather
355
356
357
358
359
360
361
362
363
            self.shared_experts = DeepseekV2MLP(
                hidden_size=config.hidden_size,
                intermediate_size=intermediate_size,
                hidden_act=config.hidden_act,
                quant_config=quant_config,
                reduce_results=False,
                prefix=add_prefix("shared_experts", prefix),
                **(
                    dict(tp_rank=0, tp_size=1)
364
                    if get_moe_a2a_backend().is_deepep()
365
                    or should_use_flashinfer_cutlass_moe_fp4_allgather()
366
367
368
                    else {}
                ),
            )
AniZpZ's avatar
AniZpZ committed
369
370
371
372
            is_packed_weight = hasattr(
                self.shared_experts.gate_up_proj.quant_method, "quant_config"
            ) and self.shared_experts.gate_up_proj.quant_method.quant_config.get_name() in {
                "awq",
373
                "awq_marlin",
AniZpZ's avatar
AniZpZ committed
374
375
                "moe_wna16",
            }
376
            self.shared_experts_is_int8 = (
377
378
                not is_packed_weight
                and self.shared_experts.gate_up_proj.weight.dtype == torch.int8
379
380
            )
            self.shared_experts_is_fp8 = (
381
382
                not is_packed_weight
                and self.shared_experts.gate_up_proj.weight.dtype == torch.float8_e4m3fn
383
384
385
386
387
388
389
390
391
            )
            if self.shared_experts_is_fp8:
                assert (
                    self.shared_experts.gate_up_proj.quant_method.quant_config.weight_block_size
                    == self.shared_experts.down_proj.quant_method.quant_config.weight_block_size
                )
                self.shared_experts_weight_block_size = (
                    self.shared_experts.gate_up_proj.quant_method.quant_config.weight_block_size
                )
392

393
394
        self.top_k = config.num_experts_per_tok

395
        if get_moe_a2a_backend().is_deepep():
396
            # TODO: we will support tp < ep in the future
397
            self.ep_size = get_moe_expert_parallel_world_size()
398
399
400
401
            self.num_experts = (
                config.n_routed_experts
                + global_server_args_dict["ep_num_redundant_experts"]
            )
402
403
404
405
406
407
408
409
410
            self.renormalize = config.norm_topk_prob
            self.topk_group = config.topk_group
            self.num_expert_group = config.n_group
            self.correction_bias = (
                self.gate.e_score_correction_bias.data
                if self.gate.e_score_correction_bias is not None
                else None
            )

411
            self.deepep_dispatcher = MaybeTboDeepEPDispatcher(
412
413
414
                group=parallel_state.get_tp_group().device_group,
                router_topk=self.top_k,
                permute_fusion=True,
415
                num_experts=self.num_experts,
416
                num_local_experts=config.n_routed_experts // self.tp_size,
Liangsheng Yin's avatar
Liangsheng Yin committed
417
                hidden_size=config.hidden_size,
418
                params_dtype=config.torch_dtype,
419
                deepep_mode=get_deepep_mode(),
420
                async_finish=True,
421
                return_recv_hook=True,
Liangsheng Yin's avatar
Liangsheng Yin committed
422
423
            )

424
        self._enable_deepep_moe = get_moe_a2a_backend().is_deepep()
425

426
427
428
429
430
431
432
    def get_moe_weights(self):
        return [
            x.data
            for name, x in self.experts.named_parameters()
            if name not in ["correction_bias"]
        ]

433
    def forward(
434
435
436
        self,
        hidden_states: torch.Tensor,
        forward_batch: Optional[ForwardBatch] = None,
437
        should_allreduce_fusion: bool = False,
438
        use_reduce_scatter: bool = False,
439
440
    ) -> torch.Tensor:
        if not self._enable_deepep_moe:
441
442
443
444
            DUAL_STREAM_TOKEN_THRESHOLD = 1024
            if (
                self.alt_stream is not None
                and self.num_fused_shared_experts == 0
445
                and hidden_states.shape[0] > 0
446
447
                and hidden_states.shape[0] <= DUAL_STREAM_TOKEN_THRESHOLD
            ):
448
                return self.forward_normal_dual_stream(
449
450
451
                    hidden_states,
                    should_allreduce_fusion,
                    use_reduce_scatter,
452
                )
453
            else:
454
                return self.forward_normal(
455
456
457
                    hidden_states,
                    should_allreduce_fusion,
                    use_reduce_scatter,
458
                )
459
460
461
        else:
            return self.forward_deepep(hidden_states, forward_batch)

462
    def forward_normal_dual_stream(
463
464
        self,
        hidden_states: torch.Tensor,
465
        should_allreduce_fusion: bool = False,
466
        use_reduce_scatter: bool = False,
467
    ) -> torch.Tensor:
468

469
470
471
        current_stream = torch.cuda.current_stream()
        self.alt_stream.wait_stream(current_stream)
        shared_output = self._forward_shared_experts(hidden_states)
472

473
        with torch.cuda.stream(self.alt_stream):
474
475
            # router_logits: (num_tokens, n_experts)
            router_logits = self.gate(hidden_states)
Cheng Wan's avatar
Cheng Wan committed
476
477
            topk_output = self.topk(hidden_states, router_logits)
            final_hidden_states = self.experts(hidden_states, topk_output)
478
479
            if not _is_cuda:
                final_hidden_states *= self.routed_scaling_factor
Cheng Wan's avatar
Cheng Wan committed
480

481
        current_stream.wait_stream(self.alt_stream)
482
483
        with use_symmetric_memory(parallel_state.get_tp_group()) as sm:
            final_hidden_states_out = torch.empty_like(final_hidden_states)
Cheng Wan's avatar
Cheng Wan committed
484

485
486
487
        torch.add(final_hidden_states, shared_output, out=final_hidden_states_out)
        final_hidden_states = final_hidden_states_out
        sm.tag(final_hidden_states)
488
489
490
491
492
493
        if (
            self.tp_size > 1
            and not should_allreduce_fusion
            and not use_reduce_scatter
            and not should_use_flashinfer_cutlass_moe_fp4_allgather()
        ):
494
495
496
            final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
        return final_hidden_states

497
    def forward_normal(
498
499
        self,
        hidden_states: torch.Tensor,
500
        should_allreduce_fusion: bool = False,
501
        use_reduce_scatter: bool = False,
502
    ) -> torch.Tensor:
503
504
        if hasattr(self, "shared_experts") and use_intel_amx_backend(
            self.shared_experts.gate_up_proj
505
        ):
506
            return self.forward_cpu(hidden_states, should_allreduce_fusion)
507

508
509
510
511
512
513
514
515
        if hidden_states.shape[0] > 0:
            shared_output = self._forward_shared_experts(hidden_states)
            # router_logits: (num_tokens, n_experts)
            router_logits = self.gate(hidden_states)
            topk_output = self.topk(hidden_states, router_logits)
        else:
            shared_output = None
            topk_output = self.topk.empty_topk_output(hidden_states.device)
516

Cheng Wan's avatar
Cheng Wan committed
517
        final_hidden_states = self.experts(hidden_states, topk_output)
518
519
        if not _is_cuda and not _use_aiter:
            # fused in biased_grouped_topk so we can skip here
520
            final_hidden_states *= self.routed_scaling_factor
521
        if shared_output is not None:
522
523
524
525
526
            with use_symmetric_memory(parallel_state.get_tp_group()) as sm:
                final_hidden_states_out = torch.empty_like(final_hidden_states)
            torch.add(final_hidden_states, shared_output, out=final_hidden_states_out)
            final_hidden_states = final_hidden_states_out
            sm.tag(final_hidden_states)
527
528
529
530
531
532
        if (
            self.tp_size > 1
            and not should_allreduce_fusion
            and not use_reduce_scatter
            and not should_use_flashinfer_cutlass_moe_fp4_allgather()
        ):
533
534
535
            final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
        return final_hidden_states

536
    def forward_cpu(
537
538
539
        self,
        hidden_states: torch.Tensor,
        should_allreduce_fusion: bool = False,
540
    ) -> torch.Tensor:
541
542
        # router_logits: (num_tokens, n_experts)
        router_logits = self.gate(hidden_states)
543
        topk_output = self.topk(hidden_states, router_logits)
544
        fused_experts_out = self.experts(
545
            hidden_states=hidden_states, topk_output=topk_output
546
547
        )

548
549
550
        assert use_intel_amx_backend(
            self.shared_experts.gate_up_proj
        ) == use_intel_amx_backend(self.shared_experts.down_proj)
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
        # [Note] inplace should be False in fused_experts.
        # If inplace is True in fused_experts (self.experts), hidden_states will be changed after fused_experts
        # While hidden_states is still needed in shared_expert.
        final_hidden_states = torch.ops.sgl_kernel.shared_expert_cpu(
            hidden_states,
            self.shared_experts.gate_up_proj.weight,
            self.shared_experts.down_proj.weight,
            fused_experts_out,
            self.routed_scaling_factor,
            True,  # inplace
            self.shared_experts_is_int8,  # use_int8_w8a8
            self.shared_experts_is_fp8,  # use_fp8_w8a16
            (
                self.shared_experts.gate_up_proj.weight_scale
                if self.shared_experts_is_int8
                else (
                    self.shared_experts.gate_up_proj.weight_scale_inv
                    if self.shared_experts_is_fp8
                    else None
                )
            ),  # w1_scale
            (
                self.shared_experts.down_proj.weight_scale
                if self.shared_experts_is_int8
                else (
                    self.shared_experts.down_proj.weight_scale_inv
                    if self.shared_experts_is_fp8
                    else None
                )
            ),  # w2_scale
            (
                self.shared_experts_weight_block_size
                if self.shared_experts_is_fp8
                else None
            ),  # block_size
            None,  # a1_scale
            None,  # a2_scale
            True,  # is_vnni
        )
590
        if self.tp_size > 1 and not should_allreduce_fusion:
591
592
593
            final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
        return final_hidden_states

594
595
596
597
    def forward_deepep(
        self, hidden_states: torch.Tensor, forward_batch: ForwardBatch
    ) -> torch.Tensor:
        shared_output = None
Cheng Wan's avatar
Cheng Wan committed
598
        if hidden_states.shape[0] > 0:
599
600
601
            # router_logits: (num_tokens, n_experts)
            router_logits = self.gate(hidden_states)
            shared_output = self._forward_shared_experts(hidden_states)
602
603
604
            topk_weights, topk_idx, _ = self.topk(
                hidden_states,
                router_logits,
605
                num_token_non_padded=forward_batch.num_token_non_padded,
606
607
608
                expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new(
                    layer_id=self.layer_id,
                ),
609
610
            )
        else:
611
612
            topk_weights, topk_idx, _ = self.topk.empty_topk_output(
                hidden_states.device
613
            )
614

615
616
617
618
        final_hidden_states = self.experts(
            hidden_states=hidden_states,
            topk_idx=topk_idx,
            topk_weights=topk_weights,
619
            forward_batch=forward_batch,
620
621
622
        )

        if shared_output is not None:
623
624
625
626
627
            x = shared_output
            x.add_(final_hidden_states, alpha=self.routed_scaling_factor)
            final_hidden_states = x
        else:
            final_hidden_states *= self.routed_scaling_factor
628
629
630
631

        return final_hidden_states

    def _forward_shared_experts(self, hidden_states):
632
        if self.num_fused_shared_experts == 0:
633
634
635
636
            return self.shared_experts(hidden_states)
        else:
            return None

637
    def op_gate(self, state):
638
        if is_non_idle_and_non_empty(
639
            state.forward_batch.forward_mode, state.hidden_states_mlp_input
640
        ):
641
            # router_logits: (num_tokens, n_experts)
642
            state.router_logits = self.gate(state.hidden_states_mlp_input)
643
        else:
644
            state.router_logits = None
645

646
    def op_shared_experts(self, state):
647
        hidden_states_mlp_input = state.pop("hidden_states_mlp_input")
648
        if (self.num_fused_shared_experts == 0) and is_non_idle_and_non_empty(
649
            state.forward_batch.forward_mode, hidden_states_mlp_input
650
        ):
651
            state.shared_output = self.shared_experts(hidden_states_mlp_input)
652
        else:
653
            state.shared_output = None
654

655
    def op_select_experts(self, state):
656
        router_logits = state.pop("router_logits")
657
658
        hidden_states = state.hidden_states_mlp_input

659
        if router_logits is not None:
660
661
662
            with get_global_expert_distribution_recorder().with_current_layer(
                self.layer_id
            ):
663
                state.topk_weights_local, state.topk_idx_local, _ = self.topk(
664
665
666
667
668
669
670
                    hidden_states=hidden_states,
                    router_logits=router_logits,
                    num_token_non_padded=state.forward_batch.num_token_non_padded,
                    expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new(
                        layer_id=self.layer_id,
                    ),
                )
671
672
673
674
675
676
677
        else:
            state.topk_idx_local = torch.full(
                (0, self.top_k), -1, dtype=torch.int, device=hidden_states.device
            )
            state.topk_weights_local = torch.empty(
                (0, self.top_k), dtype=torch.float32, device=hidden_states.device
            )
678

679
    def op_dispatch_a(self, state):
680
        if self.ep_size > 1:
681
            self.experts.deepep_dispatcher.dispatch_a(
682
                hidden_states=state.hidden_states_mlp_input,
683
684
                topk_idx=state.pop("topk_idx_local"),
                topk_weights=state.pop("topk_weights_local"),
685
                forward_batch=state.forward_batch,
686
                tbo_subbatch_index=state.get("tbo_subbatch_index"),
687
            )
688

689
    def op_dispatch_b(self, state):
690
691
692
693
        if self.ep_size > 1:
            with get_global_expert_distribution_recorder().with_current_layer(
                self.layer_id
            ):
694
                state.dispatch_output = self.experts.deepep_dispatcher.dispatch_b(
695
696
                    tbo_subbatch_index=state.get("tbo_subbatch_index"),
                )
697
698

    def op_experts(self, state):
699
700
        state.hidden_states_experts_output = self.experts.moe_impl(
            dispatch_output=state.dispatch_output,
701
        )
702

703
    def op_combine_a(self, state):
704
        if self.ep_size > 1:
705
            self.experts.deepep_dispatcher.combine_a(
706
                hidden_states=state.pop("hidden_states_experts_output"),
707
708
                topk_idx=state.dispatch_output.topk_idx,
                topk_weights=state.dispatch_output.topk_weights,
709
                forward_batch=state.forward_batch,
710
                tbo_subbatch_index=state.get("tbo_subbatch_index"),
711
            )
712
            state.pop("dispatch_output")
713

714
    def op_combine_b(self, state):
715
        if self.ep_size > 1:
716
717
718
719
            state.hidden_states_after_combine = (
                self.experts.deepep_dispatcher.combine_b(
                    tbo_subbatch_index=state.get("tbo_subbatch_index"),
                )
720
            )
721
722

    def op_output(self, state):
723
        final_hidden_states = state.pop("hidden_states_after_combine")
724
725
726
727
728
729
730

        if (shared_output := state.pop("shared_output")) is not None:
            x = shared_output
            x.add_(final_hidden_states, alpha=self.routed_scaling_factor)
            final_hidden_states = x
        else:
            final_hidden_states *= self.routed_scaling_factor
Liangsheng Yin's avatar
Liangsheng Yin committed
731

732
        state.hidden_states_mlp_output = final_hidden_states
733

Liangsheng Yin's avatar
Liangsheng Yin committed
734
735
736
737
738
739
740
741
742

def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
    import math

    if scale <= 1:
        return 1.0
    return 0.1 * mscale * math.log(scale) + 1.0


743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
class DeepseekV2AttentionMLA(nn.Module):

    def __init__(
        self,
        config: PretrainedConfig,
        hidden_size: int,
        num_heads: int,
        qk_nope_head_dim: int,
        qk_rope_head_dim: int,
        v_head_dim: int,
        q_lora_rank: int,
        kv_lora_rank: int,
        rope_theta: float = 10000,
        rope_scaling: Optional[Dict[str, Any]] = None,
        max_position_embeddings: int = 8192,
        quant_config: Optional[QuantizationConfig] = None,
Lianmin Zheng's avatar
Lianmin Zheng committed
759
760
        reduce_results: bool = True,
        layer_id: int = None,
761
        prefix: str = "",
762
        alt_stream: Optional[torch.cuda.Stream] = None,
763
764
765
766
767
768
769
770
771
772
    ) -> None:
        super().__init__()
        self.layer_id = layer_id
        self.hidden_size = hidden_size
        self.qk_nope_head_dim = qk_nope_head_dim
        self.qk_rope_head_dim = qk_rope_head_dim
        self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
        self.v_head_dim = v_head_dim
        self.q_lora_rank = q_lora_rank
        self.kv_lora_rank = kv_lora_rank
Lianmin Zheng's avatar
Lianmin Zheng committed
773
774
775
        attn_tp_rank = get_attention_tp_rank()
        attn_tp_size = get_attention_tp_size()

776
        self.num_heads = num_heads
Lianmin Zheng's avatar
Lianmin Zheng committed
777
778
        assert num_heads % attn_tp_size == 0
        self.num_local_heads = num_heads // attn_tp_size
779
780
781
782
        self.scaling = self.qk_head_dim**-0.5
        self.rope_theta = rope_theta
        self.max_position_embeddings = max_position_embeddings

Lianmin Zheng's avatar
Lianmin Zheng committed
783
784
        # For tensor parallel attention
        if self.q_lora_rank is not None:
785
            self.fused_qkv_a_proj_with_mqa = ReplicatedLinear(
Ke Bao's avatar
Ke Bao committed
786
                self.hidden_size,
787
                self.q_lora_rank + self.kv_lora_rank + self.qk_rope_head_dim,
788
789
                bias=False,
                quant_config=quant_config,
790
                prefix=add_prefix("fused_qkv_a_proj_with_mqa", prefix),
791
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
792
793
794
795
            self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
            self.q_b_proj = ColumnParallelLinear(
                q_lora_rank,
                self.num_heads * self.qk_head_dim,
Ke Bao's avatar
Ke Bao committed
796
797
                bias=False,
                quant_config=quant_config,
Lianmin Zheng's avatar
Lianmin Zheng committed
798
799
800
                prefix=add_prefix("q_b_proj", prefix),
                tp_rank=attn_tp_rank,
                tp_size=attn_tp_size,
Ke Bao's avatar
Ke Bao committed
801
            )
Lianmin Zheng's avatar
Lianmin Zheng committed
802
803
        else:
            self.q_proj = ColumnParallelLinear(
804
                self.hidden_size,
Lianmin Zheng's avatar
Lianmin Zheng committed
805
                self.num_heads * self.qk_head_dim,
806
807
                bias=False,
                quant_config=quant_config,
Lianmin Zheng's avatar
Lianmin Zheng committed
808
809
810
                prefix=add_prefix("q_proj", prefix),
                tp_rank=attn_tp_rank,
                tp_size=attn_tp_size,
811
            )
812
813
814
815
816
817
818
819
            self.kv_a_proj_with_mqa = ReplicatedLinear(
                self.hidden_size,
                self.kv_lora_rank + self.qk_rope_head_dim,
                bias=False,
                quant_config=quant_config,
                prefix=add_prefix("kv_a_proj_with_mqa", prefix),
            )

Lianmin Zheng's avatar
Lianmin Zheng committed
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
        self.kv_b_proj = ColumnParallelLinear(
            self.kv_lora_rank,
            self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
            bias=False,
            quant_config=quant_config,
            prefix=add_prefix("kv_b_proj", prefix),
            tp_rank=attn_tp_rank,
            tp_size=attn_tp_size,
        )
        # O projection.
        self.o_proj = RowParallelLinear(
            self.num_heads * self.v_head_dim,
            self.hidden_size,
            bias=False,
            quant_config=quant_config,
            reduce_results=reduce_results,
            prefix=add_prefix("o_proj", prefix),
            tp_rank=attn_tp_rank,
            tp_size=attn_tp_size,
        )
840
        self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
Ke Bao's avatar
Ke Bao committed
841
842
843
844

        if rope_scaling:
            rope_scaling["rope_type"] = "deepseek_yarn"

845
        self.rotary_emb = get_rope_wrapper(
846
847
848
849
850
851
            qk_rope_head_dim,
            rotary_dim=qk_rope_head_dim,
            max_position=max_position_embeddings,
            base=rope_theta,
            rope_scaling=rope_scaling,
            is_neox_style=False,
852
            device=global_server_args_dict["device"],
853
854
855
856
857
858
859
        )

        if rope_scaling:
            mscale_all_dim = rope_scaling.get("mscale_all_dim", False)
            scaling_factor = rope_scaling["factor"]
            mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
            self.scaling = self.scaling * mscale * mscale
Ke Bao's avatar
Ke Bao committed
860
861
        else:
            self.rotary_emb.forward = self.rotary_emb.forward_native
862

863
        self.attn_mqa = RadixAttention(
864
865
866
867
868
869
            self.num_local_heads,
            self.kv_lora_rank + self.qk_rope_head_dim,
            self.scaling,
            num_kv_heads=1,
            layer_id=layer_id,
            v_head_dim=self.kv_lora_rank,
870
            quant_config=quant_config,
871
            prefix=add_prefix("attn_mqa", prefix),
872
873
        )

874
875
876
877
878
879
880
        self.attn_mha = RadixAttention(
            self.num_local_heads,
            self.qk_nope_head_dim + self.qk_rope_head_dim,
            self.scaling,
            num_kv_heads=self.num_local_heads,
            layer_id=layer_id,
            v_head_dim=self.v_head_dim,
881
            quant_config=quant_config,
882
            prefix=add_prefix("attn_mha", prefix),
883
884
        )

885
        self.alt_stream = alt_stream
886
        self.attn_mha.kv_b_proj = None
887

Ke Bao's avatar
Ke Bao committed
888
889
        self.w_kc = None
        self.w_vc = None
890
        self.w_scale = 1.0
891

892
893
894
895
        self.w_scale_k = None
        self.w_scale_v = None
        self.use_deep_gemm_bmm = False

Lianmin Zheng's avatar
Lianmin Zheng committed
896
897
898
        self.flashinfer_mla_disable_ragged = global_server_args_dict[
            "flashinfer_mla_disable_ragged"
        ]
899
900
901
        self.disable_chunked_prefix_cache = global_server_args_dict[
            "disable_chunked_prefix_cache"
        ]
902
903
904
905

        self.current_attention_backend = (
            None  # Attention backend used by current forward batch
        )
906
907
908
        self.rocm_fused_decode_mla = get_bool_env_var(
            "SGLANG_ROCM_FUSED_DECODE_MLA", "false"
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
909

910
        # TODO: Design a finer way to determine the threshold
911
912
913
        self.chunked_prefix_cache_threshold = get_int_env_var(
            "SGL_CHUNKED_PREFIX_CACHE_THRESHOLD", 8192
        )
914

915
916
917
        # If we have self.fused_qkv_a_proj_with_mqa and we're running on CPU, we will choose the torch.ops.sgl_kernel.qkv_proj_with_rope_fused_weight kernel
        # which requires self.w_kc and self.w_vc to be packed.
        # If not, we will use torch.bmm and weight shouldn't be packed in this case
AniZpZ's avatar
AniZpZ committed
918
919
        has_fused_proj = hasattr(self, "fused_qkv_a_proj_with_mqa")
        if has_fused_proj and _is_cpu and _is_cpu_amx_available:
920
921
922
923
            self.quant_method = PackWeightMethod(
                weight_names=["w_kc", "w_vc"], transpose_dims=[[1, 2], [1, 2]]
            )

924
        is_packed_weight = (
AniZpZ's avatar
AniZpZ committed
925
926
927
            has_fused_proj
            and hasattr(self.fused_qkv_a_proj_with_mqa.quant_method, "quant_config")
            and self.fused_qkv_a_proj_with_mqa.quant_method.quant_config.get_name()
928
            in {"awq", "awq_marlin", "moe_wna16"}
929
        )
930
        self.use_min_latency_fused_a_gemm = (
AniZpZ's avatar
AniZpZ committed
931
            has_fused_proj
932
            and not is_packed_weight
933
934
935
            and self.fused_qkv_a_proj_with_mqa.weight.dtype == torch.bfloat16
            and self.fused_qkv_a_proj_with_mqa.weight.shape[0] == 2112
            and self.fused_qkv_a_proj_with_mqa.weight.shape[1] == 7168
936
            and _is_cuda
937
            and _device_sm >= 90
938
939
        )

940
        self.qkv_proj_with_rope_is_int8 = (
AniZpZ's avatar
AniZpZ committed
941
            has_fused_proj
942
            and not is_packed_weight
943
944
945
            and self.fused_qkv_a_proj_with_mqa.weight.dtype == torch.int8
        )
        self.qkv_proj_with_rope_is_fp8 = (
AniZpZ's avatar
AniZpZ committed
946
            has_fused_proj
947
            and not is_packed_weight
948
949
950
951
            and self.fused_qkv_a_proj_with_mqa.weight.dtype == torch.float8_e4m3fn
        )

        self.weight_block_size = None
952
953
954
955
956
957
        if self.qkv_proj_with_rope_is_fp8 and _is_cpu and _is_cpu_amx_available:
            assert getattr(
                self.fused_qkv_a_proj_with_mqa.quant_method, "block_quant", False
            ) == getattr(self.q_b_proj.quant_method, "block_quant", False)
            use_block_quant = getattr(
                self.fused_qkv_a_proj_with_mqa.quant_method, "block_quant", False
958
959
            )

960
961
962
963
964
965
966
967
968
            if use_block_quant:
                assert (
                    self.fused_qkv_a_proj_with_mqa.quant_method.quant_config.weight_block_size
                    == self.q_b_proj.quant_method.quant_config.weight_block_size
                )
                self.weight_block_size = (
                    self.fused_qkv_a_proj_with_mqa.quant_method.quant_config.weight_block_size
                )

969
970
971
    def dispatch_attn_forward_method(
        self, forward_batch: ForwardBatch
    ) -> AttnForwardMethod:
972
973
974
975
976
977
978
979
980
981
        def _dispatch_mla_subtype():
            if _is_hip:
                if (
                    self.rocm_fused_decode_mla
                    and forward_batch.forward_mode.is_decode()
                ):
                    return AttnForwardMethod.MLA_FUSED_ROPE
                else:
                    return AttnForwardMethod.MLA
            else:
982
983
                if hasattr(self, "fused_qkv_a_proj_with_mqa") and use_intel_amx_backend(
                    self
984
985
986
987
                ):
                    return AttnForwardMethod.MLA_FUSED_ROPE_CPU
                else:
                    return AttnForwardMethod.MLA
988

989
990
991
992
993
994
995
996
        # Determine attention backend used by current forward batch
        if forward_batch.forward_mode.is_decode_or_idle():
            attention_backend = global_server_args_dict["decode_attention_backend"]
        else:
            attention_backend = global_server_args_dict["prefill_attention_backend"]
        self.current_attention_backend = attention_backend

        if attention_backend == "ascend":
997
            return AttnForwardMethod.MLA
998
        elif attention_backend == "flashinfer":
Lianmin Zheng's avatar
Lianmin Zheng committed
999
            # Flashinfer MLA: Do not absorb when enabling ragged prefill
1000
            if (
Lianmin Zheng's avatar
Lianmin Zheng committed
1001
1002
1003
1004
                not self.flashinfer_mla_disable_ragged
                and forward_batch.forward_mode.is_extend()
                and not forward_batch.forward_mode.is_target_verify()
                and not forward_batch.forward_mode.is_draft_extend()
1005
                and sum(forward_batch.extend_prefix_lens_cpu) == 0
1006
1007
1008
            ):
                return AttnForwardMethod.MHA
            else:
1009
                return _dispatch_mla_subtype()
1010
        elif attention_backend == "fa3":
1011
            # Flash Attention: Use MHA with chunked KV cache when prefilling on long sequences.
1012
1013
            if forward_batch.extend_prefix_lens_cpu is not None:
                sum_extend_prefix_lens = sum(forward_batch.extend_prefix_lens_cpu)
1014
1015
1016
1017
1018
            if (
                forward_batch.forward_mode.is_extend()
                and not self.disable_chunked_prefix_cache
                and not forward_batch.forward_mode.is_target_verify()
                and not forward_batch.forward_mode.is_draft_extend()
1019
1020
1021
1022
                and (
                    sum_extend_prefix_lens >= self.chunked_prefix_cache_threshold
                    or sum_extend_prefix_lens == 0
                )
1023
1024
1025
            ):
                return AttnForwardMethod.MHA_CHUNKED_KV
            else:
1026
                return _dispatch_mla_subtype()
1027
        elif attention_backend == "aiter":
1028
1029
1030
1031
1032
1033
1034
1035
            if (
                forward_batch.forward_mode.is_extend()
                and not forward_batch.forward_mode.is_target_verify()
                and not forward_batch.forward_mode.is_draft_extend()
            ):
                return AttnForwardMethod.MHA
            else:
                return AttnForwardMethod.MLA
Lianmin Zheng's avatar
Lianmin Zheng committed
1036
1037
        else:
            # Triton: Use normal computation for prefill and use weight absorption for extend/decode
1038
            if (
Lianmin Zheng's avatar
Lianmin Zheng committed
1039
1040
1041
                forward_batch.forward_mode.is_extend()
                and not forward_batch.forward_mode.is_target_verify()
                and not forward_batch.forward_mode.is_draft_extend()
1042
                and sum(forward_batch.extend_prefix_lens_cpu) == 0
1043
1044
1045
            ):
                return AttnForwardMethod.MHA
            else:
1046
                return _dispatch_mla_subtype()
Lianmin Zheng's avatar
Lianmin Zheng committed
1047

1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
    def op_prepare(self, state):
        state.attn_intermediate_state = self.forward_prepare(
            positions=state.positions,
            hidden_states=state.pop("hidden_states_after_comm_pre_attn"),
            forward_batch=state.forward_batch,
            zero_allocator=state.zero_allocator,
        )

    def op_core(self, state):
        state.hidden_states_after_attn = self.forward_core(
            state.pop("attn_intermediate_state")
        )

1061
1062
1063
1064
    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
1065
        forward_batch: ForwardBatch,
1066
        zero_allocator: BumpAllocator,
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
    ):
        s = self.forward_prepare(
            positions=positions,
            hidden_states=hidden_states,
            forward_batch=forward_batch,
            zero_allocator=zero_allocator,
        )
        return self.forward_core(s)

    def forward_prepare(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        forward_batch: ForwardBatch,
        zero_allocator: BumpAllocator,
    ):
1083
1084
1085
        if self.attn_mha.kv_b_proj is None:
            self.attn_mha.kv_b_proj = self.kv_b_proj

Lianmin Zheng's avatar
Lianmin Zheng committed
1086
1087
1088
1089
        if hidden_states.shape[0] == 0:
            assert (
                not self.o_proj.reduce_results
            ), "short-circuiting allreduce will lead to hangs"
1090
            return hidden_states, None, forward_batch, None
1091

1092
1093
1094
        attn_forward_method = self.dispatch_attn_forward_method(forward_batch)

        if attn_forward_method == AttnForwardMethod.MHA:
1095
1096
1097
            inner_state = self.forward_normal_prepare(
                positions, hidden_states, forward_batch, zero_allocator
            )
1098
        elif attn_forward_method == AttnForwardMethod.MHA_CHUNKED_KV:
1099
1100
            inner_state = self.forward_normal_chunked_kv_prepare(
                positions, hidden_states, forward_batch, zero_allocator
1101
            )
1102
        elif attn_forward_method == AttnForwardMethod.MLA:
1103
            inner_state = self.forward_absorb_prepare(
1104
1105
1106
                positions, hidden_states, forward_batch, zero_allocator
            )
        elif attn_forward_method == AttnForwardMethod.MLA_FUSED_ROPE:
1107
1108
            inner_state = self.forward_absorb_fused_mla_rope_prepare(
                positions, hidden_states, forward_batch, zero_allocator
1109
            )
1110
1111
1112
1113
        elif attn_forward_method == AttnForwardMethod.MLA_FUSED_ROPE_CPU:
            inner_state = self.forward_absorb_fused_mla_rope_cpu_prepare(
                positions, hidden_states, forward_batch, zero_allocator
            )
1114
        else:
1115
            raise NotImplementedError
1116
        return None, attn_forward_method, forward_batch, inner_state
1117

1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
    def forward_core(self, intermediate_state):
        hidden_states, attn_forward_method, forward_batch, inner_state = (
            intermediate_state
        )
        if inner_state is None:
            return hidden_states

        if attn_forward_method == AttnForwardMethod.MHA:
            return self.forward_normal_core(*inner_state)
        elif attn_forward_method == AttnForwardMethod.MHA_CHUNKED_KV:
            return self.forward_normal_chunked_kv_core(*inner_state)
        elif attn_forward_method == AttnForwardMethod.MLA:
            return self.forward_absorb_core(*inner_state)
        elif attn_forward_method == AttnForwardMethod.MLA_FUSED_ROPE:
            return self.forward_absorb_fused_mla_rope_core(*inner_state)
1133
1134
        elif attn_forward_method == AttnForwardMethod.MLA_FUSED_ROPE_CPU:
            return self.forward_absorb_fused_mla_rope_cpu_core(*inner_state)
1135
1136
1137
1138
        else:
            raise NotImplementedError

    def forward_normal_prepare(
1139
1140
1141
1142
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        forward_batch: ForwardBatch,
1143
1144
        zero_allocator: BumpAllocator,
    ):
1145
        if self.q_lora_rank is not None:
1146
1147
1148
            q, latent_cache = self.fused_qkv_a_proj_with_mqa(hidden_states)[0].split(
                [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1
            )
1149
1150
1151
1152
1153
1154
            q = self.q_a_layernorm(q)
            q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
        else:
            q = self.q_proj(hidden_states)[0].view(
                -1, self.num_local_heads, self.qk_head_dim
            )
1155
1156
            latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]

1157
1158
1159
        _, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
        kv_a, _ = latent_cache.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
        latent_cache = latent_cache.unsqueeze(1)
1160
        kv_a = self.kv_a_layernorm(kv_a)
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
        kv = self.kv_b_proj(kv_a)[0]
        kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim)
        k_nope = kv[..., : self.qk_nope_head_dim]
        v = kv[..., self.qk_nope_head_dim :]
        k_pe = latent_cache[:, :, self.kv_lora_rank :]
        q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
        q[..., self.qk_nope_head_dim :] = q_pe
        k = torch.empty_like(q)
        k[..., : self.qk_nope_head_dim] = k_nope
        k[..., self.qk_nope_head_dim :] = k_pe

        latent_cache[:, :, : self.kv_lora_rank] = kv_a.unsqueeze(1)
        latent_cache[:, :, self.kv_lora_rank :] = k_pe

        # Save latent cache
        forward_batch.token_to_kv_pool.set_kv_buffer(
            self.attn_mha, forward_batch.out_cache_loc, latent_cache, None
        )
1179
1180
1181
1182

        return q, k, v, forward_batch

    def forward_normal_core(self, q, k, v, forward_batch):
1183
1184
1185
1186
1187
        attn_output = self.attn_mha(q, k, v, forward_batch, save_kv_cache=False)
        attn_output = attn_output.reshape(-1, self.num_local_heads * self.v_head_dim)
        output, _ = self.o_proj(attn_output)
        return output

Faraz's avatar
Faraz committed
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
    def _fuse_rope_for_trtllm_mla(self, forward_batch: ForwardBatch) -> bool:
        """
        Check if we should skip rope and do fused rope+quantize for TRTLLM MLA decode in fp8_e4m3 path.
        """
        return (
            self.current_attention_backend == "trtllm_mla"
            and forward_batch.forward_mode.is_decode_or_idle()
            and forward_batch.attn_backend.data_type == torch.float8_e4m3fn
        )

1198
    def forward_absorb_prepare(
1199
1200
1201
1202
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        forward_batch: ForwardBatch,
1203
        zero_allocator: BumpAllocator,
1204
    ):
1205
        from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
1206

1207
        if self.q_lora_rank is not None:
1208
1209
1210
1211
1212
1213
1214
            if hidden_states.shape[0] <= 16 and self.use_min_latency_fused_a_gemm:
                fused_qkv_a_proj_out = dsv3_fused_a_gemm(
                    hidden_states, self.fused_qkv_a_proj_with_mqa.weight.T
                )
            else:
                fused_qkv_a_proj_out = self.fused_qkv_a_proj_with_mqa(hidden_states)[0]
            q, latent_cache = fused_qkv_a_proj_out.split(
1215
1216
                [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1
            )
1217
1218
1219
            k_nope = latent_cache[..., : self.kv_lora_rank]

            # overlap qk norm
1220
            if self.alt_stream is not None and get_is_capture_mode():
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
                current_stream = torch.cuda.current_stream()
                self.alt_stream.wait_stream(current_stream)
                q = self.q_a_layernorm(q)
                with torch.cuda.stream(self.alt_stream):
                    k_nope = self.kv_a_layernorm(k_nope)
                current_stream.wait_stream(self.alt_stream)
            else:
                q = self.q_a_layernorm(q)
                k_nope = self.kv_a_layernorm(k_nope)

            k_nope = k_nope.unsqueeze(1)
1232
1233
1234
1235
1236
            q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
        else:
            q = self.q_proj(hidden_states)[0].view(
                -1, self.num_local_heads, self.qk_head_dim
            )
1237
            latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
1238
1239
1240
            k_nope = latent_cache[..., : self.kv_lora_rank]
            k_nope = self.kv_a_layernorm(k_nope).unsqueeze(1)

1241
        q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
1242
        k_pe = latent_cache[..., self.kv_lora_rank :].unsqueeze(1)
1243

1244
1245
        if self.use_deep_gemm_bmm:
            q_nope_val, q_nope_scale, masked_m, expected_m, aligned_m = (
1246
                per_token_group_quant_mla_deep_gemm_masked_fp8(q_nope.transpose(0, 1))
1247
1248
1249
1250
            )
            q_nope_out = q_nope.new_empty(
                (self.num_local_heads, aligned_m, self.kv_lora_rank)
            )
1251
            deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
1252
1253
1254
1255
1256
1257
1258
                (q_nope_val, q_nope_scale),
                (self.w_kc, self.w_scale_k),
                q_nope_out,
                masked_m,
                expected_m,
            )
            q_nope_out = q_nope_out[:, :expected_m, :]
1259
1260
        elif _is_hip:
            # TODO(haishaw): add bmm_fp8 to ROCm
1261
1262
1263
1264
            q_nope_out = torch.bmm(
                q_nope.to(torch.bfloat16).transpose(0, 1),
                self.w_kc.to(torch.bfloat16) * self.w_scale,
            )
1265
        elif self.w_kc.dtype == torch.float8_e4m3fn:
1266
            q_nope_val, q_nope_scale = per_tensor_quant_mla_fp8(
Lianmin Zheng's avatar
Lianmin Zheng committed
1267
                q_nope.transpose(0, 1),
1268
                zero_allocator.allocate(1),
1269
1270
1271
1272
1273
1274
            )
            q_nope_out = bmm_fp8(
                q_nope_val, self.w_kc, q_nope_scale, self.w_scale, torch.bfloat16
            )
        else:
            q_nope_out = torch.bmm(q_nope.transpose(0, 1), self.w_kc)
1275
1276

        q_nope_out = q_nope_out.transpose(0, 1)
Faraz's avatar
Faraz committed
1277
1278
1279

        if not self._fuse_rope_for_trtllm_mla(forward_batch):
            q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
1280

1281
1282
1283
1284
1285
        return q_pe, k_pe, q_nope_out, k_nope, forward_batch, zero_allocator

    def forward_absorb_core(
        self, q_pe, k_pe, q_nope_out, k_nope, forward_batch, zero_allocator
    ):
1286
        if (
1287
1288
1289
            self.current_attention_backend == "fa3"
            or self.current_attention_backend == "flashinfer"
            or self.current_attention_backend == "cutlass_mla"
1290
            or self.current_attention_backend == "trtllm_mla"
1291
        ):
Faraz's avatar
Faraz committed
1292
1293
1294
1295
1296
1297
            extra_args = {}
            if self._fuse_rope_for_trtllm_mla(forward_batch):
                extra_args = {
                    "cos_sin_cache": self.rotary_emb.cos_sin_cache,
                    "is_neox": self.rotary_emb.is_neox_style,
                }
1298
            attn_output = self.attn_mqa(
Faraz's avatar
Faraz committed
1299
1300
1301
1302
1303
1304
1305
                q_nope_out,
                k_nope,
                k_nope,
                forward_batch,
                q_rope=q_pe,
                k_rope=k_pe,
                **extra_args,
1306
1307
1308
            )
        else:
            q = torch.cat([q_nope_out, q_pe], dim=-1)
Ke Bao's avatar
Ke Bao committed
1309
            k = torch.cat([k_nope, k_pe], dim=-1)
1310
            attn_output = self.attn_mqa(q, k, k_nope, forward_batch)
1311
1312
        attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)

1313
1314
        if self.use_deep_gemm_bmm:
            attn_output_val, attn_output_scale, masked_m, expected_m, aligned_m = (
1315
1316
                per_token_group_quant_mla_deep_gemm_masked_fp8(
                    attn_output.transpose(0, 1)
1317
1318
1319
1320
1321
                )
            )
            attn_bmm_output = attn_output.new_empty(
                (self.num_local_heads, aligned_m, self.v_head_dim)
            )
1322
            deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
1323
1324
1325
1326
1327
1328
                (attn_output_val, attn_output_scale),
                (self.w_vc, self.w_scale_v),
                attn_bmm_output,
                masked_m,
                expected_m,
            )
Ke Bao's avatar
Ke Bao committed
1329
1330
1331
            attn_bmm_output = (
                attn_bmm_output[:, :expected_m, :].transpose(0, 1).flatten(1, 2)
            )
1332
1333
        elif _is_hip:
            # TODO(haishaw): add bmm_fp8 to ROCm
1334
1335
1336
1337
            attn_bmm_output = torch.bmm(
                attn_output.to(torch.bfloat16).transpose(0, 1),
                self.w_vc.to(torch.bfloat16) * self.w_scale,
            )
Ke Bao's avatar
Ke Bao committed
1338
            attn_bmm_output = attn_bmm_output.transpose(0, 1).flatten(1, 2)
1339
        elif self.w_vc.dtype == torch.float8_e4m3fn:
1340
            attn_output_val, attn_output_scale = per_tensor_quant_mla_fp8(
Lianmin Zheng's avatar
Lianmin Zheng committed
1341
                attn_output.transpose(0, 1),
1342
                zero_allocator.allocate(1),
1343
1344
1345
1346
1347
1348
1349
1350
            )
            attn_bmm_output = bmm_fp8(
                attn_output_val,
                self.w_vc,
                attn_output_scale,
                self.w_scale,
                torch.bfloat16,
            )
Ke Bao's avatar
Ke Bao committed
1351
            attn_bmm_output = attn_bmm_output.transpose(0, 1).flatten(1, 2)
1352
        else:
Ke Bao's avatar
Ke Bao committed
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
            attn_bmm_output = torch.empty(
                (attn_output.shape[0], self.num_local_heads * self.v_head_dim),
                dtype=attn_output.dtype,
                device=attn_output.device,
            )
            torch.bmm(
                attn_output.transpose(0, 1),
                self.w_vc,
                out=attn_bmm_output.view(
                    -1, self.num_local_heads, self.v_head_dim
                ).transpose(0, 1),
            )
        output, _ = self.o_proj(attn_bmm_output)
1366
1367
1368

        return output

1369
    def forward_absorb_fused_mla_rope_prepare(
1370
1371
1372
1373
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        forward_batch: ForwardBatch,
1374
        zero_allocator: BumpAllocator,
1375
    ):
1376
1377
1378
1379
1380
1381
1382
1383
        enable_rope_fusion = (
            os.getenv("SGLANG_FUSED_MLA_ENABLE_ROPE_FUSION", "1") == "1"
        )
        q_len = hidden_states.shape[0]
        q_input = hidden_states.new_empty(
            q_len, self.num_local_heads, self.kv_lora_rank + self.qk_rope_head_dim
        )
        if self.q_lora_rank is not None:
1384
1385
1386
            q, latent_cache = self.fused_qkv_a_proj_with_mqa(hidden_states)[0].split(
                [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1
            )
1387
1388
1389
1390
1391
1392
            q = self.q_a_layernorm(q)
            q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
        else:
            q = self.q_proj(hidden_states)[0].view(
                -1, self.num_local_heads, self.qk_head_dim
            )
1393
            latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
1394
1395
        q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)

1396
1397
        if _is_hip:
            # TODO(haishaw): add bmm_fp8 to ROCm
1398
1399
1400
1401
1402
            q_nope_out = torch.bmm(
                q_nope.to(torch.bfloat16).transpose(0, 1),
                self.w_kc.to(torch.bfloat16) * self.w_scale,
            )
        elif self.w_kc.dtype == torch.float8_e4m3fn:
1403
            q_nope_val, q_nope_scale = per_tensor_quant_mla_fp8(
1404
1405
1406
                q_nope.transpose(0, 1),
                zero_allocator.allocate(1),
                dtype=torch.float8_e4m3fn,
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
            )
            q_nope_out = bmm_fp8(
                q_nope_val, self.w_kc, q_nope_scale, self.w_scale, torch.bfloat16
            )
        else:
            q_nope_out = torch.bmm(q_nope.transpose(0, 1), self.w_kc)
        q_input[..., : self.kv_lora_rank] = q_nope_out.transpose(0, 1)
        v_input = latent_cache[..., : self.kv_lora_rank]
        v_input = self.kv_a_layernorm(v_input.contiguous()).unsqueeze(1)
        k_input = latent_cache.unsqueeze(1)
        k_input[..., : self.kv_lora_rank] = v_input

        if not enable_rope_fusion:
            k_pe = k_input[..., self.kv_lora_rank :]
            q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
            q_input[..., self.kv_lora_rank :] = q_pe
            k_input[..., self.kv_lora_rank :] = k_pe
            k_pe_output = None
        else:
            k_pe_output = torch.empty_like(k_input[..., self.kv_lora_rank :])

        q_input[..., self.kv_lora_rank :] = q_pe

        # attn_output = self.attn_mqa(q_input, k_input, v_input, forward_batch)
        # Use Fused ROPE with use_rope=OFF.
        attn_output = torch.empty(
            (q_len, self.num_local_heads, self.kv_lora_rank),
            dtype=q.dtype,
            device=q.device,
        )
        attn_logits, _, kv_indptr, kv_indices, _, _, _ = (
            forward_batch.attn_backend.forward_metadata
        )
        cos_sin_cache = self.rotary_emb.cos_sin_cache
        num_kv_split = forward_batch.attn_backend.num_kv_splits
        sm_scale = self.attn_mqa.scaling
        if attn_logits is None:
            attn_logits = torch.empty(
                (
                    forward_batch.batch_size,
                    self.num_local_heads,
                    num_kv_split,
                    self.kv_lora_rank + 1,
                ),
                dtype=torch.float32,
                device=q.device,
            )

        # save current latent cache.
        forward_batch.token_to_kv_pool.set_kv_buffer(
            self.attn_mqa, forward_batch.out_cache_loc, k_input, None
        )
        key_cache_buf = forward_batch.token_to_kv_pool.get_key_buffer(
            self.attn_mqa.layer_id
        )
        val_cache_buf = key_cache_buf[..., : self.kv_lora_rank]

1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
        return (
            q_input,
            key_cache_buf,
            val_cache_buf,
            attn_output,
            kv_indptr,
            kv_indices,
            k_pe_output,
            cos_sin_cache,
            positions,
            attn_logits,
            num_kv_split,
            sm_scale,
            enable_rope_fusion,
            k_input,
            forward_batch,
            zero_allocator,
        )

1483
1484
1485
1486
1487
1488
1489
    def forward_absorb_fused_mla_rope_cpu_prepare(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        forward_batch: ForwardBatch,
        zero_allocator: BumpAllocator,
    ):
1490
1491
        assert self.q_lora_rank is not None and use_intel_amx_backend(
            self
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
        ), "forward_absorb_fused_mla_rope_cpu_prepare requires q_lora_rank is not None and use_intel_amx_backend"

        q_input, k_input, v_input = (
            torch.ops.sgl_kernel.qkv_proj_with_rope_fused_weight(
                hidden_states,
                self.fused_qkv_a_proj_with_mqa.weight,
                self.q_b_proj.weight,
                self.w_kc,
                self.q_a_layernorm.weight,
                self.kv_a_layernorm.weight,
                positions,
                self.rotary_emb.cos_sin_cache,
                self.kv_a_layernorm.variance_epsilon,
                self.qkv_proj_with_rope_is_int8,
                self.qkv_proj_with_rope_is_fp8,
                (
                    self.fused_qkv_a_proj_with_mqa.weight_scale
                    if self.qkv_proj_with_rope_is_int8
                    else (
                        self.fused_qkv_a_proj_with_mqa.weight_scale_inv
                        if self.qkv_proj_with_rope_is_fp8
                        else None
                    )
                ),
                (
                    self.q_b_proj.weight_scale
                    if self.qkv_proj_with_rope_is_int8
                    else (
                        self.q_b_proj.weight_scale_inv
                        if self.qkv_proj_with_rope_is_fp8
                        else None
                    )
                ),
                True,  # is_vnni
                self.weight_block_size,
                self.q_lora_rank,
                self.kv_lora_rank,
                self.qk_rope_head_dim,
            )
        )
        return (q_input, k_input, v_input, forward_batch, zero_allocator)

1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
    def forward_absorb_fused_mla_rope_core(
        self,
        q_input,
        key_cache_buf,
        val_cache_buf,
        attn_output,
        kv_indptr,
        kv_indices,
        k_pe_output,
        cos_sin_cache,
        positions,
        attn_logits,
        num_kv_split,
        sm_scale,
        enable_rope_fusion,
        k_input,
        forward_batch,
        zero_allocator,
    ):
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
        decode_attention_fwd_grouped_rope(
            q_input,
            key_cache_buf,
            val_cache_buf,
            attn_output,
            kv_indptr,
            kv_indices,
            k_pe_output,
            self.kv_lora_rank,
            self.rotary_emb.rotary_dim,
            cos_sin_cache,
            positions,
            attn_logits,
            num_kv_split,
            sm_scale,
            logit_cap=self.attn_mqa.logit_cap,
            use_rope=enable_rope_fusion,
            is_neox_style=self.rotary_emb.is_neox_style,
        )

        if enable_rope_fusion:
            k_input[..., self.kv_lora_rank :] = k_pe_output
            forward_batch.token_to_kv_pool.set_kv_buffer(
                self.attn_mqa, forward_batch.out_cache_loc, k_input, None
            )

        attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)

1581
1582
        if _is_hip:
            # TODO(haishaw): add bmm_fp8 to ROCm
1583
1584
1585
1586
1587
            attn_bmm_output = torch.bmm(
                attn_output.to(torch.bfloat16).transpose(0, 1),
                self.w_vc.to(torch.bfloat16) * self.w_scale,
            )
        elif self.w_vc.dtype == torch.float8_e4m3fn:
1588
            attn_output_val, attn_output_scale = per_tensor_quant_mla_fp8(
1589
1590
1591
                attn_output.transpose(0, 1),
                zero_allocator.allocate(1),
                dtype=torch.float8_e4m3fn,
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
            )
            attn_bmm_output = bmm_fp8(
                attn_output_val,
                self.w_vc,
                attn_output_scale,
                self.w_scale,
                torch.bfloat16,
            )
        else:
            attn_bmm_output = torch.bmm(attn_output.transpose(0, 1), self.w_vc)
        attn_output = attn_bmm_output.transpose(0, 1).flatten(1, 2)
1603
1604
1605
1606
        output, _ = self.o_proj(attn_output)

        return output

1607
1608
1609
    def forward_absorb_fused_mla_rope_cpu_core(
        self, q_input, k_input, v_input, forward_batch, zero_allocator
    ):
1610
1611
        assert self.q_lora_rank is not None and use_intel_amx_backend(
            self
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
        ), "forward_absorb_fused_mla_rope_cpu_core requires q_lora_rank is not None and use_intel_amx_backend"

        attn_output = self.attn_mqa(q_input, k_input, v_input, forward_batch)
        attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)

        # [Note] Align shapes of bmm inputs.
        # Shapes of inputs:
        #   q_nope: [M, B, K]
        #   original self.w_kc: [B, K, N]
        #   current self.w_kc (which has been converted in PackWeightMethod): [B, N, K]

        # Shapes of inputs to sgl_kernel.cpu.bmm:
        #   out: [B, M, N]
        #   mat1: [B, M, K]
        #   mat2: [B, N, K]
        B = self.w_vc.size(0)
        N = self.w_vc.size(1)
        M = attn_output.size(0)
        output = torch.empty([M, int(B * N)], dtype=attn_output.dtype)
        attn_bmm_output = output.view([M, B, N]).transpose_(0, 1)
        torch.ops.sgl_kernel.bmm_cpu(
            attn_bmm_output,
            attn_output.transpose(0, 1),
            self.w_vc,
            True,  # is_vnni
            None,  # scale
        )
        attn_output = output
        output, _ = self.o_proj(attn_output)

        return output

1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
    def _chunked_prefix_attn_mha(
        self,
        q: torch.Tensor,
        accum_output: torch.Tensor,
        accum_lse: torch.Tensor,
        forward_batch: ForwardBatch,
    ) -> torch.Tensor:

        assert forward_batch.num_prefix_chunks is not None
        for i in range(forward_batch.num_prefix_chunks):
            forward_batch.set_prefix_chunk_idx(i)

            # Fetch latent cache from memory pool with precomputed chunked kv indices
            latent_cache_buf = forward_batch.token_to_kv_pool.get_key_buffer(
                self.attn_mha.layer_id
            )
            latent_cache = latent_cache_buf[
                forward_batch.prefix_chunk_kv_indices[i]
            ].contiguous()

            kv_a_normed, k_pe = latent_cache.split(
                [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
            )
            kv_a_normed = kv_a_normed.squeeze(1).contiguous()
            kv = self.kv_b_proj(kv_a_normed)[0]
            kv = kv.view(
                -1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim
            )
            v = kv[..., self.qk_nope_head_dim :]
            k_nope = kv[..., : self.qk_nope_head_dim]

            k = torch.empty(
                (
                    k_nope.shape[0],
                    self.num_local_heads,
                    self.qk_nope_head_dim + self.qk_rope_head_dim,
                ),
                dtype=v.dtype,
                device=v.device,
            )
            k[..., : self.qk_nope_head_dim] = k_nope
            k[..., self.qk_nope_head_dim :] = k_pe

            output, lse = self.attn_mha(q, k, v, forward_batch, save_kv_cache=False)
            lse = torch.transpose(lse, 0, 1).contiguous()
            tmp_output = torch.empty_like(accum_output)
            tmp_lse = torch.empty_like(accum_lse)
            merge_state_v2(output, lse, accum_output, accum_lse, tmp_output, tmp_lse)
            accum_output, accum_lse = tmp_output, tmp_lse

        return accum_output

1696
    def forward_normal_chunked_kv_prepare(
1697
1698
1699
1700
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        forward_batch: ForwardBatch,
1701
1702
        zero_allocator: BumpAllocator,
    ):
1703
1704
1705
1706
1707
1708
1709
1710
        # In normal mha, the k and v tensors will become overly large when the prefix length is long.
        # To avoid this, we split the kv cache into chunks and process them one after another.
        # Since mha is compute friendly, the for loop induced here will not introduce significant overhead.
        # The top comments in https://github.com/vllm-project/vllm/blob/main/vllm/v1/attention/backends/mla/common.py
        # will be helpful for understanding the purpose of this function.

        # First do normal mha forward to get output for extended part
        if self.q_lora_rank is not None:
1711
1712
1713
            q, latent_cache = self.fused_qkv_a_proj_with_mqa(hidden_states)[0].split(
                [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1
            )
1714
1715
1716
1717
1718
1719
            q = self.q_a_layernorm(q)
            q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
        else:
            q = self.q_proj(hidden_states)[0].view(
                -1, self.num_local_heads, self.qk_head_dim
            )
1720
            latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
1721
1722
1723
        _, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
        kv_a, _ = latent_cache.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
        latent_cache = latent_cache.unsqueeze(1)
1724
        kv_a = self.kv_a_layernorm(kv_a)
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
        kv = self.kv_b_proj(kv_a)[0]
        kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim)
        k_nope = kv[..., : self.qk_nope_head_dim]
        v = kv[..., self.qk_nope_head_dim :]
        k_pe = latent_cache[:, :, self.kv_lora_rank :]

        q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
        q[..., self.qk_nope_head_dim :] = q_pe
        k = torch.empty_like(q)
        k[..., : self.qk_nope_head_dim] = k_nope
        k[..., self.qk_nope_head_dim :] = k_pe

        latent_cache[:, :, : self.kv_lora_rank] = kv_a.unsqueeze(1)
        latent_cache[:, :, self.kv_lora_rank :] = k_pe

        # Save latent cache
        forward_batch.token_to_kv_pool.set_kv_buffer(
            self.attn_mha, forward_batch.out_cache_loc, latent_cache, None
        )

1745
1746
1747
        return q, k, v, forward_batch

    def forward_normal_chunked_kv_core(self, q, k, v, forward_batch):
1748
1749
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
        # Do mha for extended part without prefix
        forward_batch.set_attn_attend_prefix_cache(False)
        attn_output, lse = self.attn_mha(q, k, v, forward_batch, save_kv_cache=False)
        lse = torch.transpose(lse, 0, 1).contiguous()

        # Do mha attention with chunked prefix cache if there are any sequence with prefix
        if any(forward_batch.extend_prefix_lens_cpu):
            # Only initialize the info once
            if forward_batch.num_prefix_chunks is None:
                forward_batch.prepare_chunked_prefix_cache_info(q.device)

            forward_batch.set_attn_attend_prefix_cache(True)
            attn_output = self._chunked_prefix_attn_mha(
                q=q,
                accum_output=attn_output,
                accum_lse=lse,
                forward_batch=forward_batch,
            )

        attn_output = attn_output.reshape(-1, self.num_local_heads * self.v_head_dim)
        output, _ = self.o_proj(attn_output)
        return output

1771

Liangsheng Yin's avatar
Liangsheng Yin committed
1772
1773
1774
1775
1776
1777
1778
class DeepseekV2DecoderLayer(nn.Module):

    def __init__(
        self,
        config: PretrainedConfig,
        layer_id: int,
        quant_config: Optional[QuantizationConfig] = None,
1779
        is_nextn: bool = False,
1780
        prefix: str = "",
1781
        alt_stream: Optional[torch.cuda.Stream] = None,
Liangsheng Yin's avatar
Liangsheng Yin committed
1782
1783
1784
    ) -> None:
        super().__init__()
        self.hidden_size = config.hidden_size
1785
        self.config = config
Liangsheng Yin's avatar
Liangsheng Yin committed
1786
1787
1788
        rope_theta = getattr(config, "rope_theta", 10000)
        rope_scaling = getattr(config, "rope_scaling", None)
        max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
1789
        self.speculative_algorithm = global_server_args_dict["speculative_algorithm"]
Lianmin Zheng's avatar
Lianmin Zheng committed
1790
        self.layer_id = layer_id
1791
        self.is_nextn = is_nextn
Baizhou Zhang's avatar
Baizhou Zhang committed
1792
1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803
1804
1805
1806
1807
1808
1809
        self.self_attn = DeepseekV2AttentionMLA(
            config=config,
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
            qk_nope_head_dim=config.qk_nope_head_dim,
            qk_rope_head_dim=config.qk_rope_head_dim,
            v_head_dim=config.v_head_dim,
            q_lora_rank=(
                config.q_lora_rank if hasattr(config, "q_lora_rank") else None
            ),
            kv_lora_rank=config.kv_lora_rank,
            rope_theta=rope_theta,
            rope_scaling=rope_scaling,
            max_position_embeddings=max_position_embeddings,
            quant_config=quant_config,
            layer_id=layer_id,
            reduce_results=False,
            prefix=add_prefix("self_attn", prefix),
1810
            alt_stream=alt_stream,
Baizhou Zhang's avatar
Baizhou Zhang committed
1811
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
1812

1813
1814
1815
1816
1817
        self.is_layer_sparse = self._is_layer_sparse(layer_id, is_nextn=is_nextn)
        is_previous_layer_sparse = self._is_layer_sparse(layer_id - 1, is_nextn=False)

        self.layer_scatter_modes = LayerScatterModes.init_new(
            layer_id=layer_id,
1818
            num_layers=1 if is_nextn else config.num_hidden_layers,
1819
1820
            is_layer_sparse=self.is_layer_sparse,
            is_previous_layer_sparse=is_previous_layer_sparse,
1821
1822
        )

1823
        if self.is_layer_sparse:
1824
1825
1826
1827
            self.mlp = DeepseekV2MoE(
                config=config,
                quant_config=quant_config,
                prefix=add_prefix("mlp", prefix),
fzyzcjy's avatar
fzyzcjy committed
1828
                layer_id=self.layer_id,
1829
                alt_stream=alt_stream,
1830
                is_nextn=is_nextn,
1831
            )
Liangsheng Yin's avatar
Liangsheng Yin committed
1832
        else:
1833
            if enable_moe_dense_fully_dp():
1834
1835
1836
                mlp_tp_rank, mlp_tp_size = 0, 1
            else:
                mlp_tp_rank, mlp_tp_size = None, None
Liangsheng Yin's avatar
Liangsheng Yin committed
1837
1838
1839
1840
1841
            self.mlp = DeepseekV2MLP(
                hidden_size=config.hidden_size,
                intermediate_size=config.intermediate_size,
                hidden_act=config.hidden_act,
                quant_config=quant_config,
1842
                prefix=add_prefix("mlp", prefix),
1843
1844
                tp_rank=mlp_tp_rank,
                tp_size=mlp_tp_size,
Liangsheng Yin's avatar
Liangsheng Yin committed
1845
            )
1846

Liangsheng Yin's avatar
Liangsheng Yin committed
1847
1848
1849
1850
1851
        self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(
            config.hidden_size, eps=config.rms_norm_eps
        )

1852
1853
1854
1855
        self.layer_communicator = LayerCommunicator(
            layer_scatter_modes=self.layer_scatter_modes,
            input_layernorm=self.input_layernorm,
            post_attention_layernorm=self.post_attention_layernorm,
1856
            allow_reduce_scatter=True,
1857
1858
1859
            is_last_layer=(
                is_nextn or (self.layer_id == self.config.num_hidden_layers - 1)
            ),
1860
        )
1861
1862
1863
1864
1865
1866

    def _is_layer_sparse(self, layer_id: int, is_nextn: bool) -> bool:
        return is_nextn or (
            self.config.n_routed_experts is not None
            and layer_id >= self.config.first_k_dense_replace
            and layer_id % self.config.moe_layer_freq == 0
1867
1868
        )

Liangsheng Yin's avatar
Liangsheng Yin committed
1869
1870
1871
1872
    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
1873
        forward_batch: ForwardBatch,
Liangsheng Yin's avatar
Liangsheng Yin committed
1874
        residual: Optional[torch.Tensor],
1875
        zero_allocator: BumpAllocator,
Liangsheng Yin's avatar
Liangsheng Yin committed
1876
    ) -> torch.Tensor:
1877

1878
1879
        hidden_states, residual = self.layer_communicator.prepare_attn(
            hidden_states, residual, forward_batch
1880
1881
        )

1882
1883
1884
1885
1886
1887
1888
1889
1890
1891
1892
        hidden_states = self.self_attn(
            positions=positions,
            hidden_states=hidden_states,
            forward_batch=forward_batch,
            zero_allocator=zero_allocator,
        )

        hidden_states, residual = self.layer_communicator.prepare_mlp(
            hidden_states, residual, forward_batch
        )

1893
        should_allreduce_fusion = (
1894
1895
            self.layer_communicator.should_fuse_mlp_allreduce_with_next_layer(
                forward_batch
1896
            )
1897
1898
        )

1899
1900
1901
1902
1903
        # For DP with padding, reduce scatter can be used instead of all-reduce.
        use_reduce_scatter = self.layer_communicator.should_use_reduce_scatter(
            forward_batch
        )
        hidden_states = self.mlp(
1904
            hidden_states, forward_batch, should_allreduce_fusion, use_reduce_scatter
1905
        )
1906

1907
        if should_allreduce_fusion:
1908
1909
            hidden_states._sglang_needs_allreduce_fusion = True

1910
        if not should_allreduce_fusion:
1911
1912
1913
1914
            hidden_states, residual = self.layer_communicator.postprocess_layer(
                hidden_states, residual, forward_batch
            )

1915
1916
        return hidden_states, residual

1917
1918
1919
1920
1921
1922
1923
1924
    def op_comm_prepare_attn(
        self,
        state,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
        forward_batch: ForwardBatch,
        residual: Optional[torch.Tensor],
        zero_allocator: BumpAllocator,
1925
        tbo_subbatch_index: Optional[int] = None,
1926
1927
    ):
        state.hidden_states_after_comm_pre_attn, state.residual_after_input_ln = (
fzyzcjy's avatar
fzyzcjy committed
1928
            self.layer_communicator.prepare_attn(hidden_states, residual, forward_batch)
1929
1930
1931
1932
1933
1934
        )
        state.update(
            dict(
                forward_batch=forward_batch,
                positions=positions,
                zero_allocator=zero_allocator,
1935
                tbo_subbatch_index=tbo_subbatch_index,
1936
            )
1937
        )
1938

1939
1940
1941
1942
1943
1944
1945
    def op_comm_prepare_mlp(self, state):
        state.hidden_states_mlp_input, state.residual_after_comm_pre_mlp = (
            self.layer_communicator.prepare_mlp(
                state.pop("hidden_states_after_attn"),
                state.pop("residual_after_input_ln"),
                state.forward_batch,
            )
1946
        )
1947

1948
1949
1950
1951
1952
1953
1954
1955
    def op_mlp(self, state):
        hidden_states = state.pop("hidden_states_mlp_input")
        if not (
            enable_moe_dense_fully_dp()
            and (not self.is_layer_sparse)
            and hidden_states.shape[0] == 0
        ):
            state.hidden_states_mlp_output = self.mlp(
1956
                hidden_states, state.forward_batch
1957
1958
1959
            )
        else:
            state.hidden_states_mlp_output = hidden_states
1960

1961
    def op_comm_postprocess_layer(self, state):
1962
        hidden_states, residual = self.layer_communicator.postprocess_layer(
1963
1964
1965
            state.pop("hidden_states_mlp_output"),
            state.pop("residual_after_comm_pre_mlp"),
            state.forward_batch,
1966
        )
1967

1968
1969
1970
1971
1972
1973
1974
1975
1976
1977
1978
1979
1980
1981
1982
1983
1984
1985
        output = dict(
            positions=state.positions,
            hidden_states=hidden_states,
            residual=residual,
            forward_batch=state.forward_batch,
            zero_allocator=state.zero_allocator,
            tbo_subbatch_index=state.tbo_subbatch_index,
        )

        state.clear(
            expect_keys={
                "positions",
                "forward_batch",
                "zero_allocator",
                "tbo_subbatch_index",
            }
        )
        return output
1986

Liangsheng Yin's avatar
Liangsheng Yin committed
1987
1988
1989
1990
1991
1992
1993
1994

class DeepseekV2Model(nn.Module):
    fall_back_to_pt_during_load = False

    def __init__(
        self,
        config: PretrainedConfig,
        quant_config: Optional[QuantizationConfig] = None,
1995
        prefix: str = "",
Liangsheng Yin's avatar
Liangsheng Yin committed
1996
1997
1998
1999
    ) -> None:
        super().__init__()
        self.padding_id = config.pad_token_id
        self.vocab_size = config.vocab_size
2000
        self.first_k_dense_replace = config.first_k_dense_replace
2001
2002
2003
2004
2005
2006
2007
2008
2009
2010
        self.pp_group = get_pp_group()

        if self.pp_group.is_first_rank:
            self.embed_tokens = VocabParallelEmbedding(
                config.vocab_size,
                config.hidden_size,
                enable_tp=not is_dp_attention_enabled(),
            )
        else:
            self.embed_tokens = PPMissingLayer()
Liangsheng Yin's avatar
Liangsheng Yin committed
2011

2012
        self.alt_stream = torch.cuda.Stream() if _is_cuda else None
2013
2014
2015
2016
2017
2018
2019
2020
2021
2022
2023
2024
        self.layers, self.start_layer, self.end_layer = make_layers(
            config.num_hidden_layers,
            lambda idx, prefix: DeepseekV2DecoderLayer(
                config=config,
                layer_id=idx,
                quant_config=quant_config,
                prefix=prefix,
                alt_stream=self.alt_stream,
            ),
            pp_rank=self.pp_group.rank_in_group,
            pp_size=self.pp_group.world_size,
            prefix=add_prefix("layers", prefix),
Liangsheng Yin's avatar
Liangsheng Yin committed
2025
        )
2026
2027
2028
2029
        if self.pp_group.is_last_rank:
            self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        else:
            self.norm = PPMissingLayer(return_tuple=True)
Liangsheng Yin's avatar
Liangsheng Yin committed
2030

2031
2032
2033
    def get_input_embeddings(self) -> torch.Tensor:
        return self.embed_tokens

Liangsheng Yin's avatar
Liangsheng Yin committed
2034
2035
2036
2037
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
2038
        forward_batch: ForwardBatch,
2039
        input_embeds: torch.Tensor = None,
2040
2041
2042
        pp_proxy_tensors: Optional[PPProxyTensors] = None,
    ) -> Union[torch.Tensor, PPProxyTensors]:
        total_num_layers = self.end_layer - self.start_layer
2043
        device = input_embeds.device if input_embeds is not None else input_ids.device
2044
        zero_allocator = BumpAllocator(
2045
            buffer_size=total_num_layers * 2 * (2 if forward_batch.can_run_tbo else 1),
2046
            dtype=torch.float32,
2047
            device=device,
2048
        )
Lianmin Zheng's avatar
Lianmin Zheng committed
2049

2050
2051
2052
2053
2054
2055
        if self.pp_group.is_first_rank:
            if input_embeds is None:
                hidden_states = self.embed_tokens(input_ids)
            else:
                hidden_states = input_embeds
            residual = None
2056
        else:
2057
2058
2059
            assert pp_proxy_tensors is not None
            hidden_states = pp_proxy_tensors["hidden_states"]
            residual = pp_proxy_tensors["residual"]
2060

2061
2062
2063
2064
2065
2066
2067
2068
2069
2070
        normal_start_layer = self.start_layer
        normal_end_layer = self.end_layer
        if forward_batch.can_run_tbo:
            if (
                self.first_k_dense_replace > normal_start_layer
                and self.first_k_dense_replace < normal_end_layer
            ):
                normal_end_layer = self.first_k_dense_replace
            elif self.first_k_dense_replace < normal_start_layer:
                normal_end_layer = normal_start_layer = 0
2071

2072
        for i in range(normal_start_layer, normal_end_layer):
2073
2074
2075
2076
2077
            with get_global_expert_distribution_recorder().with_current_layer(i):
                layer = self.layers[i]
                hidden_states, residual = layer(
                    positions, hidden_states, forward_batch, residual, zero_allocator
                )
2078

2079
        if normal_end_layer != self.end_layer:
2080
            hidden_states, residual = model_forward_maybe_tbo(
2081
                layers=self.layers[normal_end_layer : self.end_layer],
2082
2083
2084
2085
2086
                enable_tbo=True,
                positions=positions,
                forward_batch=forward_batch,
                hidden_states=hidden_states,
                residual=residual,
2087
                input_data_scatter_mode=self.layers[
2088
                    normal_end_layer - 1
2089
                ].layer_scatter_modes.layer_output_mode,
2090
2091
2092
                zero_allocator=zero_allocator,
            )

2093
2094
2095
2096
2097
2098
2099
2100
2101
2102
2103
2104
2105
        if not self.pp_group.is_last_rank:
            return PPProxyTensors(
                {
                    "hidden_states": hidden_states,
                    "residual": residual,
                }
            )
        else:
            if not forward_batch.forward_mode.is_idle():
                if residual is None:
                    hidden_states = self.norm(hidden_states)
                else:
                    hidden_states, _ = self.norm(hidden_states, residual)
Liangsheng Yin's avatar
Liangsheng Yin committed
2106
2107
2108
2109
        return hidden_states


class DeepseekV2ForCausalLM(nn.Module):
2110
2111
    # for quark model load
    packed_modules_mapping = {}
Liangsheng Yin's avatar
Liangsheng Yin committed
2112
2113
2114
2115
2116

    def __init__(
        self,
        config: PretrainedConfig,
        quant_config: Optional[QuantizationConfig] = None,
2117
        prefix: str = "",
Liangsheng Yin's avatar
Liangsheng Yin committed
2118
2119
    ) -> None:
        super().__init__()
2120
2121
2122
2123
2124
2125
2126
2127
2128
2129
2130
2131

        # for quark model load
        # Fuse q_a_proj and kv_a_proj_with_mqa along output dimension when q_lora_rank is not None
        self.fuse_qkv_a_proj = (
            hasattr(config, "q_lora_rank") and config.q_lora_rank is not None
        )
        if self.fuse_qkv_a_proj:
            self.packed_modules_mapping["fused_qkv_a_proj_with_mqa"] = [
                "q_a_proj",
                "kv_a_proj_with_mqa",
            ]

2132
        self.pp_group = get_pp_group()
Liangsheng Yin's avatar
Liangsheng Yin committed
2133
        self.config = config
2134
        self.tp_size = get_tensor_model_parallel_world_size()
Liangsheng Yin's avatar
Liangsheng Yin committed
2135
        self.quant_config = quant_config
2136
        self.determine_num_fused_shared_experts()
2137
2138
2139
2140
2141
2142
2143
2144
        self.model = DeepseekV2Model(
            config, quant_config, prefix=add_prefix("model", prefix)
        )
        self.lm_head = ParallelLMHead(
            config.vocab_size,
            config.hidden_size,
            quant_config=quant_config,
            prefix=add_prefix("lm_head", prefix),
2145
            use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
2146
2147
2148
        )
        self.logits_processor = LogitsProcessor(config)

2149
2150
2151
2152
2153
2154
2155
2156
2157
2158
2159
2160
        self._routed_experts_weights_of_layer = LazyValue(
            lambda: {
                layer_id: layer.mlp.get_moe_weights()
                for layer_id, layer in enumerate(self.model.layers)
                if isinstance(layer.mlp, DeepseekV2MoE)
            }
        )

    @property
    def routed_experts_weights_of_layer(self):
        return self._routed_experts_weights_of_layer.value

2161
    def determine_num_fused_shared_experts(
2162
2163
        self, architecture: str = "DeepseekV3ForCausalLM"
    ):
2164
2165
2166
2167
2168
2169
2170
2171
        self.num_fused_shared_experts = 0
        if global_server_args_dict["disable_shared_experts_fusion"]:
            return

        # Only Deepseek V3/R1 can use shared experts fusion optimization now.
        disable_reason = None
        if (
            not _is_cuda
2172
            or torch.cuda.get_device_capability("cuda") < (8, 0)
2173
2174
2175
2176
            or self.config.architectures[0] != architecture
            or self.config.n_routed_experts != 256
            or self.config.n_shared_experts != 1
        ):
2177
            disable_reason = "Only Deepseek V3/R1 on NV-platform with capability >= 80 can use shared experts fusion optimization."
2178
2179
        elif get_moe_expert_parallel_world_size() > 1:
            disable_reason = "Deepseek V3/R1 can not use shared experts fusion optimization under expert parallelism."
2180
2181
2182

        if disable_reason is not None:
            global_server_args_dict["disable_shared_experts_fusion"] = True
Cheng Wan's avatar
Cheng Wan committed
2183
            self.num_fused_shared_experts = 0
2184
2185
2186
2187
2188
2189
2190
            log_info_on_rank0(
                logger,
                f"{disable_reason} Shared experts fusion optimization is disabled.",
            )
            return

        self.num_fused_shared_experts = self.config.n_shared_experts
2191

Mick's avatar
Mick committed
2192
2193
2194
    def get_input_embeddings(self) -> nn.Embedding:
        return self.model.embed_tokens

2195
    @torch.no_grad()
Liangsheng Yin's avatar
Liangsheng Yin committed
2196
2197
2198
2199
    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
2200
        forward_batch: ForwardBatch,
2201
        input_embeds: torch.Tensor = None,
2202
        pp_proxy_tensors: Optional[PPProxyTensors] = None,
Liangsheng Yin's avatar
Liangsheng Yin committed
2203
    ) -> torch.Tensor:
2204
2205
        hidden_states = self.model(
            input_ids, positions, forward_batch, input_embeds, pp_proxy_tensors
2206
        )
Liangsheng Yin's avatar
Liangsheng Yin committed
2207

2208
2209
2210
2211
2212
2213
2214
2215
2216
2217
2218
2219
2220
2221
2222
        if self.pp_group.is_last_rank:
            return self.logits_processor(
                input_ids, hidden_states, self.lm_head, forward_batch
            )
        else:
            return hidden_states

    @property
    def start_layer(self):
        return self.model.start_layer

    @property
    def end_layer(self):
        return self.model.end_layer

2223
    def post_load_weights(self, is_nextn=False, weight_names=None):
inkcherry's avatar
inkcherry committed
2224
2225

        # Perform post-processing after loading weights
2226
2227
2228
2229
        if is_nextn:
            layer_ids = [self.config.num_hidden_layers]
        else:
            if weight_names is None:
2230
                layer_ids = range(self.model.start_layer, self.model.end_layer)
2231
2232
2233
2234
2235
            else:
                layer_ids = set()
                for name in weight_names:
                    if "kv_b_proj" in name:
                        layer_id = int(name.split(".")[2])
2236
                        if layer_id < self.config.num_hidden_layers:
2237
2238
                            layer_ids.add(layer_id)

2239
2240
2241
2242
2243
2244
        for layer_id in layer_ids:
            self_attn = (
                self.model.layers[layer_id].self_attn
                if not is_nextn
                else self.model.decoder.self_attn
            )
Baizhou Zhang's avatar
Baizhou Zhang committed
2245
2246
            if hasattr(self_attn.kv_b_proj, "qweight"):
                # AWQ compatible
2247
                if _is_cuda or _is_hip:
Baizhou Zhang's avatar
Baizhou Zhang committed
2248
2249
2250
2251
2252
                    w = awq_dequantize(
                        self_attn.kv_b_proj.qweight,
                        self_attn.kv_b_proj.scales,
                        self_attn.kv_b_proj.qzeros,
                    ).T
inkcherry's avatar
inkcherry committed
2253
                else:
Baizhou Zhang's avatar
Baizhou Zhang committed
2254
2255
2256
2257
2258
2259
2260
2261
2262
2263
2264
2265
                    w = awq_dequantize(
                        self_attn.kv_b_proj.qweight,
                        self_attn.kv_b_proj.scales,
                        self_attn.kv_b_proj.qzeros,
                        0,
                        0,
                        0,
                    ).T
            else:
                w = self_attn.kv_b_proj.weight
            # NOTE(HandH1998): Since `bmm_fp8` only supports per-tensor scale, we have to requantize `self_attn.kv_b_proj`.
            # This may affect the accuracy of fp8 model.
2266
2267
2268
            # Fix deepseek v3 blockwise bmm by using deep_gemm
            use_deep_gemm_bmm = False

Baizhou Zhang's avatar
Baizhou Zhang committed
2269
2270
2271
2272
            if w.dtype in (
                torch.float8_e4m3fn,
                torch.float8_e4m3fnuz,
            ):
2273
2274
2275
2276
                if (
                    hasattr(self.quant_config, "weight_block_size")
                    and self.quant_config.weight_block_size is not None
                ):
Baizhou Zhang's avatar
Baizhou Zhang committed
2277
                    weight_block_size = self.quant_config.weight_block_size
2278
2279
2280
2281
2282
2283
2284
2285
2286
2287
2288
2289
2290
2291
2292
2293
                    assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
                    if _is_fp8_fnuz:
                        weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
                            weight=w,
                            weight_scale=self_attn.kv_b_proj.weight_scale_inv,
                            input_scale=None,
                        )
                    else:
                        weight = w
                        weight_scale = self_attn.kv_b_proj.weight_scale_inv

                    if (
                        _is_cuda
                        and weight_block_size[0] == 128
                        and weight_block_size[1] == 128
                    ):
2294
2295
2296
2297
                        if (
                            deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
                            and not deep_gemm_wrapper.DEEPGEMM_BLACKWELL
                            and get_bool_env_var("SGL_USE_DEEPGEMM_BMM", "false")
2298
                        ):
2299
2300
                            block_scale = weight_scale
                            use_deep_gemm_bmm = True
2301
                        else:
2302
2303
2304
2305
                            w = block_quant_dequant(
                                weight,
                                weight_scale,
                                weight_block_size,
2306
                                torch.bfloat16,
2307
                            )
2308
2309
2310
2311
2312
                    else:
                        w, scale = block_quant_to_tensor_quant(
                            weight, weight_scale, weight_block_size
                        )
                        self_attn.w_scale = scale
Baizhou Zhang's avatar
Baizhou Zhang committed
2313
                else:
2314
2315
2316
2317
2318
2319
2320
2321
2322
2323
                    if _is_fp8_fnuz:
                        weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
                            weight=w,
                            weight_scale=self_attn.kv_b_proj.weight_scale,
                            input_scale=None,
                        )
                    else:
                        weight = w
                        weight_scale = self_attn.kv_b_proj.weight_scale

Baizhou Zhang's avatar
Baizhou Zhang committed
2324
2325
2326
2327
2328
2329
2330
2331
2332
2333
2334
2335
2336
2337
2338
2339
2340
2341
2342
                    w, scale = channel_quant_to_tensor_quant(weight, weight_scale)
                    self_attn.w_scale = scale

            if w.dtype == torch.int8:
                if hasattr(self.quant_config, "weight_block_size"):
                    # block-wise int8 need it
                    weight_block_size = self.quant_config.weight_block_size
                    if weight_block_size is not None:
                        assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
                        weight = w
                        weight_scale = self_attn.kv_b_proj.weight_scale_inv
                        w = int8_block_dequant(
                            weight, weight_scale, weight_block_size
                        ).to(torch.bfloat16)
                else:
                    # channel-wise int8 need it
                    w = w.to(torch.bfloat16) * self_attn.kv_b_proj.weight_scale.to(
                        torch.bfloat16
                    )
2343

Baizhou Zhang's avatar
Baizhou Zhang committed
2344
2345
2346
            w_kc, w_vc = w.unflatten(
                0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
            ).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
2347
            if not use_deep_gemm_bmm:
2348
2349
2350
2351
2352
2353
                self_attn.w_kc = bind_or_assign(
                    self_attn.w_kc, w_kc.transpose(1, 2).contiguous().transpose(1, 2)
                )
                self_attn.w_vc = bind_or_assign(
                    self_attn.w_vc, w_vc.contiguous().transpose(1, 2)
                )
2354
2355
2356
2357
                if (
                    hasattr(self_attn.kv_b_proj, "weight_scale")
                    and self_attn.w_scale is None
                ):
2358
2359
2360
                    self_attn.w_scale = bind_or_assign(
                        self_attn.w_scale, self_attn.kv_b_proj.weight_scale
                    )
2361
2362
                    if _is_hip:
                        self_attn.w_scale *= 2.0
2363
2364
2365
2366
2367
2368
2369
2370
                # TODO: remove this after adding FP8 support in bmm cpu kernel
                if _is_cpu and _is_cpu_amx_available and w.dtype == torch.float8_e4m3fn:
                    self_attn.w_kc = (
                        self_attn.w_kc.to(torch.bfloat16) * self_attn.w_scale
                    )
                    self_attn.w_vc = (
                        self_attn.w_vc.to(torch.bfloat16) * self_attn.w_scale
                    )
2371
2372
2373
2374
2375
2376
            else:
                num_tiles_k = self_attn.qk_nope_head_dim // weight_block_size[1]
                num_tiles_n = self_attn.v_head_dim // weight_block_size[0]
                ws_kc, ws_vc = block_scale.unflatten(
                    0, (-1, (num_tiles_k + num_tiles_n))
                ).split([num_tiles_k, num_tiles_n], dim=1)
2377
2378
2379
2380
2381
2382
2383
2384
2385
2386
                self_attn.w_scale_k = bind_or_assign(
                    self_attn.w_scale_k, ws_kc.transpose(1, 2).contiguous()
                )
                self_attn.w_scale_v = bind_or_assign(
                    self_attn.w_scale_v, ws_vc.contiguous()
                )
                self_attn.w_kc = bind_or_assign(
                    self_attn.w_kc, w_kc.transpose(1, 2).contiguous()
                )
                self_attn.w_vc = bind_or_assign(self_attn.w_vc, w_vc.contiguous())
2387
                self_attn.use_deep_gemm_bmm = True
inkcherry's avatar
inkcherry committed
2388

2389
2390
2391
        if (
            deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
            and deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
2392
2393
            and hasattr(self.quant_config, "weight_block_size")
            and self.quant_config.weight_block_size is not None
2394
        ):
2395
            self._weight_requant_ue8m0(is_nextn)
2396

2397
    def _weight_requant_ue8m0(self, is_nextn=False):
2398
2399
2400
2401
2402
2403
2404
2405
2406
2407
        weight_block_size = self.quant_config.weight_block_size

        moe_layers = list(
            range(
                self.config.first_k_dense_replace,
                self.config.num_hidden_layers,
                self.config.moe_layer_freq,
            )
        )

2408
2409
2410
2411
2412
2413
        num_hidden_layers = 1 if is_nextn else self.config.num_hidden_layers
        for layer_id in range(num_hidden_layers):
            if is_nextn:
                layer = self.model.decoder
            else:
                layer = self.model.layers[layer_id]
2414
2415
2416
2417
2418
2419
2420
2421
2422
2423
2424

            for module in [
                layer.self_attn.fused_qkv_a_proj_with_mqa,
                layer.self_attn.q_b_proj,
                layer.self_attn.kv_b_proj,
                layer.self_attn.o_proj,
            ]:
                requant_weight_ue8m0_inplace(
                    module.weight, module.weight_scale_inv, weight_block_size
                )

2425
            if layer_id in moe_layers or is_nextn:
2426
2427
2428
2429
2430
2431
2432
2433
2434
                shared_experts = getattr(layer.mlp, "shared_experts", None)
                if shared_experts is not None:
                    for module in [
                        shared_experts.gate_up_proj,
                        shared_experts.down_proj,
                    ]:
                        requant_weight_ue8m0_inplace(
                            module.weight, module.weight_scale_inv, weight_block_size
                        )
2435
2436
2437
2438
2439
2440
2441
2442
2443
2444
2445
2446
2447
2448
2449
2450
2451
2452
2453

                experts = layer.mlp.experts
                if isinstance(experts, DeepEPMoE):
                    for w in [
                        experts.w13_weight_fp8,
                        experts.w2_weight_fp8,
                    ]:
                        requant_weight_ue8m0_inplace(w[0], w[1], weight_block_size)
            else:
                mlp = layer.mlp
                assert isinstance(mlp, DeepseekV2MLP)
                for module in [
                    mlp.gate_up_proj,
                    mlp.down_proj,
                ]:
                    requant_weight_ue8m0_inplace(
                        module.weight, module.weight_scale_inv, weight_block_size
                    )

2454
    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], is_nextn=False):
2455

2456
2457
2458
        if is_nextn:
            if hasattr(self.config, "num_nextn_predict_layers"):
                num_nextn_layers = self.config.num_nextn_predict_layers
2459
                assert num_nextn_layers == 1, "Only 1 nextn layer is supported"
2460
2461
2462
2463
2464
2465
2466
2467
2468
                # compatible with old design
                nextn_layer_id = (
                    0
                    if self.config.num_hidden_layers == 1
                    else self.config.num_hidden_layers
                )
            else:
                raise ValueError("num_nextn_predict_layers is not in the config")

Liangsheng Yin's avatar
Liangsheng Yin committed
2469
2470
2471
2472
2473
2474
2475
2476
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("gate_up_proj", "gate_proj", 0),
            ("gate_up_proj", "up_proj", 1),
        ]

        # Params for weights, fp8 weight scales, fp8 activation scales
        # (param_name, weight_name, expert_id, shard_id)
2477
        expert_params_mapping = FusedMoE.make_expert_params_mapping(
Liangsheng Yin's avatar
Liangsheng Yin committed
2478
2479
2480
            ckpt_gate_proj_name="gate_proj",
            ckpt_down_proj_name="down_proj",
            ckpt_up_proj_name="up_proj",
2481
            num_experts=self.config.n_routed_experts + self.num_fused_shared_experts,
Liangsheng Yin's avatar
Liangsheng Yin committed
2482
        )
2483
        if self.quant_config and self.quant_config.get_name() == "w4afp8":
2484
2485
            expert_params_mapping += FusedMoE.make_expert_input_scale_params_mapping(
                num_experts=self.config.n_routed_experts
2486
            )
Liangsheng Yin's avatar
Liangsheng Yin committed
2487

2488
2489
2490
2491
2492
2493
        # Fuse q_a_proj and kv_a_proj_with_mqa along output dimension when q_lora_rank is not None
        fuse_qkv_a_proj = hasattr(self.config, "q_lora_rank") and (
            self.config.q_lora_rank is not None
        )
        cached_a_proj = {} if fuse_qkv_a_proj else None

2494
2495
2496
2497
2498
2499
2500
2501
2502
        if is_nextn:
            nextn_layer_prefix = f"model.layers.{nextn_layer_id}"
            nextn_spec_weight_names = [
                "shared_head.norm",
                "eh_proj",
                "enorm",
                "hnorm",
            ]

2503
2504
        if self.num_fused_shared_experts > 0:
            assert self.num_fused_shared_experts == 1
2505
            log_info_on_rank0(logger, "Shared experts fusion optimization enabled.")
2506

2507
2508
2509
2510
2511
        with concurrent.futures.ThreadPoolExecutor() as executor:
            futures = []
            params_dict = dict(self.named_parameters())
            weight_names = []
            for name, loaded_weight in weights:
2512
2513
2514
2515
2516
2517
2518
2519
2520
2521
                layer_id = get_layer_id(name)
                if (
                    layer_id is not None
                    and hasattr(self.model, "start_layer")
                    and (
                        layer_id < self.model.start_layer
                        or layer_id >= self.model.end_layer
                    )
                ):
                    continue
2522
2523
2524
2525
2526
                if self.num_fused_shared_experts > 0 and "mlp.shared_experts" in name:
                    name = name.replace(
                        "mlp.shared_experts",
                        f"mlp.experts.{self.config.n_routed_experts}",
                    )
2527

2528
                weight_names.append(name)
2529

2530
2531
2532
2533
2534
2535
2536
2537
2538
2539
2540
2541
2542
                if not is_nextn:
                    if hasattr(self.config, "num_nextn_predict_layers"):
                        num_nextn_layers = self.config.num_nextn_predict_layers
                        if num_nextn_layers > 0 and name.startswith("model.layers"):
                            name_list = name.split(".")
                            if (
                                len(name_list) >= 3
                                and int(name_list[2]) >= self.config.num_hidden_layers
                            ):
                                continue
                else:
                    if not name.startswith(nextn_layer_prefix):
                        continue
2543

2544
2545
2546
                    # Use shared head and embed weights from target model
                    if "shared_head.head" in name or "embed_tokens" in name:
                        continue
2547

2548
2549
2550
2551
2552
2553
2554
2555
2556
2557
2558
2559
                    is_decoder = True
                    # For nextn specific weights
                    for weight_name in nextn_spec_weight_names:
                        if weight_name in name:
                            name = name.replace(nextn_layer_prefix, "model")
                            is_decoder = False
                            break
                    # For decoder layer weights
                    if is_decoder:
                        name = name.replace(nextn_layer_prefix, "model.decoder")

                if "rotary_emb.inv_freq" in name:
Liangsheng Yin's avatar
Liangsheng Yin committed
2560
                    continue
2561
2562
                for param_name, weight_name, shard_id in stacked_params_mapping:
                    # Skip non-stacked layers and experts (experts handled below).
Liangsheng Yin's avatar
Liangsheng Yin committed
2563
2564
                    if weight_name not in name:
                        continue
2565
2566
2567
2568
2569
2570
2571
2572
                    # We have mlp.experts[0].gate_proj in the checkpoint.
                    # Since we handle the experts below in expert_params_mapping,
                    # we need to skip here BEFORE we update the name, otherwise
                    # name will be updated to mlp.experts[0].gate_up_proj, which
                    # will then be updated below in expert_params_mapping
                    # for mlp.experts[0].gate_gate_up_proj, which breaks load.
                    if ("mlp.experts." in name) and name not in params_dict:
                        continue
Liangsheng Yin's avatar
Liangsheng Yin committed
2573
                    name = name.replace(weight_name, param_name)
2574
2575
2576
                    # Skip loading extra bias for GPTQ models.
                    if name.endswith(".bias") and name not in params_dict:
                        continue
Liangsheng Yin's avatar
Liangsheng Yin committed
2577
2578
                    param = params_dict[name]
                    weight_loader = param.weight_loader
2579
2580
                    futures.append(
                        executor.submit(weight_loader, param, loaded_weight, shard_id)
Liangsheng Yin's avatar
Liangsheng Yin committed
2581
2582
2583
                    )
                    break
                else:
2584
2585
2586
2587
2588
2589
2590
2591
2592
2593
2594
2595
2596
2597
2598
2599
                    for mapping in expert_params_mapping:
                        param_name, weight_name, expert_id, shard_id = mapping
                        if weight_name not in name:
                            continue
                        name = name.replace(weight_name, param_name)
                        param = params_dict[name]
                        weight_loader = param.weight_loader
                        futures.append(
                            executor.submit(
                                weight_loader,
                                param,
                                loaded_weight,
                                name,
                                shard_id=shard_id,
                                expert_id=expert_id,
                            )
2600
                        )
2601
2602
2603
2604
2605
                        break
                    else:
                        # Skip loading extra bias for GPTQ models.
                        if name.endswith(".bias") and name not in params_dict:
                            continue
2606
2607
2608
2609
2610
2611
                        # Skip loading embed_tokens if not first rank in pipeline parallelism
                        if ".embed_tokens." in name and not self.pp_group.is_first_rank:
                            continue
                        # Skip loading norm if not last rank in pipeline parallelism
                        if ".norm." in name and not self.pp_group.is_last_rank:
                            continue
2612
2613
                        if fuse_qkv_a_proj and (
                            "q_a_proj" in name or "kv_a_proj_with_mqa" in name
2614
                        ):
2615
2616
2617
                            cached_a_proj[name] = loaded_weight
                            q_a_proj_name = (
                                name
2618
                                if "q_a_proj" in name
2619
2620
2621
2622
2623
2624
                                else name.replace("kv_a_proj_with_mqa", "q_a_proj")
                            )
                            kv_a_proj_name = (
                                name
                                if "kv_a_proj_with_mqa" in name
                                else name.replace("q_a_proj", "kv_a_proj_with_mqa")
2625
2626
                            )

2627
2628
2629
2630
2631
2632
2633
2634
2635
2636
                            # When both q_a_proj and kv_a_proj_with_mqa has been cached, load the fused weight to parameter
                            if (
                                q_a_proj_name in cached_a_proj
                                and kv_a_proj_name in cached_a_proj
                            ):
                                q_a_proj_weight = cached_a_proj[q_a_proj_name]
                                kv_a_proj_weight = cached_a_proj[kv_a_proj_name]
                                cat_dim = 0
                                if self.quant_config is not None and (
                                    self.quant_config.get_name() == "awq"
2637
                                    or self.quant_config.get_name() == "awq_marlin"
2638
2639
2640
2641
2642
2643
2644
2645
2646
2647
2648
2649
2650
2651
2652
2653
2654
2655
2656
2657
2658
2659
2660
2661
2662
2663
2664
2665
2666
2667
2668
2669
2670
2671
2672
2673
2674
2675
2676
2677
2678
2679
2680
2681
                                    or self.quant_config.get_name() == "moe_wna16"
                                ):
                                    cat_dim = 1
                                fused_weight = torch.cat(
                                    [q_a_proj_weight, kv_a_proj_weight], dim=cat_dim
                                )
                                param_name = (
                                    name.replace(
                                        "q_a_proj", "fused_qkv_a_proj_with_mqa"
                                    )
                                    if "q_a_proj" in name
                                    else name.replace(
                                        "kv_a_proj_with_mqa",
                                        "fused_qkv_a_proj_with_mqa",
                                    )
                                )
                                param = params_dict[param_name]

                                weight_loader = getattr(
                                    param, "weight_loader", default_weight_loader
                                )
                                futures.append(
                                    executor.submit(weight_loader, param, fused_weight)
                                )
                                cached_a_proj.pop(q_a_proj_name)
                                cached_a_proj.pop(kv_a_proj_name)
                        else:
                            if (
                                "k_scale" in name or "v_scale" in name
                            ) and name not in params_dict:
                                # modelopt attn kv scale is named differently
                                for scale in ["k_scale", "v_scale"]:
                                    if scale in name:
                                        name = name.replace(
                                            f"{scale[0]}_proj", "attn_mqa"
                                        )
                                        break
                            if name not in params_dict:
                                # modelopt ckpt contains not needed weights for MTP module:
                                # model.decoder.self_attn.attn_mqa.v_scale and
                                # model.decoder.self_attn.attn_mqa.k_scale
                                logger.warning(f"{name} not found in params_dict.")
                                continue
                            param = params_dict[name]
2682
2683
2684
                            weight_loader = getattr(
                                param, "weight_loader", default_weight_loader
                            )
2685
2686
2687
2688
2689
2690
2691
                            futures.append(
                                executor.submit(weight_loader, param, loaded_weight)
                            )

            # Wait for all tasks to complete and raise any exceptions.
            for future in concurrent.futures.as_completed(futures):
                future.result()
Liangsheng Yin's avatar
Liangsheng Yin committed
2692

2693
        self.post_load_weights(is_nextn=is_nextn, weight_names=weight_names)
Ke Bao's avatar
Ke Bao committed
2694

2695
2696
2697
2698
2699
2700
2701
2702
2703
2704
2705
    def get_embed_and_head(self):
        return self.model.embed_tokens.weight, self.lm_head.weight

    def set_embed_and_head(self, embed, head):
        del self.model.embed_tokens.weight
        del self.lm_head.weight
        self.model.embed_tokens.weight = embed
        self.lm_head.weight = head
        torch.cuda.empty_cache()
        torch.cuda.synchronize()

2706
2707
2708
2709
2710
2711
2712
2713
    @classmethod
    def get_model_config_for_expert_location(cls, config):
        return ModelConfigForExpertLocation(
            num_layers=config.num_hidden_layers,
            num_logical_experts=config.n_routed_experts,
            num_groups=config.n_group,
        )

Liangsheng Yin's avatar
Liangsheng Yin committed
2714

HandH1998's avatar
HandH1998 committed
2715
2716
2717
2718
2719
class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM):
    pass


EntryClass = [DeepseekV2ForCausalLM, DeepseekV3ForCausalLM]