Commit 5d4ab5dc authored by Casper Hansen's avatar Casper Hansen
Browse files

Refactor entry.py [WIP]

parent 290f45e7
from lm_eval import evaluator, tasks
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
import torch
import argparse
import os
import json
from accelerate import init_empty_weights, infer_auto_device_map, dispatch_model, load_checkpoint_in_model
from accelerate.utils.modeling import get_balanced_memory
from awq.utils.parallel import auto_parallel
from awq.quantize.quantizer import pseudo_quantize_model_weight, real_quantize_model_weight
from awq.utils.lm_eval_adaptor import LMEvalAdaptor
from awq.utils.utils import simple_dispatch_model
import torch
from awq.quantize.auto_clip import apply_clip
from awq.quantize.auto_scale import apply_scale
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
parser = argparse.ArgumentParser()
parser.add_argument('--model_path', type=str, help='path of the hf model')
parser.add_argument('--batch_size', type=int, default=1, help='batch size')
parser.add_argument("--tasks", default=None, type=str)
parser.add_argument("--output_path", default=None, type=str)
parser.add_argument('--num_fewshot', type=int, default=0)
# model config
parser.add_argument('--parallel', action='store_true',
help="enable model parallelism")
# max memory to offload larger models to CPU
parser.add_argument('--max_memory', type=str, nargs='*',
help="List of device_id:max_memory pairs to be parsed into a dictionary; " \
+ "Example: 0:10GiB 1:10GiB cpu:30GiB; " \
+ "mode details here: " \
+ "https://huggingface.co/docs/accelerate/usage_guides/big_modeling")
parser.add_argument('--auto_parallel', action='store_true',
help="automatically set parallel and batch_size")
# quantization config
parser.add_argument('--w_bit', type=int, default=None)
parser.add_argument('--q_group_size', type=int, default=-1)
parser.add_argument('--no_zero_point', action='store_true',
help="disable zero_point")
parser.add_argument('--q_backend', type=str,
default="fake", choices=["fake", "real"])
# save/load real quantized weights
parser.add_argument('--dump_quant', type=str, default=None,
help='save quantized model')
parser.add_argument('--load_quant', type=str, default=None,
help='load quantized model')
# apply/save/load awq
parser.add_argument('--run_awq', action='store_true',
help="perform awq search process")
parser.add_argument('--dump_awq', type=str, default=None,
help="save the awq search results")
parser.add_argument('--load_awq', type=str, default=None,
help="load the awq search results")
args = parser.parse_args()
max_memory = [v.split(':') for v in (args.max_memory or [])]
max_memory = [v.split(':') for v in (None or [])]
max_memory = {(int(k) if k.isdigit() else k):v for k,v in max_memory}
if args.auto_parallel:
gpu_list = auto_parallel(args)
# get quantization config (apart from w_bit)
q_config = {
"zero_point": not args.no_zero_point, # by default True
"q_group_size": args.q_group_size, # whether to use group quantization
}
print("Quantization config:", q_config)
def get_awq_model(model):
from awq.models import MptAWQForCausalLM
......@@ -73,149 +15,45 @@ def get_awq_model(model):
else:
raise NotImplementedError(type(model))
def build_model_and_enc(model_path):
if not os.path.exists(model_path): # look into ssd
raise FileNotFoundError(f"{model_path} not found!")
print(f"* Building model {model_path}")
# all hf model
def load_unquantized(model_path):
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
if "mpt" in config.__class__.__name__.lower():
enc = AutoTokenizer.from_pretrained(config.tokenizer_name, trust_remote_code=True)
else:
enc = AutoTokenizer.from_pretrained(model_path, use_fast=False, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_name, trust_remote_code=True)
if args.load_quant: # directly load quantized weights
print("Loading pre-computed quantized weights...")
with init_empty_weights():
model = AutoModelForCausalLM.from_config(config=config,
torch_dtype=torch.float16, trust_remote_code=True)
real_quantize_model_weight(
model, w_bit=args.w_bit, q_config=q_config, init_only=True)
model.tie_weights()
# Infer device map
kwargs = {"max_memory": max_memory} if len(max_memory) else {}
device_map = infer_auto_device_map(
model,
no_split_module_classes=[
"OPTDecoderLayer", "LlamaDecoderLayer", "BloomBlock", "MPTBlock", "DecoderLayer"],
**kwargs
)
# Load checkpoint in the model
load_checkpoint_in_model(
model,
checkpoint=args.load_quant,
device_map=device_map,
offload_state_dict=True,
)
# Dispatch model
model = simple_dispatch_model(model, device_map=device_map)
kwargs = {"torch_dtype": torch.float16, "low_cpu_mem_usage": True}
model = AutoModelForCausalLM.from_pretrained(
model_path, config=config, trust_remote_code=True, **kwargs)
model.eval()
else: # fp16 to quantized
args.run_awq &= not args.load_awq # if load_awq, no need to run awq
# Init model on CPU:
kwargs = {"torch_dtype": torch.float16, "low_cpu_mem_usage": True}
model = AutoModelForCausalLM.from_pretrained(
model_path, config=config, trust_remote_code=True, **kwargs)
model.eval()
model.eval()
return model, tokenizer
if args.run_awq:
assert args.dump_awq, "Please save the awq results with --dump_awq"
def load_search_result_into_memory(model, search_path):
awq_results = torch.load(search_path, map_location="cpu")
awq_model = get_awq_model(model)
awq_results = awq_model.quantize(model, enc, args.w_bit, q_config)
if args.dump_awq:
dirpath = os.path.dirname(args.dump_awq)
os.makedirs(dirpath, exist_ok=True)
torch.save(awq_results, args.dump_awq)
print("AWQ results saved at", args.dump_awq)
exit(0)
if args.load_awq:
print("Loading pre-computed AWQ results from", args.load_awq)
awq_results = torch.load(args.load_awq, map_location="cpu")
apply_scale(model, awq_results["scale"])
apply_clip(model, awq_results["clip"])
# weight quantization
if args.w_bit is not None:
if args.q_backend == "fake":
assert args.dump_quant is None, \
"Need to use real quantization to dump quantized weights"
pseudo_quantize_model_weight(
model, w_bit=args.w_bit, q_config=q_config
)
elif args.q_backend == "real": # real quantization
real_quantize_model_weight(
model, w_bit=args.w_bit, q_config=q_config
)
if args.dump_quant:
dirpath = os.path.dirname(args.dump_quant)
os.makedirs(dirpath, exist_ok=True)
print(
f"Saving the quantized model at {args.dump_quant}...")
torch.save(model.cpu().state_dict(), args.dump_quant)
exit(0)
else:
raise NotImplementedError
# Move the model to GPU (as much as possible) for LM evaluation
kwargs = {"max_memory": get_balanced_memory(model, max_memory if len(max_memory) > 0 else None)}
device_map = infer_auto_device_map(
model,
# TODO: can we remove this?
no_split_module_classes=[
"OPTDecoderLayer", "LlamaDecoderLayer", "BloomBlock", "MPTBlock", "DecoderLayer"],
**kwargs
)
model = dispatch_model(model, device_map=device_map)
return model, enc
def main():
if args.output_path is not None and os.path.exists(args.output_path):
# print(f"Results {args.output_path} already generated. Exit.")
print(f"Results {args.output_path} already generated. Overwrite.")
# exit()
if args.dump_awq and os.path.exists(args.dump_awq):
print(f"Found existing AWQ results {args.dump_awq}, exit.")
exit()
# a hack here to auto set model group
model, enc = build_model_and_enc(args.model_path)
apply_scale(model, awq_results["scale"])
apply_clip(model, awq_results["clip"])
if args.tasks is not None:
task_names = args.tasks.split(",")
def run_search(model, dump_path):
model, tokenizer = load_unquantized(model_path)
awq_model = get_awq_model(model)
awq_results = awq_model.quantize(model, tokenizer, w_bit=4, q_config=q_config, run_search=True, run_quant=False)
lm_eval_model = LMEvalAdaptor(args.model_path, model, enc, args.batch_size)
results = evaluator.simple_evaluate(
model=lm_eval_model,
tasks=task_names,
batch_size=args.batch_size,
no_cache=True,
num_fewshot=args.num_fewshot,
)
dirpath = os.path.dirname(dump_path)
os.makedirs(dirpath, exist_ok=True)
torch.save(awq_results, dump_path)
print(evaluator.make_table(results))
def run_quant(model, search_path, dump_path):
model, tokenizer = load_unquantized(model_path)
load_search_result_into_memory(model, search_path)
if args.output_path is not None:
os.makedirs(os.path.dirname(args.output_path), exist_ok=True)
# otherwise cannot save
results["config"]["model"] = args.model_path
with open(args.output_path, "w") as f:
json.dump(results, f, indent=2)
awq_model = get_awq_model(model)
awq_model.quantize(model, w_bit=4, q_config=q_config, run_search=False, run_quant=True)
dirpath = os.path.dirname(dump_path)
os.makedirs(dirpath, exist_ok=True)
torch.save(model.cpu().state_dict(), dump_path)
if __name__ == '__main__':
main()
model_path = "./mpt-7b-8k-chat"
search_path = "./mpt-7b-8k-chat/mpt-7b-8k-chat-awq-search.pt"
quant_path = "./mpt-7b-8k-chat/mpt-7b-8k-chat-w4-g128.pt"
q_config = { "zero_point": True, "q_group_size": 128 }
import gc
import tqdm
import torch
import functools
import torch.nn as nn
from tqdm import tqdm
from collections import defaultdict
from awq.utils.calib_data import get_calib_dataset
from awq.quantize.auto_clip import auto_clip_block, apply_clip
from awq.quantize.auto_scale import auto_scale_block, apply_scale
from awq.utils.module import append_str_prefix, get_op_name, get_named_linears
from awq.utils.module import append_str_prefix, get_op_name, get_named_linears, set_op_by_name
from awq.quantize.quantizer import pseudo_quantize_tensor
from awq.quantize.qmodule import WQLinear, ScaledActivation
class BaseAWQForCausalLM:
@torch.no_grad()
def quantize(self, model, tokenizer, w_bit, q_config, n_samples=128, seqlen=512,
auto_scale=True, mse_range=True, calib_data="pileval"):
def quantize(self, model, tokenizer=None, w_bit=4, q_config={}, n_samples=128, seqlen=512,
auto_scale=True, mse_range=True, run_search=False, run_quant=True,
calib_data="pileval", init_only=False):
if run_search:
self._awq_search(model, tokenizer, w_bit, q_config, n_samples=n_samples, seqlen=seqlen,
auto_scale=auto_scale, mse_range=mse_range, calib_data=calib_data)
if run_quant:
self._awq_quant(model, w_bit, q_config, init_only)
def _awq_quant(self, model, w_bit, q_config, init_only):
assert q_config["zero_point"], "We only support zero_point quantization now."
layers = self.get_model_layers(model)
# Run AWQ quantization
for i in tqdm(range(len(layers)), desc="AWQ Quantization"):
layer = layers[i]
named_linears = get_named_linears(layer)
if not isinstance(layer.ffn.act, ScaledActivation):
param = next(layer.parameters())
# get activation scale
scale_dict = self.get_act_for_scaling(layer)
scale_like = torch.ones(scale_dict['scale_shape'], dtype=param.dtype, device=param.device)
# scale activation
scaled_act = ScaledActivation(scale_dict['scale_layer'], scale_like)
set_op_by_name(layer, scale_dict['scale_name'], scaled_act)
for name, module in named_linears.items():
if init_only:
q_linear = WQLinear.from_linear(
module, w_bit, q_config['q_group_size'], True)
q_linear.to(next(layer.parameters()).device)
set_op_by_name(layer, name, q_linear)
else:
module.cuda()
module.weight.data, scales, zeros = pseudo_quantize_tensor(module.weight.data, n_bit=w_bit, get_scale_zp=True, **q_config)
scales = scales.t().contiguous()
zeros = zeros.t().contiguous()
q_linear = WQLinear.from_linear(
module, w_bit, q_config['q_group_size'], False, scales, zeros)
module.cpu()
q_linear.to(next(layer.parameters()).device)
set_op_by_name(layer, name, q_linear)
torch.cuda.empty_cache()
gc.collect()
torch.cuda.empty_cache()
gc.collect()
def _awq_search(self, model, tokenizer, w_bit, q_config, n_samples=128, seqlen=512,
auto_scale=True, mse_range=True, calib_data="pileval"):
layers = self.get_model_layers(model)
samples = get_calib_dataset(
......@@ -62,8 +117,8 @@ class BaseAWQForCausalLM:
"clip": [],
}
# solve layer by layer
for i in tqdm.tqdm(range(len(layers)), desc="Running AWQ..."):
# Run AWQ search layer by layer
for i in tqdm(range(len(layers)), desc="AWQ Search:"):
layer = layers[i]
layer = layer.cuda()
named_linears = get_named_linears(layer)
......@@ -119,7 +174,7 @@ class BaseAWQForCausalLM:
del input_feat
gc.collect()
torch.cuda.empty_cache()
return awq_results
def save_quantized():
......
import torch
import torch.nn as nn
from tqdm import tqdm
import gc
from .qmodule import ScaledActivation
from ..utils.module import set_op_by_name
from transformers.models.bloom.modeling_bloom import BloomBlock
EMBEDDING_KEYWORDS = ["embed"]
LM_HEAD_KEYWORDS = ["lm_head", "embed_out", "output"]
def scale_activations(module):
param = next(module.parameters())
dtype = param.dtype
device = param.device
if isinstance(module, BloomBlock):
if isinstance(module.mlp.gelu_impl, ScaledActivation):
return
c = module.mlp.dense_h_to_4h.out_features
act = ScaledActivation(
module.mlp.gelu_impl,
torch.ones(c, dtype=dtype, device=device)
)
set_op_by_name(module, "mlp.gelu_impl", act)
elif 'mptblock' in str(module.__class__.__name__).lower():
if isinstance(module.ffn.act, ScaledActivation):
return
# get activation scale
scale_dict = MptAWQForCausalLM().get_act_for_scaling(module)
scale_like = torch.ones(scale_dict['scale_shape'], dtype=dtype, device=device)
# scale activation
scaled_act = ScaledActivation(scale_dict['scale_layer'], scale_like)
set_op_by_name(module, scale_dict['scale_name'], scaled_act)
elif 'falcon' in str(module.__class__).lower():
if isinstance(module.mlp.act, ScaledActivation):
return
c = module.mlp.dense_h_to_4h.out_features
act = ScaledActivation(
module.mlp.act,
torch.ones(c, dtype=dtype, device=device)
)
set_op_by_name(module, "mlp.act", act)
# core quantization method (simulated quantization)
def pseudo_quantize_tensor(w, n_bit=8,
zero_point=True, q_group_size=-1,
......@@ -107,38 +63,8 @@ def pseudo_quantize_model_weight(
@torch.no_grad()
def real_quantize_model_weight(
model, w_bit, q_config,
init_only=False
):
from .qmodule import WQLinear
from .pre_quant import get_blocks, get_named_linears
assert q_config["zero_point"], "We only support zero_point quantization now."
layers = get_blocks(model)
for i in tqdm(range(len(layers)), desc="real weight quantization..." + ("(init only)" if init_only else "")):
def real_quantize_model_weight(model, awq_model):
layers = awq_model.get_model_layers(model)
for i in tqdm(range(len(layers)), desc="real weight quantization..."):
layer = layers[i]
named_linears = get_named_linears(layer)
scale_activations(layer)
for name, module in named_linears.items():
if init_only:
q_linear = WQLinear.from_linear(
module, w_bit, q_config['q_group_size'], True)
q_linear.to(next(layer.parameters()).device)
set_op_by_name(layer, name, q_linear)
else:
module.cuda()
module.weight.data, scales, zeros = pseudo_quantize_tensor(module.weight.data, n_bit=w_bit, get_scale_zp=True, **q_config)
scales = scales.t().contiguous()
zeros = zeros.t().contiguous()
q_linear = WQLinear.from_linear(
module, w_bit, q_config['q_group_size'], False, scales, zeros)
module.cpu()
q_linear.to(next(layer.parameters()).device)
set_op_by_name(layer, name, q_linear)
torch.cuda.empty_cache()
gc.collect()
torch.cuda.empty_cache()
gc.collect()
del layer
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment