Unverified Commit 90f4b245 authored by Iz Beltagy's avatar Iz Beltagy Committed by GitHub
Browse files

Add support for gradient checkpointing in BERT (#4659)



* add support for gradient checkpointing in BERT

* fix unit tests

* isort

* black

* workaround for `torch.utils.checkpoint.checkpoint` not accepting bool

* Revert "workaround for `torch.utils.checkpoint.checkpoint` not accepting bool"

This reverts commit 5eb68bb804f5ffbfc7ba13c45a47717f72d04574.

* workaround for `torch.utils.checkpoint.checkpoint` not accepting bool
Co-authored-by: default avatarLysandre Debut <lysandre@huggingface.co>
parent f4e1f022
...@@ -90,6 +90,8 @@ class BertConfig(PretrainedConfig): ...@@ -90,6 +90,8 @@ class BertConfig(PretrainedConfig):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices. The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
layer_norm_eps (:obj:`float`, optional, defaults to 1e-12): layer_norm_eps (:obj:`float`, optional, defaults to 1e-12):
The epsilon used by the layer normalization layers. The epsilon used by the layer normalization layers.
gradient_checkpointing (:obj:`bool`, optional, defaults to False):
If True, use gradient checkpointing to save memory at the expense of slower backward pass.
Example:: Example::
...@@ -121,6 +123,7 @@ class BertConfig(PretrainedConfig): ...@@ -121,6 +123,7 @@ class BertConfig(PretrainedConfig):
initializer_range=0.02, initializer_range=0.02,
layer_norm_eps=1e-12, layer_norm_eps=1e-12,
pad_token_id=0, pad_token_id=0,
gradient_checkpointing=False,
**kwargs **kwargs
): ):
super().__init__(pad_token_id=pad_token_id, **kwargs) super().__init__(pad_token_id=pad_token_id, **kwargs)
...@@ -137,3 +140,4 @@ class BertConfig(PretrainedConfig): ...@@ -137,3 +140,4 @@ class BertConfig(PretrainedConfig):
self.type_vocab_size = type_vocab_size self.type_vocab_size = type_vocab_size
self.initializer_range = initializer_range self.initializer_range = initializer_range
self.layer_norm_eps = layer_norm_eps self.layer_norm_eps = layer_norm_eps
self.gradient_checkpointing = gradient_checkpointing
...@@ -22,6 +22,7 @@ import os ...@@ -22,6 +22,7 @@ import os
import warnings import warnings
import torch import torch
import torch.utils.checkpoint
from torch import nn from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss from torch.nn import CrossEntropyLoss, MSELoss
...@@ -391,6 +392,7 @@ class BertLayer(nn.Module): ...@@ -391,6 +392,7 @@ class BertLayer(nn.Module):
class BertEncoder(nn.Module): class BertEncoder(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
self.config = config
self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)]) self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])
def forward( def forward(
...@@ -409,14 +411,31 @@ class BertEncoder(nn.Module): ...@@ -409,14 +411,31 @@ class BertEncoder(nn.Module):
if output_hidden_states: if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
layer_outputs = layer_module( if getattr(self.config, "gradient_checkpointing", False):
hidden_states,
attention_mask, def create_custom_forward(module):
head_mask[i], def custom_forward(*inputs):
encoder_hidden_states, return module(*inputs, output_attentions)
encoder_attention_mask,
output_attentions, return custom_forward
)
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module),
hidden_states,
attention_mask,
head_mask[i],
encoder_hidden_states,
encoder_attention_mask,
)
else:
layer_outputs = layer_module(
hidden_states,
attention_mask,
head_mask[i],
encoder_hidden_states,
encoder_attention_mask,
output_attentions,
)
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
if output_attentions: if output_attentions:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment