th_gemm_dequantize.py 10.2 KB
Newer Older
Li Zhang's avatar
Li Zhang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
import torch
import unittest

def random_tensor(shape, dtype, device, mean=0, std=1):
    return torch.empty(shape, dtype=dtype, device=device).normal_(mean, std)
  
class TestGemmDequantize(unittest.TestCase):
    def setUp(self) -> None:
        torch.classes.load_library("lib/libth_transformer.so")
        torch.classes.load_library("lib/libgemm_dq_unit_ops.so")
        self.unpack_packed_int4s = torch.ops.fastertransformer.unpack_int4_packed_tensor_to_int8
        self.pack_int4s = torch.ops.fastertransformer.pack_int8_tensor_to_packed_int4
        self.fused_gemm_dq = torch.ops.gemm_dq_unit_ops.fused_gemm_dq
        self.fused_gemm_dq_bias_act = torch.ops.gemm_dq_unit_ops.fused_gemm_dq_bias_act
        self.bench = torch.ops.gemm_dq_unit_ops.benchmark_against_cublas_fp
        self.preprocess_weights_for_mixed_gemm = torch.ops.fastertransformer.preprocess_weights_for_mixed_gemm

        self.symmetric_quantizer = torch.ops.fastertransformer._symmetric_quantize_last_axis_of_batched_matrix

        torch.manual_seed(734876213)

    def dequantize_test_helper(self, weight_type, quant_type):
      assert quant_type == torch.int8 or quant_type == torch.quint4x2 

      lower_bound = -128 if quant_type == torch.int8 else -8
      upper_bound = 127 if quant_type == torch.int8 else 7

      m, n, k = 64, 128, 64
      weights = torch.randint(lower_bound, upper_bound, [k, n], dtype=torch.int8, device="cpu")

      packed_weight = self.pack_int4s(weights) if quant_type == torch.quint4x2 else weights
      cuda_weights = self.preprocess_weights_for_mixed_gemm(packed_weight, quant_type).to("cuda")
      weights = weights.to("cuda")

      act = torch.eye(m, dtype=weight_type, device="cuda")
      scales = torch.ones([n], dtype=weight_type, device='cuda')

      actual = self.fused_gemm_dq(act, cuda_weights, scales)
      torch.testing.assert_close(actual, weights, atol=0, rtol=0, check_dtype=False)

    def test_fp16_int8_dequantize(self):
      self.dequantize_test_helper(torch.float16, torch.int8)

    def test_bf16_int8_dequantize(self):
      self.dequantize_test_helper(torch.bfloat16, torch.int8)

    def test_fp16_int4_dequantize(self):
      self.dequantize_test_helper(torch.float16, torch.quint4x2)

    def test_bf16_int4_dequantize(self):
      self.dequantize_test_helper(torch.bfloat16, torch.quint4x2)

    def apply_act(self, inp, act_str):
      if act_str == "identity":
        return inp
      elif act_str == "silu":
        return torch.nn.SiLU()(inp)
      elif act_str == "relu":
        return torch.nn.ReLU()(inp)
      elif act_str == "gelu":
        return torch.nn.GELU(approximate="tanh")(inp)
      else:
        assert False, "Unsupported activation"

    def gemm_dequant_test_helper(self, compute_type, weight_dtype, gemm_ms, gemm_ns, gemm_ks, rtol, atol, act_str="only_gemm", benchmark=False):
        assert weight_dtype == torch.int8 or weight_dtype == torch.quint4x2, "Weight must be quantized"      

        for gemm_k in gemm_ks:
            for gemm_n in gemm_ns:
                torch_weights_cpu = random_tensor((gemm_k, gemm_n), dtype=compute_type, device="cpu", mean=0, std=0.002)
                ref_torch_weights, processed_torch_weights, torch_weight_scales = self.symmetric_quantizer(torch_weights_cpu, weight_dtype)
                ref_torch_weights = self.unpack_packed_int4s(ref_torch_weights) if weight_dtype == torch.quint4x2 else ref_torch_weights
                ref_torch_weights = ref_torch_weights.to("cuda")
                processed_torch_weights = processed_torch_weights.to("cuda")
                torch_weight_scales = torch_weight_scales.to("cuda")
                torch_biases = random_tensor((gemm_n), dtype=compute_type, device="cuda", mean=0, std=0.1)


                for num_rows in gemm_ms:
                    torch_activations = torch.randn(size=(num_rows, gemm_k), dtype=compute_type, device="cuda")

                    scales_unsqueezed = torch_weight_scales.unsqueeze(0)
                    casted_weights = ref_torch_weights.to(torch_activations.dtype)
                    dequantized_weights = torch.multiply(casted_weights, scales_unsqueezed)
                    if benchmark:
                      assert act_str == "only_gemm", "Benchmarks against cublas must use just GEMM."
                      torch.cuda.profiler.start()
                      times, results = self.bench(torch_activations, processed_torch_weights, torch_weight_scales, dequantized_weights, 200)
                      torch.cuda.profiler.stop()
                      times = times[0]
                      cublas_time = times[0].item()
                      ft_time = times[1].item()
                      ft_speedup = cublas_time / ft_time
                      print("{},{},{},{},{},{}".format(num_rows, gemm_n, gemm_k, cublas_time, ft_time, ft_speedup))
                      reference_result = results[0]
                      ft_result = results[1]
                    else:
                      if act_str == "only_gemm":
                        reference_result = torch.matmul(torch_activations, dequantized_weights)
                        ft_result = self.fused_gemm_dq(torch_activations, processed_torch_weights, torch_weight_scales)
                      else:
                        reference_result = torch.matmul(torch_activations, dequantized_weights)
                        reference_result += torch_biases.unsqueeze(0)
                        reference_result = self.apply_act(reference_result, act_str)

                        ft_result = self.fused_gemm_dq_bias_act(torch_activations, processed_torch_weights, torch_weight_scales, torch_biases, act_str)

                    msg = "FC1 Failed on m={}, n={}, k={}".format(num_rows, gemm_n, gemm_k)
                    torch.testing.assert_close(ft_result, reference_result, rtol=rtol, atol=atol, msg=msg, check_dtype=False)         

    def test_fp16_int8_gemm(self):
        self.gemm_dequant_test_helper(torch.float16, torch.int8,
                                      gemm_ms = [256, 177, 195, 125, 66, 33, 8, 2, 1],
                                      gemm_ns = [1024, 2048, 4096],
                                      gemm_ks = [4096, 8192, 16384],
                                      rtol=0.001, atol=0.002)

    def test_fp16_int4_gemm(self):
        self.gemm_dequant_test_helper(torch.float16, torch.quint4x2,
                                      gemm_ms = [256, 177, 195, 125, 66, 33, 8, 2, 1],
                                      gemm_ns = [1024, 2048, 4096],
                                      gemm_ks = [4096, 8192, 16384],
                                      rtol=0.001, atol=0.002)
    
    def test_bf16_int8_gemm(self):
        self.gemm_dequant_test_helper(torch.bfloat16, torch.int8,
                                      gemm_ms = [256, 177, 195, 125, 66, 33, 8, 2, 1],
                                      gemm_ns = [1024, 2048, 4096],
                                      gemm_ks = [4096, 8192, 16384],
                                      rtol=0.01, atol=0.01)

    def test_bf16_int4_gemm(self):
        self.gemm_dequant_test_helper(torch.bfloat16, torch.quint4x2,
                                      gemm_ms = [256, 177, 195, 125, 66, 33, 8, 2, 1],
                                      gemm_ns = [1024, 2048, 4096],
                                      gemm_ks = [4096, 8192, 16384],
                                      rtol=0.01, atol=0.01)

    def test_fp16_int8_gemm_bias(self):
        self.gemm_dequant_test_helper(torch.float16, torch.int8,
                                      gemm_ms = [256],
                                      gemm_ns = [1024],
                                      gemm_ks = [8192],
                                      rtol=0.001, atol=0.002,
                                      act_str="identity")
  
    def test_fp16_int8_gemm_bias_relu(self):
        self.gemm_dequant_test_helper(torch.float16, torch.int8,
                                      gemm_ms = [256],
                                      gemm_ns = [1024],
                                      gemm_ks = [8192],
                                      rtol=0.001, atol=0.002,
                                      act_str="relu")

    def test_fp16_int8_gemm_bias_gelu(self):
        self.gemm_dequant_test_helper(torch.float16, torch.int8,
                                      gemm_ms = [256],
                                      gemm_ns = [1024],
                                      gemm_ks = [8192],
                                      rtol=0.001, atol=0.002,
                                      act_str="gelu")                                    

    def test_fp16_int8_gemm_bias_silu(self):
        self.gemm_dequant_test_helper(torch.float16, torch.int8,
                                      gemm_ms = [256],
                                      gemm_ns = [1024],
                                      gemm_ks = [8192],
                                      rtol=0.001, atol=0.002,
                                      act_str="silu")  

    def bench_helper(self, act_type, quant_type, rtol, atol):
      # Warm, using bfloat here since it seems to reliably use cublas.
      x = random_tensor([20480, 20480], torch.bfloat16, device="cuda")
      warm_iters = 30
      for iter in range(warm_iters):
        res = x @ x

      m_shapes = torch.arange(0, 12)
      m_shapes = 2 ** m_shapes

      self.gemm_dequant_test_helper(act_type, quant_type,
                                    gemm_ms = [128],
                                    gemm_ns = [1536],
                                    gemm_ks = [12288],
                                    rtol=rtol, atol=atol, benchmark=True)

    @unittest.skip("This is a benchmark so don't run by default")
    def test_fp16_int8_cublas(self):
      self.bench_helper(torch.float16, torch.int8, 1e-3, 0.002)

    
    @unittest.skip("This is a benchmark so don't run by default")
    def test_bf16_int8_cublas(self):
      self.bench_helper(torch.bfloat16, torch.int8, 1e-2, 1e-2)

    @unittest.skip("This is a benchmark so don't run by default")
    def test_fp16_int4_cublas(self):
      self.bench_helper(torch.float16, torch.quint4x2, 1e-3, 0.002)

    
    @unittest.skip("This is a benchmark so don't run by default")
    def test_bf16_int4_cublas(self):
      self.bench_helper(torch.bfloat16, torch.quint4x2, 1e-2, 1e-2)

if __name__ == '__main__':
    unittest.main()