Unverified Commit 2ce6b189 authored by F-G Fernandez's avatar F-G Fernandez Committed by GitHub
Browse files

Added annotation typing to resnet (#2863)

* style: Added annotation typing for resnet

* fix: Fixed annotation to pass classes

* fix: Fixed annotation typing

* fix: Fixed annotation typing

* fix: Fixed annotation typing for resnet

* refactor: Removed un-necessary import

* fix: Fixed constructor typing

* style: Added black formatting on _resnet
parent 65591f14
import torch import torch
from torch import Tensor
import torch.nn as nn import torch.nn as nn
from .utils import load_state_dict_from_url from .utils import load_state_dict_from_url
from typing import Type, Any, Callable, Union, List, Optional
__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
...@@ -21,22 +23,31 @@ model_urls = { ...@@ -21,22 +23,31 @@ model_urls = {
} }
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d:
"""3x3 convolution with padding""" """3x3 convolution with padding"""
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
padding=dilation, groups=groups, bias=False, dilation=dilation) padding=dilation, groups=groups, bias=False, dilation=dilation)
def conv1x1(in_planes, out_planes, stride=1): def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
"""1x1 convolution""" """1x1 convolution"""
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
class BasicBlock(nn.Module): class BasicBlock(nn.Module):
expansion = 1 expansion: int = 1
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, def __init__(
base_width=64, dilation=1, norm_layer=None): self,
inplanes: int,
planes: int,
stride: int = 1,
downsample: Optional[nn.Module] = None,
groups: int = 1,
base_width: int = 64,
dilation: int = 1,
norm_layer: Optional[Callable[..., nn.Module]] = None
) -> None:
super(BasicBlock, self).__init__() super(BasicBlock, self).__init__()
if norm_layer is None: if norm_layer is None:
norm_layer = nn.BatchNorm2d norm_layer = nn.BatchNorm2d
...@@ -53,7 +64,7 @@ class BasicBlock(nn.Module): ...@@ -53,7 +64,7 @@ class BasicBlock(nn.Module):
self.downsample = downsample self.downsample = downsample
self.stride = stride self.stride = stride
def forward(self, x): def forward(self, x: Tensor) -> Tensor:
identity = x identity = x
out = self.conv1(x) out = self.conv1(x)
...@@ -79,10 +90,19 @@ class Bottleneck(nn.Module): ...@@ -79,10 +90,19 @@ class Bottleneck(nn.Module):
# This variant is also known as ResNet V1.5 and improves accuracy according to # This variant is also known as ResNet V1.5 and improves accuracy according to
# https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
expansion = 4 expansion: int = 4
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, def __init__(
base_width=64, dilation=1, norm_layer=None): self,
inplanes: int,
planes: int,
stride: int = 1,
downsample: Optional[nn.Module] = None,
groups: int = 1,
base_width: int = 64,
dilation: int = 1,
norm_layer: Optional[Callable[..., nn.Module]] = None
) -> None:
super(Bottleneck, self).__init__() super(Bottleneck, self).__init__()
if norm_layer is None: if norm_layer is None:
norm_layer = nn.BatchNorm2d norm_layer = nn.BatchNorm2d
...@@ -98,7 +118,7 @@ class Bottleneck(nn.Module): ...@@ -98,7 +118,7 @@ class Bottleneck(nn.Module):
self.downsample = downsample self.downsample = downsample
self.stride = stride self.stride = stride
def forward(self, x): def forward(self, x: Tensor) -> Tensor:
identity = x identity = x
out = self.conv1(x) out = self.conv1(x)
...@@ -123,9 +143,17 @@ class Bottleneck(nn.Module): ...@@ -123,9 +143,17 @@ class Bottleneck(nn.Module):
class ResNet(nn.Module): class ResNet(nn.Module):
def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, def __init__(
groups=1, width_per_group=64, replace_stride_with_dilation=None, self,
norm_layer=None): block: Type[Union[BasicBlock, Bottleneck]],
layers: List[int],
num_classes: int = 1000,
zero_init_residual: bool = False,
groups: int = 1,
width_per_group: int = 64,
replace_stride_with_dilation: Optional[List[bool]] = None,
norm_layer: Optional[Callable[..., nn.Module]] = None
) -> None:
super(ResNet, self).__init__() super(ResNet, self).__init__()
if norm_layer is None: if norm_layer is None:
norm_layer = nn.BatchNorm2d norm_layer = nn.BatchNorm2d
...@@ -170,11 +198,12 @@ class ResNet(nn.Module): ...@@ -170,11 +198,12 @@ class ResNet(nn.Module):
if zero_init_residual: if zero_init_residual:
for m in self.modules(): for m in self.modules():
if isinstance(m, Bottleneck): if isinstance(m, Bottleneck):
nn.init.constant_(m.bn3.weight, 0) nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type]
elif isinstance(m, BasicBlock): elif isinstance(m, BasicBlock):
nn.init.constant_(m.bn2.weight, 0) nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type]
def _make_layer(self, block, planes, blocks, stride=1, dilate=False): def _make_layer(self, block: Type[Union[BasicBlock, Bottleneck]], planes: int, blocks: int,
stride: int = 1, dilate: bool = False) -> nn.Sequential:
norm_layer = self._norm_layer norm_layer = self._norm_layer
downsample = None downsample = None
previous_dilation = self.dilation previous_dilation = self.dilation
...@@ -198,7 +227,7 @@ class ResNet(nn.Module): ...@@ -198,7 +227,7 @@ class ResNet(nn.Module):
return nn.Sequential(*layers) return nn.Sequential(*layers)
def _forward_impl(self, x): def _forward_impl(self, x: Tensor) -> Tensor:
# See note [TorchScript super()] # See note [TorchScript super()]
x = self.conv1(x) x = self.conv1(x)
x = self.bn1(x) x = self.bn1(x)
...@@ -216,11 +245,18 @@ class ResNet(nn.Module): ...@@ -216,11 +245,18 @@ class ResNet(nn.Module):
return x return x
def forward(self, x): def forward(self, x: Tensor) -> Tensor:
return self._forward_impl(x) return self._forward_impl(x)
def _resnet(arch, block, layers, pretrained, progress, **kwargs): def _resnet(
arch: str,
block: Type[Union[BasicBlock, Bottleneck]],
layers: List[int],
pretrained: bool,
progress: bool,
**kwargs: Any
) -> ResNet:
model = ResNet(block, layers, **kwargs) model = ResNet(block, layers, **kwargs)
if pretrained: if pretrained:
state_dict = load_state_dict_from_url(model_urls[arch], state_dict = load_state_dict_from_url(model_urls[arch],
...@@ -229,7 +265,7 @@ def _resnet(arch, block, layers, pretrained, progress, **kwargs): ...@@ -229,7 +265,7 @@ def _resnet(arch, block, layers, pretrained, progress, **kwargs):
return model return model
def resnet18(pretrained=False, progress=True, **kwargs): def resnet18(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
r"""ResNet-18 model from r"""ResNet-18 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_ `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
...@@ -241,7 +277,7 @@ def resnet18(pretrained=False, progress=True, **kwargs): ...@@ -241,7 +277,7 @@ def resnet18(pretrained=False, progress=True, **kwargs):
**kwargs) **kwargs)
def resnet34(pretrained=False, progress=True, **kwargs): def resnet34(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
r"""ResNet-34 model from r"""ResNet-34 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_ `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
...@@ -253,7 +289,7 @@ def resnet34(pretrained=False, progress=True, **kwargs): ...@@ -253,7 +289,7 @@ def resnet34(pretrained=False, progress=True, **kwargs):
**kwargs) **kwargs)
def resnet50(pretrained=False, progress=True, **kwargs): def resnet50(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
r"""ResNet-50 model from r"""ResNet-50 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_ `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
...@@ -265,7 +301,7 @@ def resnet50(pretrained=False, progress=True, **kwargs): ...@@ -265,7 +301,7 @@ def resnet50(pretrained=False, progress=True, **kwargs):
**kwargs) **kwargs)
def resnet101(pretrained=False, progress=True, **kwargs): def resnet101(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
r"""ResNet-101 model from r"""ResNet-101 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_ `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
...@@ -277,7 +313,7 @@ def resnet101(pretrained=False, progress=True, **kwargs): ...@@ -277,7 +313,7 @@ def resnet101(pretrained=False, progress=True, **kwargs):
**kwargs) **kwargs)
def resnet152(pretrained=False, progress=True, **kwargs): def resnet152(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
r"""ResNet-152 model from r"""ResNet-152 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_ `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
...@@ -289,7 +325,7 @@ def resnet152(pretrained=False, progress=True, **kwargs): ...@@ -289,7 +325,7 @@ def resnet152(pretrained=False, progress=True, **kwargs):
**kwargs) **kwargs)
def resnext50_32x4d(pretrained=False, progress=True, **kwargs): def resnext50_32x4d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
r"""ResNeXt-50 32x4d model from r"""ResNeXt-50 32x4d model from
`"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_ `"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_
...@@ -303,7 +339,7 @@ def resnext50_32x4d(pretrained=False, progress=True, **kwargs): ...@@ -303,7 +339,7 @@ def resnext50_32x4d(pretrained=False, progress=True, **kwargs):
pretrained, progress, **kwargs) pretrained, progress, **kwargs)
def resnext101_32x8d(pretrained=False, progress=True, **kwargs): def resnext101_32x8d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
r"""ResNeXt-101 32x8d model from r"""ResNeXt-101 32x8d model from
`"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_ `"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_
...@@ -317,7 +353,7 @@ def resnext101_32x8d(pretrained=False, progress=True, **kwargs): ...@@ -317,7 +353,7 @@ def resnext101_32x8d(pretrained=False, progress=True, **kwargs):
pretrained, progress, **kwargs) pretrained, progress, **kwargs)
def wide_resnet50_2(pretrained=False, progress=True, **kwargs): def wide_resnet50_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
r"""Wide ResNet-50-2 model from r"""Wide ResNet-50-2 model from
`"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_ `"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_
...@@ -335,7 +371,7 @@ def wide_resnet50_2(pretrained=False, progress=True, **kwargs): ...@@ -335,7 +371,7 @@ def wide_resnet50_2(pretrained=False, progress=True, **kwargs):
pretrained, progress, **kwargs) pretrained, progress, **kwargs)
def wide_resnet101_2(pretrained=False, progress=True, **kwargs): def wide_resnet101_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
r"""Wide ResNet-101-2 model from r"""Wide ResNet-101-2 model from
`"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_ `"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment