Unverified Commit 06e782da authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

[`core`] Refactor of `gradient_checkpointing` (#27020)

* v1

* fix

* remove `create_custom_forward`

* fixup

* fixup

* add test and fix all failing GC tests

* remove all remaining `create_custom_forward` methods

* fix idefics bug

* fixup

* replace with `__call__`

* add comment

* quality
parent 9286f0ac
...@@ -1088,9 +1088,10 @@ class DeformableDetrPreTrainedModel(PreTrainedModel): ...@@ -1088,9 +1088,10 @@ class DeformableDetrPreTrainedModel(PreTrainedModel):
if hasattr(module, "level_embed"): if hasattr(module, "level_embed"):
nn.init.normal_(module.level_embed) nn.init.normal_(module.level_embed)
def _set_gradient_checkpointing(self, module, value=False): def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, DeformableDetrDecoder): if isinstance(module, DeformableDetrDecoder):
module.gradient_checkpointing = value module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
DEFORMABLE_DETR_START_DOCSTRING = r""" DEFORMABLE_DETR_START_DOCSTRING = r"""
...@@ -1383,15 +1384,8 @@ class DeformableDetrDecoder(DeformableDetrPreTrainedModel): ...@@ -1383,15 +1384,8 @@ class DeformableDetrDecoder(DeformableDetrPreTrainedModel):
all_hidden_states += (hidden_states,) all_hidden_states += (hidden_states,)
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
def create_custom_forward(module): decoder_layer.__call__,
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
hidden_states, hidden_states,
encoder_hidden_states, encoder_hidden_states,
encoder_attention_mask, encoder_attention_mask,
......
...@@ -357,17 +357,11 @@ class DeiTEncoder(nn.Module): ...@@ -357,17 +357,11 @@ class DeiTEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None layer_head_mask = head_mask[i] if head_mask is not None else None
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
def create_custom_forward(module): layer_module.__call__,
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module),
hidden_states, hidden_states,
layer_head_mask, layer_head_mask,
output_attentions,
) )
else: else:
layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions) layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions)
...@@ -415,9 +409,10 @@ class DeiTPreTrainedModel(PreTrainedModel): ...@@ -415,9 +409,10 @@ class DeiTPreTrainedModel(PreTrainedModel):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module: DeiTEncoder, value: bool = False) -> None: def _set_gradient_checkpointing(self, module: DeiTEncoder, gradient_checkpointing_func=None) -> None:
if isinstance(module, DeiTEncoder): if isinstance(module, DeiTEncoder):
module.gradient_checkpointing = value module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
DEIT_START_DOCSTRING = r""" DEIT_START_DOCSTRING = r"""
......
...@@ -504,9 +504,10 @@ class MCTCTPreTrainedModel(PreTrainedModel): ...@@ -504,9 +504,10 @@ class MCTCTPreTrainedModel(PreTrainedModel):
attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).long() attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).long()
return attention_mask return attention_mask
def _set_gradient_checkpointing(self, module, value=False): def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, (MCTCTEncoder)): if isinstance(module, (MCTCTEncoder)):
module.gradient_checkpointing = value module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
MCTCT_START_DOCSTRING = r""" MCTCT_START_DOCSTRING = r"""
...@@ -616,18 +617,12 @@ class MCTCTEncoder(MCTCTPreTrainedModel): ...@@ -616,18 +617,12 @@ class MCTCTEncoder(MCTCTPreTrainedModel):
if not skip_the_layer or deepspeed_zero3_is_enabled: if not skip_the_layer or deepspeed_zero3_is_enabled:
# under deepspeed zero3 all gpus must run in sync # under deepspeed zero3 all gpus must run in sync
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
def create_custom_forward(module): encoder_layer.__call__,
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(encoder_layer),
hidden_states, hidden_states,
attention_mask, attention_mask,
(head_mask[idx] if head_mask is not None else None), (head_mask[idx] if head_mask is not None else None),
output_attentions,
) )
else: else:
layer_outputs = encoder_layer( layer_outputs = encoder_layer(
......
...@@ -456,9 +456,10 @@ class OpenLlamaPreTrainedModel(PreTrainedModel): ...@@ -456,9 +456,10 @@ class OpenLlamaPreTrainedModel(PreTrainedModel):
if module.padding_idx is not None: if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_() module.weight.data[module.padding_idx].zero_()
def _set_gradient_checkpointing(self, module, value=False): def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, OpenLlamaModel): if isinstance(module, OpenLlamaModel):
module.gradient_checkpointing = value module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
OPEN_LLAMA_INPUTS_DOCSTRING = r""" OPEN_LLAMA_INPUTS_DOCSTRING = r"""
...@@ -665,20 +666,14 @@ class OpenLlamaModel(OpenLlamaPreTrainedModel): ...@@ -665,20 +666,14 @@ class OpenLlamaModel(OpenLlamaPreTrainedModel):
past_key_value = past_key_values[idx] if past_key_values is not None else None past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
def create_custom_forward(module): decoder_layer.__call__,
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, output_attentions, None)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
hidden_states, hidden_states,
attention_mask, attention_mask,
position_ids, position_ids,
None, None,
output_attentions,
None,
) )
else: else:
layer_outputs = decoder_layer( layer_outputs = decoder_layer(
......
...@@ -163,9 +163,10 @@ class TrajectoryTransformerPreTrainedModel(PreTrainedModel): ...@@ -163,9 +163,10 @@ class TrajectoryTransformerPreTrainedModel(PreTrainedModel):
main_input_name = "trajectories" main_input_name = "trajectories"
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
def _set_gradient_checkpointing(self, module, value=False): def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, TrajectoryTransformerModel): if isinstance(module, TrajectoryTransformerModel):
module.gradient_checkpointing = value module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
def _init_weights(self, module): def _init_weights(self, module):
if isinstance(module, (nn.Linear, nn.Embedding)): if isinstance(module, (nn.Linear, nn.Embedding)):
...@@ -550,15 +551,8 @@ class TrajectoryTransformerModel(TrajectoryTransformerPreTrainedModel): ...@@ -550,15 +551,8 @@ class TrajectoryTransformerModel(TrajectoryTransformerPreTrainedModel):
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
outputs = self.gradient_checkpointing_func(
def create_custom_forward(module): block.__call__,
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states, hidden_states,
layer_past, layer_past,
use_cache, use_cache,
......
...@@ -387,9 +387,10 @@ class VanPreTrainedModel(PreTrainedModel): ...@@ -387,9 +387,10 @@ class VanPreTrainedModel(PreTrainedModel):
if module.bias is not None: if module.bias is not None:
module.bias.data.zero_() module.bias.data.zero_()
def _set_gradient_checkpointing(self, module, value=False): def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, VanModel): if isinstance(module, VanModel):
module.gradient_checkpointing = value module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
VAN_START_DOCSTRING = r""" VAN_START_DOCSTRING = r"""
......
...@@ -979,9 +979,10 @@ class DetaPreTrainedModel(PreTrainedModel): ...@@ -979,9 +979,10 @@ class DetaPreTrainedModel(PreTrainedModel):
if hasattr(module, "level_embed"): if hasattr(module, "level_embed"):
nn.init.normal_(module.level_embed) nn.init.normal_(module.level_embed)
def _set_gradient_checkpointing(self, module, value=False): def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, DetaDecoder): if isinstance(module, DetaDecoder):
module.gradient_checkpointing = value module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
DETA_START_DOCSTRING = r""" DETA_START_DOCSTRING = r"""
...@@ -1275,15 +1276,8 @@ class DetaDecoder(DetaPreTrainedModel): ...@@ -1275,15 +1276,8 @@ class DetaDecoder(DetaPreTrainedModel):
all_hidden_states += (hidden_states,) all_hidden_states += (hidden_states,)
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
def create_custom_forward(module): decoder_layer.__call__,
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
hidden_states, hidden_states,
encoder_hidden_states, encoder_hidden_states,
encoder_attention_mask, encoder_attention_mask,
......
...@@ -927,9 +927,10 @@ class DetrPreTrainedModel(PreTrainedModel): ...@@ -927,9 +927,10 @@ class DetrPreTrainedModel(PreTrainedModel):
if module.padding_idx is not None: if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_() module.weight.data[module.padding_idx].zero_()
def _set_gradient_checkpointing(self, module, value=False): def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, DetrDecoder): if isinstance(module, DetrDecoder):
module.gradient_checkpointing = value module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
DETR_START_DOCSTRING = r""" DETR_START_DOCSTRING = r"""
...@@ -1253,15 +1254,8 @@ class DetrDecoder(DetrPreTrainedModel): ...@@ -1253,15 +1254,8 @@ class DetrDecoder(DetrPreTrainedModel):
continue continue
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
def create_custom_forward(module): decoder_layer.__call__,
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
hidden_states, hidden_states,
combined_attention_mask, combined_attention_mask,
encoder_hidden_states, encoder_hidden_states,
......
...@@ -660,7 +660,7 @@ class DinatPreTrainedModel(PreTrainedModel): ...@@ -660,7 +660,7 @@ class DinatPreTrainedModel(PreTrainedModel):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module: DinatEncoder, value: bool = False) -> None: def _set_gradient_checkpointing(self, module: DinatEncoder, gradient_checkpointing_func=None) -> None:
pass pass
......
...@@ -447,17 +447,11 @@ class Dinov2Encoder(nn.Module): ...@@ -447,17 +447,11 @@ class Dinov2Encoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None layer_head_mask = head_mask[i] if head_mask is not None else None
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
def create_custom_forward(module): layer_module.__call__,
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module),
hidden_states, hidden_states,
layer_head_mask, layer_head_mask,
output_attentions,
) )
else: else:
layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions) layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions)
...@@ -516,9 +510,10 @@ class Dinov2PreTrainedModel(PreTrainedModel): ...@@ -516,9 +510,10 @@ class Dinov2PreTrainedModel(PreTrainedModel):
std=self.config.initializer_range, std=self.config.initializer_range,
).to(module.cls_token.dtype) ).to(module.cls_token.dtype)
def _set_gradient_checkpointing(self, module: Dinov2Encoder, value: bool = False) -> None: def _set_gradient_checkpointing(self, module: Dinov2Encoder, gradient_checkpointing_func=None) -> None:
if isinstance(module, Dinov2Encoder): if isinstance(module, Dinov2Encoder):
module.gradient_checkpointing = value module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
DINOV2_START_DOCSTRING = r""" DINOV2_START_DOCSTRING = r"""
......
...@@ -358,18 +358,12 @@ class Transformer(nn.Module): ...@@ -358,18 +358,12 @@ class Transformer(nn.Module):
all_hidden_states = all_hidden_states + (hidden_state,) all_hidden_states = all_hidden_states + (hidden_state,)
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
def create_custom_forward(module): layer_module.__call__,
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module),
hidden_state, hidden_state,
attn_mask, attn_mask,
head_mask[i], head_mask[i],
output_attentions,
) )
else: else:
layer_outputs = layer_module( layer_outputs = layer_module(
...@@ -430,9 +424,10 @@ class DistilBertPreTrainedModel(PreTrainedModel): ...@@ -430,9 +424,10 @@ class DistilBertPreTrainedModel(PreTrainedModel):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, value=False): def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, Transformer): if isinstance(module, Transformer):
module.gradient_checkpointing = value module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
DISTILBERT_START_DOCSTRING = r""" DISTILBERT_START_DOCSTRING = r"""
......
...@@ -749,15 +749,8 @@ class DonutSwinEncoder(nn.Module): ...@@ -749,15 +749,8 @@ class DonutSwinEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None layer_head_mask = head_mask[i] if head_mask is not None else None
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
def create_custom_forward(module): layer_module.__call__, hidden_states, input_dimensions, layer_head_mask, output_attentions
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module), hidden_states, input_dimensions, layer_head_mask
) )
else: else:
layer_outputs = layer_module( layer_outputs = layer_module(
...@@ -826,9 +819,10 @@ class DonutSwinPreTrainedModel(PreTrainedModel): ...@@ -826,9 +819,10 @@ class DonutSwinPreTrainedModel(PreTrainedModel):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, value=False): def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, DonutSwinEncoder): if isinstance(module, DonutSwinEncoder):
module.gradient_checkpointing = value module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
SWIN_START_DOCSTRING = r""" SWIN_START_DOCSTRING = r"""
......
...@@ -164,9 +164,10 @@ class DPRPreTrainedModel(PreTrainedModel): ...@@ -164,9 +164,10 @@ class DPRPreTrainedModel(PreTrainedModel):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, value=False): def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, BertEncoder): if isinstance(module, BertEncoder):
module.gradient_checkpointing = value module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
class DPREncoder(DPRPreTrainedModel): class DPREncoder(DPRPreTrainedModel):
......
...@@ -528,17 +528,11 @@ class DPTViTEncoder(nn.Module): ...@@ -528,17 +528,11 @@ class DPTViTEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None layer_head_mask = head_mask[i] if head_mask is not None else None
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
def create_custom_forward(module): layer_module.__call__,
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module),
hidden_states, hidden_states,
layer_head_mask, layer_head_mask,
output_attentions,
) )
else: else:
layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions) layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions)
...@@ -818,9 +812,10 @@ class DPTPreTrainedModel(PreTrainedModel): ...@@ -818,9 +812,10 @@ class DPTPreTrainedModel(PreTrainedModel):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, value=False): def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, DPTViTEncoder): if isinstance(module, DPTViTEncoder):
module.gradient_checkpointing = value module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
DPT_START_DOCSTRING = r""" DPT_START_DOCSTRING = r"""
......
...@@ -500,9 +500,10 @@ class EfficientNetPreTrainedModel(PreTrainedModel): ...@@ -500,9 +500,10 @@ class EfficientNetPreTrainedModel(PreTrainedModel):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, value=False): def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, EfficientNetBlock): if isinstance(module, EfficientNetBlock):
module.gradient_checkpointing = value module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
@add_start_docstrings( @add_start_docstrings(
......
...@@ -571,20 +571,15 @@ class ElectraEncoder(nn.Module): ...@@ -571,20 +571,15 @@ class ElectraEncoder(nn.Module):
past_key_value = past_key_values[i] if past_key_values is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
def create_custom_forward(module): layer_module.__call__,
def custom_forward(*inputs):
return module(*inputs, past_key_value, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module),
hidden_states, hidden_states,
attention_mask, attention_mask,
layer_head_mask, layer_head_mask,
encoder_hidden_states, encoder_hidden_states,
encoder_attention_mask, encoder_attention_mask,
past_key_value,
output_attentions,
) )
else: else:
layer_outputs = layer_module( layer_outputs = layer_module(
...@@ -692,9 +687,10 @@ class ElectraPreTrainedModel(PreTrainedModel): ...@@ -692,9 +687,10 @@ class ElectraPreTrainedModel(PreTrainedModel):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, value=False): def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, ElectraEncoder): if isinstance(module, ElectraEncoder):
module.gradient_checkpointing = value module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
@dataclass @dataclass
......
...@@ -473,9 +473,10 @@ class EncodecPreTrainedModel(PreTrainedModel): ...@@ -473,9 +473,10 @@ class EncodecPreTrainedModel(PreTrainedModel):
elif "bias" in name: elif "bias" in name:
nn.init.constant_(param, 0.0) nn.init.constant_(param, 0.0)
def _set_gradient_checkpointing(self, module, value=False): def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, (EncodecEncoder, EncodecDecoder)): if isinstance(module, (EncodecEncoder, EncodecDecoder)):
module.gradient_checkpointing = value module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
ENCODEC_START_DOCSTRING = r""" ENCODEC_START_DOCSTRING = r"""
......
...@@ -265,10 +265,10 @@ class EncoderDecoderModel(PreTrainedModel): ...@@ -265,10 +265,10 @@ class EncoderDecoderModel(PreTrainedModel):
self.encoder, self.decoder._modules[decoder_base_model_prefix], self.decoder.base_model_prefix self.encoder, self.decoder._modules[decoder_base_model_prefix], self.decoder.base_model_prefix
) )
def _set_gradient_checkpointing(self, module, value=False): def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
# call both encoder and decoder function on gradient checkpointing # call both encoder and decoder function on gradient checkpointing
self.encoder._set_gradient_checkpointing(module, value=value) self.encoder._set_gradient_checkpointing(module, gradient_checkpointing_func=gradient_checkpointing_func)
self.decoder._set_gradient_checkpointing(module, value=value) self.decoder._set_gradient_checkpointing(module, gradient_checkpointing_func=gradient_checkpointing_func)
def get_encoder(self): def get_encoder(self):
return self.encoder return self.encoder
......
...@@ -506,20 +506,15 @@ class ErnieEncoder(nn.Module): ...@@ -506,20 +506,15 @@ class ErnieEncoder(nn.Module):
past_key_value = past_key_values[i] if past_key_values is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
def create_custom_forward(module): layer_module.__call__,
def custom_forward(*inputs):
return module(*inputs, past_key_value, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module),
hidden_states, hidden_states,
attention_mask, attention_mask,
layer_head_mask, layer_head_mask,
encoder_hidden_states, encoder_hidden_states,
encoder_attention_mask, encoder_attention_mask,
past_key_value,
output_attentions,
) )
else: else:
layer_outputs = layer_module( layer_outputs = layer_module(
...@@ -680,9 +675,10 @@ class ErniePreTrainedModel(PreTrainedModel): ...@@ -680,9 +675,10 @@ class ErniePreTrainedModel(PreTrainedModel):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, value=False): def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, ErnieEncoder): if isinstance(module, ErnieEncoder):
module.gradient_checkpointing = value module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
@dataclass @dataclass
......
...@@ -429,9 +429,10 @@ class ErnieMPreTrainedModel(PreTrainedModel): ...@@ -429,9 +429,10 @@ class ErnieMPreTrainedModel(PreTrainedModel):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, value=False): def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, ErnieMEncoder): if isinstance(module, ErnieMEncoder):
module.gradient_checkpointing = value module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
ERNIE_M_START_DOCSTRING = r""" ERNIE_M_START_DOCSTRING = r"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment