Unverified Commit d4fac3a6 authored by Kai Chen's avatar Kai Chen Committed by GitHub
Browse files

Use parrots wrapper (#272)

* use parrots wrapper for norms

* add unittests for is_norm()

* add a test case
parent 6aa51315
......@@ -4,7 +4,7 @@ from .bricks import (ACTIVATION_LAYERS, CONV_LAYERS, NORM_LAYERS,
PADDING_LAYERS, UPSAMPLE_LAYERS, ConvModule, Scale,
build_activation_layer, build_conv_layer,
build_norm_layer, build_padding_layer,
build_upsample_layer)
build_upsample_layer, is_norm)
from .resnet import ResNet, make_res_layer
from .vgg import VGG, make_vgg_layer
from .weight_init import (bias_init_with_prob, caffe2_xavier_init,
......@@ -16,6 +16,7 @@ __all__ = [
'constant_init', 'xavier_init', 'normal_init', 'uniform_init',
'kaiming_init', 'caffe2_xavier_init', 'bias_init_with_prob', 'ConvModule',
'build_activation_layer', 'build_conv_layer', 'build_norm_layer',
'build_padding_layer', 'build_upsample_layer', 'ACTIVATION_LAYERS',
'CONV_LAYERS', 'NORM_LAYERS', 'PADDING_LAYERS', 'UPSAMPLE_LAYERS', 'Scale'
'build_padding_layer', 'build_upsample_layer', 'is_norm',
'ACTIVATION_LAYERS', 'CONV_LAYERS', 'NORM_LAYERS', 'PADDING_LAYERS',
'UPSAMPLE_LAYERS', 'Scale'
]
from .activation import build_activation_layer
from .conv import build_conv_layer
from .conv_module import ConvModule
from .norm import build_norm_layer
from .norm import build_norm_layer, is_norm
from .padding import build_padding_layer
from .registry import (ACTIVATION_LAYERS, CONV_LAYERS, NORM_LAYERS,
PADDING_LAYERS, UPSAMPLE_LAYERS)
......@@ -11,6 +11,6 @@ from .upsample import build_upsample_layer
__all__ = [
'ConvModule', 'build_activation_layer', 'build_conv_layer',
'build_norm_layer', 'build_padding_layer', 'build_upsample_layer',
'ACTIVATION_LAYERS', 'CONV_LAYERS', 'NORM_LAYERS', 'PADDING_LAYERS',
'UPSAMPLE_LAYERS', 'Scale'
'is_norm', 'ACTIVATION_LAYERS', 'CONV_LAYERS', 'NORM_LAYERS',
'PADDING_LAYERS', 'UPSAMPLE_LAYERS', 'Scale'
]
import inspect
import torch.nn as nn
from torch.nn.modules.batchnorm import _BatchNorm
from torch.nn.modules.instancenorm import _InstanceNorm
from mmcv.utils import is_tuple_of
from mmcv.utils.parrots_wrapper import SyncBatchNorm, _BatchNorm, _InstanceNorm
from .registry import NORM_LAYERS
NORM_LAYERS.register_module('BN', module=nn.BatchNorm2d)
NORM_LAYERS.register_module('BN1d', module=nn.BatchNorm1d)
NORM_LAYERS.register_module('BN2d', module=nn.BatchNorm2d)
NORM_LAYERS.register_module('BN3d', module=nn.BatchNorm3d)
NORM_LAYERS.register_module('SyncBN', module=nn.SyncBatchNorm)
NORM_LAYERS.register_module('SyncBN', module=SyncBatchNorm)
NORM_LAYERS.register_module('GN', module=nn.GroupNorm)
NORM_LAYERS.register_module('LN', module=nn.LayerNorm)
NORM_LAYERS.register_module('IN', module=nn.InstanceNorm2d)
......@@ -116,3 +116,31 @@ def build_norm_layer(cfg, num_features, postfix=''):
param.requires_grad = requires_grad
return name, layer
def is_norm(layer, exclude=None):
"""Check if a layer is a normalization layer.
Args:
layer (nn.Module): The layer to be checked.
exclude (type | tuple[type]): Types to be excluded.
Returns:
bool: Whether the layer is a norm layer.
"""
if exclude is not None:
if not isinstance(exclude, tuple):
exclude = (exclude, )
if not is_tuple_of(exclude, type):
raise TypeError(
f'"exclude" must be either None or type or a tuple of types, '
f'but got {type(exclude)}: {exclude}')
if exclude and isinstance(layer, exclude):
return False
all_norm_bases = (_BatchNorm, _InstanceNorm, nn.GroupNorm, nn.LayerNorm)
if isinstance(layer, all_norm_bases):
return True
else:
return False
......@@ -5,9 +5,11 @@ import torch.nn as nn
from mmcv.cnn.bricks import (ACTIVATION_LAYERS, CONV_LAYERS, NORM_LAYERS,
PADDING_LAYERS, build_activation_layer,
build_conv_layer, build_norm_layer,
build_padding_layer, build_upsample_layer)
build_padding_layer, build_upsample_layer,
is_norm)
from mmcv.cnn.bricks.norm import infer_abbr
from mmcv.cnn.bricks.upsample import PixelShufflePack
from mmcv.utils.parrots_wrapper import _BatchNorm
def test_build_conv_layer():
......@@ -243,3 +245,38 @@ def test_pixel_shuffle_pack():
assert pixel_shuffle.upsample_conv.kernel_size == (3, 3)
x_out = pixel_shuffle(x_in)
assert x_out.shape == (2, 3, 20, 20)
def test_is_norm():
norm_set1 = [
nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.InstanceNorm1d,
nn.InstanceNorm2d, nn.InstanceNorm3d, nn.LayerNorm
]
norm_set2 = [nn.GroupNorm]
for norm_type in norm_set1:
layer = norm_type(3)
assert is_norm(layer)
assert not is_norm(layer, exclude=(norm_type, ))
for norm_type in norm_set2:
layer = norm_type(3, 6)
assert is_norm(layer)
assert not is_norm(layer, exclude=(norm_type, ))
class MyNorm(nn.BatchNorm2d):
pass
layer = MyNorm(3)
assert is_norm(layer)
assert not is_norm(layer, exclude=_BatchNorm)
assert not is_norm(layer, exclude=(_BatchNorm, ))
layer = nn.Conv2d(3, 8, 1)
assert not is_norm(layer)
with pytest.raises(TypeError):
layer = nn.BatchNorm1d(3)
is_norm(layer, exclude='BN')
with pytest.raises(TypeError):
layer = nn.BatchNorm1d(3)
is_norm(layer, exclude=('BN', ))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment