Unverified Commit dca67968 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Gradient checkpoining] Correct disabling `find_unused_parameters` in Trainer...

[Gradient checkpoining] Correct disabling `find_unused_parameters` in Trainer when gradient checkpointing is enabled (#13961)

* up

* correct test
parent 4a18337b
...@@ -946,7 +946,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -946,7 +946,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
self.base_model._prune_heads(heads_to_prune) self.base_model._prune_heads(heads_to_prune)
def gradient_checkpointing_enable(self, flag: bool = True): def gradient_checkpointing_enable(self):
""" """
Activates gradient checkpointing for the current model. Activates gradient checkpointing for the current model.
...@@ -957,7 +957,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -957,7 +957,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.") raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")
self.apply(partial(self._set_gradient_checkpointing, value=True)) self.apply(partial(self._set_gradient_checkpointing, value=True))
def gradient_checkpointing_disable(self, flag: bool = True): def gradient_checkpointing_disable(self):
""" """
Deactivates gradient checkpointing for the current model. Deactivates gradient checkpointing for the current model.
...@@ -967,6 +967,16 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -967,6 +967,16 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
if self.supports_gradient_checkpointing: if self.supports_gradient_checkpointing:
self.apply(partial(self._set_gradient_checkpointing, value=False)) self.apply(partial(self._set_gradient_checkpointing, value=False))
@property
def is_gradient_checkpointing(self) -> bool:
"""
Whether gradient checkpointing is activated for this model or not.
Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint
activations".
"""
return any(hasattr(m, "gradient_checkpointing") and m.gradient_checkpointing for m in self.modules())
def save_pretrained( def save_pretrained(
self, self,
save_directory: Union[str, os.PathLike], save_directory: Union[str, os.PathLike],
......
...@@ -996,7 +996,7 @@ class Trainer: ...@@ -996,7 +996,7 @@ class Trainer:
elif isinstance(model, PreTrainedModel): elif isinstance(model, PreTrainedModel):
# find_unused_parameters breaks checkpointing as per # find_unused_parameters breaks checkpointing as per
# https://github.com/huggingface/transformers/pull/4659#issuecomment-643356021 # https://github.com/huggingface/transformers/pull/4659#issuecomment-643356021
find_unused_parameters = not getattr(model.config, "_gradient_checkpointing", False) find_unused_parameters = not model.is_gradient_checkpointing
else: else:
find_unused_parameters = True find_unused_parameters = True
model = nn.parallel.DistributedDataParallel( model = nn.parallel.DistributedDataParallel(
......
...@@ -197,6 +197,25 @@ class ModelTesterMixin: ...@@ -197,6 +197,25 @@ class ModelTesterMixin:
) )
self.assertTrue(len(load_result.unexpected_keys) == 0) self.assertTrue(len(load_result.unexpected_keys) == 0)
def test_gradient_checkpointing_enable_disable(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
if not model_class.supports_gradient_checkpointing:
continue
# at init model should have gradient checkpointing disabled
model = model_class(config)
self.assertFalse(model.is_gradient_checkpointing)
# check enable works
model.gradient_checkpointing_enable()
self.assertTrue(model.is_gradient_checkpointing)
# check disable works
model.gradient_checkpointing_disable()
self.assertFalse(model.is_gradient_checkpointing)
def _mock_init_weights(self, module): def _mock_init_weights(self, module):
if hasattr(module, "weight") and module.weight is not None: if hasattr(module, "weight") and module.weight is not None:
module.weight.data.fill_(3) module.weight.data.fill_(3)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment