Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
4373df55
"src/vscode:/vscode.git/clone" did not exist on "8e53cd959e535f82d49c9719d71269b589fcef7b"
Unverified
Commit
4373df55
authored
Aug 07, 2025
by
Xiaoyu Zhang
Committed by
GitHub
Aug 06, 2025
Browse files
add flashinfer mxfp4 (#8847)
parent
c0e84297
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
230 additions
and
22 deletions
+230
-22
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
+20
-2
python/sglang/srt/layers/quantization/mxfp4.py
python/sglang/srt/layers/quantization/mxfp4.py
+195
-19
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+15
-1
No files found.
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
View file @
4373df55
...
...
@@ -38,6 +38,7 @@ from sglang.srt.utils import (
is_flashinfer_available
,
is_hip
,
next_power_of_2
,
round_up
,
)
if
is_flashinfer_available
():
...
...
@@ -146,7 +147,6 @@ class FusedMoE(torch.nn.Module):
self
.
layer_id
=
layer_id
self
.
top_k
=
top_k
self
.
hidden_size
=
hidden_size
self
.
num_experts
=
num_experts
self
.
num_fused_shared_experts
=
num_fused_shared_experts
self
.
expert_map_cpu
=
None
...
...
@@ -206,6 +206,16 @@ class FusedMoE(torch.nn.Module):
assert
self
.
quant_method
is
not
None
self
.
quant_config
=
quant_config
if
(
self
.
quant_config
is
not
None
and
self
.
quant_config
.
get_name
()
==
"mxfp4"
and
(
get_bool_env_var
(
"SGLANG_USE_FLASHINFER_MXFP4_MOE"
)
or
get_bool_env_var
(
"SGLANG_USE_FLASHINFER_MXFP4_BF16_MOE"
)
)
):
hidden_size
=
round_up
(
hidden_size
,
256
)
self
.
hidden_size
=
hidden_size
self
.
quant_method
.
create_weights
(
layer
=
self
,
num_experts
=
self
.
num_local_experts
,
...
...
@@ -784,6 +794,14 @@ class FusedMoE(torch.nn.Module):
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
topk_output
:
StandardTopKOutput
):
origin_hidden_states_dim
=
hidden_states
.
shape
[
-
1
]
if
self
.
hidden_size
!=
origin_hidden_states_dim
:
hidden_states
=
torch
.
nn
.
functional
.
pad
(
hidden_states
,
(
0
,
self
.
hidden_size
-
origin_hidden_states_dim
),
mode
=
"constant"
,
value
=
0.0
,
)
assert
self
.
quant_method
is
not
None
if
self
.
moe_ep_size
>
1
and
not
self
.
enable_flashinfer_cutlass_moe
:
...
...
@@ -829,7 +847,7 @@ class FusedMoE(torch.nn.Module):
if
self
.
reduce_results
and
(
self
.
moe_tp_size
>
1
or
self
.
moe_ep_size
>
1
):
final_hidden_states
=
tensor_model_parallel_all_reduce
(
final_hidden_states
)
return
final_hidden_states
return
final_hidden_states
[...,
:
origin_hidden_states_dim
].
contiguous
()
@
classmethod
def
make_expert_params_mapping
(
...
...
python/sglang/srt/layers/quantization/mxfp4.py
View file @
4373df55
...
...
@@ -21,6 +21,7 @@ from sglang.srt.layers.quantization.base_config import (
from
sglang.srt.layers.quantization.utils
import
is_layer_skipped
from
sglang.srt.utils
import
(
direct_register_custom_op
,
get_bool_env_var
,
is_cuda
,
is_flashinfer_available
,
is_hip
,
...
...
@@ -31,6 +32,12 @@ from sglang.srt.utils import (
has_triton_kernels
=
importlib
.
util
.
find_spec
(
"triton_kernels"
)
is
not
None
# Environment variables for FlashInfer MXFP4 MoE backend
USE_FLASHINFER_MXFP4_MOE
=
get_bool_env_var
(
"SGLANG_USE_FLASHINFER_MXFP4_MOE"
,
"false"
)
USE_FLASHINFER_MXFP4_BF16_MOE
=
get_bool_env_var
(
"SGLANG_USE_FLASHINFER_MXFP4_BF16_MOE"
,
"false"
)
if
is_flashinfer_available
():
# from flashinfer.fused_moe import cutlass_fused_moe
from
flashinfer
import
(
...
...
@@ -228,16 +235,28 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
self
.
num_experts
=
num_experts
weight_dtype
=
torch
.
uint8
scale_dtype
=
torch
.
uint8
intermediate_size
*=
2
mxfp4_block
=
32
self
.
intermediate_size
=
intermediate_size
# pad the intermediate size to be a multiple of 2 * mxfp4_block
# for to hold non-uniform sharded tensor as well as swizzling
if
USE_FLASHINFER_MXFP4_MOE
or
USE_FLASHINFER_MXFP4_BF16_MOE
:
intermediate_size_per_partition_after_pad
=
round_up
(
intermediate_size
,
256
)
hidden_size
=
round_up
(
hidden_size
,
256
)
elif
is_hip
():
intermediate_size_per_partition_after_pad
=
round_up
(
intermediate_size
,
128
)
else
:
intermediate_size_per_partition_after_pad
=
round_up
(
intermediate_size
,
64
)
self
.
intermediate_size
=
intermediate_size_per_partition_after_pad
self
.
hidden_size
=
hidden_size
# Fused gate_up_proj (column parallel)
w13_weight
=
torch
.
nn
.
Parameter
(
torch
.
zeros
(
num_experts
,
2
*
intermediate_size
,
hidden_size
//
2
,
dtype
=
weight_dtype
num_experts
,
2
*
intermediate_size_per_partition_after_pad
,
hidden_size
//
2
,
dtype
=
weight_dtype
,
),
requires_grad
=
False
,
)
...
...
@@ -247,7 +266,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
w13_weight_scale
=
torch
.
nn
.
Parameter
(
torch
.
zeros
(
num_experts
,
2
*
intermediate_size
,
2
*
intermediate_size
_per_partition_after_pad
,
hidden_size
//
mxfp4_block
,
dtype
=
scale_dtype
,
),
...
...
@@ -257,7 +276,11 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
set_weight_attrs
(
w13_weight_scale
,
extra_weight_attrs
)
w13_weight_bias
=
torch
.
nn
.
Parameter
(
torch
.
zeros
(
num_experts
,
2
*
intermediate_size
,
dtype
=
torch
.
bfloat16
),
torch
.
zeros
(
num_experts
,
2
*
intermediate_size_per_partition_after_pad
,
dtype
=
torch
.
bfloat16
,
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w13_weight_bias"
,
w13_weight_bias
)
...
...
@@ -266,7 +289,10 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
# down_proj (row parallel)
w2_weight
=
torch
.
nn
.
Parameter
(
torch
.
zeros
(
num_experts
,
hidden_size
,
intermediate_size
//
2
,
dtype
=
weight_dtype
num_experts
,
hidden_size
,
intermediate_size_per_partition_after_pad
//
2
,
dtype
=
weight_dtype
,
),
requires_grad
=
False
,
)
...
...
@@ -277,7 +303,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
torch
.
zeros
(
num_experts
,
hidden_size
,
intermediate_size
//
mxfp4_block
,
intermediate_size
_per_partition_after_pad
//
mxfp4_block
,
dtype
=
scale_dtype
,
),
requires_grad
=
False
,
...
...
@@ -293,6 +319,158 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
set_weight_attrs
(
w2_weight_bias
,
extra_weight_attrs
)
def
process_weights_after_loading
(
self
,
layer
):
if
USE_FLASHINFER_MXFP4_MOE
or
USE_FLASHINFER_MXFP4_BF16_MOE
:
logger
.
info
(
"Shuffling MoE weights for FlashInfer, it might take a while..."
)
layer
.
gemm1_alpha
=
Parameter
(
torch
.
tensor
([
1.702
]
*
self
.
num_experts
,
dtype
=
torch
.
float32
).
cuda
(),
requires_grad
=
False
,
)
layer
.
gemm1_beta
=
Parameter
(
torch
.
tensor
([
1.0
]
*
self
.
num_experts
,
dtype
=
torch
.
float32
).
cuda
(),
requires_grad
=
False
,
)
layer
.
gemm1_clamp_limit
=
Parameter
(
torch
.
tensor
([
7.0
]
*
self
.
num_experts
,
dtype
=
torch
.
float32
).
cuda
(),
requires_grad
=
False
,
)
sf_block_size
=
32
# mxfp4 block size
assert
(
layer
.
w13_weight
.
dim
()
==
3
and
layer
.
w13_weight
.
shape
[
0
]
==
self
.
num_experts
and
layer
.
w13_weight
.
shape
[
1
]
==
self
.
intermediate_size
*
2
and
layer
.
w13_weight
.
shape
[
2
]
==
self
.
hidden_size
//
2
)
assert
(
layer
.
w13_weight_scale
.
dim
()
==
3
and
layer
.
w13_weight_scale
.
shape
[
0
]
==
self
.
num_experts
and
layer
.
w13_weight_scale
.
shape
[
1
]
==
self
.
intermediate_size
*
2
and
layer
.
w13_weight_scale
.
shape
[
2
]
==
self
.
hidden_size
//
sf_block_size
)
assert
(
layer
.
w2_weight
.
dim
()
==
3
and
layer
.
w2_weight
.
shape
[
0
]
==
self
.
num_experts
and
layer
.
w2_weight
.
shape
[
1
]
==
self
.
hidden_size
and
layer
.
w2_weight
.
shape
[
2
]
==
self
.
intermediate_size
//
2
)
assert
(
layer
.
w2_weight_scale
.
dim
()
==
3
and
layer
.
w2_weight_scale
.
shape
[
1
]
==
self
.
hidden_size
and
layer
.
w2_weight_scale
.
shape
[
2
]
==
self
.
intermediate_size
//
sf_block_size
)
assert
(
layer
.
w13_weight_bias
.
dim
()
==
2
and
layer
.
w13_weight_bias
.
shape
[
0
]
==
self
.
num_experts
and
layer
.
w13_weight_bias
.
shape
[
1
]
==
self
.
intermediate_size
*
2
)
assert
(
layer
.
w2_weight_bias
.
dim
()
==
2
and
layer
.
w2_weight_bias
.
shape
[
0
]
==
self
.
num_experts
and
layer
.
w2_weight_bias
.
shape
[
1
]
==
self
.
hidden_size
)
w13_weight_scale
=
layer
.
w13_weight_scale
.
data
w2_weight_scale
=
layer
.
w2_weight_scale
.
data
w13_weight
=
layer
.
w13_weight
.
data
w2_weight
=
layer
.
w2_weight
.
data
w13_bias
=
layer
.
w13_weight_bias
.
data
.
to
(
torch
.
float32
)
w2_bias
=
layer
.
w2_weight_bias
.
data
.
to
(
torch
.
float32
)
# Swap w1 and w3 as the definition of
# swiglu is different in the trtllm-gen
def
swap_every_two_rows
(
x
,
axis
=-
1
):
shape
=
x
.
shape
if
axis
<
0
:
axis
=
len
(
shape
)
+
axis
# Create a new shape with pairs swapped along specified axis
new_shape
=
list
(
shape
)
new_shape
[
axis
]
=
shape
[
axis
]
//
2
new_shape
.
insert
(
axis
+
1
,
2
)
# Reshape to expose pairs, swap them, and reshape back
x
=
x
.
reshape
(
*
new_shape
)
x
=
x
.
flip
(
axis
+
1
)
new_shape
=
list
(
shape
)
return
x
.
reshape
(
*
new_shape
)
w13_weight_scale
=
swap_every_two_rows
(
w13_weight_scale
,
-
2
)
w13_weight
=
swap_every_two_rows
(
w13_weight
,
-
2
)
w13_bias
=
swap_every_two_rows
(
w13_bias
,
-
1
)
# Shuffle weights and scaling factors for transposed mma output
gemm1_weights_mxfp4_shuffled
=
[]
gemm1_scales_mxfp4_shuffled
=
[]
gemm2_weights_mxfp4_shuffled
=
[]
gemm2_scales_mxfp4_shuffled
=
[]
gemm1_bias_shuffled
=
[]
gemm2_bias_shuffled
=
[]
epilogue_tile_m
=
128
# FIXME: this depends on the kernel internals
for
i
in
range
(
self
.
num_experts
):
gemm1_weights_mxfp4_shuffled
.
append
(
shuffle_matrix_a
(
w13_weight
[
i
].
view
(
torch
.
uint8
),
epilogue_tile_m
)
)
gemm1_scales_mxfp4_shuffled
.
append
(
shuffle_matrix_sf_a
(
w13_weight_scale
[
i
].
view
(
torch
.
uint8
),
epilogue_tile_m
)
)
gemm1_bias_shuffled
.
append
(
shuffle_matrix_a
(
w13_bias
[
i
].
clone
().
reshape
(
-
1
,
1
),
epilogue_tile_m
)
)
gemm2_weights_mxfp4_shuffled
.
append
(
shuffle_matrix_a
(
w2_weight
[
i
].
view
(
torch
.
uint8
),
epilogue_tile_m
)
)
gemm2_scales_mxfp4_shuffled
.
append
(
shuffle_matrix_sf_a
(
w2_weight_scale
[
i
].
view
(
torch
.
uint8
),
epilogue_tile_m
)
)
gemm2_bias_shuffled
.
append
(
shuffle_matrix_a
(
w2_bias
[
i
].
clone
().
reshape
(
-
1
,
1
),
epilogue_tile_m
)
)
w13_weight
=
torch
.
stack
(
gemm1_weights_mxfp4_shuffled
)
w13_weight_scale
=
(
torch
.
stack
(
gemm1_scales_mxfp4_shuffled
)
.
reshape
(
self
.
num_experts
,
2
*
self
.
intermediate_size
,
self
.
hidden_size
//
sf_block_size
,
)
.
view
(
torch
.
float8_e4m3fn
)
)
w2_weight
=
torch
.
stack
(
gemm2_weights_mxfp4_shuffled
)
w2_weight_scale
=
(
torch
.
stack
(
gemm2_scales_mxfp4_shuffled
)
.
reshape
(
self
.
num_experts
,
self
.
hidden_size
,
self
.
intermediate_size
//
sf_block_size
,
)
.
view
(
torch
.
float8_e4m3fn
)
)
layer
.
w13_weight
=
Parameter
(
w13_weight
,
requires_grad
=
False
)
layer
.
w13_weight_scale
=
Parameter
(
w13_weight_scale
,
requires_grad
=
False
)
layer
.
w2_weight
=
Parameter
(
w2_weight
,
requires_grad
=
False
)
layer
.
w2_weight_scale
=
Parameter
(
w2_weight_scale
,
requires_grad
=
False
)
layer
.
w13_weight_bias
=
Parameter
(
torch
.
stack
(
gemm1_bias_shuffled
).
reshape
(
self
.
num_experts
,
-
1
),
requires_grad
=
False
,
)
layer
.
w2_weight_bias
=
Parameter
(
torch
.
stack
(
gemm2_bias_shuffled
).
reshape
(
self
.
num_experts
,
-
1
),
requires_grad
=
False
,
)
return
from
triton_kernels.matmul_ogs
import
FlexCtx
,
PrecisionConfig
...
...
@@ -366,22 +544,21 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
activation_alpha
:
Optional
[
float
]
=
None
,
swiglu_limit
:
Optional
[
float
]
=
None
,
)
->
torch
.
Tensor
:
# avoid import error when triton_kernel is not installed
# from vllm.model_executor.layers.fused_moe.triton_kernels_moe import (
# triton_kernel_moe_forward)
"""
if (envs.VLLM_USE_FLASHINFER_MXFP4_MOE
or envs.VLLM_USE_FLASHINFER_MXFP4_BF16_MOE):
assert not self.moe.use_ep, (
"EP is not supported for flashinfer mxfp4 moe backend yet.")
if envs.VLLM_USE_FLASHINFER_MXFP4_BF16_MOE:
if
USE_FLASHINFER_MXFP4_MOE
or
USE_FLASHINFER_MXFP4_BF16_MOE
:
# When USE_FLASHINFER_MXFP4_BF16_MOE is enabled, we don't need to quantize the input,
# TRT-LLM automatically handles quantization in the kernel implementation and pipelines it with GEMM operations,
# which can theoretically improve performance
if
USE_FLASHINFER_MXFP4_BF16_MOE
:
assert
x
.
dtype
==
torch
.
bfloat16
x_quant
=
x
x_scale
=
None
else
:
x_quant
,
x_scale
=
mxfp8_quantize
(
x
,
False
)
# to mxfp8
x_scale
=
x_scale
.
view
(
torch
.
float8_e4m3fn
).
reshape
(
-
1
)
topk_weights
,
topk_ids
,
router_logits
=
topk_output
top_k
=
topk_weights
.
shape
[
-
1
]
trtllm_gen_output
=
trtllm_fp4_block_scale_moe
(
router_logits
.
to
(
torch
.
bfloat16
),
None
,
# routing_bias
...
...
@@ -412,7 +589,6 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
True
,
# do finalize
)[
0
]
return
trtllm_gen_output
"""
if
self
.
use_triton_kernels
:
if
self
.
with_bias
:
...
...
python/sglang/srt/server_args.py
View file @
4373df55
...
...
@@ -464,7 +464,21 @@ class ServerArgs:
model_arch
=
self
.
get_hf_config
().
architectures
[
0
]
if
model_arch
in
[
"GptOssForCausalLM"
]:
self
.
attention_backend
=
"triton"
self
.
enable_triton_kernel_moe
=
True
# Check if FlashInfer MXFP4 MoE is enabled
from
sglang.srt.utils
import
get_bool_env_var
USE_FLASHINFER_MXFP4_MOE
=
get_bool_env_var
(
"SGLANG_USE_FLASHINFER_MXFP4_MOE"
,
"false"
)
USE_FLASHINFER_MXFP4_BF16_MOE
=
get_bool_env_var
(
"SGLANG_USE_FLASHINFER_MXFP4_BF16_MOE"
,
"false"
)
# Only enable Triton kernel MoE if FlashInfer is not enabled
if
not
(
USE_FLASHINFER_MXFP4_MOE
or
USE_FLASHINFER_MXFP4_BF16_MOE
):
self
.
enable_triton_kernel_moe
=
True
self
.
disable_hybrid_swa_memory
=
True
quantization_config
=
getattr
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment