Commit c6e7c7fd authored by mshoeybi's avatar mshoeybi
Browse files

removed return all probs

parent 8d405805
......@@ -31,7 +31,6 @@ def generate_and_post_process(model,
prompts=None,
tokens_to_generate=0,
return_output_log_probs=False,
return_all_log_probs=False,
greedy_sampling=False,
top_k_sampling=0,
top_p_sampling=0.0,
......@@ -42,12 +41,11 @@ def generate_and_post_process(model,
move to cpu and convert to list."""
# Main inference.
tokens, lengths, output_log_probs, all_log_probs = generate(
tokens, lengths, output_log_probs = generate(
model,
prompts=prompts,
tokens_to_generate=tokens_to_generate,
return_output_log_probs=return_output_log_probs,
return_all_log_probs=return_all_log_probs,
greedy_sampling=greedy_sampling,
top_k_sampling=top_k_sampling,
top_p_sampling=top_p_sampling,
......@@ -63,11 +61,9 @@ def generate_and_post_process(model,
if return_output_log_probs:
output_log_probs = output_log_probs.cpu().numpy().tolist()
if return_all_log_probs:
all_log_probs = all_log_probs.cpu().numpy().tolist()
return prompts_plus_generations, prompts_plus_generations_segments, \
output_log_probs, all_log_probs, tokens
output_log_probs, tokens
return None
......@@ -77,7 +73,6 @@ def generate(model,
prompts=None,
tokens_to_generate=0,
return_output_log_probs=False,
return_all_log_probs=False,
greedy_sampling=False,
top_k_sampling=0,
top_p_sampling=0.0,
......@@ -90,24 +85,21 @@ def generate(model,
discard tokens in the tokens tensor that are after the
corresponding length.
output_log_probs: log probs of the tokens.
all_log_probs: full log probs for all of tokens.
"""
# Make sure input params are avaialble to all ranks.
values = [tokens_to_generate,
return_output_log_probs, return_all_log_probs,
values = [tokens_to_generate, return_output_log_probs,
greedy_sampling, top_k_sampling, top_p_sampling,
temperature, add_BOS, use_eod_token_for_early_termination]
values_float_tensor = broadcast_float_list(9, float_list=values)
values_float_tensor = broadcast_float_list(8, float_list=values)
tokens_to_generate = int(values_float_tensor[0].item())
return_output_log_probs = bool(values_float_tensor[1].item())
return_all_log_probs = bool(values_float_tensor[2].item())
greedy_sampling = bool(values_float_tensor[3].item())
top_k_sampling = int(values_float_tensor[4].item())
top_p_sampling = values_float_tensor[5].item()
temperature = values_float_tensor[6].item()
add_BOS = bool(values_float_tensor[7].item())
use_eod_token_for_early_termination = bool(values_float_tensor[8].item())
greedy_sampling = bool(values_float_tensor[2].item())
top_k_sampling = int(values_float_tensor[3].item())
top_p_sampling = values_float_tensor[4].item()
temperature = values_float_tensor[5].item()
add_BOS = bool(values_float_tensor[6].item())
use_eod_token_for_early_termination = bool(values_float_tensor[7].item())
# Tokenize prompts and get the batch.
# Note that these tensors are broadcaseted to all ranks.
......@@ -122,7 +114,6 @@ def generate(model,
return generate_tokens_probs_and_return_on_first_stage(
model, context_tokens_tensor, context_length_tensor,
return_output_log_probs=return_output_log_probs,
return_all_log_probs=return_all_log_probs,
greedy=greedy_sampling, top_k=top_k_sampling, top_p=top_p_sampling,
temperature=temperature,
use_eod_token_for_early_termination=use_eod_token_for_early_termination)
......@@ -55,13 +55,31 @@ def send_to_next_pipeline_rank(tensor=None):
def _is_cuda(tensor):
"""Check if a tensor is not none and is cuda."""
assert tensor is not None
assert tensor.is_cuda
def _is_cuda_contiguous(tensor):
"""Check if a tensor is not none, is cuda, and is contiguous."""
_is_cuda(tensor)
assert tensor.is_contiguous()
def broadcast_from_last_pipeline_stage(size, dtype, tensor=None):
"""Broadcast a tensor from last pipeline stage to all ranks."""
if mpu.is_pipeline_last_stage():
assert tensor is not None
assert tensor.is_cuda
assert tensor.is_contiguous()
is_last_stage = mpu.is_pipeline_last_stage()
# If first stage and last state are the same, then there is no
# pipeline parallelism and no need to communicate.
if mpu.is_pipeline_first_stage() and is_last_stage:
return tensor
if is_last_stage:
_is_cuda_contiguous(tensor)
else:
tensor = torch.empty(size,
dtype=dtype,
......@@ -78,14 +96,16 @@ def broadcast_from_last_pipeline_stage(size, dtype, tensor=None):
def broadcast_from_last_to_first_pipeline_stage(size, dtype, tensor=None):
"""Broadcast tensor values from last stage into the first stage."""
# Only first and last stage pipeline stages need to be involved.
is_last_stage = mpu.is_pipeline_last_stage()
is_first_stage = mpu.is_pipeline_first_stage()
# If first stage and last state are the same, then there is no
# pipeline parallelism and no need to communicate.
if is_first_stage and is_last_stage:
return tensor
# Only first and last stage pipeline stages need to be involved.
if is_last_stage or is_first_stage:
if is_last_stage:
assert tensor is not None
assert tensor.is_cuda
assert tensor.is_contiguous()
_is_cuda_contiguous(tensor)
else:
tensor = torch.empty(size,
dtype=dtype,
......@@ -105,12 +125,15 @@ def copy_from_last_to_first_pipeline_stage(size, dtype, tensor=None):
"""Copy tensor values from last stage into the first stage.
Note that the input tensor is updated in place."""
# Only first and last stage pipeline stages need to be involved.
is_last_stage = mpu.is_pipeline_last_stage()
is_first_stage = mpu.is_pipeline_first_stage()
# If first stage and last state are the same, then there is no
# pipeline parallelism and no need to communicate.
if is_first_stage and is_last_stage:
return
# Only first and last stage pipeline stages need to be involved.
if is_last_stage or is_first_stage:
assert tensor is not None
assert tensor.is_cuda
_is_cuda(tensor)
is_contiguous = tensor.is_contiguous()
src = mpu.get_pipeline_model_parallel_last_rank()
group = mpu.get_embedding_group()
......@@ -137,8 +160,7 @@ def broadcast_tensor(size, dtype, tensor=None, rank=0):
"""
if torch.distributed.get_rank() == rank:
assert tensor is not None
assert tensor.is_cuda
_is_cuda_contiguous(tensor)
else:
tensor = torch.empty(size,
dtype=dtype,
......
......@@ -31,7 +31,6 @@ from .sampling import sample
def generate_tokens_probs_and_return_on_first_stage(
model, tokens, lengths,
return_output_log_probs=False,
return_all_log_probs=False,
greedy=False, top_k=0, top_p=0.0,
temperature=1.0,
use_eod_token_for_early_termination=True):
......@@ -43,9 +42,6 @@ def generate_tokens_probs_and_return_on_first_stage(
return_output_log_probs: flag to calculate the log probability of
the generated tokens. Note that the log probability is the one
after logits are modifed for sampling.
return_all_log_probs: flag to calculate the log probability of across
all the tokens (vocab size). Note that the log probability is the
one after logits are modifed for sampling.
greedy, top_k, top_p: greedy, top-k, and top-p sampling parameters.
Note that these three paramters are exclusive meaning that:
if greedy = true then we should have top-k=top-p=0.
......@@ -62,8 +58,6 @@ def generate_tokens_probs_and_return_on_first_stage(
generated_sequence_lengths: total length (including prompt) of
the generated sequence. size: [b]
output_log_probs: log probability of the selected tokens. size: [b, s]
all_log_probs: log probability of all the tokens.
size: [b, s, vocab-size]
"""
args = get_args()
......@@ -91,10 +85,6 @@ def generate_tokens_probs_and_return_on_first_stage(
# Log probability of the sequence (prompt + generated tokens).
output_log_probs = None
output_log_probs_size = (batch_size, max_sequence_length - 1)
# Log probability of all tokens for the sequence.
all_log_probs = None
all_log_probs_size = (batch_size, max_sequence_length -1,
args.padded_vocab_size)
# Lengths of generated seuquence including including prompts.
generated_sequence_lengths = None
if mpu.is_pipeline_last_stage():
......@@ -102,10 +92,6 @@ def generate_tokens_probs_and_return_on_first_stage(
output_log_probs = torch.empty(output_log_probs_size,
dtype=torch.float32,
device=torch.cuda.current_device())
if return_all_log_probs:
all_log_probs = torch.empty(all_log_probs_size,
dtype=torch.float32,
device=torch.cuda.current_device())
generated_sequence_lengths = torch.ones(
batch_size, dtype=torch.int64,
device=torch.cuda.current_device()) * max_sequence_length
......@@ -157,12 +143,8 @@ def generate_tokens_probs_and_return_on_first_stage(
tokens[started, context_length] = new_sample[started]
# Calculate the log probabilities.
if return_output_log_probs or return_all_log_probs:
if return_output_log_probs:
log_probs = F.log_softmax(logits, dim=2)
if return_all_log_probs:
all_log_probs[:,
prev_context_length:context_length,
:] = log_probs
if return_output_log_probs:
# Pick the tokens that we need to get the log
# probabilities for. Note that next input token is
......@@ -208,8 +190,6 @@ def generate_tokens_probs_and_return_on_first_stage(
if mpu.is_pipeline_last_stage():
if return_output_log_probs:
output_log_probs = output_log_probs[:, :context_length]
if return_all_log_probs:
all_log_probs = all_log_probs[:, :context_length, :]
# ======================================
# Broadcast to the first pipeline stage.
......@@ -221,14 +201,8 @@ def generate_tokens_probs_and_return_on_first_stage(
output_log_probs_size = (batch_size, context_length)
output_log_probs = broadcast_from_last_to_first_pipeline_stage(
output_log_probs_size, torch.float32, output_log_probs)
if return_all_log_probs:
all_log_probs_size = (batch_size, context_length,
args.padded_vocab_size)
all_log_probs = broadcast_from_last_to_first_pipeline_stage(
all_log_probs_size, torch.float32, all_log_probs)
return tokens, generated_sequence_lengths, output_log_probs, \
all_log_probs
return tokens, generated_sequence_lengths, output_log_probs
......
......@@ -101,13 +101,12 @@ class MegatronGenerate(Resource):
with lock: # Need to get lock to keep multiple threads from hitting code
MegatronGenerate.send_do_generate() # Tell other ranks we're doing generate
response, response_seg, response_logprobs, _, _ = \
response, response_seg, response_logprobs, _ = \
generate_and_post_process(
self.model,
prompts=prompts,
tokens_to_generate=tokens_to_generate,
return_output_log_probs=logprobs,
return_all_log_probs=False,
greedy_sampling=args.greedy,
top_k_sampling=top_k,
top_p_sampling=top_p,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment