- having all inputs as keyword arguments (like PyTorch models), or
- having all inputs as a list, tuple or dict in the first positional arguments.
This second option is usefull when using `tf.keras.Model.fit()` method which currently requires having all the tensors in the first argument of the model call function: `model(inputs)`.
If you choose this second option, there are three possibilities you can use to gather all the input Tensors in the first positional argument :
- a single Tensor with input_ids only and nothing else: `model(inputs_ids)
- a list of varying length with one or several input Tensors IN THE ORDER given in the docstring:
`model([input_ids, attention_mask])` or `model([input_ids, attention_mask, token_type_ids])`
- a dictionary with one or several input Tensors associaed to the input names given in the docstring:
@@ -35,7 +35,7 @@ class TFPreTrainedModel(tf.keras.Model):
r""" Base class for all TF models.
:class:`~transformers.TFPreTrainedModel` takes care of storing the configuration of the models and handles methods for loading/downloading/saving models
as well as a few methods commons to all models to (i) resize the input embeddings and (ii) prune heads in the self-attention heads.
as well as a few methods common to all models to (i) resize the input embeddings and (ii) prune heads in the self-attention heads.
Class attributes (overridden by derived classes):
- ``config_class``: a class derived from :class:`~transformers.PretrainedConfig` to use as configuration class for this model architecture.
...
...
@@ -65,6 +65,21 @@ class TFPreTrainedModel(tf.keras.Model):
@@ -170,6 +214,9 @@ class PreTrainedModel(nn.Module):
ifself.config.pruned_heads:
self.prune_heads(self.config.pruned_heads)
# Tie weights if needed
self.tie_weights()
defprune_heads(self,heads_to_prune):
""" Prunes heads of the base model.
...
...
@@ -178,14 +225,12 @@ class PreTrainedModel(nn.Module):
heads_to_prune: dict with keys being selected layer indices (`int`) and associated values being the list of heads to prune in said layer (list of `int`).
E.g. {1: [0, 2], 2: [2, 3]} will prune heads 0 and 2 on layer 1 and heads 2 and 3 on layer 2.
"""
base_model=getattr(self,self.base_model_prefix,self)# get the base model if needed
# save new sets of pruned heads as union of previously stored pruned heads and newly pruned heads