"tests/models/albert/test_modeling_tf_albert.py" did not exist on "ba8b1f4754257e140ddabbe04a7f3e493e33802d"
Unverified Commit a6c850e4 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Generate: TF uses `GenerationConfig` as the basis for `.generate()` parametrization (#20994)

parent 3b309818
......@@ -14,10 +14,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import inspect
import warnings
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Dict, Optional, Tuple, Union
import numpy as np
import tensorflow as tf
......@@ -32,6 +33,7 @@ from ..models.auto import (
)
from ..tf_utils import shape_list, stable_softmax
from ..utils import ModelOutput, logging
from .configuration_utils import GenerationConfig
from .tf_logits_process import (
TFForcedBOSTokenLogitsProcessor,
TFForcedEOSTokenLogitsProcessor,
......@@ -449,6 +451,11 @@ class TFGenerationMixin:
supports_xla_generation = True
def prepare_inputs_for_generation(self, *args, **kwargs):
raise NotImplementedError(
"A model class needs to define a `prepare_inputs_for_generation` method in order to use `generate`."
)
def adjust_logits_during_generation(
self, logits, cur_len, max_length, forced_bos_token_id, forced_eos_token_id, **kwargs
):
......@@ -475,7 +482,7 @@ class TFGenerationMixin:
Confirms that the model class is compatible with generation. If not, raises an exception that points to the
right class to use.
"""
if not hasattr(self, "prepare_inputs_for_generation"):
if not self.can_generate():
generate_compatible_mappings = [
TF_MODEL_FOR_CAUSAL_LM_MAPPING,
TF_MODEL_FOR_VISION_2_SEQ_MAPPING,
......@@ -520,153 +527,43 @@ class TFGenerationMixin:
def generate(
self,
input_ids=None,
max_length=None,
max_new_tokens=None,
min_length=None,
do_sample=None,
early_stopping=None,
num_beams=None,
temperature=None,
penalty_alpha=None,
top_k=None,
top_p=None,
repetition_penalty=None,
bad_words_ids=None,
bos_token_id=None,
pad_token_id=None,
eos_token_id=None,
length_penalty=None,
no_repeat_ngram_size=None,
num_return_sequences=None,
attention_mask=None,
decoder_start_token_id=None,
use_cache=None,
input_ids: Optional[tf.Tensor] = None,
generation_config: Optional[GenerationConfig] = None,
seed=None,
output_scores=None,
output_attentions=None,
output_hidden_states=None,
return_dict_in_generate=None,
forced_bos_token_id=None,
forced_eos_token_id=None,
suppress_tokens=None,
begin_suppress_tokens=None,
forced_decoder_ids=None,
**model_kwargs,
**kwargs,
) -> Union[TFGenerateOutput, tf.Tensor]:
r"""
Generates sequences of token ids for models with a language modeling head. The method supports the following
generation methods for text-decoder, text-to-text, speech-to-text, and vision-to-text models:
- *greedy decoding* by calling [`~generation.TFGenerationMixin.greedy_search`] if `num_beams=1` and
`do_sample=False`.
- *contrastive search* by calling [`~generation.TFGenerationMixin.contrastive_search`] if `penalty_alpha>0`
and `top_k>1`
- *multinomial sampling* by calling [`~generation.TFGenerationMixin.sample`] if `num_beams=1` and
`do_sample=True`.
- *beam-search decoding* by calling [`~generation.TFGenerationMixin.beam_search`] if `num_beams>1` and
`do_sample=False`.
Generates sequences of token ids for models with a language modeling head.
Adapted in part from [Facebook's XLM beam search
code](https://github.com/facebookresearch/XLM/blob/9e6f6814d17be4fe5b15f2e6c43eb2b2d76daeb4/src/model/transformer.py#L529).
<Tip warning={true}>
Apart from `input_ids` and `attention_mask`, all the arguments below will default to the value of the attribute
of the same name inside the [`PretrainedConfig`] of the model. The default values indicated are the default
values of those config.
Most generation-controlling parameters are set in `generation_config` which, if not passed, will be set to the
model's default generation configuration. You can override any `generation_config` by passing the corresponding
parameters to generate, e.g. `.generate(inputs, num_beams=4, do_sample=True)`.
Most of these parameters are explained in more detail in [this blog
post](https://huggingface.co/blog/how-to-generate).
For a complete overview of generate, check the [following
guide](https://huggingface.co/docs/transformers/en/main_classes/text_generation).
</Tip>
Parameters:
input_ids (`tf.Tensor` of `dtype=tf.int32` and shape `(batch_size, sequence_length)`, *optional*):
The sequence used as a prompt for the generation. If `None` the method initializes it with
`bos_token_id` and a batch size of 1.
max_length (`int`, *optional*, defaults to `model.config.max_length`):
The maximum length the generated tokens can have. Corresponds to the length of the input prompt +
`max_new_tokens`. In general, prefer the use of `max_new_tokens`, which ignores the number of tokens in
the prompt.
max_new_tokens (`int`, *optional*):
The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt.
min_length (`int`, *optional*, defaults to 10):
The minimum length of the sequence to be generated.
do_sample (`bool`, *optional*, defaults to `False`):
Whether or not to use sampling ; use greedy decoding otherwise.
early_stopping (`bool`, *optional*, defaults to `False`):
Whether to stop the beam search when at least `num_beams` sentences are finished per batch or not.
num_beams (`int`, *optional*, defaults to 1):
Number of beams for beam search. 1 means no beam search.
temperature (`float`, *optional*, defaults to 1.0):
The value used to module the next token probabilities.
penalty_alpha (`float`, *optional*):
The values balance the model confidence and the degeneration penalty in contrastive search decoding.
top_k (`int`, *optional*, defaults to 50):
The number of highest probability vocabulary tokens to keep for top-k-filtering.
top_p (`float`, *optional*, defaults to 1.0):
If set to float < 1, only the most probable tokens with probabilities that add up to `top_p` or higher
are kept for generation.
repetition_penalty (`float`, *optional*, defaults to 1.0):
The parameter for repetition penalty. 1.0 means no penalty. See [this
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
pad_token_id (`int`, *optional*):
The id of the *padding* token.
bos_token_id (`int`, *optional*):
The id of the *beginning-of-sequence* token.
eos_token_id (`int`, *optional*):
The id of the *end-of-sequence* token.
length_penalty (`float`, *optional*, defaults to 1.0):
Exponential penalty to the length that is used with beam-based generation. It is applied as an exponent
to the sequence length, which in turn is used to divide the score of the sequence. Since the score is
the log likelihood of the sequence (i.e. negative), `length_penalty` > 0.0 promotes longer sequences,
while `length_penalty` < 0.0 encourages shorter sequences.
no_repeat_ngram_size (`int`, *optional*, defaults to 0):
If set to int > 0, all ngrams of that size can only occur once.
bad_words_ids(`List[int]`, *optional*):
List of token ids that are not allowed to be generated. In order to get the tokens of the words that
should not appear in the generated text, use `tokenizer.encode(bad_word, add_prefix_space=True)`.
num_return_sequences(`int`, *optional*, defaults to 1):
The number of independently computed returned sequences for each element in the batch.
attention_mask (`tf.Tensor` of `dtype=tf.int32` and shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values are in `[0, 1]`, 1 for tokens
that are not masked, and 0 for masked tokens.
If not provided, will default to a tensor the same shape as `input_ids` that masks the pad token.
[What are attention masks?](../glossary#attention-mask)
decoder_start_token_id (`int`, *optional*):
If an encoder-decoder model starts decoding with a different token than *bos*, the id of that token.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should use the past last key/values attentions (if applicable to the model) to
speed up decoding.
generation_config (`~generation.GenerationConfig`, *optional*):
The generation configuration to be used as base parametrization for the generation call. `**kwargs`
passed to generate matching the attributes of `generation_config` will override them. If
`generation_config` is not provided, the default will be used, which had the following loading
priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model
configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s
default values, whose documentation should be checked to parameterize generation.
seed (`List[int]`, *optional*):
Random seed to control sampling, containing two integers, used when `do_sample` is `True`. See the
`seed` argument from stateless functions in `tf.random`.
output_attentions (`bool`, *optional*, defaults to `False`):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more details.
output_hidden_states (`bool`, *optional*, defaults to `False`):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
for more details.
output_scores (`bool`, *optional*, defaults to `False`):
Whether or not to return the prediction scores. See `scores` under returned tensors for more details.
return_dict_in_generate (`bool`, *optional*, defaults to `False`):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
forced_bos_token_id (`int`, *optional*):
The id of the token to force as the first generated token after the `decoder_start_token_id`. Useful
for multilingual models like [mBART](../model_doc/mbart) where the first generated token needs to be
the target language token.
forced_eos_token_id (`int`, *optional*):
The id of the token to force as the last generated token when `max_length` is reached.
suppress_tokens (`List[int]`, *optional*, defaults to `model.config.suppress_tokens`):
A list of tokens that will be supressed at generation. The `SupressTokens` logit processor will set
their log probs to `-inf` so that they are not sampled.
begin_suppress_tokens (`List[int]`, *optional*, defaults to `model.config.begin_suppress_tokens`):
A list of tokens that will be supressed at the begining of the generation. The `SupressBeginTokens`
logit processor will set their log probs to `-inf` so that they are not sampled.
forced_decoder_ids (`List[List[int]]`, *optional*, defaults to `model.config.forced_decoder_ids`):
A list of pairs of integers which indicates a mapping from generation indices to token indices that
will be forced before sampling. For example, `[[1, 123]]` means the second generated token will always
be a token of index 123.
model_kwargs:
Additional model specific kwargs will be forwarded to the `call` function of the model.
kwargs:
Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be
forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder
specific kwargs should not be prefixed and decoder specific kwargs should be prefixed with *decoder_*.
Return:
[`~utils.ModelOutput`] or `tf.Tensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True` or when
......@@ -690,59 +587,92 @@ class TFGenerationMixin:
Examples:
Greedy decoding, using the default generation configuration and ad hoc modifications:
```python
tokenizer = AutoTokenizer.from_pretrained("distilgpt2") # Initialize tokenizer
model = TFAutoModelWithLMHead.from_pretrained("distilgpt2")
# Greedy decoding
outputs = model.generate(max_length=40)
print(f"Generated: {tokenizer.decode(outputs[0], skip_special_tokens=True)}")
tokenizer = AutoTokenizer.from_pretrained("openai-gpt")
model = TFAutoModelWithLMHead.from_pretrained("openai-gpt")
input_context = "The dog"
input_ids = tokenizer.encode(input_context, return_tensors="tf") # encode input context
# Generate 3 independent sequences using beam search decoding (5 beams) with sampling from initial context 'The dog'
outputs = model.generate(input_ids=input_ids, num_beams=5, num_return_sequences=3, temperature=1.5)
# 3 output sequences were generated
for i in range(3):
print(f"Generated {i}: {tokenizer.decode(outputs[i], skip_special_tokens=True)}")
tokenizer = AutoTokenizer.from_pretrained("distilgpt2")
model = TFAutoModelWithLMHead.from_pretrained("distilgpt2")
input_context = "The dog"
input_ids = tokenizer.encode(input_context, return_tensors="tf")
# Generate 3 candidates using sampling
outputs = model.generate(
input_ids=input_ids, max_length=40, temperature=0.7, num_return_sequences=3, do_sample=True
)
# 3 output sequences were generated
for i in range(3):
print(f"Generated {i}: {tokenizer.decode(outputs[i], skip_special_tokens=True)}")
tokenizer = AutoTokenizer.from_pretrained("ctrl")
model = TFAutoModelWithLMHead.from_pretrained("ctrl")
# "Legal" is one of the control codes for ctrl
input_context = "Legal My neighbor is"
input_ids = tokenizer.encode(input_context, return_tensors="tf")
outputs = model.generate(input_ids=input_ids, max_length=50, temperature=0.7, repetition_penalty=1.2)
print(f"Generated: {tokenizer.decode(outputs[0], skip_special_tokens=True)}")
tokenizer = AutoTokenizer.from_pretrained("gpt2")
model = TFAutoModelWithLMHead.from_pretrained("gpt2")
input_context = "My cute dog"
bad_words_ids = [
tokenizer.encode(bad_word, add_prefix_space=True) for bad_word in ["idiot", "stupid", "shut up"]
]
input_ids = tokenizer.encode(input_context, return_tensors="tf")
# generate sequences without allowing bad_words to be generated
outputs = model.generate(input_ids=input_ids, max_length=100, do_sample=True, bad_words_ids=bad_words_ids)
>>> from transformers import AutoTokenizer, TFAutoModelForCausalLM
>>> tokenizer = AutoTokenizer.from_pretrained("gpt2")
>>> model = TFAutoModelForCausalLM.from_pretrained("gpt2")
>>> prompt = "Today I believe we can finally"
>>> input_ids = tokenizer(prompt, return_tensors="tf").input_ids
>>> # Generate up to 30 tokens
>>> outputs = model.generate(input_ids, do_sample=False, max_length=30)
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
['Today I believe we can finally get to the point where we can make a difference in the lives of the people of the United States of America.\n']
```
Multinomial sampling, modifying an existing generation configuration:
```python
>>> from transformers import AutoTokenizer, TFAutoModelForCausalLM, GenerationConfig
>>> tokenizer = AutoTokenizer.from_pretrained("gpt2")
>>> model = TFAutoModelForCausalLM.from_pretrained("gpt2")
>>> prompt = "Today I believe we can finally"
>>> input_ids = tokenizer(prompt, return_tensors="tf").input_ids
>>> # Sample up to 30 tokens
>>> generation_config = GenerationConfig.from_pretrained("gpt2")
>>> generation_config.max_length = 30
>>> generation_config.do_sample = True
>>> outputs = model.generate(input_ids, generation_config=generation_config, seed=[0, 0])
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
["Today I believe we can finally start taking a bold stand against climate change and climate change mitigation efforts such as President Obama's climate ban and President Trump's"]
```
Beam-search decoding, using a freshly initialized generation configuration:
```python
>>> from transformers import AutoTokenizer, TFAutoModelForSeq2SeqLM, GenerationConfig
>>> tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-de")
>>> model = TFAutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-de")
>>> sentence = "Paris is one of the densest populated areas in Europe."
>>> input_ids = tokenizer(sentence, return_tensors="tf").input_ids
>>> generation_config = GenerationConfig(
... max_length=64,
... num_beams=5,
... bos_token_id=0,
... eos_token_id=0,
... decoder_start_token_id=58100,
... pad_token_id=58100,
... bad_words_ids=[[58100]],
... )
>>> outputs = model.generate(input_ids, generation_config=generation_config)
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
['Paris ist eines der dichtesten besiedelten Gebiete Europas.']
```"""
# 0. Validate the `.generate()` call
# 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
self._validate_model_class()
# priority: `generation_config` argument > `model.generation_config` (the default generation config)
if generation_config is None:
# legacy: users may modify the model configuration to control generation -- update the generation config
# model attribute accordingly, if it was created from the model config
if self.generation_config._from_model_config:
new_generation_config = GenerationConfig.from_model_config(self.config)
if new_generation_config != self.generation_config:
warnings.warn(
"You have modified the pretrained model configuration to control generation. This is a"
" deprecated strategy to control generation and will be removed soon, in a future version."
" Please use a generation configuration file (see"
" https://huggingface.co/docs/transformers/main_classes/text_generation)"
)
self.generation_config = new_generation_config
generation_config = self.generation_config
generation_config = copy.deepcopy(generation_config)
model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs
self._validate_model_kwargs(model_kwargs.copy())
# 1. Cast input dtypes to tf.int32 unless they're floats (which happens for some image models)
# 2. Cast input dtypes to tf.int32 unless they're floats (which happens for some image models)
if input_ids is not None:
if isinstance(input_ids, tf.Tensor) and input_ids.dtype.is_floating:
pass
......@@ -750,8 +680,8 @@ class TFGenerationMixin:
pass
else:
input_ids = tf.cast(input_ids, tf.int32)
if attention_mask is not None:
attention_mask = tf.cast(attention_mask, tf.int32)
if model_kwargs.get("attention_mask") is not None:
model_kwargs["attention_mask"] = tf.cast(model_kwargs["attention_mask"], tf.int32)
if "decoder_input_ids" in model_kwargs:
if (
isinstance(model_kwargs["decoder_input_ids"], tf.Tensor)
......@@ -765,44 +695,18 @@ class TFGenerationMixin:
else:
model_kwargs["decoder_input_ids"] = tf.cast(model_kwargs["decoder_input_ids"], tf.int32)
# 2. Set generation parameters if not already defined
length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty
early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping
bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
forced_bos_token_id = (
forced_bos_token_id if forced_bos_token_id is not None else self.config.forced_bos_token_id
)
forced_eos_token_id = (
forced_eos_token_id if forced_eos_token_id is not None else self.config.forced_eos_token_id
)
output_scores = output_scores if output_scores is not None else self.config.output_scores
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict_in_generate = (
return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate
)
num_beams = num_beams if num_beams is not None else self.config.num_beams
do_sample = do_sample if do_sample is not None else self.config.do_sample
num_return_sequences = (
num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences
)
if pad_token_id is None and eos_token_id is not None:
if attention_mask is None:
# 3. Set generation parameters if not already defined
if generation_config.pad_token_id is None and generation_config.eos_token_id is not None:
if model_kwargs.get("attention_mask") is None:
logger.warning(
"The attention mask and the pad token id were not set. As a consequence, you may observe "
"unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results."
)
logger.warning(f"Setting `pad_token_id` to {eos_token_id} (first `eos_token_id`) to generate sequence")
pad_token_id = eos_token_id
logger.warning(
f"Setting `pad_token_id` to {generation_config.eos_token_id} (first `eos_token_id`) to generate"
" sequence"
)
generation_config.pad_token_id = generation_config.eos_token_id
use_xla = not tf.executing_eagerly()
if use_xla and not self.supports_xla_generation:
......@@ -810,241 +714,242 @@ class TFGenerationMixin:
"The selected model does not support Graph mode nor XLA generation (e.g. from tf.function())"
)
# 3. Define model inputs
input_ids = self._prepare_model_inputs(input_ids, bos_token_id)
# 4. Define model inputs
input_ids = self._prepare_model_inputs(input_ids, generation_config.bos_token_id)
# inputs_ids now has to be defined and cannot be None anymore
batch_size = shape_list(input_ids)[0]
# 4. Prepare other model kwargs
if output_attentions is not None:
model_kwargs["output_attentions"] = output_attentions
if output_hidden_states is not None:
model_kwargs["output_hidden_states"] = output_hidden_states
if use_cache is not None:
model_kwargs["use_cache"] = use_cache
if attention_mask is not None:
model_kwargs["attention_mask"] = attention_mask
# 5. Prepare other model kwargs
model_kwargs["output_attentions"] = generation_config.output_attentions
model_kwargs["output_hidden_states"] = generation_config.output_hidden_states
model_kwargs["use_cache"] = generation_config.use_cache
accepts_attention_mask = "attention_mask" in set(inspect.signature(self.call).parameters.keys())
requires_attention_mask = "encoder_outputs" not in model_kwargs
if model_kwargs.get("attention_mask", None) is None and requires_attention_mask and accepts_attention_mask:
model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
input_ids, pad_token_id, eos_token_id
input_ids, generation_config.pad_token_id, generation_config.eos_token_id
)
# decoder-only models should use left-padding for generation
if not self.config.is_encoder_decoder:
if pad_token_id is not None and tf.math.reduce_any(input_ids[:, -1] == pad_token_id):
if generation_config.pad_token_id is not None and tf.math.reduce_any(
input_ids[:, -1] == generation_config.pad_token_id
):
logger.warning(
"A decoder-only architecture is being used, but right-padding was detected! For correct "
"generation results, please set `padding_side='left'` when initializing the tokenizer."
)
# 5. Prepare model inputs which will be used for auto-regressive generation
# 6. Prepare model inputs which will be used for auto-regressive generation
if self.config.is_encoder_decoder:
# if encoder-decoder, we create encoder_outputs and add to `model_kwargs`
model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(input_ids, model_kwargs)
# if encoder-decoder then `input_ids` come from `decoder_start_token_id`
input_ids = self._prepare_decoder_input_ids_for_generation(
batch_size,
decoder_start_token_id=decoder_start_token_id,
bos_token_id=bos_token_id,
decoder_start_token_id=generation_config.decoder_start_token_id,
bos_token_id=generation_config.bos_token_id,
model_kwargs=model_kwargs,
)
# 6. Prepare `max_length` depending on other stopping criteria.
# 7. Prepare `max_length` depending on other stopping criteria.
input_ids_seq_length = input_ids.shape[-1]
if max_length is None and max_new_tokens is None:
has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
if has_default_max_length and generation_config.max_new_tokens is None:
warnings.warn(
"Neither `max_length` nor `max_new_tokens` have been set, `max_length` will default to "
f"{self.config.max_length} (`self.config.max_length`). Controlling `max_length` via the config is "
"deprecated and `max_length` will be removed from the config in v5 of Transformers -- we recommend "
"using `max_new_tokens` to control the maximum length of the generation.",
"Neither `max_length` nor `max_new_tokens` have been set, `max_length` will default to"
f" {generation_config.max_length} (`generation_config.max_length`). Controlling `max_length` via the"
" config is deprecated and `max_length` will be removed from the config in v5 of Transformers -- we"
" recommend using `max_new_tokens` to control the maximum length of the generation.",
UserWarning,
)
elif max_length is None and max_new_tokens is not None:
max_length = max_new_tokens + input_ids_seq_length
elif max_length is not None and max_new_tokens is not None:
elif has_default_max_length and generation_config.max_new_tokens is not None:
generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length
elif not has_default_max_length and generation_config.max_new_tokens is not None:
raise ValueError(
"Both `max_new_tokens` and `max_length` have been set but they serve the same purpose -- setting a"
" limit to the generated output length. Remove one of those arguments. Please refer to the"
" documentation for more information. "
"(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)"
)
# default to config if still None
max_length = max_length if max_length is not None else self.config.max_length
min_length = min_length if min_length is not None else self.config.min_length
if min_length is not None and min_length > max_length:
if generation_config.min_length is not None and generation_config.min_length > generation_config.max_length:
raise ValueError(
f"Unfeasable length constraints: the minimum length ({min_length}) is larger than the maximum "
f"length ({max_length})"
f"Unfeasable length constraints: the minimum length ({generation_config.min_length}) is larger than"
f" the maximum length ({generation_config.max_length})"
)
if input_ids_seq_length >= max_length:
if input_ids_seq_length >= generation_config.max_length:
input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
logger.warning(
f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to"
f" {max_length}. This can lead to unexpected behavior. You should consider increasing"
"`max_new_tokens`."
f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider"
" increasing`max_new_tokens`."
)
# 7. determine generation mode
# 8. determine generation mode
is_contrastive_search_gen_mode = (
top_k is not None and top_k > 1 and do_sample is False and penalty_alpha is not None and penalty_alpha > 0
)
is_greedy_gen_mode = not is_contrastive_search_gen_mode and (num_beams == 1) and do_sample is False
is_beam_gen_mode = not is_contrastive_search_gen_mode and (num_beams > 1) and do_sample is False
is_sample_gen_mode = (num_beams == 1) and do_sample is True
is_beam_sample_gen_mode = (num_beams > 1) and do_sample is True
# 8. prepare distribution pre_processing samplers
generation_config.top_k is not None
and generation_config.top_k > 1
and generation_config.do_sample is False
and generation_config.penalty_alpha is not None
and generation_config.penalty_alpha > 0
)
is_greedy_gen_mode = (
not is_contrastive_search_gen_mode
and (generation_config.num_beams == 1)
and generation_config.do_sample is False
)
is_beam_gen_mode = (
not is_contrastive_search_gen_mode
and (generation_config.num_beams > 1)
and generation_config.do_sample is False
)
is_sample_gen_mode = (generation_config.num_beams == 1) and generation_config.do_sample is True
is_beam_sample_gen_mode = (generation_config.num_beams > 1) and generation_config.do_sample is True
# 9. prepare distribution pre_processing samplers
logits_processor = self._get_logits_processor(
repetition_penalty=repetition_penalty,
no_repeat_ngram_size=no_repeat_ngram_size,
generation_config=generation_config,
input_ids_seq_length=input_ids_seq_length,
bad_words_ids=bad_words_ids,
min_length=min_length,
max_length=max_length,
eos_token_id=eos_token_id,
forced_bos_token_id=forced_bos_token_id,
forced_eos_token_id=forced_eos_token_id,
suppress_tokens=suppress_tokens,
begin_suppress_tokens=begin_suppress_tokens,
forced_decoder_ids=forced_decoder_ids,
)
# 9. go into different generation modes
# 10. go into different generation modes
if is_greedy_gen_mode:
if num_return_sequences > 1:
if generation_config.num_return_sequences > 1:
raise ValueError(
f"num_return_sequences has to be 1, but is {num_return_sequences} when doing greedy search."
f"num_return_sequences has to be 1, but is {generation_config.num_return_sequences} when doing"
" greedy search."
)
# 10. run greedy search
# 11. run greedy search
return self.greedy_search(
input_ids,
max_length=max_length,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
max_length=generation_config.max_length,
pad_token_id=generation_config.pad_token_id,
eos_token_id=generation_config.eos_token_id,
logits_processor=logits_processor,
output_scores=output_scores,
return_dict_in_generate=return_dict_in_generate,
output_scores=generation_config.output_scores,
return_dict_in_generate=generation_config.return_dict_in_generate,
**model_kwargs,
)
elif is_contrastive_search_gen_mode:
if num_return_sequences > 1:
if generation_config.num_return_sequences > 1:
raise ValueError(
f"num_return_sequences has to be 1, but is {num_return_sequences} when doing contrastive search."
f"num_return_sequences has to be 1, but is {generation_config.num_return_sequences} when doing"
" contrastive search."
)
# 10. run contrastive search
# 11. run contrastive search
return self.contrastive_search(
input_ids,
top_k=top_k,
penalty_alpha=penalty_alpha,
top_k=generation_config.top_k,
penalty_alpha=generation_config.penalty_alpha,
logits_processor=logits_processor,
max_length=max_length,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
output_scores=output_scores,
return_dict_in_generate=return_dict_in_generate,
max_length=generation_config.max_length,
pad_token_id=generation_config.pad_token_id,
eos_token_id=generation_config.eos_token_id,
output_scores=generation_config.output_scores,
return_dict_in_generate=generation_config.return_dict_in_generate,
**model_kwargs,
)
elif is_sample_gen_mode:
# 10. prepare logits warper
logits_warper = self._get_logits_warper(top_k=top_k, top_p=top_p, temperature=temperature)
# 11. prepare logits warper
logits_warper = self._get_logits_warper(generation_config=generation_config)
# 11. expand input_ids with `num_return_sequences` additional sequences per batch
# 12. expand input_ids with `num_return_sequences` additional sequences per batch
input_ids, model_kwargs = self._expand_inputs_for_generation(
input_ids=input_ids,
expand_size=num_return_sequences,
expand_size=generation_config.num_return_sequences,
is_encoder_decoder=self.config.is_encoder_decoder,
**model_kwargs,
)
# 12. run sample
# 13. run sample
return self.sample(
input_ids,
logits_processor=logits_processor,
logits_warper=logits_warper,
max_length=max_length,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
max_length=generation_config.max_length,
pad_token_id=generation_config.pad_token_id,
eos_token_id=generation_config.eos_token_id,
seed=seed,
output_scores=output_scores,
return_dict_in_generate=return_dict_in_generate,
output_scores=generation_config.output_scores,
return_dict_in_generate=generation_config.return_dict_in_generate,
**model_kwargs,
)
elif is_beam_gen_mode:
if num_beams < num_return_sequences:
if generation_config.num_beams < generation_config.num_return_sequences:
raise ValueError(
"Beam search decoding cannot return more sequences than it has beams. Please set "
f"num_beams >= num_return_sequences, got {num_beams} and {num_return_sequences} (respectivelly)"
"Beam search decoding cannot return more sequences than it has beams. Please set num_beams >="
f" num_return_sequences, got {generation_config.num_beams} and"
f" {generation_config.num_return_sequences} (respectivelly)"
)
# 10. broadcast inputs to the desired number of beams
input_ids = self._expand_to_num_beams(input_ids, num_beams=num_beams)
# 11. broadcast inputs to the desired number of beams
input_ids = self._expand_to_num_beams(input_ids, num_beams=generation_config.num_beams)
if "encoder_outputs" in model_kwargs:
model_kwargs["encoder_outputs"]["last_hidden_state"] = self._expand_to_num_beams(
model_kwargs["encoder_outputs"]["last_hidden_state"], num_beams=num_beams
model_kwargs["encoder_outputs"]["last_hidden_state"], num_beams=generation_config.num_beams
)
if "attention_mask" in model_kwargs:
model_kwargs["attention_mask"] = self._expand_to_num_beams(
model_kwargs["attention_mask"], num_beams=num_beams
model_kwargs["attention_mask"], num_beams=generation_config.num_beams
)
# 11. run beam search
# 12. run beam search
return self.beam_search(
input_ids,
max_length=max_length,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
length_penalty=length_penalty,
early_stopping=early_stopping,
max_length=generation_config.max_length,
pad_token_id=generation_config.pad_token_id,
eos_token_id=generation_config.eos_token_id,
length_penalty=generation_config.length_penalty,
early_stopping=generation_config.early_stopping,
logits_processor=logits_processor,
return_dict_in_generate=return_dict_in_generate,
num_return_sequences=num_return_sequences,
return_dict_in_generate=generation_config.return_dict_in_generate,
num_return_sequences=generation_config.num_return_sequences,
**model_kwargs,
)
elif is_beam_sample_gen_mode:
if num_beams < num_return_sequences:
if generation_config.num_beams < generation_config.num_return_sequences:
raise ValueError(
"Beam search decoding cannot return more sequences than it has beams. Please set "
f"num_beams >= num_return_sequences, got {num_beams} and {num_return_sequences} (respectivelly)"
"Beam search decoding cannot return more sequences than it has beams. Please set num_beams >="
f" num_return_sequences, got {generation_config.num_beams} and"
f" {generation_config.num_return_sequences} (respectivelly)"
)
# 10. prepare logits warper
logits_warper = self._get_logits_warper(top_k=top_k, top_p=top_p, temperature=temperature)
# 11. prepare logits warper
logits_warper = self._get_logits_warper(generation_config=generation_config)
# 11. broadcast inputs to the desired number of beams
input_ids = self._expand_to_num_beams(input_ids, num_beams=num_beams)
# 12. broadcast inputs to the desired number of beams
input_ids = self._expand_to_num_beams(input_ids, num_beams=generation_config.num_beams)
if "encoder_outputs" in model_kwargs:
model_kwargs["encoder_outputs"]["last_hidden_state"] = self._expand_to_num_beams(
model_kwargs["encoder_outputs"]["last_hidden_state"], num_beams=num_beams
model_kwargs["encoder_outputs"]["last_hidden_state"], num_beams=generation_config.num_beams
)
if "attention_mask" in model_kwargs:
model_kwargs["attention_mask"] = self._expand_to_num_beams(
model_kwargs["attention_mask"], num_beams=num_beams
model_kwargs["attention_mask"], num_beams=generation_config.num_beams
)
# 12. run beam sample (beam search with sampling)
# 13. run beam sample (beam search with sampling)
return self.beam_search(
input_ids,
do_sample=True,
max_length=max_length,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
length_penalty=length_penalty,
early_stopping=early_stopping,
max_length=generation_config.max_length,
pad_token_id=generation_config.pad_token_id,
eos_token_id=generation_config.eos_token_id,
length_penalty=generation_config.length_penalty,
early_stopping=generation_config.early_stopping,
logits_processor=logits_processor,
logits_warper=logits_warper,
return_dict_in_generate=return_dict_in_generate,
num_return_sequences=num_return_sequences,
return_dict_in_generate=generation_config.return_dict_in_generate,
num_return_sequences=generation_config.num_return_sequences,
**model_kwargs,
)
......@@ -1108,26 +1013,16 @@ class TFGenerationMixin:
# retrieve decoder_start_token_id for encoder-decoder models
# fall back to bos_token_id if necessary
decoder_start_token_id = (
decoder_start_token_id if decoder_start_token_id is not None else self.config.decoder_start_token_id
decoder_start_token_id
if decoder_start_token_id is not None
else self.generation_config.decoder_start_token_id
)
bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id
bos_token_id = bos_token_id if bos_token_id is not None else self.generation_config.bos_token_id
if decoder_start_token_id is not None:
return decoder_start_token_id
elif (
hasattr(self.config, "decoder")
and hasattr(self.config.decoder, "decoder_start_token_id")
and self.config.decoder.decoder_start_token_id is not None
):
return self.config.decoder.decoder_start_token_id
elif bos_token_id is not None:
return bos_token_id
elif (
hasattr(self.config, "decoder")
and hasattr(self.config.decoder, "bos_token_id")
and self.config.decoder.bos_token_id is not None
):
return self.config.decoder.bos_token_id
raise ValueError(
"`decoder_start_token_id` or `bos_token_id` has to be defined for encoder-decoder generation."
)
......@@ -1332,46 +1227,30 @@ class TFGenerationMixin:
def _get_logits_warper(
self,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
temperature: Optional[float] = None,
generation_config: GenerationConfig,
) -> TFLogitsProcessorList:
"""
This class returns a [`TFLogitsProcessorList`] list object that contains all relevant [`TFLogitsWarper`]
instances used for multinomial sampling.
"""
# init warp parameters
top_k = top_k if top_k is not None else self.config.top_k
top_p = top_p if top_p is not None else self.config.top_p
temperature = temperature if temperature is not None else self.config.temperature
# instantiate warpers list
warpers = TFLogitsProcessorList()
# the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
# all samplers can be found in `generation_utils_samplers.py`
if temperature is not None and temperature != 1.0:
warpers.append(TFTemperatureLogitsWarper(temperature))
if top_k is not None and top_k != 0:
warpers.append(TFTopKLogitsWarper(top_k=top_k, min_tokens_to_keep=1))
if top_p is not None and top_p < 1.0:
warpers.append(TFTopPLogitsWarper(top_p=top_p, min_tokens_to_keep=1))
if generation_config.temperature is not None and generation_config.temperature != 1.0:
warpers.append(TFTemperatureLogitsWarper(generation_config.temperature))
if generation_config.top_k is not None and generation_config.top_k != 0:
warpers.append(TFTopKLogitsWarper(top_k=generation_config.top_k, min_tokens_to_keep=1))
if generation_config.top_p is not None and generation_config.top_p < 1.0:
warpers.append(TFTopPLogitsWarper(top_p=generation_config.top_p, min_tokens_to_keep=1))
return warpers
def _get_logits_processor(
self,
repetition_penalty: float,
no_repeat_ngram_size: int,
generation_config: GenerationConfig,
input_ids_seq_length: int,
bad_words_ids: List[List[int]],
min_length: int,
max_length: int,
eos_token_id: int,
forced_bos_token_id: int,
forced_eos_token_id: int,
suppress_tokens: Optional[List[int]] = None,
begin_suppress_tokens: Optional[List[int]] = None,
forced_decoder_ids: Optional[List[List[int]]] = None,
) -> TFLogitsProcessorList:
"""
This class returns a [`TFLogitsProcessorList`] list object that contains all relevant [`TFLogitsProcessor`]
......@@ -1379,42 +1258,45 @@ class TFGenerationMixin:
"""
processors = TFLogitsProcessorList()
repetition_penalty = repetition_penalty if repetition_penalty is not None else self.config.repetition_penalty
no_repeat_ngram_size = (
no_repeat_ngram_size if no_repeat_ngram_size is not None else self.config.no_repeat_ngram_size
)
bad_words_ids = bad_words_ids if bad_words_ids is not None else self.config.bad_words_ids
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
suppress_tokens = suppress_tokens if suppress_tokens is not None else self.config.suppress_tokens
begin_suppress_tokens = (
begin_suppress_tokens if begin_suppress_tokens is not None else self.config.begin_suppress_tokens
)
if forced_decoder_ids is None and hasattr(self.config, "forced_decoder_ids"):
forced_decoder_ids = self.config.forced_decoder_ids
# instantiate processors list
if repetition_penalty is not None and repetition_penalty != 1.0:
processors.append(TFRepetitionPenaltyLogitsProcessor(penalty=repetition_penalty))
if no_repeat_ngram_size is not None and no_repeat_ngram_size > 0:
processors.append(TFNoRepeatNGramLogitsProcessor(no_repeat_ngram_size))
if bad_words_ids is not None:
processors.append(TFNoBadWordsLogitsProcessor(bad_words_ids, eos_token_id))
if min_length is not None and eos_token_id is not None and min_length > 0:
processors.append(TFMinLengthLogitsProcessor(min_length, eos_token_id))
if forced_bos_token_id is not None:
processors.append(TFForcedBOSTokenLogitsProcessor(forced_bos_token_id))
if forced_eos_token_id is not None:
processors.append(TFForcedEOSTokenLogitsProcessor(max_length, forced_eos_token_id))
if suppress_tokens is not None:
processors.append(TFSuppressTokensLogitsProcessor(suppress_tokens))
if begin_suppress_tokens is not None:
if generation_config.repetition_penalty is not None and generation_config.repetition_penalty != 1.0:
processors.append(TFRepetitionPenaltyLogitsProcessor(penalty=generation_config.repetition_penalty))
if generation_config.no_repeat_ngram_size is not None and generation_config.no_repeat_ngram_size > 0:
processors.append(TFNoRepeatNGramLogitsProcessor(generation_config.no_repeat_ngram_size))
if generation_config.bad_words_ids is not None:
processors.append(
TFNoBadWordsLogitsProcessor(generation_config.bad_words_ids, generation_config.eos_token_id)
)
if (
generation_config.min_length is not None
and generation_config.eos_token_id is not None
and generation_config.min_length > 0
):
processors.append(TFMinLengthLogitsProcessor(generation_config.min_length, generation_config.eos_token_id))
if generation_config.forced_bos_token_id is not None:
processors.append(TFForcedBOSTokenLogitsProcessor(generation_config.forced_bos_token_id))
if generation_config.forced_eos_token_id is not None:
processors.append(
TFForcedEOSTokenLogitsProcessor(generation_config.max_length, generation_config.forced_eos_token_id)
)
if generation_config.suppress_tokens is not None:
processors.append(TFSuppressTokensLogitsProcessor(generation_config.suppress_tokens))
if generation_config.begin_suppress_tokens is not None:
begin_index = input_ids_seq_length
begin_index = begin_index if (input_ids_seq_length > 1 or forced_bos_token_id is None) else begin_index + 1
if forced_decoder_ids is not None:
begin_index += forced_decoder_ids[-1][0] # generation starts after the last token that is forced
processors.append(TFSuppressTokensAtBeginLogitsProcessor(begin_suppress_tokens, begin_index))
if forced_decoder_ids is not None:
processors.append(TFForceTokensLogitsProcessor(forced_decoder_ids))
begin_index = (
begin_index
if (input_ids_seq_length > 1 or generation_config.forced_bos_token_id is None)
else begin_index + 1
)
if generation_config.forced_decoder_ids is not None:
begin_index += generation_config.forced_decoder_ids[-1][
0
] # generation starts after the last token that is forced
processors.append(
TFSuppressTokensAtBeginLogitsProcessor(generation_config.begin_suppress_tokens, begin_index)
)
if generation_config.forced_decoder_ids is not None:
processors.append(TFForceTokensLogitsProcessor(generation_config.forced_decoder_ids))
return processors
def greedy_search(
......@@ -1500,17 +1382,22 @@ class TFGenerationMixin:
# 1. init greedy_search values
logits_processor = logits_processor if logits_processor is not None else TFLogitsProcessorList()
max_length = max_length if max_length is not None else self.config.max_length
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
output_scores = output_scores if output_scores is not None else self.config.output_scores
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
max_length = max_length if max_length is not None else self.generation_config.max_length
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
output_attentions = (
output_attentions if output_attentions is not None else self.generation_config.output_attentions
)
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states
)
return_dict_in_generate = (
return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate
return_dict_in_generate
if return_dict_in_generate is not None
else self.generation_config.return_dict_in_generate
)
use_cache = model_kwargs.pop("use_cache", self.generation_config.use_cache)
use_xla = not tf.executing_eagerly()
# TODO (Joao): fix cache format or find programatic way to detect cache index
# GPT2 and other models has a slightly different cache structure, with a different batch axis
......@@ -1546,7 +1433,7 @@ class TFGenerationMixin:
input_ids = generated[:, :cur_len]
else:
input_ids = tf.expand_dims(generated[:, cur_len - 1], -1)
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
model_inputs = self.prepare_inputs_for_generation(input_ids, use_cache=use_cache, **model_kwargs)
# forward pass to get next token logits
model_outputs = self(
**model_inputs,
......@@ -1772,17 +1659,22 @@ class TFGenerationMixin:
logits_processor = logits_processor if logits_processor is not None else TFLogitsProcessorList()
logits_warper = logits_warper if logits_warper is not None else TFLogitsProcessorList()
max_length = max_length if max_length is not None else self.config.max_length
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
output_scores = output_scores if output_scores is not None else self.config.output_scores
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
max_length = max_length if max_length is not None else self.generation_config.max_length
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
output_attentions = (
output_attentions if output_attentions is not None else self.generation_config.output_attentions
)
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states
)
return_dict_in_generate = (
return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate
return_dict_in_generate
if return_dict_in_generate is not None
else self.generation_config.return_dict_in_generate
)
use_cache = model_kwargs.pop("use_cache", self.generation_config.use_cache)
use_xla = not tf.executing_eagerly()
# TODO (Joao): fix cache format or find programatic way to detect cache index
# GPT2 and other models has a slightly different cache structure, with a different batch axis
......@@ -1814,7 +1706,7 @@ class TFGenerationMixin:
input_ids = generated[:, :cur_len]
else:
input_ids = tf.expand_dims(generated[:, cur_len - 1], -1)
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
model_inputs = self.prepare_inputs_for_generation(input_ids, use_cache=use_cache, **model_kwargs)
# forward pass to get next token logits
model_outputs = self(
**model_inputs,
......@@ -2091,25 +1983,30 @@ class TFGenerationMixin:
logits_processor = logits_processor if logits_processor is not None else TFLogitsProcessorList()
logits_warper = logits_warper if logits_warper is not None else TFLogitsProcessorList()
max_length = max_length if max_length is not None else self.config.max_length
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
max_length = max_length if max_length is not None else self.generation_config.max_length
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
num_return_sequences = (
num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences
num_return_sequences if num_return_sequences is not None else self.generation_config.num_return_sequences
)
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_attentions = (
output_attentions if output_attentions is not None else self.generation_config.output_attentions
)
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states
)
output_scores = output_scores if output_scores is not None else self.config.output_scores
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
return_dict_in_generate = (
return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate
return_dict_in_generate
if return_dict_in_generate is not None
else self.generation_config.return_dict_in_generate
)
length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty
early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping
length_penalty = length_penalty if length_penalty is not None else self.generation_config.length_penalty
early_stopping = early_stopping if early_stopping is not None else self.generation_config.early_stopping
use_cache = model_kwargs.pop("use_cache", self.generation_config.use_cache)
use_xla = not tf.executing_eagerly()
# TODO (Joao): fix cache format or find programatic way to detect cache index
# GPT2 and other models has a slightly different cache structure, with a different batch axis
......@@ -2199,7 +2096,9 @@ class TFGenerationMixin:
input_ids = running_sequences[:, :, :cur_len]
else:
input_ids = tf.expand_dims(running_sequences[:, :, cur_len - 1], -1)
model_inputs = self.prepare_inputs_for_generation(flatten_beam_dim(input_ids), **model_kwargs)
model_inputs = self.prepare_inputs_for_generation(
flatten_beam_dim(input_ids), use_cache=use_cache, **model_kwargs
)
model_outputs = self(
**model_inputs,
return_dict=True,
......@@ -2521,17 +2420,22 @@ class TFGenerationMixin:
# 1. init greedy_search values
logits_processor = logits_processor if logits_processor is not None else TFLogitsProcessorList()
logits_warper = logits_warper if logits_warper is not None else TFLogitsProcessorList()
max_length = max_length if max_length is not None else self.config.max_length
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
output_scores = output_scores if output_scores is not None else self.config.output_scores
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
max_length = max_length if max_length is not None else self.generation_config.max_length
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
output_attentions = (
output_attentions if output_attentions is not None else self.generation_config.output_attentions
)
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states
)
return_dict_in_generate = (
return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate
return_dict_in_generate
if return_dict_in_generate is not None
else self.generation_config.return_dict_in_generate
)
use_cache = True # In contrastive search, we always use cache
use_xla = not tf.executing_eagerly()
# TODO (Joao): fix cache format or find programatic way to detect cache index
# GPT2 and other models has a slightly different cache structure, with a different batch axis
......@@ -2571,8 +2475,9 @@ class TFGenerationMixin:
if model_kwargs.get("past") is None:
# prepare inputs
model_inputs = self.prepare_inputs_for_generation(generated[:, :cur_len], **model_kwargs)
model_inputs["use_cache"] = True
model_inputs = self.prepare_inputs_for_generation(
generated[:, :cur_len], use_cache=use_cache, **model_kwargs
)
# encode the given prefix and prepare model inputs; encoder-decoder model process the prefix and save
# the `encoder_outputs`
......@@ -2662,8 +2567,9 @@ class TFGenerationMixin:
)
# compute the candidate tokens by the language model and collects their hidden_states
next_model_inputs = self.prepare_inputs_for_generation(tf.reshape(top_k_ids, [-1, 1]), **model_kwargs)
next_model_inputs["use_cache"] = True
next_model_inputs = self.prepare_inputs_for_generation(
tf.reshape(top_k_ids, [-1, 1]), use_cache=use_cache, **model_kwargs
)
outputs = self(
**next_model_inputs, return_dict=True, output_hidden_states=True, output_attentions=output_attentions
)
......
......@@ -39,7 +39,7 @@ from . import DataCollatorWithPadding, DefaultDataCollator
from .activations_tf import get_tf_activation
from .configuration_utils import PretrainedConfig
from .dynamic_module_utils import custom_object_save
from .generation import TFGenerationMixin
from .generation import GenerationConfig, TFGenerationMixin
from .tf_utils import shape_list
from .utils import (
DUMMY_INPUTS,
......@@ -1137,6 +1137,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
# Save config and origin of the pretrained weights if given in model
self.config = config
self.name_or_path = config.name_or_path
self.generation_config = GenerationConfig.from_model_config(config) if self.can_generate() else None
# Set the serving spec quickly to ensure that Keras doesn't use the specific dummy input shapes as the spec
self._set_save_spec(self.serving.input_signature[0])
......@@ -1200,6 +1201,18 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
"""
raise NotImplementedError
def can_generate(self) -> bool:
"""
Returns whether this model can generate sequences with `.generate()`.
Returns:
`bool`: Whether this model can generate sequences with `.generate()`.
"""
# Detects whether `prepare_inputs_for_generation` has been overwritten, which is a requirement for generation
if "GenerationMixin" in str(self.prepare_inputs_for_generation):
return False
return True
def get_input_embeddings(self) -> tf.keras.layers.Layer:
"""
Returns the model's input embeddings layer.
......@@ -2832,6 +2845,29 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
" to use it for predictions and inference."
)
# If it is a model with generation capabilities, attempt to load the generation config
if model.can_generate():
try:
model.generation_config = GenerationConfig.from_pretrained(
pretrained_model_name_or_path,
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
revision=revision,
subfolder=subfolder,
_from_auto=from_auto_class,
_from_pipeline=from_pipeline,
**kwargs,
)
except OSError:
logger.info(
"Generation config file not found, using a generation config created from the model config."
)
pass
if output_loading_info:
loading_info = {
"missing_keys": missing_keys,
......
......@@ -15,6 +15,7 @@
"""TFRAG model implementation."""
import copy
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
......@@ -999,25 +1000,9 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss
context_input_ids=None,
context_attention_mask=None,
doc_scores=None,
max_length=None,
min_length=None,
early_stopping=None,
use_cache=None,
num_beams=None,
bos_token_id=None,
pad_token_id=None,
eos_token_id=None,
length_penalty=None,
no_repeat_ngram_size=None,
bad_words_ids=None,
num_return_sequences=None,
decoder_start_token_id=None,
n_docs=None,
output_scores=None,
output_attentions=None,
output_hidden_states=None,
return_dict_in_generate=None,
**model_kwargs
generation_config=None,
**kwargs
):
"""
Implements TFRAG token decoding.
......@@ -1051,91 +1036,32 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss
If the model has is not initialized with a `retriever`, `context_input_ids` has to be provided to the
forward pass. `context_input_ids` are returned by [`~RagRetriever.__call__`].
max_length (`int`, *optional*, defaults to 20):
The maximum length of the sequence to be generated.
min_length (`int`, *optional*, defaults to 10):
The minimum length of the sequence to be generated.
early_stopping (`bool`, *optional*, defaults to `False`):
Whether or not to stop the beam search when at least `num_beams` sentences are finished per batch or
not.
use_cache: (`bool`, *optional*, defaults to `True`):
Whether or not the model should use the past last key/values attentions (if applicable to the model) to
speed up decoding.
pad_token_id (`int`, *optional*):
The id of the *padding* token.
bos_token_id (`int`, *optional*):
The id of the *beginning-of-sequence* token.
eos_token_id (`int`, *optional*):
The id of the *end-of-sequence* token.
length_penalty (`float`, *optional*, defaults to 1.0):
Exponential penalty to the length that is used with beam-based generation. It is applied as an exponent
to the sequence length, which in turn is used to divide the score of the sequence. Since the score is
the log likelihood of the sequence (i.e. negative), `length_penalty` > 0.0 promotes longer sequences,
while `length_penalty` < 0.0 encourages shorter sequences.
no_repeat_ngram_size (`int`, *optional*, defaults to 0):
If set to int > 0, all ngrams of that size can only occur once.
bad_words_ids(`List[int]`, *optional*):
List of token ids that are not allowed to be generated. In order to get the tokens of the words that
should not appear in the generated text, use `tokenizer.encode(bad_word, add_prefix_space=True)`.
num_beams (`int`, *optional*, defaults to 1):
Number of beams for beam search. 1 means no beam search.
num_return_sequences(`int`, *optional*, defaults to 1):
The number of independently computed returned sequences for each element in the batch. Note that this
is not the value we pass to the `generator`'s `[`~generation.GenerationMixin.generate`] function, where
we set `num_return_sequences` to `num_beams`. decoder_start_token_id (`int`, *optional*): If an
encoder-decoder model starts decoding with a different token than *bos*, the id of that token.
n_docs (`int`, *optional*, defaults to `config.n_docs`)
Number of documents to retrieve and/or number of documents for which to generate an answer.
output_attentions (`bool`, *optional*, defaults to `False`):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more details.
output_hidden_states (`bool`, *optional*, defaults to `False`):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
for more details.
output_scores (`bool`, *optional*, defaults to `False`):
Whether or not to return the prediction scores. See `scores` under returned tensors for more details.
return_dict_in_generate (`bool`, *optional*, defaults to `False`):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
model_specific_kwargs:
Additional model specific kwargs will be forwarded to the `forward` function of the model.
generation_config (`~generation.GenerationConfig`, *optional*):
The generation configuration to be used as base parametrization for the generation call. `**kwargs`
passed to generate matching the attributes of `generation_config` will override them. If
`generation_config` is not provided, the default will be used, which had the following loading
priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model
configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s
default values, whose documentation should be checked to parameterize generation.
kwargs:
Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be
forwarded to the `forward` function of the model.
Return:
`tf.Tensor` of shape `(batch_size * num_return_sequences, sequence_length)`: The generated sequences. The
second dimension (sequence_length) is either equal to `max_length` or shorter if all batches finished early
due to the `eos_token_id`.
"""
# Handle `generation_config` and kwargs that might update it
if generation_config is None:
generation_config = self.generation_config
generation_config = copy.deepcopy(generation_config)
model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs
# set default parameters
n_docs = n_docs if n_docs is not None else self.config.n_docs
max_length = max_length if max_length is not None else self.config.max_length
min_length = min_length if min_length is not None else self.config.min_length
early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping
use_cache = use_cache if use_cache is not None else self.config.use_cache
num_beams = num_beams if num_beams is not None else self.config.num_beams
bos_token_id = bos_token_id if bos_token_id is not None else self.config.generator.bos_token_id
pad_token_id = pad_token_id if pad_token_id is not None else self.config.generator.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.config.generator.eos_token_id
length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty
no_repeat_ngram_size = (
no_repeat_ngram_size if no_repeat_ngram_size is not None else self.config.no_repeat_ngram_size
)
bad_words_ids = bad_words_ids if bad_words_ids is not None else self.config.bad_words_ids
num_return_sequences = (
num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences
)
decoder_start_token_id = (
decoder_start_token_id
if decoder_start_token_id is not None
else self.config.generator.decoder_start_token_id
)
output_scores = output_scores if output_scores is not None else self.config.output_scores
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict_in_generate = (
return_dict_in_generate if return_dict_in_generate is not None else self.config.return_dict_in_generate
)
# retrieve docs
if self.retriever is not None and context_input_ids is None:
......@@ -1174,14 +1100,14 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss
encoder_outputs = encoder(
input_ids=context_input_ids,
attention_mask=context_attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
output_attentions=generation_config.output_attentions,
output_hidden_states=generation_config.output_hidden_states,
return_dict=True,
)
decoder_input_ids = tf.fill(
(batch_size * num_beams, 1),
tf.cast(decoder_start_token_id, tf.int32),
(batch_size * generation_config.num_beams, 1),
tf.cast(generation_config.decoder_start_token_id, tf.int32),
)
last_hidden_state = encoder_outputs["last_hidden_state"]
......@@ -1207,10 +1133,12 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss
return tf.reshape(tensor, new_shape)
# correctly extend last_hidden_state and attention mask
context_attention_mask = extend_enc_output(context_attention_mask, num_beams=num_beams)
encoder_outputs["last_hidden_state"] = extend_enc_output(last_hidden_state, num_beams=num_beams)
context_attention_mask = extend_enc_output(context_attention_mask, num_beams=generation_config.num_beams)
encoder_outputs["last_hidden_state"] = extend_enc_output(
last_hidden_state, num_beams=generation_config.num_beams
)
doc_scores = tf.repeat(doc_scores, num_beams, axis=0)
doc_scores = tf.repeat(doc_scores, generation_config.num_beams, axis=0)
# define start_len & additional parameters
model_kwargs["doc_scores"] = doc_scores
......@@ -1219,41 +1147,35 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss
model_kwargs["n_docs"] = n_docs
pre_processor = self._get_logits_processor(
repetition_penalty=self.config.repetition_penalty,
no_repeat_ngram_size=no_repeat_ngram_size,
bad_words_ids=bad_words_ids,
min_length=min_length,
max_length=max_length,
eos_token_id=eos_token_id,
forced_bos_token_id=self.config.generator.forced_bos_token_id,
forced_eos_token_id=self.config.generator.forced_eos_token_id,
generation_config=generation_config,
input_ids_seq_length=tf.shape(decoder_input_ids)[-1],
)
if num_beams == 1:
if generation_config.num_beams == 1:
return self.greedy_search(
input_ids=decoder_input_ids,
max_length=max_length,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
max_length=generation_config.max_length,
pad_token_id=generation_config.pad_token_id,
eos_token_id=generation_config.eos_token_id,
logits_processor=pre_processor,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
output_scores=output_scores,
return_dict_in_generate=return_dict_in_generate,
output_attentions=generation_config.output_attentions,
output_hidden_states=generation_config.output_hidden_states,
output_scores=generation_config.output_scores,
return_dict_in_generate=generation_config.return_dict_in_generate,
**model_kwargs,
)
elif num_beams > 1:
if num_beams < num_return_sequences:
elif generation_config.num_beams > 1:
if generation_config.num_beams < generation_config.num_return_sequences:
raise ValueError(
"Beam search decoding cannot return more sequences than it has beams. Please set "
f"num_beams >= num_return_sequences, got {num_beams} and {num_return_sequences} (respectivelly)"
"Beam search decoding cannot return more sequences than it has beams. Please set num_beams >="
f" num_return_sequences, got {generation_config.num_beams} and"
f" {generation_config.num_return_sequences} (respectivelly)"
)
def unflatten_beam_dim(tensor):
"""Unflattens the first, flat batch*beam dimension of a non-scalar array."""
shape = shape_list(tensor)
return tf.reshape(tensor, [-1, num_beams] + shape[1:])
return tf.reshape(tensor, [-1, generation_config.num_beams] + shape[1:])
decoder_input_ids = unflatten_beam_dim(decoder_input_ids)
model_kwargs["attention_mask"] = unflatten_beam_dim(model_kwargs["attention_mask"])
......@@ -1263,18 +1185,20 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss
return self.beam_search(
input_ids=decoder_input_ids,
max_length=max_length,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
max_length=generation_config.max_length,
pad_token_id=generation_config.pad_token_id,
eos_token_id=generation_config.eos_token_id,
logits_processor=pre_processor,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
output_scores=output_scores,
return_dict_in_generate=return_dict_in_generate,
output_attentions=generation_config.output_attentions,
output_hidden_states=generation_config.output_hidden_states,
output_scores=generation_config.output_scores,
return_dict_in_generate=generation_config.return_dict_in_generate,
**model_kwargs,
)
else:
raise ValueError(f"`num_beams` has to be an integer strictly superior to 0 (≥ 1), but is {num_beams}")
raise ValueError(
f"`num_beams` has to be an integer strictly superior to 0 (≥ 1), but is {generation_config.num_beams}"
)
def get_input_embeddings(self):
return self.rag.generator.get_input_embeddings()
......
......@@ -1824,18 +1824,18 @@ class TFModelTesterMixin:
model.train_on_batch(test_batch, test_batch_labels)
def _test_xla_generate(self, **generate_kwargs):
def _generate_and_check_results(model, config, inputs_dict):
def _generate_and_check_results(model, inputs_dict):
if "input_ids" in inputs_dict:
inputs = inputs_dict["input_ids"]
# make sure there are no pad tokens in prompt, which may trigger unwanted behavior
if config.pad_token_id is not None:
if model.generation_config.pad_token_id is not None:
if config.pad_token_id == 0:
new_pad_token = config.pad_token_id + 1
new_pad_token = model.generation_config.pad_token_id + 1
else:
new_pad_token = config.pad_token_id - 1
new_pad_token = model.generation_config.pad_token_id - 1
else:
new_pad_token = None
inputs = tf.where(inputs != config.pad_token_id, inputs, new_pad_token)
inputs = tf.where(inputs != model.generation_config.pad_token_id, inputs, new_pad_token)
elif "input_features" in inputs_dict:
inputs = inputs_dict["input_features"]
else:
......@@ -1854,10 +1854,10 @@ class TFModelTesterMixin:
model = model_class(config)
if model.supports_xla_generation:
_generate_and_check_results(model, config, inputs_dict)
_generate_and_check_results(model, inputs_dict)
else:
with self.assertRaises(ValueError):
_generate_and_check_results(model, config, inputs_dict)
_generate_and_check_results(model, inputs_dict)
def test_xla_generate_fast(self):
"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment