Unverified Commit 7f89ed24 authored by shixianc's avatar shixianc Committed by GitHub
Browse files

[Fix] enable swap_ab for pplx problem size computation (#22991)


Signed-off-by: default avatarShixian Cui <shixian@amazon.com>
Co-authored-by: default avatarShixian Cui <shixian@amazon.com>
parent 8a87cd27
...@@ -161,6 +161,7 @@ void get_cutlass_moe_mm_data_caller( ...@@ -161,6 +161,7 @@ void get_cutlass_moe_mm_data_caller(
topk_ids.size(1)); topk_ids.size(1));
} }
template <bool SWAP_AB>
__global__ void compute_pplx_data(int32_t* expert_offsets, __global__ void compute_pplx_data(int32_t* expert_offsets,
int32_t* problem_sizes1, int32_t* problem_sizes1,
int32_t* problem_sizes2, int32_t* problem_sizes2,
...@@ -168,14 +169,23 @@ __global__ void compute_pplx_data(int32_t* expert_offsets, ...@@ -168,14 +169,23 @@ __global__ void compute_pplx_data(int32_t* expert_offsets,
const int padded_m, const int n, const int padded_m, const int n,
const int k) { const int k) {
int expert_idx = threadIdx.x; int expert_idx = threadIdx.x;
expert_offsets[expert_idx] = expert_idx * padded_m; expert_offsets[expert_idx] = expert_idx * padded_m;
if constexpr (!SWAP_AB) {
problem_sizes1[expert_idx * 3] = expert_num_tokens[expert_idx]; problem_sizes1[expert_idx * 3] = expert_num_tokens[expert_idx];
problem_sizes1[expert_idx * 3 + 1] = 2 * n; problem_sizes1[expert_idx * 3 + 1] = 2 * n;
problem_sizes1[expert_idx * 3 + 2] = k; problem_sizes1[expert_idx * 3 + 2] = k;
problem_sizes2[expert_idx * 3] = expert_num_tokens[expert_idx]; problem_sizes2[expert_idx * 3] = expert_num_tokens[expert_idx];
problem_sizes2[expert_idx * 3 + 1] = k; problem_sizes2[expert_idx * 3 + 1] = k;
problem_sizes2[expert_idx * 3 + 2] = n; problem_sizes2[expert_idx * 3 + 2] = n;
} else {
problem_sizes1[expert_idx * 3] = 2 * n;
problem_sizes1[expert_idx * 3 + 1] = expert_num_tokens[expert_idx];
problem_sizes1[expert_idx * 3 + 2] = k;
problem_sizes2[expert_idx * 3] = k;
problem_sizes2[expert_idx * 3 + 1] = expert_num_tokens[expert_idx];
problem_sizes2[expert_idx * 3 + 2] = n;
}
} }
void get_cutlass_pplx_moe_mm_data_caller(torch::Tensor& expert_offsets, void get_cutlass_pplx_moe_mm_data_caller(torch::Tensor& expert_offsets,
...@@ -187,10 +197,19 @@ void get_cutlass_pplx_moe_mm_data_caller(torch::Tensor& expert_offsets, ...@@ -187,10 +197,19 @@ void get_cutlass_pplx_moe_mm_data_caller(torch::Tensor& expert_offsets,
const int64_t n, const int64_t k) { const int64_t n, const int64_t k) {
auto stream = at::cuda::getCurrentCUDAStream(expert_offsets.device().index()); auto stream = at::cuda::getCurrentCUDAStream(expert_offsets.device().index());
compute_pplx_data<<<1, num_local_experts, 0, stream>>>( if (num_local_experts * padded_m > SWAP_AB_THRESHOLD) {
compute_pplx_data<false><<<1, num_local_experts, 0, stream>>>(
static_cast<int32_t*>(expert_offsets.data_ptr()), static_cast<int32_t*>(expert_offsets.data_ptr()),
static_cast<int32_t*>(problem_sizes1.data_ptr()), static_cast<int32_t*>(problem_sizes1.data_ptr()),
static_cast<int32_t*>(problem_sizes2.data_ptr()), static_cast<int32_t*>(problem_sizes2.data_ptr()),
static_cast<const int32_t*>(expert_num_tokens.data_ptr()), padded_m, n, static_cast<const int32_t*>(expert_num_tokens.data_ptr()), padded_m, n,
k); k);
} else {
compute_pplx_data<true><<<1, num_local_experts, 0, stream>>>(
static_cast<int32_t*>(expert_offsets.data_ptr()),
static_cast<int32_t*>(problem_sizes1.data_ptr()),
static_cast<int32_t*>(problem_sizes2.data_ptr()),
static_cast<const int32_t*>(expert_num_tokens.data_ptr()), padded_m, n,
k);
}
} }
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment