Unverified Commit 06e782da authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

[`core`] Refactor of `gradient_checkpointing` (#27020)

* v1

* fix

* remove `create_custom_forward`

* fixup

* fixup

* add test and fix all failing GC tests

* remove all remaining `create_custom_forward` methods

* fix idefics bug

* fixup

* replace with `__call__`

* add comment

* quality
parent 9286f0ac
...@@ -804,20 +804,15 @@ class BridgeTowerTextEncoder(nn.Module): ...@@ -804,20 +804,15 @@ class BridgeTowerTextEncoder(nn.Module):
past_key_value = past_key_values[i] if past_key_values is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
def create_custom_forward(module): layer_module.__call__,
def custom_forward(*inputs):
return module(*inputs, past_key_value, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module),
hidden_states, hidden_states,
attention_mask, attention_mask,
layer_head_mask, layer_head_mask,
encoder_hidden_states, encoder_hidden_states,
encoder_attention_mask, encoder_attention_mask,
past_key_value,
output_attentions,
) )
else: else:
layer_outputs = layer_module( layer_outputs = layer_module(
......
...@@ -651,21 +651,15 @@ class BrosEncoder(nn.Module): ...@@ -651,21 +651,15 @@ class BrosEncoder(nn.Module):
"`use_cache=False`..." "`use_cache=False`..."
) )
use_cache = False use_cache = False
layer_outputs = self.gradient_checkpointing_func(
def create_custom_forward(module): layer_module.__call__,
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module),
hidden_states, hidden_states,
bbox_pos_emb, bbox_pos_emb,
attention_mask, attention_mask,
layer_head_mask, layer_head_mask,
encoder_hidden_states, encoder_hidden_states,
encoder_attention_mask, encoder_attention_mask,
output_attentions,
) )
else: else:
layer_outputs = layer_module( layer_outputs = layer_module(
......
...@@ -524,20 +524,15 @@ class CamembertEncoder(nn.Module): ...@@ -524,20 +524,15 @@ class CamembertEncoder(nn.Module):
past_key_value = past_key_values[i] if past_key_values is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
def create_custom_forward(module): layer_module.__call__,
def custom_forward(*inputs):
return module(*inputs, past_key_value, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module),
hidden_states, hidden_states,
attention_mask, attention_mask,
layer_head_mask, layer_head_mask,
encoder_hidden_states, encoder_hidden_states,
encoder_attention_mask, encoder_attention_mask,
past_key_value,
output_attentions,
) )
else: else:
layer_outputs = layer_module( layer_outputs = layer_module(
...@@ -625,9 +620,10 @@ class CamembertPreTrainedModel(PreTrainedModel): ...@@ -625,9 +620,10 @@ class CamembertPreTrainedModel(PreTrainedModel):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, value=False): def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, CamembertEncoder): if isinstance(module, CamembertEncoder):
module.gradient_checkpointing = value module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
CAMEMBERT_INPUTS_DOCSTRING = r""" CAMEMBERT_INPUTS_DOCSTRING = r"""
......
...@@ -795,18 +795,12 @@ class CanineEncoder(nn.Module): ...@@ -795,18 +795,12 @@ class CanineEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None layer_head_mask = head_mask[i] if head_mask is not None else None
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
def create_custom_forward(module): layer_module.__call__,
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module),
hidden_states, hidden_states,
attention_mask, attention_mask,
layer_head_mask, layer_head_mask,
output_attentions,
) )
else: else:
layer_outputs = layer_module(hidden_states, attention_mask, layer_head_mask, output_attentions) layer_outputs = layer_module(hidden_states, attention_mask, layer_head_mask, output_attentions)
...@@ -919,9 +913,10 @@ class CaninePreTrainedModel(PreTrainedModel): ...@@ -919,9 +913,10 @@ class CaninePreTrainedModel(PreTrainedModel):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, value=False): def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, CanineEncoder): if isinstance(module, CanineEncoder):
module.gradient_checkpointing = value module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
CANINE_START_DOCSTRING = r""" CANINE_START_DOCSTRING = r"""
......
...@@ -742,9 +742,10 @@ class ChineseCLIPPreTrainedModel(PreTrainedModel): ...@@ -742,9 +742,10 @@ class ChineseCLIPPreTrainedModel(PreTrainedModel):
if module.bias is not None: if module.bias is not None:
module.bias.data.zero_() module.bias.data.zero_()
def _set_gradient_checkpointing(self, module, value=False): def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, ChineseCLIPVisionEncoder) or isinstance(module, ChineseCLIPTextEncoder): if isinstance(module, ChineseCLIPVisionEncoder) or isinstance(module, ChineseCLIPTextEncoder):
module.gradient_checkpointing = value module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
CHINESE_CLIP_START_DOCSTRING = r""" CHINESE_CLIP_START_DOCSTRING = r"""
...@@ -909,20 +910,15 @@ class ChineseCLIPTextEncoder(nn.Module): ...@@ -909,20 +910,15 @@ class ChineseCLIPTextEncoder(nn.Module):
past_key_value = past_key_values[i] if past_key_values is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
def create_custom_forward(module): layer_module.__call__,
def custom_forward(*inputs):
return module(*inputs, past_key_value, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module),
hidden_states, hidden_states,
attention_mask, attention_mask,
layer_head_mask, layer_head_mask,
encoder_hidden_states, encoder_hidden_states,
encoder_attention_mask, encoder_attention_mask,
past_key_value,
output_attentions,
) )
else: else:
layer_outputs = layer_module( layer_outputs = layer_module(
...@@ -1018,16 +1014,10 @@ class ChineseCLIPVisionEncoder(nn.Module): ...@@ -1018,16 +1014,10 @@ class ChineseCLIPVisionEncoder(nn.Module):
if output_hidden_states: if output_hidden_states:
encoder_states = encoder_states + (hidden_states,) encoder_states = encoder_states + (hidden_states,)
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
def create_custom_forward(module): encoder_layer.__call__,
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(encoder_layer),
hidden_states, hidden_states,
output_attentions,
) )
else: else:
layer_outputs = encoder_layer( layer_outputs = encoder_layer(
......
...@@ -939,15 +939,8 @@ class ClapAudioEncoder(nn.Module): ...@@ -939,15 +939,8 @@ class ClapAudioEncoder(nn.Module):
input_dimensions = self.input_resolutions[i] input_dimensions = self.input_resolutions[i]
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
def create_custom_forward(module): layer_module.__call__, hidden_states, input_dimensions, layer_head_mask, output_attentions
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module), hidden_states, input_dimensions, layer_head_mask
) )
else: else:
layer_outputs = layer_module( layer_outputs = layer_module(
...@@ -1595,20 +1588,15 @@ class ClapTextEncoder(nn.Module): ...@@ -1595,20 +1588,15 @@ class ClapTextEncoder(nn.Module):
past_key_value = past_key_values[i] if past_key_values is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
def create_custom_forward(module): layer_module.__call__,
def custom_forward(*inputs):
return module(*inputs, past_key_value, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module),
hidden_states, hidden_states,
attention_mask, attention_mask,
layer_head_mask, layer_head_mask,
encoder_hidden_states, encoder_hidden_states,
encoder_attention_mask, encoder_attention_mask,
past_key_value,
output_attentions,
) )
else: else:
layer_outputs = layer_module( layer_outputs = layer_module(
...@@ -1701,9 +1689,10 @@ class ClapPreTrainedModel(PreTrainedModel): ...@@ -1701,9 +1689,10 @@ class ClapPreTrainedModel(PreTrainedModel):
if module.bias is not None: if module.bias is not None:
module.bias.data.zero_() module.bias.data.zero_()
def _set_gradient_checkpointing(self, module, value=False): def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, ClapTextEncoder): if isinstance(module, ClapTextEncoder):
module.gradient_checkpointing = value module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
class ClapAudioModel(ClapPreTrainedModel): class ClapAudioModel(ClapPreTrainedModel):
......
...@@ -467,9 +467,10 @@ class CLIPPreTrainedModel(PreTrainedModel): ...@@ -467,9 +467,10 @@ class CLIPPreTrainedModel(PreTrainedModel):
if isinstance(module, nn.Linear) and module.bias is not None: if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_() module.bias.data.zero_()
def _set_gradient_checkpointing(self, module, value=False): def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, CLIPEncoder): if isinstance(module, CLIPEncoder):
module.gradient_checkpointing = value module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
CLIP_START_DOCSTRING = r""" CLIP_START_DOCSTRING = r"""
...@@ -639,18 +640,12 @@ class CLIPEncoder(nn.Module): ...@@ -639,18 +640,12 @@ class CLIPEncoder(nn.Module):
if output_hidden_states: if output_hidden_states:
encoder_states = encoder_states + (hidden_states,) encoder_states = encoder_states + (hidden_states,)
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
def create_custom_forward(module): encoder_layer.__call__,
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(encoder_layer),
hidden_states, hidden_states,
attention_mask, attention_mask,
causal_attention_mask, causal_attention_mask,
output_attentions,
) )
else: else:
layer_outputs = encoder_layer( layer_outputs = encoder_layer(
......
...@@ -479,9 +479,10 @@ class CLIPSegPreTrainedModel(PreTrainedModel): ...@@ -479,9 +479,10 @@ class CLIPSegPreTrainedModel(PreTrainedModel):
if isinstance(module, nn.Linear) and module.bias is not None: if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_() module.bias.data.zero_()
def _set_gradient_checkpointing(self, module, value=False): def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, CLIPSegEncoder): if isinstance(module, CLIPSegEncoder):
module.gradient_checkpointing = value module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
CLIPSEG_START_DOCSTRING = r""" CLIPSEG_START_DOCSTRING = r"""
...@@ -648,18 +649,12 @@ class CLIPSegEncoder(nn.Module): ...@@ -648,18 +649,12 @@ class CLIPSegEncoder(nn.Module):
if output_hidden_states: if output_hidden_states:
encoder_states = encoder_states + (hidden_states,) encoder_states = encoder_states + (hidden_states,)
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
def create_custom_forward(module): encoder_layer.__call__,
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(encoder_layer),
hidden_states, hidden_states,
attention_mask, attention_mask,
causal_attention_mask, causal_attention_mask,
output_attentions,
) )
else: else:
layer_outputs = encoder_layer( layer_outputs = encoder_layer(
......
...@@ -339,9 +339,10 @@ class CodeGenPreTrainedModel(PreTrainedModel): ...@@ -339,9 +339,10 @@ class CodeGenPreTrainedModel(PreTrainedModel):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, value=False): def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, CodeGenModel): if isinstance(module, CodeGenModel):
module.gradient_checkpointing = value module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
CODEGEN_START_DOCSTRING = r""" CODEGEN_START_DOCSTRING = r"""
...@@ -542,21 +543,15 @@ class CodeGenModel(CodeGenPreTrainedModel): ...@@ -542,21 +543,15 @@ class CodeGenModel(CodeGenPreTrainedModel):
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
outputs = self.gradient_checkpointing_func(
def create_custom_forward(module): block.__call__,
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, use_cache, output_attentions)
return custom_forward
outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states, hidden_states,
None, None,
attention_mask, attention_mask,
position_ids, position_ids,
head_mask[i], head_mask[i],
use_cache,
output_attentions,
) )
else: else:
outputs = block( outputs = block(
......
...@@ -1171,9 +1171,10 @@ class ConditionalDetrPreTrainedModel(PreTrainedModel): ...@@ -1171,9 +1171,10 @@ class ConditionalDetrPreTrainedModel(PreTrainedModel):
if module.padding_idx is not None: if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_() module.weight.data[module.padding_idx].zero_()
def _set_gradient_checkpointing(self, module, value=False): def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, ConditionalDetrDecoder): if isinstance(module, ConditionalDetrDecoder):
module.gradient_checkpointing = value module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
CONDITIONAL_DETR_START_DOCSTRING = r""" CONDITIONAL_DETR_START_DOCSTRING = r"""
...@@ -1518,15 +1519,8 @@ class ConditionalDetrDecoder(ConditionalDetrPreTrainedModel): ...@@ -1518,15 +1519,8 @@ class ConditionalDetrDecoder(ConditionalDetrPreTrainedModel):
# apply transformation # apply transformation
query_sine_embed = query_sine_embed_before_transformation * pos_transformation query_sine_embed = query_sine_embed_before_transformation * pos_transformation
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
def create_custom_forward(module): decoder_layer.__call__,
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
hidden_states, hidden_states,
combined_attention_mask, combined_attention_mask,
object_queries, object_queries,
......
...@@ -264,9 +264,10 @@ class ConvBertPreTrainedModel(PreTrainedModel): ...@@ -264,9 +264,10 @@ class ConvBertPreTrainedModel(PreTrainedModel):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, value=False): def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, ConvBertEncoder): if isinstance(module, ConvBertEncoder):
module.gradient_checkpointing = value module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
class SeparableConv1D(nn.Module): class SeparableConv1D(nn.Module):
...@@ -632,20 +633,14 @@ class ConvBertEncoder(nn.Module): ...@@ -632,20 +633,14 @@ class ConvBertEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None layer_head_mask = head_mask[i] if head_mask is not None else None
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
def create_custom_forward(module): layer_module.__call__,
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module),
hidden_states, hidden_states,
attention_mask, attention_mask,
layer_head_mask, layer_head_mask,
encoder_hidden_states, encoder_hidden_states,
encoder_attention_mask, encoder_attention_mask,
output_attentions,
) )
else: else:
layer_outputs = layer_module( layer_outputs = layer_module(
......
...@@ -296,9 +296,10 @@ class ConvNextPreTrainedModel(PreTrainedModel): ...@@ -296,9 +296,10 @@ class ConvNextPreTrainedModel(PreTrainedModel):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, value=False): def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, ConvNextEncoder): if isinstance(module, ConvNextEncoder):
module.gradient_checkpointing = value module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
CONVNEXT_START_DOCSTRING = r""" CONVNEXT_START_DOCSTRING = r"""
......
...@@ -317,9 +317,10 @@ class ConvNextV2PreTrainedModel(PreTrainedModel): ...@@ -317,9 +317,10 @@ class ConvNextV2PreTrainedModel(PreTrainedModel):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, value=False): def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, ConvNextV2Encoder): if isinstance(module, ConvNextV2Encoder):
module.gradient_checkpointing = value module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
CONVNEXTV2_START_DOCSTRING = r""" CONVNEXTV2_START_DOCSTRING = r"""
......
...@@ -556,9 +556,10 @@ class CpmAntPreTrainedModel(PreTrainedModel): ...@@ -556,9 +556,10 @@ class CpmAntPreTrainedModel(PreTrainedModel):
elif isinstance(module, CpmAntSegmentPositionEmbedding): elif isinstance(module, CpmAntSegmentPositionEmbedding):
module.relative_attention_bias.data.normal_(mean=0.0, std=self.config.init_std) module.relative_attention_bias.data.normal_(mean=0.0, std=self.config.init_std)
def _set_gradient_checkpointing(self, module, value=False): def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, CpmAntEncoder): if isinstance(module, CpmAntEncoder):
module.gradient_checkpointing = value module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
CPMANT_START_DOCSTRING = r""" CPMANT_START_DOCSTRING = r"""
......
...@@ -293,15 +293,8 @@ class Data2VecAudioFeatureEncoder(nn.Module): ...@@ -293,15 +293,8 @@ class Data2VecAudioFeatureEncoder(nn.Module):
for conv_layer in self.conv_layers: for conv_layer in self.conv_layers:
if self._requires_grad and self.gradient_checkpointing and self.training: if self._requires_grad and self.gradient_checkpointing and self.training:
hidden_states = self.gradient_checkpointing_func(
def create_custom_forward(module): conv_layer.__call__,
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(conv_layer),
hidden_states, hidden_states,
) )
else: else:
...@@ -593,17 +586,11 @@ class Data2VecAudioEncoder(nn.Module): ...@@ -593,17 +586,11 @@ class Data2VecAudioEncoder(nn.Module):
if not skip_the_layer or deepspeed_zero3_is_enabled: if not skip_the_layer or deepspeed_zero3_is_enabled:
# under deepspeed zero3 all gpus must run in sync # under deepspeed zero3 all gpus must run in sync
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
# create gradient checkpointing function layer_outputs = self.gradient_checkpointing_func(
def create_custom_forward(module): layer.__call__,
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer),
hidden_states, hidden_states,
attention_mask, attention_mask,
output_attentions,
) )
else: else:
layer_outputs = layer( layer_outputs = layer(
...@@ -761,9 +748,10 @@ class Data2VecAudioPreTrainedModel(PreTrainedModel): ...@@ -761,9 +748,10 @@ class Data2VecAudioPreTrainedModel(PreTrainedModel):
attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool() attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool()
return attention_mask return attention_mask
def _set_gradient_checkpointing(self, module, value=False): def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, (Data2VecAudioEncoder, Data2VecAudioFeatureEncoder)): if isinstance(module, (Data2VecAudioEncoder, Data2VecAudioFeatureEncoder)):
module.gradient_checkpointing = value module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
DATA2VEC_AUDIO_START_DOCSTRING = r""" DATA2VEC_AUDIO_START_DOCSTRING = r"""
......
...@@ -510,20 +510,15 @@ class Data2VecTextEncoder(nn.Module): ...@@ -510,20 +510,15 @@ class Data2VecTextEncoder(nn.Module):
past_key_value = past_key_values[i] if past_key_values is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
def create_custom_forward(module): layer_module.__call__,
def custom_forward(*inputs):
return module(*inputs, past_key_value, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module),
hidden_states, hidden_states,
attention_mask, attention_mask,
layer_head_mask, layer_head_mask,
encoder_hidden_states, encoder_hidden_states,
encoder_attention_mask, encoder_attention_mask,
past_key_value,
output_attentions,
) )
else: else:
layer_outputs = layer_module( layer_outputs = layer_module(
...@@ -613,9 +608,10 @@ class Data2VecTextPreTrainedModel(PreTrainedModel): ...@@ -613,9 +608,10 @@ class Data2VecTextPreTrainedModel(PreTrainedModel):
if hasattr(module, "weight") and module.weight is not None: if hasattr(module, "weight") and module.weight is not None:
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, value=False): def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, Data2VecTextEncoder): if isinstance(module, Data2VecTextEncoder):
module.gradient_checkpointing = value module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
DATA2VECTEXT_START_DOCSTRING = r""" DATA2VECTEXT_START_DOCSTRING = r"""
......
...@@ -522,17 +522,11 @@ class Data2VecVisionEncoder(nn.Module): ...@@ -522,17 +522,11 @@ class Data2VecVisionEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None layer_head_mask = head_mask[i] if head_mask is not None else None
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
def create_custom_forward(module): layer_module.__call__,
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module),
hidden_states, hidden_states,
layer_head_mask, layer_head_mask,
output_attentions,
) )
else: else:
relative_position_bias = ( relative_position_bias = (
...@@ -585,9 +579,10 @@ class Data2VecVisionPreTrainedModel(PreTrainedModel): ...@@ -585,9 +579,10 @@ class Data2VecVisionPreTrainedModel(PreTrainedModel):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, value=False): def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, Data2VecVisionEncoder): if isinstance(module, Data2VecVisionEncoder):
module.gradient_checkpointing = value module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
DATA2VEC_VISION_START_DOCSTRING = r""" DATA2VEC_VISION_START_DOCSTRING = r"""
......
...@@ -457,20 +457,14 @@ class DebertaEncoder(nn.Module): ...@@ -457,20 +457,14 @@ class DebertaEncoder(nn.Module):
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
hidden_states = self.gradient_checkpointing_func(
def create_custom_forward(module): layer_module.__call__,
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module),
next_kv, next_kv,
attention_mask, attention_mask,
query_states, query_states,
relative_pos, relative_pos,
rel_embeddings, rel_embeddings,
output_attentions,
) )
else: else:
hidden_states = layer_module( hidden_states = layer_module(
...@@ -839,9 +833,10 @@ class DebertaPreTrainedModel(PreTrainedModel): ...@@ -839,9 +833,10 @@ class DebertaPreTrainedModel(PreTrainedModel):
if module.padding_idx is not None: if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_() module.weight.data[module.padding_idx].zero_()
def _set_gradient_checkpointing(self, module, value=False): def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, DebertaEncoder): if isinstance(module, DebertaEncoder):
module.gradient_checkpointing = value module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
DEBERTA_START_DOCSTRING = r""" DEBERTA_START_DOCSTRING = r"""
......
...@@ -501,20 +501,14 @@ class DebertaV2Encoder(nn.Module): ...@@ -501,20 +501,14 @@ class DebertaV2Encoder(nn.Module):
all_hidden_states = all_hidden_states + (output_states,) all_hidden_states = all_hidden_states + (output_states,)
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
output_states = self.gradient_checkpointing_func(
def create_custom_forward(module): layer_module.__call__,
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
output_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module),
next_kv, next_kv,
attention_mask, attention_mask,
query_states, query_states,
relative_pos, relative_pos,
rel_embeddings, rel_embeddings,
output_attentions,
) )
else: else:
output_states = layer_module( output_states = layer_module(
...@@ -938,9 +932,10 @@ class DebertaV2PreTrainedModel(PreTrainedModel): ...@@ -938,9 +932,10 @@ class DebertaV2PreTrainedModel(PreTrainedModel):
if module.padding_idx is not None: if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_() module.weight.data[module.padding_idx].zero_()
def _set_gradient_checkpointing(self, module, value=False): def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, DebertaV2Encoder): if isinstance(module, DebertaV2Encoder):
module.gradient_checkpointing = value module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
DEBERTA_START_DOCSTRING = r""" DEBERTA_START_DOCSTRING = r"""
......
...@@ -469,9 +469,10 @@ class DecisionTransformerGPT2PreTrainedModel(PreTrainedModel): ...@@ -469,9 +469,10 @@ class DecisionTransformerGPT2PreTrainedModel(PreTrainedModel):
# Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
p.data.normal_(mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.n_layer))) p.data.normal_(mean=0.0, std=(self.config.initializer_range / math.sqrt(2 * self.config.n_layer)))
def _set_gradient_checkpointing(self, module, value=False): def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, DecisionTransformerGPT2Model): if isinstance(module, DecisionTransformerGPT2Model):
module.gradient_checkpointing = value module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
class DecisionTransformerGPT2Model(DecisionTransformerGPT2PreTrainedModel): class DecisionTransformerGPT2Model(DecisionTransformerGPT2PreTrainedModel):
...@@ -631,22 +632,16 @@ class DecisionTransformerGPT2Model(DecisionTransformerGPT2PreTrainedModel): ...@@ -631,22 +632,16 @@ class DecisionTransformerGPT2Model(DecisionTransformerGPT2PreTrainedModel):
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
outputs = self.gradient_checkpointing_func(
def create_custom_forward(module): block.__call__,
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, use_cache, output_attentions)
return custom_forward
outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states, hidden_states,
None, None,
attention_mask, attention_mask,
head_mask[i], head_mask[i],
encoder_hidden_states, encoder_hidden_states,
encoder_attention_mask, encoder_attention_mask,
use_cache,
output_attentions,
) )
else: else:
outputs = block( outputs = block(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment