:class:`~transformers.PreTrainedEncoderDecoder` is a generic model class that will be
instantiated as a transformer architecture with one of the base model
classes of the library as encoder and (optionally) another one as
decoder when created with the `AutoModel.from_pretrained(pretrained_model_name_or_path)`
class method.
"""
def__init__(self,encoder,decoder):
super(PreTrainedEncoderDecoder,self).__init__()
self.encoder=encoder
self.decoder=decoder
@classmethod
deffrom_pretrained(
cls,
encoder_pretrained_model_name_or_path=None,
decoder_pretrained_model_name_or_path=None,
*model_args,
**kwargs
):
r""" Instantiates an encoder and a decoder from one or two base classes of the library from pre-trained model checkpoints.
The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated)
To train the model, you need to first set it back in training mode with `model.train()`
Params:
encoder_pretrained_model_name_or_path: information necessary to initiate the encoder. Either:
- a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``.
- a string with the `identifier name` of a pre-trained model that was user-uploaded to our S3, e.g.: ``dbmdz/bert-base-german-cased``.
- a path to a `directory` containing model weights saved using :func:`~transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/encoder``.
- a path or url to a `tensorflow index checkpoint file` (e.g. `./tf_model/model.ckpt.index`). In this case, ``from_tf`` should be set to True and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
decoder_pretrained_model_name_or_path: information necessary to initiate the decoder. Either:
- a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``.
- a string with the `identifier name` of a pre-trained model that was user-uploaded to our S3, e.g.: ``dbmdz/bert-base-german-cased``.
- a path to a `directory` containing model weights saved using :func:`~transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/decoder``.
- a path or url to a `tensorflow index checkpoint file` (e.g. `./tf_model/model.ckpt.index`). In this case, ``from_tf`` should be set to True and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
model_args: (`optional`) Sequence of positional arguments:
All remaning positional arguments will be passed to the underlying model's ``__init__`` method
config: (`optional`) instance of a class derived from :class:`~transformers.PretrainedConfig`:
Configuration for the model to use instead of an automatically loaded configuation. Configuration can be automatically loaded when:
- the model is a model provided by the library (loaded with the ``shortcut-name`` string of a pretrained model), or
- the model was saved using :func:`~transformers.PreTrainedModel.save_pretrained` and is reloaded by suppling the save directory.
- the model is loaded by suppling a local directory as ``pretrained_model_name_or_path`` and a configuration JSON file named `config.json` is found in the directory.
state_dict: (`optional`) dict:
an optional state dictionnary for the model to use instead of a state dictionary loaded from saved weights file.
This option can be used if you want to create a model from a pretrained configuration but load your own weights.
In this case though, you should check if using :func:`~transformers.PreTrainedModel.save_pretrained` and :func:`~transformers.PreTrainedModel.from_pretrained` is not a simpler option.
cache_dir: (`optional`) string:
Path to a directory in which a downloaded pre-trained model
configuration should be cached if the standard cache should not be used.
Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
proxies: (`optional`) dict, default None:
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
The proxies are used on each request.
output_loading_info: (`optional`) boolean:
Set to ``True`` to also return a dictionnary containing missing keys, unexpected keys and error messages.
kwargs: (`optional`) Remaining dictionary of keyword arguments.
Can be used to update the configuration object (after it being loaded) and initiate the model. (e.g. ``output_attention=True``). Behave differently depending on whether a `config` is provided or automatically loaded:
- If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the underlying model's ``__init__`` method (we assume all relevant updates to the configuration have already been done)
- If a configuration is not provided, ``kwargs`` will be first passed to the configuration class initialization function (:func:`~transformers.PretrainedConfig.from_pretrained`). Each key of ``kwargs`` that corresponds to a configuration attribute will be used to override said attribute with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model's ``__init__`` function.
You can specify kwargs sepcific for the encoder and decoder by prefixing the key with `encoder_` and `decoder_` respectively. (e.g. ``decoder_output_attention=True``). The remaining kwargs will be passed to both encoders and decoders.
Examples::
model = PreTrainedEncoderDecoder.from_pretained('bert-base-uncased', 'bert-base-uncased') # initialize Bert2Bert
"""
# keyword arguments come in 3 flavors: encoder-specific (prefixed by
# `encoder_`), decoder-specific (prefixed by `decoder_`) and those
# that apply to the model as a whole.
# We let the specific kwargs override the common ones in case of conflict.
kwargs_common={
argument:value
forargument,valueinkwargs.items()
ifnotargument.startswith("encoder_")
andnotargument.startswith("decoder_")
}
kwargs_decoder=kwargs_common.copy()
kwargs_encoder=kwargs_common.copy()
kwargs_encoder.update(
{
argument[len("encoder_"):]:value
forargument,valueinkwargs.items()
ifargument.startswith("encoder_")
}
)
kwargs_decoder.update(
{
argument[len("decoder_"):]:value
forargument,valueinkwargs.items()
ifargument.startswith("decoder_")
}
)
# Load and initialize the encoder and decoder
# The distinction between encoder and decoder at the model level is made
# by the value of the flag `is_decoder` that we need to set correctly.
- having all inputs as keyword arguments (like PyTorch models), or
- having all inputs as a list, tuple or dict in the first positional arguments.
This second option is usefull when using `tf.keras.Model.fit()` method which currently requires having all the tensors in the first argument of the model call function: `model(inputs)`.
If you choose this second option, there are three possibilities you can use to gather all the input Tensors in the first positional argument :
- a single Tensor with input_ids only and nothing else: `model(inputs_ids)
- a list of varying length with one or several input Tensors IN THE ORDER given in the docstring:
`model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])`
- a dictionary with one or several input Tensors associaed to the input names given in the docstring:
- a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``.
- a string with the `identifier name` of a pre-trained model that was user-uploaded to our S3, e.g.: ``dbmdz/bert-base-german-cased``.
- a path to a `directory` containing model weights saved using :func:`~transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``.
- a path or url to a `PyTorch, TF 1.X or TF 2.0 checkpoint file` (e.g. `./tf_model/model.ckpt.index`). In the case of a PyTorch checkpoint, ``from_pt`` should be set to True and a configuration object should be provided as ``config`` argument.
Do not delete incompletely recieved file. Attempt to resume the download if such a file exists.
proxies: (`optional`) dict, default None:
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
The proxies are used on each request.
...
...
@@ -209,6 +213,7 @@ class TFAutoModelWithLMHead(object):
pretrained_model_name_or_path: either:
- a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``.
- a string with the `identifier name` of a pre-trained model that was user-uploaded to our S3, e.g.: ``dbmdz/bert-base-german-cased``.
- a path to a `directory` containing model weights saved using :func:`~transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``.
- a path or url to a `PyTorch, TF 1.X or TF 2.0 checkpoint file` (e.g. `./tf_model/model.ckpt.index`). In the case of a PyTorch checkpoint, ``from_pt`` should be set to True and a configuration object should be provided as ``config`` argument.
...
...
@@ -237,6 +242,9 @@ class TFAutoModelWithLMHead(object):
Do not delete incompletely recieved file. Attempt to resume the download if such a file exists.
proxies: (`optional`) dict, default None:
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
The proxies are used on each request.
...
...
@@ -332,6 +340,7 @@ class TFAutoModelForSequenceClassification(object):
pretrained_model_name_or_path: either:
- a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``.
- a string with the `identifier name` of a pre-trained model that was user-uploaded to our S3, e.g.: ``dbmdz/bert-base-german-cased``.
- a path to a `directory` containing model weights saved using :func:`~transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``.
- a path or url to a `PyTorch, TF 1.X or TF 2.0 checkpoint file` (e.g. `./tf_model/model.ckpt.index`). In the case of a PyTorch checkpoint, ``from_pt`` should be set to True and a configuration object should be provided as ``config`` argument.
...
...
@@ -360,6 +369,9 @@ class TFAutoModelForSequenceClassification(object):
Do not delete incompletely recieved file. Attempt to resume the download if such a file exists.
proxies: (`optional`) dict, default None:
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
The proxies are used on each request.
...
...
@@ -444,6 +456,7 @@ class TFAutoModelForQuestionAnswering(object):
pretrained_model_name_or_path: either:
- a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``.
- a string with the `identifier name` of a pre-trained model that was user-uploaded to our S3, e.g.: ``dbmdz/bert-base-german-cased``.
- a path to a `directory` containing model weights saved using :func:`~transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``.
- a path or url to a `PyTorch, TF 1.X or TF 2.0 checkpoint file` (e.g. `./tf_model/model.ckpt.index`). In the case of a PyTorch checkpoint, ``from_pt`` should be set to True and a configuration object should be provided as ``config`` argument.
...
...
@@ -472,6 +485,9 @@ class TFAutoModelForQuestionAnswering(object):
@add_start_docstrings("""DistilBert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
the hidden-states output to compute `span start logits` and `span end logits`). """,
@@ -35,7 +36,7 @@ class TFPreTrainedModel(tf.keras.Model):
r""" Base class for all TF models.
:class:`~transformers.TFPreTrainedModel` takes care of storing the configuration of the models and handles methods for loading/downloading/saving models
as well as a few methods commons to all models to (i) resize the input embeddings and (ii) prune heads in the self-attention heads.
as well as a few methods common to all models to (i) resize the input embeddings and (ii) prune heads in the self-attention heads.
Class attributes (overridden by derived classes):
- ``config_class``: a class derived from :class:`~transformers.PretrainedConfig` to use as configuration class for this model architecture.
...
...
@@ -51,7 +52,15 @@ class TFPreTrainedModel(tf.keras.Model):
config_class=None
pretrained_model_archive_map={}
base_model_prefix=""
dummy_inputs=tf.constant(DUMMY_INPUTS)# dummy inputs to build the network
""" Build a resized Embedding Variable from a provided token Embedding Module.
Increasing the size will add newly initialized vectors at the end
...
...
@@ -153,6 +177,7 @@ class TFPreTrainedModel(tf.keras.Model):
pretrained_model_name_or_path: either:
- a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``.
- a string with the `identifier name` of a pre-trained model that was user-uploaded to our S3, e.g.: ``dbmdz/bert-base-german-cased``.
- a path to a `directory` containing model weights saved using :func:`~transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``.
- a path or url to a `PyTorch state_dict save file` (e.g. `./pt_model/pytorch_model.bin`). In this case, ``from_pt`` should be set to True and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the PyTorch checkpoint in a TensorFlow model using the provided conversion scripts and loading the TensorFlow model afterwards.
...
...
@@ -176,6 +201,9 @@ class TFPreTrainedModel(tf.keras.Model):
# @add_start_docstrings("""XLNet Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
# the hidden-states output to compute `span start logits` and `span end logits`). """,