Unverified Commit 1e8a2121 authored by Jerry Jiarui XU's avatar Jerry Jiarui XU Committed by GitHub
Browse files

Add DictAction and docs for config (#243)

* fixed merge_from_dict, add DictAction

* add config docs

* fixed format type

* change to easy example

* update docs

* update docs
parent 76a064ee
...@@ -10,7 +10,7 @@ Here is an example of the config file `test.py`. ...@@ -10,7 +10,7 @@ Here is an example of the config file `test.py`.
```python ```python
a = 1 a = 1
b = {'b1': [0, 1, 2], 'b2': None} b = dict(b1=[0, 1, 2], b2=None)
c = (1, 2) c = (1, 2)
d = 'string' d = 'string'
``` ```
...@@ -18,11 +18,108 @@ d = 'string' ...@@ -18,11 +18,108 @@ d = 'string'
To load and use configs To load and use configs
```python ```python
cfg = Config.fromfile('test.py') >>> cfg = Config.fromfile('test.py')
assert cfg.a == 1 >>> print(cfg)
assert cfg.b.b1 == [0, 1, 2] >>> dict(a=1,
cfg.c = None ... b=dict(b1=[0, 1, 2], b2=None),
assert cfg.c == None ... c=(1, 2),
... d='string')
```
For all format configs, inheritance is supported. To reuse fields in other config files,
specify `_base_='./config_a.py'` or a list of configs `_base_=['./config_a.py', './config_b.py']`.
Here are 4 examples of config inheritance.
`config_a.py`
```python
a = 1
b = dict(b1=[0, 1, 2], b2=None)
```
#### Inherit from base config without overlaped keys.
`config_b.py`
```python
_base_ = './config_a.py'
c = (1, 2)
d = 'string'
```
```python
>>> cfg = Config.fromfile('./config_b.py')
>>> print(cfg)
>>> dict(a=1,
... b=dict(b1=[0, 1, 2], b2=None),
... c=(1, 2),
... d='string')
```
New fields in `config_b.py` are combined with old fields in `config_a.py`
#### Inherit from base config with overlaped keys.
`config_c.py`
```python
_base_ = './config_a.py'
b = dict(b2=1)
c = (1, 2)
```
```python
>>> cfg = Config.fromfile('./config_c.py')
>>> print(cfg)
>>> dict(a=1,
... b=dict(b1=[0, 1, 2], b2=1),
... c=(1, 2))
```
`b.b2=None` in `config_a` is replaced with `b.b2=1` in `config_c.py`.
#### Inherit from base config with ignored fields.
`config_d.py`
```python
_base_ = './config_a.py'
b = dict(_delete_=True, b2=None, b3=0.1)
c = (1, 2)
```
```python
>>> cfg = Config.fromfile('./config_d.py')
>>> print(cfg)
>>> dict(a=1,
... b=dict(b2=None, b3=0.1),
... c=(1, 2))
```
You may also set `_delete_=True` to ignore some fields in base configs. All old keys `b1, b2, b3` in `b` are replaced with new keys `b2, b3`.
#### Inherit from multiple base configs (the base configs should not contain the same keys).
`config_e.py`
```python
c = (1, 2)
d = 'string'
```
`config_f.py`
```python
_base_ = ['./config_a.py', './config_e.py']
```
```python
>>> cfg = Config.fromfile('./config_f.py')
>>> print(cfg)
>>> dict(a=1,
... b=dict(b1=[0, 1, 2], b2=None),
... c=(1, 2),
... d='string')
``` ```
### ProgressBar ### ProgressBar
...@@ -72,8 +169,6 @@ for i, task in enumerate(mmcv.track_iter_progress(tasks)): ...@@ -72,8 +169,6 @@ for i, task in enumerate(mmcv.track_iter_progress(tasks)):
print(task) print(task)
``` ```
### Timer ### Timer
It is convinient to compute the runtime of a code block with `Timer`. It is convinient to compute the runtime of a code block with `Timer`.
......
# Copyright (c) Open-MMLab. All rights reserved. # Copyright (c) Open-MMLab. All rights reserved.
from .config import Config, ConfigDict from .config import Config, ConfigDict, DictAction
from .logging import get_logger, print_log from .logging import get_logger, print_log
from .misc import (check_prerequisites, concat_list, is_list_of, is_seq_of, from .misc import (check_prerequisites, concat_list, is_list_of, is_seq_of,
is_str, is_tuple_of, iter_cast, list_cast, is_str, is_tuple_of, iter_cast, list_cast,
...@@ -13,11 +13,11 @@ from .registry import Registry, build_from_cfg ...@@ -13,11 +13,11 @@ from .registry import Registry, build_from_cfg
from .timer import Timer, TimerError, check_time from .timer import Timer, TimerError, check_time
__all__ = [ __all__ = [
'Config', 'ConfigDict', 'get_logger', 'print_log', 'is_str', 'iter_cast', 'Config', 'ConfigDict', 'DictAction', 'get_logger', 'print_log', 'is_str',
'list_cast', 'tuple_cast', 'is_seq_of', 'is_list_of', 'is_tuple_of', 'iter_cast', 'list_cast', 'tuple_cast', 'is_seq_of', 'is_list_of',
'slice_list', 'concat_list', 'check_prerequisites', 'requires_package', 'is_tuple_of', 'slice_list', 'concat_list', 'check_prerequisites',
'requires_executable', 'is_filepath', 'fopen', 'check_file_exist', 'requires_package', 'requires_executable', 'is_filepath', 'fopen',
'mkdir_or_exist', 'symlink', 'scandir', 'ProgressBar', 'track_progress', 'check_file_exist', 'mkdir_or_exist', 'symlink', 'scandir', 'ProgressBar',
'track_iter_progress', 'track_parallel_progress', 'Registry', 'track_progress', 'track_iter_progress', 'track_parallel_progress',
'build_from_cfg', 'Timer', 'TimerError', 'check_time' 'Registry', 'build_from_cfg', 'Timer', 'TimerError', 'check_time'
] ]
...@@ -4,7 +4,7 @@ import os.path as osp ...@@ -4,7 +4,7 @@ import os.path as osp
import shutil import shutil
import sys import sys
import tempfile import tempfile
from argparse import ArgumentParser from argparse import Action, ArgumentParser
from collections import abc from collections import abc
from importlib import import_module from importlib import import_module
...@@ -107,9 +107,9 @@ class Config(object): ...@@ -107,9 +107,9 @@ class Config(object):
with open(filename, 'r') as f: with open(filename, 'r') as f:
cfg_text += f.read() cfg_text += f.read()
if '_base_' in cfg_dict: if BASE_KEY in cfg_dict:
cfg_dir = osp.dirname(filename) cfg_dir = osp.dirname(filename)
base_filename = cfg_dict.pop('_base_') base_filename = cfg_dict.pop(BASE_KEY)
base_filename = base_filename if isinstance( base_filename = base_filename if isinstance(
base_filename, list) else [base_filename] base_filename, list) else [base_filename]
...@@ -137,12 +137,14 @@ class Config(object): ...@@ -137,12 +137,14 @@ class Config(object):
@staticmethod @staticmethod
def _merge_a_into_b(a, b): def _merge_a_into_b(a, b):
# merge dict a into dict b. values in a will overwrite b. # merge dict `a` into dict `b`. values in `a` will overwrite `b`.
for k, v in a.items(): for k, v in a.items():
if isinstance(v, dict) and k in b and not v.pop(DELETE_KEY, False): if isinstance(v, dict) and k in b and not v.pop(DELETE_KEY, False):
if not isinstance(b[k], dict): if not isinstance(b[k], dict):
raise TypeError( raise TypeError(
'Cannot inherit key {} from base!'.format(k)) '{}={} cannot be inherited from base because {} is a '
'dict in the child config. You may set `{}=True` to '
'ignore the base config'.format(k, v, k, DELETE_KEY))
Config._merge_a_into_b(v, b[k]) Config._merge_a_into_b(v, b[k])
else: else:
b[k] = v b[k] = v
...@@ -289,9 +291,13 @@ class Config(object): ...@@ -289,9 +291,13 @@ class Config(object):
Merge the dict parsed by MultipleKVAction into this cfg. Merge the dict parsed by MultipleKVAction into this cfg.
Example, Example,
>>> options = {'model.backbone.depth': 50} >>> options = {'model.backbone.depth': 50,
... 'model.backbone.with_cp':True}
>>> cfg = Config(dict(model=dict(backbone=dict(type='ResNet')))) >>> cfg = Config(dict(model=dict(backbone=dict(type='ResNet'))))
>>> cfg.merge_from_dict(options) >>> cfg.merge_from_dict(options)
>>> cfg_dict = super(Config, self).__getattribute__('_cfg_dict')
>>> assert cfg_dict == dict(
... model=dict(backbone=dict(depth=50, with_cp=True)))
Args: Args:
options (dict): dict of configs to merge from. options (dict): dict of configs to merge from.
...@@ -301,10 +307,42 @@ class Config(object): ...@@ -301,10 +307,42 @@ class Config(object):
d = option_cfg_dict d = option_cfg_dict
key_list = full_key.split('.') key_list = full_key.split('.')
for subkey in key_list[:-1]: for subkey in key_list[:-1]:
d[subkey] = ConfigDict() d.setdefault(subkey, ConfigDict())
d = d[subkey] d = d[subkey]
subkey = key_list[-1] subkey = key_list[-1]
d[subkey] = v d[subkey] = v
cfg_dict = super(Config, self).__getattribute__('_cfg_dict') cfg_dict = super(Config, self).__getattribute__('_cfg_dict')
Config._merge_a_into_b(option_cfg_dict, cfg_dict) Config._merge_a_into_b(option_cfg_dict, cfg_dict)
class DictAction(Action):
"""
argparse action to split an argument into KEY=VALUE form
on the first = and append to a dictionary. List options should
be passed as comma separated values, i.e KEY=V1,V2,V3
"""
@staticmethod
def _parse_int_float_bool(val):
try:
return int(val)
except ValueError:
pass
try:
return float(val)
except ValueError:
pass
if val.lower() in ['true', 'false']:
return True if val.lower() == 'true' else False
return val
def __call__(self, parser, namespace, values, option_string=None):
options = {}
for kv in values:
key, val = kv.split('=', maxsplit=1)
val = [self._parse_int_float_bool(v) for v in val.split(',')]
if len(val) == 1:
val = val[0]
options[key] = val
setattr(namespace, self.dest, options)
# Copyright (c) Open-MMLab. All rights reserved. # Copyright (c) Open-MMLab. All rights reserved.
import argparse
import json import json
import os.path as osp import os.path as osp
import sys import sys
...@@ -6,7 +7,7 @@ import tempfile ...@@ -6,7 +7,7 @@ import tempfile
import pytest import pytest
from mmcv import Config from mmcv import Config, DictAction
def test_construct(): def test_construct():
...@@ -102,9 +103,9 @@ def test_merge_recursive_bases(): ...@@ -102,9 +103,9 @@ def test_merge_recursive_bases():
def test_merge_from_dict(): def test_merge_from_dict():
cfg_file = osp.join(osp.dirname(__file__), 'data/config/a.py') cfg_file = osp.join(osp.dirname(__file__), 'data/config/a.py')
cfg = Config.fromfile(cfg_file) cfg = Config.fromfile(cfg_file)
input_options = {'item2.a': 1, 'item3': False} input_options = {'item2.a': 1, 'item2.b': 0.1, 'item3': False}
cfg.merge_from_dict(input_options) cfg.merge_from_dict(input_options)
assert cfg.item2 == dict(a=1) assert cfg.item2 == dict(a=1, b=0.1)
assert cfg.item3 is False assert cfg.item3 is False
...@@ -186,3 +187,19 @@ def test_pretty_text(): ...@@ -186,3 +187,19 @@ def test_pretty_text():
f.write(cfg.pretty_text) f.write(cfg.pretty_text)
text_cfg = Config.fromfile(text_cfg_filename) text_cfg = Config.fromfile(text_cfg_filename)
assert text_cfg._cfg_dict == cfg._cfg_dict assert text_cfg._cfg_dict == cfg._cfg_dict
def test_dict_action():
parser = argparse.ArgumentParser(description='Train a detector')
parser.add_argument(
'--options', nargs='+', action=DictAction, help='custom options')
args = parser.parse_args(
['--options', 'item2.a=1', 'item2.b=0.1', 'item2.c=x', 'item3=false'])
out_dict = {'item2.a': 1, 'item2.b': 0.1, 'item2.c': 'x', 'item3': False}
assert args.options == out_dict
cfg_file = osp.join(osp.dirname(__file__), 'data/config/a.py')
cfg = Config.fromfile(cfg_file)
cfg.merge_from_dict(args.options)
assert cfg.item2 == dict(a=1, b=0.1, c='x')
assert cfg.item3 is False
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment