Commit 5db43a7f authored by Casper Hansen's avatar Casper Hansen
Browse files

Implement GEMM/GEMV in quantize function and fused modules

parent 9b2946b6
......@@ -8,11 +8,11 @@ import torch.nn as nn
from tqdm import tqdm
from collections import defaultdict
from awq.modules.qlinear import WQLinear_GEMM, WQLinear_GEMV
from awq.modules.act import ScaledActivation
from huggingface_hub import snapshot_download
from awq.utils.calib_data import get_calib_dataset
from awq.quantize.quantizer import pseudo_quantize_tensor
from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV
from awq.quantize.auto_clip import auto_clip_block, apply_clip
from awq.quantize.auto_scale import auto_scale_block, apply_scale
from transformers import AutoModelForCausalLM, AutoConfig, PreTrainedModel
......@@ -43,6 +43,11 @@ class BaseAWQForCausalLM(nn.Module):
auto_scale=True, mse_range=True, run_search=True, run_quant=True,
calib_data="pileval"):
self.quant_config = quant_config
quant_config["version"] = "GEMM" if 'version' not in quant_config.keys() else quant_config["version"]
if quant_config["version"] == "GEMM":
logging.warning('Deprecated model weight format. Re-quantize '
'your weights again with version="GEMV" for a speedup. '
'In the next AutoAWQ version, GEMM will be deprecated.')
if run_search:
self.search_result = self._awq_search(tokenizer, quant_config, n_samples=n_samples, seqlen=seqlen,
......@@ -53,7 +58,7 @@ class BaseAWQForCausalLM(nn.Module):
self.is_quantized = True
@staticmethod
def fuse_layers(model):
def fuse_layers(model, quant_config):
pass
def _awq_quant(self):
......@@ -78,12 +83,17 @@ class BaseAWQForCausalLM(nn.Module):
scales = scales.t().contiguous()
zeros = zeros.t().contiguous()
q_linear = WQLinear_GEMM.from_linear(
module,
self.quant_config['w_bit'],
self.quant_config['q_group_size'],
False,
scales,
if self.quant_config["version"] == 'GEMM':
q_linear_module = WQLinear_GEMM
elif self.quant_config["version"] == 'GEMV':
q_linear_module = WQLinear_GEMV
q_linear = q_linear_module.from_linear(
module,
self.quant_config['w_bit'],
self.quant_config['q_group_size'],
False,
scales,
zeros
)
......@@ -275,9 +285,12 @@ class BaseAWQForCausalLM(nn.Module):
if os.path.exists(quant_config_path):
with open(quant_config_path, 'r') as file:
quant_config = json.loads(file.read())
if "version" not in quant_config.keys():
quant_config["version"] = version
else:
# Default config that works for most models
quant_config = {"zero_point": True, "q_group_size": 128, "w_bit": 4, "version": "GEMM"}
quant_config = {"zero_point": True, "q_group_size": 128, "w_bit": 4, "version": version}
# Load model config and set max generation length
if max_new_tokens is None and hasattr(self, 'max_new_tokens_key'):
......@@ -295,7 +308,7 @@ class BaseAWQForCausalLM(nn.Module):
# Only need to replace layers if a model is AWQ quantized
if is_quantized:
# Prepare WQLinear layers, replace nn.Linear
self._load_quantized_modules(self, model, quant_config, version)
self._load_quantized_modules(self, model, quant_config, quant_config["version"])
model.tie_weights()
......@@ -315,7 +328,7 @@ class BaseAWQForCausalLM(nn.Module):
)
if fuse_layers:
self.fuse_layers(model)
self.fuse_layers(model, quant_config)
else:
# If not quantized, must load with AutoModelForCausalLM
......@@ -364,9 +377,9 @@ class BaseAWQForCausalLM(nn.Module):
q_linear_module = WQLinear_GEMV
q_linear = q_linear_module.from_linear(
module,
quant_config['w_bit'],
quant_config['q_group_size'],
module,
quant_config['w_bit'],
quant_config['q_group_size'],
True
)
q_linear.to(next(layer.parameters()).device)
......
......@@ -6,8 +6,8 @@ class LlamaAWQForCausalLM(BaseAWQForCausalLM):
max_new_tokens_key = "max_position_embeddings"
@staticmethod
def fuse_layers(model: LlamaForCausalLM):
fuser = LlamaFuser(model)
def fuse_layers(model: LlamaForCausalLM, quant_config: dict):
fuser = LlamaFuser(model, quant_config)
fuser.fuse_attention()
fuser.fuse_rmsnorm()
fuser.fuse_mlp()
......@@ -66,17 +66,18 @@ class LlamaAWQForCausalLM(BaseAWQForCausalLM):
return layers
import torch
from typing import List, Tuple
from awq.modules.qlinear import WQLinear_GEMM
from typing import List, Tuple, Union
from awq.utils.utils import set_module_name
from awq.modules.fused.mlp import QuantLlamaMLP
from awq.modules.fused.norm import FTLlamaRMSNorm
from awq.modules.fused.attn import QuantLlamaAttention
from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaRMSNorm, LlamaMLP
class LlamaFuser:
def __init__(self, model):
def __init__(self, model, quant_config):
self.model = model
self.quant_config = quant_config
self.attention_modules: List[Tuple[str, LlamaAttention]] = [
(name, module) for name, module in self.model.named_modules()
......@@ -95,7 +96,7 @@ class LlamaFuser:
def fuse_attention(self):
for name, module in self.attention_modules:
qkv_layer: WQLinear_GEMM = self._fuse_qkv(module)
qkv_layer: Union[WQLinear_GEMM, WQLinear_GEMV] = self._fuse_qkv(module)
attn = QuantLlamaAttention(
module.hidden_size,
module.num_heads,
......@@ -113,7 +114,12 @@ class LlamaFuser:
bias = torch.cat([q_proj.bias, k_proj.bias, v_proj.bias], dim=0) if q_proj.bias is not None else None
# create module
qkv_layer = WQLinear_GEMM(
if self.quant_config["version"] == 'GEMM':
qkv_module = WQLinear_GEMM
elif self.quant_config["version"] == 'GEMV':
qkv_module = WQLinear_GEMV
qkv_layer = qkv_module(
q_proj.w_bit,
q_proj.group_size,
q_proj.in_features,
......
......@@ -6,7 +6,7 @@ class MptAWQForCausalLM(BaseAWQForCausalLM):
max_new_tokens_key = "max_seq_len"
@staticmethod
def fuse_layers(model: MptForCausalLM):
def fuse_layers(model: MptForCausalLM, quant_config:dict):
fuser = MptFuser(model)
fuser.fuse_mlp()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment