Unverified Commit 06e782da authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

[`core`] Refactor of `gradient_checkpointing` (#27020)

* v1

* fix

* remove `create_custom_forward`

* fixup

* fixup

* add test and fix all failing GC tests

* remove all remaining `create_custom_forward` methods

* fix idefics bug

* fixup

* replace with `__call__`

* add comment

* quality
parent 9286f0ac
...@@ -924,9 +924,10 @@ class InformerPreTrainedModel(PreTrainedModel): ...@@ -924,9 +924,10 @@ class InformerPreTrainedModel(PreTrainedModel):
if module.padding_idx is not None: if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_() module.weight.data[module.padding_idx].zero_()
def _set_gradient_checkpointing(self, module, value=False): def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, (InformerDecoder, InformerEncoder)): if isinstance(module, (InformerDecoder, InformerEncoder)):
module.gradient_checkpointing = value module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
INFORMER_START_DOCSTRING = r""" INFORMER_START_DOCSTRING = r"""
...@@ -1215,21 +1216,15 @@ class InformerEncoder(InformerPreTrainedModel): ...@@ -1215,21 +1216,15 @@ class InformerEncoder(InformerPreTrainedModel):
layer_outputs = (None, None) layer_outputs = (None, None)
else: else:
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
def create_custom_forward(module): encoder_layer.__call__,
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(encoder_layer),
hidden_states, hidden_states,
attention_mask, attention_mask,
(head_mask[idx] if head_mask is not None else None), (head_mask[idx] if head_mask is not None else None),
output_attentions,
) )
if conv_layer is not None: if conv_layer is not None:
output = torch.utils.checkpoint.checkpoint(conv_layer, layer_outputs[0]) output = self.gradient_checkpointing_func(conv_layer, layer_outputs[0])
layer_outputs = (output,) + layer_outputs[1:] layer_outputs = (output,) + layer_outputs[1:]
else: else:
layer_outputs = encoder_layer( layer_outputs = encoder_layer(
...@@ -1438,16 +1433,8 @@ class InformerDecoder(InformerPreTrainedModel): ...@@ -1438,16 +1433,8 @@ class InformerDecoder(InformerPreTrainedModel):
past_key_value = past_key_values[idx] if past_key_values is not None else None past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
def create_custom_forward(module): decoder_layer.__call__,
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, output_attentions, use_cache)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
hidden_states, hidden_states,
attention_mask, attention_mask,
encoder_hidden_states, encoder_hidden_states,
...@@ -1455,6 +1442,8 @@ class InformerDecoder(InformerPreTrainedModel): ...@@ -1455,6 +1442,8 @@ class InformerDecoder(InformerPreTrainedModel):
head_mask[idx] if head_mask is not None else None, head_mask[idx] if head_mask is not None else None,
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
None, None,
output_attentions,
use_cache,
) )
else: else:
layer_outputs = decoder_layer( layer_outputs = decoder_layer(
......
...@@ -304,9 +304,14 @@ class InstructBlipPreTrainedModel(PreTrainedModel): ...@@ -304,9 +304,14 @@ class InstructBlipPreTrainedModel(PreTrainedModel):
elif isinstance(module, nn.Linear) and module.bias is not None: elif isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_() module.bias.data.zero_()
def _set_gradient_checkpointing(self, module, value=False): def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, InstructBlipEncoder): if isinstance(module, (InstructBlipEncoder, InstructBlipQFormerEncoder)):
module.gradient_checkpointing = value module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
# Enable / disable GC for the language model as well
if hasattr(self, "language_model") and hasattr(self.language_model, "_set_gradient_checkpointing"):
self.language_model._set_gradient_checkpointing(module, gradient_checkpointing_func)
INSTRUCTBLIP_START_DOCSTRING = r""" INSTRUCTBLIP_START_DOCSTRING = r"""
...@@ -462,17 +467,11 @@ class InstructBlipEncoder(nn.Module): ...@@ -462,17 +467,11 @@ class InstructBlipEncoder(nn.Module):
if output_hidden_states: if output_hidden_states:
encoder_states = encoder_states + (hidden_states,) encoder_states = encoder_states + (hidden_states,)
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
def create_custom_forward(module): encoder_layer.__call__,
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(encoder_layer),
hidden_states, hidden_states,
attention_mask, attention_mask,
output_attentions,
) )
else: else:
layer_outputs = encoder_layer( layer_outputs = encoder_layer(
...@@ -939,15 +938,8 @@ class InstructBlipQFormerEncoder(nn.Module): ...@@ -939,15 +938,8 @@ class InstructBlipQFormerEncoder(nn.Module):
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
) )
use_cache = False use_cache = False
layer_outputs = self.gradient_checkpointing_func(
def create_custom_forward(module): layer_module.__call__,
def custom_forward(*inputs):
return module(*inputs, past_key_value, output_attentions, query_length)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module),
hidden_states, hidden_states,
attention_mask, attention_mask,
layer_head_mask, layer_head_mask,
......
...@@ -487,20 +487,15 @@ class LayoutLMEncoder(nn.Module): ...@@ -487,20 +487,15 @@ class LayoutLMEncoder(nn.Module):
past_key_value = past_key_values[i] if past_key_values is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
def create_custom_forward(module): layer_module.__call__,
def custom_forward(*inputs):
return module(*inputs, past_key_value, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module),
hidden_states, hidden_states,
attention_mask, attention_mask,
layer_head_mask, layer_head_mask,
encoder_hidden_states, encoder_hidden_states,
encoder_attention_mask, encoder_attention_mask,
past_key_value,
output_attentions,
) )
else: else:
layer_outputs = layer_module( layer_outputs = layer_module(
...@@ -638,9 +633,10 @@ class LayoutLMPreTrainedModel(PreTrainedModel): ...@@ -638,9 +633,10 @@ class LayoutLMPreTrainedModel(PreTrainedModel):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, value=False): def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, LayoutLMEncoder): if isinstance(module, LayoutLMEncoder):
module.gradient_checkpointing = value module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
LAYOUTLM_START_DOCSTRING = r""" LAYOUTLM_START_DOCSTRING = r"""
......
...@@ -439,18 +439,12 @@ class LayoutLMv2Encoder(nn.Module): ...@@ -439,18 +439,12 @@ class LayoutLMv2Encoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None layer_head_mask = head_mask[i] if head_mask is not None else None
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
def create_custom_forward(module): layer_module.__call__,
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module),
hidden_states, hidden_states,
attention_mask, attention_mask,
layer_head_mask, layer_head_mask,
output_attentions,
rel_pos=rel_pos, rel_pos=rel_pos,
rel_2d_pos=rel_2d_pos, rel_2d_pos=rel_2d_pos,
) )
...@@ -514,9 +508,10 @@ class LayoutLMv2PreTrainedModel(PreTrainedModel): ...@@ -514,9 +508,10 @@ class LayoutLMv2PreTrainedModel(PreTrainedModel):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, value=False): def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, LayoutLMv2Encoder): if isinstance(module, LayoutLMv2Encoder):
module.gradient_checkpointing = value module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
def my_convert_sync_batchnorm(module, process_group=None): def my_convert_sync_batchnorm(module, process_group=None):
......
...@@ -657,19 +657,8 @@ class LayoutLMv3Encoder(nn.Module): ...@@ -657,19 +657,8 @@ class LayoutLMv3Encoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None layer_head_mask = head_mask[i] if head_mask is not None else None
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
def create_custom_forward(module): layer_module.__call__,
def custom_forward(*inputs):
return module(*inputs)
# return module(*inputs, past_key_value, output_attentions, rel_pos, rel_2d_pos)
# The above line will cause error:
# RuntimeError: Trying to backward through the graph a second time
# (or directly access saved tensors after they have already been freed).
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module),
hidden_states, hidden_states,
attention_mask, attention_mask,
layer_head_mask, layer_head_mask,
......
...@@ -1155,9 +1155,10 @@ class LEDPreTrainedModel(PreTrainedModel): ...@@ -1155,9 +1155,10 @@ class LEDPreTrainedModel(PreTrainedModel):
if module.padding_idx is not None: if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_() module.weight.data[module.padding_idx].zero_()
def _set_gradient_checkpointing(self, module, value=False): def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, (LEDDecoder, LEDEncoder)): if isinstance(module, (LEDDecoder, LEDEncoder)):
module.gradient_checkpointing = value module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
@property @property
def dummy_inputs(self): def dummy_inputs(self):
...@@ -1876,20 +1877,15 @@ class LEDEncoder(LEDPreTrainedModel): ...@@ -1876,20 +1877,15 @@ class LEDEncoder(LEDPreTrainedModel):
layer_outputs = (None, None, None) layer_outputs = (None, None, None)
else: else:
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
def create_custom_forward(module): encoder_layer.__call__,
def custom_forward(*inputs):
return module(*inputs, is_global_attn, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(encoder_layer),
hidden_states, hidden_states,
attention_mask, attention_mask,
head_mask[idx] if head_mask is not None else None, head_mask[idx] if head_mask is not None else None,
is_index_masked, is_index_masked,
is_index_global_attn, is_index_global_attn,
is_global_attn,
output_attentions,
) )
else: else:
layer_outputs = encoder_layer( layer_outputs = encoder_layer(
...@@ -2142,16 +2138,8 @@ class LEDDecoder(LEDPreTrainedModel): ...@@ -2142,16 +2138,8 @@ class LEDDecoder(LEDPreTrainedModel):
past_key_value = past_key_values[idx] if past_key_values is not None else None past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
def create_custom_forward(module): decoder_layer.__call__,
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, output_attentions, use_cache)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
hidden_states, hidden_states,
combined_attention_mask, combined_attention_mask,
encoder_hidden_states, encoder_hidden_states,
...@@ -2159,6 +2147,8 @@ class LEDDecoder(LEDPreTrainedModel): ...@@ -2159,6 +2147,8 @@ class LEDDecoder(LEDPreTrainedModel):
head_mask[idx] if head_mask is not None else None, head_mask[idx] if head_mask is not None else None,
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
None, None,
output_attentions,
use_cache,
) )
else: else:
layer_outputs = decoder_layer( layer_outputs = decoder_layer(
......
...@@ -507,9 +507,10 @@ class LevitPreTrainedModel(PreTrainedModel): ...@@ -507,9 +507,10 @@ class LevitPreTrainedModel(PreTrainedModel):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, value=False): def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, LevitModel): if isinstance(module, LevitModel):
module.gradient_checkpointing = value module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
LEVIT_START_DOCSTRING = r""" LEVIT_START_DOCSTRING = r"""
......
...@@ -514,19 +514,13 @@ class LiltEncoder(nn.Module): ...@@ -514,19 +514,13 @@ class LiltEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None layer_head_mask = head_mask[i] if head_mask is not None else None
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
def create_custom_forward(module): layer_module.__call__,
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module),
hidden_states, hidden_states,
layout_inputs, layout_inputs,
attention_mask, attention_mask,
layer_head_mask, layer_head_mask,
output_attentions,
) )
else: else:
layer_outputs = layer_module( layer_outputs = layer_module(
...@@ -607,9 +601,10 @@ class LiltPreTrainedModel(PreTrainedModel): ...@@ -607,9 +601,10 @@ class LiltPreTrainedModel(PreTrainedModel):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, value=False): def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, LiltEncoder): if isinstance(module, LiltEncoder):
module.gradient_checkpointing = value module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
LILT_START_DOCSTRING = r""" LILT_START_DOCSTRING = r"""
......
...@@ -827,9 +827,10 @@ class LlamaPreTrainedModel(PreTrainedModel): ...@@ -827,9 +827,10 @@ class LlamaPreTrainedModel(PreTrainedModel):
if module.padding_idx is not None: if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_() module.weight.data[module.padding_idx].zero_()
def _set_gradient_checkpointing(self, module, value=False): def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, LlamaModel): if isinstance(module, LlamaModel):
module.gradient_checkpointing = value module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
LLAMA_INPUTS_DOCSTRING = r""" LLAMA_INPUTS_DOCSTRING = r"""
...@@ -1013,16 +1014,14 @@ class LlamaModel(LlamaPreTrainedModel): ...@@ -1013,16 +1014,14 @@ class LlamaModel(LlamaPreTrainedModel):
past_key_value = past_key_values[idx] if past_key_values is not None else None past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
def create_custom_forward(module): decoder_layer.__call__,
def custom_forward(*inputs): hidden_states,
# None for past_key_value attention_mask,
return module(*inputs, past_key_value, output_attentions) position_ids,
past_key_value,
return custom_forward output_attentions,
use_cache,
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer), hidden_states, attention_mask, position_ids
) )
else: else:
layer_outputs = decoder_layer( layer_outputs = decoder_layer(
......
...@@ -1304,20 +1304,15 @@ class LongformerEncoder(nn.Module): ...@@ -1304,20 +1304,15 @@ class LongformerEncoder(nn.Module):
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
def create_custom_forward(module): layer_module.__call__,
def custom_forward(*inputs):
return module(*inputs, is_global_attn, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module),
hidden_states, hidden_states,
attention_mask, attention_mask,
head_mask[idx] if head_mask is not None else None, head_mask[idx] if head_mask is not None else None,
is_index_masked, is_index_masked,
is_index_global_attn, is_index_global_attn,
is_global_attn,
output_attentions,
) )
else: else:
layer_outputs = layer_module( layer_outputs = layer_module(
...@@ -1439,9 +1434,10 @@ class LongformerPreTrainedModel(PreTrainedModel): ...@@ -1439,9 +1434,10 @@ class LongformerPreTrainedModel(PreTrainedModel):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, value=False): def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, LongformerEncoder): if isinstance(module, LongformerEncoder):
module.gradient_checkpointing = value module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
LONGFORMER_START_DOCSTRING = r""" LONGFORMER_START_DOCSTRING = r"""
......
...@@ -775,7 +775,6 @@ class LongT5TransientGlobalAttention(nn.Module): ...@@ -775,7 +775,6 @@ class LongT5TransientGlobalAttention(nn.Module):
if self.has_relative_attention_bias: if self.has_relative_attention_bias:
self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads) self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads)
self.pruned_heads = set() self.pruned_heads = set()
self.gradient_checkpointing = False
# Relativen attention bias & Layer norm for global attention # Relativen attention bias & Layer norm for global attention
if self.has_relative_attention_bias: if self.has_relative_attention_bias:
...@@ -1340,10 +1339,10 @@ class LongT5PreTrainedModel(PreTrainedModel): ...@@ -1340,10 +1339,10 @@ class LongT5PreTrainedModel(PreTrainedModel):
mean=0.0, std=factor * ((d_model) ** -0.5) mean=0.0, std=factor * ((d_model) ** -0.5)
) )
# Copied from transformers.models.t5.modeling_t5.T5PreTrainedModel._set_gradient_checkpointing with T5->LongT5 def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
def _set_gradient_checkpointing(self, module, value=False): if isinstance(module, (LongT5Attention, LongT5Stack, LongT5LocalAttention)):
if isinstance(module, (LongT5Attention, LongT5Stack)): module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = value module.gradient_checkpointing = gradient_checkpointing_func is not None
# Copied from transformers.models.t5.modeling_t5.T5PreTrainedModel._shift_right with T5->LongT5 # Copied from transformers.models.t5.modeling_t5.T5PreTrainedModel._shift_right with T5->LongT5
def _shift_right(self, input_ids): def _shift_right(self, input_ids):
...@@ -1510,15 +1509,8 @@ class LongT5Stack(LongT5PreTrainedModel): ...@@ -1510,15 +1509,8 @@ class LongT5Stack(LongT5PreTrainedModel):
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
return tuple(module(*inputs, use_cache, output_attentions))
return custom_forward
layer_outputs = checkpoint( layer_outputs = checkpoint(
create_custom_forward(layer_module), layer_module.forward,
hidden_states, hidden_states,
extended_attention_mask, extended_attention_mask,
position_bias, position_bias,
...@@ -1528,6 +1520,8 @@ class LongT5Stack(LongT5PreTrainedModel): ...@@ -1528,6 +1520,8 @@ class LongT5Stack(LongT5PreTrainedModel):
layer_head_mask, layer_head_mask,
cross_attn_layer_head_mask, cross_attn_layer_head_mask,
None, # past_key_value is always None with gradient checkpointing None, # past_key_value is always None with gradient checkpointing
use_cache,
output_attentions,
) )
else: else:
layer_outputs = layer_module( layer_outputs = layer_module(
......
...@@ -788,19 +788,13 @@ class LukeEncoder(nn.Module): ...@@ -788,19 +788,13 @@ class LukeEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None layer_head_mask = head_mask[i] if head_mask is not None else None
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
def create_custom_forward(module): layer_module.__call__,
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module),
word_hidden_states, word_hidden_states,
entity_hidden_states, entity_hidden_states,
attention_mask, attention_mask,
layer_head_mask, layer_head_mask,
output_attentions,
) )
else: else:
layer_outputs = layer_module( layer_outputs = layer_module(
...@@ -920,9 +914,10 @@ class LukePreTrainedModel(PreTrainedModel): ...@@ -920,9 +914,10 @@ class LukePreTrainedModel(PreTrainedModel):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, value=False): def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, LukeEncoder): if isinstance(module, LukeEncoder):
module.gradient_checkpointing = value module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
LUKE_START_DOCSTRING = r""" LUKE_START_DOCSTRING = r"""
......
...@@ -552,9 +552,10 @@ class M2M100PreTrainedModel(PreTrainedModel): ...@@ -552,9 +552,10 @@ class M2M100PreTrainedModel(PreTrainedModel):
if module.padding_idx is not None: if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_() module.weight.data[module.padding_idx].zero_()
def _set_gradient_checkpointing(self, module, value=False): def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, (M2M100Decoder, M2M100Encoder)): if isinstance(module, (M2M100Decoder, M2M100Encoder)):
module.gradient_checkpointing = value module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
M2M_100_START_DOCSTRING = r""" M2M_100_START_DOCSTRING = r"""
...@@ -820,18 +821,12 @@ class M2M100Encoder(M2M100PreTrainedModel): ...@@ -820,18 +821,12 @@ class M2M100Encoder(M2M100PreTrainedModel):
# under deepspeed zero3 all gpus must run in sync # under deepspeed zero3 all gpus must run in sync
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
# create gradient checkpointing function layer_outputs = self.gradient_checkpointing_func(
def create_custom_forward(module): encoder_layer.__call__,
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(encoder_layer),
hidden_states, hidden_states,
attention_mask, attention_mask,
(head_mask[idx] if head_mask is not None else None), (head_mask[idx] if head_mask is not None else None),
output_attentions,
) )
else: else:
layer_outputs = encoder_layer( layer_outputs = encoder_layer(
...@@ -1066,16 +1061,8 @@ class M2M100Decoder(M2M100PreTrainedModel): ...@@ -1066,16 +1061,8 @@ class M2M100Decoder(M2M100PreTrainedModel):
past_key_value = past_key_values[idx] if past_key_values is not None else None past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
def create_custom_forward(module): decoder_layer.__call__,
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, output_attentions, use_cache)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
hidden_states, hidden_states,
combined_attention_mask, combined_attention_mask,
encoder_hidden_states, encoder_hidden_states,
...@@ -1083,6 +1070,8 @@ class M2M100Decoder(M2M100PreTrainedModel): ...@@ -1083,6 +1070,8 @@ class M2M100Decoder(M2M100PreTrainedModel):
head_mask[idx] if head_mask is not None else None, head_mask[idx] if head_mask is not None else None,
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
None, None,
output_attentions,
use_cache,
) )
else: else:
layer_outputs = decoder_layer( layer_outputs = decoder_layer(
......
...@@ -500,9 +500,10 @@ class MarianPreTrainedModel(PreTrainedModel): ...@@ -500,9 +500,10 @@ class MarianPreTrainedModel(PreTrainedModel):
if module.padding_idx is not None: if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_() module.weight.data[module.padding_idx].zero_()
def _set_gradient_checkpointing(self, module, value=False): def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, (MarianDecoder, MarianEncoder)): if isinstance(module, (MarianDecoder, MarianEncoder)):
module.gradient_checkpointing = value module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
@property @property
def dummy_inputs(self): def dummy_inputs(self):
...@@ -788,18 +789,12 @@ class MarianEncoder(MarianPreTrainedModel): ...@@ -788,18 +789,12 @@ class MarianEncoder(MarianPreTrainedModel):
layer_outputs = (None, None) layer_outputs = (None, None)
else: else:
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
def create_custom_forward(module): encoder_layer.__call__,
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(encoder_layer),
hidden_states, hidden_states,
attention_mask, attention_mask,
(head_mask[idx] if head_mask is not None else None), (head_mask[idx] if head_mask is not None else None),
output_attentions,
) )
else: else:
layer_outputs = encoder_layer( layer_outputs = encoder_layer(
...@@ -1037,16 +1032,8 @@ class MarianDecoder(MarianPreTrainedModel): ...@@ -1037,16 +1032,8 @@ class MarianDecoder(MarianPreTrainedModel):
past_key_value = past_key_values[idx] if past_key_values is not None else None past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
def create_custom_forward(module): decoder_layer.__call__,
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, output_attentions, use_cache)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
hidden_states, hidden_states,
attention_mask, attention_mask,
encoder_hidden_states, encoder_hidden_states,
...@@ -1054,6 +1041,8 @@ class MarianDecoder(MarianPreTrainedModel): ...@@ -1054,6 +1041,8 @@ class MarianDecoder(MarianPreTrainedModel):
head_mask[idx] if head_mask is not None else None, head_mask[idx] if head_mask is not None else None,
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
None, None,
output_attentions,
use_cache,
) )
else: else:
layer_outputs = decoder_layer( layer_outputs = decoder_layer(
......
...@@ -648,20 +648,15 @@ class MarkupLMEncoder(nn.Module): ...@@ -648,20 +648,15 @@ class MarkupLMEncoder(nn.Module):
past_key_value = past_key_values[i] if past_key_values is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
def create_custom_forward(module): layer_module.__call__,
def custom_forward(*inputs):
return module(*inputs, past_key_value, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module),
hidden_states, hidden_states,
attention_mask, attention_mask,
layer_head_mask, layer_head_mask,
encoder_hidden_states, encoder_hidden_states,
encoder_attention_mask, encoder_attention_mask,
past_key_value,
output_attentions,
) )
else: else:
layer_outputs = layer_module( layer_outputs = layer_module(
......
...@@ -1864,20 +1864,14 @@ class Mask2FormerMaskedAttentionDecoder(nn.Module): ...@@ -1864,20 +1864,14 @@ class Mask2FormerMaskedAttentionDecoder(nn.Module):
continue continue
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
def create_custom_forward(module): decoder_layer.__call__,
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
hidden_states, hidden_states,
attention_mask, attention_mask,
encoder_hidden_states, encoder_hidden_states,
None, None,
None, None,
output_attentions,
) )
else: else:
......
...@@ -848,20 +848,14 @@ class DetrDecoder(nn.Module): ...@@ -848,20 +848,14 @@ class DetrDecoder(nn.Module):
continue continue
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
def create_custom_forward(module): decoder_layer.__call__,
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
hidden_states, hidden_states,
combined_attention_mask, combined_attention_mask,
encoder_hidden_states, encoder_hidden_states,
encoder_attention_mask, encoder_attention_mask,
None, None,
output_attentions,
) )
else: else:
layer_outputs = decoder_layer( layer_outputs = decoder_layer(
...@@ -1619,11 +1613,13 @@ class MaskFormerPreTrainedModel(PreTrainedModel): ...@@ -1619,11 +1613,13 @@ class MaskFormerPreTrainedModel(PreTrainedModel):
if module.padding_idx is not None: if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_() module.weight.data[module.padding_idx].zero_()
def _set_gradient_checkpointing(self, module, value=False): def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, MaskFormerPixelLevelModule): if isinstance(module, MaskFormerPixelLevelModule):
module.encoder.gradient_checkpointing = value module.gradient_checkpointing_func = gradient_checkpointing_func
module.encoder.gradient_checkpointing = gradient_checkpointing_func is not None
if isinstance(module, DetrDecoder): if isinstance(module, DetrDecoder):
module.gradient_checkpointing = value module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
@add_start_docstrings( @add_start_docstrings(
......
...@@ -688,15 +688,11 @@ class MaskFormerSwinEncoder(nn.Module): ...@@ -688,15 +688,11 @@ class MaskFormerSwinEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None layer_head_mask = head_mask[i] if head_mask is not None else None
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_hidden_states, output_dimensions, layer_all_hidden_states = self.gradient_checkpointing_func(
def create_custom_forward(module): layer_module.__call__,
def custom_forward(*inputs): hidden_states,
return module(*inputs, output_attentions) layer_head_mask,
output_attentions,
return custom_forward
layer_hidden_states, output_dimensions, layer_all_hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module), hidden_states, layer_head_mask
) )
else: else:
layer_hidden_states, output_dimensions, layer_all_hidden_states = layer_module( layer_hidden_states, output_dimensions, layer_all_hidden_states = layer_module(
...@@ -752,9 +748,10 @@ class MaskFormerSwinPreTrainedModel(PreTrainedModel): ...@@ -752,9 +748,10 @@ class MaskFormerSwinPreTrainedModel(PreTrainedModel):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, value=False): def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, MaskFormerSwinEncoder): if isinstance(module, MaskFormerSwinEncoder):
module.gradient_checkpointing = value module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
class MaskFormerSwinModel(MaskFormerSwinPreTrainedModel): class MaskFormerSwinModel(MaskFormerSwinPreTrainedModel):
......
...@@ -516,9 +516,10 @@ class MBartPreTrainedModel(PreTrainedModel): ...@@ -516,9 +516,10 @@ class MBartPreTrainedModel(PreTrainedModel):
if module.padding_idx is not None: if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_() module.weight.data[module.padding_idx].zero_()
def _set_gradient_checkpointing(self, module, value=False): def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, (MBartDecoder, MBartDecoder)): if isinstance(module, (MBartDecoder, MBartEncoder)):
module.gradient_checkpointing = value module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
@property @property
def dummy_inputs(self): def dummy_inputs(self):
...@@ -828,18 +829,12 @@ class MBartEncoder(MBartPreTrainedModel): ...@@ -828,18 +829,12 @@ class MBartEncoder(MBartPreTrainedModel):
layer_outputs = (None, None) layer_outputs = (None, None)
else: else:
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
def create_custom_forward(module): encoder_layer.__call__,
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(encoder_layer),
hidden_states, hidden_states,
attention_mask, attention_mask,
(head_mask[idx] if head_mask is not None else None), (head_mask[idx] if head_mask is not None else None),
output_attentions,
) )
else: else:
layer_outputs = encoder_layer( layer_outputs = encoder_layer(
...@@ -1086,16 +1081,8 @@ class MBartDecoder(MBartPreTrainedModel): ...@@ -1086,16 +1081,8 @@ class MBartDecoder(MBartPreTrainedModel):
past_key_value = past_key_values[idx] if past_key_values is not None else None past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
def create_custom_forward(module): decoder_layer.__call__,
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, output_attentions, use_cache)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
hidden_states, hidden_states,
attention_mask, attention_mask,
encoder_hidden_states, encoder_hidden_states,
...@@ -1103,6 +1090,8 @@ class MBartDecoder(MBartPreTrainedModel): ...@@ -1103,6 +1090,8 @@ class MBartDecoder(MBartPreTrainedModel):
head_mask[idx] if head_mask is not None else None, head_mask[idx] if head_mask is not None else None,
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
None, None,
output_attentions,
use_cache,
) )
else: else:
layer_outputs = decoder_layer( layer_outputs = decoder_layer(
......
...@@ -551,20 +551,15 @@ class MegatronBertEncoder(nn.Module): ...@@ -551,20 +551,15 @@ class MegatronBertEncoder(nn.Module):
past_key_value = past_key_values[i] if past_key_values is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
def create_custom_forward(module): layer_module.__call__,
def custom_forward(*inputs):
return module(*inputs, past_key_value, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module),
hidden_states, hidden_states,
attention_mask, attention_mask,
layer_head_mask, layer_head_mask,
encoder_hidden_states, encoder_hidden_states,
encoder_attention_mask, encoder_attention_mask,
past_key_value,
output_attentions,
) )
else: else:
layer_outputs = layer_module( layer_outputs = layer_module(
...@@ -728,9 +723,10 @@ class MegatronBertPreTrainedModel(PreTrainedModel): ...@@ -728,9 +723,10 @@ class MegatronBertPreTrainedModel(PreTrainedModel):
if isinstance(module, nn.Linear) and module.bias is not None: if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_() module.bias.data.zero_()
def _set_gradient_checkpointing(self, module, value=False): def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, MegatronBertEncoder): if isinstance(module, MegatronBertEncoder):
module.gradient_checkpointing = value module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
@dataclass @dataclass
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment