Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
a60f88b5
Unverified
Commit
a60f88b5
authored
Aug 08, 2025
by
Trevor Morris
Committed by
GitHub
Aug 08, 2025
Browse files
Add unit test for flashinfer fp4 moe (#8330)
Co-authored-by:
Yineng Zhang
<
me@zhyncs.com
>
parent
591c232f
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
118 additions
and
36 deletions
+118
-36
python/sglang/test/test_fp4_moe.py
python/sglang/test/test_fp4_moe.py
+118
-36
No files found.
python/sglang/test/test_fp4_moe.py
View file @
a60f88b5
# SPDX-License-Identifier: Apache-2.0
from
typing
import
Callable
import
pytest
import
torch
from
flashinfer.fused_moe
import
cutlass_fused_moe
as
flashinfer_cutlass_fused_moe
from
sgl_kernel
import
scaled_fp4_quant
from
sglang.srt.layers.activation
import
SiluAndMul
...
...
@@ -111,15 +114,16 @@ def torch_moe(a, w1, w2, score, topk, expert_map):
).
sum
(
dim
=
1
)
@
pytest
.
mark
.
parametrize
(
"m,n,k"
,
MNK_FACTORS
)
@
pytest
.
mark
.
parametrize
(
"e"
,
[
40
,
64
,
256
])
@
pytest
.
mark
.
parametrize
(
"topk"
,
[
1
,
6
,
8
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
half
,
torch
.
bfloat16
])
@
torch
.
inference_mode
()
def
test_cutlass_fp4_moe_no_graph
(
m
:
int
,
n
:
int
,
k
:
int
,
e
:
int
,
topk
:
int
,
dtype
:
torch
.
dtype
def
check_moe
(
m
:
int
,
n
:
int
,
k
:
int
,
e
:
int
,
topk
:
int
,
dtype
:
torch
.
dtype
,
moe_impl
:
Callable
,
flip_w13
:
bool
,
):
torch
.
manual_seed
(
7
)
a
=
torch
.
randn
((
m
,
k
),
device
=
"cuda"
,
dtype
=
dtype
)
/
10
w1
=
torch
.
randn
((
e
,
2
*
n
,
k
),
device
=
"cuda"
,
dtype
=
dtype
)
/
10
...
...
@@ -167,38 +171,18 @@ def test_cutlass_fp4_moe_no_graph(
a1_gs
=
torch
.
ones
((
e
,),
device
=
"cuda"
,
dtype
=
torch
.
float32
)
a2_gs
=
torch
.
ones
((
e
,),
device
=
"cuda"
,
dtype
=
torch
.
float32
)
# strides for the cutlass moe_fp4 kernel
ab_strides_13
=
torch
.
full
(
(
e
,),
w1_q
.
shape
[
2
]
*
2
,
dtype
=
torch
.
int64
,
device
=
w1_q
.
device
)
c_strides_13
=
torch
.
full
(
(
e
,),
w1_q
.
shape
[
1
],
dtype
=
torch
.
int64
,
device
=
w1_q
.
device
)
ab_strides_2
=
torch
.
full
(
(
e
,),
w2_q
.
shape
[
2
]
*
2
,
dtype
=
torch
.
int64
,
device
=
w2_q
.
device
)
c_strides_2
=
torch
.
full
((
e
,),
w2_q
.
shape
[
1
],
dtype
=
torch
.
int64
,
device
=
w2_q
.
device
)
params
=
CutlassMoEParams
(
CutlassMoEType
.
BlockscaledFP4
,
device
=
a
.
device
,
num_experts
=
e
,
intermediate_size_per_partition
=
n
,
# n
hidden_size
=
k
,
)
# k
cutlass_output
=
cutlass_moe_fp4
(
test_output
=
moe_impl
(
a
=
a
,
a1_gscale
=
a1_gs
,
w1_fp4
=
w1_q
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
w1_q
=
w1_q
,
w2_q
=
w2_q
,
a1_gs
=
a1_gs
,
w1_blockscale
=
w1_blockscale
,
w1_alphas
=
(
1
/
w1_gs
),
a2_gscale
=
a2_gs
,
w2_fp4
=
w2_q
,
a2_gs
=
a2_gs
,
w2_blockscale
=
w2_blockscale
,
w2_alphas
=
(
1
/
w2_gs
),
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
params
=
params
,
apply_router_weight_on_input
=
False
,
)
# Reference check:
...
...
@@ -237,10 +221,108 @@ def test_cutlass_fp4_moe_no_graph(
block_size
=
quant_blocksize
,
)
if
flip_w13
:
dim
=
-
2
size
=
w1_d
.
size
(
dim
)
assert
size
%
2
==
0
,
f
"Expected even size in dim
{
dim
}
, got
{
size
}
"
half
=
size
//
2
# Reorder weight
w1
,
w3
=
w1_d
.
split
(
half
,
dim
=
dim
)
w1_d
=
torch
.
cat
([
w3
,
w1
],
dim
=
dim
).
contiguous
()
torch_output
=
torch_moe
(
a_in_dtype
,
w1_d
,
w2_d
,
score
,
topk
,
None
)
torch
.
testing
.
assert_close
(
torch_output
,
cutlass_output
,
atol
=
1e-1
,
rtol
=
1e-1
)
torch
.
testing
.
assert_close
(
torch_output
,
test_output
,
atol
=
1e-1
,
rtol
=
1e-1
)
@
pytest
.
mark
.
parametrize
(
"m,n,k"
,
MNK_FACTORS
)
@
pytest
.
mark
.
parametrize
(
"e"
,
[
40
,
64
,
256
])
@
pytest
.
mark
.
parametrize
(
"topk"
,
[
1
,
6
,
8
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
half
,
torch
.
bfloat16
])
@
torch
.
inference_mode
()
def
test_cutlass_fp4_moe_no_graph
(
m
:
int
,
n
:
int
,
k
:
int
,
e
:
int
,
topk
:
int
,
dtype
:
torch
.
dtype
):
def
cutlass_moe_impl
(
a
,
topk_weights
,
topk_ids
,
w1_q
,
w2_q
,
a1_gs
,
w1_blockscale
,
w1_alphas
,
a2_gs
,
w2_blockscale
,
w2_alphas
,
):
params
=
CutlassMoEParams
(
CutlassMoEType
.
BlockscaledFP4
,
device
=
a
.
device
,
num_experts
=
e
,
intermediate_size_per_partition
=
n
,
# n
hidden_size
=
k
,
)
# k
return
cutlass_moe_fp4
(
a
=
a
,
a1_gscale
=
a1_gs
,
w1_fp4
=
w1_q
,
w1_blockscale
=
w1_blockscale
,
w1_alphas
=
w1_alphas
,
a2_gscale
=
a2_gs
,
w2_fp4
=
w2_q
,
w2_blockscale
=
w2_blockscale
,
w2_alphas
=
w2_alphas
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
params
=
params
,
apply_router_weight_on_input
=
False
,
)
check_moe
(
m
,
n
,
k
,
e
,
topk
,
dtype
,
cutlass_moe_impl
,
flip_w13
=
False
)
@
pytest
.
mark
.
parametrize
(
"m,n,k"
,
MNK_FACTORS
)
@
pytest
.
mark
.
parametrize
(
"e"
,
[
40
,
64
,
256
])
@
pytest
.
mark
.
parametrize
(
"topk"
,
[
1
,
6
,
8
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
half
,
torch
.
bfloat16
])
@
torch
.
inference_mode
()
def
test_flashinfer_fp4_moe_no_graph
(
m
:
int
,
n
:
int
,
k
:
int
,
e
:
int
,
topk
:
int
,
dtype
:
torch
.
dtype
):
def
flashinfer_moe_impl
(
a
,
topk_weights
,
topk_ids
,
w1_q
,
w2_q
,
a1_gs
,
w1_blockscale
,
w1_alphas
,
a2_gs
,
w2_blockscale
,
w2_alphas
,
):
return
flashinfer_cutlass_fused_moe
(
a
,
topk_ids
.
to
(
torch
.
int
),
topk_weights
,
w1_q
.
view
(
torch
.
long
),
w2_q
.
view
(
torch
.
long
),
a
.
dtype
,
quant_scales
=
[
a1_gs
,
w1_blockscale
.
view
(
torch
.
int32
),
w1_alphas
,
a2_gs
,
w2_blockscale
.
view
(
torch
.
int32
),
w2_alphas
,
],
)[
0
]
check_moe
(
m
,
n
,
k
,
e
,
topk
,
dtype
,
flashinfer_moe_impl
,
flip_w13
=
True
)
if
__name__
==
"__main__"
:
test_cutlass_fp4_moe_no_graph
(
224
,
1024
,
1024
,
256
,
8
,
torch
.
half
)
test_flashinfer_fp4_moe_no_graph
(
224
,
1024
,
1024
,
256
,
8
,
torch
.
half
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment