Unverified Commit 10bfce71 authored by yiakwy-xpu-ml-framework-team's avatar yiakwy-xpu-ml-framework-team Committed by GitHub
Browse files

fix moe align blocks benchmark (#3003)

parent 583697cd
...@@ -7,6 +7,8 @@ import triton ...@@ -7,6 +7,8 @@ import triton
import triton.language as tl import triton.language as tl
from sgl_kernel import moe_align_block_size from sgl_kernel import moe_align_block_size
USE_RANDOM_PERM = False
def ceil_div(a, b): def ceil_div(a, b):
return (a + b - 1) // b return (a + b - 1) // b
...@@ -141,8 +143,13 @@ def moe_align_block_size_triton( ...@@ -141,8 +143,13 @@ def moe_align_block_size_triton(
def calculate_diff(batch_size, seq_len): def calculate_diff(batch_size, seq_len):
num_experts = 256 num_experts = 256
block_size = 128 block_size = 128
topk_ids = torch.randint( topk = 8
0, num_experts, (batch_size, seq_len), dtype=torch.int32, device="cuda"
topk_ids = torch.stack(
[
torch.randperm(num_experts, dtype=torch.int32, device="cuda")[:topk]
for _ in range(batch_size * seq_len)
]
) )
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1) max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
...@@ -169,7 +176,7 @@ def calculate_diff(batch_size, seq_len): ...@@ -169,7 +176,7 @@ def calculate_diff(batch_size, seq_len):
expert_ids_triton = torch.empty_like(expert_ids_cuda) expert_ids_triton = torch.empty_like(expert_ids_cuda)
num_tokens_post_pad_triton = torch.empty_like(num_tokens_post_pad_cuda) num_tokens_post_pad_triton = torch.empty_like(num_tokens_post_pad_cuda)
# 运行两个实现 # compare the performance of cuda and triton implementation
moe_align_block_size( moe_align_block_size(
topk_ids, topk_ids,
num_experts, num_experts,
...@@ -206,6 +213,15 @@ seq_length_range = [2**i for i in range(0, 16)] ...@@ -206,6 +213,15 @@ seq_length_range = [2**i for i in range(0, 16)]
configs = list(itertools.product(batch_size_range, seq_length_range)) configs = list(itertools.product(batch_size_range, seq_length_range))
def get_topk_ids(num_tokens: int, num_experts: int, topk: int) -> torch.Tensor:
topk_ids = torch.zeros((num_tokens, topk), dtype=torch.int32, device="cuda")
for i in range(num_tokens):
topk_ids[i, :] = torch.randperm(num_experts, dtype=torch.int32, device="cuda")[
:topk
]
return topk_ids
@triton.testing.perf_report( @triton.testing.perf_report(
triton.testing.Benchmark( triton.testing.Benchmark(
x_names=["batch_size", "seq_len"], x_names=["batch_size", "seq_len"],
...@@ -223,8 +239,16 @@ def benchmark(batch_size, seq_len, provider): ...@@ -223,8 +239,16 @@ def benchmark(batch_size, seq_len, provider):
num_experts = 256 num_experts = 256
block_size = 128 block_size = 128
topk = 8 topk = 8
if USE_RANDOM_PERM:
topk_ids = get_topk_ids(batch_size * seq_len, num_experts, topk)
else:
topk_ids = torch.randint( topk_ids = torch.randint(
0, num_experts, (batch_size * seq_len, topk), dtype=torch.int32, device="cuda" 0,
num_experts,
(batch_size * seq_len, topk),
dtype=torch.int32,
device="cuda",
) )
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1) max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment