Commit 4f579b55 authored by Peng Xu's avatar Peng Xu
Browse files

fix beam search

parent de84b2af
......@@ -163,7 +163,7 @@ def beam_search_and_post_process(model,
return None
def beam_search(model, prompts=None, tokens_to_generate=0, beam_size=0, add_BOS=False):
def beam_search(model, prompts=None, tokens_to_generate=0, beam_size=0, add_BOS=False, stop_token=50256):
# Make sure input params are avaialble to all ranks.
values = [tokens_to_generate,
beam_size,
......@@ -176,4 +176,4 @@ def beam_search(model, prompts=None, tokens_to_generate=0, beam_size=0, add_BOS=
context_tokens_tensor, context_length_tensor = tokenize_prompts(
prompts=prompts, tokens_to_generate=tokens_to_generate, add_BOS=add_BOS)
return beam_search_and_return_on_first_stage(model, context_tokens_tensor, context_length_tensor, beam_size)
return beam_search_and_return_on_first_stage(model, context_tokens_tensor, context_length_tensor, beam_size, stop_token=stop_token)
......@@ -42,7 +42,18 @@ class InferenceParams:
self.batch_size_offset = 0
self.key_value_memory_dict = {}
def swap_key_value_dict(self, batch_idx):
"swap between batches"
if len(self.key_value_memory_dict) == 0:
raise ValueError("should not swap when dict in empty")
for layer_number in self.key_value_memory_dict.keys():
inference_key_memory, inference_value_memory = self.key_value_memory_dict[layer_number]
assert len(batch_idx) == inference_key_memory.shape[1] ## make sure batch size is the same
new_inference_key_memory = inference_key_memory[:, batch_idx]
new_inference_value_memory = inference_value_memory[:, batch_idx]
self.key_value_memory_dict[layer_number] = (
new_inference_key_memory, new_inference_value_memory)
class ForwardStep:
"""Forward step function with all the communications.
......
......@@ -258,7 +258,7 @@ def generate_tokens_probs_and_return_on_first_stage(
tensor=done)
if use_eod_token_for_early_termination and done:
break
# ===================================================
# Update the length of based on max generated length.
# ===================================================
......@@ -281,8 +281,54 @@ def generate_tokens_probs_and_return_on_first_stage(
return tokens, generated_sequence_lengths, output_log_probs
def beam_search_and_return_on_first_stage(model, tokens, lengths, beam_size):
## from huggingface beam search
class BeamHypotheses(object):
def __init__(self, num_beams, length_penalty=1.0, early_stopping=False):
"""
Initialize n-best list of hypotheses.
"""
self.length_penalty = length_penalty
self.early_stopping = early_stopping
self.num_beams = num_beams
self.beams = []
self.worst_score = 1e9
def __len__(self):
"""
Number of hypotheses in the list.
"""
return len(self.beams)
def add(self, hyp, sum_logprobs, length):
"""
Add a new hypothesis to the list.
"""
score = sum_logprobs / length ** self.length_penalty
if len(self) < self.num_beams or score > self.worst_score:
self.beams.append((score, hyp))
if len(self) > self.num_beams:
sorted_scores = sorted([(s, idx) for idx, (s, _) in enumerate(self.beams)])
del self.beams[sorted_scores[0][1]]
self.worst_score = sorted_scores[1][0]
else:
self.worst_score = min(score, self.worst_score)
def is_done(self, best_sum_logprobs, cur_len):
"""
If there are enough hypotheses and that none of the hypotheses being generated
can become better than the worst one in the heap, then we are done with this sentence.
"""
if len(self) < self.num_beams:
return False
elif self.early_stopping:
return True
else:
cur_score = best_sum_logprobs / cur_len ** self.length_penalty
ret = self.worst_score >= cur_score
return ret
def beam_search_and_return_on_first_stage(model, tokens, lengths, beam_size, stop_token):
args = get_args()
tokenizer = get_tokenizer()
......@@ -299,6 +345,8 @@ def beam_search_and_return_on_first_stage(model, tokens, lengths, beam_size):
# forward step.
forward_step = ForwardStep(model, beam_size, final_sequence_length)
hyp = BeamHypotheses(beam_size)
done = False
if mpu.is_pipeline_last_stage():
scores = torch.zeros(beam_size,
dtype=torch.float32,
......@@ -331,13 +379,43 @@ def beam_search_and_return_on_first_stage(model, tokens, lengths, beam_size):
else:
sorted_scores, indices = torch.sort(new_scores.view(-1), descending=True)
best_batches = torch.div(indices[:beam_size], vocab_size, rounding_mode='floor')
best_words = indices[:beam_size] % vocab_size
best_beam_ids = torch.div(indices[: 2 * beam_size], vocab_size).trunc().long()
best_words = indices[:2 * beam_size] % vocab_size
best_scores = sorted_scores[: 2 * beam_size]
next_beams = []
for beam_token_rank, (token_id, beam_score, beam_id) in enumerate(
zip(best_words, best_scores, best_beam_ids)
):
if token_id.item() == stop_token:
# if beam_token does not belong to top num_beams tokens, it should not be added
is_beam_token_worse_than_top_num_beams = beam_token_rank >= beam_size
if is_beam_token_worse_than_top_num_beams:
continue
hyp.add(
tokens[beam_id].clone(),
beam_score,
context_length + 1 - prompt_length
)
else:
# add next predicted token since it is not eos_token
next_beams.append((token_id, beam_score, beam_id))
if len(next_beams) == beam_size:
break
if hyp.is_done(best_scores.max().item(), context_length + 1 - prompt_length):
done = True
break
best_batches = tokens.new([item[2] for item in next_beams])
tokens = tokens[best_batches,:]
tokens[:, context_length] = best_words
scores = sorted_scores[:beam_size].unsqueeze(1)
tokens[:, context_length] = tokens.new([item[0] for item in next_beams])
scores = scores.new([item[1] for item in next_beams]).unsqueeze(1)
# set inference key values to make it consistent with best beam index
forward_step.inference_params.swap_key_value_dict(best_batches)
# Update the tokens on the first stage so the next input to
# the network is correct.
copy_from_last_to_first_pipeline_stage(batch_size, torch.int64,
......@@ -348,6 +426,18 @@ def beam_search_and_return_on_first_stage(model, tokens, lengths, beam_size):
copy_from_last_to_first_pipeline_stage(scores.size(0), torch.float32,
scores[:,0])
# if cannot find stop token, add open beams to hyps
if not done:
for beam_id in range(beam_size):
hyp.add(tokens[beam_id].clone(), scores[beam_id], context_length + 1 - prompt_length)
# rank based on scores
sorted_hyps = sorted(hyp.beams, key=lambda x: x[0], reverse=True)
scores, tokens = sorted_hyps[0]
scores = scores.unsqueeze(0)
tokens = tokens.unsqueeze(0)
return tokens, scores
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment