Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
77c09e11
Unverified
Commit
77c09e11
authored
Feb 06, 2026
by
Wentao Ye
Committed by
GitHub
Feb 06, 2026
Browse files
[Refactor] Remove align block size logic in `moe_permute` (#33449)
Signed-off-by:
yewentao256
<
zhyanwentao@126.com
>
parent
16786da7
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
38 additions
and
297 deletions
+38
-297
benchmarks/kernels/benchmark_moe_permute_unpermute.py
benchmarks/kernels/benchmark_moe_permute_unpermute.py
+0
-6
csrc/moe/moe_permute_unpermute_op.cu
csrc/moe/moe_permute_unpermute_op.cu
+5
-39
csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.cu
...permute_unpermute_kernels/moe_permute_unpermute_kernel.cu
+0
-60
csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.h
.../permute_unpermute_kernels/moe_permute_unpermute_kernel.h
+2
-8
csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.inl
...ermute_unpermute_kernels/moe_permute_unpermute_kernel.inl
+12
-35
csrc/moe/torch_bindings.cpp
csrc/moe/torch_bindings.cpp
+2
-2
tests/kernels/moe/test_moe_permute_unpermute.py
tests/kernels/moe/test_moe_permute_unpermute.py
+15
-118
vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py
.../model_executor/layers/fused_moe/moe_permute_unpermute.py
+2
-29
No files found.
benchmarks/kernels/benchmark_moe_permute_unpermute.py
View file @
77c09e11
...
...
@@ -44,10 +44,8 @@ def benchmark_permute(
hidden_states
=
torch
.
randn
(
num_tokens
,
hidden_size
,
dtype
=
dtype
)
# output_hidden_states = torch.empty_like(hidden_states)
if
use_fp8_w8a8
:
align_block_size
=
128
# deepgemm needs 128 m aligned block
qhidden_states
,
scale
=
_fp8_quantize
(
hidden_states
,
None
,
None
)
else
:
align_block_size
=
None
qhidden_states
=
hidden_states
gating_output
=
torch
.
randn
(
num_iters
,
num_tokens
,
num_experts
,
dtype
=
torch
.
float32
)
...
...
@@ -67,7 +65,6 @@ def benchmark_permute(
topk_ids
=
topk_ids
,
n_expert
=
num_experts
,
expert_map
=
None
,
align_block_size
=
align_block_size
,
)
# JIT compilation & warmup
...
...
@@ -117,10 +114,8 @@ def benchmark_unpermute(
# init_dtype = torch.float16 if use_fp8_w8a8 else dtype
hidden_states
=
torch
.
randn
(
num_tokens
,
hidden_size
,
dtype
=
dtype
)
if
use_fp8_w8a8
:
align_block_size
=
128
# deepgemm needs 128 m aligned block
qhidden_states
,
scale
=
_fp8_quantize
(
hidden_states
,
None
,
None
)
else
:
align_block_size
=
None
qhidden_states
=
hidden_states
input_gating
=
torch
.
randn
(
num_tokens
,
num_experts
,
dtype
=
torch
.
float32
)
...
...
@@ -142,7 +137,6 @@ def benchmark_unpermute(
topk_ids
=
topk_ids
,
n_expert
=
num_experts
,
expert_map
=
None
,
align_block_size
=
align_block_size
,
)
# convert to fp16/bf16 as gemm output
return
(
...
...
csrc/moe/moe_permute_unpermute_op.cu
View file @
77c09e11
...
...
@@ -14,12 +14,10 @@ void moe_permute(
const
torch
::
Tensor
&
token_expert_indices
,
// [n_token, topk]
const
std
::
optional
<
torch
::
Tensor
>&
expert_map
,
// [n_expert]
int64_t
n_expert
,
int64_t
n_local_expert
,
int64_t
topk
,
const
std
::
optional
<
int64_t
>&
align_block_size
,
torch
::
Tensor
&
permuted_input
,
// [permuted_size, hidden]
torch
::
Tensor
&
expert_first_token_offset
,
// [n_local_expert + 1]
torch
::
Tensor
&
inv_permuted_idx
,
// [n_token, topk]
torch
::
Tensor
&
permuted_idx
,
// [permute_size]
torch
::
Tensor
&
m_indices
)
{
// [align_expand_m]
torch
::
Tensor
&
permuted_idx
)
{
// [permute_size]
TORCH_CHECK
(
expert_first_token_offset
.
scalar_type
()
==
at
::
ScalarType
::
Long
,
"expert_first_token_offset must be int64"
);
TORCH_CHECK
(
topk_ids
.
scalar_type
()
==
at
::
ScalarType
::
Int
,
...
...
@@ -34,8 +32,6 @@ void moe_permute(
"token_expert_indices shape must be same as inv_permuted_idx"
);
auto
n_token
=
input
.
sizes
()[
0
];
auto
n_hidden
=
input
.
sizes
()[
1
];
auto
align_block_size_value
=
align_block_size
.
has_value
()
?
align_block_size
.
value
()
:
-
1
;
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
const
long
sorter_size
=
CubKeyValueSorter
::
getWorkspaceSize
(
n_token
*
topk
,
n_expert
);
...
...
@@ -73,42 +69,15 @@ void moe_permute(
get_ptr
<
int64_t
>
(
expert_first_token_offset
),
n_token
,
n_expert
,
n_local_expert
,
topk
,
sorter
,
get_ptr
<
int
>
(
sort_workspace
),
stream
);
// DeepGEMM: use getMIndices kernel to compute
// 1) align_expert_first_token_offset (aligned prefix offsets)
// 2) m_indices (expert id for each aligned row)
// eg. expert0: 3, expert1: 5, expert2: 2 tokens respectively
// expert_first_token_offset = [0, 3, 8, 10], align_block_size = 4
// expert0: 3->4, expert1: 5->8, expert2: 2->4
// align_expert_first_token_offset = [0, 4, 12, 16]
// so m_indices = [0,0,0,0, 1,1,1,1,1,1,1,1, 2,2,2,2]
torch
::
Tensor
align_expert_first_token_offset
;
const
int64_t
*
aligned_expert_first_token_offset_ptr
=
nullptr
;
if
(
align_block_size
.
has_value
())
{
align_expert_first_token_offset
=
torch
::
zeros_like
(
expert_first_token_offset
);
getMIndices
(
get_ptr
<
int64_t
>
(
expert_first_token_offset
),
get_ptr
<
int64_t
>
(
align_expert_first_token_offset
),
get_ptr
<
int
>
(
m_indices
),
n_local_expert
,
align_block_size_value
,
stream
);
aligned_expert_first_token_offset_ptr
=
get_ptr
<
int64_t
>
(
align_expert_first_token_offset
);
}
// dispatch expandInputRowsKernelLauncher
MOE_DISPATCH
(
input
.
scalar_type
(),
[
&
]
{
expandInputRowsKernelLauncher
<
scalar_t
>
(
get_ptr
<
scalar_t
>
(
input
),
get_ptr
<
scalar_t
>
(
permuted_input
),
get_ptr
<
int
>
(
permuted_experts_id
),
get_ptr
<
int
>
(
sorted_row_idx
),
get_ptr
<
int
>
(
inv_permuted_idx
),
get_ptr
<
int
>
(
permuted_idx
),
get_ptr
<
int64_t
>
(
expert_first_token_offset
),
aligned_expert_first_token_offset_ptr
,
n_token
,
valid_num_ptr
,
n_hidden
,
topk
,
n_local_expert
,
align_block_size_value
,
stream
);
get_ptr
<
int64_t
>
(
expert_first_token_offset
),
n_token
,
valid_num_ptr
,
n_hidden
,
topk
,
n_local_expert
,
stream
);
});
// this is only required for DeepGemm and not required for CUTLASS group gemm
if
(
align_block_size
.
has_value
())
{
expert_first_token_offset
.
copy_
(
align_expert_first_token_offset
);
}
}
void
moe_unpermute
(
...
...
@@ -201,16 +170,13 @@ void shuffle_rows(const torch::Tensor& input_tensor,
#else
void
moe_permute
(
const
torch
::
Tensor
&
input
,
const
torch
::
Tensor
&
topk_weights
,
torch
::
Tensor
&
topk_ids
,
void
moe_permute
(
const
torch
::
Tensor
&
input
,
const
torch
::
Tensor
&
topk_ids
,
const
torch
::
Tensor
&
token_expert_indices
,
const
std
::
optional
<
torch
::
Tensor
>&
expert_map
,
int64_t
n_expert
,
int64_t
n_local_expert
,
int64_t
topk
,
const
std
::
optional
<
int64_t
>&
align_block_size
,
torch
::
Tensor
&
permuted_input
,
torch
::
Tensor
&
expert_first_token_offset
,
torch
::
Tensor
&
src_row_id2dst_row_id_map
,
torch
::
Tensor
&
m_indices
)
{
torch
::
Tensor
&
inv_permuted_idx
,
torch
::
Tensor
&
permuted_idx
)
{
TORCH_CHECK
(
false
,
"moe_permute is not supported on CUDA < 12.0"
);
}
...
...
csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.cu
View file @
77c09e11
...
...
@@ -168,64 +168,4 @@ void preprocessTopkIdLauncher(int* topk_id_ptr, int size,
topk_id_ptr
,
size
,
expert_map_ptr
,
num_experts
);
}
template
<
bool
ALIGN_BLOCK_SIZE
>
__global__
void
getMIndicesKernel
(
int64_t
*
expert_first_token_offset
,
int64_t
*
align_expert_first_token_offset
,
int
*
m_indices
,
const
int
num_local_expert
,
const
int
align_block_size
)
{
int
eidx
=
blockIdx
.
x
;
int
tidx
=
threadIdx
.
x
;
extern
__shared__
int64_t
smem_expert_first_token_offset
[];
for
(
int
i
=
tidx
;
i
<=
num_local_expert
;
i
+=
blockDim
.
x
)
{
smem_expert_first_token_offset
[
i
]
=
__ldg
(
expert_first_token_offset
+
i
);
}
__syncthreads
();
auto
last_token_offset
=
smem_expert_first_token_offset
[
eidx
+
1
];
auto
first_token_offset
=
smem_expert_first_token_offset
[
eidx
];
int
n_token_in_expert
=
last_token_offset
-
first_token_offset
;
if
constexpr
(
ALIGN_BLOCK_SIZE
)
{
n_token_in_expert
=
(
n_token_in_expert
+
align_block_size
-
1
)
/
align_block_size
*
align_block_size
;
// round up to ALIGN_BLOCK_SIZE
int64_t
accumulate_align_offset
=
0
;
for
(
int
i
=
1
;
i
<=
eidx
+
1
;
i
++
)
{
int
n_token
=
smem_expert_first_token_offset
[
i
]
-
smem_expert_first_token_offset
[
i
-
1
];
accumulate_align_offset
=
accumulate_align_offset
+
(
n_token
+
align_block_size
-
1
)
/
align_block_size
*
align_block_size
;
if
(
i
==
eidx
)
{
first_token_offset
=
accumulate_align_offset
;
}
// last block store align_expert_first_token_offset
if
(
eidx
==
num_local_expert
-
1
&&
threadIdx
.
x
==
0
)
{
align_expert_first_token_offset
[
i
]
=
accumulate_align_offset
;
}
}
}
for
(
int
idx
=
tidx
;
idx
<
n_token_in_expert
;
idx
+=
blockDim
.
x
)
{
// update m_indice with expert id
m_indices
[
first_token_offset
+
idx
]
=
eidx
;
}
}
void
getMIndices
(
int64_t
*
expert_first_token_offset
,
int64_t
*
align_expert_first_token_offset
,
int
*
m_indices
,
int
num_local_expert
,
const
int
align_block_size
,
cudaStream_t
stream
)
{
int
block
=
256
;
int
grid
=
num_local_expert
;
int
smem_size
=
sizeof
(
int64_t
)
*
(
num_local_expert
+
1
);
if
(
align_block_size
==
-
1
)
{
getMIndicesKernel
<
false
><<<
grid
,
block
,
smem_size
,
stream
>>>
(
expert_first_token_offset
,
align_expert_first_token_offset
,
m_indices
,
num_local_expert
,
align_block_size
);
}
else
{
getMIndicesKernel
<
true
><<<
grid
,
block
,
smem_size
,
stream
>>>
(
expert_first_token_offset
,
align_expert_first_token_offset
,
m_indices
,
num_local_expert
,
align_block_size
);
}
}
#endif
csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.h
View file @
77c09e11
...
...
@@ -60,10 +60,9 @@ void expandInputRowsKernelLauncher(
T
const
*
unpermuted_input
,
T
*
permuted_output
,
int
*
sorted_experts
,
int
const
*
expanded_dest_row_to_expanded_source_row
,
int
*
expanded_source_row_to_expanded_dest_row
,
int
*
permuted_idx
,
int64_t
const
*
expert_first_token_offset
,
int64_t
const
*
aligned_expert_first_token_offset
,
int64_t
const
num_rows
,
int64_t
const
*
expert_first_token_offset
,
int64_t
const
num_rows
,
int64_t
const
*
num_valid_tokens_ptr
,
int64_t
const
cols
,
int
const
k
,
int
num_local_experts
,
const
int
&
align_block_size
,
cudaStream_t
stream
);
int
num_local_experts
,
cudaStream_t
stream
);
template
<
class
T
,
class
OutputType
>
void
finalizeMoeRoutingKernelLauncher
(
...
...
@@ -76,9 +75,4 @@ void preprocessTopkIdLauncher(int* topk_id_ptr, int size,
const
int
*
expert_map_ptr
,
int
num_experts
,
cudaStream_t
stream
);
void
getMIndices
(
int64_t
*
expert_first_token_offset
,
int64_t
*
align_expert_first_token_offset
,
int
*
m_indices
,
int
num_local_expert
,
const
int
align_block_size
,
cudaStream_t
stream
);
#include "moe_permute_unpermute_kernel.inl"
csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.inl
View file @
77c09e11
#pragma once
template <typename T, bool CHECK_SKIPPED
, bool ALIGN_BLOCK_SIZE
>
template <typename T, bool CHECK_SKIPPED>
__global__ void expandInputRowsKernel(
T const* unpermuted_input, T* permuted_output, int* sorted_experts,
int const* expanded_dest_row_to_expanded_source_row,
int* expanded_source_row_to_expanded_dest_row, int* permuted_idx,
int64_t const* expert_first_token_offset,
int64_t const* aligned_expert_first_token_offset, int64_t const num_rows,
int64_t const* expert_first_token_offset, int64_t const num_rows,
int64_t const* num_dest_rows, int64_t const cols, int64_t k,
int num_local_experts
, int align_block_size
) {
int num_local_experts) {
// Reverse permutation map.
// I do this so that later, we can use the source -> dest map to do the k-way
// reduction and unpermuting. I need the reverse map for that reduction to
...
...
@@ -19,24 +18,6 @@ __global__ void expandInputRowsKernel(
expanded_dest_row_to_expanded_source_row[expanded_dest_row];
int expert_id = sorted_experts[expanded_dest_row];
if constexpr (ALIGN_BLOCK_SIZE) {
// convert (unaligned) expanded_dest_row -> aligned expanded_dest_row.
// aligned_expert_first_token_offset[e] provides the aligned prefix start
// for expert e. For non-local experts we map to the end (total aligned M).
int64_t aligned_base = 0;
int64_t token_offset_in_expert = 0;
if (expert_id >= num_local_experts) {
aligned_base =
__ldg(aligned_expert_first_token_offset + num_local_experts);
token_offset_in_expert = 0;
} else {
aligned_base = __ldg(aligned_expert_first_token_offset + expert_id);
token_offset_in_expert =
expanded_dest_row - __ldg(expert_first_token_offset + expert_id);
}
expanded_dest_row = aligned_base + token_offset_in_expert;
}
if (threadIdx.x == 0) {
assert(expanded_dest_row <= INT32_MAX);
expanded_source_row_to_expanded_dest_row[expanded_source_row] =
...
...
@@ -76,29 +57,25 @@ void expandInputRowsKernelLauncher(
T const* unpermuted_input, T* permuted_output, int* sorted_experts,
int const* expanded_dest_row_to_expanded_source_row,
int* expanded_source_row_to_expanded_dest_row, int* permuted_idx,
int64_t const* expert_first_token_offset,
int64_t const* aligned_expert_first_token_offset, int64_t const num_rows,
int64_t const* expert_first_token_offset, int64_t const num_rows,
int64_t const* num_valid_tokens_ptr, int64_t const cols, int const k,
int num_local_experts,
const int& align_block_size,
cudaStream_t stream) {
int num_local_experts, cudaStream_t stream) {
int64_t const blocks = num_rows * k;
int64_t const threads = 256;
using FuncPtr = decltype(&expandInputRowsKernel<T, true, true>);
FuncPtr func_map[2][2] = {
{&expandInputRowsKernel<T, false, false>,
&expandInputRowsKernel<T, false, true>},
{&expandInputRowsKernel<T, true, false>,
&expandInputRowsKernel<T, true, true>},
using FuncPtr = decltype(&expandInputRowsKernel<T, true>);
FuncPtr func_map[2] = {
&expandInputRowsKernel<T, false>,
&expandInputRowsKernel<T, true>,
};
bool is_check_skip = num_valid_tokens_ptr != nullptr;
bool is_align_block_size = align_block_size != -1;
auto func = func_map[is_check_skip][is_align_block_size];
auto func = func_map[is_check_skip];
func<<<blocks, threads, 0, stream>>>(
unpermuted_input, permuted_output, sorted_experts,
expanded_dest_row_to_expanded_source_row,
expanded_source_row_to_expanded_dest_row, permuted_idx,
expert_first_token_offset,
aligned_expert_first_token_offset, num_rows
,
num_
valid_tokens_ptr, cols, k, num_local_experts, align_block_size
);
expert_first_token_offset,
num_rows, num_valid_tokens_ptr, cols, k
,
num_
local_experts
);
}
template <class T, class U>
...
...
csrc/moe/torch_bindings.cpp
View file @
77c09e11
...
...
@@ -99,9 +99,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
"moe_permute(Tensor input, Tensor topk_ids,"
"Tensor token_expert_indices, Tensor? expert_map, int n_expert,"
"int n_local_expert,"
"int topk,
int? align_block_size,
Tensor! permuted_input, Tensor! "
"int topk, Tensor! permuted_input, Tensor! "
"expert_first_token_offset, Tensor! inv_permuted_idx, Tensor! "
"permuted_idx
, Tensor! m_indices
)->()"
);
"permuted_idx)->()"
);
m
.
def
(
"moe_unpermute(Tensor permuted_hidden_states, Tensor topk_weights,"
...
...
tests/kernels/moe/test_moe_permute_unpermute.py
View file @
77c09e11
...
...
@@ -40,10 +40,8 @@ def torch_permute(
n_local_expert
:
int
,
start_expert
:
int
,
expert_map
:
torch
.
Tensor
|
None
=
None
,
align_block_size
:
int
|
None
=
None
,
fill_invalid_expert
:
int
=
-
1
,
)
->
list
[
torch
.
Tensor
]:
n_token
,
n_hidden
=
hidden_states
.
shape
[
0
]
,
hidden_states
.
shape
[
1
]
n_token
=
hidden_states
.
shape
[
0
]
if
expert_map
is
not
None
:
is_local_expert
=
expert_map
[
topk_ids
]
!=
-
1
not_local_expert
=
expert_map
[
topk_ids
]
==
-
1
...
...
@@ -70,16 +68,7 @@ def torch_permute(
_
,
src2dst_idx
=
torch
.
sort
(
dst_row_id2src_row_id_map
)
valid_row_idx
=
[]
if
align_block_size
is
None
:
permuted_hidden_states
=
hidden_states
[
dst_row_id2src_row_id_map
//
topk
,
...]
permuted_row_size
=
permuted_hidden_states
.
shape
[
0
]
m_indices
=
torch
.
empty
(
permuted_row_size
,
device
=
"cuda"
,
dtype
=
torch
.
int32
).
fill_
(
fill_invalid_expert
)
for
i
in
range
(
1
,
n_local_expert
+
1
):
first_token_offset
=
expert_first_token_offset
[
i
-
1
]
last_token_offset
=
expert_first_token_offset
[
i
]
m_indices
[
first_token_offset
:
last_token_offset
]
=
i
-
1
src_row_id2dst_row_id_map
=
torch
.
arange
(
0
,
n_token
*
topk
,
device
=
"cuda"
,
dtype
=
torch
.
int32
)[
src2dst_idx
].
reshape
((
n_token
,
topk
))
...
...
@@ -90,85 +79,6 @@ def torch_permute(
expert_first_token_offset
,
src_row_id2dst_row_id_map
,
dst_row_id2src_row_id_map
,
m_indices
,
valid_row_idx
,
]
else
:
permuted_row_size
=
(
(
topk
*
n_token
+
n_expert
*
(
align_block_size
-
1
)
+
align_block_size
-
1
)
//
align_block_size
*
align_block_size
)
permuted_idx
=
torch
.
full
(
(
permuted_row_size
,),
n_token
*
topk
,
dtype
=
torch
.
int32
,
device
=
hidden_states
.
device
,
)
permuted_hidden_states
=
torch
.
empty
(
(
permuted_row_size
,
n_hidden
),
device
=
"cuda"
,
dtype
=
hidden_states
.
dtype
)
align_src_row_id2dst_row_id
=
torch
.
empty
(
n_token
*
topk
,
device
=
"cuda"
,
dtype
=
torch
.
int32
)
align_expert_first_token_offset
=
torch
.
zeros_like
(
expert_first_token_offset
)
m_indices
=
torch
.
empty
(
permuted_row_size
,
device
=
"cuda"
,
dtype
=
torch
.
int32
).
fill_
(
fill_invalid_expert
)
# get align_permuted_hidden_states,
# valid row_idx and align_expert_first_token_offset
for
i
in
range
(
1
,
n_local_expert
+
1
):
first_token_offset
=
expert_first_token_offset
[
i
-
1
]
last_token_offset
=
expert_first_token_offset
[
i
]
n_token_in_expert
=
last_token_offset
-
first_token_offset
align_expert_first_token_offset
[
i
]
=
(
align_expert_first_token_offset
[
i
-
1
]
+
(
n_token_in_expert
+
align_block_size
-
1
)
//
align_block_size
*
align_block_size
)
align_first_token_offset
=
align_expert_first_token_offset
[
i
-
1
]
align_last_token_offset
=
align_expert_first_token_offset
[
i
]
dst_row_id2src_row_id_in_expert
=
dst_row_id2src_row_id_map
[
first_token_offset
:
first_token_offset
+
n_token_in_expert
]
# store token in current expert with align_first_token_offset
permuted_hidden_states
[
align_first_token_offset
:
align_first_token_offset
+
n_token_in_expert
,
...,
]
=
hidden_states
[
dst_row_id2src_row_id_in_expert
//
topk
,
...]
permuted_idx
[
align_first_token_offset
:
align_first_token_offset
+
n_token_in_expert
]
=
dst_row_id2src_row_id_in_expert
# set current expert m_indices
m_indices
[
align_first_token_offset
:
align_last_token_offset
]
=
i
-
1
valid_row_idx
+=
[
i
for
i
in
range
(
align_first_token_offset
,
align_first_token_offset
+
n_token_in_expert
,
)
]
# get align_src_row_id2dst_row_id
for
i
in
range
(
n_token
*
topk
):
eid
=
sorted_topk_ids
[
i
]
if
eid
>=
n_local_expert
:
# check token not in local expert
align_src_row_id2dst_row_id
[
i
]
=
align_expert_first_token_offset
[
-
1
]
continue
first_token_offset
=
expert_first_token_offset
[
eid
]
align_first_token_offset
=
align_expert_first_token_offset
[
eid
]
token_offset
=
i
-
first_token_offset
align_src_row_id2dst_row_id
[
i
]
=
align_first_token_offset
+
token_offset
align_src_row_id2dst_row_id
=
align_src_row_id2dst_row_id
[
src2dst_idx
].
reshape
(
(
n_token
,
topk
)
)
return
[
permuted_hidden_states
,
align_expert_first_token_offset
,
align_src_row_id2dst_row_id
,
permuted_idx
,
m_indices
,
valid_row_idx
,
]
...
...
@@ -207,7 +117,6 @@ def torch_unpermute(
@
pytest
.
mark
.
parametrize
(
"topk"
,
TOP_KS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
"ep_size"
,
EP_SIZE
)
@
pytest
.
mark
.
parametrize
(
"align_block_size"
,
[
None
,
128
])
def
test_moe_permute_unpermute
(
n_token
:
int
,
n_hidden
:
int
,
...
...
@@ -215,11 +124,9 @@ def test_moe_permute_unpermute(
n_expert
:
int
,
ep_size
:
int
,
dtype
:
torch
.
dtype
,
align_block_size
:
int
|
None
,
):
if
not
moe_permute_unpermute_supported
():
pytest
.
skip
(
"moe_permute_unpermute is not supported on this platform."
)
fill_invalid_expert
=
0
ep_rank
=
np
.
random
.
randint
(
0
,
ep_size
)
expert_map
=
None
n_local_expert
=
n_expert
...
...
@@ -238,7 +145,6 @@ def test_moe_permute_unpermute(
gold_expert_first_token_offset
,
gold_inv_permuted_idx
,
gold_permuted_idx
,
gold_m_indices
,
valid_row_idx
,
)
=
torch_permute
(
hidden_states
,
...
...
@@ -249,8 +155,6 @@ def test_moe_permute_unpermute(
n_local_expert
,
start_expert
,
expert_map
=
expert_map
,
align_block_size
=
align_block_size
,
fill_invalid_expert
=
fill_invalid_expert
,
)
(
...
...
@@ -258,7 +162,7 @@ def test_moe_permute_unpermute(
_
,
expert_first_token_offset
,
inv_permuted_idx
,
m_indices
,
_
,
)
=
moe_permute
(
hidden_states
=
hidden_states
,
a1q_scale
=
None
,
...
...
@@ -266,8 +170,6 @@ def test_moe_permute_unpermute(
n_expert
=
n_expert
,
n_local_expert
=
n_local_expert
,
expert_map
=
expert_map
,
align_block_size
=
align_block_size
,
fill_invalid_expert
=
fill_invalid_expert
,
)
# check expert_first_token_offset
...
...
@@ -278,11 +180,6 @@ def test_moe_permute_unpermute(
torch
.
testing
.
assert_close
(
gold_inv_permuted_idx
.
flatten
(),
inv_permuted_idx
,
atol
=
0
,
rtol
=
0
)
# check mindice
# current kernel usage assumes deepgemm requires align_block_size
# when it's not provided then we don't compute m_indices (for cutlass)
if
align_block_size
is
not
None
:
torch
.
testing
.
assert_close
(
gold_m_indices
,
m_indices
,
atol
=
0
,
rtol
=
0
)
# check permuted_hidden_states, only valid token
torch
.
testing
.
assert_close
(
...
...
vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py
View file @
77c09e11
...
...
@@ -11,8 +11,6 @@ def moe_permute(
n_expert
:
int
,
n_local_expert
:
int
=
-
1
,
expert_map
:
torch
.
Tensor
|
None
=
None
,
align_block_size
:
int
|
None
=
None
,
fill_invalid_expert
:
int
=
-
1
,
permuted_hidden_states
:
torch
.
Tensor
|
None
=
None
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
|
None
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""
...
...
@@ -27,9 +25,6 @@ def moe_permute(
- expert_map (Optional[torch.Tensor]): A tensor mapping expert indices
from the global expert space to the local expert space of the expert
parallel shard.
- align_block_size (Optional[int]): align group gemm block size for deepgemm
- fill_invalid_expert(int): fill expert id in m_indices for invalid expert
to workaround DeepGemm unsupported -1 in m_indices
- permuted_hidden_states (Optional[torch.Tensor]): Optional output tensor.
If None, the output tensor will be created in this function.
Returns:
...
...
@@ -37,12 +32,9 @@ def moe_permute(
- a1q_scale (Optional[torch.Tensor]): permuted quant scale for hidden_states
if original scale not per-tensor scaling
- expert_first_token_offset (torch.Tensor): offset of the first token
of each expert for standard grouped gemm. if enable 'align_block_size'
expert_first_token_offset will align up to 'align_block_size'.
of each expert for standard grouped gemm.
- inv_permuted_idx (torch.Tensor): idx map for moe_unpermute.
- permuted_idx (torch.Tensor): idx map from hidden to permuted_hidden.
- m_indices: m_indices for grouped gemm in deepgemm,`m_indices[i]` records
the group which the j-th row of the LHS belong to.`
"""
n_token
,
n_hidden
=
hidden_states
.
size
()
topk
=
topk_ids
.
size
(
1
)
...
...
@@ -50,17 +42,6 @@ def moe_permute(
"permue kernel need hidden dim align to 16B"
)
permuted_row_size
=
n_token
*
topk
if
align_block_size
is
not
None
:
permuted_row_size
=
(
(
permuted_row_size
+
n_expert
*
(
align_block_size
-
1
)
+
align_block_size
-
1
)
//
align_block_size
*
align_block_size
)
if
n_local_expert
==
-
1
:
n_local_expert
=
n_expert
if
permuted_hidden_states
is
None
:
...
...
@@ -78,12 +59,6 @@ def moe_permute(
0
,
n_token
*
topk
,
dtype
=
torch
.
int32
,
device
=
hidden_states
.
device
).
reshape
((
n_token
,
topk
))
m_indices
=
torch
.
full
(
(
permuted_row_size
,),
fill_invalid_expert
,
dtype
=
torch
.
int32
,
device
=
hidden_states
.
device
,
)
expert_first_token_offset
=
torch
.
empty
(
n_local_expert
+
1
,
dtype
=
torch
.
int64
,
device
=
hidden_states
.
device
)
...
...
@@ -105,12 +80,10 @@ def moe_permute(
n_expert
,
n_local_expert
,
topk
,
align_block_size
,
permuted_hidden_states
,
expert_first_token_offset
,
inv_permuted_idx
,
permuted_idx
,
m_indices
,
)
if
a1q_scale
is
not
None
and
a1q_scale
.
dim
()
>
1
:
...
...
@@ -120,7 +93,7 @@ def moe_permute(
a1q_scale
,
expert_first_token_offset
,
inv_permuted_idx
.
flatten
(),
m_indices
,
permuted_idx
,
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment