import torch from functools import partial from torch import nn, Tensor from torch.nn import functional as F from typing import Any, Callable, List, Optional, Sequence from torchvision.models.utils import load_state_dict_from_url from torchvision.models.mobilenetv2 import _make_divisible, ConvBNActivation __all__ = ["MobileNetV3", "mobilenet_v3_large", "mobilenet_v3_small"] model_urls = { "mobilenet_v3_large": "https://download.pytorch.org/models/mobilenet_v3_large-8738ca79.pth", "mobilenet_v3_small": None, } class SqueezeExcitation(nn.Module): def __init__(self, input_channels: int, squeeze_factor: int = 4): super().__init__() squeeze_channels = _make_divisible(input_channels // squeeze_factor, 8) self.fc1 = nn.Conv2d(input_channels, squeeze_channels, 1) self.fc2 = nn.Conv2d(squeeze_channels, input_channels, 1) def forward(self, input: Tensor) -> Tensor: scale = F.adaptive_avg_pool2d(input, 1) scale = self.fc1(scale) scale = F.relu(scale, inplace=True) scale = self.fc2(scale) scale = F.hardsigmoid(scale, inplace=True) return scale * input class InvertedResidualConfig: def __init__(self, input_channels: int, kernel: int, expanded_channels: int, out_channels: int, use_se: bool, activation: str, stride: int, width_mult: float): self.input_channels = self.adjust_channels(input_channels, width_mult) self.kernel = kernel self.expanded_channels = self.adjust_channels(expanded_channels, width_mult) self.out_channels = self.adjust_channels(out_channels, width_mult) self.use_se = use_se self.use_hs = activation == "HS" self.stride = stride @staticmethod def adjust_channels(channels: int, width_mult: float): return _make_divisible(channels * width_mult, 8) class InvertedResidual(nn.Module): def __init__(self, cnf: InvertedResidualConfig, norm_layer: Callable[..., nn.Module]): super().__init__() if not (1 <= cnf.stride <= 2): raise ValueError('illegal stride value') self.use_res_connect = cnf.stride == 1 and cnf.input_channels == cnf.out_channels layers: List[nn.Module] = [] activation_layer = nn.Hardswish if cnf.use_hs else nn.ReLU # expand if cnf.expanded_channels != cnf.input_channels: layers.append(ConvBNActivation(cnf.input_channels, cnf.expanded_channels, kernel_size=1, norm_layer=norm_layer, activation_layer=activation_layer)) # depthwise layers.append(ConvBNActivation(cnf.expanded_channels, cnf.expanded_channels, kernel_size=cnf.kernel, stride=cnf.stride, groups=cnf.expanded_channels, norm_layer=norm_layer, activation_layer=activation_layer)) if cnf.use_se: layers.append(SqueezeExcitation(cnf.expanded_channels)) # project layers.append(ConvBNActivation(cnf.expanded_channels, cnf.out_channels, kernel_size=1, norm_layer=norm_layer, activation_layer=nn.Identity)) self.block = nn.Sequential(*layers) self.out_channels = cnf.out_channels self.is_strided = cnf.stride > 1 def forward(self, input: Tensor) -> Tensor: result = self.block(input) if self.use_res_connect: result += input return result class MobileNetV3(nn.Module): def __init__( self, inverted_residual_setting: List[InvertedResidualConfig], last_channel: int, num_classes: int = 1000, block: Optional[Callable[..., nn.Module]] = None, norm_layer: Optional[Callable[..., nn.Module]] = None ) -> None: """ MobileNet V3 main class Args: inverted_residual_setting (List[InvertedResidualConfig]): Network structure last_channel (int): The number of channels on the penultimate layer num_classes (int): Number of classes block (Optional[Callable[..., nn.Module]]): Module specifying inverted residual building block for mobilenet norm_layer (Optional[Callable[..., nn.Module]]): Module specifying the normalization layer to use """ super().__init__() if not inverted_residual_setting: raise ValueError("The inverted_residual_setting should not be empty") elif not (isinstance(inverted_residual_setting, Sequence) and all([isinstance(s, InvertedResidualConfig) for s in inverted_residual_setting])): raise TypeError("The inverted_residual_setting should be List[InvertedResidualConfig]") if block is None: block = InvertedResidual if norm_layer is None: norm_layer = partial(nn.BatchNorm2d, eps=0.001, momentum=0.01) layers: List[nn.Module] = [] # building first layer firstconv_output_channels = inverted_residual_setting[0].input_channels layers.append(ConvBNActivation(3, firstconv_output_channels, kernel_size=3, stride=2, norm_layer=norm_layer, activation_layer=nn.Hardswish)) # building inverted residual blocks for cnf in inverted_residual_setting: layers.append(block(cnf, norm_layer)) # building last several layers lastconv_input_channels = inverted_residual_setting[-1].out_channels lastconv_output_channels = 6 * lastconv_input_channels layers.append(ConvBNActivation(lastconv_input_channels, lastconv_output_channels, kernel_size=1, norm_layer=norm_layer, activation_layer=nn.Hardswish)) self.features = nn.Sequential(*layers) self.avgpool = nn.AdaptiveAvgPool2d(1) self.classifier = nn.Sequential( nn.Linear(lastconv_output_channels, last_channel), nn.Hardswish(inplace=True), nn.Dropout(p=0.2, inplace=True), nn.Linear(last_channel, num_classes), ) for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode='fan_out') if m.bias is not None: nn.init.zeros_(m.bias) elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): nn.init.ones_(m.weight) nn.init.zeros_(m.bias) elif isinstance(m, nn.Linear): nn.init.normal_(m.weight, 0, 0.01) nn.init.zeros_(m.bias) def _forward_impl(self, x: Tensor) -> Tensor: x = self.features(x) x = self.avgpool(x) x = torch.flatten(x, 1) x = self.classifier(x) return x def forward(self, x: Tensor) -> Tensor: return self._forward_impl(x) def _mobilenet_v3( arch: str, inverted_residual_setting: List[InvertedResidualConfig], last_channel: int, pretrained: bool, progress: bool, **kwargs: Any ): model = MobileNetV3(inverted_residual_setting, last_channel, **kwargs) if pretrained: if model_urls.get(arch, None) is None: raise ValueError("No checkpoint is available for model type {}".format(arch)) state_dict = load_state_dict_from_url(model_urls[arch], progress=progress) model.load_state_dict(state_dict) return model def mobilenet_v3_large(pretrained: bool = False, progress: bool = True, reduced_tail: bool = False, **kwargs: Any) -> MobileNetV3: """ Constructs a large MobileNetV3 architecture from `"Searching for MobileNetV3" `_. Args: pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr reduced_tail (bool): If True, reduces the channel counts of all feature layers between C4 and C5 by 2. It is used to reduce the channel redundancy in the backbone for Detection and Segmentation. """ width_mult = 1.0 bneck_conf = partial(InvertedResidualConfig, width_mult=width_mult) adjust_channels = partial(InvertedResidualConfig.adjust_channels, width_mult=width_mult) reduce_divider = 2 if reduced_tail else 1 inverted_residual_setting = [ bneck_conf(16, 3, 16, 16, False, "RE", 1), bneck_conf(16, 3, 64, 24, False, "RE", 2), # C1 bneck_conf(24, 3, 72, 24, False, "RE", 1), bneck_conf(24, 5, 72, 40, True, "RE", 2), # C2 bneck_conf(40, 5, 120, 40, True, "RE", 1), bneck_conf(40, 5, 120, 40, True, "RE", 1), bneck_conf(40, 3, 240, 80, False, "HS", 2), # C3 bneck_conf(80, 3, 200, 80, False, "HS", 1), bneck_conf(80, 3, 184, 80, False, "HS", 1), bneck_conf(80, 3, 184, 80, False, "HS", 1), bneck_conf(80, 3, 480, 112, True, "HS", 1), bneck_conf(112, 3, 672, 112, True, "HS", 1), bneck_conf(112, 5, 672, 160 // reduce_divider, True, "HS", 2), # C4 bneck_conf(160 // reduce_divider, 5, 960 // reduce_divider, 160 // reduce_divider, True, "HS", 1), bneck_conf(160 // reduce_divider, 5, 960 // reduce_divider, 160 // reduce_divider, True, "HS", 1), ] last_channel = adjust_channels(1280 // reduce_divider) # C5 return _mobilenet_v3("mobilenet_v3_large", inverted_residual_setting, last_channel, pretrained, progress, **kwargs) def mobilenet_v3_small(pretrained: bool = False, progress: bool = True, reduced_tail: bool = False, **kwargs: Any) -> MobileNetV3: """ Constructs a small MobileNetV3 architecture from `"Searching for MobileNetV3" `_. Args: pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr reduced_tail (bool): If True, reduces the channel counts of all feature layers between C4 and C5 by 2. It is used to reduce the channel redundancy in the backbone for Detection and Segmentation. """ width_mult = 1.0 bneck_conf = partial(InvertedResidualConfig, width_mult=width_mult) adjust_channels = partial(InvertedResidualConfig.adjust_channels, width_mult=width_mult) reduce_divider = 2 if reduced_tail else 1 inverted_residual_setting = [ bneck_conf(16, 3, 16, 16, True, "RE", 2), # C1 bneck_conf(16, 3, 72, 24, False, "RE", 2), # C2 bneck_conf(24, 3, 88, 24, False, "RE", 1), bneck_conf(24, 5, 96, 40, True, "HS", 2), # C3 bneck_conf(40, 5, 240, 40, True, "HS", 1), bneck_conf(40, 5, 240, 40, True, "HS", 1), bneck_conf(40, 5, 120, 48, True, "HS", 1), bneck_conf(48, 5, 144, 48, True, "HS", 1), bneck_conf(48, 5, 288, 96 // reduce_divider, True, "HS", 2), # C4 bneck_conf(96 // reduce_divider, 5, 576 // reduce_divider, 96 // reduce_divider, True, "HS", 1), bneck_conf(96 // reduce_divider, 5, 576 // reduce_divider, 96 // reduce_divider, True, "HS", 1), ] last_channel = adjust_channels(1024 // reduce_divider) # C5 return _mobilenet_v3("mobilenet_v3_small", inverted_residual_setting, last_channel, pretrained, progress, **kwargs)