test_adam.py 12.1 KB
Newer Older
Jun Ru Anderson's avatar
Jun Ru Anderson committed
1
2
3
4
5
6
7
8
9
10
11
12
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

from copy import deepcopy
import functools

import pytest
import torch

try:
13
    from fairscale.optim.adam import Adam, Precision
Jun Ru Anderson's avatar
Jun Ru Anderson committed
14
15
16
17
18
19
20
21
22

    imported_adam = True
except ImportError:
    imported_adam = False

skip_if_no_cuda = pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
skip_if_no_adam = pytest.mark.skipif(not imported_adam, reason="Fairscale Adam not available")


23
24
25
26
27
28
@pytest.fixture(autouse=True)
def set_torch_seed():
    torch.manual_seed(1)
    yield


29
30
31
32
33
34
35
36
37
38
39
40
41
42
def make_full_precision_params():
    weight = torch.randn(2, 1).cuda().requires_grad_()
    bias = torch.randn(2).cuda().requires_grad_()
    input = torch.randn(1).cuda()

    return weight, bias, input


def make_half_precision_params():
    weight = torch.randn(2, 1).cuda().half().requires_grad_()
    bias = torch.randn(2).cuda().half().requires_grad_()
    input = torch.randn(1).half().cuda()

    return weight, bias, input
Jun Ru Anderson's avatar
Jun Ru Anderson committed
43
44


45
def step_test(optimizer, weight, bias, input):
46
47
48
    # to check if the optimizer can be printed as a string
    optimizer.__repr__()

Jun Ru Anderson's avatar
Jun Ru Anderson committed
49
50
51
52
53
54
55
56
57
58
59
60
61
    def fn():
        optimizer.zero_grad()
        y = weight.mv(input)
        if y.is_cuda and bias.is_cuda and y.get_device() != bias.get_device():
            y = y.cuda(bias.get_device())
        loss = (y + bias).pow(2).sum()
        loss.backward()
        return loss

    initial_value = fn().item()
    for _i in range(5):
        optimizer.step(fn)
    assert fn().item() < initial_value
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78


def state_dict_test(optimizer, weight, bias, input):
    def fn_base(optimizer, weight, bias, input):
        optimizer.zero_grad()
        loss = (weight.mv(input) + bias).pow(2).sum()
        loss.backward()
        return loss

    fn = functools.partial(fn_base, optimizer, weight, bias, input)

    # Prime the optimizer
    for _i in range(5):
        optimizer.step(fn)
    # Clone the weights and construct new optimizer for them
    weight_c = weight.data.clone().requires_grad_()
    bias_c = bias.data.clone().requires_grad_()
79
    optimizer_c = Adam([weight_c, bias_c], lr=1e-3, precision=optimizer.precision)
80
81
82
83
    fn_c = functools.partial(fn_base, optimizer_c, weight_c, bias_c, input)
    # Load state dict
    state_dict = deepcopy(optimizer.state_dict())
    optimizer_c.load_state_dict(state_dict)
84
85
86
87
88
89
90
91
92
93
94
95
96

    for group, group_c in zip(optimizer.param_groups, optimizer_c.param_groups):
        for p, p_c in zip(group["params"], group_c["params"]):
            assert torch.equal(optimizer.state[p]["exp_avg"], optimizer_c.state[p_c]["exp_avg"])
            assert torch.equal(optimizer.state[p]["exp_avg_sq"], optimizer_c.state[p_c]["exp_avg_sq"])

    if optimizer.fp32_param_groups:
        # When using mixed precision, fp32_param_groups are made from FP16 params rather than
        # copied via state_dict, introducing differences between the original optimizer and
        # the copy. Because this test requires that they be the exact same, we copy the
        # fp32 params from the original optimizer to the copy
        optimizer_c.fp32_param_groups = deepcopy(optimizer.fp32_param_groups)

97
98
99
100
    # Run both optimizations in parallel
    for _i in range(5):
        optimizer.step(fn)
        optimizer_c.step(fn_c)
101
102
103

        assert torch.equal(weight, weight_c)
        assert torch.equal(bias, bias_c)
104
105


106
107
108
109
110
def assert_almost_zero(x):
    assert abs(x) < 1e-3
    return 1.0


111
112
@skip_if_no_cuda
@skip_if_no_adam
113
def test_step_full_precision_inferred():
114
    weight, bias, input = make_full_precision_params()
115
116
    optimizer = Adam([weight, bias], lr=1e-3)

117
    step_test(optimizer, weight, bias, input)
118
119
120
121

    for group in optimizer.param_groups:
        for p in group["params"]:
            if p.requires_grad:
122
123
                assert p.dtype == torch.float32
    assert not optimizer.fp32_param_groups
124

125
126
127
128
129
    assert optimizer.state[weight]["exp_avg"].dtype == torch.float32
    assert optimizer.state[weight]["exp_avg_sq"].dtype == torch.float32
    assert optimizer.state[bias]["exp_avg"].dtype == torch.float32
    assert optimizer.state[bias]["exp_avg_sq"].dtype == torch.float32

130
131
132

@skip_if_no_cuda
@skip_if_no_adam
133
def test_step_mixed_precision_inferred():
134
    weight, bias, input = make_half_precision_params()
135
136
    optimizer = Adam([weight, bias], lr=1e-3)
    step_test(optimizer, weight, bias, input)
137
138
139
140
141
142
143
144
145
146
147
148
149
150

    assert len(optimizer.fp32_param_groups) == len(optimizer.param_groups)

    for fp32_group, fp16_group in zip(optimizer.fp32_param_groups, optimizer.param_groups):
        for fp32_p, fp16_p in zip(fp32_group["params"], fp16_group["params"]):

            def assert_almost_zero(x):
                assert abs(x) < 1e-3
                return 1.0

            assert fp32_p.dtype == torch.float32
            if fp16_p.requires_grad:
                assert fp16_p.dtype == torch.float16
                (fp32_p - fp16_p).to("cpu").detach().apply_(assert_almost_zero)
Jun Ru Anderson's avatar
Jun Ru Anderson committed
151

152
153
154
155
156
    assert optimizer.state[weight]["exp_avg"].dtype == torch.float32
    assert optimizer.state[weight]["exp_avg_sq"].dtype == torch.float32
    assert optimizer.state[bias]["exp_avg"].dtype == torch.float32
    assert optimizer.state[bias]["exp_avg_sq"].dtype == torch.float32

Jun Ru Anderson's avatar
Jun Ru Anderson committed
157

158
159
160
@skip_if_no_cuda
@skip_if_no_adam
def test_step_memory_efficient():
161
    weight, bias, input = make_half_precision_params()
162
163
164
165
166
167
168
    optimizer = Adam([weight, bias], lr=1e-3, precision=Precision.MEMORY_EFFICIENT_MIXED_PRECISION)
    step_test(optimizer, weight, bias, input)

    for group in optimizer.param_groups:
        for p in group["params"]:
            if p.requires_grad:
                assert p.dtype == torch.float16
169

170
171
    assert not optimizer.fp32_param_groups

172
173
174
175
176
    assert optimizer.state[weight]["exp_avg"].dtype == torch.float32
    assert optimizer.state[weight]["exp_avg_sq"].dtype == torch.float32
    assert optimizer.state[bias]["exp_avg"].dtype == torch.float32
    assert optimizer.state[bias]["exp_avg_sq"].dtype == torch.float32

177
178
179
180

@skip_if_no_cuda
@skip_if_no_adam
def test_step_pure_fp16():
181
    weight, bias, input = make_half_precision_params()
182
183
184
    optimizer = Adam([weight, bias], lr=1e-3, precision=Precision.PURE_FP16)
    step_test(optimizer, weight, bias, input)

185
186
187
188
189
    for group in optimizer.param_groups:
        for p in group["params"]:
            if p.requires_grad:
                assert p.dtype == torch.float16

190
191
192
193
    assert optimizer.state[weight]["exp_avg"].dtype == torch.float16
    assert optimizer.state[weight]["exp_avg_sq"].dtype == torch.float16
    assert optimizer.state[bias]["exp_avg"].dtype == torch.float16
    assert optimizer.state[bias]["exp_avg_sq"].dtype == torch.float16
194

195
196
197
    assert not optimizer.fp32_param_groups


Jun Ru Anderson's avatar
Jun Ru Anderson committed
198
199
200
201
202
203
204
205
206
207
@skip_if_no_cuda
@skip_if_no_adam
def test_step_multigpu():
    if not torch.cuda.device_count() > 1:
        return
    weight = torch.randn(10, 5).cuda(0).requires_grad_()
    bias = torch.randn(10).cuda(1).requires_grad_()
    input = torch.randn(5).cuda(0)
    optimizer = Adam([weight, bias], lr=1e-3)

208
    step_test(optimizer, weight, bias, input)
Jun Ru Anderson's avatar
Jun Ru Anderson committed
209
210


211
212
213
214
215
216
217
218
@skip_if_no_cuda
@skip_if_no_adam
def test_step_multigpu_mixed_precision():
    if not torch.cuda.device_count() > 1:
        return
    weight = torch.randn(10, 5).cuda(0).half().requires_grad_()
    bias = torch.randn(10).cuda(1).half().requires_grad_()
    input = torch.randn(5).cuda(0).half()
219
    optimizer = Adam([weight, bias], lr=1e-3)
220

221
    step_test(optimizer, weight, bias, input)
222
223


224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
@skip_if_no_cuda
@skip_if_no_adam
def test_step_pure_fp16_multigpu():
    if not torch.cuda.device_count() > 1:
        return
    weight = torch.randn(10, 5).half().cuda(0).requires_grad_()
    bias = torch.randn(10).half().cuda(1).requires_grad_()
    input = torch.randn(5).half().cuda(0)
    optimizer = Adam([weight, bias], lr=1e-3, precision=Precision.PURE_FP16)

    step_test(optimizer, weight, bias, input)

    assert optimizer.state[weight]["exp_avg"].dtype == torch.float16
    assert optimizer.state[weight]["exp_avg_sq"].dtype == torch.float16
    assert optimizer.state[bias]["exp_avg"].dtype == torch.float16
    assert optimizer.state[bias]["exp_avg_sq"].dtype == torch.float16
240
241


Jun Ru Anderson's avatar
Jun Ru Anderson committed
242
243
@skip_if_no_cuda
@skip_if_no_adam
244
def test_state_dict_full_precision():
245
    weight, bias, input = make_full_precision_params()
Jun Ru Anderson's avatar
Jun Ru Anderson committed
246
247
    optimizer = Adam([weight, bias], lr=1e-3)

248
    state_dict_test(optimizer, weight, bias, input)
Jun Ru Anderson's avatar
Jun Ru Anderson committed
249
250


251
252
@skip_if_no_cuda
@skip_if_no_adam
253
@pytest.mark.xfail
254
def test_state_dict_mixed_precision():
255
256
257
258
    # TODO: Optimizer state gets cast to FP16 and back to FP32 for
    # mixed-precision and memory-efficient mixed-precision, resulting
    # in a potential loss of precision. Thus, as training proceeds, we don't
    # necessarily expect the parameters to remain the exact same.
259
    weight, bias, input = make_half_precision_params()
260
261
262
263
264
265
266
    optimizer = Adam([weight, bias], lr=1e-3, precision=Precision.MIXED_PRECISION)

    state_dict_test(optimizer, weight, bias, input)


@skip_if_no_cuda
@skip_if_no_adam
267
@pytest.mark.xfail
268
def test_state_dict_memory_efficient():
269
270
271
272
    # TODO: Optimizer state gets cast to FP16 and back to FP32 for
    # mixed-precision and memory-efficient mixed-precision, resulting
    # in a potential loss of precision. Thus, as training proceeds, we don't
    # necessarily expect the parameters to remain the exact same.
273
    weight, bias, input = make_half_precision_params()
274
275
276
277
278
279
280
281
    optimizer = Adam([weight, bias], lr=1e-3, precision=Precision.MEMORY_EFFICIENT_MIXED_PRECISION)

    state_dict_test(optimizer, weight, bias, input)


@skip_if_no_cuda
@skip_if_no_adam
def test_state_dict_pure_fp16():
282
    weight, bias, input = make_half_precision_params()
283
284
285
    optimizer = Adam([weight, bias], lr=1e-3, precision=Precision.PURE_FP16)

    state_dict_test(optimizer, weight, bias, input)
Jun Ru Anderson's avatar
Jun Ru Anderson committed
286
287


288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
@skip_if_no_cuda
@skip_if_no_adam
def test_build_fp32_params():
    weight = torch.randn(10, 5).cuda().half().requires_grad_()
    bias = torch.randn(10).cuda().half().requires_grad_()
    optimizer = Adam([weight, bias], lr=1e-3)
    optimizer._build_fp32_params([weight, bias])
    for fp32_group, fp16_group in zip(optimizer.fp32_param_groups, optimizer.param_groups):
        for fp32_p, fp16_p in zip(fp32_group["params"], fp16_group["params"]):
            assert fp32_p.dtype == torch.float32
            if fp16_p.requires_grad:
                assert fp16_p.dtype == torch.float16
                (fp32_p - fp16_p).to("cpu").detach().apply_(assert_almost_zero)


Jun Ru Anderson's avatar
Jun Ru Anderson committed
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
@skip_if_no_cuda
@skip_if_no_adam
def test_invalid_beta():
    weight = torch.randn(10, 5, requires_grad=True).float().cuda()
    bias = torch.randn(10, requires_grad=True).float().cuda()
    with pytest.raises(ValueError):
        Adam([weight, bias], lr=1e-2, betas=(1.0, 0.0))


@skip_if_no_cuda
@skip_if_no_adam
def test_invalid_weight_decay():
    weight = torch.randn(10, 5, requires_grad=True).float().cuda()
    bias = torch.randn(10, requires_grad=True).float().cuda()
    with pytest.raises(ValueError):
        Adam([weight, bias], lr=1e-2, weight_decay=-1)


@skip_if_no_cuda
@skip_if_no_adam
def test_amsgrad():
    weight = torch.randn(10, 5, requires_grad=True).float().cuda()
    bias = torch.randn(10, requires_grad=True).float().cuda()
    with pytest.raises(RuntimeError):
        Adam([weight, bias], lr=1e-2, amsgrad=True)
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354


@skip_if_no_cuda
@skip_if_no_adam
def test_mixed_precision_with_full_precision_parameters():
    weight = torch.randn(10, 5, requires_grad=True).float().cuda()
    bias = torch.randn(10, requires_grad=True).float().cuda()
    with pytest.raises(AssertionError):
        Adam([weight, bias], lr=1e-2, precision=Precision.MIXED_PRECISION)


@skip_if_no_cuda
@skip_if_no_adam
def test_memory_efficient_with_full_precision_parameters():
    weight = torch.randn(10, 5, requires_grad=True).float().cuda()
    bias = torch.randn(10, requires_grad=True).float().cuda()
    with pytest.raises(AssertionError):
        Adam([weight, bias], lr=1e-2, precision=Precision.MEMORY_EFFICIENT_MIXED_PRECISION)


@skip_if_no_cuda
@skip_if_no_adam
def test_pure_fp16_with_full_precision_parameters():
    weight = torch.randn(10, 5, requires_grad=True).float().cuda()
    bias = torch.randn(10, requires_grad=True).float().cuda()
    with pytest.raises(AssertionError):
        Adam([weight, bias], lr=1e-2, precision=Precision.PURE_FP16)