macro.py 3.11 KB
Newer Older
1
2
import torch.nn as nn

3
from nni.nas.pytorch import mutables
4
5
6
from ops import FactorizedReduce, ConvBranch, PoolBranch


7
class ENASLayer(mutables.MutableScope):
8

9
    def __init__(self, key, prev_labels, in_filters, out_filters):
10
        super().__init__(key)
11
12
13
14
15
16
17
18
19
20
        self.in_filters = in_filters
        self.out_filters = out_filters
        self.mutable = mutables.LayerChoice([
            ConvBranch(in_filters, out_filters, 3, 1, 1, separable=False),
            ConvBranch(in_filters, out_filters, 3, 1, 1, separable=True),
            ConvBranch(in_filters, out_filters, 5, 1, 2, separable=False),
            ConvBranch(in_filters, out_filters, 5, 1, 2, separable=True),
            PoolBranch('avg', in_filters, out_filters, 3, 1, 1),
            PoolBranch('max', in_filters, out_filters, 3, 1, 1)
        ])
21
22
        if len(prev_labels) > 0:
            self.skipconnect = mutables.InputChoice(choose_from=prev_labels, n_chosen=None, reduction="sum")
23
24
25
26
        else:
            self.skipconnect = None
        self.batch_norm = nn.BatchNorm2d(out_filters, affine=False)

27
    def forward(self, prev_layers):
28
29
        out = self.mutable(prev_layers[-1])
        if self.skipconnect is not None:
30
            connection = self.skipconnect(prev_layers[:-1])
31
32
33
            if connection is not None:
                out += connection
        return self.batch_norm(out)
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55


class GeneralNetwork(nn.Module):
    def __init__(self, num_layers=12, out_filters=24, in_channels=3, num_classes=10,
                 dropout_rate=0.0):
        super().__init__()
        self.num_layers = num_layers
        self.num_classes = num_classes
        self.out_filters = out_filters

        self.stem = nn.Sequential(
            nn.Conv2d(in_channels, out_filters, 3, 1, 1, bias=False),
            nn.BatchNorm2d(out_filters)
        )

        pool_distance = self.num_layers // 3
        self.pool_layers_idx = [pool_distance - 1, 2 * pool_distance - 1]
        self.dropout_rate = dropout_rate
        self.dropout = nn.Dropout(self.dropout_rate)

        self.layers = nn.ModuleList()
        self.pool_layers = nn.ModuleList()
56
        labels = []
57
        for layer_id in range(self.num_layers):
58
            labels.append("layer_{}".format(layer_id))
59
60
            if layer_id in self.pool_layers_idx:
                self.pool_layers.append(FactorizedReduce(self.out_filters, self.out_filters))
61
            self.layers.append(ENASLayer(labels[-1], labels[:-1], self.out_filters, self.out_filters))
62
63
64
65
66
67
68
69

        self.gap = nn.AdaptiveAvgPool2d(1)
        self.dense = nn.Linear(self.out_filters, self.num_classes)

    def forward(self, x):
        bs = x.size(0)
        cur = self.stem(x)

70
        layers = [cur]
71
72

        for layer_id in range(self.num_layers):
73
            cur = self.layers[layer_id](layers)
74
75
76
77
78
79
80
81
82
83
            layers.append(cur)
            if layer_id in self.pool_layers_idx:
                for i, layer in enumerate(layers):
                    layers[i] = self.pool_layers[self.pool_layers_idx.index(layer_id)](layer)
                cur = layers[-1]

        cur = self.gap(cur).view(bs, -1)
        cur = self.dropout(cur)
        logits = self.dense(cur)
        return logits