config.py 15.6 KB
Newer Older
Kai Chen's avatar
Kai Chen committed
1
# Copyright (c) Open-MMLab. All rights reserved.
2
import ast
Kai Chen's avatar
Kai Chen committed
3
import os.path as osp
4
import platform
5
import re
lizz's avatar
lizz committed
6
import shutil
Kai Chen's avatar
Kai Chen committed
7
import sys
lizz's avatar
lizz committed
8
import tempfile
9
from argparse import Action, ArgumentParser
10
from collections import abc
Kai Chen's avatar
Kai Chen committed
11
12
13
from importlib import import_module

from addict import Dict
14
from yapf.yapflib.yapf_api import FormatCode
Kai Chen's avatar
Kai Chen committed
15

16
17
from .path import check_file_exist

Jerry Jiarui XU's avatar
Jerry Jiarui XU committed
18
19
BASE_KEY = '_base_'
DELETE_KEY = '_delete_'
20
RESERVED_KEYS = ['filename', 'text', 'pretty_text']
Jerry Jiarui XU's avatar
Jerry Jiarui XU committed
21

22
23
24
25
26
27
28
29
30
31

class ConfigDict(Dict):

    def __missing__(self, name):
        raise KeyError(name)

    def __getattr__(self, name):
        try:
            value = super(ConfigDict, self).__getattr__(name)
        except KeyError:
Cao Yuhang's avatar
Cao Yuhang committed
32
33
            ex = AttributeError(f"'{self.__class__.__name__}' object has no "
                                f"attribute '{name}'")
34
35
36
37
38
39
        except Exception as e:
            ex = e
        else:
            return value
        raise ex

Kai Chen's avatar
Kai Chen committed
40
41
42
43
44
45
46
47
48
49
50
51

def add_args(parser, cfg, prefix=''):
    for k, v in cfg.items():
        if isinstance(v, str):
            parser.add_argument('--' + prefix + k)
        elif isinstance(v, int):
            parser.add_argument('--' + prefix + k, type=int)
        elif isinstance(v, float):
            parser.add_argument('--' + prefix + k, type=float)
        elif isinstance(v, bool):
            parser.add_argument('--' + prefix + k, action='store_true')
        elif isinstance(v, dict):
Jerry Jiarui XU's avatar
Jerry Jiarui XU committed
52
            add_args(parser, v, prefix + k + '.')
53
        elif isinstance(v, abc.Iterable):
Kai Chen's avatar
Kai Chen committed
54
55
            parser.add_argument('--' + prefix + k, type=type(v[0]), nargs='+')
        else:
Cao Yuhang's avatar
Cao Yuhang committed
56
            print(f'cannot parse key {prefix + k} of type {type(v)}')
Kai Chen's avatar
Kai Chen committed
57
58
59
    return parser


lizz's avatar
lizz committed
60
class Config:
Kai Chen's avatar
Kai Chen committed
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
    """A facility for config and config files.

    It supports common file formats as configs: python/json/yaml. The interface
    is the same as a dict object and also allows access config values as
    attributes.

    Example:
        >>> cfg = Config(dict(a=1, b=dict(b1=[0, 1])))
        >>> cfg.a
        1
        >>> cfg.b
        {'b1': [0, 1]}
        >>> cfg.b.b1
        [0, 1]
        >>> cfg = Config.fromfile('tests/data/config/a.py')
        >>> cfg.filename
        "/home/kchen/projects/mmcv/tests/data/config/a.py"
        >>> cfg.item4
        'test'
        >>> cfg
        "Config [path: /home/kchen/projects/mmcv/tests/data/config/a.py]: "
        "{'item1': [1, 2], 'item2': {'a': 0}, 'item3': True, 'item4': 'test'}"
    """

85
86
87
88
89
90
    @staticmethod
    def _validate_py_syntax(filename):
        with open(filename) as f:
            content = f.read()
        try:
            ast.parse(content)
91
        except SyntaxError as e:
92
            raise SyntaxError('There are syntax errors in config '
93
                              f'file {filename}: {e}')
94

Kai Chen's avatar
Kai Chen committed
95
    @staticmethod
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
    def _substitute_predefined_vars(filename, temp_config_name):
        file_dirname = osp.dirname(filename)
        file_basename = osp.basename(filename)
        file_basename_no_extension = osp.splitext(file_basename)[0]
        file_extname = osp.splitext(filename)[1]
        support_templates = dict(
            fileDirname=file_dirname,
            fileBasename=file_basename,
            fileBasenameNoExtension=file_basename_no_extension,
            fileExtname=file_extname)
        config_file = open(filename).read()
        for key, value in support_templates.items():
            regexp = r'\{\{\s*' + str(key) + r'\s*\}\}'
            config_file = re.sub(regexp, value, config_file)
        with open(temp_config_name, 'w') as tmp_config_file:
            tmp_config_file.write(config_file)

    @staticmethod
    def _file2dict(filename, use_predefined_variables=True):
Kai Chen's avatar
Kai Chen committed
115
        filename = osp.abspath(osp.expanduser(filename))
116
        check_file_exist(filename)
117
118
119
120
121
122
123
        fileExtname = osp.splitext(filename)[1]
        if fileExtname not in ['.py', '.json', '.yaml', 'yml']:
            raise IOError('Only py/yml/yaml/json type are supported now!')

        with tempfile.TemporaryDirectory() as temp_config_dir:
            temp_config_file = tempfile.NamedTemporaryFile(
                dir=temp_config_dir, suffix=fileExtname)
124
125
            if platform.system() == 'Windows':
                temp_config_file.close()
126
127
128
129
130
131
132
133
134
            temp_config_name = osp.basename(temp_config_file.name)
            # Substitute predefined variables
            if use_predefined_variables:
                Config._substitute_predefined_vars(filename,
                                                   temp_config_file.name)
            else:
                shutil.copyfile(filename, temp_config_file.name)

            if filename.endswith('.py'):
135
                temp_module_name = osp.splitext(temp_config_name)[0]
lizz's avatar
lizz committed
136
                sys.path.insert(0, temp_config_dir)
137
                Config._validate_py_syntax(filename)
138
                mod = import_module(temp_module_name)
lizz's avatar
lizz committed
139
140
141
142
143
144
                sys.path.pop(0)
                cfg_dict = {
                    name: value
                    for name, value in mod.__dict__.items()
                    if not name.startswith('__')
                }
Jerry Jiarui XU's avatar
Jerry Jiarui XU committed
145
                # delete imported module
146
                del sys.modules[temp_module_name]
147
148
149
150
151
            elif filename.endswith(('.yml', '.yaml', '.json')):
                import mmcv
                cfg_dict = mmcv.load(temp_config_file.name)
            # close temp file
            temp_config_file.close()
Jerry Jiarui XU's avatar
Jerry Jiarui XU committed
152
153
154
155
156

        cfg_text = filename + '\n'
        with open(filename, 'r') as f:
            cfg_text += f.read()

157
        if BASE_KEY in cfg_dict:
Jerry Jiarui XU's avatar
Jerry Jiarui XU committed
158
            cfg_dir = osp.dirname(filename)
159
            base_filename = cfg_dict.pop(BASE_KEY)
Jerry Jiarui XU's avatar
Jerry Jiarui XU committed
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
            base_filename = base_filename if isinstance(
                base_filename, list) else [base_filename]

            cfg_dict_list = list()
            cfg_text_list = list()
            for f in base_filename:
                _cfg_dict, _cfg_text = Config._file2dict(osp.join(cfg_dir, f))
                cfg_dict_list.append(_cfg_dict)
                cfg_text_list.append(_cfg_text)

            base_cfg_dict = dict()
            for c in cfg_dict_list:
                if len(base_cfg_dict.keys() & c.keys()) > 0:
                    raise KeyError('Duplicate key is not allowed among bases')
                base_cfg_dict.update(c)

176
            base_cfg_dict = Config._merge_a_into_b(cfg_dict, base_cfg_dict)
Jerry Jiarui XU's avatar
Jerry Jiarui XU committed
177
178
179
180
181
182
183
184
185
186
            cfg_dict = base_cfg_dict

            # merge cfg_text
            cfg_text_list.append(cfg_text)
            cfg_text = '\n'.join(cfg_text_list)

        return cfg_dict, cfg_text

    @staticmethod
    def _merge_a_into_b(a, b):
187
188
189
190
        # merge dict `a` into dict `b` (non-inplace). values in `a` will
        # overwrite `b`.
        # copy first to avoid inplace modification
        b = b.copy()
Jerry Jiarui XU's avatar
Jerry Jiarui XU committed
191
192
193
194
        for k, v in a.items():
            if isinstance(v, dict) and k in b and not v.pop(DELETE_KEY, False):
                if not isinstance(b[k], dict):
                    raise TypeError(
195
196
197
198
                        f'{k}={v} in child config cannot inherit from base '
                        f'because {k} is a dict in the child config but is of '
                        f'type {type(b[k])} in base config. You may set '
                        f'`{DELETE_KEY}=True` to ignore the base config')
199
                b[k] = Config._merge_a_into_b(v, b[k])
Jerry Jiarui XU's avatar
Jerry Jiarui XU committed
200
201
            else:
                b[k] = v
202
        return b
Jerry Jiarui XU's avatar
Jerry Jiarui XU committed
203
204

    @staticmethod
205
206
207
    def fromfile(filename, use_predefined_variables=True):
        cfg_dict, cfg_text = Config._file2dict(filename,
                                               use_predefined_variables)
Jerry Jiarui XU's avatar
Jerry Jiarui XU committed
208
        return Config(cfg_dict, cfg_text=cfg_text, filename=filename)
Kai Chen's avatar
Kai Chen committed
209
210
211

    @staticmethod
    def auto_argparser(description=None):
Kai Chen's avatar
Kai Chen committed
212
        """Generate argparser from config file automatically (experimental)"""
Kai Chen's avatar
Kai Chen committed
213
214
215
        partial_parser = ArgumentParser(description=description)
        partial_parser.add_argument('config', help='config file path')
        cfg_file = partial_parser.parse_known_args()[0].config
216
        cfg = Config.fromfile(cfg_file)
Kai Chen's avatar
Kai Chen committed
217
218
219
220
221
        parser = ArgumentParser(description=description)
        parser.add_argument('config', help='config file path')
        add_args(parser, cfg)
        return parser, cfg

Jerry Jiarui XU's avatar
Jerry Jiarui XU committed
222
    def __init__(self, cfg_dict=None, cfg_text=None, filename=None):
Kai Chen's avatar
Kai Chen committed
223
224
225
        if cfg_dict is None:
            cfg_dict = dict()
        elif not isinstance(cfg_dict, dict):
Cao Yuhang's avatar
Cao Yuhang committed
226
227
            raise TypeError('cfg_dict must be a dict, but '
                            f'got {type(cfg_dict)}')
228
229
230
        for key in cfg_dict:
            if key in RESERVED_KEYS:
                raise KeyError(f'{key} is reserved for config file')
Kai Chen's avatar
Kai Chen committed
231

232
        super(Config, self).__setattr__('_cfg_dict', ConfigDict(cfg_dict))
Kai Chen's avatar
Kai Chen committed
233
        super(Config, self).__setattr__('_filename', filename)
Jerry Jiarui XU's avatar
Jerry Jiarui XU committed
234
235
236
        if cfg_text:
            text = cfg_text
        elif filename:
Kai Chen's avatar
Kai Chen committed
237
            with open(filename, 'r') as f:
Jerry Jiarui XU's avatar
Jerry Jiarui XU committed
238
                text = f.read()
Kai Chen's avatar
Kai Chen committed
239
        else:
Jerry Jiarui XU's avatar
Jerry Jiarui XU committed
240
241
            text = ''
        super(Config, self).__setattr__('_text', text)
Kai Chen's avatar
Kai Chen committed
242
243
244
245
246
247
248
249
250

    @property
    def filename(self):
        return self._filename

    @property
    def text(self):
        return self._text

Jerry Jiarui XU's avatar
Jerry Jiarui XU committed
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
    @property
    def pretty_text(self):

        indent = 4

        def _indent(s_, num_spaces):
            s = s_.split('\n')
            if len(s) == 1:
                return s_
            first = s.pop(0)
            s = [(num_spaces * ' ') + line for line in s]
            s = '\n'.join(s)
            s = first + '\n' + s
            return s

266
        def _format_basic_types(k, v, use_mapping=False):
Jerry Jiarui XU's avatar
Jerry Jiarui XU committed
267
            if isinstance(v, str):
Cao Yuhang's avatar
Cao Yuhang committed
268
                v_str = f"'{v}'"
Jerry Jiarui XU's avatar
Jerry Jiarui XU committed
269
270
            else:
                v_str = str(v)
271
272
273
274
275
276

            if use_mapping:
                k_str = f"'{k}'" if isinstance(k, str) else str(k)
                attr_str = f'{k_str}: {v_str}'
            else:
                attr_str = f'{str(k)}={v_str}'
Jerry Jiarui XU's avatar
Jerry Jiarui XU committed
277
278
279
280
            attr_str = _indent(attr_str, indent)

            return attr_str

281
        def _format_list(k, v, use_mapping=False):
Jerry Jiarui XU's avatar
Jerry Jiarui XU committed
282
283
284
285
            # check if all items in the list are dict
            if all(isinstance(_, dict) for _ in v):
                v_str = '[\n'
                v_str += '\n'.join(
Cao Yuhang's avatar
Cao Yuhang committed
286
                    f'dict({_indent(_format_dict(v_), indent)}),'
Jerry Jiarui XU's avatar
Jerry Jiarui XU committed
287
                    for v_ in v).rstrip(',')
288
289
290
291
292
                if use_mapping:
                    k_str = f"'{k}'" if isinstance(k, str) else str(k)
                    attr_str = f'{k_str}: {v_str}'
                else:
                    attr_str = f'{str(k)}={v_str}'
Jerry Jiarui XU's avatar
Jerry Jiarui XU committed
293
294
                attr_str = _indent(attr_str, indent) + ']'
            else:
295
                attr_str = _format_basic_types(k, v, use_mapping)
Jerry Jiarui XU's avatar
Jerry Jiarui XU committed
296
297
            return attr_str

298
299
300
301
302
303
304
305
        def _contain_invalid_identifier(dict_str):
            contain_invalid_identifier = False
            for key_name in dict_str:
                contain_invalid_identifier |= \
                    (not str(key_name).isidentifier())
            return contain_invalid_identifier

        def _format_dict(input_dict, outest_level=False):
Jerry Jiarui XU's avatar
Jerry Jiarui XU committed
306
307
            r = ''
            s = []
308
309
310
311
312
313

            use_mapping = _contain_invalid_identifier(input_dict)
            if use_mapping:
                r += '{'
            for idx, (k, v) in enumerate(input_dict.items()):
                is_last = idx >= len(input_dict) - 1
Jerry Jiarui XU's avatar
Jerry Jiarui XU committed
314
315
316
                end = '' if outest_level or is_last else ','
                if isinstance(v, dict):
                    v_str = '\n' + _format_dict(v)
317
318
319
320
321
                    if use_mapping:
                        k_str = f"'{k}'" if isinstance(k, str) else str(k)
                        attr_str = f'{k_str}: dict({v_str}'
                    else:
                        attr_str = f'{str(k)}=dict({v_str}'
Jerry Jiarui XU's avatar
Jerry Jiarui XU committed
322
323
                    attr_str = _indent(attr_str, indent) + ')' + end
                elif isinstance(v, list):
324
                    attr_str = _format_list(k, v, use_mapping) + end
Jerry Jiarui XU's avatar
Jerry Jiarui XU committed
325
                else:
326
                    attr_str = _format_basic_types(k, v, use_mapping) + end
Jerry Jiarui XU's avatar
Jerry Jiarui XU committed
327
328
329

                s.append(attr_str)
            r += '\n'.join(s)
330
331
            if use_mapping:
                r += '}'
Jerry Jiarui XU's avatar
Jerry Jiarui XU committed
332
333
334
335
            return r

        cfg_dict = self._cfg_dict.to_dict()
        text = _format_dict(cfg_dict, outest_level=True)
336
337
338
339
340
341
        # copied from setup.cfg
        yapf_style = dict(
            based_on_style='pep8',
            blank_line_before_nested_class_or_def=True,
            split_before_expression_after_opening_paren=True)
        text, _ = FormatCode(text, style_config=yapf_style, verify=True)
Jerry Jiarui XU's avatar
Jerry Jiarui XU committed
342
343
344

        return text

Kai Chen's avatar
Kai Chen committed
345
    def __repr__(self):
Cao Yuhang's avatar
Cao Yuhang committed
346
        return f'Config (path: {self.filename}): {self._cfg_dict.__repr__()}'
Kai Chen's avatar
Kai Chen committed
347
348
349
350
351
352
353
354
355
356
357
358

    def __len__(self):
        return len(self._cfg_dict)

    def __getattr__(self, name):
        return getattr(self._cfg_dict, name)

    def __getitem__(self, name):
        return self._cfg_dict.__getitem__(name)

    def __setattr__(self, name, value):
        if isinstance(value, dict):
359
            value = ConfigDict(value)
Kai Chen's avatar
Kai Chen committed
360
361
362
363
        self._cfg_dict.__setattr__(name, value)

    def __setitem__(self, name, value):
        if isinstance(value, dict):
364
            value = ConfigDict(value)
Kai Chen's avatar
Kai Chen committed
365
366
367
368
        self._cfg_dict.__setitem__(name, value)

    def __iter__(self):
        return iter(self._cfg_dict)
Jerry Jiarui XU's avatar
Jerry Jiarui XU committed
369

370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
    def dump(self, file=None):
        cfg_dict = super(Config, self).__getattribute__('_cfg_dict').to_dict()
        if self.filename.endswith('.py'):
            if file is None:
                return self.pretty_text
            else:
                with open(file, 'w') as f:
                    f.write(self.pretty_text)
        else:
            import mmcv
            if file is None:
                file_format = self.filename.split('.')[-1]
                return mmcv.dump(cfg_dict, file_format=file_format)
            else:
                mmcv.dump(cfg_dict, file)
Jerry Jiarui XU's avatar
Jerry Jiarui XU committed
385
386

    def merge_from_dict(self, options):
Kai Chen's avatar
Kai Chen committed
387
        """Merge list into cfg_dict.
Jerry Jiarui XU's avatar
Jerry Jiarui XU committed
388
389

        Merge the dict parsed by MultipleKVAction into this cfg.
Kai Chen's avatar
Kai Chen committed
390
391

        Examples:
392
393
            >>> options = {'model.backbone.depth': 50,
            ...            'model.backbone.with_cp':True}
Jerry Jiarui XU's avatar
Jerry Jiarui XU committed
394
395
            >>> cfg = Config(dict(model=dict(backbone=dict(type='ResNet'))))
            >>> cfg.merge_from_dict(options)
396
397
398
            >>> cfg_dict = super(Config, self).__getattribute__('_cfg_dict')
            >>> assert cfg_dict == dict(
            ...     model=dict(backbone=dict(depth=50, with_cp=True)))
Jerry Jiarui XU's avatar
Jerry Jiarui XU committed
399
400
401
402
403
404
405
406
407

        Args:
            options (dict): dict of configs to merge from.
        """
        option_cfg_dict = {}
        for full_key, v in options.items():
            d = option_cfg_dict
            key_list = full_key.split('.')
            for subkey in key_list[:-1]:
408
                d.setdefault(subkey, ConfigDict())
Jerry Jiarui XU's avatar
Jerry Jiarui XU committed
409
410
411
412
413
                d = d[subkey]
            subkey = key_list[-1]
            d[subkey] = v

        cfg_dict = super(Config, self).__getattribute__('_cfg_dict')
414
415
        super(Config, self).__setattr__(
            '_cfg_dict', Config._merge_a_into_b(option_cfg_dict, cfg_dict))
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447


class DictAction(Action):
    """
    argparse action to split an argument into KEY=VALUE form
    on the first = and append to a dictionary. List options should
    be passed as comma separated values, i.e KEY=V1,V2,V3
    """

    @staticmethod
    def _parse_int_float_bool(val):
        try:
            return int(val)
        except ValueError:
            pass
        try:
            return float(val)
        except ValueError:
            pass
        if val.lower() in ['true', 'false']:
            return True if val.lower() == 'true' else False
        return val

    def __call__(self, parser, namespace, values, option_string=None):
        options = {}
        for kv in values:
            key, val = kv.split('=', maxsplit=1)
            val = [self._parse_int_float_bool(v) for v in val.split(',')]
            if len(val) == 1:
                val = val[0]
            options[key] = val
        setattr(namespace, self.dest, options)