Unverified Commit 90f4b245 authored by Iz Beltagy's avatar Iz Beltagy Committed by GitHub
Browse files

Add support for gradient checkpointing in BERT (#4659)



* add support for gradient checkpointing in BERT

* fix unit tests

* isort

* black

* workaround for `torch.utils.checkpoint.checkpoint` not accepting bool

* Revert "workaround for `torch.utils.checkpoint.checkpoint` not accepting bool"

This reverts commit 5eb68bb804f5ffbfc7ba13c45a47717f72d04574.

* workaround for `torch.utils.checkpoint.checkpoint` not accepting bool
Co-authored-by: default avatarLysandre Debut <lysandre@huggingface.co>
parent f4e1f022
......@@ -90,6 +90,8 @@ class BertConfig(PretrainedConfig):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
layer_norm_eps (:obj:`float`, optional, defaults to 1e-12):
The epsilon used by the layer normalization layers.
gradient_checkpointing (:obj:`bool`, optional, defaults to False):
If True, use gradient checkpointing to save memory at the expense of slower backward pass.
Example::
......@@ -121,6 +123,7 @@ class BertConfig(PretrainedConfig):
initializer_range=0.02,
layer_norm_eps=1e-12,
pad_token_id=0,
gradient_checkpointing=False,
**kwargs
):
super().__init__(pad_token_id=pad_token_id, **kwargs)
......@@ -137,3 +140,4 @@ class BertConfig(PretrainedConfig):
self.type_vocab_size = type_vocab_size
self.initializer_range = initializer_range
self.layer_norm_eps = layer_norm_eps
self.gradient_checkpointing = gradient_checkpointing
......@@ -22,6 +22,7 @@ import os
import warnings
import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss
......@@ -391,6 +392,7 @@ class BertLayer(nn.Module):
class BertEncoder(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])
def forward(
......@@ -409,6 +411,23 @@ class BertEncoder(nn.Module):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if getattr(self.config, "gradient_checkpointing", False):
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module),
hidden_states,
attention_mask,
head_mask[i],
encoder_hidden_states,
encoder_attention_mask,
)
else:
layer_outputs = layer_module(
hidden_states,
attention_mask,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment