Unverified Commit c2a4ce6c authored by SparkSnail's avatar SparkSnail Committed by GitHub
Browse files

Add nniManagerIp in nnictl and trainingService (#393)

Add nniManager Ip in nnictl, pai TrainingService and kubeflow TrainingService.
If users set nniManagerIp, pai and kubeflow will use this ip instead of using getIPV4() function.
Web UI will also use this nniManagerIp.
parent cb7c7ff0
......@@ -120,8 +120,20 @@ abstract class TrainingService {
public abstract run(): Promise<void>;
}
/**
* the ip of nni manager
*/
class NNIManagerIpConfig {
public readonly nniManagerIp: string;
constructor(nniManagerIp: string){
this.nniManagerIp = nniManagerIp;
}
}
export {
TrainingService, TrainingServiceError, TrialJobStatus, TrialJobApplicationForm,
TrainingServiceMetadata, TrialJobDetail, TrialJobMetric, HyperParameters,
HostJobApplicationForm, JobApplicationForm, JobType
HostJobApplicationForm, JobApplicationForm, JobType, NNIManagerIpConfig
};
......@@ -72,6 +72,9 @@ export namespace ValidationSchemas {
path: joi.string().min(1).required()
}).required(),
kubernetesServer: joi.string().min(1).required()
}),
nni_manager_ip: joi.object({
nniManagerIp: joi.string().min(1)
})
}
};
......
......@@ -29,5 +29,6 @@ export enum TrialConfigMetadataKey {
MULTI_PHASE = 'multiPhase',
RANDOM_SCHEDULER = 'random_scheduler',
PAI_CLUSTER_CONFIG = 'pai_config',
KUBEFLOW_CLUSTER_CONFIG = 'kubeflow_config'
KUBEFLOW_CLUSTER_CONFIG = 'kubeflow_config',
NNI_MANAGER_IP = 'nni_manager_ip'
}
......@@ -120,3 +120,4 @@ export class KubeflowTrialConfig {
this.ps = ps;
}
}
......@@ -33,7 +33,7 @@ import { MethodNotImplementedError } from '../../common/errors';
import { TrialConfigMetadataKey } from '../common/trialConfigMetadataKey';
import {
JobApplicationForm, TrainingService, TrialJobApplicationForm,
TrialJobDetail, TrialJobMetric
TrialJobDetail, TrialJobMetric, NNIManagerIpConfig
} from '../../common/trainingService';
import { delay, generateParamFileName, getExperimentRootDir, getIPV4Address, uniqueString } from '../../common/utils';
import { KubeflowClusterConfig, kubeflowOperatorMap, KubeflowTrialConfig, NFSConfig } from './kubeflowConfig';
......@@ -66,6 +66,7 @@ class KubeflowTrainingService implements TrainingService {
private kubeflowRestServerPort?: number;
private kubeflowJobPlural?: string;
private readonly CONTAINER_MOUNT_PATH: string;
private nniManagerIpConfig?: NNIManagerIpConfig;
constructor() {
this.log = getLogger();
......@@ -271,6 +272,10 @@ class KubeflowTrainingService implements TrainingService {
public async setClusterMetadata(key: string, value: string): Promise<void> {
switch (key) {
case TrialConfigMetadataKey.NNI_MANAGER_IP:
this.nniManagerIpConfig = <NNIManagerIpConfig>JSON.parse(value);
break;
case TrialConfigMetadataKey.KUBEFLOW_CLUSTER_CONFIG:
this.kubeflowClusterConfig = <KubeflowClusterConfig>JSON.parse(value);
......@@ -493,13 +498,13 @@ class KubeflowTrainingService implements TrainingService {
break;
}
}
const nniManagerIp = this.nniManagerIpConfig?this.nniManagerIpConfig.nniManagerIp:getIPV4Address();
runScriptLines.push('mkdir -p $NNI_SYS_DIR');
runScriptLines.push('mkdir -p $NNI_OUTPUT_DIR');
runScriptLines.push('cp -rT $NNI_CODE_DIR $NNI_SYS_DIR');
runScriptLines.push('cd $NNI_SYS_DIR');
runScriptLines.push('sh install_nni.sh # Check and install NNI pkg');
runScriptLines.push(`python3 -m nni_trial_tool.trial_keeper --trial_command '${command}' --nnimanager_ip '${getIPV4Address()}' --nnimanager_port '${this.kubeflowRestServerPort}' 1>$NNI_OUTPUT_DIR/trialkeeper_stdout 2>$NNI_OUTPUT_DIR/trialkeeper_stderr`);
runScriptLines.push(`python3 -m nni_trial_tool.trial_keeper --trial_command '${command}' --nnimanager_ip '${nniManagerIp}' --nnimanager_port '${this.kubeflowRestServerPort}' 1>$NNI_OUTPUT_DIR/trialkeeper_stdout 2>$NNI_OUTPUT_DIR/trialkeeper_stderr`);
return runScriptLines.join('\n');
}
......
......@@ -121,3 +121,4 @@ export class NNIPAITrialConfig extends TrialConfig{
this.outputDir = outputDir;
}
}
......@@ -36,7 +36,7 @@ import { getLogger, Logger } from '../../common/log';
import { TrialConfigMetadataKey } from '../common/trialConfigMetadataKey';
import {
JobApplicationForm, TrainingService, TrialJobApplicationForm,
TrialJobDetail, TrialJobMetric
TrialJobDetail, TrialJobMetric, NNIManagerIpConfig
} from '../../common/trainingService';
import { delay, generateParamFileName, getExperimentRootDir, getIPV4Address, uniqueString } from '../../common/utils';
import { PAIJobRestServer } from './paiJobRestServer'
......@@ -69,6 +69,7 @@ class PAITrainingService implements TrainingService {
private hdfsOutputHost: string | undefined;
private nextTrialSequenceId: number;
private paiRestServerPort?: number;
private nniManagerIpConfig?: NNIManagerIpConfig;
constructor() {
this.log = getLogger();
......@@ -194,7 +195,7 @@ class PAITrainingService implements TrainingService {
trialSequenceId,
hdfsLogPath);
this.trialJobsMap.set(trialJobId, trialJobDetail);
const nniManagerIp = this.nniManagerIpConfig?this.nniManagerIpConfig.nniManagerIp:getIPV4Address();
const nniPaiTrialCommand : string = String.Format(
PAI_TRIAL_COMMAND_FORMAT,
// PAI will copy job's codeDir into /root directory
......@@ -204,7 +205,7 @@ class PAITrainingService implements TrainingService {
this.experimentId,
trialSequenceId,
this.paiTrialConfig.command,
getIPV4Address(),
nniManagerIp,
this.paiRestServerPort,
hdfsOutputDir,
this.hdfsOutputHost,
......@@ -322,6 +323,11 @@ class PAITrainingService implements TrainingService {
const deferred : Deferred<void> = new Deferred<void>();
switch (key) {
case TrialConfigMetadataKey.NNI_MANAGER_IP:
this.nniManagerIpConfig = <NNIManagerIpConfig>JSON.parse(value);
deferred.resolve();
break;
case TrialConfigMetadataKey.PAI_CLUSTER_CONFIG:
//TODO: try catch exception when setting up HDFS client and get PAI token
this.paiClusterConfig = <PAIClusterConfig>JSON.parse(value);
......
......@@ -32,6 +32,7 @@ Optional('maxTrialNum'): And(int, lambda x: 1 <= x <= 99999),
Optional('searchSpacePath'): os.path.exists,
Optional('multiPhase'): bool,
Optional('multiThread'): bool,
Optional('nniManagerIp'): str,
'useAnnotation': bool,
'tuner': Or({
'builtinTunerName': Or('TPE', 'Random', 'Anneal', 'SMAC', 'Evolution'),
......
......@@ -133,6 +133,23 @@ def set_remote_config(experiment_config, port, config_file_name):
#set trial_config
return set_trial_config(experiment_config, port, config_file_name), err_message
def setNNIManagerIp(experiment_config, port, config_file_name):
'''set nniManagerIp'''
if experiment_config.get('nniManagerIp') is None:
return True, None
ip_config_dict = dict()
ip_config_dict['nni_manager_ip'] = { 'nniManagerIp' : experiment_config['nniManagerIp'] }
response = rest_put(cluster_metadata_url(port), json.dumps(ip_config_dict), 20)
err_message = None
if not response or not response.status_code == 200:
if response is not None:
err_message = response.text
_, stderr_full_path = get_log_path(config_file_name)
with open(stderr_full_path, 'a+') as fout:
fout.write(json.dumps(json.loads(err_message), indent=4, sort_keys=True, separators=(',', ':')))
return False, err_message
return True, None
def set_pai_config(experiment_config, port, config_file_name):
'''set pai configuration'''
pai_config_data = dict()
......@@ -146,7 +163,9 @@ def set_pai_config(experiment_config, port, config_file_name):
with open(stderr_full_path, 'a+') as fout:
fout.write(json.dumps(json.loads(err_message), indent=4, sort_keys=True, separators=(',', ':')))
return False, err_message
result, message = setNNIManagerIp(experiment_config, port, config_file_name)
if not result:
return result, message
#set trial_config
return set_trial_config(experiment_config, port, config_file_name), err_message
......@@ -163,7 +182,9 @@ def set_kubeflow_config(experiment_config, port, config_file_name):
with open(stderr_full_path, 'a+') as fout:
fout.write(json.dumps(json.loads(err_message), indent=4, sort_keys=True, separators=(',', ':')))
return False, err_message
result, message = setNNIManagerIp(experiment_config, port, config_file_name)
if not result:
return result, message
#set trial_config
return set_trial_config(experiment_config, port, config_file_name), err_message
......@@ -182,8 +203,6 @@ def set_experiment(experiment_config, mode, port, config_file_name):
request_data['description'] = experiment_config['description']
if experiment_config.get('multiPhase'):
request_data['multiPhase'] = experiment_config.get('multiPhase')
if experiment_config.get('multiThread'):
request_data['multiThread'] = experiment_config.get('multiThread')
request_data['tuner'] = experiment_config['tuner']
if 'assessor' in experiment_config:
request_data['assessor'] = experiment_config['assessor']
......@@ -335,6 +354,9 @@ def launch_experiment(args, experiment_config, mode, config_file_name, experimen
except Exception:
raise Exception(ERROR_INFO % 'Restful server stopped!')
exit(1)
if experiment_config.get('nniManagerIp'):
web_ui_url_list = ['{0}:{1}'.format(experiment_config['nniManagerIp'], str(args.port))]
else:
web_ui_url_list = get_local_urls(args.port)
nni_config.set_config('webuiUrl', web_ui_url_list)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment