Unverified Commit 779f47ba authored by David de la Iglesia Castro's avatar David de la Iglesia Castro Committed by GitHub
Browse files

Allow type to be default arg (#558)

* Add test case for type defined using default_args

* Refactor build_from_cfg

* Update exception of missing type

* pre-commit

* Fix default_args is None

* pre-commit

* Bring back test

* Update exception raising
parent 7ef3a5e9
...@@ -139,8 +139,10 @@ def build_from_cfg(cfg, registry, default_args=None): ...@@ -139,8 +139,10 @@ def build_from_cfg(cfg, registry, default_args=None):
if not isinstance(cfg, dict): if not isinstance(cfg, dict):
raise TypeError(f'cfg must be a dict, but got {type(cfg)}') raise TypeError(f'cfg must be a dict, but got {type(cfg)}')
if 'type' not in cfg: if 'type' not in cfg:
raise KeyError( if default_args is None or 'type' not in default_args:
f'the cfg dict must contain the key "type", but got {cfg}') raise KeyError(
'`cfg` or `default_args` must contain the key "type", '
f'but got {cfg}\n{default_args}')
if not isinstance(registry, Registry): if not isinstance(registry, Registry):
raise TypeError('registry must be an mmcv.Registry object, ' raise TypeError('registry must be an mmcv.Registry object, '
f'but got {type(registry)}') f'but got {type(registry)}')
...@@ -149,6 +151,11 @@ def build_from_cfg(cfg, registry, default_args=None): ...@@ -149,6 +151,11 @@ def build_from_cfg(cfg, registry, default_args=None):
f'but got {type(default_args)}') f'but got {type(default_args)}')
args = cfg.copy() args = cfg.copy()
if default_args is not None:
for name, value in default_args.items():
args.setdefault(name, value)
obj_type = args.pop('type') obj_type = args.pop('type')
if is_str(obj_type): if is_str(obj_type):
obj_cls = registry.get(obj_type) obj_cls = registry.get(obj_type)
...@@ -161,7 +168,4 @@ def build_from_cfg(cfg, registry, default_args=None): ...@@ -161,7 +168,4 @@ def build_from_cfg(cfg, registry, default_args=None):
raise TypeError( raise TypeError(
f'type must be a str or valid type, but got {type(obj_type)}') f'type must be a str or valid type, but got {type(obj_type)}')
if default_args is not None:
for name, value in default_args.items():
args.setdefault(name, value)
return obj_cls(**args) return obj_cls(**args)
...@@ -158,6 +158,18 @@ def test_build_from_cfg(): ...@@ -158,6 +158,18 @@ def test_build_from_cfg():
assert isinstance(model, ResNet) assert isinstance(model, ResNet)
assert model.depth == 50 and model.stages == 4 assert model.depth == 50 and model.stages == 4
# type defined using default_args
cfg = dict(depth=50)
model = mmcv.build_from_cfg(
cfg, BACKBONES, default_args=dict(type='ResNet'))
assert isinstance(model, ResNet)
assert model.depth == 50 and model.stages == 4
cfg = dict(depth=50)
model = mmcv.build_from_cfg(cfg, BACKBONES, default_args=dict(type=ResNet))
assert isinstance(model, ResNet)
assert model.depth == 50 and model.stages == 4
# not a registry # not a registry
with pytest.raises(TypeError): with pytest.raises(TypeError):
cfg = dict(type='VGG') cfg = dict(type='VGG')
...@@ -179,10 +191,16 @@ def test_build_from_cfg(): ...@@ -179,10 +191,16 @@ def test_build_from_cfg():
model = mmcv.build_from_cfg(cfg, BACKBONES) model = mmcv.build_from_cfg(cfg, BACKBONES)
# cfg should contain the key "type" # cfg should contain the key "type"
with pytest.raises(KeyError): with pytest.raises(KeyError, match='must contain the key "type"'):
cfg = dict(depth=50, stages=4) cfg = dict(depth=50, stages=4)
model = mmcv.build_from_cfg(cfg, BACKBONES) model = mmcv.build_from_cfg(cfg, BACKBONES)
# cfg or default_args should contain the key "type"
with pytest.raises(KeyError, match='must contain the key "type"'):
cfg = dict(depth=50)
model = mmcv.build_from_cfg(
cfg, BACKBONES, default_args=dict(stages=4))
# incorrect registry type # incorrect registry type
with pytest.raises(TypeError): with pytest.raises(TypeError):
cfg = dict(type='ResNet', depth=50) cfg = dict(type='ResNet', depth=50)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment