Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
b65db028
Unverified
Commit
b65db028
authored
Oct 02, 2025
by
fzyzcjy
Committed by
GitHub
Oct 02, 2025
Browse files
Tiny cleanup deepseek_v2.py (#11163)
parent
948278f1
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
38 additions
and
38 deletions
+38
-38
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
+7
-6
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+31
-32
No files found.
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
View file @
b65db028
...
...
@@ -234,6 +234,13 @@ class FusedMoE(torch.nn.Module):
self
.
quant_method
.
create_moe_runner
(
self
,
self
.
moe_runner_config
)
self
.
dispatcher
=
StandardDispatcher
()
self
.
should_fuse_routed_scaling_factor_in_topk
=
isinstance
(
self
.
quant_method
,
ModelOptNvFp4FusedMoEMethod
)
or
(
isinstance
(
self
.
quant_method
,
Fp8MoEMethod
)
and
self
.
quant_method
.
use_cutlass_fused_experts_fp8
)
def
_load_per_tensor_weight_scale
(
self
,
shard_id
:
str
,
...
...
@@ -936,12 +943,6 @@ class FusedMoE(torch.nn.Module):
for
shard_id
in
[
"w1"
,
"w2"
,
"w3"
]
]
def
should_fuse_routed_scaling_factor_in_topk
(
self
):
return
isinstance
(
self
.
quant_method
,
ModelOptNvFp4FusedMoEMethod
)
or
(
isinstance
(
self
.
quant_method
,
Fp8MoEMethod
)
and
self
.
quant_method
.
use_cutlass_fused_experts_fp8
)
class
FlashInferFusedMoE
(
FusedMoE
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
...
...
python/sglang/srt/models/deepseek_v2.py
View file @
b65db028
...
...
@@ -166,16 +166,15 @@ if _is_cuda:
elif
_is_cpu
and
_is_cpu_amx_available
:
pass
elif
_is_hip
:
from
sglang.srt.layers.attention.triton_ops.rocm_mla_decode_rope
import
(
decode_attention_fwd_grouped_rope
,
)
from
sglang.srt.layers.quantization.awq_triton
import
(
awq_dequantize_triton
as
awq_dequantize
,
)
else
:
pass
if
_is_hip
:
from
sglang.srt.layers.attention.triton_ops.rocm_mla_decode_rope
import
(
decode_attention_fwd_grouped_rope
,
)
_is_flashinfer_available
=
is_flashinfer_available
()
_is_sm100_supported
=
is_cuda
()
and
is_sm100_supported
()
...
...
@@ -229,7 +228,7 @@ def _dispatch_mla_subtype(attn, forward_batch):
return
AttnForwardMethod
.
MLA
class
BackendRegistry
:
class
Attention
BackendRegistry
:
_handlers
=
{}
@
classmethod
...
...
@@ -241,7 +240,7 @@ class BackendRegistry:
return
cls
.
_handlers
.
get
(
backend_name
,
cls
.
_handlers
.
get
(
"triton"
))
def
handle_ascend
(
attn
,
forward_batch
):
def
handle_
attention_
ascend
(
attn
,
forward_batch
):
if
(
forward_batch
.
forward_mode
.
is_extend
()
and
not
forward_batch
.
forward_mode
.
is_target_verify
()
...
...
@@ -268,7 +267,7 @@ def _is_extend_without_speculative(forward_batch):
)
def
_handle_backend
(
attn
,
forward_batch
,
backend_name
):
def
_handle_
attention_
backend
(
attn
,
forward_batch
,
backend_name
):
sum_extend_prefix_lens
=
_get_sum_extend_prefix_lens
(
forward_batch
)
disable_ragged
=
(
backend_name
in
[
"flashinfer"
,
"flashmla"
]
...
...
@@ -290,28 +289,28 @@ def _handle_backend(attn, forward_batch, backend_name):
return
_dispatch_mla_subtype
(
attn
,
forward_batch
)
def
handle_flashinfer
(
attn
,
forward_batch
):
return
_handle_backend
(
attn
,
forward_batch
,
"flashinfer"
)
def
handle_
attention_
flashinfer
(
attn
,
forward_batch
):
return
_handle_
attention_
backend
(
attn
,
forward_batch
,
"flashinfer"
)
def
handle_fa3
(
attn
,
forward_batch
):
return
_handle_backend
(
attn
,
forward_batch
,
"fa3"
)
def
handle_
attention_
fa3
(
attn
,
forward_batch
):
return
_handle_
attention_
backend
(
attn
,
forward_batch
,
"fa3"
)
def
handle_flashmla
(
attn
,
forward_batch
):
return
_handle_backend
(
attn
,
forward_batch
,
"flashmla"
)
def
handle_
attention_
flashmla
(
attn
,
forward_batch
):
return
_handle_
attention_
backend
(
attn
,
forward_batch
,
"flashmla"
)
def
handle_cutlass_mla
(
attn
,
forward_batch
):
return
_handle_backend
(
attn
,
forward_batch
,
"cutlass_mla"
)
def
handle_
attention_
cutlass_mla
(
attn
,
forward_batch
):
return
_handle_
attention_
backend
(
attn
,
forward_batch
,
"cutlass_mla"
)
def
handle_fa4
(
attn
,
forward_batch
):
def
handle_
attention_
fa4
(
attn
,
forward_batch
):
# TODO(cicirori): use FA4 MHA for DeepSeekV3 for now
return
AttnForwardMethod
.
MHA_CHUNKED_KV
def
handle_trtllm_mla
(
attn
,
forward_batch
):
def
handle_
attention_
trtllm_mla
(
attn
,
forward_batch
):
sum_extend_prefix_lens
=
_get_sum_extend_prefix_lens
(
forward_batch
)
if
_is_extend_without_speculative
(
forward_batch
)
and
(
not
attn
.
disable_chunked_prefix_cache
or
sum_extend_prefix_lens
==
0
...
...
@@ -321,7 +320,7 @@ def handle_trtllm_mla(attn, forward_batch):
return
_dispatch_mla_subtype
(
attn
,
forward_batch
)
def
handle_aiter
(
attn
,
forward_batch
):
def
handle_
attention_
aiter
(
attn
,
forward_batch
):
if
_is_extend_without_speculative
(
forward_batch
):
if
is_dp_attention_enabled
():
if
sum
(
forward_batch
.
extend_prefix_lens_cpu
)
==
0
:
...
...
@@ -334,7 +333,7 @@ def handle_aiter(attn, forward_batch):
return
AttnForwardMethod
.
MLA
def
handle_triton
(
attn
,
forward_batch
):
def
handle_
attention_
triton
(
attn
,
forward_batch
):
if
(
_is_extend_without_speculative
(
forward_batch
)
and
sum
(
forward_batch
.
extend_prefix_lens_cpu
)
==
0
...
...
@@ -541,7 +540,7 @@ class DeepseekV2MoE(nn.Module):
correction_bias
=
self
.
gate
.
e_score_correction_bias
,
quant_config
=
quant_config
,
routed_scaling_factor
=
self
.
routed_scaling_factor
,
apply_routed_scaling_factor_on_output
=
self
.
experts
.
should_fuse_routed_scaling_factor_in_topk
()
,
apply_routed_scaling_factor_on_output
=
self
.
experts
.
should_fuse_routed_scaling_factor_in_topk
,
# Some Fp4 MoE backends require the output format to be bypassed but the MTP layers are unquantized
# and requires the output format to be standard. We use quant_config to determine the output format.
output_format
=
TopKOutputFormat
.
STANDARD
if
quant_config
is
None
else
None
,
...
...
@@ -838,13 +837,13 @@ class DeepseekV2MoE(nn.Module):
if
shared_output
is
not
None
:
x
=
shared_output
if
self
.
experts
.
should_fuse_routed_scaling_factor_in_topk
()
:
if
self
.
experts
.
should_fuse_routed_scaling_factor_in_topk
:
x
.
add_
(
final_hidden_states
)
else
:
x
.
add_
(
final_hidden_states
,
alpha
=
self
.
routed_scaling_factor
)
final_hidden_states
=
x
else
:
if
not
self
.
experts
.
should_fuse_routed_scaling_factor_in_topk
()
:
if
not
self
.
experts
.
should_fuse_routed_scaling_factor_in_topk
:
final_hidden_states
*=
self
.
routed_scaling_factor
return
final_hidden_states
...
...
@@ -1217,7 +1216,7 @@ class DeepseekV2AttentionMLA(nn.Module):
attention_backend
=
global_server_args_dict
[
"prefill_attention_backend"
]
self
.
current_attention_backend
=
attention_backend
handler
=
BackendRegistry
.
get_handler
(
attention_backend
)
handler
=
Attention
BackendRegistry
.
get_handler
(
attention_backend
)
return
handler
(
self
,
forward_batch
)
def
op_prepare
(
self
,
state
):
...
...
@@ -3092,15 +3091,15 @@ class DeepseekV2ForCausalLM(nn.Module):
)
BackendRegistry
.
register
(
"ascend"
,
handle_ascend
)
BackendRegistry
.
register
(
"flashinfer"
,
handle_flashinfer
)
BackendRegistry
.
register
(
"fa3"
,
handle_fa3
)
BackendRegistry
.
register
(
"flashmla"
,
handle_flashmla
)
BackendRegistry
.
register
(
"cutlass_mla"
,
handle_cutlass_mla
)
BackendRegistry
.
register
(
"fa4"
,
handle_fa4
)
BackendRegistry
.
register
(
"trtllm_mla"
,
handle_trtllm_mla
)
BackendRegistry
.
register
(
"aiter"
,
handle_aiter
)
BackendRegistry
.
register
(
"triton"
,
handle_triton
)
Attention
BackendRegistry
.
register
(
"ascend"
,
handle_
attention_
ascend
)
Attention
BackendRegistry
.
register
(
"flashinfer"
,
handle_
attention_
flashinfer
)
Attention
BackendRegistry
.
register
(
"fa3"
,
handle_
attention_
fa3
)
Attention
BackendRegistry
.
register
(
"flashmla"
,
handle_
attention_
flashmla
)
Attention
BackendRegistry
.
register
(
"cutlass_mla"
,
handle_
attention_
cutlass_mla
)
Attention
BackendRegistry
.
register
(
"fa4"
,
handle_
attention_
fa4
)
Attention
BackendRegistry
.
register
(
"trtllm_mla"
,
handle_
attention_
trtllm_mla
)
Attention
BackendRegistry
.
register
(
"aiter"
,
handle_
attention_
aiter
)
Attention
BackendRegistry
.
register
(
"triton"
,
handle_
attention_
triton
)
class
DeepseekV3ForCausalLM
(
DeepseekV2ForCausalLM
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment