"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "a0f867430347bcf939f71d186409b9ca138c3b34"
Unverified Commit 8cca8755 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[EncoderDecoderConfig] automatically set decoder config to decoder (#4809)

* automatically set decoder config to decoder

* add more tests
parent f1fe1846
...@@ -85,6 +85,9 @@ class EncoderDecoderConfig(PretrainedConfig): ...@@ -85,6 +85,9 @@ class EncoderDecoderConfig(PretrainedConfig):
Returns: Returns:
:class:`EncoderDecoderConfig`: An instance of a configuration object :class:`EncoderDecoderConfig`: An instance of a configuration object
""" """
logger.info("Set `config.is_decoder=True` for decoder_config")
decoder_config.is_decoder = True
return cls(encoder=encoder_config.to_dict(), decoder=decoder_config.to_dict()) return cls(encoder=encoder_config.to_dict(), decoder=decoder_config.to_dict())
def to_dict(self): def to_dict(self):
......
...@@ -27,7 +27,7 @@ from .utils import require_torch, slow, torch_device ...@@ -27,7 +27,7 @@ from .utils import require_torch, slow, torch_device
if is_torch_available(): if is_torch_available():
from transformers import BertModel, BertForMaskedLM, EncoderDecoderModel from transformers import BertModel, BertForMaskedLM, EncoderDecoderModel, EncoderDecoderConfig
import numpy as np import numpy as np
import torch import torch
...@@ -74,6 +74,36 @@ class EncoderDecoderModelTest(unittest.TestCase): ...@@ -74,6 +74,36 @@ class EncoderDecoderModelTest(unittest.TestCase):
"labels": decoder_token_labels, "labels": decoder_token_labels,
} }
def create_and_check_bert_encoder_decoder_model_from_pretrained_configs(
self,
config,
input_ids,
attention_mask,
encoder_hidden_states,
decoder_config,
decoder_input_ids,
decoder_attention_mask,
**kwargs
):
encoder_decoder_config = EncoderDecoderConfig.from_encoder_decoder_configs(config, decoder_config)
self.assertTrue(encoder_decoder_config.decoder.is_decoder)
enc_dec_model = EncoderDecoderModel(encoder_decoder_config)
enc_dec_model.to(torch_device)
enc_dec_model.eval()
self.assertTrue(enc_dec_model.config.is_encoder_decoder)
outputs_encoder_decoder = enc_dec_model(
input_ids=input_ids,
decoder_input_ids=decoder_input_ids,
attention_mask=attention_mask,
decoder_attention_mask=decoder_attention_mask,
)
self.assertEqual(outputs_encoder_decoder[0].shape, (decoder_input_ids.shape + (decoder_config.vocab_size,)))
self.assertEqual(outputs_encoder_decoder[1].shape, (input_ids.shape + (config.hidden_size,)))
def create_and_check_bert_encoder_decoder_model( def create_and_check_bert_encoder_decoder_model(
self, self,
config, config,
...@@ -88,6 +118,8 @@ class EncoderDecoderModelTest(unittest.TestCase): ...@@ -88,6 +118,8 @@ class EncoderDecoderModelTest(unittest.TestCase):
encoder_model = BertModel(config) encoder_model = BertModel(config)
decoder_model = BertForMaskedLM(decoder_config) decoder_model = BertForMaskedLM(decoder_config)
enc_dec_model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model) enc_dec_model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
self.assertTrue(enc_dec_model.config.decoder.is_decoder)
self.assertTrue(enc_dec_model.config.is_encoder_decoder)
enc_dec_model.to(torch_device) enc_dec_model.to(torch_device)
outputs_encoder_decoder = enc_dec_model( outputs_encoder_decoder = enc_dec_model(
input_ids=input_ids, input_ids=input_ids,
...@@ -304,6 +336,10 @@ class EncoderDecoderModelTest(unittest.TestCase): ...@@ -304,6 +336,10 @@ class EncoderDecoderModelTest(unittest.TestCase):
input_ids_dict = self.prepare_config_and_inputs_bert() input_ids_dict = self.prepare_config_and_inputs_bert()
self.create_and_check_bert_encoder_decoder_model(**input_ids_dict) self.create_and_check_bert_encoder_decoder_model(**input_ids_dict)
def test_bert_encoder_decoder_model_from_pretrained_configs(self):
input_ids_dict = self.prepare_config_and_inputs_bert()
self.create_and_check_bert_encoder_decoder_model_from_pretrained_configs(**input_ids_dict)
def test_bert_encoder_decoder_model_from_pretrained(self): def test_bert_encoder_decoder_model_from_pretrained(self):
input_ids_dict = self.prepare_config_and_inputs_bert() input_ids_dict = self.prepare_config_and_inputs_bert()
self.create_and_check_bert_encoder_decoder_model_from_pretrained(**input_ids_dict) self.create_and_check_bert_encoder_decoder_model_from_pretrained(**input_ids_dict)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment