Unverified Commit 06e782da authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

[`core`] Refactor of `gradient_checkpointing` (#27020)

* v1

* fix

* remove `create_custom_forward`

* fixup

* fixup

* add test and fix all failing GC tests

* remove all remaining `create_custom_forward` methods

* fix idefics bug

* fixup

* replace with `__call__`

* add comment

* quality
parent 9286f0ac
......@@ -924,9 +924,10 @@ class InformerPreTrainedModel(PreTrainedModel):
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
def _set_gradient_checkpointing(self, module, value=False):
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, (InformerDecoder, InformerEncoder)):
module.gradient_checkpointing = value
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
INFORMER_START_DOCSTRING = r"""
......@@ -1215,21 +1216,15 @@ class InformerEncoder(InformerPreTrainedModel):
layer_outputs = (None, None)
else:
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(encoder_layer),
layer_outputs = self.gradient_checkpointing_func(
encoder_layer.__call__,
hidden_states,
attention_mask,
(head_mask[idx] if head_mask is not None else None),
output_attentions,
)
if conv_layer is not None:
output = torch.utils.checkpoint.checkpoint(conv_layer, layer_outputs[0])
output = self.gradient_checkpointing_func(conv_layer, layer_outputs[0])
layer_outputs = (output,) + layer_outputs[1:]
else:
layer_outputs = encoder_layer(
......@@ -1438,16 +1433,8 @@ class InformerDecoder(InformerPreTrainedModel):
past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, output_attentions, use_cache)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
layer_outputs = self.gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
attention_mask,
encoder_hidden_states,
......@@ -1455,6 +1442,8 @@ class InformerDecoder(InformerPreTrainedModel):
head_mask[idx] if head_mask is not None else None,
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
None,
output_attentions,
use_cache,
)
else:
layer_outputs = decoder_layer(
......
......@@ -304,9 +304,14 @@ class InstructBlipPreTrainedModel(PreTrainedModel):
elif isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, InstructBlipEncoder):
module.gradient_checkpointing = value
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, (InstructBlipEncoder, InstructBlipQFormerEncoder)):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
# Enable / disable GC for the language model as well
if hasattr(self, "language_model") and hasattr(self.language_model, "_set_gradient_checkpointing"):
self.language_model._set_gradient_checkpointing(module, gradient_checkpointing_func)
INSTRUCTBLIP_START_DOCSTRING = r"""
......@@ -462,17 +467,11 @@ class InstructBlipEncoder(nn.Module):
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(encoder_layer),
layer_outputs = self.gradient_checkpointing_func(
encoder_layer.__call__,
hidden_states,
attention_mask,
output_attentions,
)
else:
layer_outputs = encoder_layer(
......@@ -939,15 +938,8 @@ class InstructBlipQFormerEncoder(nn.Module):
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, past_key_value, output_attentions, query_length)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module),
layer_outputs = self.gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
attention_mask,
layer_head_mask,
......
......@@ -487,20 +487,15 @@ class LayoutLMEncoder(nn.Module):
past_key_value = past_key_values[i] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, past_key_value, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module),
layer_outputs = self.gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
)
else:
layer_outputs = layer_module(
......@@ -638,9 +633,10 @@ class LayoutLMPreTrainedModel(PreTrainedModel):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, value=False):
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, LayoutLMEncoder):
module.gradient_checkpointing = value
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
LAYOUTLM_START_DOCSTRING = r"""
......
......@@ -439,18 +439,12 @@ class LayoutLMv2Encoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module),
layer_outputs = self.gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
attention_mask,
layer_head_mask,
output_attentions,
rel_pos=rel_pos,
rel_2d_pos=rel_2d_pos,
)
......@@ -514,9 +508,10 @@ class LayoutLMv2PreTrainedModel(PreTrainedModel):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, value=False):
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, LayoutLMv2Encoder):
module.gradient_checkpointing = value
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
def my_convert_sync_batchnorm(module, process_group=None):
......
......@@ -657,19 +657,8 @@ class LayoutLMv3Encoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
# return module(*inputs, past_key_value, output_attentions, rel_pos, rel_2d_pos)
# The above line will cause error:
# RuntimeError: Trying to backward through the graph a second time
# (or directly access saved tensors after they have already been freed).
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module),
layer_outputs = self.gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
attention_mask,
layer_head_mask,
......
......@@ -1155,9 +1155,10 @@ class LEDPreTrainedModel(PreTrainedModel):
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
def _set_gradient_checkpointing(self, module, value=False):
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, (LEDDecoder, LEDEncoder)):
module.gradient_checkpointing = value
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
@property
def dummy_inputs(self):
......@@ -1876,20 +1877,15 @@ class LEDEncoder(LEDPreTrainedModel):
layer_outputs = (None, None, None)
else:
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, is_global_attn, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(encoder_layer),
layer_outputs = self.gradient_checkpointing_func(
encoder_layer.__call__,
hidden_states,
attention_mask,
head_mask[idx] if head_mask is not None else None,
is_index_masked,
is_index_global_attn,
is_global_attn,
output_attentions,
)
else:
layer_outputs = encoder_layer(
......@@ -2142,16 +2138,8 @@ class LEDDecoder(LEDPreTrainedModel):
past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, output_attentions, use_cache)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
layer_outputs = self.gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
combined_attention_mask,
encoder_hidden_states,
......@@ -2159,6 +2147,8 @@ class LEDDecoder(LEDPreTrainedModel):
head_mask[idx] if head_mask is not None else None,
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
None,
output_attentions,
use_cache,
)
else:
layer_outputs = decoder_layer(
......
......@@ -507,9 +507,10 @@ class LevitPreTrainedModel(PreTrainedModel):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, value=False):
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, LevitModel):
module.gradient_checkpointing = value
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
LEVIT_START_DOCSTRING = r"""
......
......@@ -514,19 +514,13 @@ class LiltEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module),
layer_outputs = self.gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
layout_inputs,
attention_mask,
layer_head_mask,
output_attentions,
)
else:
layer_outputs = layer_module(
......@@ -607,9 +601,10 @@ class LiltPreTrainedModel(PreTrainedModel):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, value=False):
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, LiltEncoder):
module.gradient_checkpointing = value
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
LILT_START_DOCSTRING = r"""
......
......@@ -827,9 +827,10 @@ class LlamaPreTrainedModel(PreTrainedModel):
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
def _set_gradient_checkpointing(self, module, value=False):
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, LlamaModel):
module.gradient_checkpointing = value
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
LLAMA_INPUTS_DOCSTRING = r"""
......@@ -1013,16 +1014,14 @@ class LlamaModel(LlamaPreTrainedModel):
past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, past_key_value, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer), hidden_states, attention_mask, position_ids
layer_outputs = self.gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
attention_mask,
position_ids,
past_key_value,
output_attentions,
use_cache,
)
else:
layer_outputs = decoder_layer(
......
......@@ -1304,20 +1304,15 @@ class LongformerEncoder(nn.Module):
all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, is_global_attn, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module),
layer_outputs = self.gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
attention_mask,
head_mask[idx] if head_mask is not None else None,
is_index_masked,
is_index_global_attn,
is_global_attn,
output_attentions,
)
else:
layer_outputs = layer_module(
......@@ -1439,9 +1434,10 @@ class LongformerPreTrainedModel(PreTrainedModel):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, value=False):
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, LongformerEncoder):
module.gradient_checkpointing = value
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
LONGFORMER_START_DOCSTRING = r"""
......
......@@ -775,7 +775,6 @@ class LongT5TransientGlobalAttention(nn.Module):
if self.has_relative_attention_bias:
self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads)
self.pruned_heads = set()
self.gradient_checkpointing = False
# Relativen attention bias & Layer norm for global attention
if self.has_relative_attention_bias:
......@@ -1340,10 +1339,10 @@ class LongT5PreTrainedModel(PreTrainedModel):
mean=0.0, std=factor * ((d_model) ** -0.5)
)
# Copied from transformers.models.t5.modeling_t5.T5PreTrainedModel._set_gradient_checkpointing with T5->LongT5
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (LongT5Attention, LongT5Stack)):
module.gradient_checkpointing = value
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, (LongT5Attention, LongT5Stack, LongT5LocalAttention)):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
# Copied from transformers.models.t5.modeling_t5.T5PreTrainedModel._shift_right with T5->LongT5
def _shift_right(self, input_ids):
......@@ -1510,15 +1509,8 @@ class LongT5Stack(LongT5PreTrainedModel):
all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
return tuple(module(*inputs, use_cache, output_attentions))
return custom_forward
layer_outputs = checkpoint(
create_custom_forward(layer_module),
layer_module.forward,
hidden_states,
extended_attention_mask,
position_bias,
......@@ -1528,6 +1520,8 @@ class LongT5Stack(LongT5PreTrainedModel):
layer_head_mask,
cross_attn_layer_head_mask,
None, # past_key_value is always None with gradient checkpointing
use_cache,
output_attentions,
)
else:
layer_outputs = layer_module(
......
......@@ -788,19 +788,13 @@ class LukeEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module),
layer_outputs = self.gradient_checkpointing_func(
layer_module.__call__,
word_hidden_states,
entity_hidden_states,
attention_mask,
layer_head_mask,
output_attentions,
)
else:
layer_outputs = layer_module(
......@@ -920,9 +914,10 @@ class LukePreTrainedModel(PreTrainedModel):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, value=False):
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, LukeEncoder):
module.gradient_checkpointing = value
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
LUKE_START_DOCSTRING = r"""
......
......@@ -552,9 +552,10 @@ class M2M100PreTrainedModel(PreTrainedModel):
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
def _set_gradient_checkpointing(self, module, value=False):
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, (M2M100Decoder, M2M100Encoder)):
module.gradient_checkpointing = value
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
M2M_100_START_DOCSTRING = r"""
......@@ -820,18 +821,12 @@ class M2M100Encoder(M2M100PreTrainedModel):
# under deepspeed zero3 all gpus must run in sync
if self.gradient_checkpointing and self.training:
# create gradient checkpointing function
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(encoder_layer),
layer_outputs = self.gradient_checkpointing_func(
encoder_layer.__call__,
hidden_states,
attention_mask,
(head_mask[idx] if head_mask is not None else None),
output_attentions,
)
else:
layer_outputs = encoder_layer(
......@@ -1066,16 +1061,8 @@ class M2M100Decoder(M2M100PreTrainedModel):
past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, output_attentions, use_cache)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
layer_outputs = self.gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
combined_attention_mask,
encoder_hidden_states,
......@@ -1083,6 +1070,8 @@ class M2M100Decoder(M2M100PreTrainedModel):
head_mask[idx] if head_mask is not None else None,
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
None,
output_attentions,
use_cache,
)
else:
layer_outputs = decoder_layer(
......
......@@ -500,9 +500,10 @@ class MarianPreTrainedModel(PreTrainedModel):
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
def _set_gradient_checkpointing(self, module, value=False):
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, (MarianDecoder, MarianEncoder)):
module.gradient_checkpointing = value
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
@property
def dummy_inputs(self):
......@@ -788,18 +789,12 @@ class MarianEncoder(MarianPreTrainedModel):
layer_outputs = (None, None)
else:
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(encoder_layer),
layer_outputs = self.gradient_checkpointing_func(
encoder_layer.__call__,
hidden_states,
attention_mask,
(head_mask[idx] if head_mask is not None else None),
output_attentions,
)
else:
layer_outputs = encoder_layer(
......@@ -1037,16 +1032,8 @@ class MarianDecoder(MarianPreTrainedModel):
past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, output_attentions, use_cache)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
layer_outputs = self.gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
attention_mask,
encoder_hidden_states,
......@@ -1054,6 +1041,8 @@ class MarianDecoder(MarianPreTrainedModel):
head_mask[idx] if head_mask is not None else None,
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
None,
output_attentions,
use_cache,
)
else:
layer_outputs = decoder_layer(
......
......@@ -648,20 +648,15 @@ class MarkupLMEncoder(nn.Module):
past_key_value = past_key_values[i] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, past_key_value, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module),
layer_outputs = self.gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
)
else:
layer_outputs = layer_module(
......
......@@ -1864,20 +1864,14 @@ class Mask2FormerMaskedAttentionDecoder(nn.Module):
continue
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
layer_outputs = self.gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
attention_mask,
encoder_hidden_states,
None,
None,
output_attentions,
)
else:
......
......@@ -848,20 +848,14 @@ class DetrDecoder(nn.Module):
continue
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
layer_outputs = self.gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
combined_attention_mask,
encoder_hidden_states,
encoder_attention_mask,
None,
output_attentions,
)
else:
layer_outputs = decoder_layer(
......@@ -1619,11 +1613,13 @@ class MaskFormerPreTrainedModel(PreTrainedModel):
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
def _set_gradient_checkpointing(self, module, value=False):
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, MaskFormerPixelLevelModule):
module.encoder.gradient_checkpointing = value
module.gradient_checkpointing_func = gradient_checkpointing_func
module.encoder.gradient_checkpointing = gradient_checkpointing_func is not None
if isinstance(module, DetrDecoder):
module.gradient_checkpointing = value
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
@add_start_docstrings(
......
......@@ -688,15 +688,11 @@ class MaskFormerSwinEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_hidden_states, output_dimensions, layer_all_hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module), hidden_states, layer_head_mask
layer_hidden_states, output_dimensions, layer_all_hidden_states = self.gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
layer_head_mask,
output_attentions,
)
else:
layer_hidden_states, output_dimensions, layer_all_hidden_states = layer_module(
......@@ -752,9 +748,10 @@ class MaskFormerSwinPreTrainedModel(PreTrainedModel):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, value=False):
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, MaskFormerSwinEncoder):
module.gradient_checkpointing = value
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
class MaskFormerSwinModel(MaskFormerSwinPreTrainedModel):
......
......@@ -516,9 +516,10 @@ class MBartPreTrainedModel(PreTrainedModel):
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (MBartDecoder, MBartDecoder)):
module.gradient_checkpointing = value
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, (MBartDecoder, MBartEncoder)):
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
@property
def dummy_inputs(self):
......@@ -828,18 +829,12 @@ class MBartEncoder(MBartPreTrainedModel):
layer_outputs = (None, None)
else:
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(encoder_layer),
layer_outputs = self.gradient_checkpointing_func(
encoder_layer.__call__,
hidden_states,
attention_mask,
(head_mask[idx] if head_mask is not None else None),
output_attentions,
)
else:
layer_outputs = encoder_layer(
......@@ -1086,16 +1081,8 @@ class MBartDecoder(MBartPreTrainedModel):
past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, output_attentions, use_cache)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
layer_outputs = self.gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
attention_mask,
encoder_hidden_states,
......@@ -1103,6 +1090,8 @@ class MBartDecoder(MBartPreTrainedModel):
head_mask[idx] if head_mask is not None else None,
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
None,
output_attentions,
use_cache,
)
else:
layer_outputs = decoder_layer(
......
......@@ -551,20 +551,15 @@ class MegatronBertEncoder(nn.Module):
past_key_value = past_key_values[i] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, past_key_value, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module),
layer_outputs = self.gradient_checkpointing_func(
layer_module.__call__,
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
)
else:
layer_outputs = layer_module(
......@@ -728,9 +723,10 @@ class MegatronBertPreTrainedModel(PreTrainedModel):
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
def _set_gradient_checkpointing(self, module, value=False):
def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, MegatronBertEncoder):
module.gradient_checkpointing = value
module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
@dataclass
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment