Unverified Commit 305c2a30 authored by tripleMu's avatar tripleMu Committed by GitHub
Browse files

Add type hints for mmcv/cnn/bricks (#1993)



* Add type hint

* Add typehint in mmcv/cnn/bricks*

* Deal conflict0

* Fix

* fix

* minor fix

* minor fix
Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>
parent 2d3e42fc
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
...@@ -28,12 +30,12 @@ class Clamp(nn.Module): ...@@ -28,12 +30,12 @@ class Clamp(nn.Module):
Default to 1. Default to 1.
""" """
def __init__(self, min=-1., max=1.): def __init__(self, min: float = -1., max: float = 1.):
super().__init__() super().__init__()
self.min = min self.min = min
self.max = max self.max = max
def forward(self, x): def forward(self, x) -> torch.Tensor:
"""Forward function. """Forward function.
Args: Args:
...@@ -67,7 +69,7 @@ class GELU(nn.Module): ...@@ -67,7 +69,7 @@ class GELU(nn.Module):
>>> output = m(input) >>> output = m(input)
""" """
def forward(self, input): def forward(self, input: torch.Tensor) -> torch.Tensor:
return F.gelu(input) return F.gelu(input)
...@@ -78,7 +80,7 @@ else: ...@@ -78,7 +80,7 @@ else:
ACTIVATION_LAYERS.register_module(module=nn.GELU) ACTIVATION_LAYERS.register_module(module=nn.GELU)
def build_activation_layer(cfg): def build_activation_layer(cfg: Dict) -> nn.Module:
"""Build activation layer. """Build activation layer.
Args: Args:
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import Union
import torch import torch
from torch import nn from torch import nn
...@@ -6,7 +8,7 @@ from ..utils import constant_init, kaiming_init ...@@ -6,7 +8,7 @@ from ..utils import constant_init, kaiming_init
from .registry import PLUGIN_LAYERS from .registry import PLUGIN_LAYERS
def last_zero_init(m): def last_zero_init(m: Union[nn.Module, nn.Sequential]) -> None:
if isinstance(m, nn.Sequential): if isinstance(m, nn.Sequential):
constant_init(m[-1], val=0) constant_init(m[-1], val=0)
else: else:
...@@ -34,10 +36,10 @@ class ContextBlock(nn.Module): ...@@ -34,10 +36,10 @@ class ContextBlock(nn.Module):
_abbr_ = 'context_block' _abbr_ = 'context_block'
def __init__(self, def __init__(self,
in_channels, in_channels: int,
ratio, ratio: float,
pooling_type='att', pooling_type: str = 'att',
fusion_types=('channel_add', )): fusion_types: tuple = ('channel_add', )):
super().__init__() super().__init__()
assert pooling_type in ['avg', 'att'] assert pooling_type in ['avg', 'att']
assert isinstance(fusion_types, (list, tuple)) assert isinstance(fusion_types, (list, tuple))
...@@ -82,7 +84,7 @@ class ContextBlock(nn.Module): ...@@ -82,7 +84,7 @@ class ContextBlock(nn.Module):
if self.channel_mul_conv is not None: if self.channel_mul_conv is not None:
last_zero_init(self.channel_mul_conv) last_zero_init(self.channel_mul_conv)
def spatial_pool(self, x): def spatial_pool(self, x: torch.Tensor) -> torch.Tensor:
batch, channel, height, width = x.size() batch, channel, height, width = x.size()
if self.pooling_type == 'att': if self.pooling_type == 'att':
input_x = x input_x = x
...@@ -108,7 +110,7 @@ class ContextBlock(nn.Module): ...@@ -108,7 +110,7 @@ class ContextBlock(nn.Module):
return context return context
def forward(self, x): def forward(self, x: torch.Tensor) -> torch.Tensor:
# [N, C, 1, 1] # [N, C, 1, 1]
context = self.spatial_pool(x) context = self.spatial_pool(x)
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, Optional
from torch import nn from torch import nn
from .registry import CONV_LAYERS from .registry import CONV_LAYERS
...@@ -9,7 +11,7 @@ CONV_LAYERS.register_module('Conv3d', module=nn.Conv3d) ...@@ -9,7 +11,7 @@ CONV_LAYERS.register_module('Conv3d', module=nn.Conv3d)
CONV_LAYERS.register_module('Conv', module=nn.Conv2d) CONV_LAYERS.register_module('Conv', module=nn.Conv2d)
def build_conv_layer(cfg, *args, **kwargs): def build_conv_layer(cfg: Optional[Dict], *args, **kwargs) -> nn.Module:
"""Build convolution layer. """Build convolution layer.
Args: Args:
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import math import math
from typing import Tuple, Union
import torch
from torch import nn from torch import nn
from torch.nn import functional as F from torch.nn import functional as F
...@@ -31,18 +33,18 @@ class Conv2dAdaptivePadding(nn.Conv2d): ...@@ -31,18 +33,18 @@ class Conv2dAdaptivePadding(nn.Conv2d):
""" """
def __init__(self, def __init__(self,
in_channels, in_channels: int,
out_channels, out_channels: int,
kernel_size, kernel_size: Union[int, Tuple[int, int]],
stride=1, stride: Union[int, Tuple[int, int]] = 1,
padding=0, padding: Union[int, Tuple[int, int]] = 0,
dilation=1, dilation: Union[int, Tuple[int, int]] = 1,
groups=1, groups: int = 1,
bias=True): bias: bool = True):
super().__init__(in_channels, out_channels, kernel_size, stride, 0, super().__init__(in_channels, out_channels, kernel_size, stride, 0,
dilation, groups, bias) dilation, groups, bias)
def forward(self, x): def forward(self, x: torch.Tensor) -> torch.Tensor:
img_h, img_w = x.size()[-2:] img_h, img_w = x.size()[-2:]
kernel_h, kernel_w = self.weight.size()[-2:] kernel_h, kernel_w = self.weight.size()[-2:]
stride_h, stride_w = self.stride stride_h, stride_w = self.stride
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import warnings import warnings
from typing import Dict, Optional, Tuple, Union
import torch
import torch.nn as nn import torch.nn as nn
from mmcv.utils import _BatchNorm, _InstanceNorm from mmcv.utils import _BatchNorm, _InstanceNorm
...@@ -68,21 +70,21 @@ class ConvModule(nn.Module): ...@@ -68,21 +70,21 @@ class ConvModule(nn.Module):
_abbr_ = 'conv_block' _abbr_ = 'conv_block'
def __init__(self, def __init__(self,
in_channels, in_channels: int,
out_channels, out_channels: int,
kernel_size, kernel_size: Union[int, Tuple[int, int]],
stride=1, stride: Union[int, Tuple[int, int]] = 1,
padding=0, padding: Union[int, Tuple[int, int]] = 0,
dilation=1, dilation: Union[int, Tuple[int, int]] = 1,
groups=1, groups: int = 1,
bias='auto', bias: Union[bool, str] = 'auto',
conv_cfg=None, conv_cfg: Optional[Dict] = None,
norm_cfg=None, norm_cfg: Optional[Dict] = None,
act_cfg=dict(type='ReLU'), act_cfg: Optional[Dict] = dict(type='ReLU'),
inplace=True, inplace: bool = True,
with_spectral_norm=False, with_spectral_norm: bool = False,
padding_mode='zeros', padding_mode: str = 'zeros',
order=('conv', 'norm', 'act')): order: tuple = ('conv', 'norm', 'act')):
super().__init__() super().__init__()
assert conv_cfg is None or isinstance(conv_cfg, dict) assert conv_cfg is None or isinstance(conv_cfg, dict)
assert norm_cfg is None or isinstance(norm_cfg, dict) assert norm_cfg is None or isinstance(norm_cfg, dict)
...@@ -143,18 +145,19 @@ class ConvModule(nn.Module): ...@@ -143,18 +145,19 @@ class ConvModule(nn.Module):
norm_channels = out_channels norm_channels = out_channels
else: else:
norm_channels = in_channels norm_channels = in_channels
self.norm_name, norm = build_norm_layer(norm_cfg, norm_channels) self.norm_name, norm = build_norm_layer(
norm_cfg, norm_channels) # type: ignore
self.add_module(self.norm_name, norm) self.add_module(self.norm_name, norm)
if self.with_bias: if self.with_bias:
if isinstance(norm, (_BatchNorm, _InstanceNorm)): if isinstance(norm, (_BatchNorm, _InstanceNorm)):
warnings.warn( warnings.warn(
'Unnecessary conv bias before batch/instance norm') 'Unnecessary conv bias before batch/instance norm')
else: else:
self.norm_name = None self.norm_name = None # type: ignore
# build activation layer # build activation layer
if self.with_activation: if self.with_activation:
act_cfg_ = act_cfg.copy() act_cfg_ = act_cfg.copy() # type: ignore
# nn.Tanh has no 'inplace' argument # nn.Tanh has no 'inplace' argument
if act_cfg_['type'] not in [ if act_cfg_['type'] not in [
'Tanh', 'PReLU', 'Sigmoid', 'HSigmoid', 'Swish', 'GELU' 'Tanh', 'PReLU', 'Sigmoid', 'HSigmoid', 'Swish', 'GELU'
...@@ -193,7 +196,10 @@ class ConvModule(nn.Module): ...@@ -193,7 +196,10 @@ class ConvModule(nn.Module):
if self.with_norm: if self.with_norm:
constant_init(self.norm, 1, bias=0) constant_init(self.norm, 1, bias=0)
def forward(self, x, activate=True, norm=True): def forward(self,
x: torch.Tensor,
activate: bool = True,
norm: bool = True) -> torch.Tensor:
for layer in self.order: for layer in self.order:
if layer == 'conv': if layer == 'conv':
if self.with_explicit_padding: if self.with_explicit_padding:
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from collections import OrderedDict
from typing import Dict, List, Optional, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
...@@ -6,14 +9,14 @@ import torch.nn.functional as F ...@@ -6,14 +9,14 @@ import torch.nn.functional as F
from .registry import CONV_LAYERS from .registry import CONV_LAYERS
def conv_ws_2d(input, def conv_ws_2d(input: torch.Tensor,
weight, weight: torch.Tensor,
bias=None, bias: Optional[torch.Tensor] = None,
stride=1, stride: Union[int, Tuple[int, int]] = 1,
padding=0, padding: Union[int, Tuple[int, int]] = 0,
dilation=1, dilation: Union[int, Tuple[int, int]] = 1,
groups=1, groups: int = 1,
eps=1e-5): eps: float = 1e-5) -> torch.Tensor:
c_in = weight.size(0) c_in = weight.size(0)
weight_flat = weight.view(c_in, -1) weight_flat = weight.view(c_in, -1)
mean = weight_flat.mean(dim=1, keepdim=True).view(c_in, 1, 1, 1) mean = weight_flat.mean(dim=1, keepdim=True).view(c_in, 1, 1, 1)
...@@ -26,15 +29,15 @@ def conv_ws_2d(input, ...@@ -26,15 +29,15 @@ def conv_ws_2d(input,
class ConvWS2d(nn.Conv2d): class ConvWS2d(nn.Conv2d):
def __init__(self, def __init__(self,
in_channels, in_channels: int,
out_channels, out_channels: int,
kernel_size, kernel_size: Union[int, Tuple[int, int]],
stride=1, stride: Union[int, Tuple[int, int]] = 1,
padding=0, padding: Union[int, Tuple[int, int]] = 0,
dilation=1, dilation: Union[int, Tuple[int, int]] = 1,
groups=1, groups: int = 1,
bias=True, bias: bool = True,
eps=1e-5): eps: float = 1e-5):
super().__init__( super().__init__(
in_channels, in_channels,
out_channels, out_channels,
...@@ -46,7 +49,7 @@ class ConvWS2d(nn.Conv2d): ...@@ -46,7 +49,7 @@ class ConvWS2d(nn.Conv2d):
bias=bias) bias=bias)
self.eps = eps self.eps = eps
def forward(self, x): def forward(self, x: torch.Tensor) -> torch.Tensor:
return conv_ws_2d(x, self.weight, self.bias, self.stride, self.padding, return conv_ws_2d(x, self.weight, self.bias, self.stride, self.padding,
self.dilation, self.groups, self.eps) self.dilation, self.groups, self.eps)
...@@ -76,14 +79,14 @@ class ConvAWS2d(nn.Conv2d): ...@@ -76,14 +79,14 @@ class ConvAWS2d(nn.Conv2d):
""" """
def __init__(self, def __init__(self,
in_channels, in_channels: int,
out_channels, out_channels: int,
kernel_size, kernel_size: Union[int, Tuple[int, int]],
stride=1, stride: Union[int, Tuple[int, int]] = 1,
padding=0, padding: Union[int, Tuple[int, int]] = 0,
dilation=1, dilation: Union[int, Tuple[int, int]] = 1,
groups=1, groups: int = 1,
bias=True): bias: bool = True):
super().__init__( super().__init__(
in_channels, in_channels,
out_channels, out_channels,
...@@ -98,7 +101,7 @@ class ConvAWS2d(nn.Conv2d): ...@@ -98,7 +101,7 @@ class ConvAWS2d(nn.Conv2d):
self.register_buffer('weight_beta', self.register_buffer('weight_beta',
torch.zeros(self.out_channels, 1, 1, 1)) torch.zeros(self.out_channels, 1, 1, 1))
def _get_weight(self, weight): def _get_weight(self, weight: torch.Tensor) -> torch.Tensor:
weight_flat = weight.view(weight.size(0), -1) weight_flat = weight.view(weight.size(0), -1)
mean = weight_flat.mean(dim=1).view(-1, 1, 1, 1) mean = weight_flat.mean(dim=1).view(-1, 1, 1, 1)
std = torch.sqrt(weight_flat.var(dim=1) + 1e-5).view(-1, 1, 1, 1) std = torch.sqrt(weight_flat.var(dim=1) + 1e-5).view(-1, 1, 1, 1)
...@@ -106,13 +109,16 @@ class ConvAWS2d(nn.Conv2d): ...@@ -106,13 +109,16 @@ class ConvAWS2d(nn.Conv2d):
weight = self.weight_gamma * weight + self.weight_beta weight = self.weight_gamma * weight + self.weight_beta
return weight return weight
def forward(self, x): def forward(self, x: torch.Tensor) -> torch.Tensor:
weight = self._get_weight(self.weight) weight = self._get_weight(self.weight)
return F.conv2d(x, weight, self.bias, self.stride, self.padding, return F.conv2d(x, weight, self.bias, self.stride, self.padding,
self.dilation, self.groups) self.dilation, self.groups)
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, def _load_from_state_dict(self, state_dict: OrderedDict, prefix: str,
missing_keys, unexpected_keys, error_msgs): local_metadata: Dict, strict: bool,
missing_keys: List[str],
unexpected_keys: List[str],
error_msgs: List[str]) -> None:
"""Override default load function. """Override default load function.
AWS overrides the function _load_from_state_dict to recover AWS overrides the function _load_from_state_dict to recover
...@@ -124,7 +130,7 @@ class ConvAWS2d(nn.Conv2d): ...@@ -124,7 +130,7 @@ class ConvAWS2d(nn.Conv2d):
""" """
self.weight_gamma.data.fill_(-1) self.weight_gamma.data.fill_(-1)
local_missing_keys = [] local_missing_keys: List = []
super()._load_from_state_dict(state_dict, prefix, local_metadata, super()._load_from_state_dict(state_dict, prefix, local_metadata,
strict, local_missing_keys, strict, local_missing_keys,
unexpected_keys, error_msgs) unexpected_keys, error_msgs)
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, Optional, Tuple, Union
import torch
import torch.nn as nn import torch.nn as nn
from .conv_module import ConvModule from .conv_module import ConvModule
...@@ -46,27 +49,27 @@ class DepthwiseSeparableConvModule(nn.Module): ...@@ -46,27 +49,27 @@ class DepthwiseSeparableConvModule(nn.Module):
""" """
def __init__(self, def __init__(self,
in_channels, in_channels: int,
out_channels, out_channels: int,
kernel_size, kernel_size: Union[int, Tuple[int, int]],
stride=1, stride: Union[int, Tuple[int, int]] = 1,
padding=0, padding: Union[int, Tuple[int, int]] = 0,
dilation=1, dilation: Union[int, Tuple[int, int]] = 1,
norm_cfg=None, norm_cfg: Optional[Dict] = None,
act_cfg=dict(type='ReLU'), act_cfg: Dict = dict(type='ReLU'),
dw_norm_cfg='default', dw_norm_cfg: Union[Dict, str] = 'default',
dw_act_cfg='default', dw_act_cfg: Union[Dict, str] = 'default',
pw_norm_cfg='default', pw_norm_cfg: Union[Dict, str] = 'default',
pw_act_cfg='default', pw_act_cfg: Union[Dict, str] = 'default',
**kwargs): **kwargs):
super().__init__() super().__init__()
assert 'groups' not in kwargs, 'groups should not be specified' assert 'groups' not in kwargs, 'groups should not be specified'
# if norm/activation config of depthwise/pointwise ConvModule is not # if norm/activation config of depthwise/pointwise ConvModule is not
# specified, use default config. # specified, use default config.
dw_norm_cfg = dw_norm_cfg if dw_norm_cfg != 'default' else norm_cfg dw_norm_cfg = dw_norm_cfg if dw_norm_cfg != 'default' else norm_cfg # type: ignore # noqa E501
dw_act_cfg = dw_act_cfg if dw_act_cfg != 'default' else act_cfg dw_act_cfg = dw_act_cfg if dw_act_cfg != 'default' else act_cfg
pw_norm_cfg = pw_norm_cfg if pw_norm_cfg != 'default' else norm_cfg pw_norm_cfg = pw_norm_cfg if pw_norm_cfg != 'default' else norm_cfg # type: ignore # noqa E501
pw_act_cfg = pw_act_cfg if pw_act_cfg != 'default' else act_cfg pw_act_cfg = pw_act_cfg if pw_act_cfg != 'default' else act_cfg
# depthwise convolution # depthwise convolution
...@@ -78,19 +81,19 @@ class DepthwiseSeparableConvModule(nn.Module): ...@@ -78,19 +81,19 @@ class DepthwiseSeparableConvModule(nn.Module):
padding=padding, padding=padding,
dilation=dilation, dilation=dilation,
groups=in_channels, groups=in_channels,
norm_cfg=dw_norm_cfg, norm_cfg=dw_norm_cfg, # type: ignore
act_cfg=dw_act_cfg, act_cfg=dw_act_cfg, # type: ignore
**kwargs) **kwargs)
self.pointwise_conv = ConvModule( self.pointwise_conv = ConvModule(
in_channels, in_channels,
out_channels, out_channels,
1, 1,
norm_cfg=pw_norm_cfg, norm_cfg=pw_norm_cfg, # type: ignore
act_cfg=pw_act_cfg, act_cfg=pw_act_cfg, # type: ignore
**kwargs) **kwargs)
def forward(self, x): def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.depthwise_conv(x) x = self.depthwise_conv(x)
x = self.pointwise_conv(x) x = self.pointwise_conv(x)
return x return x
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import Any, Dict, Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -6,7 +8,9 @@ from mmcv import build_from_cfg ...@@ -6,7 +8,9 @@ from mmcv import build_from_cfg
from .registry import DROPOUT_LAYERS from .registry import DROPOUT_LAYERS
def drop_path(x, drop_prob=0., training=False): def drop_path(x: torch.Tensor,
drop_prob: float = 0.,
training: bool = False) -> torch.Tensor:
"""Drop paths (Stochastic Depth) per sample (when applied in main path of """Drop paths (Stochastic Depth) per sample (when applied in main path of
residual blocks). residual blocks).
...@@ -36,11 +40,11 @@ class DropPath(nn.Module): ...@@ -36,11 +40,11 @@ class DropPath(nn.Module):
drop_prob (float): Probability of the path to be zeroed. Default: 0.1 drop_prob (float): Probability of the path to be zeroed. Default: 0.1
""" """
def __init__(self, drop_prob=0.1): def __init__(self, drop_prob: float = 0.1):
super().__init__() super().__init__()
self.drop_prob = drop_prob self.drop_prob = drop_prob
def forward(self, x): def forward(self, x: torch.Tensor) -> torch.Tensor:
return drop_path(x, self.drop_prob, self.training) return drop_path(x, self.drop_prob, self.training)
...@@ -56,10 +60,10 @@ class Dropout(nn.Dropout): ...@@ -56,10 +60,10 @@ class Dropout(nn.Dropout):
inplace (bool): Do the operation inplace or not. Default: False. inplace (bool): Do the operation inplace or not. Default: False.
""" """
def __init__(self, drop_prob=0.5, inplace=False): def __init__(self, drop_prob: float = 0.5, inplace: bool = False):
super().__init__(p=drop_prob, inplace=inplace) super().__init__(p=drop_prob, inplace=inplace)
def build_dropout(cfg, default_args=None): def build_dropout(cfg: Dict, default_args: Optional[Dict] = None) -> Any:
"""Builder for drop out layers.""" """Builder for drop out layers."""
return build_from_cfg(cfg, DROPOUT_LAYERS, default_args) return build_from_cfg(cfg, DROPOUT_LAYERS, default_args)
...@@ -45,14 +45,14 @@ class GeneralizedAttention(nn.Module): ...@@ -45,14 +45,14 @@ class GeneralizedAttention(nn.Module):
_abbr_ = 'gen_attention_block' _abbr_ = 'gen_attention_block'
def __init__(self, def __init__(self,
in_channels, in_channels: int,
spatial_range=-1, spatial_range: int = -1,
num_heads=9, num_heads: int = 9,
position_embedding_dim=-1, position_embedding_dim: int = -1,
position_magnitude=1, position_magnitude: int = 1,
kv_stride=2, kv_stride: int = 2,
q_stride=1, q_stride: int = 1,
attention_type='1111'): attention_type: str = '1111'):
super().__init__() super().__init__()
...@@ -213,7 +213,7 @@ class GeneralizedAttention(nn.Module): ...@@ -213,7 +213,7 @@ class GeneralizedAttention(nn.Module):
return embedding_x, embedding_y return embedding_x, embedding_y
def forward(self, x_input): def forward(self, x_input: torch.Tensor) -> torch.Tensor:
num_heads = self.num_heads num_heads = self.num_heads
# use empirical_attention # use empirical_attention
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import warnings import warnings
import torch
import torch.nn as nn import torch.nn as nn
from .registry import ACTIVATION_LAYERS from .registry import ACTIVATION_LAYERS
...@@ -26,7 +27,11 @@ class HSigmoid(nn.Module): ...@@ -26,7 +27,11 @@ class HSigmoid(nn.Module):
Tensor: The output tensor. Tensor: The output tensor.
""" """
def __init__(self, bias=3.0, divisor=6.0, min_value=0.0, max_value=1.0): def __init__(self,
bias: float = 3.0,
divisor: float = 6.0,
min_value: float = 0.0,
max_value: float = 1.0):
super().__init__() super().__init__()
warnings.warn( warnings.warn(
'In MMCV v1.4.4, we modified the default value of args to align ' 'In MMCV v1.4.4, we modified the default value of args to align '
...@@ -40,7 +45,7 @@ class HSigmoid(nn.Module): ...@@ -40,7 +45,7 @@ class HSigmoid(nn.Module):
self.min_value = min_value self.min_value = min_value
self.max_value = max_value self.max_value = max_value
def forward(self, x): def forward(self, x: torch.Tensor) -> torch.Tensor:
x = (x + self.bias) / self.divisor x = (x + self.bias) / self.divisor
return x.clamp_(self.min_value, self.max_value) return x.clamp_(self.min_value, self.max_value)
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn import torch.nn as nn
from mmcv.utils import TORCH_VERSION, digit_version from mmcv.utils import TORCH_VERSION, digit_version
...@@ -21,11 +22,11 @@ class HSwish(nn.Module): ...@@ -21,11 +22,11 @@ class HSwish(nn.Module):
Tensor: The output tensor. Tensor: The output tensor.
""" """
def __init__(self, inplace=False): def __init__(self, inplace: bool = False):
super().__init__() super().__init__()
self.act = nn.ReLU6(inplace) self.act = nn.ReLU6(inplace)
def forward(self, x): def forward(self, x: torch.Tensor) -> torch.Tensor:
return x * self.act(x + 3) / 6 return x * self.act(x + 3) / 6
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from abc import ABCMeta from abc import ABCMeta
from typing import Dict, Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -33,12 +34,12 @@ class _NonLocalNd(nn.Module, metaclass=ABCMeta): ...@@ -33,12 +34,12 @@ class _NonLocalNd(nn.Module, metaclass=ABCMeta):
""" """
def __init__(self, def __init__(self,
in_channels, in_channels: int,
reduction=2, reduction: int = 2,
use_scale=True, use_scale: bool = True,
conv_cfg=None, conv_cfg: Optional[Dict] = None,
norm_cfg=None, norm_cfg: Optional[Dict] = None,
mode='embedded_gaussian', mode: str = 'embedded_gaussian',
**kwargs): **kwargs):
super().__init__() super().__init__()
self.in_channels = in_channels self.in_channels = in_channels
...@@ -61,7 +62,7 @@ class _NonLocalNd(nn.Module, metaclass=ABCMeta): ...@@ -61,7 +62,7 @@ class _NonLocalNd(nn.Module, metaclass=ABCMeta):
self.inter_channels, self.inter_channels,
kernel_size=1, kernel_size=1,
conv_cfg=conv_cfg, conv_cfg=conv_cfg,
act_cfg=None) act_cfg=None) # type: ignore
self.conv_out = ConvModule( self.conv_out = ConvModule(
self.inter_channels, self.inter_channels,
self.in_channels, self.in_channels,
...@@ -96,7 +97,7 @@ class _NonLocalNd(nn.Module, metaclass=ABCMeta): ...@@ -96,7 +97,7 @@ class _NonLocalNd(nn.Module, metaclass=ABCMeta):
self.init_weights(**kwargs) self.init_weights(**kwargs)
def init_weights(self, std=0.01, zeros_init=True): def init_weights(self, std: float = 0.01, zeros_init: bool = True) -> None:
if self.mode != 'gaussian': if self.mode != 'gaussian':
for m in [self.g, self.theta, self.phi]: for m in [self.g, self.theta, self.phi]:
normal_init(m.conv, std=std) normal_init(m.conv, std=std)
...@@ -113,7 +114,8 @@ class _NonLocalNd(nn.Module, metaclass=ABCMeta): ...@@ -113,7 +114,8 @@ class _NonLocalNd(nn.Module, metaclass=ABCMeta):
else: else:
normal_init(self.conv_out.norm, std=std) normal_init(self.conv_out.norm, std=std)
def gaussian(self, theta_x, phi_x): def gaussian(self, theta_x: torch.Tensor,
phi_x: torch.Tensor) -> torch.Tensor:
# NonLocal1d pairwise_weight: [N, H, H] # NonLocal1d pairwise_weight: [N, H, H]
# NonLocal2d pairwise_weight: [N, HxW, HxW] # NonLocal2d pairwise_weight: [N, HxW, HxW]
# NonLocal3d pairwise_weight: [N, TxHxW, TxHxW] # NonLocal3d pairwise_weight: [N, TxHxW, TxHxW]
...@@ -121,7 +123,8 @@ class _NonLocalNd(nn.Module, metaclass=ABCMeta): ...@@ -121,7 +123,8 @@ class _NonLocalNd(nn.Module, metaclass=ABCMeta):
pairwise_weight = pairwise_weight.softmax(dim=-1) pairwise_weight = pairwise_weight.softmax(dim=-1)
return pairwise_weight return pairwise_weight
def embedded_gaussian(self, theta_x, phi_x): def embedded_gaussian(self, theta_x: torch.Tensor,
phi_x: torch.Tensor) -> torch.Tensor:
# NonLocal1d pairwise_weight: [N, H, H] # NonLocal1d pairwise_weight: [N, H, H]
# NonLocal2d pairwise_weight: [N, HxW, HxW] # NonLocal2d pairwise_weight: [N, HxW, HxW]
# NonLocal3d pairwise_weight: [N, TxHxW, TxHxW] # NonLocal3d pairwise_weight: [N, TxHxW, TxHxW]
...@@ -132,7 +135,8 @@ class _NonLocalNd(nn.Module, metaclass=ABCMeta): ...@@ -132,7 +135,8 @@ class _NonLocalNd(nn.Module, metaclass=ABCMeta):
pairwise_weight = pairwise_weight.softmax(dim=-1) pairwise_weight = pairwise_weight.softmax(dim=-1)
return pairwise_weight return pairwise_weight
def dot_product(self, theta_x, phi_x): def dot_product(self, theta_x: torch.Tensor,
phi_x: torch.Tensor) -> torch.Tensor:
# NonLocal1d pairwise_weight: [N, H, H] # NonLocal1d pairwise_weight: [N, H, H]
# NonLocal2d pairwise_weight: [N, HxW, HxW] # NonLocal2d pairwise_weight: [N, HxW, HxW]
# NonLocal3d pairwise_weight: [N, TxHxW, TxHxW] # NonLocal3d pairwise_weight: [N, TxHxW, TxHxW]
...@@ -140,7 +144,8 @@ class _NonLocalNd(nn.Module, metaclass=ABCMeta): ...@@ -140,7 +144,8 @@ class _NonLocalNd(nn.Module, metaclass=ABCMeta):
pairwise_weight /= pairwise_weight.shape[-1] pairwise_weight /= pairwise_weight.shape[-1]
return pairwise_weight return pairwise_weight
def concatenation(self, theta_x, phi_x): def concatenation(self, theta_x: torch.Tensor,
phi_x: torch.Tensor) -> torch.Tensor:
# NonLocal1d pairwise_weight: [N, H, H] # NonLocal1d pairwise_weight: [N, H, H]
# NonLocal2d pairwise_weight: [N, HxW, HxW] # NonLocal2d pairwise_weight: [N, HxW, HxW]
# NonLocal3d pairwise_weight: [N, TxHxW, TxHxW] # NonLocal3d pairwise_weight: [N, TxHxW, TxHxW]
...@@ -157,7 +162,7 @@ class _NonLocalNd(nn.Module, metaclass=ABCMeta): ...@@ -157,7 +162,7 @@ class _NonLocalNd(nn.Module, metaclass=ABCMeta):
return pairwise_weight return pairwise_weight
def forward(self, x): def forward(self, x: torch.Tensor) -> torch.Tensor:
# Assume `reduction = 1`, then `inter_channels = C` # Assume `reduction = 1`, then `inter_channels = C`
# or `inter_channels = C` when `mode="gaussian"` # or `inter_channels = C` when `mode="gaussian"`
...@@ -224,9 +229,9 @@ class NonLocal1d(_NonLocalNd): ...@@ -224,9 +229,9 @@ class NonLocal1d(_NonLocalNd):
""" """
def __init__(self, def __init__(self,
in_channels, in_channels: int,
sub_sample=False, sub_sample: bool = False,
conv_cfg=dict(type='Conv1d'), conv_cfg: Dict = dict(type='Conv1d'),
**kwargs): **kwargs):
super().__init__(in_channels, conv_cfg=conv_cfg, **kwargs) super().__init__(in_channels, conv_cfg=conv_cfg, **kwargs)
...@@ -257,9 +262,9 @@ class NonLocal2d(_NonLocalNd): ...@@ -257,9 +262,9 @@ class NonLocal2d(_NonLocalNd):
_abbr_ = 'nonlocal_block' _abbr_ = 'nonlocal_block'
def __init__(self, def __init__(self,
in_channels, in_channels: int,
sub_sample=False, sub_sample: bool = False,
conv_cfg=dict(type='Conv2d'), conv_cfg: Dict = dict(type='Conv2d'),
**kwargs): **kwargs):
super().__init__(in_channels, conv_cfg=conv_cfg, **kwargs) super().__init__(in_channels, conv_cfg=conv_cfg, **kwargs)
...@@ -287,9 +292,9 @@ class NonLocal3d(_NonLocalNd): ...@@ -287,9 +292,9 @@ class NonLocal3d(_NonLocalNd):
""" """
def __init__(self, def __init__(self,
in_channels, in_channels: int,
sub_sample=False, sub_sample: bool = False,
conv_cfg=dict(type='Conv3d'), conv_cfg: Dict = dict(type='Conv3d'),
**kwargs): **kwargs):
super().__init__(in_channels, conv_cfg=conv_cfg, **kwargs) super().__init__(in_channels, conv_cfg=conv_cfg, **kwargs)
self.sub_sample = sub_sample self.sub_sample = sub_sample
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import inspect import inspect
from typing import Dict, Tuple, Union
import torch.nn as nn import torch.nn as nn
...@@ -69,7 +70,9 @@ def infer_abbr(class_type): ...@@ -69,7 +70,9 @@ def infer_abbr(class_type):
return 'norm_layer' return 'norm_layer'
def build_norm_layer(cfg, num_features, postfix=''): def build_norm_layer(cfg: Dict,
num_features: int,
postfix: Union[int, str] = '') -> Tuple[str, nn.Module]:
"""Build normalization layer. """Build normalization layer.
Args: Args:
...@@ -119,7 +122,8 @@ def build_norm_layer(cfg, num_features, postfix=''): ...@@ -119,7 +122,8 @@ def build_norm_layer(cfg, num_features, postfix=''):
return name, layer return name, layer
def is_norm(layer, exclude=None): def is_norm(layer: nn.Module,
exclude: Union[type, tuple, None] = None) -> bool:
"""Check if a layer is a normalization layer. """Check if a layer is a normalization layer.
Args: Args:
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict
import torch.nn as nn import torch.nn as nn
from .registry import PADDING_LAYERS from .registry import PADDING_LAYERS
...@@ -8,11 +10,11 @@ PADDING_LAYERS.register_module('reflect', module=nn.ReflectionPad2d) ...@@ -8,11 +10,11 @@ PADDING_LAYERS.register_module('reflect', module=nn.ReflectionPad2d)
PADDING_LAYERS.register_module('replicate', module=nn.ReplicationPad2d) PADDING_LAYERS.register_module('replicate', module=nn.ReplicationPad2d)
def build_padding_layer(cfg, *args, **kwargs): def build_padding_layer(cfg: Dict, *args, **kwargs) -> nn.Module:
"""Build padding layer. """Build padding layer.
Args: Args:
cfg (None or dict): The padding layer config, which should contain: cfg (dict): The padding layer config, which should contain:
- type (str): Layer type. - type (str): Layer type.
- layer args: Args needed to instantiate a padding layer. - layer args: Args needed to instantiate a padding layer.
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import inspect import inspect
import platform import platform
from typing import Dict, Tuple, Union
import torch.nn as nn
from .registry import PLUGIN_LAYERS from .registry import PLUGIN_LAYERS
...@@ -10,7 +13,7 @@ else: ...@@ -10,7 +13,7 @@ else:
import re # type: ignore import re # type: ignore
def infer_abbr(class_type): def infer_abbr(class_type: type) -> str:
"""Infer abbreviation from the class name. """Infer abbreviation from the class name.
This method will infer the abbreviation to map class types to This method will infer the abbreviation to map class types to
...@@ -48,16 +51,18 @@ def infer_abbr(class_type): ...@@ -48,16 +51,18 @@ def infer_abbr(class_type):
raise TypeError( raise TypeError(
f'class_type must be a type, but got {type(class_type)}') f'class_type must be a type, but got {type(class_type)}')
if hasattr(class_type, '_abbr_'): if hasattr(class_type, '_abbr_'):
return class_type._abbr_ return class_type._abbr_ # type: ignore
else: else:
return camel2snack(class_type.__name__) return camel2snack(class_type.__name__)
def build_plugin_layer(cfg, postfix='', **kwargs): def build_plugin_layer(cfg: Dict,
postfix: Union[int, str] = '',
**kwargs) -> Tuple[str, nn.Module]:
"""Build plugin layer. """Build plugin layer.
Args: Args:
cfg (None or dict): cfg should contain: cfg (dict): cfg should contain:
- type (str): identify plugin layer type. - type (str): identify plugin layer type.
- layer args: args needed to instantiate a plugin layer. - layer args: args needed to instantiate a plugin layer.
......
...@@ -13,9 +13,9 @@ class Scale(nn.Module): ...@@ -13,9 +13,9 @@ class Scale(nn.Module):
scale (float): Initial value of scale factor. Default: 1.0 scale (float): Initial value of scale factor. Default: 1.0
""" """
def __init__(self, scale=1.0): def __init__(self, scale: float = 1.0):
super().__init__() super().__init__()
self.scale = nn.Parameter(torch.tensor(scale, dtype=torch.float)) self.scale = nn.Parameter(torch.tensor(scale, dtype=torch.float))
def forward(self, x): def forward(self, x: torch.Tensor) -> torch.Tensor:
return x * self.scale return x * self.scale
...@@ -21,5 +21,5 @@ class Swish(nn.Module): ...@@ -21,5 +21,5 @@ class Swish(nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
def forward(self, x): def forward(self, x: torch.Tensor) -> torch.Tensor:
return x * torch.sigmoid(x) return x * torch.sigmoid(x)
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict
import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
...@@ -24,8 +27,8 @@ class PixelShufflePack(nn.Module): ...@@ -24,8 +27,8 @@ class PixelShufflePack(nn.Module):
channels. channels.
""" """
def __init__(self, in_channels, out_channels, scale_factor, def __init__(self, in_channels: int, out_channels: int, scale_factor: int,
upsample_kernel): upsample_kernel: int):
super().__init__() super().__init__()
self.in_channels = in_channels self.in_channels = in_channels
self.out_channels = out_channels self.out_channels = out_channels
...@@ -41,13 +44,13 @@ class PixelShufflePack(nn.Module): ...@@ -41,13 +44,13 @@ class PixelShufflePack(nn.Module):
def init_weights(self): def init_weights(self):
xavier_init(self.upsample_conv, distribution='uniform') xavier_init(self.upsample_conv, distribution='uniform')
def forward(self, x): def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.upsample_conv(x) x = self.upsample_conv(x)
x = F.pixel_shuffle(x, self.scale_factor) x = F.pixel_shuffle(x, self.scale_factor)
return x return x
def build_upsample_layer(cfg, *args, **kwargs): def build_upsample_layer(cfg: Dict, *args, **kwargs) -> nn.Module:
"""Build upsample layer. """Build upsample layer.
Args: Args:
......
...@@ -21,19 +21,19 @@ else: ...@@ -21,19 +21,19 @@ else:
TORCH_VERSION = tuple(int(x) for x in torch.__version__.split('.')[:2]) TORCH_VERSION = tuple(int(x) for x in torch.__version__.split('.')[:2])
def obsolete_torch_version(torch_version, version_threshold): def obsolete_torch_version(torch_version, version_threshold) -> bool:
return torch_version == 'parrots' or torch_version <= version_threshold return torch_version == 'parrots' or torch_version <= version_threshold
class NewEmptyTensorOp(torch.autograd.Function): class NewEmptyTensorOp(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, x, new_shape): def forward(ctx, x: torch.Tensor, new_shape: tuple) -> torch.Tensor:
ctx.shape = x.shape ctx.shape = x.shape
return x.new_empty(new_shape) return x.new_empty(new_shape)
@staticmethod @staticmethod
def backward(ctx, grad): def backward(ctx, grad: torch.Tensor) -> tuple:
shape = ctx.shape shape = ctx.shape
return NewEmptyTensorOp.apply(grad, shape), None return NewEmptyTensorOp.apply(grad, shape), None
...@@ -41,7 +41,7 @@ class NewEmptyTensorOp(torch.autograd.Function): ...@@ -41,7 +41,7 @@ class NewEmptyTensorOp(torch.autograd.Function):
@CONV_LAYERS.register_module('Conv', force=True) @CONV_LAYERS.register_module('Conv', force=True)
class Conv2d(nn.Conv2d): class Conv2d(nn.Conv2d):
def forward(self, x): def forward(self, x: torch.Tensor) -> torch.Tensor:
if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 4)): if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 4)):
out_shape = [x.shape[0], self.out_channels] out_shape = [x.shape[0], self.out_channels]
for i, k, p, s, d in zip(x.shape[-2:], self.kernel_size, for i, k, p, s, d in zip(x.shape[-2:], self.kernel_size,
...@@ -62,7 +62,7 @@ class Conv2d(nn.Conv2d): ...@@ -62,7 +62,7 @@ class Conv2d(nn.Conv2d):
@CONV_LAYERS.register_module('Conv3d', force=True) @CONV_LAYERS.register_module('Conv3d', force=True)
class Conv3d(nn.Conv3d): class Conv3d(nn.Conv3d):
def forward(self, x): def forward(self, x: torch.Tensor) -> torch.Tensor:
if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 4)): if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 4)):
out_shape = [x.shape[0], self.out_channels] out_shape = [x.shape[0], self.out_channels]
for i, k, p, s, d in zip(x.shape[-3:], self.kernel_size, for i, k, p, s, d in zip(x.shape[-3:], self.kernel_size,
...@@ -85,7 +85,7 @@ class Conv3d(nn.Conv3d): ...@@ -85,7 +85,7 @@ class Conv3d(nn.Conv3d):
@UPSAMPLE_LAYERS.register_module('deconv', force=True) @UPSAMPLE_LAYERS.register_module('deconv', force=True)
class ConvTranspose2d(nn.ConvTranspose2d): class ConvTranspose2d(nn.ConvTranspose2d):
def forward(self, x): def forward(self, x: torch.Tensor) -> torch.Tensor:
if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 4)): if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 4)):
out_shape = [x.shape[0], self.out_channels] out_shape = [x.shape[0], self.out_channels]
for i, k, p, s, d, op in zip(x.shape[-2:], self.kernel_size, for i, k, p, s, d, op in zip(x.shape[-2:], self.kernel_size,
...@@ -108,7 +108,7 @@ class ConvTranspose2d(nn.ConvTranspose2d): ...@@ -108,7 +108,7 @@ class ConvTranspose2d(nn.ConvTranspose2d):
@UPSAMPLE_LAYERS.register_module('deconv3d', force=True) @UPSAMPLE_LAYERS.register_module('deconv3d', force=True)
class ConvTranspose3d(nn.ConvTranspose3d): class ConvTranspose3d(nn.ConvTranspose3d):
def forward(self, x): def forward(self, x: torch.Tensor) -> torch.Tensor:
if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 4)): if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 4)):
out_shape = [x.shape[0], self.out_channels] out_shape = [x.shape[0], self.out_channels]
for i, k, p, s, d, op in zip(x.shape[-3:], self.kernel_size, for i, k, p, s, d, op in zip(x.shape[-3:], self.kernel_size,
...@@ -128,7 +128,7 @@ class ConvTranspose3d(nn.ConvTranspose3d): ...@@ -128,7 +128,7 @@ class ConvTranspose3d(nn.ConvTranspose3d):
class MaxPool2d(nn.MaxPool2d): class MaxPool2d(nn.MaxPool2d):
def forward(self, x): def forward(self, x: torch.Tensor) -> torch.Tensor:
# PyTorch 1.9 does not support empty tensor inference yet # PyTorch 1.9 does not support empty tensor inference yet
if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 9)): if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 9)):
out_shape = list(x.shape[:2]) out_shape = list(x.shape[:2])
...@@ -146,7 +146,7 @@ class MaxPool2d(nn.MaxPool2d): ...@@ -146,7 +146,7 @@ class MaxPool2d(nn.MaxPool2d):
class MaxPool3d(nn.MaxPool3d): class MaxPool3d(nn.MaxPool3d):
def forward(self, x): def forward(self, x: torch.Tensor) -> torch.Tensor:
# PyTorch 1.9 does not support empty tensor inference yet # PyTorch 1.9 does not support empty tensor inference yet
if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 9)): if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 9)):
out_shape = list(x.shape[:2]) out_shape = list(x.shape[:2])
...@@ -165,7 +165,7 @@ class MaxPool3d(nn.MaxPool3d): ...@@ -165,7 +165,7 @@ class MaxPool3d(nn.MaxPool3d):
class Linear(torch.nn.Linear): class Linear(torch.nn.Linear):
def forward(self, x): def forward(self, x: torch.Tensor) -> torch.Tensor:
# empty tensor forward of Linear layer is supported in Pytorch 1.6 # empty tensor forward of Linear layer is supported in Pytorch 1.6
if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 5)): if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 5)):
out_shape = [x.shape[0], self.out_features] out_shape = [x.shape[0], self.out_features]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment