Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
zhaoyu6
sglang
Commits
a3398d84
"docs/en/vscode:/vscode.git/clone" did not exist on "e55dbdcd3990b01de0ff1e266afa6e95598c0c4c"
Unverified
Commit
a3398d84
authored
Jul 07, 2025
by
Ke Bao
Committed by
GitHub
Jul 07, 2025
Browse files
Optimize moe align block size kernel (#7794)
parent
ba69c153
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
100 additions
and
63 deletions
+100
-63
sgl-kernel/csrc/moe/moe_align_kernel.cu
sgl-kernel/csrc/moe/moe_align_kernel.cu
+94
-63
sgl-kernel/include/utils.h
sgl-kernel/include/utils.h
+6
-0
No files found.
sgl-kernel/csrc/moe/moe_align_kernel.cu
View file @
a3398d84
...
@@ -21,16 +21,10 @@ limitations under the License.
...
@@ -21,16 +21,10 @@ limitations under the License.
#include "utils.h"
#include "utils.h"
template
<
typename
T
,
int
N
,
int
Alignment
=
sizeof
(
T
)
*
N
>
class
alignas
(
Alignment
)
AlignedArray
{
public:
T
data
[
N
];
};
#define WARP_SIZE 32
#define WARP_SIZE 32
#define VEC_SIZE 4
#define VEC_SIZE 4
using
Vec
=
AlignedArray
<
int32_t
,
VEC_SIZE
>
;
using
Vec
=
int4
;
template
<
typename
scalar_t
>
template
<
typename
scalar_t
>
__global__
void
count_and_sort_expert_tokens_kernel
(
__global__
void
count_and_sort_expert_tokens_kernel
(
...
@@ -55,73 +49,119 @@ __global__ void moe_align_block_size_kernel(
...
@@ -55,73 +49,119 @@ __global__ void moe_align_block_size_kernel(
int32_t
*
__restrict__
expert_ids
,
int32_t
*
__restrict__
expert_ids
,
int32_t
*
__restrict__
total_tokens_post_pad
,
int32_t
*
__restrict__
total_tokens_post_pad
,
int32_t
num_experts
,
int32_t
num_experts
,
int32_t
padded_num_experts
,
int32_t
experts_per_warp
,
int32_t
block_size
,
int32_t
block_size
,
size_t
numel
,
size_t
numel
,
int32_t
*
__restrict__
cumsum
,
int32_t
*
__restrict__
cumsum
,
bool
pad_sorted_token_ids
)
{
bool
pad_sorted_token_ids
,
extern
__shared__
int32_t
shared_counts
[];
const
int32_t
scan_size
)
{
extern
__shared__
int32_t
smem
[];
int32_t
*
shared_counts
=
smem
;
// [num_experts]
int32_t
*
prefix
=
shared_counts
+
num_experts
;
// [num_experts + 1]
int32_t
*
scan_buf
=
prefix
+
num_experts
+
1
;
// [scan_size]
__shared__
int32_t
s_total_tokens_post_pad
;
const
int
warp_
id
=
threadIdx
.
x
/
WARP_SIZE
;
const
size_t
t
id
=
threadIdx
.
x
;
const
int
my_expert_start
=
warp_id
*
experts_per_warp
;
const
size_t
stride
=
blockDim
.
x
;
for
(
int
i
=
0
;
i
<
experts_per_warp
;
++
i
)
{
if
(
tid
<
num_experts
)
{
if
(
my_expert_start
+
i
<
padded_num_experts
)
{
shared_counts
[
tid
]
=
0
;
shared_counts
[
warp_id
*
experts_per_warp
+
i
]
=
0
;
}
}
}
__syncthreads
();
__syncthreads
();
const
size_t
tid
=
threadIdx
.
x
;
const
size_t
stride
=
blockDim
.
x
;
for
(
size_t
i
=
tid
;
i
<
numel
;
i
+=
stride
)
{
for
(
size_t
i
=
tid
;
i
<
numel
;
i
+=
stride
)
{
int
expert_id
=
topk_ids
[
i
];
int
expert_id
=
topk_ids
[
i
];
int
warp_idx
=
expert_id
/
experts_per_warp
;
atomicAdd
(
&
shared_counts
[
expert_id
],
1
);
int
expert_offset
=
expert_id
%
experts_per_warp
;
atomicAdd
(
&
shared_counts
[
warp_idx
*
experts_per_warp
+
expert_offset
],
1
);
}
}
__syncthreads
();
__syncthreads
();
if
(
threadIdx
.
x
==
0
)
{
int32_t
padded_count
=
0
;
cumsum
[
0
]
=
0
;
if
(
tid
<
num_experts
)
{
for
(
int
i
=
1
;
i
<=
num_experts
;
++
i
)
{
int32_t
count
=
shared_counts
[
tid
];
int
expert_count
=
0
;
padded_count
=
(
count
+
block_size
-
1
)
/
block_size
*
block_size
;
int
warp_idx
=
(
i
-
1
)
/
experts_per_warp
;
scan_buf
[
tid
]
=
padded_count
;
int
expert_offset
=
(
i
-
1
)
%
experts_per_warp
;
expert_count
=
shared_counts
[
warp_idx
*
experts_per_warp
+
expert_offset
];
cumsum
[
i
]
=
cumsum
[
i
-
1
]
+
CEILDIV
(
expert_count
,
block_size
)
*
block_size
;
}
}
*
total_tokens_post_pad
=
cumsum
[
num_experts
];
if
(
tid
>=
num_experts
&&
tid
<
scan_size
)
{
scan_buf
[
tid
]
=
0
;
}
}
__syncthreads
();
__syncthreads
();
if
(
threadIdx
.
x
<
num_experts
)
{
// Blelloch scan
for
(
int
i
=
cumsum
[
threadIdx
.
x
];
i
<
cumsum
[
threadIdx
.
x
+
1
];
i
+=
block_size
)
{
int
offset
=
1
;
expert_ids
[
i
/
block_size
]
=
threadIdx
.
x
;
#pragma unroll
for
(
int
d
=
scan_size
>>
1
;
d
>
0
;
d
>>=
1
)
{
if
(
tid
<
d
)
{
int
ai
=
offset
*
(
2
*
tid
+
1
)
-
1
;
int
bi
=
offset
*
(
2
*
tid
+
2
)
-
1
;
scan_buf
[
bi
]
+=
scan_buf
[
ai
];
}
}
offset
<<=
1
;
__syncthreads
();
}
}
if
(
pad_sorted_token_ids
)
{
// down-sweep
int32_t
fill_val
=
static_cast
<
int32_t
>
(
numel
);
if
(
tid
==
0
)
{
int32_t
total
=
*
total_tokens_post_pad
;
prefix
[
num_experts
]
=
scan_buf
[
scan_size
-
1
];
scan_buf
[
scan_size
-
1
]
=
0
;
}
__syncthreads
();
Vec
fill_vec
;
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
VEC_SIZE
;
++
i
)
{
for
(
int
d
=
1
;
d
<
scan_size
;
d
<<=
1
)
{
fill_vec
.
data
[
i
]
=
fill_val
;
offset
>>=
1
;
if
(
tid
<
d
)
{
int
ai
=
offset
*
(
2
*
tid
+
1
)
-
1
;
int
bi
=
offset
*
(
2
*
tid
+
2
)
-
1
;
if
(
bi
<
scan_size
)
{
int
temp
=
scan_buf
[
ai
];
scan_buf
[
ai
]
=
scan_buf
[
bi
];
scan_buf
[
bi
]
+=
temp
;
}
}
__syncthreads
();
}
}
int32_t
total_vec_count
=
(
total
+
VEC_SIZE
-
1
)
/
VEC_SIZE
;
if
(
tid
<
num_experts
)
{
Vec
*
out_ptr
=
reinterpret_cast
<
Vec
*>
(
sorted_token_ids
);
prefix
[
tid
]
=
scan_buf
[
tid
];
}
if
(
tid
==
0
)
{
s_total_tokens_post_pad
=
prefix
[
num_experts
];
*
total_tokens_post_pad
=
s_total_tokens_post_pad
;
}
__syncthreads
();
if
(
tid
<=
num_experts
)
{
cumsum
[
tid
]
=
prefix
[
tid
];
}
for
(
int32_t
idx
=
tid
;
idx
<
total_vec_count
;
idx
+=
stride
)
{
// fill expert_ids
out_ptr
[
idx
]
=
fill_vec
;
const
int32_t
num_blocks
=
s_total_tokens_post_pad
/
block_size
;
for
(
int32_t
i
=
tid
;
i
<
num_blocks
;
i
+=
stride
)
{
int32_t
block_start
=
i
*
block_size
;
int
left
=
0
,
right
=
num_experts
;
while
(
left
<
right
)
{
int
mid
=
(
left
+
right
)
>>
1
;
if
(
prefix
[
mid
]
<=
block_start
)
{
left
=
mid
+
1
;
}
else
{
right
=
mid
;
}
}
expert_ids
[
i
]
=
left
-
1
;
}
if
(
pad_sorted_token_ids
)
{
Vec
fill_vec
;
fill_vec
.
x
=
fill_vec
.
y
=
fill_vec
.
z
=
fill_vec
.
w
=
numel
;
int32_t
total_vecs
=
(
s_total_tokens_post_pad
+
VEC_SIZE
-
1
)
/
VEC_SIZE
;
Vec
*
out_ptr
=
reinterpret_cast
<
Vec
*>
(
sorted_token_ids
);
for
(
int32_t
i
=
tid
;
i
<
total_vecs
;
i
+=
stride
)
{
out_ptr
[
i
]
=
fill_vec
;
}
}
}
}
}
}
...
@@ -179,20 +219,12 @@ __global__ void moe_align_block_size_small_batch_expert_kernel(
...
@@ -179,20 +219,12 @@ __global__ void moe_align_block_size_small_batch_expert_kernel(
}
}
if
(
pad_sorted_token_ids
)
{
if
(
pad_sorted_token_ids
)
{
int32_t
fill_val
=
static_cast
<
int32_t
>
(
numel
);
int32_t
total
=
*
total_tokens_post_pad
;
Vec
fill_vec
;
Vec
fill_vec
;
#pragma unroll
fill_vec
.
x
=
fill_vec
.
y
=
fill_vec
.
z
=
fill_vec
.
w
=
numel
;
for
(
int
i
=
0
;
i
<
VEC_SIZE
;
++
i
)
{
int32_t
total_vecs
=
(
*
total_tokens_post_pad
+
VEC_SIZE
-
1
)
/
VEC_SIZE
;
fill_vec
.
data
[
i
]
=
fill_val
;
}
int32_t
total_vec_count
=
(
total
+
VEC_SIZE
-
1
)
/
VEC_SIZE
;
Vec
*
out_ptr
=
reinterpret_cast
<
Vec
*>
(
sorted_token_ids
);
Vec
*
out_ptr
=
reinterpret_cast
<
Vec
*>
(
sorted_token_ids
);
for
(
int32_t
i
=
tid
;
i
<
total_vecs
;
i
+=
stride
)
{
for
(
int32_t
idx
=
tid
;
idx
<
total_vec_count
;
idx
+=
stride
)
{
out_ptr
[
i
]
=
fill_vec
;
out_ptr
[
idx
]
=
fill_vec
;
}
}
}
}
...
@@ -245,8 +277,8 @@ void moe_align_block_size(
...
@@ -245,8 +277,8 @@ void moe_align_block_size(
}
else
{
}
else
{
auto
align_kernel
=
moe_align_block_size_kernel
<
scalar_t
>
;
auto
align_kernel
=
moe_align_block_size_kernel
<
scalar_t
>
;
size_t
num_warps
=
CEILDIV
(
padded_num_experts
,
experts_per_warp
);
const
size_t
scan_size
=
next_pow2
(
num_experts
);
size_t
shared_mem_size
=
num_
warps
*
experts_per_warp
*
sizeof
(
int32_t
);
const
size_t
shared_mem_size
=
(
num_
experts
+
(
num_experts
+
1
)
+
scan_size
)
*
sizeof
(
int32_t
);
align_kernel
<<<
1
,
threads
,
shared_mem_size
,
stream
>>>
(
align_kernel
<<<
1
,
threads
,
shared_mem_size
,
stream
>>>
(
topk_ids
.
data_ptr
<
scalar_t
>
(),
topk_ids
.
data_ptr
<
scalar_t
>
(),
...
@@ -254,12 +286,11 @@ void moe_align_block_size(
...
@@ -254,12 +286,11 @@ void moe_align_block_size(
experts_ids
.
data_ptr
<
int32_t
>
(),
experts_ids
.
data_ptr
<
int32_t
>
(),
num_tokens_post_pad
.
data_ptr
<
int32_t
>
(),
num_tokens_post_pad
.
data_ptr
<
int32_t
>
(),
num_experts
,
num_experts
,
padded_num_experts
,
experts_per_warp
,
block_size
,
block_size
,
topk_ids
.
numel
(),
topk_ids
.
numel
(),
cumsum_buffer
.
data_ptr
<
int32_t
>
(),
cumsum_buffer
.
data_ptr
<
int32_t
>
(),
pad_sorted_token_ids
);
pad_sorted_token_ids
,
scan_size
);
const
int
block_threads
=
std
::
min
(
256
,
(
int
)
threads
);
const
int
block_threads
=
std
::
min
(
256
,
(
int
)
threads
);
const
int
num_blocks
=
(
topk_ids
.
numel
()
+
block_threads
-
1
)
/
block_threads
;
const
int
num_blocks
=
(
topk_ids
.
numel
()
+
block_threads
-
1
)
/
block_threads
;
...
...
sgl-kernel/include/utils.h
View file @
a3398d84
...
@@ -363,3 +363,9 @@ inline torch::Tensor pad_tensor(const torch::Tensor& tensor, int64_t alignment =
...
@@ -363,3 +363,9 @@ inline torch::Tensor pad_tensor(const torch::Tensor& tensor, int64_t alignment =
}
}
return
tensor_padded
;
return
tensor_padded
;
}
}
// Get the next power of 2 of a number
inline
uint32_t
next_pow2
(
uint32_t
x
)
noexcept
{
if
(
x
<=
1
)
return
1
;
return
1u
<<
(
32
-
__builtin_clz
(
x
-
1
));
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment