Unverified Commit b2926eac authored by chenbohua3's avatar chenbohua3 Committed by GitHub
Browse files

[Feature] support customize config path (#423)

* support customize config path

* support customize config path

* support customize config path
parent 524579b5
import os
from typing import List, Union from typing import List, Union
import tabulate import tabulate
...@@ -60,7 +61,8 @@ def get_config_from_arg(args) -> Config: ...@@ -60,7 +61,8 @@ def get_config_from_arg(args) -> Config:
raise ValueError('You must specify "--datasets" if you do not specify ' raise ValueError('You must specify "--datasets" if you do not specify '
'a config file path.') 'a config file path.')
datasets = [] datasets = []
for dataset in match_cfg_file('configs/datasets/', args.datasets): datasets_dir = os.path.join(args.config_dir, 'datasets')
for dataset in match_cfg_file(datasets_dir, args.datasets):
get_logger().info(f'Loading {dataset[0]}: {dataset[1]}') get_logger().info(f'Loading {dataset[0]}: {dataset[1]}')
cfg = Config.fromfile(dataset[1]) cfg = Config.fromfile(dataset[1])
for k in cfg.keys(): for k in cfg.keys():
...@@ -73,7 +75,8 @@ def get_config_from_arg(args) -> Config: ...@@ -73,7 +75,8 @@ def get_config_from_arg(args) -> Config:
'--datasets.') '--datasets.')
models = [] models = []
if args.models: if args.models:
for model in match_cfg_file('configs/models/', args.models): model_dir = os.path.join(args.config_dir, 'models')
for model in match_cfg_file(model_dir, args.models):
get_logger().info(f'Loading {model[0]}: {model[1]}') get_logger().info(f'Loading {model[0]}: {model[1]}')
cfg = Config.fromfile(model[1]) cfg = Config.fromfile(model[1])
if 'models' not in cfg: if 'models' not in cfg:
...@@ -98,7 +101,8 @@ def get_config_from_arg(args) -> Config: ...@@ -98,7 +101,8 @@ def get_config_from_arg(args) -> Config:
summarizer = None summarizer = None
if args.summarizer: if args.summarizer:
s = match_cfg_file('configs/summarizers/', [args.summarizer])[0] summarizers_dir = os.path.join(args.config_dir, 'summarizers')
s = match_cfg_file(summarizers_dir, [args.summarizer])[0]
get_logger().info(f'Loading {s[0]}: {s[1]}') get_logger().info(f'Loading {s[0]}: {s[1]}')
cfg = Config.fromfile(s[1]) cfg = Config.fromfile(s[1])
summarizer = cfg['summarizer'] summarizer = cfg['summarizer']
......
...@@ -83,6 +83,12 @@ def parse_args(): ...@@ -83,6 +83,12 @@ def parse_args():
'./outputs/default.', './outputs/default.',
default=None, default=None,
type=str) type=str)
parser.add_argument(
'--config-dir',
default='configs',
help='Use the custom config directory instead of config/ to '
'search the configs for datasets, models and summarizers',
type=str)
parser.add_argument('-l', parser.add_argument('-l',
'--lark', '--lark',
help='Report the running status to lark bot', help='Report the running status to lark bot',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment