Unverified Commit 6738cd30 authored by Jerry Jiarui XU's avatar Jerry Jiarui XU Committed by GitHub
Browse files

add base for config (#194)

* add base for config

* fixed format

* rm terminal width

* support multiple & recursive base

* add test case

* fix format

* add test construct

* minor fix

* add more test, rewrite merge from opt

* avoid depulicate keys

* delete imported config as module

* rename merge_from_dict
parent 7dac84d0
# Copyright (c) Open-MMLab. All rights reserved.
import json
import os.path as osp
import shutil
import sys
......@@ -11,6 +12,9 @@ from addict import Dict
from .misc import collections_abc
from .path import check_file_exist
BASE_KEY = '_base_'
DELETE_KEY = '_delete_'
class ConfigDict(Dict):
......@@ -41,11 +45,11 @@ def add_args(parser, cfg, prefix=''):
elif isinstance(v, bool):
parser.add_argument('--' + prefix + k, action='store_true')
elif isinstance(v, dict):
add_args(parser, v, k + '.')
add_args(parser, v, prefix + k + '.')
elif isinstance(v, collections_abc.Iterable):
parser.add_argument('--' + prefix + k, type=type(v[0]), nargs='+')
else:
print('connot parse key {} of type {}'.format(prefix + k, type(v)))
print('cannot parse key {} of type {}'.format(prefix + k, type(v)))
return parser
......@@ -76,7 +80,7 @@ class Config(object):
"""
@staticmethod
def fromfile(filename):
def _file2dict(filename):
filename = osp.abspath(osp.expanduser(filename))
check_file_exist(filename)
if filename.endswith('.py'):
......@@ -91,12 +95,62 @@ class Config(object):
for name, value in mod.__dict__.items()
if not name.startswith('__')
}
# delete imported module
del sys.modules['_tempconfig']
elif filename.endswith(('.yml', '.yaml', '.json')):
import mmcv
cfg_dict = mmcv.load(filename)
else:
raise IOError('Only py/yml/yaml/json type are supported now!')
return Config(cfg_dict, filename=filename)
cfg_text = filename + '\n'
with open(filename, 'r') as f:
cfg_text += f.read()
if '_base_' in cfg_dict:
cfg_dir = osp.dirname(filename)
base_filename = cfg_dict.pop('_base_')
base_filename = base_filename if isinstance(
base_filename, list) else [base_filename]
cfg_dict_list = list()
cfg_text_list = list()
for f in base_filename:
_cfg_dict, _cfg_text = Config._file2dict(osp.join(cfg_dir, f))
cfg_dict_list.append(_cfg_dict)
cfg_text_list.append(_cfg_text)
base_cfg_dict = dict()
for c in cfg_dict_list:
if len(base_cfg_dict.keys() & c.keys()) > 0:
raise KeyError('Duplicate key is not allowed among bases')
base_cfg_dict.update(c)
Config._merge_a_into_b(cfg_dict, base_cfg_dict)
cfg_dict = base_cfg_dict
# merge cfg_text
cfg_text_list.append(cfg_text)
cfg_text = '\n'.join(cfg_text_list)
return cfg_dict, cfg_text
@staticmethod
def _merge_a_into_b(a, b):
# merge dict a into dict b. values in a will overwrite b.
for k, v in a.items():
if isinstance(v, dict) and k in b and not v.pop(DELETE_KEY, False):
if not isinstance(b[k], dict):
raise TypeError(
'Cannot inherit key {} from base!'.format(k))
Config._merge_a_into_b(v, b[k])
else:
b[k] = v
@staticmethod
def fromfile(filename):
cfg_dict, cfg_text = Config._file2dict(filename)
return Config(cfg_dict, cfg_text=cfg_text, filename=filename)
@staticmethod
def auto_argparser(description=None):
......@@ -111,7 +165,7 @@ class Config(object):
add_args(parser, cfg)
return parser, cfg
def __init__(self, cfg_dict=None, filename=None):
def __init__(self, cfg_dict=None, cfg_text=None, filename=None):
if cfg_dict is None:
cfg_dict = dict()
elif not isinstance(cfg_dict, dict):
......@@ -120,11 +174,14 @@ class Config(object):
super(Config, self).__setattr__('_cfg_dict', ConfigDict(cfg_dict))
super(Config, self).__setattr__('_filename', filename)
if filename:
if cfg_text:
text = cfg_text
elif filename:
with open(filename, 'r') as f:
super(Config, self).__setattr__('_text', f.read())
text = f.read()
else:
super(Config, self).__setattr__('_text', '')
text = ''
super(Config, self).__setattr__('_text', text)
@property
def filename(self):
......@@ -159,3 +216,33 @@ class Config(object):
def __iter__(self):
return iter(self._cfg_dict)
def dump(self):
cfg_dict = super(Config, self).__getattribute__('_cfg_dict')
format_text = json.dumps(cfg_dict, indent=2)
return format_text
def merge_from_dict(self, options):
""" Merge list into cfg_dict
Merge the dict parsed by MultipleKVAction into this cfg.
Example,
>>> options = {'model.backbone.depth': 50}
>>> cfg = Config(dict(model=dict(backbone=dict(type='ResNet'))))
>>> cfg.merge_from_dict(options)
Args:
options (dict): dict of configs to merge from.
"""
option_cfg_dict = {}
for full_key, v in options.items():
d = option_cfg_dict
key_list = full_key.split('.')
for subkey in key_list[:-1]:
d[subkey] = ConfigDict()
d = d[subkey]
subkey = key_list[-1]
d[subkey] = v
cfg_dict = super(Config, self).__getattribute__('_cfg_dict')
Config._merge_a_into_b(option_cfg_dict, cfg_dict)
item1 = [1, 2]
item2 = {'a': 0}
item3 = True
item4 = 'test'
_base_ = './base.py'
item1 = [2, 3]
item2 = {'a': 1}
item3 = False
item4 = 'test_base'
_base_ = './base.py'
item2 = {'b': 0, '_delete_': True}
_base_ = './base.py'
item3 = {'a': 1}
_base_ = './d.py'
item4 = 'test_recursive_bases'
_base_ = ['./l1.py', './l2.yaml', './l3.json']
item3 = False
item4 = 'test'
_base_ = ['./l1.py', './l2.yaml', './l3.json', 'a.py']
item3 = False
item4 = 'test'
# Copyright (c) Open-MMLab. All rights reserved.
import json
import os.path as osp
import pytest
......@@ -16,6 +17,16 @@ def test_construct():
with pytest.raises(TypeError):
Config([0, 1])
cfg_dict = dict(item1=[1, 2], item2=dict(a=0), item3=True, item4='test')
format_text = json.dumps(cfg_dict, indent=2)
for filename in ['a.py', 'b.json', 'c.yaml']:
cfg_file = osp.join(osp.dirname(__file__), 'data/config', filename)
cfg = Config(cfg_dict, filename=cfg_file)
assert isinstance(cfg, Config)
assert cfg.filename == cfg_file
assert cfg.text == open(cfg_file, 'r').read()
assert cfg.dump() == format_text
def test_fromfile():
for filename in ['a.py', 'a.b.py', 'b.json', 'c.yaml']:
......@@ -23,7 +34,8 @@ def test_fromfile():
cfg = Config.fromfile(cfg_file)
assert isinstance(cfg, Config)
assert cfg.filename == cfg_file
assert cfg.text == open(cfg_file, 'r').read()
assert cfg.text == osp.abspath(osp.expanduser(cfg_file)) + '\n' + \
open(cfg_file, 'r').read()
with pytest.raises(FileNotFoundError):
Config.fromfile('no_such_file.py')
......@@ -31,6 +43,73 @@ def test_fromfile():
Config.fromfile(osp.join(osp.dirname(__file__), 'data/color.jpg'))
def test_merge_from_base():
cfg_file = osp.join(osp.dirname(__file__), 'data/config/d.py')
cfg = Config.fromfile(cfg_file)
assert isinstance(cfg, Config)
assert cfg.filename == cfg_file
base_cfg_file = osp.join(osp.dirname(__file__), 'data/config/base.py')
merge_text = osp.abspath(osp.expanduser(base_cfg_file)) + '\n' + \
open(base_cfg_file, 'r').read()
merge_text += '\n' + osp.abspath(osp.expanduser(cfg_file)) + '\n' + \
open(cfg_file, 'r').read()
assert cfg.text == merge_text
assert cfg.item1 == [2, 3]
assert cfg.item2.a == 1
assert cfg.item3 is False
assert cfg.item4 == 'test_base'
with pytest.raises(TypeError):
Config.fromfile(osp.join(osp.dirname(__file__), 'data/config/e.py'))
def test_merge_from_multiple_bases():
cfg_file = osp.join(osp.dirname(__file__), 'data/config/l.py')
cfg = Config.fromfile(cfg_file)
assert isinstance(cfg, Config)
assert cfg.filename == cfg_file
# cfg.field
assert cfg.item1 == [1, 2]
assert cfg.item2.a == 0
assert cfg.item3 is False
assert cfg.item4 == 'test'
with pytest.raises(KeyError):
Config.fromfile(osp.join(osp.dirname(__file__), 'data/config/m.py'))
def test_merge_recursive_bases():
cfg_file = osp.join(osp.dirname(__file__), 'data/config/f.py')
cfg = Config.fromfile(cfg_file)
assert isinstance(cfg, Config)
assert cfg.filename == cfg_file
# cfg.field
assert cfg.item1 == [2, 3]
assert cfg.item2.a == 1
assert cfg.item3 is False
assert cfg.item4 == 'test_recursive_bases'
def test_merge_from_dict():
cfg_file = osp.join(osp.dirname(__file__), 'data/config/a.py')
cfg = Config.fromfile(cfg_file)
input_options = {'item2.a': 1, 'item3': False}
cfg.merge_from_dict(input_options)
assert cfg.item2 == dict(a=1)
assert cfg.item3 is False
def test_merge_delete():
cfg_file = osp.join(osp.dirname(__file__), 'data/config/delete.py')
cfg = Config.fromfile(cfg_file)
# cfg.field
assert cfg.item1 == [1, 2]
assert cfg.item2 == dict(b=0)
assert cfg.item3 is True
assert cfg.item4 == 'test'
assert '_delete_' not in cfg.item2
def test_dict():
cfg_dict = dict(item1=[1, 2], item2=dict(a=0), item3=True, item4='test')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment