Unverified Commit 7151194b authored by Ke Bao's avatar Ke Bao Committed by GitHub
Browse files

Remove cumsum_buffer initilization (#7439)

parent 2ed68d7a
...@@ -750,9 +750,11 @@ def moe_align_block_size( ...@@ -750,9 +750,11 @@ def moe_align_block_size(
by block_size for proper block matrix operations. by block_size for proper block matrix operations.
""" """
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1) max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
sorted_ids, cumsum_buffer = init_sorted_ids_and_cumsum_buffer( sorted_ids = torch.empty(
max_num_tokens_padded, topk_ids.numel(), num_experts, topk_ids.device (max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device
) )
sorted_ids.fill_(topk_ids.numel())
max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size) max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size)
expert_ids = torch.empty( expert_ids = torch.empty(
(max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device (max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
...@@ -768,6 +770,9 @@ def moe_align_block_size( ...@@ -768,6 +770,9 @@ def moe_align_block_size(
num_tokens_post_pad, num_tokens_post_pad,
) )
else: else:
cumsum_buffer = torch.empty(
(num_experts + 1,), dtype=torch.int32, device=topk_ids.device
)
token_cnts_buffer = torch.empty( token_cnts_buffer = torch.empty(
(num_experts + 1) * num_experts, (num_experts + 1) * num_experts,
dtype=torch.int32, dtype=torch.int32,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment