Unverified Commit 052edcec authored by F-G Fernandez's avatar F-G Fernandez Committed by GitHub
Browse files

Added annotation typing to shufflenet (#2864)

* style: Added annotation typing for shufflenet

* fix: Removed duplicate type hint

* refactor: Removed un-necessary import

* fix: Fixed constructor typing

* style: Added black formatting on depthwise_conv

* style: Fixed stage typing in shufflenet
parent 3852b419
import torch import torch
from torch import Tensor
import torch.nn as nn import torch.nn as nn
from .utils import load_state_dict_from_url from .utils import load_state_dict_from_url
from typing import Callable, Any, List
__all__ = [ __all__ = [
...@@ -16,8 +18,7 @@ model_urls = { ...@@ -16,8 +18,7 @@ model_urls = {
} }
def channel_shuffle(x, groups): def channel_shuffle(x: Tensor, groups: int) -> Tensor:
# type: (torch.Tensor, int) -> torch.Tensor
batchsize, num_channels, height, width = x.data.size() batchsize, num_channels, height, width = x.data.size()
channels_per_group = num_channels // groups channels_per_group = num_channels // groups
...@@ -34,7 +35,12 @@ def channel_shuffle(x, groups): ...@@ -34,7 +35,12 @@ def channel_shuffle(x, groups):
class InvertedResidual(nn.Module): class InvertedResidual(nn.Module):
def __init__(self, inp, oup, stride): def __init__(
self,
inp: int,
oup: int,
stride: int
) -> None:
super(InvertedResidual, self).__init__() super(InvertedResidual, self).__init__()
if not (1 <= stride <= 3): if not (1 <= stride <= 3):
...@@ -68,10 +74,17 @@ class InvertedResidual(nn.Module): ...@@ -68,10 +74,17 @@ class InvertedResidual(nn.Module):
) )
@staticmethod @staticmethod
def depthwise_conv(i, o, kernel_size, stride=1, padding=0, bias=False): def depthwise_conv(
i: int,
o: int,
kernel_size: int,
stride: int = 1,
padding: int = 0,
bias: bool = False
) -> nn.Conv2d:
return nn.Conv2d(i, o, kernel_size, stride, padding, bias=bias, groups=i) return nn.Conv2d(i, o, kernel_size, stride, padding, bias=bias, groups=i)
def forward(self, x): def forward(self, x: Tensor) -> Tensor:
if self.stride == 1: if self.stride == 1:
x1, x2 = x.chunk(2, dim=1) x1, x2 = x.chunk(2, dim=1)
out = torch.cat((x1, self.branch2(x2)), dim=1) out = torch.cat((x1, self.branch2(x2)), dim=1)
...@@ -84,7 +97,13 @@ class InvertedResidual(nn.Module): ...@@ -84,7 +97,13 @@ class InvertedResidual(nn.Module):
class ShuffleNetV2(nn.Module): class ShuffleNetV2(nn.Module):
def __init__(self, stages_repeats, stages_out_channels, num_classes=1000, inverted_residual=InvertedResidual): def __init__(
self,
stages_repeats: List[int],
stages_out_channels: List[int],
num_classes: int = 1000,
inverted_residual: Callable[..., nn.Module] = InvertedResidual
) -> None:
super(ShuffleNetV2, self).__init__() super(ShuffleNetV2, self).__init__()
if len(stages_repeats) != 3: if len(stages_repeats) != 3:
...@@ -104,6 +123,10 @@ class ShuffleNetV2(nn.Module): ...@@ -104,6 +123,10 @@ class ShuffleNetV2(nn.Module):
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
# Static annotations for mypy
self.stage2: nn.Sequential
self.stage3: nn.Sequential
self.stage4: nn.Sequential
stage_names = ['stage{}'.format(i) for i in [2, 3, 4]] stage_names = ['stage{}'.format(i) for i in [2, 3, 4]]
for name, repeats, output_channels in zip( for name, repeats, output_channels in zip(
stage_names, stages_repeats, self._stage_out_channels[1:]): stage_names, stages_repeats, self._stage_out_channels[1:]):
...@@ -122,7 +145,7 @@ class ShuffleNetV2(nn.Module): ...@@ -122,7 +145,7 @@ class ShuffleNetV2(nn.Module):
self.fc = nn.Linear(output_channels, num_classes) self.fc = nn.Linear(output_channels, num_classes)
def _forward_impl(self, x): def _forward_impl(self, x: Tensor) -> Tensor:
# See note [TorchScript super()] # See note [TorchScript super()]
x = self.conv1(x) x = self.conv1(x)
x = self.maxpool(x) x = self.maxpool(x)
...@@ -134,11 +157,11 @@ class ShuffleNetV2(nn.Module): ...@@ -134,11 +157,11 @@ class ShuffleNetV2(nn.Module):
x = self.fc(x) x = self.fc(x)
return x return x
def forward(self, x): def forward(self, x: Tensor) -> Tensor:
return self._forward_impl(x) return self._forward_impl(x)
def _shufflenetv2(arch, pretrained, progress, *args, **kwargs): def _shufflenetv2(arch: str, pretrained: bool, progress: bool, *args: Any, **kwargs: Any) -> ShuffleNetV2:
model = ShuffleNetV2(*args, **kwargs) model = ShuffleNetV2(*args, **kwargs)
if pretrained: if pretrained:
...@@ -152,7 +175,7 @@ def _shufflenetv2(arch, pretrained, progress, *args, **kwargs): ...@@ -152,7 +175,7 @@ def _shufflenetv2(arch, pretrained, progress, *args, **kwargs):
return model return model
def shufflenet_v2_x0_5(pretrained=False, progress=True, **kwargs): def shufflenet_v2_x0_5(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ShuffleNetV2:
""" """
Constructs a ShuffleNetV2 with 0.5x output channels, as described in Constructs a ShuffleNetV2 with 0.5x output channels, as described in
`"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design" `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design"
...@@ -166,7 +189,7 @@ def shufflenet_v2_x0_5(pretrained=False, progress=True, **kwargs): ...@@ -166,7 +189,7 @@ def shufflenet_v2_x0_5(pretrained=False, progress=True, **kwargs):
[4, 8, 4], [24, 48, 96, 192, 1024], **kwargs) [4, 8, 4], [24, 48, 96, 192, 1024], **kwargs)
def shufflenet_v2_x1_0(pretrained=False, progress=True, **kwargs): def shufflenet_v2_x1_0(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ShuffleNetV2:
""" """
Constructs a ShuffleNetV2 with 1.0x output channels, as described in Constructs a ShuffleNetV2 with 1.0x output channels, as described in
`"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design" `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design"
...@@ -180,7 +203,7 @@ def shufflenet_v2_x1_0(pretrained=False, progress=True, **kwargs): ...@@ -180,7 +203,7 @@ def shufflenet_v2_x1_0(pretrained=False, progress=True, **kwargs):
[4, 8, 4], [24, 116, 232, 464, 1024], **kwargs) [4, 8, 4], [24, 116, 232, 464, 1024], **kwargs)
def shufflenet_v2_x1_5(pretrained=False, progress=True, **kwargs): def shufflenet_v2_x1_5(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ShuffleNetV2:
""" """
Constructs a ShuffleNetV2 with 1.5x output channels, as described in Constructs a ShuffleNetV2 with 1.5x output channels, as described in
`"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design" `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design"
...@@ -194,7 +217,7 @@ def shufflenet_v2_x1_5(pretrained=False, progress=True, **kwargs): ...@@ -194,7 +217,7 @@ def shufflenet_v2_x1_5(pretrained=False, progress=True, **kwargs):
[4, 8, 4], [24, 176, 352, 704, 1024], **kwargs) [4, 8, 4], [24, 176, 352, 704, 1024], **kwargs)
def shufflenet_v2_x2_0(pretrained=False, progress=True, **kwargs): def shufflenet_v2_x2_0(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ShuffleNetV2:
""" """
Constructs a ShuffleNetV2 with 2.0x output channels, as described in Constructs a ShuffleNetV2 with 2.0x output channels, as described in
`"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design" `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment