Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
8b5f83ed
Unverified
Commit
8b5f83ed
authored
Jun 07, 2025
by
Xiaoyu Zhang
Committed by
GitHub
Jun 07, 2025
Browse files
reduce torch.zeros overhead in moe align block size kernel (#6369)
parent
2a413829
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
58 additions
and
8 deletions
+58
-8
python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py
python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py
+58
-6
sgl-kernel/csrc/moe/moe_align_kernel.cu
sgl-kernel/csrc/moe/moe_align_kernel.cu
+0
-2
No files found.
python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py
View file @
8b5f83ed
...
@@ -30,6 +30,7 @@ from sglang.srt.utils import (
...
@@ -30,6 +30,7 @@ from sglang.srt.utils import (
is_cuda
,
is_cuda
,
is_hip
,
is_hip
,
log_info_on_rank0
,
log_info_on_rank0
,
next_power_of_2
,
)
)
_is_hip
=
is_hip
()
_is_hip
=
is_hip
()
...
@@ -650,6 +651,61 @@ def moe_align_block_size_triton(
...
@@ -650,6 +651,61 @@ def moe_align_block_size_triton(
)
)
@
triton
.
jit
def
init_sorted_ids_and_cumsum_buffer_kernel
(
sorted_ids_ptr
,
cumsum_buffer_ptr
,
max_num_tokens_padded
,
topk_ids_numel
,
num_experts
:
tl
.
constexpr
,
BLOCK_SIZE
:
tl
.
constexpr
,
ALIGNED_NUM_EXPERTS_P1
:
tl
.
constexpr
,
):
pid
=
tl
.
program_id
(
0
)
offsets
=
pid
*
BLOCK_SIZE
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
sorted_ids_blocks
=
tl
.
cdiv
(
max_num_tokens_padded
,
BLOCK_SIZE
)
if
pid
<
sorted_ids_blocks
:
mask
=
offsets
<
max_num_tokens_padded
tl
.
store
(
sorted_ids_ptr
+
offsets
,
tl
.
full
((
BLOCK_SIZE
,),
topk_ids_numel
,
dtype
=
tl
.
int32
),
mask
=
mask
,
)
elif
pid
==
sorted_ids_blocks
:
offset_e
=
tl
.
arange
(
0
,
ALIGNED_NUM_EXPERTS_P1
)
mask_e
=
offset_e
<
num_experts
+
1
tl
.
store
(
cumsum_buffer_ptr
+
offset_e
,
tl
.
zeros
((
ALIGNED_NUM_EXPERTS_P1
,),
dtype
=
tl
.
int32
),
mask
=
mask_e
,
)
def
init_sorted_ids_and_cumsum_buffer
(
max_num_tokens_padded
:
int
,
topk_ids_numel
:
int
,
num_experts
:
int
,
device
=
"cuda"
):
sorted_ids
=
torch
.
empty
((
max_num_tokens_padded
,),
dtype
=
torch
.
int32
,
device
=
device
)
cumsum_buffer
=
torch
.
empty
((
num_experts
+
1
,),
dtype
=
torch
.
int32
,
device
=
device
)
BLOCK_SIZE
=
1024
sorted_ids_blocks
=
triton
.
cdiv
(
max_num_tokens_padded
,
BLOCK_SIZE
)
grid
=
(
sorted_ids_blocks
+
1
,)
init_sorted_ids_and_cumsum_buffer_kernel
[
grid
](
sorted_ids
,
cumsum_buffer
,
max_num_tokens_padded
,
topk_ids_numel
,
num_experts
,
BLOCK_SIZE
,
next_power_of_2
(
num_experts
+
1
),
)
return
sorted_ids
,
cumsum_buffer
def
moe_align_block_size
(
def
moe_align_block_size
(
topk_ids
:
torch
.
Tensor
,
block_size
:
int
,
num_experts
:
int
topk_ids
:
torch
.
Tensor
,
block_size
:
int
,
num_experts
:
int
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
...
@@ -691,10 +747,9 @@ def moe_align_block_size(
...
@@ -691,10 +747,9 @@ def moe_align_block_size(
by block_size for proper block matrix operations.
by block_size for proper block matrix operations.
"""
"""
max_num_tokens_padded
=
topk_ids
.
numel
()
+
num_experts
*
(
block_size
-
1
)
max_num_tokens_padded
=
topk_ids
.
numel
()
+
num_experts
*
(
block_size
-
1
)
sorted_ids
=
torch
.
empty
(
sorted_ids
,
cumsum_buffer
=
init_sorted_ids_and_cumsum_buffer
(
(
max_num_tokens_padded
,
),
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
max_num_tokens_padded
,
topk_ids
.
numel
(),
num_experts
,
topk_ids
.
device
)
)
sorted_ids
.
fill_
(
topk_ids
.
numel
())
max_num_m_blocks
=
triton
.
cdiv
(
max_num_tokens_padded
,
block_size
)
max_num_m_blocks
=
triton
.
cdiv
(
max_num_tokens_padded
,
block_size
)
expert_ids
=
torch
.
empty
(
expert_ids
=
torch
.
empty
(
(
max_num_m_blocks
,),
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
(
max_num_m_blocks
,),
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
...
@@ -715,9 +770,6 @@ def moe_align_block_size(
...
@@ -715,9 +770,6 @@ def moe_align_block_size(
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
,
device
=
topk_ids
.
device
,
)
)
cumsum_buffer
=
torch
.
empty
(
num_experts
+
1
,
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
)
sgl_moe_align_block_size
(
sgl_moe_align_block_size
(
topk_ids
,
topk_ids
,
...
...
sgl-kernel/csrc/moe/moe_align_kernel.cu
View file @
8b5f83ed
...
@@ -197,8 +197,6 @@ void moe_align_block_size(
...
@@ -197,8 +197,6 @@ void moe_align_block_size(
size_t
num_warps
=
CEILDIV
(
padded_num_experts
,
experts_per_warp
);
size_t
num_warps
=
CEILDIV
(
padded_num_experts
,
experts_per_warp
);
size_t
shared_mem_size
=
num_warps
*
experts_per_warp
*
sizeof
(
int32_t
);
size_t
shared_mem_size
=
num_warps
*
experts_per_warp
*
sizeof
(
int32_t
);
cumsum_buffer
.
zero_
();
align_kernel
<<<
1
,
threads
,
shared_mem_size
,
stream
>>>
(
align_kernel
<<<
1
,
threads
,
shared_mem_size
,
stream
>>>
(
topk_ids
.
data_ptr
<
scalar_t
>
(),
topk_ids
.
data_ptr
<
scalar_t
>
(),
sorted_token_ids
.
data_ptr
<
int32_t
>
(),
sorted_token_ids
.
data_ptr
<
int32_t
>
(),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment