Unverified Commit 9f4acd05 authored by Ekagra Ranjan's avatar Ekagra Ranjan Committed by GitHub
Browse files

Generate: add missing comments after refactoring of generate() (#18981)

parent 59407bbe
...@@ -2240,6 +2240,8 @@ class GenerationMixin: ...@@ -2240,6 +2240,8 @@ class GenerationMixin:
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
) )
# initialise score of first beam with 0 and the rest with -1e9. This makes sure that only tokens
# of the first beam are considered to avoid sampling the exact same tokens across all beams.
beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device) beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)
beam_scores[:, 1:] = -1e9 beam_scores[:, 1:] = -1e9
beam_scores = beam_scores.view((batch_size * num_beams,)) beam_scores = beam_scores.view((batch_size * num_beams,))
...@@ -2303,6 +2305,7 @@ class GenerationMixin: ...@@ -2303,6 +2305,7 @@ class GenerationMixin:
vocab_size = next_token_scores.shape[-1] vocab_size = next_token_scores.shape[-1]
next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size) next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size)
# Sample 2 next tokens for each beam (so we have some spare tokens and match output of beam search)
next_token_scores, next_tokens = torch.topk( next_token_scores, next_tokens = torch.topk(
next_token_scores, 2 * num_beams, dim=1, largest=True, sorted=True next_token_scores, 2 * num_beams, dim=1, largest=True, sorted=True
) )
...@@ -2873,9 +2876,9 @@ class GenerationMixin: ...@@ -2873,9 +2876,9 @@ class GenerationMixin:
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
) )
beam_scores = torch.full((batch_size, num_beams), -1e9, dtype=torch.float, device=device) # initialise score of first beam of each group with 0 and the rest with -1e9. This ensures that the beams in
# initialise score of first beam of each group with 0 and the rest with 1e-9. This ensures that the beams in
# the same group don't produce same tokens everytime. # the same group don't produce same tokens everytime.
beam_scores = torch.full((batch_size, num_beams), -1e9, dtype=torch.float, device=device)
beam_scores[:, ::num_sub_beams] = 0 beam_scores[:, ::num_sub_beams] = 0
beam_scores = beam_scores.view((batch_size * num_beams,)) beam_scores = beam_scores.view((batch_size * num_beams,))
...@@ -2951,6 +2954,7 @@ class GenerationMixin: ...@@ -2951,6 +2954,7 @@ class GenerationMixin:
# reshape for beam search # reshape for beam search
next_token_scores = next_token_scores.view(batch_size, group_size * vocab_size) next_token_scores = next_token_scores.view(batch_size, group_size * vocab_size)
# Sample 2 next tokens for each beam (so we have some spare tokens and match output of beam search)
next_token_scores, next_tokens = torch.topk( next_token_scores, next_tokens = torch.topk(
next_token_scores, 2 * group_size, dim=1, largest=True, sorted=True next_token_scores, 2 * group_size, dim=1, largest=True, sorted=True
) )
...@@ -3235,6 +3239,8 @@ class GenerationMixin: ...@@ -3235,6 +3239,8 @@ class GenerationMixin:
f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}." f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}."
) )
# initialise score of first beam with 0 and the rest with -1e9. This makes sure that only tokens
# of the first beam are considered to avoid sampling the exact same tokens across all beams.
beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device) beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)
beam_scores[:, 1:] = -1e9 beam_scores[:, 1:] = -1e9
beam_scores = beam_scores.view((batch_size * num_beams,)) beam_scores = beam_scores.view((batch_size * num_beams,))
...@@ -3301,6 +3307,7 @@ class GenerationMixin: ...@@ -3301,6 +3307,7 @@ class GenerationMixin:
vocab_size = next_token_scores.shape[-1] vocab_size = next_token_scores.shape[-1]
next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size) next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size)
# Sample 2 next tokens for each beam (so we have some spare tokens and match output of beam search)
next_token_scores, next_tokens = torch.topk( next_token_scores, next_tokens = torch.topk(
next_token_scores, 2 * num_beams, dim=1, largest=True, sorted=True next_token_scores, 2 * num_beams, dim=1, largest=True, sorted=True
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment