# -*- coding: utf-8 -*- # @Time : 2019/11/1 15:31 # @Author : zhoujun import torch import torch.nn as nn #from torchvision.models.utils import load_state_dict_from_url from torch.hub import load_state_dict_from_url __all__ = [ 'ShuffleNetV2', 'shufflenet_v2_x0_5', 'shufflenet_v2_x1_0', 'shufflenet_v2_x1_5', 'shufflenet_v2_x2_0' ] model_urls = { 'shufflenetv2_x0.5': 'https://download.pytorch.org/models/shufflenetv2_x0.5-f707e7126e.pth', 'shufflenetv2_x1.0': 'https://download.pytorch.org/models/shufflenetv2_x1-5666bf0f80.pth', 'shufflenetv2_x1.5': None, 'shufflenetv2_x2.0': None, } def channel_shuffle(x, groups): batchsize, num_channels, height, width = x.data.size() channels_per_group = num_channels // groups # reshape x = x.view(batchsize, groups, channels_per_group, height, width) x = torch.transpose(x, 1, 2).contiguous() # flatten x = x.view(batchsize, -1, height, width) return x class InvertedResidual(nn.Module): def __init__(self, inp, oup, stride): super(InvertedResidual, self).__init__() if not (1 <= stride <= 3): raise ValueError('illegal stride value') self.stride = stride branch_features = oup // 2 assert (self.stride != 1) or (inp == branch_features << 1) if self.stride > 1: self.branch1 = nn.Sequential( self.depthwise_conv(inp, inp, kernel_size=3, stride=self.stride, padding=1), nn.BatchNorm2d(inp), nn.Conv2d(inp, branch_features, kernel_size=1, stride=1, padding=0, bias=False), nn.BatchNorm2d(branch_features), nn.ReLU(inplace=True), ) self.branch2 = nn.Sequential( nn.Conv2d(inp if (self.stride > 1) else branch_features, branch_features, kernel_size=1, stride=1, padding=0, bias=False), nn.BatchNorm2d(branch_features), nn.ReLU(inplace=True), self.depthwise_conv(branch_features, branch_features, kernel_size=3, stride=self.stride, padding=1), nn.BatchNorm2d(branch_features), nn.Conv2d(branch_features, branch_features, kernel_size=1, stride=1, padding=0, bias=False), nn.BatchNorm2d(branch_features), nn.ReLU(inplace=True), ) @staticmethod def depthwise_conv(i, o, kernel_size, stride=1, padding=0, bias=False): return nn.Conv2d(i, o, kernel_size, stride, padding, bias=bias, groups=i) def forward(self, x): if self.stride == 1: x1, x2 = x.chunk(2, dim=1) out = torch.cat((x1, self.branch2(x2)), dim=1) else: out = torch.cat((self.branch1(x), self.branch2(x)), dim=1) out = channel_shuffle(out, 2) return out class ShuffleNetV2(nn.Module): def __init__(self, stages_repeats, stages_out_channels, in_channels=3, **kwargs): super(ShuffleNetV2, self).__init__() self.out_channels = [] if len(stages_repeats) != 3: raise ValueError('expected stages_repeats as list of 3 positive ints') if len(stages_out_channels) != 5: raise ValueError('expected stages_out_channels as list of 5 positive ints') self._stage_out_channels = stages_out_channels output_channels = self._stage_out_channels[0] self.conv1 = nn.Sequential( nn.Conv2d(in_channels, output_channels, 3, 2, 1, bias=False), nn.BatchNorm2d(output_channels), nn.ReLU(inplace=True), ) input_channels = output_channels self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) self.out_channels.append(input_channels) stage_names = ['stage{}'.format(i) for i in [2, 3, 4]] for name, repeats, output_channels in zip( stage_names, stages_repeats, self._stage_out_channels[1:]): seq = [InvertedResidual(input_channels, output_channels, 2)] for i in range(repeats - 1): seq.append(InvertedResidual(output_channels, output_channels, 1)) setattr(self, name, nn.Sequential(*seq)) input_channels = output_channels self.out_channels.append(input_channels) output_channels = self._stage_out_channels[-1] self.conv5 = nn.Sequential( nn.Conv2d(input_channels, output_channels, 1, 1, 0, bias=False), nn.BatchNorm2d(output_channels), nn.ReLU(inplace=True), ) def forward(self, x): x = self.conv1(x) c2 = self.maxpool(x) c3 = self.stage2(c2) c4 = self.stage3(c3) c5 = self.stage4(c4) # c5 = self.conv5(c5) return c2, c3, c4, c5 def _shufflenetv2(arch, pretrained, progress, *args, **kwargs): model = ShuffleNetV2(*args, **kwargs) if pretrained: model_url = model_urls[arch] if model_url is None: raise NotImplementedError('pretrained {} is not supported as of now'.format(arch)) else: assert kwargs['in_channels'] == 3, 'in_channels must be 3 whem pretrained is True' state_dict = load_state_dict_from_url(model_url, progress=progress) model.load_state_dict(state_dict, strict=False) return model def shufflenet_v2_x0_5(pretrained=False, progress=True, **kwargs): """ Constructs a ShuffleNetV2 with 0.5x output channels, as described in `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design" `_. Args: pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ return _shufflenetv2('shufflenetv2_x0.5', pretrained, progress, [4, 8, 4], [24, 48, 96, 192, 1024], **kwargs) def shufflenet_v2_x1_0(pretrained=False, progress=True, **kwargs): """ Constructs a ShuffleNetV2 with 1.0x output channels, as described in `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design" `_. Args: pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ return _shufflenetv2('shufflenetv2_x1.0', pretrained, progress, [4, 8, 4], [24, 116, 232, 464, 1024], **kwargs) def shufflenet_v2_x1_5(pretrained=False, progress=True, **kwargs): """ Constructs a ShuffleNetV2 with 1.5x output channels, as described in `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design" `_. Args: pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ return _shufflenetv2('shufflenetv2_x1.5', pretrained, progress, [4, 8, 4], [24, 176, 352, 704, 1024], **kwargs) def shufflenet_v2_x2_0(pretrained=False, progress=True, **kwargs): """ Constructs a ShuffleNetV2 with 2.0x output channels, as described in `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design" `_. Args: pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr """ return _shufflenetv2('shufflenetv2_x2.0', pretrained, progress, [4, 8, 4], [24, 244, 488, 976, 2048], **kwargs)