test_modeling_ema.py 6.95 KB
Newer Older
facebook-github-bot's avatar
facebook-github-bot committed
1
2
3
4
5
6
7
8
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved


import copy
import itertools
import unittest

Yanghan Wang's avatar
Yanghan Wang committed
9
import d2go.runner.default_runner as default_runner
facebook-github-bot's avatar
facebook-github-bot committed
10
import torch
11
from d2go.modeling import ema
Yanghan Wang's avatar
Yanghan Wang committed
12
from d2go.utils.testing import helper
facebook-github-bot's avatar
facebook-github-bot committed
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60


class TestArch(torch.nn.Module):
    def __init__(self, value=None, int_value=None):
        super().__init__()
        self.conv = torch.nn.Conv2d(3, 4, kernel_size=3, stride=1, padding=1)
        self.bn = torch.nn.BatchNorm2d(4)
        self.relu = torch.nn.ReLU(inplace=True)
        self.avgpool = torch.nn.AdaptiveAvgPool2d((1, 1))
        if value is not None:
            self.set_const_weights(value, int_value)

    def forward(self, x):
        ret = self.conv(x)
        ret = self.bn(ret)
        ret = self.relu(ret)
        ret = self.avgpool(ret)
        return ret

    def set_const_weights(self, value, int_value=None):
        if int_value is None:
            int_value = int(value)
        for x in itertools.chain(self.parameters(), self.buffers()):
            if x.dtype == torch.float32:
                x.data.fill_(value)
            else:
                x.data.fill_(int_value)


def _compare_state_dict(model1, model2, abs_error=1e-3):
    sd1 = model1.state_dict()
    sd2 = model2.state_dict()
    if len(sd1) != len(sd2):
        return False
    if set(sd1.keys()) != set(sd2.keys()):
        return False
    for name in sd1:
        if sd1[name].dtype == torch.float32:
            if torch.abs((sd1[name] - sd2[name])).max() > abs_error:
                return False
        elif (sd1[name] != sd2[name]).any():
            return False
    return True


class TestModelingModelEMA(unittest.TestCase):
    def test_emastate(self):
        model = TestArch()
61
        state = ema.EMAState.FromModel(model)
facebook-github-bot's avatar
facebook-github-bot committed
62
63
        # two for conv (conv.weight, conv.bias),
        # five for bn (bn.weight, bn.bias, bn.running_mean, bn.running_var, bn.num_batches_tracked)
64
65
66
67
68
69
70
71
72
        full_state = {
            "conv.weight",
            "conv.bias",
            "bn.weight",
            "bn.bias",
            "bn.running_mean",
            "bn.running_var",
            "bn.num_batches_tracked",
        }
facebook-github-bot's avatar
facebook-github-bot committed
73
        self.assertEqual(len(state.state), 7)
74
        self.assertTrue(set(state.state) == full_state)
facebook-github-bot's avatar
facebook-github-bot committed
75
76
77
78
79
80
81
82
83
84

        for _, val in state.state.items():
            self.assertFalse(val.requires_grad)

        model1 = TestArch()
        self.assertFalse(_compare_state_dict(model, model1))

        state.apply_to(model1)
        self.assertTrue(_compare_state_dict(model, model1))

85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
        # test ema state that excludes buffers and frozen parameters
        model.conv.weight.requires_grad = False
        state1 = ema.EMAState.FromModel(model, include_frozen=False)
        # should exclude frozen parameter: conv.weight
        self.assertTrue(full_state - set(state1.state) == {"conv.weight"})

        state2 = ema.EMAState.FromModel(model, include_buffer=False)
        # should exclude buffers: bn.running_mean, bn.running_var, bn.num_batches_tracked
        self.assertTrue(
            full_state - set(state2.state)
            == {"bn.running_mean", "bn.running_var", "bn.num_batches_tracked"}
        )

        state3 = ema.EMAState.FromModel(
            model, include_frozen=False, include_buffer=False
        )
        # should exclude frozen param + buffers: conv.weight, bn.running_mean, bn.running_var, bn.num_batches_tracked
        self.assertTrue(set(state3.state) == {"conv.bias", "bn.weight", "bn.bias"})

facebook-github-bot's avatar
facebook-github-bot committed
104
105
    def test_emastate_saveload(self):
        model = TestArch()
106
        state = ema.EMAState.FromModel(model)
facebook-github-bot's avatar
facebook-github-bot committed
107
108
109
110

        model1 = TestArch()
        self.assertFalse(_compare_state_dict(model, model1))

111
        state1 = ema.EMAState()
facebook-github-bot's avatar
facebook-github-bot committed
112
113
114
115
116
117
118
119
120
        state1.load_state_dict(state.state_dict())
        state1.apply_to(model1)
        self.assertTrue(_compare_state_dict(model, model1))

    @helper.skip_if_no_gpu
    def test_emastate_crossdevice(self):
        model = TestArch()
        model.cuda()
        # state on gpu
121
        state = ema.EMAState.FromModel(model)
facebook-github-bot's avatar
facebook-github-bot committed
122
123
124
125
126
127
128
129
        self.assertEqual(state.device, torch.device("cuda:0"))
        # target model on cpu
        model1 = TestArch()
        state.apply_to(model1)
        self.assertEqual(next(model1.parameters()).device, torch.device("cpu"))
        self.assertTrue(_compare_state_dict(copy.deepcopy(model).cpu(), model1))

        # state on cpu
130
        state1 = ema.EMAState.FromModel(model, device="cpu")
facebook-github-bot's avatar
facebook-github-bot committed
131
132
133
134
135
136
137
138
139
140
        self.assertEqual(state1.device, torch.device("cpu"))
        # target model on gpu
        model2 = TestArch()
        model2.cuda()
        state1.apply_to(model2)
        self.assertEqual(next(model2.parameters()).device, torch.device("cuda:0"))
        self.assertTrue(_compare_state_dict(model, model2))

    def test_ema_updater(self):
        model = TestArch()
141
        state = ema.EMAState()
facebook-github-bot's avatar
facebook-github-bot committed
142
143
144

        updated_model = TestArch()

145
        updater = ema.EMAUpdater(state, decay=0.0)
facebook-github-bot's avatar
facebook-github-bot committed
146
147
148
149
150
151
152
153
        updater.init_state(model)
        for _ in range(3):
            cur = TestArch()
            updater.update(cur)
            state.apply_to(updated_model)
            # weight decay == 0.0, always use new model
            self.assertTrue(_compare_state_dict(updated_model, cur))

154
        updater = ema.EMAUpdater(state, decay=1.0)
facebook-github-bot's avatar
facebook-github-bot committed
155
156
157
158
159
160
161
162
163
        updater.init_state(model)
        for _ in range(3):
            cur = TestArch()
            updater.update(cur)
            state.apply_to(updated_model)
            # weight decay == 1.0, always use init model
            self.assertTrue(_compare_state_dict(updated_model, model))

    def test_ema_updater_decay(self):
164
        state = ema.EMAState()
facebook-github-bot's avatar
facebook-github-bot committed
165

166
        updater = ema.EMAUpdater(state, decay=0.7)
facebook-github-bot's avatar
facebook-github-bot committed
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
        updater.init_state(TestArch(1.0))
        gt_val = 1.0
        gt_val_int = 1
        for idx in range(3):
            updater.update(TestArch(float(idx)))
            updated_model = state.get_ema_model(TestArch())
            gt_val = gt_val * 0.7 + float(idx) * 0.3
            gt_val_int = int(gt_val_int * 0.7 + float(idx) * 0.3)
            self.assertTrue(
                _compare_state_dict(updated_model, TestArch(gt_val, gt_val_int))
            )


class TestModelingModelEMAHook(unittest.TestCase):
    def test_ema_hook(self):
        runner = default_runner.Detectron2GoRunner()
        cfg = runner.get_default_cfg()
        cfg.MODEL.DEVICE = "cpu"
        cfg.MODEL_EMA.ENABLED = True
        # use new model weights
        cfg.MODEL_EMA.DECAY = 0.0

        model = TestArch()
190
        ema.may_build_model_ema(cfg, model)
facebook-github-bot's avatar
facebook-github-bot committed
191
192
        self.assertTrue(hasattr(model, "ema_state"))

193
        ema_hook = ema.EMAHook(cfg, model)
facebook-github-bot's avatar
facebook-github-bot committed
194
195
196
197
198
199
        ema_hook.before_train()
        ema_hook.before_step()
        model.set_const_weights(2.0)
        ema_hook.after_step()
        ema_hook.after_train()

200
        ema_checkpointers = ema.may_get_ema_checkpointer(cfg, model)
facebook-github-bot's avatar
facebook-github-bot committed
201
202
203
204
        self.assertEqual(len(ema_checkpointers), 1)

        out_model = TestArch()
        ema_checkpointers["ema_state"].apply_to(out_model)
Yanghan Wang's avatar
Yanghan Wang committed
205
        self.assertTrue(_compare_state_dict(out_model, model))