Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
47824c14
Unverified
Commit
47824c14
authored
Aug 07, 2025
by
Xiaoyu Zhang
Committed by
GitHub
Aug 07, 2025
Browse files
[Perf] Auto enable best flashinfer mxfp4 kernel in b200 (#8898)
parent
c36a6693
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
48 additions
and
48 deletions
+48
-48
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
+4
-4
python/sglang/srt/layers/quantization/mxfp4.py
python/sglang/srt/layers/quantization/mxfp4.py
+24
-27
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+1
-0
python/sglang/srt/models/gpt_oss.py
python/sglang/srt/models/gpt_oss.py
+9
-5
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+10
-12
No files found.
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
View file @
47824c14
...
...
@@ -206,13 +206,13 @@ class FusedMoE(torch.nn.Module):
assert
self
.
quant_method
is
not
None
self
.
quant_config
=
quant_config
self
.
use_enable_flashinfer_mxfp4_moe
=
global_server_args_dict
.
get
(
"enable_flashinfer_mxfp4_moe"
,
False
)
if
(
self
.
quant_config
is
not
None
and
self
.
quant_config
.
get_name
()
==
"mxfp4"
and
(
get_bool_env_var
(
"SGLANG_USE_FLASHINFER_MXFP4_MOE"
)
or
get_bool_env_var
(
"SGLANG_USE_FLASHINFER_MXFP4_BF16_MOE"
)
)
and
self
.
use_enable_flashinfer_mxfp4_moe
):
hidden_size
=
round_up
(
hidden_size
,
256
)
self
.
hidden_size
=
hidden_size
...
...
python/sglang/srt/layers/quantization/mxfp4.py
View file @
47824c14
...
...
@@ -3,22 +3,20 @@
from
__future__
import
annotations
import
importlib
import
importlib
.util
import
logging
from
typing
import
TYPE_CHECKING
,
Callable
,
List
,
Optional
from
typing
import
TYPE_CHECKING
,
List
,
Optional
import
torch
from
torch.nn.parameter
import
Parameter
# from vllm.model_executor.layers.fused_moe import (
# FusedMoE, FusedMoEActivationFormat, FusedMoEConfig, FusedMoEMethodBase,
# FusedMoEPermuteExpertsUnpermute, FusedMoEPrepareAndFinalize)
from
sglang.srt.layers.quantization.base_config
import
(
FusedMoEMethodBase
,
QuantizationConfig
,
QuantizeMethodBase
,
)
from
sglang.srt.layers.quantization.utils
import
is_layer_skipped
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.utils
import
(
direct_register_custom_op
,
get_bool_env_var
,
...
...
@@ -32,11 +30,6 @@ from sglang.srt.utils import (
has_triton_kernels
=
importlib
.
util
.
find_spec
(
"triton_kernels"
)
is
not
None
# Environment variables for FlashInfer MXFP4 MoE backend
USE_FLASHINFER_MXFP4_MOE
=
get_bool_env_var
(
"SGLANG_USE_FLASHINFER_MXFP4_MOE"
,
"false"
)
USE_FLASHINFER_MXFP4_BF16_MOE
=
get_bool_env_var
(
"SGLANG_USE_FLASHINFER_MXFP4_BF16_MOE"
,
"false"
)
if
is_flashinfer_available
():
# from flashinfer.fused_moe import cutlass_fused_moe
...
...
@@ -193,7 +186,12 @@ class Mxfp4Config(QuantizationConfig):
):
return
UnquantizedLinearMethod
()
elif
isinstance
(
layer
,
FusedMoE
):
return
Mxfp4MoEMethod
(
use_triton_kernels
=
True
,
with_bias
=
True
)
use_flashinfer
=
global_server_args_dict
.
get
(
"enable_flashinfer_mxfp4_moe"
,
False
)
return
Mxfp4MoEMethod
(
use_triton_kernels
=
True
,
with_bias
=
True
,
use_flashinfer
=
use_flashinfer
)
else
:
raise
NotImplementedError
(
"Mxfp4 attention layer is not implemented"
)
return
None
...
...
@@ -204,11 +202,18 @@ class Mxfp4Config(QuantizationConfig):
class
Mxfp4MoEMethod
(
FusedMoEMethodBase
):
def
__init__
(
self
,
use_triton_kernels
:
bool
=
True
,
with_bias
:
bool
=
True
):
def
__init__
(
self
,
use_triton_kernels
:
bool
=
True
,
with_bias
:
bool
=
True
,
use_flashinfer
:
bool
=
False
,
):
super
().
__init__
()
self
.
topk_indices_dtype
=
None
self
.
use_triton_kernels
=
use_triton_kernels
self
.
with_bias
=
with_bias
self
.
use_flashinfer
=
use_flashinfer
self
.
triton_kernel_moe_forward
=
None
self
.
triton_kernel_moe_with_bias_forward
=
None
if
torch
.
cuda
.
is_available
()
and
has_triton_kernels
:
...
...
@@ -239,7 +244,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
# pad the intermediate size to be a multiple of 2 * mxfp4_block
# for to hold non-uniform sharded tensor as well as swizzling
if
USE_FLASHINFER_MXFP4_MOE
or
USE_FLASHINFER_MXFP4_BF16_MOE
:
if
self
.
use_flashinfer
:
intermediate_size_per_partition_after_pad
=
round_up
(
intermediate_size
,
256
)
hidden_size
=
round_up
(
hidden_size
,
256
)
elif
is_hip
():
...
...
@@ -319,7 +324,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
set_weight_attrs
(
w2_weight_bias
,
extra_weight_attrs
)
def
process_weights_after_loading
(
self
,
layer
):
if
USE_FLASHINFER_MXFP4_MOE
or
USE_FLASHINFER_MXFP4_BF16_MOE
:
if
self
.
use_flashinfer
:
logger
.
info
(
"Shuffling MoE weights for FlashInfer, it might take a while..."
)
...
...
@@ -544,20 +549,12 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
activation_alpha
:
Optional
[
float
]
=
None
,
swiglu_limit
:
Optional
[
float
]
=
None
,
)
->
torch
.
Tensor
:
if
USE_FLASHINFER_MXFP4_MOE
or
USE_FLASHINFER_MXFP4_BF16_MOE
:
# When USE_FLASHINFER_MXFP4_BF16_MOE is enabled, we don't need to quantize the input,
# TRT-LLM automatically handles quantization in the kernel implementation and pipelines it with GEMM operations,
# which can theoretically improve performance
if
USE_FLASHINFER_MXFP4_BF16_MOE
:
assert
x
.
dtype
==
torch
.
bfloat16
x_quant
=
x
x_scale
=
None
else
:
x_quant
,
x_scale
=
mxfp8_quantize
(
x
,
False
)
# to mxfp8
x_scale
=
x_scale
.
view
(
torch
.
float8_e4m3fn
).
reshape
(
-
1
)
if
self
.
use_flashinfer
:
# Based on profiling results, we need to quantize x to mxfp8 here to achieve better performance
x_quant
,
x_scale
=
mxfp8_quantize
(
x
,
False
)
# to mxfp8
x_scale
=
x_scale
.
view
(
torch
.
float8_e4m3fn
).
reshape
(
-
1
)
topk_weights
,
topk_ids
,
router_logits
=
topk_output
top_k
=
topk_weights
.
shape
[
-
1
]
top_k
,
router_logits
=
topk_output
trtllm_gen_output
=
trtllm_fp4_block_scale_moe
(
router_logits
.
to
(
torch
.
bfloat16
),
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
47824c14
...
...
@@ -107,6 +107,7 @@ GLOBAL_SERVER_ARGS_KEYS = [
"num_reserved_decode_tokens"
,
"weight_loader_disable_mmap"
,
"enable_triton_kernel_moe"
,
"enable_flashinfer_mxfp4_moe"
,
"enable_multimodal"
,
"enable_symm_mem"
,
"quantization"
,
...
...
python/sglang/srt/models/gpt_oss.py
View file @
47824c14
...
...
@@ -102,11 +102,15 @@ class GptOssSparseMoeBlock(nn.Module):
f
"the number of experts
{
config
.
num_local_experts
}
."
)
self
.
topk
=
TopK
(
top_k
=
config
.
num_experts_per_tok
,
renormalize
=
True
,
)
if
global_server_args_dict
[
"enable_flashinfer_mxfp4_moe"
]:
self
.
topk
=
None
else
:
self
.
topk
=
TopK
(
top_k
=
config
.
num_experts_per_tok
,
renormalize
=
True
,
)
self
.
top_k
=
config
.
num_experts_per_tok
experts_type
=
get_moe_impl_class
()
extra_kwargs
=
{}
if
experts_type
.
__name__
==
"FusedMoE"
:
...
...
@@ -176,7 +180,7 @@ class GptOssSparseMoeBlock(nn.Module):
if
self
.
topk
is
not
None
:
kwargs
[
"topk_output"
]
=
self
.
topk
(
hidden_states
,
router_logits
)
else
:
kwargs
[
"
router_logits"
]
=
router_logits
kwargs
[
"
topk_output"
]
=
(
self
.
top_k
,
router_logits
)
final_hidden_states
=
self
.
experts
(
**
kwargs
)
if
self
.
tp_size
>
1
:
...
...
python/sglang/srt/server_args.py
View file @
47824c14
...
...
@@ -248,6 +248,7 @@ class ServerArgs:
disable_fast_image_processor
:
bool
=
False
enable_return_hidden_states
:
bool
=
False
enable_triton_kernel_moe
:
bool
=
False
enable_flashinfer_mxfp4_moe
:
bool
=
False
# Debug tensor dumps
debug_tensor_dump_output_folder
:
Optional
[
str
]
=
None
...
...
@@ -476,18 +477,10 @@ class ServerArgs:
or
self
.
attention_backend
==
"triton"
)
# Check if FlashInfer MXFP4 MoE is enabled
from
sglang.srt.utils
import
get_bool_env_var
USE_FLASHINFER_MXFP4_MOE
=
get_bool_env_var
(
"SGLANG_USE_FLASHINFER_MXFP4_MOE"
,
"false"
)
USE_FLASHINFER_MXFP4_BF16_MOE
=
get_bool_env_var
(
"SGLANG_USE_FLASHINFER_MXFP4_BF16_MOE"
,
"false"
)
# Only enable Triton kernel MoE if FlashInfer is not enabled
if
not
(
USE_FLASHINFER_MXFP4_MOE
or
USE_FLASHINFER_MXFP4_BF16_MOE
):
if
is_sm100_supported
():
self
.
enable_flashinfer_mxfp4_moe
=
True
self
.
enable_triton_kernel_moe
=
False
else
:
self
.
enable_triton_kernel_moe
=
True
self
.
disable_hybrid_swa_memory
=
True
...
...
@@ -1846,6 +1839,11 @@ class ServerArgs:
action
=
"store_true"
,
help
=
"Use triton moe grouped gemm kernel."
,
)
parser
.
add_argument
(
"--enable-flashinfer-mxfp4-moe"
,
action
=
"store_true"
,
help
=
"Enable FlashInfer MXFP4 MoE backend for modelopt_fp4 quant on Blackwell."
,
)
# Debug tensor dumps
parser
.
add_argument
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment