"git@developer.sourcefind.cn:wuxk1/megatron-lm.git" did not exist on "b4b0d739857f40bb42d7b4c16acd8bbcd92321ab"
Unverified Commit fa49b9af authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

Clean Encoder-Decoder models with Bart/T5-like API and add generate possibility (#3383)

* change encoder decoder style to bart & t5 style

* make encoder decoder generation dummy work for bert

* make style

* clean init config in encoder decoder

* add tests for encoder decoder models

* refactor and add last tests

* refactor and add last tests

* fix attn masks for bert encoder decoder

* make style

* refactor prepare inputs for Bert

* refactor

* finish encoder decoder

* correct typo

* add docstring to config

* finish

* add tests

* better naming

* make style

* fix flake8

* clean docstring

* make style

* rename
parent 18058574
......@@ -89,6 +89,7 @@ The library currently contains PyTorch and Tensorflow implementations, pre-train
:caption: Package Reference
model_doc/auto
model_doc/encoderdecoder
model_doc/bert
model_doc/gpt
model_doc/transformerxl
......
Encoder Decoder Models
-----------
This class can wrap an encoder model, such as ``BertModel`` and a decoder modeling with a language modeling head, such as ``BertForMaskedLM`` into a encoder-decoder model.
The ``EncoderDecoderModel`` class allows to instantiate a encoder decoder model using the ``from_encoder_decoder_pretrain`` class method taking a pretrained encoder and pretrained decoder model as an input.
The ``EncoderDecoderModel`` is saved using the standard ``save_pretrained()`` method and can also again be loaded using the standard ``from_pretrained()`` method.
An application of this architecture could be *summarization* using two pretrained Bert models as is shown in the paper: `Text Summarization with Pretrained Encoders <https://arxiv.org/abs/1910.13461>`_ by Yang Liu and Mirella Lapata.
``EncoderDecoderConfig``
~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.EncoderDecoderConfig
:members:
``EncoderDecoderModel``
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.EncoderDecoderModel
:members:
......@@ -41,6 +41,7 @@ from .configuration_camembert import CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, Ca
from .configuration_ctrl import CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP, CTRLConfig
from .configuration_distilbert import DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, DistilBertConfig
from .configuration_electra import ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP, ElectraConfig
from .configuration_encoder_decoder import EncoderDecoderConfig
from .configuration_flaubert import FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, FlaubertConfig
from .configuration_gpt2 import GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2Config
from .configuration_mmbt import MMBTConfig
......@@ -267,7 +268,7 @@ if is_torch_available():
CamembertForQuestionAnswering,
CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
)
from .modeling_encoder_decoder import PreTrainedEncoderDecoder
from .modeling_encoder_decoder import EncoderDecoderModel
from .modeling_t5 import (
T5PreTrainedModel,
T5Model,
......
......@@ -25,6 +25,7 @@ from .configuration_camembert import CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, Ca
from .configuration_ctrl import CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP, CTRLConfig
from .configuration_distilbert import DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, DistilBertConfig
from .configuration_electra import ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP, ElectraConfig
from .configuration_encoder_decoder import EncoderDecoderConfig
from .configuration_flaubert import FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, FlaubertConfig
from .configuration_gpt2 import GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2Config
from .configuration_openai import OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP, OpenAIGPTConfig
......@@ -82,6 +83,7 @@ CONFIG_MAPPING = OrderedDict(
("xlm", XLMConfig,),
("ctrl", CTRLConfig,),
("electra", ElectraConfig,),
("encoder-decoder", EncoderDecoderConfig,),
]
)
......
# coding=utf-8
# Copyright 2020 The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import logging
from .configuration_utils import PretrainedConfig
logger = logging.getLogger(__name__)
class EncoderDecoderConfig(PretrainedConfig):
r"""
:class:`~transformers.EncoderDecoderConfig` is the configuration class to store the configuration of a `EncoderDecoderModel`.
It is used to instantiate an Encoder Decoder model according to the specified arguments, defining the encoder and decoder configs.
Configuration objects inherit from :class:`~transformers.PretrainedConfig`
and can be used to control the model outputs.
See the documentation for :class:`~transformers.PretrainedConfig` for more information.
Arguments:
kwargs: (`optional`) Remaining dictionary of keyword arguments. Notably:
encoder (:class:`PretrainedConfig`, optional, defaults to `None`):
An instance of a configuration object that defines the encoder config.
encoder (:class:`PretrainedConfig`, optional, defaults to `None`):
An instance of a configuration object that defines the decoder config.
"""
model_type = "encoder_decoder"
def __init__(self, **kwargs):
super().__init__(**kwargs)
assert (
"encoder" in kwargs and "decoder" in kwargs
), "Config has to be initialized with encoder and decoder config"
encoder_config = kwargs.pop("encoder")
encoder_model_type = encoder_config.pop("model_type")
decoder_config = kwargs.pop("decoder")
decoder_model_type = decoder_config.pop("model_type")
from transformers import AutoConfig
self.encoder = AutoConfig.for_model(encoder_model_type, **encoder_config)
self.decoder = AutoConfig.for_model(decoder_model_type, **decoder_config)
self.is_encoder_decoder = True
@classmethod
def from_encoder_decoder_configs(
cls, encoder_config: PretrainedConfig, decoder_config: PretrainedConfig
) -> PretrainedConfig:
r"""
Instantiate a :class:`~transformers.EncoderDecoderConfig` (or a derived class) from a pre-trained encoder model configuration and decoder model configuration.
Returns:
:class:`EncoderDecoderConfig`: An instance of a configuration object
"""
return cls(encoder=encoder_config.to_dict(), decoder=decoder_config.to_dict())
def to_dict(self):
"""
Serializes this instance to a Python dictionary. Override the default `to_dict()` from `PretrainedConfig`.
Returns:
:obj:`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = copy.deepcopy(self.__dict__)
output["encoder"] = self.encoder.to_dict()
output["decoder"] = self.decoder.to_dict()
output["model_type"] = self.__class__.model_type
return output
......@@ -27,6 +27,7 @@ from .configuration_auto import (
CTRLConfig,
DistilBertConfig,
ElectraConfig,
EncoderDecoderConfig,
FlaubertConfig,
GPT2Config,
OpenAIGPTConfig,
......@@ -86,6 +87,7 @@ from .modeling_electra import (
ElectraForTokenClassification,
ElectraModel,
)
from .modeling_encoder_decoder import EncoderDecoderModel
from .modeling_flaubert import (
FLAUBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
FlaubertForQuestionAnsweringSimple,
......@@ -219,6 +221,7 @@ MODEL_WITH_LM_HEAD_MAPPING = OrderedDict(
(XLMConfig, XLMWithLMHeadModel),
(CTRLConfig, CTRLLMHeadModel),
(ElectraConfig, ElectraForMaskedLM),
(EncoderDecoderConfig, EncoderDecoderModel),
]
)
......
......@@ -959,6 +959,28 @@ class BertForMaskedLM(BertPreTrainedModel):
return outputs # (ltr_lm_loss), (masked_lm_loss), prediction_scores, (hidden_states), (attentions)
def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs):
input_shape = input_ids.shape
effective_batch_size = input_shape[0]
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
if attention_mask is None:
attention_mask = input_ids.new_ones(input_shape)
# if model is does not use a causal mask then add a dummy token
if self.config.is_decoder is False:
assert self.config.pad_token_id is not None, "The PAD token should be defined for generation"
attention_mask = torch.cat(
[attention_mask, attention_mask.new_zeros((attention_mask.shape[0], 1))], dim=-1
)
dummy_token = torch.full(
(effective_batch_size, 1), self.config.pad_token_id, dtype=torch.long, device=input_ids.device
)
input_ids = torch.cat([input_ids, dummy_token], dim=1)
return {"input_ids": input_ids, "attention_mask": attention_mask}
@add_start_docstrings(
"""Bert Model with a `next sentence prediction (classification)` head on top. """, BERT_START_DOCSTRING,
......
......@@ -1011,7 +1011,14 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
pad_token_id = eos_token_id
# current position and vocab size
if hasattr(self.config, "vocab_size"):
vocab_size = self.config.vocab_size
elif (
self.config.is_encoder_decoder
and hasattr(self.config, "decoder")
and hasattr(self.config.decoder, "vocab_size")
):
vocab_size = self.config.decoder.vocab_size
# set effective batch size and effective batch multiplier according to do_sample
if do_sample:
......
# coding=utf-8
# Copyright 2020 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Classes to support Encoder-Decoder architectures """
def prepare_encoder_decoder_model_kwargs(**kwargs):
""" Prepare the encoder and decoder's keyword arguments.
Keyword arguments come in 3 flavors:
- encoder-specific (prefixed by `encoder_`)
- decoder-specific (prefixed by `decoder_`)
- those that apply to the model as whole.
We let the specific kwargs override the common ones in case of
conflict.
"""
kwargs_common = {
argument: value
for argument, value in kwargs.items()
if not argument.startswith("encoder_") and not argument.startswith("decoder_")
}
if "input_ids" in kwargs_common:
kwargs["encoder_input_ids"] = kwargs_common.pop("input_ids")
decoder_kwargs = kwargs_common.copy()
encoder_kwargs = kwargs_common.copy()
encoder_kwargs.update(
{argument[len("encoder_") :]: value for argument, value in kwargs.items() if argument.startswith("encoder_")}
)
decoder_kwargs.update(
{argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_")}
)
decoder_kwargs["encoder_attention_mask"] = encoder_kwargs.get("attention_mask", None)
return encoder_kwargs, decoder_kwargs
# coding=utf-8
# Copyright 2020 HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tempfile
import unittest
from transformers import is_torch_available
# TODO(PVP): this line reruns all the tests in BertModelTest; not sure whether this can be prevented
# for now only run module with pytest tests/test_modeling_encoder_decoder.py::EncoderDecoderModelTest
from .test_modeling_bert import BertModelTest
from .utils import require_torch, slow, torch_device
if is_torch_available():
from transformers import BertModel, BertForMaskedLM, EncoderDecoderModel
import numpy as np
import torch
@require_torch
class EncoderDecoderModelTest(unittest.TestCase):
def prepare_config_and_inputs_bert(self):
bert_model_tester = BertModelTest.BertModelTester(self)
encoder_config_and_inputs = bert_model_tester.prepare_config_and_inputs()
decoder_config_and_inputs = bert_model_tester.prepare_config_and_inputs_for_decoder()
(
config,
input_ids,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
) = encoder_config_and_inputs
(
decoder_config,
decoder_input_ids,
decoder_token_type_ids,
decoder_input_mask,
decoder_sequence_labels,
decoder_token_labels,
decoder_choice_labels,
encoder_hidden_states,
encoder_attention_mask,
) = decoder_config_and_inputs
return {
"config": config,
"input_ids": input_ids,
"attention_mask": input_mask,
"decoder_config": decoder_config,
"decoder_input_ids": decoder_input_ids,
"decoder_token_type_ids": decoder_token_type_ids,
"decoder_attention_mask": decoder_input_mask,
"decoder_sequence_labels": decoder_sequence_labels,
"decoder_token_labels": decoder_token_labels,
"decoder_choice_labels": decoder_choice_labels,
"encoder_hidden_states": encoder_hidden_states,
"lm_labels": decoder_token_labels,
"masked_lm_labels": decoder_token_labels,
}
def create_and_check_bert_encoder_decoder_model(
self,
config,
input_ids,
attention_mask,
encoder_hidden_states,
decoder_config,
decoder_input_ids,
decoder_attention_mask,
**kwargs
):
encoder_model = BertModel(config)
decoder_model = BertForMaskedLM(decoder_config)
enc_dec_model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
enc_dec_model.to(torch_device)
outputs_encoder_decoder = enc_dec_model(
input_ids=input_ids,
decoder_input_ids=decoder_input_ids,
attention_mask=attention_mask,
decoder_attention_mask=decoder_attention_mask,
)
self.assertEqual(outputs_encoder_decoder[0].shape, (decoder_input_ids.shape + (decoder_config.vocab_size,)))
self.assertEqual(outputs_encoder_decoder[1].shape, (input_ids.shape + (config.hidden_size,)))
encoder_outputs = (encoder_hidden_states,)
outputs_encoder_decoder = enc_dec_model(
encoder_outputs=encoder_outputs,
decoder_input_ids=decoder_input_ids,
attention_mask=attention_mask,
decoder_attention_mask=decoder_attention_mask,
)
self.assertEqual(outputs_encoder_decoder[0].shape, (decoder_input_ids.shape + (decoder_config.vocab_size,)))
self.assertEqual(outputs_encoder_decoder[1].shape, (input_ids.shape + (config.hidden_size,)))
def create_and_check_bert_encoder_decoder_model_from_pretrained(
self,
config,
input_ids,
attention_mask,
encoder_hidden_states,
decoder_config,
decoder_input_ids,
decoder_attention_mask,
**kwargs
):
encoder_model = BertModel(config)
decoder_model = BertForMaskedLM(decoder_config)
kwargs = {"encoder_model": encoder_model, "decoder_model": decoder_model}
enc_dec_model = EncoderDecoderModel.from_encoder_decoder_pretrained(**kwargs)
enc_dec_model.to(torch_device)
outputs_encoder_decoder = enc_dec_model(
input_ids=input_ids,
decoder_input_ids=decoder_input_ids,
attention_mask=attention_mask,
decoder_attention_mask=decoder_attention_mask,
)
self.assertEqual(outputs_encoder_decoder[0].shape, (decoder_input_ids.shape + (decoder_config.vocab_size,)))
self.assertEqual(outputs_encoder_decoder[1].shape, (input_ids.shape + (config.hidden_size,)))
def create_and_check_save_and_load(
self,
config,
input_ids,
attention_mask,
encoder_hidden_states,
decoder_config,
decoder_input_ids,
decoder_attention_mask,
**kwargs
):
encoder_model = BertModel(config)
decoder_model = BertForMaskedLM(decoder_config)
enc_dec_model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
enc_dec_model.to(torch_device)
enc_dec_model.eval()
with torch.no_grad():
outputs = enc_dec_model(
input_ids=input_ids,
decoder_input_ids=decoder_input_ids,
attention_mask=attention_mask,
decoder_attention_mask=decoder_attention_mask,
)
out_2 = outputs[0].cpu().numpy()
out_2[np.isnan(out_2)] = 0
with tempfile.TemporaryDirectory() as tmpdirname:
enc_dec_model.save_pretrained(tmpdirname)
EncoderDecoderModel.from_pretrained(tmpdirname)
after_outputs = enc_dec_model(
input_ids=input_ids,
decoder_input_ids=decoder_input_ids,
attention_mask=attention_mask,
decoder_attention_mask=decoder_attention_mask,
)
out_1 = after_outputs[0].cpu().numpy()
out_1[np.isnan(out_1)] = 0
max_diff = np.amax(np.abs(out_1 - out_2))
self.assertLessEqual(max_diff, 1e-5)
def create_and_check_save_and_load_encoder_decoder_model(
self,
config,
input_ids,
attention_mask,
encoder_hidden_states,
decoder_config,
decoder_input_ids,
decoder_attention_mask,
**kwargs
):
encoder_model = BertModel(config)
decoder_model = BertForMaskedLM(decoder_config)
enc_dec_model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
enc_dec_model.to(torch_device)
enc_dec_model.eval()
with torch.no_grad():
outputs = enc_dec_model(
input_ids=input_ids,
decoder_input_ids=decoder_input_ids,
attention_mask=attention_mask,
decoder_attention_mask=decoder_attention_mask,
)
out_2 = outputs[0].cpu().numpy()
out_2[np.isnan(out_2)] = 0
with tempfile.TemporaryDirectory() as encoder_tmp_dirname, tempfile.TemporaryDirectory() as decoder_tmp_dirname:
enc_dec_model.encoder.save_pretrained(encoder_tmp_dirname)
enc_dec_model.decoder.save_pretrained(decoder_tmp_dirname)
EncoderDecoderModel.from_encoder_decoder_pretrained(
encoder_pretrained_model_name_or_path=encoder_tmp_dirname,
decoder_pretrained_model_name_or_path=decoder_tmp_dirname,
)
after_outputs = enc_dec_model(
input_ids=input_ids,
decoder_input_ids=decoder_input_ids,
attention_mask=attention_mask,
decoder_attention_mask=decoder_attention_mask,
)
out_1 = after_outputs[0].cpu().numpy()
out_1[np.isnan(out_1)] = 0
max_diff = np.amax(np.abs(out_1 - out_2))
self.assertLessEqual(max_diff, 1e-5)
def check_loss_output(self, loss):
self.assertEqual(loss.size(), ())
def create_and_check_bert_encoder_decoder_model_mlm_labels(
self,
config,
input_ids,
attention_mask,
encoder_hidden_states,
decoder_config,
decoder_input_ids,
decoder_attention_mask,
masked_lm_labels,
**kwargs
):
encoder_model = BertModel(config)
decoder_model = BertForMaskedLM(decoder_config)
enc_dec_model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
enc_dec_model.to(torch_device)
outputs_encoder_decoder = enc_dec_model(
input_ids=input_ids,
decoder_input_ids=decoder_input_ids,
attention_mask=attention_mask,
decoder_attention_mask=decoder_attention_mask,
masked_lm_labels=masked_lm_labels,
)
mlm_loss = outputs_encoder_decoder[0]
self.check_loss_output(mlm_loss)
# check that backprop works
mlm_loss.backward()
self.assertEqual(outputs_encoder_decoder[1].shape, (decoder_input_ids.shape + (decoder_config.vocab_size,)))
self.assertEqual(outputs_encoder_decoder[2].shape, (input_ids.shape + (config.hidden_size,)))
def create_and_check_bert_encoder_decoder_model_lm_labels(
self,
config,
input_ids,
attention_mask,
encoder_hidden_states,
decoder_config,
decoder_input_ids,
decoder_attention_mask,
lm_labels,
**kwargs
):
encoder_model = BertModel(config)
decoder_model = BertForMaskedLM(decoder_config)
enc_dec_model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
enc_dec_model.to(torch_device)
outputs_encoder_decoder = enc_dec_model(
input_ids=input_ids,
decoder_input_ids=decoder_input_ids,
attention_mask=attention_mask,
decoder_attention_mask=decoder_attention_mask,
lm_labels=lm_labels,
)
lm_loss = outputs_encoder_decoder[0]
self.check_loss_output(lm_loss)
# check that backprop works
lm_loss.backward()
self.assertEqual(outputs_encoder_decoder[1].shape, (decoder_input_ids.shape + (decoder_config.vocab_size,)))
self.assertEqual(outputs_encoder_decoder[2].shape, (input_ids.shape + (config.hidden_size,)))
def create_and_check_bert_encoder_decoder_model_generate(self, input_ids, config, decoder_config, **kwargs):
encoder_model = BertModel(config)
decoder_model = BertForMaskedLM(decoder_config)
enc_dec_model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
enc_dec_model.to(torch_device)
# Bert does not have a bos token id, so use pad_token_id instead
generated_output = enc_dec_model.generate(
input_ids, decoder_start_token_id=enc_dec_model.config.decoder.pad_token_id
)
self.assertEqual(generated_output.shape, (input_ids.shape[0],) + (decoder_config.max_length,))
def test_bert_encoder_decoder_model(self):
input_ids_dict = self.prepare_config_and_inputs_bert()
self.create_and_check_bert_encoder_decoder_model(**input_ids_dict)
def test_bert_encoder_decoder_model_from_pretrained(self):
input_ids_dict = self.prepare_config_and_inputs_bert()
self.create_and_check_bert_encoder_decoder_model_from_pretrained(**input_ids_dict)
def test_save_and_load_from_pretrained(self):
input_ids_dict = self.prepare_config_and_inputs_bert()
self.create_and_check_save_and_load(**input_ids_dict)
def test_save_and_load_from_encoder_decoder_pretrained(self):
input_ids_dict = self.prepare_config_and_inputs_bert()
self.create_and_check_save_and_load_encoder_decoder_model(**input_ids_dict)
def test_bert_encoder_decoder_model_mlm_labels(self):
input_ids_dict = self.prepare_config_and_inputs_bert()
self.create_and_check_bert_encoder_decoder_model_mlm_labels(**input_ids_dict)
def test_bert_encoder_decoder_model_lm_labels(self):
input_ids_dict = self.prepare_config_and_inputs_bert()
self.create_and_check_bert_encoder_decoder_model_lm_labels(**input_ids_dict)
def test_bert_encoder_decoder_model_generate(self):
input_ids_dict = self.prepare_config_and_inputs_bert()
self.create_and_check_bert_encoder_decoder_model_generate(**input_ids_dict)
@slow
def test_real_bert_model_from_pretrained(self):
model = EncoderDecoderModel.from_pretrained("bert-base-uncased", "bert-base-uncased")
self.assertIsNotNone(model)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment