Unverified Commit 779f47ba authored by David de la Iglesia Castro's avatar David de la Iglesia Castro Committed by GitHub
Browse files

Allow type to be default arg (#558)

* Add test case for type defined using default_args

* Refactor build_from_cfg

* Update exception of missing type

* pre-commit

* Fix default_args is None

* pre-commit

* Bring back test

* Update exception raising
parent 7ef3a5e9
......@@ -139,8 +139,10 @@ def build_from_cfg(cfg, registry, default_args=None):
if not isinstance(cfg, dict):
raise TypeError(f'cfg must be a dict, but got {type(cfg)}')
if 'type' not in cfg:
raise KeyError(
f'the cfg dict must contain the key "type", but got {cfg}')
if default_args is None or 'type' not in default_args:
raise KeyError(
'`cfg` or `default_args` must contain the key "type", '
f'but got {cfg}\n{default_args}')
if not isinstance(registry, Registry):
raise TypeError('registry must be an mmcv.Registry object, '
f'but got {type(registry)}')
......@@ -149,6 +151,11 @@ def build_from_cfg(cfg, registry, default_args=None):
f'but got {type(default_args)}')
args = cfg.copy()
if default_args is not None:
for name, value in default_args.items():
args.setdefault(name, value)
obj_type = args.pop('type')
if is_str(obj_type):
obj_cls = registry.get(obj_type)
......@@ -161,7 +168,4 @@ def build_from_cfg(cfg, registry, default_args=None):
raise TypeError(
f'type must be a str or valid type, but got {type(obj_type)}')
if default_args is not None:
for name, value in default_args.items():
args.setdefault(name, value)
return obj_cls(**args)
......@@ -158,6 +158,18 @@ def test_build_from_cfg():
assert isinstance(model, ResNet)
assert model.depth == 50 and model.stages == 4
# type defined using default_args
cfg = dict(depth=50)
model = mmcv.build_from_cfg(
cfg, BACKBONES, default_args=dict(type='ResNet'))
assert isinstance(model, ResNet)
assert model.depth == 50 and model.stages == 4
cfg = dict(depth=50)
model = mmcv.build_from_cfg(cfg, BACKBONES, default_args=dict(type=ResNet))
assert isinstance(model, ResNet)
assert model.depth == 50 and model.stages == 4
# not a registry
with pytest.raises(TypeError):
cfg = dict(type='VGG')
......@@ -179,10 +191,16 @@ def test_build_from_cfg():
model = mmcv.build_from_cfg(cfg, BACKBONES)
# cfg should contain the key "type"
with pytest.raises(KeyError):
with pytest.raises(KeyError, match='must contain the key "type"'):
cfg = dict(depth=50, stages=4)
model = mmcv.build_from_cfg(cfg, BACKBONES)
# cfg or default_args should contain the key "type"
with pytest.raises(KeyError, match='must contain the key "type"'):
cfg = dict(depth=50)
model = mmcv.build_from_cfg(
cfg, BACKBONES, default_args=dict(stages=4))
# incorrect registry type
with pytest.raises(TypeError):
cfg = dict(type='ResNet', depth=50)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment