Unverified Commit a574de30 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Encoder-Decoder: add informative exception when the decoder is not compatible (#23426)

parent 939a65ab
......@@ -16,6 +16,7 @@
import gc
import inspect
import os
import tempfile
import warnings
......@@ -245,6 +246,13 @@ class EncoderDecoderModel(PreTrainedModel):
f"The encoder {self.encoder} should not have a LM Head. Please use a model without LM Head"
)
decoder_signature = set(inspect.signature(self.decoder.forward).parameters.keys())
if "encoder_hidden_states" not in decoder_signature:
raise ValueError(
"The selected decoder is not prepared for the encoder hidden states to be passed. Please see the "
"following discussion on GitHub: https://github.com/huggingface/transformers/issues/23350"
)
# tie encoder, decoder weights if config set accordingly
self.tie_weights()
......
......@@ -14,7 +14,7 @@
# limitations under the License.
""" Classes to support TF Encoder-Decoder architectures"""
import inspect
import re
import warnings
from typing import Optional, Tuple, Union
......@@ -266,6 +266,13 @@ class TFEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLoss):
f"The encoder {self.encoder} should not have a LM Head. Please use a model without LM Head"
)
decoder_signature = set(inspect.signature(self.decoder.call).parameters.keys())
if "encoder_hidden_states" not in decoder_signature:
raise ValueError(
"The selected decoder is not prepared for the encoder hidden states to be passed. Please see the "
"following discussion on GitHub: https://github.com/huggingface/transformers/issues/23350"
)
@property
def dummy_inputs(self):
"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment