kubeflowTrainingService.ts 25.1 KB
Newer Older
liuzhe-lz's avatar
liuzhe-lz committed
1
2
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
3

4
'use strict';
5
6
7
8
9

import * as assert from 'assert';
import * as cpp from 'child-process-promise';
import * as fs from 'fs';
import * as path from 'path';
10
import * as component from '../../../common/component';
11
12
13

import { getExperimentId } from '../../../common/experimentStartupInfo';
import {
14
    NNIManagerIpConfig, TrialJobApplicationForm, TrialJobDetail, TrialJobStatus
15
16
} from '../../../common/trainingService';
import { delay, generateParamFileName, getExperimentRootDir, uniqueString } from '../../../common/utils';
17
18
import { CONTAINER_INSTALL_NNI_SHELL_FORMAT } from '../../common/containerJobData';
import { TrialConfigMetadataKey } from '../../common/trialConfigMetadataKey';
19
import { validateCodeDir } from '../../common/util';
20
21
22
import { NFSConfig } from '../kubernetesConfig';
import { KubernetesTrialJobDetail } from '../kubernetesData';
import { KubernetesTrainingService } from '../kubernetesTrainingService';
23
import { KubeflowOperatorClientFactory } from './kubeflowApiClient';
24
25
26
import { KubeflowClusterConfig, KubeflowClusterConfigAzure, KubeflowClusterConfigFactory, KubeflowClusterConfigNFS,
    KubeflowTrialConfig, KubeflowTrialConfigFactory, KubeflowTrialConfigPytorch, KubeflowTrialConfigTensorflow
} from './kubeflowConfig';
27
import { KubeflowJobInfoCollector } from './kubeflowJobInfoCollector';
28
import { KubeflowJobRestServer } from './kubeflowJobRestServer';
29

30
// tslint:disable: no-unsafe-any no-any
31
32
33
34
35
36
37
38
/**
 * Training Service implementation for Kubeflow
 * Refer https://github.com/kubeflow/kubeflow for more info about Kubeflow
 */
@component.Singleton
class KubeflowTrainingService extends KubernetesTrainingService implements KubernetesTrainingService {
    private kubeflowClusterConfig?: KubeflowClusterConfig;
    private kubeflowTrialConfig?: KubeflowTrialConfig;
39
    private readonly kubeflowJobInfoCollector: KubeflowJobInfoCollector;
40
41

    constructor() {
42
        super();
43
        this.kubeflowJobInfoCollector = new KubeflowJobInfoCollector(this.trialJobsMap);
44
        this.experimentId = getExperimentId();
chicm-ms's avatar
chicm-ms committed
45
        this.log.info('Construct Kubeflow training service.');
46
47
48
    }

    public async run(): Promise<void> {
chicm-ms's avatar
chicm-ms committed
49
        this.log.info('Run Kubeflow training service.');
50
        this.kubernetesJobRestServer = component.get(KubeflowJobRestServer);
51
        if (this.kubernetesJobRestServer === undefined) {
52
53
54
            throw new Error('kubernetesJobRestServer not initialized!');
        }
        await this.kubernetesJobRestServer.start();
55
        this.kubernetesJobRestServer.setEnableVersionCheck = this.versionCheck;
56
57
        this.log.info(`Kubeflow Training service rest server listening on: ${this.kubernetesJobRestServer.endPoint}`);
        while (!this.stopping) {
58
            // collect metrics for Kubeflow jobs by interacting with Kubernetes API server
59
60
            await delay(3000);
            await this.kubeflowJobInfoCollector.retrieveTrialStatus(this.kubernetesCRDClient);
61
            if (this.kubernetesJobRestServer.getErrorMessage !== undefined) {
62
63
64
                throw new Error(this.kubernetesJobRestServer.getErrorMessage);
                this.stopping = true;
            }
65
        }
chicm-ms's avatar
chicm-ms committed
66
        this.log.info('Kubeflow training service exit.');
67
68
    }

69
    public async submitTrialJob(form: TrialJobApplicationForm): Promise<TrialJobDetail> {
70
        if (this.kubernetesCRDClient === undefined) {
71
72
73
            throw new Error('Kubeflow job operator client is undefined');
        }

74
        if (this.kubernetesRestServerPort === undefined) {
75
76
77
78
79
            const restServer: KubeflowJobRestServer = component.get(KubeflowJobRestServer);
            this.kubernetesRestServerPort = restServer.clusterRestServerPort;
        }
        const trialJobId: string = uniqueString(5);
        const trialWorkingFolder: string = path.join(this.CONTAINER_MOUNT_PATH, 'nni', getExperimentId(), trialJobId);
80
        const kubeflowJobName: string = `nni-exp-${this.experimentId}-trial-${trialJobId}`.toLowerCase();
81
82
        const trialLocalTempFolder: string = path.join(getExperimentRootDir(), 'trials-local', trialJobId);
        //prepare the runscript
83
        await this.prepareRunScript(trialLocalTempFolder, trialJobId, trialWorkingFolder, form);
84
85
        //upload files to sotrage
        const trialJobOutputUrl: string = await this.uploadCodeFiles(trialJobId, trialLocalTempFolder);
86
87
88
89
        let initStatus: TrialJobStatus = 'WAITING';
        if (!trialJobOutputUrl) {
            initStatus = 'FAILED';
        }
90
91
        const trialJobDetail: KubernetesTrialJobDetail = new KubernetesTrialJobDetail(
            trialJobId,
92
            initStatus,
93
94
95
96
97
98
            Date.now(),
            trialWorkingFolder,
            form,
            kubeflowJobName,
            trialJobOutputUrl
        );
99
100

        // Generate kubeflow job resource config object
101
102
103
104
        const kubeflowJobConfig: any = await this.prepareKubeflowConfig(trialJobId, trialWorkingFolder, kubeflowJobName);
        // Create kubeflow job based on generated kubeflow job resource config
        await this.kubernetesCRDClient.createKubernetesJob(kubeflowJobConfig);

105
        // Set trial job detail until create Kubeflow job successfully
106
107
108
109
        this.trialJobsMap.set(trialJobId, trialJobDetail);

        return Promise.resolve(trialJobDetail);
    }
110

111
112
113
114
115
116
117
    // tslint:disable:no-redundant-jsdoc
    public async setClusterMetadata(key: string, value: string): Promise<void> {
        switch (key) {
            case TrialConfigMetadataKey.NNI_MANAGER_IP:
                this.nniManagerIpConfig = <NNIManagerIpConfig>JSON.parse(value);
                break;

chicm-ms's avatar
chicm-ms committed
118
            case TrialConfigMetadataKey.KUBEFLOW_CLUSTER_CONFIG: {
119
120
121
122
123
124
125
126
                const kubeflowClusterJsonObject: object = JSON.parse(value);
                this.kubeflowClusterConfig = KubeflowClusterConfigFactory.generateKubeflowClusterConfig(kubeflowClusterJsonObject);
                if (this.kubeflowClusterConfig.storageType === 'azureStorage') {
                    const azureKubeflowClusterConfig: KubeflowClusterConfigAzure = <KubeflowClusterConfigAzure>this.kubeflowClusterConfig;
                    this.azureStorageAccountName = azureKubeflowClusterConfig.azureStorage.accountName;
                    this.azureStorageShare = azureKubeflowClusterConfig.azureStorage.azureShare;
                    await this.createAzureStorage(
                        azureKubeflowClusterConfig.keyVault.vaultName,
chicm-ms's avatar
chicm-ms committed
127
                        azureKubeflowClusterConfig.keyVault.name
128
129
130
131
132
133
134
135
                    );
                } else if (this.kubeflowClusterConfig.storageType === 'nfs') {
                    const nfsKubeflowClusterConfig: KubeflowClusterConfigNFS = <KubeflowClusterConfigNFS>this.kubeflowClusterConfig;
                    await this.createNFSStorage(
                        nfsKubeflowClusterConfig.nfs.server,
                        nfsKubeflowClusterConfig.nfs.path
                    );
                }
136
137
                this.kubernetesCRDClient = KubeflowOperatorClientFactory.createClient(
                    this.kubeflowClusterConfig.operator, this.kubeflowClusterConfig.apiVersion);
138
                break;
chicm-ms's avatar
chicm-ms committed
139
140
            }
            case TrialConfigMetadataKey.TRIAL_CONFIG: {
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
                if (this.kubeflowClusterConfig === undefined) {
                    this.log.error('kubeflow cluster config is not initialized');

                    return Promise.reject(new Error('kubeflow cluster config is not initialized'));
                }

                assert(this.kubeflowClusterConfig !== undefined);
                const kubeflowTrialJsonObjsect: object = JSON.parse(value);
                this.kubeflowTrialConfig = KubeflowTrialConfigFactory.generateKubeflowTrialConfig(
                    kubeflowTrialJsonObjsect,
                    this.kubeflowClusterConfig.operator
                );

                // Validate to make sure codeDir doesn't have too many files
                try {
                    await validateCodeDir(this.kubeflowTrialConfig.codeDir);
                } catch (error) {
                    this.log.error(error);

                    return Promise.reject(new Error(error));
                }
                break;
chicm-ms's avatar
chicm-ms committed
163
            }
164
165
166
167
168
169
170
171
172
173
174
175
            case TrialConfigMetadataKey.VERSION_CHECK:
                this.versionCheck = (value === 'true' || value === 'True');
                break;
            case TrialConfigMetadataKey.LOG_COLLECTION:
                this.logCollection = value;
                break;
            default:
        }

        return Promise.resolve();
    }

176
177
    /**
     * upload code files to nfs or azureStroage
178
179
     * @param trialJobId
     * @param trialLocalTempFolder
180
181
182
     * return: trialJobOutputUrl
     */
    private async uploadCodeFiles(trialJobId: string, trialLocalTempFolder: string): Promise<string> {
183
        if (this.kubeflowClusterConfig === undefined) {
184
185
186
            throw new Error('Kubeflow Cluster config is not initialized');
        }

187
188
189
190
        if (this.kubeflowTrialConfig === undefined) {
            throw new Error('Kubeflow Trial config is not initialized');
        }

191
192
        let trialJobOutputUrl: string = '';

193
        assert(this.kubeflowClusterConfig.storage === undefined
194
            || this.kubeflowClusterConfig.storage === 'azureStorage'
195
196
            || this.kubeflowClusterConfig.storage === 'nfs');

197
198
199
200
        if (this.kubeflowClusterConfig.storage === 'azureStorage') {
            if (this.azureStorageClient === undefined) {
                throw new Error('azureStorageClient is not initialized');
            }
201
202
            const azureKubeflowClusterConfig: KubeflowClusterConfigAzure = <KubeflowClusterConfigAzure>this.kubeflowClusterConfig;
            trialJobOutputUrl = await this.uploadFilesToAzureStorage(trialJobId, trialLocalTempFolder, this.kubeflowTrialConfig.codeDir, azureKubeflowClusterConfig.uploadRetryCount);
203
204
        } else if (this.kubeflowClusterConfig.storage === 'nfs' || this.kubeflowClusterConfig.storage === undefined) {
            const nfsKubeflowClusterConfig: KubeflowClusterConfigNFS = <KubeflowClusterConfigNFS>this.kubeflowClusterConfig;
205
            // Creat work dir for current trial in NFS directory
206
            await cpp.exec(`mkdir -p ${this.trialLocalNFSTempFolder}/nni/${getExperimentId()}/${trialJobId}`);
207
            // Copy script files from local dir to NFS mounted dir
208
            await cpp.exec(`cp -r ${trialLocalTempFolder}/* ${this.trialLocalNFSTempFolder}/nni/${getExperimentId()}/${trialJobId}/.`);
209
210
            // Copy codeDir to NFS mounted dir
            await cpp.exec(`cp -r ${this.kubeflowTrialConfig.codeDir}/* ${this.trialLocalNFSTempFolder}/nni/${getExperimentId()}/${trialJobId}/.`);
211
            const nfsConfig: NFSConfig = nfsKubeflowClusterConfig.nfs;
212
            trialJobOutputUrl = `nfs://${nfsConfig.server}:${path.join(nfsConfig.path, 'nni', getExperimentId(), trialJobId, 'output')}`;
213
214
215
216
        }

        return Promise.resolve(trialJobOutputUrl);
    }
217

218
219
    private async prepareRunScript(trialLocalTempFolder: string, trialJobId: string, trialWorkingFolder: string,
                                   form: TrialJobApplicationForm): Promise<void> {
220
        if (this.kubeflowClusterConfig === undefined) {
221
222
223
224
            throw new Error('Kubeflow Cluster config is not initialized');
        }

        // initialize kubeflow trial config to specific type
225
226
        let kubeflowTrialConfig: any;
        if (this.kubeflowClusterConfig.operator === 'tf-operator') {
227
            kubeflowTrialConfig = <KubeflowTrialConfigTensorflow>this.kubeflowTrialConfig;
228
        } else if (this.kubeflowClusterConfig.operator === 'pytorch-operator') {
229
            kubeflowTrialConfig = <KubeflowTrialConfigPytorch>this.kubeflowTrialConfig;
230
231
        } else {
            throw Error(`operator ${this.kubeflowClusterConfig.operator} is invalid`);
232
        }
233

234
        //create tmp trial working folder locally.
235
        await cpp.exec(`mkdir -p ${trialLocalTempFolder}`);
chicm-ms's avatar
chicm-ms committed
236
        const runScriptContent: string = CONTAINER_INSTALL_NNI_SHELL_FORMAT;
237
238
239
240
241
        // Write NNI installation file to local tmp files
        await fs.promises.writeFile(path.join(trialLocalTempFolder, 'install_nni.sh'), runScriptContent, { encoding: 'utf8' });

        // Write worker file content run_worker.sh to local tmp folders
        if (kubeflowTrialConfig.worker !== undefined) {
242
           const workerRunScriptContent: string = await this.generateRunScript('kubeflow', trialJobId, trialWorkingFolder,
243
                                                                               kubeflowTrialConfig.worker.command,
244
                                                                               form.sequenceId.toString(), 'worker',
245
                                                                               kubeflowTrialConfig.worker.gpuNum);
246
           await fs.promises.writeFile(path.join(trialLocalTempFolder, 'run_worker.sh'), workerRunScriptContent, { encoding: 'utf8' });
247
248
249
250
251
        }
        // Write parameter server file content run_ps.sh to local tmp folders
        if (this.kubeflowClusterConfig.operator === 'tf-operator') {
           const tensorflowTrialConfig: KubeflowTrialConfigTensorflow = <KubeflowTrialConfigTensorflow>this.kubeflowTrialConfig;
           if (tensorflowTrialConfig.ps !== undefined) {
252
               const psRunScriptContent: string = await this.generateRunScript('kubeflow', trialJobId, trialWorkingFolder,
253
                                                                               tensorflowTrialConfig.ps.command,
254
                                                                               form.sequenceId.toString(),
255
                                                                               'ps', tensorflowTrialConfig.ps.gpuNum);
256
257
               await fs.promises.writeFile(path.join(trialLocalTempFolder, 'run_ps.sh'), psRunScriptContent, { encoding: 'utf8' });
           }
258
259
260
        } else if (this.kubeflowClusterConfig.operator === 'pytorch-operator') {
           const pytorchTrialConfig: KubeflowTrialConfigPytorch = <KubeflowTrialConfigPytorch>this.kubeflowTrialConfig;
           if (pytorchTrialConfig.master !== undefined) {
261
               const masterRunScriptContent: string = await this.generateRunScript('kubeflow', trialJobId, trialWorkingFolder,
262
                                                                                   pytorchTrialConfig.master.command,
263
                                                                                   form.sequenceId.toString(), 'master',
264
                                                                                   pytorchTrialConfig.master.gpuNum);
265
266
               await fs.promises.writeFile(path.join(trialLocalTempFolder, 'run_master.sh'), masterRunScriptContent, { encoding: 'utf8' });
           }
267
268
        }
        // Write file content ( parameter.cfg ) to local tmp folders
269
270
271
        if (form !== undefined) {
           await fs.promises.writeFile(path.join(trialLocalTempFolder, generateParamFileName(form.hyperParameters)),
                                       form.hyperParameters.value, { encoding: 'utf8' });
272
        }
273
    }
274

275
    private async prepareKubeflowConfig(trialJobId: string, trialWorkingFolder: string, kubeflowJobName: string): Promise<any> {
276
        if (this.kubeflowClusterConfig === undefined) {
277
278
279
            throw new Error('Kubeflow Cluster config is not initialized');
        }

280
        if (this.kubeflowTrialConfig === undefined) {
281
282
283
284
            throw new Error('Kubeflow trial config is not initialized');
        }

        // initialize kubeflow trial config to specific type
285
286
        let kubeflowTrialConfig: any;
        if (this.kubeflowClusterConfig.operator === 'tf-operator') {
287
            kubeflowTrialConfig = <KubeflowTrialConfigTensorflow>this.kubeflowTrialConfig;
288
        } else if (this.kubeflowClusterConfig.operator === 'pytorch-operator') {
289
            kubeflowTrialConfig = <KubeflowTrialConfigPytorch>this.kubeflowTrialConfig;
290
291
        } else {
            throw Error(`operator ${this.kubeflowClusterConfig.operator} is invalid`);
292
        }
293

chicm-ms's avatar
chicm-ms committed
294
        const workerPodResources: any = {};
295
        if (kubeflowTrialConfig.worker !== undefined) {
296
            workerPodResources.requests = this.generatePodResource(kubeflowTrialConfig.worker.memoryMB, kubeflowTrialConfig.worker.cpuNum,
297
                                                                   kubeflowTrialConfig.worker.gpuNum);
298
        }
299
        workerPodResources.limits = {...workerPodResources.requests};
300

chicm-ms's avatar
chicm-ms committed
301
        const nonWorkerResources: any = {};
302
303
304
        if (this.kubeflowClusterConfig.operator === 'tf-operator') {
            const tensorflowTrialConfig: KubeflowTrialConfigTensorflow = <KubeflowTrialConfigTensorflow>this.kubeflowTrialConfig;
            if (tensorflowTrialConfig.ps !== undefined) {
305
                nonWorkerResources.requests = this.generatePodResource(tensorflowTrialConfig.ps.memoryMB, tensorflowTrialConfig.ps.cpuNum,
306
307
                                                                       tensorflowTrialConfig.ps.gpuNum);
                nonWorkerResources.limits = {...nonWorkerResources.requests};
308
            }
309
310
        } else if (this.kubeflowClusterConfig.operator === 'pytorch-operator') {
            const pyTorchTrialConfig: KubeflowTrialConfigPytorch = <KubeflowTrialConfigPytorch>this.kubeflowTrialConfig;
311
            nonWorkerResources.requests = this.generatePodResource(pyTorchTrialConfig.master.memoryMB, pyTorchTrialConfig.master.cpuNum,
312
313
                                                                   pyTorchTrialConfig.master.gpuNum);
            nonWorkerResources.limits = {...nonWorkerResources.requests};
314
315
316
        }

        // Generate kubeflow job resource config object
317
        const kubeflowJobConfig: any = await this.generateKubeflowJobConfig(trialJobId, trialWorkingFolder, kubeflowJobName, workerPodResources,
318
                                                                      nonWorkerResources);
319
320

        return Promise.resolve(kubeflowJobConfig);
321
    }
322
323
324
325
326
327
328
329
330

    /**
     * Generate kubeflow resource config file
     * @param trialJobId trial job id
     * @param trialWorkingFolder working folder
     * @param kubeflowJobName job name
     * @param workerPodResources worker pod template
     * @param nonWorkerPodResources non-worker pod template, like ps or master
     */
chicm-ms's avatar
chicm-ms committed
331
332
    private async generateKubeflowJobConfig(trialJobId: string, trialWorkingFolder: string, kubeflowJobName: string, workerPodResources: any,
                                            nonWorkerPodResources?: any): Promise<any> {
333
        if (this.kubeflowClusterConfig === undefined) {
334
335
336
            throw new Error('Kubeflow Cluster config is not initialized');
        }

337
        if (this.kubeflowTrialConfig === undefined) {
338
339
340
            throw new Error('Kubeflow trial config is not initialized');
        }

341
        if (this.kubernetesCRDClient === undefined) {
342
343
344
345
            throw new Error('Kubeflow operator client is not initialized');
        }

        const replicaSpecsObj: any = {};
346
347
348
        const replicaSpecsObjMap: Map<string, object> = new Map<string, object>();
        if (this.kubeflowTrialConfig.operatorType === 'tf-operator') {
            const tensorflowTrialConfig: KubeflowTrialConfigTensorflow = <KubeflowTrialConfigTensorflow>this.kubeflowTrialConfig;
chicm-ms's avatar
chicm-ms committed
349
            const privateRegistrySecretName = await this.createRegistrySecret(tensorflowTrialConfig.worker.privateRegistryAuthPath);
350
            replicaSpecsObj.Worker = this.generateReplicaConfig(trialWorkingFolder, tensorflowTrialConfig.worker.replicas,
351
                                                                tensorflowTrialConfig.worker.image, 'run_worker.sh', workerPodResources, privateRegistrySecretName);
352
            if (tensorflowTrialConfig.ps !== undefined) {
chicm-ms's avatar
chicm-ms committed
353
                const privateRegistrySecretName: string | undefined = await this.createRegistrySecret(tensorflowTrialConfig.ps.privateRegistryAuthPath);
354
                replicaSpecsObj.Ps = this.generateReplicaConfig(trialWorkingFolder, tensorflowTrialConfig.ps.replicas,
355
                                                                tensorflowTrialConfig.ps.image, 'run_ps.sh', nonWorkerPodResources, privateRegistrySecretName);
356
            }
357
358
359
360
            replicaSpecsObjMap.set(this.kubernetesCRDClient.jobKind, {tfReplicaSpecs: replicaSpecsObj});
        } else if (this.kubeflowTrialConfig.operatorType === 'pytorch-operator') {
            const pytorchTrialConfig: KubeflowTrialConfigPytorch = <KubeflowTrialConfigPytorch>this.kubeflowTrialConfig;
            if (pytorchTrialConfig.worker !== undefined) {
chicm-ms's avatar
chicm-ms committed
361
                const privateRegistrySecretName: string | undefined = await this.createRegistrySecret(pytorchTrialConfig.worker.privateRegistryAuthPath);
362
                replicaSpecsObj.Worker = this.generateReplicaConfig(trialWorkingFolder, pytorchTrialConfig.worker.replicas,
363
                                                                    pytorchTrialConfig.worker.image, 'run_worker.sh', workerPodResources, privateRegistrySecretName);
364
            }
chicm-ms's avatar
chicm-ms committed
365
            const privateRegistrySecretName: string | undefined = await this.createRegistrySecret(pytorchTrialConfig.master.privateRegistryAuthPath);
366
            replicaSpecsObj.Master = this.generateReplicaConfig(trialWorkingFolder, pytorchTrialConfig.master.replicas,
367
                                                                pytorchTrialConfig.master.image, 'run_master.sh', nonWorkerPodResources, privateRegistrySecretName);
368

369
            replicaSpecsObjMap.set(this.kubernetesCRDClient.jobKind, {pytorchReplicaSpecs: replicaSpecsObj});
370
371
        }

372
        return Promise.resolve({
373
374
            apiVersion: `kubeflow.org/${this.kubernetesCRDClient.apiVersion}`,
            kind: this.kubernetesCRDClient.jobKind,
375
            metadata: {
376
377
378
379
380
381
382
383
384
                name: kubeflowJobName,
                namespace: 'default',
                labels: {
                    app: this.NNI_KUBERNETES_TRIAL_LABEL,
                    expId: getExperimentId(),
                    trialId: trialJobId
                }
            },
            spec: replicaSpecsObjMap.get(this.kubernetesCRDClient.jobKind)
385
        });
386
387
388
389
390
391
392
393
394
395
    }

    /**
     * Generate tf-operator's tfjobs replica config section
     * @param trialWorkingFolder trial working folder
     * @param replicaNumber replica number
     * @param replicaImage image
     * @param runScriptFile script file name
     * @param podResources pod resource config section
     */
396
    private generateReplicaConfig(trialWorkingFolder: string, replicaNumber: number, replicaImage: string, runScriptFile: string,
397
                                  podResources: any, privateRegistrySecretName: string | undefined): any {
398
        if (this.kubeflowClusterConfig === undefined) {
399
400
401
            throw new Error('Kubeflow Cluster config is not initialized');
        }

402
        if (this.kubeflowTrialConfig === undefined) {
403
404
405
            throw new Error('Kubeflow trial config is not initialized');
        }

406
        if (this.kubernetesCRDClient === undefined) {
407
408
            throw new Error('Kubeflow operator client is not initialized');
        }
409
        // The config spec for volume field
410
411
        const volumeSpecMap: Map<string, object> = new Map<string, object>();
        if (this.kubeflowClusterConfig.storageType === 'azureStorage') {
412
413
414
415
416
417
418
419
            volumeSpecMap.set('nniVolumes', [
            {
                    name: 'nni-vol',
                    azureFile: {
                        secretName: `${this.azureStorageSecretName}`,
                        shareName: `${this.azureStorageShare}`,
                        readonly: false
                    }
420
421
422
            }]);
        } else {
            const nfsKubeflowClusterConfig: KubeflowClusterConfigNFS = <KubeflowClusterConfigNFS> this.kubeflowClusterConfig;
423
424
425
426
427
428
429
            volumeSpecMap.set('nniVolumes', [
            {
                name: 'nni-vol',
                nfs: {
                    server: `${nfsKubeflowClusterConfig.nfs.server}`,
                    path: `${nfsKubeflowClusterConfig.nfs.path}`
                }
430
            }]);
431
        }
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
        // The config spec for container field
        const containersSpecMap: Map<string, object> = new Map<string, object>(); 
        containersSpecMap.set('containers', [
        {
                // Kubeflow tensorflow operator requires that containers' name must be tensorflow
                // TODO: change the name based on operator's type
                name: this.kubernetesCRDClient.containerName,
                image: replicaImage,
                args: ['sh', `${path.join(trialWorkingFolder, runScriptFile)}`],
                volumeMounts: [
                {
                    name: 'nni-vol',
                    mountPath: this.CONTAINER_MOUNT_PATH
                }],
                resources: podResources
            }
        ]);
chicm-ms's avatar
chicm-ms committed
449
        const spec: any = {
450
451
452
453
454
455
456
457
458
459
            containers: containersSpecMap.get('containers'),
            restartPolicy: 'ExitCode',
            volumes: volumeSpecMap.get('nniVolumes')
        }
        if (privateRegistrySecretName) {
            spec.imagePullSecrets = [
                {
                    name: privateRegistrySecretName
                }]
        }
460
461
462
463
        return {
            replicas: replicaNumber,
            template: {
                metadata: {
464
                    // tslint:disable-next-line:no-null-keyword
465
466
                    creationTimestamp: null
                },
467
                spec: spec
468
            }
469
        }
470
471
    }
}
472
473
// tslint:enable: no-unsafe-any no-any
export { KubeflowTrainingService };