Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
77c09e11
Unverified
Commit
77c09e11
authored
Feb 06, 2026
by
Wentao Ye
Committed by
GitHub
Feb 06, 2026
Browse files
[Refactor] Remove align block size logic in `moe_permute` (#33449)
Signed-off-by:
yewentao256
<
zhyanwentao@126.com
>
parent
16786da7
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
38 additions
and
297 deletions
+38
-297
benchmarks/kernels/benchmark_moe_permute_unpermute.py
benchmarks/kernels/benchmark_moe_permute_unpermute.py
+0
-6
csrc/moe/moe_permute_unpermute_op.cu
csrc/moe/moe_permute_unpermute_op.cu
+5
-39
csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.cu
...permute_unpermute_kernels/moe_permute_unpermute_kernel.cu
+0
-60
csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.h
.../permute_unpermute_kernels/moe_permute_unpermute_kernel.h
+2
-8
csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.inl
...ermute_unpermute_kernels/moe_permute_unpermute_kernel.inl
+12
-35
csrc/moe/torch_bindings.cpp
csrc/moe/torch_bindings.cpp
+2
-2
tests/kernels/moe/test_moe_permute_unpermute.py
tests/kernels/moe/test_moe_permute_unpermute.py
+15
-118
vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py
.../model_executor/layers/fused_moe/moe_permute_unpermute.py
+2
-29
No files found.
benchmarks/kernels/benchmark_moe_permute_unpermute.py
View file @
77c09e11
...
@@ -44,10 +44,8 @@ def benchmark_permute(
...
@@ -44,10 +44,8 @@ def benchmark_permute(
hidden_states
=
torch
.
randn
(
num_tokens
,
hidden_size
,
dtype
=
dtype
)
hidden_states
=
torch
.
randn
(
num_tokens
,
hidden_size
,
dtype
=
dtype
)
# output_hidden_states = torch.empty_like(hidden_states)
# output_hidden_states = torch.empty_like(hidden_states)
if
use_fp8_w8a8
:
if
use_fp8_w8a8
:
align_block_size
=
128
# deepgemm needs 128 m aligned block
qhidden_states
,
scale
=
_fp8_quantize
(
hidden_states
,
None
,
None
)
qhidden_states
,
scale
=
_fp8_quantize
(
hidden_states
,
None
,
None
)
else
:
else
:
align_block_size
=
None
qhidden_states
=
hidden_states
qhidden_states
=
hidden_states
gating_output
=
torch
.
randn
(
num_iters
,
num_tokens
,
num_experts
,
dtype
=
torch
.
float32
)
gating_output
=
torch
.
randn
(
num_iters
,
num_tokens
,
num_experts
,
dtype
=
torch
.
float32
)
...
@@ -67,7 +65,6 @@ def benchmark_permute(
...
@@ -67,7 +65,6 @@ def benchmark_permute(
topk_ids
=
topk_ids
,
topk_ids
=
topk_ids
,
n_expert
=
num_experts
,
n_expert
=
num_experts
,
expert_map
=
None
,
expert_map
=
None
,
align_block_size
=
align_block_size
,
)
)
# JIT compilation & warmup
# JIT compilation & warmup
...
@@ -117,10 +114,8 @@ def benchmark_unpermute(
...
@@ -117,10 +114,8 @@ def benchmark_unpermute(
# init_dtype = torch.float16 if use_fp8_w8a8 else dtype
# init_dtype = torch.float16 if use_fp8_w8a8 else dtype
hidden_states
=
torch
.
randn
(
num_tokens
,
hidden_size
,
dtype
=
dtype
)
hidden_states
=
torch
.
randn
(
num_tokens
,
hidden_size
,
dtype
=
dtype
)
if
use_fp8_w8a8
:
if
use_fp8_w8a8
:
align_block_size
=
128
# deepgemm needs 128 m aligned block
qhidden_states
,
scale
=
_fp8_quantize
(
hidden_states
,
None
,
None
)
qhidden_states
,
scale
=
_fp8_quantize
(
hidden_states
,
None
,
None
)
else
:
else
:
align_block_size
=
None
qhidden_states
=
hidden_states
qhidden_states
=
hidden_states
input_gating
=
torch
.
randn
(
num_tokens
,
num_experts
,
dtype
=
torch
.
float32
)
input_gating
=
torch
.
randn
(
num_tokens
,
num_experts
,
dtype
=
torch
.
float32
)
...
@@ -142,7 +137,6 @@ def benchmark_unpermute(
...
@@ -142,7 +137,6 @@ def benchmark_unpermute(
topk_ids
=
topk_ids
,
topk_ids
=
topk_ids
,
n_expert
=
num_experts
,
n_expert
=
num_experts
,
expert_map
=
None
,
expert_map
=
None
,
align_block_size
=
align_block_size
,
)
)
# convert to fp16/bf16 as gemm output
# convert to fp16/bf16 as gemm output
return
(
return
(
...
...
csrc/moe/moe_permute_unpermute_op.cu
View file @
77c09e11
...
@@ -14,12 +14,10 @@ void moe_permute(
...
@@ -14,12 +14,10 @@ void moe_permute(
const
torch
::
Tensor
&
token_expert_indices
,
// [n_token, topk]
const
torch
::
Tensor
&
token_expert_indices
,
// [n_token, topk]
const
std
::
optional
<
torch
::
Tensor
>&
expert_map
,
// [n_expert]
const
std
::
optional
<
torch
::
Tensor
>&
expert_map
,
// [n_expert]
int64_t
n_expert
,
int64_t
n_local_expert
,
int64_t
topk
,
int64_t
n_expert
,
int64_t
n_local_expert
,
int64_t
topk
,
const
std
::
optional
<
int64_t
>&
align_block_size
,
torch
::
Tensor
&
permuted_input
,
// [permuted_size, hidden]
torch
::
Tensor
&
permuted_input
,
// [permuted_size, hidden]
torch
::
Tensor
&
expert_first_token_offset
,
// [n_local_expert + 1]
torch
::
Tensor
&
expert_first_token_offset
,
// [n_local_expert + 1]
torch
::
Tensor
&
inv_permuted_idx
,
// [n_token, topk]
torch
::
Tensor
&
inv_permuted_idx
,
// [n_token, topk]
torch
::
Tensor
&
permuted_idx
,
// [permute_size]
torch
::
Tensor
&
permuted_idx
)
{
// [permute_size]
torch
::
Tensor
&
m_indices
)
{
// [align_expand_m]
TORCH_CHECK
(
expert_first_token_offset
.
scalar_type
()
==
at
::
ScalarType
::
Long
,
TORCH_CHECK
(
expert_first_token_offset
.
scalar_type
()
==
at
::
ScalarType
::
Long
,
"expert_first_token_offset must be int64"
);
"expert_first_token_offset must be int64"
);
TORCH_CHECK
(
topk_ids
.
scalar_type
()
==
at
::
ScalarType
::
Int
,
TORCH_CHECK
(
topk_ids
.
scalar_type
()
==
at
::
ScalarType
::
Int
,
...
@@ -34,8 +32,6 @@ void moe_permute(
...
@@ -34,8 +32,6 @@ void moe_permute(
"token_expert_indices shape must be same as inv_permuted_idx"
);
"token_expert_indices shape must be same as inv_permuted_idx"
);
auto
n_token
=
input
.
sizes
()[
0
];
auto
n_token
=
input
.
sizes
()[
0
];
auto
n_hidden
=
input
.
sizes
()[
1
];
auto
n_hidden
=
input
.
sizes
()[
1
];
auto
align_block_size_value
=
align_block_size
.
has_value
()
?
align_block_size
.
value
()
:
-
1
;
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
const
long
sorter_size
=
const
long
sorter_size
=
CubKeyValueSorter
::
getWorkspaceSize
(
n_token
*
topk
,
n_expert
);
CubKeyValueSorter
::
getWorkspaceSize
(
n_token
*
topk
,
n_expert
);
...
@@ -73,42 +69,15 @@ void moe_permute(
...
@@ -73,42 +69,15 @@ void moe_permute(
get_ptr
<
int64_t
>
(
expert_first_token_offset
),
n_token
,
n_expert
,
get_ptr
<
int64_t
>
(
expert_first_token_offset
),
n_token
,
n_expert
,
n_local_expert
,
topk
,
sorter
,
get_ptr
<
int
>
(
sort_workspace
),
stream
);
n_local_expert
,
topk
,
sorter
,
get_ptr
<
int
>
(
sort_workspace
),
stream
);
// DeepGEMM: use getMIndices kernel to compute
// 1) align_expert_first_token_offset (aligned prefix offsets)
// 2) m_indices (expert id for each aligned row)
// eg. expert0: 3, expert1: 5, expert2: 2 tokens respectively
// expert_first_token_offset = [0, 3, 8, 10], align_block_size = 4
// expert0: 3->4, expert1: 5->8, expert2: 2->4
// align_expert_first_token_offset = [0, 4, 12, 16]
// so m_indices = [0,0,0,0, 1,1,1,1,1,1,1,1, 2,2,2,2]
torch
::
Tensor
align_expert_first_token_offset
;
const
int64_t
*
aligned_expert_first_token_offset_ptr
=
nullptr
;
if
(
align_block_size
.
has_value
())
{
align_expert_first_token_offset
=
torch
::
zeros_like
(
expert_first_token_offset
);
getMIndices
(
get_ptr
<
int64_t
>
(
expert_first_token_offset
),
get_ptr
<
int64_t
>
(
align_expert_first_token_offset
),
get_ptr
<
int
>
(
m_indices
),
n_local_expert
,
align_block_size_value
,
stream
);
aligned_expert_first_token_offset_ptr
=
get_ptr
<
int64_t
>
(
align_expert_first_token_offset
);
}
// dispatch expandInputRowsKernelLauncher
// dispatch expandInputRowsKernelLauncher
MOE_DISPATCH
(
input
.
scalar_type
(),
[
&
]
{
MOE_DISPATCH
(
input
.
scalar_type
(),
[
&
]
{
expandInputRowsKernelLauncher
<
scalar_t
>
(
expandInputRowsKernelLauncher
<
scalar_t
>
(
get_ptr
<
scalar_t
>
(
input
),
get_ptr
<
scalar_t
>
(
permuted_input
),
get_ptr
<
scalar_t
>
(
input
),
get_ptr
<
scalar_t
>
(
permuted_input
),
get_ptr
<
int
>
(
permuted_experts_id
),
get_ptr
<
int
>
(
sorted_row_idx
),
get_ptr
<
int
>
(
permuted_experts_id
),
get_ptr
<
int
>
(
sorted_row_idx
),
get_ptr
<
int
>
(
inv_permuted_idx
),
get_ptr
<
int
>
(
permuted_idx
),
get_ptr
<
int
>
(
inv_permuted_idx
),
get_ptr
<
int
>
(
permuted_idx
),
get_ptr
<
int64_t
>
(
expert_first_token_offset
),
get_ptr
<
int64_t
>
(
expert_first_token_offset
),
n_token
,
valid_num_ptr
,
aligned_expert_first_token_offset_ptr
,
n_token
,
valid_num_ptr
,
n_hidden
,
n_hidden
,
topk
,
n_local_expert
,
stream
);
topk
,
n_local_expert
,
align_block_size_value
,
stream
);
});
});
// this is only required for DeepGemm and not required for CUTLASS group gemm
if
(
align_block_size
.
has_value
())
{
expert_first_token_offset
.
copy_
(
align_expert_first_token_offset
);
}
}
}
void
moe_unpermute
(
void
moe_unpermute
(
...
@@ -201,16 +170,13 @@ void shuffle_rows(const torch::Tensor& input_tensor,
...
@@ -201,16 +170,13 @@ void shuffle_rows(const torch::Tensor& input_tensor,
#else
#else
void
moe_permute
(
const
torch
::
Tensor
&
input
,
const
torch
::
Tensor
&
topk_weights
,
void
moe_permute
(
const
torch
::
Tensor
&
input
,
const
torch
::
Tensor
&
topk_ids
,
torch
::
Tensor
&
topk_ids
,
const
torch
::
Tensor
&
token_expert_indices
,
const
torch
::
Tensor
&
token_expert_indices
,
const
std
::
optional
<
torch
::
Tensor
>&
expert_map
,
const
std
::
optional
<
torch
::
Tensor
>&
expert_map
,
int64_t
n_expert
,
int64_t
n_local_expert
,
int64_t
topk
,
int64_t
n_expert
,
int64_t
n_local_expert
,
int64_t
topk
,
const
std
::
optional
<
int64_t
>&
align_block_size
,
torch
::
Tensor
&
permuted_input
,
torch
::
Tensor
&
permuted_input
,
torch
::
Tensor
&
expert_first_token_offset
,
torch
::
Tensor
&
expert_first_token_offset
,
torch
::
Tensor
&
src_row_id2dst_row_id_map
,
torch
::
Tensor
&
inv_permuted_idx
,
torch
::
Tensor
&
permuted_idx
)
{
torch
::
Tensor
&
m_indices
)
{
TORCH_CHECK
(
false
,
"moe_permute is not supported on CUDA < 12.0"
);
TORCH_CHECK
(
false
,
"moe_permute is not supported on CUDA < 12.0"
);
}
}
...
...
csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.cu
View file @
77c09e11
...
@@ -168,64 +168,4 @@ void preprocessTopkIdLauncher(int* topk_id_ptr, int size,
...
@@ -168,64 +168,4 @@ void preprocessTopkIdLauncher(int* topk_id_ptr, int size,
topk_id_ptr
,
size
,
expert_map_ptr
,
num_experts
);
topk_id_ptr
,
size
,
expert_map_ptr
,
num_experts
);
}
}
template
<
bool
ALIGN_BLOCK_SIZE
>
__global__
void
getMIndicesKernel
(
int64_t
*
expert_first_token_offset
,
int64_t
*
align_expert_first_token_offset
,
int
*
m_indices
,
const
int
num_local_expert
,
const
int
align_block_size
)
{
int
eidx
=
blockIdx
.
x
;
int
tidx
=
threadIdx
.
x
;
extern
__shared__
int64_t
smem_expert_first_token_offset
[];
for
(
int
i
=
tidx
;
i
<=
num_local_expert
;
i
+=
blockDim
.
x
)
{
smem_expert_first_token_offset
[
i
]
=
__ldg
(
expert_first_token_offset
+
i
);
}
__syncthreads
();
auto
last_token_offset
=
smem_expert_first_token_offset
[
eidx
+
1
];
auto
first_token_offset
=
smem_expert_first_token_offset
[
eidx
];
int
n_token_in_expert
=
last_token_offset
-
first_token_offset
;
if
constexpr
(
ALIGN_BLOCK_SIZE
)
{
n_token_in_expert
=
(
n_token_in_expert
+
align_block_size
-
1
)
/
align_block_size
*
align_block_size
;
// round up to ALIGN_BLOCK_SIZE
int64_t
accumulate_align_offset
=
0
;
for
(
int
i
=
1
;
i
<=
eidx
+
1
;
i
++
)
{
int
n_token
=
smem_expert_first_token_offset
[
i
]
-
smem_expert_first_token_offset
[
i
-
1
];
accumulate_align_offset
=
accumulate_align_offset
+
(
n_token
+
align_block_size
-
1
)
/
align_block_size
*
align_block_size
;
if
(
i
==
eidx
)
{
first_token_offset
=
accumulate_align_offset
;
}
// last block store align_expert_first_token_offset
if
(
eidx
==
num_local_expert
-
1
&&
threadIdx
.
x
==
0
)
{
align_expert_first_token_offset
[
i
]
=
accumulate_align_offset
;
}
}
}
for
(
int
idx
=
tidx
;
idx
<
n_token_in_expert
;
idx
+=
blockDim
.
x
)
{
// update m_indice with expert id
m_indices
[
first_token_offset
+
idx
]
=
eidx
;
}
}
void
getMIndices
(
int64_t
*
expert_first_token_offset
,
int64_t
*
align_expert_first_token_offset
,
int
*
m_indices
,
int
num_local_expert
,
const
int
align_block_size
,
cudaStream_t
stream
)
{
int
block
=
256
;
int
grid
=
num_local_expert
;
int
smem_size
=
sizeof
(
int64_t
)
*
(
num_local_expert
+
1
);
if
(
align_block_size
==
-
1
)
{
getMIndicesKernel
<
false
><<<
grid
,
block
,
smem_size
,
stream
>>>
(
expert_first_token_offset
,
align_expert_first_token_offset
,
m_indices
,
num_local_expert
,
align_block_size
);
}
else
{
getMIndicesKernel
<
true
><<<
grid
,
block
,
smem_size
,
stream
>>>
(
expert_first_token_offset
,
align_expert_first_token_offset
,
m_indices
,
num_local_expert
,
align_block_size
);
}
}
#endif
#endif
csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.h
View file @
77c09e11
...
@@ -60,10 +60,9 @@ void expandInputRowsKernelLauncher(
...
@@ -60,10 +60,9 @@ void expandInputRowsKernelLauncher(
T
const
*
unpermuted_input
,
T
*
permuted_output
,
int
*
sorted_experts
,
T
const
*
unpermuted_input
,
T
*
permuted_output
,
int
*
sorted_experts
,
int
const
*
expanded_dest_row_to_expanded_source_row
,
int
const
*
expanded_dest_row_to_expanded_source_row
,
int
*
expanded_source_row_to_expanded_dest_row
,
int
*
permuted_idx
,
int
*
expanded_source_row_to_expanded_dest_row
,
int
*
permuted_idx
,
int64_t
const
*
expert_first_token_offset
,
int64_t
const
*
expert_first_token_offset
,
int64_t
const
num_rows
,
int64_t
const
*
aligned_expert_first_token_offset
,
int64_t
const
num_rows
,
int64_t
const
*
num_valid_tokens_ptr
,
int64_t
const
cols
,
int
const
k
,
int64_t
const
*
num_valid_tokens_ptr
,
int64_t
const
cols
,
int
const
k
,
int
num_local_experts
,
const
int
&
align_block_size
,
cudaStream_t
stream
);
int
num_local_experts
,
cudaStream_t
stream
);
template
<
class
T
,
class
OutputType
>
template
<
class
T
,
class
OutputType
>
void
finalizeMoeRoutingKernelLauncher
(
void
finalizeMoeRoutingKernelLauncher
(
...
@@ -76,9 +75,4 @@ void preprocessTopkIdLauncher(int* topk_id_ptr, int size,
...
@@ -76,9 +75,4 @@ void preprocessTopkIdLauncher(int* topk_id_ptr, int size,
const
int
*
expert_map_ptr
,
int
num_experts
,
const
int
*
expert_map_ptr
,
int
num_experts
,
cudaStream_t
stream
);
cudaStream_t
stream
);
void
getMIndices
(
int64_t
*
expert_first_token_offset
,
int64_t
*
align_expert_first_token_offset
,
int
*
m_indices
,
int
num_local_expert
,
const
int
align_block_size
,
cudaStream_t
stream
);
#include "moe_permute_unpermute_kernel.inl"
#include "moe_permute_unpermute_kernel.inl"
csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.inl
View file @
77c09e11
#pragma once
#pragma once
template <typename T, bool CHECK_SKIPPED
, bool ALIGN_BLOCK_SIZE
>
template <typename T, bool CHECK_SKIPPED>
__global__ void expandInputRowsKernel(
__global__ void expandInputRowsKernel(
T const* unpermuted_input, T* permuted_output, int* sorted_experts,
T const* unpermuted_input, T* permuted_output, int* sorted_experts,
int const* expanded_dest_row_to_expanded_source_row,
int const* expanded_dest_row_to_expanded_source_row,
int* expanded_source_row_to_expanded_dest_row, int* permuted_idx,
int* expanded_source_row_to_expanded_dest_row, int* permuted_idx,
int64_t const* expert_first_token_offset,
int64_t const* expert_first_token_offset, int64_t const num_rows,
int64_t const* aligned_expert_first_token_offset, int64_t const num_rows,
int64_t const* num_dest_rows, int64_t const cols, int64_t k,
int64_t const* num_dest_rows, int64_t const cols, int64_t k,
int num_local_experts
, int align_block_size
) {
int num_local_experts) {
// Reverse permutation map.
// Reverse permutation map.
// I do this so that later, we can use the source -> dest map to do the k-way
// I do this so that later, we can use the source -> dest map to do the k-way
// reduction and unpermuting. I need the reverse map for that reduction to
// reduction and unpermuting. I need the reverse map for that reduction to
...
@@ -19,24 +18,6 @@ __global__ void expandInputRowsKernel(
...
@@ -19,24 +18,6 @@ __global__ void expandInputRowsKernel(
expanded_dest_row_to_expanded_source_row[expanded_dest_row];
expanded_dest_row_to_expanded_source_row[expanded_dest_row];
int expert_id = sorted_experts[expanded_dest_row];
int expert_id = sorted_experts[expanded_dest_row];
if constexpr (ALIGN_BLOCK_SIZE) {
// convert (unaligned) expanded_dest_row -> aligned expanded_dest_row.
// aligned_expert_first_token_offset[e] provides the aligned prefix start
// for expert e. For non-local experts we map to the end (total aligned M).
int64_t aligned_base = 0;
int64_t token_offset_in_expert = 0;
if (expert_id >= num_local_experts) {
aligned_base =
__ldg(aligned_expert_first_token_offset + num_local_experts);
token_offset_in_expert = 0;
} else {
aligned_base = __ldg(aligned_expert_first_token_offset + expert_id);
token_offset_in_expert =
expanded_dest_row - __ldg(expert_first_token_offset + expert_id);
}
expanded_dest_row = aligned_base + token_offset_in_expert;
}
if (threadIdx.x == 0) {
if (threadIdx.x == 0) {
assert(expanded_dest_row <= INT32_MAX);
assert(expanded_dest_row <= INT32_MAX);
expanded_source_row_to_expanded_dest_row[expanded_source_row] =
expanded_source_row_to_expanded_dest_row[expanded_source_row] =
...
@@ -76,29 +57,25 @@ void expandInputRowsKernelLauncher(
...
@@ -76,29 +57,25 @@ void expandInputRowsKernelLauncher(
T const* unpermuted_input, T* permuted_output, int* sorted_experts,
T const* unpermuted_input, T* permuted_output, int* sorted_experts,
int const* expanded_dest_row_to_expanded_source_row,
int const* expanded_dest_row_to_expanded_source_row,
int* expanded_source_row_to_expanded_dest_row, int* permuted_idx,
int* expanded_source_row_to_expanded_dest_row, int* permuted_idx,
int64_t const* expert_first_token_offset,
int64_t const* expert_first_token_offset, int64_t const num_rows,
int64_t const* aligned_expert_first_token_offset, int64_t const num_rows,
int64_t const* num_valid_tokens_ptr, int64_t const cols, int const k,
int64_t const* num_valid_tokens_ptr, int64_t const cols, int const k,
int num_local_experts,
const int& align_block_size,
cudaStream_t stream) {
int num_local_experts, cudaStream_t stream) {
int64_t const blocks = num_rows * k;
int64_t const blocks = num_rows * k;
int64_t const threads = 256;
int64_t const threads = 256;
using FuncPtr = decltype(&expandInputRowsKernel<T, true, true>);
using FuncPtr = decltype(&expandInputRowsKernel<T, true>);
FuncPtr func_map[2][2] = {
FuncPtr func_map[2] = {
{&expandInputRowsKernel<T, false, false>,
&expandInputRowsKernel<T, false>,
&expandInputRowsKernel<T, false, true>},
&expandInputRowsKernel<T, true>,
{&expandInputRowsKernel<T, true, false>,
&expandInputRowsKernel<T, true, true>},
};
};
bool is_check_skip = num_valid_tokens_ptr != nullptr;
bool is_check_skip = num_valid_tokens_ptr != nullptr;
bool is_align_block_size = align_block_size != -1;
auto func = func_map[is_check_skip];
auto func = func_map[is_check_skip][is_align_block_size];
func<<<blocks, threads, 0, stream>>>(
func<<<blocks, threads, 0, stream>>>(
unpermuted_input, permuted_output, sorted_experts,
unpermuted_input, permuted_output, sorted_experts,
expanded_dest_row_to_expanded_source_row,
expanded_dest_row_to_expanded_source_row,
expanded_source_row_to_expanded_dest_row, permuted_idx,
expanded_source_row_to_expanded_dest_row, permuted_idx,
expert_first_token_offset,
aligned_expert_first_token_offset, num_rows
,
expert_first_token_offset,
num_rows, num_valid_tokens_ptr, cols, k
,
num_
valid_tokens_ptr, cols, k, num_local_experts, align_block_size
);
num_
local_experts
);
}
}
template <class T, class U>
template <class T, class U>
...
...
csrc/moe/torch_bindings.cpp
View file @
77c09e11
...
@@ -99,9 +99,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
...
@@ -99,9 +99,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
"moe_permute(Tensor input, Tensor topk_ids,"
"moe_permute(Tensor input, Tensor topk_ids,"
"Tensor token_expert_indices, Tensor? expert_map, int n_expert,"
"Tensor token_expert_indices, Tensor? expert_map, int n_expert,"
"int n_local_expert,"
"int n_local_expert,"
"int topk,
int? align_block_size,
Tensor! permuted_input, Tensor! "
"int topk, Tensor! permuted_input, Tensor! "
"expert_first_token_offset, Tensor! inv_permuted_idx, Tensor! "
"expert_first_token_offset, Tensor! inv_permuted_idx, Tensor! "
"permuted_idx
, Tensor! m_indices
)->()"
);
"permuted_idx)->()"
);
m
.
def
(
m
.
def
(
"moe_unpermute(Tensor permuted_hidden_states, Tensor topk_weights,"
"moe_unpermute(Tensor permuted_hidden_states, Tensor topk_weights,"
...
...
tests/kernels/moe/test_moe_permute_unpermute.py
View file @
77c09e11
...
@@ -40,10 +40,8 @@ def torch_permute(
...
@@ -40,10 +40,8 @@ def torch_permute(
n_local_expert
:
int
,
n_local_expert
:
int
,
start_expert
:
int
,
start_expert
:
int
,
expert_map
:
torch
.
Tensor
|
None
=
None
,
expert_map
:
torch
.
Tensor
|
None
=
None
,
align_block_size
:
int
|
None
=
None
,
fill_invalid_expert
:
int
=
-
1
,
)
->
list
[
torch
.
Tensor
]:
)
->
list
[
torch
.
Tensor
]:
n_token
,
n_hidden
=
hidden_states
.
shape
[
0
]
,
hidden_states
.
shape
[
1
]
n_token
=
hidden_states
.
shape
[
0
]
if
expert_map
is
not
None
:
if
expert_map
is
not
None
:
is_local_expert
=
expert_map
[
topk_ids
]
!=
-
1
is_local_expert
=
expert_map
[
topk_ids
]
!=
-
1
not_local_expert
=
expert_map
[
topk_ids
]
==
-
1
not_local_expert
=
expert_map
[
topk_ids
]
==
-
1
...
@@ -70,107 +68,19 @@ def torch_permute(
...
@@ -70,107 +68,19 @@ def torch_permute(
_
,
src2dst_idx
=
torch
.
sort
(
dst_row_id2src_row_id_map
)
_
,
src2dst_idx
=
torch
.
sort
(
dst_row_id2src_row_id_map
)
valid_row_idx
=
[]
valid_row_idx
=
[]
if
align_block_size
is
None
:
permuted_hidden_states
=
hidden_states
[
dst_row_id2src_row_id_map
//
topk
,
...]
permuted_hidden_states
=
hidden_states
[
dst_row_id2src_row_id_map
//
topk
,
...]
src_row_id2dst_row_id_map
=
torch
.
arange
(
permuted_row_size
=
permuted_hidden_states
.
shape
[
0
]
0
,
n_token
*
topk
,
device
=
"cuda"
,
dtype
=
torch
.
int32
m_indices
=
torch
.
empty
(
)[
src2dst_idx
].
reshape
((
n_token
,
topk
))
permuted_row_size
,
device
=
"cuda"
,
dtype
=
torch
.
int32
valid_row_idx
+=
[
i
for
i
in
range
(
expert_first_token_offset
[
-
1
])]
).
fill_
(
fill_invalid_expert
)
dst_row_id2src_row_id_map
[
expert_first_token_offset
[
-
1
]
:]
=
n_token
*
topk
for
i
in
range
(
1
,
n_local_expert
+
1
):
return
[
first_token_offset
=
expert_first_token_offset
[
i
-
1
]
permuted_hidden_states
,
last_token_offset
=
expert_first_token_offset
[
i
]
expert_first_token_offset
,
m_indices
[
first_token_offset
:
last_token_offset
]
=
i
-
1
src_row_id2dst_row_id_map
,
src_row_id2dst_row_id_map
=
torch
.
arange
(
dst_row_id2src_row_id_map
,
0
,
n_token
*
topk
,
device
=
"cuda"
,
dtype
=
torch
.
int32
valid_row_idx
,
)[
src2dst_idx
].
reshape
((
n_token
,
topk
))
]
valid_row_idx
+=
[
i
for
i
in
range
(
expert_first_token_offset
[
-
1
])]
dst_row_id2src_row_id_map
[
expert_first_token_offset
[
-
1
]
:]
=
n_token
*
topk
return
[
permuted_hidden_states
,
expert_first_token_offset
,
src_row_id2dst_row_id_map
,
dst_row_id2src_row_id_map
,
m_indices
,
valid_row_idx
,
]
else
:
permuted_row_size
=
(
(
topk
*
n_token
+
n_expert
*
(
align_block_size
-
1
)
+
align_block_size
-
1
)
//
align_block_size
*
align_block_size
)
permuted_idx
=
torch
.
full
(
(
permuted_row_size
,),
n_token
*
topk
,
dtype
=
torch
.
int32
,
device
=
hidden_states
.
device
,
)
permuted_hidden_states
=
torch
.
empty
(
(
permuted_row_size
,
n_hidden
),
device
=
"cuda"
,
dtype
=
hidden_states
.
dtype
)
align_src_row_id2dst_row_id
=
torch
.
empty
(
n_token
*
topk
,
device
=
"cuda"
,
dtype
=
torch
.
int32
)
align_expert_first_token_offset
=
torch
.
zeros_like
(
expert_first_token_offset
)
m_indices
=
torch
.
empty
(
permuted_row_size
,
device
=
"cuda"
,
dtype
=
torch
.
int32
).
fill_
(
fill_invalid_expert
)
# get align_permuted_hidden_states,
# valid row_idx and align_expert_first_token_offset
for
i
in
range
(
1
,
n_local_expert
+
1
):
first_token_offset
=
expert_first_token_offset
[
i
-
1
]
last_token_offset
=
expert_first_token_offset
[
i
]
n_token_in_expert
=
last_token_offset
-
first_token_offset
align_expert_first_token_offset
[
i
]
=
(
align_expert_first_token_offset
[
i
-
1
]
+
(
n_token_in_expert
+
align_block_size
-
1
)
//
align_block_size
*
align_block_size
)
align_first_token_offset
=
align_expert_first_token_offset
[
i
-
1
]
align_last_token_offset
=
align_expert_first_token_offset
[
i
]
dst_row_id2src_row_id_in_expert
=
dst_row_id2src_row_id_map
[
first_token_offset
:
first_token_offset
+
n_token_in_expert
]
# store token in current expert with align_first_token_offset
permuted_hidden_states
[
align_first_token_offset
:
align_first_token_offset
+
n_token_in_expert
,
...,
]
=
hidden_states
[
dst_row_id2src_row_id_in_expert
//
topk
,
...]
permuted_idx
[
align_first_token_offset
:
align_first_token_offset
+
n_token_in_expert
]
=
dst_row_id2src_row_id_in_expert
# set current expert m_indices
m_indices
[
align_first_token_offset
:
align_last_token_offset
]
=
i
-
1
valid_row_idx
+=
[
i
for
i
in
range
(
align_first_token_offset
,
align_first_token_offset
+
n_token_in_expert
,
)
]
# get align_src_row_id2dst_row_id
for
i
in
range
(
n_token
*
topk
):
eid
=
sorted_topk_ids
[
i
]
if
eid
>=
n_local_expert
:
# check token not in local expert
align_src_row_id2dst_row_id
[
i
]
=
align_expert_first_token_offset
[
-
1
]
continue
first_token_offset
=
expert_first_token_offset
[
eid
]
align_first_token_offset
=
align_expert_first_token_offset
[
eid
]
token_offset
=
i
-
first_token_offset
align_src_row_id2dst_row_id
[
i
]
=
align_first_token_offset
+
token_offset
align_src_row_id2dst_row_id
=
align_src_row_id2dst_row_id
[
src2dst_idx
].
reshape
(
(
n_token
,
topk
)
)
return
[
permuted_hidden_states
,
align_expert_first_token_offset
,
align_src_row_id2dst_row_id
,
permuted_idx
,
m_indices
,
valid_row_idx
,
]
def
torch_unpermute
(
def
torch_unpermute
(
...
@@ -207,7 +117,6 @@ def torch_unpermute(
...
@@ -207,7 +117,6 @@ def torch_unpermute(
@
pytest
.
mark
.
parametrize
(
"topk"
,
TOP_KS
)
@
pytest
.
mark
.
parametrize
(
"topk"
,
TOP_KS
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
"ep_size"
,
EP_SIZE
)
@
pytest
.
mark
.
parametrize
(
"ep_size"
,
EP_SIZE
)
@
pytest
.
mark
.
parametrize
(
"align_block_size"
,
[
None
,
128
])
def
test_moe_permute_unpermute
(
def
test_moe_permute_unpermute
(
n_token
:
int
,
n_token
:
int
,
n_hidden
:
int
,
n_hidden
:
int
,
...
@@ -215,11 +124,9 @@ def test_moe_permute_unpermute(
...
@@ -215,11 +124,9 @@ def test_moe_permute_unpermute(
n_expert
:
int
,
n_expert
:
int
,
ep_size
:
int
,
ep_size
:
int
,
dtype
:
torch
.
dtype
,
dtype
:
torch
.
dtype
,
align_block_size
:
int
|
None
,
):
):
if
not
moe_permute_unpermute_supported
():
if
not
moe_permute_unpermute_supported
():
pytest
.
skip
(
"moe_permute_unpermute is not supported on this platform."
)
pytest
.
skip
(
"moe_permute_unpermute is not supported on this platform."
)
fill_invalid_expert
=
0
ep_rank
=
np
.
random
.
randint
(
0
,
ep_size
)
ep_rank
=
np
.
random
.
randint
(
0
,
ep_size
)
expert_map
=
None
expert_map
=
None
n_local_expert
=
n_expert
n_local_expert
=
n_expert
...
@@ -238,7 +145,6 @@ def test_moe_permute_unpermute(
...
@@ -238,7 +145,6 @@ def test_moe_permute_unpermute(
gold_expert_first_token_offset
,
gold_expert_first_token_offset
,
gold_inv_permuted_idx
,
gold_inv_permuted_idx
,
gold_permuted_idx
,
gold_permuted_idx
,
gold_m_indices
,
valid_row_idx
,
valid_row_idx
,
)
=
torch_permute
(
)
=
torch_permute
(
hidden_states
,
hidden_states
,
...
@@ -249,8 +155,6 @@ def test_moe_permute_unpermute(
...
@@ -249,8 +155,6 @@ def test_moe_permute_unpermute(
n_local_expert
,
n_local_expert
,
start_expert
,
start_expert
,
expert_map
=
expert_map
,
expert_map
=
expert_map
,
align_block_size
=
align_block_size
,
fill_invalid_expert
=
fill_invalid_expert
,
)
)
(
(
...
@@ -258,7 +162,7 @@ def test_moe_permute_unpermute(
...
@@ -258,7 +162,7 @@ def test_moe_permute_unpermute(
_
,
_
,
expert_first_token_offset
,
expert_first_token_offset
,
inv_permuted_idx
,
inv_permuted_idx
,
m_indices
,
_
,
)
=
moe_permute
(
)
=
moe_permute
(
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
,
a1q_scale
=
None
,
a1q_scale
=
None
,
...
@@ -266,8 +170,6 @@ def test_moe_permute_unpermute(
...
@@ -266,8 +170,6 @@ def test_moe_permute_unpermute(
n_expert
=
n_expert
,
n_expert
=
n_expert
,
n_local_expert
=
n_local_expert
,
n_local_expert
=
n_local_expert
,
expert_map
=
expert_map
,
expert_map
=
expert_map
,
align_block_size
=
align_block_size
,
fill_invalid_expert
=
fill_invalid_expert
,
)
)
# check expert_first_token_offset
# check expert_first_token_offset
...
@@ -278,11 +180,6 @@ def test_moe_permute_unpermute(
...
@@ -278,11 +180,6 @@ def test_moe_permute_unpermute(
torch
.
testing
.
assert_close
(
torch
.
testing
.
assert_close
(
gold_inv_permuted_idx
.
flatten
(),
inv_permuted_idx
,
atol
=
0
,
rtol
=
0
gold_inv_permuted_idx
.
flatten
(),
inv_permuted_idx
,
atol
=
0
,
rtol
=
0
)
)
# check mindice
# current kernel usage assumes deepgemm requires align_block_size
# when it's not provided then we don't compute m_indices (for cutlass)
if
align_block_size
is
not
None
:
torch
.
testing
.
assert_close
(
gold_m_indices
,
m_indices
,
atol
=
0
,
rtol
=
0
)
# check permuted_hidden_states, only valid token
# check permuted_hidden_states, only valid token
torch
.
testing
.
assert_close
(
torch
.
testing
.
assert_close
(
...
...
vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py
View file @
77c09e11
...
@@ -11,8 +11,6 @@ def moe_permute(
...
@@ -11,8 +11,6 @@ def moe_permute(
n_expert
:
int
,
n_expert
:
int
,
n_local_expert
:
int
=
-
1
,
n_local_expert
:
int
=
-
1
,
expert_map
:
torch
.
Tensor
|
None
=
None
,
expert_map
:
torch
.
Tensor
|
None
=
None
,
align_block_size
:
int
|
None
=
None
,
fill_invalid_expert
:
int
=
-
1
,
permuted_hidden_states
:
torch
.
Tensor
|
None
=
None
,
permuted_hidden_states
:
torch
.
Tensor
|
None
=
None
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
|
None
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
|
None
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""
"""
...
@@ -27,9 +25,6 @@ def moe_permute(
...
@@ -27,9 +25,6 @@ def moe_permute(
- expert_map (Optional[torch.Tensor]): A tensor mapping expert indices
- expert_map (Optional[torch.Tensor]): A tensor mapping expert indices
from the global expert space to the local expert space of the expert
from the global expert space to the local expert space of the expert
parallel shard.
parallel shard.
- align_block_size (Optional[int]): align group gemm block size for deepgemm
- fill_invalid_expert(int): fill expert id in m_indices for invalid expert
to workaround DeepGemm unsupported -1 in m_indices
- permuted_hidden_states (Optional[torch.Tensor]): Optional output tensor.
- permuted_hidden_states (Optional[torch.Tensor]): Optional output tensor.
If None, the output tensor will be created in this function.
If None, the output tensor will be created in this function.
Returns:
Returns:
...
@@ -37,12 +32,9 @@ def moe_permute(
...
@@ -37,12 +32,9 @@ def moe_permute(
- a1q_scale (Optional[torch.Tensor]): permuted quant scale for hidden_states
- a1q_scale (Optional[torch.Tensor]): permuted quant scale for hidden_states
if original scale not per-tensor scaling
if original scale not per-tensor scaling
- expert_first_token_offset (torch.Tensor): offset of the first token
- expert_first_token_offset (torch.Tensor): offset of the first token
of each expert for standard grouped gemm. if enable 'align_block_size'
of each expert for standard grouped gemm.
expert_first_token_offset will align up to 'align_block_size'.
- inv_permuted_idx (torch.Tensor): idx map for moe_unpermute.
- inv_permuted_idx (torch.Tensor): idx map for moe_unpermute.
- permuted_idx (torch.Tensor): idx map from hidden to permuted_hidden.
- permuted_idx (torch.Tensor): idx map from hidden to permuted_hidden.
- m_indices: m_indices for grouped gemm in deepgemm,`m_indices[i]` records
the group which the j-th row of the LHS belong to.`
"""
"""
n_token
,
n_hidden
=
hidden_states
.
size
()
n_token
,
n_hidden
=
hidden_states
.
size
()
topk
=
topk_ids
.
size
(
1
)
topk
=
topk_ids
.
size
(
1
)
...
@@ -50,17 +42,6 @@ def moe_permute(
...
@@ -50,17 +42,6 @@ def moe_permute(
"permue kernel need hidden dim align to 16B"
"permue kernel need hidden dim align to 16B"
)
)
permuted_row_size
=
n_token
*
topk
permuted_row_size
=
n_token
*
topk
if
align_block_size
is
not
None
:
permuted_row_size
=
(
(
permuted_row_size
+
n_expert
*
(
align_block_size
-
1
)
+
align_block_size
-
1
)
//
align_block_size
*
align_block_size
)
if
n_local_expert
==
-
1
:
if
n_local_expert
==
-
1
:
n_local_expert
=
n_expert
n_local_expert
=
n_expert
if
permuted_hidden_states
is
None
:
if
permuted_hidden_states
is
None
:
...
@@ -78,12 +59,6 @@ def moe_permute(
...
@@ -78,12 +59,6 @@ def moe_permute(
0
,
n_token
*
topk
,
dtype
=
torch
.
int32
,
device
=
hidden_states
.
device
0
,
n_token
*
topk
,
dtype
=
torch
.
int32
,
device
=
hidden_states
.
device
).
reshape
((
n_token
,
topk
))
).
reshape
((
n_token
,
topk
))
m_indices
=
torch
.
full
(
(
permuted_row_size
,),
fill_invalid_expert
,
dtype
=
torch
.
int32
,
device
=
hidden_states
.
device
,
)
expert_first_token_offset
=
torch
.
empty
(
expert_first_token_offset
=
torch
.
empty
(
n_local_expert
+
1
,
dtype
=
torch
.
int64
,
device
=
hidden_states
.
device
n_local_expert
+
1
,
dtype
=
torch
.
int64
,
device
=
hidden_states
.
device
)
)
...
@@ -105,12 +80,10 @@ def moe_permute(
...
@@ -105,12 +80,10 @@ def moe_permute(
n_expert
,
n_expert
,
n_local_expert
,
n_local_expert
,
topk
,
topk
,
align_block_size
,
permuted_hidden_states
,
permuted_hidden_states
,
expert_first_token_offset
,
expert_first_token_offset
,
inv_permuted_idx
,
inv_permuted_idx
,
permuted_idx
,
permuted_idx
,
m_indices
,
)
)
if
a1q_scale
is
not
None
and
a1q_scale
.
dim
()
>
1
:
if
a1q_scale
is
not
None
and
a1q_scale
.
dim
()
>
1
:
...
@@ -120,7 +93,7 @@ def moe_permute(
...
@@ -120,7 +93,7 @@ def moe_permute(
a1q_scale
,
a1q_scale
,
expert_first_token_offset
,
expert_first_token_offset
,
inv_permuted_idx
.
flatten
(),
inv_permuted_idx
.
flatten
(),
m_indices
,
permuted_idx
,
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment