Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
a6c850e4
Unverified
Commit
a6c850e4
authored
Jan 04, 2023
by
Joao Gante
Committed by
GitHub
Jan 04, 2023
Browse files
Generate: TF uses `GenerationConfig` as the basis for `.generate()` parametrization (#20994)
parent
3b309818
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
447 additions
and
581 deletions
+447
-581
src/transformers/generation/tf_utils.py
src/transformers/generation/tf_utils.py
+350
-444
src/transformers/modeling_tf_utils.py
src/transformers/modeling_tf_utils.py
+37
-1
src/transformers/models/rag/modeling_tf_rag.py
src/transformers/models/rag/modeling_tf_rag.py
+53
-129
tests/test_modeling_tf_common.py
tests/test_modeling_tf_common.py
+7
-7
No files found.
src/transformers/generation/tf_utils.py
View file @
a6c850e4
...
@@ -14,10 +14,11 @@
...
@@ -14,10 +14,11 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
copy
import
inspect
import
inspect
import
warnings
import
warnings
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
Any
,
Dict
,
Optional
,
Tuple
,
Union
import
numpy
as
np
import
numpy
as
np
import
tensorflow
as
tf
import
tensorflow
as
tf
...
@@ -32,6 +33,7 @@ from ..models.auto import (
...
@@ -32,6 +33,7 @@ from ..models.auto import (
)
)
from
..tf_utils
import
shape_list
,
stable_softmax
from
..tf_utils
import
shape_list
,
stable_softmax
from
..utils
import
ModelOutput
,
logging
from
..utils
import
ModelOutput
,
logging
from
.configuration_utils
import
GenerationConfig
from
.tf_logits_process
import
(
from
.tf_logits_process
import
(
TFForcedBOSTokenLogitsProcessor
,
TFForcedBOSTokenLogitsProcessor
,
TFForcedEOSTokenLogitsProcessor
,
TFForcedEOSTokenLogitsProcessor
,
...
@@ -449,6 +451,11 @@ class TFGenerationMixin:
...
@@ -449,6 +451,11 @@ class TFGenerationMixin:
supports_xla_generation
=
True
supports_xla_generation
=
True
def
prepare_inputs_for_generation
(
self
,
*
args
,
**
kwargs
):
raise
NotImplementedError
(
"A model class needs to define a `prepare_inputs_for_generation` method in order to use `generate`."
)
def
adjust_logits_during_generation
(
def
adjust_logits_during_generation
(
self
,
logits
,
cur_len
,
max_length
,
forced_bos_token_id
,
forced_eos_token_id
,
**
kwargs
self
,
logits
,
cur_len
,
max_length
,
forced_bos_token_id
,
forced_eos_token_id
,
**
kwargs
):
):
...
@@ -475,7 +482,7 @@ class TFGenerationMixin:
...
@@ -475,7 +482,7 @@ class TFGenerationMixin:
Confirms that the model class is compatible with generation. If not, raises an exception that points to the
Confirms that the model class is compatible with generation. If not, raises an exception that points to the
right class to use.
right class to use.
"""
"""
if
not
hasattr
(
self
,
"prepare_inputs_for
_generat
ion"
):
if
not
self
.
can
_generat
e
(
):
generate_compatible_mappings
=
[
generate_compatible_mappings
=
[
TF_MODEL_FOR_CAUSAL_LM_MAPPING
,
TF_MODEL_FOR_CAUSAL_LM_MAPPING
,
TF_MODEL_FOR_VISION_2_SEQ_MAPPING
,
TF_MODEL_FOR_VISION_2_SEQ_MAPPING
,
...
@@ -520,153 +527,43 @@ class TFGenerationMixin:
...
@@ -520,153 +527,43 @@ class TFGenerationMixin:
def
generate
(
def
generate
(
self
,
self
,
input_ids
=
None
,
input_ids
:
Optional
[
tf
.
Tensor
]
=
None
,
max_length
=
None
,
generation_config
:
Optional
[
GenerationConfig
]
=
None
,
max_new_tokens
=
None
,
min_length
=
None
,
do_sample
=
None
,
early_stopping
=
None
,
num_beams
=
None
,
temperature
=
None
,
penalty_alpha
=
None
,
top_k
=
None
,
top_p
=
None
,
repetition_penalty
=
None
,
bad_words_ids
=
None
,
bos_token_id
=
None
,
pad_token_id
=
None
,
eos_token_id
=
None
,
length_penalty
=
None
,
no_repeat_ngram_size
=
None
,
num_return_sequences
=
None
,
attention_mask
=
None
,
decoder_start_token_id
=
None
,
use_cache
=
None
,
seed
=
None
,
seed
=
None
,
output_scores
=
None
,
**
kwargs
,
output_attentions
=
None
,
output_hidden_states
=
None
,
return_dict_in_generate
=
None
,
forced_bos_token_id
=
None
,
forced_eos_token_id
=
None
,
suppress_tokens
=
None
,
begin_suppress_tokens
=
None
,
forced_decoder_ids
=
None
,
**
model_kwargs
,
)
->
Union
[
TFGenerateOutput
,
tf
.
Tensor
]:
)
->
Union
[
TFGenerateOutput
,
tf
.
Tensor
]:
r
"""
r
"""
Generates sequences of token ids for models with a language modeling head. The method supports the following
Generates sequences of token ids for models with a language modeling head.
generation methods for text-decoder, text-to-text, speech-to-text, and vision-to-text models:
- *greedy decoding* by calling [`~generation.TFGenerationMixin.greedy_search`] if `num_beams=1` and
<Tip warning={true}>
`do_sample=False`.
- *contrastive search* by calling [`~generation.TFGenerationMixin.contrastive_search`] if `penalty_alpha>0`
Most generation-controlling parameters are set in `generation_config` which, if not passed, will be set to the
and `top_k>1`
model's default generation configuration. You can override any `generation_config` by passing the corresponding
- *multinomial sampling* by calling [`~generation.TFGenerationMixin.sample`] if `num_beams=1` and
parameters to generate, e.g. `.generate(inputs, num_beams=4, do_sample=True)`.
`do_sample=True`.
- *beam-search decoding* by calling [`~generation.TFGenerationMixin.beam_search`] if `num_beams>1` and
For a complete overview of generate, check the [following
`do_sample=False`.
guide](https://huggingface.co/docs/transformers/en/main_classes/text_generation).
Adapted in part from [Facebook's XLM beam search
</Tip>
code](https://github.com/facebookresearch/XLM/blob/9e6f6814d17be4fe5b15f2e6c43eb2b2d76daeb4/src/model/transformer.py#L529).
Apart from `input_ids` and `attention_mask`, all the arguments below will default to the value of the attribute
of the same name inside the [`PretrainedConfig`] of the model. The default values indicated are the default
values of those config.
Most of these parameters are explained in more detail in [this blog
post](https://huggingface.co/blog/how-to-generate).
Parameters:
Parameters:
input_ids (`tf.Tensor` of `dtype=tf.int32` and shape `(batch_size, sequence_length)`, *optional*):
input_ids (`tf.Tensor` of `dtype=tf.int32` and shape `(batch_size, sequence_length)`, *optional*):
The sequence used as a prompt for the generation. If `None` the method initializes it with
The sequence used as a prompt for the generation. If `None` the method initializes it with
`bos_token_id` and a batch size of 1.
`bos_token_id` and a batch size of 1.
max_length (`int`, *optional*, defaults to `model.config.max_length`):
generation_config (`~generation.GenerationConfig`, *optional*):
The maximum length the generated tokens can have. Corresponds to the length of the input prompt +
The generation configuration to be used as base parametrization for the generation call. `**kwargs`
`max_new_tokens`. In general, prefer the use of `max_new_tokens`, which ignores the number of tokens in
passed to generate matching the attributes of `generation_config` will override them. If
the prompt.
`generation_config` is not provided, the default will be used, which had the following loading
max_new_tokens (`int`, *optional*):
priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model
The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt.
configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s
min_length (`int`, *optional*, defaults to 10):
default values, whose documentation should be checked to parameterize generation.
The minimum length of the sequence to be generated.
do_sample (`bool`, *optional*, defaults to `False`):
Whether or not to use sampling ; use greedy decoding otherwise.
early_stopping (`bool`, *optional*, defaults to `False`):
Whether to stop the beam search when at least `num_beams` sentences are finished per batch or not.
num_beams (`int`, *optional*, defaults to 1):
Number of beams for beam search. 1 means no beam search.
temperature (`float`, *optional*, defaults to 1.0):
The value used to module the next token probabilities.
penalty_alpha (`float`, *optional*):
The values balance the model confidence and the degeneration penalty in contrastive search decoding.
top_k (`int`, *optional*, defaults to 50):
The number of highest probability vocabulary tokens to keep for top-k-filtering.
top_p (`float`, *optional*, defaults to 1.0):
If set to float < 1, only the most probable tokens with probabilities that add up to `top_p` or higher
are kept for generation.
repetition_penalty (`float`, *optional*, defaults to 1.0):
The parameter for repetition penalty. 1.0 means no penalty. See [this
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
pad_token_id (`int`, *optional*):
The id of the *padding* token.
bos_token_id (`int`, *optional*):
The id of the *beginning-of-sequence* token.
eos_token_id (`int`, *optional*):
The id of the *end-of-sequence* token.
length_penalty (`float`, *optional*, defaults to 1.0):
Exponential penalty to the length that is used with beam-based generation. It is applied as an exponent
to the sequence length, which in turn is used to divide the score of the sequence. Since the score is
the log likelihood of the sequence (i.e. negative), `length_penalty` > 0.0 promotes longer sequences,
while `length_penalty` < 0.0 encourages shorter sequences.
no_repeat_ngram_size (`int`, *optional*, defaults to 0):
If set to int > 0, all ngrams of that size can only occur once.
bad_words_ids(`List[int]`, *optional*):
List of token ids that are not allowed to be generated. In order to get the tokens of the words that
should not appear in the generated text, use `tokenizer.encode(bad_word, add_prefix_space=True)`.
num_return_sequences(`int`, *optional*, defaults to 1):
The number of independently computed returned sequences for each element in the batch.
attention_mask (`tf.Tensor` of `dtype=tf.int32` and shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values are in `[0, 1]`, 1 for tokens
that are not masked, and 0 for masked tokens.
If not provided, will default to a tensor the same shape as `input_ids` that masks the pad token.
[What are attention masks?](../glossary#attention-mask)
decoder_start_token_id (`int`, *optional*):
If an encoder-decoder model starts decoding with a different token than *bos*, the id of that token.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should use the past last key/values attentions (if applicable to the model) to
speed up decoding.
seed (`List[int]`, *optional*):
seed (`List[int]`, *optional*):
Random seed to control sampling, containing two integers, used when `do_sample` is `True`. See the
Random seed to control sampling, containing two integers, used when `do_sample` is `True`. See the
`seed` argument from stateless functions in `tf.random`.
`seed` argument from stateless functions in `tf.random`.
output_attentions (`bool`, *optional*, defaults to `False`):
kwargs:
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be
returned tensors for more details.
forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder
output_hidden_states (`bool`, *optional*, defaults to `False`):
specific kwargs should not be prefixed and decoder specific kwargs should be prefixed with *decoder_*.
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
for more details.
output_scores (`bool`, *optional*, defaults to `False`):
Whether or not to return the prediction scores. See `scores` under returned tensors for more details.
return_dict_in_generate (`bool`, *optional*, defaults to `False`):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
forced_bos_token_id (`int`, *optional*):
The id of the token to force as the first generated token after the `decoder_start_token_id`. Useful
for multilingual models like [mBART](../model_doc/mbart) where the first generated token needs to be
the target language token.
forced_eos_token_id (`int`, *optional*):
The id of the token to force as the last generated token when `max_length` is reached.
suppress_tokens (`List[int]`, *optional*, defaults to `model.config.suppress_tokens`):
A list of tokens that will be supressed at generation. The `SupressTokens` logit processor will set
their log probs to `-inf` so that they are not sampled.
begin_suppress_tokens (`List[int]`, *optional*, defaults to `model.config.begin_suppress_tokens`):
A list of tokens that will be supressed at the begining of the generation. The `SupressBeginTokens`
logit processor will set their log probs to `-inf` so that they are not sampled.
forced_decoder_ids (`List[List[int]]`, *optional*, defaults to `model.config.forced_decoder_ids`):
A list of pairs of integers which indicates a mapping from generation indices to token indices that
will be forced before sampling. For example, `[[1, 123]]` means the second generated token will always
be a token of index 123.
model_kwargs:
Additional model specific kwargs will be forwarded to the `call` function of the model.
Return:
Return:
[`~utils.ModelOutput`] or `tf.Tensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True` or when
[`~utils.ModelOutput`] or `tf.Tensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True` or when
...
@@ -690,59 +587,92 @@ class TFGenerationMixin:
...
@@ -690,59 +587,92 @@ class TFGenerationMixin:
Examples:
Examples:
Greedy decoding, using the default generation configuration and ad hoc modifications:
```python
```python
tokenizer = AutoTokenizer.from_pretrained("distilgpt2") # Initialize tokenizer
>>> from transformers import AutoTokenizer, TFAutoModelForCausalLM
model = TFAutoModelWithLMHead.from_pretrained("distilgpt2")
# Greedy decoding
>>> tokenizer = AutoTokenizer.from_pretrained("gpt2")
outputs = model.generate(max_length=40)
>>> model = TFAutoModelForCausalLM.from_pretrained("gpt2")
print(f"Generated: {tokenizer.decode(outputs[0], skip_special_tokens=True)}")
>>> prompt = "Today I believe we can finally"
tokenizer = AutoTokenizer.from_pretrained("openai-gpt")
>>> input_ids = tokenizer(prompt, return_tensors="tf").input_ids
model = TFAutoModelWithLMHead.from_pretrained("openai-gpt")
input_context = "The dog"
>>> # Generate up to 30 tokens
input_ids = tokenizer.encode(input_context, return_tensors="tf") # encode input context
>>> outputs = model.generate(input_ids, do_sample=False, max_length=30)
# Generate 3 independent sequences using beam search decoding (5 beams) with sampling from initial context 'The dog'
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
outputs = model.generate(input_ids=input_ids, num_beams=5, num_return_sequences=3, temperature=1.5)
['Today I believe we can finally get to the point where we can make a difference in the lives of the people of the United States of America.\n']
# 3 output sequences were generated
```
for i in range(3):
print(f"Generated {i}: {tokenizer.decode(outputs[i], skip_special_tokens=True)}")
Multinomial sampling, modifying an existing generation configuration:
tokenizer = AutoTokenizer.from_pretrained("distilgpt2")
```python
model = TFAutoModelWithLMHead.from_pretrained("distilgpt2")
>>> from transformers import AutoTokenizer, TFAutoModelForCausalLM, GenerationConfig
input_context = "The dog"
input_ids = tokenizer.encode(input_context, return_tensors="tf")
>>> tokenizer = AutoTokenizer.from_pretrained("gpt2")
# Generate 3 candidates using sampling
>>> model = TFAutoModelForCausalLM.from_pretrained("gpt2")
outputs = model.generate(
input_ids=input_ids, max_length=40, temperature=0.7, num_return_sequences=3, do_sample=True
>>> prompt = "Today I believe we can finally"
)
>>> input_ids = tokenizer(prompt, return_tensors="tf").input_ids
# 3 output sequences were generated
for i in range(3):
>>> # Sample up to 30 tokens
print(f"Generated {i}: {tokenizer.decode(outputs[i], skip_special_tokens=True)}")
>>> generation_config = GenerationConfig.from_pretrained("gpt2")
>>> generation_config.max_length = 30
tokenizer = AutoTokenizer.from_pretrained("ctrl")
>>> generation_config.do_sample = True
model = TFAutoModelWithLMHead.from_pretrained("ctrl")
>>> outputs = model.generate(input_ids, generation_config=generation_config, seed=[0, 0])
# "Legal" is one of the control codes for ctrl
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
input_context = "Legal My neighbor is"
["Today I believe we can finally start taking a bold stand against climate change and climate change mitigation efforts such as President Obama's climate ban and President Trump's"]
input_ids = tokenizer.encode(input_context, return_tensors="tf")
```
outputs = model.generate(input_ids=input_ids, max_length=50, temperature=0.7, repetition_penalty=1.2)
print(f"Generated: {tokenizer.decode(outputs[0], skip_special_tokens=True)}")
Beam-search decoding, using a freshly initialized generation configuration:
tokenizer = AutoTokenizer.from_pretrained("gpt2")
```python
model = TFAutoModelWithLMHead.from_pretrained("gpt2")
>>> from transformers import AutoTokenizer, TFAutoModelForSeq2SeqLM, GenerationConfig
input_context = "My cute dog"
bad_words_ids = [
>>> tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-de")
tokenizer.encode(bad_word, add_prefix_space=True) for bad_word in ["idiot", "stupid", "shut up"]
>>> model = TFAutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-de")
]
input_ids = tokenizer.encode(input_context, return_tensors="tf")
>>> sentence = "Paris is one of the densest populated areas in Europe."
# generate sequences without allowing bad_words to be generated
>>> input_ids = tokenizer(sentence, return_tensors="tf").input_ids
outputs = model.generate(input_ids=input_ids, max_length=100, do_sample=True, bad_words_ids=bad_words_ids)
>>> generation_config = GenerationConfig(
... max_length=64,
... num_beams=5,
... bos_token_id=0,
... eos_token_id=0,
... decoder_start_token_id=58100,
... pad_token_id=58100,
... bad_words_ids=[[58100]],
... )
>>> outputs = model.generate(input_ids, generation_config=generation_config)
>>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
['Paris ist eines der dichtesten besiedelten Gebiete Europas.']
```"""
```"""
#
0
.
V
alidate the `.generate()` call
#
1
.
Handle `generation_config` and kwargs that might update it, and v
alidate the `.generate()` call
self
.
_validate_model_class
()
self
.
_validate_model_class
()
# priority: `generation_config` argument > `model.generation_config` (the default generation config)
if
generation_config
is
None
:
# legacy: users may modify the model configuration to control generation -- update the generation config
# model attribute accordingly, if it was created from the model config
if
self
.
generation_config
.
_from_model_config
:
new_generation_config
=
GenerationConfig
.
from_model_config
(
self
.
config
)
if
new_generation_config
!=
self
.
generation_config
:
warnings
.
warn
(
"You have modified the pretrained model configuration to control generation. This is a"
" deprecated strategy to control generation and will be removed soon, in a future version."
" Please use a generation configuration file (see"
" https://huggingface.co/docs/transformers/main_classes/text_generation)"
)
self
.
generation_config
=
new_generation_config
generation_config
=
self
.
generation_config
generation_config
=
copy
.
deepcopy
(
generation_config
)
model_kwargs
=
generation_config
.
update
(
**
kwargs
)
# All unused kwargs must be model kwargs
self
.
_validate_model_kwargs
(
model_kwargs
.
copy
())
self
.
_validate_model_kwargs
(
model_kwargs
.
copy
())
#
1
. Cast input dtypes to tf.int32 unless they're floats (which happens for some image models)
#
2
. Cast input dtypes to tf.int32 unless they're floats (which happens for some image models)
if
input_ids
is
not
None
:
if
input_ids
is
not
None
:
if
isinstance
(
input_ids
,
tf
.
Tensor
)
and
input_ids
.
dtype
.
is_floating
:
if
isinstance
(
input_ids
,
tf
.
Tensor
)
and
input_ids
.
dtype
.
is_floating
:
pass
pass
...
@@ -750,8 +680,8 @@ class TFGenerationMixin:
...
@@ -750,8 +680,8 @@ class TFGenerationMixin:
pass
pass
else
:
else
:
input_ids
=
tf
.
cast
(
input_ids
,
tf
.
int32
)
input_ids
=
tf
.
cast
(
input_ids
,
tf
.
int32
)
if
attention_mask
is
not
None
:
if
model_kwargs
.
get
(
"
attention_mask
"
)
is
not
None
:
attention_mask
=
tf
.
cast
(
attention_mask
,
tf
.
int32
)
model_kwargs
[
"
attention_mask
"
]
=
tf
.
cast
(
model_kwargs
[
"
attention_mask
"
]
,
tf
.
int32
)
if
"decoder_input_ids"
in
model_kwargs
:
if
"decoder_input_ids"
in
model_kwargs
:
if
(
if
(
isinstance
(
model_kwargs
[
"decoder_input_ids"
],
tf
.
Tensor
)
isinstance
(
model_kwargs
[
"decoder_input_ids"
],
tf
.
Tensor
)
...
@@ -765,44 +695,18 @@ class TFGenerationMixin:
...
@@ -765,44 +695,18 @@ class TFGenerationMixin:
else
:
else
:
model_kwargs
[
"decoder_input_ids"
]
=
tf
.
cast
(
model_kwargs
[
"decoder_input_ids"
],
tf
.
int32
)
model_kwargs
[
"decoder_input_ids"
]
=
tf
.
cast
(
model_kwargs
[
"decoder_input_ids"
],
tf
.
int32
)
# 2. Set generation parameters if not already defined
# 3. Set generation parameters if not already defined
length_penalty
=
length_penalty
if
length_penalty
is
not
None
else
self
.
config
.
length_penalty
if
generation_config
.
pad_token_id
is
None
and
generation_config
.
eos_token_id
is
not
None
:
early_stopping
=
early_stopping
if
early_stopping
is
not
None
else
self
.
config
.
early_stopping
if
model_kwargs
.
get
(
"attention_mask"
)
is
None
:
bos_token_id
=
bos_token_id
if
bos_token_id
is
not
None
else
self
.
config
.
bos_token_id
pad_token_id
=
pad_token_id
if
pad_token_id
is
not
None
else
self
.
config
.
pad_token_id
eos_token_id
=
eos_token_id
if
eos_token_id
is
not
None
else
self
.
config
.
eos_token_id
forced_bos_token_id
=
(
forced_bos_token_id
if
forced_bos_token_id
is
not
None
else
self
.
config
.
forced_bos_token_id
)
forced_eos_token_id
=
(
forced_eos_token_id
if
forced_eos_token_id
is
not
None
else
self
.
config
.
forced_eos_token_id
)
output_scores
=
output_scores
if
output_scores
is
not
None
else
self
.
config
.
output_scores
output_attentions
=
output_attentions
if
output_attentions
is
not
None
else
self
.
config
.
output_attentions
output_hidden_states
=
(
output_hidden_states
if
output_hidden_states
is
not
None
else
self
.
config
.
output_hidden_states
)
return_dict_in_generate
=
(
return_dict_in_generate
if
return_dict_in_generate
is
not
None
else
self
.
config
.
return_dict_in_generate
)
num_beams
=
num_beams
if
num_beams
is
not
None
else
self
.
config
.
num_beams
do_sample
=
do_sample
if
do_sample
is
not
None
else
self
.
config
.
do_sample
num_return_sequences
=
(
num_return_sequences
if
num_return_sequences
is
not
None
else
self
.
config
.
num_return_sequences
)
if
pad_token_id
is
None
and
eos_token_id
is
not
None
:
if
attention_mask
is
None
:
logger
.
warning
(
logger
.
warning
(
"The attention mask and the pad token id were not set. As a consequence, you may observe "
"The attention mask and the pad token id were not set. As a consequence, you may observe "
"unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results."
"unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results."
)
)
logger
.
warning
(
f
"Setting `pad_token_id` to
{
eos_token_id
}
(first `eos_token_id`) to generate sequence"
)
logger
.
warning
(
pad_token_id
=
eos_token_id
f
"Setting `pad_token_id` to
{
generation_config
.
eos_token_id
}
(first `eos_token_id`) to generate"
" sequence"
)
generation_config
.
pad_token_id
=
generation_config
.
eos_token_id
use_xla
=
not
tf
.
executing_eagerly
()
use_xla
=
not
tf
.
executing_eagerly
()
if
use_xla
and
not
self
.
supports_xla_generation
:
if
use_xla
and
not
self
.
supports_xla_generation
:
...
@@ -810,241 +714,242 @@ class TFGenerationMixin:
...
@@ -810,241 +714,242 @@ class TFGenerationMixin:
"The selected model does not support Graph mode nor XLA generation (e.g. from tf.function())"
"The selected model does not support Graph mode nor XLA generation (e.g. from tf.function())"
)
)
#
3
. Define model inputs
#
4
. Define model inputs
input_ids
=
self
.
_prepare_model_inputs
(
input_ids
,
bos_token_id
)
input_ids
=
self
.
_prepare_model_inputs
(
input_ids
,
generation_config
.
bos_token_id
)
# inputs_ids now has to be defined and cannot be None anymore
# inputs_ids now has to be defined and cannot be None anymore
batch_size
=
shape_list
(
input_ids
)[
0
]
batch_size
=
shape_list
(
input_ids
)[
0
]
# 4. Prepare other model kwargs
# 5. Prepare other model kwargs
if
output_attentions
is
not
None
:
model_kwargs
[
"output_attentions"
]
=
generation_config
.
output_attentions
model_kwargs
[
"output_attentions"
]
=
output_attentions
model_kwargs
[
"output_hidden_states"
]
=
generation_config
.
output_hidden_states
if
output_hidden_states
is
not
None
:
model_kwargs
[
"use_cache"
]
=
generation_config
.
use_cache
model_kwargs
[
"output_hidden_states"
]
=
output_hidden_states
if
use_cache
is
not
None
:
model_kwargs
[
"use_cache"
]
=
use_cache
if
attention_mask
is
not
None
:
model_kwargs
[
"attention_mask"
]
=
attention_mask
accepts_attention_mask
=
"attention_mask"
in
set
(
inspect
.
signature
(
self
.
call
).
parameters
.
keys
())
accepts_attention_mask
=
"attention_mask"
in
set
(
inspect
.
signature
(
self
.
call
).
parameters
.
keys
())
requires_attention_mask
=
"encoder_outputs"
not
in
model_kwargs
requires_attention_mask
=
"encoder_outputs"
not
in
model_kwargs
if
model_kwargs
.
get
(
"attention_mask"
,
None
)
is
None
and
requires_attention_mask
and
accepts_attention_mask
:
if
model_kwargs
.
get
(
"attention_mask"
,
None
)
is
None
and
requires_attention_mask
and
accepts_attention_mask
:
model_kwargs
[
"attention_mask"
]
=
self
.
_prepare_attention_mask_for_generation
(
model_kwargs
[
"attention_mask"
]
=
self
.
_prepare_attention_mask_for_generation
(
input_ids
,
pad_token_id
,
eos_token_id
input_ids
,
generation_config
.
pad_token_id
,
generation_config
.
eos_token_id
)
)
# decoder-only models should use left-padding for generation
# decoder-only models should use left-padding for generation
if
not
self
.
config
.
is_encoder_decoder
:
if
not
self
.
config
.
is_encoder_decoder
:
if
pad_token_id
is
not
None
and
tf
.
math
.
reduce_any
(
input_ids
[:,
-
1
]
==
pad_token_id
):
if
generation_config
.
pad_token_id
is
not
None
and
tf
.
math
.
reduce_any
(
input_ids
[:,
-
1
]
==
generation_config
.
pad_token_id
):
logger
.
warning
(
logger
.
warning
(
"A decoder-only architecture is being used, but right-padding was detected! For correct "
"A decoder-only architecture is being used, but right-padding was detected! For correct "
"generation results, please set `padding_side='left'` when initializing the tokenizer."
"generation results, please set `padding_side='left'` when initializing the tokenizer."
)
)
#
5
. Prepare model inputs which will be used for auto-regressive generation
#
6
. Prepare model inputs which will be used for auto-regressive generation
if
self
.
config
.
is_encoder_decoder
:
if
self
.
config
.
is_encoder_decoder
:
# if encoder-decoder, we create encoder_outputs and add to `model_kwargs`
# if encoder-decoder, we create encoder_outputs and add to `model_kwargs`
model_kwargs
=
self
.
_prepare_encoder_decoder_kwargs_for_generation
(
input_ids
,
model_kwargs
)
model_kwargs
=
self
.
_prepare_encoder_decoder_kwargs_for_generation
(
input_ids
,
model_kwargs
)
# if encoder-decoder then `input_ids` come from `decoder_start_token_id`
# if encoder-decoder then `input_ids` come from `decoder_start_token_id`
input_ids
=
self
.
_prepare_decoder_input_ids_for_generation
(
input_ids
=
self
.
_prepare_decoder_input_ids_for_generation
(
batch_size
,
batch_size
,
decoder_start_token_id
=
decoder_start_token_id
,
decoder_start_token_id
=
generation_config
.
decoder_start_token_id
,
bos_token_id
=
bos_token_id
,
bos_token_id
=
generation_config
.
bos_token_id
,
model_kwargs
=
model_kwargs
,
model_kwargs
=
model_kwargs
,
)
)
#
6
. Prepare `max_length` depending on other stopping criteria.
#
7
. Prepare `max_length` depending on other stopping criteria.
input_ids_seq_length
=
input_ids
.
shape
[
-
1
]
input_ids_seq_length
=
input_ids
.
shape
[
-
1
]
if
max_length
is
None
and
max_new_tokens
is
None
:
has_default_max_length
=
kwargs
.
get
(
"max_length"
)
is
None
and
generation_config
.
max_length
is
not
None
if
has_default_max_length
and
generation_config
.
max_new_tokens
is
None
:
warnings
.
warn
(
warnings
.
warn
(
"Neither `max_length` nor `max_new_tokens` have been set, `max_length` will default to
"
"Neither `max_length` nor `max_new_tokens` have been set, `max_length` will default to"
f
"
{
self
.
config
.
max_length
}
(`
self.
config.max_length`). Controlling `max_length` via the
config is
"
f
"
{
generation_
config
.
max_length
}
(`
generation_
config.max_length`). Controlling `max_length` via the"
"deprecated and `max_length` will be removed from the config in v5 of Transformers -- we
recommend
"
"
config is
deprecated and `max_length` will be removed from the config in v5 of Transformers -- we"
"using `max_new_tokens` to control the maximum length of the generation."
,
"
recommend
using `max_new_tokens` to control the maximum length of the generation."
,
UserWarning
,
UserWarning
,
)
)
elif
max_length
is
None
and
max_new_tokens
is
not
None
:
elif
has_default_max_length
and
generation_config
.
max_new_tokens
is
not
None
:
max_length
=
max_new_tokens
+
input_ids_seq_length
generation_config
.
max_length
=
generation_config
.
max_new_tokens
+
input_ids_seq_length
elif
max_length
is
not
None
and
max_new_tokens
is
not
None
:
elif
not
has_default_max_length
and
generation_config
.
max_new_tokens
is
not
None
:
raise
ValueError
(
raise
ValueError
(
"Both `max_new_tokens` and `max_length` have been set but they serve the same purpose -- setting a"
"Both `max_new_tokens` and `max_length` have been set but they serve the same purpose -- setting a"
" limit to the generated output length. Remove one of those arguments. Please refer to the"
" limit to the generated output length. Remove one of those arguments. Please refer to the"
" documentation for more information. "
" documentation for more information. "
"(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)"
"(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)"
)
)
# default to config if still None
max_length
=
max_length
if
max_length
is
not
None
else
self
.
config
.
max_length
min_length
=
min_length
if
min_length
is
not
None
else
self
.
config
.
min_length
if
min_length
is
not
None
and
min_length
>
max_length
:
if
generation_config
.
min_length
is
not
None
and
generation_config
.
min_length
>
generation_config
.
max_length
:
raise
ValueError
(
raise
ValueError
(
f
"Unfeasable length constraints: the minimum length (
{
min_length
}
) is larger than
the maximum
"
f
"Unfeasable length constraints: the minimum length (
{
generation_config
.
min_length
}
) is larger than"
f
"
length (
{
max_length
}
)"
f
"
the maximum length (
{
generation_config
.
max_length
}
)"
)
)
if
input_ids_seq_length
>=
max_length
:
if
input_ids_seq_length
>=
generation_config
.
max_length
:
input_ids_string
=
"decoder_input_ids"
if
self
.
config
.
is_encoder_decoder
else
"input_ids"
input_ids_string
=
"decoder_input_ids"
if
self
.
config
.
is_encoder_decoder
else
"input_ids"
logger
.
warning
(
logger
.
warning
(
f
"Input length of
{
input_ids_string
}
is
{
input_ids_seq_length
}
, but `max_length` is set to"
f
"Input length of
{
input_ids_string
}
is
{
input_ids_seq_length
}
, but `max_length` is set to"
f
"
{
max_length
}
. This can lead to unexpected behavior. You should consider
increasing
"
f
"
{
generation_config
.
max_length
}
. This can lead to unexpected behavior. You should consider"
"`max_new_tokens`."
"
increasing
`max_new_tokens`."
)
)
#
7
. determine generation mode
#
8
. determine generation mode
is_contrastive_search_gen_mode
=
(
is_contrastive_search_gen_mode
=
(
top_k
is
not
None
and
top_k
>
1
and
do_sample
is
False
and
penalty_alpha
is
not
None
and
penalty_alpha
>
0
generation_config
.
top_k
is
not
None
and
generation_config
.
top_k
>
1
and
generation_config
.
do_sample
is
False
and
generation_config
.
penalty_alpha
is
not
None
and
generation_config
.
penalty_alpha
>
0
)
is_greedy_gen_mode
=
(
not
is_contrastive_search_gen_mode
and
(
generation_config
.
num_beams
==
1
)
and
generation_config
.
do_sample
is
False
)
)
is_greedy_gen_mode
=
not
is_contrastive_search_gen_mode
and
(
num_beams
==
1
)
and
do_sample
is
False
is_beam_gen_mode
=
(
is_beam_gen_mode
=
not
is_contrastive_search_gen_mode
and
(
num_beams
>
1
)
and
do_sample
is
False
not
is_contrastive_search_gen_mode
is_sample_gen_mode
=
(
num_beams
==
1
)
and
do_sample
is
True
and
(
generation_config
.
num_beams
>
1
)
is_beam_sample_gen_mode
=
(
num_beams
>
1
)
and
do_sample
is
True
and
generation_config
.
do_sample
is
False
)
is_sample_gen_mode
=
(
generation_config
.
num_beams
==
1
)
and
generation_config
.
do_sample
is
True
is_beam_sample_gen_mode
=
(
generation_config
.
num_beams
>
1
)
and
generation_config
.
do_sample
is
True
#
8
. prepare distribution pre_processing samplers
#
9
. prepare distribution pre_processing samplers
logits_processor
=
self
.
_get_logits_processor
(
logits_processor
=
self
.
_get_logits_processor
(
repetition_penalty
=
repetition_penalty
,
generation_config
=
generation_config
,
no_repeat_ngram_size
=
no_repeat_ngram_size
,
input_ids_seq_length
=
input_ids_seq_length
,
input_ids_seq_length
=
input_ids_seq_length
,
bad_words_ids
=
bad_words_ids
,
min_length
=
min_length
,
max_length
=
max_length
,
eos_token_id
=
eos_token_id
,
forced_bos_token_id
=
forced_bos_token_id
,
forced_eos_token_id
=
forced_eos_token_id
,
suppress_tokens
=
suppress_tokens
,
begin_suppress_tokens
=
begin_suppress_tokens
,
forced_decoder_ids
=
forced_decoder_ids
,
)
)
#
9
. go into different generation modes
#
10
. go into different generation modes
if
is_greedy_gen_mode
:
if
is_greedy_gen_mode
:
if
num_return_sequences
>
1
:
if
generation_config
.
num_return_sequences
>
1
:
raise
ValueError
(
raise
ValueError
(
f
"num_return_sequences has to be 1, but is
{
num_return_sequences
}
when doing greedy search."
f
"num_return_sequences has to be 1, but is
{
generation_config
.
num_return_sequences
}
when doing"
" greedy search."
)
)
# 1
0
. run greedy search
# 1
1
. run greedy search
return
self
.
greedy_search
(
return
self
.
greedy_search
(
input_ids
,
input_ids
,
max_length
=
max_length
,
max_length
=
generation_config
.
max_length
,
pad_token_id
=
pad_token_id
,
pad_token_id
=
generation_config
.
pad_token_id
,
eos_token_id
=
eos_token_id
,
eos_token_id
=
generation_config
.
eos_token_id
,
logits_processor
=
logits_processor
,
logits_processor
=
logits_processor
,
output_scores
=
output_scores
,
output_scores
=
generation_config
.
output_scores
,
return_dict_in_generate
=
return_dict_in_generate
,
return_dict_in_generate
=
generation_config
.
return_dict_in_generate
,
**
model_kwargs
,
**
model_kwargs
,
)
)
elif
is_contrastive_search_gen_mode
:
elif
is_contrastive_search_gen_mode
:
if
num_return_sequences
>
1
:
if
generation_config
.
num_return_sequences
>
1
:
raise
ValueError
(
raise
ValueError
(
f
"num_return_sequences has to be 1, but is
{
num_return_sequences
}
when doing contrastive search."
f
"num_return_sequences has to be 1, but is
{
generation_config
.
num_return_sequences
}
when doing"
" contrastive search."
)
)
# 1
0
. run contrastive search
# 1
1
. run contrastive search
return
self
.
contrastive_search
(
return
self
.
contrastive_search
(
input_ids
,
input_ids
,
top_k
=
top_k
,
top_k
=
generation_config
.
top_k
,
penalty_alpha
=
penalty_alpha
,
penalty_alpha
=
generation_config
.
penalty_alpha
,
logits_processor
=
logits_processor
,
logits_processor
=
logits_processor
,
max_length
=
max_length
,
max_length
=
generation_config
.
max_length
,
pad_token_id
=
pad_token_id
,
pad_token_id
=
generation_config
.
pad_token_id
,
eos_token_id
=
eos_token_id
,
eos_token_id
=
generation_config
.
eos_token_id
,
output_scores
=
output_scores
,
output_scores
=
generation_config
.
output_scores
,
return_dict_in_generate
=
return_dict_in_generate
,
return_dict_in_generate
=
generation_config
.
return_dict_in_generate
,
**
model_kwargs
,
**
model_kwargs
,
)
)
elif
is_sample_gen_mode
:
elif
is_sample_gen_mode
:
# 1
0
. prepare logits warper
# 1
1
. prepare logits warper
logits_warper
=
self
.
_get_logits_warper
(
top_k
=
top_k
,
top_p
=
top_p
,
temperature
=
temperature
)
logits_warper
=
self
.
_get_logits_warper
(
generation_config
=
generation_config
)
# 1
1
. expand input_ids with `num_return_sequences` additional sequences per batch
# 1
2
. expand input_ids with `num_return_sequences` additional sequences per batch
input_ids
,
model_kwargs
=
self
.
_expand_inputs_for_generation
(
input_ids
,
model_kwargs
=
self
.
_expand_inputs_for_generation
(
input_ids
=
input_ids
,
input_ids
=
input_ids
,
expand_size
=
num_return_sequences
,
expand_size
=
generation_config
.
num_return_sequences
,
is_encoder_decoder
=
self
.
config
.
is_encoder_decoder
,
is_encoder_decoder
=
self
.
config
.
is_encoder_decoder
,
**
model_kwargs
,
**
model_kwargs
,
)
)
# 1
2
. run sample
# 1
3
. run sample
return
self
.
sample
(
return
self
.
sample
(
input_ids
,
input_ids
,
logits_processor
=
logits_processor
,
logits_processor
=
logits_processor
,
logits_warper
=
logits_warper
,
logits_warper
=
logits_warper
,
max_length
=
max_length
,
max_length
=
generation_config
.
max_length
,
pad_token_id
=
pad_token_id
,
pad_token_id
=
generation_config
.
pad_token_id
,
eos_token_id
=
eos_token_id
,
eos_token_id
=
generation_config
.
eos_token_id
,
seed
=
seed
,
seed
=
seed
,
output_scores
=
output_scores
,
output_scores
=
generation_config
.
output_scores
,
return_dict_in_generate
=
return_dict_in_generate
,
return_dict_in_generate
=
generation_config
.
return_dict_in_generate
,
**
model_kwargs
,
**
model_kwargs
,
)
)
elif
is_beam_gen_mode
:
elif
is_beam_gen_mode
:
if
num_beams
<
num_return_sequences
:
if
generation_config
.
num_beams
<
generation_config
.
num_return_sequences
:
raise
ValueError
(
raise
ValueError
(
"Beam search decoding cannot return more sequences than it has beams. Please set "
"Beam search decoding cannot return more sequences than it has beams. Please set num_beams >="
f
"num_beams >= num_return_sequences, got
{
num_beams
}
and
{
num_return_sequences
}
(respectivelly)"
f
" num_return_sequences, got
{
generation_config
.
num_beams
}
and"
f
"
{
generation_config
.
num_return_sequences
}
(respectivelly)"
)
)
# 1
0
. broadcast inputs to the desired number of beams
# 1
1
. broadcast inputs to the desired number of beams
input_ids
=
self
.
_expand_to_num_beams
(
input_ids
,
num_beams
=
num_beams
)
input_ids
=
self
.
_expand_to_num_beams
(
input_ids
,
num_beams
=
generation_config
.
num_beams
)
if
"encoder_outputs"
in
model_kwargs
:
if
"encoder_outputs"
in
model_kwargs
:
model_kwargs
[
"encoder_outputs"
][
"last_hidden_state"
]
=
self
.
_expand_to_num_beams
(
model_kwargs
[
"encoder_outputs"
][
"last_hidden_state"
]
=
self
.
_expand_to_num_beams
(
model_kwargs
[
"encoder_outputs"
][
"last_hidden_state"
],
num_beams
=
num_beams
model_kwargs
[
"encoder_outputs"
][
"last_hidden_state"
],
num_beams
=
generation_config
.
num_beams
)
)
if
"attention_mask"
in
model_kwargs
:
if
"attention_mask"
in
model_kwargs
:
model_kwargs
[
"attention_mask"
]
=
self
.
_expand_to_num_beams
(
model_kwargs
[
"attention_mask"
]
=
self
.
_expand_to_num_beams
(
model_kwargs
[
"attention_mask"
],
num_beams
=
num_beams
model_kwargs
[
"attention_mask"
],
num_beams
=
generation_config
.
num_beams
)
)
# 1
1
. run beam search
# 1
2
. run beam search
return
self
.
beam_search
(
return
self
.
beam_search
(
input_ids
,
input_ids
,
max_length
=
max_length
,
max_length
=
generation_config
.
max_length
,
pad_token_id
=
pad_token_id
,
pad_token_id
=
generation_config
.
pad_token_id
,
eos_token_id
=
eos_token_id
,
eos_token_id
=
generation_config
.
eos_token_id
,
length_penalty
=
length_penalty
,
length_penalty
=
generation_config
.
length_penalty
,
early_stopping
=
early_stopping
,
early_stopping
=
generation_config
.
early_stopping
,
logits_processor
=
logits_processor
,
logits_processor
=
logits_processor
,
return_dict_in_generate
=
return_dict_in_generate
,
return_dict_in_generate
=
generation_config
.
return_dict_in_generate
,
num_return_sequences
=
num_return_sequences
,
num_return_sequences
=
generation_config
.
num_return_sequences
,
**
model_kwargs
,
**
model_kwargs
,
)
)
elif
is_beam_sample_gen_mode
:
elif
is_beam_sample_gen_mode
:
if
num_beams
<
num_return_sequences
:
if
generation_config
.
num_beams
<
generation_config
.
num_return_sequences
:
raise
ValueError
(
raise
ValueError
(
"Beam search decoding cannot return more sequences than it has beams. Please set "
"Beam search decoding cannot return more sequences than it has beams. Please set num_beams >="
f
"num_beams >= num_return_sequences, got
{
num_beams
}
and
{
num_return_sequences
}
(respectivelly)"
f
" num_return_sequences, got
{
generation_config
.
num_beams
}
and"
f
"
{
generation_config
.
num_return_sequences
}
(respectivelly)"
)
)
# 1
0
. prepare logits warper
# 1
1
. prepare logits warper
logits_warper
=
self
.
_get_logits_warper
(
top_k
=
top_k
,
top_p
=
top_p
,
temperature
=
temperature
)
logits_warper
=
self
.
_get_logits_warper
(
generation_config
=
generation_config
)
# 1
1
. broadcast inputs to the desired number of beams
# 1
2
. broadcast inputs to the desired number of beams
input_ids
=
self
.
_expand_to_num_beams
(
input_ids
,
num_beams
=
num_beams
)
input_ids
=
self
.
_expand_to_num_beams
(
input_ids
,
num_beams
=
generation_config
.
num_beams
)
if
"encoder_outputs"
in
model_kwargs
:
if
"encoder_outputs"
in
model_kwargs
:
model_kwargs
[
"encoder_outputs"
][
"last_hidden_state"
]
=
self
.
_expand_to_num_beams
(
model_kwargs
[
"encoder_outputs"
][
"last_hidden_state"
]
=
self
.
_expand_to_num_beams
(
model_kwargs
[
"encoder_outputs"
][
"last_hidden_state"
],
num_beams
=
num_beams
model_kwargs
[
"encoder_outputs"
][
"last_hidden_state"
],
num_beams
=
generation_config
.
num_beams
)
)
if
"attention_mask"
in
model_kwargs
:
if
"attention_mask"
in
model_kwargs
:
model_kwargs
[
"attention_mask"
]
=
self
.
_expand_to_num_beams
(
model_kwargs
[
"attention_mask"
]
=
self
.
_expand_to_num_beams
(
model_kwargs
[
"attention_mask"
],
num_beams
=
num_beams
model_kwargs
[
"attention_mask"
],
num_beams
=
generation_config
.
num_beams
)
)
# 1
2
. run beam sample (beam search with sampling)
# 1
3
. run beam sample (beam search with sampling)
return
self
.
beam_search
(
return
self
.
beam_search
(
input_ids
,
input_ids
,
do_sample
=
True
,
do_sample
=
True
,
max_length
=
max_length
,
max_length
=
generation_config
.
max_length
,
pad_token_id
=
pad_token_id
,
pad_token_id
=
generation_config
.
pad_token_id
,
eos_token_id
=
eos_token_id
,
eos_token_id
=
generation_config
.
eos_token_id
,
length_penalty
=
length_penalty
,
length_penalty
=
generation_config
.
length_penalty
,
early_stopping
=
early_stopping
,
early_stopping
=
generation_config
.
early_stopping
,
logits_processor
=
logits_processor
,
logits_processor
=
logits_processor
,
logits_warper
=
logits_warper
,
logits_warper
=
logits_warper
,
return_dict_in_generate
=
return_dict_in_generate
,
return_dict_in_generate
=
generation_config
.
return_dict_in_generate
,
num_return_sequences
=
num_return_sequences
,
num_return_sequences
=
generation_config
.
num_return_sequences
,
**
model_kwargs
,
**
model_kwargs
,
)
)
...
@@ -1108,26 +1013,16 @@ class TFGenerationMixin:
...
@@ -1108,26 +1013,16 @@ class TFGenerationMixin:
# retrieve decoder_start_token_id for encoder-decoder models
# retrieve decoder_start_token_id for encoder-decoder models
# fall back to bos_token_id if necessary
# fall back to bos_token_id if necessary
decoder_start_token_id
=
(
decoder_start_token_id
=
(
decoder_start_token_id
if
decoder_start_token_id
is
not
None
else
self
.
config
.
decoder_start_token_id
decoder_start_token_id
if
decoder_start_token_id
is
not
None
else
self
.
generation_config
.
decoder_start_token_id
)
)
bos_token_id
=
bos_token_id
if
bos_token_id
is
not
None
else
self
.
config
.
bos_token_id
bos_token_id
=
bos_token_id
if
bos_token_id
is
not
None
else
self
.
generation_
config
.
bos_token_id
if
decoder_start_token_id
is
not
None
:
if
decoder_start_token_id
is
not
None
:
return
decoder_start_token_id
return
decoder_start_token_id
elif
(
hasattr
(
self
.
config
,
"decoder"
)
and
hasattr
(
self
.
config
.
decoder
,
"decoder_start_token_id"
)
and
self
.
config
.
decoder
.
decoder_start_token_id
is
not
None
):
return
self
.
config
.
decoder
.
decoder_start_token_id
elif
bos_token_id
is
not
None
:
elif
bos_token_id
is
not
None
:
return
bos_token_id
return
bos_token_id
elif
(
hasattr
(
self
.
config
,
"decoder"
)
and
hasattr
(
self
.
config
.
decoder
,
"bos_token_id"
)
and
self
.
config
.
decoder
.
bos_token_id
is
not
None
):
return
self
.
config
.
decoder
.
bos_token_id
raise
ValueError
(
raise
ValueError
(
"`decoder_start_token_id` or `bos_token_id` has to be defined for encoder-decoder generation."
"`decoder_start_token_id` or `bos_token_id` has to be defined for encoder-decoder generation."
)
)
...
@@ -1332,46 +1227,30 @@ class TFGenerationMixin:
...
@@ -1332,46 +1227,30 @@ class TFGenerationMixin:
def
_get_logits_warper
(
def
_get_logits_warper
(
self
,
self
,
top_k
:
Optional
[
int
]
=
None
,
generation_config
:
GenerationConfig
,
top_p
:
Optional
[
float
]
=
None
,
temperature
:
Optional
[
float
]
=
None
,
)
->
TFLogitsProcessorList
:
)
->
TFLogitsProcessorList
:
"""
"""
This class returns a [`TFLogitsProcessorList`] list object that contains all relevant [`TFLogitsWarper`]
This class returns a [`TFLogitsProcessorList`] list object that contains all relevant [`TFLogitsWarper`]
instances used for multinomial sampling.
instances used for multinomial sampling.
"""
"""
# init warp parameters
top_k
=
top_k
if
top_k
is
not
None
else
self
.
config
.
top_k
top_p
=
top_p
if
top_p
is
not
None
else
self
.
config
.
top_p
temperature
=
temperature
if
temperature
is
not
None
else
self
.
config
.
temperature
# instantiate warpers list
# instantiate warpers list
warpers
=
TFLogitsProcessorList
()
warpers
=
TFLogitsProcessorList
()
# the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
# the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
# all samplers can be found in `generation_utils_samplers.py`
# all samplers can be found in `generation_utils_samplers.py`
if
temperature
is
not
None
and
temperature
!=
1.0
:
if
generation_config
.
temperature
is
not
None
and
generation_config
.
temperature
!=
1.0
:
warpers
.
append
(
TFTemperatureLogitsWarper
(
temperature
))
warpers
.
append
(
TFTemperatureLogitsWarper
(
generation_config
.
temperature
))
if
top_k
is
not
None
and
top_k
!=
0
:
if
generation_config
.
top_k
is
not
None
and
generation_config
.
top_k
!=
0
:
warpers
.
append
(
TFTopKLogitsWarper
(
top_k
=
top_k
,
min_tokens_to_keep
=
1
))
warpers
.
append
(
TFTopKLogitsWarper
(
top_k
=
generation_config
.
top_k
,
min_tokens_to_keep
=
1
))
if
top_p
is
not
None
and
top_p
<
1.0
:
if
generation_config
.
top_p
is
not
None
and
generation_config
.
top_p
<
1.0
:
warpers
.
append
(
TFTopPLogitsWarper
(
top_p
=
top_p
,
min_tokens_to_keep
=
1
))
warpers
.
append
(
TFTopPLogitsWarper
(
top_p
=
generation_config
.
top_p
,
min_tokens_to_keep
=
1
))
return
warpers
return
warpers
def
_get_logits_processor
(
def
_get_logits_processor
(
self
,
self
,
repetition_penalty
:
float
,
generation_config
:
GenerationConfig
,
no_repeat_ngram_size
:
int
,
input_ids_seq_length
:
int
,
input_ids_seq_length
:
int
,
bad_words_ids
:
List
[
List
[
int
]],
min_length
:
int
,
max_length
:
int
,
eos_token_id
:
int
,
forced_bos_token_id
:
int
,
forced_eos_token_id
:
int
,
suppress_tokens
:
Optional
[
List
[
int
]]
=
None
,
begin_suppress_tokens
:
Optional
[
List
[
int
]]
=
None
,
forced_decoder_ids
:
Optional
[
List
[
List
[
int
]]]
=
None
,
)
->
TFLogitsProcessorList
:
)
->
TFLogitsProcessorList
:
"""
"""
This class returns a [`TFLogitsProcessorList`] list object that contains all relevant [`TFLogitsProcessor`]
This class returns a [`TFLogitsProcessorList`] list object that contains all relevant [`TFLogitsProcessor`]
...
@@ -1379,42 +1258,45 @@ class TFGenerationMixin:
...
@@ -1379,42 +1258,45 @@ class TFGenerationMixin:
"""
"""
processors
=
TFLogitsProcessorList
()
processors
=
TFLogitsProcessorList
()
repetition_penalty
=
repetition_penalty
if
repetition_penalty
is
not
None
else
self
.
config
.
repetition_penalty
no_repeat_ngram_size
=
(
no_repeat_ngram_size
if
no_repeat_ngram_size
is
not
None
else
self
.
config
.
no_repeat_ngram_size
)
bad_words_ids
=
bad_words_ids
if
bad_words_ids
is
not
None
else
self
.
config
.
bad_words_ids
eos_token_id
=
eos_token_id
if
eos_token_id
is
not
None
else
self
.
config
.
eos_token_id
suppress_tokens
=
suppress_tokens
if
suppress_tokens
is
not
None
else
self
.
config
.
suppress_tokens
begin_suppress_tokens
=
(
begin_suppress_tokens
if
begin_suppress_tokens
is
not
None
else
self
.
config
.
begin_suppress_tokens
)
if
forced_decoder_ids
is
None
and
hasattr
(
self
.
config
,
"forced_decoder_ids"
):
forced_decoder_ids
=
self
.
config
.
forced_decoder_ids
# instantiate processors list
# instantiate processors list
if
repetition_penalty
is
not
None
and
repetition_penalty
!=
1.0
:
if
generation_config
.
repetition_penalty
is
not
None
and
generation_config
.
repetition_penalty
!=
1.0
:
processors
.
append
(
TFRepetitionPenaltyLogitsProcessor
(
penalty
=
repetition_penalty
))
processors
.
append
(
TFRepetitionPenaltyLogitsProcessor
(
penalty
=
generation_config
.
repetition_penalty
))
if
no_repeat_ngram_size
is
not
None
and
no_repeat_ngram_size
>
0
:
if
generation_config
.
no_repeat_ngram_size
is
not
None
and
generation_config
.
no_repeat_ngram_size
>
0
:
processors
.
append
(
TFNoRepeatNGramLogitsProcessor
(
no_repeat_ngram_size
))
processors
.
append
(
TFNoRepeatNGramLogitsProcessor
(
generation_config
.
no_repeat_ngram_size
))
if
bad_words_ids
is
not
None
:
if
generation_config
.
bad_words_ids
is
not
None
:
processors
.
append
(
TFNoBadWordsLogitsProcessor
(
bad_words_ids
,
eos_token_id
))
processors
.
append
(
if
min_length
is
not
None
and
eos_token_id
is
not
None
and
min_length
>
0
:
TFNoBadWordsLogitsProcessor
(
generation_config
.
bad_words_ids
,
generation_config
.
eos_token_id
)
processors
.
append
(
TFMinLengthLogitsProcessor
(
min_length
,
eos_token_id
))
)
if
forced_bos_token_id
is
not
None
:
if
(
processors
.
append
(
TFForcedBOSTokenLogitsProcessor
(
forced_bos_token_id
))
generation_config
.
min_length
is
not
None
if
forced_eos_token_id
is
not
None
:
and
generation_config
.
eos_token_id
is
not
None
processors
.
append
(
TFForcedEOSTokenLogitsProcessor
(
max_length
,
forced_eos_token_id
))
and
generation_config
.
min_length
>
0
if
suppress_tokens
is
not
None
:
):
processors
.
append
(
TFSuppressTokensLogitsProcessor
(
suppress_tokens
))
processors
.
append
(
TFMinLengthLogitsProcessor
(
generation_config
.
min_length
,
generation_config
.
eos_token_id
))
if
begin_suppress_tokens
is
not
None
:
if
generation_config
.
forced_bos_token_id
is
not
None
:
processors
.
append
(
TFForcedBOSTokenLogitsProcessor
(
generation_config
.
forced_bos_token_id
))
if
generation_config
.
forced_eos_token_id
is
not
None
:
processors
.
append
(
TFForcedEOSTokenLogitsProcessor
(
generation_config
.
max_length
,
generation_config
.
forced_eos_token_id
)
)
if
generation_config
.
suppress_tokens
is
not
None
:
processors
.
append
(
TFSuppressTokensLogitsProcessor
(
generation_config
.
suppress_tokens
))
if
generation_config
.
begin_suppress_tokens
is
not
None
:
begin_index
=
input_ids_seq_length
begin_index
=
input_ids_seq_length
begin_index
=
begin_index
if
(
input_ids_seq_length
>
1
or
forced_bos_token_id
is
None
)
else
begin_index
+
1
begin_index
=
(
if
forced_decoder_ids
is
not
None
:
begin_index
begin_index
+=
forced_decoder_ids
[
-
1
][
0
]
# generation starts after the last token that is forced
if
(
input_ids_seq_length
>
1
or
generation_config
.
forced_bos_token_id
is
None
)
processors
.
append
(
TFSuppressTokensAtBeginLogitsProcessor
(
begin_suppress_tokens
,
begin_index
))
else
begin_index
+
1
if
forced_decoder_ids
is
not
None
:
)
processors
.
append
(
TFForceTokensLogitsProcessor
(
forced_decoder_ids
))
if
generation_config
.
forced_decoder_ids
is
not
None
:
begin_index
+=
generation_config
.
forced_decoder_ids
[
-
1
][
0
]
# generation starts after the last token that is forced
processors
.
append
(
TFSuppressTokensAtBeginLogitsProcessor
(
generation_config
.
begin_suppress_tokens
,
begin_index
)
)
if
generation_config
.
forced_decoder_ids
is
not
None
:
processors
.
append
(
TFForceTokensLogitsProcessor
(
generation_config
.
forced_decoder_ids
))
return
processors
return
processors
def
greedy_search
(
def
greedy_search
(
...
@@ -1500,17 +1382,22 @@ class TFGenerationMixin:
...
@@ -1500,17 +1382,22 @@ class TFGenerationMixin:
# 1. init greedy_search values
# 1. init greedy_search values
logits_processor
=
logits_processor
if
logits_processor
is
not
None
else
TFLogitsProcessorList
()
logits_processor
=
logits_processor
if
logits_processor
is
not
None
else
TFLogitsProcessorList
()
max_length
=
max_length
if
max_length
is
not
None
else
self
.
config
.
max_length
max_length
=
max_length
if
max_length
is
not
None
else
self
.
generation_config
.
max_length
pad_token_id
=
pad_token_id
if
pad_token_id
is
not
None
else
self
.
config
.
pad_token_id
pad_token_id
=
pad_token_id
if
pad_token_id
is
not
None
else
self
.
generation_config
.
pad_token_id
eos_token_id
=
eos_token_id
if
eos_token_id
is
not
None
else
self
.
config
.
eos_token_id
eos_token_id
=
eos_token_id
if
eos_token_id
is
not
None
else
self
.
generation_config
.
eos_token_id
output_scores
=
output_scores
if
output_scores
is
not
None
else
self
.
config
.
output_scores
output_scores
=
output_scores
if
output_scores
is
not
None
else
self
.
generation_config
.
output_scores
output_attentions
=
output_attentions
if
output_attentions
is
not
None
else
self
.
config
.
output_attentions
output_attentions
=
(
output_attentions
if
output_attentions
is
not
None
else
self
.
generation_config
.
output_attentions
)
output_hidden_states
=
(
output_hidden_states
=
(
output_hidden_states
if
output_hidden_states
is
not
None
else
self
.
config
.
output_hidden_states
output_hidden_states
if
output_hidden_states
is
not
None
else
self
.
generation_
config
.
output_hidden_states
)
)
return_dict_in_generate
=
(
return_dict_in_generate
=
(
return_dict_in_generate
if
return_dict_in_generate
is
not
None
else
self
.
config
.
return_dict_in_generate
return_dict_in_generate
if
return_dict_in_generate
is
not
None
else
self
.
generation_config
.
return_dict_in_generate
)
)
use_cache
=
model_kwargs
.
pop
(
"use_cache"
,
self
.
generation_config
.
use_cache
)
use_xla
=
not
tf
.
executing_eagerly
()
use_xla
=
not
tf
.
executing_eagerly
()
# TODO (Joao): fix cache format or find programatic way to detect cache index
# TODO (Joao): fix cache format or find programatic way to detect cache index
# GPT2 and other models has a slightly different cache structure, with a different batch axis
# GPT2 and other models has a slightly different cache structure, with a different batch axis
...
@@ -1546,7 +1433,7 @@ class TFGenerationMixin:
...
@@ -1546,7 +1433,7 @@ class TFGenerationMixin:
input_ids
=
generated
[:,
:
cur_len
]
input_ids
=
generated
[:,
:
cur_len
]
else
:
else
:
input_ids
=
tf
.
expand_dims
(
generated
[:,
cur_len
-
1
],
-
1
)
input_ids
=
tf
.
expand_dims
(
generated
[:,
cur_len
-
1
],
-
1
)
model_inputs
=
self
.
prepare_inputs_for_generation
(
input_ids
,
**
model_kwargs
)
model_inputs
=
self
.
prepare_inputs_for_generation
(
input_ids
,
use_cache
=
use_cache
,
**
model_kwargs
)
# forward pass to get next token logits
# forward pass to get next token logits
model_outputs
=
self
(
model_outputs
=
self
(
**
model_inputs
,
**
model_inputs
,
...
@@ -1772,17 +1659,22 @@ class TFGenerationMixin:
...
@@ -1772,17 +1659,22 @@ class TFGenerationMixin:
logits_processor
=
logits_processor
if
logits_processor
is
not
None
else
TFLogitsProcessorList
()
logits_processor
=
logits_processor
if
logits_processor
is
not
None
else
TFLogitsProcessorList
()
logits_warper
=
logits_warper
if
logits_warper
is
not
None
else
TFLogitsProcessorList
()
logits_warper
=
logits_warper
if
logits_warper
is
not
None
else
TFLogitsProcessorList
()
max_length
=
max_length
if
max_length
is
not
None
else
self
.
config
.
max_length
max_length
=
max_length
if
max_length
is
not
None
else
self
.
generation_config
.
max_length
pad_token_id
=
pad_token_id
if
pad_token_id
is
not
None
else
self
.
config
.
pad_token_id
pad_token_id
=
pad_token_id
if
pad_token_id
is
not
None
else
self
.
generation_config
.
pad_token_id
eos_token_id
=
eos_token_id
if
eos_token_id
is
not
None
else
self
.
config
.
eos_token_id
eos_token_id
=
eos_token_id
if
eos_token_id
is
not
None
else
self
.
generation_config
.
eos_token_id
output_scores
=
output_scores
if
output_scores
is
not
None
else
self
.
config
.
output_scores
output_scores
=
output_scores
if
output_scores
is
not
None
else
self
.
generation_config
.
output_scores
output_attentions
=
output_attentions
if
output_attentions
is
not
None
else
self
.
config
.
output_attentions
output_attentions
=
(
output_attentions
if
output_attentions
is
not
None
else
self
.
generation_config
.
output_attentions
)
output_hidden_states
=
(
output_hidden_states
=
(
output_hidden_states
if
output_hidden_states
is
not
None
else
self
.
config
.
output_hidden_states
output_hidden_states
if
output_hidden_states
is
not
None
else
self
.
generation_
config
.
output_hidden_states
)
)
return_dict_in_generate
=
(
return_dict_in_generate
=
(
return_dict_in_generate
if
return_dict_in_generate
is
not
None
else
self
.
config
.
return_dict_in_generate
return_dict_in_generate
if
return_dict_in_generate
is
not
None
else
self
.
generation_config
.
return_dict_in_generate
)
)
use_cache
=
model_kwargs
.
pop
(
"use_cache"
,
self
.
generation_config
.
use_cache
)
use_xla
=
not
tf
.
executing_eagerly
()
use_xla
=
not
tf
.
executing_eagerly
()
# TODO (Joao): fix cache format or find programatic way to detect cache index
# TODO (Joao): fix cache format or find programatic way to detect cache index
# GPT2 and other models has a slightly different cache structure, with a different batch axis
# GPT2 and other models has a slightly different cache structure, with a different batch axis
...
@@ -1814,7 +1706,7 @@ class TFGenerationMixin:
...
@@ -1814,7 +1706,7 @@ class TFGenerationMixin:
input_ids
=
generated
[:,
:
cur_len
]
input_ids
=
generated
[:,
:
cur_len
]
else
:
else
:
input_ids
=
tf
.
expand_dims
(
generated
[:,
cur_len
-
1
],
-
1
)
input_ids
=
tf
.
expand_dims
(
generated
[:,
cur_len
-
1
],
-
1
)
model_inputs
=
self
.
prepare_inputs_for_generation
(
input_ids
,
**
model_kwargs
)
model_inputs
=
self
.
prepare_inputs_for_generation
(
input_ids
,
use_cache
=
use_cache
,
**
model_kwargs
)
# forward pass to get next token logits
# forward pass to get next token logits
model_outputs
=
self
(
model_outputs
=
self
(
**
model_inputs
,
**
model_inputs
,
...
@@ -2091,25 +1983,30 @@ class TFGenerationMixin:
...
@@ -2091,25 +1983,30 @@ class TFGenerationMixin:
logits_processor
=
logits_processor
if
logits_processor
is
not
None
else
TFLogitsProcessorList
()
logits_processor
=
logits_processor
if
logits_processor
is
not
None
else
TFLogitsProcessorList
()
logits_warper
=
logits_warper
if
logits_warper
is
not
None
else
TFLogitsProcessorList
()
logits_warper
=
logits_warper
if
logits_warper
is
not
None
else
TFLogitsProcessorList
()
max_length
=
max_length
if
max_length
is
not
None
else
self
.
config
.
max_length
max_length
=
max_length
if
max_length
is
not
None
else
self
.
generation_
config
.
max_length
pad_token_id
=
pad_token_id
if
pad_token_id
is
not
None
else
self
.
config
.
pad_token_id
pad_token_id
=
pad_token_id
if
pad_token_id
is
not
None
else
self
.
generation_
config
.
pad_token_id
eos_token_id
=
eos_token_id
if
eos_token_id
is
not
None
else
self
.
config
.
eos_token_id
eos_token_id
=
eos_token_id
if
eos_token_id
is
not
None
else
self
.
generation_
config
.
eos_token_id
num_return_sequences
=
(
num_return_sequences
=
(
num_return_sequences
if
num_return_sequences
is
not
None
else
self
.
config
.
num_return_sequences
num_return_sequences
if
num_return_sequences
is
not
None
else
self
.
generation_
config
.
num_return_sequences
)
)
output_attentions
=
output_attentions
if
output_attentions
is
not
None
else
self
.
config
.
output_attentions
output_attentions
=
(
output_attentions
if
output_attentions
is
not
None
else
self
.
generation_config
.
output_attentions
)
output_hidden_states
=
(
output_hidden_states
=
(
output_hidden_states
if
output_hidden_states
is
not
None
else
self
.
config
.
output_hidden_states
output_hidden_states
if
output_hidden_states
is
not
None
else
self
.
generation_
config
.
output_hidden_states
)
)
output_scores
=
output_scores
if
output_scores
is
not
None
else
self
.
config
.
output_scores
output_scores
=
output_scores
if
output_scores
is
not
None
else
self
.
generation_
config
.
output_scores
return_dict_in_generate
=
(
return_dict_in_generate
=
(
return_dict_in_generate
if
return_dict_in_generate
is
not
None
else
self
.
config
.
return_dict_in_generate
return_dict_in_generate
if
return_dict_in_generate
is
not
None
else
self
.
generation_config
.
return_dict_in_generate
)
)
length_penalty
=
length_penalty
if
length_penalty
is
not
None
else
self
.
config
.
length_penalty
length_penalty
=
length_penalty
if
length_penalty
is
not
None
else
self
.
generation_
config
.
length_penalty
early_stopping
=
early_stopping
if
early_stopping
is
not
None
else
self
.
config
.
early_stopping
early_stopping
=
early_stopping
if
early_stopping
is
not
None
else
self
.
generation_
config
.
early_stopping
use_cache
=
model_kwargs
.
pop
(
"use_cache"
,
self
.
generation_config
.
use_cache
)
use_xla
=
not
tf
.
executing_eagerly
()
use_xla
=
not
tf
.
executing_eagerly
()
# TODO (Joao): fix cache format or find programatic way to detect cache index
# TODO (Joao): fix cache format or find programatic way to detect cache index
# GPT2 and other models has a slightly different cache structure, with a different batch axis
# GPT2 and other models has a slightly different cache structure, with a different batch axis
...
@@ -2199,7 +2096,9 @@ class TFGenerationMixin:
...
@@ -2199,7 +2096,9 @@ class TFGenerationMixin:
input_ids
=
running_sequences
[:,
:,
:
cur_len
]
input_ids
=
running_sequences
[:,
:,
:
cur_len
]
else
:
else
:
input_ids
=
tf
.
expand_dims
(
running_sequences
[:,
:,
cur_len
-
1
],
-
1
)
input_ids
=
tf
.
expand_dims
(
running_sequences
[:,
:,
cur_len
-
1
],
-
1
)
model_inputs
=
self
.
prepare_inputs_for_generation
(
flatten_beam_dim
(
input_ids
),
**
model_kwargs
)
model_inputs
=
self
.
prepare_inputs_for_generation
(
flatten_beam_dim
(
input_ids
),
use_cache
=
use_cache
,
**
model_kwargs
)
model_outputs
=
self
(
model_outputs
=
self
(
**
model_inputs
,
**
model_inputs
,
return_dict
=
True
,
return_dict
=
True
,
...
@@ -2521,17 +2420,22 @@ class TFGenerationMixin:
...
@@ -2521,17 +2420,22 @@ class TFGenerationMixin:
# 1. init greedy_search values
# 1. init greedy_search values
logits_processor
=
logits_processor
if
logits_processor
is
not
None
else
TFLogitsProcessorList
()
logits_processor
=
logits_processor
if
logits_processor
is
not
None
else
TFLogitsProcessorList
()
logits_warper
=
logits_warper
if
logits_warper
is
not
None
else
TFLogitsProcessorList
()
logits_warper
=
logits_warper
if
logits_warper
is
not
None
else
TFLogitsProcessorList
()
max_length
=
max_length
if
max_length
is
not
None
else
self
.
config
.
max_length
max_length
=
max_length
if
max_length
is
not
None
else
self
.
generation_config
.
max_length
pad_token_id
=
pad_token_id
if
pad_token_id
is
not
None
else
self
.
config
.
pad_token_id
pad_token_id
=
pad_token_id
if
pad_token_id
is
not
None
else
self
.
generation_config
.
pad_token_id
eos_token_id
=
eos_token_id
if
eos_token_id
is
not
None
else
self
.
config
.
eos_token_id
eos_token_id
=
eos_token_id
if
eos_token_id
is
not
None
else
self
.
generation_config
.
eos_token_id
output_scores
=
output_scores
if
output_scores
is
not
None
else
self
.
config
.
output_scores
output_scores
=
output_scores
if
output_scores
is
not
None
else
self
.
generation_config
.
output_scores
output_attentions
=
output_attentions
if
output_attentions
is
not
None
else
self
.
config
.
output_attentions
output_attentions
=
(
output_attentions
if
output_attentions
is
not
None
else
self
.
generation_config
.
output_attentions
)
output_hidden_states
=
(
output_hidden_states
=
(
output_hidden_states
if
output_hidden_states
is
not
None
else
self
.
config
.
output_hidden_states
output_hidden_states
if
output_hidden_states
is
not
None
else
self
.
generation_
config
.
output_hidden_states
)
)
return_dict_in_generate
=
(
return_dict_in_generate
=
(
return_dict_in_generate
if
return_dict_in_generate
is
not
None
else
self
.
config
.
return_dict_in_generate
return_dict_in_generate
if
return_dict_in_generate
is
not
None
else
self
.
generation_config
.
return_dict_in_generate
)
)
use_cache
=
True
# In contrastive search, we always use cache
use_xla
=
not
tf
.
executing_eagerly
()
use_xla
=
not
tf
.
executing_eagerly
()
# TODO (Joao): fix cache format or find programatic way to detect cache index
# TODO (Joao): fix cache format or find programatic way to detect cache index
# GPT2 and other models has a slightly different cache structure, with a different batch axis
# GPT2 and other models has a slightly different cache structure, with a different batch axis
...
@@ -2571,8 +2475,9 @@ class TFGenerationMixin:
...
@@ -2571,8 +2475,9 @@ class TFGenerationMixin:
if
model_kwargs
.
get
(
"past"
)
is
None
:
if
model_kwargs
.
get
(
"past"
)
is
None
:
# prepare inputs
# prepare inputs
model_inputs
=
self
.
prepare_inputs_for_generation
(
generated
[:,
:
cur_len
],
**
model_kwargs
)
model_inputs
=
self
.
prepare_inputs_for_generation
(
model_inputs
[
"use_cache"
]
=
True
generated
[:,
:
cur_len
],
use_cache
=
use_cache
,
**
model_kwargs
)
# encode the given prefix and prepare model inputs; encoder-decoder model process the prefix and save
# encode the given prefix and prepare model inputs; encoder-decoder model process the prefix and save
# the `encoder_outputs`
# the `encoder_outputs`
...
@@ -2662,8 +2567,9 @@ class TFGenerationMixin:
...
@@ -2662,8 +2567,9 @@ class TFGenerationMixin:
)
)
# compute the candidate tokens by the language model and collects their hidden_states
# compute the candidate tokens by the language model and collects their hidden_states
next_model_inputs
=
self
.
prepare_inputs_for_generation
(
tf
.
reshape
(
top_k_ids
,
[
-
1
,
1
]),
**
model_kwargs
)
next_model_inputs
=
self
.
prepare_inputs_for_generation
(
next_model_inputs
[
"use_cache"
]
=
True
tf
.
reshape
(
top_k_ids
,
[
-
1
,
1
]),
use_cache
=
use_cache
,
**
model_kwargs
)
outputs
=
self
(
outputs
=
self
(
**
next_model_inputs
,
return_dict
=
True
,
output_hidden_states
=
True
,
output_attentions
=
output_attentions
**
next_model_inputs
,
return_dict
=
True
,
output_hidden_states
=
True
,
output_attentions
=
output_attentions
)
)
...
...
src/transformers/modeling_tf_utils.py
View file @
a6c850e4
...
@@ -39,7 +39,7 @@ from . import DataCollatorWithPadding, DefaultDataCollator
...
@@ -39,7 +39,7 @@ from . import DataCollatorWithPadding, DefaultDataCollator
from
.activations_tf
import
get_tf_activation
from
.activations_tf
import
get_tf_activation
from
.configuration_utils
import
PretrainedConfig
from
.configuration_utils
import
PretrainedConfig
from
.dynamic_module_utils
import
custom_object_save
from
.dynamic_module_utils
import
custom_object_save
from
.generation
import
TFGenerationMixin
from
.generation
import
GenerationConfig
,
TFGenerationMixin
from
.tf_utils
import
shape_list
from
.tf_utils
import
shape_list
from
.utils
import
(
from
.utils
import
(
DUMMY_INPUTS
,
DUMMY_INPUTS
,
...
@@ -1137,6 +1137,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
...
@@ -1137,6 +1137,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
# Save config and origin of the pretrained weights if given in model
# Save config and origin of the pretrained weights if given in model
self
.
config
=
config
self
.
config
=
config
self
.
name_or_path
=
config
.
name_or_path
self
.
name_or_path
=
config
.
name_or_path
self
.
generation_config
=
GenerationConfig
.
from_model_config
(
config
)
if
self
.
can_generate
()
else
None
# Set the serving spec quickly to ensure that Keras doesn't use the specific dummy input shapes as the spec
# Set the serving spec quickly to ensure that Keras doesn't use the specific dummy input shapes as the spec
self
.
_set_save_spec
(
self
.
serving
.
input_signature
[
0
])
self
.
_set_save_spec
(
self
.
serving
.
input_signature
[
0
])
...
@@ -1200,6 +1201,18 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
...
@@ -1200,6 +1201,18 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
"""
"""
raise
NotImplementedError
raise
NotImplementedError
def
can_generate
(
self
)
->
bool
:
"""
Returns whether this model can generate sequences with `.generate()`.
Returns:
`bool`: Whether this model can generate sequences with `.generate()`.
"""
# Detects whether `prepare_inputs_for_generation` has been overwritten, which is a requirement for generation
if
"GenerationMixin"
in
str
(
self
.
prepare_inputs_for_generation
):
return
False
return
True
def
get_input_embeddings
(
self
)
->
tf
.
keras
.
layers
.
Layer
:
def
get_input_embeddings
(
self
)
->
tf
.
keras
.
layers
.
Layer
:
"""
"""
Returns the model's input embeddings layer.
Returns the model's input embeddings layer.
...
@@ -2832,6 +2845,29 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
...
@@ -2832,6 +2845,29 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
" to use it for predictions and inference."
" to use it for predictions and inference."
)
)
# If it is a model with generation capabilities, attempt to load the generation config
if
model
.
can_generate
():
try
:
model
.
generation_config
=
GenerationConfig
.
from_pretrained
(
pretrained_model_name_or_path
,
cache_dir
=
cache_dir
,
force_download
=
force_download
,
resume_download
=
resume_download
,
proxies
=
proxies
,
local_files_only
=
local_files_only
,
use_auth_token
=
use_auth_token
,
revision
=
revision
,
subfolder
=
subfolder
,
_from_auto
=
from_auto_class
,
_from_pipeline
=
from_pipeline
,
**
kwargs
,
)
except
OSError
:
logger
.
info
(
"Generation config file not found, using a generation config created from the model config."
)
pass
if
output_loading_info
:
if
output_loading_info
:
loading_info
=
{
loading_info
=
{
"missing_keys"
:
missing_keys
,
"missing_keys"
:
missing_keys
,
...
...
src/transformers/models/rag/modeling_tf_rag.py
View file @
a6c850e4
...
@@ -15,6 +15,7 @@
...
@@ -15,6 +15,7 @@
"""TFRAG model implementation."""
"""TFRAG model implementation."""
import
copy
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
List
,
Optional
,
Tuple
,
Union
from
typing
import
List
,
Optional
,
Tuple
,
Union
...
@@ -999,25 +1000,9 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss
...
@@ -999,25 +1000,9 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss
context_input_ids
=
None
,
context_input_ids
=
None
,
context_attention_mask
=
None
,
context_attention_mask
=
None
,
doc_scores
=
None
,
doc_scores
=
None
,
max_length
=
None
,
min_length
=
None
,
early_stopping
=
None
,
use_cache
=
None
,
num_beams
=
None
,
bos_token_id
=
None
,
pad_token_id
=
None
,
eos_token_id
=
None
,
length_penalty
=
None
,
no_repeat_ngram_size
=
None
,
bad_words_ids
=
None
,
num_return_sequences
=
None
,
decoder_start_token_id
=
None
,
n_docs
=
None
,
n_docs
=
None
,
output_scores
=
None
,
generation_config
=
None
,
output_attentions
=
None
,
**
kwargs
output_hidden_states
=
None
,
return_dict_in_generate
=
None
,
**
model_kwargs
):
):
"""
"""
Implements TFRAG token decoding.
Implements TFRAG token decoding.
...
@@ -1051,91 +1036,32 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss
...
@@ -1051,91 +1036,32 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss
If the model has is not initialized with a `retriever`, `context_input_ids` has to be provided to the
If the model has is not initialized with a `retriever`, `context_input_ids` has to be provided to the
forward pass. `context_input_ids` are returned by [`~RagRetriever.__call__`].
forward pass. `context_input_ids` are returned by [`~RagRetriever.__call__`].
max_length (`int`, *optional*, defaults to 20):
The maximum length of the sequence to be generated.
min_length (`int`, *optional*, defaults to 10):
The minimum length of the sequence to be generated.
early_stopping (`bool`, *optional*, defaults to `False`):
Whether or not to stop the beam search when at least `num_beams` sentences are finished per batch or
not.
use_cache: (`bool`, *optional*, defaults to `True`):
Whether or not the model should use the past last key/values attentions (if applicable to the model) to
speed up decoding.
pad_token_id (`int`, *optional*):
The id of the *padding* token.
bos_token_id (`int`, *optional*):
The id of the *beginning-of-sequence* token.
eos_token_id (`int`, *optional*):
The id of the *end-of-sequence* token.
length_penalty (`float`, *optional*, defaults to 1.0):
Exponential penalty to the length that is used with beam-based generation. It is applied as an exponent
to the sequence length, which in turn is used to divide the score of the sequence. Since the score is
the log likelihood of the sequence (i.e. negative), `length_penalty` > 0.0 promotes longer sequences,
while `length_penalty` < 0.0 encourages shorter sequences.
no_repeat_ngram_size (`int`, *optional*, defaults to 0):
If set to int > 0, all ngrams of that size can only occur once.
bad_words_ids(`List[int]`, *optional*):
List of token ids that are not allowed to be generated. In order to get the tokens of the words that
should not appear in the generated text, use `tokenizer.encode(bad_word, add_prefix_space=True)`.
num_beams (`int`, *optional*, defaults to 1):
Number of beams for beam search. 1 means no beam search.
num_return_sequences(`int`, *optional*, defaults to 1):
The number of independently computed returned sequences for each element in the batch. Note that this
is not the value we pass to the `generator`'s `[`~generation.GenerationMixin.generate`] function, where
we set `num_return_sequences` to `num_beams`. decoder_start_token_id (`int`, *optional*): If an
encoder-decoder model starts decoding with a different token than *bos*, the id of that token.
n_docs (`int`, *optional*, defaults to `config.n_docs`)
n_docs (`int`, *optional*, defaults to `config.n_docs`)
Number of documents to retrieve and/or number of documents for which to generate an answer.
Number of documents to retrieve and/or number of documents for which to generate an answer.
output_attentions (`bool`, *optional*, defaults to `False`):
generation_config (`~generation.GenerationConfig`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
The generation configuration to be used as base parametrization for the generation call. `**kwargs`
returned tensors for more details.
passed to generate matching the attributes of `generation_config` will override them. If
output_hidden_states (`bool`, *optional*, defaults to `False`):
`generation_config` is not provided, the default will be used, which had the following loading
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model
for more details.
configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s
output_scores (`bool`, *optional*, defaults to `False`):
default values, whose documentation should be checked to parameterize generation.
Whether or not to return the prediction scores. See `scores` under returned tensors for more details.
kwargs:
return_dict_in_generate (`bool`, *optional*, defaults to `False`):
Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
forwarded to the `forward` function of the model.
model_specific_kwargs:
Additional model specific kwargs will be forwarded to the `forward` function of the model.
Return:
Return:
`tf.Tensor` of shape `(batch_size * num_return_sequences, sequence_length)`: The generated sequences. The
`tf.Tensor` of shape `(batch_size * num_return_sequences, sequence_length)`: The generated sequences. The
second dimension (sequence_length) is either equal to `max_length` or shorter if all batches finished early
second dimension (sequence_length) is either equal to `max_length` or shorter if all batches finished early
due to the `eos_token_id`.
due to the `eos_token_id`.
"""
"""
# Handle `generation_config` and kwargs that might update it
if
generation_config
is
None
:
generation_config
=
self
.
generation_config
generation_config
=
copy
.
deepcopy
(
generation_config
)
model_kwargs
=
generation_config
.
update
(
**
kwargs
)
# All unused kwargs must be model kwargs
# set default parameters
# set default parameters
n_docs
=
n_docs
if
n_docs
is
not
None
else
self
.
config
.
n_docs
n_docs
=
n_docs
if
n_docs
is
not
None
else
self
.
config
.
n_docs
max_length
=
max_length
if
max_length
is
not
None
else
self
.
config
.
max_length
min_length
=
min_length
if
min_length
is
not
None
else
self
.
config
.
min_length
early_stopping
=
early_stopping
if
early_stopping
is
not
None
else
self
.
config
.
early_stopping
use_cache
=
use_cache
if
use_cache
is
not
None
else
self
.
config
.
use_cache
num_beams
=
num_beams
if
num_beams
is
not
None
else
self
.
config
.
num_beams
bos_token_id
=
bos_token_id
if
bos_token_id
is
not
None
else
self
.
config
.
generator
.
bos_token_id
pad_token_id
=
pad_token_id
if
pad_token_id
is
not
None
else
self
.
config
.
generator
.
pad_token_id
eos_token_id
=
eos_token_id
if
eos_token_id
is
not
None
else
self
.
config
.
generator
.
eos_token_id
length_penalty
=
length_penalty
if
length_penalty
is
not
None
else
self
.
config
.
length_penalty
no_repeat_ngram_size
=
(
no_repeat_ngram_size
if
no_repeat_ngram_size
is
not
None
else
self
.
config
.
no_repeat_ngram_size
)
bad_words_ids
=
bad_words_ids
if
bad_words_ids
is
not
None
else
self
.
config
.
bad_words_ids
num_return_sequences
=
(
num_return_sequences
if
num_return_sequences
is
not
None
else
self
.
config
.
num_return_sequences
)
decoder_start_token_id
=
(
decoder_start_token_id
if
decoder_start_token_id
is
not
None
else
self
.
config
.
generator
.
decoder_start_token_id
)
output_scores
=
output_scores
if
output_scores
is
not
None
else
self
.
config
.
output_scores
output_attentions
=
output_attentions
if
output_attentions
is
not
None
else
self
.
config
.
output_attentions
output_hidden_states
=
(
output_hidden_states
if
output_hidden_states
is
not
None
else
self
.
config
.
output_hidden_states
)
return_dict_in_generate
=
(
return_dict_in_generate
if
return_dict_in_generate
is
not
None
else
self
.
config
.
return_dict_in_generate
)
# retrieve docs
# retrieve docs
if
self
.
retriever
is
not
None
and
context_input_ids
is
None
:
if
self
.
retriever
is
not
None
and
context_input_ids
is
None
:
...
@@ -1174,14 +1100,14 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss
...
@@ -1174,14 +1100,14 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss
encoder_outputs
=
encoder
(
encoder_outputs
=
encoder
(
input_ids
=
context_input_ids
,
input_ids
=
context_input_ids
,
attention_mask
=
context_attention_mask
,
attention_mask
=
context_attention_mask
,
output_attentions
=
output_attentions
,
output_attentions
=
generation_config
.
output_attentions
,
output_hidden_states
=
output_hidden_states
,
output_hidden_states
=
generation_config
.
output_hidden_states
,
return_dict
=
True
,
return_dict
=
True
,
)
)
decoder_input_ids
=
tf
.
fill
(
decoder_input_ids
=
tf
.
fill
(
(
batch_size
*
num_beams
,
1
),
(
batch_size
*
generation_config
.
num_beams
,
1
),
tf
.
cast
(
decoder_start_token_id
,
tf
.
int32
),
tf
.
cast
(
generation_config
.
decoder_start_token_id
,
tf
.
int32
),
)
)
last_hidden_state
=
encoder_outputs
[
"last_hidden_state"
]
last_hidden_state
=
encoder_outputs
[
"last_hidden_state"
]
...
@@ -1207,10 +1133,12 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss
...
@@ -1207,10 +1133,12 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss
return
tf
.
reshape
(
tensor
,
new_shape
)
return
tf
.
reshape
(
tensor
,
new_shape
)
# correctly extend last_hidden_state and attention mask
# correctly extend last_hidden_state and attention mask
context_attention_mask
=
extend_enc_output
(
context_attention_mask
,
num_beams
=
num_beams
)
context_attention_mask
=
extend_enc_output
(
context_attention_mask
,
num_beams
=
generation_config
.
num_beams
)
encoder_outputs
[
"last_hidden_state"
]
=
extend_enc_output
(
last_hidden_state
,
num_beams
=
num_beams
)
encoder_outputs
[
"last_hidden_state"
]
=
extend_enc_output
(
last_hidden_state
,
num_beams
=
generation_config
.
num_beams
)
doc_scores
=
tf
.
repeat
(
doc_scores
,
num_beams
,
axis
=
0
)
doc_scores
=
tf
.
repeat
(
doc_scores
,
generation_config
.
num_beams
,
axis
=
0
)
# define start_len & additional parameters
# define start_len & additional parameters
model_kwargs
[
"doc_scores"
]
=
doc_scores
model_kwargs
[
"doc_scores"
]
=
doc_scores
...
@@ -1219,41 +1147,35 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss
...
@@ -1219,41 +1147,35 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss
model_kwargs
[
"n_docs"
]
=
n_docs
model_kwargs
[
"n_docs"
]
=
n_docs
pre_processor
=
self
.
_get_logits_processor
(
pre_processor
=
self
.
_get_logits_processor
(
repetition_penalty
=
self
.
config
.
repetition_penalty
,
generation_config
=
generation_config
,
no_repeat_ngram_size
=
no_repeat_ngram_size
,
bad_words_ids
=
bad_words_ids
,
min_length
=
min_length
,
max_length
=
max_length
,
eos_token_id
=
eos_token_id
,
forced_bos_token_id
=
self
.
config
.
generator
.
forced_bos_token_id
,
forced_eos_token_id
=
self
.
config
.
generator
.
forced_eos_token_id
,
input_ids_seq_length
=
tf
.
shape
(
decoder_input_ids
)[
-
1
],
input_ids_seq_length
=
tf
.
shape
(
decoder_input_ids
)[
-
1
],
)
)
if
num_beams
==
1
:
if
generation_config
.
num_beams
==
1
:
return
self
.
greedy_search
(
return
self
.
greedy_search
(
input_ids
=
decoder_input_ids
,
input_ids
=
decoder_input_ids
,
max_length
=
max_length
,
max_length
=
generation_config
.
max_length
,
pad_token_id
=
pad_token_id
,
pad_token_id
=
generation_config
.
pad_token_id
,
eos_token_id
=
eos_token_id
,
eos_token_id
=
generation_config
.
eos_token_id
,
logits_processor
=
pre_processor
,
logits_processor
=
pre_processor
,
output_attentions
=
output_attentions
,
output_attentions
=
generation_config
.
output_attentions
,
output_hidden_states
=
output_hidden_states
,
output_hidden_states
=
generation_config
.
output_hidden_states
,
output_scores
=
output_scores
,
output_scores
=
generation_config
.
output_scores
,
return_dict_in_generate
=
return_dict_in_generate
,
return_dict_in_generate
=
generation_config
.
return_dict_in_generate
,
**
model_kwargs
,
**
model_kwargs
,
)
)
elif
num_beams
>
1
:
elif
generation_config
.
num_beams
>
1
:
if
num_beams
<
num_return_sequences
:
if
generation_config
.
num_beams
<
generation_config
.
num_return_sequences
:
raise
ValueError
(
raise
ValueError
(
"Beam search decoding cannot return more sequences than it has beams. Please set "
"Beam search decoding cannot return more sequences than it has beams. Please set num_beams >="
f
"num_beams >= num_return_sequences, got
{
num_beams
}
and
{
num_return_sequences
}
(respectivelly)"
f
" num_return_sequences, got
{
generation_config
.
num_beams
}
and"
f
"
{
generation_config
.
num_return_sequences
}
(respectivelly)"
)
)
def
unflatten_beam_dim
(
tensor
):
def
unflatten_beam_dim
(
tensor
):
"""Unflattens the first, flat batch*beam dimension of a non-scalar array."""
"""Unflattens the first, flat batch*beam dimension of a non-scalar array."""
shape
=
shape_list
(
tensor
)
shape
=
shape_list
(
tensor
)
return
tf
.
reshape
(
tensor
,
[
-
1
,
num_beams
]
+
shape
[
1
:])
return
tf
.
reshape
(
tensor
,
[
-
1
,
generation_config
.
num_beams
]
+
shape
[
1
:])
decoder_input_ids
=
unflatten_beam_dim
(
decoder_input_ids
)
decoder_input_ids
=
unflatten_beam_dim
(
decoder_input_ids
)
model_kwargs
[
"attention_mask"
]
=
unflatten_beam_dim
(
model_kwargs
[
"attention_mask"
])
model_kwargs
[
"attention_mask"
]
=
unflatten_beam_dim
(
model_kwargs
[
"attention_mask"
])
...
@@ -1263,18 +1185,20 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss
...
@@ -1263,18 +1185,20 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss
return
self
.
beam_search
(
return
self
.
beam_search
(
input_ids
=
decoder_input_ids
,
input_ids
=
decoder_input_ids
,
max_length
=
max_length
,
max_length
=
generation_config
.
max_length
,
pad_token_id
=
pad_token_id
,
pad_token_id
=
generation_config
.
pad_token_id
,
eos_token_id
=
eos_token_id
,
eos_token_id
=
generation_config
.
eos_token_id
,
logits_processor
=
pre_processor
,
logits_processor
=
pre_processor
,
output_attentions
=
output_attentions
,
output_attentions
=
generation_config
.
output_attentions
,
output_hidden_states
=
output_hidden_states
,
output_hidden_states
=
generation_config
.
output_hidden_states
,
output_scores
=
output_scores
,
output_scores
=
generation_config
.
output_scores
,
return_dict_in_generate
=
return_dict_in_generate
,
return_dict_in_generate
=
generation_config
.
return_dict_in_generate
,
**
model_kwargs
,
**
model_kwargs
,
)
)
else
:
else
:
raise
ValueError
(
f
"`num_beams` has to be an integer strictly superior to 0 (≥ 1), but is
{
num_beams
}
"
)
raise
ValueError
(
f
"`num_beams` has to be an integer strictly superior to 0 (≥ 1), but is
{
generation_config
.
num_beams
}
"
)
def
get_input_embeddings
(
self
):
def
get_input_embeddings
(
self
):
return
self
.
rag
.
generator
.
get_input_embeddings
()
return
self
.
rag
.
generator
.
get_input_embeddings
()
...
...
tests/test_modeling_tf_common.py
View file @
a6c850e4
...
@@ -1824,18 +1824,18 @@ class TFModelTesterMixin:
...
@@ -1824,18 +1824,18 @@ class TFModelTesterMixin:
model
.
train_on_batch
(
test_batch
,
test_batch_labels
)
model
.
train_on_batch
(
test_batch
,
test_batch_labels
)
def
_test_xla_generate
(
self
,
**
generate_kwargs
):
def
_test_xla_generate
(
self
,
**
generate_kwargs
):
def
_generate_and_check_results
(
model
,
config
,
inputs_dict
):
def
_generate_and_check_results
(
model
,
inputs_dict
):
if
"input_ids"
in
inputs_dict
:
if
"input_ids"
in
inputs_dict
:
inputs
=
inputs_dict
[
"input_ids"
]
inputs
=
inputs_dict
[
"input_ids"
]
# make sure there are no pad tokens in prompt, which may trigger unwanted behavior
# make sure there are no pad tokens in prompt, which may trigger unwanted behavior
if
config
.
pad_token_id
is
not
None
:
if
model
.
generation_
config
.
pad_token_id
is
not
None
:
if
config
.
pad_token_id
==
0
:
if
config
.
pad_token_id
==
0
:
new_pad_token
=
config
.
pad_token_id
+
1
new_pad_token
=
model
.
generation_
config
.
pad_token_id
+
1
else
:
else
:
new_pad_token
=
config
.
pad_token_id
-
1
new_pad_token
=
model
.
generation_
config
.
pad_token_id
-
1
else
:
else
:
new_pad_token
=
None
new_pad_token
=
None
inputs
=
tf
.
where
(
inputs
!=
config
.
pad_token_id
,
inputs
,
new_pad_token
)
inputs
=
tf
.
where
(
inputs
!=
model
.
generation_
config
.
pad_token_id
,
inputs
,
new_pad_token
)
elif
"input_features"
in
inputs_dict
:
elif
"input_features"
in
inputs_dict
:
inputs
=
inputs_dict
[
"input_features"
]
inputs
=
inputs_dict
[
"input_features"
]
else
:
else
:
...
@@ -1854,10 +1854,10 @@ class TFModelTesterMixin:
...
@@ -1854,10 +1854,10 @@ class TFModelTesterMixin:
model
=
model_class
(
config
)
model
=
model_class
(
config
)
if
model
.
supports_xla_generation
:
if
model
.
supports_xla_generation
:
_generate_and_check_results
(
model
,
config
,
inputs_dict
)
_generate_and_check_results
(
model
,
inputs_dict
)
else
:
else
:
with
self
.
assertRaises
(
ValueError
):
with
self
.
assertRaises
(
ValueError
):
_generate_and_check_results
(
model
,
config
,
inputs_dict
)
_generate_and_check_results
(
model
,
inputs_dict
)
def
test_xla_generate_fast
(
self
):
def
test_xla_generate_fast
(
self
):
"""
"""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment