Commit 761638d8 authored by Gems Guo's avatar Gems Guo Committed by goooxu
Browse files

Refactor close experiment implementation

parent 5a2721be
...@@ -70,17 +70,23 @@ class Logger { ...@@ -70,17 +70,23 @@ class Logger {
private DEFAULT_LOGFILE: string = path.join(getLogDir(), 'nnimanager.log'); private DEFAULT_LOGFILE: string = path.join(getLogDir(), 'nnimanager.log');
private level: number = DEBUG; private level: number = DEBUG;
private bufferSerialEmitter: BufferSerialEmitter; private bufferSerialEmitter: BufferSerialEmitter;
private writble: Writable;
constructor(fileName?: string) { constructor(fileName?: string) {
let logFile: string | undefined = fileName; let logFile: string | undefined = fileName;
if (logFile === undefined) { if (logFile === undefined) {
logFile = this.DEFAULT_LOGFILE; logFile = this.DEFAULT_LOGFILE;
} }
this.bufferSerialEmitter = new BufferSerialEmitter(fs.createWriteStream(logFile, { this.writble = fs.createWriteStream(logFile, {
flags: 'a+', flags: 'a+',
encoding: 'utf8', encoding: 'utf8',
autoClose: true autoClose: true
})); });
this.bufferSerialEmitter = new BufferSerialEmitter(this.writble);
}
public close() {
this.writble.destroy();
} }
public debug(...param: any[]): void { public debug(...param: any[]): void {
......
...@@ -35,7 +35,7 @@ import { ...@@ -35,7 +35,7 @@ import {
import { import {
TrainingService, TrialJobApplicationForm, TrialJobDetail, TrialJobMetric, TrialJobStatus TrainingService, TrialJobApplicationForm, TrialJobDetail, TrialJobMetric, TrialJobStatus
} from '../common/trainingService'; } from '../common/trainingService';
import { delay , getLogDir, getMsgDispatcherCommand} from '../common/utils'; import { delay, getLogDir, getMsgDispatcherCommand } from '../common/utils';
import { import {
ADD_CUSTOMIZED_TRIAL_JOB, KILL_TRIAL_JOB, NEW_TRIAL_JOB, NO_MORE_TRIAL_JOBS, REPORT_METRIC_DATA, ADD_CUSTOMIZED_TRIAL_JOB, KILL_TRIAL_JOB, NEW_TRIAL_JOB, NO_MORE_TRIAL_JOBS, REPORT_METRIC_DATA,
REQUEST_TRIAL_JOBS, SEND_TRIAL_JOB_PARAMETER, TERMINATE, TRIAL_END, UPDATE_SEARCH_SPACE REQUEST_TRIAL_JOBS, SEND_TRIAL_JOB_PARAMETER, TERMINATE, TRIAL_END, UPDATE_SEARCH_SPACE
...@@ -123,7 +123,7 @@ class NNIManager implements Manager { ...@@ -123,7 +123,7 @@ class NNIManager implements Manager {
this.log.debug('Setup tuner...'); this.log.debug('Setup tuner...');
// Set up multiphase config // Set up multiphase config
if(expParams.multiPhase && this.trainingService.isMultiPhaseJobSupported) { if (expParams.multiPhase && this.trainingService.isMultiPhaseJobSupported) {
this.trainingService.setClusterMetadata('multiPhase', expParams.multiPhase.toString()); this.trainingService.setClusterMetadata('multiPhase', expParams.multiPhase.toString());
} }
...@@ -217,10 +217,9 @@ class NNIManager implements Manager { ...@@ -217,10 +217,9 @@ class NNIManager implements Manager {
return this.dataStore.getTrialJobStatistics(); return this.dataStore.getTrialJobStatistics();
} }
public stopExperiment(): Promise<void> { public async stopExperiment(): Promise<void> {
this.status.status = 'STOPPING'; this.status.status = 'STOPPING';
await this.experimentDoneCleanUp();
return Promise.resolve();
} }
public async getMetricData(trialJobId?: string, metricType?: MetricType): Promise<MetricDataRecord[]> { public async getMetricData(trialJobId?: string, metricType?: MetricType): Promise<MetricDataRecord[]> {
...@@ -342,7 +341,7 @@ class NNIManager implements Manager { ...@@ -342,7 +341,7 @@ class NNIManager implements Manager {
private async periodicallyUpdateExecDuration(): Promise<void> { private async periodicallyUpdateExecDuration(): Promise<void> {
let count: number = 1; let count: number = 1;
for (; ;) { while (this.status.status !== 'STOPPING') {
await delay(1000 * 1); // 1 seconds await delay(1000 * 1); // 1 seconds
if (this.status.status === 'EXPERIMENT_RUNNING') { if (this.status.status === 'EXPERIMENT_RUNNING') {
this.experimentProfile.execDuration += 1; this.experimentProfile.execDuration += 1;
...@@ -396,10 +395,7 @@ class NNIManager implements Manager { ...@@ -396,10 +395,7 @@ class NNIManager implements Manager {
throw new Error('Error: tuner has not been setup'); throw new Error('Error: tuner has not been setup');
} }
let allFinishedTrialJobNum: number = 0; let allFinishedTrialJobNum: number = 0;
for (; ;) { while (this.status.status !== 'STOPPING') {
if (this.status.status === 'STOPPING') {
break;
}
const finishedTrialJobNum: number = await this.requestTrialJobsStatus(); const finishedTrialJobNum: number = await this.requestTrialJobsStatus();
allFinishedTrialJobNum += finishedTrialJobNum; allFinishedTrialJobNum += finishedTrialJobNum;
......
...@@ -50,7 +50,7 @@ async function initContainer(platformMode: string): Promise<void> { ...@@ -50,7 +50,7 @@ async function initContainer(platformMode: string): Promise<void> {
Container.bind(TrainingService).to(LocalTrainingServiceForGPU).scope(Scope.Singleton); Container.bind(TrainingService).to(LocalTrainingServiceForGPU).scope(Scope.Singleton);
} else if (platformMode === 'remote') { } else if (platformMode === 'remote') {
Container.bind(TrainingService).to(RemoteMachineTrainingService).scope(Scope.Singleton); Container.bind(TrainingService).to(RemoteMachineTrainingService).scope(Scope.Singleton);
} else if (platformMode === 'pai'){ } else if (platformMode === 'pai') {
Container.bind(TrainingService).to(PAITrainingService).scope(Scope.Singleton); Container.bind(TrainingService).to(PAITrainingService).scope(Scope.Singleton);
} else { } else {
throw new Error(`Error: unsupported mode: ${mode}`); throw new Error(`Error: unsupported mode: ${mode}`);
...@@ -108,3 +108,12 @@ mkDirP(getLogDir()).then(async () => { ...@@ -108,3 +108,12 @@ mkDirP(getLogDir()).then(async () => {
}).catch((err: Error) => { }).catch((err: Error) => {
console.error(`Failed to create log dir: ${err.stack}`); console.error(`Failed to create log dir: ${err.stack}`);
}); });
process.on('SIGTERM', async () => {
const ds: DataStore = component.get(DataStore);
await ds.close();
const restServer: NNIRestServer = component.get(NNIRestServer);
await restServer.stop();
const log: Logger = getLogger();
log.close();
})
\ No newline at end of file
...@@ -164,8 +164,6 @@ class NNIRestHandler { ...@@ -164,8 +164,6 @@ class NNIRestHandler {
await this.tb.cleanUp(); await this.tb.cleanUp();
await this.nniManager.stopExperiment(); await this.nniManager.stopExperiment();
res.send(); res.send();
this.log.debug('Stopping rest server');
await this.restServer.stop();
} catch (err) { } catch (err) {
this.handle_error(err, res); this.handle_error(err, res);
} }
......
...@@ -26,7 +26,7 @@ import { EventEmitter } from 'events'; ...@@ -26,7 +26,7 @@ import { EventEmitter } from 'events';
import * as fs from 'fs'; import * as fs from 'fs';
import * as path from 'path'; import * as path from 'path';
import * as ts from 'tail-stream'; import * as ts from 'tail-stream';
import { MethodNotImplementedError, NNIError, NNIErrorNames } from '../../common/errors'; import { NNIError, NNIErrorNames } from '../../common/errors';
import { getLogger, Logger } from '../../common/log'; import { getLogger, Logger } from '../../common/log';
import { TrialConfig } from '../common/trialConfig'; import { TrialConfig } from '../common/trialConfig';
import { TrialConfigMetadataKey } from '../common/trialConfigMetadataKey'; import { TrialConfigMetadataKey } from '../common/trialConfigMetadataKey';
...@@ -103,6 +103,7 @@ class LocalTrainingService implements TrainingService { ...@@ -103,6 +103,7 @@ class LocalTrainingService implements TrainingService {
protected log: Logger; protected log: Logger;
protected localTrailConfig?: TrialConfig; protected localTrailConfig?: TrialConfig;
private isMultiPhase: boolean = false; private isMultiPhase: boolean = false;
private streams: Array<ts.Stream>;
constructor() { constructor() {
this.eventEmitter = new EventEmitter(); this.eventEmitter = new EventEmitter();
...@@ -112,6 +113,7 @@ class LocalTrainingService implements TrainingService { ...@@ -112,6 +113,7 @@ class LocalTrainingService implements TrainingService {
this.stopping = false; this.stopping = false;
this.log = getLogger(); this.log = getLogger();
this.trialSequenceId = -1; this.trialSequenceId = -1;
this.streams = new Array<ts.Stream>();
} }
public async run(): Promise<void> { public async run(): Promise<void> {
...@@ -295,7 +297,9 @@ class LocalTrainingService implements TrainingService { ...@@ -295,7 +297,9 @@ class LocalTrainingService implements TrainingService {
public cleanUp(): Promise<void> { public cleanUp(): Promise<void> {
this.stopping = true; this.stopping = true;
for(const stream of this.streams) {
stream.destroy();
}
return Promise.resolve(); return Promise.resolve();
} }
...@@ -382,6 +386,7 @@ class LocalTrainingService implements TrainingService { ...@@ -382,6 +386,7 @@ class LocalTrainingService implements TrainingService {
buffer = remain; buffer = remain;
} }
}); });
this.streams.push(stream);
} }
private async runHostJob(form: HostJobApplicationForm): Promise<TrialJobDetail> { private async runHostJob(form: HostJobApplicationForm): Promise<TrialJobDetail> {
......
...@@ -136,7 +136,7 @@ export namespace HDFSClientUtility { ...@@ -136,7 +136,7 @@ export namespace HDFSClientUtility {
let timeoutId : NodeJS.Timer let timeoutId : NodeJS.Timer
const delayTimeout : Promise<boolean> = new Promise<boolean>((resolve : Function, reject : Function) : void => { const delayTimeout : Promise<boolean> = new Promise<boolean>((resolve : Function, reject : Function) : void => {
// Set timeout and reject the promise once reach timeout (5 seconds) // Set timeout and reject the promise once reach timeout (5 seconds)
setTimeout(() => deferred.reject(`Check HDFS path ${hdfsPath} exists timeout`), 5000); timeoutId = setTimeout(() => deferred.reject(`Check HDFS path ${hdfsPath} exists timeout`), 5000);
}); });
return Promise.race([deferred.promise, delayTimeout]).finally(() => clearTimeout(timeoutId)); return Promise.race([deferred.promise, delayTimeout]).finally(() => clearTimeout(timeoutId));
......
declare module 'tail-stream' { declare module 'tail-stream' {
export interface Stream { export interface Stream {
on(type: 'data', callback: (data: Buffer) => void): void; on(type: 'data', callback: (data: Buffer) => void): void;
destroy(): void;
} }
export function createReadStream(path: string): Stream; export function createReadStream(path: string): Stream;
} }
\ No newline at end of file
...@@ -190,7 +190,7 @@ def stop_experiment(args): ...@@ -190,7 +190,7 @@ def stop_experiment(args):
time.sleep(3) time.sleep(3)
rest_pid = nni_config.get_config('restServerPid') rest_pid = nni_config.get_config('restServerPid')
if rest_pid: if rest_pid:
stop_rest_cmds = ['pkill', '-P', str(rest_pid)] stop_rest_cmds = ['kill', str(rest_pid)]
call(stop_rest_cmds) call(stop_rest_cmds)
tensorboard_pid_list = nni_config.get_config('tensorboardPidList') tensorboard_pid_list = nni_config.get_config('tensorboardPidList')
if tensorboard_pid_list: if tensorboard_pid_list:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment