paiTrainingService.ts 21.1 KB
Newer Older
liuzhe-lz's avatar
liuzhe-lz committed
1
2
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
3

4
'use strict';
5

SparkSnail's avatar
SparkSnail committed
6
import * as fs from 'fs';
7
8
import * as path from 'path';
import * as request from 'request';
9
import * as component from '../../common/component';
10
11

import { EventEmitter } from 'events';
12
import { Deferred } from 'ts-deferred';
13
import { getExperimentId } from '../../common/experimentStartupInfo';
14
import { getLogger, Logger } from '../../common/log';
15
import { MethodNotImplementedError } from '../../common/errors';
16
import {
SparkSnail's avatar
SparkSnail committed
17
    HyperParameters, NNIManagerIpConfig, TrainingService,
Yuge Zhang's avatar
Yuge Zhang committed
18
    TrialJobApplicationForm, TrialJobDetail, TrialJobMetric
19
} from '../../common/trainingService';
20
import { delay } from '../../common/utils';
21
import { ExperimentConfig, OpenpaiConfig, flattenConfig, toMegaBytes } from '../../common/experimentConfig';
22
import { PAIJobInfoCollector } from './paiJobInfoCollector';
23
import { PAIJobRestServer } from './paiJobRestServer';
24
import { PAITrialJobDetail, PAI_TRIAL_COMMAND_FORMAT } from './paiConfig';
SparkSnail's avatar
SparkSnail committed
25
import { String } from 'typescript-string-operations';
liuzhe-lz's avatar
liuzhe-lz committed
26
import { generateParamFileName, getIPV4Address, uniqueString } from '../../common/utils';
SparkSnail's avatar
SparkSnail committed
27
28
29
30
import { CONTAINER_INSTALL_NNI_SHELL_FORMAT } from '../common/containerJobData';
import { execMkdir, validateCodeDir, execCopydir } from '../common/util';

const yaml = require('js-yaml');
31

32
33
interface FlattenOpenpaiConfig extends ExperimentConfig, OpenpaiConfig { }

34
35
36
37
38
/**
 * Training Service implementation for OpenPAI (Open Platform for AI)
 * Refer https://github.com/Microsoft/pai for more info about OpenPAI
 */
@component.Singleton
SparkSnail's avatar
SparkSnail committed
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
class PAITrainingService implements TrainingService {
    private readonly log!: Logger;
    private readonly metricsEmitter: EventEmitter;
    private readonly trialJobsMap: Map<string, PAITrialJobDetail>;
    private readonly expRootDir: string;
    private readonly jobQueue: string[];
    private stopping: boolean = false;
    private paiToken?: string;
    private paiTokenUpdateTime?: number;
    private readonly paiTokenUpdateInterval: number;
    private readonly experimentId!: string;
    private readonly paiJobCollector: PAIJobInfoCollector;
    private paiRestServerPort?: number;
    private nniManagerIpConfig?: NNIManagerIpConfig;
    private versionCheck: boolean = true;
54
    private logCollection: string = 'none';
SparkSnail's avatar
SparkSnail committed
55
    private paiJobRestServer?: PAIJobRestServer;
56
    private protocol: string;
SparkSnail's avatar
SparkSnail committed
57
58
59
    private copyExpCodeDirPromise?: Promise<void>;
    private paiJobConfig: any;
    private nniVersion: string | undefined;
60
    private config: FlattenOpenpaiConfig;
61

62
    constructor(config: ExperimentConfig) {
liuzhe-lz's avatar
liuzhe-lz committed
63
        this.log = getLogger('PAITrainingService');
64
65
        this.metricsEmitter = new EventEmitter();
        this.trialJobsMap = new Map<string, PAITrialJobDetail>();
66
        this.jobQueue = [];
Junwei Sun's avatar
Junwei Sun committed
67
        this.expRootDir = path.join('/nni-experiments', getExperimentId());
68
        this.experimentId = getExperimentId();
69
        this.paiJobCollector = new PAIJobInfoCollector(this.trialJobsMap);
70
        this.paiTokenUpdateInterval = 7200000; //2hours
71
        this.log.info('Construct paiBase training service.');
72
        this.config = flattenConfig(config, 'openpai');
liuzhe-lz's avatar
liuzhe-lz committed
73
        this.versionCheck = !this.config.debug;
74
75
76
77
78
79
80
81
        this.paiJobRestServer = new PAIJobRestServer(this);
        this.paiToken = this.config.token;
        this.protocol = this.config.host.toLowerCase().startsWith('https://') ? 'https' : 'http';
        this.copyExpCodeDirPromise = this.copyTrialCode();
    }

    private async copyTrialCode(): Promise<void> {
        await validateCodeDir(this.config.trialCodeDirectory);
liuzhe-lz's avatar
liuzhe-lz committed
82
        const nniManagerNFSExpCodeDir = path.join(this.config.localStorageMountPoint, this.experimentId, 'nni-code');
83
84
85
        await execMkdir(nniManagerNFSExpCodeDir);
        this.log.info(`Starting copy codeDir data from ${this.config.trialCodeDirectory} to ${nniManagerNFSExpCodeDir}`);
        await execCopydir(this.config.trialCodeDirectory, nniManagerNFSExpCodeDir);
86
87
88
    }

    public async run(): Promise<void> {
chicm-ms's avatar
chicm-ms committed
89
        this.log.info('Run PAI training service.');
90
91
92
93
94
95
        if (this.paiJobRestServer === undefined) {
            throw new Error('paiJobRestServer not initialized!');
        }
        await this.paiJobRestServer.start();
        this.paiJobRestServer.setEnableVersionCheck = this.versionCheck;
        this.log.info(`PAI Training service rest server listening on: ${this.paiJobRestServer.endPoint}`);
96
97
98
        await Promise.all([
            this.statusCheckingLoop(),
            this.submitJobLoop()]);
chicm-ms's avatar
chicm-ms committed
99
        this.log.info('PAI training service exit.');
100
101
    }

102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
    protected async submitJobLoop(): Promise<void> {
        while (!this.stopping) {
            while (!this.stopping && this.jobQueue.length > 0) {
                const trialJobId: string = this.jobQueue[0];
                if (await this.submitTrialJobToPAI(trialJobId)) {
                    // Remove trial job with trialJobId from job queue
                    this.jobQueue.shift();
                } else {
                    // Break the while loop since failed to submitJob
                    break;
                }
            }
            await delay(3000);
        }
    }

118
119
    public async listTrialJobs(): Promise<TrialJobDetail[]> {
        const jobs: TrialJobDetail[] = [];
120

121
        for (const key of this.trialJobsMap.keys()) {
122
            jobs.push(await this.getTrialJob(key));
123
        }
124

125
        return jobs;
126
127
    }

Yuge Zhang's avatar
Yuge Zhang committed
128
    public async getTrialFile(_trialJobId: string, _fileName: string): Promise<string | Buffer> {
129
130
131
        throw new MethodNotImplementedError();
    }

132
    public async getTrialJob(trialJobId: string): Promise<TrialJobDetail> {
133
134
        const paiTrialJob: PAITrialJobDetail | undefined = this.trialJobsMap.get(trialJobId);

135
        if (paiTrialJob === undefined) {
136
            throw new Error(`trial job ${trialJobId} not found`);
137
        }
138

139
        return paiTrialJob;
140
141
    }

142
    public addTrialJobMetricListener(listener: (metric: TrialJobMetric) => void): void {
143
144
145
        this.metricsEmitter.on('metric', listener);
    }

146
    public removeTrialJobMetricListener(listener: (metric: TrialJobMetric) => void): void {
147
148
149
        this.metricsEmitter.off('metric', listener);
    }

QuanluZhang's avatar
QuanluZhang committed
150
    public cancelTrialJob(trialJobId: string, isEarlyStopped: boolean = false): Promise<void> {
151
        const trialJobDetail: PAITrialJobDetail | undefined = this.trialJobsMap.get(trialJobId);
152
        if (trialJobDetail === undefined) {
chicm-ms's avatar
chicm-ms committed
153
            return Promise.reject(new Error(`cancelTrialJob: trial job id ${trialJobId} not found`));
154
155
        }

chicm-ms's avatar
chicm-ms committed
156
157
158
        if (trialJobDetail.status === 'UNKNOWN') {
            trialJobDetail.status = 'USER_CANCELED';
            return Promise.resolve();
159
160
161
        }

        const stopJobRequest: request.Options = {
162
            uri: `${this.config.host}/rest-server/api/v2/jobs/${this.config.username}~${trialJobDetail.paiJobName}/executionType`,
163
164
            method: 'PUT',
            json: true,
165
            body: { value: 'STOP' },
166
            headers: {
167
168
                'Content-Type': 'application/json',
                Authorization: `Bearer ${this.paiToken}`
169
170
            }
        };
171
172
173

        // Set trialjobDetail's early stopped field, to mark the job's cancellation source
        trialJobDetail.isEarlyStopped = isEarlyStopped;
chicm-ms's avatar
chicm-ms committed
174
        const deferred: Deferred<void> = new Deferred<void>();
175

176
        request(stopJobRequest, (error: Error, response: request.Response, _body: any) => {
177
            // Status code 202 for success.
178
            if ((error !== undefined && error !== null) || response.statusCode >= 400) {
179
                this.log.error(`PAI Training service: stop trial ${trialJobId} to PAI Cluster failed!`);
180
                deferred.reject((error !== undefined && error !== null) ? error.message :
181
                    `Stop trial failed, http code: ${response.statusCode}`);
182
183
184
185
186
            } else {
                deferred.resolve();
            }
        });

187
        return deferred.promise;
188
189
190
    }

    public async cleanUp(): Promise<void> {
chicm-ms's avatar
chicm-ms committed
191
        this.log.info('Stopping PAI training service...');
192
193
        this.stopping = true;

194
195
196
197
        if (this.paiJobRestServer === undefined) {
            throw new Error('paiJobRestServer not initialized!');
        }

198
        try {
199
            await this.paiJobRestServer.stop();
200
201
            this.log.info('PAI Training service rest server stopped successfully.');
        } catch (error) {
202
            this.log.error(`PAI Training service rest server stopped failed, error: ${error.message}`);
203
204
205
        }
    }

chicm-ms's avatar
chicm-ms committed
206
    public get MetricsEmitter(): EventEmitter {
207
208
        return this.metricsEmitter;
    }
209

SparkSnail's avatar
SparkSnail committed
210
211
212
    protected formatPAIHost(host: string): string {
        // If users' host start with 'http://' or 'https://', use the original host,
        // or format to 'http//${host}'
213
214
215
216
217
218
        if (host.startsWith('http://')) {
            this.protocol = 'http';
            return host.replace('http://', '');
        } else if (host.startsWith('https://')) {
            this.protocol = 'https';
            return host.replace('https://', '');
SparkSnail's avatar
SparkSnail committed
219
        } else {
220
            return host;
SparkSnail's avatar
SparkSnail committed
221
222
223
        }
    }

224
    protected async statusCheckingLoop(): Promise<void> {
225
        while (!this.stopping) {
226
            if (this.config.deprecated && this.config.deprecated.password) {
227
228
229
230
                try {
                    await this.updatePaiToken();
                } catch (error) {
                    this.log.error(`${error}`);
SparkSnail's avatar
SparkSnail committed
231
232
                }
            }
233
            await this.paiJobCollector.retrieveTrialStatus(this.protocol, this.paiToken, this.config);
234
235
            if (this.paiJobRestServer === undefined) {
                throw new Error('paiBaseJobRestServer not implemented!');
236
            }
237
238
            if (this.paiJobRestServer.getErrorMessage !== undefined) {
                throw new Error(this.paiJobRestServer.getErrorMessage);
239
240
241
242
243
            }
            await delay(3000);
        }
    }

244
245
246
    /**
     * Update pai token by the interval time or initialize the pai token
     */
247
    protected async updatePaiToken(): Promise<void> {
chicm-ms's avatar
chicm-ms committed
248
        const deferred: Deferred<void> = new Deferred<void>();
249
250

        const currentTime: number = new Date().getTime();
251
        //If pai token initialized and not reach the interval time, do not update
252
        if (this.paiTokenUpdateTime !== undefined && (currentTime - this.paiTokenUpdateTime) < this.paiTokenUpdateInterval) {
253
254
            return Promise.resolve();
        }
255

256
        const authenticationReq: request.Options = {
257
            uri: `${this.config.host}/rest-server/api/v1/token`,
258
259
260
            method: 'POST',
            json: true,
            body: {
261
262
                username: this.config.username,
                password: this.config.deprecated.password
263
264
265
            }
        };

266
267
        request(authenticationReq, (error: Error, response: request.Response, body: any) => {
            if (error !== undefined && error !== null) {
268
269
270
                this.log.error(`Get PAI token failed: ${error.message}`);
                deferred.reject(new Error(`Get PAI token failed: ${error.message}`));
            } else {
271
                if (response.statusCode !== 200) {
272
                    this.log.error(`Get PAI token failed: get PAI Rest return code ${response.statusCode}`);
273
                    deferred.reject(new Error(`Get PAI token failed: ${response.body}, please check paiConfig username or password`));
274
275
276
277
278
279
                }
                this.paiToken = body.token;
                this.paiTokenUpdateTime = new Date().getTime();
                deferred.resolve();
            }
        });
280

281
        let timeoutId: NodeJS.Timer;
282
        const timeoutDelay: Promise<void> = new Promise<void>((_resolve: Function, reject: Function): void => {
283
284
285
286
287
288
            // Set timeout and reject the promise once reach timeout (5 seconds)
            timeoutId = setTimeout(
                () => reject(new Error('Get PAI token timeout. Please check your PAI cluster.')),
                5000);
        });

289
        return Promise.race([timeoutDelay, deferred.promise])
290
            .finally(() => { clearTimeout(timeoutId); });
291
    }
SparkSnail's avatar
SparkSnail committed
292

293
294
    public async setClusterMetadata(_key: string, _value: string): Promise<void> { return; }
    public async getClusterMetadata(_key: string): Promise<string> { return ''; }
SparkSnail's avatar
SparkSnail committed
295
296
297
298
299
300
301
302
303
304
305
306
307
308

    // update trial parameters for multi-phase
    public async updateTrialJob(trialJobId: string, form: TrialJobApplicationForm): Promise<TrialJobDetail> {
        const trialJobDetail: PAITrialJobDetail | undefined = this.trialJobsMap.get(trialJobId);
        if (trialJobDetail === undefined) {
            throw new Error(`updateTrialJob failed: ${trialJobId} not found`);
        }
        // Write file content ( parameter.cfg ) to working folders
        await this.writeParameterFile(trialJobDetail.logPath, form.hyperParameters);

        return trialJobDetail;
    }

    public async submitTrialJob(form: TrialJobApplicationForm): Promise<TrialJobDetail> {
liuzhe-lz's avatar
liuzhe-lz committed
309
        this.log.info('submitTrialJob: form:',  form);
SparkSnail's avatar
SparkSnail committed
310
311
312
313
314

        const trialJobId: string = uniqueString(5);
        //TODO: use HDFS working folder instead
        const trialWorkingFolder: string = path.join(this.expRootDir, 'trials', trialJobId);
        const paiJobName: string = `nni_exp_${this.experimentId}_trial_${trialJobId}`;
315
316
        const logPath: string = path.join(this.config.localStorageMountPoint, this.experimentId, trialJobId);
        const paiJobDetailUrl: string = `${this.config.host}/job-detail.html?username=${this.config.username}&jobName=${paiJobName}`;
SparkSnail's avatar
SparkSnail committed
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
        const trialJobDetail: PAITrialJobDetail = new PAITrialJobDetail(
            trialJobId,
            'WAITING',
            paiJobName,
            Date.now(),
            trialWorkingFolder,
            form,
            logPath,
            paiJobDetailUrl);

        this.trialJobsMap.set(trialJobId, trialJobDetail);
        this.jobQueue.push(trialJobId);

        return trialJobDetail;
    }

liuzhe-lz's avatar
liuzhe-lz committed
333
    private async generateNNITrialCommand(trialJobDetail: PAITrialJobDetail, command: string): Promise<string> {
334
335
        const containerNFSExpCodeDir = `${this.config.containerStorageMountPoint}/${this.experimentId}/nni-code`;
        const containerWorkingDir: string = `${this.config.containerStorageMountPoint}/${this.experimentId}/${trialJobDetail.id}`;
SparkSnail's avatar
SparkSnail committed
336
337
338
339
340
341
342
        const nniPaiTrialCommand: string = String.Format(
            PAI_TRIAL_COMMAND_FORMAT,
            `${containerWorkingDir}`,
            `${containerWorkingDir}/nnioutput`,
            trialJobDetail.id,
            this.experimentId,
            trialJobDetail.form.sequenceId,
343
            false,  // multi-phase
SparkSnail's avatar
SparkSnail committed
344
345
            containerNFSExpCodeDir,
            command,
liuzhe-lz's avatar
liuzhe-lz committed
346
            this.config.nniManagerIp || await getIPV4Address(),
SparkSnail's avatar
SparkSnail committed
347
348
349
350
351
352
353
354
355
356
            this.paiRestServerPort,
            this.nniVersion,
            this.logCollection
        )
            .replace(/\r\n|\n|\r/gm, '');

        return nniPaiTrialCommand;

    }

liuzhe-lz's avatar
liuzhe-lz committed
357
    private async generateJobConfigInYamlFormat(trialJobDetail: PAITrialJobDetail): Promise<any> {
SparkSnail's avatar
SparkSnail committed
358
359
360
        const jobName = `nni_exp_${this.experimentId}_trial_${trialJobDetail.id}`

        let nniJobConfig: any = undefined;
361
362
        if (this.config.openpaiConfig !== undefined) {
            nniJobConfig = JSON.parse(JSON.stringify(this.config.openpaiConfig)); //Trick for deep clone in Typescript
SparkSnail's avatar
SparkSnail committed
363
364
365
366
367
            nniJobConfig.name = jobName;
            // Each taskRole will generate new command in NNI's command format
            // Each command will be formatted to NNI style
            for (const taskRoleIndex in nniJobConfig.taskRoles) {
                const commands = nniJobConfig.taskRoles[taskRoleIndex].commands
liuzhe-lz's avatar
liuzhe-lz committed
368
                const nniTrialCommand = await this.generateNNITrialCommand(trialJobDetail, commands.join(" && ").replace(/(["'$`\\])/g, '\\$1'));
SparkSnail's avatar
SparkSnail committed
369
370
371
372
373
374
375
376
377
378
379
380
                nniJobConfig.taskRoles[taskRoleIndex].commands = [nniTrialCommand]
            }

        } else {
            nniJobConfig = {
                protocolVersion: 2,
                name: jobName,
                type: 'job',
                jobRetryCount: 0,
                prerequisites: [
                    {
                        type: 'dockerimage',
381
                        uri: this.config.dockerImage,
SparkSnail's avatar
SparkSnail committed
382
383
384
385
386
387
388
389
390
391
392
393
394
                        name: 'docker_image_0'
                    }
                ],
                taskRoles: {
                    taskrole: {
                        instances: 1,
                        completion: {
                            minFailedInstances: 1,
                            minSucceededInstances: -1
                        },
                        taskRetryCount: 0,
                        dockerImage: 'docker_image_0',
                        resourcePerInstance: {
395
396
397
                            gpu: this.config.trialGpuNumber,
                            cpu: this.config.trialCpuNumber,
                            memoryMB: toMegaBytes(this.config.trialMemorySize)
SparkSnail's avatar
SparkSnail committed
398
399
                        },
                        commands: [
liuzhe-lz's avatar
liuzhe-lz committed
400
                            await this.generateNNITrialCommand(trialJobDetail, this.config.trialCommand)
SparkSnail's avatar
SparkSnail committed
401
402
403
404
405
406
                        ]
                    }
                },
                extras: {
                    'storages': [
                        {
407
                            name: this.config.storageConfigName
SparkSnail's avatar
SparkSnail committed
408
409
410
411
412
                        }
                    ],
                    submitFrom: 'submit-job-v2'
                }
            }
413
            if (this.config.deprecated && this.config.deprecated.virtualCluster) {
SparkSnail's avatar
SparkSnail committed
414
                nniJobConfig.defaults = {
415
                    virtualCluster: this.config.deprecated.virtualCluster
SparkSnail's avatar
SparkSnail committed
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
                }
            }
        }
        return yaml.safeDump(nniJobConfig);
    }

    protected async submitTrialJobToPAI(trialJobId: string): Promise<boolean> {
        const deferred: Deferred<boolean> = new Deferred<boolean>();
        const trialJobDetail: PAITrialJobDetail | undefined = this.trialJobsMap.get(trialJobId);

        if (trialJobDetail === undefined) {
            throw new Error(`Failed to find PAITrialJobDetail for job ${trialJobId}`);
        }

        if (this.paiJobRestServer === undefined) {
            throw new Error('paiJobRestServer is not initialized');
        }

        // Make sure experiment code files is copied from local to NFS
        if (this.copyExpCodeDirPromise !== undefined) {
            await this.copyExpCodeDirPromise;
            this.log.info(`Copy codeDir data finished.`);
            // All trials share same destination NFS code folder, only copy codeDir once for an experiment.
            // After copy data finished, set copyExpCodeDirPromise be undefined to avoid log content duplicated.
            this.copyExpCodeDirPromise = undefined;
        }

        this.paiRestServerPort = this.paiJobRestServer.clusterRestServerPort;

        // Step 1. Prepare PAI job configuration
        //create trial local working folder locally.
        await execMkdir(trialJobDetail.logPath);
        // Write NNI installation file to local files
        await fs.promises.writeFile(path.join(trialJobDetail.logPath, 'install_nni.sh'), CONTAINER_INSTALL_NNI_SHELL_FORMAT, { encoding: 'utf8' });

        // Write file content ( parameter.cfg ) to local working folders
        if (trialJobDetail.form !== undefined) {
            await this.writeParameterFile(trialJobDetail.logPath, trialJobDetail.form.hyperParameters);
        }

        //Generate Job Configuration in yaml format
liuzhe-lz's avatar
liuzhe-lz committed
457
        const paiJobConfig = await this.generateJobConfigInYamlFormat(trialJobDetail);
SparkSnail's avatar
SparkSnail committed
458
459
460
461
        this.log.debug(paiJobConfig);
        // Step 2. Submit PAI job via Rest call
        // Refer https://github.com/Microsoft/pai/blob/master/docs/rest-server/API.md for more detail about PAI Rest API
        const submitJobRequest: request.Options = {
462
            uri: `${this.config.host}/rest-server/api/v2/jobs`,
SparkSnail's avatar
SparkSnail committed
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
            method: 'POST',
            body: paiJobConfig,
            followAllRedirects: true,
            headers: {
                'Content-Type': 'text/yaml',
                Authorization: `Bearer ${this.paiToken}`
            }
        };
        request(submitJobRequest, (error: Error, response: request.Response, body: any) => {
            // If submit success, will get status code 202. refer: https://github.com/microsoft/pai/blob/master/src/rest-server/docs/swagger.yaml
            if ((error !== undefined && error !== null) || response.statusCode >= 400) {
                const errorMessage: string = (error !== undefined && error !== null) ? error.message :
                    `Submit trial ${trialJobId} failed, http code:${response.statusCode}, http body: ${body}`;
                this.log.error(errorMessage);
                trialJobDetail.status = 'FAILED';
                deferred.reject(errorMessage);
            } else {
                trialJobDetail.submitTime = Date.now();
            }
            deferred.resolve(true);
        });

        return deferred.promise;
    }

    private async writeParameterFile(directory: string, hyperParameters: HyperParameters): Promise<void> {
        const filepath: string = path.join(directory, generateParamFileName(hyperParameters));
        await fs.promises.writeFile(filepath, hyperParameters.value, { encoding: 'utf8' });
    }
J-shang's avatar
J-shang committed
492
493
494
495
496
497
498
499

    public getTrialOutputLocalPath(_trialJobId: string): Promise<string> {
        throw new MethodNotImplementedError();
    }

    public fetchTrialOutput(_trialJobId: string, _subpath: string): Promise<void> {
        throw new MethodNotImplementedError();
    }
500
501
}

502
export { PAITrainingService };