The base class ``PreTrainedModel`` implements the common methods for loading/saving a model either from a local file or directory, or from a pretrained model configuration provided by the library (downloaded from HuggingFace's AWS S3 repository).
The base class :class:`~transformers.PreTrainedModel` implements the common methods for loading/saving a model either
from a local file or directory, or from a pretrained model configuration provided by the library (downloaded from
HuggingFace's AWS S3 repository).
``PreTrainedModel`` also implements a few methods which are common among all the models to:
:class:`~transformers.PreTrainedModel` also implements a few methods which are common among all the models to:
- resize the input token embeddings when new tokens are added to the vocabulary
- resize the input token embeddings when new tokens are added to the vocabulary
- prune the attention heads of the model.
- prune the attention heads of the model.
...
@@ -19,7 +21,6 @@ The base class ``PreTrainedModel`` implements the common methods for loading/sav
...
@@ -19,7 +21,6 @@ The base class ``PreTrainedModel`` implements the common methods for loading/sav
:class:`~transformers.TFPreTrainedModel` takes care of storing the configuration of the models and handles methods for loading/downloading/saving models
:class:`~transformers.TFPreTrainedModel` takes care of storing the configuration of the models and handles methods
as well as a few methods common to all models to (i) resize the input embeddings and (ii) prune heads in the self-attention heads.
for loading, downloading and saving models as well as a few methods common to all models to:
Class attributes (overridden by derived classes):
* resize the input embeddings,
- ``config_class``: a class derived from :class:`~transformers.PretrainedConfig` to use as configuration class for this model architecture.
* prune heads in the self-attention heads.
- ``load_tf_weights``: a python ``method`` for loading a TensorFlow checkpoint in a PyTorch model, taking as arguments:
- ``model``: an instance of the relevant subclass of :class:`~transformers.PreTrainedModel`,
Class attributes (overridden by derived classes):
- ``config``: an instance of the relevant subclass of :class:`~transformers.PretrainedConfig`,
- **config_class** (:class:`~transformers.PretrainedConfig`) -- A subclass of
- ``path``: a path (string) to the TensorFlow checkpoint.
:class:`~transformers.PretrainedConfig` to use as configuration class for this model architecture.
- **base_model_prefix** (:obj:`str`) -- A string indicating the attribute associated to the base model in
- ``base_model_prefix``: a string indicating the attribute associated to the base model in derived classes of the same architecture adding modules on top of the base model.
derived classes of the same architecture adding modules on top of the base model.
"""
"""
config_class=None
config_class=None
base_model_prefix=""
base_model_prefix=""
@property
@property
defdummy_inputs(self):
defdummy_inputs(self)->Dict[str,tf.Tensor]:
""" Dummy inputs to build the network.
"""
Dummy inputs to build the network.
Returns:
Returns:
tf.Tensor with dummy inputs
:obj:`Dict[str, tf.Tensor]`: The dummy inputs.
"""
"""
return{"input_ids":tf.constant(DUMMY_INPUTS)}
return{"input_ids":tf.constant(DUMMY_INPUTS)}
...
@@ -207,13 +209,12 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin):
...
@@ -207,13 +209,12 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin):
""" Resize input token embeddings matrix of the model if new_num_tokens != config.vocab_size.
"""
Take care of tying weights embeddings afterwards if the model class has a `tie_weights()` method.
Resizes input token embeddings matrix of the model if :obj:`new_num_tokens != config.vocab_size`.
Arguments:
new_num_tokens: (`optional`) int:
Takes care of tying weights embeddings afterwards if the model class has a :obj:`tie_weights()` method.
New number of tokens in the embedding matrix. Increasing the size will add newly initialized vectors at the end. Reducing the size will remove vectors from the end.
If not provided or None: does nothing and just returns a pointer to the input tokens ``tf.Variable`` Module of the model.
Return: ``tf.Variable``
Arguments:
Pointer to the input tokens Embeddings Module of the model
new_num_tokens (:obj:`int`, `optional`):
The number of new tokens in the embedding matrix. Increasing the size will add newly initialized
vectors at the end. Reducing the size will remove vectors from the end. If not provided or :obj:`None`,
just returns a pointer to the input tokens :obj:`tf.Variable` module of the model wihtout doing
anything.
Return:
:obj:`tf.Variable`: Pointer to the input tokens Embeddings Module of the model.
@@ -325,17 +333,25 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin):
...
@@ -325,17 +333,25 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin):
returnnew_embeddings
returnnew_embeddings
defprune_heads(self,heads_to_prune):
defprune_heads(self,heads_to_prune):
""" Prunes heads of the base model.
"""
Prunes heads of the base model.
Arguments:
heads_to_prune: dict with keys being selected layer indices (`int`) and associated values being the list of heads to prune in said layer (list of `int`).
Arguments:
heads_to_prune (:obj:`Dict[int, List[int]]`):
Dictionary with keys being selected layer indices (:obj:`int`) and associated values being the list
of heads to prune in said layer (list of :obj:`int`). For instance {1: [0, 2], 2: [2, 3]} will
prune heads 0 and 2 on layer 1 and heads 2 and 3 on layer 2.
"""
"""
raiseNotImplementedError
raiseNotImplementedError
defsave_pretrained(self,save_directory):
defsave_pretrained(self,save_directory):
""" Save a model and its configuration file to a directory, so that it
"""
can be re-loaded using the :func:`~transformers.PreTrainedModel.from_pretrained` class method.
Save a model and its configuration file to a directory, so that it can be re-loaded using the
`:func:`~transformers.TFPreTrainedModel.from_pretrained`` class method.
Arguments:
save_directory (:obj:`str`):
Directory to which to save. Will be created if it doesn't exist.
"""
"""
ifos.path.isfile(save_directory):
ifos.path.isfile(save_directory):
logger.error("Provided path ({}) should be a directory, not a file".format(save_directory))
logger.error("Provided path ({}) should be a directory, not a file".format(save_directory))
...
@@ -352,68 +368,101 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin):
...
@@ -352,68 +368,101 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin):
r"""Instantiate a pretrained TF 2.0 model from a pre-trained model configuration.
r"""
Instantiate a pretrained TF 2.0 model from a pre-trained model configuration.
The warning ``Weights from XXX not initialized from pretrained model`` means that the weights of XXX do not come pre-trained with the rest of the model.
The warning `Weights from XXX not initialized from pretrained model` means that the weights of XXX do not come
It is up to you to train those weights with a downstream fine-tuning task.
pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning
task.
The warning ``Weights from XXX not used in YYY`` means that the layer XXX is not used by YYY, therefore those weights are discarded.
The warning `Weights from XXX not used in YYY` means that the layer XXX is not used by YYY, therefore those
- a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``.
Can be either:
- a string with the `identifier name` of a pre-trained model that was user-uploaded to our S3, e.g.: ``dbmdz/bert-base-german-cased``.
- a path to a `directory` containing model weights saved using :func:`~transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``.
- A string with the `shortcut name` of a pretrained model to load from cache or download, e.g.,
- a path or url to a `PyTorch state_dict save file` (e.g. `./pt_model/pytorch_model.bin`). In this case, ``from_pt`` should be set to True and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the PyTorch checkpoint in a TensorFlow model using the provided conversion scripts and loading the TensorFlow model afterwards.
``bert-base-uncased``.
- A string with the `identifier name` of a pretrained model that was user-uploaded to our S3, e.g.,
model_args: (`optional`) Sequence of positional arguments:
``dbmdz/bert-base-german-cased``.
All remaning positional arguments will be passed to the underlying model's ``__init__`` method
- A path to a `directory` containing model weights saved using
- A path or url to a `PyTorch state_dict save file` (e.g, `./pt_model/pytorch_model.bin`). In
- an instance of a class derived from :class:`~transformers.PretrainedConfig`, or
this case, ``from_pt`` should be set to :obj:`True` and a configuration object should be provided
- a string valid as input to :func:`~transformers.PretrainedConfig.from_pretrained()`
as ``config`` argument. This loading path is slower than converting the PyTorch model in a
TensorFlow model using the provided conversion scripts and loading the TensorFlow model
Configuration for the model to use instead of an automatically loaded configuation. Configuration can be automatically loaded when:
afterwards.
- the model is a model provided by the library (loaded with the ``shortcut-name`` string of a pretrained model), or
- :obj:`None` if you are both providing the configuration and state dictionary (resp. with keyword
- the model was saved using :func:`~transformers.PreTrainedModel.save_pretrained` and is reloaded by suppling the save directory.
arguments ``config`` and ``state_dict``).
- the model is loaded by suppling a local directory as ``pretrained_model_name_or_path`` and a configuration JSON file named `config.json` is found in the directory.
model_args (sequence of positional arguments, `optional`):
All remaning positional arguments will be passed to the underlying model's ``__init__`` method.
Do not delete incompletely recieved file. Attempt to resume the download if such a file exists.
- The model was saved using :func:`~transformers.TFPreTrainedModel.save_pretrained` and is reloaded
by suppling the save directory.
proxies: (`optional`) dict, default None:
- The model is loaded by suppling a local directory as ``pretrained_model_name_or_path`` and a
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
configuration JSON file named `config.json` is found in the directory.
The proxies are used on each request.
from_pt: (:obj:`bool`, `optional`, defaults to :obj:`False`):
Load the model weights from a PyTorch state_dict save file (see docstring of
output_loading_info: (`optional`) boolean:
``pretrained_model_name_or_path`` argument).
Set to ``True`` to also return a dictionnary containing missing keys, unexpected keys and error messages.
cache_dir (:obj:`str`, `optional`):
Path to a directory in which a downloaded pretrained model configuration should be cached if the
kwargs: (`optional`) Remaining dictionary of keyword arguments:
standard cache should not be used.
Can be used to update the configuration object (after it being loaded) and initiate the model. (e.g. ``output_attention=True``). Behave differently depending on whether a `config` is provided or automatically loaded:
force_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
- If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the underlying model's ``__init__`` method (we assume all relevant updates to the configuration have already been done)
cached versions if they exist.
- If a configuration is not provided, ``kwargs`` will be first passed to the configuration class initialization function (:func:`~transformers.PretrainedConfig.from_pretrained`). Each key of ``kwargs`` that corresponds to a configuration attribute will be used to override said attribute with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model's ``__init__`` function.
resume_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to delete incompletely received files. Will attempt to resume the download if such a
file exists.
proxies: (:obj:`Dict[str, str], `optional`):
A dictionary of proxy servers to use by protocol or endpoint, e.g.,
:obj:`{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each
request.
output_loading_info(:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether ot not to also return a dictionnary containing missing keys, unexpected keys and error
messages.
local_files_only(:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to only look at local files (e.g., not try doanloading the model).
use_cdn(:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not to use Cloudfront (a Content Delivery Network, or CDN) when searching for the model on
our S3 (faster).
kwargs (remaining dictionary of keyword arguments, `optional`):
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
:obj:`output_attention=True`). Behaves differently depending on whether a ``config`` is provided or
automatically loaded:
- If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the
underlying model's ``__init__`` method (we assume all relevant updates to the configuration have
already been done)
- If a configuration is not provided, ``kwargs`` will be first passed to the configuration class
initialization function (:func:`~transformers.PretrainedConfig.from_pretrained`). Each key of
``kwargs`` that corresponds to a configuration attribute will be used to override said attribute
with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration
attribute will be passed to the underlying model's ``__init__`` function.
Examples::
Examples::
# For example purposes. Not runnable.
from transformers import BertConfig, TFBertModel
model = BertModel.from_pretrained('bert-base-uncased') # Download model and configuration from S3 and cache.
# Download model and configuration from S3 and cache.
model = BertModel.from_pretrained('./test/saved_model/') # E.g. model was saved using `save_pretrained('./test/saved_model/')`
model = TFBertModel.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained('bert-base-uncased', output_attention=True) # Update configuration during loading
# Model was saved using `save_pretrained('./test/saved_model/')` (for example purposes, not runnable).
model = TFBertModel.from_pretrained('./test/saved_model/')
# Update configuration during loading.
model = TFBertModel.from_pretrained('bert-base-uncased', output_attention=True)
assert model.config.output_attention == True
assert model.config.output_attention == True
# Loading from a TF checkpoint file instead of a PyTorch model (slower)
# Loading from a Pytorch model file instead of a TensorFlow checkpoint (slower, for example purposes, not runnable).
:class:`~transformers.PreTrainedModel` takes care of storing the configuration of the models and handles methods for loading/downloading/saving models
:class:`~transformers.PreTrainedModel` takes care of storing the configuration of the models and handles methods
as well as a few methods common to all models to (i) resize the input embeddings and (ii) prune heads in the self-attention heads.
for loading, downloading and saving models as well as a few methods common to all models to:
Class attributes (overridden by derived classes):
* resize the input embeddings,
- ``config_class``: a class derived from :class:`~transformers.PretrainedConfig` to use as configuration class for this model architecture.
* prune heads in the self-attention heads.
- ``load_tf_weights``: a python ``method`` for loading a TensorFlow checkpoint in a PyTorch model, taking as arguments:
- ``model``: an instance of the relevant subclass of :class:`~transformers.PreTrainedModel`,
Class attributes (overridden by derived classes):
- ``config``: an instance of the relevant subclass of :class:`~transformers.PretrainedConfig`,
- **config_class** (:class:`~transformers.PretrainedConfig`) -- A subclass of
- ``path``: a path (string) to the TensorFlow checkpoint.
:class:`~transformers.PretrainedConfig` to use as configuration class for this model architecture.
- **load_tf_weights** (:obj:`Callable`) -- A python `method` for loading a TensorFlow checkpoint in a
PyTorch model, taking as arguments:
- ``base_model_prefix``: a string indicating the attribute associated to the base model in derived classes of the same architecture adding modules on top of the base model.
- **model** (:class:`~transformers.PreTrainedModel`) -- An instance of the model on which to load the
TensorFlow checkpoint.
- **config** (:class:`~transformers.PreTrainedConfig`) -- An instance of the configuration associated
to the model.
- **path** (:obj:`str`) -- A path to the TensorFlow checkpoint.
- **base_model_prefix** (:obj:`str`) -- A string indicating the attribute associated to the base model in
derived classes of the same architecture adding modules on top of the base model.
"""
"""
config_class=None
config_class=None
base_model_prefix=""
base_model_prefix=""
@property
@property
defdummy_inputs(self):
defdummy_inputs(self)->Dict[str,torch.Tensor]:
""" Dummy inputs to do a forward pass in the network.
""" Dummy inputs to do a forward pass in the network.
""" Resize input token embeddings matrix of the model if new_num_tokens != config.vocab_size.
"""
Take care of tying weights embeddings afterwards if the model class has a `tie_weights()` method.
Resizes input token embeddings matrix of the model if :obj:`new_num_tokens != config.vocab_size`.
Arguments:
new_num_tokens: (`optional`) int:
Takes care of tying weights embeddings afterwards if the model class has a :obj:`tie_weights()` method.
New number of tokens in the embedding matrix. Increasing the size will add newly initialized vectors at the end. Reducing the size will remove vectors from the end.
If not provided or None: does nothing and just returns a pointer to the input tokens ``torch.nn.Embeddings`` Module of the model.
Return: ``torch.nn.Embeddings``
Arguments:
Pointer to the input tokens Embeddings Module of the model
new_num_tokens (:obj:`int`, `optional`):
The number of new tokens in the embedding matrix. Increasing the size will add newly initialized
vectors at the end. Reducing the size will remove vectors from the end. If not provided or :obj:`None`,
just returns a pointer to the input tokens :obj:`torch.nn.Embedding` module of the model wihtout doing
anything.
Return:
:obj:`torch.nn.Embedding`: Pointer to the input tokens Embeddings Module of the model.
"""
"""
base_model=getattr(self,self.base_model_prefix,self)# get the base model if needed
base_model=getattr(self,self.base_model_prefix,self)# get the base model if needed
heads_to_prune: dict with keys being selected layer indices (`int`) and associated values being the list of heads to prune in said layer (list of `int`).
Arguments:
E.g. {1: [0, 2], 2: [2, 3]} will prune heads 0 and 2 on layer 1 and heads 2 and 3 on layer 2.
heads_to_prune (:obj:`Dict[int, List[int]]`):
Dictionary with keys being selected layer indices (:obj:`int`) and associated values being the list
of heads to prune in said layer (list of :obj:`int`). For instance {1: [0, 2], 2: [2, 3]} will
prune heads 0 and 2 on layer 1 and heads 2 and 3 on layer 2.
"""
"""
# save new sets of pruned heads as union of previously stored pruned heads and newly pruned heads
# save new sets of pruned heads as union of previously stored pruned heads and newly pruned heads
forlayer,headsinheads_to_prune.items():
forlayer,headsinheads_to_prune.items():
...
@@ -475,11 +492,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
...
@@ -475,11 +492,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
self.base_model._prune_heads(heads_to_prune)
self.base_model._prune_heads(heads_to_prune)
defsave_pretrained(self,save_directory):
defsave_pretrained(self,save_directory):
""" Save a model and its configuration file to a directory, so that it
"""
can be re-loaded using the `:func:`~transformers.PreTrainedModel.from_pretrained`` class method.
Save a model and its configuration file to a directory, so that it can be re-loaded using the
`:func:`~transformers.PreTrainedModel.from_pretrained`` class method.
Arguments:
Arguments:
save_directory: directory to which to save.
save_directory (:obj:`str`):
Directory to which to save. Will be created if it doesn't exist.
"""
"""
ifos.path.isfile(save_directory):
ifos.path.isfile(save_directory):
logger.error("Provided path ({}) should be a directory, not a file".format(save_directory))
logger.error("Provided path ({}) should be a directory, not a file".format(save_directory))
...
@@ -511,75 +530,110 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
...
@@ -511,75 +530,110 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
r"""Instantiate a pretrained pytorch model from a pre-trained model configuration.
r"""
Instantiate a pretrained pytorch model from a pre-trained model configuration.
The model is set in evaluation mode by default using ``model.eval()`` (Dropout modules are deactivated)
The model is set in evaluation mode by default using ``model.eval()`` (Dropout modules are deactivated).
To train the model, you should first set it back in training mode with ``model.train()``
To train the model, you should first set it back in training mode with ``model.train()``.
The warning ``Weights from XXX not initialized from pretrained model`` means that the weights of XXX do not come pre-trained with the rest of the model.
The warning `Weights from XXX not initialized from pretrained model` means that the weights of XXX do not come
It is up to you to train those weights with a downstream fine-tuning task.
pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning
task.
The warning ``Weights from XXX not used in YYY`` means that the layer XXX is not used by YYY, therefore those weights are discarded.
The warning `Weights from XXX not used in YYY` means that the layer XXX is not used by YYY, therefore those
- a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``.
Can be either:
- a string with the `identifier name` of a pre-trained model that was user-uploaded to our S3, e.g.: ``dbmdz/bert-base-german-cased``.
- a path to a `directory` containing model weights saved using :func:`~transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``.
- A string with the `shortcut name` of a pretrained model to load from cache or download, e.g.,
- a path or url to a `tensorflow index checkpoint file` (e.g. `./tf_model/model.ckpt.index`). In this case, ``from_tf`` should be set to True and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
``bert-base-uncased``.
- None if you are both providing the configuration and state dictionary (resp. with keyword arguments ``config`` and ``state_dict``)
- A string with the `identifier name` of a pretrained model that was user-uploaded to our S3, e.g.,
``dbmdz/bert-base-german-cased``.
model_args: (`optional`) Sequence of positional arguments:
- A path to a `directory` containing model weights saved using
All remaning positional arguments will be passed to the underlying model's ``__init__`` method
- A path or url to a `tensorflow index checkpoint file` (e.g, `./tf_model/model.ckpt.index`). In
config: (`optional`) one of:
this case, ``from_tf`` should be set to :obj:`True` and a configuration object should be provided
- an instance of a class derived from :class:`~transformers.PretrainedConfig`, or
as ``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in
- a string valid as input to :func:`~transformers.PretrainedConfig.from_pretrained()`
a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
- :obj:`None` if you are both providing the configuration and state dictionary (resp. with keyword
Configuration for the model to use instead of an automatically loaded configuation. Configuration can be automatically loaded when:
arguments ``config`` and ``state_dict``).
- the model is a model provided by the library (loaded with the ``shortcut-name`` string of a pretrained model), or
model_args (sequence of positional arguments, `optional`):
- the model was saved using :func:`~transformers.PreTrainedModel.save_pretrained` and is reloaded by suppling the save directory.
All remaning positional arguments will be passed to the underlying model's ``__init__`` method.
- the model is loaded by suppling a local directory as ``pretrained_model_name_or_path`` and a configuration JSON file named `config.json` is found in the directory.
an optional state dictionnary for the model to use instead of a state dictionary loaded from saved weights file.
- an instance of a class derived from :class:`~transformers.PretrainedConfig`,
This option can be used if you want to create a model from a pretrained configuration but load your own weights.
- a string valid as input to :func:`~transformers.PretrainedConfig.from_pretrained`.
In this case though, you should check if using :func:`~transformers.PreTrainedModel.save_pretrained` and :func:`~transformers.PreTrainedModel.from_pretrained` is not a simpler option.
Configuration for the model to use instead of an automatically loaded configuation. Configuration can
cache_dir: (`optional`) string:
be automatically loaded when:
Path to a directory in which a downloaded pre-trained model
configuration should be cached if the standard cache should not be used.
- The model is a model provided by the library (loaded with the `shortcut name` string of a
A state dictionary to use instead of a state dictionary loaded from saved weights file.
proxies: (`optional`) dict, default None:
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}.
This option can be used if you want to create a model from a pretrained configuration but load your own
The proxies are used on each request.
weights. In this case though, you should check if using
:func:`~transformers.PreTrainedModel.save_pretrained` and
output_loading_info: (`optional`) boolean:
:func:`~transformers.PreTrainedModel.from_pretrained` is not a simpler option.
Set to ``True`` to also return a dictionnary containing missing keys, unexpected keys and error messages.
cache_dir (:obj:`str`, `optional`):
Path to a directory in which a downloaded pretrained model configuration should be cached if the
kwargs: (`optional`) Remaining dictionary of keyword arguments:
standard cache should not be used.
Can be used to update the configuration object (after it being loaded) and initiate the model. (e.g. ``output_attention=True``). Behave differently depending on whether a `config` is provided or automatically loaded:
from_tf (:obj:`bool`, `optional`, defaults to :obj:`False`):
Load the model weights from a TensorFlow checkpoint save file (see docstring of
- If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the underlying model's ``__init__`` method (we assume all relevant updates to the configuration have already been done)
``pretrained_model_name_or_path`` argument).
- If a configuration is not provided, ``kwargs`` will be first passed to the configuration class initialization function (:func:`~transformers.PretrainedConfig.from_pretrained`). Each key of ``kwargs`` that corresponds to a configuration attribute will be used to override said attribute with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model's ``__init__`` function.
force_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.
resume_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to delete incompletely received files. Will attempt to resume the download if such a
file exists.
proxies: (:obj:`Dict[str, str], `optional`):
A dictionary of proxy servers to use by protocol or endpoint, e.g.,
:obj:`{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each
request.
output_loading_info(:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether ot not to also return a dictionnary containing missing keys, unexpected keys and error
messages.
local_files_only(:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to only look at local files (e.g., not try doanloading the model).
use_cdn(:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not to use Cloudfront (a Content Delivery Network, or CDN) when searching for the model on
our S3 (faster).
kwargs (remaining dictionary of keyword arguments, `optional`):
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
:obj:`output_attention=True`). Behaves differently depending on whether a ``config`` is provided or
automatically loaded:
- If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the
underlying model's ``__init__`` method (we assume all relevant updates to the configuration have
already been done)
- If a configuration is not provided, ``kwargs`` will be first passed to the configuration class
initialization function (:func:`~transformers.PretrainedConfig.from_pretrained`). Each key of
``kwargs`` that corresponds to a configuration attribute will be used to override said attribute
with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration
attribute will be passed to the underlying model's ``__init__`` function.
Examples::
Examples::
# For example purposes. Not runnable.
from transformers import BertConfig, BertModel
model = BertModel.from_pretrained('bert-base-uncased') # Download model and configuration from S3 and cache.
# Download model and configuration from S3 and cache.
model = BertModel.from_pretrained('./test/saved_model/') # E.g. model was saved using `save_pretrained('./test/saved_model/')`
model = BertModel.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained('bert-base-uncased', output_attention=True) # Update configuration during loading
# Model was saved using `save_pretrained('./test/saved_model/')` (for example purposes, not runnable).
model = BertModel.from_pretrained('./test/saved_model/')
# Update configuration during loading.
model = BertModel.from_pretrained('bert-base-uncased', output_attention=True)
assert model.config.output_attention == True
assert model.config.output_attention == True
# Loading from a TF checkpoint file instead of a PyTorch model (slower)
# Loading from a TF checkpoint file instead of a PyTorch model (slower, for example purposes, not runnable).