"comfy/vscode:/vscode.git/clone" did not exist on "9fccf4aa031a3e0698c03bd5424c78e56fba9f08"
model.py 6.88 KB
Newer Older
huchen's avatar
huchen committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
import torch
import torch.nn as nn
from .operations import *
from torch.autograd import Variable
from .utils import drop_path
from .genotypes import PRIMITIVES

from .protoc.genotype import protoc_pb2


class Cell(nn.Module):

    def __init__(self, genotype, C_prev_prev, C_prev, reduction_prev):
        """
        Basic building block of an architecture, takes the output of previous two cells as input
        Args:
            genotype(protoc_pb2.Cell): a protobuf object defining the cell structure
                it defines the followings:
                * genotype.channel(int): the channel number of intermediate states
                * genotype.type(int): 0 - NORMAL (stride == 1) or 1 - REDUCE (stride == 2)
                * genotype.num_steps(int): the number of intermediate states
                * genotype.concat(list[int]): indices of selected states (including two input states) used for output.
                                              should be in `[0, num_steps + 2)`, where 0 and 1 stand for the two inputs,
                                              and 2, 3, ..., (num_steps + 1) stand for all intermediate states.
                                              The channel number of cell output will be `channel * len(concat)`.
                * genotype.op(list[protoc_pb2.Operation]): list of connections
                * genotype.auxiliary(bool): whether attach auxiliary classification tower after this cell.
                                            when `True`, current cell must be the second reduction cell in the network
            C_prev_prev(int): the output channel number of previous previous cell
            C_prev(int): the output channel number of previous cell
            reduction_prev(bool): `True` if previous cell is a reduction cell (stride == 2)
        """
        super(Cell, self).__init__()
        C = genotype.channel
        self.reduction = genotype.type == 1
        if reduction_prev:
            self.preprocess0 = FactorizedReduce(C_prev_prev, C)
        else:
            self.preprocess0 = ReLUConvBN(C_prev_prev, C, 1, 1, 0)
        self.preprocess1 = ReLUConvBN(C_prev, C, 1, 1, 0)

        if genotype.num_steps > 6:
            raise Exception('Number of intermediate states should not be greater than 6')
        self._steps = genotype.num_steps
        self._concat = genotype.concat
        self.multiplier = len(genotype.concat)

        self._ops = nn.ModuleList()
        self._indices = [[] for _ in range(self._steps)]
        for iop, op_pb in enumerate(genotype.op):
            # for each defined operation, we put them in different buckets indexed by their destinations.
            stride = 2 if self.reduction and op_pb.frm < 2 else 1
            op = OPS[PRIMITIVES[op_pb.type]](C, stride, True)
            self._ops += [op]
            self._indices[op_pb.to - 2].append((op_pb.frm, iop))

    def forward(self, s0, s1, drop_prob):
        s0 = self.preprocess0(s0)  # prev_prev
        s1 = self.preprocess1(s1)  # prev

        states = [s0, s1]
        for indices in self._indices:  # iter over intermediate states
            hs = []
            for frm, iop in indices:  # iter over connections
                h = states[frm]
                op = self._ops[iop]
                h = op(h)
                if self.training and drop_prob > 0.:
                    if not isinstance(op, Identity):
                        h = drop_path(h, drop_prob)
                hs += [h]
            s = sum(hs)  # connections towards intermediate states are summed together
            states += [s]

        # selected states are concatenated together
        return torch.cat([states[i] for i in self._concat], dim=1)


class AuxiliaryHeadImageNet(nn.Module):

    def __init__(self, C, num_classes):
        """assuming input size 14x14"""
        super(AuxiliaryHeadImageNet, self).__init__()
        self.features = nn.Sequential(
            nn.ReLU(inplace=True),
            nn.AvgPool2d(5, stride=2, padding=0, count_include_pad=False),
            nn.Conv2d(C, 128, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 768, 2, bias=False),
            nn.BatchNorm2d(768),
            nn.ReLU(inplace=True)
        )
        self.classifier = nn.Linear(768, num_classes)

    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x.view(x.size(0),-1))
        return x


class NetworkImageNet(nn.Module):

    def __init__(self, genotype, num_classes=1000):
        """

        Args:
            genotype(proto_pb2.Genotype): a protobuf object defining the architecture
            num_classes(int): number of classes, 1000 as default for ImageNet
        """
        super(NetworkImageNet, self).__init__()
        C = genotype.init_channel

        self.stem0 = nn.Sequential(
            nn.Conv2d(3, C // 2, kernel_size=3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(C // 2),
            nn.ReLU(inplace=True),
            nn.Conv2d(C // 2, C, 3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(C),
        )
        self.stem1 = nn.Sequential(
            nn.ReLU(inplace=True),
            nn.Conv2d(C, C, 3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(C),
        )
        C_prev_prev, C_prev, C_curr = C, C, C

        self.cells = nn.ModuleList()
        self.auxiliary_head = None
        self.auxiliary_index = None
        reduction_prev = True

        if len(genotype.cell) > 50:
            raise Exception('Number of cells should not be greater than 50.')

        for i, cell_pb in enumerate(genotype.cell):
            C_curr = cell_pb.channel
            reduction = cell_pb.type == 1
            cell = Cell(cell_pb, C_prev_prev, C_prev, reduction_prev)
            reduction_prev = reduction
            self.cells += [cell]
            C_prev_prev, C_prev = C_prev, cell.multiplier * C_curr
            if cell_pb.auxiliary:
                if not cell_pb.type == 1:
                    raise Exception('Auxiliary head should be attached to reduction cell.')
                if self.auxiliary_head is not None:
                    raise Exception('Only one auxiliary head is allowed, got multiple.')
                C_to_auxiliary = C_prev
                self.auxiliary_head = AuxiliaryHeadImageNet(C_to_auxiliary, num_classes)
                self.auxiliary_index = i

        self.global_pooling = nn.AvgPool2d(7)
        self.classifier = nn.Linear(C_prev, num_classes)
        self.drop_path_prob = 0.

    def forward(self, input):
        logits_aux = None
        s0 = self.stem0(input)
        s1 = self.stem1(s0)
        for i, cell in enumerate(self.cells):
            s0, s1 = s1, cell(s0, s1, self.drop_path_prob)
            if i == self.auxiliary_index:
                if self.training:
                    logits_aux = self.auxiliary_head(s1)
        out = self.global_pooling(s1)
        logits = self.classifier(out.view(out.size(0), -1))
        return logits, logits_aux