paiTrainingService.ts 21 KB
Newer Older
liuzhe-lz's avatar
liuzhe-lz committed
1
2
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
3

4
'use strict';
5

SparkSnail's avatar
SparkSnail committed
6
import * as fs from 'fs';
7
8
import * as path from 'path';
import * as request from 'request';
9
import * as component from '../../common/component';
10
11

import { EventEmitter } from 'events';
12
import { Deferred } from 'ts-deferred';
13
import { getExperimentId } from '../../common/experimentStartupInfo';
14
import { getLogger, Logger } from '../../common/log';
15
import { MethodNotImplementedError } from '../../common/errors';
16
import {
SparkSnail's avatar
SparkSnail committed
17
    HyperParameters, NNIManagerIpConfig, TrainingService,
18
    TrialJobApplicationForm, TrialJobDetail, TrialJobMetric, LogType
19
} from '../../common/trainingService';
20
import { delay } from '../../common/utils';
21
import { ExperimentConfig, OpenpaiConfig, flattenConfig, toMegaBytes } from '../../common/experimentConfig';
22
import { PAIJobInfoCollector } from './paiJobInfoCollector';
23
import { PAIJobRestServer } from './paiJobRestServer';
24
import { PAITrialJobDetail, PAI_TRIAL_COMMAND_FORMAT } from './paiConfig';
SparkSnail's avatar
SparkSnail committed
25
26
27
import { String } from 'typescript-string-operations';
import {
    generateParamFileName,
28
    getIPV4Address, uniqueString
SparkSnail's avatar
SparkSnail committed
29
30
31
32
33
} from '../../common/utils';
import { CONTAINER_INSTALL_NNI_SHELL_FORMAT } from '../common/containerJobData';
import { execMkdir, validateCodeDir, execCopydir } from '../common/util';

const yaml = require('js-yaml');
34

35
36
interface FlattenOpenpaiConfig extends ExperimentConfig, OpenpaiConfig { }

37
38
39
40
41
/**
 * Training Service implementation for OpenPAI (Open Platform for AI)
 * Refer https://github.com/Microsoft/pai for more info about OpenPAI
 */
@component.Singleton
SparkSnail's avatar
SparkSnail committed
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
class PAITrainingService implements TrainingService {
    private readonly log!: Logger;
    private readonly metricsEmitter: EventEmitter;
    private readonly trialJobsMap: Map<string, PAITrialJobDetail>;
    private readonly expRootDir: string;
    private readonly jobQueue: string[];
    private stopping: boolean = false;
    private paiToken?: string;
    private paiTokenUpdateTime?: number;
    private readonly paiTokenUpdateInterval: number;
    private readonly experimentId!: string;
    private readonly paiJobCollector: PAIJobInfoCollector;
    private paiRestServerPort?: number;
    private nniManagerIpConfig?: NNIManagerIpConfig;
    private versionCheck: boolean = true;
57
    private logCollection: string = 'none';
SparkSnail's avatar
SparkSnail committed
58
    private paiJobRestServer?: PAIJobRestServer;
59
    private protocol: string;
SparkSnail's avatar
SparkSnail committed
60
61
62
    private copyExpCodeDirPromise?: Promise<void>;
    private paiJobConfig: any;
    private nniVersion: string | undefined;
63
    private config: FlattenOpenpaiConfig;
64

65
    constructor(config: ExperimentConfig) {
liuzhe-lz's avatar
liuzhe-lz committed
66
        this.log = getLogger('PAITrainingService');
67
68
        this.metricsEmitter = new EventEmitter();
        this.trialJobsMap = new Map<string, PAITrialJobDetail>();
69
        this.jobQueue = [];
Junwei Sun's avatar
Junwei Sun committed
70
        this.expRootDir = path.join('/nni-experiments', getExperimentId());
71
        this.experimentId = getExperimentId();
72
        this.paiJobCollector = new PAIJobInfoCollector(this.trialJobsMap);
73
        this.paiTokenUpdateInterval = 7200000; //2hours
74
        this.log.info('Construct paiBase training service.');
75
76
77
78
79
80
81
82
83
84
85
86
87
        this.config = flattenConfig(config, 'openpai');
        this.paiJobRestServer = new PAIJobRestServer(this);
        this.paiToken = this.config.token;
        this.protocol = this.config.host.toLowerCase().startsWith('https://') ? 'https' : 'http';
        this.copyExpCodeDirPromise = this.copyTrialCode();
    }

    private async copyTrialCode(): Promise<void> {
        await validateCodeDir(this.config.trialCodeDirectory);
        const nniManagerNFSExpCodeDir = path.join(this.config.trialCodeDirectory, this.experimentId, 'nni-code');
        await execMkdir(nniManagerNFSExpCodeDir);
        this.log.info(`Starting copy codeDir data from ${this.config.trialCodeDirectory} to ${nniManagerNFSExpCodeDir}`);
        await execCopydir(this.config.trialCodeDirectory, nniManagerNFSExpCodeDir);
88
89
90
    }

    public async run(): Promise<void> {
chicm-ms's avatar
chicm-ms committed
91
        this.log.info('Run PAI training service.');
92
93
94
95
96
97
        if (this.paiJobRestServer === undefined) {
            throw new Error('paiJobRestServer not initialized!');
        }
        await this.paiJobRestServer.start();
        this.paiJobRestServer.setEnableVersionCheck = this.versionCheck;
        this.log.info(`PAI Training service rest server listening on: ${this.paiJobRestServer.endPoint}`);
98
99
100
        await Promise.all([
            this.statusCheckingLoop(),
            this.submitJobLoop()]);
chicm-ms's avatar
chicm-ms committed
101
        this.log.info('PAI training service exit.');
102
103
    }

104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
    protected async submitJobLoop(): Promise<void> {
        while (!this.stopping) {
            while (!this.stopping && this.jobQueue.length > 0) {
                const trialJobId: string = this.jobQueue[0];
                if (await this.submitTrialJobToPAI(trialJobId)) {
                    // Remove trial job with trialJobId from job queue
                    this.jobQueue.shift();
                } else {
                    // Break the while loop since failed to submitJob
                    break;
                }
            }
            await delay(3000);
        }
    }

120
121
    public async listTrialJobs(): Promise<TrialJobDetail[]> {
        const jobs: TrialJobDetail[] = [];
122

123
        for (const key of this.trialJobsMap.keys()) {
124
            jobs.push(await this.getTrialJob(key));
125
        }
126

127
        return jobs;
128
129
    }

130
131
132
133
    public async getTrialLog(_trialJobId: string, _logType: LogType): Promise<string> {
        throw new MethodNotImplementedError();
    }

134
    public async getTrialJob(trialJobId: string): Promise<TrialJobDetail> {
135
136
        const paiTrialJob: PAITrialJobDetail | undefined = this.trialJobsMap.get(trialJobId);

137
        if (paiTrialJob === undefined) {
138
            throw new Error(`trial job ${trialJobId} not found`);
139
        }
140

141
        return paiTrialJob;
142
143
    }

144
    public addTrialJobMetricListener(listener: (metric: TrialJobMetric) => void): void {
145
146
147
        this.metricsEmitter.on('metric', listener);
    }

148
    public removeTrialJobMetricListener(listener: (metric: TrialJobMetric) => void): void {
149
150
151
        this.metricsEmitter.off('metric', listener);
    }

QuanluZhang's avatar
QuanluZhang committed
152
    public cancelTrialJob(trialJobId: string, isEarlyStopped: boolean = false): Promise<void> {
153
        const trialJobDetail: PAITrialJobDetail | undefined = this.trialJobsMap.get(trialJobId);
154
        if (trialJobDetail === undefined) {
chicm-ms's avatar
chicm-ms committed
155
            return Promise.reject(new Error(`cancelTrialJob: trial job id ${trialJobId} not found`));
156
157
        }

chicm-ms's avatar
chicm-ms committed
158
159
160
        if (trialJobDetail.status === 'UNKNOWN') {
            trialJobDetail.status = 'USER_CANCELED';
            return Promise.resolve();
161
162
163
        }

        const stopJobRequest: request.Options = {
164
            uri: `${this.config.host}/rest-server/api/v2/jobs/${this.config.username}~${trialJobDetail.paiJobName}/executionType`,
165
166
            method: 'PUT',
            json: true,
167
            body: { value: 'STOP' },
168
            headers: {
169
170
                'Content-Type': 'application/json',
                Authorization: `Bearer ${this.paiToken}`
171
172
            }
        };
173
174
175

        // Set trialjobDetail's early stopped field, to mark the job's cancellation source
        trialJobDetail.isEarlyStopped = isEarlyStopped;
chicm-ms's avatar
chicm-ms committed
176
        const deferred: Deferred<void> = new Deferred<void>();
177

178
        request(stopJobRequest, (error: Error, response: request.Response, _body: any) => {
179
            // Status code 202 for success.
180
            if ((error !== undefined && error !== null) || response.statusCode >= 400) {
181
                this.log.error(`PAI Training service: stop trial ${trialJobId} to PAI Cluster failed!`);
182
                deferred.reject((error !== undefined && error !== null) ? error.message :
183
                    `Stop trial failed, http code: ${response.statusCode}`);
184
185
186
187
188
            } else {
                deferred.resolve();
            }
        });

189
        return deferred.promise;
190
191
192
    }

    public async cleanUp(): Promise<void> {
chicm-ms's avatar
chicm-ms committed
193
        this.log.info('Stopping PAI training service...');
194
195
        this.stopping = true;

196
197
198
199
        if (this.paiJobRestServer === undefined) {
            throw new Error('paiJobRestServer not initialized!');
        }

200
        try {
201
            await this.paiJobRestServer.stop();
202
203
            this.log.info('PAI Training service rest server stopped successfully.');
        } catch (error) {
204
            this.log.error(`PAI Training service rest server stopped failed, error: ${error.message}`);
205
206
207
        }
    }

chicm-ms's avatar
chicm-ms committed
208
    public get MetricsEmitter(): EventEmitter {
209
210
        return this.metricsEmitter;
    }
211

SparkSnail's avatar
SparkSnail committed
212
213
214
    protected formatPAIHost(host: string): string {
        // If users' host start with 'http://' or 'https://', use the original host,
        // or format to 'http//${host}'
215
216
217
218
219
220
        if (host.startsWith('http://')) {
            this.protocol = 'http';
            return host.replace('http://', '');
        } else if (host.startsWith('https://')) {
            this.protocol = 'https';
            return host.replace('https://', '');
SparkSnail's avatar
SparkSnail committed
221
        } else {
222
            return host;
SparkSnail's avatar
SparkSnail committed
223
224
225
        }
    }

226
    protected async statusCheckingLoop(): Promise<void> {
227
        while (!this.stopping) {
228
            if (this.config.deprecated && this.config.deprecated.password) {
229
230
231
232
                try {
                    await this.updatePaiToken();
                } catch (error) {
                    this.log.error(`${error}`);
SparkSnail's avatar
SparkSnail committed
233
234
                }
            }
235
            await this.paiJobCollector.retrieveTrialStatus(this.protocol, this.paiToken, this.config);
236
237
            if (this.paiJobRestServer === undefined) {
                throw new Error('paiBaseJobRestServer not implemented!');
238
            }
239
240
            if (this.paiJobRestServer.getErrorMessage !== undefined) {
                throw new Error(this.paiJobRestServer.getErrorMessage);
241
242
243
244
245
            }
            await delay(3000);
        }
    }

246
247
248
    /**
     * Update pai token by the interval time or initialize the pai token
     */
249
    protected async updatePaiToken(): Promise<void> {
chicm-ms's avatar
chicm-ms committed
250
        const deferred: Deferred<void> = new Deferred<void>();
251
252

        const currentTime: number = new Date().getTime();
253
        //If pai token initialized and not reach the interval time, do not update
254
        if (this.paiTokenUpdateTime !== undefined && (currentTime - this.paiTokenUpdateTime) < this.paiTokenUpdateInterval) {
255
256
            return Promise.resolve();
        }
257

258
        const authenticationReq: request.Options = {
259
            uri: `${this.config.host}/rest-server/api/v1/token`,
260
261
262
            method: 'POST',
            json: true,
            body: {
263
264
                username: this.config.username,
                password: this.config.deprecated.password
265
266
267
            }
        };

268
269
        request(authenticationReq, (error: Error, response: request.Response, body: any) => {
            if (error !== undefined && error !== null) {
270
271
272
                this.log.error(`Get PAI token failed: ${error.message}`);
                deferred.reject(new Error(`Get PAI token failed: ${error.message}`));
            } else {
273
                if (response.statusCode !== 200) {
274
                    this.log.error(`Get PAI token failed: get PAI Rest return code ${response.statusCode}`);
275
                    deferred.reject(new Error(`Get PAI token failed: ${response.body}, please check paiConfig username or password`));
276
277
278
279
280
281
                }
                this.paiToken = body.token;
                this.paiTokenUpdateTime = new Date().getTime();
                deferred.resolve();
            }
        });
282

283
        let timeoutId: NodeJS.Timer;
284
        const timeoutDelay: Promise<void> = new Promise<void>((_resolve: Function, reject: Function): void => {
285
286
287
288
289
290
            // Set timeout and reject the promise once reach timeout (5 seconds)
            timeoutId = setTimeout(
                () => reject(new Error('Get PAI token timeout. Please check your PAI cluster.')),
                5000);
        });

291
        return Promise.race([timeoutDelay, deferred.promise])
292
            .finally(() => { clearTimeout(timeoutId); });
293
    }
SparkSnail's avatar
SparkSnail committed
294

295
296
    public async setClusterMetadata(_key: string, _value: string): Promise<void> { return; }
    public async getClusterMetadata(_key: string): Promise<string> { return ''; }
SparkSnail's avatar
SparkSnail committed
297
298
299
300
301
302
303
304
305
306
307
308
309
310

    // update trial parameters for multi-phase
    public async updateTrialJob(trialJobId: string, form: TrialJobApplicationForm): Promise<TrialJobDetail> {
        const trialJobDetail: PAITrialJobDetail | undefined = this.trialJobsMap.get(trialJobId);
        if (trialJobDetail === undefined) {
            throw new Error(`updateTrialJob failed: ${trialJobId} not found`);
        }
        // Write file content ( parameter.cfg ) to working folders
        await this.writeParameterFile(trialJobDetail.logPath, form.hyperParameters);

        return trialJobDetail;
    }

    public async submitTrialJob(form: TrialJobApplicationForm): Promise<TrialJobDetail> {
liuzhe-lz's avatar
liuzhe-lz committed
311
        this.log.info('submitTrialJob: form:',  form);
SparkSnail's avatar
SparkSnail committed
312
313
314
315
316

        const trialJobId: string = uniqueString(5);
        //TODO: use HDFS working folder instead
        const trialWorkingFolder: string = path.join(this.expRootDir, 'trials', trialJobId);
        const paiJobName: string = `nni_exp_${this.experimentId}_trial_${trialJobId}`;
317
318
        const logPath: string = path.join(this.config.localStorageMountPoint, this.experimentId, trialJobId);
        const paiJobDetailUrl: string = `${this.config.host}/job-detail.html?username=${this.config.username}&jobName=${paiJobName}`;
SparkSnail's avatar
SparkSnail committed
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
        const trialJobDetail: PAITrialJobDetail = new PAITrialJobDetail(
            trialJobId,
            'WAITING',
            paiJobName,
            Date.now(),
            trialWorkingFolder,
            form,
            logPath,
            paiJobDetailUrl);

        this.trialJobsMap.set(trialJobId, trialJobDetail);
        this.jobQueue.push(trialJobId);

        return trialJobDetail;
    }

    private generateNNITrialCommand(trialJobDetail: PAITrialJobDetail, command: string): string {
336
337
        const containerNFSExpCodeDir = `${this.config.containerStorageMountPoint}/${this.experimentId}/nni-code`;
        const containerWorkingDir: string = `${this.config.containerStorageMountPoint}/${this.experimentId}/${trialJobDetail.id}`;
SparkSnail's avatar
SparkSnail committed
338
339
340
341
342
343
344
        const nniPaiTrialCommand: string = String.Format(
            PAI_TRIAL_COMMAND_FORMAT,
            `${containerWorkingDir}`,
            `${containerWorkingDir}/nnioutput`,
            trialJobDetail.id,
            this.experimentId,
            trialJobDetail.form.sequenceId,
345
            false,  // multi-phase
SparkSnail's avatar
SparkSnail committed
346
347
            containerNFSExpCodeDir,
            command,
348
            this.config.nniManagerIp || getIPV4Address(),
SparkSnail's avatar
SparkSnail committed
349
350
351
352
353
354
355
356
357
358
359
360
361
362
            this.paiRestServerPort,
            this.nniVersion,
            this.logCollection
        )
            .replace(/\r\n|\n|\r/gm, '');

        return nniPaiTrialCommand;

    }

    private generateJobConfigInYamlFormat(trialJobDetail: PAITrialJobDetail): any {
        const jobName = `nni_exp_${this.experimentId}_trial_${trialJobDetail.id}`

        let nniJobConfig: any = undefined;
363
364
        if (this.config.openpaiConfig !== undefined) {
            nniJobConfig = JSON.parse(JSON.stringify(this.config.openpaiConfig)); //Trick for deep clone in Typescript
SparkSnail's avatar
SparkSnail committed
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
            nniJobConfig.name = jobName;
            // Each taskRole will generate new command in NNI's command format
            // Each command will be formatted to NNI style
            for (const taskRoleIndex in nniJobConfig.taskRoles) {
                const commands = nniJobConfig.taskRoles[taskRoleIndex].commands
                const nniTrialCommand = this.generateNNITrialCommand(trialJobDetail, commands.join(" && ").replace(/(["'$`\\])/g, '\\$1'));
                nniJobConfig.taskRoles[taskRoleIndex].commands = [nniTrialCommand]
            }

        } else {
            nniJobConfig = {
                protocolVersion: 2,
                name: jobName,
                type: 'job',
                jobRetryCount: 0,
                prerequisites: [
                    {
                        type: 'dockerimage',
383
                        uri: this.config.dockerImage,
SparkSnail's avatar
SparkSnail committed
384
385
386
387
388
389
390
391
392
393
394
395
396
                        name: 'docker_image_0'
                    }
                ],
                taskRoles: {
                    taskrole: {
                        instances: 1,
                        completion: {
                            minFailedInstances: 1,
                            minSucceededInstances: -1
                        },
                        taskRetryCount: 0,
                        dockerImage: 'docker_image_0',
                        resourcePerInstance: {
397
398
399
                            gpu: this.config.trialGpuNumber,
                            cpu: this.config.trialCpuNumber,
                            memoryMB: toMegaBytes(this.config.trialMemorySize)
SparkSnail's avatar
SparkSnail committed
400
401
                        },
                        commands: [
402
                            this.generateNNITrialCommand(trialJobDetail, this.config.trialCommand)
SparkSnail's avatar
SparkSnail committed
403
404
405
406
407
408
                        ]
                    }
                },
                extras: {
                    'storages': [
                        {
409
                            name: this.config.storageConfigName
SparkSnail's avatar
SparkSnail committed
410
411
412
413
414
                        }
                    ],
                    submitFrom: 'submit-job-v2'
                }
            }
415
            if (this.config.deprecated && this.config.deprecated.virtualCluster) {
SparkSnail's avatar
SparkSnail committed
416
                nniJobConfig.defaults = {
417
                    virtualCluster: this.config.deprecated.virtualCluster
SparkSnail's avatar
SparkSnail committed
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
                }
            }
        }
        return yaml.safeDump(nniJobConfig);
    }

    protected async submitTrialJobToPAI(trialJobId: string): Promise<boolean> {
        const deferred: Deferred<boolean> = new Deferred<boolean>();
        const trialJobDetail: PAITrialJobDetail | undefined = this.trialJobsMap.get(trialJobId);

        if (trialJobDetail === undefined) {
            throw new Error(`Failed to find PAITrialJobDetail for job ${trialJobId}`);
        }

        if (this.paiJobRestServer === undefined) {
            throw new Error('paiJobRestServer is not initialized');
        }

        // Make sure experiment code files is copied from local to NFS
        if (this.copyExpCodeDirPromise !== undefined) {
            await this.copyExpCodeDirPromise;
            this.log.info(`Copy codeDir data finished.`);
            // All trials share same destination NFS code folder, only copy codeDir once for an experiment.
            // After copy data finished, set copyExpCodeDirPromise be undefined to avoid log content duplicated.
            this.copyExpCodeDirPromise = undefined;
        }

        this.paiRestServerPort = this.paiJobRestServer.clusterRestServerPort;

        // Step 1. Prepare PAI job configuration
        //create trial local working folder locally.
        await execMkdir(trialJobDetail.logPath);
        // Write NNI installation file to local files
        await fs.promises.writeFile(path.join(trialJobDetail.logPath, 'install_nni.sh'), CONTAINER_INSTALL_NNI_SHELL_FORMAT, { encoding: 'utf8' });

        // Write file content ( parameter.cfg ) to local working folders
        if (trialJobDetail.form !== undefined) {
            await this.writeParameterFile(trialJobDetail.logPath, trialJobDetail.form.hyperParameters);
        }

        //Generate Job Configuration in yaml format
        const paiJobConfig = this.generateJobConfigInYamlFormat(trialJobDetail);
        this.log.debug(paiJobConfig);
        // Step 2. Submit PAI job via Rest call
        // Refer https://github.com/Microsoft/pai/blob/master/docs/rest-server/API.md for more detail about PAI Rest API
        const submitJobRequest: request.Options = {
464
            uri: `${this.config.host}/rest-server/api/v2/jobs`,
SparkSnail's avatar
SparkSnail committed
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
            method: 'POST',
            body: paiJobConfig,
            followAllRedirects: true,
            headers: {
                'Content-Type': 'text/yaml',
                Authorization: `Bearer ${this.paiToken}`
            }
        };
        request(submitJobRequest, (error: Error, response: request.Response, body: any) => {
            // If submit success, will get status code 202. refer: https://github.com/microsoft/pai/blob/master/src/rest-server/docs/swagger.yaml
            if ((error !== undefined && error !== null) || response.statusCode >= 400) {
                const errorMessage: string = (error !== undefined && error !== null) ? error.message :
                    `Submit trial ${trialJobId} failed, http code:${response.statusCode}, http body: ${body}`;
                this.log.error(errorMessage);
                trialJobDetail.status = 'FAILED';
                deferred.reject(errorMessage);
            } else {
                trialJobDetail.submitTime = Date.now();
            }
            deferred.resolve(true);
        });

        return deferred.promise;
    }

    private async writeParameterFile(directory: string, hyperParameters: HyperParameters): Promise<void> {
        const filepath: string = path.join(directory, generateParamFileName(hyperParameters));
        await fs.promises.writeFile(filepath, hyperParameters.value, { encoding: 'utf8' });
    }
J-shang's avatar
J-shang committed
494
495
496
497
498
499
500
501

    public getTrialOutputLocalPath(_trialJobId: string): Promise<string> {
        throw new MethodNotImplementedError();
    }

    public fetchTrialOutput(_trialJobId: string, _subpath: string): Promise<void> {
        throw new MethodNotImplementedError();
    }
502
503
}

504
export { PAITrainingService };