"tests/vscode:/vscode.git/clone" did not exist on "8a7a332539809adcad88546f492945c4e752ff49"
test_hooks.py 16.1 KB
Newer Older
Kai Chen's avatar
Kai Chen committed
1
"""Tests the hooks with runners.
Wenwei Zhang's avatar
Wenwei Zhang committed
2
3

CommandLine:
4
    pytest tests/test_runner/test_hooks.py
Wenwei Zhang's avatar
Wenwei Zhang committed
5
6
    xdoctest tests/test_hooks.py zero
"""
7
import logging
Jiangmiao Pang's avatar
Jiangmiao Pang committed
8
import os.path as osp
9
import re
10
import shutil
Jiangmiao Pang's avatar
Jiangmiao Pang committed
11
import sys
12
import tempfile
Wenwei Zhang's avatar
Wenwei Zhang committed
13
from unittest.mock import MagicMock, call
Jiangmiao Pang's avatar
Jiangmiao Pang committed
14

15
16
17
import pytest
import torch
import torch.nn as nn
shilong's avatar
shilong committed
18
from torch.nn.init import constant_
19
20
from torch.utils.data import DataLoader

21
22
23
from mmcv.runner import (CheckpointHook, EMAHook, IterTimerHook,
                         MlflowLoggerHook, PaviLoggerHook, WandbLoggerHook,
                         build_runner)
24
25
from mmcv.runner.hooks.lr_updater import (CosineRestartLrUpdaterHook,
                                          OneCycleLrUpdaterHook)
Jiangmiao Pang's avatar
Jiangmiao Pang committed
26
27


28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
def test_checkpoint_hook():
    """xdoctest -m tests/test_runner/test_hooks.py test_checkpoint_hook."""

    # test epoch based runner
    loader = DataLoader(torch.ones((5, 2)))
    runner = _build_demo_runner('EpochBasedRunner', max_epochs=1)
    runner.meta = dict()
    checkpointhook = CheckpointHook(interval=1, by_epoch=True)
    runner.register_hook(checkpointhook)
    runner.run([loader], [('train', 1)])
    assert runner.meta['hook_msgs']['last_ckpt'] == osp.join(
        runner.work_dir, 'epoch_1.pth')
    shutil.rmtree(runner.work_dir)

    # test iter based runner
    runner = _build_demo_runner(
        'IterBasedRunner', max_iters=1, max_epochs=None)
    runner.meta = dict()
    checkpointhook = CheckpointHook(interval=1, by_epoch=False)
    runner.register_hook(checkpointhook)
    runner.run([loader], [('train', 1)])
    assert runner.meta['hook_msgs']['last_ckpt'] == osp.join(
        runner.work_dir, 'iter_1.pth')
    shutil.rmtree(runner.work_dir)


shilong's avatar
shilong committed
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
def test_ema_hook():
    """xdoctest -m tests/test_hooks.py test_ema_hook."""

    class DemoModel(nn.Module):

        def __init__(self):
            super().__init__()
            self.conv = nn.Conv2d(
                in_channels=1,
                out_channels=2,
                kernel_size=1,
                padding=1,
                bias=True)
            self._init_weight()

        def _init_weight(self):
            constant_(self.conv.weight, 0)
            constant_(self.conv.bias, 0)

        def forward(self, x):
            return self.conv(x).sum()

        def train_step(self, x, optimizer, **kwargs):
            return dict(loss=self(x))

        def val_step(self, x, optimizer, **kwargs):
            return dict(loss=self(x))

    loader = DataLoader(torch.ones((1, 1, 1, 1)))
    runner = _build_demo_runner()
    demo_model = DemoModel()
    runner.model = demo_model
    emahook = EMAHook(momentum=0.1, interval=2, warm_up=100, resume_from=None)
    checkpointhook = CheckpointHook(interval=1, by_epoch=True)
    runner.register_hook(emahook, priority='HIGHEST')
    runner.register_hook(checkpointhook)
90
    runner.run([loader, loader], [('train', 1), ('val', 1)])
shilong's avatar
shilong committed
91
92
93
94
95
96
97
98
99
100
101
102
103
104
    checkpoint = torch.load(f'{runner.work_dir}/epoch_1.pth')
    contain_ema_buffer = False
    for name, value in checkpoint['state_dict'].items():
        if 'ema' in name:
            contain_ema_buffer = True
            assert value.sum() == 0
            value.fill_(1)
        else:
            assert value.sum() == 0
    assert contain_ema_buffer
    torch.save(checkpoint, f'{runner.work_dir}/epoch_1.pth')
    work_dir = runner.work_dir
    resume_ema_hook = EMAHook(
        momentum=0.5, warm_up=0, resume_from=f'{work_dir}/epoch_1.pth')
105
    runner = _build_demo_runner(max_epochs=2)
shilong's avatar
shilong committed
106
107
108
109
    runner.model = demo_model
    runner.register_hook(resume_ema_hook, priority='HIGHEST')
    checkpointhook = CheckpointHook(interval=1, by_epoch=True)
    runner.register_hook(checkpointhook)
110
    runner.run([loader, loader], [('train', 1), ('val', 1)])
shilong's avatar
shilong committed
111
112
113
114
115
116
117
118
119
120
121
122
123
    checkpoint = torch.load(f'{runner.work_dir}/epoch_2.pth')
    contain_ema_buffer = False
    for name, value in checkpoint['state_dict'].items():
        if 'ema' in name:
            contain_ema_buffer = True
            assert value.sum() == 2
        else:
            assert value.sum() == 1
    assert contain_ema_buffer
    shutil.rmtree(runner.work_dir)
    shutil.rmtree(work_dir)


Jiangmiao Pang's avatar
Jiangmiao Pang committed
124
125
126
def test_pavi_hook():
    sys.modules['pavi'] = MagicMock()

Wenwei Zhang's avatar
Wenwei Zhang committed
127
128
    loader = DataLoader(torch.ones((5, 2)))
    runner = _build_demo_runner()
129
    runner.meta = dict(config_dict=dict(lr=0.02, gpu_ids=range(1)))
130
    hook = PaviLoggerHook(add_graph=False, add_last_ckpt=True)
Jiangmiao Pang's avatar
Jiangmiao Pang committed
131
    runner.register_hook(hook)
132
    runner.run([loader, loader], [('train', 1), ('val', 1)])
133
    shutil.rmtree(runner.work_dir)
Jiangmiao Pang's avatar
Jiangmiao Pang committed
134
135

    assert hasattr(hook, 'writer')
Wenwei Zhang's avatar
Wenwei Zhang committed
136
137
138
    hook.writer.add_scalars.assert_called_with('val', {
        'learning_rate': 0.02,
        'momentum': 0.95
139
    }, 1)
Jiangmiao Pang's avatar
Jiangmiao Pang committed
140
    hook.writer.add_snapshot_file.assert_called_with(
141
        tag=runner.work_dir.split('/')[-1],
142
143
        snapshot_file_path=osp.join(runner.work_dir, 'epoch_1.pth'),
        iteration=1)
144
145


Wang Xinjiang's avatar
Wang Xinjiang committed
146
147
148
149
def test_sync_buffers_hook():
    loader = DataLoader(torch.ones((5, 2)))
    runner = _build_demo_runner()
    runner.register_hook_from_cfg(dict(type='SyncBuffersHook'))
150
    runner.run([loader, loader], [('train', 1), ('val', 1)])
Wang Xinjiang's avatar
Wang Xinjiang committed
151
152
153
    shutil.rmtree(runner.work_dir)


Wenwei Zhang's avatar
Wenwei Zhang committed
154
def test_momentum_runner_hook():
Kai Chen's avatar
Kai Chen committed
155
    """xdoctest -m tests/test_hooks.py test_momentum_runner_hook."""
Wenwei Zhang's avatar
Wenwei Zhang committed
156
157
158
159
160
    sys.modules['pavi'] = MagicMock()
    loader = DataLoader(torch.ones((10, 2)))
    runner = _build_demo_runner()

    # add momentum scheduler
Wang Xinjiang's avatar
Wang Xinjiang committed
161
162
    hook_cfg = dict(
        type='CyclicMomentumUpdaterHook',
Wenwei Zhang's avatar
Wenwei Zhang committed
163
164
165
166
        by_epoch=False,
        target_ratio=(0.85 / 0.95, 1),
        cyclic_times=1,
        step_ratio_up=0.4)
Wang Xinjiang's avatar
Wang Xinjiang committed
167
    runner.register_hook_from_cfg(hook_cfg)
Wenwei Zhang's avatar
Wenwei Zhang committed
168
169

    # add momentum LR scheduler
Wang Xinjiang's avatar
Wang Xinjiang committed
170
171
    hook_cfg = dict(
        type='CyclicLrUpdaterHook',
Wenwei Zhang's avatar
Wenwei Zhang committed
172
173
174
175
        by_epoch=False,
        target_ratio=(10, 1),
        cyclic_times=1,
        step_ratio_up=0.4)
Wang Xinjiang's avatar
Wang Xinjiang committed
176
177
    runner.register_hook_from_cfg(hook_cfg)
    runner.register_hook_from_cfg(dict(type='IterTimerHook'))
Wenwei Zhang's avatar
Wenwei Zhang committed
178
179

    # add pavi hook
180
    hook = PaviLoggerHook(interval=1, add_graph=False, add_last_ckpt=True)
Wenwei Zhang's avatar
Wenwei Zhang committed
181
    runner.register_hook(hook)
182
    runner.run([loader], [('train', 1)])
183
    shutil.rmtree(runner.work_dir)
Wenwei Zhang's avatar
Wenwei Zhang committed
184
185
186
187
188
189
190

    # TODO: use a more elegant way to check values
    assert hasattr(hook, 'writer')
    calls = [
        call('train', {
            'learning_rate': 0.01999999999999999,
            'momentum': 0.95
191
        }, 1),
Wenwei Zhang's avatar
Wenwei Zhang committed
192
193
194
        call('train', {
            'learning_rate': 0.2,
            'momentum': 0.85
195
        }, 5),
Wenwei Zhang's avatar
Wenwei Zhang committed
196
197
198
        call('train', {
            'learning_rate': 0.155,
            'momentum': 0.875
199
        }, 7),
Wenwei Zhang's avatar
Wenwei Zhang committed
200
201
202
203
204
    ]
    hook.writer.add_scalars.assert_has_calls(calls, any_order=True)


def test_cosine_runner_hook():
Kai Chen's avatar
Kai Chen committed
205
    """xdoctest -m tests/test_hooks.py test_cosine_runner_hook."""
Wenwei Zhang's avatar
Wenwei Zhang committed
206
207
208
209
210
    sys.modules['pavi'] = MagicMock()
    loader = DataLoader(torch.ones((10, 2)))
    runner = _build_demo_runner()

    # add momentum scheduler
Wang Xinjiang's avatar
Wang Xinjiang committed
211
212
213

    hook_cfg = dict(
        type='CosineAnnealingMomentumUpdaterHook',
214
215
216
217
        min_momentum_ratio=0.99 / 0.95,
        by_epoch=False,
        warmup_iters=2,
        warmup_ratio=0.9 / 0.95)
Wang Xinjiang's avatar
Wang Xinjiang committed
218
    runner.register_hook_from_cfg(hook_cfg)
Wenwei Zhang's avatar
Wenwei Zhang committed
219
220

    # add momentum LR scheduler
Wang Xinjiang's avatar
Wang Xinjiang committed
221
222
223
224
225
226
227
228
    hook_cfg = dict(
        type='CosineAnnealingLrUpdaterHook',
        by_epoch=False,
        min_lr_ratio=0,
        warmup_iters=2,
        warmup_ratio=0.9)
    runner.register_hook_from_cfg(hook_cfg)
    runner.register_hook_from_cfg(dict(type='IterTimerHook'))
229
    runner.register_hook(IterTimerHook())
Wenwei Zhang's avatar
Wenwei Zhang committed
230
    # add pavi hook
231
    hook = PaviLoggerHook(interval=1, add_graph=False, add_last_ckpt=True)
Wenwei Zhang's avatar
Wenwei Zhang committed
232
    runner.register_hook(hook)
233
    runner.run([loader], [('train', 1)])
234
    shutil.rmtree(runner.work_dir)
Wenwei Zhang's avatar
Wenwei Zhang committed
235
236
237
238
239
240
241

    # TODO: use a more elegant way to check values
    assert hasattr(hook, 'writer')
    calls = [
        call('train', {
            'learning_rate': 0.02,
            'momentum': 0.95
242
        }, 1),
Wenwei Zhang's avatar
Wenwei Zhang committed
243
244
245
        call('train', {
            'learning_rate': 0.01,
            'momentum': 0.97
246
        }, 6),
Wenwei Zhang's avatar
Wenwei Zhang committed
247
248
249
        call('train', {
            'learning_rate': 0.0004894348370484647,
            'momentum': 0.9890211303259032
250
        }, 10)
Wenwei Zhang's avatar
Wenwei Zhang committed
251
252
253
254
    ]
    hook.writer.add_scalars.assert_has_calls(calls, any_order=True)


255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
def test_one_cycle_runner_hook():
    """Test OneCycleLrUpdaterHook and OneCycleMomentumUpdaterHook."""
    with pytest.raises(AssertionError):
        # by_epoch should be False
        OneCycleLrUpdaterHook(max_lr=0.1, by_epoch=True)

    with pytest.raises(ValueError):
        # expected float between 0 and 1
        OneCycleLrUpdaterHook(max_lr=0.1, pct_start=-0.1)

    with pytest.raises(ValueError):
        # anneal_strategy should be either 'cos' or 'linear'
        OneCycleLrUpdaterHook(max_lr=0.1, anneal_strategy='sin')

    sys.modules['pavi'] = MagicMock()
    loader = DataLoader(torch.ones((10, 2)))
    runner = _build_demo_runner()

    # add momentum scheduler
    hook_cfg = dict(
        type='OneCycleMomentumUpdaterHook',
        base_momentum=0.85,
        max_momentum=0.95,
        pct_start=0.5,
        anneal_strategy='cos',
        three_phase=False)
    runner.register_hook_from_cfg(hook_cfg)

    # add momentum LR scheduler
    hook_cfg = dict(
        type='OneCycleLrUpdaterHook',
        max_lr=0.01,
        pct_start=0.5,
        anneal_strategy='cos',
        div_factor=25,
        final_div_factor=1e4,
        three_phase=False)
    runner.register_hook_from_cfg(hook_cfg)
    runner.register_hook_from_cfg(dict(type='IterTimerHook'))
    runner.register_hook(IterTimerHook())
    # add pavi hook
    hook = PaviLoggerHook(interval=1, add_graph=False, add_last_ckpt=True)
    runner.register_hook(hook)
    runner.run([loader], [('train', 1)])
    shutil.rmtree(runner.work_dir)

    # TODO: use a more elegant way to check values
    assert hasattr(hook, 'writer')
    calls = [
        call('train', {
            'learning_rate': 0.0003999999999999993,
            'momentum': 0.95
        }, 1),
        call('train', {
            'learning_rate': 0.00904508879153485,
            'momentum': 0.8595491502812526
        }, 6),
        call('train', {
            'learning_rate': 4e-08,
            'momentum': 0.95
        }, 10)
    ]
    hook.writer.add_scalars.assert_has_calls(calls, any_order=True)


Harry's avatar
Harry committed
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
def test_cosine_restart_lr_update_hook():
    """Test CosineRestartLrUpdaterHook."""
    with pytest.raises(AssertionError):
        # either `min_lr` or `min_lr_ratio` should be specified
        CosineRestartLrUpdaterHook(
            by_epoch=False,
            periods=[2, 10],
            restart_weights=[0.5, 0.5],
            min_lr=0.1,
            min_lr_ratio=0)

    with pytest.raises(AssertionError):
        # periods and restart_weights should have the same length
        CosineRestartLrUpdaterHook(
            by_epoch=False,
            periods=[2, 10],
            restart_weights=[0.5],
            min_lr_ratio=0)

    with pytest.raises(ValueError):
        # the last cumulative_periods 7 (out of [5, 7]) should >= 10
        sys.modules['pavi'] = MagicMock()
        loader = DataLoader(torch.ones((10, 2)))
        runner = _build_demo_runner()

        # add cosine restart LR scheduler
        hook = CosineRestartLrUpdaterHook(
            by_epoch=False,
            periods=[5, 2],  # cumulative_periods [5, 7 (5 + 2)]
            restart_weights=[0.5, 0.5],
            min_lr=0.0001)
        runner.register_hook(hook)
        runner.register_hook(IterTimerHook())

        # add pavi hook
        hook = PaviLoggerHook(interval=1, add_graph=False, add_last_ckpt=True)
        runner.register_hook(hook)
357
        runner.run([loader], [('train', 1)])
Harry's avatar
Harry committed
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
        shutil.rmtree(runner.work_dir)

    sys.modules['pavi'] = MagicMock()
    loader = DataLoader(torch.ones((10, 2)))
    runner = _build_demo_runner()

    # add cosine restart LR scheduler
    hook = CosineRestartLrUpdaterHook(
        by_epoch=False,
        periods=[5, 5],
        restart_weights=[0.5, 0.5],
        min_lr_ratio=0)
    runner.register_hook(hook)
    runner.register_hook(IterTimerHook())

    # add pavi hook
    hook = PaviLoggerHook(interval=1, add_graph=False, add_last_ckpt=True)
    runner.register_hook(hook)
376
    runner.run([loader], [('train', 1)])
Harry's avatar
Harry committed
377
378
379
380
381
382
383
384
    shutil.rmtree(runner.work_dir)

    # TODO: use a more elegant way to check values
    assert hasattr(hook, 'writer')
    calls = [
        call('train', {
            'learning_rate': 0.01,
            'momentum': 0.95
385
        }, 1),
Harry's avatar
Harry committed
386
        call('train', {
Kuro Latency's avatar
Kuro Latency committed
387
            'learning_rate': 0.01,
Harry's avatar
Harry committed
388
            'momentum': 0.95
389
        }, 6),
Harry's avatar
Harry committed
390
391
392
        call('train', {
            'learning_rate': 0.0009549150281252633,
            'momentum': 0.95
393
        }, 10)
Harry's avatar
Harry committed
394
395
396
397
    ]
    hook.writer.add_scalars.assert_has_calls(calls, any_order=True)


398
399
400
401
402
@pytest.mark.parametrize('log_model', (True, False))
def test_mlflow_hook(log_model):
    sys.modules['mlflow'] = MagicMock()
    sys.modules['mlflow.pytorch'] = MagicMock()

Wenwei Zhang's avatar
Wenwei Zhang committed
403
404
    runner = _build_demo_runner()
    loader = DataLoader(torch.ones((5, 2)))
405

406
    hook = MlflowLoggerHook(exp_name='test', log_model=log_model)
407
    runner.register_hook(hook)
408
    runner.run([loader, loader], [('train', 1), ('val', 1)])
409
    shutil.rmtree(runner.work_dir)
410
411

    hook.mlflow.set_experiment.assert_called_with('test')
Wenwei Zhang's avatar
Wenwei Zhang committed
412
413
414
415
    hook.mlflow.log_metrics.assert_called_with(
        {
            'learning_rate': 0.02,
            'momentum': 0.95
416
        }, step=6)
417
418
419
420
421
422
423
424
425
    if log_model:
        hook.mlflow_pytorch.log_model.assert_called_with(
            runner.model, 'models')
    else:
        assert not hook.mlflow_pytorch.log_model.called


def test_wandb_hook():
    sys.modules['wandb'] = MagicMock()
Wenwei Zhang's avatar
Wenwei Zhang committed
426
    runner = _build_demo_runner()
427
    hook = WandbLoggerHook()
Wenwei Zhang's avatar
Wenwei Zhang committed
428
    loader = DataLoader(torch.ones((5, 2)))
429
430

    runner.register_hook(hook)
431
    runner.run([loader, loader], [('train', 1), ('val', 1)])
432
433
    shutil.rmtree(runner.work_dir)

434
    hook.wandb.init.assert_called_with()
Wenwei Zhang's avatar
Wenwei Zhang committed
435
436
437
438
    hook.wandb.log.assert_called_with({
        'learning_rate': 0.02,
        'momentum': 0.95
    },
439
440
                                      step=6,
                                      commit=True)
441
    hook.wandb.join.assert_called_with()
Wenwei Zhang's avatar
Wenwei Zhang committed
442
443


444
445
446
def _build_demo_runner(runner_type='EpochBasedRunner',
                       max_epochs=1,
                       max_iters=None):
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464

    class Model(nn.Module):

        def __init__(self):
            super().__init__()
            self.linear = nn.Linear(2, 1)

        def forward(self, x):
            return self.linear(x)

        def train_step(self, x, optimizer, **kwargs):
            return dict(loss=self(x))

        def val_step(self, x, optimizer, **kwargs):
            return dict(loss=self(x))

    model = Model()

Wenwei Zhang's avatar
Wenwei Zhang committed
465
466
467
468
469
470
471
    optimizer = torch.optim.SGD(model.parameters(), lr=0.02, momentum=0.95)

    log_config = dict(
        interval=1, hooks=[
            dict(type='TextLoggerHook'),
        ])

472
    tmp_dir = tempfile.mkdtemp()
473
474
475
476
477
478
479
480
481
    runner = build_runner(
        dict(type=runner_type),
        default_args=dict(
            model=model,
            work_dir=tmp_dir,
            optimizer=optimizer,
            logger=logging.getLogger(),
            max_epochs=max_epochs,
            max_iters=max_iters))
482
    runner.register_checkpoint_hook(dict(interval=1))
Wenwei Zhang's avatar
Wenwei Zhang committed
483
484
    runner.register_logger_hooks(log_config)
    return runner
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523


def test_runner_with_revise_keys():

    import os

    class Model(nn.Module):

        def __init__(self):
            super().__init__()
            self.conv = nn.Conv2d(3, 3, 1)

    class PrefixModel(nn.Module):

        def __init__(self):
            super().__init__()
            self.backbone = Model()

    pmodel = PrefixModel()
    model = Model()
    checkpoint_path = os.path.join(tempfile.gettempdir(), 'checkpoint.pth')

    # add prefix
    torch.save(model.state_dict(), checkpoint_path)
    runner = _build_demo_runner(runner_type='EpochBasedRunner')
    runner.model = pmodel
    state_dict = runner.load_checkpoint(
        checkpoint_path, revise_keys=[(r'^', 'backbone.')])
    for key in pmodel.backbone.state_dict().keys():
        assert torch.equal(pmodel.backbone.state_dict()[key], state_dict[key])
    # strip prefix
    torch.save(pmodel.state_dict(), checkpoint_path)
    runner.model = model
    state_dict = runner.load_checkpoint(
        checkpoint_path, revise_keys=[(r'^backbone\.', '')])
    for key in state_dict.keys():
        key_stripped = re.sub(r'^backbone\.', '', key)
        assert torch.equal(model.state_dict()[key_stripped], state_dict[key])
    os.remove(checkpoint_path)