Commit 5db43a7f authored by Casper Hansen's avatar Casper Hansen
Browse files

Implement GEMM/GEMV in quantize function and fused modules

parent 9b2946b6
...@@ -8,11 +8,11 @@ import torch.nn as nn ...@@ -8,11 +8,11 @@ import torch.nn as nn
from tqdm import tqdm from tqdm import tqdm
from collections import defaultdict from collections import defaultdict
from awq.modules.qlinear import WQLinear_GEMM, WQLinear_GEMV
from awq.modules.act import ScaledActivation from awq.modules.act import ScaledActivation
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from awq.utils.calib_data import get_calib_dataset from awq.utils.calib_data import get_calib_dataset
from awq.quantize.quantizer import pseudo_quantize_tensor from awq.quantize.quantizer import pseudo_quantize_tensor
from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV
from awq.quantize.auto_clip import auto_clip_block, apply_clip from awq.quantize.auto_clip import auto_clip_block, apply_clip
from awq.quantize.auto_scale import auto_scale_block, apply_scale from awq.quantize.auto_scale import auto_scale_block, apply_scale
from transformers import AutoModelForCausalLM, AutoConfig, PreTrainedModel from transformers import AutoModelForCausalLM, AutoConfig, PreTrainedModel
...@@ -43,6 +43,11 @@ class BaseAWQForCausalLM(nn.Module): ...@@ -43,6 +43,11 @@ class BaseAWQForCausalLM(nn.Module):
auto_scale=True, mse_range=True, run_search=True, run_quant=True, auto_scale=True, mse_range=True, run_search=True, run_quant=True,
calib_data="pileval"): calib_data="pileval"):
self.quant_config = quant_config self.quant_config = quant_config
quant_config["version"] = "GEMM" if 'version' not in quant_config.keys() else quant_config["version"]
if quant_config["version"] == "GEMM":
logging.warning('Deprecated model weight format. Re-quantize '
'your weights again with version="GEMV" for a speedup. '
'In the next AutoAWQ version, GEMM will be deprecated.')
if run_search: if run_search:
self.search_result = self._awq_search(tokenizer, quant_config, n_samples=n_samples, seqlen=seqlen, self.search_result = self._awq_search(tokenizer, quant_config, n_samples=n_samples, seqlen=seqlen,
...@@ -53,7 +58,7 @@ class BaseAWQForCausalLM(nn.Module): ...@@ -53,7 +58,7 @@ class BaseAWQForCausalLM(nn.Module):
self.is_quantized = True self.is_quantized = True
@staticmethod @staticmethod
def fuse_layers(model): def fuse_layers(model, quant_config):
pass pass
def _awq_quant(self): def _awq_quant(self):
...@@ -78,12 +83,17 @@ class BaseAWQForCausalLM(nn.Module): ...@@ -78,12 +83,17 @@ class BaseAWQForCausalLM(nn.Module):
scales = scales.t().contiguous() scales = scales.t().contiguous()
zeros = zeros.t().contiguous() zeros = zeros.t().contiguous()
q_linear = WQLinear_GEMM.from_linear( if self.quant_config["version"] == 'GEMM':
module, q_linear_module = WQLinear_GEMM
self.quant_config['w_bit'], elif self.quant_config["version"] == 'GEMV':
self.quant_config['q_group_size'], q_linear_module = WQLinear_GEMV
False,
scales, q_linear = q_linear_module.from_linear(
module,
self.quant_config['w_bit'],
self.quant_config['q_group_size'],
False,
scales,
zeros zeros
) )
...@@ -275,9 +285,12 @@ class BaseAWQForCausalLM(nn.Module): ...@@ -275,9 +285,12 @@ class BaseAWQForCausalLM(nn.Module):
if os.path.exists(quant_config_path): if os.path.exists(quant_config_path):
with open(quant_config_path, 'r') as file: with open(quant_config_path, 'r') as file:
quant_config = json.loads(file.read()) quant_config = json.loads(file.read())
if "version" not in quant_config.keys():
quant_config["version"] = version
else: else:
# Default config that works for most models # Default config that works for most models
quant_config = {"zero_point": True, "q_group_size": 128, "w_bit": 4, "version": "GEMM"} quant_config = {"zero_point": True, "q_group_size": 128, "w_bit": 4, "version": version}
# Load model config and set max generation length # Load model config and set max generation length
if max_new_tokens is None and hasattr(self, 'max_new_tokens_key'): if max_new_tokens is None and hasattr(self, 'max_new_tokens_key'):
...@@ -295,7 +308,7 @@ class BaseAWQForCausalLM(nn.Module): ...@@ -295,7 +308,7 @@ class BaseAWQForCausalLM(nn.Module):
# Only need to replace layers if a model is AWQ quantized # Only need to replace layers if a model is AWQ quantized
if is_quantized: if is_quantized:
# Prepare WQLinear layers, replace nn.Linear # Prepare WQLinear layers, replace nn.Linear
self._load_quantized_modules(self, model, quant_config, version) self._load_quantized_modules(self, model, quant_config, quant_config["version"])
model.tie_weights() model.tie_weights()
...@@ -315,7 +328,7 @@ class BaseAWQForCausalLM(nn.Module): ...@@ -315,7 +328,7 @@ class BaseAWQForCausalLM(nn.Module):
) )
if fuse_layers: if fuse_layers:
self.fuse_layers(model) self.fuse_layers(model, quant_config)
else: else:
# If not quantized, must load with AutoModelForCausalLM # If not quantized, must load with AutoModelForCausalLM
...@@ -364,9 +377,9 @@ class BaseAWQForCausalLM(nn.Module): ...@@ -364,9 +377,9 @@ class BaseAWQForCausalLM(nn.Module):
q_linear_module = WQLinear_GEMV q_linear_module = WQLinear_GEMV
q_linear = q_linear_module.from_linear( q_linear = q_linear_module.from_linear(
module, module,
quant_config['w_bit'], quant_config['w_bit'],
quant_config['q_group_size'], quant_config['q_group_size'],
True True
) )
q_linear.to(next(layer.parameters()).device) q_linear.to(next(layer.parameters()).device)
......
...@@ -6,8 +6,8 @@ class LlamaAWQForCausalLM(BaseAWQForCausalLM): ...@@ -6,8 +6,8 @@ class LlamaAWQForCausalLM(BaseAWQForCausalLM):
max_new_tokens_key = "max_position_embeddings" max_new_tokens_key = "max_position_embeddings"
@staticmethod @staticmethod
def fuse_layers(model: LlamaForCausalLM): def fuse_layers(model: LlamaForCausalLM, quant_config: dict):
fuser = LlamaFuser(model) fuser = LlamaFuser(model, quant_config)
fuser.fuse_attention() fuser.fuse_attention()
fuser.fuse_rmsnorm() fuser.fuse_rmsnorm()
fuser.fuse_mlp() fuser.fuse_mlp()
...@@ -66,17 +66,18 @@ class LlamaAWQForCausalLM(BaseAWQForCausalLM): ...@@ -66,17 +66,18 @@ class LlamaAWQForCausalLM(BaseAWQForCausalLM):
return layers return layers
import torch import torch
from typing import List, Tuple from typing import List, Tuple, Union
from awq.modules.qlinear import WQLinear_GEMM
from awq.utils.utils import set_module_name from awq.utils.utils import set_module_name
from awq.modules.fused.mlp import QuantLlamaMLP from awq.modules.fused.mlp import QuantLlamaMLP
from awq.modules.fused.norm import FTLlamaRMSNorm from awq.modules.fused.norm import FTLlamaRMSNorm
from awq.modules.fused.attn import QuantLlamaAttention from awq.modules.fused.attn import QuantLlamaAttention
from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaRMSNorm, LlamaMLP from transformers.models.llama.modeling_llama import LlamaAttention, LlamaRMSNorm, LlamaMLP
class LlamaFuser: class LlamaFuser:
def __init__(self, model): def __init__(self, model, quant_config):
self.model = model self.model = model
self.quant_config = quant_config
self.attention_modules: List[Tuple[str, LlamaAttention]] = [ self.attention_modules: List[Tuple[str, LlamaAttention]] = [
(name, module) for name, module in self.model.named_modules() (name, module) for name, module in self.model.named_modules()
...@@ -95,7 +96,7 @@ class LlamaFuser: ...@@ -95,7 +96,7 @@ class LlamaFuser:
def fuse_attention(self): def fuse_attention(self):
for name, module in self.attention_modules: for name, module in self.attention_modules:
qkv_layer: WQLinear_GEMM = self._fuse_qkv(module) qkv_layer: Union[WQLinear_GEMM, WQLinear_GEMV] = self._fuse_qkv(module)
attn = QuantLlamaAttention( attn = QuantLlamaAttention(
module.hidden_size, module.hidden_size,
module.num_heads, module.num_heads,
...@@ -113,7 +114,12 @@ class LlamaFuser: ...@@ -113,7 +114,12 @@ class LlamaFuser:
bias = torch.cat([q_proj.bias, k_proj.bias, v_proj.bias], dim=0) if q_proj.bias is not None else None bias = torch.cat([q_proj.bias, k_proj.bias, v_proj.bias], dim=0) if q_proj.bias is not None else None
# create module # create module
qkv_layer = WQLinear_GEMM( if self.quant_config["version"] == 'GEMM':
qkv_module = WQLinear_GEMM
elif self.quant_config["version"] == 'GEMV':
qkv_module = WQLinear_GEMV
qkv_layer = qkv_module(
q_proj.w_bit, q_proj.w_bit,
q_proj.group_size, q_proj.group_size,
q_proj.in_features, q_proj.in_features,
......
...@@ -6,7 +6,7 @@ class MptAWQForCausalLM(BaseAWQForCausalLM): ...@@ -6,7 +6,7 @@ class MptAWQForCausalLM(BaseAWQForCausalLM):
max_new_tokens_key = "max_seq_len" max_new_tokens_key = "max_seq_len"
@staticmethod @staticmethod
def fuse_layers(model: MptForCausalLM): def fuse_layers(model: MptForCausalLM, quant_config:dict):
fuser = MptFuser(model) fuser = MptFuser(model)
fuser.fuse_mlp() fuser.fuse_mlp()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment