Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
1fc4b2a1
Unverified
Commit
1fc4b2a1
authored
Jul 22, 2022
by
Joao Gante
Committed by
GitHub
Jul 22, 2022
Browse files
TF: use the correct config with `(...)EncoderDecoder` models (#18097)
parent
49354097
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
199 additions
and
74 deletions
+199
-74
src/transformers/modeling_tf_utils.py
src/transformers/modeling_tf_utils.py
+18
-12
src/transformers/models/encoder_decoder/modeling_tf_encoder_decoder.py
...ers/models/encoder_decoder/modeling_tf_encoder_decoder.py
+7
-7
src/transformers/models/vision_encoder_decoder/modeling_tf_vision_encoder_decoder.py
...ion_encoder_decoder/modeling_tf_vision_encoder_decoder.py
+7
-7
tests/models/encoder_decoder/test_modeling_encoder_decoder.py
...s/models/encoder_decoder/test_modeling_encoder_decoder.py
+84
-24
tests/models/encoder_decoder/test_modeling_tf_encoder_decoder.py
...odels/encoder_decoder/test_modeling_tf_encoder_decoder.py
+83
-24
No files found.
src/transformers/modeling_tf_utils.py
View file @
1fc4b2a1
...
...
@@ -403,8 +403,13 @@ def unpack_inputs(func):
# move any arg into kwargs, if they exist
fn_args_and_kwargs
.
update
(
dict
(
zip
(
func
.
__code__
.
co_varnames
[
1
:],
args
)))
# process the inputs and call the wrapped function
unpacked_inputs
=
input_processing
(
func
,
self
.
config
,
**
fn_args_and_kwargs
)
# Encoder Decoder models delegate the application of the configuration options to their inner models.
if
"encoder_decoder"
in
str
(
self
).
lower
():
config
=
None
else
:
config
=
self
.
config
unpacked_inputs
=
input_processing
(
func
,
config
,
**
fn_args_and_kwargs
)
return
func
(
self
,
**
unpacked_inputs
)
# Keras enforces the first layer argument to be passed, and checks it through `inspect.getfullargspec()`. This
...
...
@@ -559,18 +564,19 @@ def input_processing(func, config, **kwargs):
if
"kwargs"
in
output
:
del
output
[
"kwargs"
]
boolean_dict
=
{
k
:
v
for
k
,
v
in
output
.
items
()
if
k
in
[
"return_dict"
,
"output_attentions"
,
"output_hidden_states"
,
"use_cache"
]
}
if
config
is
not
None
:
boolean_dict
=
{
k
:
v
for
k
,
v
in
output
.
items
()
if
k
in
[
"return_dict"
,
"output_attentions"
,
"output_hidden_states"
,
"use_cache"
]
}
output
.
update
(
booleans_processing
(
config
=
config
,
**
boolean_dict
,
output
.
update
(
booleans_processing
(
config
=
config
,
**
boolean_dict
,
)
)
)
return
output
...
...
src/transformers/models/encoder_decoder/modeling_tf_encoder_decoder.py
View file @
1fc4b2a1
...
...
@@ -630,13 +630,13 @@ class TFEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLoss):
warnings
.
warn
(
DEPRECATION_WARNING
,
FutureWarning
)
loss
=
self
.
hf_compute_loss
(
labels
,
logits
)
past_key_values
=
None
if
decoder_inputs
[
"use_cache"
]:
past_key_values
=
decoder_outputs
[
1
]
# The starting index of the remaining elements in `decoder_outputs`
start_index
=
sum
([
1
if
x
is
not
None
else
0
for
x
in
(
loss
,
logits
,
past_key_values
)])
if
not
return_dict
:
past_key_values
=
None
if
use_cache
:
past_key_values
=
decoder_outputs
[
1
]
# The starting index of the remaining elements in `decoder_outputs`
start_index
=
sum
([
1
if
x
is
not
None
else
0
for
x
in
(
loss
,
logits
,
past_key_values
)])
if
not
decoder_inputs
[
"return_dict"
]:
if
not
isinstance
(
encoder_outputs
,
tuple
):
encoder_outputs
=
encoder_outputs
.
to_tuple
()
output
=
(
loss
,
logits
,
past_key_values
)
+
decoder_outputs
[
start_index
:]
+
encoder_outputs
...
...
@@ -646,7 +646,7 @@ class TFEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLoss):
return
TFSeq2SeqLMOutput
(
loss
=
loss
,
logits
=
decoder_outputs
.
logits
,
past_key_values
=
past_key_values
,
past_key_values
=
decoder_outputs
.
past_key_values
,
decoder_hidden_states
=
decoder_outputs
.
hidden_states
,
decoder_attentions
=
decoder_outputs
.
attentions
,
cross_attentions
=
decoder_outputs
.
cross_attentions
,
...
...
src/transformers/models/vision_encoder_decoder/modeling_tf_vision_encoder_decoder.py
View file @
1fc4b2a1
...
...
@@ -663,13 +663,13 @@ class TFVisionEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLos
warnings
.
warn
(
DEPRECATION_WARNING
,
FutureWarning
)
loss
=
self
.
hf_compute_loss
(
labels
,
logits
)
past_key_values
=
None
if
decoder_inputs
[
"use_cache"
]:
past_key_values
=
decoder_outputs
[
1
]
# The starting index of the remaining elements in `decoder_outputs`
start_index
=
sum
([
1
if
x
is
not
None
else
0
for
x
in
(
loss
,
logits
,
past_key_values
)])
if
not
return_dict
:
past_key_values
=
None
if
use_cache
:
past_key_values
=
decoder_outputs
[
1
]
# The starting index of the remaining elements in `decoder_outputs`
start_index
=
sum
([
1
if
x
is
not
None
else
0
for
x
in
(
loss
,
logits
,
past_key_values
)])
if
not
decoder_inputs
[
"return_dict"
]:
if
not
isinstance
(
encoder_outputs
,
tuple
):
encoder_outputs
=
encoder_outputs
.
to_tuple
()
output
=
(
loss
,
logits
,
past_key_values
)
+
decoder_outputs
[
start_index
:]
+
encoder_outputs
...
...
@@ -679,7 +679,7 @@ class TFVisionEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLos
return
TFSeq2SeqLMOutput
(
loss
=
loss
,
logits
=
decoder_outputs
.
logits
,
past_key_values
=
past_key_values
,
past_key_values
=
decoder_outputs
.
past_key_values
,
decoder_hidden_states
=
decoder_outputs
.
hidden_states
,
decoder_attentions
=
decoder_outputs
.
attentions
,
cross_attentions
=
decoder_outputs
.
cross_attentions
,
...
...
tests/models/encoder_decoder/test_modeling_encoder_decoder.py
View file @
1fc4b2a1
...
...
@@ -351,6 +351,40 @@ class EncoderDecoderMixin:
outputs_encoder_decoder
[
"encoder_last_hidden_state"
].
shape
,
(
input_ids
.
shape
+
(
config
.
hidden_size
,))
)
def
_check_output_with_attentions
(
self
,
outputs_encoder_decoder
,
config
,
input_ids
,
decoder_config
,
decoder_input_ids
):
encoder_attentions
=
outputs_encoder_decoder
[
"encoder_attentions"
]
self
.
assertEqual
(
len
(
encoder_attentions
),
config
.
num_hidden_layers
)
self
.
assertEqual
(
encoder_attentions
[
0
].
shape
[
-
3
:],
(
config
.
num_attention_heads
,
input_ids
.
shape
[
-
1
],
input_ids
.
shape
[
-
1
])
)
decoder_attentions
=
outputs_encoder_decoder
[
"decoder_attentions"
]
num_decoder_layers
=
(
decoder_config
.
num_decoder_layers
if
hasattr
(
decoder_config
,
"num_decoder_layers"
)
else
decoder_config
.
num_hidden_layers
)
self
.
assertEqual
(
len
(
decoder_attentions
),
num_decoder_layers
)
self
.
assertEqual
(
decoder_attentions
[
0
].
shape
[
-
3
:],
(
decoder_config
.
num_attention_heads
,
decoder_input_ids
.
shape
[
-
1
],
decoder_input_ids
.
shape
[
-
1
]),
)
cross_attentions
=
outputs_encoder_decoder
[
"cross_attentions"
]
self
.
assertEqual
(
len
(
cross_attentions
),
num_decoder_layers
)
cross_attention_input_seq_len
=
decoder_input_ids
.
shape
[
-
1
]
*
(
1
+
(
decoder_config
.
ngram
if
hasattr
(
decoder_config
,
"ngram"
)
else
0
)
)
self
.
assertEqual
(
cross_attentions
[
0
].
shape
[
-
3
:],
(
decoder_config
.
num_attention_heads
,
cross_attention_input_seq_len
,
input_ids
.
shape
[
-
1
]),
)
def
check_encoder_decoder_model_output_attentions
(
self
,
config
,
...
...
@@ -376,36 +410,58 @@ class EncoderDecoderMixin:
decoder_attention_mask
=
decoder_attention_mask
,
output_attentions
=
True
,
)
encoder_attentions
=
outputs_encoder_decoder
[
"encoder_attentions"
]
self
.
assertEqual
(
len
(
encoder_attentions
),
config
.
num_hidden_layers
)
self
.
assertEqual
(
encoder_attentions
[
0
].
shape
[
-
3
:],
(
config
.
num_attention_heads
,
input_ids
.
shape
[
-
1
],
input_ids
.
shape
[
-
1
])
self
.
_check_output_with_attentions
(
outputs_encoder_decoder
,
config
,
input_ids
,
decoder_config
,
decoder_input_ids
)
decoder_attentions
=
outputs_encoder_decoder
[
"decoder_attentions"
]
num_decoder_layers
=
(
decoder_config
.
num_decoder_layers
if
hasattr
(
decoder_config
,
"num_decoder_layers"
)
else
decoder_config
.
num_hidden_layers
)
self
.
assertEqual
(
len
(
decoder_attentions
),
num_decoder_layers
)
def
check_encoder_decoder_model_output_attentions_from_config
(
self
,
config
,
input_ids
,
attention_mask
,
encoder_hidden_states
,
decoder_config
,
decoder_input_ids
,
decoder_attention_mask
,
labels
,
**
kwargs
):
# Similar to `check_encoder_decoder_model_output_attentions`, but with `output_attentions` triggered from the
# config file. Contrarily to most models, changing the model's config won't work -- the defaults are loaded
# from the inner models' configurations.
self
.
assertEqual
(
decoder_attentions
[
0
].
shape
[
-
3
:],
(
decoder_config
.
num_attention_heads
,
decoder_input_ids
.
shape
[
-
1
],
decoder_input_ids
.
shape
[
-
1
]),
decoder_input_ids
=
decoder_input_ids
[:,
:
-
1
]
decoder_attention_mask
=
decoder_attention_mask
[:,
:
-
1
]
encoder_model
,
decoder_model
=
self
.
get_encoder_decoder_model
(
config
,
decoder_config
)
enc_dec_model
=
EncoderDecoderModel
(
encoder
=
encoder_model
,
decoder
=
decoder_model
)
enc_dec_model
.
config
.
output_attentions
=
True
# model config -> won't work
enc_dec_model
.
to
(
torch_device
)
outputs_encoder_decoder
=
enc_dec_model
(
input_ids
=
input_ids
,
decoder_input_ids
=
decoder_input_ids
,
attention_mask
=
attention_mask
,
decoder_attention_mask
=
decoder_attention_mask
,
)
self
.
assertTrue
(
all
(
key
not
in
outputs_encoder_decoder
for
key
in
[
"encoder_attentions"
,
"decoder_attentions"
,
"cross_attentions"
]
)
)
cross_attentions
=
outputs_encoder_decoder
[
"cross_attentions"
]
self
.
assertEqual
(
len
(
cross_attentions
),
num_decoder_layers
)
cross_attention_input_seq_len
=
decoder_input_ids
.
shape
[
-
1
]
*
(
1
+
(
decoder_config
.
ngram
if
hasattr
(
decoder_config
,
"ngram"
)
else
0
)
config
.
output_attentions
=
True
# inner model config -> will work
decoder_config
.
output_attentions
=
True
encoder_model
,
decoder_model
=
self
.
get_encoder_decoder_model
(
config
,
decoder_config
)
enc_dec_model
=
EncoderDecoderModel
(
encoder
=
encoder_model
,
decoder
=
decoder_model
)
enc_dec_model
.
to
(
torch_device
)
outputs_encoder_decoder
=
enc_dec_model
(
input_ids
=
input_ids
,
decoder_input_ids
=
decoder_input_ids
,
attention_mask
=
attention_mask
,
decoder_attention_mask
=
decoder_attention_mask
,
)
self
.
assertEqual
(
cross_attentions
[
0
].
shape
[
-
3
:],
(
decoder_config
.
num_attention_heads
,
cross_attention_input_seq_len
,
input_ids
.
shape
[
-
1
]),
self
.
_check_output_with_attentions
(
outputs_encoder_decoder
,
config
,
input_ids
,
decoder_config
,
decoder_input_ids
)
def
check_encoder_decoder_model_generate
(
self
,
input_ids
,
config
,
decoder_config
,
**
kwargs
):
...
...
@@ -543,6 +599,10 @@ class EncoderDecoderMixin:
input_ids_dict
=
self
.
prepare_config_and_inputs
()
self
.
check_encoder_decoder_model_output_attentions
(
**
input_ids_dict
)
def
test_encoder_decoder_model_output_attentions_from_config
(
self
):
input_ids_dict
=
self
.
prepare_config_and_inputs
()
self
.
check_encoder_decoder_model_output_attentions_from_config
(
**
input_ids_dict
)
def
test_encoder_decoder_model_generate
(
self
):
input_ids_dict
=
self
.
prepare_config_and_inputs
()
self
.
check_encoder_decoder_model_generate
(
**
input_ids_dict
)
...
...
tests/models/encoder_decoder/test_modeling_tf_encoder_decoder.py
View file @
1fc4b2a1
...
...
@@ -255,31 +255,9 @@ class TFEncoderDecoderMixin:
outputs_encoder_decoder
[
"encoder_last_hidden_state"
].
shape
,
(
input_ids
.
shape
+
(
config
.
hidden_size
,))
)
def
check_encoder_decoder_model_output_attentions
(
self
,
config
,
input_ids
,
attention_mask
,
encoder_hidden_states
,
decoder_config
,
decoder_input_ids
,
decoder_attention_mask
,
**
kwargs
def
_check_output_with_attentions
(
self
,
outputs_encoder_decoder
,
config
,
input_ids
,
decoder_config
,
decoder_input_ids
):
# make the decoder inputs a different shape from the encoder inputs to harden the test
decoder_input_ids
=
decoder_input_ids
[:,
:
-
1
]
decoder_attention_mask
=
decoder_attention_mask
[:,
:
-
1
]
encoder_model
,
decoder_model
=
self
.
get_encoder_decoder_model
(
config
,
decoder_config
)
enc_dec_model
=
TFEncoderDecoderModel
(
encoder
=
encoder_model
,
decoder
=
decoder_model
)
outputs_encoder_decoder
=
enc_dec_model
(
input_ids
=
input_ids
,
decoder_input_ids
=
decoder_input_ids
,
attention_mask
=
attention_mask
,
decoder_attention_mask
=
decoder_attention_mask
,
output_attentions
=
True
,
kwargs
=
kwargs
,
)
encoder_attentions
=
outputs_encoder_decoder
[
"encoder_attentions"
]
self
.
assertEqual
(
len
(
encoder_attentions
),
config
.
num_hidden_layers
)
...
...
@@ -311,6 +289,83 @@ class TFEncoderDecoderMixin:
(
decoder_config
.
num_attention_heads
,
cross_attention_input_seq_len
,
input_ids
.
shape
[
-
1
]),
)
def
check_encoder_decoder_model_output_attentions
(
self
,
config
,
input_ids
,
attention_mask
,
encoder_hidden_states
,
decoder_config
,
decoder_input_ids
,
decoder_attention_mask
,
**
kwargs
):
# make the decoder inputs a different shape from the encoder inputs to harden the test
decoder_input_ids
=
decoder_input_ids
[:,
:
-
1
]
decoder_attention_mask
=
decoder_attention_mask
[:,
:
-
1
]
encoder_model
,
decoder_model
=
self
.
get_encoder_decoder_model
(
config
,
decoder_config
)
enc_dec_model
=
TFEncoderDecoderModel
(
encoder
=
encoder_model
,
decoder
=
decoder_model
)
outputs_encoder_decoder
=
enc_dec_model
(
input_ids
=
input_ids
,
decoder_input_ids
=
decoder_input_ids
,
attention_mask
=
attention_mask
,
decoder_attention_mask
=
decoder_attention_mask
,
output_attentions
=
True
,
kwargs
=
kwargs
,
)
self
.
_check_output_with_attentions
(
outputs_encoder_decoder
,
config
,
input_ids
,
decoder_config
,
decoder_input_ids
)
def
check_encoder_decoder_model_output_attentions_from_config
(
self
,
config
,
input_ids
,
attention_mask
,
encoder_hidden_states
,
decoder_config
,
decoder_input_ids
,
decoder_attention_mask
,
**
kwargs
):
# Similar to `check_encoder_decoder_model_output_attentions`, but with `output_attentions` triggered from the
# config file. Contrarily to most models, changing the model's config won't work -- the defaults are loaded
# from the inner models' configurations.
decoder_input_ids
=
decoder_input_ids
[:,
:
-
1
]
decoder_attention_mask
=
decoder_attention_mask
[:,
:
-
1
]
encoder_model
,
decoder_model
=
self
.
get_encoder_decoder_model
(
config
,
decoder_config
)
enc_dec_model
=
TFEncoderDecoderModel
(
encoder
=
encoder_model
,
decoder
=
decoder_model
)
enc_dec_model
.
config
.
output_attentions
=
True
# model config -> won't work
outputs_encoder_decoder
=
enc_dec_model
(
input_ids
=
input_ids
,
decoder_input_ids
=
decoder_input_ids
,
attention_mask
=
attention_mask
,
decoder_attention_mask
=
decoder_attention_mask
,
kwargs
=
kwargs
,
)
self
.
assertTrue
(
all
(
key
not
in
outputs_encoder_decoder
for
key
in
[
"encoder_attentions"
,
"decoder_attentions"
,
"cross_attentions"
]
)
)
config
.
output_attentions
=
True
# inner model config -> will work
decoder_config
.
output_attentions
=
True
encoder_model
,
decoder_model
=
self
.
get_encoder_decoder_model
(
config
,
decoder_config
)
enc_dec_model
=
TFEncoderDecoderModel
(
encoder
=
encoder_model
,
decoder
=
decoder_model
)
outputs_encoder_decoder
=
enc_dec_model
(
input_ids
=
input_ids
,
decoder_input_ids
=
decoder_input_ids
,
attention_mask
=
attention_mask
,
decoder_attention_mask
=
decoder_attention_mask
,
kwargs
=
kwargs
,
)
self
.
_check_output_with_attentions
(
outputs_encoder_decoder
,
config
,
input_ids
,
decoder_config
,
decoder_input_ids
)
def
check_encoder_decoder_model_generate
(
self
,
input_ids
,
config
,
decoder_config
,
**
kwargs
):
encoder_model
,
decoder_model
=
self
.
get_encoder_decoder_model
(
config
,
decoder_config
)
enc_dec_model
=
TFEncoderDecoderModel
(
encoder
=
encoder_model
,
decoder
=
decoder_model
)
...
...
@@ -570,6 +625,10 @@ class TFEncoderDecoderMixin:
input_ids_dict
=
self
.
prepare_config_and_inputs
()
self
.
check_encoder_decoder_model_output_attentions
(
**
input_ids_dict
)
def
test_encoder_decoder_model_output_attentions_from_config
(
self
):
input_ids_dict
=
self
.
prepare_config_and_inputs
()
self
.
check_encoder_decoder_model_output_attentions_from_config
(
**
input_ids_dict
)
def
test_encoder_decoder_model_generate
(
self
):
input_ids_dict
=
self
.
prepare_config_and_inputs
()
self
.
check_encoder_decoder_model_generate
(
**
input_ids_dict
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment