bench_rmsnorm.py 12 KB
Newer Older
1
2
3
4
# Benchmarks SGLang RMSNorm kernels versus vLLM and FlashInfer across
# (batch_size, seq_len, hidden_size) and prints speed-up.
import argparse
import itertools
5
import os
6
7
8
9
10
11
12
13
14
import re
from typing import List, Optional, Tuple, Union

import sgl_kernel
import torch
import torch.nn as nn
import triton
import triton.testing
from sgl_kernel.utils import is_arch_support_pdl
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38

# Optional imports
try:
    from flashinfer.norm import fused_add_rmsnorm, rmsnorm

    FLASHINFER_AVAILABLE = True
except ImportError:
    fused_add_rmsnorm = None
    rmsnorm = None
    FLASHINFER_AVAILABLE = False

try:
    from vllm import _custom_ops as vllm_ops

    VLLM_AVAILABLE = True
except ImportError:
    vllm_ops = None
    VLLM_AVAILABLE = False

# CI environment detection
IS_CI = (
    os.getenv("CI", "false").lower() == "true"
    or os.getenv("GITHUB_ACTIONS", "false").lower() == "true"
)
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104


def str2int_list(arg: str) -> List[int]:
    if arg in ("", None):
        return []
    if re.fullmatch(r"\d+(,\d+)*", arg.strip()) is None:
        raise argparse.ArgumentTypeError(f"Bad int list: {arg}")
    return [int(x) for x in arg.split(",")]


class HuggingFaceRMSNorm(nn.Module):
    def __init__(self, hidden_size: int, eps: float = 1e-6) -> None:
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.variance_epsilon = eps

    def forward(
        self,
        x: torch.Tensor,
        residual: Optional[torch.Tensor] = None,
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
        orig_dtype = x.dtype
        x = x.to(torch.float32)
        if residual is not None:
            x = x + residual.to(torch.float32)
            residual = x.to(orig_dtype)

        variance = x.pow(2).mean(dim=-1, keepdim=True)
        x = x * torch.rsqrt(variance + self.variance_epsilon)
        x = x.to(orig_dtype) * self.weight
        if residual is None:
            return x
        else:
            return x, residual


def rmsnorm_naive(
    x: torch.Tensor,
    weight: torch.Tensor,
    residual: Optional[torch.Tensor] = None,
    eps: float = 1e-6,
):
    naive_norm = HuggingFaceRMSNorm(x.shape[-1], eps=eps)
    naive_norm.weight = nn.Parameter(weight)
    naive_norm = naive_norm.to(x.device)

    orig_shape = x.shape
    x = x.view(-1, x.shape[-1])
    if residual is not None:
        residual = residual.view(-1, residual.shape[-1])

    output = naive_norm(x, residual)

    if isinstance(output, tuple):
        output = (output[0].view(orig_shape), output[1].view(orig_shape))
    else:
        output = output.view(orig_shape)
    return output


def rmsnorm_flashinfer(
    x: torch.Tensor,
    weight: torch.Tensor,
    residual: Optional[torch.Tensor] = None,
    eps: float = 1e-6,
):
105
106
107
108
    if not FLASHINFER_AVAILABLE:
        # Fallback to naive implementation if FlashInfer is not available
        return rmsnorm_naive(x, weight, residual, eps)

109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
    orig_shape = x.shape
    x = x.view(-1, x.shape[-1])
    if residual is not None:
        residual = residual.view(-1, residual.shape[-1])

    if residual is not None:
        fused_add_rmsnorm(x, residual, weight, eps)
        output = (x, residual)
    else:
        output = rmsnorm(x, weight, eps)

    if isinstance(output, tuple):
        output = (output[0].view(orig_shape), output[1].view(orig_shape))
    else:
        output = output.view(orig_shape)
    return output


def rmsnorm_vllm(
    x: torch.Tensor,
    weight: torch.Tensor,
    residual: Optional[torch.Tensor] = None,
    eps: float = 1e-6,
):
133
134
135
136
    if not VLLM_AVAILABLE:
        # Fallback to naive implementation if vLLM is not available
        return rmsnorm_naive(x, weight, residual, eps)

137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
    orig_shape = x.shape
    x = x.view(-1, x.shape[-1])
    if residual is not None:
        residual = residual.view(-1, residual.shape[-1])

    if residual is not None:
        vllm_ops.fused_add_rms_norm(x, residual, weight, eps)
        output = (x, residual)
    else:
        out = torch.empty_like(x)
        vllm_ops.rms_norm(out, x, weight, eps)
        output = out

    if isinstance(output, tuple):
        output = (output[0].view(orig_shape), output[1].view(orig_shape))
    else:
        output = output.view(orig_shape)
    return output


def rmsnorm_sglang(
    x: torch.Tensor,
    weight: torch.Tensor,
    residual: Optional[torch.Tensor] = None,
    eps: float = 1e-6,
    enable_pdl: Optional[bool] = None,
):
    orig_shape = x.shape
    x = x.view(-1, x.shape[-1])
    if residual is not None:
        residual = residual.view(-1, residual.shape[-1])

    if enable_pdl is None:
        enable_pdl = is_arch_support_pdl()

    if residual is not None:
        sgl_kernel.fused_add_rmsnorm(x, residual, weight, eps, enable_pdl=enable_pdl)
        output = (x, residual)
    else:
        out = torch.empty_like(x)
        sgl_kernel.rmsnorm(x, weight, eps, out=out, enable_pdl=enable_pdl)
        output = out

    if isinstance(output, tuple):
        output = (output[0].view(orig_shape), output[1].view(orig_shape))
    else:
        output = output.view(orig_shape)
    return output


def calculate_diff(batch_size, seq_len, hidden_size, use_residual=True):
    dtype = torch.bfloat16
    x = torch.randn(batch_size, seq_len, hidden_size, dtype=dtype, device="cuda")
    weight = torch.ones(hidden_size, dtype=dtype, device="cuda")
    residual = torch.randn_like(x) if use_residual else None

    output_naive = rmsnorm_naive(
        x.clone(), weight, residual.clone() if residual is not None else None
    )
    output_flashinfer = rmsnorm_flashinfer(
        x.clone(), weight, residual.clone() if residual is not None else None
    )
    output_vllm = rmsnorm_vllm(
        x.clone(), weight, residual.clone() if residual is not None else None
    )
    output_sglang = rmsnorm_sglang(
        x.clone(), weight, residual.clone() if residual is not None else None
    )

    if use_residual:
        output_naive = output_naive[0]
        output_flashinfer = output_flashinfer[0]
        output_vllm = output_vllm[0]
        output_sglang = output_sglang[0]

    print(f"Naive output={output_naive}")
213
214
215
216
217
218
219
220
    if FLASHINFER_AVAILABLE:
        print(f"FlashInfer output={output_flashinfer}")
    else:
        print("FlashInfer not available, skipped")
    if VLLM_AVAILABLE:
        print(f"VLLM output={output_vllm}")
    else:
        print("vLLM not available, skipped")
221
222
    print(f"SGLang output={output_sglang}")

223
224
225
226
227
228
229
230
231
232
233
234
235
    # Only compare available implementations
    all_match = torch.allclose(output_naive, output_sglang, atol=1e-2, rtol=1e-2)
    if FLASHINFER_AVAILABLE:
        all_match = all_match and torch.allclose(
            output_naive, output_flashinfer, atol=1e-2, rtol=1e-2
        )
    if VLLM_AVAILABLE:
        all_match = all_match and torch.allclose(
            output_naive, output_vllm, atol=1e-2, rtol=1e-2
        )

    if all_match:
        print("✅ All available implementations match")
236
237
238
239
    else:
        print("❌ Implementations differ")


240
241
242
243
244
245
246
247
248
# CI environment uses simplified parameters
if IS_CI:
    default_batch_sizes = [1]  # Single batch size for CI
    default_seq_lens = [64]  # Single sequence length for CI
    default_hidden_sizes = [4096]  # Single hidden size for CI
else:
    default_batch_sizes = [2**i for i in range(0, 7, 2)]  # 1, 4, 16, 64
    default_seq_lens = [2**i for i in range(6, 11, 1)]  # 64, 128, 256, 512, 1024
    default_hidden_sizes = [32 * 128, 48 * 128]  # 4096, 6144
249
250
251
252
253
254


def make_configs(bsizes: List[int], slens: List[int], hsizes: List[int]) -> List[Tuple]:
    return list(itertools.product(bsizes, slens, hsizes))


255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
# Filter providers based on availability
available_providers = ["huggingface", "sglang"]
available_names = ["HuggingFace", "SGL Kernel"]
available_styles = [("blue", "-"), ("orange", "-")]

if FLASHINFER_AVAILABLE:
    available_providers.insert(-1, "flashinfer")
    available_names.insert(-1, "FlashInfer")
    available_styles.insert(-1, ("green", "-"))

if VLLM_AVAILABLE:
    available_providers.insert(-1, "vllm")
    available_names.insert(-1, "vLLM")
    available_styles.insert(-1, ("red", "-"))


271
272
273
274
275
@triton.testing.perf_report(
    triton.testing.Benchmark(
        x_names=["batch_size", "seq_len", "hidden_size"],
        x_vals=[],
        line_arg="provider",
276
277
278
        line_vals=available_providers,
        line_names=available_names,
        styles=available_styles,
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
        ylabel="µs (median)  or  × (speed-up)",
        plot_name="rmsnorm-performance",
        args={},
    )
)
def benchmark(batch_size, seq_len, hidden_size, provider, use_residual):
    device = torch.device("cuda")
    dtype = torch.bfloat16

    x = torch.randn(batch_size, seq_len, hidden_size, dtype=dtype, device=device)
    weight = torch.ones(hidden_size, dtype=dtype, device=device)
    residual = torch.randn_like(x) if use_residual else None

    # timing helper
    def timed(fn):
        for _ in range(5):
            fn()
        torch.cuda.synchronize()
        ms, qmin, qmax = triton.testing.do_bench_cudagraph(
            fn, quantiles=[0.5, 0.2, 0.8]
        )
        return 1000 * ms, 1000 * qmax, 1000 * qmin

    if provider == "huggingface":
        return timed(
            lambda: rmsnorm_naive(
                x.clone(),
                weight,
                residual.clone() if residual is not None else None,
            )
        )
    elif provider == "flashinfer":
311
312
        if not FLASHINFER_AVAILABLE:
            return (0, 0, 0)
313
314
315
316
317
318
319
320
        return timed(
            lambda: rmsnorm_flashinfer(
                x.clone(),
                weight,
                residual.clone() if residual is not None else None,
            )
        )
    elif provider == "vllm":
321
322
        if not VLLM_AVAILABLE:
            return (0, 0, 0)
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
        return timed(
            lambda: rmsnorm_vllm(
                x.clone(),
                weight,
                residual.clone() if residual is not None else None,
            )
        )
    elif provider == "sglang":
        return timed(
            lambda: rmsnorm_sglang(
                x.clone(),
                weight,
                residual.clone() if residual is not None else None,
            )
        )

    # provider == "speedup"
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
    if VLLM_AVAILABLE:
        t_ref, _, _ = timed(
            lambda: rmsnorm_vllm(
                x.clone(),
                weight,
                residual.clone() if residual is not None else None,
            )
        )
    else:
        t_ref, _, _ = timed(
            lambda: rmsnorm_naive(
                x.clone(),
                weight,
                residual.clone() if residual is not None else None,
            )
355
356
357
358
359
360
361
362
        )
    t_sgl, _, _ = timed(
        lambda: rmsnorm_sglang(
            x.clone(),
            weight,
            residual.clone() if residual is not None else None,
        )
    )
363
    spd = t_ref / t_sgl if t_ref > 0 else 1.0
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
    return (spd, spd, spd)


if __name__ == "__main__":
    p = argparse.ArgumentParser("RMSNorm kernel benchmark")
    p.add_argument("--batch_sizes", type=str2int_list, default=default_batch_sizes)
    p.add_argument("--seq_lens", type=str2int_list, default=default_seq_lens)
    p.add_argument("--hidden_sizes", type=str2int_list, default=default_hidden_sizes)
    p.add_argument(
        "--use_residual", action="store_true", help="Whether to use residual connection"
    )
    p.add_argument("--verify_only", action="store_true")
    args = p.parse_args()

    # coerce lists
    if isinstance(args.batch_sizes, str):
        args.batch_sizes = str2int_list(args.batch_sizes)
    if isinstance(args.seq_lens, str):
        args.seq_lens = str2int_list(args.seq_lens)
    if isinstance(args.hidden_sizes, str):
        args.hidden_sizes = str2int_list(args.hidden_sizes)

    # patch perf_report grid
    benchmark_grid = make_configs(args.batch_sizes, args.seq_lens, args.hidden_sizes)
    if hasattr(benchmark, "benchmarks"):
        benchmark.benchmarks.x_vals = benchmark_grid
    else:
        benchmark.benchmark.x_vals = benchmark_grid

    if args.verify_only:
        ok = calculate_diff(4, 128, args.hidden_sizes[0], args.use_residual)
        print("✅ sanity pass" if ok else "❌ mismatch")
    else:
        benchmark.run(print_data=True, use_residual=args.use_residual)