:class:`~transformers.PreTrainedSeq2seq` is a generic model class that will be
:class:`~transformers.PreTrainedSeq2seq` is a generic model class that will be
instantiated as a Seq2seq model with one of the base model classes of
instantiated as a transformer architecture with one of the base model
the library as encoder and (optionally) as decoder when created with
classes of the library as encoder and (optionally) another one as
the `AutoModel.from_pretrained(pretrained_model_name_or_path)` class
decoder when created with the `AutoModel.from_pretrained(pretrained_model_name_or_path)`
method.
class method.
"""
"""
def__init__(self,encoder,decoder):
def__init__(self,encoder,decoder):
...
@@ -59,13 +59,13 @@ class PreTrainedSeq2seq(nn.Module):
...
@@ -59,13 +59,13 @@ class PreTrainedSeq2seq(nn.Module):
encoder_pretrained_model_name_or_path: information necessary to initiate the encoder. Either:
encoder_pretrained_model_name_or_path: information necessary to initiate the encoder. Either:
- a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``.
- a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``.
- a path to a `directory` containing model weights saved using :func:`~transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``.
- a path to a `directory` containing model weights saved using :func:`~transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/encoder``.
- a path or url to a `tensorflow index checkpoint file` (e.g. `./tf_model/model.ckpt.index`). In this case, ``from_tf`` should be set to True and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
- a path or url to a `tensorflow index checkpoint file` (e.g. `./tf_model/model.ckpt.index`). In this case, ``from_tf`` should be set to True and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
decoder_pretrained_model_name_or_path: information necessary to initiate the decoder. Either:
decoder_pretrained_model_name_or_path: information necessary to initiate the decoder. Either:
- a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``.
- a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``.
- a path to a `directory` containing model weights saved using :func:`~transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``.
- a path to a `directory` containing model weights saved using :func:`~transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/decoder``.
- a path or url to a `tensorflow index checkpoint file` (e.g. `./tf_model/model.ckpt.index`). In this case, ``from_tf`` should be set to True and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
- a path or url to a `tensorflow index checkpoint file` (e.g. `./tf_model/model.ckpt.index`). In this case, ``from_tf`` should be set to True and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
model_args: (`optional`) Sequence of positional arguments:
model_args: (`optional`) Sequence of positional arguments:
...
@@ -103,7 +103,7 @@ class PreTrainedSeq2seq(nn.Module):
...
@@ -103,7 +103,7 @@ class PreTrainedSeq2seq(nn.Module):
- If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the underlying model's ``__init__`` method (we assume all relevant updates to the configuration have already been done)
- If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the underlying model's ``__init__`` method (we assume all relevant updates to the configuration have already been done)
- If a configuration is not provided, ``kwargs`` will be first passed to the configuration class initialization function (:func:`~transformers.PretrainedConfig.from_pretrained`). Each key of ``kwargs`` that corresponds to a configuration attribute will be used to override said attribute with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model's ``__init__`` function.
- If a configuration is not provided, ``kwargs`` will be first passed to the configuration class initialization function (:func:`~transformers.PretrainedConfig.from_pretrained`). Each key of ``kwargs`` that corresponds to a configuration attribute will be used to override said attribute with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model's ``__init__`` function.
You can specify different kwargs for the decoder by prefixing the key with `decoder_` (e.g. ``decoder_output_attention=True``).
You can specify kwargs sepcific for the encoder and decoder by prefixing the key with `encoder_` and `decoder_` respectively. (e.g. ``decoder_output_attention=True``). The remaining kwargs will be passed to both encoders and decoders.
Examples::
Examples::
...
@@ -154,8 +154,11 @@ class PreTrainedSeq2seq(nn.Module):
...
@@ -154,8 +154,11 @@ class PreTrainedSeq2seq(nn.Module):
returnmodel
returnmodel
defsave_pretrained(self,save_directory):
defsave_pretrained(self,save_directory):
""" Save a Seq2Seq model and its configuration file in a format
""" Save a Seq2Seq model and its configuration file in a format such
such that it can be loaded using `:func:`~transformers.PreTrainedSeq2seq.from_pretrained` """
that it can be loaded using `:func:`~transformers.PreTrainedSeq2seq.from_pretrained`
We save the encoder' and decoder's parameters in two separate directories.