Unverified Commit 34197c5c authored by Kai Chen's avatar Kai Chen Committed by GitHub
Browse files

Add registry and build_from_cfg (#195)

* add registry and build_from_cfg

* add some corner cases

* add some corner cases

* fix the unittest for python 3.5

* minor fix
parent 0863073c
......@@ -8,14 +8,16 @@ from .path import (FileNotFoundError, check_file_exist, fopen, is_filepath,
mkdir_or_exist, scandir, symlink)
from .progressbar import (ProgressBar, track_iter_progress,
track_parallel_progress, track_progress)
from .registry import Registry, build_from_cfg
from .timer import Timer, TimerError, check_time
__all__ = [
'ConfigDict', 'Config', 'is_str', 'iter_cast', 'list_cast', 'tuple_cast',
'is_seq_of', 'is_list_of', 'is_tuple_of', 'slice_list', 'concat_list',
'check_prerequisites', 'requires_package', 'requires_executable',
'is_filepath', 'fopen', 'check_file_exist', 'mkdir_or_exist', 'symlink',
'scandir', 'FileNotFoundError', 'ProgressBar', 'track_progress',
'ConfigDict', 'Config', 'Registry', 'build_from_cfg', 'is_str',
'iter_cast', 'list_cast', 'tuple_cast', 'is_seq_of', 'is_list_of',
'is_tuple_of', 'slice_list', 'concat_list', 'check_prerequisites',
'requires_package', 'requires_executable', 'is_filepath', 'fopen',
'check_file_exist', 'mkdir_or_exist', 'symlink', 'scandir',
'FileNotFoundError', 'ProgressBar', 'track_progress',
'track_iter_progress', 'track_parallel_progress', 'Timer', 'TimerError',
'check_time'
]
import inspect
from functools import partial
from .misc import is_str
class Registry(object):
"""A registry to map strings to classes.
Args:
name (str): Registry name.
"""
def __init__(self, name):
self._name = name
self._module_dict = dict()
def __len__(self):
return len(self._module_dict)
def __repr__(self):
format_str = self.__class__.__name__ + '(name={}, items={})'.format(
self._name, list(self._module_dict.keys()))
return format_str
@property
def name(self):
return self._name
@property
def module_dict(self):
return self._module_dict
def get(self, key):
"""Get the registry record.
Args:
key (str): The class name in string format.
Returns:
class: The corresponding class.
"""
return self._module_dict.get(key, None)
def _register_module(self, module_class, force=False):
if not inspect.isclass(module_class):
raise TypeError('module must be a class, but got {}'.format(
type(module_class)))
module_name = module_class.__name__
if not force and module_name in self._module_dict:
raise KeyError('{} is already registered in {}'.format(
module_name, self.name))
self._module_dict[module_name] = module_class
def register_module(self, cls=None, force=False):
"""Register a module.
A record will be added to `self._module_dict`, whose key is the class
name and value is the class itself.
It can be used as a decorator or a normal function.
Example:
>>> backbones = Registry('backbone')
>>> @backbones.register_module
>>> class ResNet(object):
>>> pass
Example:
>>> backbones = Registry('backbone')
>>> class ResNet(object):
>>> pass
>>> backbones.register_module(ResNet)
Args:
module (:obj:`nn.Module`): Module to be registered.
force (bool, optional): Whether to override an existing class with
the same name. Default: False.
"""
if cls is None:
return partial(self.register_module, force=force)
self._register_module(cls, force=force)
return cls
def build_from_cfg(cfg, registry, default_args=None):
"""Build a module from config dict.
Args:
cfg (dict): Config dict. It should at least contain the key "type".
registry (:obj:`Registry`): The registry to search the type from.
default_args (dict, optional): Default initialization arguments.
Returns:
obj: The constructed object.
"""
if not (isinstance(cfg, dict) and 'type' in cfg):
raise TypeError('cfg must be a dict containing the key "type"')
if not isinstance(registry, Registry):
raise TypeError(
'registry must be an mmcv.Registry object, but got {}'.format(
type(registry)))
if not (isinstance(default_args, dict) or default_args is None):
raise TypeError(
'default_args must be a dict or None, but got {}'.format(
type(default_args)))
args = cfg.copy()
obj_type = args.pop('type')
if is_str(obj_type):
obj_cls = registry.get(obj_type)
if obj_cls is None:
raise KeyError('{} is not in the {} registry'.format(
obj_type, registry.name))
elif inspect.isclass(obj_type):
obj_cls = obj_type
else:
raise TypeError('type must be a str or valid type, but got {}'.format(
type(obj_type)))
if default_args is not None:
for name, value in default_args.items():
args.setdefault(name, value)
return obj_cls(**args)
import pytest
import mmcv
def test_registry():
reg_name = 'cat'
CATS = mmcv.Registry(reg_name)
assert CATS.name == reg_name
assert CATS.module_dict == {}
assert len(CATS) == 0
@CATS.register_module
class BritishShorthair:
pass
assert len(CATS) == 1
assert CATS.get('BritishShorthair') is BritishShorthair
class Munchkin:
pass
CATS.register_module(Munchkin)
assert len(CATS) == 2
assert CATS.get('Munchkin') is Munchkin
with pytest.raises(KeyError):
CATS.register_module(Munchkin)
CATS.register_module(Munchkin, force=True)
assert len(CATS) == 2
with pytest.raises(KeyError):
@CATS.register_module
class BritishShorthair:
pass
@CATS.register_module(force=True)
class BritishShorthair:
pass
assert len(CATS) == 2
assert CATS.get('PersianCat') is None
# The order of dict keys are not preserved in python 3.5
assert repr(CATS) in [
"Registry(name=cat, items=['BritishShorthair', 'Munchkin'])",
"Registry(name=cat, items=['Munchkin', 'BritishShorthair'])"
]
# the registered module should be a class
with pytest.raises(TypeError):
CATS.register_module(0)
def test_build_from_cfg():
BACKBONES = mmcv.Registry('backbone')
@BACKBONES.register_module
class ResNet:
def __init__(self, depth, stages=4):
self.depth = depth
self.stages = stages
@BACKBONES.register_module
class ResNeXt:
def __init__(self, depth, stages=4):
self.depth = depth
self.stages = stages
cfg = dict(type='ResNet', depth=50)
model = mmcv.build_from_cfg(cfg, BACKBONES)
assert isinstance(model, ResNet)
assert model.depth == 50 and model.stages == 4
cfg = dict(type='ResNet', depth=50)
model = mmcv.build_from_cfg(cfg, BACKBONES, default_args={'stages': 3})
assert isinstance(model, ResNet)
assert model.depth == 50 and model.stages == 3
cfg = dict(type='ResNeXt', depth=50, stages=3)
model = mmcv.build_from_cfg(cfg, BACKBONES)
assert isinstance(model, ResNeXt)
assert model.depth == 50 and model.stages == 3
cfg = dict(type=ResNet, depth=50)
model = mmcv.build_from_cfg(cfg, BACKBONES)
assert isinstance(model, ResNet)
assert model.depth == 50 and model.stages == 4
# non-registered class
with pytest.raises(KeyError):
cfg = dict(type='VGG')
model = mmcv.build_from_cfg(cfg, BACKBONES)
# cfg['type'] should be a str or class
with pytest.raises(TypeError):
cfg = dict(type=1000)
model = mmcv.build_from_cfg(cfg, BACKBONES)
# cfg should contain the key "type"
with pytest.raises(TypeError):
cfg = dict(depth=50, stages=4)
model = mmcv.build_from_cfg(cfg, BACKBONES)
# incorrect registry type
with pytest.raises(TypeError):
dict(type='ResNet', depth=50)
model = mmcv.build_from_cfg(cfg, 'BACKBONES')
# incorrect default_args type
with pytest.raises(TypeError):
dict(type='ResNet', depth=50)
model = mmcv.build_from_cfg(cfg, BACKBONES, default_args=0)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment