Commit 46750ff9 authored by Casper Hansen's avatar Casper Hansen
Browse files

Add speed benchmark

parent 63346c34
import os
import time
import torch
import argparse
from lm_eval import evaluator
......@@ -72,6 +73,56 @@ def run_eval(model_path, quant_file, device, tasks, task_batch_size, task_n_shot
print(evaluator.make_table(results))
@torch.inference_mode()
def run_speed(model_path, quant_file, device, n_generate=128, max_new_tokens=256):
def _timer(func):
start = time.time()
out = func()
return out, time.time() - start
def _generate(model, model_out, n_generate):
past_key_values = model_out.past_key_values
for i in range(n_generate):
logits = model_out.logits[0, -1, :]
probs = torch.softmax(logits, dim=-1)
token = torch.multinomial(probs, num_samples=1)
token = torch.as_tensor([token], device=device).unsqueeze(0)
model_out = model(token, use_cache=True, past_key_values=past_key_values)
def _warmup(device:str):
warm_up = torch.randn((4096,4096)).to(device)
torch.mm(warm_up,warm_up)
# Load model
model, load_time = _timer(lambda: AutoAWQForCausalLM.from_quantized(model_path, quant_file, fuse_layers=True))
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
_warmup(device)
# Generate random inputs
n_context = max_new_tokens - n_generate
ids = torch.randint(0, tokenizer.vocab_size, (1, n_context)).cuda()
# Context stage
model_out, context_time = _timer(lambda: model(ids, use_cache=True))
# Generation stage
_, generation_time = _timer(lambda: _generate(model, model_out, n_generate))
# Prints
memory_used = torch.cuda.max_memory_allocated(device) / (1024 ** 2)
context_tokens_per_second = n_context / context_time
context_ms_per_token = (context_time*1000) / n_context
inference_tokens_per_second = n_generate / generation_time
inference_ms_per_token = (generation_time*1000) / n_generate
print(f"[======] Model summary: {model_path} [======]")
print(f"[*] Load time: {load_time:.2f} seconds")
print(f"[*] Context speed: {context_tokens_per_second:.2f} tokens/second ({context_ms_per_token:.2f} ms/token)")
print(f"[*] Generation speed: {inference_tokens_per_second:.2f} tokens/second ({inference_ms_per_token:.2f} ms/token)")
print(f"[*] VRAM: {memory_used:.2f} MB")
if __name__ == '__main__':
"""
- Run AWQ search and save result:
......@@ -85,9 +136,12 @@ if __name__ == '__main__':
- Run perplexity unquantized FP16 model:
python -m awq.entry --entry_type eval --model_path lmsys/vicuna-7b-v1.5 --task_use_pretrained
- Run a speedtest to benchmark the quantized model:
python -m awq.entry --entry_type speed --model_path vicuna-7b-v1.5-awq --quant_file awq_model_w4_g128.pt
"""
parser = argparse.ArgumentParser()
parser.add_argument('--entry_type', type=str, help='The type of task to run (search|quant|eval)')
parser.add_argument('--entry_type', type=str, help='The type of task to run (search|quant|eval|speed)')
parser.add_argument('--model_path', type=str, help='Path to hf model')
parser.add_argument('--search_path', type=str, help='Path to save/load AWQ search results')
parser.add_argument('--quant_path', type=str, help='Path to save AWQ model to directory')
......@@ -102,6 +156,8 @@ if __name__ == '__main__':
help="Pass '--task_use_pretrained' to use a pretrained model running FP16")
parser.add_argument('--task_batch_size', type=int, default=1)
parser.add_argument('--task_n_shot', type=int, default=0)
parser.add_argument('--n_generate', type=int, default=128)
parser.add_argument('--n_context', type=int, default=256)
args = parser.parse_args()
quant_config = { "zero_point": True, "q_group_size": args.q_group_size, "w_bit": args.w_bit }
......@@ -113,5 +169,7 @@ if __name__ == '__main__':
elif args.entry_type == 'eval':
run_eval(args.model_path, args.quant_file, args.device,
args.tasks, args.task_batch_size, args.task_n_shot, args.task_use_pretrained)
elif args.entry_type == 'speed':
run_speed(args.model_path, args.quant_file, args.device, args.n_generate, args.n_context)
else:
raise Exception('--entry_type must be one of (search|quant|eval)')
\ No newline at end of file
raise Exception('--entry_type must be one of (search|quant|eval|speed)')
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment