Unverified Commit 3d869726 authored by Zehan Li's avatar Zehan Li Committed by GitHub
Browse files

add gradient checkpointing for distilbert (#24719)

* add gradient checkpointing for distilbert

* reformatted
parent 2642d8d0
......@@ -324,6 +324,7 @@ class Transformer(nn.Module):
super().__init__()
self.n_layers = config.n_layers
self.layer = nn.ModuleList([TransformerBlock(config) for _ in range(config.n_layers)])
self.gradient_checkpointing = False
def forward(
self,
......@@ -356,9 +357,28 @@ class Transformer(nn.Module):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_state,)
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module),
hidden_state,
attn_mask,
head_mask[i],
)
else:
layer_outputs = layer_module(
x=hidden_state, attn_mask=attn_mask, head_mask=head_mask[i], output_attentions=output_attentions
hidden_state,
attn_mask,
head_mask[i],
output_attentions,
)
hidden_state = layer_outputs[-1]
if output_attentions:
......@@ -392,6 +412,7 @@ class DistilBertPreTrainedModel(PreTrainedModel):
config_class = DistilBertConfig
load_tf_weights = None
base_model_prefix = "distilbert"
supports_gradient_checkpointing = True
def _init_weights(self, module: nn.Module):
"""Initialize the weights."""
......@@ -409,6 +430,10 @@ class DistilBertPreTrainedModel(PreTrainedModel):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, Transformer):
module.gradient_checkpointing = value
DISTILBERT_START_DOCSTRING = r"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment