Commit 5d4ab5dc authored by Casper Hansen's avatar Casper Hansen
Browse files

Refactor entry.py [WIP]

parent 290f45e7
from lm_eval import evaluator, tasks
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
import torch
import argparse
import os import os
import json import torch
from accelerate import init_empty_weights, infer_auto_device_map, dispatch_model, load_checkpoint_in_model
from accelerate.utils.modeling import get_balanced_memory
from awq.utils.parallel import auto_parallel
from awq.quantize.quantizer import pseudo_quantize_model_weight, real_quantize_model_weight
from awq.utils.lm_eval_adaptor import LMEvalAdaptor
from awq.utils.utils import simple_dispatch_model
from awq.quantize.auto_clip import apply_clip from awq.quantize.auto_clip import apply_clip
from awq.quantize.auto_scale import apply_scale from awq.quantize.auto_scale import apply_scale
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
parser = argparse.ArgumentParser() max_memory = [v.split(':') for v in (None or [])]
parser.add_argument('--model_path', type=str, help='path of the hf model')
parser.add_argument('--batch_size', type=int, default=1, help='batch size')
parser.add_argument("--tasks", default=None, type=str)
parser.add_argument("--output_path", default=None, type=str)
parser.add_argument('--num_fewshot', type=int, default=0)
# model config
parser.add_argument('--parallel', action='store_true',
help="enable model parallelism")
# max memory to offload larger models to CPU
parser.add_argument('--max_memory', type=str, nargs='*',
help="List of device_id:max_memory pairs to be parsed into a dictionary; " \
+ "Example: 0:10GiB 1:10GiB cpu:30GiB; " \
+ "mode details here: " \
+ "https://huggingface.co/docs/accelerate/usage_guides/big_modeling")
parser.add_argument('--auto_parallel', action='store_true',
help="automatically set parallel and batch_size")
# quantization config
parser.add_argument('--w_bit', type=int, default=None)
parser.add_argument('--q_group_size', type=int, default=-1)
parser.add_argument('--no_zero_point', action='store_true',
help="disable zero_point")
parser.add_argument('--q_backend', type=str,
default="fake", choices=["fake", "real"])
# save/load real quantized weights
parser.add_argument('--dump_quant', type=str, default=None,
help='save quantized model')
parser.add_argument('--load_quant', type=str, default=None,
help='load quantized model')
# apply/save/load awq
parser.add_argument('--run_awq', action='store_true',
help="perform awq search process")
parser.add_argument('--dump_awq', type=str, default=None,
help="save the awq search results")
parser.add_argument('--load_awq', type=str, default=None,
help="load the awq search results")
args = parser.parse_args()
max_memory = [v.split(':') for v in (args.max_memory or [])]
max_memory = {(int(k) if k.isdigit() else k):v for k,v in max_memory} max_memory = {(int(k) if k.isdigit() else k):v for k,v in max_memory}
if args.auto_parallel:
gpu_list = auto_parallel(args)
# get quantization config (apart from w_bit)
q_config = {
"zero_point": not args.no_zero_point, # by default True
"q_group_size": args.q_group_size, # whether to use group quantization
}
print("Quantization config:", q_config)
def get_awq_model(model): def get_awq_model(model):
from awq.models import MptAWQForCausalLM from awq.models import MptAWQForCausalLM
...@@ -73,149 +15,45 @@ def get_awq_model(model): ...@@ -73,149 +15,45 @@ def get_awq_model(model):
else: else:
raise NotImplementedError(type(model)) raise NotImplementedError(type(model))
def build_model_and_enc(model_path): def load_unquantized(model_path):
if not os.path.exists(model_path): # look into ssd
raise FileNotFoundError(f"{model_path} not found!")
print(f"* Building model {model_path}")
# all hf model
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
if "mpt" in config.__class__.__name__.lower(): tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_name, trust_remote_code=True)
enc = AutoTokenizer.from_pretrained(config.tokenizer_name, trust_remote_code=True)
else:
enc = AutoTokenizer.from_pretrained(model_path, use_fast=False, trust_remote_code=True)
if args.load_quant: # directly load quantized weights
print("Loading pre-computed quantized weights...")
with init_empty_weights():
model = AutoModelForCausalLM.from_config(config=config,
torch_dtype=torch.float16, trust_remote_code=True)
real_quantize_model_weight(
model, w_bit=args.w_bit, q_config=q_config, init_only=True)
model.tie_weights()
# Infer device map
kwargs = {"max_memory": max_memory} if len(max_memory) else {}
device_map = infer_auto_device_map(
model,
no_split_module_classes=[
"OPTDecoderLayer", "LlamaDecoderLayer", "BloomBlock", "MPTBlock", "DecoderLayer"],
**kwargs
)
# Load checkpoint in the model
load_checkpoint_in_model(
model,
checkpoint=args.load_quant,
device_map=device_map,
offload_state_dict=True,
)
# Dispatch model
model = simple_dispatch_model(model, device_map=device_map)
model.eval()
else: # fp16 to quantized
args.run_awq &= not args.load_awq # if load_awq, no need to run awq
# Init model on CPU:
kwargs = {"torch_dtype": torch.float16, "low_cpu_mem_usage": True} kwargs = {"torch_dtype": torch.float16, "low_cpu_mem_usage": True}
model = AutoModelForCausalLM.from_pretrained( model = AutoModelForCausalLM.from_pretrained(
model_path, config=config, trust_remote_code=True, **kwargs) model_path, config=config, trust_remote_code=True, **kwargs)
model.eval() model.eval()
if args.run_awq: return model, tokenizer
assert args.dump_awq, "Please save the awq results with --dump_awq"
awq_model = get_awq_model(model)
awq_results = awq_model.quantize(model, enc, args.w_bit, q_config)
if args.dump_awq:
dirpath = os.path.dirname(args.dump_awq)
os.makedirs(dirpath, exist_ok=True)
torch.save(awq_results, args.dump_awq)
print("AWQ results saved at", args.dump_awq)
exit(0)
if args.load_awq: def load_search_result_into_memory(model, search_path):
print("Loading pre-computed AWQ results from", args.load_awq) awq_results = torch.load(search_path, map_location="cpu")
awq_results = torch.load(args.load_awq, map_location="cpu")
apply_scale(model, awq_results["scale"]) apply_scale(model, awq_results["scale"])
apply_clip(model, awq_results["clip"]) apply_clip(model, awq_results["clip"])
# weight quantization def run_search(model, dump_path):
if args.w_bit is not None: model, tokenizer = load_unquantized(model_path)
if args.q_backend == "fake": awq_model = get_awq_model(model)
assert args.dump_quant is None, \ awq_results = awq_model.quantize(model, tokenizer, w_bit=4, q_config=q_config, run_search=True, run_quant=False)
"Need to use real quantization to dump quantized weights"
pseudo_quantize_model_weight(
model, w_bit=args.w_bit, q_config=q_config
)
elif args.q_backend == "real": # real quantization
real_quantize_model_weight(
model, w_bit=args.w_bit, q_config=q_config
)
if args.dump_quant:
dirpath = os.path.dirname(args.dump_quant)
os.makedirs(dirpath, exist_ok=True)
print(
f"Saving the quantized model at {args.dump_quant}...")
torch.save(model.cpu().state_dict(), args.dump_quant)
exit(0)
else:
raise NotImplementedError
# Move the model to GPU (as much as possible) for LM evaluation
kwargs = {"max_memory": get_balanced_memory(model, max_memory if len(max_memory) > 0 else None)}
device_map = infer_auto_device_map(
model,
# TODO: can we remove this?
no_split_module_classes=[
"OPTDecoderLayer", "LlamaDecoderLayer", "BloomBlock", "MPTBlock", "DecoderLayer"],
**kwargs
)
model = dispatch_model(model, device_map=device_map)
return model, enc
def main():
if args.output_path is not None and os.path.exists(args.output_path):
# print(f"Results {args.output_path} already generated. Exit.")
print(f"Results {args.output_path} already generated. Overwrite.")
# exit()
if args.dump_awq and os.path.exists(args.dump_awq):
print(f"Found existing AWQ results {args.dump_awq}, exit.")
exit()
# a hack here to auto set model group
model, enc = build_model_and_enc(args.model_path)
if args.tasks is not None:
task_names = args.tasks.split(",")
lm_eval_model = LMEvalAdaptor(args.model_path, model, enc, args.batch_size) dirpath = os.path.dirname(dump_path)
results = evaluator.simple_evaluate( os.makedirs(dirpath, exist_ok=True)
model=lm_eval_model, torch.save(awq_results, dump_path)
tasks=task_names,
batch_size=args.batch_size,
no_cache=True,
num_fewshot=args.num_fewshot,
)
print(evaluator.make_table(results)) def run_quant(model, search_path, dump_path):
model, tokenizer = load_unquantized(model_path)
load_search_result_into_memory(model, search_path)
if args.output_path is not None: awq_model = get_awq_model(model)
os.makedirs(os.path.dirname(args.output_path), exist_ok=True) awq_model.quantize(model, w_bit=4, q_config=q_config, run_search=False, run_quant=True)
# otherwise cannot save
results["config"]["model"] = args.model_path
with open(args.output_path, "w") as f:
json.dump(results, f, indent=2)
dirpath = os.path.dirname(dump_path)
os.makedirs(dirpath, exist_ok=True)
torch.save(model.cpu().state_dict(), dump_path)
if __name__ == '__main__': model_path = "./mpt-7b-8k-chat"
main() search_path = "./mpt-7b-8k-chat/mpt-7b-8k-chat-awq-search.pt"
quant_path = "./mpt-7b-8k-chat/mpt-7b-8k-chat-w4-g128.pt"
q_config = { "zero_point": True, "q_group_size": 128 }
import gc import gc
import tqdm
import torch import torch
import functools import functools
import torch.nn as nn import torch.nn as nn
from tqdm import tqdm
from collections import defaultdict from collections import defaultdict
from awq.utils.calib_data import get_calib_dataset from awq.utils.calib_data import get_calib_dataset
from awq.quantize.auto_clip import auto_clip_block, apply_clip from awq.quantize.auto_clip import auto_clip_block, apply_clip
from awq.quantize.auto_scale import auto_scale_block, apply_scale from awq.quantize.auto_scale import auto_scale_block, apply_scale
from awq.utils.module import append_str_prefix, get_op_name, get_named_linears from awq.utils.module import append_str_prefix, get_op_name, get_named_linears, set_op_by_name
from awq.quantize.quantizer import pseudo_quantize_tensor
from awq.quantize.qmodule import WQLinear, ScaledActivation
class BaseAWQForCausalLM: class BaseAWQForCausalLM:
@torch.no_grad() @torch.no_grad()
def quantize(self, model, tokenizer, w_bit, q_config, n_samples=128, seqlen=512, def quantize(self, model, tokenizer=None, w_bit=4, q_config={}, n_samples=128, seqlen=512,
auto_scale=True, mse_range=True, calib_data="pileval"): auto_scale=True, mse_range=True, run_search=False, run_quant=True,
calib_data="pileval", init_only=False):
if run_search:
self._awq_search(model, tokenizer, w_bit, q_config, n_samples=n_samples, seqlen=seqlen,
auto_scale=auto_scale, mse_range=mse_range, calib_data=calib_data)
if run_quant:
self._awq_quant(model, w_bit, q_config, init_only)
def _awq_quant(self, model, w_bit, q_config, init_only):
assert q_config["zero_point"], "We only support zero_point quantization now."
layers = self.get_model_layers(model)
# Run AWQ quantization
for i in tqdm(range(len(layers)), desc="AWQ Quantization"):
layer = layers[i]
named_linears = get_named_linears(layer)
if not isinstance(layer.ffn.act, ScaledActivation):
param = next(layer.parameters())
# get activation scale
scale_dict = self.get_act_for_scaling(layer)
scale_like = torch.ones(scale_dict['scale_shape'], dtype=param.dtype, device=param.device)
# scale activation
scaled_act = ScaledActivation(scale_dict['scale_layer'], scale_like)
set_op_by_name(layer, scale_dict['scale_name'], scaled_act)
for name, module in named_linears.items():
if init_only:
q_linear = WQLinear.from_linear(
module, w_bit, q_config['q_group_size'], True)
q_linear.to(next(layer.parameters()).device)
set_op_by_name(layer, name, q_linear)
else:
module.cuda()
module.weight.data, scales, zeros = pseudo_quantize_tensor(module.weight.data, n_bit=w_bit, get_scale_zp=True, **q_config)
scales = scales.t().contiguous()
zeros = zeros.t().contiguous()
q_linear = WQLinear.from_linear(
module, w_bit, q_config['q_group_size'], False, scales, zeros)
module.cpu()
q_linear.to(next(layer.parameters()).device)
set_op_by_name(layer, name, q_linear)
torch.cuda.empty_cache()
gc.collect()
torch.cuda.empty_cache()
gc.collect()
def _awq_search(self, model, tokenizer, w_bit, q_config, n_samples=128, seqlen=512,
auto_scale=True, mse_range=True, calib_data="pileval"):
layers = self.get_model_layers(model) layers = self.get_model_layers(model)
samples = get_calib_dataset( samples = get_calib_dataset(
...@@ -62,8 +117,8 @@ class BaseAWQForCausalLM: ...@@ -62,8 +117,8 @@ class BaseAWQForCausalLM:
"clip": [], "clip": [],
} }
# solve layer by layer # Run AWQ search layer by layer
for i in tqdm.tqdm(range(len(layers)), desc="Running AWQ..."): for i in tqdm(range(len(layers)), desc="AWQ Search:"):
layer = layers[i] layer = layers[i]
layer = layer.cuda() layer = layer.cuda()
named_linears = get_named_linears(layer) named_linears = get_named_linears(layer)
......
import torch import torch
import torch.nn as nn
from tqdm import tqdm from tqdm import tqdm
import gc
from .qmodule import ScaledActivation
from ..utils.module import set_op_by_name
from transformers.models.bloom.modeling_bloom import BloomBlock
EMBEDDING_KEYWORDS = ["embed"] EMBEDDING_KEYWORDS = ["embed"]
LM_HEAD_KEYWORDS = ["lm_head", "embed_out", "output"] LM_HEAD_KEYWORDS = ["lm_head", "embed_out", "output"]
def scale_activations(module):
param = next(module.parameters())
dtype = param.dtype
device = param.device
if isinstance(module, BloomBlock):
if isinstance(module.mlp.gelu_impl, ScaledActivation):
return
c = module.mlp.dense_h_to_4h.out_features
act = ScaledActivation(
module.mlp.gelu_impl,
torch.ones(c, dtype=dtype, device=device)
)
set_op_by_name(module, "mlp.gelu_impl", act)
elif 'mptblock' in str(module.__class__.__name__).lower():
if isinstance(module.ffn.act, ScaledActivation):
return
# get activation scale
scale_dict = MptAWQForCausalLM().get_act_for_scaling(module)
scale_like = torch.ones(scale_dict['scale_shape'], dtype=dtype, device=device)
# scale activation
scaled_act = ScaledActivation(scale_dict['scale_layer'], scale_like)
set_op_by_name(module, scale_dict['scale_name'], scaled_act)
elif 'falcon' in str(module.__class__).lower():
if isinstance(module.mlp.act, ScaledActivation):
return
c = module.mlp.dense_h_to_4h.out_features
act = ScaledActivation(
module.mlp.act,
torch.ones(c, dtype=dtype, device=device)
)
set_op_by_name(module, "mlp.act", act)
# core quantization method (simulated quantization) # core quantization method (simulated quantization)
def pseudo_quantize_tensor(w, n_bit=8, def pseudo_quantize_tensor(w, n_bit=8,
zero_point=True, q_group_size=-1, zero_point=True, q_group_size=-1,
...@@ -107,38 +63,8 @@ def pseudo_quantize_model_weight( ...@@ -107,38 +63,8 @@ def pseudo_quantize_model_weight(
@torch.no_grad() @torch.no_grad()
def real_quantize_model_weight( def real_quantize_model_weight(model, awq_model):
model, w_bit, q_config, layers = awq_model.get_model_layers(model)
init_only=False for i in tqdm(range(len(layers)), desc="real weight quantization..."):
):
from .qmodule import WQLinear
from .pre_quant import get_blocks, get_named_linears
assert q_config["zero_point"], "We only support zero_point quantization now."
layers = get_blocks(model)
for i in tqdm(range(len(layers)), desc="real weight quantization..." + ("(init only)" if init_only else "")):
layer = layers[i] layer = layers[i]
named_linears = get_named_linears(layer) del layer
scale_activations(layer)
for name, module in named_linears.items():
if init_only:
q_linear = WQLinear.from_linear(
module, w_bit, q_config['q_group_size'], True)
q_linear.to(next(layer.parameters()).device)
set_op_by_name(layer, name, q_linear)
else:
module.cuda()
module.weight.data, scales, zeros = pseudo_quantize_tensor(module.weight.data, n_bit=w_bit, get_scale_zp=True, **q_config)
scales = scales.t().contiguous()
zeros = zeros.t().contiguous()
q_linear = WQLinear.from_linear(
module, w_bit, q_config['q_group_size'], False, scales, zeros)
module.cpu()
q_linear.to(next(layer.parameters()).device)
set_op_by_name(layer, name, q_linear)
torch.cuda.empty_cache()
gc.collect()
torch.cuda.empty_cache()
gc.collect()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment