Unverified Commit 8dbf009b authored by Casper's avatar Casper Committed by GitHub
Browse files

Benchmark hf generate (#237)

parent d1112e1c
...@@ -4,14 +4,38 @@ import argparse ...@@ -4,14 +4,38 @@ import argparse
import numpy as np import numpy as np
import pandas as pd import pandas as pd
from awq import AutoAWQForCausalLM from awq import AutoAWQForCausalLM
from transformers import AutoTokenizer from awq.models.base import BaseAWQForCausalLM
from torch.cuda import OutOfMemoryError from transformers import AutoTokenizer, GenerationConfig, LogitsProcessor, LogitsProcessorList
class TimeMeasuringLogitsProcessor(LogitsProcessor):
def __init__(self):
self.token_times = [time.time()]
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
"""The logit processor is called after the model forward."""
# cuda runs async operates, so we synchronize for accurate time measurement
torch.cuda.synchronize()
# measure time
start_time = time.time()
self.token_times.append(start_time)
return scores
def get_prefill_duration(self):
return self.token_times[1] - self.token_times[0]
def get_decode_durations(self):
token_times = self.token_times[1:]
token_durations = [token_times[i + 1] - token_times[i] for i in range(len(token_times) - 1)]
return token_durations
def warmup(model): def warmup(model):
warm_up = torch.randn((4096,4096)).to(next(model.parameters()).device) warm_up = torch.randn((4096,4096)).to(next(model.parameters()).device)
torch.mm(warm_up,warm_up) torch.mm(warm_up,warm_up)
def generate(model, input_ids, n_generate): def generate_torch(model, input_ids, n_generate):
context_time = 0 context_time = 0
generate_time = [] generate_time = []
...@@ -39,8 +63,39 @@ def generate(model, input_ids, n_generate): ...@@ -39,8 +63,39 @@ def generate(model, input_ids, n_generate):
return context_time, generate_time return context_time, generate_time
def run_round(model_path, quant_file, n_generate, input_ids, batch_size, no_safetensors): def generate_hf(model: BaseAWQForCausalLM, input_ids, n_generate):
generation_config = GenerationConfig(
min_new_tokens=n_generate,
max_new_tokens=n_generate,
use_cache=True,
forced_eos_token_id=-100,
eos_token_id=-100,
)
time_processor = TimeMeasuringLogitsProcessor()
model.generate(
input_ids,
generation_config=generation_config,
logits_processor=LogitsProcessorList([time_processor]),
)
context_time = time_processor.get_prefill_duration()
generate_time = time_processor.get_decode_durations()
return context_time, generate_time
def run_round(generator, model_path, quant_file, n_generate, input_ids, batch_size, no_safetensors, pretrained):
print(f" -- Loading model...") print(f" -- Loading model...")
if pretrained:
model = AutoAWQForCausalLM.from_pretrained(
model_path,
safetensors=not no_safetensors,
device_map="cuda",
torch_dtype=torch.float16,
)
else:
model = AutoAWQForCausalLM.from_quantized( model = AutoAWQForCausalLM.from_quantized(
model_path, quant_file, fuse_layers=True, model_path, quant_file, fuse_layers=True,
max_new_tokens=n_generate, batch_size=batch_size, max_new_tokens=n_generate, batch_size=batch_size,
...@@ -53,7 +108,7 @@ def run_round(model_path, quant_file, n_generate, input_ids, batch_size, no_safe ...@@ -53,7 +108,7 @@ def run_round(model_path, quant_file, n_generate, input_ids, batch_size, no_safe
print(f" -- Generating {n_generate} tokens, {input_ids.shape[1]} in context...") print(f" -- Generating {n_generate} tokens, {input_ids.shape[1]} in context...")
try: try:
context_time, generate_time = generate(model, input_ids, n_generate) context_time, generate_time = generator(model, input_ids, n_generate)
successful_generate = True successful_generate = True
except RuntimeError as ex: except RuntimeError as ex:
if 'cuda out of memory' in str(ex).lower(): if 'cuda out of memory' in str(ex).lower():
...@@ -78,6 +133,11 @@ def run_round(model_path, quant_file, n_generate, input_ids, batch_size, no_safe ...@@ -78,6 +133,11 @@ def run_round(model_path, quant_file, n_generate, input_ids, batch_size, no_safe
prefill_tokens_per_second = 'OOM' prefill_tokens_per_second = 'OOM'
decode_tokens_per_second = 'OOM' decode_tokens_per_second = 'OOM'
if pretrained:
version = "FP16"
else:
version = model.quant_config.version
return { return {
"Batch Size": batch_size, "Batch Size": batch_size,
"Prefill Length": input_ids.shape[1], "Prefill Length": input_ids.shape[1],
...@@ -85,7 +145,7 @@ def run_round(model_path, quant_file, n_generate, input_ids, batch_size, no_safe ...@@ -85,7 +145,7 @@ def run_round(model_path, quant_file, n_generate, input_ids, batch_size, no_safe
"Prefill tokens/s": prefill_tokens_per_second, "Prefill tokens/s": prefill_tokens_per_second,
"Decode tokens/s": decode_tokens_per_second, "Decode tokens/s": decode_tokens_per_second,
"Memory (VRAM)": f"{memory_used:.2f} GB ({memory_pct:.2f}%)" "Memory (VRAM)": f"{memory_used:.2f} GB ({memory_pct:.2f}%)"
}, model.quant_config.version }, version
def main(args): def main(args):
rounds = [ rounds = [
...@@ -98,6 +158,13 @@ def main(args): ...@@ -98,6 +158,13 @@ def main(args):
{"context": 2048, "n_generate": 2048}, {"context": 2048, "n_generate": 2048},
] ]
if args.generator == "torch":
generator = generate_torch
elif args.generator == "hf":
generator = generate_hf
else:
raise ValueError(f"Unknown generator method passed: {args.generator}")
all_stats = [] all_stats = []
tokenizer = AutoTokenizer.from_pretrained(args.model_path, trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained(args.model_path, trust_remote_code=True)
...@@ -105,12 +172,14 @@ def main(args): ...@@ -105,12 +172,14 @@ def main(args):
input_ids = torch.randint(0, tokenizer.vocab_size, (args.batch_size, settings["context"])).cuda() input_ids = torch.randint(0, tokenizer.vocab_size, (args.batch_size, settings["context"])).cuda()
stats, model_version = run_round( stats, model_version = run_round(
generator,
args.model_path, args.model_path,
args.quant_file, args.quant_file,
settings["n_generate"], settings["n_generate"],
input_ids, input_ids,
args.batch_size, args.batch_size,
args.no_safetensors args.no_safetensors,
args.pretrained
) )
all_stats.append(stats) all_stats.append(stats)
...@@ -130,6 +199,8 @@ if __name__ == "__main__": ...@@ -130,6 +199,8 @@ if __name__ == "__main__":
parser.add_argument("--quant_file", type=str, default="", help="weights filename") parser.add_argument("--quant_file", type=str, default="", help="weights filename")
parser.add_argument("--batch_size", type=int, default=1, help="Batch size for cache and generation") parser.add_argument("--batch_size", type=int, default=1, help="Batch size for cache and generation")
parser.add_argument("--no_safetensors", default=False, action="store_true", help="Use for disabling safetensors") parser.add_argument("--no_safetensors", default=False, action="store_true", help="Use for disabling safetensors")
parser.add_argument("--generator", type=str, default="torch", choices=["torch", "hf"], help="weights filename")
parser.add_argument("--pretrained", default=False, action="store_true", help="Measure pretrained model.")
args = parser.parse_args() args = parser.parse_args()
main(args) main(args)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment