Unverified Commit c4e314f9 authored by Xiaoyu Zhang's avatar Xiaoyu Zhang Committed by GitHub
Browse files

Restruct sgl-kernel benchmark (#10861)

parent 7a06ef98
...@@ -88,7 +88,7 @@ def benchmark(batch_size, seq_len, provider): ...@@ -88,7 +88,7 @@ def benchmark(batch_size, seq_len, provider):
elif provider == "sglang": elif provider == "sglang":
fn = lambda: sglang_scaled_fp8_quant(x.clone()) fn = lambda: sglang_scaled_fp8_quant(x.clone())
ms, min_ms, max_ms = triton.testing.do_bench(fn, quantiles=quantiles) ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(fn, quantiles=quantiles)
return 1000 * ms, 1000 * max_ms, 1000 * min_ms return 1000 * ms, 1000 * max_ms, 1000 * min_ms
......
...@@ -160,7 +160,7 @@ def benchmark_quantization(batch_size, seq_len, hidden_dim, provider): ...@@ -160,7 +160,7 @@ def benchmark_quantization(batch_size, seq_len, hidden_dim, provider):
elif provider == "sglang": elif provider == "sglang":
fn = lambda: sglang_per_token_quant_fp8(x.clone()) fn = lambda: sglang_per_token_quant_fp8(x.clone())
ms, min_ms, max_ms = triton.testing.do_bench(fn, quantiles=quantiles) ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(fn, quantiles=quantiles)
return 1000 * ms, 1000 * max_ms, 1000 * min_ms return 1000 * ms, 1000 * max_ms, 1000 * min_ms
......
...@@ -117,17 +117,17 @@ def benchmark(batch_size, provider, N, K): ...@@ -117,17 +117,17 @@ def benchmark(batch_size, provider, N, K):
quantiles = [0.5, 0.2, 0.8] quantiles = [0.5, 0.2, 0.8]
if provider == "FP16": if provider == "FP16":
ms, min_ms, max_ms = triton.testing.do_bench( ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
lambda: torch.matmul(a_fp16, b_fp16), lambda: torch.matmul(a_fp16, b_fp16),
quantiles=quantiles, quantiles=quantiles,
) )
if provider == "W8A8": if provider == "W8A8":
ms, min_ms, max_ms = triton.testing.do_bench( ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
lambda: int8_scaled_mm(a, b, scale_a, scale_b, torch.float16), lambda: int8_scaled_mm(a, b, scale_a, scale_b, torch.float16),
quantiles=quantiles, quantiles=quantiles,
) )
if provider == "Qserve_W4A8_Per_Channel": if provider == "Qserve_W4A8_Per_Channel":
ms, min_ms, max_ms = triton.testing.do_bench( ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
lambda: qserve_w4a8_per_chn_gemm( lambda: qserve_w4a8_per_chn_gemm(
a_qserve_chn, a_qserve_chn,
b_qserve_chn, b_qserve_chn,
...@@ -139,7 +139,7 @@ def benchmark(batch_size, provider, N, K): ...@@ -139,7 +139,7 @@ def benchmark(batch_size, provider, N, K):
quantiles=quantiles, quantiles=quantiles,
) )
if provider == "Qserve_W4A8_Per_Group": if provider == "Qserve_W4A8_Per_Group":
ms, min_ms, max_ms = triton.testing.do_bench( ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
lambda: qserve_w4a8_per_group_gemm( lambda: qserve_w4a8_per_group_gemm(
a_qserve_group, a_qserve_group,
b_qserve_group, b_qserve_group,
...@@ -189,8 +189,6 @@ if __name__ == "__main__": ...@@ -189,8 +189,6 @@ if __name__ == "__main__":
print(f"{model_name} N={N} K={K}: ") print(f"{model_name} N={N} K={K}: ")
benchmark.run( benchmark.run(
print_data=True, print_data=True,
show_plots=True,
save_path="bench_qserve_w4a8_gemm_res",
N=N, N=N,
K=K, K=K,
) )
......
# Benchmarks SGLang RMSNorm kernels versus vLLM and FlashInfer across
# (batch_size, seq_len, hidden_size) and prints speed-up.
import argparse
import itertools import itertools
from typing import Optional, Tuple, Union import re
from typing import List, Optional, Tuple, Union
import sgl_kernel
import torch import torch
import torch.nn as nn
import triton import triton
import triton.testing
from flashinfer.norm import fused_add_rmsnorm, rmsnorm from flashinfer.norm import fused_add_rmsnorm, rmsnorm
from torch import nn from sgl_kernel.utils import is_arch_support_pdl
from vllm import _custom_ops as vllm_ops from vllm import _custom_ops as vllm_ops
def str2int_list(arg: str) -> List[int]:
if arg in ("", None):
return []
if re.fullmatch(r"\d+(,\d+)*", arg.strip()) is None:
raise argparse.ArgumentTypeError(f"Bad int list: {arg}")
return [int(x) for x in arg.split(",")]
class HuggingFaceRMSNorm(nn.Module): class HuggingFaceRMSNorm(nn.Module):
def __init__(self, hidden_size: int, eps: float = 1e-6) -> None: def __init__(self, hidden_size: int, eps: float = 1e-6) -> None:
super().__init__() super().__init__()
...@@ -108,6 +123,36 @@ def rmsnorm_vllm( ...@@ -108,6 +123,36 @@ def rmsnorm_vllm(
return output return output
def rmsnorm_sglang(
x: torch.Tensor,
weight: torch.Tensor,
residual: Optional[torch.Tensor] = None,
eps: float = 1e-6,
enable_pdl: Optional[bool] = None,
):
orig_shape = x.shape
x = x.view(-1, x.shape[-1])
if residual is not None:
residual = residual.view(-1, residual.shape[-1])
if enable_pdl is None:
enable_pdl = is_arch_support_pdl()
if residual is not None:
sgl_kernel.fused_add_rmsnorm(x, residual, weight, eps, enable_pdl=enable_pdl)
output = (x, residual)
else:
out = torch.empty_like(x)
sgl_kernel.rmsnorm(x, weight, eps, out=out, enable_pdl=enable_pdl)
output = out
if isinstance(output, tuple):
output = (output[0].view(orig_shape), output[1].view(orig_shape))
else:
output = output.view(orig_shape)
return output
def calculate_diff(batch_size, seq_len, hidden_size, use_residual=True): def calculate_diff(batch_size, seq_len, hidden_size, use_residual=True):
dtype = torch.bfloat16 dtype = torch.bfloat16
x = torch.randn(batch_size, seq_len, hidden_size, dtype=dtype, device="cuda") x = torch.randn(batch_size, seq_len, hidden_size, dtype=dtype, device="cuda")
...@@ -123,108 +168,151 @@ def calculate_diff(batch_size, seq_len, hidden_size, use_residual=True): ...@@ -123,108 +168,151 @@ def calculate_diff(batch_size, seq_len, hidden_size, use_residual=True):
output_vllm = rmsnorm_vllm( output_vllm = rmsnorm_vllm(
x.clone(), weight, residual.clone() if residual is not None else None x.clone(), weight, residual.clone() if residual is not None else None
) )
output_sglang = rmsnorm_sglang(
x.clone(), weight, residual.clone() if residual is not None else None
)
if use_residual: if use_residual:
output_naive = output_naive[0] output_naive = output_naive[0]
output_flashinfer = output_flashinfer[0] output_flashinfer = output_flashinfer[0]
output_vllm = output_vllm[0] output_vllm = output_vllm[0]
output_sglang = output_sglang[0]
print(f"Naive output={output_naive}") print(f"Naive output={output_naive}")
print(f"FlashInfer output={output_flashinfer}") print(f"FlashInfer output={output_flashinfer}")
print(f"VLLM output={output_vllm}") print(f"VLLM output={output_vllm}")
print(f"SGLang output={output_sglang}")
if torch.allclose( if (
output_naive, output_flashinfer, atol=1e-2, rtol=1e-2 torch.allclose(output_naive, output_flashinfer, atol=1e-2, rtol=1e-2)
) and torch.allclose(output_naive, output_vllm, atol=1e-2, rtol=1e-2): and torch.allclose(output_naive, output_vllm, atol=1e-2, rtol=1e-2)
and torch.allclose(output_naive, output_sglang, atol=1e-2, rtol=1e-2)
):
print("✅ All implementations match") print("✅ All implementations match")
else: else:
print("❌ Implementations differ") print("❌ Implementations differ")
batch_size_range = [2**i for i in range(0, 7, 2)] default_batch_sizes = [2**i for i in range(0, 7, 2)] # 1, 4, 16, 64
seq_length_range = [2**i for i in range(6, 11, 1)] default_seq_lens = [2**i for i in range(6, 11, 1)] # 64, 128, 256, 512, 1024
head_num_range = [32, 48] default_hidden_sizes = [32 * 128, 48 * 128] # 4096, 6144
configs = list(itertools.product(head_num_range, batch_size_range, seq_length_range))
def get_benchmark(use_residual): def make_configs(bsizes: List[int], slens: List[int], hsizes: List[int]) -> List[Tuple]:
@triton.testing.perf_report( return list(itertools.product(bsizes, slens, hsizes))
@triton.testing.perf_report(
triton.testing.Benchmark( triton.testing.Benchmark(
x_names=["head_num", "batch_size", "seq_len"], x_names=["batch_size", "seq_len", "hidden_size"],
x_vals=[list(_) for _ in configs], x_vals=[],
line_arg="provider", line_arg="provider",
line_vals=["huggingface", "flashinfer", "vllm"], line_vals=["huggingface", "flashinfer", "vllm", "sglang"],
line_names=["HuggingFace", "FlashInfer", "vLLM"], line_names=["HuggingFace", "FlashInfer", "vLLM", "SGL Kernel"],
styles=[("blue", "-"), ("green", "-"), ("red", "-")], styles=[("blue", "-"), ("green", "-"), ("red", "-"), ("orange", "-")],
ylabel="us", ylabel="µs (median) or × (speed-up)",
plot_name=f"rmsnorm-performance-{'with' if use_residual else 'without'}-residual", plot_name="rmsnorm-performance",
args={}, args={},
) )
) )
def benchmark(head_num, batch_size, seq_len, provider): def benchmark(batch_size, seq_len, hidden_size, provider, use_residual):
device = torch.device("cuda")
dtype = torch.bfloat16 dtype = torch.bfloat16
hidden_size = head_num * 128 # assuming head_dim = 128
x = torch.randn(batch_size, seq_len, hidden_size, dtype=dtype, device="cuda") x = torch.randn(batch_size, seq_len, hidden_size, dtype=dtype, device=device)
weight = torch.ones(hidden_size, dtype=dtype, device="cuda") weight = torch.ones(hidden_size, dtype=dtype, device=device)
residual = torch.randn_like(x) if use_residual else None residual = torch.randn_like(x) if use_residual else None
quantiles = [0.5, 0.2, 0.8] # timing helper
def timed(fn):
for _ in range(5):
fn()
torch.cuda.synchronize()
ms, qmin, qmax = triton.testing.do_bench_cudagraph(
fn, quantiles=[0.5, 0.2, 0.8]
)
return 1000 * ms, 1000 * qmax, 1000 * qmin
if provider == "huggingface": if provider == "huggingface":
ms, min_ms, max_ms = triton.testing.do_bench( return timed(
lambda: rmsnorm_naive( lambda: rmsnorm_naive(
x.clone(), x.clone(),
weight, weight,
residual.clone() if residual is not None else None, residual.clone() if residual is not None else None,
), )
quantiles=quantiles,
) )
elif provider == "flashinfer": elif provider == "flashinfer":
ms, min_ms, max_ms = triton.testing.do_bench( return timed(
lambda: rmsnorm_flashinfer( lambda: rmsnorm_flashinfer(
x.clone(), x.clone(),
weight, weight,
residual.clone() if residual is not None else None, residual.clone() if residual is not None else None,
),
quantiles=quantiles,
) )
else: )
ms, min_ms, max_ms = triton.testing.do_bench( elif provider == "vllm":
return timed(
lambda: rmsnorm_vllm( lambda: rmsnorm_vllm(
x.clone(), x.clone(),
weight, weight,
residual.clone() if residual is not None else None, residual.clone() if residual is not None else None,
), )
quantiles=quantiles, )
elif provider == "sglang":
return timed(
lambda: rmsnorm_sglang(
x.clone(),
weight,
residual.clone() if residual is not None else None,
)
) )
return 1000 * ms, 1000 * max_ms, 1000 * min_ms # provider == "speedup"
t_ref, _, _ = timed(
return benchmark lambda: rmsnorm_vllm(
x.clone(),
weight,
residual.clone() if residual is not None else None,
)
)
t_sgl, _, _ = timed(
lambda: rmsnorm_sglang(
x.clone(),
weight,
residual.clone() if residual is not None else None,
)
)
spd = t_ref / t_sgl
return (spd, spd, spd)
if __name__ == "__main__": if __name__ == "__main__":
import argparse p = argparse.ArgumentParser("RMSNorm kernel benchmark")
p.add_argument("--batch_sizes", type=str2int_list, default=default_batch_sizes)
parser = argparse.ArgumentParser() p.add_argument("--seq_lens", type=str2int_list, default=default_seq_lens)
parser.add_argument( p.add_argument("--hidden_sizes", type=str2int_list, default=default_hidden_sizes)
p.add_argument(
"--use_residual", action="store_true", help="Whether to use residual connection" "--use_residual", action="store_true", help="Whether to use residual connection"
) )
parser.add_argument( p.add_argument("--verify_only", action="store_true")
"--save_path", args = p.parse_args()
type=str,
default="./configs/benchmark_ops/rmsnorm/", # coerce lists
help="Path to save rmsnorm benchmark results", if isinstance(args.batch_sizes, str):
) args.batch_sizes = str2int_list(args.batch_sizes)
args = parser.parse_args() if isinstance(args.seq_lens, str):
args.seq_lens = str2int_list(args.seq_lens)
# Run correctness test if isinstance(args.hidden_sizes, str):
calculate_diff( args.hidden_sizes = str2int_list(args.hidden_sizes)
batch_size=4, seq_len=128, hidden_size=4096, use_residual=args.use_residual
) # patch perf_report grid
benchmark_grid = make_configs(args.batch_sizes, args.seq_lens, args.hidden_sizes)
if hasattr(benchmark, "benchmarks"):
benchmark.benchmarks.x_vals = benchmark_grid
else:
benchmark.benchmark.x_vals = benchmark_grid
# Get the benchmark function with proper use_residual setting if args.verify_only:
benchmark = get_benchmark(args.use_residual) ok = calculate_diff(4, 128, args.hidden_sizes[0], args.use_residual)
# Run performance benchmark print("✅ sanity pass" if ok else "❌ mismatch")
benchmark.run(print_data=True, save_path=args.save_path) else:
benchmark.run(print_data=True, use_residual=args.use_residual)
...@@ -114,7 +114,9 @@ def benchmark_sampling(batch_size, vocab_size, p, provider): ...@@ -114,7 +114,9 @@ def benchmark_sampling(batch_size, vocab_size, p, provider):
filter_apply_order="joint", filter_apply_order="joint",
) )
ms, min_ms, max_ms = triton.testing.do_bench(fn, quantiles=[0.5, 0.2, 0.8]) ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
fn, quantiles=[0.5, 0.2, 0.8]
)
return 1000 * ms, 1000 * max_ms, 1000 * min_ms return 1000 * ms, 1000 * max_ms, 1000 * min_ms
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
import pytest import pytest
import sgl_kernel import sgl_kernel
import torch import torch
from sgl_kernel.utils import is_arch_support_pdl
def llama_rms_norm(x, w, eps=1e-6): def llama_rms_norm(x, w, eps=1e-6):
...@@ -58,11 +59,12 @@ def test_norm(batch_size, hidden_size, dtype, specify_out): ...@@ -58,11 +59,12 @@ def test_norm(batch_size, hidden_size, dtype, specify_out):
w = torch.randn(hidden_size).to(0).to(dtype) w = torch.randn(hidden_size).to(0).to(dtype)
y_ref = llama_rms_norm(x, w) y_ref = llama_rms_norm(x, w)
enable_pdl = is_arch_support_pdl()
if specify_out: if specify_out:
y = torch.empty_like(x) y = torch.empty_like(x)
sgl_kernel.rmsnorm(x, w, out=y) sgl_kernel.rmsnorm(x, w, out=y, enable_pdl=enable_pdl)
else: else:
y = sgl_kernel.rmsnorm(x, w) y = sgl_kernel.rmsnorm(x, w, enable_pdl=enable_pdl)
torch.testing.assert_close(y_ref, y, rtol=1e-3, atol=1e-3) torch.testing.assert_close(y_ref, y, rtol=1e-3, atol=1e-3)
...@@ -83,7 +85,10 @@ def test_fused_add_rmsnorm(batch_size, hidden_size, dtype): ...@@ -83,7 +85,10 @@ def test_fused_add_rmsnorm(batch_size, hidden_size, dtype):
x_fused = x.clone() x_fused = x.clone()
residual_fused = residual.clone() residual_fused = residual.clone()
sgl_kernel.fused_add_rmsnorm(x_fused, residual_fused, weight, eps) enable_pdl = is_arch_support_pdl()
sgl_kernel.fused_add_rmsnorm(
x_fused, residual_fused, weight, eps, enable_pdl=enable_pdl
)
torch.testing.assert_close(x_fused, x_native, rtol=1e-3, atol=1e-3) torch.testing.assert_close(x_fused, x_native, rtol=1e-3, atol=1e-3)
torch.testing.assert_close(residual_fused, residual_native, rtol=1e-3, atol=1e-3) torch.testing.assert_close(residual_fused, residual_native, rtol=1e-3, atol=1e-3)
...@@ -98,11 +103,12 @@ def test_gemma_norm(batch_size, hidden_size, dtype, specify_out): ...@@ -98,11 +103,12 @@ def test_gemma_norm(batch_size, hidden_size, dtype, specify_out):
w = torch.randn(hidden_size).to(0).to(dtype) w = torch.randn(hidden_size).to(0).to(dtype)
y_ref = gemma_rms_norm(x, w) y_ref = gemma_rms_norm(x, w)
enable_pdl = is_arch_support_pdl()
if specify_out: if specify_out:
y = torch.empty_like(x) y = torch.empty_like(x)
sgl_kernel.gemma_rmsnorm(x, w, out=y) sgl_kernel.gemma_rmsnorm(x, w, out=y, enable_pdl=enable_pdl)
else: else:
y = sgl_kernel.gemma_rmsnorm(x, w) y = sgl_kernel.gemma_rmsnorm(x, w, enable_pdl=enable_pdl)
torch.testing.assert_close(y_ref, y, rtol=1e-3, atol=1e-3) torch.testing.assert_close(y_ref, y, rtol=1e-3, atol=1e-3)
...@@ -123,7 +129,10 @@ def test_gemma_fused_add_rmsnorm(batch_size, hidden_size, dtype): ...@@ -123,7 +129,10 @@ def test_gemma_fused_add_rmsnorm(batch_size, hidden_size, dtype):
x_fused = x.clone() x_fused = x.clone()
residual_fused = residual.clone() residual_fused = residual.clone()
sgl_kernel.gemma_fused_add_rmsnorm(x_fused, residual_fused, weight, eps) enable_pdl = is_arch_support_pdl()
sgl_kernel.gemma_fused_add_rmsnorm(
x_fused, residual_fused, weight, eps, enable_pdl=enable_pdl
)
torch.testing.assert_close(x_fused, x_native, rtol=1e-3, atol=1e-3) torch.testing.assert_close(x_fused, x_native, rtol=1e-3, atol=1e-3)
torch.testing.assert_close(residual_fused, residual_native, rtol=1e-3, atol=1e-3) torch.testing.assert_close(residual_fused, residual_native, rtol=1e-3, atol=1e-3)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment