paiTrainingService.ts 18.5 KB
Newer Older
liuzhe-lz's avatar
liuzhe-lz committed
1
2
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
3

4
5
6
7
import fs from 'fs';
import path from 'path';
import request from 'request';
import * as component from 'common/component';
8
9

import { EventEmitter } from 'events';
10
import { Deferred } from 'ts-deferred';
11
12
13
import { getExperimentId } from 'common/experimentStartupInfo';
import { getLogger, Logger } from 'common/log';
import { MethodNotImplementedError } from 'common/errors';
14
import {
SparkSnail's avatar
SparkSnail committed
15
    HyperParameters, NNIManagerIpConfig, TrainingService,
Yuge Zhang's avatar
Yuge Zhang committed
16
    TrialJobApplicationForm, TrialJobDetail, TrialJobMetric
17
18
} from 'common/trainingService';
import { delay } from 'common/utils';
liuzhe-lz's avatar
liuzhe-lz committed
19
import { OpenpaiConfig, toMegaBytes } from 'common/experimentConfig';
20
import { PAIJobInfoCollector } from './paiJobInfoCollector';
21
import { PAIJobRestServer } from './paiJobRestServer';
22
import { PAITrialJobDetail, PAI_TRIAL_COMMAND_FORMAT } from './paiConfig';
SparkSnail's avatar
SparkSnail committed
23
import { String } from 'typescript-string-operations';
24
import { generateParamFileName, getIPV4Address, uniqueString } from 'common/utils';
SparkSnail's avatar
SparkSnail committed
25
26
27
28
import { CONTAINER_INSTALL_NNI_SHELL_FORMAT } from '../common/containerJobData';
import { execMkdir, validateCodeDir, execCopydir } from '../common/util';

const yaml = require('js-yaml');
29
30
31
32
33
34

/**
 * Training Service implementation for OpenPAI (Open Platform for AI)
 * Refer https://github.com/Microsoft/pai for more info about OpenPAI
 */
@component.Singleton
SparkSnail's avatar
SparkSnail committed
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
class PAITrainingService implements TrainingService {
    private readonly log!: Logger;
    private readonly metricsEmitter: EventEmitter;
    private readonly trialJobsMap: Map<string, PAITrialJobDetail>;
    private readonly expRootDir: string;
    private readonly jobQueue: string[];
    private stopping: boolean = false;
    private paiToken?: string;
    private paiTokenUpdateTime?: number;
    private readonly paiTokenUpdateInterval: number;
    private readonly experimentId!: string;
    private readonly paiJobCollector: PAIJobInfoCollector;
    private paiRestServerPort?: number;
    private nniManagerIpConfig?: NNIManagerIpConfig;
    private versionCheck: boolean = true;
50
    private logCollection: string = 'none';
SparkSnail's avatar
SparkSnail committed
51
    private paiJobRestServer?: PAIJobRestServer;
52
    private protocol: string;
SparkSnail's avatar
SparkSnail committed
53
54
55
    private copyExpCodeDirPromise?: Promise<void>;
    private paiJobConfig: any;
    private nniVersion: string | undefined;
liuzhe-lz's avatar
liuzhe-lz committed
56
    private config: OpenpaiConfig;
57

liuzhe-lz's avatar
liuzhe-lz committed
58
    constructor(config: OpenpaiConfig) {
liuzhe-lz's avatar
liuzhe-lz committed
59
        this.log = getLogger('PAITrainingService');
60
61
        this.metricsEmitter = new EventEmitter();
        this.trialJobsMap = new Map<string, PAITrialJobDetail>();
62
        this.jobQueue = [];
Junwei Sun's avatar
Junwei Sun committed
63
        this.expRootDir = path.join('/nni-experiments', getExperimentId());
64
        this.experimentId = getExperimentId();
65
        this.paiJobCollector = new PAIJobInfoCollector(this.trialJobsMap);
66
        this.paiTokenUpdateInterval = 7200000; //2hours
67
        this.log.info('Construct paiBase training service.');
liuzhe-lz's avatar
liuzhe-lz committed
68
        this.config = config;
liuzhe-lz's avatar
liuzhe-lz committed
69
        this.versionCheck = !this.config.debug;
70
71
72
73
74
75
76
77
        this.paiJobRestServer = new PAIJobRestServer(this);
        this.paiToken = this.config.token;
        this.protocol = this.config.host.toLowerCase().startsWith('https://') ? 'https' : 'http';
        this.copyExpCodeDirPromise = this.copyTrialCode();
    }

    private async copyTrialCode(): Promise<void> {
        await validateCodeDir(this.config.trialCodeDirectory);
liuzhe-lz's avatar
liuzhe-lz committed
78
        const nniManagerNFSExpCodeDir = path.join(this.config.localStorageMountPoint, this.experimentId, 'nni-code');
79
80
81
        await execMkdir(nniManagerNFSExpCodeDir);
        this.log.info(`Starting copy codeDir data from ${this.config.trialCodeDirectory} to ${nniManagerNFSExpCodeDir}`);
        await execCopydir(this.config.trialCodeDirectory, nniManagerNFSExpCodeDir);
82
83
84
    }

    public async run(): Promise<void> {
chicm-ms's avatar
chicm-ms committed
85
        this.log.info('Run PAI training service.');
86
87
88
89
90
91
        if (this.paiJobRestServer === undefined) {
            throw new Error('paiJobRestServer not initialized!');
        }
        await this.paiJobRestServer.start();
        this.paiJobRestServer.setEnableVersionCheck = this.versionCheck;
        this.log.info(`PAI Training service rest server listening on: ${this.paiJobRestServer.endPoint}`);
92
93
94
        await Promise.all([
            this.statusCheckingLoop(),
            this.submitJobLoop()]);
chicm-ms's avatar
chicm-ms committed
95
        this.log.info('PAI training service exit.');
96
97
    }

98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
    protected async submitJobLoop(): Promise<void> {
        while (!this.stopping) {
            while (!this.stopping && this.jobQueue.length > 0) {
                const trialJobId: string = this.jobQueue[0];
                if (await this.submitTrialJobToPAI(trialJobId)) {
                    // Remove trial job with trialJobId from job queue
                    this.jobQueue.shift();
                } else {
                    // Break the while loop since failed to submitJob
                    break;
                }
            }
            await delay(3000);
        }
    }

114
115
    public async listTrialJobs(): Promise<TrialJobDetail[]> {
        const jobs: TrialJobDetail[] = [];
116

117
        for (const key of this.trialJobsMap.keys()) {
118
            jobs.push(await this.getTrialJob(key));
119
        }
120

121
        return jobs;
122
123
    }

Yuge Zhang's avatar
Yuge Zhang committed
124
    public async getTrialFile(_trialJobId: string, _fileName: string): Promise<string | Buffer> {
125
126
127
        throw new MethodNotImplementedError();
    }

128
    public async getTrialJob(trialJobId: string): Promise<TrialJobDetail> {
129
130
        const paiTrialJob: PAITrialJobDetail | undefined = this.trialJobsMap.get(trialJobId);

131
        if (paiTrialJob === undefined) {
132
            throw new Error(`trial job ${trialJobId} not found`);
133
        }
134

135
        return paiTrialJob;
136
137
    }

138
    public addTrialJobMetricListener(listener: (metric: TrialJobMetric) => void): void {
139
140
141
        this.metricsEmitter.on('metric', listener);
    }

142
    public removeTrialJobMetricListener(listener: (metric: TrialJobMetric) => void): void {
143
144
145
        this.metricsEmitter.off('metric', listener);
    }

QuanluZhang's avatar
QuanluZhang committed
146
    public cancelTrialJob(trialJobId: string, isEarlyStopped: boolean = false): Promise<void> {
147
        const trialJobDetail: PAITrialJobDetail | undefined = this.trialJobsMap.get(trialJobId);
148
        if (trialJobDetail === undefined) {
chicm-ms's avatar
chicm-ms committed
149
            return Promise.reject(new Error(`cancelTrialJob: trial job id ${trialJobId} not found`));
150
151
        }

chicm-ms's avatar
chicm-ms committed
152
153
154
        if (trialJobDetail.status === 'UNKNOWN') {
            trialJobDetail.status = 'USER_CANCELED';
            return Promise.resolve();
155
156
157
        }

        const stopJobRequest: request.Options = {
158
            uri: `${this.config.host}/rest-server/api/v2/jobs/${this.config.username}~${trialJobDetail.paiJobName}/executionType`,
159
160
            method: 'PUT',
            json: true,
161
            body: { value: 'STOP' },
162
            headers: {
163
164
                'Content-Type': 'application/json',
                Authorization: `Bearer ${this.paiToken}`
165
166
            }
        };
167
168
169

        // Set trialjobDetail's early stopped field, to mark the job's cancellation source
        trialJobDetail.isEarlyStopped = isEarlyStopped;
chicm-ms's avatar
chicm-ms committed
170
        const deferred: Deferred<void> = new Deferred<void>();
171

172
        request(stopJobRequest, (error: Error, response: request.Response, _body: any) => {
173
            // Status code 202 for success.
174
            if ((error !== undefined && error !== null) || response.statusCode >= 400) {
175
                this.log.error(`PAI Training service: stop trial ${trialJobId} to PAI Cluster failed!`);
176
                deferred.reject((error !== undefined && error !== null) ? error.message :
177
                    `Stop trial failed, http code: ${response.statusCode}`);
178
179
180
181
182
            } else {
                deferred.resolve();
            }
        });

183
        return deferred.promise;
184
185
186
    }

    public async cleanUp(): Promise<void> {
chicm-ms's avatar
chicm-ms committed
187
        this.log.info('Stopping PAI training service...');
188
189
        this.stopping = true;

190
191
192
193
        if (this.paiJobRestServer === undefined) {
            throw new Error('paiJobRestServer not initialized!');
        }

194
        try {
195
            await this.paiJobRestServer.stop();
196
197
            this.log.info('PAI Training service rest server stopped successfully.');
        } catch (error) {
198
            this.log.error(`PAI Training service rest server stopped failed, error: ${error.message}`);
199
200
201
        }
    }

chicm-ms's avatar
chicm-ms committed
202
    public get MetricsEmitter(): EventEmitter {
203
204
        return this.metricsEmitter;
    }
205

SparkSnail's avatar
SparkSnail committed
206
207
208
    protected formatPAIHost(host: string): string {
        // If users' host start with 'http://' or 'https://', use the original host,
        // or format to 'http//${host}'
209
210
211
212
213
214
        if (host.startsWith('http://')) {
            this.protocol = 'http';
            return host.replace('http://', '');
        } else if (host.startsWith('https://')) {
            this.protocol = 'https';
            return host.replace('https://', '');
SparkSnail's avatar
SparkSnail committed
215
        } else {
216
            return host;
SparkSnail's avatar
SparkSnail committed
217
218
219
        }
    }

220
    protected async statusCheckingLoop(): Promise<void> {
221
        while (!this.stopping) {
222
            await this.paiJobCollector.retrieveTrialStatus(this.protocol, this.paiToken, this.config);
223
224
            if (this.paiJobRestServer === undefined) {
                throw new Error('paiBaseJobRestServer not implemented!');
225
            }
226
227
            if (this.paiJobRestServer.getErrorMessage !== undefined) {
                throw new Error(this.paiJobRestServer.getErrorMessage);
228
229
230
231
232
            }
            await delay(3000);
        }
    }

233
234
    public async setClusterMetadata(_key: string, _value: string): Promise<void> { return; }
    public async getClusterMetadata(_key: string): Promise<string> { return ''; }
SparkSnail's avatar
SparkSnail committed
235
236
237
238
239
240
241
242
243
244
245
246
247
248

    // update trial parameters for multi-phase
    public async updateTrialJob(trialJobId: string, form: TrialJobApplicationForm): Promise<TrialJobDetail> {
        const trialJobDetail: PAITrialJobDetail | undefined = this.trialJobsMap.get(trialJobId);
        if (trialJobDetail === undefined) {
            throw new Error(`updateTrialJob failed: ${trialJobId} not found`);
        }
        // Write file content ( parameter.cfg ) to working folders
        await this.writeParameterFile(trialJobDetail.logPath, form.hyperParameters);

        return trialJobDetail;
    }

    public async submitTrialJob(form: TrialJobApplicationForm): Promise<TrialJobDetail> {
liuzhe-lz's avatar
liuzhe-lz committed
249
        this.log.info('submitTrialJob: form:',  form);
SparkSnail's avatar
SparkSnail committed
250

251
        const trialJobId: string = form.id === undefined ? uniqueString(5) : form.id;
SparkSnail's avatar
SparkSnail committed
252
253
254
        //TODO: use HDFS working folder instead
        const trialWorkingFolder: string = path.join(this.expRootDir, 'trials', trialJobId);
        const paiJobName: string = `nni_exp_${this.experimentId}_trial_${trialJobId}`;
255
256
        const logPath: string = path.join(this.config.localStorageMountPoint, this.experimentId, trialJobId);
        const paiJobDetailUrl: string = `${this.config.host}/job-detail.html?username=${this.config.username}&jobName=${paiJobName}`;
SparkSnail's avatar
SparkSnail committed
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
        const trialJobDetail: PAITrialJobDetail = new PAITrialJobDetail(
            trialJobId,
            'WAITING',
            paiJobName,
            Date.now(),
            trialWorkingFolder,
            form,
            logPath,
            paiJobDetailUrl);

        this.trialJobsMap.set(trialJobId, trialJobDetail);
        this.jobQueue.push(trialJobId);

        return trialJobDetail;
    }

liuzhe-lz's avatar
liuzhe-lz committed
273
    private async generateNNITrialCommand(trialJobDetail: PAITrialJobDetail, command: string): Promise<string> {
274
275
        const containerNFSExpCodeDir = `${this.config.containerStorageMountPoint}/${this.experimentId}/nni-code`;
        const containerWorkingDir: string = `${this.config.containerStorageMountPoint}/${this.experimentId}/${trialJobDetail.id}`;
SparkSnail's avatar
SparkSnail committed
276
277
278
279
280
281
282
        const nniPaiTrialCommand: string = String.Format(
            PAI_TRIAL_COMMAND_FORMAT,
            `${containerWorkingDir}`,
            `${containerWorkingDir}/nnioutput`,
            trialJobDetail.id,
            this.experimentId,
            trialJobDetail.form.sequenceId,
283
            false,  // multi-phase
SparkSnail's avatar
SparkSnail committed
284
285
            containerNFSExpCodeDir,
            command,
liuzhe-lz's avatar
liuzhe-lz committed
286
            this.config.nniManagerIp || await getIPV4Address(),
SparkSnail's avatar
SparkSnail committed
287
288
289
290
291
292
293
294
295
296
            this.paiRestServerPort,
            this.nniVersion,
            this.logCollection
        )
            .replace(/\r\n|\n|\r/gm, '');

        return nniPaiTrialCommand;

    }

liuzhe-lz's avatar
liuzhe-lz committed
297
    private async generateJobConfigInYamlFormat(trialJobDetail: PAITrialJobDetail): Promise<any> {
SparkSnail's avatar
SparkSnail committed
298
299
300
        const jobName = `nni_exp_${this.experimentId}_trial_${trialJobDetail.id}`

        let nniJobConfig: any = undefined;
301
302
        if (this.config.openpaiConfig !== undefined) {
            nniJobConfig = JSON.parse(JSON.stringify(this.config.openpaiConfig)); //Trick for deep clone in Typescript
SparkSnail's avatar
SparkSnail committed
303
304
305
306
307
            nniJobConfig.name = jobName;
            // Each taskRole will generate new command in NNI's command format
            // Each command will be formatted to NNI style
            for (const taskRoleIndex in nniJobConfig.taskRoles) {
                const commands = nniJobConfig.taskRoles[taskRoleIndex].commands
liuzhe-lz's avatar
liuzhe-lz committed
308
                const nniTrialCommand = await this.generateNNITrialCommand(trialJobDetail, commands.join(" && ").replace(/(["'$`\\])/g, '\\$1'));
SparkSnail's avatar
SparkSnail committed
309
310
311
312
313
314
315
316
317
318
319
320
                nniJobConfig.taskRoles[taskRoleIndex].commands = [nniTrialCommand]
            }

        } else {
            nniJobConfig = {
                protocolVersion: 2,
                name: jobName,
                type: 'job',
                jobRetryCount: 0,
                prerequisites: [
                    {
                        type: 'dockerimage',
321
                        uri: this.config.dockerImage,
SparkSnail's avatar
SparkSnail committed
322
323
324
325
326
327
328
329
330
331
332
333
334
                        name: 'docker_image_0'
                    }
                ],
                taskRoles: {
                    taskrole: {
                        instances: 1,
                        completion: {
                            minFailedInstances: 1,
                            minSucceededInstances: -1
                        },
                        taskRetryCount: 0,
                        dockerImage: 'docker_image_0',
                        resourcePerInstance: {
335
336
337
                            gpu: this.config.trialGpuNumber,
                            cpu: this.config.trialCpuNumber,
                            memoryMB: toMegaBytes(this.config.trialMemorySize)
SparkSnail's avatar
SparkSnail committed
338
339
                        },
                        commands: [
liuzhe-lz's avatar
liuzhe-lz committed
340
                            await this.generateNNITrialCommand(trialJobDetail, this.config.trialCommand)
SparkSnail's avatar
SparkSnail committed
341
342
343
344
345
346
                        ]
                    }
                },
                extras: {
                    'storages': [
                        {
347
                            name: this.config.storageConfigName
SparkSnail's avatar
SparkSnail committed
348
349
350
351
352
                        }
                    ],
                    submitFrom: 'submit-job-v2'
                }
            }
353
            if (this.config.virtualCluster) {
SparkSnail's avatar
SparkSnail committed
354
                nniJobConfig.defaults = {
355
                    virtualCluster: this.config.virtualCluster
SparkSnail's avatar
SparkSnail committed
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
                }
            }
        }
        return yaml.safeDump(nniJobConfig);
    }

    protected async submitTrialJobToPAI(trialJobId: string): Promise<boolean> {
        const deferred: Deferred<boolean> = new Deferred<boolean>();
        const trialJobDetail: PAITrialJobDetail | undefined = this.trialJobsMap.get(trialJobId);

        if (trialJobDetail === undefined) {
            throw new Error(`Failed to find PAITrialJobDetail for job ${trialJobId}`);
        }

        if (this.paiJobRestServer === undefined) {
            throw new Error('paiJobRestServer is not initialized');
        }

        // Make sure experiment code files is copied from local to NFS
        if (this.copyExpCodeDirPromise !== undefined) {
            await this.copyExpCodeDirPromise;
            this.log.info(`Copy codeDir data finished.`);
            // All trials share same destination NFS code folder, only copy codeDir once for an experiment.
            // After copy data finished, set copyExpCodeDirPromise be undefined to avoid log content duplicated.
            this.copyExpCodeDirPromise = undefined;
        }

        this.paiRestServerPort = this.paiJobRestServer.clusterRestServerPort;

        // Step 1. Prepare PAI job configuration
        //create trial local working folder locally.
        await execMkdir(trialJobDetail.logPath);
        // Write NNI installation file to local files
        await fs.promises.writeFile(path.join(trialJobDetail.logPath, 'install_nni.sh'), CONTAINER_INSTALL_NNI_SHELL_FORMAT, { encoding: 'utf8' });

        // Write file content ( parameter.cfg ) to local working folders
        if (trialJobDetail.form !== undefined) {
            await this.writeParameterFile(trialJobDetail.logPath, trialJobDetail.form.hyperParameters);
        }

        //Generate Job Configuration in yaml format
liuzhe-lz's avatar
liuzhe-lz committed
397
        const paiJobConfig = await this.generateJobConfigInYamlFormat(trialJobDetail);
SparkSnail's avatar
SparkSnail committed
398
399
400
401
        this.log.debug(paiJobConfig);
        // Step 2. Submit PAI job via Rest call
        // Refer https://github.com/Microsoft/pai/blob/master/docs/rest-server/API.md for more detail about PAI Rest API
        const submitJobRequest: request.Options = {
402
            uri: `${this.config.host}/rest-server/api/v2/jobs`,
SparkSnail's avatar
SparkSnail committed
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
            method: 'POST',
            body: paiJobConfig,
            followAllRedirects: true,
            headers: {
                'Content-Type': 'text/yaml',
                Authorization: `Bearer ${this.paiToken}`
            }
        };
        request(submitJobRequest, (error: Error, response: request.Response, body: any) => {
            // If submit success, will get status code 202. refer: https://github.com/microsoft/pai/blob/master/src/rest-server/docs/swagger.yaml
            if ((error !== undefined && error !== null) || response.statusCode >= 400) {
                const errorMessage: string = (error !== undefined && error !== null) ? error.message :
                    `Submit trial ${trialJobId} failed, http code:${response.statusCode}, http body: ${body}`;
                this.log.error(errorMessage);
                trialJobDetail.status = 'FAILED';
                deferred.reject(errorMessage);
            } else {
                trialJobDetail.submitTime = Date.now();
            }
            deferred.resolve(true);
        });

        return deferred.promise;
    }

    private async writeParameterFile(directory: string, hyperParameters: HyperParameters): Promise<void> {
        const filepath: string = path.join(directory, generateParamFileName(hyperParameters));
        await fs.promises.writeFile(filepath, hyperParameters.value, { encoding: 'utf8' });
    }
J-shang's avatar
J-shang committed
432
433
434
435
436
437
438
439

    public getTrialOutputLocalPath(_trialJobId: string): Promise<string> {
        throw new MethodNotImplementedError();
    }

    public fetchTrialOutput(_trialJobId: string, _subpath: string): Promise<void> {
        throw new MethodNotImplementedError();
    }
440
441
}

442
export { PAITrainingService };