test_activation_checkpointing.py 4.93 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
# TODO: add tests with model parallelism for activation partitioning and other features.

from copy import deepcopy

import pytest

import torch

import deepspeed
ckpt = deepspeed.checkpointing.checkpoint

from common import distributed_test


def _compute(module, *inputs, do_checkpoint=False):
    if do_checkpoint:
        outputs = ckpt(module, *inputs)
    else:
        outputs = module(*inputs)

    if torch.is_tensor(outputs):
        outputs = (outputs, )

    sum(o.sum() for o in outputs if o.requires_grad).backward()
    grads = [p.grad for p in module.parameters()]
26
    input_grads = [inp.grad for inp in inputs if torch.is_tensor(inp)]
27
28
29
30
31
32
33
34

    return {
        'outputs': outputs,
        'module_grads': grads,
        'input_grads': input_grads,
    }


35
36
37
38
39
40
41
42
43
44
45
46
def _prep_inputs(*inputs):
    _inputs = []

    for inp in inputs:
        inp = deepcopy(inp)
        if torch.is_tensor(inp):
            inp = inp.cuda()
        _inputs.append(inp)

    return tuple(_inputs)


47
48
49
50
51
52
53
54
55
56
57
# This is distributed because checkpoint() assumes that torch.distributed is initialized.
# torch.distributed is used with activation partitioning, but not for these simple cases.
@distributed_test(world_size=1)
def _test_activation_checkpoint(module, *inputs):
    # Move to device
    module.cuda()

    # Get rid of dropouts until we fork the RNG between tests.
    module.eval()

    module_ = deepcopy(module)
58
    inputs_ = _prep_inputs(*inputs)
59
60
61
    base = _compute(module_, *inputs_, do_checkpoint=False)

    module_ = deepcopy(module)
62
    inputs_ = _prep_inputs(*inputs)
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
    test = _compute(module_, *inputs_, do_checkpoint=True)

    for group in base.keys():
        for b, t in zip(base[group], test[group]):
            # Catch grad `None`s, etc.
            if not torch.is_tensor(b):
                assert b == t
            elif b.is_floating_point():
                assert torch.allclose(b, t)
            else:
                assert torch.equal(b, t)


#
# Helpers
#


class MaskedLinear(torch.nn.Linear):
    def forward(self, x, mask):
        out = super().forward(x)
        if mask.is_floating_point():
            out = out * mask
        else:
            # must cast BoolTensor in older torch versions
            out = out * mask.type_as(out)
        return out


class MaskedLinearSeq(MaskedLinear):
    """Tests pipeline modules by also returning the mask."""
    def forward(self, x, mask):
        return super().forward(x, mask), mask


class MaskedLinearSeqDup(MaskedLinearSeq):
    """MaskedLinearSeq, but with more outputs than inputs and in a different order."""
    def forward(self, x, mask):
        dup = x.clone().detach() * 1.38  # just an arbitrary scaling
        x, mask = super().forward(x, mask)
        return dup, x, mask


HIDDEN_DIM = 20


def _mixed_mask(size=HIDDEN_DIM):
    entries = torch.randn(size)
    mask = torch.where(entries > 0, torch.ones(size), torch.zeros(size))
    mask = mask.bool()
    return mask


def _bool_to_float(btensor, dtype=torch.float32):
    """Converts a torch.BoolTensor to an equivalent dtype. """
    ones = torch.ones(size=btensor.size(), dtype=dtype)
    zeros = torch.zeros(size=btensor.size(), dtype=dtype)
    return torch.where(btensor, ones, zeros)


#
# Tests
#


def test_ckpt_inputs1_outputs1():
    module = torch.nn.Linear(HIDDEN_DIM, HIDDEN_DIM)
    inputs = torch.rand(HIDDEN_DIM)
    inputs.requires_grad = True
    _test_activation_checkpoint(module, inputs)


# both bool and float are important, as bool is not diffentiable
@pytest.mark.parametrize('mask',
                         [
                             _mixed_mask(),
                             _bool_to_float(_mixed_mask()),
                         ])
def test_ckpt_inputs2_outputs1(mask):
    module = MaskedLinear(HIDDEN_DIM, HIDDEN_DIM)
    inputs = torch.rand(HIDDEN_DIM)
    inputs.requires_grad = True
    _test_activation_checkpoint(module, inputs, mask)


@pytest.mark.parametrize('mask',
                         [
                             _mixed_mask(),
                             _bool_to_float(_mixed_mask()),
                         ])
def test_ckpt_inputs2_outputs2(mask):
    module = MaskedLinearSeq(HIDDEN_DIM, HIDDEN_DIM)
    inputs = torch.rand(HIDDEN_DIM)
    inputs.requires_grad = True
    _test_activation_checkpoint(module, inputs, mask)


@pytest.mark.parametrize('mask',
                         [
                             _mixed_mask(),
                             _bool_to_float(_mixed_mask()),
                         ])
def test_ckpt_inputs2_outputs3(mask):
    module = MaskedLinearSeqDup(HIDDEN_DIM, HIDDEN_DIM)
    inputs = torch.rand(HIDDEN_DIM)
    inputs.requires_grad = True
    _test_activation_checkpoint(module, inputs, mask)
170
171
172
173
174
175
176
177
178
179
180
181


class DropMaskLinear(torch.nn.Linear):
    def forward(self, x, mask):
        return super().forward(x)


def test_ckpt_arg_none():
    module = DropMaskLinear(HIDDEN_DIM, HIDDEN_DIM)
    inputs = (torch.rand(HIDDEN_DIM), None)
    inputs[0].requires_grad = True
    _test_activation_checkpoint(module, *inputs)