"docs/source/vscode:/vscode.git/clone" did not exist on "570b3f9cdd18a9e6d075fac561cf2eb11dfec9ce"
Unverified Commit dca67968 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Gradient checkpoining] Correct disabling `find_unused_parameters` in Trainer...

[Gradient checkpoining] Correct disabling `find_unused_parameters` in Trainer when gradient checkpointing is enabled (#13961)

* up

* correct test
parent 4a18337b
......@@ -946,7 +946,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
self.base_model._prune_heads(heads_to_prune)
def gradient_checkpointing_enable(self, flag: bool = True):
def gradient_checkpointing_enable(self):
"""
Activates gradient checkpointing for the current model.
......@@ -957,7 +957,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")
self.apply(partial(self._set_gradient_checkpointing, value=True))
def gradient_checkpointing_disable(self, flag: bool = True):
def gradient_checkpointing_disable(self):
"""
Deactivates gradient checkpointing for the current model.
......@@ -967,6 +967,16 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
if self.supports_gradient_checkpointing:
self.apply(partial(self._set_gradient_checkpointing, value=False))
@property
def is_gradient_checkpointing(self) -> bool:
"""
Whether gradient checkpointing is activated for this model or not.
Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint
activations".
"""
return any(hasattr(m, "gradient_checkpointing") and m.gradient_checkpointing for m in self.modules())
def save_pretrained(
self,
save_directory: Union[str, os.PathLike],
......
......@@ -996,7 +996,7 @@ class Trainer:
elif isinstance(model, PreTrainedModel):
# find_unused_parameters breaks checkpointing as per
# https://github.com/huggingface/transformers/pull/4659#issuecomment-643356021
find_unused_parameters = not getattr(model.config, "_gradient_checkpointing", False)
find_unused_parameters = not model.is_gradient_checkpointing
else:
find_unused_parameters = True
model = nn.parallel.DistributedDataParallel(
......
......@@ -197,6 +197,25 @@ class ModelTesterMixin:
)
self.assertTrue(len(load_result.unexpected_keys) == 0)
def test_gradient_checkpointing_enable_disable(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
if not model_class.supports_gradient_checkpointing:
continue
# at init model should have gradient checkpointing disabled
model = model_class(config)
self.assertFalse(model.is_gradient_checkpointing)
# check enable works
model.gradient_checkpointing_enable()
self.assertTrue(model.is_gradient_checkpointing)
# check disable works
model.gradient_checkpointing_disable()
self.assertFalse(model.is_gradient_checkpointing)
def _mock_init_weights(self, module):
if hasattr(module, "weight") and module.weight is not None:
module.weight.data.fill_(3)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment