Commit 0091f1e2 authored by Casper Hansen's avatar Casper Hansen
Browse files

Support batch size benchmark, stop if OOM

parent 28d52d81
...@@ -4,6 +4,8 @@ import argparse ...@@ -4,6 +4,8 @@ import argparse
import numpy as np import numpy as np
import pandas as pd import pandas as pd
from awq import AutoAWQForCausalLM from awq import AutoAWQForCausalLM
from transformers import AutoTokenizer
from torch.cuda import OutOfMemoryError
def warmup(model): def warmup(model):
warm_up = torch.randn((4096,4096)).to(next(model.parameters()).device) warm_up = torch.randn((4096,4096)).to(next(model.parameters()).device)
...@@ -20,10 +22,10 @@ def generate(model, input_ids, n_generate): ...@@ -20,10 +22,10 @@ def generate(model, input_ids, n_generate):
if i == 0: if i == 0:
# prefill context # prefill context
inputs = torch.as_tensor([input_ids], device=next(model.parameters()).device) inputs = torch.as_tensor(input_ids, device=next(model.parameters()).device)
else: else:
# decode tokens # decode tokens
inputs = torch.as_tensor([[token]], device=next(model.parameters()).device) inputs = torch.as_tensor(token, device=next(model.parameters()).device)
out = model(inputs, use_cache=True) out = model(inputs, use_cache=True)
...@@ -37,33 +39,46 @@ def generate(model, input_ids, n_generate): ...@@ -37,33 +39,46 @@ def generate(model, input_ids, n_generate):
return context_time, generate_time return context_time, generate_time
def run_round(model_path, quant_file, n_generate, input_ids): def run_round(model_path, quant_file, n_generate, input_ids, batch_size):
print(f" -- Loading model...") print(f" -- Loading model...")
model = AutoAWQForCausalLM.from_quantized(model_path, quant_file, fuse_layers=True) model = AutoAWQForCausalLM.from_quantized(
model_path, quant_file, fuse_layers=True,
max_new_tokens=n_generate
)
print(f" -- Warming up...") print(f" -- Warming up...")
warmup(model) warmup(model)
print(f" -- Generating {n_generate} tokens, {len(input_ids)} token prompt...") print(f" -- Generating {n_generate} tokens, {input_ids.shape[1]} in context...")
context_time, generate_time = generate(model, input_ids, n_generate)
try:
context_time, generate_time = generate(model, input_ids, n_generate)
successful_generate = True
except RuntimeError as ex:
successful_generate = False
device = next(model.parameters()).device device = next(model.parameters()).device
prefill_tokens_per_second = n_generate / context_time
decode_tokens_per_second = n_generate / generate_time
memory_used = torch.cuda.max_memory_allocated(device) / (1024 ** 3) memory_used = torch.cuda.max_memory_allocated(device) / (1024 ** 3)
memory_pct = memory_used / (torch.cuda.get_device_properties(device).total_memory / (1024 ** 3)) * 100 memory_pct = memory_used / (torch.cuda.get_device_properties(device).total_memory / (1024 ** 3)) * 100
print(f" ** Speed (Prefill): {prefill_tokens_per_second:.2f} tokens/second") if successful_generate:
print(f" ** Speed (Decode): {decode_tokens_per_second:.2f} tokens/second") prefill_tokens_per_second = n_generate / context_time * batch_size
print(f" ** Max Memory (VRAM): {memory_used:.2f} GB ({memory_pct:.2f}%)") decode_tokens_per_second = n_generate / generate_time * batch_size
print(f" ** Speed (Prefill): {prefill_tokens_per_second:.2f} tokens/second")
print(f" ** Speed (Decode): {decode_tokens_per_second:.2f} tokens/second")
print(f" ** Max Memory (VRAM): {memory_used:.2f} GB ({memory_pct:.2f}%)")
else:
prefill_tokens_per_second = 'OOM'
decode_tokens_per_second = 'OOM'
return { return {
"Prefill length": len(input_ids), "Batch Size": batch_size,
"Decode length": n_generate, "Prefill Length": input_ids.shape[1],
"Decode Length": n_generate,
"Prefill tokens/s": prefill_tokens_per_second, "Prefill tokens/s": prefill_tokens_per_second,
"Decode tokens/s": decode_tokens_per_second, "Decode tokens/s": decode_tokens_per_second,
"Memory (VRAM)": f"{memory_used:.2f} GB ({memory_pct:.2f}%)", "Memory (VRAM)": f"{memory_used:.2f} GB ({memory_pct:.2f}%)"
"GPU": torch.cuda.get_device_name()
} }
def main(args): def main(args):
...@@ -79,26 +94,34 @@ def main(args): ...@@ -79,26 +94,34 @@ def main(args):
] ]
all_stats = [] all_stats = []
tokenizer = AutoTokenizer.from_pretrained(args.model_path, trust_remote_code=True)
for settings in rounds: for settings in rounds:
input_ids = [1 for _ in range(settings["context"])] input_ids = torch.randint(0, tokenizer.vocab_size, (args.batch_size, settings["context"])).cuda()
stats = run_round( stats = run_round(
args.model_path, args.model_path,
args.quant_file, args.quant_file,
settings["n_generate"], settings["n_generate"],
input_ids input_ids,
args.batch_size
) )
all_stats.append(stats) all_stats.append(stats)
if stats["Prefill tokens/s"] == 'OOM':
break
df = pd.DataFrame(all_stats) df = pd.DataFrame(all_stats)
print('GPU:', torch.cuda.get_device_name())
print('Model:', args.model_path)
print(df.to_markdown(index=False)) print(df.to_markdown(index=False))
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--model_path", type=str, default="vicuna-7b-v1.5-awq-gemv", help="path to the model") parser.add_argument("--model_path", type=str, default="casperhansen/vicuna-7b-v1.5-awq", help="path to the model")
parser.add_argument("--quant_file", type=str, default="awq_model_w4_g128.pt", help="weights filename") parser.add_argument("--quant_file", type=str, default="awq_model_w4_g128.pt", help="weights filename")
parser.add_argument("--batch_size", type=int, default=1, help="weights filename")
args = parser.parse_args() args = parser.parse_args()
main(args) main(args)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment