import copy import unittest import autogptq_marlin_cuda import torch import torch.nn as nn from auto_gptq.nn_modules.qlinear.qlinear_cuda_old import QuantLinear as CudaOldQuantLinear from auto_gptq.nn_modules.qlinear.qlinear_marlin import QuantLinear as MarlinQuantLinear from auto_gptq.nn_modules.qlinear.qlinear_marlin import _get_perms, dequantize_weight def gen_quant4(k, n, groupsize=-1): maxq = 2 ** 4 - 1 w = torch.randn((k, n), dtype=torch.half, device="cpu") original_w = w.clone() if groupsize != -1: w = w.reshape((-1, groupsize, n)) w = w.permute(1, 0, 2) w = w.reshape((groupsize, -1)) s = torch.max(torch.abs(w), 0, keepdim=True)[0] s *= 2 / maxq # Quantize. w = torch.round(w / s).int() # Unsigned storage. w += (maxq + 1) // 2 w = torch.clamp(w, 0, maxq) # Dequantize. ref = (w - (maxq + 1) // 2).half() * s if groupsize != -1: def reshape(w): w = w.reshape((groupsize, -1, n)) w = w.permute(1, 0, 2) w = w.reshape((k, n)).contiguous() return w ref = reshape(ref) w = reshape(w) s = s.reshape((-1, n)).contiguous() linear = nn.Linear(k, n, bias=False) linear.weight.data = ref.t() return original_w, linear, s original_w, linear, s = gen_quant4(64, 128) class TestRepacking(unittest.TestCase): def test_marlin_fast_repacking(self): k = 2048 n = 1024 m = 5 group_size = 128 _, linear, s = gen_quant4(k, n, group_size) cuda_old_linear = CudaOldQuantLinear(bits=4, group_size=group_size, infeatures=k, outfeatures=n, bias=False) zeros = torch.full((k // group_size, n), 8, dtype=torch.int32) cuda_old_linear.pack(linear, s.T, zeros.T, g_idx=None) # Adapted from utils.marlin_utils.convert_to_marlin dequantized_weight, dequantized_qzeros = dequantize_weight(cuda_old_linear) dequantized_weight = dequantized_weight.to(torch.float16) self.assertTrue(torch.all(dequantized_qzeros == 8)) linear_module = torch.nn.Linear( in_features=k, out_features=n, bias=False, dtype=torch.float16, device="cuda", ) linear_module.weight.data.copy_(linear.weight.data) # Not using dequantized_weight to avoid approx # Create new linear method and copy to model. marlin_linear = MarlinQuantLinear( bits=4, group_size=group_size, infeatures=k, outfeatures=n, bias=False, trainable=False, ) marlin_linear.pack(linear_module.to("cuda"), scales=copy.deepcopy(cuda_old_linear.scales.data.t()).to("cuda")) inp = torch.rand(m, k, dtype=torch.float16, device="cuda") cuda_old_linear = cuda_old_linear.to("cuda") marlin_linear = marlin_linear.to("cuda") with torch.no_grad(): res_cuda_old = cuda_old_linear(inp) res_marlin = marlin_linear(inp) reldiff = (res_cuda_old - res_marlin).abs() / (res_cuda_old.abs() + 1e-12) self.assertTrue(torch.mean(reldiff) < 4e-3) weight_repacked = autogptq_marlin_cuda.gptq_repack(cuda_old_linear.qweight) self.assertTrue(torch.allclose(weight_repacked, marlin_linear.B)) _, _scale_perm, _scale_perm_single = _get_perms() s = cuda_old_linear.scales.data.clone() if group_size != k: s = s.reshape((1, -1)) s = s.reshape((-1, len(_scale_perm)))[:, _scale_perm] else: s = s.reshape((-1, len(_scale_perm_single)))[:, _scale_perm_single] s = s.reshape((-1, n)).contiguous() self.assertTrue(torch.allclose(s, marlin_linear.s))