test_checkpoint_activations.py 11.1 KB
Newer Older
1
2
3
4
5
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

6
"""Test fairscale.nn.misc.checkpoint_activations API."""
7

8
import pytest
9
10
import torch
import torch.nn as nn
11
from torch.utils.checkpoint import checkpoint as torch_checkpoint_wrapper
12

13
from fairscale.nn.checkpoint.checkpoint_activations import checkpoint_wrapper, disable_checkpointing
14
from fairscale.nn.misc import checkpoint_wrapper as deprecated_checkpoint_wrapper
15
16
from fairscale.utils import torch_version
from fairscale.utils.testing import skip_if_no_cuda
17
18


19
20
21
22
23
24
25
26
def get_cuda_mem_allocated():
    """Helper to get cuda memory allocated if possible."""
    if torch.cuda.is_available():
        return torch.cuda.memory_allocated()
    else:
        return 0


27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
def get_loss_and_gnorm(model, input):
    """Helper to run a forward/backward pass and return results in a dict."""
    ret = {}

    ret["mem_0"] = get_cuda_mem_allocated()
    ret["mem_peak"] = 0
    if ret["mem_0"] > 0:
        torch.cuda.reset_peak_memory_stats()

    model.zero_grad()
    loss = model(input).sum()
    ret["mem_after_fwd"] = get_cuda_mem_allocated()

    loss.backward()
    ret["mem_after_bwd"] = get_cuda_mem_allocated()

    gnorm = torch.norm(torch.stack([torch.norm(p.grad.detach()) for p in model.parameters()]))
    ret["loss"] = loss.item()
    ret["gnorm"] = gnorm.item()

    if ret["mem_0"] > 0:
        ret["mem_peak"] = torch.cuda.max_memory_allocated()

    return ret


class BasicModel(nn.Module):
    """Basic model with a single FFN being checkpointed.

56
    Used for extensive checkings: equivalency with non-checkpoint, torch-checkpoint, etc.
57
58
    """

59
    def __init__(self, use_pytorch_checkpoint=False, use_fairscale_checkpoint=False, **kwargs):
60
        super().__init__()
61
62
63
64
        torch.manual_seed(0)  # make sure weights are deterministic.
        assert not (
            use_pytorch_checkpoint and use_fairscale_checkpoint
        ), "Cannot use both pytorch and fairscale checkpointing mechanisms."
65
66
67
68
69
70
71
        self.use_pytorch_checkpoint = use_pytorch_checkpoint
        self.ffn = nn.Sequential(
            nn.Linear(32, 128),
            # add a Dropout layer to test RNG save/restore
            nn.Dropout(p=0.5),
            nn.Linear(128, 32),
        )
72
        if use_fairscale_checkpoint:
73
74
75
76
77
            self.ffn = checkpoint_wrapper(self.ffn, **kwargs)
        self.out = nn.Linear(32, 1)

    def forward(self, x):
        if self.use_pytorch_checkpoint:
78
            x = torch_checkpoint_wrapper(self.ffn, x)
79
80
81
82
83
        else:
            x = self.ffn(x)
        return self.out(x)


84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
@pytest.mark.parametrize("device", ["cpu", "cuda"])
def test_basic(device):
    if "cuda" in device and not torch.cuda.is_available():
        pytest.skip("test requires a GPU")

    input = torch.rand(2, 16, 32).requires_grad_(True)
    model = BasicModel().to(device)
    no_cpt = get_loss_and_gnorm(model, input.to(device))

    model = BasicModel(use_pytorch_checkpoint=True).to(device)
    pyt_cpt = get_loss_and_gnorm(model, input.to(device))

    model = BasicModel(use_fairscale_checkpoint=True).to(device)
    fairscale_cpt = get_loss_and_gnorm(model, input.to(device))

    model = BasicModel(use_fairscale_checkpoint=True, offload_to_cpu=True).to(device)
    fairscale_cpt_offload = get_loss_and_gnorm(model, input.to(device))

    # Check for correctness.
    for key in "loss", "gnorm":
        if not (no_cpt[key] == pyt_cpt[key] == fairscale_cpt[key] == fairscale_cpt_offload[key]):
            print(no_cpt, pyt_cpt, fairscale_cpt, fairscale_cpt_offload)
            assert 0
        del no_cpt[key]
        del pyt_cpt[key]
        del fairscale_cpt[key]
        del fairscale_cpt_offload[key]

    # Check for memory usage for cuda only.
    if "cpu" in device:
        return

    mem_peaks = [98816, 103424, 103424, 107520]
    if torch_version() < (1, 7, 0):
        # Older torch behaves slightly differently
        mem_peaks = [102400, 103424, 103424, 107520]

    assert no_cpt == {"mem_0": 38912, "mem_peak": mem_peaks[0], "mem_after_fwd": 64000, "mem_after_bwd": 74240}, no_cpt
    assert pyt_cpt == {
        "mem_0": 38912,
        "mem_peak": mem_peaks[1],
        "mem_after_fwd": 43520,
        "mem_after_bwd": 74240,
    }, pyt_cpt
    assert fairscale_cpt == {
        "mem_0": 38912,
        "mem_peak": mem_peaks[2],
        "mem_after_fwd": 43520,
        "mem_after_bwd": 74240,
    }, fairscale_cpt
    assert fairscale_cpt_offload == {
        "mem_0": 38912,
        "mem_peak": mem_peaks[3],
        "mem_after_fwd": 43520,
        "mem_after_bwd": 74240,
    }, fairscale_cpt_offload


class CpuOffloadModel(nn.Module):
    """Model used to check cpu offload memory saving"""

    def __init__(self, enable_checkpoint=False, cpu_offload=False):
        super().__init__()

        torch.manual_seed(0)  # make sure weights are deterministic.

        # These numbers are picked to show cpu_offload memory saving.
        # Inner (recomputed) activation sizes need to be just right
        # to show the benefit.
        self.layers = nn.Sequential(
            nn.Sequential(nn.Linear(4, 4), nn.Linear(4, 4), nn.Linear(4, 8)),
            nn.Sequential(nn.Linear(8, 4), nn.Linear(4, 4), nn.Linear(4, 4)),
            nn.Sequential(nn.Linear(4, 6), nn.Linear(6, 8), nn.Linear(8, 2)),
        )

        if enable_checkpoint:
            for i, layer in enumerate(self.layers):
                # Only middle layer needs to have offloading
                self.layers[i] = checkpoint_wrapper(layer, cpu_offload if i == 1 else False)

    def forward(self, x):
        return self.layers(x)


@skip_if_no_cuda
def test_offload_memory():
    device = "cuda"

    input = torch.rand(60, 24, 4).requires_grad_(True)

    model = CpuOffloadModel().to(device)
    base = get_loss_and_gnorm(model, input.to(device))

    model = CpuOffloadModel(True).to(device)
    cpt = get_loss_and_gnorm(model, input.to(device))

    model = CpuOffloadModel(True, True).to(device)
    offload = get_loss_and_gnorm(model, input.to(device))

    for key in "loss", "gnorm":
        if not (base[key] == cpt[key] == offload[key]):
            # Use print to collect all debugging info.
            print(base, cpt, offload)
            assert 0
        del base[key]
        del cpt[key]
        del offload[key]

    ref_base = {"mem_0": 32256, "mem_peak": 334336, "mem_after_fwd": 274944, "mem_after_bwd": 41984}
    ref_cpt = {"mem_0": 32256, "mem_peak": 253952, "mem_after_fwd": 101888, "mem_after_bwd": 41984}
    ref_offload = {"mem_0": 32256, "mem_peak": 207872, "mem_after_fwd": 55808, "mem_after_bwd": 41984}

    if not (base == ref_base and cpt == ref_cpt and offload == ref_offload):
        # Use print to collect all debugging info.
        print(base, cpt, offload)
        assert 0
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256


class MultiinMultioutModel(nn.Module):
    """Model used to check different inputs and outputs"""

    def __init__(self, multiout=False, checkpoint_config=0):
        super().__init__()
        torch.manual_seed(0)  # make sure weights are deterministic.

        self.multiout = multiout

        self.conv1 = nn.Sequential(nn.Conv2d(1, 5, 3), nn.ReLU(), nn.Conv2d(5, 5, 3))
        self.conv2 = nn.Sequential(nn.Conv2d(3, 5, 3), nn.ReLU(), nn.Conv2d(5, 5, 3))

        assert 0 <= checkpoint_config <= 3
        if checkpoint_config & 1:
            self.conv1 = checkpoint_wrapper(self.conv1)
        if checkpoint_config & (1 << 1):
            self.conv2 = checkpoint_wrapper(self.conv2)

    def forward(self, x1, x2=None):
        out1 = self.conv1(x1)
        out2 = self.conv2(x2)
        if self.multiout:
            return out1, out2
        return out1 + out2


@pytest.mark.parametrize("device", ["cpu", "cuda"])
@pytest.mark.parametrize("multiout", [True, False])
@pytest.mark.parametrize("checkpoint_config", [1, 2, 3])
def test_multiin_multiout(device, multiout, checkpoint_config):
    if "cuda" in device and not torch.cuda.is_available():
        pytest.skip("test requires a GPU")

    def train(model, in1, in2):
        out = model(in1, x2=in2)
        if isinstance(out, tuple):
            out = torch.cat(out)
        loss = out.sum()
        loss.backward()
        gnorm = torch.norm(torch.stack([torch.norm(p.grad.detach()) for p in model.parameters()]))
        return {"loss": loss.item(), "gnorm": gnorm.item()}

    in1 = torch.rand(4, 1, 32, 32).requires_grad_(True)
    in2 = torch.rand(4, 3, 32, 32).requires_grad_(True)

    model = MultiinMultioutModel(multiout, 0).to(device)
    no_cpt = train(model, in1.to(device), in2.to(device))

    model = MultiinMultioutModel(multiout, checkpoint_config).to(device)
    cpt = train(model, in1.to(device), in2.to(device))

    for key in ["loss", "gnorm"]:
        if no_cpt[key] != cpt[key]:
            print(no_cpt, cpt)
            assert 0
257
258
259
260
261
262
263
264
265
266
267
268
269
270


def test_deprecated_path():

    # Check if import works as before.
    # from fairscale.nn.misc.checkpoint_activations import checkpoint_wrapper
    from fairscale.nn import checkpoint_wrapper

    ffn = nn.Sequential(nn.Linear(32, 128), nn.Dropout(p=0.5), nn.Linear(128, 32),)
    ffn = checkpoint_wrapper(ffn, {})

    # Check if direct import works as before.
    ffn = nn.Sequential(nn.Linear(32, 128), nn.Dropout(p=0.5), nn.Linear(128, 32),)
    ffn = deprecated_checkpoint_wrapper(ffn, {})
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305


@skip_if_no_cuda
def test_list_input():
    """ Test checkpointing with input argument type being a list.

    Note: Testing shows that PyTorch's torch.utils.checkpoint function does not pass this test.
    """
    count = 0

    class Model(nn.Module):
        def __init__(self):
            super().__init__()
            self.conv = nn.Linear(2, 2)

        def forward(self, x):
            nonlocal count
            count += 1
            y = []
            for i in x:
                y.append(self.conv(i))
            return y

    model = nn.Sequential(checkpoint_wrapper(Model()), Model()).cuda()
    in_data1 = torch.rand(4, 2).cuda()
    in_data2 = torch.rand(4, 2).cuda()

    # Forward. Count should be 2 for 2 modules.
    out = model([in_data1, in_data2])
    loss = sum(x.sum() for x in out)
    assert count == 2, f"Incorrect count {count}"

    # Backward. Adds 1 more forward call due to checkpoint.
    loss.backward()
    assert count == 3, f"Incorrect count {count}"
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341


def test_checkpoint_disabling():
    """ Test to check new disable_checkpoint() API added to checkpoint_wrapper."""

    class TestModel(nn.Module):
        def __init__(self):
            super().__init__()
            self.cnt = 0
            self.linear = nn.Linear(2, 2)

        def forward(self, x):
            self.cnt += 1
            y = []
            for i in x:
                y.append(self.linear(i))
            return y

    x = torch.rand(4, 2)
    model1 = checkpoint_wrapper(TestModel())
    model2 = checkpoint_wrapper(TestModel())

    # Forward. cnt += 1
    y = model1(x)
    y = sum(i.sum() for i in y)
    # Backward. cnt += 1
    y.backward()
    assert model1.cnt == 2

    with disable_checkpointing():
        # Forward. cnt += 1
        y = model2(x)
        y = sum(i.sum() for i in y)
        # Backward. cnt remains same as checkpointing is disabled
        y.backward()
    assert model2.cnt == 1