Unverified Commit 8b5f83ed authored by Xiaoyu Zhang's avatar Xiaoyu Zhang Committed by GitHub
Browse files

reduce torch.zeros overhead in moe align block size kernel (#6369)

parent 2a413829
...@@ -30,6 +30,7 @@ from sglang.srt.utils import ( ...@@ -30,6 +30,7 @@ from sglang.srt.utils import (
is_cuda, is_cuda,
is_hip, is_hip,
log_info_on_rank0, log_info_on_rank0,
next_power_of_2,
) )
_is_hip = is_hip() _is_hip = is_hip()
...@@ -650,6 +651,61 @@ def moe_align_block_size_triton( ...@@ -650,6 +651,61 @@ def moe_align_block_size_triton(
) )
@triton.jit
def init_sorted_ids_and_cumsum_buffer_kernel(
sorted_ids_ptr,
cumsum_buffer_ptr,
max_num_tokens_padded,
topk_ids_numel,
num_experts: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
ALIGNED_NUM_EXPERTS_P1: tl.constexpr,
):
pid = tl.program_id(0)
offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
sorted_ids_blocks = tl.cdiv(max_num_tokens_padded, BLOCK_SIZE)
if pid < sorted_ids_blocks:
mask = offsets < max_num_tokens_padded
tl.store(
sorted_ids_ptr + offsets,
tl.full((BLOCK_SIZE,), topk_ids_numel, dtype=tl.int32),
mask=mask,
)
elif pid == sorted_ids_blocks:
offset_e = tl.arange(0, ALIGNED_NUM_EXPERTS_P1)
mask_e = offset_e < num_experts + 1
tl.store(
cumsum_buffer_ptr + offset_e,
tl.zeros((ALIGNED_NUM_EXPERTS_P1,), dtype=tl.int32),
mask=mask_e,
)
def init_sorted_ids_and_cumsum_buffer(
max_num_tokens_padded: int, topk_ids_numel: int, num_experts: int, device="cuda"
):
sorted_ids = torch.empty((max_num_tokens_padded,), dtype=torch.int32, device=device)
cumsum_buffer = torch.empty((num_experts + 1,), dtype=torch.int32, device=device)
BLOCK_SIZE = 1024
sorted_ids_blocks = triton.cdiv(max_num_tokens_padded, BLOCK_SIZE)
grid = (sorted_ids_blocks + 1,)
init_sorted_ids_and_cumsum_buffer_kernel[grid](
sorted_ids,
cumsum_buffer,
max_num_tokens_padded,
topk_ids_numel,
num_experts,
BLOCK_SIZE,
next_power_of_2(num_experts + 1),
)
return sorted_ids, cumsum_buffer
def moe_align_block_size( def moe_align_block_size(
topk_ids: torch.Tensor, block_size: int, num_experts: int topk_ids: torch.Tensor, block_size: int, num_experts: int
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
...@@ -691,10 +747,9 @@ def moe_align_block_size( ...@@ -691,10 +747,9 @@ def moe_align_block_size(
by block_size for proper block matrix operations. by block_size for proper block matrix operations.
""" """
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1) max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
sorted_ids = torch.empty( sorted_ids, cumsum_buffer = init_sorted_ids_and_cumsum_buffer(
(max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device max_num_tokens_padded, topk_ids.numel(), num_experts, topk_ids.device
) )
sorted_ids.fill_(topk_ids.numel())
max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size) max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size)
expert_ids = torch.empty( expert_ids = torch.empty(
(max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device (max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
...@@ -715,9 +770,6 @@ def moe_align_block_size( ...@@ -715,9 +770,6 @@ def moe_align_block_size(
dtype=torch.int32, dtype=torch.int32,
device=topk_ids.device, device=topk_ids.device,
) )
cumsum_buffer = torch.empty(
num_experts + 1, dtype=torch.int32, device=topk_ids.device
)
sgl_moe_align_block_size( sgl_moe_align_block_size(
topk_ids, topk_ids,
......
...@@ -197,8 +197,6 @@ void moe_align_block_size( ...@@ -197,8 +197,6 @@ void moe_align_block_size(
size_t num_warps = CEILDIV(padded_num_experts, experts_per_warp); size_t num_warps = CEILDIV(padded_num_experts, experts_per_warp);
size_t shared_mem_size = num_warps * experts_per_warp * sizeof(int32_t); size_t shared_mem_size = num_warps * experts_per_warp * sizeof(int32_t);
cumsum_buffer.zero_();
align_kernel<<<1, threads, shared_mem_size, stream>>>( align_kernel<<<1, threads, shared_mem_size, stream>>>(
topk_ids.data_ptr<scalar_t>(), topk_ids.data_ptr<scalar_t>(),
sorted_token_ids.data_ptr<int32_t>(), sorted_token_ids.data_ptr<int32_t>(),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment