config.py 11.9 KB
Newer Older
Kai Chen's avatar
Kai Chen committed
1
# Copyright (c) Open-MMLab. All rights reserved.
Jerry Jiarui XU's avatar
Jerry Jiarui XU committed
2
import json
Kai Chen's avatar
Kai Chen committed
3
import os.path as osp
lizz's avatar
lizz committed
4
import shutil
Kai Chen's avatar
Kai Chen committed
5
import sys
lizz's avatar
lizz committed
6
import tempfile
7
from argparse import Action, ArgumentParser
8
from collections import abc
Kai Chen's avatar
Kai Chen committed
9
10
11
12
from importlib import import_module

from addict import Dict

13
14
from .path import check_file_exist

Jerry Jiarui XU's avatar
Jerry Jiarui XU committed
15
16
17
BASE_KEY = '_base_'
DELETE_KEY = '_delete_'

18
19
20
21
22
23
24
25
26
27

class ConfigDict(Dict):

    def __missing__(self, name):
        raise KeyError(name)

    def __getattr__(self, name):
        try:
            value = super(ConfigDict, self).__getattr__(name)
        except KeyError:
Cao Yuhang's avatar
Cao Yuhang committed
28
29
            ex = AttributeError(f"'{self.__class__.__name__}' object has no "
                                f"attribute '{name}'")
30
31
32
33
34
35
        except Exception as e:
            ex = e
        else:
            return value
        raise ex

Kai Chen's avatar
Kai Chen committed
36
37
38
39
40
41
42
43
44
45
46
47

def add_args(parser, cfg, prefix=''):
    for k, v in cfg.items():
        if isinstance(v, str):
            parser.add_argument('--' + prefix + k)
        elif isinstance(v, int):
            parser.add_argument('--' + prefix + k, type=int)
        elif isinstance(v, float):
            parser.add_argument('--' + prefix + k, type=float)
        elif isinstance(v, bool):
            parser.add_argument('--' + prefix + k, action='store_true')
        elif isinstance(v, dict):
Jerry Jiarui XU's avatar
Jerry Jiarui XU committed
48
            add_args(parser, v, prefix + k + '.')
49
        elif isinstance(v, abc.Iterable):
Kai Chen's avatar
Kai Chen committed
50
51
            parser.add_argument('--' + prefix + k, type=type(v[0]), nargs='+')
        else:
Cao Yuhang's avatar
Cao Yuhang committed
52
            print(f'cannot parse key {prefix + k} of type {type(v)}')
Kai Chen's avatar
Kai Chen committed
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
    return parser


class Config(object):
    """A facility for config and config files.

    It supports common file formats as configs: python/json/yaml. The interface
    is the same as a dict object and also allows access config values as
    attributes.

    Example:
        >>> cfg = Config(dict(a=1, b=dict(b1=[0, 1])))
        >>> cfg.a
        1
        >>> cfg.b
        {'b1': [0, 1]}
        >>> cfg.b.b1
        [0, 1]
        >>> cfg = Config.fromfile('tests/data/config/a.py')
        >>> cfg.filename
        "/home/kchen/projects/mmcv/tests/data/config/a.py"
        >>> cfg.item4
        'test'
        >>> cfg
        "Config [path: /home/kchen/projects/mmcv/tests/data/config/a.py]: "
        "{'item1': [1, 2], 'item2': {'a': 0}, 'item3': True, 'item4': 'test'}"

    """

    @staticmethod
Jerry Jiarui XU's avatar
Jerry Jiarui XU committed
83
    def _file2dict(filename):
Kai Chen's avatar
Kai Chen committed
84
        filename = osp.abspath(osp.expanduser(filename))
85
        check_file_exist(filename)
Kai Chen's avatar
Kai Chen committed
86
        if filename.endswith('.py'):
lizz's avatar
lizz committed
87
            with tempfile.TemporaryDirectory() as temp_config_dir:
88
89
90
                temp_config_file = tempfile.NamedTemporaryFile(
                    dir=temp_config_dir, suffix='.py')
                temp_config_name = osp.basename(temp_config_file.name)
lizz's avatar
lizz committed
91
                shutil.copyfile(filename,
92
93
                                osp.join(temp_config_dir, temp_config_name))
                temp_module_name = osp.splitext(temp_config_name)[0]
lizz's avatar
lizz committed
94
                sys.path.insert(0, temp_config_dir)
95
                mod = import_module(temp_module_name)
lizz's avatar
lizz committed
96
97
98
99
100
101
                sys.path.pop(0)
                cfg_dict = {
                    name: value
                    for name, value in mod.__dict__.items()
                    if not name.startswith('__')
                }
Jerry Jiarui XU's avatar
Jerry Jiarui XU committed
102
                # delete imported module
103
104
105
                del sys.modules[temp_module_name]
                # close temp file
                temp_config_file.close()
wangg12's avatar
wangg12 committed
106
        elif filename.endswith(('.yml', '.yaml', '.json')):
Kai Chen's avatar
Kai Chen committed
107
108
109
            import mmcv
            cfg_dict = mmcv.load(filename)
        else:
wangg12's avatar
wangg12 committed
110
            raise IOError('Only py/yml/yaml/json type are supported now!')
Jerry Jiarui XU's avatar
Jerry Jiarui XU committed
111
112
113
114
115

        cfg_text = filename + '\n'
        with open(filename, 'r') as f:
            cfg_text += f.read()

116
        if BASE_KEY in cfg_dict:
Jerry Jiarui XU's avatar
Jerry Jiarui XU committed
117
            cfg_dir = osp.dirname(filename)
118
            base_filename = cfg_dict.pop(BASE_KEY)
Jerry Jiarui XU's avatar
Jerry Jiarui XU committed
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
            base_filename = base_filename if isinstance(
                base_filename, list) else [base_filename]

            cfg_dict_list = list()
            cfg_text_list = list()
            for f in base_filename:
                _cfg_dict, _cfg_text = Config._file2dict(osp.join(cfg_dir, f))
                cfg_dict_list.append(_cfg_dict)
                cfg_text_list.append(_cfg_text)

            base_cfg_dict = dict()
            for c in cfg_dict_list:
                if len(base_cfg_dict.keys() & c.keys()) > 0:
                    raise KeyError('Duplicate key is not allowed among bases')
                base_cfg_dict.update(c)

135
            base_cfg_dict = Config._merge_a_into_b(cfg_dict, base_cfg_dict)
Jerry Jiarui XU's avatar
Jerry Jiarui XU committed
136
137
138
139
140
141
142
143
144
145
            cfg_dict = base_cfg_dict

            # merge cfg_text
            cfg_text_list.append(cfg_text)
            cfg_text = '\n'.join(cfg_text_list)

        return cfg_dict, cfg_text

    @staticmethod
    def _merge_a_into_b(a, b):
146
147
148
149
        # merge dict `a` into dict `b` (non-inplace). values in `a` will
        # overwrite `b`.
        # copy first to avoid inplace modification
        b = b.copy()
Jerry Jiarui XU's avatar
Jerry Jiarui XU committed
150
151
152
153
        for k, v in a.items():
            if isinstance(v, dict) and k in b and not v.pop(DELETE_KEY, False):
                if not isinstance(b[k], dict):
                    raise TypeError(
154
155
156
157
                        f'{k}={v} in child config cannot inherit from base '
                        f'because {k} is a dict in the child config but is of '
                        f'type {type(b[k])} in base config. You may set '
                        f'`{DELETE_KEY}=True` to ignore the base config')
158
                b[k] = Config._merge_a_into_b(v, b[k])
Jerry Jiarui XU's avatar
Jerry Jiarui XU committed
159
160
            else:
                b[k] = v
161
        return b
Jerry Jiarui XU's avatar
Jerry Jiarui XU committed
162
163
164
165
166

    @staticmethod
    def fromfile(filename):
        cfg_dict, cfg_text = Config._file2dict(filename)
        return Config(cfg_dict, cfg_text=cfg_text, filename=filename)
Kai Chen's avatar
Kai Chen committed
167
168
169
170
171
172
173
174

    @staticmethod
    def auto_argparser(description=None):
        """Generate argparser from config file automatically (experimental)
        """
        partial_parser = ArgumentParser(description=description)
        partial_parser.add_argument('config', help='config file path')
        cfg_file = partial_parser.parse_known_args()[0].config
175
        cfg = Config.fromfile(cfg_file)
Kai Chen's avatar
Kai Chen committed
176
177
178
179
180
        parser = ArgumentParser(description=description)
        parser.add_argument('config', help='config file path')
        add_args(parser, cfg)
        return parser, cfg

Jerry Jiarui XU's avatar
Jerry Jiarui XU committed
181
    def __init__(self, cfg_dict=None, cfg_text=None, filename=None):
Kai Chen's avatar
Kai Chen committed
182
183
184
        if cfg_dict is None:
            cfg_dict = dict()
        elif not isinstance(cfg_dict, dict):
Cao Yuhang's avatar
Cao Yuhang committed
185
186
            raise TypeError('cfg_dict must be a dict, but '
                            f'got {type(cfg_dict)}')
Kai Chen's avatar
Kai Chen committed
187

188
        super(Config, self).__setattr__('_cfg_dict', ConfigDict(cfg_dict))
Kai Chen's avatar
Kai Chen committed
189
        super(Config, self).__setattr__('_filename', filename)
Jerry Jiarui XU's avatar
Jerry Jiarui XU committed
190
191
192
        if cfg_text:
            text = cfg_text
        elif filename:
Kai Chen's avatar
Kai Chen committed
193
            with open(filename, 'r') as f:
Jerry Jiarui XU's avatar
Jerry Jiarui XU committed
194
                text = f.read()
Kai Chen's avatar
Kai Chen committed
195
        else:
Jerry Jiarui XU's avatar
Jerry Jiarui XU committed
196
197
            text = ''
        super(Config, self).__setattr__('_text', text)
Kai Chen's avatar
Kai Chen committed
198
199
200
201
202
203
204
205
206

    @property
    def filename(self):
        return self._filename

    @property
    def text(self):
        return self._text

Jerry Jiarui XU's avatar
Jerry Jiarui XU committed
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
    @property
    def pretty_text(self):

        indent = 4

        def _indent(s_, num_spaces):
            s = s_.split('\n')
            if len(s) == 1:
                return s_
            first = s.pop(0)
            s = [(num_spaces * ' ') + line for line in s]
            s = '\n'.join(s)
            s = first + '\n' + s
            return s

        def _format_basic_types(k, v):
            if isinstance(v, str):
Cao Yuhang's avatar
Cao Yuhang committed
224
                v_str = f"'{v}'"
Jerry Jiarui XU's avatar
Jerry Jiarui XU committed
225
226
            else:
                v_str = str(v)
Cao Yuhang's avatar
Cao Yuhang committed
227
            attr_str = f'{str(k)}={v_str}'
Jerry Jiarui XU's avatar
Jerry Jiarui XU committed
228
229
230
231
232
233
234
235
236
            attr_str = _indent(attr_str, indent)

            return attr_str

        def _format_list(k, v):
            # check if all items in the list are dict
            if all(isinstance(_, dict) for _ in v):
                v_str = '[\n'
                v_str += '\n'.join(
Cao Yuhang's avatar
Cao Yuhang committed
237
                    f'dict({_indent(_format_dict(v_), indent)}),'
Jerry Jiarui XU's avatar
Jerry Jiarui XU committed
238
                    for v_ in v).rstrip(',')
Cao Yuhang's avatar
Cao Yuhang committed
239
                attr_str = f'{str(k)}={v_str}'
Jerry Jiarui XU's avatar
Jerry Jiarui XU committed
240
241
242
243
244
245
246
247
248
249
250
251
252
                attr_str = _indent(attr_str, indent) + ']'
            else:
                attr_str = _format_basic_types(k, v)
            return attr_str

        def _format_dict(d, outest_level=False):
            r = ''
            s = []
            for idx, (k, v) in enumerate(d.items()):
                is_last = idx >= len(d) - 1
                end = '' if outest_level or is_last else ','
                if isinstance(v, dict):
                    v_str = '\n' + _format_dict(v)
Cao Yuhang's avatar
Cao Yuhang committed
253
                    attr_str = f'{str(k)}=dict({v_str}'
Jerry Jiarui XU's avatar
Jerry Jiarui XU committed
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
                    attr_str = _indent(attr_str, indent) + ')' + end
                elif isinstance(v, list):
                    attr_str = _format_list(k, v) + end
                else:
                    attr_str = _format_basic_types(k, v) + end

                s.append(attr_str)
            r += '\n'.join(s)
            return r

        cfg_dict = self._cfg_dict.to_dict()
        text = _format_dict(cfg_dict, outest_level=True)

        return text

Kai Chen's avatar
Kai Chen committed
269
    def __repr__(self):
Cao Yuhang's avatar
Cao Yuhang committed
270
        return f'Config (path: {self.filename}): {self._cfg_dict.__repr__()}'
Kai Chen's avatar
Kai Chen committed
271
272
273
274
275
276
277
278
279
280
281
282

    def __len__(self):
        return len(self._cfg_dict)

    def __getattr__(self, name):
        return getattr(self._cfg_dict, name)

    def __getitem__(self, name):
        return self._cfg_dict.__getitem__(name)

    def __setattr__(self, name, value):
        if isinstance(value, dict):
283
            value = ConfigDict(value)
Kai Chen's avatar
Kai Chen committed
284
285
286
287
        self._cfg_dict.__setattr__(name, value)

    def __setitem__(self, name, value):
        if isinstance(value, dict):
288
            value = ConfigDict(value)
Kai Chen's avatar
Kai Chen committed
289
290
291
292
        self._cfg_dict.__setitem__(name, value)

    def __iter__(self):
        return iter(self._cfg_dict)
Jerry Jiarui XU's avatar
Jerry Jiarui XU committed
293
294
295
296
297
298
299

    def dump(self):
        cfg_dict = super(Config, self).__getattribute__('_cfg_dict')
        format_text = json.dumps(cfg_dict, indent=2)
        return format_text

    def merge_from_dict(self, options):
Kai Chen's avatar
Kai Chen committed
300
        """Merge list into cfg_dict
Jerry Jiarui XU's avatar
Jerry Jiarui XU committed
301
302

        Merge the dict parsed by MultipleKVAction into this cfg.
Kai Chen's avatar
Kai Chen committed
303
304

        Examples:
305
306
            >>> options = {'model.backbone.depth': 50,
            ...            'model.backbone.with_cp':True}
Jerry Jiarui XU's avatar
Jerry Jiarui XU committed
307
308
            >>> cfg = Config(dict(model=dict(backbone=dict(type='ResNet'))))
            >>> cfg.merge_from_dict(options)
309
310
311
            >>> cfg_dict = super(Config, self).__getattribute__('_cfg_dict')
            >>> assert cfg_dict == dict(
            ...     model=dict(backbone=dict(depth=50, with_cp=True)))
Jerry Jiarui XU's avatar
Jerry Jiarui XU committed
312
313
314
315
316
317
318
319
320

        Args:
            options (dict): dict of configs to merge from.
        """
        option_cfg_dict = {}
        for full_key, v in options.items():
            d = option_cfg_dict
            key_list = full_key.split('.')
            for subkey in key_list[:-1]:
321
                d.setdefault(subkey, ConfigDict())
Jerry Jiarui XU's avatar
Jerry Jiarui XU committed
322
323
324
325
326
                d = d[subkey]
            subkey = key_list[-1]
            d[subkey] = v

        cfg_dict = super(Config, self).__getattribute__('_cfg_dict')
327
328
        super(Config, self).__setattr__(
            '_cfg_dict', Config._merge_a_into_b(option_cfg_dict, cfg_dict))
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360


class DictAction(Action):
    """
    argparse action to split an argument into KEY=VALUE form
    on the first = and append to a dictionary. List options should
    be passed as comma separated values, i.e KEY=V1,V2,V3
    """

    @staticmethod
    def _parse_int_float_bool(val):
        try:
            return int(val)
        except ValueError:
            pass
        try:
            return float(val)
        except ValueError:
            pass
        if val.lower() in ['true', 'false']:
            return True if val.lower() == 'true' else False
        return val

    def __call__(self, parser, namespace, values, option_string=None):
        options = {}
        for kv in values:
            key, val = kv.split('=', maxsplit=1)
            val = [self._parse_int_float_bool(v) for v in val.split(',')]
            if len(val) == 1:
                val = val[0]
            options[key] = val
        setattr(namespace, self.dest, options)