Commit 71359e1f authored by mshoeybi's avatar mshoeybi
Browse files

removed greedy argument

parent c6e7c7fd
...@@ -427,7 +427,7 @@ Several downstream tasks are described for both GPT and BERT models below. They ...@@ -427,7 +427,7 @@ Several downstream tasks are described for both GPT and BERT models below. They
## GPT Text Generation ## GPT Text Generation
We have included a simple REST server to use for text generation in `tools/run_text_generation_server.py`. You run it much like you would start a pretraining job, specifying an appropriate pretrained checkpoint. There are also few optional parameters: `temperature`, `top-k`, `top-p`, and `greedy`. See `--help` or the source file for more information. See [examples/run_text_generation_server_345M.sh](examples/run_text_generation_server_345M.sh) for an example of how to run the server. We have included a simple REST server to use for text generation in `tools/run_text_generation_server.py`. You run it much like you would start a pretraining job, specifying an appropriate pretrained checkpoint. There are also few optional parameters: `temperature`, `top-k`and `top-p`. See `--help` or the source file for more information. See [examples/run_text_generation_server_345M.sh](examples/run_text_generation_server_345M.sh) for an example of how to run the server.
Once the server is running you can use `tools/text_generation_cli.py` to query it, it takes one argument which is the host the server is running on. Once the server is running you can use `tools/text_generation_cli.py` to query it, it takes one argument which is the host the server is running on.
......
...@@ -31,7 +31,6 @@ def generate_and_post_process(model, ...@@ -31,7 +31,6 @@ def generate_and_post_process(model,
prompts=None, prompts=None,
tokens_to_generate=0, tokens_to_generate=0,
return_output_log_probs=False, return_output_log_probs=False,
greedy_sampling=False,
top_k_sampling=0, top_k_sampling=0,
top_p_sampling=0.0, top_p_sampling=0.0,
temperature=1.0, temperature=1.0,
...@@ -46,7 +45,6 @@ def generate_and_post_process(model, ...@@ -46,7 +45,6 @@ def generate_and_post_process(model,
prompts=prompts, prompts=prompts,
tokens_to_generate=tokens_to_generate, tokens_to_generate=tokens_to_generate,
return_output_log_probs=return_output_log_probs, return_output_log_probs=return_output_log_probs,
greedy_sampling=greedy_sampling,
top_k_sampling=top_k_sampling, top_k_sampling=top_k_sampling,
top_p_sampling=top_p_sampling, top_p_sampling=top_p_sampling,
temperature=temperature, temperature=temperature,
...@@ -73,7 +71,6 @@ def generate(model, ...@@ -73,7 +71,6 @@ def generate(model,
prompts=None, prompts=None,
tokens_to_generate=0, tokens_to_generate=0,
return_output_log_probs=False, return_output_log_probs=False,
greedy_sampling=False,
top_k_sampling=0, top_k_sampling=0,
top_p_sampling=0.0, top_p_sampling=0.0,
temperature=1.0, temperature=1.0,
...@@ -89,17 +86,16 @@ def generate(model, ...@@ -89,17 +86,16 @@ def generate(model,
# Make sure input params are avaialble to all ranks. # Make sure input params are avaialble to all ranks.
values = [tokens_to_generate, return_output_log_probs, values = [tokens_to_generate, return_output_log_probs,
greedy_sampling, top_k_sampling, top_p_sampling, top_k_sampling, top_p_sampling,
temperature, add_BOS, use_eod_token_for_early_termination] temperature, add_BOS, use_eod_token_for_early_termination]
values_float_tensor = broadcast_float_list(8, float_list=values) values_float_tensor = broadcast_float_list(7, float_list=values)
tokens_to_generate = int(values_float_tensor[0].item()) tokens_to_generate = int(values_float_tensor[0].item())
return_output_log_probs = bool(values_float_tensor[1].item()) return_output_log_probs = bool(values_float_tensor[1].item())
greedy_sampling = bool(values_float_tensor[2].item()) top_k_sampling = int(values_float_tensor[2].item())
top_k_sampling = int(values_float_tensor[3].item()) top_p_sampling = values_float_tensor[3].item()
top_p_sampling = values_float_tensor[4].item() temperature = values_float_tensor[4].item()
temperature = values_float_tensor[5].item() add_BOS = bool(values_float_tensor[5].item())
add_BOS = bool(values_float_tensor[6].item()) use_eod_token_for_early_termination = bool(values_float_tensor[6].item())
use_eod_token_for_early_termination = bool(values_float_tensor[7].item())
# Tokenize prompts and get the batch. # Tokenize prompts and get the batch.
# Note that these tensors are broadcaseted to all ranks. # Note that these tensors are broadcaseted to all ranks.
...@@ -114,6 +110,7 @@ def generate(model, ...@@ -114,6 +110,7 @@ def generate(model,
return generate_tokens_probs_and_return_on_first_stage( return generate_tokens_probs_and_return_on_first_stage(
model, context_tokens_tensor, context_length_tensor, model, context_tokens_tensor, context_length_tensor,
return_output_log_probs=return_output_log_probs, return_output_log_probs=return_output_log_probs,
greedy=greedy_sampling, top_k=top_k_sampling, top_p=top_p_sampling, top_k=top_k_sampling,
top_p=top_p_sampling,
temperature=temperature, temperature=temperature,
use_eod_token_for_early_termination=use_eod_token_for_early_termination) use_eod_token_for_early_termination=use_eod_token_for_early_termination)
...@@ -31,7 +31,7 @@ from .sampling import sample ...@@ -31,7 +31,7 @@ from .sampling import sample
def generate_tokens_probs_and_return_on_first_stage( def generate_tokens_probs_and_return_on_first_stage(
model, tokens, lengths, model, tokens, lengths,
return_output_log_probs=False, return_output_log_probs=False,
greedy=False, top_k=0, top_p=0.0, top_k=0, top_p=0.0,
temperature=1.0, temperature=1.0,
use_eod_token_for_early_termination=True): use_eod_token_for_early_termination=True):
"""Main token generation function. """Main token generation function.
...@@ -41,12 +41,12 @@ def generate_tokens_probs_and_return_on_first_stage( ...@@ -41,12 +41,12 @@ def generate_tokens_probs_and_return_on_first_stage(
lengths: original prompt length, size: [b] lengths: original prompt length, size: [b]
return_output_log_probs: flag to calculate the log probability of return_output_log_probs: flag to calculate the log probability of
the generated tokens. Note that the log probability is the one the generated tokens. Note that the log probability is the one
after logits are modifed for sampling. from the original logit.
greedy, top_k, top_p: greedy, top-k, and top-p sampling parameters. top_k, top_p: top-k and top-p sampling parameters.
Note that these three paramters are exclusive meaning that: Note that top-k = 1 is gready. Also, these paramters are
if greedy = true then we should have top-k=top-p=0. exclusive meaning that:
if top-k > 0 then we expect greedy=false and top-p=0. if top-k > 0 then we expect top-p=0.
if top-p > 0 then we check for greedy=false and top-k=0. if top-p > 0 then we check for top-k=0.
temperature: sampling temperature. temperature: sampling temperature.
use_eod_token_for_early_termination: if True, do early termination if use_eod_token_for_early_termination: if True, do early termination if
all the sequences have reached this token. all the sequences have reached this token.
...@@ -124,22 +124,15 @@ def generate_tokens_probs_and_return_on_first_stage( ...@@ -124,22 +124,15 @@ def generate_tokens_probs_and_return_on_first_stage(
# Sample. # Sample.
last_token_logits = logits[:, -1, :] last_token_logits = logits[:, -1, :]
new_sample, updated_last_token_logits = sample( new_sample = sample(last_token_logits,
last_token_logits, top_k=top_k,
greedy=greedy, top_p=top_p,
top_k=top_k, temperature=temperature,
top_p=top_p, vocab_size=tokenizer.vocab_size)
temperature=temperature,
vocab_size=tokenizer.vocab_size)
# Now that we have the sample and updated logits,
# update the main logits and input tokens.
# If a prompt length is smaller or equal th current context # If a prompt length is smaller or equal th current context
# length, it means we have started generating tokens # length, it means we have started generating tokens
started = lengths <= context_length started = lengths <= context_length
# Update the logits # Update the tokens.
last_token_logits.masked_scatter_(
started.unsqueeze(1), updated_last_token_logits[started])
# and the tokens.
tokens[started, context_length] = new_sample[started] tokens[started, context_length] = new_sample[started]
# Calculate the log probabilities. # Calculate the log probabilities.
......
...@@ -55,8 +55,7 @@ def modify_logits_for_top_p_filtering(logits, top_p): ...@@ -55,8 +55,7 @@ def modify_logits_for_top_p_filtering(logits, top_p):
def sample(logits, greedy=False, top_k=0, top_p=0.0, temperature=1.0, def sample(logits, top_k=0, top_p=0.0, temperature=1.0, vocab_size=None):
vocab_size=None):
""" Sample and generate a token. """ Sample and generate a token.
Note: logits has the dimension [b, v] where b is the batch size Note: logits has the dimension [b, v] where b is the batch size
and v is the vocabulary size. and v is the vocabulary size.
...@@ -70,21 +69,21 @@ def sample(logits, greedy=False, top_k=0, top_p=0.0, temperature=1.0, ...@@ -70,21 +69,21 @@ def sample(logits, greedy=False, top_k=0, top_p=0.0, temperature=1.0,
assert logits.type() == 'torch.cuda.FloatTensor', \ assert logits.type() == 'torch.cuda.FloatTensor', \
'input logits should be floats.' 'input logits should be floats.'
# Clone so we do not modify the inputs,
logits = logits.clone()
# Greedy is just simple argmax. # Greedy is just simple argmax.
if greedy: if top_k == 1:
assert top_k == 0, 'cannot set both greedy and top-k samplings.'
assert top_p == 0.0, 'cannot set both greedy and top-p samplings.' assert top_p == 0.0, 'cannot set both greedy and top-p samplings.'
samples = torch.argmax(logits, dim=-1) samples = torch.argmax(logits, dim=-1)
# Top-k or top-p sampling. # Top-k or top-p sampling.
else: else:
# Clone so we do not modify the inputs,
logits = logits.clone()
# Apply temperature in place. # Apply temperature in place.
logits.div_(temperature) if temperature != 1.0:
logits.div_(temperature)
if top_k > 0: if top_k > 1:
assert top_p == 0.0, 'cannot set both top-k and top-p samplings.' assert top_p == 0.0, 'cannot set both top-k and top-p samplings.'
assert top_k <= logits.size(1), 'top-k is larger than logit size.' assert top_k <= logits.size(1), 'top-k is larger than logit size.'
if vocab_size: if vocab_size:
...@@ -104,4 +103,4 @@ def sample(logits, greedy=False, top_k=0, top_p=0.0, temperature=1.0, ...@@ -104,4 +103,4 @@ def sample(logits, greedy=False, top_k=0, top_p=0.0, temperature=1.0,
if vocab_size: if vocab_size:
samples = torch.clamp(samples, min=0, max=(vocab_size - 1)) samples = torch.clamp(samples, min=0, max=(vocab_size - 1))
return samples, logits return samples
...@@ -107,7 +107,6 @@ class MegatronGenerate(Resource): ...@@ -107,7 +107,6 @@ class MegatronGenerate(Resource):
prompts=prompts, prompts=prompts,
tokens_to_generate=tokens_to_generate, tokens_to_generate=tokens_to_generate,
return_output_log_probs=logprobs, return_output_log_probs=logprobs,
greedy_sampling=args.greedy,
top_k_sampling=top_k, top_k_sampling=top_k,
top_p_sampling=top_p, top_p_sampling=top_p,
temperature=temperature, temperature=temperature,
......
...@@ -43,8 +43,6 @@ def add_text_generate_args(parser): ...@@ -43,8 +43,6 @@ def add_text_generate_args(parser):
group.add_argument("--temperature", type=float, default=1.0, group.add_argument("--temperature", type=float, default=1.0,
help='Sampling temperature.') help='Sampling temperature.')
group.add_argument("--greedy", action='store_true', default=False,
help='Use greedy sampling.')
group.add_argument("--top_p", type=float, default=0.0, group.add_argument("--top_p", type=float, default=0.0,
help='Top p sampling.') help='Top p sampling.')
group.add_argument("--top_k", type=int, default=0, group.add_argument("--top_k", type=int, default=0,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment