test_adam.py 10.7 KB
Newer Older
Jun Ru Anderson's avatar
Jun Ru Anderson committed
1
2
3
4
5
6
7
8
9
10
11
12
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

from copy import deepcopy
import functools

import pytest
import torch

try:
13
    from fairscale.optim.adam import Adam, Precision
Jun Ru Anderson's avatar
Jun Ru Anderson committed
14
15
16
17
18
19
20
21
22

    imported_adam = True
except ImportError:
    imported_adam = False

skip_if_no_cuda = pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
skip_if_no_adam = pytest.mark.skipif(not imported_adam, reason="Fairscale Adam not available")


23
24
25
26
27
28
29
30
31
32
33
34
35
36
def make_full_precision_params():
    weight = torch.randn(2, 1).cuda().requires_grad_()
    bias = torch.randn(2).cuda().requires_grad_()
    input = torch.randn(1).cuda()

    return weight, bias, input


def make_half_precision_params():
    weight = torch.randn(2, 1).cuda().half().requires_grad_()
    bias = torch.randn(2).cuda().half().requires_grad_()
    input = torch.randn(1).half().cuda()

    return weight, bias, input
Jun Ru Anderson's avatar
Jun Ru Anderson committed
37
38


39
def step_test(optimizer, weight, bias, input):
40
41
42
    # to check if the optimizer can be printed as a string
    optimizer.__repr__()

Jun Ru Anderson's avatar
Jun Ru Anderson committed
43
44
45
46
47
48
49
50
51
52
53
54
55
    def fn():
        optimizer.zero_grad()
        y = weight.mv(input)
        if y.is_cuda and bias.is_cuda and y.get_device() != bias.get_device():
            y = y.cuda(bias.get_device())
        loss = (y + bias).pow(2).sum()
        loss.backward()
        return loss

    initial_value = fn().item()
    for _i in range(5):
        optimizer.step(fn)
    assert fn().item() < initial_value
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72


def state_dict_test(optimizer, weight, bias, input):
    def fn_base(optimizer, weight, bias, input):
        optimizer.zero_grad()
        loss = (weight.mv(input) + bias).pow(2).sum()
        loss.backward()
        return loss

    fn = functools.partial(fn_base, optimizer, weight, bias, input)

    # Prime the optimizer
    for _i in range(5):
        optimizer.step(fn)
    # Clone the weights and construct new optimizer for them
    weight_c = weight.data.clone().requires_grad_()
    bias_c = bias.data.clone().requires_grad_()
73
    optimizer_c = Adam([weight_c, bias_c], lr=1e-3, precision=optimizer.precision)
74
75
76
77
78
79
80
81
82
83
    fn_c = functools.partial(fn_base, optimizer_c, weight_c, bias_c, input)
    # Load state dict
    state_dict = deepcopy(optimizer.state_dict())
    optimizer_c.load_state_dict(state_dict)
    # Run both optimizations in parallel
    for _i in range(5):
        optimizer.step(fn)
        optimizer_c.step(fn_c)
        (weight - weight_c).to("cpu").detach().apply_(assert_almost_zero)
        (bias - bias_c).to("cpu").detach().apply_(assert_almost_zero)
84
85


86
87
88
89
90
def assert_almost_zero(x):
    assert abs(x) < 1e-3
    return 1.0


91
92
@skip_if_no_cuda
@skip_if_no_adam
93
def test_step_full_precision_inferred():
94
    weight, bias, input = make_full_precision_params()
95
96
    optimizer = Adam([weight, bias], lr=1e-3)

97
    step_test(optimizer, weight, bias, input)
98
99
100
101

    for group in optimizer.param_groups:
        for p in group["params"]:
            if p.requires_grad:
102
103
                assert p.dtype == torch.float32
    assert not optimizer.fp32_param_groups
104

105
106
107
108
109
    assert optimizer.state[weight]["exp_avg"].dtype == torch.float32
    assert optimizer.state[weight]["exp_avg_sq"].dtype == torch.float32
    assert optimizer.state[bias]["exp_avg"].dtype == torch.float32
    assert optimizer.state[bias]["exp_avg_sq"].dtype == torch.float32

110
111
112

@skip_if_no_cuda
@skip_if_no_adam
113
def test_step_mixed_precision_inferred():
114
    weight, bias, input = make_half_precision_params()
115
116
    optimizer = Adam([weight, bias], lr=1e-3)
    step_test(optimizer, weight, bias, input)
117
118
119
120
121
122
123
124
125
126
127
128
129
130

    assert len(optimizer.fp32_param_groups) == len(optimizer.param_groups)

    for fp32_group, fp16_group in zip(optimizer.fp32_param_groups, optimizer.param_groups):
        for fp32_p, fp16_p in zip(fp32_group["params"], fp16_group["params"]):

            def assert_almost_zero(x):
                assert abs(x) < 1e-3
                return 1.0

            assert fp32_p.dtype == torch.float32
            if fp16_p.requires_grad:
                assert fp16_p.dtype == torch.float16
                (fp32_p - fp16_p).to("cpu").detach().apply_(assert_almost_zero)
Jun Ru Anderson's avatar
Jun Ru Anderson committed
131

132
133
134
135
136
    assert optimizer.state[weight]["exp_avg"].dtype == torch.float32
    assert optimizer.state[weight]["exp_avg_sq"].dtype == torch.float32
    assert optimizer.state[bias]["exp_avg"].dtype == torch.float32
    assert optimizer.state[bias]["exp_avg_sq"].dtype == torch.float32

Jun Ru Anderson's avatar
Jun Ru Anderson committed
137

138
139
140
@skip_if_no_cuda
@skip_if_no_adam
def test_step_memory_efficient():
141
    weight, bias, input = make_half_precision_params()
142
143
144
145
146
147
148
    optimizer = Adam([weight, bias], lr=1e-3, precision=Precision.MEMORY_EFFICIENT_MIXED_PRECISION)
    step_test(optimizer, weight, bias, input)

    for group in optimizer.param_groups:
        for p in group["params"]:
            if p.requires_grad:
                assert p.dtype == torch.float16
149

150
151
    assert not optimizer.fp32_param_groups

152
153
154
155
156
    assert optimizer.state[weight]["exp_avg"].dtype == torch.float32
    assert optimizer.state[weight]["exp_avg_sq"].dtype == torch.float32
    assert optimizer.state[bias]["exp_avg"].dtype == torch.float32
    assert optimizer.state[bias]["exp_avg_sq"].dtype == torch.float32

157
158
159
160

@skip_if_no_cuda
@skip_if_no_adam
def test_step_pure_fp16():
161
    weight, bias, input = make_half_precision_params()
162
163
164
    optimizer = Adam([weight, bias], lr=1e-3, precision=Precision.PURE_FP16)
    step_test(optimizer, weight, bias, input)

165
166
167
168
169
    for group in optimizer.param_groups:
        for p in group["params"]:
            if p.requires_grad:
                assert p.dtype == torch.float16

170
171
172
173
    assert optimizer.state[weight]["exp_avg"].dtype == torch.float16
    assert optimizer.state[weight]["exp_avg_sq"].dtype == torch.float16
    assert optimizer.state[bias]["exp_avg"].dtype == torch.float16
    assert optimizer.state[bias]["exp_avg_sq"].dtype == torch.float16
174

175
176
177
    assert not optimizer.fp32_param_groups


Jun Ru Anderson's avatar
Jun Ru Anderson committed
178
179
180
181
182
183
184
185
186
187
@skip_if_no_cuda
@skip_if_no_adam
def test_step_multigpu():
    if not torch.cuda.device_count() > 1:
        return
    weight = torch.randn(10, 5).cuda(0).requires_grad_()
    bias = torch.randn(10).cuda(1).requires_grad_()
    input = torch.randn(5).cuda(0)
    optimizer = Adam([weight, bias], lr=1e-3)

188
    step_test(optimizer, weight, bias, input)
Jun Ru Anderson's avatar
Jun Ru Anderson committed
189
190


191
192
193
194
195
196
197
198
@skip_if_no_cuda
@skip_if_no_adam
def test_step_multigpu_mixed_precision():
    if not torch.cuda.device_count() > 1:
        return
    weight = torch.randn(10, 5).cuda(0).half().requires_grad_()
    bias = torch.randn(10).cuda(1).half().requires_grad_()
    input = torch.randn(5).cuda(0).half()
199
    optimizer = Adam([weight, bias], lr=1e-3)
200

201
    step_test(optimizer, weight, bias, input)
202
203


204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
@skip_if_no_cuda
@skip_if_no_adam
def test_step_pure_fp16_multigpu():
    if not torch.cuda.device_count() > 1:
        return
    weight = torch.randn(10, 5).half().cuda(0).requires_grad_()
    bias = torch.randn(10).half().cuda(1).requires_grad_()
    input = torch.randn(5).half().cuda(0)
    optimizer = Adam([weight, bias], lr=1e-3, precision=Precision.PURE_FP16)

    step_test(optimizer, weight, bias, input)

    assert optimizer.state[weight]["exp_avg"].dtype == torch.float16
    assert optimizer.state[weight]["exp_avg_sq"].dtype == torch.float16
    assert optimizer.state[bias]["exp_avg"].dtype == torch.float16
    assert optimizer.state[bias]["exp_avg_sq"].dtype == torch.float16
220
221


Jun Ru Anderson's avatar
Jun Ru Anderson committed
222
223
@skip_if_no_cuda
@skip_if_no_adam
224
def test_state_dict_full_precision():
225
    weight, bias, input = make_full_precision_params()
Jun Ru Anderson's avatar
Jun Ru Anderson committed
226
227
    optimizer = Adam([weight, bias], lr=1e-3)

228
    state_dict_test(optimizer, weight, bias, input)
Jun Ru Anderson's avatar
Jun Ru Anderson committed
229
230


231
232
233
@skip_if_no_cuda
@skip_if_no_adam
def test_state_dict_mixed_precision():
234
    weight, bias, input = make_half_precision_params()
235
236
237
238
239
240
241
242
    optimizer = Adam([weight, bias], lr=1e-3, precision=Precision.MIXED_PRECISION)

    state_dict_test(optimizer, weight, bias, input)


@skip_if_no_cuda
@skip_if_no_adam
def test_state_dict_memory_efficient():
243
    weight, bias, input = make_half_precision_params()
244
245
246
247
248
249
250
251
    optimizer = Adam([weight, bias], lr=1e-3, precision=Precision.MEMORY_EFFICIENT_MIXED_PRECISION)

    state_dict_test(optimizer, weight, bias, input)


@skip_if_no_cuda
@skip_if_no_adam
def test_state_dict_pure_fp16():
252
    weight, bias, input = make_half_precision_params()
253
254
255
    optimizer = Adam([weight, bias], lr=1e-3, precision=Precision.PURE_FP16)

    state_dict_test(optimizer, weight, bias, input)
Jun Ru Anderson's avatar
Jun Ru Anderson committed
256
257


258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
@skip_if_no_cuda
@skip_if_no_adam
def test_build_fp32_params():
    weight = torch.randn(10, 5).cuda().half().requires_grad_()
    bias = torch.randn(10).cuda().half().requires_grad_()
    optimizer = Adam([weight, bias], lr=1e-3)
    optimizer._build_fp32_params([weight, bias])
    for fp32_group, fp16_group in zip(optimizer.fp32_param_groups, optimizer.param_groups):
        for fp32_p, fp16_p in zip(fp32_group["params"], fp16_group["params"]):
            assert fp32_p.dtype == torch.float32
            if fp16_p.requires_grad:
                assert fp16_p.dtype == torch.float16
                (fp32_p - fp16_p).to("cpu").detach().apply_(assert_almost_zero)


Jun Ru Anderson's avatar
Jun Ru Anderson committed
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
@skip_if_no_cuda
@skip_if_no_adam
def test_invalid_beta():
    weight = torch.randn(10, 5, requires_grad=True).float().cuda()
    bias = torch.randn(10, requires_grad=True).float().cuda()
    with pytest.raises(ValueError):
        Adam([weight, bias], lr=1e-2, betas=(1.0, 0.0))


@skip_if_no_cuda
@skip_if_no_adam
def test_invalid_weight_decay():
    weight = torch.randn(10, 5, requires_grad=True).float().cuda()
    bias = torch.randn(10, requires_grad=True).float().cuda()
    with pytest.raises(ValueError):
        Adam([weight, bias], lr=1e-2, weight_decay=-1)


@skip_if_no_cuda
@skip_if_no_adam
def test_amsgrad():
    weight = torch.randn(10, 5, requires_grad=True).float().cuda()
    bias = torch.randn(10, requires_grad=True).float().cuda()
    with pytest.raises(RuntimeError):
        Adam([weight, bias], lr=1e-2, amsgrad=True)
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324


@skip_if_no_cuda
@skip_if_no_adam
def test_mixed_precision_with_full_precision_parameters():
    weight = torch.randn(10, 5, requires_grad=True).float().cuda()
    bias = torch.randn(10, requires_grad=True).float().cuda()
    with pytest.raises(AssertionError):
        Adam([weight, bias], lr=1e-2, precision=Precision.MIXED_PRECISION)


@skip_if_no_cuda
@skip_if_no_adam
def test_memory_efficient_with_full_precision_parameters():
    weight = torch.randn(10, 5, requires_grad=True).float().cuda()
    bias = torch.randn(10, requires_grad=True).float().cuda()
    with pytest.raises(AssertionError):
        Adam([weight, bias], lr=1e-2, precision=Precision.MEMORY_EFFICIENT_MIXED_PRECISION)


@skip_if_no_cuda
@skip_if_no_adam
def test_pure_fp16_with_full_precision_parameters():
    weight = torch.randn(10, 5, requires_grad=True).float().cuda()
    bias = torch.randn(10, requires_grad=True).float().cuda()
    with pytest.raises(AssertionError):
        Adam([weight, bias], lr=1e-2, precision=Precision.PURE_FP16)