test_hooks.py 8.76 KB
Newer Older
Wenwei Zhang's avatar
Wenwei Zhang committed
1
2
3
4
5
6
7
8
"""
Tests the hooks with runners.

CommandLine:
    pytest tests/test_hooks.py
    xdoctest tests/test_hooks.py zero

"""
9
import logging
Jiangmiao Pang's avatar
Jiangmiao Pang committed
10
import os.path as osp
11
import shutil
Jiangmiao Pang's avatar
Jiangmiao Pang committed
12
import sys
13
import tempfile
Wenwei Zhang's avatar
Wenwei Zhang committed
14
from unittest.mock import MagicMock, call
Jiangmiao Pang's avatar
Jiangmiao Pang committed
15

16
17
18
19
20
import pytest
import torch
import torch.nn as nn
from torch.utils.data import DataLoader

21
22
23
from mmcv.runner import (EpochBasedRunner, IterTimerHook, MlflowLoggerHook,
                         PaviLoggerHook, WandbLoggerHook)
from mmcv.runner.hooks.lr_updater import (CosineAnealingLrUpdaterHook,
Harry's avatar
Harry committed
24
                                          CosineRestartLrUpdaterHook,
25
26
27
                                          CyclicLrUpdaterHook)
from mmcv.runner.hooks.momentum_updater import (
    CosineAnealingMomentumUpdaterHook, CyclicMomentumUpdaterHook)
Jiangmiao Pang's avatar
Jiangmiao Pang committed
28
29
30
31
32


def test_pavi_hook():
    sys.modules['pavi'] = MagicMock()

Wenwei Zhang's avatar
Wenwei Zhang committed
33
34
    loader = DataLoader(torch.ones((5, 2)))
    runner = _build_demo_runner()
35
    hook = PaviLoggerHook(add_graph=False, add_last_ckpt=True)
Jiangmiao Pang's avatar
Jiangmiao Pang committed
36
37
    runner.register_hook(hook)
    runner.run([loader, loader], [('train', 1), ('val', 1)], 1)
38
    shutil.rmtree(runner.work_dir)
Jiangmiao Pang's avatar
Jiangmiao Pang committed
39
40

    assert hasattr(hook, 'writer')
Wenwei Zhang's avatar
Wenwei Zhang committed
41
42
43
44
    hook.writer.add_scalars.assert_called_with('val', {
        'learning_rate': 0.02,
        'momentum': 0.95
    }, 5)
Jiangmiao Pang's avatar
Jiangmiao Pang committed
45
    hook.writer.add_snapshot_file.assert_called_with(
46
        tag=runner.work_dir.split('/')[-1],
Wenwei Zhang's avatar
Wenwei Zhang committed
47
        snapshot_file_path=osp.join(runner.work_dir, 'latest.pth'),
Jiangmiao Pang's avatar
Jiangmiao Pang committed
48
        iteration=5)
49
50


Wenwei Zhang's avatar
Wenwei Zhang committed
51
52
53
54
55
56
57
58
59
def test_momentum_runner_hook():
    """
    xdoctest -m tests/test_hooks.py test_momentum_runner_hook
    """
    sys.modules['pavi'] = MagicMock()
    loader = DataLoader(torch.ones((10, 2)))
    runner = _build_demo_runner()

    # add momentum scheduler
60
    hook = CyclicMomentumUpdaterHook(
Wenwei Zhang's avatar
Wenwei Zhang committed
61
62
63
64
65
66
67
        by_epoch=False,
        target_ratio=(0.85 / 0.95, 1),
        cyclic_times=1,
        step_ratio_up=0.4)
    runner.register_hook(hook)

    # add momentum LR scheduler
68
    hook = CyclicLrUpdaterHook(
Wenwei Zhang's avatar
Wenwei Zhang committed
69
70
71
72
73
        by_epoch=False,
        target_ratio=(10, 1),
        cyclic_times=1,
        step_ratio_up=0.4)
    runner.register_hook(hook)
74
    runner.register_hook(IterTimerHook())
Wenwei Zhang's avatar
Wenwei Zhang committed
75
76

    # add pavi hook
77
    hook = PaviLoggerHook(interval=1, add_graph=False, add_last_ckpt=True)
Wenwei Zhang's avatar
Wenwei Zhang committed
78
79
    runner.register_hook(hook)
    runner.run([loader], [('train', 1)], 1)
80
    shutil.rmtree(runner.work_dir)
Wenwei Zhang's avatar
Wenwei Zhang committed
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109

    # TODO: use a more elegant way to check values
    assert hasattr(hook, 'writer')
    calls = [
        call('train', {
            'learning_rate': 0.01999999999999999,
            'momentum': 0.95
        }, 0),
        call('train', {
            'learning_rate': 0.2,
            'momentum': 0.85
        }, 4),
        call('train', {
            'learning_rate': 0.155,
            'momentum': 0.875
        }, 6),
    ]
    hook.writer.add_scalars.assert_has_calls(calls, any_order=True)


def test_cosine_runner_hook():
    """
    xdoctest -m tests/test_hooks.py test_cosine_runner_hook
    """
    sys.modules['pavi'] = MagicMock()
    loader = DataLoader(torch.ones((10, 2)))
    runner = _build_demo_runner()

    # add momentum scheduler
110
111
112
113
114
    hook = CosineAnealingMomentumUpdaterHook(
        min_momentum_ratio=0.99 / 0.95,
        by_epoch=False,
        warmup_iters=2,
        warmup_ratio=0.9 / 0.95)
Wenwei Zhang's avatar
Wenwei Zhang committed
115
116
117
    runner.register_hook(hook)

    # add momentum LR scheduler
118
    hook = CosineAnealingLrUpdaterHook(
Wenwei Zhang's avatar
Wenwei Zhang committed
119
120
        by_epoch=False, min_lr_ratio=0, warmup_iters=2, warmup_ratio=0.9)
    runner.register_hook(hook)
121
    runner.register_hook(IterTimerHook())
Wenwei Zhang's avatar
Wenwei Zhang committed
122
123

    # add pavi hook
124
    hook = PaviLoggerHook(interval=1, add_graph=False, add_last_ckpt=True)
Wenwei Zhang's avatar
Wenwei Zhang committed
125
126
    runner.register_hook(hook)
    runner.run([loader], [('train', 1)], 1)
127
    shutil.rmtree(runner.work_dir)
Wenwei Zhang's avatar
Wenwei Zhang committed
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147

    # TODO: use a more elegant way to check values
    assert hasattr(hook, 'writer')
    calls = [
        call('train', {
            'learning_rate': 0.02,
            'momentum': 0.95
        }, 0),
        call('train', {
            'learning_rate': 0.01,
            'momentum': 0.97
        }, 5),
        call('train', {
            'learning_rate': 0.0004894348370484647,
            'momentum': 0.9890211303259032
        }, 9)
    ]
    hook.writer.add_scalars.assert_has_calls(calls, any_order=True)


Harry's avatar
Harry committed
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
def test_cosine_restart_lr_update_hook():
    """Test CosineRestartLrUpdaterHook."""
    with pytest.raises(AssertionError):
        # either `min_lr` or `min_lr_ratio` should be specified
        CosineRestartLrUpdaterHook(
            by_epoch=False,
            periods=[2, 10],
            restart_weights=[0.5, 0.5],
            min_lr=0.1,
            min_lr_ratio=0)

    with pytest.raises(AssertionError):
        # periods and restart_weights should have the same length
        CosineRestartLrUpdaterHook(
            by_epoch=False,
            periods=[2, 10],
            restart_weights=[0.5],
            min_lr_ratio=0)

    with pytest.raises(ValueError):
        # the last cumulative_periods 7 (out of [5, 7]) should >= 10
        sys.modules['pavi'] = MagicMock()
        loader = DataLoader(torch.ones((10, 2)))
        runner = _build_demo_runner()

        # add cosine restart LR scheduler
        hook = CosineRestartLrUpdaterHook(
            by_epoch=False,
            periods=[5, 2],  # cumulative_periods [5, 7 (5 + 2)]
            restart_weights=[0.5, 0.5],
            min_lr=0.0001)
        runner.register_hook(hook)
        runner.register_hook(IterTimerHook())

        # add pavi hook
        hook = PaviLoggerHook(interval=1, add_graph=False, add_last_ckpt=True)
        runner.register_hook(hook)
        runner.run([loader], [('train', 1)], 1)
        shutil.rmtree(runner.work_dir)

    sys.modules['pavi'] = MagicMock()
    loader = DataLoader(torch.ones((10, 2)))
    runner = _build_demo_runner()

    # add cosine restart LR scheduler
    hook = CosineRestartLrUpdaterHook(
        by_epoch=False,
        periods=[5, 5],
        restart_weights=[0.5, 0.5],
        min_lr_ratio=0)
    runner.register_hook(hook)
    runner.register_hook(IterTimerHook())

    # add pavi hook
    hook = PaviLoggerHook(interval=1, add_graph=False, add_last_ckpt=True)
    runner.register_hook(hook)
    runner.run([loader], [('train', 1)], 1)
    shutil.rmtree(runner.work_dir)

    # TODO: use a more elegant way to check values
    assert hasattr(hook, 'writer')
    calls = [
        call('train', {
            'learning_rate': 0.01,
            'momentum': 0.95
        }, 0),
        call('train', {
            'learning_rate': 0.0,
            'momentum': 0.95
        }, 5),
        call('train', {
            'learning_rate': 0.0009549150281252633,
            'momentum': 0.95
        }, 9)
    ]
    hook.writer.add_scalars.assert_has_calls(calls, any_order=True)


226
227
228
229
230
@pytest.mark.parametrize('log_model', (True, False))
def test_mlflow_hook(log_model):
    sys.modules['mlflow'] = MagicMock()
    sys.modules['mlflow.pytorch'] = MagicMock()

Wenwei Zhang's avatar
Wenwei Zhang committed
231
232
    runner = _build_demo_runner()
    loader = DataLoader(torch.ones((5, 2)))
233

234
    hook = MlflowLoggerHook(exp_name='test', log_model=log_model)
235
236
    runner.register_hook(hook)
    runner.run([loader, loader], [('train', 1), ('val', 1)], 1)
237
    shutil.rmtree(runner.work_dir)
238
239

    hook.mlflow.set_experiment.assert_called_with('test')
Wenwei Zhang's avatar
Wenwei Zhang committed
240
241
242
243
244
    hook.mlflow.log_metrics.assert_called_with(
        {
            'learning_rate': 0.02,
            'momentum': 0.95
        }, step=5)
245
246
247
248
249
250
251
252
253
    if log_model:
        hook.mlflow_pytorch.log_model.assert_called_with(
            runner.model, 'models')
    else:
        assert not hook.mlflow_pytorch.log_model.called


def test_wandb_hook():
    sys.modules['wandb'] = MagicMock()
Wenwei Zhang's avatar
Wenwei Zhang committed
254
    runner = _build_demo_runner()
255
    hook = WandbLoggerHook()
Wenwei Zhang's avatar
Wenwei Zhang committed
256
    loader = DataLoader(torch.ones((5, 2)))
257
258
259

    runner.register_hook(hook)
    runner.run([loader, loader], [('train', 1), ('val', 1)], 1)
260
261
    shutil.rmtree(runner.work_dir)

262
    hook.wandb.init.assert_called_with()
Wenwei Zhang's avatar
Wenwei Zhang committed
263
264
265
266
267
    hook.wandb.log.assert_called_with({
        'learning_rate': 0.02,
        'momentum': 0.95
    },
                                      step=5)
268
    hook.wandb.join.assert_called_with()
Wenwei Zhang's avatar
Wenwei Zhang committed
269
270
271


def _build_demo_runner():
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289

    class Model(nn.Module):

        def __init__(self):
            super().__init__()
            self.linear = nn.Linear(2, 1)

        def forward(self, x):
            return self.linear(x)

        def train_step(self, x, optimizer, **kwargs):
            return dict(loss=self(x))

        def val_step(self, x, optimizer, **kwargs):
            return dict(loss=self(x))

    model = Model()

Wenwei Zhang's avatar
Wenwei Zhang committed
290
291
292
293
294
295
296
    optimizer = torch.optim.SGD(model.parameters(), lr=0.02, momentum=0.95)

    log_config = dict(
        interval=1, hooks=[
            dict(type='TextLoggerHook'),
        ])

297
    tmp_dir = tempfile.mkdtemp()
298
    runner = EpochBasedRunner(
Wenwei Zhang's avatar
Wenwei Zhang committed
299
        model=model,
300
301
302
        work_dir=tmp_dir,
        optimizer=optimizer,
        logger=logging.getLogger())
Wenwei Zhang's avatar
Wenwei Zhang committed
303
304
305

    runner.register_logger_hooks(log_config)
    return runner