Unverified Commit d5936d02 authored by Jerry Jiarui XU's avatar Jerry Jiarui XU Committed by GitHub
Browse files

Make Config.dump() output in original file (#275)

* add dump() from for different file type

* remove exception
parent d4fac3a6
# Copyright (c) Open-MMLab. All rights reserved. # Copyright (c) Open-MMLab. All rights reserved.
import json
import os.path as osp import os.path as osp
import shutil import shutil
import sys import sys
...@@ -291,10 +290,21 @@ class Config(object): ...@@ -291,10 +290,21 @@ class Config(object):
def __iter__(self): def __iter__(self):
return iter(self._cfg_dict) return iter(self._cfg_dict)
def dump(self): def dump(self, file=None):
cfg_dict = super(Config, self).__getattribute__('_cfg_dict') cfg_dict = super(Config, self).__getattribute__('_cfg_dict').to_dict()
format_text = json.dumps(cfg_dict, indent=2) if self.filename.endswith('.py'):
return format_text if file is None:
return self.pretty_text
else:
with open(file, 'w') as f:
f.write(self.pretty_text)
else:
import mmcv
if file is None:
file_format = self.filename.split('.')[-1]
return mmcv.dump(cfg_dict, file_format=file_format)
else:
mmcv.dump(cfg_dict, file)
def merge_from_dict(self, options): def merge_from_dict(self, options):
"""Merge list into cfg_dict """Merge list into cfg_dict
......
...@@ -2,10 +2,10 @@ ...@@ -2,10 +2,10 @@
import argparse import argparse
import json import json
import os.path as osp import os.path as osp
import sys
import tempfile import tempfile
import pytest import pytest
import yaml
from mmcv import Config, DictAction from mmcv import Config, DictAction
...@@ -21,18 +21,44 @@ def test_construct(): ...@@ -21,18 +21,44 @@ def test_construct():
Config([0, 1]) Config([0, 1])
cfg_dict = dict(item1=[1, 2], item2=dict(a=0), item3=True, item4='test') cfg_dict = dict(item1=[1, 2], item2=dict(a=0), item3=True, item4='test')
format_text = json.dumps(cfg_dict, indent=2) # test a.py
for filename in ['a.py', 'b.json', 'c.yaml']: cfg_file = osp.join(osp.dirname(__file__), 'data/config/a.py')
cfg_file = osp.join(osp.dirname(__file__), 'data/config', filename) cfg = Config(cfg_dict, filename=cfg_file)
cfg = Config(cfg_dict, filename=cfg_file) assert isinstance(cfg, Config)
assert isinstance(cfg, Config) assert cfg.filename == cfg_file
assert cfg.filename == cfg_file assert cfg.text == open(cfg_file, 'r').read()
assert cfg.text == open(cfg_file, 'r').read() assert cfg.dump() == cfg.pretty_text
if sys.version_info >= (3, 6): with tempfile.TemporaryDirectory() as temp_config_dir:
assert cfg.dump() == format_text dump_file = osp.join(temp_config_dir, 'a.py')
else: cfg.dump(dump_file)
loaded = json.loads(cfg.dump()) assert cfg.dump() == open(dump_file, 'r').read()
assert set(loaded.keys()) == set(cfg_dict) assert Config.fromfile(dump_file)
# test b.json
cfg_file = osp.join(osp.dirname(__file__), 'data/config/b.json')
cfg = Config(cfg_dict, filename=cfg_file)
assert isinstance(cfg, Config)
assert cfg.filename == cfg_file
assert cfg.text == open(cfg_file, 'r').read()
assert cfg.dump() == json.dumps(cfg_dict)
with tempfile.TemporaryDirectory() as temp_config_dir:
dump_file = osp.join(temp_config_dir, 'b.json')
cfg.dump(dump_file)
assert cfg.dump() == open(dump_file, 'r').read()
assert Config.fromfile(dump_file)
# test c.yaml
cfg_file = osp.join(osp.dirname(__file__), 'data/config/c.yaml')
cfg = Config(cfg_dict, filename=cfg_file)
assert isinstance(cfg, Config)
assert cfg.filename == cfg_file
assert cfg.text == open(cfg_file, 'r').read()
assert cfg.dump() == yaml.dump(cfg_dict)
with tempfile.TemporaryDirectory() as temp_config_dir:
dump_file = osp.join(temp_config_dir, 'c.yaml')
cfg.dump(dump_file)
assert cfg.dump() == open(dump_file, 'r').read()
assert Config.fromfile(dump_file)
def test_fromfile(): def test_fromfile():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment