Commit 1e82667b authored by Casper's avatar Casper
Browse files

Merge remote-tracking branch 'origin/refactor-models'

parents 05abe7d6 46750ff9
...@@ -69,13 +69,16 @@ git clone https://huggingface.co/datasets/mit-han-lab/awq-model-zoo awq_cache ...@@ -69,13 +69,16 @@ git clone https://huggingface.co/datasets/mit-han-lab/awq-model-zoo awq_cache
The detailed support list: The detailed support list:
| Models | Sizes | INT4-g128 | INT3-g128 | | Models | Sizes | INT4-g128 | INT3-g128 |
| ------ | --------------------------- | --------- | --------- | | ---------| ----------------------------| -----------| --------- |
| LLaMA-2 | 7B/7B-chat/13B/13B-chat | ✅ | ✅ | | LLaMA-2 | 7B/13B/70B | ✅ | ✅ |
| LLaMA | 7B/13B/30B/65B | ✅ | ✅ | | LLaMA | 7B/13B/30B/65B | ✅ | ✅ |
| OPT | 125m/1.3B/2.7B/6.7B/13B/30B | ✅ | ✅ | | Vicuna | 7B/13B | ✅ | |
| Vicuna-v1.1 | 7B/13B | ✅ | | | MPT | 7B/30B | ✅ | |
| LLaVA-v0 | 13B | ✅ | | | Falcon | 7B/40B | ✅ | |
| OPT | 125m/1.3B/2.7B/6.7B/13B/30B | ✅ | ✅ |
| Bloom | 560m/3B/7B/ | ✅ | ✅ |
| LLaVA-v0 | 13B | ✅ | |
## Examples ## Examples
...@@ -89,39 +92,30 @@ Note that we perform AWQ using only textual calibration data, depsite we are run ...@@ -89,39 +92,30 @@ Note that we perform AWQ using only textual calibration data, depsite we are run
## Usage ## Usage
We provide several sample script to run AWQ (please refer to `./scripts`). We use OPT-6.7B as an example. We provide several sample script to run AWQ (please refer to `./scripts`). We use Vicuna 7B v1.5 as an example.
1. Perform AWQ search and save search results (we already did it for you): 1. Perform AWQ search and save search results
```bash ```bash
python -m awq.entry --model_path /PATH/TO/OPT/opt-6.7b \ python -m awq.entry --entry_type search \
--w_bit 4 --q_group_size 128 \ --model_path lmsys/vicuna-7b-v1.5 \
--run_awq --dump_awq awq_cache/opt-6.7b-w4-g128.pt --search_path vicuna-7b-v1.5-awq
``` ```
2. Evaluate the AWQ quantized model on WikiText-2 (simulated pseudo quantization) Note: if you use Falcon 7B, please pass `--q_group_size 64` in order for it to work.
```bash
python -m awq.entry --model_path /PATH/TO/OPT/opt-6.7b \
--tasks wikitext \
--w_bit 4 --q_group_size 128 \
--load_awq awq_cache/opt-6.7b-w4-g128.pt \
--q_backend fake
```
3. Generate real quantized weights (INT4) 2. Generate quantized weights and save them (INT4)
```bash ```bash
mkdir quant_cache python -m awq.entry --entry_type quant \
python -m awq.entry --model_path /PATH/TO/OPT/opt-6.7b \ --model_path lmsys/vicuna-7b-v1.5 \
--w_bit 4 --q_group_size 128 \ --search_path vicuna-7b-v1.5-awq/awq_model_search_result.pt \
--load_awq awq_cache/opt-6.7b-w4-g128.pt \ --quant_path vicuna-7b-v1.5-awq
--q_backend real --dump_quant quant_cache/opt-6.7b-w4-g128-awq.pt
``` ```
4. Load and evaluate the real quantized model (now you can see smaller gpu memory usage) 3. Load and evaluate the perplexity of the real quantized model weights (faster and uses less memory)
```bash ```bash
python -m awq.entry --model_path /PATH/TO/OPT/opt-6.7b \ python -m awq.entry --entry_type perplexity \
--tasks wikitext \ --quant_path vicuna-7b-v1.5-awq \
--w_bit 4 --q_group_size 128 \ --quant_file awq_model_w4_g128.pt
--load_quant quant_cache/opt-6.7b-w4-g128-awq.pt
``` ```
## Reference ## Reference
......
from lm_eval import evaluator, tasks import os
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig import time
import torch import torch
import argparse import argparse
import os from lm_eval import evaluator
import json from transformers import AutoTokenizer
from accelerate import init_empty_weights, infer_auto_device_map, dispatch_model, load_checkpoint_in_model from awq.models.auto import AutoAWQForCausalLM
from awq.utils.parallel import auto_parallel from awq.quantize.auto_clip import apply_clip
from awq.quantize.pre_quant import run_awq, apply_awq from awq.quantize.auto_scale import apply_scale
from awq.quantize.quantizer import pseudo_quantize_model_weight, real_quantize_model_weight
from awq.utils.lm_eval_adaptor import LMEvalAdaptor from awq.utils.lm_eval_adaptor import LMEvalAdaptor
from awq.utils.utils import simple_dispatch_model
parser = argparse.ArgumentParser()
parser.add_argument('--model_path', type=str, help='path of the hf model')
parser.add_argument('--batch_size', type=int, default=1, help='batch size')
parser.add_argument("--tasks", default=None, type=str)
parser.add_argument("--output_path", default=None, type=str)
parser.add_argument('--num_fewshot', type=int, default=0)
# model config
parser.add_argument('--parallel', action='store_true',
help="enable model parallelism")
# max memory to offload larger models to CPU
parser.add_argument('--max_memory', type=str, nargs='*',
help="List of device_id:max_memory pairs to be parsed into a dictionary; " \
+ "Example: 0:10GiB 1:10GiB cpu:30GiB; " \
+ "mode details here: " \
+ "https://huggingface.co/docs/accelerate/usage_guides/big_modeling")
parser.add_argument('--auto_parallel', action='store_true',
help="automatically set parallel and batch_size")
# quantization config
parser.add_argument('--w_bit', type=int, default=None)
parser.add_argument('--q_group_size', type=int, default=-1)
parser.add_argument('--no_zero_point', action='store_true',
help="disable zero_point")
parser.add_argument('--q_backend', type=str,
default="fake", choices=["fake", "real"])
# save/load real quantized weights
parser.add_argument('--dump_quant', type=str, default=None,
help='save quantized model')
parser.add_argument('--load_quant', type=str, default=None,
help='load quantized model')
# apply/save/load awq
parser.add_argument('--run_awq', action='store_true',
help="perform awq search process")
parser.add_argument('--dump_awq', type=str, default=None,
help="save the awq search results")
parser.add_argument('--load_awq', type=str, default=None,
help="load the awq search results")
args = parser.parse_args()
max_memory = [v.split(':') for v in (args.max_memory or [])]
max_memory = {(int(k) if k.isdigit() else k):v for k,v in max_memory}
if args.auto_parallel:
gpu_list = auto_parallel(args)
# get quantization config (apart from w_bit)
q_config = {
"zero_point": not args.no_zero_point, # by default True
"q_group_size": args.q_group_size, # whether to use group quantization
}
print("Quantization config:", q_config)
# build model and tokenizer
def build_model_and_enc(model_path):
if not os.path.exists(model_path): # look into ssd
raise FileNotFoundError(f"{model_path} not found!")
print(f"* Building model {model_path}")
# all hf model
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
if "mpt" in config.__class__.__name__.lower():
enc = AutoTokenizer.from_pretrained(config.tokenizer_name, trust_remote_code=True)
else:
enc = AutoTokenizer.from_pretrained(model_path, use_fast=False, trust_remote_code=True)
if args.load_quant: # directly load quantized weights
print("Loading pre-computed quantized weights...")
with init_empty_weights():
model = AutoModelForCausalLM.from_config(config=config,
torch_dtype=torch.float16, trust_remote_code=True)
real_quantize_model_weight(
model, w_bit=args.w_bit, q_config=q_config, init_only=True)
model.tie_weights()
# Infer device map
kwargs = {"max_memory": max_memory} if len(max_memory) else {}
device_map = infer_auto_device_map(
model,
no_split_module_classes=[
"OPTDecoderLayer", "LlamaDecoderLayer", "BloomBlock", "MPTBlock", "DecoderLayer"],
**kwargs
)
# Load checkpoint in the model
load_checkpoint_in_model(
model,
checkpoint=args.load_quant,
device_map=device_map,
offload_state_dict=True,
)
# Dispatch model
model = simple_dispatch_model(model, device_map=device_map)
model.eval()
else: # fp16 to quantized
args.run_awq &= not args.load_awq # if load_awq, no need to run awq
# Init model on CPU:
kwargs = {"torch_dtype": torch.float16, "low_cpu_mem_usage": True}
model = AutoModelForCausalLM.from_pretrained(
model_path, config=config, trust_remote_code=True, **kwargs)
model.eval()
if args.run_awq:
assert args.dump_awq, "Please save the awq results with --dump_awq"
awq_results = run_awq(
model, enc,
w_bit=args.w_bit, q_config=q_config,
n_samples=128, seqlen=512,
)
if args.dump_awq:
dirpath = os.path.dirname(args.dump_awq)
os.makedirs(dirpath, exist_ok=True)
torch.save(awq_results, args.dump_awq)
print("AWQ results saved at", args.dump_awq)
exit(0)
if args.load_awq:
print("Loading pre-computed AWQ results from", args.load_awq)
awq_results = torch.load(args.load_awq, map_location="cpu")
apply_awq(model, awq_results)
# weight quantization
if args.w_bit is not None:
if args.q_backend == "fake":
assert args.dump_quant is None, \
"Need to use real quantization to dump quantized weights"
pseudo_quantize_model_weight(
model, w_bit=args.w_bit, q_config=q_config
)
elif args.q_backend == "real": # real quantization
real_quantize_model_weight(
model, w_bit=args.w_bit, q_config=q_config
)
if args.dump_quant:
dirpath = os.path.dirname(args.dump_quant)
os.makedirs(dirpath, exist_ok=True)
print(
f"Saving the quantized model at {args.dump_quant}...")
torch.save(model.cpu().state_dict(), args.dump_quant)
exit(0)
else:
raise NotImplementedError
# Move the model to GPU (as much as possible) for LM evaluation
kwargs = {"max_memory": max_memory} if len(max_memory) else {}
device_map = infer_auto_device_map(
model,
# TODO: can we remove this?
no_split_module_classes=[
"OPTDecoderLayer", "LlamaDecoderLayer", "BloomBlock", "MPTBlock", "DecoderLayer"],
**kwargs
)
model = dispatch_model(model, device_map=device_map)
return model, enc
def main():
if args.output_path is not None and os.path.exists(args.output_path):
# print(f"Results {args.output_path} already generated. Exit.")
print(f"Results {args.output_path} already generated. Overwrite.")
# exit()
if args.dump_awq and os.path.exists(args.dump_awq):
print(f"Found existing AWQ results {args.dump_awq}, exit.")
exit()
# a hack here to auto set model group
model, enc = build_model_and_enc(args.model_path)
if args.tasks is not None:
task_names = args.tasks.split(",")
lm_eval_model = LMEvalAdaptor(args.model_path, model, enc, args.batch_size)
results = evaluator.simple_evaluate(
model=lm_eval_model,
tasks=task_names,
batch_size=args.batch_size,
no_cache=True,
num_fewshot=args.num_fewshot,
)
print(evaluator.make_table(results))
if args.output_path is not None:
os.makedirs(os.path.dirname(args.output_path), exist_ok=True)
# otherwise cannot save
results["config"]["model"] = args.model_path
with open(args.output_path, "w") as f:
json.dump(results, f, indent=2)
def load_search_result_into_memory(model, search_path):
awq_results = torch.load(search_path, map_location="cpu")
apply_scale(model, awq_results["scale"])
apply_clip(model, awq_results["clip"])
def run_search(model_path, dump_path, quant_config):
"""
Step 1/2: Search the pile for an optimal scaling factor.
"""
# Load model
model = AutoAWQForCausalLM.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
# Quantize
model.quantize(tokenizer, quant_config=quant_config, run_search=True, run_quant=False)
# Save search results
model.save_quantized(dump_path)
# Save tokenizer
tokenizer.save_pretrained(dump_path)
def run_quant(model_path, search_path, dump_path, quant_config):
"""
Step 2/2: Use the search results to quantize model weights
"""
# Load model and search results
model = AutoAWQForCausalLM.from_pretrained(model_path)
load_search_result_into_memory(model.model, search_path)
# Run actual weight quantization
model.quantize(quant_config=quant_config, run_search=False, run_quant=True)
# Save quantized model
model.save_quantized(dump_path)
def run_eval(model_path, quant_file, device, tasks, task_batch_size, task_n_shot, task_use_pretrained):
"""
Post quantization: Evaluate perplexity on wikitext with EleutherAI Evaluation Harness
"""
# Load model
if task_use_pretrained:
model = AutoAWQForCausalLM.from_pretrained(model_path)
else:
model = AutoAWQForCausalLM.from_quantized(model_path, quant_file)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
# Load adapter
lm_eval_model = LMEvalAdaptor(model_path, model, tokenizer, device, batch_size=task_batch_size)
# Evaluate perplexity of quantized model
results = evaluator.simple_evaluate(
model=lm_eval_model,
tasks=tasks.split(','),
batch_size=task_batch_size,
no_cache=True,
num_fewshot=task_n_shot,
)
print(evaluator.make_table(results))
@torch.inference_mode()
def run_speed(model_path, quant_file, device, n_generate=128, max_new_tokens=256):
def _timer(func):
start = time.time()
out = func()
return out, time.time() - start
def _generate(model, model_out, n_generate):
past_key_values = model_out.past_key_values
for i in range(n_generate):
logits = model_out.logits[0, -1, :]
probs = torch.softmax(logits, dim=-1)
token = torch.multinomial(probs, num_samples=1)
token = torch.as_tensor([token], device=device).unsqueeze(0)
model_out = model(token, use_cache=True, past_key_values=past_key_values)
def _warmup(device:str):
warm_up = torch.randn((4096,4096)).to(device)
torch.mm(warm_up,warm_up)
# Load model
model, load_time = _timer(lambda: AutoAWQForCausalLM.from_quantized(model_path, quant_file, fuse_layers=True))
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
_warmup(device)
# Generate random inputs
n_context = max_new_tokens - n_generate
ids = torch.randint(0, tokenizer.vocab_size, (1, n_context)).cuda()
# Context stage
model_out, context_time = _timer(lambda: model(ids, use_cache=True))
# Generation stage
_, generation_time = _timer(lambda: _generate(model, model_out, n_generate))
# Prints
memory_used = torch.cuda.max_memory_allocated(device) / (1024 ** 2)
context_tokens_per_second = n_context / context_time
context_ms_per_token = (context_time*1000) / n_context
inference_tokens_per_second = n_generate / generation_time
inference_ms_per_token = (generation_time*1000) / n_generate
print(f"[======] Model summary: {model_path} [======]")
print(f"[*] Load time: {load_time:.2f} seconds")
print(f"[*] Context speed: {context_tokens_per_second:.2f} tokens/second ({context_ms_per_token:.2f} ms/token)")
print(f"[*] Generation speed: {inference_tokens_per_second:.2f} tokens/second ({inference_ms_per_token:.2f} ms/token)")
print(f"[*] VRAM: {memory_used:.2f} MB")
if __name__ == '__main__': if __name__ == '__main__':
main() """
- Run AWQ search and save result:
python -m awq.entry --entry_type search --model_path lmsys/vicuna-7b-v1.5 --search_path vicuna-7b-v1.5-awq
- Run AWQ to save the real quantized weights at the quant_path:
python -m awq.entry --entry_type quant --model_path lmsys/vicuna-7b-v1.5 --search_path vicuna-7b-v1.5-awq/awq_model_search_result.pt --quant_path vicuna-7b-v1.5-awq
- Run perplexity of quantized model:
python -m awq.entry --entry_type eval --model_path vicuna-7b-v1.5-awq --quant_file awq_model_w4_g128.pt
- Run perplexity unquantized FP16 model:
python -m awq.entry --entry_type eval --model_path lmsys/vicuna-7b-v1.5 --task_use_pretrained
- Run a speedtest to benchmark the quantized model:
python -m awq.entry --entry_type speed --model_path vicuna-7b-v1.5-awq --quant_file awq_model_w4_g128.pt
"""
parser = argparse.ArgumentParser()
parser.add_argument('--entry_type', type=str, help='The type of task to run (search|quant|eval|speed)')
parser.add_argument('--model_path', type=str, help='Path to hf model')
parser.add_argument('--search_path', type=str, help='Path to save/load AWQ search results')
parser.add_argument('--quant_path', type=str, help='Path to save AWQ model to directory')
parser.add_argument('--quant_file', type=str, help='Path to quantized AWQ model file')
parser.add_argument('--device', type=str, default='cuda:0', help='Device to load model to')
parser.add_argument('--w_bit', type=int, default=4)
parser.add_argument('--q_group_size', type=int, default=128)
parser.add_argument('--tasks', type=str, default='wikitext', help='Tasks to evaluate. '
'Separate tasks by comma for multiple tasks.'
'https://github.com/EleutherAI/lm-evaluation-harness/blob/master/docs/task_table.md')
parser.add_argument("--task_use_pretrained", default=False, action=argparse.BooleanOptionalAction,
help="Pass '--task_use_pretrained' to use a pretrained model running FP16")
parser.add_argument('--task_batch_size', type=int, default=1)
parser.add_argument('--task_n_shot', type=int, default=0)
parser.add_argument('--n_generate', type=int, default=128)
parser.add_argument('--n_context', type=int, default=256)
args = parser.parse_args()
quant_config = { "zero_point": True, "q_group_size": args.q_group_size, "w_bit": args.w_bit }
if args.entry_type == 'search':
run_search(args.model_path, args.search_path, quant_config)
elif args.entry_type == 'quant':
run_quant(args.model_path, args.search_path, args.quant_path, quant_config)
elif args.entry_type == 'eval':
run_eval(args.model_path, args.quant_file, args.device,
args.tasks, args.task_batch_size, args.task_n_shot, args.task_use_pretrained)
elif args.entry_type == 'speed':
run_speed(args.model_path, args.quant_file, args.device, args.n_generate, args.n_context)
else:
raise Exception('--entry_type must be one of (search|quant|eval|speed)')
\ No newline at end of file
from .mpt import MptAWQForCausalLM
from .llama import LlamaAWQForCausalLM
from .opt import OptAWQForCausalLM
from .falcon import FalconAWQForCausalLM
from .bloom import BloomAWQForCausalLM
\ No newline at end of file
from transformers import AutoConfig
from awq.models import *
from awq.models.base import BaseAWQForCausalLM
AWQ_CAUSAL_LM_MODEL_MAP = {
"mpt": MptAWQForCausalLM,
"llama": LlamaAWQForCausalLM,
"opt": OptAWQForCausalLM,
"RefinedWeb": FalconAWQForCausalLM,
"RefinedWebModel": FalconAWQForCausalLM,
"bloom": BloomAWQForCausalLM
}
def check_and_get_model_type(model_dir, trust_remote_code=True):
config = AutoConfig.from_pretrained(model_dir, trust_remote_code=trust_remote_code)
if config.model_type not in AWQ_CAUSAL_LM_MODEL_MAP.keys():
raise TypeError(f"{config.model_type} isn't supported yet.")
model_type = config.model_type
return model_type
class AutoAWQForCausalLM:
def __init__(self):
raise EnvironmentError('You must instantiate AutoAWQForCausalLM with\n'
'AutoAWQForCausalLM.from_quantized or AutoAWQForCausalLM.from_pretrained')
@classmethod
def from_pretrained(self, model_path, trust_remote_code=True, safetensors=False) -> BaseAWQForCausalLM:
model_type = check_and_get_model_type(model_path, trust_remote_code)
return AWQ_CAUSAL_LM_MODEL_MAP[model_type].from_pretrained(
model_path, model_type, trust_remote_code=trust_remote_code, safetensors=safetensors
)
@classmethod
def from_quantized(self, quant_path, quant_filename, max_new_tokens=None,
device='balanced', trust_remote_code=True, fuse_layers=True) -> BaseAWQForCausalLM:
model_type = check_and_get_model_type(quant_path, trust_remote_code)
return AWQ_CAUSAL_LM_MODEL_MAP[model_type].from_quantized(
quant_path, model_type, quant_filename, max_new_tokens, device, trust_remote_code=trust_remote_code, fuse_layers=fuse_layers
)
\ No newline at end of file
import os
import gc
import json
import torch
import functools
import torch.nn as nn
from tqdm import tqdm
from collections import defaultdict
from huggingface_hub import snapshot_download
from awq.utils.calib_data import get_calib_dataset
from awq.quantize.quantizer import pseudo_quantize_tensor
from awq.quantize.qmodule import WQLinear, ScaledActivation
from awq.quantize.auto_clip import auto_clip_block, apply_clip
from awq.quantize.auto_scale import auto_scale_block, apply_scale
from transformers import AutoModelForCausalLM, AutoConfig, PreTrainedModel
from accelerate import init_empty_weights, load_checkpoint_and_dispatch, infer_auto_device_map
from awq.utils.module import append_str_prefix, get_op_name, get_named_linears, set_op_by_name
class BaseAWQForCausalLM(nn.Module):
def __init__(self, model, model_type, is_quantized, quant_config):
super().__init__()
self.model:PreTrainedModel = model
self.model_type:str = model_type
self.is_quantized:bool = is_quantized
self.search_result = None
self.quant_config:dict = quant_config
def to(self, device: str):
return self.model.to(device)
def forward(self, *args, **kwargs):
return self.model(*args, **kwargs)
def generate(self, *args, **kwargs):
with torch.inference_mode():
return self.model.generate(*args, **kwargs)
@torch.no_grad()
def quantize(self, tokenizer=None, quant_config={}, n_samples=128, seqlen=512,
auto_scale=True, mse_range=True, run_search=True, run_quant=True,
calib_data="pileval"):
self.quant_config = quant_config
if run_search:
self.search_result = self._awq_search(tokenizer, quant_config, n_samples=n_samples, seqlen=seqlen,
auto_scale=auto_scale, mse_range=mse_range, calib_data=calib_data)
if run_quant:
self._awq_quant()
self.is_quantized = True
def _awq_quant(self):
assert self.quant_config["zero_point"], "We only support zero_point quantization now."
layers = self.get_model_layers(self.model)
# Run AWQ quantization
for i in tqdm(range(len(layers)), desc="AWQ Quantization"):
layer = layers[i]
named_linears = get_named_linears(layer)
self._scale_activations(self, layer)
for name, module in named_linears.items():
module.cuda()
module.weight.data, scales, zeros = pseudo_quantize_tensor(
module.weight.data,
get_scale_zp=True,
**self.quant_config
)
scales = scales.t().contiguous()
zeros = zeros.t().contiguous()
q_linear = WQLinear.from_linear(
module,
self.quant_config['w_bit'],
self.quant_config['q_group_size'],
False,
scales,
zeros
)
module.cpu()
q_linear.to(next(layer.parameters()).device)
set_op_by_name(layer, name, q_linear)
torch.cuda.empty_cache()
gc.collect()
torch.cuda.empty_cache()
gc.collect()
def _awq_search(self, tokenizer, quant_config, n_samples=128, seqlen=512,
auto_scale=True, mse_range=True, calib_data="pileval"):
layers = self.get_model_layers(self.model)
samples = get_calib_dataset(
data=calib_data, tokenizer=tokenizer, n_samples=n_samples, block_size=seqlen)
samples = torch.cat(samples, dim=0)
inps = []
layer_kwargs = {}
layers[0] = layers[0].cuda()
self.move_embed(self.model, "cuda")
# get input and kwargs to layer 0
# with_kwargs is only supported in PyTorch 2.0
# use this Catcher hack for now
class Catcher(nn.Module):
def __init__(self, module):
super().__init__()
self.module = module
def forward(self, inp, **kwargs):
inps.append(inp)
layer_kwargs.update(kwargs)
raise ValueError # early exit to break later inference
# patch layer 0 to catch input and kwargs
layers[0] = Catcher(layers[0])
try:
self.model(samples.to(next(self.model.parameters()).device))
except ValueError: # work with early exit
pass
del samples
layers[0] = layers[0].module # restore
inps = inps[0]
layers[0] = layers[0].cpu()
self.move_embed(self.model, "cpu")
gc.collect()
torch.cuda.empty_cache()
awq_results = {
"scale": [],
"clip": [],
}
# Run AWQ search layer by layer
for i in tqdm(range(len(layers)), desc="AWQ Search"):
layer = layers[i]
layer = layer.cuda()
named_linears = get_named_linears(layer)
# firstly, get input features of all linear layers
def cache_input_hook(m, x, y, name, feat_dict):
x = x[0]
x = x.detach().cpu()
feat_dict[name].append(x)
input_feat = defaultdict(list)
handles = []
for name in named_linears:
handles.append(named_linears[name].register_forward_hook(
functools.partial(cache_input_hook, name=name,
feat_dict=input_feat)))
inps = inps.to(next(layer.parameters()).device) # in case multi-gpu
# get output as next layer's input
inps = layer(inps, **layer_kwargs)[0]
for h in handles:
h.remove()
# now solve for scaling and clipping
input_feat = {k: torch.cat(v, dim=0) for k, v in input_feat.items()}
# Clear GPU memory
torch.cuda.empty_cache()
if auto_scale: # if it applies, we should also modify the input_feat with scales
scales_list = auto_scale_block(
self,
layer,
layer_kwargs,
quant_config=quant_config,
input_feat=input_feat,
)
apply_scale(layers[i], scales_list, input_feat_dict=input_feat)
# append prefix to make names global
awq_results["scale"] += append_str_prefix(scales_list, get_op_name(self.model, layer) + ".")
# Clear GPU memory
torch.cuda.empty_cache()
if mse_range:
clip_list = auto_clip_block(
layer,
quant_config=quant_config,
input_feat=input_feat
)
apply_clip(layer, clip_list)
# append prefix to make names global
awq_results["clip"] += append_str_prefix(clip_list, get_op_name(self.model, layer) + ".")
layer = layer.cpu()
# Haotian: check activation replacement
del input_feat
gc.collect()
torch.cuda.empty_cache()
return awq_results
def save_quantized(self, save_dir):
def _save_files(save_dir, model_name, model):
class EmptyModule(nn.Module):
def __init__(self): super(EmptyModule, self).__init__()
def forward(self, x): return x
# Save model fiels without search results
self.model.save_pretrained(save_dir, state_dict=EmptyModule().state_dict())
# Remove empty module
os.remove(f'{save_dir}/pytorch_model.bin')
# Save search results
torch.save(model, f'{save_dir}/{model_name}')
# Save config
with open(f'{save_dir}/quant_config.json', 'w+') as file:
file.write(json.dumps(self.quant_config, indent=4))
save_dir = save_dir[:-1] if save_dir[-1] == '/' else save_dir
# Save model
if self.search_result is None or self.is_quantized:
model_name = f'awq_model_w{self.quant_config["w_bit"]}_g{self.quant_config["q_group_size"]}.pt'
_save_files(save_dir, model_name, self.model.state_dict())
else:
model_name = 'awq_model_search_result.pt'
_save_files(save_dir, model_name, self.search_result)
@classmethod
def from_pretrained(self, model_path, model_type, torch_dtype: torch.dtype = torch.float16,
trust_remote_code=True, safetensors=False):
return self.from_quantized(
model_path,
model_type,
model_filename='',
max_new_tokens=None,
device='balanced',
torch_dtype=torch_dtype,
trust_remote_code=trust_remote_code,
safetensors=safetensors,
is_quantized=False
)
@classmethod
def from_quantized(self, model_path, model_type, model_filename, max_new_tokens=None,
device='balanced', torch_dtype=torch.float16, trust_remote_code=True,
safetensors=False, is_quantized=True, fuse_layers=False):
# [STEP 1] Download model if path is not a directory
if not os.path.isdir(model_path):
ignore_patterns = ["*msgpack*", "*h5*"]
if safetensors:
ignore_patterns.extend(["*.pt", "*.bin"])
else:
ignore_patterns.append("*safetensors*")
model_path = snapshot_download(model_path, ignore_patterns=ignore_patterns)
# TODO: Better naming, model_filename becomes a directory
model_filename = model_path + f'/{model_filename}'
# [STEP 2] Load config and set sequence length
# TODO: Create BaseAWQConfig class
quant_config_path = f'{model_path}/quant_config.json'
if os.path.exists(quant_config_path):
with open(quant_config_path, 'r') as file:
quant_config = json.loads(file.read())
else:
# Default config that works for most models
quant_config = {"zero_point": True, "q_group_size": 128, "w_bit": 4}
# Load model config and set max generation length
if max_new_tokens is None and hasattr(self, 'max_new_tokens_key'):
config = AutoConfig.from_pretrained(model_path, trust_remote_code=trust_remote_code)
config.max_new_tokens = getattr(config, self.max_new_tokens_key)
else:
max_new_tokens = 2048 if max_new_tokens is None else max_new_tokens
config = AutoConfig.from_pretrained(model_path, trust_remote_code=trust_remote_code)
config.max_new_tokens = max_new_tokens
# [STEP 3] Load model
with init_empty_weights():
model = AutoModelForCausalLM.from_config(config=config, torch_dtype=torch_dtype, trust_remote_code=trust_remote_code)
# Only need to replace layers if a model is AWQ quantized
if is_quantized:
# Prepare WQLinear layers, replace nn.Linear
self._load_quantized_modules(self, model, quant_config)
model.tie_weights()
# Load model weights
if is_quantized:
model = load_checkpoint_and_dispatch(model, model_filename, device_map=device, no_split_module_classes=[self.layer_type])
if fuse_layers:
self.fuse_layers(model)
else:
# If not quantized, must load with AutoModelForCausalLM
device_map = infer_auto_device_map(
model,
no_split_module_classes=[self.layer_type],
dtype=torch_dtype
)
del model
# Load model weights
model = AutoModelForCausalLM.from_pretrained(
model_filename, device_map=device_map, offload_folder="offload", offload_state_dict=True, torch_dtype=torch_dtype, use_safetensors=safetensors
)
model.eval()
return self(model, model_type, is_quantized=is_quantized, quant_config=quant_config)
def _load_quantized_modules(self, model, quant_config):
# Real quantization of weights
assert quant_config["zero_point"], "We only support zero_point quantization now."
# Get blocks of model
layers = self.get_model_layers(model)
for i in tqdm(range(len(layers)), desc="Replacing layers..."):
layer = layers[i]
# Get every linear layer in a block
named_linears = get_named_linears(layer)
# Replace activation functions
self._scale_activations(self, layer)
# Replace nn.Linear with WQLinear
for name, module in named_linears.items():
q_linear = WQLinear.from_linear(
module, quant_config['w_bit'], quant_config['q_group_size'], True)
q_linear.to(next(layer.parameters()).device)
set_op_by_name(layer, name, q_linear)
torch.cuda.empty_cache()
gc.collect()
@staticmethod
def _scale_activations(self, layer):
scale_dict = self.get_act_for_scaling(layer)
if scale_dict['is_scalable']:
if not isinstance(scale_dict['scale_layer'], ScaledActivation):
param = next(layer.parameters())
# get activation scale
scale_like = torch.ones(scale_dict['scale_shape'], dtype=param.dtype, device=param.device)
# scale activation
scaled_act = ScaledActivation(scale_dict['scale_layer'], scale_like)
set_op_by_name(layer, scale_dict['scale_name'], scaled_act)
\ No newline at end of file
from .base import BaseAWQForCausalLM
from transformers.models.bloom.modeling_bloom import BloomForCausalLM, BloomBlock
class BloomAWQForCausalLM(BaseAWQForCausalLM):
layer_type = "BloomBlock"
@staticmethod
def get_model_layers(model: BloomForCausalLM):
return model.transformer.h
@staticmethod
def get_act_for_scaling(module: BloomBlock):
return dict(
is_scalable=True,
scale_name="mlp.gelu_impl",
scale_layer=module.mlp.gelu_impl,
scale_shape=module.mlp.dense_h_to_4h.out_features
)
@staticmethod
def move_embed(model: BloomForCausalLM, device: str):
model.transformer.word_embeddings = model.transformer.word_embeddings.to(device)
model.transformer.word_embeddings_layernorm = model.transformer.word_embeddings_layernorm.to(device)
@staticmethod
def get_layers_for_scaling(module: BloomBlock, input_feat, module_kwargs):
layers = []
# attention input
layers.append(dict(
prev_op=module.input_layernorm,
layers=[module.self_attention.query_key_value],
inp=input_feat['self_attention.query_key_value'],
module2inspect=module, kwargs=module_kwargs,
))
# attention out
# Please refer to https://github.com/mit-han-lab/llm-awq/issues/2#issuecomment-1606297469
"""
scales_list.append(_auto_get_scale(
prev_op=module.self_attention.query_key_value,
layers=[module.self_attention.dense],
inp=input_feat['self_attention.dense'],
))
"""
# linear 1
layers.append(dict(
prev_op=module.post_attention_layernorm,
layers=[module.mlp.dense_h_to_4h],
inp=input_feat['mlp.dense_h_to_4h'],
module2inspect=module, kwargs=module_kwargs,
))
# linear 2
layers.append(dict(
prev_op=module.mlp.gelu_impl,
layers=[module.mlp.dense_4h_to_h],
inp=input_feat['mlp.dense_4h_to_h'],
))
return layers
\ No newline at end of file
from .base import BaseAWQForCausalLM
from transformers.models.falcon.modeling_falcon import FalconDecoderLayer, FalconForCausalLM
class FalconAWQForCausalLM(BaseAWQForCausalLM):
layer_type = "FalconDecoderLayer"
@staticmethod
def get_model_layers(model: FalconForCausalLM):
return model.transformer.h
@staticmethod
def get_act_for_scaling(module: FalconDecoderLayer):
return dict(
is_scalable=True,
scale_name="mlp.act",
scale_layer=module.mlp.act,
scale_shape=module.mlp.dense_h_to_4h.out_features
)
@staticmethod
def move_embed(model: FalconForCausalLM, device):
model.transformer.word_embeddings = model.transformer.word_embeddings.to(device)
@staticmethod
def get_layers_for_scaling(module: FalconDecoderLayer, input_feat, module_kwargs):
layers = []
# Falcon 7B (older architecture)
if module.config.num_attention_heads == 71:
# linear 1 + attention
layers.append(dict(
prev_op=module.input_layernorm,
layers=[module.mlp.dense_h_to_4h, module.self_attention.query_key_value],
inp=input_feat['self_attention.query_key_value'],
module2inspect=module,
kwargs=module_kwargs,
))
# Falcon 40B (newer architecture)
else:
# linear 1 + attention
layers.append(dict(
prev_op=module.ln_attn,
layers=[module.self_attention.query_key_value],
inp=input_feat['self_attention.query_key_value'],
module2inspect=module,
kwargs=module_kwargs,
))
# linear 2
layers.append(dict(
prev_op=module.ln_mlp,
layers=[module.mlp.dense_h_to_4h],
inp=input_feat['mlp.dense_h_to_4h'],
module2inspect=module,
kwargs=module_kwargs,
))
return layers
\ No newline at end of file
from .base import BaseAWQForCausalLM
from awq.modules import make_quant_norm, make_quant_attn, make_fused_mlp
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaForCausalLM
class LlamaAWQForCausalLM(BaseAWQForCausalLM):
layer_type = "LlamaDecoderLayer"
max_new_tokens_key = "max_position_embeddings"
@staticmethod
def fuse_layers(awq_model):
make_quant_attn(awq_model, awq_model.device)
make_quant_norm(awq_model)
make_fused_mlp(awq_model)
@staticmethod
def get_model_layers(model: LlamaForCausalLM):
return model.model.layers
@staticmethod
def get_act_for_scaling(module: LlamaDecoderLayer):
return dict(
is_scalable=False
)
@staticmethod
def move_embed(model: LlamaForCausalLM, device: str):
model.model.embed_tokens = model.model.embed_tokens.to(device)
@staticmethod
def get_layers_for_scaling(module: LlamaDecoderLayer, input_feat, module_kwargs):
layers = []
# attention input
layers.append(dict(
prev_op=module.input_layernorm,
layers=[module.self_attn.q_proj,
module.self_attn.k_proj, module.self_attn.v_proj],
inp=input_feat['self_attn.q_proj'],
module2inspect=module.self_attn, kwargs=module_kwargs,
))
# attention out
# Please refer to https://github.com/mit-han-lab/llm-awq/pull/67#issue-1850622696
if module.self_attn.v_proj.weight.shape == module.self_attn.o_proj.weight.shape:
layers.append(dict(
prev_op=module.self_attn.v_proj,
layers=[module.self_attn.o_proj],
inp=input_feat['self_attn.o_proj'],
))
# linear 1
layers.append(dict(
prev_op=module.post_attention_layernorm,
layers=[module.mlp.gate_proj, module.mlp.up_proj],
inp=input_feat['mlp.gate_proj'],
module2inspect=module.mlp,
))
# linear 2
layers.append(dict(
prev_op=module.mlp.up_proj,
layers=[module.mlp.down_proj],
inp=input_feat['mlp.down_proj'],
))
return layers
\ No newline at end of file
from .base import BaseAWQForCausalLM
from awq.modules import make_fused_mlp
class MptAWQForCausalLM(BaseAWQForCausalLM):
layer_type = "MPTBlock"
max_new_tokens_key = "max_seq_len"
@staticmethod
def fuse_layers(awq_model):
make_fused_mlp(awq_model)
@staticmethod
def get_model_layers(model):
return model.transformer.blocks
@staticmethod
def get_act_for_scaling(module):
return dict(
is_scalable=True,
scale_name="ffn.act",
scale_layer=module.ffn.act,
scale_shape=module.ffn.up_proj.out_features
)
@staticmethod
def move_embed(model, device):
model.transformer.wte = model.transformer.wte.to(device)
model.transformer.emb_drop = model.transformer.emb_drop.to(device)
@staticmethod
def get_layers_for_scaling(module, input_feat, module_kwargs):
layers = []
# attention input
layers.append(dict(
prev_op=module.norm_1,
layers=[module.attn.Wqkv],
inp=input_feat['attn.Wqkv'],
module2inspect=module.attn,
kwargs=module_kwargs
))
# attention output
layers.append(dict(
prev_op=module.attn.Wqkv,
layers=[module.attn.out_proj],
inp=input_feat['attn.out_proj']
))
# linear 1
layers.append(dict(
prev_op=module.norm_2,
layers=[module.ffn.up_proj],
inp=input_feat['ffn.up_proj'],
module2inspect=module.ffn
))
# linear 2
layers.append(dict(
prev_op=module.ffn.act,
layers=[module.ffn.down_proj],
inp=input_feat['ffn.down_proj']
))
return layers
\ No newline at end of file
from .base import BaseAWQForCausalLM
from transformers.models.opt.modeling_opt import OPTForCausalLM, OPTDecoderLayer
class OptAWQForCausalLM(BaseAWQForCausalLM):
layer_type = "OPTDecoderLayer"
max_new_tokens_key = "max_position_embeddings"
@staticmethod
def get_model_layers(model: OPTForCausalLM):
return model.model.decoder.layers
@staticmethod
def get_act_for_scaling(module: OPTDecoderLayer):
return dict(
is_scalable=False
)
@staticmethod
def move_embed(model: OPTForCausalLM, device: str):
model.model.decoder.embed_tokens = model.model.decoder.embed_tokens.to(device)
model.model.decoder.embed_positions = model.model.decoder.embed_positions.to(device)
@staticmethod
def get_layers_for_scaling(module: OPTDecoderLayer, input_feat, module_kwargs):
layers = []
# attention input
layers.append(dict(
prev_op=module.self_attn_layer_norm,
layers=[module.self_attn.q_proj,
module.self_attn.k_proj, module.self_attn.v_proj],
inp=input_feat['self_attn.q_proj'],
module2inspect=module.self_attn, kwargs=module_kwargs,
))
# attention out
layers.append(dict(
prev_op=module.self_attn.v_proj,
layers=[module.self_attn.out_proj],
inp=input_feat['self_attn.out_proj'],
))
# linear 1
layers.append(dict(
prev_op=module.final_layer_norm,
layers=[module.fc1],
inp=input_feat['fc1'],
))
# linear 2
layers.append(dict(
prev_op=module.fc1,
layers=[module.fc2],
inp=input_feat['fc2'],
))
return layers
\ No newline at end of file
...@@ -7,6 +7,27 @@ from transformers.models.llama.modeling_llama import LlamaMLP ...@@ -7,6 +7,27 @@ from transformers.models.llama.modeling_llama import LlamaMLP
import awq_inference_engine import awq_inference_engine
class QuantMPTMLP(nn.Module):
def __init__(
self,
up_proj,
act,
down_proj
):
super().__init__()
self.register_buffer('up_proj_qweight', up_proj.qweight)
self.register_buffer('up_proj_scales', up_proj.scales)
self.register_buffer('up_proj_qzeros', up_proj.qzeros)
self.up_proj = up_proj
self.act = act
self.down_proj = down_proj
def forward(self, x: torch.Tensor):
x = x.reshape(-1, x.shape[-1])
x = awq_inference_engine.gemm_forward_cuda(x, self.up_proj_qweight, self.up_proj_scales, self.up_proj_qzeros, 8)
return self.down_proj(self.act(x))
class QuantLlamaMLP(nn.Module): class QuantLlamaMLP(nn.Module):
...@@ -57,10 +78,15 @@ def make_fused_mlp(m, parent_name=''): ...@@ -57,10 +78,15 @@ def make_fused_mlp(m, parent_name=''):
""" """
if isinstance(m, LlamaMLP): if isinstance(m, LlamaMLP):
return QuantLlamaMLP(m.gate_proj, m.down_proj, m.up_proj) return QuantLlamaMLP(m.gate_proj, m.down_proj, m.up_proj)
elif "mptmlp" in str(m.__class__).lower():
return QuantMPTMLP(m.up_proj, m.act, m.down_proj)
for name, child in m.named_children(): for name, child in m.named_children():
child = make_fused_mlp(child, parent_name=f"{parent_name}.{name}") child = make_fused_mlp(child, parent_name=f"{parent_name}.{name}")
if isinstance(child, QuantLlamaMLP): if isinstance(child, QuantLlamaMLP):
setattr(m, name, child) setattr(m, name, child)
return m elif isinstance(child, QuantMPTMLP):
setattr(m, name, child)
return m
\ No newline at end of file
...@@ -8,7 +8,9 @@ __all__ = ["auto_clip_block"] ...@@ -8,7 +8,9 @@ __all__ = ["auto_clip_block"]
# weight quantization # weight quantization
@torch.no_grad() @torch.no_grad()
def auto_clip_layer(w, input_feat, n_bit, q_config, def auto_clip_layer(w,
input_feat,
quant_config,
n_grid=20, n_grid=20,
max_shrink=0.5, max_shrink=0.5,
n_sample_token=512): n_sample_token=512):
...@@ -16,7 +18,7 @@ def auto_clip_layer(w, input_feat, n_bit, q_config, ...@@ -16,7 +18,7 @@ def auto_clip_layer(w, input_feat, n_bit, q_config,
org_w_shape = w.shape org_w_shape = w.shape
# w [co, ci] -> [co, 1, n_group, group size] # w [co, ci] -> [co, 1, n_group, group size]
# input_feat [n_token, ci] -> [1, n_token, n_group, group size] # input_feat [n_token, ci] -> [1, n_token, n_group, group size]
group_size = q_config["q_group_size"] if q_config["q_group_size"] > 0 else w.shape[1] group_size = quant_config["q_group_size"] if quant_config["q_group_size"] > 0 else w.shape[1]
input_feat = input_feat.view(-1, input_feat.shape[-1]) input_feat = input_feat.view(-1, input_feat.shape[-1])
input_feat = input_feat.reshape(1, input_feat.shape[0], -1, group_size) input_feat = input_feat.reshape(1, input_feat.shape[0], -1, group_size)
input_feat = input_feat[:, 0::input_feat.shape[1] // n_sample_token] input_feat = input_feat[:, 0::input_feat.shape[1] // n_sample_token]
...@@ -41,7 +43,7 @@ def auto_clip_layer(w, input_feat, n_bit, q_config, ...@@ -41,7 +43,7 @@ def auto_clip_layer(w, input_feat, n_bit, q_config,
max_val = org_max_val * (1 - i_s / n_grid) max_val = org_max_val * (1 - i_s / n_grid)
min_val = - max_val min_val = - max_val
cur_w = torch.clamp(w, min_val, max_val) cur_w = torch.clamp(w, min_val, max_val)
q_w = pseudo_quantize_tensor(cur_w, n_bit=n_bit, **q_config) q_w = pseudo_quantize_tensor(cur_w, **quant_config)
cur_out = (input_feat * q_w).sum(dim=-1) cur_out = (input_feat * q_w).sum(dim=-1)
# co, 1, n_group, 1 # co, 1, n_group, 1
...@@ -64,7 +66,7 @@ def auto_clip_layer(w, input_feat, n_bit, q_config, ...@@ -64,7 +66,7 @@ def auto_clip_layer(w, input_feat, n_bit, q_config,
@torch.no_grad() @torch.no_grad()
def auto_clip_block(module, def auto_clip_block(module,
w_bit, q_config, quant_config,
input_feat): input_feat):
named_linears = {name: m for name, named_linears = {name: m for name,
...@@ -77,7 +79,7 @@ def auto_clip_block(module, ...@@ -77,7 +79,7 @@ def auto_clip_block(module,
continue continue
named_linears[name].cuda() named_linears[name].cuda()
max_val = auto_clip_layer( max_val = auto_clip_layer(
named_linears[name].weight, input_feat[name], n_bit=w_bit, q_config=q_config) named_linears[name].weight, input_feat[name], quant_config=quant_config)
clip_list.append((name, max_val)) clip_list.append((name, max_val))
named_linears[name].cpu() named_linears[name].cpu()
return clip_list return clip_list
......
...@@ -7,7 +7,7 @@ from transformers.models.opt.modeling_opt import OPTDecoderLayer ...@@ -7,7 +7,7 @@ from transformers.models.opt.modeling_opt import OPTDecoderLayer
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaRMSNorm from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaRMSNorm
from .qmodule import ScaledActivation from .qmodule import ScaledActivation
from ..utils.module import get_op_by_name, get_op_name, set_op_by_name from awq.utils.module import get_op_by_name, get_op_name, set_op_by_name
__all__ = ["auto_scale_block", "apply_scale"] __all__ = ["auto_scale_block", "apply_scale"]
...@@ -89,15 +89,15 @@ def scale_gelu_fc(gelu, fc, scales): ...@@ -89,15 +89,15 @@ def scale_gelu_fc(gelu, fc, scales):
@torch.no_grad() @torch.no_grad()
def auto_scale_block(module, module_kwargs, def auto_scale_block(awq_model,
w_bit, q_config, module,
module_kwargs,
quant_config,
input_feat): input_feat):
from .quantizer import pseudo_quantize_tensor from .quantizer import pseudo_quantize_tensor
# firstly, get the weight quantize function # firstly, get the weight quantize function
if w_bit is not None: if quant_config['w_bit'] is not None:
def w_quantize_func(p): return pseudo_quantize_tensor( def w_quantize_func(p): return pseudo_quantize_tensor(p, **quant_config).detach()
p, n_bit=w_bit, **q_config,
).detach()
else: else:
def w_quantize_func(p): return p def w_quantize_func(p): return p
...@@ -110,7 +110,7 @@ def auto_scale_block(module, module_kwargs, ...@@ -110,7 +110,7 @@ def auto_scale_block(module, module_kwargs,
# x: n, ci # x: n, ci
weight = torch.cat([_m.weight for _m in linears2scale], dim=0) weight = torch.cat([_m.weight for _m in linears2scale], dim=0)
w_max = get_weight_scale( w_max = get_weight_scale(
weight, q_group_size=q_config.get("q_group_size", -1)) weight, q_group_size=quant_config.get("q_group_size", -1))
# Clear GPU memory # Clear GPU memory
del weight del weight
gc.collect() gc.collect()
...@@ -173,170 +173,10 @@ def auto_scale_block(module, module_kwargs, ...@@ -173,170 +173,10 @@ def auto_scale_block(module, module_kwargs,
# prev_op_name, [layer_name], scale # prev_op_name, [layer_name], scale
return (get_op_name(module, prev_op), tuple([get_op_name(module, m) for m in layers]), scales) return (get_op_name(module, prev_op), tuple([get_op_name(module, m) for m in layers]), scales)
scales_list = [] # return the searched scales layers: list[dict] = awq_model.get_layers_for_scaling(
module, input_feat, module_kwargs
if isinstance(module, OPTDecoderLayer): )
# attention input scales_list = [_auto_get_scale(**layer) for layer in layers]
scales_list.append(_auto_get_scale(
prev_op=module.self_attn_layer_norm,
layers=[module.self_attn.q_proj,
module.self_attn.k_proj, module.self_attn.v_proj],
inp=input_feat['self_attn.q_proj'],
module2inspect=module.self_attn, kwargs=module_kwargs,
))
# attn out
scales_list.append(_auto_get_scale(
prev_op=module.self_attn.v_proj,
layers=[module.self_attn.out_proj],
inp=input_feat['self_attn.out_proj'],
))
# fc1
scales_list.append(_auto_get_scale(
prev_op=module.final_layer_norm,
layers=[module.fc1],
inp=input_feat['fc1'],
))
# fc2
scales_list.append(_auto_get_scale(
prev_op=module.fc1,
layers=[module.fc2],
inp=input_feat['fc2'],
))
elif isinstance(module, LlamaDecoderLayer):
# attention input
scales_list.append(_auto_get_scale(
prev_op=module.input_layernorm,
layers=[module.self_attn.q_proj,
module.self_attn.k_proj, module.self_attn.v_proj],
inp=input_feat['self_attn.q_proj'],
module2inspect=module.self_attn, kwargs=module_kwargs,
))
# attn out
scales_list.append(_auto_get_scale(
prev_op=module.self_attn.v_proj,
layers=[module.self_attn.o_proj],
inp=input_feat['self_attn.o_proj'],
))
# fc1
scales_list.append(_auto_get_scale(
prev_op=module.post_attention_layernorm,
layers=[module.mlp.gate_proj, module.mlp.up_proj],
inp=input_feat['mlp.gate_proj'],
module2inspect=module.mlp,
))
# fc2
scales_list.append(_auto_get_scale(
prev_op=module.mlp.up_proj,
layers=[module.mlp.down_proj],
inp=input_feat['mlp.down_proj'],
))
elif isinstance(module, BloomBlock):
# attention input
scales_list.append(_auto_get_scale(
prev_op=module.input_layernorm,
layers=[module.self_attention.query_key_value],
inp=input_feat['self_attention.query_key_value'],
module2inspect=module, kwargs=module_kwargs,
))
# attn out
# Please refer to https://github.com/mit-han-lab/llm-awq/issues/2#issuecomment-1606297469
"""
scales_list.append(_auto_get_scale(
prev_op=module.self_attention.query_key_value,
layers=[module.self_attention.dense],
inp=input_feat['self_attention.dense'],
))
"""
# fc1
scales_list.append(_auto_get_scale(
prev_op=module.post_attention_layernorm,
layers=[module.mlp.dense_h_to_4h],
inp=input_feat['mlp.dense_h_to_4h'],
module2inspect=module, kwargs=module_kwargs,
))
# fc2
scales_list.append(_auto_get_scale(
prev_op=module.mlp.gelu_impl,
layers=[module.mlp.dense_4h_to_h],
inp=input_feat['mlp.dense_4h_to_h'],
))
elif "mpt" in str(module.__class__).lower():
# attention input
scales_list.append(_auto_get_scale(
prev_op=module.norm_1,
layers=[module.attn.Wqkv],
inp=input_feat['attn.Wqkv'],
module2inspect=module.attn,
kwargs=module_kwargs,
))
# attn out
scales_list.append(_auto_get_scale(
prev_op=module.attn.Wqkv,
layers=[module.attn.out_proj],
inp=input_feat['attn.out_proj'],
))
# fc1
scales_list.append(_auto_get_scale(
prev_op=module.norm_2,
layers=[module.ffn.up_proj],
inp=input_feat['ffn.up_proj'],
module2inspect=module.ffn,
))
# fc2
scales_list.append(_auto_get_scale(
prev_op=module.ffn.act,
layers=[module.ffn.down_proj],
inp=input_feat['ffn.down_proj'],
))
elif "falcon" in str(module.__class__).lower():
# attn out
# Haotian: TBD: need to handle repeated scales for MQ
"""
scales_list.append(_auto_get_scale(
prev_op=module.self_attention.query_key_value,
layers=[module.self_attention.dense],
inp=input_feat['self_attention.dense'],
))
"""
# fc1, as long as it is scaled, everything is screwed up
if "falcon-7b" in str(module.__class__).lower():
scales_list.append(_auto_get_scale(
prev_op=module.input_layernorm,
layers=[module.mlp.dense_h_to_4h, module.self_attention.query_key_value],
inp=input_feat['self_attention.query_key_value'],
module2inspect=module,
kwargs=module_kwargs,
))
elif "falcon-40b" in str(module.__class__).lower():
scales_list.append(_auto_get_scale(
prev_op=module.ln_attn,
layers=[module.self_attention.query_key_value],
inp=input_feat['self_attention.query_key_value'],
module2inspect=module,
kwargs=module_kwargs,
))
scales_list.append(_auto_get_scale(
prev_op=module.ln_mlp,
layers=[module.mlp.dense_h_to_4h],
inp=input_feat['mlp.dense_h_to_4h'],
module2inspect=module,
kwargs=module_kwargs,
))
else:
raise NotImplementedError("Unknown Falcon architecture, currently only falcon-7b and falcon-40b are supported")
# fc2
scales_list.append(_auto_get_scale(
prev_op=module.mlp.act,
layers=[module.mlp.dense_4h_to_h],
inp=input_feat['mlp.dense_4h_to_h'],
))
else:
raise NotImplementedError(f"{type(module)} not supported yet!")
return scales_list return scales_list
......
import torch
import torch.nn as nn
import tqdm
import gc
import functools
from collections import defaultdict
from transformers.models.bloom.modeling_bloom import BloomForCausalLM
from transformers.models.opt.modeling_opt import OPTForCausalLM
from transformers.models.llama.modeling_llama import LlamaForCausalLM
from .auto_scale import auto_scale_block, apply_scale
from .auto_clip import auto_clip_block, apply_clip
__all__ = ["run_awq"]
def get_named_linears(module):
return {name: m for name, m in module.named_modules() if isinstance(m, nn.Linear)}
def get_blocks(model):
if isinstance(model, LlamaForCausalLM):
layers = model.model.layers
elif isinstance(model, OPTForCausalLM):
layers = model.model.decoder.layers
elif isinstance(model, BloomForCausalLM):
layers = model.transformer.h
elif "mpt" in str(model.__class__).lower():
layers = model.transformer.blocks
elif "falcon" in str(model.__class__).lower():
layers = model.transformer.h
else:
raise NotImplementedError(type(model))
return layers
def move_embed(model, device):
if isinstance(model, LlamaForCausalLM):
model.model.embed_tokens = model.model.embed_tokens.to(device)
elif isinstance(model, OPTForCausalLM):
model.model.decoder.embed_tokens = model.model.decoder.embed_tokens.to(device)
model.model.decoder.embed_positions = model.model.decoder.embed_positions.to(device)
elif isinstance(model, BloomForCausalLM):
model.transformer.word_embeddings = model.transformer.word_embeddings.to(device)
model.transformer.word_embeddings_layernorm = model.transformer.word_embeddings_layernorm.to(device)
elif "mpt" in str(model.__class__).lower():
model.transformer.wte = model.transformer.wte.to(device)
model.transformer.emb_drop = model.transformer.emb_drop.to(device)
elif "falcon" in str(model.__class__).lower():
model.transformer.word_embeddings = model.transformer.word_embeddings.to(device)
else:
raise NotImplementedError(type(model))
@torch.no_grad()
def run_awq(
model, enc,
w_bit, q_config,
n_samples=512, seqlen=512,
auto_scale=True, mse_range=True,
# some configs for ablation study
calib_data="pileval",
):
from ..utils.calib_data import get_calib_dataset
from ..utils.module import append_str_prefix, get_op_name
layers = get_blocks(model)
samples = get_calib_dataset(
data=calib_data, tokenizer=enc, n_samples=n_samples, block_size=seqlen)
samples = torch.cat(samples, dim=0)
inps = []
layer_kwargs = {}
layers[0] = layers[0].cuda()
move_embed(model, "cuda")
# get input and kwargs to layer 0
# with_kwargs is only supported in PyTorch 2.0
# use this Catcher hack for now
class Catcher(nn.Module):
def __init__(self, module):
super().__init__()
self.module = module
def forward(self, inp, **kwargs):
inps.append(inp)
layer_kwargs.update(kwargs)
raise ValueError # early exit to break later inference
# patch layer 0 to catch input and kwargs
layers[0] = Catcher(layers[0])
try:
model(samples.to(next(model.parameters()).device))
except ValueError: # work with early exit
pass
del samples
layers[0] = layers[0].module # restore
inps = inps[0]
layers[0] = layers[0].cpu()
move_embed(model, "cpu")
gc.collect()
torch.cuda.empty_cache()
awq_results = {
"scale": [],
"clip": [],
}
# solve layer by layer
for i in tqdm.tqdm(range(len(layers)), desc="Running AWQ..."):
layer = layers[i]
layer = layer.cuda()
named_linears = get_named_linears(layer)
# firstly, get input features of all linear layers
def cache_input_hook(m, x, y, name, feat_dict):
x = x[0]
x = x.detach().cpu()
feat_dict[name].append(x)
input_feat = defaultdict(list)
handles = []
for name in named_linears:
handles.append(named_linears[name].register_forward_hook(
functools.partial(cache_input_hook, name=name,
feat_dict=input_feat)))
inps = inps.to(next(layer.parameters()).device) # in case multi-gpu
# get output as next layer's input
inps = layer(inps, **layer_kwargs)[0]
for h in handles:
h.remove()
# now solve for scaling and clipping
input_feat = {k: torch.cat(v, dim=0) for k, v in input_feat.items()}
# Clear GPU memory
torch.cuda.empty_cache()
if auto_scale: # if it applies, we should also modify the input_feat with scales
scales_list = auto_scale_block(
layer, layer_kwargs,
w_bit=w_bit, q_config=q_config,
input_feat=input_feat,
)
# apply_scale(layer, scales_list, input_feat_dict=input_feat)
apply_scale(layers[i], scales_list, input_feat_dict=input_feat)
# append prefix to make names global
awq_results["scale"] += append_str_prefix(scales_list, get_op_name(model, layer) + ".")
# Clear GPU memory
torch.cuda.empty_cache()
if mse_range:
clip_list = auto_clip_block(layer,
w_bit=w_bit, q_config=q_config,
input_feat=input_feat,)
apply_clip(layer, clip_list)
# append prefix to make names global
awq_results["clip"] += append_str_prefix(clip_list, get_op_name(model, layer) + ".")
layer = layer.cpu()
# Haotian: check activation replacement
del input_feat
gc.collect()
torch.cuda.empty_cache()
return awq_results
def apply_awq(model, awq_results):
apply_scale(model, awq_results["scale"])
apply_clip(model, awq_results["clip"])
import torch import torch
import torch.nn as nn
from tqdm import tqdm
import gc
from .qmodule import ScaledActivation
from ..utils.module import set_op_by_name
from transformers.models.bloom.modeling_bloom import BloomBlock
EMBEDDING_KEYWORDS = ["embed"]
LM_HEAD_KEYWORDS = ["lm_head", "embed_out", "output"]
def scale_activations(module):
param = next(module.parameters())
dtype = param.dtype
device = param.device
if isinstance(module, BloomBlock):
if isinstance(module.mlp.gelu_impl, ScaledActivation):
return
c = module.mlp.dense_h_to_4h.out_features
act = ScaledActivation(
module.mlp.gelu_impl,
torch.ones(c, dtype=dtype, device=device)
)
set_op_by_name(module, "mlp.gelu_impl", act)
elif 'mptblock' in str(module.__class__.__name__).lower():
if isinstance(module.ffn.act, ScaledActivation):
return
c = module.ffn.up_proj.out_features
act = ScaledActivation(
module.ffn.act,
torch.ones(c, dtype=dtype, device=device)
)
set_op_by_name(module, "ffn.act", act)
elif 'falcon' in str(module.__class__).lower():
if isinstance(module.mlp.act, ScaledActivation):
return
c = module.mlp.dense_h_to_4h.out_features
act = ScaledActivation(
module.mlp.act,
torch.ones(c, dtype=dtype, device=device)
)
set_op_by_name(module, "mlp.act", act)
# core quantization method (simulated quantization) # core quantization method (simulated quantization)
def pseudo_quantize_tensor(w, n_bit=8, def pseudo_quantize_tensor(w, w_bit=4,
zero_point=True, q_group_size=-1, zero_point=True,
q_group_size=-1,
inplace=False, inplace=False,
get_scale_zp=False get_scale_zp=False
): ):
...@@ -58,7 +15,7 @@ def pseudo_quantize_tensor(w, n_bit=8, ...@@ -58,7 +15,7 @@ def pseudo_quantize_tensor(w, n_bit=8,
if zero_point: if zero_point:
max_val = w.amax(dim=1, keepdim=True) max_val = w.amax(dim=1, keepdim=True)
min_val = w.amin(dim=1, keepdim=True) min_val = w.amin(dim=1, keepdim=True)
max_int = 2 ** n_bit - 1 max_int = 2 ** w_bit - 1
min_int = 0 min_int = 0
scales = (max_val - min_val).clamp(min=1e-5) / max_int scales = (max_val - min_val).clamp(min=1e-5) / max_int
zeros = (-torch.round(min_val / scales)).clamp_(min_int, max_int) zeros = (-torch.round(min_val / scales)).clamp_(min_int, max_int)
...@@ -66,8 +23,8 @@ def pseudo_quantize_tensor(w, n_bit=8, ...@@ -66,8 +23,8 @@ def pseudo_quantize_tensor(w, n_bit=8,
assert min_val is None assert min_val is None
max_val = w.abs().amax(dim=1, keepdim=True) max_val = w.abs().amax(dim=1, keepdim=True)
max_val = max_val.clamp(min=1e-5) max_val = max_val.clamp(min=1e-5)
max_int = 2 ** (n_bit - 1) - 1 max_int = 2 ** (w_bit - 1) - 1
min_int = - 2 ** (n_bit - 1) min_int = - 2 ** (w_bit - 1)
scales = max_val / max_int scales = max_val / max_int
zeros = 0 zeros = 0
...@@ -88,54 +45,3 @@ def pseudo_quantize_tensor(w, n_bit=8, ...@@ -88,54 +45,3 @@ def pseudo_quantize_tensor(w, n_bit=8,
return w, scales.view(w.shape[0], -1), zeros.view(w.shape[0], -1) return w, scales.view(w.shape[0], -1), zeros.view(w.shape[0], -1)
else: else:
return w return w
@torch.no_grad()
def pseudo_quantize_model_weight(
model, w_bit, q_config,
):
from .pre_quant import get_blocks, get_named_linears
layers = get_blocks(model)
for i in tqdm(range(len(layers)), desc="pseudo weight quantization..."):
named_linears = get_named_linears(layers[i])
for n, m in named_linears.items():
m.cuda()
m.weight.data = pseudo_quantize_tensor(m.weight.data, n_bit=w_bit, **q_config)
m.cpu()
@torch.no_grad()
def real_quantize_model_weight(
model, w_bit, q_config,
init_only=False
):
from .qmodule import WQLinear
from .pre_quant import get_blocks, get_named_linears
assert q_config["zero_point"], "We only support zero_point quantization now."
layers = get_blocks(model)
for i in tqdm(range(len(layers)), desc="real weight quantization..." + ("(init only)" if init_only else "")):
layer = layers[i]
named_linears = get_named_linears(layer)
scale_activations(layer)
for name, module in named_linears.items():
if init_only:
q_linear = WQLinear.from_linear(
module, w_bit, q_config['q_group_size'], True)
q_linear.to(next(layer.parameters()).device)
set_op_by_name(layer, name, q_linear)
else:
module.cuda()
module.weight.data, scales, zeros = pseudo_quantize_tensor(module.weight.data, n_bit=w_bit, get_scale_zp=True, **q_config)
scales = scales.t().contiguous()
zeros = zeros.t().contiguous()
q_linear = WQLinear.from_linear(
module, w_bit, q_config['q_group_size'], False, scales, zeros)
module.cpu()
q_linear.to(next(layer.parameters()).device)
set_op_by_name(layer, name, q_linear)
torch.cuda.empty_cache()
gc.collect()
torch.cuda.empty_cache()
gc.collect()
import torch import torch
from datasets import load_dataset from datasets import load_dataset
def get_calib_dataset(data="pileval", tokenizer=None, n_samples=512, block_size=512): def get_calib_dataset(data="pileval", tokenizer=None, n_samples=512, block_size=512):
if data == "pileval": if data == "pileval":
dataset = load_dataset("json", data_files="https://the-eye.eu/public/AI/pile/val.jsonl.zst", split="train") dataset = load_dataset("mit-han-lab/pile-val-backup", split="validation")
else: else:
raise NotImplementedError raise NotImplementedError
dataset = dataset.shuffle(seed=42) dataset = dataset.shuffle(seed=42)
......
...@@ -6,13 +6,13 @@ import fnmatch ...@@ -6,13 +6,13 @@ import fnmatch
class LMEvalAdaptor(BaseLM): class LMEvalAdaptor(BaseLM):
def __init__(self, model_name, model, tokenizer, batch_size=1, max_length=-1): def __init__(self, model_name, model, tokenizer, device, batch_size=1, max_length=-1):
super().__init__() super().__init__()
assert isinstance(batch_size, int) assert isinstance(batch_size, int)
self.model_name = model_name self.model_name = model_name
self.model = model self.model = model.to(device)
self.model.eval() self.model.eval()
self.tokenizer = tokenizer self.tokenizer = tokenizer
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment