Unverified Commit e3035933 authored by Min Xu's avatar Min Xu Committed by GitHub
Browse files

[bug]: not all CUDA memory is freed when model is deleted (#412)

* [bug]: not all CUDA memory is freed when model is deleted

* fixed memory leak

- without this, peak memory will be high when more than one model
  is trained (i.e. first model leave staff around pushing up the
  peak memory when the second model runs)

* addressed comments

* fix

* changelog
parent 2b15720b
......@@ -13,6 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Fixed
- Catch corner case when the model is too small with respect to the world size, and shards are empty ([#406] (https://github.com/facebookresearch/fairscale/pull/406))
- Memory leak in checkpoint_wrapper ([#413] (https://github.com/facebookresearch/fairscale/pull/413))
## [0.1.7] - 2021-02-19
### Fixed
......
......@@ -4,13 +4,12 @@
# LICENSE file in the root directory of this source tree.
from contextlib import contextmanager
import functools
from typing import Any, Dict, Generator, Optional, Tuple
import torch
from torch import Tensor
import torch.nn as nn
import torch.utils.checkpoint as checkpoint
import torch.utils.checkpoint as torch_checkpoint
from fairscale.utils.containers import pack_kwargs, split_non_tensors, unpack_kwargs, unpack_non_tensors
......@@ -41,8 +40,23 @@ def checkpoint_wrapper(module: nn.Module, offload_to_cpu: bool = False) -> nn.Mo
(nn.Module):
wrapped module
"""
module.forward = functools.partial(_checkpointed_forward, module.forward, offload_to_cpu) # type: ignore
return module
# Do not use functools.partial like:
#
# module.forward = functools.partial(_checkpointed_forward, module.forward, offload_to_cpu)
#
# It causes the backward to hold-on to tensor memory even when model is
# freed.
# Use a wrapper to wrap the original module.
class CheckpointWrapper(nn.Module):
def __init__(self, module: nn.Module):
super().__init__()
self.module = module
def forward(self, *args: Any, **kwargs: Any) -> Any:
return _checkpointed_forward(self.module, offload_to_cpu, *args, **kwargs)
return CheckpointWrapper(module)
def _checkpointed_forward(original_forward: Any, offload_to_cpu: bool, *args: Any, **kwargs: Any) -> Any:
......@@ -52,13 +66,11 @@ def _checkpointed_forward(original_forward: Any, offload_to_cpu: bool, *args: An
kwarg_keys, flat_args = pack_kwargs(*args, **kwargs)
parent_ctx_dict: Dict[str, Any] = {"offload": offload_to_cpu}
output = CheckpointFunction.apply(original_forward, parent_ctx_dict, kwarg_keys, *flat_args)
if isinstance(output, torch.Tensor):
return output
else:
if not isinstance(output, torch.Tensor):
packed_non_tensor_outputs = parent_ctx_dict["packed_non_tensor_outputs"]
if packed_non_tensor_outputs:
output = unpack_non_tensors(output, packed_non_tensor_outputs)
return output
return output
def get_rng_state() -> Dict[str, Any]:
......@@ -109,7 +121,7 @@ class CheckpointFunction(torch.autograd.Function):
**kwargs: Any
) -> Any:
if torch.is_grad_enabled(): # grad may be disabled, e.g., during validation
checkpoint.check_backward_validity(args)
torch_checkpoint.check_backward_validity(args)
ctx.run_function = run_function
ctx.kwarg_keys = kwarg_keys
......@@ -131,15 +143,13 @@ class CheckpointFunction(torch.autograd.Function):
unpacked_args, unpacked_kwargs = unpack_kwargs(kwarg_keys, args)
outputs = run_function(*unpacked_args, **unpacked_kwargs)
if isinstance(outputs, torch.Tensor):
return outputs
else:
if not isinstance(outputs, torch.Tensor):
# Autograd Functions don't like non-Tensor outputs. We can split the
# non-Tensor and Tensor outputs, returning the former by reference
# through *parent_ctx_dict* and returning the latter directly.
outputs, packed_non_tensor_outputs = split_non_tensors(outputs)
parent_ctx_dict["packed_non_tensor_outputs"] = packed_non_tensor_outputs
return outputs
return outputs
@staticmethod
def backward(ctx: Any, *args: Any) -> Tuple[Optional[Tensor], ...]:
......@@ -147,7 +157,7 @@ class CheckpointFunction(torch.autograd.Function):
raise RuntimeError("Checkpointing is not compatible with .grad(), please use .backward() if possible")
tensor_inputs: Tuple = ctx.saved_tensors
tensor_inputs = checkpoint.detach_variable(tensor_inputs)
tensor_inputs = torch_checkpoint.detach_variable(tensor_inputs)
if ctx.fwd_device is not None:
tensor_inputs = tuple(t.to(ctx.fwd_device[i]) for i, t in enumerate(tensor_inputs))
for i, need_grad in enumerate(ctx.grad_requirements):
......
......@@ -3,23 +3,32 @@
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
"""
Test fairscale.nn.misc.checkpoint_activations
"""
"""Test fairscale.nn.misc.checkpoint_activations API."""
import unittest
import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint
from torch.utils.checkpoint import checkpoint as torch_checkpoint
from fairscale.nn.misc.checkpoint_activations import checkpoint_wrapper
def get_cuda_mem_allocated():
"""Helper to get cuda memory allocated if possible."""
if torch.cuda.is_available():
return torch.cuda.memory_allocated()
else:
return 0
class Model(nn.Module):
def __init__(self, use_pytorch_checkpoint=False, use_fairseq_checkpoint=False, **kwargs):
def __init__(self, use_pytorch_checkpoint=False, use_fairscale_checkpoint=False, **kwargs):
super().__init__()
torch.manual_seed(0)
torch.manual_seed(0) # make sure weights are deterministic.
assert not (
use_pytorch_checkpoint and use_fairscale_checkpoint
), "Cannot use both pytorch and fairscale checkpointing mechanisms."
self.use_pytorch_checkpoint = use_pytorch_checkpoint
self.ffn = nn.Sequential(
nn.Linear(32, 128),
......@@ -27,46 +36,70 @@ class Model(nn.Module):
nn.Dropout(p=0.5),
nn.Linear(128, 32),
)
if use_fairseq_checkpoint:
if use_fairscale_checkpoint:
self.ffn = checkpoint_wrapper(self.ffn, **kwargs)
self.out = nn.Linear(32, 1)
def forward(self, x):
if self.use_pytorch_checkpoint:
x = checkpoint(self.ffn, x)
x = torch_checkpoint(self.ffn, x)
else:
x = self.ffn(x)
return self.out(x)
class TestComparisonToPyTorch(unittest.TestCase):
def _test_checkpoint_wrapper(self, device, log_memory_usage=False):
def get_loss_and_gnorm(model):
torch.manual_seed(1)
input = torch.rand(2, 16, 32).requires_grad_(True).to(device)
def _test_checkpoint_wrapper(self, device):
def get_loss_and_gnorm(model, input):
ret = {}
ret["mem_0"] = get_cuda_mem_allocated()
model.zero_grad()
loss = model(input).sum()
ret["mem_after_fwd"] = get_cuda_mem_allocated()
loss.backward()
ret["mem_after_bwd"] = get_cuda_mem_allocated()
gnorm = torch.norm(torch.stack([torch.norm(p.grad.detach()) for p in model.parameters()]))
return {"loss": loss, "gnorm": gnorm}
ret["loss"] = loss.item()
ret["gnorm"] = gnorm.item()
return ret
input = torch.rand(2, 16, 32).requires_grad_(True)
model = Model().to(device)
no_cpt = get_loss_and_gnorm(model)
no_cpt = get_loss_and_gnorm(model, input.to(device))
model = Model(use_pytorch_checkpoint=True).to(device)
pyt_cpt = get_loss_and_gnorm(model)
pyt_cpt = get_loss_and_gnorm(model, input.to(device))
model = Model(use_fairscale_checkpoint=True).to(device)
fairscale_cpt = get_loss_and_gnorm(model, input.to(device))
model = Model(use_fairscale_checkpoint=True, offload_to_cpu=True).to(device)
fairscale_cpt_offload = get_loss_and_gnorm(model, input.to(device))
# Check for correctness.
torch.testing.assert_allclose(no_cpt["loss"], pyt_cpt["loss"])
torch.testing.assert_allclose(no_cpt["gnorm"], pyt_cpt["gnorm"])
model = Model(use_fairseq_checkpoint=True).to(device)
fairseq_cpt = get_loss_and_gnorm(model)
torch.testing.assert_allclose(no_cpt["loss"], fairseq_cpt["loss"])
torch.testing.assert_allclose(no_cpt["gnorm"], fairseq_cpt["gnorm"])
model = Model(use_fairseq_checkpoint=True, offload_to_cpu=True).to(device)
fairseq_cpt_offload = get_loss_and_gnorm(model)
torch.testing.assert_allclose(no_cpt["loss"], fairseq_cpt_offload["loss"])
torch.testing.assert_allclose(no_cpt["gnorm"], fairseq_cpt_offload["gnorm"])
torch.testing.assert_allclose(no_cpt["loss"], fairscale_cpt["loss"])
torch.testing.assert_allclose(no_cpt["gnorm"], fairscale_cpt["gnorm"])
torch.testing.assert_allclose(no_cpt["loss"], fairscale_cpt_offload["loss"])
torch.testing.assert_allclose(no_cpt["gnorm"], fairscale_cpt_offload["gnorm"])
# Check for memory usage for cuda only.
if device == torch.device("cpu"):
return
for d in [no_cpt, pyt_cpt, fairscale_cpt, fairscale_cpt_offload]:
del d["loss"]
del d["gnorm"]
assert no_cpt == {"mem_0": 38912, "mem_after_fwd": 64000, "mem_after_bwd": 74240}, no_cpt
assert pyt_cpt == {"mem_0": 38912, "mem_after_fwd": 43520, "mem_after_bwd": 74240}, pyt_cpt
assert fairscale_cpt == {"mem_0": 38912, "mem_after_fwd": 43520, "mem_after_bwd": 74240}, fairscale_cpt
assert fairscale_cpt_offload == {
"mem_0": 38912,
"mem_after_fwd": 43520,
"mem_after_bwd": 74240,
}, fairscale_cpt_offload
def test_checkpoint_wrapper_cpu(self):
self._test_checkpoint_wrapper(device=torch.device("cpu"))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment