"vscode:/vscode.git/clone" did not exist on "c9f8c7a7f7548dd28cde0285571484a9e6d92bb8"
paiTrainingService.ts 12.2 KB
Newer Older
liuzhe-lz's avatar
liuzhe-lz committed
1
2
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
3

4
'use strict';
5
6
7
8

import * as fs from 'fs';
import * as path from 'path';
import * as request from 'request';
9
import * as component from '../../common/component';
10
11

import { EventEmitter } from 'events';
12
13
import { Deferred } from 'ts-deferred';
import { String } from 'typescript-string-operations';
14
import { getExperimentId } from '../../common/experimentStartupInfo';
15
16
import { getLogger, Logger } from '../../common/log';
import {
17
    HyperParameters, NNIManagerIpConfig, TrainingService,
18
    TrialJobApplicationForm, TrialJobDetail, TrialJobMetric
19
} from '../../common/trainingService';
20
import { delay, generateParamFileName,
21
    getExperimentRootDir, getIPV4Address, getVersion, uniqueString, unixPathJoin } from '../../common/utils';
22
23
import { CONTAINER_INSTALL_NNI_SHELL_FORMAT } from '../common/containerJobData';
import { TrialConfigMetadataKey } from '../common/trialConfigMetadataKey';
24
import { execMkdir, validateCodeDir } from '../common/util';
25
import { PAIJobInfoCollector } from './paiJobInfoCollector';
26
import { PAIJobRestServer, ParameterFileMeta } from './paiJobRestServer';
27
import { PAIClusterConfig, PAITrialJobDetail } from './paiConfig';
28
29
30
31
32
33

/**
 * Training Service implementation for OpenPAI (Open Platform for AI)
 * Refer https://github.com/Microsoft/pai for more info about OpenPAI
 */
@component.Singleton
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
abstract class PAITrainingService implements TrainingService {
    protected readonly log!: Logger;
    protected readonly metricsEmitter: EventEmitter;
    protected readonly trialJobsMap: Map<string, PAITrialJobDetail>;
    protected readonly expRootDir: string;
    protected paiClusterConfig?: PAIClusterConfig;
    protected readonly jobQueue: string[];
    protected stopping: boolean = false;
    protected paiToken? : string;
    protected paiTokenUpdateTime?: number;
    protected readonly paiTokenUpdateInterval: number;
    protected readonly experimentId!: string;
    protected readonly paiJobCollector: PAIJobInfoCollector;
    protected paiRestServerPort?: number;
    protected nniManagerIpConfig?: NNIManagerIpConfig;
    protected versionCheck: boolean = true;
    protected logCollection: string;
    protected isMultiPhase: boolean = false;
    protected authFileHdfsPath: string | undefined = undefined;
    protected portList?: string | undefined;
    protected paiJobRestServer?: PAIJobRestServer;
55
    protected protocol: string = 'http';
56
57
58
59
60

    constructor() {
        this.log = getLogger();
        this.metricsEmitter = new EventEmitter();
        this.trialJobsMap = new Map<string, PAITrialJobDetail>();
61
        this.jobQueue = [];
62
        this.expRootDir = path.join('/nni', 'experiments', getExperimentId());
63
        this.experimentId = getExperimentId();
64
        this.paiJobCollector = new PAIJobInfoCollector(this.trialJobsMap);
65
        this.paiTokenUpdateInterval = 7200000; //2hours
SparkSnail's avatar
SparkSnail committed
66
        this.logCollection = 'none';
67
        this.log.info('Construct paiBase training service.');
68
69
70
    }

    public async run(): Promise<void> {
chicm-ms's avatar
chicm-ms committed
71
        this.log.info('Run PAI training service.');
72
73
74
75
76
77
        if (this.paiJobRestServer === undefined) {
            throw new Error('paiJobRestServer not initialized!');
        }
        await this.paiJobRestServer.start();
        this.paiJobRestServer.setEnableVersionCheck = this.versionCheck;
        this.log.info(`PAI Training service rest server listening on: ${this.paiJobRestServer.endPoint}`);
78
79
80
        await Promise.all([
            this.statusCheckingLoop(),
            this.submitJobLoop()]);
chicm-ms's avatar
chicm-ms committed
81
        this.log.info('PAI training service exit.');
82
83
    }

84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
    public async submitTrialJob(form: TrialJobApplicationForm): Promise<any> {
        throw new Error('Not implemented!');
    }

    public async updateTrialJob(trialJobId: string, form: TrialJobApplicationForm): Promise<TrialJobDetail> {
        throw new Error('Not implemented!');
    }

    protected async submitTrialJobToPAI(trialJobId: string): Promise<boolean> {
        throw new Error('Not implemented!');
    }

    protected async submitJobLoop(): Promise<void> {
        while (!this.stopping) {
            while (!this.stopping && this.jobQueue.length > 0) {
                const trialJobId: string = this.jobQueue[0];
                if (await this.submitTrialJobToPAI(trialJobId)) {
                    // Remove trial job with trialJobId from job queue
                    this.jobQueue.shift();
                } else {
                    // Break the while loop since failed to submitJob
                    break;
                }
            }
            await delay(3000);
        }
    }

    public async setClusterMetadata(key: string, value: string): Promise<void> {
        throw new Error('Not implemented!');
    }

116
117
    public async listTrialJobs(): Promise<TrialJobDetail[]> {
        const jobs: TrialJobDetail[] = [];
118
119

        for (const [key, value] of this.trialJobsMap) {
120
            jobs.push(await this.getTrialJob(key));
121
        }
122

123
        return jobs;
124
125
    }

126
    public async getTrialJob(trialJobId: string): Promise<TrialJobDetail> {
127
        if (this.paiClusterConfig === undefined) {
128
129
130
131
132
            throw new Error('PAI Cluster config is not initialized');
        }

        const paiTrialJob: PAITrialJobDetail | undefined = this.trialJobsMap.get(trialJobId);

133
        if (paiTrialJob === undefined) {
134
            throw new Error(`trial job ${trialJobId} not found`);
135
        }
136

137
        return paiTrialJob;
138
139
    }

140
    public addTrialJobMetricListener(listener: (metric: TrialJobMetric) => void): void {
141
142
143
        this.metricsEmitter.on('metric', listener);
    }

144
    public removeTrialJobMetricListener(listener: (metric: TrialJobMetric) => void): void {
145
146
147
148
        this.metricsEmitter.off('metric', listener);
    }

    public get isMultiPhaseJobSupported(): boolean {
149
        return true;
150
151
    }

QuanluZhang's avatar
QuanluZhang committed
152
    public cancelTrialJob(trialJobId: string, isEarlyStopped: boolean = false): Promise<void> {
chicm-ms's avatar
chicm-ms committed
153
154
        const trialJobDetail: PAITrialJobDetail | undefined =  this.trialJobsMap.get(trialJobId);
        const deferred: Deferred<void> = new Deferred<void>();
155
        if (trialJobDetail === undefined) {
156
            this.log.error(`cancelTrialJob: trial job id ${trialJobId} not found`);
157

158
159
160
            return Promise.reject();
        }

161
        if (this.paiClusterConfig === undefined) {
162
            throw new Error('PAI Cluster config is not initialized');
163
        }
164
        if (this.paiToken === undefined) {
165
166
167
168
            throw new Error('PAI token is not initialized');
        }

        const stopJobRequest: request.Options = {
169
            uri: `${this.protocol}://${this.paiClusterConfig.host}/rest-server/api/v1/user/${this.paiClusterConfig.userName}\
170
/jobs/${trialJobDetail.paiJobName}/executionType`, 
171
172
            method: 'PUT',
            json: true,
173
            body: {value: 'STOP'},
174
            headers: {
175
176
                'Content-Type': 'application/json',
                Authorization: `Bearer ${this.paiToken}`
177
178
            }
        };
179
180
181
182

        // Set trialjobDetail's early stopped field, to mark the job's cancellation source
        trialJobDetail.isEarlyStopped = isEarlyStopped;

183
        request(stopJobRequest, (error: Error, response: request.Response, body: any) => {
184
            if ((error !== undefined && error !== null) || response.statusCode >= 400) {
185
                this.log.error(`PAI Training service: stop trial ${trialJobId} to PAI Cluster failed!`);
186
187
                deferred.reject((error !== undefined && error !== null) ? error.message :
                 `Stop trial failed, http code: ${response.statusCode}`);
188
189
190
191
192
            } else {
                deferred.resolve();
            }
        });

193
        return deferred.promise;
194
195
196
    }

    public getClusterMetadata(key: string): Promise<string> {
197
        throw new Error('Not implemented!');
198
199
200
    }

    public async cleanUp(): Promise<void> {
chicm-ms's avatar
chicm-ms committed
201
        this.log.info('Stopping PAI training service...');
202
203
        this.stopping = true;

204
205
206
207
        if (this.paiJobRestServer === undefined) {
            throw new Error('paiJobRestServer not initialized!');
        }

208
        try {
209
            await this.paiJobRestServer.stop();
210
211
            this.log.info('PAI Training service rest server stopped successfully.');
        } catch (error) {
212
            this.log.error(`PAI Training service rest server stopped failed, error: ${error.message}`);
213
214
215
        }
    }

chicm-ms's avatar
chicm-ms committed
216
    public get MetricsEmitter(): EventEmitter {
217
218
        return this.metricsEmitter;
    }
219

SparkSnail's avatar
SparkSnail committed
220
221
222
    protected formatPAIHost(host: string): string {
        // If users' host start with 'http://' or 'https://', use the original host,
        // or format to 'http//${host}'
223
224
225
226
227
228
        if (host.startsWith('http://')) {
            this.protocol = 'http';
            return host.replace('http://', '');
        } else if (host.startsWith('https://')) {
            this.protocol = 'https';
            return host.replace('https://', '');
SparkSnail's avatar
SparkSnail committed
229
        } else {
230
            return host;
SparkSnail's avatar
SparkSnail committed
231
232
233
        }
    }

234
    protected async statusCheckingLoop(): Promise<void> {
235
        while (!this.stopping) {
236
237
238
239
240
241
242
243
244
            if(this.paiClusterConfig && this.paiClusterConfig.passWord) {
                try {
                    await this.updatePaiToken();
                } catch (error) {
                    this.log.error(`${error}`);
                    //only throw error when initlize paiToken first time
                    if (this.paiToken === undefined) {
                        throw new Error(error);
                    }
SparkSnail's avatar
SparkSnail committed
245
246
                }
            }
247
            await this.paiJobCollector.retrieveTrialStatus(this.protocol, this.paiToken, this.paiClusterConfig);
248
249
            if (this.paiJobRestServer === undefined) {
                throw new Error('paiBaseJobRestServer not implemented!');
250
            }
251
252
            if (this.paiJobRestServer.getErrorMessage !== undefined) {
                throw new Error(this.paiJobRestServer.getErrorMessage);
253
254
255
256
257
            }
            await delay(3000);
        }
    }

258
259
260
    /**
     * Update pai token by the interval time or initialize the pai token
     */
261
    protected async updatePaiToken(): Promise<void> {
chicm-ms's avatar
chicm-ms committed
262
        const deferred: Deferred<void> = new Deferred<void>();
263
264

        const currentTime: number = new Date().getTime();
265
        //If pai token initialized and not reach the interval time, do not update
266
        if (this.paiTokenUpdateTime !== undefined && (currentTime - this.paiTokenUpdateTime) < this.paiTokenUpdateInterval) {
267
268
            return Promise.resolve();
        }
269

270
        if (this.paiClusterConfig === undefined) {
271
            const paiClusterConfigError: string = `pai cluster config not initialized!`;
272
            this.log.error(`${paiClusterConfigError}`);
273
            throw Error(`${paiClusterConfigError}`);
274
275
        }

276
        const authenticationReq: request.Options = {
277
            uri: `${this.protocol}://${this.paiClusterConfig.host}/rest-server/api/v1/token`,
278
279
280
281
282
283
284
285
            method: 'POST',
            json: true,
            body: {
                username: this.paiClusterConfig.userName,
                password: this.paiClusterConfig.passWord
            }
        };

286
287
        request(authenticationReq, (error: Error, response: request.Response, body: any) => {
            if (error !== undefined && error !== null) {
288
289
290
                this.log.error(`Get PAI token failed: ${error.message}`);
                deferred.reject(new Error(`Get PAI token failed: ${error.message}`));
            } else {
291
                if (response.statusCode !== 200) {
292
                    this.log.error(`Get PAI token failed: get PAI Rest return code ${response.statusCode}`);
293
                    deferred.reject(new Error(`Get PAI token failed: ${response.body}, please check paiConfig username or password`));
294
295
296
297
298
299
                }
                this.paiToken = body.token;
                this.paiTokenUpdateTime = new Date().getTime();
                deferred.resolve();
            }
        });
300

301
302
303
304
305
306
307
308
        let timeoutId: NodeJS.Timer;
        const timeoutDelay: Promise<void> = new Promise<void>((resolve: Function, reject: Function): void => {
            // Set timeout and reject the promise once reach timeout (5 seconds)
            timeoutId = setTimeout(
                () => reject(new Error('Get PAI token timeout. Please check your PAI cluster.')),
                5000);
        });

309
        return Promise.race([timeoutDelay, deferred.promise])
310
            .finally(() => { clearTimeout(timeoutId); });
311
    }
312
313
}

314
export { PAITrainingService };