Unverified Commit fa49b9af authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

Clean Encoder-Decoder models with Bart/T5-like API and add generate possibility (#3383)

* change encoder decoder style to bart & t5 style

* make encoder decoder generation dummy work for bert

* make style

* clean init config in encoder decoder

* add tests for encoder decoder models

* refactor and add last tests

* refactor and add last tests

* fix attn masks for bert encoder decoder

* make style

* refactor prepare inputs for Bert

* refactor

* finish encoder decoder

* correct typo

* add docstring to config

* finish

* add tests

* better naming

* make style

* fix flake8

* clean docstring

* make style

* rename
parent 18058574
......@@ -89,6 +89,7 @@ The library currently contains PyTorch and Tensorflow implementations, pre-train
:caption: Package Reference
model_doc/auto
model_doc/encoderdecoder
model_doc/bert
model_doc/gpt
model_doc/transformerxl
......
Encoder Decoder Models
-----------
This class can wrap an encoder model, such as ``BertModel`` and a decoder modeling with a language modeling head, such as ``BertForMaskedLM`` into a encoder-decoder model.
The ``EncoderDecoderModel`` class allows to instantiate a encoder decoder model using the ``from_encoder_decoder_pretrain`` class method taking a pretrained encoder and pretrained decoder model as an input.
The ``EncoderDecoderModel`` is saved using the standard ``save_pretrained()`` method and can also again be loaded using the standard ``from_pretrained()`` method.
An application of this architecture could be *summarization* using two pretrained Bert models as is shown in the paper: `Text Summarization with Pretrained Encoders <https://arxiv.org/abs/1910.13461>`_ by Yang Liu and Mirella Lapata.
``EncoderDecoderConfig``
~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.EncoderDecoderConfig
:members:
``EncoderDecoderModel``
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.EncoderDecoderModel
:members:
......@@ -41,6 +41,7 @@ from .configuration_camembert import CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, Ca
from .configuration_ctrl import CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP, CTRLConfig
from .configuration_distilbert import DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, DistilBertConfig
from .configuration_electra import ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP, ElectraConfig
from .configuration_encoder_decoder import EncoderDecoderConfig
from .configuration_flaubert import FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, FlaubertConfig
from .configuration_gpt2 import GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2Config
from .configuration_mmbt import MMBTConfig
......@@ -267,7 +268,7 @@ if is_torch_available():
CamembertForQuestionAnswering,
CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
)
from .modeling_encoder_decoder import PreTrainedEncoderDecoder
from .modeling_encoder_decoder import EncoderDecoderModel
from .modeling_t5 import (
T5PreTrainedModel,
T5Model,
......
......@@ -25,6 +25,7 @@ from .configuration_camembert import CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, Ca
from .configuration_ctrl import CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP, CTRLConfig
from .configuration_distilbert import DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, DistilBertConfig
from .configuration_electra import ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP, ElectraConfig
from .configuration_encoder_decoder import EncoderDecoderConfig
from .configuration_flaubert import FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, FlaubertConfig
from .configuration_gpt2 import GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2Config
from .configuration_openai import OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP, OpenAIGPTConfig
......@@ -82,6 +83,7 @@ CONFIG_MAPPING = OrderedDict(
("xlm", XLMConfig,),
("ctrl", CTRLConfig,),
("electra", ElectraConfig,),
("encoder-decoder", EncoderDecoderConfig,),
]
)
......
# coding=utf-8
# Copyright 2020 The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import logging
from .configuration_utils import PretrainedConfig
logger = logging.getLogger(__name__)
class EncoderDecoderConfig(PretrainedConfig):
r"""
:class:`~transformers.EncoderDecoderConfig` is the configuration class to store the configuration of a `EncoderDecoderModel`.
It is used to instantiate an Encoder Decoder model according to the specified arguments, defining the encoder and decoder configs.
Configuration objects inherit from :class:`~transformers.PretrainedConfig`
and can be used to control the model outputs.
See the documentation for :class:`~transformers.PretrainedConfig` for more information.
Arguments:
kwargs: (`optional`) Remaining dictionary of keyword arguments. Notably:
encoder (:class:`PretrainedConfig`, optional, defaults to `None`):
An instance of a configuration object that defines the encoder config.
encoder (:class:`PretrainedConfig`, optional, defaults to `None`):
An instance of a configuration object that defines the decoder config.
"""
model_type = "encoder_decoder"
def __init__(self, **kwargs):
super().__init__(**kwargs)
assert (
"encoder" in kwargs and "decoder" in kwargs
), "Config has to be initialized with encoder and decoder config"
encoder_config = kwargs.pop("encoder")
encoder_model_type = encoder_config.pop("model_type")
decoder_config = kwargs.pop("decoder")
decoder_model_type = decoder_config.pop("model_type")
from transformers import AutoConfig
self.encoder = AutoConfig.for_model(encoder_model_type, **encoder_config)
self.decoder = AutoConfig.for_model(decoder_model_type, **decoder_config)
self.is_encoder_decoder = True
@classmethod
def from_encoder_decoder_configs(
cls, encoder_config: PretrainedConfig, decoder_config: PretrainedConfig
) -> PretrainedConfig:
r"""
Instantiate a :class:`~transformers.EncoderDecoderConfig` (or a derived class) from a pre-trained encoder model configuration and decoder model configuration.
Returns:
:class:`EncoderDecoderConfig`: An instance of a configuration object
"""
return cls(encoder=encoder_config.to_dict(), decoder=decoder_config.to_dict())
def to_dict(self):
"""
Serializes this instance to a Python dictionary. Override the default `to_dict()` from `PretrainedConfig`.
Returns:
:obj:`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = copy.deepcopy(self.__dict__)
output["encoder"] = self.encoder.to_dict()
output["decoder"] = self.decoder.to_dict()
output["model_type"] = self.__class__.model_type
return output
......@@ -27,6 +27,7 @@ from .configuration_auto import (
CTRLConfig,
DistilBertConfig,
ElectraConfig,
EncoderDecoderConfig,
FlaubertConfig,
GPT2Config,
OpenAIGPTConfig,
......@@ -86,6 +87,7 @@ from .modeling_electra import (
ElectraForTokenClassification,
ElectraModel,
)
from .modeling_encoder_decoder import EncoderDecoderModel
from .modeling_flaubert import (
FLAUBERT_PRETRAINED_MODEL_ARCHIVE_MAP,
FlaubertForQuestionAnsweringSimple,
......@@ -219,6 +221,7 @@ MODEL_WITH_LM_HEAD_MAPPING = OrderedDict(
(XLMConfig, XLMWithLMHeadModel),
(CTRLConfig, CTRLLMHeadModel),
(ElectraConfig, ElectraForMaskedLM),
(EncoderDecoderConfig, EncoderDecoderModel),
]
)
......
......@@ -959,6 +959,28 @@ class BertForMaskedLM(BertPreTrainedModel):
return outputs # (ltr_lm_loss), (masked_lm_loss), prediction_scores, (hidden_states), (attentions)
def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs):
input_shape = input_ids.shape
effective_batch_size = input_shape[0]
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
if attention_mask is None:
attention_mask = input_ids.new_ones(input_shape)
# if model is does not use a causal mask then add a dummy token
if self.config.is_decoder is False:
assert self.config.pad_token_id is not None, "The PAD token should be defined for generation"
attention_mask = torch.cat(
[attention_mask, attention_mask.new_zeros((attention_mask.shape[0], 1))], dim=-1
)
dummy_token = torch.full(
(effective_batch_size, 1), self.config.pad_token_id, dtype=torch.long, device=input_ids.device
)
input_ids = torch.cat([input_ids, dummy_token], dim=1)
return {"input_ids": input_ids, "attention_mask": attention_mask}
@add_start_docstrings(
"""Bert Model with a `next sentence prediction (classification)` head on top. """, BERT_START_DOCSTRING,
......
......@@ -16,53 +16,101 @@
import logging
import os
from typing import Optional
from torch import nn
from .modeling_auto import AutoModel, AutoModelWithLMHead
from .configuration_encoder_decoder import EncoderDecoderConfig
from .configuration_utils import PretrainedConfig
from .modeling_utils import PreTrainedModel
logger = logging.getLogger(__name__)
class PreTrainedEncoderDecoder(nn.Module):
class EncoderDecoderModel(PreTrainedModel):
r"""
:class:`~transformers.PreTrainedEncoderDecoder` is a generic model class that will be
:class:`~transformers.EncoderDecoder` is a generic model class that will be
instantiated as a transformer architecture with one of the base model
classes of the library as encoder and (optionally) another one as
classes of the library as encoder and another one as
decoder when created with the `AutoModel.from_pretrained(pretrained_model_name_or_path)`
class method.
class method for the encoder and `AutoModelWithLMHead.from_pretrained(pretrained_model_name_or_path)` class method for the decoder.
"""
config_class = EncoderDecoderConfig
def __init__(
self,
config: Optional[PretrainedConfig] = None,
encoder: Optional[PreTrainedModel] = None,
decoder: Optional[PreTrainedModel] = None,
):
assert config is not None or (
encoder is not None and decoder is not None
), "Either a configuration or an Encoder and a decoder has to be provided"
if config is None:
config = EncoderDecoderConfig.from_encoder_decoder_configs(encoder.config, decoder.config)
else:
assert isinstance(config, self.config_class), "config: {} has to be of type {}".format(
config, self.config_class
)
# initialize with config
super().__init__(config)
if encoder is None:
from transformers import AutoModel
encoder = AutoModel.from_config(config.encoder)
if decoder is None:
from transformers import AutoModelWithLMHead
decoder = AutoModelWithLMHead.from_config(config.decoder)
def __init__(self, encoder, decoder):
super().__init__()
self.encoder = encoder
self.decoder = decoder
assert (
self.encoder.get_output_embeddings() is None
), "The encoder {} should not have a LM Head. Please use a model without LM Head"
def tie_weights(self):
# for now no weights tying in encoder-decoder
pass
def get_encoder(self):
return self.encoder
def get_decoder(self):
return self.decoder
def get_input_embeddings(self):
return self.encoder.get_input_embeddings()
def get_output_embeddings(self):
return self.decoder.get_output_embeddings()
@classmethod
def from_pretrained(
def from_encoder_decoder_pretrained(
cls,
encoder_pretrained_model_name_or_path=None,
decoder_pretrained_model_name_or_path=None,
encoder_pretrained_model_name_or_path: str = None,
decoder_pretrained_model_name_or_path: str = None,
*model_args,
**kwargs
):
) -> PreTrainedModel:
r""" Instantiates an encoder and a decoder from one or two base classes of the library from pre-trained model checkpoints.
The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated)
To train the model, you need to first set it back in training mode with `model.train()`
The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated).
To train the model, you need to first set it back in training mode with `model.train()`.
Params:
encoder_pretrained_model_name_or_path: information necessary to initiate the encoder. Either:
encoder_pretrained_model_name_or_path (:obj: `str`, `optional`, defaults to `None`):
information necessary to initiate the encoder. Either:
- a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``.
- a string with the `identifier name` of a pre-trained model that was user-uploaded to our S3, e.g.: ``dbmdz/bert-base-german-cased``.
- a path to a `directory` containing model weights saved using :func:`~transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/encoder``.
- a path or url to a `tensorflow index checkpoint file` (e.g. `./tf_model/model.ckpt.index`). In this case, ``from_tf`` should be set to True and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
decoder_pretrained_model_name_or_path: information necessary to initiate the decoder. Either:
decoder_pretrained_model_name_or_path (:obj: `str`, `optional`, defaults to `None`):
information necessary to initiate the decoder. Either:
- a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``.
- a string with the `identifier name` of a pre-trained model that was user-uploaded to our S3, e.g.: ``dbmdz/bert-base-german-cased``.
......@@ -72,165 +120,169 @@ class PreTrainedEncoderDecoder(nn.Module):
model_args: (`optional`) Sequence of positional arguments:
All remaning positional arguments will be passed to the underlying model's ``__init__`` method
config: (`optional`) instance of a class derived from :class:`~transformers.PretrainedConfig`:
Configuration for the model to use instead of an automatically loaded configuation. Configuration can be automatically loaded when:
- the model is a model provided by the library (loaded with the ``shortcut-name`` string of a pretrained model), or
- the model was saved using :func:`~transformers.PreTrainedModel.save_pretrained` and is reloaded by suppling the save directory.
- the model is loaded by suppling a local directory as ``pretrained_model_name_or_path`` and a configuration JSON file named `config.json` is found in the directory.
state_dict: (`optional`) dict:
an optional state dictionnary for the model to use instead of a state dictionary loaded from saved weights file.
This option can be used if you want to create a model from a pretrained configuration but load your own weights.
In this case though, you should check if using :func:`~transformers.PreTrainedModel.save_pretrained` and :func:`~transformers.PreTrainedModel.from_pretrained` is not a simpler option.
cache_dir: (`optional`) string:
Path to a directory in which a downloaded pre-trained model
configuration should be cached if the standard cache should not be used.
force_download: (`optional`) boolean, default False:
Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
proxies: (`optional`) dict, default None:
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
The proxies are used on each request.
output_loading_info: (`optional`) boolean:
Set to ``True`` to also return a dictionnary containing missing keys, unexpected keys and error messages.
kwargs: (`optional`) Remaining dictionary of keyword arguments.
Can be used to update the configuration object (after it being loaded) and initiate the model. (e.g. ``output_attention=True``). Behave differently depending on whether a `config` is provided or automatically loaded:
- If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the underlying model's ``__init__`` method (we assume all relevant updates to the configuration have already been done)
- If a configuration is not provided, ``kwargs`` will be first passed to the configuration class initialization function (:func:`~transformers.PretrainedConfig.from_pretrained`). Each key of ``kwargs`` that corresponds to a configuration attribute will be used to override said attribute with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model's ``__init__`` function.
You can specify kwargs sepcific for the encoder and decoder by prefixing the key with `encoder_` and `decoder_` respectively. (e.g. ``decoder_output_attention=True``). The remaining kwargs will be passed to both encoders and decoders.
Examples::
# For example purposes. Not runnable.
model = PreTrainedEncoderDecoder.from_pretrained('bert-base-uncased', 'bert-base-uncased') # initialize Bert2Bert
model = EncoderDecoder.from_encoder_decoder_pretrained('bert-base-uncased', 'bert-base-uncased') # initialize Bert2Bert
"""
# keyword arguments come in 3 flavors: encoder-specific (prefixed by
# `encoder_`), decoder-specific (prefixed by `decoder_`) and those
# that apply to the model as a whole.
# We let the specific kwargs override the common ones in case of conflict.
kwargs_common = {
argument: value
for argument, value in kwargs.items()
if not argument.startswith("encoder_") and not argument.startswith("decoder_")
}
kwargs_decoder = kwargs_common.copy()
kwargs_encoder = kwargs_common.copy()
kwargs_encoder.update(
{
argument[len("encoder_") :]: value
for argument, value in kwargs.items()
if argument.startswith("encoder_")
kwargs_encoder = {
argument[len("encoder_") :]: value for argument, value in kwargs.items() if argument.startswith("encoder_")
}
)
kwargs_decoder.update(
{
argument[len("decoder_") :]: value
for argument, value in kwargs.items()
if argument.startswith("decoder_")
kwargs_decoder = {
argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_")
}
)
# Load and initialize the encoder and decoder
# The distinction between encoder and decoder at the model level is made
# by the value of the flag `is_decoder` that we need to set correctly.
encoder = kwargs_encoder.pop("model", None)
if encoder is None:
assert (
encoder_pretrained_model_name_or_path is not None
), "If `model` is not defined as an argument, a `encoder_pretrained_model_name_or_path` has to be defined"
from .modeling_auto import AutoModel
encoder = AutoModel.from_pretrained(encoder_pretrained_model_name_or_path, *model_args, **kwargs_encoder)
encoder.config.is_decoder = False
decoder = kwargs_decoder.pop("model", None)
if decoder is None:
assert (
decoder_pretrained_model_name_or_path is not None
), "If `decoder_model` is not defined as an argument, a `decoder_pretrained_model_name_or_path` has to be defined"
from .modeling_auto import AutoModelWithLMHead
decoder = AutoModelWithLMHead.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder)
decoder.config.is_decoder = True
model = cls(encoder, decoder)
model = cls(encoder=encoder, decoder=decoder)
return model
def save_pretrained(self, save_directory):
""" Save a Seq2Seq model and its configuration file in a format such
that it can be loaded using `:func:`~transformers.PreTrainedEncoderDecoder.from_pretrained`
def forward(
self,
input_ids=None,
inputs_embeds=None,
attention_mask=None,
head_mask=None,
encoder_outputs=None,
decoder_input_ids=None,
decoder_attention_mask=None,
decoder_head_mask=None,
decoder_inputs_embeds=None,
masked_lm_labels=None,
lm_labels=None,
**kwargs,
):
We save the encoder' and decoder's parameters in two separate directories.
"""
Args:
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary for the encoder.
Indices can be obtained using :class:`transformers.PretrainedTokenizer`.
See :func:`transformers.PreTrainedTokenizer.encode` and
:func:`transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`):
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
Mask to avoid performing attention on padding token indices for the encoder.
Mask values selected in ``[0, 1]``:
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
head_mask: (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`, defaults to :obj:`None`):
Mask to nullify selected heads of the self-attention modules for the encoder.
Mask values selected in ``[0, 1]``:
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`, defaults to :obj:`None`):
Tuple consists of (`last_hidden_state`, `optional`: `hidden_states`, `optional`: `attentions`)
`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`) is a sequence of hidden-states at the output of the last layer of the encoder.
Used in the cross-attention of the decoder.
decoder_input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`, defaults to :obj:`None`):
Provide for sequence to sequence training to the decoder.
Indices can be obtained using :class:`transformers.PretrainedTokenizer`.
See :func:`transformers.PreTrainedTokenizer.encode` and
:func:`transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
decoder_attention_mask (:obj:`torch.BoolTensor` of shape :obj:`(batch_size, tgt_seq_len)`, `optional`, defaults to :obj:`None`):
Default behavior: generate a tensor that ignores pad tokens in decoder_input_ids. Causal mask will also be used by default.
decoder_head_mask: (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`, defaults to :obj:`None`):
Mask to nullify selected heads of the self-attention modules for the decoder.
Mask values selected in ``[0, 1]``:
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
decoder_inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, target_sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`):
Optionally, instead of passing :obj:`decoder_input_ids` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `decoder_input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
masked_lm_labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
Labels for computing the masked language modeling loss for the decoder.
Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels
in ``[0, ..., config.vocab_size]``
lm_labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
Labels for computing the left-to-right language modeling loss (next word prediction) for the decoder.
Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels
in ``[0, ..., config.vocab_size]``
kwargs: (`optional`) Remaining dictionary of keyword arguments. Keyword arguments come in two flavors:
- Without a prefix which will be input as `**encoder_kwargs` for the encoder forward function.
- With a `decoder_` prefix which will be input as `**decoder_kwargs` for the decoder forward function.
"""
# If the root output directory does not exist, create it
if not os.path.exists(save_directory):
os.mkdir(save_directory)
# Check whether the output directory is empty or not
sub_directories = [
directory
for directory in os.listdir(save_directory)
if os.path.isdir(os.path.join(save_directory, directory))
]
if len(sub_directories) > 0:
if "encoder" in sub_directories and "decoder" in sub_directories:
print(
"WARNING: there is an older version of encoder-decoder saved in"
+ " the output directory. The default behaviour is to overwrite them."
)
# Empty the output directory
for directory_to_remove in sub_directories:
# Remove all files into the subdirectory
files_to_remove = os.listdir(os.path.join(save_directory, directory_to_remove))
for file_to_remove in files_to_remove:
os.remove(os.path.join(save_directory, directory_to_remove, file_to_remove))
# Remove the subdirectory itself
os.rmdir(os.path.join(save_directory, directory_to_remove))
assert len(os.listdir(save_directory)) == 0 # sanity check
# Create the "encoder" directory inside the output directory and save the encoder into it
if not os.path.exists(os.path.join(save_directory, "encoder")):
os.mkdir(os.path.join(save_directory, "encoder"))
self.encoder.save_pretrained(os.path.join(save_directory, "encoder"))
kwargs_encoder = {argument: value for argument, value in kwargs.items() if not argument.startswith("decoder_")}
# Create the "encoder" directory inside the output directory and save the decoder into it
if not os.path.exists(os.path.join(save_directory, "decoder")):
os.mkdir(os.path.join(save_directory, "decoder"))
self.decoder.save_pretrained(os.path.join(save_directory, "decoder"))
kwargs_decoder = {
argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_")
}
def forward(self, encoder_input_ids, decoder_input_ids, **kwargs):
""" The forward pass on a seq2eq depends what we are performing:
if encoder_outputs is None:
encoder_outputs = self.encoder(
input_ids=input_ids,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
head_mask=head_mask,
**kwargs_encoder,
)
- During training we perform one forward pass through both the encoder
and decoder;
- During prediction, we perform one forward pass through the encoder,
and then perform several forward passes with the encoder's hidden
state through the decoder to decode a full sequence.
hidden_states = encoder_outputs[0]
# Decode
decoder_outputs = self.decoder(
input_ids=decoder_input_ids,
inputs_embeds=decoder_inputs_embeds,
attention_mask=decoder_attention_mask,
encoder_hidden_states=hidden_states,
encoder_attention_mask=attention_mask,
head_mask=decoder_head_mask,
lm_labels=lm_labels,
masked_lm_labels=masked_lm_labels,
**kwargs_decoder,
)
Therefore, we skip the forward pass on the encoder if an argument named
`encoder_hidden_state` is passed to this function.
return decoder_outputs + encoder_outputs
Params:
encoder_input_ids: ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``
Indices of encoder input sequence tokens in the vocabulary.
decoder_input_ids: ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``
Indices of decoder input sequence tokens in the vocabulary.
kwargs: (`optional`) Remaining dictionary of keyword arguments.
"""
kwargs_encoder, kwargs_decoder = self.prepare_model_kwargs(**kwargs)
def prepare_inputs_for_generation(self, input_ids, past, attention_mask, **kwargs):
assert past is not None, "past has to be defined for encoder_outputs"
# Encode if needed (training, first prediction pass)
encoder_hidden_states = kwargs_encoder.pop("hidden_states", None)
if encoder_hidden_states is None:
encoder_outputs = self.encoder(encoder_input_ids, **kwargs_encoder)
encoder_hidden_states = encoder_outputs[0]
# first step
if type(past) is tuple:
encoder_outputs = past
else:
encoder_outputs = ()
encoder_outputs = (past,)
kwargs_decoder["encoder_hidden_states"] = encoder_hidden_states
decoder_outputs = self.decoder(decoder_input_ids, **kwargs_decoder)
decoder_inputs = self.decoder.prepare_inputs_for_generation(input_ids)
return decoder_outputs + encoder_outputs
return {
"attention_mask": attention_mask,
"decoder_attention_mask": decoder_inputs["attention_mask"],
"decoder_input_ids": decoder_inputs["input_ids"],
"encoder_outputs": encoder_outputs,
}
def _reorder_cache(self, past, beam_idx):
# as a default encoder-decoder models do not re-order the past.
# TODO(PVP): might have to be updated, e.g. if GPT2 is to be used as a decoder
return past
......@@ -1011,7 +1011,14 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
pad_token_id = eos_token_id
# current position and vocab size
if hasattr(self.config, "vocab_size"):
vocab_size = self.config.vocab_size
elif (
self.config.is_encoder_decoder
and hasattr(self.config, "decoder")
and hasattr(self.config.decoder, "vocab_size")
):
vocab_size = self.config.decoder.vocab_size
# set effective batch size and effective batch multiplier according to do_sample
if do_sample:
......
# coding=utf-8
# Copyright 2020 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Classes to support Encoder-Decoder architectures """
def prepare_encoder_decoder_model_kwargs(**kwargs):
""" Prepare the encoder and decoder's keyword arguments.
Keyword arguments come in 3 flavors:
- encoder-specific (prefixed by `encoder_`)
- decoder-specific (prefixed by `decoder_`)
- those that apply to the model as whole.
We let the specific kwargs override the common ones in case of
conflict.
"""
kwargs_common = {
argument: value
for argument, value in kwargs.items()
if not argument.startswith("encoder_") and not argument.startswith("decoder_")
}
if "input_ids" in kwargs_common:
kwargs["encoder_input_ids"] = kwargs_common.pop("input_ids")
decoder_kwargs = kwargs_common.copy()
encoder_kwargs = kwargs_common.copy()
encoder_kwargs.update(
{argument[len("encoder_") :]: value for argument, value in kwargs.items() if argument.startswith("encoder_")}
)
decoder_kwargs.update(
{argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_")}
)
decoder_kwargs["encoder_attention_mask"] = encoder_kwargs.get("attention_mask", None)
return encoder_kwargs, decoder_kwargs
# coding=utf-8
# Copyright 2020 HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tempfile
import unittest
from transformers import is_torch_available
# TODO(PVP): this line reruns all the tests in BertModelTest; not sure whether this can be prevented
# for now only run module with pytest tests/test_modeling_encoder_decoder.py::EncoderDecoderModelTest
from .test_modeling_bert import BertModelTest
from .utils import require_torch, slow, torch_device
if is_torch_available():
from transformers import BertModel, BertForMaskedLM, EncoderDecoderModel
import numpy as np
import torch
@require_torch
class EncoderDecoderModelTest(unittest.TestCase):
def prepare_config_and_inputs_bert(self):
bert_model_tester = BertModelTest.BertModelTester(self)
encoder_config_and_inputs = bert_model_tester.prepare_config_and_inputs()
decoder_config_and_inputs = bert_model_tester.prepare_config_and_inputs_for_decoder()
(
config,
input_ids,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
) = encoder_config_and_inputs
(
decoder_config,
decoder_input_ids,
decoder_token_type_ids,
decoder_input_mask,
decoder_sequence_labels,
decoder_token_labels,
decoder_choice_labels,
encoder_hidden_states,
encoder_attention_mask,
) = decoder_config_and_inputs
return {
"config": config,
"input_ids": input_ids,
"attention_mask": input_mask,
"decoder_config": decoder_config,
"decoder_input_ids": decoder_input_ids,
"decoder_token_type_ids": decoder_token_type_ids,
"decoder_attention_mask": decoder_input_mask,
"decoder_sequence_labels": decoder_sequence_labels,
"decoder_token_labels": decoder_token_labels,
"decoder_choice_labels": decoder_choice_labels,
"encoder_hidden_states": encoder_hidden_states,
"lm_labels": decoder_token_labels,
"masked_lm_labels": decoder_token_labels,
}
def create_and_check_bert_encoder_decoder_model(
self,
config,
input_ids,
attention_mask,
encoder_hidden_states,
decoder_config,
decoder_input_ids,
decoder_attention_mask,
**kwargs
):
encoder_model = BertModel(config)
decoder_model = BertForMaskedLM(decoder_config)
enc_dec_model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
enc_dec_model.to(torch_device)
outputs_encoder_decoder = enc_dec_model(
input_ids=input_ids,
decoder_input_ids=decoder_input_ids,
attention_mask=attention_mask,
decoder_attention_mask=decoder_attention_mask,
)
self.assertEqual(outputs_encoder_decoder[0].shape, (decoder_input_ids.shape + (decoder_config.vocab_size,)))
self.assertEqual(outputs_encoder_decoder[1].shape, (input_ids.shape + (config.hidden_size,)))
encoder_outputs = (encoder_hidden_states,)
outputs_encoder_decoder = enc_dec_model(
encoder_outputs=encoder_outputs,
decoder_input_ids=decoder_input_ids,
attention_mask=attention_mask,
decoder_attention_mask=decoder_attention_mask,
)
self.assertEqual(outputs_encoder_decoder[0].shape, (decoder_input_ids.shape + (decoder_config.vocab_size,)))
self.assertEqual(outputs_encoder_decoder[1].shape, (input_ids.shape + (config.hidden_size,)))
def create_and_check_bert_encoder_decoder_model_from_pretrained(
self,
config,
input_ids,
attention_mask,
encoder_hidden_states,
decoder_config,
decoder_input_ids,
decoder_attention_mask,
**kwargs
):
encoder_model = BertModel(config)
decoder_model = BertForMaskedLM(decoder_config)
kwargs = {"encoder_model": encoder_model, "decoder_model": decoder_model}
enc_dec_model = EncoderDecoderModel.from_encoder_decoder_pretrained(**kwargs)
enc_dec_model.to(torch_device)
outputs_encoder_decoder = enc_dec_model(
input_ids=input_ids,
decoder_input_ids=decoder_input_ids,
attention_mask=attention_mask,
decoder_attention_mask=decoder_attention_mask,
)
self.assertEqual(outputs_encoder_decoder[0].shape, (decoder_input_ids.shape + (decoder_config.vocab_size,)))
self.assertEqual(outputs_encoder_decoder[1].shape, (input_ids.shape + (config.hidden_size,)))
def create_and_check_save_and_load(
self,
config,
input_ids,
attention_mask,
encoder_hidden_states,
decoder_config,
decoder_input_ids,
decoder_attention_mask,
**kwargs
):
encoder_model = BertModel(config)
decoder_model = BertForMaskedLM(decoder_config)
enc_dec_model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
enc_dec_model.to(torch_device)
enc_dec_model.eval()
with torch.no_grad():
outputs = enc_dec_model(
input_ids=input_ids,
decoder_input_ids=decoder_input_ids,
attention_mask=attention_mask,
decoder_attention_mask=decoder_attention_mask,
)
out_2 = outputs[0].cpu().numpy()
out_2[np.isnan(out_2)] = 0
with tempfile.TemporaryDirectory() as tmpdirname:
enc_dec_model.save_pretrained(tmpdirname)
EncoderDecoderModel.from_pretrained(tmpdirname)
after_outputs = enc_dec_model(
input_ids=input_ids,
decoder_input_ids=decoder_input_ids,
attention_mask=attention_mask,
decoder_attention_mask=decoder_attention_mask,
)
out_1 = after_outputs[0].cpu().numpy()
out_1[np.isnan(out_1)] = 0
max_diff = np.amax(np.abs(out_1 - out_2))
self.assertLessEqual(max_diff, 1e-5)
def create_and_check_save_and_load_encoder_decoder_model(
self,
config,
input_ids,
attention_mask,
encoder_hidden_states,
decoder_config,
decoder_input_ids,
decoder_attention_mask,
**kwargs
):
encoder_model = BertModel(config)
decoder_model = BertForMaskedLM(decoder_config)
enc_dec_model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
enc_dec_model.to(torch_device)
enc_dec_model.eval()
with torch.no_grad():
outputs = enc_dec_model(
input_ids=input_ids,
decoder_input_ids=decoder_input_ids,
attention_mask=attention_mask,
decoder_attention_mask=decoder_attention_mask,
)
out_2 = outputs[0].cpu().numpy()
out_2[np.isnan(out_2)] = 0
with tempfile.TemporaryDirectory() as encoder_tmp_dirname, tempfile.TemporaryDirectory() as decoder_tmp_dirname:
enc_dec_model.encoder.save_pretrained(encoder_tmp_dirname)
enc_dec_model.decoder.save_pretrained(decoder_tmp_dirname)
EncoderDecoderModel.from_encoder_decoder_pretrained(
encoder_pretrained_model_name_or_path=encoder_tmp_dirname,
decoder_pretrained_model_name_or_path=decoder_tmp_dirname,
)
after_outputs = enc_dec_model(
input_ids=input_ids,
decoder_input_ids=decoder_input_ids,
attention_mask=attention_mask,
decoder_attention_mask=decoder_attention_mask,
)
out_1 = after_outputs[0].cpu().numpy()
out_1[np.isnan(out_1)] = 0
max_diff = np.amax(np.abs(out_1 - out_2))
self.assertLessEqual(max_diff, 1e-5)
def check_loss_output(self, loss):
self.assertEqual(loss.size(), ())
def create_and_check_bert_encoder_decoder_model_mlm_labels(
self,
config,
input_ids,
attention_mask,
encoder_hidden_states,
decoder_config,
decoder_input_ids,
decoder_attention_mask,
masked_lm_labels,
**kwargs
):
encoder_model = BertModel(config)
decoder_model = BertForMaskedLM(decoder_config)
enc_dec_model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
enc_dec_model.to(torch_device)
outputs_encoder_decoder = enc_dec_model(
input_ids=input_ids,
decoder_input_ids=decoder_input_ids,
attention_mask=attention_mask,
decoder_attention_mask=decoder_attention_mask,
masked_lm_labels=masked_lm_labels,
)
mlm_loss = outputs_encoder_decoder[0]
self.check_loss_output(mlm_loss)
# check that backprop works
mlm_loss.backward()
self.assertEqual(outputs_encoder_decoder[1].shape, (decoder_input_ids.shape + (decoder_config.vocab_size,)))
self.assertEqual(outputs_encoder_decoder[2].shape, (input_ids.shape + (config.hidden_size,)))
def create_and_check_bert_encoder_decoder_model_lm_labels(
self,
config,
input_ids,
attention_mask,
encoder_hidden_states,
decoder_config,
decoder_input_ids,
decoder_attention_mask,
lm_labels,
**kwargs
):
encoder_model = BertModel(config)
decoder_model = BertForMaskedLM(decoder_config)
enc_dec_model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
enc_dec_model.to(torch_device)
outputs_encoder_decoder = enc_dec_model(
input_ids=input_ids,
decoder_input_ids=decoder_input_ids,
attention_mask=attention_mask,
decoder_attention_mask=decoder_attention_mask,
lm_labels=lm_labels,
)
lm_loss = outputs_encoder_decoder[0]
self.check_loss_output(lm_loss)
# check that backprop works
lm_loss.backward()
self.assertEqual(outputs_encoder_decoder[1].shape, (decoder_input_ids.shape + (decoder_config.vocab_size,)))
self.assertEqual(outputs_encoder_decoder[2].shape, (input_ids.shape + (config.hidden_size,)))
def create_and_check_bert_encoder_decoder_model_generate(self, input_ids, config, decoder_config, **kwargs):
encoder_model = BertModel(config)
decoder_model = BertForMaskedLM(decoder_config)
enc_dec_model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
enc_dec_model.to(torch_device)
# Bert does not have a bos token id, so use pad_token_id instead
generated_output = enc_dec_model.generate(
input_ids, decoder_start_token_id=enc_dec_model.config.decoder.pad_token_id
)
self.assertEqual(generated_output.shape, (input_ids.shape[0],) + (decoder_config.max_length,))
def test_bert_encoder_decoder_model(self):
input_ids_dict = self.prepare_config_and_inputs_bert()
self.create_and_check_bert_encoder_decoder_model(**input_ids_dict)
def test_bert_encoder_decoder_model_from_pretrained(self):
input_ids_dict = self.prepare_config_and_inputs_bert()
self.create_and_check_bert_encoder_decoder_model_from_pretrained(**input_ids_dict)
def test_save_and_load_from_pretrained(self):
input_ids_dict = self.prepare_config_and_inputs_bert()
self.create_and_check_save_and_load(**input_ids_dict)
def test_save_and_load_from_encoder_decoder_pretrained(self):
input_ids_dict = self.prepare_config_and_inputs_bert()
self.create_and_check_save_and_load_encoder_decoder_model(**input_ids_dict)
def test_bert_encoder_decoder_model_mlm_labels(self):
input_ids_dict = self.prepare_config_and_inputs_bert()
self.create_and_check_bert_encoder_decoder_model_mlm_labels(**input_ids_dict)
def test_bert_encoder_decoder_model_lm_labels(self):
input_ids_dict = self.prepare_config_and_inputs_bert()
self.create_and_check_bert_encoder_decoder_model_lm_labels(**input_ids_dict)
def test_bert_encoder_decoder_model_generate(self):
input_ids_dict = self.prepare_config_and_inputs_bert()
self.create_and_check_bert_encoder_decoder_model_generate(**input_ids_dict)
@slow
def test_real_bert_model_from_pretrained(self):
model = EncoderDecoderModel.from_pretrained("bert-base-uncased", "bert-base-uncased")
self.assertIsNotNone(model)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment