Unverified Commit 9094abe8 authored by Arthur's avatar Arthur Committed by GitHub
Browse files

[`gradient_checkpointing`] default to use it for torch 2.3 (#28538)

* default to use it

* style
parent 49c0b293
......@@ -2092,7 +2092,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")
if gradient_checkpointing_kwargs is None:
gradient_checkpointing_kwargs = {}
gradient_checkpointing_kwargs = {"use_reentrant": True}
gradient_checkpointing_func = functools.partial(checkpoint, **gradient_checkpointing_kwargs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment