Commit bbbd525e authored by Zhen Wan's avatar Zhen Wan
Browse files

support speedtest to benchmark FP16 model

parent 637d4abd
...@@ -12,7 +12,7 @@ from awq.utils.lm_eval_adaptor import LMEvalAdaptor ...@@ -12,7 +12,7 @@ from awq.utils.lm_eval_adaptor import LMEvalAdaptor
def load_search_result_into_memory(model, search_path): def load_search_result_into_memory(model, search_path):
awq_results = torch.load(search_path, map_location="cpu") awq_results = torch.load(search_path, map_location="cpu")
apply_scale(model, awq_results["scale"]) apply_scale(model, awq_results["scale"])
apply_clip(model, awq_results["clip"]) apply_clip(model, awq_results["clip"])
...@@ -56,7 +56,7 @@ def run_eval(model_path, quant_file, device, tasks, task_batch_size, task_n_shot ...@@ -56,7 +56,7 @@ def run_eval(model_path, quant_file, device, tasks, task_batch_size, task_n_shot
model = AutoAWQForCausalLM.from_pretrained(model_path) model = AutoAWQForCausalLM.from_pretrained(model_path)
else: else:
model = AutoAWQForCausalLM.from_quantized(model_path, quant_file) model = AutoAWQForCausalLM.from_quantized(model_path, quant_file)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
# Load adapter # Load adapter
...@@ -74,12 +74,12 @@ def run_eval(model_path, quant_file, device, tasks, task_batch_size, task_n_shot ...@@ -74,12 +74,12 @@ def run_eval(model_path, quant_file, device, tasks, task_batch_size, task_n_shot
print(evaluator.make_table(results)) print(evaluator.make_table(results))
@torch.inference_mode() @torch.inference_mode()
def run_speed(model_path, quant_file, device, n_generate=128, max_new_tokens=256): def run_speed(model_path, quant_file, device, n_generate=128, max_seq_len=256):
def _timer(func): def _timer(func):
start = time.time() start = time.time()
out = func() out = func()
return out, time.time() - start return out, time.time() - start
def _generate(model, model_out, n_generate): def _generate(model, model_out, n_generate):
past_key_values = model_out.past_key_values past_key_values = model_out.past_key_values
...@@ -90,18 +90,23 @@ def run_speed(model_path, quant_file, device, n_generate=128, max_new_tokens=256 ...@@ -90,18 +90,23 @@ def run_speed(model_path, quant_file, device, n_generate=128, max_new_tokens=256
token = torch.as_tensor([token], device=device).unsqueeze(0) token = torch.as_tensor([token], device=device).unsqueeze(0)
model_out = model(token, use_cache=True, past_key_values=past_key_values) model_out = model(token, use_cache=True, past_key_values=past_key_values)
def _warmup(device:str): def _warmup(device:str):
warm_up = torch.randn((4096,4096)).to(device) warm_up = torch.randn((4096,4096)).to(device)
torch.mm(warm_up,warm_up) torch.mm(warm_up,warm_up)
# Load model # Load model
model, load_time = _timer(lambda: AutoAWQForCausalLM.from_quantized(model_path, quant_file, fuse_layers=True)) if quant_file:
model, load_time = _timer(lambda: AutoAWQForCausalLM.from_quantized(model_path, quant_file, fuse_layers=True))
else:
# fp16 model
model, load_time = _timer(lambda: AutoAWQForCausalLM.from_pretrained(model_path))
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
_warmup(device) _warmup(device)
# Generate random inputs # Generate random inputs
n_context = max_new_tokens - n_generate n_context = max_seq_len - n_generate
ids = torch.randint(0, tokenizer.vocab_size, (1, n_context)).cuda() ids = torch.randint(0, tokenizer.vocab_size, (1, n_context)).cuda()
# Context stage # Context stage
...@@ -109,7 +114,7 @@ def run_speed(model_path, quant_file, device, n_generate=128, max_new_tokens=256 ...@@ -109,7 +114,7 @@ def run_speed(model_path, quant_file, device, n_generate=128, max_new_tokens=256
# Generation stage # Generation stage
_, generation_time = _timer(lambda: _generate(model, model_out, n_generate)) _, generation_time = _timer(lambda: _generate(model, model_out, n_generate))
# Prints # Prints
memory_used = torch.cuda.max_memory_allocated(device) / (1024 ** 2) memory_used = torch.cuda.max_memory_allocated(device) / (1024 ** 2)
context_tokens_per_second = n_context / context_time context_tokens_per_second = n_context / context_time
...@@ -138,7 +143,11 @@ if __name__ == '__main__': ...@@ -138,7 +143,11 @@ if __name__ == '__main__':
python -m awq.entry --entry_type eval --model_path lmsys/vicuna-7b-v1.5 --task_use_pretrained python -m awq.entry --entry_type eval --model_path lmsys/vicuna-7b-v1.5 --task_use_pretrained
- Run a speedtest to benchmark the quantized model: - Run a speedtest to benchmark the quantized model:
python -m awq.entry --entry_type speed --model_path vicuna-7b-v1.5-awq --quant_file awq_model_w4_g128.pt python -m awq.entry --entry_type speed --model_path vicuna-7b-v1.5-awq --quant_file awq_model_w4_g128.pt \
--n_generate 128 --max_seq_len 256
- Run a speedtest to benchmark the unquantized FP16 model:
python -m awq.entry --entry_type speed --model_path lmsys/vicuna-7b-v1.5 --n_generate 128 --max_seq_len 256
""" """
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--entry_type', type=str, help='The type of task to run (search|quant|eval|speed)') parser.add_argument('--entry_type', type=str, help='The type of task to run (search|quant|eval|speed)')
...@@ -157,19 +166,19 @@ if __name__ == '__main__': ...@@ -157,19 +166,19 @@ if __name__ == '__main__':
parser.add_argument('--task_batch_size', type=int, default=1) parser.add_argument('--task_batch_size', type=int, default=1)
parser.add_argument('--task_n_shot', type=int, default=0) parser.add_argument('--task_n_shot', type=int, default=0)
parser.add_argument('--n_generate', type=int, default=128) parser.add_argument('--n_generate', type=int, default=128)
parser.add_argument('--n_context', type=int, default=256) parser.add_argument('--max_seq_len', type=int, default=256)
args = parser.parse_args() args = parser.parse_args()
quant_config = { "zero_point": True, "q_group_size": args.q_group_size, "w_bit": args.w_bit } quant_config = { "zero_point": True, "q_group_size": args.q_group_size, "w_bit": args.w_bit }
if args.entry_type == 'search': if args.entry_type == 'search':
run_search(args.model_path, args.search_path, quant_config) run_search(args.model_path, args.search_path, quant_config)
elif args.entry_type == 'quant': elif args.entry_type == 'quant':
run_quant(args.model_path, args.search_path, args.quant_path, quant_config) run_quant(args.model_path, args.search_path, args.quant_path, quant_config)
elif args.entry_type == 'eval': elif args.entry_type == 'eval':
run_eval(args.model_path, args.quant_file, args.device, run_eval(args.model_path, args.quant_file, args.device,
args.tasks, args.task_batch_size, args.task_n_shot, args.task_use_pretrained) args.tasks, args.task_batch_size, args.task_n_shot, args.task_use_pretrained)
elif args.entry_type == 'speed': elif args.entry_type == 'speed':
run_speed(args.model_path, args.quant_file, args.device, args.n_generate, args.n_context) run_speed(args.model_path, args.quant_file, args.device, args.n_generate, args.max_seq_len)
else: else:
raise Exception('--entry_type must be one of (search|quant|eval|speed)') raise Exception('--entry_type must be one of (search|quant|eval|speed)')
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment