macro.py 3.19 KB
Newer Older
1
2
import torch.nn as nn

3
from nni.nas.pytorch import mutables
4
5
6
from ops import FactorizedReduce, ConvBranch, PoolBranch


7
class ENASLayer(mutables.MutableScope):
8

9
10
    def __init__(self, key, num_prev_layers, in_filters, out_filters):
        super().__init__(key)
11
12
13
14
15
16
17
18
19
20
        self.in_filters = in_filters
        self.out_filters = out_filters
        self.mutable = mutables.LayerChoice([
            ConvBranch(in_filters, out_filters, 3, 1, 1, separable=False),
            ConvBranch(in_filters, out_filters, 3, 1, 1, separable=True),
            ConvBranch(in_filters, out_filters, 5, 1, 2, separable=False),
            ConvBranch(in_filters, out_filters, 5, 1, 2, separable=True),
            PoolBranch('avg', in_filters, out_filters, 3, 1, 1),
            PoolBranch('max', in_filters, out_filters, 3, 1, 1)
        ])
21
22
        if num_prev_layers > 0:
            self.skipconnect = mutables.InputChoice(num_prev_layers, n_selected=None, reduction="sum")
23
24
25
26
        else:
            self.skipconnect = None
        self.batch_norm = nn.BatchNorm2d(out_filters, affine=False)

27
28
29
30
31
32
33
    def forward(self, prev_layers, prev_labels):
        out = self.mutable(prev_layers[-1])
        if self.skipconnect is not None:
            connection = self.skipconnect(prev_layers[:-1], tags=prev_labels)
            if connection is not None:
                out += connection
        return self.batch_norm(out)
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58


class GeneralNetwork(nn.Module):
    def __init__(self, num_layers=12, out_filters=24, in_channels=3, num_classes=10,
                 dropout_rate=0.0):
        super().__init__()
        self.num_layers = num_layers
        self.num_classes = num_classes
        self.out_filters = out_filters

        self.stem = nn.Sequential(
            nn.Conv2d(in_channels, out_filters, 3, 1, 1, bias=False),
            nn.BatchNorm2d(out_filters)
        )

        pool_distance = self.num_layers // 3
        self.pool_layers_idx = [pool_distance - 1, 2 * pool_distance - 1]
        self.dropout_rate = dropout_rate
        self.dropout = nn.Dropout(self.dropout_rate)

        self.layers = nn.ModuleList()
        self.pool_layers = nn.ModuleList()
        for layer_id in range(self.num_layers):
            if layer_id in self.pool_layers_idx:
                self.pool_layers.append(FactorizedReduce(self.out_filters, self.out_filters))
59
60
            self.layers.append(ENASLayer("layer_{}".format(layer_id), layer_id,
                                         self.out_filters, self.out_filters))
61
62
63
64
65
66
67
68

        self.gap = nn.AdaptiveAvgPool2d(1)
        self.dense = nn.Linear(self.out_filters, self.num_classes)

    def forward(self, x):
        bs = x.size(0)
        cur = self.stem(x)

69
        layers, labels = [cur], []
70
71

        for layer_id in range(self.num_layers):
72
            cur = self.layers[layer_id](layers, labels)
73
            layers.append(cur)
74
            labels.append(self.layers[layer_id].key)
75
76
77
78
79
80
81
82
83
            if layer_id in self.pool_layers_idx:
                for i, layer in enumerate(layers):
                    layers[i] = self.pool_layers[self.pool_layers_idx.index(layer_id)](layer)
                cur = layers[-1]

        cur = self.gap(cur).view(bs, -1)
        cur = self.dropout(cur)
        logits = self.dense(cur)
        return logits