Unverified Commit b31905d1 authored by Xin Qiu's avatar Xin Qiu Committed by GitHub
Browse files

Fix remaining issues in beam score calculation (#27808)

* Fix issues in add and is_done for BeamHypotheses

* make newly added arguments optional for better compatibility

* Directly use cur_len as generated_len, add note for retrocompatibility

* update test expectation

* make cur_len represents the length of the entire sequence including the decoder prompt

* remove redundant if/else in testing
parent 3ac9945e
...@@ -224,8 +224,8 @@ class BeamSearchScorer(BeamScorer): ...@@ -224,8 +224,8 @@ class BeamSearchScorer(BeamScorer):
group_index: Optional[int] = 0, group_index: Optional[int] = 0,
decoder_prompt_len: Optional[int] = 0, decoder_prompt_len: Optional[int] = 0,
) -> Dict[str, torch.Tensor]: ) -> Dict[str, torch.Tensor]:
# add up to the length which the next_scores is calculated on # add up to the length which the next_scores is calculated on (including decoder prompt)
cur_len = input_ids.shape[-1] - decoder_prompt_len + 1 cur_len = input_ids.shape[-1] + 1
batch_size = len(self._beam_hyps) // self.num_beam_groups batch_size = len(self._beam_hyps) // self.num_beam_groups
if not (batch_size == (input_ids.shape[0] // self.group_size)): if not (batch_size == (input_ids.shape[0] // self.group_size)):
...@@ -279,15 +279,11 @@ class BeamSearchScorer(BeamScorer): ...@@ -279,15 +279,11 @@ class BeamSearchScorer(BeamScorer):
else: else:
beam_index = None beam_index = None
# skip the corner case where the very first generated token is eos_token
if decoder_prompt_len == input_ids.shape[-1]:
continue
self._beam_hyps[batch_group_idx].add( self._beam_hyps[batch_group_idx].add(
input_ids[batch_beam_idx].clone(), input_ids[batch_beam_idx].clone(),
next_score.item(), next_score.item(),
beam_indices=beam_index, beam_indices=beam_index,
decoder_prompt_len=decoder_prompt_len, generated_len=cur_len - decoder_prompt_len,
) )
else: else:
# add next predicted token since it is not eos_token # add next predicted token since it is not eos_token
...@@ -308,7 +304,7 @@ class BeamSearchScorer(BeamScorer): ...@@ -308,7 +304,7 @@ class BeamSearchScorer(BeamScorer):
# Check if we are done so that we can save a pad step if all(done) # Check if we are done so that we can save a pad step if all(done)
self._done[batch_group_idx] = self._done[batch_group_idx] or self._beam_hyps[batch_group_idx].is_done( self._done[batch_group_idx] = self._done[batch_group_idx] or self._beam_hyps[batch_group_idx].is_done(
next_scores[batch_idx].max().item(), cur_len next_scores[batch_idx].max().item(), cur_len, decoder_prompt_len
) )
return UserDict( return UserDict(
...@@ -348,7 +344,8 @@ class BeamSearchScorer(BeamScorer): ...@@ -348,7 +344,8 @@ class BeamSearchScorer(BeamScorer):
final_score = final_beam_scores[batch_beam_idx].item() final_score = final_beam_scores[batch_beam_idx].item()
final_tokens = input_ids[batch_beam_idx] final_tokens = input_ids[batch_beam_idx]
beam_index = beam_indices[batch_beam_idx] if beam_indices is not None else None beam_index = beam_indices[batch_beam_idx] if beam_indices is not None else None
beam_hyp.add(final_tokens, final_score, beam_indices=beam_index, decoder_prompt_len=decoder_prompt_len) generated_len = final_tokens.shape[-1] - decoder_prompt_len
beam_hyp.add(final_tokens, final_score, beam_indices=beam_index, generated_len=generated_len)
# select the best hypotheses # select the best hypotheses
sent_lengths = input_ids.new(batch_size * self.num_beam_hyps_to_keep) sent_lengths = input_ids.new(batch_size * self.num_beam_hyps_to_keep)
...@@ -560,8 +557,8 @@ class ConstrainedBeamSearchScorer(BeamScorer): ...@@ -560,8 +557,8 @@ class ConstrainedBeamSearchScorer(BeamScorer):
indicating to which beam the next tokens shall be added. indicating to which beam the next tokens shall be added.
""" """
# add up to the length which the next_scores is calculated on # add up to the length which the next_scores is calculated on (including decoder prompt)
cur_len = input_ids.shape[-1] - decoder_prompt_len + 1 cur_len = input_ids.shape[-1] + 1
batch_size = len(self._beam_hyps) batch_size = len(self._beam_hyps)
if not (batch_size == (input_ids.shape[0] // self.group_size)): if not (batch_size == (input_ids.shape[0] // self.group_size)):
if self.num_beam_groups > 1: if self.num_beam_groups > 1:
...@@ -617,16 +614,11 @@ class ConstrainedBeamSearchScorer(BeamScorer): ...@@ -617,16 +614,11 @@ class ConstrainedBeamSearchScorer(BeamScorer):
else: else:
beam_index = None beam_index = None
# skip the corner case where the only constraint token is
# eos_token and the very first generated token is eos_token
if decoder_prompt_len == input_ids.shape[-1]:
continue
beam_hyp.add( beam_hyp.add(
input_ids[batch_beam_idx].clone(), input_ids[batch_beam_idx].clone(),
next_score.item(), next_score.item(),
beam_indices=beam_index, beam_indices=beam_index,
decoder_prompt_len=decoder_prompt_len, generated_len=cur_len - decoder_prompt_len,
) )
else: else:
# add next predicted token since it is not eos_token # add next predicted token since it is not eos_token
...@@ -660,7 +652,7 @@ class ConstrainedBeamSearchScorer(BeamScorer): ...@@ -660,7 +652,7 @@ class ConstrainedBeamSearchScorer(BeamScorer):
# Check if we are done so that we can save a pad step if all(done) # Check if we are done so that we can save a pad step if all(done)
self._done[batch_idx] = self._done[batch_idx] or beam_hyp.is_done( self._done[batch_idx] = self._done[batch_idx] or beam_hyp.is_done(
next_scores[batch_idx].max().item(), cur_len next_scores[batch_idx].max().item(), cur_len, decoder_prompt_len
) )
return UserDict( return UserDict(
...@@ -846,9 +838,8 @@ class ConstrainedBeamSearchScorer(BeamScorer): ...@@ -846,9 +838,8 @@ class ConstrainedBeamSearchScorer(BeamScorer):
completes_constraint = self.check_completes_constraints(final_tokens.cpu().tolist()) completes_constraint = self.check_completes_constraints(final_tokens.cpu().tolist())
if completes_constraint: if completes_constraint:
beam_index = beam_indices[batch_beam_idx] if beam_indices is not None else None beam_index = beam_indices[batch_beam_idx] if beam_indices is not None else None
beam_hyp.add( generated_len = final_tokens.shape[-1] - decoder_prompt_len
final_tokens, final_score, beam_indices=beam_index, decoder_prompt_len=decoder_prompt_len beam_hyp.add(final_tokens, final_score, beam_indices=beam_index, generated_len=generated_len)
)
ids_collect.append(beam_id) ids_collect.append(beam_id)
# due to overly complex constraints or other factors, sometimes we can't gaurantee a successful # due to overly complex constraints or other factors, sometimes we can't gaurantee a successful
...@@ -859,7 +850,8 @@ class ConstrainedBeamSearchScorer(BeamScorer): ...@@ -859,7 +850,8 @@ class ConstrainedBeamSearchScorer(BeamScorer):
batch_beam_idx = batch_idx * self.num_beams + beam_id batch_beam_idx = batch_idx * self.num_beams + beam_id
final_score = final_beam_scores[batch_beam_idx].item() final_score = final_beam_scores[batch_beam_idx].item()
final_tokens = input_ids[batch_beam_idx] final_tokens = input_ids[batch_beam_idx]
beam_hyp.add(final_tokens, final_score, decoder_prompt_len=decoder_prompt_len) generated_len = final_tokens.shape[-1] - decoder_prompt_len
beam_hyp.add(final_tokens, final_score, generated_len=generated_len)
if len(ids_collect) >= self.num_beam_hyps_to_keep: if len(ids_collect) >= self.num_beam_hyps_to_keep:
break break
...@@ -956,12 +948,17 @@ class BeamHypotheses: ...@@ -956,12 +948,17 @@ class BeamHypotheses:
hyp: torch.LongTensor, hyp: torch.LongTensor,
sum_logprobs: float, sum_logprobs: float,
beam_indices: Optional[torch.LongTensor] = None, beam_indices: Optional[torch.LongTensor] = None,
decoder_prompt_len: Optional[int] = 0, generated_len: Optional[int] = None,
): ):
""" """
Add a new hypothesis to the list. Add a new hypothesis to the list.
""" """
score = sum_logprobs / ((hyp.shape[-1] - decoder_prompt_len) ** self.length_penalty) if generated_len is not None:
score = sum_logprobs / (generated_len**self.length_penalty)
# This 'else' case exists for retrocompatibility
else:
score = sum_logprobs / (hyp.shape[-1] ** self.length_penalty)
if len(self) < self.num_beams or score > self.worst_score: if len(self) < self.num_beams or score > self.worst_score:
self.beams.append((score, hyp, beam_indices)) self.beams.append((score, hyp, beam_indices))
if len(self) > self.num_beams: if len(self) > self.num_beams:
...@@ -971,7 +968,7 @@ class BeamHypotheses: ...@@ -971,7 +968,7 @@ class BeamHypotheses:
else: else:
self.worst_score = min(score, self.worst_score) self.worst_score = min(score, self.worst_score)
def is_done(self, best_sum_logprobs: float, cur_len: int) -> bool: def is_done(self, best_sum_logprobs: float, cur_len: int, decoder_prompt_len: Optional[int] = 0) -> bool:
""" """
If there are enough hypotheses and that none of the hypotheses being generated can become better than the worst If there are enough hypotheses and that none of the hypotheses being generated can become better than the worst
one in the heap, then we are done with this sentence. one in the heap, then we are done with this sentence.
...@@ -987,7 +984,7 @@ class BeamHypotheses: ...@@ -987,7 +984,7 @@ class BeamHypotheses:
# when `length_penalty` is positive. See the discussion below for more details. # when `length_penalty` is positive. See the discussion below for more details.
# https://github.com/huggingface/transformers/pull/20901#issuecomment-1369845565 # https://github.com/huggingface/transformers/pull/20901#issuecomment-1369845565
elif self.early_stopping is False: elif self.early_stopping is False:
highest_attainable_score = best_sum_logprobs / cur_len**self.length_penalty highest_attainable_score = best_sum_logprobs / (cur_len - decoder_prompt_len) ** self.length_penalty
ret = self.worst_score >= highest_attainable_score ret = self.worst_score >= highest_attainable_score
return ret return ret
# `"never"`: compute the best possible score, depending on the signal of `length_penalty` # `"never"`: compute the best possible score, depending on the signal of `length_penalty`
...@@ -996,9 +993,13 @@ class BeamHypotheses: ...@@ -996,9 +993,13 @@ class BeamHypotheses:
# abs(`highest_attainable_score`) is obtained -> `highest_attainable_score` is negative, hence we obtain # abs(`highest_attainable_score`) is obtained -> `highest_attainable_score` is negative, hence we obtain
# its max this way # its max this way
if self.length_penalty > 0.0: if self.length_penalty > 0.0:
highest_attainable_score = best_sum_logprobs / self.max_length**self.length_penalty if self.max_length <= decoder_prompt_len:
raise ValueError("max_length is not larger than decoder prompt length")
highest_attainable_score = (
best_sum_logprobs / (self.max_length - decoder_prompt_len) ** self.length_penalty
)
# the opposite logic applies here (max `highest_attainable_score` from `cur_len`) # the opposite logic applies here (max `highest_attainable_score` from `cur_len`)
else: else:
highest_attainable_score = best_sum_logprobs / cur_len**self.length_penalty highest_attainable_score = best_sum_logprobs / (cur_len - decoder_prompt_len) ** self.length_penalty
ret = self.worst_score >= highest_attainable_score ret = self.worst_score >= highest_attainable_score
return ret return ret
...@@ -633,10 +633,6 @@ class GenerationIntegrationTestsMixin: ...@@ -633,10 +633,6 @@ class GenerationIntegrationTestsMixin:
"do_sample": False, "do_sample": False,
"num_beams": 3, "num_beams": 3,
} }
if is_pt:
expectation = 20
else:
# TODO (joao): fix me
expectation = 13 expectation = 13
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2") tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
......
...@@ -800,7 +800,7 @@ class ViT2GPT2ModelIntegrationTest(unittest.TestCase): ...@@ -800,7 +800,7 @@ class ViT2GPT2ModelIntegrationTest(unittest.TestCase):
preds, scores = generate_step(pixel_values) preds, scores = generate_step(pixel_values)
EXPECTED_SCORES = np.array([-0.64145195]) EXPECTED_SCORES = np.array([-0.5956343])
max_diff = np.amax(np.abs(scores - EXPECTED_SCORES)) max_diff = np.amax(np.abs(scores - EXPECTED_SCORES))
self.assertLessEqual(max_diff, 1e-4) self.assertLessEqual(max_diff, 1e-4)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment