Unverified Commit 1d9c26a4 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Generate: TF `compute_transition_scores` (#21341)

parent d3046dad
...@@ -50,6 +50,7 @@ and how to create and save a customized generation configuration, refer to the ...@@ -50,6 +50,7 @@ and how to create and save a customized generation configuration, refer to the
[[autodoc]] generation.TFGenerationMixin [[autodoc]] generation.TFGenerationMixin
- generate - generate
- compute_transition_scores
## FlaxGenerationMixin ## FlaxGenerationMixin
......
...@@ -210,6 +210,9 @@ class TFBeamSearchDecoderOnlyOutput(ModelOutput): ...@@ -210,6 +210,9 @@ class TFBeamSearchDecoderOnlyOutput(ModelOutput):
softmax scores for each vocabulary token and sum of log softmax of previously generated tokens in this softmax scores for each vocabulary token and sum of log softmax of previously generated tokens in this
beam. Tuple of `tf.Tensor` with up to `max_new_tokens` elements (one element for each generated token), beam. Tuple of `tf.Tensor` with up to `max_new_tokens` elements (one element for each generated token),
with each tensor of shape `(batch_size*num_beams*num_return_sequences, config.vocab_size)`. with each tensor of shape `(batch_size*num_beams*num_return_sequences, config.vocab_size)`.
beam_indices (`tf.Tensor`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
Beam indices of generated token id at each generation step. `tf.Tensor` of shape
`(batch_size*num_return_sequences, sequence_length)`.
attentions (`tuple(tuple(tf.Tensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`): attentions (`tuple(tuple(tf.Tensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
`tf.Tensor` of shape `(batch_size*num_beams, num_heads, generated_length, sequence_length)`. `tf.Tensor` of shape `(batch_size*num_beams, num_heads, generated_length, sequence_length)`.
...@@ -221,6 +224,7 @@ class TFBeamSearchDecoderOnlyOutput(ModelOutput): ...@@ -221,6 +224,7 @@ class TFBeamSearchDecoderOnlyOutput(ModelOutput):
sequences: tf.Tensor = None sequences: tf.Tensor = None
sequences_scores: Optional[tf.Tensor] = None sequences_scores: Optional[tf.Tensor] = None
scores: Optional[Tuple[tf.Tensor]] = None scores: Optional[Tuple[tf.Tensor]] = None
beam_indices: Optional[tf.Tensor] = None
attentions: Optional[Tuple[Tuple[tf.Tensor]]] = None attentions: Optional[Tuple[Tuple[tf.Tensor]]] = None
hidden_states: Optional[Tuple[Tuple[tf.Tensor]]] = None hidden_states: Optional[Tuple[Tuple[tf.Tensor]]] = None
...@@ -243,7 +247,9 @@ class TFBeamSearchEncoderDecoderOutput(ModelOutput): ...@@ -243,7 +247,9 @@ class TFBeamSearchEncoderDecoderOutput(ModelOutput):
softmax scores for each vocabulary token and sum of log softmax of previously generated tokens in this softmax scores for each vocabulary token and sum of log softmax of previously generated tokens in this
beam. `Tuple of `tf.Tensor` with up to `max_new_tokens` elements (one element for each generated token), beam. `Tuple of `tf.Tensor` with up to `max_new_tokens` elements (one element for each generated token),
with each tensor of shape `(batch_size*num_beams, config.vocab_size)`. with each tensor of shape `(batch_size*num_beams, config.vocab_size)`.
attentions (`tuple(tuple(tf.Tensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`): beam_indices (`tf.Tensor`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
Beam indices of generated token id at each generation step. `tf.Tensor` of shape
`(batch_size*num_return_sequences, sequence_length)`.
encoder_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`): encoder_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
Tuple of `tf.Tensor` (one for each layer of the decoder) of shape `(batch_size, num_heads, sequence_length, Tuple of `tf.Tensor` (one for each layer of the decoder) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`. sequence_length)`.
...@@ -265,6 +271,7 @@ class TFBeamSearchEncoderDecoderOutput(ModelOutput): ...@@ -265,6 +271,7 @@ class TFBeamSearchEncoderDecoderOutput(ModelOutput):
sequences: tf.Tensor = None sequences: tf.Tensor = None
sequences_scores: Optional[tf.Tensor] = None sequences_scores: Optional[tf.Tensor] = None
scores: Optional[Tuple[tf.Tensor]] = None scores: Optional[Tuple[tf.Tensor]] = None
beam_indices: Optional[tf.Tensor] = None
encoder_attentions: Optional[Tuple[tf.Tensor]] = None encoder_attentions: Optional[Tuple[tf.Tensor]] = None
encoder_hidden_states: Optional[Tuple[tf.Tensor]] = None encoder_hidden_states: Optional[Tuple[tf.Tensor]] = None
decoder_attentions: Optional[Tuple[Tuple[tf.Tensor]]] = None decoder_attentions: Optional[Tuple[Tuple[tf.Tensor]]] = None
...@@ -288,6 +295,9 @@ class TFBeamSampleDecoderOnlyOutput(ModelOutput): ...@@ -288,6 +295,9 @@ class TFBeamSampleDecoderOnlyOutput(ModelOutput):
softmax scores for each vocabulary token and sum of log softmax of previously generated tokens in this softmax scores for each vocabulary token and sum of log softmax of previously generated tokens in this
beam. Tuple of `tf.Tensor` with up to `max_new_tokens` elements (one element for each generated token), beam. Tuple of `tf.Tensor` with up to `max_new_tokens` elements (one element for each generated token),
with each tensor of shape `(batch_size*num_beams*num_return_sequences, config.vocab_size)`. with each tensor of shape `(batch_size*num_beams*num_return_sequences, config.vocab_size)`.
beam_indices (`tf.Tensor`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
Beam indices of generated token id at each generation step. `tf.Tensor` of shape
`(batch_size*num_return_sequences, sequence_length)`.
attentions (`tuple(tuple(tf.Tensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`): attentions (`tuple(tuple(tf.Tensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
`tf.Tensor` of shape `(batch_size*num_beams, num_heads, generated_length, sequence_length)`. `tf.Tensor` of shape `(batch_size*num_beams, num_heads, generated_length, sequence_length)`.
...@@ -299,6 +309,7 @@ class TFBeamSampleDecoderOnlyOutput(ModelOutput): ...@@ -299,6 +309,7 @@ class TFBeamSampleDecoderOnlyOutput(ModelOutput):
sequences: tf.Tensor = None sequences: tf.Tensor = None
sequences_scores: Optional[tf.Tensor] = None sequences_scores: Optional[tf.Tensor] = None
scores: Optional[Tuple[tf.Tensor]] = None scores: Optional[Tuple[tf.Tensor]] = None
beam_indices: Optional[tf.Tensor] = None
attentions: Optional[Tuple[Tuple[tf.Tensor]]] = None attentions: Optional[Tuple[Tuple[tf.Tensor]]] = None
hidden_states: Optional[Tuple[Tuple[tf.Tensor]]] = None hidden_states: Optional[Tuple[Tuple[tf.Tensor]]] = None
...@@ -321,6 +332,9 @@ class TFBeamSampleEncoderDecoderOutput(ModelOutput): ...@@ -321,6 +332,9 @@ class TFBeamSampleEncoderDecoderOutput(ModelOutput):
softmax scores for each vocabulary token and sum of log softmax of previously generated tokens in this softmax scores for each vocabulary token and sum of log softmax of previously generated tokens in this
beam. Tuple of `tf.Tensor` with up to `max_new_tokens` elements (one element for each generated token), beam. Tuple of `tf.Tensor` with up to `max_new_tokens` elements (one element for each generated token),
with each tensor of shape `(batch_size*num_beams, config.vocab_size)`. with each tensor of shape `(batch_size*num_beams, config.vocab_size)`.
beam_indices (`tf.Tensor`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
Beam indices of generated token id at each generation step. `tf.Tensor` of shape
`(batch_size*num_return_sequences, sequence_length)`.
encoder_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`): encoder_attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
Tuple of `tf.Tensor` (one for each layer of the decoder) of shape `(batch_size, num_heads, sequence_length, Tuple of `tf.Tensor` (one for each layer of the decoder) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`. sequence_length)`.
...@@ -341,6 +355,7 @@ class TFBeamSampleEncoderDecoderOutput(ModelOutput): ...@@ -341,6 +355,7 @@ class TFBeamSampleEncoderDecoderOutput(ModelOutput):
sequences: tf.Tensor = None sequences: tf.Tensor = None
sequences_scores: Optional[tf.Tensor] = None sequences_scores: Optional[tf.Tensor] = None
scores: Optional[Tuple[tf.Tensor]] = None scores: Optional[Tuple[tf.Tensor]] = None
beam_indices: Optional[tf.Tensor] = None
encoder_attentions: Optional[Tuple[tf.Tensor]] = None encoder_attentions: Optional[Tuple[tf.Tensor]] = None
encoder_hidden_states: Optional[Tuple[tf.Tensor]] = None encoder_hidden_states: Optional[Tuple[tf.Tensor]] = None
decoder_attentions: Optional[Tuple[Tuple[tf.Tensor]]] = None decoder_attentions: Optional[Tuple[Tuple[tf.Tensor]]] = None
...@@ -480,6 +495,126 @@ class TFGenerationMixin: ...@@ -480,6 +495,126 @@ class TFGenerationMixin:
else: else:
return logits return logits
def compute_transition_scores(
self,
sequences: tf.Tensor,
scores: Tuple[tf.Tensor],
beam_indices: Optional[tf.Tensor] = None,
normalize_logits: bool = False,
) -> tf.Tensor:
"""
Computes the transition scores of sequences given the generation scores (and beam indices, if beam search was
used). This is a convenient method to quicky obtain the scores of the selected tokens at generation time.
Parameters:
sequences (`tf.Tensor`):
The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or
shorter if all batches finished early due to the `eos_token_id`.
scores (`tuple(tf.Tensor)`):
Transition scores for each vocabulary token at each generation step. Beam transition scores consisting
of log probabilities of tokens conditioned on log softmax of previously generated tokens Tuple of
`tf.Tensor` with up to `max_new_tokens` elements (one element for each generated token), with each
tensor of shape `(batch_size*num_beams, config.vocab_size)`.
beam_indices (`tf.Tensor`, *optional*):
Beam indices of generated token id at each generation step. `tf.Tensor` of shape
`(batch_size*num_return_sequences, sequence_length)`. Only required if a `num_beams>1` at
generate-time.
normalize_logits (`bool`, *optional*, defaults to `False`):
Whether to normalize the logits (which, for legacy reasons, may be unnormalized).
Return:
`tf.Tensor`: A `tf.Tensor` of shape `(batch_size*num_return_sequences, sequence_length)` containing
the transition scores (logits)
Examples:
```python
>>> from transformers import GPT2Tokenizer, TFAutoModelForCausalLM
>>> import numpy as np
>>> tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
>>> model = TFAutoModelForCausalLM.from_pretrained("gpt2")
>>> tokenizer.pad_token_id = tokenizer.eos_token_id
>>> inputs = tokenizer(["Today is"], return_tensors="tf")
>>> # Example 1: Print the scores for each token generated with Greedy Search
>>> outputs = model.generate(**inputs, max_new_tokens=5, return_dict_in_generate=True, output_scores=True)
>>> transition_scores = model.compute_transition_scores(
... outputs.sequences, outputs.scores, normalize_logits=True
... )
>>> # input_length is the length of the input prompt for decoder-only models, like the GPT family, and 1 for
>>> # encoder-decoder models, like BART or T5.
>>> input_length = 1 if model.config.is_encoder_decoder else inputs.input_ids.shape[1]
>>> generated_tokens = outputs.sequences[:, input_length:]
>>> for tok, score in zip(generated_tokens[0], transition_scores[0]):
... # | token | token string | logits | probability
... print(f"| {tok:5d} | {tokenizer.decode(tok):8s} | {score.numpy():.3f} | {np.exp(score.numpy()):.2%}")
| 262 | the | -1.413 | 24.33%
| 1110 | day | -2.609 | 7.36%
| 618 | when | -2.009 | 13.41%
| 356 | we | -1.859 | 15.58%
| 460 | can | -2.508 | 8.14%
>>> # Example 2: Reconstruct the sequence scores from Beam Search
>>> outputs = model.generate(
... **inputs,
... max_new_tokens=5,
... num_beams=4,
... num_return_sequences=4,
... return_dict_in_generate=True,
... output_scores=True,
... )
>>> transition_scores = model.compute_transition_scores(
... outputs.sequences, outputs.scores, outputs.beam_indices, normalize_logits=False
... )
>>> # If you sum the generated tokens' scores and apply the length penalty, you'll get the sequence scores.
>>> # Tip: recomputing the scores is only guaranteed to match with `normalize_logits=False`. Depending on the
>>> # use case, you might want to recompute it with `normalize_logits=True`.
>>> output_length = input_length + np.sum(transition_scores.numpy() < 0, axis=1)
>>> length_penalty = model.generation_config.length_penalty
>>> reconstructed_scores = np.sum(transition_scores, axis=1) / (output_length**length_penalty)
>>> print(np.allclose(outputs.sequences_scores, reconstructed_scores))
True
```"""
# 1. In absence of `beam_indices`, we can assume that we come from e.g. greedy search, which is equivalent
# to a beam search approach were the first (and only) beam is always selected
if beam_indices is None:
beam_indices = tf.tile(tf.expand_dims(tf.range(scores[0].shape[0]), axis=1), [1, len(scores)])
# 2. reshape scores as [batch_size, vocab_size, # generation steps] with # generation steps being
# seq_len - input_length
scores = tf.transpose(tf.reshape(tf.stack(scores), (len(scores), -1)), (1, 0))
scores = tf.reshape(scores, (-1, self.config.vocab_size, scores.shape[-1]))
# 3. Optionally normalize the logits (across the vocab dimension)
if normalize_logits:
scores = tf.nn.log_softmax(scores, axis=1)
# 4. cut beam_indices to longest beam length
beam_indices_mask = beam_indices < 0
max_beam_length = tf.math.reduce_max(
tf.math.reduce_sum((1 - tf.cast(beam_indices_mask, dtype=tf.int32)), axis=-1)
)
beam_indices = beam_indices[:, -max_beam_length:]
beam_indices_mask = beam_indices_mask[:, -max_beam_length:]
# 5. Set indices of beams that finished early to 0; such indices will be masked correctly afterwards
beam_indices = tf.where(beam_indices_mask, 0, beam_indices)
# 6. Define which indices contributed to scores
cut_idx = sequences.shape[-1] - max_beam_length
token_indices = sequences[:, cut_idx:]
gen_step_idx = tf.broadcast_to(tf.range(scores.shape[-1]), token_indices.shape)
indices = tf.stack([beam_indices, token_indices, gen_step_idx], axis=-1)
# 7. Compute scores
transition_scores = tf.gather_nd(scores, indices)
# 8. Mask out transition_scores of beams that stopped early
transition_scores = tf.where(beam_indices_mask, 0, transition_scores)
return transition_scores
def _validate_model_class(self): def _validate_model_class(self):
""" """
Confirms that the model class is compatible with generation. If not, raises an exception that points to the Confirms that the model class is compatible with generation. If not, raises an exception that points to the
...@@ -866,6 +1001,7 @@ class TFGenerationMixin: ...@@ -866,6 +1001,7 @@ class TFGenerationMixin:
length_penalty=generation_config.length_penalty, length_penalty=generation_config.length_penalty,
early_stopping=generation_config.early_stopping, early_stopping=generation_config.early_stopping,
logits_processor=logits_processor, logits_processor=logits_processor,
output_scores=generation_config.output_scores,
return_dict_in_generate=generation_config.return_dict_in_generate, return_dict_in_generate=generation_config.return_dict_in_generate,
num_return_sequences=generation_config.num_return_sequences, num_return_sequences=generation_config.num_return_sequences,
**model_kwargs, **model_kwargs,
...@@ -906,6 +1042,7 @@ class TFGenerationMixin: ...@@ -906,6 +1042,7 @@ class TFGenerationMixin:
early_stopping=generation_config.early_stopping, early_stopping=generation_config.early_stopping,
logits_processor=logits_processor, logits_processor=logits_processor,
logits_warper=logits_warper, logits_warper=logits_warper,
output_scores=generation_config.output_scores,
return_dict_in_generate=generation_config.return_dict_in_generate, return_dict_in_generate=generation_config.return_dict_in_generate,
num_return_sequences=generation_config.num_return_sequences, num_return_sequences=generation_config.num_return_sequences,
**model_kwargs, **model_kwargs,
...@@ -1489,10 +1626,13 @@ class TFGenerationMixin: ...@@ -1489,10 +1626,13 @@ class TFGenerationMixin:
) )
next_token_logits = model_outputs.logits[:, -1] next_token_logits = model_outputs.logits[:, -1]
# pre-process distribution
next_tokens_scores = logits_processor(generated, next_token_logits, cur_len)
# Store scores, attentions and hidden_states when required # Store scores, attentions and hidden_states when required
if not use_xla and return_dict_in_generate: if not use_xla and return_dict_in_generate:
if output_scores: if output_scores:
scores.append(next_token_logits) scores.append(next_tokens_scores)
if output_attentions and self.config.is_encoder_decoder: if output_attentions and self.config.is_encoder_decoder:
decoder_attentions.append(model_outputs.decoder_attentions) decoder_attentions.append(model_outputs.decoder_attentions)
elif output_attentions and not self.config.is_encoder_decoder: elif output_attentions and not self.config.is_encoder_decoder:
...@@ -1505,9 +1645,6 @@ class TFGenerationMixin: ...@@ -1505,9 +1645,6 @@ class TFGenerationMixin:
elif output_hidden_states and self.config.is_encoder_decoder: elif output_hidden_states and self.config.is_encoder_decoder:
decoder_hidden_states.append(model_outputs.hidden_states) decoder_hidden_states.append(model_outputs.hidden_states)
# pre-process distribution
next_tokens_scores = logits_processor(generated, next_token_logits, cur_len)
# argmax # argmax
next_tokens = tf.argmax(next_tokens_scores, axis=-1, output_type=tf.int32) next_tokens = tf.argmax(next_tokens_scores, axis=-1, output_type=tf.int32)
...@@ -1762,10 +1899,14 @@ class TFGenerationMixin: ...@@ -1762,10 +1899,14 @@ class TFGenerationMixin:
) )
next_token_logits = model_outputs.logits[:, -1] next_token_logits = model_outputs.logits[:, -1]
# pre-process distribution
next_tokens_scores = logits_processor(generated, next_token_logits, cur_len)
next_tokens_scores = logits_warper(generated, next_tokens_scores, cur_len)
# Store scores, attentions and hidden_states when required # Store scores, attentions and hidden_states when required
if not use_xla and return_dict_in_generate: if not use_xla and return_dict_in_generate:
if output_scores: if output_scores:
scores.append(next_token_logits) scores.append(next_tokens_scores)
if output_attentions and self.config.is_encoder_decoder: if output_attentions and self.config.is_encoder_decoder:
decoder_attentions.append(model_outputs.decoder_attentions) decoder_attentions.append(model_outputs.decoder_attentions)
elif output_attentions and not self.config.is_encoder_decoder: elif output_attentions and not self.config.is_encoder_decoder:
...@@ -1778,10 +1919,6 @@ class TFGenerationMixin: ...@@ -1778,10 +1919,6 @@ class TFGenerationMixin:
elif output_hidden_states and self.config.is_encoder_decoder: elif output_hidden_states and self.config.is_encoder_decoder:
decoder_hidden_states.append(model_outputs.hidden_states) decoder_hidden_states.append(model_outputs.hidden_states)
# pre-process distribution
next_tokens_scores = logits_processor(generated, next_token_logits, cur_len)
next_tokens_scores = logits_warper(generated, next_tokens_scores, cur_len)
# sample # sample
if seed is not None: if seed is not None:
sample_seed = seed sample_seed = seed
...@@ -2066,7 +2203,7 @@ class TFGenerationMixin: ...@@ -2066,7 +2203,7 @@ class TFGenerationMixin:
needs_full_input = "use_mems" in set(inspect.signature(self.prepare_inputs_for_generation).parameters.keys()) needs_full_input = "use_mems" in set(inspect.signature(self.prepare_inputs_for_generation).parameters.keys())
# 2. init `attentions`, `hidden_states`, and `scores` tuples # 2. init `attentions`, `hidden_states`, and `scores` tuples
scores = [] if (return_dict_in_generate and output_scores) else None all_scores = [] if (return_dict_in_generate and output_scores) else None
decoder_attentions = [] if (return_dict_in_generate and output_attentions) else None decoder_attentions = [] if (return_dict_in_generate and output_attentions) else None
cross_attentions = [] if (return_dict_in_generate and output_attentions) else None cross_attentions = [] if (return_dict_in_generate and output_attentions) else None
decoder_hidden_states = [] if (return_dict_in_generate and output_hidden_states) else None decoder_hidden_states = [] if (return_dict_in_generate and output_hidden_states) else None
...@@ -2090,6 +2227,10 @@ class TFGenerationMixin: ...@@ -2090,6 +2227,10 @@ class TFGenerationMixin:
) )
scores = tf.ones((batch_size, num_beams)) * -1.0e9 scores = tf.ones((batch_size, num_beams)) * -1.0e9
# per batch beam indices
running_beam_indices = tf.ones((batch_size, num_beams, max_length), dtype=tf.int32) * -1
beam_indices = tf.ones((batch_size, num_beams, max_length), dtype=tf.int32) * -1
# flatten beam dim # flatten beam dim
if "encoder_outputs" in model_kwargs: if "encoder_outputs" in model_kwargs:
model_kwargs["encoder_outputs"]["last_hidden_state"] = flatten_beam_dim( model_kwargs["encoder_outputs"]["last_hidden_state"] = flatten_beam_dim(
...@@ -2104,8 +2245,10 @@ class TFGenerationMixin: ...@@ -2104,8 +2245,10 @@ class TFGenerationMixin:
cur_len, cur_len,
running_sequences, running_sequences,
running_scores, running_scores,
running_beam_indices,
sequences, sequences,
scores, scores,
beam_indices,
is_sent_finished, is_sent_finished,
model_kwargs, model_kwargs,
): ):
...@@ -2140,8 +2283,10 @@ class TFGenerationMixin: ...@@ -2140,8 +2283,10 @@ class TFGenerationMixin:
cur_len, cur_len,
running_sequences, running_sequences,
running_scores, running_scores,
running_beam_indices,
sequences, sequences,
scores, scores,
beam_indices,
is_sent_finished, is_sent_finished,
model_kwargs, model_kwargs,
): ):
...@@ -2165,10 +2310,31 @@ class TFGenerationMixin: ...@@ -2165,10 +2310,31 @@ class TFGenerationMixin:
) )
logits = unflatten_beam_dim(model_outputs.logits[:, -1], num_beams) logits = unflatten_beam_dim(model_outputs.logits[:, -1], num_beams)
# 2. Compute log probs
# get log probabilities from logits, process logits with processors (*e.g.* min_length, ...), and
# add new logprobs to existing running logprobs scores.
log_probs = tf.nn.log_softmax(logits)
log_probs = logits_processor(flatten_beam_dim(running_sequences), flatten_beam_dim(log_probs), cur_len)
log_probs = unflatten_beam_dim(log_probs, num_beams)
log_probs_processed = log_probs
log_probs = log_probs + tf.expand_dims(running_scores, axis=2)
if do_sample:
# Note: logits warpers are intentionally applied after adding running beam scores. On some logits
# warpers (like top_p) this is indiferent, but on others (like temperature) it is not. For reference,
# see https://github.com/huggingface/transformers/pull/5420#discussion_r449779867
log_probs = logits_warper(flatten_beam_dim(running_sequences), flatten_beam_dim(log_probs), cur_len)
log_probs = unflatten_beam_dim(log_probs, num_beams)
vocab_size = log_probs.shape[2]
log_probs = tf.reshape(log_probs, (batch_size, num_beams * vocab_size))
# Store scores, attentions and hidden_states when required # Store scores, attentions and hidden_states when required
if not use_xla and return_dict_in_generate: if not use_xla and return_dict_in_generate:
if output_scores: if output_scores:
scores.append(model_outputs.logits[:, -1]) all_scores.append(
logits_warper(
flatten_beam_dim(running_sequences), flatten_beam_dim(log_probs_processed), cur_len
)
)
if output_attentions and self.config.is_encoder_decoder: if output_attentions and self.config.is_encoder_decoder:
decoder_attentions.append(model_outputs.decoder_attentions) decoder_attentions.append(model_outputs.decoder_attentions)
elif output_attentions and not self.config.is_encoder_decoder: elif output_attentions and not self.config.is_encoder_decoder:
...@@ -2181,19 +2347,6 @@ class TFGenerationMixin: ...@@ -2181,19 +2347,6 @@ class TFGenerationMixin:
elif output_hidden_states and self.config.is_encoder_decoder: elif output_hidden_states and self.config.is_encoder_decoder:
decoder_hidden_states.append(model_outputs.hidden_states) decoder_hidden_states.append(model_outputs.hidden_states)
# 2. Compute log probs
# get log probabilities from logits, process logits with processors (*e.g.* min_length, ...), and
# add new logprobs to existing running logprobs scores.
log_probs = tf.nn.log_softmax(logits)
log_probs = logits_processor(flatten_beam_dim(running_sequences), flatten_beam_dim(log_probs), cur_len)
log_probs = unflatten_beam_dim(log_probs, num_beams)
log_probs = log_probs + tf.expand_dims(running_scores, axis=2)
if do_sample:
log_probs = logits_warper(flatten_beam_dim(running_sequences), flatten_beam_dim(log_probs), cur_len)
log_probs = unflatten_beam_dim(log_probs, num_beams)
vocab_size = log_probs.shape[2]
log_probs = tf.reshape(log_probs, (batch_size, num_beams * vocab_size))
# 3. Retrieve top-K # 3. Retrieve top-K
# Each item in batch has num_beams * vocab_size candidate sequences. For each item, get the top 2*k # Each item in batch has num_beams * vocab_size candidate sequences. For each item, get the top 2*k
# candidates with the highest log-probabilities. We gather the top 2*K beams here so that even if the # candidates with the highest log-probabilities. We gather the top 2*K beams here so that even if the
...@@ -2210,8 +2363,9 @@ class TFGenerationMixin: ...@@ -2210,8 +2363,9 @@ class TFGenerationMixin:
topk_log_probs = tf.gather(log_probs, topk_indices, axis=1, batch_dims=1) topk_log_probs = tf.gather(log_probs, topk_indices, axis=1, batch_dims=1)
else: else:
topk_log_probs, topk_indices = tf.math.top_k(log_probs, k=beams_to_keep) topk_log_probs, topk_indices = tf.math.top_k(log_probs, k=beams_to_keep)
topk_beam_indices = topk_indices // vocab_size topk_current_beam_indices = topk_indices // vocab_size
topk_running_sequences = self._gather_beams(running_sequences, topk_beam_indices) topk_running_beam_indices = self._gather_beams(running_beam_indices, topk_current_beam_indices)
topk_running_sequences = self._gather_beams(running_sequences, topk_current_beam_indices)
topk_ids = topk_indices % vocab_size topk_ids = topk_indices % vocab_size
# writes the new token # writes the new token
...@@ -2226,6 +2380,16 @@ class TFGenerationMixin: ...@@ -2226,6 +2380,16 @@ class TFGenerationMixin:
updates=tf.reshape(topk_ids, [batch_size * beams_to_keep]), updates=tf.reshape(topk_ids, [batch_size * beams_to_keep]),
) )
# we want to store the beam indices with batch information -> real beam index = beam index % num beams
batch_modified_indices = topk_current_beam_indices + tf.broadcast_to(
tf.expand_dims(tf.range(batch_size) * num_beams, axis=1), topk_current_beam_indices.shape
)
topk_beam_indices = tf.tensor_scatter_nd_update(
tensor=topk_running_beam_indices,
indices=update_indices,
updates=tf.reshape(batch_modified_indices, [batch_size * beams_to_keep]),
)
# 4. Check which sequences have ended # 4. Check which sequences have ended
# Update current sequences: Did the top `num_beams` sequences reach an end marker? # Update current sequences: Did the top `num_beams` sequences reach an end marker?
# To prevent these just finished sequences from being added to the current sequences # To prevent these just finished sequences from being added to the current sequences
...@@ -2246,8 +2410,8 @@ class TFGenerationMixin: ...@@ -2246,8 +2410,8 @@ class TFGenerationMixin:
# Determine the top k beam indices (from top 2*k beams) from log probs and gather top k beams # Determine the top k beam indices (from top 2*k beams) from log probs and gather top k beams
# (from top 2*k beams). # (from top 2*k beams).
next_topk_indices = tf.math.top_k(running_topk_log_probs, k=num_beams)[1] next_topk_indices = tf.math.top_k(running_topk_log_probs, k=num_beams)[1]
next_running_sequences, next_running_scores = self._gather_beams( next_running_sequences, next_running_scores, next_running_beam_indices = self._gather_beams(
[topk_sequences, running_topk_log_probs], next_topk_indices [topk_sequences, running_topk_log_probs, topk_beam_indices], next_topk_indices
) )
# 6. Process topk logits # 6. Process topk logits
...@@ -2267,10 +2431,11 @@ class TFGenerationMixin: ...@@ -2267,10 +2431,11 @@ class TFGenerationMixin:
# to existing finished scores and select the best from the new set of beams # to existing finished scores and select the best from the new set of beams
merged_sequences = tf.concat([sequences, topk_sequences], axis=1) merged_sequences = tf.concat([sequences, topk_sequences], axis=1)
merged_scores = tf.concat([scores, topk_log_probs], axis=1) merged_scores = tf.concat([scores, topk_log_probs], axis=1)
merged_beams = tf.concat([beam_indices, topk_beam_indices], axis=1)
merged_is_sent_finished = tf.concat([is_sent_finished, did_topk_just_finished], axis=1) merged_is_sent_finished = tf.concat([is_sent_finished, did_topk_just_finished], axis=1)
topk_merged_indices = tf.math.top_k(merged_scores, k=num_beams)[1] topk_merged_indices = tf.math.top_k(merged_scores, k=num_beams)[1]
next_sequences, next_scores, next_is_sent_finished = self._gather_beams( next_sequences, next_scores, next_beam_indices, next_is_sent_finished = self._gather_beams(
[merged_sequences, merged_scores, merged_is_sent_finished], topk_merged_indices [merged_sequences, merged_scores, merged_beams, merged_is_sent_finished], topk_merged_indices
) )
# 8. Prepare data for the next iteration # 8. Prepare data for the next iteration
...@@ -2282,7 +2447,7 @@ class TFGenerationMixin: ...@@ -2282,7 +2447,7 @@ class TFGenerationMixin:
lambda tensor: unflatten_beam_dim(tensor, num_beams, batch_axis=cache_batch_axis), lambda tensor: unflatten_beam_dim(tensor, num_beams, batch_axis=cache_batch_axis),
model_outputs.past_key_values, model_outputs.past_key_values,
) )
next_running_indices = self._gather_beams(topk_beam_indices, next_topk_indices) next_running_indices = self._gather_beams(topk_current_beam_indices, next_topk_indices)
next_cache = self._gather_beams(cache, next_running_indices, batch_axis=cache_batch_axis) next_cache = self._gather_beams(cache, next_running_indices, batch_axis=cache_batch_axis)
model_outputs["past_key_values"] = tf.nest.map_structure( model_outputs["past_key_values"] = tf.nest.map_structure(
lambda tensor: flatten_beam_dim(tensor, batch_axis=cache_batch_axis), next_cache lambda tensor: flatten_beam_dim(tensor, batch_axis=cache_batch_axis), next_cache
...@@ -2312,8 +2477,10 @@ class TFGenerationMixin: ...@@ -2312,8 +2477,10 @@ class TFGenerationMixin:
cur_len, cur_len,
next_running_sequences, next_running_sequences,
next_running_scores, next_running_scores,
next_running_beam_indices,
next_sequences, next_sequences,
next_scores, next_scores,
next_beam_indices,
next_is_sent_finished, next_is_sent_finished,
next_model_kwargs, next_model_kwargs,
) )
...@@ -2324,24 +2491,62 @@ class TFGenerationMixin: ...@@ -2324,24 +2491,62 @@ class TFGenerationMixin:
cur_len, cur_len,
running_sequences, running_sequences,
running_scores, running_scores,
running_beam_indices,
sequences, sequences,
scores, scores,
beam_indices,
is_sent_finished, is_sent_finished,
model_kwargs, model_kwargs,
) = beam_search_body_fn( ) = beam_search_body_fn(
cur_len, running_sequences, running_scores, sequences, scores, is_sent_finished, model_kwargs cur_len,
running_sequences,
running_scores,
running_beam_indices,
sequences,
scores,
beam_indices,
is_sent_finished,
model_kwargs,
) )
# 2-to-n generation steps can then be run in autoregressive fashion (only in case 1st generation step does # 2-to-n generation steps can then be run in autoregressive fashion (only in case 1st generation step does
# NOT yield EOS token though) # NOT yield EOS token though)
if beam_search_cond_fn( if beam_search_cond_fn(
cur_len, running_sequences, running_scores, sequences, scores, is_sent_finished, model_kwargs cur_len,
running_sequences,
running_scores,
running_beam_indices,
sequences,
scores,
beam_indices,
is_sent_finished,
model_kwargs,
): ):
maximum_iterations = max_length - cur_len maximum_iterations = max_length - cur_len
cur_len, running_sequences, running_scores, sequences, scores, is_sent_finished, _ = tf.while_loop( (
cur_len,
running_sequences,
running_scores,
running_beam_indices,
sequences,
scores,
beam_indices,
is_sent_finished,
_,
) = tf.while_loop(
beam_search_cond_fn, beam_search_cond_fn,
beam_search_body_fn, beam_search_body_fn,
(cur_len, running_sequences, running_scores, sequences, scores, is_sent_finished, model_kwargs), (
cur_len,
running_sequences,
running_scores,
running_beam_indices,
sequences,
scores,
beam_indices,
is_sent_finished,
model_kwargs,
),
maximum_iterations=maximum_iterations, maximum_iterations=maximum_iterations,
) )
...@@ -2350,15 +2555,21 @@ class TFGenerationMixin: ...@@ -2350,15 +2555,21 @@ class TFGenerationMixin:
# running sequences for that batch item. # running sequences for that batch item.
none_finished = tf.math.reduce_any(is_sent_finished, axis=1) none_finished = tf.math.reduce_any(is_sent_finished, axis=1)
sequences = tf.where(none_finished[:, None, None], sequences, running_sequences) sequences = tf.where(none_finished[:, None, None], sequences, running_sequences)
beam_indices = tf.where(none_finished[:, None, None], beam_indices, running_beam_indices)
# Apply the length penalty so that running scores match the finalized scores if they are used
running_scores = running_scores / (tf.cast(cur_len, dtype=tf.float32) ** length_penalty)
scores = tf.where(none_finished[:, None], scores, running_scores) scores = tf.where(none_finished[:, None], scores, running_scores)
# Take best beams for each batch (the score is sorted in descending order) # Take best beams for each batch (the score is sorted in descending order)
sequences = flatten_beam_dim(sequences[:, :num_return_sequences, :]) sequences = flatten_beam_dim(sequences[:, :num_return_sequences, :])
scores = flatten_beam_dim(scores[:, :num_return_sequences]) scores = flatten_beam_dim(scores[:, :num_return_sequences])
beam_indices = flatten_beam_dim(beam_indices[:, :num_return_sequences, :])
if not use_xla: if not use_xla:
# Cut for backward compatibility # Cut for backward compatibility
sequences = sequences[:, :cur_len] sequences = sequences[:, :cur_len]
beam_indices = beam_indices[:, :cur_len]
if return_dict_in_generate: if return_dict_in_generate:
if self.config.is_encoder_decoder: if self.config.is_encoder_decoder:
...@@ -2371,7 +2582,9 @@ class TFGenerationMixin: ...@@ -2371,7 +2582,9 @@ class TFGenerationMixin:
output_cls = TFBeamSampleEncoderDecoderOutput if do_sample else TFBeamSearchEncoderDecoderOutput output_cls = TFBeamSampleEncoderDecoderOutput if do_sample else TFBeamSearchEncoderDecoderOutput
return output_cls( return output_cls(
sequences=sequences, sequences=sequences,
scores=scores, sequences_scores=scores,
scores=all_scores,
beam_indices=beam_indices,
encoder_attentions=encoder_attentions, encoder_attentions=encoder_attentions,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
decoder_attentions=decoder_attentions, decoder_attentions=decoder_attentions,
...@@ -2382,7 +2595,9 @@ class TFGenerationMixin: ...@@ -2382,7 +2595,9 @@ class TFGenerationMixin:
output_cls = TFBeamSampleDecoderOnlyOutput if do_sample else TFBeamSearchDecoderOnlyOutput output_cls = TFBeamSampleDecoderOnlyOutput if do_sample else TFBeamSearchDecoderOnlyOutput
return output_cls( return output_cls(
sequences=sequences, sequences=sequences,
scores=scores, sequences_scores=scores,
scores=all_scores,
beam_indices=beam_indices,
attentions=decoder_attentions, attentions=decoder_attentions,
hidden_states=decoder_hidden_states, hidden_states=decoder_hidden_states,
) )
...@@ -2607,7 +2822,7 @@ class TFGenerationMixin: ...@@ -2607,7 +2822,7 @@ class TFGenerationMixin:
# Store scores, attentions and hidden_states when required # Store scores, attentions and hidden_states when required
if not use_xla and return_dict_in_generate: if not use_xla and return_dict_in_generate:
if output_scores: if output_scores:
scores.append(outputs.logits[:, -1]) scores.append(logit_for_next_step)
if output_attentions and self.config.is_encoder_decoder: if output_attentions and self.config.is_encoder_decoder:
decoder_attentions.append(outputs.decoder_attentions) decoder_attentions.append(outputs.decoder_attentions)
elif output_attentions and not self.config.is_encoder_decoder: elif output_attentions and not self.config.is_encoder_decoder:
......
...@@ -301,9 +301,9 @@ class BeamSearchDecoderOnlyOutput(ModelOutput): ...@@ -301,9 +301,9 @@ class BeamSearchDecoderOnlyOutput(ModelOutput):
of log probabilities of tokens conditioned on log softmax of previously generated tokens in this beam. of log probabilities of tokens conditioned on log softmax of previously generated tokens in this beam.
Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for each generated token), Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for each generated token),
with each tensor of shape `(batch_size*num_beams*num_return_sequences, config.vocab_size)`. with each tensor of shape `(batch_size*num_beams*num_return_sequences, config.vocab_size)`.
beam_indices (`tuple(tuple(torch.LongTensor))`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`): beam_indices (`torch.LongTensor`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
Beam indices of generated token id at each generation step. `torch.LongTensor` of shape Beam indices of generated token id at each generation step. `torch.LongTensor` of shape
`(batch_size*num_return_sequences, input_ids.shape[-1])`. `(batch_size*num_return_sequences, sequence_length)`.
attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`): attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
`torch.FloatTensor` of shape `(batch_size*num_beams, num_heads, generated_length, sequence_length)`. `torch.FloatTensor` of shape `(batch_size*num_beams, num_heads, generated_length, sequence_length)`.
...@@ -338,10 +338,9 @@ class BeamSearchEncoderDecoderOutput(ModelOutput): ...@@ -338,10 +338,9 @@ class BeamSearchEncoderDecoderOutput(ModelOutput):
of log probabilities of tokens conditioned on log softmax of previously generated tokens in this beam. of log probabilities of tokens conditioned on log softmax of previously generated tokens in this beam.
Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for each generated token), Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for each generated token),
with each tensor of shape `(batch_size*num_beams, config.vocab_size)`. with each tensor of shape `(batch_size*num_beams, config.vocab_size)`.
beam_indices (`tuple(tuple(torch.LongTensor))`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`): beam_indices (`torch.LongTensor`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
Beam indices of generated token id at each generation step. `torch.LongTensor` of shape Beam indices of generated token id at each generation step. `torch.LongTensor` of shape
`(batch_size*num_return_sequences, max_length-1)`. `(batch_size*num_return_sequences, sequence_length)`.
attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`): encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer of the decoder) of shape `(batch_size, num_heads, Tuple of `torch.FloatTensor` (one for each layer of the decoder) of shape `(batch_size, num_heads,
sequence_length, sequence_length)`. sequence_length, sequence_length)`.
...@@ -387,9 +386,9 @@ class BeamSampleDecoderOnlyOutput(ModelOutput): ...@@ -387,9 +386,9 @@ class BeamSampleDecoderOnlyOutput(ModelOutput):
of log probabilities of tokens conditioned on log softmax of previously generated tokens in this beam. of log probabilities of tokens conditioned on log softmax of previously generated tokens in this beam.
Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for each generated token), Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for each generated token),
with each tensor of shape `(batch_size*num_beams*num_return_sequences, config.vocab_size)`. with each tensor of shape `(batch_size*num_beams*num_return_sequences, config.vocab_size)`.
beam_indices (`tuple(tuple(torch.LongTensor))`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`): beam_indices (`torch.LongTensor`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
Beam indices of generated token id at each generation step. `torch.LongTensor` of shape Beam indices of generated token id at each generation step. `torch.LongTensor` of shape
`(batch_size*num_return_sequences, input_ids.shape[-1])`. `(batch_size*num_return_sequences, sequence_length)`.
attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`): attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
`torch.FloatTensor` of shape `(batch_size*num_beams, num_heads, generated_length, sequence_length)`. `torch.FloatTensor` of shape `(batch_size*num_beams, num_heads, generated_length, sequence_length)`.
...@@ -426,7 +425,7 @@ class BeamSampleEncoderDecoderOutput(ModelOutput): ...@@ -426,7 +425,7 @@ class BeamSampleEncoderDecoderOutput(ModelOutput):
with each tensor of shape `(batch_size*num_beams, config.vocab_size)`). with each tensor of shape `(batch_size*num_beams, config.vocab_size)`).
beam_indices (`torch.LongTensor`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`): beam_indices (`torch.LongTensor`, *optional*, returned when `output_scores=True` is passed or when `config.output_scores=True`):
Beam indices of generated token id at each generation step. `torch.LongTensor` of shape Beam indices of generated token id at each generation step. `torch.LongTensor` of shape
`(batch_size*num_return_sequences, max_length-1)`. `(batch_size*num_return_sequences, sequence_length)`.
encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`): encoder_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer of the decoder) of shape `(batch_size, num_heads, Tuple of `torch.FloatTensor` (one for each layer of the decoder) of shape `(batch_size, num_heads,
sequence_length, sequence_length)`. sequence_length, sequence_length)`.
...@@ -937,9 +936,9 @@ class GenerationMixin: ...@@ -937,9 +936,9 @@ class GenerationMixin:
of log probabilities of tokens conditioned on log softmax of previously generated tokens Tuple of of log probabilities of tokens conditioned on log softmax of previously generated tokens Tuple of
`torch.FloatTensor` with up to `max_new_tokens` elements (one element for each generated token), with `torch.FloatTensor` with up to `max_new_tokens` elements (one element for each generated token), with
each tensor of shape `(batch_size*num_beams, config.vocab_size)`. each tensor of shape `(batch_size*num_beams, config.vocab_size)`.
beam_indices (`tuple(tuple(torch.LongTensor))`, *optional*): beam_indices (`torch.LongTensor`, *optional*):
Beam indices of generated token id at each generation step. `torch.LongTensor` of shape Beam indices of generated token id at each generation step. `torch.LongTensor` of shape
`(batch_size*num_return_sequences, input_ids.shape[-1])`. Only required if a `num_beams>1` at `(batch_size*num_return_sequences, sequence_length)`. Only required if a `num_beams>1` at
generate-time. generate-time.
normalize_logits (`bool`, *optional*, defaults to `False`): normalize_logits (`bool`, *optional*, defaults to `False`):
Whether to normalize the logits (which, for legacy reasons, may be unnormalized). Whether to normalize the logits (which, for legacy reasons, may be unnormalized).
...@@ -1017,11 +1016,10 @@ class GenerationMixin: ...@@ -1017,11 +1016,10 @@ class GenerationMixin:
# 4. cut beam_indices to longest beam length # 4. cut beam_indices to longest beam length
beam_indices_mask = beam_indices < 0 beam_indices_mask = beam_indices < 0
max_beam_length = (1 - beam_indices_mask.long()).sum(-1).max() max_beam_length = (1 - beam_indices_mask.long()).sum(-1).max()
beam_indices = beam_indices[:, :max_beam_length] beam_indices = beam_indices.clone()[:, :max_beam_length]
beam_indices_mask = beam_indices_mask[:, :max_beam_length] beam_indices_mask = beam_indices_mask[:, :max_beam_length]
# 5. Set indices of beams that finished early to 0 # 5. Set indices of beams that finished early to 0; such indices will be masked correctly afterwards
# such indices will be masked correctly afterwards
beam_indices[beam_indices_mask] = 0 beam_indices[beam_indices_mask] = 0
# 6. multiply beam_indices with vocab size to gather correctly from scores # 6. multiply beam_indices with vocab size to gather correctly from scores
...@@ -3067,6 +3065,9 @@ class GenerationMixin: ...@@ -3067,6 +3065,9 @@ class GenerationMixin:
next_token_scores_processed = logits_processor(input_ids, next_token_scores) next_token_scores_processed = logits_processor(input_ids, next_token_scores)
next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(next_token_scores) next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(next_token_scores)
# Note: logits warpers are intentionally applied after adding running beam scores. On some logits warpers
# (like top_p) this is indiferent, but on others (like temperature) it is not. For reference, see
# https://github.com/huggingface/transformers/pull/5420#discussion_r449779867
next_token_scores = logits_warper(input_ids, next_token_scores) next_token_scores = logits_warper(input_ids, next_token_scores)
# Store scores, attentions and hidden_states when required # Store scores, attentions and hidden_states when required
......
...@@ -5,7 +5,7 @@ Framework agnostic tests for generate()-related methods. ...@@ -5,7 +5,7 @@ Framework agnostic tests for generate()-related methods.
import numpy as np import numpy as np
from transformers import AutoTokenizer from transformers import AutoTokenizer
from transformers.testing_utils import torch_device from transformers.testing_utils import slow, torch_device
class GenerationIntegrationTestsMixin: class GenerationIntegrationTestsMixin:
...@@ -133,16 +133,12 @@ class GenerationIntegrationTestsMixin: ...@@ -133,16 +133,12 @@ class GenerationIntegrationTestsMixin:
def test_encoder_decoder_generate_with_inputs_embeds(self): def test_encoder_decoder_generate_with_inputs_embeds(self):
model_cls = self.framework_dependent_parameters["AutoModelForSeq2SeqLM"] model_cls = self.framework_dependent_parameters["AutoModelForSeq2SeqLM"]
return_tensors = self.framework_dependent_parameters["return_tensors"] return_tensors = self.framework_dependent_parameters["return_tensors"]
is_pt = not model_cls.__name__.startswith("TF")
article = """Justin Timberlake and Jessica Biel, welcome to parenthood.""" article = """Justin Timberlake and Jessica Biel, welcome to parenthood."""
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart") tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
model = model_cls.from_pretrained("hf-internal-testing/tiny-random-bart", max_length=5) model = model_cls.from_pretrained("hf-internal-testing/tiny-random-bart", max_length=5)
model.config.eos_token_id = None model.config.eos_token_id = None
input_ids = tokenizer(article, return_tensors=return_tensors).input_ids input_ids = tokenizer(article, return_tensors=return_tensors).input_ids
if is_pt:
model = model.to(torch_device)
input_ids = input_ids.to(torch_device)
inputs_embeds = model.get_input_embeddings()(input_ids) inputs_embeds = model.get_input_embeddings()(input_ids)
...@@ -150,3 +146,253 @@ class GenerationIntegrationTestsMixin: ...@@ -150,3 +146,253 @@ class GenerationIntegrationTestsMixin:
# make sure model generated correctly until `max_length` # make sure model generated correctly until `max_length`
self.assertEqual(output_sequences.shape, (1, 5)) self.assertEqual(output_sequences.shape, (1, 5))
def test_transition_scores_greedy_search(self):
model_cls = self.framework_dependent_parameters["AutoModelForCausalLM"]
return_tensors = self.framework_dependent_parameters["return_tensors"]
is_pt = not model_cls.__name__.startswith("TF")
articles = ["Justin Timberlake", "Michael Phelps"]
tokenizer = AutoTokenizer.from_pretrained("distilgpt2", padding_side="left")
tokenizer.pad_token = tokenizer.eos_token
model = model_cls.from_pretrained("distilgpt2")
input_ids = tokenizer(articles, return_tensors=return_tensors, padding=True).input_ids
if is_pt:
model = model.to(torch_device)
input_ids = input_ids.to(torch_device)
outputs = model.generate(
input_ids=input_ids,
max_new_tokens=5,
pad_token_id=tokenizer.eos_token_id,
eos_token_id=None,
return_dict_in_generate=True,
output_scores=True,
)
transition_scores = model.compute_transition_scores(outputs.sequences, outputs.scores)
if is_pt:
transition_scores = transition_scores.cpu().numpy()
expected_scores = np.array(
[
[-57.8844, -60.45698, -70.16364, -65.50791, -66.35648],
[-54.417572, -60.216614, -62.661243, -58.621933, -58.298683],
]
)
self.assertTrue(np.allclose(transition_scores, expected_scores, atol=1e-3))
def test_transition_scores_greedy_search_normalized(self):
model_cls = self.framework_dependent_parameters["AutoModelForCausalLM"]
return_tensors = self.framework_dependent_parameters["return_tensors"]
is_pt = not model_cls.__name__.startswith("TF")
articles = ["Justin Timberlake", "Michael Phelps"]
tokenizer = AutoTokenizer.from_pretrained("distilgpt2", padding_side="left")
tokenizer.pad_token = tokenizer.eos_token
model = model_cls.from_pretrained("distilgpt2")
input_ids = tokenizer(articles, return_tensors=return_tensors, padding=True).input_ids
if is_pt:
model = model.to(torch_device)
input_ids = input_ids.to(torch_device)
outputs = model.generate(
input_ids=input_ids,
max_new_tokens=5,
pad_token_id=tokenizer.eos_token_id,
eos_token_id=None,
return_dict_in_generate=True,
output_scores=True,
)
transition_scores = model.compute_transition_scores(outputs.sequences, outputs.scores, normalize_logits=True)
if is_pt:
transition_scores = transition_scores.cpu().numpy()
expected_scores = np.array(
[
[-2.538938, -2.2694316, -2.1580915, -1.572299, -2.6719835],
[-1.8826028, -2.2461371, -1.7556462, -2.9644494, -1.7996008],
]
)
self.assertTrue(np.allclose(transition_scores, expected_scores, atol=1e-3))
def test_transition_scores_beam_search_encoder_decoder(self):
model_cls = self.framework_dependent_parameters["AutoModelForSeq2SeqLM"]
return_tensors = self.framework_dependent_parameters["return_tensors"]
is_pt = not model_cls.__name__.startswith("TF")
articles = [
"Justin Timberlake and Jessica Biel, welcome to parenthood.",
"Michael Phelps is arguably the most decorated Olympian of all time.",
]
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
model = model_cls.from_pretrained(
"hf-internal-testing/tiny-random-bart",
max_length=10,
num_beams=4,
num_return_sequences=2,
eos_token_id=None,
return_dict_in_generate=True,
output_scores=True,
length_penalty=0.0,
)
input_ids = tokenizer(articles, return_tensors=return_tensors, padding=True).input_ids
if is_pt:
model = model.to(torch_device)
input_ids = input_ids.to(torch_device)
outputs = model.generate(input_ids=input_ids)
transition_scores = model.compute_transition_scores(outputs.sequences, outputs.scores, outputs.beam_indices)
if is_pt:
transition_scores = transition_scores.cpu().numpy()
outputs.sequences_scores = outputs.sequences_scores.cpu().numpy()
self.assertTrue(np.allclose(np.sum(transition_scores, axis=-1), outputs.sequences_scores, atol=1e-3))
def test_transition_scores_beam_search_encoder_decoder_with_eos(self):
model_cls = self.framework_dependent_parameters["AutoModelForSeq2SeqLM"]
return_tensors = self.framework_dependent_parameters["return_tensors"]
is_pt = not model_cls.__name__.startswith("TF")
articles = [
"Justin Timberlake and Jessica Biel, welcome to parenthood.",
"Michael Phelps is arguably the most decorated Olympian of all time.",
]
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
model = model_cls.from_pretrained(
"hf-internal-testing/tiny-random-bart",
max_length=10,
num_beams=4,
num_return_sequences=2,
return_dict_in_generate=True,
output_scores=True,
length_penalty=0.0,
)
input_ids = tokenizer(articles, return_tensors=return_tensors, padding=True).input_ids
if is_pt:
model = model.to(torch_device)
input_ids = input_ids.to(torch_device)
outputs = model.generate(input_ids=input_ids)
transition_scores = model.compute_transition_scores(outputs.sequences, outputs.scores, outputs.beam_indices)
if is_pt:
transition_scores = transition_scores.cpu().numpy()
outputs.sequences_scores = outputs.sequences_scores.cpu().numpy()
self.assertTrue(np.allclose(np.sum(transition_scores, axis=-1), outputs.sequences_scores, atol=1e-3))
def test_transition_scores_beam_search_decoder_only(self):
model_cls = self.framework_dependent_parameters["AutoModelForCausalLM"]
return_tensors = self.framework_dependent_parameters["return_tensors"]
is_pt = not model_cls.__name__.startswith("TF")
articles = [
"Justin Timberlake",
"Michael Phelps",
]
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
tokenizer.pad_token = tokenizer.eos_token
model = model_cls.from_pretrained(
"hf-internal-testing/tiny-random-gpt2",
max_length=10,
num_beams=4,
num_return_sequences=2,
pad_token_id=tokenizer.eos_token_id,
eos_token_id=None,
return_dict_in_generate=True,
output_scores=True,
length_penalty=0.0,
)
input_ids = tokenizer(articles, return_tensors=return_tensors, padding=True).input_ids
if is_pt:
model = model.to(torch_device)
input_ids = input_ids.to(torch_device)
outputs = model.generate(input_ids=input_ids)
transition_scores = model.compute_transition_scores(outputs.sequences, outputs.scores, outputs.beam_indices)
if is_pt:
transition_scores = transition_scores.cpu().numpy()
outputs.sequences_scores = outputs.sequences_scores.cpu().numpy()
self.assertTrue(np.allclose(np.sum(transition_scores, axis=-1), outputs.sequences_scores, atol=1e-3))
def test_transition_scores_beam_sample_encoder_decoder(self):
model_cls = self.framework_dependent_parameters["AutoModelForSeq2SeqLM"]
return_tensors = self.framework_dependent_parameters["return_tensors"]
is_pt = not model_cls.__name__.startswith("TF")
articles = [
"Justin Timberlake and Jessica Biel, welcome to parenthood.",
"Michael Phelps is arguably the most decorated Olympian of all time.",
]
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
model = model_cls.from_pretrained(
"hf-internal-testing/tiny-random-bart",
do_sample=True,
max_length=10,
num_beams=4,
num_return_sequences=2,
eos_token_id=None,
return_dict_in_generate=True,
output_scores=True,
length_penalty=0.0,
)
input_ids = tokenizer(articles, return_tensors=return_tensors, padding=True).input_ids
if is_pt:
model = model.to(torch_device)
input_ids = input_ids.to(torch_device)
outputs = model.generate(input_ids=input_ids)
transition_scores = model.compute_transition_scores(outputs.sequences, outputs.scores, outputs.beam_indices)
if is_pt:
transition_scores = transition_scores.cpu().numpy()
outputs.sequences_scores = outputs.sequences_scores.cpu().numpy()
self.assertTrue(np.allclose(np.sum(transition_scores, axis=-1), outputs.sequences_scores, atol=1e-3))
@slow
def test_transition_scores_early_stopping(self):
# This is an aggressive test that makes sure that `beam_search's`
# transition scores are computed correctly for varying `num_return_sequences`, `num_beams` and `batch_size > 1`
# 2 x input_ids for "question: How are you? \n context: I had a long day, "
model_cls = self.framework_dependent_parameters["AutoModelForSeq2SeqLM"]
create_tensor_fn = self.framework_dependent_parameters["create_tensor_fn"]
is_pt = not model_cls.__name__.startswith("TF")
input_ids = create_tensor_fn(2 * [[822, 10, 571, 33, 25, 58, 2625, 10, 27, 141, 3, 9, 307, 239, 6, 1]])
model = model_cls.from_pretrained("t5-small")
if is_pt:
model = model.to(torch_device)
input_ids = input_ids.to(torch_device)
outputs = model.generate(
input_ids,
max_length=10,
return_dict_in_generate=True,
output_scores=True,
forced_eos_token_id=model.config.eos_token_id,
num_beams=4,
do_sample=False,
num_return_sequences=3,
length_penalty=0.0,
)
transition_scores = model.compute_transition_scores(
sequences=outputs.sequences, scores=outputs.scores, beam_indices=outputs.beam_indices
)
if is_pt:
transition_scores = transition_scores.cpu().numpy()
outputs.sequences_scores = outputs.sequences_scores.cpu().numpy()
self.assertTrue(np.allclose(np.sum(transition_scores, axis=-1), outputs.sequences_scores))
...@@ -17,8 +17,6 @@ ...@@ -17,8 +17,6 @@
import inspect import inspect
import unittest import unittest
import numpy as np
from transformers import is_torch_available, pipeline from transformers import is_torch_available, pipeline
from transformers.testing_utils import require_torch, slow, torch_device from transformers.testing_utils import require_torch, slow, torch_device
...@@ -2220,165 +2218,6 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi ...@@ -2220,165 +2218,6 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
self.assertListEqual(output_sequences_no_mask.tolist(), output_sequences_with_mask.tolist()) self.assertListEqual(output_sequences_no_mask.tolist(), output_sequences_with_mask.tolist())
def test_transition_scores_greedy_search(self):
articles = ["Justin Timberlake", "Michael Phelps"]
tokenizer = GPT2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
tokenizer.pad_token = tokenizer.eos_token
model = GPT2LMHeadModel.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
input_ids = tokenizer(articles, return_tensors="pt", padding=True).input_ids.to(torch_device)
outputs = model.generate(
input_ids=input_ids,
max_new_tokens=5,
pad_token_id=tokenizer.eos_token_id,
eos_token_id=None,
return_dict_in_generate=True,
output_scores=True,
)
transition_scores = model.compute_transition_scores(outputs.sequences, outputs.scores)
expected_scores = np.array(
[
[0.3596273, 0.39646253, 0.46157718, 0.4594633, 0.44866616],
[0.34934354, 0.4935004, 0.6373219, 0.5173545, 0.57517034],
]
)
self.assertTrue(np.allclose(transition_scores.cpu().numpy(), expected_scores))
def test_transition_scores_greedy_search_normalized(self):
articles = ["Justin Timberlake", "Michael Phelps"]
tokenizer = GPT2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
tokenizer.pad_token = tokenizer.eos_token
model = GPT2LMHeadModel.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device)
input_ids = tokenizer(articles, return_tensors="pt", padding=True).input_ids.to(torch_device)
outputs = model.generate(
input_ids=input_ids,
max_new_tokens=5,
pad_token_id=tokenizer.eos_token_id,
eos_token_id=None,
return_dict_in_generate=True,
output_scores=True,
)
transition_scores = model.compute_transition_scores(outputs.sequences, outputs.scores, normalize_logits=True)
expected_scores = np.array(
[
[-6.5532393, -6.5158753, -6.451863, -6.4527144, -6.459402],
[-6.5685124, -6.4277077, -6.282607, -6.399295, -6.340927],
]
)
self.assertTrue(np.allclose(transition_scores.cpu().numpy(), expected_scores))
def test_transition_scores_beam_search_encoder_decoder(self):
articles = [
"Justin Timberlake and Jessica Biel, welcome to parenthood.",
"Michael Phelps is arguably the most decorated Olympian of all time.",
]
tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
model = BartForConditionalGeneration.from_pretrained(
"hf-internal-testing/tiny-random-bart",
max_length=10,
num_beams=4,
num_return_sequences=2,
eos_token_id=None,
return_dict_in_generate=True,
output_scores=True,
length_penalty=0.0,
)
model = model.to(torch_device)
input_ids = tokenizer(articles, return_tensors="pt", padding=True).input_ids.to(torch_device)
outputs = model.generate(input_ids=input_ids)
transition_scores = model.compute_transition_scores(outputs.sequences, outputs.scores, outputs.beam_indices)
transition_scores_sum = transition_scores.sum(-1)
self.assertTrue(torch.allclose(transition_scores_sum, outputs.sequences_scores, atol=1e-3))
def test_transition_scores_beam_search_encoder_decoder_with_eos(self):
articles = [
"Justin Timberlake and Jessica Biel, welcome to parenthood.",
"Michael Phelps is arguably the most decorated Olympian of all time.",
]
tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
model = BartForConditionalGeneration.from_pretrained(
"hf-internal-testing/tiny-random-bart",
max_length=10,
num_beams=4,
num_return_sequences=2,
return_dict_in_generate=True,
output_scores=True,
length_penalty=0.0,
)
model = model.to(torch_device)
input_ids = tokenizer(articles, return_tensors="pt", padding=True).input_ids.to(torch_device)
outputs = model.generate(input_ids=input_ids)
transition_scores = model.compute_transition_scores(outputs.sequences, outputs.scores, outputs.beam_indices)
transition_scores_sum = transition_scores.sum(-1)
self.assertTrue(torch.allclose(transition_scores_sum, outputs.sequences_scores, atol=1e-3))
def test_transition_scores_beam_search_decoder_only(self):
articles = [
"Justin Timberlake",
"Michael Phelps",
]
tokenizer = GPT2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
tokenizer.pad_token = tokenizer.eos_token
model = GPT2LMHeadModel.from_pretrained(
"hf-internal-testing/tiny-random-gpt2",
max_length=10,
num_beams=4,
num_return_sequences=2,
pad_token_id=tokenizer.eos_token_id,
eos_token_id=None,
return_dict_in_generate=True,
output_scores=True,
length_penalty=0.0,
)
model = model.to(torch_device)
input_ids = tokenizer(articles, return_tensors="pt", padding=True).input_ids.to(torch_device)
outputs = model.generate(input_ids=input_ids)
transition_scores = model.compute_transition_scores(outputs.sequences, outputs.scores, outputs.beam_indices)
transition_scores_sum = transition_scores.sum(-1)
self.assertTrue(torch.allclose(transition_scores_sum, outputs.sequences_scores, atol=1e-3))
def test_transition_scores_beam_sample_encoder_decoder(self):
articles = [
"Justin Timberlake and Jessica Biel, welcome to parenthood.",
"Michael Phelps is arguably the most decorated Olympian of all time.",
]
tokenizer = BartTokenizer.from_pretrained("hf-internal-testing/tiny-random-bart")
model = BartForConditionalGeneration.from_pretrained(
"hf-internal-testing/tiny-random-bart",
do_sample=True,
max_length=10,
num_beams=4,
num_return_sequences=2,
eos_token_id=None,
return_dict_in_generate=True,
output_scores=True,
length_penalty=0.0,
)
model = model.to(torch_device)
input_ids = tokenizer(articles, return_tensors="pt", padding=True).input_ids.to(torch_device)
outputs = model.generate(input_ids=input_ids)
transition_scores = model.compute_transition_scores(outputs.sequences, outputs.scores, outputs.beam_indices)
transition_scores_sum = transition_scores.sum(-1)
self.assertTrue(torch.allclose(transition_scores_sum, outputs.sequences_scores, atol=1e-3))
def test_transition_scores_group_beam_search_encoder_decoder(self): def test_transition_scores_group_beam_search_encoder_decoder(self):
articles = [ articles = [
"Justin Timberlake and Jessica Biel, welcome to parenthood.", "Justin Timberlake and Jessica Biel, welcome to parenthood.",
...@@ -2406,38 +2245,6 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi ...@@ -2406,38 +2245,6 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi
self.assertTrue(torch.allclose(transition_scores_sum, outputs.sequences_scores, atol=1e-3)) self.assertTrue(torch.allclose(transition_scores_sum, outputs.sequences_scores, atol=1e-3))
@slow
def test_transition_scores_early_stopping(self):
# This is an aggressive test that makes sure that `beam_search's`
# transition scores are computed correctly for varying `num_return_sequences`,
# `num_beams` and `batch_size > 1`
# 2 x input_ids for "question: How are you? \n context: I had a long day, "
input_ids = torch.tensor(2 * [[822, 10, 571, 33, 25, 58, 2625, 10, 27, 141, 3, 9, 307, 239, 6, 1]]).to(
torch_device
)
model = AutoModelForSeq2SeqLM.from_pretrained("t5-small").to(torch_device)
result = model.generate(
input_ids,
max_length=10,
return_dict_in_generate=True,
output_scores=True,
forced_eos_token_id=model.config.eos_token_id,
num_beams=4,
do_sample=False,
num_return_sequences=3,
length_penalty=0.0,
)
transition_scores = model.compute_transition_scores(
sequences=result.sequences, scores=result.scores, beam_indices=result.beam_indices
)
sum_transition_scores = torch.sum(transition_scores, dim=1)
self.assertListEqual(sum_transition_scores.cpu().tolist(), result.sequences_scores.cpu().tolist())
def test_log_scores_sample_decoder_only(self): def test_log_scores_sample_decoder_only(self):
articles = ["I need input_ids to generate", "Short and"] articles = ["I need input_ids to generate", "Short and"]
tokenizer = GPT2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2") tokenizer = GPT2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment