bench_activation.py 5.55 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# Benchmarks SGLang kernels versus vLLM across
# (kernel, dtype, batch_size, seq_len, dim) and prints speed-up.
import argparse
import itertools
import re
from typing import List, Tuple

import sgl_kernel
import torch
import torch.nn.functional as F
import triton
import triton.testing
from sgl_kernel import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul
from vllm import _custom_ops as vllm_ops

16
17
18
19
20
21
22
23
24
# gelu_quick is only available on HIP/ROCm platforms
try:
    from sgl_kernel import gelu_quick

    GELU_QUICK_AVAILABLE = True
except ImportError:
    GELU_QUICK_AVAILABLE = False
    gelu_quick = None

25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
if not hasattr(vllm_ops, "silu_and_mul"):
    vllm_ops = torch.ops._C


def str2int_list(arg: str) -> List[int]:
    if arg in ("", None):
        return []
    if re.fullmatch(r"\d+(,\d+)*", arg.strip()) is None:
        raise argparse.ArgumentTypeError(f"Bad int list: {arg}")
    return [int(x) for x in arg.split(",")]


def calculate_diff(
    kernel: str, dtype: torch.dtype, batch_size: int, seq_len: int, dim: int
) -> bool:
    """Compare vLLM with SGLang for one shape."""
    device = torch.device("cuda")

    # activation-only quick GELU
    if kernel == "gelu_quick":
45
46
47
48
49
50
        if not GELU_QUICK_AVAILABLE:
            print(
                f"[{kernel:14s} | {str(dtype):9s} | B={batch_size:3d} | "
                f"L={seq_len:3d} | D={dim:5d}] ⚠️  not available on this platform"
            )
            return True
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
        x = torch.randn(batch_size, seq_len, dim, dtype=dtype, device=device)
        ref_out = torch.zeros_like(x)
        getattr(vllm_ops, kernel)(ref_out, x)
        test_out = getattr(sgl_kernel, kernel)(x)
    # fused activation x mul kernels
    else:
        x = torch.randn(batch_size, seq_len, 2 * dim, dtype=dtype, device=device)
        ref_out = torch.zeros(batch_size, seq_len, dim, dtype=dtype, device=device)
        getattr(vllm_ops, kernel)(ref_out, x)
        test_out = getattr(sgl_kernel, kernel)(x)

    ok = torch.allclose(ref_out, test_out, rtol=1e-3, atol=1e-5)
    tag = "✅ match" if ok else "❌ mismatch"
    print(
        f"[{kernel:14s} | {str(dtype):9s} | B={batch_size:3d} | "
        f"L={seq_len:3d} | D={dim:5d}] {tag}"
    )
    return ok


71
72
73
kernels = ["silu_and_mul", "gelu_and_mul", "gelu_tanh_and_mul"]
if GELU_QUICK_AVAILABLE:
    kernels.append("gelu_quick")
74
75
76
77
78
79
80
81
82
dtypes = [torch.float16, torch.bfloat16]


def make_configs(bsizes: List[int], slens: List[int], dims_: List[int]) -> List[Tuple]:
    return list(itertools.product(kernels, dtypes, bsizes, slens, dims_))


default_batch_sizes = [2**i for i in range(0, 5, 2)]  # 1,4,16
default_seq_lens = [2**i for i in range(0, 8, 2)]  # 1,4,16,64
83
default_dims = [2**i for i in range(10, 15)]  # 1024...16384
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105


@triton.testing.perf_report(
    triton.testing.Benchmark(
        x_names=["kernel", "dtype", "batch_size", "seq_len", "dim"],
        x_vals=[],
        line_arg="provider",
        line_vals=["vllm", "sglang", "speedup"],
        line_names=["vLLM", "SGL Kernel", "Speed-up (x)"],
        styles=[("blue", "-"), ("green", "-"), ("red", "--")],
        ylabel="µs (median)  or  × (speed-up)",
        plot_name="activation-performance",
        args={},
    )
)
def benchmark(kernel, dtype, batch_size, seq_len, dim, provider):
    device = torch.device("cuda")
    in_mult = 1 if kernel == "gelu_quick" else 2
    x = torch.randn(batch_size, seq_len, in_mult * dim, dtype=dtype, device=device)
    y0 = torch.zeros(batch_size, seq_len, dim, dtype=dtype, device=device)

    vllm_kernel = getattr(vllm_ops, kernel)
106
107
108
    if kernel == "gelu_quick" and not GELU_QUICK_AVAILABLE:
        # Skip benchmark for gelu_quick if not available
        return (0, 0, 0)
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
    sglang_kernel = getattr(sgl_kernel, kernel)

    def baseline():
        tmp = y0.clone()
        vllm_kernel(tmp, x)
        return tmp

    def sglang():
        return sglang_kernel(x)

    # timing helper
    def timed(fn):
        for _ in range(5):
            fn()
        torch.cuda.synchronize()
124
125
126
        ms, qmin, qmax = triton.testing.do_bench_cudagraph(
            fn, quantiles=[0.5, 0.2, 0.8]
        )
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
        return 1000 * ms, 1000 * qmax, 1000 * qmin

    if provider == "vllm":
        return timed(baseline)
    if provider == "sglang":
        return timed(sglang)

    # provider == "speedup"
    t_ref, _, _ = timed(baseline)
    t_sgl, _, _ = timed(sglang)
    spd = t_ref / t_sgl
    return (spd, spd, spd)


if __name__ == "__main__":
    p = argparse.ArgumentParser("Activation kernel benchmark")
    p.add_argument("--batch_sizes", type=str2int_list, default=default_batch_sizes)
    p.add_argument("--seq_lens", type=str2int_list, default=default_seq_lens)
    p.add_argument("--dims", type=str2int_list, default=default_dims)
    p.add_argument("--verify_only", action="store_true")
    args = p.parse_args()

    # coerce lists
    if isinstance(args.batch_sizes, str):
        args.batch_sizes = str2int_list(args.batch_sizes)
    if isinstance(args.seq_lens, str):
        args.seq_lens = str2int_list(args.seq_lens)
    if isinstance(args.dims, str):
        args.dims = str2int_list(args.dims)

    # patch perf_report grid
    benchmark_grid = make_configs(args.batch_sizes, args.seq_lens, args.dims)
    if hasattr(benchmark, "benchmarks"):
        benchmark.benchmarks.x_vals = benchmark_grid
    else:
        benchmark.benchmark.x_vals = benchmark_grid

    if args.verify_only:
165
166
167
        # Test with the first available kernel
        test_kernel = kernels[0]
        ok = calculate_diff(test_kernel, torch.float16, 1, 1, args.dims[0])
168
169
170
        print("✅ sanity pass" if ok else "❌ mismatch")
    else:
        benchmark.run(print_data=True)