Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
a6c850e4
Unverified
Commit
a6c850e4
authored
Jan 04, 2023
by
Joao Gante
Committed by
GitHub
Jan 04, 2023
Browse files
Generate: TF uses `GenerationConfig` as the basis for `.generate()` parametrization (#20994)
parent
3b309818
Changes
4
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
447 additions
and
581 deletions
+447
-581
src/transformers/generation/tf_utils.py
src/transformers/generation/tf_utils.py
+350
-444
src/transformers/modeling_tf_utils.py
src/transformers/modeling_tf_utils.py
+37
-1
src/transformers/models/rag/modeling_tf_rag.py
src/transformers/models/rag/modeling_tf_rag.py
+53
-129
tests/test_modeling_tf_common.py
tests/test_modeling_tf_common.py
+7
-7
No files found.
src/transformers/generation/tf_utils.py
View file @
a6c850e4
This diff is collapsed.
Click to expand it.
src/transformers/modeling_tf_utils.py
View file @
a6c850e4
...
@@ -39,7 +39,7 @@ from . import DataCollatorWithPadding, DefaultDataCollator
...
@@ -39,7 +39,7 @@ from . import DataCollatorWithPadding, DefaultDataCollator
from
.activations_tf
import
get_tf_activation
from
.activations_tf
import
get_tf_activation
from
.configuration_utils
import
PretrainedConfig
from
.configuration_utils
import
PretrainedConfig
from
.dynamic_module_utils
import
custom_object_save
from
.dynamic_module_utils
import
custom_object_save
from
.generation
import
TFGenerationMixin
from
.generation
import
GenerationConfig
,
TFGenerationMixin
from
.tf_utils
import
shape_list
from
.tf_utils
import
shape_list
from
.utils
import
(
from
.utils
import
(
DUMMY_INPUTS
,
DUMMY_INPUTS
,
...
@@ -1137,6 +1137,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
...
@@ -1137,6 +1137,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
# Save config and origin of the pretrained weights if given in model
# Save config and origin of the pretrained weights if given in model
self
.
config
=
config
self
.
config
=
config
self
.
name_or_path
=
config
.
name_or_path
self
.
name_or_path
=
config
.
name_or_path
self
.
generation_config
=
GenerationConfig
.
from_model_config
(
config
)
if
self
.
can_generate
()
else
None
# Set the serving spec quickly to ensure that Keras doesn't use the specific dummy input shapes as the spec
# Set the serving spec quickly to ensure that Keras doesn't use the specific dummy input shapes as the spec
self
.
_set_save_spec
(
self
.
serving
.
input_signature
[
0
])
self
.
_set_save_spec
(
self
.
serving
.
input_signature
[
0
])
...
@@ -1200,6 +1201,18 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
...
@@ -1200,6 +1201,18 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
"""
"""
raise
NotImplementedError
raise
NotImplementedError
def
can_generate
(
self
)
->
bool
:
"""
Returns whether this model can generate sequences with `.generate()`.
Returns:
`bool`: Whether this model can generate sequences with `.generate()`.
"""
# Detects whether `prepare_inputs_for_generation` has been overwritten, which is a requirement for generation
if
"GenerationMixin"
in
str
(
self
.
prepare_inputs_for_generation
):
return
False
return
True
def
get_input_embeddings
(
self
)
->
tf
.
keras
.
layers
.
Layer
:
def
get_input_embeddings
(
self
)
->
tf
.
keras
.
layers
.
Layer
:
"""
"""
Returns the model's input embeddings layer.
Returns the model's input embeddings layer.
...
@@ -2832,6 +2845,29 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
...
@@ -2832,6 +2845,29 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
" to use it for predictions and inference."
" to use it for predictions and inference."
)
)
# If it is a model with generation capabilities, attempt to load the generation config
if
model
.
can_generate
():
try
:
model
.
generation_config
=
GenerationConfig
.
from_pretrained
(
pretrained_model_name_or_path
,
cache_dir
=
cache_dir
,
force_download
=
force_download
,
resume_download
=
resume_download
,
proxies
=
proxies
,
local_files_only
=
local_files_only
,
use_auth_token
=
use_auth_token
,
revision
=
revision
,
subfolder
=
subfolder
,
_from_auto
=
from_auto_class
,
_from_pipeline
=
from_pipeline
,
**
kwargs
,
)
except
OSError
:
logger
.
info
(
"Generation config file not found, using a generation config created from the model config."
)
pass
if
output_loading_info
:
if
output_loading_info
:
loading_info
=
{
loading_info
=
{
"missing_keys"
:
missing_keys
,
"missing_keys"
:
missing_keys
,
...
...
src/transformers/models/rag/modeling_tf_rag.py
View file @
a6c850e4
...
@@ -15,6 +15,7 @@
...
@@ -15,6 +15,7 @@
"""TFRAG model implementation."""
"""TFRAG model implementation."""
import
copy
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
List
,
Optional
,
Tuple
,
Union
from
typing
import
List
,
Optional
,
Tuple
,
Union
...
@@ -999,25 +1000,9 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss
...
@@ -999,25 +1000,9 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss
context_input_ids
=
None
,
context_input_ids
=
None
,
context_attention_mask
=
None
,
context_attention_mask
=
None
,
doc_scores
=
None
,
doc_scores
=
None
,
max_length
=
None
,
min_length
=
None
,
early_stopping
=
None
,
use_cache
=
None
,
num_beams
=
None
,
bos_token_id
=
None
,
pad_token_id
=
None
,
eos_token_id
=
None
,
length_penalty
=
None
,
no_repeat_ngram_size
=
None
,
bad_words_ids
=
None
,
num_return_sequences
=
None
,
decoder_start_token_id
=
None
,
n_docs
=
None
,
n_docs
=
None
,
output_scores
=
None
,
generation_config
=
None
,
output_attentions
=
None
,
**
kwargs
output_hidden_states
=
None
,
return_dict_in_generate
=
None
,
**
model_kwargs
):
):
"""
"""
Implements TFRAG token decoding.
Implements TFRAG token decoding.
...
@@ -1051,91 +1036,32 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss
...
@@ -1051,91 +1036,32 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss
If the model has is not initialized with a `retriever`, `context_input_ids` has to be provided to the
If the model has is not initialized with a `retriever`, `context_input_ids` has to be provided to the
forward pass. `context_input_ids` are returned by [`~RagRetriever.__call__`].
forward pass. `context_input_ids` are returned by [`~RagRetriever.__call__`].
max_length (`int`, *optional*, defaults to 20):
The maximum length of the sequence to be generated.
min_length (`int`, *optional*, defaults to 10):
The minimum length of the sequence to be generated.
early_stopping (`bool`, *optional*, defaults to `False`):
Whether or not to stop the beam search when at least `num_beams` sentences are finished per batch or
not.
use_cache: (`bool`, *optional*, defaults to `True`):
Whether or not the model should use the past last key/values attentions (if applicable to the model) to
speed up decoding.
pad_token_id (`int`, *optional*):
The id of the *padding* token.
bos_token_id (`int`, *optional*):
The id of the *beginning-of-sequence* token.
eos_token_id (`int`, *optional*):
The id of the *end-of-sequence* token.
length_penalty (`float`, *optional*, defaults to 1.0):
Exponential penalty to the length that is used with beam-based generation. It is applied as an exponent
to the sequence length, which in turn is used to divide the score of the sequence. Since the score is
the log likelihood of the sequence (i.e. negative), `length_penalty` > 0.0 promotes longer sequences,
while `length_penalty` < 0.0 encourages shorter sequences.
no_repeat_ngram_size (`int`, *optional*, defaults to 0):
If set to int > 0, all ngrams of that size can only occur once.
bad_words_ids(`List[int]`, *optional*):
List of token ids that are not allowed to be generated. In order to get the tokens of the words that
should not appear in the generated text, use `tokenizer.encode(bad_word, add_prefix_space=True)`.
num_beams (`int`, *optional*, defaults to 1):
Number of beams for beam search. 1 means no beam search.
num_return_sequences(`int`, *optional*, defaults to 1):
The number of independently computed returned sequences for each element in the batch. Note that this
is not the value we pass to the `generator`'s `[`~generation.GenerationMixin.generate`] function, where
we set `num_return_sequences` to `num_beams`. decoder_start_token_id (`int`, *optional*): If an
encoder-decoder model starts decoding with a different token than *bos*, the id of that token.
n_docs (`int`, *optional*, defaults to `config.n_docs`)
n_docs (`int`, *optional*, defaults to `config.n_docs`)
Number of documents to retrieve and/or number of documents for which to generate an answer.
Number of documents to retrieve and/or number of documents for which to generate an answer.
output_attentions (`bool`, *optional*, defaults to `False`):
generation_config (`~generation.GenerationConfig`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
The generation configuration to be used as base parametrization for the generation call. `**kwargs`
returned tensors for more details.
passed to generate matching the attributes of `generation_config` will override them. If
output_hidden_states (`bool`, *optional*, defaults to `False`):
`generation_config` is not provided, the default will be used, which had the following loading
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model
for more details.
configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s
output_scores (`bool`, *optional*, defaults to `False`):
default values, whose documentation should be checked to parameterize generation.
Whether or not to return the prediction scores. See `scores` under returned tensors for more details.
kwargs:
return_dict_in_generate (`bool`, *optional*, defaults to `False`):
Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
forwarded to the `forward` function of the model.
model_specific_kwargs:
Additional model specific kwargs will be forwarded to the `forward` function of the model.
Return:
Return:
`tf.Tensor` of shape `(batch_size * num_return_sequences, sequence_length)`: The generated sequences. The
`tf.Tensor` of shape `(batch_size * num_return_sequences, sequence_length)`: The generated sequences. The
second dimension (sequence_length) is either equal to `max_length` or shorter if all batches finished early
second dimension (sequence_length) is either equal to `max_length` or shorter if all batches finished early
due to the `eos_token_id`.
due to the `eos_token_id`.
"""
"""
# Handle `generation_config` and kwargs that might update it
if
generation_config
is
None
:
generation_config
=
self
.
generation_config
generation_config
=
copy
.
deepcopy
(
generation_config
)
model_kwargs
=
generation_config
.
update
(
**
kwargs
)
# All unused kwargs must be model kwargs
# set default parameters
# set default parameters
n_docs
=
n_docs
if
n_docs
is
not
None
else
self
.
config
.
n_docs
n_docs
=
n_docs
if
n_docs
is
not
None
else
self
.
config
.
n_docs
max_length
=
max_length
if
max_length
is
not
None
else
self
.
config
.
max_length
min_length
=
min_length
if
min_length
is
not
None
else
self
.
config
.
min_length
early_stopping
=
early_stopping
if
early_stopping
is
not
None
else
self
.
config
.
early_stopping
use_cache
=
use_cache
if
use_cache
is
not
None
else
self
.
config
.
use_cache
num_beams
=
num_beams
if
num_beams
is
not
None
else
self
.
config
.
num_beams
bos_token_id
=
bos_token_id
if
bos_token_id
is
not
None
else
self
.
config
.
generator
.
bos_token_id
pad_token_id
=
pad_token_id
if
pad_token_id
is
not
None
else
self
.
config
.
generator
.
pad_token_id
eos_token_id
=
eos_token_id
if
eos_token_id
is
not
None
else
self
.
config
.
generator
.
eos_token_id
length_penalty
=
length_penalty
if
length_penalty
is
not
None
else
self
.
config
.
length_penalty
no_repeat_ngram_size
=
(
no_repeat_ngram_size
if
no_repeat_ngram_size
is
not
None
else
self
.
config
.
no_repeat_ngram_size
)
bad_words_ids
=
bad_words_ids
if
bad_words_ids
is
not
None
else
self
.
config
.
bad_words_ids
num_return_sequences
=
(
num_return_sequences
if
num_return_sequences
is
not
None
else
self
.
config
.
num_return_sequences
)
decoder_start_token_id
=
(
decoder_start_token_id
if
decoder_start_token_id
is
not
None
else
self
.
config
.
generator
.
decoder_start_token_id
)
output_scores
=
output_scores
if
output_scores
is
not
None
else
self
.
config
.
output_scores
output_attentions
=
output_attentions
if
output_attentions
is
not
None
else
self
.
config
.
output_attentions
output_hidden_states
=
(
output_hidden_states
if
output_hidden_states
is
not
None
else
self
.
config
.
output_hidden_states
)
return_dict_in_generate
=
(
return_dict_in_generate
if
return_dict_in_generate
is
not
None
else
self
.
config
.
return_dict_in_generate
)
# retrieve docs
# retrieve docs
if
self
.
retriever
is
not
None
and
context_input_ids
is
None
:
if
self
.
retriever
is
not
None
and
context_input_ids
is
None
:
...
@@ -1174,14 +1100,14 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss
...
@@ -1174,14 +1100,14 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss
encoder_outputs
=
encoder
(
encoder_outputs
=
encoder
(
input_ids
=
context_input_ids
,
input_ids
=
context_input_ids
,
attention_mask
=
context_attention_mask
,
attention_mask
=
context_attention_mask
,
output_attentions
=
output_attentions
,
output_attentions
=
generation_config
.
output_attentions
,
output_hidden_states
=
output_hidden_states
,
output_hidden_states
=
generation_config
.
output_hidden_states
,
return_dict
=
True
,
return_dict
=
True
,
)
)
decoder_input_ids
=
tf
.
fill
(
decoder_input_ids
=
tf
.
fill
(
(
batch_size
*
num_beams
,
1
),
(
batch_size
*
generation_config
.
num_beams
,
1
),
tf
.
cast
(
decoder_start_token_id
,
tf
.
int32
),
tf
.
cast
(
generation_config
.
decoder_start_token_id
,
tf
.
int32
),
)
)
last_hidden_state
=
encoder_outputs
[
"last_hidden_state"
]
last_hidden_state
=
encoder_outputs
[
"last_hidden_state"
]
...
@@ -1207,10 +1133,12 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss
...
@@ -1207,10 +1133,12 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss
return
tf
.
reshape
(
tensor
,
new_shape
)
return
tf
.
reshape
(
tensor
,
new_shape
)
# correctly extend last_hidden_state and attention mask
# correctly extend last_hidden_state and attention mask
context_attention_mask
=
extend_enc_output
(
context_attention_mask
,
num_beams
=
num_beams
)
context_attention_mask
=
extend_enc_output
(
context_attention_mask
,
num_beams
=
generation_config
.
num_beams
)
encoder_outputs
[
"last_hidden_state"
]
=
extend_enc_output
(
last_hidden_state
,
num_beams
=
num_beams
)
encoder_outputs
[
"last_hidden_state"
]
=
extend_enc_output
(
last_hidden_state
,
num_beams
=
generation_config
.
num_beams
)
doc_scores
=
tf
.
repeat
(
doc_scores
,
num_beams
,
axis
=
0
)
doc_scores
=
tf
.
repeat
(
doc_scores
,
generation_config
.
num_beams
,
axis
=
0
)
# define start_len & additional parameters
# define start_len & additional parameters
model_kwargs
[
"doc_scores"
]
=
doc_scores
model_kwargs
[
"doc_scores"
]
=
doc_scores
...
@@ -1219,41 +1147,35 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss
...
@@ -1219,41 +1147,35 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss
model_kwargs
[
"n_docs"
]
=
n_docs
model_kwargs
[
"n_docs"
]
=
n_docs
pre_processor
=
self
.
_get_logits_processor
(
pre_processor
=
self
.
_get_logits_processor
(
repetition_penalty
=
self
.
config
.
repetition_penalty
,
generation_config
=
generation_config
,
no_repeat_ngram_size
=
no_repeat_ngram_size
,
bad_words_ids
=
bad_words_ids
,
min_length
=
min_length
,
max_length
=
max_length
,
eos_token_id
=
eos_token_id
,
forced_bos_token_id
=
self
.
config
.
generator
.
forced_bos_token_id
,
forced_eos_token_id
=
self
.
config
.
generator
.
forced_eos_token_id
,
input_ids_seq_length
=
tf
.
shape
(
decoder_input_ids
)[
-
1
],
input_ids_seq_length
=
tf
.
shape
(
decoder_input_ids
)[
-
1
],
)
)
if
num_beams
==
1
:
if
generation_config
.
num_beams
==
1
:
return
self
.
greedy_search
(
return
self
.
greedy_search
(
input_ids
=
decoder_input_ids
,
input_ids
=
decoder_input_ids
,
max_length
=
max_length
,
max_length
=
generation_config
.
max_length
,
pad_token_id
=
pad_token_id
,
pad_token_id
=
generation_config
.
pad_token_id
,
eos_token_id
=
eos_token_id
,
eos_token_id
=
generation_config
.
eos_token_id
,
logits_processor
=
pre_processor
,
logits_processor
=
pre_processor
,
output_attentions
=
output_attentions
,
output_attentions
=
generation_config
.
output_attentions
,
output_hidden_states
=
output_hidden_states
,
output_hidden_states
=
generation_config
.
output_hidden_states
,
output_scores
=
output_scores
,
output_scores
=
generation_config
.
output_scores
,
return_dict_in_generate
=
return_dict_in_generate
,
return_dict_in_generate
=
generation_config
.
return_dict_in_generate
,
**
model_kwargs
,
**
model_kwargs
,
)
)
elif
num_beams
>
1
:
elif
generation_config
.
num_beams
>
1
:
if
num_beams
<
num_return_sequences
:
if
generation_config
.
num_beams
<
generation_config
.
num_return_sequences
:
raise
ValueError
(
raise
ValueError
(
"Beam search decoding cannot return more sequences than it has beams. Please set "
"Beam search decoding cannot return more sequences than it has beams. Please set num_beams >="
f
"num_beams >= num_return_sequences, got
{
num_beams
}
and
{
num_return_sequences
}
(respectivelly)"
f
" num_return_sequences, got
{
generation_config
.
num_beams
}
and"
f
"
{
generation_config
.
num_return_sequences
}
(respectivelly)"
)
)
def
unflatten_beam_dim
(
tensor
):
def
unflatten_beam_dim
(
tensor
):
"""Unflattens the first, flat batch*beam dimension of a non-scalar array."""
"""Unflattens the first, flat batch*beam dimension of a non-scalar array."""
shape
=
shape_list
(
tensor
)
shape
=
shape_list
(
tensor
)
return
tf
.
reshape
(
tensor
,
[
-
1
,
num_beams
]
+
shape
[
1
:])
return
tf
.
reshape
(
tensor
,
[
-
1
,
generation_config
.
num_beams
]
+
shape
[
1
:])
decoder_input_ids
=
unflatten_beam_dim
(
decoder_input_ids
)
decoder_input_ids
=
unflatten_beam_dim
(
decoder_input_ids
)
model_kwargs
[
"attention_mask"
]
=
unflatten_beam_dim
(
model_kwargs
[
"attention_mask"
])
model_kwargs
[
"attention_mask"
]
=
unflatten_beam_dim
(
model_kwargs
[
"attention_mask"
])
...
@@ -1263,18 +1185,20 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss
...
@@ -1263,18 +1185,20 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss
return
self
.
beam_search
(
return
self
.
beam_search
(
input_ids
=
decoder_input_ids
,
input_ids
=
decoder_input_ids
,
max_length
=
max_length
,
max_length
=
generation_config
.
max_length
,
pad_token_id
=
pad_token_id
,
pad_token_id
=
generation_config
.
pad_token_id
,
eos_token_id
=
eos_token_id
,
eos_token_id
=
generation_config
.
eos_token_id
,
logits_processor
=
pre_processor
,
logits_processor
=
pre_processor
,
output_attentions
=
output_attentions
,
output_attentions
=
generation_config
.
output_attentions
,
output_hidden_states
=
output_hidden_states
,
output_hidden_states
=
generation_config
.
output_hidden_states
,
output_scores
=
output_scores
,
output_scores
=
generation_config
.
output_scores
,
return_dict_in_generate
=
return_dict_in_generate
,
return_dict_in_generate
=
generation_config
.
return_dict_in_generate
,
**
model_kwargs
,
**
model_kwargs
,
)
)
else
:
else
:
raise
ValueError
(
f
"`num_beams` has to be an integer strictly superior to 0 (≥ 1), but is
{
num_beams
}
"
)
raise
ValueError
(
f
"`num_beams` has to be an integer strictly superior to 0 (≥ 1), but is
{
generation_config
.
num_beams
}
"
)
def
get_input_embeddings
(
self
):
def
get_input_embeddings
(
self
):
return
self
.
rag
.
generator
.
get_input_embeddings
()
return
self
.
rag
.
generator
.
get_input_embeddings
()
...
...
tests/test_modeling_tf_common.py
View file @
a6c850e4
...
@@ -1824,18 +1824,18 @@ class TFModelTesterMixin:
...
@@ -1824,18 +1824,18 @@ class TFModelTesterMixin:
model
.
train_on_batch
(
test_batch
,
test_batch_labels
)
model
.
train_on_batch
(
test_batch
,
test_batch_labels
)
def
_test_xla_generate
(
self
,
**
generate_kwargs
):
def
_test_xla_generate
(
self
,
**
generate_kwargs
):
def
_generate_and_check_results
(
model
,
config
,
inputs_dict
):
def
_generate_and_check_results
(
model
,
inputs_dict
):
if
"input_ids"
in
inputs_dict
:
if
"input_ids"
in
inputs_dict
:
inputs
=
inputs_dict
[
"input_ids"
]
inputs
=
inputs_dict
[
"input_ids"
]
# make sure there are no pad tokens in prompt, which may trigger unwanted behavior
# make sure there are no pad tokens in prompt, which may trigger unwanted behavior
if
config
.
pad_token_id
is
not
None
:
if
model
.
generation_
config
.
pad_token_id
is
not
None
:
if
config
.
pad_token_id
==
0
:
if
config
.
pad_token_id
==
0
:
new_pad_token
=
config
.
pad_token_id
+
1
new_pad_token
=
model
.
generation_
config
.
pad_token_id
+
1
else
:
else
:
new_pad_token
=
config
.
pad_token_id
-
1
new_pad_token
=
model
.
generation_
config
.
pad_token_id
-
1
else
:
else
:
new_pad_token
=
None
new_pad_token
=
None
inputs
=
tf
.
where
(
inputs
!=
config
.
pad_token_id
,
inputs
,
new_pad_token
)
inputs
=
tf
.
where
(
inputs
!=
model
.
generation_
config
.
pad_token_id
,
inputs
,
new_pad_token
)
elif
"input_features"
in
inputs_dict
:
elif
"input_features"
in
inputs_dict
:
inputs
=
inputs_dict
[
"input_features"
]
inputs
=
inputs_dict
[
"input_features"
]
else
:
else
:
...
@@ -1854,10 +1854,10 @@ class TFModelTesterMixin:
...
@@ -1854,10 +1854,10 @@ class TFModelTesterMixin:
model
=
model_class
(
config
)
model
=
model_class
(
config
)
if
model
.
supports_xla_generation
:
if
model
.
supports_xla_generation
:
_generate_and_check_results
(
model
,
config
,
inputs_dict
)
_generate_and_check_results
(
model
,
inputs_dict
)
else
:
else
:
with
self
.
assertRaises
(
ValueError
):
with
self
.
assertRaises
(
ValueError
):
_generate_and_check_results
(
model
,
config
,
inputs_dict
)
_generate_and_check_results
(
model
,
inputs_dict
)
def
test_xla_generate_fast
(
self
):
def
test_xla_generate_fast
(
self
):
"""
"""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment