Commit 0091f1e2 authored by Casper Hansen's avatar Casper Hansen
Browse files

Support batch size benchmark, stop if OOM

parent 28d52d81
......@@ -4,6 +4,8 @@ import argparse
import numpy as np
import pandas as pd
from awq import AutoAWQForCausalLM
from transformers import AutoTokenizer
from torch.cuda import OutOfMemoryError
def warmup(model):
warm_up = torch.randn((4096,4096)).to(next(model.parameters()).device)
......@@ -20,10 +22,10 @@ def generate(model, input_ids, n_generate):
if i == 0:
# prefill context
inputs = torch.as_tensor([input_ids], device=next(model.parameters()).device)
inputs = torch.as_tensor(input_ids, device=next(model.parameters()).device)
else:
# decode tokens
inputs = torch.as_tensor([[token]], device=next(model.parameters()).device)
inputs = torch.as_tensor(token, device=next(model.parameters()).device)
out = model(inputs, use_cache=True)
......@@ -37,33 +39,46 @@ def generate(model, input_ids, n_generate):
return context_time, generate_time
def run_round(model_path, quant_file, n_generate, input_ids):
def run_round(model_path, quant_file, n_generate, input_ids, batch_size):
print(f" -- Loading model...")
model = AutoAWQForCausalLM.from_quantized(model_path, quant_file, fuse_layers=True)
model = AutoAWQForCausalLM.from_quantized(
model_path, quant_file, fuse_layers=True,
max_new_tokens=n_generate
)
print(f" -- Warming up...")
warmup(model)
print(f" -- Generating {n_generate} tokens, {len(input_ids)} token prompt...")
print(f" -- Generating {n_generate} tokens, {input_ids.shape[1]} in context...")
try:
context_time, generate_time = generate(model, input_ids, n_generate)
successful_generate = True
except RuntimeError as ex:
successful_generate = False
device = next(model.parameters()).device
prefill_tokens_per_second = n_generate / context_time
decode_tokens_per_second = n_generate / generate_time
memory_used = torch.cuda.max_memory_allocated(device) / (1024 ** 3)
memory_pct = memory_used / (torch.cuda.get_device_properties(device).total_memory / (1024 ** 3)) * 100
if successful_generate:
prefill_tokens_per_second = n_generate / context_time * batch_size
decode_tokens_per_second = n_generate / generate_time * batch_size
print(f" ** Speed (Prefill): {prefill_tokens_per_second:.2f} tokens/second")
print(f" ** Speed (Decode): {decode_tokens_per_second:.2f} tokens/second")
print(f" ** Max Memory (VRAM): {memory_used:.2f} GB ({memory_pct:.2f}%)")
else:
prefill_tokens_per_second = 'OOM'
decode_tokens_per_second = 'OOM'
return {
"Prefill length": len(input_ids),
"Decode length": n_generate,
"Batch Size": batch_size,
"Prefill Length": input_ids.shape[1],
"Decode Length": n_generate,
"Prefill tokens/s": prefill_tokens_per_second,
"Decode tokens/s": decode_tokens_per_second,
"Memory (VRAM)": f"{memory_used:.2f} GB ({memory_pct:.2f}%)",
"GPU": torch.cuda.get_device_name()
"Memory (VRAM)": f"{memory_used:.2f} GB ({memory_pct:.2f}%)"
}
def main(args):
......@@ -79,26 +94,34 @@ def main(args):
]
all_stats = []
tokenizer = AutoTokenizer.from_pretrained(args.model_path, trust_remote_code=True)
for settings in rounds:
input_ids = [1 for _ in range(settings["context"])]
input_ids = torch.randint(0, tokenizer.vocab_size, (args.batch_size, settings["context"])).cuda()
stats = run_round(
args.model_path,
args.quant_file,
settings["n_generate"],
input_ids
input_ids,
args.batch_size
)
all_stats.append(stats)
if stats["Prefill tokens/s"] == 'OOM':
break
df = pd.DataFrame(all_stats)
print('GPU:', torch.cuda.get_device_name())
print('Model:', args.model_path)
print(df.to_markdown(index=False))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model_path", type=str, default="vicuna-7b-v1.5-awq-gemv", help="path to the model")
parser.add_argument("--model_path", type=str, default="casperhansen/vicuna-7b-v1.5-awq", help="path to the model")
parser.add_argument("--quant_file", type=str, default="awq_model_w4_g128.pt", help="weights filename")
parser.add_argument("--batch_size", type=int, default=1, help="weights filename")
args = parser.parse_args()
main(args)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment