Unverified Commit 33d7506e authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Update doc of the model page (#5985)

parent c3206eef
Models Models
---------------------------------------------------- ----------------------------------------------------
The base class ``PreTrainedModel`` implements the common methods for loading/saving a model either from a local file or directory, or from a pretrained model configuration provided by the library (downloaded from HuggingFace's AWS S3 repository). The base class :class:`~transformers.PreTrainedModel` implements the common methods for loading/saving a model either
from a local file or directory, or from a pretrained model configuration provided by the library (downloaded from
HuggingFace's AWS S3 repository).
``PreTrainedModel`` also implements a few methods which are common among all the models to: :class:`~transformers.PreTrainedModel` also implements a few methods which are common among all the models to:
- resize the input token embeddings when new tokens are added to the vocabulary - resize the input token embeddings when new tokens are added to the vocabulary
- prune the attention heads of the model. - prune the attention heads of the model.
...@@ -19,7 +21,6 @@ The base class ``PreTrainedModel`` implements the common methods for loading/sav ...@@ -19,7 +21,6 @@ The base class ``PreTrainedModel`` implements the common methods for loading/sav
.. autofunction:: transformers.apply_chunking_to_forward .. autofunction:: transformers.apply_chunking_to_forward
``TFPreTrainedModel`` ``TFPreTrainedModel``
~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~
......
...@@ -18,6 +18,7 @@ import functools ...@@ -18,6 +18,7 @@ import functools
import logging import logging
import os import os
import warnings import warnings
from typing import Dict
import h5py import h5py
import numpy as np import numpy as np
...@@ -167,30 +168,31 @@ TFMaskedLanguageModelingLoss = TFCausalLanguageModelingLoss ...@@ -167,30 +168,31 @@ TFMaskedLanguageModelingLoss = TFCausalLanguageModelingLoss
class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin): class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin):
r""" Base class for all TF models. r"""
Base class for all TF models.
:class:`~transformers.TFPreTrainedModel` takes care of storing the configuration of the models and handles methods for loading/downloading/saving models :class:`~transformers.TFPreTrainedModel` takes care of storing the configuration of the models and handles methods
as well as a few methods common to all models to (i) resize the input embeddings and (ii) prune heads in the self-attention heads. for loading, downloading and saving models as well as a few methods common to all models to:
Class attributes (overridden by derived classes): * resize the input embeddings,
- ``config_class``: a class derived from :class:`~transformers.PretrainedConfig` to use as configuration class for this model architecture. * prune heads in the self-attention heads.
- ``load_tf_weights``: a python ``method`` for loading a TensorFlow checkpoint in a PyTorch model, taking as arguments:
- ``model``: an instance of the relevant subclass of :class:`~transformers.PreTrainedModel`,
- ``config``: an instance of the relevant subclass of :class:`~transformers.PretrainedConfig`,
- ``path``: a path (string) to the TensorFlow checkpoint.
- ``base_model_prefix``: a string indicating the attribute associated to the base model in derived classes of the same architecture adding modules on top of the base model. Class attributes (overridden by derived classes):
- **config_class** (:class:`~transformers.PretrainedConfig`) -- A subclass of
:class:`~transformers.PretrainedConfig` to use as configuration class for this model architecture.
- **base_model_prefix** (:obj:`str`) -- A string indicating the attribute associated to the base model in
derived classes of the same architecture adding modules on top of the base model.
""" """
config_class = None config_class = None
base_model_prefix = "" base_model_prefix = ""
@property @property
def dummy_inputs(self): def dummy_inputs(self) -> Dict[str, tf.Tensor]:
""" Dummy inputs to build the network. """
Dummy inputs to build the network.
Returns: Returns:
tf.Tensor with dummy inputs :obj:`Dict[str, tf.Tensor]`: The dummy inputs.
""" """
return {"input_ids": tf.constant(DUMMY_INPUTS)} return {"input_ids": tf.constant(DUMMY_INPUTS)}
...@@ -207,13 +209,12 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin): ...@@ -207,13 +209,12 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin):
# Save config in model # Save config in model
self.config = config self.config = config
def get_input_embeddings(self): def get_input_embeddings(self) -> tf.keras.layers.Layer:
""" """
Returns the model's input embeddings. Returns the model's input embeddings.
Returns: Returns:
:obj:`tf.keras.layers.Layer`: :obj:`tf.keras.layers.Layer`: A torch module mapping vocabulary to hidden states.
A torch module mapping vocabulary to hidden states.
""" """
base_model = getattr(self, self.base_model_prefix, self) base_model = getattr(self, self.base_model_prefix, self)
if base_model is not self: if base_model is not self:
...@@ -223,7 +224,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin): ...@@ -223,7 +224,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin):
def set_input_embeddings(self, value): def set_input_embeddings(self, value):
""" """
Set model's input embeddings Set model's input embeddings.
Args: Args:
value (:obj:`tf.keras.layers.Layer`): value (:obj:`tf.keras.layers.Layer`):
...@@ -235,28 +236,30 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin): ...@@ -235,28 +236,30 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin):
else: else:
raise NotImplementedError raise NotImplementedError
def get_output_embeddings(self): def get_output_embeddings(self) -> tf.keras.layers.Layer:
""" """
Returns the model's output embeddings. Returns the model's output embeddings.
Returns: Returns:
:obj:`tf.keras.layers.Layer`: :obj:`tf.keras.layers.Layer`: A torch module mapping hidden states to vocabulary.
A torch module mapping hidden states to vocabulary.
""" """
return None # Overwrite for models with output embeddings return None # Overwrite for models with output embeddings
def resize_token_embeddings(self, new_num_tokens=None): def resize_token_embeddings(self, new_num_tokens=None) -> tf.Variable:
""" Resize input token embeddings matrix of the model if new_num_tokens != config.vocab_size. """
Take care of tying weights embeddings afterwards if the model class has a `tie_weights()` method. Resizes input token embeddings matrix of the model if :obj:`new_num_tokens != config.vocab_size`.
Arguments:
new_num_tokens: (`optional`) int: Takes care of tying weights embeddings afterwards if the model class has a :obj:`tie_weights()` method.
New number of tokens in the embedding matrix. Increasing the size will add newly initialized vectors at the end. Reducing the size will remove vectors from the end.
If not provided or None: does nothing and just returns a pointer to the input tokens ``tf.Variable`` Module of the model.
Return: ``tf.Variable`` Arguments:
Pointer to the input tokens Embeddings Module of the model new_num_tokens (:obj:`int`, `optional`):
The number of new tokens in the embedding matrix. Increasing the size will add newly initialized
vectors at the end. Reducing the size will remove vectors from the end. If not provided or :obj:`None`,
just returns a pointer to the input tokens :obj:`tf.Variable` module of the model wihtout doing
anything.
Return:
:obj:`tf.Variable`: Pointer to the input tokens Embeddings Module of the model.
""" """
model_embeds = self._resize_token_embeddings(new_num_tokens) model_embeds = self._resize_token_embeddings(new_num_tokens)
if new_num_tokens is None: if new_num_tokens is None:
...@@ -285,19 +288,24 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin): ...@@ -285,19 +288,24 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin):
else: else:
raise ValueError("word embedding is not defined.") raise ValueError("word embedding is not defined.")
def _get_resized_embeddings(self, old_embeddings, new_num_tokens=None): def _get_resized_embeddings(self, old_embeddings, new_num_tokens=None) -> tf.Variable:
""" Build a resized Embedding Variable from a provided token Embedding Module. """
Increasing the size will add newly initialized vectors at the end Build a resized Embedding Module from a provided token Embedding Module. Increasing the size will add newly
Reducing the size will remove vectors from the end. initialized vectors at the end. Reducing the size will remove vectors from the end
Args: Args:
new_num_tokens: (`optional`) int old_embeddings (:obj:`tf.Variable`):
Old embeddings to be resized.
new_num_tokens (:obj:`int`, `optional`):
New number of tokens in the embedding matrix. New number of tokens in the embedding matrix.
Increasing the size will add newly initialized vectors at the end
Reducing the size will remove vectors from the end Increasing the size will add newly initialized vectors at the end. Reducing the size will remove
If not provided or None: return the provided token Embedding Module. vectors from the end. If not provided or :obj:`None`, just returns a pointer to the input tokens
Return: ``tf.Variable`` :obj:`tf.Variable`` module of the model wihtout doing anything.
Pointer to the resized word Embedding Module or the old Embedding Module if new_num_tokens is None
Return:
:obj:`tf.Variable`: Pointer to the resized Embedding Module or the old Embedding Module if
:obj:`new_num_tokens` is :obj:`None`
""" """
word_embeddings = self._get_word_embeddings(old_embeddings) word_embeddings = self._get_word_embeddings(old_embeddings)
if new_num_tokens is None: if new_num_tokens is None:
...@@ -325,17 +333,25 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin): ...@@ -325,17 +333,25 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin):
return new_embeddings return new_embeddings
def prune_heads(self, heads_to_prune): def prune_heads(self, heads_to_prune):
""" Prunes heads of the base model. """
Prunes heads of the base model.
Arguments: Arguments:
heads_to_prune (:obj:`Dict[int, List[int]]`):
heads_to_prune: dict with keys being selected layer indices (`int`) and associated values being the list of heads to prune in said layer (list of `int`). Dictionary with keys being selected layer indices (:obj:`int`) and associated values being the list
of heads to prune in said layer (list of :obj:`int`). For instance {1: [0, 2], 2: [2, 3]} will
prune heads 0 and 2 on layer 1 and heads 2 and 3 on layer 2.
""" """
raise NotImplementedError raise NotImplementedError
def save_pretrained(self, save_directory): def save_pretrained(self, save_directory):
""" Save a model and its configuration file to a directory, so that it """
can be re-loaded using the :func:`~transformers.PreTrainedModel.from_pretrained` class method. Save a model and its configuration file to a directory, so that it can be re-loaded using the
`:func:`~transformers.TFPreTrainedModel.from_pretrained`` class method.
Arguments:
save_directory (:obj:`str`):
Directory to which to save. Will be created if it doesn't exist.
""" """
if os.path.isfile(save_directory): if os.path.isfile(save_directory):
logger.error("Provided path ({}) should be a directory, not a file".format(save_directory)) logger.error("Provided path ({}) should be a directory, not a file".format(save_directory))
...@@ -352,68 +368,101 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin): ...@@ -352,68 +368,101 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin):
@classmethod @classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
r"""Instantiate a pretrained TF 2.0 model from a pre-trained model configuration. r"""
Instantiate a pretrained TF 2.0 model from a pre-trained model configuration.
The warning ``Weights from XXX not initialized from pretrained model`` means that the weights of XXX do not come pre-trained with the rest of the model. The warning `Weights from XXX not initialized from pretrained model` means that the weights of XXX do not come
It is up to you to train those weights with a downstream fine-tuning task. pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning
task.
The warning ``Weights from XXX not used in YYY`` means that the layer XXX is not used by YYY, therefore those weights are discarded. The warning `Weights from XXX not used in YYY` means that the layer XXX is not used by YYY, therefore those
weights are discarded.
Parameters: Parameters:
pretrained_model_name_or_path: either: pretrained_model_name_or_path (:obj:`str`, `optional`):
- a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``. Can be either:
- a string with the `identifier name` of a pre-trained model that was user-uploaded to our S3, e.g.: ``dbmdz/bert-base-german-cased``.
- a path to a `directory` containing model weights saved using :func:`~transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``. - A string with the `shortcut name` of a pretrained model to load from cache or download, e.g.,
- a path or url to a `PyTorch state_dict save file` (e.g. `./pt_model/pytorch_model.bin`). In this case, ``from_pt`` should be set to True and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the PyTorch checkpoint in a TensorFlow model using the provided conversion scripts and loading the TensorFlow model afterwards. ``bert-base-uncased``.
- A string with the `identifier name` of a pretrained model that was user-uploaded to our S3, e.g.,
model_args: (`optional`) Sequence of positional arguments: ``dbmdz/bert-base-german-cased``.
All remaning positional arguments will be passed to the underlying model's ``__init__`` method - A path to a `directory` containing model weights saved using
:func:`~transformersTF.PreTrainedModel.save_pretrained`, e.g., ``./my_model_directory/``.
config: (`optional`) one of: - A path or url to a `PyTorch state_dict save file` (e.g, `./pt_model/pytorch_model.bin`). In
- an instance of a class derived from :class:`~transformers.PretrainedConfig`, or this case, ``from_pt`` should be set to :obj:`True` and a configuration object should be provided
- a string valid as input to :func:`~transformers.PretrainedConfig.from_pretrained()` as ``config`` argument. This loading path is slower than converting the PyTorch model in a
TensorFlow model using the provided conversion scripts and loading the TensorFlow model
Configuration for the model to use instead of an automatically loaded configuation. Configuration can be automatically loaded when: afterwards.
- the model is a model provided by the library (loaded with the ``shortcut-name`` string of a pretrained model), or - :obj:`None` if you are both providing the configuration and state dictionary (resp. with keyword
- the model was saved using :func:`~transformers.PreTrainedModel.save_pretrained` and is reloaded by suppling the save directory. arguments ``config`` and ``state_dict``).
- the model is loaded by suppling a local directory as ``pretrained_model_name_or_path`` and a configuration JSON file named `config.json` is found in the directory. model_args (sequence of positional arguments, `optional`):
All remaning positional arguments will be passed to the underlying model's ``__init__`` method.
from_pt: (`optional`) boolean, default False: config (:obj:`Union[PretrainedConfig, str]`, `optional`):
Load the model weights from a PyTorch state_dict save file (see docstring of pretrained_model_name_or_path argument). Can be either:
cache_dir: (`optional`) string: - an instance of a class derived from :class:`~transformers.PretrainedConfig`,
Path to a directory in which a downloaded pre-trained model - a string valid as input to :func:`~transformers.PretrainedConfig.from_pretrained`.
configuration should be cached if the standard cache should not be used.
Configuration for the model to use instead of an automatically loaded configuation. Configuration can
force_download: (`optional`) boolean, default False: be automatically loaded when:
Force to (re-)download the model weights and configuration files and override the cached versions if they exists.
- The model is a model provided by the library (loaded with the `shortcut name` string of a
resume_download: (`optional`) boolean, default False: pretrained model).
Do not delete incompletely recieved file. Attempt to resume the download if such a file exists. - The model was saved using :func:`~transformers.TFPreTrainedModel.save_pretrained` and is reloaded
by suppling the save directory.
proxies: (`optional`) dict, default None: - The model is loaded by suppling a local directory as ``pretrained_model_name_or_path`` and a
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}. configuration JSON file named `config.json` is found in the directory.
The proxies are used on each request. from_pt: (:obj:`bool`, `optional`, defaults to :obj:`False`):
Load the model weights from a PyTorch state_dict save file (see docstring of
output_loading_info: (`optional`) boolean: ``pretrained_model_name_or_path`` argument).
Set to ``True`` to also return a dictionnary containing missing keys, unexpected keys and error messages. cache_dir (:obj:`str`, `optional`):
Path to a directory in which a downloaded pretrained model configuration should be cached if the
kwargs: (`optional`) Remaining dictionary of keyword arguments: standard cache should not be used.
Can be used to update the configuration object (after it being loaded) and initiate the model. (e.g. ``output_attention=True``). Behave differently depending on whether a `config` is provided or automatically loaded: force_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
- If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the underlying model's ``__init__`` method (we assume all relevant updates to the configuration have already been done) cached versions if they exist.
- If a configuration is not provided, ``kwargs`` will be first passed to the configuration class initialization function (:func:`~transformers.PretrainedConfig.from_pretrained`). Each key of ``kwargs`` that corresponds to a configuration attribute will be used to override said attribute with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model's ``__init__`` function. resume_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to delete incompletely received files. Will attempt to resume the download if such a
file exists.
proxies: (:obj:`Dict[str, str], `optional`):
A dictionary of proxy servers to use by protocol or endpoint, e.g.,
:obj:`{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each
request.
output_loading_info(:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether ot not to also return a dictionnary containing missing keys, unexpected keys and error
messages.
local_files_only(:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to only look at local files (e.g., not try doanloading the model).
use_cdn(:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not to use Cloudfront (a Content Delivery Network, or CDN) when searching for the model on
our S3 (faster).
kwargs (remaining dictionary of keyword arguments, `optional`):
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
:obj:`output_attention=True`). Behaves differently depending on whether a ``config`` is provided or
automatically loaded:
- If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the
underlying model's ``__init__`` method (we assume all relevant updates to the configuration have
already been done)
- If a configuration is not provided, ``kwargs`` will be first passed to the configuration class
initialization function (:func:`~transformers.PretrainedConfig.from_pretrained`). Each key of
``kwargs`` that corresponds to a configuration attribute will be used to override said attribute
with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration
attribute will be passed to the underlying model's ``__init__`` function.
Examples:: Examples::
# For example purposes. Not runnable. from transformers import BertConfig, TFBertModel
model = BertModel.from_pretrained('bert-base-uncased') # Download model and configuration from S3 and cache. # Download model and configuration from S3 and cache.
model = BertModel.from_pretrained('./test/saved_model/') # E.g. model was saved using `save_pretrained('./test/saved_model/')` model = TFBertModel.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained('bert-base-uncased', output_attention=True) # Update configuration during loading # Model was saved using `save_pretrained('./test/saved_model/')` (for example purposes, not runnable).
model = TFBertModel.from_pretrained('./test/saved_model/')
# Update configuration during loading.
model = TFBertModel.from_pretrained('bert-base-uncased', output_attention=True)
assert model.config.output_attention == True assert model.config.output_attention == True
# Loading from a TF checkpoint file instead of a PyTorch model (slower) # Loading from a Pytorch model file instead of a TensorFlow checkpoint (slower, for example purposes, not runnable).
config = BertConfig.from_json_file('./tf_model/my_tf_model_config.json') config = BertConfig.from_json_file('./pt_model/my_pt_model_config.json')
model = BertModel.from_pretrained('./tf_model/my_tf_checkpoint.ckpt.index', from_pt=True, config=config) model = TFBertModel.from_pretrained('./pt_model/my_pytorch_model.bin', from_pt=True, config=config)
""" """
config = kwargs.pop("config", None) config = kwargs.pop("config", None)
......
...@@ -266,34 +266,43 @@ class ModuleUtilsMixin: ...@@ -266,34 +266,43 @@ class ModuleUtilsMixin:
class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin): class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
r""" Base class for all models. r"""
Base class for all models.
:class:`~transformers.PreTrainedModel` takes care of storing the configuration of the models and handles methods for loading/downloading/saving models :class:`~transformers.PreTrainedModel` takes care of storing the configuration of the models and handles methods
as well as a few methods common to all models to (i) resize the input embeddings and (ii) prune heads in the self-attention heads. for loading, downloading and saving models as well as a few methods common to all models to:
Class attributes (overridden by derived classes): * resize the input embeddings,
- ``config_class``: a class derived from :class:`~transformers.PretrainedConfig` to use as configuration class for this model architecture. * prune heads in the self-attention heads.
- ``load_tf_weights``: a python ``method`` for loading a TensorFlow checkpoint in a PyTorch model, taking as arguments:
- ``model``: an instance of the relevant subclass of :class:`~transformers.PreTrainedModel`,
- ``config``: an instance of the relevant subclass of :class:`~transformers.PretrainedConfig`,
- ``path``: a path (string) to the TensorFlow checkpoint.
- ``base_model_prefix``: a string indicating the attribute associated to the base model in derived classes of the same architecture adding modules on top of the base model. Class attributes (overridden by derived classes):
- **config_class** (:class:`~transformers.PretrainedConfig`) -- A subclass of
:class:`~transformers.PretrainedConfig` to use as configuration class for this model architecture.
- **load_tf_weights** (:obj:`Callable`) -- A python `method` for loading a TensorFlow checkpoint in a
PyTorch model, taking as arguments:
- **model** (:class:`~transformers.PreTrainedModel`) -- An instance of the model on which to load the
TensorFlow checkpoint.
- **config** (:class:`~transformers.PreTrainedConfig`) -- An instance of the configuration associated
to the model.
- **path** (:obj:`str`) -- A path to the TensorFlow checkpoint.
- **base_model_prefix** (:obj:`str`) -- A string indicating the attribute associated to the base model in
derived classes of the same architecture adding modules on top of the base model.
""" """
config_class = None config_class = None
base_model_prefix = "" base_model_prefix = ""
@property @property
def dummy_inputs(self): def dummy_inputs(self) -> Dict[str, torch.Tensor]:
""" Dummy inputs to do a forward pass in the network. """ Dummy inputs to do a forward pass in the network.
Returns: Returns:
torch.Tensor with dummy inputs :obj:`Dict[str, torch.Tensor]`: The dummy inputs.
""" """
return {"input_ids": torch.tensor(DUMMY_INPUTS)} return {"input_ids": torch.tensor(DUMMY_INPUTS)}
def __init__(self, config, *inputs, **kwargs): def __init__(self, config: PretrainedConfig, *inputs, **kwargs):
super().__init__() super().__init__()
if not isinstance(config, PretrainedConfig): if not isinstance(config, PretrainedConfig):
raise ValueError( raise ValueError(
...@@ -310,13 +319,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin): ...@@ -310,13 +319,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
def base_model(self): def base_model(self):
return getattr(self, self.base_model_prefix, self) return getattr(self, self.base_model_prefix, self)
def get_input_embeddings(self): def get_input_embeddings(self) -> nn.Module:
""" """
Returns the model's input embeddings. Returns the model's input embeddings.
Returns: Returns:
:obj:`nn.Module`: :obj:`nn.Module`: A torch module mapping vocabulary to hidden states.
A torch module mapping vocabulary to hidden states.
""" """
base_model = getattr(self, self.base_model_prefix, self) base_model = getattr(self, self.base_model_prefix, self)
if base_model is not self: if base_model is not self:
...@@ -329,8 +337,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin): ...@@ -329,8 +337,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
Set model's input embeddings Set model's input embeddings
Args: Args:
value (:obj:`nn.Module`): value (:obj:`nn.Module`): A module mapping vocabulary to hidden states.
A module mapping vocabulary to hidden states.
""" """
base_model = getattr(self, self.base_model_prefix, self) base_model = getattr(self, self.base_model_prefix, self)
if base_model is not self: if base_model is not self:
...@@ -338,20 +345,20 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin): ...@@ -338,20 +345,20 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
else: else:
raise NotImplementedError raise NotImplementedError
def get_output_embeddings(self): def get_output_embeddings(self) -> nn.Module:
""" """
Returns the model's output embeddings. Returns the model's output embeddings.
Returns: Returns:
:obj:`nn.Module`: :obj:`nn.Module`: A torch module mapping hidden states to vocabulary.
A torch module mapping hidden states to vocabulary.
""" """
return None # Overwrite for models with output embeddings return None # Overwrite for models with output embeddings
def tie_weights(self): def tie_weights(self):
""" """
Tie the weights between the input embeddings and the output embeddings. Tie the weights between the input embeddings and the output embeddings.
If the `torchscript` flag is set in the configuration, can't handle parameter sharing so we are cloning
If the :obj:`torchscript` flag is set in the configuration, can't handle parameter sharing so we are cloning
the weights instead. the weights instead.
""" """
output_embeddings = self.get_output_embeddings() output_embeddings = self.get_output_embeddings()
...@@ -376,18 +383,21 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin): ...@@ -376,18 +383,21 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
if hasattr(output_embeddings, "out_features") and hasattr(input_embeddings, "num_embeddings"): if hasattr(output_embeddings, "out_features") and hasattr(input_embeddings, "num_embeddings"):
output_embeddings.out_features = input_embeddings.num_embeddings output_embeddings.out_features = input_embeddings.num_embeddings
def resize_token_embeddings(self, new_num_tokens: Optional[int] = None): def resize_token_embeddings(self, new_num_tokens: Optional[int] = None) -> torch.nn.Embedding:
""" Resize input token embeddings matrix of the model if new_num_tokens != config.vocab_size. """
Take care of tying weights embeddings afterwards if the model class has a `tie_weights()` method. Resizes input token embeddings matrix of the model if :obj:`new_num_tokens != config.vocab_size`.
Arguments:
new_num_tokens: (`optional`) int: Takes care of tying weights embeddings afterwards if the model class has a :obj:`tie_weights()` method.
New number of tokens in the embedding matrix. Increasing the size will add newly initialized vectors at the end. Reducing the size will remove vectors from the end.
If not provided or None: does nothing and just returns a pointer to the input tokens ``torch.nn.Embeddings`` Module of the model.
Return: ``torch.nn.Embeddings`` Arguments:
Pointer to the input tokens Embeddings Module of the model new_num_tokens (:obj:`int`, `optional`):
The number of new tokens in the embedding matrix. Increasing the size will add newly initialized
vectors at the end. Reducing the size will remove vectors from the end. If not provided or :obj:`None`,
just returns a pointer to the input tokens :obj:`torch.nn.Embedding` module of the model wihtout doing
anything.
Return:
:obj:`torch.nn.Embedding`: Pointer to the input tokens Embeddings Module of the model.
""" """
base_model = getattr(self, self.base_model_prefix, self) # get the base model if needed base_model = getattr(self, self.base_model_prefix, self) # get the base model if needed
model_embeds = base_model._resize_token_embeddings(new_num_tokens) model_embeds = base_model._resize_token_embeddings(new_num_tokens)
...@@ -412,20 +422,23 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin): ...@@ -412,20 +422,23 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
def _get_resized_embeddings( def _get_resized_embeddings(
self, old_embeddings: torch.nn.Embedding, new_num_tokens: Optional[int] = None self, old_embeddings: torch.nn.Embedding, new_num_tokens: Optional[int] = None
) -> torch.nn.Embedding: ) -> torch.nn.Embedding:
""" Build a resized Embedding Module from a provided token Embedding Module. """
Increasing the size will add newly initialized vectors at the end Build a resized Embedding Module from a provided token Embedding Module. Increasing the size will add newly
Reducing the size will remove vectors from the end initialized vectors at the end. Reducing the size will remove vectors from the end
Args: Args:
old_embeddings: ``torch.nn.Embedding`` old_embeddings (:obj:`torch.nn.Embedding`):
Old embeddings to be resized. Old embeddings to be resized.
new_num_tokens: (`optional`) int new_num_tokens (:obj:`int`, `optional`):
New number of tokens in the embedding matrix. New number of tokens in the embedding matrix.
Increasing the size will add newly initialized vectors at the end
Reducing the size will remove vectors from the end Increasing the size will add newly initialized vectors at the end. Reducing the size will remove
If not provided or None: return the provided token Embedding Module. vectors from the end. If not provided or :obj:`None`, just returns a pointer to the input tokens
Return: ``torch.nn.Embedding`` :obj:`torch.nn.Embedding`` module of the model wihtout doing anything.
Pointer to the resized Embedding Module or the old Embedding Module if new_num_tokens is None
Return:
:obj:`torch.nn.Embedding`: Pointer to the resized Embedding Module or the old Embedding Module if
:obj:`new_num_tokens` is :obj:`None`
""" """
if new_num_tokens is None: if new_num_tokens is None:
return old_embeddings return old_embeddings
...@@ -448,7 +461,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin): ...@@ -448,7 +461,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
return new_embeddings return new_embeddings
def init_weights(self): def init_weights(self):
""" Initialize and prunes weights if needed. """ """
Initializes and prunes weights if needed.
"""
# Initialize weights # Initialize weights
self.apply(self._init_weights) self.apply(self._init_weights)
...@@ -459,13 +474,15 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin): ...@@ -459,13 +474,15 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
# Tie weights if needed # Tie weights if needed
self.tie_weights() self.tie_weights()
def prune_heads(self, heads_to_prune: Dict): def prune_heads(self, heads_to_prune: Dict[int, List[int]]):
""" Prunes heads of the base model. """
Prunes heads of the base model.
Arguments: Arguments:
heads_to_prune (:obj:`Dict[int, List[int]]`):
heads_to_prune: dict with keys being selected layer indices (`int`) and associated values being the list of heads to prune in said layer (list of `int`). Dictionary with keys being selected layer indices (:obj:`int`) and associated values being the list
E.g. {1: [0, 2], 2: [2, 3]} will prune heads 0 and 2 on layer 1 and heads 2 and 3 on layer 2. of heads to prune in said layer (list of :obj:`int`). For instance {1: [0, 2], 2: [2, 3]} will
prune heads 0 and 2 on layer 1 and heads 2 and 3 on layer 2.
""" """
# save new sets of pruned heads as union of previously stored pruned heads and newly pruned heads # save new sets of pruned heads as union of previously stored pruned heads and newly pruned heads
for layer, heads in heads_to_prune.items(): for layer, heads in heads_to_prune.items():
...@@ -475,11 +492,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin): ...@@ -475,11 +492,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
self.base_model._prune_heads(heads_to_prune) self.base_model._prune_heads(heads_to_prune)
def save_pretrained(self, save_directory): def save_pretrained(self, save_directory):
""" Save a model and its configuration file to a directory, so that it """
can be re-loaded using the `:func:`~transformers.PreTrainedModel.from_pretrained`` class method. Save a model and its configuration file to a directory, so that it can be re-loaded using the
`:func:`~transformers.PreTrainedModel.from_pretrained`` class method.
Arguments: Arguments:
save_directory: directory to which to save. save_directory (:obj:`str`):
Directory to which to save. Will be created if it doesn't exist.
""" """
if os.path.isfile(save_directory): if os.path.isfile(save_directory):
logger.error("Provided path ({}) should be a directory, not a file".format(save_directory)) logger.error("Provided path ({}) should be a directory, not a file".format(save_directory))
...@@ -511,75 +530,110 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin): ...@@ -511,75 +530,110 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
@classmethod @classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
r"""Instantiate a pretrained pytorch model from a pre-trained model configuration. r"""
Instantiate a pretrained pytorch model from a pre-trained model configuration.
The model is set in evaluation mode by default using ``model.eval()`` (Dropout modules are deactivated) The model is set in evaluation mode by default using ``model.eval()`` (Dropout modules are deactivated).
To train the model, you should first set it back in training mode with ``model.train()`` To train the model, you should first set it back in training mode with ``model.train()``.
The warning ``Weights from XXX not initialized from pretrained model`` means that the weights of XXX do not come pre-trained with the rest of the model. The warning `Weights from XXX not initialized from pretrained model` means that the weights of XXX do not come
It is up to you to train those weights with a downstream fine-tuning task. pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning
task.
The warning ``Weights from XXX not used in YYY`` means that the layer XXX is not used by YYY, therefore those weights are discarded. The warning `Weights from XXX not used in YYY` means that the layer XXX is not used by YYY, therefore those
weights are discarded.
Parameters: Parameters:
pretrained_model_name_or_path: either: pretrained_model_name_or_path (:obj:`str`, `optional`):
- a string with the `shortcut name` of a pre-trained model to load from cache or download, e.g.: ``bert-base-uncased``. Can be either:
- a string with the `identifier name` of a pre-trained model that was user-uploaded to our S3, e.g.: ``dbmdz/bert-base-german-cased``.
- a path to a `directory` containing model weights saved using :func:`~transformers.PreTrainedModel.save_pretrained`, e.g.: ``./my_model_directory/``. - A string with the `shortcut name` of a pretrained model to load from cache or download, e.g.,
- a path or url to a `tensorflow index checkpoint file` (e.g. `./tf_model/model.ckpt.index`). In this case, ``from_tf`` should be set to True and a configuration object should be provided as ``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards. ``bert-base-uncased``.
- None if you are both providing the configuration and state dictionary (resp. with keyword arguments ``config`` and ``state_dict``) - A string with the `identifier name` of a pretrained model that was user-uploaded to our S3, e.g.,
``dbmdz/bert-base-german-cased``.
model_args: (`optional`) Sequence of positional arguments: - A path to a `directory` containing model weights saved using
All remaning positional arguments will be passed to the underlying model's ``__init__`` method :func:`~transformers.PreTrainedModel.save_pretrained`, e.g., ``./my_model_directory/``.
- A path or url to a `tensorflow index checkpoint file` (e.g, `./tf_model/model.ckpt.index`). In
config: (`optional`) one of: this case, ``from_tf`` should be set to :obj:`True` and a configuration object should be provided
- an instance of a class derived from :class:`~transformers.PretrainedConfig`, or as ``config`` argument. This loading path is slower than converting the TensorFlow checkpoint in
- a string valid as input to :func:`~transformers.PretrainedConfig.from_pretrained()` a PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
- :obj:`None` if you are both providing the configuration and state dictionary (resp. with keyword
Configuration for the model to use instead of an automatically loaded configuation. Configuration can be automatically loaded when: arguments ``config`` and ``state_dict``).
- the model is a model provided by the library (loaded with the ``shortcut-name`` string of a pretrained model), or model_args (sequence of positional arguments, `optional`):
- the model was saved using :func:`~transformers.PreTrainedModel.save_pretrained` and is reloaded by suppling the save directory. All remaning positional arguments will be passed to the underlying model's ``__init__`` method.
- the model is loaded by suppling a local directory as ``pretrained_model_name_or_path`` and a configuration JSON file named `config.json` is found in the directory. config (:obj:`Union[PretrainedConfig, str]`, `optional`):
Can be either:
state_dict: (`optional`) dict:
an optional state dictionnary for the model to use instead of a state dictionary loaded from saved weights file. - an instance of a class derived from :class:`~transformers.PretrainedConfig`,
This option can be used if you want to create a model from a pretrained configuration but load your own weights. - a string valid as input to :func:`~transformers.PretrainedConfig.from_pretrained`.
In this case though, you should check if using :func:`~transformers.PreTrainedModel.save_pretrained` and :func:`~transformers.PreTrainedModel.from_pretrained` is not a simpler option.
Configuration for the model to use instead of an automatically loaded configuation. Configuration can
cache_dir: (`optional`) string: be automatically loaded when:
Path to a directory in which a downloaded pre-trained model
configuration should be cached if the standard cache should not be used. - The model is a model provided by the library (loaded with the `shortcut name` string of a
pretrained model).
force_download: (`optional`) boolean, default False: - The model was saved using :func:`~transformers.PreTrainedModel.save_pretrained` and is reloaded
Force to (re-)download the model weights and configuration files and override the cached versions if they exists. by suppling the save directory.
- The model is loaded by suppling a local directory as ``pretrained_model_name_or_path`` and a
resume_download: (`optional`) boolean, default False: configuration JSON file named `config.json` is found in the directory.
Do not delete incompletely recieved file. Attempt to resume the download if such a file exists. state_dict (:obj:`Dict[str, torch.Tensor]`, `optional`):
A state dictionary to use instead of a state dictionary loaded from saved weights file.
proxies: (`optional`) dict, default None:
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}. This option can be used if you want to create a model from a pretrained configuration but load your own
The proxies are used on each request. weights. In this case though, you should check if using
:func:`~transformers.PreTrainedModel.save_pretrained` and
output_loading_info: (`optional`) boolean: :func:`~transformers.PreTrainedModel.from_pretrained` is not a simpler option.
Set to ``True`` to also return a dictionnary containing missing keys, unexpected keys and error messages. cache_dir (:obj:`str`, `optional`):
Path to a directory in which a downloaded pretrained model configuration should be cached if the
kwargs: (`optional`) Remaining dictionary of keyword arguments: standard cache should not be used.
Can be used to update the configuration object (after it being loaded) and initiate the model. (e.g. ``output_attention=True``). Behave differently depending on whether a `config` is provided or automatically loaded: from_tf (:obj:`bool`, `optional`, defaults to :obj:`False`):
Load the model weights from a TensorFlow checkpoint save file (see docstring of
- If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the underlying model's ``__init__`` method (we assume all relevant updates to the configuration have already been done) ``pretrained_model_name_or_path`` argument).
- If a configuration is not provided, ``kwargs`` will be first passed to the configuration class initialization function (:func:`~transformers.PretrainedConfig.from_pretrained`). Each key of ``kwargs`` that corresponds to a configuration attribute will be used to override said attribute with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model's ``__init__`` function. force_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.
resume_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to delete incompletely received files. Will attempt to resume the download if such a
file exists.
proxies: (:obj:`Dict[str, str], `optional`):
A dictionary of proxy servers to use by protocol or endpoint, e.g.,
:obj:`{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each
request.
output_loading_info(:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether ot not to also return a dictionnary containing missing keys, unexpected keys and error
messages.
local_files_only(:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to only look at local files (e.g., not try doanloading the model).
use_cdn(:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not to use Cloudfront (a Content Delivery Network, or CDN) when searching for the model on
our S3 (faster).
kwargs (remaining dictionary of keyword arguments, `optional`):
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
:obj:`output_attention=True`). Behaves differently depending on whether a ``config`` is provided or
automatically loaded:
- If a configuration is provided with ``config``, ``**kwargs`` will be directly passed to the
underlying model's ``__init__`` method (we assume all relevant updates to the configuration have
already been done)
- If a configuration is not provided, ``kwargs`` will be first passed to the configuration class
initialization function (:func:`~transformers.PretrainedConfig.from_pretrained`). Each key of
``kwargs`` that corresponds to a configuration attribute will be used to override said attribute
with the supplied ``kwargs`` value. Remaining keys that do not correspond to any configuration
attribute will be passed to the underlying model's ``__init__`` function.
Examples:: Examples::
# For example purposes. Not runnable. from transformers import BertConfig, BertModel
model = BertModel.from_pretrained('bert-base-uncased') # Download model and configuration from S3 and cache. # Download model and configuration from S3 and cache.
model = BertModel.from_pretrained('./test/saved_model/') # E.g. model was saved using `save_pretrained('./test/saved_model/')` model = BertModel.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained('bert-base-uncased', output_attention=True) # Update configuration during loading # Model was saved using `save_pretrained('./test/saved_model/')` (for example purposes, not runnable).
model = BertModel.from_pretrained('./test/saved_model/')
# Update configuration during loading.
model = BertModel.from_pretrained('bert-base-uncased', output_attention=True)
assert model.config.output_attention == True assert model.config.output_attention == True
# Loading from a TF checkpoint file instead of a PyTorch model (slower) # Loading from a TF checkpoint file instead of a PyTorch model (slower, for example purposes, not runnable).
config = BertConfig.from_json_file('./tf_model/my_tf_model_config.json') config = BertConfig.from_json_file('./tf_model/my_tf_model_config.json')
model = BertModel.from_pretrained('./tf_model/my_tf_checkpoint.ckpt.index', from_tf=True, config=config) model = BertModel.from_pretrained('./tf_model/my_tf_checkpoint.ckpt.index', from_tf=True, config=config)
""" """
config = kwargs.pop("config", None) config = kwargs.pop("config", None)
state_dict = kwargs.pop("state_dict", None) state_dict = kwargs.pop("state_dict", None)
...@@ -1242,18 +1296,23 @@ def apply_chunking_to_forward( ...@@ -1242,18 +1296,23 @@ def apply_chunking_to_forward(
chunk_size: int, chunk_dim: int, forward_fn: Callable[..., torch.Tensor], *input_tensors chunk_size: int, chunk_dim: int, forward_fn: Callable[..., torch.Tensor], *input_tensors
) -> torch.Tensor: ) -> torch.Tensor:
""" """
This function chunks the `input_tensors` into smaller input tensor parts of size `chunk_size` over the dimension `chunk_dim`. This function chunks the :obj:`input_tensors` into smaller input tensor parts of size :obj:`chunk_size` over the
It then applies a layer `forward_fn` to each chunk independently to save memory. dimension :obj:`chunk_dim`. It then applies a layer :obj:`forward_fn` to each chunk independently to save memory.
If the `forward_fn` is independent across the `chunk_dim` this function will yield the
same result as not applying it. If the :obj:`forward_fn` is independent across the :obj:`chunk_dim` this function will yield the same result as
directly applying :obj:`forward_fn` to :obj:`input_tensors`.
Args: Args:
chunk_size: int - the chunk size of a chunked tensor. `num_chunks` = `len(input_tensors[0]) / chunk_size` chunk_size (:obj:`int`):
chunk_dim: int - the dimension over which the input_tensors should be chunked The chunk size of a chunked tensor: :obj:`num_chunks = len(input_tensors[0]) / chunk_size`.
forward_fn: fn - the forward fn of the model chunk_dim (:obj:`int`):
input_tensors: tuple(torch.Tensor) - the input tensors of `forward_fn` which are chunked The dimension over which the :obj:`input_tensors` should be chunked.
forward_fn (:obj:`Callable[..., torch.Tensor]`):
The forward function of the model.
input_tensors (:obj:`Tuple[torch.Tensor]`):
The input tensors of ``forward_fn`` which will be chunked.
Returns: Returns:
a Tensor with the same shape the foward_fn would have given if applied :obj:`torch.Tensor`: A tensor with the same shape as the :obj:`foward_fn` would have given if applied`.
Examples:: Examples::
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment