Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
af1cc8fe
"git@developer.sourcefind.cn:change/sglang.git" did not exist on "f1cf6eefbec615bfec1c026f29c0f5bb06f00ba6"
Unverified
Commit
af1cc8fe
authored
Jul 17, 2025
by
Yuan Luo
Committed by
GitHub
Jul 17, 2025
Browse files
[kernel] opt moe align block kernel by block/warp scan algorithm (#7884)
parent
49b87774
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
51 additions
and
42 deletions
+51
-42
sgl-kernel/csrc/moe/moe_align_kernel.cu
sgl-kernel/csrc/moe/moe_align_kernel.cu
+51
-42
No files found.
sgl-kernel/csrc/moe/moe_align_kernel.cu
View file @
af1cc8fe
...
@@ -26,6 +26,12 @@ limitations under the License.
...
@@ -26,6 +26,12 @@ limitations under the License.
#define VEC_SIZE 4
#define VEC_SIZE 4
using
Vec
=
int4
;
using
Vec
=
int4
;
#ifndef __CUDA_ARCH__ // HIP
#define SHFL_UP(mask, val, delta) __shfl_up((val), (delta))
#else // CUDA
#define SHFL_UP(mask, val, delta) __shfl_up_sync((mask), (val), (delta))
#endif
template
<
typename
scalar_t
>
template
<
typename
scalar_t
>
__global__
void
count_and_sort_expert_tokens_kernel
(
__global__
void
count_and_sort_expert_tokens_kernel
(
const
scalar_t
*
__restrict__
topk_ids
,
const
scalar_t
*
__restrict__
topk_ids
,
...
@@ -42,6 +48,16 @@ __global__ void count_and_sort_expert_tokens_kernel(
...
@@ -42,6 +48,16 @@ __global__ void count_and_sort_expert_tokens_kernel(
}
}
}
}
__device__
__forceinline__
int
warp_exclusive_scan
(
int
v
,
unsigned
mask
=
0xffffffffu
)
{
int
original
=
v
;
#pragma unroll
for
(
int
offset
=
1
;
offset
<
WARP_SIZE
;
offset
<<=
1
)
{
int
n
=
SHFL_UP
(
mask
,
v
,
offset
);
if
((
threadIdx
.
x
&
(
WARP_SIZE
-
1
))
>=
offset
)
v
+=
n
;
}
return
v
-
original
;
}
template
<
typename
scalar_t
>
template
<
typename
scalar_t
>
__global__
void
moe_align_block_size_kernel
(
__global__
void
moe_align_block_size_kernel
(
const
scalar_t
*
__restrict__
topk_ids
,
const
scalar_t
*
__restrict__
topk_ids
,
...
@@ -58,6 +74,7 @@ __global__ void moe_align_block_size_kernel(
...
@@ -58,6 +74,7 @@ __global__ void moe_align_block_size_kernel(
int32_t
*
shared_counts
=
smem
;
// [num_experts]
int32_t
*
shared_counts
=
smem
;
// [num_experts]
int32_t
*
prefix
=
shared_counts
+
num_experts
;
// [num_experts + 1]
int32_t
*
prefix
=
shared_counts
+
num_experts
;
// [num_experts + 1]
int32_t
*
scan_buf
=
prefix
+
num_experts
+
1
;
// [scan_size]
int32_t
*
scan_buf
=
prefix
+
num_experts
+
1
;
// [scan_size]
int32_t
*
warp_sums
=
scan_buf
+
scan_size
;
// [<= 32]
__shared__
int32_t
s_total_tokens_post_pad
;
__shared__
int32_t
s_total_tokens_post_pad
;
const
size_t
tid
=
threadIdx
.
x
;
const
size_t
tid
=
threadIdx
.
x
;
...
@@ -76,6 +93,7 @@ __global__ void moe_align_block_size_kernel(
...
@@ -76,6 +93,7 @@ __global__ void moe_align_block_size_kernel(
__syncthreads
();
__syncthreads
();
// Calculate padded_cnt, write scan_buf, directly prefix sum
int32_t
padded_count
=
0
;
int32_t
padded_count
=
0
;
if
(
tid
<
num_experts
)
{
if
(
tid
<
num_experts
)
{
int32_t
count
=
shared_counts
[
tid
];
int32_t
count
=
shared_counts
[
tid
];
...
@@ -83,58 +101,52 @@ __global__ void moe_align_block_size_kernel(
...
@@ -83,58 +101,52 @@ __global__ void moe_align_block_size_kernel(
scan_buf
[
tid
]
=
padded_count
;
scan_buf
[
tid
]
=
padded_count
;
}
}
if
(
tid
>=
num_experts
&&
tid
<
scan_size
)
{
// Intra warp prefix sum
scan_buf
[
tid
]
=
0
;
const
int
warp_id
=
tid
/
WARP_SIZE
;
}
const
int
lane_id
=
tid
&
(
WARP_SIZE
-
1
);
const
int
num_warps_for_scan
=
(
scan_size
+
WARP_SIZE
-
1
)
/
WARP_SIZE
;
const
int
warp_sum
=
warp_exclusive_scan
(
padded_count
)
+
padded_count
;
if
(
lane_id
==
WARP_SIZE
-
1
)
warp_sums
[
warp_id
]
=
warp_sum
;
__syncthreads
();
__syncthreads
();
// Blelloch scan
// warp0 accumulate all the block's prefix sum
int
offset
=
1
;
if
(
tid
<
WARP_SIZE
)
{
#pragma unroll
int
val
=
(
tid
<
num_warps_for_scan
)
?
warp_sums
[
tid
]
:
0
;
for
(
int
d
=
scan_size
>>
1
;
d
>
0
;
d
>>=
1
)
{
int
incl
=
warp_exclusive_scan
(
val
)
+
val
;
if
(
tid
<
d
)
{
warp_sums
[
tid
]
=
incl
;
int
ai
=
offset
*
(
2
*
tid
+
1
)
-
1
;
int
bi
=
offset
*
(
2
*
tid
+
2
)
-
1
;
scan_buf
[
bi
]
+=
scan_buf
[
ai
];
}
offset
<<=
1
;
__syncthreads
();
}
}
__syncthreads
();
//
down-sweep
//
Every thread obtains the whole block's sum
if
(
tid
==
0
)
{
if
(
tid
==
0
)
{
prefix
[
num_experts
]
=
scan_buf
[
scan_size
-
1
];
prefix
[
num_experts
]
=
warp_sums
[
num_warps_for_scan
-
1
];
scan_buf
[
scan_size
-
1
]
=
0
;
s_total_tokens_post_pad
=
prefix
[
num_experts
];
*
total_tokens_post_pad
=
s_total_tokens_post_pad
;
}
}
__syncthreads
();
__syncthreads
();
#pragma unroll
// Fill 0 to scan_buf extended area (tid >= num_expert)
for
(
int
d
=
1
;
d
<
scan_size
;
d
<<=
1
)
{
if
(
tid
>=
num_experts
&&
tid
<
scan_size
)
scan_buf
[
tid
]
=
0
;
offset
>>=
1
;
__syncthreads
();
if
(
tid
<
d
)
{
int
ai
=
offset
*
(
2
*
tid
+
1
)
-
1
;
int
bi
=
offset
*
(
2
*
tid
+
2
)
-
1
;
if
(
bi
<
scan_size
)
{
int
temp
=
scan_buf
[
ai
];
scan_buf
[
ai
]
=
scan_buf
[
bi
];
scan_buf
[
bi
]
+=
temp
;
}
}
__syncthreads
();
}
if
(
tid
<
num_experts
)
{
// Perform 2 level exclusive-prefix-sum to scan_buf
prefix
[
tid
]
=
scan_buf
[
tid
];
int
v
=
(
tid
<
scan_size
)
?
scan_buf
[
tid
]
:
0
;
}
int
pre
=
warp_exclusive_scan
(
v
);
if
(
lane_id
==
WARP_SIZE
-
1
)
warp_sums
[
warp_id
]
=
pre
+
v
;
__syncthreads
();
if
(
t
id
==
0
)
{
if
(
warp_
id
==
0
)
{
s_total_tokens_post_pad
=
prefix
[
num_experts
]
;
int
val
=
(
lane_id
<
num_warps_for_scan
)
?
warp_sums
[
lane_id
]
:
0
;
*
total_tokens_post_pad
=
s_total_tokens_post_pad
;
warp_sums
[
lane_id
]
=
warp_exclusive_scan
(
val
)
;
}
}
__syncthreads
();
int
offset
=
warp_sums
[
warp_id
];
if
(
tid
<
scan_size
)
scan_buf
[
tid
]
=
pre
+
offset
;
__syncthreads
();
__syncthreads
();
// Write prefix[0..num_experts - 1] and cumsum
if
(
tid
<
num_experts
)
prefix
[
tid
]
=
scan_buf
[
tid
];
if
(
tid
<=
num_experts
)
{
if
(
tid
<=
num_experts
)
{
cumsum
[
tid
]
=
prefix
[
tid
];
cumsum
[
tid
]
=
prefix
[
tid
];
}
}
...
@@ -250,9 +262,6 @@ void moe_align_block_size(
...
@@ -250,9 +262,6 @@ void moe_align_block_size(
bool
pad_sorted_token_ids
)
{
bool
pad_sorted_token_ids
)
{
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
int64_t
padded_num_experts
=
((
num_experts
+
WARP_SIZE
-
1
)
/
WARP_SIZE
)
*
WARP_SIZE
;
int
experts_per_warp
=
WARP_SIZE
;
int
threads
=
1024
;
int
threads
=
1024
;
threads
=
((
threads
+
WARP_SIZE
-
1
)
/
WARP_SIZE
)
*
WARP_SIZE
;
threads
=
((
threads
+
WARP_SIZE
-
1
)
/
WARP_SIZE
)
*
WARP_SIZE
;
...
@@ -278,7 +287,7 @@ void moe_align_block_size(
...
@@ -278,7 +287,7 @@ void moe_align_block_size(
auto
align_kernel
=
moe_align_block_size_kernel
<
scalar_t
>
;
auto
align_kernel
=
moe_align_block_size_kernel
<
scalar_t
>
;
const
size_t
scan_size
=
next_pow2
(
num_experts
);
const
size_t
scan_size
=
next_pow2
(
num_experts
);
const
size_t
shared_mem_size
=
(
num_experts
+
(
num_experts
+
1
)
+
scan_size
)
*
sizeof
(
int32_t
);
const
size_t
shared_mem_size
=
(
num_experts
+
(
num_experts
+
1
)
+
scan_size
+
WARP_SIZE
)
*
sizeof
(
int32_t
);
align_kernel
<<<
1
,
threads
,
shared_mem_size
,
stream
>>>
(
align_kernel
<<<
1
,
threads
,
shared_mem_size
,
stream
>>>
(
topk_ids
.
data_ptr
<
scalar_t
>
(),
topk_ids
.
data_ptr
<
scalar_t
>
(),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment