Commit 96816d3d authored by peng xu's avatar peng xu
Browse files

rename hyp and allow return multiple samples

parent 4f579b55
......@@ -144,7 +144,9 @@ def beam_search_and_post_process(model,
prompts=None,
tokens_to_generate=0,
beam_size=0,
add_BOS=False):
add_BOS=False,
stop_token=50256,
num_return_gen=1):
"""Run beam search and post-process outputs, i.e., detokenize,
move to cpu and convert to list."""
......@@ -153,7 +155,9 @@ def beam_search_and_post_process(model,
prompts=prompts,
tokens_to_generate=tokens_to_generate,
beam_size=beam_size,
add_BOS=add_BOS)
add_BOS=add_BOS,
stop_token=stop_token,
num_return_gen=num_return_gen)
# Only post-process on first stage.
if mpu.is_pipeline_first_stage():
lengths = tokens.size(1)*torch.ones(beam_size, dtype=torch.int64, device=torch.cuda.current_device())
......@@ -163,7 +167,7 @@ def beam_search_and_post_process(model,
return None
def beam_search(model, prompts=None, tokens_to_generate=0, beam_size=0, add_BOS=False, stop_token=50256):
def beam_search(model, prompts=None, tokens_to_generate=0, beam_size=0, add_BOS=False, stop_token=50256, num_return_gen=1):
# Make sure input params are avaialble to all ranks.
values = [tokens_to_generate,
beam_size,
......@@ -176,4 +180,5 @@ def beam_search(model, prompts=None, tokens_to_generate=0, beam_size=0, add_BOS=
context_tokens_tensor, context_length_tensor = tokenize_prompts(
prompts=prompts, tokens_to_generate=tokens_to_generate, add_BOS=add_BOS)
return beam_search_and_return_on_first_stage(model, context_tokens_tensor, context_length_tensor, beam_size, stop_token=stop_token)
return beam_search_and_return_on_first_stage(model, context_tokens_tensor, context_length_tensor,
beam_size, stop_token=stop_token, num_return_gen=num_return_gen)
......@@ -328,7 +328,7 @@ class BeamHypotheses(object):
ret = self.worst_score >= cur_score
return ret
def beam_search_and_return_on_first_stage(model, tokens, lengths, beam_size, stop_token):
def beam_search_and_return_on_first_stage(model, tokens, lengths, beam_size, stop_token, num_return_gen=1):
args = get_args()
tokenizer = get_tokenizer()
......@@ -345,7 +345,7 @@ def beam_search_and_return_on_first_stage(model, tokens, lengths, beam_size, sto
# forward step.
forward_step = ForwardStep(model, beam_size, final_sequence_length)
hyp = BeamHypotheses(beam_size)
beam_hyp = BeamHypotheses(beam_size)
done = False
if mpu.is_pipeline_last_stage():
scores = torch.zeros(beam_size,
......@@ -392,7 +392,7 @@ def beam_search_and_return_on_first_stage(model, tokens, lengths, beam_size, sto
is_beam_token_worse_than_top_num_beams = beam_token_rank >= beam_size
if is_beam_token_worse_than_top_num_beams:
continue
hyp.add(
beam_hyp.add(
tokens[beam_id].clone(),
beam_score,
context_length + 1 - prompt_length
......@@ -404,7 +404,7 @@ def beam_search_and_return_on_first_stage(model, tokens, lengths, beam_size, sto
if len(next_beams) == beam_size:
break
if hyp.is_done(best_scores.max().item(), context_length + 1 - prompt_length):
if beam_hyp.is_done(best_scores.max().item(), context_length + 1 - prompt_length):
done = True
break
......@@ -430,13 +430,14 @@ def beam_search_and_return_on_first_stage(model, tokens, lengths, beam_size, sto
# if cannot find stop token, add open beams to hyps
if not done:
for beam_id in range(beam_size):
hyp.add(tokens[beam_id].clone(), scores[beam_id], context_length + 1 - prompt_length)
beam_hyp.add(tokens[beam_id].clone(), scores[beam_id], context_length + 1 - prompt_length)
# rank based on scores
sorted_hyps = sorted(hyp.beams, key=lambda x: x[0], reverse=True)
scores, tokens = sorted_hyps[0]
scores = scores.unsqueeze(0)
tokens = tokens.unsqueeze(0)
sorted_hyps = sorted(beam_hyp.beams, key=lambda x: x[0], reverse=True)
scores = [sorted_hyps[i][0] for i in range(num_return_gen)]
tokens = [sorted_hyps[i][1] for i in range(num_return_gen)]
scores = torch.stack(scores, dim=0)
tokens = torch.stack(tokens, dim=0)
return tokens, scores
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment