"...text-generation-inference.git" did not exist on "f4a073ae6d2cbcf6ee353b4e27ea90586893fe8b"
Commit a4626828 authored by Casper Hansen's avatar Casper Hansen
Browse files

Switch to model.generate()

parent e71181bd
...@@ -3,11 +3,11 @@ import time ...@@ -3,11 +3,11 @@ import time
import torch import torch
import argparse import argparse
from lm_eval import evaluator from lm_eval import evaluator
from transformers import AutoTokenizer
from awq import AutoAWQForCausalLM from awq import AutoAWQForCausalLM
from awq.quantize.auto_clip import apply_clip from awq.quantize.auto_clip import apply_clip
from awq.quantize.auto_scale import apply_scale from awq.quantize.auto_scale import apply_scale
from awq.utils.lm_eval_adaptor import LMEvalAdaptor from awq.utils.lm_eval_adaptor import LMEvalAdaptor
from transformers import AutoTokenizer, GenerationConfig
def load_search_result_into_memory(model, search_path): def load_search_result_into_memory(model, search_path):
...@@ -80,22 +80,6 @@ def run_speed(model_path, quant_file, device, n_generate=128, n_context=256, bat ...@@ -80,22 +80,6 @@ def run_speed(model_path, quant_file, device, n_generate=128, n_context=256, bat
out = func() out = func()
return out, time.time() - start return out, time.time() - start
def _generate(model, model_out, n_generate, batch_size):
past_key_values = model_out.past_key_values
for i in range(n_generate):
logits = model_out.logits[:, -1, :]
new_tokens = []
for batch_index in range(batch_size):
probs = torch.softmax(logits[batch_index], dim=-1)
token = torch.multinomial(probs, num_samples=1)
new_tokens.append(token)
tokens = torch.as_tensor(new_tokens, device=device).unsqueeze(-1)
model_out = model(tokens, use_cache=True, past_key_values=past_key_values)
def _warmup(device:str): def _warmup(device:str):
warm_up = torch.randn((4096,4096)).to(device) warm_up = torch.randn((4096,4096)).to(device)
torch.mm(warm_up,warm_up) torch.mm(warm_up,warm_up)
...@@ -114,10 +98,27 @@ def run_speed(model_path, quant_file, device, n_generate=128, n_context=256, bat ...@@ -114,10 +98,27 @@ def run_speed(model_path, quant_file, device, n_generate=128, n_context=256, bat
ids = torch.randint(0, tokenizer.vocab_size, (batch_size, n_context)).cuda() ids = torch.randint(0, tokenizer.vocab_size, (batch_size, n_context)).cuda()
# Context stage # Context stage
model_out, context_time = _timer(lambda: model(ids, use_cache=True)) _, context_time = _timer(lambda: model.generate(
ids,
generation_config=GenerationConfig(
max_new_tokens=0,
min_new_tokens=0,
use_cache=True
)
))
# Generation stage # Generation stage
_, generation_time = _timer(lambda: _generate(model, model_out, n_generate, batch_size)) _, generation_time = _timer(lambda: model.generate(
ids,
generation_config=GenerationConfig(
max_new_tokens=n_context,
min_new_tokens=n_context,
forced_eos_token_id=-100,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=-100,
use_cache=True
)
))
# Prints # Prints
memory_used = torch.cuda.max_memory_allocated(device) / (1024 ** 2) memory_used = torch.cuda.max_memory_allocated(device) / (1024 ** 2)
...@@ -126,7 +127,7 @@ def run_speed(model_path, quant_file, device, n_generate=128, n_context=256, bat ...@@ -126,7 +127,7 @@ def run_speed(model_path, quant_file, device, n_generate=128, n_context=256, bat
inference_tokens_per_second = n_generate / generation_time * batch_size inference_tokens_per_second = n_generate / generation_time * batch_size
inference_ms_per_token = (generation_time*1000) / n_generate / batch_size inference_ms_per_token = (generation_time*1000) / n_generate / batch_size
print(f"[======] Model summary: {model_path} [======]") print(f"[=] Model summary: {model_path} [=]")
print(f"[*] Load time: {load_time:.2f} seconds") print(f"[*] Load time: {load_time:.2f} seconds")
print(f"[*] Context speed: {context_tokens_per_second:.2f} tokens/second ({context_ms_per_token:.2f} ms/token)") print(f"[*] Context speed: {context_tokens_per_second:.2f} tokens/second ({context_ms_per_token:.2f} ms/token)")
print(f"[*] Generation speed: {inference_tokens_per_second:.2f} tokens/second ({inference_ms_per_token:.2f} ms/token)") print(f"[*] Generation speed: {inference_tokens_per_second:.2f} tokens/second ({inference_ms_per_token:.2f} ms/token)")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment