Unverified Commit eea08aa6 authored by Casper's avatar Casper Committed by GitHub
Browse files

AwqConfig class (#132)

parent a7d87540
......@@ -74,6 +74,7 @@ The detailed support list:
| ---------| ----------------------------|
| LLaMA-2 | 7B/13B/70B |
| LLaMA | 7B/13B/30B/65B |
| Mistral | 7B |
| Vicuna | 7B/13B |
| MPT | 7B/30B |
| Falcon | 7B/40B |
......@@ -97,6 +98,8 @@ There are two versions of AWQ: GEMM and GEMV. Both names relate to how matrix mu
### Examples
More examples can be found in the [examples directory](examples).
<details>
<summary>Quantization</summary>
......@@ -109,7 +112,7 @@ from transformers import AutoTokenizer
model_path = 'lmsys/vicuna-7b-v1.5'
quant_path = 'vicuna-7b-v1.5-awq'
quant_config = { "zero_point": True, "q_group_size": 128, "w_bit": 4 }
quant_config = { "zero_point": True, "q_group_size": 128, "w_bit": 4, "version": "GEMM" }
# Load model
model = AutoAWQForCausalLM.from_pretrained(model_path)
......@@ -134,10 +137,9 @@ from awq import AutoAWQForCausalLM
from transformers import AutoTokenizer, TextStreamer
quant_path = "casperhansen/vicuna-7b-v1.5-awq"
quant_file = "awq_model_w4_g128.pt"
# Load model
model = AutoAWQForCausalLM.from_quantized(quant_path, quant_file, fuse_layers=True)
model = AutoAWQForCausalLM.from_quantized(quant_path, fuse_layers=True)
tokenizer = AutoTokenizer.from_pretrained(quant_path, trust_remote_code=True)
streamer = TextStreamer(tokenizer, skip_special_tokens=True)
......
import os
import json
import logging
from typing import Dict
from dataclasses import dataclass, field, fields
from transformers.utils.hub import PushToHubMixin, cached_file
@dataclass
class AwqConfig(PushToHubMixin):
quant_method: str = field(default="awq")
zero_point: bool = field(default=True)
q_group_size: int = field(default=128)
w_bit: int = field(default=4)
version: str = field(default="GEMM")
config_file_name = "quant_config.json"
def save_pretrained(self, save_dir: str, **kwargs):
logging.warning(
"`quant_config.json` is being deprecated in the future"
" in favor of quantization_config in config.json."
)
with open(os.path.join(save_dir, self.config_file_name), "w+", encoding="utf-8") as file:
file.write(json.dumps(self.to_dict(), indent=4))
@classmethod
def from_dict(cls, quant_config: Dict={}):
if not quant_config:
quant_config = cls()
else:
quant_config = cls(**quant_config)
return quant_config
@classmethod
def from_pretrained(cls, save_dir: str, **kwargs):
cache_dir = kwargs.pop("cache_dir", None)
force_download = kwargs.pop("force_download", False)
resume_download = kwargs.pop("resume_download", False)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", False)
use_auth_token = kwargs.pop("use_auth_token", None)
revision = kwargs.pop("revision", None)
subfolder = kwargs.pop("subfolder", None)
commit_hash = kwargs.pop("_commit_hash", None)
if os.path.isdir(save_dir): # Local
resolved_config_file = os.path.join(save_dir, cls.config_file_name)
else: # Remote
resolved_config_file = cached_file(
save_dir,
cls.config_file_name,
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
use_auth_token=use_auth_token,
revision=revision,
local_files_only=local_files_only,
subfolder=subfolder,
_raise_exceptions_for_missing_entries=False,
_raise_exceptions_for_connection_errors=False,
_commit_hash=commit_hash,
)
if os.path.exists(resolved_config_file):
with open(resolved_config_file, 'r', encoding="utf-8") as file:
loaded_config = json.loads(file.read())
quant_config = cls(**loaded_config)
else:
quant_config = cls()
return quant_config
def to_dict(self):
return {
"zero_point": self.zero_point,
"q_group_size": self.q_group_size,
"w_bit": self.w_bit,
"version": self.version
}
def to_transformers_dict(self):
return {
"quant_method": self.quant_method,
"zero_point": self.zero_point,
"group_size": self.q_group_size,
"bits": self.w_bit,
"version": self.version.lower(),
}
## Reference from llama.py
from .base import BaseAWQForCausalLM
from typing import Dict
from transformers.models.llama.modeling_llama import (
LlamaDecoderLayer as AquilaDecoderLayer,
LlamaForCausalLM as AquilaForCausalLM,
......@@ -14,8 +13,8 @@ class AquilaAWQForCausalLM(BaseAWQForCausalLM):
max_new_tokens_key = "max_position_embeddings"
@staticmethod
def fuse_layers(model: AquilaForCausalLM, quant_config: Dict):
fuser = AquilaFuser(model, quant_config)
def fuse_layers(model: AquilaForCausalLM):
fuser = AquilaFuser(model)
fuser.fuse_attention()
fuser.fuse_rmsnorm()
fuser.fuse_mlp()
......@@ -82,9 +81,8 @@ from awq.modules.fused.norm import FasterTransformerRMSNorm
from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV
class AquilaFuser:
def __init__(self, model, quant_config):
def __init__(self, model):
self.model = model
self.quant_config = quant_config
self.attention_modules: List[Tuple[str, AquilaAttention]] = [
(name, module) for name, module in self.model.named_modules()
......
......@@ -4,18 +4,19 @@ import json
import torch
import torch.nn as nn
from tqdm import tqdm
from typing import List, Union, Dict
from typing import List, Union
from safetensors.torch import save_file
from awq.models._config import AwqConfig
from awq.modules.act import ScaledActivation
from huggingface_hub import snapshot_download
from awq.quantize.quantizer import AwqQuantizer
from awq.utils.utils import simple_dispatch_model
from transformers.modeling_utils import shard_checkpoint
from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV
from awq.utils.module import get_named_linears, set_op_by_name
from transformers import AutoModelForCausalLM, AutoConfig, PreTrainedModel
from accelerate import init_empty_weights, load_checkpoint_in_model, infer_auto_device_map
class BaseAWQForCausalLM(nn.Module):
def __init__(self, model, model_type, is_quantized, quant_config):
super().__init__()
......@@ -23,7 +24,7 @@ class BaseAWQForCausalLM(nn.Module):
self.model_type:str = model_type
self.is_quantized:bool = is_quantized
self.search_result = None
self.quant_config: Dict = quant_config
self.quant_config: AwqConfig = quant_config
def to(self, device: str):
return self.model.to(device)
......@@ -39,18 +40,17 @@ class BaseAWQForCausalLM(nn.Module):
def quantize(self, tokenizer=None, quant_config={},
calib_data: Union[str, List[str]]="pileval",
split="train", text_column="text"):
self.quant_config = quant_config
quant_config["version"] = "GEMM" if 'version' not in quant_config.keys() else quant_config["version"]
self.quant_config: AwqConfig = AwqConfig.from_dict(quant_config)
quantizer = AwqQuantizer(
self, self.model, tokenizer, quant_config["w_bit"], quant_config["q_group_size"],
quant_config["version"], calib_data, split, text_column
self, self.model, tokenizer, self.quant_config.w_bit, self.quant_config.q_group_size,
self.quant_config.version, calib_data, split, text_column
)
quantizer.quantize()
self.is_quantized = True
@staticmethod
def fuse_layers(model, quant_config):
def fuse_layers(model):
pass
def save_quantized(self, save_dir, safetensors=False, shard_size="10GB"):
......@@ -61,8 +61,10 @@ class BaseAWQForCausalLM(nn.Module):
def __init__(self): super(EmptyModule, self).__init__()
def forward(self, x): return x
# Save model files with empty state dict
# Save model and config files with empty state dict
self.model.config.quantization_config = self.quant_config.to_transformers_dict()
self.model.save_pretrained(save_dir, state_dict=EmptyModule().state_dict())
self.quant_config.save_pretrained(save_dir)
# Remove empty state dict
os.remove(f'{save_dir}/pytorch_model.bin')
......@@ -89,10 +91,6 @@ class BaseAWQForCausalLM(nn.Module):
if index is not None:
with open(f'{save_dir}/{model_name}.index.json', 'w+') as file:
file.write(json.dumps(index, indent=4))
# Save config
with open(f'{save_dir}/quant_config.json', 'w+') as file:
file.write(json.dumps(self.quant_config, indent=4))
@classmethod
......@@ -146,7 +144,7 @@ class BaseAWQForCausalLM(nn.Module):
model = AutoModelForCausalLM.from_config(config=config, torch_dtype=torch_dtype, trust_remote_code=trust_remote_code)
# Prepare WQLinear layers, replace nn.Linear
self._load_quantized_modules(self, model, quant_config, quant_config["version"])
self._load_quantized_modules(self, model, quant_config, quant_config.version)
model.tie_weights()
......@@ -169,7 +167,7 @@ class BaseAWQForCausalLM(nn.Module):
# Dispath to devices
if fuse_layers:
self.fuse_layers(model, quant_config)
self.fuse_layers(model)
# Offloading dispatch
from accelerate import dispatch_model
......@@ -201,16 +199,7 @@ class BaseAWQForCausalLM(nn.Module):
# [STEP 2] Load config and set sequence length
# TODO: Create BaseAWQConfig class
quant_config_path = f'{model_path}/quant_config.json'
if os.path.exists(quant_config_path):
with open(quant_config_path, 'r') as file:
quant_config = json.loads(file.read())
if "version" not in quant_config.keys():
quant_config["version"] = version
else:
# Default config that works for most models
quant_config = {"zero_point": True, "q_group_size": 128, "w_bit": 4, "version": version}
quant_config = AwqConfig.from_pretrained(model_path)
# Load model config and set max generation length
if max_new_tokens is None and hasattr(self, 'max_new_tokens_key'):
......@@ -225,7 +214,7 @@ class BaseAWQForCausalLM(nn.Module):
def _load_quantized_modules(self, model, quant_config, version):
# Real quantization of weights
assert quant_config["zero_point"], "We only support zero_point quantization now."
assert quant_config.zero_point, "We only support zero_point quantization now."
# Get blocks of model
layers = self.get_model_layers(model)
......@@ -248,8 +237,8 @@ class BaseAWQForCausalLM(nn.Module):
q_linear = q_linear_module.from_linear(
module,
quant_config['w_bit'],
quant_config['q_group_size'],
quant_config.w_bit,
quant_config.q_group_size,
True
)
q_linear.to(next(layer.parameters()).device)
......
from .base import BaseAWQForCausalLM
from typing import Dict
from transformers.models.falcon.modeling_falcon import FalconDecoderLayer as OldFalconDecoderLayer, FalconForCausalLM, FalconAttention
class FalconAWQForCausalLM(BaseAWQForCausalLM):
layer_type = "FalconDecoderLayer"
@staticmethod
def fuse_layers(model: FalconForCausalLM, quant_config: Dict):
def fuse_layers(model: FalconForCausalLM):
fuser = FalconFuser(model)
# TODO: Implement correctly fused modules for Falcon 40B and Falcon 180B
......
from .base import BaseAWQForCausalLM
from typing import Dict
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaForCausalLM
class LlamaAWQForCausalLM(BaseAWQForCausalLM):
......@@ -7,8 +6,8 @@ class LlamaAWQForCausalLM(BaseAWQForCausalLM):
max_new_tokens_key = "max_position_embeddings"
@staticmethod
def fuse_layers(model: LlamaForCausalLM, quant_config: Dict):
fuser = LlamaFuser(model, quant_config)
def fuse_layers(model: LlamaForCausalLM):
fuser = LlamaFuser(model)
fuser.fuse_attention()
fuser.fuse_rmsnorm()
fuser.fuse_mlp()
......@@ -76,9 +75,8 @@ from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaRMSNorm, LlamaMLP
class LlamaFuser:
def __init__(self, model, quant_config):
def __init__(self, model):
self.model = model
self.quant_config = quant_config
self.attention_modules: List[Tuple[str, LlamaAttention]] = [
(name, module) for name, module in self.model.named_modules()
......
from typing import Dict
from .base import BaseAWQForCausalLM
from transformers.models.mistral.modeling_mistral import MistralDecoderLayer, MistralForCausalLM
......@@ -7,8 +6,8 @@ class MistralAWQForCausalLM(BaseAWQForCausalLM):
max_new_tokens_key = "max_position_embeddings"
@staticmethod
def fuse_layers(model: MistralForCausalLM, quant_config: Dict):
fuser = MistralFuser(model, quant_config)
def fuse_layers(model: MistralForCausalLM):
fuser = MistralFuser(model)
fuser.fuse_attention()
fuser.fuse_rmsnorm()
fuser.fuse_mlp()
......@@ -76,9 +75,8 @@ from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV
from transformers.models.mistral.modeling_mistral import MistralAttention, MistralRMSNorm, MistralMLP
class MistralFuser:
def __init__(self, model, quant_config):
def __init__(self, model):
self.model = model
self.quant_config = quant_config
self.attention_modules: List[Tuple[str, MistralAttention]] = [
(name, module) for name, module in self.model.named_modules()
......
from .base import BaseAWQForCausalLM
from typing import Dict
from transformers.models.mpt.modeling_mpt import MptBlock as OldMptBlock, MptForCausalLM
class MptAWQForCausalLM(BaseAWQForCausalLM):
......@@ -7,7 +6,7 @@ class MptAWQForCausalLM(BaseAWQForCausalLM):
max_new_tokens_key = "max_seq_len"
@staticmethod
def fuse_layers(model: MptForCausalLM, quant_config: Dict):
def fuse_layers(model: MptForCausalLM):
fuser = MptFuser(model)
fuser.fuse_transformer()
......
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
# NOTE: Must install from PR until merged
# pip install --upgrade git+https://github.com/younesbelkada/transformers.git@add-awq
model_id = "casperhansen/mistral-7b-instruct-v0.1-awq"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
device_map="cuda:0"
)
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
# Convert prompt to tokens
text = "[INST] What are the basic steps to use the Huggingface transformers library? [/INST]"
tokens = tokenizer(
text,
return_tensors='pt'
).input_ids.cuda()
# Generate output
generation_output = model.generate(
tokens,
streamer=streamer,
max_new_tokens=512
)
\ No newline at end of file
......@@ -85,7 +85,7 @@ def run_round(model_path, quant_file, n_generate, input_ids, batch_size, safeten
"Prefill tokens/s": prefill_tokens_per_second,
"Decode tokens/s": decode_tokens_per_second,
"Memory (VRAM)": f"{memory_used:.2f} GB ({memory_pct:.2f}%)"
}, model.quant_config["version"]
}, model.quant_config.version
def main(args):
rounds = [
......@@ -126,8 +126,8 @@ def main(args):
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model_path", type=str, default="casperhansen/vicuna-7b-v1.5-awq", help="path to the model")
parser.add_argument("--quant_file", type=str, default="awq_model_w4_g128.pt", help="weights filename")
parser.add_argument("--model_path", type=str, default="casperhansen/mistral-7b-instruct-v0.1-awq", help="path to the model")
parser.add_argument("--quant_file", type=str, default="", help="weights filename")
parser.add_argument("--batch_size", type=int, default=1, help="Batch size for cache and generation")
parser.add_argument("--safetensors", default=False, action="store_true", help="Use for enabling safetensors")
args = parser.parse_args()
......
......@@ -33,7 +33,7 @@ def run_eval(model_path, quant_file, device, tasks, task_batch_size, task_n_shot
if __name__ == '__main__':
"""
- Run perplexity of quantized model:
python examples/eval.py --model_path vicuna-7b-v1.5-awq --quant_file awq_model_w4_g128.pt
python examples/eval.py --model_path casperhansen/mistral-7b-instruct-v0.1-awq
- Run perplexity unquantized FP16 model:
python examples/eval.py --use_pretrained --model_path lmsys/vicuna-7b-v1.5
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment