Unverified Commit c324b1fc authored by Jiazhen Wang's avatar Jiazhen Wang Committed by GitHub
Browse files

[Fix] Ensure the type of filename parameter in Config is str (#1725)

* ensure type of filename is str

* check filename for func: fromfile

* add ut for fromfile
parent de0c1039
......@@ -13,6 +13,7 @@ import warnings
from argparse import Action, ArgumentParser
from collections import abc
from importlib import import_module
from pathlib import Path
from addict import Dict
from yapf.yapflib.yapf_api import FormatCode
......@@ -334,6 +335,8 @@ class Config:
def fromfile(filename,
use_predefined_variables=True,
import_custom_modules=True):
if isinstance(filename, Path):
filename = str(filename)
cfg_dict, cfg_text = Config._file2dict(filename,
use_predefined_variables)
if import_custom_modules and cfg_dict.get('custom_imports', None):
......@@ -390,6 +393,9 @@ class Config:
if key in RESERVED_KEYS:
raise KeyError(f'{key} is reserved for config file')
if isinstance(filename, Path):
filename = str(filename)
super(Config, self).__setattr__('_cfg_dict', ConfigDict(cfg_dict))
super(Config, self).__setattr__('_filename', filename)
if cfg_text:
......
......@@ -29,10 +29,13 @@ def test_construct():
cfg_dict = dict(item1=[1, 2], item2=dict(a=0), item3=True, item4='test')
# test a.py
cfg_file = osp.join(data_path, 'config/a.py')
cfg = Config(cfg_dict, filename=cfg_file)
cfg_file_path = Path(cfg_file)
file_list = [cfg_file, cfg_file_path]
for item in file_list:
cfg = Config(cfg_dict, filename=item)
assert isinstance(cfg, Config)
assert cfg.filename == cfg_file
assert cfg.text == open(cfg_file, 'r').read()
assert isinstance(cfg.filename, str) and cfg.filename == str(item)
assert cfg.text == open(item, 'r').read()
assert cfg.dump() == cfg.pretty_text
with tempfile.TemporaryDirectory() as temp_config_dir:
dump_file = osp.join(temp_config_dir, 'a.py')
......@@ -142,11 +145,14 @@ def test_construct():
def test_fromfile():
for filename in ['a.py', 'a.b.py', 'b.json', 'c.yaml']:
cfg_file = osp.join(data_path, 'config', filename)
cfg = Config.fromfile(cfg_file)
cfg_file_path = Path(cfg_file)
file_list = [cfg_file, cfg_file_path]
for item in file_list:
cfg = Config.fromfile(item)
assert isinstance(cfg, Config)
assert cfg.filename == cfg_file
assert cfg.text == osp.abspath(osp.expanduser(cfg_file)) + '\n' + \
open(cfg_file, 'r').read()
assert isinstance(cfg.filename, str) and cfg.filename == str(item)
assert cfg.text == osp.abspath(osp.expanduser(item)) + '\n' + \
open(item, 'r').read()
# test custom_imports for Config.fromfile
cfg_file = osp.join(data_path, 'config', 'q.py')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment