Unverified Commit 19a02415 authored by Mashiro's avatar Mashiro Committed by GitHub
Browse files

[Refactor] Use MODELS registry in mmengine and delete basemodule (#2172)

* change MODELS to mmengine, delete basemodule

* fix unit test

* remove build from cfg

* add comment and rename TARGET_MODELS to registry

* refine cnn docs

* remove unnecessary check

* refine as comment

* refine build_xxx_conv error message

* fix lint

* fix import registry from mmcv

* remove unused file
parent f6fd6c21
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from importlib import import_module
import numpy as np import numpy as np
import pytest import pytest
import torch import torch
import torch.nn as nn import torch.nn as nn
from mmengine.registry import MODELS
from mmcv.cnn.bricks import (ACTIVATION_LAYERS, CONV_LAYERS, NORM_LAYERS, from mmcv.cnn.bricks import (build_activation_layer, build_conv_layer,
PADDING_LAYERS, PLUGIN_LAYERS,
build_activation_layer, build_conv_layer,
build_norm_layer, build_padding_layer, build_norm_layer, build_padding_layer,
build_plugin_layer, build_upsample_layer, is_norm) build_plugin_layer, build_upsample_layer, is_norm)
from mmcv.cnn.bricks.norm import infer_abbr as infer_norm_abbr from mmcv.cnn.bricks.norm import infer_abbr as infer_norm_abbr
...@@ -63,18 +64,19 @@ def test_build_conv_layer(): ...@@ -63,18 +64,19 @@ def test_build_conv_layer():
# sparse convs cannot support the case when groups>1 # sparse convs cannot support the case when groups>1
kwargs.pop('groups') kwargs.pop('groups')
for type_name, module in CONV_LAYERS.module_dict.items(): for type_name, module in MODELS.module_dict.items():
cfg = dict(type=type_name) cfg = dict(type=type_name)
# SparseInverseConv2d and SparseInverseConv3d do not have the argument # SparseInverseConv2d and SparseInverseConv3d do not have the argument
# 'dilation' # 'dilation'
if type_name == 'SparseInverseConv2d' or type_name == \ if type_name == 'SparseInverseConv2d' or type_name == \
'SparseInverseConv3d': 'SparseInverseConv3d':
kwargs.pop('dilation') kwargs.pop('dilation')
layer = build_conv_layer(cfg, **kwargs) if 'conv' in type_name.lower():
assert isinstance(layer, module) layer = build_conv_layer(cfg, **kwargs)
assert layer.in_channels == kwargs['in_channels'] assert isinstance(layer, module)
assert layer.out_channels == kwargs['out_channels'] assert layer.in_channels == kwargs['in_channels']
kwargs['dilation'] = 2 # recover the key assert layer.out_channels == kwargs['out_channels']
kwargs['dilation'] = 2 # recover the key
def test_infer_norm_abbr(): def test_infer_norm_abbr():
...@@ -154,7 +156,9 @@ def test_build_norm_layer(): ...@@ -154,7 +156,9 @@ def test_build_norm_layer():
'IN2d': 'in', 'IN2d': 'in',
'IN3d': 'in', 'IN3d': 'in',
} }
for type_name, module in NORM_LAYERS.module_dict.items(): for type_name, module in MODELS.module_dict.items():
if type_name not in abbr_mapping:
continue
if type_name == 'MMSyncBN': # skip MMSyncBN if type_name == 'MMSyncBN': # skip MMSyncBN
continue continue
for postfix in ['_test', 1]: for postfix in ['_test', 1]:
...@@ -172,6 +176,17 @@ def test_build_norm_layer(): ...@@ -172,6 +176,17 @@ def test_build_norm_layer():
def test_build_activation_layer(): def test_build_activation_layer():
act_names = [
'ReLU', 'LeakyReLU', 'PReLU', 'RReLU', 'ReLU6', 'ELU', 'Sigmoid',
'Tanh'
]
for module_name in ['activation', 'hsigmoid', 'hswish', 'swish']:
act_module = import_module(f'mmcv.cnn.bricks.{module_name}')
for key, value in act_module.__dict__.items():
if isinstance(value, type) and issubclass(value, nn.Module):
act_names.append(key)
with pytest.raises(TypeError): with pytest.raises(TypeError):
# cfg must be a dict # cfg must be a dict
cfg = 'ReLU' cfg = 'ReLU'
...@@ -188,10 +203,11 @@ def test_build_activation_layer(): ...@@ -188,10 +203,11 @@ def test_build_activation_layer():
build_activation_layer(cfg) build_activation_layer(cfg)
# test each type of activation layer in activation_cfg # test each type of activation layer in activation_cfg
for type_name, module in ACTIVATION_LAYERS.module_dict.items(): for type_name, module in MODELS.module_dict.items():
cfg['type'] = type_name if type_name in act_names:
layer = build_activation_layer(cfg) cfg['type'] = type_name
assert isinstance(layer, module) layer = build_activation_layer(cfg)
assert isinstance(layer, module)
# sanity check for Clamp # sanity check for Clamp
act = build_activation_layer(dict(type='Clamp')) act = build_activation_layer(dict(type='Clamp'))
...@@ -207,6 +223,13 @@ def test_build_activation_layer(): ...@@ -207,6 +223,13 @@ def test_build_activation_layer():
def test_build_padding_layer(): def test_build_padding_layer():
pad_names = ['zero', 'reflect', 'replicate']
for module_name in ['padding']:
pad_module = import_module(f'mmcv.cnn.bricks.{module_name}')
for key, value in pad_module.__dict__.items():
if isinstance(value, type) and issubclass(value, nn.Module):
pad_names.append(key)
with pytest.raises(TypeError): with pytest.raises(TypeError):
# cfg must be a dict # cfg must be a dict
cfg = 'reflect' cfg = 'reflect'
...@@ -222,10 +245,11 @@ def test_build_padding_layer(): ...@@ -222,10 +245,11 @@ def test_build_padding_layer():
cfg = dict(type='FancyPad') cfg = dict(type='FancyPad')
build_padding_layer(cfg) build_padding_layer(cfg)
for type_name, module in PADDING_LAYERS.module_dict.items(): for type_name, module in MODELS.module_dict.items():
cfg['type'] = type_name if type_name in pad_names:
layer = build_padding_layer(cfg, 2) cfg['type'] = type_name
assert isinstance(layer, module) layer = build_padding_layer(cfg, 2)
assert isinstance(layer, module)
input_x = torch.randn(1, 2, 5, 5) input_x = torch.randn(1, 2, 5, 5)
cfg = dict(type='reflect') cfg = dict(type='reflect')
...@@ -377,22 +401,21 @@ def test_build_plugin_layer(): ...@@ -377,22 +401,21 @@ def test_build_plugin_layer():
name, layer = build_plugin_layer( name, layer = build_plugin_layer(
cfg, postfix=postfix, in_channels=16, ratio=1. / 4) cfg, postfix=postfix, in_channels=16, ratio=1. / 4)
assert name == 'context_block' + str(postfix) assert name == 'context_block' + str(postfix)
assert isinstance(layer, PLUGIN_LAYERS.module_dict['ContextBlock']) assert isinstance(layer, MODELS.module_dict['ContextBlock'])
# test GeneralizedAttention # test GeneralizedAttention
for postfix in ['', '_test', 1]: for postfix in ['', '_test', 1]:
cfg = dict(type='GeneralizedAttention') cfg = dict(type='GeneralizedAttention')
name, layer = build_plugin_layer(cfg, postfix=postfix, in_channels=16) name, layer = build_plugin_layer(cfg, postfix=postfix, in_channels=16)
assert name == 'gen_attention_block' + str(postfix) assert name == 'gen_attention_block' + str(postfix)
assert isinstance(layer, assert isinstance(layer, MODELS.module_dict['GeneralizedAttention'])
PLUGIN_LAYERS.module_dict['GeneralizedAttention'])
# test NonLocal2d # test NonLocal2d
for postfix in ['', '_test', 1]: for postfix in ['', '_test', 1]:
cfg = dict(type='NonLocal2d') cfg = dict(type='NonLocal2d')
name, layer = build_plugin_layer(cfg, postfix=postfix, in_channels=16) name, layer = build_plugin_layer(cfg, postfix=postfix, in_channels=16)
assert name == 'nonlocal_block' + str(postfix) assert name == 'nonlocal_block' + str(postfix)
assert isinstance(layer, PLUGIN_LAYERS.module_dict['NonLocal2d']) assert isinstance(layer, MODELS.module_dict['NonLocal2d'])
# test ConvModule # test ConvModule
for postfix in ['', '_test', 1]: for postfix in ['', '_test', 1]:
...@@ -404,4 +427,4 @@ def test_build_plugin_layer(): ...@@ -404,4 +427,4 @@ def test_build_plugin_layer():
out_channels=4, out_channels=4,
kernel_size=3) kernel_size=3)
assert name == 'conv_block' + str(postfix) assert name == 'conv_block' + str(postfix)
assert isinstance(layer, PLUGIN_LAYERS.module_dict['ConvModule']) assert isinstance(layer, MODELS.module_dict['ConvModule'])
...@@ -5,12 +5,13 @@ from unittest.mock import patch ...@@ -5,12 +5,13 @@ from unittest.mock import patch
import pytest import pytest
import torch import torch
import torch.nn as nn import torch.nn as nn
from mmengine.registry import MODELS
from mmcv.cnn.bricks import CONV_LAYERS, ConvModule, HSigmoid, HSwish from mmcv.cnn.bricks import ConvModule, HSigmoid, HSwish
from mmcv.utils import TORCH_VERSION, digit_version from mmcv.utils import TORCH_VERSION, digit_version
@CONV_LAYERS.register_module() @MODELS.register_module()
class ExampleConv(nn.Module): class ExampleConv(nn.Module):
def __init__(self, def __init__(self,
......
# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn as nn
import mmcv
from mmcv.cnn import MODELS, build_model_from_cfg
def test_build_model_from_cfg():
BACKBONES = mmcv.Registry('backbone', build_func=build_model_from_cfg)
@BACKBONES.register_module()
class ResNet(nn.Module):
def __init__(self, depth, stages=4):
super().__init__()
self.depth = depth
self.stages = stages
def forward(self, x):
return x
@BACKBONES.register_module()
class ResNeXt(nn.Module):
def __init__(self, depth, stages=4):
super().__init__()
self.depth = depth
self.stages = stages
def forward(self, x):
return x
cfg = dict(type='ResNet', depth=50)
model = BACKBONES.build(cfg)
assert isinstance(model, ResNet)
assert model.depth == 50 and model.stages == 4
cfg = dict(type='ResNeXt', depth=50, stages=3)
model = BACKBONES.build(cfg)
assert isinstance(model, ResNeXt)
assert model.depth == 50 and model.stages == 3
cfg = [
dict(type='ResNet', depth=50),
dict(type='ResNeXt', depth=50, stages=3)
]
model = BACKBONES.build(cfg)
assert isinstance(model, nn.Sequential)
assert isinstance(model[0], ResNet)
assert model[0].depth == 50 and model[0].stages == 4
assert isinstance(model[1], ResNeXt)
assert model[1].depth == 50 and model[1].stages == 3
# test inherit `build_func` from parent
NEW_MODELS = mmcv.Registry('models', parent=MODELS, scope='new')
assert NEW_MODELS.build_func is build_model_from_cfg
# test specify `build_func`
def pseudo_build(cfg):
return cfg
NEW_MODELS = mmcv.Registry(
'models', parent=MODELS, build_func=pseudo_build)
assert NEW_MODELS.build_func is pseudo_build
...@@ -3,6 +3,7 @@ import copy ...@@ -3,6 +3,7 @@ import copy
import pytest import pytest
import torch import torch
from mmengine.model import ModuleList
from mmcv.cnn.bricks.drop import DropPath from mmcv.cnn.bricks.drop import DropPath
from mmcv.cnn.bricks.transformer import (FFN, AdaptivePadding, from mmcv.cnn.bricks.transformer import (FFN, AdaptivePadding,
...@@ -10,7 +11,6 @@ from mmcv.cnn.bricks.transformer import (FFN, AdaptivePadding, ...@@ -10,7 +11,6 @@ from mmcv.cnn.bricks.transformer import (FFN, AdaptivePadding,
MultiheadAttention, PatchEmbed, MultiheadAttention, PatchEmbed,
PatchMerging, PatchMerging,
TransformerLayerSequence) TransformerLayerSequence)
from mmcv.runner import ModuleList
def test_adaptive_padding(): def test_adaptive_padding():
......
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment