Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
6c97b9b9
Unverified
Commit
6c97b9b9
authored
Jan 20, 2026
by
Wentao Ye
Committed by
GitHub
Jan 20, 2026
Browse files
[Perf] Only clone when needed for `moe_permute` (#32273)
Signed-off-by:
yewentao256
<
zhyanwentao@126.com
>
parent
4ca62a0d
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
6 additions
and
5 deletions
+6
-5
csrc/moe/moe_permute_unpermute_op.cu
csrc/moe/moe_permute_unpermute_op.cu
+4
-3
csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.cu
...permute_unpermute_kernels/moe_permute_unpermute_kernel.cu
+1
-1
csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.h
.../permute_unpermute_kernels/moe_permute_unpermute_kernel.h
+1
-1
No files found.
csrc/moe/moe_permute_unpermute_op.cu
View file @
6c97b9b9
...
...
@@ -42,7 +42,7 @@ void moe_permute(
auto
sort_workspace
=
torch
::
empty
(
{
sorter_size
},
torch
::
dtype
(
torch
::
kInt8
).
device
(
torch
::
kCUDA
).
requires_grad
(
false
));
auto
copy_topk_ids
=
topk_ids
.
clone
();
// copy topk_ids for preprocess
torch
::
Tensor
topk_ids_for_sort
=
topk_ids
;
auto
permuted_experts_id
=
torch
::
empty_like
(
topk_ids
);
auto
sorted_row_idx
=
torch
::
empty_like
(
inv_permuted_idx
);
...
...
@@ -62,12 +62,13 @@ void moe_permute(
const
int
*
expert_map_ptr
=
get_ptr
<
int
>
(
expert_map
.
value
());
valid_num_ptr
=
get_ptr
<
int64_t
>
(
expert_first_token_offset
)
+
n_local_expert
;
preprocessTopkIdLauncher
(
get_ptr
<
int
>
(
copy_topk_ids
),
n_token
*
topk
,
topk_ids_for_sort
=
topk_ids
.
clone
();
preprocessTopkIdLauncher
(
get_ptr
<
int
>
(
topk_ids_for_sort
),
n_token
*
topk
,
expert_map_ptr
,
n_expert
,
stream
);
}
// expert sort topk expert id and scan expert id get expert_first_token_offset
sortAndScanExpert
(
get_ptr
<
int
>
(
copy_topk_ids
),
get_ptr
<
int
>
(
token_expert_indices
),
get_ptr
<
const
int
>
(
topk_ids_for_sort
),
get_ptr
<
int
>
(
token_expert_indices
),
get_ptr
<
int
>
(
permuted_experts_id
),
get_ptr
<
int
>
(
sorted_row_idx
),
get_ptr
<
int64_t
>
(
expert_first_token_offset
),
n_token
,
n_expert
,
n_local_expert
,
topk
,
sorter
,
get_ptr
<
int
>
(
sort_workspace
),
stream
);
...
...
csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.cu
View file @
6c97b9b9
...
...
@@ -109,7 +109,7 @@ void computeExpertFirstTokenOffset(int const* sorted_indices,
sorted_indices
,
total_indices
,
num_experts
,
expert_first_token_offset
);
}
void
sortAndScanExpert
(
int
*
expert_for_source_row
,
const
int
*
source_rows
,
void
sortAndScanExpert
(
const
int
*
expert_for_source_row
,
const
int
*
source_rows
,
int
*
permuted_experts
,
int
*
permuted_rows
,
int64_t
*
expert_first_token_offset
,
int
num_rows
,
int
num_experts
,
int
num_experts_per_node
,
int
k
,
...
...
csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.h
View file @
6c97b9b9
...
...
@@ -48,7 +48,7 @@ void computeExpertFirstTokenOffset(int const* sorted_indices,
int64_t
*
expert_first_token_offset
,
cudaStream_t
stream
);
void
sortAndScanExpert
(
int
*
expert_for_source_row
,
const
int
*
source_rows
,
void
sortAndScanExpert
(
const
int
*
expert_for_source_row
,
const
int
*
source_rows
,
int
*
permuted_experts
,
int
*
permuted_rows
,
int64_t
*
expert_first_token_offset
,
int
num_rows
,
int
num_experts
,
int
num_experts_per_node
,
int
k
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment