Unverified Commit a0bd4647 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Generate: correct beam search length on score calculation for multi batch generation (#23127)

parent ce31e3c8
......@@ -212,7 +212,7 @@ class BeamSearchScorer(BeamScorer):
eos_token_id: Optional[Union[int, List[int]]] = None,
beam_indices: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor]:
cur_len = input_ids.shape[-1]
cur_len = input_ids.shape[-1] + 1 # add up to the length which the next_scores is calculated on
batch_size = len(self._beam_hyps)
if not (batch_size == (input_ids.shape[0] // self.group_size)):
if self.num_beam_groups > 1:
......@@ -287,7 +287,6 @@ class BeamSearchScorer(BeamScorer):
)
# Check if we are done so that we can save a pad step if all(done)
cur_len += 1 # add up to the length which the next_scores is calculated on
self._done[batch_idx] = self._done[batch_idx] or beam_hyp.is_done(
next_scores[batch_idx].max().item(), cur_len
)
......@@ -532,7 +531,7 @@ class ConstrainedBeamSearchScorer(BeamScorer):
indicating to which beam the next tokens shall be added.
"""
cur_len = input_ids.shape[-1]
cur_len = input_ids.shape[-1] + 1 # add up to the length which the next_scores is calculated on
batch_size = len(self._beam_hyps)
if not (batch_size == (input_ids.shape[0] // self.group_size)):
if self.num_beam_groups > 1:
......@@ -617,7 +616,6 @@ class ConstrainedBeamSearchScorer(BeamScorer):
)
# Check if we are done so that we can save a pad step if all(done)
cur_len += 1 # add up to the length which the next_scores is calculated on
self._done[batch_idx] = self._done[batch_idx] or beam_hyp.is_done(
next_scores[batch_idx].max().item(), cur_len
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment