Unverified Commit 19a02415 authored by Mashiro's avatar Mashiro Committed by GitHub
Browse files

[Refactor] Use MODELS registry in mmengine and delete basemodule (#2172)

* change MODELS to mmengine, delete basemodule

* fix unit test

* remove build from cfg

* add comment and rename TARGET_MODELS to registry

* refine cnn docs

* remove unnecessary check

* refine as comment

* refine build_xxx_conv error message

* fix lint

* fix import registry from mmcv

* remove unused file
parent f6fd6c21
# Copyright (c) OpenMMLab. All rights reserved.
from importlib import import_module
import numpy as np
import pytest
import torch
import torch.nn as nn
from mmengine.registry import MODELS
from mmcv.cnn.bricks import (ACTIVATION_LAYERS, CONV_LAYERS, NORM_LAYERS,
PADDING_LAYERS, PLUGIN_LAYERS,
build_activation_layer, build_conv_layer,
from mmcv.cnn.bricks import (build_activation_layer, build_conv_layer,
build_norm_layer, build_padding_layer,
build_plugin_layer, build_upsample_layer, is_norm)
from mmcv.cnn.bricks.norm import infer_abbr as infer_norm_abbr
......@@ -63,18 +64,19 @@ def test_build_conv_layer():
# sparse convs cannot support the case when groups>1
kwargs.pop('groups')
for type_name, module in CONV_LAYERS.module_dict.items():
for type_name, module in MODELS.module_dict.items():
cfg = dict(type=type_name)
# SparseInverseConv2d and SparseInverseConv3d do not have the argument
# 'dilation'
if type_name == 'SparseInverseConv2d' or type_name == \
'SparseInverseConv3d':
kwargs.pop('dilation')
layer = build_conv_layer(cfg, **kwargs)
assert isinstance(layer, module)
assert layer.in_channels == kwargs['in_channels']
assert layer.out_channels == kwargs['out_channels']
kwargs['dilation'] = 2 # recover the key
if 'conv' in type_name.lower():
layer = build_conv_layer(cfg, **kwargs)
assert isinstance(layer, module)
assert layer.in_channels == kwargs['in_channels']
assert layer.out_channels == kwargs['out_channels']
kwargs['dilation'] = 2 # recover the key
def test_infer_norm_abbr():
......@@ -154,7 +156,9 @@ def test_build_norm_layer():
'IN2d': 'in',
'IN3d': 'in',
}
for type_name, module in NORM_LAYERS.module_dict.items():
for type_name, module in MODELS.module_dict.items():
if type_name not in abbr_mapping:
continue
if type_name == 'MMSyncBN': # skip MMSyncBN
continue
for postfix in ['_test', 1]:
......@@ -172,6 +176,17 @@ def test_build_norm_layer():
def test_build_activation_layer():
act_names = [
'ReLU', 'LeakyReLU', 'PReLU', 'RReLU', 'ReLU6', 'ELU', 'Sigmoid',
'Tanh'
]
for module_name in ['activation', 'hsigmoid', 'hswish', 'swish']:
act_module = import_module(f'mmcv.cnn.bricks.{module_name}')
for key, value in act_module.__dict__.items():
if isinstance(value, type) and issubclass(value, nn.Module):
act_names.append(key)
with pytest.raises(TypeError):
# cfg must be a dict
cfg = 'ReLU'
......@@ -188,10 +203,11 @@ def test_build_activation_layer():
build_activation_layer(cfg)
# test each type of activation layer in activation_cfg
for type_name, module in ACTIVATION_LAYERS.module_dict.items():
cfg['type'] = type_name
layer = build_activation_layer(cfg)
assert isinstance(layer, module)
for type_name, module in MODELS.module_dict.items():
if type_name in act_names:
cfg['type'] = type_name
layer = build_activation_layer(cfg)
assert isinstance(layer, module)
# sanity check for Clamp
act = build_activation_layer(dict(type='Clamp'))
......@@ -207,6 +223,13 @@ def test_build_activation_layer():
def test_build_padding_layer():
pad_names = ['zero', 'reflect', 'replicate']
for module_name in ['padding']:
pad_module = import_module(f'mmcv.cnn.bricks.{module_name}')
for key, value in pad_module.__dict__.items():
if isinstance(value, type) and issubclass(value, nn.Module):
pad_names.append(key)
with pytest.raises(TypeError):
# cfg must be a dict
cfg = 'reflect'
......@@ -222,10 +245,11 @@ def test_build_padding_layer():
cfg = dict(type='FancyPad')
build_padding_layer(cfg)
for type_name, module in PADDING_LAYERS.module_dict.items():
cfg['type'] = type_name
layer = build_padding_layer(cfg, 2)
assert isinstance(layer, module)
for type_name, module in MODELS.module_dict.items():
if type_name in pad_names:
cfg['type'] = type_name
layer = build_padding_layer(cfg, 2)
assert isinstance(layer, module)
input_x = torch.randn(1, 2, 5, 5)
cfg = dict(type='reflect')
......@@ -377,22 +401,21 @@ def test_build_plugin_layer():
name, layer = build_plugin_layer(
cfg, postfix=postfix, in_channels=16, ratio=1. / 4)
assert name == 'context_block' + str(postfix)
assert isinstance(layer, PLUGIN_LAYERS.module_dict['ContextBlock'])
assert isinstance(layer, MODELS.module_dict['ContextBlock'])
# test GeneralizedAttention
for postfix in ['', '_test', 1]:
cfg = dict(type='GeneralizedAttention')
name, layer = build_plugin_layer(cfg, postfix=postfix, in_channels=16)
assert name == 'gen_attention_block' + str(postfix)
assert isinstance(layer,
PLUGIN_LAYERS.module_dict['GeneralizedAttention'])
assert isinstance(layer, MODELS.module_dict['GeneralizedAttention'])
# test NonLocal2d
for postfix in ['', '_test', 1]:
cfg = dict(type='NonLocal2d')
name, layer = build_plugin_layer(cfg, postfix=postfix, in_channels=16)
assert name == 'nonlocal_block' + str(postfix)
assert isinstance(layer, PLUGIN_LAYERS.module_dict['NonLocal2d'])
assert isinstance(layer, MODELS.module_dict['NonLocal2d'])
# test ConvModule
for postfix in ['', '_test', 1]:
......@@ -404,4 +427,4 @@ def test_build_plugin_layer():
out_channels=4,
kernel_size=3)
assert name == 'conv_block' + str(postfix)
assert isinstance(layer, PLUGIN_LAYERS.module_dict['ConvModule'])
assert isinstance(layer, MODELS.module_dict['ConvModule'])
......@@ -5,12 +5,13 @@ from unittest.mock import patch
import pytest
import torch
import torch.nn as nn
from mmengine.registry import MODELS
from mmcv.cnn.bricks import CONV_LAYERS, ConvModule, HSigmoid, HSwish
from mmcv.cnn.bricks import ConvModule, HSigmoid, HSwish
from mmcv.utils import TORCH_VERSION, digit_version
@CONV_LAYERS.register_module()
@MODELS.register_module()
class ExampleConv(nn.Module):
def __init__(self,
......
# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn as nn
import mmcv
from mmcv.cnn import MODELS, build_model_from_cfg
def test_build_model_from_cfg():
BACKBONES = mmcv.Registry('backbone', build_func=build_model_from_cfg)
@BACKBONES.register_module()
class ResNet(nn.Module):
def __init__(self, depth, stages=4):
super().__init__()
self.depth = depth
self.stages = stages
def forward(self, x):
return x
@BACKBONES.register_module()
class ResNeXt(nn.Module):
def __init__(self, depth, stages=4):
super().__init__()
self.depth = depth
self.stages = stages
def forward(self, x):
return x
cfg = dict(type='ResNet', depth=50)
model = BACKBONES.build(cfg)
assert isinstance(model, ResNet)
assert model.depth == 50 and model.stages == 4
cfg = dict(type='ResNeXt', depth=50, stages=3)
model = BACKBONES.build(cfg)
assert isinstance(model, ResNeXt)
assert model.depth == 50 and model.stages == 3
cfg = [
dict(type='ResNet', depth=50),
dict(type='ResNeXt', depth=50, stages=3)
]
model = BACKBONES.build(cfg)
assert isinstance(model, nn.Sequential)
assert isinstance(model[0], ResNet)
assert model[0].depth == 50 and model[0].stages == 4
assert isinstance(model[1], ResNeXt)
assert model[1].depth == 50 and model[1].stages == 3
# test inherit `build_func` from parent
NEW_MODELS = mmcv.Registry('models', parent=MODELS, scope='new')
assert NEW_MODELS.build_func is build_model_from_cfg
# test specify `build_func`
def pseudo_build(cfg):
return cfg
NEW_MODELS = mmcv.Registry(
'models', parent=MODELS, build_func=pseudo_build)
assert NEW_MODELS.build_func is pseudo_build
......@@ -3,6 +3,7 @@ import copy
import pytest
import torch
from mmengine.model import ModuleList
from mmcv.cnn.bricks.drop import DropPath
from mmcv.cnn.bricks.transformer import (FFN, AdaptivePadding,
......@@ -10,7 +11,6 @@ from mmcv.cnn.bricks.transformer import (FFN, AdaptivePadding,
MultiheadAttention, PatchEmbed,
PatchMerging,
TransformerLayerSequence)
from mmcv.runner import ModuleList
def test_adaptive_padding():
......
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment