Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
c4e744db
Unverified
Commit
c4e744db
authored
Jan 28, 2026
by
Wentao Ye
Committed by
GitHub
Jan 28, 2026
Browse files
[Perf] Optimize `moe_permute` for CUTLASS FP8 (#32892)
Signed-off-by:
yewentao256
<
zhyanwentao@126.com
>
parent
8ebf372e
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
47 additions
and
44 deletions
+47
-44
csrc/moe/moe_permute_unpermute_op.cu
csrc/moe/moe_permute_unpermute_op.cu
+24
-9
csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.h
.../permute_unpermute_kernels/moe_permute_unpermute_kernel.h
+2
-1
csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.inl
...ermute_unpermute_kernels/moe_permute_unpermute_kernel.inl
+21
-34
No files found.
csrc/moe/moe_permute_unpermute_op.cu
View file @
c4e744db
...
@@ -73,25 +73,40 @@ void moe_permute(
...
@@ -73,25 +73,40 @@ void moe_permute(
get_ptr
<
int64_t
>
(
expert_first_token_offset
),
n_token
,
n_expert
,
get_ptr
<
int64_t
>
(
expert_first_token_offset
),
n_token
,
n_expert
,
n_local_expert
,
topk
,
sorter
,
get_ptr
<
int
>
(
sort_workspace
),
stream
);
n_local_expert
,
topk
,
sorter
,
get_ptr
<
int
>
(
sort_workspace
),
stream
);
// DeepGEMM: use getMIndices kernel to compute
// 1) align_expert_first_token_offset (aligned prefix offsets)
// 2) m_indices (expert id for each aligned row)
// eg. expert0: 3, expert1: 5, expert2: 2 tokens respectively
// expert_first_token_offset = [0, 3, 8, 10], align_block_size = 4
// expert0: 3->4, expert1: 5->8, expert2: 2->4
// align_expert_first_token_offset = [0, 4, 12, 16]
// so m_indices = [0,0,0,0, 1,1,1,1,1,1,1,1, 2,2,2,2]
torch
::
Tensor
align_expert_first_token_offset
;
const
int64_t
*
aligned_expert_first_token_offset_ptr
=
nullptr
;
if
(
align_block_size
.
has_value
())
{
align_expert_first_token_offset
=
torch
::
zeros_like
(
expert_first_token_offset
);
getMIndices
(
get_ptr
<
int64_t
>
(
expert_first_token_offset
),
get_ptr
<
int64_t
>
(
align_expert_first_token_offset
),
get_ptr
<
int
>
(
m_indices
),
n_local_expert
,
align_block_size_value
,
stream
);
aligned_expert_first_token_offset_ptr
=
get_ptr
<
int64_t
>
(
align_expert_first_token_offset
);
}
// dispatch expandInputRowsKernelLauncher
// dispatch expandInputRowsKernelLauncher
MOE_DISPATCH
(
input
.
scalar_type
(),
[
&
]
{
MOE_DISPATCH
(
input
.
scalar_type
(),
[
&
]
{
expandInputRowsKernelLauncher
<
scalar_t
>
(
expandInputRowsKernelLauncher
<
scalar_t
>
(
get_ptr
<
scalar_t
>
(
input
),
get_ptr
<
scalar_t
>
(
permuted_input
),
get_ptr
<
scalar_t
>
(
input
),
get_ptr
<
scalar_t
>
(
permuted_input
),
get_ptr
<
int
>
(
permuted_experts_id
),
get_ptr
<
int
>
(
sorted_row_idx
),
get_ptr
<
int
>
(
permuted_experts_id
),
get_ptr
<
int
>
(
sorted_row_idx
),
get_ptr
<
int
>
(
inv_permuted_idx
),
get_ptr
<
int
>
(
permuted_idx
),
get_ptr
<
int
>
(
inv_permuted_idx
),
get_ptr
<
int
>
(
permuted_idx
),
get_ptr
<
int64_t
>
(
expert_first_token_offset
),
n_token
,
valid_num_ptr
,
get_ptr
<
int64_t
>
(
expert_first_token_offset
),
n_hidden
,
topk
,
n_local_expert
,
align_block_size_value
,
stream
);
aligned_expert_first_token_offset_ptr
,
n_token
,
valid_num_ptr
,
n_hidden
,
topk
,
n_local_expert
,
align_block_size_value
,
stream
);
});
});
// get m_indices and update expert_first_token_offset with align block
// this is only required for DeepGemm and not required for CUTLASS group gemm
// this is only required for DeepGemm and not required for CUTLASS group gemm
if
(
align_block_size
.
has_value
())
{
if
(
align_block_size
.
has_value
())
{
auto
align_expert_first_token_offset
=
torch
::
zeros_like
(
expert_first_token_offset
);
getMIndices
(
get_ptr
<
int64_t
>
(
expert_first_token_offset
),
get_ptr
<
int64_t
>
(
align_expert_first_token_offset
),
get_ptr
<
int
>
(
m_indices
),
n_local_expert
,
align_block_size_value
,
stream
);
expert_first_token_offset
.
copy_
(
align_expert_first_token_offset
);
expert_first_token_offset
.
copy_
(
align_expert_first_token_offset
);
}
}
}
}
...
...
csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.h
View file @
c4e744db
...
@@ -60,7 +60,8 @@ void expandInputRowsKernelLauncher(
...
@@ -60,7 +60,8 @@ void expandInputRowsKernelLauncher(
T
const
*
unpermuted_input
,
T
*
permuted_output
,
int
*
sorted_experts
,
T
const
*
unpermuted_input
,
T
*
permuted_output
,
int
*
sorted_experts
,
int
const
*
expanded_dest_row_to_expanded_source_row
,
int
const
*
expanded_dest_row_to_expanded_source_row
,
int
*
expanded_source_row_to_expanded_dest_row
,
int
*
permuted_idx
,
int
*
expanded_source_row_to_expanded_dest_row
,
int
*
permuted_idx
,
int64_t
*
expert_first_token_offset
,
int64_t
const
num_rows
,
int64_t
const
*
expert_first_token_offset
,
int64_t
const
*
aligned_expert_first_token_offset
,
int64_t
const
num_rows
,
int64_t
const
*
num_valid_tokens_ptr
,
int64_t
const
cols
,
int
const
k
,
int64_t
const
*
num_valid_tokens_ptr
,
int64_t
const
cols
,
int
const
k
,
int
num_local_experts
,
const
int
&
align_block_size
,
cudaStream_t
stream
);
int
num_local_experts
,
const
int
&
align_block_size
,
cudaStream_t
stream
);
...
...
csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.inl
View file @
c4e744db
...
@@ -5,7 +5,8 @@ __global__ void expandInputRowsKernel(
...
@@ -5,7 +5,8 @@ __global__ void expandInputRowsKernel(
T const* unpermuted_input, T* permuted_output, int* sorted_experts,
T const* unpermuted_input, T* permuted_output, int* sorted_experts,
int const* expanded_dest_row_to_expanded_source_row,
int const* expanded_dest_row_to_expanded_source_row,
int* expanded_source_row_to_expanded_dest_row, int* permuted_idx,
int* expanded_source_row_to_expanded_dest_row, int* permuted_idx,
int64_t* expert_first_token_offset, int64_t const num_rows,
int64_t const* expert_first_token_offset,
int64_t const* aligned_expert_first_token_offset, int64_t const num_rows,
int64_t const* num_dest_rows, int64_t const cols, int64_t k,
int64_t const* num_dest_rows, int64_t const cols, int64_t k,
int num_local_experts, int align_block_size) {
int num_local_experts, int align_block_size) {
// Reverse permutation map.
// Reverse permutation map.
...
@@ -18,35 +19,22 @@ __global__ void expandInputRowsKernel(
...
@@ -18,35 +19,22 @@ __global__ void expandInputRowsKernel(
expanded_dest_row_to_expanded_source_row[expanded_dest_row];
expanded_dest_row_to_expanded_source_row[expanded_dest_row];
int expert_id = sorted_experts[expanded_dest_row];
int expert_id = sorted_experts[expanded_dest_row];
extern __shared__ int64_t smem_expert_first_token_offset[];
if constexpr (ALIGN_BLOCK_SIZE) {
if constexpr (ALIGN_BLOCK_SIZE) {
// load g2s
// convert (unaligned) expanded_dest_row -> aligned expanded_dest_row.
for (int idx = threadIdx.x; idx < num_local_experts + 1;
// aligned_expert_first_token_offset[e] provides the aligned prefix start
idx += blockDim.x) {
// for expert e. For non-local experts we map to the end (total aligned M).
smem_expert_first_token_offset[idx] =
int64_t aligned_base = 0;
__ldg(expert_first_token_offset + idx);
int64_t token_offset_in_expert = 0;
if (expert_id >= num_local_experts) {
aligned_base =
__ldg(aligned_expert_first_token_offset + num_local_experts);
token_offset_in_expert = 0;
} else {
aligned_base = __ldg(aligned_expert_first_token_offset + expert_id);
token_offset_in_expert =
expanded_dest_row - __ldg(expert_first_token_offset + expert_id);
}
}
__syncthreads();
expanded_dest_row = aligned_base + token_offset_in_expert;
int lane_idx = threadIdx.x & 31;
if (lane_idx == 0) {
// set token_offset_in_expert = 0 if this expert is not local expert
int token_offset_in_expert =
expert_id >= num_local_experts
? 0
: expanded_dest_row - smem_expert_first_token_offset[expert_id];
int64_t accumulate_align_offset = 0;
#pragma unroll 1
for (int eidx = 1; eidx <= min(expert_id, num_local_experts); eidx++) {
auto n_token_in_expert = smem_expert_first_token_offset[eidx] -
smem_expert_first_token_offset[eidx - 1];
accumulate_align_offset += (n_token_in_expert + align_block_size - 1) /
align_block_size * align_block_size;
}
expanded_dest_row = accumulate_align_offset + token_offset_in_expert;
}
// lane0 shuffle broadcast align_expanded_dest_row
expanded_dest_row = __shfl_sync(0xffffffff, expanded_dest_row, 0);
}
}
if (threadIdx.x == 0) {
if (threadIdx.x == 0) {
...
@@ -88,7 +76,8 @@ void expandInputRowsKernelLauncher(
...
@@ -88,7 +76,8 @@ void expandInputRowsKernelLauncher(
T const* unpermuted_input, T* permuted_output, int* sorted_experts,
T const* unpermuted_input, T* permuted_output, int* sorted_experts,
int const* expanded_dest_row_to_expanded_source_row,
int const* expanded_dest_row_to_expanded_source_row,
int* expanded_source_row_to_expanded_dest_row, int* permuted_idx,
int* expanded_source_row_to_expanded_dest_row, int* permuted_idx,
int64_t* expert_first_token_offset, int64_t const num_rows,
int64_t const* expert_first_token_offset,
int64_t const* aligned_expert_first_token_offset, int64_t const num_rows,
int64_t const* num_valid_tokens_ptr, int64_t const cols, int const k,
int64_t const* num_valid_tokens_ptr, int64_t const cols, int const k,
int num_local_experts, const int& align_block_size, cudaStream_t stream) {
int num_local_experts, const int& align_block_size, cudaStream_t stream) {
int64_t const blocks = num_rows * k;
int64_t const blocks = num_rows * k;
...
@@ -104,14 +93,12 @@ void expandInputRowsKernelLauncher(
...
@@ -104,14 +93,12 @@ void expandInputRowsKernelLauncher(
bool is_align_block_size = align_block_size != -1;
bool is_align_block_size = align_block_size != -1;
auto func = func_map[is_check_skip][is_align_block_size];
auto func = func_map[is_check_skip][is_align_block_size];
int64_t smem_size = sizeof(int64_t) * (num_local_experts + 1);
func<<<blocks, threads, 0, stream>>>(
func<<<blocks, threads, smem_size, stream>>>(
unpermuted_input, permuted_output, sorted_experts,
unpermuted_input, permuted_output, sorted_experts,
expanded_dest_row_to_expanded_source_row,
expanded_dest_row_to_expanded_source_row,
expanded_source_row_to_expanded_dest_row, permuted_idx,
expanded_source_row_to_expanded_dest_row, permuted_idx,
expert_first_token_offset,
num_rows, num_valid_tokens_ptr, cols, k
,
expert_first_token_offset,
aligned_expert_first_token_offset, num_rows
,
num_local_experts, align_block_size);
num_valid_tokens_ptr, cols, k,
num_local_experts, align_block_size);
}
}
template <class T, class U>
template <class T, class U>
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment