Unverified Commit 59c1418e authored by Mashiro's avatar Mashiro Committed by GitHub
Browse files

[Enhancement] Make build_xxx_layer allow accepting a class type (#2782)

parent b4dee63c
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import inspect
from typing import Dict, Optional from typing import Dict, Optional
from mmengine.registry import MODELS from mmengine.registry import MODELS
...@@ -35,7 +36,8 @@ def build_conv_layer(cfg: Optional[Dict], *args, **kwargs) -> nn.Module: ...@@ -35,7 +36,8 @@ def build_conv_layer(cfg: Optional[Dict], *args, **kwargs) -> nn.Module:
cfg_ = cfg.copy() cfg_ = cfg.copy()
layer_type = cfg_.pop('type') layer_type = cfg_.pop('type')
if inspect.isclass(layer_type):
return layer_type(*args, **kwargs, **cfg_) # type: ignore
# Switch registry to the target scope. If `conv_layer` cannot be found # Switch registry to the target scope. If `conv_layer` cannot be found
# in the registry, fallback to search `conv_layer` in the # in the registry, fallback to search `conv_layer` in the
# mmengine.MODELS. # mmengine.MODELS.
......
...@@ -98,14 +98,17 @@ def build_norm_layer(cfg: Dict, ...@@ -98,14 +98,17 @@ def build_norm_layer(cfg: Dict,
layer_type = cfg_.pop('type') layer_type = cfg_.pop('type')
# Switch registry to the target scope. If `norm_layer` cannot be found if inspect.isclass(layer_type):
# in the registry, fallback to search `norm_layer` in the norm_layer = layer_type
# mmengine.MODELS. else:
with MODELS.switch_scope_and_registry(None) as registry: # Switch registry to the target scope. If `norm_layer` cannot be found
norm_layer = registry.get(layer_type) # in the registry, fallback to search `norm_layer` in the
if norm_layer is None: # mmengine.MODELS.
raise KeyError(f'Cannot find {norm_layer} in registry under scope ' with MODELS.switch_scope_and_registry(None) as registry:
f'name {registry.scope}') norm_layer = registry.get(layer_type)
if norm_layer is None:
raise KeyError(f'Cannot find {norm_layer} in registry under '
f'scope name {registry.scope}')
abbr = infer_abbr(norm_layer) abbr = infer_abbr(norm_layer)
assert isinstance(postfix, (int, str)) assert isinstance(postfix, (int, str))
...@@ -113,7 +116,7 @@ def build_norm_layer(cfg: Dict, ...@@ -113,7 +116,7 @@ def build_norm_layer(cfg: Dict,
requires_grad = cfg_.pop('requires_grad', True) requires_grad = cfg_.pop('requires_grad', True)
cfg_.setdefault('eps', 1e-5) cfg_.setdefault('eps', 1e-5)
if layer_type != 'GN': if norm_layer is not nn.GroupNorm:
layer = norm_layer(num_features, **cfg_) layer = norm_layer(num_features, **cfg_)
if layer_type == 'SyncBN' and hasattr(layer, '_specify_ddp_gpu_num'): if layer_type == 'SyncBN' and hasattr(layer, '_specify_ddp_gpu_num'):
layer._specify_ddp_gpu_num(1) layer._specify_ddp_gpu_num(1)
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import inspect
from typing import Dict from typing import Dict
import torch.nn as nn import torch.nn as nn
...@@ -27,7 +28,8 @@ def build_padding_layer(cfg: Dict, *args, **kwargs) -> nn.Module: ...@@ -27,7 +28,8 @@ def build_padding_layer(cfg: Dict, *args, **kwargs) -> nn.Module:
cfg_ = cfg.copy() cfg_ = cfg.copy()
padding_type = cfg_.pop('type') padding_type = cfg_.pop('type')
if inspect.isclass(padding_type):
return padding_type(*args, **kwargs, **cfg_)
# Switch registry to the target scope. If `padding_layer` cannot be found # Switch registry to the target scope. If `padding_layer` cannot be found
# in the registry, fallback to search `padding_layer` in the # in the registry, fallback to search `padding_layer` in the
# mmengine.MODELS. # mmengine.MODELS.
......
...@@ -79,15 +79,18 @@ def build_plugin_layer(cfg: Dict, ...@@ -79,15 +79,18 @@ def build_plugin_layer(cfg: Dict,
cfg_ = cfg.copy() cfg_ = cfg.copy()
layer_type = cfg_.pop('type') layer_type = cfg_.pop('type')
if inspect.isclass(layer_type):
# Switch registry to the target scope. If `plugin_layer` cannot be found plugin_layer = layer_type
# in the registry, fallback to search `plugin_layer` in the else:
# mmengine.MODELS. # Switch registry to the target scope. If `plugin_layer` cannot be
with MODELS.switch_scope_and_registry(None) as registry: # found in the registry, fallback to search `plugin_layer` in the
plugin_layer = registry.get(layer_type) # mmengine.MODELS.
if plugin_layer is None: with MODELS.switch_scope_and_registry(None) as registry:
raise KeyError(f'Cannot find {plugin_layer} in registry under scope ' plugin_layer = registry.get(layer_type)
f'name {registry.scope}') if plugin_layer is None:
raise KeyError(
f'Cannot find {plugin_layer} in registry under scope '
f'name {registry.scope}')
abbr = infer_abbr(plugin_layer) abbr = infer_abbr(plugin_layer)
assert isinstance(postfix, (int, str)) assert isinstance(postfix, (int, str))
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import inspect
from typing import Dict from typing import Dict
import torch import torch
...@@ -76,15 +77,18 @@ def build_upsample_layer(cfg: Dict, *args, **kwargs) -> nn.Module: ...@@ -76,15 +77,18 @@ def build_upsample_layer(cfg: Dict, *args, **kwargs) -> nn.Module:
layer_type = cfg_.pop('type') layer_type = cfg_.pop('type')
if inspect.isclass(layer_type):
upsample = layer_type
# Switch registry to the target scope. If `upsample` cannot be found # Switch registry to the target scope. If `upsample` cannot be found
# in the registry, fallback to search `upsample` in the # in the registry, fallback to search `upsample` in the
# mmengine.MODELS. # mmengine.MODELS.
with MODELS.switch_scope_and_registry(None) as registry: else:
upsample = registry.get(layer_type) with MODELS.switch_scope_and_registry(None) as registry:
if upsample is None: upsample = registry.get(layer_type)
raise KeyError(f'Cannot find {upsample} in registry under scope ' if upsample is None:
f'name {registry.scope}') raise KeyError(f'Cannot find {upsample} in registry under scope '
if upsample is nn.Upsample: f'name {registry.scope}')
cfg_['mode'] = layer_type if upsample is nn.Upsample:
cfg_['mode'] = layer_type
layer = upsample(*args, **kwargs, **cfg_) layer = upsample(*args, **kwargs, **cfg_)
return layer return layer
...@@ -293,8 +293,9 @@ def batched_nms(boxes: Tensor, ...@@ -293,8 +293,9 @@ def batched_nms(boxes: Tensor,
max_coordinate + torch.tensor(1).to(boxes)) max_coordinate + torch.tensor(1).to(boxes))
boxes_for_nms = boxes + offsets[:, None] boxes_for_nms = boxes + offsets[:, None]
nms_type = nms_cfg_.pop('type', 'nms') nms_op = nms_cfg_.pop('type', 'nms')
nms_op = eval(nms_type) if isinstance(nms_op, str):
nms_op = eval(nms_op)
split_thr = nms_cfg_.pop('split_thr', 10000) split_thr = nms_cfg_.pop('split_thr', 10000)
# Won't split to multiple nms nodes when exporting to onnx # Won't split to multiple nms nodes when exporting to onnx
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import inspect
from importlib import import_module from importlib import import_module
import numpy as np import numpy as np
...@@ -7,10 +8,14 @@ import torch ...@@ -7,10 +8,14 @@ import torch
import torch.nn as nn import torch.nn as nn
from mmengine.registry import MODELS from mmengine.registry import MODELS
from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm
from torch.nn import ReflectionPad2d, Upsample
from mmcv.cnn.bricks import (build_activation_layer, build_conv_layer, from mmcv.cnn.bricks import (ContextBlock, ConvModule, ConvTranspose2d,
GeneralizedAttention, NonLocal2d,
build_activation_layer, build_conv_layer,
build_norm_layer, build_padding_layer, build_norm_layer, build_padding_layer,
build_plugin_layer, build_upsample_layer, is_norm) build_plugin_layer, build_upsample_layer, is_norm)
from mmcv.cnn.bricks.activation import Clamp
from mmcv.cnn.bricks.norm import infer_abbr as infer_norm_abbr from mmcv.cnn.bricks.norm import infer_abbr as infer_norm_abbr
from mmcv.cnn.bricks.plugin import infer_abbr as infer_plugin_abbr from mmcv.cnn.bricks.plugin import infer_abbr as infer_plugin_abbr
from mmcv.cnn.bricks.upsample import PixelShufflePack from mmcv.cnn.bricks.upsample import PixelShufflePack
...@@ -65,18 +70,19 @@ def test_build_conv_layer(): ...@@ -65,18 +70,19 @@ def test_build_conv_layer():
kwargs.pop('groups') kwargs.pop('groups')
for type_name, module in MODELS.module_dict.items(): for type_name, module in MODELS.module_dict.items():
cfg = dict(type=type_name) for type_name_ in (type_name, module):
# SparseInverseConv2d and SparseInverseConv3d do not have the argument cfg = dict(type=type_name_)
# 'dilation' # SparseInverseConv2d and SparseInverseConv3d do not have the
if type_name == 'SparseInverseConv2d' or type_name == \ # argument 'dilation'
'SparseInverseConv3d': if type_name == 'SparseInverseConv2d' or type_name == \
kwargs.pop('dilation') 'SparseInverseConv3d':
if 'conv' in type_name.lower(): kwargs.pop('dilation')
layer = build_conv_layer(cfg, **kwargs) if 'conv' in type_name.lower():
assert isinstance(layer, module) layer = build_conv_layer(cfg, **kwargs)
assert layer.in_channels == kwargs['in_channels'] assert isinstance(layer, module)
assert layer.out_channels == kwargs['out_channels'] assert layer.in_channels == kwargs['in_channels']
kwargs['dilation'] = 2 # recover the key assert layer.out_channels == kwargs['out_channels']
kwargs['dilation'] = 2 # recover the key
def test_infer_norm_abbr(): def test_infer_norm_abbr():
...@@ -162,17 +168,18 @@ def test_build_norm_layer(): ...@@ -162,17 +168,18 @@ def test_build_norm_layer():
if type_name == 'MMSyncBN': # skip MMSyncBN if type_name == 'MMSyncBN': # skip MMSyncBN
continue continue
for postfix in ['_test', 1]: for postfix in ['_test', 1]:
cfg = dict(type=type_name) for type_name_ in (type_name, module):
if type_name == 'GN': cfg = dict(type=type_name_)
cfg['num_groups'] = 3 if type_name == 'GN':
name, layer = build_norm_layer(cfg, 3, postfix=postfix) cfg['num_groups'] = 3
assert name == abbr_mapping[type_name] + str(postfix) name, layer = build_norm_layer(cfg, 3, postfix=postfix)
assert isinstance(layer, module) assert name == abbr_mapping[type_name] + str(postfix)
if type_name == 'GN': assert isinstance(layer, module)
assert layer.num_channels == 3 if type_name == 'GN':
assert layer.num_groups == cfg['num_groups'] assert layer.num_channels == 3
elif type_name != 'LN': assert layer.num_groups == cfg['num_groups']
assert layer.num_features == 3 elif type_name != 'LN':
assert layer.num_features == 3
def test_build_activation_layer(): def test_build_activation_layer():
...@@ -184,7 +191,7 @@ def test_build_activation_layer(): ...@@ -184,7 +191,7 @@ def test_build_activation_layer():
for module_name in ['activation', 'hsigmoid', 'hswish', 'swish']: for module_name in ['activation', 'hsigmoid', 'hswish', 'swish']:
act_module = import_module(f'mmcv.cnn.bricks.{module_name}') act_module = import_module(f'mmcv.cnn.bricks.{module_name}')
for key, value in act_module.__dict__.items(): for key, value in act_module.__dict__.items():
if isinstance(value, type) and issubclass(value, nn.Module): if inspect.isclass(value) and issubclass(value, nn.Module):
act_names.append(key) act_names.append(key)
with pytest.raises(TypeError): with pytest.raises(TypeError):
...@@ -210,10 +217,12 @@ def test_build_activation_layer(): ...@@ -210,10 +217,12 @@ def test_build_activation_layer():
assert isinstance(layer, module) assert isinstance(layer, module)
# sanity check for Clamp # sanity check for Clamp
act = build_activation_layer(dict(type='Clamp')) for type_name in ('Clamp', Clamp):
x = torch.randn(10) * 1000 act = build_activation_layer(dict(type='Clamp'))
y = act(x) x = torch.randn(10) * 1000
assert np.logical_and((y >= -1).numpy(), (y <= 1).numpy()).all() y = act(x)
assert np.logical_and((y >= -1).numpy(), (y <= 1).numpy()).all()
act = build_activation_layer(dict(type='Clip', min=0)) act = build_activation_layer(dict(type='Clip', min=0))
y = act(x) y = act(x)
assert np.logical_and((y >= 0).numpy(), (y <= 1).numpy()).all() assert np.logical_and((y >= 0).numpy(), (y <= 1).numpy()).all()
...@@ -227,7 +236,7 @@ def test_build_padding_layer(): ...@@ -227,7 +236,7 @@ def test_build_padding_layer():
for module_name in ['padding']: for module_name in ['padding']:
pad_module = import_module(f'mmcv.cnn.bricks.{module_name}') pad_module = import_module(f'mmcv.cnn.bricks.{module_name}')
for key, value in pad_module.__dict__.items(): for key, value in pad_module.__dict__.items():
if isinstance(value, type) and issubclass(value, nn.Module): if inspect.isclass(value) and issubclass(value, nn.Module):
pad_names.append(key) pad_names.append(key)
with pytest.raises(TypeError): with pytest.raises(TypeError):
...@@ -250,12 +259,12 @@ def test_build_padding_layer(): ...@@ -250,12 +259,12 @@ def test_build_padding_layer():
cfg['type'] = type_name cfg['type'] = type_name
layer = build_padding_layer(cfg, 2) layer = build_padding_layer(cfg, 2)
assert isinstance(layer, module) assert isinstance(layer, module)
for type_name in (ReflectionPad2d, 'reflect'):
input_x = torch.randn(1, 2, 5, 5) input_x = torch.randn(1, 2, 5, 5)
cfg = dict(type='reflect') cfg = dict(type=type_name)
padding_layer = build_padding_layer(cfg, 2) padding_layer = build_padding_layer(cfg, 2)
res = padding_layer(input_x) res = padding_layer(input_x)
assert res.shape == (1, 2, 9, 9) assert res.shape == (1, 2, 9, 9)
def test_upsample_layer(): def test_upsample_layer():
...@@ -280,38 +289,48 @@ def test_upsample_layer(): ...@@ -280,38 +289,48 @@ def test_upsample_layer():
assert isinstance(layer, nn.Upsample) assert isinstance(layer, nn.Upsample)
assert layer.mode == type_name assert layer.mode == type_name
cfg = dict()
cfg['type'] = Upsample
layer_from_cls = build_upsample_layer(cfg)
assert isinstance(layer_from_cls, nn.Upsample)
assert layer_from_cls.mode == 'nearest'
cfg = dict( cfg = dict(
type='deconv', in_channels=3, out_channels=3, kernel_size=3, stride=2) type='deconv', in_channels=3, out_channels=3, kernel_size=3, stride=2)
layer = build_upsample_layer(cfg) layer = build_upsample_layer(cfg)
assert isinstance(layer, nn.ConvTranspose2d) assert isinstance(layer, nn.ConvTranspose2d)
cfg = dict(type='deconv') for type_name in ('deconv', ConvTranspose2d):
kwargs = dict(in_channels=3, out_channels=3, kernel_size=3, stride=2) cfg = dict(type=ConvTranspose2d)
layer = build_upsample_layer(cfg, **kwargs) kwargs = dict(in_channels=3, out_channels=3, kernel_size=3, stride=2)
assert isinstance(layer, nn.ConvTranspose2d) layer = build_upsample_layer(cfg, **kwargs)
assert layer.in_channels == kwargs['in_channels'] assert isinstance(layer, nn.ConvTranspose2d)
assert layer.out_channels == kwargs['out_channels'] assert layer.in_channels == kwargs['in_channels']
assert layer.kernel_size == (kwargs['kernel_size'], kwargs['kernel_size']) assert layer.out_channels == kwargs['out_channels']
assert layer.stride == (kwargs['stride'], kwargs['stride']) assert layer.kernel_size == (kwargs['kernel_size'],
kwargs['kernel_size'])
layer = build_upsample_layer(cfg, 3, 3, 3, 2) assert layer.stride == (kwargs['stride'], kwargs['stride'])
assert isinstance(layer, nn.ConvTranspose2d)
assert layer.in_channels == kwargs['in_channels'] layer = build_upsample_layer(cfg, 3, 3, 3, 2)
assert layer.out_channels == kwargs['out_channels'] assert isinstance(layer, nn.ConvTranspose2d)
assert layer.kernel_size == (kwargs['kernel_size'], kwargs['kernel_size']) assert layer.in_channels == kwargs['in_channels']
assert layer.stride == (kwargs['stride'], kwargs['stride']) assert layer.out_channels == kwargs['out_channels']
assert layer.kernel_size == (kwargs['kernel_size'],
cfg = dict( kwargs['kernel_size'])
type='pixel_shuffle', assert layer.stride == (kwargs['stride'], kwargs['stride'])
in_channels=3,
out_channels=3, for type_name in ('pixel_shuffle', PixelShufflePack):
scale_factor=2, cfg = dict(
upsample_kernel=3) type=type_name,
layer = build_upsample_layer(cfg) in_channels=3,
out_channels=3,
scale_factor=2,
upsample_kernel=3)
layer = build_upsample_layer(cfg)
assert isinstance(layer, PixelShufflePack) assert isinstance(layer, PixelShufflePack)
assert layer.scale_factor == 2 assert layer.scale_factor == 2
assert layer.upsample_kernel == 3 assert layer.upsample_kernel == 3
def test_pixel_shuffle_pack(): def test_pixel_shuffle_pack():
...@@ -396,35 +415,42 @@ def test_build_plugin_layer(): ...@@ -396,35 +415,42 @@ def test_build_plugin_layer():
build_plugin_layer(cfg, postfix=[1, 2]) build_plugin_layer(cfg, postfix=[1, 2])
# test ContextBlock # test ContextBlock
for postfix in ['', '_test', 1]: for type_name in ('ContextBlock', ContextBlock):
cfg = dict(type='ContextBlock') for postfix in ['', '_test', 1]:
name, layer = build_plugin_layer( cfg = dict(type=type_name)
cfg, postfix=postfix, in_channels=16, ratio=1. / 4) name, layer = build_plugin_layer(
assert name == 'context_block' + str(postfix) cfg, postfix=postfix, in_channels=16, ratio=1. / 4)
assert isinstance(layer, MODELS.module_dict['ContextBlock']) assert name == 'context_block' + str(postfix)
assert isinstance(layer, MODELS.module_dict['ContextBlock'])
# test GeneralizedAttention # test GeneralizedAttention
for postfix in ['', '_test', 1]: for type_name in ('GeneralizedAttention', GeneralizedAttention):
cfg = dict(type='GeneralizedAttention') for postfix in ['', '_test', 1]:
name, layer = build_plugin_layer(cfg, postfix=postfix, in_channels=16) cfg = dict(type=type_name)
assert name == 'gen_attention_block' + str(postfix) name, layer = build_plugin_layer(
assert isinstance(layer, MODELS.module_dict['GeneralizedAttention']) cfg, postfix=postfix, in_channels=16)
assert name == 'gen_attention_block' + str(postfix)
assert isinstance(layer,
MODELS.module_dict['GeneralizedAttention'])
# test NonLocal2d # test NonLocal2d
for postfix in ['', '_test', 1]: for type_name in ('NonLocal2d', NonLocal2d):
cfg = dict(type='NonLocal2d') for postfix in ['', '_test', 1]:
name, layer = build_plugin_layer(cfg, postfix=postfix, in_channels=16) cfg = dict(type='NonLocal2d')
assert name == 'nonlocal_block' + str(postfix) name, layer = build_plugin_layer(
assert isinstance(layer, MODELS.module_dict['NonLocal2d']) cfg, postfix=postfix, in_channels=16)
assert name == 'nonlocal_block' + str(postfix)
assert isinstance(layer, MODELS.module_dict['NonLocal2d'])
# test ConvModule # test ConvModule
for postfix in ['', '_test', 1]: for postfix in ['', '_test', 1]:
cfg = dict(type='ConvModule') for type_name in ('ConvModule', ConvModule):
name, layer = build_plugin_layer( cfg = dict(type=type_name)
cfg, name, layer = build_plugin_layer(
postfix=postfix, cfg,
in_channels=16, postfix=postfix,
out_channels=4, in_channels=16,
kernel_size=3) out_channels=4,
assert name == 'conv_block' + str(postfix) kernel_size=3)
assert isinstance(layer, MODELS.module_dict['ConvModule']) assert name == 'conv_block' + str(postfix)
assert isinstance(layer, MODELS.module_dict['ConvModule'])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment