Unverified Commit 0be04104 authored by Rui Xu's avatar Rui Xu Committed by GitHub
Browse files

[Bug] fix raise error bug in registering multiple names (#949)

* fix raise error bug in registering multiple names

* fix bug in checking the type of name

* fix lint

* fix unit test for registry

* fix bug in unit test
parent 47856388
......@@ -239,11 +239,6 @@ class Registry:
module_name = module_class.__name__
if isinstance(module_name, str):
module_name = [module_name]
else:
assert is_seq_of(
module_name,
str), ('module_name should be either of None, an '
f'instance of str or list, but got {type(module_name)}')
for name in module_name:
if not force and name in self._module_dict:
raise KeyError(f'{name} is already registered '
......@@ -297,16 +292,18 @@ class Registry:
if isinstance(name, type):
return self.deprecated_register_module(name, force=force)
# raise the error ahead of time
if not (name is None or isinstance(name, str) or is_seq_of(name, str)):
raise TypeError(
'name must be either of None, an instance of str or a sequence'
f' of str, but got {type(name)}')
# use it as a normal method: x.register_module(module=SomeClass)
if module is not None:
self._register_module(
module_class=module, module_name=name, force=force)
return module
# raise the error ahead of time
if not (name is None or isinstance(name, str)):
raise TypeError(f'name must be a str, but got {type(name)}')
# use it as a decorator: @x.register_module()
def _register(cls):
self._register_module(
......
......@@ -46,11 +46,12 @@ def test_registry():
assert CATS.get('PersianCat') is None
assert 'PersianCat' not in CATS
@CATS.register_module(name='Siamese')
@CATS.register_module(name=['Siamese', 'Siamese2'])
class SiameseCat:
pass
assert CATS.get('Siamese').__name__ == 'SiameseCat'
assert CATS.get('Siamese2').__name__ == 'SiameseCat'
class SphynxCat:
pass
......@@ -68,6 +69,8 @@ def test_registry():
"<locals>.Munchkin'>, ")
repr_str += ("'Siamese': <class 'test_registry.test_registry."
"<locals>.SiameseCat'>, ")
repr_str += ("'Siamese2': <class 'test_registry.test_registry."
"<locals>.SiameseCat'>, ")
repr_str += ("'Sphynx': <class 'test_registry.test_registry."
"<locals>.SphynxCat'>, ")
repr_str += ("'Sphynx1': <class 'test_registry.test_registry."
......@@ -78,7 +81,7 @@ def test_registry():
assert repr(CATS) == repr_str
# name type
with pytest.raises(AssertionError):
with pytest.raises(TypeError):
CATS.register_module(name=7474741, module=SphynxCat)
# the registered module should be a class
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment