model.py 6.49 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import torch
import torch.nn as nn
import torch.nn.functional as F

import ops
import numpy as np
from nni.nas.pytorch import mutables
from utils import parse_results
from aux_head import DistillHeadCIFAR, DistillHeadImagenet, AuxiliaryHeadCIFAR, AuxiliaryHeadImageNet


class Node(nn.Module):
    def __init__(self, node_id, num_prev_nodes, channels, num_downsample_connect):
        super().__init__()
        self.ops = nn.ModuleList()
        choice_keys = []
        for i in range(num_prev_nodes):
            stride = 2 if i < num_downsample_connect else 1
            choice_keys.append("{}_p{}".format(node_id, i))
            self.ops.append(mutables.LayerChoice([ops.OPS[k](channels, stride, False) for k in ops.PRIMITIVES],
                                                 key=choice_keys[-1]))
        self.drop_path = ops.DropPath()
        self.input_switch = mutables.InputChoice(choose_from=choice_keys, n_chosen=2, key="{}_switch".format(node_id))

    def forward(self, prev_nodes):
        assert len(self.ops) == len(prev_nodes)
        out = [op(node) for op, node in zip(self.ops, prev_nodes)]
        out = [self.drop_path(o) if o is not None else None for o in out]
        return self.input_switch(out)


class Cell(nn.Module):

    def __init__(self, n_nodes, channels_pp, channels_p, channels, reduction_p, reduction):
        super().__init__()
        self.reduction = reduction
        self.n_nodes = n_nodes

        # If previous cell is reduction cell, current input size does not match with
        # output size of cell[k-2]. So the output[k-2] should be reduced by preprocessing.
        if reduction_p:
            self.preproc0 = ops.FactorizedReduce(channels_pp, channels, affine=False)
        else:
            self.preproc0 = ops.StdConv(channels_pp, channels, 1, 1, 0, affine=False)
        self.preproc1 = ops.StdConv(channels_p, channels, 1, 1, 0, affine=False)

        # generate dag
        self.mutable_ops = nn.ModuleList()
        for depth in range(2, self.n_nodes + 2):
            self.mutable_ops.append(Node("{}_n{}".format("reduce" if reduction else "normal", depth),
                                         depth, channels, 2 if reduction else 0))

    def forward(self, s0, s1):
        # s0, s1 are the outputs of previous previous cell and previous cell, respectively.
        tensors = [self.preproc0(s0), self.preproc1(s1)]
        for node in self.mutable_ops:
            cur_tensor = node(tensors)
            tensors.append(cur_tensor)

        output = torch.cat(tensors[2:], dim=1)
        return output


class Model(nn.Module):

    def __init__(self, dataset, n_layers, in_channels=3, channels=16, n_nodes=4, retrain=False, shared_modules=None):
        super().__init__()
        assert dataset in ["cifar10", "imagenet"]
        self.dataset = dataset
        self.input_size = 32 if dataset == "cifar" else 224
        self.in_channels = in_channels
        self.channels = channels
        self.n_nodes = n_nodes
        self.aux_size = {2 * n_layers // 3: self.input_size // 4}
        if dataset == "cifar10":
            self.n_classes = 10
            self.aux_head_class = AuxiliaryHeadCIFAR if retrain else DistillHeadCIFAR
            if not retrain:
                self.aux_size = {n_layers // 3: 6, 2 * n_layers // 3: 6}
        elif dataset == "imagenet":
            self.n_classes = 1000
            self.aux_head_class = AuxiliaryHeadImageNet if retrain else DistillHeadImagenet
            if not retrain:
                self.aux_size = {n_layers // 3: 6, 2 * n_layers // 3: 5}
        self.n_layers = n_layers
        self.aux_head = nn.ModuleDict()
        self.ensemble_param = nn.Parameter(torch.rand(len(self.aux_size) + 1) / (len(self.aux_size) + 1)) \
            if not retrain else None

        stem_multiplier = 3 if dataset == "cifar" else 1
        c_cur = stem_multiplier * self.channels
        self.shared_modules = {}  # do not wrap with ModuleDict
        if shared_modules is not None:
            self.stem = shared_modules["stem"]
        else:
            self.stem = nn.Sequential(
                nn.Conv2d(in_channels, c_cur, 3, 1, 1, bias=False),
                nn.BatchNorm2d(c_cur)
            )
            self.shared_modules["stem"] = self.stem

        # for the first cell, stem is used for both s0 and s1
        # [!] channels_pp and channels_p is output channel size, but c_cur is input channel size.
        channels_pp, channels_p, c_cur = c_cur, c_cur, channels

        self.cells = nn.ModuleList()
        reduction_p, reduction = False, False
        aux_head_count = 0
        for i in range(n_layers):
            reduction_p, reduction = reduction, False
            if i in [n_layers // 3, 2 * n_layers // 3]:
                c_cur *= 2
                reduction = True

            cell = Cell(n_nodes, channels_pp, channels_p, c_cur, reduction_p, reduction)
            self.cells.append(cell)
            c_cur_out = c_cur * n_nodes
            if i in self.aux_size:
                if shared_modules is not None:
                    self.aux_head[str(i)] = shared_modules["aux" + str(aux_head_count)]
                else:
                    self.aux_head[str(i)] = self.aux_head_class(c_cur_out, self.aux_size[i], self.n_classes)
                    self.shared_modules["aux" + str(aux_head_count)] = self.aux_head[str(i)]
                aux_head_count += 1
            channels_pp, channels_p = channels_p, c_cur_out

        self.gap = nn.AdaptiveAvgPool2d(1)
        self.linear = nn.Linear(channels_p, self.n_classes)

    def forward(self, x):
        s0 = s1 = self.stem(x)
        outputs = []

        for i, cell in enumerate(self.cells):
            s0, s1 = s1, cell(s0, s1)
            if str(i) in self.aux_head:
                outputs.append(self.aux_head[str(i)](s1))

        out = self.gap(s1)
        out = out.view(out.size(0), -1)  # flatten
        logits = self.linear(out)
        outputs.append(logits)

        if self.ensemble_param is None:
            assert len(outputs) == 2
            return outputs[1], outputs[0]
        else:
            em_output = torch.cat([(e * o) for e, o in zip(F.softmax(self.ensemble_param, dim=0), outputs)], 0)
            return logits, em_output

    def drop_path_prob(self, p):
        for module in self.modules():
            if isinstance(module, ops.DropPath):
                module.p = p

    def plot_genotype(self, results, logger):
        genotypes = parse_results(results, self.n_nodes)
        logger.info(genotypes)
        return genotypes