Unverified Commit 7130a7ce authored by Xiaoyu Zhang's avatar Xiaoyu Zhang Committed by GitHub
Browse files

refine sgl_moe_align_block_size_benchmark (#4327)

parent 8f1f614e
...@@ -4,7 +4,8 @@ import itertools ...@@ -4,7 +4,8 @@ import itertools
import torch import torch
import triton import triton
import triton.language as tl import triton.language as tl
from sgl_kernel import moe_align_block_size from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size
from vllm import _custom_ops as ops
USE_RANDOM_PERM = False USE_RANDOM_PERM = False
...@@ -139,15 +140,11 @@ def moe_align_block_size_triton( ...@@ -139,15 +140,11 @@ def moe_align_block_size_triton(
) )
def calculate_diff(batch_size, seq_len): def calculate_diff(num_tokens, num_experts=256, block_size=128, topk=8):
num_experts = 256
block_size = 128
topk = 8
topk_ids = torch.stack( topk_ids = torch.stack(
[ [
torch.randperm(num_experts, dtype=torch.int32, device="cuda")[:topk] torch.randperm(num_experts, dtype=torch.int32, device="cuda")[:topk]
for _ in range(batch_size * seq_len) for _ in range(num_tokens)
] ]
) )
...@@ -175,8 +172,13 @@ def calculate_diff(batch_size, seq_len): ...@@ -175,8 +172,13 @@ def calculate_diff(batch_size, seq_len):
expert_ids_triton = torch.zeros_like(expert_ids_cuda) expert_ids_triton = torch.zeros_like(expert_ids_cuda)
num_tokens_post_pad_triton = torch.empty_like(num_tokens_post_pad_cuda) num_tokens_post_pad_triton = torch.empty_like(num_tokens_post_pad_cuda)
# compare the performance of cuda and triton implementation sorted_ids_vllm = torch.empty_like(sorted_ids_cuda)
moe_align_block_size( sorted_ids_vllm.fill_(topk_ids.numel())
expert_ids_vllm = torch.zeros_like(expert_ids_cuda)
num_tokens_post_pad_vllm = torch.empty_like(num_tokens_post_pad_cuda)
# compare the performance of cuda, triton and vllm implementation
sgl_moe_align_block_size(
topk_ids, topk_ids,
num_experts, num_experts,
block_size, block_size,
...@@ -194,22 +196,43 @@ def calculate_diff(batch_size, seq_len): ...@@ -194,22 +196,43 @@ def calculate_diff(batch_size, seq_len):
expert_ids_triton, expert_ids_triton,
num_tokens_post_pad_triton, num_tokens_post_pad_triton,
) )
ops.moe_align_block_size(
topk_ids,
num_experts,
block_size,
sorted_ids_vllm,
expert_ids_vllm,
num_tokens_post_pad_vllm,
)
if torch.allclose(expert_ids_cuda, expert_ids_triton) and torch.allclose( if torch.allclose(expert_ids_cuda, expert_ids_triton) and torch.allclose(
num_tokens_post_pad_cuda, num_tokens_post_pad_triton num_tokens_post_pad_cuda, num_tokens_post_pad_triton
): ):
print("✅ CUDA and Triton implementations match") print("✅ SGL and Triton implementations match")
else: else:
print("❌ CUDA and Triton implementations do not match") print("❌ SGL and Triton implementations do not match")
print("CUDA expert_ids:", expert_ids_cuda) print("SGL expert_ids:", expert_ids_cuda)
print("Triton expert_ids:", expert_ids_triton) print("Triton expert_ids:", expert_ids_triton)
print("CUDA num_tokens_post_pad:", num_tokens_post_pad_cuda) print("SGL num_tokens_post_pad:", num_tokens_post_pad_cuda)
print("Triton num_tokens_post_pad:", num_tokens_post_pad_triton) print("Triton num_tokens_post_pad:", num_tokens_post_pad_triton)
if torch.allclose(expert_ids_cuda, expert_ids_vllm) and torch.allclose(
num_tokens_post_pad_cuda, num_tokens_post_pad_vllm
):
print("✅ SGL and VLLM implementations match")
else:
print("❌ SGL and VLLM implementations do not match")
print("SGL expert_ids:", expert_ids_cuda)
print("VLLM expert_ids:", expert_ids_vllm)
print("SGL num_tokens_post_pad:", num_tokens_post_pad_cuda)
print("VLLM num_tokens_post_pad:", num_tokens_post_pad_vllm)
batch_size_range = [2**i for i in range(0, 8)] num_tokens_range = [16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192]
seq_length_range = [2**i for i in range(0, 16)] num_experts_range = [32, 64, 128, 256]
configs = list(itertools.product(batch_size_range, seq_length_range)) topk_range = [2, 4, 8]
configs = list(itertools.product(num_tokens_range, num_experts_range, topk_range))
def get_topk_ids(num_tokens: int, num_experts: int, topk: int) -> torch.Tensor: def get_topk_ids(num_tokens: int, num_experts: int, topk: int) -> torch.Tensor:
...@@ -223,29 +246,27 @@ def get_topk_ids(num_tokens: int, num_experts: int, topk: int) -> torch.Tensor: ...@@ -223,29 +246,27 @@ def get_topk_ids(num_tokens: int, num_experts: int, topk: int) -> torch.Tensor:
@triton.testing.perf_report( @triton.testing.perf_report(
triton.testing.Benchmark( triton.testing.Benchmark(
x_names=["batch_size", "seq_len"], x_names=["num_tokens", "num_experts", "topk"],
x_vals=[list(_) for _ in configs], x_vals=configs,
line_arg="provider", line_arg="provider",
line_vals=["cuda", "triton"], line_vals=["sgl", "triton", "vllm"],
line_names=["CUDA", "Triton"], line_names=["SGL", "Triton", "VLLM"],
styles=[("blue", "-"), ("red", "-")], styles=[("blue", "-"), ("red", "-"), ("green", "-")],
ylabel="us", ylabel="us",
plot_name="moe-align-block-size-performance", plot_name="moe-align-block-size-performance",
args={}, args={},
) )
) )
def benchmark(batch_size, seq_len, provider): def benchmark(num_tokens, num_experts, topk, provider):
num_experts = 256
block_size = 128 block_size = 128
topk = 8
if USE_RANDOM_PERM: if USE_RANDOM_PERM:
topk_ids = get_topk_ids(batch_size * seq_len, num_experts, topk) topk_ids = get_topk_ids(num_tokens, num_experts, topk)
else: else:
topk_ids = torch.randint( topk_ids = torch.randint(
0, 0,
num_experts, num_experts,
(batch_size * seq_len, topk), (num_tokens, topk),
dtype=torch.int32, dtype=torch.int32,
device="cuda", device="cuda",
) )
...@@ -268,9 +289,9 @@ def benchmark(batch_size, seq_len, provider): ...@@ -268,9 +289,9 @@ def benchmark(batch_size, seq_len, provider):
) )
quantiles = [0.5, 0.2, 0.8] quantiles = [0.5, 0.2, 0.8]
if provider == "cuda": if provider == "sgl":
ms, min_ms, max_ms = triton.testing.do_bench( ms, min_ms, max_ms = triton.testing.do_bench(
lambda: moe_align_block_size( lambda: sgl_moe_align_block_size(
topk_ids, topk_ids,
num_experts, num_experts,
block_size, block_size,
...@@ -282,7 +303,7 @@ def benchmark(batch_size, seq_len, provider): ...@@ -282,7 +303,7 @@ def benchmark(batch_size, seq_len, provider):
), ),
quantiles=quantiles, quantiles=quantiles,
) )
else: elif provider == "triton":
ms, min_ms, max_ms = triton.testing.do_bench( ms, min_ms, max_ms = triton.testing.do_bench(
lambda: moe_align_block_size_triton( lambda: moe_align_block_size_triton(
topk_ids, topk_ids,
...@@ -294,6 +315,18 @@ def benchmark(batch_size, seq_len, provider): ...@@ -294,6 +315,18 @@ def benchmark(batch_size, seq_len, provider):
), ),
quantiles=quantiles, quantiles=quantiles,
) )
else: # vllm
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: ops.moe_align_block_size(
topk_ids,
num_experts,
block_size,
sorted_ids.clone(),
expert_ids.clone(),
num_tokens_post_pad.clone(),
),
quantiles=quantiles,
)
return 1000 * ms, 1000 * max_ms, 1000 * min_ms return 1000 * ms, 1000 * max_ms, 1000 * min_ms
...@@ -306,8 +339,22 @@ if __name__ == "__main__": ...@@ -306,8 +339,22 @@ if __name__ == "__main__":
default="./configs/benchmark_ops/moe_align_blocks/", default="./configs/benchmark_ops/moe_align_blocks/",
help="Path to save moe align benchmark results", help="Path to save moe align benchmark results",
) )
parser.add_argument(
"--num_experts",
type=int,
default=256,
choices=[8, 64, 128, 256],
help="Number of experts for benchmark",
)
parser.add_argument(
"--topk",
type=int,
default=8,
choices=[2, 4, 8],
help="Top-k value for benchmark",
)
args = parser.parse_args() args = parser.parse_args()
calculate_diff(batch_size=4, seq_len=1024) calculate_diff(num_tokens=1024, num_experts=args.num_experts, topk=args.topk)
benchmark.run(print_data=True) benchmark.run(print_data=True)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment