model.py 5.36 KB
Newer Older
1
import torch
2
import nni.retiarii.nn.pytorch as nn
3
4
5
6
import math

import ops
import putils
7
from nni.retiarii.nn.pytorch import LayerChoice
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59

class SearchMobileNet(nn.Module):
    def __init__(self,
                 width_stages=[24,40,80,96,192,320],
                 n_cell_stages=[4,4,4,4,4,1],
                 stride_stages=[2,2,2,1,2,1],
                 width_mult=1, n_classes=1000,
                 dropout_rate=0, bn_param=(0.1, 1e-3)):
        """
        Parameters
        ----------
        width_stages: str
            width (output channels) of each cell stage in the block
        n_cell_stages: str
            number of cells in each cell stage
        stride_strages: str
            stride of each cell stage in the block
        width_mult : int
            the scale factor of width
        """
        super(SearchMobileNet, self).__init__()

        input_channel = putils.make_divisible(32 * width_mult, 8)
        first_cell_width = putils.make_divisible(16 * width_mult, 8)
        for i in range(len(width_stages)):
            width_stages[i] = putils.make_divisible(width_stages[i] * width_mult, 8)
        # first conv
        first_conv = ops.ConvLayer(3, input_channel, kernel_size=3, stride=2, use_bn=True, act_func='relu6', ops_order='weight_bn_act')
        # first block
        first_block_conv = ops.OPS['3x3_MBConv1'](input_channel, first_cell_width, 1)
        first_block = first_block_conv

        input_channel = first_cell_width

        blocks = [first_block]

        stage_cnt = 0
        for width, n_cell, s in zip(width_stages, n_cell_stages, stride_stages):
            for i in range(n_cell):
                if i == 0:
                    stride = s
                else:
                    stride = 1
                op_candidates = [ops.OPS['3x3_MBConv3'](input_channel, width, stride),
                                 ops.OPS['3x3_MBConv6'](input_channel, width, stride),
                                 ops.OPS['5x5_MBConv3'](input_channel, width, stride),
                                 ops.OPS['5x5_MBConv6'](input_channel, width, stride),
                                 ops.OPS['7x7_MBConv3'](input_channel, width, stride),
                                 ops.OPS['7x7_MBConv6'](input_channel, width, stride)]
                if stride == 1 and input_channel == width:
                    # if it is not the first one
                    op_candidates += [ops.OPS['Zero'](input_channel, width, stride)]
60
                    conv_op = LayerChoice(op_candidates, label="s{}_c{}".format(stage_cnt, i))
61
                else:
62
                    conv_op = LayerChoice(op_candidates, label="s{}_c{}".format(stage_cnt, i))
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
                # shortcut
                if stride == 1 and input_channel == width:
                    # if not first cell
                    shortcut = ops.IdentityLayer(input_channel, input_channel)
                else:
                    shortcut = None
                inverted_residual_block = ops.MobileInvertedResidualBlock(conv_op, shortcut, op_candidates)
                blocks.append(inverted_residual_block)
                input_channel = width
            stage_cnt += 1

        # feature mix layer
        last_channel = putils.make_devisible(1280 * width_mult, 8) if width_mult > 1.0 else 1280
        feature_mix_layer = ops.ConvLayer(input_channel, last_channel, kernel_size=1, use_bn=True, act_func='relu6', ops_order='weight_bn_act', )
        classifier = ops.LinearLayer(last_channel, n_classes, dropout_rate=dropout_rate)

        self.first_conv = first_conv
        self.blocks = nn.ModuleList(blocks)
        self.feature_mix_layer = feature_mix_layer
        self.global_avg_pooling = nn.AdaptiveAvgPool2d(1)
        self.classifier = classifier

        # set bn param
        self.set_bn_param(momentum=bn_param[0], eps=bn_param[1])

    def forward(self, x):
        x = self.first_conv(x)
        for block in self.blocks:
            x = block(x)
        x = self.feature_mix_layer(x)
        x = self.global_avg_pooling(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x

    def set_bn_param(self, momentum, eps):
        for m in self.modules():
            if isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
                m.momentum = momentum
                m.eps = eps
        return

    def init_model(self, model_init='he_fout', init_div_groups=False):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                if model_init == 'he_fout':
                    n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                    if init_div_groups:
                        n /= m.groups
                    m.weight.data.normal_(0, math.sqrt(2. / n))
                elif model_init == 'he_fin':
                    n = m.kernel_size[0] * m.kernel_size[1] * m.in_channels
                    if init_div_groups:
                        n /= m.groups
                    m.weight.data.normal_(0, math.sqrt(2. / n))
                else:
                    raise NotImplementedError
            elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                stdv = 1. / math.sqrt(m.weight.size(1))
                m.weight.data.uniform_(-stdv, stdv)
                if m.bias is not None:
                    m.bias.data.zero_()