Commit 96816d3d authored by peng xu's avatar peng xu
Browse files

rename hyp and allow return multiple samples

parent 4f579b55
...@@ -144,7 +144,9 @@ def beam_search_and_post_process(model, ...@@ -144,7 +144,9 @@ def beam_search_and_post_process(model,
prompts=None, prompts=None,
tokens_to_generate=0, tokens_to_generate=0,
beam_size=0, beam_size=0,
add_BOS=False): add_BOS=False,
stop_token=50256,
num_return_gen=1):
"""Run beam search and post-process outputs, i.e., detokenize, """Run beam search and post-process outputs, i.e., detokenize,
move to cpu and convert to list.""" move to cpu and convert to list."""
...@@ -153,7 +155,9 @@ def beam_search_and_post_process(model, ...@@ -153,7 +155,9 @@ def beam_search_and_post_process(model,
prompts=prompts, prompts=prompts,
tokens_to_generate=tokens_to_generate, tokens_to_generate=tokens_to_generate,
beam_size=beam_size, beam_size=beam_size,
add_BOS=add_BOS) add_BOS=add_BOS,
stop_token=stop_token,
num_return_gen=num_return_gen)
# Only post-process on first stage. # Only post-process on first stage.
if mpu.is_pipeline_first_stage(): if mpu.is_pipeline_first_stage():
lengths = tokens.size(1)*torch.ones(beam_size, dtype=torch.int64, device=torch.cuda.current_device()) lengths = tokens.size(1)*torch.ones(beam_size, dtype=torch.int64, device=torch.cuda.current_device())
...@@ -163,7 +167,7 @@ def beam_search_and_post_process(model, ...@@ -163,7 +167,7 @@ def beam_search_and_post_process(model,
return None return None
def beam_search(model, prompts=None, tokens_to_generate=0, beam_size=0, add_BOS=False, stop_token=50256): def beam_search(model, prompts=None, tokens_to_generate=0, beam_size=0, add_BOS=False, stop_token=50256, num_return_gen=1):
# Make sure input params are avaialble to all ranks. # Make sure input params are avaialble to all ranks.
values = [tokens_to_generate, values = [tokens_to_generate,
beam_size, beam_size,
...@@ -176,4 +180,5 @@ def beam_search(model, prompts=None, tokens_to_generate=0, beam_size=0, add_BOS= ...@@ -176,4 +180,5 @@ def beam_search(model, prompts=None, tokens_to_generate=0, beam_size=0, add_BOS=
context_tokens_tensor, context_length_tensor = tokenize_prompts( context_tokens_tensor, context_length_tensor = tokenize_prompts(
prompts=prompts, tokens_to_generate=tokens_to_generate, add_BOS=add_BOS) prompts=prompts, tokens_to_generate=tokens_to_generate, add_BOS=add_BOS)
return beam_search_and_return_on_first_stage(model, context_tokens_tensor, context_length_tensor, beam_size, stop_token=stop_token) return beam_search_and_return_on_first_stage(model, context_tokens_tensor, context_length_tensor,
beam_size, stop_token=stop_token, num_return_gen=num_return_gen)
...@@ -328,7 +328,7 @@ class BeamHypotheses(object): ...@@ -328,7 +328,7 @@ class BeamHypotheses(object):
ret = self.worst_score >= cur_score ret = self.worst_score >= cur_score
return ret return ret
def beam_search_and_return_on_first_stage(model, tokens, lengths, beam_size, stop_token): def beam_search_and_return_on_first_stage(model, tokens, lengths, beam_size, stop_token, num_return_gen=1):
args = get_args() args = get_args()
tokenizer = get_tokenizer() tokenizer = get_tokenizer()
...@@ -345,7 +345,7 @@ def beam_search_and_return_on_first_stage(model, tokens, lengths, beam_size, sto ...@@ -345,7 +345,7 @@ def beam_search_and_return_on_first_stage(model, tokens, lengths, beam_size, sto
# forward step. # forward step.
forward_step = ForwardStep(model, beam_size, final_sequence_length) forward_step = ForwardStep(model, beam_size, final_sequence_length)
hyp = BeamHypotheses(beam_size) beam_hyp = BeamHypotheses(beam_size)
done = False done = False
if mpu.is_pipeline_last_stage(): if mpu.is_pipeline_last_stage():
scores = torch.zeros(beam_size, scores = torch.zeros(beam_size,
...@@ -392,7 +392,7 @@ def beam_search_and_return_on_first_stage(model, tokens, lengths, beam_size, sto ...@@ -392,7 +392,7 @@ def beam_search_and_return_on_first_stage(model, tokens, lengths, beam_size, sto
is_beam_token_worse_than_top_num_beams = beam_token_rank >= beam_size is_beam_token_worse_than_top_num_beams = beam_token_rank >= beam_size
if is_beam_token_worse_than_top_num_beams: if is_beam_token_worse_than_top_num_beams:
continue continue
hyp.add( beam_hyp.add(
tokens[beam_id].clone(), tokens[beam_id].clone(),
beam_score, beam_score,
context_length + 1 - prompt_length context_length + 1 - prompt_length
...@@ -404,7 +404,7 @@ def beam_search_and_return_on_first_stage(model, tokens, lengths, beam_size, sto ...@@ -404,7 +404,7 @@ def beam_search_and_return_on_first_stage(model, tokens, lengths, beam_size, sto
if len(next_beams) == beam_size: if len(next_beams) == beam_size:
break break
if hyp.is_done(best_scores.max().item(), context_length + 1 - prompt_length): if beam_hyp.is_done(best_scores.max().item(), context_length + 1 - prompt_length):
done = True done = True
break break
...@@ -430,13 +430,14 @@ def beam_search_and_return_on_first_stage(model, tokens, lengths, beam_size, sto ...@@ -430,13 +430,14 @@ def beam_search_and_return_on_first_stage(model, tokens, lengths, beam_size, sto
# if cannot find stop token, add open beams to hyps # if cannot find stop token, add open beams to hyps
if not done: if not done:
for beam_id in range(beam_size): for beam_id in range(beam_size):
hyp.add(tokens[beam_id].clone(), scores[beam_id], context_length + 1 - prompt_length) beam_hyp.add(tokens[beam_id].clone(), scores[beam_id], context_length + 1 - prompt_length)
# rank based on scores # rank based on scores
sorted_hyps = sorted(hyp.beams, key=lambda x: x[0], reverse=True) sorted_hyps = sorted(beam_hyp.beams, key=lambda x: x[0], reverse=True)
scores, tokens = sorted_hyps[0] scores = [sorted_hyps[i][0] for i in range(num_return_gen)]
scores = scores.unsqueeze(0) tokens = [sorted_hyps[i][1] for i in range(num_return_gen)]
tokens = tokens.unsqueeze(0) scores = torch.stack(scores, dim=0)
tokens = torch.stack(tokens, dim=0)
return tokens, scores return tokens, scores
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment