Unverified Commit 0e6794ff authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

[`core` / `gradient_checkpointing`] add support for old GC method (#27610)

* add support for old GC method

* add also disable

* up

* oops
parent 8eb9e29d
...@@ -1876,7 +1876,18 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -1876,7 +1876,18 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
gradient_checkpointing_func = functools.partial(checkpoint, **gradient_checkpointing_kwargs) gradient_checkpointing_func = functools.partial(checkpoint, **gradient_checkpointing_kwargs)
# For old GC format (transformers < 4.35.0) for models that live on the Hub
# we will fall back to the overwritten `_set_gradient_checkpointing` methid
_is_using_old_format = "value" in inspect.signature(self._set_gradient_checkpointing).parameters
if not _is_using_old_format:
self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=gradient_checkpointing_func) self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=gradient_checkpointing_func)
else:
self.apply(partial(self._set_gradient_checkpointing, value=True))
logger.warn(
"You are using an old version of the checkpointing format that is deprecated (We will also silently ignore `gradient_checkpointing_kwargs` in case you passed it)."
"Please update to the new format on your modeling file. To use the new format, you need to completely remove the definition of the method `_set_gradient_checkpointing` in your model."
)
if getattr(self, "_hf_peft_config_loaded", False): if getattr(self, "_hf_peft_config_loaded", False):
# When using PEFT + gradient checkpointing + Trainer we need to make sure the input has requires_grad=True # When using PEFT + gradient checkpointing + Trainer we need to make sure the input has requires_grad=True
...@@ -1915,7 +1926,17 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -1915,7 +1926,17 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
activations". activations".
""" """
if self.supports_gradient_checkpointing: if self.supports_gradient_checkpointing:
# For old GC format (transformers < 4.35.0) for models that live on the Hub
# we will fall back to the overwritten `_set_gradient_checkpointing` methid
_is_using_old_format = "value" in inspect.signature(self._set_gradient_checkpointing).parameters
if not _is_using_old_format:
self._set_gradient_checkpointing(enable=False) self._set_gradient_checkpointing(enable=False)
else:
logger.warn(
"You are using an old version of the checkpointing format that is deprecated (We will also silently ignore `gradient_checkpointing_kwargs` in case you passed it)."
"Please update to the new format on your modeling file. To use the new format, you need to completely remove the definition of the method `_set_gradient_checkpointing` in your model."
)
self.apply(partial(self._set_gradient_checkpointing, value=False))
if getattr(self, "_hf_peft_config_loaded", False): if getattr(self, "_hf_peft_config_loaded", False):
self.disable_input_require_grads() self.disable_input_require_grads()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment