"src/lib/i18n/locales/nl-NL/translation.json" did not exist on "66342140a397d8966ea8f7f08e6c43316d4ab5e2"
paiTrainingService.ts 21 KB
Newer Older
liuzhe-lz's avatar
liuzhe-lz committed
1
2
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
3

4
5
6
7
import fs from 'fs';
import path from 'path';
import request from 'request';
import * as component from 'common/component';
8
9

import { EventEmitter } from 'events';
10
import { Deferred } from 'ts-deferred';
11
12
13
import { getExperimentId } from 'common/experimentStartupInfo';
import { getLogger, Logger } from 'common/log';
import { MethodNotImplementedError } from 'common/errors';
14
import {
SparkSnail's avatar
SparkSnail committed
15
    HyperParameters, NNIManagerIpConfig, TrainingService,
Yuge Zhang's avatar
Yuge Zhang committed
16
    TrialJobApplicationForm, TrialJobDetail, TrialJobMetric
17
18
19
} from 'common/trainingService';
import { delay } from 'common/utils';
import { ExperimentConfig, OpenpaiConfig, flattenConfig, toMegaBytes } from 'common/experimentConfig';
20
import { PAIJobInfoCollector } from './paiJobInfoCollector';
21
import { PAIJobRestServer } from './paiJobRestServer';
22
import { PAITrialJobDetail, PAI_TRIAL_COMMAND_FORMAT } from './paiConfig';
SparkSnail's avatar
SparkSnail committed
23
import { String } from 'typescript-string-operations';
24
import { generateParamFileName, getIPV4Address, uniqueString } from 'common/utils';
SparkSnail's avatar
SparkSnail committed
25
26
27
28
import { CONTAINER_INSTALL_NNI_SHELL_FORMAT } from '../common/containerJobData';
import { execMkdir, validateCodeDir, execCopydir } from '../common/util';

const yaml = require('js-yaml');
29

30
31
interface FlattenOpenpaiConfig extends ExperimentConfig, OpenpaiConfig { }

32
33
34
35
36
/**
 * Training Service implementation for OpenPAI (Open Platform for AI)
 * Refer https://github.com/Microsoft/pai for more info about OpenPAI
 */
@component.Singleton
SparkSnail's avatar
SparkSnail committed
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
class PAITrainingService implements TrainingService {
    private readonly log!: Logger;
    private readonly metricsEmitter: EventEmitter;
    private readonly trialJobsMap: Map<string, PAITrialJobDetail>;
    private readonly expRootDir: string;
    private readonly jobQueue: string[];
    private stopping: boolean = false;
    private paiToken?: string;
    private paiTokenUpdateTime?: number;
    private readonly paiTokenUpdateInterval: number;
    private readonly experimentId!: string;
    private readonly paiJobCollector: PAIJobInfoCollector;
    private paiRestServerPort?: number;
    private nniManagerIpConfig?: NNIManagerIpConfig;
    private versionCheck: boolean = true;
52
    private logCollection: string = 'none';
SparkSnail's avatar
SparkSnail committed
53
    private paiJobRestServer?: PAIJobRestServer;
54
    private protocol: string;
SparkSnail's avatar
SparkSnail committed
55
56
57
    private copyExpCodeDirPromise?: Promise<void>;
    private paiJobConfig: any;
    private nniVersion: string | undefined;
58
    private config: FlattenOpenpaiConfig;
59

60
    constructor(config: ExperimentConfig) {
liuzhe-lz's avatar
liuzhe-lz committed
61
        this.log = getLogger('PAITrainingService');
62
63
        this.metricsEmitter = new EventEmitter();
        this.trialJobsMap = new Map<string, PAITrialJobDetail>();
64
        this.jobQueue = [];
Junwei Sun's avatar
Junwei Sun committed
65
        this.expRootDir = path.join('/nni-experiments', getExperimentId());
66
        this.experimentId = getExperimentId();
67
        this.paiJobCollector = new PAIJobInfoCollector(this.trialJobsMap);
68
        this.paiTokenUpdateInterval = 7200000; //2hours
69
        this.log.info('Construct paiBase training service.');
70
        this.config = flattenConfig(config, 'openpai');
liuzhe-lz's avatar
liuzhe-lz committed
71
        this.versionCheck = !this.config.debug;
72
73
74
75
76
77
78
79
        this.paiJobRestServer = new PAIJobRestServer(this);
        this.paiToken = this.config.token;
        this.protocol = this.config.host.toLowerCase().startsWith('https://') ? 'https' : 'http';
        this.copyExpCodeDirPromise = this.copyTrialCode();
    }

    private async copyTrialCode(): Promise<void> {
        await validateCodeDir(this.config.trialCodeDirectory);
liuzhe-lz's avatar
liuzhe-lz committed
80
        const nniManagerNFSExpCodeDir = path.join(this.config.localStorageMountPoint, this.experimentId, 'nni-code');
81
82
83
        await execMkdir(nniManagerNFSExpCodeDir);
        this.log.info(`Starting copy codeDir data from ${this.config.trialCodeDirectory} to ${nniManagerNFSExpCodeDir}`);
        await execCopydir(this.config.trialCodeDirectory, nniManagerNFSExpCodeDir);
84
85
86
    }

    public async run(): Promise<void> {
chicm-ms's avatar
chicm-ms committed
87
        this.log.info('Run PAI training service.');
88
89
90
91
92
93
        if (this.paiJobRestServer === undefined) {
            throw new Error('paiJobRestServer not initialized!');
        }
        await this.paiJobRestServer.start();
        this.paiJobRestServer.setEnableVersionCheck = this.versionCheck;
        this.log.info(`PAI Training service rest server listening on: ${this.paiJobRestServer.endPoint}`);
94
95
96
        await Promise.all([
            this.statusCheckingLoop(),
            this.submitJobLoop()]);
chicm-ms's avatar
chicm-ms committed
97
        this.log.info('PAI training service exit.');
98
99
    }

100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
    protected async submitJobLoop(): Promise<void> {
        while (!this.stopping) {
            while (!this.stopping && this.jobQueue.length > 0) {
                const trialJobId: string = this.jobQueue[0];
                if (await this.submitTrialJobToPAI(trialJobId)) {
                    // Remove trial job with trialJobId from job queue
                    this.jobQueue.shift();
                } else {
                    // Break the while loop since failed to submitJob
                    break;
                }
            }
            await delay(3000);
        }
    }

116
117
    public async listTrialJobs(): Promise<TrialJobDetail[]> {
        const jobs: TrialJobDetail[] = [];
118

119
        for (const key of this.trialJobsMap.keys()) {
120
            jobs.push(await this.getTrialJob(key));
121
        }
122

123
        return jobs;
124
125
    }

Yuge Zhang's avatar
Yuge Zhang committed
126
    public async getTrialFile(_trialJobId: string, _fileName: string): Promise<string | Buffer> {
127
128
129
        throw new MethodNotImplementedError();
    }

130
    public async getTrialJob(trialJobId: string): Promise<TrialJobDetail> {
131
132
        const paiTrialJob: PAITrialJobDetail | undefined = this.trialJobsMap.get(trialJobId);

133
        if (paiTrialJob === undefined) {
134
            throw new Error(`trial job ${trialJobId} not found`);
135
        }
136

137
        return paiTrialJob;
138
139
    }

140
    public addTrialJobMetricListener(listener: (metric: TrialJobMetric) => void): void {
141
142
143
        this.metricsEmitter.on('metric', listener);
    }

144
    public removeTrialJobMetricListener(listener: (metric: TrialJobMetric) => void): void {
145
146
147
        this.metricsEmitter.off('metric', listener);
    }

QuanluZhang's avatar
QuanluZhang committed
148
    public cancelTrialJob(trialJobId: string, isEarlyStopped: boolean = false): Promise<void> {
149
        const trialJobDetail: PAITrialJobDetail | undefined = this.trialJobsMap.get(trialJobId);
150
        if (trialJobDetail === undefined) {
chicm-ms's avatar
chicm-ms committed
151
            return Promise.reject(new Error(`cancelTrialJob: trial job id ${trialJobId} not found`));
152
153
        }

chicm-ms's avatar
chicm-ms committed
154
155
156
        if (trialJobDetail.status === 'UNKNOWN') {
            trialJobDetail.status = 'USER_CANCELED';
            return Promise.resolve();
157
158
159
        }

        const stopJobRequest: request.Options = {
160
            uri: `${this.config.host}/rest-server/api/v2/jobs/${this.config.username}~${trialJobDetail.paiJobName}/executionType`,
161
162
            method: 'PUT',
            json: true,
163
            body: { value: 'STOP' },
164
            headers: {
165
166
                'Content-Type': 'application/json',
                Authorization: `Bearer ${this.paiToken}`
167
168
            }
        };
169
170
171

        // Set trialjobDetail's early stopped field, to mark the job's cancellation source
        trialJobDetail.isEarlyStopped = isEarlyStopped;
chicm-ms's avatar
chicm-ms committed
172
        const deferred: Deferred<void> = new Deferred<void>();
173

174
        request(stopJobRequest, (error: Error, response: request.Response, _body: any) => {
175
            // Status code 202 for success.
176
            if ((error !== undefined && error !== null) || response.statusCode >= 400) {
177
                this.log.error(`PAI Training service: stop trial ${trialJobId} to PAI Cluster failed!`);
178
                deferred.reject((error !== undefined && error !== null) ? error.message :
179
                    `Stop trial failed, http code: ${response.statusCode}`);
180
181
182
183
184
            } else {
                deferred.resolve();
            }
        });

185
        return deferred.promise;
186
187
188
    }

    public async cleanUp(): Promise<void> {
chicm-ms's avatar
chicm-ms committed
189
        this.log.info('Stopping PAI training service...');
190
191
        this.stopping = true;

192
193
194
195
        if (this.paiJobRestServer === undefined) {
            throw new Error('paiJobRestServer not initialized!');
        }

196
        try {
197
            await this.paiJobRestServer.stop();
198
199
            this.log.info('PAI Training service rest server stopped successfully.');
        } catch (error) {
200
            this.log.error(`PAI Training service rest server stopped failed, error: ${error.message}`);
201
202
203
        }
    }

chicm-ms's avatar
chicm-ms committed
204
    public get MetricsEmitter(): EventEmitter {
205
206
        return this.metricsEmitter;
    }
207

SparkSnail's avatar
SparkSnail committed
208
209
210
    protected formatPAIHost(host: string): string {
        // If users' host start with 'http://' or 'https://', use the original host,
        // or format to 'http//${host}'
211
212
213
214
215
216
        if (host.startsWith('http://')) {
            this.protocol = 'http';
            return host.replace('http://', '');
        } else if (host.startsWith('https://')) {
            this.protocol = 'https';
            return host.replace('https://', '');
SparkSnail's avatar
SparkSnail committed
217
        } else {
218
            return host;
SparkSnail's avatar
SparkSnail committed
219
220
221
        }
    }

222
    protected async statusCheckingLoop(): Promise<void> {
223
        while (!this.stopping) {
224
            if (this.config.deprecated && this.config.deprecated.password) {
225
226
227
228
                try {
                    await this.updatePaiToken();
                } catch (error) {
                    this.log.error(`${error}`);
SparkSnail's avatar
SparkSnail committed
229
230
                }
            }
231
            await this.paiJobCollector.retrieveTrialStatus(this.protocol, this.paiToken, this.config);
232
233
            if (this.paiJobRestServer === undefined) {
                throw new Error('paiBaseJobRestServer not implemented!');
234
            }
235
236
            if (this.paiJobRestServer.getErrorMessage !== undefined) {
                throw new Error(this.paiJobRestServer.getErrorMessage);
237
238
239
240
241
            }
            await delay(3000);
        }
    }

242
243
244
    /**
     * Update pai token by the interval time or initialize the pai token
     */
245
    protected async updatePaiToken(): Promise<void> {
chicm-ms's avatar
chicm-ms committed
246
        const deferred: Deferred<void> = new Deferred<void>();
247
248

        const currentTime: number = new Date().getTime();
249
        //If pai token initialized and not reach the interval time, do not update
250
        if (this.paiTokenUpdateTime !== undefined && (currentTime - this.paiTokenUpdateTime) < this.paiTokenUpdateInterval) {
251
252
            return Promise.resolve();
        }
253

254
        const authenticationReq: request.Options = {
255
            uri: `${this.config.host}/rest-server/api/v1/token`,
256
257
258
            method: 'POST',
            json: true,
            body: {
259
260
                username: this.config.username,
                password: this.config.deprecated.password
261
262
263
            }
        };

264
265
        request(authenticationReq, (error: Error, response: request.Response, body: any) => {
            if (error !== undefined && error !== null) {
266
267
268
                this.log.error(`Get PAI token failed: ${error.message}`);
                deferred.reject(new Error(`Get PAI token failed: ${error.message}`));
            } else {
269
                if (response.statusCode !== 200) {
270
                    this.log.error(`Get PAI token failed: get PAI Rest return code ${response.statusCode}`);
271
                    deferred.reject(new Error(`Get PAI token failed: ${response.body}, please check paiConfig username or password`));
272
273
274
275
276
277
                }
                this.paiToken = body.token;
                this.paiTokenUpdateTime = new Date().getTime();
                deferred.resolve();
            }
        });
278

279
        let timeoutId: NodeJS.Timer;
280
        const timeoutDelay: Promise<void> = new Promise<void>((_resolve: Function, reject: Function): void => {
281
282
283
284
285
286
            // Set timeout and reject the promise once reach timeout (5 seconds)
            timeoutId = setTimeout(
                () => reject(new Error('Get PAI token timeout. Please check your PAI cluster.')),
                5000);
        });

287
        return Promise.race([timeoutDelay, deferred.promise])
288
            .finally(() => { clearTimeout(timeoutId); });
289
    }
SparkSnail's avatar
SparkSnail committed
290

291
292
    public async setClusterMetadata(_key: string, _value: string): Promise<void> { return; }
    public async getClusterMetadata(_key: string): Promise<string> { return ''; }
SparkSnail's avatar
SparkSnail committed
293
294
295
296
297
298
299
300
301
302
303
304
305
306

    // update trial parameters for multi-phase
    public async updateTrialJob(trialJobId: string, form: TrialJobApplicationForm): Promise<TrialJobDetail> {
        const trialJobDetail: PAITrialJobDetail | undefined = this.trialJobsMap.get(trialJobId);
        if (trialJobDetail === undefined) {
            throw new Error(`updateTrialJob failed: ${trialJobId} not found`);
        }
        // Write file content ( parameter.cfg ) to working folders
        await this.writeParameterFile(trialJobDetail.logPath, form.hyperParameters);

        return trialJobDetail;
    }

    public async submitTrialJob(form: TrialJobApplicationForm): Promise<TrialJobDetail> {
liuzhe-lz's avatar
liuzhe-lz committed
307
        this.log.info('submitTrialJob: form:',  form);
SparkSnail's avatar
SparkSnail committed
308
309
310
311
312

        const trialJobId: string = uniqueString(5);
        //TODO: use HDFS working folder instead
        const trialWorkingFolder: string = path.join(this.expRootDir, 'trials', trialJobId);
        const paiJobName: string = `nni_exp_${this.experimentId}_trial_${trialJobId}`;
313
314
        const logPath: string = path.join(this.config.localStorageMountPoint, this.experimentId, trialJobId);
        const paiJobDetailUrl: string = `${this.config.host}/job-detail.html?username=${this.config.username}&jobName=${paiJobName}`;
SparkSnail's avatar
SparkSnail committed
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
        const trialJobDetail: PAITrialJobDetail = new PAITrialJobDetail(
            trialJobId,
            'WAITING',
            paiJobName,
            Date.now(),
            trialWorkingFolder,
            form,
            logPath,
            paiJobDetailUrl);

        this.trialJobsMap.set(trialJobId, trialJobDetail);
        this.jobQueue.push(trialJobId);

        return trialJobDetail;
    }

liuzhe-lz's avatar
liuzhe-lz committed
331
    private async generateNNITrialCommand(trialJobDetail: PAITrialJobDetail, command: string): Promise<string> {
332
333
        const containerNFSExpCodeDir = `${this.config.containerStorageMountPoint}/${this.experimentId}/nni-code`;
        const containerWorkingDir: string = `${this.config.containerStorageMountPoint}/${this.experimentId}/${trialJobDetail.id}`;
SparkSnail's avatar
SparkSnail committed
334
335
336
337
338
339
340
        const nniPaiTrialCommand: string = String.Format(
            PAI_TRIAL_COMMAND_FORMAT,
            `${containerWorkingDir}`,
            `${containerWorkingDir}/nnioutput`,
            trialJobDetail.id,
            this.experimentId,
            trialJobDetail.form.sequenceId,
341
            false,  // multi-phase
SparkSnail's avatar
SparkSnail committed
342
343
            containerNFSExpCodeDir,
            command,
liuzhe-lz's avatar
liuzhe-lz committed
344
            this.config.nniManagerIp || await getIPV4Address(),
SparkSnail's avatar
SparkSnail committed
345
346
347
348
349
350
351
352
353
354
            this.paiRestServerPort,
            this.nniVersion,
            this.logCollection
        )
            .replace(/\r\n|\n|\r/gm, '');

        return nniPaiTrialCommand;

    }

liuzhe-lz's avatar
liuzhe-lz committed
355
    private async generateJobConfigInYamlFormat(trialJobDetail: PAITrialJobDetail): Promise<any> {
SparkSnail's avatar
SparkSnail committed
356
357
358
        const jobName = `nni_exp_${this.experimentId}_trial_${trialJobDetail.id}`

        let nniJobConfig: any = undefined;
359
360
        if (this.config.openpaiConfig !== undefined) {
            nniJobConfig = JSON.parse(JSON.stringify(this.config.openpaiConfig)); //Trick for deep clone in Typescript
SparkSnail's avatar
SparkSnail committed
361
362
363
364
365
            nniJobConfig.name = jobName;
            // Each taskRole will generate new command in NNI's command format
            // Each command will be formatted to NNI style
            for (const taskRoleIndex in nniJobConfig.taskRoles) {
                const commands = nniJobConfig.taskRoles[taskRoleIndex].commands
liuzhe-lz's avatar
liuzhe-lz committed
366
                const nniTrialCommand = await this.generateNNITrialCommand(trialJobDetail, commands.join(" && ").replace(/(["'$`\\])/g, '\\$1'));
SparkSnail's avatar
SparkSnail committed
367
368
369
370
371
372
373
374
375
376
377
378
                nniJobConfig.taskRoles[taskRoleIndex].commands = [nniTrialCommand]
            }

        } else {
            nniJobConfig = {
                protocolVersion: 2,
                name: jobName,
                type: 'job',
                jobRetryCount: 0,
                prerequisites: [
                    {
                        type: 'dockerimage',
379
                        uri: this.config.dockerImage,
SparkSnail's avatar
SparkSnail committed
380
381
382
383
384
385
386
387
388
389
390
391
392
                        name: 'docker_image_0'
                    }
                ],
                taskRoles: {
                    taskrole: {
                        instances: 1,
                        completion: {
                            minFailedInstances: 1,
                            minSucceededInstances: -1
                        },
                        taskRetryCount: 0,
                        dockerImage: 'docker_image_0',
                        resourcePerInstance: {
393
394
395
                            gpu: this.config.trialGpuNumber,
                            cpu: this.config.trialCpuNumber,
                            memoryMB: toMegaBytes(this.config.trialMemorySize)
SparkSnail's avatar
SparkSnail committed
396
397
                        },
                        commands: [
liuzhe-lz's avatar
liuzhe-lz committed
398
                            await this.generateNNITrialCommand(trialJobDetail, this.config.trialCommand)
SparkSnail's avatar
SparkSnail committed
399
400
401
402
403
404
                        ]
                    }
                },
                extras: {
                    'storages': [
                        {
405
                            name: this.config.storageConfigName
SparkSnail's avatar
SparkSnail committed
406
407
408
409
410
                        }
                    ],
                    submitFrom: 'submit-job-v2'
                }
            }
411
            if (this.config.virtualCluster) {
SparkSnail's avatar
SparkSnail committed
412
                nniJobConfig.defaults = {
413
                    virtualCluster: this.config.virtualCluster
SparkSnail's avatar
SparkSnail committed
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
                }
            }
        }
        return yaml.safeDump(nniJobConfig);
    }

    protected async submitTrialJobToPAI(trialJobId: string): Promise<boolean> {
        const deferred: Deferred<boolean> = new Deferred<boolean>();
        const trialJobDetail: PAITrialJobDetail | undefined = this.trialJobsMap.get(trialJobId);

        if (trialJobDetail === undefined) {
            throw new Error(`Failed to find PAITrialJobDetail for job ${trialJobId}`);
        }

        if (this.paiJobRestServer === undefined) {
            throw new Error('paiJobRestServer is not initialized');
        }

        // Make sure experiment code files is copied from local to NFS
        if (this.copyExpCodeDirPromise !== undefined) {
            await this.copyExpCodeDirPromise;
            this.log.info(`Copy codeDir data finished.`);
            // All trials share same destination NFS code folder, only copy codeDir once for an experiment.
            // After copy data finished, set copyExpCodeDirPromise be undefined to avoid log content duplicated.
            this.copyExpCodeDirPromise = undefined;
        }

        this.paiRestServerPort = this.paiJobRestServer.clusterRestServerPort;

        // Step 1. Prepare PAI job configuration
        //create trial local working folder locally.
        await execMkdir(trialJobDetail.logPath);
        // Write NNI installation file to local files
        await fs.promises.writeFile(path.join(trialJobDetail.logPath, 'install_nni.sh'), CONTAINER_INSTALL_NNI_SHELL_FORMAT, { encoding: 'utf8' });

        // Write file content ( parameter.cfg ) to local working folders
        if (trialJobDetail.form !== undefined) {
            await this.writeParameterFile(trialJobDetail.logPath, trialJobDetail.form.hyperParameters);
        }

        //Generate Job Configuration in yaml format
liuzhe-lz's avatar
liuzhe-lz committed
455
        const paiJobConfig = await this.generateJobConfigInYamlFormat(trialJobDetail);
SparkSnail's avatar
SparkSnail committed
456
457
458
459
        this.log.debug(paiJobConfig);
        // Step 2. Submit PAI job via Rest call
        // Refer https://github.com/Microsoft/pai/blob/master/docs/rest-server/API.md for more detail about PAI Rest API
        const submitJobRequest: request.Options = {
460
            uri: `${this.config.host}/rest-server/api/v2/jobs`,
SparkSnail's avatar
SparkSnail committed
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
            method: 'POST',
            body: paiJobConfig,
            followAllRedirects: true,
            headers: {
                'Content-Type': 'text/yaml',
                Authorization: `Bearer ${this.paiToken}`
            }
        };
        request(submitJobRequest, (error: Error, response: request.Response, body: any) => {
            // If submit success, will get status code 202. refer: https://github.com/microsoft/pai/blob/master/src/rest-server/docs/swagger.yaml
            if ((error !== undefined && error !== null) || response.statusCode >= 400) {
                const errorMessage: string = (error !== undefined && error !== null) ? error.message :
                    `Submit trial ${trialJobId} failed, http code:${response.statusCode}, http body: ${body}`;
                this.log.error(errorMessage);
                trialJobDetail.status = 'FAILED';
                deferred.reject(errorMessage);
            } else {
                trialJobDetail.submitTime = Date.now();
            }
            deferred.resolve(true);
        });

        return deferred.promise;
    }

    private async writeParameterFile(directory: string, hyperParameters: HyperParameters): Promise<void> {
        const filepath: string = path.join(directory, generateParamFileName(hyperParameters));
        await fs.promises.writeFile(filepath, hyperParameters.value, { encoding: 'utf8' });
    }
J-shang's avatar
J-shang committed
490
491
492
493
494
495
496
497

    public getTrialOutputLocalPath(_trialJobId: string): Promise<string> {
        throw new MethodNotImplementedError();
    }

    public fetchTrialOutput(_trialJobId: string, _subpath: string): Promise<void> {
        throw new MethodNotImplementedError();
    }
498
499
}

500
export { PAITrainingService };