Unverified Commit 77c09e11 authored by Wentao Ye's avatar Wentao Ye Committed by GitHub
Browse files

[Refactor] Remove align block size logic in `moe_permute` (#33449)


Signed-off-by: default avataryewentao256 <zhyanwentao@126.com>
parent 16786da7
......@@ -44,10 +44,8 @@ def benchmark_permute(
hidden_states = torch.randn(num_tokens, hidden_size, dtype=dtype)
# output_hidden_states = torch.empty_like(hidden_states)
if use_fp8_w8a8:
align_block_size = 128 # deepgemm needs 128 m aligned block
qhidden_states, scale = _fp8_quantize(hidden_states, None, None)
else:
align_block_size = None
qhidden_states = hidden_states
gating_output = torch.randn(num_iters, num_tokens, num_experts, dtype=torch.float32)
......@@ -67,7 +65,6 @@ def benchmark_permute(
topk_ids=topk_ids,
n_expert=num_experts,
expert_map=None,
align_block_size=align_block_size,
)
# JIT compilation & warmup
......@@ -117,10 +114,8 @@ def benchmark_unpermute(
# init_dtype = torch.float16 if use_fp8_w8a8 else dtype
hidden_states = torch.randn(num_tokens, hidden_size, dtype=dtype)
if use_fp8_w8a8:
align_block_size = 128 # deepgemm needs 128 m aligned block
qhidden_states, scale = _fp8_quantize(hidden_states, None, None)
else:
align_block_size = None
qhidden_states = hidden_states
input_gating = torch.randn(num_tokens, num_experts, dtype=torch.float32)
......@@ -142,7 +137,6 @@ def benchmark_unpermute(
topk_ids=topk_ids,
n_expert=num_experts,
expert_map=None,
align_block_size=align_block_size,
)
# convert to fp16/bf16 as gemm output
return (
......
......@@ -14,12 +14,10 @@ void moe_permute(
const torch::Tensor& token_expert_indices, // [n_token, topk]
const std::optional<torch::Tensor>& expert_map, // [n_expert]
int64_t n_expert, int64_t n_local_expert, int64_t topk,
const std::optional<int64_t>& align_block_size,
torch::Tensor& permuted_input, // [permuted_size, hidden]
torch::Tensor& expert_first_token_offset, // [n_local_expert + 1]
torch::Tensor& inv_permuted_idx, // [n_token, topk]
torch::Tensor& permuted_idx, // [permute_size]
torch::Tensor& m_indices) { // [align_expand_m]
torch::Tensor& permuted_idx) { // [permute_size]
TORCH_CHECK(expert_first_token_offset.scalar_type() == at::ScalarType::Long,
"expert_first_token_offset must be int64");
TORCH_CHECK(topk_ids.scalar_type() == at::ScalarType::Int,
......@@ -34,8 +32,6 @@ void moe_permute(
"token_expert_indices shape must be same as inv_permuted_idx");
auto n_token = input.sizes()[0];
auto n_hidden = input.sizes()[1];
auto align_block_size_value =
align_block_size.has_value() ? align_block_size.value() : -1;
auto stream = at::cuda::getCurrentCUDAStream().stream();
const long sorter_size =
CubKeyValueSorter::getWorkspaceSize(n_token * topk, n_expert);
......@@ -73,42 +69,15 @@ void moe_permute(
get_ptr<int64_t>(expert_first_token_offset), n_token, n_expert,
n_local_expert, topk, sorter, get_ptr<int>(sort_workspace), stream);
// DeepGEMM: use getMIndices kernel to compute
// 1) align_expert_first_token_offset (aligned prefix offsets)
// 2) m_indices (expert id for each aligned row)
// eg. expert0: 3, expert1: 5, expert2: 2 tokens respectively
// expert_first_token_offset = [0, 3, 8, 10], align_block_size = 4
// expert0: 3->4, expert1: 5->8, expert2: 2->4
// align_expert_first_token_offset = [0, 4, 12, 16]
// so m_indices = [0,0,0,0, 1,1,1,1,1,1,1,1, 2,2,2,2]
torch::Tensor align_expert_first_token_offset;
const int64_t* aligned_expert_first_token_offset_ptr = nullptr;
if (align_block_size.has_value()) {
align_expert_first_token_offset =
torch::zeros_like(expert_first_token_offset);
getMIndices(get_ptr<int64_t>(expert_first_token_offset),
get_ptr<int64_t>(align_expert_first_token_offset),
get_ptr<int>(m_indices), n_local_expert, align_block_size_value,
stream);
aligned_expert_first_token_offset_ptr =
get_ptr<int64_t>(align_expert_first_token_offset);
}
// dispatch expandInputRowsKernelLauncher
MOE_DISPATCH(input.scalar_type(), [&] {
expandInputRowsKernelLauncher<scalar_t>(
get_ptr<scalar_t>(input), get_ptr<scalar_t>(permuted_input),
get_ptr<int>(permuted_experts_id), get_ptr<int>(sorted_row_idx),
get_ptr<int>(inv_permuted_idx), get_ptr<int>(permuted_idx),
get_ptr<int64_t>(expert_first_token_offset),
aligned_expert_first_token_offset_ptr, n_token, valid_num_ptr, n_hidden,
topk, n_local_expert, align_block_size_value, stream);
get_ptr<int64_t>(expert_first_token_offset), n_token, valid_num_ptr,
n_hidden, topk, n_local_expert, stream);
});
// this is only required for DeepGemm and not required for CUTLASS group gemm
if (align_block_size.has_value()) {
expert_first_token_offset.copy_(align_expert_first_token_offset);
}
}
void moe_unpermute(
......@@ -201,16 +170,13 @@ void shuffle_rows(const torch::Tensor& input_tensor,
#else
void moe_permute(const torch::Tensor& input, const torch::Tensor& topk_weights,
torch::Tensor& topk_ids,
void moe_permute(const torch::Tensor& input, const torch::Tensor& topk_ids,
const torch::Tensor& token_expert_indices,
const std::optional<torch::Tensor>& expert_map,
int64_t n_expert, int64_t n_local_expert, int64_t topk,
const std::optional<int64_t>& align_block_size,
torch::Tensor& permuted_input,
torch::Tensor& expert_first_token_offset,
torch::Tensor& src_row_id2dst_row_id_map,
torch::Tensor& m_indices) {
torch::Tensor& inv_permuted_idx, torch::Tensor& permuted_idx) {
TORCH_CHECK(false, "moe_permute is not supported on CUDA < 12.0");
}
......
......@@ -168,64 +168,4 @@ void preprocessTopkIdLauncher(int* topk_id_ptr, int size,
topk_id_ptr, size, expert_map_ptr, num_experts);
}
template <bool ALIGN_BLOCK_SIZE>
__global__ void getMIndicesKernel(int64_t* expert_first_token_offset,
int64_t* align_expert_first_token_offset,
int* m_indices, const int num_local_expert,
const int align_block_size) {
int eidx = blockIdx.x;
int tidx = threadIdx.x;
extern __shared__ int64_t smem_expert_first_token_offset[];
for (int i = tidx; i <= num_local_expert; i += blockDim.x) {
smem_expert_first_token_offset[i] = __ldg(expert_first_token_offset + i);
}
__syncthreads();
auto last_token_offset = smem_expert_first_token_offset[eidx + 1];
auto first_token_offset = smem_expert_first_token_offset[eidx];
int n_token_in_expert = last_token_offset - first_token_offset;
if constexpr (ALIGN_BLOCK_SIZE) {
n_token_in_expert = (n_token_in_expert + align_block_size - 1) /
align_block_size * align_block_size;
// round up to ALIGN_BLOCK_SIZE
int64_t accumulate_align_offset = 0;
for (int i = 1; i <= eidx + 1; i++) {
int n_token = smem_expert_first_token_offset[i] -
smem_expert_first_token_offset[i - 1];
accumulate_align_offset =
accumulate_align_offset + (n_token + align_block_size - 1) /
align_block_size * align_block_size;
if (i == eidx) {
first_token_offset = accumulate_align_offset;
}
// last block store align_expert_first_token_offset
if (eidx == num_local_expert - 1 && threadIdx.x == 0) {
align_expert_first_token_offset[i] = accumulate_align_offset;
}
}
}
for (int idx = tidx; idx < n_token_in_expert; idx += blockDim.x) {
// update m_indice with expert id
m_indices[first_token_offset + idx] = eidx;
}
}
void getMIndices(int64_t* expert_first_token_offset,
int64_t* align_expert_first_token_offset, int* m_indices,
int num_local_expert, const int align_block_size,
cudaStream_t stream) {
int block = 256;
int grid = num_local_expert;
int smem_size = sizeof(int64_t) * (num_local_expert + 1);
if (align_block_size == -1) {
getMIndicesKernel<false><<<grid, block, smem_size, stream>>>(
expert_first_token_offset, align_expert_first_token_offset, m_indices,
num_local_expert, align_block_size);
} else {
getMIndicesKernel<true><<<grid, block, smem_size, stream>>>(
expert_first_token_offset, align_expert_first_token_offset, m_indices,
num_local_expert, align_block_size);
}
}
#endif
......@@ -60,10 +60,9 @@ void expandInputRowsKernelLauncher(
T const* unpermuted_input, T* permuted_output, int* sorted_experts,
int const* expanded_dest_row_to_expanded_source_row,
int* expanded_source_row_to_expanded_dest_row, int* permuted_idx,
int64_t const* expert_first_token_offset,
int64_t const* aligned_expert_first_token_offset, int64_t const num_rows,
int64_t const* expert_first_token_offset, int64_t const num_rows,
int64_t const* num_valid_tokens_ptr, int64_t const cols, int const k,
int num_local_experts, const int& align_block_size, cudaStream_t stream);
int num_local_experts, cudaStream_t stream);
template <class T, class OutputType>
void finalizeMoeRoutingKernelLauncher(
......@@ -76,9 +75,4 @@ void preprocessTopkIdLauncher(int* topk_id_ptr, int size,
const int* expert_map_ptr, int num_experts,
cudaStream_t stream);
void getMIndices(int64_t* expert_first_token_offset,
int64_t* align_expert_first_token_offset, int* m_indices,
int num_local_expert, const int align_block_size,
cudaStream_t stream);
#include "moe_permute_unpermute_kernel.inl"
#pragma once
template <typename T, bool CHECK_SKIPPED, bool ALIGN_BLOCK_SIZE>
template <typename T, bool CHECK_SKIPPED>
__global__ void expandInputRowsKernel(
T const* unpermuted_input, T* permuted_output, int* sorted_experts,
int const* expanded_dest_row_to_expanded_source_row,
int* expanded_source_row_to_expanded_dest_row, int* permuted_idx,
int64_t const* expert_first_token_offset,
int64_t const* aligned_expert_first_token_offset, int64_t const num_rows,
int64_t const* expert_first_token_offset, int64_t const num_rows,
int64_t const* num_dest_rows, int64_t const cols, int64_t k,
int num_local_experts, int align_block_size) {
int num_local_experts) {
// Reverse permutation map.
// I do this so that later, we can use the source -> dest map to do the k-way
// reduction and unpermuting. I need the reverse map for that reduction to
......@@ -19,24 +18,6 @@ __global__ void expandInputRowsKernel(
expanded_dest_row_to_expanded_source_row[expanded_dest_row];
int expert_id = sorted_experts[expanded_dest_row];
if constexpr (ALIGN_BLOCK_SIZE) {
// convert (unaligned) expanded_dest_row -> aligned expanded_dest_row.
// aligned_expert_first_token_offset[e] provides the aligned prefix start
// for expert e. For non-local experts we map to the end (total aligned M).
int64_t aligned_base = 0;
int64_t token_offset_in_expert = 0;
if (expert_id >= num_local_experts) {
aligned_base =
__ldg(aligned_expert_first_token_offset + num_local_experts);
token_offset_in_expert = 0;
} else {
aligned_base = __ldg(aligned_expert_first_token_offset + expert_id);
token_offset_in_expert =
expanded_dest_row - __ldg(expert_first_token_offset + expert_id);
}
expanded_dest_row = aligned_base + token_offset_in_expert;
}
if (threadIdx.x == 0) {
assert(expanded_dest_row <= INT32_MAX);
expanded_source_row_to_expanded_dest_row[expanded_source_row] =
......@@ -76,29 +57,25 @@ void expandInputRowsKernelLauncher(
T const* unpermuted_input, T* permuted_output, int* sorted_experts,
int const* expanded_dest_row_to_expanded_source_row,
int* expanded_source_row_to_expanded_dest_row, int* permuted_idx,
int64_t const* expert_first_token_offset,
int64_t const* aligned_expert_first_token_offset, int64_t const num_rows,
int64_t const* expert_first_token_offset, int64_t const num_rows,
int64_t const* num_valid_tokens_ptr, int64_t const cols, int const k,
int num_local_experts, const int& align_block_size, cudaStream_t stream) {
int num_local_experts, cudaStream_t stream) {
int64_t const blocks = num_rows * k;
int64_t const threads = 256;
using FuncPtr = decltype(&expandInputRowsKernel<T, true, true>);
FuncPtr func_map[2][2] = {
{&expandInputRowsKernel<T, false, false>,
&expandInputRowsKernel<T, false, true>},
{&expandInputRowsKernel<T, true, false>,
&expandInputRowsKernel<T, true, true>},
using FuncPtr = decltype(&expandInputRowsKernel<T, true>);
FuncPtr func_map[2] = {
&expandInputRowsKernel<T, false>,
&expandInputRowsKernel<T, true>,
};
bool is_check_skip = num_valid_tokens_ptr != nullptr;
bool is_align_block_size = align_block_size != -1;
auto func = func_map[is_check_skip][is_align_block_size];
auto func = func_map[is_check_skip];
func<<<blocks, threads, 0, stream>>>(
unpermuted_input, permuted_output, sorted_experts,
expanded_dest_row_to_expanded_source_row,
expanded_source_row_to_expanded_dest_row, permuted_idx,
expert_first_token_offset, aligned_expert_first_token_offset, num_rows,
num_valid_tokens_ptr, cols, k, num_local_experts, align_block_size);
expert_first_token_offset, num_rows, num_valid_tokens_ptr, cols, k,
num_local_experts);
}
template <class T, class U>
......
......@@ -99,9 +99,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
"moe_permute(Tensor input, Tensor topk_ids,"
"Tensor token_expert_indices, Tensor? expert_map, int n_expert,"
"int n_local_expert,"
"int topk, int? align_block_size,Tensor! permuted_input, Tensor! "
"int topk, Tensor! permuted_input, Tensor! "
"expert_first_token_offset, Tensor! inv_permuted_idx, Tensor! "
"permuted_idx, Tensor! m_indices)->()");
"permuted_idx)->()");
m.def(
"moe_unpermute(Tensor permuted_hidden_states, Tensor topk_weights,"
......
......@@ -40,10 +40,8 @@ def torch_permute(
n_local_expert: int,
start_expert: int,
expert_map: torch.Tensor | None = None,
align_block_size: int | None = None,
fill_invalid_expert: int = -1,
) -> list[torch.Tensor]:
n_token, n_hidden = hidden_states.shape[0], hidden_states.shape[1]
n_token = hidden_states.shape[0]
if expert_map is not None:
is_local_expert = expert_map[topk_ids] != -1
not_local_expert = expert_map[topk_ids] == -1
......@@ -70,16 +68,7 @@ def torch_permute(
_, src2dst_idx = torch.sort(dst_row_id2src_row_id_map)
valid_row_idx = []
if align_block_size is None:
permuted_hidden_states = hidden_states[dst_row_id2src_row_id_map // topk, ...]
permuted_row_size = permuted_hidden_states.shape[0]
m_indices = torch.empty(
permuted_row_size, device="cuda", dtype=torch.int32
).fill_(fill_invalid_expert)
for i in range(1, n_local_expert + 1):
first_token_offset = expert_first_token_offset[i - 1]
last_token_offset = expert_first_token_offset[i]
m_indices[first_token_offset:last_token_offset] = i - 1
src_row_id2dst_row_id_map = torch.arange(
0, n_token * topk, device="cuda", dtype=torch.int32
)[src2dst_idx].reshape((n_token, topk))
......@@ -90,85 +79,6 @@ def torch_permute(
expert_first_token_offset,
src_row_id2dst_row_id_map,
dst_row_id2src_row_id_map,
m_indices,
valid_row_idx,
]
else:
permuted_row_size = (
(topk * n_token + n_expert * (align_block_size - 1) + align_block_size - 1)
// align_block_size
* align_block_size
)
permuted_idx = torch.full(
(permuted_row_size,),
n_token * topk,
dtype=torch.int32,
device=hidden_states.device,
)
permuted_hidden_states = torch.empty(
(permuted_row_size, n_hidden), device="cuda", dtype=hidden_states.dtype
)
align_src_row_id2dst_row_id = torch.empty(
n_token * topk, device="cuda", dtype=torch.int32
)
align_expert_first_token_offset = torch.zeros_like(expert_first_token_offset)
m_indices = torch.empty(
permuted_row_size, device="cuda", dtype=torch.int32
).fill_(fill_invalid_expert)
# get align_permuted_hidden_states,
# valid row_idx and align_expert_first_token_offset
for i in range(1, n_local_expert + 1):
first_token_offset = expert_first_token_offset[i - 1]
last_token_offset = expert_first_token_offset[i]
n_token_in_expert = last_token_offset - first_token_offset
align_expert_first_token_offset[i] = (
align_expert_first_token_offset[i - 1]
+ (n_token_in_expert + align_block_size - 1)
// align_block_size
* align_block_size
)
align_first_token_offset = align_expert_first_token_offset[i - 1]
align_last_token_offset = align_expert_first_token_offset[i]
dst_row_id2src_row_id_in_expert = dst_row_id2src_row_id_map[
first_token_offset : first_token_offset + n_token_in_expert
]
# store token in current expert with align_first_token_offset
permuted_hidden_states[
align_first_token_offset : align_first_token_offset + n_token_in_expert,
...,
] = hidden_states[dst_row_id2src_row_id_in_expert // topk, ...]
permuted_idx[
align_first_token_offset : align_first_token_offset + n_token_in_expert
] = dst_row_id2src_row_id_in_expert
# set current expert m_indices
m_indices[align_first_token_offset:align_last_token_offset] = i - 1
valid_row_idx += [
i
for i in range(
align_first_token_offset,
align_first_token_offset + n_token_in_expert,
)
]
# get align_src_row_id2dst_row_id
for i in range(n_token * topk):
eid = sorted_topk_ids[i]
if eid >= n_local_expert:
# check token not in local expert
align_src_row_id2dst_row_id[i] = align_expert_first_token_offset[-1]
continue
first_token_offset = expert_first_token_offset[eid]
align_first_token_offset = align_expert_first_token_offset[eid]
token_offset = i - first_token_offset
align_src_row_id2dst_row_id[i] = align_first_token_offset + token_offset
align_src_row_id2dst_row_id = align_src_row_id2dst_row_id[src2dst_idx].reshape(
(n_token, topk)
)
return [
permuted_hidden_states,
align_expert_first_token_offset,
align_src_row_id2dst_row_id,
permuted_idx,
m_indices,
valid_row_idx,
]
......@@ -207,7 +117,6 @@ def torch_unpermute(
@pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("ep_size", EP_SIZE)
@pytest.mark.parametrize("align_block_size", [None, 128])
def test_moe_permute_unpermute(
n_token: int,
n_hidden: int,
......@@ -215,11 +124,9 @@ def test_moe_permute_unpermute(
n_expert: int,
ep_size: int,
dtype: torch.dtype,
align_block_size: int | None,
):
if not moe_permute_unpermute_supported():
pytest.skip("moe_permute_unpermute is not supported on this platform.")
fill_invalid_expert = 0
ep_rank = np.random.randint(0, ep_size)
expert_map = None
n_local_expert = n_expert
......@@ -238,7 +145,6 @@ def test_moe_permute_unpermute(
gold_expert_first_token_offset,
gold_inv_permuted_idx,
gold_permuted_idx,
gold_m_indices,
valid_row_idx,
) = torch_permute(
hidden_states,
......@@ -249,8 +155,6 @@ def test_moe_permute_unpermute(
n_local_expert,
start_expert,
expert_map=expert_map,
align_block_size=align_block_size,
fill_invalid_expert=fill_invalid_expert,
)
(
......@@ -258,7 +162,7 @@ def test_moe_permute_unpermute(
_,
expert_first_token_offset,
inv_permuted_idx,
m_indices,
_,
) = moe_permute(
hidden_states=hidden_states,
a1q_scale=None,
......@@ -266,8 +170,6 @@ def test_moe_permute_unpermute(
n_expert=n_expert,
n_local_expert=n_local_expert,
expert_map=expert_map,
align_block_size=align_block_size,
fill_invalid_expert=fill_invalid_expert,
)
# check expert_first_token_offset
......@@ -278,11 +180,6 @@ def test_moe_permute_unpermute(
torch.testing.assert_close(
gold_inv_permuted_idx.flatten(), inv_permuted_idx, atol=0, rtol=0
)
# check mindice
# current kernel usage assumes deepgemm requires align_block_size
# when it's not provided then we don't compute m_indices (for cutlass)
if align_block_size is not None:
torch.testing.assert_close(gold_m_indices, m_indices, atol=0, rtol=0)
# check permuted_hidden_states, only valid token
torch.testing.assert_close(
......
......@@ -11,8 +11,6 @@ def moe_permute(
n_expert: int,
n_local_expert: int = -1,
expert_map: torch.Tensor | None = None,
align_block_size: int | None = None,
fill_invalid_expert: int = -1,
permuted_hidden_states: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
......@@ -27,9 +25,6 @@ def moe_permute(
- expert_map (Optional[torch.Tensor]): A tensor mapping expert indices
from the global expert space to the local expert space of the expert
parallel shard.
- align_block_size (Optional[int]): align group gemm block size for deepgemm
- fill_invalid_expert(int): fill expert id in m_indices for invalid expert
to workaround DeepGemm unsupported -1 in m_indices
- permuted_hidden_states (Optional[torch.Tensor]): Optional output tensor.
If None, the output tensor will be created in this function.
Returns:
......@@ -37,12 +32,9 @@ def moe_permute(
- a1q_scale (Optional[torch.Tensor]): permuted quant scale for hidden_states
if original scale not per-tensor scaling
- expert_first_token_offset (torch.Tensor): offset of the first token
of each expert for standard grouped gemm. if enable 'align_block_size'
expert_first_token_offset will align up to 'align_block_size'.
of each expert for standard grouped gemm.
- inv_permuted_idx (torch.Tensor): idx map for moe_unpermute.
- permuted_idx (torch.Tensor): idx map from hidden to permuted_hidden.
- m_indices: m_indices for grouped gemm in deepgemm,`m_indices[i]` records
the group which the j-th row of the LHS belong to.`
"""
n_token, n_hidden = hidden_states.size()
topk = topk_ids.size(1)
......@@ -50,17 +42,6 @@ def moe_permute(
"permue kernel need hidden dim align to 16B"
)
permuted_row_size = n_token * topk
if align_block_size is not None:
permuted_row_size = (
(
permuted_row_size
+ n_expert * (align_block_size - 1)
+ align_block_size
- 1
)
// align_block_size
* align_block_size
)
if n_local_expert == -1:
n_local_expert = n_expert
if permuted_hidden_states is None:
......@@ -78,12 +59,6 @@ def moe_permute(
0, n_token * topk, dtype=torch.int32, device=hidden_states.device
).reshape((n_token, topk))
m_indices = torch.full(
(permuted_row_size,),
fill_invalid_expert,
dtype=torch.int32,
device=hidden_states.device,
)
expert_first_token_offset = torch.empty(
n_local_expert + 1, dtype=torch.int64, device=hidden_states.device
)
......@@ -105,12 +80,10 @@ def moe_permute(
n_expert,
n_local_expert,
topk,
align_block_size,
permuted_hidden_states,
expert_first_token_offset,
inv_permuted_idx,
permuted_idx,
m_indices,
)
if a1q_scale is not None and a1q_scale.dim() > 1:
......@@ -120,7 +93,7 @@ def moe_permute(
a1q_scale,
expert_first_token_offset,
inv_permuted_idx.flatten(),
m_indices,
permuted_idx,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment