Unverified Commit a8d84961 authored by F-G Fernandez's avatar F-G Fernandez Committed by GitHub
Browse files

Added annotation typing to densenet (#2860)

* style: Added annotation typing for densenet

* fix: Fixed import

* refactor: Removed un-necessary import

* fix: Fixed constructor typing

* chore: Updated mypy.ini

* fix: Fixed tuple typing

* style: Ignored some mypy errors

* style: Fixed typing

* fix: Added missing constructor typing
parent f655e6a7
...@@ -16,10 +16,6 @@ ignore_errors = True ...@@ -16,10 +16,6 @@ ignore_errors = True
ignore_errors = True ignore_errors = True
[mypy-torchvision.models.densenet.*]
ignore_errors = True
[mypy-torchvision.models.quantization.*] [mypy-torchvision.models.quantization.*]
ignore_errors = True ignore_errors = True
......
...@@ -6,7 +6,7 @@ import torch.utils.checkpoint as cp ...@@ -6,7 +6,7 @@ import torch.utils.checkpoint as cp
from collections import OrderedDict from collections import OrderedDict
from .utils import load_state_dict_from_url from .utils import load_state_dict_from_url
from torch import Tensor from torch import Tensor
from torch.jit.annotations import List from typing import Any, List, Tuple
__all__ = ['DenseNet', 'densenet121', 'densenet169', 'densenet201', 'densenet161'] __all__ = ['DenseNet', 'densenet121', 'densenet169', 'densenet201', 'densenet161']
...@@ -20,56 +20,64 @@ model_urls = { ...@@ -20,56 +20,64 @@ model_urls = {
class _DenseLayer(nn.Module): class _DenseLayer(nn.Module):
def __init__(self, num_input_features, growth_rate, bn_size, drop_rate, memory_efficient=False): def __init__(
self,
num_input_features: int,
growth_rate: int,
bn_size: int,
drop_rate: float,
memory_efficient: bool = False
) -> None:
super(_DenseLayer, self).__init__() super(_DenseLayer, self).__init__()
self.add_module('norm1', nn.BatchNorm2d(num_input_features)), self.norm1: nn.BatchNorm2d
self.add_module('relu1', nn.ReLU(inplace=True)), self.add_module('norm1', nn.BatchNorm2d(num_input_features))
self.relu1: nn.ReLU
self.add_module('relu1', nn.ReLU(inplace=True))
self.conv1: nn.Conv2d
self.add_module('conv1', nn.Conv2d(num_input_features, bn_size * self.add_module('conv1', nn.Conv2d(num_input_features, bn_size *
growth_rate, kernel_size=1, stride=1, growth_rate, kernel_size=1, stride=1,
bias=False)), bias=False))
self.add_module('norm2', nn.BatchNorm2d(bn_size * growth_rate)), self.norm2: nn.BatchNorm2d
self.add_module('relu2', nn.ReLU(inplace=True)), self.add_module('norm2', nn.BatchNorm2d(bn_size * growth_rate))
self.relu2: nn.ReLU
self.add_module('relu2', nn.ReLU(inplace=True))
self.conv2: nn.Conv2d
self.add_module('conv2', nn.Conv2d(bn_size * growth_rate, growth_rate, self.add_module('conv2', nn.Conv2d(bn_size * growth_rate, growth_rate,
kernel_size=3, stride=1, padding=1, kernel_size=3, stride=1, padding=1,
bias=False)), bias=False))
self.drop_rate = float(drop_rate) self.drop_rate = float(drop_rate)
self.memory_efficient = memory_efficient self.memory_efficient = memory_efficient
def bn_function(self, inputs): def bn_function(self, inputs: List[Tensor]) -> Tensor:
# type: (List[Tensor]) -> Tensor
concated_features = torch.cat(inputs, 1) concated_features = torch.cat(inputs, 1)
bottleneck_output = self.conv1(self.relu1(self.norm1(concated_features))) # noqa: T484 bottleneck_output = self.conv1(self.relu1(self.norm1(concated_features))) # noqa: T484
return bottleneck_output return bottleneck_output
# todo: rewrite when torchscript supports any # todo: rewrite when torchscript supports any
def any_requires_grad(self, input): def any_requires_grad(self, input: List[Tensor]) -> bool:
# type: (List[Tensor]) -> bool
for tensor in input: for tensor in input:
if tensor.requires_grad: if tensor.requires_grad:
return True return True
return False return False
@torch.jit.unused # noqa: T484 @torch.jit.unused # noqa: T484
def call_checkpoint_bottleneck(self, input): def call_checkpoint_bottleneck(self, input: List[Tensor]) -> Tensor:
# type: (List[Tensor]) -> Tensor
def closure(*inputs): def closure(*inputs):
return self.bn_function(inputs) return self.bn_function(inputs)
return cp.checkpoint(closure, *input) return cp.checkpoint(closure, *input)
@torch.jit._overload_method # noqa: F811 @torch.jit._overload_method # noqa: F811
def forward(self, input): def forward(self, input: List[Tensor]) -> Tensor:
# type: (List[Tensor]) -> (Tensor)
pass pass
@torch.jit._overload_method # noqa: F811 @torch.jit._overload_method # type: ignore[no-redef] # noqa: F811
def forward(self, input): def forward(self, input: Tensor) -> Tensor:
# type: (Tensor) -> (Tensor)
pass pass
# torchscript does not yet support *args, so we overload method # torchscript does not yet support *args, so we overload method
# allowing it to take either a List[Tensor] or single Tensor # allowing it to take either a List[Tensor] or single Tensor
def forward(self, input): # noqa: F811 def forward(self, input: Tensor) -> Tensor: # type: ignore[no-redef] # noqa: F811
if isinstance(input, Tensor): if isinstance(input, Tensor):
prev_features = [input] prev_features = [input]
else: else:
...@@ -93,7 +101,15 @@ class _DenseLayer(nn.Module): ...@@ -93,7 +101,15 @@ class _DenseLayer(nn.Module):
class _DenseBlock(nn.ModuleDict): class _DenseBlock(nn.ModuleDict):
_version = 2 _version = 2
def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate, memory_efficient=False): def __init__(
self,
num_layers: int,
num_input_features: int,
bn_size: int,
growth_rate: int,
drop_rate: float,
memory_efficient: bool = False
) -> None:
super(_DenseBlock, self).__init__() super(_DenseBlock, self).__init__()
for i in range(num_layers): for i in range(num_layers):
layer = _DenseLayer( layer = _DenseLayer(
...@@ -105,7 +121,7 @@ class _DenseBlock(nn.ModuleDict): ...@@ -105,7 +121,7 @@ class _DenseBlock(nn.ModuleDict):
) )
self.add_module('denselayer%d' % (i + 1), layer) self.add_module('denselayer%d' % (i + 1), layer)
def forward(self, init_features): def forward(self, init_features: Tensor) -> Tensor: # type: ignore[override]
features = [init_features] features = [init_features]
for name, layer in self.items(): for name, layer in self.items():
new_features = layer(features) new_features = layer(features)
...@@ -114,7 +130,7 @@ class _DenseBlock(nn.ModuleDict): ...@@ -114,7 +130,7 @@ class _DenseBlock(nn.ModuleDict):
class _Transition(nn.Sequential): class _Transition(nn.Sequential):
def __init__(self, num_input_features, num_output_features): def __init__(self, num_input_features: int, num_output_features: int) -> None:
super(_Transition, self).__init__() super(_Transition, self).__init__()
self.add_module('norm', nn.BatchNorm2d(num_input_features)) self.add_module('norm', nn.BatchNorm2d(num_input_features))
self.add_module('relu', nn.ReLU(inplace=True)) self.add_module('relu', nn.ReLU(inplace=True))
...@@ -139,8 +155,16 @@ class DenseNet(nn.Module): ...@@ -139,8 +155,16 @@ class DenseNet(nn.Module):
but slower. Default: *False*. See `"paper" <https://arxiv.org/pdf/1707.06990.pdf>`_ but slower. Default: *False*. See `"paper" <https://arxiv.org/pdf/1707.06990.pdf>`_
""" """
def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16), def __init__(
num_init_features=64, bn_size=4, drop_rate=0, num_classes=1000, memory_efficient=False): self,
growth_rate: int = 32,
block_config: Tuple[int, int, int, int] = (6, 12, 24, 16),
num_init_features: int = 64,
bn_size: int = 4,
drop_rate: float = 0,
num_classes: int = 1000,
memory_efficient: bool = False
) -> None:
super(DenseNet, self).__init__() super(DenseNet, self).__init__()
...@@ -188,7 +212,7 @@ class DenseNet(nn.Module): ...@@ -188,7 +212,7 @@ class DenseNet(nn.Module):
elif isinstance(m, nn.Linear): elif isinstance(m, nn.Linear):
nn.init.constant_(m.bias, 0) nn.init.constant_(m.bias, 0)
def forward(self, x): def forward(self, x: Tensor) -> Tensor:
features = self.features(x) features = self.features(x)
out = F.relu(features, inplace=True) out = F.relu(features, inplace=True)
out = F.adaptive_avg_pool2d(out, (1, 1)) out = F.adaptive_avg_pool2d(out, (1, 1))
...@@ -197,7 +221,7 @@ class DenseNet(nn.Module): ...@@ -197,7 +221,7 @@ class DenseNet(nn.Module):
return out return out
def _load_state_dict(model, model_url, progress): def _load_state_dict(model: nn.Module, model_url: str, progress: bool) -> None:
# '.'s are no longer allowed in module names, but previous _DenseLayer # '.'s are no longer allowed in module names, but previous _DenseLayer
# has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'. # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
# They are also in the checkpoints in model_urls. This pattern is used # They are also in the checkpoints in model_urls. This pattern is used
...@@ -215,15 +239,22 @@ def _load_state_dict(model, model_url, progress): ...@@ -215,15 +239,22 @@ def _load_state_dict(model, model_url, progress):
model.load_state_dict(state_dict) model.load_state_dict(state_dict)
def _densenet(arch, growth_rate, block_config, num_init_features, pretrained, progress, def _densenet(
**kwargs): arch: str,
growth_rate: int,
block_config: Tuple[int, int, int, int],
num_init_features: int,
pretrained: bool,
progress: bool,
**kwargs: Any
) -> DenseNet:
model = DenseNet(growth_rate, block_config, num_init_features, **kwargs) model = DenseNet(growth_rate, block_config, num_init_features, **kwargs)
if pretrained: if pretrained:
_load_state_dict(model, model_urls[arch], progress) _load_state_dict(model, model_urls[arch], progress)
return model return model
def densenet121(pretrained=False, progress=True, **kwargs): def densenet121(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> DenseNet:
r"""Densenet-121 model from r"""Densenet-121 model from
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_ `"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_
...@@ -237,7 +268,7 @@ def densenet121(pretrained=False, progress=True, **kwargs): ...@@ -237,7 +268,7 @@ def densenet121(pretrained=False, progress=True, **kwargs):
**kwargs) **kwargs)
def densenet161(pretrained=False, progress=True, **kwargs): def densenet161(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> DenseNet:
r"""Densenet-161 model from r"""Densenet-161 model from
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_ `"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_
...@@ -251,7 +282,7 @@ def densenet161(pretrained=False, progress=True, **kwargs): ...@@ -251,7 +282,7 @@ def densenet161(pretrained=False, progress=True, **kwargs):
**kwargs) **kwargs)
def densenet169(pretrained=False, progress=True, **kwargs): def densenet169(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> DenseNet:
r"""Densenet-169 model from r"""Densenet-169 model from
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_ `"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_
...@@ -265,7 +296,7 @@ def densenet169(pretrained=False, progress=True, **kwargs): ...@@ -265,7 +296,7 @@ def densenet169(pretrained=False, progress=True, **kwargs):
**kwargs) **kwargs)
def densenet201(pretrained=False, progress=True, **kwargs): def densenet201(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> DenseNet:
r"""Densenet-201 model from r"""Densenet-201 model from
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_ `"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment