Unverified Commit c4e314f9 authored by Xiaoyu Zhang's avatar Xiaoyu Zhang Committed by GitHub
Browse files

Restruct sgl-kernel benchmark (#10861)

parent 7a06ef98
...@@ -5,7 +5,8 @@ import tilelang ...@@ -5,7 +5,8 @@ import tilelang
import tilelang.language as T import tilelang.language as T
import torch import torch
import triton import triton
from deep_gemm import ceil_div, get_col_major_tma_aligned_tensor from deep_gemm import ceil_div
from deep_gemm.utils.layout import get_mn_major_tma_aligned_tensor
from vllm.model_executor.layers.quantization.utils.fp8_utils import ( from vllm.model_executor.layers.quantization.utils.fp8_utils import (
w8a8_block_fp8_matmul as vllm_w8a8_block_fp8_matmul, w8a8_block_fp8_matmul as vllm_w8a8_block_fp8_matmul,
) )
...@@ -131,7 +132,7 @@ def fp8_gemm_deepgemm( ...@@ -131,7 +132,7 @@ def fp8_gemm_deepgemm(
out = torch.empty((m, n), device="cuda", dtype=torch.bfloat16) out = torch.empty((m, n), device="cuda", dtype=torch.bfloat16)
# Run DeepGEMM kernel # Run DeepGEMM kernel
deep_gemm.gemm_fp8_fp8_bf16_nt((x_fp8, x_scale), (y_fp8, y_scale), out) deep_gemm.fp8_gemm_nt((x_fp8, x_scale), (y_fp8, y_scale), out)
return out return out
...@@ -179,7 +180,7 @@ def calculate_diff(m: int, n: int, k: int): ...@@ -179,7 +180,7 @@ def calculate_diff(m: int, n: int, k: int):
x_fp8, x_scale = per_token_cast_to_fp8(x.clone()) x_fp8, x_scale = per_token_cast_to_fp8(x.clone())
y_fp8, y_scale = per_block_cast_to_fp8(y.clone()) y_fp8, y_scale = per_block_cast_to_fp8(y.clone())
x_scale_col_major = get_col_major_tma_aligned_tensor(x_scale.clone()) x_scale_col_major = get_mn_major_tma_aligned_tensor(x_scale.clone())
out_deepgemm = fp8_gemm_deepgemm( out_deepgemm = fp8_gemm_deepgemm(
x_fp8.clone(), x_fp8.clone(),
...@@ -300,7 +301,7 @@ def get_benchmark(tp_size): ...@@ -300,7 +301,7 @@ def get_benchmark(tp_size):
# Preprocess data before benchmarking # Preprocess data before benchmarking
x_fp8, x_scale = per_token_cast_to_fp8(x) x_fp8, x_scale = per_token_cast_to_fp8(x)
y_fp8, y_scale = per_block_cast_to_fp8(y) y_fp8, y_scale = per_block_cast_to_fp8(y)
x_scale_col_major = get_col_major_tma_aligned_tensor(x_scale.clone()) x_scale_col_major = get_mn_major_tma_aligned_tensor(x_scale.clone())
quantiles = [0.5, 0.2, 0.8] quantiles = [0.5, 0.2, 0.8]
......
...@@ -4,7 +4,8 @@ import deep_gemm ...@@ -4,7 +4,8 @@ import deep_gemm
import torch import torch
import triton import triton
import triton.language as tl import triton.language as tl
from deep_gemm import calc_diff, get_col_major_tma_aligned_tensor from deep_gemm import calc_diff
from deep_gemm.utils.layout import get_mn_major_tma_aligned_tensor
# Import shared functionality from the regular GEMM benchmark # Import shared functionality from the regular GEMM benchmark
from sglang.benchmark.kernels.deepseek.benchmark_deepgemm_fp8_gemm import ( from sglang.benchmark.kernels.deepseek.benchmark_deepgemm_fp8_gemm import (
...@@ -71,9 +72,9 @@ def construct_grouped_and_flat_fp8( ...@@ -71,9 +72,9 @@ def construct_grouped_and_flat_fp8(
# Transpose earlier for testing # Transpose earlier for testing
x_fp8_grouped = ( x_fp8_grouped = (
x_fp8_grouped[0], x_fp8_grouped[0],
get_col_major_tma_aligned_tensor(x_fp8_grouped[1]), get_mn_major_tma_aligned_tensor(x_fp8_grouped[1]),
) )
x_fp8_flat = (x_fp8_flat[0], get_col_major_tma_aligned_tensor(x_fp8_flat[1])) x_fp8_flat = (x_fp8_flat[0], get_mn_major_tma_aligned_tensor(x_fp8_flat[1]))
return x_fp8_grouped, y_fp8_grouped, x_fp8_flat, y_fp8_flat, out, ref_out return x_fp8_grouped, y_fp8_grouped, x_fp8_flat, y_fp8_flat, out, ref_out
...@@ -240,7 +241,7 @@ def fp8_gemm_group_triton(a_tuple, b_tuple, c, num_groups): ...@@ -240,7 +241,7 @@ def fp8_gemm_group_triton(a_tuple, b_tuple, c, num_groups):
def fp8_gemm_group_deepgemm(x_fp8_grouped, y_fp8_grouped, out, m_indices): def fp8_gemm_group_deepgemm(x_fp8_grouped, y_fp8_grouped, out, m_indices):
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( deep_gemm.m_grouped_fp8_gemm_nt_contiguous(
x_fp8_grouped, x_fp8_grouped,
y_fp8_grouped, y_fp8_grouped,
out, out,
......
...@@ -19,10 +19,6 @@ from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod ...@@ -19,10 +19,6 @@ from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
from sglang.srt.layers.quantization.utils import is_layer_skipped from sglang.srt.layers.quantization.utils import is_layer_skipped
from sglang.srt.utils import is_npu, set_weight_attrs from sglang.srt.utils import is_npu, set_weight_attrs
_is_npu = is_npu()
if not _is_npu:
from sglang.srt.layers.moe.cutlass_w4a8_moe import cutlass_w4a8_moe
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.layers.moe import MoeRunnerConfig from sglang.srt.layers.moe import MoeRunnerConfig
from sglang.srt.layers.moe.ep_moe.layer import EPMoE from sglang.srt.layers.moe.ep_moe.layer import EPMoE
......
...@@ -621,11 +621,11 @@ class TestW8A8BlockFP8BatchedDeepGemm(CustomTestCase): ...@@ -621,11 +621,11 @@ class TestW8A8BlockFP8BatchedDeepGemm(CustomTestCase):
w_s, w_s,
) )
from deep_gemm import m_grouped_gemm_fp8_fp8_bf16_nt_masked from deep_gemm import fp8_m_grouped_gemm_nt_masked
with torch.inference_mode(): with torch.inference_mode():
ref_out = torch_w8a8_block_fp8_bmm(a, a_s, w, w_s, block_size, dtype) ref_out = torch_w8a8_block_fp8_bmm(a, a_s, w, w_s, block_size, dtype)
m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs, rhs, oe, masked_m, expected_m) fp8_m_grouped_gemm_nt_masked(lhs, rhs, oe, masked_m, expected_m)
out = oe[:, :M, :] out = oe[:, :M, :]
self.assertTrue( self.assertTrue(
......
...@@ -251,6 +251,14 @@ To use this with your library functions, simply wrap them with make_pytorch_shim ...@@ -251,6 +251,14 @@ To use this with your library functions, simply wrap them with make_pytorch_shim
``` ```
2. Add benchmarks using [triton benchmark](https://triton-lang.org/main/python-api/generated/triton.testing.Benchmark.html) in [benchmark/](https://github.com/sgl-project/sglang/tree/main/sgl-kernel/benchmark) 2. Add benchmarks using [triton benchmark](https://triton-lang.org/main/python-api/generated/triton.testing.Benchmark.html) in [benchmark/](https://github.com/sgl-project/sglang/tree/main/sgl-kernel/benchmark)
**We recommend using `triton.testing.do_bench_cudagraph` for kernel benchmarking**:
Compared to `triton.testing.do_bench`, `do_bench_cudagraph` provides:
- Reduced CPU overhead impact for more accurate kernel performance measurements
- Incorporation of PDL (Programmatic Dependent Launch) effects into individual kernel results
- More realistic performance data on PDL-supported architectures (SM >= 90)
3. Run test suite 3. Run test suite
### FAQ ### FAQ
......
...@@ -10,10 +10,18 @@ import torch ...@@ -10,10 +10,18 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
import triton import triton
import triton.testing import triton.testing
from sgl_kernel import gelu_quick # activation-only kernel
from sgl_kernel import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul from sgl_kernel import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul
from vllm import _custom_ops as vllm_ops from vllm import _custom_ops as vllm_ops
# gelu_quick is only available on HIP/ROCm platforms
try:
from sgl_kernel import gelu_quick
GELU_QUICK_AVAILABLE = True
except ImportError:
GELU_QUICK_AVAILABLE = False
gelu_quick = None
if not hasattr(vllm_ops, "silu_and_mul"): if not hasattr(vllm_ops, "silu_and_mul"):
vllm_ops = torch.ops._C vllm_ops = torch.ops._C
...@@ -34,6 +42,12 @@ def calculate_diff( ...@@ -34,6 +42,12 @@ def calculate_diff(
# activation-only quick GELU # activation-only quick GELU
if kernel == "gelu_quick": if kernel == "gelu_quick":
if not GELU_QUICK_AVAILABLE:
print(
f"[{kernel:14s} | {str(dtype):9s} | B={batch_size:3d} | "
f"L={seq_len:3d} | D={dim:5d}] ⚠️ not available on this platform"
)
return True
x = torch.randn(batch_size, seq_len, dim, dtype=dtype, device=device) x = torch.randn(batch_size, seq_len, dim, dtype=dtype, device=device)
ref_out = torch.zeros_like(x) ref_out = torch.zeros_like(x)
getattr(vllm_ops, kernel)(ref_out, x) getattr(vllm_ops, kernel)(ref_out, x)
...@@ -54,7 +68,9 @@ def calculate_diff( ...@@ -54,7 +68,9 @@ def calculate_diff(
return ok return ok
kernels = ["silu_and_mul", "gelu_and_mul", "gelu_tanh_and_mul", "gelu_quick"] kernels = ["silu_and_mul", "gelu_and_mul", "gelu_tanh_and_mul"]
if GELU_QUICK_AVAILABLE:
kernels.append("gelu_quick")
dtypes = [torch.float16, torch.bfloat16] dtypes = [torch.float16, torch.bfloat16]
...@@ -64,7 +80,7 @@ def make_configs(bsizes: List[int], slens: List[int], dims_: List[int]) -> List[ ...@@ -64,7 +80,7 @@ def make_configs(bsizes: List[int], slens: List[int], dims_: List[int]) -> List[
default_batch_sizes = [2**i for i in range(0, 5, 2)] # 1,4,16 default_batch_sizes = [2**i for i in range(0, 5, 2)] # 1,4,16
default_seq_lens = [2**i for i in range(0, 8, 2)] # 1,4,16,64 default_seq_lens = [2**i for i in range(0, 8, 2)] # 1,4,16,64
default_dims = [2**i for i in range(7, 15)] # 128...16384 default_dims = [2**i for i in range(10, 15)] # 1024...16384
@triton.testing.perf_report( @triton.testing.perf_report(
...@@ -87,6 +103,9 @@ def benchmark(kernel, dtype, batch_size, seq_len, dim, provider): ...@@ -87,6 +103,9 @@ def benchmark(kernel, dtype, batch_size, seq_len, dim, provider):
y0 = torch.zeros(batch_size, seq_len, dim, dtype=dtype, device=device) y0 = torch.zeros(batch_size, seq_len, dim, dtype=dtype, device=device)
vllm_kernel = getattr(vllm_ops, kernel) vllm_kernel = getattr(vllm_ops, kernel)
if kernel == "gelu_quick" and not GELU_QUICK_AVAILABLE:
# Skip benchmark for gelu_quick if not available
return (0, 0, 0)
sglang_kernel = getattr(sgl_kernel, kernel) sglang_kernel = getattr(sgl_kernel, kernel)
def baseline(): def baseline():
...@@ -97,18 +116,14 @@ def benchmark(kernel, dtype, batch_size, seq_len, dim, provider): ...@@ -97,18 +116,14 @@ def benchmark(kernel, dtype, batch_size, seq_len, dim, provider):
def sglang(): def sglang():
return sglang_kernel(x) return sglang_kernel(x)
# one-time correctness check
if provider == "vllm" and not calculate_diff(
kernel, dtype, batch_size, seq_len, dim
):
raise ValueError("Mismatch – abort benchmark")
# timing helper # timing helper
def timed(fn): def timed(fn):
for _ in range(5): for _ in range(5):
fn() fn()
torch.cuda.synchronize() torch.cuda.synchronize()
ms, qmin, qmax = triton.testing.do_bench(fn, quantiles=[0.5, 0.2, 0.8]) ms, qmin, qmax = triton.testing.do_bench_cudagraph(
fn, quantiles=[0.5, 0.2, 0.8]
)
return 1000 * ms, 1000 * qmax, 1000 * qmin return 1000 * ms, 1000 * qmax, 1000 * qmin
if provider == "vllm": if provider == "vllm":
...@@ -147,7 +162,9 @@ if __name__ == "__main__": ...@@ -147,7 +162,9 @@ if __name__ == "__main__":
benchmark.benchmark.x_vals = benchmark_grid benchmark.benchmark.x_vals = benchmark_grid
if args.verify_only: if args.verify_only:
ok = calculate_diff("gelu_quick", torch.float16, 1, 1, args.dims[0]) # Test with the first available kernel
test_kernel = kernels[0]
ok = calculate_diff(test_kernel, torch.float16, 1, 1, args.dims[0])
print("✅ sanity pass" if ok else "❌ mismatch") print("✅ sanity pass" if ok else "❌ mismatch")
else: else:
benchmark.run(print_data=True) benchmark.run(print_data=True)
...@@ -108,7 +108,7 @@ def benchmark(qweight_row, qweight_col, provider): ...@@ -108,7 +108,7 @@ def benchmark(qweight_row, qweight_col, provider):
qweight.clone(), scales.clone(), qzeros.clone() qweight.clone(), scales.clone(), qzeros.clone()
) )
ms, min_ms, max_ms = triton.testing.do_bench(fn, quantiles=quantiles) ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(fn, quantiles=quantiles)
return 1000 * ms, 1000 * max_ms, 1000 * min_ms return 1000 * ms, 1000 * max_ms, 1000 * min_ms
......
...@@ -87,7 +87,7 @@ def benchmark(batch_size, seq_len, provider, block_size, num_kv_splits): ...@@ -87,7 +87,7 @@ def benchmark(batch_size, seq_len, provider, block_size, num_kv_splits):
workspace = torch.empty(workspace_size, device="cuda", dtype=torch.uint8) workspace = torch.empty(workspace_size, device="cuda", dtype=torch.uint8)
quantiles = [0.5, 0.2, 0.8] quantiles = [0.5, 0.2, 0.8]
ms, min_ms, max_ms = triton.testing.do_bench( ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
lambda: cutlass_mla_decode( lambda: cutlass_mla_decode(
qn.transpose(0, 1), qn.transpose(0, 1),
qr, qr,
...@@ -136,8 +136,6 @@ if __name__ == "__main__": ...@@ -136,8 +136,6 @@ if __name__ == "__main__":
print(f"block_size={block_size}, num_kv_splits={kv_split}: ") print(f"block_size={block_size}, num_kv_splits={kv_split}: ")
benchmark.run( benchmark.run(
print_data=True, print_data=True,
show_plots=True,
save_path="bench_blackwell_mla_res",
block_size=block_size, block_size=block_size,
num_kv_splits=kv_split, num_kv_splits=kv_split,
) )
......
...@@ -41,7 +41,7 @@ def benchmark(num_tokens, impl): ...@@ -41,7 +41,7 @@ def benchmark(num_tokens, impl):
def runner(): def runner():
dsv3_fused_a_gemm(mat_a, mat_b) dsv3_fused_a_gemm(mat_a, mat_b)
ms, min_ms, max_ms = triton.testing.do_bench(runner, quantiles=quantiles) ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(runner, quantiles=quantiles)
def tflops(t_ms): def tflops(t_ms):
flops = 2 * M * K * N flops = 2 * M * K * N
...@@ -54,4 +54,4 @@ if __name__ == "__main__": ...@@ -54,4 +54,4 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
args = parser.parse_args() args = parser.parse_args()
benchmark.run(print_data=True, show_plots=True, save_path="bench_dsv3_gemm") benchmark.run(print_data=True)
...@@ -52,7 +52,7 @@ def benchmark_bf16_output(num_tokens, impl): ...@@ -52,7 +52,7 @@ def benchmark_bf16_output(num_tokens, impl):
def runner(): def runner():
dsv3_router_gemm(mat_a, mat_b, out_dtype=torch.bfloat16) dsv3_router_gemm(mat_a, mat_b, out_dtype=torch.bfloat16)
ms, min_ms, max_ms = triton.testing.do_bench(runner, quantiles=quantiles) ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(runner, quantiles=quantiles)
def tflops(t_ms): def tflops(t_ms):
flops = 2 * M * K * N flops = 2 * M * K * N
...@@ -106,7 +106,7 @@ def benchmark_float_output(num_tokens, impl): ...@@ -106,7 +106,7 @@ def benchmark_float_output(num_tokens, impl):
def runner(): def runner():
dsv3_router_gemm(mat_a, mat_b, out_dtype=torch.float32) dsv3_router_gemm(mat_a, mat_b, out_dtype=torch.float32)
ms, min_ms, max_ms = triton.testing.do_bench(runner, quantiles=quantiles) ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(runner, quantiles=quantiles)
def tflops(t_ms): def tflops(t_ms):
flops = 2 * M * K * N flops = 2 * M * K * N
...@@ -119,9 +119,5 @@ if __name__ == "__main__": ...@@ -119,9 +119,5 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
args = parser.parse_args() args = parser.parse_args()
benchmark_bf16_output.run( benchmark_bf16_output.run(print_data=True)
print_data=True, show_plots=True, save_path="bench_dsv3_router_gemm" benchmark_float_output.run(print_data=True)
)
benchmark_float_output.run(
print_data=True, show_plots=True, save_path="bench_dsv3_router_gemm"
)
...@@ -198,8 +198,6 @@ if __name__ == "__main__": ...@@ -198,8 +198,6 @@ if __name__ == "__main__":
print(f"DeepSeek-R1-0528-FP4 N={N} K={K}: ") print(f"DeepSeek-R1-0528-FP4 N={N} K={K}: ")
benchmark.run( benchmark.run(
print_data=True, print_data=True,
show_plots=True,
save_path="bench_fp4_res",
N=N, N=N,
K=K, K=K,
dtype=args.dtype, dtype=args.dtype,
......
...@@ -5,7 +5,7 @@ import itertools ...@@ -5,7 +5,7 @@ import itertools
import deep_gemm import deep_gemm
import torch import torch
import triton import triton
from deep_gemm import get_col_major_tma_aligned_tensor from deep_gemm.utils.layout import get_mn_major_tma_aligned_tensor
from sgl_kernel import fp8_blockwise_scaled_mm from sgl_kernel import fp8_blockwise_scaled_mm
from vllm._custom_ops import cutlass_scaled_mm as vllm_scaled_mm from vllm._custom_ops import cutlass_scaled_mm as vllm_scaled_mm
...@@ -71,7 +71,7 @@ def fp8_gemm_deepgemm( ...@@ -71,7 +71,7 @@ def fp8_gemm_deepgemm(
out = torch.empty((m, n), device="cuda", dtype=torch.bfloat16) out = torch.empty((m, n), device="cuda", dtype=torch.bfloat16)
# Run DeepGEMM kernel # Run DeepGEMM kernel
deep_gemm.gemm_fp8_fp8_bf16_nt((x_fp8, x_scale), (y_fp8, y_scale), out) deep_gemm.fp8_gemm_nt((x_fp8, x_scale), (y_fp8, y_scale), out)
return out return out
...@@ -117,7 +117,7 @@ def benchmark(batch_size, provider, N, K): ...@@ -117,7 +117,7 @@ def benchmark(batch_size, provider, N, K):
if provider == "sgl-kernel": if provider == "sgl-kernel":
scale_a = scale_a.t().contiguous().t() scale_a = scale_a.t().contiguous().t()
b_fp8, scale_b = b_fp8.t(), scale_b.t() b_fp8, scale_b = b_fp8.t(), scale_b.t()
ms, min_ms, max_ms = triton.testing.do_bench( ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
lambda: fp8_blockwise_scaled_mm( lambda: fp8_blockwise_scaled_mm(
a_fp8, b_fp8, scale_a, scale_b, torch.float16 a_fp8, b_fp8, scale_a, scale_b, torch.float16
), ),
...@@ -126,20 +126,20 @@ def benchmark(batch_size, provider, N, K): ...@@ -126,20 +126,20 @@ def benchmark(batch_size, provider, N, K):
if provider == "vllm": if provider == "vllm":
scale_a = scale_a.t().contiguous().t() scale_a = scale_a.t().contiguous().t()
b_fp8, scale_b = b_fp8.t(), scale_b.t() b_fp8, scale_b = b_fp8.t(), scale_b.t()
ms, min_ms, max_ms = triton.testing.do_bench( ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
lambda: vllm_scaled_mm(a_fp8, b_fp8, scale_a, scale_b, torch.float16), lambda: vllm_scaled_mm(a_fp8, b_fp8, scale_a, scale_b, torch.float16),
quantiles=quantiles, quantiles=quantiles,
) )
if provider == "triton": if provider == "triton":
ms, min_ms, max_ms = triton.testing.do_bench( ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
lambda: w8a8_block_fp8_matmul( lambda: w8a8_block_fp8_matmul(
a_fp8, b_fp8, scale_a, scale_b, [128, 128], torch.float16 a_fp8, b_fp8, scale_a, scale_b, [128, 128], torch.float16
), ),
quantiles=quantiles, quantiles=quantiles,
) )
if provider == "deepgemm": if provider == "deepgemm":
scale_a_col_major = get_col_major_tma_aligned_tensor(scale_a.clone()) scale_a_col_major = get_mn_major_tma_aligned_tensor(scale_a.clone())
ms, min_ms, max_ms = triton.testing.do_bench( ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
lambda: fp8_gemm_deepgemm( lambda: fp8_gemm_deepgemm(
a_fp8, scale_a_col_major, b_fp8, scale_b, M, N, K a_fp8, scale_a_col_major, b_fp8, scale_b, M, N, K
), ),
...@@ -174,8 +174,6 @@ if __name__ == "__main__": ...@@ -174,8 +174,6 @@ if __name__ == "__main__":
print(f"{model_name} N={N} K={K}: ") print(f"{model_name} N={N} K={K}: ")
benchmark.run( benchmark.run(
print_data=True, print_data=True,
show_plots=True,
save_path="bench_fp8_blockwise_res",
N=N, N=N,
K=K, K=K,
) )
......
...@@ -125,7 +125,7 @@ def benchmark(batch_size, provider, N, K): ...@@ -125,7 +125,7 @@ def benchmark(batch_size, provider, N, K):
a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant(a, scale_a) a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant(a, scale_a)
b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b, scale_b) b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b, scale_b)
b_fp8 = b_fp8.t() b_fp8 = b_fp8.t()
ms, min_ms, max_ms = triton.testing.do_bench( ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
lambda: vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype), lambda: vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype),
quantiles=quantiles, quantiles=quantiles,
) )
...@@ -133,7 +133,7 @@ def benchmark(batch_size, provider, N, K): ...@@ -133,7 +133,7 @@ def benchmark(batch_size, provider, N, K):
a_fp8, scale_a_fp8 = sglang_scaled_fp8_quant(a, scale_a) a_fp8, scale_a_fp8 = sglang_scaled_fp8_quant(a, scale_a)
b_fp8, scale_b_fp8 = sglang_scaled_fp8_quant(b, scale_b) b_fp8, scale_b_fp8 = sglang_scaled_fp8_quant(b, scale_b)
b_fp8 = b_fp8.t() b_fp8 = b_fp8.t()
ms, min_ms, max_ms = triton.testing.do_bench( ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
lambda: sgl_scaled_mm( lambda: sgl_scaled_mm(
a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype, bias=None a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype, bias=None
), ),
...@@ -177,8 +177,6 @@ if __name__ == "__main__": ...@@ -177,8 +177,6 @@ if __name__ == "__main__":
KN_model_names = prepare_shapes(args) KN_model_names = prepare_shapes(args)
for K, N, model_name in KN_model_names: for K, N, model_name in KN_model_names:
print(f"{model_name} N={N} K={K}: ") print(f"{model_name} N={N} K={K}: ")
benchmark.run( benchmark.run(print_data=True, N=N, K=K)
print_data=True, show_plots=True, save_path="bench_fp8_res", N=N, K=K
)
print("Benchmark finished!") print("Benchmark finished!")
...@@ -86,12 +86,12 @@ def benchmark(batch_size, provider, N, K): ...@@ -86,12 +86,12 @@ def benchmark(batch_size, provider, N, K):
quantiles = [0.5, 0.2, 0.8] quantiles = [0.5, 0.2, 0.8]
if provider == "sgl-kernel": if provider == "sgl-kernel":
ms, min_ms, max_ms = triton.testing.do_bench( ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
lambda: int8_scaled_mm(a, b, scale_a, scale_b, torch.float16, bias), lambda: int8_scaled_mm(a, b, scale_a, scale_b, torch.float16, bias),
quantiles=quantiles, quantiles=quantiles,
) )
if provider == "vllm": if provider == "vllm":
ms, min_ms, max_ms = triton.testing.do_bench( ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
lambda: vllm_scaled_mm(a, b, scale_a, scale_b, torch.float16, bias), lambda: vllm_scaled_mm(a, b, scale_a, scale_b, torch.float16, bias),
quantiles=quantiles, quantiles=quantiles,
) )
...@@ -139,8 +139,6 @@ if __name__ == "__main__": ...@@ -139,8 +139,6 @@ if __name__ == "__main__":
KN_model_names = prepare_shapes(args) KN_model_names = prepare_shapes(args)
for K, N, model_name in KN_model_names: for K, N, model_name in KN_model_names:
print(f"{model_name} N={N} K={K}: ") print(f"{model_name} N={N} K={K}: ")
benchmark.run( benchmark.run(print_data=True, N=N, K=K)
print_data=True, show_plots=True, save_path="bench_int8_res", N=N, K=K
)
print("Benchmark finished!") print("Benchmark finished!")
...@@ -246,7 +246,7 @@ def benchmark(batch_size, provider): ...@@ -246,7 +246,7 @@ def benchmark(batch_size, provider):
quantiles = [0.5, 0.2, 0.8] quantiles = [0.5, 0.2, 0.8]
if provider == "naive": if provider == "naive":
ms, min_ms, max_ms = triton.testing.do_bench( ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
lambda: lightning_attention_decode_naive( lambda: lightning_attention_decode_naive(
q.clone(), k.clone(), v.clone(), past_kv.clone(), slope.clone() q.clone(), k.clone(), v.clone(), past_kv.clone(), slope.clone()
), ),
...@@ -257,7 +257,7 @@ def benchmark(batch_size, provider): ...@@ -257,7 +257,7 @@ def benchmark(batch_size, provider):
batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype
) )
new_kv = torch.empty(batch_size, num_heads, head_dim, head_dim, device=device) new_kv = torch.empty(batch_size, num_heads, head_dim, head_dim, device=device)
ms, min_ms, max_ms = triton.testing.do_bench( ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
lambda: lightning_attention_decode_kernel( lambda: lightning_attention_decode_kernel(
q.clone(), q.clone(),
k.clone(), k.clone(),
...@@ -270,7 +270,7 @@ def benchmark(batch_size, provider): ...@@ -270,7 +270,7 @@ def benchmark(batch_size, provider):
quantiles=quantiles, quantiles=quantiles,
) )
elif provider == "triton": elif provider == "triton":
ms, min_ms, max_ms = triton.testing.do_bench( ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
lambda: triton_lightning_attn_decode( lambda: triton_lightning_attn_decode(
q.clone(), k.clone(), v.clone(), past_kv.clone(), slope.clone() q.clone(), k.clone(), v.clone(), past_kv.clone(), slope.clone()
), ),
......
...@@ -324,7 +324,7 @@ def benchmark(num_tokens, num_experts, topk, provider): ...@@ -324,7 +324,7 @@ def benchmark(num_tokens, num_experts, topk, provider):
quantiles = [0.5, 0.2, 0.8] quantiles = [0.5, 0.2, 0.8]
if provider == "sgl": if provider == "sgl":
ms, min_ms, max_ms = triton.testing.do_bench( ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
lambda: sgl_moe_align_block_size_with_empty( lambda: sgl_moe_align_block_size_with_empty(
topk_ids, topk_ids,
num_experts, num_experts,
...@@ -336,7 +336,7 @@ def benchmark(num_tokens, num_experts, topk, provider): ...@@ -336,7 +336,7 @@ def benchmark(num_tokens, num_experts, topk, provider):
quantiles=quantiles, quantiles=quantiles,
) )
elif provider == "sgl_fusion": elif provider == "sgl_fusion":
ms, min_ms, max_ms = triton.testing.do_bench( ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
lambda: sgl_moe_align_block_size_with_empty( lambda: sgl_moe_align_block_size_with_empty(
topk_ids, topk_ids,
num_experts, num_experts,
...@@ -350,7 +350,7 @@ def benchmark(num_tokens, num_experts, topk, provider): ...@@ -350,7 +350,7 @@ def benchmark(num_tokens, num_experts, topk, provider):
) )
elif provider == "triton": elif provider == "triton":
sorted_ids.fill_(topk_ids.numel()) sorted_ids.fill_(topk_ids.numel())
ms, min_ms, max_ms = triton.testing.do_bench( ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
lambda: moe_align_block_size_triton( lambda: moe_align_block_size_triton(
topk_ids, topk_ids,
num_experts, num_experts,
......
...@@ -63,7 +63,9 @@ def benchmark(batch_size, provider): ...@@ -63,7 +63,9 @@ def benchmark(batch_size, provider):
block_size, block_size,
) )
ms, min_ms, max_ms = triton.testing.do_bench(run_triton, quantiles=quantiles) ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
run_triton, quantiles=quantiles
)
else: else:
raise ValueError(f"Unknown provider: {provider}") raise ValueError(f"Unknown provider: {provider}")
......
...@@ -46,7 +46,7 @@ configs = [(sq,) for sq in seq_length_range] ...@@ -46,7 +46,7 @@ configs = [(sq,) for sq in seq_length_range]
) )
) )
def benchmark(seq_length, provider): def benchmark(seq_length, provider):
dtype = torch.bfloat16 dtype = torch.float32
device = torch.device("cuda") device = torch.device("cuda")
num_experts, num_expert_group, topk_group, topk = 256, 8, 4, 8 num_experts, num_expert_group, topk_group, topk = 256, 8, 4, 8
...@@ -56,14 +56,14 @@ def benchmark(seq_length, provider): ...@@ -56,14 +56,14 @@ def benchmark(seq_length, provider):
quantiles = [0.5, 0.2, 0.8] quantiles = [0.5, 0.2, 0.8]
if provider == "original": if provider == "original":
ms, min_ms, max_ms = triton.testing.do_bench( ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
lambda: biased_grouped_topk_org( lambda: biased_grouped_topk_org(
scores.clone(), bias.clone(), num_expert_group, topk_group, topk scores.clone(), bias.clone(), num_expert_group, topk_group, topk
), ),
quantiles=quantiles, quantiles=quantiles,
) )
elif provider == "kernel": elif provider == "kernel":
ms, min_ms, max_ms = triton.testing.do_bench( ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
lambda: biased_grouped_topk_org_fuse_kernel( lambda: biased_grouped_topk_org_fuse_kernel(
scores.clone(), bias.clone(), num_expert_group, topk_group, topk scores.clone(), bias.clone(), num_expert_group, topk_group, topk
), ),
......
...@@ -97,7 +97,7 @@ def benchmark(num_tokens, num_experts, topk, provider): ...@@ -97,7 +97,7 @@ def benchmark(num_tokens, num_experts, topk, provider):
fn = lambda: sglang_topk_softmax(gating_output, topk) fn = lambda: sglang_topk_softmax(gating_output, topk)
quantiles = [0.5, 0.2, 0.8] quantiles = [0.5, 0.2, 0.8]
ms, min_ms, max_ms = triton.testing.do_bench(fn, quantiles=quantiles) ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(fn, quantiles=quantiles)
return 1000 * ms, 1000 * max_ms, 1000 * min_ms return 1000 * ms, 1000 * max_ms, 1000 * min_ms
......
...@@ -165,8 +165,6 @@ if __name__ == "__main__": ...@@ -165,8 +165,6 @@ if __name__ == "__main__":
KN_model_names = prepare_shapes(args) KN_model_names = prepare_shapes(args)
for K, N, model_name in KN_model_names: for K, N, model_name in KN_model_names:
print(f"{model_name} N={N} K={K}: ") print(f"{model_name} N={N} K={K}: ")
benchmark.run( benchmark.run(print_data=True, N=N, K=K)
print_data=True, show_plots=True, save_path="bench_fp4_res", N=N, K=K
)
print("Benchmark finished!") print("Benchmark finished!")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment