Unverified Commit 06e782da authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

[`core`] Refactor of `gradient_checkpointing` (#27020)

* v1

* fix

* remove `create_custom_forward`

* fixup

* fixup

* add test and fix all failing GC tests

* remove all remaining `create_custom_forward` methods

* fix idefics bug

* fixup

* replace with `__call__`

* add comment

* quality
parent 9286f0ac
......@@ -333,9 +333,10 @@ class MgpstrPreTrainedModel(PreTrainedModel):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module: MgpstrEncoder, value: bool = False) -> None:
def _set_gradient_checkpointing(self, module: MgpstrEncoder, gradient_checkpointing_func=None) -> None:
if isinstance(module, MgpstrEncoder):
module.gradient_checkpointing = value
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
MGP_STR_START_DOCSTRING = r"""
......
......@@ -816,9 +816,10 @@ class MistralPreTrainedModel(PreTrainedModel):
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
def _set_gradient_checkpointing(self, module, value=False):
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, MistralModel):
module.gradient_checkpointing = value
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
MISTRAL_INPUTS_DOCSTRING = r"""
......@@ -1020,19 +1021,14 @@ class MistralModel(MistralPreTrainedModel):
past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, past_key_value, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
layer_outputs = self.gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
attention_mask,
position_ids,
past_key_value,
output_attentions,
use_cache,
)
else:
layer_outputs = decoder_layer(
......
......@@ -626,15 +626,8 @@ class MobileViTEncoder(nn.Module):
for i, layer_module in enumerate(self.layer):
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module),
hidden_states = self.gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
)
else:
......@@ -672,9 +665,10 @@ class MobileViTPreTrainedModel(PreTrainedModel):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, value=False):
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, MobileViTEncoder):
module.gradient_checkpointing = value
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
MOBILEVIT_START_DOCSTRING = r"""
......
......@@ -582,15 +582,8 @@ class MobileViTV2Encoder(nn.Module):
for i, layer_module in enumerate(self.layer):
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module),
hidden_states = self.gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
)
else:
......@@ -629,9 +622,10 @@ class MobileViTV2PreTrainedModel(PreTrainedModel):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, value=False):
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, MobileViTV2Encoder):
module.gradient_checkpointing = value
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
MOBILEVITV2_START_DOCSTRING = r"""
......
......@@ -294,9 +294,10 @@ class MptPreTrainedModel(PreTrainedModel):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module: nn.Module, value: bool = False):
def _set_gradient_checkpointing(self, module: nn.Module, gradient_checkpointing_func=None):
if isinstance(module, MptModel):
module.gradient_checkpointing = value
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
@staticmethod
def _convert_to_mpt_cache(
......@@ -523,20 +524,14 @@ class MptModel(MptPreTrainedModel):
all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, use_cache=use_cache, output_attentions=output_attentions)
return custom_forward
outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
outputs = self.gradient_checkpointing_func(
block.__call__,
hidden_states,
alibi,
causal_mask,
layer_past,
use_cache,
output_attentions,
)
else:
outputs = block(
......
......@@ -766,15 +766,8 @@ class MraEncoder(nn.Module):
all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module),
layer_outputs = self.gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
attention_mask,
)
......@@ -871,9 +864,10 @@ class MraPreTrainedModel(PreTrainedModel):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, value=False):
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, MraEncoder):
module.gradient_checkpointing = value
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
MRA_START_DOCSTRING = r"""
......
......@@ -845,9 +845,10 @@ class MT5PreTrainedModel(PreTrainedModel):
if module.has_relative_attention_bias:
module.relative_attention_bias.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5))
def _set_gradient_checkpointing(self, module, value=False):
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, (MT5Attention, MT5Stack)):
module.gradient_checkpointing = value
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
def _shift_right(self, input_ids):
decoder_start_token_id = self.config.decoder_start_token_id
......@@ -1073,15 +1074,8 @@ class MT5Stack(MT5PreTrainedModel):
all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
return tuple(module(*inputs, use_cache, output_attentions))
return custom_forward
layer_outputs = checkpoint(
create_custom_forward(layer_module),
layer_module.forward,
hidden_states,
extended_attention_mask,
position_bias,
......@@ -1091,6 +1085,8 @@ class MT5Stack(MT5PreTrainedModel):
layer_head_mask,
cross_attn_layer_head_mask,
None, # past_key_value is always None with gradient checkpointing
use_cache,
output_attentions,
)
else:
layer_outputs = layer_module(
......
......@@ -475,9 +475,10 @@ class MusicgenPreTrainedModel(PreTrainedModel):
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
def _set_gradient_checkpointing(self, module, value=False):
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, MusicgenDecoder):
module.gradient_checkpointing = value
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
MUSICGEN_START_DOCSTRING = r"""
......@@ -826,16 +827,8 @@ class MusicgenDecoder(MusicgenPreTrainedModel):
past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, output_attentions, use_cache)
return custom_forward
layer_outputs = checkpoint(
create_custom_forward(decoder_layer),
decoder_layer.forward,
hidden_states,
attention_mask,
encoder_hidden_states,
......@@ -843,6 +836,8 @@ class MusicgenDecoder(MusicgenPreTrainedModel):
head_mask[idx] if head_mask is not None else None,
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
None,
output_attentions,
use_cache,
)
else:
layer_outputs = decoder_layer(
......@@ -1562,10 +1557,10 @@ class MusicgenForConditionalGeneration(PreTrainedModel):
self.text_encoder, self.decoder._modules[decoder_base_model_prefix], self.decoder.base_model_prefix
)
def _set_gradient_checkpointing(self, module, value=False):
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
# call both encoder and decoder function on gradient checkpointing
self.text_encoder._set_gradient_checkpointing(module, value=value)
self.decoder._set_gradient_checkpointing(module, value=value)
self.text_encoder._set_gradient_checkpointing(module, gradient_checkpointing_func=gradient_checkpointing_func)
self.decoder._set_gradient_checkpointing(module, gradient_checkpointing_func=gradient_checkpointing_func)
def get_audio_encoder(self):
return self.audio_encoder
......
......@@ -563,9 +563,10 @@ class MvpPreTrainedModel(PreTrainedModel):
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
def _set_gradient_checkpointing(self, module, value=False):
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, (MvpDecoder, MvpEncoder, MvpPrompt)):
module.gradient_checkpointing = value
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
@property
def dummy_inputs(self):
......@@ -949,19 +950,13 @@ class MvpEncoder(MvpPreTrainedModel):
layer_outputs = (None, None)
else:
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(encoder_layer),
layer_outputs = self.gradient_checkpointing_func(
encoder_layer.__call__,
hidden_states,
attention_mask,
(head_mask[idx] if head_mask is not None else None),
(self_attn_prompt[idx] if self.use_prompt else None),
output_attentions,
)
else:
layer_outputs = encoder_layer(
......@@ -1227,16 +1222,8 @@ class MvpDecoder(MvpPreTrainedModel):
past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, output_attentions, use_cache)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
layer_outputs = self.gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
attention_mask,
encoder_hidden_states,
......@@ -1246,6 +1233,8 @@ class MvpDecoder(MvpPreTrainedModel):
self_attn_prompt[idx] if self.use_prompt else None,
cross_attn_prompt[idx] if self.use_prompt else None,
None,
output_attentions,
use_cache,
)
else:
layer_outputs = decoder_layer(
......
......@@ -639,7 +639,7 @@ class NatPreTrainedModel(PreTrainedModel):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module: NatEncoder, value: bool = False) -> None:
def _set_gradient_checkpointing(self, module: NatEncoder, gradient_checkpointing_func=None) -> None:
pass
......
......@@ -577,20 +577,15 @@ class NezhaEncoder(nn.Module):
past_key_value = past_key_values[i] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, past_key_value, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module),
layer_outputs = self.gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
)
else:
layer_outputs = layer_module(
......@@ -752,9 +747,10 @@ class NezhaPreTrainedModel(PreTrainedModel):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, value=False):
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, NezhaEncoder):
module.gradient_checkpointing = value
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
@dataclass
......
......@@ -874,9 +874,10 @@ class NllbMoePreTrainedModel(PreTrainedModel):
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
def _set_gradient_checkpointing(self, module, value=False):
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, (NllbMoeDecoder, NllbMoeEncoder)):
module.gradient_checkpointing = value
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
NLLB_MOE_START_DOCSTRING = r"""
......@@ -1153,18 +1154,12 @@ class NllbMoeEncoder(NllbMoePreTrainedModel):
layer_outputs = (None, None, None)
else:
if self.gradient_checkpointing and self.training:
# create gradient checkpointing function
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(encoder_layer),
layer_outputs = self.gradient_checkpointing_func(
encoder_layer.__call__,
hidden_states,
attention_mask,
(head_mask[idx] if head_mask is not None else None),
output_attentions,
)
else:
layer_outputs = encoder_layer(
......@@ -1426,15 +1421,8 @@ class NllbMoeDecoder(NllbMoePreTrainedModel):
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
def create_custom_forward(module):
def custom_forward(*inputs):
return tuple(module(*inputs, use_cache, output_attentions))
return custom_forward
layer_outputs = checkpoint(
create_custom_forward(decoder_layer),
decoder_layer.forward,
hidden_states,
combined_attention_mask,
encoder_hidden_states,
......@@ -1442,6 +1430,8 @@ class NllbMoeDecoder(NllbMoePreTrainedModel):
layer_head_mask,
cross_attn_layer_head_mask,
None, # past_key_value is always None with gradient checkpointing
use_cache,
output_attentions,
)
else:
layer_outputs = decoder_layer(
......
......@@ -370,17 +370,11 @@ class NystromformerEncoder(nn.Module):
all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module),
layer_outputs = self.gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
attention_mask,
output_attentions,
)
else:
layer_outputs = layer_module(hidden_states, attention_mask, output_attentions)
......@@ -477,9 +471,10 @@ class NystromformerPreTrainedModel(PreTrainedModel):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, value=False):
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, NystromformerEncoder):
module.gradient_checkpointing = value
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
NYSTROMFORMER_START_DOCSTRING = r"""
......
......@@ -2616,7 +2616,7 @@ class OneFormerTextTransformer(nn.Module):
def forward(self, hidden_states: torch.Tensor):
for layer in self.layers:
if self.use_checkpoint:
hidden_states = torch.utils.checkpoint.checkpoint(layer, hidden_states)
hidden_states = self.gradient_checkpointing_func(layer, hidden_states)
else:
hidden_states = layer(hidden_states)
return hidden_states
......
......@@ -411,9 +411,10 @@ class OPTPreTrainedModel(PreTrainedModel):
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
def _set_gradient_checkpointing(self, module, value=False):
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, (OPTDecoder)):
module.gradient_checkpointing = value
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
OPT_INPUTS_DOCSTRING = r"""
......@@ -691,20 +692,14 @@ class OPTDecoder(OPTPreTrainedModel):
past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, output_attentions, None)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
layer_outputs = self.gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
causal_attention_mask,
head_mask[idx] if head_mask is not None else None,
None,
output_attentions,
use_cache,
)
else:
layer_outputs = decoder_layer(
......
......@@ -584,9 +584,10 @@ class Owlv2PreTrainedModel(PreTrainedModel):
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
def _set_gradient_checkpointing(self, module, value=False):
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, Owlv2Encoder):
module.gradient_checkpointing = value
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
OWLV2_START_DOCSTRING = r"""
......@@ -764,18 +765,12 @@ class Owlv2Encoder(nn.Module):
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(encoder_layer),
layer_outputs = self.gradient_checkpointing_func(
encoder_layer.__call__,
hidden_states,
attention_mask,
causal_attention_mask,
output_attentions,
)
else:
layer_outputs = encoder_layer(
......@@ -1378,6 +1373,7 @@ class Owlv2ForObjectDetection(Owlv2PreTrainedModel):
def objectness_predictor(self, image_features: torch.FloatTensor) -> torch.FloatTensor:
"""Predicts the probability that each image feature token is an object.
Args:
image_features (`torch.FloatTensor` of shape `(batch_size, num_patches, hidden_dim)`)):
Features extracted from the image.
......
......@@ -576,9 +576,10 @@ class OwlViTPreTrainedModel(PreTrainedModel):
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
def _set_gradient_checkpointing(self, module, value=False):
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, OwlViTEncoder):
module.gradient_checkpointing = value
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
OWLVIT_START_DOCSTRING = r"""
......@@ -753,18 +754,12 @@ class OwlViTEncoder(nn.Module):
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(encoder_layer),
layer_outputs = self.gradient_checkpointing_func(
encoder_layer.__call__,
hidden_states,
attention_mask,
causal_attention_mask,
output_attentions,
)
else:
layer_outputs = encoder_layer(
......
......@@ -500,9 +500,10 @@ class PegasusPreTrainedModel(PreTrainedModel):
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
def _set_gradient_checkpointing(self, module, value=False):
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, (PegasusDecoder, PegasusEncoder)):
module.gradient_checkpointing = value
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
PEGASUS_START_DOCSTRING = r"""
......@@ -803,18 +804,12 @@ class PegasusEncoder(PegasusPreTrainedModel):
layer_outputs = (None, None)
else:
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(encoder_layer),
layer_outputs = self.gradient_checkpointing_func(
encoder_layer.__call__,
hidden_states,
attention_mask,
(head_mask[idx] if head_mask is not None else None),
output_attentions,
)
else:
layer_outputs = encoder_layer(
......@@ -1087,16 +1082,8 @@ class PegasusDecoder(PegasusPreTrainedModel):
past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, output_attentions, use_cache)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
layer_outputs = self.gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
attention_mask,
encoder_hidden_states,
......@@ -1104,6 +1091,8 @@ class PegasusDecoder(PegasusPreTrainedModel):
head_mask[idx] if head_mask is not None else None,
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
None,
output_attentions,
use_cache,
)
else:
layer_outputs = decoder_layer(
......
......@@ -780,9 +780,10 @@ class PegasusXPreTrainedModel(PreTrainedModel):
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
def _set_gradient_checkpointing(self, module, value=False):
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, (PegasusXDecoder, PegasusXEncoder)):
module.gradient_checkpointing = value
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
PEGASUS_X_START_DOCSTRING = r"""
......@@ -1071,18 +1072,12 @@ class PegasusXEncoder(PegasusXPreTrainedModel):
layer_outputs = (None, None)
else:
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(encoder_layer),
layer_outputs = self.gradient_checkpointing_func(
encoder_layer.__call__,
hidden_states,
global_hidden_states,
attention_mask,
output_attentions,
)
else:
layer_outputs = encoder_layer(
......@@ -1331,21 +1326,15 @@ class PegasusXDecoder(PegasusXPreTrainedModel):
past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, output_attentions, use_cache)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
layer_outputs = self.gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
attention_mask,
encoder_hidden_states,
encoder_attention_mask,
None,
output_attentions,
use_cache,
)
else:
layer_outputs = decoder_layer(
......
......@@ -467,9 +467,10 @@ class PersimmonPreTrainedModel(PreTrainedModel):
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
def _set_gradient_checkpointing(self, module, value=False):
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, PersimmonModel):
module.gradient_checkpointing = value
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
PERSIMMON_INPUTS_DOCSTRING = r"""
......@@ -668,19 +669,13 @@ class PersimmonModel(PersimmonPreTrainedModel):
past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, past_key_value, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
layer_outputs = self.gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
attention_mask,
position_ids,
past_key_value,
output_attentions,
)
else:
layer_outputs = decoder_layer(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment