model.py 6.02 KB
Newer Older
1
2
3
4
import torch
import torch.nn as nn

import ops
5
from nni.nas.pytorch import mutables
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33


class AuxiliaryHead(nn.Module):
    """ Auxiliary head in 2/3 place of network to let the gradient flow well """

    def __init__(self, input_size, C, n_classes):
        """ assuming input size 7x7 or 8x8 """
        assert input_size in [7, 8]
        super().__init__()
        self.net = nn.Sequential(
            nn.ReLU(inplace=True),
            nn.AvgPool2d(5, stride=input_size - 5, padding=0, count_include_pad=False),  # 2x2 out
            nn.Conv2d(C, 128, kernel_size=1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 768, kernel_size=2, bias=False),  # 1x1 out
            nn.BatchNorm2d(768),
            nn.ReLU(inplace=True)
        )
        self.linear = nn.Linear(768, n_classes)

    def forward(self, x):
        out = self.net(x)
        out = out.view(out.size(0), -1)  # flatten
        logits = self.linear(out)
        return logits


34
35
36
class Node(nn.Module):
    def __init__(self, node_id, num_prev_nodes, channels, num_downsample_connect):
        super().__init__()
37
        self.ops = nn.ModuleList()
38
        choice_keys = []
39
40
        for i in range(num_prev_nodes):
            stride = 2 if i < num_downsample_connect else 1
41
            choice_keys.append("{}_p{}".format(node_id, i))
42
43
44
45
46
47
48
49
50
            self.ops.append(
                mutables.LayerChoice(
                    [
                        ops.PoolBN('max', channels, 3, stride, 1, affine=False),
                        ops.PoolBN('avg', channels, 3, stride, 1, affine=False),
                        nn.Identity() if stride == 1 else ops.FactorizedReduce(channels, channels, affine=False),
                        ops.SepConv(channels, channels, 3, stride, 1, affine=False),
                        ops.SepConv(channels, channels, 5, stride, 2, affine=False),
                        ops.DilConv(channels, channels, 3, stride, 2, 2, affine=False),
51
                        ops.DilConv(channels, channels, 5, stride, 4, 2, affine=False)
52
                    ],
53
54
55
                    key=choice_keys[-1]))
        self.drop_path = ops.DropPath_()
        self.input_switch = mutables.InputChoice(choose_from=choice_keys, n_chosen=2, key="{}_switch".format(node_id))
56
57
58
59

    def forward(self, prev_nodes):
        assert len(self.ops) == len(prev_nodes)
        out = [op(node) for op, node in zip(self.ops, prev_nodes)]
60
        out = [self.drop_path(o) if o is not None else None for o in out]
61
        return self.input_switch(out)
62
63
64
65


class Cell(nn.Module):

66
    def __init__(self, n_nodes, channels_pp, channels_p, channels, reduction_p, reduction):
67
68
69
70
71
72
73
74
75
76
77
78
79
80
        super().__init__()
        self.reduction = reduction
        self.n_nodes = n_nodes

        # If previous cell is reduction cell, current input size does not match with
        # output size of cell[k-2]. So the output[k-2] should be reduced by preprocessing.
        if reduction_p:
            self.preproc0 = ops.FactorizedReduce(channels_pp, channels, affine=False)
        else:
            self.preproc0 = ops.StdConv(channels_pp, channels, 1, 1, 0, affine=False)
        self.preproc1 = ops.StdConv(channels_p, channels, 1, 1, 0, affine=False)

        # generate dag
        self.mutable_ops = nn.ModuleList()
81
82
83
        for depth in range(2, self.n_nodes + 2):
            self.mutable_ops.append(Node("{}_n{}".format("reduce" if reduction else "normal", depth),
                                         depth, channels, 2 if reduction else 0))
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98

    def forward(self, s0, s1):
        # s0, s1 are the outputs of previous previous cell and previous cell, respectively.
        tensors = [self.preproc0(s0), self.preproc1(s1)]
        for node in self.mutable_ops:
            cur_tensor = node(tensors)
            tensors.append(cur_tensor)

        output = torch.cat(tensors[2:], dim=1)
        return output


class CNN(nn.Module):

    def __init__(self, input_size, in_channels, channels, n_classes, n_layers, n_nodes=4,
99
                 stem_multiplier=3, auxiliary=False):
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
        super().__init__()
        self.in_channels = in_channels
        self.channels = channels
        self.n_classes = n_classes
        self.n_layers = n_layers
        self.aux_pos = 2 * n_layers // 3 if auxiliary else -1

        c_cur = stem_multiplier * self.channels
        self.stem = nn.Sequential(
            nn.Conv2d(in_channels, c_cur, 3, 1, 1, bias=False),
            nn.BatchNorm2d(c_cur)
        )

        # for the first cell, stem is used for both s0 and s1
        # [!] channels_pp and channels_p is output channel size, but c_cur is input channel size.
        channels_pp, channels_p, c_cur = c_cur, c_cur, channels

        self.cells = nn.ModuleList()
        reduction_p, reduction = False, False
        for i in range(n_layers):
            reduction_p, reduction = reduction, False
            # Reduce featuremap size and double channels in 1/3 and 2/3 layer.
            if i in [n_layers // 3, 2 * n_layers // 3]:
                c_cur *= 2
                reduction = True

126
            cell = Cell(n_nodes, channels_pp, channels_p, c_cur, reduction_p, reduction)
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
            self.cells.append(cell)
            c_cur_out = c_cur * n_nodes
            channels_pp, channels_p = channels_p, c_cur_out

            if i == self.aux_pos:
                self.aux_head = AuxiliaryHead(input_size // 4, channels_p, n_classes)

        self.gap = nn.AdaptiveAvgPool2d(1)
        self.linear = nn.Linear(channels_p, n_classes)

    def forward(self, x):
        s0 = s1 = self.stem(x)

        aux_logits = None
        for i, cell in enumerate(self.cells):
            s0, s1 = s1, cell(s0, s1)
            if i == self.aux_pos and self.training:
                aux_logits = self.aux_head(s1)

        out = self.gap(s1)
        out = out.view(out.size(0), -1)  # flatten
        logits = self.linear(out)

        if aux_logits is not None:
            return logits, aux_logits
        return logits
153
154
155
156
157

    def drop_path_prob(self, p):
        for module in self.modules():
            if isinstance(module, ops.DropPath_):
                module.p = p