Unverified Commit c2a4ce6c authored by SparkSnail's avatar SparkSnail Committed by GitHub
Browse files

Add nniManagerIp in nnictl and trainingService (#393)

Add nniManager Ip in nnictl, pai TrainingService and kubeflow TrainingService.
If users set nniManagerIp, pai and kubeflow will use this ip instead of using getIPV4() function.
Web UI will also use this nniManagerIp.
parent cb7c7ff0
...@@ -120,8 +120,20 @@ abstract class TrainingService { ...@@ -120,8 +120,20 @@ abstract class TrainingService {
public abstract run(): Promise<void>; public abstract run(): Promise<void>;
} }
/**
* the ip of nni manager
*/
class NNIManagerIpConfig {
public readonly nniManagerIp: string;
constructor(nniManagerIp: string){
this.nniManagerIp = nniManagerIp;
}
}
export { export {
TrainingService, TrainingServiceError, TrialJobStatus, TrialJobApplicationForm, TrainingService, TrainingServiceError, TrialJobStatus, TrialJobApplicationForm,
TrainingServiceMetadata, TrialJobDetail, TrialJobMetric, HyperParameters, TrainingServiceMetadata, TrialJobDetail, TrialJobMetric, HyperParameters,
HostJobApplicationForm, JobApplicationForm, JobType HostJobApplicationForm, JobApplicationForm, JobType, NNIManagerIpConfig
}; };
...@@ -72,6 +72,9 @@ export namespace ValidationSchemas { ...@@ -72,6 +72,9 @@ export namespace ValidationSchemas {
path: joi.string().min(1).required() path: joi.string().min(1).required()
}).required(), }).required(),
kubernetesServer: joi.string().min(1).required() kubernetesServer: joi.string().min(1).required()
}),
nni_manager_ip: joi.object({
nniManagerIp: joi.string().min(1)
}) })
} }
}; };
......
...@@ -29,5 +29,6 @@ export enum TrialConfigMetadataKey { ...@@ -29,5 +29,6 @@ export enum TrialConfigMetadataKey {
MULTI_PHASE = 'multiPhase', MULTI_PHASE = 'multiPhase',
RANDOM_SCHEDULER = 'random_scheduler', RANDOM_SCHEDULER = 'random_scheduler',
PAI_CLUSTER_CONFIG = 'pai_config', PAI_CLUSTER_CONFIG = 'pai_config',
KUBEFLOW_CLUSTER_CONFIG = 'kubeflow_config' KUBEFLOW_CLUSTER_CONFIG = 'kubeflow_config',
NNI_MANAGER_IP = 'nni_manager_ip'
} }
...@@ -119,4 +119,5 @@ export class KubeflowTrialConfig { ...@@ -119,4 +119,5 @@ export class KubeflowTrialConfig {
this.worker = worker; this.worker = worker;
this.ps = ps; this.ps = ps;
} }
} }
\ No newline at end of file
...@@ -33,7 +33,7 @@ import { MethodNotImplementedError } from '../../common/errors'; ...@@ -33,7 +33,7 @@ import { MethodNotImplementedError } from '../../common/errors';
import { TrialConfigMetadataKey } from '../common/trialConfigMetadataKey'; import { TrialConfigMetadataKey } from '../common/trialConfigMetadataKey';
import { import {
JobApplicationForm, TrainingService, TrialJobApplicationForm, JobApplicationForm, TrainingService, TrialJobApplicationForm,
TrialJobDetail, TrialJobMetric TrialJobDetail, TrialJobMetric, NNIManagerIpConfig
} from '../../common/trainingService'; } from '../../common/trainingService';
import { delay, generateParamFileName, getExperimentRootDir, getIPV4Address, uniqueString } from '../../common/utils'; import { delay, generateParamFileName, getExperimentRootDir, getIPV4Address, uniqueString } from '../../common/utils';
import { KubeflowClusterConfig, kubeflowOperatorMap, KubeflowTrialConfig, NFSConfig } from './kubeflowConfig'; import { KubeflowClusterConfig, kubeflowOperatorMap, KubeflowTrialConfig, NFSConfig } from './kubeflowConfig';
...@@ -65,7 +65,8 @@ class KubeflowTrainingService implements TrainingService { ...@@ -65,7 +65,8 @@ class KubeflowTrainingService implements TrainingService {
private kubeflowJobInfoCollector: KubeflowJobInfoCollector; private kubeflowJobInfoCollector: KubeflowJobInfoCollector;
private kubeflowRestServerPort?: number; private kubeflowRestServerPort?: number;
private kubeflowJobPlural?: string; private kubeflowJobPlural?: string;
private readonly CONTAINER_MOUNT_PATH: string; private readonly CONTAINER_MOUNT_PATH: string;
private nniManagerIpConfig?: NNIManagerIpConfig;
constructor() { constructor() {
this.log = getLogger(); this.log = getLogger();
...@@ -271,6 +272,10 @@ class KubeflowTrainingService implements TrainingService { ...@@ -271,6 +272,10 @@ class KubeflowTrainingService implements TrainingService {
public async setClusterMetadata(key: string, value: string): Promise<void> { public async setClusterMetadata(key: string, value: string): Promise<void> {
switch (key) { switch (key) {
case TrialConfigMetadataKey.NNI_MANAGER_IP:
this.nniManagerIpConfig = <NNIManagerIpConfig>JSON.parse(value);
break;
case TrialConfigMetadataKey.KUBEFLOW_CLUSTER_CONFIG: case TrialConfigMetadataKey.KUBEFLOW_CLUSTER_CONFIG:
this.kubeflowClusterConfig = <KubeflowClusterConfig>JSON.parse(value); this.kubeflowClusterConfig = <KubeflowClusterConfig>JSON.parse(value);
...@@ -493,13 +498,13 @@ class KubeflowTrainingService implements TrainingService { ...@@ -493,13 +498,13 @@ class KubeflowTrainingService implements TrainingService {
break; break;
} }
} }
const nniManagerIp = this.nniManagerIpConfig?this.nniManagerIpConfig.nniManagerIp:getIPV4Address();
runScriptLines.push('mkdir -p $NNI_SYS_DIR'); runScriptLines.push('mkdir -p $NNI_SYS_DIR');
runScriptLines.push('mkdir -p $NNI_OUTPUT_DIR'); runScriptLines.push('mkdir -p $NNI_OUTPUT_DIR');
runScriptLines.push('cp -rT $NNI_CODE_DIR $NNI_SYS_DIR'); runScriptLines.push('cp -rT $NNI_CODE_DIR $NNI_SYS_DIR');
runScriptLines.push('cd $NNI_SYS_DIR'); runScriptLines.push('cd $NNI_SYS_DIR');
runScriptLines.push('sh install_nni.sh # Check and install NNI pkg'); runScriptLines.push('sh install_nni.sh # Check and install NNI pkg');
runScriptLines.push(`python3 -m nni_trial_tool.trial_keeper --trial_command '${command}' --nnimanager_ip '${getIPV4Address()}' --nnimanager_port '${this.kubeflowRestServerPort}' 1>$NNI_OUTPUT_DIR/trialkeeper_stdout 2>$NNI_OUTPUT_DIR/trialkeeper_stderr`); runScriptLines.push(`python3 -m nni_trial_tool.trial_keeper --trial_command '${command}' --nnimanager_ip '${nniManagerIp}' --nnimanager_port '${this.kubeflowRestServerPort}' 1>$NNI_OUTPUT_DIR/trialkeeper_stdout 2>$NNI_OUTPUT_DIR/trialkeeper_stderr`);
return runScriptLines.join('\n'); return runScriptLines.join('\n');
} }
......
...@@ -120,4 +120,5 @@ export class NNIPAITrialConfig extends TrialConfig{ ...@@ -120,4 +120,5 @@ export class NNIPAITrialConfig extends TrialConfig{
this.dataDir = dataDir; this.dataDir = dataDir;
this.outputDir = outputDir; this.outputDir = outputDir;
} }
} }
\ No newline at end of file
...@@ -36,7 +36,7 @@ import { getLogger, Logger } from '../../common/log'; ...@@ -36,7 +36,7 @@ import { getLogger, Logger } from '../../common/log';
import { TrialConfigMetadataKey } from '../common/trialConfigMetadataKey'; import { TrialConfigMetadataKey } from '../common/trialConfigMetadataKey';
import { import {
JobApplicationForm, TrainingService, TrialJobApplicationForm, JobApplicationForm, TrainingService, TrialJobApplicationForm,
TrialJobDetail, TrialJobMetric TrialJobDetail, TrialJobMetric, NNIManagerIpConfig
} from '../../common/trainingService'; } from '../../common/trainingService';
import { delay, generateParamFileName, getExperimentRootDir, getIPV4Address, uniqueString } from '../../common/utils'; import { delay, generateParamFileName, getExperimentRootDir, getIPV4Address, uniqueString } from '../../common/utils';
import { PAIJobRestServer } from './paiJobRestServer' import { PAIJobRestServer } from './paiJobRestServer'
...@@ -69,6 +69,7 @@ class PAITrainingService implements TrainingService { ...@@ -69,6 +69,7 @@ class PAITrainingService implements TrainingService {
private hdfsOutputHost: string | undefined; private hdfsOutputHost: string | undefined;
private nextTrialSequenceId: number; private nextTrialSequenceId: number;
private paiRestServerPort?: number; private paiRestServerPort?: number;
private nniManagerIpConfig?: NNIManagerIpConfig;
constructor() { constructor() {
this.log = getLogger(); this.log = getLogger();
...@@ -194,7 +195,7 @@ class PAITrainingService implements TrainingService { ...@@ -194,7 +195,7 @@ class PAITrainingService implements TrainingService {
trialSequenceId, trialSequenceId,
hdfsLogPath); hdfsLogPath);
this.trialJobsMap.set(trialJobId, trialJobDetail); this.trialJobsMap.set(trialJobId, trialJobDetail);
const nniManagerIp = this.nniManagerIpConfig?this.nniManagerIpConfig.nniManagerIp:getIPV4Address();
const nniPaiTrialCommand : string = String.Format( const nniPaiTrialCommand : string = String.Format(
PAI_TRIAL_COMMAND_FORMAT, PAI_TRIAL_COMMAND_FORMAT,
// PAI will copy job's codeDir into /root directory // PAI will copy job's codeDir into /root directory
...@@ -204,7 +205,7 @@ class PAITrainingService implements TrainingService { ...@@ -204,7 +205,7 @@ class PAITrainingService implements TrainingService {
this.experimentId, this.experimentId,
trialSequenceId, trialSequenceId,
this.paiTrialConfig.command, this.paiTrialConfig.command,
getIPV4Address(), nniManagerIp,
this.paiRestServerPort, this.paiRestServerPort,
hdfsOutputDir, hdfsOutputDir,
this.hdfsOutputHost, this.hdfsOutputHost,
...@@ -322,6 +323,11 @@ class PAITrainingService implements TrainingService { ...@@ -322,6 +323,11 @@ class PAITrainingService implements TrainingService {
const deferred : Deferred<void> = new Deferred<void>(); const deferred : Deferred<void> = new Deferred<void>();
switch (key) { switch (key) {
case TrialConfigMetadataKey.NNI_MANAGER_IP:
this.nniManagerIpConfig = <NNIManagerIpConfig>JSON.parse(value);
deferred.resolve();
break;
case TrialConfigMetadataKey.PAI_CLUSTER_CONFIG: case TrialConfigMetadataKey.PAI_CLUSTER_CONFIG:
//TODO: try catch exception when setting up HDFS client and get PAI token //TODO: try catch exception when setting up HDFS client and get PAI token
this.paiClusterConfig = <PAIClusterConfig>JSON.parse(value); this.paiClusterConfig = <PAIClusterConfig>JSON.parse(value);
......
...@@ -32,6 +32,7 @@ Optional('maxTrialNum'): And(int, lambda x: 1 <= x <= 99999), ...@@ -32,6 +32,7 @@ Optional('maxTrialNum'): And(int, lambda x: 1 <= x <= 99999),
Optional('searchSpacePath'): os.path.exists, Optional('searchSpacePath'): os.path.exists,
Optional('multiPhase'): bool, Optional('multiPhase'): bool,
Optional('multiThread'): bool, Optional('multiThread'): bool,
Optional('nniManagerIp'): str,
'useAnnotation': bool, 'useAnnotation': bool,
'tuner': Or({ 'tuner': Or({
'builtinTunerName': Or('TPE', 'Random', 'Anneal', 'SMAC', 'Evolution'), 'builtinTunerName': Or('TPE', 'Random', 'Anneal', 'SMAC', 'Evolution'),
......
...@@ -133,6 +133,23 @@ def set_remote_config(experiment_config, port, config_file_name): ...@@ -133,6 +133,23 @@ def set_remote_config(experiment_config, port, config_file_name):
#set trial_config #set trial_config
return set_trial_config(experiment_config, port, config_file_name), err_message return set_trial_config(experiment_config, port, config_file_name), err_message
def setNNIManagerIp(experiment_config, port, config_file_name):
'''set nniManagerIp'''
if experiment_config.get('nniManagerIp') is None:
return True, None
ip_config_dict = dict()
ip_config_dict['nni_manager_ip'] = { 'nniManagerIp' : experiment_config['nniManagerIp'] }
response = rest_put(cluster_metadata_url(port), json.dumps(ip_config_dict), 20)
err_message = None
if not response or not response.status_code == 200:
if response is not None:
err_message = response.text
_, stderr_full_path = get_log_path(config_file_name)
with open(stderr_full_path, 'a+') as fout:
fout.write(json.dumps(json.loads(err_message), indent=4, sort_keys=True, separators=(',', ':')))
return False, err_message
return True, None
def set_pai_config(experiment_config, port, config_file_name): def set_pai_config(experiment_config, port, config_file_name):
'''set pai configuration''' '''set pai configuration'''
pai_config_data = dict() pai_config_data = dict()
...@@ -146,7 +163,9 @@ def set_pai_config(experiment_config, port, config_file_name): ...@@ -146,7 +163,9 @@ def set_pai_config(experiment_config, port, config_file_name):
with open(stderr_full_path, 'a+') as fout: with open(stderr_full_path, 'a+') as fout:
fout.write(json.dumps(json.loads(err_message), indent=4, sort_keys=True, separators=(',', ':'))) fout.write(json.dumps(json.loads(err_message), indent=4, sort_keys=True, separators=(',', ':')))
return False, err_message return False, err_message
result, message = setNNIManagerIp(experiment_config, port, config_file_name)
if not result:
return result, message
#set trial_config #set trial_config
return set_trial_config(experiment_config, port, config_file_name), err_message return set_trial_config(experiment_config, port, config_file_name), err_message
...@@ -163,7 +182,9 @@ def set_kubeflow_config(experiment_config, port, config_file_name): ...@@ -163,7 +182,9 @@ def set_kubeflow_config(experiment_config, port, config_file_name):
with open(stderr_full_path, 'a+') as fout: with open(stderr_full_path, 'a+') as fout:
fout.write(json.dumps(json.loads(err_message), indent=4, sort_keys=True, separators=(',', ':'))) fout.write(json.dumps(json.loads(err_message), indent=4, sort_keys=True, separators=(',', ':')))
return False, err_message return False, err_message
result, message = setNNIManagerIp(experiment_config, port, config_file_name)
if not result:
return result, message
#set trial_config #set trial_config
return set_trial_config(experiment_config, port, config_file_name), err_message return set_trial_config(experiment_config, port, config_file_name), err_message
...@@ -182,8 +203,6 @@ def set_experiment(experiment_config, mode, port, config_file_name): ...@@ -182,8 +203,6 @@ def set_experiment(experiment_config, mode, port, config_file_name):
request_data['description'] = experiment_config['description'] request_data['description'] = experiment_config['description']
if experiment_config.get('multiPhase'): if experiment_config.get('multiPhase'):
request_data['multiPhase'] = experiment_config.get('multiPhase') request_data['multiPhase'] = experiment_config.get('multiPhase')
if experiment_config.get('multiThread'):
request_data['multiThread'] = experiment_config.get('multiThread')
request_data['tuner'] = experiment_config['tuner'] request_data['tuner'] = experiment_config['tuner']
if 'assessor' in experiment_config: if 'assessor' in experiment_config:
request_data['assessor'] = experiment_config['assessor'] request_data['assessor'] = experiment_config['assessor']
...@@ -335,7 +354,10 @@ def launch_experiment(args, experiment_config, mode, config_file_name, experimen ...@@ -335,7 +354,10 @@ def launch_experiment(args, experiment_config, mode, config_file_name, experimen
except Exception: except Exception:
raise Exception(ERROR_INFO % 'Restful server stopped!') raise Exception(ERROR_INFO % 'Restful server stopped!')
exit(1) exit(1)
web_ui_url_list = get_local_urls(args.port) if experiment_config.get('nniManagerIp'):
web_ui_url_list = ['{0}:{1}'.format(experiment_config['nniManagerIp'], str(args.port))]
else:
web_ui_url_list = get_local_urls(args.port)
nni_config.set_config('webuiUrl', web_ui_url_list) nni_config.set_config('webuiUrl', web_ui_url_list)
#save experiment information #save experiment information
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment