Unverified Commit c90e14fb authored by Yoni Gottesman's avatar Yoni Gottesman Committed by GitHub
Browse files

Fix beam search to sample at least 1 non eos token (#25103) (#25115)

parent 31f137c0
...@@ -3068,9 +3068,10 @@ class GenerationMixin: ...@@ -3068,9 +3068,10 @@ class GenerationMixin:
vocab_size = next_token_scores.shape[-1] vocab_size = next_token_scores.shape[-1]
next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size) next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size)
# Sample 2 next tokens for each beam (so we have some spare tokens and match output of beam search) # Sample 1 + len(eos_token_id) next tokens for each beam so we have at least 1 non eos token per beam.
n_eos_tokens = len(eos_token_id) if eos_token_id else 0
next_token_scores, next_tokens = torch.topk( next_token_scores, next_tokens = torch.topk(
next_token_scores, 2 * num_beams, dim=1, largest=True, sorted=True next_token_scores, max(2, 1 + n_eos_tokens) * num_beams, dim=1, largest=True, sorted=True
) )
next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor") next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor")
...@@ -3746,9 +3747,10 @@ class GenerationMixin: ...@@ -3746,9 +3747,10 @@ class GenerationMixin:
# reshape for beam search # reshape for beam search
next_token_scores = next_token_scores.view(batch_size, group_size * vocab_size) next_token_scores = next_token_scores.view(batch_size, group_size * vocab_size)
# Sample 2 next tokens for each beam (so we have some spare tokens and match output of beam search) # Sample 1 + len(eos_token_id) next tokens for each beam so we have at least 1 non eos token per beam.
n_eos_tokens = len(eos_token_id) if eos_token_id else 0
next_token_scores, next_tokens = torch.topk( next_token_scores, next_tokens = torch.topk(
next_token_scores, 2 * group_size, dim=1, largest=True, sorted=True next_token_scores, max(2, 1 + n_eos_tokens) * group_size, dim=1, largest=True, sorted=True
) )
next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor") next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor")
...@@ -4119,9 +4121,10 @@ class GenerationMixin: ...@@ -4119,9 +4121,10 @@ class GenerationMixin:
vocab_size = next_token_scores.shape[-1] vocab_size = next_token_scores.shape[-1]
next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size) next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size)
# Sample 2 next tokens for each beam (so we have some spare tokens and match output of beam search) # Sample 1 + len(eos_token_id) next tokens for each beam so we have at least 1 non eos token per beam.
n_eos_tokens = len(eos_token_id) if eos_token_id else 0
next_token_scores, next_tokens = torch.topk( next_token_scores, next_tokens = torch.topk(
next_token_scores, 2 * num_beams, dim=1, largest=True, sorted=True next_token_scores, max(2, 1 + n_eos_tokens) * num_beams, dim=1, largest=True, sorted=True
) )
next_indices = (next_tokens / vocab_size).long() next_indices = (next_tokens / vocab_size).long()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment