import torch import torch.nn as nn from .operations import * from torch.autograd import Variable from .utils import drop_path from .genotypes import PRIMITIVES from .protoc.genotype import protoc_pb2 class Cell(nn.Module): def __init__(self, genotype, C_prev_prev, C_prev, reduction_prev): """ Basic building block of an architecture, takes the output of previous two cells as input Args: genotype(protoc_pb2.Cell): a protobuf object defining the cell structure it defines the followings: * genotype.channel(int): the channel number of intermediate states * genotype.type(int): 0 - NORMAL (stride == 1) or 1 - REDUCE (stride == 2) * genotype.num_steps(int): the number of intermediate states * genotype.concat(list[int]): indices of selected states (including two input states) used for output. should be in `[0, num_steps + 2)`, where 0 and 1 stand for the two inputs, and 2, 3, ..., (num_steps + 1) stand for all intermediate states. The channel number of cell output will be `channel * len(concat)`. * genotype.op(list[protoc_pb2.Operation]): list of connections * genotype.auxiliary(bool): whether attach auxiliary classification tower after this cell. when `True`, current cell must be the second reduction cell in the network C_prev_prev(int): the output channel number of previous previous cell C_prev(int): the output channel number of previous cell reduction_prev(bool): `True` if previous cell is a reduction cell (stride == 2) """ super(Cell, self).__init__() C = genotype.channel self.reduction = genotype.type == 1 if reduction_prev: self.preprocess0 = FactorizedReduce(C_prev_prev, C) else: self.preprocess0 = ReLUConvBN(C_prev_prev, C, 1, 1, 0) self.preprocess1 = ReLUConvBN(C_prev, C, 1, 1, 0) if genotype.num_steps > 6: raise Exception('Number of intermediate states should not be greater than 6') self._steps = genotype.num_steps self._concat = genotype.concat self.multiplier = len(genotype.concat) self._ops = nn.ModuleList() self._indices = [[] for _ in range(self._steps)] for iop, op_pb in enumerate(genotype.op): # for each defined operation, we put them in different buckets indexed by their destinations. stride = 2 if self.reduction and op_pb.frm < 2 else 1 op = OPS[PRIMITIVES[op_pb.type]](C, stride, True) self._ops += [op] self._indices[op_pb.to - 2].append((op_pb.frm, iop)) def forward(self, s0, s1, drop_prob): s0 = self.preprocess0(s0) # prev_prev s1 = self.preprocess1(s1) # prev states = [s0, s1] for indices in self._indices: # iter over intermediate states hs = [] for frm, iop in indices: # iter over connections h = states[frm] op = self._ops[iop] h = op(h) if self.training and drop_prob > 0.: if not isinstance(op, Identity): h = drop_path(h, drop_prob) hs += [h] s = sum(hs) # connections towards intermediate states are summed together states += [s] # selected states are concatenated together return torch.cat([states[i] for i in self._concat], dim=1) class AuxiliaryHeadImageNet(nn.Module): def __init__(self, C, num_classes): """assuming input size 14x14""" super(AuxiliaryHeadImageNet, self).__init__() self.features = nn.Sequential( nn.ReLU(inplace=True), nn.AvgPool2d(5, stride=2, padding=0, count_include_pad=False), nn.Conv2d(C, 128, 1, bias=False), nn.BatchNorm2d(128), nn.ReLU(inplace=True), nn.Conv2d(128, 768, 2, bias=False), nn.BatchNorm2d(768), nn.ReLU(inplace=True) ) self.classifier = nn.Linear(768, num_classes) def forward(self, x): x = self.features(x) x = self.classifier(x.view(x.size(0),-1)) return x class NetworkImageNet(nn.Module): def __init__(self, genotype, num_classes=1000): """ Args: genotype(proto_pb2.Genotype): a protobuf object defining the architecture num_classes(int): number of classes, 1000 as default for ImageNet """ super(NetworkImageNet, self).__init__() C = genotype.init_channel self.stem0 = nn.Sequential( nn.Conv2d(3, C // 2, kernel_size=3, stride=2, padding=1, bias=False), nn.BatchNorm2d(C // 2), nn.ReLU(inplace=True), nn.Conv2d(C // 2, C, 3, stride=2, padding=1, bias=False), nn.BatchNorm2d(C), ) self.stem1 = nn.Sequential( nn.ReLU(inplace=True), nn.Conv2d(C, C, 3, stride=2, padding=1, bias=False), nn.BatchNorm2d(C), ) C_prev_prev, C_prev, C_curr = C, C, C self.cells = nn.ModuleList() self.auxiliary_head = None self.auxiliary_index = None reduction_prev = True if len(genotype.cell) > 50: raise Exception('Number of cells should not be greater than 50.') for i, cell_pb in enumerate(genotype.cell): C_curr = cell_pb.channel reduction = cell_pb.type == 1 cell = Cell(cell_pb, C_prev_prev, C_prev, reduction_prev) reduction_prev = reduction self.cells += [cell] C_prev_prev, C_prev = C_prev, cell.multiplier * C_curr if cell_pb.auxiliary: if not cell_pb.type == 1: raise Exception('Auxiliary head should be attached to reduction cell.') if self.auxiliary_head is not None: raise Exception('Only one auxiliary head is allowed, got multiple.') C_to_auxiliary = C_prev self.auxiliary_head = AuxiliaryHeadImageNet(C_to_auxiliary, num_classes) self.auxiliary_index = i self.global_pooling = nn.AvgPool2d(7) self.classifier = nn.Linear(C_prev, num_classes) self.drop_path_prob = 0. def forward(self, input): logits_aux = None s0 = self.stem0(input) s1 = self.stem1(s0) for i, cell in enumerate(self.cells): s0, s1 = s1, cell(s0, s1, self.drop_path_prob) if i == self.auxiliary_index: if self.training: logits_aux = self.auxiliary_head(s1) out = self.global_pooling(s1) logits = self.classifier(out.view(out.size(0), -1)) return logits, logits_aux