Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
f80371ff
Unverified
Commit
f80371ff
authored
Oct 23, 2025
by
b8zhong
Committed by
GitHub
Oct 23, 2025
Browse files
Use flashinfer_trtllm moe runner backend to gain around 10% perf on b200 fp8 dpsk (#11816)
parent
62eff37b
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
76 additions
and
65 deletions
+76
-65
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
+1
-1
python/sglang/srt/layers/quantization/fp8.py
python/sglang/srt/layers/quantization/fp8.py
+65
-59
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+10
-5
No files found.
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
View file @
f80371ff
...
@@ -232,7 +232,7 @@ class FusedMoE(torch.nn.Module):
...
@@ -232,7 +232,7 @@ class FusedMoE(torch.nn.Module):
self
.
quant_method
,
ModelOptNvFp4FusedMoEMethod
self
.
quant_method
,
ModelOptNvFp4FusedMoEMethod
)
or
(
)
or
(
isinstance
(
self
.
quant_method
,
Fp8MoEMethod
)
isinstance
(
self
.
quant_method
,
Fp8MoEMethod
)
and
self
.
quant_method
.
use_cutlass_fused_experts
_fp8
and
self
.
quant_method
.
_should_
use_cutlass_fused_experts
()
)
)
def
_load_per_tensor_weight_scale
(
def
_load_per_tensor_weight_scale
(
...
...
python/sglang/srt/layers/quantization/fp8.py
View file @
f80371ff
...
@@ -33,6 +33,7 @@ from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading
...
@@ -33,6 +33,7 @@ from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading
from
sglang.srt.layers.moe
import
MoeRunner
,
MoeRunnerBackend
,
MoeRunnerConfig
from
sglang.srt.layers.moe
import
MoeRunner
,
MoeRunnerBackend
,
MoeRunnerConfig
from
sglang.srt.layers.moe.moe_runner.deep_gemm
import
DeepGemmMoeQuantInfo
from
sglang.srt.layers.moe.moe_runner.deep_gemm
import
DeepGemmMoeQuantInfo
from
sglang.srt.layers.moe.moe_runner.triton
import
TritonMoeQuantInfo
from
sglang.srt.layers.moe.moe_runner.triton
import
TritonMoeQuantInfo
from
sglang.srt.layers.moe.utils
import
get_moe_runner_backend
from
sglang.srt.layers.parameter
import
(
from
sglang.srt.layers.parameter
import
(
BlockQuantScaleParameter
,
BlockQuantScaleParameter
,
ModelWeightParameter
,
ModelWeightParameter
,
...
@@ -525,12 +526,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
...
@@ -525,12 +526,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
self
.
quant_config
=
quant_config
self
.
quant_config
=
quant_config
self
.
block_quant
=
self
.
quant_config
.
weight_block_size
is
not
None
self
.
block_quant
=
self
.
quant_config
.
weight_block_size
is
not
None
self
.
cutlass_fp8_supported
=
cutlass_fp8_supported
()
self
.
cutlass_fp8_supported
=
cutlass_fp8_supported
()
self
.
use_cutlass_fused_experts_fp8
=
(
get_bool_env_var
(
"SGLANG_CUTLASS_MOE"
)
and
self
.
cutlass_fp8_supported
and
self
.
block_quant
and
(
is_sm100_supported
()
or
is_sm90_supported
())
)
def
create_weights
(
def
create_weights
(
self
,
self
,
...
@@ -638,58 +633,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
...
@@ -638,58 +633,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
layer
.
register_parameter
(
"w13_weight_scale_inv"
,
w13_weight_scale
)
layer
.
register_parameter
(
"w13_weight_scale_inv"
,
w13_weight_scale
)
layer
.
register_parameter
(
"w2_weight_scale_inv"
,
w2_weight_scale
)
layer
.
register_parameter
(
"w2_weight_scale_inv"
,
w2_weight_scale
)
assert
self
.
quant_config
.
activation_scheme
==
"dynamic"
assert
self
.
quant_config
.
activation_scheme
==
"dynamic"
if
self
.
use_cutlass_fused_experts_fp8
:
if
self
.
_should_use_cutlass_fused_experts
():
self
.
ab_strides1
=
torch
.
full
(
self
.
_ensure_cutlass_buffers_initialized
(
layer
)
(
num_experts
,),
hidden_size
,
device
=
w13_weight
.
device
,
dtype
=
torch
.
int64
,
)
self
.
c_strides1
=
torch
.
full
(
(
num_experts
,),
2
*
intermediate_size_per_partition
,
device
=
w13_weight
.
device
,
dtype
=
torch
.
int64
,
)
self
.
ab_strides2
=
torch
.
full
(
(
num_experts
,),
intermediate_size_per_partition
,
device
=
w2_weight
.
device
,
dtype
=
torch
.
int64
,
)
self
.
c_strides2
=
torch
.
full
(
(
num_experts
,),
hidden_size
,
device
=
w2_weight
.
device
,
dtype
=
torch
.
int64
,
)
self
.
workspace
=
torch
.
empty
(
90000
,
device
=
w13_weight
.
device
,
dtype
=
torch
.
uint8
)
self
.
a_ptr
=
torch
.
empty
(
num_experts
,
device
=
w13_weight
.
device
,
dtype
=
torch
.
int64
)
self
.
b_ptr
=
torch
.
empty
(
num_experts
,
device
=
w13_weight
.
device
,
dtype
=
torch
.
int64
)
self
.
out_ptr
=
torch
.
empty
(
num_experts
,
device
=
w13_weight
.
device
,
dtype
=
torch
.
int64
)
self
.
a_scales_ptr
=
torch
.
empty
(
num_experts
,
device
=
w13_weight
.
device
,
dtype
=
torch
.
int64
)
self
.
b_scales_ptr
=
torch
.
empty
(
num_experts
,
device
=
w13_weight
.
device
,
dtype
=
torch
.
int64
)
self
.
expert_offsets
=
torch
.
empty
(
num_experts
+
1
,
device
=
w13_weight
.
device
,
dtype
=
torch
.
int32
)
self
.
problem_sizes1
=
torch
.
empty
(
num_experts
,
3
,
device
=
w13_weight
.
device
,
dtype
=
torch
.
int32
)
self
.
problem_sizes2
=
torch
.
empty
(
num_experts
,
3
,
device
=
w13_weight
.
device
,
dtype
=
torch
.
int32
)
else
:
else
:
# Allocate 2 scales for w1 and w3 respectively.
# Allocate 2 scales for w1 and w3 respectively.
...
@@ -1079,7 +1024,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
...
@@ -1079,7 +1024,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
if
ret
is
not
None
:
if
ret
is
not
None
:
return
StandardCombineInput
(
hidden_states
=
ret
)
return
StandardCombineInput
(
hidden_states
=
ret
)
if
self
.
use_cutlass_fused_experts
_fp8
:
if
self
.
_should_
use_cutlass_fused_experts
()
:
from
sglang.srt.layers.moe.cutlass_moe
import
cutlass_fused_experts_fp8
from
sglang.srt.layers.moe.cutlass_moe
import
cutlass_fused_experts_fp8
topk_weights
,
topk_ids
,
_
=
topk_output
topk_weights
,
topk_ids
,
_
=
topk_output
...
@@ -1171,6 +1116,67 @@ class Fp8MoEMethod(FusedMoEMethodBase):
...
@@ -1171,6 +1116,67 @@ class Fp8MoEMethod(FusedMoEMethodBase):
return
self
.
runner
.
run
(
dispatch_output
,
quant_info
)
return
self
.
runner
.
run
(
dispatch_output
,
quant_info
)
def
_should_use_cutlass_fused_experts
(
self
)
->
bool
:
"""Decide whether to use Cutlass FP8 fused-experts path based on moe runner backend,
with env var override via `SGLANG_CUTLASS_MOE`.
"""
backend
=
get_moe_runner_backend
()
env_force
=
get_bool_env_var
(
"SGLANG_CUTLASS_MOE"
)
# TODO: remove env var in the future, it should be handled by moe runner backend
if
env_force
:
return
True
return
(
backend
.
is_flashinfer_cutlass
()
and
self
.
cutlass_fp8_supported
and
self
.
block_quant
and
(
is_sm100_supported
()
or
is_sm90_supported
())
)
def
_ensure_cutlass_buffers_initialized
(
self
,
layer
:
Module
)
->
None
:
if
getattr
(
self
,
"_cutlass_buffers_ready"
,
False
):
return
device
=
layer
.
w13_weight
.
device
num_experts
=
layer
.
w13_weight
.
shape
[
0
]
hidden_size
=
layer
.
w2_weight
.
shape
[
1
]
intermediate_size_per_partition
=
layer
.
intermediate_size_per_partition
self
.
ab_strides1
=
torch
.
full
(
(
num_experts
,),
hidden_size
,
device
=
device
,
dtype
=
torch
.
int64
)
self
.
c_strides1
=
torch
.
full
(
(
num_experts
,),
2
*
intermediate_size_per_partition
,
device
=
device
,
dtype
=
torch
.
int64
,
)
self
.
ab_strides2
=
torch
.
full
(
(
num_experts
,),
intermediate_size_per_partition
,
device
=
device
,
dtype
=
torch
.
int64
,
)
self
.
c_strides2
=
torch
.
full
(
(
num_experts
,),
hidden_size
,
device
=
device
,
dtype
=
torch
.
int64
)
self
.
workspace
=
torch
.
empty
(
90000
,
device
=
device
,
dtype
=
torch
.
uint8
)
self
.
a_ptr
=
torch
.
empty
(
num_experts
,
device
=
device
,
dtype
=
torch
.
int64
)
self
.
b_ptr
=
torch
.
empty
(
num_experts
,
device
=
device
,
dtype
=
torch
.
int64
)
self
.
out_ptr
=
torch
.
empty
(
num_experts
,
device
=
device
,
dtype
=
torch
.
int64
)
self
.
a_scales_ptr
=
torch
.
empty
(
num_experts
,
device
=
device
,
dtype
=
torch
.
int64
)
self
.
b_scales_ptr
=
torch
.
empty
(
num_experts
,
device
=
device
,
dtype
=
torch
.
int64
)
self
.
expert_offsets
=
torch
.
empty
(
num_experts
+
1
,
device
=
device
,
dtype
=
torch
.
int32
)
self
.
problem_sizes1
=
torch
.
empty
(
num_experts
,
3
,
device
=
device
,
dtype
=
torch
.
int32
)
self
.
problem_sizes2
=
torch
.
empty
(
num_experts
,
3
,
device
=
device
,
dtype
=
torch
.
int32
)
self
.
_cutlass_buffers_ready
=
True
def
apply_with_router_logits
(
def
apply_with_router_logits
(
self
,
self
,
layer
:
torch
.
nn
.
Module
,
layer
:
torch
.
nn
.
Module
,
...
...
python/sglang/srt/server_args.py
View file @
f80371ff
...
@@ -892,14 +892,19 @@ class ServerArgs:
...
@@ -892,14 +892,19 @@ class ServerArgs:
logger
.
info
(
logger
.
info
(
"Enable FlashInfer AllReduce Fusion on sm100 for DeepseekV3ForCausalLM"
"Enable FlashInfer AllReduce Fusion on sm100 for DeepseekV3ForCausalLM"
)
)
if
(
if
self
.
moe_runner_backend
==
"auto"
:
self
.
quantization
==
"modelopt_fp4"
and
self
.
moe_runner_backend
==
"auto"
):
self
.
moe_runner_backend
=
"flashinfer_trtllm"
self
.
moe_runner_backend
=
"flashinfer_trtllm"
logger
.
info
(
logger
.
info
(
"Use flashinfer_trtllm as
moe
runner backend on sm100 for DeepseekV3ForCausalLM"
"Use flashinfer_trtllm as
MoE
runner backend on sm100 for DeepseekV3ForCausalLM"
)
)
if
self
.
quantization
is
None
:
# Default DeepSeek V3/R1 native FP8 when not explicitly set,
# Because we need this condition for an assertion in
# flashinfer_trtllm MoE runner backend.
self
.
quantization
=
"fp8"
logger
.
info
(
"Quantization not specified, default to fp8 for DeepSeek on sm100"
)
elif
model_arch
in
[
"GptOssForCausalLM"
]:
elif
model_arch
in
[
"GptOssForCausalLM"
]:
if
(
if
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment