Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
7130a7ce
Unverified
Commit
7130a7ce
authored
Mar 12, 2025
by
Xiaoyu Zhang
Committed by
GitHub
Mar 11, 2025
Browse files
refine sgl_moe_align_block_size_benchmark (#4327)
parent
8f1f614e
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
77 additions
and
30 deletions
+77
-30
sgl-kernel/benchmark/bench_moe_align_block_size.py
sgl-kernel/benchmark/bench_moe_align_block_size.py
+77
-30
No files found.
benchmark/kernels/fused_moe_triton/benchmark_deepseekv3
_moe_align_block
s
.py
→
sgl-kernel/benchmark/bench
_moe_align_block
_size
.py
View file @
7130a7ce
...
@@ -4,7 +4,8 @@ import itertools
...
@@ -4,7 +4,8 @@ import itertools
import
torch
import
torch
import
triton
import
triton
import
triton.language
as
tl
import
triton.language
as
tl
from
sgl_kernel
import
moe_align_block_size
from
sgl_kernel
import
moe_align_block_size
as
sgl_moe_align_block_size
from
vllm
import
_custom_ops
as
ops
USE_RANDOM_PERM
=
False
USE_RANDOM_PERM
=
False
...
@@ -139,15 +140,11 @@ def moe_align_block_size_triton(
...
@@ -139,15 +140,11 @@ def moe_align_block_size_triton(
)
)
def
calculate_diff
(
batch_size
,
seq_len
):
def
calculate_diff
(
num_tokens
,
num_experts
=
256
,
block_size
=
128
,
topk
=
8
):
num_experts
=
256
block_size
=
128
topk
=
8
topk_ids
=
torch
.
stack
(
topk_ids
=
torch
.
stack
(
[
[
torch
.
randperm
(
num_experts
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)[:
topk
]
torch
.
randperm
(
num_experts
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)[:
topk
]
for
_
in
range
(
batch_size
*
seq_l
en
)
for
_
in
range
(
num_tok
en
s
)
]
]
)
)
...
@@ -175,8 +172,13 @@ def calculate_diff(batch_size, seq_len):
...
@@ -175,8 +172,13 @@ def calculate_diff(batch_size, seq_len):
expert_ids_triton
=
torch
.
zeros_like
(
expert_ids_cuda
)
expert_ids_triton
=
torch
.
zeros_like
(
expert_ids_cuda
)
num_tokens_post_pad_triton
=
torch
.
empty_like
(
num_tokens_post_pad_cuda
)
num_tokens_post_pad_triton
=
torch
.
empty_like
(
num_tokens_post_pad_cuda
)
# compare the performance of cuda and triton implementation
sorted_ids_vllm
=
torch
.
empty_like
(
sorted_ids_cuda
)
moe_align_block_size
(
sorted_ids_vllm
.
fill_
(
topk_ids
.
numel
())
expert_ids_vllm
=
torch
.
zeros_like
(
expert_ids_cuda
)
num_tokens_post_pad_vllm
=
torch
.
empty_like
(
num_tokens_post_pad_cuda
)
# compare the performance of cuda, triton and vllm implementation
sgl_moe_align_block_size
(
topk_ids
,
topk_ids
,
num_experts
,
num_experts
,
block_size
,
block_size
,
...
@@ -194,22 +196,43 @@ def calculate_diff(batch_size, seq_len):
...
@@ -194,22 +196,43 @@ def calculate_diff(batch_size, seq_len):
expert_ids_triton
,
expert_ids_triton
,
num_tokens_post_pad_triton
,
num_tokens_post_pad_triton
,
)
)
ops
.
moe_align_block_size
(
topk_ids
,
num_experts
,
block_size
,
sorted_ids_vllm
,
expert_ids_vllm
,
num_tokens_post_pad_vllm
,
)
if
torch
.
allclose
(
expert_ids_cuda
,
expert_ids_triton
)
and
torch
.
allclose
(
if
torch
.
allclose
(
expert_ids_cuda
,
expert_ids_triton
)
and
torch
.
allclose
(
num_tokens_post_pad_cuda
,
num_tokens_post_pad_triton
num_tokens_post_pad_cuda
,
num_tokens_post_pad_triton
):
):
print
(
"✅
CUDA
and Triton implementations match"
)
print
(
"✅
SGL
and Triton implementations match"
)
else
:
else
:
print
(
"❌
CUDA
and Triton implementations do not match"
)
print
(
"❌
SGL
and Triton implementations do not match"
)
print
(
"
CUDA
expert_ids:"
,
expert_ids_cuda
)
print
(
"
SGL
expert_ids:"
,
expert_ids_cuda
)
print
(
"Triton expert_ids:"
,
expert_ids_triton
)
print
(
"Triton expert_ids:"
,
expert_ids_triton
)
print
(
"
CUDA
num_tokens_post_pad:"
,
num_tokens_post_pad_cuda
)
print
(
"
SGL
num_tokens_post_pad:"
,
num_tokens_post_pad_cuda
)
print
(
"Triton num_tokens_post_pad:"
,
num_tokens_post_pad_triton
)
print
(
"Triton num_tokens_post_pad:"
,
num_tokens_post_pad_triton
)
if
torch
.
allclose
(
expert_ids_cuda
,
expert_ids_vllm
)
and
torch
.
allclose
(
num_tokens_post_pad_cuda
,
num_tokens_post_pad_vllm
):
print
(
"✅ SGL and VLLM implementations match"
)
else
:
print
(
"❌ SGL and VLLM implementations do not match"
)
print
(
"SGL expert_ids:"
,
expert_ids_cuda
)
print
(
"VLLM expert_ids:"
,
expert_ids_vllm
)
print
(
"SGL num_tokens_post_pad:"
,
num_tokens_post_pad_cuda
)
print
(
"VLLM num_tokens_post_pad:"
,
num_tokens_post_pad_vllm
)
batch_size_range
=
[
2
**
i
for
i
in
range
(
0
,
8
)]
num_tokens_range
=
[
16
,
32
,
64
,
128
,
256
,
512
,
1024
,
2048
,
4096
,
8192
]
seq_length_range
=
[
2
**
i
for
i
in
range
(
0
,
16
)]
num_experts_range
=
[
32
,
64
,
128
,
256
]
configs
=
list
(
itertools
.
product
(
batch_size_range
,
seq_length_range
))
topk_range
=
[
2
,
4
,
8
]
configs
=
list
(
itertools
.
product
(
num_tokens_range
,
num_experts_range
,
topk_range
))
def
get_topk_ids
(
num_tokens
:
int
,
num_experts
:
int
,
topk
:
int
)
->
torch
.
Tensor
:
def
get_topk_ids
(
num_tokens
:
int
,
num_experts
:
int
,
topk
:
int
)
->
torch
.
Tensor
:
...
@@ -223,29 +246,27 @@ def get_topk_ids(num_tokens: int, num_experts: int, topk: int) -> torch.Tensor:
...
@@ -223,29 +246,27 @@ def get_topk_ids(num_tokens: int, num_experts: int, topk: int) -> torch.Tensor:
@
triton
.
testing
.
perf_report
(
@
triton
.
testing
.
perf_report
(
triton
.
testing
.
Benchmark
(
triton
.
testing
.
Benchmark
(
x_names
=
[
"
batch_size"
,
"seq_len
"
],
x_names
=
[
"
num_tokens"
,
"num_experts"
,
"topk
"
],
x_vals
=
[
list
(
_
)
for
_
in
configs
]
,
x_vals
=
configs
,
line_arg
=
"provider"
,
line_arg
=
"provider"
,
line_vals
=
[
"
cuda
"
,
"triton"
],
line_vals
=
[
"
sgl
"
,
"triton"
,
"vllm"
],
line_names
=
[
"
CUDA
"
,
"Triton"
],
line_names
=
[
"
SGL
"
,
"Triton"
,
"VLLM"
],
styles
=
[(
"blue"
,
"-"
),
(
"red"
,
"-"
)],
styles
=
[(
"blue"
,
"-"
),
(
"red"
,
"-"
),
(
"green"
,
"-"
)],
ylabel
=
"us"
,
ylabel
=
"us"
,
plot_name
=
"moe-align-block-size-performance"
,
plot_name
=
"moe-align-block-size-performance"
,
args
=
{},
args
=
{},
)
)
)
)
def
benchmark
(
batch_size
,
seq_len
,
provider
):
def
benchmark
(
num_tokens
,
num_experts
,
topk
,
provider
):
num_experts
=
256
block_size
=
128
block_size
=
128
topk
=
8
if
USE_RANDOM_PERM
:
if
USE_RANDOM_PERM
:
topk_ids
=
get_topk_ids
(
batch_size
*
seq_l
en
,
num_experts
,
topk
)
topk_ids
=
get_topk_ids
(
num_tok
en
s
,
num_experts
,
topk
)
else
:
else
:
topk_ids
=
torch
.
randint
(
topk_ids
=
torch
.
randint
(
0
,
0
,
num_experts
,
num_experts
,
(
batch_size
*
seq_l
en
,
topk
),
(
num_tok
en
s
,
topk
),
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
device
=
"cuda"
,
device
=
"cuda"
,
)
)
...
@@ -268,9 +289,9 @@ def benchmark(batch_size, seq_len, provider):
...
@@ -268,9 +289,9 @@ def benchmark(batch_size, seq_len, provider):
)
)
quantiles
=
[
0.5
,
0.2
,
0.8
]
quantiles
=
[
0.5
,
0.2
,
0.8
]
if
provider
==
"
cuda
"
:
if
provider
==
"
sgl
"
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
lambda
:
moe_align_block_size
(
lambda
:
sgl_
moe_align_block_size
(
topk_ids
,
topk_ids
,
num_experts
,
num_experts
,
block_size
,
block_size
,
...
@@ -282,7 +303,7 @@ def benchmark(batch_size, seq_len, provider):
...
@@ -282,7 +303,7 @@ def benchmark(batch_size, seq_len, provider):
),
),
quantiles
=
quantiles
,
quantiles
=
quantiles
,
)
)
el
se
:
el
if
provider
==
"triton"
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
lambda
:
moe_align_block_size_triton
(
lambda
:
moe_align_block_size_triton
(
topk_ids
,
topk_ids
,
...
@@ -294,6 +315,18 @@ def benchmark(batch_size, seq_len, provider):
...
@@ -294,6 +315,18 @@ def benchmark(batch_size, seq_len, provider):
),
),
quantiles
=
quantiles
,
quantiles
=
quantiles
,
)
)
else
:
# vllm
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
lambda
:
ops
.
moe_align_block_size
(
topk_ids
,
num_experts
,
block_size
,
sorted_ids
.
clone
(),
expert_ids
.
clone
(),
num_tokens_post_pad
.
clone
(),
),
quantiles
=
quantiles
,
)
return
1000
*
ms
,
1000
*
max_ms
,
1000
*
min_ms
return
1000
*
ms
,
1000
*
max_ms
,
1000
*
min_ms
...
@@ -306,8 +339,22 @@ if __name__ == "__main__":
...
@@ -306,8 +339,22 @@ if __name__ == "__main__":
default
=
"./configs/benchmark_ops/moe_align_blocks/"
,
default
=
"./configs/benchmark_ops/moe_align_blocks/"
,
help
=
"Path to save moe align benchmark results"
,
help
=
"Path to save moe align benchmark results"
,
)
)
parser
.
add_argument
(
"--num_experts"
,
type
=
int
,
default
=
256
,
choices
=
[
8
,
64
,
128
,
256
],
help
=
"Number of experts for benchmark"
,
)
parser
.
add_argument
(
"--topk"
,
type
=
int
,
default
=
8
,
choices
=
[
2
,
4
,
8
],
help
=
"Top-k value for benchmark"
,
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
calculate_diff
(
batch_size
=
4
,
seq_len
=
1024
)
calculate_diff
(
num_tokens
=
1024
,
num_experts
=
args
.
num_experts
,
topk
=
args
.
topk
)
benchmark
.
run
(
print_data
=
True
)
benchmark
.
run
(
print_data
=
True
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment