Unverified Commit 06e782da authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

[`core`] Refactor of `gradient_checkpointing` (#27020)

* v1

* fix

* remove `create_custom_forward`

* fixup

* fixup

* add test and fix all failing GC tests

* remove all remaining `create_custom_forward` methods

* fix idefics bug

* fixup

* replace with `__call__`

* add comment

* quality
parent 9286f0ac
...@@ -398,15 +398,8 @@ class UniSpeechSatFeatureEncoder(nn.Module): ...@@ -398,15 +398,8 @@ class UniSpeechSatFeatureEncoder(nn.Module):
for conv_layer in self.conv_layers: for conv_layer in self.conv_layers:
if self._requires_grad and self.gradient_checkpointing and self.training: if self._requires_grad and self.gradient_checkpointing and self.training:
hidden_states = self.gradient_checkpointing_func(
def create_custom_forward(module): conv_layer.__call__,
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(conv_layer),
hidden_states, hidden_states,
) )
else: else:
...@@ -781,17 +774,11 @@ class UniSpeechSatEncoder(nn.Module): ...@@ -781,17 +774,11 @@ class UniSpeechSatEncoder(nn.Module):
if not skip_the_layer or deepspeed_zero3_is_enabled: if not skip_the_layer or deepspeed_zero3_is_enabled:
# under deepspeed zero3 all gpus must run in sync # under deepspeed zero3 all gpus must run in sync
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
# create gradient checkpointing function layer_outputs = self.gradient_checkpointing_func(
def create_custom_forward(module): layer.__call__,
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer),
hidden_states, hidden_states,
attention_mask, attention_mask,
output_attentions,
) )
else: else:
layer_outputs = layer( layer_outputs = layer(
...@@ -871,17 +858,11 @@ class UniSpeechSatEncoderStableLayerNorm(nn.Module): ...@@ -871,17 +858,11 @@ class UniSpeechSatEncoderStableLayerNorm(nn.Module):
# under deepspeed zero3 all gpus must run in sync # under deepspeed zero3 all gpus must run in sync
# XXX: could optimize this like synced_gpus in generate_utils but not sure if it's worth the code complication # XXX: could optimize this like synced_gpus in generate_utils but not sure if it's worth the code complication
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
# create gradient checkpointing function layer_outputs = self.gradient_checkpointing_func(
def create_custom_forward(module): layer.__call__,
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer),
hidden_states, hidden_states,
attention_mask, attention_mask,
output_attentions,
) )
else: else:
layer_outputs = layer( layer_outputs = layer(
...@@ -1053,9 +1034,10 @@ class UniSpeechSatPreTrainedModel(PreTrainedModel): ...@@ -1053,9 +1034,10 @@ class UniSpeechSatPreTrainedModel(PreTrainedModel):
attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool() attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool()
return attention_mask return attention_mask
def _set_gradient_checkpointing(self, module, value=False): def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, (UniSpeechSatEncoder, UniSpeechSatEncoderStableLayerNorm, UniSpeechSatFeatureEncoder)): if isinstance(module, (UniSpeechSatEncoder, UniSpeechSatEncoderStableLayerNorm, UniSpeechSatFeatureEncoder)):
module.gradient_checkpointing = value module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
UNISPEECH_SAT_START_DOCSTRING = r""" UNISPEECH_SAT_START_DOCSTRING = r"""
......
...@@ -315,9 +315,10 @@ class UperNetPreTrainedModel(PreTrainedModel): ...@@ -315,9 +315,10 @@ class UperNetPreTrainedModel(PreTrainedModel):
if self.auxiliary_head is not None: if self.auxiliary_head is not None:
self.auxiliary_head.init_weights() self.auxiliary_head.init_weights()
def _set_gradient_checkpointing(self, module, value=False): def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, BackboneMixin): if isinstance(module, BackboneMixin):
module.gradient_checkpointing = value module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
UPERNET_START_DOCSTRING = r""" UPERNET_START_DOCSTRING = r"""
......
...@@ -434,17 +434,11 @@ class VideoMAEEncoder(nn.Module): ...@@ -434,17 +434,11 @@ class VideoMAEEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None layer_head_mask = head_mask[i] if head_mask is not None else None
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
def create_custom_forward(module): layer_module.__call__,
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module),
hidden_states, hidden_states,
layer_head_mask, layer_head_mask,
output_attentions,
) )
else: else:
layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions) layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions)
...@@ -489,9 +483,10 @@ class VideoMAEPreTrainedModel(PreTrainedModel): ...@@ -489,9 +483,10 @@ class VideoMAEPreTrainedModel(PreTrainedModel):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, value=False): def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, VideoMAEEncoder): if isinstance(module, (VideoMAEEncoder, VideoMAEDecoder)):
module.gradient_checkpointing = value module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
VIDEOMAE_START_DOCSTRING = r""" VIDEOMAE_START_DOCSTRING = r"""
...@@ -726,17 +721,11 @@ class VideoMAEDecoder(nn.Module): ...@@ -726,17 +721,11 @@ class VideoMAEDecoder(nn.Module):
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
def create_custom_forward(module): layer_module.__call__,
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module),
hidden_states, hidden_states,
None, None,
output_attentions,
) )
else: else:
layer_outputs = layer_module(hidden_states, head_mask=None, output_attentions=output_attentions) layer_outputs = layer_module(hidden_states, head_mask=None, output_attentions=output_attentions)
......
...@@ -531,18 +531,12 @@ class ViltEncoder(nn.Module): ...@@ -531,18 +531,12 @@ class ViltEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None layer_head_mask = head_mask[i] if head_mask is not None else None
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
def create_custom_forward(module): layer_module.__call__,
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module),
hidden_states, hidden_states,
attention_mask, attention_mask,
layer_head_mask, layer_head_mask,
output_attentions,
) )
else: else:
layer_outputs = layer_module(hidden_states, attention_mask, layer_head_mask, output_attentions) layer_outputs = layer_module(hidden_states, attention_mask, layer_head_mask, output_attentions)
...@@ -591,9 +585,10 @@ class ViltPreTrainedModel(PreTrainedModel): ...@@ -591,9 +585,10 @@ class ViltPreTrainedModel(PreTrainedModel):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, value=False): def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, ViltEncoder): if isinstance(module, ViltEncoder):
module.gradient_checkpointing = value module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
VILT_START_DOCSTRING = r""" VILT_START_DOCSTRING = r"""
......
...@@ -225,10 +225,10 @@ class VisionEncoderDecoderModel(PreTrainedModel): ...@@ -225,10 +225,10 @@ class VisionEncoderDecoderModel(PreTrainedModel):
f"The encoder {self.encoder} should not have a LM Head. Please use a model without LM Head" f"The encoder {self.encoder} should not have a LM Head. Please use a model without LM Head"
) )
def _set_gradient_checkpointing(self, module, value=False): def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
# call both encoder and decoder function on gradient checkpointing # call both encoder and decoder function on gradient checkpointing
self.encoder._set_gradient_checkpointing(module, value=value) self.encoder._set_gradient_checkpointing(module, gradient_checkpointing_func=gradient_checkpointing_func)
self.decoder._set_gradient_checkpointing(module, value=value) self.decoder._set_gradient_checkpointing(module, gradient_checkpointing_func=gradient_checkpointing_func)
def get_encoder(self): def get_encoder(self):
return self.encoder return self.encoder
......
...@@ -418,18 +418,12 @@ class VisualBertEncoder(nn.Module): ...@@ -418,18 +418,12 @@ class VisualBertEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None layer_head_mask = head_mask[i] if head_mask is not None else None
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
def create_custom_forward(module): layer_module.__call__,
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module),
hidden_states, hidden_states,
attention_mask, attention_mask,
layer_head_mask, layer_head_mask,
output_attentions,
) )
else: else:
layer_outputs = layer_module(hidden_states, attention_mask, layer_head_mask, output_attentions) layer_outputs = layer_module(hidden_states, attention_mask, layer_head_mask, output_attentions)
...@@ -547,9 +541,10 @@ class VisualBertPreTrainedModel(PreTrainedModel): ...@@ -547,9 +541,10 @@ class VisualBertPreTrainedModel(PreTrainedModel):
if isinstance(module, nn.Linear) and module.bias is not None: if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_() module.bias.data.zero_()
def _set_gradient_checkpointing(self, module, value=False): def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, VisualBertEncoder): if isinstance(module, VisualBertEncoder):
module.gradient_checkpointing = value module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
@dataclass @dataclass
......
...@@ -397,17 +397,11 @@ class ViTEncoder(nn.Module): ...@@ -397,17 +397,11 @@ class ViTEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None layer_head_mask = head_mask[i] if head_mask is not None else None
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
def create_custom_forward(module): layer_module.__call__,
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module),
hidden_states, hidden_states,
layer_head_mask, layer_head_mask,
output_attentions,
) )
else: else:
layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions) layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions)
...@@ -467,9 +461,10 @@ class ViTPreTrainedModel(PreTrainedModel): ...@@ -467,9 +461,10 @@ class ViTPreTrainedModel(PreTrainedModel):
std=self.config.initializer_range, std=self.config.initializer_range,
).to(module.cls_token.dtype) ).to(module.cls_token.dtype)
def _set_gradient_checkpointing(self, module: ViTEncoder, value: bool = False) -> None: def _set_gradient_checkpointing(self, module: ViTEncoder, gradient_checkpointing_func=None) -> None:
if isinstance(module, ViTEncoder): if isinstance(module, ViTEncoder):
module.gradient_checkpointing = value module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
VIT_START_DOCSTRING = r""" VIT_START_DOCSTRING = r"""
......
...@@ -415,17 +415,11 @@ class ViTHybridEncoder(nn.Module): ...@@ -415,17 +415,11 @@ class ViTHybridEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None layer_head_mask = head_mask[i] if head_mask is not None else None
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
def create_custom_forward(module): layer_module.__call__,
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module),
hidden_states, hidden_states,
layer_head_mask, layer_head_mask,
output_attentions,
) )
else: else:
layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions) layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions)
...@@ -486,9 +480,10 @@ class ViTHybridPreTrainedModel(PreTrainedModel): ...@@ -486,9 +480,10 @@ class ViTHybridPreTrainedModel(PreTrainedModel):
std=self.config.initializer_range, std=self.config.initializer_range,
).to(module.cls_token.dtype) ).to(module.cls_token.dtype)
def _set_gradient_checkpointing(self, module: ViTHybridEncoder, value: bool = False) -> None: def _set_gradient_checkpointing(self, module: ViTHybridEncoder, gradient_checkpointing_func=None) -> None:
if isinstance(module, ViTHybridEncoder): if isinstance(module, ViTHybridEncoder):
module.gradient_checkpointing = value module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
VIT_START_DOCSTRING = r""" VIT_START_DOCSTRING = r"""
......
...@@ -536,17 +536,11 @@ class ViTMAEEncoder(nn.Module): ...@@ -536,17 +536,11 @@ class ViTMAEEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None layer_head_mask = head_mask[i] if head_mask is not None else None
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
def create_custom_forward(module): layer_module.__call__,
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module),
hidden_states, hidden_states,
layer_head_mask, layer_head_mask,
output_attentions,
) )
else: else:
layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions) layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions)
...@@ -591,9 +585,10 @@ class ViTMAEPreTrainedModel(PreTrainedModel): ...@@ -591,9 +585,10 @@ class ViTMAEPreTrainedModel(PreTrainedModel):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, value=False): def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, ViTMAEEncoder): if isinstance(module, (ViTMAEEncoder, ViTMAEDecoder)):
module.gradient_checkpointing = value module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
VIT_MAE_START_DOCSTRING = r""" VIT_MAE_START_DOCSTRING = r"""
...@@ -793,17 +788,11 @@ class ViTMAEDecoder(nn.Module): ...@@ -793,17 +788,11 @@ class ViTMAEDecoder(nn.Module):
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
def create_custom_forward(module): layer_module.__call__,
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module),
hidden_states, hidden_states,
None, None,
output_attentions,
) )
else: else:
layer_outputs = layer_module(hidden_states, head_mask=None, output_attentions=output_attentions) layer_outputs = layer_module(hidden_states, head_mask=None, output_attentions=output_attentions)
......
...@@ -387,17 +387,11 @@ class ViTMSNEncoder(nn.Module): ...@@ -387,17 +387,11 @@ class ViTMSNEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None layer_head_mask = head_mask[i] if head_mask is not None else None
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
def create_custom_forward(module): layer_module.__call__,
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module),
hidden_states, hidden_states,
layer_head_mask, layer_head_mask,
output_attentions,
) )
else: else:
layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions) layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions)
...@@ -444,9 +438,10 @@ class ViTMSNPreTrainedModel(PreTrainedModel): ...@@ -444,9 +438,10 @@ class ViTMSNPreTrainedModel(PreTrainedModel):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module: ViTMSNEncoder, value: bool = False) -> None: def _set_gradient_checkpointing(self, module: ViTMSNEncoder, gradient_checkpointing_func=None) -> None:
if isinstance(module, ViTMSNEncoder): if isinstance(module, ViTMSNEncoder):
module.gradient_checkpointing = value module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
VIT_MSN_START_DOCSTRING = r""" VIT_MSN_START_DOCSTRING = r"""
......
...@@ -565,17 +565,11 @@ class VitDetEncoder(nn.Module): ...@@ -565,17 +565,11 @@ class VitDetEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None layer_head_mask = head_mask[i] if head_mask is not None else None
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
def create_custom_forward(module): layer_module.__call__,
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module),
hidden_states, hidden_states,
layer_head_mask, layer_head_mask,
output_attentions,
) )
else: else:
layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions) layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions)
...@@ -666,9 +660,10 @@ class VitDetPreTrainedModel(PreTrainedModel): ...@@ -666,9 +660,10 @@ class VitDetPreTrainedModel(PreTrainedModel):
module.norm3.weight.data.zero_() module.norm3.weight.data.zero_()
module.norm3.bias.data.zero_() module.norm3.bias.data.zero_()
def _set_gradient_checkpointing(self, module: VitDetEncoder, value: bool = False) -> None: def _set_gradient_checkpointing(self, module: VitDetEncoder, gradient_checkpointing_func=None) -> None:
if isinstance(module, VitDetEncoder): if isinstance(module, VitDetEncoder):
module.gradient_checkpointing = value module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
VITDET_START_DOCSTRING = r""" VITDET_START_DOCSTRING = r"""
......
...@@ -86,9 +86,15 @@ class VitMattePreTrainedModel(PreTrainedModel): ...@@ -86,9 +86,15 @@ class VitMattePreTrainedModel(PreTrainedModel):
if module.bias is not None: if module.bias is not None:
module.bias.data.zero_() module.bias.data.zero_()
def _set_gradient_checkpointing(self, module, value=False): def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, BackboneMixin): if isinstance(module, BackboneMixin):
module.gradient_checkpointing = value module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
for backbone_module in module.modules():
if hasattr(backbone_module, "gradient_checkpointing"):
backbone_module.gradient_checkpointing_func = gradient_checkpointing_func
backbone_module.gradient_checkpointing = gradient_checkpointing_func is not None
class VitMatteBasicConv3x3(nn.Module): class VitMatteBasicConv3x3(nn.Module):
......
...@@ -1167,18 +1167,12 @@ class VitsEncoder(nn.Module): ...@@ -1167,18 +1167,12 @@ class VitsEncoder(nn.Module):
if not skip_the_layer or deepspeed_zero3_is_enabled: if not skip_the_layer or deepspeed_zero3_is_enabled:
# under deepspeed zero3 all gpus must run in sync # under deepspeed zero3 all gpus must run in sync
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
# create gradient checkpointing function layer_outputs = self.gradient_checkpointing_func(
def create_custom_forward(module): encoder_layer.__call__,
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(encoder_layer),
hidden_states, hidden_states,
padding_mask, padding_mask,
attention_mask, attention_mask,
output_attentions,
) )
else: else:
layer_outputs = encoder_layer( layer_outputs = encoder_layer(
...@@ -1296,9 +1290,10 @@ class VitsPreTrainedModel(PreTrainedModel): ...@@ -1296,9 +1290,10 @@ class VitsPreTrainedModel(PreTrainedModel):
if module.padding_idx is not None: if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_() module.weight.data[module.padding_idx].zero_()
def _set_gradient_checkpointing(self, module, value=False): def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, (VitsTextEncoder)): if isinstance(module, VitsEncoder):
module.gradient_checkpointing = value module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
VITS_START_DOCSTRING = r""" VITS_START_DOCSTRING = r"""
......
...@@ -338,17 +338,11 @@ class VivitEncoder(nn.Module): ...@@ -338,17 +338,11 @@ class VivitEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None layer_head_mask = head_mask[i] if head_mask is not None else None
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
def create_custom_forward(module): layer_module.__call__,
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module),
hidden_states, hidden_states,
layer_head_mask, layer_head_mask,
output_attentions,
) )
else: else:
layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions) layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions)
...@@ -414,9 +408,10 @@ class VivitPreTrainedModel(PreTrainedModel): ...@@ -414,9 +408,10 @@ class VivitPreTrainedModel(PreTrainedModel):
elif isinstance(module, nn.Parameter): elif isinstance(module, nn.Parameter):
module.data.normal_(mean=0.0, std=self.config.initializer_range) module.data.normal_(mean=0.0, std=self.config.initializer_range)
def _set_gradient_checkpointing(self, module, value=False): def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, VivitEncoder): if isinstance(module, VivitEncoder):
module.gradient_checkpointing = value module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
VIVIT_START_DOCSTRING = r""" VIVIT_START_DOCSTRING = r"""
......
...@@ -451,15 +451,8 @@ class Wav2Vec2FeatureEncoder(nn.Module): ...@@ -451,15 +451,8 @@ class Wav2Vec2FeatureEncoder(nn.Module):
for conv_layer in self.conv_layers: for conv_layer in self.conv_layers:
if self._requires_grad and self.gradient_checkpointing and self.training: if self._requires_grad and self.gradient_checkpointing and self.training:
hidden_states = self.gradient_checkpointing_func(
def create_custom_forward(module): conv_layer.__call__,
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(conv_layer),
hidden_states, hidden_states,
) )
else: else:
...@@ -803,17 +796,11 @@ class Wav2Vec2Encoder(nn.Module): ...@@ -803,17 +796,11 @@ class Wav2Vec2Encoder(nn.Module):
if not skip_the_layer or deepspeed_zero3_is_enabled: if not skip_the_layer or deepspeed_zero3_is_enabled:
# under deepspeed zero3 all gpus must run in sync # under deepspeed zero3 all gpus must run in sync
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
# create gradient checkpointing function layer_outputs = self.gradient_checkpointing_func(
def create_custom_forward(module): layer.__call__,
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer),
hidden_states, hidden_states,
attention_mask, attention_mask,
output_attentions,
) )
else: else:
layer_outputs = layer( layer_outputs = layer(
...@@ -892,17 +879,11 @@ class Wav2Vec2EncoderStableLayerNorm(nn.Module): ...@@ -892,17 +879,11 @@ class Wav2Vec2EncoderStableLayerNorm(nn.Module):
# under deepspeed zero3 all gpus must run in sync # under deepspeed zero3 all gpus must run in sync
# XXX: could optimize this like synced_gpus in generate_utils but not sure if it's worth the code complication # XXX: could optimize this like synced_gpus in generate_utils but not sure if it's worth the code complication
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
# create gradient checkpointing function layer_outputs = self.gradient_checkpointing_func(
def create_custom_forward(module): layer.__call__,
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer),
hidden_states, hidden_states,
attention_mask, attention_mask,
output_attentions,
) )
else: else:
layer_outputs = layer( layer_outputs = layer(
...@@ -1173,9 +1154,10 @@ class Wav2Vec2PreTrainedModel(PreTrainedModel): ...@@ -1173,9 +1154,10 @@ class Wav2Vec2PreTrainedModel(PreTrainedModel):
attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool() attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool()
return attention_mask return attention_mask
def _set_gradient_checkpointing(self, module, value=False): def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, (Wav2Vec2Encoder, Wav2Vec2EncoderStableLayerNorm, Wav2Vec2FeatureEncoder)): if isinstance(module, (Wav2Vec2Encoder, Wav2Vec2EncoderStableLayerNorm, Wav2Vec2FeatureEncoder)):
module.gradient_checkpointing = value module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
def _get_adapters(self): def _get_adapters(self):
if self.config.adapter_attn_dim is None: if self.config.adapter_attn_dim is None:
......
...@@ -518,15 +518,8 @@ class Wav2Vec2ConformerFeatureEncoder(nn.Module): ...@@ -518,15 +518,8 @@ class Wav2Vec2ConformerFeatureEncoder(nn.Module):
for conv_layer in self.conv_layers: for conv_layer in self.conv_layers:
if self._requires_grad and self.gradient_checkpointing and self.training: if self._requires_grad and self.gradient_checkpointing and self.training:
hidden_states = self.gradient_checkpointing_func(
def create_custom_forward(module): conv_layer.__call__,
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(conv_layer),
hidden_states, hidden_states,
) )
else: else:
...@@ -911,18 +904,12 @@ class Wav2Vec2ConformerEncoder(nn.Module): ...@@ -911,18 +904,12 @@ class Wav2Vec2ConformerEncoder(nn.Module):
if not skip_the_layer or deepspeed_zero3_is_enabled: if not skip_the_layer or deepspeed_zero3_is_enabled:
# under deepspeed zero3 all gpus must run in sync # under deepspeed zero3 all gpus must run in sync
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
# create gradient checkpointing function layer_outputs = self.gradient_checkpointing_func(
def create_custom_forward(module): layer.__call__,
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer),
hidden_states, hidden_states,
attention_mask, attention_mask,
relative_position_embeddings, relative_position_embeddings,
output_attentions,
) )
else: else:
layer_outputs = layer( layer_outputs = layer(
...@@ -1178,9 +1165,10 @@ class Wav2Vec2ConformerPreTrainedModel(PreTrainedModel): ...@@ -1178,9 +1165,10 @@ class Wav2Vec2ConformerPreTrainedModel(PreTrainedModel):
attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool() attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool()
return attention_mask return attention_mask
def _set_gradient_checkpointing(self, module, value=False): def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, (Wav2Vec2ConformerEncoder, Wav2Vec2ConformerFeatureEncoder)): if isinstance(module, (Wav2Vec2ConformerEncoder, Wav2Vec2ConformerFeatureEncoder)):
module.gradient_checkpointing = value module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
WAV2VEC2_CONFORMER_START_DOCSTRING = r""" WAV2VEC2_CONFORMER_START_DOCSTRING = r"""
......
...@@ -354,15 +354,8 @@ class WavLMFeatureEncoder(nn.Module): ...@@ -354,15 +354,8 @@ class WavLMFeatureEncoder(nn.Module):
for conv_layer in self.conv_layers: for conv_layer in self.conv_layers:
if self._requires_grad and self.gradient_checkpointing and self.training: if self._requires_grad and self.gradient_checkpointing and self.training:
hidden_states = self.gradient_checkpointing_func(
def create_custom_forward(module): conv_layer.__call__,
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(conv_layer),
hidden_states, hidden_states,
) )
else: else:
...@@ -713,18 +706,12 @@ class WavLMEncoder(nn.Module): ...@@ -713,18 +706,12 @@ class WavLMEncoder(nn.Module):
if not skip_the_layer or deepspeed_zero3_is_enabled: if not skip_the_layer or deepspeed_zero3_is_enabled:
# under deepspeed zero3 all gpus must run in sync # under deepspeed zero3 all gpus must run in sync
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
# create gradient checkpointing function layer_outputs = self.gradient_checkpointing_func(
def create_custom_forward(module): layer.__call__,
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer),
hidden_states, hidden_states,
attention_mask, attention_mask,
position_bias, position_bias,
output_attentions,
) )
else: else:
layer_outputs = layer( layer_outputs = layer(
...@@ -804,18 +791,12 @@ class WavLMEncoderStableLayerNorm(nn.Module): ...@@ -804,18 +791,12 @@ class WavLMEncoderStableLayerNorm(nn.Module):
# under deepspeed zero3 all gpus must run in sync # under deepspeed zero3 all gpus must run in sync
# XXX: could optimize this like synced_gpus in generate_utils but not sure if it's worth the code complication # XXX: could optimize this like synced_gpus in generate_utils but not sure if it's worth the code complication
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
# create gradient checkpointing function layer_outputs = self.gradient_checkpointing_func(
def create_custom_forward(module): layer.__call__,
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer),
hidden_states, hidden_states,
attention_mask, attention_mask,
position_bias, position_bias,
output_attentions,
) )
else: else:
layer_outputs = layer( layer_outputs = layer(
...@@ -1052,9 +1033,10 @@ class WavLMPreTrainedModel(PreTrainedModel): ...@@ -1052,9 +1033,10 @@ class WavLMPreTrainedModel(PreTrainedModel):
attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool() attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool()
return attention_mask return attention_mask
def _set_gradient_checkpointing(self, module, value=False): def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, (WavLMEncoder, WavLMEncoderStableLayerNorm, WavLMFeatureEncoder)): if isinstance(module, (WavLMEncoder, WavLMEncoderStableLayerNorm, WavLMFeatureEncoder)):
module.gradient_checkpointing = value module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
WAVLM_START_DOCSTRING = r""" WAVLM_START_DOCSTRING = r"""
......
...@@ -685,9 +685,10 @@ class WhisperPreTrainedModel(PreTrainedModel): ...@@ -685,9 +685,10 @@ class WhisperPreTrainedModel(PreTrainedModel):
embed_positions = module.embed_positions.weight embed_positions = module.embed_positions.weight
embed_positions.copy_(sinusoids(*embed_positions.shape)) embed_positions.copy_(sinusoids(*embed_positions.shape))
def _set_gradient_checkpointing(self, module, value=False): def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, (WhisperDecoder, WhisperEncoder)): if isinstance(module, (WhisperDecoder, WhisperEncoder)):
module.gradient_checkpointing = value module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor): def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor):
""" """
...@@ -942,18 +943,12 @@ class WhisperEncoder(WhisperPreTrainedModel): ...@@ -942,18 +943,12 @@ class WhisperEncoder(WhisperPreTrainedModel):
layer_outputs = (None, None) layer_outputs = (None, None)
else: else:
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
def create_custom_forward(module): encoder_layer.__call__,
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(encoder_layer),
hidden_states, hidden_states,
None, None,
(head_mask[idx] if head_mask is not None else None), (head_mask[idx] if head_mask is not None else None),
output_attentions,
) )
else: else:
layer_outputs = encoder_layer( layer_outputs = encoder_layer(
...@@ -1174,16 +1169,8 @@ class WhisperDecoder(WhisperPreTrainedModel): ...@@ -1174,16 +1169,8 @@ class WhisperDecoder(WhisperPreTrainedModel):
past_key_value = past_key_values[idx] if past_key_values is not None else None past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
def create_custom_forward(module): decoder_layer.__call__,
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, output_attentions, use_cache)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
hidden_states, hidden_states,
attention_mask, attention_mask,
encoder_hidden_states, encoder_hidden_states,
...@@ -1191,6 +1178,8 @@ class WhisperDecoder(WhisperPreTrainedModel): ...@@ -1191,6 +1178,8 @@ class WhisperDecoder(WhisperPreTrainedModel):
head_mask[idx] if head_mask is not None else None, head_mask[idx] if head_mask is not None else None,
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
None, # past_key_value None, # past_key_value
output_attentions,
use_cache,
) )
else: else:
layer_outputs = decoder_layer( layer_outputs = decoder_layer(
......
...@@ -534,9 +534,10 @@ class XCLIPPreTrainedModel(PreTrainedModel): ...@@ -534,9 +534,10 @@ class XCLIPPreTrainedModel(PreTrainedModel):
if module.bias is not None: if module.bias is not None:
module.bias.data.zero_() module.bias.data.zero_()
def _set_gradient_checkpointing(self, module, value=False): def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, (XCLIPEncoder, XCLIPVisionEncoder)): if isinstance(module, (XCLIPEncoder, XCLIPVisionEncoder)):
module.gradient_checkpointing = value module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
X_CLIP_START_DOCSTRING = r""" X_CLIP_START_DOCSTRING = r"""
...@@ -703,18 +704,12 @@ class XCLIPEncoder(nn.Module): ...@@ -703,18 +704,12 @@ class XCLIPEncoder(nn.Module):
if output_hidden_states: if output_hidden_states:
encoder_states = encoder_states + (hidden_states,) encoder_states = encoder_states + (hidden_states,)
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
def create_custom_forward(module): encoder_layer.__call__,
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(encoder_layer),
hidden_states, hidden_states,
attention_mask, attention_mask,
causal_attention_mask, causal_attention_mask,
output_attentions,
) )
else: else:
layer_outputs = encoder_layer( layer_outputs = encoder_layer(
...@@ -950,18 +945,12 @@ class XCLIPVisionEncoder(nn.Module): ...@@ -950,18 +945,12 @@ class XCLIPVisionEncoder(nn.Module):
if output_hidden_states: if output_hidden_states:
encoder_states = encoder_states + (hidden_states,) encoder_states = encoder_states + (hidden_states,)
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
def create_custom_forward(module): encoder_layer.__call__,
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(encoder_layer),
hidden_states, hidden_states,
attention_mask, attention_mask,
causal_attention_mask, causal_attention_mask,
output_attentions,
) )
else: else:
layer_outputs = encoder_layer( layer_outputs = encoder_layer(
......
...@@ -503,9 +503,10 @@ class XGLMPreTrainedModel(PreTrainedModel): ...@@ -503,9 +503,10 @@ class XGLMPreTrainedModel(PreTrainedModel):
if module.padding_idx is not None: if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_() module.weight.data[module.padding_idx].zero_()
def _set_gradient_checkpointing(self, module, value=False): def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, XGLMModel): if isinstance(module, XGLMModel):
module.gradient_checkpointing = value module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
@add_start_docstrings( @add_start_docstrings(
...@@ -674,16 +675,8 @@ class XGLMModel(XGLMPreTrainedModel): ...@@ -674,16 +675,8 @@ class XGLMModel(XGLMPreTrainedModel):
past_key_value = past_key_values[idx] if past_key_values is not None else None past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
def create_custom_forward(module): decoder_layer.__call__,
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, output_attentions, use_cache)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
hidden_states, hidden_states,
attention_mask, attention_mask,
encoder_hidden_states, encoder_hidden_states,
...@@ -691,6 +684,8 @@ class XGLMModel(XGLMPreTrainedModel): ...@@ -691,6 +684,8 @@ class XGLMModel(XGLMPreTrainedModel):
head_mask[idx] if head_mask is not None else None, head_mask[idx] if head_mask is not None else None,
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
None, None,
output_attentions,
use_cache,
) )
else: else:
layer_outputs = decoder_layer( layer_outputs = decoder_layer(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment