Unverified Commit 42f03d84 authored by Jerry Jiarui XU's avatar Jerry Jiarui XU Committed by GitHub
Browse files

Refactor plugins: move from ops to cnn (#380)

* Refator plugins: move from ops to cnn

* minior update

* minior update
parent 82211b40
from .activation import build_activation_layer from .activation import build_activation_layer
from .context_block import ContextBlock
from .conv import build_conv_layer from .conv import build_conv_layer
from .conv_module import ConvModule from .conv_module import ConvModule
from .generalized_attention import GeneralizedAttention
from .hsigmoid import HSigmoid from .hsigmoid import HSigmoid
from .hswish import HSwish from .hswish import HSwish
from .non_local import NonLocal1d, NonLocal2d, NonLocal3d from .non_local import NonLocal1d, NonLocal2d, NonLocal3d
from .norm import build_norm_layer, is_norm from .norm import build_norm_layer, is_norm
from .padding import build_padding_layer from .padding import build_padding_layer
from .plugin import build_plugin_layer
from .registry import (ACTIVATION_LAYERS, CONV_LAYERS, NORM_LAYERS, from .registry import (ACTIVATION_LAYERS, CONV_LAYERS, NORM_LAYERS,
PADDING_LAYERS, UPSAMPLE_LAYERS) PADDING_LAYERS, PLUGIN_LAYERS, UPSAMPLE_LAYERS)
from .scale import Scale from .scale import Scale
from .upsample import build_upsample_layer from .upsample import build_upsample_layer
__all__ = [ __all__ = [
'ConvModule', 'build_activation_layer', 'build_conv_layer', 'ConvModule', 'build_activation_layer', 'build_conv_layer',
'build_norm_layer', 'build_padding_layer', 'build_upsample_layer', 'build_norm_layer', 'build_padding_layer', 'build_upsample_layer',
'is_norm', 'HSigmoid', 'HSwish', 'NonLocal1d', 'NonLocal2d', 'NonLocal3d', 'build_plugin_layer', 'is_norm', 'HSigmoid', 'HSwish', 'NonLocal1d',
'NonLocal2d', 'NonLocal3d', 'ContextBlock', 'GeneralizedAttention',
'ACTIVATION_LAYERS', 'CONV_LAYERS', 'NORM_LAYERS', 'PADDING_LAYERS', 'ACTIVATION_LAYERS', 'CONV_LAYERS', 'NORM_LAYERS', 'PADDING_LAYERS',
'UPSAMPLE_LAYERS', 'Scale' 'UPSAMPLE_LAYERS', 'PLUGIN_LAYERS', 'Scale'
] ]
import torch import torch
from torch import nn from torch import nn
from ..cnn import constant_init, kaiming_init from ..utils import constant_init, kaiming_init
from .registry import PLUGIN_LAYERS
def last_zero_init(m): def last_zero_init(m):
...@@ -11,6 +12,7 @@ def last_zero_init(m): ...@@ -11,6 +12,7 @@ def last_zero_init(m):
constant_init(m, val=0) constant_init(m, val=0)
@PLUGIN_LAYERS.register_module()
class ContextBlock(nn.Module): class ContextBlock(nn.Module):
"""ContextBlock module in GCNet. """ContextBlock module in GCNet.
...@@ -20,11 +22,16 @@ class ContextBlock(nn.Module): ...@@ -20,11 +22,16 @@ class ContextBlock(nn.Module):
Args: Args:
in_channels (int): Channels of the input feature map. in_channels (int): Channels of the input feature map.
ratio (float): Ratio of channels of transform bottleneck ratio (float): Ratio of channels of transform bottleneck
pooling_type (str): Pooling method for context modeling pooling_type (str): Pooling method for context modeling.
fusion_types (list[str]|tuple[str]): Fusion method for feature fusion, Options are 'att' and 'avg', stand for attention pooling and
options: 'channels_add', 'channel_mul' average pooling respectively. Default: 'att'.
fusion_types (Sequence[str]): Fusion method for feature fusion,
Options are 'channels_add', 'channel_mul', stand for channelwise
addition and multiplication respectively. Default: ('channel_add',)
""" """
_abbr_ = 'context_block'
def __init__(self, def __init__(self,
in_channels, in_channels,
ratio, ratio,
......
...@@ -7,8 +7,10 @@ from .activation import build_activation_layer ...@@ -7,8 +7,10 @@ from .activation import build_activation_layer
from .conv import build_conv_layer from .conv import build_conv_layer
from .norm import build_norm_layer from .norm import build_norm_layer
from .padding import build_padding_layer from .padding import build_padding_layer
from .registry import PLUGIN_LAYERS
@PLUGIN_LAYERS.register_module()
class ConvModule(nn.Module): class ConvModule(nn.Module):
"""A conv block that bundles conv/norm/activation layers. """A conv block that bundles conv/norm/activation layers.
...@@ -54,6 +56,8 @@ class ConvModule(nn.Module): ...@@ -54,6 +56,8 @@ class ConvModule(nn.Module):
Default: ('conv', 'norm', 'act'). Default: ('conv', 'norm', 'act').
""" """
_abbr_ = 'conv_block'
def __init__(self, def __init__(self,
in_channels, in_channels,
out_channels, out_channels,
......
...@@ -5,9 +5,11 @@ import torch ...@@ -5,9 +5,11 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from ..cnn import kaiming_init from ..utils import kaiming_init
from .registry import PLUGIN_LAYERS
@PLUGIN_LAYERS.register_module()
class GeneralizedAttention(nn.Module): class GeneralizedAttention(nn.Module):
"""GeneralizedAttention module. """GeneralizedAttention module.
...@@ -16,13 +18,18 @@ class GeneralizedAttention(nn.Module): ...@@ -16,13 +18,18 @@ class GeneralizedAttention(nn.Module):
Args: Args:
in_channels (int): Channels of the input feature map. in_channels (int): Channels of the input feature map.
spatial_range (int): The spatial range. spatial_range (int): The spatial range. -1 indicates no spatial range
-1 indicates no spatial range constraint. constraint. Default: -1.
num_heads (int): The head number of empirical_attention module. num_heads (int): The head number of empirical_attention module.
Default: 9.
position_embedding_dim (int): The position embedding dimension. position_embedding_dim (int): The position embedding dimension.
Default: -1.
position_magnitude (int): A multiplier acting on coord difference. position_magnitude (int): A multiplier acting on coord difference.
Default: 1.
kv_stride (int): The feature stride acting on key/value feature map. kv_stride (int): The feature stride acting on key/value feature map.
Default: 2.
q_stride (int): The feature stride acting on query feature map. q_stride (int): The feature stride acting on query feature map.
Default: 1.
attention_type (str): A binary indicator string for indicating which attention_type (str): A binary indicator string for indicating which
items in generalized empirical_attention module are used. items in generalized empirical_attention module are used.
'1000' indicates 'query and key content' (appr - appr) item, '1000' indicates 'query and key content' (appr - appr) item,
...@@ -30,8 +37,11 @@ class GeneralizedAttention(nn.Module): ...@@ -30,8 +37,11 @@ class GeneralizedAttention(nn.Module):
(appr - position) item, (appr - position) item,
'0010' indicates 'key content only' (bias - appr) item, '0010' indicates 'key content only' (bias - appr) item,
'0001' indicates 'relative position only' (bias - position) item. '0001' indicates 'relative position only' (bias - position) item.
Default: '1111'.
""" """
_abbr_ = 'gen_attention_block'
def __init__(self, def __init__(self,
in_channels, in_channels,
spatial_range=-1, spatial_range=-1,
...@@ -161,16 +171,16 @@ class GeneralizedAttention(nn.Module): ...@@ -161,16 +171,16 @@ class GeneralizedAttention(nn.Module):
device, device,
feat_dim, feat_dim,
wave_length=1000): wave_length=1000):
h_idxs = torch.linspace(0, h - 1, h).cuda(device) h_idxs = torch.linspace(0, h - 1, h).to(device)
h_idxs = h_idxs.view((h, 1)) * q_stride h_idxs = h_idxs.view((h, 1)) * q_stride
w_idxs = torch.linspace(0, w - 1, w).cuda(device) w_idxs = torch.linspace(0, w - 1, w).to(device)
w_idxs = w_idxs.view((w, 1)) * q_stride w_idxs = w_idxs.view((w, 1)) * q_stride
h_kv_idxs = torch.linspace(0, h_kv - 1, h_kv).cuda(device) h_kv_idxs = torch.linspace(0, h_kv - 1, h_kv).to(device)
h_kv_idxs = h_kv_idxs.view((h_kv, 1)) * kv_stride h_kv_idxs = h_kv_idxs.view((h_kv, 1)) * kv_stride
w_kv_idxs = torch.linspace(0, w_kv - 1, w_kv).cuda(device) w_kv_idxs = torch.linspace(0, w_kv - 1, w_kv).to(device)
w_kv_idxs = w_kv_idxs.view((w_kv, 1)) * kv_stride w_kv_idxs = w_kv_idxs.view((w_kv, 1)) * kv_stride
# (h, h_kv, 1) # (h, h_kv, 1)
...@@ -181,9 +191,9 @@ class GeneralizedAttention(nn.Module): ...@@ -181,9 +191,9 @@ class GeneralizedAttention(nn.Module):
w_diff = w_idxs.unsqueeze(1) - w_kv_idxs.unsqueeze(0) w_diff = w_idxs.unsqueeze(1) - w_kv_idxs.unsqueeze(0)
w_diff *= self.position_magnitude w_diff *= self.position_magnitude
feat_range = torch.arange(0, feat_dim / 4).cuda(device) feat_range = torch.arange(0, feat_dim / 4).to(device)
dim_mat = torch.Tensor([wave_length]).cuda(device) dim_mat = torch.Tensor([wave_length]).to(device)
dim_mat = dim_mat**((4. / feat_dim) * feat_range) dim_mat = dim_mat**((4. / feat_dim) * feat_range)
dim_mat = dim_mat.view((1, 1, -1)) dim_mat = dim_mat.view((1, 1, -1))
...@@ -370,6 +380,15 @@ class GeneralizedAttention(nn.Module): ...@@ -370,6 +380,15 @@ class GeneralizedAttention(nn.Module):
view(n, self.v_dim * self.num_heads, h, w) view(n, self.v_dim * self.num_heads, h, w)
out = self.proj_conv(out) out = self.proj_conv(out)
# output is downsampled, upsample back to input size
if self.q_downsample is not None:
out = F.interpolate(
out,
size=x_input.shape[2:],
mode='bilinear',
align_corners=False)
out = self.gamma * out + x_input out = self.gamma * out + x_input
return out return out
......
...@@ -5,6 +5,7 @@ import torch.nn as nn ...@@ -5,6 +5,7 @@ import torch.nn as nn
from ..utils import constant_init, normal_init from ..utils import constant_init, normal_init
from .conv_module import ConvModule from .conv_module import ConvModule
from .registry import PLUGIN_LAYERS
class _NonLocalNd(nn.Module, metaclass=ABCMeta): class _NonLocalNd(nn.Module, metaclass=ABCMeta):
...@@ -185,6 +186,7 @@ class NonLocal1d(_NonLocalNd): ...@@ -185,6 +186,7 @@ class NonLocal1d(_NonLocalNd):
self.phi = nn.Sequential(self.phi, max_pool_layer) self.phi = nn.Sequential(self.phi, max_pool_layer)
@PLUGIN_LAYERS.register_module()
class NonLocal2d(_NonLocalNd): class NonLocal2d(_NonLocalNd):
"""2D Non-local module. """2D Non-local module.
...@@ -197,6 +199,8 @@ class NonLocal2d(_NonLocalNd): ...@@ -197,6 +199,8 @@ class NonLocal2d(_NonLocalNd):
Default: dict(type='Conv2d'). Default: dict(type='Conv2d').
""" """
_abbr_ = 'nonlocal_block'
def __init__(self, def __init__(self,
in_channels, in_channels,
sub_sample=False, sub_sample=False,
......
...@@ -26,7 +26,7 @@ def infer_abbr(class_type): ...@@ -26,7 +26,7 @@ def infer_abbr(class_type):
the norm type in variable names, e.g, self.bn1, self.gn. This method will the norm type in variable names, e.g, self.bn1, self.gn. This method will
infer the abbreviation to map class types to abbreviations. infer the abbreviation to map class types to abbreviations.
Rule 1: If the class has the property "abbr", return the property. Rule 1: If the class has the property "_abbr_", return the property.
Rule 2: If the parent class is _BatchNorm, GroupNorm, LayerNorm or Rule 2: If the parent class is _BatchNorm, GroupNorm, LayerNorm or
InstanceNorm, the abbreviation of this layer will be "bn", "gn", "ln" and InstanceNorm, the abbreviation of this layer will be "bn", "gn", "ln" and
"in" respectively. "in" respectively.
...@@ -44,8 +44,8 @@ def infer_abbr(class_type): ...@@ -44,8 +44,8 @@ def infer_abbr(class_type):
if not inspect.isclass(class_type): if not inspect.isclass(class_type):
raise TypeError( raise TypeError(
f'class_type must be a type, but got {type(class_type)}') f'class_type must be a type, but got {type(class_type)}')
if hasattr(class_type, 'abbr'): if hasattr(class_type, '_abbr_'):
return class_type.abbr return class_type._abbr_
if issubclass(class_type, _InstanceNorm): # IN is a subclass of BN if issubclass(class_type, _InstanceNorm): # IN is a subclass of BN
return 'in' return 'in'
elif issubclass(class_type, _BatchNorm): elif issubclass(class_type, _BatchNorm):
......
import inspect
import re
from .registry import PLUGIN_LAYERS
def infer_abbr(class_type):
"""Infer abbreviation from the class name.
This method will infer the abbreviation to map class types to
abbreviations.
Rule 1: If the class has the property "abbr", return the property.
Rule 2: Otherwise, the abbreviation falls back to snake case of class
name, e.g. the abbreviation of ``FancyBlock`` will be ``fancy_block``.
Args:
class_type (type): The norm layer type.
Returns:
str: The inferred abbreviation.
"""
def camel2snack(word):
"""Convert camel case word into snack case.
Modified from `inflection lib
<https://inflection.readthedocs.io/en/latest/#inflection.underscore>`_.
Example::
>>> camel2snack("FancyBlock")
'fancy_block'
"""
word = re.sub(r'([A-Z]+)([A-Z][a-z])', r'\1_\2', word)
word = re.sub(r'([a-z\d])([A-Z])', r'\1_\2', word)
word = word.replace('-', '_')
return word.lower()
if not inspect.isclass(class_type):
raise TypeError(
f'class_type must be a type, but got {type(class_type)}')
if hasattr(class_type, '_abbr_'):
return class_type._abbr_
else:
return camel2snack(class_type.__name__)
def build_plugin_layer(cfg, postfix='', **kwargs):
"""Build plugin layer.
Args:
cfg (None or dict): cfg should contain:
type (str): identify plugin layer type.
layer args: args needed to instantiate a plugin layer.
postfix (int, str): appended into norm abbreviation to
create named layer. Default: ''.
Returns:
tuple[str, nn.Module]:
name (str): abbreviation + postfix
layer (nn.Module): created plugin layer
"""
if not isinstance(cfg, dict):
raise TypeError('cfg must be a dict')
if 'type' not in cfg:
raise KeyError('the cfg dict must contain the key "type"')
cfg_ = cfg.copy()
layer_type = cfg_.pop('type')
if layer_type not in PLUGIN_LAYERS:
raise KeyError(f'Unrecognized plugin type {layer_type}')
plugin_layer = PLUGIN_LAYERS.get(layer_type)
abbr = infer_abbr(plugin_layer)
assert isinstance(postfix, (int, str))
name = abbr + str(postfix)
layer = plugin_layer(**kwargs, **cfg_)
return name, layer
...@@ -5,3 +5,4 @@ NORM_LAYERS = Registry('norm layer') ...@@ -5,3 +5,4 @@ NORM_LAYERS = Registry('norm layer')
ACTIVATION_LAYERS = Registry('activation layer') ACTIVATION_LAYERS = Registry('activation layer')
PADDING_LAYERS = Registry('padding layer') PADDING_LAYERS = Registry('padding layer')
UPSAMPLE_LAYERS = Registry('upsample layer') UPSAMPLE_LAYERS = Registry('upsample layer')
PLUGIN_LAYERS = Registry('plugin layer')
from .bbox import bbox_overlaps from .bbox import bbox_overlaps
from .carafe import CARAFE, CARAFENaive, CARAFEPack, carafe, carafe_naive from .carafe import CARAFE, CARAFENaive, CARAFEPack, carafe, carafe_naive
from .cc_attention import CrissCrossAttention from .cc_attention import CrissCrossAttention
from .context_block import ContextBlock
from .conv_ws import ConvWS2d, conv_ws_2d from .conv_ws import ConvWS2d, conv_ws_2d
from .corner_pool import CornerPool from .corner_pool import CornerPool
from .deform_conv import DeformConv2d, DeformConv2dPack, deform_conv2d from .deform_conv import DeformConv2d, DeformConv2dPack, deform_conv2d
...@@ -9,14 +8,12 @@ from .deform_roi_pool import (DeformRoIPool, DeformRoIPoolPack, ...@@ -9,14 +8,12 @@ from .deform_roi_pool import (DeformRoIPool, DeformRoIPoolPack,
ModulatedDeformRoIPoolPack, deform_roi_pool) ModulatedDeformRoIPoolPack, deform_roi_pool)
from .focal_loss import (SigmoidFocalLoss, SoftmaxFocalLoss, from .focal_loss import (SigmoidFocalLoss, SoftmaxFocalLoss,
sigmoid_focal_loss, softmax_focal_loss) sigmoid_focal_loss, softmax_focal_loss)
from .generalized_attention import GeneralizedAttention
from .info import get_compiler_version, get_compiling_cuda_version from .info import get_compiler_version, get_compiling_cuda_version
from .masked_conv import MaskedConv2d, masked_conv2d from .masked_conv import MaskedConv2d, masked_conv2d
from .modulated_deform_conv import (ModulatedDeformConv2d, from .modulated_deform_conv import (ModulatedDeformConv2d,
ModulatedDeformConv2dPack, ModulatedDeformConv2dPack,
modulated_deform_conv2d) modulated_deform_conv2d)
from .nms import batched_nms, nms, nms_match, soft_nms from .nms import batched_nms, nms, nms_match, soft_nms
from .plugin import build_plugin_layer
from .point_sample import (SimpleRoIAlign, point_sample, from .point_sample import (SimpleRoIAlign, point_sample,
rel_roi_point_to_rel_img_point) rel_roi_point_to_rel_img_point)
from .psa_mask import PSAMask from .psa_mask import PSAMask
...@@ -27,16 +24,14 @@ from .wrappers import Conv2d, ConvTranspose2d, Linear, MaxPool2d ...@@ -27,16 +24,14 @@ from .wrappers import Conv2d, ConvTranspose2d, Linear, MaxPool2d
__all__ = [ __all__ = [
'bbox_overlaps', 'CARAFE', 'CARAFENaive', 'CARAFEPack', 'carafe', 'bbox_overlaps', 'CARAFE', 'CARAFENaive', 'CARAFEPack', 'carafe',
'carafe_naive', 'ContextBlock', 'ConvWS2d', 'conv_ws_2d', 'CornerPool', 'carafe_naive', 'ConvWS2d', 'conv_ws_2d', 'CornerPool', 'DeformConv2d',
'DeformConv2d', 'DeformConv2dPack', 'deform_conv2d', 'DeformRoIPool', 'DeformConv2dPack', 'deform_conv2d', 'DeformRoIPool', 'DeformRoIPoolPack',
'DeformRoIPoolPack', 'ModulatedDeformRoIPoolPack', 'deform_roi_pool', 'ModulatedDeformRoIPoolPack', 'deform_roi_pool', 'SigmoidFocalLoss',
'SigmoidFocalLoss', 'SoftmaxFocalLoss', 'sigmoid_focal_loss', 'SoftmaxFocalLoss', 'sigmoid_focal_loss', 'softmax_focal_loss',
'softmax_focal_loss', 'GeneralizedAttention', 'get_compiler_version', 'get_compiler_version', 'get_compiling_cuda_version', 'MaskedConv2d',
'get_compiling_cuda_version', 'MaskedConv2d', 'masked_conv2d', 'masked_conv2d', 'ModulatedDeformConv2d', 'ModulatedDeformConv2dPack',
'ModulatedDeformConv2d', 'ModulatedDeformConv2dPack',
'modulated_deform_conv2d', 'batched_nms', 'nms', 'soft_nms', 'nms_match', 'modulated_deform_conv2d', 'batched_nms', 'nms', 'soft_nms', 'nms_match',
'build_plugin_layer', 'RoIAlign', 'roi_align', 'RoIPool', 'roi_pool', 'RoIAlign', 'roi_align', 'RoIPool', 'roi_pool', 'SyncBatchNorm', 'Conv2d',
'SyncBatchNorm', 'Conv2d', 'ConvTranspose2d', 'Linear', 'MaxPool2d', 'ConvTranspose2d', 'Linear', 'MaxPool2d', 'CrissCrossAttention', 'PSAMask',
'CrissCrossAttention', 'PSAMask', 'point_sample', 'point_sample', 'rel_roi_point_to_rel_img_point', 'SimpleRoIAlign'
'rel_roi_point_to_rel_img_point', 'SimpleRoIAlign'
] ]
from ..cnn import ConvModule, NonLocal2d
from .context_block import ContextBlock
from .generalized_attention import GeneralizedAttention
plugin_cfg = {
# format: layer_type: (abbreviation, module)
'ContextBlock': ('context_block', ContextBlock),
'GeneralizedAttention': ('gen_attention_block', GeneralizedAttention),
'NonLocal2d': ('nonlocal_block', NonLocal2d),
'ConvModule': ('conv_block', ConvModule),
}
def build_plugin_layer(cfg, postfix='', **kwargs):
"""Build plugin layer.
Args:
cfg (None or dict): cfg should contain:
type (str): identify plugin layer type.
layer args: args needed to instantiate a plugin layer.
postfix (int, str): appended into norm abbreviation to
create named layer.
Returns:
name (str): abbreviation + postfix
layer (nn.Module): created plugin layer
"""
assert isinstance(cfg, dict) and 'type' in cfg
cfg_ = cfg.copy()
layer_type = cfg_.pop('type')
if layer_type not in plugin_cfg:
raise KeyError(f'Unrecognized plugin type {layer_type}')
else:
abbr, plugin_layer = plugin_cfg[layer_type]
assert isinstance(postfix, (int, str))
name = abbr + str(postfix)
layer = plugin_layer(**kwargs, **cfg_)
return name, layer
...@@ -3,11 +3,12 @@ import torch ...@@ -3,11 +3,12 @@ import torch
import torch.nn as nn import torch.nn as nn
from mmcv.cnn.bricks import (ACTIVATION_LAYERS, CONV_LAYERS, NORM_LAYERS, from mmcv.cnn.bricks import (ACTIVATION_LAYERS, CONV_LAYERS, NORM_LAYERS,
PADDING_LAYERS, build_activation_layer, PADDING_LAYERS, PLUGIN_LAYERS,
build_conv_layer, build_norm_layer, build_activation_layer, build_conv_layer,
build_padding_layer, build_upsample_layer, build_norm_layer, build_padding_layer,
is_norm) build_plugin_layer, build_upsample_layer, is_norm)
from mmcv.cnn.bricks.norm import infer_abbr from mmcv.cnn.bricks.norm import infer_abbr as infer_norm_abbr
from mmcv.cnn.bricks.plugin import infer_abbr as infer_plugin_abbr
from mmcv.cnn.bricks.upsample import PixelShufflePack from mmcv.cnn.bricks.upsample import PixelShufflePack
from mmcv.utils.parrots_wrapper import _BatchNorm from mmcv.utils.parrots_wrapper import _BatchNorm
...@@ -56,41 +57,41 @@ def test_build_conv_layer(): ...@@ -56,41 +57,41 @@ def test_build_conv_layer():
assert layer.out_channels == kwargs['out_channels'] assert layer.out_channels == kwargs['out_channels']
def test_infer_abbr(): def test_infer_norm_abbr():
with pytest.raises(TypeError): with pytest.raises(TypeError):
# class_type must be a class # class_type must be a class
infer_abbr(0) infer_norm_abbr(0)
class MyNorm: class MyNorm:
abbr = 'mn' _abbr_ = 'mn'
assert infer_abbr(MyNorm) == 'mn' assert infer_norm_abbr(MyNorm) == 'mn'
class FancyBatchNorm: class FancyBatchNorm:
pass pass
assert infer_abbr(FancyBatchNorm) == 'bn' assert infer_norm_abbr(FancyBatchNorm) == 'bn'
class FancyInstanceNorm: class FancyInstanceNorm:
pass pass
assert infer_abbr(FancyInstanceNorm) == 'in' assert infer_norm_abbr(FancyInstanceNorm) == 'in'
class FancyLayerNorm: class FancyLayerNorm:
pass pass
assert infer_abbr(FancyLayerNorm) == 'ln' assert infer_norm_abbr(FancyLayerNorm) == 'ln'
class FancyGroupNorm: class FancyGroupNorm:
pass pass
assert infer_abbr(FancyGroupNorm) == 'gn' assert infer_norm_abbr(FancyGroupNorm) == 'gn'
class FancyNorm: class FancyNorm:
pass pass
assert infer_abbr(FancyNorm) == 'norm' assert infer_norm_abbr(FancyNorm) == 'norm'
def test_build_norm_layer(): def test_build_norm_layer():
...@@ -296,3 +297,77 @@ def test_is_norm(): ...@@ -296,3 +297,77 @@ def test_is_norm():
with pytest.raises(TypeError): with pytest.raises(TypeError):
layer = nn.BatchNorm1d(3) layer = nn.BatchNorm1d(3)
is_norm(layer, exclude=('BN', )) is_norm(layer, exclude=('BN', ))
def test_infer_plugin_abbr():
with pytest.raises(TypeError):
# class_type must be a class
infer_plugin_abbr(0)
class MyPlugin:
_abbr_ = 'mp'
assert infer_plugin_abbr(MyPlugin) == 'mp'
class FancyPlugin:
pass
assert infer_plugin_abbr(FancyPlugin) == 'fancy_plugin'
def test_build_plugin_layer():
with pytest.raises(TypeError):
# cfg must be a dict
cfg = 'Plugin'
build_plugin_layer(cfg)
with pytest.raises(KeyError):
# `type` must be in cfg
cfg = dict()
build_plugin_layer(cfg)
with pytest.raises(KeyError):
# unsupported plugin type
cfg = dict(type='FancyPlugin')
build_plugin_layer(cfg)
with pytest.raises(AssertionError):
# postfix must be int or str
cfg = dict(type='ConvModule')
build_plugin_layer(cfg, postfix=[1, 2])
# test ContextBlock
for postfix in ['', '_test', 1]:
cfg = dict(type='ContextBlock')
name, layer = build_plugin_layer(
cfg, postfix=postfix, in_channels=16, ratio=1. / 4)
assert name == 'context_block' + str(postfix)
assert isinstance(layer, PLUGIN_LAYERS.module_dict['ContextBlock'])
# test GeneralizedAttention
for postfix in ['', '_test', 1]:
cfg = dict(type='GeneralizedAttention')
name, layer = build_plugin_layer(cfg, postfix=postfix, in_channels=16)
assert name == 'gen_attention_block' + str(postfix)
assert isinstance(layer,
PLUGIN_LAYERS.module_dict['GeneralizedAttention'])
# test NonLocal2d
for postfix in ['', '_test', 1]:
cfg = dict(type='NonLocal2d')
name, layer = build_plugin_layer(cfg, postfix=postfix, in_channels=16)
assert name == 'nonlocal_block' + str(postfix)
assert isinstance(layer, PLUGIN_LAYERS.module_dict['NonLocal2d'])
# test ConvModule
for postfix in ['', '_test', 1]:
cfg = dict(type='ConvModule')
name, layer = build_plugin_layer(
cfg,
postfix=postfix,
in_channels=16,
out_channels=4,
kernel_size=3)
assert name == 'conv_block' + str(postfix)
assert isinstance(layer, PLUGIN_LAYERS.module_dict['ConvModule'])
import pytest
import torch
from mmcv.cnn.bricks import ContextBlock
def test_context_block():
with pytest.raises(AssertionError):
# pooling_type should be in ['att', 'avg']
ContextBlock(16, 1. / 4, pooling_type='unsupport_type')
with pytest.raises(AssertionError):
# fusion_types should be of type list or tuple
ContextBlock(16, 1. / 4, fusion_types='unsupport_type')
with pytest.raises(AssertionError):
# fusion_types should be in ['channel_add', 'channel_mul']
ContextBlock(16, 1. / 4, fusion_types=('unsupport_type', ))
# test pooling_type='att'
imgs = torch.randn(2, 16, 20, 20)
context_block = ContextBlock(16, 1. / 4, pooling_type='att')
out = context_block(imgs)
assert context_block.conv_mask.in_channels == 16
assert context_block.conv_mask.out_channels == 1
assert out.shape == imgs.shape
# test pooling_type='avg'
imgs = torch.randn(2, 16, 20, 20)
context_block = ContextBlock(16, 1. / 4, pooling_type='avg')
out = context_block(imgs)
assert hasattr(context_block, 'avg_pool')
assert out.shape == imgs.shape
# test fusion_types=('channel_add',)
imgs = torch.randn(2, 16, 20, 20)
context_block = ContextBlock(16, 1. / 4, fusion_types=('channel_add', ))
out = context_block(imgs)
assert context_block.channel_add_conv is not None
assert context_block.channel_mul_conv is None
assert out.shape == imgs.shape
# test fusion_types=('channel_mul',)
imgs = torch.randn(2, 16, 20, 20)
context_block = ContextBlock(16, 1. / 4, fusion_types=('channel_mul', ))
out = context_block(imgs)
assert context_block.channel_add_conv is None
assert context_block.channel_mul_conv is not None
assert out.shape == imgs.shape
# test fusion_types=('channel_add', 'channel_mul')
imgs = torch.randn(2, 16, 20, 20)
context_block = ContextBlock(
16, 1. / 4, fusion_types=('channel_add', 'channel_mul'))
out = context_block(imgs)
assert context_block.channel_add_conv is not None
assert context_block.channel_mul_conv is not None
assert out.shape == imgs.shape
import torch
from mmcv.cnn.bricks import GeneralizedAttention
def test_context_block():
# test attention_type='1000'
imgs = torch.randn(2, 16, 20, 20)
gen_attention_block = GeneralizedAttention(16, attention_type='1000')
assert gen_attention_block.query_conv.in_channels == 16
assert gen_attention_block.key_conv.in_channels == 16
assert gen_attention_block.key_conv.in_channels == 16
out = gen_attention_block(imgs)
assert out.shape == imgs.shape
# test attention_type='0100'
imgs = torch.randn(2, 16, 20, 20)
gen_attention_block = GeneralizedAttention(16, attention_type='0100')
assert gen_attention_block.query_conv.in_channels == 16
assert gen_attention_block.appr_geom_fc_x.in_features == 8
assert gen_attention_block.appr_geom_fc_y.in_features == 8
out = gen_attention_block(imgs)
assert out.shape == imgs.shape
# test attention_type='0010'
imgs = torch.randn(2, 16, 20, 20)
gen_attention_block = GeneralizedAttention(16, attention_type='0010')
assert gen_attention_block.key_conv.in_channels == 16
assert hasattr(gen_attention_block, 'appr_bias')
out = gen_attention_block(imgs)
assert out.shape == imgs.shape
# test attention_type='0001'
imgs = torch.randn(2, 16, 20, 20)
gen_attention_block = GeneralizedAttention(16, attention_type='0001')
assert gen_attention_block.appr_geom_fc_x.in_features == 8
assert gen_attention_block.appr_geom_fc_y.in_features == 8
assert hasattr(gen_attention_block, 'geom_bias')
out = gen_attention_block(imgs)
assert out.shape == imgs.shape
# test spatial_range >= 0
imgs = torch.randn(2, 256, 20, 20)
gen_attention_block = GeneralizedAttention(256, spatial_range=10)
assert hasattr(gen_attention_block, 'local_constraint_map')
out = gen_attention_block(imgs)
assert out.shape == imgs.shape
# test q_stride > 1
imgs = torch.randn(2, 16, 20, 20)
gen_attention_block = GeneralizedAttention(16, q_stride=2)
assert gen_attention_block.q_downsample is not None
out = gen_attention_block(imgs)
assert out.shape == imgs.shape
# test kv_stride > 1
imgs = torch.randn(2, 16, 20, 20)
gen_attention_block = GeneralizedAttention(16, kv_stride=2)
assert gen_attention_block.kv_downsample is not None
out = gen_attention_block(imgs)
assert out.shape == imgs.shape
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment