Unverified Commit d9effbd1 authored by lizz's avatar lizz Committed by GitHub
Browse files

Support variables in base files for configs (#1083)



* Support variables in base files for configs
Signed-off-by: default avatarlizz <lizz@sensetime.com>

* Test json and yaml as well
Signed-off-by: default avatarlizz <lizz@sensetime.com>

* Add test for recusive base
Signed-off-by: default avatarlizz <lizz@sensetime.com>

* Test misleading values
Signed-off-by: default avatarlizz <lizz@sensetime.com>

* Improve comments
Signed-off-by: default avatarlizz <lizz@sensetime.com>

* Add doc
Signed-off-by: default avatarlizz <lizz@sensetime.com>

* Improve doc
Signed-off-by: default avatarlizz <lizz@sensetime.com>

* More tests
Signed-off-by: default avatarlizz <lizz@sensetime.com>

* Harder test case
Signed-off-by: default avatarlizz <lizz@sensetime.com>

* use BASE_KEY instead of base
Signed-off-by: default avatarlizz <lizz@sensetime.com>
parent eb08835f
......@@ -154,6 +154,32 @@ _base_ = ['./config_a.py', './config_e.py']
... d='string')
```
#### Reference variables from base
You can reference variables defined in base using the following grammar.
`base.py`
```python
item1 = 'a'
item2 = dict(item3 = 'b')
```
`config_g.py`
```python
_base_ = ['./base.py']
item = dict(a = {{ _base_.item1 }}, b = {{ _base_.item2.item3 }})
```
```python
>>> cfg = Config.fromfile('./config_g.py')
>>> print(cfg.pretty_text)
item1 = 'a'
item2 = dict(item3='b')
item = dict(a='a', b='b')
```
### ProgressBar
If you want to apply a method to a list of items and track the progress, `track_progress`
......
# Copyright (c) Open-MMLab. All rights reserved.
import ast
import copy
import os
import os.path as osp
import platform
import shutil
import sys
import tempfile
import uuid
import warnings
from argparse import Action, ArgumentParser
from collections import abc
......@@ -121,6 +123,57 @@ class Config:
with open(temp_config_name, 'w') as tmp_config_file:
tmp_config_file.write(config_file)
@staticmethod
def _pre_substitute_base_vars(filename, temp_config_name):
"""Substitute base variable placehoders to string, so that parsing
would work."""
with open(filename, 'r', encoding='utf-8') as f:
# Setting encoding explicitly to resolve coding issue on windows
config_file = f.read()
base_var_dict = {}
regexp = r'\{\{\s*' + BASE_KEY + r'\.([\w\.]+)\s*\}\}'
base_vars = set(re.findall(regexp, config_file))
for base_var in base_vars:
randstr = f'_{base_var}_{uuid.uuid4().hex.lower()[:6]}'
base_var_dict[randstr] = base_var
regexp = r'\{\{\s*' + BASE_KEY + r'\.' + base_var + r'\s*\}\}'
config_file = re.sub(regexp, f'"{randstr}"', config_file)
with open(temp_config_name, 'w') as tmp_config_file:
tmp_config_file.write(config_file)
return base_var_dict
@staticmethod
def _substitute_base_vars(cfg, base_var_dict, base_cfg):
"""Substitute variable strings to their actual values."""
cfg = copy.deepcopy(cfg)
if isinstance(cfg, dict):
for k, v in cfg.items():
if isinstance(v, str) and v in base_var_dict:
new_v = base_cfg
for new_k in base_var_dict[v].split('.'):
new_v = new_v[new_k]
cfg[k] = new_v
elif isinstance(v, (list, tuple, dict)):
cfg[k] = Config._substitute_base_vars(
v, base_var_dict, base_cfg)
elif isinstance(cfg, tuple):
cfg = tuple(
Config._substitute_base_vars(c, base_var_dict, base_cfg)
for c in cfg)
elif isinstance(cfg, list):
cfg = [
Config._substitute_base_vars(c, base_var_dict, base_cfg)
for c in cfg
]
elif isinstance(cfg, str) and cfg in base_var_dict:
new_v = base_cfg
for new_k in base_var_dict[cfg].split('.'):
new_v = new_v[new_k]
cfg = new_v
return cfg
@staticmethod
def _file2dict(filename, use_predefined_variables=True):
filename = osp.abspath(osp.expanduser(filename))
......@@ -141,6 +194,9 @@ class Config:
temp_config_file.name)
else:
shutil.copyfile(filename, temp_config_file.name)
# Substitute base variables from placeholders to strings
base_var_dict = Config._pre_substitute_base_vars(
temp_config_file.name, temp_config_file.name)
if filename.endswith('.py'):
temp_module_name = osp.splitext(temp_config_name)[0]
......@@ -185,6 +241,10 @@ class Config:
raise KeyError('Duplicate key is not allowed among bases')
base_cfg_dict.update(c)
# Subtitute base variables from strings to their actual values
cfg_dict = Config._substitute_base_vars(cfg_dict, base_var_dict,
base_cfg_dict)
base_cfg_dict = Config._merge_a_into_b(cfg_dict, base_cfg_dict)
cfg_dict = base_cfg_dict
......
{
"_base_": [
"./l1.py",
"./l2.yaml",
"./l3.json",
"./l4.py"
],
"item3": false,
"item4": "test",
"item8": "{{fileBasename}}",
"item9": {{ _base_.item2 }},
"item10": {{ _base_.item7.b.c }}
}
_base_ = ['./l1.py', './l2.yaml', './l3.json', './l4.py']
item3 = False
item4 = 'test'
item8 = '{{fileBasename}}'
item9 = {{ _base_.item2 }}
item10 = {{ _base_.item7.b.c }}
_base_ : ['./l1.py', './l2.yaml', './l3.json', './l4.py']
item3 : False
item4 : 'test'
item8 : '{{fileBasename}}'
item9 : {{ _base_.item2 }}
item10 : {{ _base_.item7.b.c }}
{
"_base_": [
"./t.py"
],
"base": "_base_.item8",
"item11": {{ _base_.item8 }},
"item12": {{ _base_.item9 }},
"item13": {{ _base_.item10 }},
"item14": {{ _base_.item1 }},
"item15": {
"a": {
"b": {{ _base_.item2 }}
},
"b": [
{{ _base_.item3 }}
],
"c": [{{ _base_.item4 }}],
"d": [[
{
"e": {{ _base_.item5.a }}
}
],
{{ _base_.item6 }}],
"e": {{ _base_.item1 }}
}
}
_base_ = ['./t.py']
base = '_base_.item8'
item11 = {{ _base_.item8 }}
item12 = {{ _base_.item9 }}
item13 = {{ _base_.item10 }}
item14 = {{ _base_.item1 }}
item15 = dict(
a = dict( b = {{ _base_.item2 }} ),
b = [{{ _base_.item3 }}],
c = [{{ _base_.item4 }}],
d = [[dict(e = {{ _base_.item5.a }})],{{ _base_.item6 }}],
e = {{ _base_.item1 }}
)
_base_: ["./t.py"]
base: "_base_.item8"
item11: {{ _base_.item8 }}
item12: {{ _base_.item9 }}
item13: {{ _base_.item10 }}
item14: {{ _base_.item1 }}
item15:
a:
b: {{ _base_.item2 }}
b: [{{ _base_.item3 }}]
c: [{{ _base_.item4 }}]
d:
- [e: {{ _base_.item5.a }}]
- {{ _base_.item6 }}
e: {{ _base_.item1 }}
_base_ = ['./u.py']
item21 = {{ _base_.item11 }}
item22 = item21
item23 = {{ _base_.item10 }}
item24 = item23
item25 = dict(
a = dict( b = item24 ),
b = [item24],
c = [[dict(e = item22)],{{ _base_.item6 }}],
e = item21
)
......@@ -224,6 +224,81 @@ def test_merge_from_multiple_bases():
Config.fromfile(osp.join(data_path, 'config/m.py'))
def test_base_variables():
for file in ['t.py', 't.json', 't.yaml']:
cfg_file = osp.join(data_path, f'config/{file}')
cfg = Config.fromfile(cfg_file)
assert isinstance(cfg, Config)
assert cfg.filename == cfg_file
# cfg.field
assert cfg.item1 == [1, 2]
assert cfg.item2.a == 0
assert cfg.item3 is False
assert cfg.item4 == 'test'
assert cfg.item5 == dict(a=0, b=1)
assert cfg.item6 == [dict(a=0), dict(b=1)]
assert cfg.item7 == dict(a=[0, 1, 2], b=dict(c=[3.1, 4.2, 5.3]))
assert cfg.item8 == file
assert cfg.item9 == dict(a=0)
assert cfg.item10 == [3.1, 4.2, 5.3]
# test nested base
for file in ['u.py', 'u.json', 'u.yaml']:
cfg_file = osp.join(data_path, f'config/{file}')
cfg = Config.fromfile(cfg_file)
assert isinstance(cfg, Config)
assert cfg.filename == cfg_file
# cfg.field
assert cfg.base == '_base_.item8'
assert cfg.item1 == [1, 2]
assert cfg.item2.a == 0
assert cfg.item3 is False
assert cfg.item4 == 'test'
assert cfg.item5 == dict(a=0, b=1)
assert cfg.item6 == [dict(a=0), dict(b=1)]
assert cfg.item7 == dict(a=[0, 1, 2], b=dict(c=[3.1, 4.2, 5.3]))
assert cfg.item8 == 't.py'
assert cfg.item9 == dict(a=0)
assert cfg.item10 == [3.1, 4.2, 5.3]
assert cfg.item11 == 't.py'
assert cfg.item12 == dict(a=0)
assert cfg.item13 == [3.1, 4.2, 5.3]
assert cfg.item14 == [1, 2]
assert cfg.item15 == dict(
a=dict(b=dict(a=0)),
b=[False],
c=['test'],
d=[[{
'e': 0
}], [{
'a': 0
}, {
'b': 1
}]],
e=[1, 2])
# test reference assignment for py
cfg_file = osp.join(data_path, 'config/v.py')
cfg = Config.fromfile(cfg_file)
assert isinstance(cfg, Config)
assert cfg.filename == cfg_file
assert cfg.item21 == 't.py'
assert cfg.item22 == 't.py'
assert cfg.item23 == [3.1, 4.2, 5.3]
assert cfg.item24 == [3.1, 4.2, 5.3]
assert cfg.item25 == dict(
a=dict(b=[3.1, 4.2, 5.3]),
b=[[3.1, 4.2, 5.3]],
c=[[{
'e': 't.py'
}], [{
'a': 0
}, {
'b': 1
}]],
e='t.py')
def test_merge_recursive_bases():
cfg_file = osp.join(data_path, 'config/f.py')
cfg = Config.fromfile(cfg_file)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment