Unverified Commit d39352d1 authored by NielsRogge's avatar NielsRogge Committed by GitHub
Browse files

Fix import of torch.utils.checkpoint (#27155)



* Fix import

* Apply suggestions from code review
Co-authored-by: default avatarYounes Belkada <49240599+younesbelkada@users.noreply.github.com>

---------
Co-authored-by: default avatarYounes Belkada <49240599+younesbelkada@users.noreply.github.com>
parent e971486d
...@@ -33,6 +33,7 @@ import torch ...@@ -33,6 +33,7 @@ import torch
from packaging import version from packaging import version
from torch import Tensor, nn from torch import Tensor, nn
from torch.nn import CrossEntropyLoss, Identity from torch.nn import CrossEntropyLoss, Identity
from torch.utils.checkpoint import checkpoint
from .activations import get_activation from .activations import get_activation
from .configuration_utils import PretrainedConfig from .configuration_utils import PretrainedConfig
...@@ -1869,9 +1870,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -1869,9 +1870,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
if gradient_checkpointing_kwargs is None: if gradient_checkpointing_kwargs is None:
gradient_checkpointing_kwargs = {} gradient_checkpointing_kwargs = {}
gradient_checkpointing_func = functools.partial( gradient_checkpointing_func = functools.partial(checkpoint, **gradient_checkpointing_kwargs)
torch.utils.checkpoint.checkpoint, **gradient_checkpointing_kwargs
)
self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=gradient_checkpointing_func) self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=gradient_checkpointing_func)
...@@ -1882,9 +1881,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -1882,9 +1881,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
# the gradients to make sure the gradient flows. # the gradients to make sure the gradient flows.
self.enable_input_require_grads() self.enable_input_require_grads()
def _set_gradient_checkpointing( def _set_gradient_checkpointing(self, enable: bool = True, gradient_checkpointing_func: Callable = checkpoint):
self, enable: bool = True, gradient_checkpointing_func: Callable = torch.utils.checkpoint.checkpoint
):
is_gradient_checkpointing_set = False is_gradient_checkpointing_set = False
# Apply it on the top-level module in case the top-level modules supports it # Apply it on the top-level module in case the top-level modules supports it
......
...@@ -1813,7 +1813,7 @@ class SeamlessM4TEncoder(SeamlessM4TPreTrainedModel): ...@@ -1813,7 +1813,7 @@ class SeamlessM4TEncoder(SeamlessM4TPreTrainedModel):
layer_outputs = (None, None) layer_outputs = (None, None)
else: else:
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
layer_outputs = torch.utils.checkpoint.checkpoint( layer_outputs = self._gradient_checkpointing_func(
encoder_layer.forward, encoder_layer.forward,
hidden_states, hidden_states,
attention_mask, attention_mask,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment