Unverified Commit 817ec68b authored by liuzhe-lz's avatar liuzhe-lz Committed by GitHub
Browse files

Add native support for v2 config (#3466)

parent 6aaca5f7
...@@ -80,7 +80,6 @@ abstract class TrainingService { ...@@ -80,7 +80,6 @@ abstract class TrainingService {
public abstract removeTrialJobMetricListener(listener: (metric: TrialJobMetric) => void): void; public abstract removeTrialJobMetricListener(listener: (metric: TrialJobMetric) => void): void;
public abstract submitTrialJob(form: TrialJobApplicationForm): Promise<TrialJobDetail>; public abstract submitTrialJob(form: TrialJobApplicationForm): Promise<TrialJobDetail>;
public abstract updateTrialJob(trialJobId: string, form: TrialJobApplicationForm): Promise<TrialJobDetail>; public abstract updateTrialJob(trialJobId: string, form: TrialJobApplicationForm): Promise<TrialJobDetail>;
public abstract get isMultiPhaseJobSupported(): boolean;
public abstract cancelTrialJob(trialJobId: string, isEarlyStopped?: boolean): Promise<void>; public abstract cancelTrialJob(trialJobId: string, isEarlyStopped?: boolean): Promise<void>;
public abstract getTrialLog(trialJobId: string, logType: LogType): Promise<string>; public abstract getTrialLog(trialJobId: string, logType: LogType): Promise<string>;
public abstract setClusterMetadata(key: string, value: string): Promise<void>; public abstract setClusterMetadata(key: string, value: string): Promise<void>;
......
...@@ -20,7 +20,7 @@ import * as glob from 'glob'; ...@@ -20,7 +20,7 @@ import * as glob from 'glob';
import { Database, DataStore } from './datastore'; import { Database, DataStore } from './datastore';
import { ExperimentStartupInfo, getExperimentStartupInfo, setExperimentStartupInfo } from './experimentStartupInfo'; import { ExperimentStartupInfo, getExperimentStartupInfo, setExperimentStartupInfo } from './experimentStartupInfo';
import { ExperimentParams, Manager } from './manager'; import { ExperimentConfig, Manager } from './manager';
import { ExperimentManager } from './experimentManager'; import { ExperimentManager } from './experimentManager';
import { HyperParameters, TrainingService, TrialJobStatus } from './trainingService'; import { HyperParameters, TrainingService, TrialJobStatus } from './trainingService';
import { logLevelNameMap } from './log'; import { logLevelNameMap } from './log';
...@@ -159,7 +159,7 @@ function getCmdPy(): string { ...@@ -159,7 +159,7 @@ function getCmdPy(): string {
* @param expParams: experiment startup parameters * @param expParams: experiment startup parameters
* *
*/ */
function getMsgDispatcherCommand(expParams: ExperimentParams): string { function getMsgDispatcherCommand(expParams: ExperimentConfig): string {
const clonedParams = Object.assign({}, expParams); const clonedParams = Object.assign({}, expParams);
delete clonedParams.searchSpace; delete clonedParams.searchSpace;
return `${getCmdPy()} -m nni --exp_params ${Buffer.from(JSON.stringify(clonedParams)).toString('base64')}`; return `${getCmdPy()} -m nni --exp_params ${Buffer.from(JSON.stringify(clonedParams)).toString('base64')}`;
...@@ -332,8 +332,8 @@ async function getVersion(): Promise<string> { ...@@ -332,8 +332,8 @@ async function getVersion(): Promise<string> {
const deferred: Deferred<string> = new Deferred<string>(); const deferred: Deferred<string> = new Deferred<string>();
import(path.join(__dirname, '..', 'package.json')).then((pkg) => { import(path.join(__dirname, '..', 'package.json')).then((pkg) => {
deferred.resolve(pkg.version); deferred.resolve(pkg.version);
}).catch((error) => { }).catch(() => {
deferred.reject(error); deferred.resolve('999.0.0-developing');
}); });
return deferred.promise; return deferred.promise;
} }
......
...@@ -11,7 +11,7 @@ import { Database, DataStore, MetricData, MetricDataRecord, MetricType, ...@@ -11,7 +11,7 @@ import { Database, DataStore, MetricData, MetricDataRecord, MetricType,
TrialJobEvent, TrialJobEventRecord, TrialJobInfo, HyperParameterFormat, TrialJobEvent, TrialJobEventRecord, TrialJobInfo, HyperParameterFormat,
ExportedDataFormat } from '../common/datastore'; ExportedDataFormat } from '../common/datastore';
import { NNIError } from '../common/errors'; import { NNIError } from '../common/errors';
import { getExperimentId, isNewExperiment } from '../common/experimentStartupInfo'; import { isNewExperiment } from '../common/experimentStartupInfo';
import { getLogger, Logger } from '../common/log'; import { getLogger, Logger } from '../common/log';
import { ExperimentProfile, TrialJobStatistics } from '../common/manager'; import { ExperimentProfile, TrialJobStatistics } from '../common/manager';
import { TrialJobDetail, TrialJobStatus } from '../common/trainingService'; import { TrialJobDetail, TrialJobStatus } from '../common/trainingService';
...@@ -21,7 +21,6 @@ class NNIDataStore implements DataStore { ...@@ -21,7 +21,6 @@ class NNIDataStore implements DataStore {
private db: Database = component.get(Database); private db: Database = component.get(Database);
private log: Logger = getLogger(); private log: Logger = getLogger();
private initTask!: Deferred<void>; private initTask!: Deferred<void>;
private multiPhase: boolean | undefined;
public init(): Promise<void> { public init(): Promise<void> {
if (this.initTask !== undefined) { if (this.initTask !== undefined) {
...@@ -241,16 +240,10 @@ class NNIDataStore implements DataStore { ...@@ -241,16 +240,10 @@ class NNIDataStore implements DataStore {
const map: Map<string, MetricDataRecord[]> = new Map(); const map: Map<string, MetricDataRecord[]> = new Map();
const metrics: MetricDataRecord[] = await this.getMetricData(trialJobId, 'FINAL'); const metrics: MetricDataRecord[] = await this.getMetricData(trialJobId, 'FINAL');
const multiPhase: boolean = await this.isMultiPhase();
for (const metric of metrics) { for (const metric of metrics) {
const existMetrics: MetricDataRecord[] | undefined = map.get(metric.trialJobId); const existMetrics: MetricDataRecord[] | undefined = map.get(metric.trialJobId);
if (existMetrics !== undefined) { if (existMetrics !== undefined) {
if (!multiPhase) { this.log.error(`Found multiple FINAL results for trial job ${trialJobId}, metrics: ${JSON.stringify(metrics)}`);
this.log.error(`Found multiple FINAL results for trial job ${trialJobId}, metrics: ${JSON.stringify(metrics)}`);
} else {
existMetrics.push(metric);
}
} else { } else {
map.set(metric.trialJobId, [metric]); map.set(metric.trialJobId, [metric]);
} }
...@@ -259,23 +252,6 @@ class NNIDataStore implements DataStore { ...@@ -259,23 +252,6 @@ class NNIDataStore implements DataStore {
return map; return map;
} }
private async isMultiPhase(): Promise<boolean> {
if (this.multiPhase === undefined) {
const expProfile: ExperimentProfile = await this.getExperimentProfile(getExperimentId());
if (expProfile !== undefined) {
this.multiPhase = expProfile.params.multiPhase;
} else {
return false;
}
}
if (this.multiPhase !== undefined) {
return this.multiPhase;
} else {
return false;
}
}
private getJobStatusByLatestEvent(oldStatus: TrialJobStatus, event: TrialJobEvent): TrialJobStatus { private getJobStatusByLatestEvent(oldStatus: TrialJobStatus, event: TrialJobEvent): TrialJobStatus {
switch (event) { switch (event) {
case 'USER_TO_CANCEL': case 'USER_TO_CANCEL':
......
...@@ -12,9 +12,10 @@ import { NNIError } from '../common/errors'; ...@@ -12,9 +12,10 @@ import { NNIError } from '../common/errors';
import { getExperimentId, getDispatcherPipe } from '../common/experimentStartupInfo'; import { getExperimentId, getDispatcherPipe } from '../common/experimentStartupInfo';
import { getLogger, Logger } from '../common/log'; import { getLogger, Logger } from '../common/log';
import { import {
ExperimentParams, ExperimentProfile, Manager, ExperimentStatus, ExperimentProfile, Manager, ExperimentStatus,
NNIManagerStatus, ProfileUpdateType, TrialJobStatistics NNIManagerStatus, ProfileUpdateType, TrialJobStatistics
} from '../common/manager'; } from '../common/manager';
import { ExperimentConfig, toSeconds, toCudaVisibleDevices } from '../common/experimentConfig';
import { ExperimentManager } from '../common/experimentManager'; import { ExperimentManager } from '../common/experimentManager';
import { TensorboardManager } from '../common/tensorboardManager'; import { TensorboardManager } from '../common/tensorboardManager';
import { import {
...@@ -32,29 +33,28 @@ import { NNIRestServer } from '../rest_server/nniRestServer'; ...@@ -32,29 +33,28 @@ import { NNIRestServer } from '../rest_server/nniRestServer';
* NNIManager which implements Manager interface * NNIManager which implements Manager interface
*/ */
class NNIManager implements Manager { class NNIManager implements Manager {
private trainingService: TrainingService; private trainingService!: TrainingService;
private dispatcher: IpcInterface | undefined; private dispatcher: IpcInterface | undefined;
private experimentManager: ExperimentManager; private experimentManager: ExperimentManager;
private currSubmittedTrialNum: number; // need to be recovered private currSubmittedTrialNum: number; // need to be recovered
private trialConcurrencyChange: number; // >0: increase, <0: decrease private trialConcurrencyChange: number; // >0: increase, <0: decrease
private log: Logger; private log: Logger;
private dataStore: DataStore; private dataStore: DataStore;
private experimentProfile: ExperimentProfile; private experimentProfile!: ExperimentProfile;
private dispatcherPid: number; private dispatcherPid: number;
private status: NNIManagerStatus; private status: NNIManagerStatus;
private waitingTrials: TrialJobApplicationForm[]; private waitingTrials: TrialJobApplicationForm[];
private trialJobs: Map<string, TrialJobDetail>; private trialJobs: Map<string, TrialJobDetail>;
private trialDataForTuner: string; private trialDataForTuner: string;
private readonly: boolean; private readonly: boolean;
private config!: ExperimentConfig;
private trialJobMetricListener: (metric: TrialJobMetric) => void; private trialJobMetricListener: (metric: TrialJobMetric) => void;
constructor() { constructor() {
this.currSubmittedTrialNum = 0; this.currSubmittedTrialNum = 0;
this.trialConcurrencyChange = 0; this.trialConcurrencyChange = 0;
this.trainingService = component.get(TrainingService);
this.experimentManager = component.get(ExperimentManager); this.experimentManager = component.get(ExperimentManager);
assert(this.trainingService);
this.dispatcherPid = 0; this.dispatcherPid = 0;
this.waitingTrials = []; this.waitingTrials = [];
this.trialJobs = new Map<string, TrialJobDetail>(); this.trialJobs = new Map<string, TrialJobDetail>();
...@@ -63,7 +63,6 @@ class NNIManager implements Manager { ...@@ -63,7 +63,6 @@ class NNIManager implements Manager {
this.log = getLogger(); this.log = getLogger();
this.dataStore = component.get(DataStore); this.dataStore = component.get(DataStore);
this.experimentProfile = this.createEmptyExperimentProfile();
this.status = { this.status = {
status: 'INITIALIZED', status: 'INITIALIZED',
errors: [] errors: []
...@@ -89,13 +88,13 @@ class NNIManager implements Manager { ...@@ -89,13 +88,13 @@ class NNIManager implements Manager {
this.updateTrialConcurrency(experimentProfile.params.trialConcurrency); this.updateTrialConcurrency(experimentProfile.params.trialConcurrency);
break; break;
case 'MAX_EXEC_DURATION': case 'MAX_EXEC_DURATION':
this.updateMaxExecDuration(experimentProfile.params.maxExecDuration); this.experimentProfile.params.maxExperimentDuration = experimentProfile.params.maxExperimentDuration;
break; break;
case 'SEARCH_SPACE': case 'SEARCH_SPACE':
this.updateSearchSpace(experimentProfile.params.searchSpace); this.updateSearchSpace(experimentProfile.params.searchSpace);
break; break;
case 'MAX_TRIAL_NUM': case 'MAX_TRIAL_NUM':
this.updateMaxTrialNum(experimentProfile.params.maxTrialNum); this.experimentProfile.params.maxTrialNumber = experimentProfile.params.maxTrialNumber;
break; break;
default: default:
throw new Error('Error: unrecognized updateType'); throw new Error('Error: unrecognized updateType');
...@@ -130,7 +129,7 @@ class NNIManager implements Manager { ...@@ -130,7 +129,7 @@ class NNIManager implements Manager {
if (this.readonly) { if (this.readonly) {
return Promise.reject(new Error('Error: can not add customized trial job in readonly mode!')); return Promise.reject(new Error('Error: can not add customized trial job in readonly mode!'));
} }
if (this.currSubmittedTrialNum >= this.experimentProfile.params.maxTrialNum) { if (this.currSubmittedTrialNum >= this.maxTrialNum) {
return Promise.reject(new Error('reach maxTrialNum')); return Promise.reject(new Error('reach maxTrialNum'));
} }
...@@ -165,35 +164,30 @@ class NNIManager implements Manager { ...@@ -165,35 +164,30 @@ class NNIManager implements Manager {
await this.dataStore.storeTrialJobEvent('USER_TO_CANCEL', trialJobId, ''); await this.dataStore.storeTrialJobEvent('USER_TO_CANCEL', trialJobId, '');
} }
public async startExperiment(expParams: ExperimentParams): Promise<string> { public async startExperiment(config: ExperimentConfig): Promise<string> {
this.experimentProfile = {
params: config,
id: getExperimentId(),
execDuration: 0,
logDir: getExperimentRootDir(),
startTime: Date.now(),
endTime: undefined,
nextSequenceId: 0,
revision: 0
};
this.log.info(`Starting experiment: ${this.experimentProfile.id}`); this.log.info(`Starting experiment: ${this.experimentProfile.id}`);
this.experimentProfile.params = expParams;
await this.storeExperimentProfile(); await this.storeExperimentProfile();
this.log.debug('Setup tuner...');
// Set up multiphase config this.log.info('Setup training service...');
if (expParams.multiPhase && this.trainingService.isMultiPhaseJobSupported) { this.trainingService = await this.initTrainingService(config);
this.trainingService.setClusterMetadata('multiPhase', expParams.multiPhase.toString());
}
// Set up versionCheck config
if (expParams.versionCheck !== undefined) {
this.trainingService.setClusterMetadata('version_check', expParams.versionCheck.toString());
}
// Set up logCollection config
if (expParams.logCollection !== undefined) {
this.trainingService.setClusterMetadata('log_collection', expParams.logCollection.toString());
}
const dispatcherCommand: string = getMsgDispatcherCommand(expParams); this.log.info('Setup tuner...');
const dispatcherCommand: string = getMsgDispatcherCommand(config);
this.log.debug(`dispatcher command: ${dispatcherCommand}`); this.log.debug(`dispatcher command: ${dispatcherCommand}`);
const checkpointDir: string = await this.createCheckpointDir(); const checkpointDir: string = await this.createCheckpointDir();
this.setupTuner( this.setupTuner(dispatcherCommand, undefined, 'start', checkpointDir);
dispatcherCommand,
undefined,
'start',
checkpointDir);
this.experimentProfile.startTime = Date.now();
this.setStatus('RUNNING'); this.setStatus('RUNNING');
await this.storeExperimentProfile(); await this.storeExperimentProfile();
this.run().catch((err: Error) => { this.run().catch((err: Error) => {
...@@ -212,26 +206,16 @@ class NNIManager implements Manager { ...@@ -212,26 +206,16 @@ class NNIManager implements Manager {
if (readonly) { if (readonly) {
return Promise.resolve(); return Promise.resolve();
} }
const expParams: ExperimentParams = this.experimentProfile.params;
// Set up multiphase config
if (expParams.multiPhase && this.trainingService.isMultiPhaseJobSupported) {
this.trainingService.setClusterMetadata('multiPhase', expParams.multiPhase.toString());
}
// Set up versionCheck config this.log.info('Setup training service...');
if (expParams.versionCheck !== undefined) { const config: ExperimentConfig = this.experimentProfile.params;
this.trainingService.setClusterMetadata('version_check', expParams.versionCheck.toString()); this.trainingService = await this.initTrainingService(config);
}
const dispatcherCommand: string = getMsgDispatcherCommand(expParams); this.log.info('Setup tuner...');
const dispatcherCommand: string = getMsgDispatcherCommand(config);
this.log.debug(`dispatcher command: ${dispatcherCommand}`); this.log.debug(`dispatcher command: ${dispatcherCommand}`);
const checkpointDir: string = await this.createCheckpointDir(); const checkpointDir: string = await this.createCheckpointDir();
this.setupTuner( this.setupTuner(dispatcherCommand, undefined, 'resume', checkpointDir);
dispatcherCommand,
undefined,
'resume',
checkpointDir);
const allTrialJobs: TrialJobInfo[] = await this.dataStore.listTrialJobs(); const allTrialJobs: TrialJobInfo[] = await this.dataStore.listTrialJobs();
...@@ -253,8 +237,8 @@ class NNIManager implements Manager { ...@@ -253,8 +237,8 @@ class NNIManager implements Manager {
} }
this.trialDataForTuner = JSON.stringify(trialData); this.trialDataForTuner = JSON.stringify(trialData);
if (this.experimentProfile.execDuration < this.experimentProfile.params.maxExecDuration && if (this.experimentProfile.execDuration < this.maxDuration &&
this.currSubmittedTrialNum < this.experimentProfile.params.maxTrialNum && this.currSubmittedTrialNum < this.maxTrialNum &&
this.experimentProfile.endTime) { this.experimentProfile.endTime) {
delete this.experimentProfile.endTime; delete this.experimentProfile.endTime;
} }
...@@ -270,27 +254,12 @@ class NNIManager implements Manager { ...@@ -270,27 +254,12 @@ class NNIManager implements Manager {
return this.dataStore.getTrialJob(trialJobId); return this.dataStore.getTrialJob(trialJobId);
} }
public async setClusterMetadata(key: string, value: string): Promise<void> { public async setClusterMetadata(_key: string, _value: string): Promise<void> {
if (this.readonly) { throw new Error('Calling removed API setClusterMetadata');
return Promise.reject(new Error('Error: can not set cluster metadata in readonly mode!'));
}
this.log.info(`NNIManager setClusterMetadata, key: ${key}, value: ${value}`);
let timeoutId: NodeJS.Timer;
// TO DO: move timeout value to constants file
const delay1: Promise<{}> = new Promise((resolve: Function, reject: Function): void => {
timeoutId = setTimeout(
() => { reject(new Error('TrainingService setClusterMetadata timeout. Please check your config file.')); },
30000);
});
await Promise.race([delay1, this.trainingService.setClusterMetadata(key, value)]).finally(() => {
clearTimeout(timeoutId);
});
} }
public getClusterMetadata(key: string): Promise<string> { public getClusterMetadata(_key: string): Promise<string> {
return Promise.resolve( throw new Error('Calling removed API getClusterMetadata');
this.trainingService.getClusterMetadata(key)
);
} }
public async getTrialJobStatistics(): Promise<TrialJobStatistics[]> { public async getTrialJobStatistics(): Promise<TrialJobStatistics[]> {
...@@ -424,6 +393,40 @@ class NNIManager implements Manager { ...@@ -424,6 +393,40 @@ class NNIManager implements Manager {
return this.dataStore.listTrialJobs(status); return this.dataStore.listTrialJobs(status);
} }
private get maxDuration(): number {
const value = this.experimentProfile.params.maxExperimentDuration;
return (value === undefined ? Infinity : toSeconds(value));
}
private get maxTrialNum(): number {
const value = this.experimentProfile.params.maxTrialNumber;
return (value === undefined ? Infinity : value);
}
private async initTrainingService(config: ExperimentConfig): Promise<TrainingService> {
this.config = config;
const platform = Array.isArray(config.trainingService) ? 'hybrid' : config.trainingService.platform;
if (['remote', 'pai', 'aml', 'hybrid'].includes(platform)) {
const module_ = await import('../training_service/reusable/routerTrainingService');
return new module_.RouterTrainingService(config);
} else if (platform === 'local') {
const module_ = await import('../training_service/local/localTrainingService');
return new module_.LocalTrainingService(config);
} else if (platform === 'kubeflow') {
const module_ = await import('../training_service/kubernetes/kubeflow/kubeflowTrainingService');
return new module_.KubeflowTrainingService();
} else if (platform === 'frameworkcontroller') {
const module_ = await import('../training_service/kubernetes/frameworkcontroller/frameworkcontrollerTrainingService');
return new module_.FrameworkControllerTrainingService();
} else if (platform === 'adl') {
const module_ = await import('../training_service/kubernetes/adl/adlTrainingService');
return new module_.AdlTrainingService();
}
throw new Error(`Unsupported training service platform "${platform}"`);
}
private setupTuner(command: string, cwd: string | undefined, mode: 'start' | 'resume', dataDirectory: string): void { private setupTuner(command: string, cwd: string | undefined, mode: 'start' | 'resume', dataDirectory: string): void {
if (this.dispatcher !== undefined) { if (this.dispatcher !== undefined) {
return; return;
...@@ -436,10 +439,7 @@ class NNIManager implements Manager { ...@@ -436,10 +439,7 @@ class NNIManager implements Manager {
newCwd = cwd; newCwd = cwd;
} }
// TO DO: add CUDA_VISIBLE_DEVICES // TO DO: add CUDA_VISIBLE_DEVICES
let includeIntermediateResultsEnv: boolean | undefined = false; const includeIntermediateResultsEnv = !!(this.config.deprecated && this.config.deprecated.includeIntermediateResults);
if (this.experimentProfile.params.tuner !== undefined) {
includeIntermediateResultsEnv = this.experimentProfile.params.tuner.includeIntermediateResults;
}
const nniEnv = { const nniEnv = {
SDK_PROCESS: 'dispatcher', SDK_PROCESS: 'dispatcher',
...@@ -448,7 +448,7 @@ class NNIManager implements Manager { ...@@ -448,7 +448,7 @@ class NNIManager implements Manager {
NNI_LOG_DIRECTORY: getLogDir(), NNI_LOG_DIRECTORY: getLogDir(),
NNI_LOG_LEVEL: getLogLevel(), NNI_LOG_LEVEL: getLogLevel(),
NNI_INCLUDE_INTERMEDIATE_RESULTS: includeIntermediateResultsEnv, NNI_INCLUDE_INTERMEDIATE_RESULTS: includeIntermediateResultsEnv,
CUDA_VISIBLE_DEVICES: this.getGpuEnvvarValue() CUDA_VISIBLE_DEVICES: toCudaVisibleDevices(this.experimentProfile.params.tunerGpuIndices)
}; };
const newEnv = Object.assign({}, process.env, nniEnv); const newEnv = Object.assign({}, process.env, nniEnv);
const tunerProc: ChildProcess = getTunerProc(command, stdio, newCwd, newEnv); const tunerProc: ChildProcess = getTunerProc(command, stdio, newCwd, newEnv);
...@@ -458,22 +458,6 @@ class NNIManager implements Manager { ...@@ -458,22 +458,6 @@ class NNIManager implements Manager {
return; return;
} }
private getGpuEnvvarValue(): string {
let cudaDevices: string | undefined;
if (this.experimentProfile.params.advisor !== undefined) {
cudaDevices = this.experimentProfile.params.advisor.gpuIndices;
} else if (this.experimentProfile.params.tuner !== undefined) {
cudaDevices = this.experimentProfile.params.tuner.gpuIndices;
}
if (cudaDevices === undefined) {
return '';
} else {
return cudaDevices;
}
}
private updateTrialConcurrency(trialConcurrency: number): void { private updateTrialConcurrency(trialConcurrency: number): void {
// we assume trialConcurrency >= 0, which is checked by restserver // we assume trialConcurrency >= 0, which is checked by restserver
this.trialConcurrencyChange += (trialConcurrency - this.experimentProfile.params.trialConcurrency); this.trialConcurrencyChange += (trialConcurrency - this.experimentProfile.params.trialConcurrency);
...@@ -482,12 +466,6 @@ class NNIManager implements Manager { ...@@ -482,12 +466,6 @@ class NNIManager implements Manager {
return; return;
} }
private updateMaxExecDuration(duration: number): void {
this.experimentProfile.params.maxExecDuration = duration;
return;
}
private updateSearchSpace(searchSpace: string): void { private updateSearchSpace(searchSpace: string): void {
if (this.dispatcher === undefined) { if (this.dispatcher === undefined) {
throw new Error('Error: tuner has not been setup'); throw new Error('Error: tuner has not been setup');
...@@ -498,12 +476,6 @@ class NNIManager implements Manager { ...@@ -498,12 +476,6 @@ class NNIManager implements Manager {
return; return;
} }
private updateMaxTrialNum(maxTrialNum: number): void {
this.experimentProfile.params.maxTrialNum = maxTrialNum;
return;
}
private async periodicallyUpdateExecDuration(): Promise<void> { private async periodicallyUpdateExecDuration(): Promise<void> {
let count: number = 1; let count: number = 1;
while (!['ERROR', 'STOPPING', 'STOPPED'].includes(this.status.status)) { while (!['ERROR', 'STOPPING', 'STOPPED'].includes(this.status.status)) {
...@@ -619,8 +591,8 @@ class NNIManager implements Manager { ...@@ -619,8 +591,8 @@ class NNIManager implements Manager {
this.status.status === 'DONE' || this.status.status === 'DONE' ||
this.status.status === 'NO_MORE_TRIAL' || this.status.status === 'NO_MORE_TRIAL' ||
this.status.status === 'TUNER_NO_MORE_TRIAL', `Actual status: ${this.status.status}`); this.status.status === 'TUNER_NO_MORE_TRIAL', `Actual status: ${this.status.status}`);
if (this.experimentProfile.execDuration > this.experimentProfile.params.maxExecDuration || if (this.experimentProfile.execDuration > this.maxDuration ||
this.currSubmittedTrialNum >= this.experimentProfile.params.maxTrialNum) { this.currSubmittedTrialNum >= this.maxTrialNum) {
if (this.status.status !== 'DONE') { if (this.status.status !== 'DONE') {
this.setStatus('NO_MORE_TRIAL'); this.setStatus('NO_MORE_TRIAL');
waitSubmittedToFinish = this.currSubmittedTrialNum; waitSubmittedToFinish = this.currSubmittedTrialNum;
...@@ -644,7 +616,7 @@ class NNIManager implements Manager { ...@@ -644,7 +616,7 @@ class NNIManager implements Manager {
} }
for (let i: number = this.trialJobs.size; i < this.experimentProfile.params.trialConcurrency; i++) { for (let i: number = this.trialJobs.size; i < this.experimentProfile.params.trialConcurrency; i++) {
if (this.waitingTrials.length === 0 || if (this.waitingTrials.length === 0 ||
this.currSubmittedTrialNum >= this.experimentProfile.params.maxTrialNum) { this.currSubmittedTrialNum >= this.maxTrialNum) {
break; break;
} }
const form = this.waitingTrials.shift() as TrialJobApplicationForm; const form = this.waitingTrials.shift() as TrialJobApplicationForm;
...@@ -718,7 +690,7 @@ class NNIManager implements Manager { ...@@ -718,7 +690,7 @@ class NNIManager implements Manager {
} }
this.log.debug(`Send tuner command: INITIALIZE: ${this.experimentProfile.params.searchSpace}`); this.log.debug(`Send tuner command: INITIALIZE: ${this.experimentProfile.params.searchSpace}`);
// Tuner need to be initialized with search space before generating any hyper parameters // Tuner need to be initialized with search space before generating any hyper parameters
this.dispatcher.sendCommand(INITIALIZE, this.experimentProfile.params.searchSpace); this.dispatcher.sendCommand(INITIALIZE, JSON.stringify(this.experimentProfile.params.searchSpace));
} }
private async onTrialJobMetrics(metric: TrialJobMetric): Promise<void> { private async onTrialJobMetrics(metric: TrialJobMetric): Promise<void> {
...@@ -741,7 +713,7 @@ class NNIManager implements Manager { ...@@ -741,7 +713,7 @@ class NNIManager implements Manager {
if (this.dispatcher === undefined) { if (this.dispatcher === undefined) {
throw new Error('Dispatcher error: tuner has not been setup'); throw new Error('Dispatcher error: tuner has not been setup');
} }
if (this.experimentProfile.params.multiThread) { if (this.config.deprecated && this.config.deprecated.multiThread) {
// Send multiple requests to ensure multiple hyper parameters are generated in non-blocking way. // Send multiple requests to ensure multiple hyper parameters are generated in non-blocking way.
// For a single REQUEST_TRIAL_JOBS request, hyper parameters are generated one by one // For a single REQUEST_TRIAL_JOBS request, hyper parameters are generated one by one
// sequentially. // sequentially.
...@@ -846,42 +818,11 @@ class NNIManager implements Manager { ...@@ -846,42 +818,11 @@ class NNIManager implements Manager {
this.experimentManager.setExperimentInfo(this.experimentProfile.id, 'endTime', this.experimentProfile.endTime); this.experimentManager.setExperimentInfo(this.experimentProfile.id, 'endTime', this.experimentProfile.endTime);
} }
private createEmptyExperimentProfile(): ExperimentProfile {
return {
id: getExperimentId(),
revision: 0,
execDuration: 0,
logDir: getExperimentRootDir(),
nextSequenceId: 0,
params: {
authorName: '',
experimentName: '',
trialConcurrency: 0,
maxExecDuration: 0, // unit: second
maxTrialNum: 0, // maxTrialNum includes all the submitted trial jobs
trainingServicePlatform: '',
searchSpace: ''
}
};
}
private async createCheckpointDir(): Promise<string> { private async createCheckpointDir(): Promise<string> {
// TODO: test // TODO: test
const chkpDir: string = getCheckpointDir(); const chkpDir: string = getCheckpointDir();
// create checkpoint directory
await mkDirP(chkpDir); await mkDirP(chkpDir);
// assign this directory to exp profile's checkpointDir return chkpDir;
if (this.experimentProfile.params.advisor) {
this.experimentProfile.params.advisor.checkpointDir = chkpDir;
}
if (this.experimentProfile.params.tuner) {
this.experimentProfile.params.tuner.checkpointDir = chkpDir;
}
if (this.experimentProfile.params.assessor) {
this.experimentProfile.params.assessor.checkpointDir = chkpDir;
}
return Promise.resolve(chkpDir);
} }
public async getTrialOutputLocalPath(trialJobId: string): Promise<string> { public async getTrialOutputLocalPath(trialJobId: string): Promise<string> {
......
...@@ -38,12 +38,13 @@ describe('Unit test for dataStore', () => { ...@@ -38,12 +38,13 @@ describe('Unit test for dataStore', () => {
it('test experiment profiles CRUD', async () => { it('test experiment profiles CRUD', async () => {
const profile: ExperimentProfile = { const profile: ExperimentProfile = {
params: { params: {
authorName: 'test1',
experimentName: 'exp1', experimentName: 'exp1',
trialConcurrency: 2, trialConcurrency: 2,
maxExecDuration: 10, maxExperimentDuration: '10s',
maxTrialNum: 5, maxTrialNumber: 5,
trainingServicePlatform: 'local', trainingService: {
platform: 'local'
},
searchSpace: `{ searchSpace: `{
"dropout_rate": { "dropout_rate": {
"_type": "uniform", "_type": "uniform",
...@@ -55,12 +56,15 @@ describe('Unit test for dataStore', () => { ...@@ -55,12 +56,15 @@ describe('Unit test for dataStore', () => {
} }
}`, }`,
tuner: { tuner: {
className: 'testTuner', className: 'testTuner'
checkpointDir: '/tmp/cp' },
} trialCommand: '',
trialCodeDirectory: '',
debug: true
}, },
id: 'exp123', id: 'exp123',
execDuration: 0, execDuration: 0,
logDir: '',
startTime: Date.now(), startTime: Date.now(),
endTime: Date.now(), endTime: Date.now(),
nextSequenceId: 0, nextSequenceId: 0,
......
...@@ -6,7 +6,7 @@ import * as glob from 'glob'; ...@@ -6,7 +6,7 @@ import * as glob from 'glob';
glob.sync('**/*.ts').forEach((file) => { glob.sync('**/*.ts').forEach((file) => {
if (file.indexOf('node_modules/') < 0 && file.indexOf('types/') < 0 if (file.indexOf('node_modules/') < 0 && file.indexOf('types/') < 0
&& file.indexOf('.test.ts') < 0 && file.indexOf('main.ts')) { && file.indexOf('.test.ts') < 0 && file.indexOf('dlts') < 0 && file.indexOf('main.ts')) {
try { try {
import('../../' + file); import('../../' + file);
} catch(err) { } catch(err) {
......
...@@ -22,24 +22,24 @@ function startProcess(): void { ...@@ -22,24 +22,24 @@ function startProcess(): void {
// Mock tuner config // Mock tuner config
{ {
experimentName: 'exp1', experimentName: 'exp1',
maxExecDuration: 3600, maxExperimentDuration: '1h',
searchSpace: '', searchSpace: '',
trainingServicePlatform: 'local', trainingService: {
authorName: '', platform: 'local'
},
trialConcurrency: 1, trialConcurrency: 1,
maxTrialNum: 5, maxTrialNumber: 5,
tuner: { tuner: {
className: 'DummyTuner', className: 'dummy_tuner.DummyTuner',
codeDir: './', codeDirectory: '.'
classFileName: 'dummy_tuner.py',
checkpointDir: './'
}, },
assessor: { assessor: {
className: 'DummyAssessor', className: 'dummy_assessor.DummyAssessor',
codeDir: './', codeDirectory: '.'
classFileName: 'dummy_assessor.py', },
checkpointDir: './' trialCommand: '',
} trialCodeDirectory: '',
debug: true
} }
); );
const proc: ChildProcess = getTunerProc(dispatcherCmd, stdio, 'core/test', process.env); const proc: ChildProcess = getTunerProc(dispatcherCmd, stdio, 'core/test', process.env);
......
...@@ -25,7 +25,6 @@ import * as path from 'path'; ...@@ -25,7 +25,6 @@ import * as path from 'path';
async function initContainer(): Promise<void> { async function initContainer(): Promise<void> {
prepareUnitTest(); prepareUnitTest();
Container.bind(TrainingService).to(MockedTrainingService).scope(Scope.Singleton);
Container.bind(Manager).to(NNIManager).scope(Scope.Singleton); Container.bind(Manager).to(NNIManager).scope(Scope.Singleton);
Container.bind(Database).to(SqlDB).scope(Scope.Singleton); Container.bind(Database).to(SqlDB).scope(Scope.Singleton);
Container.bind(DataStore).to(MockedDataStore).scope(Scope.Singleton); Container.bind(DataStore).to(MockedDataStore).scope(Scope.Singleton);
...@@ -37,58 +36,62 @@ async function initContainer(): Promise<void> { ...@@ -37,58 +36,62 @@ async function initContainer(): Promise<void> {
describe('Unit test for nnimanager', function () { describe('Unit test for nnimanager', function () {
this.timeout(10000); this.timeout(10000);
let nniManager: Manager; let nniManager: NNIManager;
let ClusterMetadataKey = 'mockedMetadataKey'; let ClusterMetadataKey = 'mockedMetadataKey';
let experimentParams = { let experimentParams = {
authorName: 'zql',
experimentName: 'naive_experiment', experimentName: 'naive_experiment',
trialConcurrency: 3, trialConcurrency: 3,
maxExecDuration: 5, maxExperimentDuration: '5s',
maxTrialNum: 3, maxTrialNumber: 3,
trainingServicePlatform: 'local', trainingService: {
searchSpace: '{"lr": {"_type": "choice", "_value": [0.01,0.001]}}', platform: 'local'
},
searchSpace: {'lr': {'_type': 'choice', '_value': [0.01,0.001]}},
tuner: { tuner: {
builtinTunerName: 'TPE', name: 'TPE',
classArgs: { classArgs: {
optimize_mode: 'maximize' optimize_mode: 'maximize'
}, }
checkpointDir: '',
}, },
assessor: { assessor: {
builtinAssessorName: 'Medianstop', name: 'Medianstop'
checkpointDir: '', },
} trialCommand: 'sleep 2',
trialCodeDirectory: '',
debug: true
} }
let updateExperimentParams = { let updateExperimentParams = {
authorName: '',
experimentName: 'another_experiment', experimentName: 'another_experiment',
trialConcurrency: 2, trialConcurrency: 2,
maxExecDuration: 6, maxExperimentDuration: '6s',
maxTrialNum: 2, maxTrialNumber: 2,
trainingServicePlatform: 'local', trainingService: {
platform: 'local'
},
searchSpace: '{"lr": {"_type": "choice", "_value": [0.01,0.001]}}', searchSpace: '{"lr": {"_type": "choice", "_value": [0.01,0.001]}}',
tuner: { tuner: {
builtinTunerName: 'TPE', name: 'TPE',
classArgs: { classArgs: {
optimize_mode: 'maximize' optimize_mode: 'maximize'
}, }
checkpointDir: '',
gpuNum: 0
}, },
assessor: { assessor: {
builtinAssessorName: 'Medianstop', name: 'Medianstop'
checkpointDir: '', },
gpuNum: 1 trialCommand: 'sleep 2',
} trialCodeDirectory: '',
debug: true
} }
let experimentProfile = { let experimentProfile = {
params: updateExperimentParams, params: updateExperimentParams,
id: 'test', id: 'test',
execDuration: 0, execDuration: 0,
logDir: '',
startTime: 0,
nextSequenceId: 0, nextSequenceId: 0,
revision: 0 revision: 0
} }
...@@ -114,8 +117,20 @@ describe('Unit test for nnimanager', function () { ...@@ -114,8 +117,20 @@ describe('Unit test for nnimanager', function () {
const experimentsManager: ExperimentManager = component.get(ExperimentManager); const experimentsManager: ExperimentManager = component.get(ExperimentManager);
experimentsManager.setExperimentPath('.experiment.test'); experimentsManager.setExperimentPath('.experiment.test');
nniManager = component.get(Manager); nniManager = component.get(Manager);
const expId: string = await nniManager.startExperiment(experimentParams); const expId: string = await nniManager.startExperiment(experimentParams);
assert.strictEqual(expId, 'unittest'); assert.strictEqual(expId, 'unittest');
// TODO:
// In current architecture we cannot prevent NNI manager from creating a training service.
// The training service must be manually stopped here or its callbacks will block exit.
// I'm planning on a custom training service register system similar to custom tuner,
// and when that is done we can let NNI manager to use MockedTrainingService through config.
const manager = nniManager as any;
manager.trainingService.removeTrialJobMetricListener(manager.trialJobMetricListener);
manager.trainingService.cleanUp();
manager.trainingService = new MockedTrainingService();
}) })
after(async () => { after(async () => {
...@@ -160,28 +175,11 @@ describe('Unit test for nnimanager', function () { ...@@ -160,28 +175,11 @@ describe('Unit test for nnimanager', function () {
}) })
}) })
it('test getClusterMetadata', () => {
//default value is "default"
return nniManager.getClusterMetadata(ClusterMetadataKey).then(function (value) {
expect(value).to.equal("default");
});
})
it('test setClusterMetadata and getClusterMetadata', () => {
//set a valid key
return nniManager.setClusterMetadata(ClusterMetadataKey, "newdata").then(() => {
return nniManager.getClusterMetadata(ClusterMetadataKey).then(function (value) {
expect(value).to.equal("newdata");
});
}).catch((error) => {
console.log(error);
})
})
it('test cancelTrialJobByUser', () => { it('test cancelTrialJobByUser', () => {
return nniManager.cancelTrialJobByUser('1234').then(() => { return nniManager.cancelTrialJobByUser('1234').then(() => {
}).catch((error) => { }).catch((error) => {
console.log(error);
assert.fail(error); assert.fail(error);
}) })
}) })
...@@ -209,7 +207,7 @@ describe('Unit test for nnimanager', function () { ...@@ -209,7 +207,7 @@ describe('Unit test for nnimanager', function () {
it('test updateExperimentProfile MAX_EXEC_DURATION', () => { it('test updateExperimentProfile MAX_EXEC_DURATION', () => {
return nniManager.updateExperimentProfile(experimentProfile, 'MAX_EXEC_DURATION').then(() => { return nniManager.updateExperimentProfile(experimentProfile, 'MAX_EXEC_DURATION').then(() => {
nniManager.getExperimentProfile().then((updateProfile) => { nniManager.getExperimentProfile().then((updateProfile) => {
expect(updateProfile.params.maxExecDuration).to.be.equal(6); expect(updateProfile.params.maxExperimentDuration).to.be.equal('6s');
}); });
}).catch((error) => { }).catch((error) => {
assert.fail(error); assert.fail(error);
...@@ -229,9 +227,9 @@ describe('Unit test for nnimanager', function () { ...@@ -229,9 +227,9 @@ describe('Unit test for nnimanager', function () {
it('test updateExperimentProfile MAX_TRIAL_NUM', () => { it('test updateExperimentProfile MAX_TRIAL_NUM', () => {
return nniManager.updateExperimentProfile(experimentProfile, 'MAX_TRIAL_NUM').then(() => { return nniManager.updateExperimentProfile(experimentProfile, 'MAX_TRIAL_NUM').then(() => {
nniManager.getExperimentProfile().then((updateProfile) => { nniManager.getExperimentProfile().then((updateProfile) => {
expect(updateProfile.params.maxTrialNum).to.be.equal(2); expect(updateProfile.params.maxTrialNumber).to.be.equal(2);
}); });
}).catch((error) => { }).catch((error: any) => {
assert.fail(error); assert.fail(error);
}) })
}) })
...@@ -276,8 +274,8 @@ describe('Unit test for nnimanager', function () { ...@@ -276,8 +274,8 @@ describe('Unit test for nnimanager', function () {
}) })
}) })
it('test addCustomizedTrialJob reach maxTrialNum', () => { it('test addCustomizedTrialJob reach maxTrialNumber', () => {
// test currSubmittedTrialNum reach maxTrialNum // test currSubmittedTrialNum reach maxTrialNumber
return nniManager.addCustomizedTrialJob('"hyperParam"').then(() => { return nniManager.addCustomizedTrialJob('"hyperParam"').then(() => {
nniManager.getTrialJobStatistics().then(function (trialJobStatistics) { nniManager.getTrialJobStatistics().then(function (trialJobStatistics) {
if (trialJobStatistics[0].trialJobStatus === 'WAITING') if (trialJobStatistics[0].trialJobStatus === 'WAITING')
......
...@@ -10,40 +10,45 @@ import { Container } from 'typescript-ioc'; ...@@ -10,40 +10,45 @@ import { Container } from 'typescript-ioc';
import * as component from '../../common/component'; import * as component from '../../common/component';
import { Database, MetricDataRecord, TrialJobEvent, TrialJobEventRecord } from '../../common/datastore'; import { Database, MetricDataRecord, TrialJobEvent, TrialJobEventRecord } from '../../common/datastore';
import { setExperimentStartupInfo } from '../../common/experimentStartupInfo'; import { setExperimentStartupInfo } from '../../common/experimentStartupInfo';
import { ExperimentParams, ExperimentProfile } from '../../common/manager'; import { ExperimentConfig, ExperimentProfile } from '../../common/manager';
import { cleanupUnitTest, getDefaultDatabaseDir, mkDirP, prepareUnitTest } from '../../common/utils'; import { cleanupUnitTest, getDefaultDatabaseDir, mkDirP, prepareUnitTest } from '../../common/utils';
import { SqlDB } from '../sqlDatabase'; import { SqlDB } from '../sqlDatabase';
const expParams1: ExperimentParams = { const expParams1: ExperimentConfig = {
authorName: 'ZhangSan',
experimentName: 'Exp1', experimentName: 'Exp1',
trialConcurrency: 3, trialConcurrency: 3,
maxExecDuration: 100, maxExperimentDuration: '100s',
maxTrialNum: 5, maxTrialNumber: 5,
trainingServicePlatform: 'local', trainingService: {
platform: 'local'
},
searchSpace: 'SS', searchSpace: 'SS',
tuner: { tuner: {
className: 'testTuner', className: 'testTuner'
checkpointDir: '/tmp' },
} trialCommand: '',
trialCodeDirectory: '',
debug: true
}; };
const expParams2: ExperimentParams = { const expParams2: ExperimentConfig = {
authorName: 'LiSi',
experimentName: 'Exp2', experimentName: 'Exp2',
trialConcurrency: 5, trialConcurrency: 5,
maxExecDuration: 1000, maxExperimentDuration: '1000s',
maxTrialNum: 5, maxTrialNumber: 5,
trainingServicePlatform: 'local', trainingService: {
platform: 'local'
},
searchSpace: '', searchSpace: '',
tuner: { tuner: {
className: 'testTuner', className: 'testTuner'
checkpointDir: '/tmp'
}, },
assessor: { assessor: {
className: 'testAssessor', className: 'testAssessor'
checkpointDir: '/tmp' },
} trialCommand: '',
trialCodeDirectory: '',
debug: true
}; };
const profiles: ExperimentProfile[] = [ const profiles: ExperimentProfile[] = [
......
...@@ -14,7 +14,6 @@ import { getLogger, Logger, logLevelNameMap } from './common/log'; ...@@ -14,7 +14,6 @@ import { getLogger, Logger, logLevelNameMap } from './common/log';
import { Manager, ExperimentStartUpMode } from './common/manager'; import { Manager, ExperimentStartUpMode } from './common/manager';
import { ExperimentManager } from './common/experimentManager'; import { ExperimentManager } from './common/experimentManager';
import { TensorboardManager } from './common/tensorboardManager'; import { TensorboardManager } from './common/tensorboardManager';
import { TrainingService } from './common/trainingService';
import { getLogDir, mkDirP, parseArg } from './common/utils'; import { getLogDir, mkDirP, parseArg } from './common/utils';
import { NNIDataStore } from './core/nniDataStore'; import { NNIDataStore } from './core/nniDataStore';
import { NNIManager } from './core/nnimanager'; import { NNIManager } from './core/nnimanager';
...@@ -22,12 +21,6 @@ import { SqlDB } from './core/sqlDatabase'; ...@@ -22,12 +21,6 @@ import { SqlDB } from './core/sqlDatabase';
import { NNIExperimentsManager } from './core/nniExperimentsManager'; import { NNIExperimentsManager } from './core/nniExperimentsManager';
import { NNITensorboardManager } from './core/nniTensorboardManager'; import { NNITensorboardManager } from './core/nniTensorboardManager';
import { NNIRestServer } from './rest_server/nniRestServer'; import { NNIRestServer } from './rest_server/nniRestServer';
import { FrameworkControllerTrainingService } from './training_service/kubernetes/frameworkcontroller/frameworkcontrollerTrainingService';
import { AdlTrainingService } from './training_service/kubernetes/adl/adlTrainingService';
import { KubeflowTrainingService } from './training_service/kubernetes/kubeflow/kubeflowTrainingService';
import { LocalTrainingService } from './training_service/local/localTrainingService';
import { RouterTrainingService } from './training_service/reusable/routerTrainingService';
import { DLTSTrainingService } from './training_service/dlts/dltsTrainingService';
function initStartupInfo( function initStartupInfo(
...@@ -38,34 +31,6 @@ function initStartupInfo( ...@@ -38,34 +31,6 @@ function initStartupInfo(
} }
async function initContainer(foreground: boolean, platformMode: string, logFileName?: string): Promise<void> { async function initContainer(foreground: boolean, platformMode: string, logFileName?: string): Promise<void> {
const routerPlatformMode = ['remote', 'pai', 'aml', 'hybrid'];
if (routerPlatformMode.includes(platformMode)) {
Container.bind(TrainingService)
.to(RouterTrainingService)
.scope(Scope.Singleton);
} else if (platformMode === 'local') {
Container.bind(TrainingService)
.to(LocalTrainingService)
.scope(Scope.Singleton);
} else if (platformMode === 'kubeflow') {
Container.bind(TrainingService)
.to(KubeflowTrainingService)
.scope(Scope.Singleton);
} else if (platformMode === 'frameworkcontroller') {
Container.bind(TrainingService)
.to(FrameworkControllerTrainingService)
.scope(Scope.Singleton);
} else if (platformMode === 'dlts') {
Container.bind(TrainingService)
.to(DLTSTrainingService)
.scope(Scope.Singleton);
} else if (platformMode === 'adl') {
Container.bind(TrainingService)
.to(AdlTrainingService)
.scope(Scope.Singleton);
} else {
throw new Error(`Error: unsupported mode: ${platformMode}`);
}
Container.bind(Manager) Container.bind(Manager)
.to(NNIManager) .to(NNIManager)
.scope(Scope.Singleton); .scope(Scope.Singleton);
......
...@@ -40,7 +40,6 @@ class NNIRestHandler { ...@@ -40,7 +40,6 @@ class NNIRestHandler {
router.use((req: Request, res: Response, next) => { router.use((req: Request, res: Response, next) => {
this.log.debug(`${req.method}: ${req.url}: body:\n${JSON.stringify(req.body, undefined, 4)}`); this.log.debug(`${req.method}: ${req.url}: body:\n${JSON.stringify(req.body, undefined, 4)}`);
res.header('Access-Control-Allow-Origin', '*');
res.header('Access-Control-Allow-Headers', 'Origin, X-Requested-With, Content-Type, Accept'); res.header('Access-Control-Allow-Headers', 'Origin, X-Requested-With, Content-Type, Accept');
res.header('Access-Control-Allow-Methods', 'PUT,POST,GET,DELETE,OPTIONS'); res.header('Access-Control-Allow-Methods', 'PUT,POST,GET,DELETE,OPTIONS');
...@@ -139,7 +138,7 @@ class NNIRestHandler { ...@@ -139,7 +138,7 @@ class NNIRestHandler {
} }
private updateExperimentProfile(router: Router): void { private updateExperimentProfile(router: Router): void {
router.put('/experiment', expressJoi(ValidationSchemas.UPDATEEXPERIMENT), (req: Request, res: Response) => { router.put('/experiment', (req: Request, res: Response) => {
this.nniManager.updateExperimentProfile(req.body, req.query.update_type).then(() => { this.nniManager.updateExperimentProfile(req.body, req.query.update_type).then(() => {
res.send(); res.send();
}).catch((err: Error) => { }).catch((err: Error) => {
...@@ -169,7 +168,7 @@ class NNIRestHandler { ...@@ -169,7 +168,7 @@ class NNIRestHandler {
} }
private startExperiment(router: Router): void { private startExperiment(router: Router): void {
router.post('/experiment', expressJoi(ValidationSchemas.STARTEXPERIMENT), (req: Request, res: Response) => { router.post('/experiment', (req: Request, res: Response) => {
if (isNewExperiment()) { if (isNewExperiment()) {
this.nniManager.startExperiment(req.body).then((eid: string) => { this.nniManager.startExperiment(req.body).then((eid: string) => {
res.send({ res.send({
......
...@@ -9,7 +9,7 @@ import { Provider } from 'typescript-ioc'; ...@@ -9,7 +9,7 @@ import { Provider } from 'typescript-ioc';
import { MetricDataRecord, MetricType, TrialJobInfo } from '../../common/datastore'; import { MetricDataRecord, MetricType, TrialJobInfo } from '../../common/datastore';
import { MethodNotImplementedError } from '../../common/errors'; import { MethodNotImplementedError } from '../../common/errors';
import { import {
ExperimentParams, ExperimentProfile, Manager, ProfileUpdateType, ExperimentConfig, ExperimentProfile, Manager, ProfileUpdateType,
TrialJobStatistics, NNIManagerStatus TrialJobStatistics, NNIManagerStatus
} from '../../common/manager'; } from '../../common/manager';
import { import {
...@@ -90,7 +90,7 @@ export class MockedNNIManager extends Manager { ...@@ -90,7 +90,7 @@ export class MockedNNIManager extends Manager {
return Promise.resolve('METAVALUE1'); return Promise.resolve('METAVALUE1');
} }
public startExperiment(experimentParams: ExperimentParams): Promise<string> { public startExperiment(experimentParams: ExperimentConfig): Promise<string> {
return Promise.resolve('id-1234'); return Promise.resolve('id-1234');
} }
...@@ -135,20 +135,24 @@ export class MockedNNIManager extends Manager { ...@@ -135,20 +135,24 @@ export class MockedNNIManager extends Manager {
public getExperimentProfile(): Promise<ExperimentProfile> { public getExperimentProfile(): Promise<ExperimentProfile> {
const profile: ExperimentProfile = { const profile: ExperimentProfile = {
params: { params: {
authorName: 'test',
experimentName: 'exp1', experimentName: 'exp1',
trialConcurrency: 2, trialConcurrency: 2,
maxExecDuration: 30, maxExperimentDuration: '30s',
maxTrialNum: 3, maxTrialNumber: 3,
trainingServicePlatform: 'local', trainingService: {
platform: 'local'
},
searchSpace: '{lr: 0.01}', searchSpace: '{lr: 0.01}',
tuner: { tuner: {
className: 'testTuner', className: 'testTuner',
checkpointDir: '' },
} trialCommand: '',
trialCodeDirectory: '',
debug: true
}, },
id: '2345', id: '2345',
execDuration: 0, execDuration: 0,
logDir: '',
startTime: Date.now(), startTime: Date.now(),
endTime: Date.now(), endTime: Date.now(),
nextSequenceId: 0, nextSequenceId: 0,
......
...@@ -356,5 +356,9 @@ python3 -m nni.tools.trial_tool.trial_keeper --trial_command '{8}' \ ...@@ -356,5 +356,9 @@ python3 -m nni.tools.trial_tool.trial_keeper --trial_command '{8}' \
return Promise.resolve(result); return Promise.resolve(result);
} }
public async updateTrialJob(_1: any, _2: any): Promise<TrialJobDetail> {
throw new Error('not supported');
}
} }
export { AdlTrainingService }; export { AdlTrainingService };
...@@ -563,6 +563,10 @@ class FrameworkControllerTrainingService extends KubernetesTrainingService imple ...@@ -563,6 +563,10 @@ class FrameworkControllerTrainingService extends KubernetesTrainingService imple
} }
}; };
} }
public async updateTrialJob(_1: any, _2: any): Promise<TrialJobDetail> {
throw new Error('not supported');
}
} }
export {FrameworkControllerTrainingService}; export {FrameworkControllerTrainingService};
...@@ -463,5 +463,9 @@ class KubeflowTrainingService extends KubernetesTrainingService implements Kuber ...@@ -463,5 +463,9 @@ class KubeflowTrainingService extends KubernetesTrainingService implements Kuber
} }
} }
} }
public async updateTrialJob(_1: any, _2: any): Promise<TrialJobDetail> {
throw new Error('not supported');
}
} }
export { KubeflowTrainingService }; export { KubeflowTrainingService };
...@@ -43,7 +43,7 @@ class GPUScheduler { ...@@ -43,7 +43,7 @@ class GPUScheduler {
} }
} }
public getAvailableGPUIndices(useActiveGpu: boolean, occupiedGpuIndexNumMap: Map<number, number>): number[] { public getAvailableGPUIndices(useActiveGpu: boolean | undefined, occupiedGpuIndexNumMap: Map<number, number>): number[] {
if (this.gpuSummary !== undefined) { if (this.gpuSummary !== undefined) {
if (process.platform === 'win32' || useActiveGpu) { if (process.platform === 'win32' || useActiveGpu) {
return this.gpuSummary.gpuInfos.map((info: GPUInfo) => info.index); return this.gpuSummary.gpuInfos.map((info: GPUInfo) => info.index);
......
...@@ -2,7 +2,6 @@ ...@@ -2,7 +2,6 @@
// Licensed under the MIT license. // Licensed under the MIT license.
'use strict'; 'use strict';
import * as cpp from 'child-process-promise';
import * as cp from 'child_process'; import * as cp from 'child_process';
import { EventEmitter } from 'events'; import { EventEmitter } from 'events';
import * as fs from 'fs'; import * as fs from 'fs';
...@@ -19,8 +18,7 @@ import { ...@@ -19,8 +18,7 @@ import {
import { import {
delay, generateParamFileName, getExperimentRootDir, getJobCancelStatus, getNewLine, isAlive, uniqueString delay, generateParamFileName, getExperimentRootDir, getJobCancelStatus, getNewLine, isAlive, uniqueString
} from '../../common/utils'; } from '../../common/utils';
import { TrialConfig } from '../common/trialConfig'; import { ExperimentConfig, LocalConfig, flattenConfig } from '../../common/experimentConfig';
import { TrialConfigMetadataKey } from '../common/trialConfigMetadataKey';
import { execMkdir, execNewFile, getScriptName, runScript, setEnvironmentVariable } from '../common/util'; import { execMkdir, execNewFile, getScriptName, runScript, setEnvironmentVariable } from '../common/util';
import { GPUScheduler } from './gpuScheduler'; import { GPUScheduler } from './gpuScheduler';
...@@ -75,30 +73,13 @@ class LocalTrialJobDetail implements TrialJobDetail { ...@@ -75,30 +73,13 @@ class LocalTrialJobDetail implements TrialJobDetail {
} }
} }
/** interface FlattenLocalConfig extends ExperimentConfig, LocalConfig { }
* Local training service config
*/
export class LocalConfig {
public maxTrialNumPerGpu?: number;
public gpuIndices?: string;
public useActiveGpu?: boolean;
constructor(gpuIndices?: string, maxTrialNumPerGpu?: number, useActiveGpu?: boolean) {
if (gpuIndices !== undefined) {
this.gpuIndices = gpuIndices;
}
if (maxTrialNumPerGpu !== undefined) {
this.maxTrialNumPerGpu = maxTrialNumPerGpu;
}
if (useActiveGpu !== undefined) {
this.useActiveGpu = useActiveGpu;
}
}
}
/** /**
* Local machine training service * Local machine training service
*/ */
class LocalTrainingService implements TrainingService { class LocalTrainingService implements TrainingService {
private readonly config: FlattenLocalConfig;
private readonly eventEmitter: EventEmitter; private readonly eventEmitter: EventEmitter;
private readonly jobMap: Map<string, LocalTrialJobDetail>; private readonly jobMap: Map<string, LocalTrialJobDetail>;
private readonly jobQueue: string[]; private readonly jobQueue: string[];
...@@ -108,29 +89,34 @@ class LocalTrainingService implements TrainingService { ...@@ -108,29 +89,34 @@ class LocalTrainingService implements TrainingService {
private readonly experimentId!: string; private readonly experimentId!: string;
private gpuScheduler!: GPUScheduler; private gpuScheduler!: GPUScheduler;
private readonly occupiedGpuIndexNumMap: Map<number, number>; private readonly occupiedGpuIndexNumMap: Map<number, number>;
private designatedGpuIndices!: Set<number>;
private readonly log: Logger; private readonly log: Logger;
private localTrialConfig?: TrialConfig;
private localConfig?: LocalConfig;
private isMultiPhase: boolean;
private readonly jobStreamMap: Map<string, ts.Stream>; private readonly jobStreamMap: Map<string, ts.Stream>;
private maxTrialNumPerGpu: number;
private useActiveGpu: boolean;
constructor() { constructor(config: ExperimentConfig) {
this.config = flattenConfig<FlattenLocalConfig>(config, 'local');
this.eventEmitter = new EventEmitter(); this.eventEmitter = new EventEmitter();
this.jobMap = new Map<string, LocalTrialJobDetail>(); this.jobMap = new Map<string, LocalTrialJobDetail>();
this.jobQueue = []; this.jobQueue = [];
this.initialized = false;
this.stopping = false; this.stopping = false;
this.log = getLogger(); this.log = getLogger();
this.experimentId = getExperimentId(); this.experimentId = getExperimentId();
this.jobStreamMap = new Map<string, ts.Stream>(); this.jobStreamMap = new Map<string, ts.Stream>();
this.log.info('Construct local machine training service.'); this.log.info('Construct local machine training service.');
this.occupiedGpuIndexNumMap = new Map<number, number>(); this.occupiedGpuIndexNumMap = new Map<number, number>();
this.maxTrialNumPerGpu = 1;
this.useActiveGpu = false; if (this.config.trialGpuNumber !== undefined && this.config.trialGpuNumber > 0) {
this.isMultiPhase = false; this.gpuScheduler = new GPUScheduler();
}
if (this.config.gpuIndices === []) {
throw new Error('gpuIndices cannot be empty when specified.');
}
this.rootDir = getExperimentRootDir();
if (!fs.existsSync(this.rootDir)) {
throw new Error('root dir not created');
}
this.initialized = true;
} }
public async run(): Promise<void> { public async run(): Promise<void> {
...@@ -236,13 +222,6 @@ class LocalTrainingService implements TrainingService { ...@@ -236,13 +222,6 @@ class LocalTrainingService implements TrainingService {
return trialJobDetail; return trialJobDetail;
} }
/**
* Is multiphase job supported in current training service
*/
public get isMultiPhaseJobSupported(): boolean {
return true;
}
public async cancelTrialJob(trialJobId: string, isEarlyStopped: boolean = false): Promise<void> { public async cancelTrialJob(trialJobId: string, isEarlyStopped: boolean = false): Promise<void> {
const trialJob: LocalTrialJobDetail | undefined = this.jobMap.get(trialJobId); const trialJob: LocalTrialJobDetail | undefined = this.jobMap.get(trialJobId);
if (trialJob === undefined) { if (trialJob === undefined) {
...@@ -272,69 +251,8 @@ class LocalTrainingService implements TrainingService { ...@@ -272,69 +251,8 @@ class LocalTrainingService implements TrainingService {
return Promise.resolve(); return Promise.resolve();
} }
public async setClusterMetadata(key: string, value: string): Promise<void> { public async setClusterMetadata(_key: string, _value: string): Promise<void> { return; }
if (!this.initialized) { public async getClusterMetadata(_key: string): Promise<string> { return ''; }
this.rootDir = getExperimentRootDir();
if (!fs.existsSync(this.rootDir)) {
await cpp.exec(`powershell.exe mkdir ${this.rootDir}`);
}
this.initialized = true;
}
switch (key) {
case TrialConfigMetadataKey.TRIAL_CONFIG:
this.localTrialConfig = <TrialConfig>JSON.parse(value);
// Parse trial config failed, throw Error
if (this.localTrialConfig === undefined) {
throw new Error('trial config parsed failed');
}
if (this.localTrialConfig.gpuNum !== undefined) {
this.log.info(`required GPU number is ${this.localTrialConfig.gpuNum}`);
if (this.gpuScheduler === undefined && this.localTrialConfig.gpuNum > 0) {
this.gpuScheduler = new GPUScheduler();
}
}
break;
case TrialConfigMetadataKey.LOCAL_CONFIG:
this.localConfig = <LocalConfig>JSON.parse(value);
this.log.info(`Specified GPU indices: ${this.localConfig.gpuIndices}`);
if (this.localConfig.gpuIndices !== undefined) {
this.designatedGpuIndices = new Set(this.localConfig.gpuIndices.split(',')
.map((x: string) => parseInt(x, 10)));
if (this.designatedGpuIndices.size === 0) {
throw new Error('gpuIndices can not be empty if specified.');
}
}
if (this.localConfig.maxTrialNumPerGpu !== undefined) {
this.maxTrialNumPerGpu = this.localConfig.maxTrialNumPerGpu;
}
if (this.localConfig.useActiveGpu !== undefined) {
this.useActiveGpu = this.localConfig.useActiveGpu;
}
break;
case TrialConfigMetadataKey.MULTI_PHASE:
this.isMultiPhase = (value === 'true' || value === 'True');
break;
default:
}
}
public getClusterMetadata(key: string): Promise<string> {
switch (key) {
case TrialConfigMetadataKey.TRIAL_CONFIG: {
let getResult: Promise<string>;
if (this.localTrialConfig === undefined) {
getResult = Promise.reject(new NNIError(NNIErrorNames.NOT_FOUND, `${key} is never set yet`));
} else {
getResult = Promise.resolve(JSON.stringify(this.localTrialConfig));
}
return getResult;
}
default:
return Promise.reject(new NNIError(NNIErrorNames.NOT_FOUND, 'Key not found'));
}
}
public async cleanUp(): Promise<void> { public async cleanUp(): Promise<void> {
this.log.info('Stopping local machine training service...'); this.log.info('Stopping local machine training service...');
...@@ -386,9 +304,6 @@ class LocalTrainingService implements TrainingService { ...@@ -386,9 +304,6 @@ class LocalTrainingService implements TrainingService {
trialJobDetail: TrialJobDetail, trialJobDetail: TrialJobDetail,
resource: { gpuIndices: number[] }, resource: { gpuIndices: number[] },
gpuNum: number | undefined): { key: string; value: string }[] { gpuNum: number | undefined): { key: string; value: string }[] {
if (this.localTrialConfig === undefined) {
throw new Error('localTrialConfig is not initialized!');
}
const envVariables: { key: string; value: string }[] = [ const envVariables: { key: string; value: string }[] = [
{ key: 'NNI_PLATFORM', value: 'local' }, { key: 'NNI_PLATFORM', value: 'local' },
{ key: 'NNI_EXP_ID', value: this.experimentId }, { key: 'NNI_EXP_ID', value: this.experimentId },
...@@ -396,8 +311,7 @@ class LocalTrainingService implements TrainingService { ...@@ -396,8 +311,7 @@ class LocalTrainingService implements TrainingService {
{ key: 'NNI_TRIAL_JOB_ID', value: trialJobDetail.id }, { key: 'NNI_TRIAL_JOB_ID', value: trialJobDetail.id },
{ key: 'NNI_OUTPUT_DIR', value: trialJobDetail.workingDirectory }, { key: 'NNI_OUTPUT_DIR', value: trialJobDetail.workingDirectory },
{ key: 'NNI_TRIAL_SEQ_ID', value: trialJobDetail.form.sequenceId.toString() }, { key: 'NNI_TRIAL_SEQ_ID', value: trialJobDetail.form.sequenceId.toString() },
{ key: 'MULTI_PHASE', value: this.isMultiPhase.toString() }, { key: 'NNI_CODE_DIR', value: this.config.trialCodeDirectory}
{ key: 'NNI_CODE_DIR', value: this.localTrialConfig.codeDir}
]; ];
if (gpuNum !== undefined) { if (gpuNum !== undefined) {
envVariables.push({ envVariables.push({
...@@ -414,34 +328,30 @@ class LocalTrainingService implements TrainingService { ...@@ -414,34 +328,30 @@ class LocalTrainingService implements TrainingService {
} }
private tryGetAvailableResource(): [boolean, { gpuIndices: number[]}] { private tryGetAvailableResource(): [boolean, { gpuIndices: number[]}] {
if (this.localTrialConfig === undefined) {
throw new Error('localTrialConfig is not initialized!');
}
const resource: { gpuIndices: number[] } = { gpuIndices: [] }; const resource: { gpuIndices: number[] } = { gpuIndices: [] };
if (this.gpuScheduler === undefined) { if (this.gpuScheduler === undefined) {
return [true, resource]; return [true, resource];
} }
let selectedGPUIndices: number[] = []; let selectedGPUIndices: number[] = [];
const availableGpuIndices: number[] = this.gpuScheduler.getAvailableGPUIndices(this.useActiveGpu, this.occupiedGpuIndexNumMap); const availableGpuIndices: number[] = this.gpuScheduler.getAvailableGPUIndices(this.config.useActiveGpu, this.occupiedGpuIndexNumMap);
for (const index of availableGpuIndices) { for (const index of availableGpuIndices) {
const num: number | undefined = this.occupiedGpuIndexNumMap.get(index); const num: number | undefined = this.occupiedGpuIndexNumMap.get(index);
if (num === undefined || num < this.maxTrialNumPerGpu) { if (num === undefined || num < this.config.maxTrialNumberPerGpu) {
selectedGPUIndices.push(index); selectedGPUIndices.push(index);
} }
} }
if (this.designatedGpuIndices !== undefined) { if (this.config.gpuIndices !== undefined) {
this.checkSpecifiedGpuIndices(); this.checkSpecifiedGpuIndices();
selectedGPUIndices = selectedGPUIndices.filter((index: number) => this.designatedGpuIndices.has(index)); selectedGPUIndices = selectedGPUIndices.filter((index: number) => this.config.gpuIndices!.includes(index));
} }
if (selectedGPUIndices.length < this.localTrialConfig.gpuNum) { if (selectedGPUIndices.length < this.config.trialGpuNumber!) {
return [false, resource]; return [false, resource];
} }
selectedGPUIndices.splice(this.localTrialConfig.gpuNum); selectedGPUIndices.splice(this.config.trialGpuNumber!);
Object.assign(resource, { gpuIndices: selectedGPUIndices }); Object.assign(resource, { gpuIndices: selectedGPUIndices });
return [true, resource]; return [true, resource];
...@@ -449,8 +359,8 @@ class LocalTrainingService implements TrainingService { ...@@ -449,8 +359,8 @@ class LocalTrainingService implements TrainingService {
private checkSpecifiedGpuIndices(): void { private checkSpecifiedGpuIndices(): void {
const gpuCount: number | undefined = this.gpuScheduler.getSystemGpuCount(); const gpuCount: number | undefined = this.gpuScheduler.getSystemGpuCount();
if (this.designatedGpuIndices !== undefined && gpuCount !== undefined) { if (this.config.gpuIndices !== undefined && gpuCount !== undefined) {
for (const index of this.designatedGpuIndices) { for (const index of this.config.gpuIndices) {
if (index >= gpuCount) { if (index >= gpuCount) {
throw new Error(`Specified GPU index not found: ${index}`); throw new Error(`Specified GPU index not found: ${index}`);
} }
...@@ -499,18 +409,18 @@ class LocalTrainingService implements TrainingService { ...@@ -499,18 +409,18 @@ class LocalTrainingService implements TrainingService {
} }
} }
private getScript(localTrialConfig: TrialConfig, workingDirectory: string): string[] { private getScript(workingDirectory: string): string[] {
const script: string[] = []; const script: string[] = [];
if (process.platform === 'win32') { if (process.platform === 'win32') {
script.push(`cd $env:NNI_CODE_DIR`); script.push(`cd $env:NNI_CODE_DIR`);
script.push( script.push(
`cmd.exe /c ${localTrialConfig.command} 2>&1 | Out-File "${path.join(workingDirectory, 'stderr')}" -encoding utf8`, `cmd.exe /c ${this.config.trialCommand} 2>&1 | Out-File "${path.join(workingDirectory, 'stderr')}" -encoding utf8`,
`$NOW_DATE = [int64](([datetime]::UtcNow)-(get-date "1/1/1970")).TotalSeconds`, `$NOW_DATE = [int64](([datetime]::UtcNow)-(get-date "1/1/1970")).TotalSeconds`,
`$NOW_DATE = "$NOW_DATE" + (Get-Date -Format fff).ToString()`, `$NOW_DATE = "$NOW_DATE" + (Get-Date -Format fff).ToString()`,
`Write $LASTEXITCODE " " $NOW_DATE | Out-File "${path.join(workingDirectory, '.nni', 'state')}" -NoNewline -encoding utf8`); `Write $LASTEXITCODE " " $NOW_DATE | Out-File "${path.join(workingDirectory, '.nni', 'state')}" -NoNewline -encoding utf8`);
} else { } else {
script.push(`cd $NNI_CODE_DIR`); script.push(`cd $NNI_CODE_DIR`);
script.push(`eval ${localTrialConfig.command} 2>"${path.join(workingDirectory, 'stderr')}"`); script.push(`eval ${this.config.trialCommand} 2>"${path.join(workingDirectory, 'stderr')}"`);
if (process.platform === 'darwin') { if (process.platform === 'darwin') {
// https://superuser.com/questions/599072/how-to-get-bash-execution-time-in-milliseconds-under-mac-os-x // https://superuser.com/questions/599072/how-to-get-bash-execution-time-in-milliseconds-under-mac-os-x
// Considering the worst case, write 999 to avoid negative duration // Considering the worst case, write 999 to avoid negative duration
...@@ -525,14 +435,8 @@ class LocalTrainingService implements TrainingService { ...@@ -525,14 +435,8 @@ class LocalTrainingService implements TrainingService {
private async runTrialJob(trialJobId: string, resource: {gpuIndices: number[]}): Promise<void> { private async runTrialJob(trialJobId: string, resource: {gpuIndices: number[]}): Promise<void> {
const trialJobDetail: LocalTrialJobDetail = <LocalTrialJobDetail>this.jobMap.get(trialJobId); const trialJobDetail: LocalTrialJobDetail = <LocalTrialJobDetail>this.jobMap.get(trialJobId);
if (this.localTrialConfig === undefined) { const variables: { key: string; value: string }[] = this.getEnvironmentVariables(trialJobDetail, resource, this.config.trialGpuNumber);
throw new Error(`localTrialConfig not initialized!`);
}
const variables: { key: string; value: string }[] = this.getEnvironmentVariables(trialJobDetail, resource, this.localTrialConfig.gpuNum);
if (this.localTrialConfig === undefined) {
throw new Error('trial config is not initialized');
}
const runScriptContent: string[] = []; const runScriptContent: string[] = [];
if (process.platform !== 'win32') { if (process.platform !== 'win32') {
runScriptContent.push('#!/bin/bash'); runScriptContent.push('#!/bin/bash');
...@@ -542,7 +446,7 @@ class LocalTrainingService implements TrainingService { ...@@ -542,7 +446,7 @@ class LocalTrainingService implements TrainingService {
for (const variable of variables) { for (const variable of variables) {
runScriptContent.push(setEnvironmentVariable(variable)); runScriptContent.push(setEnvironmentVariable(variable));
} }
const scripts: string[] = this.getScript(this.localTrialConfig, trialJobDetail.workingDirectory); const scripts: string[] = this.getScript(trialJobDetail.workingDirectory);
scripts.forEach((script: string) => { scripts.forEach((script: string) => {
runScriptContent.push(script); runScriptContent.push(script);
}); });
......
...@@ -8,7 +8,10 @@ import { Deferred } from 'ts-deferred'; ...@@ -8,7 +8,10 @@ import { Deferred } from 'ts-deferred';
import { NNIError, NNIErrorNames } from '../../common/errors'; import { NNIError, NNIErrorNames } from '../../common/errors';
import { getLogger, Logger } from '../../common/log'; import { getLogger, Logger } from '../../common/log';
import { TrialJobStatus } from '../../common/trainingService'; import { TrialJobStatus } from '../../common/trainingService';
import { PAIClusterConfig, PAITrialJobDetail } from './paiConfig'; import { ExperimentConfig, OpenpaiConfig } from '../../common/experimentConfig';
import { PAITrialJobDetail } from './paiConfig';
interface FlattenOpenpaiConfig extends ExperimentConfig, OpenpaiConfig { }
/** /**
* Collector PAI jobs info from PAI cluster, and update pai job status locally * Collector PAI jobs info from PAI cluster, and update pai job status locally
...@@ -25,8 +28,8 @@ export class PAIJobInfoCollector { ...@@ -25,8 +28,8 @@ export class PAIJobInfoCollector {
this.finalStatuses = ['SUCCEEDED', 'FAILED', 'USER_CANCELED', 'SYS_CANCELED', 'EARLY_STOPPED']; this.finalStatuses = ['SUCCEEDED', 'FAILED', 'USER_CANCELED', 'SYS_CANCELED', 'EARLY_STOPPED'];
} }
public async retrieveTrialStatus(protocol: string, token? : string, paiBaseClusterConfig?: PAIClusterConfig): Promise<void> { public async retrieveTrialStatus(protocol: string, token? : string, config?: FlattenOpenpaiConfig): Promise<void> {
if (paiBaseClusterConfig === undefined || token === undefined) { if (config === undefined || token === undefined) {
return Promise.resolve(); return Promise.resolve();
} }
...@@ -35,13 +38,13 @@ export class PAIJobInfoCollector { ...@@ -35,13 +38,13 @@ export class PAIJobInfoCollector {
if (paiTrialJob === undefined) { if (paiTrialJob === undefined) {
throw new NNIError(NNIErrorNames.NOT_FOUND, `trial job id ${trialJobId} not found`); throw new NNIError(NNIErrorNames.NOT_FOUND, `trial job id ${trialJobId} not found`);
} }
updatePaiTrialJobs.push(this.getSinglePAITrialJobInfo(protocol, paiTrialJob, token, paiBaseClusterConfig)); updatePaiTrialJobs.push(this.getSinglePAITrialJobInfo(protocol, paiTrialJob, token, config));
} }
await Promise.all(updatePaiTrialJobs); await Promise.all(updatePaiTrialJobs);
} }
private getSinglePAITrialJobInfo(protocol: string, paiTrialJob: PAITrialJobDetail, paiToken: string, paiClusterConfig: PAIClusterConfig): Promise<void> { private getSinglePAITrialJobInfo(protocol: string, paiTrialJob: PAITrialJobDetail, paiToken: string, config: FlattenOpenpaiConfig): Promise<void> {
const deferred: Deferred<void> = new Deferred<void>(); const deferred: Deferred<void> = new Deferred<void>();
if (!this.statusesNeedToCheck.includes(paiTrialJob.status)) { if (!this.statusesNeedToCheck.includes(paiTrialJob.status)) {
deferred.resolve(); deferred.resolve();
...@@ -52,7 +55,7 @@ export class PAIJobInfoCollector { ...@@ -52,7 +55,7 @@ export class PAIJobInfoCollector {
// Rest call to get PAI job info and update status // Rest call to get PAI job info and update status
// Refer https://github.com/Microsoft/pai/blob/master/docs/rest-server/API.md for more detail about PAI Rest API // Refer https://github.com/Microsoft/pai/blob/master/docs/rest-server/API.md for more detail about PAI Rest API
const getJobInfoRequest: request.Options = { const getJobInfoRequest: request.Options = {
uri: `${protocol}://${paiClusterConfig.host}/rest-server/api/v2/jobs/${paiClusterConfig.userName}~${paiTrialJob.paiJobName}`, uri: `${config.host}/rest-server/api/v2/jobs/${config.username}~${paiTrialJob.paiJobName}`,
method: 'GET', method: 'GET',
json: true, json: true,
headers: { headers: {
......
...@@ -18,20 +18,22 @@ import { ...@@ -18,20 +18,22 @@ import {
TrialJobApplicationForm, TrialJobDetail, TrialJobMetric, LogType TrialJobApplicationForm, TrialJobDetail, TrialJobMetric, LogType
} from '../../common/trainingService'; } from '../../common/trainingService';
import { delay } from '../../common/utils'; import { delay } from '../../common/utils';
import { ExperimentConfig, OpenpaiConfig, flattenConfig, toMegaBytes } from '../../common/experimentConfig';
import { PAIJobInfoCollector } from './paiJobInfoCollector'; import { PAIJobInfoCollector } from './paiJobInfoCollector';
import { PAIJobRestServer } from './paiJobRestServer'; import { PAIJobRestServer } from './paiJobRestServer';
import { PAIClusterConfig, PAITrialJobDetail, PAI_TRIAL_COMMAND_FORMAT, NNIPAITrialConfig } from './paiConfig'; import { PAITrialJobDetail, PAI_TRIAL_COMMAND_FORMAT } from './paiConfig';
import { String } from 'typescript-string-operations'; import { String } from 'typescript-string-operations';
import { import {
generateParamFileName, generateParamFileName,
getIPV4Address, getVersion, uniqueString getIPV4Address, uniqueString
} from '../../common/utils'; } from '../../common/utils';
import { CONTAINER_INSTALL_NNI_SHELL_FORMAT } from '../common/containerJobData'; import { CONTAINER_INSTALL_NNI_SHELL_FORMAT } from '../common/containerJobData';
import { TrialConfigMetadataKey } from '../common/trialConfigMetadataKey';
import { execMkdir, validateCodeDir, execCopydir } from '../common/util'; import { execMkdir, validateCodeDir, execCopydir } from '../common/util';
const yaml = require('js-yaml'); const yaml = require('js-yaml');
interface FlattenOpenpaiConfig extends ExperimentConfig, OpenpaiConfig { }
/** /**
* Training Service implementation for OpenPAI (Open Platform for AI) * Training Service implementation for OpenPAI (Open Platform for AI)
* Refer https://github.com/Microsoft/pai for more info about OpenPAI * Refer https://github.com/Microsoft/pai for more info about OpenPAI
...@@ -42,7 +44,6 @@ class PAITrainingService implements TrainingService { ...@@ -42,7 +44,6 @@ class PAITrainingService implements TrainingService {
private readonly metricsEmitter: EventEmitter; private readonly metricsEmitter: EventEmitter;
private readonly trialJobsMap: Map<string, PAITrialJobDetail>; private readonly trialJobsMap: Map<string, PAITrialJobDetail>;
private readonly expRootDir: string; private readonly expRootDir: string;
private paiClusterConfig?: PAIClusterConfig;
private readonly jobQueue: string[]; private readonly jobQueue: string[];
private stopping: boolean = false; private stopping: boolean = false;
private paiToken?: string; private paiToken?: string;
...@@ -53,16 +54,15 @@ class PAITrainingService implements TrainingService { ...@@ -53,16 +54,15 @@ class PAITrainingService implements TrainingService {
private paiRestServerPort?: number; private paiRestServerPort?: number;
private nniManagerIpConfig?: NNIManagerIpConfig; private nniManagerIpConfig?: NNIManagerIpConfig;
private versionCheck: boolean = true; private versionCheck: boolean = true;
private logCollection: string; private logCollection: string = 'none';
private isMultiPhase: boolean = false;
private paiJobRestServer?: PAIJobRestServer; private paiJobRestServer?: PAIJobRestServer;
private protocol: string = 'http'; private protocol: string;
private copyExpCodeDirPromise?: Promise<void>; private copyExpCodeDirPromise?: Promise<void>;
private paiJobConfig: any; private paiJobConfig: any;
private nniVersion: string | undefined; private nniVersion: string | undefined;
private paiTrialConfig: NNIPAITrialConfig | undefined; private config: FlattenOpenpaiConfig;
constructor() { constructor(config: ExperimentConfig) {
this.log = getLogger(); this.log = getLogger();
this.metricsEmitter = new EventEmitter(); this.metricsEmitter = new EventEmitter();
this.trialJobsMap = new Map<string, PAITrialJobDetail>(); this.trialJobsMap = new Map<string, PAITrialJobDetail>();
...@@ -71,8 +71,20 @@ class PAITrainingService implements TrainingService { ...@@ -71,8 +71,20 @@ class PAITrainingService implements TrainingService {
this.experimentId = getExperimentId(); this.experimentId = getExperimentId();
this.paiJobCollector = new PAIJobInfoCollector(this.trialJobsMap); this.paiJobCollector = new PAIJobInfoCollector(this.trialJobsMap);
this.paiTokenUpdateInterval = 7200000; //2hours this.paiTokenUpdateInterval = 7200000; //2hours
this.logCollection = 'none';
this.log.info('Construct paiBase training service.'); this.log.info('Construct paiBase training service.');
this.config = flattenConfig(config, 'openpai');
this.paiJobRestServer = new PAIJobRestServer(this);
this.paiToken = this.config.token;
this.protocol = this.config.host.toLowerCase().startsWith('https://') ? 'https' : 'http';
this.copyExpCodeDirPromise = this.copyTrialCode();
}
private async copyTrialCode(): Promise<void> {
await validateCodeDir(this.config.trialCodeDirectory);
const nniManagerNFSExpCodeDir = path.join(this.config.trialCodeDirectory, this.experimentId, 'nni-code');
await execMkdir(nniManagerNFSExpCodeDir);
this.log.info(`Starting copy codeDir data from ${this.config.trialCodeDirectory} to ${nniManagerNFSExpCodeDir}`);
await execCopydir(this.config.trialCodeDirectory, nniManagerNFSExpCodeDir);
} }
public async run(): Promise<void> { public async run(): Promise<void> {
...@@ -120,10 +132,6 @@ class PAITrainingService implements TrainingService { ...@@ -120,10 +132,6 @@ class PAITrainingService implements TrainingService {
} }
public async getTrialJob(trialJobId: string): Promise<TrialJobDetail> { public async getTrialJob(trialJobId: string): Promise<TrialJobDetail> {
if (this.paiClusterConfig === undefined) {
throw new Error('PAI Cluster config is not initialized');
}
const paiTrialJob: PAITrialJobDetail | undefined = this.trialJobsMap.get(trialJobId); const paiTrialJob: PAITrialJobDetail | undefined = this.trialJobsMap.get(trialJobId);
if (paiTrialJob === undefined) { if (paiTrialJob === undefined) {
...@@ -141,30 +149,19 @@ class PAITrainingService implements TrainingService { ...@@ -141,30 +149,19 @@ class PAITrainingService implements TrainingService {
this.metricsEmitter.off('metric', listener); this.metricsEmitter.off('metric', listener);
} }
public get isMultiPhaseJobSupported(): boolean {
return true;
}
public cancelTrialJob(trialJobId: string, isEarlyStopped: boolean = false): Promise<void> { public cancelTrialJob(trialJobId: string, isEarlyStopped: boolean = false): Promise<void> {
const trialJobDetail: PAITrialJobDetail | undefined = this.trialJobsMap.get(trialJobId); const trialJobDetail: PAITrialJobDetail | undefined = this.trialJobsMap.get(trialJobId);
if (trialJobDetail === undefined) { if (trialJobDetail === undefined) {
return Promise.reject(new Error(`cancelTrialJob: trial job id ${trialJobId} not found`)); return Promise.reject(new Error(`cancelTrialJob: trial job id ${trialJobId} not found`));
} }
if (this.paiClusterConfig === undefined) {
return Promise.reject(new Error('PAI Cluster config is not initialized'));
}
if (this.paiToken === undefined) {
return Promise.reject(new Error('PAI token is not initialized'));
}
if (trialJobDetail.status === 'UNKNOWN') { if (trialJobDetail.status === 'UNKNOWN') {
trialJobDetail.status = 'USER_CANCELED'; trialJobDetail.status = 'USER_CANCELED';
return Promise.resolve(); return Promise.resolve();
} }
const stopJobRequest: request.Options = { const stopJobRequest: request.Options = {
uri: `${this.protocol}://${this.paiClusterConfig.host}/rest-server/api/v2/jobs/${this.paiClusterConfig.userName}~${trialJobDetail.paiJobName}/executionType`, uri: `${this.config.host}/rest-server/api/v2/jobs/${this.config.username}~${trialJobDetail.paiJobName}/executionType`,
method: 'PUT', method: 'PUT',
json: true, json: true,
body: { value: 'STOP' }, body: { value: 'STOP' },
...@@ -192,10 +189,6 @@ class PAITrainingService implements TrainingService { ...@@ -192,10 +189,6 @@ class PAITrainingService implements TrainingService {
return deferred.promise; return deferred.promise;
} }
public getClusterMetadata(_key: string): Promise<string> {
throw new Error('Not implemented!');
}
public async cleanUp(): Promise<void> { public async cleanUp(): Promise<void> {
this.log.info('Stopping PAI training service...'); this.log.info('Stopping PAI training service...');
this.stopping = true; this.stopping = true;
...@@ -232,18 +225,14 @@ class PAITrainingService implements TrainingService { ...@@ -232,18 +225,14 @@ class PAITrainingService implements TrainingService {
protected async statusCheckingLoop(): Promise<void> { protected async statusCheckingLoop(): Promise<void> {
while (!this.stopping) { while (!this.stopping) {
if (this.paiClusterConfig && this.paiClusterConfig.passWord) { if (this.config.deprecated && this.config.deprecated.password) {
try { try {
await this.updatePaiToken(); await this.updatePaiToken();
} catch (error) { } catch (error) {
this.log.error(`${error}`); this.log.error(`${error}`);
//only throw error when initlize paiToken first time
if (this.paiToken === undefined) {
throw new Error(error);
}
} }
} }
await this.paiJobCollector.retrieveTrialStatus(this.protocol, this.paiToken, this.paiClusterConfig); await this.paiJobCollector.retrieveTrialStatus(this.protocol, this.paiToken, this.config);
if (this.paiJobRestServer === undefined) { if (this.paiJobRestServer === undefined) {
throw new Error('paiBaseJobRestServer not implemented!'); throw new Error('paiBaseJobRestServer not implemented!');
} }
...@@ -266,19 +255,13 @@ class PAITrainingService implements TrainingService { ...@@ -266,19 +255,13 @@ class PAITrainingService implements TrainingService {
return Promise.resolve(); return Promise.resolve();
} }
if (this.paiClusterConfig === undefined) {
const paiClusterConfigError: string = `pai cluster config not initialized!`;
this.log.error(`${paiClusterConfigError}`);
throw Error(`${paiClusterConfigError}`);
}
const authenticationReq: request.Options = { const authenticationReq: request.Options = {
uri: `${this.protocol}://${this.paiClusterConfig.host}/rest-server/api/v1/token`, uri: `${this.config.host}/rest-server/api/v1/token`,
method: 'POST', method: 'POST',
json: true, json: true,
body: { body: {
username: this.paiClusterConfig.userName, username: this.config.username,
password: this.paiClusterConfig.passWord password: this.config.deprecated.password
} }
}; };
...@@ -309,52 +292,8 @@ class PAITrainingService implements TrainingService { ...@@ -309,52 +292,8 @@ class PAITrainingService implements TrainingService {
.finally(() => { clearTimeout(timeoutId); }); .finally(() => { clearTimeout(timeoutId); });
} }
public async setClusterMetadata(key: string, value: string): Promise<void> { public async setClusterMetadata(_key: string, _value: string): Promise<void> { return; }
switch (key) { public async getClusterMetadata(_key: string): Promise<string> { return ''; }
case TrialConfigMetadataKey.NNI_MANAGER_IP:
this.nniManagerIpConfig = <NNIManagerIpConfig>JSON.parse(value);
break;
case TrialConfigMetadataKey.PAI_CLUSTER_CONFIG:
this.paiJobRestServer = new PAIJobRestServer(component.get(PAITrainingService));
this.paiClusterConfig = <PAIClusterConfig>JSON.parse(value);
this.paiClusterConfig.host = this.formatPAIHost(this.paiClusterConfig.host);
this.paiToken = this.paiClusterConfig.token;
break;
case TrialConfigMetadataKey.TRIAL_CONFIG: {
if (this.paiClusterConfig === undefined) {
this.log.error('pai cluster config is not initialized');
break;
}
this.paiTrialConfig = <NNIPAITrialConfig>JSON.parse(value);
// Validate to make sure codeDir doesn't have too many files
await validateCodeDir(this.paiTrialConfig.codeDir);
const nniManagerNFSExpCodeDir = path.join(this.paiTrialConfig.nniManagerNFSMountPath, this.experimentId, 'nni-code');
await execMkdir(nniManagerNFSExpCodeDir);
//Copy codeDir files to local working folder
this.log.info(`Starting copy codeDir data from ${this.paiTrialConfig.codeDir} to ${nniManagerNFSExpCodeDir}`);
this.copyExpCodeDirPromise = execCopydir(this.paiTrialConfig.codeDir, nniManagerNFSExpCodeDir);
if (this.paiTrialConfig.paiConfigPath) {
this.paiJobConfig = yaml.safeLoad(fs.readFileSync(this.paiTrialConfig.paiConfigPath, 'utf8'));
}
break;
}
case TrialConfigMetadataKey.VERSION_CHECK:
this.versionCheck = (value === 'true' || value === 'True');
this.nniVersion = this.versionCheck ? await getVersion() : '';
break;
case TrialConfigMetadataKey.LOG_COLLECTION:
this.logCollection = value;
break;
case TrialConfigMetadataKey.MULTI_PHASE:
this.isMultiPhase = (value === 'true' || value === 'True');
break;
default:
//Reject for unknown keys
this.log.error(`Uknown key: ${key}`);
}
}
// update trial parameters for multi-phase // update trial parameters for multi-phase
public async updateTrialJob(trialJobId: string, form: TrialJobApplicationForm): Promise<TrialJobDetail> { public async updateTrialJob(trialJobId: string, form: TrialJobApplicationForm): Promise<TrialJobDetail> {
...@@ -369,21 +308,14 @@ class PAITrainingService implements TrainingService { ...@@ -369,21 +308,14 @@ class PAITrainingService implements TrainingService {
} }
public async submitTrialJob(form: TrialJobApplicationForm): Promise<TrialJobDetail> { public async submitTrialJob(form: TrialJobApplicationForm): Promise<TrialJobDetail> {
if (this.paiClusterConfig === undefined) {
throw new Error(`paiClusterConfig not initialized!`);
}
if (this.paiTrialConfig === undefined) {
throw new Error(`paiTrialConfig not initialized!`);
}
this.log.info(`submitTrialJob: form: ${JSON.stringify(form)}`); this.log.info(`submitTrialJob: form: ${JSON.stringify(form)}`);
const trialJobId: string = uniqueString(5); const trialJobId: string = uniqueString(5);
//TODO: use HDFS working folder instead //TODO: use HDFS working folder instead
const trialWorkingFolder: string = path.join(this.expRootDir, 'trials', trialJobId); const trialWorkingFolder: string = path.join(this.expRootDir, 'trials', trialJobId);
const paiJobName: string = `nni_exp_${this.experimentId}_trial_${trialJobId}`; const paiJobName: string = `nni_exp_${this.experimentId}_trial_${trialJobId}`;
const logPath: string = path.join(this.paiTrialConfig.nniManagerNFSMountPath, this.experimentId, trialJobId); const logPath: string = path.join(this.config.localStorageMountPoint, this.experimentId, trialJobId);
const paiJobDetailUrl: string = `${this.protocol}://${this.paiClusterConfig.host}/job-detail.html?username=${this.paiClusterConfig.userName}&jobName=${paiJobName}`; const paiJobDetailUrl: string = `${this.config.host}/job-detail.html?username=${this.config.username}&jobName=${paiJobName}`;
const trialJobDetail: PAITrialJobDetail = new PAITrialJobDetail( const trialJobDetail: PAITrialJobDetail = new PAITrialJobDetail(
trialJobId, trialJobId,
'WAITING', 'WAITING',
...@@ -401,12 +333,8 @@ class PAITrainingService implements TrainingService { ...@@ -401,12 +333,8 @@ class PAITrainingService implements TrainingService {
} }
private generateNNITrialCommand(trialJobDetail: PAITrialJobDetail, command: string): string { private generateNNITrialCommand(trialJobDetail: PAITrialJobDetail, command: string): string {
if (this.paiTrialConfig === undefined) { const containerNFSExpCodeDir = `${this.config.containerStorageMountPoint}/${this.experimentId}/nni-code`;
throw new Error('trial config is not initialized'); const containerWorkingDir: string = `${this.config.containerStorageMountPoint}/${this.experimentId}/${trialJobDetail.id}`;
}
const containerNFSExpCodeDir = `${this.paiTrialConfig.containerNFSMountPath}/${this.experimentId}/nni-code`;
const containerWorkingDir: string = `${this.paiTrialConfig.containerNFSMountPath}/${this.experimentId}/${trialJobDetail.id}`;
const nniManagerIp: string = this.nniManagerIpConfig ? this.nniManagerIpConfig.nniManagerIp : getIPV4Address();
const nniPaiTrialCommand: string = String.Format( const nniPaiTrialCommand: string = String.Format(
PAI_TRIAL_COMMAND_FORMAT, PAI_TRIAL_COMMAND_FORMAT,
`${containerWorkingDir}`, `${containerWorkingDir}`,
...@@ -414,10 +342,10 @@ class PAITrainingService implements TrainingService { ...@@ -414,10 +342,10 @@ class PAITrainingService implements TrainingService {
trialJobDetail.id, trialJobDetail.id,
this.experimentId, this.experimentId,
trialJobDetail.form.sequenceId, trialJobDetail.form.sequenceId,
this.isMultiPhase, false, // multi-phase
containerNFSExpCodeDir, containerNFSExpCodeDir,
command, command,
nniManagerIp, this.config.nniManagerIp || getIPV4Address(),
this.paiRestServerPort, this.paiRestServerPort,
this.nniVersion, this.nniVersion,
this.logCollection this.logCollection
...@@ -429,14 +357,11 @@ class PAITrainingService implements TrainingService { ...@@ -429,14 +357,11 @@ class PAITrainingService implements TrainingService {
} }
private generateJobConfigInYamlFormat(trialJobDetail: PAITrialJobDetail): any { private generateJobConfigInYamlFormat(trialJobDetail: PAITrialJobDetail): any {
if (this.paiTrialConfig === undefined) {
throw new Error('trial config is not initialized');
}
const jobName = `nni_exp_${this.experimentId}_trial_${trialJobDetail.id}` const jobName = `nni_exp_${this.experimentId}_trial_${trialJobDetail.id}`
let nniJobConfig: any = undefined; let nniJobConfig: any = undefined;
if (this.paiTrialConfig.paiConfigPath) { if (this.config.openpaiConfig !== undefined) {
nniJobConfig = JSON.parse(JSON.stringify(this.paiJobConfig)); //Trick for deep clone in Typescript nniJobConfig = JSON.parse(JSON.stringify(this.config.openpaiConfig)); //Trick for deep clone in Typescript
nniJobConfig.name = jobName; nniJobConfig.name = jobName;
// Each taskRole will generate new command in NNI's command format // Each taskRole will generate new command in NNI's command format
// Each command will be formatted to NNI style // Each command will be formatted to NNI style
...@@ -455,7 +380,7 @@ class PAITrainingService implements TrainingService { ...@@ -455,7 +380,7 @@ class PAITrainingService implements TrainingService {
prerequisites: [ prerequisites: [
{ {
type: 'dockerimage', type: 'dockerimage',
uri: this.paiTrialConfig.image, uri: this.config.dockerImage,
name: 'docker_image_0' name: 'docker_image_0'
} }
], ],
...@@ -469,27 +394,27 @@ class PAITrainingService implements TrainingService { ...@@ -469,27 +394,27 @@ class PAITrainingService implements TrainingService {
taskRetryCount: 0, taskRetryCount: 0,
dockerImage: 'docker_image_0', dockerImage: 'docker_image_0',
resourcePerInstance: { resourcePerInstance: {
gpu: this.paiTrialConfig.gpuNum, gpu: this.config.trialGpuNumber,
cpu: this.paiTrialConfig.cpuNum, cpu: this.config.trialCpuNumber,
memoryMB: this.paiTrialConfig.memoryMB memoryMB: toMegaBytes(this.config.trialMemorySize)
}, },
commands: [ commands: [
this.generateNNITrialCommand(trialJobDetail, this.paiTrialConfig.command) this.generateNNITrialCommand(trialJobDetail, this.config.trialCommand)
] ]
} }
}, },
extras: { extras: {
'storages': [ 'storages': [
{ {
name: this.paiTrialConfig.paiStorageConfigName name: this.config.storageConfigName
} }
], ],
submitFrom: 'submit-job-v2' submitFrom: 'submit-job-v2'
} }
} }
if (this.paiTrialConfig.virtualCluster) { if (this.config.deprecated && this.config.deprecated.virtualCluster) {
nniJobConfig.defaults = { nniJobConfig.defaults = {
virtualCluster: this.paiTrialConfig.virtualCluster virtualCluster: this.config.deprecated.virtualCluster
} }
} }
} }
...@@ -504,16 +429,6 @@ class PAITrainingService implements TrainingService { ...@@ -504,16 +429,6 @@ class PAITrainingService implements TrainingService {
throw new Error(`Failed to find PAITrialJobDetail for job ${trialJobId}`); throw new Error(`Failed to find PAITrialJobDetail for job ${trialJobId}`);
} }
if (this.paiClusterConfig === undefined) {
throw new Error('PAI Cluster config is not initialized');
}
if (this.paiTrialConfig === undefined) {
throw new Error('trial config is not initialized');
}
if (this.paiToken === undefined) {
throw new Error('PAI token is not initialized');
}
if (this.paiJobRestServer === undefined) { if (this.paiJobRestServer === undefined) {
throw new Error('paiJobRestServer is not initialized'); throw new Error('paiJobRestServer is not initialized');
} }
...@@ -546,7 +461,7 @@ class PAITrainingService implements TrainingService { ...@@ -546,7 +461,7 @@ class PAITrainingService implements TrainingService {
// Step 2. Submit PAI job via Rest call // Step 2. Submit PAI job via Rest call
// Refer https://github.com/Microsoft/pai/blob/master/docs/rest-server/API.md for more detail about PAI Rest API // Refer https://github.com/Microsoft/pai/blob/master/docs/rest-server/API.md for more detail about PAI Rest API
const submitJobRequest: request.Options = { const submitJobRequest: request.Options = {
uri: `${this.protocol}://${this.paiClusterConfig.host}/rest-server/api/v2/jobs`, uri: `${this.config.host}/rest-server/api/v2/jobs`,
method: 'POST', method: 'POST',
body: paiJobConfig, body: paiJobConfig,
followAllRedirects: true, followAllRedirects: true,
......
...@@ -6,7 +6,8 @@ ...@@ -6,7 +6,8 @@
import * as assert from 'assert'; import * as assert from 'assert';
import { getLogger, Logger } from '../../common/log'; import { getLogger, Logger } from '../../common/log';
import { randomSelect } from '../../common/utils'; import { randomSelect } from '../../common/utils';
import { GPUInfo, parseGpuIndices, ScheduleResultType } from '../common/gpuData'; import { RemoteMachineConfig } from '../../common/experimentConfig';
import { GPUInfo, ScheduleResultType } from '../common/gpuData';
import { ExecutorManager, RemoteMachineMeta, RemoteMachineScheduleResult, RemoteMachineTrialJobDetail } from './remoteMachineData'; import { ExecutorManager, RemoteMachineMeta, RemoteMachineScheduleResult, RemoteMachineTrialJobDetail } from './remoteMachineData';
type SCHEDULE_POLICY_NAME = 'random' | 'round-robin'; type SCHEDULE_POLICY_NAME = 'random' | 'round-robin';
...@@ -16,7 +17,7 @@ type SCHEDULE_POLICY_NAME = 'random' | 'round-robin'; ...@@ -16,7 +17,7 @@ type SCHEDULE_POLICY_NAME = 'random' | 'round-robin';
*/ */
export class GPUScheduler { export class GPUScheduler {
private readonly machineExecutorMap: Map<RemoteMachineMeta, ExecutorManager>; private readonly machineExecutorMap: Map<RemoteMachineConfig, ExecutorManager>;
private readonly log: Logger = getLogger(); private readonly log: Logger = getLogger();
private readonly policyName: SCHEDULE_POLICY_NAME = 'round-robin'; private readonly policyName: SCHEDULE_POLICY_NAME = 'round-robin';
private roundRobinIndex: number = 0; private roundRobinIndex: number = 0;
...@@ -26,10 +27,10 @@ export class GPUScheduler { ...@@ -26,10 +27,10 @@ export class GPUScheduler {
* Constructor * Constructor
* @param machineExecutorMap map from remote machine to executor * @param machineExecutorMap map from remote machine to executor
*/ */
constructor(machineExecutorMap: Map<RemoteMachineMeta, ExecutorManager>) { constructor(machineExecutorMap: Map<RemoteMachineConfig, ExecutorManager>) {
assert(machineExecutorMap.size > 0); assert(machineExecutorMap.size > 0);
this.machineExecutorMap = machineExecutorMap; this.machineExecutorMap = machineExecutorMap;
this.configuredRMs = Array.from(machineExecutorMap.keys()); this.configuredRMs = Array.from(machineExecutorMap.values(), manager => manager.rmMeta);
} }
/** /**
...@@ -41,7 +42,7 @@ export class GPUScheduler { ...@@ -41,7 +42,7 @@ export class GPUScheduler {
requiredGPUNum = 0; requiredGPUNum = 0;
} }
assert(requiredGPUNum >= 0); assert(requiredGPUNum >= 0);
const allRMs: RemoteMachineMeta[] = Array.from(this.machineExecutorMap.keys()); const allRMs: RemoteMachineMeta[] = Array.from(this.machineExecutorMap.values(), manager => manager.rmMeta);
assert(allRMs.length > 0); assert(allRMs.length > 0);
// Step 1: Check if required GPU number not exceeds the total GPU number in all machines // Step 1: Check if required GPU number not exceeds the total GPU number in all machines
...@@ -133,11 +134,12 @@ export class GPUScheduler { ...@@ -133,11 +134,12 @@ export class GPUScheduler {
*/ */
private gpuResourceDetection(): Map<RemoteMachineMeta, GPUInfo[]> { private gpuResourceDetection(): Map<RemoteMachineMeta, GPUInfo[]> {
const totalResourceMap: Map<RemoteMachineMeta, GPUInfo[]> = new Map<RemoteMachineMeta, GPUInfo[]>(); const totalResourceMap: Map<RemoteMachineMeta, GPUInfo[]> = new Map<RemoteMachineMeta, GPUInfo[]>();
this.machineExecutorMap.forEach((executorManager: ExecutorManager, rmMeta: RemoteMachineMeta) => { this.machineExecutorMap.forEach((executorManager: ExecutorManager, machineConfig: RemoteMachineConfig) => {
const rmMeta = executorManager.rmMeta;
// Assgin totoal GPU count as init available GPU number // Assgin totoal GPU count as init available GPU number
if (rmMeta.gpuSummary !== undefined) { if (rmMeta.gpuSummary !== undefined) {
const availableGPUs: GPUInfo[] = []; const availableGPUs: GPUInfo[] = [];
const designatedGpuIndices: Set<number> | undefined = parseGpuIndices(rmMeta.gpuIndices); const designatedGpuIndices: number[] | undefined = machineConfig.gpuIndices;
if (designatedGpuIndices !== undefined) { if (designatedGpuIndices !== undefined) {
for (const gpuIndex of designatedGpuIndices) { for (const gpuIndex of designatedGpuIndices) {
if (gpuIndex >= rmMeta.gpuSummary.gpuCount) { if (gpuIndex >= rmMeta.gpuSummary.gpuCount) {
...@@ -152,12 +154,11 @@ export class GPUScheduler { ...@@ -152,12 +154,11 @@ export class GPUScheduler {
// or trial number on a GPU reach max number, // or trial number on a GPU reach max number,
// We should NOT allocate this GPU // We should NOT allocate this GPU
// if users set useActiveGpu, use the gpu whether there is another activeProcess // if users set useActiveGpu, use the gpu whether there is another activeProcess
if (designatedGpuIndices === undefined || designatedGpuIndices.has(gpuInfo.index)) { if (designatedGpuIndices === undefined || designatedGpuIndices.includes(gpuInfo.index)) {
if (rmMeta.occupiedGpuIndexMap !== undefined) { if (rmMeta.occupiedGpuIndexMap !== undefined) {
const num: number | undefined = rmMeta.occupiedGpuIndexMap.get(gpuInfo.index); const num: number | undefined = rmMeta.occupiedGpuIndexMap.get(gpuInfo.index);
const maxTrialNumPerGpu: number = rmMeta.maxTrialNumPerGpu ? rmMeta.maxTrialNumPerGpu : 1; if ((num === undefined && (!machineConfig.useActiveGpu && gpuInfo.activeProcessNum === 0 || machineConfig.useActiveGpu)) ||
if ((num === undefined && (!rmMeta.useActiveGpu && gpuInfo.activeProcessNum === 0 || rmMeta.useActiveGpu)) || (num !== undefined && num < machineConfig.maxTrialNumberPerGpu)) {
(num !== undefined && num < maxTrialNumPerGpu)) {
availableGPUs.push(gpuInfo); availableGPUs.push(gpuInfo);
} }
} else { } else {
...@@ -209,7 +210,7 @@ export class GPUScheduler { ...@@ -209,7 +210,7 @@ export class GPUScheduler {
} }
rmMeta.occupiedGpuIndexMap.set(gpuInfo.index, num + 1); rmMeta.occupiedGpuIndexMap.set(gpuInfo.index, num + 1);
} else { } else {
throw new Error(`Machine ${rmMeta.ip} occupiedGpuIndexMap initialize error!`); throw new Error(`Machine ${rmMeta.config.host} occupiedGpuIndexMap initialize error!`);
} }
}); });
trialJobDetail.gpuIndices = allocatedGPUs; trialJobDetail.gpuIndices = allocatedGPUs;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment