Unverified Commit c324b1fc authored by Jiazhen Wang's avatar Jiazhen Wang Committed by GitHub
Browse files

[Fix] Ensure the type of filename parameter in Config is str (#1725)

* ensure type of filename is str

* check filename for func: fromfile

* add ut for fromfile
parent de0c1039
...@@ -13,6 +13,7 @@ import warnings ...@@ -13,6 +13,7 @@ import warnings
from argparse import Action, ArgumentParser from argparse import Action, ArgumentParser
from collections import abc from collections import abc
from importlib import import_module from importlib import import_module
from pathlib import Path
from addict import Dict from addict import Dict
from yapf.yapflib.yapf_api import FormatCode from yapf.yapflib.yapf_api import FormatCode
...@@ -334,6 +335,8 @@ class Config: ...@@ -334,6 +335,8 @@ class Config:
def fromfile(filename, def fromfile(filename,
use_predefined_variables=True, use_predefined_variables=True,
import_custom_modules=True): import_custom_modules=True):
if isinstance(filename, Path):
filename = str(filename)
cfg_dict, cfg_text = Config._file2dict(filename, cfg_dict, cfg_text = Config._file2dict(filename,
use_predefined_variables) use_predefined_variables)
if import_custom_modules and cfg_dict.get('custom_imports', None): if import_custom_modules and cfg_dict.get('custom_imports', None):
...@@ -390,6 +393,9 @@ class Config: ...@@ -390,6 +393,9 @@ class Config:
if key in RESERVED_KEYS: if key in RESERVED_KEYS:
raise KeyError(f'{key} is reserved for config file') raise KeyError(f'{key} is reserved for config file')
if isinstance(filename, Path):
filename = str(filename)
super(Config, self).__setattr__('_cfg_dict', ConfigDict(cfg_dict)) super(Config, self).__setattr__('_cfg_dict', ConfigDict(cfg_dict))
super(Config, self).__setattr__('_filename', filename) super(Config, self).__setattr__('_filename', filename)
if cfg_text: if cfg_text:
......
...@@ -29,16 +29,19 @@ def test_construct(): ...@@ -29,16 +29,19 @@ def test_construct():
cfg_dict = dict(item1=[1, 2], item2=dict(a=0), item3=True, item4='test') cfg_dict = dict(item1=[1, 2], item2=dict(a=0), item3=True, item4='test')
# test a.py # test a.py
cfg_file = osp.join(data_path, 'config/a.py') cfg_file = osp.join(data_path, 'config/a.py')
cfg = Config(cfg_dict, filename=cfg_file) cfg_file_path = Path(cfg_file)
assert isinstance(cfg, Config) file_list = [cfg_file, cfg_file_path]
assert cfg.filename == cfg_file for item in file_list:
assert cfg.text == open(cfg_file, 'r').read() cfg = Config(cfg_dict, filename=item)
assert cfg.dump() == cfg.pretty_text assert isinstance(cfg, Config)
with tempfile.TemporaryDirectory() as temp_config_dir: assert isinstance(cfg.filename, str) and cfg.filename == str(item)
dump_file = osp.join(temp_config_dir, 'a.py') assert cfg.text == open(item, 'r').read()
cfg.dump(dump_file) assert cfg.dump() == cfg.pretty_text
assert cfg.dump() == open(dump_file, 'r').read() with tempfile.TemporaryDirectory() as temp_config_dir:
assert Config.fromfile(dump_file) dump_file = osp.join(temp_config_dir, 'a.py')
cfg.dump(dump_file)
assert cfg.dump() == open(dump_file, 'r').read()
assert Config.fromfile(dump_file)
# test b.json # test b.json
cfg_file = osp.join(data_path, 'config/b.json') cfg_file = osp.join(data_path, 'config/b.json')
...@@ -142,11 +145,14 @@ def test_construct(): ...@@ -142,11 +145,14 @@ def test_construct():
def test_fromfile(): def test_fromfile():
for filename in ['a.py', 'a.b.py', 'b.json', 'c.yaml']: for filename in ['a.py', 'a.b.py', 'b.json', 'c.yaml']:
cfg_file = osp.join(data_path, 'config', filename) cfg_file = osp.join(data_path, 'config', filename)
cfg = Config.fromfile(cfg_file) cfg_file_path = Path(cfg_file)
assert isinstance(cfg, Config) file_list = [cfg_file, cfg_file_path]
assert cfg.filename == cfg_file for item in file_list:
assert cfg.text == osp.abspath(osp.expanduser(cfg_file)) + '\n' + \ cfg = Config.fromfile(item)
open(cfg_file, 'r').read() assert isinstance(cfg, Config)
assert isinstance(cfg.filename, str) and cfg.filename == str(item)
assert cfg.text == osp.abspath(osp.expanduser(item)) + '\n' + \
open(item, 'r').read()
# test custom_imports for Config.fromfile # test custom_imports for Config.fromfile
cfg_file = osp.join(data_path, 'config', 'q.py') cfg_file = osp.join(data_path, 'config', 'q.py')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment