Unverified Commit abdc726c authored by Casper's avatar Casper Committed by GitHub
Browse files

Merge pull request #25 from wanzhenchn/main

support speedtest to benchmark FP16 model
parents 637d4abd 4f42f509
...@@ -74,7 +74,7 @@ def run_eval(model_path, quant_file, device, tasks, task_batch_size, task_n_shot ...@@ -74,7 +74,7 @@ def run_eval(model_path, quant_file, device, tasks, task_batch_size, task_n_shot
print(evaluator.make_table(results)) print(evaluator.make_table(results))
@torch.inference_mode() @torch.inference_mode()
def run_speed(model_path, quant_file, device, n_generate=128, max_new_tokens=256): def run_speed(model_path, quant_file, device, n_generate=128, n_context=256):
def _timer(func): def _timer(func):
start = time.time() start = time.time()
out = func() out = func()
...@@ -95,13 +95,16 @@ def run_speed(model_path, quant_file, device, n_generate=128, max_new_tokens=256 ...@@ -95,13 +95,16 @@ def run_speed(model_path, quant_file, device, n_generate=128, max_new_tokens=256
warm_up = torch.randn((4096,4096)).to(device) warm_up = torch.randn((4096,4096)).to(device)
torch.mm(warm_up,warm_up) torch.mm(warm_up,warm_up)
# Load model if quant_file:
model, load_time = _timer(lambda: AutoAWQForCausalLM.from_quantized(model_path, quant_file, fuse_layers=True)) model, load_time = _timer(lambda: AutoAWQForCausalLM.from_quantized(model_path, quant_file, fuse_layers=True))
else:
model, load_time = _timer(lambda: AutoAWQForCausalLM.from_pretrained(model_path))
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
_warmup(device) _warmup(device)
# Generate random inputs # Generate random inputs
n_context = max_new_tokens - n_generate n_context = n_context - n_generate
ids = torch.randint(0, tokenizer.vocab_size, (1, n_context)).cuda() ids = torch.randint(0, tokenizer.vocab_size, (1, n_context)).cuda()
# Context stage # Context stage
...@@ -138,7 +141,10 @@ if __name__ == '__main__': ...@@ -138,7 +141,10 @@ if __name__ == '__main__':
python -m awq.entry --entry_type eval --model_path lmsys/vicuna-7b-v1.5 --task_use_pretrained python -m awq.entry --entry_type eval --model_path lmsys/vicuna-7b-v1.5 --task_use_pretrained
- Run a speedtest to benchmark the quantized model: - Run a speedtest to benchmark the quantized model:
python -m awq.entry --entry_type speed --model_path vicuna-7b-v1.5-awq --quant_file awq_model_w4_g128.pt python -m awq.entry --entry_type speed --model_path vicuna-7b-v1.5-awq --quant_file awq_model_w4_g128.pt --n_generate 128 --n_context 256
- Run a speedtest to benchmark the unquantized FP16 model:
python -m awq.entry --entry_type speed --model_path lmsys/vicuna-7b-v1.5 --n_generate 128 --n_context 256
""" """
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--entry_type', type=str, help='The type of task to run (search|quant|eval|speed)') parser.add_argument('--entry_type', type=str, help='The type of task to run (search|quant|eval|speed)')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment