test_activation_checkpointing.py 8.03 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
# TODO: add tests with model parallelism for activation partitioning and other features.

from copy import deepcopy

import pytest

import torch

import deepspeed
ckpt = deepspeed.checkpointing.checkpoint

from common import distributed_test


def _compute(module, *inputs, do_checkpoint=False):
    if do_checkpoint:
        outputs = ckpt(module, *inputs)
    else:
        outputs = module(*inputs)

    if torch.is_tensor(outputs):
        outputs = (outputs, )

24
25
    sum(o.sum() for o in outputs if torch.is_tensor(o) and o.requires_grad).backward()

26
    grads = [p.grad for p in module.parameters()]
27
    input_grads = [inp.grad for inp in inputs if torch.is_tensor(inp)]
28
29
30
31
32
33
34
35

    return {
        'outputs': outputs,
        'module_grads': grads,
        'input_grads': input_grads,
    }


36
37
38
39
40
41
42
43
44
45
46
47
def _prep_inputs(*inputs):
    _inputs = []

    for inp in inputs:
        inp = deepcopy(inp)
        if torch.is_tensor(inp):
            inp = inp.cuda()
        _inputs.append(inp)

    return tuple(_inputs)


48
49
50
51
52
53
54
55
56
57
58
59
60
def _match_outputs(ref, tgt):
    assert type(ref) == type(tgt)
    if type(ref) in [list, tuple]:
        for x, y in zip(ref, tgt):
            _match_outputs(x, y)
    elif not torch.is_tensor(ref):
        assert ref == tgt
    elif ref.is_floating_point():
        assert torch.allclose(ref, tgt)
    else:
        assert torch.equal(ref, tgt)


61
62
63
64
65
66
67
68
69
70
71
# This is distributed because checkpoint() assumes that torch.distributed is initialized.
# torch.distributed is used with activation partitioning, but not for these simple cases.
@distributed_test(world_size=1)
def _test_activation_checkpoint(module, *inputs):
    # Move to device
    module.cuda()

    # Get rid of dropouts until we fork the RNG between tests.
    module.eval()

    module_ = deepcopy(module)
72
    inputs_ = _prep_inputs(*inputs)
73
74
75
    base = _compute(module_, *inputs_, do_checkpoint=False)

    module_ = deepcopy(module)
76
    inputs_ = _prep_inputs(*inputs)
77
78
79
80
    test = _compute(module_, *inputs_, do_checkpoint=True)

    for group in base.keys():
        for b, t in zip(base[group], test[group]):
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
            _match_outputs(b, t)


# This is distributed because checkpoint() assumes that torch.distributed is initialized.
# torch.distributed is used with activation partitioning, but not for these simple cases.
@distributed_test(world_size=1)
def _test_activation_checkpoint_ordering(module, expected_ordering, *inputs):
    # Move to device
    module.cuda()

    # Get rid of dropouts until we fork the RNG between tests.
    module.eval()

    module_ = deepcopy(module)
    inputs_ = _prep_inputs(*inputs)
    test = _compute(module_, *inputs_, do_checkpoint=True)

    outputs = test['outputs']
    test_ordering = []
    for item in outputs:
        if type(item) in [list, tuple]:
            test_ordering += [torch.is_tensor(t) for t in item]
        else:
            test_ordering += [torch.is_tensor(item)]

    assert expected_ordering == test_ordering
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202


#
# Helpers
#


class MaskedLinear(torch.nn.Linear):
    def forward(self, x, mask):
        out = super().forward(x)
        if mask.is_floating_point():
            out = out * mask
        else:
            # must cast BoolTensor in older torch versions
            out = out * mask.type_as(out)
        return out


class MaskedLinearSeq(MaskedLinear):
    """Tests pipeline modules by also returning the mask."""
    def forward(self, x, mask):
        return super().forward(x, mask), mask


class MaskedLinearSeqDup(MaskedLinearSeq):
    """MaskedLinearSeq, but with more outputs than inputs and in a different order."""
    def forward(self, x, mask):
        dup = x.clone().detach() * 1.38  # just an arbitrary scaling
        x, mask = super().forward(x, mask)
        return dup, x, mask


HIDDEN_DIM = 20


def _mixed_mask(size=HIDDEN_DIM):
    entries = torch.randn(size)
    mask = torch.where(entries > 0, torch.ones(size), torch.zeros(size))
    mask = mask.bool()
    return mask


def _bool_to_float(btensor, dtype=torch.float32):
    """Converts a torch.BoolTensor to an equivalent dtype. """
    ones = torch.ones(size=btensor.size(), dtype=dtype)
    zeros = torch.zeros(size=btensor.size(), dtype=dtype)
    return torch.where(btensor, ones, zeros)


#
# Tests
#


def test_ckpt_inputs1_outputs1():
    module = torch.nn.Linear(HIDDEN_DIM, HIDDEN_DIM)
    inputs = torch.rand(HIDDEN_DIM)
    inputs.requires_grad = True
    _test_activation_checkpoint(module, inputs)


# both bool and float are important, as bool is not diffentiable
@pytest.mark.parametrize('mask',
                         [
                             _mixed_mask(),
                             _bool_to_float(_mixed_mask()),
                         ])
def test_ckpt_inputs2_outputs1(mask):
    module = MaskedLinear(HIDDEN_DIM, HIDDEN_DIM)
    inputs = torch.rand(HIDDEN_DIM)
    inputs.requires_grad = True
    _test_activation_checkpoint(module, inputs, mask)


@pytest.mark.parametrize('mask',
                         [
                             _mixed_mask(),
                             _bool_to_float(_mixed_mask()),
                         ])
def test_ckpt_inputs2_outputs2(mask):
    module = MaskedLinearSeq(HIDDEN_DIM, HIDDEN_DIM)
    inputs = torch.rand(HIDDEN_DIM)
    inputs.requires_grad = True
    _test_activation_checkpoint(module, inputs, mask)


@pytest.mark.parametrize('mask',
                         [
                             _mixed_mask(),
                             _bool_to_float(_mixed_mask()),
                         ])
def test_ckpt_inputs2_outputs3(mask):
    module = MaskedLinearSeqDup(HIDDEN_DIM, HIDDEN_DIM)
    inputs = torch.rand(HIDDEN_DIM)
    inputs.requires_grad = True
    _test_activation_checkpoint(module, inputs, mask)
203
204
205
206
207
208
209
210
211
212
213
214


class DropMaskLinear(torch.nn.Linear):
    def forward(self, x, mask):
        return super().forward(x)


def test_ckpt_arg_none():
    module = DropMaskLinear(HIDDEN_DIM, HIDDEN_DIM)
    inputs = (torch.rand(HIDDEN_DIM), None)
    inputs[0].requires_grad = True
    _test_activation_checkpoint(module, *inputs)
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289


class LinearNonTensorInput(torch.nn.Linear):
    def forward(self, x, non_tensor_input):
        return super().forward(x)


@pytest.mark.parametrize(
    'non_tensor_input',
    [None,
     2,
     True,
     (None,
      2.5),
     (None,
      True,
      torch.randn(HIDDEN_DIM))])
def test_ckpt_non_tensor_input(non_tensor_input):
    module = LinearNonTensorInput(HIDDEN_DIM, HIDDEN_DIM)
    inputs = torch.rand(HIDDEN_DIM)
    inputs.requires_grad = True
    _test_activation_checkpoint(module, inputs, non_tensor_input)


class LinearNonTensorOutput(torch.nn.Linear):
    def __init__(self, non_tensor_output):
        super().__init__(HIDDEN_DIM, HIDDEN_DIM)
        self.non_tensor_output = non_tensor_output

    def forward(self, x):
        out = super().forward(x)
        return out, self.non_tensor_output


@pytest.mark.parametrize(
    'non_tensor_output',
    [None,
     2,
     True,
     (None,
      2.5),
     (None,
      True,
      torch.randn(HIDDEN_DIM))])
def test_ckpt_non_tensor_output(non_tensor_output):
    module = LinearNonTensorOutput(non_tensor_output)
    inputs = torch.rand(HIDDEN_DIM)
    inputs.requires_grad = True
    _test_activation_checkpoint(module, inputs)


@pytest.mark.parametrize('non_tensor_output',
                         [
                             None,
                             (torch.randn(HIDDEN_DIM),
                              2.5),
                             (None,
                              torch.randn(HIDDEN_DIM),
                              True),
                             (None,
                              True,
                              torch.randn(HIDDEN_DIM))
                         ])
def test_ckpt_non_tensor_output_ordering(non_tensor_output):
    module = LinearNonTensorOutput(non_tensor_output)
    inputs = torch.rand(HIDDEN_DIM)
    inputs.requires_grad = True

    # First return is a tensor
    ordering = [True]
    if type(non_tensor_output) in [list, tuple]:
        ordering += [torch.is_tensor(t) for t in non_tensor_output]
    else:
        ordering += [torch.is_tensor(non_tensor_output)]
    _test_activation_checkpoint_ordering(module, ordering, inputs)