Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
0ec82edd
Unverified
Commit
0ec82edd
authored
Jul 21, 2025
by
Himanshu Jaju
Committed by
GitHub
Jul 21, 2025
Browse files
[perf] Speed up align sum kernels (#21079)
Signed-off-by:
Himanshu Jaju
<
hj@mistral.ai
>
parent
005ae9be
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
60 additions
and
25 deletions
+60
-25
benchmarks/kernels/benchmark_moe_align_block_size.py
benchmarks/kernels/benchmark_moe_align_block_size.py
+2
-5
csrc/moe/moe_align_sum_kernels.cu
csrc/moe/moe_align_sum_kernels.cu
+55
-16
vllm/model_executor/layers/fused_moe/moe_align_block_size.py
vllm/model_executor/layers/fused_moe/moe_align_block_size.py
+3
-4
No files found.
benchmarks/kernels/benchmark_moe_align_block_size.py
View file @
0ec82edd
...
...
@@ -33,15 +33,13 @@ def check_correctness(num_tokens, num_experts=256, block_size=256, topk=8):
sorted_ids_triton
=
torch
.
empty
(
(
max_num_tokens_padded
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
sorted_ids_triton
.
fill_
(
topk_ids
.
numel
())
# fill with sentinel value
expert_ids_triton
=
torch
.
zeros
(
expert_ids_triton
=
torch
.
empty
(
(
max_num_tokens_padded
//
block_size
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
num_tokens_post_pad_triton
=
torch
.
empty
((
1
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
sorted_ids_vllm
=
torch
.
empty_like
(
sorted_ids_triton
)
sorted_ids_vllm
.
fill_
(
topk_ids
.
numel
())
expert_ids_vllm
=
torch
.
zeros_like
(
expert_ids_triton
)
expert_ids_vllm
=
torch
.
empty_like
(
expert_ids_triton
)
num_tokens_post_pad_vllm
=
torch
.
empty_like
(
num_tokens_post_pad_triton
)
# 2. run implementations
...
...
@@ -102,7 +100,6 @@ def benchmark(num_tokens, num_experts, topk, provider):
max_num_tokens_padded
=
topk_ids
.
numel
()
+
num_experts
*
(
block_size
-
1
)
sorted_ids
=
torch
.
empty
((
max_num_tokens_padded
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
sorted_ids
.
fill_
(
topk_ids
.
numel
())
max_num_m_blocks
=
max_num_tokens_padded
//
block_size
expert_ids
=
torch
.
empty
((
max_num_m_blocks
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
num_tokens_post_pad
=
torch
.
empty
((
1
,),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
...
...
csrc/moe/moe_align_sum_kernels.cu
View file @
0ec82edd
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <cub/cub.cuh>
#include <ATen/ATen.h>
#include <ATen/cuda/Atomic.cuh>
...
...
@@ -19,9 +20,14 @@ __global__ void moe_align_block_size_kernel(
int32_t
*
__restrict__
sorted_token_ids
,
int32_t
*
__restrict__
expert_ids
,
int32_t
*
__restrict__
total_tokens_post_pad
,
int32_t
num_experts
,
int32_t
padded_num_experts
,
int32_t
experts_per_warp
,
int32_t
block_size
,
size_t
numel
,
int32_t
*
__restrict__
cumsum
)
{
size_t
numel
,
int32_t
*
__restrict__
cumsum
,
int32_t
max_num_tokens_padded
)
{
extern
__shared__
int32_t
shared_counts
[];
// Initialize sorted_token_ids with numel
for
(
size_t
it
=
threadIdx
.
x
;
it
<
max_num_tokens_padded
;
it
+=
blockDim
.
x
)
{
sorted_token_ids
[
it
]
=
numel
;
}
const
int
warp_id
=
threadIdx
.
x
/
WARP_SIZE
;
const
int
my_expert_start
=
warp_id
*
experts_per_warp
;
...
...
@@ -45,18 +51,27 @@ __global__ void moe_align_block_size_kernel(
__syncthreads
();
if
(
threadIdx
.
x
==
0
)
{
cumsum
[
0
]
=
0
;
for
(
int
i
=
1
;
i
<=
num_experts
;
++
i
)
{
int
expert_count
=
0
;
int
warp_idx
=
(
i
-
1
)
/
experts_per_warp
;
int
expert_offset
=
(
i
-
1
)
%
experts_per_warp
;
expert_count
=
shared_counts
[
warp_idx
*
experts_per_warp
+
expert_offset
];
// Compute prefix sum over token counts per expert
using
BlockScan
=
cub
::
BlockScan
<
int32_t
,
1024
>
;
__shared__
typename
BlockScan
::
TempStorage
temp_storage
;
cumsum
[
i
]
=
cumsum
[
i
-
1
]
+
CEILDIV
(
expert_count
,
block_size
)
*
block_size
;
}
*
total_tokens_post_pad
=
cumsum
[
num_experts
];
int
expert_count
=
0
;
int
expert_id
=
threadIdx
.
x
;
if
(
expert_id
<
num_experts
)
{
int
warp_idx
=
expert_id
/
experts_per_warp
;
int
expert_offset
=
expert_id
%
experts_per_warp
;
expert_count
=
shared_counts
[
warp_idx
*
experts_per_warp
+
expert_offset
];
expert_count
=
CEILDIV
(
expert_count
,
block_size
)
*
block_size
;
}
int
cumsum_val
;
BlockScan
(
temp_storage
).
ExclusiveSum
(
expert_count
,
cumsum_val
);
if
(
expert_id
<=
num_experts
)
{
cumsum
[
expert_id
]
=
cumsum_val
;
}
if
(
expert_id
==
num_experts
)
{
*
total_tokens_post_pad
=
cumsum_val
;
}
__syncthreads
();
...
...
@@ -67,6 +82,13 @@ __global__ void moe_align_block_size_kernel(
expert_ids
[
i
/
block_size
]
=
threadIdx
.
x
;
}
}
// Fill remaining expert_ids with 0
const
size_t
fill_start_idx
=
cumsum
[
num_experts
]
/
block_size
+
threadIdx
.
x
;
const
size_t
expert_ids_size
=
CEILDIV
(
max_num_tokens_padded
,
block_size
);
for
(
size_t
i
=
fill_start_idx
;
i
<
expert_ids_size
;
i
+=
blockDim
.
x
)
{
expert_ids
[
i
]
=
0
;
}
}
template
<
typename
scalar_t
>
...
...
@@ -105,7 +127,12 @@ __global__ void moe_align_block_size_small_batch_expert_kernel(
const
scalar_t
*
__restrict__
topk_ids
,
int32_t
*
__restrict__
sorted_token_ids
,
int32_t
*
__restrict__
expert_ids
,
int32_t
*
__restrict__
total_tokens_post_pad
,
int32_t
num_experts
,
int32_t
block_size
,
size_t
numel
)
{
int32_t
block_size
,
size_t
numel
,
int32_t
max_num_tokens_padded
)
{
// Initialize sorted_token_ids with numel
for
(
size_t
it
=
threadIdx
.
x
;
it
<
max_num_tokens_padded
;
it
+=
blockDim
.
x
)
{
sorted_token_ids
[
it
]
=
numel
;
}
const
size_t
tid
=
threadIdx
.
x
;
const
size_t
stride
=
blockDim
.
x
;
...
...
@@ -153,6 +180,13 @@ __global__ void moe_align_block_size_small_batch_expert_kernel(
}
}
// Fill remaining expert_ids with 0
const
size_t
fill_start_idx
=
cumsum
[
num_experts
]
/
block_size
+
threadIdx
.
x
;
const
size_t
expert_ids_size
=
CEILDIV
(
max_num_tokens_padded
,
block_size
);
for
(
size_t
i
=
fill_start_idx
;
i
<
expert_ids_size
;
i
+=
blockDim
.
x
)
{
expert_ids
[
i
]
=
0
;
}
for
(
size_t
i
=
tid
;
i
<
numel
;
i
+=
stride
)
{
int32_t
expert_id
=
topk_ids
[
i
];
int32_t
rank_post_pad
=
...
...
@@ -179,13 +213,17 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
int
threads
=
1024
;
threads
=
((
threads
+
WARP_SIZE
-
1
)
/
WARP_SIZE
)
*
WARP_SIZE
;
// BlockScan uses 1024 threads and assigns one thread per expert.
TORCH_CHECK
(
padded_num_experts
<
1024
,
"padded_num_experts must be less than 1024"
);
VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES
(
topk_ids
.
scalar_type
(),
"moe_align_block_size_kernel"
,
[
&
]
{
// calc needed amount of shared mem for `cumsum` tensors
auto
options_int
=
torch
::
TensorOptions
().
dtype
(
torch
::
kInt
).
device
(
topk_ids
.
device
());
torch
::
Tensor
cumsum_buffer
=
torch
::
zeros
({
num_experts
+
1
},
options_int
);
torch
::
empty
({
num_experts
+
1
},
options_int
);
bool
small_batch_expert_mode
=
(
topk_ids
.
numel
()
<
1024
)
&&
(
num_experts
<=
64
);
...
...
@@ -203,7 +241,7 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
sorted_token_ids
.
data_ptr
<
int32_t
>
(),
experts_ids
.
data_ptr
<
int32_t
>
(),
num_tokens_post_pad
.
data_ptr
<
int32_t
>
(),
num_experts
,
block_size
,
topk_ids
.
numel
());
topk_ids
.
numel
()
,
sorted_token_ids
.
size
(
0
)
);
}
else
{
auto
align_kernel
=
vllm
::
moe
::
moe_align_block_size_kernel
<
scalar_t
>
;
...
...
@@ -217,7 +255,8 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
experts_ids
.
data_ptr
<
int32_t
>
(),
num_tokens_post_pad
.
data_ptr
<
int32_t
>
(),
num_experts
,
padded_num_experts
,
experts_per_warp
,
block_size
,
topk_ids
.
numel
(),
cumsum_buffer
.
data_ptr
<
int32_t
>
());
topk_ids
.
numel
(),
cumsum_buffer
.
data_ptr
<
int32_t
>
(),
sorted_token_ids
.
size
(
0
));
const
int
block_threads
=
std
::
min
(
256
,
(
int
)
threads
);
const
int
num_blocks
=
...
...
vllm/model_executor/layers/fused_moe/moe_align_block_size.py
View file @
0ec82edd
...
...
@@ -111,6 +111,8 @@ def moe_align_block_size_triton(
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
)
tokens_per_thread
=
cdiv
(
numel
,
num_experts
)
sorted_token_ids
.
fill_
(
numel
)
expert_ids
.
zero_
()
moe_align_block_size_stage1
[
grid
](
topk_ids
,
...
...
@@ -205,11 +207,8 @@ def moe_align_block_size(
sorted_ids
=
torch
.
empty
((
max_num_tokens_padded
,
),
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
)
sorted_ids
.
fill_
(
topk_ids
.
numel
())
max_num_m_blocks
=
triton
.
cdiv
(
max_num_tokens_padded
,
block_size
)
# Expert ids must be zeroed out to prevent index out of bounds error while
# mapping global expert ids to local expert ids in expert parallelism.
expert_ids
=
torch
.
zeros
((
max_num_m_blocks
,
),
expert_ids
=
torch
.
empty
((
max_num_m_blocks
,
),
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
)
num_tokens_post_pad
=
torch
.
empty
((
1
),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment