Unverified Commit f730362e authored by Xiaoyu Zhang's avatar Xiaoyu Zhang Committed by GitHub
Browse files

reduce moe_align_block_size_kernel small batch mode overhead (#5086)

parent e3c4bd31
...@@ -702,7 +702,7 @@ def moe_align_block_size( ...@@ -702,7 +702,7 @@ def moe_align_block_size(
num_tokens_post_pad, num_tokens_post_pad,
) )
else: else:
token_cnts_buffer = torch.zeros( token_cnts_buffer = torch.empty(
(num_experts + 1) * num_experts, (num_experts + 1) * num_experts,
dtype=torch.int32, dtype=torch.int32,
device=topk_ids.device, device=topk_ids.device,
......
...@@ -241,9 +241,9 @@ def calculate_diff(num_tokens, num_experts=256, block_size=128, topk=8): ...@@ -241,9 +241,9 @@ def calculate_diff(num_tokens, num_experts=256, block_size=128, topk=8):
# Test range # Test range
num_tokens_range = [16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192] num_tokens_range = [1, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192]
num_experts_range = [8, 32, 64, 128, 256] num_experts_range = [8, 32, 64, 128, 256]
topk_range = [2, 4, 8] topk_range = [1, 2, 4, 8]
configs = list(itertools.product(num_tokens_range, num_experts_range, topk_range)) configs = list(itertools.product(num_tokens_range, num_experts_range, topk_range))
...@@ -294,17 +294,28 @@ def benchmark(num_tokens, num_experts, topk, provider): ...@@ -294,17 +294,28 @@ def benchmark(num_tokens, num_experts, topk, provider):
(max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device (max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
) )
num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device) num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device)
token_cnts_buffer = torch.zeros(
(num_experts + 1) * num_experts, dtype=torch.int32, device=topk_ids.device quantiles = [0.5, 0.2, 0.8]
if provider == "sgl":
def sgl_moe_align_block_size_with_empty(
topk_ids,
num_experts,
block_size,
sorted_ids,
expert_ids,
num_tokens_post_pad,
):
token_cnts_buffer = torch.empty(
(num_experts + 1) * num_experts,
dtype=torch.int32,
device=topk_ids.device,
) )
cumsum_buffer = torch.zeros( cumsum_buffer = torch.empty(
num_experts + 1, dtype=torch.int32, device=topk_ids.device num_experts + 1, dtype=torch.int32, device=topk_ids.device
) )
quantiles = [0.5, 0.2, 0.8] sgl_moe_align_block_size(
if provider == "sgl":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: sgl_moe_align_block_size(
topk_ids, topk_ids,
num_experts, num_experts,
block_size, block_size,
...@@ -313,6 +324,16 @@ def benchmark(num_tokens, num_experts, topk, provider): ...@@ -313,6 +324,16 @@ def benchmark(num_tokens, num_experts, topk, provider):
num_tokens_post_pad.clone(), num_tokens_post_pad.clone(),
token_cnts_buffer, token_cnts_buffer,
cumsum_buffer, cumsum_buffer,
)
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: sgl_moe_align_block_size_with_empty(
topk_ids,
num_experts,
block_size,
sorted_ids,
expert_ids,
num_tokens_post_pad,
), ),
quantiles=quantiles, quantiles=quantiles,
) )
......
...@@ -64,10 +64,10 @@ __global__ void moe_align_block_size_kernel( ...@@ -64,10 +64,10 @@ __global__ void moe_align_block_size_kernel(
__syncthreads(); __syncthreads();
const size_t tokens_per_thread = CEILDIV(numel, blockDim.x); const size_t tid = threadIdx.x;
const size_t start_idx = threadIdx.x * tokens_per_thread; const size_t stride = blockDim.x;
for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) { for (size_t i = tid; i < numel; i += stride) {
int expert_id = topk_ids[i]; int expert_id = topk_ids[i];
int warp_idx = expert_id / experts_per_warp; int warp_idx = expert_id / experts_per_warp;
int expert_offset = expert_id % experts_per_warp; int expert_offset = expert_id % experts_per_warp;
...@@ -98,6 +98,65 @@ __global__ void moe_align_block_size_kernel( ...@@ -98,6 +98,65 @@ __global__ void moe_align_block_size_kernel(
} }
} }
template <typename scalar_t>
__global__ void moe_align_block_size_small_batch_expert_kernel(
const scalar_t* __restrict__ topk_ids,
int32_t* __restrict__ sorted_token_ids,
int32_t* __restrict__ expert_ids,
int32_t* __restrict__ total_tokens_post_pad,
int32_t num_experts,
int32_t block_size,
size_t numel) {
const size_t tid = threadIdx.x;
const size_t stride = blockDim.x;
extern __shared__ int32_t shared_mem[];
int32_t* cumsum = shared_mem;
int32_t* tokens_cnts = (int32_t*)(shared_mem + num_experts + 1);
for (int i = 0; i < num_experts; ++i) {
tokens_cnts[(threadIdx.x + 1) * num_experts + i] = 0;
}
for (size_t i = tid; i < numel; i += stride) {
++tokens_cnts[(threadIdx.x + 1) * num_experts + topk_ids[i]];
}
__syncthreads();
if (threadIdx.x < num_experts) {
tokens_cnts[threadIdx.x] = 0;
for (int i = 1; i <= blockDim.x; ++i) {
tokens_cnts[i * num_experts + threadIdx.x] += tokens_cnts[(i - 1) * num_experts + threadIdx.x];
}
}
__syncthreads();
if (threadIdx.x == 0) {
cumsum[0] = 0;
for (int i = 1; i <= num_experts; ++i) {
cumsum[i] = cumsum[i - 1] + CEILDIV(tokens_cnts[blockDim.x * num_experts + i - 1], block_size) * block_size;
}
*total_tokens_post_pad = static_cast<int32_t>(cumsum[num_experts]);
}
__syncthreads();
if (threadIdx.x < num_experts) {
for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1]; i += block_size) {
expert_ids[i / block_size] = threadIdx.x;
}
}
for (size_t i = tid; i < numel; i += stride) {
int32_t expert_id = topk_ids[i];
int32_t rank_post_pad = tokens_cnts[threadIdx.x * num_experts + expert_id] + cumsum[expert_id];
sorted_token_ids[rank_post_pad] = i;
++tokens_cnts[threadIdx.x * num_experts + expert_id];
}
}
void moe_align_block_size( void moe_align_block_size(
torch::Tensor topk_ids, torch::Tensor topk_ids,
int64_t num_experts, int64_t num_experts,
...@@ -111,28 +170,35 @@ void moe_align_block_size( ...@@ -111,28 +170,35 @@ void moe_align_block_size(
int64_t padded_num_experts = ((num_experts + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE; int64_t padded_num_experts = ((num_experts + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE;
int experts_per_warp; int experts_per_warp = WARP_SIZE;
int threads; int threads = 1024;
if (num_experts <= 8) {
experts_per_warp = 8;
threads = 256;
} else if (num_experts <= 16) {
experts_per_warp = 16;
threads = 512;
} else {
experts_per_warp = WARP_SIZE;
threads = 1024;
}
threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE; threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE;
DISPATCH_INTEGRAL_TYPES(topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] { DISPATCH_INTEGRAL_TYPES(topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
bool small_batch_expert_mode = (topk_ids.numel() < 1024) && (num_experts <= 64);
if (small_batch_expert_mode) {
const int32_t threads = max((int32_t)num_experts, WARP_SIZE);
const int32_t shared_mem_size = ((threads + 1) * num_experts + (num_experts + 1)) * sizeof(int32_t);
auto small_batch_expert_kernel = moe_align_block_size_small_batch_expert_kernel<scalar_t>;
small_batch_expert_kernel<<<1, threads, shared_mem_size, stream>>>(
topk_ids.data_ptr<scalar_t>(),
sorted_token_ids.data_ptr<int32_t>(),
experts_ids.data_ptr<int32_t>(),
num_tokens_post_pad.data_ptr<int32_t>(),
num_experts,
block_size,
topk_ids.numel());
} else {
auto align_kernel = moe_align_block_size_kernel<scalar_t>; auto align_kernel = moe_align_block_size_kernel<scalar_t>;
size_t num_warps = CEILDIV(padded_num_experts, experts_per_warp); size_t num_warps = CEILDIV(padded_num_experts, experts_per_warp);
size_t shared_mem_size = num_warps * experts_per_warp * sizeof(int32_t); size_t shared_mem_size = num_warps * experts_per_warp * sizeof(int32_t);
cumsum_buffer.zero_();
align_kernel<<<1, threads, shared_mem_size, stream>>>( align_kernel<<<1, threads, shared_mem_size, stream>>>(
topk_ids.data_ptr<scalar_t>(), topk_ids.data_ptr<scalar_t>(),
sorted_token_ids.data_ptr<int32_t>(), sorted_token_ids.data_ptr<int32_t>(),
...@@ -156,5 +222,6 @@ void moe_align_block_size( ...@@ -156,5 +222,6 @@ void moe_align_block_size(
sorted_token_ids.data_ptr<int32_t>(), sorted_token_ids.data_ptr<int32_t>(),
cumsum_buffer.data_ptr<int32_t>(), cumsum_buffer.data_ptr<int32_t>(),
topk_ids.numel()); topk_ids.numel());
}
}); });
} }
...@@ -151,7 +151,6 @@ def moe_align_block_size_triton( ...@@ -151,7 +151,6 @@ def moe_align_block_size_triton(
def test_moe_align_block_size_compare_implementations( def test_moe_align_block_size_compare_implementations(
block_size, num_tokens, topk, num_experts block_size, num_tokens, topk, num_experts
): ):
# For DeepSeek V3, we have 256 experts
topk_ids = torch.stack( topk_ids = torch.stack(
[ [
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment