Unverified Commit 4b5b4d3d authored by Min Xu's avatar Min Xu Committed by GitHub
Browse files

[test]: add peak mem in checkpoint test (#415)

* [test]: add peak mem in checkpoint test

* more debugging

* new test

* more fix

* better collection of debug in case of future failures

* update the comment

* typo

* comment

* clarify

* better wording
parent d64ff250
...@@ -30,11 +30,32 @@ def checkpoint_wrapper(module: nn.Module, offload_to_cpu: bool = False) -> nn.Mo ...@@ -30,11 +30,32 @@ def checkpoint_wrapper(module: nn.Module, offload_to_cpu: bool = False) -> nn.Mo
checkpointed_module = checkpoint_wrapper(my_module, offload_to_cpu=True) checkpointed_module = checkpoint_wrapper(my_module, offload_to_cpu=True)
a, b = checkpointed_module(x, y=3, z=torch.Tensor([1])) a, b = checkpointed_module(x, y=3, z=torch.Tensor([1]))
To understand the benefits of checkpointing and the `offload_to_cpu` flag,
let's divide activations into 2 types: inner activations and outer
activations w.r.t. the checkpointed modules. The inner ones are saved
by activation checkpointing, the outer ones are saved by offload_to_cpu.
In terms of GPU memory savings:
- When inner ones are large in size and outer ones are small,
checkpointing helps a lot, offload_to_cpu may help a little.
- When inner ones are small and outer ones are large,
checkpointing helps little, offload_to_cpu helps a lot.
- When both inner and outer are large, both help and the
benefit is additive.
..Note::
The first and last layers are not likely to benefit from the `offload_to_cpu` flag
because (1) there are typically other references to the first layer's input, so
the GPU memory won't be freed; (2) the input to the last layer is immediately
used by the backward pass and won't result in memory savings.
Args: Args:
module (nn.Module): module (nn.Module):
module to wrap The module to be wrapped
offload_to_cpu (Optional, bool): offload_to_cpu (Optional, bool):
whether to offload activations to CPU Whether to offload activations to CPU.
Returns: Returns:
(nn.Module): (nn.Module):
......
...@@ -5,13 +5,13 @@ ...@@ -5,13 +5,13 @@
"""Test fairscale.nn.misc.checkpoint_activations API.""" """Test fairscale.nn.misc.checkpoint_activations API."""
import unittest import pytest
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.utils.checkpoint import checkpoint as torch_checkpoint from torch.utils.checkpoint import checkpoint as torch_checkpoint_wrapper
from fairscale.nn.misc.checkpoint_activations import checkpoint_wrapper from fairscale.nn.misc.checkpoint_activations import checkpoint_wrapper
from fairscale.utils.testing import skip_if_no_cuda, torch_version
def get_cuda_mem_allocated(): def get_cuda_mem_allocated():
...@@ -22,7 +22,38 @@ def get_cuda_mem_allocated(): ...@@ -22,7 +22,38 @@ def get_cuda_mem_allocated():
return 0 return 0
class Model(nn.Module): def get_loss_and_gnorm(model, input):
"""Helper to run a forward/backward pass and return results in a dict."""
ret = {}
ret["mem_0"] = get_cuda_mem_allocated()
ret["mem_peak"] = 0
if ret["mem_0"] > 0:
torch.cuda.reset_peak_memory_stats()
model.zero_grad()
loss = model(input).sum()
ret["mem_after_fwd"] = get_cuda_mem_allocated()
loss.backward()
ret["mem_after_bwd"] = get_cuda_mem_allocated()
gnorm = torch.norm(torch.stack([torch.norm(p.grad.detach()) for p in model.parameters()]))
ret["loss"] = loss.item()
ret["gnorm"] = gnorm.item()
if ret["mem_0"] > 0:
ret["mem_peak"] = torch.cuda.max_memory_allocated()
return ret
class BasicModel(nn.Module):
"""Basic model with a single FFN being checkpointed.
Used for extensive checkings: equivalency with non-checkpoint, torch-checkpoint, etc.
"""
def __init__(self, use_pytorch_checkpoint=False, use_fairscale_checkpoint=False, **kwargs): def __init__(self, use_pytorch_checkpoint=False, use_fairscale_checkpoint=False, **kwargs):
super().__init__() super().__init__()
torch.manual_seed(0) # make sure weights are deterministic. torch.manual_seed(0) # make sure weights are deterministic.
...@@ -42,72 +73,125 @@ class Model(nn.Module): ...@@ -42,72 +73,125 @@ class Model(nn.Module):
def forward(self, x): def forward(self, x):
if self.use_pytorch_checkpoint: if self.use_pytorch_checkpoint:
x = torch_checkpoint(self.ffn, x) x = torch_checkpoint_wrapper(self.ffn, x)
else: else:
x = self.ffn(x) x = self.ffn(x)
return self.out(x) return self.out(x)
class TestComparisonToPyTorch(unittest.TestCase): @pytest.mark.parametrize("device", ["cpu", "cuda"])
def _test_checkpoint_wrapper(self, device): def test_basic(device):
def get_loss_and_gnorm(model, input): if "cuda" in device and not torch.cuda.is_available():
ret = {} pytest.skip("test requires a GPU")
ret["mem_0"] = get_cuda_mem_allocated()
model.zero_grad() input = torch.rand(2, 16, 32).requires_grad_(True)
loss = model(input).sum() model = BasicModel().to(device)
ret["mem_after_fwd"] = get_cuda_mem_allocated() no_cpt = get_loss_and_gnorm(model, input.to(device))
loss.backward()
ret["mem_after_bwd"] = get_cuda_mem_allocated() model = BasicModel(use_pytorch_checkpoint=True).to(device)
gnorm = torch.norm(torch.stack([torch.norm(p.grad.detach()) for p in model.parameters()])) pyt_cpt = get_loss_and_gnorm(model, input.to(device))
ret["loss"] = loss.item()
ret["gnorm"] = gnorm.item() model = BasicModel(use_fairscale_checkpoint=True).to(device)
return ret fairscale_cpt = get_loss_and_gnorm(model, input.to(device))
input = torch.rand(2, 16, 32).requires_grad_(True) model = BasicModel(use_fairscale_checkpoint=True, offload_to_cpu=True).to(device)
model = Model().to(device) fairscale_cpt_offload = get_loss_and_gnorm(model, input.to(device))
no_cpt = get_loss_and_gnorm(model, input.to(device))
# Check for correctness.
model = Model(use_pytorch_checkpoint=True).to(device) for key in "loss", "gnorm":
pyt_cpt = get_loss_and_gnorm(model, input.to(device)) if not (no_cpt[key] == pyt_cpt[key] == fairscale_cpt[key] == fairscale_cpt_offload[key]):
print(no_cpt, pyt_cpt, fairscale_cpt, fairscale_cpt_offload)
model = Model(use_fairscale_checkpoint=True).to(device) assert 0
fairscale_cpt = get_loss_and_gnorm(model, input.to(device)) del no_cpt[key]
del pyt_cpt[key]
model = Model(use_fairscale_checkpoint=True, offload_to_cpu=True).to(device) del fairscale_cpt[key]
fairscale_cpt_offload = get_loss_and_gnorm(model, input.to(device)) del fairscale_cpt_offload[key]
# Check for correctness. # Check for memory usage for cuda only.
torch.testing.assert_allclose(no_cpt["loss"], pyt_cpt["loss"]) if "cpu" in device:
torch.testing.assert_allclose(no_cpt["gnorm"], pyt_cpt["gnorm"]) return
torch.testing.assert_allclose(no_cpt["loss"], fairscale_cpt["loss"]) mem_peaks = [98816, 103424, 103424, 107520]
torch.testing.assert_allclose(no_cpt["gnorm"], fairscale_cpt["gnorm"]) if torch_version() < (1, 7, 0):
# Older torch behaves slightly differently
torch.testing.assert_allclose(no_cpt["loss"], fairscale_cpt_offload["loss"]) mem_peaks = [102400, 103424, 103424, 107520]
torch.testing.assert_allclose(no_cpt["gnorm"], fairscale_cpt_offload["gnorm"])
assert no_cpt == {"mem_0": 38912, "mem_peak": mem_peaks[0], "mem_after_fwd": 64000, "mem_after_bwd": 74240}, no_cpt
# Check for memory usage for cuda only. assert pyt_cpt == {
if device == torch.device("cpu"): "mem_0": 38912,
return "mem_peak": mem_peaks[1],
for d in [no_cpt, pyt_cpt, fairscale_cpt, fairscale_cpt_offload]: "mem_after_fwd": 43520,
del d["loss"] "mem_after_bwd": 74240,
del d["gnorm"] }, pyt_cpt
assert no_cpt == {"mem_0": 38912, "mem_after_fwd": 64000, "mem_after_bwd": 74240}, no_cpt assert fairscale_cpt == {
assert pyt_cpt == {"mem_0": 38912, "mem_after_fwd": 43520, "mem_after_bwd": 74240}, pyt_cpt "mem_0": 38912,
assert fairscale_cpt == {"mem_0": 38912, "mem_after_fwd": 43520, "mem_after_bwd": 74240}, fairscale_cpt "mem_peak": mem_peaks[2],
assert fairscale_cpt_offload == { "mem_after_fwd": 43520,
"mem_0": 38912, "mem_after_bwd": 74240,
"mem_after_fwd": 43520, }, fairscale_cpt
"mem_after_bwd": 74240, assert fairscale_cpt_offload == {
}, fairscale_cpt_offload "mem_0": 38912,
"mem_peak": mem_peaks[3],
def test_checkpoint_wrapper_cpu(self): "mem_after_fwd": 43520,
self._test_checkpoint_wrapper(device=torch.device("cpu")) "mem_after_bwd": 74240,
}, fairscale_cpt_offload
@unittest.skipIf(not torch.cuda.is_available(), "test requires a GPU")
def test_checkpoint_wrapper_cuda(self):
self._test_checkpoint_wrapper(device=torch.device("cuda")) class CpuOffloadModel(nn.Module):
"""Model used to check cpu offload memory saving"""
if __name__ == "__main__": def __init__(self, enable_checkpoint=False, cpu_offload=False):
unittest.main() super().__init__()
torch.manual_seed(0) # make sure weights are deterministic.
# These numbers are picked to show cpu_offload memory saving.
# Inner (recomputed) activation sizes need to be just right
# to show the benefit.
self.layers = nn.Sequential(
nn.Sequential(nn.Linear(4, 4), nn.Linear(4, 4), nn.Linear(4, 8)),
nn.Sequential(nn.Linear(8, 4), nn.Linear(4, 4), nn.Linear(4, 4)),
nn.Sequential(nn.Linear(4, 6), nn.Linear(6, 8), nn.Linear(8, 2)),
)
if enable_checkpoint:
for i, layer in enumerate(self.layers):
# Only middle layer needs to have offloading
self.layers[i] = checkpoint_wrapper(layer, cpu_offload if i == 1 else False)
def forward(self, x):
return self.layers(x)
@skip_if_no_cuda
def test_offload_memory():
device = "cuda"
input = torch.rand(60, 24, 4).requires_grad_(True)
model = CpuOffloadModel().to(device)
base = get_loss_and_gnorm(model, input.to(device))
model = CpuOffloadModel(True).to(device)
cpt = get_loss_and_gnorm(model, input.to(device))
model = CpuOffloadModel(True, True).to(device)
offload = get_loss_and_gnorm(model, input.to(device))
for key in "loss", "gnorm":
if not (base[key] == cpt[key] == offload[key]):
# Use print to collect all debugging info.
print(base, cpt, offload)
assert 0
del base[key]
del cpt[key]
del offload[key]
ref_base = {"mem_0": 32256, "mem_peak": 334336, "mem_after_fwd": 274944, "mem_after_bwd": 41984}
ref_cpt = {"mem_0": 32256, "mem_peak": 253952, "mem_after_fwd": 101888, "mem_after_bwd": 41984}
ref_offload = {"mem_0": 32256, "mem_peak": 207872, "mem_after_fwd": 55808, "mem_after_bwd": 41984}
if not (base == ref_base and cpt == ref_cpt and offload == ref_offload):
# Use print to collect all debugging info.
print(base, cpt, offload)
assert 0
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment