paiTrainingService.ts 12.2 KB
Newer Older
liuzhe-lz's avatar
liuzhe-lz committed
1
2
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
3

4
'use strict';
5
6
7

import * as path from 'path';
import * as request from 'request';
8
import * as component from '../../common/component';
9
10

import { EventEmitter } from 'events';
11
import { Deferred } from 'ts-deferred';
12
import { getExperimentId } from '../../common/experimentStartupInfo';
13
import { getLogger, Logger } from '../../common/log';
14
import { MethodNotImplementedError } from '../../common/errors';
15
import {
16
    NNIManagerIpConfig, TrainingService,
17
    TrialJobApplicationForm, TrialJobDetail, TrialJobMetric, LogType
18
} from '../../common/trainingService';
19
import { delay } from '../../common/utils';
20
import { PAIJobInfoCollector } from './paiJobInfoCollector';
21
import { PAIJobRestServer } from './paiJobRestServer';
22
import { PAIClusterConfig, PAITrialJobDetail } from './paiConfig';
23
24
25
26
27
28

/**
 * Training Service implementation for OpenPAI (Open Platform for AI)
 * Refer https://github.com/Microsoft/pai for more info about OpenPAI
 */
@component.Singleton
29
30
31
32
33
34
35
36
abstract class PAITrainingService implements TrainingService {
    protected readonly log!: Logger;
    protected readonly metricsEmitter: EventEmitter;
    protected readonly trialJobsMap: Map<string, PAITrialJobDetail>;
    protected readonly expRootDir: string;
    protected paiClusterConfig?: PAIClusterConfig;
    protected readonly jobQueue: string[];
    protected stopping: boolean = false;
37
    protected paiToken?: string;
38
39
40
41
42
43
44
45
46
47
48
49
    protected paiTokenUpdateTime?: number;
    protected readonly paiTokenUpdateInterval: number;
    protected readonly experimentId!: string;
    protected readonly paiJobCollector: PAIJobInfoCollector;
    protected paiRestServerPort?: number;
    protected nniManagerIpConfig?: NNIManagerIpConfig;
    protected versionCheck: boolean = true;
    protected logCollection: string;
    protected isMultiPhase: boolean = false;
    protected authFileHdfsPath: string | undefined = undefined;
    protected portList?: string | undefined;
    protected paiJobRestServer?: PAIJobRestServer;
50
    protected protocol: string = 'http';
51
52
53
54
55

    constructor() {
        this.log = getLogger();
        this.metricsEmitter = new EventEmitter();
        this.trialJobsMap = new Map<string, PAITrialJobDetail>();
56
        this.jobQueue = [];
Junwei Sun's avatar
Junwei Sun committed
57
        this.expRootDir = path.join('/nni-experiments', getExperimentId());
58
        this.experimentId = getExperimentId();
59
        this.paiJobCollector = new PAIJobInfoCollector(this.trialJobsMap);
60
        this.paiTokenUpdateInterval = 7200000; //2hours
SparkSnail's avatar
SparkSnail committed
61
        this.logCollection = 'none';
62
        this.log.info('Construct paiBase training service.');
63
64
65
    }

    public async run(): Promise<void> {
chicm-ms's avatar
chicm-ms committed
66
        this.log.info('Run PAI training service.');
67
68
69
70
71
72
        if (this.paiJobRestServer === undefined) {
            throw new Error('paiJobRestServer not initialized!');
        }
        await this.paiJobRestServer.start();
        this.paiJobRestServer.setEnableVersionCheck = this.versionCheck;
        this.log.info(`PAI Training service rest server listening on: ${this.paiJobRestServer.endPoint}`);
73
74
75
        await Promise.all([
            this.statusCheckingLoop(),
            this.submitJobLoop()]);
chicm-ms's avatar
chicm-ms committed
76
        this.log.info('PAI training service exit.');
77
78
    }

79
    public async submitTrialJob(_form: TrialJobApplicationForm): Promise<any> {
80
81
82
        throw new Error('Not implemented!');
    }

83
    public async updateTrialJob(_trialJobId: string, _form: TrialJobApplicationForm): Promise<TrialJobDetail> {
84
85
86
        throw new Error('Not implemented!');
    }

87
    protected async submitTrialJobToPAI(_trialJobId: string): Promise<boolean> {
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
        throw new Error('Not implemented!');
    }

    protected async submitJobLoop(): Promise<void> {
        while (!this.stopping) {
            while (!this.stopping && this.jobQueue.length > 0) {
                const trialJobId: string = this.jobQueue[0];
                if (await this.submitTrialJobToPAI(trialJobId)) {
                    // Remove trial job with trialJobId from job queue
                    this.jobQueue.shift();
                } else {
                    // Break the while loop since failed to submitJob
                    break;
                }
            }
            await delay(3000);
        }
    }

107
    public async setClusterMetadata(_key: string, _value: string): Promise<void> {
108
109
110
        throw new Error('Not implemented!');
    }

111
112
    public async listTrialJobs(): Promise<TrialJobDetail[]> {
        const jobs: TrialJobDetail[] = [];
113

114
        for (const key of this.trialJobsMap.keys()) {
115
            jobs.push(await this.getTrialJob(key));
116
        }
117

118
        return jobs;
119
120
    }

121
122
123
124
    public async getTrialLog(_trialJobId: string, _logType: LogType): Promise<string> {
        throw new MethodNotImplementedError();
    }

125
    public async getTrialJob(trialJobId: string): Promise<TrialJobDetail> {
126
        if (this.paiClusterConfig === undefined) {
127
128
129
130
131
            throw new Error('PAI Cluster config is not initialized');
        }

        const paiTrialJob: PAITrialJobDetail | undefined = this.trialJobsMap.get(trialJobId);

132
        if (paiTrialJob === undefined) {
133
            throw new Error(`trial job ${trialJobId} not found`);
134
        }
135

136
        return paiTrialJob;
137
138
    }

139
    public addTrialJobMetricListener(listener: (metric: TrialJobMetric) => void): void {
140
141
142
        this.metricsEmitter.on('metric', listener);
    }

143
    public removeTrialJobMetricListener(listener: (metric: TrialJobMetric) => void): void {
144
145
146
147
        this.metricsEmitter.off('metric', listener);
    }

    public get isMultiPhaseJobSupported(): boolean {
148
        return true;
149
150
    }

QuanluZhang's avatar
QuanluZhang committed
151
    public cancelTrialJob(trialJobId: string, isEarlyStopped: boolean = false): Promise<void> {
152
        const trialJobDetail: PAITrialJobDetail | undefined = this.trialJobsMap.get(trialJobId);
153
        if (trialJobDetail === undefined) {
chicm-ms's avatar
chicm-ms committed
154
            return Promise.reject(new Error(`cancelTrialJob: trial job id ${trialJobId} not found`));
155
156
        }

157
        if (this.paiClusterConfig === undefined) {
chicm-ms's avatar
chicm-ms committed
158
            return Promise.reject(new Error('PAI Cluster config is not initialized'));
159
        }
160
        if (this.paiToken === undefined) {
chicm-ms's avatar
chicm-ms committed
161
162
163
164
165
166
            return Promise.reject(new Error('PAI token is not initialized'));
        }

        if (trialJobDetail.status === 'UNKNOWN') {
            trialJobDetail.status = 'USER_CANCELED';
            return Promise.resolve();
167
168
169
        }

        const stopJobRequest: request.Options = {
170
            uri: `${this.protocol}://${this.paiClusterConfig.host}/rest-server/api/v2/jobs/${this.paiClusterConfig.userName}~${trialJobDetail.paiJobName}/executionType`,
171
172
            method: 'PUT',
            json: true,
173
            body: { value: 'STOP' },
174
            headers: {
175
176
                'Content-Type': 'application/json',
                Authorization: `Bearer ${this.paiToken}`
177
178
            }
        };
179
180
181

        // Set trialjobDetail's early stopped field, to mark the job's cancellation source
        trialJobDetail.isEarlyStopped = isEarlyStopped;
chicm-ms's avatar
chicm-ms committed
182
        const deferred: Deferred<void> = new Deferred<void>();
183

184
        request(stopJobRequest, (error: Error, response: request.Response, _body: any) => {
185
            // Status code 202 for success.
186
            if ((error !== undefined && error !== null) || response.statusCode >= 400) {
187
                this.log.error(`PAI Training service: stop trial ${trialJobId} to PAI Cluster failed!`);
188
                deferred.reject((error !== undefined && error !== null) ? error.message :
189
                    `Stop trial failed, http code: ${response.statusCode}`);
190
191
192
193
194
            } else {
                deferred.resolve();
            }
        });

195
        return deferred.promise;
196
197
    }

198
    public getClusterMetadata(_key: string): Promise<string> {
199
        throw new Error('Not implemented!');
200
201
202
    }

    public async cleanUp(): Promise<void> {
chicm-ms's avatar
chicm-ms committed
203
        this.log.info('Stopping PAI training service...');
204
205
        this.stopping = true;

206
207
208
209
        if (this.paiJobRestServer === undefined) {
            throw new Error('paiJobRestServer not initialized!');
        }

210
        try {
211
            await this.paiJobRestServer.stop();
212
213
            this.log.info('PAI Training service rest server stopped successfully.');
        } catch (error) {
214
            this.log.error(`PAI Training service rest server stopped failed, error: ${error.message}`);
215
216
217
        }
    }

chicm-ms's avatar
chicm-ms committed
218
    public get MetricsEmitter(): EventEmitter {
219
220
        return this.metricsEmitter;
    }
221

SparkSnail's avatar
SparkSnail committed
222
223
224
    protected formatPAIHost(host: string): string {
        // If users' host start with 'http://' or 'https://', use the original host,
        // or format to 'http//${host}'
225
226
227
228
229
230
        if (host.startsWith('http://')) {
            this.protocol = 'http';
            return host.replace('http://', '');
        } else if (host.startsWith('https://')) {
            this.protocol = 'https';
            return host.replace('https://', '');
SparkSnail's avatar
SparkSnail committed
231
        } else {
232
            return host;
SparkSnail's avatar
SparkSnail committed
233
234
235
        }
    }

236
    protected async statusCheckingLoop(): Promise<void> {
237
        while (!this.stopping) {
238
            if (this.paiClusterConfig && this.paiClusterConfig.passWord) {
239
240
241
242
243
244
245
246
                try {
                    await this.updatePaiToken();
                } catch (error) {
                    this.log.error(`${error}`);
                    //only throw error when initlize paiToken first time
                    if (this.paiToken === undefined) {
                        throw new Error(error);
                    }
SparkSnail's avatar
SparkSnail committed
247
248
                }
            }
249
            await this.paiJobCollector.retrieveTrialStatus(this.protocol, this.paiToken, this.paiClusterConfig);
250
251
            if (this.paiJobRestServer === undefined) {
                throw new Error('paiBaseJobRestServer not implemented!');
252
            }
253
254
            if (this.paiJobRestServer.getErrorMessage !== undefined) {
                throw new Error(this.paiJobRestServer.getErrorMessage);
255
256
257
258
259
            }
            await delay(3000);
        }
    }

260
261
262
    /**
     * Update pai token by the interval time or initialize the pai token
     */
263
    protected async updatePaiToken(): Promise<void> {
chicm-ms's avatar
chicm-ms committed
264
        const deferred: Deferred<void> = new Deferred<void>();
265
266

        const currentTime: number = new Date().getTime();
267
        //If pai token initialized and not reach the interval time, do not update
268
        if (this.paiTokenUpdateTime !== undefined && (currentTime - this.paiTokenUpdateTime) < this.paiTokenUpdateInterval) {
269
270
            return Promise.resolve();
        }
271

272
        if (this.paiClusterConfig === undefined) {
273
            const paiClusterConfigError: string = `pai cluster config not initialized!`;
274
            this.log.error(`${paiClusterConfigError}`);
275
            throw Error(`${paiClusterConfigError}`);
276
277
        }

278
        const authenticationReq: request.Options = {
279
            uri: `${this.protocol}://${this.paiClusterConfig.host}/rest-server/api/v1/token`,
280
281
282
283
284
285
286
287
            method: 'POST',
            json: true,
            body: {
                username: this.paiClusterConfig.userName,
                password: this.paiClusterConfig.passWord
            }
        };

288
289
        request(authenticationReq, (error: Error, response: request.Response, body: any) => {
            if (error !== undefined && error !== null) {
290
291
292
                this.log.error(`Get PAI token failed: ${error.message}`);
                deferred.reject(new Error(`Get PAI token failed: ${error.message}`));
            } else {
293
                if (response.statusCode !== 200) {
294
                    this.log.error(`Get PAI token failed: get PAI Rest return code ${response.statusCode}`);
295
                    deferred.reject(new Error(`Get PAI token failed: ${response.body}, please check paiConfig username or password`));
296
297
298
299
300
301
                }
                this.paiToken = body.token;
                this.paiTokenUpdateTime = new Date().getTime();
                deferred.resolve();
            }
        });
302

303
        let timeoutId: NodeJS.Timer;
304
        const timeoutDelay: Promise<void> = new Promise<void>((_resolve: Function, reject: Function): void => {
305
306
307
308
309
310
            // Set timeout and reject the promise once reach timeout (5 seconds)
            timeoutId = setTimeout(
                () => reject(new Error('Get PAI token timeout. Please check your PAI cluster.')),
                5000);
        });

311
        return Promise.race([timeoutDelay, deferred.promise])
312
            .finally(() => { clearTimeout(timeoutId); });
313
    }
314
315
}

316
export { PAITrainingService };