Commit fe314160 authored by Casper Hansen's avatar Casper Hansen
Browse files

Refactor modules, create separate GEMM and GEMV

parent ef6b60e2
......@@ -7,10 +7,11 @@ import torch.nn as nn
from tqdm import tqdm
from collections import defaultdict
from awq.modules.qlinear import WQLinear_GEMM
from awq.modules.act import ScaledActivation
from huggingface_hub import snapshot_download
from awq.utils.calib_data import get_calib_dataset
from awq.quantize.quantizer import pseudo_quantize_tensor
from awq.quantize.qmodule import WQLinear, ScaledActivation
from awq.quantize.auto_clip import auto_clip_block, apply_clip
from awq.quantize.auto_scale import auto_scale_block, apply_scale
from transformers import AutoModelForCausalLM, AutoConfig, PreTrainedModel
......@@ -76,7 +77,7 @@ class BaseAWQForCausalLM(nn.Module):
scales = scales.t().contiguous()
zeros = zeros.t().contiguous()
q_linear = WQLinear.from_linear(
q_linear = WQLinear_GEMM.from_linear(
module,
self.quant_config['w_bit'],
self.quant_config['q_group_size'],
......@@ -351,7 +352,7 @@ class BaseAWQForCausalLM(nn.Module):
# Replace nn.Linear with WQLinear
for name, module in named_linears.items():
q_linear = WQLinear.from_linear(
q_linear = WQLinear_GEMM.from_linear(
module, quant_config['w_bit'], quant_config['q_group_size'], True)
q_linear.to(next(layer.parameters()).device)
set_op_by_name(layer, name, q_linear)
......
......@@ -67,11 +67,11 @@ class LlamaAWQForCausalLM(BaseAWQForCausalLM):
import torch
from typing import List, Tuple
from awq.quantize.qmodule import WQLinear
from awq.modules.qlinear import WQLinear_GEMM
from awq.utils.utils import set_module_name
from awq.modules.fused_mlp import QuantLlamaMLP
from awq.modules.fused_norm import FTLlamaRMSNorm
from awq.modules.fused_attn import QuantLlamaAttention
from awq.modules.fused.mlp import QuantLlamaMLP
from awq.modules.fused.norm import FTLlamaRMSNorm
from awq.modules.fused.attn import QuantLlamaAttention
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaRMSNorm, LlamaMLP
class LlamaFuser:
......@@ -95,7 +95,7 @@ class LlamaFuser:
def fuse_attention(self):
for name, module in self.attention_modules:
qkv_layer: WQLinear = self._fuse_qkv(module)
qkv_layer: WQLinear_GEMM = self._fuse_qkv(module)
attn = QuantLlamaAttention(
module.hidden_size,
module.num_heads,
......@@ -113,7 +113,7 @@ class LlamaFuser:
bias = torch.cat([q_proj.bias, k_proj.bias, v_proj.bias], dim=0) if q_proj.bias is not None else None
# create module
qkv_layer = WQLinear(
qkv_layer = WQLinear_GEMM(
q_proj.w_bit,
q_proj.group_size,
q_proj.in_features,
......
......@@ -67,7 +67,7 @@ class MptAWQForCausalLM(BaseAWQForCausalLM):
from typing import List, Tuple
from awq.utils.utils import set_module_name
from awq.modules.fused_mlp import QuantMPTMLP
from awq.modules.fused.mlp import QuantMPTMLP
class MptFuser:
def __init__(self, model):
......
from .fused_norm import *
from .fused_attn import *
from .fused_mlp import *
import torch.nn as nn
class ScaledActivation(nn.Module):
def __init__(self, module, scales):
super().__init__()
self.act = module
self.scales = nn.Parameter(scales.data)
def forward(self, x):
return self.act(x) / self.scales.view(1, 1, -1).to(x.device)
......@@ -4,17 +4,24 @@ import torch.nn as nn
import awq_inference_engine # with CUDA kernels
class ScaledActivation(nn.Module):
def __init__(self, module, scales):
super().__init__()
self.act = module
self.scales = nn.Parameter(scales.data)
def forward(self, x):
return self.act(x) / self.scales.view(1, 1, -1).to(x.device)
def make_divisible(c, divisor):
return (c + divisor - 1) // divisor
def calculate_zeros_width(in_features, group_size=128, pack_num=8):
if group_size >= 128:
size_multiplier = 1
elif group_size == 64:
size_multiplier = 2
elif group_size == 32:
size_multiplier = 4
else:
raise NotImplementedError
base_width = make_divisible(in_features // group_size, pack_num)
base_width = make_divisible(base_width, size_multiplier) * size_multiplier
return base_width
class WQLinear(nn.Module):
class WQLinear_GEMM(nn.Module):
def __init__(self, w_bit, group_size, in_features, out_features, bias, dev):
super().__init__()
......@@ -25,6 +32,7 @@ class WQLinear(nn.Module):
self.out_features = out_features
self.w_bit = w_bit
self.group_size = group_size if group_size != -1 else in_features
# quick sanity check (make sure aligment)
assert self.in_features % self.group_size == 0
assert out_features % (32 // self.w_bit) == 0
......@@ -74,7 +82,7 @@ class WQLinear(nn.Module):
zeros = zeros.to(dtype=torch.int32)
qzeros = torch.zeros((zeros.shape[0], zeros.shape[1] // 32 * awq_linear.w_bit), dtype=torch.int32, device=zeros.device)
for col in range(zeros.shape[1] // pack_num):
for col in range(zeros.shape[1] // pack_num):
if awq_linear.w_bit == 4:
order_map = [0, 2, 4, 6, 1, 3, 5, 7]
else:
......@@ -97,3 +105,100 @@ class WQLinear(nn.Module):
return 'in_features={}, out_features={}, bias={}, w_bit={}, group_size={}'.format(
self.in_features, self.out_features, self.bias is not None, self.w_bit, self.group_size
)
class WQLinear_GEMV(nn.Module):
def __init__(self, w_bit, group_size, in_features, out_features, bias, dev):
super().__init__()
if w_bit not in [4]:
raise NotImplementedError("Only 4-bit are supported for now.")
self.in_features = in_features
self.out_features = out_features
self.w_bit = w_bit
self.group_size = group_size if group_size != -1 else in_features
self.split_k_iters = 8
# quick sanity check (make sure aligment)
assert self.in_features % self.group_size == 0
assert out_features % (32 // self.w_bit) == 0
pack_num = (32 // self.w_bit)
self.register_buffer('qweight', torch.zeros((out_features, in_features // pack_num), dtype=torch.int32, device=dev))
self.register_buffer('qzeros', torch.zeros((out_features, calculate_zeros_width(in_features, self.group_size)), dtype=torch.int32, device=dev))
self.register_buffer('scales', torch.zeros((out_features, calculate_zeros_width(in_features, self.group_size) * pack_num), dtype=torch.float16, device=dev))
if bias:
self.register_buffer('bias', torch.zeros((out_features), dtype=torch.float16, device=dev))
else:
self.bias = None
@classmethod
def from_linear(cls, linear, w_bit, group_size, init_only=False, scales=None, zeros=None):
awq_linear = cls(w_bit, group_size, linear.in_features, linear.out_features, linear.bias is not None, linear.weight.device)
if init_only: # just prepare for loading sd
return awq_linear
# need scales and zeros info for real quantization
assert scales is not None and zeros is not None
scale_zeros = zeros * scales
pack_num = 32 // awq_linear.w_bit
qscales = torch.zeros(
(scales.shape[0], calculate_zeros_width(linear.in_features, group_size) * pack_num),
dtype=torch.float16,
device=scales.device
)
qscales[:, :scales.shape[1]] = scales
awq_linear.scales = qscales
if linear.bias is not None:
awq_linear.bias = linear.bias.clone().half()
intweight = []
for idx in range(awq_linear.in_features):
intweight.append(torch.round((linear.weight.data[:, idx] + scale_zeros[:, idx // group_size]) / awq_linear.scales[:, idx // group_size]).to(torch.int)[:, None])
intweight = torch.cat(intweight, dim=1)
intweight = intweight.to(dtype=torch.int32)
qweight = torch.zeros((intweight.shape[0], intweight.shape[1] // 32 * awq_linear.w_bit), dtype=torch.int32, device=intweight.device)
for col in range(intweight.shape[1] // pack_num):
if awq_linear.w_bit == 4:
order_map = [0, 1, 2, 3, 4, 5, 6, 7]
else:
raise NotImplementedError("Only 4-bit are supported for now.")
for i in range(pack_num):
qweight_col = intweight[:, col * pack_num + order_map[i]]
qweight[:, col] |= qweight_col << (i * awq_linear.w_bit)
awq_linear.qweight = qweight
zeros = zeros.to(dtype=torch.int32)
qzeros = torch.zeros(
(zeros.shape[0], calculate_zeros_width(linear.in_features, group_size)),
dtype=torch.int32,
device=zeros.device,
)
for col in range((zeros.shape[1] + pack_num - 1) // pack_num):
if awq_linear.w_bit == 4:
order_map = [0, 1, 2, 3, 4, 5, 6, 7]
else:
raise NotImplementedError("Only 4-bit are supported for now.")
for i in range(pack_num):
if col * pack_num + order_map[i] >= zeros.shape[1]:
continue
qzero_col = zeros[:, col * pack_num + order_map[i]]
qzeros[:, col] |= qzero_col << (i * awq_linear.w_bit)
awq_linear.qzeros = qzeros
return awq_linear
@torch.no_grad()
def forward(self, x):
out_shape = x.shape[:-1] + (self.out_features, )
out = awq_inference_engine.gemv_forward_cuda(x.reshape(-1, x.shape[-1]), self.qweight, self.scales, self.qzeros, self.group_size)
out = out + self.bias if self.bias is not None else out
return out.reshape(out_shape)
def extra_repr(self) -> str:
return 'in_features={}, out_features={}, bias={}, w_bit={}, group_size={}'.format(
self.in_features, self.out_features, self.bias is not None, self.w_bit, self.group_size
)
\ No newline at end of file
......@@ -7,7 +7,7 @@ from transformers.models.bloom.modeling_bloom import BloomBlock, BloomGelu
from transformers.models.opt.modeling_opt import OPTDecoderLayer
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaRMSNorm
from transformers.activations import NewGELUActivation
from .qmodule import ScaledActivation
from awq.modules.act import ScaledActivation
from awq.utils.module import get_op_by_name, get_op_name, set_op_by_name
__all__ = ["auto_scale_block", "apply_scale"]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment