import torch.nn as nn from torch.nn import init __all__ = ["mobile_v1"] class MobileNetV1(nn.Module): def __init__(self, scale=1.0, num_classes=1000, bn_group=None): super(MobileNetV1, self).__init__() BN = nn.BatchNorm2d self.scale = scale def conv_bn(inp, oup, stride): return nn.Sequential(nn.Conv2d(inp, oup, 3, stride, 1, bias=False), BN(oup), nn.ReLU(inplace=True)) def conv_dw(inp, oup, stride): inp = int(inp * scale) oup = int(oup * scale) return nn.Sequential( nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False), BN(inp), nn.ReLU(inplace=True), nn.Conv2d(inp, oup, 1, 1, 0, bias=False), BN(oup), nn.ReLU(inplace=True), ) self.model = nn.Sequential( conv_bn(3, int(32 * scale), 2), conv_dw(32, 64, 1), conv_dw(64, 128, 2), conv_dw(128, 128, 1), conv_dw(128, 256, 2), conv_dw(256, 256, 1), conv_dw(256, 512, 2), conv_dw(512, 512, 1), conv_dw(512, 512, 1), conv_dw(512, 512, 1), conv_dw(512, 512, 1), conv_dw(512, 512, 1), conv_dw(512, 1024, 2), conv_dw(1024, 1024, 1), nn.AvgPool2d(7), ) self.fc = nn.Linear(int(1024 * scale), num_classes) self.init_params() def init_params(self): for m in self.modules(): if isinstance(m, nn.Conv2d): init.kaiming_normal_(m.weight, mode='fan_out') if m.bias is not None: init.constant_(m.bias, 0) elif isinstance(m, nn.BatchNorm2d): init.constant_(m.weight, 1) init.constant_(m.bias, 0) elif isinstance(m, nn.Linear): init.normal_(m.weight, std=0.01) if m.bias is not None: init.constant_(m.bias, 0) def forward(self, x): x = self.model(x) x = x.view(-1, int(1024 * self.scale)) x = self.fc(x) return x def mobile_v1(**kwargs): model = MobileNetV1(**kwargs) return model