Unverified Commit bd346816 authored by SparkSnail's avatar SparkSnail Committed by GitHub
Browse files

Fix localTrainingService stream (#885)

parent 892c9c4d
......@@ -103,7 +103,7 @@ class LocalTrainingService implements TrainingService {
protected log: Logger;
protected localTrailConfig?: TrialConfig;
private isMultiPhase: boolean = false;
private streams: Array<ts.Stream>;
protected jobStreamMap: Map<string, ts.Stream>;
constructor() {
this.eventEmitter = new EventEmitter();
......@@ -113,7 +113,7 @@ class LocalTrainingService implements TrainingService {
this.stopping = false;
this.log = getLogger();
this.trialSequenceId = -1;
this.streams = new Array<ts.Stream>();
this.jobStreamMap = new Map<string, ts.Stream>();
this.log.info('Construct local machine training service.');
}
......@@ -307,14 +307,24 @@ class LocalTrainingService implements TrainingService {
public cleanUp(): Promise<void> {
this.log.info('Stopping local machine training service...');
this.stopping = true;
for (const stream of this.streams) {
for (const stream of this.jobStreamMap.values()) {
stream.destroy();
}
return Promise.resolve();
}
protected onTrialJobStatusChanged(trialJob: TrialJobDetail, oldStatus: TrialJobStatus): void {
//abstract
//if job is not running, destory job stream
if(['SUCCEEDED', 'FAILED', 'USER_CANCELED', 'SYS_CANCELED', 'EARLY_STOPPED'].includes(trialJob.status)) {
if(this.jobStreamMap.has(trialJob.id)) {
const stream = this.jobStreamMap.get(trialJob.id);
if(!stream) {
throw new Error(`Could not find stream in trial ${trialJob.id}`);
}
stream.destroy();
this.jobStreamMap.delete(trialJob.id);
}
}
}
protected getEnvironmentVariables(trialJobDetail: TrialJobDetail, _: {}): { key: string; value: string }[] {
......@@ -396,7 +406,8 @@ class LocalTrainingService implements TrainingService {
buffer = remain;
}
});
this.streams.push(stream);
this.jobStreamMap.set(trialJobDetail.id, stream);
}
private async runHostJob(form: HostJobApplicationForm): Promise<TrialJobDetail> {
......
......@@ -78,6 +78,7 @@ class LocalTrainingServiceForGPU extends LocalTrainingService {
}
protected onTrialJobStatusChanged(trialJob: LocalTrialJobDetailForGPU, oldStatus: TrialJobStatus): void {
super.onTrialJobStatusChanged(trialJob, oldStatus);
if (trialJob.gpuIndices !== undefined && trialJob.gpuIndices.length !== 0 && this.gpuScheduler !== undefined) {
if (oldStatus === 'RUNNING' && trialJob.status !== 'RUNNING') {
for (const index of trialJob.gpuIndices) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment