Unverified Commit 06e782da authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

[`core`] Refactor of `gradient_checkpointing` (#27020)

* v1

* fix

* remove `create_custom_forward`

* fixup

* fixup

* add test and fix all failing GC tests

* remove all remaining `create_custom_forward` methods

* fix idefics bug

* fixup

* replace with `__call__`

* add comment

* quality
parent 9286f0ac
...@@ -570,9 +570,10 @@ class XLMProphetNetPreTrainedModel(PreTrainedModel): ...@@ -570,9 +570,10 @@ class XLMProphetNetPreTrainedModel(PreTrainedModel):
if module.padding_idx is not None: if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_() module.weight.data[module.padding_idx].zero_()
def _set_gradient_checkpointing(self, module, value=False): def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, (XLMProphetNetDecoder, XLMProphetNetEncoder)): if isinstance(module, (XLMProphetNetDecoder, XLMProphetNetEncoder)):
module.gradient_checkpointing = value module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
def _shift_right(self, input_ids): def _shift_right(self, input_ids):
decoder_start_token_id = self.config.decoder_start_token_id decoder_start_token_id = self.config.decoder_start_token_id
...@@ -1349,18 +1350,12 @@ class XLMProphetNetEncoder(XLMProphetNetPreTrainedModel): ...@@ -1349,18 +1350,12 @@ class XLMProphetNetEncoder(XLMProphetNetPreTrainedModel):
encoder_hidden_states = encoder_hidden_states + (hidden_states,) encoder_hidden_states = encoder_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
def create_custom_forward(module): encoder_layer.__call__,
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(encoder_layer),
hidden_states, hidden_states,
extended_attention_mask, extended_attention_mask,
(head_mask[idx] if head_mask is not None else None), (head_mask[idx] if head_mask is not None else None),
output_attentions,
) )
else: else:
layer_outputs = encoder_layer( layer_outputs = encoder_layer(
...@@ -1592,16 +1587,8 @@ class XLMProphetNetDecoder(XLMProphetNetPreTrainedModel): ...@@ -1592,16 +1587,8 @@ class XLMProphetNetDecoder(XLMProphetNetPreTrainedModel):
past_key_value = past_key_values[idx] if past_key_values is not None else None past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
def create_custom_forward(module): decoder_layer.__call__,
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, use_cache, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
hidden_states, hidden_states,
extended_attention_mask, extended_attention_mask,
encoder_hidden_states, encoder_hidden_states,
...@@ -1613,6 +1600,8 @@ class XLMProphetNetDecoder(XLMProphetNetPreTrainedModel): ...@@ -1613,6 +1600,8 @@ class XLMProphetNetDecoder(XLMProphetNetPreTrainedModel):
predict_relative_position_buckets, predict_relative_position_buckets,
position_ids, position_ids,
None, None,
use_cache,
output_attentions,
) )
else: else:
layer_outputs = decoder_layer( layer_outputs = decoder_layer(
......
...@@ -511,20 +511,15 @@ class XLMRobertaEncoder(nn.Module): ...@@ -511,20 +511,15 @@ class XLMRobertaEncoder(nn.Module):
past_key_value = past_key_values[i] if past_key_values is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
def create_custom_forward(module): layer_module.__call__,
def custom_forward(*inputs):
return module(*inputs, past_key_value, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module),
hidden_states, hidden_states,
attention_mask, attention_mask,
layer_head_mask, layer_head_mask,
encoder_hidden_states, encoder_hidden_states,
encoder_attention_mask, encoder_attention_mask,
past_key_value,
output_attentions,
) )
else: else:
layer_outputs = layer_module( layer_outputs = layer_module(
...@@ -614,9 +609,10 @@ class XLMRobertaPreTrainedModel(PreTrainedModel): ...@@ -614,9 +609,10 @@ class XLMRobertaPreTrainedModel(PreTrainedModel):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, value=False): def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, XLMRobertaEncoder): if isinstance(module, XLMRobertaEncoder):
module.gradient_checkpointing = value module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
XLM_ROBERTA_START_DOCSTRING = r""" XLM_ROBERTA_START_DOCSTRING = r"""
......
...@@ -499,20 +499,15 @@ class XLMRobertaXLEncoder(nn.Module): ...@@ -499,20 +499,15 @@ class XLMRobertaXLEncoder(nn.Module):
past_key_value = past_key_values[i] if past_key_values is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
def create_custom_forward(module): layer_module.__call__,
def custom_forward(*inputs):
return module(*inputs, past_key_value, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module),
hidden_states, hidden_states,
attention_mask, attention_mask,
layer_head_mask, layer_head_mask,
encoder_hidden_states, encoder_hidden_states,
encoder_attention_mask, encoder_attention_mask,
past_key_value,
output_attentions,
) )
else: else:
layer_outputs = layer_module( layer_outputs = layer_module(
......
...@@ -573,21 +573,16 @@ class XmodEncoder(nn.Module): ...@@ -573,21 +573,16 @@ class XmodEncoder(nn.Module):
past_key_value = past_key_values[i] if past_key_values is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
def create_custom_forward(module): layer_module.__call__,
def custom_forward(*inputs):
return module(*inputs, past_key_value, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module),
hidden_states, hidden_states,
lang_ids, lang_ids,
attention_mask, attention_mask,
layer_head_mask, layer_head_mask,
encoder_hidden_states, encoder_hidden_states,
encoder_attention_mask, encoder_attention_mask,
past_key_value,
output_attentions,
) )
else: else:
layer_outputs = layer_module( layer_outputs = layer_module(
...@@ -680,9 +675,10 @@ class XmodPreTrainedModel(PreTrainedModel): ...@@ -680,9 +675,10 @@ class XmodPreTrainedModel(PreTrainedModel):
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
# Copied from transformers.models.roberta.modeling_roberta.RobertaPreTrainedModel._set_gradient_checkpointing with Roberta->Xmod # Copied from transformers.models.roberta.modeling_roberta.RobertaPreTrainedModel._set_gradient_checkpointing with Roberta->Xmod
def _set_gradient_checkpointing(self, module, value=False): def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, XmodEncoder): if isinstance(module, XmodEncoder):
module.gradient_checkpointing = value module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
def set_default_language(self, language: str): def set_default_language(self, language: str):
""" """
......
...@@ -492,17 +492,11 @@ class YolosEncoder(nn.Module): ...@@ -492,17 +492,11 @@ class YolosEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None layer_head_mask = head_mask[i] if head_mask is not None else None
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
def create_custom_forward(module): layer_module.__call__,
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module),
hidden_states, hidden_states,
layer_head_mask, layer_head_mask,
output_attentions,
) )
else: else:
layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions) layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions)
...@@ -551,9 +545,10 @@ class YolosPreTrainedModel(PreTrainedModel): ...@@ -551,9 +545,10 @@ class YolosPreTrainedModel(PreTrainedModel):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module: YolosEncoder, value: bool = False) -> None: def _set_gradient_checkpointing(self, module: YolosEncoder, gradient_checkpointing_func=None) -> None:
if isinstance(module, YolosEncoder): if isinstance(module, YolosEncoder):
module.gradient_checkpointing = value module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
YOLOS_START_DOCSTRING = r""" YOLOS_START_DOCSTRING = r"""
......
...@@ -561,17 +561,11 @@ class YosoEncoder(nn.Module): ...@@ -561,17 +561,11 @@ class YosoEncoder(nn.Module):
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
def create_custom_forward(module): layer_module.__call__,
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module),
hidden_states, hidden_states,
attention_mask, attention_mask,
output_attentions,
) )
else: else:
layer_outputs = layer_module(hidden_states, attention_mask, output_attentions) layer_outputs = layer_module(hidden_states, attention_mask, output_attentions)
...@@ -668,9 +662,10 @@ class YosoPreTrainedModel(PreTrainedModel): ...@@ -668,9 +662,10 @@ class YosoPreTrainedModel(PreTrainedModel):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, value=False): def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, YosoEncoder): if isinstance(module, YosoEncoder):
module.gradient_checkpointing = value module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
YOSO_START_DOCSTRING = r""" YOSO_START_DOCSTRING = r"""
......
...@@ -544,19 +544,15 @@ class {{cookiecutter.camelcase_modelname}}Encoder(nn.Module): ...@@ -544,19 +544,15 @@ class {{cookiecutter.camelcase_modelname}}Encoder(nn.Module):
past_key_value = past_key_values[i] if past_key_values is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
def create_custom_forward(module): layer_outputs = self.gradient_checkpointing_func(
def custom_forward(*inputs): layer_module.__call__,
return module(*inputs, past_key_value, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module),
hidden_states, hidden_states,
attention_mask, attention_mask,
layer_head_mask, layer_head_mask,
encoder_hidden_states, encoder_hidden_states,
encoder_attention_mask, encoder_attention_mask,
past_key_value,
output_attentions,
) )
else: else:
layer_outputs = layer_module( layer_outputs = layer_module(
...@@ -679,9 +675,10 @@ class {{cookiecutter.camelcase_modelname}}PreTrainedModel(PreTrainedModel): ...@@ -679,9 +675,10 @@ class {{cookiecutter.camelcase_modelname}}PreTrainedModel(PreTrainedModel):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, value=False): def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, {{cookiecutter.camelcase_modelname}}Encoder): if isinstance(module, {{cookiecutter.camelcase_modelname}}Encoder):
module.gradient_checkpointing = value module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
{{cookiecutter.uppercase_modelname}}_START_DOCSTRING = r""" {{cookiecutter.uppercase_modelname}}_START_DOCSTRING = r"""
...@@ -2024,9 +2021,10 @@ class {{cookiecutter.camelcase_modelname}}PreTrainedModel(PreTrainedModel): ...@@ -2024,9 +2021,10 @@ class {{cookiecutter.camelcase_modelname}}PreTrainedModel(PreTrainedModel):
if module.padding_idx is not None: if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_() module.weight.data[module.padding_idx].zero_()
def _set_gradient_checkpointing(self, module, value=False): def _set_gradient_checkpointing(self, module, gradient_checkpointing_func=None):
if isinstance(module, ({{cookiecutter.camelcase_modelname}}Decoder, {{cookiecutter.camelcase_modelname}}Encoder)): if isinstance(module, ({{cookiecutter.camelcase_modelname}}Decoder, {{cookiecutter.camelcase_modelname}}Encoder)):
module.gradient_checkpointing = value module.gradient_checkpointing_func = gradient_checkpointing_func
module.gradient_checkpointing = gradient_checkpointing_func is not None
{{cookiecutter.uppercase_modelname}}_START_DOCSTRING = r""" {{cookiecutter.uppercase_modelname}}_START_DOCSTRING = r"""
...@@ -2312,18 +2310,12 @@ class {{cookiecutter.camelcase_modelname}}Encoder({{cookiecutter.camelcase_model ...@@ -2312,18 +2310,12 @@ class {{cookiecutter.camelcase_modelname}}Encoder({{cookiecutter.camelcase_model
layer_outputs = (None, None) layer_outputs = (None, None)
else: else:
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = self.gradient_checkpointing_func(
def create_custom_forward(module): encoder_layer.__call__,
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(encoder_layer),
hidden_states, hidden_states,
attention_mask, attention_mask,
(head_mask[idx] if head_mask is not None else None), (head_mask[idx] if head_mask is not None else None),
output_attentions,
) )
else: else:
layer_outputs = encoder_layer( layer_outputs = encoder_layer(
...@@ -2551,15 +2543,8 @@ class {{cookiecutter.camelcase_modelname}}Decoder({{cookiecutter.camelcase_model ...@@ -2551,15 +2543,8 @@ class {{cookiecutter.camelcase_modelname}}Decoder({{cookiecutter.camelcase_model
past_key_value = past_key_values[idx] if past_key_values is not None else None past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
def create_custom_forward(module): layer_outputs = self.gradient_checkpointing_func(
def custom_forward(*inputs): decoder_layer.__call__,
# None for past_key_value
return module(*inputs, output_attentions, use_cache)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer),
hidden_states, hidden_states,
attention_mask, attention_mask,
encoder_hidden_states, encoder_hidden_states,
...@@ -2567,6 +2552,8 @@ class {{cookiecutter.camelcase_modelname}}Decoder({{cookiecutter.camelcase_model ...@@ -2567,6 +2552,8 @@ class {{cookiecutter.camelcase_modelname}}Decoder({{cookiecutter.camelcase_model
head_mask[idx] if head_mask is not None else None, head_mask[idx] if head_mask is not None else None,
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None, cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None,
None, None,
output_attentions,
use_cache,
) )
else: else:
......
...@@ -349,10 +349,24 @@ class ModelTesterMixin: ...@@ -349,10 +349,24 @@ class ModelTesterMixin:
model.gradient_checkpointing_enable() model.gradient_checkpointing_enable()
self.assertTrue(model.is_gradient_checkpointing) self.assertTrue(model.is_gradient_checkpointing)
# Loop over all modules and check that relevant modules have gradient_checkpointing set to True
for n, m in model.named_modules():
if hasattr(m, "gradient_checkpointing"):
self.assertTrue(
m.gradient_checkpointing, f"Module {n} does not have gradient_checkpointing set to True"
)
# check disable works # check disable works
model.gradient_checkpointing_disable() model.gradient_checkpointing_disable()
self.assertFalse(model.is_gradient_checkpointing) self.assertFalse(model.is_gradient_checkpointing)
# Loop over all modules and check that relevant modules have gradient_checkpointing set to False
for n, m in model.named_modules():
if hasattr(m, "gradient_checkpointing"):
self.assertFalse(
m.gradient_checkpointing, f"Module {n} does not have gradient_checkpointing set to False"
)
def test_save_load_fast_init_from_base(self): def test_save_load_fast_init_from_base(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
if config.__class__ not in MODEL_MAPPING: if config.__class__ not in MODEL_MAPPING:
...@@ -569,6 +583,13 @@ class ModelTesterMixin: ...@@ -569,6 +583,13 @@ class ModelTesterMixin:
loss = model(**inputs).loss loss = model(**inputs).loss
loss.backward() loss.backward()
model.gradient_checkpointing_disable()
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": True})
model.train()
inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
loss = model(**inputs).loss
loss.backward()
def test_attention_outputs(self): def test_attention_outputs(self):
if not self.has_attentions: if not self.has_attentions:
self.skipTest(reason="Model does not output attentions") self.skipTest(reason="Model does not output attentions")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment