Unverified Commit 2ce6b189 authored by F-G Fernandez's avatar F-G Fernandez Committed by GitHub
Browse files

Added annotation typing to resnet (#2863)

* style: Added annotation typing for resnet

* fix: Fixed annotation to pass classes

* fix: Fixed annotation typing

* fix: Fixed annotation typing

* fix: Fixed annotation typing for resnet

* refactor: Removed un-necessary import

* fix: Fixed constructor typing

* style: Added black formatting on _resnet
parent 65591f14
import torch
from torch import Tensor
import torch.nn as nn
from .utils import load_state_dict_from_url
from typing import Type, Any, Callable, Union, List, Optional
__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
......@@ -21,22 +23,31 @@ model_urls = {
}
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d:
"""3x3 convolution with padding"""
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
padding=dilation, groups=groups, bias=False, dilation=dilation)
def conv1x1(in_planes, out_planes, stride=1):
def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
"""1x1 convolution"""
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
base_width=64, dilation=1, norm_layer=None):
expansion: int = 1
def __init__(
self,
inplanes: int,
planes: int,
stride: int = 1,
downsample: Optional[nn.Module] = None,
groups: int = 1,
base_width: int = 64,
dilation: int = 1,
norm_layer: Optional[Callable[..., nn.Module]] = None
) -> None:
super(BasicBlock, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
......@@ -53,7 +64,7 @@ class BasicBlock(nn.Module):
self.downsample = downsample
self.stride = stride
def forward(self, x):
def forward(self, x: Tensor) -> Tensor:
identity = x
out = self.conv1(x)
......@@ -79,10 +90,19 @@ class Bottleneck(nn.Module):
# This variant is also known as ResNet V1.5 and improves accuracy according to
# https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
expansion = 4
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
base_width=64, dilation=1, norm_layer=None):
expansion: int = 4
def __init__(
self,
inplanes: int,
planes: int,
stride: int = 1,
downsample: Optional[nn.Module] = None,
groups: int = 1,
base_width: int = 64,
dilation: int = 1,
norm_layer: Optional[Callable[..., nn.Module]] = None
) -> None:
super(Bottleneck, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
......@@ -98,7 +118,7 @@ class Bottleneck(nn.Module):
self.downsample = downsample
self.stride = stride
def forward(self, x):
def forward(self, x: Tensor) -> Tensor:
identity = x
out = self.conv1(x)
......@@ -123,9 +143,17 @@ class Bottleneck(nn.Module):
class ResNet(nn.Module):
def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
groups=1, width_per_group=64, replace_stride_with_dilation=None,
norm_layer=None):
def __init__(
self,
block: Type[Union[BasicBlock, Bottleneck]],
layers: List[int],
num_classes: int = 1000,
zero_init_residual: bool = False,
groups: int = 1,
width_per_group: int = 64,
replace_stride_with_dilation: Optional[List[bool]] = None,
norm_layer: Optional[Callable[..., nn.Module]] = None
) -> None:
super(ResNet, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
......@@ -170,11 +198,12 @@ class ResNet(nn.Module):
if zero_init_residual:
for m in self.modules():
if isinstance(m, Bottleneck):
nn.init.constant_(m.bn3.weight, 0)
nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type]
elif isinstance(m, BasicBlock):
nn.init.constant_(m.bn2.weight, 0)
nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type]
def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
def _make_layer(self, block: Type[Union[BasicBlock, Bottleneck]], planes: int, blocks: int,
stride: int = 1, dilate: bool = False) -> nn.Sequential:
norm_layer = self._norm_layer
downsample = None
previous_dilation = self.dilation
......@@ -198,7 +227,7 @@ class ResNet(nn.Module):
return nn.Sequential(*layers)
def _forward_impl(self, x):
def _forward_impl(self, x: Tensor) -> Tensor:
# See note [TorchScript super()]
x = self.conv1(x)
x = self.bn1(x)
......@@ -216,11 +245,18 @@ class ResNet(nn.Module):
return x
def forward(self, x):
def forward(self, x: Tensor) -> Tensor:
return self._forward_impl(x)
def _resnet(arch, block, layers, pretrained, progress, **kwargs):
def _resnet(
arch: str,
block: Type[Union[BasicBlock, Bottleneck]],
layers: List[int],
pretrained: bool,
progress: bool,
**kwargs: Any
) -> ResNet:
model = ResNet(block, layers, **kwargs)
if pretrained:
state_dict = load_state_dict_from_url(model_urls[arch],
......@@ -229,7 +265,7 @@ def _resnet(arch, block, layers, pretrained, progress, **kwargs):
return model
def resnet18(pretrained=False, progress=True, **kwargs):
def resnet18(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
r"""ResNet-18 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
......@@ -241,7 +277,7 @@ def resnet18(pretrained=False, progress=True, **kwargs):
**kwargs)
def resnet34(pretrained=False, progress=True, **kwargs):
def resnet34(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
r"""ResNet-34 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
......@@ -253,7 +289,7 @@ def resnet34(pretrained=False, progress=True, **kwargs):
**kwargs)
def resnet50(pretrained=False, progress=True, **kwargs):
def resnet50(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
r"""ResNet-50 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
......@@ -265,7 +301,7 @@ def resnet50(pretrained=False, progress=True, **kwargs):
**kwargs)
def resnet101(pretrained=False, progress=True, **kwargs):
def resnet101(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
r"""ResNet-101 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
......@@ -277,7 +313,7 @@ def resnet101(pretrained=False, progress=True, **kwargs):
**kwargs)
def resnet152(pretrained=False, progress=True, **kwargs):
def resnet152(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
r"""ResNet-152 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
......@@ -289,7 +325,7 @@ def resnet152(pretrained=False, progress=True, **kwargs):
**kwargs)
def resnext50_32x4d(pretrained=False, progress=True, **kwargs):
def resnext50_32x4d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
r"""ResNeXt-50 32x4d model from
`"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_
......@@ -303,7 +339,7 @@ def resnext50_32x4d(pretrained=False, progress=True, **kwargs):
pretrained, progress, **kwargs)
def resnext101_32x8d(pretrained=False, progress=True, **kwargs):
def resnext101_32x8d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
r"""ResNeXt-101 32x8d model from
`"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_
......@@ -317,7 +353,7 @@ def resnext101_32x8d(pretrained=False, progress=True, **kwargs):
pretrained, progress, **kwargs)
def wide_resnet50_2(pretrained=False, progress=True, **kwargs):
def wide_resnet50_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
r"""Wide ResNet-50-2 model from
`"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_
......@@ -335,7 +371,7 @@ def wide_resnet50_2(pretrained=False, progress=True, **kwargs):
pretrained, progress, **kwargs)
def wide_resnet101_2(pretrained=False, progress=True, **kwargs):
def wide_resnet101_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
r"""Wide ResNet-101-2 model from
`"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment