test_search.py 4.18 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
# Copyright (c) OpenMMLab. All rights reserved.

import torch.nn as nn

from mmcv.cnn.rfsearch import Conv2dRFSearchOp, RFSearchHook


def test_rfsearchhook():

    def conv(in_channels, out_channels, kernel_size, stride, padding,
             dilation):
        return nn.Conv2d(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
            dilation=dilation)

    class Model(nn.Module):

        def __init__(self):
            super().__init__()
            self.stem = conv(1, 2, 3, 1, 1, 1)
            self.conv0 = conv(2, 2, 3, 1, 1, 1)
            self.layer0 = nn.Sequential(
                conv(2, 2, 3, 1, 1, 1), conv(2, 2, 3, 1, 1, 1))
            self.conv1 = conv(2, 2, 1, 1, 0, 1)
            self.conv2 = conv(2, 2, 3, 1, 1, 1)
            self.conv3 = conv(2, 2, (1, 3), 1, (0, 1), 1)

        def forward(self, x):
            x1 = self.stem(x)
            x2 = self.layer0(x1)
            x3 = self.conv0(x2)
            x4 = self.conv1(x3)
            x5 = self.conv2(x4)
            x6 = self.conv3(x5)
            return x6

        def train_step(self, x, optimizer, **kwargs):
            return dict(loss=self(x).mean(), num_samples=x.shape[0])

    rfsearch_cfg = dict(
        mode='search',
        rfstructure_file=None,
        config=dict(
            search=dict(
                step=0,
                max_step=12,
                search_interval=1,
                exp_rate=0.5,
                init_alphas=0.01,
                mmin=1,
                mmax=24,
                num_branches=2,
                skip_layer=['stem', 'conv0', 'layer0.1'])),
    )

    # hook for search
    rfsearchhook_search = RFSearchHook(
        'search', rfsearch_cfg['config'], by_epoch=True, verbose=True)
    rfsearchhook_search.config['structure'] = {
        'module.layer0.0': [1, 1],
        'module.conv2': [2, 2],
        'module.conv3': [1, 1]
    }
    # hook for fixed_single_branch
    rfsearchhook_fixed_single_branch = RFSearchHook(
        'fixed_single_branch',
        rfsearch_cfg['config'],
        by_epoch=True,
        verbose=True)
    rfsearchhook_fixed_single_branch.config['structure'] = {
        'module.layer0.0': [1, 1],
        'module.conv2': [2, 2],
        'module.conv3': [1, 1]
    }
    # hook for fixed_multi_branch
    rfsearchhook_fixed_multi_branch = RFSearchHook(
        'fixed_multi_branch',
        rfsearch_cfg['config'],
        by_epoch=True,
        verbose=True)
    rfsearchhook_fixed_multi_branch.config['structure'] = {
        'module.layer0.0': [1, 1],
        'module.conv2': [2, 2],
        'module.conv3': [1, 1]
    }

    def test_skip_layer():
        assert not isinstance(model.stem, Conv2dRFSearchOp)
        assert not isinstance(model.conv0, Conv2dRFSearchOp)
        assert isinstance(model.layer0[0], Conv2dRFSearchOp)
        assert not isinstance(model.layer0[1], Conv2dRFSearchOp)

    # 1. test init_model() with mode of search
    model = Model()
    rfsearchhook_search.init_model(model)

    test_skip_layer()
    assert not isinstance(model.conv1, Conv2dRFSearchOp)
    assert isinstance(model.conv2, Conv2dRFSearchOp)
    assert isinstance(model.conv3, Conv2dRFSearchOp)
    assert model.conv2.dilation_rates == [(1, 1), (3, 3)]
    assert model.conv3.dilation_rates == [(1, 1), (1, 2)]

    # 2. test init_model() with mode of fixed_single_branch
    model = Model()
    rfsearchhook_fixed_single_branch.init_model(model)

    assert not isinstance(model.conv1, Conv2dRFSearchOp)
    assert not isinstance(model.conv2, Conv2dRFSearchOp)
    assert not isinstance(model.conv3, Conv2dRFSearchOp)
    assert model.conv1.dilation == (1, 1)
    assert model.conv2.dilation == (2, 2)
    assert model.conv3.dilation == (1, 1)

    # 3. test init_model() with mode of fixed_multi_branch
    model = Model()
    rfsearchhook_fixed_multi_branch.init_model(model)

    test_skip_layer()
    assert not isinstance(model.conv1, Conv2dRFSearchOp)
    assert isinstance(model.conv2, Conv2dRFSearchOp)
    assert isinstance(model.conv3, Conv2dRFSearchOp)
    assert model.conv2.dilation_rates == [(1, 1), (3, 3)]
    assert model.conv3.dilation_rates == [(1, 1), (1, 2)]