Unverified Commit a4e3c240 authored by liuzhe-lz's avatar liuzhe-lz Committed by GitHub
Browse files

fix config v2 relative path (#3439)

parent e457047c
...@@ -47,6 +47,8 @@ class ConfigBase: ...@@ -47,6 +47,8 @@ class ConfigBase:
They will be converted to snake_case automatically. They will be converted to snake_case automatically.
If a field is missing and don't have default value, it will be set to `dataclasses.MISSING`. If a field is missing and don't have default value, it will be set to `dataclasses.MISSING`.
""" """
if 'basepath' in kwargs:
_base_path = kwargs.pop('basepath')
kwargs = {util.case_insensitive(key): value for key, value in kwargs.items()} kwargs = {util.case_insensitive(key): value for key, value in kwargs.items()}
if _base_path is None: if _base_path is None:
_base_path = Path() _base_path = Path()
......
...@@ -68,17 +68,24 @@ class ExperimentConfig(ConfigBase): ...@@ -68,17 +68,24 @@ class ExperimentConfig(ConfigBase):
training_service: Union[TrainingServiceConfig, List[TrainingServiceConfig]] training_service: Union[TrainingServiceConfig, List[TrainingServiceConfig]]
def __init__(self, training_service_platform: Optional[Union[str, List[str]]] = None, **kwargs): def __init__(self, training_service_platform: Optional[Union[str, List[str]]] = None, **kwargs):
base_path = kwargs.pop('_base_path', None)
kwargs = util.case_insensitive(kwargs) kwargs = util.case_insensitive(kwargs)
if training_service_platform is not None: if training_service_platform is not None:
assert 'trainingservice' not in kwargs assert 'trainingservice' not in kwargs
kwargs['trainingservice'] = util.training_service_config_factory(platform = training_service_platform) kwargs['trainingservice'] = util.training_service_config_factory(
platform=training_service_platform,
base_path=base_path
)
elif isinstance(kwargs.get('trainingservice'), (dict, list)): elif isinstance(kwargs.get('trainingservice'), (dict, list)):
# dict means a single training service # dict means a single training service
# list means hybrid training service # list means hybrid training service
kwargs['trainingservice'] = util.training_service_config_factory(config = kwargs['trainingservice']) kwargs['trainingservice'] = util.training_service_config_factory(
config=kwargs['trainingservice'],
base_path=base_path
)
else: else:
raise RuntimeError('Unsupported Training service configuration!') raise RuntimeError('Unsupported Training service configuration!')
super().__init__(**kwargs) super().__init__(_base_path=base_path, **kwargs)
for algo_type in ['tuner', 'assessor', 'advisor']: for algo_type in ['tuner', 'assessor', 'advisor']:
if isinstance(kwargs.get(algo_type), dict): if isinstance(kwargs.get(algo_type), dict):
setattr(self, algo_type, _AlgorithmConfig(**kwargs.pop(algo_type))) setattr(self, algo_type, _AlgorithmConfig(**kwargs.pop(algo_type)))
......
...@@ -29,7 +29,10 @@ def canonical_path(path: Optional[PathLike]) -> Optional[str]: ...@@ -29,7 +29,10 @@ def canonical_path(path: Optional[PathLike]) -> Optional[str]:
def count(*values) -> int: def count(*values) -> int:
return sum(value is not None and value is not False for value in values) return sum(value is not None and value is not False for value in values)
def training_service_config_factory(platform: Union[str, List[str]] = None, config: Union[List, Dict] = None): # -> TrainingServiceConfig def training_service_config_factory(
platform: Union[str, List[str]] = None,
config: Union[List, Dict] = None,
base_path: Optional[Path] = None): # -> TrainingServiceConfig
from .common import TrainingServiceConfig from .common import TrainingServiceConfig
ts_configs = [] ts_configs = []
if platform is not None: if platform is not None:
...@@ -47,7 +50,7 @@ def training_service_config_factory(platform: Union[str, List[str]] = None, conf ...@@ -47,7 +50,7 @@ def training_service_config_factory(platform: Union[str, List[str]] = None, conf
for conf in configs: for conf in configs:
if conf['platform'] not in supported_platforms: if conf['platform'] not in supported_platforms:
raise RuntimeError(f'Unrecognized platform {conf["platform"]}') raise RuntimeError(f'Unrecognized platform {conf["platform"]}')
ts_configs.append(supported_platforms[conf['platform']](**conf)) ts_configs.append(supported_platforms[conf['platform']](_base_path=base_path, **conf))
return ts_configs if len(ts_configs) > 1 else ts_configs[0] return ts_configs if len(ts_configs) > 1 else ts_configs[0]
def load_config(Type, value): def load_config(Type, value):
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
import json import json
import os import os
from pathlib import Path
import sys import sys
import string import string
import random import random
...@@ -590,7 +591,7 @@ def create_experiment(args): ...@@ -590,7 +591,7 @@ def create_experiment(args):
except Exception: except Exception:
print_warning('Validation with V1 schema failed. Trying to convert from V2 format...') print_warning('Validation with V1 schema failed. Trying to convert from V2 format...')
try: try:
config = ExperimentConfig(**experiment_config) config = ExperimentConfig(_base_path=Path(config_path).parent, **experiment_config)
experiment_config = convert.to_v1_yaml(config) experiment_config = convert.to_v1_yaml(config)
except Exception as e: except Exception as e:
print_error(f'Config in v2 format validation failed, the config error in v2 format is: {repr(e)}') print_error(f'Config in v2 format validation failed, the config error in v2 format is: {repr(e)}')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment