Commit 14d198c6 authored by Casper Hansen's avatar Casper Hansen
Browse files

Implement from_pretrained. Fix static methods and classmethods

parent 35ac58c7
......@@ -2,37 +2,12 @@ import os
import torch
import argparse
from lm_eval import evaluator
from transformers import AutoTokenizer
from awq.models.auto import AutoAWQForCausalLM
from awq.quantize.auto_clip import apply_clip
from awq.quantize.auto_scale import apply_scale
from awq.utils.lm_eval_adaptor import LMEvalAdaptor
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
def get_awq_model(model):
from awq.models import MptAWQForCausalLM
if "mpt" in str(model.__class__).lower():
return MptAWQForCausalLM()
else:
raise NotImplementedError(type(model))
def load_unquantized(model_path):
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_name, trust_remote_code=True)
kwargs = {"torch_dtype": torch.float16, "low_cpu_mem_usage": True}
model = AutoModelForCausalLM.from_pretrained(
model_path, config=config, trust_remote_code=True, **kwargs)
model.eval()
return model, tokenizer
def load_quantized(model_path, quant_path, w_bit, q_config, device):
from awq.models.auto import AutoAWQForCausalLM
model = AutoAWQForCausalLM.from_quantized(model_path, quant_path, w_bit, q_config, device)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
return model, tokenizer
def load_search_result_into_memory(model, search_path):
awq_results = torch.load(search_path, map_location="cpu")
......@@ -41,27 +16,27 @@ def load_search_result_into_memory(model, search_path):
apply_clip(model, awq_results["clip"])
def run_search(model_path, dump_path, w_bit, q_config):
model, tokenizer = load_unquantized(model_path)
awq_model = get_awq_model(model)
awq_results = awq_model.quantize(model, tokenizer, w_bit=w_bit, q_config=q_config, run_search=True, run_quant=False)
model = AutoAWQForCausalLM.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
awq_results = model.quantize(model.model, tokenizer, w_bit=w_bit, q_config=q_config, run_search=True, run_quant=False)
dirpath = os.path.dirname(dump_path)
os.makedirs(dirpath, exist_ok=True)
torch.save(awq_results, dump_path)
def run_quant(model_path, search_path, dump_path, w_bit, q_config, device):
model, tokenizer = load_unquantized(model_path, device)
load_search_result_into_memory(model, search_path)
awq_model = get_awq_model(model)
awq_model.quantize(model, w_bit=w_bit, q_config=q_config, run_search=False, run_quant=True)
def run_quant(model_path, search_path, dump_path, w_bit, q_config):
model = AutoAWQForCausalLM.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
load_search_result_into_memory(model.model, search_path)
model.quantize(model.model, w_bit=w_bit, q_config=q_config, run_search=False, run_quant=True)
dirpath = os.path.dirname(dump_path)
os.makedirs(dirpath, exist_ok=True)
torch.save(model.cpu().state_dict(), dump_path)
torch.save(model.model.cpu().state_dict(), dump_path)
def run_perplexity(model_path, quant_path, w_bit, q_config, device):
model, tokenizer = load_quantized(model_path, quant_path, w_bit, q_config, device)
model = AutoAWQForCausalLM.from_quantized(model_path, quant_path, w_bit, q_config, device)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
lm_eval_model = LMEvalAdaptor(model_path, model, tokenizer, device, batch_size=1)
results = evaluator.simple_evaluate(
......
......@@ -18,13 +18,17 @@ class AutoAWQForCausalLM:
'AutoAWQForCausalLM.from_quantized or AutoAWQForCausalLM.from_pretrained')
@classmethod
def from_pretrained():
pass
def from_pretrained(self, model_path, trust_remote_code=True):
model_type = check_and_get_model_type(model_path, trust_remote_code)
return AWQ_CAUSAL_LM_MODEL_MAP[model_type].from_pretrained(
model_path, model_type, trust_remote_code=trust_remote_code
)
@classmethod
def from_quantized(self, model_path, quant_path, w_bit, q_config, device, trust_remote_code=True):
model_type = check_and_get_model_type(model_path, trust_remote_code)
return AWQ_CAUSAL_LM_MODEL_MAP[model_type]().from_quantized(
model_path, quant_path, w_bit, q_config, device
return AWQ_CAUSAL_LM_MODEL_MAP[model_type].from_quantized(
model_path, quant_path, w_bit, q_config, device, trust_remote_code
)
\ No newline at end of file
......@@ -15,6 +15,11 @@ from accelerate import init_empty_weights, load_checkpoint_and_dispatch
from awq.utils.module import append_str_prefix, get_op_name, get_named_linears, set_op_by_name
class BaseAWQForCausalLM:
def __init__(self, model, model_type, is_quantized):
self.model = model
self.model_type = model_type
self.is_quantized = is_quantized
@torch.no_grad()
def quantize(self, model, tokenizer=None, w_bit=4, q_config={}, n_samples=128, seqlen=512,
auto_scale=True, mse_range=True, run_search=False, run_quant=True,
......@@ -39,7 +44,7 @@ class BaseAWQForCausalLM:
for i in tqdm(range(len(layers)), desc="AWQ Quantization"):
layer = layers[i]
named_linears = get_named_linears(layer)
self._scale_activations(layer)
self._scale_activations(self, layer)
for name, module in named_linears.items():
module.cuda()
......@@ -167,9 +172,21 @@ class BaseAWQForCausalLM:
def save_quantized():
pass
def from_pretrained():
pass
@classmethod
def from_pretrained(self, model_path, model_type, torch_dtype: torch.dtype = torch.float16, trust_remote_code=True):
# Load config
config = AutoConfig.from_pretrained(model_path, trust_remote_code=trust_remote_code)
# Load empty weights
with init_empty_weights():
model = AutoModelForCausalLM.from_config(config=config, torch_dtype=torch.float16, trust_remote_code=True)
# Load model weights
model = load_checkpoint_and_dispatch(model, model_path, device_map="balanced", no_split_module_classes=[self.layer_type])
return self(model, model_type, is_quantized=False)
@classmethod
def from_quantized(self, model_path, quant_path, w_bit, q_config, device, trust_remote_code=True):
# Load config
config = AutoConfig.from_pretrained(model_path, trust_remote_code=trust_remote_code)
......@@ -183,7 +200,7 @@ class BaseAWQForCausalLM:
for i in tqdm(range(len(layers)), desc="Replacing layers..."):
layer = layers[i]
named_linears = get_named_linears(layer)
self._scale_activations(layer)
self._scale_activations(self, layer)
for name, module in named_linears.items():
q_linear = WQLinear.from_linear(
......@@ -196,10 +213,11 @@ class BaseAWQForCausalLM:
model.tie_weights()
model = load_checkpoint_and_dispatch(model, quant_path, device_map="balanced")
model = load_checkpoint_and_dispatch(model, quant_path, device_map="balanced", no_split_module_classes=[self.layer_type])
return model
@staticmethod
def _scale_activations(self, layer):
act_function = self.get_act_from_layer(layer)
......
......@@ -3,10 +3,12 @@ from .base import BaseAWQForCausalLM
class MptAWQForCausalLM(BaseAWQForCausalLM):
layer_type = "MPTBlock"
def get_model_layers(self, model):
@staticmethod
def get_model_layers(model):
return model.transformer.blocks
def get_layers_for_scaling(self, module, input_feat, module_kwargs):
@staticmethod
def get_layers_for_scaling(module, input_feat, module_kwargs):
layers = []
# attention input
......@@ -42,16 +44,19 @@ class MptAWQForCausalLM(BaseAWQForCausalLM):
return layers
def get_act_from_layer(self, layer):
@staticmethod
def get_act_from_layer(layer):
return layer.ffn.act
def get_act_for_scaling(self, module):
@staticmethod
def get_act_for_scaling(module):
return dict(
scale_name="ffn.act",
scale_layer=module.ffn.act,
scale_shape=module.ffn.up_proj.out_features
)
def move_embed(self, model, device):
@staticmethod
def move_embed(model, device):
model.transformer.wte = model.transformer.wte.to(device)
model.transformer.emb_drop = model.transformer.emb_drop.to(device)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment