Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
57c22e57
Unverified
Commit
57c22e57
authored
Jul 27, 2025
by
Caleb_Du
Committed by
GitHub
Jul 27, 2025
Browse files
Fix CUDA permute/unpermute for use with DeepGemm Moe (#17934)
Signed-off-by:
Caleb_Du
<
Caleb_Du@zju.edu.cn
>
parent
bda9d053
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
236 additions
and
209 deletions
+236
-209
benchmarks/kernels/benchmark_moe_permute_unpermute.py
benchmarks/kernels/benchmark_moe_permute_unpermute.py
+43
-33
csrc/moe/moe_permute_unpermute_op.cu
csrc/moe/moe_permute_unpermute_op.cu
+35
-38
csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.cu
...permute_unpermute_kernels/moe_permute_unpermute_kernel.cu
+1
-1
csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.h
.../permute_unpermute_kernels/moe_permute_unpermute_kernel.h
+4
-16
csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.inl
...ermute_unpermute_kernels/moe_permute_unpermute_kernel.inl
+18
-21
csrc/moe/torch_bindings.cpp
csrc/moe/torch_bindings.cpp
+5
-6
tests/kernels/moe/test_moe_permute_unpermute.py
tests/kernels/moe/test_moe_permute_unpermute.py
+82
-50
vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py
.../model_executor/layers/fused_moe/moe_permute_unpermute.py
+48
-44
No files found.
benchmarks/kernels/benchmark_moe_permute_unpermute.py
View file @
57c22e57
...
...
@@ -8,12 +8,13 @@ import ray
import
torch
from
transformers
import
AutoConfig
from
vllm.model_executor.layers.fused_moe.deep_gemm_moe
import
(
from
vllm.model_executor.layers.fused_moe.fused_moe
import
*
from
vllm.model_executor.layers.fused_moe.moe_permute_unpermute
import
(
_moe_permute
,
_moe_unpermute_and_reduce
,
moe_permute
,
moe_unpermute
,
)
from
vllm.model_executor.layers.fused_moe.fused_moe
import
*
from
vllm.model_executor.layers.fused_moe.moe_permute_unpermute
import
*
from
vllm.model_executor.layers.fused_moe.utils
import
_fp8_quantize
from
vllm.platforms
import
current_platform
from
vllm.utils
import
FlexibleArgumentParser
...
...
@@ -63,19 +64,20 @@ def benchmark_permute(
def
run
():
if
use_customized_permute
:
(
permuted_hidden_states
,
first_token_off
,
inv_perm_idx
,
m_indices
)
=
(
moe_permute
(
(
permuted_hidden_states
,
a1q_scale
,
first_token_off
,
inv_perm_idx
,
m_indices
,
)
=
moe_permute
(
qhidden_states
,
topk_weights
=
topk_weights
,
a1q_scale
=
None
,
topk_ids
=
topk_ids
,
token_expert_indices
=
token_expert_indices
,
topk
=
topk
,
n_expert
=
num_experts
,
n_local_expert
=
num_experts
,
expert_map
=
None
,
align_block_size
=
align_block_size
,
)
)
else
:
(
permuted_hidden_states
,
...
...
@@ -150,19 +152,20 @@ def benchmark_unpermute(
def
prepare
():
if
use_customized_permute
:
(
permuted_hidden_states
,
first_token_off
,
inv_perm_idx
,
m_indices
)
=
(
moe_permute
(
(
permuted_hidden_states
,
a1q_scale
,
first_token_off
,
inv_perm_idx
,
m_indices
,
)
=
moe_permute
(
qhidden_states
,
topk_weights
=
topk_weights
,
a1q_scale
=
None
,
topk_ids
=
topk_ids
,
token_expert_indices
=
token_expert_indices
,
topk
=
topk
,
n_expert
=
num_experts
,
n_local_expert
=
num_experts
,
expert_map
=
None
,
align_block_size
=
align_block_size
,
)
)
# convert to fp16/bf16 as gemm output
return
(
permuted_hidden_states
.
to
(
dtype
),
...
...
@@ -191,16 +194,19 @@ def benchmark_unpermute(
def
run
(
input
:
tuple
):
if
use_customized_permute
:
(
permuted_hidden_states
,
first_token_off
,
inv_perm_idx
,
m_indices
)
=
input
(
permuted_hidden_states
,
first_token_off
,
inv_perm_idx
,
m_indices
,
)
=
input
output
=
torch
.
empty_like
(
hidden_states
)
moe_unpermute
(
output
,
permuted_hidden_states
,
topk_weights
,
topk_ids
,
inv_perm_idx
,
first_token_off
,
topk
,
num_experts
,
num_experts
,
)
else
:
(
...
...
@@ -211,7 +217,11 @@ def benchmark_unpermute(
inv_perm
,
)
=
input
_moe_unpermute_and_reduce
(
output_hidden_states
,
permuted_hidden_states
,
inv_perm
,
topk_weights
output_hidden_states
,
permuted_hidden_states
,
inv_perm
,
topk_weights
,
True
,
)
# JIT compilation & warmup
...
...
csrc/moe/moe_permute_unpermute_op.cu
View file @
57c22e57
...
...
@@ -10,32 +10,28 @@
void
moe_permute
(
const
torch
::
Tensor
&
input
,
// [n_token, hidden]
const
torch
::
Tensor
&
topk_weights
,
//[n_token, topk]
torch
::
Tensor
&
topk_ids
,
// [n_token, topk]
const
torch
::
Tensor
&
topk_ids
,
// [n_token, topk]
const
torch
::
Tensor
&
token_expert_indices
,
// [n_token, topk]
const
std
::
optional
<
torch
::
Tensor
>&
expert_map
,
// [n_expert]
int64_t
n_expert
,
int64_t
n_local_expert
,
int64_t
topk
,
const
std
::
optional
<
int64_t
>&
align_block_size
,
torch
::
Tensor
&
permuted_input
,
// [topk * n_token/align_block_size_m, hidden]
torch
::
Tensor
&
permuted_input
,
// [permuted_size, hidden]
torch
::
Tensor
&
expert_first_token_offset
,
// [n_local_expert + 1]
torch
::
Tensor
&
src_row_id2dst_row_id_map
,
// [n_token, topk]
torch
::
Tensor
&
inv_permuted_idx
,
// [n_token, topk]
torch
::
Tensor
&
permuted_idx
,
// [permute_size]
torch
::
Tensor
&
m_indices
)
{
// [align_expand_m]
TORCH_CHECK
(
topk_weights
.
scalar_type
()
==
at
::
ScalarType
::
Float
,
"topk_weights must be float32"
);
TORCH_CHECK
(
expert_first_token_offset
.
scalar_type
()
==
at
::
ScalarType
::
Long
,
"expert_first_token_offset must be int64"
);
TORCH_CHECK
(
topk_ids
.
scalar_type
()
==
at
::
ScalarType
::
Int
,
"topk_ids must be int32"
);
TORCH_CHECK
(
token_expert_indices
.
scalar_type
()
==
at
::
ScalarType
::
Int
,
"token_expert_indices must be int32"
);
TORCH_CHECK
(
src_row_id2dst_row_id_map
.
scalar_type
()
==
at
::
ScalarType
::
Int
,
"
src_row_id2dst_row_id_map
must be int32"
);
TORCH_CHECK
(
inv_permuted_idx
.
scalar_type
()
==
at
::
ScalarType
::
Int
,
"
inv_permuted_idx
must be int32"
);
TORCH_CHECK
(
expert_first_token_offset
.
size
(
0
)
==
n_local_expert
+
1
,
"expert_first_token_offset shape != n_local_expert+1"
)
TORCH_CHECK
(
src_row_id2dst_row_id_map
.
sizes
()
==
token_expert_indices
.
sizes
(),
"token_expert_indices shape must be same as src_row_id2dst_row_id_map"
);
TORCH_CHECK
(
inv_permuted_idx
.
sizes
()
==
token_expert_indices
.
sizes
(),
"token_expert_indices shape must be same as inv_permuted_idx"
);
auto
n_token
=
input
.
sizes
()[
0
];
auto
n_hidden
=
input
.
sizes
()[
1
];
auto
align_block_size_value
=
...
...
@@ -46,8 +42,9 @@ void moe_permute(
auto
sort_workspace
=
torch
::
empty
(
{
sorter_size
},
torch
::
dtype
(
torch
::
kInt8
).
device
(
torch
::
kCUDA
).
requires_grad
(
false
));
auto
copy_topk_ids
=
topk_ids
.
clone
();
// copy topk_ids for preprocess
auto
permuted_experts_id
=
torch
::
empty_like
(
topk_ids
);
auto
dst_row_id2src
_row_id
_map
=
torch
::
empty_like
(
src_row_id2dst_row_id_map
);
auto
sorted
_row_id
x
=
torch
::
empty_like
(
inv_permuted_idx
);
auto
align_expert_first_token_offset
=
torch
::
zeros_like
(
expert_first_token_offset
);
...
...
@@ -67,24 +64,22 @@ void moe_permute(
const
int
*
expert_map_ptr
=
get_ptr
<
int
>
(
expert_map
.
value
());
valid_num_ptr
=
get_ptr
<
int64_t
>
(
expert_first_token_offset
)
+
n_local_expert
;
preprocessTopkIdLauncher
(
get_ptr
<
int
>
(
topk_ids
),
n_token
*
topk
,
preprocessTopkIdLauncher
(
get_ptr
<
int
>
(
copy_
topk_ids
),
n_token
*
topk
,
expert_map_ptr
,
n_expert
,
stream
);
}
// expert sort topk expert id and scan expert id get expert_first_token_offset
sortAndScanExpert
(
get_ptr
<
int
>
(
topk_ids
),
get_ptr
<
int
>
(
token_expert_indices
),
get_ptr
<
int
>
(
permuted_experts_id
),
get_ptr
<
int
>
(
dst_row_id2src_row_id_map
),
get_ptr
<
int64_t
>
(
expert_first_token_offset
),
n_token
,
n_expert
,
n_local_expert
,
topk
,
sorter
,
get_ptr
<
int
>
(
sort_workspace
),
stream
);
sortAndScanExpert
(
get_ptr
<
int
>
(
copy_topk_ids
),
get_ptr
<
int
>
(
token_expert_indices
),
get_ptr
<
int
>
(
permuted_experts_id
),
get_ptr
<
int
>
(
sorted_row_idx
),
get_ptr
<
int64_t
>
(
expert_first_token_offset
),
n_token
,
n_expert
,
n_local_expert
,
topk
,
sorter
,
get_ptr
<
int
>
(
sort_workspace
),
stream
);
// dispatch expandInputRowsKernelLauncher
MOE_DISPATCH
(
input
.
scalar_type
(),
[
&
]
{
expandInputRowsKernelLauncher
<
scalar_t
>
(
get_ptr
<
scalar_t
>
(
input
),
get_ptr
<
scalar_t
>
(
permuted_input
),
get_ptr
<
float
>
(
topk_weights
),
get_ptr
<
int
>
(
permuted_experts_id
),
get_ptr
<
int
>
(
dst_row_id2src_row_id_map
),
get_ptr
<
int
>
(
src_row_id2dst_row_id_map
),
get_ptr
<
int
>
(
permuted_experts_id
),
get_ptr
<
int
>
(
sorted_row_idx
),
get_ptr
<
int
>
(
inv_permuted_idx
),
get_ptr
<
int
>
(
permuted_idx
),
get_ptr
<
int64_t
>
(
expert_first_token_offset
),
n_token
,
valid_num_ptr
,
n_hidden
,
topk
,
n_local_expert
,
align_block_size_value
,
stream
);
});
...
...
@@ -102,31 +97,33 @@ void moe_permute(
void
moe_unpermute
(
const
torch
::
Tensor
&
permuted_hidden_states
,
// [n_token * topk, hidden]
const
torch
::
Tensor
&
topk_weights
,
//[n_token, topk]
const
torch
::
Tensor
&
topk_ids
,
// [n_token, topk]
const
torch
::
Tensor
&
src_row_id2dst_row_id_map
,
// [n_token, topk]
const
torch
::
Tensor
&
expert_first_token_offset
,
// [n_local_expert+1]
int64_t
n_expert
,
int64_t
n_local_expert
,
int64_t
topk
,
const
torch
::
Tensor
&
topk_weights
,
//
[n_token, topk]
const
torch
::
Tensor
&
inv_permuted_idx
,
// [n_token, topk]
const
std
::
optional
<
torch
::
Tensor
>&
expert_first_token_offset
,
// [n_local_expert+1]
int64_t
topk
,
torch
::
Tensor
&
hidden_states
// [n_token, hidden]
)
{
TORCH_CHECK
(
src_row_id2dst_row_id_map
.
sizes
()
==
topk_ids
.
sizes
(),
"topk_ids shape must be same as src_row_id2dst_row_id_map"
);
TORCH_CHECK
(
topk_ids
.
scalar_type
()
==
at
::
ScalarType
::
Int
,
"topk_ids must be int32"
);
TORCH_CHECK
(
permuted_hidden_states
.
scalar_type
()
==
hidden_states
.
scalar_type
(),
"
topk_id
s dtype must be same as
src_row_id2dst_row_id_map
"
);
"
permuted_hidden_state
s dtype must be same as
hidden_states
"
);
auto
n_token
=
hidden_states
.
size
(
0
);
auto
n_hidden
=
hidden_states
.
size
(
1
);
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
const
int64_t
*
valid_ptr
=
get_ptr
<
int64_t
>
(
expert_first_token_offset
)
+
n_local_expert
;
int64_t
const
*
valid_ptr
=
nullptr
;
if
(
expert_first_token_offset
.
has_value
())
{
int
n_local_expert
=
expert_first_token_offset
.
value
().
size
(
0
)
-
1
;
valid_ptr
=
get_ptr
<
int64_t
>
(
expert_first_token_offset
.
value
())
+
n_local_expert
;
}
MOE_DISPATCH
(
hidden_states
.
scalar_type
(),
[
&
]
{
finalizeMoeRoutingKernelLauncher
<
scalar_t
,
scalar_t
>
(
get_ptr
<
scalar_t
>
(
permuted_hidden_states
),
get_ptr
<
scalar_t
>
(
hidden_states
),
get_ptr
<
float
>
(
topk_weights
),
get_ptr
<
int
>
(
src_row_id2dst_row_id_map
),
get_ptr
<
int
>
(
topk_ids
)
,
n_token
,
n_hidden
,
topk
,
valid_ptr
,
stream
);
get_ptr
<
int
>
(
inv_permuted_idx
),
n_token
,
n_hidden
,
topk
,
valid_ptr
,
stream
);
});
}
...
...
csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.cu
View file @
57c22e57
...
...
@@ -177,7 +177,7 @@ __global__ void getMIndicesKernel(int64_t* expert_first_token_offset,
int
tidx
=
threadIdx
.
x
;
extern
__shared__
int64_t
smem_expert_first_token_offset
[];
for
(
int
i
=
tidx
;
i
<=
num_local_expert
;
i
+=
blockDim
.
x
)
{
smem_expert_first_token_offset
[
tidx
]
=
__ldg
(
expert_first_token_offset
+
i
);
smem_expert_first_token_offset
[
i
]
=
__ldg
(
expert_first_token_offset
+
i
);
}
__syncthreads
();
auto
last_token_offset
=
smem_expert_first_token_offset
[
eidx
+
1
];
...
...
csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.h
View file @
57c22e57
...
...
@@ -57,31 +57,19 @@ void sortAndScanExpert(int* expert_for_source_row, const int* source_rows,
template
<
typename
T
>
void
expandInputRowsKernelLauncher
(
T
const
*
unpermuted_input
,
T
*
permuted_output
,
const
float
*
unpermuted_scales
,
int
*
sorted_experts
,
T
const
*
unpermuted_input
,
T
*
permuted_output
,
int
*
sorted_experts
,
int
const
*
expanded_dest_row_to_expanded_source_row
,
int
*
expanded_source_row_to_expanded_dest_row
,
int
*
expanded_source_row_to_expanded_dest_row
,
int
*
permuted_idx
,
int64_t
*
expert_first_token_offset
,
int64_t
const
num_rows
,
int64_t
const
*
num_valid_tokens_ptr
,
int64_t
const
cols
,
int
const
k
,
int
num_local_experts
,
const
int
&
align_block_size
,
cudaStream_t
stream
);
// Final kernel to unpermute and scale
// This kernel unpermutes the original data, does the k-way reduction and
// performs the final skip connection.
template
<
typename
T
,
typename
OutputType
,
bool
CHECK_SKIPPED
>
__global__
void
finalizeMoeRoutingKernel
(
T
const
*
expanded_permuted_rows
,
OutputType
*
reduced_unpermuted_output
,
float
const
*
scales
,
int
const
*
expanded_source_row_to_expanded_dest_row
,
int
const
*
expert_for_source_row
,
int64_t
const
orig_cols
,
int64_t
const
k
,
int64_t
const
*
num_valid_ptr
);
template
<
class
T
,
class
OutputType
>
void
finalizeMoeRoutingKernelLauncher
(
T
const
*
expanded_permuted_rows
,
OutputType
*
reduced_unpermuted_output
,
float
const
*
scales
,
int
const
*
expanded_source_row_to_expanded_dest_row
,
int
const
*
expert_for_source_row
,
int64_t
const
num_rows
,
int64_t
const
cols
,
int64_t
const
k
,
int64_t
const
*
num_valid_ptr
,
cudaStream_t
stream
);
int64_t
const
num_rows
,
int64_t
const
cols
,
int64_t
const
k
,
int64_t
const
*
num_valid_ptr
,
cudaStream_t
stream
);
void
preprocessTopkIdLauncher
(
int
*
topk_id_ptr
,
int
size
,
const
int
*
expert_map_ptr
,
int
num_experts
,
...
...
csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.inl
View file @
57c22e57
...
...
@@ -2,10 +2,9 @@
template <typename T, bool CHECK_SKIPPED, bool ALIGN_BLOCK_SIZE>
__global__ void expandInputRowsKernel(
T const* unpermuted_input, T* permuted_output,
const float* unpermuted_scales, int* sorted_experts,
T const* unpermuted_input, T* permuted_output, int* sorted_experts,
int const* expanded_dest_row_to_expanded_source_row,
int* expanded_source_row_to_expanded_dest_row,
int* expanded_source_row_to_expanded_dest_row,
int* permuted_idx,
int64_t* expert_first_token_offset, int64_t const num_rows,
int64_t const* num_dest_rows, int64_t const cols, int64_t k,
int num_local_experts, int align_block_size) {
...
...
@@ -54,6 +53,10 @@ __global__ void expandInputRowsKernel(
assert(expanded_dest_row <= INT32_MAX);
expanded_source_row_to_expanded_dest_row[expanded_source_row] =
static_cast<int>(expanded_dest_row);
// skip non local expert token
if (!CHECK_SKIPPED || blockIdx.x < *num_dest_rows) {
permuted_idx[expanded_dest_row] = expanded_source_row;
}
}
if (!CHECK_SKIPPED || blockIdx.x < *num_dest_rows) {
...
...
@@ -62,7 +65,7 @@ __global__ void expandInputRowsKernel(
using DataElem = cutlass::Array<T, ELEM_PER_THREAD>;
// Duplicate and permute rows
int64_t const source_row = expanded_source_row
% num_rows
;
int64_t const source_row = expanded_source_row
/ k
;
auto const* source_row_ptr =
reinterpret_cast<DataElem const*>(unpermuted_input + source_row * cols);
...
...
@@ -82,10 +85,9 @@ __global__ void expandInputRowsKernel(
template <typename T>
void expandInputRowsKernelLauncher(
T const* unpermuted_input, T* permuted_output,
const float* unpermuted_scales, int* sorted_experts,
T const* unpermuted_input, T* permuted_output, int* sorted_experts,
int const* expanded_dest_row_to_expanded_source_row,
int* expanded_source_row_to_expanded_dest_row,
int* expanded_source_row_to_expanded_dest_row,
int* permuted_idx,
int64_t* expert_first_token_offset, int64_t const num_rows,
int64_t const* num_valid_tokens_ptr, int64_t const cols, int const k,
int num_local_experts, const int& align_block_size, cudaStream_t stream) {
...
...
@@ -105,11 +107,11 @@ void expandInputRowsKernelLauncher(
int64_t smem_size = sizeof(int64_t) * (num_local_experts + 1);
func<<<blocks, threads, smem_size, stream>>>(
unpermuted_input, permuted_output,
unpermuted_scales,
sorted_experts,
unpermuted_input, permuted_output, sorted_experts,
expanded_dest_row_to_expanded_source_row,
expanded_source_row_to_expanded_dest_row,
ex
per
t_first_token_offset
,
num_rows, num_valid_tokens_ptr, cols, k,
num_local_experts,
align_block_size);
expanded_source_row_to_expanded_dest_row, per
muted_idx
,
expert_first_token_offset,
num_rows, num_valid_tokens_ptr, cols, k,
num_local_experts,
align_block_size);
}
template <class T, class U>
...
...
@@ -128,11 +130,9 @@ template <typename T, typename OutputType, bool CHECK_SKIPPED>
__global__ void finalizeMoeRoutingKernel(
T const* expanded_permuted_rows, OutputType* reduced_unpermuted_output,
float const* scales, int const* expanded_source_row_to_expanded_dest_row,
int const* expert_for_source_row, int64_t const orig_cols, int64_t const k,
int64_t const* num_valid_ptr) {
int64_t const orig_cols, int64_t const k, int64_t const* num_valid_ptr) {
assert(orig_cols % 4 == 0);
int64_t const original_row = blockIdx.x;
int64_t const num_rows = gridDim.x;
auto const offset = original_row * orig_cols;
OutputType* reduced_row_ptr = reduced_unpermuted_output + offset;
int64_t const num_valid = *num_valid_ptr;
...
...
@@ -159,14 +159,13 @@ __global__ void finalizeMoeRoutingKernel(
ComputeElem thread_output;
thread_output.fill(0);
for (int k_idx = 0; k_idx < k; ++k_idx) {
int64_t const expanded_original_row = original_row + k_idx
* num_rows
;
int64_t const expanded_original_row = original_row
* k
+ k_idx;
int64_t const expanded_permuted_row =
expanded_source_row_to_expanded_dest_row[expanded_original_row];
int64_t const k_offset = original_row * k + k_idx;
float const row_scale = scales[k_offset];
// Check after row_rescale has accumulated
if (CHECK_SKIPPED && expanded_permuted_row >= num_valid) {
continue;
}
...
...
@@ -189,9 +188,8 @@ template <class T, class OutputType>
void finalizeMoeRoutingKernelLauncher(
T const* expanded_permuted_rows, OutputType* reduced_unpermuted_output,
float const* scales, int const* expanded_source_row_to_expanded_dest_row,
int const* expert_for_source_row, int64_t const num_rows,
int64_t const cols, int64_t const k, int64_t const* num_valid_ptr,
cudaStream_t stream) {
int64_t const num_rows, int64_t const cols, int64_t const k,
int64_t const* num_valid_ptr, cudaStream_t stream) {
int64_t const blocks = num_rows;
int64_t const threads = 256;
bool const check_finished = num_valid_ptr != nullptr;
...
...
@@ -201,6 +199,5 @@ void finalizeMoeRoutingKernelLauncher(
auto* const kernel = func_map[check_finished];
kernel<<<blocks, threads, 0, stream>>>(
expanded_permuted_rows, reduced_unpermuted_output, scales,
expanded_source_row_to_expanded_dest_row, expert_for_source_row, cols, k,
num_valid_ptr);
expanded_source_row_to_expanded_dest_row, cols, k, num_valid_ptr);
}
csrc/moe/torch_bindings.cpp
View file @
57c22e57
...
...
@@ -56,18 +56,17 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
" -> Tensor"
);
m
.
def
(
"moe_permute(Tensor input, Tensor
topk_weight, Tensor!
topk_ids,"
"moe_permute(Tensor input, Tensor topk_ids,"
"Tensor token_expert_indices, Tensor? expert_map, int n_expert,"
"int n_local_expert,"
"int topk, int? align_block_size,Tensor! permuted_input, Tensor! "
"expert_first_token_offset, Tensor!
src_row_id2dst_row_id_map
, Tensor! "
"m_indices)->()"
);
"expert_first_token_offset, Tensor!
inv_permuted_idx
, Tensor! "
"
permuted_idx, Tensor!
m_indices)->()"
);
m
.
def
(
"moe_unpermute(Tensor permuted_hidden_states, Tensor topk_weights,"
"Tensor topk_ids,Tensor src_row_id2dst_row_id_map, Tensor "
"expert_first_token_offset, int n_expert, int n_local_expert,int "
"topk, Tensor! hidden_states)->()"
);
"Tensor inv_permuted_idx, Tensor? expert_first_token_offset, "
"int topk, Tensor! hidden_states)->()"
);
m
.
def
(
"moe_permute_unpermute_supported() -> bool"
);
m
.
impl
(
"moe_permute_unpermute_supported"
,
&
moe_permute_unpermute_supported
);
...
...
tests/kernels/moe/test_moe_permute_unpermute.py
View file @
57c22e57
...
...
@@ -17,15 +17,16 @@ from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import (
moe_permute
,
moe_permute_unpermute_supported
,
moe_unpermute
)
from
vllm.platforms
import
current_platform
NUM_EXPERTS
=
[
16
,
64
]
NUM_EXPERTS
=
[
16
,
64
,
256
]
TOP_KS
=
[
2
,
4
,
6
,
8
]
EP_SIZE
=
[
1
,
4
,
16
]
current_platform
.
seed_everything
(
0
)
def
torch_permute
(
hidden_states
:
torch
.
Tensor
,
def
torch_permute
(
hidden_states
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
token_expert_indices
:
torch
.
Tensor
,
#
token_expert_indices: torch.Tensor,
topk
:
int
,
n_expert
:
int
,
n_local_expert
:
int
,
...
...
@@ -39,6 +40,11 @@ def torch_permute(hidden_states: torch.Tensor,
not_local_expert
=
(
expert_map
[
topk_ids
]
==
-
1
)
topk_ids
=
is_local_expert
*
(
topk_ids
-
start_expert
)
+
not_local_expert
*
(
topk_ids
+
n_expert
)
token_expert_indices
=
torch
.
arange
(
0
,
n_token
*
topk
,
dtype
=
torch
.
int32
,
device
=
hidden_states
.
device
).
reshape
(
(
n_token
,
topk
))
sorted_topk_ids
,
sorted_indices
=
torch
.
sort
(
topk_ids
.
flatten
(),
stable
=
True
)
...
...
@@ -59,8 +65,8 @@ def torch_permute(hidden_states: torch.Tensor,
valid_row_idx
=
[]
if
align_block_size
is
None
:
permuted_hidden_states
=
hidden_states
[
dst_row_id2src_row_id_map
%
n_token
,
...]
permuted_hidden_states
=
hidden_states
[
dst_row_id2src_row_id_map
//
topk
,
...]
permuted_row_size
=
permuted_hidden_states
.
shape
[
0
]
m_indices
=
torch
.
empty
(
permuted_row_size
,
device
=
"cuda"
,
...
...
@@ -73,14 +79,21 @@ def torch_permute(hidden_states: torch.Tensor,
0
,
n_token
*
topk
,
device
=
"cuda"
,
dtype
=
torch
.
int32
)[
src2dst_idx
].
reshape
((
n_token
,
topk
))
valid_row_idx
+=
[
i
for
i
in
range
(
expert_first_token_offset
[
-
1
])]
dst_row_id2src_row_id_map
[
expert_first_token_offset
[
-
1
]:]
=
n_token
*
topk
return
[
permuted_hidden_states
,
expert_first_token_offset
,
src_row_id2dst_row_id_map
,
m_indices
,
valid_row_idx
src_row_id2dst_row_id_map
,
dst_row_id2src_row_id_map
,
m_indices
,
valid_row_idx
]
else
:
permuted_row_size
=
(
topk
*
n_token
+
n_expert
*
(
align_block_size
-
1
)
+
align_block_size
-
1
)
//
align_block_size
*
align_block_size
permuted_idx
=
torch
.
full
((
permuted_row_size
,
),
n_token
*
topk
,
dtype
=
torch
.
int32
,
device
=
hidden_states
.
device
)
permuted_hidden_states
=
torch
.
empty
((
permuted_row_size
,
n_hidden
),
device
=
"cuda"
,
dtype
=
hidden_states
.
dtype
)
...
...
@@ -105,13 +118,16 @@ def torch_permute(hidden_states: torch.Tensor,
align_first_token_offset
=
align_expert_first_token_offset
[
i
-
1
]
align_last_token_offset
=
align_expert_first_token_offset
[
i
]
dst_row_id2src_row_id_in_expert
=
dst_row_id2src_row_id_map
[
first_token_offset
:
first_token_offset
+
n_token_in_expert
]
%
n_token
first_token_offset
:
first_token_offset
+
n_token_in_expert
]
# store token in current expert with align_first_token_offset
permuted_hidden_states
[
align_first_token_offset
:
\
align_first_token_offset
+
n_token_in_expert
,
\
...]
=
hidden_states
[
\
dst_row_id2src_row_id_in_expert
,
...]
dst_row_id2src_row_id_in_expert
//
topk
,
\
...]
permuted_idx
[
align_first_token_offset
:
\
align_first_token_offset
+
\
n_token_in_expert
]
=
dst_row_id2src_row_id_in_expert
# set current expert m_indices
m_indices
[
align_first_token_offset
:
align_last_token_offset
]
=
i
-
1
valid_row_idx
+=
[
...
...
@@ -135,7 +151,7 @@ def torch_permute(hidden_states: torch.Tensor,
src2dst_idx
].
reshape
((
n_token
,
topk
))
return
[
permuted_hidden_states
,
align_expert_first_token_offset
,
align_src_row_id2dst_row_id
,
m_indices
,
valid_row_idx
align_src_row_id2dst_row_id
,
permuted_idx
,
m_indices
,
valid_row_idx
]
...
...
@@ -146,15 +162,18 @@ def torch_unpermute(permuted_hidden_states: torch.Tensor,
valid_row_idx
:
torch
.
Tensor
,
topk
:
int
,
n_expert
:
int
)
->
torch
.
Tensor
:
# ignore invalid row
n_hidden
=
permuted_hidden_states
.
shape
[
1
]
mask
=
torch
.
zeros
(
permuted_hidden_states
.
shape
[
0
],
dtype
=
bool
,
device
=
"cuda"
)
mask
[
valid_row_idx
]
=
True
permuted_hidden_states
[
~
mask
]
=
0
idx
=
src_row_id2dst_row_id_map
.
flatten
()[
token_expert_indices
.
flatten
()].
reshape
(
token_expert_indices
.
shape
)
output
=
permuted_hidden_states
[
idx
,
...]
*
topk_weights
[...,
None
]
output
=
output
.
sum
(
dim
=
1
).
to
(
permuted_hidden_states
.
dtype
)
permuted_hidden_states
=
permuted_hidden_states
[
src_row_id2dst_row_id_map
.
flatten
(),
...]
permuted_hidden_states
=
permuted_hidden_states
.
view
(
-
1
,
topk
,
n_hidden
)
output
=
(
permuted_hidden_states
*
topk_weights
.
unsqueeze
(
2
)).
sum
(
1
).
to
(
permuted_hidden_states
.
dtype
)
return
output
...
...
@@ -184,10 +203,12 @@ def test_moe_permute_unpermute(n_token: int, n_hidden: int, topk: int,
gating_output
=
torch
.
randn
((
n_token
,
n_expert
),
device
=
"cuda"
).
to
(
dtype
)
topk_weights
,
topk_ids
,
token_expert_indices
=
fused_topk
(
hidden_states
,
gating_output
,
topk
,
False
)
gold0
,
gold1
,
gold2
,
gold3
,
valid_row_idx
=
torch_permute
(
(
gold_permuted_hidden_states
,
gold_expert_first_token_offset
,
gold_inv_permuted_idx
,
gold_permuted_idx
,
gold_m_indices
,
valid_row_idx
)
=
torch_permute
(
hidden_states
,
topk_ids
,
token_expert_indices
,
#
token_expert_indices,
topk
,
n_expert
,
n_local_expert
,
...
...
@@ -196,31 +217,42 @@ def test_moe_permute_unpermute(n_token: int, n_hidden: int, topk: int,
align_block_size
=
align_block_size
,
fill_invalid_expert
=
fill_invalid_expert
)
result0
,
result1
,
result2
,
result3
=
moe_permute
(
hidden_states
,
topk_weights
,
topk_ids
,
token_expert_indices
,
topk
,
n_expert
,
n_local_expert
,
expert_map
,
align_block_size
,
fill_invalid_expert
)
(
permuted_hidden_states
,
_
,
expert_first_token_offset
,
inv_permuted_idx
,
m_indices
)
=
moe_permute
(
hidden_states
=
hidden_states
,
a1q_scale
=
None
,
topk_ids
=
topk_ids
,
n_expert
=
n_expert
,
n_local_expert
=
n_local_expert
,
expert_map
=
expert_map
,
align_block_size
=
align_block_size
,
fill_invalid_expert
=
fill_invalid_expert
)
# check expert_first_token_offset
torch
.
testing
.
assert_close
(
gold1
,
result1
,
atol
=
0
,
rtol
=
0
)
torch
.
testing
.
assert_close
(
gold_expert_first_token_offset
,
expert_first_token_offset
,
atol
=
0
,
rtol
=
0
)
# check src_row_id2dst_row_id_map
torch
.
testing
.
assert_close
(
gold2
,
result2
,
atol
=
0
,
rtol
=
0
)
torch
.
testing
.
assert_close
(
gold_inv_permuted_idx
.
flatten
(),
inv_permuted_idx
,
atol
=
0
,
rtol
=
0
)
# check mindice
torch
.
testing
.
assert_close
(
gold
3
,
result3
,
atol
=
0
,
rtol
=
0
)
torch
.
testing
.
assert_close
(
gold
_m_indices
,
m_indices
,
atol
=
0
,
rtol
=
0
)
# check permuted_hidden_states, only valid token
torch
.
testing
.
assert_close
(
gold
0
[
valid_row_idx
],
result0
[
valid_row_idx
],
torch
.
testing
.
assert_close
(
gold
_permuted_hidden_states
[
valid_row_idx
],
permuted_hidden_states
[
valid_row_idx
],
atol
=
0
,
rtol
=
0
)
# add a random tensor to simulate group gemm
result0
=
0.5
*
result0
+
torch
.
randn_like
(
result0
)
result0
=
0.5
*
permuted_hidden_states
+
torch
.
randn_like
(
permuted_hidden_states
)
result4
=
torch
.
empty_like
(
hidden_states
)
moe_unpermute
(
result4
,
result0
,
topk_weights
,
inv_permuted_idx
,
expert_first_token_offset
)
result4
=
moe_unpermute
(
result0
,
topk_weights
,
topk_ids
,
result2
,
result1
,
topk
,
n_expert
,
n_local_expert
)
gold4
=
torch_unpermute
(
result0
,
topk_weights
,
topk_ids
,
token_expert_indices
,
result2
,
valid_row_idx
,
topk
,
n_local_expert
)
token_expert_indices
,
inv_permuted_idx
,
valid_row_idx
,
topk
,
n_local_expert
)
# check unpermuted hidden
torch
.
testing
.
assert_close
(
result4
,
gold4
,
atol
=
2e-2
,
rtol
=
0
)
vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py
View file @
57c22e57
...
...
@@ -76,25 +76,22 @@ def _moe_unpermute_and_reduce(
def
moe_permute
(
hidden_states
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
a1q_scale
:
Optional
[
torch
.
Tensor
]
,
topk_ids
:
torch
.
Tensor
,
token_expert_indices
:
torch
.
Tensor
,
topk
:
int
,
n_expert
:
int
,
n_local_expert
:
int
,
n_local_expert
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
align_block_size
:
Optional
[
int
]
=
None
,
fill_invalid_expert
:
int
=
-
1
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
],
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""
This function expands and permutes activation to gather uncontinuous tokens
for each expert.
Parameters:
- hidden_states (torch.Tensor): The input tensor to the MoE layer.
-
topk_weights (
torch.Tensor):
topk expert route weight for each token.
-
a1q_scale (Optional[
torch.Tensor
]
):
quant scale for hidden_states
- topk_ids (torch.Tensor): topk expert route id for each token.
- token_expert_indices (torch.Tensor): indice for expanded hidden.
- topk (int): The number of top-k experts to select.
- n_expert (int): The number of expert.
- n_local_expert (int): The number of expert in current EP rank.
- expert_map (Optional[torch.Tensor]): A tensor mapping expert indices
...
...
@@ -105,14 +102,17 @@ def moe_permute(
to workaround DeepGemm unsupported -1 in m_indices
Returns:
- permuted_hidden_states (torch.Tensor): permuted activation.
- a1q_scale (Optional[torch.Tensor]): quant scale for hidden_states
- expert_first_token_offset (torch.Tensor): offset of the first token
of each expert for standard grouped gemm. if enable 'align_block_size'
expert_first_token_offset will align up to 'align_block_size'.
- src_row_id2dst_row_id_map (torch.Tensor): idx map for moe_unpermute.
- inv_permuted_idx (torch.Tensor): idx map for moe_unpermute.
- permuted_idx (torch.Tensor): idx map from hidden to permuted_hidden.
- m_indices: m_indices for grouped gemm in deepgemm,`m_indices[i]` records
the group which the j-th row of the LHS belong to.`
"""
n_token
,
n_hidden
=
hidden_states
.
size
()
topk
=
topk_ids
.
size
(
1
)
assert
(
n_hidden
*
hidden_states
.
element_size
()
)
%
16
==
0
,
"permue kernel need hidden dim align to 16B"
permuted_row_size
=
n_token
*
topk
...
...
@@ -120,12 +120,19 @@ def moe_permute(
permuted_row_size
=
(
permuted_row_size
+
n_expert
*
(
align_block_size
-
1
)
+
align_block_size
-
1
)
//
align_block_size
*
align_block_size
if
n_local_expert
==
-
1
:
n_local_expert
=
n_expert
permuted_hidden_states
=
torch
.
empty
(
(
permuted_row_size
,
n_hidden
),
dtype
=
hidden_states
.
dtype
,
device
=
hidden_states
.
device
,
)
token_expert_indices
=
torch
.
arange
(
0
,
n_token
*
topk
,
dtype
=
torch
.
int32
,
device
=
hidden_states
.
device
).
reshape
(
(
n_token
,
topk
))
m_indices
=
torch
.
full
((
permuted_row_size
,
),
fill_invalid_expert
,
dtype
=
torch
.
int32
,
...
...
@@ -133,57 +140,54 @@ def moe_permute(
expert_first_token_offset
=
torch
.
empty
(
n_local_expert
+
1
,
dtype
=
torch
.
int64
,
device
=
hidden_states
.
device
)
src_row_id2dst_row_id_map
=
torch
.
empty
((
n_token
,
topk
),
permuted_idx
=
torch
.
full
((
permuted_row_size
,
),
n_token
*
topk
,
dtype
=
torch
.
int32
,
device
=
hidden_states
.
device
)
inv_permuted_idx
=
torch
.
empty
((
n_token
,
topk
),
dtype
=
torch
.
int32
,
device
=
hidden_states
.
device
)
torch
.
ops
.
_moe_C
.
moe_permute
(
hidden_states
,
topk_weights
,
topk_ids
,
token_expert_indices
,
expert_map
,
n_expert
,
n_local_expert
,
topk
,
align_block_size
,
permuted_hidden_states
,
expert_first_token_offset
,
src_row_id2dst_row_id_map
,
m_indices
)
return
(
permuted_hidden_states
,
expert_first_token_offset
,
src_row_id2dst_row_id_map
,
m_indices
)
topk_ids
=
topk_ids
.
to
(
torch
.
int32
)
torch
.
ops
.
_moe_C
.
moe_permute
(
hidden_states
,
topk_ids
,
token_expert_indices
,
expert_map
,
n_expert
,
n_local_expert
,
topk
,
align_block_size
,
permuted_hidden_states
,
expert_first_token_offset
,
inv_permuted_idx
,
permuted_idx
,
m_indices
)
if
a1q_scale
is
not
None
:
a1q_scale
=
a1q_scale
[
permuted_idx
.
clamp
(
max
=
n_token
*
topk
-
1
)
//
topk
]
return
(
permuted_hidden_states
,
a1q_scale
,
expert_first_token_offset
,
inv_permuted_idx
.
flatten
(),
m_indices
)
def
moe_unpermute
(
out
:
torch
.
Tensor
,
permuted_hidden_states
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
src_row_id2dst_row_id_map
:
torch
.
Tensor
,
expert_first_token_offset
:
torch
.
Tensor
,
topk
:
int
,
n_expert
:
int
,
n_local_expert
:
int
,
)
->
torch
.
Tensor
:
inv_permuted_idx
:
torch
.
Tensor
,
expert_first_token_offset
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
None
:
"""
This function expands and permutes activation to gathering uncontinuous
tokens for each expert.
Parameters:
- out (torch.Tensor): output tensor
- permuted_hidden_states (torch.Tensor): permuted activation.
- topk_weights (torch.Tensor): topk expert route weight for each token.
- topk_ids (torch.Tensor): topk expert route id for each token.
- expert_first_token_offset (torch.Tensor): offset of the first token
of each expert for grouped gemm.
- topk (int): The number of top-k experts to select.
- n_expert (int): The number of expert.
- n_local_expert (int): The number of expert in current EP rank.
- inv_permuted_idx (torch.Tensor): row idx map for moe_unpermute.
- expert_first_token_offset (Optional[torch.Tensor]): offset of the first
token of each expert for grouped gemm.
Returns:
- hidden_states (torch.Tensor): The reduced and unpermuted activation
tensor.
"""
n_token
,
n_hidden
=
topk_weights
.
size
(
0
),
permuted_hidden_states
.
size
(
-
1
)
topk
=
topk_weights
.
size
(
1
)
n_hidden
=
permuted_hidden_states
.
size
(
-
1
)
assert
(
n_hidden
*
permuted_hidden_states
.
element_size
()
)
%
16
==
0
,
"unpermue kernel need hidden dim align to 16B"
hidden_states
=
torch
.
empty
((
n_token
,
n_hidden
),
dtype
=
permuted_hidden_states
.
dtype
,
device
=
permuted_hidden_states
.
device
)
torch
.
ops
.
_moe_C
.
moe_unpermute
(
permuted_hidden_states
,
topk_weights
,
topk_ids
,
src_row_id2dst_row_id_map
,
expert_first_token_offset
,
n_expert
,
n_local_expert
,
topk
,
hidden_states
)
return
hidden_states
inv_permuted_idx
,
expert_first_token_offset
,
topk
,
out
)
def
moe_permute_unpermute_supported
():
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment