Unverified Commit f730362e authored by Xiaoyu Zhang's avatar Xiaoyu Zhang Committed by GitHub
Browse files

reduce moe_align_block_size_kernel small batch mode overhead (#5086)

parent e3c4bd31
......@@ -702,7 +702,7 @@ def moe_align_block_size(
num_tokens_post_pad,
)
else:
token_cnts_buffer = torch.zeros(
token_cnts_buffer = torch.empty(
(num_experts + 1) * num_experts,
dtype=torch.int32,
device=topk_ids.device,
......
......@@ -241,9 +241,9 @@ def calculate_diff(num_tokens, num_experts=256, block_size=128, topk=8):
# Test range
num_tokens_range = [16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192]
num_tokens_range = [1, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192]
num_experts_range = [8, 32, 64, 128, 256]
topk_range = [2, 4, 8]
topk_range = [1, 2, 4, 8]
configs = list(itertools.product(num_tokens_range, num_experts_range, topk_range))
......@@ -294,17 +294,28 @@ def benchmark(num_tokens, num_experts, topk, provider):
(max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
)
num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device)
token_cnts_buffer = torch.zeros(
(num_experts + 1) * num_experts, dtype=torch.int32, device=topk_ids.device
)
cumsum_buffer = torch.zeros(
num_experts + 1, dtype=torch.int32, device=topk_ids.device
)
quantiles = [0.5, 0.2, 0.8]
if provider == "sgl":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: sgl_moe_align_block_size(
def sgl_moe_align_block_size_with_empty(
topk_ids,
num_experts,
block_size,
sorted_ids,
expert_ids,
num_tokens_post_pad,
):
token_cnts_buffer = torch.empty(
(num_experts + 1) * num_experts,
dtype=torch.int32,
device=topk_ids.device,
)
cumsum_buffer = torch.empty(
num_experts + 1, dtype=torch.int32, device=topk_ids.device
)
sgl_moe_align_block_size(
topk_ids,
num_experts,
block_size,
......@@ -313,6 +324,16 @@ def benchmark(num_tokens, num_experts, topk, provider):
num_tokens_post_pad.clone(),
token_cnts_buffer,
cumsum_buffer,
)
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: sgl_moe_align_block_size_with_empty(
topk_ids,
num_experts,
block_size,
sorted_ids,
expert_ids,
num_tokens_post_pad,
),
quantiles=quantiles,
)
......
......@@ -64,10 +64,10 @@ __global__ void moe_align_block_size_kernel(
__syncthreads();
const size_t tokens_per_thread = CEILDIV(numel, blockDim.x);
const size_t start_idx = threadIdx.x * tokens_per_thread;
const size_t tid = threadIdx.x;
const size_t stride = blockDim.x;
for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
for (size_t i = tid; i < numel; i += stride) {
int expert_id = topk_ids[i];
int warp_idx = expert_id / experts_per_warp;
int expert_offset = expert_id % experts_per_warp;
......@@ -98,6 +98,65 @@ __global__ void moe_align_block_size_kernel(
}
}
template <typename scalar_t>
__global__ void moe_align_block_size_small_batch_expert_kernel(
const scalar_t* __restrict__ topk_ids,
int32_t* __restrict__ sorted_token_ids,
int32_t* __restrict__ expert_ids,
int32_t* __restrict__ total_tokens_post_pad,
int32_t num_experts,
int32_t block_size,
size_t numel) {
const size_t tid = threadIdx.x;
const size_t stride = blockDim.x;
extern __shared__ int32_t shared_mem[];
int32_t* cumsum = shared_mem;
int32_t* tokens_cnts = (int32_t*)(shared_mem + num_experts + 1);
for (int i = 0; i < num_experts; ++i) {
tokens_cnts[(threadIdx.x + 1) * num_experts + i] = 0;
}
for (size_t i = tid; i < numel; i += stride) {
++tokens_cnts[(threadIdx.x + 1) * num_experts + topk_ids[i]];
}
__syncthreads();
if (threadIdx.x < num_experts) {
tokens_cnts[threadIdx.x] = 0;
for (int i = 1; i <= blockDim.x; ++i) {
tokens_cnts[i * num_experts + threadIdx.x] += tokens_cnts[(i - 1) * num_experts + threadIdx.x];
}
}
__syncthreads();
if (threadIdx.x == 0) {
cumsum[0] = 0;
for (int i = 1; i <= num_experts; ++i) {
cumsum[i] = cumsum[i - 1] + CEILDIV(tokens_cnts[blockDim.x * num_experts + i - 1], block_size) * block_size;
}
*total_tokens_post_pad = static_cast<int32_t>(cumsum[num_experts]);
}
__syncthreads();
if (threadIdx.x < num_experts) {
for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1]; i += block_size) {
expert_ids[i / block_size] = threadIdx.x;
}
}
for (size_t i = tid; i < numel; i += stride) {
int32_t expert_id = topk_ids[i];
int32_t rank_post_pad = tokens_cnts[threadIdx.x * num_experts + expert_id] + cumsum[expert_id];
sorted_token_ids[rank_post_pad] = i;
++tokens_cnts[threadIdx.x * num_experts + expert_id];
}
}
void moe_align_block_size(
torch::Tensor topk_ids,
int64_t num_experts,
......@@ -111,50 +170,58 @@ void moe_align_block_size(
int64_t padded_num_experts = ((num_experts + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE;
int experts_per_warp;
int threads;
if (num_experts <= 8) {
experts_per_warp = 8;
threads = 256;
} else if (num_experts <= 16) {
experts_per_warp = 16;
threads = 512;
} else {
experts_per_warp = WARP_SIZE;
threads = 1024;
}
int experts_per_warp = WARP_SIZE;
int threads = 1024;
threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE;
DISPATCH_INTEGRAL_TYPES(topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
auto align_kernel = moe_align_block_size_kernel<scalar_t>;
size_t num_warps = CEILDIV(padded_num_experts, experts_per_warp);
size_t shared_mem_size = num_warps * experts_per_warp * sizeof(int32_t);
align_kernel<<<1, threads, shared_mem_size, stream>>>(
topk_ids.data_ptr<scalar_t>(),
sorted_token_ids.data_ptr<int32_t>(),
experts_ids.data_ptr<int32_t>(),
num_tokens_post_pad.data_ptr<int32_t>(),
num_experts,
padded_num_experts,
experts_per_warp,
block_size,
topk_ids.numel(),
cumsum_buffer.data_ptr<int32_t>());
const int block_threads = std::min(256, (int)threads);
const int num_blocks = (topk_ids.numel() + block_threads - 1) / block_threads;
const int max_blocks = 65535;
const int actual_blocks = std::min(num_blocks, max_blocks);
auto sort_kernel = count_and_sort_expert_tokens_kernel<scalar_t>;
sort_kernel<<<actual_blocks, block_threads, 0, stream>>>(
topk_ids.data_ptr<scalar_t>(),
sorted_token_ids.data_ptr<int32_t>(),
cumsum_buffer.data_ptr<int32_t>(),
topk_ids.numel());
bool small_batch_expert_mode = (topk_ids.numel() < 1024) && (num_experts <= 64);
if (small_batch_expert_mode) {
const int32_t threads = max((int32_t)num_experts, WARP_SIZE);
const int32_t shared_mem_size = ((threads + 1) * num_experts + (num_experts + 1)) * sizeof(int32_t);
auto small_batch_expert_kernel = moe_align_block_size_small_batch_expert_kernel<scalar_t>;
small_batch_expert_kernel<<<1, threads, shared_mem_size, stream>>>(
topk_ids.data_ptr<scalar_t>(),
sorted_token_ids.data_ptr<int32_t>(),
experts_ids.data_ptr<int32_t>(),
num_tokens_post_pad.data_ptr<int32_t>(),
num_experts,
block_size,
topk_ids.numel());
} else {
auto align_kernel = moe_align_block_size_kernel<scalar_t>;
size_t num_warps = CEILDIV(padded_num_experts, experts_per_warp);
size_t shared_mem_size = num_warps * experts_per_warp * sizeof(int32_t);
cumsum_buffer.zero_();
align_kernel<<<1, threads, shared_mem_size, stream>>>(
topk_ids.data_ptr<scalar_t>(),
sorted_token_ids.data_ptr<int32_t>(),
experts_ids.data_ptr<int32_t>(),
num_tokens_post_pad.data_ptr<int32_t>(),
num_experts,
padded_num_experts,
experts_per_warp,
block_size,
topk_ids.numel(),
cumsum_buffer.data_ptr<int32_t>());
const int block_threads = std::min(256, (int)threads);
const int num_blocks = (topk_ids.numel() + block_threads - 1) / block_threads;
const int max_blocks = 65535;
const int actual_blocks = std::min(num_blocks, max_blocks);
auto sort_kernel = count_and_sort_expert_tokens_kernel<scalar_t>;
sort_kernel<<<actual_blocks, block_threads, 0, stream>>>(
topk_ids.data_ptr<scalar_t>(),
sorted_token_ids.data_ptr<int32_t>(),
cumsum_buffer.data_ptr<int32_t>(),
topk_ids.numel());
}
});
}
......@@ -151,7 +151,6 @@ def moe_align_block_size_triton(
def test_moe_align_block_size_compare_implementations(
block_size, num_tokens, topk, num_experts
):
# For DeepSeek V3, we have 256 experts
topk_ids = torch.stack(
[
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment