Unverified Commit 64480df4 authored by yiakwy-xpu-ml-framework-team's avatar yiakwy-xpu-ml-framework-team Committed by GitHub
Browse files

[BUG] fix moe benchmark when bs*seq is small (#3382)

parent 4530136e
......@@ -157,7 +157,7 @@ def calculate_diff(batch_size, seq_len):
)
sorted_ids_cuda.fill_(topk_ids.numel())
max_num_m_blocks = max_num_tokens_padded // block_size
expert_ids_cuda = torch.empty(
expert_ids_cuda = torch.zeros(
(max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
)
num_tokens_post_pad_cuda = torch.empty(
......@@ -172,7 +172,7 @@ def calculate_diff(batch_size, seq_len):
sorted_ids_triton = torch.empty_like(sorted_ids_cuda)
sorted_ids_triton.fill_(topk_ids.numel())
expert_ids_triton = torch.empty_like(expert_ids_cuda)
expert_ids_triton = torch.zeros_like(expert_ids_cuda)
num_tokens_post_pad_triton = torch.empty_like(num_tokens_post_pad_cuda)
# compare the performance of cuda and triton implementation
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment