Unverified Commit 1c96fa86 authored by yiakwy-xpu-ml-framework-team's avatar yiakwy-xpu-ml-framework-team Committed by GitHub
Browse files

[MOE] enable efficient moe_alignment multi-blocks execution (3x~6x) (#3613)

parent bc20e93f
...@@ -99,13 +99,12 @@ def moe_align_block_size_triton( ...@@ -99,13 +99,12 @@ def moe_align_block_size_triton(
sorted_token_ids: torch.Tensor, sorted_token_ids: torch.Tensor,
expert_ids: torch.Tensor, expert_ids: torch.Tensor,
num_tokens_post_pad: torch.Tensor, num_tokens_post_pad: torch.Tensor,
tokens_cnts: torch.Tensor,
cumsum: torch.Tensor,
) -> None: ) -> None:
numel = topk_ids.numel() numel = topk_ids.numel()
grid = (num_experts,) grid = (num_experts,)
tokens_cnts = torch.zeros(
(num_experts + 1, num_experts), dtype=torch.int32, device=topk_ids.device
)
cumsum = torch.zeros((num_experts + 1,), dtype=torch.int32, device=topk_ids.device)
tokens_per_thread = ceil_div(numel, num_experts) tokens_per_thread = ceil_div(numel, num_experts)
moe_align_block_size_stage1[grid]( moe_align_block_size_stage1[grid](
...@@ -139,11 +138,18 @@ def moe_align_block_size_triton( ...@@ -139,11 +138,18 @@ def moe_align_block_size_triton(
) )
def calculate_diff(batch_size, seq_len): def calculate_diff(batch_size, seq_len, num_experts):
num_experts = 256 num_experts = num_experts
block_size = 128 block_size = 128
topk = 8 topk = 8
assert batch_size >= 1
assert seq_len >= 1
assert num_experts >= 4
if topk > num_experts:
topk = num_experts
topk_ids = torch.stack( topk_ids = torch.stack(
[ [
torch.randperm(num_experts, dtype=torch.int32, device="cuda")[:topk] torch.randperm(num_experts, dtype=torch.int32, device="cuda")[:topk]
...@@ -175,6 +181,13 @@ def calculate_diff(batch_size, seq_len): ...@@ -175,6 +181,13 @@ def calculate_diff(batch_size, seq_len):
expert_ids_triton = torch.zeros_like(expert_ids_cuda) expert_ids_triton = torch.zeros_like(expert_ids_cuda)
num_tokens_post_pad_triton = torch.empty_like(num_tokens_post_pad_cuda) num_tokens_post_pad_triton = torch.empty_like(num_tokens_post_pad_cuda)
token_cnts_buffer_triton = torch.zeros(
(num_experts + 1, num_experts), dtype=torch.int32, device=topk_ids.device
)
cumsum_buffer_triton = torch.zeros(
(num_experts + 1,), dtype=torch.int32, device=topk_ids.device
)
# compare the performance of cuda and triton implementation # compare the performance of cuda and triton implementation
moe_align_block_size( moe_align_block_size(
topk_ids, topk_ids,
...@@ -193,14 +206,27 @@ def calculate_diff(batch_size, seq_len): ...@@ -193,14 +206,27 @@ def calculate_diff(batch_size, seq_len):
sorted_ids_triton, sorted_ids_triton,
expert_ids_triton, expert_ids_triton,
num_tokens_post_pad_triton, num_tokens_post_pad_triton,
token_cnts_buffer_triton,
cumsum_buffer_triton,
) )
if torch.allclose(expert_ids_cuda, expert_ids_triton) and torch.allclose( sorted_ids_cuda_snapshot = sorted_ids_cuda[: cumsum_buffer[1]].sort().values
num_tokens_post_pad_cuda, num_tokens_post_pad_triton sorted_ids_triton_snapshot = sorted_ids_triton[: cumsum_buffer[1]].sort().values
if (
torch.allclose(expert_ids_cuda, expert_ids_triton)
and torch.allclose(num_tokens_post_pad_cuda, num_tokens_post_pad_triton)
and torch.allclose(sorted_ids_cuda_snapshot, sorted_ids_triton_snapshot)
): ):
print("✅ CUDA and Triton implementations match") print(
"✅ CUDA and Triton implementations match : num_tokens={}, num_experts={}".format(
batch_size * seq_len, num_experts
)
)
else: else:
print("❌ CUDA and Triton implementations do not match") print("❌ CUDA and Triton implementations do not match")
print("CUDA sorted ids:", sorted_ids_cuda_snapshot)
print("Triton sorted ids:", sorted_ids_triton_snapshot)
print("CUDA expert_ids:", expert_ids_cuda) print("CUDA expert_ids:", expert_ids_cuda)
print("Triton expert_ids:", expert_ids_triton) print("Triton expert_ids:", expert_ids_triton)
print("CUDA num_tokens_post_pad:", num_tokens_post_pad_cuda) print("CUDA num_tokens_post_pad:", num_tokens_post_pad_cuda)
...@@ -256,7 +282,7 @@ def benchmark(batch_size, seq_len, provider): ...@@ -256,7 +282,7 @@ def benchmark(batch_size, seq_len, provider):
) )
sorted_ids.fill_(topk_ids.numel()) sorted_ids.fill_(topk_ids.numel())
max_num_m_blocks = max_num_tokens_padded // block_size max_num_m_blocks = max_num_tokens_padded // block_size
expert_ids = torch.empty( expert_ids = torch.zeros(
(max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device (max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
) )
num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device) num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device)
...@@ -267,34 +293,37 @@ def benchmark(batch_size, seq_len, provider): ...@@ -267,34 +293,37 @@ def benchmark(batch_size, seq_len, provider):
num_experts + 1, dtype=torch.int32, device=topk_ids.device num_experts + 1, dtype=torch.int32, device=topk_ids.device
) )
quantiles = [0.5, 0.2, 0.8] # Warm up
if provider == "cuda": api_func = (
ms, min_ms, max_ms = triton.testing.do_bench( moe_align_block_size if provider == "cuda" else moe_align_block_size_triton
lambda: moe_align_block_size( )
for _ in range(10):
api_func(
topk_ids, topk_ids,
num_experts, num_experts,
block_size, block_size,
sorted_ids.clone(), sorted_ids,
expert_ids.clone(), expert_ids,
num_tokens_post_pad.clone(), num_tokens_post_pad,
token_cnts_buffer, token_cnts_buffer.clone(),
cumsum_buffer, cumsum_buffer.clone(),
),
quantiles=quantiles,
) )
else: torch.cuda.synchronize()
quantiles = [0.5, 0.2, 0.8]
ms, min_ms, max_ms = triton.testing.do_bench( ms, min_ms, max_ms = triton.testing.do_bench(
lambda: moe_align_block_size_triton( lambda: api_func(
topk_ids, topk_ids,
num_experts, num_experts,
block_size, block_size,
sorted_ids.clone(), sorted_ids,
expert_ids.clone(), expert_ids,
num_tokens_post_pad.clone(), num_tokens_post_pad,
token_cnts_buffer.clone(),
cumsum_buffer.clone(),
), ),
quantiles=quantiles, quantiles=quantiles,
) )
return 1000 * ms, 1000 * max_ms, 1000 * min_ms return 1000 * ms, 1000 * max_ms, 1000 * min_ms
...@@ -306,8 +335,22 @@ if __name__ == "__main__": ...@@ -306,8 +335,22 @@ if __name__ == "__main__":
default="./configs/benchmark_ops/moe_align_blocks/", default="./configs/benchmark_ops/moe_align_blocks/",
help="Path to save moe align benchmark results", help="Path to save moe align benchmark results",
) )
parser.add_argument(
"--verify",
action="store_true",
help="verify kernel",
)
args = parser.parse_args() args = parser.parse_args()
calculate_diff(batch_size=4, seq_len=1024) if args.verify:
num_experts_range = [2**i for i in range(3, 9)]
configs = list(
itertools.product(batch_size_range, seq_length_range, num_experts_range)
)
benchmark.run(print_data=True) for bs, seq, num_experts in configs:
calculate_diff(batch_size=bs, seq_len=seq, num_experts=num_experts)
else:
benchmark.run(print_data=True, save_path=args.save_path)
[build-system] [build-system]
requires = ["setuptools>=61.0", "wheel", "torch"] requires = ["setuptools>=61.0", "wheel", "torch<=2.5.1"]
build-backend = "setuptools.build_meta" build-backend = "setuptools.build_meta"
[project] [project]
......
...@@ -16,77 +16,284 @@ limitations under the License. ...@@ -16,77 +16,284 @@ limitations under the License.
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h> #include <c10/cuda/CUDAGuard.h>
#include <cooperative_groups.h>
#include <torch/extension.h> #include <torch/extension.h>
#include <THC/THCAtomics.cuh> #include <THC/THCAtomics.cuh>
#include "utils.h" #include "utils.h"
#define WARP_SIZE 32 #define MAX_NUM_EXPERTS 256
#define EXPERTS_PER_WARP ((MAX_NUM_EXPERTS) / (WARP_SIZE))
template <typename scalar_t> #define FRAGS_PER_BLOCK 4
__global__ void count_and_sort_expert_tokens_kernel(const scalar_t* __restrict__ topk_ids,
int32_t* __restrict__ sorted_token_ids,
int32_t* __restrict__ cumsum_buffer, size_t numel) {
const size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
const size_t stride = blockDim.x * gridDim.x;
for (size_t i = tid; i < numel; i += stride) { #define FRAG_SIZE_M 16
int32_t expert_id = topk_ids[i]; #define FRAG_SIZE_N 16
int32_t rank_post_pad = atomicAdd(&cumsum_buffer[expert_id], 1);
sorted_token_ids[rank_post_pad] = i; #ifndef USE_ROCM
#define kWarpsToLoad 2
#else
#define kWarpsToLoad 1
#endif
#define kElementsPerAccess 4
#define kElementsPerThr 16
#define SGLANG_FORCE_INLINE_DEVICE_FUNC static __forceinline__ __attribute__((always_inline)) __device__
namespace cg = cooperative_groups;
SGLANG_FORCE_INLINE_DEVICE_FUNC void store_global_cumsum(int* cumsum /*dest*/, int* total_tokens_post_pad /*dest*/,
const int32_t* local_offsets, const int& tid,
const int& num_experts, cg::grid_group& grid) {
int active_threads = CEILDIV(num_experts + 1, kElementsPerThr);
if (tid < active_threads - 1) {
for (int i = tid * kElementsPerThr; i < (tid + 1) * kElementsPerThr; i += kElementsPerAccess) {
*(int4*)(cumsum + i) = *(int4*)(local_offsets + i);
}
}
if (tid == active_threads - 1) {
#pragma unroll
for (int i = tid * kElementsPerThr; i < num_experts + 1; i++) {
*(cumsum + i) = *(local_offsets + i);
}
}
if (tid == active_threads) {
*total_tokens_post_pad = local_offsets[num_experts];
} }
__threadfence_system();
grid.sync();
}
SGLANG_FORCE_INLINE_DEVICE_FUNC void align_global_cumsum(int32_t* local_offsets /*src_and_dest*/,
int32_t* local_offsets_buf, int* smem_ptr, const int tid,
const int32_t& block_size, const int32_t& num_experts) {
int active_threads = CEILDIV(num_experts, kElementsPerThr);
int start = tid * kElementsPerThr + 1;
int end = MIN((tid + 1) * kElementsPerThr, num_experts) + 1;
if (tid == 0) {
smem_ptr[0] = 0;
}
if (tid < active_threads) {
for (int i = start; i < end; ++i) {
smem_ptr[i] = local_offsets[i] - local_offsets[i - 1];
}
}
__syncthreads();
if (tid < active_threads) {
for (int i = start; i < end; ++i) {
int last_val = (i - 1) % kElementsPerThr == 0 ? 0 : local_offsets[i - 1];
local_offsets[i] = last_val + CEILDIV(smem_ptr[i], block_size) * block_size;
}
local_offsets_buf[tid] = local_offsets[end - 1];
}
__syncthreads();
if (tid < active_threads && tid > 0) {
int offset = 0;
for (int j = 0; j < tid; ++j) {
offset += local_offsets_buf[j];
}
for (int i = start; i < end; ++i) {
local_offsets[i] += offset;
}
}
__syncthreads();
}
SGLANG_FORCE_INLINE_DEVICE_FUNC void reduce_unaligned_cumsum(int* tokens_cnts_ptr /*src_and_dest*/, int* smem_ptr,
int32_t* local_offsets, const int& tid, const int& lane_id,
const int& warp_id, const int32_t& num_experts,
cg::grid_group& grid) {
int total_fragments = CEILDIV(num_experts, FRAG_SIZE_N);
int fragments_per_block = CEILDIV(total_fragments, gridDim.x);
int fragments_per_warp = CEILDIV(fragments_per_block, FRAGS_PER_BLOCK);
for (int i = 0; i < gridDim.x; i += FRAG_SIZE_M) {
for (int j = 0; j < fragments_per_warp; j++) {
if (warp_id * fragments_per_warp < kWarpsToLoad * fragments_per_block) {
const int kNumThrPerRow = WARP_SIZE / FRAG_SIZE_N;
int sRow = lane_id / kNumThrPerRow;
int sWarpColStride = kNumThrPerRow * kElementsPerAccess;
int sWarpColOff = warp_id * sWarpColStride;
int sThrColOff = lane_id % kNumThrPerRow * kElementsPerAccess;
int sCol = sThrColOff + sWarpColOff;
int gRow = i + sRow;
int gBlockColOff = blockIdx.x * fragments_per_block * FRAG_SIZE_N;
int gWarpColOff_0 = (warp_id / kWarpsToLoad * fragments_per_warp + j) * FRAG_SIZE_N;
int gWarpColOff_1 = warp_id % kWarpsToLoad * sWarpColStride;
int gCol = gBlockColOff + gWarpColOff_0 + gWarpColOff_1 + sThrColOff;
if (gRow < num_experts && gCol < num_experts) {
int4* tokens_cnts_4i_ptr = (int4*)(tokens_cnts_ptr + (gRow + 1) * num_experts + gCol);
int4* smem_4i_ptr = (int4*)(smem_ptr + sRow * FRAGS_PER_BLOCK * FRAG_SIZE_N + sCol);
*smem_4i_ptr = *tokens_cnts_4i_ptr;
}
}
__syncthreads();
if (warp_id * fragments_per_warp < kWarpsToLoad * fragments_per_block) {
if (warp_id % kWarpsToLoad == 0) {
for (int k = 0; k < FRAG_SIZE_M; k += (WARP_SIZE / FRAG_SIZE_N)) {
int sRow = lane_id / FRAG_SIZE_N + k;
int sThrColOff = lane_id % FRAG_SIZE_N;
int sCol = sThrColOff + (warp_id / kWarpsToLoad) * FRAG_SIZE_N;
int gBlockColOff = blockIdx.x * fragments_per_block * FRAG_SIZE_N;
int gWarpColOff_0 = (warp_id / kWarpsToLoad * fragments_per_warp + j) * FRAG_SIZE_N;
int gCol = gBlockColOff + gWarpColOff_0 + sThrColOff;
if (gCol < num_experts) {
atomicAdd(local_offsets + gCol + 1, *(smem_ptr + sRow * FRAGS_PER_BLOCK * FRAG_SIZE_N + sCol));
// atomicAdd(tokens_cnts_ptr + gCol, *(smem_ptr + sRow * FRAGS_PER_BLOCK * FRAG_SIZE_N + sCol));
}
}
}
}
__syncthreads();
} // end of j
} // end of i
if (threadIdx.x < num_experts) {
atomicAdd(tokens_cnts_ptr + threadIdx.x, *(local_offsets + threadIdx.x + 1));
}
__threadfence_system();
grid.sync();
if (tid < num_experts) {
*(local_offsets + tid + 1) = *(tokens_cnts_ptr + tid);
}
__syncthreads();
}
SGLANG_FORCE_INLINE_DEVICE_FUNC void parallel_unaligned_local_cumsum(
const int& tid, int* tokens_cnts_ptr /*dest*/, int32_t* local_offsets /*dest*/, int32_t* local_offsets_buf,
const int32_t (*shared_counts)[EXPERTS_PER_WARP] /*src*/, const int& experts_per_warp, const int32_t& num_experts,
cg::grid_group& grid) {
int active_threads = CEILDIV(num_experts, kElementsPerThr);
if (threadIdx.x == 0) {
local_offsets[0] = 0;
}
if (threadIdx.x < active_threads) {
for (int i = threadIdx.x * kElementsPerThr + 1; i < MIN((threadIdx.x + 1) * kElementsPerThr, num_experts) + 1;
++i) {
int warp_idx = (i - 1) / experts_per_warp;
int expert_offset = (i - 1) % experts_per_warp;
int expert_count = shared_counts[warp_idx][expert_offset];
int last_val = (i - 1) % kElementsPerThr == 0 ? 0 : local_offsets[i - 1];
local_offsets[i] = last_val + expert_count;
}
local_offsets_buf[threadIdx.x] = local_offsets[MIN((threadIdx.x + 1) * kElementsPerThr, num_experts)];
}
__syncthreads();
if (threadIdx.x < active_threads && threadIdx.x > 0) {
int offset = 0;
for (int j = 0; j < threadIdx.x; ++j) {
offset += local_offsets_buf[j];
}
for (int i = threadIdx.x * kElementsPerThr + 1; i < MIN((threadIdx.x + 1) * kElementsPerThr, num_experts) + 1;
++i) {
local_offsets[i] += offset;
}
}
__syncthreads();
if (tid < num_experts) {
*(tokens_cnts_ptr + tid) = 0;
}
if (threadIdx.x < num_experts) {
*(tokens_cnts_ptr + (blockIdx.x + 1) * num_experts + threadIdx.x) = *(local_offsets + threadIdx.x + 1);
*(local_offsets + threadIdx.x + 1) = 0;
} else if (threadIdx.x < MAX_NUM_EXPERTS) {
*(local_offsets + threadIdx.x + 1) = 0;
}
__threadfence_system();
grid.sync();
} }
template <typename scalar_t> template <typename scalar_t>
__global__ void moe_align_block_size_kernel(const scalar_t* __restrict__ topk_ids, __global__ void moe_align_block_size_kernel(const scalar_t* __restrict__ topk_ids,
int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ expert_ids, int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ expert_ids,
int32_t* __restrict__ total_tokens_post_pad, int32_t num_experts, int32_t* __restrict__ total_tokens_post_pad, int32_t num_experts,
int32_t block_size, size_t numel, int32_t* __restrict__ cumsum) { int32_t block_size, size_t numel, int32_t* __restrict__ tokens_cnts,
__shared__ int32_t shared_counts[WARP_SIZE][8]; int32_t* __restrict__ cumsum, const int tokens_per_block,
const int tokens_per_thread, const int K) {
__shared__ int32_t smem[FRAG_SIZE_M * FRAG_SIZE_N * FRAGS_PER_BLOCK];
int32_t(*shared_counts)[EXPERTS_PER_WARP] = (int32_t(*)[EXPERTS_PER_WARP]) & smem[0];
__shared__ int32_t local_offsets[MAX_NUM_EXPERTS + 1];
__shared__ int32_t local_offsets_buf[CEILDIV(MAX_NUM_EXPERTS, kElementsPerThr)];
const int tid = threadIdx.x + blockDim.x * blockIdx.x;
const int warp_id = threadIdx.x / WARP_SIZE; const int warp_id = threadIdx.x / WARP_SIZE;
const int experts_per_warp = 8; const int lane_id = threadIdx.x % WARP_SIZE;
const int my_expert_start = warp_id * experts_per_warp; const int experts_per_warp = EXPERTS_PER_WARP;
int* tokens_cnts_ptr = &(tokens_cnts[0]);
int* smem_ptr = &(smem[0]);
for (int i = 0; i < experts_per_warp; ++i) { cg::grid_group grid = cg::this_grid();
if (my_expert_start + i < num_experts) {
shared_counts[warp_id][i] = 0; if (threadIdx.x < FRAG_SIZE_M * FRAG_SIZE_N) {
for (int i = 0; i < FRAG_SIZE_M * FRAG_SIZE_N * FRAGS_PER_BLOCK; i += FRAG_SIZE_M * FRAG_SIZE_N) {
smem[threadIdx.x + i] = 0;
} }
} }
__syncthreads(); __syncthreads();
const size_t tokens_per_thread = CEILDIV(numel, blockDim.x); const size_t start_idx = tokens_per_block * blockIdx.x + tokens_per_thread * threadIdx.x;
const size_t start_idx = threadIdx.x * tokens_per_thread; const size_t end_idx = start_idx + tokens_per_thread;
for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) { if (threadIdx.x * tokens_per_thread < tokens_per_block) {
for (int i = start_idx; i < MIN(numel, end_idx); ++i) {
int expert_id = topk_ids[i]; int expert_id = topk_ids[i];
int warp_idx = expert_id / experts_per_warp; int warp_idx = expert_id / experts_per_warp;
int expert_offset = expert_id % experts_per_warp; int expert_offset = expert_id % experts_per_warp;
atomicAdd(&shared_counts[warp_idx][expert_offset], 1); atomicAdd(&shared_counts[warp_idx][expert_offset], 1);
} }
}
__syncthreads(); __syncthreads();
if (threadIdx.x == 0) { parallel_unaligned_local_cumsum(tid, tokens_cnts_ptr /*dest*/, local_offsets, local_offsets_buf, shared_counts,
cumsum[0] = 0; experts_per_warp, num_experts, grid);
for (int i = 1; i <= num_experts; ++i) {
int expert_count = 0; reduce_unaligned_cumsum(tokens_cnts_ptr /*src_and_dest*/, smem_ptr, local_offsets, tid, lane_id, warp_id, num_experts,
int warp_idx = (i - 1) / experts_per_warp; grid);
int expert_offset = (i - 1) % experts_per_warp;
expert_count = shared_counts[warp_idx][expert_offset]; align_global_cumsum(local_offsets /*src_and_dest*/, local_offsets_buf, smem_ptr, tid, block_size, num_experts);
cumsum[i] = cumsum[i - 1] + CEILDIV(expert_count, block_size) * block_size; store_global_cumsum(cumsum /*dest*/, total_tokens_post_pad /*dest*/, local_offsets /*src*/, tid, num_experts, grid);
if (tid < num_experts) {
for (int i = local_offsets[tid]; i < local_offsets[tid + 1]; i += block_size) {
expert_ids[i / block_size] = tid;
} }
*total_tokens_post_pad = cumsum[num_experts];
} }
__syncthreads(); __syncthreads();
if (threadIdx.x < num_experts) { if (threadIdx.x * tokens_per_thread < tokens_per_block) {
for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1]; i += block_size) { for (int i = start_idx; i < MIN(numel, end_idx); ++i) {
expert_ids[i / block_size] = threadIdx.x; int32_t expert_id = topk_ids[i];
int32_t rank_post_pad = atomicAdd(&cumsum[expert_id], 1);
sorted_token_ids[rank_post_pad] = i;
} }
} }
} }
...@@ -95,22 +302,29 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, int64_t b ...@@ -95,22 +302,29 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, int64_t b
torch::Tensor sorted_token_ids, torch::Tensor experts_ids, torch::Tensor num_tokens_post_pad, torch::Tensor sorted_token_ids, torch::Tensor experts_ids, torch::Tensor num_tokens_post_pad,
torch::Tensor token_cnts_buffer, torch::Tensor cumsum_buffer) { torch::Tensor token_cnts_buffer, torch::Tensor cumsum_buffer) {
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
TORCH_CHECK(num_experts == 256, "moe_align_block_size kernel only support deepseek v3 now.");
DISPATCH_INTEGRAL_TYPES(topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] { DISPATCH_INTEGRAL_TYPES(topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
auto align_kernel = moe_align_block_size_kernel<scalar_t>; auto kernel = moe_align_block_size_kernel<scalar_t>;
align_kernel<<<1, 1024, 0, stream>>>(topk_ids.data_ptr<scalar_t>(), sorted_token_ids.data_ptr<int32_t>(),
experts_ids.data_ptr<int32_t>(), num_tokens_post_pad.data_ptr<int32_t>(),
num_experts, block_size, topk_ids.numel(), cumsum_buffer.data_ptr<int32_t>());
const int block_threads = 256; const int block_threads = 256;
const int num_blocks = (topk_ids.numel() + block_threads - 1) / block_threads;
const int max_blocks = 65535; const int num_blocks = MIN(CEILDIV(topk_ids.sizes()[0], block_threads), num_experts);
const int actual_blocks = std::min(num_blocks, max_blocks);
scalar_t* topk_ids_ptr = topk_ids.data_ptr<scalar_t>();
auto sort_kernel = count_and_sort_expert_tokens_kernel<scalar_t>; int32_t* sorted_token_ids_ptr = sorted_token_ids.data_ptr<int32_t>();
sort_kernel<<<actual_blocks, block_threads, 0, stream>>>(topk_ids.data_ptr<scalar_t>(), int32_t* experts_ids_ptr = experts_ids.data_ptr<int32_t>();
sorted_token_ids.data_ptr<int32_t>(), int32_t* num_tokens_post_pad_ptr = num_tokens_post_pad.data_ptr<int32_t>();
cumsum_buffer.data_ptr<int32_t>(), topk_ids.numel()); size_t num_tokens = topk_ids.numel();
int32_t* token_cnts_buffer_ptr = token_cnts_buffer.data_ptr<int32_t>();
int32_t* cumsum_buffer_ptr = cumsum_buffer.data_ptr<int32_t>();
int tokens_per_block = CEILDIV(topk_ids.sizes()[0], num_blocks) * topk_ids.sizes()[1];
int tokens_per_thread = CEILDIV(tokens_per_block, block_threads);
int K = topk_ids.sizes()[1];
void* kernelArgs[] = {&topk_ids_ptr, &sorted_token_ids_ptr, &experts_ids_ptr, &num_tokens_post_pad_ptr,
&num_experts, &block_size, &num_tokens, &token_cnts_buffer_ptr,
&cumsum_buffer_ptr, &tokens_per_block, &tokens_per_thread, &K};
cudaLaunchCooperativeKernel((void*)kernel, num_blocks, block_threads, kernelArgs);
}); });
} }
...@@ -49,6 +49,17 @@ struct cuda_error : public std::runtime_error { ...@@ -49,6 +49,17 @@ struct cuda_error : public std::runtime_error {
} \ } \
} while (0) } while (0)
#define checkCudaErrors(val) check((val), #val, __FILE__, __LINE__)
template <typename T>
void check(T result, char const* const func, const char* const file, int const line) {
if (result) {
fprintf(stderr, "CUDA error at %s:%d code=%d(%s) \"%s\" \n", file, line, static_cast<unsigned int>(result),
cudaGetErrorString(result), func);
cudaDeviceReset();
exit(EXIT_FAILURE);
}
}
#define CHECK_IS_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor") #define CHECK_IS_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_IS_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") #define CHECK_IS_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_CUDA_INPUT(x) \ #define CHECK_CUDA_INPUT(x) \
...@@ -95,3 +106,22 @@ inline int getSMVersion() { ...@@ -95,3 +106,22 @@ inline int getSMVersion() {
AT_DISPATCH_SWITCH(TYPE, NAME, DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__)) AT_DISPATCH_SWITCH(TYPE, NAME, DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__))
#define CEILDIV(x, y) (((x) + (y)-1) / (y)) #define CEILDIV(x, y) (((x) + (y)-1) / (y))
#define MIN(x, y) ((x) < (y) ? (x) : (y))
#ifndef USE_ROCM
#define WARP_SIZE 32
#else
#define WARP_SIZE warpSize // 64
#endif
#if defined(__HIP_PLATFORM_AMD__)
#include <hip/hip_cooperative_groups.h>
#include <hip/hip_runtime.h>
static __inline__ __host__ __device__ hipError_t cudaLaunchCooperativeKernel(const void* f, dim3 gridDim,
dim3 blockDimX, void** kernelParams) {
return hipLaunchCooperativeKernel(f, gridDim, blockDimX, kernelParams, 0, hipStreamDefault);
}
#endif
...@@ -171,12 +171,12 @@ def test_moe_align_block_size_compare_implementations(block_size, num_tokens, to ...@@ -171,12 +171,12 @@ def test_moe_align_block_size_compare_implementations(block_size, num_tokens, to
num_tokens_post_pad_cuda = torch.empty( num_tokens_post_pad_cuda = torch.empty(
(1), dtype=torch.int32, device=topk_ids.device (1), dtype=torch.int32, device=topk_ids.device
) )
token_cnts_buffer = torch.empty( token_cnts_buffer = torch.zeros(
(num_experts + 1) * num_experts, (num_experts + 1) * num_experts,
dtype=torch.int32, dtype=torch.int32,
device=topk_ids.device, device=topk_ids.device,
) )
cumsum_buffer = torch.empty( cumsum_buffer = torch.zeros(
num_experts + 1, dtype=torch.int32, device=topk_ids.device num_experts + 1, dtype=torch.int32, device=topk_ids.device
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment