import torch try: import awq_v2_ext # with CUDA kernels (AutoAWQ_kernels) AWQ_INSTALLED = True except: AWQ_INSTALLED = False def make_divisible(c, divisor): return (c + divisor - 1) // divisor def calculate_zeros_width(in_features, group_size=128, pack_num=8): if group_size >= 128: size_multiplier = 1 elif group_size == 64: size_multiplier = 2 elif group_size == 32: size_multiplier = 4 else: raise NotImplementedError base_width = make_divisible(in_features // group_size, pack_num) base_width = make_divisible(base_width, size_multiplier) * size_multiplier return base_width def pack_intweight(unpacked_qweight, interleave, kstride): # unpacked_qweight: [N, K] N = unpacked_qweight.shape[0] K = unpacked_qweight.shape[1] Packed_Kernel = unpacked_qweight.cpu().numpy().reshape(N, K // 32, 32) # np.arange(32).reshape(4, 4, 2).transpose(1, 0, 2) => [0, 1, 8, 9, 16, 17, 24, 25, ...] Packed_Kernel = Packed_Kernel.reshape(N, K // 32, 4, 4, 2).transpose(0, 1, 3, 2, 4) Packed_Kernel = Packed_Kernel.reshape(N, K // 32, 32) # reorder each 8 weights for fast dequantization # [0, 1, 2, 3, 4, 5, 6, 7] => [0, 2, 4, 6, 1, 3, 5, 7] Packed_Kernel = Packed_Kernel.reshape(N, K // 32, 4, 8) Packed_Kernel = Packed_Kernel.reshape(N, K // 32, 4, 4, 2).transpose(0, 1, 2, 4, 3) Packed_Kernel = Packed_Kernel.reshape(N, K) # interleaving every four rows Packed_Kernel = Packed_Kernel.reshape( N // interleave, interleave, K // kstride, kstride ) # N // 4, K // 64, 4, 64 Packed_Kernel = Packed_Kernel.transpose(0, 2, 1, 3) Packed_Kernel = Packed_Kernel.reshape( N // interleave, K // kstride, kstride, interleave ) # Packing -> (N // 4, K // 64, 64) Packed_Kernel = ( Packed_Kernel[..., 0] | (Packed_Kernel[..., 1] << 4) | (Packed_Kernel[..., 2] << 8) | (Packed_Kernel[..., 3] << 12) ) # reshape to (N // 4, K), FP16 format Packed_Kernel = Packed_Kernel.reshape(N // interleave, K) qweight = ( torch.tensor(Packed_Kernel.astype("int16")) .to(unpacked_qweight.device) .contiguous() ) return qweight class WQLinear_GEMVFast(torch.nn.Module): def __init__(self, w_bit, group_size, in_features, out_features, bias, dev): super().__init__() self.in_features = in_features self.out_features = out_features self.w_bit = w_bit self.group_size = group_size if group_size != -1 else in_features self.split_k_iters = 8 self.interleave = 4 # quick sanity check (make sure aligment) assert self.in_features % self.group_size == 0 assert out_features % (32 // self.w_bit) == 0 pack_num = 32 // self.w_bit int16_pack_num = 16 // self.w_bit assert out_features % (self.interleave) == 0 self.register_buffer( "qweight", torch.zeros( ( out_features // self.interleave, in_features // int16_pack_num * self.interleave, ), dtype=torch.int16, device=dev, ), ) self.register_buffer( "scales", torch.zeros( ( calculate_zeros_width(in_features, self.group_size) * pack_num, out_features, ), dtype=torch.float16, device=dev, ), ) self.register_buffer( "qzeros", torch.zeros( ( calculate_zeros_width(in_features, self.group_size) * pack_num, out_features, ), dtype=torch.float16, device=dev, ), ) if bias: self.register_buffer( "bias", torch.zeros((out_features), dtype=torch.float16, device=dev) ) else: self.bias = None @classmethod def from_linear( cls, linear, w_bit, group_size, init_only=False, scales=None, zeros=None ): awq_linear = cls( w_bit, group_size, linear.in_features, linear.out_features, linear.bias is not None, linear.weight.device, ) if init_only: return awq_linear # need scales and zeros info for real quantization assert scales is not None and zeros is not None scale_zeros = zeros * scales pack_num = 32 // awq_linear.w_bit qscales = torch.zeros( ( scales.shape[0], calculate_zeros_width(linear.in_features, group_size) * pack_num, ), dtype=torch.float16, device=scales.device, ) qscales[:, : scales.shape[1]] = scales # awq_linear.scales = scales.clone().half() awq_linear.scales = qscales.transpose(1, 0).contiguous() if linear.bias is not None: awq_linear.bias = linear.bias.clone().half() intweight = [] for idx in range(awq_linear.in_features): intweight.append( torch.round( (linear.weight.data[:, idx] + scale_zeros[:, idx // group_size]) / qscales[:, idx // group_size] ).to(torch.int)[:, None] ) intweight = torch.cat(intweight, dim=1) intweight = intweight.to(dtype=torch.int32) awq_linear.qweight = pack_intweight( intweight.contiguous(), interleave=4, kstride=64 ) zeros = zeros.to(dtype=torch.int32) qzeros = torch.zeros_like(qscales) qzeros[:, : scales.shape[1]] = -( qscales[:, : scales.shape[1]] * (zeros.to(torch.float32)) ).to(torch.float16) awq_linear.qzeros = qzeros.transpose(1, 0).contiguous() return awq_linear @torch.no_grad() def forward(self, x): inputs = x batch_size, n_tokens, _ = inputs.shape if batch_size < 8 and n_tokens == 1: out = awq_v2_ext.gemv_forward_cuda_decode( inputs, self.qweight, self.scales, self.qzeros, inputs.numel() // inputs.shape[-1], self.out_features, self.in_features, self.group_size, ) else: out = awq_v2_ext.gemm_forward_cuda_prefill( inputs, self.qweight, self.scales, self.qzeros ) out = out + self.bias if self.bias is not None else out return out