Unverified Commit 518bd02c authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Generation] Fix Transition probs (#17311)

* [Draft] fix transition probs

* up

* up

* up

* make it work

* fix

* finish

* update
parent e8714c03
...@@ -212,6 +212,7 @@ class BeamSearchScorer(BeamScorer): ...@@ -212,6 +212,7 @@ class BeamSearchScorer(BeamScorer):
next_indices: torch.LongTensor, next_indices: torch.LongTensor,
pad_token_id: Optional[int] = None, pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None, eos_token_id: Optional[int] = None,
beam_indices: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor]: ) -> Tuple[torch.Tensor]:
cur_len = input_ids.shape[-1] cur_len = input_ids.shape[-1]
batch_size = len(self._beam_hyps) batch_size = len(self._beam_hyps)
...@@ -256,9 +257,16 @@ class BeamSearchScorer(BeamScorer): ...@@ -256,9 +257,16 @@ class BeamSearchScorer(BeamScorer):
is_beam_token_worse_than_top_num_beams = beam_token_rank >= self.group_size is_beam_token_worse_than_top_num_beams = beam_token_rank >= self.group_size
if is_beam_token_worse_than_top_num_beams: if is_beam_token_worse_than_top_num_beams:
continue continue
if beam_indices is not None:
beam_index = beam_indices[batch_beam_idx]
beam_index = beam_index + (next_index,)
else:
beam_index = None
beam_hyp.add( beam_hyp.add(
input_ids[batch_beam_idx].clone(), input_ids[batch_beam_idx].clone(),
next_score.item(), next_score.item(),
beam_indices=beam_index,
) )
else: else:
# add next predicted token since it is not eos_token # add next predicted token since it is not eos_token
...@@ -299,6 +307,7 @@ class BeamSearchScorer(BeamScorer): ...@@ -299,6 +307,7 @@ class BeamSearchScorer(BeamScorer):
max_length: int, max_length: int,
pad_token_id: Optional[int] = None, pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None, eos_token_id: Optional[int] = None,
beam_indices: Optional[torch.LongTensor] = None,
) -> Tuple[torch.LongTensor]: ) -> Tuple[torch.LongTensor]:
batch_size = len(self._beam_hyps) batch_size = len(self._beam_hyps)
...@@ -313,11 +322,13 @@ class BeamSearchScorer(BeamScorer): ...@@ -313,11 +322,13 @@ class BeamSearchScorer(BeamScorer):
batch_beam_idx = batch_idx * self.num_beams + beam_id batch_beam_idx = batch_idx * self.num_beams + beam_id
final_score = final_beam_scores[batch_beam_idx].item() final_score = final_beam_scores[batch_beam_idx].item()
final_tokens = input_ids[batch_beam_idx] final_tokens = input_ids[batch_beam_idx]
beam_hyp.add(final_tokens, final_score) beam_index = beam_indices[batch_beam_idx] if beam_indices is not None else None
beam_hyp.add(final_tokens, final_score, beam_indices=beam_index)
# select the best hypotheses # select the best hypotheses
sent_lengths = input_ids.new(batch_size * self.num_beam_hyps_to_keep) sent_lengths = input_ids.new(batch_size * self.num_beam_hyps_to_keep)
best = [] best = []
best_indices = []
best_scores = torch.zeros(batch_size * self.num_beam_hyps_to_keep, device=self.device, dtype=torch.float32) best_scores = torch.zeros(batch_size * self.num_beam_hyps_to_keep, device=self.device, dtype=torch.float32)
# retrieve best hypotheses # retrieve best hypotheses
...@@ -327,23 +338,42 @@ class BeamSearchScorer(BeamScorer): ...@@ -327,23 +338,42 @@ class BeamSearchScorer(BeamScorer):
best_hyp_tuple = sorted_hyps.pop() best_hyp_tuple = sorted_hyps.pop()
best_score = best_hyp_tuple[0] best_score = best_hyp_tuple[0]
best_hyp = best_hyp_tuple[1] best_hyp = best_hyp_tuple[1]
best_index = best_hyp_tuple[2]
sent_lengths[self.num_beam_hyps_to_keep * i + j] = len(best_hyp) sent_lengths[self.num_beam_hyps_to_keep * i + j] = len(best_hyp)
# append to lists # append hyp to lists
best.append(best_hyp) best.append(best_hyp)
# append indices to list
best_indices.append(best_index)
best_scores[i * self.num_beam_hyps_to_keep + j] = best_score best_scores[i * self.num_beam_hyps_to_keep + j] = best_score
# prepare for adding eos # prepare for adding eos
sent_lengths_max = sent_lengths.max().item() + 1 sent_lengths_max = sent_lengths.max().item() + 1
sent_max_len = min(sent_lengths_max, max_length) if max_length is not None else sent_lengths_max sent_max_len = min(sent_lengths_max, max_length) if max_length is not None else sent_lengths_max
decoded: torch.LongTensor = input_ids.new(batch_size * self.num_beam_hyps_to_keep, sent_max_len) decoded: torch.LongTensor = input_ids.new(batch_size * self.num_beam_hyps_to_keep, sent_max_len)
if len(best_indices) > 0 and best_indices[0] is not None:
indices: torch.LongTensor = input_ids.new(batch_size * self.num_beam_hyps_to_keep, sent_max_len)
else:
indices = None
# shorter batches are padded if needed # shorter batches are padded if needed
if sent_lengths.min().item() != sent_lengths.max().item(): if sent_lengths.min().item() != sent_lengths.max().item():
assert pad_token_id is not None, "`pad_token_id` has to be defined" assert pad_token_id is not None, "`pad_token_id` has to be defined"
decoded.fill_(pad_token_id) decoded.fill_(pad_token_id)
if indices is not None:
indices.fill_(-1)
# fill with hypotheses and eos_token_id if the latter fits in # fill with hypotheses and eos_token_id if the latter fits in
for i, hypo in enumerate(best): for i, (hypo, best_idx) in enumerate(zip(best, best_indices)):
decoded[i, : sent_lengths[i]] = hypo decoded[i, : sent_lengths[i]] = hypo
if indices is not None:
indices[i, : len(best_idx)] = torch.tensor(best_idx)
if sent_lengths[i] < sent_max_len: if sent_lengths[i] < sent_max_len:
decoded[i, sent_lengths[i]] = eos_token_id decoded[i, sent_lengths[i]] = eos_token_id
...@@ -351,6 +381,7 @@ class BeamSearchScorer(BeamScorer): ...@@ -351,6 +381,7 @@ class BeamSearchScorer(BeamScorer):
{ {
"sequences": decoded, "sequences": decoded,
"sequence_scores": best_scores, "sequence_scores": best_scores,
"beam_indices": indices,
} }
) )
...@@ -789,6 +820,7 @@ class ConstrainedBeamSearchScorer(BeamScorer): ...@@ -789,6 +820,7 @@ class ConstrainedBeamSearchScorer(BeamScorer):
# prepare for adding eos # prepare for adding eos
sent_lengths_max = sent_lengths.max().item() + 1 sent_lengths_max = sent_lengths.max().item() + 1
sent_max_len = min(sent_lengths_max, max_length) if max_length is not None else sent_lengths_max sent_max_len = min(sent_lengths_max, max_length) if max_length is not None else sent_lengths_max
decoded: torch.LongTensor = input_ids.new(batch_size * self.num_beam_hyps_to_keep, sent_max_len) decoded: torch.LongTensor = input_ids.new(batch_size * self.num_beam_hyps_to_keep, sent_max_len)
# shorter batches are padded if needed # shorter batches are padded if needed
...@@ -801,6 +833,7 @@ class ConstrainedBeamSearchScorer(BeamScorer): ...@@ -801,6 +833,7 @@ class ConstrainedBeamSearchScorer(BeamScorer):
decoded[i, : sent_lengths[i]] = hypo decoded[i, : sent_lengths[i]] = hypo
if sent_lengths[i] < sent_max_len: if sent_lengths[i] < sent_max_len:
decoded[i, sent_lengths[i]] = eos_token_id decoded[i, sent_lengths[i]] = eos_token_id
return UserDict( return UserDict(
{ {
"sequences": decoded, "sequences": decoded,
...@@ -826,15 +859,15 @@ class BeamHypotheses: ...@@ -826,15 +859,15 @@ class BeamHypotheses:
""" """
return len(self.beams) return len(self.beams)
def add(self, hyp: torch.LongTensor, sum_logprobs: float): def add(self, hyp: torch.LongTensor, sum_logprobs: float, beam_indices: Optional[torch.LongTensor] = None):
""" """
Add a new hypothesis to the list. Add a new hypothesis to the list.
""" """
score = sum_logprobs / (hyp.shape[-1] ** self.length_penalty) score = sum_logprobs / (hyp.shape[-1] ** self.length_penalty)
if len(self) < self.num_beams or score > self.worst_score: if len(self) < self.num_beams or score > self.worst_score:
self.beams.append((score, hyp)) self.beams.append((score, hyp, beam_indices))
if len(self) > self.num_beams: if len(self) > self.num_beams:
sorted_next_scores = sorted([(s, idx) for idx, (s, _) in enumerate(self.beams)]) sorted_next_scores = sorted([(s, idx) for idx, (s, _, _) in enumerate(self.beams)])
del self.beams[sorted_next_scores[0][1]] del self.beams[sorted_next_scores[0][1]]
self.worst_score = sorted_next_scores[1][0] self.worst_score = sorted_next_scores[1][0]
else: else:
......
...@@ -217,8 +217,8 @@ class BeamSearchDecoderOnlyOutput(ModelOutput): ...@@ -217,8 +217,8 @@ class BeamSearchDecoderOnlyOutput(ModelOutput):
`(max_length-input_ids.shape[-1],)`-shaped tuple of `torch.FloatTensor` with each tensor of shape `(max_length-input_ids.shape[-1],)`-shaped tuple of `torch.FloatTensor` with each tensor of shape
`(batch_size*num_beams*num_return_sequences, config.vocab_size)`). `(batch_size*num_beams*num_return_sequences, config.vocab_size)`).
beam_indices (`tuple(tuple(torch.LongTensor))`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`): beam_indices (`tuple(tuple(torch.LongTensor))`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
Beam indices of generated token id at each generation step. `(batch_size*num_return_sequences)`-shaped Beam indices of generated token id at each generation step. `torch.LongTensor` of shape
tuple of `(max_length-input_ids.shape[-1],)`-shaped tuples of scalar `torch.LongTensor` tensors. `(batch_size*num_return_sequences, input_ids.shape[-1])`.
attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`): attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
`torch.FloatTensor` of shape `(batch_size*num_beams, num_heads, generated_length, sequence_length)`. `torch.FloatTensor` of shape `(batch_size*num_beams, num_heads, generated_length, sequence_length)`.
...@@ -230,7 +230,7 @@ class BeamSearchDecoderOnlyOutput(ModelOutput): ...@@ -230,7 +230,7 @@ class BeamSearchDecoderOnlyOutput(ModelOutput):
sequences: torch.LongTensor = None sequences: torch.LongTensor = None
sequences_scores: Optional[torch.FloatTensor] = None sequences_scores: Optional[torch.FloatTensor] = None
scores: Optional[Tuple[torch.FloatTensor]] = None scores: Optional[Tuple[torch.FloatTensor]] = None
beam_indices: Optional[Tuple[Tuple[torch.LongTensor]]] = None beam_indices: Optional[torch.LongTensor] = None
attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
...@@ -254,8 +254,8 @@ class BeamSearchEncoderDecoderOutput(ModelOutput): ...@@ -254,8 +254,8 @@ class BeamSearchEncoderDecoderOutput(ModelOutput):
`(max_length-1,)`-shaped tuple of `torch.FloatTensor` with each tensor of shape `(batch_size*num_beams, `(max_length-1,)`-shaped tuple of `torch.FloatTensor` with each tensor of shape `(batch_size*num_beams,
config.vocab_size)`). config.vocab_size)`).
beam_indices (`tuple(tuple(torch.LongTensor))`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`): beam_indices (`tuple(tuple(torch.LongTensor))`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
Beam indices of generated token id at each generation step. `(batch_size*num_return_sequences)`-shaped Beam indices of generated token id at each generation step. `torch.LongTensor` of shape
tuple of `(max_length-1,)`-shaped tuples of scalar `torch.LongTensor` tensors. `(batch_size*num_return_sequences, max_length-1)`.
attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`): attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`): encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer of the decoder) of shape `(batch_size, num_heads, Tuple of `torch.FloatTensor` (one for each layer of the decoder) of shape `(batch_size, num_heads,
...@@ -278,7 +278,7 @@ class BeamSearchEncoderDecoderOutput(ModelOutput): ...@@ -278,7 +278,7 @@ class BeamSearchEncoderDecoderOutput(ModelOutput):
sequences: torch.LongTensor = None sequences: torch.LongTensor = None
sequences_scores: Optional[torch.FloatTensor] = None sequences_scores: Optional[torch.FloatTensor] = None
scores: Optional[Tuple[torch.FloatTensor]] = None scores: Optional[Tuple[torch.FloatTensor]] = None
beam_indices: Optional[Tuple[Tuple[torch.LongTensor]]] = None beam_indices: Optional[torch.LongTensor] = None
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
...@@ -303,8 +303,8 @@ class BeamSampleDecoderOnlyOutput(ModelOutput): ...@@ -303,8 +303,8 @@ class BeamSampleDecoderOnlyOutput(ModelOutput):
`(max_length-input_ids.shape[-1],)`-shaped tuple of `torch.FloatTensor` with each tensor of shape `(max_length-input_ids.shape[-1],)`-shaped tuple of `torch.FloatTensor` with each tensor of shape
`(batch_size*num_beams*num_return_sequences, config.vocab_size)`). `(batch_size*num_beams*num_return_sequences, config.vocab_size)`).
beam_indices (`tuple(tuple(torch.LongTensor))`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`): beam_indices (`tuple(tuple(torch.LongTensor))`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
Beam indices of generated token id at each generation step. `(batch_size*num_return_sequences)`-shaped Beam indices of generated token id at each generation step. `torch.LongTensor` of shape
tuple of `(max_length-input_ids.shape[-1],)`-shaped tuples of scalar `torch.LongTensor` tensors. `(batch_size*num_return_sequences, input_ids.shape[-1])`.
attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`): attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
`torch.FloatTensor` of shape `(batch_size*num_beams, num_heads, generated_length, sequence_length)`. `torch.FloatTensor` of shape `(batch_size*num_beams, num_heads, generated_length, sequence_length)`.
...@@ -316,7 +316,7 @@ class BeamSampleDecoderOnlyOutput(ModelOutput): ...@@ -316,7 +316,7 @@ class BeamSampleDecoderOnlyOutput(ModelOutput):
sequences: torch.LongTensor = None sequences: torch.LongTensor = None
sequences_scores: Optional[torch.FloatTensor] = None sequences_scores: Optional[torch.FloatTensor] = None
scores: Optional[Tuple[torch.FloatTensor]] = None scores: Optional[Tuple[torch.FloatTensor]] = None
beam_indices: Optional[Tuple[Tuple[torch.LongTensor]]] = None beam_indices: Optional[torch.LongTensor] = None
attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
...@@ -339,9 +339,9 @@ class BeamSampleEncoderDecoderOutput(ModelOutput): ...@@ -339,9 +339,9 @@ class BeamSampleEncoderDecoderOutput(ModelOutput):
of log probabilities of tokens conditioned on log softmax of previously generated tokens in this beam. of log probabilities of tokens conditioned on log softmax of previously generated tokens in this beam.
`(max_length-1,)`-shaped tuple of `torch.FloatTensor` with each tensor of shape `(batch_size*num_beams, `(max_length-1,)`-shaped tuple of `torch.FloatTensor` with each tensor of shape `(batch_size*num_beams,
config.vocab_size)`). config.vocab_size)`).
beam_indices (`tuple(tuple(torch.LongTensor))`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`): beam_indices (`torch.LongTensor`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
Beam indices of generated token id at each generation step. `(batch_size*num_return_sequences)`-shaped Beam indices of generated token id at each generation step. `torch.LongTensor` of shape
tuple of `(max_length-1,)`-shaped tuples of scalar `torch.LongTensor` tensors. `(batch_size*num_return_sequences, max_length-1)`.
encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`): encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer of the decoder) of shape `(batch_size, num_heads, Tuple of `torch.FloatTensor` (one for each layer of the decoder) of shape `(batch_size, num_heads,
sequence_length, sequence_length)`. sequence_length, sequence_length)`.
...@@ -362,7 +362,7 @@ class BeamSampleEncoderDecoderOutput(ModelOutput): ...@@ -362,7 +362,7 @@ class BeamSampleEncoderDecoderOutput(ModelOutput):
sequences: torch.LongTensor = None sequences: torch.LongTensor = None
sequences_scores: Optional[torch.FloatTensor] = None sequences_scores: Optional[torch.FloatTensor] = None
scores: Optional[Tuple[torch.FloatTensor]] = None scores: Optional[Tuple[torch.FloatTensor]] = None
beam_indices: Optional[Tuple[Tuple[torch.LongTensor]]] = None beam_indices: Optional[torch.LongTensor] = None
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
...@@ -811,32 +811,33 @@ class GenerationMixin: ...@@ -811,32 +811,33 @@ class GenerationMixin:
"""compute the transition probabilities of sequences given generation """compute the transition probabilities of sequences given generation
scores and beam indices""" scores and beam indices"""
# reshape scores as [vocab_size * batch_size, # generation steps] # 1. reshape scores as [vocab_size * batch_size, # generation steps]
# with batch_size being 2 * vocab_size and # generation steps being # with batch_size being 2 * vocab_size and # generation steps being
# seq_len - input_length # seq_len - input_length
scores = torch.stack(scores).reshape(len(scores), -1).transpose(0, 1) scores = torch.stack(scores).reshape(len(scores), -1).transpose(0, 1)
# start of generated tokens # 2. cut beam_indices to longest beam length
cut_idx = sequences.shape[-1] - scores.shape[-1] beam_indices_mask = beam_indices < 0
# adjust for beam indices max_beam_length = (1 - beam_indices_mask.long()).sum(-1).max()
beam_sequence_indices = torch.tensor(beam_indices, device=sequences.device) * self.config.vocab_size beam_indices = beam_indices[:, :max_beam_length]
# compute real indices beam_indices_mask = beam_indices_mask[:, :max_beam_length]
# 3. Set indices of beams that finished early to 0
# such indices will be masked correctly afterwards
beam_indices[beam_indices_mask] = 0
# 4. multiply beam_indices with vocab size to gather correctly from scores
beam_sequence_indices = beam_indices * self.config.vocab_size
# 5. Define which indices contributed to scores
cut_idx = sequences.shape[-1] - max_beam_length
indices = sequences[:, cut_idx:] + beam_sequence_indices indices = sequences[:, cut_idx:] + beam_sequence_indices
# gather scores and run
# 6. Compute scores
transition_scores = scores.gather(0, indices) transition_scores = scores.gather(0, indices)
# make sure that if EOS token was used before length of sequence `sequence.shape[-1]`
# get first occurence of EOS token
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
if eos_token_id is not None: # 7. Mask out transition_scores of beams that stopped early
is_eos_token_id = sequences[:, cut_idx:] == eos_token_id transition_scores[beam_indices_mask] = 0
# make sure first eos token still contributes to transition probs
is_eos_token_id[:, -1] = False
is_eos_token_id = is_eos_token_id.roll(1, -1)
# all indices after eos shoud be masked
zero_transition_prob_mask = is_eos_token_id.cumsum(-1).bool()
# zero out padded probs
transition_scores.masked_fill_(zero_transition_prob_mask, 0.0)
return transition_scores return transition_scores
...@@ -2256,6 +2257,7 @@ class GenerationMixin: ...@@ -2256,6 +2257,7 @@ class GenerationMixin:
next_indices, next_indices,
pad_token_id=pad_token_id, pad_token_id=pad_token_id,
eos_token_id=eos_token_id, eos_token_id=eos_token_id,
beam_indices=beam_indices,
) )
beam_scores = beam_outputs["next_beam_scores"] beam_scores = beam_outputs["next_beam_scores"]
...@@ -2290,25 +2292,19 @@ class GenerationMixin: ...@@ -2290,25 +2292,19 @@ class GenerationMixin:
pad_token_id=pad_token_id, pad_token_id=pad_token_id,
eos_token_id=eos_token_id, eos_token_id=eos_token_id,
max_length=stopping_criteria.max_length, max_length=stopping_criteria.max_length,
beam_indices=beam_indices,
) )
if return_dict_in_generate: if return_dict_in_generate:
if not output_scores: if not output_scores:
sequence_outputs["sequence_scores"] = None sequence_outputs["sequence_scores"] = None
else:
num_return_sequences = beam_scorer.num_beam_hyps_to_keep
# return only as many indices as sequences
beam_indices = tuple(
(beam_indices[i * num_beams : i * num_beams + num_return_sequences] for i in range(batch_size))
)
beam_indices = sum(beam_indices, ())
if self.config.is_encoder_decoder: if self.config.is_encoder_decoder:
return BeamSearchEncoderDecoderOutput( return BeamSearchEncoderDecoderOutput(
sequences=sequence_outputs["sequences"], sequences=sequence_outputs["sequences"],
sequences_scores=sequence_outputs["sequence_scores"], sequences_scores=sequence_outputs["sequence_scores"],
scores=scores, scores=scores,
beam_indices=beam_indices, beam_indices=sequence_outputs["beam_indices"],
encoder_attentions=encoder_attentions, encoder_attentions=encoder_attentions,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
decoder_attentions=decoder_attentions, decoder_attentions=decoder_attentions,
...@@ -2320,7 +2316,7 @@ class GenerationMixin: ...@@ -2320,7 +2316,7 @@ class GenerationMixin:
sequences=sequence_outputs["sequences"], sequences=sequence_outputs["sequences"],
sequences_scores=sequence_outputs["sequence_scores"], sequences_scores=sequence_outputs["sequence_scores"],
scores=scores, scores=scores,
beam_indices=beam_indices, beam_indices=sequence_outputs["beam_indices"],
attentions=decoder_attentions, attentions=decoder_attentions,
hidden_states=decoder_hidden_states, hidden_states=decoder_hidden_states,
) )
...@@ -2580,6 +2576,7 @@ class GenerationMixin: ...@@ -2580,6 +2576,7 @@ class GenerationMixin:
next_indices, next_indices,
pad_token_id=pad_token_id, pad_token_id=pad_token_id,
eos_token_id=eos_token_id, eos_token_id=eos_token_id,
beam_indices=beam_indices,
) )
beam_scores = beam_outputs["next_beam_scores"] beam_scores = beam_outputs["next_beam_scores"]
beam_next_tokens = beam_outputs["next_beam_tokens"] beam_next_tokens = beam_outputs["next_beam_tokens"]
...@@ -2613,25 +2610,19 @@ class GenerationMixin: ...@@ -2613,25 +2610,19 @@ class GenerationMixin:
pad_token_id=pad_token_id, pad_token_id=pad_token_id,
eos_token_id=eos_token_id, eos_token_id=eos_token_id,
max_length=stopping_criteria.max_length, max_length=stopping_criteria.max_length,
beam_indices=beam_indices,
) )
if return_dict_in_generate: if return_dict_in_generate:
if not output_scores: if not output_scores:
sequence_outputs["sequence_scores"] = None sequence_outputs["sequence_scores"] = None
else:
num_return_sequences = beam_scorer.num_beam_hyps_to_keep
# return only as many indices as sequences
beam_indices = tuple(
(beam_indices[i * num_beams : i * num_beams + num_return_sequences] for i in range(batch_size))
)
beam_indices = sum(beam_indices, ())
if self.config.is_encoder_decoder: if self.config.is_encoder_decoder:
return BeamSampleEncoderDecoderOutput( return BeamSampleEncoderDecoderOutput(
sequences=sequence_outputs["sequences"], sequences=sequence_outputs["sequences"],
sequences_scores=sequence_outputs["sequence_scores"], sequences_scores=sequence_outputs["sequence_scores"],
scores=scores, scores=scores,
beam_indices=beam_indices, beam_indices=sequence_outputs["beam_indices"],
encoder_attentions=encoder_attentions, encoder_attentions=encoder_attentions,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
decoder_attentions=decoder_attentions, decoder_attentions=decoder_attentions,
...@@ -2643,7 +2634,7 @@ class GenerationMixin: ...@@ -2643,7 +2634,7 @@ class GenerationMixin:
sequences=sequence_outputs["sequences"], sequences=sequence_outputs["sequences"],
sequences_scores=sequence_outputs["sequence_scores"], sequences_scores=sequence_outputs["sequence_scores"],
scores=scores, scores=scores,
beam_indices=beam_indices, beam_indices=sequence_outputs["beam_indices"],
attentions=decoder_attentions, attentions=decoder_attentions,
hidden_states=decoder_hidden_states, hidden_states=decoder_hidden_states,
) )
...@@ -2909,6 +2900,7 @@ class GenerationMixin: ...@@ -2909,6 +2900,7 @@ class GenerationMixin:
next_tokens = next_tokens % vocab_size next_tokens = next_tokens % vocab_size
# stateless # stateless
process_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None
beam_outputs = beam_scorer.process( beam_outputs = beam_scorer.process(
group_input_ids, group_input_ids,
next_token_scores, next_token_scores,
...@@ -2916,6 +2908,7 @@ class GenerationMixin: ...@@ -2916,6 +2908,7 @@ class GenerationMixin:
next_indices, next_indices,
pad_token_id=pad_token_id, pad_token_id=pad_token_id,
eos_token_id=eos_token_id, eos_token_id=eos_token_id,
beam_indices=process_beam_indices,
) )
beam_scores[batch_group_indices] = beam_outputs["next_beam_scores"] beam_scores[batch_group_indices] = beam_outputs["next_beam_scores"]
beam_next_tokens = beam_outputs["next_beam_tokens"] beam_next_tokens = beam_outputs["next_beam_tokens"]
...@@ -2971,6 +2964,7 @@ class GenerationMixin: ...@@ -2971,6 +2964,7 @@ class GenerationMixin:
else: else:
this_peer_finished = True this_peer_finished = True
final_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None
sequence_outputs = beam_scorer.finalize( sequence_outputs = beam_scorer.finalize(
input_ids, input_ids,
beam_scores, beam_scores,
...@@ -2979,26 +2973,19 @@ class GenerationMixin: ...@@ -2979,26 +2973,19 @@ class GenerationMixin:
pad_token_id=pad_token_id, pad_token_id=pad_token_id,
eos_token_id=eos_token_id, eos_token_id=eos_token_id,
max_length=stopping_criteria.max_length, max_length=stopping_criteria.max_length,
beam_indices=final_beam_indices,
) )
if return_dict_in_generate: if return_dict_in_generate:
if not output_scores: if not output_scores:
sequence_outputs["sequence_scores"] = None sequence_outputs["sequence_scores"] = None
else:
beam_indices = sum(beam_indices, ())
num_return_sequences = beam_scorer.num_beam_hyps_to_keep
# return only as many indices as sequences
beam_indices = tuple(
(beam_indices[i * num_beams : i * num_beams + num_return_sequences] for i in range(batch_size))
)
beam_indices = sum(beam_indices, ())
if self.config.is_encoder_decoder: if self.config.is_encoder_decoder:
return BeamSearchEncoderDecoderOutput( return BeamSearchEncoderDecoderOutput(
sequences=sequence_outputs["sequences"], sequences=sequence_outputs["sequences"],
sequences_scores=sequence_outputs["sequence_scores"], sequences_scores=sequence_outputs["sequence_scores"],
scores=scores, scores=scores,
beam_indices=beam_indices, beam_indices=sequence_outputs["beam_indices"],
encoder_attentions=encoder_attentions, encoder_attentions=encoder_attentions,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
decoder_attentions=decoder_attentions, decoder_attentions=decoder_attentions,
...@@ -3010,6 +2997,7 @@ class GenerationMixin: ...@@ -3010,6 +2997,7 @@ class GenerationMixin:
sequences=sequence_outputs["sequences"], sequences=sequence_outputs["sequences"],
sequences_scores=sequence_outputs["sequence_scores"], sequences_scores=sequence_outputs["sequence_scores"],
scores=scores, scores=scores,
beam_indices=sequence_outputs["beam_indices"],
attentions=decoder_attentions, attentions=decoder_attentions,
hidden_states=decoder_hidden_states, hidden_states=decoder_hidden_states,
) )
......
...@@ -126,7 +126,11 @@ class BeamSearchTester: ...@@ -126,7 +126,11 @@ class BeamSearchTester:
tokens = next_tokens.clone() tokens = next_tokens.clone()
tokens[:, : self.num_beams] = self.eos_token_id tokens[:, : self.num_beams] = self.eos_token_id
beam_scorer.process(input_ids, next_scores, tokens, next_indices, eos_token_id=self.eos_token_id) beam_indices = torch.zeros_like(input_ids) + torch.arange(input_ids.shape[-1], device=input_ids.device)
beam_indices = tuple(tuple(b) for b in beam_indices)
beam_scorer.process(
input_ids, next_scores, tokens, next_indices, eos_token_id=self.eos_token_id, beam_indices=beam_indices
)
# beam scorer should be done # beam scorer should be done
self.parent.assertTrue(beam_scorer.is_done) self.parent.assertTrue(beam_scorer.is_done)
...@@ -136,7 +140,7 @@ class BeamSearchTester: ...@@ -136,7 +140,7 @@ class BeamSearchTester:
tokens = next_tokens.clone() tokens = next_tokens.clone()
tokens[:, 1] = self.eos_token_id tokens[:, 1] = self.eos_token_id
beam_outputs = beam_scorer.process( beam_outputs = beam_scorer.process(
input_ids, next_scores, tokens, next_indices, eos_token_id=self.eos_token_id input_ids, next_scores, tokens, next_indices, eos_token_id=self.eos_token_id, beam_indices=beam_indices
) )
output_scores = beam_outputs["next_beam_scores"] output_scores = beam_outputs["next_beam_scores"]
output_tokens = beam_outputs["next_beam_tokens"] output_tokens = beam_outputs["next_beam_tokens"]
...@@ -161,10 +165,15 @@ class BeamSearchTester: ...@@ -161,10 +165,15 @@ class BeamSearchTester:
self.parent.assertTrue(torch.allclose(expected_output_scores, output_scores, atol=1e-3)) self.parent.assertTrue(torch.allclose(expected_output_scores, output_scores, atol=1e-3))
# make sure ids of eos token are correctly saved in beam_hyps of beam scorer # make sure ids of eos token are correctly saved in beam_hyps of beam scorer
expected_beam_indices = list(range(10))
for batch_idx in range(self.batch_size): for batch_idx in range(self.batch_size):
correct_idx = batch_idx * self.num_beams + next_indices[batch_idx, 1] correct_idx = batch_idx * self.num_beams + next_indices[batch_idx, 1]
self.parent.assertListEqual( self.parent.assertListEqual(
input_ids[correct_idx].tolist(), beam_scorer._beam_hyps[batch_idx].beams[0][-1].tolist() input_ids[correct_idx].tolist(), beam_scorer._beam_hyps[batch_idx].beams[0][1].tolist()
)
self.parent.assertListEqual(
expected_beam_indices + [next_indices[batch_idx, 1].item()],
torch.tensor(beam_scorer._beam_hyps[batch_idx].beams[0][2]).tolist(),
) )
def check_beam_scores_finalize(self, input_ids, next_tokens, next_indices, next_scores): def check_beam_scores_finalize(self, input_ids, next_tokens, next_indices, next_scores):
...@@ -188,6 +197,8 @@ class BeamSearchTester: ...@@ -188,6 +197,8 @@ class BeamSearchTester:
input_ids = torch.cat([input_ids[output_indices, :], output_tokens.unsqueeze(-1)], dim=-1) input_ids = torch.cat([input_ids[output_indices, :], output_tokens.unsqueeze(-1)], dim=-1)
# finalize # finalize
beam_indices = torch.zeros_like(input_ids) + torch.arange(input_ids.shape[-1], device=input_ids.device)
beam_indices = tuple(tuple(b) for b in beam_indices)
sequence_output = beam_scorer.finalize( sequence_output = beam_scorer.finalize(
input_ids, input_ids,
output_scores, output_scores,
...@@ -196,6 +207,7 @@ class BeamSearchTester: ...@@ -196,6 +207,7 @@ class BeamSearchTester:
pad_token_id=self.pad_token_id, pad_token_id=self.pad_token_id,
eos_token_id=self.eos_token_id, eos_token_id=self.eos_token_id,
max_length=max_length, max_length=max_length,
beam_indices=beam_indices,
) )
sequences = sequence_output["sequences"] sequences = sequence_output["sequences"]
...@@ -225,6 +237,7 @@ class BeamSearchTester: ...@@ -225,6 +237,7 @@ class BeamSearchTester:
pad_token_id=self.pad_token_id, pad_token_id=self.pad_token_id,
eos_token_id=self.eos_token_id, eos_token_id=self.eos_token_id,
max_length=max_length, max_length=max_length,
beam_indices=beam_indices,
) )
sequences = sequence_output["sequences"] sequences = sequence_output["sequences"]
sequence_scores = sequence_output["sequence_scores"] sequence_scores = sequence_output["sequence_scores"]
...@@ -394,7 +407,7 @@ class ConstrainedBeamSearchTester: ...@@ -394,7 +407,7 @@ class ConstrainedBeamSearchTester:
for batch_idx in range(self.batch_size): for batch_idx in range(self.batch_size):
correct_idx = batch_idx * self.num_beams + next_indices[batch_idx, 1] correct_idx = batch_idx * self.num_beams + next_indices[batch_idx, 1]
self.parent.assertListEqual( self.parent.assertListEqual(
input_ids[correct_idx].tolist(), constrained_beam_scorer._beam_hyps[batch_idx].beams[0][-1].tolist() input_ids[correct_idx].tolist(), constrained_beam_scorer._beam_hyps[batch_idx].beams[0][1].tolist()
) )
def check_constrained_beam_scorer_finalize( def check_constrained_beam_scorer_finalize(
......
...@@ -2322,6 +2322,94 @@ class GenerationIntegrationTests(unittest.TestCase): ...@@ -2322,6 +2322,94 @@ class GenerationIntegrationTests(unittest.TestCase):
self.assertTrue(torch.allclose(transition_scores_sum, outputs.sequences_scores, atol=1e-3)) self.assertTrue(torch.allclose(transition_scores_sum, outputs.sequences_scores, atol=1e-3))
@slow
def test_transition_scores_early_stopping(self):
# This is an aggressive test that makes sure that `beam_search's`
# transition scores are computed correctly for varying `num_return_sequences`,
# `num_beams` and `batch_size > 1`
# 2 x input_ids for "question: How are you? \n context: I had a long day, "
input_ids = torch.tensor(2 * [[822, 10, 571, 33, 25, 58, 2625, 10, 27, 141, 3, 9, 307, 239, 6, 1]]).to(
torch_device
)
model = AutoModelForSeq2SeqLM.from_pretrained("t5-small").to(torch_device)
result = model.generate(
input_ids,
max_length=10,
return_dict_in_generate=True,
output_scores=True,
forced_eos_token_id=model.config.eos_token_id,
num_beams=4,
do_sample=False,
num_return_sequences=3,
length_penalty=0.0,
)
transition_scores = model.compute_transition_beam_scores(
sequences=result.sequences, scores=result.scores, beam_indices=result.beam_indices
)
sum_transition_scores = torch.sum(transition_scores, dim=1)
self.assertListEqual(sum_transition_scores.cpu().tolist(), result.sequences_scores.cpu().tolist())
def test_log_scores_sample_decoder_only(self):
articles = ["I need input_ids to generate", "Short and"]
tokenizer = GPT2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
tokenizer.padding_side = "left"
tokenizer.pad_token = tokenizer.eos_token
model = GPT2LMHeadModel.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
inputs = tokenizer(articles, return_tensors="pt", padding=True).to(torch_device)
result = model.generate(
**inputs,
max_length=15,
return_dict_in_generate=True,
do_sample=False,
output_scores=True,
)
# decoder-only starts generating from `input_ids`
begin_generation = inputs.input_ids.shape[-1]
gen_sequences = result.sequences[:, begin_generation:]
probs = torch.stack(result.scores, dim=1).softmax(-1)
gen_probs = torch.gather(probs, 2, gen_sequences[:, :, None]).squeeze(-1)
expected_probs = torch.tensor([[0.0014, 0.0015], [0.0014, 0.0014]])
self.assertTrue(torch.allclose(gen_probs.cpu(), expected_probs, atol=1e-3))
def test_log_scores_sample_encoder_decoder(self):
articles = ["I need input_ids to generate", "Short and"]
tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
model = BartForConditionalGeneration.from_pretrained("hf-internal-testing/tiny-random-bart").to(torch_device)
inputs = tokenizer(articles, return_tensors="pt", padding=True).to(torch_device)
result = model.generate(
**inputs,
max_length=3,
return_dict_in_generate=True,
do_sample=False,
num_beams=1,
output_scores=True,
)
# encoder-decoder has one decoder_start_token_id by default
begin_generation = 1
gen_sequences = result.sequences[:, begin_generation:]
probs = torch.stack(result.scores, dim=1).softmax(-1)
gen_probs = torch.gather(probs, 2, gen_sequences[:, :, None]).squeeze(-1)
expected_probs = torch.tensor([[0.0013, 1.0000], [0.0013, 1.0000]])
self.assertTrue(torch.allclose(gen_probs.cpu(), expected_probs, atol=1e-3))
@slow @slow
def test_beam_search_example_integration(self): def test_beam_search_example_integration(self):
# exactly the example provided in the docstrings of beam search, which previously # exactly the example provided in the docstrings of beam search, which previously
...@@ -2366,8 +2454,8 @@ class GenerationIntegrationTests(unittest.TestCase): ...@@ -2366,8 +2454,8 @@ class GenerationIntegrationTests(unittest.TestCase):
@slow @slow
def test_constrained_beam_search(self): def test_constrained_beam_search(self):
model = GPT2LMHeadModel.from_pretrained("../gpt2").to(torch_device) model = GPT2LMHeadModel.from_pretrained("gpt2").to(torch_device)
tokenizer = GPT2Tokenizer.from_pretrained("../gpt2") tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
force_tokens = tokenizer("scared", add_prefix_space=True, add_special_tokens=False).input_ids force_tokens = tokenizer("scared", add_prefix_space=True, add_special_tokens=False).input_ids
force_tokens_2 = tokenizer("big weapons", add_prefix_space=True, add_special_tokens=False).input_ids force_tokens_2 = tokenizer("big weapons", add_prefix_space=True, add_special_tokens=False).input_ids
...@@ -2403,8 +2491,8 @@ class GenerationIntegrationTests(unittest.TestCase): ...@@ -2403,8 +2491,8 @@ class GenerationIntegrationTests(unittest.TestCase):
@slow @slow
def test_constrained_beam_search_mixed(self): def test_constrained_beam_search_mixed(self):
model = GPT2LMHeadModel.from_pretrained("../gpt2").to(torch_device) model = GPT2LMHeadModel.from_pretrained("gpt2").to(torch_device)
tokenizer = GPT2Tokenizer.from_pretrained("../gpt2") tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
force_phrase = tokenizer("scared", add_prefix_space=True, add_special_tokens=False).input_ids force_phrase = tokenizer("scared", add_prefix_space=True, add_special_tokens=False).input_ids
flexible_phrases = tokenizer( flexible_phrases = tokenizer(
...@@ -2442,8 +2530,8 @@ class GenerationIntegrationTests(unittest.TestCase): ...@@ -2442,8 +2530,8 @@ class GenerationIntegrationTests(unittest.TestCase):
@slow @slow
def test_constrained_beam_search_mixed_mixin(self): def test_constrained_beam_search_mixed_mixin(self):
model = GPT2LMHeadModel.from_pretrained("../gpt2").to(torch_device) model = GPT2LMHeadModel.from_pretrained("gpt2").to(torch_device)
tokenizer = GPT2Tokenizer.from_pretrained("../gpt2") tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
force_word = "scared" force_word = "scared"
force_flexible = ["scream", "screams", "screaming", "screamed"] force_flexible = ["scream", "screams", "screaming", "screamed"]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment