Commit 3b0d2fa3 authored by Rémi Louf's avatar Rémi Louf
Browse files

rename seq2seq to encoder_decoder

parent 9c1bdb5b
......@@ -10,7 +10,7 @@ similar API between the different models.
| [GLUE](#glue) | Examples running BERT/XLM/XLNet/RoBERTa on the 9 GLUE tasks. Examples feature distributed training as well as half-precision. |
| [SQuAD](#squad) | Using BERT for question answering, examples with distributed training. |
| [Multiple Choice](#multiple choice) | Examples running BERT/XLNet/RoBERTa on the SWAG/RACE/ARC tasks.
| [Seq2seq Model fine-tuning](#seq2seq-model-fine-tuning) | Fine-tuning the library models for seq2seq tasks on the CNN/Daily Mail dataset. |
| [Abstractive summarization](#abstractive-summarization) | Fine-tuning the library models for abstractive summarization tasks on the CNN/Daily Mail dataset. |
## Language model fine-tuning
......@@ -391,7 +391,7 @@ exact_match = 86.91
This fine-tuned model is available as a checkpoint under the reference
`bert-large-uncased-whole-word-masking-finetuned-squad`.
## Seq2seq model fine-tuning
## Abstractive summarization
Based on the script
[`run_summarization_finetuning.py`](https://github.com/huggingface/transformers/blob/master/examples/run_summarization_finetuning.py).
......@@ -408,8 +408,6 @@ note that the finetuning script **will not work** if you do not download both
datasets. We will refer as `$DATA_PATH` the path to where you uncompressed both
archive.
## Bert2Bert and abstractive summarization
```bash
export DATA_PATH=/path/to/dataset/
......
......@@ -32,7 +32,7 @@ from transformers import (
AutoTokenizer,
BertForMaskedLM,
BertConfig,
PreTrainedSeq2seq,
PreTrainedEncoderDecoder,
Model2Model,
)
......@@ -475,7 +475,7 @@ def main():
for checkpoint in checkpoints:
encoder_checkpoint = os.path.join(checkpoint, "encoder")
decoder_checkpoint = os.path.join(checkpoint, "decoder")
model = PreTrainedSeq2seq.from_pretrained(
model = PreTrainedEncoderDecoder.from_pretrained(
encoder_checkpoint, decoder_checkpoint
)
model.to(args.device)
......
......@@ -87,7 +87,7 @@ if is_torch_available():
from .modeling_distilbert import (DistilBertForMaskedLM, DistilBertModel,
DistilBertForSequenceClassification, DistilBertForQuestionAnswering,
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP)
from .modeling_seq2seq import PreTrainedSeq2seq, Model2Model
from .modeling_encoder_decoder import PreTrainedEncoderDecoder, Model2Model
# Optimization
from .optimization import (AdamW, ConstantLRSchedule, WarmupConstantSchedule, WarmupCosineSchedule,
......
......@@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Auto Model class. """
""" Classes to support Encoder-Decoder architectures """
from __future__ import absolute_import, division, print_function, unicode_literals
......@@ -27,9 +27,9 @@ from .modeling_auto import AutoModel, AutoModelWithLMHead
logger = logging.getLogger(__name__)
class PreTrainedSeq2seq(nn.Module):
class PreTrainedEncoderDecoder(nn.Module):
r"""
:class:`~transformers.PreTrainedSeq2seq` is a generic model class that will be
:class:`~transformers.PreTrainedEncoderDecoder` is a generic model class that will be
instantiated as a transformer architecture with one of the base model
classes of the library as encoder and (optionally) another one as
decoder when created with the `AutoModel.from_pretrained(pretrained_model_name_or_path)`
......@@ -37,7 +37,7 @@ class PreTrainedSeq2seq(nn.Module):
"""
def __init__(self, encoder, decoder):
super(PreTrainedSeq2seq, self).__init__()
super(PreTrainedEncoderDecoder, self).__init__()
self.encoder = encoder
self.decoder = decoder
......@@ -107,7 +107,7 @@ class PreTrainedSeq2seq(nn.Module):
Examples::
model = PreTrainedSeq2seq.from_pretained('bert-base-uncased', 'bert-base-uncased') # initialize Bert2Bert
model = PreTrainedEncoderDecoder.from_pretained('bert-base-uncased', 'bert-base-uncased') # initialize Bert2Bert
"""
# keyword arguments come in 3 flavors: encoder-specific (prefixed by
......@@ -155,7 +155,7 @@ class PreTrainedSeq2seq(nn.Module):
def save_pretrained(self, save_directory):
""" Save a Seq2Seq model and its configuration file in a format such
that it can be loaded using `:func:`~transformers.PreTrainedSeq2seq.from_pretrained`
that it can be loaded using `:func:`~transformers.PreTrainedEncoderDecoder.from_pretrained`
We save the encoder' and decoder's parameters in two separate directories.
"""
......@@ -219,7 +219,7 @@ class PreTrainedSeq2seq(nn.Module):
return decoder_outputs + encoder_outputs
class Model2Model(PreTrainedSeq2seq):
class Model2Model(PreTrainedEncoderDecoder):
r"""
:class:`~transformers.Model2Model` instantiates a Seq2Seq2 model
where both of the encoder and decoder are of the same family. If the
......@@ -277,14 +277,14 @@ class Model2Model(PreTrainedSeq2seq):
return model
class Model2LSTM(PreTrainedSeq2seq):
class Model2LSTM(PreTrainedEncoderDecoder):
@classmethod
def from_pretrained(cls, *args, **kwargs):
if kwargs.get("decoder_model", None) is None:
# We will create a randomly initilized LSTM model as decoder
if "decoder_config" not in kwargs:
raise ValueError(
"To load an LSTM in Seq2seq model, please supply either: "
"To load an LSTM in Encoder-Decoder model, please supply either: "
" - a torch.nn.LSTM model as `decoder_model` parameter (`decoder_model=lstm_model`), or"
" - a dictionary of configuration parameters that will be used to initialize a"
" torch.nn.LSTM model as `decoder_config` keyword argument. "
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment