Commit 3b0d2fa3 authored by Rémi Louf's avatar Rémi Louf
Browse files

rename seq2seq to encoder_decoder

parent 9c1bdb5b
...@@ -10,7 +10,7 @@ similar API between the different models. ...@@ -10,7 +10,7 @@ similar API between the different models.
| [GLUE](#glue) | Examples running BERT/XLM/XLNet/RoBERTa on the 9 GLUE tasks. Examples feature distributed training as well as half-precision. | | [GLUE](#glue) | Examples running BERT/XLM/XLNet/RoBERTa on the 9 GLUE tasks. Examples feature distributed training as well as half-precision. |
| [SQuAD](#squad) | Using BERT for question answering, examples with distributed training. | | [SQuAD](#squad) | Using BERT for question answering, examples with distributed training. |
| [Multiple Choice](#multiple choice) | Examples running BERT/XLNet/RoBERTa on the SWAG/RACE/ARC tasks. | [Multiple Choice](#multiple choice) | Examples running BERT/XLNet/RoBERTa on the SWAG/RACE/ARC tasks.
| [Seq2seq Model fine-tuning](#seq2seq-model-fine-tuning) | Fine-tuning the library models for seq2seq tasks on the CNN/Daily Mail dataset. | | [Abstractive summarization](#abstractive-summarization) | Fine-tuning the library models for abstractive summarization tasks on the CNN/Daily Mail dataset. |
## Language model fine-tuning ## Language model fine-tuning
...@@ -391,7 +391,7 @@ exact_match = 86.91 ...@@ -391,7 +391,7 @@ exact_match = 86.91
This fine-tuned model is available as a checkpoint under the reference This fine-tuned model is available as a checkpoint under the reference
`bert-large-uncased-whole-word-masking-finetuned-squad`. `bert-large-uncased-whole-word-masking-finetuned-squad`.
## Seq2seq model fine-tuning ## Abstractive summarization
Based on the script Based on the script
[`run_summarization_finetuning.py`](https://github.com/huggingface/transformers/blob/master/examples/run_summarization_finetuning.py). [`run_summarization_finetuning.py`](https://github.com/huggingface/transformers/blob/master/examples/run_summarization_finetuning.py).
...@@ -408,8 +408,6 @@ note that the finetuning script **will not work** if you do not download both ...@@ -408,8 +408,6 @@ note that the finetuning script **will not work** if you do not download both
datasets. We will refer as `$DATA_PATH` the path to where you uncompressed both datasets. We will refer as `$DATA_PATH` the path to where you uncompressed both
archive. archive.
## Bert2Bert and abstractive summarization
```bash ```bash
export DATA_PATH=/path/to/dataset/ export DATA_PATH=/path/to/dataset/
......
...@@ -32,7 +32,7 @@ from transformers import ( ...@@ -32,7 +32,7 @@ from transformers import (
AutoTokenizer, AutoTokenizer,
BertForMaskedLM, BertForMaskedLM,
BertConfig, BertConfig,
PreTrainedSeq2seq, PreTrainedEncoderDecoder,
Model2Model, Model2Model,
) )
...@@ -475,7 +475,7 @@ def main(): ...@@ -475,7 +475,7 @@ def main():
for checkpoint in checkpoints: for checkpoint in checkpoints:
encoder_checkpoint = os.path.join(checkpoint, "encoder") encoder_checkpoint = os.path.join(checkpoint, "encoder")
decoder_checkpoint = os.path.join(checkpoint, "decoder") decoder_checkpoint = os.path.join(checkpoint, "decoder")
model = PreTrainedSeq2seq.from_pretrained( model = PreTrainedEncoderDecoder.from_pretrained(
encoder_checkpoint, decoder_checkpoint encoder_checkpoint, decoder_checkpoint
) )
model.to(args.device) model.to(args.device)
......
...@@ -87,7 +87,7 @@ if is_torch_available(): ...@@ -87,7 +87,7 @@ if is_torch_available():
from .modeling_distilbert import (DistilBertForMaskedLM, DistilBertModel, from .modeling_distilbert import (DistilBertForMaskedLM, DistilBertModel,
DistilBertForSequenceClassification, DistilBertForQuestionAnswering, DistilBertForSequenceClassification, DistilBertForQuestionAnswering,
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP) DISTILBERT_PRETRAINED_MODEL_ARCHIVE_MAP)
from .modeling_seq2seq import PreTrainedSeq2seq, Model2Model from .modeling_encoder_decoder import PreTrainedEncoderDecoder, Model2Model
# Optimization # Optimization
from .optimization import (AdamW, ConstantLRSchedule, WarmupConstantSchedule, WarmupCosineSchedule, from .optimization import (AdamW, ConstantLRSchedule, WarmupConstantSchedule, WarmupCosineSchedule,
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" Auto Model class. """ """ Classes to support Encoder-Decoder architectures """
from __future__ import absolute_import, division, print_function, unicode_literals from __future__ import absolute_import, division, print_function, unicode_literals
...@@ -27,9 +27,9 @@ from .modeling_auto import AutoModel, AutoModelWithLMHead ...@@ -27,9 +27,9 @@ from .modeling_auto import AutoModel, AutoModelWithLMHead
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class PreTrainedSeq2seq(nn.Module): class PreTrainedEncoderDecoder(nn.Module):
r""" r"""
:class:`~transformers.PreTrainedSeq2seq` is a generic model class that will be :class:`~transformers.PreTrainedEncoderDecoder` is a generic model class that will be
instantiated as a transformer architecture with one of the base model instantiated as a transformer architecture with one of the base model
classes of the library as encoder and (optionally) another one as classes of the library as encoder and (optionally) another one as
decoder when created with the `AutoModel.from_pretrained(pretrained_model_name_or_path)` decoder when created with the `AutoModel.from_pretrained(pretrained_model_name_or_path)`
...@@ -37,7 +37,7 @@ class PreTrainedSeq2seq(nn.Module): ...@@ -37,7 +37,7 @@ class PreTrainedSeq2seq(nn.Module):
""" """
def __init__(self, encoder, decoder): def __init__(self, encoder, decoder):
super(PreTrainedSeq2seq, self).__init__() super(PreTrainedEncoderDecoder, self).__init__()
self.encoder = encoder self.encoder = encoder
self.decoder = decoder self.decoder = decoder
...@@ -107,7 +107,7 @@ class PreTrainedSeq2seq(nn.Module): ...@@ -107,7 +107,7 @@ class PreTrainedSeq2seq(nn.Module):
Examples:: Examples::
model = PreTrainedSeq2seq.from_pretained('bert-base-uncased', 'bert-base-uncased') # initialize Bert2Bert model = PreTrainedEncoderDecoder.from_pretained('bert-base-uncased', 'bert-base-uncased') # initialize Bert2Bert
""" """
# keyword arguments come in 3 flavors: encoder-specific (prefixed by # keyword arguments come in 3 flavors: encoder-specific (prefixed by
...@@ -155,7 +155,7 @@ class PreTrainedSeq2seq(nn.Module): ...@@ -155,7 +155,7 @@ class PreTrainedSeq2seq(nn.Module):
def save_pretrained(self, save_directory): def save_pretrained(self, save_directory):
""" Save a Seq2Seq model and its configuration file in a format such """ Save a Seq2Seq model and its configuration file in a format such
that it can be loaded using `:func:`~transformers.PreTrainedSeq2seq.from_pretrained` that it can be loaded using `:func:`~transformers.PreTrainedEncoderDecoder.from_pretrained`
We save the encoder' and decoder's parameters in two separate directories. We save the encoder' and decoder's parameters in two separate directories.
""" """
...@@ -219,7 +219,7 @@ class PreTrainedSeq2seq(nn.Module): ...@@ -219,7 +219,7 @@ class PreTrainedSeq2seq(nn.Module):
return decoder_outputs + encoder_outputs return decoder_outputs + encoder_outputs
class Model2Model(PreTrainedSeq2seq): class Model2Model(PreTrainedEncoderDecoder):
r""" r"""
:class:`~transformers.Model2Model` instantiates a Seq2Seq2 model :class:`~transformers.Model2Model` instantiates a Seq2Seq2 model
where both of the encoder and decoder are of the same family. If the where both of the encoder and decoder are of the same family. If the
...@@ -277,14 +277,14 @@ class Model2Model(PreTrainedSeq2seq): ...@@ -277,14 +277,14 @@ class Model2Model(PreTrainedSeq2seq):
return model return model
class Model2LSTM(PreTrainedSeq2seq): class Model2LSTM(PreTrainedEncoderDecoder):
@classmethod @classmethod
def from_pretrained(cls, *args, **kwargs): def from_pretrained(cls, *args, **kwargs):
if kwargs.get("decoder_model", None) is None: if kwargs.get("decoder_model", None) is None:
# We will create a randomly initilized LSTM model as decoder # We will create a randomly initilized LSTM model as decoder
if "decoder_config" not in kwargs: if "decoder_config" not in kwargs:
raise ValueError( raise ValueError(
"To load an LSTM in Seq2seq model, please supply either: " "To load an LSTM in Encoder-Decoder model, please supply either: "
" - a torch.nn.LSTM model as `decoder_model` parameter (`decoder_model=lstm_model`), or" " - a torch.nn.LSTM model as `decoder_model` parameter (`decoder_model=lstm_model`), or"
" - a dictionary of configuration parameters that will be used to initialize a" " - a dictionary of configuration parameters that will be used to initialize a"
" torch.nn.LSTM model as `decoder_config` keyword argument. " " torch.nn.LSTM model as `decoder_config` keyword argument. "
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment