bench_per_token_group_quant_8bit.py 6.38 KB
Newer Older
1
import itertools
2
from typing import Tuple
3
4
5
6

import torch
import triton
import triton.language as tl
7
from sgl_kernel import sgl_per_token_group_quant_fp8, sgl_per_token_group_quant_int8
8

9
from sglang.srt.utils import is_hip
10

11
12
_is_hip = is_hip()
fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
13
14
15


@triton.jit
16
def _per_token_group_quant_8bit(
17
18
19
20
21
22
23
24
25
26
    # Pointers to inputs and output
    y_ptr,
    y_q_ptr,
    y_s_ptr,
    # Stride of input
    y_stride,
    # Collums of input
    N,
    # Avoid to divide zero
    eps,
27
28
29
    # Information for 8bit data type (int8 or fp8_type_)
    max_8bit,
    min_8bit,
30
31
32
33
34
    # Meta-parameters
    BLOCK: tl.constexpr,
):
    """A Triton-accelerated function to perform per-token-group quantization on a
    tensor.
35
    This function converts the tensor values into 8bit values.
36
37
38
39
40
41
42
43
44
45
46
47
48
    """
    # Map the program id to the row of X and Y it should compute.
    g_id = tl.program_id(0)
    y_ptr += g_id * y_stride
    y_q_ptr += g_id * y_stride
    y_s_ptr += g_id

    cols = tl.arange(0, BLOCK)  # N <= BLOCK
    mask = cols < N

    y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32)
    # Quant
    _absmax = tl.maximum(tl.max(tl.abs(y)), eps)
49
50
    y_s = _absmax / max_8bit
    y_q = tl.clamp(y / y_s, min_8bit, max_8bit).to(y_q_ptr.dtype.element_ty)
51
52
53
54
55

    tl.store(y_q_ptr + cols, y_q, mask=mask)
    tl.store(y_s_ptr, y_s)


56
def triton_per_token_group_quant_8bit(
57
58
    x: torch.Tensor,
    group_size: int,
59
    dst_dtype: torch.dtype,
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
    eps: float = 1e-10,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """Function to perform per-token-group quantization on an input tensor `x`.
    It converts the tensor values into signed float8 values and returns the
    quantized tensor along with the scaling factor used for quantization.
    Args:
        x: The input tenosr with ndim >= 2.
        group_size: The group size used for quantization.
        eps: The minimum to avoid dividing zero.
        dtype: The dype of output tensor. Note that only `torch.float8_e4m3fn` is supported for now.
    Returns:
        Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the scaling factor for quantization.
    """
    assert (
        x.shape[-1] % group_size == 0
    ), "the last dimension of `x` cannot be divisible by `group_size`"
    assert x.is_contiguous(), "`x` is not contiguous"

78
79
80
81
82
83
84
85
    if dst_dtype == torch.int8:
        iinfo = torch.iinfo(dst_dtype)
        max_8bit = iinfo.max
        min_8bit = iinfo.min
    else:
        finfo = torch.finfo(dst_dtype)
        max_8bit = finfo.max
        min_8bit = finfo.min
86

87
    x_q = torch.empty_like(x, device=x.device, dtype=dst_dtype)
88
89
90
91
92
93
94
95
96
97
98
99
    M = x.numel() // group_size
    N = group_size
    x_s = torch.empty(
        x.shape[:-1] + (x.shape[-1] // group_size,),
        device=x.device,
        dtype=torch.float32,
    )

    BLOCK = triton.next_power_of_2(N)
    # heuristics for number of warps
    num_warps = min(max(BLOCK // 256, 1), 8)
    num_stages = 1
100
    _per_token_group_quant_8bit[(M,)](
101
102
103
104
105
106
        x,
        x_q,
        x_s,
        group_size,
        N,
        eps,
107
108
        max_8bit,
        min_8bit,
109
110
111
112
113
114
115
116
        BLOCK=BLOCK,
        num_warps=num_warps,
        num_stages=num_stages,
    )

    return x_q, x_s


117
def sglang_per_token_group_quant_8bit(
118
119
    x: torch.Tensor,
    group_size: int,
120
    dst_dtype: torch.dtype,
121
122
123
124
125
126
127
    eps: float = 1e-10,
):
    assert (
        x.shape[-1] % group_size == 0
    ), "the last dimension of `x` cannot be divisible by `group_size`"
    assert x.is_contiguous(), "`x` is not contiguous"

128
    x_q = torch.empty_like(x, device=x.device, dtype=dst_dtype)
129
130
131
132
133
134
    x_s = torch.empty(
        x.shape[:-1] + (x.shape[-1] // group_size,),
        device=x.device,
        dtype=torch.float32,
    )

135
136
137
138
139
140
141
142
143
144
    if dst_dtype == torch.int8:
        iinfo = torch.iinfo(dst_dtype)
        int8_max = iinfo.max
        int8_min = iinfo.min
        sgl_per_token_group_quant_int8(x, x_q, x_s, group_size, eps, int8_min, int8_max)
    else:
        f8_info = torch.finfo(dst_dtype)
        fp8_max = f8_info.max
        fp8_min = f8_info.min
        sgl_per_token_group_quant_fp8(x, x_q, x_s, group_size, eps, fp8_min, fp8_max)
145
146
147
148

    return x_q, x_s


149
def calculate_diff(batch_size, seq_len, group_size, dst_dtype):
150
    device = torch.device("cuda")
151
    hidden_dim = 7168
152

153
154
155
    x = torch.randn(
        batch_size * seq_len, hidden_dim, device=device, dtype=torch.float16
    )
156

157
158
159
160
161
162
    x_q_triton, x_s_triton = triton_per_token_group_quant_8bit(
        x.clone(), group_size, dst_dtype
    )
    x_q_sglang, x_s_sglang = sglang_per_token_group_quant_8bit(
        x.clone(), group_size, dst_dtype
    )
163
164
165
166

    if torch.allclose(
        x_q_triton.to(torch.float32), x_q_sglang.to(torch.float32), rtol=1e-3, atol=1e-5
    ) and torch.allclose(x_s_triton, x_s_sglang, rtol=1e-3, atol=1e-5):
167
        print(f"✅ {dst_dtype} implementations match")
168
169
170
171
172
173
174
    else:
        print("❌ Implementations differ")


batch_size_range = [1, 2, 4, 8, 16, 32, 64]
seq_len_range = [64, 128, 256, 512, 1024, 2048]
group_size_range = [128]  # For DeepSeek V3/R1
175
dst_dtype_range = [torch.int8, fp8_type_]
176

177
178
179
180
181
configs = list(
    itertools.product(
        batch_size_range, seq_len_range, group_size_range, dst_dtype_range
    )
)
182
183
184
185


@triton.testing.perf_report(
    triton.testing.Benchmark(
186
        x_names=["batch_size", "seq_len", "group_size", "dst_dtype"],
187
188
189
190
191
192
        x_vals=configs,
        line_arg="provider",
        line_vals=["triton", "sglang"],
        line_names=["Triton", "SGL Kernel"],
        styles=[("blue", "-"), ("green", "-")],
        ylabel="us",
193
        plot_name="per-token-group-quant-8bit-performance",
194
195
196
        args={},
    )
)
197
def benchmark(batch_size, seq_len, group_size, dst_dtype, provider):
198
    device = torch.device("cuda")
199
    hidden_dim = 7168
200

201
202
203
    x = torch.randn(
        batch_size * seq_len, hidden_dim, device=device, dtype=torch.float16
    )
204
205
206
207

    quantiles = [0.5, 0.2, 0.8]

    if provider == "triton":
208
        fn = lambda: triton_per_token_group_quant_8bit(x.clone(), group_size, dst_dtype)
209
    elif provider == "sglang":
210
        fn = lambda: sglang_per_token_group_quant_8bit(x.clone(), group_size, dst_dtype)
211
212
213
214
215
216
217
218

    ms, min_ms, max_ms = triton.testing.do_bench(fn, quantiles=quantiles)

    return 1000 * ms, 1000 * max_ms, 1000 * min_ms


if __name__ == "__main__":

219
220
    calculate_diff(batch_size=4, seq_len=128, group_size=64, dst_dtype=torch.int8)
    calculate_diff(batch_size=4, seq_len=128, group_size=64, dst_dtype=fp8_type_)
221
222

    benchmark.run(print_data=True)