Commit 9b131fad authored by rprenger's avatar rprenger
Browse files

Adding stop token logic and random seed for deterministic answers

parent 83bc79d6
...@@ -37,7 +37,8 @@ def generate_and_post_process(model, ...@@ -37,7 +37,8 @@ def generate_and_post_process(model,
add_BOS=False, add_BOS=False,
use_eod_token_for_early_termination=True, use_eod_token_for_early_termination=True,
stop_on_double_eol=False, stop_on_double_eol=False,
stop_on_eol=False): stop_on_eol=False,
random_seed=-1):
"""Run inference and post-process outputs, i.e., detokenize, """Run inference and post-process outputs, i.e., detokenize,
move to cpu and convert to list.""" move to cpu and convert to list."""
...@@ -53,7 +54,8 @@ def generate_and_post_process(model, ...@@ -53,7 +54,8 @@ def generate_and_post_process(model,
add_BOS=add_BOS, add_BOS=add_BOS,
use_eod_token_for_early_termination=use_eod_token_for_early_termination, use_eod_token_for_early_termination=use_eod_token_for_early_termination,
stop_on_double_eol=stop_on_double_eol, stop_on_double_eol=stop_on_double_eol,
stop_on_eol=stop_on_eol) stop_on_eol=stop_on_eol,
random_seed=random_seed)
# Only post-process on first stage. # Only post-process on first stage.
if mpu.is_pipeline_first_stage(): if mpu.is_pipeline_first_stage():
...@@ -80,7 +82,8 @@ def generate(model, ...@@ -80,7 +82,8 @@ def generate(model,
add_BOS=False, add_BOS=False,
use_eod_token_for_early_termination=True, use_eod_token_for_early_termination=True,
stop_on_double_eol=False, stop_on_double_eol=False,
stop_on_eol=False): stop_on_eol=False,
random_seed=-1):
"""Given prompts and input parameters, run inference and return: """Given prompts and input parameters, run inference and return:
tokens: prompts plus the generated tokens. tokens: prompts plus the generated tokens.
lengths: length of the prompt + generations. Note that we can lengths: length of the prompt + generations. Note that we can
...@@ -95,8 +98,9 @@ def generate(model, ...@@ -95,8 +98,9 @@ def generate(model,
top_k_sampling, top_p_sampling, top_k_sampling, top_p_sampling,
temperature, add_BOS, use_eod_token_for_early_termination, temperature, add_BOS, use_eod_token_for_early_termination,
stop_on_double_eol, stop_on_double_eol,
stop_on_eol] stop_on_eol,
values_float_tensor = broadcast_float_list(9, float_list=values) random_seed]
values_float_tensor = broadcast_float_list(10, float_list=values)
tokens_to_generate = int(values_float_tensor[0].item()) tokens_to_generate = int(values_float_tensor[0].item())
return_output_log_probs = bool(values_float_tensor[1].item()) return_output_log_probs = bool(values_float_tensor[1].item())
top_k_sampling = int(values_float_tensor[2].item()) top_k_sampling = int(values_float_tensor[2].item())
...@@ -106,6 +110,10 @@ def generate(model, ...@@ -106,6 +110,10 @@ def generate(model,
use_eod_token_for_early_termination = bool(values_float_tensor[6].item()) use_eod_token_for_early_termination = bool(values_float_tensor[6].item())
stop_on_double_eol = bool(values_float_tensor[7].item()) stop_on_double_eol = bool(values_float_tensor[7].item())
stop_on_eol = bool(values_float_tensor[8].item()) stop_on_eol = bool(values_float_tensor[8].item())
random_seed = int(values_float_tensor[9].item())
if random_seed != -1:
torch.random.manual_seed(random_seed)
# Tokenize prompts and get the batch. # Tokenize prompts and get the batch.
# Note that these tensors are broadcaseted to all ranks. # Note that these tensors are broadcaseted to all ranks.
......
...@@ -114,6 +114,14 @@ class MegatronGenerate(Resource): ...@@ -114,6 +114,14 @@ class MegatronGenerate(Resource):
if not isinstance(stop_on_eol, bool): if not isinstance(stop_on_eol, bool):
return "stop_on_eol must be a boolean value" return "stop_on_eol must be a boolean value"
random_seed = -1
if "random_seed" in request.get_json():
random_seed = request.get_json()["random_seed"]
if not isinstance(random_seed, int):
return "random_seed must be integer"
if random_seed < 0:
return "random_seed must be a positive integer"
# if str(request.remote_addr) == "10.14.68.146": # if str(request.remote_addr) == "10.14.68.146":
# return "Too many tokens requested from this IP address. Contact Ryan Prenger rprenger@nvidia.com" # return "Too many tokens requested from this IP address. Contact Ryan Prenger rprenger@nvidia.com"
...@@ -135,7 +143,8 @@ class MegatronGenerate(Resource): ...@@ -135,7 +143,8 @@ class MegatronGenerate(Resource):
add_BOS=add_BOS, add_BOS=add_BOS,
use_eod_token_for_early_termination=True, use_eod_token_for_early_termination=True,
stop_on_double_eol=stop_on_double_eol, stop_on_double_eol=stop_on_double_eol,
stop_on_eol=stop_on_eol) stop_on_eol=stop_on_eol,
random_seed=random_seed)
except ValueError as ve: except ValueError as ve:
return "Length of prompt + tokens_to_generate longer than allowed" return "Length of prompt + tokens_to_generate longer than allowed"
print("end time: ", datetime.datetime.now()) print("end time: ", datetime.datetime.now())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment