Commit 2bc3c1a4 authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'factual' into 'main'

Implements the top_p decay and top_p bound parameters so from the Factual Sampling work

See merge request ADLR/megatron-lm!423
parents badcb5a7 2eea6216
...@@ -34,6 +34,8 @@ def generate_and_post_process(model, ...@@ -34,6 +34,8 @@ def generate_and_post_process(model,
return_output_log_probs=False, return_output_log_probs=False,
top_k_sampling=0, top_k_sampling=0,
top_p_sampling=0.0, top_p_sampling=0.0,
top_p_decay=0.0,
top_p_bound=0.0,
temperature=1.0, temperature=1.0,
add_BOS=False, add_BOS=False,
use_eod_token_for_early_termination=True, use_eod_token_for_early_termination=True,
...@@ -51,6 +53,8 @@ def generate_and_post_process(model, ...@@ -51,6 +53,8 @@ def generate_and_post_process(model,
return_output_log_probs=return_output_log_probs, return_output_log_probs=return_output_log_probs,
top_k_sampling=top_k_sampling, top_k_sampling=top_k_sampling,
top_p_sampling=top_p_sampling, top_p_sampling=top_p_sampling,
top_p_decay=top_p_decay,
top_p_bound=top_p_bound,
temperature=temperature, temperature=temperature,
add_BOS=add_BOS, add_BOS=add_BOS,
use_eod_token_for_early_termination=use_eod_token_for_early_termination, use_eod_token_for_early_termination=use_eod_token_for_early_termination,
...@@ -79,6 +83,8 @@ def generate(model, ...@@ -79,6 +83,8 @@ def generate(model,
return_output_log_probs=False, return_output_log_probs=False,
top_k_sampling=0, top_k_sampling=0,
top_p_sampling=0.0, top_p_sampling=0.0,
top_p_decay=0.0,
top_p_bound=0.0,
temperature=1.0, temperature=1.0,
add_BOS=False, add_BOS=False,
use_eod_token_for_early_termination=True, use_eod_token_for_early_termination=True,
...@@ -96,22 +102,24 @@ def generate(model, ...@@ -96,22 +102,24 @@ def generate(model,
# Make sure input params are avaialble to all ranks. # Make sure input params are avaialble to all ranks.
values = [tokens_to_generate, values = [tokens_to_generate,
return_output_log_probs, return_output_log_probs,
top_k_sampling, top_p_sampling, top_k_sampling, top_p_sampling, top_p_decay, top_p_bound,
temperature, add_BOS, use_eod_token_for_early_termination, temperature, add_BOS, use_eod_token_for_early_termination,
stop_on_double_eol, stop_on_double_eol,
stop_on_eol, stop_on_eol,
random_seed] random_seed]
values_float_tensor = broadcast_float_list(10, float_list=values) values_float_tensor = broadcast_float_list(12, float_list=values)
tokens_to_generate = int(values_float_tensor[0].item()) tokens_to_generate = int(values_float_tensor[0].item())
return_output_log_probs = bool(values_float_tensor[1].item()) return_output_log_probs = bool(values_float_tensor[1].item())
top_k_sampling = int(values_float_tensor[2].item()) top_k_sampling = int(values_float_tensor[2].item())
top_p_sampling = values_float_tensor[3].item() top_p_sampling = values_float_tensor[3].item()
temperature = values_float_tensor[4].item() top_p_decay = values_float_tensor[4].item()
add_BOS = bool(values_float_tensor[5].item()) top_p_bound = values_float_tensor[5].item()
use_eod_token_for_early_termination = bool(values_float_tensor[6].item()) temperature = values_float_tensor[6].item()
stop_on_double_eol = bool(values_float_tensor[7].item()) add_BOS = bool(values_float_tensor[7].item())
stop_on_eol = bool(values_float_tensor[8].item()) use_eod_token_for_early_termination = bool(values_float_tensor[8].item())
random_seed = int(values_float_tensor[9].item()) stop_on_double_eol = bool(values_float_tensor[9].item())
stop_on_eol = bool(values_float_tensor[10].item())
random_seed = int(values_float_tensor[11].item())
if random_seed != -1: if random_seed != -1:
torch.random.manual_seed(random_seed) torch.random.manual_seed(random_seed)
...@@ -135,6 +143,8 @@ def generate(model, ...@@ -135,6 +143,8 @@ def generate(model,
return_output_log_probs=return_output_log_probs, return_output_log_probs=return_output_log_probs,
top_k=top_k_sampling, top_k=top_k_sampling,
top_p=top_p_sampling, top_p=top_p_sampling,
top_p_decay=top_p_decay,
top_p_bound=top_p_bound,
temperature=temperature, temperature=temperature,
use_eod_token_for_early_termination=use_eod_token_for_early_termination, use_eod_token_for_early_termination=use_eod_token_for_early_termination,
stop_on_double_eol=stop_on_double_eol, stop_on_double_eol=stop_on_double_eol,
......
...@@ -95,7 +95,7 @@ def score_and_return_on_first_stage(model, tokens, lengths): ...@@ -95,7 +95,7 @@ def score_and_return_on_first_stage(model, tokens, lengths):
def generate_tokens_probs_and_return_on_first_stage( def generate_tokens_probs_and_return_on_first_stage(
model, tokens, lengths, model, tokens, lengths,
return_output_log_probs=False, return_output_log_probs=False,
top_k=0, top_p=0.0, top_k=0, top_p=0.0, top_p_decay=0.0, top_p_bound=0.0,
temperature=1.0, temperature=1.0,
use_eod_token_for_early_termination=True, use_eod_token_for_early_termination=True,
stop_on_double_eol=False, stop_on_double_eol=False,
...@@ -201,7 +201,11 @@ def generate_tokens_probs_and_return_on_first_stage( ...@@ -201,7 +201,11 @@ def generate_tokens_probs_and_return_on_first_stage(
top_p=top_p, top_p=top_p,
temperature=temperature, temperature=temperature,
vocab_size=tokenizer.vocab_size) vocab_size=tokenizer.vocab_size)
if top_p > 0.0 and top_p_decay > 0.0:
top_p = top_p * top_p_decay
if top_p_bound > 0.0:
top_p = max(top_p, top_p_bound)
# If a prompt length is smaller or equal th current context # If a prompt length is smaller or equal th current context
# length, it means we have started generating tokens # length, it means we have started generating tokens
started = lengths <= context_length started = lengths <= context_length
......
...@@ -100,6 +100,26 @@ class MegatronGenerate(Resource): ...@@ -100,6 +100,26 @@ class MegatronGenerate(Resource):
if not (0 <= top_p <= 1.0): if not (0 <= top_p <= 1.0):
return "top_p must be less than or equal to 1.0" return "top_p must be less than or equal to 1.0"
top_p_decay = 0.0
if "top_p_decay" in request.get_json():
top_p_decay = request.get_json()["top_p_decay"]
if not (type(top_p_decay) == float):
return "top_p_decay must be a positive float less than or equal to 1.0"
if top_p == 0.0:
return "top_p_decay cannot be set without top_p"
if not (0 <= top_p_decay <= 1.0):
return "top_p_decay must be less than or equal to 1.0"
top_p_bound = 0.0
if "top_p_bound" in request.get_json():
top_p_bound = request.get_json()["top_p_bound"]
if not (type(top_p_bound) == float):
return "top_p_bound must be a positive float less than or equal to top_p"
if top_p == 0.0:
return "top_p_bound cannot be set without top_p"
if not (0.0 < top_p_bound <= top_p):
return "top_p_bound must be greater than 0 and less than top_p"
add_BOS = False add_BOS = False
if "add_BOS" in request.get_json(): if "add_BOS" in request.get_json():
add_BOS = request.get_json()["add_BOS"] add_BOS = request.get_json()["add_BOS"]
...@@ -192,6 +212,8 @@ class MegatronGenerate(Resource): ...@@ -192,6 +212,8 @@ class MegatronGenerate(Resource):
return_output_log_probs=logprobs, return_output_log_probs=logprobs,
top_k_sampling=top_k, top_k_sampling=top_k,
top_p_sampling=top_p, top_p_sampling=top_p,
top_p_decay=top_p_decay,
top_p_bound=top_p_bound,
temperature=temperature, temperature=temperature,
add_BOS=add_BOS, add_BOS=add_BOS,
use_eod_token_for_early_termination=True, use_eod_token_for_early_termination=True,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment