paiTrainingService.ts 12 KB
Newer Older
liuzhe-lz's avatar
liuzhe-lz committed
1
2
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
3

4
'use strict';
5
6
7
8

import * as fs from 'fs';
import * as path from 'path';
import * as request from 'request';
9
import * as component from '../../common/component';
10
11

import { EventEmitter } from 'events';
12
13
import { Deferred } from 'ts-deferred';
import { String } from 'typescript-string-operations';
14
import { getExperimentId } from '../../common/experimentStartupInfo';
15
16
import { getLogger, Logger } from '../../common/log';
import {
17
    HyperParameters, NNIManagerIpConfig, TrainingService,
18
    TrialJobApplicationForm, TrialJobDetail, TrialJobMetric
19
} from '../../common/trainingService';
20
import { delay, generateParamFileName,
21
    getExperimentRootDir, getIPV4Address, getVersion, uniqueString, unixPathJoin } from '../../common/utils';
22
23
import { CONTAINER_INSTALL_NNI_SHELL_FORMAT } from '../common/containerJobData';
import { TrialConfigMetadataKey } from '../common/trialConfigMetadataKey';
24
import { execMkdir, validateCodeDir } from '../common/util';
25
import { PAIJobInfoCollector } from './paiJobInfoCollector';
26
import { PAIJobRestServer, ParameterFileMeta } from './paiJobRestServer';
27
import { PAIClusterConfig, PAITrialJobDetail } from './paiConfig';
28
29
30
31
32
33

/**
 * Training Service implementation for OpenPAI (Open Platform for AI)
 * Refer https://github.com/Microsoft/pai for more info about OpenPAI
 */
@component.Singleton
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
abstract class PAITrainingService implements TrainingService {
    protected readonly log!: Logger;
    protected readonly metricsEmitter: EventEmitter;
    protected readonly trialJobsMap: Map<string, PAITrialJobDetail>;
    protected readonly expRootDir: string;
    protected paiClusterConfig?: PAIClusterConfig;
    protected readonly jobQueue: string[];
    protected stopping: boolean = false;
    protected paiToken? : string;
    protected paiTokenUpdateTime?: number;
    protected readonly paiTokenUpdateInterval: number;
    protected readonly experimentId!: string;
    protected readonly paiJobCollector: PAIJobInfoCollector;
    protected paiRestServerPort?: number;
    protected nniManagerIpConfig?: NNIManagerIpConfig;
    protected versionCheck: boolean = true;
    protected logCollection: string;
    protected isMultiPhase: boolean = false;
    protected authFileHdfsPath: string | undefined = undefined;
    protected portList?: string | undefined;
    protected paiJobRestServer?: PAIJobRestServer;
55
56
57
58
59

    constructor() {
        this.log = getLogger();
        this.metricsEmitter = new EventEmitter();
        this.trialJobsMap = new Map<string, PAITrialJobDetail>();
60
        this.jobQueue = [];
61
        this.expRootDir = path.join('/nni', 'experiments', getExperimentId());
62
        this.experimentId = getExperimentId();
63
        this.paiJobCollector = new PAIJobInfoCollector(this.trialJobsMap);
64
        this.paiTokenUpdateInterval = 7200000; //2hours
SparkSnail's avatar
SparkSnail committed
65
        this.logCollection = 'none';
66
        this.log.info('Construct paiBase training service.');
67
68
69
    }

    public async run(): Promise<void> {
chicm-ms's avatar
chicm-ms committed
70
        this.log.info('Run PAI training service.');
71
72
73
74
75
76
        if (this.paiJobRestServer === undefined) {
            throw new Error('paiJobRestServer not initialized!');
        }
        await this.paiJobRestServer.start();
        this.paiJobRestServer.setEnableVersionCheck = this.versionCheck;
        this.log.info(`PAI Training service rest server listening on: ${this.paiJobRestServer.endPoint}`);
77
78
79
        await Promise.all([
            this.statusCheckingLoop(),
            this.submitJobLoop()]);
chicm-ms's avatar
chicm-ms committed
80
        this.log.info('PAI training service exit.');
81
82
    }

83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
    public async submitTrialJob(form: TrialJobApplicationForm): Promise<any> {
        throw new Error('Not implemented!');
    }

    public async updateTrialJob(trialJobId: string, form: TrialJobApplicationForm): Promise<TrialJobDetail> {
        throw new Error('Not implemented!');
    }

    protected async submitTrialJobToPAI(trialJobId: string): Promise<boolean> {
        throw new Error('Not implemented!');
    }

    protected async submitJobLoop(): Promise<void> {
        while (!this.stopping) {
            while (!this.stopping && this.jobQueue.length > 0) {
                const trialJobId: string = this.jobQueue[0];
                if (await this.submitTrialJobToPAI(trialJobId)) {
                    // Remove trial job with trialJobId from job queue
                    this.jobQueue.shift();
                } else {
                    // Break the while loop since failed to submitJob
                    break;
                }
            }
            await delay(3000);
        }
    }

    public async setClusterMetadata(key: string, value: string): Promise<void> {
        throw new Error('Not implemented!');
    }

115
116
    public async listTrialJobs(): Promise<TrialJobDetail[]> {
        const jobs: TrialJobDetail[] = [];
117
118

        for (const [key, value] of this.trialJobsMap) {
119
            jobs.push(await this.getTrialJob(key));
120
        }
121

122
        return jobs;
123
124
    }

125
    public async getTrialJob(trialJobId: string): Promise<TrialJobDetail> {
126
        if (this.paiClusterConfig === undefined) {
127
128
129
130
131
            throw new Error('PAI Cluster config is not initialized');
        }

        const paiTrialJob: PAITrialJobDetail | undefined = this.trialJobsMap.get(trialJobId);

132
        if (paiTrialJob === undefined) {
133
            throw new Error(`trial job ${trialJobId} not found`);
134
        }
135

136
        return paiTrialJob;
137
138
    }

139
    public addTrialJobMetricListener(listener: (metric: TrialJobMetric) => void): void {
140
141
142
        this.metricsEmitter.on('metric', listener);
    }

143
    public removeTrialJobMetricListener(listener: (metric: TrialJobMetric) => void): void {
144
145
146
147
        this.metricsEmitter.off('metric', listener);
    }

    public get isMultiPhaseJobSupported(): boolean {
148
        return true;
149
150
    }

QuanluZhang's avatar
QuanluZhang committed
151
    public cancelTrialJob(trialJobId: string, isEarlyStopped: boolean = false): Promise<void> {
chicm-ms's avatar
chicm-ms committed
152
153
        const trialJobDetail: PAITrialJobDetail | undefined =  this.trialJobsMap.get(trialJobId);
        const deferred: Deferred<void> = new Deferred<void>();
154
        if (trialJobDetail === undefined) {
155
            this.log.error(`cancelTrialJob: trial job id ${trialJobId} not found`);
156

157
158
159
            return Promise.reject();
        }

160
        if (this.paiClusterConfig === undefined) {
161
            throw new Error('PAI Cluster config is not initialized');
162
        }
163
        if (this.paiToken === undefined) {
164
165
166
167
            throw new Error('PAI token is not initialized');
        }

        const stopJobRequest: request.Options = {
SparkSnail's avatar
SparkSnail committed
168
            uri: `${this.paiClusterConfig.host}/rest-server/api/v1/user/${this.paiClusterConfig.userName}\
169
/jobs/${trialJobDetail.paiJobName}/executionType`, 
170
171
            method: 'PUT',
            json: true,
172
            body: {value: 'STOP'},
173
            headers: {
174
175
                'Content-Type': 'application/json',
                Authorization: `Bearer ${this.paiToken}`
176
177
            }
        };
178
179
180
181

        // Set trialjobDetail's early stopped field, to mark the job's cancellation source
        trialJobDetail.isEarlyStopped = isEarlyStopped;

182
        request(stopJobRequest, (error: Error, response: request.Response, body: any) => {
183
            if ((error !== undefined && error !== null) || response.statusCode >= 400) {
184
                this.log.error(`PAI Training service: stop trial ${trialJobId} to PAI Cluster failed!`);
185
186
                deferred.reject((error !== undefined && error !== null) ? error.message :
                 `Stop trial failed, http code: ${response.statusCode}`);
187
188
189
190
191
            } else {
                deferred.resolve();
            }
        });

192
        return deferred.promise;
193
194
195
    }

    public getClusterMetadata(key: string): Promise<string> {
196
        throw new Error('Not implemented!');
197
198
199
    }

    public async cleanUp(): Promise<void> {
chicm-ms's avatar
chicm-ms committed
200
        this.log.info('Stopping PAI training service...');
201
202
        this.stopping = true;

203
204
205
206
        if (this.paiJobRestServer === undefined) {
            throw new Error('paiJobRestServer not initialized!');
        }

207
        try {
208
            await this.paiJobRestServer.stop();
209
210
            this.log.info('PAI Training service rest server stopped successfully.');
        } catch (error) {
211
            this.log.error(`PAI Training service rest server stopped failed, error: ${error.message}`);
212
213
214
        }
    }

chicm-ms's avatar
chicm-ms committed
215
    public get MetricsEmitter(): EventEmitter {
216
217
        return this.metricsEmitter;
    }
218

SparkSnail's avatar
SparkSnail committed
219
220
221
222
223
224
225
226
227
228
    protected formatPAIHost(host: string): string {
        // If users' host start with 'http://' or 'https://', use the original host,
        // or format to 'http//${host}'
        if (host.startsWith('http://') || host.startsWith('https://')) {
            return host;
        } else {
            return `http://${host}`;
        }
    }

229
    protected async statusCheckingLoop(): Promise<void> {
230
        while (!this.stopping) {
231
232
233
234
235
236
237
238
239
            if(this.paiClusterConfig && this.paiClusterConfig.passWord) {
                try {
                    await this.updatePaiToken();
                } catch (error) {
                    this.log.error(`${error}`);
                    //only throw error when initlize paiToken first time
                    if (this.paiToken === undefined) {
                        throw new Error(error);
                    }
SparkSnail's avatar
SparkSnail committed
240
241
                }
            }
242
            await this.paiJobCollector.retrieveTrialStatus(this.paiToken, this.paiClusterConfig);
243
244
            if (this.paiJobRestServer === undefined) {
                throw new Error('paiBaseJobRestServer not implemented!');
245
            }
246
247
            if (this.paiJobRestServer.getErrorMessage !== undefined) {
                throw new Error(this.paiJobRestServer.getErrorMessage);
248
249
250
251
252
            }
            await delay(3000);
        }
    }

253
254
255
    /**
     * Update pai token by the interval time or initialize the pai token
     */
256
    protected async updatePaiToken(): Promise<void> {
chicm-ms's avatar
chicm-ms committed
257
        const deferred: Deferred<void> = new Deferred<void>();
258
259

        const currentTime: number = new Date().getTime();
260
        //If pai token initialized and not reach the interval time, do not update
261
        if (this.paiTokenUpdateTime !== undefined && (currentTime - this.paiTokenUpdateTime) < this.paiTokenUpdateInterval) {
262
263
            return Promise.resolve();
        }
264

265
        if (this.paiClusterConfig === undefined) {
266
            const paiClusterConfigError: string = `pai cluster config not initialized!`;
267
            this.log.error(`${paiClusterConfigError}`);
268
            throw Error(`${paiClusterConfigError}`);
269
270
        }

271
        const authenticationReq: request.Options = {
SparkSnail's avatar
SparkSnail committed
272
            uri: `${this.paiClusterConfig.host}/rest-server/api/v1/token`,
273
274
275
276
277
278
279
280
            method: 'POST',
            json: true,
            body: {
                username: this.paiClusterConfig.userName,
                password: this.paiClusterConfig.passWord
            }
        };

281
282
        request(authenticationReq, (error: Error, response: request.Response, body: any) => {
            if (error !== undefined && error !== null) {
283
284
285
                this.log.error(`Get PAI token failed: ${error.message}`);
                deferred.reject(new Error(`Get PAI token failed: ${error.message}`));
            } else {
286
                if (response.statusCode !== 200) {
287
                    this.log.error(`Get PAI token failed: get PAI Rest return code ${response.statusCode}`);
288
                    deferred.reject(new Error(`Get PAI token failed: ${response.body}, please check paiConfig username or password`));
289
290
291
292
293
294
                }
                this.paiToken = body.token;
                this.paiTokenUpdateTime = new Date().getTime();
                deferred.resolve();
            }
        });
295

296
297
298
299
300
301
302
303
        let timeoutId: NodeJS.Timer;
        const timeoutDelay: Promise<void> = new Promise<void>((resolve: Function, reject: Function): void => {
            // Set timeout and reject the promise once reach timeout (5 seconds)
            timeoutId = setTimeout(
                () => reject(new Error('Get PAI token timeout. Please check your PAI cluster.')),
                5000);
        });

304
        return Promise.race([timeoutDelay, deferred.promise])
305
            .finally(() => { clearTimeout(timeoutId); });
306
    }
307
308
}

309
export { PAITrainingService };