paiTrainingService.ts 12 KB
Newer Older
liuzhe-lz's avatar
liuzhe-lz committed
1
2
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
3

4
'use strict';
5
6
7

import * as path from 'path';
import * as request from 'request';
8
import * as component from '../../common/component';
9
10

import { EventEmitter } from 'events';
11
import { Deferred } from 'ts-deferred';
12
import { getExperimentId } from '../../common/experimentStartupInfo';
13
14
import { getLogger, Logger } from '../../common/log';
import {
15
    NNIManagerIpConfig, TrainingService,
16
    TrialJobApplicationForm, TrialJobDetail, TrialJobMetric
17
} from '../../common/trainingService';
18
import { delay } from '../../common/utils';
19
import { PAIJobInfoCollector } from './paiJobInfoCollector';
20
import { PAIJobRestServer } from './paiJobRestServer';
21
import { PAIClusterConfig, PAITrialJobDetail } from './paiConfig';
22
23
24
25
26
27

/**
 * Training Service implementation for OpenPAI (Open Platform for AI)
 * Refer https://github.com/Microsoft/pai for more info about OpenPAI
 */
@component.Singleton
28
29
30
31
32
33
34
35
abstract class PAITrainingService implements TrainingService {
    protected readonly log!: Logger;
    protected readonly metricsEmitter: EventEmitter;
    protected readonly trialJobsMap: Map<string, PAITrialJobDetail>;
    protected readonly expRootDir: string;
    protected paiClusterConfig?: PAIClusterConfig;
    protected readonly jobQueue: string[];
    protected stopping: boolean = false;
36
    protected paiToken?: string;
37
38
39
40
41
42
43
44
45
46
47
48
    protected paiTokenUpdateTime?: number;
    protected readonly paiTokenUpdateInterval: number;
    protected readonly experimentId!: string;
    protected readonly paiJobCollector: PAIJobInfoCollector;
    protected paiRestServerPort?: number;
    protected nniManagerIpConfig?: NNIManagerIpConfig;
    protected versionCheck: boolean = true;
    protected logCollection: string;
    protected isMultiPhase: boolean = false;
    protected authFileHdfsPath: string | undefined = undefined;
    protected portList?: string | undefined;
    protected paiJobRestServer?: PAIJobRestServer;
49
    protected protocol: string = 'http';
50
51
52
53
54

    constructor() {
        this.log = getLogger();
        this.metricsEmitter = new EventEmitter();
        this.trialJobsMap = new Map<string, PAITrialJobDetail>();
55
        this.jobQueue = [];
56
        this.expRootDir = path.join('/nni', 'experiments', getExperimentId());
57
        this.experimentId = getExperimentId();
58
        this.paiJobCollector = new PAIJobInfoCollector(this.trialJobsMap);
59
        this.paiTokenUpdateInterval = 7200000; //2hours
SparkSnail's avatar
SparkSnail committed
60
        this.logCollection = 'none';
61
        this.log.info('Construct paiBase training service.');
62
63
64
    }

    public async run(): Promise<void> {
chicm-ms's avatar
chicm-ms committed
65
        this.log.info('Run PAI training service.');
66
67
68
69
70
71
        if (this.paiJobRestServer === undefined) {
            throw new Error('paiJobRestServer not initialized!');
        }
        await this.paiJobRestServer.start();
        this.paiJobRestServer.setEnableVersionCheck = this.versionCheck;
        this.log.info(`PAI Training service rest server listening on: ${this.paiJobRestServer.endPoint}`);
72
73
74
        await Promise.all([
            this.statusCheckingLoop(),
            this.submitJobLoop()]);
chicm-ms's avatar
chicm-ms committed
75
        this.log.info('PAI training service exit.');
76
77
    }

78
    public async submitTrialJob(_form: TrialJobApplicationForm): Promise<any> {
79
80
81
        throw new Error('Not implemented!');
    }

82
    public async updateTrialJob(_trialJobId: string, _form: TrialJobApplicationForm): Promise<TrialJobDetail> {
83
84
85
        throw new Error('Not implemented!');
    }

86
    protected async submitTrialJobToPAI(_trialJobId: string): Promise<boolean> {
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
        throw new Error('Not implemented!');
    }

    protected async submitJobLoop(): Promise<void> {
        while (!this.stopping) {
            while (!this.stopping && this.jobQueue.length > 0) {
                const trialJobId: string = this.jobQueue[0];
                if (await this.submitTrialJobToPAI(trialJobId)) {
                    // Remove trial job with trialJobId from job queue
                    this.jobQueue.shift();
                } else {
                    // Break the while loop since failed to submitJob
                    break;
                }
            }
            await delay(3000);
        }
    }

106
    public async setClusterMetadata(_key: string, _value: string): Promise<void> {
107
108
109
        throw new Error('Not implemented!');
    }

110
111
    public async listTrialJobs(): Promise<TrialJobDetail[]> {
        const jobs: TrialJobDetail[] = [];
112

113
        for (const key of this.trialJobsMap.keys()) {
114
            jobs.push(await this.getTrialJob(key));
115
        }
116

117
        return jobs;
118
119
    }

120
    public async getTrialJob(trialJobId: string): Promise<TrialJobDetail> {
121
        if (this.paiClusterConfig === undefined) {
122
123
124
125
126
            throw new Error('PAI Cluster config is not initialized');
        }

        const paiTrialJob: PAITrialJobDetail | undefined = this.trialJobsMap.get(trialJobId);

127
        if (paiTrialJob === undefined) {
128
            throw new Error(`trial job ${trialJobId} not found`);
129
        }
130

131
        return paiTrialJob;
132
133
    }

134
    public addTrialJobMetricListener(listener: (metric: TrialJobMetric) => void): void {
135
136
137
        this.metricsEmitter.on('metric', listener);
    }

138
    public removeTrialJobMetricListener(listener: (metric: TrialJobMetric) => void): void {
139
140
141
142
        this.metricsEmitter.off('metric', listener);
    }

    public get isMultiPhaseJobSupported(): boolean {
143
        return true;
144
145
    }

QuanluZhang's avatar
QuanluZhang committed
146
    public cancelTrialJob(trialJobId: string, isEarlyStopped: boolean = false): Promise<void> {
147
        const trialJobDetail: PAITrialJobDetail | undefined = this.trialJobsMap.get(trialJobId);
148
        if (trialJobDetail === undefined) {
chicm-ms's avatar
chicm-ms committed
149
            return Promise.reject(new Error(`cancelTrialJob: trial job id ${trialJobId} not found`));
150
151
        }

152
        if (this.paiClusterConfig === undefined) {
chicm-ms's avatar
chicm-ms committed
153
            return Promise.reject(new Error('PAI Cluster config is not initialized'));
154
        }
155
        if (this.paiToken === undefined) {
chicm-ms's avatar
chicm-ms committed
156
157
158
159
160
161
            return Promise.reject(new Error('PAI token is not initialized'));
        }

        if (trialJobDetail.status === 'UNKNOWN') {
            trialJobDetail.status = 'USER_CANCELED';
            return Promise.resolve();
162
163
164
        }

        const stopJobRequest: request.Options = {
165
            uri: `${this.protocol}://${this.paiClusterConfig.host}/rest-server/api/v1/user/${this.paiClusterConfig.userName}\
166
/jobs/${trialJobDetail.paiJobName}/executionType`,
167
168
            method: 'PUT',
            json: true,
169
            body: { value: 'STOP' },
170
            headers: {
171
172
                'Content-Type': 'application/json',
                Authorization: `Bearer ${this.paiToken}`
173
174
            }
        };
175
176
177

        // Set trialjobDetail's early stopped field, to mark the job's cancellation source
        trialJobDetail.isEarlyStopped = isEarlyStopped;
chicm-ms's avatar
chicm-ms committed
178
        const deferred: Deferred<void> = new Deferred<void>();
179

180
        request(stopJobRequest, (error: Error, response: request.Response, _body: any) => {
181
            if ((error !== undefined && error !== null) || response.statusCode >= 400) {
182
                this.log.error(`PAI Training service: stop trial ${trialJobId} to PAI Cluster failed!`);
183
                deferred.reject((error !== undefined && error !== null) ? error.message :
184
                    `Stop trial failed, http code: ${response.statusCode}`);
185
186
187
188
189
            } else {
                deferred.resolve();
            }
        });

190
        return deferred.promise;
191
192
    }

193
    public getClusterMetadata(_key: string): Promise<string> {
194
        throw new Error('Not implemented!');
195
196
197
    }

    public async cleanUp(): Promise<void> {
chicm-ms's avatar
chicm-ms committed
198
        this.log.info('Stopping PAI training service...');
199
200
        this.stopping = true;

201
202
203
204
        if (this.paiJobRestServer === undefined) {
            throw new Error('paiJobRestServer not initialized!');
        }

205
        try {
206
            await this.paiJobRestServer.stop();
207
208
            this.log.info('PAI Training service rest server stopped successfully.');
        } catch (error) {
209
            this.log.error(`PAI Training service rest server stopped failed, error: ${error.message}`);
210
211
212
        }
    }

chicm-ms's avatar
chicm-ms committed
213
    public get MetricsEmitter(): EventEmitter {
214
215
        return this.metricsEmitter;
    }
216

SparkSnail's avatar
SparkSnail committed
217
218
219
    protected formatPAIHost(host: string): string {
        // If users' host start with 'http://' or 'https://', use the original host,
        // or format to 'http//${host}'
220
221
222
223
224
225
        if (host.startsWith('http://')) {
            this.protocol = 'http';
            return host.replace('http://', '');
        } else if (host.startsWith('https://')) {
            this.protocol = 'https';
            return host.replace('https://', '');
SparkSnail's avatar
SparkSnail committed
226
        } else {
227
            return host;
SparkSnail's avatar
SparkSnail committed
228
229
230
        }
    }

231
    protected async statusCheckingLoop(): Promise<void> {
232
        while (!this.stopping) {
233
            if (this.paiClusterConfig && this.paiClusterConfig.passWord) {
234
235
236
237
238
239
240
241
                try {
                    await this.updatePaiToken();
                } catch (error) {
                    this.log.error(`${error}`);
                    //only throw error when initlize paiToken first time
                    if (this.paiToken === undefined) {
                        throw new Error(error);
                    }
SparkSnail's avatar
SparkSnail committed
242
243
                }
            }
244
            await this.paiJobCollector.retrieveTrialStatus(this.protocol, this.paiToken, this.paiClusterConfig);
245
246
            if (this.paiJobRestServer === undefined) {
                throw new Error('paiBaseJobRestServer not implemented!');
247
            }
248
249
            if (this.paiJobRestServer.getErrorMessage !== undefined) {
                throw new Error(this.paiJobRestServer.getErrorMessage);
250
251
252
253
254
            }
            await delay(3000);
        }
    }

255
256
257
    /**
     * Update pai token by the interval time or initialize the pai token
     */
258
    protected async updatePaiToken(): Promise<void> {
chicm-ms's avatar
chicm-ms committed
259
        const deferred: Deferred<void> = new Deferred<void>();
260
261

        const currentTime: number = new Date().getTime();
262
        //If pai token initialized and not reach the interval time, do not update
263
        if (this.paiTokenUpdateTime !== undefined && (currentTime - this.paiTokenUpdateTime) < this.paiTokenUpdateInterval) {
264
265
            return Promise.resolve();
        }
266

267
        if (this.paiClusterConfig === undefined) {
268
            const paiClusterConfigError: string = `pai cluster config not initialized!`;
269
            this.log.error(`${paiClusterConfigError}`);
270
            throw Error(`${paiClusterConfigError}`);
271
272
        }

273
        const authenticationReq: request.Options = {
274
            uri: `${this.protocol}://${this.paiClusterConfig.host}/rest-server/api/v1/token`,
275
276
277
278
279
280
281
282
            method: 'POST',
            json: true,
            body: {
                username: this.paiClusterConfig.userName,
                password: this.paiClusterConfig.passWord
            }
        };

283
284
        request(authenticationReq, (error: Error, response: request.Response, body: any) => {
            if (error !== undefined && error !== null) {
285
286
287
                this.log.error(`Get PAI token failed: ${error.message}`);
                deferred.reject(new Error(`Get PAI token failed: ${error.message}`));
            } else {
288
                if (response.statusCode !== 200) {
289
                    this.log.error(`Get PAI token failed: get PAI Rest return code ${response.statusCode}`);
290
                    deferred.reject(new Error(`Get PAI token failed: ${response.body}, please check paiConfig username or password`));
291
292
293
294
295
296
                }
                this.paiToken = body.token;
                this.paiTokenUpdateTime = new Date().getTime();
                deferred.resolve();
            }
        });
297

298
        let timeoutId: NodeJS.Timer;
299
        const timeoutDelay: Promise<void> = new Promise<void>((_resolve: Function, reject: Function): void => {
300
301
302
303
304
305
            // Set timeout and reject the promise once reach timeout (5 seconds)
            timeoutId = setTimeout(
                () => reject(new Error('Get PAI token timeout. Please check your PAI cluster.')),
                5000);
        });

306
        return Promise.race([timeoutDelay, deferred.promise])
307
            .finally(() => { clearTimeout(timeoutId); });
308
    }
309
310
}

311
export { PAITrainingService };