Commit 913922ca authored by Tri Dao's avatar Tri Dao
Browse files

[Gen] Refactor decoding function

parent 3557e0bb
...@@ -84,6 +84,7 @@ def sample(logits, top_k=1, top_p=0.0, temperature=1.0): ...@@ -84,6 +84,7 @@ def sample(logits, top_k=1, top_p=0.0, temperature=1.0):
) )
@torch.inference_mode()
def decode( def decode(
input_ids, input_ids,
model, model,
...@@ -97,7 +98,7 @@ def decode( ...@@ -97,7 +98,7 @@ def decode(
tensor_parallel=1, tensor_parallel=1,
fused_ft_kernel=False, fused_ft_kernel=False,
cg=False, cg=False,
timing=False, enable_timing=False,
): ):
"""Decoding, either greedy or with top-k or top-p sampling. """Decoding, either greedy or with top-k or top-p sampling.
If top-k = 0, don't limit the number of candidates (pure sampling). If top-k = 0, don't limit the number of candidates (pure sampling).
...@@ -137,73 +138,67 @@ def decode( ...@@ -137,73 +138,67 @@ def decode(
max_sequence_len=max_length, max_batch_size=batch_size, fused_ft_kernel=fused_ft_kernel max_sequence_len=max_length, max_batch_size=batch_size, fused_ft_kernel=fused_ft_kernel
) )
def logits_forward_fn(input_ids, position_ids, inference_params): def get_logits(input_ids, inference_params):
if not cg: decoding = inference_params.sequence_len_offset > 0
return model( if decoding:
position_ids = torch.full(
(batch_size, 1),
inference_params.sequence_len_offset,
dtype=torch.long,
device=input_ids.device,
)
else:
position_ids = None
if not cg or not decoding:
logits = model(
input_ids, input_ids,
position_ids=position_ids, position_ids=position_ids,
inference_params=inference_params, inference_params=inference_params,
num_last_tokens=1, num_last_tokens=1,
).logits.squeeze(dim=1) ).logits.squeeze(dim=1)
else: else:
return model._decoding_cache.run( logits = model._decoding_cache.run(
input_ids, position_ids, inference_params.sequence_len_offset input_ids, position_ids, inference_params.sequence_len_offset
).clone() ).clone()
return logits[..., :vocab_size] if vocab_size is not None else logits
logits_postprocess_fn = ( def sample_tokens(logits, inference_params):
lambda logits: logits[..., :vocab_size] if vocab_size is not None else logits if teacher_outputs is None or teacher_output_len <= inference_params.sequence_len_offset:
) token = sample(logits, top_k=top_k, top_p=top_p, temperature=temperature)
scores = []
with torch.inference_mode():
if timing:
if tensor_parallel > 1:
torch.distributed.barrier()
torch.cuda.synchronize()
start = time.time()
logits = model(
input_ids, inference_params=inference_params, num_last_tokens=1
).logits.squeeze(dim=1)
logits = logits_postprocess_fn(logits)
scores.append(logits if not cg else logits.clone())
if teacher_outputs is None or teacher_output_len <= seqlen_og:
next_token = sample(logits, top_k=top_k, top_p=top_p, temperature=temperature)
else: else:
next_token = teacher_outputs[:, seqlen_og] token = teacher_outputs[:, inference_params.sequence_len_offset]
sequences = [next_token] return rearrange(token, "b -> b 1")
inference_params.sequence_len_offset = seqlen_og
while True: def should_stop(current_token, inference_params):
position_ids = torch.full( if inference_params.sequence_len_offset == 0:
(batch_size, 1), return False
inference_params.sequence_len_offset, if eos_token_id is not None and (current_token == eos_token_id).all():
dtype=torch.long, return True
device=input_ids.device, if inference_params.sequence_len_offset >= max_length - 1:
) return True
logits = logits_postprocess_fn( return False
logits_forward_fn(rearrange(next_token, "b -> b 1"), position_ids, inference_params)
) start = torch.cuda.Event(enable_timing=enable_timing)
scores.append(logits) end = torch.cuda.Event(enable_timing=enable_timing)
if (
teacher_outputs is None if enable_timing:
or teacher_output_len <= inference_params.sequence_len_offset + 1 if tensor_parallel > 1:
): torch.distributed.barrier()
next_token = sample(logits, top_k=top_k, top_p=top_p, temperature=temperature) start.record()
else: scores, sequences = [], [input_ids]
next_token = teacher_outputs[:, inference_params.sequence_len_offset + 1] while not should_stop(sequences[-1], inference_params):
sequences.append(next_token) scores.append(get_logits(sequences[-1], inference_params))
inference_params.sequence_len_offset += 1 inference_params.sequence_len_offset += sequences[-1].shape[1]
if eos_token_id is not None and (next_token == eos_token_id).all(): sequences.append(sample_tokens(scores[-1], inference_params))
break if enable_timing:
if inference_params.sequence_len_offset >= max_length - 1: end.record()
break if tensor_parallel > 1:
if timing: torch.distributed.barrier()
if tensor_parallel > 1: torch.cuda.synchronize()
torch.distributed.barrier() print(f"Prompt processing + decoding time: {(start.elapsed_time(end)):.0f}ms")
torch.cuda.synchronize()
print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
output_cls = GreedySearchDecoderOnlyOutput if top_k == 1 else SampleDecoderOnlyOutput output_cls = GreedySearchDecoderOnlyOutput if top_k == 1 else SampleDecoderOnlyOutput
return output_cls( return output_cls(
sequences=torch.cat([input_ids, torch.stack(sequences, dim=1)], dim=1), scores=tuple(scores) sequences=torch.cat(sequences, dim=1), scores=tuple(scores)
) )
...@@ -280,7 +275,7 @@ def decode_speculative( ...@@ -280,7 +275,7 @@ def decode_speculative(
tensor_parallel=1, tensor_parallel=1,
fused_ft_kernel=False, fused_ft_kernel=False,
cg=False, cg=False,
timing=False, enable_timing=False,
debug=False, debug=False,
): ):
""" """
...@@ -446,7 +441,7 @@ def decode_speculative( ...@@ -446,7 +441,7 @@ def decode_speculative(
sequences = [input_ids] sequences = [input_ids]
scores = [] scores = []
with torch.inference_mode(): with torch.inference_mode():
if timing: if enable_timing:
if tensor_parallel > 1: if tensor_parallel > 1:
torch.distributed.barrier() torch.distributed.barrier()
torch.cuda.synchronize() torch.cuda.synchronize()
...@@ -566,7 +561,7 @@ def decode_speculative( ...@@ -566,7 +561,7 @@ def decode_speculative(
).logits ).logits
print((scores[-1] - scores_ref[:, :-1]).abs().max()) print((scores[-1] - scores_ref[:, :-1]).abs().max())
if timing: if enable_timing:
if tensor_parallel > 1: if tensor_parallel > 1:
torch.distributed.barrier() torch.distributed.barrier()
torch.cuda.synchronize() torch.cuda.synchronize()
......
...@@ -289,7 +289,7 @@ def test_baichuan_generation(model_name): ...@@ -289,7 +289,7 @@ def test_baichuan_generation(model_name):
fused_ft_kernel=True, fused_ft_kernel=True,
return_dict_in_generate=True, return_dict_in_generate=True,
output_scores=True, output_scores=True,
timing=True, enable_timing=True,
teacher_outputs=out_hf.sequences, teacher_outputs=out_hf.sequences,
) )
torch.cuda.synchronize() torch.cuda.synchronize()
...@@ -310,7 +310,7 @@ def test_baichuan_generation(model_name): ...@@ -310,7 +310,7 @@ def test_baichuan_generation(model_name):
cg=True, cg=True,
return_dict_in_generate=True, return_dict_in_generate=True,
output_scores=True, output_scores=True,
timing=True, enable_timing=True,
teacher_outputs=out_hf.sequences, teacher_outputs=out_hf.sequences,
) )
torch.cuda.synchronize() torch.cuda.synchronize()
...@@ -400,7 +400,7 @@ def test_baichuan_parallel_generation(model_name, world_size): ...@@ -400,7 +400,7 @@ def test_baichuan_parallel_generation(model_name, world_size):
# teacher_outputs=out_hf.sequences, # teacher_outputs=out_hf.sequences,
return_dict_in_generate=True, return_dict_in_generate=True,
output_scores=True, output_scores=True,
timing=True, enable_timing=True,
) )
# Capture graph outside the timing loop # Capture graph outside the timing loop
...@@ -419,7 +419,7 @@ def test_baichuan_parallel_generation(model_name, world_size): ...@@ -419,7 +419,7 @@ def test_baichuan_parallel_generation(model_name, world_size):
# teacher_outputs=out_hf.sequences, # teacher_outputs=out_hf.sequences,
return_dict_in_generate=True, return_dict_in_generate=True,
output_scores=True, output_scores=True,
timing=True, enable_timing=True,
) )
del model del model
parallel_state.destroy_model_parallel() parallel_state.destroy_model_parallel()
......
...@@ -245,7 +245,7 @@ def test_falcon_generation(model_name): ...@@ -245,7 +245,7 @@ def test_falcon_generation(model_name):
fused_ft_kernel=True, fused_ft_kernel=True,
return_dict_in_generate=True, return_dict_in_generate=True,
output_scores=True, output_scores=True,
timing=True, enable_timing=True,
teacher_outputs=out_hf.sequences, teacher_outputs=out_hf.sequences,
) )
torch.cuda.synchronize() torch.cuda.synchronize()
...@@ -264,7 +264,7 @@ def test_falcon_generation(model_name): ...@@ -264,7 +264,7 @@ def test_falcon_generation(model_name):
cg=True, cg=True,
return_dict_in_generate=True, return_dict_in_generate=True,
output_scores=True, output_scores=True,
timing=True, enable_timing=True,
teacher_outputs=out_hf.sequences, teacher_outputs=out_hf.sequences,
) )
torch.cuda.synchronize() torch.cuda.synchronize()
...@@ -351,7 +351,7 @@ def test_falcon_parallel_generation(model_name, world_size): ...@@ -351,7 +351,7 @@ def test_falcon_parallel_generation(model_name, world_size):
# teacher_outputs=out_hf.sequences, # teacher_outputs=out_hf.sequences,
return_dict_in_generate=True, return_dict_in_generate=True,
output_scores=True, output_scores=True,
timing=True, enable_timing=True,
) )
# Capture graph outside the timing loop # Capture graph outside the timing loop
...@@ -368,7 +368,7 @@ def test_falcon_parallel_generation(model_name, world_size): ...@@ -368,7 +368,7 @@ def test_falcon_parallel_generation(model_name, world_size):
# teacher_outputs=out_hf.sequences, # teacher_outputs=out_hf.sequences,
return_dict_in_generate=True, return_dict_in_generate=True,
output_scores=True, output_scores=True,
timing=True, enable_timing=True,
) )
del model del model
parallel_state.destroy_model_parallel() parallel_state.destroy_model_parallel()
......
...@@ -200,7 +200,7 @@ def test_gpt2_generation(model_name, rotary, optimized, fused_ft_kernel): ...@@ -200,7 +200,7 @@ def test_gpt2_generation(model_name, rotary, optimized, fused_ft_kernel):
fused_ft_kernel=fused_ft_kernel, fused_ft_kernel=fused_ft_kernel,
return_dict_in_generate=True, return_dict_in_generate=True,
output_scores=True, output_scores=True,
timing=True, enable_timing=True,
) )
print(out.sequences) print(out.sequences)
print(tokenizer.batch_decode(out.sequences.tolist())) print(tokenizer.batch_decode(out.sequences.tolist()))
...@@ -212,7 +212,7 @@ def test_gpt2_generation(model_name, rotary, optimized, fused_ft_kernel): ...@@ -212,7 +212,7 @@ def test_gpt2_generation(model_name, rotary, optimized, fused_ft_kernel):
cg=True, cg=True,
return_dict_in_generate=True, return_dict_in_generate=True,
output_scores=True, output_scores=True,
timing=True, enable_timing=True,
) )
print(out_cg.sequences) print(out_cg.sequences)
...@@ -267,7 +267,7 @@ def get_logits(model, input_ids, max_length, teacher_outputs=None, **kwargs): ...@@ -267,7 +267,7 @@ def get_logits(model, input_ids, max_length, teacher_outputs=None, **kwargs):
teacher_outputs=teacher_outputs, teacher_outputs=teacher_outputs,
return_dict_in_generate=True, return_dict_in_generate=True,
output_scores=True, output_scores=True,
timing=True, enable_timing=True,
**kwargs, **kwargs,
) )
return torch.stack(out.scores, dim=1) return torch.stack(out.scores, dim=1)
...@@ -431,7 +431,7 @@ def test_gpt2_speculative_decoding(model_name, optimized, fused_ft_kernel, cg): ...@@ -431,7 +431,7 @@ def test_gpt2_speculative_decoding(model_name, optimized, fused_ft_kernel, cg):
fused_ft_kernel=fused_ft_kernel, fused_ft_kernel=fused_ft_kernel,
cg=cg, cg=cg,
speculative_lookahead=4, speculative_lookahead=4,
timing=True, enable_timing=True,
) )
print(tokenizer.batch_decode(out.sequences)) print(tokenizer.batch_decode(out.sequences))
out_og = model.generate( out_og = model.generate(
...@@ -440,7 +440,7 @@ def test_gpt2_speculative_decoding(model_name, optimized, fused_ft_kernel, cg): ...@@ -440,7 +440,7 @@ def test_gpt2_speculative_decoding(model_name, optimized, fused_ft_kernel, cg):
top_k=5, top_k=5,
fused_ft_kernel=fused_ft_kernel, fused_ft_kernel=fused_ft_kernel,
cg=False, cg=False,
timing=True, enable_timing=True,
return_dict_in_generate=True, return_dict_in_generate=True,
) )
print(tokenizer.batch_decode(out_og.sequences)) print(tokenizer.batch_decode(out_og.sequences))
......
...@@ -114,7 +114,7 @@ def test_tensor_parallel(model_name, rotary, fused_ft_kernel, world_size): ...@@ -114,7 +114,7 @@ def test_tensor_parallel(model_name, rotary, fused_ft_kernel, world_size):
fused_ft_kernel=fused_ft_kernel, fused_ft_kernel=fused_ft_kernel,
return_dict_in_generate=True, return_dict_in_generate=True,
output_scores=True, output_scores=True,
timing=True, enable_timing=True,
) )
print(out.sequences) print(out.sequences)
if fused_ft_kernel: if fused_ft_kernel:
...@@ -127,7 +127,7 @@ def test_tensor_parallel(model_name, rotary, fused_ft_kernel, world_size): ...@@ -127,7 +127,7 @@ def test_tensor_parallel(model_name, rotary, fused_ft_kernel, world_size):
cg=True, cg=True,
return_dict_in_generate=True, return_dict_in_generate=True,
output_scores=True, output_scores=True,
timing=True, enable_timing=True,
) )
print(out_cg.sequences) print(out_cg.sequences)
......
...@@ -144,7 +144,7 @@ def test_gptj_generation(model_name): ...@@ -144,7 +144,7 @@ def test_gptj_generation(model_name):
# eos_token_id=eos_token_id, fused_ft_kernel=False, # eos_token_id=eos_token_id, fused_ft_kernel=False,
return_dict_in_generate=True, return_dict_in_generate=True,
output_scores=True, output_scores=True,
timing=True, enable_timing=True,
teacher_outputs=out_hf.sequences, teacher_outputs=out_hf.sequences,
) )
torch.cuda.synchronize() torch.cuda.synchronize()
...@@ -163,7 +163,7 @@ def test_gptj_generation(model_name): ...@@ -163,7 +163,7 @@ def test_gptj_generation(model_name):
cg=True, cg=True,
return_dict_in_generate=True, return_dict_in_generate=True,
output_scores=True, output_scores=True,
timing=True, enable_timing=True,
teacher_outputs=out_hf.sequences, teacher_outputs=out_hf.sequences,
) )
torch.cuda.synchronize() torch.cuda.synchronize()
......
...@@ -295,7 +295,7 @@ def test_llama_generation(model_name, checkpoint_format): ...@@ -295,7 +295,7 @@ def test_llama_generation(model_name, checkpoint_format):
fused_ft_kernel=True, fused_ft_kernel=True,
return_dict_in_generate=True, return_dict_in_generate=True,
output_scores=True, output_scores=True,
timing=True, enable_timing=True,
teacher_outputs=out_hf.sequences, teacher_outputs=out_hf.sequences,
) )
torch.cuda.synchronize() torch.cuda.synchronize()
...@@ -314,7 +314,7 @@ def test_llama_generation(model_name, checkpoint_format): ...@@ -314,7 +314,7 @@ def test_llama_generation(model_name, checkpoint_format):
cg=True, cg=True,
return_dict_in_generate=True, return_dict_in_generate=True,
output_scores=True, output_scores=True,
timing=True, enable_timing=True,
teacher_outputs=out_hf.sequences, teacher_outputs=out_hf.sequences,
) )
torch.cuda.synchronize() torch.cuda.synchronize()
...@@ -403,7 +403,7 @@ def test_llama_parallel_generation(model_name, world_size, checkpoint_format): ...@@ -403,7 +403,7 @@ def test_llama_parallel_generation(model_name, world_size, checkpoint_format):
# teacher_outputs=out_hf.sequences, # teacher_outputs=out_hf.sequences,
return_dict_in_generate=True, return_dict_in_generate=True,
output_scores=True, output_scores=True,
timing=True, enable_timing=True,
) )
# Capture graph outside the timing loop # Capture graph outside the timing loop
...@@ -420,7 +420,7 @@ def test_llama_parallel_generation(model_name, world_size, checkpoint_format): ...@@ -420,7 +420,7 @@ def test_llama_parallel_generation(model_name, world_size, checkpoint_format):
# teacher_outputs=out_hf.sequences, # teacher_outputs=out_hf.sequences,
return_dict_in_generate=True, return_dict_in_generate=True,
output_scores=True, output_scores=True,
timing=True, enable_timing=True,
) )
del model del model
parallel_state.destroy_model_parallel() parallel_state.destroy_model_parallel()
......
...@@ -158,7 +158,7 @@ def test_opt_generation(model_name): ...@@ -158,7 +158,7 @@ def test_opt_generation(model_name):
fused_ft_kernel=fused_ft_kernel, fused_ft_kernel=fused_ft_kernel,
return_dict_in_generate=True, return_dict_in_generate=True,
output_scores=True, output_scores=True,
timing=True, enable_timing=True,
) )
torch.cuda.synchronize() torch.cuda.synchronize()
print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms") print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
...@@ -179,7 +179,7 @@ def test_opt_generation(model_name): ...@@ -179,7 +179,7 @@ def test_opt_generation(model_name):
cg=True, cg=True,
return_dict_in_generate=True, return_dict_in_generate=True,
output_scores=True, output_scores=True,
timing=True, enable_timing=True,
) )
torch.cuda.synchronize() torch.cuda.synchronize()
print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms") print(f"Prompt processing + decoding time: {(time.time() - start) * 1000:.0f}ms")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment