"vscode:/vscode.git/clone" did not exist on "b14bb54622460dfbe75edfc622ccb06b6b1813f8"
Unverified Commit 27d46397 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Make gradient_checkpointing a training argument (#13657)



* Make gradient_checkpointing a training argument

* Update src/transformers/modeling_utils.py
Co-authored-by: default avatarStas Bekman <stas00@users.noreply.github.com>

* Update src/transformers/configuration_utils.py
Co-authored-by: default avatarStas Bekman <stas00@users.noreply.github.com>

* Fix tests

* Style

* document Gradient Checkpointing as a performance feature

* Small rename

* PoC for not using the config

* Adapt BC to new PoC

* Forgot to save

* Rollout changes to all other models

* Fix typo
Co-authored-by: default avatarStas Bekman <stas00@users.noreply.github.com>
Co-authored-by: default avatarStas Bekman <stas@stason.org>
parent 75f6641e
...@@ -61,8 +61,6 @@ class CanineConfig(PretrainedConfig): ...@@ -61,8 +61,6 @@ class CanineConfig(PretrainedConfig):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices. The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-12): layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-12):
The epsilon used by the layer normalization layers. The epsilon used by the layer normalization layers.
gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`):
If :obj:`True`, use gradient checkpointing to save memory at the expense of slower backward pass.
downsampling_rate (:obj:`int`, `optional`, defaults to 4): downsampling_rate (:obj:`int`, `optional`, defaults to 4):
The rate at which to downsample the original character sequence length before applying the deep Transformer The rate at which to downsample the original character sequence length before applying the deep Transformer
encoder. encoder.
......
...@@ -772,6 +772,7 @@ class CanineEncoder(nn.Module): ...@@ -772,6 +772,7 @@ class CanineEncoder(nn.Module):
for _ in range(config.num_hidden_layers) for _ in range(config.num_hidden_layers)
] ]
) )
self.gradient_checkpointing = False
def forward( def forward(
self, self,
...@@ -791,7 +792,7 @@ class CanineEncoder(nn.Module): ...@@ -791,7 +792,7 @@ class CanineEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None layer_head_mask = head_mask[i] if head_mask is not None else None
if getattr(self.config, "gradient_checkpointing", False) and self.training: if self.gradient_checkpointing and self.training:
def create_custom_forward(module): def create_custom_forward(module):
def custom_forward(*inputs): def custom_forward(*inputs):
...@@ -895,6 +896,7 @@ class CaninePreTrainedModel(PreTrainedModel): ...@@ -895,6 +896,7 @@ class CaninePreTrainedModel(PreTrainedModel):
config_class = CanineConfig config_class = CanineConfig
load_tf_weights = load_tf_weights_in_canine load_tf_weights = load_tf_weights_in_canine
base_model_prefix = "canine" base_model_prefix = "canine"
supports_gradient_checkpointing = True
_keys_to_ignore_on_load_missing = [r"position_ids"] _keys_to_ignore_on_load_missing = [r"position_ids"]
def _init_weights(self, module): def _init_weights(self, module):
...@@ -913,6 +915,10 @@ class CaninePreTrainedModel(PreTrainedModel): ...@@ -913,6 +915,10 @@ class CaninePreTrainedModel(PreTrainedModel):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, CanineEncoder):
module.gradient_checkpointing = value
CANINE_START_DOCSTRING = r""" CANINE_START_DOCSTRING = r"""
This model is a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`_ sub-class. Use This model is a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`_ sub-class. Use
......
...@@ -68,8 +68,6 @@ class CLIPTextConfig(PretrainedConfig): ...@@ -68,8 +68,6 @@ class CLIPTextConfig(PretrainedConfig):
initializer_factor (:obj:`float`, `optional`, defaults to 1): initializer_factor (:obj:`float`, `optional`, defaults to 1):
A factor for initializing all weight matrices (should be kept to 1, used internally for initialization A factor for initializing all weight matrices (should be kept to 1, used internally for initialization
testing). testing).
gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`):
If True, use gradient checkpointing to save memory at the expense of slower backward pass.
Example:: Example::
...@@ -103,7 +101,6 @@ class CLIPTextConfig(PretrainedConfig): ...@@ -103,7 +101,6 @@ class CLIPTextConfig(PretrainedConfig):
pad_token_id=1, pad_token_id=1,
bos_token_id=0, bos_token_id=0,
eos_token_id=2, eos_token_id=2,
gradient_checkpointing=False,
**kwargs **kwargs
): ):
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
...@@ -120,7 +117,6 @@ class CLIPTextConfig(PretrainedConfig): ...@@ -120,7 +117,6 @@ class CLIPTextConfig(PretrainedConfig):
self.initializer_range = initializer_range self.initializer_range = initializer_range
self.initializer_factor = initializer_factor self.initializer_factor = initializer_factor
self.attention_dropout = attention_dropout self.attention_dropout = attention_dropout
self.gradient_checkpointing = gradient_checkpointing
class CLIPVisionConfig(PretrainedConfig): class CLIPVisionConfig(PretrainedConfig):
...@@ -161,8 +157,6 @@ class CLIPVisionConfig(PretrainedConfig): ...@@ -161,8 +157,6 @@ class CLIPVisionConfig(PretrainedConfig):
initializer_factor (:obj:`float`, `optional`, defaults to 1): initializer_factor (:obj:`float`, `optional`, defaults to 1):
A factor for initializing all weight matrices (should be kept to 1, used internally for initialization A factor for initializing all weight matrices (should be kept to 1, used internally for initialization
testing). testing).
gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`):
If True, use gradient checkpointing to save memory at the expense of slower backward pass.
Example:: Example::
...@@ -194,7 +188,6 @@ class CLIPVisionConfig(PretrainedConfig): ...@@ -194,7 +188,6 @@ class CLIPVisionConfig(PretrainedConfig):
attention_dropout=0.0, attention_dropout=0.0,
initializer_range=0.02, initializer_range=0.02,
initializer_factor=1.0, initializer_factor=1.0,
gradient_checkpointing=False,
**kwargs **kwargs
): ):
super().__init__(**kwargs) super().__init__(**kwargs)
...@@ -211,7 +204,6 @@ class CLIPVisionConfig(PretrainedConfig): ...@@ -211,7 +204,6 @@ class CLIPVisionConfig(PretrainedConfig):
self.attention_dropout = attention_dropout self.attention_dropout = attention_dropout
self.layer_norm_eps = layer_norm_eps self.layer_norm_eps = layer_norm_eps
self.hidden_act = hidden_act self.hidden_act = hidden_act
self.gradient_checkpointing = gradient_checkpointing
class CLIPConfig(PretrainedConfig): class CLIPConfig(PretrainedConfig):
......
...@@ -338,6 +338,7 @@ class CLIPPreTrainedModel(PreTrainedModel): ...@@ -338,6 +338,7 @@ class CLIPPreTrainedModel(PreTrainedModel):
config_class = CLIPConfig config_class = CLIPConfig
base_model_prefix = "clip" base_model_prefix = "clip"
supports_gradient_checkpointing = True
_keys_to_ignore_on_load_missing = [r"position_ids"] _keys_to_ignore_on_load_missing = [r"position_ids"]
def _init_weights(self, module): def _init_weights(self, module):
...@@ -383,6 +384,10 @@ class CLIPPreTrainedModel(PreTrainedModel): ...@@ -383,6 +384,10 @@ class CLIPPreTrainedModel(PreTrainedModel):
if isinstance(module, nn.Linear) and module.bias is not None: if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_() module.bias.data.zero_()
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, CLIPEncoder):
module.gradient_checkpointing = value
CLIP_START_DOCSTRING = r""" CLIP_START_DOCSTRING = r"""
This model is a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`_ subclass. Use This model is a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`_ subclass. Use
...@@ -499,6 +504,7 @@ class CLIPEncoder(nn.Module): ...@@ -499,6 +504,7 @@ class CLIPEncoder(nn.Module):
super().__init__() super().__init__()
self.config = config self.config = config
self.layers = nn.ModuleList([CLIPEncoderLayer(config) for _ in range(config.num_hidden_layers)]) self.layers = nn.ModuleList([CLIPEncoderLayer(config) for _ in range(config.num_hidden_layers)])
self.gradient_checkpointing = False
def forward( def forward(
self, self,
...@@ -551,7 +557,7 @@ class CLIPEncoder(nn.Module): ...@@ -551,7 +557,7 @@ class CLIPEncoder(nn.Module):
for idx, encoder_layer in enumerate(self.layers): for idx, encoder_layer in enumerate(self.layers):
if output_hidden_states: if output_hidden_states:
encoder_states = encoder_states + (hidden_states,) encoder_states = encoder_states + (hidden_states,)
if getattr(self.config, "gradient_checkpointing", False) and self.training: if self.gradient_checkpointing and self.training:
def create_custom_forward(module): def create_custom_forward(module):
def custom_forward(*inputs): def custom_forward(*inputs):
......
...@@ -248,6 +248,7 @@ class ConvBertPreTrainedModel(PreTrainedModel): ...@@ -248,6 +248,7 @@ class ConvBertPreTrainedModel(PreTrainedModel):
config_class = ConvBertConfig config_class = ConvBertConfig
load_tf_weights = load_tf_weights_in_convbert load_tf_weights = load_tf_weights_in_convbert
base_model_prefix = "convbert" base_model_prefix = "convbert"
supports_gradient_checkpointing = True
authorized_missing_keys = [r"position_ids"] authorized_missing_keys = [r"position_ids"]
authorized_unexpected_keys = [r"convbert\.embeddings_project\.weight", r"convbert\.embeddings_project\.bias"] authorized_unexpected_keys = [r"convbert\.embeddings_project\.weight", r"convbert\.embeddings_project\.bias"]
...@@ -267,6 +268,10 @@ class ConvBertPreTrainedModel(PreTrainedModel): ...@@ -267,6 +268,10 @@ class ConvBertPreTrainedModel(PreTrainedModel):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, ConvBertEncoder):
module.gradient_checkpointing = value
class SeparableConv1D(nn.Module): class SeparableConv1D(nn.Module):
"""This class implements separable convolution, i.e. a depthwise and a pointwise layer""" """This class implements separable convolution, i.e. a depthwise and a pointwise layer"""
...@@ -603,6 +608,7 @@ class ConvBertEncoder(nn.Module): ...@@ -603,6 +608,7 @@ class ConvBertEncoder(nn.Module):
super().__init__() super().__init__()
self.config = config self.config = config
self.layer = nn.ModuleList([ConvBertLayer(config) for _ in range(config.num_hidden_layers)]) self.layer = nn.ModuleList([ConvBertLayer(config) for _ in range(config.num_hidden_layers)])
self.gradient_checkpointing = False
def forward( def forward(
self, self,
...@@ -624,7 +630,7 @@ class ConvBertEncoder(nn.Module): ...@@ -624,7 +630,7 @@ class ConvBertEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None layer_head_mask = head_mask[i] if head_mask is not None else None
if getattr(self.config, "gradient_checkpointing", False): if self.gradient_checkpointing and self.training:
def create_custom_forward(module): def create_custom_forward(module):
def custom_forward(*inputs): def custom_forward(*inputs):
......
...@@ -58,8 +58,6 @@ class DeiTConfig(PretrainedConfig): ...@@ -58,8 +58,6 @@ class DeiTConfig(PretrainedConfig):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices. The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-12): layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-12):
The epsilon used by the layer normalization layers. The epsilon used by the layer normalization layers.
gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`):
If True, use gradient checkpointing to save memory at the expense of slower backward pass.
image_size (:obj:`int`, `optional`, defaults to :obj:`224`): image_size (:obj:`int`, `optional`, defaults to :obj:`224`):
The size (resolution) of each image. The size (resolution) of each image.
patch_size (:obj:`int`, `optional`, defaults to :obj:`16`): patch_size (:obj:`int`, `optional`, defaults to :obj:`16`):
......
...@@ -324,6 +324,7 @@ class DeiTEncoder(nn.Module): ...@@ -324,6 +324,7 @@ class DeiTEncoder(nn.Module):
super().__init__() super().__init__()
self.config = config self.config = config
self.layer = nn.ModuleList([DeiTLayer(config) for _ in range(config.num_hidden_layers)]) self.layer = nn.ModuleList([DeiTLayer(config) for _ in range(config.num_hidden_layers)])
self.gradient_checkpointing = False
def forward( def forward(
self, self,
...@@ -342,7 +343,7 @@ class DeiTEncoder(nn.Module): ...@@ -342,7 +343,7 @@ class DeiTEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None layer_head_mask = head_mask[i] if head_mask is not None else None
if getattr(self.config, "gradient_checkpointing", False) and self.training: if self.gradient_checkpointing and self.training:
def create_custom_forward(module): def create_custom_forward(module):
def custom_forward(*inputs): def custom_forward(*inputs):
...@@ -384,6 +385,7 @@ class DeiTPreTrainedModel(PreTrainedModel): ...@@ -384,6 +385,7 @@ class DeiTPreTrainedModel(PreTrainedModel):
config_class = DeiTConfig config_class = DeiTConfig
base_model_prefix = "deit" base_model_prefix = "deit"
supports_gradient_checkpointing = True
def _init_weights(self, module): def _init_weights(self, module):
"""Initialize the weights""" """Initialize the weights"""
...@@ -401,6 +403,10 @@ class DeiTPreTrainedModel(PreTrainedModel): ...@@ -401,6 +403,10 @@ class DeiTPreTrainedModel(PreTrainedModel):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, DeiTEncoder):
module.gradient_checkpointing = value
DEIT_START_DOCSTRING = r""" DEIT_START_DOCSTRING = r"""
This model is a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`_ subclass. Use This model is a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`_ subclass. Use
......
...@@ -783,6 +783,7 @@ class DetrClassificationHead(nn.Module): ...@@ -783,6 +783,7 @@ class DetrClassificationHead(nn.Module):
class DetrPreTrainedModel(PreTrainedModel): class DetrPreTrainedModel(PreTrainedModel):
config_class = DetrConfig config_class = DetrConfig
base_model_prefix = "model" base_model_prefix = "model"
supports_gradient_checkpointing = True
def _init_weights(self, module): def _init_weights(self, module):
std = self.config.init_std std = self.config.init_std
...@@ -807,6 +808,10 @@ class DetrPreTrainedModel(PreTrainedModel): ...@@ -807,6 +808,10 @@ class DetrPreTrainedModel(PreTrainedModel):
if module.padding_idx is not None: if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_() module.weight.data[module.padding_idx].zero_()
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, DetrDecoder):
module.gradient_checkpointing = value
DETR_START_DOCSTRING = r""" DETR_START_DOCSTRING = r"""
This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic
...@@ -997,6 +1002,7 @@ class DetrDecoder(DetrPreTrainedModel): ...@@ -997,6 +1002,7 @@ class DetrDecoder(DetrPreTrainedModel):
self.layernorm = nn.LayerNorm(config.d_model) self.layernorm = nn.LayerNorm(config.d_model)
self.init_weights() self.init_weights()
self.gradient_checkpointing = False
def forward( def forward(
self, self,
...@@ -1084,7 +1090,7 @@ class DetrDecoder(DetrPreTrainedModel): ...@@ -1084,7 +1090,7 @@ class DetrDecoder(DetrPreTrainedModel):
if self.training and (dropout_probability < self.layerdrop): if self.training and (dropout_probability < self.layerdrop):
continue continue
if getattr(self.config, "gradient_checkpointing", False) and self.training: if self.gradient_checkpointing and self.training:
def create_custom_forward(module): def create_custom_forward(module):
def custom_forward(*inputs): def custom_forward(*inputs):
......
...@@ -69,8 +69,6 @@ class DPRConfig(PretrainedConfig): ...@@ -69,8 +69,6 @@ class DPRConfig(PretrainedConfig):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices. The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-12): layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-12):
The epsilon used by the layer normalization layers. The epsilon used by the layer normalization layers.
gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`):
If True, use gradient checkpointing to save memory at the expense of slower backward pass.
position_embedding_type (:obj:`str`, `optional`, defaults to :obj:`"absolute"`): position_embedding_type (:obj:`str`, `optional`, defaults to :obj:`"absolute"`):
Type of position embedding. Choose one of :obj:`"absolute"`, :obj:`"relative_key"`, Type of position embedding. Choose one of :obj:`"absolute"`, :obj:`"relative_key"`,
:obj:`"relative_key_query"`. For positional embeddings use :obj:`"absolute"`. For more information on :obj:`"relative_key_query"`. For positional embeddings use :obj:`"absolute"`. For more information on
...@@ -99,7 +97,6 @@ class DPRConfig(PretrainedConfig): ...@@ -99,7 +97,6 @@ class DPRConfig(PretrainedConfig):
initializer_range=0.02, initializer_range=0.02,
layer_norm_eps=1e-12, layer_norm_eps=1e-12,
pad_token_id=0, pad_token_id=0,
gradient_checkpointing=False,
position_embedding_type="absolute", position_embedding_type="absolute",
projection_dim: int = 0, projection_dim: int = 0,
**kwargs **kwargs
...@@ -118,6 +115,5 @@ class DPRConfig(PretrainedConfig): ...@@ -118,6 +115,5 @@ class DPRConfig(PretrainedConfig):
self.type_vocab_size = type_vocab_size self.type_vocab_size = type_vocab_size
self.initializer_range = initializer_range self.initializer_range = initializer_range
self.layer_norm_eps = layer_norm_eps self.layer_norm_eps = layer_norm_eps
self.gradient_checkpointing = gradient_checkpointing
self.projection_dim = projection_dim self.projection_dim = projection_dim
self.position_embedding_type = position_embedding_type self.position_embedding_type = position_embedding_type
...@@ -30,7 +30,7 @@ from ...file_utils import ( ...@@ -30,7 +30,7 @@ from ...file_utils import (
from ...modeling_outputs import BaseModelOutputWithPooling from ...modeling_outputs import BaseModelOutputWithPooling
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...utils import logging from ...utils import logging
from ..bert.modeling_bert import BertModel from ..bert.modeling_bert import BertEncoder, BertModel
from .configuration_dpr import DPRConfig from .configuration_dpr import DPRConfig
...@@ -300,6 +300,10 @@ class DPRPretrainedQuestionEncoder(PreTrainedModel): ...@@ -300,6 +300,10 @@ class DPRPretrainedQuestionEncoder(PreTrainedModel):
def init_weights(self): def init_weights(self):
self.question_encoder.init_weights() self.question_encoder.init_weights()
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, BertEncoder):
module.gradient_checkpointing = value
class DPRPretrainedReader(PreTrainedModel): class DPRPretrainedReader(PreTrainedModel):
""" """
...@@ -317,6 +321,10 @@ class DPRPretrainedReader(PreTrainedModel): ...@@ -317,6 +321,10 @@ class DPRPretrainedReader(PreTrainedModel):
self.span_predictor.qa_classifier.apply(self.span_predictor.encoder.bert_model._init_weights) self.span_predictor.qa_classifier.apply(self.span_predictor.encoder.bert_model._init_weights)
self.span_predictor.qa_outputs.apply(self.span_predictor.encoder.bert_model._init_weights) self.span_predictor.qa_outputs.apply(self.span_predictor.encoder.bert_model._init_weights)
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, BertEncoder):
module.gradient_checkpointing = value
############### ###############
# Actual Models # Actual Models
......
...@@ -527,6 +527,7 @@ class ElectraEncoder(nn.Module): ...@@ -527,6 +527,7 @@ class ElectraEncoder(nn.Module):
super().__init__() super().__init__()
self.config = config self.config = config
self.layer = nn.ModuleList([ElectraLayer(config) for _ in range(config.num_hidden_layers)]) self.layer = nn.ModuleList([ElectraLayer(config) for _ in range(config.num_hidden_layers)])
self.gradient_checkpointing = False
def forward( def forward(
self, self,
...@@ -553,12 +554,11 @@ class ElectraEncoder(nn.Module): ...@@ -553,12 +554,11 @@ class ElectraEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None layer_head_mask = head_mask[i] if head_mask is not None else None
past_key_value = past_key_values[i] if past_key_values is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None
if getattr(self.config, "gradient_checkpointing", False) and self.training: if self.gradient_checkpointing and self.training:
if use_cache: if use_cache:
logger.warning( logger.warning(
"`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
"`use_cache=False`..."
) )
use_cache = False use_cache = False
...@@ -663,6 +663,7 @@ class ElectraPreTrainedModel(PreTrainedModel): ...@@ -663,6 +663,7 @@ class ElectraPreTrainedModel(PreTrainedModel):
config_class = ElectraConfig config_class = ElectraConfig
load_tf_weights = load_tf_weights_in_electra load_tf_weights = load_tf_weights_in_electra
base_model_prefix = "electra" base_model_prefix = "electra"
supports_gradient_checkpointing = True
_keys_to_ignore_on_load_missing = [r"position_ids"] _keys_to_ignore_on_load_missing = [r"position_ids"]
_keys_to_ignore_on_load_unexpected = [r"electra\.embeddings_project\.weight", r"electra\.embeddings_project\.bias"] _keys_to_ignore_on_load_unexpected = [r"electra\.embeddings_project\.weight", r"electra\.embeddings_project\.bias"]
...@@ -683,6 +684,10 @@ class ElectraPreTrainedModel(PreTrainedModel): ...@@ -683,6 +684,10 @@ class ElectraPreTrainedModel(PreTrainedModel):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, ElectraEncoder):
module.gradient_checkpointing = value
@dataclass @dataclass
class ElectraForPreTrainingOutput(ModelOutput): class ElectraForPreTrainingOutput(ModelOutput):
......
...@@ -64,8 +64,6 @@ class FNetConfig(PretrainedConfig): ...@@ -64,8 +64,6 @@ class FNetConfig(PretrainedConfig):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices. The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-12): layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-12):
The epsilon used by the layer normalization layers. The epsilon used by the layer normalization layers.
gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`):
If :obj:`True`, use gradient checkpointing to save memory at the expense of slower backward pass.
use_tpu_fourier_optimizations (:obj:`bool`, `optional`, defaults to :obj:`False`): use_tpu_fourier_optimizations (:obj:`bool`, `optional`, defaults to :obj:`False`):
Determines whether to use TPU optimized FFTs. If :obj:`True`, the model will favor axis-wise FFTs Determines whether to use TPU optimized FFTs. If :obj:`True`, the model will favor axis-wise FFTs
transforms. Set to :obj:`False` for GPU/CPU hardware, in which case n-dimensional FFTs are used. transforms. Set to :obj:`False` for GPU/CPU hardware, in which case n-dimensional FFTs are used.
......
...@@ -284,6 +284,7 @@ class FNetEncoder(nn.Module): ...@@ -284,6 +284,7 @@ class FNetEncoder(nn.Module):
super().__init__() super().__init__()
self.config = config self.config = config
self.layer = nn.ModuleList([FNetLayer(config) for _ in range(config.num_hidden_layers)]) self.layer = nn.ModuleList([FNetLayer(config) for _ in range(config.num_hidden_layers)])
self.gradient_checkpointing = False
def forward(self, hidden_states, output_hidden_states=False, return_dict=True): def forward(self, hidden_states, output_hidden_states=False, return_dict=True):
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
...@@ -292,7 +293,7 @@ class FNetEncoder(nn.Module): ...@@ -292,7 +293,7 @@ class FNetEncoder(nn.Module):
if output_hidden_states: if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
if getattr(self.config, "gradient_checkpointing", False) and self.training: if self.gradient_checkpointing and self.training:
def create_custom_forward(module): def create_custom_forward(module):
def custom_forward(*inputs): def custom_forward(*inputs):
...@@ -413,6 +414,7 @@ class FNetPreTrainedModel(PreTrainedModel): ...@@ -413,6 +414,7 @@ class FNetPreTrainedModel(PreTrainedModel):
config_class = FNetConfig config_class = FNetConfig
base_model_prefix = "fnet" base_model_prefix = "fnet"
supports_gradient_checkpointing = True
_keys_to_ignore_on_load_missing = [r"position_ids"] _keys_to_ignore_on_load_missing = [r"position_ids"]
def _init_weights(self, module): def _init_weights(self, module):
...@@ -432,6 +434,10 @@ class FNetPreTrainedModel(PreTrainedModel): ...@@ -432,6 +434,10 @@ class FNetPreTrainedModel(PreTrainedModel):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, FNetEncoder):
module.gradient_checkpointing = value
@dataclass @dataclass
class FNetForPreTrainingOutput(ModelOutput): class FNetForPreTrainingOutput(ModelOutput):
......
...@@ -108,9 +108,7 @@ class GPT2Config(PretrainedConfig): ...@@ -108,9 +108,7 @@ class GPT2Config(PretrainedConfig):
The dropout ratio to be used after the projection and activation. The dropout ratio to be used after the projection and activation.
scale_attn_weights (:obj:`bool`, `optional`, defaults to :obj:`True`): scale_attn_weights (:obj:`bool`, `optional`, defaults to :obj:`True`):
Scale attention weights by dividing by sqrt(hidden_size). Scale attention weights by dividing by sqrt(hidden_size)..
gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`): use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not the model should return the last key/values attentions (not used by all models). Whether or not the model should return the last key/values attentions (not used by all models).
...@@ -158,7 +156,6 @@ class GPT2Config(PretrainedConfig): ...@@ -158,7 +156,6 @@ class GPT2Config(PretrainedConfig):
summary_proj_to_labels=True, summary_proj_to_labels=True,
summary_first_dropout=0.1, summary_first_dropout=0.1,
scale_attn_weights=True, scale_attn_weights=True,
gradient_checkpointing=False,
use_cache=True, use_cache=True,
bos_token_id=50256, bos_token_id=50256,
eos_token_id=50256, eos_token_id=50256,
...@@ -182,7 +179,6 @@ class GPT2Config(PretrainedConfig): ...@@ -182,7 +179,6 @@ class GPT2Config(PretrainedConfig):
self.summary_activation = summary_activation self.summary_activation = summary_activation
self.summary_first_dropout = summary_first_dropout self.summary_first_dropout = summary_first_dropout
self.summary_proj_to_labels = summary_proj_to_labels self.summary_proj_to_labels = summary_proj_to_labels
self.gradient_checkpointing = gradient_checkpointing
self.scale_attn_weights = scale_attn_weights self.scale_attn_weights = scale_attn_weights
self.use_cache = use_cache self.use_cache = use_cache
......
...@@ -374,6 +374,7 @@ class GPT2PreTrainedModel(PreTrainedModel): ...@@ -374,6 +374,7 @@ class GPT2PreTrainedModel(PreTrainedModel):
load_tf_weights = load_tf_weights_in_gpt2 load_tf_weights = load_tf_weights_in_gpt2
base_model_prefix = "transformer" base_model_prefix = "transformer"
is_parallelizable = True is_parallelizable = True
supports_gradient_checkpointing = True
def __init__(self, *inputs, **kwargs): def __init__(self, *inputs, **kwargs):
super().__init__(*inputs, **kwargs) super().__init__(*inputs, **kwargs)
...@@ -394,6 +395,10 @@ class GPT2PreTrainedModel(PreTrainedModel): ...@@ -394,6 +395,10 @@ class GPT2PreTrainedModel(PreTrainedModel):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, GPT2Model):
module.gradient_checkpointing = value
@dataclass @dataclass
class GPT2DoubleHeadsModelOutput(ModelOutput): class GPT2DoubleHeadsModelOutput(ModelOutput):
...@@ -589,6 +594,7 @@ class GPT2Model(GPT2PreTrainedModel): ...@@ -589,6 +594,7 @@ class GPT2Model(GPT2PreTrainedModel):
# Model parallel # Model parallel
self.model_parallel = False self.model_parallel = False
self.device_map = None self.device_map = None
self.gradient_checkpointing = False
@add_start_docstrings(PARALLELIZE_DOCSTRING) @add_start_docstrings(PARALLELIZE_DOCSTRING)
def parallelize(self, device_map=None): def parallelize(self, device_map=None):
...@@ -764,12 +770,11 @@ class GPT2Model(GPT2PreTrainedModel): ...@@ -764,12 +770,11 @@ class GPT2Model(GPT2PreTrainedModel):
if output_hidden_states: if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
if getattr(self.config, "gradient_checkpointing", False) and self.training: if self.gradient_checkpointing and self.training:
if use_cache: if use_cache:
logger.warning( logger.warning(
"`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
"`use_cache=False`..."
) )
use_cache = False use_cache = False
......
...@@ -79,8 +79,6 @@ class GPTNeoConfig(PretrainedConfig): ...@@ -79,8 +79,6 @@ class GPTNeoConfig(PretrainedConfig):
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`): use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if ``config.is_decoder=True``. relevant if ``config.is_decoder=True``.
gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`):
If True, use gradient checkpointing to save memory at the expense of slower backward pass.
Example:: Example::
...@@ -120,7 +118,6 @@ class GPTNeoConfig(PretrainedConfig): ...@@ -120,7 +118,6 @@ class GPTNeoConfig(PretrainedConfig):
summary_activation=None, summary_activation=None,
summary_proj_to_labels=True, summary_proj_to_labels=True,
summary_first_dropout=0.1, summary_first_dropout=0.1,
gradient_checkpointing=False,
use_cache=True, use_cache=True,
bos_token_id=50256, bos_token_id=50256,
eos_token_id=50256, eos_token_id=50256,
...@@ -144,7 +141,6 @@ class GPTNeoConfig(PretrainedConfig): ...@@ -144,7 +141,6 @@ class GPTNeoConfig(PretrainedConfig):
self.summary_activation = summary_activation self.summary_activation = summary_activation
self.summary_first_dropout = summary_first_dropout self.summary_first_dropout = summary_first_dropout
self.summary_proj_to_labels = summary_proj_to_labels self.summary_proj_to_labels = summary_proj_to_labels
self.gradient_checkpointing = gradient_checkpointing
self.use_cache = use_cache self.use_cache = use_cache
self.bos_token_id = bos_token_id self.bos_token_id = bos_token_id
......
...@@ -361,6 +361,7 @@ class GPTNeoPreTrainedModel(PreTrainedModel): ...@@ -361,6 +361,7 @@ class GPTNeoPreTrainedModel(PreTrainedModel):
config_class = GPTNeoConfig config_class = GPTNeoConfig
load_tf_weights = load_tf_weights_in_gpt_neo load_tf_weights = load_tf_weights_in_gpt_neo
base_model_prefix = "transformer" base_model_prefix = "transformer"
supports_gradient_checkpointing = True
def __init__(self, *inputs, **kwargs): def __init__(self, *inputs, **kwargs):
super().__init__(*inputs, **kwargs) super().__init__(*inputs, **kwargs)
...@@ -381,6 +382,10 @@ class GPTNeoPreTrainedModel(PreTrainedModel): ...@@ -381,6 +382,10 @@ class GPTNeoPreTrainedModel(PreTrainedModel):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, GPTNeoModel):
module.gradient_checkpointing = value
GPT_NEO_START_DOCSTRING = r""" GPT_NEO_START_DOCSTRING = r"""
...@@ -482,6 +487,7 @@ class GPTNeoModel(GPTNeoPreTrainedModel): ...@@ -482,6 +487,7 @@ class GPTNeoModel(GPTNeoPreTrainedModel):
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
self.init_weights() self.init_weights()
self.gradient_checkpointing = False
def get_input_embeddings(self): def get_input_embeddings(self):
return self.wte return self.wte
...@@ -592,12 +598,11 @@ class GPTNeoModel(GPTNeoPreTrainedModel): ...@@ -592,12 +598,11 @@ class GPTNeoModel(GPTNeoPreTrainedModel):
if output_hidden_states: if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
if getattr(self.config, "gradient_checkpointing", False) and self.training: if self.gradient_checkpointing and self.training:
if use_cache: if use_cache:
logger.warning( logger.warning(
"`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
"`use_cache=False`..."
) )
use_cache = False use_cache = False
......
...@@ -68,8 +68,6 @@ class GPTJConfig(PretrainedConfig): ...@@ -68,8 +68,6 @@ class GPTJConfig(PretrainedConfig):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices. The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
scale_attn_weights (:obj:`bool`, `optional`, defaults to :obj:`True`): scale_attn_weights (:obj:`bool`, `optional`, defaults to :obj:`True`):
Scale attention weights by dividing by sqrt(hidden_size). Scale attention weights by dividing by sqrt(hidden_size).
gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`): use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not the model should return the last key/values attentions (not used by all models). Whether or not the model should return the last key/values attentions (not used by all models).
...@@ -111,7 +109,6 @@ class GPTJConfig(PretrainedConfig): ...@@ -111,7 +109,6 @@ class GPTJConfig(PretrainedConfig):
layer_norm_epsilon=1e-5, layer_norm_epsilon=1e-5,
initializer_range=0.02, initializer_range=0.02,
scale_attn_weights=True, scale_attn_weights=True,
gradient_checkpointing=False,
use_cache=True, use_cache=True,
bos_token_id=50256, bos_token_id=50256,
eos_token_id=50256, eos_token_id=50256,
...@@ -131,7 +128,6 @@ class GPTJConfig(PretrainedConfig): ...@@ -131,7 +128,6 @@ class GPTJConfig(PretrainedConfig):
self.attn_pdrop = attn_pdrop self.attn_pdrop = attn_pdrop
self.layer_norm_epsilon = layer_norm_epsilon self.layer_norm_epsilon = layer_norm_epsilon
self.initializer_range = initializer_range self.initializer_range = initializer_range
self.gradient_checkpointing = gradient_checkpointing
self.scale_attn_weights = scale_attn_weights self.scale_attn_weights = scale_attn_weights
self.use_cache = use_cache self.use_cache = use_cache
......
...@@ -303,6 +303,7 @@ class GPTJPreTrainedModel(PreTrainedModel): ...@@ -303,6 +303,7 @@ class GPTJPreTrainedModel(PreTrainedModel):
config_class = GPTJConfig config_class = GPTJConfig
base_model_prefix = "transformer" base_model_prefix = "transformer"
is_parallelizable = True is_parallelizable = True
supports_gradient_checkpointing = True
def __init__(self, *inputs, **kwargs): def __init__(self, *inputs, **kwargs):
super().__init__(*inputs, **kwargs) super().__init__(*inputs, **kwargs)
...@@ -323,6 +324,10 @@ class GPTJPreTrainedModel(PreTrainedModel): ...@@ -323,6 +324,10 @@ class GPTJPreTrainedModel(PreTrainedModel):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, GPTJModel):
module.gradient_checkpointing = value
GPTJ_START_DOCSTRING = r""" GPTJ_START_DOCSTRING = r"""
This model is a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`_ sub-class. Use This model is a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`_ sub-class. Use
...@@ -445,6 +450,7 @@ class GPTJModel(GPTJPreTrainedModel): ...@@ -445,6 +450,7 @@ class GPTJModel(GPTJPreTrainedModel):
# Model parallel # Model parallel
self.model_parallel = False self.model_parallel = False
self.device_map = None self.device_map = None
self.gradient_checkpointing = False
@add_start_docstrings(PARALLELIZE_DOCSTRING) @add_start_docstrings(PARALLELIZE_DOCSTRING)
def parallelize(self, device_map=None): def parallelize(self, device_map=None):
...@@ -598,12 +604,11 @@ class GPTJModel(GPTJPreTrainedModel): ...@@ -598,12 +604,11 @@ class GPTJModel(GPTJPreTrainedModel):
if output_hidden_states: if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
if getattr(self.config, "gradient_checkpointing", False) and self.training: if self.gradient_checkpointing and self.training:
if use_cache: if use_cache:
logger.warning( logger.warning(
"`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
"`use_cache=False`..."
) )
use_cache = False use_cache = False
......
...@@ -120,8 +120,6 @@ class HubertConfig(PretrainedConfig): ...@@ -120,8 +120,6 @@ class HubertConfig(PretrainedConfig):
instance of :class:`~transformers.HubertForSequenceClassification`. instance of :class:`~transformers.HubertForSequenceClassification`.
classifier_proj_size (:obj:`int`, `optional`, defaults to 256): classifier_proj_size (:obj:`int`, `optional`, defaults to 256):
Dimensionality of the projection before token mean-pooling for classification. Dimensionality of the projection before token mean-pooling for classification.
gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`):
If True, use gradient checkpointing to save memory at the expense of slower backward pass.
Example:: Example::
...@@ -172,7 +170,6 @@ class HubertConfig(PretrainedConfig): ...@@ -172,7 +170,6 @@ class HubertConfig(PretrainedConfig):
ctc_zero_infinity=False, ctc_zero_infinity=False,
use_weighted_layer_sum=False, use_weighted_layer_sum=False,
classifier_proj_size=256, classifier_proj_size=256,
gradient_checkpointing=False,
pad_token_id=0, pad_token_id=0,
bos_token_id=1, bos_token_id=1,
eos_token_id=2, eos_token_id=2,
...@@ -203,7 +200,6 @@ class HubertConfig(PretrainedConfig): ...@@ -203,7 +200,6 @@ class HubertConfig(PretrainedConfig):
self.initializer_range = initializer_range self.initializer_range = initializer_range
self.vocab_size = vocab_size self.vocab_size = vocab_size
self.do_stable_layer_norm = do_stable_layer_norm self.do_stable_layer_norm = do_stable_layer_norm
self.gradient_checkpointing = gradient_checkpointing
self.use_weighted_layer_sum = use_weighted_layer_sum self.use_weighted_layer_sum = use_weighted_layer_sum
self.classifier_proj_size = classifier_proj_size self.classifier_proj_size = classifier_proj_size
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment