Commit 5ab28fd1 authored by rprenger's avatar rprenger
Browse files

getting rid of unnecessary just_score

parent a3770921
...@@ -27,8 +27,6 @@ from .tokenization import ( ...@@ -27,8 +27,6 @@ from .tokenization import (
tokenize_prompts, tokenize_prompts,
detokenize_generations) detokenize_generations)
def generate_and_post_process(model, def generate_and_post_process(model,
prompts=None, prompts=None,
tokens_to_generate=0, tokens_to_generate=0,
...@@ -37,8 +35,7 @@ def generate_and_post_process(model, ...@@ -37,8 +35,7 @@ def generate_and_post_process(model,
top_p_sampling=0.0, top_p_sampling=0.0,
temperature=1.0, temperature=1.0,
add_BOS=False, add_BOS=False,
use_eod_token_for_early_termination=True, use_eod_token_for_early_termination=True):
just_score=False):
"""Run inference and post-process outputs, i.e., detokenize, """Run inference and post-process outputs, i.e., detokenize,
move to cpu and convert to list.""" move to cpu and convert to list."""
...@@ -52,8 +49,7 @@ def generate_and_post_process(model, ...@@ -52,8 +49,7 @@ def generate_and_post_process(model,
top_p_sampling=top_p_sampling, top_p_sampling=top_p_sampling,
temperature=temperature, temperature=temperature,
add_BOS=add_BOS, add_BOS=add_BOS,
use_eod_token_for_early_termination=use_eod_token_for_early_termination, use_eod_token_for_early_termination=use_eod_token_for_early_termination)
just_score=just_score)
# Only post-process on first stage. # Only post-process on first stage.
if mpu.is_pipeline_first_stage(): if mpu.is_pipeline_first_stage():
...@@ -70,8 +66,6 @@ def generate_and_post_process(model, ...@@ -70,8 +66,6 @@ def generate_and_post_process(model,
return None return None
def generate(model, def generate(model,
prompts=None, prompts=None,
tokens_to_generate=0, tokens_to_generate=0,
...@@ -80,8 +74,7 @@ def generate(model, ...@@ -80,8 +74,7 @@ def generate(model,
top_p_sampling=0.0, top_p_sampling=0.0,
temperature=1.0, temperature=1.0,
add_BOS=False, add_BOS=False,
use_eod_token_for_early_termination=True, use_eod_token_for_early_termination=True):
just_score=False):
"""Given prompts and input parameters, run inference and return: """Given prompts and input parameters, run inference and return:
tokens: prompts plus the generated tokens. tokens: prompts plus the generated tokens.
lengths: length of the prompt + generations. Note that we can lengths: length of the prompt + generations. Note that we can
...@@ -94,8 +87,8 @@ def generate(model, ...@@ -94,8 +87,8 @@ def generate(model,
values = [tokens_to_generate, values = [tokens_to_generate,
return_output_log_probs, return_output_log_probs,
top_k_sampling, top_p_sampling, top_k_sampling, top_p_sampling,
temperature, add_BOS, use_eod_token_for_early_termination, just_score] temperature, add_BOS, use_eod_token_for_early_termination]
values_float_tensor = broadcast_float_list(8, float_list=values) values_float_tensor = broadcast_float_list(7, float_list=values)
tokens_to_generate = int(values_float_tensor[0].item()) tokens_to_generate = int(values_float_tensor[0].item())
return_output_log_probs = bool(values_float_tensor[1].item()) return_output_log_probs = bool(values_float_tensor[1].item())
top_k_sampling = int(values_float_tensor[2].item()) top_k_sampling = int(values_float_tensor[2].item())
...@@ -103,7 +96,6 @@ def generate(model, ...@@ -103,7 +96,6 @@ def generate(model,
temperature = values_float_tensor[4].item() temperature = values_float_tensor[4].item()
add_BOS = bool(values_float_tensor[5].item()) add_BOS = bool(values_float_tensor[5].item())
use_eod_token_for_early_termination = bool(values_float_tensor[6].item()) use_eod_token_for_early_termination = bool(values_float_tensor[6].item())
just_score = bool(values_float_tensor[7].item())
# Tokenize prompts and get the batch. # Tokenize prompts and get the batch.
# Note that these tensors are broadcaseted to all ranks. # Note that these tensors are broadcaseted to all ranks.
...@@ -113,7 +105,7 @@ def generate(model, ...@@ -113,7 +105,7 @@ def generate(model,
context_tokens_tensor, context_length_tensor = tokenize_prompts( context_tokens_tensor, context_length_tensor = tokenize_prompts(
prompts=prompts, tokens_to_generate=tokens_to_generate, add_BOS=add_BOS) prompts=prompts, tokens_to_generate=tokens_to_generate, add_BOS=add_BOS)
if just_score: if tokens_to_generate == 0:
return score_and_return_on_first_stage( return score_and_return_on_first_stage(
model, context_tokens_tensor, context_length_tensor) model, context_tokens_tensor, context_length_tensor)
......
...@@ -54,15 +54,12 @@ class MegatronGenerate(Resource): ...@@ -54,15 +54,12 @@ class MegatronGenerate(Resource):
return "Maximum number of prompts is 128", 400 return "Maximum number of prompts is 128", 400
tokens_to_generate = 64 # Choosing hopefully sane default. Full sequence is slow tokens_to_generate = 64 # Choosing hopefully sane default. Full sequence is slow
just_score=False
if "tokens_to_generate" in request.get_json(): if "tokens_to_generate" in request.get_json():
tokens_to_generate = request.get_json()["tokens_to_generate"] tokens_to_generate = request.get_json()["tokens_to_generate"]
if not isinstance(tokens_to_generate, int): if not isinstance(tokens_to_generate, int):
return "tokens_to_generate must be an integer greater than 0" return "tokens_to_generate must be an integer greater than 0"
if tokens_to_generate < 0: if tokens_to_generate < 0:
return "tokens_to_generate must be an integer greater than or equal to 0" return "tokens_to_generate must be an integer greater than or equal to 0"
if tokens_to_generate == 0:
just_score = True
logprobs = False logprobs = False
if "logprobs" in request.get_json(): if "logprobs" in request.get_json():
...@@ -70,8 +67,8 @@ class MegatronGenerate(Resource): ...@@ -70,8 +67,8 @@ class MegatronGenerate(Resource):
if not isinstance(logprobs, bool): if not isinstance(logprobs, bool):
return "logprobs must be a boolean value" return "logprobs must be a boolean value"
if just_score and not logprobs: if tokens_to_generate == 0 and not logprobs:
return "tokens_to_generate=0 implies logprobs=True" return "tokens_to_generate=0 implies logprobs should be True"
temperature = 1.0 temperature = 1.0
if "temperature" in request.get_json(): if "temperature" in request.get_json():
...@@ -117,8 +114,7 @@ class MegatronGenerate(Resource): ...@@ -117,8 +114,7 @@ class MegatronGenerate(Resource):
top_p_sampling=top_p, top_p_sampling=top_p,
temperature=temperature, temperature=temperature,
add_BOS=add_BOS, add_BOS=add_BOS,
use_eod_token_for_early_termination=True, use_eod_token_for_early_termination=True)
just_score=just_score)
return jsonify({"text": response, return jsonify({"text": response,
"segments": response_seg, "segments": response_seg,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment