Unverified Commit d4dbd7ca authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Generate: get the correct beam index on eos token (#18851)

parent c6d3daba
......@@ -259,7 +259,7 @@ class BeamSearchScorer(BeamScorer):
continue
if beam_indices is not None:
beam_index = beam_indices[batch_beam_idx]
beam_index = beam_index + (next_index,)
beam_index = beam_index + (batch_beam_idx,)
else:
beam_index = None
......
......@@ -172,7 +172,7 @@ class BeamSearchTester:
input_ids[correct_idx].tolist(), beam_scorer._beam_hyps[batch_idx].beams[0][1].tolist()
)
self.parent.assertListEqual(
expected_beam_indices + [next_indices[batch_idx, 1].item()],
expected_beam_indices + [correct_idx],
torch.tensor(beam_scorer._beam_hyps[batch_idx].beams[0][2]).tolist(),
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment