Unverified Commit 27d46397 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Make gradient_checkpointing a training argument (#13657)



* Make gradient_checkpointing a training argument

* Update src/transformers/modeling_utils.py
Co-authored-by: default avatarStas Bekman <stas00@users.noreply.github.com>

* Update src/transformers/configuration_utils.py
Co-authored-by: default avatarStas Bekman <stas00@users.noreply.github.com>

* Fix tests

* Style

* document Gradient Checkpointing as a performance feature

* Small rename

* PoC for not using the config

* Adapt BC to new PoC

* Forgot to save

* Rollout changes to all other models

* Fix typo
Co-authored-by: default avatarStas Bekman <stas00@users.noreply.github.com>
Co-authored-by: default avatarStas Bekman <stas@stason.org>
parent 75f6641e
......@@ -46,8 +46,8 @@ Tips:
- LED makes use of *global attention* by means of the ``global_attention_mask`` (see
:class:`~transformers.LongformerModel`). For summarization, it is advised to put *global attention* only on the first
``<s>`` token. For question answering, it is advised to put *global attention* on all tokens of the question.
- To fine-tune LED on all 16384, it is necessary to enable *gradient checkpointing* by setting
``config.gradient_checkpointing = True``.
- To fine-tune LED on all 16384, it is necessary to enable *gradient checkpointing* by executing
``model.gradient_checkpointing_enable()``.
- A notebook showing how to evaluate LED, can be accessed `here
<https://colab.research.google.com/drive/12INTTR6n64TzS4RrXZxMSXfrOd9Xzamo?usp=sharing>`__.
- A notebook showing how to fine-tune LED, can be accessed `here
......
......@@ -53,6 +53,7 @@ Software:
- Tensor Parallelism
- Low-memory Optimizers
- fp16/bf16 (smaller data)
- Gradient checkpointing
......@@ -226,6 +227,21 @@ pytorch `autocast` which performs AMP include a caching feature, which speed thi
Autocast maintains a cache of the FP16 casts of model params (leaves). This helps streamline parameter reuse: if the same FP32 param is used in several different FP16list ops, like several matmuls, instead of re-casting the param to FP16 on entering each matmul, the cast will occur on the first matmul, the casted FP16 copy will be cached, and for all later matmuls the FP16 copy will be reused. The cache is maintained only within a particular outermost autocast context. When you exit the autocast context the cache is dropped. For recommended usage, in which autocast wraps the forward pass, and then you exit the context before calling backward(), this means the cache only lasts the duration of the forward pass each iteration, and will be rebuilt next iteration. (The cache of FP16-casted copies MUST be rebuilt each iteration. The FP32 params get updated by the optimizer, so the FP16 copies must be recreated, otherwise the FP16 values will be stale.)
### Gradient Checkpointing
One way to use significantly less GPU memory is to enabled "Gradient Checkpointing" (also known as "activation checkpointing"). When enabled, a lot of memory can be freed at the cost of small decrease in the training speed due to recomputing parts of the graph during back-propagation.
This technique was first shared in the paper: [Training Deep Nets with Sublinear Memory Cost](https://arxiv.org/abs/1604.06174). The paper will also give you the exact details on the savings, but it's in the ballpark of `O(sqrt(n))`, where `n` is the number of feed-forward layers.
To activate this feature in 🤗 Transformers for models that support it, use:
```python
model.gradient_checkpointing_enable()
```
or add `--gradient_checkpointing` to the Trainer arguments.
### Batch sizes
One gets the most efficient performance when batch sizes and input/output neuron counts are divisible by a certain number, which typically starts at 8, but can be much higher as well. That number varies a lot depending on the specific hardware being used and the dtype of the model.
......
......@@ -174,8 +174,3 @@ python run_clm.py --model_type gpt2 --tokenizer_name gpt2 \ --config_overrides="
```
This feature is only available in `run_clm.py`, `run_plm.py` and `run_mlm.py`.
This feature can also be used to activate gradient checkpointing by passing:
```
--config_overrides "gradient_checkpointing=true,use_cache=False"
```
......@@ -19,6 +19,7 @@
import copy
import json
import os
import warnings
from typing import Any, Dict, Tuple, Union
from . import __version__
......@@ -330,6 +331,14 @@ class PretrainedConfig(PushToHubMixin):
# Drop the transformers version info
self.transformers_version = kwargs.pop("transformers_version", None)
# Deal with gradient checkpointing
if "gradient_checkpointing" in kwargs:
warnings.warn(
"Passing `gradient_checkpointing` to a config initialization is deprecated and will be removed in v5 "
"Transformers. Using `model.gradient_checkpointing_enable()` instead, or if you are using the "
"`Trainer` API, pass `gradient_checkpointing=True` in your `TrainingArguments`."
)
# Additional attributes without default values
for key, value in kwargs.items():
try:
......
......@@ -20,6 +20,7 @@ import re
import warnings
from contextlib import contextmanager
from dataclasses import dataclass
from functools import partial
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
import torch
......@@ -450,6 +451,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
_keys_to_ignore_on_save = None
is_parallelizable = False
supports_gradient_checkpointing = False
@property
def dummy_inputs(self) -> Dict[str, torch.Tensor]:
......@@ -469,6 +471,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
# Save config and origin of the pretrained weights if given in model
self.config = config
self.name_or_path = config.name_or_path
if getattr(self.config, "gradient_checkpointing", False):
self.gradient_checkpointing_enable()
# Remove the attribute now that is has been consumed, so it's no saved in the config.
delattr(self.config, "gradient_checkpointing")
@classmethod
def _from_config(cls, config, **kwargs):
......@@ -932,6 +938,27 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
self.base_model._prune_heads(heads_to_prune)
def gradient_checkpointing_enable(self, flag: bool = True):
"""
Activates gradient checkpointing for the current model.
Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint
activations".
"""
if not self.supports_gradient_checkpointing:
raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")
self.apply(partial(self._set_gradient_checkpointing, value=True))
def gradient_checkpointing_disable(self, flag: bool = True):
"""
Deactivates gradient checkpointing for the current model.
Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint
activations".
"""
if self.supports_gradient_checkpointing:
self.apply(partial(self._set_gradient_checkpointing, value=False))
def save_pretrained(
self,
save_directory: Union[str, os.PathLike],
......
......@@ -82,8 +82,6 @@ class BartConfig(PretrainedConfig):
decoder_layerdrop: (:obj:`float`, `optional`, defaults to 0.0):
The LayerDrop probability for the decoder. See the `LayerDrop paper <see
https://arxiv.org/abs/1909.11556>`__ for more details.
gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`):
If True, use gradient checkpointing to save memory at the expense of slower backward pass.
scale_embedding (:obj:`bool`, `optional`, defaults to :obj:`False`):
Scale embeddings by diving by sqrt(d_model).
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
......@@ -131,7 +129,6 @@ class BartConfig(PretrainedConfig):
init_std=0.02,
classifier_dropout=0.0,
scale_embedding=False,
gradient_checkpointing=False,
use_cache=True,
num_labels=3,
pad_token_id=1,
......@@ -161,7 +158,6 @@ class BartConfig(PretrainedConfig):
self.classifier_dropout = classifier_dropout
self.use_cache = use_cache
self.num_hidden_layers = encoder_layers
self.gradient_checkpointing = gradient_checkpointing
self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
super().__init__(
......
......@@ -471,6 +471,7 @@ class BartClassificationHead(nn.Module):
class BartPretrainedModel(PreTrainedModel):
config_class = BartConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_keys_to_ignore_on_load_unexpected = [r"encoder\.version", r"decoder\.version"]
def _init_weights(self, module):
......@@ -484,6 +485,10 @@ class BartPretrainedModel(PreTrainedModel):
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (BartDecoder, BartEncoder)):
module.gradient_checkpointing = value
@property
def dummy_inputs(self):
pad_token = self.config.pad_token_id
......@@ -687,6 +692,7 @@ class BartEncoder(BartPretrainedModel):
self.layernorm_embedding = nn.LayerNorm(embed_dim)
self.init_weights()
self.gradient_checkpointing = False
def forward(
self,
......@@ -782,7 +788,7 @@ class BartEncoder(BartPretrainedModel):
if self.training and (dropout_probability < self.layerdrop): # skip the layer
layer_outputs = (None, None)
else:
if getattr(self.config, "gradient_checkpointing", False) and self.training:
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
......@@ -849,6 +855,7 @@ class BartDecoder(BartPretrainedModel):
self.layernorm_embedding = nn.LayerNorm(config.d_model)
self.init_weights()
self.gradient_checkpointing = False
def get_input_embeddings(self):
return self.embed_tokens
......@@ -1020,12 +1027,11 @@ class BartDecoder(BartPretrainedModel):
past_key_value = past_key_values[idx] if past_key_values is not None else None
if getattr(self.config, "gradient_checkpointing", False) and self.training:
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
"`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting "
"`use_cache=False`..."
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
......
......@@ -57,8 +57,6 @@ class BeitConfig(PretrainedConfig):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-12):
The epsilon used by the layer normalization layers.
gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`):
If True, use gradient checkpointing to save memory at the expense of slower backward pass.
image_size (:obj:`int`, `optional`, defaults to :obj:`224`):
The size (resolution) of each image.
patch_size (:obj:`int`, `optional`, defaults to :obj:`16`):
......
......@@ -432,6 +432,7 @@ class BeitEncoder(nn.Module):
for i in range(config.num_hidden_layers)
]
)
self.gradient_checkpointing = False
def forward(
self,
......@@ -450,7 +451,7 @@ class BeitEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None
if getattr(self.config, "gradient_checkpointing", False) and self.training:
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
......@@ -494,6 +495,7 @@ class BeitPreTrainedModel(PreTrainedModel):
config_class = BeitConfig
base_model_prefix = "beit"
supports_gradient_checkpointing = True
def _init_weights(self, module):
"""Initialize the weights"""
......@@ -511,6 +513,10 @@ class BeitPreTrainedModel(PreTrainedModel):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, BeitEncoder):
module.gradient_checkpointing = value
BEIT_START_DOCSTRING = r"""
This model is a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`_ subclass. Use
......
......@@ -92,8 +92,6 @@ class BertConfig(PretrainedConfig):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-12):
The epsilon used by the layer normalization layers.
gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`):
If True, use gradient checkpointing to save memory at the expense of slower backward pass.
position_embedding_type (:obj:`str`, `optional`, defaults to :obj:`"absolute"`):
Type of position embedding. Choose one of :obj:`"absolute"`, :obj:`"relative_key"`,
:obj:`"relative_key_query"`. For positional embeddings use :obj:`"absolute"`. For more information on
......@@ -137,7 +135,6 @@ class BertConfig(PretrainedConfig):
initializer_range=0.02,
layer_norm_eps=1e-12,
pad_token_id=0,
gradient_checkpointing=False,
position_embedding_type="absolute",
use_cache=True,
classifier_dropout=None,
......@@ -157,7 +154,6 @@ class BertConfig(PretrainedConfig):
self.type_vocab_size = type_vocab_size
self.initializer_range = initializer_range
self.layer_norm_eps = layer_norm_eps
self.gradient_checkpointing = gradient_checkpointing
self.position_embedding_type = position_embedding_type
self.use_cache = use_cache
self.classifier_dropout = classifier_dropout
......
......@@ -529,6 +529,7 @@ class BertEncoder(nn.Module):
super().__init__()
self.config = config
self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])
self.gradient_checkpointing = False
def forward(
self,
......@@ -555,12 +556,11 @@ class BertEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None
past_key_value = past_key_values[i] if past_key_values is not None else None
if getattr(self.config, "gradient_checkpointing", False) and self.training:
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
"`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting "
"`use_cache=False`..."
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
......@@ -714,6 +714,7 @@ class BertPreTrainedModel(PreTrainedModel):
config_class = BertConfig
load_tf_weights = load_tf_weights_in_bert
base_model_prefix = "bert"
supports_gradient_checkpointing = True
_keys_to_ignore_on_load_missing = [r"position_ids"]
def _init_weights(self, module):
......@@ -732,6 +733,10 @@ class BertPreTrainedModel(PreTrainedModel):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, BertEncoder):
module.gradient_checkpointing = value
@dataclass
class BertForPreTrainingOutput(ModelOutput):
......
......@@ -52,8 +52,6 @@ class BertGenerationConfig(PretrainedConfig):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-12):
The epsilon used by the layer normalization layers.
gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`):
If :obj:`True`, use gradient checkpointing to save memory at the expense of slower backward pass.
position_embedding_type (:obj:`str`, `optional`, defaults to :obj:`"absolute"`):
Type of position embedding. Choose one of :obj:`"absolute"`, :obj:`"relative_key"`,
:obj:`"relative_key_query"`. For positional embeddings use :obj:`"absolute"`. For more information on
......@@ -96,7 +94,6 @@ class BertGenerationConfig(PretrainedConfig):
pad_token_id=0,
bos_token_id=2,
eos_token_id=1,
gradient_checkpointing=False,
position_embedding_type="absolute",
use_cache=True,
**kwargs
......@@ -114,6 +111,5 @@ class BertGenerationConfig(PretrainedConfig):
self.max_position_embeddings = max_position_embeddings
self.initializer_range = initializer_range
self.layer_norm_eps = layer_norm_eps
self.gradient_checkpointing = gradient_checkpointing
self.position_embedding_type = position_embedding_type
self.use_cache = use_cache
......@@ -82,8 +82,6 @@ class BigBirdConfig(PretrainedConfig):
num_random_blocks (:obj:`int`, `optional`, defaults to 3)
Each query is going to attend these many number of random blocks. Useful only when :obj:`attention_type ==
"block_sparse"`.
gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`):
If True, use gradient checkpointing to save memory at the expense of slower backward pass.
classifier_dropout (:obj:`float`, `optional`):
The dropout ratio for the classification head.
......@@ -127,7 +125,6 @@ class BigBirdConfig(PretrainedConfig):
rescale_embeddings=False,
block_size=64,
num_random_blocks=3,
gradient_checkpointing=False,
classifier_dropout=None,
**kwargs
):
......@@ -153,7 +150,6 @@ class BigBirdConfig(PretrainedConfig):
self.layer_norm_eps = layer_norm_eps
self.use_cache = use_cache
self.is_encoder_decoder = is_encoder_decoder
self.gradient_checkpointing = gradient_checkpointing
self.rescale_embeddings = rescale_embeddings
self.attention_type = attention_type
......
......@@ -1555,6 +1555,7 @@ class BigBirdEncoder(nn.Module):
self.layer = nn.ModuleList(
[BigBirdLayer(config, seed=layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
self.gradient_checkpointing = False
def set_attention_type(self, value: str):
if value not in ["original_full", "block_sparse"]:
......@@ -1598,12 +1599,11 @@ class BigBirdEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None
past_key_value = past_key_values[i] if past_key_values is not None else None
if getattr(self.config, "gradient_checkpointing", False) and self.training:
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
"`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting "
"`use_cache=False`..."
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
......@@ -1756,6 +1756,7 @@ class BigBirdPreTrainedModel(PreTrainedModel):
config_class = BigBirdConfig
load_tf_weights = load_tf_weights_in_big_bird
base_model_prefix = "bert"
supports_gradient_checkpointing = True
_keys_to_ignore_on_load_missing = [r"position_ids"]
def _init_weights(self, module):
......@@ -1774,6 +1775,10 @@ class BigBirdPreTrainedModel(PreTrainedModel):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, BigBirdEncoder):
module.gradient_checkpointing = value
BIG_BIRD_START_DOCSTRING = r"""
This model is a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`_ sub-class. Use
......
......@@ -94,8 +94,6 @@ class BigBirdPegasusConfig(PretrainedConfig):
"block_sparse"`.
scale_embeddings (:obj:`bool`, `optional`, defaults to :obj:`True`)
Whether to rescale embeddings with (hidden_size ** 0.5).
gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`):
If True, use gradient checkpointing to save memory at the expense of slower backward pass.
Example::
......@@ -141,7 +139,6 @@ class BigBirdPegasusConfig(PretrainedConfig):
decoder_start_token_id=2,
classifier_dropout=0.0,
scale_embedding=True,
gradient_checkpointing=False,
pad_token_id=0,
bos_token_id=2,
eos_token_id=1,
......@@ -170,7 +167,6 @@ class BigBirdPegasusConfig(PretrainedConfig):
self.classifier_dropout = classifier_dropout
self.use_cache = use_cache
self.num_hidden_layers = encoder_layers
self.gradient_checkpointing = gradient_checkpointing
self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
# extra config
......
......@@ -1567,6 +1567,7 @@ class BigBirdPegasusClassificationHead(nn.Module):
class BigBirdPegasusPreTrainedModel(PreTrainedModel):
config_class = BigBirdPegasusConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
def _init_weights(self, module):
std = self.config.init_std
......@@ -1579,6 +1580,10 @@ class BigBirdPegasusPreTrainedModel(PreTrainedModel):
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (BigBirdPegasusDecoder, BigBirdPegasusEncoder)):
module.gradient_checkpointing = value
@property
def dummy_inputs(self):
pad_token = self.config.pad_token_id
......@@ -1764,6 +1769,7 @@ class BigBirdPegasusEncoder(BigBirdPegasusPreTrainedModel):
self.layernorm_embedding = nn.LayerNorm(embed_dim)
self.init_weights()
self.gradient_checkpointing = False
def forward(
self,
......@@ -1894,7 +1900,7 @@ class BigBirdPegasusEncoder(BigBirdPegasusPreTrainedModel):
if self.training and (dropout_probability < self.layerdrop): # skip the layer
layer_outputs = (None, None)
else:
if getattr(self.config, "gradient_checkpointing", False) and self.training:
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
......@@ -2054,6 +2060,7 @@ class BigBirdPegasusDecoder(BigBirdPegasusPreTrainedModel):
self.layernorm_embedding = nn.LayerNorm(config.d_model)
self.init_weights()
self.gradient_checkpointing = False
def get_input_embeddings(self):
return self.embed_tokens
......@@ -2225,12 +2232,11 @@ class BigBirdPegasusDecoder(BigBirdPegasusPreTrainedModel):
past_key_value = past_key_values[idx] if past_key_values is not None else None
if getattr(self.config, "gradient_checkpointing", False) and self.training:
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
"`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting "
"`use_cache=False`..."
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
......
......@@ -78,8 +78,6 @@ class BlenderbotConfig(PretrainedConfig):
decoder_layerdrop: (:obj:`float`, `optional`, defaults to 0.0):
The LayerDrop probability for the decoder. See the `LayerDrop paper <see
https://arxiv.org/abs/1909.11556>`__ for more details.
gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`):
If True, use gradient checkpointing to save memory at the expense of slower backward pass.
scale_embedding (:obj:`bool`, `optional`, defaults to :obj:`False`):
Scale embeddings by diving by sqrt(d_model).
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
......@@ -128,7 +126,6 @@ class BlenderbotConfig(PretrainedConfig):
decoder_start_token_id=1,
classifier_dropout=0.0,
scale_embedding=False,
gradient_checkpointing=False,
pad_token_id=0,
bos_token_id=1,
eos_token_id=2,
......@@ -155,7 +152,6 @@ class BlenderbotConfig(PretrainedConfig):
self.classifier_dropout = classifier_dropout
self.use_cache = use_cache
self.num_hidden_layers = encoder_layers
self.gradient_checkpointing = gradient_checkpointing
self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
super().__init__(
......
......@@ -451,6 +451,7 @@ class BlenderbotDecoderLayer(nn.Module):
class BlenderbotPreTrainedModel(PreTrainedModel):
config_class = BlenderbotConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
def _init_weights(self, module):
std = self.config.init_std
......@@ -463,6 +464,10 @@ class BlenderbotPreTrainedModel(PreTrainedModel):
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (BlenderbotDecoder, BlenderbotEncoder)):
module.gradient_checkpointing = value
@property
def dummy_inputs(self):
pad_token = self.config.pad_token_id
......@@ -644,6 +649,7 @@ class BlenderbotEncoder(BlenderbotPreTrainedModel):
self.layer_norm = nn.LayerNorm(config.d_model)
self.init_weights()
self.gradient_checkpointing = False
def forward(
self,
......@@ -738,7 +744,7 @@ class BlenderbotEncoder(BlenderbotPreTrainedModel):
if self.training and (dropout_probability < self.layerdrop): # skip the layer
layer_outputs = (None, None)
else:
if getattr(self.config, "gradient_checkpointing", False) and self.training:
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
......@@ -808,6 +814,7 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel):
self.layer_norm = nn.LayerNorm(config.d_model)
self.init_weights()
self.gradient_checkpointing = False
def get_input_embeddings(self):
return self.embed_tokens
......@@ -980,12 +987,11 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel):
past_key_value = past_key_values[idx] if past_key_values is not None else None
if getattr(self.config, "gradient_checkpointing", False) and self.training:
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
"`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting "
"`use_cache=False`..."
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
......
......@@ -78,8 +78,6 @@ class BlenderbotSmallConfig(PretrainedConfig):
decoder_layerdrop: (:obj:`float`, `optional`, defaults to 0.0):
The LayerDrop probability for the decoder. See the `LayerDrop paper <see
https://arxiv.org/abs/1909.11556>`__ for more details.
gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`):
If True, use gradient checkpointing to save memory at the expense of slower backward pass.
scale_embedding (:obj:`bool`, `optional`, defaults to :obj:`False`):
Scale embeddings by diving by sqrt(d_model).
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
......@@ -128,7 +126,6 @@ class BlenderbotSmallConfig(PretrainedConfig):
decoder_start_token_id=1,
classifier_dropout=0.0,
scale_embedding=False,
gradient_checkpointing=False,
pad_token_id=0,
bos_token_id=1,
eos_token_id=2,
......@@ -154,7 +151,6 @@ class BlenderbotSmallConfig(PretrainedConfig):
self.classifier_dropout = classifier_dropout
self.use_cache = use_cache
self.num_hidden_layers = encoder_layers
self.gradient_checkpointing = gradient_checkpointing
self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
super().__init__(
......
......@@ -449,6 +449,7 @@ class BlenderbotSmallDecoderLayer(nn.Module):
class BlenderbotSmallPreTrainedModel(PreTrainedModel):
config_class = BlenderbotSmallConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
def _init_weights(self, module):
std = self.config.init_std
......@@ -461,6 +462,10 @@ class BlenderbotSmallPreTrainedModel(PreTrainedModel):
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (BlenderbotSmallDecoder, BlenderbotSmallEncoder)):
module.gradient_checkpointing = value
@property
def dummy_inputs(self):
pad_token = self.config.pad_token_id
......@@ -645,6 +650,7 @@ class BlenderbotSmallEncoder(BlenderbotSmallPreTrainedModel):
self.layernorm_embedding = nn.LayerNorm(embed_dim)
self.init_weights()
self.gradient_checkpointing = False
def forward(
self,
......@@ -740,7 +746,7 @@ class BlenderbotSmallEncoder(BlenderbotSmallPreTrainedModel):
if self.training and (dropout_probability < self.layerdrop): # skip the layer
layer_outputs = (None, None)
else:
if getattr(self.config, "gradient_checkpointing", False) and self.training:
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
......@@ -808,6 +814,7 @@ class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel):
self.layernorm_embedding = nn.LayerNorm(config.d_model)
self.init_weights()
self.gradient_checkpointing = False
def get_input_embeddings(self):
return self.embed_tokens
......@@ -981,12 +988,11 @@ class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel):
past_key_value = past_key_values[idx] if past_key_values is not None else None
if getattr(self.config, "gradient_checkpointing", False) and self.training:
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
"`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting "
"`use_cache=False`..."
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment