Unverified Commit c4e314f9 authored by Xiaoyu Zhang's avatar Xiaoyu Zhang Committed by GitHub
Browse files

Restruct sgl-kernel benchmark (#10861)

parent 7a06ef98
......@@ -88,7 +88,7 @@ def benchmark(batch_size, seq_len, provider):
elif provider == "sglang":
fn = lambda: sglang_scaled_fp8_quant(x.clone())
ms, min_ms, max_ms = triton.testing.do_bench(fn, quantiles=quantiles)
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(fn, quantiles=quantiles)
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
......
......@@ -160,7 +160,7 @@ def benchmark_quantization(batch_size, seq_len, hidden_dim, provider):
elif provider == "sglang":
fn = lambda: sglang_per_token_quant_fp8(x.clone())
ms, min_ms, max_ms = triton.testing.do_bench(fn, quantiles=quantiles)
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(fn, quantiles=quantiles)
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
......
......@@ -117,17 +117,17 @@ def benchmark(batch_size, provider, N, K):
quantiles = [0.5, 0.2, 0.8]
if provider == "FP16":
ms, min_ms, max_ms = triton.testing.do_bench(
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
lambda: torch.matmul(a_fp16, b_fp16),
quantiles=quantiles,
)
if provider == "W8A8":
ms, min_ms, max_ms = triton.testing.do_bench(
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
lambda: int8_scaled_mm(a, b, scale_a, scale_b, torch.float16),
quantiles=quantiles,
)
if provider == "Qserve_W4A8_Per_Channel":
ms, min_ms, max_ms = triton.testing.do_bench(
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
lambda: qserve_w4a8_per_chn_gemm(
a_qserve_chn,
b_qserve_chn,
......@@ -139,7 +139,7 @@ def benchmark(batch_size, provider, N, K):
quantiles=quantiles,
)
if provider == "Qserve_W4A8_Per_Group":
ms, min_ms, max_ms = triton.testing.do_bench(
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
lambda: qserve_w4a8_per_group_gemm(
a_qserve_group,
b_qserve_group,
......@@ -189,8 +189,6 @@ if __name__ == "__main__":
print(f"{model_name} N={N} K={K}: ")
benchmark.run(
print_data=True,
show_plots=True,
save_path="bench_qserve_w4a8_gemm_res",
N=N,
K=K,
)
......
# Benchmarks SGLang RMSNorm kernels versus vLLM and FlashInfer across
# (batch_size, seq_len, hidden_size) and prints speed-up.
import argparse
import itertools
from typing import Optional, Tuple, Union
import re
from typing import List, Optional, Tuple, Union
import sgl_kernel
import torch
import torch.nn as nn
import triton
import triton.testing
from flashinfer.norm import fused_add_rmsnorm, rmsnorm
from torch import nn
from sgl_kernel.utils import is_arch_support_pdl
from vllm import _custom_ops as vllm_ops
def str2int_list(arg: str) -> List[int]:
if arg in ("", None):
return []
if re.fullmatch(r"\d+(,\d+)*", arg.strip()) is None:
raise argparse.ArgumentTypeError(f"Bad int list: {arg}")
return [int(x) for x in arg.split(",")]
class HuggingFaceRMSNorm(nn.Module):
def __init__(self, hidden_size: int, eps: float = 1e-6) -> None:
super().__init__()
......@@ -108,6 +123,36 @@ def rmsnorm_vllm(
return output
def rmsnorm_sglang(
x: torch.Tensor,
weight: torch.Tensor,
residual: Optional[torch.Tensor] = None,
eps: float = 1e-6,
enable_pdl: Optional[bool] = None,
):
orig_shape = x.shape
x = x.view(-1, x.shape[-1])
if residual is not None:
residual = residual.view(-1, residual.shape[-1])
if enable_pdl is None:
enable_pdl = is_arch_support_pdl()
if residual is not None:
sgl_kernel.fused_add_rmsnorm(x, residual, weight, eps, enable_pdl=enable_pdl)
output = (x, residual)
else:
out = torch.empty_like(x)
sgl_kernel.rmsnorm(x, weight, eps, out=out, enable_pdl=enable_pdl)
output = out
if isinstance(output, tuple):
output = (output[0].view(orig_shape), output[1].view(orig_shape))
else:
output = output.view(orig_shape)
return output
def calculate_diff(batch_size, seq_len, hidden_size, use_residual=True):
dtype = torch.bfloat16
x = torch.randn(batch_size, seq_len, hidden_size, dtype=dtype, device="cuda")
......@@ -123,108 +168,151 @@ def calculate_diff(batch_size, seq_len, hidden_size, use_residual=True):
output_vllm = rmsnorm_vllm(
x.clone(), weight, residual.clone() if residual is not None else None
)
output_sglang = rmsnorm_sglang(
x.clone(), weight, residual.clone() if residual is not None else None
)
if use_residual:
output_naive = output_naive[0]
output_flashinfer = output_flashinfer[0]
output_vllm = output_vllm[0]
output_sglang = output_sglang[0]
print(f"Naive output={output_naive}")
print(f"FlashInfer output={output_flashinfer}")
print(f"VLLM output={output_vllm}")
print(f"SGLang output={output_sglang}")
if torch.allclose(
output_naive, output_flashinfer, atol=1e-2, rtol=1e-2
) and torch.allclose(output_naive, output_vllm, atol=1e-2, rtol=1e-2):
if (
torch.allclose(output_naive, output_flashinfer, atol=1e-2, rtol=1e-2)
and torch.allclose(output_naive, output_vllm, atol=1e-2, rtol=1e-2)
and torch.allclose(output_naive, output_sglang, atol=1e-2, rtol=1e-2)
):
print("✅ All implementations match")
else:
print("❌ Implementations differ")
batch_size_range = [2**i for i in range(0, 7, 2)]
seq_length_range = [2**i for i in range(6, 11, 1)]
head_num_range = [32, 48]
configs = list(itertools.product(head_num_range, batch_size_range, seq_length_range))
def get_benchmark(use_residual):
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["head_num", "batch_size", "seq_len"],
x_vals=[list(_) for _ in configs],
line_arg="provider",
line_vals=["huggingface", "flashinfer", "vllm"],
line_names=["HuggingFace", "FlashInfer", "vLLM"],
styles=[("blue", "-"), ("green", "-"), ("red", "-")],
ylabel="us",
plot_name=f"rmsnorm-performance-{'with' if use_residual else 'without'}-residual",
args={},
)
default_batch_sizes = [2**i for i in range(0, 7, 2)] # 1, 4, 16, 64
default_seq_lens = [2**i for i in range(6, 11, 1)] # 64, 128, 256, 512, 1024
default_hidden_sizes = [32 * 128, 48 * 128] # 4096, 6144
def make_configs(bsizes: List[int], slens: List[int], hsizes: List[int]) -> List[Tuple]:
return list(itertools.product(bsizes, slens, hsizes))
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["batch_size", "seq_len", "hidden_size"],
x_vals=[],
line_arg="provider",
line_vals=["huggingface", "flashinfer", "vllm", "sglang"],
line_names=["HuggingFace", "FlashInfer", "vLLM", "SGL Kernel"],
styles=[("blue", "-"), ("green", "-"), ("red", "-"), ("orange", "-")],
ylabel="µs (median) or × (speed-up)",
plot_name="rmsnorm-performance",
args={},
)
def benchmark(head_num, batch_size, seq_len, provider):
dtype = torch.bfloat16
hidden_size = head_num * 128 # assuming head_dim = 128
x = torch.randn(batch_size, seq_len, hidden_size, dtype=dtype, device="cuda")
weight = torch.ones(hidden_size, dtype=dtype, device="cuda")
residual = torch.randn_like(x) if use_residual else None
quantiles = [0.5, 0.2, 0.8]
if provider == "huggingface":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: rmsnorm_naive(
x.clone(),
weight,
residual.clone() if residual is not None else None,
),
quantiles=quantiles,
)
def benchmark(batch_size, seq_len, hidden_size, provider, use_residual):
device = torch.device("cuda")
dtype = torch.bfloat16
x = torch.randn(batch_size, seq_len, hidden_size, dtype=dtype, device=device)
weight = torch.ones(hidden_size, dtype=dtype, device=device)
residual = torch.randn_like(x) if use_residual else None
# timing helper
def timed(fn):
for _ in range(5):
fn()
torch.cuda.synchronize()
ms, qmin, qmax = triton.testing.do_bench_cudagraph(
fn, quantiles=[0.5, 0.2, 0.8]
)
return 1000 * ms, 1000 * qmax, 1000 * qmin
if provider == "huggingface":
return timed(
lambda: rmsnorm_naive(
x.clone(),
weight,
residual.clone() if residual is not None else None,
)
elif provider == "flashinfer":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: rmsnorm_flashinfer(
x.clone(),
weight,
residual.clone() if residual is not None else None,
),
quantiles=quantiles,
)
elif provider == "flashinfer":
return timed(
lambda: rmsnorm_flashinfer(
x.clone(),
weight,
residual.clone() if residual is not None else None,
)
else:
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: rmsnorm_vllm(
x.clone(),
weight,
residual.clone() if residual is not None else None,
),
quantiles=quantiles,
)
elif provider == "vllm":
return timed(
lambda: rmsnorm_vllm(
x.clone(),
weight,
residual.clone() if residual is not None else None,
)
)
elif provider == "sglang":
return timed(
lambda: rmsnorm_sglang(
x.clone(),
weight,
residual.clone() if residual is not None else None,
)
)
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
return benchmark
# provider == "speedup"
t_ref, _, _ = timed(
lambda: rmsnorm_vllm(
x.clone(),
weight,
residual.clone() if residual is not None else None,
)
)
t_sgl, _, _ = timed(
lambda: rmsnorm_sglang(
x.clone(),
weight,
residual.clone() if residual is not None else None,
)
)
spd = t_ref / t_sgl
return (spd, spd, spd)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument(
p = argparse.ArgumentParser("RMSNorm kernel benchmark")
p.add_argument("--batch_sizes", type=str2int_list, default=default_batch_sizes)
p.add_argument("--seq_lens", type=str2int_list, default=default_seq_lens)
p.add_argument("--hidden_sizes", type=str2int_list, default=default_hidden_sizes)
p.add_argument(
"--use_residual", action="store_true", help="Whether to use residual connection"
)
parser.add_argument(
"--save_path",
type=str,
default="./configs/benchmark_ops/rmsnorm/",
help="Path to save rmsnorm benchmark results",
)
args = parser.parse_args()
# Run correctness test
calculate_diff(
batch_size=4, seq_len=128, hidden_size=4096, use_residual=args.use_residual
)
p.add_argument("--verify_only", action="store_true")
args = p.parse_args()
# coerce lists
if isinstance(args.batch_sizes, str):
args.batch_sizes = str2int_list(args.batch_sizes)
if isinstance(args.seq_lens, str):
args.seq_lens = str2int_list(args.seq_lens)
if isinstance(args.hidden_sizes, str):
args.hidden_sizes = str2int_list(args.hidden_sizes)
# patch perf_report grid
benchmark_grid = make_configs(args.batch_sizes, args.seq_lens, args.hidden_sizes)
if hasattr(benchmark, "benchmarks"):
benchmark.benchmarks.x_vals = benchmark_grid
else:
benchmark.benchmark.x_vals = benchmark_grid
# Get the benchmark function with proper use_residual setting
benchmark = get_benchmark(args.use_residual)
# Run performance benchmark
benchmark.run(print_data=True, save_path=args.save_path)
if args.verify_only:
ok = calculate_diff(4, 128, args.hidden_sizes[0], args.use_residual)
print("✅ sanity pass" if ok else "❌ mismatch")
else:
benchmark.run(print_data=True, use_residual=args.use_residual)
......@@ -114,7 +114,9 @@ def benchmark_sampling(batch_size, vocab_size, p, provider):
filter_apply_order="joint",
)
ms, min_ms, max_ms = triton.testing.do_bench(fn, quantiles=[0.5, 0.2, 0.8])
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
fn, quantiles=[0.5, 0.2, 0.8]
)
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
......
......@@ -3,6 +3,7 @@
import pytest
import sgl_kernel
import torch
from sgl_kernel.utils import is_arch_support_pdl
def llama_rms_norm(x, w, eps=1e-6):
......@@ -58,11 +59,12 @@ def test_norm(batch_size, hidden_size, dtype, specify_out):
w = torch.randn(hidden_size).to(0).to(dtype)
y_ref = llama_rms_norm(x, w)
enable_pdl = is_arch_support_pdl()
if specify_out:
y = torch.empty_like(x)
sgl_kernel.rmsnorm(x, w, out=y)
sgl_kernel.rmsnorm(x, w, out=y, enable_pdl=enable_pdl)
else:
y = sgl_kernel.rmsnorm(x, w)
y = sgl_kernel.rmsnorm(x, w, enable_pdl=enable_pdl)
torch.testing.assert_close(y_ref, y, rtol=1e-3, atol=1e-3)
......@@ -83,7 +85,10 @@ def test_fused_add_rmsnorm(batch_size, hidden_size, dtype):
x_fused = x.clone()
residual_fused = residual.clone()
sgl_kernel.fused_add_rmsnorm(x_fused, residual_fused, weight, eps)
enable_pdl = is_arch_support_pdl()
sgl_kernel.fused_add_rmsnorm(
x_fused, residual_fused, weight, eps, enable_pdl=enable_pdl
)
torch.testing.assert_close(x_fused, x_native, rtol=1e-3, atol=1e-3)
torch.testing.assert_close(residual_fused, residual_native, rtol=1e-3, atol=1e-3)
......@@ -98,11 +103,12 @@ def test_gemma_norm(batch_size, hidden_size, dtype, specify_out):
w = torch.randn(hidden_size).to(0).to(dtype)
y_ref = gemma_rms_norm(x, w)
enable_pdl = is_arch_support_pdl()
if specify_out:
y = torch.empty_like(x)
sgl_kernel.gemma_rmsnorm(x, w, out=y)
sgl_kernel.gemma_rmsnorm(x, w, out=y, enable_pdl=enable_pdl)
else:
y = sgl_kernel.gemma_rmsnorm(x, w)
y = sgl_kernel.gemma_rmsnorm(x, w, enable_pdl=enable_pdl)
torch.testing.assert_close(y_ref, y, rtol=1e-3, atol=1e-3)
......@@ -123,7 +129,10 @@ def test_gemma_fused_add_rmsnorm(batch_size, hidden_size, dtype):
x_fused = x.clone()
residual_fused = residual.clone()
sgl_kernel.gemma_fused_add_rmsnorm(x_fused, residual_fused, weight, eps)
enable_pdl = is_arch_support_pdl()
sgl_kernel.gemma_fused_add_rmsnorm(
x_fused, residual_fused, weight, eps, enable_pdl=enable_pdl
)
torch.testing.assert_close(x_fused, x_native, rtol=1e-3, atol=1e-3)
torch.testing.assert_close(residual_fused, residual_native, rtol=1e-3, atol=1e-3)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment