Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
e855d380
Unverified
Commit
e855d380
authored
Mar 16, 2026
by
Wentao Ye
Committed by
GitHub
Mar 16, 2026
Browse files
[Compile] Fix compile warning in `moe_permute` (#36529)
Signed-off-by:
yewentao256
<
zhyanwentao@126.com
>
parent
0e5a9382
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
12 additions
and
14 deletions
+12
-14
csrc/moe/moe_permute_unpermute_op.cu
csrc/moe/moe_permute_unpermute_op.cu
+3
-4
csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.h
.../permute_unpermute_kernels/moe_permute_unpermute_kernel.h
+1
-1
csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.inl
...ermute_unpermute_kernels/moe_permute_unpermute_kernel.inl
+8
-9
No files found.
csrc/moe/moe_permute_unpermute_op.cu
View file @
e855d380
...
@@ -73,10 +73,9 @@ void moe_permute(
...
@@ -73,10 +73,9 @@ void moe_permute(
MOE_DISPATCH
(
input
.
scalar_type
(),
[
&
]
{
MOE_DISPATCH
(
input
.
scalar_type
(),
[
&
]
{
expandInputRowsKernelLauncher
<
scalar_t
>
(
expandInputRowsKernelLauncher
<
scalar_t
>
(
get_ptr
<
scalar_t
>
(
input
),
get_ptr
<
scalar_t
>
(
permuted_input
),
get_ptr
<
scalar_t
>
(
input
),
get_ptr
<
scalar_t
>
(
permuted_input
),
get_ptr
<
int
>
(
permuted_experts_id
),
get_ptr
<
int
>
(
sorted_row_idx
),
get_ptr
<
int
>
(
sorted_row_idx
),
get_ptr
<
int
>
(
inv_permuted_idx
),
get_ptr
<
int
>
(
inv_permuted_idx
),
get_ptr
<
int
>
(
permuted_idx
),
get_ptr
<
int
>
(
permuted_idx
),
get_ptr
<
int64_t
>
(
expert_first_token_offset
),
get_ptr
<
int64_t
>
(
expert_first_token_offset
),
n_token
,
valid_num_ptr
,
n_token
,
valid_num_ptr
,
n_hidden
,
topk
,
n_local_expert
,
stream
);
n_hidden
,
topk
,
n_local_expert
,
stream
);
});
});
}
}
...
...
csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.h
View file @
e855d380
...
@@ -57,7 +57,7 @@ void sortAndScanExpert(const int* expert_for_source_row, const int* source_rows,
...
@@ -57,7 +57,7 @@ void sortAndScanExpert(const int* expert_for_source_row, const int* source_rows,
template
<
typename
T
>
template
<
typename
T
>
void
expandInputRowsKernelLauncher
(
void
expandInputRowsKernelLauncher
(
T
const
*
unpermuted_input
,
T
*
permuted_output
,
int
*
sorted_experts
,
T
const
*
unpermuted_input
,
T
*
permuted_output
,
int
const
*
expanded_dest_row_to_expanded_source_row
,
int
const
*
expanded_dest_row_to_expanded_source_row
,
int
*
expanded_source_row_to_expanded_dest_row
,
int
*
permuted_idx
,
int
*
expanded_source_row_to_expanded_dest_row
,
int
*
permuted_idx
,
int64_t
const
*
expert_first_token_offset
,
int64_t
const
num_rows
,
int64_t
const
*
expert_first_token_offset
,
int64_t
const
num_rows
,
...
...
csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.inl
View file @
e855d380
...
@@ -2,7 +2,7 @@
...
@@ -2,7 +2,7 @@
template <typename T, bool CHECK_SKIPPED>
template <typename T, bool CHECK_SKIPPED>
__global__ void expandInputRowsKernel(
__global__ void expandInputRowsKernel(
T const* unpermuted_input, T* permuted_output,
int* sorted_experts,
T const* unpermuted_input, T* permuted_output,
int const* expanded_dest_row_to_expanded_source_row,
int const* expanded_dest_row_to_expanded_source_row,
int* expanded_source_row_to_expanded_dest_row, int* permuted_idx,
int* expanded_source_row_to_expanded_dest_row, int* permuted_idx,
int64_t const* expert_first_token_offset, int64_t const num_rows,
int64_t const* expert_first_token_offset, int64_t const num_rows,
...
@@ -16,7 +16,6 @@ __global__ void expandInputRowsKernel(
...
@@ -16,7 +16,6 @@ __global__ void expandInputRowsKernel(
int64_t expanded_dest_row = blockIdx.x;
int64_t expanded_dest_row = blockIdx.x;
int64_t const expanded_source_row =
int64_t const expanded_source_row =
expanded_dest_row_to_expanded_source_row[expanded_dest_row];
expanded_dest_row_to_expanded_source_row[expanded_dest_row];
int expert_id = sorted_experts[expanded_dest_row];
if (threadIdx.x == 0) {
if (threadIdx.x == 0) {
assert(expanded_dest_row <= INT32_MAX);
assert(expanded_dest_row <= INT32_MAX);
...
@@ -54,7 +53,7 @@ __global__ void expandInputRowsKernel(
...
@@ -54,7 +53,7 @@ __global__ void expandInputRowsKernel(
template <typename T>
template <typename T>
void expandInputRowsKernelLauncher(
void expandInputRowsKernelLauncher(
T const* unpermuted_input, T* permuted_output,
int* sorted_experts,
T const* unpermuted_input, T* permuted_output,
int const* expanded_dest_row_to_expanded_source_row,
int const* expanded_dest_row_to_expanded_source_row,
int* expanded_source_row_to_expanded_dest_row, int* permuted_idx,
int* expanded_source_row_to_expanded_dest_row, int* permuted_idx,
int64_t const* expert_first_token_offset, int64_t const num_rows,
int64_t const* expert_first_token_offset, int64_t const num_rows,
...
@@ -70,12 +69,12 @@ void expandInputRowsKernelLauncher(
...
@@ -70,12 +69,12 @@ void expandInputRowsKernelLauncher(
bool is_check_skip = num_valid_tokens_ptr != nullptr;
bool is_check_skip = num_valid_tokens_ptr != nullptr;
auto func = func_map[is_check_skip];
auto func = func_map[is_check_skip];
func<<<blocks, threads, 0, stream>>>(
func<<<blocks, threads, 0, stream>>>(
unpermuted_input, permuted_output,
unpermuted_input, permut
ed_ou
tput, sorted_experts
,
expanded_dest_row_to_expand
ed_
s
ou
rce_row
,
expanded_
dest
_row_to_expanded_
source
_row,
expanded_
source
_row_to_expanded_
dest
_row,
expanded_source_row_to_expanded_dest_row, permuted_idx
,
permuted_idx, expert_first_token_offset
,
expert_first_token_offset,
num_rows, num_valid_tokens_ptr, cols, k,
num_rows, num_valid_tokens_ptr, cols, k,
num_local_experts);
num_local_experts);
}
}
template <class T, class U>
template <class T, class U>
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment