"torchvision/vscode:/vscode.git/clone" did not exist on "4c0f44145792adca866a1668a79f2e11ed966491"
Unverified Commit c4e314f9 authored by Xiaoyu Zhang's avatar Xiaoyu Zhang Committed by GitHub
Browse files

Restruct sgl-kernel benchmark (#10861)

parent 7a06ef98
......@@ -5,7 +5,8 @@ import tilelang
import tilelang.language as T
import torch
import triton
from deep_gemm import ceil_div, get_col_major_tma_aligned_tensor
from deep_gemm import ceil_div
from deep_gemm.utils.layout import get_mn_major_tma_aligned_tensor
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
w8a8_block_fp8_matmul as vllm_w8a8_block_fp8_matmul,
)
......@@ -131,7 +132,7 @@ def fp8_gemm_deepgemm(
out = torch.empty((m, n), device="cuda", dtype=torch.bfloat16)
# Run DeepGEMM kernel
deep_gemm.gemm_fp8_fp8_bf16_nt((x_fp8, x_scale), (y_fp8, y_scale), out)
deep_gemm.fp8_gemm_nt((x_fp8, x_scale), (y_fp8, y_scale), out)
return out
......@@ -179,7 +180,7 @@ def calculate_diff(m: int, n: int, k: int):
x_fp8, x_scale = per_token_cast_to_fp8(x.clone())
y_fp8, y_scale = per_block_cast_to_fp8(y.clone())
x_scale_col_major = get_col_major_tma_aligned_tensor(x_scale.clone())
x_scale_col_major = get_mn_major_tma_aligned_tensor(x_scale.clone())
out_deepgemm = fp8_gemm_deepgemm(
x_fp8.clone(),
......@@ -300,7 +301,7 @@ def get_benchmark(tp_size):
# Preprocess data before benchmarking
x_fp8, x_scale = per_token_cast_to_fp8(x)
y_fp8, y_scale = per_block_cast_to_fp8(y)
x_scale_col_major = get_col_major_tma_aligned_tensor(x_scale.clone())
x_scale_col_major = get_mn_major_tma_aligned_tensor(x_scale.clone())
quantiles = [0.5, 0.2, 0.8]
......
......@@ -4,7 +4,8 @@ import deep_gemm
import torch
import triton
import triton.language as tl
from deep_gemm import calc_diff, get_col_major_tma_aligned_tensor
from deep_gemm import calc_diff
from deep_gemm.utils.layout import get_mn_major_tma_aligned_tensor
# Import shared functionality from the regular GEMM benchmark
from sglang.benchmark.kernels.deepseek.benchmark_deepgemm_fp8_gemm import (
......@@ -71,9 +72,9 @@ def construct_grouped_and_flat_fp8(
# Transpose earlier for testing
x_fp8_grouped = (
x_fp8_grouped[0],
get_col_major_tma_aligned_tensor(x_fp8_grouped[1]),
get_mn_major_tma_aligned_tensor(x_fp8_grouped[1]),
)
x_fp8_flat = (x_fp8_flat[0], get_col_major_tma_aligned_tensor(x_fp8_flat[1]))
x_fp8_flat = (x_fp8_flat[0], get_mn_major_tma_aligned_tensor(x_fp8_flat[1]))
return x_fp8_grouped, y_fp8_grouped, x_fp8_flat, y_fp8_flat, out, ref_out
......@@ -240,7 +241,7 @@ def fp8_gemm_group_triton(a_tuple, b_tuple, c, num_groups):
def fp8_gemm_group_deepgemm(x_fp8_grouped, y_fp8_grouped, out, m_indices):
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
deep_gemm.m_grouped_fp8_gemm_nt_contiguous(
x_fp8_grouped,
y_fp8_grouped,
out,
......
......@@ -19,10 +19,6 @@ from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
from sglang.srt.layers.quantization.utils import is_layer_skipped
from sglang.srt.utils import is_npu, set_weight_attrs
_is_npu = is_npu()
if not _is_npu:
from sglang.srt.layers.moe.cutlass_w4a8_moe import cutlass_w4a8_moe
if TYPE_CHECKING:
from sglang.srt.layers.moe import MoeRunnerConfig
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
......
......@@ -621,11 +621,11 @@ class TestW8A8BlockFP8BatchedDeepGemm(CustomTestCase):
w_s,
)
from deep_gemm import m_grouped_gemm_fp8_fp8_bf16_nt_masked
from deep_gemm import fp8_m_grouped_gemm_nt_masked
with torch.inference_mode():
ref_out = torch_w8a8_block_fp8_bmm(a, a_s, w, w_s, block_size, dtype)
m_grouped_gemm_fp8_fp8_bf16_nt_masked(lhs, rhs, oe, masked_m, expected_m)
fp8_m_grouped_gemm_nt_masked(lhs, rhs, oe, masked_m, expected_m)
out = oe[:, :M, :]
self.assertTrue(
......
......@@ -251,6 +251,14 @@ To use this with your library functions, simply wrap them with make_pytorch_shim
```
2. Add benchmarks using [triton benchmark](https://triton-lang.org/main/python-api/generated/triton.testing.Benchmark.html) in [benchmark/](https://github.com/sgl-project/sglang/tree/main/sgl-kernel/benchmark)
**We recommend using `triton.testing.do_bench_cudagraph` for kernel benchmarking**:
Compared to `triton.testing.do_bench`, `do_bench_cudagraph` provides:
- Reduced CPU overhead impact for more accurate kernel performance measurements
- Incorporation of PDL (Programmatic Dependent Launch) effects into individual kernel results
- More realistic performance data on PDL-supported architectures (SM >= 90)
3. Run test suite
### FAQ
......
......@@ -10,10 +10,18 @@ import torch
import torch.nn.functional as F
import triton
import triton.testing
from sgl_kernel import gelu_quick # activation-only kernel
from sgl_kernel import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul
from vllm import _custom_ops as vllm_ops
# gelu_quick is only available on HIP/ROCm platforms
try:
from sgl_kernel import gelu_quick
GELU_QUICK_AVAILABLE = True
except ImportError:
GELU_QUICK_AVAILABLE = False
gelu_quick = None
if not hasattr(vllm_ops, "silu_and_mul"):
vllm_ops = torch.ops._C
......@@ -34,6 +42,12 @@ def calculate_diff(
# activation-only quick GELU
if kernel == "gelu_quick":
if not GELU_QUICK_AVAILABLE:
print(
f"[{kernel:14s} | {str(dtype):9s} | B={batch_size:3d} | "
f"L={seq_len:3d} | D={dim:5d}] ⚠️ not available on this platform"
)
return True
x = torch.randn(batch_size, seq_len, dim, dtype=dtype, device=device)
ref_out = torch.zeros_like(x)
getattr(vllm_ops, kernel)(ref_out, x)
......@@ -54,7 +68,9 @@ def calculate_diff(
return ok
kernels = ["silu_and_mul", "gelu_and_mul", "gelu_tanh_and_mul", "gelu_quick"]
kernels = ["silu_and_mul", "gelu_and_mul", "gelu_tanh_and_mul"]
if GELU_QUICK_AVAILABLE:
kernels.append("gelu_quick")
dtypes = [torch.float16, torch.bfloat16]
......@@ -64,7 +80,7 @@ def make_configs(bsizes: List[int], slens: List[int], dims_: List[int]) -> List[
default_batch_sizes = [2**i for i in range(0, 5, 2)] # 1,4,16
default_seq_lens = [2**i for i in range(0, 8, 2)] # 1,4,16,64
default_dims = [2**i for i in range(7, 15)] # 128...16384
default_dims = [2**i for i in range(10, 15)] # 1024...16384
@triton.testing.perf_report(
......@@ -87,6 +103,9 @@ def benchmark(kernel, dtype, batch_size, seq_len, dim, provider):
y0 = torch.zeros(batch_size, seq_len, dim, dtype=dtype, device=device)
vllm_kernel = getattr(vllm_ops, kernel)
if kernel == "gelu_quick" and not GELU_QUICK_AVAILABLE:
# Skip benchmark for gelu_quick if not available
return (0, 0, 0)
sglang_kernel = getattr(sgl_kernel, kernel)
def baseline():
......@@ -97,18 +116,14 @@ def benchmark(kernel, dtype, batch_size, seq_len, dim, provider):
def sglang():
return sglang_kernel(x)
# one-time correctness check
if provider == "vllm" and not calculate_diff(
kernel, dtype, batch_size, seq_len, dim
):
raise ValueError("Mismatch – abort benchmark")
# timing helper
def timed(fn):
for _ in range(5):
fn()
torch.cuda.synchronize()
ms, qmin, qmax = triton.testing.do_bench(fn, quantiles=[0.5, 0.2, 0.8])
ms, qmin, qmax = triton.testing.do_bench_cudagraph(
fn, quantiles=[0.5, 0.2, 0.8]
)
return 1000 * ms, 1000 * qmax, 1000 * qmin
if provider == "vllm":
......@@ -147,7 +162,9 @@ if __name__ == "__main__":
benchmark.benchmark.x_vals = benchmark_grid
if args.verify_only:
ok = calculate_diff("gelu_quick", torch.float16, 1, 1, args.dims[0])
# Test with the first available kernel
test_kernel = kernels[0]
ok = calculate_diff(test_kernel, torch.float16, 1, 1, args.dims[0])
print("✅ sanity pass" if ok else "❌ mismatch")
else:
benchmark.run(print_data=True)
......@@ -108,7 +108,7 @@ def benchmark(qweight_row, qweight_col, provider):
qweight.clone(), scales.clone(), qzeros.clone()
)
ms, min_ms, max_ms = triton.testing.do_bench(fn, quantiles=quantiles)
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(fn, quantiles=quantiles)
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
......
......@@ -87,7 +87,7 @@ def benchmark(batch_size, seq_len, provider, block_size, num_kv_splits):
workspace = torch.empty(workspace_size, device="cuda", dtype=torch.uint8)
quantiles = [0.5, 0.2, 0.8]
ms, min_ms, max_ms = triton.testing.do_bench(
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
lambda: cutlass_mla_decode(
qn.transpose(0, 1),
qr,
......@@ -136,8 +136,6 @@ if __name__ == "__main__":
print(f"block_size={block_size}, num_kv_splits={kv_split}: ")
benchmark.run(
print_data=True,
show_plots=True,
save_path="bench_blackwell_mla_res",
block_size=block_size,
num_kv_splits=kv_split,
)
......
......@@ -41,7 +41,7 @@ def benchmark(num_tokens, impl):
def runner():
dsv3_fused_a_gemm(mat_a, mat_b)
ms, min_ms, max_ms = triton.testing.do_bench(runner, quantiles=quantiles)
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(runner, quantiles=quantiles)
def tflops(t_ms):
flops = 2 * M * K * N
......@@ -54,4 +54,4 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser()
args = parser.parse_args()
benchmark.run(print_data=True, show_plots=True, save_path="bench_dsv3_gemm")
benchmark.run(print_data=True)
......@@ -52,7 +52,7 @@ def benchmark_bf16_output(num_tokens, impl):
def runner():
dsv3_router_gemm(mat_a, mat_b, out_dtype=torch.bfloat16)
ms, min_ms, max_ms = triton.testing.do_bench(runner, quantiles=quantiles)
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(runner, quantiles=quantiles)
def tflops(t_ms):
flops = 2 * M * K * N
......@@ -106,7 +106,7 @@ def benchmark_float_output(num_tokens, impl):
def runner():
dsv3_router_gemm(mat_a, mat_b, out_dtype=torch.float32)
ms, min_ms, max_ms = triton.testing.do_bench(runner, quantiles=quantiles)
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(runner, quantiles=quantiles)
def tflops(t_ms):
flops = 2 * M * K * N
......@@ -119,9 +119,5 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser()
args = parser.parse_args()
benchmark_bf16_output.run(
print_data=True, show_plots=True, save_path="bench_dsv3_router_gemm"
)
benchmark_float_output.run(
print_data=True, show_plots=True, save_path="bench_dsv3_router_gemm"
)
benchmark_bf16_output.run(print_data=True)
benchmark_float_output.run(print_data=True)
......@@ -198,8 +198,6 @@ if __name__ == "__main__":
print(f"DeepSeek-R1-0528-FP4 N={N} K={K}: ")
benchmark.run(
print_data=True,
show_plots=True,
save_path="bench_fp4_res",
N=N,
K=K,
dtype=args.dtype,
......
......@@ -5,7 +5,7 @@ import itertools
import deep_gemm
import torch
import triton
from deep_gemm import get_col_major_tma_aligned_tensor
from deep_gemm.utils.layout import get_mn_major_tma_aligned_tensor
from sgl_kernel import fp8_blockwise_scaled_mm
from vllm._custom_ops import cutlass_scaled_mm as vllm_scaled_mm
......@@ -71,7 +71,7 @@ def fp8_gemm_deepgemm(
out = torch.empty((m, n), device="cuda", dtype=torch.bfloat16)
# Run DeepGEMM kernel
deep_gemm.gemm_fp8_fp8_bf16_nt((x_fp8, x_scale), (y_fp8, y_scale), out)
deep_gemm.fp8_gemm_nt((x_fp8, x_scale), (y_fp8, y_scale), out)
return out
......@@ -117,7 +117,7 @@ def benchmark(batch_size, provider, N, K):
if provider == "sgl-kernel":
scale_a = scale_a.t().contiguous().t()
b_fp8, scale_b = b_fp8.t(), scale_b.t()
ms, min_ms, max_ms = triton.testing.do_bench(
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
lambda: fp8_blockwise_scaled_mm(
a_fp8, b_fp8, scale_a, scale_b, torch.float16
),
......@@ -126,20 +126,20 @@ def benchmark(batch_size, provider, N, K):
if provider == "vllm":
scale_a = scale_a.t().contiguous().t()
b_fp8, scale_b = b_fp8.t(), scale_b.t()
ms, min_ms, max_ms = triton.testing.do_bench(
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
lambda: vllm_scaled_mm(a_fp8, b_fp8, scale_a, scale_b, torch.float16),
quantiles=quantiles,
)
if provider == "triton":
ms, min_ms, max_ms = triton.testing.do_bench(
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
lambda: w8a8_block_fp8_matmul(
a_fp8, b_fp8, scale_a, scale_b, [128, 128], torch.float16
),
quantiles=quantiles,
)
if provider == "deepgemm":
scale_a_col_major = get_col_major_tma_aligned_tensor(scale_a.clone())
ms, min_ms, max_ms = triton.testing.do_bench(
scale_a_col_major = get_mn_major_tma_aligned_tensor(scale_a.clone())
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
lambda: fp8_gemm_deepgemm(
a_fp8, scale_a_col_major, b_fp8, scale_b, M, N, K
),
......@@ -174,8 +174,6 @@ if __name__ == "__main__":
print(f"{model_name} N={N} K={K}: ")
benchmark.run(
print_data=True,
show_plots=True,
save_path="bench_fp8_blockwise_res",
N=N,
K=K,
)
......
......@@ -125,7 +125,7 @@ def benchmark(batch_size, provider, N, K):
a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant(a, scale_a)
b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b, scale_b)
b_fp8 = b_fp8.t()
ms, min_ms, max_ms = triton.testing.do_bench(
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
lambda: vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype),
quantiles=quantiles,
)
......@@ -133,7 +133,7 @@ def benchmark(batch_size, provider, N, K):
a_fp8, scale_a_fp8 = sglang_scaled_fp8_quant(a, scale_a)
b_fp8, scale_b_fp8 = sglang_scaled_fp8_quant(b, scale_b)
b_fp8 = b_fp8.t()
ms, min_ms, max_ms = triton.testing.do_bench(
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
lambda: sgl_scaled_mm(
a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype, bias=None
),
......@@ -177,8 +177,6 @@ if __name__ == "__main__":
KN_model_names = prepare_shapes(args)
for K, N, model_name in KN_model_names:
print(f"{model_name} N={N} K={K}: ")
benchmark.run(
print_data=True, show_plots=True, save_path="bench_fp8_res", N=N, K=K
)
benchmark.run(print_data=True, N=N, K=K)
print("Benchmark finished!")
......@@ -86,12 +86,12 @@ def benchmark(batch_size, provider, N, K):
quantiles = [0.5, 0.2, 0.8]
if provider == "sgl-kernel":
ms, min_ms, max_ms = triton.testing.do_bench(
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
lambda: int8_scaled_mm(a, b, scale_a, scale_b, torch.float16, bias),
quantiles=quantiles,
)
if provider == "vllm":
ms, min_ms, max_ms = triton.testing.do_bench(
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
lambda: vllm_scaled_mm(a, b, scale_a, scale_b, torch.float16, bias),
quantiles=quantiles,
)
......@@ -139,8 +139,6 @@ if __name__ == "__main__":
KN_model_names = prepare_shapes(args)
for K, N, model_name in KN_model_names:
print(f"{model_name} N={N} K={K}: ")
benchmark.run(
print_data=True, show_plots=True, save_path="bench_int8_res", N=N, K=K
)
benchmark.run(print_data=True, N=N, K=K)
print("Benchmark finished!")
......@@ -246,7 +246,7 @@ def benchmark(batch_size, provider):
quantiles = [0.5, 0.2, 0.8]
if provider == "naive":
ms, min_ms, max_ms = triton.testing.do_bench(
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
lambda: lightning_attention_decode_naive(
q.clone(), k.clone(), v.clone(), past_kv.clone(), slope.clone()
),
......@@ -257,7 +257,7 @@ def benchmark(batch_size, provider):
batch_size, num_heads, seq_len, head_dim, device=device, dtype=dtype
)
new_kv = torch.empty(batch_size, num_heads, head_dim, head_dim, device=device)
ms, min_ms, max_ms = triton.testing.do_bench(
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
lambda: lightning_attention_decode_kernel(
q.clone(),
k.clone(),
......@@ -270,7 +270,7 @@ def benchmark(batch_size, provider):
quantiles=quantiles,
)
elif provider == "triton":
ms, min_ms, max_ms = triton.testing.do_bench(
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
lambda: triton_lightning_attn_decode(
q.clone(), k.clone(), v.clone(), past_kv.clone(), slope.clone()
),
......
......@@ -324,7 +324,7 @@ def benchmark(num_tokens, num_experts, topk, provider):
quantiles = [0.5, 0.2, 0.8]
if provider == "sgl":
ms, min_ms, max_ms = triton.testing.do_bench(
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
lambda: sgl_moe_align_block_size_with_empty(
topk_ids,
num_experts,
......@@ -336,7 +336,7 @@ def benchmark(num_tokens, num_experts, topk, provider):
quantiles=quantiles,
)
elif provider == "sgl_fusion":
ms, min_ms, max_ms = triton.testing.do_bench(
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
lambda: sgl_moe_align_block_size_with_empty(
topk_ids,
num_experts,
......@@ -350,7 +350,7 @@ def benchmark(num_tokens, num_experts, topk, provider):
)
elif provider == "triton":
sorted_ids.fill_(topk_ids.numel())
ms, min_ms, max_ms = triton.testing.do_bench(
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
lambda: moe_align_block_size_triton(
topk_ids,
num_experts,
......
......@@ -63,7 +63,9 @@ def benchmark(batch_size, provider):
block_size,
)
ms, min_ms, max_ms = triton.testing.do_bench(run_triton, quantiles=quantiles)
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
run_triton, quantiles=quantiles
)
else:
raise ValueError(f"Unknown provider: {provider}")
......
......@@ -46,7 +46,7 @@ configs = [(sq,) for sq in seq_length_range]
)
)
def benchmark(seq_length, provider):
dtype = torch.bfloat16
dtype = torch.float32
device = torch.device("cuda")
num_experts, num_expert_group, topk_group, topk = 256, 8, 4, 8
......@@ -56,14 +56,14 @@ def benchmark(seq_length, provider):
quantiles = [0.5, 0.2, 0.8]
if provider == "original":
ms, min_ms, max_ms = triton.testing.do_bench(
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
lambda: biased_grouped_topk_org(
scores.clone(), bias.clone(), num_expert_group, topk_group, topk
),
quantiles=quantiles,
)
elif provider == "kernel":
ms, min_ms, max_ms = triton.testing.do_bench(
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
lambda: biased_grouped_topk_org_fuse_kernel(
scores.clone(), bias.clone(), num_expert_group, topk_group, topk
),
......
......@@ -97,7 +97,7 @@ def benchmark(num_tokens, num_experts, topk, provider):
fn = lambda: sglang_topk_softmax(gating_output, topk)
quantiles = [0.5, 0.2, 0.8]
ms, min_ms, max_ms = triton.testing.do_bench(fn, quantiles=quantiles)
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(fn, quantiles=quantiles)
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
......
......@@ -165,8 +165,6 @@ if __name__ == "__main__":
KN_model_names = prepare_shapes(args)
for K, N, model_name in KN_model_names:
print(f"{model_name} N={N} K={K}: ")
benchmark.run(
print_data=True, show_plots=True, save_path="bench_fp4_res", N=N, K=K
)
benchmark.run(print_data=True, N=N, K=K)
print("Benchmark finished!")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment