Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
6e8d8c4a
Unverified
Commit
6e8d8c4a
authored
Aug 01, 2025
by
Wentao Ye
Committed by
GitHub
Aug 02, 2025
Browse files
[Test] Add Unit Test for Batched DeepGEMM (#21559)
Signed-off-by:
yewentao256
<
zhyanwentao@126.com
>
parent
8d524ce7
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
107 additions
and
8 deletions
+107
-8
tests/kernels/moe/test_batched_deepgemm.py
tests/kernels/moe/test_batched_deepgemm.py
+103
-0
tests/kernels/moe/test_deepgemm.py
tests/kernels/moe/test_deepgemm.py
+2
-6
vllm/utils/deep_gemm.py
vllm/utils/deep_gemm.py
+2
-2
No files found.
tests/kernels/moe/test_batched_deepgemm.py
0 → 100644
View file @
6e8d8c4a
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
pytest
import
torch
from
vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe
import
(
BatchedDeepGemmExperts
)
from
vllm.model_executor.layers.fused_moe.fused_batched_moe
import
(
BatchedPrepareAndFinalize
,
BatchedTritonExperts
)
from
vllm.model_executor.layers.fused_moe.modular_kernel
import
(
FusedMoEModularKernel
)
from
vllm.utils.deep_gemm
import
calc_diff
,
is_deep_gemm_supported
from
.test_deepgemm
import
make_block_quant_fp8_weights
BLOCK_SIZE
=
[
128
,
128
]
@
pytest
.
mark
.
skipif
(
not
is_deep_gemm_supported
(),
reason
=
"Requires deep_gemm kernels"
)
@
pytest
.
mark
.
parametrize
(
"E"
,
[
16
,
32
])
# number of experts
@
pytest
.
mark
.
parametrize
(
"T"
,
[
256
,
512
])
# tokens per expert
@
pytest
.
mark
.
parametrize
(
"K"
,
[
128
,
256
])
# hidden dim
@
pytest
.
mark
.
parametrize
(
"N"
,
[
512
,
1024
])
# intermediate dim per expert
@
pytest
.
mark
.
parametrize
(
"topk"
,
[
2
,
4
])
def
test_batched_deepgemm_vs_triton
(
E
:
int
,
T
:
int
,
K
:
int
,
N
:
int
,
topk
:
int
,
monkeypatch
):
"""Compare BatchedDeepGemmExperts to BatchedTritonExperts."""
monkeypatch
.
setenv
(
"VLLM_USE_DEEP_GEMM"
,
"1"
)
device
=
"cuda"
w1
,
w2
,
w1_s
,
w2_s
=
make_block_quant_fp8_weights
(
E
,
N
,
K
,
BLOCK_SIZE
)
M
=
E
*
T
# total tokens
a
=
torch
.
randn
(
M
,
K
,
device
=
device
,
dtype
=
torch
.
bfloat16
)
/
10.0
fp8_info
=
torch
.
finfo
(
torch
.
float8_e4m3fn
)
a
.
clamp_
(
fp8_info
.
min
,
fp8_info
.
max
)
# random router outputs → top-k indices / weights
router_logits
=
torch
.
randn
(
M
,
E
,
device
=
device
,
dtype
=
torch
.
float32
)
topk_weights
,
topk_ids
=
torch
.
topk
(
router_logits
,
k
=
topk
,
dim
=-
1
)
topk_weights
=
torch
.
nn
.
functional
.
softmax
(
topk_weights
,
dim
=-
1
)
# token number for each expert
cnt
=
torch
.
bincount
(
topk_ids
.
flatten
(),
minlength
=
E
)
max_cnt
=
int
(
cnt
.
max
().
item
())
# next power of 2 for max token number
max_num_tokens
=
1
<<
(
max_cnt
-
1
).
bit_length
()
prep_finalize
=
BatchedPrepareAndFinalize
(
max_num_tokens
=
max_num_tokens
,
num_local_experts
=
E
,
num_dispatchers
=
1
,
rank
=
0
,
)
# triton (reference)
triton_experts
=
BatchedTritonExperts
(
max_num_tokens
=
max_num_tokens
,
num_dispatchers
=
1
,
use_fp8_w8a8
=
True
,
per_act_token_quant
=
False
,
block_shape
=
BLOCK_SIZE
,
)
mk_triton
=
FusedMoEModularKernel
(
prep_finalize
,
triton_experts
)
out_triton
=
mk_triton
(
hidden_states
=
a
,
w1
=
w1
,
w2
=
w2
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
inplace
=
False
,
w1_scale
=
w1_s
,
w2_scale
=
w2_s
,
global_num_experts
=
E
,
)
# deepgemm
deepgemm_experts
=
BatchedDeepGemmExperts
(
max_num_tokens
=
max_num_tokens
,
num_dispatchers
=
1
,
block_shape
=
BLOCK_SIZE
,
per_act_token_quant
=
False
,
)
mk_deepgemm
=
FusedMoEModularKernel
(
prep_finalize
,
deepgemm_experts
)
out_deepgemm
=
mk_deepgemm
(
hidden_states
=
a
,
w1
=
w1
,
w2
=
w2
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
inplace
=
False
,
w1_scale
=
w1_s
,
w2_scale
=
w2_s
,
global_num_experts
=
E
,
)
diff
=
calc_diff
(
out_deepgemm
,
out_triton
)
assert
diff
<
1e-3
,
f
"Output diff too large:
{
diff
}
"
tests/kernels/moe/test_deepgemm.py
View file @
6e8d8c4a
...
...
@@ -20,11 +20,6 @@ from vllm.utils.deep_gemm import (calc_diff, is_deep_gemm_supported,
BLOCK_SIZE
=
[
128
,
128
]
requires_deep_gemm
=
pytest
.
mark
.
skipif
(
not
is_deep_gemm_supported
(),
reason
=
"Requires deep_gemm kernels"
,
)
def
make_block_quant_fp8_weights
(
e
:
int
,
...
...
@@ -152,7 +147,8 @@ NUM_EXPERTS = [32]
@
pytest
.
mark
.
parametrize
(
"mnk"
,
MNKs
)
@
pytest
.
mark
.
parametrize
(
"topk"
,
TOPKS
)
@
pytest
.
mark
.
parametrize
(
"num_experts"
,
NUM_EXPERTS
)
@
requires_deep_gemm
@
pytest
.
mark
.
skipif
(
not
is_deep_gemm_supported
(),
reason
=
"Requires deep_gemm kernels"
)
def
test_deepgemm_vs_triton
(
mnk
,
topk
,
num_experts
,
monkeypatch
):
with
monkeypatch
.
context
()
as
m
:
...
...
vllm/utils/deep_gemm.py
View file @
6e8d8c4a
...
...
@@ -23,10 +23,10 @@ def is_deep_gemm_supported() -> bool:
"""Return ``True`` if DeepGEMM is supported on the current platform.
Currently, only Hopper and Blackwell GPUs are supported.
"""
supported_arch
=
current_platform
.
is_cuda
()
and
(
is_
supported_arch
=
current_platform
.
is_cuda
()
and
(
current_platform
.
is_device_capability
(
90
)
or
current_platform
.
is_device_capability
(
100
))
return
has_deep_gemm
()
and
supported_arch
return
has_deep_gemm
()
and
is_
supported_arch
@
functools
.
cache
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment