test_awq_dequant.py 5.43 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
# Adapted from https://github.com/vllm-project/vllm/blob/main/tests/kernels/quantization/test_awq_triton.py
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
unittest version of the AWQ Triton kernel tests.

Run with:
    python -m unittest test_awq_dequant.py
"""
import unittest

import torch

from sglang.srt.layers.quantization.awq_triton import (
    AWQ_TRITON_SUPPORTED_GROUP_SIZES,
    awq_dequantize_triton,
    awq_gemm_triton,
)
from sglang.test.test_utils import CustomTestCase

device = "cuda"


def reverse_awq_order(t: torch.Tensor) -> torch.Tensor:
    bits = 4
    AWQ_REVERSE_ORDER = [0, 4, 1, 5, 2, 6, 3, 7]
    idx = torch.arange(t.shape[-1], dtype=torch.int32, device=t.device)
    idx = idx.view(-1, 32 // bits)[:, AWQ_REVERSE_ORDER].view(-1)
    return (t[:, idx] & 0xF).contiguous()


def awq_dequantize_torch(
    qweight: torch.Tensor,
    scales: torch.Tensor,
    qzeros: torch.Tensor,
    group_size: int,
) -> torch.Tensor:
    if group_size == -1:
        group_size = qweight.shape[0]

    bits = 4
    shifts = torch.arange(0, 32, bits, device=qzeros.device)

    iweights = torch.bitwise_right_shift(qweight[:, :, None], shifts[None, None, :]).to(
        torch.int8
    )
    iweights = reverse_awq_order(iweights.view(iweights.shape[0], -1))

    zeros = torch.bitwise_right_shift(qzeros[:, :, None], shifts[None, None, :]).to(
        torch.int8
    )
    zeros = reverse_awq_order(zeros.view(qzeros.shape[0], -1))

    iweights = torch.bitwise_and(iweights, (2**bits) - 1)
    zeros = torch.bitwise_and(zeros, (2**bits) - 1)

    scales = scales.repeat_interleave(group_size, dim=0)
    zeros = zeros.repeat_interleave(group_size, dim=0)
    return (iweights - zeros) * scales


class TestAWQTriton(CustomTestCase):
    def test_dequantize(self):
        rows_list = [3584, 18944, 128, 256, 512, 1024]
        cols_list = [448, 576, 4736, 16, 32, 64, 128]

        for qweight_rows in rows_list:
            for qweight_cols in cols_list:
                for group_size in AWQ_TRITON_SUPPORTED_GROUP_SIZES:
                    with self.subTest(
                        rows=qweight_rows, cols=qweight_cols, g=group_size
                    ):
                        self._run_dequant_case(
                            qweight_rows=qweight_rows,
                            qweight_cols=qweight_cols,
                            group_size=group_size,
                        )

    def _run_dequant_case(self, qweight_rows, qweight_cols, group_size):
        if group_size == -1:
            group_size = qweight_rows

        torch.manual_seed(0)

        qweight = torch.randint(
            0,
            torch.iinfo(torch.int32).max,
            (qweight_rows, qweight_cols),
            dtype=torch.int32,
            device=device,
        )
        scales = torch.rand(
            qweight_rows // group_size,
            qweight_cols * 8,
            dtype=torch.float16,
            device=device,
        )
        zeros = torch.randint(
            0,
            torch.iinfo(torch.int32).max,
            (qweight_rows // group_size, qweight_cols),
            dtype=torch.int32,
            device=device,
        )

        ref = awq_dequantize_torch(qweight, scales, zeros, group_size)
        tri = awq_dequantize_triton(qweight, scales, zeros)

        # sanity
        self.assertFalse(torch.any(torch.isinf(tri)) or torch.any(torch.isnan(tri)))
        torch.testing.assert_close(ref, tri)

    # GEMM
    def test_gemm(self):
        N_list = [1, 2, 4, 8, 14, 17, 23, 32]
        K_list = [128]
        M_list = [16, 24, 32]
        splitK_list = [1, 8]

        for N in N_list:
            for K in K_list:
                for M in M_list:
                    for group_size in AWQ_TRITON_SUPPORTED_GROUP_SIZES:
                        for splitK in splitK_list:
                            with self.subTest(N=N, K=K, M=M, g=group_size, sk=splitK):
                                self._run_gemm_case(
                                    N=N,
                                    K=K,
                                    M=M,
                                    group_size=group_size,
                                    splitK=splitK,
                                )

    def _run_gemm_case(self, N, K, M, group_size, splitK):
        if group_size == -1:
            group_size = K

        torch.manual_seed(0)

        x = torch.rand((N, K), dtype=torch.float32, device=device)
        qweight = torch.randint(
            0,
            torch.iinfo(torch.int32).max,
            (K, M // 8),
            dtype=torch.int32,
            device=device,
        )
        qzeros = torch.randint(
            0,
            torch.iinfo(torch.int32).max,
            (K // group_size, M // 8),
            dtype=torch.int32,
            device=device,
        )
        scales = torch.rand((K // group_size, M), dtype=torch.float32, device=device)

        tri_out = awq_gemm_triton(x, qweight, scales, qzeros, splitK)

        self.assertFalse(
            torch.any(torch.isinf(tri_out)) or torch.any(torch.isnan(tri_out))
        )

        # dequantize & compare
        w_deq = awq_dequantize_triton(qweight, scales, qzeros)
        ref_out = torch.matmul(x, w_deq)

        self.assertFalse(
            torch.any(torch.isinf(ref_out)) or torch.any(torch.isnan(ref_out))
        )

        torch.testing.assert_close(tri_out.cpu(), ref_out.cpu(), atol=1e-1, rtol=1e-1)


if __name__ == "__main__":
    unittest.main(verbosity=2)