Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
a574de30
Unverified
Commit
a574de30
authored
May 17, 2023
by
Joao Gante
Committed by
GitHub
May 17, 2023
Browse files
Encoder-Decoder: add informative exception when the decoder is not compatible (#23426)
parent
939a65ab
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
16 additions
and
1 deletion
+16
-1
src/transformers/models/encoder_decoder/modeling_encoder_decoder.py
...ormers/models/encoder_decoder/modeling_encoder_decoder.py
+8
-0
src/transformers/models/encoder_decoder/modeling_tf_encoder_decoder.py
...ers/models/encoder_decoder/modeling_tf_encoder_decoder.py
+8
-1
No files found.
src/transformers/models/encoder_decoder/modeling_encoder_decoder.py
View file @
a574de30
...
@@ -16,6 +16,7 @@
...
@@ -16,6 +16,7 @@
import
gc
import
gc
import
inspect
import
os
import
os
import
tempfile
import
tempfile
import
warnings
import
warnings
...
@@ -245,6 +246,13 @@ class EncoderDecoderModel(PreTrainedModel):
...
@@ -245,6 +246,13 @@ class EncoderDecoderModel(PreTrainedModel):
f
"The encoder
{
self
.
encoder
}
should not have a LM Head. Please use a model without LM Head"
f
"The encoder
{
self
.
encoder
}
should not have a LM Head. Please use a model without LM Head"
)
)
decoder_signature
=
set
(
inspect
.
signature
(
self
.
decoder
.
forward
).
parameters
.
keys
())
if
"encoder_hidden_states"
not
in
decoder_signature
:
raise
ValueError
(
"The selected decoder is not prepared for the encoder hidden states to be passed. Please see the "
"following discussion on GitHub: https://github.com/huggingface/transformers/issues/23350"
)
# tie encoder, decoder weights if config set accordingly
# tie encoder, decoder weights if config set accordingly
self
.
tie_weights
()
self
.
tie_weights
()
...
...
src/transformers/models/encoder_decoder/modeling_tf_encoder_decoder.py
View file @
a574de30
...
@@ -14,7 +14,7 @@
...
@@ -14,7 +14,7 @@
# limitations under the License.
# limitations under the License.
""" Classes to support TF Encoder-Decoder architectures"""
""" Classes to support TF Encoder-Decoder architectures"""
import
inspect
import
re
import
re
import
warnings
import
warnings
from
typing
import
Optional
,
Tuple
,
Union
from
typing
import
Optional
,
Tuple
,
Union
...
@@ -266,6 +266,13 @@ class TFEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLoss):
...
@@ -266,6 +266,13 @@ class TFEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLoss):
f
"The encoder
{
self
.
encoder
}
should not have a LM Head. Please use a model without LM Head"
f
"The encoder
{
self
.
encoder
}
should not have a LM Head. Please use a model without LM Head"
)
)
decoder_signature
=
set
(
inspect
.
signature
(
self
.
decoder
.
call
).
parameters
.
keys
())
if
"encoder_hidden_states"
not
in
decoder_signature
:
raise
ValueError
(
"The selected decoder is not prepared for the encoder hidden states to be passed. Please see the "
"following discussion on GitHub: https://github.com/huggingface/transformers/issues/23350"
)
@
property
@
property
def
dummy_inputs
(
self
):
def
dummy_inputs
(
self
):
"""
"""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment