Unverified Commit a5f5ab40 authored by Cheng Wan's avatar Cheng Wan Committed by GitHub
Browse files

update sgl-kernel for EP: kernel part (#8514)


Co-authored-by: default avatarXiaoyu Zhang <35585791+BBuf@users.noreply.github.com>
Co-authored-by: default avatarKe Bao <ispobaoke@gmail.com>
parent 59aab76f
...@@ -164,9 +164,6 @@ def calculate_diff(num_tokens, num_experts=256, block_size=128, topk=8): ...@@ -164,9 +164,6 @@ def calculate_diff(num_tokens, num_experts=256, block_size=128, topk=8):
num_tokens_post_pad_cuda = torch.empty( num_tokens_post_pad_cuda = torch.empty(
(1), dtype=torch.int32, device=topk_ids.device (1), dtype=torch.int32, device=topk_ids.device
) )
token_cnts_buffer = torch.zeros(
(num_experts + 1) * num_experts, dtype=torch.int32, device=topk_ids.device
)
cumsum_buffer = torch.zeros( cumsum_buffer = torch.zeros(
num_experts + 1, dtype=torch.int32, device=topk_ids.device num_experts + 1, dtype=torch.int32, device=topk_ids.device
) )
...@@ -189,7 +186,6 @@ def calculate_diff(num_tokens, num_experts=256, block_size=128, topk=8): ...@@ -189,7 +186,6 @@ def calculate_diff(num_tokens, num_experts=256, block_size=128, topk=8):
sorted_ids_cuda, sorted_ids_cuda,
expert_ids_cuda, expert_ids_cuda,
num_tokens_post_pad_cuda, num_tokens_post_pad_cuda,
token_cnts_buffer,
cumsum_buffer, cumsum_buffer,
) )
moe_align_block_size_triton( moe_align_block_size_triton(
...@@ -273,11 +269,6 @@ def sgl_moe_align_block_size_with_empty( ...@@ -273,11 +269,6 @@ def sgl_moe_align_block_size_with_empty(
if not pad_sorted_token_ids: if not pad_sorted_token_ids:
sorted_ids.fill_(topk_ids.numel()) sorted_ids.fill_(topk_ids.numel())
token_cnts_buffer = torch.empty(
(num_experts + 1) * num_experts,
dtype=torch.int32,
device=topk_ids.device,
)
cumsum_buffer = torch.empty( cumsum_buffer = torch.empty(
num_experts + 1, dtype=torch.int32, device=topk_ids.device num_experts + 1, dtype=torch.int32, device=topk_ids.device
) )
...@@ -289,7 +280,6 @@ def sgl_moe_align_block_size_with_empty( ...@@ -289,7 +280,6 @@ def sgl_moe_align_block_size_with_empty(
sorted_ids.clone(), sorted_ids.clone(),
expert_ids.clone(), expert_ids.clone(),
num_tokens_post_pad.clone(), num_tokens_post_pad.clone(),
token_cnts_buffer,
cumsum_buffer, cumsum_buffer,
pad_sorted_token_ids, pad_sorted_token_ids,
) )
......
...@@ -165,7 +165,7 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { ...@@ -165,7 +165,7 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
*/ */
m.def( m.def(
"moe_align_block_size(Tensor topk_ids, int num_experts, int block_size, Tensor! sorted_token_ids, Tensor! " "moe_align_block_size(Tensor topk_ids, int num_experts, int block_size, Tensor! sorted_token_ids, Tensor! "
"experts_ids, Tensor! num_tokens_post_pad, Tensor! token_cnts_buffer, Tensor! cumsum_buffer, bool " "experts_ids, Tensor! num_tokens_post_pad, Tensor! cumsum_buffer, bool "
"pad_sorted_token_ids) -> ()"); "pad_sorted_token_ids) -> ()");
m.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size); m.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size);
......
...@@ -36,7 +36,7 @@ __global__ void count_and_sort_expert_tokens_kernel( ...@@ -36,7 +36,7 @@ __global__ void count_and_sort_expert_tokens_kernel(
const size_t stride = blockDim.x * gridDim.x; const size_t stride = blockDim.x * gridDim.x;
for (size_t i = tid; i < numel; i += stride) { for (size_t i = tid; i < numel; i += stride) {
int32_t expert_id = topk_ids[i]; int32_t expert_id = topk_ids[i] + 1;
int32_t rank_post_pad = atomicAdd(&cumsum_buffer[expert_id], 1); int32_t rank_post_pad = atomicAdd(&cumsum_buffer[expert_id], 1);
sorted_token_ids[rank_post_pad] = i; sorted_token_ids[rank_post_pad] = i;
} }
...@@ -82,7 +82,7 @@ __global__ void moe_align_block_size_kernel( ...@@ -82,7 +82,7 @@ __global__ void moe_align_block_size_kernel(
__syncthreads(); __syncthreads();
for (size_t i = tid; i < numel; i += stride) { for (size_t i = tid; i < numel; i += stride) {
int expert_id = topk_ids[i]; int expert_id = topk_ids[i] + 1;
atomicAdd(&shared_counts[expert_id], 1); atomicAdd(&shared_counts[expert_id], 1);
} }
...@@ -215,7 +215,7 @@ __global__ void moe_align_block_size_kernel( ...@@ -215,7 +215,7 @@ __global__ void moe_align_block_size_kernel(
right = mid; right = mid;
} }
} }
expert_ids[i] = left - 1; expert_ids[i] = left - 2;
} }
if (pad_sorted_token_ids) { if (pad_sorted_token_ids) {
...@@ -251,7 +251,7 @@ __global__ void moe_align_block_size_small_batch_expert_kernel( ...@@ -251,7 +251,7 @@ __global__ void moe_align_block_size_small_batch_expert_kernel(
} }
for (size_t i = tid; i < numel; i += stride) { for (size_t i = tid; i < numel; i += stride) {
++tokens_cnts[(threadIdx.x + 1) * num_experts + topk_ids[i]]; ++tokens_cnts[(threadIdx.x + 1) * num_experts + topk_ids[i] + 1];
} }
__syncthreads(); __syncthreads();
...@@ -277,7 +277,7 @@ __global__ void moe_align_block_size_small_batch_expert_kernel( ...@@ -277,7 +277,7 @@ __global__ void moe_align_block_size_small_batch_expert_kernel(
if (threadIdx.x < num_experts) { if (threadIdx.x < num_experts) {
for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1]; i += block_size) { for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1]; i += block_size) {
expert_ids[i / block_size] = threadIdx.x; expert_ids[i / block_size] = threadIdx.x - 1;
} }
} }
...@@ -294,7 +294,7 @@ __global__ void moe_align_block_size_small_batch_expert_kernel( ...@@ -294,7 +294,7 @@ __global__ void moe_align_block_size_small_batch_expert_kernel(
__syncthreads(); __syncthreads();
for (size_t i = tid; i < numel; i += stride) { for (size_t i = tid; i < numel; i += stride) {
int32_t expert_id = topk_ids[i]; int32_t expert_id = topk_ids[i] + 1;
int32_t rank_post_pad = tokens_cnts[threadIdx.x * num_experts + expert_id] + cumsum[expert_id]; int32_t rank_post_pad = tokens_cnts[threadIdx.x * num_experts + expert_id] + cumsum[expert_id];
sorted_token_ids[rank_post_pad] = i; sorted_token_ids[rank_post_pad] = i;
++tokens_cnts[threadIdx.x * num_experts + expert_id]; ++tokens_cnts[threadIdx.x * num_experts + expert_id];
...@@ -308,7 +308,6 @@ void moe_align_block_size( ...@@ -308,7 +308,6 @@ void moe_align_block_size(
torch::Tensor sorted_token_ids, torch::Tensor sorted_token_ids,
torch::Tensor experts_ids, torch::Tensor experts_ids,
torch::Tensor num_tokens_post_pad, torch::Tensor num_tokens_post_pad,
torch::Tensor token_cnts_buffer,
torch::Tensor cumsum_buffer, torch::Tensor cumsum_buffer,
bool pad_sorted_token_ids) { bool pad_sorted_token_ids) {
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
......
...@@ -92,7 +92,7 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) { ...@@ -92,7 +92,7 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
*/ */
m.def( m.def(
"moe_align_block_size(Tensor topk_ids, int num_experts, int block_size, Tensor! sorted_token_ids, Tensor! " "moe_align_block_size(Tensor topk_ids, int num_experts, int block_size, Tensor! sorted_token_ids, Tensor! "
"experts_ids, Tensor! num_tokens_post_pad, Tensor! token_cnts_buffer, Tensor! cumsum_buffer, bool " "experts_ids, Tensor! num_tokens_post_pad, Tensor! cumsum_buffer, bool "
"pad_sorted_token_ids) -> ()"); "pad_sorted_token_ids) -> ()");
m.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size); m.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size);
......
...@@ -230,7 +230,6 @@ void moe_align_block_size( ...@@ -230,7 +230,6 @@ void moe_align_block_size(
torch::Tensor sorted_token_ids, torch::Tensor sorted_token_ids,
torch::Tensor experts_ids, torch::Tensor experts_ids,
torch::Tensor num_tokens_post_pad, torch::Tensor num_tokens_post_pad,
torch::Tensor token_cnts_buffer,
torch::Tensor cumsum_buffer, torch::Tensor cumsum_buffer,
bool pad_sorted_token_ids); bool pad_sorted_token_ids);
......
...@@ -10,7 +10,6 @@ def moe_align_block_size( ...@@ -10,7 +10,6 @@ def moe_align_block_size(
sorted_token_ids, sorted_token_ids,
experts_ids, experts_ids,
num_tokens_post_pad, num_tokens_post_pad,
token_cnts_buffer,
cumsum_buffer, cumsum_buffer,
pad_sorted_token_ids=False, pad_sorted_token_ids=False,
): ):
...@@ -21,7 +20,6 @@ def moe_align_block_size( ...@@ -21,7 +20,6 @@ def moe_align_block_size(
sorted_token_ids, sorted_token_ids,
experts_ids, experts_ids,
num_tokens_post_pad, num_tokens_post_pad,
token_cnts_buffer,
cumsum_buffer, cumsum_buffer,
pad_sorted_token_ids, pad_sorted_token_ids,
) )
......
...@@ -157,7 +157,7 @@ def test_moe_align_block_size_compare_implementations( ...@@ -157,7 +157,7 @@ def test_moe_align_block_size_compare_implementations(
:, :topk :, :topk
] ]
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1) max_num_tokens_padded = topk_ids.numel() + (num_experts + 1) * (block_size - 1)
sorted_ids_cuda = torch.empty( sorted_ids_cuda = torch.empty(
(max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device (max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device
...@@ -171,13 +171,8 @@ def test_moe_align_block_size_compare_implementations( ...@@ -171,13 +171,8 @@ def test_moe_align_block_size_compare_implementations(
num_tokens_post_pad_cuda = torch.empty( num_tokens_post_pad_cuda = torch.empty(
(1), dtype=torch.int32, device=topk_ids.device (1), dtype=torch.int32, device=topk_ids.device
) )
token_cnts_buffer = torch.empty(
(num_experts + 1) * num_experts,
dtype=torch.int32,
device=topk_ids.device,
)
cumsum_buffer = torch.empty( cumsum_buffer = torch.empty(
num_experts + 1, dtype=torch.int32, device=topk_ids.device num_experts + 2, dtype=torch.int32, device=topk_ids.device
) )
sorted_ids_triton = torch.empty_like(sorted_ids_cuda) sorted_ids_triton = torch.empty_like(sorted_ids_cuda)
...@@ -187,19 +182,18 @@ def test_moe_align_block_size_compare_implementations( ...@@ -187,19 +182,18 @@ def test_moe_align_block_size_compare_implementations(
moe_align_block_size( moe_align_block_size(
topk_ids, topk_ids,
num_experts, num_experts + 1,
block_size, block_size,
sorted_ids_cuda, sorted_ids_cuda,
expert_ids_cuda, expert_ids_cuda,
num_tokens_post_pad_cuda, num_tokens_post_pad_cuda,
token_cnts_buffer,
cumsum_buffer, cumsum_buffer,
pad_sorted_token_ids, pad_sorted_token_ids,
) )
moe_align_block_size_triton( moe_align_block_size_triton(
topk_ids, topk_ids,
num_experts, num_experts + 1,
block_size, block_size,
sorted_ids_triton, sorted_ids_triton,
expert_ids_triton, expert_ids_triton,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment