"vscode:/vscode.git/clone" did not exist on "55f5fc68acc84027a275649d2cfb320448ed95d7"
test_search.py 5.95 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# Copyright (c) OpenMMLab. All rights reserved.
"""Tests the rfsearch with runners.

CommandLine:
    pytest tests/test_runner/test_hooks.py
    xdoctest tests/test_hooks.py zero
"""

import torch
import torch.nn as nn
from torch.utils.data import DataLoader

from mmcv.cnn.rfsearch import Conv2dRFSearchOp, RFSearchHook
from tests.test_runner.test_hooks import _build_demo_runner


def test_rfsearchhook():

19
20
21
22
23
24
25
26
27
28
    def conv(in_channels, out_channels, kernel_size, stride, padding,
             dilation):
        return nn.Conv2d(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
            dilation=dilation)

29
30
31
32
    class Model(nn.Module):

        def __init__(self):
            super().__init__()
33
34
35
36
37
38
39
            self.stem = conv(1, 2, 3, 1, 1, 1)
            self.conv0 = conv(2, 2, 3, 1, 1, 1)
            self.layer0 = nn.Sequential(
                conv(2, 2, 3, 1, 1, 1), conv(2, 2, 3, 1, 1, 1))
            self.conv1 = conv(2, 2, 1, 1, 0, 1)
            self.conv2 = conv(2, 2, 3, 1, 1, 1)
            self.conv3 = conv(2, 2, (1, 3), 1, (0, 1), 1)
40
41

        def forward(self, x):
42
43
44
45
46
47
48
            x1 = self.stem(x)
            x2 = self.layer0(x1)
            x3 = self.conv0(x2)
            x4 = self.conv1(x3)
            x5 = self.conv2(x4)
            x6 = self.conv3(x5)
            return x6
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65

        def train_step(self, x, optimizer, **kwargs):
            return dict(loss=self(x).mean(), num_samples=x.shape[0])

    rfsearch_cfg = dict(
        mode='search',
        rfstructure_file=None,
        config=dict(
            search=dict(
                step=0,
                max_step=12,
                search_interval=1,
                exp_rate=0.5,
                init_alphas=0.01,
                mmin=1,
                mmax=24,
                num_branches=2,
66
                skip_layer=['stem', 'conv0', 'layer0.1'])),
67
68
69
70
71
72
    )

    # hook for search
    rfsearchhook_search = RFSearchHook(
        'search', rfsearch_cfg['config'], by_epoch=True, verbose=True)
    rfsearchhook_search.config['structure'] = {
73
        'module.layer0.0': [1, 1],
74
75
76
77
78
79
80
81
82
83
        'module.conv2': [2, 2],
        'module.conv3': [1, 1]
    }
    # hook for fixed_single_branch
    rfsearchhook_fixed_single_branch = RFSearchHook(
        'fixed_single_branch',
        rfsearch_cfg['config'],
        by_epoch=True,
        verbose=True)
    rfsearchhook_fixed_single_branch.config['structure'] = {
84
        'module.layer0.0': [1, 1],
85
86
87
88
89
90
91
92
93
94
        'module.conv2': [2, 2],
        'module.conv3': [1, 1]
    }
    # hook for fixed_multi_branch
    rfsearchhook_fixed_multi_branch = RFSearchHook(
        'fixed_multi_branch',
        rfsearch_cfg['config'],
        by_epoch=True,
        verbose=True)
    rfsearchhook_fixed_multi_branch.config['structure'] = {
95
        'module.layer0.0': [1, 1],
96
97
98
99
        'module.conv2': [2, 2],
        'module.conv3': [1, 1]
    }

100
101
102
103
104
105
    def test_skip_layer():
        assert not isinstance(model.stem, Conv2dRFSearchOp)
        assert not isinstance(model.conv0, Conv2dRFSearchOp)
        assert isinstance(model.layer0[0], Conv2dRFSearchOp)
        assert not isinstance(model.layer0[1], Conv2dRFSearchOp)

106
107
108
109
    # 1. test init_model() with mode of search
    model = Model()
    rfsearchhook_search.init_model(model)

110
    test_skip_layer()
111
112
113
114
115
116
117
118
119
120
121
122
123
    assert not isinstance(model.conv1, Conv2dRFSearchOp)
    assert isinstance(model.conv2, Conv2dRFSearchOp)
    assert isinstance(model.conv3, Conv2dRFSearchOp)
    assert model.conv2.dilation_rates == [(1, 1), (3, 3)]
    assert model.conv3.dilation_rates == [(1, 1), (1, 2)]

    # 1. test step() with mode of search
    loader = DataLoader(torch.ones((1, 1, 1, 1)))
    runner = _build_demo_runner()
    runner.model = model
    runner.register_hook(rfsearchhook_search)
    runner.run([loader], [('train', 1)])

124
    test_skip_layer()
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
    assert not isinstance(model.conv1, Conv2dRFSearchOp)
    assert isinstance(model.conv2, Conv2dRFSearchOp)
    assert isinstance(model.conv3, Conv2dRFSearchOp)
    assert model.conv2.dilation_rates == [(1, 1), (3, 3)]
    assert model.conv3.dilation_rates == [(1, 1), (1, 3)]

    # 2. test init_model() with mode of fixed_single_branch
    model = Model()
    rfsearchhook_fixed_single_branch.init_model(model)

    assert not isinstance(model.conv1, Conv2dRFSearchOp)
    assert not isinstance(model.conv2, Conv2dRFSearchOp)
    assert not isinstance(model.conv3, Conv2dRFSearchOp)
    assert model.conv1.dilation == (1, 1)
    assert model.conv2.dilation == (2, 2)
    assert model.conv3.dilation == (1, 1)

    # 2. test step() with mode of fixed_single_branch
    runner = _build_demo_runner()
    runner.model = model
    runner.register_hook(rfsearchhook_fixed_single_branch)
    runner.run([loader], [('train', 1)])

    assert not isinstance(model.conv1, Conv2dRFSearchOp)
    assert not isinstance(model.conv2, Conv2dRFSearchOp)
    assert not isinstance(model.conv3, Conv2dRFSearchOp)
    assert model.conv1.dilation == (1, 1)
    assert model.conv2.dilation == (2, 2)
    assert model.conv3.dilation == (1, 1)

    # 3. test init_model() with mode of fixed_multi_branch
    model = Model()
    rfsearchhook_fixed_multi_branch.init_model(model)

159
    test_skip_layer()
160
161
162
163
164
165
166
167
168
169
170
171
    assert not isinstance(model.conv1, Conv2dRFSearchOp)
    assert isinstance(model.conv2, Conv2dRFSearchOp)
    assert isinstance(model.conv3, Conv2dRFSearchOp)
    assert model.conv2.dilation_rates == [(1, 1), (3, 3)]
    assert model.conv3.dilation_rates == [(1, 1), (1, 2)]

    # 3. test step() with mode of fixed_single_branch
    runner = _build_demo_runner()
    runner.model = model
    runner.register_hook(rfsearchhook_fixed_multi_branch)
    runner.run([loader], [('train', 1)])

172
    test_skip_layer()
173
174
175
176
177
    assert not isinstance(model.conv1, Conv2dRFSearchOp)
    assert isinstance(model.conv2, Conv2dRFSearchOp)
    assert isinstance(model.conv3, Conv2dRFSearchOp)
    assert model.conv2.dilation_rates == [(1, 1), (3, 3)]
    assert model.conv3.dilation_rates == [(1, 1), (1, 2)]