Unverified Commit e3035933 authored by Min Xu's avatar Min Xu Committed by GitHub
Browse files

[bug]: not all CUDA memory is freed when model is deleted (#412)

* [bug]: not all CUDA memory is freed when model is deleted

* fixed memory leak

- without this, peak memory will be high when more than one model
  is trained (i.e. first model leave staff around pushing up the
  peak memory when the second model runs)

* addressed comments

* fix

* changelog
parent 2b15720b
...@@ -13,6 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ...@@ -13,6 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Fixed ### Fixed
- Catch corner case when the model is too small with respect to the world size, and shards are empty ([#406] (https://github.com/facebookresearch/fairscale/pull/406)) - Catch corner case when the model is too small with respect to the world size, and shards are empty ([#406] (https://github.com/facebookresearch/fairscale/pull/406))
- Memory leak in checkpoint_wrapper ([#413] (https://github.com/facebookresearch/fairscale/pull/413))
## [0.1.7] - 2021-02-19 ## [0.1.7] - 2021-02-19
### Fixed ### Fixed
......
...@@ -4,13 +4,12 @@ ...@@ -4,13 +4,12 @@
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from contextlib import contextmanager from contextlib import contextmanager
import functools
from typing import Any, Dict, Generator, Optional, Tuple from typing import Any, Dict, Generator, Optional, Tuple
import torch import torch
from torch import Tensor from torch import Tensor
import torch.nn as nn import torch.nn as nn
import torch.utils.checkpoint as checkpoint import torch.utils.checkpoint as torch_checkpoint
from fairscale.utils.containers import pack_kwargs, split_non_tensors, unpack_kwargs, unpack_non_tensors from fairscale.utils.containers import pack_kwargs, split_non_tensors, unpack_kwargs, unpack_non_tensors
...@@ -41,8 +40,23 @@ def checkpoint_wrapper(module: nn.Module, offload_to_cpu: bool = False) -> nn.Mo ...@@ -41,8 +40,23 @@ def checkpoint_wrapper(module: nn.Module, offload_to_cpu: bool = False) -> nn.Mo
(nn.Module): (nn.Module):
wrapped module wrapped module
""" """
module.forward = functools.partial(_checkpointed_forward, module.forward, offload_to_cpu) # type: ignore # Do not use functools.partial like:
return module #
# module.forward = functools.partial(_checkpointed_forward, module.forward, offload_to_cpu)
#
# It causes the backward to hold-on to tensor memory even when model is
# freed.
# Use a wrapper to wrap the original module.
class CheckpointWrapper(nn.Module):
def __init__(self, module: nn.Module):
super().__init__()
self.module = module
def forward(self, *args: Any, **kwargs: Any) -> Any:
return _checkpointed_forward(self.module, offload_to_cpu, *args, **kwargs)
return CheckpointWrapper(module)
def _checkpointed_forward(original_forward: Any, offload_to_cpu: bool, *args: Any, **kwargs: Any) -> Any: def _checkpointed_forward(original_forward: Any, offload_to_cpu: bool, *args: Any, **kwargs: Any) -> Any:
...@@ -52,13 +66,11 @@ def _checkpointed_forward(original_forward: Any, offload_to_cpu: bool, *args: An ...@@ -52,13 +66,11 @@ def _checkpointed_forward(original_forward: Any, offload_to_cpu: bool, *args: An
kwarg_keys, flat_args = pack_kwargs(*args, **kwargs) kwarg_keys, flat_args = pack_kwargs(*args, **kwargs)
parent_ctx_dict: Dict[str, Any] = {"offload": offload_to_cpu} parent_ctx_dict: Dict[str, Any] = {"offload": offload_to_cpu}
output = CheckpointFunction.apply(original_forward, parent_ctx_dict, kwarg_keys, *flat_args) output = CheckpointFunction.apply(original_forward, parent_ctx_dict, kwarg_keys, *flat_args)
if isinstance(output, torch.Tensor): if not isinstance(output, torch.Tensor):
return output
else:
packed_non_tensor_outputs = parent_ctx_dict["packed_non_tensor_outputs"] packed_non_tensor_outputs = parent_ctx_dict["packed_non_tensor_outputs"]
if packed_non_tensor_outputs: if packed_non_tensor_outputs:
output = unpack_non_tensors(output, packed_non_tensor_outputs) output = unpack_non_tensors(output, packed_non_tensor_outputs)
return output return output
def get_rng_state() -> Dict[str, Any]: def get_rng_state() -> Dict[str, Any]:
...@@ -109,7 +121,7 @@ class CheckpointFunction(torch.autograd.Function): ...@@ -109,7 +121,7 @@ class CheckpointFunction(torch.autograd.Function):
**kwargs: Any **kwargs: Any
) -> Any: ) -> Any:
if torch.is_grad_enabled(): # grad may be disabled, e.g., during validation if torch.is_grad_enabled(): # grad may be disabled, e.g., during validation
checkpoint.check_backward_validity(args) torch_checkpoint.check_backward_validity(args)
ctx.run_function = run_function ctx.run_function = run_function
ctx.kwarg_keys = kwarg_keys ctx.kwarg_keys = kwarg_keys
...@@ -131,15 +143,13 @@ class CheckpointFunction(torch.autograd.Function): ...@@ -131,15 +143,13 @@ class CheckpointFunction(torch.autograd.Function):
unpacked_args, unpacked_kwargs = unpack_kwargs(kwarg_keys, args) unpacked_args, unpacked_kwargs = unpack_kwargs(kwarg_keys, args)
outputs = run_function(*unpacked_args, **unpacked_kwargs) outputs = run_function(*unpacked_args, **unpacked_kwargs)
if isinstance(outputs, torch.Tensor): if not isinstance(outputs, torch.Tensor):
return outputs
else:
# Autograd Functions don't like non-Tensor outputs. We can split the # Autograd Functions don't like non-Tensor outputs. We can split the
# non-Tensor and Tensor outputs, returning the former by reference # non-Tensor and Tensor outputs, returning the former by reference
# through *parent_ctx_dict* and returning the latter directly. # through *parent_ctx_dict* and returning the latter directly.
outputs, packed_non_tensor_outputs = split_non_tensors(outputs) outputs, packed_non_tensor_outputs = split_non_tensors(outputs)
parent_ctx_dict["packed_non_tensor_outputs"] = packed_non_tensor_outputs parent_ctx_dict["packed_non_tensor_outputs"] = packed_non_tensor_outputs
return outputs return outputs
@staticmethod @staticmethod
def backward(ctx: Any, *args: Any) -> Tuple[Optional[Tensor], ...]: def backward(ctx: Any, *args: Any) -> Tuple[Optional[Tensor], ...]:
...@@ -147,7 +157,7 @@ class CheckpointFunction(torch.autograd.Function): ...@@ -147,7 +157,7 @@ class CheckpointFunction(torch.autograd.Function):
raise RuntimeError("Checkpointing is not compatible with .grad(), please use .backward() if possible") raise RuntimeError("Checkpointing is not compatible with .grad(), please use .backward() if possible")
tensor_inputs: Tuple = ctx.saved_tensors tensor_inputs: Tuple = ctx.saved_tensors
tensor_inputs = checkpoint.detach_variable(tensor_inputs) tensor_inputs = torch_checkpoint.detach_variable(tensor_inputs)
if ctx.fwd_device is not None: if ctx.fwd_device is not None:
tensor_inputs = tuple(t.to(ctx.fwd_device[i]) for i, t in enumerate(tensor_inputs)) tensor_inputs = tuple(t.to(ctx.fwd_device[i]) for i, t in enumerate(tensor_inputs))
for i, need_grad in enumerate(ctx.grad_requirements): for i, need_grad in enumerate(ctx.grad_requirements):
......
...@@ -3,23 +3,32 @@ ...@@ -3,23 +3,32 @@
# This source code is licensed under the BSD license found in the # This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
""" """Test fairscale.nn.misc.checkpoint_activations API."""
Test fairscale.nn.misc.checkpoint_activations
"""
import unittest import unittest
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.utils.checkpoint import checkpoint from torch.utils.checkpoint import checkpoint as torch_checkpoint
from fairscale.nn.misc.checkpoint_activations import checkpoint_wrapper from fairscale.nn.misc.checkpoint_activations import checkpoint_wrapper
def get_cuda_mem_allocated():
"""Helper to get cuda memory allocated if possible."""
if torch.cuda.is_available():
return torch.cuda.memory_allocated()
else:
return 0
class Model(nn.Module): class Model(nn.Module):
def __init__(self, use_pytorch_checkpoint=False, use_fairseq_checkpoint=False, **kwargs): def __init__(self, use_pytorch_checkpoint=False, use_fairscale_checkpoint=False, **kwargs):
super().__init__() super().__init__()
torch.manual_seed(0) torch.manual_seed(0) # make sure weights are deterministic.
assert not (
use_pytorch_checkpoint and use_fairscale_checkpoint
), "Cannot use both pytorch and fairscale checkpointing mechanisms."
self.use_pytorch_checkpoint = use_pytorch_checkpoint self.use_pytorch_checkpoint = use_pytorch_checkpoint
self.ffn = nn.Sequential( self.ffn = nn.Sequential(
nn.Linear(32, 128), nn.Linear(32, 128),
...@@ -27,46 +36,70 @@ class Model(nn.Module): ...@@ -27,46 +36,70 @@ class Model(nn.Module):
nn.Dropout(p=0.5), nn.Dropout(p=0.5),
nn.Linear(128, 32), nn.Linear(128, 32),
) )
if use_fairseq_checkpoint: if use_fairscale_checkpoint:
self.ffn = checkpoint_wrapper(self.ffn, **kwargs) self.ffn = checkpoint_wrapper(self.ffn, **kwargs)
self.out = nn.Linear(32, 1) self.out = nn.Linear(32, 1)
def forward(self, x): def forward(self, x):
if self.use_pytorch_checkpoint: if self.use_pytorch_checkpoint:
x = checkpoint(self.ffn, x) x = torch_checkpoint(self.ffn, x)
else: else:
x = self.ffn(x) x = self.ffn(x)
return self.out(x) return self.out(x)
class TestComparisonToPyTorch(unittest.TestCase): class TestComparisonToPyTorch(unittest.TestCase):
def _test_checkpoint_wrapper(self, device, log_memory_usage=False): def _test_checkpoint_wrapper(self, device):
def get_loss_and_gnorm(model): def get_loss_and_gnorm(model, input):
torch.manual_seed(1) ret = {}
input = torch.rand(2, 16, 32).requires_grad_(True).to(device) ret["mem_0"] = get_cuda_mem_allocated()
model.zero_grad() model.zero_grad()
loss = model(input).sum() loss = model(input).sum()
ret["mem_after_fwd"] = get_cuda_mem_allocated()
loss.backward() loss.backward()
ret["mem_after_bwd"] = get_cuda_mem_allocated()
gnorm = torch.norm(torch.stack([torch.norm(p.grad.detach()) for p in model.parameters()])) gnorm = torch.norm(torch.stack([torch.norm(p.grad.detach()) for p in model.parameters()]))
return {"loss": loss, "gnorm": gnorm} ret["loss"] = loss.item()
ret["gnorm"] = gnorm.item()
return ret
input = torch.rand(2, 16, 32).requires_grad_(True)
model = Model().to(device) model = Model().to(device)
no_cpt = get_loss_and_gnorm(model) no_cpt = get_loss_and_gnorm(model, input.to(device))
model = Model(use_pytorch_checkpoint=True).to(device) model = Model(use_pytorch_checkpoint=True).to(device)
pyt_cpt = get_loss_and_gnorm(model) pyt_cpt = get_loss_and_gnorm(model, input.to(device))
model = Model(use_fairscale_checkpoint=True).to(device)
fairscale_cpt = get_loss_and_gnorm(model, input.to(device))
model = Model(use_fairscale_checkpoint=True, offload_to_cpu=True).to(device)
fairscale_cpt_offload = get_loss_and_gnorm(model, input.to(device))
# Check for correctness.
torch.testing.assert_allclose(no_cpt["loss"], pyt_cpt["loss"]) torch.testing.assert_allclose(no_cpt["loss"], pyt_cpt["loss"])
torch.testing.assert_allclose(no_cpt["gnorm"], pyt_cpt["gnorm"]) torch.testing.assert_allclose(no_cpt["gnorm"], pyt_cpt["gnorm"])
model = Model(use_fairseq_checkpoint=True).to(device) torch.testing.assert_allclose(no_cpt["loss"], fairscale_cpt["loss"])
fairseq_cpt = get_loss_and_gnorm(model) torch.testing.assert_allclose(no_cpt["gnorm"], fairscale_cpt["gnorm"])
torch.testing.assert_allclose(no_cpt["loss"], fairseq_cpt["loss"])
torch.testing.assert_allclose(no_cpt["gnorm"], fairseq_cpt["gnorm"]) torch.testing.assert_allclose(no_cpt["loss"], fairscale_cpt_offload["loss"])
torch.testing.assert_allclose(no_cpt["gnorm"], fairscale_cpt_offload["gnorm"])
model = Model(use_fairseq_checkpoint=True, offload_to_cpu=True).to(device)
fairseq_cpt_offload = get_loss_and_gnorm(model) # Check for memory usage for cuda only.
torch.testing.assert_allclose(no_cpt["loss"], fairseq_cpt_offload["loss"]) if device == torch.device("cpu"):
torch.testing.assert_allclose(no_cpt["gnorm"], fairseq_cpt_offload["gnorm"]) return
for d in [no_cpt, pyt_cpt, fairscale_cpt, fairscale_cpt_offload]:
del d["loss"]
del d["gnorm"]
assert no_cpt == {"mem_0": 38912, "mem_after_fwd": 64000, "mem_after_bwd": 74240}, no_cpt
assert pyt_cpt == {"mem_0": 38912, "mem_after_fwd": 43520, "mem_after_bwd": 74240}, pyt_cpt
assert fairscale_cpt == {"mem_0": 38912, "mem_after_fwd": 43520, "mem_after_bwd": 74240}, fairscale_cpt
assert fairscale_cpt_offload == {
"mem_0": 38912,
"mem_after_fwd": 43520,
"mem_after_bwd": 74240,
}, fairscale_cpt_offload
def test_checkpoint_wrapper_cpu(self): def test_checkpoint_wrapper_cpu(self):
self._test_checkpoint_wrapper(device=torch.device("cpu")) self._test_checkpoint_wrapper(device=torch.device("cpu"))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment