Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
a3398d84
Unverified
Commit
a3398d84
authored
Jul 07, 2025
by
Ke Bao
Committed by
GitHub
Jul 07, 2025
Browse files
Optimize moe align block size kernel (#7794)
parent
ba69c153
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
100 additions
and
63 deletions
+100
-63
sgl-kernel/csrc/moe/moe_align_kernel.cu
sgl-kernel/csrc/moe/moe_align_kernel.cu
+94
-63
sgl-kernel/include/utils.h
sgl-kernel/include/utils.h
+6
-0
No files found.
sgl-kernel/csrc/moe/moe_align_kernel.cu
View file @
a3398d84
...
...
@@ -21,16 +21,10 @@ limitations under the License.
#include "utils.h"
template
<
typename
T
,
int
N
,
int
Alignment
=
sizeof
(
T
)
*
N
>
class
alignas
(
Alignment
)
AlignedArray
{
public:
T
data
[
N
];
};
#define WARP_SIZE 32
#define VEC_SIZE 4
using
Vec
=
AlignedArray
<
int32_t
,
VEC_SIZE
>
;
using
Vec
=
int4
;
template
<
typename
scalar_t
>
__global__
void
count_and_sort_expert_tokens_kernel
(
...
...
@@ -55,73 +49,119 @@ __global__ void moe_align_block_size_kernel(
int32_t
*
__restrict__
expert_ids
,
int32_t
*
__restrict__
total_tokens_post_pad
,
int32_t
num_experts
,
int32_t
padded_num_experts
,
int32_t
experts_per_warp
,
int32_t
block_size
,
size_t
numel
,
int32_t
*
__restrict__
cumsum
,
bool
pad_sorted_token_ids
)
{
extern
__shared__
int32_t
shared_counts
[];
bool
pad_sorted_token_ids
,
const
int32_t
scan_size
)
{
extern
__shared__
int32_t
smem
[];
int32_t
*
shared_counts
=
smem
;
// [num_experts]
int32_t
*
prefix
=
shared_counts
+
num_experts
;
// [num_experts + 1]
int32_t
*
scan_buf
=
prefix
+
num_experts
+
1
;
// [scan_size]
__shared__
int32_t
s_total_tokens_post_pad
;
const
int
warp_
id
=
threadIdx
.
x
/
WARP_SIZE
;
const
int
my_expert_start
=
warp_id
*
experts_per_warp
;
const
size_t
t
id
=
threadIdx
.
x
;
const
size_t
stride
=
blockDim
.
x
;
for
(
int
i
=
0
;
i
<
experts_per_warp
;
++
i
)
{
if
(
my_expert_start
+
i
<
padded_num_experts
)
{
shared_counts
[
warp_id
*
experts_per_warp
+
i
]
=
0
;
}
if
(
tid
<
num_experts
)
{
shared_counts
[
tid
]
=
0
;
}
__syncthreads
();
const
size_t
tid
=
threadIdx
.
x
;
const
size_t
stride
=
blockDim
.
x
;
for
(
size_t
i
=
tid
;
i
<
numel
;
i
+=
stride
)
{
int
expert_id
=
topk_ids
[
i
];
int
warp_idx
=
expert_id
/
experts_per_warp
;
int
expert_offset
=
expert_id
%
experts_per_warp
;
atomicAdd
(
&
shared_counts
[
warp_idx
*
experts_per_warp
+
expert_offset
],
1
);
atomicAdd
(
&
shared_counts
[
expert_id
],
1
);
}
__syncthreads
();
if
(
threadIdx
.
x
==
0
)
{
cumsum
[
0
]
=
0
;
for
(
int
i
=
1
;
i
<=
num_experts
;
++
i
)
{
int
expert_count
=
0
;
int
warp_idx
=
(
i
-
1
)
/
experts_per_warp
;
int
expert_offset
=
(
i
-
1
)
%
experts_per_warp
;
expert_count
=
shared_counts
[
warp_idx
*
experts_per_warp
+
expert_offset
];
int32_t
padded_count
=
0
;
if
(
tid
<
num_experts
)
{
int32_t
count
=
shared_counts
[
tid
];
padded_count
=
(
count
+
block_size
-
1
)
/
block_size
*
block_size
;
scan_buf
[
tid
]
=
padded_count
;
}
cumsum
[
i
]
=
cumsum
[
i
-
1
]
+
CEILDIV
(
expert_count
,
block_size
)
*
block_size
;
}
*
total_tokens_post_pad
=
cumsum
[
num_experts
];
if
(
tid
>=
num_experts
&&
tid
<
scan_size
)
{
scan_buf
[
tid
]
=
0
;
}
__syncthreads
();
if
(
threadIdx
.
x
<
num_experts
)
{
for
(
int
i
=
cumsum
[
threadIdx
.
x
];
i
<
cumsum
[
threadIdx
.
x
+
1
];
i
+=
block_size
)
{
expert_ids
[
i
/
block_size
]
=
threadIdx
.
x
;
// Blelloch scan
int
offset
=
1
;
#pragma unroll
for
(
int
d
=
scan_size
>>
1
;
d
>
0
;
d
>>=
1
)
{
if
(
tid
<
d
)
{
int
ai
=
offset
*
(
2
*
tid
+
1
)
-
1
;
int
bi
=
offset
*
(
2
*
tid
+
2
)
-
1
;
scan_buf
[
bi
]
+=
scan_buf
[
ai
];
}
offset
<<=
1
;
__syncthreads
();
}
if
(
pad_sorted_token_ids
)
{
int32_t
fill_val
=
static_cast
<
int32_t
>
(
numel
);
int32_t
total
=
*
total_tokens_post_pad
;
// down-sweep
if
(
tid
==
0
)
{
prefix
[
num_experts
]
=
scan_buf
[
scan_size
-
1
];
scan_buf
[
scan_size
-
1
]
=
0
;
}
__syncthreads
();
Vec
fill_vec
;
#pragma unroll
for
(
int
i
=
0
;
i
<
VEC_SIZE
;
++
i
)
{
fill_vec
.
data
[
i
]
=
fill_val
;
for
(
int
d
=
1
;
d
<
scan_size
;
d
<<=
1
)
{
offset
>>=
1
;
if
(
tid
<
d
)
{
int
ai
=
offset
*
(
2
*
tid
+
1
)
-
1
;
int
bi
=
offset
*
(
2
*
tid
+
2
)
-
1
;
if
(
bi
<
scan_size
)
{
int
temp
=
scan_buf
[
ai
];
scan_buf
[
ai
]
=
scan_buf
[
bi
];
scan_buf
[
bi
]
+=
temp
;
}
}
__syncthreads
();
}
int32_t
total_vec_count
=
(
total
+
VEC_SIZE
-
1
)
/
VEC_SIZE
;
Vec
*
out_ptr
=
reinterpret_cast
<
Vec
*>
(
sorted_token_ids
);
if
(
tid
<
num_experts
)
{
prefix
[
tid
]
=
scan_buf
[
tid
];
}
for
(
int32_t
idx
=
tid
;
idx
<
total_vec_count
;
idx
+=
stride
)
{
out_ptr
[
idx
]
=
fill_vec
;
if
(
tid
==
0
)
{
s_total_tokens_post_pad
=
prefix
[
num_experts
];
*
total_tokens_post_pad
=
s_total_tokens_post_pad
;
}
__syncthreads
();
if
(
tid
<=
num_experts
)
{
cumsum
[
tid
]
=
prefix
[
tid
];
}
// fill expert_ids
const
int32_t
num_blocks
=
s_total_tokens_post_pad
/
block_size
;
for
(
int32_t
i
=
tid
;
i
<
num_blocks
;
i
+=
stride
)
{
int32_t
block_start
=
i
*
block_size
;
int
left
=
0
,
right
=
num_experts
;
while
(
left
<
right
)
{
int
mid
=
(
left
+
right
)
>>
1
;
if
(
prefix
[
mid
]
<=
block_start
)
{
left
=
mid
+
1
;
}
else
{
right
=
mid
;
}
}
expert_ids
[
i
]
=
left
-
1
;
}
if
(
pad_sorted_token_ids
)
{
Vec
fill_vec
;
fill_vec
.
x
=
fill_vec
.
y
=
fill_vec
.
z
=
fill_vec
.
w
=
numel
;
int32_t
total_vecs
=
(
s_total_tokens_post_pad
+
VEC_SIZE
-
1
)
/
VEC_SIZE
;
Vec
*
out_ptr
=
reinterpret_cast
<
Vec
*>
(
sorted_token_ids
);
for
(
int32_t
i
=
tid
;
i
<
total_vecs
;
i
+=
stride
)
{
out_ptr
[
i
]
=
fill_vec
;
}
}
}
...
...
@@ -179,20 +219,12 @@ __global__ void moe_align_block_size_small_batch_expert_kernel(
}
if
(
pad_sorted_token_ids
)
{
int32_t
fill_val
=
static_cast
<
int32_t
>
(
numel
);
int32_t
total
=
*
total_tokens_post_pad
;
Vec
fill_vec
;
#pragma unroll
for
(
int
i
=
0
;
i
<
VEC_SIZE
;
++
i
)
{
fill_vec
.
data
[
i
]
=
fill_val
;
}
int32_t
total_vec_count
=
(
total
+
VEC_SIZE
-
1
)
/
VEC_SIZE
;
fill_vec
.
x
=
fill_vec
.
y
=
fill_vec
.
z
=
fill_vec
.
w
=
numel
;
int32_t
total_vecs
=
(
*
total_tokens_post_pad
+
VEC_SIZE
-
1
)
/
VEC_SIZE
;
Vec
*
out_ptr
=
reinterpret_cast
<
Vec
*>
(
sorted_token_ids
);
for
(
int32_t
idx
=
tid
;
idx
<
total_vec_count
;
idx
+=
stride
)
{
out_ptr
[
idx
]
=
fill_vec
;
for
(
int32_t
i
=
tid
;
i
<
total_vecs
;
i
+=
stride
)
{
out_ptr
[
i
]
=
fill_vec
;
}
}
...
...
@@ -245,8 +277,8 @@ void moe_align_block_size(
}
else
{
auto
align_kernel
=
moe_align_block_size_kernel
<
scalar_t
>
;
size_t
num_warps
=
CEILDIV
(
padded_num_experts
,
experts_per_warp
);
size_t
shared_mem_size
=
num_
warps
*
experts_per_warp
*
sizeof
(
int32_t
);
const
size_t
scan_size
=
next_pow2
(
num_experts
);
const
size_t
shared_mem_size
=
(
num_
experts
+
(
num_experts
+
1
)
+
scan_size
)
*
sizeof
(
int32_t
);
align_kernel
<<<
1
,
threads
,
shared_mem_size
,
stream
>>>
(
topk_ids
.
data_ptr
<
scalar_t
>
(),
...
...
@@ -254,12 +286,11 @@ void moe_align_block_size(
experts_ids
.
data_ptr
<
int32_t
>
(),
num_tokens_post_pad
.
data_ptr
<
int32_t
>
(),
num_experts
,
padded_num_experts
,
experts_per_warp
,
block_size
,
topk_ids
.
numel
(),
cumsum_buffer
.
data_ptr
<
int32_t
>
(),
pad_sorted_token_ids
);
pad_sorted_token_ids
,
scan_size
);
const
int
block_threads
=
std
::
min
(
256
,
(
int
)
threads
);
const
int
num_blocks
=
(
topk_ids
.
numel
()
+
block_threads
-
1
)
/
block_threads
;
...
...
sgl-kernel/include/utils.h
View file @
a3398d84
...
...
@@ -363,3 +363,9 @@ inline torch::Tensor pad_tensor(const torch::Tensor& tensor, int64_t alignment =
}
return
tensor_padded
;
}
// Get the next power of 2 of a number
inline
uint32_t
next_pow2
(
uint32_t
x
)
noexcept
{
if
(
x
<=
1
)
return
1
;
return
1u
<<
(
32
-
__builtin_clz
(
x
-
1
));
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment