"git@developer.sourcefind.cn:OpenDAS/mmcv.git" did not exist on "192b3b3ceba3b4eec6a729c0b171ddd6cdc10025"
Unverified Commit 276883f1 authored by Rui Xu's avatar Rui Xu Committed by GitHub
Browse files

[Feature] Allow register multi-name for a module simultaneously (#775)

* allow register multi-name for a module simultaneously

* add assertion for name type

* use isintance intead of is_str

* fix bug in unit test

* fix unit test
parent 0de9e149
...@@ -2,7 +2,7 @@ import inspect ...@@ -2,7 +2,7 @@ import inspect
import warnings import warnings
from functools import partial from functools import partial
from .misc import is_str from .misc import is_seq_of
class Registry: class Registry:
...@@ -54,10 +54,18 @@ class Registry: ...@@ -54,10 +54,18 @@ class Registry:
if module_name is None: if module_name is None:
module_name = module_class.__name__ module_name = module_class.__name__
if not force and module_name in self._module_dict: if isinstance(module_name, str):
raise KeyError(f'{module_name} is already registered ' module_name = [module_name]
f'in {self.name}') else:
self._module_dict[module_name] = module_class assert is_seq_of(
module_name,
str), ('module_name should be either of None, an '
f'instance of str or list, but got {type(module_name)}')
for name in module_name:
if not force and name in self._module_dict:
raise KeyError(f'{name} is already registered '
f'in {self.name}')
self._module_dict[name] = module_class
def deprecated_register_module(self, cls=None, force=False): def deprecated_register_module(self, cls=None, force=False):
warnings.warn( warnings.warn(
...@@ -157,7 +165,7 @@ def build_from_cfg(cfg, registry, default_args=None): ...@@ -157,7 +165,7 @@ def build_from_cfg(cfg, registry, default_args=None):
args.setdefault(name, value) args.setdefault(name, value)
obj_type = args.pop('type') obj_type = args.pop('type')
if is_str(obj_type): if isinstance(obj_type, str):
obj_cls = registry.get(obj_type) obj_cls = registry.get(obj_type)
if obj_cls is None: if obj_cls is None:
raise KeyError( raise KeyError(
......
...@@ -58,6 +58,9 @@ def test_registry(): ...@@ -58,6 +58,9 @@ def test_registry():
CATS.register_module(name='Sphynx', module=SphynxCat) CATS.register_module(name='Sphynx', module=SphynxCat)
assert CATS.get('Sphynx') is SphynxCat assert CATS.get('Sphynx') is SphynxCat
CATS.register_module(name=['Sphynx1', 'Sphynx2'], module=SphynxCat)
assert CATS.get('Sphynx2') is SphynxCat
repr_str = 'Registry(name=cat, items={' repr_str = 'Registry(name=cat, items={'
repr_str += ("'BritishShorthair': <class 'test_registry.test_registry." repr_str += ("'BritishShorthair': <class 'test_registry.test_registry."
"<locals>.BritishShorthair'>, ") "<locals>.BritishShorthair'>, ")
...@@ -66,10 +69,18 @@ def test_registry(): ...@@ -66,10 +69,18 @@ def test_registry():
repr_str += ("'Siamese': <class 'test_registry.test_registry." repr_str += ("'Siamese': <class 'test_registry.test_registry."
"<locals>.SiameseCat'>, ") "<locals>.SiameseCat'>, ")
repr_str += ("'Sphynx': <class 'test_registry.test_registry." repr_str += ("'Sphynx': <class 'test_registry.test_registry."
"<locals>.SphynxCat'>, ")
repr_str += ("'Sphynx1': <class 'test_registry.test_registry."
"<locals>.SphynxCat'>, ")
repr_str += ("'Sphynx2': <class 'test_registry.test_registry."
"<locals>.SphynxCat'>") "<locals>.SphynxCat'>")
repr_str += '})' repr_str += '})'
assert repr(CATS) == repr_str assert repr(CATS) == repr_str
# name type
with pytest.raises(AssertionError):
CATS.register_module(name=7474741, module=SphynxCat)
# the registered module should be a class # the registered module should be a class
with pytest.raises(TypeError): with pytest.raises(TypeError):
CATS.register_module(0) CATS.register_module(0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment