Unverified Commit 3bdcdd13 authored by Yuan Luo's avatar Yuan Luo Committed by GitHub
Browse files

[Hot-Fix] moe_aligned_block_size CI failed in AMD (#8461)


Co-authored-by: default avatarluoyuan.luo <luoyuan.luo@antgroup.com>
Co-authored-by: default avatarXiaoyu Zhang <35585791+BBuf@users.noreply.github.com>
Co-authored-by: default avatarJieXin Liang <Alcanderian@users.noreply.github.com>
parent a730ce81
......@@ -42,6 +42,18 @@ __global__ void count_and_sort_expert_tokens_kernel(
}
}
#ifdef __CUDA_ARCH__
__device__ __forceinline__ int warp_exclusive_scan(int v, unsigned mask = 0xffffffffu) {
int original = v;
#pragma unroll
for (int offset = 1; offset < WARP_SIZE; offset <<= 1) {
int n = __shfl_up_sync(mask, v, offset);
if ((threadIdx.x & (WARP_SIZE - 1)) >= offset) v += n;
}
return v - original;
}
#endif
template <typename scalar_t>
__global__ void moe_align_block_size_kernel(
const scalar_t* __restrict__ topk_ids,
......@@ -83,6 +95,8 @@ __global__ void moe_align_block_size_kernel(
scan_buf[tid] = padded_count;
}
#ifndef __CUDA_ARCH__ // HIP
if (tid >= num_experts && tid < scan_size) {
scan_buf[tid] = 0;
}
......@@ -132,13 +146,62 @@ __global__ void moe_align_block_size_kernel(
s_total_tokens_post_pad = prefix[num_experts];
*total_tokens_post_pad = s_total_tokens_post_pad;
}
__syncthreads();
#else // CUDA
// Intra warp prefix sum
int32_t* warp_sums = scan_buf + scan_size; // [<= 32]
const int warp_id = tid / WARP_SIZE;
const int lane_id = tid & (WARP_SIZE - 1);
const int num_warps_for_scan = (scan_size + WARP_SIZE - 1) / WARP_SIZE;
const int warp_sum = warp_exclusive_scan(padded_count) + padded_count;
if (lane_id == WARP_SIZE - 1) warp_sums[warp_id] = warp_sum;
__syncthreads();
// warp0 accumulate all the block's prefix sum
if (tid < WARP_SIZE) {
int val = (tid < num_warps_for_scan) ? warp_sums[tid] : 0;
int incl = warp_exclusive_scan(val) + val;
warp_sums[tid] = incl;
}
__syncthreads();
// Every thread obtains the whole block's sum
if (tid == 0) {
prefix[num_experts] = warp_sums[num_warps_for_scan - 1];
s_total_tokens_post_pad = prefix[num_experts];
*total_tokens_post_pad = s_total_tokens_post_pad;
}
__syncthreads();
// Fill 0 to scan_buf extended area (tid >= num_expert)
if (tid >= num_experts && tid < scan_size) scan_buf[tid] = 0;
__syncthreads();
// Perform 2 level exclusive-prefix-sum to scan_buf
int v = (tid < scan_size) ? scan_buf[tid] : 0;
int pre = warp_exclusive_scan(v);
if (lane_id == WARP_SIZE - 1) warp_sums[warp_id] = pre + v;
__syncthreads();
if (warp_id == 0) {
int val = (lane_id < num_warps_for_scan) ? warp_sums[lane_id] : 0;
warp_sums[lane_id] = warp_exclusive_scan(val);
}
__syncthreads();
int offset = warp_sums[warp_id];
if (tid < scan_size) scan_buf[tid] = pre + offset;
__syncthreads();
// Write prefix[0..num_experts - 1] and cumsum
if (tid < num_experts) prefix[tid] = scan_buf[tid];
#endif
if (tid <= num_experts) {
cumsum[tid] = prefix[tid];
}
// fill expert_ids
const int32_t num_blocks = s_total_tokens_post_pad / block_size;
for (int32_t i = tid; i < num_blocks; i += stride) {
......@@ -250,9 +313,6 @@ void moe_align_block_size(
bool pad_sorted_token_ids) {
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
int64_t padded_num_experts = ((num_experts + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE;
int experts_per_warp = WARP_SIZE;
int threads = 1024;
threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE;
......@@ -278,8 +338,7 @@ void moe_align_block_size(
auto align_kernel = moe_align_block_size_kernel<scalar_t>;
const size_t scan_size = next_pow2(num_experts);
const size_t shared_mem_size = (num_experts + (num_experts + 1) + scan_size) * sizeof(int32_t);
const size_t shared_mem_size = (num_experts + (num_experts + 1) + scan_size + WARP_SIZE) * sizeof(int32_t);
align_kernel<<<1, threads, shared_mem_size, stream>>>(
topk_ids.data_ptr<scalar_t>(),
sorted_token_ids.data_ptr<int32_t>(),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment