Unverified Commit 1e8a2121 authored by Jerry Jiarui XU's avatar Jerry Jiarui XU Committed by GitHub
Browse files

Add DictAction and docs for config (#243)

* fixed merge_from_dict, add DictAction

* add config docs

* fixed format type

* change to easy example

* update docs

* update docs
parent 76a064ee
......@@ -10,7 +10,7 @@ Here is an example of the config file `test.py`.
```python
a = 1
b = {'b1': [0, 1, 2], 'b2': None}
b = dict(b1=[0, 1, 2], b2=None)
c = (1, 2)
d = 'string'
```
......@@ -18,11 +18,108 @@ d = 'string'
To load and use configs
```python
cfg = Config.fromfile('test.py')
assert cfg.a == 1
assert cfg.b.b1 == [0, 1, 2]
cfg.c = None
assert cfg.c == None
>>> cfg = Config.fromfile('test.py')
>>> print(cfg)
>>> dict(a=1,
... b=dict(b1=[0, 1, 2], b2=None),
... c=(1, 2),
... d='string')
```
For all format configs, inheritance is supported. To reuse fields in other config files,
specify `_base_='./config_a.py'` or a list of configs `_base_=['./config_a.py', './config_b.py']`.
Here are 4 examples of config inheritance.
`config_a.py`
```python
a = 1
b = dict(b1=[0, 1, 2], b2=None)
```
#### Inherit from base config without overlaped keys.
`config_b.py`
```python
_base_ = './config_a.py'
c = (1, 2)
d = 'string'
```
```python
>>> cfg = Config.fromfile('./config_b.py')
>>> print(cfg)
>>> dict(a=1,
... b=dict(b1=[0, 1, 2], b2=None),
... c=(1, 2),
... d='string')
```
New fields in `config_b.py` are combined with old fields in `config_a.py`
#### Inherit from base config with overlaped keys.
`config_c.py`
```python
_base_ = './config_a.py'
b = dict(b2=1)
c = (1, 2)
```
```python
>>> cfg = Config.fromfile('./config_c.py')
>>> print(cfg)
>>> dict(a=1,
... b=dict(b1=[0, 1, 2], b2=1),
... c=(1, 2))
```
`b.b2=None` in `config_a` is replaced with `b.b2=1` in `config_c.py`.
#### Inherit from base config with ignored fields.
`config_d.py`
```python
_base_ = './config_a.py'
b = dict(_delete_=True, b2=None, b3=0.1)
c = (1, 2)
```
```python
>>> cfg = Config.fromfile('./config_d.py')
>>> print(cfg)
>>> dict(a=1,
... b=dict(b2=None, b3=0.1),
... c=(1, 2))
```
You may also set `_delete_=True` to ignore some fields in base configs. All old keys `b1, b2, b3` in `b` are replaced with new keys `b2, b3`.
#### Inherit from multiple base configs (the base configs should not contain the same keys).
`config_e.py`
```python
c = (1, 2)
d = 'string'
```
`config_f.py`
```python
_base_ = ['./config_a.py', './config_e.py']
```
```python
>>> cfg = Config.fromfile('./config_f.py')
>>> print(cfg)
>>> dict(a=1,
... b=dict(b1=[0, 1, 2], b2=None),
... c=(1, 2),
... d='string')
```
### ProgressBar
......@@ -72,8 +169,6 @@ for i, task in enumerate(mmcv.track_iter_progress(tasks)):
print(task)
```
### Timer
It is convinient to compute the runtime of a code block with `Timer`.
......
# Copyright (c) Open-MMLab. All rights reserved.
from .config import Config, ConfigDict
from .config import Config, ConfigDict, DictAction
from .logging import get_logger, print_log
from .misc import (check_prerequisites, concat_list, is_list_of, is_seq_of,
is_str, is_tuple_of, iter_cast, list_cast,
......@@ -13,11 +13,11 @@ from .registry import Registry, build_from_cfg
from .timer import Timer, TimerError, check_time
__all__ = [
'Config', 'ConfigDict', 'get_logger', 'print_log', 'is_str', 'iter_cast',
'list_cast', 'tuple_cast', 'is_seq_of', 'is_list_of', 'is_tuple_of',
'slice_list', 'concat_list', 'check_prerequisites', 'requires_package',
'requires_executable', 'is_filepath', 'fopen', 'check_file_exist',
'mkdir_or_exist', 'symlink', 'scandir', 'ProgressBar', 'track_progress',
'track_iter_progress', 'track_parallel_progress', 'Registry',
'build_from_cfg', 'Timer', 'TimerError', 'check_time'
'Config', 'ConfigDict', 'DictAction', 'get_logger', 'print_log', 'is_str',
'iter_cast', 'list_cast', 'tuple_cast', 'is_seq_of', 'is_list_of',
'is_tuple_of', 'slice_list', 'concat_list', 'check_prerequisites',
'requires_package', 'requires_executable', 'is_filepath', 'fopen',
'check_file_exist', 'mkdir_or_exist', 'symlink', 'scandir', 'ProgressBar',
'track_progress', 'track_iter_progress', 'track_parallel_progress',
'Registry', 'build_from_cfg', 'Timer', 'TimerError', 'check_time'
]
......@@ -4,7 +4,7 @@ import os.path as osp
import shutil
import sys
import tempfile
from argparse import ArgumentParser
from argparse import Action, ArgumentParser
from collections import abc
from importlib import import_module
......@@ -107,9 +107,9 @@ class Config(object):
with open(filename, 'r') as f:
cfg_text += f.read()
if '_base_' in cfg_dict:
if BASE_KEY in cfg_dict:
cfg_dir = osp.dirname(filename)
base_filename = cfg_dict.pop('_base_')
base_filename = cfg_dict.pop(BASE_KEY)
base_filename = base_filename if isinstance(
base_filename, list) else [base_filename]
......@@ -137,12 +137,14 @@ class Config(object):
@staticmethod
def _merge_a_into_b(a, b):
# merge dict a into dict b. values in a will overwrite b.
# merge dict `a` into dict `b`. values in `a` will overwrite `b`.
for k, v in a.items():
if isinstance(v, dict) and k in b and not v.pop(DELETE_KEY, False):
if not isinstance(b[k], dict):
raise TypeError(
'Cannot inherit key {} from base!'.format(k))
'{}={} cannot be inherited from base because {} is a '
'dict in the child config. You may set `{}=True` to '
'ignore the base config'.format(k, v, k, DELETE_KEY))
Config._merge_a_into_b(v, b[k])
else:
b[k] = v
......@@ -289,9 +291,13 @@ class Config(object):
Merge the dict parsed by MultipleKVAction into this cfg.
Example,
>>> options = {'model.backbone.depth': 50}
>>> options = {'model.backbone.depth': 50,
... 'model.backbone.with_cp':True}
>>> cfg = Config(dict(model=dict(backbone=dict(type='ResNet'))))
>>> cfg.merge_from_dict(options)
>>> cfg_dict = super(Config, self).__getattribute__('_cfg_dict')
>>> assert cfg_dict == dict(
... model=dict(backbone=dict(depth=50, with_cp=True)))
Args:
options (dict): dict of configs to merge from.
......@@ -301,10 +307,42 @@ class Config(object):
d = option_cfg_dict
key_list = full_key.split('.')
for subkey in key_list[:-1]:
d[subkey] = ConfigDict()
d.setdefault(subkey, ConfigDict())
d = d[subkey]
subkey = key_list[-1]
d[subkey] = v
cfg_dict = super(Config, self).__getattribute__('_cfg_dict')
Config._merge_a_into_b(option_cfg_dict, cfg_dict)
class DictAction(Action):
"""
argparse action to split an argument into KEY=VALUE form
on the first = and append to a dictionary. List options should
be passed as comma separated values, i.e KEY=V1,V2,V3
"""
@staticmethod
def _parse_int_float_bool(val):
try:
return int(val)
except ValueError:
pass
try:
return float(val)
except ValueError:
pass
if val.lower() in ['true', 'false']:
return True if val.lower() == 'true' else False
return val
def __call__(self, parser, namespace, values, option_string=None):
options = {}
for kv in values:
key, val = kv.split('=', maxsplit=1)
val = [self._parse_int_float_bool(v) for v in val.split(',')]
if len(val) == 1:
val = val[0]
options[key] = val
setattr(namespace, self.dest, options)
# Copyright (c) Open-MMLab. All rights reserved.
import argparse
import json
import os.path as osp
import sys
......@@ -6,7 +7,7 @@ import tempfile
import pytest
from mmcv import Config
from mmcv import Config, DictAction
def test_construct():
......@@ -102,9 +103,9 @@ def test_merge_recursive_bases():
def test_merge_from_dict():
cfg_file = osp.join(osp.dirname(__file__), 'data/config/a.py')
cfg = Config.fromfile(cfg_file)
input_options = {'item2.a': 1, 'item3': False}
input_options = {'item2.a': 1, 'item2.b': 0.1, 'item3': False}
cfg.merge_from_dict(input_options)
assert cfg.item2 == dict(a=1)
assert cfg.item2 == dict(a=1, b=0.1)
assert cfg.item3 is False
......@@ -186,3 +187,19 @@ def test_pretty_text():
f.write(cfg.pretty_text)
text_cfg = Config.fromfile(text_cfg_filename)
assert text_cfg._cfg_dict == cfg._cfg_dict
def test_dict_action():
parser = argparse.ArgumentParser(description='Train a detector')
parser.add_argument(
'--options', nargs='+', action=DictAction, help='custom options')
args = parser.parse_args(
['--options', 'item2.a=1', 'item2.b=0.1', 'item2.c=x', 'item3=false'])
out_dict = {'item2.a': 1, 'item2.b': 0.1, 'item2.c': 'x', 'item3': False}
assert args.options == out_dict
cfg_file = osp.join(osp.dirname(__file__), 'data/config/a.py')
cfg = Config.fromfile(cfg_file)
cfg.merge_from_dict(args.options)
assert cfg.item2 == dict(a=1, b=0.1, c='x')
assert cfg.item3 is False
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment