Unverified Commit 9bb479bb authored by fishyds's avatar fishyds Committed by GitHub
Browse files

Add idompotent support for get_parameters() in nni sdk (#216)

* Updated based on comments

* Fix bug, make get_parameters() idompotent

* Add idompotent support for get_parameters() in LocalTrainingService
parent 14fac162
...@@ -116,6 +116,11 @@ class NNIManager implements Manager { ...@@ -116,6 +116,11 @@ class NNIManager implements Manager {
await this.storeExperimentProfile(); await this.storeExperimentProfile();
this.log.debug('Setup tuner...'); this.log.debug('Setup tuner...');
// Set up multiphase config
if(expParams.multiPhase && this.trainingService.isMultiPhaseJobSupported) {
this.trainingService.setClusterMetadata('multiPhase', expParams.multiPhase.toString());
}
const dispatcherCommand: string = getMsgDispatcherCommand(expParams.tuner, expParams.assessor, expParams.multiPhase); const dispatcherCommand: string = getMsgDispatcherCommand(expParams.tuner, expParams.assessor, expParams.multiPhase);
console.log(`dispatcher command: ${dispatcherCommand}`); console.log(`dispatcher command: ${dispatcherCommand}`);
this.setupTuner( this.setupTuner(
...@@ -140,6 +145,11 @@ class NNIManager implements Manager { ...@@ -140,6 +145,11 @@ class NNIManager implements Manager {
this.experimentProfile = await this.dataStore.getExperimentProfile(experimentId); this.experimentProfile = await this.dataStore.getExperimentProfile(experimentId);
const expParams: ExperimentParams = this.experimentProfile.params; const expParams: ExperimentParams = this.experimentProfile.params;
// Set up multiphase config
if(expParams.multiPhase && this.trainingService.isMultiPhaseJobSupported) {
this.trainingService.setClusterMetadata('multiPhase', expParams.multiPhase.toString());
}
const dispatcherCommand: string = getMsgDispatcherCommand(expParams.tuner, expParams.assessor, expParams.multiPhase); const dispatcherCommand: string = getMsgDispatcherCommand(expParams.tuner, expParams.assessor, expParams.multiPhase);
console.log(`dispatcher command: ${dispatcherCommand}`); console.log(`dispatcher command: ${dispatcherCommand}`);
this.setupTuner( this.setupTuner(
......
...@@ -26,6 +26,7 @@ export enum TrialConfigMetadataKey { ...@@ -26,6 +26,7 @@ export enum TrialConfigMetadataKey {
MACHINE_LIST = 'machine_list', MACHINE_LIST = 'machine_list',
TRIAL_CONFIG = 'trial_config', TRIAL_CONFIG = 'trial_config',
EXPERIMENT_ID = 'experimentId', EXPERIMENT_ID = 'experimentId',
MULTI_PHASE = 'multiPhase',
RANDOM_SCHEDULER = 'random_scheduler', RANDOM_SCHEDULER = 'random_scheduler',
PAI_CLUSTER_CONFIG = 'pai_config' PAI_CLUSTER_CONFIG = 'pai_config'
} }
...@@ -102,6 +102,7 @@ class LocalTrainingService implements TrainingService { ...@@ -102,6 +102,7 @@ class LocalTrainingService implements TrainingService {
private trialSequenceId: number; private trialSequenceId: number;
protected log: Logger; protected log: Logger;
protected localTrailConfig?: TrialConfig; protected localTrailConfig?: TrialConfig;
private isMultiPhase: boolean = false;
constructor() { constructor() {
this.eventEmitter = new EventEmitter(); this.eventEmitter = new EventEmitter();
...@@ -237,7 +238,7 @@ class LocalTrainingService implements TrainingService { ...@@ -237,7 +238,7 @@ class LocalTrainingService implements TrainingService {
* Is multiphase job supported in current training service * Is multiphase job supported in current training service
*/ */
public get isMultiPhaseJobSupported(): boolean { public get isMultiPhaseJobSupported(): boolean {
return false; return true;
} }
public async cancelTrialJob(trialJobId: string): Promise<void> { public async cancelTrialJob(trialJobId: string): Promise<void> {
...@@ -270,6 +271,9 @@ class LocalTrainingService implements TrainingService { ...@@ -270,6 +271,9 @@ class LocalTrainingService implements TrainingService {
throw new Error('trial config parsed failed'); throw new Error('trial config parsed failed');
} }
break; break;
case TrialConfigMetadataKey.MULTI_PHASE:
this.isMultiPhase = (value === 'true' || value === 'True');
break;
default: default:
} }
} }
...@@ -304,7 +308,8 @@ class LocalTrainingService implements TrainingService { ...@@ -304,7 +308,8 @@ class LocalTrainingService implements TrainingService {
{ key: 'NNI_PLATFORM', value: 'local' }, { key: 'NNI_PLATFORM', value: 'local' },
{ key: 'NNI_SYS_DIR', value: trialJobDetail.workingDirectory }, { key: 'NNI_SYS_DIR', value: trialJobDetail.workingDirectory },
{ key: 'NNI_TRIAL_JOB_ID', value: trialJobDetail.id }, { key: 'NNI_TRIAL_JOB_ID', value: trialJobDetail.id },
{ key: 'NNI_OUTPUT_DIR', value: trialJobDetail.workingDirectory } { key: 'NNI_OUTPUT_DIR', value: trialJobDetail.workingDirectory },
{ key: 'MULTI_PHASE', value: this.isMultiPhase.toString() }
]; ];
} }
......
...@@ -19,7 +19,6 @@ ...@@ -19,7 +19,6 @@
'use strict'; 'use strict';
import { Client } from 'ssh2';
import { JobApplicationForm, TrialJobDetail, TrialJobStatus } from '../../common/trainingService'; import { JobApplicationForm, TrialJobDetail, TrialJobStatus } from '../../common/trainingService';
import { GPUSummary } from '../common/gpuData'; import { GPUSummary } from '../common/gpuData';
...@@ -112,6 +111,7 @@ export enum ScheduleResultType { ...@@ -112,6 +111,7 @@ export enum ScheduleResultType {
export const REMOTEMACHINE_RUN_SHELL_FORMAT: string = export const REMOTEMACHINE_RUN_SHELL_FORMAT: string =
`#!/bin/bash `#!/bin/bash
export NNI_PLATFORM=remote NNI_SYS_DIR={0} NNI_TRIAL_JOB_ID={1} NNI_OUTPUT_DIR={0} export NNI_PLATFORM=remote NNI_SYS_DIR={0} NNI_TRIAL_JOB_ID={1} NNI_OUTPUT_DIR={0}
export MULTI_PHASE={7}
cd $NNI_SYS_DIR cd $NNI_SYS_DIR
echo $$ >{2} echo $$ >{2}
eval {3}{4} 2>{5} eval {3}{4} 2>{5}
......
...@@ -64,6 +64,7 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -64,6 +64,7 @@ class RemoteMachineTrainingService implements TrainingService {
private stopping: boolean = false; private stopping: boolean = false;
private metricsEmitter: EventEmitter; private metricsEmitter: EventEmitter;
private log: Logger; private log: Logger;
private isMultiPhase: boolean = false;
private trialSequenceId: number; private trialSequenceId: number;
constructor(@component.Inject timer: ObservableTimer) { constructor(@component.Inject timer: ObservableTimer) {
...@@ -226,7 +227,7 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -226,7 +227,7 @@ class RemoteMachineTrainingService implements TrainingService {
* Is multiphase job supported in current training service * Is multiphase job supported in current training service
*/ */
public get isMultiPhaseJobSupported(): boolean { public get isMultiPhaseJobSupported(): boolean {
return false; return true;
} }
/** /**
...@@ -295,6 +296,9 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -295,6 +296,9 @@ class RemoteMachineTrainingService implements TrainingService {
} }
this.trialConfig = remoteMachineTrailConfig; this.trialConfig = remoteMachineTrailConfig;
break; break;
case TrialConfigMetadataKey.MULTI_PHASE:
this.isMultiPhase = (value === 'true' || value === 'True');
break;
default: default:
//Reject for unknown keys //Reject for unknown keys
throw new Error(`Uknown key: ${key}`); throw new Error(`Uknown key: ${key}`);
...@@ -457,7 +461,9 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -457,7 +461,9 @@ class RemoteMachineTrainingService implements TrainingService {
`CUDA_VISIBLE_DEVICES=${cuda_visible_device} ` : `CUDA_VISIBLE_DEVICES=" " `, `CUDA_VISIBLE_DEVICES=${cuda_visible_device} ` : `CUDA_VISIBLE_DEVICES=" " `,
this.trialConfig.command, this.trialConfig.command,
path.join(trialWorkingFolder, 'stderr'), path.join(trialWorkingFolder, 'stderr'),
path.join(trialWorkingFolder, '.nni', 'code')); path.join(trialWorkingFolder, '.nni', 'code'),
/** Mark if the trial is multi-phase job */
this.isMultiPhase);
//create tmp trial working folder locally. //create tmp trial working folder locally.
await cpp.exec(`mkdir -p ${path.join(trialLocalTempFolder, '.nni')}`); await cpp.exec(`mkdir -p ${path.join(trialLocalTempFolder, '.nni')}`);
......
...@@ -36,6 +36,8 @@ if not os.path.exists(_outputdir): ...@@ -36,6 +36,8 @@ if not os.path.exists(_outputdir):
_log_file_path = os.path.join(_outputdir, 'trial.log') _log_file_path = os.path.join(_outputdir, 'trial.log')
init_logger(_log_file_path) init_logger(_log_file_path)
_multiphase = os.environ.get('MULTI_PHASE')
_param_index = 0 _param_index = 0
def request_next_parameter(): def request_next_parameter():
...@@ -49,7 +51,13 @@ def request_next_parameter(): ...@@ -49,7 +51,13 @@ def request_next_parameter():
def get_parameters(): def get_parameters():
global _param_index global _param_index
params_filepath = os.path.join(_sysdir, ('parameter_{}.cfg'.format(_param_index), 'parameter.cfg')[_param_index == 0]) params_file_name = ''
if _multiphase and (_multiphase == 'true' or _multiphase == 'True'):
params_file_name = ('parameter_{}.cfg'.format(_param_index), 'parameter.cfg')[_param_index == 0]
else:
params_file_name = 'parameter.cfg'
params_filepath = os.path.join(_sysdir, params_file_name)
if not os.path.isfile(params_filepath): if not os.path.isfile(params_filepath):
request_next_parameter() request_next_parameter()
while not os.path.isfile(params_filepath): while not os.path.isfile(params_filepath):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment