Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
817d4370
Unverified
Commit
817d4370
authored
Mar 13, 2025
by
Shi Shuai
Committed by
GitHub
Mar 12, 2025
Browse files
feat: support ep size < 32 for sgl kernel (#4348)
parent
c550e52f
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
82 additions
and
35 deletions
+82
-35
sgl-kernel/benchmark/bench_moe_align_block_size.py
sgl-kernel/benchmark/bench_moe_align_block_size.py
+54
-29
sgl-kernel/csrc/moe/moe_align_kernel.cu
sgl-kernel/csrc/moe/moe_align_kernel.cu
+28
-6
No files found.
sgl-kernel/benchmark/bench_moe_align_block_size.py
View file @
817d4370
...
...
@@ -196,14 +196,21 @@ def calculate_diff(num_tokens, num_experts=256, block_size=128, topk=8):
expert_ids_triton
,
num_tokens_post_pad_triton
,
)
ops
.
moe_align_block_size
(
topk_ids
,
num_experts
,
block_size
,
sorted_ids_vllm
,
expert_ids_vllm
,
num_tokens_post_pad_vllm
,
)
try
:
ops
.
moe_align_block_size
(
topk_ids
,
num_experts
,
block_size
,
sorted_ids_vllm
,
expert_ids_vllm
,
num_tokens_post_pad_vllm
,
)
print
(
f
"✅ VLLM implementation works with
{
num_experts
}
experts!"
)
vllm_works
=
True
except
RuntimeError
as
e
:
print
(
f
"❌ VLLM implementation failed with
{
num_experts
}
experts:
{
e
}
"
)
vllm_works
=
False
if
torch
.
allclose
(
expert_ids_cuda
,
expert_ids_triton
)
and
torch
.
allclose
(
num_tokens_post_pad_cuda
,
num_tokens_post_pad_triton
...
...
@@ -216,20 +223,26 @@ def calculate_diff(num_tokens, num_experts=256, block_size=128, topk=8):
print
(
"SGL num_tokens_post_pad:"
,
num_tokens_post_pad_cuda
)
print
(
"Triton num_tokens_post_pad:"
,
num_tokens_post_pad_triton
)
if
torch
.
allclose
(
expert_ids_cuda
,
expert_ids_vllm
)
and
torch
.
allclose
(
num_tokens_post_pad_cuda
,
num_tokens_post_pad_vllm
if
(
vllm_works
and
torch
.
allclose
(
expert_ids_cuda
,
expert_ids_vllm
)
and
torch
.
allclose
(
num_tokens_post_pad_cuda
,
num_tokens_post_pad_vllm
)
):
print
(
"✅ SGL and VLLM implementations match"
)
else
:
print
(
"❌ SGL and VLLM implementations do not match"
)
print
(
"SGL expert_ids:"
,
expert_ids_cuda
)
print
(
"VLLM expert_ids:"
,
expert_ids_vllm
)
print
(
"SGL num_tokens_post_pad:"
,
num_tokens_post_pad_cuda
)
print
(
"VLLM num_tokens_post_pad:"
,
num_tokens_post_pad_vllm
)
if
not
vllm_works
:
print
(
"⚠️ VLLM comparison skipped due to failure"
)
else
:
print
(
"❌ SGL and VLLM implementations do not match"
)
print
(
"SGL expert_ids:"
,
expert_ids_cuda
)
print
(
"VLLM expert_ids:"
,
expert_ids_vllm
)
print
(
"SGL num_tokens_post_pad:"
,
num_tokens_post_pad_cuda
)
print
(
"VLLM num_tokens_post_pad:"
,
num_tokens_post_pad_vllm
)
# Test range
num_tokens_range
=
[
16
,
32
,
64
,
128
,
256
,
512
,
1024
,
2048
,
4096
,
8192
]
num_experts_range
=
[
32
,
64
,
128
,
256
]
num_experts_range
=
[
8
,
32
,
64
,
128
,
256
]
topk_range
=
[
2
,
4
,
8
]
configs
=
list
(
itertools
.
product
(
num_tokens_range
,
num_experts_range
,
topk_range
))
...
...
@@ -316,17 +329,22 @@ def benchmark(num_tokens, num_experts, topk, provider):
quantiles
=
quantiles
,
)
else
:
# vllm
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
lambda
:
ops
.
moe_align_block_size
(
topk_ids
,
num_experts
,
block_size
,
sorted_ids
.
clone
(),
expert_ids
.
clone
(),
num_tokens_post_pad
.
clone
(),
),
quantiles
=
quantiles
,
)
try
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
lambda
:
ops
.
moe_align_block_size
(
topk_ids
,
num_experts
,
block_size
,
sorted_ids
.
clone
(),
expert_ids
.
clone
(),
num_tokens_post_pad
.
clone
(),
),
quantiles
=
quantiles
,
)
except
RuntimeError
as
e
:
print
(
f
"❌ VLLM benchmark failed with
{
num_experts
}
experts:
{
e
}
"
)
# Return extreme values to indicate failure in the chart
return
float
(
"inf"
),
float
(
"inf"
),
float
(
"inf"
)
return
1000
*
ms
,
1000
*
max_ms
,
1000
*
min_ms
...
...
@@ -343,7 +361,7 @@ if __name__ == "__main__":
"--num_experts"
,
type
=
int
,
default
=
256
,
choices
=
[
8
,
64
,
128
,
256
],
choices
=
[
8
,
16
,
32
,
64
,
128
,
256
],
help
=
"Number of experts for benchmark"
,
)
parser
.
add_argument
(
...
...
@@ -353,8 +371,15 @@ if __name__ == "__main__":
choices
=
[
2
,
4
,
8
],
help
=
"Top-k value for benchmark"
,
)
parser
.
add_argument
(
"--skip_full_benchmark"
,
action
=
"store_true"
,
help
=
"Only run the calculate_diff function, skip full benchmarking"
,
)
args
=
parser
.
parse_args
()
calculate_diff
(
num_tokens
=
1024
,
num_experts
=
args
.
num_experts
,
topk
=
args
.
topk
)
benchmark
.
run
(
print_data
=
True
)
if
not
args
.
skip_full_benchmark
:
print
(
f
"
\n
📊 Running performance benchmark for
{
args
.
num_experts
}
experts..."
)
benchmark
.
run
(
print_data
=
True
)
sgl-kernel/csrc/moe/moe_align_kernel.cu
View file @
817d4370
...
...
@@ -47,6 +47,7 @@ __global__ void moe_align_block_size_kernel(
int32_t
*
__restrict__
expert_ids
,
int32_t
*
__restrict__
total_tokens_post_pad
,
int32_t
num_experts
,
int32_t
padded_num_experts
,
int32_t
experts_per_warp
,
int32_t
block_size
,
size_t
numel
,
...
...
@@ -57,7 +58,7 @@ __global__ void moe_align_block_size_kernel(
const
int
my_expert_start
=
warp_id
*
experts_per_warp
;
for
(
int
i
=
0
;
i
<
experts_per_warp
;
++
i
)
{
if
(
my_expert_start
+
i
<
num_experts
)
{
if
(
my_expert_start
+
i
<
padded_
num_experts
)
{
shared_counts
[
warp_id
*
experts_per_warp
+
i
]
=
0
;
}
}
...
...
@@ -108,23 +109,44 @@ void moe_align_block_size(
torch
::
Tensor
token_cnts_buffer
,
torch
::
Tensor
cumsum_buffer
)
{
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
TORCH_CHECK
(
num_experts
%
WARP_SIZE
==
0
);
int
experts_per_warp
=
num_experts
/
WARP_SIZE
;
int64_t
padded_num_experts
=
((
num_experts
+
WARP_SIZE
-
1
)
/
WARP_SIZE
)
*
WARP_SIZE
;
int
experts_per_warp
;
int
threads
;
if
(
num_experts
<=
8
)
{
experts_per_warp
=
8
;
threads
=
256
;
}
else
if
(
num_experts
<=
16
)
{
experts_per_warp
=
16
;
threads
=
512
;
}
else
{
experts_per_warp
=
WARP_SIZE
;
threads
=
1024
;
}
threads
=
((
threads
+
WARP_SIZE
-
1
)
/
WARP_SIZE
)
*
WARP_SIZE
;
DISPATCH_INTEGRAL_TYPES
(
topk_ids
.
scalar_type
(),
"moe_align_block_size_kernel"
,
[
&
]
{
auto
align_kernel
=
moe_align_block_size_kernel
<
scalar_t
>
;
size_t
shared_mem_size
=
32
*
experts_per_warp
*
sizeof
(
int32_t
);
align_kernel
<<<
1
,
1024
,
shared_mem_size
,
stream
>>>
(
size_t
num_warps
=
CEILDIV
(
padded_num_experts
,
experts_per_warp
);
size_t
shared_mem_size
=
num_warps
*
experts_per_warp
*
sizeof
(
int32_t
);
align_kernel
<<<
1
,
threads
,
shared_mem_size
,
stream
>>>
(
topk_ids
.
data_ptr
<
scalar_t
>
(),
sorted_token_ids
.
data_ptr
<
int32_t
>
(),
experts_ids
.
data_ptr
<
int32_t
>
(),
num_tokens_post_pad
.
data_ptr
<
int32_t
>
(),
num_experts
,
padded_num_experts
,
experts_per_warp
,
block_size
,
topk_ids
.
numel
(),
cumsum_buffer
.
data_ptr
<
int32_t
>
());
const
int
block_threads
=
256
;
const
int
block_threads
=
std
::
min
(
256
,
(
int
)
threads
)
;
const
int
num_blocks
=
(
topk_ids
.
numel
()
+
block_threads
-
1
)
/
block_threads
;
const
int
max_blocks
=
65535
;
const
int
actual_blocks
=
std
::
min
(
num_blocks
,
max_blocks
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment