"docs/vscode:/vscode.git/clone" did not exist on "32b85dfa8d4a5fa54469ddc72be89d827c1ee9d6"
Unverified Commit 59c1418e authored by Mashiro's avatar Mashiro Committed by GitHub
Browse files

[Enhancement] Make build_xxx_layer allow accepting a class type (#2782)

parent b4dee63c
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import inspect
from typing import Dict, Optional from typing import Dict, Optional
from mmengine.registry import MODELS from mmengine.registry import MODELS
...@@ -35,7 +36,8 @@ def build_conv_layer(cfg: Optional[Dict], *args, **kwargs) -> nn.Module: ...@@ -35,7 +36,8 @@ def build_conv_layer(cfg: Optional[Dict], *args, **kwargs) -> nn.Module:
cfg_ = cfg.copy() cfg_ = cfg.copy()
layer_type = cfg_.pop('type') layer_type = cfg_.pop('type')
if inspect.isclass(layer_type):
return layer_type(*args, **kwargs, **cfg_) # type: ignore
# Switch registry to the target scope. If `conv_layer` cannot be found # Switch registry to the target scope. If `conv_layer` cannot be found
# in the registry, fallback to search `conv_layer` in the # in the registry, fallback to search `conv_layer` in the
# mmengine.MODELS. # mmengine.MODELS.
......
...@@ -98,14 +98,17 @@ def build_norm_layer(cfg: Dict, ...@@ -98,14 +98,17 @@ def build_norm_layer(cfg: Dict,
layer_type = cfg_.pop('type') layer_type = cfg_.pop('type')
if inspect.isclass(layer_type):
norm_layer = layer_type
else:
# Switch registry to the target scope. If `norm_layer` cannot be found # Switch registry to the target scope. If `norm_layer` cannot be found
# in the registry, fallback to search `norm_layer` in the # in the registry, fallback to search `norm_layer` in the
# mmengine.MODELS. # mmengine.MODELS.
with MODELS.switch_scope_and_registry(None) as registry: with MODELS.switch_scope_and_registry(None) as registry:
norm_layer = registry.get(layer_type) norm_layer = registry.get(layer_type)
if norm_layer is None: if norm_layer is None:
raise KeyError(f'Cannot find {norm_layer} in registry under scope ' raise KeyError(f'Cannot find {norm_layer} in registry under '
f'name {registry.scope}') f'scope name {registry.scope}')
abbr = infer_abbr(norm_layer) abbr = infer_abbr(norm_layer)
assert isinstance(postfix, (int, str)) assert isinstance(postfix, (int, str))
...@@ -113,7 +116,7 @@ def build_norm_layer(cfg: Dict, ...@@ -113,7 +116,7 @@ def build_norm_layer(cfg: Dict,
requires_grad = cfg_.pop('requires_grad', True) requires_grad = cfg_.pop('requires_grad', True)
cfg_.setdefault('eps', 1e-5) cfg_.setdefault('eps', 1e-5)
if layer_type != 'GN': if norm_layer is not nn.GroupNorm:
layer = norm_layer(num_features, **cfg_) layer = norm_layer(num_features, **cfg_)
if layer_type == 'SyncBN' and hasattr(layer, '_specify_ddp_gpu_num'): if layer_type == 'SyncBN' and hasattr(layer, '_specify_ddp_gpu_num'):
layer._specify_ddp_gpu_num(1) layer._specify_ddp_gpu_num(1)
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import inspect
from typing import Dict from typing import Dict
import torch.nn as nn import torch.nn as nn
...@@ -27,7 +28,8 @@ def build_padding_layer(cfg: Dict, *args, **kwargs) -> nn.Module: ...@@ -27,7 +28,8 @@ def build_padding_layer(cfg: Dict, *args, **kwargs) -> nn.Module:
cfg_ = cfg.copy() cfg_ = cfg.copy()
padding_type = cfg_.pop('type') padding_type = cfg_.pop('type')
if inspect.isclass(padding_type):
return padding_type(*args, **kwargs, **cfg_)
# Switch registry to the target scope. If `padding_layer` cannot be found # Switch registry to the target scope. If `padding_layer` cannot be found
# in the registry, fallback to search `padding_layer` in the # in the registry, fallback to search `padding_layer` in the
# mmengine.MODELS. # mmengine.MODELS.
......
...@@ -79,14 +79,17 @@ def build_plugin_layer(cfg: Dict, ...@@ -79,14 +79,17 @@ def build_plugin_layer(cfg: Dict,
cfg_ = cfg.copy() cfg_ = cfg.copy()
layer_type = cfg_.pop('type') layer_type = cfg_.pop('type')
if inspect.isclass(layer_type):
# Switch registry to the target scope. If `plugin_layer` cannot be found plugin_layer = layer_type
# in the registry, fallback to search `plugin_layer` in the else:
# Switch registry to the target scope. If `plugin_layer` cannot be
# found in the registry, fallback to search `plugin_layer` in the
# mmengine.MODELS. # mmengine.MODELS.
with MODELS.switch_scope_and_registry(None) as registry: with MODELS.switch_scope_and_registry(None) as registry:
plugin_layer = registry.get(layer_type) plugin_layer = registry.get(layer_type)
if plugin_layer is None: if plugin_layer is None:
raise KeyError(f'Cannot find {plugin_layer} in registry under scope ' raise KeyError(
f'Cannot find {plugin_layer} in registry under scope '
f'name {registry.scope}') f'name {registry.scope}')
abbr = infer_abbr(plugin_layer) abbr = infer_abbr(plugin_layer)
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import inspect
from typing import Dict from typing import Dict
import torch import torch
...@@ -76,9 +77,12 @@ def build_upsample_layer(cfg: Dict, *args, **kwargs) -> nn.Module: ...@@ -76,9 +77,12 @@ def build_upsample_layer(cfg: Dict, *args, **kwargs) -> nn.Module:
layer_type = cfg_.pop('type') layer_type = cfg_.pop('type')
if inspect.isclass(layer_type):
upsample = layer_type
# Switch registry to the target scope. If `upsample` cannot be found # Switch registry to the target scope. If `upsample` cannot be found
# in the registry, fallback to search `upsample` in the # in the registry, fallback to search `upsample` in the
# mmengine.MODELS. # mmengine.MODELS.
else:
with MODELS.switch_scope_and_registry(None) as registry: with MODELS.switch_scope_and_registry(None) as registry:
upsample = registry.get(layer_type) upsample = registry.get(layer_type)
if upsample is None: if upsample is None:
......
...@@ -293,8 +293,9 @@ def batched_nms(boxes: Tensor, ...@@ -293,8 +293,9 @@ def batched_nms(boxes: Tensor,
max_coordinate + torch.tensor(1).to(boxes)) max_coordinate + torch.tensor(1).to(boxes))
boxes_for_nms = boxes + offsets[:, None] boxes_for_nms = boxes + offsets[:, None]
nms_type = nms_cfg_.pop('type', 'nms') nms_op = nms_cfg_.pop('type', 'nms')
nms_op = eval(nms_type) if isinstance(nms_op, str):
nms_op = eval(nms_op)
split_thr = nms_cfg_.pop('split_thr', 10000) split_thr = nms_cfg_.pop('split_thr', 10000)
# Won't split to multiple nms nodes when exporting to onnx # Won't split to multiple nms nodes when exporting to onnx
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import inspect
from importlib import import_module from importlib import import_module
import numpy as np import numpy as np
...@@ -7,10 +8,14 @@ import torch ...@@ -7,10 +8,14 @@ import torch
import torch.nn as nn import torch.nn as nn
from mmengine.registry import MODELS from mmengine.registry import MODELS
from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm
from torch.nn import ReflectionPad2d, Upsample
from mmcv.cnn.bricks import (build_activation_layer, build_conv_layer, from mmcv.cnn.bricks import (ContextBlock, ConvModule, ConvTranspose2d,
GeneralizedAttention, NonLocal2d,
build_activation_layer, build_conv_layer,
build_norm_layer, build_padding_layer, build_norm_layer, build_padding_layer,
build_plugin_layer, build_upsample_layer, is_norm) build_plugin_layer, build_upsample_layer, is_norm)
from mmcv.cnn.bricks.activation import Clamp
from mmcv.cnn.bricks.norm import infer_abbr as infer_norm_abbr from mmcv.cnn.bricks.norm import infer_abbr as infer_norm_abbr
from mmcv.cnn.bricks.plugin import infer_abbr as infer_plugin_abbr from mmcv.cnn.bricks.plugin import infer_abbr as infer_plugin_abbr
from mmcv.cnn.bricks.upsample import PixelShufflePack from mmcv.cnn.bricks.upsample import PixelShufflePack
...@@ -65,9 +70,10 @@ def test_build_conv_layer(): ...@@ -65,9 +70,10 @@ def test_build_conv_layer():
kwargs.pop('groups') kwargs.pop('groups')
for type_name, module in MODELS.module_dict.items(): for type_name, module in MODELS.module_dict.items():
cfg = dict(type=type_name) for type_name_ in (type_name, module):
# SparseInverseConv2d and SparseInverseConv3d do not have the argument cfg = dict(type=type_name_)
# 'dilation' # SparseInverseConv2d and SparseInverseConv3d do not have the
# argument 'dilation'
if type_name == 'SparseInverseConv2d' or type_name == \ if type_name == 'SparseInverseConv2d' or type_name == \
'SparseInverseConv3d': 'SparseInverseConv3d':
kwargs.pop('dilation') kwargs.pop('dilation')
...@@ -162,7 +168,8 @@ def test_build_norm_layer(): ...@@ -162,7 +168,8 @@ def test_build_norm_layer():
if type_name == 'MMSyncBN': # skip MMSyncBN if type_name == 'MMSyncBN': # skip MMSyncBN
continue continue
for postfix in ['_test', 1]: for postfix in ['_test', 1]:
cfg = dict(type=type_name) for type_name_ in (type_name, module):
cfg = dict(type=type_name_)
if type_name == 'GN': if type_name == 'GN':
cfg['num_groups'] = 3 cfg['num_groups'] = 3
name, layer = build_norm_layer(cfg, 3, postfix=postfix) name, layer = build_norm_layer(cfg, 3, postfix=postfix)
...@@ -184,7 +191,7 @@ def test_build_activation_layer(): ...@@ -184,7 +191,7 @@ def test_build_activation_layer():
for module_name in ['activation', 'hsigmoid', 'hswish', 'swish']: for module_name in ['activation', 'hsigmoid', 'hswish', 'swish']:
act_module = import_module(f'mmcv.cnn.bricks.{module_name}') act_module = import_module(f'mmcv.cnn.bricks.{module_name}')
for key, value in act_module.__dict__.items(): for key, value in act_module.__dict__.items():
if isinstance(value, type) and issubclass(value, nn.Module): if inspect.isclass(value) and issubclass(value, nn.Module):
act_names.append(key) act_names.append(key)
with pytest.raises(TypeError): with pytest.raises(TypeError):
...@@ -210,10 +217,12 @@ def test_build_activation_layer(): ...@@ -210,10 +217,12 @@ def test_build_activation_layer():
assert isinstance(layer, module) assert isinstance(layer, module)
# sanity check for Clamp # sanity check for Clamp
for type_name in ('Clamp', Clamp):
act = build_activation_layer(dict(type='Clamp')) act = build_activation_layer(dict(type='Clamp'))
x = torch.randn(10) * 1000 x = torch.randn(10) * 1000
y = act(x) y = act(x)
assert np.logical_and((y >= -1).numpy(), (y <= 1).numpy()).all() assert np.logical_and((y >= -1).numpy(), (y <= 1).numpy()).all()
act = build_activation_layer(dict(type='Clip', min=0)) act = build_activation_layer(dict(type='Clip', min=0))
y = act(x) y = act(x)
assert np.logical_and((y >= 0).numpy(), (y <= 1).numpy()).all() assert np.logical_and((y >= 0).numpy(), (y <= 1).numpy()).all()
...@@ -227,7 +236,7 @@ def test_build_padding_layer(): ...@@ -227,7 +236,7 @@ def test_build_padding_layer():
for module_name in ['padding']: for module_name in ['padding']:
pad_module = import_module(f'mmcv.cnn.bricks.{module_name}') pad_module = import_module(f'mmcv.cnn.bricks.{module_name}')
for key, value in pad_module.__dict__.items(): for key, value in pad_module.__dict__.items():
if isinstance(value, type) and issubclass(value, nn.Module): if inspect.isclass(value) and issubclass(value, nn.Module):
pad_names.append(key) pad_names.append(key)
with pytest.raises(TypeError): with pytest.raises(TypeError):
...@@ -250,9 +259,9 @@ def test_build_padding_layer(): ...@@ -250,9 +259,9 @@ def test_build_padding_layer():
cfg['type'] = type_name cfg['type'] = type_name
layer = build_padding_layer(cfg, 2) layer = build_padding_layer(cfg, 2)
assert isinstance(layer, module) assert isinstance(layer, module)
for type_name in (ReflectionPad2d, 'reflect'):
input_x = torch.randn(1, 2, 5, 5) input_x = torch.randn(1, 2, 5, 5)
cfg = dict(type='reflect') cfg = dict(type=type_name)
padding_layer = build_padding_layer(cfg, 2) padding_layer = build_padding_layer(cfg, 2)
res = padding_layer(input_x) res = padding_layer(input_x)
assert res.shape == (1, 2, 9, 9) assert res.shape == (1, 2, 9, 9)
...@@ -280,29 +289,39 @@ def test_upsample_layer(): ...@@ -280,29 +289,39 @@ def test_upsample_layer():
assert isinstance(layer, nn.Upsample) assert isinstance(layer, nn.Upsample)
assert layer.mode == type_name assert layer.mode == type_name
cfg = dict()
cfg['type'] = Upsample
layer_from_cls = build_upsample_layer(cfg)
assert isinstance(layer_from_cls, nn.Upsample)
assert layer_from_cls.mode == 'nearest'
cfg = dict( cfg = dict(
type='deconv', in_channels=3, out_channels=3, kernel_size=3, stride=2) type='deconv', in_channels=3, out_channels=3, kernel_size=3, stride=2)
layer = build_upsample_layer(cfg) layer = build_upsample_layer(cfg)
assert isinstance(layer, nn.ConvTranspose2d) assert isinstance(layer, nn.ConvTranspose2d)
cfg = dict(type='deconv') for type_name in ('deconv', ConvTranspose2d):
cfg = dict(type=ConvTranspose2d)
kwargs = dict(in_channels=3, out_channels=3, kernel_size=3, stride=2) kwargs = dict(in_channels=3, out_channels=3, kernel_size=3, stride=2)
layer = build_upsample_layer(cfg, **kwargs) layer = build_upsample_layer(cfg, **kwargs)
assert isinstance(layer, nn.ConvTranspose2d) assert isinstance(layer, nn.ConvTranspose2d)
assert layer.in_channels == kwargs['in_channels'] assert layer.in_channels == kwargs['in_channels']
assert layer.out_channels == kwargs['out_channels'] assert layer.out_channels == kwargs['out_channels']
assert layer.kernel_size == (kwargs['kernel_size'], kwargs['kernel_size']) assert layer.kernel_size == (kwargs['kernel_size'],
kwargs['kernel_size'])
assert layer.stride == (kwargs['stride'], kwargs['stride']) assert layer.stride == (kwargs['stride'], kwargs['stride'])
layer = build_upsample_layer(cfg, 3, 3, 3, 2) layer = build_upsample_layer(cfg, 3, 3, 3, 2)
assert isinstance(layer, nn.ConvTranspose2d) assert isinstance(layer, nn.ConvTranspose2d)
assert layer.in_channels == kwargs['in_channels'] assert layer.in_channels == kwargs['in_channels']
assert layer.out_channels == kwargs['out_channels'] assert layer.out_channels == kwargs['out_channels']
assert layer.kernel_size == (kwargs['kernel_size'], kwargs['kernel_size']) assert layer.kernel_size == (kwargs['kernel_size'],
kwargs['kernel_size'])
assert layer.stride == (kwargs['stride'], kwargs['stride']) assert layer.stride == (kwargs['stride'], kwargs['stride'])
for type_name in ('pixel_shuffle', PixelShufflePack):
cfg = dict( cfg = dict(
type='pixel_shuffle', type=type_name,
in_channels=3, in_channels=3,
out_channels=3, out_channels=3,
scale_factor=2, scale_factor=2,
...@@ -396,30 +415,37 @@ def test_build_plugin_layer(): ...@@ -396,30 +415,37 @@ def test_build_plugin_layer():
build_plugin_layer(cfg, postfix=[1, 2]) build_plugin_layer(cfg, postfix=[1, 2])
# test ContextBlock # test ContextBlock
for type_name in ('ContextBlock', ContextBlock):
for postfix in ['', '_test', 1]: for postfix in ['', '_test', 1]:
cfg = dict(type='ContextBlock') cfg = dict(type=type_name)
name, layer = build_plugin_layer( name, layer = build_plugin_layer(
cfg, postfix=postfix, in_channels=16, ratio=1. / 4) cfg, postfix=postfix, in_channels=16, ratio=1. / 4)
assert name == 'context_block' + str(postfix) assert name == 'context_block' + str(postfix)
assert isinstance(layer, MODELS.module_dict['ContextBlock']) assert isinstance(layer, MODELS.module_dict['ContextBlock'])
# test GeneralizedAttention # test GeneralizedAttention
for type_name in ('GeneralizedAttention', GeneralizedAttention):
for postfix in ['', '_test', 1]: for postfix in ['', '_test', 1]:
cfg = dict(type='GeneralizedAttention') cfg = dict(type=type_name)
name, layer = build_plugin_layer(cfg, postfix=postfix, in_channels=16) name, layer = build_plugin_layer(
cfg, postfix=postfix, in_channels=16)
assert name == 'gen_attention_block' + str(postfix) assert name == 'gen_attention_block' + str(postfix)
assert isinstance(layer, MODELS.module_dict['GeneralizedAttention']) assert isinstance(layer,
MODELS.module_dict['GeneralizedAttention'])
# test NonLocal2d # test NonLocal2d
for type_name in ('NonLocal2d', NonLocal2d):
for postfix in ['', '_test', 1]: for postfix in ['', '_test', 1]:
cfg = dict(type='NonLocal2d') cfg = dict(type='NonLocal2d')
name, layer = build_plugin_layer(cfg, postfix=postfix, in_channels=16) name, layer = build_plugin_layer(
cfg, postfix=postfix, in_channels=16)
assert name == 'nonlocal_block' + str(postfix) assert name == 'nonlocal_block' + str(postfix)
assert isinstance(layer, MODELS.module_dict['NonLocal2d']) assert isinstance(layer, MODELS.module_dict['NonLocal2d'])
# test ConvModule # test ConvModule
for postfix in ['', '_test', 1]: for postfix in ['', '_test', 1]:
cfg = dict(type='ConvModule') for type_name in ('ConvModule', ConvModule):
cfg = dict(type=type_name)
name, layer = build_plugin_layer( name, layer = build_plugin_layer(
cfg, cfg,
postfix=postfix, postfix=postfix,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment