Unverified Commit 77c09e11 authored by Wentao Ye's avatar Wentao Ye Committed by GitHub
Browse files

[Refactor] Remove align block size logic in `moe_permute` (#33449)


Signed-off-by: default avataryewentao256 <zhyanwentao@126.com>
parent 16786da7
...@@ -44,10 +44,8 @@ def benchmark_permute( ...@@ -44,10 +44,8 @@ def benchmark_permute(
hidden_states = torch.randn(num_tokens, hidden_size, dtype=dtype) hidden_states = torch.randn(num_tokens, hidden_size, dtype=dtype)
# output_hidden_states = torch.empty_like(hidden_states) # output_hidden_states = torch.empty_like(hidden_states)
if use_fp8_w8a8: if use_fp8_w8a8:
align_block_size = 128 # deepgemm needs 128 m aligned block
qhidden_states, scale = _fp8_quantize(hidden_states, None, None) qhidden_states, scale = _fp8_quantize(hidden_states, None, None)
else: else:
align_block_size = None
qhidden_states = hidden_states qhidden_states = hidden_states
gating_output = torch.randn(num_iters, num_tokens, num_experts, dtype=torch.float32) gating_output = torch.randn(num_iters, num_tokens, num_experts, dtype=torch.float32)
...@@ -67,7 +65,6 @@ def benchmark_permute( ...@@ -67,7 +65,6 @@ def benchmark_permute(
topk_ids=topk_ids, topk_ids=topk_ids,
n_expert=num_experts, n_expert=num_experts,
expert_map=None, expert_map=None,
align_block_size=align_block_size,
) )
# JIT compilation & warmup # JIT compilation & warmup
...@@ -117,10 +114,8 @@ def benchmark_unpermute( ...@@ -117,10 +114,8 @@ def benchmark_unpermute(
# init_dtype = torch.float16 if use_fp8_w8a8 else dtype # init_dtype = torch.float16 if use_fp8_w8a8 else dtype
hidden_states = torch.randn(num_tokens, hidden_size, dtype=dtype) hidden_states = torch.randn(num_tokens, hidden_size, dtype=dtype)
if use_fp8_w8a8: if use_fp8_w8a8:
align_block_size = 128 # deepgemm needs 128 m aligned block
qhidden_states, scale = _fp8_quantize(hidden_states, None, None) qhidden_states, scale = _fp8_quantize(hidden_states, None, None)
else: else:
align_block_size = None
qhidden_states = hidden_states qhidden_states = hidden_states
input_gating = torch.randn(num_tokens, num_experts, dtype=torch.float32) input_gating = torch.randn(num_tokens, num_experts, dtype=torch.float32)
...@@ -142,7 +137,6 @@ def benchmark_unpermute( ...@@ -142,7 +137,6 @@ def benchmark_unpermute(
topk_ids=topk_ids, topk_ids=topk_ids,
n_expert=num_experts, n_expert=num_experts,
expert_map=None, expert_map=None,
align_block_size=align_block_size,
) )
# convert to fp16/bf16 as gemm output # convert to fp16/bf16 as gemm output
return ( return (
......
...@@ -14,12 +14,10 @@ void moe_permute( ...@@ -14,12 +14,10 @@ void moe_permute(
const torch::Tensor& token_expert_indices, // [n_token, topk] const torch::Tensor& token_expert_indices, // [n_token, topk]
const std::optional<torch::Tensor>& expert_map, // [n_expert] const std::optional<torch::Tensor>& expert_map, // [n_expert]
int64_t n_expert, int64_t n_local_expert, int64_t topk, int64_t n_expert, int64_t n_local_expert, int64_t topk,
const std::optional<int64_t>& align_block_size,
torch::Tensor& permuted_input, // [permuted_size, hidden] torch::Tensor& permuted_input, // [permuted_size, hidden]
torch::Tensor& expert_first_token_offset, // [n_local_expert + 1] torch::Tensor& expert_first_token_offset, // [n_local_expert + 1]
torch::Tensor& inv_permuted_idx, // [n_token, topk] torch::Tensor& inv_permuted_idx, // [n_token, topk]
torch::Tensor& permuted_idx, // [permute_size] torch::Tensor& permuted_idx) { // [permute_size]
torch::Tensor& m_indices) { // [align_expand_m]
TORCH_CHECK(expert_first_token_offset.scalar_type() == at::ScalarType::Long, TORCH_CHECK(expert_first_token_offset.scalar_type() == at::ScalarType::Long,
"expert_first_token_offset must be int64"); "expert_first_token_offset must be int64");
TORCH_CHECK(topk_ids.scalar_type() == at::ScalarType::Int, TORCH_CHECK(topk_ids.scalar_type() == at::ScalarType::Int,
...@@ -34,8 +32,6 @@ void moe_permute( ...@@ -34,8 +32,6 @@ void moe_permute(
"token_expert_indices shape must be same as inv_permuted_idx"); "token_expert_indices shape must be same as inv_permuted_idx");
auto n_token = input.sizes()[0]; auto n_token = input.sizes()[0];
auto n_hidden = input.sizes()[1]; auto n_hidden = input.sizes()[1];
auto align_block_size_value =
align_block_size.has_value() ? align_block_size.value() : -1;
auto stream = at::cuda::getCurrentCUDAStream().stream(); auto stream = at::cuda::getCurrentCUDAStream().stream();
const long sorter_size = const long sorter_size =
CubKeyValueSorter::getWorkspaceSize(n_token * topk, n_expert); CubKeyValueSorter::getWorkspaceSize(n_token * topk, n_expert);
...@@ -73,42 +69,15 @@ void moe_permute( ...@@ -73,42 +69,15 @@ void moe_permute(
get_ptr<int64_t>(expert_first_token_offset), n_token, n_expert, get_ptr<int64_t>(expert_first_token_offset), n_token, n_expert,
n_local_expert, topk, sorter, get_ptr<int>(sort_workspace), stream); n_local_expert, topk, sorter, get_ptr<int>(sort_workspace), stream);
// DeepGEMM: use getMIndices kernel to compute
// 1) align_expert_first_token_offset (aligned prefix offsets)
// 2) m_indices (expert id for each aligned row)
// eg. expert0: 3, expert1: 5, expert2: 2 tokens respectively
// expert_first_token_offset = [0, 3, 8, 10], align_block_size = 4
// expert0: 3->4, expert1: 5->8, expert2: 2->4
// align_expert_first_token_offset = [0, 4, 12, 16]
// so m_indices = [0,0,0,0, 1,1,1,1,1,1,1,1, 2,2,2,2]
torch::Tensor align_expert_first_token_offset;
const int64_t* aligned_expert_first_token_offset_ptr = nullptr;
if (align_block_size.has_value()) {
align_expert_first_token_offset =
torch::zeros_like(expert_first_token_offset);
getMIndices(get_ptr<int64_t>(expert_first_token_offset),
get_ptr<int64_t>(align_expert_first_token_offset),
get_ptr<int>(m_indices), n_local_expert, align_block_size_value,
stream);
aligned_expert_first_token_offset_ptr =
get_ptr<int64_t>(align_expert_first_token_offset);
}
// dispatch expandInputRowsKernelLauncher // dispatch expandInputRowsKernelLauncher
MOE_DISPATCH(input.scalar_type(), [&] { MOE_DISPATCH(input.scalar_type(), [&] {
expandInputRowsKernelLauncher<scalar_t>( expandInputRowsKernelLauncher<scalar_t>(
get_ptr<scalar_t>(input), get_ptr<scalar_t>(permuted_input), get_ptr<scalar_t>(input), get_ptr<scalar_t>(permuted_input),
get_ptr<int>(permuted_experts_id), get_ptr<int>(sorted_row_idx), get_ptr<int>(permuted_experts_id), get_ptr<int>(sorted_row_idx),
get_ptr<int>(inv_permuted_idx), get_ptr<int>(permuted_idx), get_ptr<int>(inv_permuted_idx), get_ptr<int>(permuted_idx),
get_ptr<int64_t>(expert_first_token_offset), get_ptr<int64_t>(expert_first_token_offset), n_token, valid_num_ptr,
aligned_expert_first_token_offset_ptr, n_token, valid_num_ptr, n_hidden, n_hidden, topk, n_local_expert, stream);
topk, n_local_expert, align_block_size_value, stream);
}); });
// this is only required for DeepGemm and not required for CUTLASS group gemm
if (align_block_size.has_value()) {
expert_first_token_offset.copy_(align_expert_first_token_offset);
}
} }
void moe_unpermute( void moe_unpermute(
...@@ -201,16 +170,13 @@ void shuffle_rows(const torch::Tensor& input_tensor, ...@@ -201,16 +170,13 @@ void shuffle_rows(const torch::Tensor& input_tensor,
#else #else
void moe_permute(const torch::Tensor& input, const torch::Tensor& topk_weights, void moe_permute(const torch::Tensor& input, const torch::Tensor& topk_ids,
torch::Tensor& topk_ids,
const torch::Tensor& token_expert_indices, const torch::Tensor& token_expert_indices,
const std::optional<torch::Tensor>& expert_map, const std::optional<torch::Tensor>& expert_map,
int64_t n_expert, int64_t n_local_expert, int64_t topk, int64_t n_expert, int64_t n_local_expert, int64_t topk,
const std::optional<int64_t>& align_block_size,
torch::Tensor& permuted_input, torch::Tensor& permuted_input,
torch::Tensor& expert_first_token_offset, torch::Tensor& expert_first_token_offset,
torch::Tensor& src_row_id2dst_row_id_map, torch::Tensor& inv_permuted_idx, torch::Tensor& permuted_idx) {
torch::Tensor& m_indices) {
TORCH_CHECK(false, "moe_permute is not supported on CUDA < 12.0"); TORCH_CHECK(false, "moe_permute is not supported on CUDA < 12.0");
} }
......
...@@ -168,64 +168,4 @@ void preprocessTopkIdLauncher(int* topk_id_ptr, int size, ...@@ -168,64 +168,4 @@ void preprocessTopkIdLauncher(int* topk_id_ptr, int size,
topk_id_ptr, size, expert_map_ptr, num_experts); topk_id_ptr, size, expert_map_ptr, num_experts);
} }
template <bool ALIGN_BLOCK_SIZE>
__global__ void getMIndicesKernel(int64_t* expert_first_token_offset,
int64_t* align_expert_first_token_offset,
int* m_indices, const int num_local_expert,
const int align_block_size) {
int eidx = blockIdx.x;
int tidx = threadIdx.x;
extern __shared__ int64_t smem_expert_first_token_offset[];
for (int i = tidx; i <= num_local_expert; i += blockDim.x) {
smem_expert_first_token_offset[i] = __ldg(expert_first_token_offset + i);
}
__syncthreads();
auto last_token_offset = smem_expert_first_token_offset[eidx + 1];
auto first_token_offset = smem_expert_first_token_offset[eidx];
int n_token_in_expert = last_token_offset - first_token_offset;
if constexpr (ALIGN_BLOCK_SIZE) {
n_token_in_expert = (n_token_in_expert + align_block_size - 1) /
align_block_size * align_block_size;
// round up to ALIGN_BLOCK_SIZE
int64_t accumulate_align_offset = 0;
for (int i = 1; i <= eidx + 1; i++) {
int n_token = smem_expert_first_token_offset[i] -
smem_expert_first_token_offset[i - 1];
accumulate_align_offset =
accumulate_align_offset + (n_token + align_block_size - 1) /
align_block_size * align_block_size;
if (i == eidx) {
first_token_offset = accumulate_align_offset;
}
// last block store align_expert_first_token_offset
if (eidx == num_local_expert - 1 && threadIdx.x == 0) {
align_expert_first_token_offset[i] = accumulate_align_offset;
}
}
}
for (int idx = tidx; idx < n_token_in_expert; idx += blockDim.x) {
// update m_indice with expert id
m_indices[first_token_offset + idx] = eidx;
}
}
void getMIndices(int64_t* expert_first_token_offset,
int64_t* align_expert_first_token_offset, int* m_indices,
int num_local_expert, const int align_block_size,
cudaStream_t stream) {
int block = 256;
int grid = num_local_expert;
int smem_size = sizeof(int64_t) * (num_local_expert + 1);
if (align_block_size == -1) {
getMIndicesKernel<false><<<grid, block, smem_size, stream>>>(
expert_first_token_offset, align_expert_first_token_offset, m_indices,
num_local_expert, align_block_size);
} else {
getMIndicesKernel<true><<<grid, block, smem_size, stream>>>(
expert_first_token_offset, align_expert_first_token_offset, m_indices,
num_local_expert, align_block_size);
}
}
#endif #endif
...@@ -60,10 +60,9 @@ void expandInputRowsKernelLauncher( ...@@ -60,10 +60,9 @@ void expandInputRowsKernelLauncher(
T const* unpermuted_input, T* permuted_output, int* sorted_experts, T const* unpermuted_input, T* permuted_output, int* sorted_experts,
int const* expanded_dest_row_to_expanded_source_row, int const* expanded_dest_row_to_expanded_source_row,
int* expanded_source_row_to_expanded_dest_row, int* permuted_idx, int* expanded_source_row_to_expanded_dest_row, int* permuted_idx,
int64_t const* expert_first_token_offset, int64_t const* expert_first_token_offset, int64_t const num_rows,
int64_t const* aligned_expert_first_token_offset, int64_t const num_rows,
int64_t const* num_valid_tokens_ptr, int64_t const cols, int const k, int64_t const* num_valid_tokens_ptr, int64_t const cols, int const k,
int num_local_experts, const int& align_block_size, cudaStream_t stream); int num_local_experts, cudaStream_t stream);
template <class T, class OutputType> template <class T, class OutputType>
void finalizeMoeRoutingKernelLauncher( void finalizeMoeRoutingKernelLauncher(
...@@ -76,9 +75,4 @@ void preprocessTopkIdLauncher(int* topk_id_ptr, int size, ...@@ -76,9 +75,4 @@ void preprocessTopkIdLauncher(int* topk_id_ptr, int size,
const int* expert_map_ptr, int num_experts, const int* expert_map_ptr, int num_experts,
cudaStream_t stream); cudaStream_t stream);
void getMIndices(int64_t* expert_first_token_offset,
int64_t* align_expert_first_token_offset, int* m_indices,
int num_local_expert, const int align_block_size,
cudaStream_t stream);
#include "moe_permute_unpermute_kernel.inl" #include "moe_permute_unpermute_kernel.inl"
#pragma once #pragma once
template <typename T, bool CHECK_SKIPPED, bool ALIGN_BLOCK_SIZE> template <typename T, bool CHECK_SKIPPED>
__global__ void expandInputRowsKernel( __global__ void expandInputRowsKernel(
T const* unpermuted_input, T* permuted_output, int* sorted_experts, T const* unpermuted_input, T* permuted_output, int* sorted_experts,
int const* expanded_dest_row_to_expanded_source_row, int const* expanded_dest_row_to_expanded_source_row,
int* expanded_source_row_to_expanded_dest_row, int* permuted_idx, int* expanded_source_row_to_expanded_dest_row, int* permuted_idx,
int64_t const* expert_first_token_offset, int64_t const* expert_first_token_offset, int64_t const num_rows,
int64_t const* aligned_expert_first_token_offset, int64_t const num_rows,
int64_t const* num_dest_rows, int64_t const cols, int64_t k, int64_t const* num_dest_rows, int64_t const cols, int64_t k,
int num_local_experts, int align_block_size) { int num_local_experts) {
// Reverse permutation map. // Reverse permutation map.
// I do this so that later, we can use the source -> dest map to do the k-way // I do this so that later, we can use the source -> dest map to do the k-way
// reduction and unpermuting. I need the reverse map for that reduction to // reduction and unpermuting. I need the reverse map for that reduction to
...@@ -19,24 +18,6 @@ __global__ void expandInputRowsKernel( ...@@ -19,24 +18,6 @@ __global__ void expandInputRowsKernel(
expanded_dest_row_to_expanded_source_row[expanded_dest_row]; expanded_dest_row_to_expanded_source_row[expanded_dest_row];
int expert_id = sorted_experts[expanded_dest_row]; int expert_id = sorted_experts[expanded_dest_row];
if constexpr (ALIGN_BLOCK_SIZE) {
// convert (unaligned) expanded_dest_row -> aligned expanded_dest_row.
// aligned_expert_first_token_offset[e] provides the aligned prefix start
// for expert e. For non-local experts we map to the end (total aligned M).
int64_t aligned_base = 0;
int64_t token_offset_in_expert = 0;
if (expert_id >= num_local_experts) {
aligned_base =
__ldg(aligned_expert_first_token_offset + num_local_experts);
token_offset_in_expert = 0;
} else {
aligned_base = __ldg(aligned_expert_first_token_offset + expert_id);
token_offset_in_expert =
expanded_dest_row - __ldg(expert_first_token_offset + expert_id);
}
expanded_dest_row = aligned_base + token_offset_in_expert;
}
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
assert(expanded_dest_row <= INT32_MAX); assert(expanded_dest_row <= INT32_MAX);
expanded_source_row_to_expanded_dest_row[expanded_source_row] = expanded_source_row_to_expanded_dest_row[expanded_source_row] =
...@@ -76,29 +57,25 @@ void expandInputRowsKernelLauncher( ...@@ -76,29 +57,25 @@ void expandInputRowsKernelLauncher(
T const* unpermuted_input, T* permuted_output, int* sorted_experts, T const* unpermuted_input, T* permuted_output, int* sorted_experts,
int const* expanded_dest_row_to_expanded_source_row, int const* expanded_dest_row_to_expanded_source_row,
int* expanded_source_row_to_expanded_dest_row, int* permuted_idx, int* expanded_source_row_to_expanded_dest_row, int* permuted_idx,
int64_t const* expert_first_token_offset, int64_t const* expert_first_token_offset, int64_t const num_rows,
int64_t const* aligned_expert_first_token_offset, int64_t const num_rows,
int64_t const* num_valid_tokens_ptr, int64_t const cols, int const k, int64_t const* num_valid_tokens_ptr, int64_t const cols, int const k,
int num_local_experts, const int& align_block_size, cudaStream_t stream) { int num_local_experts, cudaStream_t stream) {
int64_t const blocks = num_rows * k; int64_t const blocks = num_rows * k;
int64_t const threads = 256; int64_t const threads = 256;
using FuncPtr = decltype(&expandInputRowsKernel<T, true, true>); using FuncPtr = decltype(&expandInputRowsKernel<T, true>);
FuncPtr func_map[2][2] = { FuncPtr func_map[2] = {
{&expandInputRowsKernel<T, false, false>, &expandInputRowsKernel<T, false>,
&expandInputRowsKernel<T, false, true>}, &expandInputRowsKernel<T, true>,
{&expandInputRowsKernel<T, true, false>,
&expandInputRowsKernel<T, true, true>},
}; };
bool is_check_skip = num_valid_tokens_ptr != nullptr; bool is_check_skip = num_valid_tokens_ptr != nullptr;
bool is_align_block_size = align_block_size != -1; auto func = func_map[is_check_skip];
auto func = func_map[is_check_skip][is_align_block_size];
func<<<blocks, threads, 0, stream>>>( func<<<blocks, threads, 0, stream>>>(
unpermuted_input, permuted_output, sorted_experts, unpermuted_input, permuted_output, sorted_experts,
expanded_dest_row_to_expanded_source_row, expanded_dest_row_to_expanded_source_row,
expanded_source_row_to_expanded_dest_row, permuted_idx, expanded_source_row_to_expanded_dest_row, permuted_idx,
expert_first_token_offset, aligned_expert_first_token_offset, num_rows, expert_first_token_offset, num_rows, num_valid_tokens_ptr, cols, k,
num_valid_tokens_ptr, cols, k, num_local_experts, align_block_size); num_local_experts);
} }
template <class T, class U> template <class T, class U>
......
...@@ -99,9 +99,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { ...@@ -99,9 +99,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
"moe_permute(Tensor input, Tensor topk_ids," "moe_permute(Tensor input, Tensor topk_ids,"
"Tensor token_expert_indices, Tensor? expert_map, int n_expert," "Tensor token_expert_indices, Tensor? expert_map, int n_expert,"
"int n_local_expert," "int n_local_expert,"
"int topk, int? align_block_size,Tensor! permuted_input, Tensor! " "int topk, Tensor! permuted_input, Tensor! "
"expert_first_token_offset, Tensor! inv_permuted_idx, Tensor! " "expert_first_token_offset, Tensor! inv_permuted_idx, Tensor! "
"permuted_idx, Tensor! m_indices)->()"); "permuted_idx)->()");
m.def( m.def(
"moe_unpermute(Tensor permuted_hidden_states, Tensor topk_weights," "moe_unpermute(Tensor permuted_hidden_states, Tensor topk_weights,"
......
...@@ -40,10 +40,8 @@ def torch_permute( ...@@ -40,10 +40,8 @@ def torch_permute(
n_local_expert: int, n_local_expert: int,
start_expert: int, start_expert: int,
expert_map: torch.Tensor | None = None, expert_map: torch.Tensor | None = None,
align_block_size: int | None = None,
fill_invalid_expert: int = -1,
) -> list[torch.Tensor]: ) -> list[torch.Tensor]:
n_token, n_hidden = hidden_states.shape[0], hidden_states.shape[1] n_token = hidden_states.shape[0]
if expert_map is not None: if expert_map is not None:
is_local_expert = expert_map[topk_ids] != -1 is_local_expert = expert_map[topk_ids] != -1
not_local_expert = expert_map[topk_ids] == -1 not_local_expert = expert_map[topk_ids] == -1
...@@ -70,16 +68,7 @@ def torch_permute( ...@@ -70,16 +68,7 @@ def torch_permute(
_, src2dst_idx = torch.sort(dst_row_id2src_row_id_map) _, src2dst_idx = torch.sort(dst_row_id2src_row_id_map)
valid_row_idx = [] valid_row_idx = []
if align_block_size is None:
permuted_hidden_states = hidden_states[dst_row_id2src_row_id_map // topk, ...] permuted_hidden_states = hidden_states[dst_row_id2src_row_id_map // topk, ...]
permuted_row_size = permuted_hidden_states.shape[0]
m_indices = torch.empty(
permuted_row_size, device="cuda", dtype=torch.int32
).fill_(fill_invalid_expert)
for i in range(1, n_local_expert + 1):
first_token_offset = expert_first_token_offset[i - 1]
last_token_offset = expert_first_token_offset[i]
m_indices[first_token_offset:last_token_offset] = i - 1
src_row_id2dst_row_id_map = torch.arange( src_row_id2dst_row_id_map = torch.arange(
0, n_token * topk, device="cuda", dtype=torch.int32 0, n_token * topk, device="cuda", dtype=torch.int32
)[src2dst_idx].reshape((n_token, topk)) )[src2dst_idx].reshape((n_token, topk))
...@@ -90,85 +79,6 @@ def torch_permute( ...@@ -90,85 +79,6 @@ def torch_permute(
expert_first_token_offset, expert_first_token_offset,
src_row_id2dst_row_id_map, src_row_id2dst_row_id_map,
dst_row_id2src_row_id_map, dst_row_id2src_row_id_map,
m_indices,
valid_row_idx,
]
else:
permuted_row_size = (
(topk * n_token + n_expert * (align_block_size - 1) + align_block_size - 1)
// align_block_size
* align_block_size
)
permuted_idx = torch.full(
(permuted_row_size,),
n_token * topk,
dtype=torch.int32,
device=hidden_states.device,
)
permuted_hidden_states = torch.empty(
(permuted_row_size, n_hidden), device="cuda", dtype=hidden_states.dtype
)
align_src_row_id2dst_row_id = torch.empty(
n_token * topk, device="cuda", dtype=torch.int32
)
align_expert_first_token_offset = torch.zeros_like(expert_first_token_offset)
m_indices = torch.empty(
permuted_row_size, device="cuda", dtype=torch.int32
).fill_(fill_invalid_expert)
# get align_permuted_hidden_states,
# valid row_idx and align_expert_first_token_offset
for i in range(1, n_local_expert + 1):
first_token_offset = expert_first_token_offset[i - 1]
last_token_offset = expert_first_token_offset[i]
n_token_in_expert = last_token_offset - first_token_offset
align_expert_first_token_offset[i] = (
align_expert_first_token_offset[i - 1]
+ (n_token_in_expert + align_block_size - 1)
// align_block_size
* align_block_size
)
align_first_token_offset = align_expert_first_token_offset[i - 1]
align_last_token_offset = align_expert_first_token_offset[i]
dst_row_id2src_row_id_in_expert = dst_row_id2src_row_id_map[
first_token_offset : first_token_offset + n_token_in_expert
]
# store token in current expert with align_first_token_offset
permuted_hidden_states[
align_first_token_offset : align_first_token_offset + n_token_in_expert,
...,
] = hidden_states[dst_row_id2src_row_id_in_expert // topk, ...]
permuted_idx[
align_first_token_offset : align_first_token_offset + n_token_in_expert
] = dst_row_id2src_row_id_in_expert
# set current expert m_indices
m_indices[align_first_token_offset:align_last_token_offset] = i - 1
valid_row_idx += [
i
for i in range(
align_first_token_offset,
align_first_token_offset + n_token_in_expert,
)
]
# get align_src_row_id2dst_row_id
for i in range(n_token * topk):
eid = sorted_topk_ids[i]
if eid >= n_local_expert:
# check token not in local expert
align_src_row_id2dst_row_id[i] = align_expert_first_token_offset[-1]
continue
first_token_offset = expert_first_token_offset[eid]
align_first_token_offset = align_expert_first_token_offset[eid]
token_offset = i - first_token_offset
align_src_row_id2dst_row_id[i] = align_first_token_offset + token_offset
align_src_row_id2dst_row_id = align_src_row_id2dst_row_id[src2dst_idx].reshape(
(n_token, topk)
)
return [
permuted_hidden_states,
align_expert_first_token_offset,
align_src_row_id2dst_row_id,
permuted_idx,
m_indices,
valid_row_idx, valid_row_idx,
] ]
...@@ -207,7 +117,6 @@ def torch_unpermute( ...@@ -207,7 +117,6 @@ def torch_unpermute(
@pytest.mark.parametrize("topk", TOP_KS) @pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("dtype", [torch.bfloat16]) @pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("ep_size", EP_SIZE) @pytest.mark.parametrize("ep_size", EP_SIZE)
@pytest.mark.parametrize("align_block_size", [None, 128])
def test_moe_permute_unpermute( def test_moe_permute_unpermute(
n_token: int, n_token: int,
n_hidden: int, n_hidden: int,
...@@ -215,11 +124,9 @@ def test_moe_permute_unpermute( ...@@ -215,11 +124,9 @@ def test_moe_permute_unpermute(
n_expert: int, n_expert: int,
ep_size: int, ep_size: int,
dtype: torch.dtype, dtype: torch.dtype,
align_block_size: int | None,
): ):
if not moe_permute_unpermute_supported(): if not moe_permute_unpermute_supported():
pytest.skip("moe_permute_unpermute is not supported on this platform.") pytest.skip("moe_permute_unpermute is not supported on this platform.")
fill_invalid_expert = 0
ep_rank = np.random.randint(0, ep_size) ep_rank = np.random.randint(0, ep_size)
expert_map = None expert_map = None
n_local_expert = n_expert n_local_expert = n_expert
...@@ -238,7 +145,6 @@ def test_moe_permute_unpermute( ...@@ -238,7 +145,6 @@ def test_moe_permute_unpermute(
gold_expert_first_token_offset, gold_expert_first_token_offset,
gold_inv_permuted_idx, gold_inv_permuted_idx,
gold_permuted_idx, gold_permuted_idx,
gold_m_indices,
valid_row_idx, valid_row_idx,
) = torch_permute( ) = torch_permute(
hidden_states, hidden_states,
...@@ -249,8 +155,6 @@ def test_moe_permute_unpermute( ...@@ -249,8 +155,6 @@ def test_moe_permute_unpermute(
n_local_expert, n_local_expert,
start_expert, start_expert,
expert_map=expert_map, expert_map=expert_map,
align_block_size=align_block_size,
fill_invalid_expert=fill_invalid_expert,
) )
( (
...@@ -258,7 +162,7 @@ def test_moe_permute_unpermute( ...@@ -258,7 +162,7 @@ def test_moe_permute_unpermute(
_, _,
expert_first_token_offset, expert_first_token_offset,
inv_permuted_idx, inv_permuted_idx,
m_indices, _,
) = moe_permute( ) = moe_permute(
hidden_states=hidden_states, hidden_states=hidden_states,
a1q_scale=None, a1q_scale=None,
...@@ -266,8 +170,6 @@ def test_moe_permute_unpermute( ...@@ -266,8 +170,6 @@ def test_moe_permute_unpermute(
n_expert=n_expert, n_expert=n_expert,
n_local_expert=n_local_expert, n_local_expert=n_local_expert,
expert_map=expert_map, expert_map=expert_map,
align_block_size=align_block_size,
fill_invalid_expert=fill_invalid_expert,
) )
# check expert_first_token_offset # check expert_first_token_offset
...@@ -278,11 +180,6 @@ def test_moe_permute_unpermute( ...@@ -278,11 +180,6 @@ def test_moe_permute_unpermute(
torch.testing.assert_close( torch.testing.assert_close(
gold_inv_permuted_idx.flatten(), inv_permuted_idx, atol=0, rtol=0 gold_inv_permuted_idx.flatten(), inv_permuted_idx, atol=0, rtol=0
) )
# check mindice
# current kernel usage assumes deepgemm requires align_block_size
# when it's not provided then we don't compute m_indices (for cutlass)
if align_block_size is not None:
torch.testing.assert_close(gold_m_indices, m_indices, atol=0, rtol=0)
# check permuted_hidden_states, only valid token # check permuted_hidden_states, only valid token
torch.testing.assert_close( torch.testing.assert_close(
......
...@@ -11,8 +11,6 @@ def moe_permute( ...@@ -11,8 +11,6 @@ def moe_permute(
n_expert: int, n_expert: int,
n_local_expert: int = -1, n_local_expert: int = -1,
expert_map: torch.Tensor | None = None, expert_map: torch.Tensor | None = None,
align_block_size: int | None = None,
fill_invalid_expert: int = -1,
permuted_hidden_states: torch.Tensor | None = None, permuted_hidden_states: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor, torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor, torch.Tensor, torch.Tensor]:
""" """
...@@ -27,9 +25,6 @@ def moe_permute( ...@@ -27,9 +25,6 @@ def moe_permute(
- expert_map (Optional[torch.Tensor]): A tensor mapping expert indices - expert_map (Optional[torch.Tensor]): A tensor mapping expert indices
from the global expert space to the local expert space of the expert from the global expert space to the local expert space of the expert
parallel shard. parallel shard.
- align_block_size (Optional[int]): align group gemm block size for deepgemm
- fill_invalid_expert(int): fill expert id in m_indices for invalid expert
to workaround DeepGemm unsupported -1 in m_indices
- permuted_hidden_states (Optional[torch.Tensor]): Optional output tensor. - permuted_hidden_states (Optional[torch.Tensor]): Optional output tensor.
If None, the output tensor will be created in this function. If None, the output tensor will be created in this function.
Returns: Returns:
...@@ -37,12 +32,9 @@ def moe_permute( ...@@ -37,12 +32,9 @@ def moe_permute(
- a1q_scale (Optional[torch.Tensor]): permuted quant scale for hidden_states - a1q_scale (Optional[torch.Tensor]): permuted quant scale for hidden_states
if original scale not per-tensor scaling if original scale not per-tensor scaling
- expert_first_token_offset (torch.Tensor): offset of the first token - expert_first_token_offset (torch.Tensor): offset of the first token
of each expert for standard grouped gemm. if enable 'align_block_size' of each expert for standard grouped gemm.
expert_first_token_offset will align up to 'align_block_size'.
- inv_permuted_idx (torch.Tensor): idx map for moe_unpermute. - inv_permuted_idx (torch.Tensor): idx map for moe_unpermute.
- permuted_idx (torch.Tensor): idx map from hidden to permuted_hidden. - permuted_idx (torch.Tensor): idx map from hidden to permuted_hidden.
- m_indices: m_indices for grouped gemm in deepgemm,`m_indices[i]` records
the group which the j-th row of the LHS belong to.`
""" """
n_token, n_hidden = hidden_states.size() n_token, n_hidden = hidden_states.size()
topk = topk_ids.size(1) topk = topk_ids.size(1)
...@@ -50,17 +42,6 @@ def moe_permute( ...@@ -50,17 +42,6 @@ def moe_permute(
"permue kernel need hidden dim align to 16B" "permue kernel need hidden dim align to 16B"
) )
permuted_row_size = n_token * topk permuted_row_size = n_token * topk
if align_block_size is not None:
permuted_row_size = (
(
permuted_row_size
+ n_expert * (align_block_size - 1)
+ align_block_size
- 1
)
// align_block_size
* align_block_size
)
if n_local_expert == -1: if n_local_expert == -1:
n_local_expert = n_expert n_local_expert = n_expert
if permuted_hidden_states is None: if permuted_hidden_states is None:
...@@ -78,12 +59,6 @@ def moe_permute( ...@@ -78,12 +59,6 @@ def moe_permute(
0, n_token * topk, dtype=torch.int32, device=hidden_states.device 0, n_token * topk, dtype=torch.int32, device=hidden_states.device
).reshape((n_token, topk)) ).reshape((n_token, topk))
m_indices = torch.full(
(permuted_row_size,),
fill_invalid_expert,
dtype=torch.int32,
device=hidden_states.device,
)
expert_first_token_offset = torch.empty( expert_first_token_offset = torch.empty(
n_local_expert + 1, dtype=torch.int64, device=hidden_states.device n_local_expert + 1, dtype=torch.int64, device=hidden_states.device
) )
...@@ -105,12 +80,10 @@ def moe_permute( ...@@ -105,12 +80,10 @@ def moe_permute(
n_expert, n_expert,
n_local_expert, n_local_expert,
topk, topk,
align_block_size,
permuted_hidden_states, permuted_hidden_states,
expert_first_token_offset, expert_first_token_offset,
inv_permuted_idx, inv_permuted_idx,
permuted_idx, permuted_idx,
m_indices,
) )
if a1q_scale is not None and a1q_scale.dim() > 1: if a1q_scale is not None and a1q_scale.dim() > 1:
...@@ -120,7 +93,7 @@ def moe_permute( ...@@ -120,7 +93,7 @@ def moe_permute(
a1q_scale, a1q_scale,
expert_first_token_offset, expert_first_token_offset,
inv_permuted_idx.flatten(), inv_permuted_idx.flatten(),
m_indices, permuted_idx,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment