Unverified Commit a4e3c240 authored by liuzhe-lz's avatar liuzhe-lz Committed by GitHub
Browse files

fix config v2 relative path (#3439)

parent e457047c
......@@ -47,6 +47,8 @@ class ConfigBase:
They will be converted to snake_case automatically.
If a field is missing and don't have default value, it will be set to `dataclasses.MISSING`.
"""
if 'basepath' in kwargs:
_base_path = kwargs.pop('basepath')
kwargs = {util.case_insensitive(key): value for key, value in kwargs.items()}
if _base_path is None:
_base_path = Path()
......
......@@ -68,17 +68,24 @@ class ExperimentConfig(ConfigBase):
training_service: Union[TrainingServiceConfig, List[TrainingServiceConfig]]
def __init__(self, training_service_platform: Optional[Union[str, List[str]]] = None, **kwargs):
base_path = kwargs.pop('_base_path', None)
kwargs = util.case_insensitive(kwargs)
if training_service_platform is not None:
assert 'trainingservice' not in kwargs
kwargs['trainingservice'] = util.training_service_config_factory(platform = training_service_platform)
kwargs['trainingservice'] = util.training_service_config_factory(
platform=training_service_platform,
base_path=base_path
)
elif isinstance(kwargs.get('trainingservice'), (dict, list)):
# dict means a single training service
# list means hybrid training service
kwargs['trainingservice'] = util.training_service_config_factory(config = kwargs['trainingservice'])
kwargs['trainingservice'] = util.training_service_config_factory(
config=kwargs['trainingservice'],
base_path=base_path
)
else:
raise RuntimeError('Unsupported Training service configuration!')
super().__init__(**kwargs)
super().__init__(_base_path=base_path, **kwargs)
for algo_type in ['tuner', 'assessor', 'advisor']:
if isinstance(kwargs.get(algo_type), dict):
setattr(self, algo_type, _AlgorithmConfig(**kwargs.pop(algo_type)))
......
......@@ -29,7 +29,10 @@ def canonical_path(path: Optional[PathLike]) -> Optional[str]:
def count(*values) -> int:
return sum(value is not None and value is not False for value in values)
def training_service_config_factory(platform: Union[str, List[str]] = None, config: Union[List, Dict] = None): # -> TrainingServiceConfig
def training_service_config_factory(
platform: Union[str, List[str]] = None,
config: Union[List, Dict] = None,
base_path: Optional[Path] = None): # -> TrainingServiceConfig
from .common import TrainingServiceConfig
ts_configs = []
if platform is not None:
......@@ -47,7 +50,7 @@ def training_service_config_factory(platform: Union[str, List[str]] = None, conf
for conf in configs:
if conf['platform'] not in supported_platforms:
raise RuntimeError(f'Unrecognized platform {conf["platform"]}')
ts_configs.append(supported_platforms[conf['platform']](**conf))
ts_configs.append(supported_platforms[conf['platform']](_base_path=base_path, **conf))
return ts_configs if len(ts_configs) > 1 else ts_configs[0]
def load_config(Type, value):
......
......@@ -3,6 +3,7 @@
import json
import os
from pathlib import Path
import sys
import string
import random
......@@ -590,7 +591,7 @@ def create_experiment(args):
except Exception:
print_warning('Validation with V1 schema failed. Trying to convert from V2 format...')
try:
config = ExperimentConfig(**experiment_config)
config = ExperimentConfig(_base_path=Path(config_path).parent, **experiment_config)
experiment_config = convert.to_v1_yaml(config)
except Exception as e:
print_error(f'Config in v2 format validation failed, the config error in v2 format is: {repr(e)}')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment