Unverified Commit a0bd4647 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Generate: correct beam search length on score calculation for multi batch generation (#23127)

parent ce31e3c8
...@@ -212,7 +212,7 @@ class BeamSearchScorer(BeamScorer): ...@@ -212,7 +212,7 @@ class BeamSearchScorer(BeamScorer):
eos_token_id: Optional[Union[int, List[int]]] = None, eos_token_id: Optional[Union[int, List[int]]] = None,
beam_indices: Optional[torch.LongTensor] = None, beam_indices: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor]: ) -> Tuple[torch.Tensor]:
cur_len = input_ids.shape[-1] cur_len = input_ids.shape[-1] + 1 # add up to the length which the next_scores is calculated on
batch_size = len(self._beam_hyps) batch_size = len(self._beam_hyps)
if not (batch_size == (input_ids.shape[0] // self.group_size)): if not (batch_size == (input_ids.shape[0] // self.group_size)):
if self.num_beam_groups > 1: if self.num_beam_groups > 1:
...@@ -287,7 +287,6 @@ class BeamSearchScorer(BeamScorer): ...@@ -287,7 +287,6 @@ class BeamSearchScorer(BeamScorer):
) )
# Check if we are done so that we can save a pad step if all(done) # Check if we are done so that we can save a pad step if all(done)
cur_len += 1 # add up to the length which the next_scores is calculated on
self._done[batch_idx] = self._done[batch_idx] or beam_hyp.is_done( self._done[batch_idx] = self._done[batch_idx] or beam_hyp.is_done(
next_scores[batch_idx].max().item(), cur_len next_scores[batch_idx].max().item(), cur_len
) )
...@@ -532,7 +531,7 @@ class ConstrainedBeamSearchScorer(BeamScorer): ...@@ -532,7 +531,7 @@ class ConstrainedBeamSearchScorer(BeamScorer):
indicating to which beam the next tokens shall be added. indicating to which beam the next tokens shall be added.
""" """
cur_len = input_ids.shape[-1] cur_len = input_ids.shape[-1] + 1 # add up to the length which the next_scores is calculated on
batch_size = len(self._beam_hyps) batch_size = len(self._beam_hyps)
if not (batch_size == (input_ids.shape[0] // self.group_size)): if not (batch_size == (input_ids.shape[0] // self.group_size)):
if self.num_beam_groups > 1: if self.num_beam_groups > 1:
...@@ -617,7 +616,6 @@ class ConstrainedBeamSearchScorer(BeamScorer): ...@@ -617,7 +616,6 @@ class ConstrainedBeamSearchScorer(BeamScorer):
) )
# Check if we are done so that we can save a pad step if all(done) # Check if we are done so that we can save a pad step if all(done)
cur_len += 1 # add up to the length which the next_scores is calculated on
self._done[batch_idx] = self._done[batch_idx] or beam_hyp.is_done( self._done[batch_idx] = self._done[batch_idx] or beam_hyp.is_done(
next_scores[batch_idx].max().item(), cur_len next_scores[batch_idx].max().item(), cur_len
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment