model.py 5.35 KB
Newer Older
1
2
import torch
import torch.nn as nn
3
4
5

import ops
from nni.nas import pytorch as nas
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141


class SearchCell(nn.Module):
    """
    Cell for search.
    """

    def __init__(self, n_nodes, channels_pp, channels_p, channels, reduction_p, reduction):
        """
        Initialization a search cell.

        Parameters
        ----------
        n_nodes: int
            Number of nodes in current DAG.
        channels_pp: int
            Number of output channels from previous previous cell.
        channels_p: int
            Number of output channels from previous cell.
        channels: int
            Number of channels that will be used in the current DAG.
        reduction_p: bool
            Flag for whether the previous cell is reduction cell or not.
        reduction: bool
            Flag for whether the current cell is reduction cell or not.
        """
        super().__init__()
        self.reduction = reduction
        self.n_nodes = n_nodes

        # If previous cell is reduction cell, current input size does not match with
        # output size of cell[k-2]. So the output[k-2] should be reduced by preprocessing.
        if reduction_p:
            self.preproc0 = ops.FactorizedReduce(channels_pp, channels, affine=False)
        else:
            self.preproc0 = ops.StdConv(channels_pp, channels, 1, 1, 0, affine=False)
        self.preproc1 = ops.StdConv(channels_p, channels, 1, 1, 0, affine=False)

        # generate dag
        self.mutable_ops = nn.ModuleList()
        for depth in range(self.n_nodes):
            self.mutable_ops.append(nn.ModuleList())
            for i in range(2 + depth):  # include 2 input nodes
                # reduction should be used only for input node
                stride = 2 if reduction and i < 2 else 1
                op = nas.mutables.LayerChoice([ops.PoolBN('max', channels, 3, stride, 1, affine=False),
                                               ops.PoolBN('avg', channels, 3, stride, 1, affine=False),
                                               ops.Identity() if stride == 1 else
                                               ops.FactorizedReduce(channels, channels, affine=False),
                                               ops.SepConv(channels, channels, 3, stride, 1, affine=False),
                                               ops.SepConv(channels, channels, 5, stride, 2, affine=False),
                                               ops.DilConv(channels, channels, 3, stride, 2, 2, affine=False),
                                               ops.DilConv(channels, channels, 5, stride, 4, 2, affine=False),
                                               ops.Zero(stride)],
                                              key="r{}_d{}_i{}".format(reduction, depth, i))
                self.mutable_ops[depth].append(op)

    def forward(self, s0, s1):
        # s0, s1 are the outputs of previous previous cell and previous cell, respectively.
        tensors = [self.preproc0(s0), self.preproc1(s1)]
        for ops in self.mutable_ops:
            assert len(ops) == len(tensors)
            cur_tensor = sum(op(tensor) for op, tensor in zip(ops, tensors))
            tensors.append(cur_tensor)

        output = torch.cat(tensors[2:], dim=1)
        return output


class SearchCNN(nn.Module):
    """
    Search CNN model
    """

    def __init__(self, in_channels, channels, n_classes, n_layers, n_nodes=4, stem_multiplier=3):
        """
        Initializing a search channelsNN.

        Parameters
        ----------
        in_channels: int
            Number of channels in images.
        channels: int
            Number of channels used in the network.
        n_classes: int
            Number of classes.
        n_layers: int
            Number of cells in the whole network.
        n_nodes: int
            Number of nodes in a cell.
        stem_multiplier: int
            Multiplier of channels in STEM.
        """
        super().__init__()
        self.in_channels = in_channels
        self.channels = channels
        self.n_classes = n_classes
        self.n_layers = n_layers

        c_cur = stem_multiplier * self.channels
        self.stem = nn.Sequential(
            nn.Conv2d(in_channels, c_cur, 3, 1, 1, bias=False),
            nn.BatchNorm2d(c_cur)
        )

        # for the first cell, stem is used for both s0 and s1
        # [!] channels_pp and channels_p is output channel size, but c_cur is input channel size.
        channels_pp, channels_p, c_cur = c_cur, c_cur, channels

        self.cells = nn.ModuleList()
        reduction_p, reduction = False, False
        for i in range(n_layers):
            reduction_p, reduction = reduction, False
            # Reduce featuremap size and double channels in 1/3 and 2/3 layer.
            if i in [n_layers // 3, 2 * n_layers // 3]:
                c_cur *= 2
                reduction = True

            cell = SearchCell(n_nodes, channels_pp, channels_p, c_cur, reduction_p, reduction)
            self.cells.append(cell)
            c_cur_out = c_cur * n_nodes
            channels_pp, channels_p = channels_p, c_cur_out

        self.gap = nn.AdaptiveAvgPool2d(1)
        self.linear = nn.Linear(channels_p, n_classes)

    def forward(self, x):
        s0 = s1 = self.stem(x)

        for cell in self.cells:
            s0, s1 = s1, cell(s0, s1)

        out = self.gap(s1)
        out = out.view(out.size(0), -1)  # flatten
        logits = self.linear(out)
        return logits