Unverified Commit 276883f1 authored by Rui Xu's avatar Rui Xu Committed by GitHub
Browse files

[Feature] Allow register multi-name for a module simultaneously (#775)

* allow register multi-name for a module simultaneously

* add assertion for name type

* use isintance intead of is_str

* fix bug in unit test

* fix unit test
parent 0de9e149
......@@ -2,7 +2,7 @@ import inspect
import warnings
from functools import partial
from .misc import is_str
from .misc import is_seq_of
class Registry:
......@@ -54,10 +54,18 @@ class Registry:
if module_name is None:
module_name = module_class.__name__
if not force and module_name in self._module_dict:
raise KeyError(f'{module_name} is already registered '
if isinstance(module_name, str):
module_name = [module_name]
else:
assert is_seq_of(
module_name,
str), ('module_name should be either of None, an '
f'instance of str or list, but got {type(module_name)}')
for name in module_name:
if not force and name in self._module_dict:
raise KeyError(f'{name} is already registered '
f'in {self.name}')
self._module_dict[module_name] = module_class
self._module_dict[name] = module_class
def deprecated_register_module(self, cls=None, force=False):
warnings.warn(
......@@ -157,7 +165,7 @@ def build_from_cfg(cfg, registry, default_args=None):
args.setdefault(name, value)
obj_type = args.pop('type')
if is_str(obj_type):
if isinstance(obj_type, str):
obj_cls = registry.get(obj_type)
if obj_cls is None:
raise KeyError(
......
......@@ -58,6 +58,9 @@ def test_registry():
CATS.register_module(name='Sphynx', module=SphynxCat)
assert CATS.get('Sphynx') is SphynxCat
CATS.register_module(name=['Sphynx1', 'Sphynx2'], module=SphynxCat)
assert CATS.get('Sphynx2') is SphynxCat
repr_str = 'Registry(name=cat, items={'
repr_str += ("'BritishShorthair': <class 'test_registry.test_registry."
"<locals>.BritishShorthair'>, ")
......@@ -66,10 +69,18 @@ def test_registry():
repr_str += ("'Siamese': <class 'test_registry.test_registry."
"<locals>.SiameseCat'>, ")
repr_str += ("'Sphynx': <class 'test_registry.test_registry."
"<locals>.SphynxCat'>, ")
repr_str += ("'Sphynx1': <class 'test_registry.test_registry."
"<locals>.SphynxCat'>, ")
repr_str += ("'Sphynx2': <class 'test_registry.test_registry."
"<locals>.SphynxCat'>")
repr_str += '})'
assert repr(CATS) == repr_str
# name type
with pytest.raises(AssertionError):
CATS.register_module(name=7474741, module=SphynxCat)
# the registered module should be a class
with pytest.raises(TypeError):
CATS.register_module(0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment