test_optimizer.py 10.5 KB
Newer Older
facebook-github-bot's avatar
facebook-github-bot committed
1
2
3
4
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved


5
import random
facebook-github-bot's avatar
facebook-github-bot committed
6
import unittest
Yanghan Wang's avatar
Yanghan Wang committed
7

facebook-github-bot's avatar
facebook-github-bot committed
8
import d2go.runner.default_runner as default_runner
Yanghan Wang's avatar
Yanghan Wang committed
9
import torch
10
11
12
13
14
15
16
from d2go.optimizer import (
    build_optimizer_mapper,
)
from d2go.optimizer.build import (
    expand_optimizer_param_groups,
    regroup_optimizer_param_groups,
)
17
from d2go.utils.testing import helper
Yanghan Wang's avatar
Yanghan Wang committed
18

facebook-github-bot's avatar
facebook-github-bot committed
19
20
21
22

class TestArch(torch.nn.Module):
    def __init__(self):
        super().__init__()
23
        self.conv = torch.nn.Conv2d(3, 4, kernel_size=5, stride=1, padding=1)
facebook-github-bot's avatar
facebook-github-bot committed
24
25
26
        self.bn = torch.nn.BatchNorm2d(4)
        self.relu = torch.nn.ReLU(inplace=True)
        self.avgpool = torch.nn.AdaptiveAvgPool2d((1, 1))
27
        self.linear = torch.nn.Linear(4, 1)
facebook-github-bot's avatar
facebook-github-bot committed
28
29
30
31
32
33

    def forward(self, x):
        ret = self.conv(x)
        ret = self.bn(ret)
        ret = self.relu(ret)
        ret = self.avgpool(ret)
34
35
        ret = torch.transpose(ret, 1, 3)
        ret = self.linear(ret)
facebook-github-bot's avatar
facebook-github-bot committed
36
37
        return ret

Yanghan Wang's avatar
Yanghan Wang committed
38

facebook-github-bot's avatar
facebook-github-bot committed
39
def _test_each_optimizer(cfg):
40
41
    print("Solver: " + str(cfg.SOLVER.OPTIMIZER))

facebook-github-bot's avatar
facebook-github-bot committed
42
    model = TestArch()
43
    criterion = torch.nn.BCEWithLogitsLoss()
facebook-github-bot's avatar
facebook-github-bot committed
44
45
    optimizer = build_optimizer_mapper(cfg, model)
    optimizer.zero_grad()
46
47
48
49
50
51
52

    random.seed(20210912)
    for _ in range(2500):
        target = torch.empty(1, 1, 1, 1).fill_(random.randint(0, 1))
        x = torch.add(torch.rand(1, 3, 16, 16), 2 * target)
        y_pred = model(x)
        loss = criterion(y_pred, target)
facebook-github-bot's avatar
facebook-github-bot committed
53
54
55
        loss.backward()
        optimizer.step()

56
57
58
59
60
61
62
63
64
65
    n_correct = 0
    for _ in range(200):
        target = torch.empty(1, 1, 1, 1).fill_(random.randint(0, 1))
        x = torch.add(torch.rand(1, 3, 16, 16), 2 * target)
        y_pred = torch.round(torch.sigmoid(model(x)))
        if y_pred == target:
            n_correct += 1

    print("Correct prediction rate {0}.".format(n_correct / 200))

facebook-github-bot's avatar
facebook-github-bot committed
66

67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
def _check_param_group(self, group, num_params=None, **kwargs):
    if num_params is not None:
        self.assertEqual(len(group["params"]), num_params)
    for key, val in kwargs.items():
        self.assertEqual(group[key], val)


def get_optimizer_cfg(
    lr,
    weight_decay=None,
    weight_decay_norm=None,
    weight_decay_bias=None,
    lr_mult=None,
):
    runner = default_runner.Detectron2GoRunner()
    cfg = runner.get_default_cfg()
    if lr is not None:
        cfg.SOLVER.BASE_LR = lr
    if weight_decay is not None:
        cfg.SOLVER.WEIGHT_DECAY = weight_decay
    if weight_decay_norm is not None:
        cfg.SOLVER.WEIGHT_DECAY_NORM = weight_decay_norm
    if weight_decay_bias is not None:
        cfg.SOLVER.WEIGHT_DECAY_BIAS = weight_decay_bias
    if lr_mult is not None:
        cfg.SOLVER.LR_MULTIPLIER_OVERWRITE = [lr_mult]
    return cfg


Yanghan Wang's avatar
Yanghan Wang committed
96
class TestOptimizer(unittest.TestCase):
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
    def test_expand_optimizer_param_groups(self):
        groups = [
            {
                "params": ["p1", "p2", "p3", "p4"],
                "lr": 1.0,
                "weight_decay": 3.0,
            },
            {
                "params": ["p2", "p3", "p5"],
                "lr": 2.0,
                "momentum": 2.0,
            },
            {
                "params": ["p1"],
                "weight_decay": 4.0,
            },
        ]
        gt_groups = [
            dict(params=["p1"], lr=1.0, weight_decay=4.0),  # noqa
            dict(params=["p2"], lr=2.0, weight_decay=3.0, momentum=2.0),  # noqa
            dict(params=["p3"], lr=2.0, weight_decay=3.0, momentum=2.0),  # noqa
            dict(params=["p4"], lr=1.0, weight_decay=3.0),  # noqa
            dict(params=["p5"], lr=2.0, momentum=2.0),  # noqa
        ]
        out = expand_optimizer_param_groups(groups)
        self.assertEqual(out, gt_groups)

    def test_regroup_optimizer_param_groups(self):
        expanded_groups = [
            dict(params=["p1"], lr=1.0, weight_decay=4.0),  # noqa
            dict(params=["p2"], lr=2.0, weight_decay=3.0, momentum=2.0),  # noqa
            dict(params=["p3"], lr=2.0, weight_decay=3.0, momentum=2.0),  # noqa
            dict(params=["p4"], lr=1.0, weight_decay=3.0),  # noqa
            dict(params=["p5"], lr=2.0, momentum=2.0),  # noqa
        ]
        gt_groups = [
            {
                "lr": 1.0,
                "weight_decay": 4.0,
                "params": ["p1"],
            },
            {
                "lr": 2.0,
                "weight_decay": 3.0,
                "momentum": 2.0,
                "params": ["p2", "p3"],
            },
            {
                "lr": 1.0,
                "weight_decay": 3.0,
                "params": ["p4"],
            },
            {
                "lr": 2.0,
                "momentum": 2.0,
                "params": ["p5"],
            },
        ]
        out = regroup_optimizer_param_groups(expanded_groups)
        self.assertEqual(out, gt_groups)

    def test_create_optimizer_default(self):
        class Model(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.conv = torch.nn.Conv2d(3, 3, 1)
                self.bn = torch.nn.BatchNorm2d(3)

            def forward(self, x):
                return self.bn(self.conv(x))

        model = Model()
        cfg = get_optimizer_cfg(
            lr=1.0, weight_decay=1.0, weight_decay_norm=1.0, weight_decay_bias=1.0
        )
        optimizer = build_optimizer_mapper(cfg, model)
        self.assertEqual(len(optimizer.param_groups), 1)
        _check_param_group(
            self, optimizer.param_groups[0], num_params=4, weight_decay=1.0, lr=1.0
        )

    def test_create_optimizer_lr(self):
        class Model(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.conv1 = torch.nn.Conv2d(3, 3, 1)
                self.conv2 = torch.nn.Conv2d(3, 3, 1)
                self.bn = torch.nn.BatchNorm2d(3)

            def forward(self, x):
                return self.bn(self.conv2(self.conv1(x)))

        model = Model()
        cfg = get_optimizer_cfg(
            lr=1.0,
            lr_mult={"conv1": 3.0, "conv2": 3.0},
            weight_decay=2.0,
            weight_decay_norm=2.0,
        )
        optimizer = build_optimizer_mapper(cfg, model)

        self.assertEqual(len(optimizer.param_groups), 2)

        _check_param_group(self, optimizer.param_groups[0], num_params=4, lr=3.0)
        _check_param_group(self, optimizer.param_groups[1], num_params=2, lr=1.0)

    def test_create_optimizer_weight_decay_norm(self):
        class Model(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.conv = torch.nn.Conv2d(3, 3, 1)
                self.bn = torch.nn.BatchNorm2d(3)

            def forward(self, x):
                return self.bn(self.conv(x))

        model = Model()
        cfg = get_optimizer_cfg(
            lr=1.0, weight_decay=1.0, weight_decay_norm=2.0, weight_decay_bias=1.0
        )
        optimizer = build_optimizer_mapper(cfg, model)

        self.assertEqual(len(optimizer.param_groups), 2)

        _check_param_group(
            self, optimizer.param_groups[0], num_params=2, lr=1.0, weight_decay=1.0
        )
        _check_param_group(
            self, optimizer.param_groups[1], num_params=2, lr=1.0, weight_decay=2.0
        )

228
    def test_all_optimizers(self):
facebook-github-bot's avatar
facebook-github-bot committed
229
230
        runner = default_runner.Detectron2GoRunner()
        cfg = runner.get_default_cfg()
Yanghan Wang's avatar
Yanghan Wang committed
231
        multipliers = [None, [{"conv": 0.1}]]
facebook-github-bot's avatar
facebook-github-bot committed
232

233
        for optimizer_name in ["SGD", "AdamW", "SGD_MT", "AdamW_MT"]:
facebook-github-bot's avatar
facebook-github-bot committed
234
            for mult in multipliers:
235
                cfg.SOLVER.BASE_LR = 0.01
facebook-github-bot's avatar
facebook-github-bot committed
236
237
238
239
240
241
242
243
                cfg.SOLVER.OPTIMIZER = optimizer_name
                cfg.SOLVER.MULTIPLIERS = mult
                _test_each_optimizer(cfg)

    def test_full_model_grad_clipping(self):
        runner = default_runner.Detectron2GoRunner()
        cfg = runner.get_default_cfg()

244
        for optimizer_name in ["SGD", "AdamW", "SGD_MT", "AdamW_MT"]:
245
            cfg.SOLVER.BASE_LR = 0.02
facebook-github-bot's avatar
facebook-github-bot committed
246
247
248
249
250
            cfg.SOLVER.CLIP_GRADIENTS.CLIP_VALUE = 0.2
            cfg.SOLVER.CLIP_GRADIENTS.ENABLED = True
            cfg.SOLVER.CLIP_GRADIENTS.CLIP_TYPE = "full_model"
            cfg.SOLVER.OPTIMIZER = optimizer_name
            _test_each_optimizer(cfg)
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322

    def test_create_optimizer_custom(self):
        class Model(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.conv = torch.nn.Conv2d(3, 3, 1)
                self.bn = torch.nn.BatchNorm2d(3)

            def forward(self, x):
                return self.bn(self.conv(x))

            def get_optimizer_param_groups(self, _opts):
                ret = [
                    {
                        "params": [self.conv.weight],
                        "lr": 10.0,
                    }
                ]
                return ret

        model = Model()
        cfg = get_optimizer_cfg(lr=1.0, weight_decay=1.0, weight_decay_norm=0.0)
        optimizer = build_optimizer_mapper(cfg, model)

        self.assertEqual(len(optimizer.param_groups), 3)

        _check_param_group(
            self, optimizer.param_groups[0], num_params=1, lr=10.0, weight_decay=1.0
        )
        _check_param_group(
            self, optimizer.param_groups[1], num_params=1, lr=1.0, weight_decay=1.0
        )
        _check_param_group(
            self, optimizer.param_groups[2], num_params=2, lr=1.0, weight_decay=0.0
        )

    @helper.enable_ddp_env
    def test_create_optimizer_custom_ddp(self):
        class Model(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.conv = torch.nn.Conv2d(3, 3, 1)
                self.bn = torch.nn.BatchNorm2d(3)

            def forward(self, x):
                return self.bn(self.conv(x))

            def get_optimizer_param_groups(self, _opts):
                ret = [
                    {
                        "params": [self.conv.weight],
                        "lr": 10.0,
                    }
                ]
                return ret

        model = Model()
        model = torch.nn.parallel.DistributedDataParallel(model)
        cfg = get_optimizer_cfg(lr=1.0, weight_decay=1.0, weight_decay_norm=0.0)
        optimizer = build_optimizer_mapper(cfg, model)

        self.assertEqual(len(optimizer.param_groups), 3)

        _check_param_group(
            self, optimizer.param_groups[0], num_params=1, lr=10.0, weight_decay=1.0
        )
        _check_param_group(
            self, optimizer.param_groups[1], num_params=1, lr=1.0, weight_decay=1.0
        )
        _check_param_group(
            self, optimizer.param_groups[2], num_params=2, lr=1.0, weight_decay=0.0
        )