Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
ad349985
Unverified
Commit
ad349985
authored
Feb 06, 2025
by
Xiaoyu Zhang
Committed by
GitHub
Feb 06, 2025
Browse files
clean moe align block kernel code and add acc test (#3332)
parent
32de54ed
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
241 additions
and
84 deletions
+241
-84
benchmark/kernels/fused_moe_triton/benchmark_deepseekv3_moe_align_blocks.py
...fused_moe_triton/benchmark_deepseekv3_moe_align_blocks.py
+1
-1
sgl-kernel/src/sgl-kernel/csrc/moe_align_kernel.cu
sgl-kernel/src/sgl-kernel/csrc/moe_align_kernel.cu
+14
-26
sgl-kernel/src/sgl-kernel/include/utils.h
sgl-kernel/src/sgl-kernel/include/utils.h
+12
-0
sgl-kernel/tests/test_moe_align.py
sgl-kernel/tests/test_moe_align.py
+214
-57
No files found.
benchmark/kernels/fused_moe_triton/benchmark_deepseekv3_moe_align_blocks.py
View file @
ad349985
...
@@ -310,4 +310,4 @@ if __name__ == "__main__":
...
@@ -310,4 +310,4 @@ if __name__ == "__main__":
calculate_diff
(
batch_size
=
4
,
seq_len
=
1024
)
calculate_diff
(
batch_size
=
4
,
seq_len
=
1024
)
benchmark
.
run
(
print_data
=
True
,
save_path
=
args
.
save_path
)
benchmark
.
run
(
print_data
=
True
)
sgl-kernel/src/sgl-kernel/csrc/moe_align_kernel.cu
View file @
ad349985
...
@@ -13,8 +13,6 @@ See the License for the specific language governing permissions and
...
@@ -13,8 +13,6 @@ See the License for the specific language governing permissions and
limitations under the License.
limitations under the License.
==============================================================================*/
==============================================================================*/
// Adapted from https://github.com/vllm-project/vllm/blob/v0.6.5/csrc/moe/moe_align_sum_kernels.cu
#include <ATen/ATen.h>
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAGuard.h>
...
@@ -22,32 +20,15 @@ limitations under the License.
...
@@ -22,32 +20,15 @@ limitations under the License.
#include <THC/THCAtomics.cuh>
#include <THC/THCAtomics.cuh>
#define WARP_SIZE 32
#include "utils.h"
#define DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \
cudaFuncSetAttribute(FUNC, cudaFuncAttributeMaxDynamicSharedMemorySize, VAL)
#define CEILDIV(x, y) (((x) + (y)-1) / (y))
#define DISPATCH_CASE_INTEGRAL_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__)
#define DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \
#define WARP_SIZE 32
AT_DISPATCH_SWITCH(TYPE, NAME, DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__))
__device__
__forceinline__
int32_t
index
(
int32_t
total_col
,
int32_t
row
,
int32_t
col
)
{
return
row
*
total_col
+
col
;
}
template
<
typename
scalar_t
>
template
<
typename
scalar_t
>
__global__
void
moe_align_block_size_kernel
(
scalar_t
*
__restrict__
topk_ids
,
int32_t
*
sorted_token_ids
,
__global__
void
moe_align_block_size_kernel
(
scalar_t
*
__restrict__
topk_ids
,
int32_t
*
sorted_token_ids
,
int32_t
*
expert_ids
,
int32_t
*
total_tokens_post_pad
,
int32_t
num_experts
,
int32_t
*
expert_ids
,
int32_t
*
total_tokens_post_pad
,
int32_t
num_experts
,
int32_t
block_size
,
size_t
numel
,
int32_t
*
cumsum
)
{
int32_t
block_size
,
size_t
numel
,
int32_t
*
cumsum
)
{
__shared__
int32_t
shared_counts
[
32
][
8
];
__shared__
int32_t
shared_counts
[
WARP_SIZE
][
8
];
__shared__
int32_t
local_offsets
[
256
];
__shared__
int32_t
local_offsets
[
256
];
const
int
warp_id
=
threadIdx
.
x
/
WARP_SIZE
;
const
int
warp_id
=
threadIdx
.
x
/
WARP_SIZE
;
...
@@ -96,6 +77,11 @@ __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids, int
...
@@ -96,6 +77,11 @@ __global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids, int
__syncthreads
();
__syncthreads
();
// Note: For the moe_align_kernel, the primary bottleneck lies in the atomic add and non-coalesced memory writes here.
// If these operations can be performed using multiple blocks, similar to the Triton version, the performance of this
// kernel can achieve state-of-the-art performance across all token cases. However, once multiple blocks are used,
// illegal memory access occurs. Even replacing these lines of code with the stage 4 kernel from the Triton version
// results in the same issue, and a correct solution has not yet been found.
for
(
int
i
=
start_idx
;
i
<
numel
&&
i
<
start_idx
+
tokens_per_thread
;
++
i
)
{
for
(
int
i
=
start_idx
;
i
<
numel
&&
i
<
start_idx
+
tokens_per_thread
;
++
i
)
{
int32_t
expert_id
=
topk_ids
[
i
];
int32_t
expert_id
=
topk_ids
[
i
];
int32_t
rank_post_pad
=
atomicAdd
(
&
local_offsets
[
expert_id
],
1
);
int32_t
rank_post_pad
=
atomicAdd
(
&
local_offsets
[
expert_id
],
1
);
...
@@ -107,10 +93,12 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, int64_t b
...
@@ -107,10 +93,12 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, int64_t b
torch
::
Tensor
sorted_token_ids
,
torch
::
Tensor
experts_ids
,
torch
::
Tensor
num_tokens_post_pad
,
torch
::
Tensor
sorted_token_ids
,
torch
::
Tensor
experts_ids
,
torch
::
Tensor
num_tokens_post_pad
,
torch
::
Tensor
token_cnts_buffer
,
torch
::
Tensor
cumsum_buffer
)
{
torch
::
Tensor
token_cnts_buffer
,
torch
::
Tensor
cumsum_buffer
)
{
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
TORCH_CHECK
(
num_experts
==
256
,
"moe_align_block_size kernel only support deepseek v3 now."
);
DISPATCH_INTEGRAL_TYPES
(
topk_ids
.
scalar_type
(),
"moe_align_block_size_kernel"
,
[
&
]
{
DISPATCH_INTEGRAL_TYPES
(
topk_ids
.
scalar_type
(),
"moe_align_block_size_kernel"
,
[
&
]
{
auto
kernel
=
moe_align_block_size_kernel
<
scalar_t
>
;
auto
align_
kernel
=
moe_align_block_size_kernel
<
scalar_t
>
;
kernel
<<<
1
,
1024
,
0
,
stream
>>>
(
topk_ids
.
data_ptr
<
scalar_t
>
(),
sorted_token_ids
.
data_ptr
<
int32_t
>
(),
align_
kernel
<<<
1
,
1024
,
0
,
stream
>>>
(
topk_ids
.
data_ptr
<
scalar_t
>
(),
sorted_token_ids
.
data_ptr
<
int32_t
>
(),
experts_ids
.
data_ptr
<
int32_t
>
(),
num_tokens_post_pad
.
data_ptr
<
int32_t
>
(),
experts_ids
.
data_ptr
<
int32_t
>
(),
num_tokens_post_pad
.
data_ptr
<
int32_t
>
(),
num_experts
,
block_size
,
topk_ids
.
numel
(),
cumsum_buffer
.
data_ptr
<
int32_t
>
());
num_experts
,
block_size
,
topk_ids
.
numel
(),
cumsum_buffer
.
data_ptr
<
int32_t
>
());
});
});
}
}
sgl-kernel/src/sgl-kernel/include/utils.h
View file @
ad349985
...
@@ -79,3 +79,15 @@ inline int getSMVersion() {
...
@@ -79,3 +79,15 @@ inline int getSMVersion() {
return false; \
return false; \
} \
} \
}()
}()
#define DISPATCH_CASE_INTEGRAL_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__)
#define DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH(TYPE, NAME, DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__))
#define CEILDIV(x, y) (((x) + (y)-1) / (y))
sgl-kernel/tests/test_moe_align.py
View file @
ad349985
import
itertools
import
pytest
import
torch
import
torch
import
triton
import
triton.language
as
tl
from
sgl_kernel
import
moe_align_block_size
from
sgl_kernel
import
moe_align_block_size
def
test_moe_align_block_size
():
def
ceil_div
(
a
,
b
):
return
(
a
+
b
-
1
)
//
b
@
triton
.
jit
def
moe_align_block_size_stage1
(
topk_ids_ptr
,
tokens_cnts_ptr
,
num_experts
:
tl
.
constexpr
,
numel
:
tl
.
constexpr
,
tokens_per_thread
:
tl
.
constexpr
,
):
pid
=
tl
.
program_id
(
0
)
start_idx
=
pid
*
tokens_per_thread
off_c
=
(
pid
+
1
)
*
num_experts
for
i
in
range
(
tokens_per_thread
):
if
start_idx
+
i
<
numel
:
idx
=
tl
.
load
(
topk_ids_ptr
+
start_idx
+
i
)
token_cnt
=
tl
.
load
(
tokens_cnts_ptr
+
off_c
+
idx
)
tl
.
store
(
tokens_cnts_ptr
+
off_c
+
idx
,
token_cnt
+
1
)
@
triton
.
jit
def
moe_align_block_size_stage2
(
tokens_cnts_ptr
,
num_experts
:
tl
.
constexpr
,
):
pid
=
tl
.
program_id
(
0
)
last_cnt
=
0
for
i
in
range
(
1
,
num_experts
+
1
):
token_cnt
=
tl
.
load
(
tokens_cnts_ptr
+
i
*
num_experts
+
pid
)
last_cnt
=
last_cnt
+
token_cnt
tl
.
store
(
tokens_cnts_ptr
+
i
*
num_experts
+
pid
,
last_cnt
)
@
triton
.
jit
def
moe_align_block_size_stage3
(
total_tokens_post_pad_ptr
,
tokens_cnts_ptr
,
cumsum_ptr
,
num_experts
:
tl
.
constexpr
,
block_size
:
tl
.
constexpr
,
):
last_cumsum
=
0
off_cnt
=
num_experts
*
num_experts
for
i
in
range
(
1
,
num_experts
+
1
):
token_cnt
=
tl
.
load
(
tokens_cnts_ptr
+
off_cnt
+
i
-
1
)
last_cumsum
=
last_cumsum
+
tl
.
cdiv
(
token_cnt
,
block_size
)
*
block_size
tl
.
store
(
cumsum_ptr
+
i
,
last_cumsum
)
tl
.
store
(
total_tokens_post_pad_ptr
,
last_cumsum
)
@
triton
.
jit
def
moe_align_block_size_stage4
(
topk_ids_ptr
,
sorted_token_ids_ptr
,
expert_ids_ptr
,
tokens_cnts_ptr
,
cumsum_ptr
,
num_experts
:
tl
.
constexpr
,
block_size
:
tl
.
constexpr
,
numel
:
tl
.
constexpr
,
tokens_per_thread
:
tl
.
constexpr
,
):
pid
=
tl
.
program_id
(
0
)
start_idx
=
tl
.
load
(
cumsum_ptr
+
pid
)
end_idx
=
tl
.
load
(
cumsum_ptr
+
pid
+
1
)
for
i
in
range
(
start_idx
,
end_idx
,
block_size
):
tl
.
store
(
expert_ids_ptr
+
i
//
block_size
,
pid
)
start_idx
=
pid
*
tokens_per_thread
off_t
=
pid
*
num_experts
for
i
in
range
(
start_idx
,
tl
.
minimum
(
start_idx
+
tokens_per_thread
,
numel
)):
expert_id
=
tl
.
load
(
topk_ids_ptr
+
i
)
token_cnt
=
tl
.
load
(
tokens_cnts_ptr
+
off_t
+
expert_id
)
rank_post_pad
=
token_cnt
+
tl
.
load
(
cumsum_ptr
+
expert_id
)
tl
.
store
(
sorted_token_ids_ptr
+
rank_post_pad
,
i
)
tl
.
store
(
tokens_cnts_ptr
+
off_t
+
expert_id
,
token_cnt
+
1
)
def
moe_align_block_size_triton
(
topk_ids
:
torch
.
Tensor
,
num_experts
:
int
,
block_size
:
int
,
sorted_token_ids
:
torch
.
Tensor
,
expert_ids
:
torch
.
Tensor
,
num_tokens_post_pad
:
torch
.
Tensor
,
)
->
None
:
numel
=
topk_ids
.
numel
()
grid
=
(
num_experts
,)
tokens_cnts
=
torch
.
zeros
(
(
num_experts
+
1
,
num_experts
),
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
)
cumsum
=
torch
.
zeros
((
num_experts
+
1
,),
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
)
tokens_per_thread
=
ceil_div
(
numel
,
num_experts
)
moe_align_block_size_stage1
[
grid
](
topk_ids
,
tokens_cnts
,
num_experts
,
numel
,
tokens_per_thread
,
)
moe_align_block_size_stage2
[
grid
](
tokens_cnts
,
num_experts
,
)
moe_align_block_size_stage3
[(
1
,)](
num_tokens_post_pad
,
tokens_cnts
,
cumsum
,
num_experts
,
block_size
,
)
moe_align_block_size_stage4
[
grid
](
topk_ids
,
sorted_token_ids
,
expert_ids
,
tokens_cnts
,
cumsum
,
num_experts
,
block_size
,
numel
,
tokens_per_thread
,
)
@
pytest
.
mark
.
parametrize
(
"block_size,num_tokens,topk"
,
list
(
itertools
.
product
(
[
32
,
64
,
128
,
256
],
# block_size
[
1
,
2
,
4
,
8
,
16
,
32
,
64
,
128
,
256
,
512
,
1024
,
2048
,
4096
],
# num_tokens
[
1
,
2
,
4
,
8
,
16
,
32
,
64
],
# topk
)
),
)
def
test_moe_align_block_size_compare_implementations
(
block_size
,
num_tokens
,
topk
):
# For DeepSeek V3, we have 256 experts
# For DeepSeek V3, we have 256 experts
num_experts
=
256
num_experts
=
256
# Test different combinations of block_size, num_tokens and topk
topk_ids
=
torch
.
stack
(
for
block_size
in
[
32
,
64
,
128
,
256
]:
[
print
(
f
"
\n
Testing block_size=
{
block_size
}
"
)
torch
.
randperm
(
num_experts
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)[:
topk
]
for
num_tokens
in
[
1
,
2
,
4
,
8
,
16
,
32
,
64
,
128
,
256
,
512
,
1024
,
2048
,
4096
]:
for
_
in
range
(
num_tokens
)
for
topk
in
[
1
,
2
,
4
,
8
,
16
,
32
,
64
]:
]
print
(
)
f
"Testing block_size=
{
block_size
}
, num_tokens=
{
num_tokens
}
, topk=
{
topk
}
"
)
max_num_tokens_padded
=
topk_ids
.
numel
()
+
num_experts
*
(
block_size
-
1
)
# Create random topk_ids with shape [num_tokens, topk]
sorted_ids_cuda
=
torch
.
empty
(
topk_ids
=
torch
.
randint
(
(
max_num_tokens_padded
,),
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
0
,
num_experts
,
(
num_tokens
,
topk
),
dtype
=
torch
.
int32
,
device
=
"cuda"
)
)
sorted_ids_cuda
.
fill_
(
topk_ids
.
numel
())
max_num_m_blocks
=
max_num_tokens_padded
//
block_size
max_num_tokens_padded
=
topk_ids
.
numel
()
+
num_experts
*
(
expert_ids_cuda
=
torch
.
zeros
(
block_size
-
1
(
max_num_m_blocks
,),
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
)
)
sorted_ids
=
torch
.
empty
(
num_tokens_post_pad_cuda
=
torch
.
empty
(
(
max_num_tokens_padded
,),
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
(
1
),
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
)
)
sorted_ids
.
fill_
(
topk_ids
.
numel
())
token_cnts_buffer
=
torch
.
empty
(
max_num_m_blocks
=
max_num_tokens_padded
//
block_size
(
num_experts
+
1
)
*
num_experts
,
expert_ids
=
torch
.
empty
(
dtype
=
torch
.
int32
,
(
max_num_m_blocks
,),
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
device
=
topk_ids
.
device
,
)
)
num_tokens_post_pad
=
torch
.
empty
(
cumsum_buffer
=
torch
.
empty
(
(
1
),
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
num_experts
+
1
,
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
)
)
token_cnts_buffer
=
torch
.
empty
(
sorted_ids_triton
=
torch
.
empty_like
(
sorted_ids_cuda
)
(
num_experts
+
1
)
*
num_experts
,
sorted_ids_triton
.
fill_
(
topk_ids
.
numel
())
dtype
=
torch
.
int32
,
expert_ids_triton
=
torch
.
zeros_like
(
expert_ids_cuda
)
device
=
topk_ids
.
device
,
num_tokens_post_pad_triton
=
torch
.
empty_like
(
num_tokens_post_pad_cuda
)
)
cumsum_buffer
=
torch
.
empty
(
moe_align_block_size
(
num_experts
+
1
,
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
topk_ids
,
)
num_experts
,
block_size
,
try
:
sorted_ids_cuda
,
moe_align_block_size
(
expert_ids_cuda
,
topk_ids
,
num_tokens_post_pad_cuda
,
num_experts
,
token_cnts_buffer
,
block_size
,
cumsum_buffer
,
sorted_ids
,
)
expert_ids
,
num_tokens_post_pad
,
moe_align_block_size_triton
(
token_cnts_buffer
,
topk_ids
,
cumsum_buffer
,
num_experts
,
)
block_size
,
except
Exception
as
e
:
sorted_ids_triton
,
print
(
expert_ids_triton
,
f
"Error occurred with block_size=
{
block_size
}
, num_tokens=
{
num_tokens
}
, topk=
{
topk
}
"
num_tokens_post_pad_triton
,
)
)
print
(
f
"Error message:
{
str
(
e
)
}
"
)
raise
e
assert
torch
.
allclose
(
expert_ids_cuda
,
expert_ids_triton
),
(
f
"Expert IDs mismatch for block_size=
{
block_size
}
, "
f
"num_tokens=
{
num_tokens
}
, topk=
{
topk
}
\n
"
f
"CUDA expert_ids:
{
expert_ids_cuda
}
\n
"
f
"Triton expert_ids:
{
expert_ids_triton
}
"
)
assert
torch
.
allclose
(
num_tokens_post_pad_cuda
,
num_tokens_post_pad_triton
),
(
f
"Num tokens post pad mismatch for block_size=
{
block_size
}
, "
f
"num_tokens=
{
num_tokens
}
, topk=
{
topk
}
\n
"
f
"CUDA num_tokens_post_pad:
{
num_tokens_post_pad_cuda
}
\n
"
f
"Triton num_tokens_post_pad:
{
num_tokens_post_pad_triton
}
"
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
test
_moe_align_block_size
(
)
py
test
.
main
([
__file__
]
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment