bench_per_token_group_quant_8bit.py 3.02 KB
Newer Older
1
import itertools
2
import os
3
4
5
import time
from functools import partial
from pathlib import Path
6
7
8
9

import torch
import triton

10
11
12
13
14
15
16
from sglang.srt.layers.quantization.fp8_kernel import (
    create_per_token_group_quant_fp8_output_scale,
)
from sglang.srt.layers.quantization.fp8_kernel import (
    per_token_group_quant_8bit as triton_per_token_group_quant_8bit,
)
from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_group_quant_8bit
17
from sglang.srt.utils import is_hip
18
from sglang.srt.utils.bench_utils import bench_kineto
19

20
21
22
23
24
25
# CI environment detection
IS_CI = (
    os.getenv("CI", "false").lower() == "true"
    or os.getenv("GITHUB_ACTIONS", "false").lower() == "true"
)

26
27
_is_hip = is_hip()
fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
28
29


30
31
32
33
34
35
36
37
38
39
40
41
# CI environment uses simplified parameters
if IS_CI:
    num_tokens_range = [64]  # Single value for CI
    hidden_dim_range = [1536]  # Single value for CI
    group_size_range = [128]  # Keep as is
    dst_dtype_range = [fp8_type_]  # Keep as is
else:
    num_tokens_range = [1, 4, 16, 64, 256, 768, 2048, 8192, 16384]
    hidden_dim_range = [1536, 7168, 18432]  # For DeepSeek V3/R1
    group_size_range = [128]  # For DeepSeek V3/R1
    # TODO test int8
    dst_dtype_range = [fp8_type_]
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
flags_range = [
    dict(
        column_major_scales=False,
        scale_tma_aligned=False,
        scale_ue8m0=False,
    ),
    dict(
        column_major_scales=True,
        scale_tma_aligned=False,
        scale_ue8m0=False,
    ),
    dict(
        column_major_scales=True,
        scale_tma_aligned=True,
        scale_ue8m0=False,
    ),
    dict(
        column_major_scales=True,
        scale_tma_aligned=True,
        scale_ue8m0=True,
    ),
]


configs = list(
    itertools.product(
        num_tokens_range,
        hidden_dim_range,
        group_size_range,
        dst_dtype_range,
        flags_range,
73
    )
74
)
75
76
77
78


@triton.testing.perf_report(
    triton.testing.Benchmark(
79
        x_names=["num_tokens", "hidden_dim", "group_size", "dst_dtype", "flags"],
80
81
82
        x_vals=configs,
        line_arg="provider",
        line_vals=["triton", "sglang"],
83
        line_names=["Triton", "SGL Kernel"],
84
85
        styles=[("blue", "-"), ("green", "-")],
        ylabel="us",
86
        plot_name="per-token-group-quant-8bit-performance",
87
88
89
        args={},
    )
)
90
91
92
def benchmark(num_tokens, hidden_dim, group_size, dst_dtype, flags, provider):
    if flags["scale_ue8m0"] and group_size != 128:
        return
93

94
95
96
    device = torch.device("cuda")

    x = torch.randn(num_tokens, hidden_dim, device=device, dtype=torch.bfloat16)
97

98
    fn, kernel_names = {
99
        "triton": (triton_per_token_group_quant_8bit, "_per_token_group_quant_8bit"),
100
101
102
103
104
        "sglang": (
            sglang_per_token_group_quant_8bit,
            "per_token_group_quant_8bit_kernel",
        ),
    }[provider]
105
    bench_fn = lambda: fn(x=x, group_size=group_size, dst_dtype=dst_dtype, **flags)
106

107
    time_s = bench_kineto(bench_fn, kernel_names=kernel_names)
108
    return time_s * 1e6
109
110
111
112


if __name__ == "__main__":
    benchmark.run(print_data=True)