test_checkpoint_activations.py 4.27 KB
Newer Older
1
2
3
4
5
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

6
"""Test fairscale.nn.misc.checkpoint_activations API."""
7
8
9
10
11

import unittest

import torch
import torch.nn as nn
12
from torch.utils.checkpoint import checkpoint as torch_checkpoint
13
14
15
16

from fairscale.nn.misc.checkpoint_activations import checkpoint_wrapper


17
18
19
20
21
22
23
24
def get_cuda_mem_allocated():
    """Helper to get cuda memory allocated if possible."""
    if torch.cuda.is_available():
        return torch.cuda.memory_allocated()
    else:
        return 0


25
class Model(nn.Module):
26
    def __init__(self, use_pytorch_checkpoint=False, use_fairscale_checkpoint=False, **kwargs):
27
        super().__init__()
28
29
30
31
        torch.manual_seed(0)  # make sure weights are deterministic.
        assert not (
            use_pytorch_checkpoint and use_fairscale_checkpoint
        ), "Cannot use both pytorch and fairscale checkpointing mechanisms."
32
33
34
35
36
37
38
        self.use_pytorch_checkpoint = use_pytorch_checkpoint
        self.ffn = nn.Sequential(
            nn.Linear(32, 128),
            # add a Dropout layer to test RNG save/restore
            nn.Dropout(p=0.5),
            nn.Linear(128, 32),
        )
39
        if use_fairscale_checkpoint:
40
41
42
43
44
            self.ffn = checkpoint_wrapper(self.ffn, **kwargs)
        self.out = nn.Linear(32, 1)

    def forward(self, x):
        if self.use_pytorch_checkpoint:
45
            x = torch_checkpoint(self.ffn, x)
46
47
48
49
50
51
        else:
            x = self.ffn(x)
        return self.out(x)


class TestComparisonToPyTorch(unittest.TestCase):
52
53
54
55
    def _test_checkpoint_wrapper(self, device):
        def get_loss_and_gnorm(model, input):
            ret = {}
            ret["mem_0"] = get_cuda_mem_allocated()
56
57
            model.zero_grad()
            loss = model(input).sum()
58
            ret["mem_after_fwd"] = get_cuda_mem_allocated()
59
            loss.backward()
60
            ret["mem_after_bwd"] = get_cuda_mem_allocated()
61
            gnorm = torch.norm(torch.stack([torch.norm(p.grad.detach()) for p in model.parameters()]))
62
63
64
            ret["loss"] = loss.item()
            ret["gnorm"] = gnorm.item()
            return ret
65

66
        input = torch.rand(2, 16, 32).requires_grad_(True)
67
        model = Model().to(device)
68
        no_cpt = get_loss_and_gnorm(model, input.to(device))
69
70

        model = Model(use_pytorch_checkpoint=True).to(device)
71
72
73
74
75
76
77
78
79
        pyt_cpt = get_loss_and_gnorm(model, input.to(device))

        model = Model(use_fairscale_checkpoint=True).to(device)
        fairscale_cpt = get_loss_and_gnorm(model, input.to(device))

        model = Model(use_fairscale_checkpoint=True, offload_to_cpu=True).to(device)
        fairscale_cpt_offload = get_loss_and_gnorm(model, input.to(device))

        # Check for correctness.
80
81
82
        torch.testing.assert_allclose(no_cpt["loss"], pyt_cpt["loss"])
        torch.testing.assert_allclose(no_cpt["gnorm"], pyt_cpt["gnorm"])

83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
        torch.testing.assert_allclose(no_cpt["loss"], fairscale_cpt["loss"])
        torch.testing.assert_allclose(no_cpt["gnorm"], fairscale_cpt["gnorm"])

        torch.testing.assert_allclose(no_cpt["loss"], fairscale_cpt_offload["loss"])
        torch.testing.assert_allclose(no_cpt["gnorm"], fairscale_cpt_offload["gnorm"])

        # Check for memory usage for cuda only.
        if device == torch.device("cpu"):
            return
        for d in [no_cpt, pyt_cpt, fairscale_cpt, fairscale_cpt_offload]:
            del d["loss"]
            del d["gnorm"]
        assert no_cpt == {"mem_0": 38912, "mem_after_fwd": 64000, "mem_after_bwd": 74240}, no_cpt
        assert pyt_cpt == {"mem_0": 38912, "mem_after_fwd": 43520, "mem_after_bwd": 74240}, pyt_cpt
        assert fairscale_cpt == {"mem_0": 38912, "mem_after_fwd": 43520, "mem_after_bwd": 74240}, fairscale_cpt
        assert fairscale_cpt_offload == {
            "mem_0": 38912,
            "mem_after_fwd": 43520,
            "mem_after_bwd": 74240,
        }, fairscale_cpt_offload
103
104
105
106
107
108
109
110
111
112
113

    def test_checkpoint_wrapper_cpu(self):
        self._test_checkpoint_wrapper(device=torch.device("cpu"))

    @unittest.skipIf(not torch.cuda.is_available(), "test requires a GPU")
    def test_checkpoint_wrapper_cuda(self):
        self._test_checkpoint_wrapper(device=torch.device("cuda"))


if __name__ == "__main__":
    unittest.main()