model.py 5.36 KB
Newer Older
qianyj's avatar
qianyj committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import torch
import nni.retiarii.nn.pytorch as nn
import math

import ops
import putils
from nni.retiarii.nn.pytorch import LayerChoice

class SearchMobileNet(nn.Module):
    def __init__(self,
                 width_stages=[24,40,80,96,192,320],
                 n_cell_stages=[4,4,4,4,4,1],
                 stride_stages=[2,2,2,1,2,1],
                 width_mult=1, n_classes=1000,
                 dropout_rate=0, bn_param=(0.1, 1e-3)):
        """
        Parameters
        ----------
        width_stages: str
            width (output channels) of each cell stage in the block
        n_cell_stages: str
            number of cells in each cell stage
        stride_strages: str
            stride of each cell stage in the block
        width_mult : int
            the scale factor of width
        """
        super(SearchMobileNet, self).__init__()

        input_channel = putils.make_divisible(32 * width_mult, 8)
        first_cell_width = putils.make_divisible(16 * width_mult, 8)
        for i in range(len(width_stages)):
            width_stages[i] = putils.make_divisible(width_stages[i] * width_mult, 8)
        # first conv
        first_conv = ops.ConvLayer(3, input_channel, kernel_size=3, stride=2, use_bn=True, act_func='relu6', ops_order='weight_bn_act')
        # first block
        first_block_conv = ops.OPS['3x3_MBConv1'](input_channel, first_cell_width, 1)
        first_block = first_block_conv

        input_channel = first_cell_width

        blocks = [first_block]

        stage_cnt = 0
        for width, n_cell, s in zip(width_stages, n_cell_stages, stride_stages):
            for i in range(n_cell):
                if i == 0:
                    stride = s
                else:
                    stride = 1
                op_candidates = [ops.OPS['3x3_MBConv3'](input_channel, width, stride),
                                 ops.OPS['3x3_MBConv6'](input_channel, width, stride),
                                 ops.OPS['5x5_MBConv3'](input_channel, width, stride),
                                 ops.OPS['5x5_MBConv6'](input_channel, width, stride),
                                 ops.OPS['7x7_MBConv3'](input_channel, width, stride),
                                 ops.OPS['7x7_MBConv6'](input_channel, width, stride)]
                if stride == 1 and input_channel == width:
                    # if it is not the first one
                    op_candidates += [ops.OPS['Zero'](input_channel, width, stride)]
                    conv_op = LayerChoice(op_candidates, label="s{}_c{}".format(stage_cnt, i))
                else:
                    conv_op = LayerChoice(op_candidates, label="s{}_c{}".format(stage_cnt, i))
                # shortcut
                if stride == 1 and input_channel == width:
                    # if not first cell
                    shortcut = ops.IdentityLayer(input_channel, input_channel)
                else:
                    shortcut = None
                inverted_residual_block = ops.MobileInvertedResidualBlock(conv_op, shortcut, op_candidates)
                blocks.append(inverted_residual_block)
                input_channel = width
            stage_cnt += 1

        # feature mix layer
        last_channel = putils.make_devisible(1280 * width_mult, 8) if width_mult > 1.0 else 1280
        feature_mix_layer = ops.ConvLayer(input_channel, last_channel, kernel_size=1, use_bn=True, act_func='relu6', ops_order='weight_bn_act', )
        classifier = ops.LinearLayer(last_channel, n_classes, dropout_rate=dropout_rate)

        self.first_conv = first_conv
        self.blocks = nn.ModuleList(blocks)
        self.feature_mix_layer = feature_mix_layer
        self.global_avg_pooling = nn.AdaptiveAvgPool2d(1)
        self.classifier = classifier

        # set bn param
        self.set_bn_param(momentum=bn_param[0], eps=bn_param[1])

    def forward(self, x):
        x = self.first_conv(x)
        for block in self.blocks:
            x = block(x)
        x = self.feature_mix_layer(x)
        x = self.global_avg_pooling(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x

    def set_bn_param(self, momentum, eps):
        for m in self.modules():
            if isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
                m.momentum = momentum
                m.eps = eps
        return

    def init_model(self, model_init='he_fout', init_div_groups=False):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                if model_init == 'he_fout':
                    n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                    if init_div_groups:
                        n /= m.groups
                    m.weight.data.normal_(0, math.sqrt(2. / n))
                elif model_init == 'he_fin':
                    n = m.kernel_size[0] * m.kernel_size[1] * m.in_channels
                    if init_div_groups:
                        n /= m.groups
                    m.weight.data.normal_(0, math.sqrt(2. / n))
                else:
                    raise NotImplementedError
            elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                stdv = 1. / math.sqrt(m.weight.size(1))
                m.weight.data.uniform_(-stdv, stdv)
                if m.bias is not None:
                    m.bias.data.zero_()