Unverified Commit 0be04104 authored by Rui Xu's avatar Rui Xu Committed by GitHub
Browse files

[Bug] fix raise error bug in registering multiple names (#949)

* fix raise error bug in registering multiple names

* fix bug in checking the type of name

* fix lint

* fix unit test for registry

* fix bug in unit test
parent 47856388
...@@ -239,11 +239,6 @@ class Registry: ...@@ -239,11 +239,6 @@ class Registry:
module_name = module_class.__name__ module_name = module_class.__name__
if isinstance(module_name, str): if isinstance(module_name, str):
module_name = [module_name] module_name = [module_name]
else:
assert is_seq_of(
module_name,
str), ('module_name should be either of None, an '
f'instance of str or list, but got {type(module_name)}')
for name in module_name: for name in module_name:
if not force and name in self._module_dict: if not force and name in self._module_dict:
raise KeyError(f'{name} is already registered ' raise KeyError(f'{name} is already registered '
...@@ -297,16 +292,18 @@ class Registry: ...@@ -297,16 +292,18 @@ class Registry:
if isinstance(name, type): if isinstance(name, type):
return self.deprecated_register_module(name, force=force) return self.deprecated_register_module(name, force=force)
# raise the error ahead of time
if not (name is None or isinstance(name, str) or is_seq_of(name, str)):
raise TypeError(
'name must be either of None, an instance of str or a sequence'
f' of str, but got {type(name)}')
# use it as a normal method: x.register_module(module=SomeClass) # use it as a normal method: x.register_module(module=SomeClass)
if module is not None: if module is not None:
self._register_module( self._register_module(
module_class=module, module_name=name, force=force) module_class=module, module_name=name, force=force)
return module return module
# raise the error ahead of time
if not (name is None or isinstance(name, str)):
raise TypeError(f'name must be a str, but got {type(name)}')
# use it as a decorator: @x.register_module() # use it as a decorator: @x.register_module()
def _register(cls): def _register(cls):
self._register_module( self._register_module(
......
...@@ -46,11 +46,12 @@ def test_registry(): ...@@ -46,11 +46,12 @@ def test_registry():
assert CATS.get('PersianCat') is None assert CATS.get('PersianCat') is None
assert 'PersianCat' not in CATS assert 'PersianCat' not in CATS
@CATS.register_module(name='Siamese') @CATS.register_module(name=['Siamese', 'Siamese2'])
class SiameseCat: class SiameseCat:
pass pass
assert CATS.get('Siamese').__name__ == 'SiameseCat' assert CATS.get('Siamese').__name__ == 'SiameseCat'
assert CATS.get('Siamese2').__name__ == 'SiameseCat'
class SphynxCat: class SphynxCat:
pass pass
...@@ -68,6 +69,8 @@ def test_registry(): ...@@ -68,6 +69,8 @@ def test_registry():
"<locals>.Munchkin'>, ") "<locals>.Munchkin'>, ")
repr_str += ("'Siamese': <class 'test_registry.test_registry." repr_str += ("'Siamese': <class 'test_registry.test_registry."
"<locals>.SiameseCat'>, ") "<locals>.SiameseCat'>, ")
repr_str += ("'Siamese2': <class 'test_registry.test_registry."
"<locals>.SiameseCat'>, ")
repr_str += ("'Sphynx': <class 'test_registry.test_registry." repr_str += ("'Sphynx': <class 'test_registry.test_registry."
"<locals>.SphynxCat'>, ") "<locals>.SphynxCat'>, ")
repr_str += ("'Sphynx1': <class 'test_registry.test_registry." repr_str += ("'Sphynx1': <class 'test_registry.test_registry."
...@@ -78,7 +81,7 @@ def test_registry(): ...@@ -78,7 +81,7 @@ def test_registry():
assert repr(CATS) == repr_str assert repr(CATS) == repr_str
# name type # name type
with pytest.raises(AssertionError): with pytest.raises(TypeError):
CATS.register_module(name=7474741, module=SphynxCat) CATS.register_module(name=7474741, module=SphynxCat)
# the registered module should be a class # the registered module should be a class
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment