Unverified Commit 3d3e605a authored by Sam Shleifer's avatar Sam Shleifer Committed by GitHub
Browse files

[cleanup] generate_beam_search comments (#5115)

parent ca2d0f98
...@@ -1219,9 +1219,9 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): ...@@ -1219,9 +1219,9 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
if len(next_sent_beam) == num_beams: if len(next_sent_beam) == num_beams:
break break
# Check if were done so that we can save a pad step if all(done) # Check if we are done so that we can save a pad step if all(done)
done[batch_idx] = done[batch_idx] or generated_hyps[batch_idx].is_done( done[batch_idx] = done[batch_idx] or generated_hyps[batch_idx].is_done(
tf.reduce_max(next_scores[batch_idx]).numpy(), cur_len=cur_len tf.reduce_max(next_scores[batch_idx]).numpy(), cur_len
) )
# update next beam content # update next beam content
...@@ -1509,7 +1509,7 @@ class BeamHypotheses(object): ...@@ -1509,7 +1509,7 @@ class BeamHypotheses(object):
else: else:
self.worst_score = min(score, self.worst_score) self.worst_score = min(score, self.worst_score)
def is_done(self, best_sum_logprobs, cur_len=None): def is_done(self, best_sum_logprobs, cur_len):
""" """
If there are enough hypotheses and that none of the hypotheses being generated If there are enough hypotheses and that none of the hypotheses being generated
can become better than the worst one in the heap, then we are done with this sentence. can become better than the worst one in the heap, then we are done with this sentence.
...@@ -1520,8 +1520,6 @@ class BeamHypotheses(object): ...@@ -1520,8 +1520,6 @@ class BeamHypotheses(object):
elif self.early_stopping: elif self.early_stopping:
return True return True
else: else:
if cur_len is None:
cur_len = self.max_length
cur_score = best_sum_logprobs / cur_len ** self.length_penalty cur_score = best_sum_logprobs / cur_len ** self.length_penalty
ret = self.worst_score >= cur_score ret = self.worst_score >= cur_score
return ret return ret
......
...@@ -1462,7 +1462,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ...@@ -1462,7 +1462,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
# for each sentence # for each sentence
for batch_idx in range(batch_size): for batch_idx in range(batch_size):
# if we are done with this sentence # if we are done with this sentence, add a pad token
if done[batch_idx]: if done[batch_idx]:
assert ( assert (
len(generated_hyps[batch_idx]) >= num_beams len(generated_hyps[batch_idx]) >= num_beams
...@@ -1473,7 +1473,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ...@@ -1473,7 +1473,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
next_batch_beam.extend([(0, pad_token_id, 0)] * num_beams) # pad the batch next_batch_beam.extend([(0, pad_token_id, 0)] * num_beams) # pad the batch
continue continue
# next sentence beam content # next sentence beam content, this will get added to next_batch_beam
next_sent_beam = [] next_sent_beam = []
# next tokens for this sentence # next tokens for this sentence
...@@ -1485,7 +1485,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ...@@ -1485,7 +1485,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
token_id = beam_token_id % vocab_size token_id = beam_token_id % vocab_size
effective_beam_id = batch_idx * num_beams + beam_id effective_beam_id = batch_idx * num_beams + beam_id
# add to generated hypotheses if end of sentence or last iteration # add to generated hypotheses if end of sentence
if (eos_token_id is not None) and (token_id.item() == eos_token_id): if (eos_token_id is not None) and (token_id.item() == eos_token_id):
# if beam_token does not belong to top num_beams tokens, it should not be added # if beam_token does not belong to top num_beams tokens, it should not be added
is_beam_token_worse_than_top_num_beams = beam_token_rank >= num_beams is_beam_token_worse_than_top_num_beams = beam_token_rank >= num_beams
...@@ -1495,22 +1495,22 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ...@@ -1495,22 +1495,22 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
input_ids[effective_beam_id].clone(), beam_token_score.item(), input_ids[effective_beam_id].clone(), beam_token_score.item(),
) )
else: else:
# add next predicted token if it is not eos_token # add next predicted token since it is not eos_token
next_sent_beam.append((beam_token_score, token_id, effective_beam_id)) next_sent_beam.append((beam_token_score, token_id, effective_beam_id))
# the beam for next step is full # once the beam for next step is full, don't add more tokens to it.
if len(next_sent_beam) == num_beams: if len(next_sent_beam) == num_beams:
break break
# Check if were done so that we can save a pad step if all(done) # Check if we are done so that we can save a pad step if all(done)
done[batch_idx] = done[batch_idx] or generated_hyps[batch_idx].is_done( done[batch_idx] = done[batch_idx] or generated_hyps[batch_idx].is_done(
next_scores[batch_idx].max().item(), cur_len=cur_len next_scores[batch_idx].max().item(), cur_len
) )
# update next beam content # update next beam content
assert len(next_sent_beam) == num_beams, "Beam should always be full" assert len(next_sent_beam) == num_beams, "Beam should always be full"
next_batch_beam.extend(next_sent_beam) next_batch_beam.extend(next_sent_beam)
assert len(next_batch_beam) == num_beams * (batch_idx + 1) assert len(next_batch_beam) == num_beams * (batch_idx + 1), "We should have added num_beams each step"
# stop when we are done with each sentence # stop when we are done with each sentence
if all(done): if all(done):
...@@ -1537,7 +1537,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ...@@ -1537,7 +1537,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1 [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
) )
# finalize all open beam hypotheses and end to generated hypotheses # finalize all open beam hypotheses and add to generated hypotheses
for batch_idx in range(batch_size): for batch_idx in range(batch_size):
if done[batch_idx]: if done[batch_idx]:
continue continue
...@@ -1576,7 +1576,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ...@@ -1576,7 +1576,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
sent_lengths[effective_batch_idx] = len(best_hyp) sent_lengths[effective_batch_idx] = len(best_hyp)
best.append(best_hyp) best.append(best_hyp)
# shorter batches are filled with pad_token # shorter batches are padded
if sent_lengths.min().item() != sent_lengths.max().item(): if sent_lengths.min().item() != sent_lengths.max().item():
assert pad_token_id is not None, "`Pad_token_id` has to be defined" assert pad_token_id is not None, "`Pad_token_id` has to be defined"
sent_max_len = min(sent_lengths.max().item() + 1, max_length) sent_max_len = min(sent_lengths.max().item() + 1, max_length)
...@@ -1731,7 +1731,7 @@ class BeamHypotheses(object): ...@@ -1731,7 +1731,7 @@ class BeamHypotheses(object):
else: else:
self.worst_score = min(score, self.worst_score) self.worst_score = min(score, self.worst_score)
def is_done(self, best_sum_logprobs, cur_len=None): def is_done(self, best_sum_logprobs, cur_len):
""" """
If there are enough hypotheses and that none of the hypotheses being generated If there are enough hypotheses and that none of the hypotheses being generated
can become better than the worst one in the heap, then we are done with this sentence. can become better than the worst one in the heap, then we are done with this sentence.
...@@ -1742,8 +1742,6 @@ class BeamHypotheses(object): ...@@ -1742,8 +1742,6 @@ class BeamHypotheses(object):
elif self.early_stopping: elif self.early_stopping:
return True return True
else: else:
if cur_len is None:
cur_len = self.max_length
cur_score = best_sum_logprobs / cur_len ** self.length_penalty cur_score = best_sum_logprobs / cur_len ** self.length_penalty
ret = self.worst_score >= cur_score ret = self.worst_score >= cur_score
return ret return ret
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment