Unverified Commit 3d869726 authored by Zehan Li's avatar Zehan Li Committed by GitHub
Browse files

add gradient checkpointing for distilbert (#24719)

* add gradient checkpointing for distilbert

* reformatted
parent 2642d8d0
...@@ -324,6 +324,7 @@ class Transformer(nn.Module): ...@@ -324,6 +324,7 @@ class Transformer(nn.Module):
super().__init__() super().__init__()
self.n_layers = config.n_layers self.n_layers = config.n_layers
self.layer = nn.ModuleList([TransformerBlock(config) for _ in range(config.n_layers)]) self.layer = nn.ModuleList([TransformerBlock(config) for _ in range(config.n_layers)])
self.gradient_checkpointing = False
def forward( def forward(
self, self,
...@@ -356,9 +357,28 @@ class Transformer(nn.Module): ...@@ -356,9 +357,28 @@ class Transformer(nn.Module):
if output_hidden_states: if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_state,) all_hidden_states = all_hidden_states + (hidden_state,)
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module),
hidden_state,
attn_mask,
head_mask[i],
)
else:
layer_outputs = layer_module( layer_outputs = layer_module(
x=hidden_state, attn_mask=attn_mask, head_mask=head_mask[i], output_attentions=output_attentions hidden_state,
attn_mask,
head_mask[i],
output_attentions,
) )
hidden_state = layer_outputs[-1] hidden_state = layer_outputs[-1]
if output_attentions: if output_attentions:
...@@ -392,6 +412,7 @@ class DistilBertPreTrainedModel(PreTrainedModel): ...@@ -392,6 +412,7 @@ class DistilBertPreTrainedModel(PreTrainedModel):
config_class = DistilBertConfig config_class = DistilBertConfig
load_tf_weights = None load_tf_weights = None
base_model_prefix = "distilbert" base_model_prefix = "distilbert"
supports_gradient_checkpointing = True
def _init_weights(self, module: nn.Module): def _init_weights(self, module: nn.Module):
"""Initialize the weights.""" """Initialize the weights."""
...@@ -409,6 +430,10 @@ class DistilBertPreTrainedModel(PreTrainedModel): ...@@ -409,6 +430,10 @@ class DistilBertPreTrainedModel(PreTrainedModel):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, Transformer):
module.gradient_checkpointing = value
DISTILBERT_START_DOCSTRING = r""" DISTILBERT_START_DOCSTRING = r"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment