Unverified Commit 8d6acc6c authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Beam Search] Correct returned beam scores (#14654)

* better

* save intermediate

* finish code

* up

* docs

* Apply suggestions from code review

* up

* add compute transition  beam scores function to model and make sure scores are correct with eos

* apply nicos comments

* Apply suggestions from code review

* another fix
parent e239fc3b
...@@ -208,10 +208,13 @@ class BeamSearchDecoderOnlyOutput(ModelOutput): ...@@ -208,10 +208,13 @@ class BeamSearchDecoderOnlyOutput(ModelOutput):
sequences_scores (`torch.FloatTensor` of shape `(batch_size*num_return_sequences)`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`): sequences_scores (`torch.FloatTensor` of shape `(batch_size*num_return_sequences)`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
Final beam scores of the generated `sequences`. Final beam scores of the generated `sequences`.
scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`): scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
Processed beam scores for each vocabulary token at each generation step. Beam scores consisting of log Beam transition scores for each vocabulary token at each generation step. Beam transition scores consisting
softmax scores for each vocabulary token and sum of log softmax of previously generated tokens in this beam of log probabilities of tokens conditioned on log softmax of previously generated tokens in this beam.
. `(max_length-input_ids.shape[-1],)`-shaped tuple of `torch.FloatTensor` with each tensor of shape `(max_length-input_ids.shape[-1],)`-shaped tuple of `torch.FloatTensor` with each tensor of shape
`(batch_size*num_beams*num_return_sequences, config.vocab_size)`). `(batch_size*num_beams*num_return_sequences, config.vocab_size)`).
beam_indices (`tuple(tuple(torch.LongTensor))`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
Beam indices of generated token id at each generation step. `(batch_size*num_return_sequences)`-shaped
tuple of `(max_length-input_ids.shape[-1],)`-shaped tuples of scalar `torch.LongTensor` tensors.
attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`): attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
`torch.FloatTensor` of shape `(batch_size*num_beams, num_heads, generated_length, sequence_length)`. `torch.FloatTensor` of shape `(batch_size*num_beams, num_heads, generated_length, sequence_length)`.
...@@ -223,6 +226,7 @@ class BeamSearchDecoderOnlyOutput(ModelOutput): ...@@ -223,6 +226,7 @@ class BeamSearchDecoderOnlyOutput(ModelOutput):
sequences: torch.LongTensor = None sequences: torch.LongTensor = None
sequences_scores: Optional[torch.FloatTensor] = None sequences_scores: Optional[torch.FloatTensor] = None
scores: Optional[Tuple[torch.FloatTensor]] = None scores: Optional[Tuple[torch.FloatTensor]] = None
beam_indices: Optional[Tuple[Tuple[torch.LongTensor]]] = None
attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
...@@ -241,10 +245,13 @@ class BeamSearchEncoderDecoderOutput(ModelOutput): ...@@ -241,10 +245,13 @@ class BeamSearchEncoderDecoderOutput(ModelOutput):
sequences_scores (`torch.FloatTensor` of shape `(batch_size*num_return_sequences)`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`): sequences_scores (`torch.FloatTensor` of shape `(batch_size*num_return_sequences)`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
Final beam scores of the generated `sequences`. Final beam scores of the generated `sequences`.
scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`): scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
Processed beam scores for each vocabulary token at each generation step. Beam scores consisting of log Beam transition scores for each vocabulary token at each generation step. Beam transition scores consisting
softmax scores for each vocabulary token and sum of log softmax of previously generated tokens in this beam of log probabilities of tokens conditioned on log softmax of previously generated tokens in this beam.
. `(max_length-1,)`-shaped tuple of `torch.FloatTensor` with each tensor of shape `(batch_size*num_beams, `(max_length-1,)`-shaped tuple of `torch.FloatTensor` with each tensor of shape `(batch_size*num_beams,
config.vocab_size)`). config.vocab_size)`).
beam_indices (`tuple(tuple(torch.LongTensor))`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
Beam indices of generated token id at each generation step. `(batch_size*num_return_sequences)`-shaped
tuple of `(max_length-1,)`-shaped tuples of scalar `torch.LongTensor` tensors.
attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`): attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`): encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer of the decoder) of shape `(batch_size, num_heads, Tuple of `torch.FloatTensor` (one for each layer of the decoder) of shape `(batch_size, num_heads,
...@@ -267,6 +274,7 @@ class BeamSearchEncoderDecoderOutput(ModelOutput): ...@@ -267,6 +274,7 @@ class BeamSearchEncoderDecoderOutput(ModelOutput):
sequences: torch.LongTensor = None sequences: torch.LongTensor = None
sequences_scores: Optional[torch.FloatTensor] = None sequences_scores: Optional[torch.FloatTensor] = None
scores: Optional[Tuple[torch.FloatTensor]] = None scores: Optional[Tuple[torch.FloatTensor]] = None
beam_indices: Optional[Tuple[Tuple[torch.LongTensor]]] = None
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
...@@ -286,10 +294,13 @@ class BeamSampleDecoderOnlyOutput(ModelOutput): ...@@ -286,10 +294,13 @@ class BeamSampleDecoderOnlyOutput(ModelOutput):
sequences_scores (`torch.FloatTensor` of shape `(batch_size * num_return_sequence)`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`): sequences_scores (`torch.FloatTensor` of shape `(batch_size * num_return_sequence)`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
Final beam scores of the generated `sequences`. Final beam scores of the generated `sequences`.
scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`): scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
Processed beam scores for each vocabulary token at each generation step. Beam scores consisting of log Beam transition scores for each vocabulary token at each generation step. Beam transition scores consisting
softmax scores for each vocabulary token and sum of log softmax of previously generated tokens in this beam of log probabilities of tokens conditioned on log softmax of previously generated tokens in this beam.
. `(max_length-input_ids.shape[-1],)`-shaped tuple of `torch.FloatTensor` with each tensor of shape `(max_length-input_ids.shape[-1],)`-shaped tuple of `torch.FloatTensor` with each tensor of shape
`(batch_size*num_beams*num_return_sequences, config.vocab_size)`). `(batch_size*num_beams*num_return_sequences, config.vocab_size)`).
beam_indices (`tuple(tuple(torch.LongTensor))`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
Beam indices of generated token id at each generation step. `(batch_size*num_return_sequences)`-shaped
tuple of `(max_length-input_ids.shape[-1],)`-shaped tuples of scalar `torch.LongTensor` tensors.
attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`): attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
`torch.FloatTensor` of shape `(batch_size*num_beams, num_heads, generated_length, sequence_length)`. `torch.FloatTensor` of shape `(batch_size*num_beams, num_heads, generated_length, sequence_length)`.
...@@ -301,6 +312,7 @@ class BeamSampleDecoderOnlyOutput(ModelOutput): ...@@ -301,6 +312,7 @@ class BeamSampleDecoderOnlyOutput(ModelOutput):
sequences: torch.LongTensor = None sequences: torch.LongTensor = None
sequences_scores: Optional[torch.FloatTensor] = None sequences_scores: Optional[torch.FloatTensor] = None
scores: Optional[Tuple[torch.FloatTensor]] = None scores: Optional[Tuple[torch.FloatTensor]] = None
beam_indices: Optional[Tuple[Tuple[torch.LongTensor]]] = None
attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
...@@ -319,10 +331,13 @@ class BeamSampleEncoderDecoderOutput(ModelOutput): ...@@ -319,10 +331,13 @@ class BeamSampleEncoderDecoderOutput(ModelOutput):
sequences_scores (`torch.FloatTensor` of shape `(batch_size * num_return_sequence)`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`): sequences_scores (`torch.FloatTensor` of shape `(batch_size * num_return_sequence)`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
Final beam scores of the generated `sequences`. Final beam scores of the generated `sequences`.
scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`): scores (`tuple(torch.FloatTensor)` *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
Processed beam scores for each vocabulary token at each generation step. Beam scores consisting of log Beam transition scores for each vocabulary token at each generation step. Beam transition scores consisting
softmax scores for each vocabulary token and sum of log softmax of previously generated tokens in this beam of log probabilities of tokens conditioned on log softmax of previously generated tokens in this beam.
. `(max_length-1,)`-shaped tuple of `torch.FloatTensor` with each tensor of shape `(batch_size*num_beams, `(max_length-1,)`-shaped tuple of `torch.FloatTensor` with each tensor of shape `(batch_size*num_beams,
config.vocab_size)`). config.vocab_size)`).
beam_indices (`tuple(tuple(torch.LongTensor))`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
Beam indices of generated token id at each generation step. `(batch_size*num_return_sequences)`-shaped
tuple of `(max_length-1,)`-shaped tuples of scalar `torch.LongTensor` tensors.
encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`): encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer of the decoder) of shape `(batch_size, num_heads, Tuple of `torch.FloatTensor` (one for each layer of the decoder) of shape `(batch_size, num_heads,
sequence_length, sequence_length)`. sequence_length, sequence_length)`.
...@@ -343,6 +358,7 @@ class BeamSampleEncoderDecoderOutput(ModelOutput): ...@@ -343,6 +358,7 @@ class BeamSampleEncoderDecoderOutput(ModelOutput):
sequences: torch.LongTensor = None sequences: torch.LongTensor = None
sequences_scores: Optional[torch.FloatTensor] = None sequences_scores: Optional[torch.FloatTensor] = None
scores: Optional[Tuple[torch.FloatTensor]] = None scores: Optional[Tuple[torch.FloatTensor]] = None
beam_indices: Optional[Tuple[Tuple[torch.LongTensor]]] = None
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
...@@ -743,6 +759,45 @@ class GenerationMixin: ...@@ -743,6 +759,45 @@ class GenerationMixin:
default_list.extend(custom_list) default_list.extend(custom_list)
return default_list return default_list
def compute_transition_beam_scores(
self,
sequences: torch.Tensor,
scores: Tuple[torch.Tensor],
beam_indices: torch.Tensor,
eos_token_id: int = None,
):
"""compute the transition probabilities of sequences given generation
scores and beam indices"""
# reshape scores as [vocab_size * batch_size, # generation steps]
# with batch_size being 2 * vocab_size and # generation steps being
# seq_len - input_length
scores = torch.stack(scores).reshape(len(scores), -1).transpose(0, 1)
# start of generated tokens
cut_idx = sequences.shape[-1] - scores.shape[-1]
# adjust for beam indices
beam_sequence_indices = torch.tensor(beam_indices, device=sequences.device) * self.config.vocab_size
# compute real indices
indices = sequences[:, cut_idx:] + beam_sequence_indices
# gather scores and run
transition_scores = scores.gather(0, indices)
# make sure that if EOS token was used before length of sequence `sequence.shape[-1]`
# get first occurence of EOS token
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
if eos_token_id is not None:
is_eos_token_id = sequences[:, cut_idx:] == eos_token_id
# make sure first eos token still contributes to transition probs
is_eos_token_id[:, -1] = False
is_eos_token_id = is_eos_token_id.roll(1, -1)
# all indices after eos shoud be masked
zero_transition_prob_mask = is_eos_token_id.cumsum(-1).bool()
# zero out padded probs
transition_scores.masked_fill_(zero_transition_prob_mask, 0.0)
return transition_scores
@torch.no_grad() @torch.no_grad()
def generate( def generate(
self, self,
...@@ -1871,8 +1926,21 @@ class GenerationMixin: ...@@ -1871,8 +1926,21 @@ class GenerationMixin:
return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate
) )
batch_size = len(beam_scorer._beam_hyps)
num_beams = beam_scorer.num_beams
batch_beam_size, cur_len = input_ids.shape
if num_beams * batch_size != batch_beam_size:
raise ValueError(
f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}."
)
# init attention / hidden states / scores tuples # init attention / hidden states / scores tuples
scores = () if (return_dict_in_generate and output_scores) else None scores = () if (return_dict_in_generate and output_scores) else None
beam_indices = (
tuple(() for _ in range(batch_beam_size)) if (return_dict_in_generate and output_scores) else None
)
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
cross_attentions = () if (return_dict_in_generate and output_attentions) else None cross_attentions = () if (return_dict_in_generate and output_attentions) else None
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
...@@ -1884,16 +1952,6 @@ class GenerationMixin: ...@@ -1884,16 +1952,6 @@ class GenerationMixin:
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
) )
batch_size = len(beam_scorer._beam_hyps)
num_beams = beam_scorer.num_beams
batch_beam_size, cur_len = input_ids.shape
if num_beams * batch_size != batch_beam_size:
raise ValueError(
f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}."
)
beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device) beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)
beam_scores[:, 1:] = -1e9 beam_scores[:, 1:] = -1e9
beam_scores = beam_scores.view((batch_size * num_beams,)) beam_scores = beam_scores.view((batch_size * num_beams,))
...@@ -1932,13 +1990,13 @@ class GenerationMixin: ...@@ -1932,13 +1990,13 @@ class GenerationMixin:
next_token_logits, dim=-1 next_token_logits, dim=-1
) # (batch_size * num_beams, vocab_size) ) # (batch_size * num_beams, vocab_size)
next_token_scores = logits_processor(input_ids, next_token_scores) next_token_scores_processed = logits_processor(input_ids, next_token_scores)
next_token_scores = next_token_scores + beam_scores[:, None].expand_as(next_token_scores) next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(next_token_scores)
# Store scores, attentions and hidden_states when required # Store scores, attentions and hidden_states when required
if return_dict_in_generate: if return_dict_in_generate:
if output_scores: if output_scores:
scores += (next_token_scores,) scores += (next_token_scores_processed,)
if output_attentions: if output_attentions:
decoder_attentions += ( decoder_attentions += (
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
...@@ -1973,6 +2031,7 @@ class GenerationMixin: ...@@ -1973,6 +2031,7 @@ class GenerationMixin:
pad_token_id=pad_token_id, pad_token_id=pad_token_id,
eos_token_id=eos_token_id, eos_token_id=eos_token_id,
) )
beam_scores = beam_outputs["next_beam_scores"] beam_scores = beam_outputs["next_beam_scores"]
beam_next_tokens = beam_outputs["next_beam_tokens"] beam_next_tokens = beam_outputs["next_beam_tokens"]
beam_idx = beam_outputs["next_beam_indices"] beam_idx = beam_outputs["next_beam_indices"]
...@@ -1985,6 +2044,9 @@ class GenerationMixin: ...@@ -1985,6 +2044,9 @@ class GenerationMixin:
if model_kwargs["past"] is not None: if model_kwargs["past"] is not None:
model_kwargs["past"] = self._reorder_cache(model_kwargs["past"], beam_idx) model_kwargs["past"] = self._reorder_cache(model_kwargs["past"], beam_idx)
if return_dict_in_generate and output_scores:
beam_indices = tuple((beam_indices[beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices))))
# increase cur_len # increase cur_len
cur_len = cur_len + 1 cur_len = cur_len + 1
...@@ -2007,11 +2069,20 @@ class GenerationMixin: ...@@ -2007,11 +2069,20 @@ class GenerationMixin:
if return_dict_in_generate: if return_dict_in_generate:
if not output_scores: if not output_scores:
sequence_outputs["sequence_scores"] = None sequence_outputs["sequence_scores"] = None
else:
num_return_sequences = beam_scorer.num_beam_hyps_to_keep
# return only as many indices as sequences
beam_indices = tuple(
(beam_indices[i * num_beams : i * num_beams + num_return_sequences] for i in range(batch_size))
)
beam_indices = sum(beam_indices, ())
if self.config.is_encoder_decoder: if self.config.is_encoder_decoder:
return BeamSearchEncoderDecoderOutput( return BeamSearchEncoderDecoderOutput(
sequences=sequence_outputs["sequences"], sequences=sequence_outputs["sequences"],
sequences_scores=sequence_outputs["sequence_scores"], sequences_scores=sequence_outputs["sequence_scores"],
scores=scores, scores=scores,
beam_indices=beam_indices,
encoder_attentions=encoder_attentions, encoder_attentions=encoder_attentions,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
decoder_attentions=decoder_attentions, decoder_attentions=decoder_attentions,
...@@ -2023,6 +2094,7 @@ class GenerationMixin: ...@@ -2023,6 +2094,7 @@ class GenerationMixin:
sequences=sequence_outputs["sequences"], sequences=sequence_outputs["sequences"],
sequences_scores=sequence_outputs["sequence_scores"], sequences_scores=sequence_outputs["sequence_scores"],
scores=scores, scores=scores,
beam_indices=beam_indices,
attentions=decoder_attentions, attentions=decoder_attentions,
hidden_states=decoder_hidden_states, hidden_states=decoder_hidden_states,
) )
...@@ -2175,8 +2247,16 @@ class GenerationMixin: ...@@ -2175,8 +2247,16 @@ class GenerationMixin:
return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate
) )
batch_size = len(beam_scorer._beam_hyps)
num_beams = beam_scorer.num_beams
batch_beam_size, cur_len = input_ids.shape
# init attention / hidden states / scores tuples # init attention / hidden states / scores tuples
scores = () if (return_dict_in_generate and output_scores) else None scores = () if (return_dict_in_generate and output_scores) else None
beam_indices = (
tuple(() for _ in range(batch_beam_size)) if (return_dict_in_generate and output_scores) else None
)
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
cross_attentions = () if (return_dict_in_generate and output_attentions) else None cross_attentions = () if (return_dict_in_generate and output_attentions) else None
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
...@@ -2188,11 +2268,6 @@ class GenerationMixin: ...@@ -2188,11 +2268,6 @@ class GenerationMixin:
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
) )
batch_size = len(beam_scorer._beam_hyps)
num_beams = beam_scorer.num_beams
batch_beam_size, cur_len = input_ids.shape
beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device) beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)
beam_scores = beam_scores.view((batch_size * num_beams,)) beam_scores = beam_scores.view((batch_size * num_beams,))
...@@ -2231,14 +2306,14 @@ class GenerationMixin: ...@@ -2231,14 +2306,14 @@ class GenerationMixin:
next_token_logits, dim=-1 next_token_logits, dim=-1
) # (batch_size * num_beams, vocab_size) ) # (batch_size * num_beams, vocab_size)
next_token_scores = logits_processor(input_ids, next_token_scores) next_token_scores_processed = logits_processor(input_ids, next_token_scores)
next_token_scores = next_token_scores + beam_scores[:, None].expand_as(next_token_scores) next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(next_token_scores)
next_token_scores = logits_warper(input_ids, next_token_scores) next_token_scores = logits_warper(input_ids, next_token_scores)
# Store scores, attentions and hidden_states when required # Store scores, attentions and hidden_states when required
if return_dict_in_generate: if return_dict_in_generate:
if output_scores: if output_scores:
scores += (next_token_scores,) scores += (logits_warper(input_ids, next_token_scores_processed),)
if output_attentions: if output_attentions:
decoder_attentions += ( decoder_attentions += (
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
...@@ -2289,6 +2364,9 @@ class GenerationMixin: ...@@ -2289,6 +2364,9 @@ class GenerationMixin:
if model_kwargs["past"] is not None: if model_kwargs["past"] is not None:
model_kwargs["past"] = self._reorder_cache(model_kwargs["past"], beam_idx) model_kwargs["past"] = self._reorder_cache(model_kwargs["past"], beam_idx)
if return_dict_in_generate and output_scores:
beam_indices = tuple((beam_indices[beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices))))
# increase cur_len # increase cur_len
cur_len = cur_len + 1 cur_len = cur_len + 1
...@@ -2311,11 +2389,20 @@ class GenerationMixin: ...@@ -2311,11 +2389,20 @@ class GenerationMixin:
if return_dict_in_generate: if return_dict_in_generate:
if not output_scores: if not output_scores:
sequence_outputs["sequence_scores"] = None sequence_outputs["sequence_scores"] = None
else:
num_return_sequences = beam_scorer.num_beam_hyps_to_keep
# return only as many indices as sequences
beam_indices = tuple(
(beam_indices[i * num_beams : i * num_beams + num_return_sequences] for i in range(batch_size))
)
beam_indices = sum(beam_indices, ())
if self.config.is_encoder_decoder: if self.config.is_encoder_decoder:
return BeamSampleEncoderDecoderOutput( return BeamSampleEncoderDecoderOutput(
sequences=sequence_outputs["sequences"], sequences=sequence_outputs["sequences"],
sequences_scores=sequence_outputs["sequence_scores"], sequences_scores=sequence_outputs["sequence_scores"],
scores=scores, scores=scores,
beam_indices=beam_indices,
encoder_attentions=encoder_attentions, encoder_attentions=encoder_attentions,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
decoder_attentions=decoder_attentions, decoder_attentions=decoder_attentions,
...@@ -2327,6 +2414,7 @@ class GenerationMixin: ...@@ -2327,6 +2414,7 @@ class GenerationMixin:
sequences=sequence_outputs["sequences"], sequences=sequence_outputs["sequences"],
sequences_scores=sequence_outputs["sequence_scores"], sequences_scores=sequence_outputs["sequence_scores"],
scores=scores, scores=scores,
beam_indices=beam_indices,
attentions=decoder_attentions, attentions=decoder_attentions,
hidden_states=decoder_hidden_states, hidden_states=decoder_hidden_states,
) )
...@@ -2472,19 +2560,6 @@ class GenerationMixin: ...@@ -2472,19 +2560,6 @@ class GenerationMixin:
return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate
) )
# init attention / hidden states / scores tuples
scores = () if (return_dict_in_generate and output_scores) else None
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
cross_attentions = () if (return_dict_in_generate and output_attentions) else None
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states
if return_dict_in_generate and self.config.is_encoder_decoder:
encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
encoder_hidden_states = (
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
)
batch_size = len(beam_scorer._beam_hyps) batch_size = len(beam_scorer._beam_hyps)
num_beams = beam_scorer.num_beams num_beams = beam_scorer.num_beams
num_beam_groups = beam_scorer.num_beam_groups num_beam_groups = beam_scorer.num_beam_groups
...@@ -2493,11 +2568,29 @@ class GenerationMixin: ...@@ -2493,11 +2568,29 @@ class GenerationMixin:
batch_beam_size, cur_len = input_ids.shape batch_beam_size, cur_len = input_ids.shape
if return_dict_in_generate and output_scores:
beam_indices = [tuple(() for _ in range(num_sub_beams * batch_size)) for _ in range(num_beam_groups)]
else:
beam_indices = None
if num_beams * batch_size != batch_beam_size: if num_beams * batch_size != batch_beam_size:
raise ValueError( raise ValueError(
f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}." f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}."
) )
# init attention / hidden states / scores tuples
scores = () if (return_dict_in_generate and output_scores) else None
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
cross_attentions = () if (return_dict_in_generate and output_attentions) else None
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states
if return_dict_in_generate and self.config.is_encoder_decoder:
encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
encoder_hidden_states = (
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
)
beam_scores = torch.full((batch_size, num_beams), -1e9, dtype=torch.float, device=device) beam_scores = torch.full((batch_size, num_beams), -1e9, dtype=torch.float, device=device)
# initialise score of first beam of each group with 0 and the rest with 1e-9. This ensures that the beams in # initialise score of first beam of each group with 0 and the rest with 1e-9. This ensures that the beams in
# the same group don't produce same tokens everytime. # the same group don't produce same tokens everytime.
...@@ -2564,15 +2657,14 @@ class GenerationMixin: ...@@ -2564,15 +2657,14 @@ class GenerationMixin:
) # (batch_size * group_size, vocab_size) ) # (batch_size * group_size, vocab_size)
vocab_size = next_token_scores.shape[-1] vocab_size = next_token_scores.shape[-1]
next_token_scores = logits_processor( next_token_scores_processed = logits_processor(
group_input_ids, next_token_scores, current_tokens=current_tokens, beam_group_idx=beam_group_idx group_input_ids, next_token_scores, current_tokens=current_tokens, beam_group_idx=beam_group_idx
) )
next_token_scores = next_token_scores + beam_scores[batch_group_indices].unsqueeze(-1).expand_as( next_token_scores = next_token_scores_processed + beam_scores[batch_group_indices].unsqueeze(-1)
next_token_scores next_token_scores = next_token_scores.expand_as(next_token_scores_processed)
)
if output_scores: if output_scores:
processed_score[batch_group_indices] = next_token_scores processed_score[batch_group_indices] = next_token_scores_processed
# reshape for beam search # reshape for beam search
next_token_scores = next_token_scores.view(batch_size, group_size * vocab_size) next_token_scores = next_token_scores.view(batch_size, group_size * vocab_size)
...@@ -2597,6 +2689,11 @@ class GenerationMixin: ...@@ -2597,6 +2689,11 @@ class GenerationMixin:
beam_next_tokens = beam_outputs["next_beam_tokens"] beam_next_tokens = beam_outputs["next_beam_tokens"]
beam_idx = beam_outputs["next_beam_indices"] beam_idx = beam_outputs["next_beam_indices"]
if return_dict_in_generate and output_scores:
beam_indices[beam_group_idx] = tuple(
beam_indices[beam_group_idx][beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices[0]))
)
input_ids[batch_group_indices] = group_input_ids[beam_idx] input_ids[batch_group_indices] = group_input_ids[beam_idx]
group_input_ids = torch.cat([group_input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1) group_input_ids = torch.cat([group_input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
current_tokens[batch_group_indices] = group_input_ids[:, -1] current_tokens[batch_group_indices] = group_input_ids[:, -1]
...@@ -2655,11 +2752,21 @@ class GenerationMixin: ...@@ -2655,11 +2752,21 @@ class GenerationMixin:
if return_dict_in_generate: if return_dict_in_generate:
if not output_scores: if not output_scores:
sequence_outputs["sequence_scores"] = None sequence_outputs["sequence_scores"] = None
else:
beam_indices = sum(beam_indices, ())
num_return_sequences = beam_scorer.num_beam_hyps_to_keep
# return only as many indices as sequences
beam_indices = tuple(
(beam_indices[i * num_beams : i * num_beams + num_return_sequences] for i in range(batch_size))
)
beam_indices = sum(beam_indices, ())
if self.config.is_encoder_decoder: if self.config.is_encoder_decoder:
return BeamSearchEncoderDecoderOutput( return BeamSearchEncoderDecoderOutput(
sequences=sequence_outputs["sequences"], sequences=sequence_outputs["sequences"],
sequences_scores=sequence_outputs["sequence_scores"], sequences_scores=sequence_outputs["sequence_scores"],
scores=scores, scores=scores,
beam_indices=beam_indices,
encoder_attentions=encoder_attentions, encoder_attentions=encoder_attentions,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
decoder_attentions=decoder_attentions, decoder_attentions=decoder_attentions,
......
...@@ -1903,3 +1903,147 @@ class GenerationIntegrationTests(unittest.TestCase): ...@@ -1903,3 +1903,147 @@ class GenerationIntegrationTests(unittest.TestCase):
output_sequences_with_mask = output_sequences_with_mask.cpu() output_sequences_with_mask = output_sequences_with_mask.cpu()
self.assertListEqual(output_sequences_no_mask.tolist(), output_sequences_with_mask.tolist()) self.assertListEqual(output_sequences_no_mask.tolist(), output_sequences_with_mask.tolist())
def test_transition_scores_beam_search_encoder_decoder(self):
articles = [
"Justin Timberlake and Jessica Biel, welcome to parenthood.",
"Michael Phelps is arguably the most decorated Olympian of all time.",
]
tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
model = BartForConditionalGeneration.from_pretrained(
"hf-internal-testing/tiny-random-bart",
max_length=10,
num_beams=4,
num_return_sequences=2,
eos_token_id=None,
return_dict_in_generate=True,
output_scores=True,
length_penalty=0.0,
)
model = model.to(torch_device)
input_ids = tokenizer(articles, return_tensors="pt", padding=True).input_ids.to(torch_device)
outputs = model.generate(input_ids=input_ids)
transition_scores = model.compute_transition_beam_scores(
outputs.sequences, outputs.scores, outputs.beam_indices
)
transition_scores_sum = transition_scores.sum(-1)
self.assertTrue(torch.allclose(transition_scores_sum, outputs.sequences_scores, atol=1e-3))
def test_transition_scores_beam_search_encoder_decoder_with_eos(self):
articles = [
"Justin Timberlake and Jessica Biel, welcome to parenthood.",
"Michael Phelps is arguably the most decorated Olympian of all time.",
]
tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
model = BartForConditionalGeneration.from_pretrained(
"hf-internal-testing/tiny-random-bart",
max_length=10,
num_beams=4,
num_return_sequences=2,
return_dict_in_generate=True,
output_scores=True,
length_penalty=0.0,
)
model = model.to(torch_device)
input_ids = tokenizer(articles, return_tensors="pt", padding=True).input_ids.to(torch_device)
outputs = model.generate(input_ids=input_ids)
transition_scores = model.compute_transition_beam_scores(
outputs.sequences, outputs.scores, outputs.beam_indices
)
transition_scores_sum = transition_scores.sum(-1)
self.assertTrue(torch.allclose(transition_scores_sum, outputs.sequences_scores, atol=1e-3))
def test_transition_scores_beam_search_decoder_only(self):
articles = [
"Justin Timberlake",
"Michael Phelps",
]
tokenizer = GPT2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
tokenizer.pad_token = tokenizer.eos_token
model = GPT2LMHeadModel.from_pretrained(
"hf-internal-testing/tiny-random-gpt2",
max_length=10,
num_beams=4,
num_return_sequences=2,
pad_token_id=tokenizer.eos_token_id,
eos_token_id=None,
return_dict_in_generate=True,
output_scores=True,
length_penalty=0.0,
)
model = model.to(torch_device)
input_ids = tokenizer(articles, return_tensors="pt", padding=True).input_ids.to(torch_device)
outputs = model.generate(input_ids=input_ids)
transition_scores = model.compute_transition_beam_scores(
outputs.sequences, outputs.scores, outputs.beam_indices
)
transition_scores_sum = transition_scores.sum(-1)
self.assertTrue(torch.allclose(transition_scores_sum, outputs.sequences_scores, atol=1e-3))
def test_transition_scores_beam_sample_encoder_decoder(self):
articles = [
"Justin Timberlake and Jessica Biel, welcome to parenthood.",
"Michael Phelps is arguably the most decorated Olympian of all time.",
]
tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
model = BartForConditionalGeneration.from_pretrained(
"hf-internal-testing/tiny-random-bart",
do_sample=True,
max_length=10,
num_beams=4,
num_return_sequences=2,
eos_token_id=None,
return_dict_in_generate=True,
output_scores=True,
length_penalty=0.0,
)
model = model.to(torch_device)
input_ids = tokenizer(articles, return_tensors="pt", padding=True).input_ids.to(torch_device)
outputs = model.generate(input_ids=input_ids)
transition_scores = model.compute_transition_beam_scores(
outputs.sequences, outputs.scores, outputs.beam_indices
)
transition_scores_sum = transition_scores.sum(-1)
self.assertTrue(torch.allclose(transition_scores_sum, outputs.sequences_scores, atol=1e-3))
def test_transition_scores_group_beam_search_encoder_decoder(self):
articles = [
"Justin Timberlake and Jessica Biel, welcome to parenthood.",
"Michael Phelps is arguably the most decorated Olympian of all time.",
]
tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
model = BartForConditionalGeneration.from_pretrained(
"hf-internal-testing/tiny-random-bart",
max_length=10,
num_beams=2,
num_beam_groups=2,
num_return_sequences=2,
eos_token_id=None,
return_dict_in_generate=True,
output_scores=True,
length_penalty=0.0,
)
model = model.to(torch_device)
input_ids = tokenizer(articles, return_tensors="pt", padding=True).input_ids.to(torch_device)
outputs = model.generate(input_ids=input_ids)
transition_scores = model.compute_transition_beam_scores(
outputs.sequences, outputs.scores, outputs.beam_indices
)
transition_scores_sum = transition_scores.sum(-1)
self.assertTrue(torch.allclose(transition_scores_sum, outputs.sequences_scores, atol=1e-3))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment