Unverified Commit 06e782da authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

[`core`] Refactor of `gradient_checkpointing` (#27020)

* v1

* fix

* remove `create_custom_forward`

* fixup

* fixup

* add test and fix all failing GC tests

* remove all remaining `create_custom_forward` methods

* fix idefics bug

* fixup

* replace with `__call__`

* add comment

* quality
parent 9286f0ac
......@@ -249,10 +249,10 @@ class SpeechEncoderDecoderModel(PreTrainedModel):
f"The encoder {self.encoder} should not have a LM Head. Please use a model without LM Head"
)
def _set_gradient_checkpointing(self, module, value=False):
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
# call both encoder and decoder function on gradient checkpointing
self.encoder._set_gradient_checkpointing(module, value=value)
self.decoder._set_gradient_checkpointing(module, value=value)
self.encoder._set_gradient_checkpointing(module, gradient_checkpointing_func=gradient_checkpointing_func)
self.decoder._set_gradient_checkpointing(module, gradient_checkpointing_func=gradient_checkpointing_func)
def get_encoder(self):
return self.encoder
......
......@@ -559,9 +559,10 @@ class Speech2TextPreTrainedModel(PreTrainedModel):
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
def _set_gradient_checkpointing(self, module, value=False):
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, (Speech2TextDecoder, Speech2TextEncoder)):
module.gradient_checkpointing = value
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor):
"""
......@@ -817,18 +818,12 @@ class Speech2TextEncoder(Speech2TextPreTrainedModel):
layer_outputs = (None, None)
else:
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(encoder_layer),
layer_outputs = self.gradient_checkpointing_func(
encoder_layer.__call__,
hidden_states,
attention_mask,
(head_mask[idx] if head_mask is not None else None),
output_attentions,
)
else:
layer_outputs = encoder_layer(
......@@ -1065,16 +1060,8 @@ class Speech2TextDecoder(Speech2TextPreTrainedModel):
past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, output_attentions, use_cache)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
layer_outputs = self.gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
attention_mask,
encoder_hidden_states,
......@@ -1082,6 +1069,8 @@ class Speech2TextDecoder(Speech2TextPreTrainedModel):
head_mask[idx] if head_mask is not None else None,
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
None,
output_attentions,
use_cache,
)
else:
layer_outputs = decoder_layer(
......
......@@ -437,9 +437,10 @@ class Speech2Text2PreTrainedModel(PreTrainedModel):
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
def _set_gradient_checkpointing(self, module, value=False):
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, Speech2Text2Decoder):
module.gradient_checkpointing = value
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
SPEECH_TO_TEXT_2_START_DOCSTRING = r"""
......@@ -669,16 +670,8 @@ class Speech2Text2Decoder(Speech2Text2PreTrainedModel):
past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, output_attentions, use_cache)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
layer_outputs = self.gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
attention_mask,
encoder_hidden_states,
......
......@@ -520,15 +520,8 @@ class SpeechT5FeatureEncoder(nn.Module):
for conv_layer in self.conv_layers:
if self._requires_grad and self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(conv_layer),
hidden_states = self.gradient_checkpointing_func(
conv_layer.__call__,
hidden_states,
)
else:
......@@ -1281,9 +1274,10 @@ class SpeechT5PreTrainedModel(PreTrainedModel):
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
def _set_gradient_checkpointing(self, module, value=False):
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, (SpeechT5Encoder, SpeechT5Decoder, SpeechT5FeatureEncoder)):
module.gradient_checkpointing = value
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
class SpeechT5Encoder(SpeechT5PreTrainedModel):
......@@ -1386,19 +1380,13 @@ class SpeechT5Encoder(SpeechT5PreTrainedModel):
if not skip_the_layer or deepspeed_zero3_is_enabled:
# under deepspeed zero3 all gpus must run in sync
if self.gradient_checkpointing and self.training:
# create gradient checkpointing function
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(encoder_layer),
layer_outputs = self.gradient_checkpointing_func(
encoder_layer.__call__,
hidden_states,
attention_mask,
(head_mask[idx] if head_mask is not None else None),
position_bias,
output_attentions,
)
else:
layer_outputs = encoder_layer(
......@@ -1439,7 +1427,6 @@ class SpeechT5EncoderWithSpeechPrenet(SpeechT5PreTrainedModel):
super().__init__(config)
self.prenet = SpeechT5SpeechEncoderPrenet(config)
self.wrapped_encoder = SpeechT5Encoder(config)
self.gradient_checkpointing = False
# Initialize weights and apply final processing
self.post_init()
......@@ -1476,7 +1463,6 @@ class SpeechT5EncoderWithTextPrenet(SpeechT5PreTrainedModel):
super().__init__(config)
self.prenet = SpeechT5TextEncoderPrenet(config)
self.wrapped_encoder = SpeechT5Encoder(config)
self.gradient_checkpointing = False
# Initialize weights and apply final processing
self.post_init()
......@@ -1519,7 +1505,6 @@ class SpeechT5EncoderWithoutPrenet(SpeechT5PreTrainedModel):
def __init__(self, config: SpeechT5Config):
super().__init__(config)
self.wrapped_encoder = SpeechT5Encoder(config)
self.gradient_checkpointing = False
# Initialize weights and apply final processing
self.post_init()
......@@ -1715,16 +1700,8 @@ class SpeechT5Decoder(SpeechT5PreTrainedModel):
past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, output_attentions, use_cache)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
layer_outputs = self.gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
attention_mask,
encoder_hidden_states,
......@@ -1732,6 +1709,8 @@ class SpeechT5Decoder(SpeechT5PreTrainedModel):
head_mask[idx] if head_mask is not None else None,
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
None,
output_attentions,
use_cache,
)
else:
layer_outputs = decoder_layer(
......@@ -1788,7 +1767,6 @@ class SpeechT5DecoderWithSpeechPrenet(SpeechT5PreTrainedModel):
super().__init__(config)
self.prenet = SpeechT5SpeechDecoderPrenet(config)
self.wrapped_decoder = SpeechT5Decoder(config)
self.gradient_checkpointing = False
# Initialize weights and apply final processing
self.post_init()
......@@ -1836,7 +1814,6 @@ class SpeechT5DecoderWithTextPrenet(SpeechT5PreTrainedModel):
super().__init__(config)
self.prenet = SpeechT5TextDecoderPrenet(config)
self.wrapped_decoder = SpeechT5Decoder(config)
self.gradient_checkpointing = False
# Initialize weights and apply final processing
self.post_init()
......@@ -1889,7 +1866,6 @@ class SpeechT5DecoderWithoutPrenet(SpeechT5PreTrainedModel):
def __init__(self, config: SpeechT5Config):
super().__init__(config)
self.wrapped_decoder = SpeechT5Decoder(config)
self.gradient_checkpointing = False
# Initialize weights and apply final processing
self.post_init()
......
......@@ -459,20 +459,15 @@ class SplinterEncoder(nn.Module):
past_key_value = past_key_values[i] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, past_key_value, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module),
layer_outputs = self.gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
)
else:
layer_outputs = layer_module(
......@@ -544,9 +539,10 @@ class SplinterPreTrainedModel(PreTrainedModel):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, value=False):
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, SplinterEncoder):
module.gradient_checkpointing = value
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
SPLINTER_START_DOCSTRING = r"""
......
......@@ -442,9 +442,10 @@ class SwiftFormerPreTrainedModel(PreTrainedModel):
nn.init.constant_(module.bias, 0)
nn.init.constant_(module.weight, 1.0)
def _set_gradient_checkpointing(self, module: SwiftFormerEncoder, value: bool = False) -> None:
def _set_gradient_checkpointing(self, module: SwiftFormerEncoder, gradient_checkpointing_func=None) -> None:
if isinstance(module, SwiftFormerEncoder):
module.gradient_checkpointing = value
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
SWIFTFORMER_START_DOCSTRING = r"""
......
......@@ -825,15 +825,8 @@ class SwinEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module), hidden_states, input_dimensions, layer_head_mask
layer_outputs = self.gradient_checkpointing_func(
layer_module.__call__, hidden_states, input_dimensions, layer_head_mask, output_attentions
)
else:
layer_outputs = layer_module(
......@@ -901,9 +894,10 @@ class SwinPreTrainedModel(PreTrainedModel):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, value=False):
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, SwinEncoder):
module.gradient_checkpointing = value
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
SWIN_START_DOCSTRING = r"""
......
......@@ -951,11 +951,6 @@ class TFSwinPreTrainedModel(TFPreTrainedModel):
config_class = SwinConfig
base_model_prefix = "swin"
main_input_name = "pixel_values"
supports_gradient_checkpointing = True
def _set_gradient_checkpointing(self, module, value=False) -> None:
if isinstance(module, TFSwinEncoder):
module.gradient_checkpointing = value
SWIN_START_DOCSTRING = r"""
......
......@@ -746,15 +746,8 @@ class Swin2SREncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(stage_module), hidden_states, input_dimensions, layer_head_mask
layer_outputs = self.gradient_checkpointing_func(
stage_module.__call__, hidden_states, input_dimensions, layer_head_mask, output_attentions
)
else:
layer_outputs = stage_module(hidden_states, input_dimensions, layer_head_mask, output_attentions)
......@@ -802,9 +795,10 @@ class Swin2SRPreTrainedModel(PreTrainedModel):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, value=False):
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, Swin2SREncoder):
module.gradient_checkpointing = value
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
SWIN2SR_START_DOCSTRING = r"""
......
......@@ -906,15 +906,8 @@ class Swinv2Encoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module), hidden_states, input_dimensions, layer_head_mask
layer_outputs = self.gradient_checkpointing_func(
layer_module.__call__, hidden_states, input_dimensions, layer_head_mask, output_attentions
)
else:
layer_outputs = layer_module(
......@@ -983,9 +976,10 @@ class Swinv2PreTrainedModel(PreTrainedModel):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, value=False):
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, Swinv2Encoder):
module.gradient_checkpointing = value
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
SWINV2_START_DOCSTRING = r"""
......
......@@ -865,9 +865,10 @@ class SwitchTransformersPreTrainedModel(PreTrainedModel):
module.experts[f"expert_{idx}"].wi.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5))
module.experts[f"expert_{idx}"].wo.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5))
def _set_gradient_checkpointing(self, module, value=False):
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, (SwitchTransformersAttention, SwitchTransformersStack)):
module.gradient_checkpointing = value
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
def _shift_right(self, input_ids):
decoder_start_token_id = self.config.decoder_start_token_id
......@@ -1039,15 +1040,8 @@ class SwitchTransformersStack(SwitchTransformersPreTrainedModel):
all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
return tuple(module(*inputs, use_cache, output_attentions))
return custom_forward
layer_outputs = checkpoint(
create_custom_forward(layer_module),
layer_module.forward,
hidden_states,
extended_attention_mask,
position_bias,
......@@ -1057,6 +1051,8 @@ class SwitchTransformersStack(SwitchTransformersPreTrainedModel):
layer_head_mask,
cross_attn_layer_head_mask,
None, # past_key_value is always None with gradient checkpointing
use_cache,
output_attentions,
)
else:
layer_outputs = layer_module(
......
......@@ -873,9 +873,10 @@ class T5PreTrainedModel(PreTrainedModel):
if module.has_relative_attention_bias:
module.relative_attention_bias.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5))
def _set_gradient_checkpointing(self, module, value=False):
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, (T5Attention, T5Stack)):
module.gradient_checkpointing = value
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
def _shift_right(self, input_ids):
decoder_start_token_id = self.config.decoder_start_token_id
......@@ -1100,15 +1101,8 @@ class T5Stack(T5PreTrainedModel):
all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
return tuple(module(*inputs, use_cache, output_attentions))
return custom_forward
layer_outputs = checkpoint(
create_custom_forward(layer_module),
layer_module.forward,
hidden_states,
extended_attention_mask,
position_bias,
......@@ -1118,6 +1112,8 @@ class T5Stack(T5PreTrainedModel):
layer_head_mask,
cross_attn_layer_head_mask,
None, # past_key_value is always None with gradient checkpointing
use_cache,
output_attentions,
)
else:
layer_outputs = layer_module(
......
......@@ -837,9 +837,10 @@ class TableTransformerPreTrainedModel(PreTrainedModel):
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
def _set_gradient_checkpointing(self, module, value=False):
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, TableTransformerDecoder):
module.gradient_checkpointing = value
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
TABLE_TRANSFORMER_START_DOCSTRING = r"""
......@@ -1149,15 +1150,8 @@ class TableTransformerDecoder(TableTransformerPreTrainedModel):
continue
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
layer_outputs = self.gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
combined_attention_mask,
encoder_hidden_states,
......
......@@ -646,20 +646,15 @@ class TapasEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, past_key_values, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module),
layer_outputs = self.gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_values,
output_attentions,
)
else:
layer_outputs = layer_module(
......@@ -778,9 +773,10 @@ class TapasPreTrainedModel(PreTrainedModel):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, value=False):
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, TapasEncoder):
module.gradient_checkpointing = value
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
TAPAS_START_DOCSTRING = r"""
......
......@@ -663,9 +663,10 @@ class TimeSeriesTransformerPreTrainedModel(PreTrainedModel):
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
def _set_gradient_checkpointing(self, module, value=False):
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, (TimeSeriesTransformerDecoder, TimeSeriesTransformerEncoder)):
module.gradient_checkpointing = value
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
TIME_SERIES_TRANSFORMER_START_DOCSTRING = r"""
......@@ -946,18 +947,12 @@ class TimeSeriesTransformerEncoder(TimeSeriesTransformerPreTrainedModel):
layer_outputs = (None, None)
else:
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(encoder_layer),
layer_outputs = self.gradient_checkpointing_func(
encoder_layer.__call__,
hidden_states,
attention_mask,
(head_mask[idx] if head_mask is not None else None),
output_attentions,
)
else:
layer_outputs = encoder_layer(
......@@ -1163,16 +1158,8 @@ class TimeSeriesTransformerDecoder(TimeSeriesTransformerPreTrainedModel):
past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, output_attentions, use_cache)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
layer_outputs = self.gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
attention_mask,
encoder_hidden_states,
......@@ -1180,6 +1167,8 @@ class TimeSeriesTransformerDecoder(TimeSeriesTransformerPreTrainedModel):
head_mask[idx] if head_mask is not None else None,
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
None,
output_attentions,
use_cache,
)
else:
layer_outputs = decoder_layer(
......
......@@ -439,16 +439,10 @@ class TimesformerEncoder(nn.Module):
all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module),
layer_outputs = self.gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
output_attentions,
)
else:
layer_outputs = layer_module(hidden_states, output_attentions)
......@@ -494,9 +488,10 @@ class TimesformerPreTrainedModel(PreTrainedModel):
nn.init.trunc_normal_(module.position_embeddings, std=self.config.initializer_range)
module.patch_embeddings.apply(self._init_weights)
def _set_gradient_checkpointing(self, module, value=False):
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, TimesformerEncoder):
module.gradient_checkpointing = value
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
TIMESFORMER_START_DOCSTRING = r"""
......
......@@ -454,9 +454,10 @@ class TrOCRPreTrainedModel(PreTrainedModel):
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
def _set_gradient_checkpointing(self, module, value=False):
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, TrOCRDecoder):
module.gradient_checkpointing = value
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
TROCR_START_DOCSTRING = r"""
......@@ -701,16 +702,8 @@ class TrOCRDecoder(TrOCRPreTrainedModel):
past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, output_attentions, use_cache)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
layer_outputs = self.gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
attention_mask,
encoder_hidden_states,
......@@ -718,6 +711,8 @@ class TrOCRDecoder(TrOCRPreTrainedModel):
head_mask[idx] if head_mask is not None else None,
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
None,
output_attentions,
use_cache,
)
else:
layer_outputs = decoder_layer(
......
......@@ -560,18 +560,12 @@ class TvltEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module),
layer_outputs = self.gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
attention_mask,
layer_head_mask,
output_attentions,
)
else:
layer_outputs = layer_module(hidden_states, attention_mask, layer_head_mask, output_attentions)
......@@ -616,9 +610,10 @@ class TvltPreTrainedModel(PreTrainedModel):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, TvltEncoder):
module.gradient_checkpointing = value
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, (TvltEncoder, TvltDecoder)):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
TVLT_START_DOCSTRING = r"""
......@@ -877,17 +872,11 @@ class TvltDecoder(nn.Module):
all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module),
layer_outputs = self.gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
None,
output_attentions,
)
else:
layer_outputs = layer_module(hidden_states, output_attentions=output_attentions)
......
......@@ -556,9 +556,10 @@ class UMT5PreTrainedModel(PreTrainedModel):
if module.has_relative_attention_bias:
module.relative_attention_bias.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5))
def _set_gradient_checkpointing(self, module, value=False):
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, (UMT5Attention, UMT5Stack)):
module.gradient_checkpointing = value
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
def _shift_right(self, input_ids):
decoder_start_token_id = self.config.decoder_start_token_id
......@@ -709,15 +710,8 @@ class UMT5Stack(UMT5PreTrainedModel):
all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
return tuple(module(*inputs, use_cache, output_attentions))
return custom_forward
layer_outputs = checkpoint(
create_custom_forward(layer_module),
layer_module.forward,
hidden_states,
extended_attention_mask,
encoder_hidden_states,
......@@ -725,6 +719,8 @@ class UMT5Stack(UMT5PreTrainedModel):
layer_head_mask,
cross_attn_layer_head_mask,
None, # past_key_value is always None with gradient checkpointing
use_cache,
output_attentions,
)
else:
layer_outputs = layer_module(
......
......@@ -384,15 +384,8 @@ class UniSpeechFeatureEncoder(nn.Module):
for conv_layer in self.conv_layers:
if self._requires_grad and self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(conv_layer),
hidden_states = self.gradient_checkpointing_func(
conv_layer.__call__,
hidden_states,
)
else:
......@@ -767,17 +760,11 @@ class UniSpeechEncoder(nn.Module):
if not skip_the_layer or deepspeed_zero3_is_enabled:
# under deepspeed zero3 all gpus must run in sync
if self.gradient_checkpointing and self.training:
# create gradient checkpointing function
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer),
layer_outputs = self.gradient_checkpointing_func(
layer.__call__,
hidden_states,
attention_mask,
output_attentions,
)
else:
layer_outputs = layer(
......@@ -857,17 +844,11 @@ class UniSpeechEncoderStableLayerNorm(nn.Module):
# under deepspeed zero3 all gpus must run in sync
# XXX: could optimize this like synced_gpus in generate_utils but not sure if it's worth the code complication
if self.gradient_checkpointing and self.training:
# create gradient checkpointing function
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer),
layer_outputs = self.gradient_checkpointing_func(
layer.__call__,
hidden_states,
attention_mask,
output_attentions,
)
else:
layer_outputs = layer(
......@@ -1039,9 +1020,10 @@ class UniSpeechPreTrainedModel(PreTrainedModel):
attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool()
return attention_mask
def _set_gradient_checkpointing(self, module, value=False):
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, (UniSpeechEncoder, UniSpeechEncoderStableLayerNorm, UniSpeechFeatureEncoder)):
module.gradient_checkpointing = value
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
UNISPEECH_START_DOCSTRING = r"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment