Unverified Commit 59c1418e authored by Mashiro's avatar Mashiro Committed by GitHub
Browse files

[Enhancement] Make build_xxx_layer allow accepting a class type (#2782)

parent b4dee63c
# Copyright (c) OpenMMLab. All rights reserved.
import inspect
from typing import Dict, Optional
from mmengine.registry import MODELS
......@@ -35,7 +36,8 @@ def build_conv_layer(cfg: Optional[Dict], *args, **kwargs) -> nn.Module:
cfg_ = cfg.copy()
layer_type = cfg_.pop('type')
if inspect.isclass(layer_type):
return layer_type(*args, **kwargs, **cfg_) # type: ignore
# Switch registry to the target scope. If `conv_layer` cannot be found
# in the registry, fallback to search `conv_layer` in the
# mmengine.MODELS.
......
......@@ -98,14 +98,17 @@ def build_norm_layer(cfg: Dict,
layer_type = cfg_.pop('type')
if inspect.isclass(layer_type):
norm_layer = layer_type
else:
# Switch registry to the target scope. If `norm_layer` cannot be found
# in the registry, fallback to search `norm_layer` in the
# mmengine.MODELS.
with MODELS.switch_scope_and_registry(None) as registry:
norm_layer = registry.get(layer_type)
if norm_layer is None:
raise KeyError(f'Cannot find {norm_layer} in registry under scope '
f'name {registry.scope}')
raise KeyError(f'Cannot find {norm_layer} in registry under '
f'scope name {registry.scope}')
abbr = infer_abbr(norm_layer)
assert isinstance(postfix, (int, str))
......@@ -113,7 +116,7 @@ def build_norm_layer(cfg: Dict,
requires_grad = cfg_.pop('requires_grad', True)
cfg_.setdefault('eps', 1e-5)
if layer_type != 'GN':
if norm_layer is not nn.GroupNorm:
layer = norm_layer(num_features, **cfg_)
if layer_type == 'SyncBN' and hasattr(layer, '_specify_ddp_gpu_num'):
layer._specify_ddp_gpu_num(1)
......
# Copyright (c) OpenMMLab. All rights reserved.
import inspect
from typing import Dict
import torch.nn as nn
......@@ -27,7 +28,8 @@ def build_padding_layer(cfg: Dict, *args, **kwargs) -> nn.Module:
cfg_ = cfg.copy()
padding_type = cfg_.pop('type')
if inspect.isclass(padding_type):
return padding_type(*args, **kwargs, **cfg_)
# Switch registry to the target scope. If `padding_layer` cannot be found
# in the registry, fallback to search `padding_layer` in the
# mmengine.MODELS.
......
......@@ -79,14 +79,17 @@ def build_plugin_layer(cfg: Dict,
cfg_ = cfg.copy()
layer_type = cfg_.pop('type')
# Switch registry to the target scope. If `plugin_layer` cannot be found
# in the registry, fallback to search `plugin_layer` in the
if inspect.isclass(layer_type):
plugin_layer = layer_type
else:
# Switch registry to the target scope. If `plugin_layer` cannot be
# found in the registry, fallback to search `plugin_layer` in the
# mmengine.MODELS.
with MODELS.switch_scope_and_registry(None) as registry:
plugin_layer = registry.get(layer_type)
if plugin_layer is None:
raise KeyError(f'Cannot find {plugin_layer} in registry under scope '
raise KeyError(
f'Cannot find {plugin_layer} in registry under scope '
f'name {registry.scope}')
abbr = infer_abbr(plugin_layer)
......
# Copyright (c) OpenMMLab. All rights reserved.
import inspect
from typing import Dict
import torch
......@@ -76,9 +77,12 @@ def build_upsample_layer(cfg: Dict, *args, **kwargs) -> nn.Module:
layer_type = cfg_.pop('type')
if inspect.isclass(layer_type):
upsample = layer_type
# Switch registry to the target scope. If `upsample` cannot be found
# in the registry, fallback to search `upsample` in the
# mmengine.MODELS.
else:
with MODELS.switch_scope_and_registry(None) as registry:
upsample = registry.get(layer_type)
if upsample is None:
......
......@@ -293,8 +293,9 @@ def batched_nms(boxes: Tensor,
max_coordinate + torch.tensor(1).to(boxes))
boxes_for_nms = boxes + offsets[:, None]
nms_type = nms_cfg_.pop('type', 'nms')
nms_op = eval(nms_type)
nms_op = nms_cfg_.pop('type', 'nms')
if isinstance(nms_op, str):
nms_op = eval(nms_op)
split_thr = nms_cfg_.pop('split_thr', 10000)
# Won't split to multiple nms nodes when exporting to onnx
......
# Copyright (c) OpenMMLab. All rights reserved.
import inspect
from importlib import import_module
import numpy as np
......@@ -7,10 +8,14 @@ import torch
import torch.nn as nn
from mmengine.registry import MODELS
from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm
from torch.nn import ReflectionPad2d, Upsample
from mmcv.cnn.bricks import (build_activation_layer, build_conv_layer,
from mmcv.cnn.bricks import (ContextBlock, ConvModule, ConvTranspose2d,
GeneralizedAttention, NonLocal2d,
build_activation_layer, build_conv_layer,
build_norm_layer, build_padding_layer,
build_plugin_layer, build_upsample_layer, is_norm)
from mmcv.cnn.bricks.activation import Clamp
from mmcv.cnn.bricks.norm import infer_abbr as infer_norm_abbr
from mmcv.cnn.bricks.plugin import infer_abbr as infer_plugin_abbr
from mmcv.cnn.bricks.upsample import PixelShufflePack
......@@ -65,9 +70,10 @@ def test_build_conv_layer():
kwargs.pop('groups')
for type_name, module in MODELS.module_dict.items():
cfg = dict(type=type_name)
# SparseInverseConv2d and SparseInverseConv3d do not have the argument
# 'dilation'
for type_name_ in (type_name, module):
cfg = dict(type=type_name_)
# SparseInverseConv2d and SparseInverseConv3d do not have the
# argument 'dilation'
if type_name == 'SparseInverseConv2d' or type_name == \
'SparseInverseConv3d':
kwargs.pop('dilation')
......@@ -162,7 +168,8 @@ def test_build_norm_layer():
if type_name == 'MMSyncBN': # skip MMSyncBN
continue
for postfix in ['_test', 1]:
cfg = dict(type=type_name)
for type_name_ in (type_name, module):
cfg = dict(type=type_name_)
if type_name == 'GN':
cfg['num_groups'] = 3
name, layer = build_norm_layer(cfg, 3, postfix=postfix)
......@@ -184,7 +191,7 @@ def test_build_activation_layer():
for module_name in ['activation', 'hsigmoid', 'hswish', 'swish']:
act_module = import_module(f'mmcv.cnn.bricks.{module_name}')
for key, value in act_module.__dict__.items():
if isinstance(value, type) and issubclass(value, nn.Module):
if inspect.isclass(value) and issubclass(value, nn.Module):
act_names.append(key)
with pytest.raises(TypeError):
......@@ -210,10 +217,12 @@ def test_build_activation_layer():
assert isinstance(layer, module)
# sanity check for Clamp
for type_name in ('Clamp', Clamp):
act = build_activation_layer(dict(type='Clamp'))
x = torch.randn(10) * 1000
y = act(x)
assert np.logical_and((y >= -1).numpy(), (y <= 1).numpy()).all()
act = build_activation_layer(dict(type='Clip', min=0))
y = act(x)
assert np.logical_and((y >= 0).numpy(), (y <= 1).numpy()).all()
......@@ -227,7 +236,7 @@ def test_build_padding_layer():
for module_name in ['padding']:
pad_module = import_module(f'mmcv.cnn.bricks.{module_name}')
for key, value in pad_module.__dict__.items():
if isinstance(value, type) and issubclass(value, nn.Module):
if inspect.isclass(value) and issubclass(value, nn.Module):
pad_names.append(key)
with pytest.raises(TypeError):
......@@ -250,9 +259,9 @@ def test_build_padding_layer():
cfg['type'] = type_name
layer = build_padding_layer(cfg, 2)
assert isinstance(layer, module)
for type_name in (ReflectionPad2d, 'reflect'):
input_x = torch.randn(1, 2, 5, 5)
cfg = dict(type='reflect')
cfg = dict(type=type_name)
padding_layer = build_padding_layer(cfg, 2)
res = padding_layer(input_x)
assert res.shape == (1, 2, 9, 9)
......@@ -280,29 +289,39 @@ def test_upsample_layer():
assert isinstance(layer, nn.Upsample)
assert layer.mode == type_name
cfg = dict()
cfg['type'] = Upsample
layer_from_cls = build_upsample_layer(cfg)
assert isinstance(layer_from_cls, nn.Upsample)
assert layer_from_cls.mode == 'nearest'
cfg = dict(
type='deconv', in_channels=3, out_channels=3, kernel_size=3, stride=2)
layer = build_upsample_layer(cfg)
assert isinstance(layer, nn.ConvTranspose2d)
cfg = dict(type='deconv')
for type_name in ('deconv', ConvTranspose2d):
cfg = dict(type=ConvTranspose2d)
kwargs = dict(in_channels=3, out_channels=3, kernel_size=3, stride=2)
layer = build_upsample_layer(cfg, **kwargs)
assert isinstance(layer, nn.ConvTranspose2d)
assert layer.in_channels == kwargs['in_channels']
assert layer.out_channels == kwargs['out_channels']
assert layer.kernel_size == (kwargs['kernel_size'], kwargs['kernel_size'])
assert layer.kernel_size == (kwargs['kernel_size'],
kwargs['kernel_size'])
assert layer.stride == (kwargs['stride'], kwargs['stride'])
layer = build_upsample_layer(cfg, 3, 3, 3, 2)
assert isinstance(layer, nn.ConvTranspose2d)
assert layer.in_channels == kwargs['in_channels']
assert layer.out_channels == kwargs['out_channels']
assert layer.kernel_size == (kwargs['kernel_size'], kwargs['kernel_size'])
assert layer.kernel_size == (kwargs['kernel_size'],
kwargs['kernel_size'])
assert layer.stride == (kwargs['stride'], kwargs['stride'])
for type_name in ('pixel_shuffle', PixelShufflePack):
cfg = dict(
type='pixel_shuffle',
type=type_name,
in_channels=3,
out_channels=3,
scale_factor=2,
......@@ -396,30 +415,37 @@ def test_build_plugin_layer():
build_plugin_layer(cfg, postfix=[1, 2])
# test ContextBlock
for type_name in ('ContextBlock', ContextBlock):
for postfix in ['', '_test', 1]:
cfg = dict(type='ContextBlock')
cfg = dict(type=type_name)
name, layer = build_plugin_layer(
cfg, postfix=postfix, in_channels=16, ratio=1. / 4)
assert name == 'context_block' + str(postfix)
assert isinstance(layer, MODELS.module_dict['ContextBlock'])
# test GeneralizedAttention
for type_name in ('GeneralizedAttention', GeneralizedAttention):
for postfix in ['', '_test', 1]:
cfg = dict(type='GeneralizedAttention')
name, layer = build_plugin_layer(cfg, postfix=postfix, in_channels=16)
cfg = dict(type=type_name)
name, layer = build_plugin_layer(
cfg, postfix=postfix, in_channels=16)
assert name == 'gen_attention_block' + str(postfix)
assert isinstance(layer, MODELS.module_dict['GeneralizedAttention'])
assert isinstance(layer,
MODELS.module_dict['GeneralizedAttention'])
# test NonLocal2d
for type_name in ('NonLocal2d', NonLocal2d):
for postfix in ['', '_test', 1]:
cfg = dict(type='NonLocal2d')
name, layer = build_plugin_layer(cfg, postfix=postfix, in_channels=16)
name, layer = build_plugin_layer(
cfg, postfix=postfix, in_channels=16)
assert name == 'nonlocal_block' + str(postfix)
assert isinstance(layer, MODELS.module_dict['NonLocal2d'])
# test ConvModule
for postfix in ['', '_test', 1]:
cfg = dict(type='ConvModule')
for type_name in ('ConvModule', ConvModule):
cfg = dict(type=type_name)
name, layer = build_plugin_layer(
cfg,
postfix=postfix,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment