Unverified Commit 826b82a2 authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

[Misc] Fix expert_ids shape in MoE (#4517)

parent c9d852d6
...@@ -203,14 +203,15 @@ def moe_align_block_size( ...@@ -203,14 +203,15 @@ def moe_align_block_size(
- The padding ensures that the total number of tokens is now divisible - The padding ensures that the total number of tokens is now divisible
by block_size for proper block matrix operations. by block_size for proper block matrix operations.
""" """
sorted_ids = torch.empty( max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
(topk_ids.numel() + num_experts * (block_size - 1), ), sorted_ids = torch.empty((max_num_tokens_padded, ),
dtype=torch.int32,
device=topk_ids.device)
expert_ids = torch.empty((topk_ids.numel() + num_experts, ),
dtype=torch.int32, dtype=torch.int32,
device=topk_ids.device) device=topk_ids.device)
sorted_ids.fill_(topk_ids.numel()) sorted_ids.fill_(topk_ids.numel())
max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size)
expert_ids = torch.empty((max_num_m_blocks, ),
dtype=torch.int32,
device=topk_ids.device)
num_tokens_post_pad = torch.empty((1), num_tokens_post_pad = torch.empty((1),
dtype=torch.int32, dtype=torch.int32,
device=topk_ids.device) device=topk_ids.device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment