Unverified Commit 0866669e authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[EncoderDecoder] Fix initialization and save/load bug (#4680)

* fix bug

* add more tests
parent 6f82aea6
...@@ -35,6 +35,7 @@ class EncoderDecoderModel(PreTrainedModel): ...@@ -35,6 +35,7 @@ class EncoderDecoderModel(PreTrainedModel):
class method for the encoder and `AutoModelWithLMHead.from_pretrained(pretrained_model_name_or_path)` class method for the decoder. class method for the encoder and `AutoModelWithLMHead.from_pretrained(pretrained_model_name_or_path)` class method for the decoder.
""" """
config_class = EncoderDecoderConfig config_class = EncoderDecoderConfig
base_model_prefix = "encoder_decoder"
def __init__( def __init__(
self, self,
...@@ -158,12 +159,26 @@ class EncoderDecoderModel(PreTrainedModel): ...@@ -158,12 +159,26 @@ class EncoderDecoderModel(PreTrainedModel):
), "If `decoder_model` is not defined as an argument, a `decoder_pretrained_model_name_or_path` has to be defined" ), "If `decoder_model` is not defined as an argument, a `decoder_pretrained_model_name_or_path` has to be defined"
from .modeling_auto import AutoModelWithLMHead from .modeling_auto import AutoModelWithLMHead
decoder = AutoModelWithLMHead.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder) if "config" not in kwargs_decoder:
decoder.config.is_decoder = True from transformers import AutoConfig
decoder_config = AutoConfig.from_pretrained(decoder_pretrained_model_name_or_path)
if decoder_config.is_decoder is False:
logger.info(
f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. Cross attention layers are added to {decoder_pretrained_model_name_or_path} and randomly initialized if {decoder_pretrained_model_name_or_path}'s architecture allows for cross attention layers."
)
decoder_config.is_decoder = True
kwargs_decoder["config"] = decoder_config
model = cls(encoder=encoder, decoder=decoder) if kwargs_decoder["config"].is_decoder is False:
logger.warning(
f"Decoder model {decoder_pretrained_model_name_or_path} is not initialized as a decoder. In order to initialize {decoder_pretrained_model_name_or_path} as a decoder, make sure that the attribute `is_decoder` of `decoder_config` passed to `.from_encoder_decoder_pretrained(...)` is set to `True` or do not pass a `decoder_config` to `.from_encoder_decoder_pretrained(...)`"
)
decoder = AutoModelWithLMHead.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder)
return model return cls(encoder=encoder, decoder=decoder)
def forward( def forward(
self, self,
......
...@@ -22,6 +22,7 @@ from transformers import is_torch_available ...@@ -22,6 +22,7 @@ from transformers import is_torch_available
# TODO(PVP): this line reruns all the tests in BertModelTest; not sure whether this can be prevented # TODO(PVP): this line reruns all the tests in BertModelTest; not sure whether this can be prevented
# for now only run module with pytest tests/test_modeling_encoder_decoder.py::EncoderDecoderModelTest # for now only run module with pytest tests/test_modeling_encoder_decoder.py::EncoderDecoderModelTest
from .test_modeling_bert import BertModelTester from .test_modeling_bert import BertModelTester
from .test_modeling_common import ids_tensor
from .utils import require_torch, slow, torch_device from .utils import require_torch, slow, torch_device
...@@ -331,3 +332,33 @@ class EncoderDecoderModelTest(unittest.TestCase): ...@@ -331,3 +332,33 @@ class EncoderDecoderModelTest(unittest.TestCase):
def test_real_bert_model_from_pretrained(self): def test_real_bert_model_from_pretrained(self):
model = EncoderDecoderModel.from_encoder_decoder_pretrained("bert-base-uncased", "bert-base-uncased") model = EncoderDecoderModel.from_encoder_decoder_pretrained("bert-base-uncased", "bert-base-uncased")
self.assertIsNotNone(model) self.assertIsNotNone(model)
@slow
def test_real_bert_model_from_pretrained_has_cross_attention(self):
model = EncoderDecoderModel.from_encoder_decoder_pretrained("bert-base-uncased", "bert-base-uncased")
self.assertTrue(hasattr(model.decoder.bert.encoder.layer[0], "crossattention"))
@slow
def test_real_bert_model_save_load_from_pretrained(self):
model_2 = EncoderDecoderModel.from_encoder_decoder_pretrained("bert-base-uncased", "bert-base-uncased")
model_2.to(torch_device)
input_ids = ids_tensor([13, 5], model_2.config.encoder.vocab_size)
decoder_input_ids = ids_tensor([13, 1], model_2.config.encoder.vocab_size)
attention_mask = ids_tensor([13, 5], vocab_size=2)
with torch.no_grad():
outputs = model_2(input_ids=input_ids, decoder_input_ids=decoder_input_ids, attention_mask=attention_mask,)
out_2 = outputs[0].cpu().numpy()
out_2[np.isnan(out_2)] = 0
with tempfile.TemporaryDirectory() as tmp_dirname:
model_2.save_pretrained(tmp_dirname)
model_1 = EncoderDecoderModel.from_pretrained(tmp_dirname)
model_1.to(torch_device)
after_outputs = model_1(
input_ids=input_ids, decoder_input_ids=decoder_input_ids, attention_mask=attention_mask,
)
out_1 = after_outputs[0].cpu().numpy()
out_1[np.isnan(out_1)] = 0
max_diff = np.amax(np.abs(out_1 - out_2))
self.assertLessEqual(max_diff, 1e-5)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment