Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
47824c14
You need to sign in or sign up before continuing.
Unverified
Commit
47824c14
authored
Aug 07, 2025
by
Xiaoyu Zhang
Committed by
GitHub
Aug 07, 2025
Browse files
[Perf] Auto enable best flashinfer mxfp4 kernel in b200 (#8898)
parent
c36a6693
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
48 additions
and
48 deletions
+48
-48
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
+4
-4
python/sglang/srt/layers/quantization/mxfp4.py
python/sglang/srt/layers/quantization/mxfp4.py
+24
-27
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+1
-0
python/sglang/srt/models/gpt_oss.py
python/sglang/srt/models/gpt_oss.py
+9
-5
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+10
-12
No files found.
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
View file @
47824c14
...
...
@@ -206,13 +206,13 @@ class FusedMoE(torch.nn.Module):
assert
self
.
quant_method
is
not
None
self
.
quant_config
=
quant_config
self
.
use_enable_flashinfer_mxfp4_moe
=
global_server_args_dict
.
get
(
"enable_flashinfer_mxfp4_moe"
,
False
)
if
(
self
.
quant_config
is
not
None
and
self
.
quant_config
.
get_name
()
==
"mxfp4"
and
(
get_bool_env_var
(
"SGLANG_USE_FLASHINFER_MXFP4_MOE"
)
or
get_bool_env_var
(
"SGLANG_USE_FLASHINFER_MXFP4_BF16_MOE"
)
)
and
self
.
use_enable_flashinfer_mxfp4_moe
):
hidden_size
=
round_up
(
hidden_size
,
256
)
self
.
hidden_size
=
hidden_size
...
...
python/sglang/srt/layers/quantization/mxfp4.py
View file @
47824c14
...
...
@@ -3,22 +3,20 @@
from
__future__
import
annotations
import
importlib
import
importlib
.util
import
logging
from
typing
import
TYPE_CHECKING
,
Callable
,
List
,
Optional
from
typing
import
TYPE_CHECKING
,
List
,
Optional
import
torch
from
torch.nn.parameter
import
Parameter
# from vllm.model_executor.layers.fused_moe import (
# FusedMoE, FusedMoEActivationFormat, FusedMoEConfig, FusedMoEMethodBase,
# FusedMoEPermuteExpertsUnpermute, FusedMoEPrepareAndFinalize)
from
sglang.srt.layers.quantization.base_config
import
(
FusedMoEMethodBase
,
QuantizationConfig
,
QuantizeMethodBase
,
)
from
sglang.srt.layers.quantization.utils
import
is_layer_skipped
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.utils
import
(
direct_register_custom_op
,
get_bool_env_var
,
...
...
@@ -32,11 +30,6 @@ from sglang.srt.utils import (
has_triton_kernels
=
importlib
.
util
.
find_spec
(
"triton_kernels"
)
is
not
None
# Environment variables for FlashInfer MXFP4 MoE backend
USE_FLASHINFER_MXFP4_MOE
=
get_bool_env_var
(
"SGLANG_USE_FLASHINFER_MXFP4_MOE"
,
"false"
)
USE_FLASHINFER_MXFP4_BF16_MOE
=
get_bool_env_var
(
"SGLANG_USE_FLASHINFER_MXFP4_BF16_MOE"
,
"false"
)
if
is_flashinfer_available
():
# from flashinfer.fused_moe import cutlass_fused_moe
...
...
@@ -193,7 +186,12 @@ class Mxfp4Config(QuantizationConfig):
):
return
UnquantizedLinearMethod
()
elif
isinstance
(
layer
,
FusedMoE
):
return
Mxfp4MoEMethod
(
use_triton_kernels
=
True
,
with_bias
=
True
)
use_flashinfer
=
global_server_args_dict
.
get
(
"enable_flashinfer_mxfp4_moe"
,
False
)
return
Mxfp4MoEMethod
(
use_triton_kernels
=
True
,
with_bias
=
True
,
use_flashinfer
=
use_flashinfer
)
else
:
raise
NotImplementedError
(
"Mxfp4 attention layer is not implemented"
)
return
None
...
...
@@ -204,11 +202,18 @@ class Mxfp4Config(QuantizationConfig):
class
Mxfp4MoEMethod
(
FusedMoEMethodBase
):
def
__init__
(
self
,
use_triton_kernels
:
bool
=
True
,
with_bias
:
bool
=
True
):
def
__init__
(
self
,
use_triton_kernels
:
bool
=
True
,
with_bias
:
bool
=
True
,
use_flashinfer
:
bool
=
False
,
):
super
().
__init__
()
self
.
topk_indices_dtype
=
None
self
.
use_triton_kernels
=
use_triton_kernels
self
.
with_bias
=
with_bias
self
.
use_flashinfer
=
use_flashinfer
self
.
triton_kernel_moe_forward
=
None
self
.
triton_kernel_moe_with_bias_forward
=
None
if
torch
.
cuda
.
is_available
()
and
has_triton_kernels
:
...
...
@@ -239,7 +244,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
# pad the intermediate size to be a multiple of 2 * mxfp4_block
# for to hold non-uniform sharded tensor as well as swizzling
if
USE_FLASHINFER_MXFP4_MOE
or
USE_FLASHINFER_MXFP4_BF16_MOE
:
if
self
.
use_flashinfer
:
intermediate_size_per_partition_after_pad
=
round_up
(
intermediate_size
,
256
)
hidden_size
=
round_up
(
hidden_size
,
256
)
elif
is_hip
():
...
...
@@ -319,7 +324,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
set_weight_attrs
(
w2_weight_bias
,
extra_weight_attrs
)
def
process_weights_after_loading
(
self
,
layer
):
if
USE_FLASHINFER_MXFP4_MOE
or
USE_FLASHINFER_MXFP4_BF16_MOE
:
if
self
.
use_flashinfer
:
logger
.
info
(
"Shuffling MoE weights for FlashInfer, it might take a while..."
)
...
...
@@ -544,20 +549,12 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
activation_alpha
:
Optional
[
float
]
=
None
,
swiglu_limit
:
Optional
[
float
]
=
None
,
)
->
torch
.
Tensor
:
if
USE_FLASHINFER_MXFP4_MOE
or
USE_FLASHINFER_MXFP4_BF16_MOE
:
# When USE_FLASHINFER_MXFP4_BF16_MOE is enabled, we don't need to quantize the input,
# TRT-LLM automatically handles quantization in the kernel implementation and pipelines it with GEMM operations,
# which can theoretically improve performance
if
USE_FLASHINFER_MXFP4_BF16_MOE
:
assert
x
.
dtype
==
torch
.
bfloat16
x_quant
=
x
x_scale
=
None
else
:
if
self
.
use_flashinfer
:
# Based on profiling results, we need to quantize x to mxfp8 here to achieve better performance
x_quant
,
x_scale
=
mxfp8_quantize
(
x
,
False
)
# to mxfp8
x_scale
=
x_scale
.
view
(
torch
.
float8_e4m3fn
).
reshape
(
-
1
)
topk_weights
,
topk_ids
,
router_logits
=
topk_output
top_k
=
topk_weights
.
shape
[
-
1
]
top_k
,
router_logits
=
topk_output
trtllm_gen_output
=
trtllm_fp4_block_scale_moe
(
router_logits
.
to
(
torch
.
bfloat16
),
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
47824c14
...
...
@@ -107,6 +107,7 @@ GLOBAL_SERVER_ARGS_KEYS = [
"num_reserved_decode_tokens"
,
"weight_loader_disable_mmap"
,
"enable_triton_kernel_moe"
,
"enable_flashinfer_mxfp4_moe"
,
"enable_multimodal"
,
"enable_symm_mem"
,
"quantization"
,
...
...
python/sglang/srt/models/gpt_oss.py
View file @
47824c14
...
...
@@ -102,11 +102,15 @@ class GptOssSparseMoeBlock(nn.Module):
f
"the number of experts
{
config
.
num_local_experts
}
."
)
if
global_server_args_dict
[
"enable_flashinfer_mxfp4_moe"
]:
self
.
topk
=
None
else
:
self
.
topk
=
TopK
(
top_k
=
config
.
num_experts_per_tok
,
renormalize
=
True
,
)
self
.
top_k
=
config
.
num_experts_per_tok
experts_type
=
get_moe_impl_class
()
extra_kwargs
=
{}
if
experts_type
.
__name__
==
"FusedMoE"
:
...
...
@@ -176,7 +180,7 @@ class GptOssSparseMoeBlock(nn.Module):
if
self
.
topk
is
not
None
:
kwargs
[
"topk_output"
]
=
self
.
topk
(
hidden_states
,
router_logits
)
else
:
kwargs
[
"
router_logits"
]
=
router_logits
kwargs
[
"
topk_output"
]
=
(
self
.
top_k
,
router_logits
)
final_hidden_states
=
self
.
experts
(
**
kwargs
)
if
self
.
tp_size
>
1
:
...
...
python/sglang/srt/server_args.py
View file @
47824c14
...
...
@@ -248,6 +248,7 @@ class ServerArgs:
disable_fast_image_processor
:
bool
=
False
enable_return_hidden_states
:
bool
=
False
enable_triton_kernel_moe
:
bool
=
False
enable_flashinfer_mxfp4_moe
:
bool
=
False
# Debug tensor dumps
debug_tensor_dump_output_folder
:
Optional
[
str
]
=
None
...
...
@@ -476,18 +477,10 @@ class ServerArgs:
or
self
.
attention_backend
==
"triton"
)
# Check if FlashInfer MXFP4 MoE is enabled
from
sglang.srt.utils
import
get_bool_env_var
USE_FLASHINFER_MXFP4_MOE
=
get_bool_env_var
(
"SGLANG_USE_FLASHINFER_MXFP4_MOE"
,
"false"
)
USE_FLASHINFER_MXFP4_BF16_MOE
=
get_bool_env_var
(
"SGLANG_USE_FLASHINFER_MXFP4_BF16_MOE"
,
"false"
)
# Only enable Triton kernel MoE if FlashInfer is not enabled
if
not
(
USE_FLASHINFER_MXFP4_MOE
or
USE_FLASHINFER_MXFP4_BF16_MOE
):
if
is_sm100_supported
():
self
.
enable_flashinfer_mxfp4_moe
=
True
self
.
enable_triton_kernel_moe
=
False
else
:
self
.
enable_triton_kernel_moe
=
True
self
.
disable_hybrid_swa_memory
=
True
...
...
@@ -1846,6 +1839,11 @@ class ServerArgs:
action
=
"store_true"
,
help
=
"Use triton moe grouped gemm kernel."
,
)
parser
.
add_argument
(
"--enable-flashinfer-mxfp4-moe"
,
action
=
"store_true"
,
help
=
"Enable FlashInfer MXFP4 MoE backend for modelopt_fp4 quant on Blackwell."
,
)
# Debug tensor dumps
parser
.
add_argument
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment