Unverified Commit b31905d1 authored by Xin Qiu's avatar Xin Qiu Committed by GitHub
Browse files

Fix remaining issues in beam score calculation (#27808)

* Fix issues in add and is_done for BeamHypotheses

* make newly added arguments optional for better compatibility

* Directly use cur_len as generated_len, add note for retrocompatibility

* update test expectation

* make cur_len represents the length of the entire sequence including the decoder prompt

* remove redundant if/else in testing
parent 3ac9945e
......@@ -224,8 +224,8 @@ class BeamSearchScorer(BeamScorer):
group_index: Optional[int] = 0,
decoder_prompt_len: Optional[int] = 0,
) -> Dict[str, torch.Tensor]:
# add up to the length which the next_scores is calculated on
cur_len = input_ids.shape[-1] - decoder_prompt_len + 1
# add up to the length which the next_scores is calculated on (including decoder prompt)
cur_len = input_ids.shape[-1] + 1
batch_size = len(self._beam_hyps) // self.num_beam_groups
if not (batch_size == (input_ids.shape[0] // self.group_size)):
......@@ -279,15 +279,11 @@ class BeamSearchScorer(BeamScorer):
else:
beam_index = None
# skip the corner case where the very first generated token is eos_token
if decoder_prompt_len == input_ids.shape[-1]:
continue
self._beam_hyps[batch_group_idx].add(
input_ids[batch_beam_idx].clone(),
next_score.item(),
beam_indices=beam_index,
decoder_prompt_len=decoder_prompt_len,
generated_len=cur_len - decoder_prompt_len,
)
else:
# add next predicted token since it is not eos_token
......@@ -308,7 +304,7 @@ class BeamSearchScorer(BeamScorer):
# Check if we are done so that we can save a pad step if all(done)
self._done[batch_group_idx] = self._done[batch_group_idx] or self._beam_hyps[batch_group_idx].is_done(
next_scores[batch_idx].max().item(), cur_len
next_scores[batch_idx].max().item(), cur_len, decoder_prompt_len
)
return UserDict(
......@@ -348,7 +344,8 @@ class BeamSearchScorer(BeamScorer):
final_score = final_beam_scores[batch_beam_idx].item()
final_tokens = input_ids[batch_beam_idx]
beam_index = beam_indices[batch_beam_idx] if beam_indices is not None else None
beam_hyp.add(final_tokens, final_score, beam_indices=beam_index, decoder_prompt_len=decoder_prompt_len)
generated_len = final_tokens.shape[-1] - decoder_prompt_len
beam_hyp.add(final_tokens, final_score, beam_indices=beam_index, generated_len=generated_len)
# select the best hypotheses
sent_lengths = input_ids.new(batch_size * self.num_beam_hyps_to_keep)
......@@ -560,8 +557,8 @@ class ConstrainedBeamSearchScorer(BeamScorer):
indicating to which beam the next tokens shall be added.
"""
# add up to the length which the next_scores is calculated on
cur_len = input_ids.shape[-1] - decoder_prompt_len + 1
# add up to the length which the next_scores is calculated on (including decoder prompt)
cur_len = input_ids.shape[-1] + 1
batch_size = len(self._beam_hyps)
if not (batch_size == (input_ids.shape[0] // self.group_size)):
if self.num_beam_groups > 1:
......@@ -617,16 +614,11 @@ class ConstrainedBeamSearchScorer(BeamScorer):
else:
beam_index = None
# skip the corner case where the only constraint token is
# eos_token and the very first generated token is eos_token
if decoder_prompt_len == input_ids.shape[-1]:
continue
beam_hyp.add(
input_ids[batch_beam_idx].clone(),
next_score.item(),
beam_indices=beam_index,
decoder_prompt_len=decoder_prompt_len,
generated_len=cur_len - decoder_prompt_len,
)
else:
# add next predicted token since it is not eos_token
......@@ -660,7 +652,7 @@ class ConstrainedBeamSearchScorer(BeamScorer):
# Check if we are done so that we can save a pad step if all(done)
self._done[batch_idx] = self._done[batch_idx] or beam_hyp.is_done(
next_scores[batch_idx].max().item(), cur_len
next_scores[batch_idx].max().item(), cur_len, decoder_prompt_len
)
return UserDict(
......@@ -846,9 +838,8 @@ class ConstrainedBeamSearchScorer(BeamScorer):
completes_constraint = self.check_completes_constraints(final_tokens.cpu().tolist())
if completes_constraint:
beam_index = beam_indices[batch_beam_idx] if beam_indices is not None else None
beam_hyp.add(
final_tokens, final_score, beam_indices=beam_index, decoder_prompt_len=decoder_prompt_len
)
generated_len = final_tokens.shape[-1] - decoder_prompt_len
beam_hyp.add(final_tokens, final_score, beam_indices=beam_index, generated_len=generated_len)
ids_collect.append(beam_id)
# due to overly complex constraints or other factors, sometimes we can't gaurantee a successful
......@@ -859,7 +850,8 @@ class ConstrainedBeamSearchScorer(BeamScorer):
batch_beam_idx = batch_idx * self.num_beams + beam_id
final_score = final_beam_scores[batch_beam_idx].item()
final_tokens = input_ids[batch_beam_idx]
beam_hyp.add(final_tokens, final_score, decoder_prompt_len=decoder_prompt_len)
generated_len = final_tokens.shape[-1] - decoder_prompt_len
beam_hyp.add(final_tokens, final_score, generated_len=generated_len)
if len(ids_collect) >= self.num_beam_hyps_to_keep:
break
......@@ -956,12 +948,17 @@ class BeamHypotheses:
hyp: torch.LongTensor,
sum_logprobs: float,
beam_indices: Optional[torch.LongTensor] = None,
decoder_prompt_len: Optional[int] = 0,
generated_len: Optional[int] = None,
):
"""
Add a new hypothesis to the list.
"""
score = sum_logprobs / ((hyp.shape[-1] - decoder_prompt_len) ** self.length_penalty)
if generated_len is not None:
score = sum_logprobs / (generated_len**self.length_penalty)
# This 'else' case exists for retrocompatibility
else:
score = sum_logprobs / (hyp.shape[-1] ** self.length_penalty)
if len(self) < self.num_beams or score > self.worst_score:
self.beams.append((score, hyp, beam_indices))
if len(self) > self.num_beams:
......@@ -971,7 +968,7 @@ class BeamHypotheses:
else:
self.worst_score = min(score, self.worst_score)
def is_done(self, best_sum_logprobs: float, cur_len: int) -> bool:
def is_done(self, best_sum_logprobs: float, cur_len: int, decoder_prompt_len: Optional[int] = 0) -> bool:
"""
If there are enough hypotheses and that none of the hypotheses being generated can become better than the worst
one in the heap, then we are done with this sentence.
......@@ -987,7 +984,7 @@ class BeamHypotheses:
# when `length_penalty` is positive. See the discussion below for more details.
# https://github.com/huggingface/transformers/pull/20901#issuecomment-1369845565
elif self.early_stopping is False:
highest_attainable_score = best_sum_logprobs / cur_len**self.length_penalty
highest_attainable_score = best_sum_logprobs / (cur_len - decoder_prompt_len) ** self.length_penalty
ret = self.worst_score >= highest_attainable_score
return ret
# `"never"`: compute the best possible score, depending on the signal of `length_penalty`
......@@ -996,9 +993,13 @@ class BeamHypotheses:
# abs(`highest_attainable_score`) is obtained -> `highest_attainable_score` is negative, hence we obtain
# its max this way
if self.length_penalty > 0.0:
highest_attainable_score = best_sum_logprobs / self.max_length**self.length_penalty
if self.max_length <= decoder_prompt_len:
raise ValueError("max_length is not larger than decoder prompt length")
highest_attainable_score = (
best_sum_logprobs / (self.max_length - decoder_prompt_len) ** self.length_penalty
)
# the opposite logic applies here (max `highest_attainable_score` from `cur_len`)
else:
highest_attainable_score = best_sum_logprobs / cur_len**self.length_penalty
highest_attainable_score = best_sum_logprobs / (cur_len - decoder_prompt_len) ** self.length_penalty
ret = self.worst_score >= highest_attainable_score
return ret
......@@ -633,10 +633,6 @@ class GenerationIntegrationTestsMixin:
"do_sample": False,
"num_beams": 3,
}
if is_pt:
expectation = 20
else:
# TODO (joao): fix me
expectation = 13
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
......
......@@ -800,7 +800,7 @@ class ViT2GPT2ModelIntegrationTest(unittest.TestCase):
preds, scores = generate_step(pixel_values)
EXPECTED_SCORES = np.array([-0.64145195])
EXPECTED_SCORES = np.array([-0.5956343])
max_diff = np.amax(np.abs(scores - EXPECTED_SCORES))
self.assertLessEqual(max_diff, 1e-4)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment