paiTrainingService.ts 12.4 KB
Newer Older
liuzhe-lz's avatar
liuzhe-lz committed
1
2
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
3

4
'use strict';
5
6
7
8

import * as fs from 'fs';
import * as path from 'path';
import * as request from 'request';
9
import * as component from '../../common/component';
10
11

import { EventEmitter } from 'events';
12
13
import { Deferred } from 'ts-deferred';
import { String } from 'typescript-string-operations';
14
import { getExperimentId } from '../../common/experimentStartupInfo';
15
16
import { getLogger, Logger } from '../../common/log';
import {
17
    HyperParameters, NNIManagerIpConfig, TrainingService,
18
    TrialJobApplicationForm, TrialJobDetail, TrialJobMetric
19
} from '../../common/trainingService';
20
import { delay, generateParamFileName,
21
    getExperimentRootDir, getIPV4Address, getVersion, uniqueString, unixPathJoin } from '../../common/utils';
22
23
import { CONTAINER_INSTALL_NNI_SHELL_FORMAT } from '../common/containerJobData';
import { TrialConfigMetadataKey } from '../common/trialConfigMetadataKey';
24
import { execMkdir, validateCodeDir } from '../common/util';
25
import { PAIJobInfoCollector } from './paiJobInfoCollector';
26
import { PAIJobRestServer, ParameterFileMeta } from './paiJobRestServer';
27
import { PAIClusterConfig, PAITrialJobDetail } from './paiConfig';
28
29
30
31
32
33

/**
 * Training Service implementation for OpenPAI (Open Platform for AI)
 * Refer https://github.com/Microsoft/pai for more info about OpenPAI
 */
@component.Singleton
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
abstract class PAITrainingService implements TrainingService {
    protected readonly log!: Logger;
    protected readonly metricsEmitter: EventEmitter;
    protected readonly trialJobsMap: Map<string, PAITrialJobDetail>;
    protected readonly expRootDir: string;
    protected paiClusterConfig?: PAIClusterConfig;
    protected readonly jobQueue: string[];
    protected stopping: boolean = false;
    protected paiToken? : string;
    protected paiTokenUpdateTime?: number;
    protected readonly paiTokenUpdateInterval: number;
    protected readonly experimentId!: string;
    protected readonly paiJobCollector: PAIJobInfoCollector;
    protected paiRestServerPort?: number;
    protected nniManagerIpConfig?: NNIManagerIpConfig;
    protected versionCheck: boolean = true;
    protected logCollection: string;
    protected isMultiPhase: boolean = false;
    protected authFileHdfsPath: string | undefined = undefined;
    protected portList?: string | undefined;
    protected paiJobRestServer?: PAIJobRestServer;
55
    protected protocol: string = 'http';
56
57
58
59
60

    constructor() {
        this.log = getLogger();
        this.metricsEmitter = new EventEmitter();
        this.trialJobsMap = new Map<string, PAITrialJobDetail>();
61
        this.jobQueue = [];
62
        this.expRootDir = path.join('/nni', 'experiments', getExperimentId());
63
        this.experimentId = getExperimentId();
64
        this.paiJobCollector = new PAIJobInfoCollector(this.trialJobsMap);
65
        this.paiTokenUpdateInterval = 7200000; //2hours
SparkSnail's avatar
SparkSnail committed
66
        this.logCollection = 'none';
67
        this.log.info('Construct paiBase training service.');
68
69
70
    }

    public async run(): Promise<void> {
chicm-ms's avatar
chicm-ms committed
71
        this.log.info('Run PAI training service.');
72
73
74
75
76
77
        if (this.paiJobRestServer === undefined) {
            throw new Error('paiJobRestServer not initialized!');
        }
        await this.paiJobRestServer.start();
        this.paiJobRestServer.setEnableVersionCheck = this.versionCheck;
        this.log.info(`PAI Training service rest server listening on: ${this.paiJobRestServer.endPoint}`);
78
79
80
        await Promise.all([
            this.statusCheckingLoop(),
            this.submitJobLoop()]);
chicm-ms's avatar
chicm-ms committed
81
        this.log.info('PAI training service exit.');
82
83
    }

84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
    public async submitTrialJob(form: TrialJobApplicationForm): Promise<any> {
        throw new Error('Not implemented!');
    }

    public async updateTrialJob(trialJobId: string, form: TrialJobApplicationForm): Promise<TrialJobDetail> {
        throw new Error('Not implemented!');
    }

    protected async submitTrialJobToPAI(trialJobId: string): Promise<boolean> {
        throw new Error('Not implemented!');
    }

    protected async submitJobLoop(): Promise<void> {
        while (!this.stopping) {
            while (!this.stopping && this.jobQueue.length > 0) {
                const trialJobId: string = this.jobQueue[0];
                if (await this.submitTrialJobToPAI(trialJobId)) {
                    // Remove trial job with trialJobId from job queue
                    this.jobQueue.shift();
                } else {
                    // Break the while loop since failed to submitJob
                    break;
                }
            }
            await delay(3000);
        }
    }

    public async setClusterMetadata(key: string, value: string): Promise<void> {
        throw new Error('Not implemented!');
    }

116
117
    public async listTrialJobs(): Promise<TrialJobDetail[]> {
        const jobs: TrialJobDetail[] = [];
118
119

        for (const [key, value] of this.trialJobsMap) {
120
            jobs.push(await this.getTrialJob(key));
121
        }
122

123
        return jobs;
124
125
    }

126
    public async getTrialJob(trialJobId: string): Promise<TrialJobDetail> {
127
        if (this.paiClusterConfig === undefined) {
128
129
130
131
132
            throw new Error('PAI Cluster config is not initialized');
        }

        const paiTrialJob: PAITrialJobDetail | undefined = this.trialJobsMap.get(trialJobId);

133
        if (paiTrialJob === undefined) {
134
            throw new Error(`trial job ${trialJobId} not found`);
135
        }
136

137
        return paiTrialJob;
138
139
    }

140
    public addTrialJobMetricListener(listener: (metric: TrialJobMetric) => void): void {
141
142
143
        this.metricsEmitter.on('metric', listener);
    }

144
    public removeTrialJobMetricListener(listener: (metric: TrialJobMetric) => void): void {
145
146
147
148
        this.metricsEmitter.off('metric', listener);
    }

    public get isMultiPhaseJobSupported(): boolean {
149
        return true;
150
151
    }

QuanluZhang's avatar
QuanluZhang committed
152
    public cancelTrialJob(trialJobId: string, isEarlyStopped: boolean = false): Promise<void> {
chicm-ms's avatar
chicm-ms committed
153
        const trialJobDetail: PAITrialJobDetail | undefined =  this.trialJobsMap.get(trialJobId);
154
        if (trialJobDetail === undefined) {
chicm-ms's avatar
chicm-ms committed
155
            return Promise.reject(new Error(`cancelTrialJob: trial job id ${trialJobId} not found`));
156
157
        }

158
        if (this.paiClusterConfig === undefined) {
chicm-ms's avatar
chicm-ms committed
159
            return Promise.reject(new Error('PAI Cluster config is not initialized'));
160
        }
161
        if (this.paiToken === undefined) {
chicm-ms's avatar
chicm-ms committed
162
163
164
165
166
167
            return Promise.reject(new Error('PAI token is not initialized'));
        }

        if (trialJobDetail.status === 'UNKNOWN') {
            trialJobDetail.status = 'USER_CANCELED';
            return Promise.resolve();
168
169
170
        }

        const stopJobRequest: request.Options = {
171
            uri: `${this.protocol}://${this.paiClusterConfig.host}/rest-server/api/v1/user/${this.paiClusterConfig.userName}\
172
/jobs/${trialJobDetail.paiJobName}/executionType`, 
173
174
            method: 'PUT',
            json: true,
175
            body: {value: 'STOP'},
176
            headers: {
177
178
                'Content-Type': 'application/json',
                Authorization: `Bearer ${this.paiToken}`
179
180
            }
        };
181
182
183

        // Set trialjobDetail's early stopped field, to mark the job's cancellation source
        trialJobDetail.isEarlyStopped = isEarlyStopped;
chicm-ms's avatar
chicm-ms committed
184
        const deferred: Deferred<void> = new Deferred<void>();
185

186
        request(stopJobRequest, (error: Error, response: request.Response, body: any) => {
187
            if ((error !== undefined && error !== null) || response.statusCode >= 400) {
188
                this.log.error(`PAI Training service: stop trial ${trialJobId} to PAI Cluster failed!`);
189
190
                deferred.reject((error !== undefined && error !== null) ? error.message :
                 `Stop trial failed, http code: ${response.statusCode}`);
191
192
193
194
195
            } else {
                deferred.resolve();
            }
        });

196
        return deferred.promise;
197
198
199
    }

    public getClusterMetadata(key: string): Promise<string> {
200
        throw new Error('Not implemented!');
201
202
203
    }

    public async cleanUp(): Promise<void> {
chicm-ms's avatar
chicm-ms committed
204
        this.log.info('Stopping PAI training service...');
205
206
        this.stopping = true;

207
208
209
210
        if (this.paiJobRestServer === undefined) {
            throw new Error('paiJobRestServer not initialized!');
        }

211
        try {
212
            await this.paiJobRestServer.stop();
213
214
            this.log.info('PAI Training service rest server stopped successfully.');
        } catch (error) {
215
            this.log.error(`PAI Training service rest server stopped failed, error: ${error.message}`);
216
217
218
        }
    }

chicm-ms's avatar
chicm-ms committed
219
    public get MetricsEmitter(): EventEmitter {
220
221
        return this.metricsEmitter;
    }
222

SparkSnail's avatar
SparkSnail committed
223
224
225
    protected formatPAIHost(host: string): string {
        // If users' host start with 'http://' or 'https://', use the original host,
        // or format to 'http//${host}'
226
227
228
229
230
231
        if (host.startsWith('http://')) {
            this.protocol = 'http';
            return host.replace('http://', '');
        } else if (host.startsWith('https://')) {
            this.protocol = 'https';
            return host.replace('https://', '');
SparkSnail's avatar
SparkSnail committed
232
        } else {
233
            return host;
SparkSnail's avatar
SparkSnail committed
234
235
236
        }
    }

237
    protected async statusCheckingLoop(): Promise<void> {
238
        while (!this.stopping) {
239
240
241
242
243
244
245
246
247
            if(this.paiClusterConfig && this.paiClusterConfig.passWord) {
                try {
                    await this.updatePaiToken();
                } catch (error) {
                    this.log.error(`${error}`);
                    //only throw error when initlize paiToken first time
                    if (this.paiToken === undefined) {
                        throw new Error(error);
                    }
SparkSnail's avatar
SparkSnail committed
248
249
                }
            }
250
            await this.paiJobCollector.retrieveTrialStatus(this.protocol, this.paiToken, this.paiClusterConfig);
251
252
            if (this.paiJobRestServer === undefined) {
                throw new Error('paiBaseJobRestServer not implemented!');
253
            }
254
255
            if (this.paiJobRestServer.getErrorMessage !== undefined) {
                throw new Error(this.paiJobRestServer.getErrorMessage);
256
257
258
259
260
            }
            await delay(3000);
        }
    }

261
262
263
    /**
     * Update pai token by the interval time or initialize the pai token
     */
264
    protected async updatePaiToken(): Promise<void> {
chicm-ms's avatar
chicm-ms committed
265
        const deferred: Deferred<void> = new Deferred<void>();
266
267

        const currentTime: number = new Date().getTime();
268
        //If pai token initialized and not reach the interval time, do not update
269
        if (this.paiTokenUpdateTime !== undefined && (currentTime - this.paiTokenUpdateTime) < this.paiTokenUpdateInterval) {
270
271
            return Promise.resolve();
        }
272

273
        if (this.paiClusterConfig === undefined) {
274
            const paiClusterConfigError: string = `pai cluster config not initialized!`;
275
            this.log.error(`${paiClusterConfigError}`);
276
            throw Error(`${paiClusterConfigError}`);
277
278
        }

279
        const authenticationReq: request.Options = {
280
            uri: `${this.protocol}://${this.paiClusterConfig.host}/rest-server/api/v1/token`,
281
282
283
284
285
286
287
288
            method: 'POST',
            json: true,
            body: {
                username: this.paiClusterConfig.userName,
                password: this.paiClusterConfig.passWord
            }
        };

289
290
        request(authenticationReq, (error: Error, response: request.Response, body: any) => {
            if (error !== undefined && error !== null) {
291
292
293
                this.log.error(`Get PAI token failed: ${error.message}`);
                deferred.reject(new Error(`Get PAI token failed: ${error.message}`));
            } else {
294
                if (response.statusCode !== 200) {
295
                    this.log.error(`Get PAI token failed: get PAI Rest return code ${response.statusCode}`);
296
                    deferred.reject(new Error(`Get PAI token failed: ${response.body}, please check paiConfig username or password`));
297
298
299
300
301
302
                }
                this.paiToken = body.token;
                this.paiTokenUpdateTime = new Date().getTime();
                deferred.resolve();
            }
        });
303

304
305
306
307
308
309
310
311
        let timeoutId: NodeJS.Timer;
        const timeoutDelay: Promise<void> = new Promise<void>((resolve: Function, reject: Function): void => {
            // Set timeout and reject the promise once reach timeout (5 seconds)
            timeoutId = setTimeout(
                () => reject(new Error('Get PAI token timeout. Please check your PAI cluster.')),
                5000);
        });

312
        return Promise.race([timeoutDelay, deferred.promise])
313
            .finally(() => { clearTimeout(timeoutId); });
314
    }
315
316
}

317
export { PAITrainingService };