paiTrainingService.ts 21.1 KB
Newer Older
liuzhe-lz's avatar
liuzhe-lz committed
1
2
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
3

4
'use strict';
5

SparkSnail's avatar
SparkSnail committed
6
import * as fs from 'fs';
7
8
import * as path from 'path';
import * as request from 'request';
9
import * as component from '../../common/component';
10
11

import { EventEmitter } from 'events';
12
import { Deferred } from 'ts-deferred';
13
import { getExperimentId } from '../../common/experimentStartupInfo';
14
import { getLogger, Logger } from '../../common/log';
15
import { MethodNotImplementedError } from '../../common/errors';
16
import {
SparkSnail's avatar
SparkSnail committed
17
    HyperParameters, NNIManagerIpConfig, TrainingService,
Yuge Zhang's avatar
Yuge Zhang committed
18
    TrialJobApplicationForm, TrialJobDetail, TrialJobMetric
19
} from '../../common/trainingService';
20
import { delay } from '../../common/utils';
21
import { ExperimentConfig, OpenpaiConfig, flattenConfig, toMegaBytes } from '../../common/experimentConfig';
22
import { PAIJobInfoCollector } from './paiJobInfoCollector';
23
import { PAIJobRestServer } from './paiJobRestServer';
24
import { PAITrialJobDetail, PAI_TRIAL_COMMAND_FORMAT } from './paiConfig';
SparkSnail's avatar
SparkSnail committed
25
import { String } from 'typescript-string-operations';
liuzhe-lz's avatar
liuzhe-lz committed
26
import { generateParamFileName, getIPV4Address, uniqueString } from '../../common/utils';
SparkSnail's avatar
SparkSnail committed
27
28
29
30
import { CONTAINER_INSTALL_NNI_SHELL_FORMAT } from '../common/containerJobData';
import { execMkdir, validateCodeDir, execCopydir } from '../common/util';

const yaml = require('js-yaml');
31

32
33
interface FlattenOpenpaiConfig extends ExperimentConfig, OpenpaiConfig { }

34
35
36
37
38
/**
 * Training Service implementation for OpenPAI (Open Platform for AI)
 * Refer https://github.com/Microsoft/pai for more info about OpenPAI
 */
@component.Singleton
SparkSnail's avatar
SparkSnail committed
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
class PAITrainingService implements TrainingService {
    private readonly log!: Logger;
    private readonly metricsEmitter: EventEmitter;
    private readonly trialJobsMap: Map<string, PAITrialJobDetail>;
    private readonly expRootDir: string;
    private readonly jobQueue: string[];
    private stopping: boolean = false;
    private paiToken?: string;
    private paiTokenUpdateTime?: number;
    private readonly paiTokenUpdateInterval: number;
    private readonly experimentId!: string;
    private readonly paiJobCollector: PAIJobInfoCollector;
    private paiRestServerPort?: number;
    private nniManagerIpConfig?: NNIManagerIpConfig;
    private versionCheck: boolean = true;
54
    private logCollection: string = 'none';
SparkSnail's avatar
SparkSnail committed
55
    private paiJobRestServer?: PAIJobRestServer;
56
    private protocol: string;
SparkSnail's avatar
SparkSnail committed
57
58
59
    private copyExpCodeDirPromise?: Promise<void>;
    private paiJobConfig: any;
    private nniVersion: string | undefined;
60
    private config: FlattenOpenpaiConfig;
61

62
    constructor(config: ExperimentConfig) {
liuzhe-lz's avatar
liuzhe-lz committed
63
        this.log = getLogger('PAITrainingService');
64
65
        this.metricsEmitter = new EventEmitter();
        this.trialJobsMap = new Map<string, PAITrialJobDetail>();
66
        this.jobQueue = [];
Junwei Sun's avatar
Junwei Sun committed
67
        this.expRootDir = path.join('/nni-experiments', getExperimentId());
68
        this.experimentId = getExperimentId();
69
        this.paiJobCollector = new PAIJobInfoCollector(this.trialJobsMap);
70
        this.paiTokenUpdateInterval = 7200000; //2hours
71
        this.log.info('Construct paiBase training service.');
72
73
74
75
76
77
78
79
80
81
82
83
84
        this.config = flattenConfig(config, 'openpai');
        this.paiJobRestServer = new PAIJobRestServer(this);
        this.paiToken = this.config.token;
        this.protocol = this.config.host.toLowerCase().startsWith('https://') ? 'https' : 'http';
        this.copyExpCodeDirPromise = this.copyTrialCode();
    }

    private async copyTrialCode(): Promise<void> {
        await validateCodeDir(this.config.trialCodeDirectory);
        const nniManagerNFSExpCodeDir = path.join(this.config.trialCodeDirectory, this.experimentId, 'nni-code');
        await execMkdir(nniManagerNFSExpCodeDir);
        this.log.info(`Starting copy codeDir data from ${this.config.trialCodeDirectory} to ${nniManagerNFSExpCodeDir}`);
        await execCopydir(this.config.trialCodeDirectory, nniManagerNFSExpCodeDir);
85
86
87
    }

    public async run(): Promise<void> {
chicm-ms's avatar
chicm-ms committed
88
        this.log.info('Run PAI training service.');
89
90
91
92
93
94
        if (this.paiJobRestServer === undefined) {
            throw new Error('paiJobRestServer not initialized!');
        }
        await this.paiJobRestServer.start();
        this.paiJobRestServer.setEnableVersionCheck = this.versionCheck;
        this.log.info(`PAI Training service rest server listening on: ${this.paiJobRestServer.endPoint}`);
95
96
97
        await Promise.all([
            this.statusCheckingLoop(),
            this.submitJobLoop()]);
chicm-ms's avatar
chicm-ms committed
98
        this.log.info('PAI training service exit.');
99
100
    }

101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
    protected async submitJobLoop(): Promise<void> {
        while (!this.stopping) {
            while (!this.stopping && this.jobQueue.length > 0) {
                const trialJobId: string = this.jobQueue[0];
                if (await this.submitTrialJobToPAI(trialJobId)) {
                    // Remove trial job with trialJobId from job queue
                    this.jobQueue.shift();
                } else {
                    // Break the while loop since failed to submitJob
                    break;
                }
            }
            await delay(3000);
        }
    }

117
118
    public async listTrialJobs(): Promise<TrialJobDetail[]> {
        const jobs: TrialJobDetail[] = [];
119

120
        for (const key of this.trialJobsMap.keys()) {
121
            jobs.push(await this.getTrialJob(key));
122
        }
123

124
        return jobs;
125
126
    }

Yuge Zhang's avatar
Yuge Zhang committed
127
    public async getTrialFile(_trialJobId: string, _fileName: string): Promise<string | Buffer> {
128
129
130
        throw new MethodNotImplementedError();
    }

131
    public async getTrialJob(trialJobId: string): Promise<TrialJobDetail> {
132
133
        const paiTrialJob: PAITrialJobDetail | undefined = this.trialJobsMap.get(trialJobId);

134
        if (paiTrialJob === undefined) {
135
            throw new Error(`trial job ${trialJobId} not found`);
136
        }
137

138
        return paiTrialJob;
139
140
    }

141
    public addTrialJobMetricListener(listener: (metric: TrialJobMetric) => void): void {
142
143
144
        this.metricsEmitter.on('metric', listener);
    }

145
    public removeTrialJobMetricListener(listener: (metric: TrialJobMetric) => void): void {
146
147
148
        this.metricsEmitter.off('metric', listener);
    }

QuanluZhang's avatar
QuanluZhang committed
149
    public cancelTrialJob(trialJobId: string, isEarlyStopped: boolean = false): Promise<void> {
150
        const trialJobDetail: PAITrialJobDetail | undefined = this.trialJobsMap.get(trialJobId);
151
        if (trialJobDetail === undefined) {
chicm-ms's avatar
chicm-ms committed
152
            return Promise.reject(new Error(`cancelTrialJob: trial job id ${trialJobId} not found`));
153
154
        }

chicm-ms's avatar
chicm-ms committed
155
156
157
        if (trialJobDetail.status === 'UNKNOWN') {
            trialJobDetail.status = 'USER_CANCELED';
            return Promise.resolve();
158
159
160
        }

        const stopJobRequest: request.Options = {
161
            uri: `${this.config.host}/rest-server/api/v2/jobs/${this.config.username}~${trialJobDetail.paiJobName}/executionType`,
162
163
            method: 'PUT',
            json: true,
164
            body: { value: 'STOP' },
165
            headers: {
166
167
                'Content-Type': 'application/json',
                Authorization: `Bearer ${this.paiToken}`
168
169
            }
        };
170
171
172

        // Set trialjobDetail's early stopped field, to mark the job's cancellation source
        trialJobDetail.isEarlyStopped = isEarlyStopped;
chicm-ms's avatar
chicm-ms committed
173
        const deferred: Deferred<void> = new Deferred<void>();
174

175
        request(stopJobRequest, (error: Error, response: request.Response, _body: any) => {
176
            // Status code 202 for success.
177
            if ((error !== undefined && error !== null) || response.statusCode >= 400) {
178
                this.log.error(`PAI Training service: stop trial ${trialJobId} to PAI Cluster failed!`);
179
                deferred.reject((error !== undefined && error !== null) ? error.message :
180
                    `Stop trial failed, http code: ${response.statusCode}`);
181
182
183
184
185
            } else {
                deferred.resolve();
            }
        });

186
        return deferred.promise;
187
188
189
    }

    public async cleanUp(): Promise<void> {
chicm-ms's avatar
chicm-ms committed
190
        this.log.info('Stopping PAI training service...');
191
192
        this.stopping = true;

193
194
195
196
        if (this.paiJobRestServer === undefined) {
            throw new Error('paiJobRestServer not initialized!');
        }

197
        try {
198
            await this.paiJobRestServer.stop();
199
200
            this.log.info('PAI Training service rest server stopped successfully.');
        } catch (error) {
201
            this.log.error(`PAI Training service rest server stopped failed, error: ${error.message}`);
202
203
204
        }
    }

chicm-ms's avatar
chicm-ms committed
205
    public get MetricsEmitter(): EventEmitter {
206
207
        return this.metricsEmitter;
    }
208

SparkSnail's avatar
SparkSnail committed
209
210
211
    protected formatPAIHost(host: string): string {
        // If users' host start with 'http://' or 'https://', use the original host,
        // or format to 'http//${host}'
212
213
214
215
216
217
        if (host.startsWith('http://')) {
            this.protocol = 'http';
            return host.replace('http://', '');
        } else if (host.startsWith('https://')) {
            this.protocol = 'https';
            return host.replace('https://', '');
SparkSnail's avatar
SparkSnail committed
218
        } else {
219
            return host;
SparkSnail's avatar
SparkSnail committed
220
221
222
        }
    }

223
    protected async statusCheckingLoop(): Promise<void> {
224
        while (!this.stopping) {
225
            if (this.config.deprecated && this.config.deprecated.password) {
226
227
228
229
                try {
                    await this.updatePaiToken();
                } catch (error) {
                    this.log.error(`${error}`);
SparkSnail's avatar
SparkSnail committed
230
231
                }
            }
232
            await this.paiJobCollector.retrieveTrialStatus(this.protocol, this.paiToken, this.config);
233
234
            if (this.paiJobRestServer === undefined) {
                throw new Error('paiBaseJobRestServer not implemented!');
235
            }
236
237
            if (this.paiJobRestServer.getErrorMessage !== undefined) {
                throw new Error(this.paiJobRestServer.getErrorMessage);
238
239
240
241
242
            }
            await delay(3000);
        }
    }

243
244
245
    /**
     * Update pai token by the interval time or initialize the pai token
     */
246
    protected async updatePaiToken(): Promise<void> {
chicm-ms's avatar
chicm-ms committed
247
        const deferred: Deferred<void> = new Deferred<void>();
248
249

        const currentTime: number = new Date().getTime();
250
        //If pai token initialized and not reach the interval time, do not update
251
        if (this.paiTokenUpdateTime !== undefined && (currentTime - this.paiTokenUpdateTime) < this.paiTokenUpdateInterval) {
252
253
            return Promise.resolve();
        }
254

255
        const authenticationReq: request.Options = {
256
            uri: `${this.config.host}/rest-server/api/v1/token`,
257
258
259
            method: 'POST',
            json: true,
            body: {
260
261
                username: this.config.username,
                password: this.config.deprecated.password
262
263
264
            }
        };

265
266
        request(authenticationReq, (error: Error, response: request.Response, body: any) => {
            if (error !== undefined && error !== null) {
267
268
269
                this.log.error(`Get PAI token failed: ${error.message}`);
                deferred.reject(new Error(`Get PAI token failed: ${error.message}`));
            } else {
270
                if (response.statusCode !== 200) {
271
                    this.log.error(`Get PAI token failed: get PAI Rest return code ${response.statusCode}`);
272
                    deferred.reject(new Error(`Get PAI token failed: ${response.body}, please check paiConfig username or password`));
273
274
275
276
277
278
                }
                this.paiToken = body.token;
                this.paiTokenUpdateTime = new Date().getTime();
                deferred.resolve();
            }
        });
279

280
        let timeoutId: NodeJS.Timer;
281
        const timeoutDelay: Promise<void> = new Promise<void>((_resolve: Function, reject: Function): void => {
282
283
284
285
286
287
            // Set timeout and reject the promise once reach timeout (5 seconds)
            timeoutId = setTimeout(
                () => reject(new Error('Get PAI token timeout. Please check your PAI cluster.')),
                5000);
        });

288
        return Promise.race([timeoutDelay, deferred.promise])
289
            .finally(() => { clearTimeout(timeoutId); });
290
    }
SparkSnail's avatar
SparkSnail committed
291

292
293
    public async setClusterMetadata(_key: string, _value: string): Promise<void> { return; }
    public async getClusterMetadata(_key: string): Promise<string> { return ''; }
SparkSnail's avatar
SparkSnail committed
294
295
296
297
298
299
300
301
302
303
304
305
306
307

    // update trial parameters for multi-phase
    public async updateTrialJob(trialJobId: string, form: TrialJobApplicationForm): Promise<TrialJobDetail> {
        const trialJobDetail: PAITrialJobDetail | undefined = this.trialJobsMap.get(trialJobId);
        if (trialJobDetail === undefined) {
            throw new Error(`updateTrialJob failed: ${trialJobId} not found`);
        }
        // Write file content ( parameter.cfg ) to working folders
        await this.writeParameterFile(trialJobDetail.logPath, form.hyperParameters);

        return trialJobDetail;
    }

    public async submitTrialJob(form: TrialJobApplicationForm): Promise<TrialJobDetail> {
liuzhe-lz's avatar
liuzhe-lz committed
308
        this.log.info('submitTrialJob: form:',  form);
SparkSnail's avatar
SparkSnail committed
309
310
311
312
313

        const trialJobId: string = uniqueString(5);
        //TODO: use HDFS working folder instead
        const trialWorkingFolder: string = path.join(this.expRootDir, 'trials', trialJobId);
        const paiJobName: string = `nni_exp_${this.experimentId}_trial_${trialJobId}`;
314
315
        const logPath: string = path.join(this.config.localStorageMountPoint, this.experimentId, trialJobId);
        const paiJobDetailUrl: string = `${this.config.host}/job-detail.html?username=${this.config.username}&jobName=${paiJobName}`;
SparkSnail's avatar
SparkSnail committed
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
        const trialJobDetail: PAITrialJobDetail = new PAITrialJobDetail(
            trialJobId,
            'WAITING',
            paiJobName,
            Date.now(),
            trialWorkingFolder,
            form,
            logPath,
            paiJobDetailUrl);

        this.trialJobsMap.set(trialJobId, trialJobDetail);
        this.jobQueue.push(trialJobId);

        return trialJobDetail;
    }

liuzhe-lz's avatar
liuzhe-lz committed
332
    private async generateNNITrialCommand(trialJobDetail: PAITrialJobDetail, command: string): Promise<string> {
333
334
        const containerNFSExpCodeDir = `${this.config.containerStorageMountPoint}/${this.experimentId}/nni-code`;
        const containerWorkingDir: string = `${this.config.containerStorageMountPoint}/${this.experimentId}/${trialJobDetail.id}`;
SparkSnail's avatar
SparkSnail committed
335
336
337
338
339
340
341
        const nniPaiTrialCommand: string = String.Format(
            PAI_TRIAL_COMMAND_FORMAT,
            `${containerWorkingDir}`,
            `${containerWorkingDir}/nnioutput`,
            trialJobDetail.id,
            this.experimentId,
            trialJobDetail.form.sequenceId,
342
            false,  // multi-phase
SparkSnail's avatar
SparkSnail committed
343
344
            containerNFSExpCodeDir,
            command,
liuzhe-lz's avatar
liuzhe-lz committed
345
            this.config.nniManagerIp || await getIPV4Address(),
SparkSnail's avatar
SparkSnail committed
346
347
348
349
350
351
352
353
354
355
            this.paiRestServerPort,
            this.nniVersion,
            this.logCollection
        )
            .replace(/\r\n|\n|\r/gm, '');

        return nniPaiTrialCommand;

    }

liuzhe-lz's avatar
liuzhe-lz committed
356
    private async generateJobConfigInYamlFormat(trialJobDetail: PAITrialJobDetail): Promise<any> {
SparkSnail's avatar
SparkSnail committed
357
358
359
        const jobName = `nni_exp_${this.experimentId}_trial_${trialJobDetail.id}`

        let nniJobConfig: any = undefined;
360
361
        if (this.config.openpaiConfig !== undefined) {
            nniJobConfig = JSON.parse(JSON.stringify(this.config.openpaiConfig)); //Trick for deep clone in Typescript
SparkSnail's avatar
SparkSnail committed
362
363
364
365
366
            nniJobConfig.name = jobName;
            // Each taskRole will generate new command in NNI's command format
            // Each command will be formatted to NNI style
            for (const taskRoleIndex in nniJobConfig.taskRoles) {
                const commands = nniJobConfig.taskRoles[taskRoleIndex].commands
liuzhe-lz's avatar
liuzhe-lz committed
367
                const nniTrialCommand = await this.generateNNITrialCommand(trialJobDetail, commands.join(" && ").replace(/(["'$`\\])/g, '\\$1'));
SparkSnail's avatar
SparkSnail committed
368
369
370
371
372
373
374
375
376
377
378
379
                nniJobConfig.taskRoles[taskRoleIndex].commands = [nniTrialCommand]
            }

        } else {
            nniJobConfig = {
                protocolVersion: 2,
                name: jobName,
                type: 'job',
                jobRetryCount: 0,
                prerequisites: [
                    {
                        type: 'dockerimage',
380
                        uri: this.config.dockerImage,
SparkSnail's avatar
SparkSnail committed
381
382
383
384
385
386
387
388
389
390
391
392
393
                        name: 'docker_image_0'
                    }
                ],
                taskRoles: {
                    taskrole: {
                        instances: 1,
                        completion: {
                            minFailedInstances: 1,
                            minSucceededInstances: -1
                        },
                        taskRetryCount: 0,
                        dockerImage: 'docker_image_0',
                        resourcePerInstance: {
394
395
396
                            gpu: this.config.trialGpuNumber,
                            cpu: this.config.trialCpuNumber,
                            memoryMB: toMegaBytes(this.config.trialMemorySize)
SparkSnail's avatar
SparkSnail committed
397
398
                        },
                        commands: [
liuzhe-lz's avatar
liuzhe-lz committed
399
                            await this.generateNNITrialCommand(trialJobDetail, this.config.trialCommand)
SparkSnail's avatar
SparkSnail committed
400
401
402
403
404
405
                        ]
                    }
                },
                extras: {
                    'storages': [
                        {
406
                            name: this.config.storageConfigName
SparkSnail's avatar
SparkSnail committed
407
408
409
410
411
                        }
                    ],
                    submitFrom: 'submit-job-v2'
                }
            }
412
            if (this.config.deprecated && this.config.deprecated.virtualCluster) {
SparkSnail's avatar
SparkSnail committed
413
                nniJobConfig.defaults = {
414
                    virtualCluster: this.config.deprecated.virtualCluster
SparkSnail's avatar
SparkSnail committed
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
                }
            }
        }
        return yaml.safeDump(nniJobConfig);
    }

    protected async submitTrialJobToPAI(trialJobId: string): Promise<boolean> {
        const deferred: Deferred<boolean> = new Deferred<boolean>();
        const trialJobDetail: PAITrialJobDetail | undefined = this.trialJobsMap.get(trialJobId);

        if (trialJobDetail === undefined) {
            throw new Error(`Failed to find PAITrialJobDetail for job ${trialJobId}`);
        }

        if (this.paiJobRestServer === undefined) {
            throw new Error('paiJobRestServer is not initialized');
        }

        // Make sure experiment code files is copied from local to NFS
        if (this.copyExpCodeDirPromise !== undefined) {
            await this.copyExpCodeDirPromise;
            this.log.info(`Copy codeDir data finished.`);
            // All trials share same destination NFS code folder, only copy codeDir once for an experiment.
            // After copy data finished, set copyExpCodeDirPromise be undefined to avoid log content duplicated.
            this.copyExpCodeDirPromise = undefined;
        }

        this.paiRestServerPort = this.paiJobRestServer.clusterRestServerPort;

        // Step 1. Prepare PAI job configuration
        //create trial local working folder locally.
        await execMkdir(trialJobDetail.logPath);
        // Write NNI installation file to local files
        await fs.promises.writeFile(path.join(trialJobDetail.logPath, 'install_nni.sh'), CONTAINER_INSTALL_NNI_SHELL_FORMAT, { encoding: 'utf8' });

        // Write file content ( parameter.cfg ) to local working folders
        if (trialJobDetail.form !== undefined) {
            await this.writeParameterFile(trialJobDetail.logPath, trialJobDetail.form.hyperParameters);
        }

        //Generate Job Configuration in yaml format
liuzhe-lz's avatar
liuzhe-lz committed
456
        const paiJobConfig = await this.generateJobConfigInYamlFormat(trialJobDetail);
SparkSnail's avatar
SparkSnail committed
457
458
459
460
        this.log.debug(paiJobConfig);
        // Step 2. Submit PAI job via Rest call
        // Refer https://github.com/Microsoft/pai/blob/master/docs/rest-server/API.md for more detail about PAI Rest API
        const submitJobRequest: request.Options = {
461
            uri: `${this.config.host}/rest-server/api/v2/jobs`,
SparkSnail's avatar
SparkSnail committed
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
            method: 'POST',
            body: paiJobConfig,
            followAllRedirects: true,
            headers: {
                'Content-Type': 'text/yaml',
                Authorization: `Bearer ${this.paiToken}`
            }
        };
        request(submitJobRequest, (error: Error, response: request.Response, body: any) => {
            // If submit success, will get status code 202. refer: https://github.com/microsoft/pai/blob/master/src/rest-server/docs/swagger.yaml
            if ((error !== undefined && error !== null) || response.statusCode >= 400) {
                const errorMessage: string = (error !== undefined && error !== null) ? error.message :
                    `Submit trial ${trialJobId} failed, http code:${response.statusCode}, http body: ${body}`;
                this.log.error(errorMessage);
                trialJobDetail.status = 'FAILED';
                deferred.reject(errorMessage);
            } else {
                trialJobDetail.submitTime = Date.now();
            }
            deferred.resolve(true);
        });

        return deferred.promise;
    }

    private async writeParameterFile(directory: string, hyperParameters: HyperParameters): Promise<void> {
        const filepath: string = path.join(directory, generateParamFileName(hyperParameters));
        await fs.promises.writeFile(filepath, hyperParameters.value, { encoding: 'utf8' });
    }
J-shang's avatar
J-shang committed
491
492
493
494
495
496
497
498

    public getTrialOutputLocalPath(_trialJobId: string): Promise<string> {
        throw new MethodNotImplementedError();
    }

    public fetchTrialOutput(_trialJobId: string, _subpath: string): Promise<void> {
        throw new MethodNotImplementedError();
    }
499
500
}

501
export { PAITrainingService };