Unverified Commit d39352d1 authored by NielsRogge's avatar NielsRogge Committed by GitHub
Browse files

Fix import of torch.utils.checkpoint (#27155)



* Fix import

* Apply suggestions from code review
Co-authored-by: default avatarYounes Belkada <49240599+younesbelkada@users.noreply.github.com>

---------
Co-authored-by: default avatarYounes Belkada <49240599+younesbelkada@users.noreply.github.com>
parent e971486d
......@@ -33,6 +33,7 @@ import torch
from packaging import version
from torch import Tensor, nn
from torch.nn import CrossEntropyLoss, Identity
from torch.utils.checkpoint import checkpoint
from .activations import get_activation
from .configuration_utils import PretrainedConfig
......@@ -1869,9 +1870,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
if gradient_checkpointing_kwargs is None:
gradient_checkpointing_kwargs = {}
gradient_checkpointing_func = functools.partial(
torch.utils.checkpoint.checkpoint, **gradient_checkpointing_kwargs
)
gradient_checkpointing_func = functools.partial(checkpoint, **gradient_checkpointing_kwargs)
self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=gradient_checkpointing_func)
......@@ -1882,9 +1881,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
# the gradients to make sure the gradient flows.
self.enable_input_require_grads()
def _set_gradient_checkpointing(
self, enable: bool = True, gradient_checkpointing_func: Callable = torch.utils.checkpoint.checkpoint
):
def _set_gradient_checkpointing(self, enable: bool = True, gradient_checkpointing_func: Callable = checkpoint):
is_gradient_checkpointing_set = False
# Apply it on the top-level module in case the top-level modules supports it
......
......@@ -1813,7 +1813,7 @@ class SeamlessM4TEncoder(SeamlessM4TPreTrainedModel):
layer_outputs = (None, None)
else:
if self.gradient_checkpointing and self.training:
layer_outputs = torch.utils.checkpoint.checkpoint(
layer_outputs = self._gradient_checkpointing_func(
encoder_layer.forward,
hidden_states,
attention_mask,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment