Unverified Commit ce4a6bb1 authored by Jiaming Tang's avatar Jiaming Tang Committed by GitHub
Browse files

Merge pull request #36 from abhinavkulkarni/dev/more_models

parents 3b9f2875 ba01560f
......@@ -4,11 +4,12 @@ import torch
import argparse
import os
import json
from accelerate import init_empty_weights, infer_auto_device_map, dispatch_model, load_checkpoint_and_dispatch
from accelerate import init_empty_weights, infer_auto_device_map, dispatch_model, load_checkpoint_in_model
from awq.utils.parallel import auto_parallel
from awq.quantize.pre_quant import run_awq, apply_awq
from awq.quantize.quantizer import pseudo_quantize_model_weight, real_quantize_model_weight
from awq.utils.lm_eval_adaptor import LMEvalAdaptor
from awq.utils.utils import simple_dispatch_model
parser = argparse.ArgumentParser()
......@@ -80,16 +81,32 @@ def build_model_and_enc(model_path):
if args.load_quant: # directly load quantized weights
print("Loading pre-computed quantized weights...")
with init_empty_weights():
model = AutoModelForCausalLM.from_pretrained(model_path, config=config,
torch_dtype=torch.float16, trust_remote_code=True)
model = AutoModelForCausalLM.from_config(config=config,
torch_dtype=torch.float16, trust_remote_code=True)
real_quantize_model_weight(
model, w_bit=args.w_bit, q_config=q_config, init_only=True)
model = load_checkpoint_and_dispatch(
model, args.load_quant, device_map="balanced",
# TODO: can we remove this?
model.tie_weights()
# Infer device map
kwargs = {"max_memory": max_memory} if len(max_memory) else {}
device_map = infer_auto_device_map(
model,
no_split_module_classes=[
"OPTDecoderLayer", "LlamaDecoderLayer", "BloomBlock", "MPTBlock", "DecoderLayer"]
"OPTDecoderLayer", "LlamaDecoderLayer", "BloomBlock", "MPTBlock", "DecoderLayer"],
**kwargs
)
# Load checkpoint in the model
load_checkpoint_in_model(
model,
checkpoint=args.load_quant,
device_map=device_map,
offload_state_dict=True,
)
# Dispatch model
model = simple_dispatch_model(model, device_map=device_map)
model.eval()
else: # fp16 to quantized
args.run_awq &= not args.load_awq # if load_awq, no need to run awq
# Init model on CPU:
......@@ -97,6 +114,8 @@ def build_model_and_enc(model_path):
model = AutoModelForCausalLM.from_pretrained(
model_path, config=config, trust_remote_code=True, **kwargs)
model.eval()
if args.run_awq:
assert args.dump_awq, "Please save the awq results with --dump_awq"
......
import gc
import torch
import torch.nn as nn
......@@ -112,6 +113,7 @@ def auto_scale_block(module, module_kwargs,
weight, q_group_size=q_config.get("q_group_size", -1))
# Clear GPU memory
del weight
gc.collect()
torch.cuda.empty_cache()
x = x.to(next(block.parameters()).device)
......@@ -129,8 +131,6 @@ def auto_scale_block(module, module_kwargs,
n_grid = 20
history = []
# Clear GPU memory
torch.cuda.empty_cache()
org_sd = {k: v.cpu() for k, v in block.state_dict().items()}
for ratio in range(n_grid):
ratio = ratio * 1 / n_grid
......@@ -169,6 +169,7 @@ def auto_scale_block(module, module_kwargs,
module2inspect = layers[0]
scales = _search_module_scale(module2inspect, layers, inp, kwargs)
scales = scales.detach().cpu()
# prev_op_name, [layer_name], scale
return (get_op_name(module, prev_op), tuple([get_op_name(module, m) for m in layers]), scales)
......@@ -302,13 +303,31 @@ def auto_scale_block(module, module_kwargs,
))
"""
# fc1, as long as it is scaled, everything is screwed up
scales_list.append(_auto_get_scale(
prev_op=module.input_layernorm,
layers=[module.mlp.dense_h_to_4h, module.self_attention.query_key_value],
inp=input_feat['self_attention.query_key_value'],
module2inspect=module,
kwargs=module_kwargs,
))
if "falcon-7b" in str(module.__class__).lower():
scales_list.append(_auto_get_scale(
prev_op=module.input_layernorm,
layers=[module.mlp.dense_h_to_4h, module.self_attention.query_key_value],
inp=input_feat['self_attention.query_key_value'],
module2inspect=module,
kwargs=module_kwargs,
))
elif "falcon-40b" in str(module.__class__).lower():
scales_list.append(_auto_get_scale(
prev_op=module.ln_attn,
layers=[module.self_attention.query_key_value],
inp=input_feat['self_attention.query_key_value'],
module2inspect=module,
kwargs=module_kwargs,
))
scales_list.append(_auto_get_scale(
prev_op=module.ln_mlp,
layers=[module.mlp.dense_h_to_4h],
inp=input_feat['mlp.dense_h_to_4h'],
module2inspect=module,
kwargs=module_kwargs,
))
else:
raise NotImplementedError("Unknown Falcon architecture, currently only falcon-7b and falcon-40b are supported")
# fc2
scales_list.append(_auto_get_scale(
prev_op=module.mlp.act,
......@@ -329,6 +348,7 @@ def apply_scale(module, scales_list, input_feat_dict=None):
prev_op.cuda()
for layer in layers:
layer.cuda()
scales.cuda()
if isinstance(prev_op, nn.Linear):
assert len(layers) == 1
......@@ -352,3 +372,4 @@ def apply_scale(module, scales_list, input_feat_dict=None):
prev_op.cpu()
for layer in layers:
layer.cpu()
scales.cpu()
......@@ -95,6 +95,7 @@ def run_awq(
model(samples.to(next(model.parameters()).device))
except ValueError: # work with early exit
pass
del samples
layers[0] = layers[0].module # restore
inps = inps[0]
......
......@@ -93,3 +93,7 @@ class WQLinear(nn.Module):
out = out + self.bias if self.bias is not None else out
return out.reshape(out_shape)
def extra_repr(self) -> str:
return 'in_features={}, out_features={}, bias={}, w_bit={}, group_size={}'.format(
self.in_features, self.out_features, self.bias is not None, self.w_bit, self.group_size
)
......@@ -130,6 +130,7 @@ def real_quantize_model_weight(
q_linear = WQLinear.from_linear(
module, w_bit, q_config['q_group_size'], False, scales, zeros)
module.cpu()
q_linear.to(next(layer.parameters()).device)
set_op_by_name(layer, name, q_linear)
torch.cuda.empty_cache()
gc.collect()
import torch
import accelerate
def get_module_by_name_suffix(model, module_name: str):
for name, module in model.named_modules():
if name.endswith(module_name):
return module
def simple_dispatch_model(model, device_map):
from accelerate.hooks import add_hook_to_module, AlignDevicesHook
if "" in device_map:
d = device_map[""]
model = model.to(torch.device(d))
model.hf_device_map = device_map
return model
tied_params = accelerate.utils.modeling.find_tied_parameters(model)
if set(device_map.values()) == {"cpu"} or set(device_map.values()) == {"cpu", "disk"}:
main_device = "cpu"
else:
main_device = [d for d in device_map.values() if d not in ["cpu", "disk"]][0]
cpu_offload_group = [(n, d) for n, d in device_map.items() if d == "cpu"]
prev_hook = None
for idx, (n, d) in enumerate(cpu_offload_group):
m = get_module_by_name_suffix(model, n)
_, prev_hook = accelerate.cpu_offload_with_hook(m, execution_device=main_device, prev_module_hook=prev_hook)
# set first cpu offload module's prev_module_hook to the last cpu offload module's hook
if len(cpu_offload_group) > 1:
get_module_by_name_suffix(model, cpu_offload_group[0][0])._hf_hook.prev_module_hook = prev_hook
for n, d in device_map.items():
m = get_module_by_name_suffix(model, n)
if d != "cpu":
d = torch.device(d)
hook = AlignDevicesHook(d, io_same_device=True, place_submodules=True)
add_hook_to_module(m, hook)
accelerate.utils.modeling.retie_parameters(model, tied_params)
model.hf_device_map = device_map
return model
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment