Commit c6e7c7fd authored by mshoeybi's avatar mshoeybi
Browse files

removed return all probs

parent 8d405805
...@@ -31,7 +31,6 @@ def generate_and_post_process(model, ...@@ -31,7 +31,6 @@ def generate_and_post_process(model,
prompts=None, prompts=None,
tokens_to_generate=0, tokens_to_generate=0,
return_output_log_probs=False, return_output_log_probs=False,
return_all_log_probs=False,
greedy_sampling=False, greedy_sampling=False,
top_k_sampling=0, top_k_sampling=0,
top_p_sampling=0.0, top_p_sampling=0.0,
...@@ -42,12 +41,11 @@ def generate_and_post_process(model, ...@@ -42,12 +41,11 @@ def generate_and_post_process(model,
move to cpu and convert to list.""" move to cpu and convert to list."""
# Main inference. # Main inference.
tokens, lengths, output_log_probs, all_log_probs = generate( tokens, lengths, output_log_probs = generate(
model, model,
prompts=prompts, prompts=prompts,
tokens_to_generate=tokens_to_generate, tokens_to_generate=tokens_to_generate,
return_output_log_probs=return_output_log_probs, return_output_log_probs=return_output_log_probs,
return_all_log_probs=return_all_log_probs,
greedy_sampling=greedy_sampling, greedy_sampling=greedy_sampling,
top_k_sampling=top_k_sampling, top_k_sampling=top_k_sampling,
top_p_sampling=top_p_sampling, top_p_sampling=top_p_sampling,
...@@ -63,11 +61,9 @@ def generate_and_post_process(model, ...@@ -63,11 +61,9 @@ def generate_and_post_process(model,
if return_output_log_probs: if return_output_log_probs:
output_log_probs = output_log_probs.cpu().numpy().tolist() output_log_probs = output_log_probs.cpu().numpy().tolist()
if return_all_log_probs:
all_log_probs = all_log_probs.cpu().numpy().tolist()
return prompts_plus_generations, prompts_plus_generations_segments, \ return prompts_plus_generations, prompts_plus_generations_segments, \
output_log_probs, all_log_probs, tokens output_log_probs, tokens
return None return None
...@@ -77,7 +73,6 @@ def generate(model, ...@@ -77,7 +73,6 @@ def generate(model,
prompts=None, prompts=None,
tokens_to_generate=0, tokens_to_generate=0,
return_output_log_probs=False, return_output_log_probs=False,
return_all_log_probs=False,
greedy_sampling=False, greedy_sampling=False,
top_k_sampling=0, top_k_sampling=0,
top_p_sampling=0.0, top_p_sampling=0.0,
...@@ -90,24 +85,21 @@ def generate(model, ...@@ -90,24 +85,21 @@ def generate(model,
discard tokens in the tokens tensor that are after the discard tokens in the tokens tensor that are after the
corresponding length. corresponding length.
output_log_probs: log probs of the tokens. output_log_probs: log probs of the tokens.
all_log_probs: full log probs for all of tokens.
""" """
# Make sure input params are avaialble to all ranks. # Make sure input params are avaialble to all ranks.
values = [tokens_to_generate, values = [tokens_to_generate, return_output_log_probs,
return_output_log_probs, return_all_log_probs,
greedy_sampling, top_k_sampling, top_p_sampling, greedy_sampling, top_k_sampling, top_p_sampling,
temperature, add_BOS, use_eod_token_for_early_termination] temperature, add_BOS, use_eod_token_for_early_termination]
values_float_tensor = broadcast_float_list(9, float_list=values) values_float_tensor = broadcast_float_list(8, float_list=values)
tokens_to_generate = int(values_float_tensor[0].item()) tokens_to_generate = int(values_float_tensor[0].item())
return_output_log_probs = bool(values_float_tensor[1].item()) return_output_log_probs = bool(values_float_tensor[1].item())
return_all_log_probs = bool(values_float_tensor[2].item()) greedy_sampling = bool(values_float_tensor[2].item())
greedy_sampling = bool(values_float_tensor[3].item()) top_k_sampling = int(values_float_tensor[3].item())
top_k_sampling = int(values_float_tensor[4].item()) top_p_sampling = values_float_tensor[4].item()
top_p_sampling = values_float_tensor[5].item() temperature = values_float_tensor[5].item()
temperature = values_float_tensor[6].item() add_BOS = bool(values_float_tensor[6].item())
add_BOS = bool(values_float_tensor[7].item()) use_eod_token_for_early_termination = bool(values_float_tensor[7].item())
use_eod_token_for_early_termination = bool(values_float_tensor[8].item())
# Tokenize prompts and get the batch. # Tokenize prompts and get the batch.
# Note that these tensors are broadcaseted to all ranks. # Note that these tensors are broadcaseted to all ranks.
...@@ -122,7 +114,6 @@ def generate(model, ...@@ -122,7 +114,6 @@ def generate(model,
return generate_tokens_probs_and_return_on_first_stage( return generate_tokens_probs_and_return_on_first_stage(
model, context_tokens_tensor, context_length_tensor, model, context_tokens_tensor, context_length_tensor,
return_output_log_probs=return_output_log_probs, return_output_log_probs=return_output_log_probs,
return_all_log_probs=return_all_log_probs,
greedy=greedy_sampling, top_k=top_k_sampling, top_p=top_p_sampling, greedy=greedy_sampling, top_k=top_k_sampling, top_p=top_p_sampling,
temperature=temperature, temperature=temperature,
use_eod_token_for_early_termination=use_eod_token_for_early_termination) use_eod_token_for_early_termination=use_eod_token_for_early_termination)
...@@ -55,13 +55,31 @@ def send_to_next_pipeline_rank(tensor=None): ...@@ -55,13 +55,31 @@ def send_to_next_pipeline_rank(tensor=None):
def _is_cuda(tensor):
"""Check if a tensor is not none and is cuda."""
assert tensor is not None
assert tensor.is_cuda
def _is_cuda_contiguous(tensor):
"""Check if a tensor is not none, is cuda, and is contiguous."""
_is_cuda(tensor)
assert tensor.is_contiguous()
def broadcast_from_last_pipeline_stage(size, dtype, tensor=None): def broadcast_from_last_pipeline_stage(size, dtype, tensor=None):
"""Broadcast a tensor from last pipeline stage to all ranks.""" """Broadcast a tensor from last pipeline stage to all ranks."""
if mpu.is_pipeline_last_stage(): is_last_stage = mpu.is_pipeline_last_stage()
assert tensor is not None # If first stage and last state are the same, then there is no
assert tensor.is_cuda # pipeline parallelism and no need to communicate.
assert tensor.is_contiguous() if mpu.is_pipeline_first_stage() and is_last_stage:
return tensor
if is_last_stage:
_is_cuda_contiguous(tensor)
else: else:
tensor = torch.empty(size, tensor = torch.empty(size,
dtype=dtype, dtype=dtype,
...@@ -78,14 +96,16 @@ def broadcast_from_last_pipeline_stage(size, dtype, tensor=None): ...@@ -78,14 +96,16 @@ def broadcast_from_last_pipeline_stage(size, dtype, tensor=None):
def broadcast_from_last_to_first_pipeline_stage(size, dtype, tensor=None): def broadcast_from_last_to_first_pipeline_stage(size, dtype, tensor=None):
"""Broadcast tensor values from last stage into the first stage.""" """Broadcast tensor values from last stage into the first stage."""
# Only first and last stage pipeline stages need to be involved.
is_last_stage = mpu.is_pipeline_last_stage() is_last_stage = mpu.is_pipeline_last_stage()
is_first_stage = mpu.is_pipeline_first_stage() is_first_stage = mpu.is_pipeline_first_stage()
# If first stage and last state are the same, then there is no
# pipeline parallelism and no need to communicate.
if is_first_stage and is_last_stage:
return tensor
# Only first and last stage pipeline stages need to be involved.
if is_last_stage or is_first_stage: if is_last_stage or is_first_stage:
if is_last_stage: if is_last_stage:
assert tensor is not None _is_cuda_contiguous(tensor)
assert tensor.is_cuda
assert tensor.is_contiguous()
else: else:
tensor = torch.empty(size, tensor = torch.empty(size,
dtype=dtype, dtype=dtype,
...@@ -105,12 +125,15 @@ def copy_from_last_to_first_pipeline_stage(size, dtype, tensor=None): ...@@ -105,12 +125,15 @@ def copy_from_last_to_first_pipeline_stage(size, dtype, tensor=None):
"""Copy tensor values from last stage into the first stage. """Copy tensor values from last stage into the first stage.
Note that the input tensor is updated in place.""" Note that the input tensor is updated in place."""
# Only first and last stage pipeline stages need to be involved.
is_last_stage = mpu.is_pipeline_last_stage() is_last_stage = mpu.is_pipeline_last_stage()
is_first_stage = mpu.is_pipeline_first_stage() is_first_stage = mpu.is_pipeline_first_stage()
# If first stage and last state are the same, then there is no
# pipeline parallelism and no need to communicate.
if is_first_stage and is_last_stage:
return
# Only first and last stage pipeline stages need to be involved.
if is_last_stage or is_first_stage: if is_last_stage or is_first_stage:
assert tensor is not None _is_cuda(tensor)
assert tensor.is_cuda
is_contiguous = tensor.is_contiguous() is_contiguous = tensor.is_contiguous()
src = mpu.get_pipeline_model_parallel_last_rank() src = mpu.get_pipeline_model_parallel_last_rank()
group = mpu.get_embedding_group() group = mpu.get_embedding_group()
...@@ -137,8 +160,7 @@ def broadcast_tensor(size, dtype, tensor=None, rank=0): ...@@ -137,8 +160,7 @@ def broadcast_tensor(size, dtype, tensor=None, rank=0):
""" """
if torch.distributed.get_rank() == rank: if torch.distributed.get_rank() == rank:
assert tensor is not None _is_cuda_contiguous(tensor)
assert tensor.is_cuda
else: else:
tensor = torch.empty(size, tensor = torch.empty(size,
dtype=dtype, dtype=dtype,
......
...@@ -31,7 +31,6 @@ from .sampling import sample ...@@ -31,7 +31,6 @@ from .sampling import sample
def generate_tokens_probs_and_return_on_first_stage( def generate_tokens_probs_and_return_on_first_stage(
model, tokens, lengths, model, tokens, lengths,
return_output_log_probs=False, return_output_log_probs=False,
return_all_log_probs=False,
greedy=False, top_k=0, top_p=0.0, greedy=False, top_k=0, top_p=0.0,
temperature=1.0, temperature=1.0,
use_eod_token_for_early_termination=True): use_eod_token_for_early_termination=True):
...@@ -43,9 +42,6 @@ def generate_tokens_probs_and_return_on_first_stage( ...@@ -43,9 +42,6 @@ def generate_tokens_probs_and_return_on_first_stage(
return_output_log_probs: flag to calculate the log probability of return_output_log_probs: flag to calculate the log probability of
the generated tokens. Note that the log probability is the one the generated tokens. Note that the log probability is the one
after logits are modifed for sampling. after logits are modifed for sampling.
return_all_log_probs: flag to calculate the log probability of across
all the tokens (vocab size). Note that the log probability is the
one after logits are modifed for sampling.
greedy, top_k, top_p: greedy, top-k, and top-p sampling parameters. greedy, top_k, top_p: greedy, top-k, and top-p sampling parameters.
Note that these three paramters are exclusive meaning that: Note that these three paramters are exclusive meaning that:
if greedy = true then we should have top-k=top-p=0. if greedy = true then we should have top-k=top-p=0.
...@@ -62,8 +58,6 @@ def generate_tokens_probs_and_return_on_first_stage( ...@@ -62,8 +58,6 @@ def generate_tokens_probs_and_return_on_first_stage(
generated_sequence_lengths: total length (including prompt) of generated_sequence_lengths: total length (including prompt) of
the generated sequence. size: [b] the generated sequence. size: [b]
output_log_probs: log probability of the selected tokens. size: [b, s] output_log_probs: log probability of the selected tokens. size: [b, s]
all_log_probs: log probability of all the tokens.
size: [b, s, vocab-size]
""" """
args = get_args() args = get_args()
...@@ -91,10 +85,6 @@ def generate_tokens_probs_and_return_on_first_stage( ...@@ -91,10 +85,6 @@ def generate_tokens_probs_and_return_on_first_stage(
# Log probability of the sequence (prompt + generated tokens). # Log probability of the sequence (prompt + generated tokens).
output_log_probs = None output_log_probs = None
output_log_probs_size = (batch_size, max_sequence_length - 1) output_log_probs_size = (batch_size, max_sequence_length - 1)
# Log probability of all tokens for the sequence.
all_log_probs = None
all_log_probs_size = (batch_size, max_sequence_length -1,
args.padded_vocab_size)
# Lengths of generated seuquence including including prompts. # Lengths of generated seuquence including including prompts.
generated_sequence_lengths = None generated_sequence_lengths = None
if mpu.is_pipeline_last_stage(): if mpu.is_pipeline_last_stage():
...@@ -102,10 +92,6 @@ def generate_tokens_probs_and_return_on_first_stage( ...@@ -102,10 +92,6 @@ def generate_tokens_probs_and_return_on_first_stage(
output_log_probs = torch.empty(output_log_probs_size, output_log_probs = torch.empty(output_log_probs_size,
dtype=torch.float32, dtype=torch.float32,
device=torch.cuda.current_device()) device=torch.cuda.current_device())
if return_all_log_probs:
all_log_probs = torch.empty(all_log_probs_size,
dtype=torch.float32,
device=torch.cuda.current_device())
generated_sequence_lengths = torch.ones( generated_sequence_lengths = torch.ones(
batch_size, dtype=torch.int64, batch_size, dtype=torch.int64,
device=torch.cuda.current_device()) * max_sequence_length device=torch.cuda.current_device()) * max_sequence_length
...@@ -157,12 +143,8 @@ def generate_tokens_probs_and_return_on_first_stage( ...@@ -157,12 +143,8 @@ def generate_tokens_probs_and_return_on_first_stage(
tokens[started, context_length] = new_sample[started] tokens[started, context_length] = new_sample[started]
# Calculate the log probabilities. # Calculate the log probabilities.
if return_output_log_probs or return_all_log_probs: if return_output_log_probs:
log_probs = F.log_softmax(logits, dim=2) log_probs = F.log_softmax(logits, dim=2)
if return_all_log_probs:
all_log_probs[:,
prev_context_length:context_length,
:] = log_probs
if return_output_log_probs: if return_output_log_probs:
# Pick the tokens that we need to get the log # Pick the tokens that we need to get the log
# probabilities for. Note that next input token is # probabilities for. Note that next input token is
...@@ -208,8 +190,6 @@ def generate_tokens_probs_and_return_on_first_stage( ...@@ -208,8 +190,6 @@ def generate_tokens_probs_and_return_on_first_stage(
if mpu.is_pipeline_last_stage(): if mpu.is_pipeline_last_stage():
if return_output_log_probs: if return_output_log_probs:
output_log_probs = output_log_probs[:, :context_length] output_log_probs = output_log_probs[:, :context_length]
if return_all_log_probs:
all_log_probs = all_log_probs[:, :context_length, :]
# ====================================== # ======================================
# Broadcast to the first pipeline stage. # Broadcast to the first pipeline stage.
...@@ -221,14 +201,8 @@ def generate_tokens_probs_and_return_on_first_stage( ...@@ -221,14 +201,8 @@ def generate_tokens_probs_and_return_on_first_stage(
output_log_probs_size = (batch_size, context_length) output_log_probs_size = (batch_size, context_length)
output_log_probs = broadcast_from_last_to_first_pipeline_stage( output_log_probs = broadcast_from_last_to_first_pipeline_stage(
output_log_probs_size, torch.float32, output_log_probs) output_log_probs_size, torch.float32, output_log_probs)
if return_all_log_probs:
all_log_probs_size = (batch_size, context_length, return tokens, generated_sequence_lengths, output_log_probs
args.padded_vocab_size)
all_log_probs = broadcast_from_last_to_first_pipeline_stage(
all_log_probs_size, torch.float32, all_log_probs)
return tokens, generated_sequence_lengths, output_log_probs, \
all_log_probs
......
...@@ -101,13 +101,12 @@ class MegatronGenerate(Resource): ...@@ -101,13 +101,12 @@ class MegatronGenerate(Resource):
with lock: # Need to get lock to keep multiple threads from hitting code with lock: # Need to get lock to keep multiple threads from hitting code
MegatronGenerate.send_do_generate() # Tell other ranks we're doing generate MegatronGenerate.send_do_generate() # Tell other ranks we're doing generate
response, response_seg, response_logprobs, _, _ = \ response, response_seg, response_logprobs, _ = \
generate_and_post_process( generate_and_post_process(
self.model, self.model,
prompts=prompts, prompts=prompts,
tokens_to_generate=tokens_to_generate, tokens_to_generate=tokens_to_generate,
return_output_log_probs=logprobs, return_output_log_probs=logprobs,
return_all_log_probs=False,
greedy_sampling=args.greedy, greedy_sampling=args.greedy,
top_k_sampling=top_k, top_k_sampling=top_k,
top_p_sampling=top_p, top_p_sampling=top_p,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment