Unverified Commit 27d46397 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Make gradient_checkpointing a training argument (#13657)



* Make gradient_checkpointing a training argument

* Update src/transformers/modeling_utils.py
Co-authored-by: default avatarStas Bekman <stas00@users.noreply.github.com>

* Update src/transformers/configuration_utils.py
Co-authored-by: default avatarStas Bekman <stas00@users.noreply.github.com>

* Fix tests

* Style

* document Gradient Checkpointing as a performance feature

* Small rename

* PoC for not using the config

* Adapt BC to new PoC

* Forgot to save

* Rollout changes to all other models

* Fix typo
Co-authored-by: default avatarStas Bekman <stas00@users.noreply.github.com>
Co-authored-by: default avatarStas Bekman <stas@stason.org>
parent 75f6641e
......@@ -61,8 +61,6 @@ class CanineConfig(PretrainedConfig):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-12):
The epsilon used by the layer normalization layers.
gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`):
If :obj:`True`, use gradient checkpointing to save memory at the expense of slower backward pass.
downsampling_rate (:obj:`int`, `optional`, defaults to 4):
The rate at which to downsample the original character sequence length before applying the deep Transformer
encoder.
......
......@@ -772,6 +772,7 @@ class CanineEncoder(nn.Module):
for _ in range(config.num_hidden_layers)
]
)
self.gradient_checkpointing = False
def forward(
self,
......@@ -791,7 +792,7 @@ class CanineEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None
if getattr(self.config, "gradient_checkpointing", False) and self.training:
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
......@@ -895,6 +896,7 @@ class CaninePreTrainedModel(PreTrainedModel):
config_class = CanineConfig
load_tf_weights = load_tf_weights_in_canine
base_model_prefix = "canine"
supports_gradient_checkpointing = True
_keys_to_ignore_on_load_missing = [r"position_ids"]
def _init_weights(self, module):
......@@ -913,6 +915,10 @@ class CaninePreTrainedModel(PreTrainedModel):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, CanineEncoder):
module.gradient_checkpointing = value
CANINE_START_DOCSTRING = r"""
This model is a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`_ sub-class. Use
......
......@@ -68,8 +68,6 @@ class CLIPTextConfig(PretrainedConfig):
initializer_factor (:obj:`float`, `optional`, defaults to 1):
A factor for initializing all weight matrices (should be kept to 1, used internally for initialization
testing).
gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`):
If True, use gradient checkpointing to save memory at the expense of slower backward pass.
Example::
......@@ -103,7 +101,6 @@ class CLIPTextConfig(PretrainedConfig):
pad_token_id=1,
bos_token_id=0,
eos_token_id=2,
gradient_checkpointing=False,
**kwargs
):
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
......@@ -120,7 +117,6 @@ class CLIPTextConfig(PretrainedConfig):
self.initializer_range = initializer_range
self.initializer_factor = initializer_factor
self.attention_dropout = attention_dropout
self.gradient_checkpointing = gradient_checkpointing
class CLIPVisionConfig(PretrainedConfig):
......@@ -161,8 +157,6 @@ class CLIPVisionConfig(PretrainedConfig):
initializer_factor (:obj:`float`, `optional`, defaults to 1):
A factor for initializing all weight matrices (should be kept to 1, used internally for initialization
testing).
gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`):
If True, use gradient checkpointing to save memory at the expense of slower backward pass.
Example::
......@@ -194,7 +188,6 @@ class CLIPVisionConfig(PretrainedConfig):
attention_dropout=0.0,
initializer_range=0.02,
initializer_factor=1.0,
gradient_checkpointing=False,
**kwargs
):
super().__init__(**kwargs)
......@@ -211,7 +204,6 @@ class CLIPVisionConfig(PretrainedConfig):
self.attention_dropout = attention_dropout
self.layer_norm_eps = layer_norm_eps
self.hidden_act = hidden_act
self.gradient_checkpointing = gradient_checkpointing
class CLIPConfig(PretrainedConfig):
......
......@@ -338,6 +338,7 @@ class CLIPPreTrainedModel(PreTrainedModel):
config_class = CLIPConfig
base_model_prefix = "clip"
supports_gradient_checkpointing = True
_keys_to_ignore_on_load_missing = [r"position_ids"]
def _init_weights(self, module):
......@@ -383,6 +384,10 @@ class CLIPPreTrainedModel(PreTrainedModel):
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, CLIPEncoder):
module.gradient_checkpointing = value
CLIP_START_DOCSTRING = r"""
This model is a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`_ subclass. Use
......@@ -499,6 +504,7 @@ class CLIPEncoder(nn.Module):
super().__init__()
self.config = config
self.layers = nn.ModuleList([CLIPEncoderLayer(config) for _ in range(config.num_hidden_layers)])
self.gradient_checkpointing = False
def forward(
self,
......@@ -551,7 +557,7 @@ class CLIPEncoder(nn.Module):
for idx, encoder_layer in enumerate(self.layers):
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
if getattr(self.config, "gradient_checkpointing", False) and self.training:
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
......
......@@ -248,6 +248,7 @@ class ConvBertPreTrainedModel(PreTrainedModel):
config_class = ConvBertConfig
load_tf_weights = load_tf_weights_in_convbert
base_model_prefix = "convbert"
supports_gradient_checkpointing = True
authorized_missing_keys = [r"position_ids"]
authorized_unexpected_keys = [r"convbert\.embeddings_project\.weight", r"convbert\.embeddings_project\.bias"]
......@@ -267,6 +268,10 @@ class ConvBertPreTrainedModel(PreTrainedModel):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, ConvBertEncoder):
module.gradient_checkpointing = value
class SeparableConv1D(nn.Module):
"""This class implements separable convolution, i.e. a depthwise and a pointwise layer"""
......@@ -603,6 +608,7 @@ class ConvBertEncoder(nn.Module):
super().__init__()
self.config = config
self.layer = nn.ModuleList([ConvBertLayer(config) for _ in range(config.num_hidden_layers)])
self.gradient_checkpointing = False
def forward(
self,
......@@ -624,7 +630,7 @@ class ConvBertEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None
if getattr(self.config, "gradient_checkpointing", False):
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
......
......@@ -58,8 +58,6 @@ class DeiTConfig(PretrainedConfig):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-12):
The epsilon used by the layer normalization layers.
gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`):
If True, use gradient checkpointing to save memory at the expense of slower backward pass.
image_size (:obj:`int`, `optional`, defaults to :obj:`224`):
The size (resolution) of each image.
patch_size (:obj:`int`, `optional`, defaults to :obj:`16`):
......
......@@ -324,6 +324,7 @@ class DeiTEncoder(nn.Module):
super().__init__()
self.config = config
self.layer = nn.ModuleList([DeiTLayer(config) for _ in range(config.num_hidden_layers)])
self.gradient_checkpointing = False
def forward(
self,
......@@ -342,7 +343,7 @@ class DeiTEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None
if getattr(self.config, "gradient_checkpointing", False) and self.training:
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
......@@ -384,6 +385,7 @@ class DeiTPreTrainedModel(PreTrainedModel):
config_class = DeiTConfig
base_model_prefix = "deit"
supports_gradient_checkpointing = True
def _init_weights(self, module):
"""Initialize the weights"""
......@@ -401,6 +403,10 @@ class DeiTPreTrainedModel(PreTrainedModel):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, DeiTEncoder):
module.gradient_checkpointing = value
DEIT_START_DOCSTRING = r"""
This model is a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`_ subclass. Use
......
......@@ -783,6 +783,7 @@ class DetrClassificationHead(nn.Module):
class DetrPreTrainedModel(PreTrainedModel):
config_class = DetrConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
def _init_weights(self, module):
std = self.config.init_std
......@@ -807,6 +808,10 @@ class DetrPreTrainedModel(PreTrainedModel):
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, DetrDecoder):
module.gradient_checkpointing = value
DETR_START_DOCSTRING = r"""
This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic
......@@ -997,6 +1002,7 @@ class DetrDecoder(DetrPreTrainedModel):
self.layernorm = nn.LayerNorm(config.d_model)
self.init_weights()
self.gradient_checkpointing = False
def forward(
self,
......@@ -1084,7 +1090,7 @@ class DetrDecoder(DetrPreTrainedModel):
if self.training and (dropout_probability < self.layerdrop):
continue
if getattr(self.config, "gradient_checkpointing", False) and self.training:
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
......
......@@ -69,8 +69,6 @@ class DPRConfig(PretrainedConfig):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-12):
The epsilon used by the layer normalization layers.
gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`):
If True, use gradient checkpointing to save memory at the expense of slower backward pass.
position_embedding_type (:obj:`str`, `optional`, defaults to :obj:`"absolute"`):
Type of position embedding. Choose one of :obj:`"absolute"`, :obj:`"relative_key"`,
:obj:`"relative_key_query"`. For positional embeddings use :obj:`"absolute"`. For more information on
......@@ -99,7 +97,6 @@ class DPRConfig(PretrainedConfig):
initializer_range=0.02,
layer_norm_eps=1e-12,
pad_token_id=0,
gradient_checkpointing=False,
position_embedding_type="absolute",
projection_dim: int = 0,
**kwargs
......@@ -118,6 +115,5 @@ class DPRConfig(PretrainedConfig):
self.type_vocab_size = type_vocab_size
self.initializer_range = initializer_range
self.layer_norm_eps = layer_norm_eps
self.gradient_checkpointing = gradient_checkpointing
self.projection_dim = projection_dim
self.position_embedding_type = position_embedding_type
......@@ -30,7 +30,7 @@ from ...file_utils import (
from ...modeling_outputs import BaseModelOutputWithPooling
from ...modeling_utils import PreTrainedModel
from ...utils import logging
from ..bert.modeling_bert import BertModel
from ..bert.modeling_bert import BertEncoder, BertModel
from .configuration_dpr import DPRConfig
......@@ -300,6 +300,10 @@ class DPRPretrainedQuestionEncoder(PreTrainedModel):
def init_weights(self):
self.question_encoder.init_weights()
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, BertEncoder):
module.gradient_checkpointing = value
class DPRPretrainedReader(PreTrainedModel):
"""
......@@ -317,6 +321,10 @@ class DPRPretrainedReader(PreTrainedModel):
self.span_predictor.qa_classifier.apply(self.span_predictor.encoder.bert_model._init_weights)
self.span_predictor.qa_outputs.apply(self.span_predictor.encoder.bert_model._init_weights)
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, BertEncoder):
module.gradient_checkpointing = value
###############
# Actual Models
......
......@@ -527,6 +527,7 @@ class ElectraEncoder(nn.Module):
super().__init__()
self.config = config
self.layer = nn.ModuleList([ElectraLayer(config) for _ in range(config.num_hidden_layers)])
self.gradient_checkpointing = False
def forward(
self,
......@@ -553,12 +554,11 @@ class ElectraEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None
past_key_value = past_key_values[i] if past_key_values is not None else None
if getattr(self.config, "gradient_checkpointing", False) and self.training:
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
"`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting "
"`use_cache=False`..."
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
......@@ -663,6 +663,7 @@ class ElectraPreTrainedModel(PreTrainedModel):
config_class = ElectraConfig
load_tf_weights = load_tf_weights_in_electra
base_model_prefix = "electra"
supports_gradient_checkpointing = True
_keys_to_ignore_on_load_missing = [r"position_ids"]
_keys_to_ignore_on_load_unexpected = [r"electra\.embeddings_project\.weight", r"electra\.embeddings_project\.bias"]
......@@ -683,6 +684,10 @@ class ElectraPreTrainedModel(PreTrainedModel):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, ElectraEncoder):
module.gradient_checkpointing = value
@dataclass
class ElectraForPreTrainingOutput(ModelOutput):
......
......@@ -64,8 +64,6 @@ class FNetConfig(PretrainedConfig):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-12):
The epsilon used by the layer normalization layers.
gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`):
If :obj:`True`, use gradient checkpointing to save memory at the expense of slower backward pass.
use_tpu_fourier_optimizations (:obj:`bool`, `optional`, defaults to :obj:`False`):
Determines whether to use TPU optimized FFTs. If :obj:`True`, the model will favor axis-wise FFTs
transforms. Set to :obj:`False` for GPU/CPU hardware, in which case n-dimensional FFTs are used.
......
......@@ -284,6 +284,7 @@ class FNetEncoder(nn.Module):
super().__init__()
self.config = config
self.layer = nn.ModuleList([FNetLayer(config) for _ in range(config.num_hidden_layers)])
self.gradient_checkpointing = False
def forward(self, hidden_states, output_hidden_states=False, return_dict=True):
all_hidden_states = () if output_hidden_states else None
......@@ -292,7 +293,7 @@ class FNetEncoder(nn.Module):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if getattr(self.config, "gradient_checkpointing", False) and self.training:
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
......@@ -413,6 +414,7 @@ class FNetPreTrainedModel(PreTrainedModel):
config_class = FNetConfig
base_model_prefix = "fnet"
supports_gradient_checkpointing = True
_keys_to_ignore_on_load_missing = [r"position_ids"]
def _init_weights(self, module):
......@@ -432,6 +434,10 @@ class FNetPreTrainedModel(PreTrainedModel):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, FNetEncoder):
module.gradient_checkpointing = value
@dataclass
class FNetForPreTrainingOutput(ModelOutput):
......
......@@ -108,9 +108,7 @@ class GPT2Config(PretrainedConfig):
The dropout ratio to be used after the projection and activation.
scale_attn_weights (:obj:`bool`, `optional`, defaults to :obj:`True`):
Scale attention weights by dividing by sqrt(hidden_size).
gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.
Scale attention weights by dividing by sqrt(hidden_size)..
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not the model should return the last key/values attentions (not used by all models).
......@@ -158,7 +156,6 @@ class GPT2Config(PretrainedConfig):
summary_proj_to_labels=True,
summary_first_dropout=0.1,
scale_attn_weights=True,
gradient_checkpointing=False,
use_cache=True,
bos_token_id=50256,
eos_token_id=50256,
......@@ -182,7 +179,6 @@ class GPT2Config(PretrainedConfig):
self.summary_activation = summary_activation
self.summary_first_dropout = summary_first_dropout
self.summary_proj_to_labels = summary_proj_to_labels
self.gradient_checkpointing = gradient_checkpointing
self.scale_attn_weights = scale_attn_weights
self.use_cache = use_cache
......
......@@ -374,6 +374,7 @@ class GPT2PreTrainedModel(PreTrainedModel):
load_tf_weights = load_tf_weights_in_gpt2
base_model_prefix = "transformer"
is_parallelizable = True
supports_gradient_checkpointing = True
def __init__(self, *inputs, **kwargs):
super().__init__(*inputs, **kwargs)
......@@ -394,6 +395,10 @@ class GPT2PreTrainedModel(PreTrainedModel):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, GPT2Model):
module.gradient_checkpointing = value
@dataclass
class GPT2DoubleHeadsModelOutput(ModelOutput):
......@@ -589,6 +594,7 @@ class GPT2Model(GPT2PreTrainedModel):
# Model parallel
self.model_parallel = False
self.device_map = None
self.gradient_checkpointing = False
@add_start_docstrings(PARALLELIZE_DOCSTRING)
def parallelize(self, device_map=None):
......@@ -764,12 +770,11 @@ class GPT2Model(GPT2PreTrainedModel):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if getattr(self.config, "gradient_checkpointing", False) and self.training:
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
"`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting "
"`use_cache=False`..."
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
......
......@@ -79,8 +79,6 @@ class GPTNeoConfig(PretrainedConfig):
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if ``config.is_decoder=True``.
gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`):
If True, use gradient checkpointing to save memory at the expense of slower backward pass.
Example::
......@@ -120,7 +118,6 @@ class GPTNeoConfig(PretrainedConfig):
summary_activation=None,
summary_proj_to_labels=True,
summary_first_dropout=0.1,
gradient_checkpointing=False,
use_cache=True,
bos_token_id=50256,
eos_token_id=50256,
......@@ -144,7 +141,6 @@ class GPTNeoConfig(PretrainedConfig):
self.summary_activation = summary_activation
self.summary_first_dropout = summary_first_dropout
self.summary_proj_to_labels = summary_proj_to_labels
self.gradient_checkpointing = gradient_checkpointing
self.use_cache = use_cache
self.bos_token_id = bos_token_id
......
......@@ -361,6 +361,7 @@ class GPTNeoPreTrainedModel(PreTrainedModel):
config_class = GPTNeoConfig
load_tf_weights = load_tf_weights_in_gpt_neo
base_model_prefix = "transformer"
supports_gradient_checkpointing = True
def __init__(self, *inputs, **kwargs):
super().__init__(*inputs, **kwargs)
......@@ -381,6 +382,10 @@ class GPTNeoPreTrainedModel(PreTrainedModel):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, GPTNeoModel):
module.gradient_checkpointing = value
GPT_NEO_START_DOCSTRING = r"""
......@@ -482,6 +487,7 @@ class GPTNeoModel(GPTNeoPreTrainedModel):
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
self.init_weights()
self.gradient_checkpointing = False
def get_input_embeddings(self):
return self.wte
......@@ -592,12 +598,11 @@ class GPTNeoModel(GPTNeoPreTrainedModel):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if getattr(self.config, "gradient_checkpointing", False) and self.training:
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
"`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting "
"`use_cache=False`..."
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
......
......@@ -68,8 +68,6 @@ class GPTJConfig(PretrainedConfig):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
scale_attn_weights (:obj:`bool`, `optional`, defaults to :obj:`True`):
Scale attention weights by dividing by sqrt(hidden_size).
gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not the model should return the last key/values attentions (not used by all models).
......@@ -111,7 +109,6 @@ class GPTJConfig(PretrainedConfig):
layer_norm_epsilon=1e-5,
initializer_range=0.02,
scale_attn_weights=True,
gradient_checkpointing=False,
use_cache=True,
bos_token_id=50256,
eos_token_id=50256,
......@@ -131,7 +128,6 @@ class GPTJConfig(PretrainedConfig):
self.attn_pdrop = attn_pdrop
self.layer_norm_epsilon = layer_norm_epsilon
self.initializer_range = initializer_range
self.gradient_checkpointing = gradient_checkpointing
self.scale_attn_weights = scale_attn_weights
self.use_cache = use_cache
......
......@@ -303,6 +303,7 @@ class GPTJPreTrainedModel(PreTrainedModel):
config_class = GPTJConfig
base_model_prefix = "transformer"
is_parallelizable = True
supports_gradient_checkpointing = True
def __init__(self, *inputs, **kwargs):
super().__init__(*inputs, **kwargs)
......@@ -323,6 +324,10 @@ class GPTJPreTrainedModel(PreTrainedModel):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, GPTJModel):
module.gradient_checkpointing = value
GPTJ_START_DOCSTRING = r"""
This model is a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`_ sub-class. Use
......@@ -445,6 +450,7 @@ class GPTJModel(GPTJPreTrainedModel):
# Model parallel
self.model_parallel = False
self.device_map = None
self.gradient_checkpointing = False
@add_start_docstrings(PARALLELIZE_DOCSTRING)
def parallelize(self, device_map=None):
......@@ -598,12 +604,11 @@ class GPTJModel(GPTJPreTrainedModel):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if getattr(self.config, "gradient_checkpointing", False) and self.training:
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
"`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting "
"`use_cache=False`..."
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
......
......@@ -120,8 +120,6 @@ class HubertConfig(PretrainedConfig):
instance of :class:`~transformers.HubertForSequenceClassification`.
classifier_proj_size (:obj:`int`, `optional`, defaults to 256):
Dimensionality of the projection before token mean-pooling for classification.
gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`):
If True, use gradient checkpointing to save memory at the expense of slower backward pass.
Example::
......@@ -172,7 +170,6 @@ class HubertConfig(PretrainedConfig):
ctc_zero_infinity=False,
use_weighted_layer_sum=False,
classifier_proj_size=256,
gradient_checkpointing=False,
pad_token_id=0,
bos_token_id=1,
eos_token_id=2,
......@@ -203,7 +200,6 @@ class HubertConfig(PretrainedConfig):
self.initializer_range = initializer_range
self.vocab_size = vocab_size
self.do_stable_layer_norm = do_stable_layer_norm
self.gradient_checkpointing = gradient_checkpointing
self.use_weighted_layer_sum = use_weighted_layer_sum
self.classifier_proj_size = classifier_proj_size
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment