Unverified Commit f7de63fd authored by Kai Chen's avatar Kai Chen Committed by GitHub
Browse files

Support specifying names for registry.register_module() (#251)

* support specifying names for registry.register_module()

* minor fix

* add more unittests
parent c203419f
import inspect
import warnings
from functools import partial
from .misc import is_str
......@@ -24,7 +25,7 @@ class Registry(object):
def __repr__(self):
format_str = self.__class__.__name__ + \
f'(name={self._name}, ' \
f'items={list(self._module_dict.keys())})'
f'items={self._module_dict})'
return format_str
@property
......@@ -46,44 +47,80 @@ class Registry(object):
"""
return self._module_dict.get(key, None)
def _register_module(self, module_class, force=False):
def _register_module(self, module_class, module_name=None, force=False):
if not inspect.isclass(module_class):
raise TypeError('module must be a class, '
f'but got {type(module_class)}')
module_name = module_class.__name__
if module_name is None:
module_name = module_class.__name__
if not force and module_name in self._module_dict:
raise KeyError(f'{module_name} is already registered '
f'in {self.name}')
self._module_dict[module_name] = module_class
def register_module(self, cls=None, force=False):
def deprecated_register_module(self, cls=None, force=False):
warnings.warn(
'The old API of register_module(module, force=False) '
'is deprecated and will be removed, please use the new API '
'register_module(name=None, force=False, module=None) instead.',
DeprecationWarning)
if cls is None:
return partial(self.deprecated_register_module, force=force)
self._register_module(cls, force=force)
return cls
def register_module(self, name=None, force=False, module=None):
"""Register a module.
A record will be added to `self._module_dict`, whose key is the class
name and value is the class itself.
name or the specified name, and value is the class itself.
It can be used as a decorator or a normal function.
Example:
>>> backbones = Registry('backbone')
>>> @backbones.register_module
>>> class ResNet(object):
>>> @backbones.register_module()
>>> class ResNet:
>>> pass
Example:
>>> backbones = Registry('backbone')
>>> class ResNet(object):
>>> @backbones.register_module(name='mnet')
>>> class MobileNet:
>>> pass
>>> backbones = Registry('backbone')
>>> class ResNet:
>>> pass
>>> backbones.register_module(ResNet)
Args:
module (:obj:`nn.Module`): Module to be registered.
name (str | None): The module name to be registered. If not
specified, the class name will be used.
force (bool, optional): Whether to override an existing class with
the same name. Default: False.
module (type): Module class to be registered.
"""
if cls is None:
return partial(self.register_module, force=force)
self._register_module(cls, force=force)
return cls
# NOTE: This is a walkaround to be compatible with the old api,
# while it may introduce unexpected bugs.
if isinstance(name, type):
return self.deprecated_register_module(name, force=force)
# use it as a normal method: x.register_module(module=SomeClass)
if module is not None:
self._register_module(
module_class=module, module_name=name, force=force)
return
# raise the error ahead of time
if not (name is None or isinstance(name, str)):
raise TypeError(f'name must be a str, but got {type(name)}')
# use it as a decorator: @x.register_module()
def _register(cls):
self._register_module(
module_class=cls, module_name=name, force=force)
return _register
def build_from_cfg(cfg, registry, default_args=None):
......
......@@ -4,9 +4,8 @@ import mmcv
def test_registry():
reg_name = 'cat'
CATS = mmcv.Registry(reg_name)
assert CATS.name == reg_name
CATS = mmcv.Registry('cat')
assert CATS.name == 'cat'
assert CATS.module_dict == {}
assert len(CATS) == 0
......@@ -31,9 +30,10 @@ def test_registry():
CATS.register_module(Munchkin, force=True)
assert len(CATS) == 2
# force=False
with pytest.raises(KeyError):
@CATS.register_module
@CATS.register_module()
class BritishShorthair:
pass
......@@ -46,16 +46,80 @@ def test_registry():
assert CATS.get('PersianCat') is None
assert 'PersianCat' not in CATS
# The order of dict keys are not preserved in python 3.5
assert repr(CATS) in [
"Registry(name=cat, items=['BritishShorthair', 'Munchkin'])",
"Registry(name=cat, items=['Munchkin', 'BritishShorthair'])"
]
@CATS.register_module(name='Siamese')
class SiameseCat:
pass
assert CATS.get('Siamese').__name__ == 'SiameseCat'
class SphynxCat:
pass
CATS.register_module(name='Sphynx', module=SphynxCat)
assert CATS.get('Sphynx') is SphynxCat
repr_str = 'Registry(name=cat, items={'
repr_str += ("'BritishShorthair': <class 'test_registry.test_registry."
"<locals>.BritishShorthair'>, ")
repr_str += ("'Munchkin': <class 'test_registry.test_registry."
"<locals>.Munchkin'>, ")
repr_str += ("'Siamese': <class 'test_registry.test_registry."
"<locals>.SiameseCat'>, ")
repr_str += ("'Sphynx': <class 'test_registry.test_registry."
"<locals>.SphynxCat'>")
repr_str += '})'
assert repr(CATS) == repr_str
# the registered module should be a class
with pytest.raises(TypeError):
CATS.register_module(0)
# can only decorate a class
with pytest.raises(TypeError):
@CATS.register_module()
def some_method():
pass
# begin: test old APIs
with pytest.warns(DeprecationWarning):
CATS.register_module(SphynxCat)
assert CATS.get('SphynxCat').__name__ == 'SphynxCat'
with pytest.warns(DeprecationWarning):
CATS.register_module(SphynxCat, force=True)
assert CATS.get('SphynxCat').__name__ == 'SphynxCat'
with pytest.warns(DeprecationWarning):
@CATS.register_module
class NewCat:
pass
assert CATS.get('NewCat').__name__ == 'NewCat'
with pytest.warns(DeprecationWarning):
CATS.deprecated_register_module(SphynxCat, force=True)
assert CATS.get('SphynxCat').__name__ == 'SphynxCat'
with pytest.warns(DeprecationWarning):
@CATS.deprecated_register_module
class CuteCat:
pass
assert CATS.get('CuteCat').__name__ == 'CuteCat'
with pytest.warns(DeprecationWarning):
@CATS.deprecated_register_module(force=True)
class NewCat2:
pass
assert CATS.get('NewCat2').__name__ == 'NewCat2'
# end: test old APIs
def test_build_from_cfg():
BACKBONES = mmcv.Registry('backbone')
......@@ -94,11 +158,21 @@ def test_build_from_cfg():
assert isinstance(model, ResNet)
assert model.depth == 50 and model.stages == 4
# not a registry
with pytest.raises(TypeError):
cfg = dict(type='VGG')
model = mmcv.build_from_cfg(cfg, 'BACKBONES')
# non-registered class
with pytest.raises(KeyError):
cfg = dict(type='VGG')
model = mmcv.build_from_cfg(cfg, BACKBONES)
# default_args must be a dict or None
with pytest.raises(TypeError):
cfg = dict(type='ResNet', depth=50)
model = mmcv.build_from_cfg(cfg, BACKBONES, default_args=1)
# cfg['type'] should be a str or class
with pytest.raises(TypeError):
cfg = dict(type=1000)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment