Unverified Commit d1112e1c authored by Casper's avatar Casper Committed by GitHub
Browse files

New scaling to improve perplexity (#216)

parent 63d2aaec
...@@ -49,12 +49,12 @@ class BaseAWQForCausalLM(nn.Module): ...@@ -49,12 +49,12 @@ class BaseAWQForCausalLM(nn.Module):
@torch.no_grad() @torch.no_grad()
def quantize(self, tokenizer=None, quant_config={}, def quantize(self, tokenizer=None, quant_config={},
calib_data: Union[str, List[str]]="pileval", calib_data: Union[str, List[str]]="pileval",
split="train", text_column="text"): split="train", text_column="text", duo_scaling=True):
self.quant_config: AwqConfig = AwqConfig.from_dict(quant_config) self.quant_config: AwqConfig = AwqConfig.from_dict(quant_config)
quantizer = AwqQuantizer( quantizer = AwqQuantizer(
self, self.model, tokenizer, self.quant_config.w_bit, self.quant_config.q_group_size, self, self.model, tokenizer, self.quant_config.w_bit, self.quant_config.q_group_size,
self.quant_config.version, calib_data, split, text_column self.quant_config.version, calib_data, split, text_column, duo_scaling
) )
quantizer.quantize() quantizer.quantize()
self.is_quantized = True self.is_quantized = True
......
...@@ -14,7 +14,7 @@ from awq.utils.module import append_str_prefix, get_op_name, get_named_linears, ...@@ -14,7 +14,7 @@ from awq.utils.module import append_str_prefix, get_op_name, get_named_linears,
class AwqQuantizer: class AwqQuantizer:
def __init__(self, awq_model, model, tokenizer, w_bit, group_size, version, def __init__(self, awq_model, model, tokenizer, w_bit, group_size, version,
calib_data, split, text_column) -> None: calib_data, split, text_column, duo_scaling) -> None:
self.awq_model = awq_model self.awq_model = awq_model
self.model = model self.model = model
self.tokenizer = tokenizer self.tokenizer = tokenizer
...@@ -24,6 +24,7 @@ class AwqQuantizer: ...@@ -24,6 +24,7 @@ class AwqQuantizer:
self.calib_data = calib_data self.calib_data = calib_data
self.split = split self.split = split
self.text_column = text_column self.text_column = text_column
self.duo_scaling = duo_scaling
self.modules, self.module_kwargs, self.inps = self.init_quant() self.modules, self.module_kwargs, self.inps = self.init_quant()
def pseudo_quantize_tensor(self, w: torch.Tensor, get_scale_zp=False): def pseudo_quantize_tensor(self, w: torch.Tensor, get_scale_zp=False):
...@@ -197,7 +198,10 @@ class AwqQuantizer: ...@@ -197,7 +198,10 @@ class AwqQuantizer:
ratio = ratio / n_grid ratio = ratio / n_grid
# NOTE: s^-1 * x is fused here, according to paper # NOTE: s^-1 * x is fused here, according to paper
scales = (x_max.pow(ratio) / w_max.pow(1-ratio)).clamp(min=1e-4) if self.duo_scaling:
scales = (x_max.pow(ratio) / w_max.pow(1-ratio)).clamp(min=1e-4)
else:
scales = x_max.pow(ratio).clamp(min=1e-4).view(-1)
scales = scales / (scales.max() * scales.min()).sqrt() scales = scales / (scales.max() * scales.min()).sqrt()
scales_view = scales.view(1, -1).to(device) scales_view = scales.view(1, -1).to(device)
......
import torch
import torch.nn as nn
from tqdm import tqdm
from datasets import load_dataset
def evaluate_perplexity(model, tokenizer):
def _perplexity(nlls, n_samples, seqlen):
return torch.exp(torch.stack(nlls).sum() / (n_samples * seqlen))
# load and prepare dataset
data = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test')
data = tokenizer("\n\n".join(data['text']), return_tensors='pt')
data = data.input_ids.to(model.device)
seqlen = 2048
model = model.eval()
n_samples = data.numel() // seqlen
nlls = []
with tqdm(range(n_samples), desc="Perplexity -") as progress_bar:
for i in progress_bar:
start_index = (i * seqlen)
end_index = ((i + 1) * seqlen)
batch = data[:, start_index:end_index].to(model.device)
with torch.no_grad():
logits = model(batch).logits
shift_logits = logits[:, :-1, :].contiguous().float()
shift_labels = data[:, start_index:end_index][:, 1:]
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
neg_log_likelihood = loss.float() * seqlen
nlls.append(neg_log_likelihood)
curr_ppl = _perplexity(nlls, i+1, seqlen)
progress_bar.set_description(f"Perplexity {curr_ppl:.3f}")
ppl = _perplexity(nlls, n_samples, seqlen)
return ppl.item()
if __name__ == '__main__':
from transformers import AutoModelForCausalLM, AutoTokenizer
model_path = 'mistralai/Mistral-7B-Instruct-v0.1'
model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(model_path)
evaluate_perplexity(model, tokenizer)
...@@ -3,32 +3,41 @@ from lm_eval import evaluator ...@@ -3,32 +3,41 @@ from lm_eval import evaluator
from awq import AutoAWQForCausalLM from awq import AutoAWQForCausalLM
from transformers import AutoTokenizer from transformers import AutoTokenizer
from awq.utils.lm_eval_adaptor import LMEvalAdaptor from awq.utils.lm_eval_adaptor import LMEvalAdaptor
from awq.utils.eval_utils import evaluate_perplexity
def run_eval(model_path, quant_file, device, tasks, task_batch_size, task_n_shot, task_use_pretrained): def run_eval(
model_path, quant_file, device, tasks, task_batch_size, task_n_shot,
task_use_pretrained, pretrained_safetensors
):
""" """
Post quantization: Evaluate perplexity on wikitext with EleutherAI Evaluation Harness Post quantization: Evaluate perplexity on wikitext with EleutherAI Evaluation Harness
""" """
# Load model # Load model
if task_use_pretrained: if task_use_pretrained:
model = AutoAWQForCausalLM.from_pretrained(model_path) model = AutoAWQForCausalLM.from_pretrained(model_path, safetensors=pretrained_safetensors)
else: else:
model = AutoAWQForCausalLM.from_quantized(model_path, quant_file, fuse_layers=False) model = AutoAWQForCausalLM.from_quantized(model_path, quant_file, fuse_layers=False)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
# Load adapter # Load adapter
lm_eval_model = LMEvalAdaptor(model_path, model, tokenizer, device, batch_size=task_batch_size) tasks = tasks.split(',')
if len(tasks) == 1 and tasks[0] == 'wikitext':
# Evaluate perplexity of quantized model evaluate_perplexity(model.model, tokenizer)
results = evaluator.simple_evaluate(
model=lm_eval_model, else:
tasks=tasks.split(','), lm_eval_model = LMEvalAdaptor(model_path, model, tokenizer, device, batch_size=task_batch_size)
batch_size=task_batch_size,
no_cache=True,
num_fewshot=task_n_shot,
)
print(evaluator.make_table(results)) # Evaluate perplexity of quantized model
results = evaluator.simple_evaluate(
model=lm_eval_model,
tasks=tasks,
batch_size=task_batch_size,
no_cache=True,
num_fewshot=task_n_shot,
)
print(evaluator.make_table(results))
if __name__ == '__main__': if __name__ == '__main__':
""" """
...@@ -45,6 +54,8 @@ if __name__ == '__main__': ...@@ -45,6 +54,8 @@ if __name__ == '__main__':
parser.add_argument('--device', type=str, default='cuda:0', help='Device to load model to') parser.add_argument('--device', type=str, default='cuda:0', help='Device to load model to')
parser.add_argument("--use_pretrained", default=False, action='store_true', parser.add_argument("--use_pretrained", default=False, action='store_true',
help="Pass '--use_pretrained' to use a pretrained model running FP16") help="Pass '--use_pretrained' to use a pretrained model running FP16")
parser.add_argument("--pretrained_safetensors", default=False, action='store_true',
help="Load safetensors for FP16 model")
parser.add_argument('--tasks', type=str, default='wikitext', help='Tasks to evaluate. ' parser.add_argument('--tasks', type=str, default='wikitext', help='Tasks to evaluate. '
'Separate tasks by comma for multiple tasks.' 'Separate tasks by comma for multiple tasks.'
'https://github.com/EleutherAI/lm-evaluation-harness/blob/master/docs/task_table.md') 'https://github.com/EleutherAI/lm-evaluation-harness/blob/master/docs/task_table.md')
...@@ -52,5 +63,9 @@ if __name__ == '__main__': ...@@ -52,5 +63,9 @@ if __name__ == '__main__':
parser.add_argument('--n_shot', type=int, default=0) parser.add_argument('--n_shot', type=int, default=0)
args = parser.parse_args() args = parser.parse_args()
run_eval(args.model_path, args.quant_file, args.device, run_eval(
args.tasks, args.batch_size, args.n_shot, args.use_pretrained) args.model_path, args.quant_file, args.device,
\ No newline at end of file args.tasks, args.batch_size, args.n_shot, args.use_pretrained,
args.pretrained_safetensors
)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment