Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
4373df55
Unverified
Commit
4373df55
authored
Aug 07, 2025
by
Xiaoyu Zhang
Committed by
GitHub
Aug 06, 2025
Browse files
add flashinfer mxfp4 (#8847)
parent
c0e84297
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
230 additions
and
22 deletions
+230
-22
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
+20
-2
python/sglang/srt/layers/quantization/mxfp4.py
python/sglang/srt/layers/quantization/mxfp4.py
+195
-19
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+15
-1
No files found.
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
View file @
4373df55
...
...
@@ -38,6 +38,7 @@ from sglang.srt.utils import (
is_flashinfer_available
,
is_hip
,
next_power_of_2
,
round_up
,
)
if
is_flashinfer_available
():
...
...
@@ -146,7 +147,6 @@ class FusedMoE(torch.nn.Module):
self
.
layer_id
=
layer_id
self
.
top_k
=
top_k
self
.
hidden_size
=
hidden_size
self
.
num_experts
=
num_experts
self
.
num_fused_shared_experts
=
num_fused_shared_experts
self
.
expert_map_cpu
=
None
...
...
@@ -206,6 +206,16 @@ class FusedMoE(torch.nn.Module):
assert
self
.
quant_method
is
not
None
self
.
quant_config
=
quant_config
if
(
self
.
quant_config
is
not
None
and
self
.
quant_config
.
get_name
()
==
"mxfp4"
and
(
get_bool_env_var
(
"SGLANG_USE_FLASHINFER_MXFP4_MOE"
)
or
get_bool_env_var
(
"SGLANG_USE_FLASHINFER_MXFP4_BF16_MOE"
)
)
):
hidden_size
=
round_up
(
hidden_size
,
256
)
self
.
hidden_size
=
hidden_size
self
.
quant_method
.
create_weights
(
layer
=
self
,
num_experts
=
self
.
num_local_experts
,
...
...
@@ -784,6 +794,14 @@ class FusedMoE(torch.nn.Module):
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
topk_output
:
StandardTopKOutput
):
origin_hidden_states_dim
=
hidden_states
.
shape
[
-
1
]
if
self
.
hidden_size
!=
origin_hidden_states_dim
:
hidden_states
=
torch
.
nn
.
functional
.
pad
(
hidden_states
,
(
0
,
self
.
hidden_size
-
origin_hidden_states_dim
),
mode
=
"constant"
,
value
=
0.0
,
)
assert
self
.
quant_method
is
not
None
if
self
.
moe_ep_size
>
1
and
not
self
.
enable_flashinfer_cutlass_moe
:
...
...
@@ -829,7 +847,7 @@ class FusedMoE(torch.nn.Module):
if
self
.
reduce_results
and
(
self
.
moe_tp_size
>
1
or
self
.
moe_ep_size
>
1
):
final_hidden_states
=
tensor_model_parallel_all_reduce
(
final_hidden_states
)
return
final_hidden_states
return
final_hidden_states
[...,
:
origin_hidden_states_dim
].
contiguous
()
@
classmethod
def
make_expert_params_mapping
(
...
...
python/sglang/srt/layers/quantization/mxfp4.py
View file @
4373df55
...
...
@@ -21,6 +21,7 @@ from sglang.srt.layers.quantization.base_config import (
from
sglang.srt.layers.quantization.utils
import
is_layer_skipped
from
sglang.srt.utils
import
(
direct_register_custom_op
,
get_bool_env_var
,
is_cuda
,
is_flashinfer_available
,
is_hip
,
...
...
@@ -31,6 +32,12 @@ from sglang.srt.utils import (
has_triton_kernels
=
importlib
.
util
.
find_spec
(
"triton_kernels"
)
is
not
None
# Environment variables for FlashInfer MXFP4 MoE backend
USE_FLASHINFER_MXFP4_MOE
=
get_bool_env_var
(
"SGLANG_USE_FLASHINFER_MXFP4_MOE"
,
"false"
)
USE_FLASHINFER_MXFP4_BF16_MOE
=
get_bool_env_var
(
"SGLANG_USE_FLASHINFER_MXFP4_BF16_MOE"
,
"false"
)
if
is_flashinfer_available
():
# from flashinfer.fused_moe import cutlass_fused_moe
from
flashinfer
import
(
...
...
@@ -228,16 +235,28 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
self
.
num_experts
=
num_experts
weight_dtype
=
torch
.
uint8
scale_dtype
=
torch
.
uint8
intermediate_size
*=
2
mxfp4_block
=
32
self
.
intermediate_size
=
intermediate_size
# pad the intermediate size to be a multiple of 2 * mxfp4_block
# for to hold non-uniform sharded tensor as well as swizzling
if
USE_FLASHINFER_MXFP4_MOE
or
USE_FLASHINFER_MXFP4_BF16_MOE
:
intermediate_size_per_partition_after_pad
=
round_up
(
intermediate_size
,
256
)
hidden_size
=
round_up
(
hidden_size
,
256
)
elif
is_hip
():
intermediate_size_per_partition_after_pad
=
round_up
(
intermediate_size
,
128
)
else
:
intermediate_size_per_partition_after_pad
=
round_up
(
intermediate_size
,
64
)
self
.
intermediate_size
=
intermediate_size_per_partition_after_pad
self
.
hidden_size
=
hidden_size
# Fused gate_up_proj (column parallel)
w13_weight
=
torch
.
nn
.
Parameter
(
torch
.
zeros
(
num_experts
,
2
*
intermediate_size
,
hidden_size
//
2
,
dtype
=
weight_dtype
num_experts
,
2
*
intermediate_size_per_partition_after_pad
,
hidden_size
//
2
,
dtype
=
weight_dtype
,
),
requires_grad
=
False
,
)
...
...
@@ -247,7 +266,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
w13_weight_scale
=
torch
.
nn
.
Parameter
(
torch
.
zeros
(
num_experts
,
2
*
intermediate_size
,
2
*
intermediate_size
_per_partition_after_pad
,
hidden_size
//
mxfp4_block
,
dtype
=
scale_dtype
,
),
...
...
@@ -257,7 +276,11 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
set_weight_attrs
(
w13_weight_scale
,
extra_weight_attrs
)
w13_weight_bias
=
torch
.
nn
.
Parameter
(
torch
.
zeros
(
num_experts
,
2
*
intermediate_size
,
dtype
=
torch
.
bfloat16
),
torch
.
zeros
(
num_experts
,
2
*
intermediate_size_per_partition_after_pad
,
dtype
=
torch
.
bfloat16
,
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w13_weight_bias"
,
w13_weight_bias
)
...
...
@@ -266,7 +289,10 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
# down_proj (row parallel)
w2_weight
=
torch
.
nn
.
Parameter
(
torch
.
zeros
(
num_experts
,
hidden_size
,
intermediate_size
//
2
,
dtype
=
weight_dtype
num_experts
,
hidden_size
,
intermediate_size_per_partition_after_pad
//
2
,
dtype
=
weight_dtype
,
),
requires_grad
=
False
,
)
...
...
@@ -277,7 +303,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
torch
.
zeros
(
num_experts
,
hidden_size
,
intermediate_size
//
mxfp4_block
,
intermediate_size
_per_partition_after_pad
//
mxfp4_block
,
dtype
=
scale_dtype
,
),
requires_grad
=
False
,
...
...
@@ -293,6 +319,158 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
set_weight_attrs
(
w2_weight_bias
,
extra_weight_attrs
)
def
process_weights_after_loading
(
self
,
layer
):
if
USE_FLASHINFER_MXFP4_MOE
or
USE_FLASHINFER_MXFP4_BF16_MOE
:
logger
.
info
(
"Shuffling MoE weights for FlashInfer, it might take a while..."
)
layer
.
gemm1_alpha
=
Parameter
(
torch
.
tensor
([
1.702
]
*
self
.
num_experts
,
dtype
=
torch
.
float32
).
cuda
(),
requires_grad
=
False
,
)
layer
.
gemm1_beta
=
Parameter
(
torch
.
tensor
([
1.0
]
*
self
.
num_experts
,
dtype
=
torch
.
float32
).
cuda
(),
requires_grad
=
False
,
)
layer
.
gemm1_clamp_limit
=
Parameter
(
torch
.
tensor
([
7.0
]
*
self
.
num_experts
,
dtype
=
torch
.
float32
).
cuda
(),
requires_grad
=
False
,
)
sf_block_size
=
32
# mxfp4 block size
assert
(
layer
.
w13_weight
.
dim
()
==
3
and
layer
.
w13_weight
.
shape
[
0
]
==
self
.
num_experts
and
layer
.
w13_weight
.
shape
[
1
]
==
self
.
intermediate_size
*
2
and
layer
.
w13_weight
.
shape
[
2
]
==
self
.
hidden_size
//
2
)
assert
(
layer
.
w13_weight_scale
.
dim
()
==
3
and
layer
.
w13_weight_scale
.
shape
[
0
]
==
self
.
num_experts
and
layer
.
w13_weight_scale
.
shape
[
1
]
==
self
.
intermediate_size
*
2
and
layer
.
w13_weight_scale
.
shape
[
2
]
==
self
.
hidden_size
//
sf_block_size
)
assert
(
layer
.
w2_weight
.
dim
()
==
3
and
layer
.
w2_weight
.
shape
[
0
]
==
self
.
num_experts
and
layer
.
w2_weight
.
shape
[
1
]
==
self
.
hidden_size
and
layer
.
w2_weight
.
shape
[
2
]
==
self
.
intermediate_size
//
2
)
assert
(
layer
.
w2_weight_scale
.
dim
()
==
3
and
layer
.
w2_weight_scale
.
shape
[
1
]
==
self
.
hidden_size
and
layer
.
w2_weight_scale
.
shape
[
2
]
==
self
.
intermediate_size
//
sf_block_size
)
assert
(
layer
.
w13_weight_bias
.
dim
()
==
2
and
layer
.
w13_weight_bias
.
shape
[
0
]
==
self
.
num_experts
and
layer
.
w13_weight_bias
.
shape
[
1
]
==
self
.
intermediate_size
*
2
)
assert
(
layer
.
w2_weight_bias
.
dim
()
==
2
and
layer
.
w2_weight_bias
.
shape
[
0
]
==
self
.
num_experts
and
layer
.
w2_weight_bias
.
shape
[
1
]
==
self
.
hidden_size
)
w13_weight_scale
=
layer
.
w13_weight_scale
.
data
w2_weight_scale
=
layer
.
w2_weight_scale
.
data
w13_weight
=
layer
.
w13_weight
.
data
w2_weight
=
layer
.
w2_weight
.
data
w13_bias
=
layer
.
w13_weight_bias
.
data
.
to
(
torch
.
float32
)
w2_bias
=
layer
.
w2_weight_bias
.
data
.
to
(
torch
.
float32
)
# Swap w1 and w3 as the definition of
# swiglu is different in the trtllm-gen
def
swap_every_two_rows
(
x
,
axis
=-
1
):
shape
=
x
.
shape
if
axis
<
0
:
axis
=
len
(
shape
)
+
axis
# Create a new shape with pairs swapped along specified axis
new_shape
=
list
(
shape
)
new_shape
[
axis
]
=
shape
[
axis
]
//
2
new_shape
.
insert
(
axis
+
1
,
2
)
# Reshape to expose pairs, swap them, and reshape back
x
=
x
.
reshape
(
*
new_shape
)
x
=
x
.
flip
(
axis
+
1
)
new_shape
=
list
(
shape
)
return
x
.
reshape
(
*
new_shape
)
w13_weight_scale
=
swap_every_two_rows
(
w13_weight_scale
,
-
2
)
w13_weight
=
swap_every_two_rows
(
w13_weight
,
-
2
)
w13_bias
=
swap_every_two_rows
(
w13_bias
,
-
1
)
# Shuffle weights and scaling factors for transposed mma output
gemm1_weights_mxfp4_shuffled
=
[]
gemm1_scales_mxfp4_shuffled
=
[]
gemm2_weights_mxfp4_shuffled
=
[]
gemm2_scales_mxfp4_shuffled
=
[]
gemm1_bias_shuffled
=
[]
gemm2_bias_shuffled
=
[]
epilogue_tile_m
=
128
# FIXME: this depends on the kernel internals
for
i
in
range
(
self
.
num_experts
):
gemm1_weights_mxfp4_shuffled
.
append
(
shuffle_matrix_a
(
w13_weight
[
i
].
view
(
torch
.
uint8
),
epilogue_tile_m
)
)
gemm1_scales_mxfp4_shuffled
.
append
(
shuffle_matrix_sf_a
(
w13_weight_scale
[
i
].
view
(
torch
.
uint8
),
epilogue_tile_m
)
)
gemm1_bias_shuffled
.
append
(
shuffle_matrix_a
(
w13_bias
[
i
].
clone
().
reshape
(
-
1
,
1
),
epilogue_tile_m
)
)
gemm2_weights_mxfp4_shuffled
.
append
(
shuffle_matrix_a
(
w2_weight
[
i
].
view
(
torch
.
uint8
),
epilogue_tile_m
)
)
gemm2_scales_mxfp4_shuffled
.
append
(
shuffle_matrix_sf_a
(
w2_weight_scale
[
i
].
view
(
torch
.
uint8
),
epilogue_tile_m
)
)
gemm2_bias_shuffled
.
append
(
shuffle_matrix_a
(
w2_bias
[
i
].
clone
().
reshape
(
-
1
,
1
),
epilogue_tile_m
)
)
w13_weight
=
torch
.
stack
(
gemm1_weights_mxfp4_shuffled
)
w13_weight_scale
=
(
torch
.
stack
(
gemm1_scales_mxfp4_shuffled
)
.
reshape
(
self
.
num_experts
,
2
*
self
.
intermediate_size
,
self
.
hidden_size
//
sf_block_size
,
)
.
view
(
torch
.
float8_e4m3fn
)
)
w2_weight
=
torch
.
stack
(
gemm2_weights_mxfp4_shuffled
)
w2_weight_scale
=
(
torch
.
stack
(
gemm2_scales_mxfp4_shuffled
)
.
reshape
(
self
.
num_experts
,
self
.
hidden_size
,
self
.
intermediate_size
//
sf_block_size
,
)
.
view
(
torch
.
float8_e4m3fn
)
)
layer
.
w13_weight
=
Parameter
(
w13_weight
,
requires_grad
=
False
)
layer
.
w13_weight_scale
=
Parameter
(
w13_weight_scale
,
requires_grad
=
False
)
layer
.
w2_weight
=
Parameter
(
w2_weight
,
requires_grad
=
False
)
layer
.
w2_weight_scale
=
Parameter
(
w2_weight_scale
,
requires_grad
=
False
)
layer
.
w13_weight_bias
=
Parameter
(
torch
.
stack
(
gemm1_bias_shuffled
).
reshape
(
self
.
num_experts
,
-
1
),
requires_grad
=
False
,
)
layer
.
w2_weight_bias
=
Parameter
(
torch
.
stack
(
gemm2_bias_shuffled
).
reshape
(
self
.
num_experts
,
-
1
),
requires_grad
=
False
,
)
return
from
triton_kernels.matmul_ogs
import
FlexCtx
,
PrecisionConfig
...
...
@@ -366,22 +544,21 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
activation_alpha
:
Optional
[
float
]
=
None
,
swiglu_limit
:
Optional
[
float
]
=
None
,
)
->
torch
.
Tensor
:
# avoid import error when triton_kernel is not installed
# from vllm.model_executor.layers.fused_moe.triton_kernels_moe import (
# triton_kernel_moe_forward)
"""
if (envs.VLLM_USE_FLASHINFER_MXFP4_MOE
or envs.VLLM_USE_FLASHINFER_MXFP4_BF16_MOE):
assert not self.moe.use_ep, (
"EP is not supported for flashinfer mxfp4 moe backend yet.")
if envs.VLLM_USE_FLASHINFER_MXFP4_BF16_MOE:
if
USE_FLASHINFER_MXFP4_MOE
or
USE_FLASHINFER_MXFP4_BF16_MOE
:
# When USE_FLASHINFER_MXFP4_BF16_MOE is enabled, we don't need to quantize the input,
# TRT-LLM automatically handles quantization in the kernel implementation and pipelines it with GEMM operations,
# which can theoretically improve performance
if
USE_FLASHINFER_MXFP4_BF16_MOE
:
assert
x
.
dtype
==
torch
.
bfloat16
x_quant
=
x
x_scale
=
None
else
:
x_quant
,
x_scale
=
mxfp8_quantize
(
x
,
False
)
# to mxfp8
x_scale
=
x_scale
.
view
(
torch
.
float8_e4m3fn
).
reshape
(
-
1
)
topk_weights
,
topk_ids
,
router_logits
=
topk_output
top_k
=
topk_weights
.
shape
[
-
1
]
trtllm_gen_output
=
trtllm_fp4_block_scale_moe
(
router_logits
.
to
(
torch
.
bfloat16
),
None
,
# routing_bias
...
...
@@ -412,7 +589,6 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
True
,
# do finalize
)[
0
]
return
trtllm_gen_output
"""
if
self
.
use_triton_kernels
:
if
self
.
with_bias
:
...
...
python/sglang/srt/server_args.py
View file @
4373df55
...
...
@@ -464,7 +464,21 @@ class ServerArgs:
model_arch
=
self
.
get_hf_config
().
architectures
[
0
]
if
model_arch
in
[
"GptOssForCausalLM"
]:
self
.
attention_backend
=
"triton"
self
.
enable_triton_kernel_moe
=
True
# Check if FlashInfer MXFP4 MoE is enabled
from
sglang.srt.utils
import
get_bool_env_var
USE_FLASHINFER_MXFP4_MOE
=
get_bool_env_var
(
"SGLANG_USE_FLASHINFER_MXFP4_MOE"
,
"false"
)
USE_FLASHINFER_MXFP4_BF16_MOE
=
get_bool_env_var
(
"SGLANG_USE_FLASHINFER_MXFP4_BF16_MOE"
,
"false"
)
# Only enable Triton kernel MoE if FlashInfer is not enabled
if
not
(
USE_FLASHINFER_MXFP4_MOE
or
USE_FLASHINFER_MXFP4_BF16_MOE
):
self
.
enable_triton_kernel_moe
=
True
self
.
disable_hybrid_swa_memory
=
True
quantization_config
=
getattr
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment