Commit 803ae5ee authored by rprenger's avatar rprenger
Browse files

cleaning up merge conflicts

parents 9cc286ba 0b0e37f0
...@@ -427,7 +427,7 @@ Several downstream tasks are described for both GPT and BERT models below. They ...@@ -427,7 +427,7 @@ Several downstream tasks are described for both GPT and BERT models below. They
## GPT Text Generation ## GPT Text Generation
We have included a simple REST server to use for text generation in `tools/run_text_generation_server.py`. You run it much like you would start a pretraining job, specifying an appropriate pretrained checkpoint. There are also few optional parameters: `temperature`, `top-k`, `top-p`, and `greedy`. See `--help` or the source file for more information. See [examples/run_text_generation_server_345M.sh](examples/run_text_generation_server_345M.sh) for an example of how to run the server. We have included a simple REST server to use for text generation in `tools/run_text_generation_server.py`. You run it much like you would start a pretraining job, specifying an appropriate pretrained checkpoint. There are also few optional parameters: `temperature`, `top-k`and `top-p`. See `--help` or the source file for more information. See [examples/run_text_generation_server_345M.sh](examples/run_text_generation_server_345M.sh) for an example of how to run the server.
Once the server is running you can use `tools/text_generation_cli.py` to query it, it takes one argument which is the host the server is running on. Once the server is running you can use `tools/text_generation_cli.py` to query it, it takes one argument which is the host the server is running on.
......
...@@ -33,7 +33,6 @@ def generate_and_post_process(model, ...@@ -33,7 +33,6 @@ def generate_and_post_process(model,
prompts=None, prompts=None,
tokens_to_generate=0, tokens_to_generate=0,
return_output_log_probs=False, return_output_log_probs=False,
greedy_sampling=False,
top_k_sampling=0, top_k_sampling=0,
top_p_sampling=0.0, top_p_sampling=0.0,
temperature=1.0, temperature=1.0,
...@@ -49,7 +48,6 @@ def generate_and_post_process(model, ...@@ -49,7 +48,6 @@ def generate_and_post_process(model,
prompts=prompts, prompts=prompts,
tokens_to_generate=tokens_to_generate, tokens_to_generate=tokens_to_generate,
return_output_log_probs=return_output_log_probs, return_output_log_probs=return_output_log_probs,
greedy_sampling=greedy_sampling,
top_k_sampling=top_k_sampling, top_k_sampling=top_k_sampling,
top_p_sampling=top_p_sampling, top_p_sampling=top_p_sampling,
temperature=temperature, temperature=temperature,
...@@ -78,7 +76,6 @@ def generate(model, ...@@ -78,7 +76,6 @@ def generate(model,
prompts=None, prompts=None,
tokens_to_generate=0, tokens_to_generate=0,
return_output_log_probs=False, return_output_log_probs=False,
greedy_sampling=False,
top_k_sampling=0, top_k_sampling=0,
top_p_sampling=0.0, top_p_sampling=0.0,
temperature=1.0, temperature=1.0,
...@@ -98,16 +95,15 @@ def generate(model, ...@@ -98,16 +95,15 @@ def generate(model,
return_output_log_probs, return_output_log_probs,
greedy_sampling, top_k_sampling, top_p_sampling, greedy_sampling, top_k_sampling, top_p_sampling,
temperature, add_BOS, use_eod_token_for_early_termination, just_score] temperature, add_BOS, use_eod_token_for_early_termination, just_score]
values_float_tensor = broadcast_float_list(9, float_list=values) values_float_tensor = broadcast_float_list(8, float_list=values)
tokens_to_generate = int(values_float_tensor[0].item()) tokens_to_generate = int(values_float_tensor[0].item())
return_output_log_probs = bool(values_float_tensor[1].item()) return_output_log_probs = bool(values_float_tensor[1].item())
greedy_sampling = bool(values_float_tensor[2].item()) top_k_sampling = int(values_float_tensor[2].item())
top_k_sampling = int(values_float_tensor[3].item()) top_p_sampling = values_float_tensor[3].item()
top_p_sampling = values_float_tensor[4].item() temperature = values_float_tensor[4].item()
temperature = values_float_tensor[5].item() add_BOS = bool(values_float_tensor[5].item())
add_BOS = bool(values_float_tensor[6].item()) use_eod_token_for_early_termination = bool(values_float_tensor[6].item())
use_eod_token_for_early_termination = bool(values_float_tensor[7].item()) just_score = bool(values_float_tensor[7].item())
just_score = bool(values_float_tensor[8].item())
# Tokenize prompts and get the batch. # Tokenize prompts and get the batch.
# Note that these tensors are broadcaseted to all ranks. # Note that these tensors are broadcaseted to all ranks.
...@@ -126,6 +122,7 @@ def generate(model, ...@@ -126,6 +122,7 @@ def generate(model,
return generate_tokens_probs_and_return_on_first_stage( return generate_tokens_probs_and_return_on_first_stage(
model, context_tokens_tensor, context_length_tensor, model, context_tokens_tensor, context_length_tensor,
return_output_log_probs=return_output_log_probs, return_output_log_probs=return_output_log_probs,
greedy=greedy_sampling, top_k=top_k_sampling, top_p=top_p_sampling, top_k=top_k_sampling,
top_p=top_p_sampling,
temperature=temperature, temperature=temperature,
use_eod_token_for_early_termination=use_eod_token_for_early_termination) use_eod_token_for_early_termination=use_eod_token_for_early_termination)
...@@ -55,13 +55,31 @@ def send_to_next_pipeline_rank(tensor=None): ...@@ -55,13 +55,31 @@ def send_to_next_pipeline_rank(tensor=None):
def broadcast_from_last_pipeline_stage(size, dtype, tensor=None): def _is_cuda(tensor):
"""Broadcast a tensor from last pipeline stage to all ranks.""" """Check if a tensor is not none and is cuda."""
if mpu.is_pipeline_last_stage():
assert tensor is not None assert tensor is not None
assert tensor.is_cuda assert tensor.is_cuda
def _is_cuda_contiguous(tensor):
"""Check if a tensor is not none, is cuda, and is contiguous."""
_is_cuda(tensor)
assert tensor.is_contiguous() assert tensor.is_contiguous()
def broadcast_from_last_pipeline_stage(size, dtype, tensor=None):
"""Broadcast a tensor from last pipeline stage to all ranks."""
is_last_stage = mpu.is_pipeline_last_stage()
# If first stage and last state are the same, then there is no
# pipeline parallelism and no need to communicate.
if mpu.is_pipeline_first_stage() and is_last_stage:
return tensor
if is_last_stage:
_is_cuda_contiguous(tensor)
else: else:
tensor = torch.empty(size, tensor = torch.empty(size,
dtype=dtype, dtype=dtype,
...@@ -78,14 +96,16 @@ def broadcast_from_last_pipeline_stage(size, dtype, tensor=None): ...@@ -78,14 +96,16 @@ def broadcast_from_last_pipeline_stage(size, dtype, tensor=None):
def broadcast_from_last_to_first_pipeline_stage(size, dtype, tensor=None): def broadcast_from_last_to_first_pipeline_stage(size, dtype, tensor=None):
"""Broadcast tensor values from last stage into the first stage.""" """Broadcast tensor values from last stage into the first stage."""
# Only first and last stage pipeline stages need to be involved.
is_last_stage = mpu.is_pipeline_last_stage() is_last_stage = mpu.is_pipeline_last_stage()
is_first_stage = mpu.is_pipeline_first_stage() is_first_stage = mpu.is_pipeline_first_stage()
# If first stage and last state are the same, then there is no
# pipeline parallelism and no need to communicate.
if is_first_stage and is_last_stage:
return tensor
# Only first and last stage pipeline stages need to be involved.
if is_last_stage or is_first_stage: if is_last_stage or is_first_stage:
if is_last_stage: if is_last_stage:
assert tensor is not None _is_cuda_contiguous(tensor)
assert tensor.is_cuda
assert tensor.is_contiguous()
else: else:
tensor = torch.empty(size, tensor = torch.empty(size,
dtype=dtype, dtype=dtype,
...@@ -105,12 +125,15 @@ def copy_from_last_to_first_pipeline_stage(size, dtype, tensor=None): ...@@ -105,12 +125,15 @@ def copy_from_last_to_first_pipeline_stage(size, dtype, tensor=None):
"""Copy tensor values from last stage into the first stage. """Copy tensor values from last stage into the first stage.
Note that the input tensor is updated in place.""" Note that the input tensor is updated in place."""
# Only first and last stage pipeline stages need to be involved.
is_last_stage = mpu.is_pipeline_last_stage() is_last_stage = mpu.is_pipeline_last_stage()
is_first_stage = mpu.is_pipeline_first_stage() is_first_stage = mpu.is_pipeline_first_stage()
# If first stage and last state are the same, then there is no
# pipeline parallelism and no need to communicate.
if is_first_stage and is_last_stage:
return
# Only first and last stage pipeline stages need to be involved.
if is_last_stage or is_first_stage: if is_last_stage or is_first_stage:
assert tensor is not None _is_cuda(tensor)
assert tensor.is_cuda
is_contiguous = tensor.is_contiguous() is_contiguous = tensor.is_contiguous()
src = mpu.get_pipeline_model_parallel_last_rank() src = mpu.get_pipeline_model_parallel_last_rank()
group = mpu.get_embedding_group() group = mpu.get_embedding_group()
...@@ -137,8 +160,7 @@ def broadcast_tensor(size, dtype, tensor=None, rank=0): ...@@ -137,8 +160,7 @@ def broadcast_tensor(size, dtype, tensor=None, rank=0):
""" """
if torch.distributed.get_rank() == rank: if torch.distributed.get_rank() == rank:
assert tensor is not None _is_cuda_contiguous(tensor)
assert tensor.is_cuda
else: else:
tensor = torch.empty(size, tensor = torch.empty(size,
dtype=dtype, dtype=dtype,
......
...@@ -94,7 +94,7 @@ def score_and_return_on_first_stage(model, tokens, lengths): ...@@ -94,7 +94,7 @@ def score_and_return_on_first_stage(model, tokens, lengths):
def generate_tokens_probs_and_return_on_first_stage( def generate_tokens_probs_and_return_on_first_stage(
model, tokens, lengths, model, tokens, lengths,
return_output_log_probs=False, return_output_log_probs=False,
greedy=False, top_k=0, top_p=0.0, top_k=0, top_p=0.0,
temperature=1.0, temperature=1.0,
use_eod_token_for_early_termination=True): use_eod_token_for_early_termination=True):
"""Main token generation function. """Main token generation function.
...@@ -104,12 +104,12 @@ def generate_tokens_probs_and_return_on_first_stage( ...@@ -104,12 +104,12 @@ def generate_tokens_probs_and_return_on_first_stage(
lengths: original prompt length, size: [b] lengths: original prompt length, size: [b]
return_output_log_probs: flag to calculate the log probability of return_output_log_probs: flag to calculate the log probability of
the generated tokens. Note that the log probability is the one the generated tokens. Note that the log probability is the one
after logits are modifed for sampling. from the original logit.
greedy, top_k, top_p: greedy, top-k, and top-p sampling parameters. top_k, top_p: top-k and top-p sampling parameters.
Note that these three paramters are exclusive meaning that: Note that top-k = 1 is gready. Also, these paramters are
if greedy = true then we should have top-k=top-p=0. exclusive meaning that:
if top-k > 0 then we expect greedy=false and top-p=0. if top-k > 0 then we expect top-p=0.
if top-p > 0 then we check for greedy=false and top-k=0. if top-p > 0 then we check for top-k=0.
temperature: sampling temperature. temperature: sampling temperature.
use_eod_token_for_early_termination: if True, do early termination if use_eod_token_for_early_termination: if True, do early termination if
all the sequences have reached this token. all the sequences have reached this token.
...@@ -148,8 +148,6 @@ def generate_tokens_probs_and_return_on_first_stage( ...@@ -148,8 +148,6 @@ def generate_tokens_probs_and_return_on_first_stage(
# Log probability of the sequence (prompt + generated tokens). # Log probability of the sequence (prompt + generated tokens).
output_log_probs = None output_log_probs = None
output_log_probs_size = (batch_size, max_sequence_length - 1) output_log_probs_size = (batch_size, max_sequence_length - 1)
# Log probability of all tokens for the sequence.
# Lengths of generated seuquence including including prompts. # Lengths of generated seuquence including including prompts.
generated_sequence_lengths = None generated_sequence_lengths = None
if mpu.is_pipeline_last_stage(): if mpu.is_pipeline_last_stage():
...@@ -190,22 +188,15 @@ def generate_tokens_probs_and_return_on_first_stage( ...@@ -190,22 +188,15 @@ def generate_tokens_probs_and_return_on_first_stage(
# Sample. # Sample.
last_token_logits = logits[:, -1, :] last_token_logits = logits[:, -1, :]
new_sample, updated_last_token_logits = sample( new_sample = sample(last_token_logits,
last_token_logits,
greedy=greedy,
top_k=top_k, top_k=top_k,
top_p=top_p, top_p=top_p,
temperature=temperature, temperature=temperature,
vocab_size=tokenizer.vocab_size) vocab_size=tokenizer.vocab_size)
# Now that we have the sample and updated logits,
# update the main logits and input tokens.
# If a prompt length is smaller or equal th current context # If a prompt length is smaller or equal th current context
# length, it means we have started generating tokens # length, it means we have started generating tokens
started = lengths <= context_length started = lengths <= context_length
# Update the logits # Update the tokens.
last_token_logits.masked_scatter_(
started.unsqueeze(1), updated_last_token_logits[started])
# and the tokens.
tokens[started, context_length] = new_sample[started] tokens[started, context_length] = new_sample[started]
# Calculate the log probabilities. # Calculate the log probabilities.
...@@ -255,7 +246,7 @@ def generate_tokens_probs_and_return_on_first_stage( ...@@ -255,7 +246,7 @@ def generate_tokens_probs_and_return_on_first_stage(
tokens = tokens[:, :(context_length + 1)] tokens = tokens[:, :(context_length + 1)]
if mpu.is_pipeline_last_stage(): if mpu.is_pipeline_last_stage():
if return_output_log_probs: if return_output_log_probs:
output_log_probs = output_log_probs[:, :context_length].contiguous() output_log_probs = output_log_probs[:, :context_length]
# ====================================== # ======================================
# Broadcast to the first pipeline stage. # Broadcast to the first pipeline stage.
......
...@@ -55,8 +55,7 @@ def modify_logits_for_top_p_filtering(logits, top_p): ...@@ -55,8 +55,7 @@ def modify_logits_for_top_p_filtering(logits, top_p):
def sample(logits, greedy=False, top_k=0, top_p=0.0, temperature=1.0, def sample(logits, top_k=0, top_p=0.0, temperature=1.0, vocab_size=None):
vocab_size=None):
""" Sample and generate a token. """ Sample and generate a token.
Note: logits has the dimension [b, v] where b is the batch size Note: logits has the dimension [b, v] where b is the batch size
and v is the vocabulary size. and v is the vocabulary size.
...@@ -70,21 +69,21 @@ def sample(logits, greedy=False, top_k=0, top_p=0.0, temperature=1.0, ...@@ -70,21 +69,21 @@ def sample(logits, greedy=False, top_k=0, top_p=0.0, temperature=1.0,
assert logits.type() == 'torch.cuda.FloatTensor', \ assert logits.type() == 'torch.cuda.FloatTensor', \
'input logits should be floats.' 'input logits should be floats.'
# Clone so we do not modify the inputs,
logits = logits.clone()
# Greedy is just simple argmax. # Greedy is just simple argmax.
if greedy: if top_k == 1:
assert top_k == 0, 'cannot set both greedy and top-k samplings.'
assert top_p == 0.0, 'cannot set both greedy and top-p samplings.' assert top_p == 0.0, 'cannot set both greedy and top-p samplings.'
samples = torch.argmax(logits, dim=-1) samples = torch.argmax(logits, dim=-1)
# Top-k or top-p sampling. # Top-k or top-p sampling.
else: else:
# Clone so we do not modify the inputs,
logits = logits.clone()
# Apply temperature in place. # Apply temperature in place.
if temperature != 1.0:
logits.div_(temperature) logits.div_(temperature)
if top_k > 0: if top_k > 1:
assert top_p == 0.0, 'cannot set both top-k and top-p samplings.' assert top_p == 0.0, 'cannot set both top-k and top-p samplings.'
assert top_k <= logits.size(1), 'top-k is larger than logit size.' assert top_k <= logits.size(1), 'top-k is larger than logit size.'
if vocab_size: if vocab_size:
...@@ -104,4 +103,4 @@ def sample(logits, greedy=False, top_k=0, top_p=0.0, temperature=1.0, ...@@ -104,4 +103,4 @@ def sample(logits, greedy=False, top_k=0, top_p=0.0, temperature=1.0,
if vocab_size: if vocab_size:
samples = torch.clamp(samples, min=0, max=(vocab_size - 1)) samples = torch.clamp(samples, min=0, max=(vocab_size - 1))
return samples, logits return samples
...@@ -112,7 +112,6 @@ class MegatronGenerate(Resource): ...@@ -112,7 +112,6 @@ class MegatronGenerate(Resource):
prompts=prompts, prompts=prompts,
tokens_to_generate=tokens_to_generate, tokens_to_generate=tokens_to_generate,
return_output_log_probs=logprobs, return_output_log_probs=logprobs,
greedy_sampling=args.greedy,
top_k_sampling=top_k, top_k_sampling=top_k,
top_p_sampling=top_p, top_p_sampling=top_p,
temperature=temperature, temperature=temperature,
......
...@@ -43,8 +43,6 @@ def add_text_generate_args(parser): ...@@ -43,8 +43,6 @@ def add_text_generate_args(parser):
group.add_argument("--temperature", type=float, default=1.0, group.add_argument("--temperature", type=float, default=1.0,
help='Sampling temperature.') help='Sampling temperature.')
group.add_argument("--greedy", action='store_true', default=False,
help='Use greedy sampling.')
group.add_argument("--top_p", type=float, default=0.0, group.add_argument("--top_p", type=float, default=0.0,
help='Top p sampling.') help='Top p sampling.')
group.add_argument("--top_k", type=int, default=0, group.add_argument("--top_k", type=int, default=0,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment