Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
3c8ac78d
Unverified
Commit
3c8ac78d
authored
Feb 03, 2025
by
Xiaoyu Zhang
Committed by
GitHub
Feb 03, 2025
Browse files
optimize test_fused_moe style (#3268)
parent
455bfe8d
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
90 additions
and
39 deletions
+90
-39
test/srt/test_fused_moe.py
test/srt/test_fused_moe.py
+90
-39
No files found.
test/srt/test_fused_moe.py
View file @
3c8ac78d
import
unittest
import
torch
import
torch.nn.functional
as
F
from
tqdm
import
tqdm
from
vllm.model_executor.layers.fused_moe
import
fused_moe
as
fused_moe_vllm
from
sglang.srt.layers.activation
import
SiluAndMul
...
...
@@ -11,6 +13,37 @@ class TestFusedMOE(unittest.TestCase):
NUM_EXPERTS
=
[
8
,
64
]
TOP_KS
=
[
2
,
6
]
@
staticmethod
def
create_random_cuda_tensor
(
shape
,
dtype
,
mean
=
0
,
std
=
0.01
):
"""Create a random CUDA tensor
Args:
shape: Tensor shape
dtype: Data type
mean: Mean value
std: Standard deviation
Returns:
torch.Tensor: Randomly initialized CUDA tensor
"""
return
torch
.
empty
(
shape
,
dtype
=
dtype
,
device
=
"cuda"
).
normal_
(
mean
,
std
)
def
get_tolerance
(
self
,
dtype
):
"""Get tolerance values for different data types
Args:
dtype: Data type
Returns:
tuple: (relative tolerance, absolute tolerance)
"""
if
dtype
==
torch
.
float32
:
return
1e-3
,
1e-5
elif
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
]:
return
1e-1
,
1e-2
else
:
return
1e-2
,
1e-2
# Default values for other types
def
torch_naive_moe
(
self
,
a
,
w1
,
w2
,
score
,
topk
):
B
,
D
=
a
.
shape
a
=
a
.
view
(
B
,
-
1
,
D
).
repeat
(
1
,
topk
,
1
).
reshape
(
-
1
,
D
)
...
...
@@ -30,23 +63,25 @@ class TestFusedMOE(unittest.TestCase):
).
sum
(
dim
=
1
)
def
_test_case
(
self
,
m
,
n
,
k
,
e
,
topk
,
dtype
,
use_fp8_w8a8
=
False
):
rtol
,
atol
=
self
.
get_tolerance
(
dtype
)
if
use_fp8_w8a8
:
# AssertionError: fp8e4nv data type is not supported on CUDA arch < 89
capability
=
torch
.
cuda
.
get_device_capability
()
if
not
(
capability
[
0
]
>=
9
or
capability
==
(
8
,
9
)):
return
a
=
torch
.
randn
((
m
,
k
),
device
=
"cuda"
,
dtype
=
dtype
)
/
10
w1
=
torch
.
randn
((
e
,
2
*
n
,
k
),
device
=
"cuda"
,
dtype
=
dtype
)
/
10
w2
=
torch
.
randn
((
e
,
k
,
n
),
device
=
"cuda"
,
dtype
=
dtype
)
/
10
a
=
self
.
create_random_cuda_tensor
((
m
,
k
),
dtype
)
w1
=
self
.
create_random_cuda_tensor
((
e
,
2
*
n
,
k
),
dtype
)
w2
=
self
.
create_random_cuda_tensor
((
e
,
k
,
n
),
dtype
)
w1
=
w1
.
to
(
torch
.
float8_e4m3fn
)
w2
=
w2
.
to
(
torch
.
float8_e4m3fn
)
score
=
torch
.
randn
((
m
,
e
),
device
=
"cuda"
,
dtype
=
dtype
)
score
=
self
.
create_random_cuda_tensor
((
m
,
e
),
dtype
)
w1_scale
=
torch
.
randn
(
e
,
dtype
=
torch
.
float32
,
device
=
"cuda"
)
w2_scale
=
torch
.
randn
(
e
,
dtype
=
torch
.
float32
,
device
=
"cuda"
)
a1_scale
=
torch
.
randn
(
1
,
dtype
=
torch
.
float32
,
device
=
"cuda"
)
a2_scale
=
torch
.
randn
(
1
,
dtype
=
torch
.
float32
,
device
=
"cuda"
)
w1_scale
=
self
.
create_random_cuda_tensor
(
e
,
torch
.
float32
)
w2_scale
=
self
.
create_random_cuda_tensor
(
e
,
torch
.
float32
)
a1_scale
=
self
.
create_random_cuda_tensor
(
1
,
torch
.
float32
)
a2_scale
=
self
.
create_random_cuda_tensor
(
1
,
torch
.
float32
)
sglang_output
=
fused_moe
(
a
,
...
...
@@ -76,17 +111,19 @@ class TestFusedMOE(unittest.TestCase):
a2_scale
=
a2_scale
,
)
torch
.
testing
.
assert_close
(
sglang_output
,
vllm_output
,
a
tol
=
2e-2
,
r
tol
=
0
)
torch
.
testing
.
assert_close
(
sglang_output
,
vllm_output
,
r
tol
=
rtol
,
a
tol
=
atol
)
else
:
a
=
torch
.
randn
((
m
,
k
),
device
=
"cuda"
,
dtype
=
dtype
)
/
10
w1
=
torch
.
randn
((
e
,
2
*
n
,
k
),
device
=
"cuda"
,
dtype
=
dtype
)
/
10
w2
=
torch
.
randn
((
e
,
k
,
n
),
device
=
"cuda"
,
dtype
=
dtype
)
/
10
score
=
torch
.
randn
((
m
,
e
),
device
=
"cuda"
,
dtype
=
dtype
)
a
=
self
.
create_random_cuda_tensor
((
m
,
k
),
dtype
)
w1
=
self
.
create_random_cuda_tensor
((
e
,
2
*
n
,
k
),
dtype
)
w2
=
self
.
create_random_cuda_tensor
((
e
,
k
,
n
),
dtype
)
score
=
self
.
create_random_cuda_tensor
((
m
,
e
),
dtype
)
triton_output
=
fused_moe
(
a
,
w1
,
w2
,
score
,
topk
,
renormalize
=
False
)
torch_output
=
self
.
torch_naive_moe
(
a
,
w1
,
w2
,
score
,
topk
)
torch
.
testing
.
assert_close
(
triton_output
,
torch_output
,
atol
=
2e-2
,
rtol
=
0
)
torch
.
testing
.
assert_close
(
triton_output
,
torch_output
,
rtol
=
rtol
,
atol
=
atol
)
def
test_various_configurations
(
self
):
m_values
=
[
1
,
33
,
64
,
222
,
1024
*
128
]
...
...
@@ -95,31 +132,45 @@ class TestFusedMOE(unittest.TestCase):
dtypes
=
[
torch
.
float16
,
torch
.
bfloat16
]
fp8_modes
=
[
False
,
True
]
for
m
in
m_values
:
for
n
in
n_values
:
for
k
in
k_values
:
for
e
in
self
.
NUM_EXPERTS
:
for
topk
in
self
.
TOP_KS
:
for
dtype
in
dtypes
:
for
use_fp8_w8a8
in
fp8_modes
:
with
self
.
subTest
(
m
=
m
,
n
=
n
,
k
=
k
,
e
=
e
,
topk
=
topk
,
dtype
=
dtype
,
fp8
=
use_fp8_w8a8
,
):
self
.
_test_case
(
m
,
n
,
k
,
e
,
topk
,
dtype
,
use_fp8_w8a8
=
use_fp8_w8a8
,
)
# Calculate total number of tests
total_tests
=
(
len
(
m_values
)
*
len
(
n_values
)
*
len
(
k_values
)
*
len
(
self
.
NUM_EXPERTS
)
*
len
(
self
.
TOP_KS
)
*
len
(
dtypes
)
*
len
(
fp8_modes
)
)
# Create progress bar
with
tqdm
(
total
=
total_tests
,
desc
=
"Running MoE tests"
)
as
pbar
:
for
m
in
m_values
:
for
n
in
n_values
:
for
k
in
k_values
:
for
e
in
self
.
NUM_EXPERTS
:
for
topk
in
self
.
TOP_KS
:
for
dtype
in
dtypes
:
for
use_fp8_w8a8
in
fp8_modes
:
with
self
.
subTest
(
m
=
m
,
n
=
n
,
k
=
k
,
e
=
e
,
topk
=
topk
,
dtype
=
dtype
,
fp8
=
use_fp8_w8a8
,
):
self
.
_test_case
(
m
,
n
,
k
,
e
,
topk
,
dtype
,
use_fp8_w8a8
=
use_fp8_w8a8
,
)
pbar
.
update
(
1
)
if
__name__
==
"__main__"
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment