Commit 14d198c6 authored by Casper Hansen's avatar Casper Hansen
Browse files

Implement from_pretrained. Fix static methods and classmethods

parent 35ac58c7
...@@ -2,37 +2,12 @@ import os ...@@ -2,37 +2,12 @@ import os
import torch import torch
import argparse import argparse
from lm_eval import evaluator from lm_eval import evaluator
from transformers import AutoTokenizer
from awq.models.auto import AutoAWQForCausalLM
from awq.quantize.auto_clip import apply_clip from awq.quantize.auto_clip import apply_clip
from awq.quantize.auto_scale import apply_scale from awq.quantize.auto_scale import apply_scale
from awq.utils.lm_eval_adaptor import LMEvalAdaptor from awq.utils.lm_eval_adaptor import LMEvalAdaptor
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
def get_awq_model(model):
from awq.models import MptAWQForCausalLM
if "mpt" in str(model.__class__).lower():
return MptAWQForCausalLM()
else:
raise NotImplementedError(type(model))
def load_unquantized(model_path):
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_name, trust_remote_code=True)
kwargs = {"torch_dtype": torch.float16, "low_cpu_mem_usage": True}
model = AutoModelForCausalLM.from_pretrained(
model_path, config=config, trust_remote_code=True, **kwargs)
model.eval()
return model, tokenizer
def load_quantized(model_path, quant_path, w_bit, q_config, device):
from awq.models.auto import AutoAWQForCausalLM
model = AutoAWQForCausalLM.from_quantized(model_path, quant_path, w_bit, q_config, device)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
return model, tokenizer
def load_search_result_into_memory(model, search_path): def load_search_result_into_memory(model, search_path):
awq_results = torch.load(search_path, map_location="cpu") awq_results = torch.load(search_path, map_location="cpu")
...@@ -41,27 +16,27 @@ def load_search_result_into_memory(model, search_path): ...@@ -41,27 +16,27 @@ def load_search_result_into_memory(model, search_path):
apply_clip(model, awq_results["clip"]) apply_clip(model, awq_results["clip"])
def run_search(model_path, dump_path, w_bit, q_config): def run_search(model_path, dump_path, w_bit, q_config):
model, tokenizer = load_unquantized(model_path) model = AutoAWQForCausalLM.from_pretrained(model_path)
awq_model = get_awq_model(model) tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
awq_results = awq_model.quantize(model, tokenizer, w_bit=w_bit, q_config=q_config, run_search=True, run_quant=False) awq_results = model.quantize(model.model, tokenizer, w_bit=w_bit, q_config=q_config, run_search=True, run_quant=False)
dirpath = os.path.dirname(dump_path) dirpath = os.path.dirname(dump_path)
os.makedirs(dirpath, exist_ok=True) os.makedirs(dirpath, exist_ok=True)
torch.save(awq_results, dump_path) torch.save(awq_results, dump_path)
def run_quant(model_path, search_path, dump_path, w_bit, q_config, device): def run_quant(model_path, search_path, dump_path, w_bit, q_config):
model, tokenizer = load_unquantized(model_path, device) model = AutoAWQForCausalLM.from_pretrained(model_path)
load_search_result_into_memory(model, search_path) tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
load_search_result_into_memory(model.model, search_path)
awq_model = get_awq_model(model) model.quantize(model.model, w_bit=w_bit, q_config=q_config, run_search=False, run_quant=True)
awq_model.quantize(model, w_bit=w_bit, q_config=q_config, run_search=False, run_quant=True)
dirpath = os.path.dirname(dump_path) dirpath = os.path.dirname(dump_path)
os.makedirs(dirpath, exist_ok=True) os.makedirs(dirpath, exist_ok=True)
torch.save(model.cpu().state_dict(), dump_path) torch.save(model.model.cpu().state_dict(), dump_path)
def run_perplexity(model_path, quant_path, w_bit, q_config, device): def run_perplexity(model_path, quant_path, w_bit, q_config, device):
model, tokenizer = load_quantized(model_path, quant_path, w_bit, q_config, device) model = AutoAWQForCausalLM.from_quantized(model_path, quant_path, w_bit, q_config, device)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
lm_eval_model = LMEvalAdaptor(model_path, model, tokenizer, device, batch_size=1) lm_eval_model = LMEvalAdaptor(model_path, model, tokenizer, device, batch_size=1)
results = evaluator.simple_evaluate( results = evaluator.simple_evaluate(
......
...@@ -18,13 +18,17 @@ class AutoAWQForCausalLM: ...@@ -18,13 +18,17 @@ class AutoAWQForCausalLM:
'AutoAWQForCausalLM.from_quantized or AutoAWQForCausalLM.from_pretrained') 'AutoAWQForCausalLM.from_quantized or AutoAWQForCausalLM.from_pretrained')
@classmethod @classmethod
def from_pretrained(): def from_pretrained(self, model_path, trust_remote_code=True):
pass model_type = check_and_get_model_type(model_path, trust_remote_code)
return AWQ_CAUSAL_LM_MODEL_MAP[model_type].from_pretrained(
model_path, model_type, trust_remote_code=trust_remote_code
)
@classmethod @classmethod
def from_quantized(self, model_path, quant_path, w_bit, q_config, device, trust_remote_code=True): def from_quantized(self, model_path, quant_path, w_bit, q_config, device, trust_remote_code=True):
model_type = check_and_get_model_type(model_path, trust_remote_code) model_type = check_and_get_model_type(model_path, trust_remote_code)
return AWQ_CAUSAL_LM_MODEL_MAP[model_type]().from_quantized( return AWQ_CAUSAL_LM_MODEL_MAP[model_type].from_quantized(
model_path, quant_path, w_bit, q_config, device model_path, quant_path, w_bit, q_config, device, trust_remote_code
) )
\ No newline at end of file
...@@ -15,6 +15,11 @@ from accelerate import init_empty_weights, load_checkpoint_and_dispatch ...@@ -15,6 +15,11 @@ from accelerate import init_empty_weights, load_checkpoint_and_dispatch
from awq.utils.module import append_str_prefix, get_op_name, get_named_linears, set_op_by_name from awq.utils.module import append_str_prefix, get_op_name, get_named_linears, set_op_by_name
class BaseAWQForCausalLM: class BaseAWQForCausalLM:
def __init__(self, model, model_type, is_quantized):
self.model = model
self.model_type = model_type
self.is_quantized = is_quantized
@torch.no_grad() @torch.no_grad()
def quantize(self, model, tokenizer=None, w_bit=4, q_config={}, n_samples=128, seqlen=512, def quantize(self, model, tokenizer=None, w_bit=4, q_config={}, n_samples=128, seqlen=512,
auto_scale=True, mse_range=True, run_search=False, run_quant=True, auto_scale=True, mse_range=True, run_search=False, run_quant=True,
...@@ -39,7 +44,7 @@ class BaseAWQForCausalLM: ...@@ -39,7 +44,7 @@ class BaseAWQForCausalLM:
for i in tqdm(range(len(layers)), desc="AWQ Quantization"): for i in tqdm(range(len(layers)), desc="AWQ Quantization"):
layer = layers[i] layer = layers[i]
named_linears = get_named_linears(layer) named_linears = get_named_linears(layer)
self._scale_activations(layer) self._scale_activations(self, layer)
for name, module in named_linears.items(): for name, module in named_linears.items():
module.cuda() module.cuda()
...@@ -167,9 +172,21 @@ class BaseAWQForCausalLM: ...@@ -167,9 +172,21 @@ class BaseAWQForCausalLM:
def save_quantized(): def save_quantized():
pass pass
def from_pretrained(): @classmethod
pass def from_pretrained(self, model_path, model_type, torch_dtype: torch.dtype = torch.float16, trust_remote_code=True):
# Load config
config = AutoConfig.from_pretrained(model_path, trust_remote_code=trust_remote_code)
# Load empty weights
with init_empty_weights():
model = AutoModelForCausalLM.from_config(config=config, torch_dtype=torch.float16, trust_remote_code=True)
# Load model weights
model = load_checkpoint_and_dispatch(model, model_path, device_map="balanced", no_split_module_classes=[self.layer_type])
return self(model, model_type, is_quantized=False)
@classmethod
def from_quantized(self, model_path, quant_path, w_bit, q_config, device, trust_remote_code=True): def from_quantized(self, model_path, quant_path, w_bit, q_config, device, trust_remote_code=True):
# Load config # Load config
config = AutoConfig.from_pretrained(model_path, trust_remote_code=trust_remote_code) config = AutoConfig.from_pretrained(model_path, trust_remote_code=trust_remote_code)
...@@ -183,7 +200,7 @@ class BaseAWQForCausalLM: ...@@ -183,7 +200,7 @@ class BaseAWQForCausalLM:
for i in tqdm(range(len(layers)), desc="Replacing layers..."): for i in tqdm(range(len(layers)), desc="Replacing layers..."):
layer = layers[i] layer = layers[i]
named_linears = get_named_linears(layer) named_linears = get_named_linears(layer)
self._scale_activations(layer) self._scale_activations(self, layer)
for name, module in named_linears.items(): for name, module in named_linears.items():
q_linear = WQLinear.from_linear( q_linear = WQLinear.from_linear(
...@@ -196,10 +213,11 @@ class BaseAWQForCausalLM: ...@@ -196,10 +213,11 @@ class BaseAWQForCausalLM:
model.tie_weights() model.tie_weights()
model = load_checkpoint_and_dispatch(model, quant_path, device_map="balanced") model = load_checkpoint_and_dispatch(model, quant_path, device_map="balanced", no_split_module_classes=[self.layer_type])
return model return model
@staticmethod
def _scale_activations(self, layer): def _scale_activations(self, layer):
act_function = self.get_act_from_layer(layer) act_function = self.get_act_from_layer(layer)
......
...@@ -3,10 +3,12 @@ from .base import BaseAWQForCausalLM ...@@ -3,10 +3,12 @@ from .base import BaseAWQForCausalLM
class MptAWQForCausalLM(BaseAWQForCausalLM): class MptAWQForCausalLM(BaseAWQForCausalLM):
layer_type = "MPTBlock" layer_type = "MPTBlock"
def get_model_layers(self, model): @staticmethod
def get_model_layers(model):
return model.transformer.blocks return model.transformer.blocks
def get_layers_for_scaling(self, module, input_feat, module_kwargs): @staticmethod
def get_layers_for_scaling(module, input_feat, module_kwargs):
layers = [] layers = []
# attention input # attention input
...@@ -42,16 +44,19 @@ class MptAWQForCausalLM(BaseAWQForCausalLM): ...@@ -42,16 +44,19 @@ class MptAWQForCausalLM(BaseAWQForCausalLM):
return layers return layers
def get_act_from_layer(self, layer): @staticmethod
def get_act_from_layer(layer):
return layer.ffn.act return layer.ffn.act
def get_act_for_scaling(self, module): @staticmethod
def get_act_for_scaling(module):
return dict( return dict(
scale_name="ffn.act", scale_name="ffn.act",
scale_layer=module.ffn.act, scale_layer=module.ffn.act,
scale_shape=module.ffn.up_proj.out_features scale_shape=module.ffn.up_proj.out_features
) )
def move_embed(self, model, device): @staticmethod
def move_embed(model, device):
model.transformer.wte = model.transformer.wte.to(device) model.transformer.wte = model.transformer.wte.to(device)
model.transformer.emb_drop = model.transformer.emb_drop.to(device) model.transformer.emb_drop = model.transformer.emb_drop.to(device)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment