Commit 1c9ef9b3 authored by Tri Dao's avatar Tri Dao
Browse files

[Gen] Measure prompt processing + decoding time, not just decoding

parent 6f6e9a9a
...@@ -107,10 +107,12 @@ def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0, ...@@ -107,10 +107,12 @@ def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0,
fused_ft_kernel=fused_ft_kernel) fused_ft_kernel=fused_ft_kernel)
scores = [] scores = []
with torch.inference_mode(): with torch.inference_mode():
logits = model(input_ids, inference_params=inference_params).logits[:, -1]
if timing: if timing:
if tensor_parallel > 1:
torch.distributed.barrier()
torch.cuda.synchronize() torch.cuda.synchronize()
start = time.time() start = time.time()
logits = model(input_ids, inference_params=inference_params).logits[:, -1]
if vocab_size is not None: if vocab_size is not None:
logits = logits[..., :vocab_size] logits = logits[..., :vocab_size]
scores.append(logits if not cg else logits.clone()) scores.append(logits if not cg else logits.clone())
...@@ -143,8 +145,10 @@ def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0, ...@@ -143,8 +145,10 @@ def decode(input_ids, model, max_length, top_k=1, top_p=0.0, temperature=1.0,
if inference_params.sequence_len_offset >= max_length - 1: if inference_params.sequence_len_offset >= max_length - 1:
break break
if timing: if timing:
if tensor_parallel > 1:
torch.distributed.barrier()
torch.cuda.synchronize() torch.cuda.synchronize()
print(f'Decoding time: {(time.time() - start) * 1000:.0f}ms') print(f'Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms')
output_cls = GreedySearchDecoderOnlyOutput if top_k == 1 else SampleDecoderOnlyOutput output_cls = GreedySearchDecoderOnlyOutput if top_k == 1 else SampleDecoderOnlyOutput
return output_cls( return output_cls(
sequences=torch.cat([input_ids, torch.stack(sequences, dim=1)], dim=1), sequences=torch.cat([input_ids, torch.stack(sequences, dim=1)], dim=1),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment