Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
14127804
Unverified
Commit
14127804
authored
Nov 05, 2025
by
Kaixi Hou
Committed by
GitHub
Nov 05, 2025
Browse files
[NVIDIA] Fix unit test of MoE and add it to nightly ci (#12709)
parent
82f39dc1
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
5 additions
and
245 deletions
+5
-245
test/srt/run_suite.py
test/srt/run_suite.py
+3
-0
test/srt/test_fp4_moe.py
test/srt/test_fp4_moe.py
+2
-245
No files found.
test/srt/run_suite.py
View file @
14127804
...
...
@@ -215,6 +215,9 @@ suites = {
TestFile
(
"batch_invariant/test_batch_invariant_ops.py"
,
10
),
TestFile
(
"test_deepseek_v3_deterministic.py"
,
240
),
],
"nightly-4-gpu-b200"
:
[
TestFile
(
"test_fp4_moe.py"
,
300
),
],
"nightly-8-gpu"
:
[],
"__not_in_ci__"
:
[
TestFile
(
"ascend/test_ascend_w8a8_quantization.py"
),
...
...
python/sglang/tes
t/test_fp4_moe.py
→
test/sr
t/test_fp4_moe.py
View file @
14127804
...
...
@@ -5,10 +5,9 @@ import pytest
import
torch
from
flashinfer
import
fp4_quantize
from
flashinfer.fused_moe
import
cutlass_fused_moe
as
flashinfer_cutlass_fused_moe
from
sgl_kernel
import
scaled_fp4_grouped_quant
,
scaled_fp4_quant
from
sgl_kernel
import
scaled_fp4_grouped_quant
,
scaled_fp4_quant
,
silu_and_mul
from
torch.nn
import
functional
as
F
from
sglang.srt.layers.activation
import
SiluAndMul
from
sglang.srt.layers.moe.cutlass_moe
import
cutlass_moe_fp4
from
sglang.srt.layers.moe.cutlass_moe_params
import
CutlassMoEParams
,
CutlassMoEType
from
sglang.srt.layers.moe.flashinfer_cutedsl_moe
import
flashinfer_cutedsl_moe_masked
...
...
@@ -140,7 +139,7 @@ def torch_moe(a, w1, w2, score, topk, expert_map):
for
i
in
range
(
w1
.
shape
[
0
]):
mask
=
topk_ids
==
i
if
mask
.
sum
():
out
[
mask
]
=
S
ilu
AndMul
()
(
a
[
mask
]
@
w1
[
i
].
transpose
(
0
,
1
))
@
w2
[
i
].
transpose
(
out
[
mask
]
=
s
ilu
_and_mul
(
a
[
mask
]
@
w1
[
i
].
transpose
(
0
,
1
))
@
w2
[
i
].
transpose
(
0
,
1
)
return
(
...
...
@@ -451,248 +450,6 @@ def test_flashinfer_fp4_moe_no_graph(
check_moe
(
m
,
n
,
k
,
e
,
topk
,
dtype
,
flashinfer_moe_impl
,
flip_w13
=
True
)
@
pytest
.
mark
.
parametrize
(
"bs, hidden_dim, inter_dim"
,
[(
2
,
128
,
256
),
(
16
,
128
,
512
)])
@
pytest
.
mark
.
parametrize
(
"topk"
,
[
1
,
2
,
4
])
@
torch
.
inference_mode
()
def
test_flashinfer_cutedsl_moe_masked
(
bs
:
int
,
hidden_dim
:
int
,
inter_dim
:
int
,
topk
:
int
):
torch
.
manual_seed
(
42
)
device
=
"cuda"
dtype
=
torch
.
bfloat16
num_experts
=
8
hidden_states
=
(
torch
.
randn
(
bs
,
hidden_dim
,
dtype
=
torch
.
bfloat16
,
device
=
device
)
/
5.0
)
w1
=
(
torch
.
randn
(
num_experts
,
2
*
inter_dim
,
hidden_dim
,
dtype
=
torch
.
bfloat16
,
device
=
device
)
/
10.0
)
w2
=
(
torch
.
randn
(
num_experts
,
hidden_dim
,
inter_dim
,
dtype
=
torch
.
bfloat16
,
device
=
device
)
/
10.0
)
router_logits
=
torch
.
randn
(
bs
,
num_experts
,
dtype
=
torch
.
float32
)
hidden_states_expanded
=
(
hidden_states
.
view
(
bs
,
-
1
,
hidden_dim
)
.
repeat
(
1
,
topk
,
1
)
.
reshape
(
-
1
,
hidden_dim
)
)
hidden_states_3d
,
masked_m
,
topk_idx
,
routing_weights
=
prepare_inputs
(
hidden_states_expanded
,
router_logits
,
num_experts
,
topk
)
w1_amax
=
w1
.
abs
().
amax
(
dim
=
(
1
,
2
)).
to
(
torch
.
float32
).
to
(
w1
.
device
)
w2_amax
=
w2
.
abs
().
amax
(
dim
=
(
1
,
2
)).
to
(
torch
.
float32
).
to
(
w2
.
device
)
input_global_scale
=
torch
.
ones
(
(
num_experts
,),
dtype
=
torch
.
float32
,
device
=
hidden_states
.
device
)
w1_global_scale
=
FLOAT8_E4M3_MAX
*
FLOAT4_E2M1_MAX
/
w1_amax
w2_global_scale
=
FLOAT8_E4M3_MAX
*
FLOAT4_E2M1_MAX
/
w2_amax
a2_global_scale
=
torch
.
ones
(
(
num_experts
,),
dtype
=
torch
.
float32
,
device
=
hidden_states
.
device
)
# assume intermediate scale is 1.0
w1_fp4
,
w1_blockscale
=
scaled_fp4_grouped_quant
(
w1
,
w1_global_scale
,
torch
.
ones
(
num_experts
,
dtype
=
torch
.
int32
,
device
=
w1
.
device
)
*
2
*
inter_dim
,
)
w2_fp4
,
w2_blockscale
=
scaled_fp4_grouped_quant
(
w2
,
w2_global_scale
,
torch
.
ones
(
num_experts
,
dtype
=
torch
.
int32
,
device
=
w2
.
device
)
*
hidden_dim
,
)
w1_alpha
=
1.0
/
(
input_global_scale
*
w1_global_scale
)
w2_alpha
=
1.0
/
(
a2_global_scale
*
w2_global_scale
)
out
=
flashinfer_cutedsl_moe_masked
(
hidden_states_3d
.
to
(
hidden_states
.
device
),
input_global_scale
,
w1_fp4
.
permute
(
2
,
0
,
1
),
w1_blockscale
,
w1_alpha
,
w2_fp4
.
permute
(
2
,
0
,
1
),
a2_global_scale
,
w2_blockscale
,
w2_alpha
,
masked_m
.
to
(
hidden_states
.
device
),
)
# reference
a_fp4
,
a_scale_interleaved
=
fp4_quantize
(
hidden_states
,
input_global_scale
)
a_in_dtype
=
dequantize_nvfp4_to_dtype
(
a_fp4
,
a_scale_interleaved
,
input_global_scale
,
dtype
=
hidden_states
.
dtype
,
device
=
hidden_states
.
device
,
block_size
=
16
,
)
w1_d
=
torch
.
empty
(
(
num_experts
,
2
*
inter_dim
,
hidden_dim
),
device
=
w1
.
device
,
dtype
=
w1
.
dtype
)
w2_d
=
torch
.
empty
(
(
num_experts
,
hidden_dim
,
inter_dim
),
device
=
w2
.
device
,
dtype
=
w2
.
dtype
)
for
idx
in
range
(
0
,
num_experts
):
w1_fp4_sliced
,
w1_blockscale_sliced
=
fp4_quantize
(
w1
[
idx
],
w1_global_scale
[
idx
]
)
w2_fp4_sliced
,
w2_blockscale_sliced
=
fp4_quantize
(
w2
[
idx
],
w2_global_scale
[
idx
]
)
w1_d
[
idx
]
=
dequantize_nvfp4_to_dtype
(
w1_fp4_sliced
,
w1_blockscale_sliced
,
w1_global_scale
[
idx
],
dtype
=
w1
.
dtype
,
device
=
w1
.
device
,
block_size
=
16
,
)
w2_d
[
idx
]
=
dequantize_nvfp4_to_dtype
(
w2_fp4_sliced
,
w2_blockscale_sliced
,
w2_global_scale
[
idx
],
dtype
=
w2
.
dtype
,
device
=
w2
.
device
,
block_size
=
16
,
)
ref_output
=
torch_moe_nvfp4
(
a_in_dtype
,
w1_d
,
w2_d
,
topk
,
routing_weights
.
to
(
a_in_dtype
.
device
),
topk_idx
.
to
(
a_in_dtype
.
device
),
)
out_weighted
=
torch
.
zeros_like
(
ref_output
,
device
=
out
.
device
,
dtype
=
out
.
dtype
)
positions
=
torch
.
nonzero
(
masked_m
[
topk_idx
],
as_tuple
=
False
)
rows
,
cols
=
positions
[:,
0
],
positions
[:,
1
]
experts
=
topk_idx
[
rows
,
cols
]
for
i
in
range
(
num_experts
):
mask
=
experts
==
i
if
mask
.
any
():
idx
=
torch
.
nonzero
(
mask
,
as_tuple
=
False
).
squeeze
(
-
1
)
r
,
c
=
rows
[
idx
],
cols
[
idx
]
out_weighted
[
r
]
+=
out
[
i
,
:
len
(
r
),
:]
*
routing_weights
[
r
,
c
].
to
(
out
.
device
).
unsqueeze
(
-
1
)
torch
.
testing
.
assert_close
(
out_weighted
.
cpu
(),
ref_output
.
cpu
(),
atol
=
5e-2
,
rtol
=
5e-2
)
@
pytest
.
mark
.
parametrize
(
"bs, hidden_dim, inter_dim, topk"
,
[(
2
,
128
,
256
,
2
),
(
16
,
128
,
512
,
5
)]
)
@
torch
.
inference_mode
()
def
test_grouped_gemm_nt_masked
(
bs
:
int
,
hidden_dim
:
int
,
inter_dim
:
int
,
topk
:
int
)
->
None
:
torch
.
manual_seed
(
42
)
B
=
bs
D
=
hidden_dim
N
=
inter_dim
num_experts
=
8
hidden_states
=
torch
.
randn
(
B
,
D
,
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
weights
=
torch
.
randn
(
num_experts
,
N
,
D
,
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
router_logits
=
torch
.
randn
(
B
,
num_experts
,
dtype
=
torch
.
float32
)
hidden_states_expanded
=
(
hidden_states
.
view
(
B
,
-
1
,
D
).
repeat
(
1
,
topk
,
1
).
reshape
(
-
1
,
D
)
)
hidden_states_3d
,
masked_m
,
topk_idx
,
_
=
prepare_inputs
(
hidden_states_expanded
,
router_logits
,
num_experts
,
topk
)
# reference
out
=
torch
.
zeros
(
(
B
*
topk
,
weights
.
shape
[
1
]),
dtype
=
weights
.
dtype
,
device
=
weights
.
device
)
for
i
in
range
(
num_experts
):
mask
=
topk_idx
.
view
(
-
1
)
==
i
if
mask
.
sum
():
lhs
=
hidden_states_expanded
[
mask
]
rhs
=
weights
[
i
]
a_amax
=
lhs
.
abs
().
max
().
to
(
torch
.
float32
).
to
(
hidden_states
.
device
)
b_amax
=
rhs
.
abs
().
amax
().
to
(
torch
.
float32
).
to
(
weights
.
device
)
a_gs
=
FLOAT8_E4M3_MAX
*
FLOAT4_E2M1_MAX
/
a_amax
b_gs
=
FLOAT8_E4M3_MAX
*
FLOAT4_E2M1_MAX
/
b_amax
lhsq
,
lhsq_sf
=
fp4_quantize
(
lhs
,
a_gs
,
)
rhsq
,
rhsq_sf
=
fp4_quantize
(
rhs
,
b_gs
,
)
lhs_in_dtype
=
dequantize_nvfp4_to_dtype
(
lhsq
,
lhsq_sf
,
a_gs
,
dtype
=
hidden_states
.
dtype
,
device
=
hidden_states
.
device
,
block_size
=
16
,
)
rhs_in_dtype
=
dequantize_nvfp4_to_dtype
(
rhsq
,
rhsq_sf
,
b_gs
,
dtype
=
hidden_states
.
dtype
,
device
=
hidden_states
.
device
,
block_size
=
16
,
)
out
[
mask
]
=
lhs_in_dtype
@
rhs_in_dtype
.
t
()
a_amax
=
(
hidden_states_3d
.
abs
()
.
amax
(
dim
=
(
1
,
2
))
.
to
(
torch
.
float32
)
.
to
(
hidden_states
.
device
)
)
b_amax
=
weights
.
abs
().
amax
(
dim
=
(
1
,
2
)).
to
(
torch
.
float32
).
to
(
weights
.
device
)
a_gs
=
FLOAT8_E4M3_MAX
*
FLOAT4_E2M1_MAX
/
a_amax
b_gs
=
FLOAT8_E4M3_MAX
*
FLOAT4_E2M1_MAX
/
b_amax
out_flashinfer
=
flashinfer_cutedsl_grouped_gemm_nt_masked
(
hidden_states_3d
.
to
(
hidden_states
.
device
),
a_gs
,
weights
,
b_gs
,
masked_m
)
# re-pack out into [num_experts, max_m, n]
out_ref
=
torch
.
zeros
(
(
num_experts
,
max
(
masked_m
),
weights
.
shape
[
1
]),
dtype
=
out
.
dtype
)
expert_slot
=
[
0
]
*
num_experts
for
i
,
expert_id
in
enumerate
(
topk_idx
.
view
(
-
1
).
tolist
()):
out_ref
[
expert_id
,
expert_slot
[
expert_id
],
:]
=
out
[
i
]
expert_slot
[
expert_id
]
+=
1
# Note: just to compare the masked position due to cutedsl may write nan
# into unmasked position.
for
i
in
range
(
num_experts
):
torch
.
testing
.
assert_close
(
out_flashinfer
.
permute
(
2
,
0
,
1
)[
i
,
:
masked_m
[
i
]],
out_ref
.
to
(
out_flashinfer
.
device
)[
i
,
:
masked_m
[
i
]],
atol
=
1e-1
,
rtol
=
5e-2
,
)
if
__name__
==
"__main__"
:
test_cutlass_fp4_moe_no_graph
(
224
,
1024
,
1024
,
256
,
8
,
torch
.
half
)
test_flashinfer_fp4_moe_no_graph
(
224
,
1024
,
1024
,
256
,
8
,
torch
.
half
)
test_flashinfer_cutedsl_moe_masked
(
16
,
128
,
512
,
4
)
test_grouped_gemm_nt_masked
(
16
,
128
,
512
,
4
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment