Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
56e544f2
Unverified
Commit
56e544f2
authored
Jul 26, 2025
by
Wentao Ye
Committed by
GitHub
Jul 26, 2025
Browse files
[Refactor] Remove `moe_align_block_size_triton` (#21335)
Signed-off-by:
yewentao256
<
zhyanwentao@126.com
>
parent
97d6c30c
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
6 additions
and
224 deletions
+6
-224
benchmarks/kernels/benchmark_moe_align_block_size.py
benchmarks/kernels/benchmark_moe_align_block_size.py
+4
-86
vllm/model_executor/layers/fused_moe/moe_align_block_size.py
vllm/model_executor/layers/fused_moe/moe_align_block_size.py
+2
-138
No files found.
benchmarks/kernels/benchmark_moe_align_block_size.py
View file @
56e544f2
...
...
@@ -5,9 +5,8 @@ import itertools
import
torch
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.fused_moe.moe_align_block_size
import
(
moe_align_block_size
_triton
,
moe_align_block_size
,
)
from
vllm.triton_utils
import
triton
...
...
@@ -21,60 +20,6 @@ def get_topk_ids(num_tokens: int, num_experts: int, topk: int) -> torch.Tensor:
)
def
check_correctness
(
num_tokens
,
num_experts
=
256
,
block_size
=
256
,
topk
=
8
):
"""
Verifies vllm vs. Triton
"""
topk_ids
=
get_topk_ids
(
num_tokens
,
num_experts
,
topk
)
# 1. malloc space for triton and vllm
# malloc enough space (max_num_tokens_padded) for the sorted ids
max_num_tokens_padded
=
topk_ids
.
numel
()
+
num_experts
*
(
block_size
-
1
)
sorted_ids_triton
=
torch
.
empty
(
(
max_num_tokens_padded
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
expert_ids_triton
=
torch
.
empty
(
(
max_num_tokens_padded
//
block_size
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
num_tokens_post_pad_triton
=
torch
.
empty
((
1
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
sorted_ids_vllm
=
torch
.
empty_like
(
sorted_ids_triton
)
expert_ids_vllm
=
torch
.
empty_like
(
expert_ids_triton
)
num_tokens_post_pad_vllm
=
torch
.
empty_like
(
num_tokens_post_pad_triton
)
# 2. run implementations
moe_align_block_size_triton
(
topk_ids
,
num_experts
,
block_size
,
sorted_ids_triton
,
expert_ids_triton
,
num_tokens_post_pad_triton
,
)
ops
.
moe_align_block_size
(
topk_ids
,
num_experts
,
block_size
,
sorted_ids_vllm
,
expert_ids_vllm
,
num_tokens_post_pad_vllm
,
)
print
(
f
"✅ VLLM implementation works with
{
num_experts
}
experts!"
)
# 3. compare results
if
torch
.
allclose
(
expert_ids_triton
,
expert_ids_vllm
)
and
torch
.
allclose
(
num_tokens_post_pad_triton
,
num_tokens_post_pad_vllm
):
print
(
"✅ Triton and VLLM implementations match."
)
else
:
print
(
"❌ Triton and VLLM implementations DO NOT match."
)
print
(
"Triton expert_ids:"
,
expert_ids_triton
)
print
(
"VLLM expert_ids:"
,
expert_ids_vllm
)
print
(
"Triton num_tokens_post_pad:"
,
num_tokens_post_pad_triton
)
print
(
"VLLM num_tokens_post_pad:"
,
num_tokens_post_pad_vllm
)
# test configurations
num_tokens_range
=
[
1
,
16
,
256
,
4096
]
num_experts_range
=
[
16
,
64
,
224
,
256
,
280
,
512
]
...
...
@@ -87,8 +32,8 @@ configs = list(itertools.product(num_tokens_range, num_experts_range, topk_range
x_names
=
[
"num_tokens"
,
"num_experts"
,
"topk"
],
x_vals
=
configs
,
line_arg
=
"provider"
,
line_vals
=
[
"vllm"
,
"triton"
],
# "triton"
line_names
=
[
"
V
LLM"
,
"Triton"
],
# "Triton"
line_vals
=
[
"vllm"
]
,
line_names
=
[
"
v
LLM"
]
,
plot_name
=
"moe-align-block-size-performance"
,
args
=
{},
)
...
...
@@ -98,36 +43,11 @@ def benchmark(num_tokens, num_experts, topk, provider):
block_size
=
256
topk_ids
=
get_topk_ids
(
num_tokens
,
num_experts
,
topk
)
max_num_tokens_padded
=
topk_ids
.
numel
()
+
num_experts
*
(
block_size
-
1
)
sorted_ids
=
torch
.
empty
((
max_num_tokens_padded
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
max_num_m_blocks
=
max_num_tokens_padded
//
block_size
expert_ids
=
torch
.
empty
((
max_num_m_blocks
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
num_tokens_post_pad
=
torch
.
empty
((
1
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
quantiles
=
[
0.5
,
0.2
,
0.8
]
if
provider
==
"vllm"
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
lambda
:
ops
.
moe_align_block_size
(
topk_ids
,
num_experts
,
block_size
,
sorted_ids
.
clone
(),
expert_ids
.
clone
(),
num_tokens_post_pad
.
clone
(),
),
quantiles
=
quantiles
,
)
elif
provider
==
"triton"
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
lambda
:
moe_align_block_size_triton
(
topk_ids
,
num_experts
,
block_size
,
sorted_ids
.
clone
(),
expert_ids
.
clone
(),
num_tokens_post_pad
.
clone
(),
),
lambda
:
moe_align_block_size
(
topk_ids
,
block_size
,
num_experts
),
quantiles
=
quantiles
,
)
...
...
@@ -151,6 +71,4 @@ if __name__ == "__main__":
)
args
=
parser
.
parse_args
()
print
(
"Running correctness check..."
)
check_correctness
(
num_tokens
=
1024
,
num_experts
=
args
.
num_experts
,
topk
=
args
.
topk
)
benchmark
.
run
(
print_data
=
True
,
show_plots
=
True
)
vllm/model_executor/layers/fused_moe/moe_align_block_size.py
View file @
56e544f2
...
...
@@ -5,144 +5,8 @@ from typing import Optional
import
torch
from
vllm
import
_custom_ops
as
ops
from
vllm.triton_utils
import
tl
,
triton
from
vllm.utils
import
cdiv
,
round_up
@
triton
.
jit
def
moe_align_block_size_stage1
(
topk_ids_ptr
,
tokens_cnts_ptr
,
num_experts
:
tl
.
constexpr
,
numel
:
tl
.
constexpr
,
tokens_per_thread
:
tl
.
constexpr
,
):
pid
=
tl
.
program_id
(
0
)
start_idx
=
pid
*
tokens_per_thread
off_c
=
(
pid
+
1
)
*
num_experts
for
i
in
range
(
tokens_per_thread
):
if
start_idx
+
i
<
numel
:
idx
=
tl
.
load
(
topk_ids_ptr
+
start_idx
+
i
)
token_cnt
=
tl
.
load
(
tokens_cnts_ptr
+
off_c
+
idx
)
tl
.
store
(
tokens_cnts_ptr
+
off_c
+
idx
,
token_cnt
+
1
)
@
triton
.
jit
def
moe_align_block_size_stage2
(
tokens_cnts_ptr
,
num_experts
:
tl
.
constexpr
,
):
pid
=
tl
.
program_id
(
0
)
last_cnt
=
0
for
i
in
range
(
1
,
num_experts
+
1
):
token_cnt
=
tl
.
load
(
tokens_cnts_ptr
+
i
*
num_experts
+
pid
)
last_cnt
=
last_cnt
+
token_cnt
tl
.
store
(
tokens_cnts_ptr
+
i
*
num_experts
+
pid
,
last_cnt
)
@
triton
.
jit
def
moe_align_block_size_stage3
(
total_tokens_post_pad_ptr
,
tokens_cnts_ptr
,
cumsum_ptr
,
num_experts
:
tl
.
constexpr
,
block_size
:
tl
.
constexpr
,
):
last_cumsum
=
0
off_cnt
=
num_experts
*
num_experts
for
i
in
range
(
1
,
num_experts
+
1
):
token_cnt
=
tl
.
load
(
tokens_cnts_ptr
+
off_cnt
+
i
-
1
)
last_cumsum
=
last_cumsum
+
tl
.
cdiv
(
token_cnt
,
block_size
)
*
block_size
tl
.
store
(
cumsum_ptr
+
i
,
last_cumsum
)
tl
.
store
(
total_tokens_post_pad_ptr
,
last_cumsum
)
@
triton
.
jit
def
moe_align_block_size_stage4
(
topk_ids_ptr
,
sorted_token_ids_ptr
,
expert_ids_ptr
,
tokens_cnts_ptr
,
cumsum_ptr
,
num_experts
:
tl
.
constexpr
,
block_size
:
tl
.
constexpr
,
numel
:
tl
.
constexpr
,
tokens_per_thread
:
tl
.
constexpr
,
):
pid
=
tl
.
program_id
(
0
)
start_idx
=
tl
.
load
(
cumsum_ptr
+
pid
)
end_idx
=
tl
.
load
(
cumsum_ptr
+
pid
+
1
)
for
i
in
range
(
start_idx
,
end_idx
,
block_size
):
tl
.
store
(
expert_ids_ptr
+
i
//
block_size
,
pid
)
start_idx
=
pid
*
tokens_per_thread
off_t
=
pid
*
num_experts
for
i
in
range
(
start_idx
,
tl
.
minimum
(
start_idx
+
tokens_per_thread
,
numel
)):
expert_id
=
tl
.
load
(
topk_ids_ptr
+
i
)
token_cnt
=
tl
.
load
(
tokens_cnts_ptr
+
off_t
+
expert_id
)
rank_post_pad
=
token_cnt
+
tl
.
load
(
cumsum_ptr
+
expert_id
)
tl
.
store
(
sorted_token_ids_ptr
+
rank_post_pad
,
i
)
tl
.
store
(
tokens_cnts_ptr
+
off_t
+
expert_id
,
token_cnt
+
1
)
# Triton implementation based on:
# https://github.com/sgl-project/sglang/commit/ba5112ff691d791a9e38c6c71f59324a5fcb49d0
def
moe_align_block_size_triton
(
topk_ids
:
torch
.
Tensor
,
num_experts
:
int
,
block_size
:
int
,
sorted_token_ids
:
torch
.
Tensor
,
expert_ids
:
torch
.
Tensor
,
num_tokens_post_pad
:
torch
.
Tensor
,
)
->
None
:
numel
=
topk_ids
.
numel
()
grid
=
(
num_experts
,
)
tokens_cnts
=
torch
.
zeros
((
num_experts
+
1
,
num_experts
),
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
)
cumsum
=
torch
.
zeros
((
num_experts
+
1
,
),
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
)
tokens_per_thread
=
cdiv
(
numel
,
num_experts
)
sorted_token_ids
.
fill_
(
numel
)
expert_ids
.
zero_
()
moe_align_block_size_stage1
[
grid
](
topk_ids
,
tokens_cnts
,
num_experts
,
numel
,
tokens_per_thread
,
)
moe_align_block_size_stage2
[
grid
](
tokens_cnts
,
num_experts
,
)
moe_align_block_size_stage3
[(
1
,
)](
num_tokens_post_pad
,
tokens_cnts
,
cumsum
,
num_experts
,
block_size
,
)
moe_align_block_size_stage4
[
grid
](
topk_ids
,
sorted_token_ids
,
expert_ids
,
tokens_cnts
,
cumsum
,
num_experts
,
block_size
,
numel
,
tokens_per_thread
,
)
from
vllm.triton_utils
import
triton
from
vllm.utils
import
round_up
def
moe_align_block_size
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment