Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
1fc4b2a1
Unverified
Commit
1fc4b2a1
authored
Jul 22, 2022
by
Joao Gante
Committed by
GitHub
Jul 22, 2022
Browse files
TF: use the correct config with `(...)EncoderDecoder` models (#18097)
parent
49354097
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
199 additions
and
74 deletions
+199
-74
src/transformers/modeling_tf_utils.py
src/transformers/modeling_tf_utils.py
+18
-12
src/transformers/models/encoder_decoder/modeling_tf_encoder_decoder.py
...ers/models/encoder_decoder/modeling_tf_encoder_decoder.py
+7
-7
src/transformers/models/vision_encoder_decoder/modeling_tf_vision_encoder_decoder.py
...ion_encoder_decoder/modeling_tf_vision_encoder_decoder.py
+7
-7
tests/models/encoder_decoder/test_modeling_encoder_decoder.py
...s/models/encoder_decoder/test_modeling_encoder_decoder.py
+84
-24
tests/models/encoder_decoder/test_modeling_tf_encoder_decoder.py
...odels/encoder_decoder/test_modeling_tf_encoder_decoder.py
+83
-24
No files found.
src/transformers/modeling_tf_utils.py
View file @
1fc4b2a1
...
@@ -403,8 +403,13 @@ def unpack_inputs(func):
...
@@ -403,8 +403,13 @@ def unpack_inputs(func):
# move any arg into kwargs, if they exist
# move any arg into kwargs, if they exist
fn_args_and_kwargs
.
update
(
dict
(
zip
(
func
.
__code__
.
co_varnames
[
1
:],
args
)))
fn_args_and_kwargs
.
update
(
dict
(
zip
(
func
.
__code__
.
co_varnames
[
1
:],
args
)))
# process the inputs and call the wrapped function
# Encoder Decoder models delegate the application of the configuration options to their inner models.
unpacked_inputs
=
input_processing
(
func
,
self
.
config
,
**
fn_args_and_kwargs
)
if
"encoder_decoder"
in
str
(
self
).
lower
():
config
=
None
else
:
config
=
self
.
config
unpacked_inputs
=
input_processing
(
func
,
config
,
**
fn_args_and_kwargs
)
return
func
(
self
,
**
unpacked_inputs
)
return
func
(
self
,
**
unpacked_inputs
)
# Keras enforces the first layer argument to be passed, and checks it through `inspect.getfullargspec()`. This
# Keras enforces the first layer argument to be passed, and checks it through `inspect.getfullargspec()`. This
...
@@ -559,18 +564,19 @@ def input_processing(func, config, **kwargs):
...
@@ -559,18 +564,19 @@ def input_processing(func, config, **kwargs):
if
"kwargs"
in
output
:
if
"kwargs"
in
output
:
del
output
[
"kwargs"
]
del
output
[
"kwargs"
]
boolean_dict
=
{
if
config
is
not
None
:
k
:
v
boolean_dict
=
{
for
k
,
v
in
output
.
items
()
k
:
v
if
k
in
[
"return_dict"
,
"output_attentions"
,
"output_hidden_states"
,
"use_cache"
]
for
k
,
v
in
output
.
items
()
}
if
k
in
[
"return_dict"
,
"output_attentions"
,
"output_hidden_states"
,
"use_cache"
]
}
output
.
update
(
output
.
update
(
booleans_processing
(
booleans_processing
(
config
=
config
,
config
=
config
,
**
boolean_dict
,
**
boolean_dict
,
)
)
)
)
return
output
return
output
...
...
src/transformers/models/encoder_decoder/modeling_tf_encoder_decoder.py
View file @
1fc4b2a1
...
@@ -630,13 +630,13 @@ class TFEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLoss):
...
@@ -630,13 +630,13 @@ class TFEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLoss):
warnings
.
warn
(
DEPRECATION_WARNING
,
FutureWarning
)
warnings
.
warn
(
DEPRECATION_WARNING
,
FutureWarning
)
loss
=
self
.
hf_compute_loss
(
labels
,
logits
)
loss
=
self
.
hf_compute_loss
(
labels
,
logits
)
past_key_values
=
None
if
not
return_dict
:
if
decoder_inputs
[
"use_cache"
]:
past_key_values
=
None
past_key_values
=
decoder_outputs
[
1
]
if
use_cache
:
# The starting index of the remaining elements in `decoder_outputs`
past_key_values
=
decoder_outputs
[
1
]
start_index
=
sum
([
1
if
x
is
not
None
else
0
for
x
in
(
loss
,
logits
,
past_key_values
)])
# The starting index of the remaining elements in `decoder_outputs`
start_index
=
sum
([
1
if
x
is
not
None
else
0
for
x
in
(
loss
,
logits
,
past_key_values
)])
if
not
decoder_inputs
[
"return_dict"
]:
if
not
isinstance
(
encoder_outputs
,
tuple
):
if
not
isinstance
(
encoder_outputs
,
tuple
):
encoder_outputs
=
encoder_outputs
.
to_tuple
()
encoder_outputs
=
encoder_outputs
.
to_tuple
()
output
=
(
loss
,
logits
,
past_key_values
)
+
decoder_outputs
[
start_index
:]
+
encoder_outputs
output
=
(
loss
,
logits
,
past_key_values
)
+
decoder_outputs
[
start_index
:]
+
encoder_outputs
...
@@ -646,7 +646,7 @@ class TFEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLoss):
...
@@ -646,7 +646,7 @@ class TFEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLoss):
return
TFSeq2SeqLMOutput
(
return
TFSeq2SeqLMOutput
(
loss
=
loss
,
loss
=
loss
,
logits
=
decoder_outputs
.
logits
,
logits
=
decoder_outputs
.
logits
,
past_key_values
=
past_key_values
,
past_key_values
=
decoder_outputs
.
past_key_values
,
decoder_hidden_states
=
decoder_outputs
.
hidden_states
,
decoder_hidden_states
=
decoder_outputs
.
hidden_states
,
decoder_attentions
=
decoder_outputs
.
attentions
,
decoder_attentions
=
decoder_outputs
.
attentions
,
cross_attentions
=
decoder_outputs
.
cross_attentions
,
cross_attentions
=
decoder_outputs
.
cross_attentions
,
...
...
src/transformers/models/vision_encoder_decoder/modeling_tf_vision_encoder_decoder.py
View file @
1fc4b2a1
...
@@ -663,13 +663,13 @@ class TFVisionEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLos
...
@@ -663,13 +663,13 @@ class TFVisionEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLos
warnings
.
warn
(
DEPRECATION_WARNING
,
FutureWarning
)
warnings
.
warn
(
DEPRECATION_WARNING
,
FutureWarning
)
loss
=
self
.
hf_compute_loss
(
labels
,
logits
)
loss
=
self
.
hf_compute_loss
(
labels
,
logits
)
past_key_values
=
None
if
not
return_dict
:
if
decoder_inputs
[
"use_cache"
]:
past_key_values
=
None
past_key_values
=
decoder_outputs
[
1
]
if
use_cache
:
# The starting index of the remaining elements in `decoder_outputs`
past_key_values
=
decoder_outputs
[
1
]
start_index
=
sum
([
1
if
x
is
not
None
else
0
for
x
in
(
loss
,
logits
,
past_key_values
)])
# The starting index of the remaining elements in `decoder_outputs`
start_index
=
sum
([
1
if
x
is
not
None
else
0
for
x
in
(
loss
,
logits
,
past_key_values
)])
if
not
decoder_inputs
[
"return_dict"
]:
if
not
isinstance
(
encoder_outputs
,
tuple
):
if
not
isinstance
(
encoder_outputs
,
tuple
):
encoder_outputs
=
encoder_outputs
.
to_tuple
()
encoder_outputs
=
encoder_outputs
.
to_tuple
()
output
=
(
loss
,
logits
,
past_key_values
)
+
decoder_outputs
[
start_index
:]
+
encoder_outputs
output
=
(
loss
,
logits
,
past_key_values
)
+
decoder_outputs
[
start_index
:]
+
encoder_outputs
...
@@ -679,7 +679,7 @@ class TFVisionEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLos
...
@@ -679,7 +679,7 @@ class TFVisionEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLos
return
TFSeq2SeqLMOutput
(
return
TFSeq2SeqLMOutput
(
loss
=
loss
,
loss
=
loss
,
logits
=
decoder_outputs
.
logits
,
logits
=
decoder_outputs
.
logits
,
past_key_values
=
past_key_values
,
past_key_values
=
decoder_outputs
.
past_key_values
,
decoder_hidden_states
=
decoder_outputs
.
hidden_states
,
decoder_hidden_states
=
decoder_outputs
.
hidden_states
,
decoder_attentions
=
decoder_outputs
.
attentions
,
decoder_attentions
=
decoder_outputs
.
attentions
,
cross_attentions
=
decoder_outputs
.
cross_attentions
,
cross_attentions
=
decoder_outputs
.
cross_attentions
,
...
...
tests/models/encoder_decoder/test_modeling_encoder_decoder.py
View file @
1fc4b2a1
...
@@ -351,6 +351,40 @@ class EncoderDecoderMixin:
...
@@ -351,6 +351,40 @@ class EncoderDecoderMixin:
outputs_encoder_decoder
[
"encoder_last_hidden_state"
].
shape
,
(
input_ids
.
shape
+
(
config
.
hidden_size
,))
outputs_encoder_decoder
[
"encoder_last_hidden_state"
].
shape
,
(
input_ids
.
shape
+
(
config
.
hidden_size
,))
)
)
def
_check_output_with_attentions
(
self
,
outputs_encoder_decoder
,
config
,
input_ids
,
decoder_config
,
decoder_input_ids
):
encoder_attentions
=
outputs_encoder_decoder
[
"encoder_attentions"
]
self
.
assertEqual
(
len
(
encoder_attentions
),
config
.
num_hidden_layers
)
self
.
assertEqual
(
encoder_attentions
[
0
].
shape
[
-
3
:],
(
config
.
num_attention_heads
,
input_ids
.
shape
[
-
1
],
input_ids
.
shape
[
-
1
])
)
decoder_attentions
=
outputs_encoder_decoder
[
"decoder_attentions"
]
num_decoder_layers
=
(
decoder_config
.
num_decoder_layers
if
hasattr
(
decoder_config
,
"num_decoder_layers"
)
else
decoder_config
.
num_hidden_layers
)
self
.
assertEqual
(
len
(
decoder_attentions
),
num_decoder_layers
)
self
.
assertEqual
(
decoder_attentions
[
0
].
shape
[
-
3
:],
(
decoder_config
.
num_attention_heads
,
decoder_input_ids
.
shape
[
-
1
],
decoder_input_ids
.
shape
[
-
1
]),
)
cross_attentions
=
outputs_encoder_decoder
[
"cross_attentions"
]
self
.
assertEqual
(
len
(
cross_attentions
),
num_decoder_layers
)
cross_attention_input_seq_len
=
decoder_input_ids
.
shape
[
-
1
]
*
(
1
+
(
decoder_config
.
ngram
if
hasattr
(
decoder_config
,
"ngram"
)
else
0
)
)
self
.
assertEqual
(
cross_attentions
[
0
].
shape
[
-
3
:],
(
decoder_config
.
num_attention_heads
,
cross_attention_input_seq_len
,
input_ids
.
shape
[
-
1
]),
)
def
check_encoder_decoder_model_output_attentions
(
def
check_encoder_decoder_model_output_attentions
(
self
,
self
,
config
,
config
,
...
@@ -376,36 +410,58 @@ class EncoderDecoderMixin:
...
@@ -376,36 +410,58 @@ class EncoderDecoderMixin:
decoder_attention_mask
=
decoder_attention_mask
,
decoder_attention_mask
=
decoder_attention_mask
,
output_attentions
=
True
,
output_attentions
=
True
,
)
)
self
.
_check_output_with_attentions
(
encoder_attentions
=
outputs_encoder_decoder
[
"encoder_attentions"
]
outputs_encoder_decoder
,
config
,
input_ids
,
decoder_config
,
decoder_input_ids
self
.
assertEqual
(
len
(
encoder_attentions
),
config
.
num_hidden_layers
)
self
.
assertEqual
(
encoder_attentions
[
0
].
shape
[
-
3
:],
(
config
.
num_attention_heads
,
input_ids
.
shape
[
-
1
],
input_ids
.
shape
[
-
1
])
)
)
decoder_attentions
=
outputs_encoder_decoder
[
"decoder_attentions"
]
def
check_encoder_decoder_model_output_attentions_from_config
(
num_decoder_layers
=
(
self
,
decoder_config
.
num_decoder_layers
config
,
if
hasattr
(
decoder_config
,
"num_decoder_layers"
)
input_ids
,
else
decoder_config
.
num_hidden_layers
attention_mask
,
)
encoder_hidden_states
,
self
.
assertEqual
(
len
(
decoder_attentions
),
num_decoder_layers
)
decoder_config
,
decoder_input_ids
,
decoder_attention_mask
,
labels
,
**
kwargs
):
# Similar to `check_encoder_decoder_model_output_attentions`, but with `output_attentions` triggered from the
# config file. Contrarily to most models, changing the model's config won't work -- the defaults are loaded
# from the inner models' configurations.
self
.
assertEqual
(
decoder_input_ids
=
decoder_input_ids
[:,
:
-
1
]
decoder_attentions
[
0
].
shape
[
-
3
:],
decoder_attention_mask
=
decoder_attention_mask
[:,
:
-
1
]
(
decoder_config
.
num_attention_heads
,
decoder_input_ids
.
shape
[
-
1
],
decoder_input_ids
.
shape
[
-
1
]),
encoder_model
,
decoder_model
=
self
.
get_encoder_decoder_model
(
config
,
decoder_config
)
enc_dec_model
=
EncoderDecoderModel
(
encoder
=
encoder_model
,
decoder
=
decoder_model
)
enc_dec_model
.
config
.
output_attentions
=
True
# model config -> won't work
enc_dec_model
.
to
(
torch_device
)
outputs_encoder_decoder
=
enc_dec_model
(
input_ids
=
input_ids
,
decoder_input_ids
=
decoder_input_ids
,
attention_mask
=
attention_mask
,
decoder_attention_mask
=
decoder_attention_mask
,
)
self
.
assertTrue
(
all
(
key
not
in
outputs_encoder_decoder
for
key
in
[
"encoder_attentions"
,
"decoder_attentions"
,
"cross_attentions"
]
)
)
)
cross_attentions
=
outputs_encoder_decoder
[
"cross_attentions"
]
config
.
output_attentions
=
True
# inner model config -> will work
self
.
assertEqual
(
len
(
cross_attentions
),
num_decoder_layers
)
decoder_config
.
output_attentions
=
True
encoder_model
,
decoder_model
=
self
.
get_encoder_decoder_model
(
config
,
decoder_config
)
cross_attention_input_seq_len
=
decoder_input_ids
.
shape
[
-
1
]
*
(
enc_dec_model
=
EncoderDecoderModel
(
encoder
=
encoder_model
,
decoder
=
decoder_model
)
1
+
(
decoder_config
.
ngram
if
hasattr
(
decoder_config
,
"ngram"
)
else
0
)
enc_dec_model
.
to
(
torch_device
)
outputs_encoder_decoder
=
enc_dec_model
(
input_ids
=
input_ids
,
decoder_input_ids
=
decoder_input_ids
,
attention_mask
=
attention_mask
,
decoder_attention_mask
=
decoder_attention_mask
,
)
)
self
.
assertEqual
(
self
.
_check_output_with_attentions
(
cross_attentions
[
0
].
shape
[
-
3
:],
outputs_encoder_decoder
,
config
,
input_ids
,
decoder_config
,
decoder_input_ids
(
decoder_config
.
num_attention_heads
,
cross_attention_input_seq_len
,
input_ids
.
shape
[
-
1
]),
)
)
def
check_encoder_decoder_model_generate
(
self
,
input_ids
,
config
,
decoder_config
,
**
kwargs
):
def
check_encoder_decoder_model_generate
(
self
,
input_ids
,
config
,
decoder_config
,
**
kwargs
):
...
@@ -543,6 +599,10 @@ class EncoderDecoderMixin:
...
@@ -543,6 +599,10 @@ class EncoderDecoderMixin:
input_ids_dict
=
self
.
prepare_config_and_inputs
()
input_ids_dict
=
self
.
prepare_config_and_inputs
()
self
.
check_encoder_decoder_model_output_attentions
(
**
input_ids_dict
)
self
.
check_encoder_decoder_model_output_attentions
(
**
input_ids_dict
)
def
test_encoder_decoder_model_output_attentions_from_config
(
self
):
input_ids_dict
=
self
.
prepare_config_and_inputs
()
self
.
check_encoder_decoder_model_output_attentions_from_config
(
**
input_ids_dict
)
def
test_encoder_decoder_model_generate
(
self
):
def
test_encoder_decoder_model_generate
(
self
):
input_ids_dict
=
self
.
prepare_config_and_inputs
()
input_ids_dict
=
self
.
prepare_config_and_inputs
()
self
.
check_encoder_decoder_model_generate
(
**
input_ids_dict
)
self
.
check_encoder_decoder_model_generate
(
**
input_ids_dict
)
...
...
tests/models/encoder_decoder/test_modeling_tf_encoder_decoder.py
View file @
1fc4b2a1
...
@@ -255,31 +255,9 @@ class TFEncoderDecoderMixin:
...
@@ -255,31 +255,9 @@ class TFEncoderDecoderMixin:
outputs_encoder_decoder
[
"encoder_last_hidden_state"
].
shape
,
(
input_ids
.
shape
+
(
config
.
hidden_size
,))
outputs_encoder_decoder
[
"encoder_last_hidden_state"
].
shape
,
(
input_ids
.
shape
+
(
config
.
hidden_size
,))
)
)
def
check_encoder_decoder_model_output_attentions
(
def
_check_output_with_attentions
(
self
,
self
,
outputs_encoder_decoder
,
config
,
input_ids
,
decoder_config
,
decoder_input_ids
config
,
input_ids
,
attention_mask
,
encoder_hidden_states
,
decoder_config
,
decoder_input_ids
,
decoder_attention_mask
,
**
kwargs
):
):
# make the decoder inputs a different shape from the encoder inputs to harden the test
decoder_input_ids
=
decoder_input_ids
[:,
:
-
1
]
decoder_attention_mask
=
decoder_attention_mask
[:,
:
-
1
]
encoder_model
,
decoder_model
=
self
.
get_encoder_decoder_model
(
config
,
decoder_config
)
enc_dec_model
=
TFEncoderDecoderModel
(
encoder
=
encoder_model
,
decoder
=
decoder_model
)
outputs_encoder_decoder
=
enc_dec_model
(
input_ids
=
input_ids
,
decoder_input_ids
=
decoder_input_ids
,
attention_mask
=
attention_mask
,
decoder_attention_mask
=
decoder_attention_mask
,
output_attentions
=
True
,
kwargs
=
kwargs
,
)
encoder_attentions
=
outputs_encoder_decoder
[
"encoder_attentions"
]
encoder_attentions
=
outputs_encoder_decoder
[
"encoder_attentions"
]
self
.
assertEqual
(
len
(
encoder_attentions
),
config
.
num_hidden_layers
)
self
.
assertEqual
(
len
(
encoder_attentions
),
config
.
num_hidden_layers
)
...
@@ -311,6 +289,83 @@ class TFEncoderDecoderMixin:
...
@@ -311,6 +289,83 @@ class TFEncoderDecoderMixin:
(
decoder_config
.
num_attention_heads
,
cross_attention_input_seq_len
,
input_ids
.
shape
[
-
1
]),
(
decoder_config
.
num_attention_heads
,
cross_attention_input_seq_len
,
input_ids
.
shape
[
-
1
]),
)
)
def
check_encoder_decoder_model_output_attentions
(
self
,
config
,
input_ids
,
attention_mask
,
encoder_hidden_states
,
decoder_config
,
decoder_input_ids
,
decoder_attention_mask
,
**
kwargs
):
# make the decoder inputs a different shape from the encoder inputs to harden the test
decoder_input_ids
=
decoder_input_ids
[:,
:
-
1
]
decoder_attention_mask
=
decoder_attention_mask
[:,
:
-
1
]
encoder_model
,
decoder_model
=
self
.
get_encoder_decoder_model
(
config
,
decoder_config
)
enc_dec_model
=
TFEncoderDecoderModel
(
encoder
=
encoder_model
,
decoder
=
decoder_model
)
outputs_encoder_decoder
=
enc_dec_model
(
input_ids
=
input_ids
,
decoder_input_ids
=
decoder_input_ids
,
attention_mask
=
attention_mask
,
decoder_attention_mask
=
decoder_attention_mask
,
output_attentions
=
True
,
kwargs
=
kwargs
,
)
self
.
_check_output_with_attentions
(
outputs_encoder_decoder
,
config
,
input_ids
,
decoder_config
,
decoder_input_ids
)
def
check_encoder_decoder_model_output_attentions_from_config
(
self
,
config
,
input_ids
,
attention_mask
,
encoder_hidden_states
,
decoder_config
,
decoder_input_ids
,
decoder_attention_mask
,
**
kwargs
):
# Similar to `check_encoder_decoder_model_output_attentions`, but with `output_attentions` triggered from the
# config file. Contrarily to most models, changing the model's config won't work -- the defaults are loaded
# from the inner models' configurations.
decoder_input_ids
=
decoder_input_ids
[:,
:
-
1
]
decoder_attention_mask
=
decoder_attention_mask
[:,
:
-
1
]
encoder_model
,
decoder_model
=
self
.
get_encoder_decoder_model
(
config
,
decoder_config
)
enc_dec_model
=
TFEncoderDecoderModel
(
encoder
=
encoder_model
,
decoder
=
decoder_model
)
enc_dec_model
.
config
.
output_attentions
=
True
# model config -> won't work
outputs_encoder_decoder
=
enc_dec_model
(
input_ids
=
input_ids
,
decoder_input_ids
=
decoder_input_ids
,
attention_mask
=
attention_mask
,
decoder_attention_mask
=
decoder_attention_mask
,
kwargs
=
kwargs
,
)
self
.
assertTrue
(
all
(
key
not
in
outputs_encoder_decoder
for
key
in
[
"encoder_attentions"
,
"decoder_attentions"
,
"cross_attentions"
]
)
)
config
.
output_attentions
=
True
# inner model config -> will work
decoder_config
.
output_attentions
=
True
encoder_model
,
decoder_model
=
self
.
get_encoder_decoder_model
(
config
,
decoder_config
)
enc_dec_model
=
TFEncoderDecoderModel
(
encoder
=
encoder_model
,
decoder
=
decoder_model
)
outputs_encoder_decoder
=
enc_dec_model
(
input_ids
=
input_ids
,
decoder_input_ids
=
decoder_input_ids
,
attention_mask
=
attention_mask
,
decoder_attention_mask
=
decoder_attention_mask
,
kwargs
=
kwargs
,
)
self
.
_check_output_with_attentions
(
outputs_encoder_decoder
,
config
,
input_ids
,
decoder_config
,
decoder_input_ids
)
def
check_encoder_decoder_model_generate
(
self
,
input_ids
,
config
,
decoder_config
,
**
kwargs
):
def
check_encoder_decoder_model_generate
(
self
,
input_ids
,
config
,
decoder_config
,
**
kwargs
):
encoder_model
,
decoder_model
=
self
.
get_encoder_decoder_model
(
config
,
decoder_config
)
encoder_model
,
decoder_model
=
self
.
get_encoder_decoder_model
(
config
,
decoder_config
)
enc_dec_model
=
TFEncoderDecoderModel
(
encoder
=
encoder_model
,
decoder
=
decoder_model
)
enc_dec_model
=
TFEncoderDecoderModel
(
encoder
=
encoder_model
,
decoder
=
decoder_model
)
...
@@ -570,6 +625,10 @@ class TFEncoderDecoderMixin:
...
@@ -570,6 +625,10 @@ class TFEncoderDecoderMixin:
input_ids_dict
=
self
.
prepare_config_and_inputs
()
input_ids_dict
=
self
.
prepare_config_and_inputs
()
self
.
check_encoder_decoder_model_output_attentions
(
**
input_ids_dict
)
self
.
check_encoder_decoder_model_output_attentions
(
**
input_ids_dict
)
def
test_encoder_decoder_model_output_attentions_from_config
(
self
):
input_ids_dict
=
self
.
prepare_config_and_inputs
()
self
.
check_encoder_decoder_model_output_attentions_from_config
(
**
input_ids_dict
)
def
test_encoder_decoder_model_generate
(
self
):
def
test_encoder_decoder_model_generate
(
self
):
input_ids_dict
=
self
.
prepare_config_and_inputs
()
input_ids_dict
=
self
.
prepare_config_and_inputs
()
self
.
check_encoder_decoder_model_generate
(
**
input_ids_dict
)
self
.
check_encoder_decoder_model_generate
(
**
input_ids_dict
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment