Commit e09dc751 authored by Casper Hansen's avatar Casper Hansen
Browse files

Create AutoAWQForCausalLM and load quantized models with from_quantized

parent 934ad336
......@@ -27,8 +27,12 @@ def load_unquantized(model_path):
return model, tokenizer
def load_quantized(model_path):
awq_model = get_awq_model(model)
def load_quantized(model_path, quant_path, w_bit, q_config, device):
from awq.models.auto import AutoAWQForCausalLM
model = AutoAWQForCausalLM.from_quantized(model_path, quant_path, w_bit, q_config, device)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
return model, tokenizer
def load_search_result_into_memory(model, search_path):
awq_results = torch.load(search_path, map_location="cpu")
......@@ -56,8 +60,8 @@ def run_quant(model_path, search_path, dump_path, w_bit, q_config, device):
os.makedirs(dirpath, exist_ok=True)
torch.save(model.cpu().state_dict(), dump_path)
def run_perplexity(model_path, device):
model, tokenizer = load_unquantized(model_path)
def run_perplexity(model_path, quant_path, w_bit, q_config, device):
model, tokenizer = load_quantized(model_path, quant_path, w_bit, q_config, device)
lm_eval_model = LMEvalAdaptor(model_path, model, tokenizer, device, batch_size=1)
results = evaluator.simple_evaluate(
......@@ -91,6 +95,6 @@ if __name__ == '__main__':
elif args.entry_type == 'quant':
run_quant(args.model_path, args.search_path, args.quant_path, args.w_bit, q_config)
elif args.entry_type == 'perplexity':
run_perplexity(args.model_path, args.device)
run_perplexity(args.model_path, args.quant_path, args.w_bit, q_config, args.device)
else:
raise Exception('--entry_type must be one of (search|quant|perplexity)')
\ No newline at end of file
from transformers import AutoConfig
from awq.models import MptAWQForCausalLM
AWQ_CAUSAL_LM_MODEL_MAP = {
"mpt": MptAWQForCausalLM,
}
def check_and_get_model_type(model_dir, trust_remote_code=True):
config = AutoConfig.from_pretrained(model_dir, trust_remote_code=trust_remote_code)
if config.model_type not in AWQ_CAUSAL_LM_MODEL_MAP.keys():
raise TypeError(f"{config.model_type} isn't supported yet.")
model_type = config.model_type
return model_type
class AutoAWQForCausalLM:
def __init__(self):
raise EnvironmentError('You must instantiate AutoAWQForCausalLM with\n'
'AutoAWQForCausalLM.from_quantized or AutoAWQForCausalLM.from_pretrained')
@classmethod
def from_pretrained():
pass
@classmethod
def from_quantized(self, model_path, quant_path, w_bit, q_config, device, trust_remote_code=True):
model_type = check_and_get_model_type(model_path, trust_remote_code)
return AWQ_CAUSAL_LM_MODEL_MAP[model_type]().from_quantized(
model_path, quant_path, w_bit, q_config, device
)
\ No newline at end of file
......@@ -6,15 +6,15 @@ from tqdm import tqdm
from collections import defaultdict
from awq.utils.calib_data import get_calib_dataset
from transformers import AutoModelForCausalLM, AutoConfig
from awq.quantize.quantizer import pseudo_quantize_tensor
from awq.quantize.qmodule import WQLinear, ScaledActivation
from awq.quantize.auto_clip import auto_clip_block, apply_clip
from awq.quantize.auto_scale import auto_scale_block, apply_scale
from accelerate import init_empty_weights, load_checkpoint_and_dispatch
from awq.utils.module import append_str_prefix, get_op_name, get_named_linears, set_op_by_name
from awq.quantize.quantizer import pseudo_quantize_tensor
from awq.quantize.qmodule import WQLinear, ScaledActivation
class BaseAWQForCausalLM:
@torch.no_grad()
def quantize(self, model, tokenizer=None, w_bit=4, q_config={}, n_samples=128, seqlen=512,
auto_scale=True, mse_range=True, run_search=False, run_quant=True,
......@@ -186,5 +186,46 @@ class BaseAWQForCausalLM:
def from_pretrained():
pass
def from_quantized():
pass
\ No newline at end of file
def from_quantized(self, model_path, quant_path, w_bit, q_config, device, trust_remote_code=True):
# Load config
config = AutoConfig.from_pretrained(model_path, trust_remote_code=trust_remote_code)
with init_empty_weights():
model = AutoModelForCausalLM.from_config(config=config, torch_dtype=torch.float16, trust_remote_code=True)
# Initialize layers
assert q_config["zero_point"], "We only support zero_point quantization now."
layers = self.get_model_layers(model)
for i in tqdm(range(len(layers)), desc="Replacing layers..."):
layer = layers[i]
named_linears = get_named_linears(layer)
self._scale_activations(layer)
for name, module in named_linears.items():
q_linear = WQLinear.from_linear(
module, w_bit, q_config['q_group_size'], True)
q_linear.to(next(layer.parameters()).device)
set_op_by_name(layer, name, q_linear)
torch.cuda.empty_cache()
gc.collect()
model.tie_weights()
model = load_checkpoint_and_dispatch(model, quant_path, device_map="balanced")
return model
def _scale_activations(self, layer):
act_function = self.get_act_from_layer(layer)
if act_function is not None and not isinstance(act_function, ScaledActivation):
param = next(layer.parameters())
# get activation scale
scale_dict = self.get_act_for_scaling(layer)
scale_like = torch.ones(scale_dict['scale_shape'], dtype=param.dtype, device=param.device)
# scale activation
scaled_act = ScaledActivation(scale_dict['scale_layer'], scale_like)
set_op_by_name(layer, scale_dict['scale_name'], scaled_act)
\ No newline at end of file
......@@ -42,6 +42,9 @@ class MptAWQForCausalLM(BaseAWQForCausalLM):
return layers
def get_act_from_layer(self, layer):
return layer.ffn.act
def get_act_for_scaling(self, module):
return dict(
scale_name="ffn.act",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment