Unverified Commit 8876553e authored by Min Xu's avatar Min Xu Committed by GitHub
Browse files

[bug] use weakref in the wrapper (#424)



* use weakref in the wrapper

* comment

* comment

* Update fairscale/nn/misc/checkpoint_activations.py
Co-authored-by: default avatarSam Shleifer <sshleifer@gmail.com>
Co-authored-by: default avatarSam Shleifer <sshleifer@gmail.com>
parent 4b5b4d3d
......@@ -4,7 +4,9 @@
# LICENSE file in the root directory of this source tree.
from contextlib import contextmanager
import functools
from typing import Any, Dict, Generator, Optional, Tuple
import weakref
import torch
from torch import Tensor
......@@ -61,29 +63,23 @@ def checkpoint_wrapper(module: nn.Module, offload_to_cpu: bool = False) -> nn.Mo
(nn.Module):
wrapped module
"""
# Do not use functools.partial like:
# The use of weakref here is to prevent creating a ref cycle: m -> m.forward -> m.
# When such cycle exists, gc won't collect the module when the module is freed.
# That causes GPU memory to be leaked. See the unit test for how we catch that.
#
# module.forward = functools.partial(_checkpointed_forward, module.forward, offload_to_cpu)
#
# It causes the backward to hold-on to tensor memory even when model is
# freed.
# Use a wrapper to wrap the original module.
class CheckpointWrapper(nn.Module):
def __init__(self, module: nn.Module):
super().__init__()
self.module = module
def forward(self, *args: Any, **kwargs: Any) -> Any:
return _checkpointed_forward(self.module, offload_to_cpu, *args, **kwargs)
return CheckpointWrapper(module)
# We prefer this over a class wrapper since the class wrapper would have to
# proxy a lot of fields and methods.
module.forward = functools.partial(_checkpointed_forward, type(module).forward, weakref.ref(module), offload_to_cpu) # type: ignore
return module
def _checkpointed_forward(original_forward: Any, offload_to_cpu: bool, *args: Any, **kwargs: Any) -> Any:
def _checkpointed_forward(
original_forward: Any, weak_self: Any, offload_to_cpu: bool, *args: Any, **kwargs: Any
) -> Any:
# Autograd Functions in PyTorch work best with positional args, since
# the backward must return gradients (or None) for every input argument.
# We can flatten keyword arguments to make this easier.
args = (weak_self(),) + args
kwarg_keys, flat_args = pack_kwargs(*args, **kwargs)
parent_ctx_dict: Dict[str, Any] = {"offload": offload_to_cpu}
output = CheckpointFunction.apply(original_forward, parent_ctx_dict, kwarg_keys, *flat_args)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment