Unverified Commit d4dbd7ca authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Generate: get the correct beam index on eos token (#18851)

parent c6d3daba
...@@ -259,7 +259,7 @@ class BeamSearchScorer(BeamScorer): ...@@ -259,7 +259,7 @@ class BeamSearchScorer(BeamScorer):
continue continue
if beam_indices is not None: if beam_indices is not None:
beam_index = beam_indices[batch_beam_idx] beam_index = beam_indices[batch_beam_idx]
beam_index = beam_index + (next_index,) beam_index = beam_index + (batch_beam_idx,)
else: else:
beam_index = None beam_index = None
......
...@@ -172,7 +172,7 @@ class BeamSearchTester: ...@@ -172,7 +172,7 @@ class BeamSearchTester:
input_ids[correct_idx].tolist(), beam_scorer._beam_hyps[batch_idx].beams[0][1].tolist() input_ids[correct_idx].tolist(), beam_scorer._beam_hyps[batch_idx].beams[0][1].tolist()
) )
self.parent.assertListEqual( self.parent.assertListEqual(
expected_beam_indices + [next_indices[batch_idx, 1].item()], expected_beam_indices + [correct_idx],
torch.tensor(beam_scorer._beam_hyps[batch_idx].beams[0][2]).tolist(), torch.tensor(beam_scorer._beam_hyps[batch_idx].beams[0][2]).tolist(),
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment