Unverified Commit 06e782da authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

[`core`] Refactor of `gradient_checkpointing` (#27020)

* v1

* fix

* remove `create_custom_forward`

* fixup

* fixup

* add test and fix all failing GC tests

* remove all remaining `create_custom_forward` methods

* fix idefics bug

* fixup

* replace with `__call__`

* add comment

* quality
parent 9286f0ac
......@@ -398,15 +398,8 @@ class UniSpeechSatFeatureEncoder(nn.Module):
for conv_layer in self.conv_layers:
if self._requires_grad and self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(conv_layer),
hidden_states = self.gradient_checkpointing_func(
conv_layer.__call__,
hidden_states,
)
else:
......@@ -781,17 +774,11 @@ class UniSpeechSatEncoder(nn.Module):
if not skip_the_layer or deepspeed_zero3_is_enabled:
# under deepspeed zero3 all gpus must run in sync
if self.gradient_checkpointing and self.training:
# create gradient checkpointing function
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer),
layer_outputs = self.gradient_checkpointing_func(
layer.__call__,
hidden_states,
attention_mask,
output_attentions,
)
else:
layer_outputs = layer(
......@@ -871,17 +858,11 @@ class UniSpeechSatEncoderStableLayerNorm(nn.Module):
# under deepspeed zero3 all gpus must run in sync
# XXX: could optimize this like synced_gpus in generate_utils but not sure if it's worth the code complication
if self.gradient_checkpointing and self.training:
# create gradient checkpointing function
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer),
layer_outputs = self.gradient_checkpointing_func(
layer.__call__,
hidden_states,
attention_mask,
output_attentions,
)
else:
layer_outputs = layer(
......@@ -1053,9 +1034,10 @@ class UniSpeechSatPreTrainedModel(PreTrainedModel):
attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool()
return attention_mask
def _set_gradient_checkpointing(self, module, value=False):
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, (UniSpeechSatEncoder, UniSpeechSatEncoderStableLayerNorm, UniSpeechSatFeatureEncoder)):
module.gradient_checkpointing = value
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
UNISPEECH_SAT_START_DOCSTRING = r"""
......
......@@ -315,9 +315,10 @@ class UperNetPreTrainedModel(PreTrainedModel):
if self.auxiliary_head is not None:
self.auxiliary_head.init_weights()
def _set_gradient_checkpointing(self, module, value=False):
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, BackboneMixin):
module.gradient_checkpointing = value
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
UPERNET_START_DOCSTRING = r"""
......
......@@ -434,17 +434,11 @@ class VideoMAEEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module),
layer_outputs = self.gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
layer_head_mask,
output_attentions,
)
else:
layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions)
......@@ -489,9 +483,10 @@ class VideoMAEPreTrainedModel(PreTrainedModel):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, VideoMAEEncoder):
module.gradient_checkpointing = value
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, (VideoMAEEncoder, VideoMAEDecoder)):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
VIDEOMAE_START_DOCSTRING = r"""
......@@ -726,17 +721,11 @@ class VideoMAEDecoder(nn.Module):
all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module),
layer_outputs = self.gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
None,
output_attentions,
)
else:
layer_outputs = layer_module(hidden_states, head_mask=None, output_attentions=output_attentions)
......
......@@ -531,18 +531,12 @@ class ViltEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module),
layer_outputs = self.gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
attention_mask,
layer_head_mask,
output_attentions,
)
else:
layer_outputs = layer_module(hidden_states, attention_mask, layer_head_mask, output_attentions)
......@@ -591,9 +585,10 @@ class ViltPreTrainedModel(PreTrainedModel):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, value=False):
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, ViltEncoder):
module.gradient_checkpointing = value
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
VILT_START_DOCSTRING = r"""
......
......@@ -225,10 +225,10 @@ class VisionEncoderDecoderModel(PreTrainedModel):
f"The encoder {self.encoder} should not have a LM Head. Please use a model without LM Head"
)
def _set_gradient_checkpointing(self, module, value=False):
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
# call both encoder and decoder function on gradient checkpointing
self.encoder._set_gradient_checkpointing(module, value=value)
self.decoder._set_gradient_checkpointing(module, value=value)
self.encoder._set_gradient_checkpointing(module, gradient_checkpointing_func=gradient_checkpointing_func)
self.decoder._set_gradient_checkpointing(module, gradient_checkpointing_func=gradient_checkpointing_func)
def get_encoder(self):
return self.encoder
......
......@@ -418,18 +418,12 @@ class VisualBertEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module),
layer_outputs = self.gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
attention_mask,
layer_head_mask,
output_attentions,
)
else:
layer_outputs = layer_module(hidden_states, attention_mask, layer_head_mask, output_attentions)
......@@ -547,9 +541,10 @@ class VisualBertPreTrainedModel(PreTrainedModel):
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
def _set_gradient_checkpointing(self, module, value=False):
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, VisualBertEncoder):
module.gradient_checkpointing = value
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
@dataclass
......
......@@ -397,17 +397,11 @@ class ViTEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module),
layer_outputs = self.gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
layer_head_mask,
output_attentions,
)
else:
layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions)
......@@ -467,9 +461,10 @@ class ViTPreTrainedModel(PreTrainedModel):
std=self.config.initializer_range,
).to(module.cls_token.dtype)
def _set_gradient_checkpointing(self, module: ViTEncoder, value: bool = False) -> None:
def _set_gradient_checkpointing(self, module: ViTEncoder, gradient_checkpointing_func=None) -> None:
if isinstance(module, ViTEncoder):
module.gradient_checkpointing = value
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
VIT_START_DOCSTRING = r"""
......
......@@ -415,17 +415,11 @@ class ViTHybridEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module),
layer_outputs = self.gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
layer_head_mask,
output_attentions,
)
else:
layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions)
......@@ -486,9 +480,10 @@ class ViTHybridPreTrainedModel(PreTrainedModel):
std=self.config.initializer_range,
).to(module.cls_token.dtype)
def _set_gradient_checkpointing(self, module: ViTHybridEncoder, value: bool = False) -> None:
def _set_gradient_checkpointing(self, module: ViTHybridEncoder, gradient_checkpointing_func=None) -> None:
if isinstance(module, ViTHybridEncoder):
module.gradient_checkpointing = value
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
VIT_START_DOCSTRING = r"""
......
......@@ -536,17 +536,11 @@ class ViTMAEEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module),
layer_outputs = self.gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
layer_head_mask,
output_attentions,
)
else:
layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions)
......@@ -591,9 +585,10 @@ class ViTMAEPreTrainedModel(PreTrainedModel):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, ViTMAEEncoder):
module.gradient_checkpointing = value
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, (ViTMAEEncoder, ViTMAEDecoder)):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
VIT_MAE_START_DOCSTRING = r"""
......@@ -793,17 +788,11 @@ class ViTMAEDecoder(nn.Module):
all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module),
layer_outputs = self.gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
None,
output_attentions,
)
else:
layer_outputs = layer_module(hidden_states, head_mask=None, output_attentions=output_attentions)
......
......@@ -387,17 +387,11 @@ class ViTMSNEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module),
layer_outputs = self.gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
layer_head_mask,
output_attentions,
)
else:
layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions)
......@@ -444,9 +438,10 @@ class ViTMSNPreTrainedModel(PreTrainedModel):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module: ViTMSNEncoder, value: bool = False) -> None:
def _set_gradient_checkpointing(self, module: ViTMSNEncoder, gradient_checkpointing_func=None) -> None:
if isinstance(module, ViTMSNEncoder):
module.gradient_checkpointing = value
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
VIT_MSN_START_DOCSTRING = r"""
......
......@@ -565,17 +565,11 @@ class VitDetEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module),
layer_outputs = self.gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
layer_head_mask,
output_attentions,
)
else:
layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions)
......@@ -666,9 +660,10 @@ class VitDetPreTrainedModel(PreTrainedModel):
module.norm3.weight.data.zero_()
module.norm3.bias.data.zero_()
def _set_gradient_checkpointing(self, module: VitDetEncoder, value: bool = False) -> None:
def _set_gradient_checkpointing(self, module: VitDetEncoder, gradient_checkpointing_func=None) -> None:
if isinstance(module, VitDetEncoder):
module.gradient_checkpointing = value
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
VITDET_START_DOCSTRING = r"""
......
......@@ -86,9 +86,15 @@ class VitMattePreTrainedModel(PreTrainedModel):
if module.bias is not None:
module.bias.data.zero_()
def _set_gradient_checkpointing(self, module, value=False):
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, BackboneMixin):
module.gradient_checkpointing = value
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
for backbone_module in module.modules():
if hasattr(backbone_module, "gradient_checkpointing"):
backbone_module.gradient_checkpointing_func = gradient_checkpointing_func
backbone_module.gradient_checkpointing = gradient_checkpointing_func is not None
class VitMatteBasicConv3x3(nn.Module):
......
......@@ -1167,18 +1167,12 @@ class VitsEncoder(nn.Module):
if not skip_the_layer or deepspeed_zero3_is_enabled:
# under deepspeed zero3 all gpus must run in sync
if self.gradient_checkpointing and self.training:
# create gradient checkpointing function
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(encoder_layer),
layer_outputs = self.gradient_checkpointing_func(
encoder_layer.__call__,
hidden_states,
padding_mask,
attention_mask,
output_attentions,
)
else:
layer_outputs = encoder_layer(
......@@ -1296,9 +1290,10 @@ class VitsPreTrainedModel(PreTrainedModel):
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (VitsTextEncoder)):
module.gradient_checkpointing = value
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, VitsEncoder):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
VITS_START_DOCSTRING = r"""
......
......@@ -338,17 +338,11 @@ class VivitEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module),
layer_outputs = self.gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
layer_head_mask,
output_attentions,
)
else:
layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions)
......@@ -414,9 +408,10 @@ class VivitPreTrainedModel(PreTrainedModel):
elif isinstance(module, nn.Parameter):
module.data.normal_(mean=0.0, std=self.config.initializer_range)
def _set_gradient_checkpointing(self, module, value=False):
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, VivitEncoder):
module.gradient_checkpointing = value
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
VIVIT_START_DOCSTRING = r"""
......
......@@ -451,15 +451,8 @@ class Wav2Vec2FeatureEncoder(nn.Module):
for conv_layer in self.conv_layers:
if self._requires_grad and self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(conv_layer),
hidden_states = self.gradient_checkpointing_func(
conv_layer.__call__,
hidden_states,
)
else:
......@@ -803,17 +796,11 @@ class Wav2Vec2Encoder(nn.Module):
if not skip_the_layer or deepspeed_zero3_is_enabled:
# under deepspeed zero3 all gpus must run in sync
if self.gradient_checkpointing and self.training:
# create gradient checkpointing function
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer),
layer_outputs = self.gradient_checkpointing_func(
layer.__call__,
hidden_states,
attention_mask,
output_attentions,
)
else:
layer_outputs = layer(
......@@ -892,17 +879,11 @@ class Wav2Vec2EncoderStableLayerNorm(nn.Module):
# under deepspeed zero3 all gpus must run in sync
# XXX: could optimize this like synced_gpus in generate_utils but not sure if it's worth the code complication
if self.gradient_checkpointing and self.training:
# create gradient checkpointing function
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer),
layer_outputs = self.gradient_checkpointing_func(
layer.__call__,
hidden_states,
attention_mask,
output_attentions,
)
else:
layer_outputs = layer(
......@@ -1173,9 +1154,10 @@ class Wav2Vec2PreTrainedModel(PreTrainedModel):
attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool()
return attention_mask
def _set_gradient_checkpointing(self, module, value=False):
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, (Wav2Vec2Encoder, Wav2Vec2EncoderStableLayerNorm, Wav2Vec2FeatureEncoder)):
module.gradient_checkpointing = value
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
def _get_adapters(self):
if self.config.adapter_attn_dim is None:
......
......@@ -518,15 +518,8 @@ class Wav2Vec2ConformerFeatureEncoder(nn.Module):
for conv_layer in self.conv_layers:
if self._requires_grad and self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(conv_layer),
hidden_states = self.gradient_checkpointing_func(
conv_layer.__call__,
hidden_states,
)
else:
......@@ -911,18 +904,12 @@ class Wav2Vec2ConformerEncoder(nn.Module):
if not skip_the_layer or deepspeed_zero3_is_enabled:
# under deepspeed zero3 all gpus must run in sync
if self.gradient_checkpointing and self.training:
# create gradient checkpointing function
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer),
layer_outputs = self.gradient_checkpointing_func(
layer.__call__,
hidden_states,
attention_mask,
relative_position_embeddings,
output_attentions,
)
else:
layer_outputs = layer(
......@@ -1178,9 +1165,10 @@ class Wav2Vec2ConformerPreTrainedModel(PreTrainedModel):
attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool()
return attention_mask
def _set_gradient_checkpointing(self, module, value=False):
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, (Wav2Vec2ConformerEncoder, Wav2Vec2ConformerFeatureEncoder)):
module.gradient_checkpointing = value
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
WAV2VEC2_CONFORMER_START_DOCSTRING = r"""
......
......@@ -354,15 +354,8 @@ class WavLMFeatureEncoder(nn.Module):
for conv_layer in self.conv_layers:
if self._requires_grad and self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(conv_layer),
hidden_states = self.gradient_checkpointing_func(
conv_layer.__call__,
hidden_states,
)
else:
......@@ -713,18 +706,12 @@ class WavLMEncoder(nn.Module):
if not skip_the_layer or deepspeed_zero3_is_enabled:
# under deepspeed zero3 all gpus must run in sync
if self.gradient_checkpointing and self.training:
# create gradient checkpointing function
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer),
layer_outputs = self.gradient_checkpointing_func(
layer.__call__,
hidden_states,
attention_mask,
position_bias,
output_attentions,
)
else:
layer_outputs = layer(
......@@ -804,18 +791,12 @@ class WavLMEncoderStableLayerNorm(nn.Module):
# under deepspeed zero3 all gpus must run in sync
# XXX: could optimize this like synced_gpus in generate_utils but not sure if it's worth the code complication
if self.gradient_checkpointing and self.training:
# create gradient checkpointing function
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer),
layer_outputs = self.gradient_checkpointing_func(
layer.__call__,
hidden_states,
attention_mask,
position_bias,
output_attentions,
)
else:
layer_outputs = layer(
......@@ -1052,9 +1033,10 @@ class WavLMPreTrainedModel(PreTrainedModel):
attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool()
return attention_mask
def _set_gradient_checkpointing(self, module, value=False):
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, (WavLMEncoder, WavLMEncoderStableLayerNorm, WavLMFeatureEncoder)):
module.gradient_checkpointing = value
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
WAVLM_START_DOCSTRING = r"""
......
......@@ -685,9 +685,10 @@ class WhisperPreTrainedModel(PreTrainedModel):
embed_positions = module.embed_positions.weight
embed_positions.copy_(sinusoids(*embed_positions.shape))
def _set_gradient_checkpointing(self, module, value=False):
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, (WhisperDecoder, WhisperEncoder)):
module.gradient_checkpointing = value
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor):
"""
......@@ -942,18 +943,12 @@ class WhisperEncoder(WhisperPreTrainedModel):
layer_outputs = (None, None)
else:
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(encoder_layer),
layer_outputs = self.gradient_checkpointing_func(
encoder_layer.__call__,
hidden_states,
None,
(head_mask[idx] if head_mask is not None else None),
output_attentions,
)
else:
layer_outputs = encoder_layer(
......@@ -1174,16 +1169,8 @@ class WhisperDecoder(WhisperPreTrainedModel):
past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, output_attentions, use_cache)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
layer_outputs = self.gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
attention_mask,
encoder_hidden_states,
......@@ -1191,6 +1178,8 @@ class WhisperDecoder(WhisperPreTrainedModel):
head_mask[idx] if head_mask is not None else None,
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
None, # past_key_value
output_attentions,
use_cache,
)
else:
layer_outputs = decoder_layer(
......
......@@ -534,9 +534,10 @@ class XCLIPPreTrainedModel(PreTrainedModel):
if module.bias is not None:
module.bias.data.zero_()
def _set_gradient_checkpointing(self, module, value=False):
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, (XCLIPEncoder, XCLIPVisionEncoder)):
module.gradient_checkpointing = value
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
X_CLIP_START_DOCSTRING = r"""
......@@ -703,18 +704,12 @@ class XCLIPEncoder(nn.Module):
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(encoder_layer),
layer_outputs = self.gradient_checkpointing_func(
encoder_layer.__call__,
hidden_states,
attention_mask,
causal_attention_mask,
output_attentions,
)
else:
layer_outputs = encoder_layer(
......@@ -950,18 +945,12 @@ class XCLIPVisionEncoder(nn.Module):
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(encoder_layer),
layer_outputs = self.gradient_checkpointing_func(
encoder_layer.__call__,
hidden_states,
attention_mask,
causal_attention_mask,
output_attentions,
)
else:
layer_outputs = encoder_layer(
......
......@@ -503,9 +503,10 @@ class XGLMPreTrainedModel(PreTrainedModel):
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
def _set_gradient_checkpointing(self, module, value=False):
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, XGLMModel):
module.gradient_checkpointing = value
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
@add_start_docstrings(
......@@ -674,16 +675,8 @@ class XGLMModel(XGLMPreTrainedModel):
past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, output_attentions, use_cache)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
layer_outputs = self.gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
attention_mask,
encoder_hidden_states,
......@@ -691,6 +684,8 @@ class XGLMModel(XGLMPreTrainedModel):
head_mask[idx] if head_mask is not None else None,
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
None,
output_attentions,
use_cache,
)
else:
layer_outputs = decoder_layer(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment