Unverified Commit d965808e authored by liuzhe-lz's avatar liuzhe-lz Committed by GitHub
Browse files

Fix k8s and hybrid config (#3563)

parent 9c1f5344
......@@ -35,6 +35,7 @@ Local Mode
trialCommand: python mnist.py
trialCodeDirectory: .
trialGpuNumber: 1
trialConcurrency: 2
maxExperimentDuration: 24h
maxTrialNumber: 100
tuner:
......@@ -59,6 +60,7 @@ Local Mode (Inline Search Space)
_value: [0.0001, 0.1]
trialCommand: python mnist.py
trialGpuNumber: 1
trialConcurrency: 2
tuner:
name: TPE
classArgs:
......@@ -77,6 +79,7 @@ Remote Mode
trialCommand: python mnist.py
trialCodeDirectory: .
trialGpuNumber: 1
trialConcurrency: 2
maxExperimentDuration: 24h
maxTrialNumber: 100
tuner:
......
......@@ -32,6 +32,12 @@ def main():
if exp_params.get('deprecated', {}).get('multiThread'):
enable_multi_thread()
if 'trainingServicePlatform' in exp_params: # config schema is v1
from .experiment.config.convert import convert_algo
for algo_type in ['tuner', 'assessor', 'advisor']:
if algo_type in exp_params:
convert_algo(algo_type, exp_params, exp_params)
if exp_params.get('advisor') is not None:
# advisor is enabled and starts to run
_run_advisor(exp_params)
......
......@@ -82,6 +82,7 @@ class ConfigBase:
Convert config to JSON object.
The keys of returned object will be camelCase.
"""
self.validate()
return dataclasses.asdict(
self.canonical(),
dict_factory=lambda items: dict((util.camel_case(k), v) for k, v in items if v is not None)
......
......@@ -98,6 +98,13 @@ class ExperimentConfig(ConfigBase):
if isinstance(kwargs.get(algo_type), dict):
setattr(self, algo_type, _AlgorithmConfig(**kwargs.pop(algo_type)))
def canonical(self):
ret = super().canonical()
if isinstance(ret.training_service, list):
for i, ts in enumerate(ret.training_service):
ret.training_service[i] = ts.canonical()
return ret
def validate(self, initialized_tuner: bool = False) -> None:
super().validate()
if initialized_tuner:
......
......@@ -45,31 +45,8 @@ def to_v2(v1) -> ExperimentConfig:
_move_field(v1_trial, v2, 'gpuNum', 'trial_gpu_number')
for algo_type in ['tuner', 'assessor', 'advisor']:
if algo_type not in v1:
continue
v1_algo = v1.pop(algo_type)
builtin_name = v1_algo.pop(f'builtin{algo_type.title()}Name', None)
class_args = v1_algo.pop('classArgs', None)
if builtin_name is not None:
v2_algo = AlgorithmConfig(name=builtin_name, class_args=class_args)
else:
class_directory = util.canonical_path(v1_algo.pop('codeDir'))
class_file_name = v1_algo.pop('classFileName')
assert class_file_name.endswith('.py')
class_name = class_file_name[:-3] + '.' + v1_algo.pop('className')
v2_algo = CustomAlgorithmConfig(
class_name=class_name,
class_directory=class_directory,
class_args=class_args
)
setattr(v2, algo_type, v2_algo)
_deprecate(v1_algo, v2, 'includeIntermediateResults')
_move_field(v1_algo, v2, 'gpuIndices', 'tuner_gpu_indices')
assert not v1_algo, v1_algo
if algo_type in v1:
convert_algo(algo_type, v1, v2)
ts = v2.training_service
......@@ -259,3 +236,31 @@ def _deprecate(v1, v2, key):
if v2._deprecated is None:
v2._deprecated = {}
v2._deprecated[key] = v1.pop(key)
def convert_algo(algo_type, v1, v2):
if algo_type not in v1:
return None
v1_algo = v1.pop(algo_type)
builtin_name = v1_algo.pop(f'builtin{algo_type.title()}Name', None)
class_args = v1_algo.pop('classArgs', None)
if builtin_name is not None:
v2_algo = AlgorithmConfig(name=builtin_name, class_args=class_args)
else:
class_directory = util.canonical_path(v1_algo.pop('codeDir'))
class_file_name = v1_algo.pop('classFileName')
assert class_file_name.endswith('.py')
class_name = class_file_name[:-3] + '.' + v1_algo.pop('className')
v2_algo = CustomAlgorithmConfig(
class_name=class_name,
class_directory=class_directory,
class_args=class_args
)
setattr(v2, algo_type, v2_algo)
_deprecate(v1_algo, v2, 'includeIntermediateResults')
_move_field(v1_algo, v2, 'gpuIndices', 'tuner_gpu_indices')
assert not v1_algo, v1_algo
return v2_algo
......@@ -333,6 +333,8 @@ def launch_experiment(args, experiment_config, mode, experiment_id, config_versi
# start rest server
if config_version == 1:
platform = experiment_config['trainingServicePlatform']
elif isinstance(experiment_config['trainingService'], list):
platform = 'hybrid'
else:
platform = experiment_config['trainingService']['platform']
......
......@@ -409,7 +409,17 @@ class NNIManager implements Manager {
private async initTrainingService(config: ExperimentConfig): Promise<TrainingService> {
this.config = config;
const platform = Array.isArray(config.trainingService) ? 'hybrid' : config.trainingService.platform;
let platform: string;
if (Array.isArray(config.trainingService)) {
platform = 'hybrid';
} else if (config.trainingService.platform) {
platform = config.trainingService.platform;
} else {
platform = (config as any).trainingServicePlatform;
}
if (!platform) {
throw new Error('Cannot detect training service platform');
}
if (['remote', 'pai', 'aml', 'hybrid'].includes(platform)) {
const module_ = await import('../training_service/reusable/routerTrainingService');
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment