test_checkpoint.py 5.07 KB
Newer Older
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

# Copyright 2019 Kakao Brain
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from functools import partial

import pytest
import torch
from torch import nn
import torch.cuda

Tom Birch's avatar
Tom Birch committed
27
from fairscale.nn.pipe.checkpoint import Checkpointing, Function, TensorOrTensors, is_checkpointing, is_recomputing
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
28
29
30
31
32
33
34
35
from fairscale.nn.pipe.dependency import fork, join
from fairscale.nn.pipe.microbatch import Batch

devices = ["cpu"]
if torch.cuda.is_available():
    devices.append("cuda")


Tom Birch's avatar
Tom Birch committed
36
37
38
39
40
41
42
43
44
45
46
47
48
49
def make_checkpoint(function: Function, input: TensorOrTensors, index: int) -> TensorOrTensors:
    """Makes a checkpoint with a simple interface like
    :func:`torch.utils.checkpoint.checkpoint`. It's only used to test or debug
    :class:`Checkpoint` and :class:`Recompute` without boilerplate.
    """
    batch = Batch(input, index)

    chk = Checkpointing(function, batch)
    batch = chk.checkpoint()
    chk.recompute(batch)

    return batch.tensor_or_tensors


Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
@pytest.mark.parametrize("device", devices)
def test_serial_checkpoints(device):
    # Copied from https://github.com/pytorch/pytorch/pull/18568.
    timeline = []

    class Log(torch.autograd.Function):
        @staticmethod
        def forward(ctx, name, x):
            ctx.name = name
            timeline.append(f"{name}:forward")
            return x.detach()

        @staticmethod
        def backward(ctx, grad_output):
            name = ctx.name
            timeline.append(f"{name}:backward")
            return None, grad_output

    a = torch.rand(1, device=device, requires_grad=True)
    b = torch.rand(1, device=device, requires_grad=True)

    # Increase the next function sequence number.
    _ = a + 1 + 2 + 3 + 4 + 5

Tom Birch's avatar
Tom Birch committed
74
    a = make_checkpoint(partial(Log.apply, "a"), a, 0)
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
75
76
77
78

    a, phony = fork(a)
    b = join(b, phony)

Tom Birch's avatar
Tom Birch committed
79
    b = make_checkpoint(partial(Log.apply, "b"), b, 0)
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95

    c = torch.cat((a, b))

    out = c.sum()

    #                        +--> {a} --Checkpoint(Log)--> {a}
    # {out} --Sum--> {c} --Cat     ^-----------------------------+
    #                        +--> {b} --Checkpoint(Log)--> {b} --First--> {b}
    out.backward()

    assert timeline == ["a:forward", "b:forward", "b:forward", "b:backward", "a:forward", "a:backward"]
    #    |----------------------|  |-----------------------|  |-----------------------|
    #          forward pass            Checkpoint(Log[b])         Checkpoint(Log[a])


def test_not_requires_grad():
Tom Birch's avatar
Tom Birch committed
96
    x = Batch(torch.rand(1, requires_grad=False), 0)
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
    assert not x[0].requires_grad

    def f(x):
        return x * 2

    chk = Checkpointing(f, x)
    x = chk.checkpoint()
    assert x[0].requires_grad

    chk.recompute(x)
    assert x[0].requires_grad

    x.tensor.backward()


def test_not_requires_grad_with_parameter():
    x = torch.rand(1, requires_grad=False)
    a = torch.rand(1, requires_grad=True)

    def f(x):
        return x * a

Tom Birch's avatar
Tom Birch committed
119
    y = make_checkpoint(f, x, 0)
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
    y.backward()

    assert a.grad is not None


@pytest.mark.parametrize("device", devices)
def test_random_in_checkpoint(device):
    dropout = nn.Dropout(p=0.5)

    torch.manual_seed(0)
    x = torch.randn(3, 3, device=device, requires_grad=True)
    y = dropout(x)
    y.norm().backward()

    torch.manual_seed(0)
    chk_x = torch.randn(3, 3, device=device, requires_grad=True)
Tom Birch's avatar
Tom Birch committed
136
    chk_y = make_checkpoint(dropout, chk_x, 0)
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
    chk_y.norm().backward()

    assert torch.allclose(x.grad, chk_x.grad)


def test_detect_checkpointing_recomputing():
    logs = []

    class Detect(nn.Module):
        def forward(self, input):
            logs.append((is_checkpointing(), is_recomputing()))
            return input

    model = Detect()
    input = torch.rand(1, requires_grad=True)

Tom Birch's avatar
Tom Birch committed
153
    output = make_checkpoint(model, input, 0)
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
    output.backward()

    assert logs == [(True, False), (False, True)]


def test_detect_checkpointing_recomputing_without_checkpoint():
    logs = []

    class Detect(nn.Module):
        def forward(self, input):
            logs.append((is_checkpointing(), is_recomputing()))
            return input

    model = Detect()
    input = torch.rand(1, requires_grad=True)

    output = model(input)
    output.backward()

    assert logs == [(False, False)]


def test_non_grad_output():
    class ForkNonGrad(nn.Module):
        def forward(self, input):
            return (input * 2, torch.rand(1))

    model = ForkNonGrad()
    input = torch.rand(1, requires_grad=True)

Tom Birch's avatar
Tom Birch committed
184
    output = make_checkpoint(model, input, 0)
Mandeep Singh Baines's avatar
Mandeep Singh Baines committed
185
    output[0].backward()