test_per_token_group_quant_8bit.py 8.84 KB
Newer Older
1
import itertools
2
from typing import Tuple
3
4
5
6
7

import pytest
import torch
import triton
import triton.language as tl
8
from sgl_kernel import sgl_per_token_group_quant_fp8, sgl_per_token_group_quant_int8
9

10
from sglang.srt.utils import is_hip
11

12
13
_is_hip = is_hip()
fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
14
15
16


@triton.jit
17
def _per_token_group_quant_fp8(
18
19
20
21
22
23
24
25
26
27
    # Pointers to inputs and output
    y_ptr,
    y_q_ptr,
    y_s_ptr,
    # Stride of input
    y_stride,
    # Collums of input
    N,
    # Avoid to divide zero
    eps,
28
29
30
    # Information for float8
    fp8_min,
    fp8_max,
31
32
33
34
35
    # Meta-parameters
    BLOCK: tl.constexpr,
):
    """A Triton-accelerated function to perform per-token-group quantization on a
    tensor.
36
37

    This function converts the tensor values into float8 values.
38
39
40
41
42
43
44
45
46
47
48
49
50
    """
    # Map the program id to the row of X and Y it should compute.
    g_id = tl.program_id(0)
    y_ptr += g_id * y_stride
    y_q_ptr += g_id * y_stride
    y_s_ptr += g_id

    cols = tl.arange(0, BLOCK)  # N <= BLOCK
    mask = cols < N

    y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32)
    # Quant
    _absmax = tl.maximum(tl.max(tl.abs(y)), eps)
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
    y_s = _absmax / fp8_max
    y_s_inv = 1.0 / y_s
    y_q = tl.clamp(y * y_s_inv, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)

    tl.store(y_q_ptr + cols, y_q, mask=mask)
    tl.store(y_s_ptr, y_s)


@triton.jit
def _per_token_group_quant_fp8_colmajor(
    # Pointers to inputs and output
    y_ptr,
    y_q_ptr,
    y_s_ptr,
    group_size,
    # Num columns of y
    y_num_columns,
    # Stride from one column to the next of y_s
    y_s_col_stride,
    # Avoid to divide zero
    eps,
    # Information for float8
    fp8_min,
    fp8_max,
    # Meta-parameters
    BLOCK: tl.constexpr,
):
    """A Triton-accelerated function to perform per-token-group
    quantization on a tensor.
    This function converts the tensor values into float8 values.
    """
    # Map the program id to the row of X and Y it should compute.
    g_id = tl.program_id(0)
    y_ptr += g_id * group_size
    y_q_ptr += g_id * group_size

    # Convert g_id the flattened block coordinate to 2D so we can index
    # into the output y_scales matrix
    blocks_per_row = y_num_columns // group_size
    scale_col = g_id % blocks_per_row
    scale_row = g_id // blocks_per_row
    y_s_ptr += scale_col * y_s_col_stride + scale_row

    cols = tl.arange(0, BLOCK)  # group_size <= BLOCK
    mask = cols < group_size

    y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32)
    # Quant
    _absmax = tl.maximum(tl.max(tl.abs(y)), eps)
    y_s = _absmax / fp8_max
    y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
102
103
104
105
106

    tl.store(y_q_ptr + cols, y_q, mask=mask)
    tl.store(y_s_ptr, y_s)


107
def triton_per_token_group_quant_8bit(
108
109
110
    x: torch.Tensor,
    group_size: int,
    eps: float = 1e-10,
111
112
113
    dtype: torch.dtype = fp8_type_,
    column_major_scales: bool = False,
    scale_tma_aligned: bool = False,
114
115
) -> Tuple[torch.Tensor, torch.Tensor]:
    """Function to perform per-token-group quantization on an input tensor `x`.
116

117
118
    It converts the tensor values into signed float8 values and returns the
    quantized tensor along with the scaling factor used for quantization.
119

120
121
122
123
    Args:
        x: The input tenosr with ndim >= 2.
        group_size: The group size used for quantization.
        eps: The minimum to avoid dividing zero.
124
125
        dtype: The dype of output tensor.

126
127
128
129
130
131
132
133
    Returns:
        Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the scaling factor for quantization.
    """
    assert (
        x.shape[-1] % group_size == 0
    ), "the last dimension of `x` cannot be divisible by `group_size`"
    assert x.is_contiguous(), "`x` is not contiguous"

134
135
    if dtype == torch.int8:
        finfo = torch.iinfo(dtype)
136
    else:
137
138
139
140
141
142
143
144
145
146
147
        finfo = torch.finfo(dtype)

    fp8_max = finfo.max

    if _is_hip:
        if dtype == torch.int8:
            fp8_max = 127.0
        else:
            fp8_max = 224.0

    fp8_min = -fp8_max
148

149
    x_q = torch.empty_like(x, device=x.device, dtype=dtype)
150
151
    M = x.numel() // group_size
    N = group_size
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
    if column_major_scales:
        if scale_tma_aligned:
            # aligned to 4 * sizeof(float)
            aligned_size = (x.shape[-2] + 3) // 4 * 4
            x_s = torch.empty(
                x.shape[:-2] + (x.shape[-1] // group_size, aligned_size),
                device=x.device,
                dtype=torch.float32,
            ).permute(-1, -2)[: x.shape[-2], :]
        else:
            x_s = torch.empty(
                (x.shape[-1] // group_size,) + x.shape[:-1],
                device=x.device,
                dtype=torch.float32,
            ).permute(-1, -2)
    else:
        x_s = torch.empty(
            x.shape[:-1] + (x.shape[-1] // group_size,),
            device=x.device,
            dtype=torch.float32,
        )
173
174
175
176
177

    BLOCK = triton.next_power_of_2(N)
    # heuristics for number of warps
    num_warps = min(max(BLOCK // 256, 1), 8)
    num_stages = 1
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
    if column_major_scales:
        _per_token_group_quant_fp8_colmajor[(M,)](
            x,
            x_q,
            x_s,
            group_size,
            x.shape[1],
            x_s.stride(1),
            eps,
            fp8_min=fp8_min,
            fp8_max=fp8_max,
            BLOCK=BLOCK,
            num_warps=num_warps,
            num_stages=num_stages,
        )
    else:
        _per_token_group_quant_fp8[(M,)](
            x,
            x_q,
            x_s,
            group_size,
            N,
            eps,
            fp8_min=fp8_min,
            fp8_max=fp8_max,
            BLOCK=BLOCK,
            num_warps=num_warps,
            num_stages=num_stages,
        )
207
208
209
210

    return x_q, x_s


211
def sglang_per_token_group_quant_8bit(
212
213
214
    x: torch.Tensor,
    group_size: int,
    eps: float = 1e-10,
215
216
217
    dtype: torch.dtype = fp8_type_,
    column_major_scales: bool = False,
    scale_tma_aligned: bool = False,
218
219
220
221
222
223
):
    assert (
        x.shape[-1] % group_size == 0
    ), "the last dimension of `x` cannot be divisible by `group_size`"
    assert x.is_contiguous(), "`x` is not contiguous"

224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
    x_q = torch.empty_like(x, device=x.device, dtype=dtype)
    M = x.numel() // group_size
    N = group_size
    if column_major_scales:
        if scale_tma_aligned:
            # aligned to 4 * sizeof(float)
            aligned_size = (x.shape[-2] + 3) // 4 * 4
            x_s = torch.empty(
                x.shape[:-2] + (x.shape[-1] // group_size, aligned_size),
                device=x.device,
                dtype=torch.float32,
            ).permute(-1, -2)[: x.shape[-2], :]
        else:
            x_s = torch.empty(
                (x.shape[-1] // group_size,) + x.shape[:-1],
                device=x.device,
                dtype=torch.float32,
            ).permute(-1, -2)
    else:
        x_s = torch.empty(
            x.shape[:-1] + (x.shape[-1] // group_size,),
            device=x.device,
            dtype=torch.float32,
        )
248

249
250
    if dtype == torch.int8:
        iinfo = torch.iinfo(dtype)
251
252
253
254
        int8_max = iinfo.max
        int8_min = iinfo.min
        sgl_per_token_group_quant_int8(x, x_q, x_s, group_size, eps, int8_min, int8_max)
    else:
255
        f8_info = torch.finfo(dtype)
256
257
258
        fp8_max = f8_info.max
        fp8_min = f8_info.min
        sgl_per_token_group_quant_fp8(x, x_q, x_s, group_size, eps, fp8_min, fp8_max)
259
260
261
262
263

    return x_q, x_s


@pytest.mark.parametrize(
264
    "num_tokens, hidden_dim, group_size, dst_dtype, column_major_scales, scale_tma_aligned",
265
266
    list(
        itertools.product(
267
268
269
            [127, 128, 512, 1024, 4096, 8192],  # num_tokens
            [256, 512, 1024, 2048, 4096],  # hidden_dim
            [8, 16, 32, 64, 128],  # group_size
270
            [torch.int8, fp8_type_],  # dtype
271
272
            [False, True],  # column_major_scales
            [False, True],  # scale_tma_aligned
273
274
275
        )
    ),
)
276
277
278
279
280
281
282
def test_per_token_group_quant_with_column_major(
    num_tokens,
    hidden_dim,
    group_size,
    dst_dtype,
    column_major_scales,
    scale_tma_aligned,
283
):
284
285
286
287
288
289
290
291
292
293
294
295
    if not column_major_scales and scale_tma_aligned:
        return

    x = torch.randn(num_tokens, hidden_dim, device="cuda", dtype=torch.float16)

    x_q_triton, x_s_triton = triton_per_token_group_quant_8bit(
        x,
        group_size,
        eps=1e-10,
        dtype=dst_dtype,
        column_major_scales=column_major_scales,
        scale_tma_aligned=scale_tma_aligned,
296
297
    )

298
299
300
301
302
303
304
305
    x_q_sglang, x_s_sglang = sglang_per_token_group_quant_8bit(
        x,
        group_size,
        eps=1e-10,
        dtype=dst_dtype,
        column_major_scales=column_major_scales,
        scale_tma_aligned=scale_tma_aligned,
    )
306

PGFLMG's avatar
PGFLMG committed
307
    torch.testing.assert_close(
308
309
        x_q_triton.to(torch.float32), x_q_sglang.to(torch.float32), rtol=1e-3, atol=1e-5
    )
PGFLMG's avatar
PGFLMG committed
310
    torch.testing.assert_close(
311
312
        x_s_triton.contiguous(), x_s_sglang.contiguous(), rtol=1e-3, atol=1e-5
    )
313
314
315
316


if __name__ == "__main__":
    pytest.main([__file__])