Unverified Commit 1dbb5d36 authored by jmercat's avatar jmercat Committed by GitHub
Browse files

[Fix] Fix dump method of Config (#1837)



* config dump fix
dump should not depend on how config was loaded only where it is dumped
changed checks for type to the given filename instead of the source filename

* defined test_fromdict and fixed dump function
it might be more convoluted than it needs to be but the formatting depends on the extension of the argument or the extension of the self.filename

* config dump defaults to returning pretty_text, test_fromdict renamed to test_dump_from_dict

* some reformatting in docstrings

* refine unittest

* fix unit test as comment

* Refine docstring

* minor refinement

* import mmcv

* fix lint
Co-authored-by: default avatarHAOCHENYE <21724054@zju.edu.cn>
Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>
parent a848ecfd
...@@ -561,20 +561,42 @@ class Config: ...@@ -561,20 +561,42 @@ class Config:
super(Config, self).__setattr__('_text', _text) super(Config, self).__setattr__('_text', _text)
def dump(self, file=None): def dump(self, file=None):
"""Dumps config into a file or returns a string representation of the
config.
If a file argument is given, saves the config to that file using the
format defined by the file argument extension.
Otherwise, returns a string representing the config. The formatting of
this returned string is defined by the extension of `self.filename`. If
`self.filename` is not defined, returns a string representation of a
dict (lowercased and using ' for strings).
Examples:
>>> cfg_dict = dict(item1=[1, 2], item2=dict(a=0),
... item3=True, item4='test')
>>> cfg = Config(cfg_dict=cfg_dict)
>>> dump_file = "a.py"
>>> cfg.dump(dump_file)
Args:
file (str, optional): Path of the output file where the config
will be dumped. Defaults to None.
"""
import mmcv
cfg_dict = super(Config, self).__getattribute__('_cfg_dict').to_dict() cfg_dict = super(Config, self).__getattribute__('_cfg_dict').to_dict()
if self.filename.endswith('.py'): if file is None:
if file is None: if self.filename is None or self.filename.endswith('.py'):
return self.pretty_text return self.pretty_text
else: else:
with open(file, 'w', encoding='utf-8') as f:
f.write(self.pretty_text)
else:
import mmcv
if file is None:
file_format = self.filename.split('.')[-1] file_format = self.filename.split('.')[-1]
return mmcv.dump(cfg_dict, file_format=file_format) return mmcv.dump(cfg_dict, file_format=file_format)
else: elif file.endswith('.py'):
mmcv.dump(cfg_dict, file) with open(file, 'w', encoding='utf-8') as f:
f.write(self.pretty_text)
else:
file_format = file.split('.')[-1]
return mmcv.dump(cfg_dict, file=file, file_format=file_format)
def merge_from_dict(self, options, allow_list_keys=True): def merge_from_dict(self, options, allow_list_keys=True):
"""Merge list into cfg_dict. """Merge list into cfg_dict.
......
...@@ -435,6 +435,51 @@ def test_dict(): ...@@ -435,6 +435,51 @@ def test_dict():
assert cfg.item2.a == 1 assert cfg.item2.a == 1
@pytest.mark.parametrize('file', ['a.json', 'b.py', 'c.yaml', 'd.yml', None])
def test_dump(file):
# config loaded from dict
cfg_dict = dict(item1=[1, 2], item2=dict(a=0), item3=True, item4='test')
cfg = Config(cfg_dict=cfg_dict)
assert cfg.item1 == cfg_dict['item1']
assert cfg.item2 == cfg_dict['item2']
assert cfg.item3 == cfg_dict['item3']
assert cfg.item4 == cfg_dict['item4']
assert cfg._filename is None
if file is not None:
# dump without a filename argument is only returning pretty_text.
with tempfile.TemporaryDirectory() as temp_config_dir:
cfg_file = osp.join(temp_config_dir, file)
cfg.dump(cfg_file)
dumped_cfg = Config.fromfile(cfg_file)
assert dumped_cfg._cfg_dict == cfg._cfg_dict
else:
assert cfg.dump() == cfg.pretty_text
# The key of json must be a string, so key `1` will be converted to `'1'`.
def compare_json_cfg(ori_cfg, dumped_json_cfg):
for key, value in ori_cfg.items():
assert str(key) in dumped_json_cfg
if not isinstance(value, dict):
assert ori_cfg[key] == dumped_json_cfg[str(key)]
else:
compare_json_cfg(value, dumped_json_cfg[str(key)])
# config loaded from file
cfg_file = osp.join(data_path, 'config/n.py')
cfg = Config.fromfile(cfg_file)
if file is not None:
with tempfile.TemporaryDirectory() as temp_config_dir:
cfg_file = osp.join(temp_config_dir, file)
cfg.dump(cfg_file)
dumped_cfg = Config.fromfile(cfg_file)
if not file.endswith('.json'):
assert dumped_cfg._cfg_dict == cfg._cfg_dict
else:
compare_json_cfg(cfg._cfg_dict, dumped_cfg._cfg_dict)
else:
assert cfg.dump() == cfg.pretty_text
def test_setattr(): def test_setattr():
cfg = Config() cfg = Config()
cfg.item1 = [1, 2] cfg.item1 = [1, 2]
...@@ -496,18 +541,6 @@ def test_dict_action(): ...@@ -496,18 +541,6 @@ def test_dict_action():
assert cfg.item3 is False assert cfg.item3 is False
def test_dump_mapping():
cfg_file = osp.join(data_path, 'config/n.py')
cfg = Config.fromfile(cfg_file)
with tempfile.TemporaryDirectory() as temp_config_dir:
text_cfg_filename = osp.join(temp_config_dir, '_text_config.py')
cfg.dump(text_cfg_filename)
text_cfg = Config.fromfile(text_cfg_filename)
assert text_cfg._cfg_dict == cfg._cfg_dict
def test_reserved_key(): def test_reserved_key():
cfg_file = osp.join(data_path, 'config/g.py') cfg_file = osp.join(data_path, 'config/g.py')
with pytest.raises(KeyError): with pytest.raises(KeyError):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment