Unverified Commit ab536fb1 authored by Jiaming Tang's avatar Jiaming Tang Committed by GitHub
Browse files

Merge pull request #22 from abhinavkulkarni/dev/more_models

parents 8e7e9ccc 6371c3a0
from lm_eval import evaluator, tasks
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig, AutoModelForSeq2SeqLM
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
import torch
import argparse
import os
import json
from accelerate import init_empty_weights, load_checkpoint_and_dispatch
from accelerate import init_empty_weights, infer_auto_device_map, dispatch_model, load_checkpoint_and_dispatch
from awq.utils.parallel import auto_parallel
from awq.quantize.pre_quant import run_awq, apply_awq
from awq.quantize.quantizer import pseudo_quantize_model_weight, real_quantize_model_weight
......@@ -20,6 +20,12 @@ parser.add_argument('--num_fewshot', type=int, default=0)
# model config
parser.add_argument('--parallel', action='store_true',
help="enable model parallelism")
# max memory to offload larger models to CPU
parser.add_argument('--max_memory', type=str, nargs='*',
help="List of device_id:max_memory pairs to be parsed into a dictionary; " \
+ "Example: 0:10GiB 1:10GiB cpu:30GiB; " \
+ "mode details here: " \
+ "https://huggingface.co/docs/accelerate/usage_guides/big_modeling")
parser.add_argument('--auto_parallel', action='store_true',
help="automatically set parallel and batch_size")
# quantization config
......@@ -43,6 +49,9 @@ parser.add_argument('--load_awq', type=str, default=None,
help="load the awq search results")
args = parser.parse_args()
max_memory = [v.split(':') for v in (args.max_memory or [])]
max_memory = {(int(k) if k.isdigit() else k):v for k,v in max_memory}
if args.auto_parallel:
gpu_list = auto_parallel(args)
......@@ -69,7 +78,6 @@ def build_model_and_enc(model_path):
enc = AutoTokenizer.from_pretrained(model_path, use_fast=False)
if args.load_quant: # directly load quantized weights
# no need to really load the fp16 weights... just to get the model structure
print("Loading pre-computed quantized weights...")
with init_empty_weights():
model = AutoModelForCausalLM.from_pretrained(model_path, config=config,
......@@ -84,21 +92,14 @@ def build_model_and_enc(model_path):
)
else: # fp16 to quantized
args.run_awq &= not args.load_awq # if load_awq, no need to run awq
# Init model on CPU:
kwargs = {"torch_dtype": torch.float16, "low_cpu_mem_usage": True}
model = AutoModelForCausalLM.from_pretrained(
model_path, config=config, trust_remote_code=True, **kwargs)
if args.run_awq:
assert args.dump_awq, "Please save the awq results with --dump_awq"
# Init model on CPU
def skip(*args, **kwargs):
pass
torch.nn.init.kaiming_normal_ = skip
torch.nn.init.kaiming_uniform_ = skip
torch.nn.init.uniform_ = skip
torch.nn.init.normal_ = skip
model = AutoModelForCausalLM.from_pretrained(
model_path, config=config, trust_remote_code=True, torch_dtype=torch.float16)
awq_results = run_awq(
model, enc,
w_bit=args.w_bit, q_config=q_config,
......@@ -112,12 +113,6 @@ def build_model_and_enc(model_path):
print("AWQ results saved at", args.dump_awq)
exit(0)
else:
# Inference with fake quant
# Init model on GPUs:
kwargs = {"device_map": "balanced", "torch_dtype": torch.float16}
model = AutoModelForCausalLM.from_pretrained(
model_path, config=config, trust_remote_code=True, **kwargs)
if args.load_awq:
print("Loading pre-computed AWQ results from", args.load_awq)
......@@ -147,6 +142,17 @@ def build_model_and_enc(model_path):
else:
raise NotImplementedError
# Move the model to GPU (as much as possible) for LM evaluation
kwargs = {"max_memory": max_memory} if len(max_memory) else {}
device_map = infer_auto_device_map(
model,
# TODO: can we remove this?
no_split_module_classes=[
"OPTDecoderLayer", "LlamaDecoderLayer", "BloomBlock", "MPTBlock", "DecoderLayer"],
**kwargs
)
model = dispatch_model(model, device_map=device_map)
return model, enc
......@@ -163,11 +169,10 @@ def main():
# a hack here to auto set model group
model, enc = build_model_and_enc(args.model_path)
lm_eval_model = LMEvalAdaptor(args.model_path, model, enc, args.batch_size)
if args.tasks is not None:
task_names = args.tasks.split(",")
lm_eval_model = LMEvalAdaptor(args.model_path, model, enc, args.batch_size)
results = evaluator.simple_evaluate(
model=lm_eval_model,
tasks=task_names,
......
......@@ -75,9 +75,11 @@ def auto_clip_block(module,
# due to qk bmm, it is hard to clip precisely
if any([_ in name for _ in ["q_", "k_", "query", "key", "Wqkv"]]):
continue
named_linears[name].cuda()
max_val = auto_clip_layer(
named_linears[name].weight, input_feat[name], n_bit=w_bit, q_config=q_config)
clip_list.append((name, max_val))
named_linears[name].cpu()
return clip_list
......@@ -86,8 +88,10 @@ def apply_clip(module, clip_list):
from ..utils.module import get_op_by_name
for name, max_val in clip_list:
layer = get_op_by_name(module, name)
layer.cuda()
max_val = max_val.to(layer.weight.device)
org_shape = layer.weight.shape
layer.weight.data = layer.weight.data.reshape(*max_val.shape[:2], -1)
layer.weight.data = torch.clamp(layer.weight.data, -max_val, max_val)
layer.weight.data = layer.weight.data.reshape(org_shape)
layer.cpu()
......@@ -321,6 +321,10 @@ def apply_scale(module, scales_list, input_feat_dict=None):
prev_op = get_op_by_name(module, prev_op_name)
layers = [get_op_by_name(module, name) for name in layer_names]
prev_op.cuda()
for layer in layers:
layer.cuda()
if isinstance(prev_op, nn.Linear):
assert len(layers) == 1
scale_fc_fc(prev_op, layers[0], scales)
......@@ -339,3 +343,7 @@ def apply_scale(module, scales_list, input_feat_dict=None):
for layer_name in layer_names:
inp = input_feat_dict[layer_name]
inp.div_(scales.view(1, -1).to(inp.device))
prev_op.cpu()
for layer in layers:
layer.cpu()
......@@ -98,7 +98,9 @@ def pseudo_quantize_model_weight(
for i in tqdm(range(len(layers)), desc="pseudo weight quantization..."):
named_linears = get_named_linears(layers[i])
for n, m in named_linears.items():
m.cuda()
m.weight.data = pseudo_quantize_tensor(m.weight.data, n_bit=w_bit, **q_config)
m.cpu()
@torch.no_grad()
......@@ -121,11 +123,13 @@ def real_quantize_model_weight(
q_linear = WQLinear.from_linear(
module, w_bit, q_config['q_group_size'], True)
else:
module.cuda()
module.weight.data, scales, zeros = pseudo_quantize_tensor(module.weight.data, n_bit=w_bit, get_scale_zp=True, **q_config)
scales = scales.t().contiguous()
zeros = zeros.t().contiguous()
q_linear = WQLinear.from_linear(
module, w_bit, q_config['q_group_size'], False, scales, zeros)
module.cpu()
set_op_by_name(layer, name, q_linear)
torch.cuda.empty_cache()
gc.collect()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment