Unverified Commit 441267d1 authored by SparkSnail's avatar SparkSnail Committed by GitHub
Browse files

Hotfox k8s setClusterMetadata (#3567)

parent 22c185c5
...@@ -119,6 +119,18 @@ def set_trial_config(experiment_config, port, config_file_name): ...@@ -119,6 +119,18 @@ def set_trial_config(experiment_config, port, config_file_name):
def set_adl_config(experiment_config, port, config_file_name): def set_adl_config(experiment_config, port, config_file_name):
'''set adl configuration''' '''set adl configuration'''
adl_config_data = dict()
# hack for supporting v2 config, need refactor
adl_config_data['adl_config'] = {}
response = rest_put(cluster_metadata_url(port), json.dumps(adl_config_data), REST_TIME_OUT)
err_message = None
if not response or not response.status_code == 200:
if response is not None:
err_message = response.text
_, stderr_full_path = get_log_path(config_file_name)
with open(stderr_full_path, 'a+') as fout:
fout.write(json.dumps(json.loads(err_message), indent=4, sort_keys=True, separators=(',', ':')))
return False, err_message
result, message = setNNIManagerIp(experiment_config, port, config_file_name) result, message = setNNIManagerIp(experiment_config, port, config_file_name)
if not result: if not result:
return result, message return result, message
...@@ -377,6 +389,10 @@ def launch_experiment(args, experiment_config, mode, experiment_id, config_versi ...@@ -377,6 +389,10 @@ def launch_experiment(args, experiment_config, mode, experiment_id, config_versi
except Exception: except Exception:
raise Exception(ERROR_INFO % 'Rest server stopped!') raise Exception(ERROR_INFO % 'Rest server stopped!')
exit(1) exit(1)
if config_version == 1 and mode != 'view':
# set platform configuration
set_platform_config(experiment_config['trainingServicePlatform'], experiment_config, args.port,\
experiment_id, rest_process)
# start a new experiment # start a new experiment
print_normal('Starting experiment...') print_normal('Starting experiment...')
...@@ -398,10 +414,6 @@ def launch_experiment(args, experiment_config, mode, experiment_id, config_versi ...@@ -398,10 +414,6 @@ def launch_experiment(args, experiment_config, mode, experiment_id, config_versi
except Exception: except Exception:
raise Exception(ERROR_INFO % 'Restful server stopped!') raise Exception(ERROR_INFO % 'Restful server stopped!')
exit(1) exit(1)
if config_version == 1 and mode != 'view':
# set platform configuration
set_platform_config(experiment_config['trainingServicePlatform'], experiment_config, args.port,\
experiment_id, rest_process)
if experiment_config.get('nniManagerIp'): if experiment_config.get('nniManagerIp'):
web_ui_url_list = ['http://{0}:{1}'.format(experiment_config['nniManagerIp'], str(args.port))] web_ui_url_list = ['http://{0}:{1}'.format(experiment_config['nniManagerIp'], str(args.port))]
else: else:
......
...@@ -175,12 +175,14 @@ class NNIManager implements Manager { ...@@ -175,12 +175,14 @@ class NNIManager implements Manager {
nextSequenceId: 0, nextSequenceId: 0,
revision: 0 revision: 0
}; };
this.config = config;
this.log.info(`Starting experiment: ${this.experimentProfile.id}`); this.log.info(`Starting experiment: ${this.experimentProfile.id}`);
await this.storeExperimentProfile(); await this.storeExperimentProfile();
this.log.info('Setup training service...'); if (this.trainingService === undefined) {
this.trainingService = await this.initTrainingService(config); this.log.info('Setup training service...');
this.trainingService = await this.initTrainingService(config);
}
this.log.info('Setup tuner...'); this.log.info('Setup tuner...');
const dispatcherCommand: string = getMsgDispatcherCommand(config); const dispatcherCommand: string = getMsgDispatcherCommand(config);
...@@ -256,10 +258,30 @@ class NNIManager implements Manager { ...@@ -256,10 +258,30 @@ class NNIManager implements Manager {
} }
public async setClusterMetadata(key: string, value: string): Promise<void> { public async setClusterMetadata(key: string, value: string): Promise<void> {
while (this.trainingService === undefined) { // Hack for supporting v2 config, need refactor
await delay(1000); if (this.trainingService === undefined) {
this.log.info('Setup training service...');
switch (key) {
case 'kubeflow_config': {
const kubeflowModule = await import('../training_service/kubernetes/kubeflow/kubeflowTrainingService');
this.trainingService = new kubeflowModule.KubeflowTrainingService();
break;
}
case 'frameworkcontroller_config': {
const fcModule = await import('../training_service/kubernetes/frameworkcontroller/frameworkcontrollerTrainingService');
this.trainingService = new fcModule.FrameworkControllerTrainingService();
break;
}
case 'adl_config': {
const adlModule = await import('../training_service/kubernetes/adl/adlTrainingService');
this.trainingService = new adlModule.AdlTrainingService();
break;
}
default:
throw new Error("Setup training service failed.");
}
} }
this.trainingService.setClusterMetadata(key, value); await this.trainingService.setClusterMetadata(key, value);
} }
public getClusterMetadata(key: string): Promise<string> { public getClusterMetadata(key: string): Promise<string> {
...@@ -408,7 +430,6 @@ class NNIManager implements Manager { ...@@ -408,7 +430,6 @@ class NNIManager implements Manager {
} }
private async initTrainingService(config: ExperimentConfig): Promise<TrainingService> { private async initTrainingService(config: ExperimentConfig): Promise<TrainingService> {
this.config = config;
let platform: string; let platform: string;
if (Array.isArray(config.trainingService)) { if (Array.isArray(config.trainingService)) {
platform = 'hybrid'; platform = 'hybrid';
......
...@@ -131,6 +131,9 @@ export namespace ValidationSchemas { ...@@ -131,6 +131,9 @@ export namespace ValidationSchemas {
maxTrialNumPerGpu: joi.number(), maxTrialNumPerGpu: joi.number(),
useActiveGpu: joi.boolean(), useActiveGpu: joi.boolean(),
}), }),
adl_config: joi.object({ // eslint-disable-line @typescript-eslint/camelcase
// hack for v2 configuration
}),
kubeflow_config: joi.object({ // eslint-disable-line @typescript-eslint/camelcase kubeflow_config: joi.object({ // eslint-disable-line @typescript-eslint/camelcase
operator: joi.string().min(1).required(), operator: joi.string().min(1).required(),
storage: joi.string().min(1), storage: joi.string().min(1),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment