Unverified Commit 9094abe8 authored by Arthur's avatar Arthur Committed by GitHub
Browse files

[`gradient_checkpointing`] default to use it for torch 2.3 (#28538)

* default to use it

* style
parent 49c0b293
...@@ -2092,7 +2092,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -2092,7 +2092,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.") raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")
if gradient_checkpointing_kwargs is None: if gradient_checkpointing_kwargs is None:
gradient_checkpointing_kwargs = {} gradient_checkpointing_kwargs = {"use_reentrant": True}
gradient_checkpointing_func = functools.partial(checkpoint, **gradient_checkpointing_kwargs) gradient_checkpointing_func = functools.partial(checkpoint, **gradient_checkpointing_kwargs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment