# flake8: noqa import unittest import torch def random_tensor(shape, dtype, device, mean=0, std=1): return torch.empty(shape, dtype=dtype, device=device).normal_(mean, std) class TestGemmDequantize(unittest.TestCase): def setUp(self) -> None: torch.classes.load_library('lib/libth_transformer.so') torch.classes.load_library('lib/libgemm_dq_unit_ops.so') self.unpack_packed_int4s = torch.ops.turbomind.unpack_int4_packed_tensor_to_int8 self.pack_int4s = torch.ops.turbomind.pack_int8_tensor_to_packed_int4 self.fused_gemm_dq = torch.ops.gemm_dq_unit_ops.fused_gemm_dq self.fused_gemm_dq_bias_act = torch.ops.gemm_dq_unit_ops.fused_gemm_dq_bias_act self.bench = torch.ops.gemm_dq_unit_ops.benchmark_against_cublas_fp self.preprocess_weights_for_mixed_gemm = torch.ops.turbomind.preprocess_weights_for_mixed_gemm self.symmetric_quantizer = torch.ops.turbomind._symmetric_quantize_last_axis_of_batched_matrix torch.manual_seed(734876213) def dequantize_test_helper(self, weight_type, quant_type): assert quant_type == torch.int8 or quant_type == torch.quint4x2 lower_bound = -128 if quant_type == torch.int8 else -8 upper_bound = 127 if quant_type == torch.int8 else 7 m, n, k = 64, 128, 64 weights = torch.randint(lower_bound, upper_bound, [k, n], dtype=torch.int8, device='cpu') packed_weight = self.pack_int4s( weights) if quant_type == torch.quint4x2 else weights cuda_weights = self.preprocess_weights_for_mixed_gemm( packed_weight, quant_type).to('cuda') weights = weights.to('cuda') act = torch.eye(m, dtype=weight_type, device='cuda') scales = torch.ones([n], dtype=weight_type, device='cuda') actual = self.fused_gemm_dq(act, cuda_weights, scales) torch.testing.assert_close(actual, weights, atol=0, rtol=0, check_dtype=False) def test_fp16_int8_dequantize(self): self.dequantize_test_helper(torch.float16, torch.int8) def test_bf16_int8_dequantize(self): self.dequantize_test_helper(torch.bfloat16, torch.int8) def test_fp16_int4_dequantize(self): self.dequantize_test_helper(torch.float16, torch.quint4x2) def test_bf16_int4_dequantize(self): self.dequantize_test_helper(torch.bfloat16, torch.quint4x2) def apply_act(self, inp, act_str): if act_str == 'identity': return inp elif act_str == 'silu': return torch.nn.SiLU()(inp) elif act_str == 'relu': return torch.nn.ReLU()(inp) elif act_str == 'gelu': return torch.nn.GELU(approximate='tanh')(inp) else: assert False, 'Unsupported activation' def gemm_dequant_test_helper(self, compute_type, weight_dtype, gemm_ms, gemm_ns, gemm_ks, rtol, atol, act_str='only_gemm', benchmark=False): assert weight_dtype == torch.int8 or weight_dtype == torch.quint4x2, 'Weight must be quantized' for gemm_k in gemm_ks: for gemm_n in gemm_ns: torch_weights_cpu = random_tensor((gemm_k, gemm_n), dtype=compute_type, device='cpu', mean=0, std=0.002) ref_torch_weights, processed_torch_weights, torch_weight_scales = self.symmetric_quantizer( torch_weights_cpu, weight_dtype) ref_torch_weights = self.unpack_packed_int4s( ref_torch_weights ) if weight_dtype == torch.quint4x2 else ref_torch_weights ref_torch_weights = ref_torch_weights.to('cuda') processed_torch_weights = processed_torch_weights.to('cuda') torch_weight_scales = torch_weight_scales.to('cuda') torch_biases = random_tensor((gemm_n), dtype=compute_type, device='cuda', mean=0, std=0.1) for num_rows in gemm_ms: torch_activations = torch.randn(size=(num_rows, gemm_k), dtype=compute_type, device='cuda') scales_unsqueezed = torch_weight_scales.unsqueeze(0) casted_weights = ref_torch_weights.to( torch_activations.dtype) dequantized_weights = torch.multiply( casted_weights, scales_unsqueezed) if benchmark: assert act_str == 'only_gemm', 'Benchmarks against cublas must use just GEMM.' torch.cuda.profiler.start() times, results = self.bench(torch_activations, processed_torch_weights, torch_weight_scales, dequantized_weights, 200) torch.cuda.profiler.stop() times = times[0] cublas_time = times[0].item() ft_time = times[1].item() ft_speedup = cublas_time / ft_time print('{},{},{},{},{},{}'.format( num_rows, gemm_n, gemm_k, cublas_time, ft_time, ft_speedup)) reference_result = results[0] ft_result = results[1] else: if act_str == 'only_gemm': reference_result = torch.matmul( torch_activations, dequantized_weights) ft_result = self.fused_gemm_dq( torch_activations, processed_torch_weights, torch_weight_scales) else: reference_result = torch.matmul( torch_activations, dequantized_weights) reference_result += torch_biases.unsqueeze(0) reference_result = self.apply_act( reference_result, act_str) ft_result = self.fused_gemm_dq_bias_act( torch_activations, processed_torch_weights, torch_weight_scales, torch_biases, act_str) msg = 'FC1 Failed on m={}, n={}, k={}'.format( num_rows, gemm_n, gemm_k) torch.testing.assert_close(ft_result, reference_result, rtol=rtol, atol=atol, msg=msg, check_dtype=False) def test_fp16_int8_gemm(self): self.gemm_dequant_test_helper( torch.float16, torch.int8, gemm_ms=[256, 177, 195, 125, 66, 33, 8, 2, 1], gemm_ns=[1024, 2048, 4096], gemm_ks=[4096, 8192, 16384], rtol=0.001, atol=0.002) def test_fp16_int4_gemm(self): self.gemm_dequant_test_helper( torch.float16, torch.quint4x2, gemm_ms=[256, 177, 195, 125, 66, 33, 8, 2, 1], gemm_ns=[1024, 2048, 4096], gemm_ks=[4096, 8192, 16384], rtol=0.001, atol=0.002) def test_bf16_int8_gemm(self): self.gemm_dequant_test_helper( torch.bfloat16, torch.int8, gemm_ms=[256, 177, 195, 125, 66, 33, 8, 2, 1], gemm_ns=[1024, 2048, 4096], gemm_ks=[4096, 8192, 16384], rtol=0.01, atol=0.01) def test_bf16_int4_gemm(self): self.gemm_dequant_test_helper( torch.bfloat16, torch.quint4x2, gemm_ms=[256, 177, 195, 125, 66, 33, 8, 2, 1], gemm_ns=[1024, 2048, 4096], gemm_ks=[4096, 8192, 16384], rtol=0.01, atol=0.01) def test_fp16_int8_gemm_bias(self): self.gemm_dequant_test_helper(torch.float16, torch.int8, gemm_ms=[256], gemm_ns=[1024], gemm_ks=[8192], rtol=0.001, atol=0.002, act_str='identity') def test_fp16_int8_gemm_bias_relu(self): self.gemm_dequant_test_helper(torch.float16, torch.int8, gemm_ms=[256], gemm_ns=[1024], gemm_ks=[8192], rtol=0.001, atol=0.002, act_str='relu') def test_fp16_int8_gemm_bias_gelu(self): self.gemm_dequant_test_helper(torch.float16, torch.int8, gemm_ms=[256], gemm_ns=[1024], gemm_ks=[8192], rtol=0.001, atol=0.002, act_str='gelu') def test_fp16_int8_gemm_bias_silu(self): self.gemm_dequant_test_helper(torch.float16, torch.int8, gemm_ms=[256], gemm_ns=[1024], gemm_ks=[8192], rtol=0.001, atol=0.002, act_str='silu') def bench_helper(self, act_type, quant_type, rtol, atol): # Warm, using bfloat here since it seems to reliably use cublas. x = random_tensor([20480, 20480], torch.bfloat16, device='cuda') warm_iters = 30 for iter in range(warm_iters): res = x @ x m_shapes = torch.arange(0, 12) m_shapes = 2**m_shapes self.gemm_dequant_test_helper(act_type, quant_type, gemm_ms=[128], gemm_ns=[1536], gemm_ks=[12288], rtol=rtol, atol=atol, benchmark=True) @unittest.skip("This is a benchmark so don't run by default") def test_fp16_int8_cublas(self): self.bench_helper(torch.float16, torch.int8, 1e-3, 0.002) @unittest.skip("This is a benchmark so don't run by default") def test_bf16_int8_cublas(self): self.bench_helper(torch.bfloat16, torch.int8, 1e-2, 1e-2) @unittest.skip("This is a benchmark so don't run by default") def test_fp16_int4_cublas(self): self.bench_helper(torch.float16, torch.quint4x2, 1e-3, 0.002) @unittest.skip("This is a benchmark so don't run by default") def test_bf16_int4_cublas(self): self.bench_helper(torch.bfloat16, torch.quint4x2, 1e-2, 1e-2) if __name__ == '__main__': unittest.main()