bench_per_token_group_quant_8bit.py 2.59 KB
Newer Older
1
import itertools
2
3
4
import time
from functools import partial
from pathlib import Path
5
6
7
8

import torch
import triton

9
10
11
12
13
14
15
16
from sglang.srt.bench_utils import bench_kineto
from sglang.srt.layers.quantization.fp8_kernel import (
    create_per_token_group_quant_fp8_output_scale,
)
from sglang.srt.layers.quantization.fp8_kernel import (
    per_token_group_quant_8bit as triton_per_token_group_quant_8bit,
)
from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_group_quant_8bit
17
from sglang.srt.utils import is_hip
18

19
20
_is_hip = is_hip()
fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
21
22


23
24
num_tokens_range = [1, 4, 16, 64, 256, 768, 2048, 8192, 16384]
hidden_dim_range = [1536, 7168, 18432]  # For DeepSeek V3/R1
25
group_size_range = [128]  # For DeepSeek V3/R1
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
# TODO test int8
dst_dtype_range = [fp8_type_]
flags_range = [
    dict(
        column_major_scales=False,
        scale_tma_aligned=False,
        scale_ue8m0=False,
    ),
    dict(
        column_major_scales=True,
        scale_tma_aligned=False,
        scale_ue8m0=False,
    ),
    dict(
        column_major_scales=True,
        scale_tma_aligned=True,
        scale_ue8m0=False,
    ),
    dict(
        column_major_scales=True,
        scale_tma_aligned=True,
        scale_ue8m0=True,
    ),
]

51

52
53
configs = list(
    itertools.product(
54
55
56
57
58
        num_tokens_range,
        hidden_dim_range,
        group_size_range,
        dst_dtype_range,
        flags_range,
59
60
    )
)
61
62
63
64


@triton.testing.perf_report(
    triton.testing.Benchmark(
65
        x_names=["num_tokens", "hidden_dim", "group_size", "dst_dtype", "flags"],
66
67
68
69
70
71
        x_vals=configs,
        line_arg="provider",
        line_vals=["triton", "sglang"],
        line_names=["Triton", "SGL Kernel"],
        styles=[("blue", "-"), ("green", "-")],
        ylabel="us",
72
        plot_name="per-token-group-quant-8bit-performance",
73
74
75
        args={},
    )
)
76
77
78
def benchmark(num_tokens, hidden_dim, group_size, dst_dtype, flags, provider):
    if flags["scale_ue8m0"] and group_size != 128:
        return
79

80
    device = torch.device("cuda")
81

82
    x = torch.randn(num_tokens, hidden_dim, device=device, dtype=torch.bfloat16)
83

84
85
86
87
88
89
90
91
    fn, kernel_names = {
        "triton": (triton_per_token_group_quant_8bit, "_per_token_group_quant_fp8"),
        "sglang": (
            sglang_per_token_group_quant_8bit,
            "per_token_group_quant_8bit_kernel",
        ),
    }[provider]
    bench_fn = lambda: fn(x=x, group_size=group_size, dst_dtype=dst_dtype, **flags)
92

93
94
    time_s = bench_kineto(bench_fn, kernel_names=kernel_names)
    return time_s * 1e6
95
96
97
98


if __name__ == "__main__":
    benchmark.run(print_data=True)