"tests/git@developer.sourcefind.cn:OpenDAS/apex.git" did not exist on "ab5fc48f4d601a87b28e8fe91e4bf553862c255d"
Commit f4a8b1d9 authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'disable_newline_after_colon' into 'main'

Disable newline after colon

See merge request ADLR/megatron-lm!469
parents abf60f75 544e2502
...@@ -28,6 +28,7 @@ def generate_and_post_process(model, ...@@ -28,6 +28,7 @@ def generate_and_post_process(model,
use_eod_token_for_early_termination=True, use_eod_token_for_early_termination=True,
stop_on_double_eol=False, stop_on_double_eol=False,
stop_on_eol=False, stop_on_eol=False,
prevent_newline_after_colon=False,
random_seed=-1): random_seed=-1):
"""Run inference and post-process outputs, i.e., detokenize, """Run inference and post-process outputs, i.e., detokenize,
move to cpu and convert to list.""" move to cpu and convert to list."""
...@@ -47,6 +48,7 @@ def generate_and_post_process(model, ...@@ -47,6 +48,7 @@ def generate_and_post_process(model,
use_eod_token_for_early_termination=use_eod_token_for_early_termination, use_eod_token_for_early_termination=use_eod_token_for_early_termination,
stop_on_double_eol=stop_on_double_eol, stop_on_double_eol=stop_on_double_eol,
stop_on_eol=stop_on_eol, stop_on_eol=stop_on_eol,
prevent_newline_after_colon=prevent_newline_after_colon,
random_seed=random_seed) random_seed=random_seed)
# Only post-process on first stage. # Only post-process on first stage.
...@@ -77,6 +79,7 @@ def generate(model, ...@@ -77,6 +79,7 @@ def generate(model,
use_eod_token_for_early_termination=True, use_eod_token_for_early_termination=True,
stop_on_double_eol=False, stop_on_double_eol=False,
stop_on_eol=False, stop_on_eol=False,
prevent_newline_after_colon=False,
random_seed=-1): random_seed=-1):
"""Given prompts and input parameters, run inference and return: """Given prompts and input parameters, run inference and return:
tokens: prompts plus the generated tokens. tokens: prompts plus the generated tokens.
...@@ -93,8 +96,9 @@ def generate(model, ...@@ -93,8 +96,9 @@ def generate(model,
temperature, add_BOS, use_eod_token_for_early_termination, temperature, add_BOS, use_eod_token_for_early_termination,
stop_on_double_eol, stop_on_double_eol,
stop_on_eol, stop_on_eol,
prevent_newline_after_colon,
random_seed] random_seed]
values_float_tensor = broadcast_float_list(12, float_list=values) values_float_tensor = broadcast_float_list(len(values), float_list=values)
tokens_to_generate = int(values_float_tensor[0].item()) tokens_to_generate = int(values_float_tensor[0].item())
return_output_log_probs = bool(values_float_tensor[1].item()) return_output_log_probs = bool(values_float_tensor[1].item())
top_k_sampling = int(values_float_tensor[2].item()) top_k_sampling = int(values_float_tensor[2].item())
...@@ -106,7 +110,8 @@ def generate(model, ...@@ -106,7 +110,8 @@ def generate(model,
use_eod_token_for_early_termination = bool(values_float_tensor[8].item()) use_eod_token_for_early_termination = bool(values_float_tensor[8].item())
stop_on_double_eol = bool(values_float_tensor[9].item()) stop_on_double_eol = bool(values_float_tensor[9].item())
stop_on_eol = bool(values_float_tensor[10].item()) stop_on_eol = bool(values_float_tensor[10].item())
random_seed = int(values_float_tensor[11].item()) prevent_newline_after_colon = bool(values_float_tensor[11].item())
random_seed = int(values_float_tensor[12].item())
if random_seed != -1: if random_seed != -1:
torch.random.manual_seed(random_seed) torch.random.manual_seed(random_seed)
...@@ -135,7 +140,8 @@ def generate(model, ...@@ -135,7 +140,8 @@ def generate(model,
temperature=temperature, temperature=temperature,
use_eod_token_for_early_termination=use_eod_token_for_early_termination, use_eod_token_for_early_termination=use_eod_token_for_early_termination,
stop_on_double_eol=stop_on_double_eol, stop_on_double_eol=stop_on_double_eol,
stop_on_eol=stop_on_eol) stop_on_eol=stop_on_eol,
prevent_newline_after_colon=prevent_newline_after_colon)
def beam_search_and_post_process(model, def beam_search_and_post_process(model,
prompts=None, prompts=None,
...@@ -144,7 +150,8 @@ def beam_search_and_post_process(model, ...@@ -144,7 +150,8 @@ def beam_search_and_post_process(model,
add_BOS=False, add_BOS=False,
stop_token=50256, stop_token=50256,
num_return_gen=1, num_return_gen=1,
length_penalty=1): length_penalty=1,
prevent_newline_after_colon=False):
"""Run beam search and post-process outputs, i.e., detokenize, """Run beam search and post-process outputs, i.e., detokenize,
move to cpu and convert to list.""" move to cpu and convert to list."""
...@@ -156,7 +163,8 @@ def beam_search_and_post_process(model, ...@@ -156,7 +163,8 @@ def beam_search_and_post_process(model,
add_BOS=add_BOS, add_BOS=add_BOS,
stop_token=stop_token, stop_token=stop_token,
num_return_gen=num_return_gen, num_return_gen=num_return_gen,
length_penalty=length_penalty) length_penalty=length_penalty,
prevent_newline_after_colon=prevent_newline_after_colon)
# Only post-process on first stage. # Only post-process on first stage.
if mpu.is_pipeline_first_stage(): if mpu.is_pipeline_first_stage():
lengths = tokens.size(1)*torch.ones(beam_size, dtype=torch.int64, device=torch.cuda.current_device()) lengths = tokens.size(1)*torch.ones(beam_size, dtype=torch.int64, device=torch.cuda.current_device())
...@@ -166,24 +174,27 @@ def beam_search_and_post_process(model, ...@@ -166,24 +174,27 @@ def beam_search_and_post_process(model,
return None return None
def beam_search(model, prompts=None, tokens_to_generate=0, beam_size=0, add_BOS=False, stop_token=50256, num_return_gen=1, length_penalty=1): def beam_search(model, prompts=None, tokens_to_generate=0, beam_size=0, add_BOS=False, stop_token=50256, num_return_gen=1, length_penalty=1, prevent_newline_after_colon=False):
# Make sure input params are avaialble to all ranks. # Make sure input params are avaialble to all ranks.
values = [tokens_to_generate, values = [tokens_to_generate,
beam_size, beam_size,
add_BOS, add_BOS,
stop_token, stop_token,
num_return_gen, num_return_gen,
length_penalty] length_penalty,
values_float_tensor = broadcast_float_list(6, float_list=values) prevent_newline_after_colon]
values_float_tensor = broadcast_float_list(len(values), float_list=values)
tokens_to_generate = int(values_float_tensor[0].item()) tokens_to_generate = int(values_float_tensor[0].item())
beam_size = int(values_float_tensor[1].item()) beam_size = int(values_float_tensor[1].item())
add_BOS = bool(values_float_tensor[2].item()) add_BOS = bool(values_float_tensor[2].item())
stop_token = int(values_float_tensor[3].item()) stop_token = int(values_float_tensor[3].item())
num_return_gen = int(values_float_tensor[4].item()) num_return_gen = int(values_float_tensor[4].item())
length_penalty = values_float_tensor[5].item() length_penalty = values_float_tensor[5].item()
prevent_newline_after_colon = values_float_tensor[6].item()
context_tokens_tensor, context_length_tensor = tokenize_prompts( context_tokens_tensor, context_length_tensor = tokenize_prompts(
prompts=prompts, tokens_to_generate=tokens_to_generate, add_BOS=add_BOS) prompts=prompts, tokens_to_generate=tokens_to_generate, add_BOS=add_BOS)
return beam_search_and_return_on_first_stage(model, context_tokens_tensor, context_length_tensor, return beam_search_and_return_on_first_stage(model, context_tokens_tensor, context_length_tensor,
beam_size, stop_token=stop_token, num_return_gen=num_return_gen, length_penalty=length_penalty) beam_size, stop_token=stop_token, num_return_gen=num_return_gen, length_penalty=length_penalty,
prevent_newline_after_colon=prevent_newline_after_colon)
...@@ -93,7 +93,8 @@ def generate_tokens_probs_and_return_on_first_stage( ...@@ -93,7 +93,8 @@ def generate_tokens_probs_and_return_on_first_stage(
temperature=1.0, temperature=1.0,
use_eod_token_for_early_termination=True, use_eod_token_for_early_termination=True,
stop_on_double_eol=False, stop_on_double_eol=False,
stop_on_eol=False stop_on_eol=False,
prevent_newline_after_colon=True
): ):
"""Main token generation function. """Main token generation function.
Arguments: Arguments:
...@@ -111,6 +112,7 @@ def generate_tokens_probs_and_return_on_first_stage( ...@@ -111,6 +112,7 @@ def generate_tokens_probs_and_return_on_first_stage(
temperature: sampling temperature. temperature: sampling temperature.
use_eod_token_for_early_termination: if True, do early termination if use_eod_token_for_early_termination: if True, do early termination if
all the sequences have reached this token. all the sequences have reached this token.
prevent_newline_after_colon: if True, it will disable generating new line \n after :
Note: Outside of model, other parameters only need to be available on Note: Outside of model, other parameters only need to be available on
rank 0. rank 0.
Outputs: Note that is size is adjusted to a lower value than Outputs: Note that is size is adjusted to a lower value than
...@@ -186,6 +188,8 @@ def generate_tokens_probs_and_return_on_first_stage( ...@@ -186,6 +188,8 @@ def generate_tokens_probs_and_return_on_first_stage(
logits = forward_step(tokens2use, positions2use, attention_mask2use) logits = forward_step(tokens2use, positions2use, attention_mask2use)
if mpu.is_pipeline_last_stage(): if mpu.is_pipeline_last_stage():
if prevent_newline_after_colon:
logits[tokens2use[:, -1] == tokenizer.tokenize(':')[0], -1, tokenizer.tokenize('\n')[0]] = -1e10 # disable "\n" after ":"
# Always the last stage should have an output. # Always the last stage should have an output.
assert logits is not None assert logits is not None
...@@ -281,7 +285,7 @@ def generate_tokens_probs_and_return_on_first_stage( ...@@ -281,7 +285,7 @@ def generate_tokens_probs_and_return_on_first_stage(
return tokens, generated_sequence_lengths, output_log_probs return tokens, generated_sequence_lengths, output_log_probs
def beam_search_and_return_on_first_stage(model, tokens, lengths, beam_size, stop_token, num_return_gen, length_penalty): def beam_search_and_return_on_first_stage(model, tokens, lengths, beam_size, stop_token, num_return_gen, length_penalty, prevent_newline_after_colon=True):
args = get_args() args = get_args()
tokenizer = get_tokenizer() tokenizer = get_tokenizer()
...@@ -324,6 +328,8 @@ def beam_search_and_return_on_first_stage(model, tokens, lengths, beam_size, sto ...@@ -324,6 +328,8 @@ def beam_search_and_return_on_first_stage(model, tokens, lengths, beam_size, sto
logits = forward_step(tokens2use, positions2use, attention_mask2use) logits = forward_step(tokens2use, positions2use, attention_mask2use)
if mpu.is_pipeline_last_stage(): if mpu.is_pipeline_last_stage():
if prevent_newline_after_colon:
logits[tokens2use[:, -1] == tokenizer.tokenize(':')[0], -1, tokenizer.tokenize('\n')[0]] = -1e10 # disable "\n" after ":"
vocab_size = logits.size(2) vocab_size = logits.size(2)
log_probs = F.log_softmax(logits, dim=2) log_probs = F.log_softmax(logits, dim=2)
new_scores = log_probs[:, -1, :] + scores new_scores = log_probs[:, -1, :] + scores
......
...@@ -128,6 +128,12 @@ class MegatronGenerate(Resource): ...@@ -128,6 +128,12 @@ class MegatronGenerate(Resource):
if not isinstance(stop_on_eol, bool): if not isinstance(stop_on_eol, bool):
return "stop_on_eol must be a boolean value" return "stop_on_eol must be a boolean value"
prevent_newline_after_colon = False
if "prevent_newline_after_colon" in request.get_json():
prevent_newline_after_colon = request.get_json()["prevent_newline_after_colon"]
if not isinstance(prevent_newline_after_colon, bool):
return "prevent_newline_after_colon must be a boolean value"
random_seed = -1 random_seed = -1
if "random_seed" in request.get_json(): if "random_seed" in request.get_json():
random_seed = request.get_json()["random_seed"] random_seed = request.get_json()["random_seed"]
...@@ -183,7 +189,8 @@ class MegatronGenerate(Resource): ...@@ -183,7 +189,8 @@ class MegatronGenerate(Resource):
add_BOS=add_BOS, add_BOS=add_BOS,
stop_token=stop_token, stop_token=stop_token,
num_return_gen=beam_width, # Returning whole beam num_return_gen=beam_width, # Returning whole beam
length_penalty=length_penalty length_penalty=length_penalty,
prevent_newline_after_colon=prevent_newline_after_colon
) )
return jsonify({"text": response, return jsonify({"text": response,
...@@ -206,6 +213,7 @@ class MegatronGenerate(Resource): ...@@ -206,6 +213,7 @@ class MegatronGenerate(Resource):
use_eod_token_for_early_termination=True, use_eod_token_for_early_termination=True,
stop_on_double_eol=stop_on_double_eol, stop_on_double_eol=stop_on_double_eol,
stop_on_eol=stop_on_eol, stop_on_eol=stop_on_eol,
prevent_newline_after_colon=prevent_newline_after_colon,
random_seed=random_seed) random_seed=random_seed)
return jsonify({"text": response, return jsonify({"text": response,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment