Unverified Commit f51bcf50 authored by ShawnHu's avatar ShawnHu Committed by GitHub
Browse files

Add type hints in mmcv/cnn/resnet.py (#1982)

* Add type hints in resnet.py

* using lint while commit

* using lint while commit

* using lint while commit

* reslove typehints

* add pre-defined type hints

* Add type hints for other methods in mmcv/cnn/resnet.py

* Fix type hints
parent aea2bb28
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import logging import logging
from typing import Optional, Sequence, Tuple, Union
import torch.nn as nn import torch.nn as nn
import torch.utils.checkpoint as cp import torch.utils.checkpoint as cp
from torch import Tensor
from .utils import constant_init, kaiming_init from .utils import constant_init, kaiming_init
def conv3x3(in_planes, out_planes, stride=1, dilation=1): def conv3x3(in_planes: int,
out_planes: int,
stride: int = 1,
dilation: int = 1):
"""3x3 convolution with padding.""" """3x3 convolution with padding."""
return nn.Conv2d( return nn.Conv2d(
in_planes, in_planes,
...@@ -23,13 +28,13 @@ class BasicBlock(nn.Module): ...@@ -23,13 +28,13 @@ class BasicBlock(nn.Module):
expansion = 1 expansion = 1
def __init__(self, def __init__(self,
inplanes, inplanes: int,
planes, planes: int,
stride=1, stride: int = 1,
dilation=1, dilation: int = 1,
downsample=None, downsample: Optional[nn.Module] = None,
style='pytorch', style: str = 'pytorch',
with_cp=False): with_cp: bool = False):
super().__init__() super().__init__()
assert style in ['pytorch', 'caffe'] assert style in ['pytorch', 'caffe']
self.conv1 = conv3x3(inplanes, planes, stride, dilation) self.conv1 = conv3x3(inplanes, planes, stride, dilation)
...@@ -42,7 +47,7 @@ class BasicBlock(nn.Module): ...@@ -42,7 +47,7 @@ class BasicBlock(nn.Module):
self.dilation = dilation self.dilation = dilation
assert not with_cp assert not with_cp
def forward(self, x): def forward(self, x: Tensor) -> Tensor:
residual = x residual = x
out = self.conv1(x) out = self.conv1(x)
...@@ -65,13 +70,13 @@ class Bottleneck(nn.Module): ...@@ -65,13 +70,13 @@ class Bottleneck(nn.Module):
expansion = 4 expansion = 4
def __init__(self, def __init__(self,
inplanes, inplanes: int,
planes, planes: int,
stride=1, stride: int = 1,
dilation=1, dilation: int = 1,
downsample=None, downsample: Optional[nn.Module] = None,
style='pytorch', style: str = 'pytorch',
with_cp=False): with_cp: bool = False):
"""Bottleneck block. """Bottleneck block.
If style is "pytorch", the stride-two layer is the 3x3 conv layer, if If style is "pytorch", the stride-two layer is the 3x3 conv layer, if
...@@ -107,7 +112,7 @@ class Bottleneck(nn.Module): ...@@ -107,7 +112,7 @@ class Bottleneck(nn.Module):
self.dilation = dilation self.dilation = dilation
self.with_cp = with_cp self.with_cp = with_cp
def forward(self, x): def forward(self, x: Tensor) -> Tensor:
def _inner_forward(x): def _inner_forward(x):
residual = x residual = x
...@@ -140,14 +145,14 @@ class Bottleneck(nn.Module): ...@@ -140,14 +145,14 @@ class Bottleneck(nn.Module):
return out return out
def make_res_layer(block, def make_res_layer(block: nn.Module,
inplanes, inplanes: int,
planes, planes: int,
blocks, blocks: int,
stride=1, stride: int = 1,
dilation=1, dilation: int = 1,
style='pytorch', style: str = 'pytorch',
with_cp=False): with_cp: bool = False) -> nn.Module:
downsample = None downsample = None
if stride != 1 or inplanes != planes * block.expansion: if stride != 1 or inplanes != planes * block.expansion:
downsample = nn.Sequential( downsample = nn.Sequential(
...@@ -208,22 +213,22 @@ class ResNet(nn.Module): ...@@ -208,22 +213,22 @@ class ResNet(nn.Module):
} }
def __init__(self, def __init__(self,
depth, depth: int,
num_stages=4, num_stages: int = 4,
strides=(1, 2, 2, 2), strides: Sequence[int] = (1, 2, 2, 2),
dilations=(1, 1, 1, 1), dilations: Sequence[int] = (1, 1, 1, 1),
out_indices=(0, 1, 2, 3), out_indices: Sequence[int] = (0, 1, 2, 3),
style='pytorch', style: str = 'pytorch',
frozen_stages=-1, frozen_stages: int = -1,
bn_eval=True, bn_eval: bool = True,
bn_frozen=False, bn_frozen: bool = False,
with_cp=False): with_cp: bool = False):
super().__init__() super().__init__()
if depth not in self.arch_settings: if depth not in self.arch_settings:
raise KeyError(f'invalid depth {depth} for resnet') raise KeyError(f'invalid depth {depth} for resnet')
assert num_stages >= 1 and num_stages <= 4 assert num_stages >= 1 and num_stages <= 4
block, stage_blocks = self.arch_settings[depth] block, stage_blocks = self.arch_settings[depth]
stage_blocks = stage_blocks[:num_stages] stage_blocks = stage_blocks[:num_stages] # type: ignore
assert len(strides) == len(dilations) == num_stages assert len(strides) == len(dilations) == num_stages
assert max(out_indices) < num_stages assert max(out_indices) < num_stages
...@@ -234,7 +239,7 @@ class ResNet(nn.Module): ...@@ -234,7 +239,7 @@ class ResNet(nn.Module):
self.bn_frozen = bn_frozen self.bn_frozen = bn_frozen
self.with_cp = with_cp self.with_cp = with_cp
self.inplanes = 64 self.inplanes: int = 64
self.conv1 = nn.Conv2d( self.conv1 = nn.Conv2d(
3, 64, kernel_size=7, stride=2, padding=3, bias=False) 3, 64, kernel_size=7, stride=2, padding=3, bias=False)
self.bn1 = nn.BatchNorm2d(64) self.bn1 = nn.BatchNorm2d(64)
...@@ -255,14 +260,15 @@ class ResNet(nn.Module): ...@@ -255,14 +260,15 @@ class ResNet(nn.Module):
dilation=dilation, dilation=dilation,
style=self.style, style=self.style,
with_cp=with_cp) with_cp=with_cp)
self.inplanes = planes * block.expansion self.inplanes = planes * block.expansion # type: ignore
layer_name = f'layer{i + 1}' layer_name = f'layer{i + 1}'
self.add_module(layer_name, res_layer) self.add_module(layer_name, res_layer)
self.res_layers.append(layer_name) self.res_layers.append(layer_name)
self.feat_dim = block.expansion * 64 * 2**(len(stage_blocks) - 1) self.feat_dim = block.expansion * 64 * 2**( # type: ignore
len(stage_blocks) - 1)
def init_weights(self, pretrained=None): def init_weights(self, pretrained: Optional[str] = None) -> None:
if isinstance(pretrained, str): if isinstance(pretrained, str):
logger = logging.getLogger() logger = logging.getLogger()
from ..runner import load_checkpoint from ..runner import load_checkpoint
...@@ -276,7 +282,7 @@ class ResNet(nn.Module): ...@@ -276,7 +282,7 @@ class ResNet(nn.Module):
else: else:
raise TypeError('pretrained must be a str or None') raise TypeError('pretrained must be a str or None')
def forward(self, x): def forward(self, x: Tensor) -> Union[Tensor, Tuple[Tensor]]:
x = self.conv1(x) x = self.conv1(x)
x = self.bn1(x) x = self.bn1(x)
x = self.relu(x) x = self.relu(x)
...@@ -292,7 +298,7 @@ class ResNet(nn.Module): ...@@ -292,7 +298,7 @@ class ResNet(nn.Module):
else: else:
return tuple(outs) return tuple(outs)
def train(self, mode=True): def train(self, mode: bool = True) -> None:
super().train(mode) super().train(mode)
if self.bn_eval: if self.bn_eval:
for m in self.modules(): for m in self.modules():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment