Unverified Commit 85430ddc authored by Casper's avatar Casper Committed by GitHub
Browse files

Merge pull request #26 from casper-hansen/speed_batch_size

Implement batch size for speed test
parents abdc726c 42655843
...@@ -74,29 +74,35 @@ def run_eval(model_path, quant_file, device, tasks, task_batch_size, task_n_shot ...@@ -74,29 +74,35 @@ def run_eval(model_path, quant_file, device, tasks, task_batch_size, task_n_shot
print(evaluator.make_table(results)) print(evaluator.make_table(results))
@torch.inference_mode() @torch.inference_mode()
def run_speed(model_path, quant_file, device, n_generate=128, n_context=256): def run_speed(model_path, quant_file, device, n_generate=128, n_context=256, batch_size=1, disable_fused_layers=False):
def _timer(func): def _timer(func):
start = time.time() start = time.time()
out = func() out = func()
return out, time.time() - start return out, time.time() - start
def _generate(model, model_out, n_generate): def _generate(model, model_out, n_generate, batch_size):
past_key_values = model_out.past_key_values past_key_values = model_out.past_key_values
for i in range(n_generate): for i in range(n_generate):
logits = model_out.logits[0, -1, :] logits = model_out.logits[:, -1, :]
probs = torch.softmax(logits, dim=-1) new_tokens = []
for batch_index in range(batch_size):
probs = torch.softmax(logits[batch_index], dim=-1)
token = torch.multinomial(probs, num_samples=1) token = torch.multinomial(probs, num_samples=1)
token = torch.as_tensor([token], device=device).unsqueeze(0) new_tokens.append(token)
tokens = torch.as_tensor(new_tokens, device=device).unsqueeze(-1)
model_out = model(token, use_cache=True, past_key_values=past_key_values) model_out = model(tokens, use_cache=True, past_key_values=past_key_values)
def _warmup(device:str): def _warmup(device:str):
warm_up = torch.randn((4096,4096)).to(device) warm_up = torch.randn((4096,4096)).to(device)
torch.mm(warm_up,warm_up) torch.mm(warm_up,warm_up)
if quant_file: if quant_file:
model, load_time = _timer(lambda: AutoAWQForCausalLM.from_quantized(model_path, quant_file, fuse_layers=True)) fuse_layers = False if disable_fused_layers else True
model, load_time = _timer(lambda: AutoAWQForCausalLM.from_quantized(model_path, quant_file, fuse_layers=fuse_layers))
else: else:
model, load_time = _timer(lambda: AutoAWQForCausalLM.from_pretrained(model_path)) model, load_time = _timer(lambda: AutoAWQForCausalLM.from_pretrained(model_path))
...@@ -105,20 +111,20 @@ def run_speed(model_path, quant_file, device, n_generate=128, n_context=256): ...@@ -105,20 +111,20 @@ def run_speed(model_path, quant_file, device, n_generate=128, n_context=256):
# Generate random inputs # Generate random inputs
n_context = n_context - n_generate n_context = n_context - n_generate
ids = torch.randint(0, tokenizer.vocab_size, (1, n_context)).cuda() ids = torch.randint(0, tokenizer.vocab_size, (batch_size, n_context)).cuda()
# Context stage # Context stage
model_out, context_time = _timer(lambda: model(ids, use_cache=True)) model_out, context_time = _timer(lambda: model(ids, use_cache=True))
# Generation stage # Generation stage
_, generation_time = _timer(lambda: _generate(model, model_out, n_generate)) _, generation_time = _timer(lambda: _generate(model, model_out, n_generate, batch_size))
# Prints # Prints
memory_used = torch.cuda.max_memory_allocated(device) / (1024 ** 2) memory_used = torch.cuda.max_memory_allocated(device) / (1024 ** 2)
context_tokens_per_second = n_context / context_time context_tokens_per_second = n_context / context_time * batch_size
context_ms_per_token = (context_time*1000) / n_context context_ms_per_token = (context_time*1000) / n_context * batch_size
inference_tokens_per_second = n_generate / generation_time inference_tokens_per_second = n_generate / generation_time * batch_size
inference_ms_per_token = (generation_time*1000) / n_generate inference_ms_per_token = (generation_time*1000) / n_generate * batch_size
print(f"[======] Model summary: {model_path} [======]") print(f"[======] Model summary: {model_path} [======]")
print(f"[*] Load time: {load_time:.2f} seconds") print(f"[*] Load time: {load_time:.2f} seconds")
...@@ -164,6 +170,9 @@ if __name__ == '__main__': ...@@ -164,6 +170,9 @@ if __name__ == '__main__':
parser.add_argument('--task_n_shot', type=int, default=0) parser.add_argument('--task_n_shot', type=int, default=0)
parser.add_argument('--n_generate', type=int, default=128) parser.add_argument('--n_generate', type=int, default=128)
parser.add_argument('--n_context', type=int, default=256) parser.add_argument('--n_context', type=int, default=256)
parser.add_argument('--batch_size', type=int, default=1)
parser.add_argument("--disable_fused_layers", default=False, action='store_true',
help="Pass '--disable_fused_layers' to disable fused layers")
args = parser.parse_args() args = parser.parse_args()
quant_config = { "zero_point": True, "q_group_size": args.q_group_size, "w_bit": args.w_bit } quant_config = { "zero_point": True, "q_group_size": args.q_group_size, "w_bit": args.w_bit }
...@@ -176,6 +185,9 @@ if __name__ == '__main__': ...@@ -176,6 +185,9 @@ if __name__ == '__main__':
run_eval(args.model_path, args.quant_file, args.device, run_eval(args.model_path, args.quant_file, args.device,
args.tasks, args.task_batch_size, args.task_n_shot, args.task_use_pretrained) args.tasks, args.task_batch_size, args.task_n_shot, args.task_use_pretrained)
elif args.entry_type == 'speed': elif args.entry_type == 'speed':
run_speed(args.model_path, args.quant_file, args.device, args.n_generate, args.n_context) if args.batch_size > 1 and not args.disable_fused_layers:
raise Exception('Fused layers only support batch_size=1. Pass --disable_fused_layers to run batch_size>1 (much slower).')
run_speed(args.model_path, args.quant_file, args.device, args.n_generate, args.n_context, args.batch_size, args.disable_fused_layers)
else: else:
raise Exception('--entry_type must be one of (search|quant|eval|speed)') raise Exception('--entry_type must be one of (search|quant|eval|speed)')
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment