# Copyright (c) Microsoft Corporation. # Licensed under the MIT license. import torch.nn as nn import math def conv_bn(inp, oup, stride): return nn.Sequential( nn.Conv2d(inp, oup, 3, stride, 1, bias=False), nn.BatchNorm2d(oup), nn.ReLU(inplace=True) ) def conv_dw(inp, oup, stride): return nn.Sequential( nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False), nn.BatchNorm2d(inp), nn.ReLU(inplace=True), nn.Conv2d(inp, oup, 1, 1, 0, bias=False), nn.BatchNorm2d(oup), nn.ReLU(inplace=True), ) class MobileNet(nn.Module): def __init__(self, n_class, profile='normal'): super(MobileNet, self).__init__() # original if profile == 'normal': in_planes = 32 cfg = [64, (128, 2), 128, (256, 2), 256, (512, 2), 512, 512, 512, 512, 512, (1024, 2), 1024] # 0.5 AMC elif profile == '0.5flops': in_planes = 24 cfg = [48, (96, 2), 80, (192, 2), 200, (328, 2), 352, 368, 360, 328, 400, (736, 2), 752] else: raise NotImplementedError self.conv1 = conv_bn(3, in_planes, stride=2) self.features = self._make_layers(in_planes, cfg, conv_dw) self.classifier = nn.Sequential( nn.Linear(cfg[-1], n_class), ) self._initialize_weights() def forward(self, x): x = self.conv1(x) x = self.features(x) x = x.mean([2, 3]) # global average pooling x = self.classifier(x) return x def _make_layers(self, in_planes, cfg, layer): layers = [] for x in cfg: out_planes = x if isinstance(x, int) else x[0] stride = 1 if isinstance(x, int) else x[1] layers.append(layer(in_planes, out_planes, stride)) in_planes = out_planes return nn.Sequential(*layers) def _initialize_weights(self): for m in self.modules(): if isinstance(m, nn.Conv2d): n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels m.weight.data.normal_(0, math.sqrt(2. / n)) if m.bias is not None: m.bias.data.zero_() elif isinstance(m, nn.BatchNorm2d): m.weight.data.fill_(1) m.bias.data.zero_() elif isinstance(m, nn.Linear): n = m.weight.size(1) m.weight.data.normal_(0, 0.01) m.bias.data.zero_()