Commit fdeee889 authored by limm's avatar limm
Browse files

release v1.6.1 of mmcv

parent df465820
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import Union
import torch import torch
from torch import nn from torch import nn
...@@ -6,7 +8,7 @@ from ..utils import constant_init, kaiming_init ...@@ -6,7 +8,7 @@ from ..utils import constant_init, kaiming_init
from .registry import PLUGIN_LAYERS from .registry import PLUGIN_LAYERS
def last_zero_init(m): def last_zero_init(m: Union[nn.Module, nn.Sequential]) -> None:
if isinstance(m, nn.Sequential): if isinstance(m, nn.Sequential):
constant_init(m[-1], val=0) constant_init(m[-1], val=0)
else: else:
...@@ -34,11 +36,11 @@ class ContextBlock(nn.Module): ...@@ -34,11 +36,11 @@ class ContextBlock(nn.Module):
_abbr_ = 'context_block' _abbr_ = 'context_block'
def __init__(self, def __init__(self,
in_channels, in_channels: int,
ratio, ratio: float,
pooling_type='att', pooling_type: str = 'att',
fusion_types=('channel_add', )): fusion_types: tuple = ('channel_add', )):
super(ContextBlock, self).__init__() super().__init__()
assert pooling_type in ['avg', 'att'] assert pooling_type in ['avg', 'att']
assert isinstance(fusion_types, (list, tuple)) assert isinstance(fusion_types, (list, tuple))
valid_fusion_types = ['channel_add', 'channel_mul'] valid_fusion_types = ['channel_add', 'channel_mul']
...@@ -82,7 +84,7 @@ class ContextBlock(nn.Module): ...@@ -82,7 +84,7 @@ class ContextBlock(nn.Module):
if self.channel_mul_conv is not None: if self.channel_mul_conv is not None:
last_zero_init(self.channel_mul_conv) last_zero_init(self.channel_mul_conv)
def spatial_pool(self, x): def spatial_pool(self, x: torch.Tensor) -> torch.Tensor:
batch, channel, height, width = x.size() batch, channel, height, width = x.size()
if self.pooling_type == 'att': if self.pooling_type == 'att':
input_x = x input_x = x
...@@ -108,7 +110,7 @@ class ContextBlock(nn.Module): ...@@ -108,7 +110,7 @@ class ContextBlock(nn.Module):
return context return context
def forward(self, x): def forward(self, x: torch.Tensor) -> torch.Tensor:
# [N, C, 1, 1] # [N, C, 1, 1]
context = self.spatial_pool(x) context = self.spatial_pool(x)
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, Optional
from torch import nn from torch import nn
from .registry import CONV_LAYERS from .registry import CONV_LAYERS
...@@ -9,7 +11,7 @@ CONV_LAYERS.register_module('Conv3d', module=nn.Conv3d) ...@@ -9,7 +11,7 @@ CONV_LAYERS.register_module('Conv3d', module=nn.Conv3d)
CONV_LAYERS.register_module('Conv', module=nn.Conv2d) CONV_LAYERS.register_module('Conv', module=nn.Conv2d)
def build_conv_layer(cfg, *args, **kwargs): def build_conv_layer(cfg: Optional[Dict], *args, **kwargs) -> nn.Module:
"""Build convolution layer. """Build convolution layer.
Args: Args:
...@@ -35,7 +37,7 @@ def build_conv_layer(cfg, *args, **kwargs): ...@@ -35,7 +37,7 @@ def build_conv_layer(cfg, *args, **kwargs):
layer_type = cfg_.pop('type') layer_type = cfg_.pop('type')
if layer_type not in CONV_LAYERS: if layer_type not in CONV_LAYERS:
raise KeyError(f'Unrecognized norm type {layer_type}') raise KeyError(f'Unrecognized layer type {layer_type}')
else: else:
conv_layer = CONV_LAYERS.get(layer_type) conv_layer = CONV_LAYERS.get(layer_type)
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import math import math
from typing import Tuple, Union
import torch
from torch import nn from torch import nn
from torch.nn import functional as F from torch.nn import functional as F
...@@ -31,18 +33,18 @@ class Conv2dAdaptivePadding(nn.Conv2d): ...@@ -31,18 +33,18 @@ class Conv2dAdaptivePadding(nn.Conv2d):
""" """
def __init__(self, def __init__(self,
in_channels, in_channels: int,
out_channels, out_channels: int,
kernel_size, kernel_size: Union[int, Tuple[int, int]],
stride=1, stride: Union[int, Tuple[int, int]] = 1,
padding=0, padding: Union[int, Tuple[int, int]] = 0,
dilation=1, dilation: Union[int, Tuple[int, int]] = 1,
groups=1, groups: int = 1,
bias=True): bias: bool = True):
super().__init__(in_channels, out_channels, kernel_size, stride, 0, super().__init__(in_channels, out_channels, kernel_size, stride, 0,
dilation, groups, bias) dilation, groups, bias)
def forward(self, x): def forward(self, x: torch.Tensor) -> torch.Tensor:
img_h, img_w = x.size()[-2:] img_h, img_w = x.size()[-2:]
kernel_h, kernel_w = self.weight.size()[-2:] kernel_h, kernel_w = self.weight.size()[-2:]
stride_h, stride_w = self.stride stride_h, stride_w = self.stride
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import warnings import warnings
from typing import Dict, Optional, Tuple, Union
import torch
import torch.nn as nn import torch.nn as nn
from mmcv.utils import _BatchNorm, _InstanceNorm from mmcv.utils import _BatchNorm, _InstanceNorm
...@@ -68,22 +70,22 @@ class ConvModule(nn.Module): ...@@ -68,22 +70,22 @@ class ConvModule(nn.Module):
_abbr_ = 'conv_block' _abbr_ = 'conv_block'
def __init__(self, def __init__(self,
in_channels, in_channels: int,
out_channels, out_channels: int,
kernel_size, kernel_size: Union[int, Tuple[int, int]],
stride=1, stride: Union[int, Tuple[int, int]] = 1,
padding=0, padding: Union[int, Tuple[int, int]] = 0,
dilation=1, dilation: Union[int, Tuple[int, int]] = 1,
groups=1, groups: int = 1,
bias='auto', bias: Union[bool, str] = 'auto',
conv_cfg=None, conv_cfg: Optional[Dict] = None,
norm_cfg=None, norm_cfg: Optional[Dict] = None,
act_cfg=dict(type='ReLU'), act_cfg: Optional[Dict] = dict(type='ReLU'),
inplace=True, inplace: bool = True,
with_spectral_norm=False, with_spectral_norm: bool = False,
padding_mode='zeros', padding_mode: str = 'zeros',
order=('conv', 'norm', 'act')): order: tuple = ('conv', 'norm', 'act')):
super(ConvModule, self).__init__() super().__init__()
assert conv_cfg is None or isinstance(conv_cfg, dict) assert conv_cfg is None or isinstance(conv_cfg, dict)
assert norm_cfg is None or isinstance(norm_cfg, dict) assert norm_cfg is None or isinstance(norm_cfg, dict)
assert act_cfg is None or isinstance(act_cfg, dict) assert act_cfg is None or isinstance(act_cfg, dict)
...@@ -96,7 +98,7 @@ class ConvModule(nn.Module): ...@@ -96,7 +98,7 @@ class ConvModule(nn.Module):
self.with_explicit_padding = padding_mode not in official_padding_mode self.with_explicit_padding = padding_mode not in official_padding_mode
self.order = order self.order = order
assert isinstance(self.order, tuple) and len(self.order) == 3 assert isinstance(self.order, tuple) and len(self.order) == 3
assert set(order) == set(['conv', 'norm', 'act']) assert set(order) == {'conv', 'norm', 'act'}
self.with_norm = norm_cfg is not None self.with_norm = norm_cfg is not None
self.with_activation = act_cfg is not None self.with_activation = act_cfg is not None
...@@ -143,21 +145,22 @@ class ConvModule(nn.Module): ...@@ -143,21 +145,22 @@ class ConvModule(nn.Module):
norm_channels = out_channels norm_channels = out_channels
else: else:
norm_channels = in_channels norm_channels = in_channels
self.norm_name, norm = build_norm_layer(norm_cfg, norm_channels) self.norm_name, norm = build_norm_layer(
norm_cfg, norm_channels) # type: ignore
self.add_module(self.norm_name, norm) self.add_module(self.norm_name, norm)
if self.with_bias: if self.with_bias:
if isinstance(norm, (_BatchNorm, _InstanceNorm)): if isinstance(norm, (_BatchNorm, _InstanceNorm)):
warnings.warn( warnings.warn(
'Unnecessary conv bias before batch/instance norm') 'Unnecessary conv bias before batch/instance norm')
else: else:
self.norm_name = None self.norm_name = None # type: ignore
# build activation layer # build activation layer
if self.with_activation: if self.with_activation:
act_cfg_ = act_cfg.copy() act_cfg_ = act_cfg.copy() # type: ignore
# nn.Tanh has no 'inplace' argument # nn.Tanh has no 'inplace' argument
if act_cfg_['type'] not in [ if act_cfg_['type'] not in [
'Tanh', 'PReLU', 'Sigmoid', 'HSigmoid', 'Swish' 'Tanh', 'PReLU', 'Sigmoid', 'HSigmoid', 'Swish', 'GELU'
]: ]:
act_cfg_.setdefault('inplace', inplace) act_cfg_.setdefault('inplace', inplace)
self.activate = build_activation_layer(act_cfg_) self.activate = build_activation_layer(act_cfg_)
...@@ -193,7 +196,10 @@ class ConvModule(nn.Module): ...@@ -193,7 +196,10 @@ class ConvModule(nn.Module):
if self.with_norm: if self.with_norm:
constant_init(self.norm, 1, bias=0) constant_init(self.norm, 1, bias=0)
def forward(self, x, activate=True, norm=True): def forward(self,
x: torch.Tensor,
activate: bool = True,
norm: bool = True) -> torch.Tensor:
for layer in self.order: for layer in self.order:
if layer == 'conv': if layer == 'conv':
if self.with_explicit_padding: if self.with_explicit_padding:
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from collections import OrderedDict
from typing import Dict, List, Optional, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
...@@ -6,14 +9,14 @@ import torch.nn.functional as F ...@@ -6,14 +9,14 @@ import torch.nn.functional as F
from .registry import CONV_LAYERS from .registry import CONV_LAYERS
def conv_ws_2d(input, def conv_ws_2d(input: torch.Tensor,
weight, weight: torch.Tensor,
bias=None, bias: Optional[torch.Tensor] = None,
stride=1, stride: Union[int, Tuple[int, int]] = 1,
padding=0, padding: Union[int, Tuple[int, int]] = 0,
dilation=1, dilation: Union[int, Tuple[int, int]] = 1,
groups=1, groups: int = 1,
eps=1e-5): eps: float = 1e-5) -> torch.Tensor:
c_in = weight.size(0) c_in = weight.size(0)
weight_flat = weight.view(c_in, -1) weight_flat = weight.view(c_in, -1)
mean = weight_flat.mean(dim=1, keepdim=True).view(c_in, 1, 1, 1) mean = weight_flat.mean(dim=1, keepdim=True).view(c_in, 1, 1, 1)
...@@ -26,16 +29,16 @@ def conv_ws_2d(input, ...@@ -26,16 +29,16 @@ def conv_ws_2d(input,
class ConvWS2d(nn.Conv2d): class ConvWS2d(nn.Conv2d):
def __init__(self, def __init__(self,
in_channels, in_channels: int,
out_channels, out_channels: int,
kernel_size, kernel_size: Union[int, Tuple[int, int]],
stride=1, stride: Union[int, Tuple[int, int]] = 1,
padding=0, padding: Union[int, Tuple[int, int]] = 0,
dilation=1, dilation: Union[int, Tuple[int, int]] = 1,
groups=1, groups: int = 1,
bias=True, bias: bool = True,
eps=1e-5): eps: float = 1e-5):
super(ConvWS2d, self).__init__( super().__init__(
in_channels, in_channels,
out_channels, out_channels,
kernel_size, kernel_size,
...@@ -46,7 +49,7 @@ class ConvWS2d(nn.Conv2d): ...@@ -46,7 +49,7 @@ class ConvWS2d(nn.Conv2d):
bias=bias) bias=bias)
self.eps = eps self.eps = eps
def forward(self, x): def forward(self, x: torch.Tensor) -> torch.Tensor:
return conv_ws_2d(x, self.weight, self.bias, self.stride, self.padding, return conv_ws_2d(x, self.weight, self.bias, self.stride, self.padding,
self.dilation, self.groups, self.eps) self.dilation, self.groups, self.eps)
...@@ -76,14 +79,14 @@ class ConvAWS2d(nn.Conv2d): ...@@ -76,14 +79,14 @@ class ConvAWS2d(nn.Conv2d):
""" """
def __init__(self, def __init__(self,
in_channels, in_channels: int,
out_channels, out_channels: int,
kernel_size, kernel_size: Union[int, Tuple[int, int]],
stride=1, stride: Union[int, Tuple[int, int]] = 1,
padding=0, padding: Union[int, Tuple[int, int]] = 0,
dilation=1, dilation: Union[int, Tuple[int, int]] = 1,
groups=1, groups: int = 1,
bias=True): bias: bool = True):
super().__init__( super().__init__(
in_channels, in_channels,
out_channels, out_channels,
...@@ -98,7 +101,7 @@ class ConvAWS2d(nn.Conv2d): ...@@ -98,7 +101,7 @@ class ConvAWS2d(nn.Conv2d):
self.register_buffer('weight_beta', self.register_buffer('weight_beta',
torch.zeros(self.out_channels, 1, 1, 1)) torch.zeros(self.out_channels, 1, 1, 1))
def _get_weight(self, weight): def _get_weight(self, weight: torch.Tensor) -> torch.Tensor:
weight_flat = weight.view(weight.size(0), -1) weight_flat = weight.view(weight.size(0), -1)
mean = weight_flat.mean(dim=1).view(-1, 1, 1, 1) mean = weight_flat.mean(dim=1).view(-1, 1, 1, 1)
std = torch.sqrt(weight_flat.var(dim=1) + 1e-5).view(-1, 1, 1, 1) std = torch.sqrt(weight_flat.var(dim=1) + 1e-5).view(-1, 1, 1, 1)
...@@ -106,13 +109,16 @@ class ConvAWS2d(nn.Conv2d): ...@@ -106,13 +109,16 @@ class ConvAWS2d(nn.Conv2d):
weight = self.weight_gamma * weight + self.weight_beta weight = self.weight_gamma * weight + self.weight_beta
return weight return weight
def forward(self, x): def forward(self, x: torch.Tensor) -> torch.Tensor:
weight = self._get_weight(self.weight) weight = self._get_weight(self.weight)
return F.conv2d(x, weight, self.bias, self.stride, self.padding, return F.conv2d(x, weight, self.bias, self.stride, self.padding,
self.dilation, self.groups) self.dilation, self.groups)
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, def _load_from_state_dict(self, state_dict: OrderedDict, prefix: str,
missing_keys, unexpected_keys, error_msgs): local_metadata: Dict, strict: bool,
missing_keys: List[str],
unexpected_keys: List[str],
error_msgs: List[str]) -> None:
"""Override default load function. """Override default load function.
AWS overrides the function _load_from_state_dict to recover AWS overrides the function _load_from_state_dict to recover
...@@ -124,7 +130,7 @@ class ConvAWS2d(nn.Conv2d): ...@@ -124,7 +130,7 @@ class ConvAWS2d(nn.Conv2d):
""" """
self.weight_gamma.data.fill_(-1) self.weight_gamma.data.fill_(-1)
local_missing_keys = [] local_missing_keys: List = []
super()._load_from_state_dict(state_dict, prefix, local_metadata, super()._load_from_state_dict(state_dict, prefix, local_metadata,
strict, local_missing_keys, strict, local_missing_keys,
unexpected_keys, error_msgs) unexpected_keys, error_msgs)
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, Optional, Tuple, Union
import torch
import torch.nn as nn import torch.nn as nn
from .conv_module import ConvModule from .conv_module import ConvModule
...@@ -46,27 +49,27 @@ class DepthwiseSeparableConvModule(nn.Module): ...@@ -46,27 +49,27 @@ class DepthwiseSeparableConvModule(nn.Module):
""" """
def __init__(self, def __init__(self,
in_channels, in_channels: int,
out_channels, out_channels: int,
kernel_size, kernel_size: Union[int, Tuple[int, int]],
stride=1, stride: Union[int, Tuple[int, int]] = 1,
padding=0, padding: Union[int, Tuple[int, int]] = 0,
dilation=1, dilation: Union[int, Tuple[int, int]] = 1,
norm_cfg=None, norm_cfg: Optional[Dict] = None,
act_cfg=dict(type='ReLU'), act_cfg: Dict = dict(type='ReLU'),
dw_norm_cfg='default', dw_norm_cfg: Union[Dict, str] = 'default',
dw_act_cfg='default', dw_act_cfg: Union[Dict, str] = 'default',
pw_norm_cfg='default', pw_norm_cfg: Union[Dict, str] = 'default',
pw_act_cfg='default', pw_act_cfg: Union[Dict, str] = 'default',
**kwargs): **kwargs):
super(DepthwiseSeparableConvModule, self).__init__() super().__init__()
assert 'groups' not in kwargs, 'groups should not be specified' assert 'groups' not in kwargs, 'groups should not be specified'
# if norm/activation config of depthwise/pointwise ConvModule is not # if norm/activation config of depthwise/pointwise ConvModule is not
# specified, use default config. # specified, use default config.
dw_norm_cfg = dw_norm_cfg if dw_norm_cfg != 'default' else norm_cfg dw_norm_cfg = dw_norm_cfg if dw_norm_cfg != 'default' else norm_cfg # type: ignore # noqa E501
dw_act_cfg = dw_act_cfg if dw_act_cfg != 'default' else act_cfg dw_act_cfg = dw_act_cfg if dw_act_cfg != 'default' else act_cfg
pw_norm_cfg = pw_norm_cfg if pw_norm_cfg != 'default' else norm_cfg pw_norm_cfg = pw_norm_cfg if pw_norm_cfg != 'default' else norm_cfg # type: ignore # noqa E501
pw_act_cfg = pw_act_cfg if pw_act_cfg != 'default' else act_cfg pw_act_cfg = pw_act_cfg if pw_act_cfg != 'default' else act_cfg
# depthwise convolution # depthwise convolution
...@@ -78,19 +81,19 @@ class DepthwiseSeparableConvModule(nn.Module): ...@@ -78,19 +81,19 @@ class DepthwiseSeparableConvModule(nn.Module):
padding=padding, padding=padding,
dilation=dilation, dilation=dilation,
groups=in_channels, groups=in_channels,
norm_cfg=dw_norm_cfg, norm_cfg=dw_norm_cfg, # type: ignore
act_cfg=dw_act_cfg, act_cfg=dw_act_cfg, # type: ignore
**kwargs) **kwargs)
self.pointwise_conv = ConvModule( self.pointwise_conv = ConvModule(
in_channels, in_channels,
out_channels, out_channels,
1, 1,
norm_cfg=pw_norm_cfg, norm_cfg=pw_norm_cfg, # type: ignore
act_cfg=pw_act_cfg, act_cfg=pw_act_cfg, # type: ignore
**kwargs) **kwargs)
def forward(self, x): def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.depthwise_conv(x) x = self.depthwise_conv(x)
x = self.pointwise_conv(x) x = self.pointwise_conv(x)
return x return x
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import Any, Dict, Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -6,7 +8,9 @@ from mmcv import build_from_cfg ...@@ -6,7 +8,9 @@ from mmcv import build_from_cfg
from .registry import DROPOUT_LAYERS from .registry import DROPOUT_LAYERS
def drop_path(x, drop_prob=0., training=False): def drop_path(x: torch.Tensor,
drop_prob: float = 0.,
training: bool = False) -> torch.Tensor:
"""Drop paths (Stochastic Depth) per sample (when applied in main path of """Drop paths (Stochastic Depth) per sample (when applied in main path of
residual blocks). residual blocks).
...@@ -36,11 +40,11 @@ class DropPath(nn.Module): ...@@ -36,11 +40,11 @@ class DropPath(nn.Module):
drop_prob (float): Probability of the path to be zeroed. Default: 0.1 drop_prob (float): Probability of the path to be zeroed. Default: 0.1
""" """
def __init__(self, drop_prob=0.1): def __init__(self, drop_prob: float = 0.1):
super(DropPath, self).__init__() super().__init__()
self.drop_prob = drop_prob self.drop_prob = drop_prob
def forward(self, x): def forward(self, x: torch.Tensor) -> torch.Tensor:
return drop_path(x, self.drop_prob, self.training) return drop_path(x, self.drop_prob, self.training)
...@@ -56,10 +60,10 @@ class Dropout(nn.Dropout): ...@@ -56,10 +60,10 @@ class Dropout(nn.Dropout):
inplace (bool): Do the operation inplace or not. Default: False. inplace (bool): Do the operation inplace or not. Default: False.
""" """
def __init__(self, drop_prob=0.5, inplace=False): def __init__(self, drop_prob: float = 0.5, inplace: bool = False):
super().__init__(p=drop_prob, inplace=inplace) super().__init__(p=drop_prob, inplace=inplace)
def build_dropout(cfg, default_args=None): def build_dropout(cfg: Dict, default_args: Optional[Dict] = None) -> Any:
"""Builder for drop out layers.""" """Builder for drop out layers."""
return build_from_cfg(cfg, DROPOUT_LAYERS, default_args) return build_from_cfg(cfg, DROPOUT_LAYERS, default_args)
...@@ -45,16 +45,16 @@ class GeneralizedAttention(nn.Module): ...@@ -45,16 +45,16 @@ class GeneralizedAttention(nn.Module):
_abbr_ = 'gen_attention_block' _abbr_ = 'gen_attention_block'
def __init__(self, def __init__(self,
in_channels, in_channels: int,
spatial_range=-1, spatial_range: int = -1,
num_heads=9, num_heads: int = 9,
position_embedding_dim=-1, position_embedding_dim: int = -1,
position_magnitude=1, position_magnitude: int = 1,
kv_stride=2, kv_stride: int = 2,
q_stride=1, q_stride: int = 1,
attention_type='1111'): attention_type: str = '1111'):
super(GeneralizedAttention, self).__init__() super().__init__()
# hard range means local range for non-local operation # hard range means local range for non-local operation
self.position_embedding_dim = ( self.position_embedding_dim = (
...@@ -131,7 +131,7 @@ class GeneralizedAttention(nn.Module): ...@@ -131,7 +131,7 @@ class GeneralizedAttention(nn.Module):
max_len_kv = int((max_len - 1.0) / self.kv_stride + 1) max_len_kv = int((max_len - 1.0) / self.kv_stride + 1)
local_constraint_map = np.ones( local_constraint_map = np.ones(
(max_len, max_len, max_len_kv, max_len_kv), dtype=np.int) (max_len, max_len, max_len_kv, max_len_kv), dtype=int)
for iy in range(max_len): for iy in range(max_len):
for ix in range(max_len): for ix in range(max_len):
local_constraint_map[ local_constraint_map[
...@@ -213,7 +213,7 @@ class GeneralizedAttention(nn.Module): ...@@ -213,7 +213,7 @@ class GeneralizedAttention(nn.Module):
return embedding_x, embedding_y return embedding_x, embedding_y
def forward(self, x_input): def forward(self, x_input: torch.Tensor) -> torch.Tensor:
num_heads = self.num_heads num_heads = self.num_heads
# use empirical_attention # use empirical_attention
...@@ -351,7 +351,7 @@ class GeneralizedAttention(nn.Module): ...@@ -351,7 +351,7 @@ class GeneralizedAttention(nn.Module):
repeat(n, 1, 1, 1) repeat(n, 1, 1, 1)
position_feat_x_reshape = position_feat_x.\ position_feat_x_reshape = position_feat_x.\
view(n, num_heads, w*w_kv, self.qk_embed_dim) view(n, num_heads, w * w_kv, self.qk_embed_dim)
position_feat_y_reshape = position_feat_y.\ position_feat_y_reshape = position_feat_y.\
view(n, num_heads, h * h_kv, self.qk_embed_dim) view(n, num_heads, h * h_kv, self.qk_embed_dim)
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import warnings
import torch
import torch.nn as nn import torch.nn as nn
from .registry import ACTIVATION_LAYERS from .registry import ACTIVATION_LAYERS
...@@ -8,11 +11,15 @@ from .registry import ACTIVATION_LAYERS ...@@ -8,11 +11,15 @@ from .registry import ACTIVATION_LAYERS
class HSigmoid(nn.Module): class HSigmoid(nn.Module):
"""Hard Sigmoid Module. Apply the hard sigmoid function: """Hard Sigmoid Module. Apply the hard sigmoid function:
Hsigmoid(x) = min(max((x + bias) / divisor, min_value), max_value) Hsigmoid(x) = min(max((x + bias) / divisor, min_value), max_value)
Default: Hsigmoid(x) = min(max((x + 1) / 2, 0), 1) Default: Hsigmoid(x) = min(max((x + 3) / 6, 0), 1)
Note:
In MMCV v1.4.4, we modified the default value of args to align with
PyTorch official.
Args: Args:
bias (float): Bias of the input feature map. Default: 1.0. bias (float): Bias of the input feature map. Default: 3.0.
divisor (float): Divisor of the input feature map. Default: 2.0. divisor (float): Divisor of the input feature map. Default: 6.0.
min_value (float): Lower bound value. Default: 0.0. min_value (float): Lower bound value. Default: 0.0.
max_value (float): Upper bound value. Default: 1.0. max_value (float): Upper bound value. Default: 1.0.
...@@ -20,15 +27,25 @@ class HSigmoid(nn.Module): ...@@ -20,15 +27,25 @@ class HSigmoid(nn.Module):
Tensor: The output tensor. Tensor: The output tensor.
""" """
def __init__(self, bias=1.0, divisor=2.0, min_value=0.0, max_value=1.0): def __init__(self,
super(HSigmoid, self).__init__() bias: float = 3.0,
divisor: float = 6.0,
min_value: float = 0.0,
max_value: float = 1.0):
super().__init__()
warnings.warn(
'In MMCV v1.4.4, we modified the default value of args to align '
'with PyTorch official. Previous Implementation: '
'Hsigmoid(x) = min(max((x + 1) / 2, 0), 1). '
'Current Implementation: '
'Hsigmoid(x) = min(max((x + 3) / 6, 0), 1).')
self.bias = bias self.bias = bias
self.divisor = divisor self.divisor = divisor
assert self.divisor != 0 assert self.divisor != 0
self.min_value = min_value self.min_value = min_value
self.max_value = max_value self.max_value = max_value
def forward(self, x): def forward(self, x: torch.Tensor) -> torch.Tensor:
x = (x + self.bias) / self.divisor x = (x + self.bias) / self.divisor
return x.clamp_(self.min_value, self.max_value) return x.clamp_(self.min_value, self.max_value)
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn import torch.nn as nn
from mmcv.utils import TORCH_VERSION, digit_version
from .registry import ACTIVATION_LAYERS from .registry import ACTIVATION_LAYERS
@ACTIVATION_LAYERS.register_module()
class HSwish(nn.Module): class HSwish(nn.Module):
"""Hard Swish Module. """Hard Swish Module.
...@@ -21,9 +22,18 @@ class HSwish(nn.Module): ...@@ -21,9 +22,18 @@ class HSwish(nn.Module):
Tensor: The output tensor. Tensor: The output tensor.
""" """
def __init__(self, inplace=False): def __init__(self, inplace: bool = False):
super(HSwish, self).__init__() super().__init__()
self.act = nn.ReLU6(inplace) self.act = nn.ReLU6(inplace)
def forward(self, x): def forward(self, x: torch.Tensor) -> torch.Tensor:
return x * self.act(x + 3) / 6 return x * self.act(x + 3) / 6
if (TORCH_VERSION == 'parrots'
or digit_version(TORCH_VERSION) < digit_version('1.7')):
# Hardswish is not supported when PyTorch version < 1.6.
# And Hardswish in PyTorch 1.6 does not support inplace.
ACTIVATION_LAYERS.register_module(module=HSwish)
else:
ACTIVATION_LAYERS.register_module(module=nn.Hardswish, name='HSwish')
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from abc import ABCMeta from abc import ABCMeta
from typing import Dict, Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -33,14 +34,14 @@ class _NonLocalNd(nn.Module, metaclass=ABCMeta): ...@@ -33,14 +34,14 @@ class _NonLocalNd(nn.Module, metaclass=ABCMeta):
""" """
def __init__(self, def __init__(self,
in_channels, in_channels: int,
reduction=2, reduction: int = 2,
use_scale=True, use_scale: bool = True,
conv_cfg=None, conv_cfg: Optional[Dict] = None,
norm_cfg=None, norm_cfg: Optional[Dict] = None,
mode='embedded_gaussian', mode: str = 'embedded_gaussian',
**kwargs): **kwargs):
super(_NonLocalNd, self).__init__() super().__init__()
self.in_channels = in_channels self.in_channels = in_channels
self.reduction = reduction self.reduction = reduction
self.use_scale = use_scale self.use_scale = use_scale
...@@ -61,7 +62,7 @@ class _NonLocalNd(nn.Module, metaclass=ABCMeta): ...@@ -61,7 +62,7 @@ class _NonLocalNd(nn.Module, metaclass=ABCMeta):
self.inter_channels, self.inter_channels,
kernel_size=1, kernel_size=1,
conv_cfg=conv_cfg, conv_cfg=conv_cfg,
act_cfg=None) act_cfg=None) # type: ignore
self.conv_out = ConvModule( self.conv_out = ConvModule(
self.inter_channels, self.inter_channels,
self.in_channels, self.in_channels,
...@@ -96,7 +97,7 @@ class _NonLocalNd(nn.Module, metaclass=ABCMeta): ...@@ -96,7 +97,7 @@ class _NonLocalNd(nn.Module, metaclass=ABCMeta):
self.init_weights(**kwargs) self.init_weights(**kwargs)
def init_weights(self, std=0.01, zeros_init=True): def init_weights(self, std: float = 0.01, zeros_init: bool = True) -> None:
if self.mode != 'gaussian': if self.mode != 'gaussian':
for m in [self.g, self.theta, self.phi]: for m in [self.g, self.theta, self.phi]:
normal_init(m.conv, std=std) normal_init(m.conv, std=std)
...@@ -113,7 +114,8 @@ class _NonLocalNd(nn.Module, metaclass=ABCMeta): ...@@ -113,7 +114,8 @@ class _NonLocalNd(nn.Module, metaclass=ABCMeta):
else: else:
normal_init(self.conv_out.norm, std=std) normal_init(self.conv_out.norm, std=std)
def gaussian(self, theta_x, phi_x): def gaussian(self, theta_x: torch.Tensor,
phi_x: torch.Tensor) -> torch.Tensor:
# NonLocal1d pairwise_weight: [N, H, H] # NonLocal1d pairwise_weight: [N, H, H]
# NonLocal2d pairwise_weight: [N, HxW, HxW] # NonLocal2d pairwise_weight: [N, HxW, HxW]
# NonLocal3d pairwise_weight: [N, TxHxW, TxHxW] # NonLocal3d pairwise_weight: [N, TxHxW, TxHxW]
...@@ -121,7 +123,8 @@ class _NonLocalNd(nn.Module, metaclass=ABCMeta): ...@@ -121,7 +123,8 @@ class _NonLocalNd(nn.Module, metaclass=ABCMeta):
pairwise_weight = pairwise_weight.softmax(dim=-1) pairwise_weight = pairwise_weight.softmax(dim=-1)
return pairwise_weight return pairwise_weight
def embedded_gaussian(self, theta_x, phi_x): def embedded_gaussian(self, theta_x: torch.Tensor,
phi_x: torch.Tensor) -> torch.Tensor:
# NonLocal1d pairwise_weight: [N, H, H] # NonLocal1d pairwise_weight: [N, H, H]
# NonLocal2d pairwise_weight: [N, HxW, HxW] # NonLocal2d pairwise_weight: [N, HxW, HxW]
# NonLocal3d pairwise_weight: [N, TxHxW, TxHxW] # NonLocal3d pairwise_weight: [N, TxHxW, TxHxW]
...@@ -132,7 +135,8 @@ class _NonLocalNd(nn.Module, metaclass=ABCMeta): ...@@ -132,7 +135,8 @@ class _NonLocalNd(nn.Module, metaclass=ABCMeta):
pairwise_weight = pairwise_weight.softmax(dim=-1) pairwise_weight = pairwise_weight.softmax(dim=-1)
return pairwise_weight return pairwise_weight
def dot_product(self, theta_x, phi_x): def dot_product(self, theta_x: torch.Tensor,
phi_x: torch.Tensor) -> torch.Tensor:
# NonLocal1d pairwise_weight: [N, H, H] # NonLocal1d pairwise_weight: [N, H, H]
# NonLocal2d pairwise_weight: [N, HxW, HxW] # NonLocal2d pairwise_weight: [N, HxW, HxW]
# NonLocal3d pairwise_weight: [N, TxHxW, TxHxW] # NonLocal3d pairwise_weight: [N, TxHxW, TxHxW]
...@@ -140,7 +144,8 @@ class _NonLocalNd(nn.Module, metaclass=ABCMeta): ...@@ -140,7 +144,8 @@ class _NonLocalNd(nn.Module, metaclass=ABCMeta):
pairwise_weight /= pairwise_weight.shape[-1] pairwise_weight /= pairwise_weight.shape[-1]
return pairwise_weight return pairwise_weight
def concatenation(self, theta_x, phi_x): def concatenation(self, theta_x: torch.Tensor,
phi_x: torch.Tensor) -> torch.Tensor:
# NonLocal1d pairwise_weight: [N, H, H] # NonLocal1d pairwise_weight: [N, H, H]
# NonLocal2d pairwise_weight: [N, HxW, HxW] # NonLocal2d pairwise_weight: [N, HxW, HxW]
# NonLocal3d pairwise_weight: [N, TxHxW, TxHxW] # NonLocal3d pairwise_weight: [N, TxHxW, TxHxW]
...@@ -157,7 +162,7 @@ class _NonLocalNd(nn.Module, metaclass=ABCMeta): ...@@ -157,7 +162,7 @@ class _NonLocalNd(nn.Module, metaclass=ABCMeta):
return pairwise_weight return pairwise_weight
def forward(self, x): def forward(self, x: torch.Tensor) -> torch.Tensor:
# Assume `reduction = 1`, then `inter_channels = C` # Assume `reduction = 1`, then `inter_channels = C`
# or `inter_channels = C` when `mode="gaussian"` # or `inter_channels = C` when `mode="gaussian"`
...@@ -224,12 +229,11 @@ class NonLocal1d(_NonLocalNd): ...@@ -224,12 +229,11 @@ class NonLocal1d(_NonLocalNd):
""" """
def __init__(self, def __init__(self,
in_channels, in_channels: int,
sub_sample=False, sub_sample: bool = False,
conv_cfg=dict(type='Conv1d'), conv_cfg: Dict = dict(type='Conv1d'),
**kwargs): **kwargs):
super(NonLocal1d, self).__init__( super().__init__(in_channels, conv_cfg=conv_cfg, **kwargs)
in_channels, conv_cfg=conv_cfg, **kwargs)
self.sub_sample = sub_sample self.sub_sample = sub_sample
...@@ -258,12 +262,11 @@ class NonLocal2d(_NonLocalNd): ...@@ -258,12 +262,11 @@ class NonLocal2d(_NonLocalNd):
_abbr_ = 'nonlocal_block' _abbr_ = 'nonlocal_block'
def __init__(self, def __init__(self,
in_channels, in_channels: int,
sub_sample=False, sub_sample: bool = False,
conv_cfg=dict(type='Conv2d'), conv_cfg: Dict = dict(type='Conv2d'),
**kwargs): **kwargs):
super(NonLocal2d, self).__init__( super().__init__(in_channels, conv_cfg=conv_cfg, **kwargs)
in_channels, conv_cfg=conv_cfg, **kwargs)
self.sub_sample = sub_sample self.sub_sample = sub_sample
...@@ -289,12 +292,11 @@ class NonLocal3d(_NonLocalNd): ...@@ -289,12 +292,11 @@ class NonLocal3d(_NonLocalNd):
""" """
def __init__(self, def __init__(self,
in_channels, in_channels: int,
sub_sample=False, sub_sample: bool = False,
conv_cfg=dict(type='Conv3d'), conv_cfg: Dict = dict(type='Conv3d'),
**kwargs): **kwargs):
super(NonLocal3d, self).__init__( super().__init__(in_channels, conv_cfg=conv_cfg, **kwargs)
in_channels, conv_cfg=conv_cfg, **kwargs)
self.sub_sample = sub_sample self.sub_sample = sub_sample
if sub_sample: if sub_sample:
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import inspect import inspect
from typing import Dict, Tuple, Union
import torch.nn as nn import torch.nn as nn
...@@ -69,7 +70,9 @@ def infer_abbr(class_type): ...@@ -69,7 +70,9 @@ def infer_abbr(class_type):
return 'norm_layer' return 'norm_layer'
def build_norm_layer(cfg, num_features, postfix=''): def build_norm_layer(cfg: Dict,
num_features: int,
postfix: Union[int, str] = '') -> Tuple[str, nn.Module]:
"""Build normalization layer. """Build normalization layer.
Args: Args:
...@@ -83,9 +86,9 @@ def build_norm_layer(cfg, num_features, postfix=''): ...@@ -83,9 +86,9 @@ def build_norm_layer(cfg, num_features, postfix=''):
to create named layer. to create named layer.
Returns: Returns:
(str, nn.Module): The first element is the layer name consisting of tuple[str, nn.Module]: The first element is the layer name consisting
abbreviation and postfix, e.g., bn1, gn. The second element is the of abbreviation and postfix, e.g., bn1, gn. The second element is the
created norm layer. created norm layer.
""" """
if not isinstance(cfg, dict): if not isinstance(cfg, dict):
raise TypeError('cfg must be a dict') raise TypeError('cfg must be a dict')
...@@ -119,7 +122,8 @@ def build_norm_layer(cfg, num_features, postfix=''): ...@@ -119,7 +122,8 @@ def build_norm_layer(cfg, num_features, postfix=''):
return name, layer return name, layer
def is_norm(layer, exclude=None): def is_norm(layer: nn.Module,
exclude: Union[type, tuple, None] = None) -> bool:
"""Check if a layer is a normalization layer. """Check if a layer is a normalization layer.
Args: Args:
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict
import torch.nn as nn import torch.nn as nn
from .registry import PADDING_LAYERS from .registry import PADDING_LAYERS
...@@ -8,11 +10,11 @@ PADDING_LAYERS.register_module('reflect', module=nn.ReflectionPad2d) ...@@ -8,11 +10,11 @@ PADDING_LAYERS.register_module('reflect', module=nn.ReflectionPad2d)
PADDING_LAYERS.register_module('replicate', module=nn.ReplicationPad2d) PADDING_LAYERS.register_module('replicate', module=nn.ReplicationPad2d)
def build_padding_layer(cfg, *args, **kwargs): def build_padding_layer(cfg: Dict, *args, **kwargs) -> nn.Module:
"""Build padding layer. """Build padding layer.
Args: Args:
cfg (None or dict): The padding layer config, which should contain: cfg (dict): The padding layer config, which should contain:
- type (str): Layer type. - type (str): Layer type.
- layer args: Args needed to instantiate a padding layer. - layer args: Args needed to instantiate a padding layer.
......
# Copyright (c) OpenMMLab. All rights reserved.
import inspect import inspect
import platform import platform
from typing import Dict, Tuple, Union
import torch.nn as nn
from .registry import PLUGIN_LAYERS from .registry import PLUGIN_LAYERS
if platform.system() == 'Windows': if platform.system() == 'Windows':
import regex as re import regex as re # type: ignore
else: else:
import re import re # type: ignore
def infer_abbr(class_type): def infer_abbr(class_type: type) -> str:
"""Infer abbreviation from the class name. """Infer abbreviation from the class name.
This method will infer the abbreviation to map class types to This method will infer the abbreviation to map class types to
...@@ -47,25 +51,27 @@ def infer_abbr(class_type): ...@@ -47,25 +51,27 @@ def infer_abbr(class_type):
raise TypeError( raise TypeError(
f'class_type must be a type, but got {type(class_type)}') f'class_type must be a type, but got {type(class_type)}')
if hasattr(class_type, '_abbr_'): if hasattr(class_type, '_abbr_'):
return class_type._abbr_ return class_type._abbr_ # type: ignore
else: else:
return camel2snack(class_type.__name__) return camel2snack(class_type.__name__)
def build_plugin_layer(cfg, postfix='', **kwargs): def build_plugin_layer(cfg: Dict,
postfix: Union[int, str] = '',
**kwargs) -> Tuple[str, nn.Module]:
"""Build plugin layer. """Build plugin layer.
Args: Args:
cfg (None or dict): cfg should contain: cfg (dict): cfg should contain:
type (str): identify plugin layer type.
layer args: args needed to instantiate a plugin layer. - type (str): identify plugin layer type.
- layer args: args needed to instantiate a plugin layer.
postfix (int, str): appended into norm abbreviation to postfix (int, str): appended into norm abbreviation to
create named layer. Default: ''. create named layer. Default: ''.
Returns: Returns:
tuple[str, nn.Module]: tuple[str, nn.Module]: The first one is the concatenation of
name (str): abbreviation + postfix abbreviation and postfix. The second is the created plugin layer.
layer (nn.Module): created plugin layer
""" """
if not isinstance(cfg, dict): if not isinstance(cfg, dict):
raise TypeError('cfg must be a dict') raise TypeError('cfg must be a dict')
......
...@@ -13,9 +13,9 @@ class Scale(nn.Module): ...@@ -13,9 +13,9 @@ class Scale(nn.Module):
scale (float): Initial value of scale factor. Default: 1.0 scale (float): Initial value of scale factor. Default: 1.0
""" """
def __init__(self, scale=1.0): def __init__(self, scale: float = 1.0):
super(Scale, self).__init__() super().__init__()
self.scale = nn.Parameter(torch.tensor(scale, dtype=torch.float)) self.scale = nn.Parameter(torch.tensor(scale, dtype=torch.float))
def forward(self, x): def forward(self, x: torch.Tensor) -> torch.Tensor:
return x * self.scale return x * self.scale
...@@ -19,7 +19,7 @@ class Swish(nn.Module): ...@@ -19,7 +19,7 @@ class Swish(nn.Module):
""" """
def __init__(self): def __init__(self):
super(Swish, self).__init__() super().__init__()
def forward(self, x): def forward(self, x: torch.Tensor) -> torch.Tensor:
return x * torch.sigmoid(x) return x * torch.sigmoid(x)
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import copy import copy
import math
import warnings import warnings
from typing import Sequence
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
from mmcv import ConfigDict, deprecated_api_warning from mmcv.cnn import (Linear, build_activation_layer, build_conv_layer,
from mmcv.cnn import Linear, build_activation_layer, build_norm_layer build_norm_layer)
from mmcv.runner.base_module import BaseModule, ModuleList, Sequential from mmcv.runner.base_module import BaseModule, ModuleList, Sequential
from mmcv.utils import build_from_cfg from mmcv.utils import (ConfigDict, build_from_cfg, deprecated_api_warning,
to_2tuple)
from .drop import build_dropout from .drop import build_dropout
from .registry import (ATTENTION, FEEDFORWARD_NETWORK, POSITIONAL_ENCODING, from .registry import (ATTENTION, FEEDFORWARD_NETWORK, POSITIONAL_ENCODING,
TRANSFORMER_LAYER, TRANSFORMER_LAYER_SEQUENCE) TRANSFORMER_LAYER, TRANSFORMER_LAYER_SEQUENCE)
# Avoid BC-breaking of importing MultiScaleDeformableAttention from this file # Avoid BC-breaking of importing MultiScaleDeformableAttention from this file
try: try:
from mmcv.ops.multi_scale_deform_attn import MultiScaleDeformableAttention # noqa F401 from mmcv.ops.multi_scale_deform_attn import \
MultiScaleDeformableAttention # noqa F401
warnings.warn( warnings.warn(
ImportWarning( ImportWarning(
'``MultiScaleDeformableAttention`` has been moved to ' '``MultiScaleDeformableAttention`` has been moved to '
...@@ -55,6 +60,349 @@ def build_transformer_layer_sequence(cfg, default_args=None): ...@@ -55,6 +60,349 @@ def build_transformer_layer_sequence(cfg, default_args=None):
return build_from_cfg(cfg, TRANSFORMER_LAYER_SEQUENCE, default_args) return build_from_cfg(cfg, TRANSFORMER_LAYER_SEQUENCE, default_args)
class AdaptivePadding(nn.Module):
"""Applies padding adaptively to the input.
This module can make input get fully covered by filter
you specified. It support two modes "same" and "corner". The
"same" mode is same with "SAME" padding mode in TensorFlow, pad
zero around input. The "corner" mode would pad zero
to bottom right.
Args:
kernel_size (int | tuple): Size of the kernel. Default: 1.
stride (int | tuple): Stride of the filter. Default: 1.
dilation (int | tuple): Spacing between kernel elements.
Default: 1.
padding (str): Support "same" and "corner", "corner" mode
would pad zero to bottom right, and "same" mode would
pad zero around input. Default: "corner".
Example:
>>> kernel_size = 16
>>> stride = 16
>>> dilation = 1
>>> input = torch.rand(1, 1, 15, 17)
>>> adap_pad = AdaptivePadding(
>>> kernel_size=kernel_size,
>>> stride=stride,
>>> dilation=dilation,
>>> padding="corner")
>>> out = adap_pad(input)
>>> assert (out.shape[2], out.shape[3]) == (16, 32)
>>> input = torch.rand(1, 1, 16, 17)
>>> out = adap_pad(input)
>>> assert (out.shape[2], out.shape[3]) == (16, 32)
"""
def __init__(self, kernel_size=1, stride=1, dilation=1, padding='corner'):
super().__init__()
assert padding in ('same', 'corner')
kernel_size = to_2tuple(kernel_size)
stride = to_2tuple(stride)
dilation = to_2tuple(dilation)
self.padding = padding
self.kernel_size = kernel_size
self.stride = stride
self.dilation = dilation
def get_pad_shape(self, input_shape):
"""Calculate the padding size of input.
Args:
input_shape (:obj:`torch.Size`): arrange as (H, W).
Returns:
Tuple[int]: The padding size along the
original H and W directions
"""
input_h, input_w = input_shape
kernel_h, kernel_w = self.kernel_size
stride_h, stride_w = self.stride
output_h = math.ceil(input_h / stride_h)
output_w = math.ceil(input_w / stride_w)
pad_h = max((output_h - 1) * stride_h +
(kernel_h - 1) * self.dilation[0] + 1 - input_h, 0)
pad_w = max((output_w - 1) * stride_w +
(kernel_w - 1) * self.dilation[1] + 1 - input_w, 0)
return pad_h, pad_w
def forward(self, x):
"""Add padding to `x`
Args:
x (Tensor): Input tensor has shape (B, C, H, W).
Returns:
Tensor: The tensor with adaptive padding
"""
pad_h, pad_w = self.get_pad_shape(x.size()[-2:])
if pad_h > 0 or pad_w > 0:
if self.padding == 'corner':
x = F.pad(x, [0, pad_w, 0, pad_h])
elif self.padding == 'same':
x = F.pad(x, [
pad_w // 2, pad_w - pad_w // 2, pad_h // 2,
pad_h - pad_h // 2
])
return x
class PatchEmbed(BaseModule):
"""Image to Patch Embedding.
We use a conv layer to implement PatchEmbed.
Args:
in_channels (int): The num of input channels. Default: 3
embed_dims (int): The dimensions of embedding. Default: 768
conv_type (str): The type of convolution
to generate patch embedding. Default: "Conv2d".
kernel_size (int): The kernel_size of embedding conv. Default: 16.
stride (int): The slide stride of embedding conv.
Default: 16.
padding (int | tuple | string): The padding length of
embedding conv. When it is a string, it means the mode
of adaptive padding, support "same" and "corner" now.
Default: "corner".
dilation (int): The dilation rate of embedding conv. Default: 1.
bias (bool): Bias of embed conv. Default: True.
norm_cfg (dict, optional): Config dict for normalization layer.
Default: None.
input_size (int | tuple | None): The size of input, which will be
used to calculate the out size. Only works when `dynamic_size`
is False. Default: None.
init_cfg (`mmcv.ConfigDict`, optional): The Config for initialization.
Default: None.
"""
def __init__(self,
in_channels=3,
embed_dims=768,
conv_type='Conv2d',
kernel_size=16,
stride=16,
padding='corner',
dilation=1,
bias=True,
norm_cfg=None,
input_size=None,
init_cfg=None):
super().__init__(init_cfg=init_cfg)
self.embed_dims = embed_dims
if stride is None:
stride = kernel_size
kernel_size = to_2tuple(kernel_size)
stride = to_2tuple(stride)
dilation = to_2tuple(dilation)
if isinstance(padding, str):
self.adaptive_padding = AdaptivePadding(
kernel_size=kernel_size,
stride=stride,
dilation=dilation,
padding=padding)
# disable the padding of conv
padding = 0
else:
self.adaptive_padding = None
padding = to_2tuple(padding)
self.projection = build_conv_layer(
dict(type=conv_type),
in_channels=in_channels,
out_channels=embed_dims,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
bias=bias)
if norm_cfg is not None:
self.norm = build_norm_layer(norm_cfg, embed_dims)[1]
else:
self.norm = None
if input_size:
input_size = to_2tuple(input_size)
# `init_out_size` would be used outside to
# calculate the num_patches
# e.g. when `use_abs_pos_embed` outside
self.init_input_size = input_size
if self.adaptive_padding:
pad_h, pad_w = self.adaptive_padding.get_pad_shape(input_size)
input_h, input_w = input_size
input_h = input_h + pad_h
input_w = input_w + pad_w
input_size = (input_h, input_w)
# https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html
h_out = (input_size[0] + 2 * padding[0] - dilation[0] *
(kernel_size[0] - 1) - 1) // stride[0] + 1
w_out = (input_size[1] + 2 * padding[1] - dilation[1] *
(kernel_size[1] - 1) - 1) // stride[1] + 1
self.init_out_size = (h_out, w_out)
else:
self.init_input_size = None
self.init_out_size = None
def forward(self, x):
"""
Args:
x (Tensor): Has shape (B, C, H, W). In most case, C is 3.
Returns:
tuple: Contains merged results and its spatial shape.
- x (Tensor): Has shape (B, out_h * out_w, embed_dims)
- out_size (tuple[int]): Spatial shape of x, arrange as
(out_h, out_w).
"""
if self.adaptive_padding:
x = self.adaptive_padding(x)
x = self.projection(x)
out_size = (x.shape[2], x.shape[3])
x = x.flatten(2).transpose(1, 2)
if self.norm is not None:
x = self.norm(x)
return x, out_size
class PatchMerging(BaseModule):
"""Merge patch feature map.
This layer groups feature map by kernel_size, and applies norm and linear
layers to the grouped feature map ((used in Swin Transformer)).
Our implementation uses `nn.Unfold` to
merge patches, which is about 25% faster than the original
implementation. However, we need to modify pretrained
models for compatibility.
Args:
in_channels (int): The num of input channels.
to gets fully covered by filter and stride you specified.
out_channels (int): The num of output channels.
kernel_size (int | tuple, optional): the kernel size in the unfold
layer. Defaults to 2.
stride (int | tuple, optional): the stride of the sliding blocks in the
unfold layer. Default: None. (Would be set as `kernel_size`)
padding (int | tuple | string ): The padding length of
embedding conv. When it is a string, it means the mode
of adaptive padding, support "same" and "corner" now.
Default: "corner".
dilation (int | tuple, optional): dilation parameter in the unfold
layer. Default: 1.
bias (bool, optional): Whether to add bias in linear layer or not.
Defaults: False.
norm_cfg (dict, optional): Config dict for normalization layer.
Default: dict(type='LN').
init_cfg (dict, optional): The extra config for initialization.
Default: None.
"""
def __init__(self,
in_channels,
out_channels,
kernel_size=2,
stride=None,
padding='corner',
dilation=1,
bias=False,
norm_cfg=dict(type='LN'),
init_cfg=None):
super().__init__(init_cfg=init_cfg)
self.in_channels = in_channels
self.out_channels = out_channels
if stride:
stride = stride
else:
stride = kernel_size
kernel_size = to_2tuple(kernel_size)
stride = to_2tuple(stride)
dilation = to_2tuple(dilation)
if isinstance(padding, str):
self.adaptive_padding = AdaptivePadding(
kernel_size=kernel_size,
stride=stride,
dilation=dilation,
padding=padding)
# disable the padding of unfold
padding = 0
else:
self.adaptive_padding = None
padding = to_2tuple(padding)
self.sampler = nn.Unfold(
kernel_size=kernel_size,
dilation=dilation,
padding=padding,
stride=stride)
sample_dim = kernel_size[0] * kernel_size[1] * in_channels
if norm_cfg is not None:
self.norm = build_norm_layer(norm_cfg, sample_dim)[1]
else:
self.norm = None
self.reduction = nn.Linear(sample_dim, out_channels, bias=bias)
def forward(self, x, input_size):
"""
Args:
x (Tensor): Has shape (B, H*W, C_in).
input_size (tuple[int]): The spatial shape of x, arrange as (H, W).
Default: None.
Returns:
tuple: Contains merged results and its spatial shape.
- x (Tensor): Has shape (B, Merged_H * Merged_W, C_out)
- out_size (tuple[int]): Spatial shape of x, arrange as
(Merged_H, Merged_W).
"""
B, L, C = x.shape
assert isinstance(input_size, Sequence), f'Expect ' \
f'input_size is ' \
f'`Sequence` ' \
f'but get {input_size}'
H, W = input_size
assert L == H * W, 'input feature has wrong size'
x = x.view(B, H, W, C).permute([0, 3, 1, 2]) # B, C, H, W
if self.adaptive_padding:
x = self.adaptive_padding(x)
H, W = x.shape[-2:]
# Use nn.Unfold to merge patch. About 25% faster than original method,
# but need to modify pretrained model for compatibility
# if kernel_size=2 and stride=2, x should has shape (B, 4*C, H/2*W/2)
x = self.sampler(x)
out_h = (H + 2 * self.sampler.padding[0] - self.sampler.dilation[0] *
(self.sampler.kernel_size[0] - 1) -
1) // self.sampler.stride[0] + 1
out_w = (W + 2 * self.sampler.padding[1] - self.sampler.dilation[1] *
(self.sampler.kernel_size[1] - 1) -
1) // self.sampler.stride[1] + 1
output_size = (out_h, out_w)
x = x.transpose(1, 2) # B, H/2*W/2, 4*C
x = self.norm(x) if self.norm else x
x = self.reduction(x)
return x, output_size
@ATTENTION.register_module() @ATTENTION.register_module()
class MultiheadAttention(BaseModule): class MultiheadAttention(BaseModule):
"""A wrapper for ``torch.nn.MultiheadAttention``. """A wrapper for ``torch.nn.MultiheadAttention``.
...@@ -87,12 +435,13 @@ class MultiheadAttention(BaseModule): ...@@ -87,12 +435,13 @@ class MultiheadAttention(BaseModule):
init_cfg=None, init_cfg=None,
batch_first=False, batch_first=False,
**kwargs): **kwargs):
super(MultiheadAttention, self).__init__(init_cfg) super().__init__(init_cfg)
if 'dropout' in kwargs: if 'dropout' in kwargs:
warnings.warn('The arguments `dropout` in MultiheadAttention ' warnings.warn(
'has been deprecated, now you can separately ' 'The arguments `dropout` in MultiheadAttention '
'set `attn_drop`(float), proj_drop(float), ' 'has been deprecated, now you can separately '
'and `dropout_layer`(dict) ') 'set `attn_drop`(float), proj_drop(float), '
'and `dropout_layer`(dict) ', DeprecationWarning)
attn_drop = kwargs['dropout'] attn_drop = kwargs['dropout']
dropout_layer['drop_prob'] = kwargs.pop('dropout') dropout_layer['drop_prob'] = kwargs.pop('dropout')
...@@ -154,9 +503,9 @@ class MultiheadAttention(BaseModule): ...@@ -154,9 +503,9 @@ class MultiheadAttention(BaseModule):
Returns: Returns:
Tensor: forwarded results with shape Tensor: forwarded results with shape
[num_queries, bs, embed_dims] [num_queries, bs, embed_dims]
if self.batch_first is False, else if self.batch_first is False, else
[bs, num_queries embed_dims]. [bs, num_queries embed_dims].
""" """
if key is None: if key is None:
...@@ -241,7 +590,7 @@ class FFN(BaseModule): ...@@ -241,7 +590,7 @@ class FFN(BaseModule):
add_identity=True, add_identity=True,
init_cfg=None, init_cfg=None,
**kwargs): **kwargs):
super(FFN, self).__init__(init_cfg) super().__init__(init_cfg)
assert num_fcs >= 2, 'num_fcs should be no less ' \ assert num_fcs >= 2, 'num_fcs should be no less ' \
f'than 2. got {num_fcs}.' f'than 2. got {num_fcs}.'
self.embed_dims = embed_dims self.embed_dims = embed_dims
...@@ -342,15 +691,15 @@ class BaseTransformerLayer(BaseModule): ...@@ -342,15 +691,15 @@ class BaseTransformerLayer(BaseModule):
f'The arguments `{ori_name}` in BaseTransformerLayer ' f'The arguments `{ori_name}` in BaseTransformerLayer '
f'has been deprecated, now you should set `{new_name}` ' f'has been deprecated, now you should set `{new_name}` '
f'and other FFN related arguments ' f'and other FFN related arguments '
f'to a dict named `ffn_cfgs`. ') f'to a dict named `ffn_cfgs`. ', DeprecationWarning)
ffn_cfgs[new_name] = kwargs[ori_name] ffn_cfgs[new_name] = kwargs[ori_name]
super(BaseTransformerLayer, self).__init__(init_cfg) super().__init__(init_cfg)
self.batch_first = batch_first self.batch_first = batch_first
assert set(operation_order) & set( assert set(operation_order) & {
['self_attn', 'norm', 'ffn', 'cross_attn']) == \ 'self_attn', 'norm', 'ffn', 'cross_attn'} == \
set(operation_order), f'The operation_order of' \ set(operation_order), f'The operation_order of' \
f' {self.__class__.__name__} should ' \ f' {self.__class__.__name__} should ' \
f'contains all four operation type ' \ f'contains all four operation type ' \
...@@ -397,7 +746,7 @@ class BaseTransformerLayer(BaseModule): ...@@ -397,7 +746,7 @@ class BaseTransformerLayer(BaseModule):
assert len(ffn_cfgs) == num_ffns assert len(ffn_cfgs) == num_ffns
for ffn_index in range(num_ffns): for ffn_index in range(num_ffns):
if 'embed_dims' not in ffn_cfgs[ffn_index]: if 'embed_dims' not in ffn_cfgs[ffn_index]:
ffn_cfgs['embed_dims'] = self.embed_dims ffn_cfgs[ffn_index]['embed_dims'] = self.embed_dims
else: else:
assert ffn_cfgs[ffn_index]['embed_dims'] == self.embed_dims assert ffn_cfgs[ffn_index]['embed_dims'] == self.embed_dims
self.ffns.append( self.ffns.append(
...@@ -531,7 +880,7 @@ class TransformerLayerSequence(BaseModule): ...@@ -531,7 +880,7 @@ class TransformerLayerSequence(BaseModule):
""" """
def __init__(self, transformerlayers=None, num_layers=None, init_cfg=None): def __init__(self, transformerlayers=None, num_layers=None, init_cfg=None):
super(TransformerLayerSequence, self).__init__(init_cfg) super().__init__(init_cfg)
if isinstance(transformerlayers, dict): if isinstance(transformerlayers, dict):
transformerlayers = [ transformerlayers = [
copy.deepcopy(transformerlayers) for _ in range(num_layers) copy.deepcopy(transformerlayers) for _ in range(num_layers)
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict
import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
...@@ -24,9 +27,9 @@ class PixelShufflePack(nn.Module): ...@@ -24,9 +27,9 @@ class PixelShufflePack(nn.Module):
channels. channels.
""" """
def __init__(self, in_channels, out_channels, scale_factor, def __init__(self, in_channels: int, out_channels: int, scale_factor: int,
upsample_kernel): upsample_kernel: int):
super(PixelShufflePack, self).__init__() super().__init__()
self.in_channels = in_channels self.in_channels = in_channels
self.out_channels = out_channels self.out_channels = out_channels
self.scale_factor = scale_factor self.scale_factor = scale_factor
...@@ -41,13 +44,13 @@ class PixelShufflePack(nn.Module): ...@@ -41,13 +44,13 @@ class PixelShufflePack(nn.Module):
def init_weights(self): def init_weights(self):
xavier_init(self.upsample_conv, distribution='uniform') xavier_init(self.upsample_conv, distribution='uniform')
def forward(self, x): def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.upsample_conv(x) x = self.upsample_conv(x)
x = F.pixel_shuffle(x, self.scale_factor) x = F.pixel_shuffle(x, self.scale_factor)
return x return x
def build_upsample_layer(cfg, *args, **kwargs): def build_upsample_layer(cfg: Dict, *args, **kwargs) -> nn.Module:
"""Build upsample layer. """Build upsample layer.
Args: Args:
...@@ -55,7 +58,7 @@ def build_upsample_layer(cfg, *args, **kwargs): ...@@ -55,7 +58,7 @@ def build_upsample_layer(cfg, *args, **kwargs):
- type (str): Layer type. - type (str): Layer type.
- scale_factor (int): Upsample ratio, which is not applicable to - scale_factor (int): Upsample ratio, which is not applicable to
deconv. deconv.
- layer args: Args needed to instantiate a upsample layer. - layer args: Args needed to instantiate a upsample layer.
args (argument list): Arguments passed to the ``__init__`` args (argument list): Arguments passed to the ``__init__``
method of the corresponding conv layer. method of the corresponding conv layer.
......
...@@ -21,19 +21,19 @@ else: ...@@ -21,19 +21,19 @@ else:
TORCH_VERSION = tuple(int(x) for x in torch.__version__.split('.')[:2]) TORCH_VERSION = tuple(int(x) for x in torch.__version__.split('.')[:2])
def obsolete_torch_version(torch_version, version_threshold): def obsolete_torch_version(torch_version, version_threshold) -> bool:
return torch_version == 'parrots' or torch_version <= version_threshold return torch_version == 'parrots' or torch_version <= version_threshold
class NewEmptyTensorOp(torch.autograd.Function): class NewEmptyTensorOp(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, x, new_shape): def forward(ctx, x: torch.Tensor, new_shape: tuple) -> torch.Tensor:
ctx.shape = x.shape ctx.shape = x.shape
return x.new_empty(new_shape) return x.new_empty(new_shape)
@staticmethod @staticmethod
def backward(ctx, grad): def backward(ctx, grad: torch.Tensor) -> tuple:
shape = ctx.shape shape = ctx.shape
return NewEmptyTensorOp.apply(grad, shape), None return NewEmptyTensorOp.apply(grad, shape), None
...@@ -41,7 +41,7 @@ class NewEmptyTensorOp(torch.autograd.Function): ...@@ -41,7 +41,7 @@ class NewEmptyTensorOp(torch.autograd.Function):
@CONV_LAYERS.register_module('Conv', force=True) @CONV_LAYERS.register_module('Conv', force=True)
class Conv2d(nn.Conv2d): class Conv2d(nn.Conv2d):
def forward(self, x): def forward(self, x: torch.Tensor) -> torch.Tensor:
if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 4)): if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 4)):
out_shape = [x.shape[0], self.out_channels] out_shape = [x.shape[0], self.out_channels]
for i, k, p, s, d in zip(x.shape[-2:], self.kernel_size, for i, k, p, s, d in zip(x.shape[-2:], self.kernel_size,
...@@ -62,7 +62,7 @@ class Conv2d(nn.Conv2d): ...@@ -62,7 +62,7 @@ class Conv2d(nn.Conv2d):
@CONV_LAYERS.register_module('Conv3d', force=True) @CONV_LAYERS.register_module('Conv3d', force=True)
class Conv3d(nn.Conv3d): class Conv3d(nn.Conv3d):
def forward(self, x): def forward(self, x: torch.Tensor) -> torch.Tensor:
if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 4)): if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 4)):
out_shape = [x.shape[0], self.out_channels] out_shape = [x.shape[0], self.out_channels]
for i, k, p, s, d in zip(x.shape[-3:], self.kernel_size, for i, k, p, s, d in zip(x.shape[-3:], self.kernel_size,
...@@ -85,7 +85,7 @@ class Conv3d(nn.Conv3d): ...@@ -85,7 +85,7 @@ class Conv3d(nn.Conv3d):
@UPSAMPLE_LAYERS.register_module('deconv', force=True) @UPSAMPLE_LAYERS.register_module('deconv', force=True)
class ConvTranspose2d(nn.ConvTranspose2d): class ConvTranspose2d(nn.ConvTranspose2d):
def forward(self, x): def forward(self, x: torch.Tensor) -> torch.Tensor:
if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 4)): if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 4)):
out_shape = [x.shape[0], self.out_channels] out_shape = [x.shape[0], self.out_channels]
for i, k, p, s, d, op in zip(x.shape[-2:], self.kernel_size, for i, k, p, s, d, op in zip(x.shape[-2:], self.kernel_size,
...@@ -108,7 +108,7 @@ class ConvTranspose2d(nn.ConvTranspose2d): ...@@ -108,7 +108,7 @@ class ConvTranspose2d(nn.ConvTranspose2d):
@UPSAMPLE_LAYERS.register_module('deconv3d', force=True) @UPSAMPLE_LAYERS.register_module('deconv3d', force=True)
class ConvTranspose3d(nn.ConvTranspose3d): class ConvTranspose3d(nn.ConvTranspose3d):
def forward(self, x): def forward(self, x: torch.Tensor) -> torch.Tensor:
if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 4)): if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 4)):
out_shape = [x.shape[0], self.out_channels] out_shape = [x.shape[0], self.out_channels]
for i, k, p, s, d, op in zip(x.shape[-3:], self.kernel_size, for i, k, p, s, d, op in zip(x.shape[-3:], self.kernel_size,
...@@ -128,7 +128,7 @@ class ConvTranspose3d(nn.ConvTranspose3d): ...@@ -128,7 +128,7 @@ class ConvTranspose3d(nn.ConvTranspose3d):
class MaxPool2d(nn.MaxPool2d): class MaxPool2d(nn.MaxPool2d):
def forward(self, x): def forward(self, x: torch.Tensor) -> torch.Tensor:
# PyTorch 1.9 does not support empty tensor inference yet # PyTorch 1.9 does not support empty tensor inference yet
if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 9)): if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 9)):
out_shape = list(x.shape[:2]) out_shape = list(x.shape[:2])
...@@ -146,7 +146,7 @@ class MaxPool2d(nn.MaxPool2d): ...@@ -146,7 +146,7 @@ class MaxPool2d(nn.MaxPool2d):
class MaxPool3d(nn.MaxPool3d): class MaxPool3d(nn.MaxPool3d):
def forward(self, x): def forward(self, x: torch.Tensor) -> torch.Tensor:
# PyTorch 1.9 does not support empty tensor inference yet # PyTorch 1.9 does not support empty tensor inference yet
if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 9)): if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 9)):
out_shape = list(x.shape[:2]) out_shape = list(x.shape[:2])
...@@ -165,7 +165,7 @@ class MaxPool3d(nn.MaxPool3d): ...@@ -165,7 +165,7 @@ class MaxPool3d(nn.MaxPool3d):
class Linear(torch.nn.Linear): class Linear(torch.nn.Linear):
def forward(self, x): def forward(self, x: torch.Tensor) -> torch.Tensor:
# empty tensor forward of Linear layer is supported in Pytorch 1.6 # empty tensor forward of Linear layer is supported in Pytorch 1.6
if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 5)): if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 5)):
out_shape = [x.shape[0], self.out_features] out_shape = [x.shape[0], self.out_features]
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import logging import logging
from typing import Optional, Sequence, Tuple, Union
import torch.nn as nn import torch.nn as nn
import torch.utils.checkpoint as cp import torch.utils.checkpoint as cp
from torch import Tensor
from .utils import constant_init, kaiming_init from .utils import constant_init, kaiming_init
def conv3x3(in_planes, out_planes, stride=1, dilation=1): def conv3x3(in_planes: int,
out_planes: int,
stride: int = 1,
dilation: int = 1):
"""3x3 convolution with padding.""" """3x3 convolution with padding."""
return nn.Conv2d( return nn.Conv2d(
in_planes, in_planes,
...@@ -23,14 +28,14 @@ class BasicBlock(nn.Module): ...@@ -23,14 +28,14 @@ class BasicBlock(nn.Module):
expansion = 1 expansion = 1
def __init__(self, def __init__(self,
inplanes, inplanes: int,
planes, planes: int,
stride=1, stride: int = 1,
dilation=1, dilation: int = 1,
downsample=None, downsample: Optional[nn.Module] = None,
style='pytorch', style: str = 'pytorch',
with_cp=False): with_cp: bool = False):
super(BasicBlock, self).__init__() super().__init__()
assert style in ['pytorch', 'caffe'] assert style in ['pytorch', 'caffe']
self.conv1 = conv3x3(inplanes, planes, stride, dilation) self.conv1 = conv3x3(inplanes, planes, stride, dilation)
self.bn1 = nn.BatchNorm2d(planes) self.bn1 = nn.BatchNorm2d(planes)
...@@ -42,7 +47,7 @@ class BasicBlock(nn.Module): ...@@ -42,7 +47,7 @@ class BasicBlock(nn.Module):
self.dilation = dilation self.dilation = dilation
assert not with_cp assert not with_cp
def forward(self, x): def forward(self, x: Tensor) -> Tensor:
residual = x residual = x
out = self.conv1(x) out = self.conv1(x)
...@@ -65,19 +70,19 @@ class Bottleneck(nn.Module): ...@@ -65,19 +70,19 @@ class Bottleneck(nn.Module):
expansion = 4 expansion = 4
def __init__(self, def __init__(self,
inplanes, inplanes: int,
planes, planes: int,
stride=1, stride: int = 1,
dilation=1, dilation: int = 1,
downsample=None, downsample: Optional[nn.Module] = None,
style='pytorch', style: str = 'pytorch',
with_cp=False): with_cp: bool = False):
"""Bottleneck block. """Bottleneck block.
If style is "pytorch", the stride-two layer is the 3x3 conv layer, if If style is "pytorch", the stride-two layer is the 3x3 conv layer, if
it is "caffe", the stride-two layer is the first 1x1 conv layer. it is "caffe", the stride-two layer is the first 1x1 conv layer.
""" """
super(Bottleneck, self).__init__() super().__init__()
assert style in ['pytorch', 'caffe'] assert style in ['pytorch', 'caffe']
if style == 'pytorch': if style == 'pytorch':
conv1_stride = 1 conv1_stride = 1
...@@ -107,7 +112,7 @@ class Bottleneck(nn.Module): ...@@ -107,7 +112,7 @@ class Bottleneck(nn.Module):
self.dilation = dilation self.dilation = dilation
self.with_cp = with_cp self.with_cp = with_cp
def forward(self, x): def forward(self, x: Tensor) -> Tensor:
def _inner_forward(x): def _inner_forward(x):
residual = x residual = x
...@@ -140,14 +145,14 @@ class Bottleneck(nn.Module): ...@@ -140,14 +145,14 @@ class Bottleneck(nn.Module):
return out return out
def make_res_layer(block, def make_res_layer(block: nn.Module,
inplanes, inplanes: int,
planes, planes: int,
blocks, blocks: int,
stride=1, stride: int = 1,
dilation=1, dilation: int = 1,
style='pytorch', style: str = 'pytorch',
with_cp=False): with_cp: bool = False) -> nn.Module:
downsample = None downsample = None
if stride != 1 or inplanes != planes * block.expansion: if stride != 1 or inplanes != planes * block.expansion:
downsample = nn.Sequential( downsample = nn.Sequential(
...@@ -208,22 +213,22 @@ class ResNet(nn.Module): ...@@ -208,22 +213,22 @@ class ResNet(nn.Module):
} }
def __init__(self, def __init__(self,
depth, depth: int,
num_stages=4, num_stages: int = 4,
strides=(1, 2, 2, 2), strides: Sequence[int] = (1, 2, 2, 2),
dilations=(1, 1, 1, 1), dilations: Sequence[int] = (1, 1, 1, 1),
out_indices=(0, 1, 2, 3), out_indices: Sequence[int] = (0, 1, 2, 3),
style='pytorch', style: str = 'pytorch',
frozen_stages=-1, frozen_stages: int = -1,
bn_eval=True, bn_eval: bool = True,
bn_frozen=False, bn_frozen: bool = False,
with_cp=False): with_cp: bool = False):
super(ResNet, self).__init__() super().__init__()
if depth not in self.arch_settings: if depth not in self.arch_settings:
raise KeyError(f'invalid depth {depth} for resnet') raise KeyError(f'invalid depth {depth} for resnet')
assert num_stages >= 1 and num_stages <= 4 assert num_stages >= 1 and num_stages <= 4
block, stage_blocks = self.arch_settings[depth] block, stage_blocks = self.arch_settings[depth]
stage_blocks = stage_blocks[:num_stages] stage_blocks = stage_blocks[:num_stages] # type: ignore
assert len(strides) == len(dilations) == num_stages assert len(strides) == len(dilations) == num_stages
assert max(out_indices) < num_stages assert max(out_indices) < num_stages
...@@ -234,7 +239,7 @@ class ResNet(nn.Module): ...@@ -234,7 +239,7 @@ class ResNet(nn.Module):
self.bn_frozen = bn_frozen self.bn_frozen = bn_frozen
self.with_cp = with_cp self.with_cp = with_cp
self.inplanes = 64 self.inplanes: int = 64
self.conv1 = nn.Conv2d( self.conv1 = nn.Conv2d(
3, 64, kernel_size=7, stride=2, padding=3, bias=False) 3, 64, kernel_size=7, stride=2, padding=3, bias=False)
self.bn1 = nn.BatchNorm2d(64) self.bn1 = nn.BatchNorm2d(64)
...@@ -255,14 +260,15 @@ class ResNet(nn.Module): ...@@ -255,14 +260,15 @@ class ResNet(nn.Module):
dilation=dilation, dilation=dilation,
style=self.style, style=self.style,
with_cp=with_cp) with_cp=with_cp)
self.inplanes = planes * block.expansion self.inplanes = planes * block.expansion # type: ignore
layer_name = f'layer{i + 1}' layer_name = f'layer{i + 1}'
self.add_module(layer_name, res_layer) self.add_module(layer_name, res_layer)
self.res_layers.append(layer_name) self.res_layers.append(layer_name)
self.feat_dim = block.expansion * 64 * 2**(len(stage_blocks) - 1) self.feat_dim = block.expansion * 64 * 2**( # type: ignore
len(stage_blocks) - 1)
def init_weights(self, pretrained=None): def init_weights(self, pretrained: Optional[str] = None) -> None:
if isinstance(pretrained, str): if isinstance(pretrained, str):
logger = logging.getLogger() logger = logging.getLogger()
from ..runner import load_checkpoint from ..runner import load_checkpoint
...@@ -276,7 +282,7 @@ class ResNet(nn.Module): ...@@ -276,7 +282,7 @@ class ResNet(nn.Module):
else: else:
raise TypeError('pretrained must be a str or None') raise TypeError('pretrained must be a str or None')
def forward(self, x): def forward(self, x: Tensor) -> Union[Tensor, Tuple[Tensor]]:
x = self.conv1(x) x = self.conv1(x)
x = self.bn1(x) x = self.bn1(x)
x = self.relu(x) x = self.relu(x)
...@@ -292,8 +298,8 @@ class ResNet(nn.Module): ...@@ -292,8 +298,8 @@ class ResNet(nn.Module):
else: else:
return tuple(outs) return tuple(outs)
def train(self, mode=True): def train(self, mode: bool = True) -> None:
super(ResNet, self).train(mode) super().train(mode)
if self.bn_eval: if self.bn_eval:
for m in self.modules(): for m in self.modules():
if isinstance(m, nn.BatchNorm2d): if isinstance(m, nn.BatchNorm2d):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment