openPaiEnvironmentService.ts 16.6 KB
Newer Older
1
2
3
4
5
6
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.

'use strict';

import * as fs from 'fs';
7
import * as yaml from 'js-yaml';
8
9
10
11
12
13
14
import * as request from 'request';
import { Deferred } from 'ts-deferred';
import * as component from '../../../common/component';
import { getExperimentId } from '../../../common/experimentStartupInfo';
import { getLogger, Logger } from '../../../common/log';
import { TrialConfigMetadataKey } from '../../common/trialConfigMetadataKey';
import { PAIClusterConfig } from '../../pai/paiConfig';
SparkSnail's avatar
SparkSnail committed
15
import { NNIPAITrialConfig } from '../../pai/paiConfig';
16
17
18
19
20
21
22
23
24
25
26
27
import { EnvironmentInformation, EnvironmentService } from '../environment';
import { StorageService } from '../storageService';


/**
 * Collector PAI jobs info from PAI cluster, and update pai job status locally
 */
@component.Singleton
export class OpenPaiEnvironmentService extends EnvironmentService {

    private readonly log: Logger = getLogger();
    private paiClusterConfig: PAIClusterConfig | undefined;
SparkSnail's avatar
SparkSnail committed
28
    private paiTrialConfig: NNIPAITrialConfig | undefined;
29
30
31
32
33
34
35
36
37
38
39
    private paiJobConfig: any;
    private paiToken?: string;
    private protocol: string = 'http';

    private experimentId: string;

    constructor() {
        super();
        this.experimentId = getExperimentId();
    }

40
41
42
43
    public get environmentMaintenceLoopInterval(): number {
        return 5000;
    }

44
45
46
47
    public get hasStorageService(): boolean {
        return true;
    }

48
49
50
51
    public get getName(): string {
        return 'pai';
    }

52
53
54
55
56
    public async config(key: string, value: string): Promise<void> {
        switch (key) {
            case TrialConfigMetadataKey.PAI_CLUSTER_CONFIG:
                this.paiClusterConfig = <PAIClusterConfig>JSON.parse(value);
                this.paiClusterConfig.host = this.formatPAIHost(this.paiClusterConfig.host);
57
                this.paiToken = this.paiClusterConfig.token;
58
59
60
61
62
63
64
                break;

            case TrialConfigMetadataKey.TRIAL_CONFIG: {
                if (this.paiClusterConfig === undefined) {
                    this.log.error('pai cluster config is not initialized');
                    break;
                }
SparkSnail's avatar
SparkSnail committed
65
                this.paiTrialConfig = <NNIPAITrialConfig>JSON.parse(value);
66
67
68
69
70
71
72
73
74
                // Validate to make sure codeDir doesn't have too many files

                const storageService = component.get<StorageService>(StorageService);
                const remoteRoot = storageService.joinPath(this.paiTrialConfig.nniManagerNFSMountPath, this.experimentId);
                storageService.initialize(this.paiTrialConfig.nniManagerNFSMountPath, remoteRoot);

                if (this.paiTrialConfig.paiConfigPath) {
                    this.paiJobConfig = yaml.safeLoad(fs.readFileSync(this.paiTrialConfig.paiConfigPath, 'utf8'));
                }
75
76
77
78
79
80
81
82
83
84

                if (this.paiClusterConfig.gpuNum === undefined) {
                    this.paiClusterConfig.gpuNum = this.paiTrialConfig.gpuNum;
                }
                if (this.paiClusterConfig.cpuNum === undefined) {
                    this.paiClusterConfig.cpuNum = this.paiTrialConfig.cpuNum;
                }
                if (this.paiClusterConfig.memoryMB === undefined) {
                    this.paiClusterConfig.memoryMB = this.paiTrialConfig.memoryMB;
                }
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
                break;
            }
            default:
                this.log.debug(`OpenPAI not proccessed metadata key: '${key}', value: '${value}'`);
        }
    }

    public async refreshEnvironmentsStatus(environments: EnvironmentInformation[]): Promise<void> {
        const deferred: Deferred<void> = new Deferred<void>();

        if (this.paiClusterConfig === undefined) {
            throw new Error('PAI Cluster config is not initialized');
        }
        if (this.paiToken === undefined) {
            throw new Error('PAI token is not initialized');
        }

        const getJobInfoRequest: request.Options = {
            uri: `${this.protocol}://${this.paiClusterConfig.host}/rest-server/api/v2/jobs?username=${this.paiClusterConfig.userName}`,
            method: 'GET',
            json: true,
            headers: {
                'Content-Type': 'application/json',
                Authorization: `Bearer ${this.paiToken}`
            }
        };

        request(getJobInfoRequest, async (error: any, response: request.Response, body: any) => {
113
            // Status code 200 for success
114
            if ((error !== undefined && error !== null) || response.statusCode >= 400) {
115
116
117
118
                const errorMessage: string = (error !== undefined && error !== null) ? error.message :
                    `OpenPAI: get environment list from PAI Cluster failed!, http code:${response.statusCode}, http body: ${JSON.stringify(body)}`;
                this.log.error(`${errorMessage}`);
                deferred.reject(errorMessage);
119
120
121
122
123
124
125
            } else {
                const jobInfos = new Map<string, any>();
                body.forEach((jobInfo: any) => {
                    jobInfos.set(jobInfo.name, jobInfo);
                });

                environments.forEach((environment) => {
126
127
                    if (jobInfos.has(environment.envId)) {
                        const jobResponse = jobInfos.get(environment.envId);
128
129
130
131
132
133
                        if (jobResponse && jobResponse.state) {
                            const oldEnvironmentStatus = environment.status;
                            switch (jobResponse.state) {
                                case 'RUNNING':
                                case 'WAITING':
                                case 'SUCCEEDED':
134
135
                                    environment.setStatus(jobResponse.state);
                                    break;
136
                                case 'FAILED':
137
                                    environment.setStatus(jobResponse.state);
138
                                    deferred.reject(`OpenPAI: job ${environment.envId} is failed!`);
139
140
141
                                    break;
                                case 'STOPPED':
                                case 'STOPPING':
142
                                    environment.setStatus('USER_CANCELED');
143
144
                                    break;
                                default:
145
146
                                    this.log.error(`OpenPAI: job ${environment.envId} returns unknown state ${jobResponse.state}.`);
                                    environment.setStatus('UNKNOWN');
147
148
                            }
                            if (oldEnvironmentStatus !== environment.status) {
149
                                this.log.debug(`OpenPAI: job ${environment.envId} change status ${oldEnvironmentStatus} to ${environment.status} due to job is ${jobResponse.state}.`)
150
151
                            }
                        } else {
152
                            this.log.error(`OpenPAI: job ${environment.envId} has no state returned. body:${JSON.stringify(jobResponse)}`);
153
154
155
156
                            // some error happens, and mark this environment
                            environment.status = 'FAILED';
                        }
                    } else {
157
                        this.log.error(`OpenPAI job ${environment.envId} is not found in job list.`);
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
                        environment.status = 'UNKNOWN';
                    }
                });
                deferred.resolve();
            }
        });
        return deferred.promise;
    }

    public async startEnvironment(environment: EnvironmentInformation): Promise<void> {
        const deferred: Deferred<void> = new Deferred<void>();

        if (this.paiClusterConfig === undefined) {
            throw new Error('PAI Cluster config is not initialized');
        }
        if (this.paiToken === undefined) {
            throw new Error('PAI token is not initialized');
        }
        if (this.paiTrialConfig === undefined) {
            throw new Error('PAI trial config is not initialized');
        }

        // Step 1. Prepare PAI job configuration
SparkSnail's avatar
SparkSnail committed
181
182
        const environmentRoot = `${this.paiTrialConfig.containerNFSMountPath}/${this.experimentId}`;
        environment.runnerWorkingFolder = `${environmentRoot}/envs/${environment.id}`;
183
184
185
186
        environment.command = `cd ${environmentRoot} && ${environment.command}`;
        environment.trackingUrl = `${this.protocol}://${this.paiClusterConfig.host}/job-detail.html?username=${this.paiClusterConfig.userName}&jobName=${environment.envId}`;
        environment.useActiveGpu = this.paiClusterConfig.useActiveGpu;
        environment.maxTrialNumberPerGpu = this.paiClusterConfig.maxTrialNumPerGpu;
187
188
189
190
191
192
193
194
195
196

        // Step 2. Generate Job Configuration in yaml format
        const paiJobConfig = this.generateJobConfigInYamlFormat(environment);
        this.log.debug(`generated paiJobConfig: ${paiJobConfig}`);

        // Step 3. Submit PAI job via Rest call
        const submitJobRequest: request.Options = {
            uri: `${this.protocol}://${this.paiClusterConfig.host}/rest-server/api/v2/jobs`,
            method: 'POST',
            body: paiJobConfig,
197
            followAllRedirects: true,
198
199
200
201
202
203
            headers: {
                'Content-Type': 'text/yaml',
                Authorization: `Bearer ${this.paiToken}`
            }
        };
        request(submitJobRequest, (error, response, body) => {
204
            // Status code 202 for success, refer https://github.com/microsoft/pai/blob/master/src/rest-server/docs/swagger.yaml
205
206
            if ((error !== undefined && error !== null) || response.statusCode >= 400) {
                const errorMessage: string = (error !== undefined && error !== null) ? error.message :
207
                    `start environment ${environment.envId} failed, http code:${response.statusCode}, http body: ${body}`;
208
209
210

                this.log.error(errorMessage);
                environment.status = 'FAILED';
211
                deferred.reject(errorMessage);
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
            }
            deferred.resolve();
        });

        return deferred.promise;
    }

    public async stopEnvironment(environment: EnvironmentInformation): Promise<void> {
        const deferred: Deferred<void> = new Deferred<void>();

        if (this.paiClusterConfig === undefined) {
            return Promise.reject(new Error('PAI Cluster config is not initialized'));
        }
        if (this.paiToken === undefined) {
            return Promise.reject(Error('PAI token is not initialized'));
        }

        const stopJobRequest: request.Options = {
230
            uri: `${this.protocol}://${this.paiClusterConfig.host}/rest-server/api/v2/jobs/${this.paiClusterConfig.userName}~${environment.envId}/executionType`,
231
232
233
234
235
236
237
238
239
240
            method: 'PUT',
            json: true,
            body: { value: 'STOP' },
            time: true,
            headers: {
                'Content-Type': 'application/json',
                Authorization: `Bearer ${this.paiToken}`
            }
        };

241
        this.log.debug(`stopping OpenPAI environment ${environment.envId}, ${stopJobRequest.uri}`);
242
243
244
245

        try {
            request(stopJobRequest, (error, response, _body) => {
                try {
246
                    // Status code 202 for success.
247
                    if ((error !== undefined && error !== null) || (response && response.statusCode >= 400)) {
248
249
250
                        const errorMessage: string = (error !== undefined && error !== null) ? error.message :
                            `OpenPAI: stop job ${environment.envId} failed, http code:${response.statusCode}, http body: ${_body}`;
                        this.log.error(`${errorMessage}`);
251
252
253
                        deferred.reject((error !== undefined && error !== null) ? error :
                            `Stop trial failed, http code: ${response.statusCode}`);
                    } else {
254
                        this.log.info(`OpenPAI job ${environment.envId} stopped.`);
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
                    }
                    deferred.resolve();
                } catch (error) {
                    this.log.error(`OpenPAI error when inner stopping environment ${error}`);
                    deferred.reject(error);
                }
            });
        } catch (error) {
            this.log.error(`OpenPAI error when stopping environment ${error}`);
            return Promise.reject(error);
        }

        return deferred.promise;
    }

    private generateJobConfigInYamlFormat(environment: EnvironmentInformation): any {
        if (this.paiTrialConfig === undefined) {
            throw new Error('trial config is not initialized');
        }
274
        const jobName = environment.envId;
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305

        let nniJobConfig: any = undefined;
        if (this.paiTrialConfig.paiConfigPath) {
            nniJobConfig = JSON.parse(JSON.stringify(this.paiJobConfig)); //Trick for deep clone in Typescript
            nniJobConfig.name = jobName;
            if (nniJobConfig.taskRoles) {

                environment.nodeCount = 0;
                // count instance
                for (const taskRoleName in nniJobConfig.taskRoles) {
                    const taskRole = nniJobConfig.taskRoles[taskRoleName];
                    let instanceCount = 1;
                    if (taskRole.instances) {
                        instanceCount = taskRole.instances;
                    }
                    environment.nodeCount += instanceCount;
                }

                // Each taskRole will generate new command in NNI's command format
                // Each command will be formatted to NNI style
                for (const taskRoleName in nniJobConfig.taskRoles) {
                    const taskRole = nniJobConfig.taskRoles[taskRoleName];
                    // replace ' to '\''
                    const joinedCommand = taskRole.commands.join(" && ").replace("'", "'\\''").trim();
                    const nniTrialCommand = `${environment.command} --node_count ${environment.nodeCount} --trial_command '${joinedCommand}'`;
                    this.log.debug(`replace command ${taskRole.commands} to ${[nniTrialCommand]}`);
                    taskRole.commands = [nniTrialCommand];
                }
            }

        } else {
306
307
308
309
310
311
312
313
314
315
316
317
318
            if (this.paiClusterConfig === undefined) {
                throw new Error('PAI Cluster config is not initialized');
            }
            if (this.paiClusterConfig.gpuNum === undefined) {
                throw new Error('PAI Cluster gpuNum is not initialized');
            }
            if (this.paiClusterConfig.cpuNum === undefined) {
                throw new Error('PAI Cluster cpuNum is not initialized');
            }
            if (this.paiClusterConfig.memoryMB === undefined) {
                throw new Error('PAI Cluster memoryMB is not initialized');
            }

319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
            nniJobConfig = {
                protocolVersion: 2,
                name: jobName,
                type: 'job',
                jobRetryCount: 0,
                prerequisites: [
                    {
                        type: 'dockerimage',
                        uri: this.paiTrialConfig.image,
                        name: 'docker_image_0'
                    }
                ],
                taskRoles: {
                    taskrole: {
                        instances: 1,
                        completion: {
                            minFailedInstances: 1,
                            minSucceededInstances: -1
                        },
                        taskRetryCount: 0,
                        dockerImage: 'docker_image_0',
                        resourcePerInstance: {
341
342
343
                            gpu: this.paiClusterConfig.gpuNum,
                            cpu: this.paiClusterConfig.cpuNum,
                            memoryMB: this.paiClusterConfig.memoryMB
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
                        },
                        commands: [
                            environment.command
                        ]
                    }
                },
                extras: {
                    'storages': [
                        {
                            name: this.paiTrialConfig.paiStorageConfigName
                        }
                    ],
                    submitFrom: 'submit-job-v2'
                }
            }
            if (this.paiTrialConfig.virtualCluster) {
                nniJobConfig.defaults = {
                    virtualCluster: this.paiTrialConfig.virtualCluster
                }
            }
        }
        return yaml.safeDump(nniJobConfig);
    }

    protected formatPAIHost(host: string): string {
        // If users' host start with 'http://' or 'https://', use the original host,
        // or format to 'http//${host}'
        if (host.startsWith('http://')) {
            this.protocol = 'http';
            return host.replace('http://', '');
        } else if (host.startsWith('https://')) {
            this.protocol = 'https';
            return host.replace('https://', '');
        } else {
            return host;
        }
    }
}