Commit fd176a90 authored by rprenger's avatar rprenger
Browse files

Making the API talk to the server, fixed a bug where parameters weren't getting forwarded to node

parent 5a570bd8
...@@ -146,7 +146,8 @@ def beam_search_and_post_process(model, ...@@ -146,7 +146,8 @@ def beam_search_and_post_process(model,
beam_size=0, beam_size=0,
add_BOS=False, add_BOS=False,
stop_token=50256, stop_token=50256,
num_return_gen=1): num_return_gen=1,
length_penalty=1):
"""Run beam search and post-process outputs, i.e., detokenize, """Run beam search and post-process outputs, i.e., detokenize,
move to cpu and convert to list.""" move to cpu and convert to list."""
...@@ -157,7 +158,8 @@ def beam_search_and_post_process(model, ...@@ -157,7 +158,8 @@ def beam_search_and_post_process(model,
beam_size=beam_size, beam_size=beam_size,
add_BOS=add_BOS, add_BOS=add_BOS,
stop_token=stop_token, stop_token=stop_token,
num_return_gen=num_return_gen) num_return_gen=num_return_gen,
length_penalty=length_penalty)
# Only post-process on first stage. # Only post-process on first stage.
if mpu.is_pipeline_first_stage(): if mpu.is_pipeline_first_stage():
lengths = tokens.size(1)*torch.ones(beam_size, dtype=torch.int64, device=torch.cuda.current_device()) lengths = tokens.size(1)*torch.ones(beam_size, dtype=torch.int64, device=torch.cuda.current_device())
...@@ -167,18 +169,24 @@ def beam_search_and_post_process(model, ...@@ -167,18 +169,24 @@ def beam_search_and_post_process(model,
return None return None
def beam_search(model, prompts=None, tokens_to_generate=0, beam_size=0, add_BOS=False, stop_token=50256, num_return_gen=1): def beam_search(model, prompts=None, tokens_to_generate=0, beam_size=0, add_BOS=False, stop_token=50256, num_return_gen=1, length_penalty=1):
# Make sure input params are avaialble to all ranks. # Make sure input params are avaialble to all ranks.
values = [tokens_to_generate, values = [tokens_to_generate,
beam_size, beam_size,
add_BOS] add_BOS,
stop_token,
num_return_gen,
length_penalty]
values_float_tensor = broadcast_float_list(3, float_list=values) values_float_tensor = broadcast_float_list(3, float_list=values)
tokens_to_generate = int(values_float_tensor[0].item()) tokens_to_generate = int(values_float_tensor[0].item())
beam_size = int(values_float_tensor[1].item()) beam_size = int(values_float_tensor[1].item())
add_BOS = bool(values_float_tensor[2].item()) add_BOS = bool(values_float_tensor[2].item())
stop_token = int(values_float_tensor[3].item())
num_return_gen = int(values_float_tensor[4].item())
length_penalty = values_float_tensor[5].item()
context_tokens_tensor, context_length_tensor = tokenize_prompts( context_tokens_tensor, context_length_tensor = tokenize_prompts(
prompts=prompts, tokens_to_generate=tokens_to_generate, add_BOS=add_BOS) prompts=prompts, tokens_to_generate=tokens_to_generate, add_BOS=add_BOS)
return beam_search_and_return_on_first_stage(model, context_tokens_tensor, context_length_tensor, return beam_search_and_return_on_first_stage(model, context_tokens_tensor, context_length_tensor,
beam_size, stop_token=stop_token, num_return_gen=num_return_gen) beam_size, stop_token=stop_token, num_return_gen=num_return_gen, length_penalty=length_penalty)
...@@ -328,7 +328,7 @@ class BeamHypotheses(object): ...@@ -328,7 +328,7 @@ class BeamHypotheses(object):
ret = self.worst_score >= cur_score ret = self.worst_score >= cur_score
return ret return ret
def beam_search_and_return_on_first_stage(model, tokens, lengths, beam_size, stop_token, num_return_gen=1): def beam_search_and_return_on_first_stage(model, tokens, lengths, beam_size, stop_token, num_return_gen, length_penalty):
args = get_args() args = get_args()
tokenizer = get_tokenizer() tokenizer = get_tokenizer()
...@@ -345,7 +345,7 @@ def beam_search_and_return_on_first_stage(model, tokens, lengths, beam_size, sto ...@@ -345,7 +345,7 @@ def beam_search_and_return_on_first_stage(model, tokens, lengths, beam_size, sto
# forward step. # forward step.
forward_step = ForwardStep(model, beam_size, final_sequence_length) forward_step = ForwardStep(model, beam_size, final_sequence_length)
beam_hyp = BeamHypotheses(beam_size) beam_hyp = BeamHypotheses(beam_size, length_penalty)
done = False done = False
if mpu.is_pipeline_last_stage(): if mpu.is_pipeline_last_stage():
scores = torch.zeros(beam_size, scores = torch.zeros(beam_size,
......
...@@ -135,20 +135,28 @@ class MegatronGenerate(Resource): ...@@ -135,20 +135,28 @@ class MegatronGenerate(Resource):
if not isinstance(no_log, bool): if not isinstance(no_log, bool):
return "no_log must be a boolean value" return "no_log must be a boolean value"
beam_search = False beam_width = None
if "beam_search" in request.get_json(): if "beam_width" in request.get_json():
beam_search = request.get_json()["beam_search"] beam_width = request.get_json()["beam_width"]
if not isinstance(no_log, bool): if not isinstance(beam_width, int):
return "beam_search must be a boolean value" return "beam_width must be integer"
if beam_width < 1:
beam_size = 4 return "beam_width must be an integer > 1"
if "beam_size" in request.get_json(): if len(prompts) > 1:
beam_size = request.get_json()["beam_size"] return "When doing beam_search, batch size must be 1"
if not isinstance(beam_size, int):
return "beam_size must be integer"
if beam_size < 1:
return "beam_size must be an integer > 1"
stop_token=50256
if "stop_token" in request.get_json():
stop_token = request.get_json()["stop_token"]
if not isinstance(stop_token, int):
return "stop_token must be an integer"
length_penalty = 1
if "length_penalty" in request.get_json():
length_penalty = request.get_json()["length_penalty"]
if not isinstance(length_penalty, float):
return "length_penalty must be a float"
with lock: # Need to get lock to keep multiple threads from hitting code with lock: # Need to get lock to keep multiple threads from hitting code
if not no_log: if not no_log:
...@@ -157,15 +165,19 @@ class MegatronGenerate(Resource): ...@@ -157,15 +165,19 @@ class MegatronGenerate(Resource):
print("start time: ", datetime.datetime.now()) print("start time: ", datetime.datetime.now())
try: try:
if beam_search: if beam_width is not None:
MegatronGenerate.send_do_beam_search() # Tell other ranks we're doing beam_search MegatronGenerate.send_do_beam_search() # Tell other ranks we're doing beam_search
response, response_seg, response_scores = \ response, response_seg, response_scores = \
beam_search_and_post_process( beam_search_and_post_process(
self.model, self.model,
prompts=prompts, prompts=prompts,
tokens_to_generate=tokens_to_generate, tokens_to_generate=tokens_to_generate,
beam_size = beam_size, beam_size = beam_width,
add_BOS=add_BOS) add_BOS=add_BOS,
stop_token=stop_token,
num_return_gen=beam_width, # Returning whole beam
length_penalty=length_penalty
)
return jsonify({"text": response, return jsonify({"text": response,
"segments": response_seg, "segments": response_seg,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment