th_gemm_dequantize.py 12.2 KB
Newer Older
AllentDan's avatar
AllentDan committed
1
# flake8: noqa
Li Zhang's avatar
Li Zhang committed
2
3
import unittest

AllentDan's avatar
AllentDan committed
4
5
6
import torch


Li Zhang's avatar
Li Zhang committed
7
8
def random_tensor(shape, dtype, device, mean=0, std=1):
    return torch.empty(shape, dtype=dtype, device=device).normal_(mean, std)
AllentDan's avatar
AllentDan committed
9
10


Li Zhang's avatar
Li Zhang committed
11
class TestGemmDequantize(unittest.TestCase):
AllentDan's avatar
AllentDan committed
12

Li Zhang's avatar
Li Zhang committed
13
    def setUp(self) -> None:
AllentDan's avatar
AllentDan committed
14
15
        torch.classes.load_library('lib/libth_transformer.so')
        torch.classes.load_library('lib/libgemm_dq_unit_ops.so')
16
17
        self.unpack_packed_int4s = torch.ops.turbomind.unpack_int4_packed_tensor_to_int8
        self.pack_int4s = torch.ops.turbomind.pack_int8_tensor_to_packed_int4
Li Zhang's avatar
Li Zhang committed
18
19
20
        self.fused_gemm_dq = torch.ops.gemm_dq_unit_ops.fused_gemm_dq
        self.fused_gemm_dq_bias_act = torch.ops.gemm_dq_unit_ops.fused_gemm_dq_bias_act
        self.bench = torch.ops.gemm_dq_unit_ops.benchmark_against_cublas_fp
21
        self.preprocess_weights_for_mixed_gemm = torch.ops.turbomind.preprocess_weights_for_mixed_gemm
Li Zhang's avatar
Li Zhang committed
22

23
        self.symmetric_quantizer = torch.ops.turbomind._symmetric_quantize_last_axis_of_batched_matrix
Li Zhang's avatar
Li Zhang committed
24
25
26
27

        torch.manual_seed(734876213)

    def dequantize_test_helper(self, weight_type, quant_type):
AllentDan's avatar
AllentDan committed
28
        assert quant_type == torch.int8 or quant_type == torch.quint4x2
Li Zhang's avatar
Li Zhang committed
29

AllentDan's avatar
AllentDan committed
30
31
        lower_bound = -128 if quant_type == torch.int8 else -8
        upper_bound = 127 if quant_type == torch.int8 else 7
Li Zhang's avatar
Li Zhang committed
32

AllentDan's avatar
AllentDan committed
33
34
35
36
37
        m, n, k = 64, 128, 64
        weights = torch.randint(lower_bound,
                                upper_bound, [k, n],
                                dtype=torch.int8,
                                device='cpu')
Li Zhang's avatar
Li Zhang committed
38

AllentDan's avatar
AllentDan committed
39
40
41
42
43
        packed_weight = self.pack_int4s(
            weights) if quant_type == torch.quint4x2 else weights
        cuda_weights = self.preprocess_weights_for_mixed_gemm(
            packed_weight, quant_type).to('cuda')
        weights = weights.to('cuda')
Li Zhang's avatar
Li Zhang committed
44

AllentDan's avatar
AllentDan committed
45
46
        act = torch.eye(m, dtype=weight_type, device='cuda')
        scales = torch.ones([n], dtype=weight_type, device='cuda')
Li Zhang's avatar
Li Zhang committed
47

AllentDan's avatar
AllentDan committed
48
49
50
51
52
53
        actual = self.fused_gemm_dq(act, cuda_weights, scales)
        torch.testing.assert_close(actual,
                                   weights,
                                   atol=0,
                                   rtol=0,
                                   check_dtype=False)
Li Zhang's avatar
Li Zhang committed
54
55

    def test_fp16_int8_dequantize(self):
AllentDan's avatar
AllentDan committed
56
        self.dequantize_test_helper(torch.float16, torch.int8)
Li Zhang's avatar
Li Zhang committed
57
58

    def test_bf16_int8_dequantize(self):
AllentDan's avatar
AllentDan committed
59
        self.dequantize_test_helper(torch.bfloat16, torch.int8)
Li Zhang's avatar
Li Zhang committed
60
61

    def test_fp16_int4_dequantize(self):
AllentDan's avatar
AllentDan committed
62
        self.dequantize_test_helper(torch.float16, torch.quint4x2)
Li Zhang's avatar
Li Zhang committed
63
64

    def test_bf16_int4_dequantize(self):
AllentDan's avatar
AllentDan committed
65
        self.dequantize_test_helper(torch.bfloat16, torch.quint4x2)
Li Zhang's avatar
Li Zhang committed
66
67

    def apply_act(self, inp, act_str):
AllentDan's avatar
AllentDan committed
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
        if act_str == 'identity':
            return inp
        elif act_str == 'silu':
            return torch.nn.SiLU()(inp)
        elif act_str == 'relu':
            return torch.nn.ReLU()(inp)
        elif act_str == 'gelu':
            return torch.nn.GELU(approximate='tanh')(inp)
        else:
            assert False, 'Unsupported activation'

    def gemm_dequant_test_helper(self,
                                 compute_type,
                                 weight_dtype,
                                 gemm_ms,
                                 gemm_ns,
                                 gemm_ks,
                                 rtol,
                                 atol,
                                 act_str='only_gemm',
                                 benchmark=False):
        assert weight_dtype == torch.int8 or weight_dtype == torch.quint4x2, 'Weight must be quantized'
Li Zhang's avatar
Li Zhang committed
90
91
92

        for gemm_k in gemm_ks:
            for gemm_n in gemm_ns:
AllentDan's avatar
AllentDan committed
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
                torch_weights_cpu = random_tensor((gemm_k, gemm_n),
                                                  dtype=compute_type,
                                                  device='cpu',
                                                  mean=0,
                                                  std=0.002)
                ref_torch_weights, processed_torch_weights, torch_weight_scales = self.symmetric_quantizer(
                    torch_weights_cpu, weight_dtype)
                ref_torch_weights = self.unpack_packed_int4s(
                    ref_torch_weights
                ) if weight_dtype == torch.quint4x2 else ref_torch_weights
                ref_torch_weights = ref_torch_weights.to('cuda')
                processed_torch_weights = processed_torch_weights.to('cuda')
                torch_weight_scales = torch_weight_scales.to('cuda')
                torch_biases = random_tensor((gemm_n),
                                             dtype=compute_type,
                                             device='cuda',
                                             mean=0,
                                             std=0.1)
Li Zhang's avatar
Li Zhang committed
111
112

                for num_rows in gemm_ms:
AllentDan's avatar
AllentDan committed
113
114
115
                    torch_activations = torch.randn(size=(num_rows, gemm_k),
                                                    dtype=compute_type,
                                                    device='cuda')
Li Zhang's avatar
Li Zhang committed
116
117

                    scales_unsqueezed = torch_weight_scales.unsqueeze(0)
AllentDan's avatar
AllentDan committed
118
119
120
121
                    casted_weights = ref_torch_weights.to(
                        torch_activations.dtype)
                    dequantized_weights = torch.multiply(
                        casted_weights, scales_unsqueezed)
Li Zhang's avatar
Li Zhang committed
122
                    if benchmark:
AllentDan's avatar
AllentDan committed
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
                        assert act_str == 'only_gemm', 'Benchmarks against cublas must use just GEMM.'
                        torch.cuda.profiler.start()
                        times, results = self.bench(torch_activations,
                                                    processed_torch_weights,
                                                    torch_weight_scales,
                                                    dequantized_weights, 200)
                        torch.cuda.profiler.stop()
                        times = times[0]
                        cublas_time = times[0].item()
                        ft_time = times[1].item()
                        ft_speedup = cublas_time / ft_time
                        print('{},{},{},{},{},{}'.format(
                            num_rows, gemm_n, gemm_k, cublas_time, ft_time,
                            ft_speedup))
                        reference_result = results[0]
                        ft_result = results[1]
Li Zhang's avatar
Li Zhang committed
139
                    else:
AllentDan's avatar
AllentDan committed
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
                        if act_str == 'only_gemm':
                            reference_result = torch.matmul(
                                torch_activations, dequantized_weights)
                            ft_result = self.fused_gemm_dq(
                                torch_activations, processed_torch_weights,
                                torch_weight_scales)
                        else:
                            reference_result = torch.matmul(
                                torch_activations, dequantized_weights)
                            reference_result += torch_biases.unsqueeze(0)
                            reference_result = self.apply_act(
                                reference_result, act_str)

                            ft_result = self.fused_gemm_dq_bias_act(
                                torch_activations, processed_torch_weights,
                                torch_weight_scales, torch_biases, act_str)

                    msg = 'FC1 Failed on m={}, n={}, k={}'.format(
                        num_rows, gemm_n, gemm_k)
                    torch.testing.assert_close(ft_result,
                                               reference_result,
                                               rtol=rtol,
                                               atol=atol,
                                               msg=msg,
                                               check_dtype=False)
Li Zhang's avatar
Li Zhang committed
165
166

    def test_fp16_int8_gemm(self):
AllentDan's avatar
AllentDan committed
167
168
169
170
171
172
173
174
        self.gemm_dequant_test_helper(
            torch.float16,
            torch.int8,
            gemm_ms=[256, 177, 195, 125, 66, 33, 8, 2, 1],
            gemm_ns=[1024, 2048, 4096],
            gemm_ks=[4096, 8192, 16384],
            rtol=0.001,
            atol=0.002)
Li Zhang's avatar
Li Zhang committed
175
176

    def test_fp16_int4_gemm(self):
AllentDan's avatar
AllentDan committed
177
178
179
180
181
182
183
184
185
        self.gemm_dequant_test_helper(
            torch.float16,
            torch.quint4x2,
            gemm_ms=[256, 177, 195, 125, 66, 33, 8, 2, 1],
            gemm_ns=[1024, 2048, 4096],
            gemm_ks=[4096, 8192, 16384],
            rtol=0.001,
            atol=0.002)

Li Zhang's avatar
Li Zhang committed
186
    def test_bf16_int8_gemm(self):
AllentDan's avatar
AllentDan committed
187
188
189
190
191
192
193
194
        self.gemm_dequant_test_helper(
            torch.bfloat16,
            torch.int8,
            gemm_ms=[256, 177, 195, 125, 66, 33, 8, 2, 1],
            gemm_ns=[1024, 2048, 4096],
            gemm_ks=[4096, 8192, 16384],
            rtol=0.01,
            atol=0.01)
Li Zhang's avatar
Li Zhang committed
195
196

    def test_bf16_int4_gemm(self):
AllentDan's avatar
AllentDan committed
197
198
199
200
201
202
203
204
        self.gemm_dequant_test_helper(
            torch.bfloat16,
            torch.quint4x2,
            gemm_ms=[256, 177, 195, 125, 66, 33, 8, 2, 1],
            gemm_ns=[1024, 2048, 4096],
            gemm_ks=[4096, 8192, 16384],
            rtol=0.01,
            atol=0.01)
Li Zhang's avatar
Li Zhang committed
205
206

    def test_fp16_int8_gemm_bias(self):
AllentDan's avatar
AllentDan committed
207
208
209
210
211
212
213
214
215
        self.gemm_dequant_test_helper(torch.float16,
                                      torch.int8,
                                      gemm_ms=[256],
                                      gemm_ns=[1024],
                                      gemm_ks=[8192],
                                      rtol=0.001,
                                      atol=0.002,
                                      act_str='identity')

Li Zhang's avatar
Li Zhang committed
216
    def test_fp16_int8_gemm_bias_relu(self):
AllentDan's avatar
AllentDan committed
217
218
219
220
221
222
223
224
        self.gemm_dequant_test_helper(torch.float16,
                                      torch.int8,
                                      gemm_ms=[256],
                                      gemm_ns=[1024],
                                      gemm_ks=[8192],
                                      rtol=0.001,
                                      atol=0.002,
                                      act_str='relu')
Li Zhang's avatar
Li Zhang committed
225
226

    def test_fp16_int8_gemm_bias_gelu(self):
AllentDan's avatar
AllentDan committed
227
228
229
230
231
232
233
234
        self.gemm_dequant_test_helper(torch.float16,
                                      torch.int8,
                                      gemm_ms=[256],
                                      gemm_ns=[1024],
                                      gemm_ks=[8192],
                                      rtol=0.001,
                                      atol=0.002,
                                      act_str='gelu')
Li Zhang's avatar
Li Zhang committed
235
236

    def test_fp16_int8_gemm_bias_silu(self):
AllentDan's avatar
AllentDan committed
237
238
239
240
241
242
243
244
        self.gemm_dequant_test_helper(torch.float16,
                                      torch.int8,
                                      gemm_ms=[256],
                                      gemm_ns=[1024],
                                      gemm_ks=[8192],
                                      rtol=0.001,
                                      atol=0.002,
                                      act_str='silu')
Li Zhang's avatar
Li Zhang committed
245
246

    def bench_helper(self, act_type, quant_type, rtol, atol):
AllentDan's avatar
AllentDan committed
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
        # Warm, using bfloat here since it seems to reliably use cublas.
        x = random_tensor([20480, 20480], torch.bfloat16, device='cuda')
        warm_iters = 30
        for iter in range(warm_iters):
            res = x @ x

        m_shapes = torch.arange(0, 12)
        m_shapes = 2**m_shapes

        self.gemm_dequant_test_helper(act_type,
                                      quant_type,
                                      gemm_ms=[128],
                                      gemm_ns=[1536],
                                      gemm_ks=[12288],
                                      rtol=rtol,
                                      atol=atol,
                                      benchmark=True)
Li Zhang's avatar
Li Zhang committed
264
265
266

    @unittest.skip("This is a benchmark so don't run by default")
    def test_fp16_int8_cublas(self):
AllentDan's avatar
AllentDan committed
267
        self.bench_helper(torch.float16, torch.int8, 1e-3, 0.002)
Li Zhang's avatar
Li Zhang committed
268
269
270

    @unittest.skip("This is a benchmark so don't run by default")
    def test_bf16_int8_cublas(self):
AllentDan's avatar
AllentDan committed
271
        self.bench_helper(torch.bfloat16, torch.int8, 1e-2, 1e-2)
Li Zhang's avatar
Li Zhang committed
272
273
274

    @unittest.skip("This is a benchmark so don't run by default")
    def test_fp16_int4_cublas(self):
AllentDan's avatar
AllentDan committed
275
        self.bench_helper(torch.float16, torch.quint4x2, 1e-3, 0.002)
Li Zhang's avatar
Li Zhang committed
276
277
278

    @unittest.skip("This is a benchmark so don't run by default")
    def test_bf16_int4_cublas(self):
AllentDan's avatar
AllentDan committed
279
280
        self.bench_helper(torch.bfloat16, torch.quint4x2, 1e-2, 1e-2)

Li Zhang's avatar
Li Zhang committed
281
282

if __name__ == '__main__':
AllentDan's avatar
AllentDan committed
283
    unittest.main()