test_per_token_group_quant_8bit.py 5.08 KB
Newer Older
1
import itertools
2
from typing import Tuple
3
4
5
6
7

import pytest
import torch
import triton
import triton.language as tl
8
from sgl_kernel import sgl_per_token_group_quant_fp8, sgl_per_token_group_quant_int8
9

10
from sglang.srt.utils import is_hip
11
12
13
14
15
16

is_hip_ = is_hip()
fp8_type_ = torch.float8_e4m3fnuz if is_hip_ else torch.float8_e4m3fn


@triton.jit
17
def _per_token_group_quant_8bit(
18
19
20
21
22
23
24
25
26
27
    # Pointers to inputs and output
    y_ptr,
    y_q_ptr,
    y_s_ptr,
    # Stride of input
    y_stride,
    # Collums of input
    N,
    # Avoid to divide zero
    eps,
28
29
30
    # Information for 8bit data type (int8 or fp8_type_)
    max_8bit,
    min_8bit,
31
32
33
34
35
    # Meta-parameters
    BLOCK: tl.constexpr,
):
    """A Triton-accelerated function to perform per-token-group quantization on a
    tensor.
36
    This function converts the tensor values into 8bit values.
37
38
39
40
41
42
43
44
45
46
47
48
49
    """
    # Map the program id to the row of X and Y it should compute.
    g_id = tl.program_id(0)
    y_ptr += g_id * y_stride
    y_q_ptr += g_id * y_stride
    y_s_ptr += g_id

    cols = tl.arange(0, BLOCK)  # N <= BLOCK
    mask = cols < N

    y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32)
    # Quant
    _absmax = tl.maximum(tl.max(tl.abs(y)), eps)
50
51
    y_s = _absmax / max_8bit
    y_q = tl.clamp(y / y_s, min_8bit, max_8bit).to(y_q_ptr.dtype.element_ty)
52
53
54
55
56

    tl.store(y_q_ptr + cols, y_q, mask=mask)
    tl.store(y_s_ptr, y_s)


57
def triton_per_token_group_quant_8bit(
58
59
    x: torch.Tensor,
    group_size: int,
60
    dst_dtype: torch.dtype,
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
    eps: float = 1e-10,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """Function to perform per-token-group quantization on an input tensor `x`.
    It converts the tensor values into signed float8 values and returns the
    quantized tensor along with the scaling factor used for quantization.
    Args:
        x: The input tenosr with ndim >= 2.
        group_size: The group size used for quantization.
        eps: The minimum to avoid dividing zero.
        dtype: The dype of output tensor. Note that only `torch.float8_e4m3fn` is supported for now.
    Returns:
        Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the scaling factor for quantization.
    """
    assert (
        x.shape[-1] % group_size == 0
    ), "the last dimension of `x` cannot be divisible by `group_size`"
    assert x.is_contiguous(), "`x` is not contiguous"

79
80
81
82
83
84
85
86
    if dst_dtype == torch.int8:
        iinfo = torch.iinfo(dst_dtype)
        max_8bit = iinfo.max
        min_8bit = iinfo.min
    else:
        finfo = torch.finfo(dst_dtype)
        max_8bit = finfo.max
        min_8bit = finfo.min
87

88
    x_q = torch.empty_like(x, device=x.device, dtype=dst_dtype)
89
90
91
92
93
94
95
96
97
98
99
100
    M = x.numel() // group_size
    N = group_size
    x_s = torch.empty(
        x.shape[:-1] + (x.shape[-1] // group_size,),
        device=x.device,
        dtype=torch.float32,
    )

    BLOCK = triton.next_power_of_2(N)
    # heuristics for number of warps
    num_warps = min(max(BLOCK // 256, 1), 8)
    num_stages = 1
101
    _per_token_group_quant_8bit[(M,)](
102
103
104
105
106
107
        x,
        x_q,
        x_s,
        group_size,
        N,
        eps,
108
109
        max_8bit,
        min_8bit,
110
111
112
113
114
115
116
117
        BLOCK=BLOCK,
        num_warps=num_warps,
        num_stages=num_stages,
    )

    return x_q, x_s


118
def sglang_per_token_group_quant_8bit(
119
120
    x: torch.Tensor,
    group_size: int,
121
    dst_dtype: torch.dtype,
122
123
124
125
126
127
128
    eps: float = 1e-10,
):
    assert (
        x.shape[-1] % group_size == 0
    ), "the last dimension of `x` cannot be divisible by `group_size`"
    assert x.is_contiguous(), "`x` is not contiguous"

129
    x_q = torch.empty_like(x, device=x.device, dtype=dst_dtype)
130
131
132
133
134
135
    x_s = torch.empty(
        x.shape[:-1] + (x.shape[-1] // group_size,),
        device=x.device,
        dtype=torch.float32,
    )

136
137
138
139
140
141
142
143
144
145
    if dst_dtype == torch.int8:
        iinfo = torch.iinfo(dst_dtype)
        int8_max = iinfo.max
        int8_min = iinfo.min
        sgl_per_token_group_quant_int8(x, x_q, x_s, group_size, eps, int8_min, int8_max)
    else:
        f8_info = torch.finfo(dst_dtype)
        fp8_max = f8_info.max
        fp8_min = f8_info.min
        sgl_per_token_group_quant_fp8(x, x_q, x_s, group_size, eps, fp8_min, fp8_max)
146
147
148
149
150

    return x_q, x_s


@pytest.mark.parametrize(
151
    "batch_size, seq_len, group_size, dst_dtype",
152
153
    list(
        itertools.product(
154
            [1, 2, 4, 8, 16, 32, 64, 128],  # batch_size
155
            [64, 128, 256, 512, 1024, 2048],  # seq_len
156
            [16, 32, 64, 128, 256],  # group_size
157
            [torch.int8, fp8_type_],  # dtype
158
159
160
        )
    ),
)
161
162
163
def test_per_token_group_quant_compare_implementations(
    batch_size, seq_len, group_size, dst_dtype
):
164
165
166
167
    x = torch.randn(
        (batch_size, seq_len, group_size * 2), device="cuda", dtype=torch.float16
    )

168
169
    x_q_triton, x_s_triton = triton_per_token_group_quant_8bit(x, group_size, dst_dtype)
    x_q_sglang, x_s_sglang = sglang_per_token_group_quant_8bit(x, group_size, dst_dtype)
170
171
172
173
174
175
176
177
178

    assert torch.allclose(
        x_q_triton.to(torch.float32), x_q_sglang.to(torch.float32), rtol=1e-3, atol=1e-5
    )
    assert torch.allclose(x_s_triton, x_s_sglang, rtol=1e-3, atol=1e-5)


if __name__ == "__main__":
    pytest.main([__file__])