Commit 9cc286ba authored by rprenger's avatar rprenger
Browse files

Getting tokens_to_generate=0 to work

parent 21d2b0fe
...@@ -20,7 +20,9 @@ import torch ...@@ -20,7 +20,9 @@ import torch
from megatron import mpu from megatron import mpu
from .communication import broadcast_float_list from .communication import broadcast_float_list
from .generation import generate_tokens_probs_and_return_on_first_stage from .generation import (
generate_tokens_probs_and_return_on_first_stage,
score_and_return_on_first_stage)
from .tokenization import ( from .tokenization import (
tokenize_prompts, tokenize_prompts,
detokenize_generations) detokenize_generations)
...@@ -31,7 +33,6 @@ def generate_and_post_process(model, ...@@ -31,7 +33,6 @@ def generate_and_post_process(model,
prompts=None, prompts=None,
tokens_to_generate=0, tokens_to_generate=0,
return_output_log_probs=False, return_output_log_probs=False,
return_all_log_probs=False,
greedy_sampling=False, greedy_sampling=False,
top_k_sampling=0, top_k_sampling=0,
top_p_sampling=0.0, top_p_sampling=0.0,
...@@ -43,12 +44,11 @@ def generate_and_post_process(model, ...@@ -43,12 +44,11 @@ def generate_and_post_process(model,
move to cpu and convert to list.""" move to cpu and convert to list."""
# Main inference. # Main inference.
tokens, lengths, output_log_probs, all_log_probs = generate( tokens, lengths, output_log_probs = generate(
model, model,
prompts=prompts, prompts=prompts,
tokens_to_generate=tokens_to_generate, tokens_to_generate=tokens_to_generate,
return_output_log_probs=return_output_log_probs, return_output_log_probs=return_output_log_probs,
return_all_log_probs=return_all_log_probs,
greedy_sampling=greedy_sampling, greedy_sampling=greedy_sampling,
top_k_sampling=top_k_sampling, top_k_sampling=top_k_sampling,
top_p_sampling=top_p_sampling, top_p_sampling=top_p_sampling,
...@@ -59,17 +59,16 @@ def generate_and_post_process(model, ...@@ -59,17 +59,16 @@ def generate_and_post_process(model,
# Only post-process on first stage. # Only post-process on first stage.
if mpu.is_pipeline_first_stage(): if mpu.is_pipeline_first_stage():
tokens, prompts_plus_generations, prompts_plus_generations_segments = \ tokens, prompts_plus_generations, prompts_plus_generations_segments = \
detokenize_generations(tokens, lengths, True) detokenize_generations(tokens, lengths, True)
if return_output_log_probs: if return_output_log_probs:
output_log_probs = output_log_probs.cpu().numpy().tolist() output_log_probs = output_log_probs.cpu().numpy().tolist()
if return_all_log_probs: for i, (prob, seg) in enumerate(zip(output_log_probs, prompts_plus_generations_segments)):
all_log_probs = all_log_probs.cpu().numpy().tolist() output_log_probs[i] = prob[:len(seg)-1]
return prompts_plus_generations, prompts_plus_generations_segments, \ return prompts_plus_generations, prompts_plus_generations_segments, \
output_log_probs, all_log_probs, tokens output_log_probs, tokens
return None return None
...@@ -79,7 +78,6 @@ def generate(model, ...@@ -79,7 +78,6 @@ def generate(model,
prompts=None, prompts=None,
tokens_to_generate=0, tokens_to_generate=0,
return_output_log_probs=False, return_output_log_probs=False,
return_all_log_probs=False,
greedy_sampling=False, greedy_sampling=False,
top_k_sampling=0, top_k_sampling=0,
top_p_sampling=0.0, top_p_sampling=0.0,
...@@ -93,25 +91,23 @@ def generate(model, ...@@ -93,25 +91,23 @@ def generate(model,
discard tokens in the tokens tensor that are after the discard tokens in the tokens tensor that are after the
corresponding length. corresponding length.
output_log_probs: log probs of the tokens. output_log_probs: log probs of the tokens.
all_log_probs: full log probs for all of tokens.
""" """
# Make sure input params are avaialble to all ranks. # Make sure input params are avaialble to all ranks.
values = [tokens_to_generate, values = [tokens_to_generate,
return_output_log_probs, return_all_log_probs, return_output_log_probs,
greedy_sampling, top_k_sampling, top_p_sampling, greedy_sampling, top_k_sampling, top_p_sampling,
temperature, add_BOS, use_eod_token_for_early_termination, just_score] temperature, add_BOS, use_eod_token_for_early_termination, just_score]
values_float_tensor = broadcast_float_list(10, float_list=values) values_float_tensor = broadcast_float_list(9, float_list=values)
tokens_to_generate = int(values_float_tensor[0].item()) tokens_to_generate = int(values_float_tensor[0].item())
return_output_log_probs = bool(values_float_tensor[1].item()) return_output_log_probs = bool(values_float_tensor[1].item())
return_all_log_probs = bool(values_float_tensor[2].item()) greedy_sampling = bool(values_float_tensor[2].item())
greedy_sampling = bool(values_float_tensor[3].item()) top_k_sampling = int(values_float_tensor[3].item())
top_k_sampling = int(values_float_tensor[4].item()) top_p_sampling = values_float_tensor[4].item()
top_p_sampling = values_float_tensor[5].item() temperature = values_float_tensor[5].item()
temperature = values_float_tensor[6].item() add_BOS = bool(values_float_tensor[6].item())
add_BOS = bool(values_float_tensor[7].item()) use_eod_token_for_early_termination = bool(values_float_tensor[7].item())
use_eod_token_for_early_termination = bool(values_float_tensor[8].item()) just_score = bool(values_float_tensor[8].item())
just_score = bool(values_float_tensor[9].item())
# Tokenize prompts and get the batch. # Tokenize prompts and get the batch.
# Note that these tensors are broadcaseted to all ranks. # Note that these tensors are broadcaseted to all ranks.
...@@ -121,13 +117,15 @@ def generate(model, ...@@ -121,13 +117,15 @@ def generate(model,
context_tokens_tensor, context_length_tensor = tokenize_prompts( context_tokens_tensor, context_length_tensor = tokenize_prompts(
prompts=prompts, tokens_to_generate=tokens_to_generate, add_BOS=add_BOS) prompts=prompts, tokens_to_generate=tokens_to_generate, add_BOS=add_BOS)
if just_score:
return score_and_return_on_first_stage(
model, context_tokens_tensor, context_length_tensor)
# Main inference function. # Main inference function.
# Note that the outputs are available on the first stage. # Note that the outputs are available on the first stage.
return generate_tokens_probs_and_return_on_first_stage( return generate_tokens_probs_and_return_on_first_stage(
model, context_tokens_tensor, context_length_tensor, model, context_tokens_tensor, context_length_tensor,
return_output_log_probs=return_output_log_probs, return_output_log_probs=return_output_log_probs,
return_all_log_probs=return_all_log_probs,
greedy=greedy_sampling, top_k=top_k_sampling, top_p=top_p_sampling, greedy=greedy_sampling, top_k=top_k_sampling, top_p=top_p_sampling,
temperature=temperature, temperature=temperature,
use_eod_token_for_early_termination=use_eod_token_for_early_termination, use_eod_token_for_early_termination=use_eod_token_for_early_termination)
just_score=just_score)
...@@ -27,15 +27,76 @@ from .communication import ( ...@@ -27,15 +27,76 @@ from .communication import (
from .forward_step import ForwardStep from .forward_step import ForwardStep
from .sampling import sample from .sampling import sample
def score_and_return_on_first_stage(model, tokens, lengths):
"""Function for just scoring.
Arguments:
model: no interleaving is supported.
tokens: prompt tokens extended to be of size [b, max_prompt_length]
lengths: original prompt length, size: [b]
Note: Outside of model, other parameters only need to be available on
rank 0.
Outputs:
output_log_probs: log probability of the selected tokens. size: [b, s]
"""
args = get_args()
batch_size = tokens.size(0)
max_prompt_length = lengths.max().item()
assert max_prompt_length == tokens.size(1)
max_sequence_length = min(max_prompt_length, args.max_position_embeddings)
# forward step.
forward_step = ForwardStep(model, batch_size, max_sequence_length)
# ===================
# Pre-allocate memory
# ===================
# Log probability of the sequence (prompt + generated tokens).
output_log_probs = None
output_log_probs_size = (batch_size, max_sequence_length - 1)
if mpu.is_pipeline_last_stage():
output_log_probs = torch.empty(output_log_probs_size,
dtype=torch.float32,
device=torch.cuda.current_device())
# =============
# Run infernece
# =============
with torch.no_grad():
attention_mask, position_ids = _build_attention_mask_and_position_ids(tokens)
# logits will be meanigful only in the last pipeline stage.
logits = forward_step(tokens, position_ids, attention_mask)
if mpu.is_pipeline_last_stage():
# Always the last stage should have an output.
assert logits is not None
log_probs = F.log_softmax(logits, dim=2)
# Pick the tokens that we need to get the log
# probabilities for. Note that next input token is
# the token which we selected in the current logits,
# so shift by 1.
indices = torch.unsqueeze(tokens[:, 1:], 2)
output_log_probs = torch.gather(log_probs, 2, indices).squeeze(2)
# ======================================
# Broadcast to the first pipeline stage.
# ======================================
output_log_probs = broadcast_from_last_to_first_pipeline_stage(
output_log_probs_size, torch.float32, output_log_probs)
return tokens, lengths, output_log_probs
def generate_tokens_probs_and_return_on_first_stage( def generate_tokens_probs_and_return_on_first_stage(
model, tokens, lengths, model, tokens, lengths,
return_output_log_probs=False, return_output_log_probs=False,
return_all_log_probs=False,
greedy=False, top_k=0, top_p=0.0, greedy=False, top_k=0, top_p=0.0,
temperature=1.0, temperature=1.0,
use_eod_token_for_early_termination=True, use_eod_token_for_early_termination=True):
just_score=False):
"""Main token generation function. """Main token generation function.
Arguments: Arguments:
model: no interleaving is supported. model: no interleaving is supported.
...@@ -44,9 +105,6 @@ def generate_tokens_probs_and_return_on_first_stage( ...@@ -44,9 +105,6 @@ def generate_tokens_probs_and_return_on_first_stage(
return_output_log_probs: flag to calculate the log probability of return_output_log_probs: flag to calculate the log probability of
the generated tokens. Note that the log probability is the one the generated tokens. Note that the log probability is the one
after logits are modifed for sampling. after logits are modifed for sampling.
return_all_log_probs: flag to calculate the log probability of across
all the tokens (vocab size). Note that the log probability is the
one after logits are modifed for sampling.
greedy, top_k, top_p: greedy, top-k, and top-p sampling parameters. greedy, top_k, top_p: greedy, top-k, and top-p sampling parameters.
Note that these three paramters are exclusive meaning that: Note that these three paramters are exclusive meaning that:
if greedy = true then we should have top-k=top-p=0. if greedy = true then we should have top-k=top-p=0.
...@@ -63,8 +121,6 @@ def generate_tokens_probs_and_return_on_first_stage( ...@@ -63,8 +121,6 @@ def generate_tokens_probs_and_return_on_first_stage(
generated_sequence_lengths: total length (including prompt) of generated_sequence_lengths: total length (including prompt) of
the generated sequence. size: [b] the generated sequence. size: [b]
output_log_probs: log probability of the selected tokens. size: [b, s] output_log_probs: log probability of the selected tokens. size: [b, s]
all_log_probs: log probability of all the tokens.
size: [b, s, vocab-size]
""" """
args = get_args() args = get_args()
...@@ -93,9 +149,7 @@ def generate_tokens_probs_and_return_on_first_stage( ...@@ -93,9 +149,7 @@ def generate_tokens_probs_and_return_on_first_stage(
output_log_probs = None output_log_probs = None
output_log_probs_size = (batch_size, max_sequence_length - 1) output_log_probs_size = (batch_size, max_sequence_length - 1)
# Log probability of all tokens for the sequence. # Log probability of all tokens for the sequence.
all_log_probs = None
all_log_probs_size = (batch_size, max_sequence_length -1,
args.padded_vocab_size)
# Lengths of generated seuquence including including prompts. # Lengths of generated seuquence including including prompts.
generated_sequence_lengths = None generated_sequence_lengths = None
if mpu.is_pipeline_last_stage(): if mpu.is_pipeline_last_stage():
...@@ -103,10 +157,6 @@ def generate_tokens_probs_and_return_on_first_stage( ...@@ -103,10 +157,6 @@ def generate_tokens_probs_and_return_on_first_stage(
output_log_probs = torch.empty(output_log_probs_size, output_log_probs = torch.empty(output_log_probs_size,
dtype=torch.float32, dtype=torch.float32,
device=torch.cuda.current_device()) device=torch.cuda.current_device())
if return_all_log_probs:
all_log_probs = torch.empty(all_log_probs_size,
dtype=torch.float32,
device=torch.cuda.current_device())
generated_sequence_lengths = torch.ones( generated_sequence_lengths = torch.ones(
batch_size, dtype=torch.int64, batch_size, dtype=torch.int64,
device=torch.cuda.current_device()) * max_sequence_length device=torch.cuda.current_device()) * max_sequence_length
...@@ -159,12 +209,8 @@ def generate_tokens_probs_and_return_on_first_stage( ...@@ -159,12 +209,8 @@ def generate_tokens_probs_and_return_on_first_stage(
tokens[started, context_length] = new_sample[started] tokens[started, context_length] = new_sample[started]
# Calculate the log probabilities. # Calculate the log probabilities.
if return_output_log_probs or return_all_log_probs: if return_output_log_probs:
log_probs = F.log_softmax(logits, dim=2) log_probs = F.log_softmax(logits, dim=2)
if return_all_log_probs:
all_log_probs[:,
prev_context_length:context_length,
:] = log_probs
if return_output_log_probs: if return_output_log_probs:
# Pick the tokens that we need to get the log # Pick the tokens that we need to get the log
# probabilities for. Note that next input token is # probabilities for. Note that next input token is
...@@ -210,8 +256,6 @@ def generate_tokens_probs_and_return_on_first_stage( ...@@ -210,8 +256,6 @@ def generate_tokens_probs_and_return_on_first_stage(
if mpu.is_pipeline_last_stage(): if mpu.is_pipeline_last_stage():
if return_output_log_probs: if return_output_log_probs:
output_log_probs = output_log_probs[:, :context_length].contiguous() output_log_probs = output_log_probs[:, :context_length].contiguous()
if return_all_log_probs:
all_log_probs = all_log_probs[:, :context_length, :]
# ====================================== # ======================================
# Broadcast to the first pipeline stage. # Broadcast to the first pipeline stage.
...@@ -223,14 +267,8 @@ def generate_tokens_probs_and_return_on_first_stage( ...@@ -223,14 +267,8 @@ def generate_tokens_probs_and_return_on_first_stage(
output_log_probs_size = (batch_size, context_length) output_log_probs_size = (batch_size, context_length)
output_log_probs = broadcast_from_last_to_first_pipeline_stage( output_log_probs = broadcast_from_last_to_first_pipeline_stage(
output_log_probs_size, torch.float32, output_log_probs) output_log_probs_size, torch.float32, output_log_probs)
if return_all_log_probs:
all_log_probs_size = (batch_size, context_length, return tokens, generated_sequence_lengths, output_log_probs
args.padded_vocab_size)
all_log_probs = broadcast_from_last_to_first_pipeline_stage(
all_log_probs_size, torch.float32, all_log_probs)
return tokens, generated_sequence_lengths, output_log_probs, \
all_log_probs
......
...@@ -69,6 +69,8 @@ class MegatronGenerate(Resource): ...@@ -69,6 +69,8 @@ class MegatronGenerate(Resource):
logprobs = request.get_json()["logprobs"] logprobs = request.get_json()["logprobs"]
if not isinstance(logprobs, bool): if not isinstance(logprobs, bool):
return "logprobs must be a boolean value" return "logprobs must be a boolean value"
if just_score and not logprobs:
return "tokens_to_generate=0 implies logprobs=True"
temperature = 1.0 temperature = 1.0
if "temperature" in request.get_json(): if "temperature" in request.get_json():
...@@ -83,7 +85,7 @@ class MegatronGenerate(Resource): ...@@ -83,7 +85,7 @@ class MegatronGenerate(Resource):
top_k = request.get_json()["top_k"] top_k = request.get_json()["top_k"]
if not (type(top_k) == int): if not (type(top_k) == int):
return "top_k must be an integer equal to or greater than 0 and less than or equal to 1000" return "top_k must be an integer equal to or greater than 0 and less than or equal to 1000"
if not (0 < top_k <= 1000): if not (0 <= top_k <= 1000):
return "top_k must be equal to or greater than 0 and less than or equal to 1000" return "top_k must be equal to or greater than 0 and less than or equal to 1000"
top_p = 0.0 top_p = 0.0
...@@ -93,7 +95,7 @@ class MegatronGenerate(Resource): ...@@ -93,7 +95,7 @@ class MegatronGenerate(Resource):
return "top_p must be a positive float less than or equal to 1.0" return "top_p must be a positive float less than or equal to 1.0"
if top_p > 0.0 and top_k > 0.0: if top_p > 0.0 and top_k > 0.0:
return "cannot set both top-k and top-p samplings." return "cannot set both top-k and top-p samplings."
if not (0 < top_p <= 1.0): if not (0 <= top_p <= 1.0):
return "top_p must be less than or equal to 1.0" return "top_p must be less than or equal to 1.0"
add_BOS = False add_BOS = False
...@@ -104,13 +106,12 @@ class MegatronGenerate(Resource): ...@@ -104,13 +106,12 @@ class MegatronGenerate(Resource):
with lock: # Need to get lock to keep multiple threads from hitting code with lock: # Need to get lock to keep multiple threads from hitting code
MegatronGenerate.send_do_generate() # Tell other ranks we're doing generate MegatronGenerate.send_do_generate() # Tell other ranks we're doing generate
response, response_seg, response_logprobs, _, _ = \ response, response_seg, response_logprobs, _ = \
generate_and_post_process( generate_and_post_process(
self.model, self.model,
prompts=prompts, prompts=prompts,
tokens_to_generate=tokens_to_generate, tokens_to_generate=tokens_to_generate,
return_output_log_probs=logprobs, return_output_log_probs=logprobs,
return_all_log_probs=False,
greedy_sampling=args.greedy, greedy_sampling=args.greedy,
top_k_sampling=top_k, top_k_sampling=top_k,
top_p_sampling=top_p, top_p_sampling=top_p,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment