Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
0374304a
Unverified
Commit
0374304a
authored
Aug 23, 2025
by
fzyzcjy
Committed by
GitHub
Aug 23, 2025
Browse files
Add enable_flashinfer_mxfp4_bf16_moe for higher precision and slower moe backend (#9004)
parent
127d4b0d
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
37 additions
and
5 deletions
+37
-5
python/sglang/srt/layers/quantization/mxfp4.py
python/sglang/srt/layers/quantization/mxfp4.py
+27
-5
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+1
-0
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+9
-0
No files found.
python/sglang/srt/layers/quantization/mxfp4.py
View file @
0374304a
...
...
@@ -30,6 +30,7 @@ from sglang.srt.layers.quantization.base_config import (
)
from
sglang.srt.layers.quantization.utils
import
is_layer_skipped
from
sglang.srt.layers.utils
import
is_sm100_supported
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.utils
import
(
direct_register_custom_op
,
get_bool_env_var
,
...
...
@@ -262,6 +263,9 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
self
.
use_triton_kernels
=
get_moe_runner_backend
().
is_triton_kernel
()
self
.
with_bias
=
False
self
.
use_flashinfer
=
get_moe_runner_backend
().
is_flashinfer_mxfp4
()
self
.
flashinfer_mxfp4_moe_precision
=
global_server_args_dict
[
"flashinfer_mxfp4_moe_precision"
]
self
.
triton_kernel_moe_forward
=
None
self
.
triton_kernel_moe_with_bias_forward
=
None
...
...
@@ -615,11 +619,29 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
from
sglang.srt.layers.moe.topk
import
TopKOutputChecker
if
self
.
use_flashinfer
:
# Based on profiling results, we need to quantize x to mxfp8 here to achieve better performance
x_quant
,
x_scale
=
mxfp8_quantize
(
x
,
False
,
alignment
=
self
.
hidden_size
)
# to mxfp8
x_scale
=
x_scale
.
view
(
torch
.
float8_e4m3fn
).
reshape
(
-
1
)
# When bf16 mode is enabled, we don't need to quantize the input,
# TRT-LLM automatically handles quantization in the kernel implementation and pipelines it with GEMM operations,
# which can theoretically improve performance
if
self
.
flashinfer_mxfp4_moe_precision
==
"bf16"
:
assert
x
.
dtype
==
torch
.
bfloat16
x_quant
=
x
x_scale
=
None
# May be fused later if this code branch is frequently needed
origin_hidden_states_dim
=
x_quant
.
shape
[
-
1
]
if
self
.
hidden_size
!=
origin_hidden_states_dim
:
x_quant
=
torch
.
nn
.
functional
.
pad
(
x_quant
,
(
0
,
self
.
hidden_size
-
origin_hidden_states_dim
),
mode
=
"constant"
,
value
=
0.0
,
)
elif
self
.
flashinfer_mxfp4_moe_precision
==
"default"
:
x_quant
,
x_scale
=
mxfp8_quantize
(
x
,
False
,
alignment
=
self
.
hidden_size
)
x_scale
=
x_scale
.
view
(
torch
.
float8_e4m3fn
).
reshape
(
-
1
)
else
:
raise
NotImplementedError
assert
x_quant
.
shape
[
-
1
]
==
self
.
hidden_size
assert
TopKOutputChecker
.
format_is_bypassed
(
topk_output
)
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
0374304a
...
...
@@ -87,6 +87,7 @@ GLOBAL_SERVER_ARGS_KEYS = [
"disable_flashinfer_cutlass_moe_fp4_allgather"
,
"disable_radix_cache"
,
"enable_dp_lm_head"
,
"flashinfer_mxfp4_moe_precision"
,
"enable_flashinfer_allreduce_fusion"
,
"moe_dense_tp_size"
,
"ep_dispatch_algorithm"
,
...
...
python/sglang/srt/server_args.py
View file @
0374304a
...
...
@@ -190,6 +190,7 @@ class ServerArgs:
"flashinfer_cutlass"
,
"flashinfer_mxfp4"
,
]
=
"auto"
flashinfer_mxfp4_moe_precision
:
Literal
[
"default"
,
"bf16"
]
=
"default"
enable_flashinfer_allreduce_fusion
:
bool
=
False
deepep_mode
:
Literal
[
"auto"
,
"normal"
,
"low_latency"
]
=
"auto"
ep_num_redundant_experts
:
int
=
0
...
...
@@ -1496,10 +1497,18 @@ class ServerArgs:
"triton_kernel"
,
"flashinfer_trtllm"
,
"flashinfer_cutlass"
,
"flashinfer_mxfp4"
,
],
default
=
ServerArgs
.
moe_runner_backend
,
help
=
"Choose the runner backend for MoE."
,
)
parser
.
add_argument
(
"--flashinfer-mxfp4-moe-precision"
,
type
=
str
,
choices
=
[
"mxfp4"
,
"bf16"
],
default
=
ServerArgs
.
flashinfer_mxfp4_moe_precision
,
help
=
"Choose the computation precision of flashinfer mxfp4 moe"
,
)
parser
.
add_argument
(
"--enable-flashinfer-allreduce-fusion"
,
action
=
"store_true"
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment