"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "ab7551cd7ff84cb5b7328bc37a06e06fa19f02bb"
Unverified Commit 51e0ebed authored by jsnfly's avatar jsnfly Committed by GitHub
Browse files

Allow passing encoder_ouputs as tuple to EncoderDecoder Models (#16814)



* Add passing encoder_outputs as tuple to existing test

* Add check for tuple

* Add check for tuple also for speech and vision
Co-authored-by: default avatarjsnfly <jsnfly@gmx.de>
parent 51fa7191
...@@ -22,7 +22,7 @@ from torch import nn ...@@ -22,7 +22,7 @@ from torch import nn
from torch.nn import CrossEntropyLoss from torch.nn import CrossEntropyLoss
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...modeling_outputs import Seq2SeqLMOutput from ...modeling_outputs import BaseModelOutput, Seq2SeqLMOutput
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
from ..auto.configuration_auto import AutoConfig from ..auto.configuration_auto import AutoConfig
...@@ -494,6 +494,8 @@ class EncoderDecoderModel(PreTrainedModel): ...@@ -494,6 +494,8 @@ class EncoderDecoderModel(PreTrainedModel):
return_dict=return_dict, return_dict=return_dict,
**kwargs_encoder, **kwargs_encoder,
) )
elif isinstance(encoder_outputs, tuple):
encoder_outputs = BaseModelOutput(*encoder_outputs)
encoder_hidden_states = encoder_outputs[0] encoder_hidden_states = encoder_outputs[0]
......
...@@ -22,7 +22,7 @@ from torch import nn ...@@ -22,7 +22,7 @@ from torch import nn
from torch.nn import CrossEntropyLoss from torch.nn import CrossEntropyLoss
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...modeling_outputs import Seq2SeqLMOutput from ...modeling_outputs import BaseModelOutput, Seq2SeqLMOutput
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
from ..auto.configuration_auto import AutoConfig from ..auto.configuration_auto import AutoConfig
...@@ -514,6 +514,8 @@ class SpeechEncoderDecoderModel(PreTrainedModel): ...@@ -514,6 +514,8 @@ class SpeechEncoderDecoderModel(PreTrainedModel):
return_dict=return_dict, return_dict=return_dict,
**kwargs_encoder, **kwargs_encoder,
) )
elif isinstance(encoder_outputs, tuple):
encoder_outputs = BaseModelOutput(*encoder_outputs)
encoder_hidden_states = encoder_outputs[0] encoder_hidden_states = encoder_outputs[0]
......
...@@ -22,7 +22,7 @@ from torch import nn ...@@ -22,7 +22,7 @@ from torch import nn
from torch.nn import CrossEntropyLoss from torch.nn import CrossEntropyLoss
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...modeling_outputs import Seq2SeqLMOutput from ...modeling_outputs import BaseModelOutput, Seq2SeqLMOutput
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
from ..auto.configuration_auto import AutoConfig from ..auto.configuration_auto import AutoConfig
...@@ -466,6 +466,8 @@ class VisionEncoderDecoderModel(PreTrainedModel): ...@@ -466,6 +466,8 @@ class VisionEncoderDecoderModel(PreTrainedModel):
return_dict=return_dict, return_dict=return_dict,
**kwargs_encoder, **kwargs_encoder,
) )
elif isinstance(encoder_outputs, tuple):
encoder_outputs = BaseModelOutput(*encoder_outputs)
encoder_hidden_states = encoder_outputs[0] encoder_hidden_states = encoder_outputs[0]
......
...@@ -142,6 +142,22 @@ class EncoderDecoderMixin: ...@@ -142,6 +142,22 @@ class EncoderDecoderMixin:
outputs_encoder_decoder["encoder_last_hidden_state"].shape, (input_ids.shape + (config.hidden_size,)) outputs_encoder_decoder["encoder_last_hidden_state"].shape, (input_ids.shape + (config.hidden_size,))
) )
# Test passing encoder_outputs as tuple.
encoder_outputs = (encoder_hidden_states,)
outputs_encoder_decoder = enc_dec_model(
encoder_outputs=encoder_outputs,
decoder_input_ids=decoder_input_ids,
attention_mask=attention_mask,
decoder_attention_mask=decoder_attention_mask,
)
self.assertEqual(
outputs_encoder_decoder["logits"].shape, (decoder_input_ids.shape + (decoder_config.vocab_size,))
)
self.assertEqual(
outputs_encoder_decoder["encoder_last_hidden_state"].shape, (input_ids.shape + (config.hidden_size,))
)
def check_encoder_decoder_model_from_pretrained_using_model_paths( def check_encoder_decoder_model_from_pretrained_using_model_paths(
self, self,
config, config,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment