openPaiEnvironmentService.ts 14.7 KB
Newer Older
1
2
3
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.

4
5
import yaml from 'js-yaml';
import request from 'request';
SparkSnail's avatar
SparkSnail committed
6
import { Container, Scope } from 'typescript-ioc';
7
import { Deferred } from 'ts-deferred';
8
9
10
11
12
13
import * as component from 'common/component';
import { ExperimentConfig, OpenpaiConfig, flattenConfig, toMegaBytes } from 'common/experimentConfig';
import { ExperimentStartupInfo } from 'common/experimentStartupInfo';
import { getLogger, Logger } from 'common/log';
import { PAIClusterConfig } from 'training_service/pai/paiConfig';
import { NNIPAITrialConfig } from 'training_service/pai/paiConfig';
14
import { EnvironmentInformation, EnvironmentService } from '../environment';
15
import { SharedStorageService } from '../sharedStorage';
16
import { MountedStorageService } from '../storages/mountedStorageService';
SparkSnail's avatar
SparkSnail committed
17
import { StorageService } from '../storageService';
18

19
interface FlattenOpenpaiConfig extends ExperimentConfig, OpenpaiConfig { }
20
21
22
23
24
25
26

/**
 * Collector PAI jobs info from PAI cluster, and update pai job status locally
 */
@component.Singleton
export class OpenPaiEnvironmentService extends EnvironmentService {

liuzhe-lz's avatar
liuzhe-lz committed
27
    private readonly log: Logger = getLogger('OpenPaiEnvironmentService');
28
    private paiClusterConfig: PAIClusterConfig | undefined;
SparkSnail's avatar
SparkSnail committed
29
    private paiTrialConfig: NNIPAITrialConfig | undefined;
30
31
    private paiToken: string;
    private protocol: string;
32
    private experimentId: string;
33
    private config: FlattenOpenpaiConfig;
34

35
    constructor(config: ExperimentConfig, info: ExperimentStartupInfo) {
36
        super();
37
        this.experimentId = info.experimentId;
38
39
40
        this.config = flattenConfig(config, 'openpai');
        this.paiToken = this.config.token;
        this.protocol = this.config.host.toLowerCase().startsWith('https://') ? 'https' : 'http';
SparkSnail's avatar
SparkSnail committed
41
42
43
44
        Container.bind(StorageService)
          .to(MountedStorageService)
          .scope(Scope.Singleton);
        const storageService = component.get<StorageService>(StorageService)
45
46
        const remoteRoot = storageService.joinPath(this.config.localStorageMountPoint, this.experimentId);
        storageService.initialize(this.config.localStorageMountPoint, remoteRoot);
47
48
    }

49
50
51
52
    public get environmentMaintenceLoopInterval(): number {
        return 5000;
    }

53
54
55
56
    public get hasStorageService(): boolean {
        return true;
    }

57
58
59
60
    public get getName(): string {
        return 'pai';
    }

61
62
63
64
65
66
67
68
    public async refreshEnvironmentsStatus(environments: EnvironmentInformation[]): Promise<void> {
        const deferred: Deferred<void> = new Deferred<void>();

        if (this.paiToken === undefined) {
            throw new Error('PAI token is not initialized');
        }

        const getJobInfoRequest: request.Options = {
69
            uri: `${this.config.host}/rest-server/api/v2/jobs?username=${this.config.username}`,
70
71
72
73
74
75
76
77
78
            method: 'GET',
            json: true,
            headers: {
                'Content-Type': 'application/json',
                Authorization: `Bearer ${this.paiToken}`
            }
        };

        request(getJobInfoRequest, async (error: any, response: request.Response, body: any) => {
79
            // Status code 200 for success
80
            if ((error !== undefined && error !== null) || response.statusCode >= 400) {
81
                const errorMessage: string = (error !== undefined && error !== null) ? error.message :
liuzhe-lz's avatar
liuzhe-lz committed
82
                    `OpenPAI: get environment list from PAI Cluster failed!, http code:${response.statusCode}, http body:' ${JSON.stringify(body)}`;
83
84
                this.log.error(`${errorMessage}`);
                deferred.reject(errorMessage);
85
86
87
88
89
90
91
            } else {
                const jobInfos = new Map<string, any>();
                body.forEach((jobInfo: any) => {
                    jobInfos.set(jobInfo.name, jobInfo);
                });

                environments.forEach((environment) => {
92
93
                    if (jobInfos.has(environment.envId)) {
                        const jobResponse = jobInfos.get(environment.envId);
94
95
96
97
98
99
                        if (jobResponse && jobResponse.state) {
                            const oldEnvironmentStatus = environment.status;
                            switch (jobResponse.state) {
                                case 'RUNNING':
                                case 'WAITING':
                                case 'SUCCEEDED':
100
101
                                    environment.setStatus(jobResponse.state);
                                    break;
102
                                case 'FAILED':
103
                                    environment.setStatus(jobResponse.state);
104
                                    deferred.reject(`OpenPAI: job ${environment.envId} is failed!`);
105
106
107
                                    break;
                                case 'STOPPED':
                                case 'STOPPING':
108
                                    environment.setStatus('USER_CANCELED');
109
110
                                    break;
                                default:
111
112
                                    this.log.error(`OpenPAI: job ${environment.envId} returns unknown state ${jobResponse.state}.`);
                                    environment.setStatus('UNKNOWN');
113
114
                            }
                            if (oldEnvironmentStatus !== environment.status) {
115
                                this.log.debug(`OpenPAI: job ${environment.envId} change status ${oldEnvironmentStatus} to ${environment.status} due to job is ${jobResponse.state}.`)
116
117
                            }
                        } else {
liuzhe-lz's avatar
liuzhe-lz committed
118
                            this.log.error(`OpenPAI: job ${environment.envId} has no state returned. body:`, jobResponse);
119
120
121
122
                            // some error happens, and mark this environment
                            environment.status = 'FAILED';
                        }
                    } else {
123
                        this.log.error(`OpenPAI job ${environment.envId} is not found in job list.`);
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
                        environment.status = 'UNKNOWN';
                    }
                });
                deferred.resolve();
            }
        });
        return deferred.promise;
    }

    public async startEnvironment(environment: EnvironmentInformation): Promise<void> {
        const deferred: Deferred<void> = new Deferred<void>();

        if (this.paiToken === undefined) {
            throw new Error('PAI token is not initialized');
        }
        // Step 1. Prepare PAI job configuration
140
141
142
        let environmentRoot: string;
        if (environment.useSharedStorage) {
            environmentRoot = component.get<SharedStorageService>(SharedStorageService).remoteWorkingRoot;
J-shang's avatar
J-shang committed
143
            environment.command = `${component.get<SharedStorageService>(SharedStorageService).remoteMountCommand.replace(/echo -e /g, `echo `).replace(/echo /g, `echo -e `)} && cd ${environmentRoot} && ${environment.command}`;
144
        } else {
145
            environmentRoot = `${this.config.containerStorageMountPoint}/${this.experimentId}`;
146
147
            environment.command = `cd ${environmentRoot} && ${environment.command}`;
        }
SparkSnail's avatar
SparkSnail committed
148
        environment.runnerWorkingFolder = `${environmentRoot}/envs/${environment.id}`;
149
150
151
        environment.trackingUrl = `${this.config.host}/job-detail.html?username=${this.config.username}&jobName=${environment.envId}`;
        environment.useActiveGpu = false;  // does openpai supports these?
        environment.maxTrialNumberPerGpu = 1;
152
153
154
155
156
157
158

        // Step 2. Generate Job Configuration in yaml format
        const paiJobConfig = this.generateJobConfigInYamlFormat(environment);
        this.log.debug(`generated paiJobConfig: ${paiJobConfig}`);

        // Step 3. Submit PAI job via Rest call
        const submitJobRequest: request.Options = {
159
            uri: `${this.config.host}/rest-server/api/v2/jobs`,
160
161
            method: 'POST',
            body: paiJobConfig,
162
            followAllRedirects: true,
163
164
165
166
167
168
            headers: {
                'Content-Type': 'text/yaml',
                Authorization: `Bearer ${this.paiToken}`
            }
        };
        request(submitJobRequest, (error, response, body) => {
169
            // Status code 202 for success, refer https://github.com/microsoft/pai/blob/master/src/rest-server/docs/swagger.yaml
170
171
            if ((error !== undefined && error !== null) || response.statusCode >= 400) {
                const errorMessage: string = (error !== undefined && error !== null) ? error.message :
172
                    `start environment ${environment.envId} failed, http code:${response.statusCode}, http body: ${body}`;
173
174
175

                this.log.error(errorMessage);
                environment.status = 'FAILED';
176
                deferred.reject(errorMessage);
177
178
179
180
181
182
183
184
185
186
            }
            deferred.resolve();
        });

        return deferred.promise;
    }

    public async stopEnvironment(environment: EnvironmentInformation): Promise<void> {
        const deferred: Deferred<void> = new Deferred<void>();

187
188
189
        if (environment.isAlive === false) {
            return Promise.resolve();
        }
190
191
192
193
194
        if (this.paiToken === undefined) {
            return Promise.reject(Error('PAI token is not initialized'));
        }

        const stopJobRequest: request.Options = {
195
            uri: `${this.config.host}/rest-server/api/v2/jobs/${this.config.username}~${environment.envId}/executionType`,
196
197
198
199
200
201
202
203
204
205
            method: 'PUT',
            json: true,
            body: { value: 'STOP' },
            time: true,
            headers: {
                'Content-Type': 'application/json',
                Authorization: `Bearer ${this.paiToken}`
            }
        };

206
        this.log.debug(`stopping OpenPAI environment ${environment.envId}, ${stopJobRequest.uri}`);
207
208
209
210

        try {
            request(stopJobRequest, (error, response, _body) => {
                try {
211
                    // Status code 202 for success.
212
                    if ((error !== undefined && error !== null) || (response && response.statusCode >= 400)) {
213
214
215
                        const errorMessage: string = (error !== undefined && error !== null) ? error.message :
                            `OpenPAI: stop job ${environment.envId} failed, http code:${response.statusCode}, http body: ${_body}`;
                        this.log.error(`${errorMessage}`);
216
217
218
                        deferred.reject((error !== undefined && error !== null) ? error :
                            `Stop trial failed, http code: ${response.statusCode}`);
                    } else {
219
                        this.log.info(`OpenPAI job ${environment.envId} stopped.`);
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
                    }
                    deferred.resolve();
                } catch (error) {
                    this.log.error(`OpenPAI error when inner stopping environment ${error}`);
                    deferred.reject(error);
                }
            });
        } catch (error) {
            this.log.error(`OpenPAI error when stopping environment ${error}`);
            return Promise.reject(error);
        }

        return deferred.promise;
    }

    private generateJobConfigInYamlFormat(environment: EnvironmentInformation): any {
236
        const jobName = environment.envId;
237
238

        let nniJobConfig: any = undefined;
239
240
        if (this.config.openpaiConfig !== undefined) {
            nniJobConfig = JSON.parse(JSON.stringify(this.config.openpaiConfig)); //Trick for deep clone in Typescript
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
            nniJobConfig.name = jobName;
            if (nniJobConfig.taskRoles) {

                environment.nodeCount = 0;
                // count instance
                for (const taskRoleName in nniJobConfig.taskRoles) {
                    const taskRole = nniJobConfig.taskRoles[taskRoleName];
                    let instanceCount = 1;
                    if (taskRole.instances) {
                        instanceCount = taskRole.instances;
                    }
                    environment.nodeCount += instanceCount;
                }

                // Each taskRole will generate new command in NNI's command format
                // Each command will be formatted to NNI style
                for (const taskRoleName in nniJobConfig.taskRoles) {
                    const taskRole = nniJobConfig.taskRoles[taskRoleName];
                    // replace ' to '\''
                    const joinedCommand = taskRole.commands.join(" && ").replace("'", "'\\''").trim();
                    const nniTrialCommand = `${environment.command} --node_count ${environment.nodeCount} --trial_command '${joinedCommand}'`;
                    this.log.debug(`replace command ${taskRole.commands} to ${[nniTrialCommand]}`);
                    taskRole.commands = [nniTrialCommand];
                }
            }

        } else {
            nniJobConfig = {
                protocolVersion: 2,
                name: jobName,
                type: 'job',
                jobRetryCount: 0,
                prerequisites: [
                    {
                        type: 'dockerimage',
276
                        uri: this.config.dockerImage,
277
278
279
280
281
282
283
284
285
286
287
288
289
                        name: 'docker_image_0'
                    }
                ],
                taskRoles: {
                    taskrole: {
                        instances: 1,
                        completion: {
                            minFailedInstances: 1,
                            minSucceededInstances: -1
                        },
                        taskRetryCount: 0,
                        dockerImage: 'docker_image_0',
                        resourcePerInstance: {
SparkSnail's avatar
SparkSnail committed
290
                            gpu: this.config.trialGpuNumber === undefined? 0: this.config.trialGpuNumber,
291
292
                            cpu: this.config.trialCpuNumber,
                            memoryMB: toMegaBytes(this.config.trialMemorySize)
293
294
295
296
297
298
299
300
301
                        },
                        commands: [
                            environment.command
                        ]
                    }
                },
                extras: {
                    'storages': [
                        {
302
                            name: this.config.storageConfigName
303
304
305
306
307
                        }
                    ],
                    submitFrom: 'submit-job-v2'
                }
            }
SparkSnail's avatar
SparkSnail committed
308
            if (this.config.virtualCluster) {
309
                nniJobConfig.defaults = {
SparkSnail's avatar
SparkSnail committed
310
                    virtualCluster: this.config.virtualCluster
311
312
313
                }
            }
        }
314
        return yaml.dump(nniJobConfig);
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
    }

    protected formatPAIHost(host: string): string {
        // If users' host start with 'http://' or 'https://', use the original host,
        // or format to 'http//${host}'
        if (host.startsWith('http://')) {
            this.protocol = 'http';
            return host.replace('http://', '');
        } else if (host.startsWith('https://')) {
            this.protocol = 'https';
            return host.replace('https://', '');
        } else {
            return host;
        }
    }
}