test_hooks.py 5.91 KB
Newer Older
Wenwei Zhang's avatar
Wenwei Zhang committed
1
2
3
4
5
6
7
8
"""
Tests the hooks with runners.

CommandLine:
    pytest tests/test_hooks.py
    xdoctest tests/test_hooks.py zero

"""
9
import logging
Jiangmiao Pang's avatar
Jiangmiao Pang committed
10
import os.path as osp
11
import shutil
Jiangmiao Pang's avatar
Jiangmiao Pang committed
12
import sys
13
import tempfile
Wenwei Zhang's avatar
Wenwei Zhang committed
14
from unittest.mock import MagicMock, call
Jiangmiao Pang's avatar
Jiangmiao Pang committed
15

16
17
18
19
20
import pytest
import torch
import torch.nn as nn
from torch.utils.data import DataLoader

Jiangmiao Pang's avatar
Jiangmiao Pang committed
21
22
23
24
25
26
import mmcv.runner


def test_pavi_hook():
    sys.modules['pavi'] = MagicMock()

Wenwei Zhang's avatar
Wenwei Zhang committed
27
28
    loader = DataLoader(torch.ones((5, 2)))
    runner = _build_demo_runner()
Jiangmiao Pang's avatar
Jiangmiao Pang committed
29
30
31
32
    hook = mmcv.runner.hooks.PaviLoggerHook(
        add_graph=False, add_last_ckpt=True)
    runner.register_hook(hook)
    runner.run([loader, loader], [('train', 1), ('val', 1)], 1)
33
    shutil.rmtree(runner.work_dir)
Jiangmiao Pang's avatar
Jiangmiao Pang committed
34
35

    assert hasattr(hook, 'writer')
Wenwei Zhang's avatar
Wenwei Zhang committed
36
37
38
39
    hook.writer.add_scalars.assert_called_with('val', {
        'learning_rate': 0.02,
        'momentum': 0.95
    }, 5)
Jiangmiao Pang's avatar
Jiangmiao Pang committed
40
    hook.writer.add_snapshot_file.assert_called_with(
41
        tag=runner.work_dir.split('/')[-1],
Wenwei Zhang's avatar
Wenwei Zhang committed
42
        snapshot_file_path=osp.join(runner.work_dir, 'latest.pth'),
Jiangmiao Pang's avatar
Jiangmiao Pang committed
43
        iteration=5)
44
45


Wenwei Zhang's avatar
Wenwei Zhang committed
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
def test_momentum_runner_hook():
    """
    xdoctest -m tests/test_hooks.py test_momentum_runner_hook
    """
    sys.modules['pavi'] = MagicMock()
    loader = DataLoader(torch.ones((10, 2)))
    runner = _build_demo_runner()

    # add momentum scheduler
    hook = mmcv.runner.hooks.momentum_updater.CyclicMomentumUpdaterHook(
        by_epoch=False,
        target_ratio=(0.85 / 0.95, 1),
        cyclic_times=1,
        step_ratio_up=0.4)
    runner.register_hook(hook)

    # add momentum LR scheduler
    hook = mmcv.runner.hooks.lr_updater.CyclicLrUpdaterHook(
        by_epoch=False,
        target_ratio=(10, 1),
        cyclic_times=1,
        step_ratio_up=0.4)
    runner.register_hook(hook)
    runner.register_hook(mmcv.runner.hooks.IterTimerHook())

    # add pavi hook
    hook = mmcv.runner.hooks.PaviLoggerHook(
        interval=1, add_graph=False, add_last_ckpt=True)
    runner.register_hook(hook)
    runner.run([loader], [('train', 1)], 1)
76
    shutil.rmtree(runner.work_dir)
Wenwei Zhang's avatar
Wenwei Zhang committed
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124

    # TODO: use a more elegant way to check values
    assert hasattr(hook, 'writer')
    calls = [
        call('train', {
            'learning_rate': 0.01999999999999999,
            'momentum': 0.95
        }, 0),
        call('train', {
            'learning_rate': 0.2,
            'momentum': 0.85
        }, 4),
        call('train', {
            'learning_rate': 0.155,
            'momentum': 0.875
        }, 6),
    ]
    hook.writer.add_scalars.assert_has_calls(calls, any_order=True)


def test_cosine_runner_hook():
    """
    xdoctest -m tests/test_hooks.py test_cosine_runner_hook
    """
    sys.modules['pavi'] = MagicMock()
    loader = DataLoader(torch.ones((10, 2)))
    runner = _build_demo_runner()

    # add momentum scheduler
    hook = mmcv.runner.hooks.momentum_updater \
        .CosineAnealingMomentumUpdaterHook(
            min_momentum_ratio=0.99 / 0.95,
            by_epoch=False,
            warmup_iters=2,
            warmup_ratio=0.9 / 0.95)
    runner.register_hook(hook)

    # add momentum LR scheduler
    hook = mmcv.runner.hooks.lr_updater.CosineAnealingLrUpdaterHook(
        by_epoch=False, min_lr_ratio=0, warmup_iters=2, warmup_ratio=0.9)
    runner.register_hook(hook)
    runner.register_hook(mmcv.runner.hooks.IterTimerHook())

    # add pavi hook
    hook = mmcv.runner.hooks.PaviLoggerHook(
        interval=1, add_graph=False, add_last_ckpt=True)
    runner.register_hook(hook)
    runner.run([loader], [('train', 1)], 1)
125
    shutil.rmtree(runner.work_dir)
Wenwei Zhang's avatar
Wenwei Zhang committed
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145

    # TODO: use a more elegant way to check values
    assert hasattr(hook, 'writer')
    calls = [
        call('train', {
            'learning_rate': 0.02,
            'momentum': 0.95
        }, 0),
        call('train', {
            'learning_rate': 0.01,
            'momentum': 0.97
        }, 5),
        call('train', {
            'learning_rate': 0.0004894348370484647,
            'momentum': 0.9890211303259032
        }, 9)
    ]
    hook.writer.add_scalars.assert_has_calls(calls, any_order=True)


146
147
148
149
150
@pytest.mark.parametrize('log_model', (True, False))
def test_mlflow_hook(log_model):
    sys.modules['mlflow'] = MagicMock()
    sys.modules['mlflow.pytorch'] = MagicMock()

Wenwei Zhang's avatar
Wenwei Zhang committed
151
152
    runner = _build_demo_runner()
    loader = DataLoader(torch.ones((5, 2)))
153
154
155
156
157

    hook = mmcv.runner.hooks.MlflowLoggerHook(
        exp_name='test', log_model=log_model)
    runner.register_hook(hook)
    runner.run([loader, loader], [('train', 1), ('val', 1)], 1)
158
    shutil.rmtree(runner.work_dir)
159
160

    hook.mlflow.set_experiment.assert_called_with('test')
Wenwei Zhang's avatar
Wenwei Zhang committed
161
162
163
164
165
    hook.mlflow.log_metrics.assert_called_with(
        {
            'learning_rate': 0.02,
            'momentum': 0.95
        }, step=5)
166
167
168
169
170
171
172
173
174
    if log_model:
        hook.mlflow_pytorch.log_model.assert_called_with(
            runner.model, 'models')
    else:
        assert not hook.mlflow_pytorch.log_model.called


def test_wandb_hook():
    sys.modules['wandb'] = MagicMock()
Wenwei Zhang's avatar
Wenwei Zhang committed
175
    runner = _build_demo_runner()
176
    hook = mmcv.runner.hooks.WandbLoggerHook()
Wenwei Zhang's avatar
Wenwei Zhang committed
177
    loader = DataLoader(torch.ones((5, 2)))
178
179
180

    runner.register_hook(hook)
    runner.run([loader, loader], [('train', 1), ('val', 1)], 1)
181
182
    shutil.rmtree(runner.work_dir)

183
    hook.wandb.init.assert_called_with()
Wenwei Zhang's avatar
Wenwei Zhang committed
184
185
186
187
188
    hook.wandb.log.assert_called_with({
        'learning_rate': 0.02,
        'momentum': 0.95
    },
                                      step=5)
189
    hook.wandb.join.assert_called_with()
Wenwei Zhang's avatar
Wenwei Zhang committed
190
191
192
193
194
195
196
197
198
199
200


def _build_demo_runner():
    model = nn.Linear(2, 1)
    optimizer = torch.optim.SGD(model.parameters(), lr=0.02, momentum=0.95)

    log_config = dict(
        interval=1, hooks=[
            dict(type='TextLoggerHook'),
        ])

201
    tmp_dir = tempfile.mkdtemp()
Wenwei Zhang's avatar
Wenwei Zhang committed
202
203
    runner = mmcv.runner.Runner(
        model=model,
204
        work_dir=tmp_dir,
Wenwei Zhang's avatar
Wenwei Zhang committed
205
        batch_processor=lambda model, x, **kwargs: {'loss': model(x) - 0},
206
207
        optimizer=optimizer,
        logger=logging.getLogger())
Wenwei Zhang's avatar
Wenwei Zhang committed
208
209
210

    runner.register_logger_hooks(log_config)
    return runner