Commit 9e0ee6fd authored by rprenger's avatar rprenger
Browse files

Hacked in way to to have stop tokens

parent 148a24ad
...@@ -35,7 +35,9 @@ def generate_and_post_process(model, ...@@ -35,7 +35,9 @@ def generate_and_post_process(model,
top_p_sampling=0.0, top_p_sampling=0.0,
temperature=1.0, temperature=1.0,
add_BOS=False, add_BOS=False,
use_eod_token_for_early_termination=True): use_eod_token_for_early_termination=True,
stop_on_double_eol=False,
stop_on_eol=False):
"""Run inference and post-process outputs, i.e., detokenize, """Run inference and post-process outputs, i.e., detokenize,
move to cpu and convert to list.""" move to cpu and convert to list."""
...@@ -49,7 +51,9 @@ def generate_and_post_process(model, ...@@ -49,7 +51,9 @@ def generate_and_post_process(model,
top_p_sampling=top_p_sampling, top_p_sampling=top_p_sampling,
temperature=temperature, temperature=temperature,
add_BOS=add_BOS, add_BOS=add_BOS,
use_eod_token_for_early_termination=use_eod_token_for_early_termination) use_eod_token_for_early_termination=use_eod_token_for_early_termination,
stop_on_double_eol=stop_on_double_eol,
stop_on_eol=stop_on_eol)
# Only post-process on first stage. # Only post-process on first stage.
if mpu.is_pipeline_first_stage(): if mpu.is_pipeline_first_stage():
...@@ -74,7 +78,9 @@ def generate(model, ...@@ -74,7 +78,9 @@ def generate(model,
top_p_sampling=0.0, top_p_sampling=0.0,
temperature=1.0, temperature=1.0,
add_BOS=False, add_BOS=False,
use_eod_token_for_early_termination=True): use_eod_token_for_early_termination=True,
stop_on_double_eol=False,
stop_on_eol=False):
"""Given prompts and input parameters, run inference and return: """Given prompts and input parameters, run inference and return:
tokens: prompts plus the generated tokens. tokens: prompts plus the generated tokens.
lengths: length of the prompt + generations. Note that we can lengths: length of the prompt + generations. Note that we can
...@@ -87,8 +93,10 @@ def generate(model, ...@@ -87,8 +93,10 @@ def generate(model,
values = [tokens_to_generate, values = [tokens_to_generate,
return_output_log_probs, return_output_log_probs,
top_k_sampling, top_p_sampling, top_k_sampling, top_p_sampling,
temperature, add_BOS, use_eod_token_for_early_termination] temperature, add_BOS, use_eod_token_for_early_termination,
values_float_tensor = broadcast_float_list(7, float_list=values) stop_on_double_eol,
stop_on_eol]
values_float_tensor = broadcast_float_list(9, float_list=values)
tokens_to_generate = int(values_float_tensor[0].item()) tokens_to_generate = int(values_float_tensor[0].item())
return_output_log_probs = bool(values_float_tensor[1].item()) return_output_log_probs = bool(values_float_tensor[1].item())
top_k_sampling = int(values_float_tensor[2].item()) top_k_sampling = int(values_float_tensor[2].item())
...@@ -96,6 +104,8 @@ def generate(model, ...@@ -96,6 +104,8 @@ def generate(model,
temperature = values_float_tensor[4].item() temperature = values_float_tensor[4].item()
add_BOS = bool(values_float_tensor[5].item()) add_BOS = bool(values_float_tensor[5].item())
use_eod_token_for_early_termination = bool(values_float_tensor[6].item()) use_eod_token_for_early_termination = bool(values_float_tensor[6].item())
stop_on_double_eol = bool(values_float_tensor[7].item())
stop_on_eol = bool(values_float_tensor[8].item())
# Tokenize prompts and get the batch. # Tokenize prompts and get the batch.
# Note that these tensors are broadcaseted to all ranks. # Note that these tensors are broadcaseted to all ranks.
...@@ -117,4 +127,6 @@ def generate(model, ...@@ -117,4 +127,6 @@ def generate(model,
top_k=top_k_sampling, top_k=top_k_sampling,
top_p=top_p_sampling, top_p=top_p_sampling,
temperature=temperature, temperature=temperature,
use_eod_token_for_early_termination=use_eod_token_for_early_termination) use_eod_token_for_early_termination=use_eod_token_for_early_termination,
stop_on_double_eol=stop_on_double_eol,
stop_on_eol=stop_on_eol)
...@@ -96,7 +96,10 @@ def generate_tokens_probs_and_return_on_first_stage( ...@@ -96,7 +96,10 @@ def generate_tokens_probs_and_return_on_first_stage(
return_output_log_probs=False, return_output_log_probs=False,
top_k=0, top_p=0.0, top_k=0, top_p=0.0,
temperature=1.0, temperature=1.0,
use_eod_token_for_early_termination=True): use_eod_token_for_early_termination=True,
stop_on_double_eol=False,
stop_on_eol=False
):
"""Main token generation function. """Main token generation function.
Arguments: Arguments:
model: no interleaving is supported. model: no interleaving is supported.
...@@ -231,8 +234,18 @@ def generate_tokens_probs_and_return_on_first_stage( ...@@ -231,8 +234,18 @@ def generate_tokens_probs_and_return_on_first_stage(
# Check if all the sequences have hit the termination_id. # Check if all the sequences have hit the termination_id.
done = None done = None
if mpu.is_pipeline_last_stage(): if mpu.is_pipeline_last_stage():
done_token = (new_sample == termination_id).byte() & \ if stop_on_double_eol:
started.byte() hit_double_eol = (new_sample == 628).byte() & started.byte()
hit_two_eols = (new_sample == 198).byte() & (tokens[:, context_length-1] == 198).byte() & started.byte()
done_token = hit_double_eol | hit_two_eols
elif stop_on_eol:
hit_double_eol = (new_sample == 628).byte() & started.byte()
hit_eol = (new_sample == 198).byte() & started.byte()
done_token = hit_double_eol | hit_eol
else:
done_token = (new_sample == termination_id).byte() & \
started.byte()
just_finished = (done_token & ~is_generation_done).bool() just_finished = (done_token & ~is_generation_done).bool()
generated_sequence_lengths[just_finished.view(-1)] = \ generated_sequence_lengths[just_finished.view(-1)] = \
context_length + 1 context_length + 1
......
...@@ -98,6 +98,24 @@ class MegatronGenerate(Resource): ...@@ -98,6 +98,24 @@ class MegatronGenerate(Resource):
add_BOS = request.get_json()["add_BOS"] add_BOS = request.get_json()["add_BOS"]
if not isinstance(add_BOS, bool): if not isinstance(add_BOS, bool):
return "add_BOS must be a boolean value" return "add_BOS must be a boolean value"
if any([len(prompt) == 0 for prompt in prompts]) and not add_BOS:
return "Empty prompts require add_BOS=true"
stop_on_double_eol = False
if "stop_on_double_eol" in request.get_json():
stop_on_double_eol = request.get_json()["stop_on_double_eol"]
if not isinstance(stop_on_double_eol, bool):
return "stop_on_double_eol must be a boolean value"
stop_on_eol = False
if "stop_on_eol" in request.get_json():
stop_on_eol = request.get_json()["stop_on_eol"]
if not isinstance(stop_on_eol, bool):
return "stop_on_eol must be a boolean value"
if str(request.remote_addr) == "10.14.68.146":
return "Too many tokens requested from this IP address. Contact Ryan Prenger rprenger@nvidia.com"
with lock: # Need to get lock to keep multiple threads from hitting code with lock: # Need to get lock to keep multiple threads from hitting code
print("request IP: " + str(request.remote_addr)) print("request IP: " + str(request.remote_addr))
...@@ -115,7 +133,9 @@ class MegatronGenerate(Resource): ...@@ -115,7 +133,9 @@ class MegatronGenerate(Resource):
top_p_sampling=top_p, top_p_sampling=top_p,
temperature=temperature, temperature=temperature,
add_BOS=add_BOS, add_BOS=add_BOS,
use_eod_token_for_early_termination=True) use_eod_token_for_early_termination=True,
stop_on_double_eol=stop_on_double_eol,
stop_on_eol=stop_on_eol)
except ValueError as ve: except ValueError as ve:
return "Length of prompt + tokens_to_generate longer than allowed" return "Length of prompt + tokens_to_generate longer than allowed"
print("end time: ", datetime.datetime.now()) print("end time: ", datetime.datetime.now())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment