Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
f730362e
Unverified
Commit
f730362e
authored
Apr 10, 2025
by
Xiaoyu Zhang
Committed by
GitHub
Apr 09, 2025
Browse files
reduce moe_align_block_size_kernel small batch mode overhead (#5086)
parent
e3c4bd31
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
143 additions
and
56 deletions
+143
-56
python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py
python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py
+1
-1
sgl-kernel/benchmark/bench_moe_align_block_size.py
sgl-kernel/benchmark/bench_moe_align_block_size.py
+31
-10
sgl-kernel/csrc/moe/moe_align_kernel.cu
sgl-kernel/csrc/moe/moe_align_kernel.cu
+111
-44
sgl-kernel/tests/test_moe_align.py
sgl-kernel/tests/test_moe_align.py
+0
-1
No files found.
python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py
View file @
f730362e
...
...
@@ -702,7 +702,7 @@ def moe_align_block_size(
num_tokens_post_pad
,
)
else
:
token_cnts_buffer
=
torch
.
zeros
(
token_cnts_buffer
=
torch
.
empty
(
(
num_experts
+
1
)
*
num_experts
,
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
,
...
...
sgl-kernel/benchmark/bench_moe_align_block_size.py
View file @
f730362e
...
...
@@ -241,9 +241,9 @@ def calculate_diff(num_tokens, num_experts=256, block_size=128, topk=8):
# Test range
num_tokens_range
=
[
16
,
32
,
64
,
128
,
256
,
512
,
1024
,
2048
,
4096
,
8192
]
num_tokens_range
=
[
1
,
8
,
16
,
32
,
64
,
128
,
256
,
512
,
1024
,
2048
,
4096
,
8192
]
num_experts_range
=
[
8
,
32
,
64
,
128
,
256
]
topk_range
=
[
2
,
4
,
8
]
topk_range
=
[
1
,
2
,
4
,
8
]
configs
=
list
(
itertools
.
product
(
num_tokens_range
,
num_experts_range
,
topk_range
))
...
...
@@ -294,17 +294,28 @@ def benchmark(num_tokens, num_experts, topk, provider):
(
max_num_m_blocks
,),
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
)
num_tokens_post_pad
=
torch
.
empty
((
1
),
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
)
token_cnts_buffer
=
torch
.
zeros
(
(
num_experts
+
1
)
*
num_experts
,
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
)
cumsum_buffer
=
torch
.
zeros
(
num_experts
+
1
,
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
)
quantiles
=
[
0.5
,
0.2
,
0.8
]
if
provider
==
"sgl"
:
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
lambda
:
sgl_moe_align_block_size
(
def
sgl_moe_align_block_size_with_empty
(
topk_ids
,
num_experts
,
block_size
,
sorted_ids
,
expert_ids
,
num_tokens_post_pad
,
):
token_cnts_buffer
=
torch
.
empty
(
(
num_experts
+
1
)
*
num_experts
,
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
,
)
cumsum_buffer
=
torch
.
empty
(
num_experts
+
1
,
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
)
sgl_moe_align_block_size
(
topk_ids
,
num_experts
,
block_size
,
...
...
@@ -313,6 +324,16 @@ def benchmark(num_tokens, num_experts, topk, provider):
num_tokens_post_pad
.
clone
(),
token_cnts_buffer
,
cumsum_buffer
,
)
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
lambda
:
sgl_moe_align_block_size_with_empty
(
topk_ids
,
num_experts
,
block_size
,
sorted_ids
,
expert_ids
,
num_tokens_post_pad
,
),
quantiles
=
quantiles
,
)
...
...
sgl-kernel/csrc/moe/moe_align_kernel.cu
View file @
f730362e
...
...
@@ -64,10 +64,10 @@ __global__ void moe_align_block_size_kernel(
__syncthreads
();
const
size_t
t
okens_per_thread
=
CEILDIV
(
numel
,
blockDim
.
x
)
;
const
size_t
st
art_
id
x
=
threadIdx
.
x
*
tokens_per_thread
;
const
size_t
t
id
=
threadIdx
.
x
;
const
size_t
st
r
id
e
=
blockDim
.
x
;
for
(
in
t
i
=
start_
id
x
;
i
<
numel
&&
i
<
st
art_idx
+
tokens_per_thread
;
++
i
)
{
for
(
size_
t
i
=
t
id
;
i
<
numel
;
i
+=
st
ride
)
{
int
expert_id
=
topk_ids
[
i
];
int
warp_idx
=
expert_id
/
experts_per_warp
;
int
expert_offset
=
expert_id
%
experts_per_warp
;
...
...
@@ -98,6 +98,65 @@ __global__ void moe_align_block_size_kernel(
}
}
template
<
typename
scalar_t
>
__global__
void
moe_align_block_size_small_batch_expert_kernel
(
const
scalar_t
*
__restrict__
topk_ids
,
int32_t
*
__restrict__
sorted_token_ids
,
int32_t
*
__restrict__
expert_ids
,
int32_t
*
__restrict__
total_tokens_post_pad
,
int32_t
num_experts
,
int32_t
block_size
,
size_t
numel
)
{
const
size_t
tid
=
threadIdx
.
x
;
const
size_t
stride
=
blockDim
.
x
;
extern
__shared__
int32_t
shared_mem
[];
int32_t
*
cumsum
=
shared_mem
;
int32_t
*
tokens_cnts
=
(
int32_t
*
)(
shared_mem
+
num_experts
+
1
);
for
(
int
i
=
0
;
i
<
num_experts
;
++
i
)
{
tokens_cnts
[(
threadIdx
.
x
+
1
)
*
num_experts
+
i
]
=
0
;
}
for
(
size_t
i
=
tid
;
i
<
numel
;
i
+=
stride
)
{
++
tokens_cnts
[(
threadIdx
.
x
+
1
)
*
num_experts
+
topk_ids
[
i
]];
}
__syncthreads
();
if
(
threadIdx
.
x
<
num_experts
)
{
tokens_cnts
[
threadIdx
.
x
]
=
0
;
for
(
int
i
=
1
;
i
<=
blockDim
.
x
;
++
i
)
{
tokens_cnts
[
i
*
num_experts
+
threadIdx
.
x
]
+=
tokens_cnts
[(
i
-
1
)
*
num_experts
+
threadIdx
.
x
];
}
}
__syncthreads
();
if
(
threadIdx
.
x
==
0
)
{
cumsum
[
0
]
=
0
;
for
(
int
i
=
1
;
i
<=
num_experts
;
++
i
)
{
cumsum
[
i
]
=
cumsum
[
i
-
1
]
+
CEILDIV
(
tokens_cnts
[
blockDim
.
x
*
num_experts
+
i
-
1
],
block_size
)
*
block_size
;
}
*
total_tokens_post_pad
=
static_cast
<
int32_t
>
(
cumsum
[
num_experts
]);
}
__syncthreads
();
if
(
threadIdx
.
x
<
num_experts
)
{
for
(
int
i
=
cumsum
[
threadIdx
.
x
];
i
<
cumsum
[
threadIdx
.
x
+
1
];
i
+=
block_size
)
{
expert_ids
[
i
/
block_size
]
=
threadIdx
.
x
;
}
}
for
(
size_t
i
=
tid
;
i
<
numel
;
i
+=
stride
)
{
int32_t
expert_id
=
topk_ids
[
i
];
int32_t
rank_post_pad
=
tokens_cnts
[
threadIdx
.
x
*
num_experts
+
expert_id
]
+
cumsum
[
expert_id
];
sorted_token_ids
[
rank_post_pad
]
=
i
;
++
tokens_cnts
[
threadIdx
.
x
*
num_experts
+
expert_id
];
}
}
void
moe_align_block_size
(
torch
::
Tensor
topk_ids
,
int64_t
num_experts
,
...
...
@@ -111,50 +170,58 @@ void moe_align_block_size(
int64_t
padded_num_experts
=
((
num_experts
+
WARP_SIZE
-
1
)
/
WARP_SIZE
)
*
WARP_SIZE
;
int
experts_per_warp
;
int
threads
;
if
(
num_experts
<=
8
)
{
experts_per_warp
=
8
;
threads
=
256
;
}
else
if
(
num_experts
<=
16
)
{
experts_per_warp
=
16
;
threads
=
512
;
}
else
{
experts_per_warp
=
WARP_SIZE
;
threads
=
1024
;
}
int
experts_per_warp
=
WARP_SIZE
;
int
threads
=
1024
;
threads
=
((
threads
+
WARP_SIZE
-
1
)
/
WARP_SIZE
)
*
WARP_SIZE
;
DISPATCH_INTEGRAL_TYPES
(
topk_ids
.
scalar_type
(),
"moe_align_block_size_kernel"
,
[
&
]
{
auto
align_kernel
=
moe_align_block_size_kernel
<
scalar_t
>
;
size_t
num_warps
=
CEILDIV
(
padded_num_experts
,
experts_per_warp
);
size_t
shared_mem_size
=
num_warps
*
experts_per_warp
*
sizeof
(
int32_t
);
align_kernel
<<<
1
,
threads
,
shared_mem_size
,
stream
>>>
(
topk_ids
.
data_ptr
<
scalar_t
>
(),
sorted_token_ids
.
data_ptr
<
int32_t
>
(),
experts_ids
.
data_ptr
<
int32_t
>
(),
num_tokens_post_pad
.
data_ptr
<
int32_t
>
(),
num_experts
,
padded_num_experts
,
experts_per_warp
,
block_size
,
topk_ids
.
numel
(),
cumsum_buffer
.
data_ptr
<
int32_t
>
());
const
int
block_threads
=
std
::
min
(
256
,
(
int
)
threads
);
const
int
num_blocks
=
(
topk_ids
.
numel
()
+
block_threads
-
1
)
/
block_threads
;
const
int
max_blocks
=
65535
;
const
int
actual_blocks
=
std
::
min
(
num_blocks
,
max_blocks
);
auto
sort_kernel
=
count_and_sort_expert_tokens_kernel
<
scalar_t
>
;
sort_kernel
<<<
actual_blocks
,
block_threads
,
0
,
stream
>>>
(
topk_ids
.
data_ptr
<
scalar_t
>
(),
sorted_token_ids
.
data_ptr
<
int32_t
>
(),
cumsum_buffer
.
data_ptr
<
int32_t
>
(),
topk_ids
.
numel
());
bool
small_batch_expert_mode
=
(
topk_ids
.
numel
()
<
1024
)
&&
(
num_experts
<=
64
);
if
(
small_batch_expert_mode
)
{
const
int32_t
threads
=
max
((
int32_t
)
num_experts
,
WARP_SIZE
);
const
int32_t
shared_mem_size
=
((
threads
+
1
)
*
num_experts
+
(
num_experts
+
1
))
*
sizeof
(
int32_t
);
auto
small_batch_expert_kernel
=
moe_align_block_size_small_batch_expert_kernel
<
scalar_t
>
;
small_batch_expert_kernel
<<<
1
,
threads
,
shared_mem_size
,
stream
>>>
(
topk_ids
.
data_ptr
<
scalar_t
>
(),
sorted_token_ids
.
data_ptr
<
int32_t
>
(),
experts_ids
.
data_ptr
<
int32_t
>
(),
num_tokens_post_pad
.
data_ptr
<
int32_t
>
(),
num_experts
,
block_size
,
topk_ids
.
numel
());
}
else
{
auto
align_kernel
=
moe_align_block_size_kernel
<
scalar_t
>
;
size_t
num_warps
=
CEILDIV
(
padded_num_experts
,
experts_per_warp
);
size_t
shared_mem_size
=
num_warps
*
experts_per_warp
*
sizeof
(
int32_t
);
cumsum_buffer
.
zero_
();
align_kernel
<<<
1
,
threads
,
shared_mem_size
,
stream
>>>
(
topk_ids
.
data_ptr
<
scalar_t
>
(),
sorted_token_ids
.
data_ptr
<
int32_t
>
(),
experts_ids
.
data_ptr
<
int32_t
>
(),
num_tokens_post_pad
.
data_ptr
<
int32_t
>
(),
num_experts
,
padded_num_experts
,
experts_per_warp
,
block_size
,
topk_ids
.
numel
(),
cumsum_buffer
.
data_ptr
<
int32_t
>
());
const
int
block_threads
=
std
::
min
(
256
,
(
int
)
threads
);
const
int
num_blocks
=
(
topk_ids
.
numel
()
+
block_threads
-
1
)
/
block_threads
;
const
int
max_blocks
=
65535
;
const
int
actual_blocks
=
std
::
min
(
num_blocks
,
max_blocks
);
auto
sort_kernel
=
count_and_sort_expert_tokens_kernel
<
scalar_t
>
;
sort_kernel
<<<
actual_blocks
,
block_threads
,
0
,
stream
>>>
(
topk_ids
.
data_ptr
<
scalar_t
>
(),
sorted_token_ids
.
data_ptr
<
int32_t
>
(),
cumsum_buffer
.
data_ptr
<
int32_t
>
(),
topk_ids
.
numel
());
}
});
}
sgl-kernel/tests/test_moe_align.py
View file @
f730362e
...
...
@@ -151,7 +151,6 @@ def moe_align_block_size_triton(
def
test_moe_align_block_size_compare_implementations
(
block_size
,
num_tokens
,
topk
,
num_experts
):
# For DeepSeek V3, we have 256 experts
topk_ids
=
torch
.
stack
(
[
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment