kubeflowTrainingService.ts 24.9 KB
Newer Older
liuzhe-lz's avatar
liuzhe-lz committed
1
2
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
3

4
'use strict';
5
6
7
8
9

import * as assert from 'assert';
import * as cpp from 'child-process-promise';
import * as fs from 'fs';
import * as path from 'path';
10
import * as component from '../../../common/component';
11
12
13

import { getExperimentId } from '../../../common/experimentStartupInfo';
import {
14
    NNIManagerIpConfig, TrialJobApplicationForm, TrialJobDetail, TrialJobStatus
15
16
} from '../../../common/trainingService';
import { delay, generateParamFileName, getExperimentRootDir, uniqueString } from '../../../common/utils';
17
18
import { CONTAINER_INSTALL_NNI_SHELL_FORMAT } from '../../common/containerJobData';
import { TrialConfigMetadataKey } from '../../common/trialConfigMetadataKey';
19
import { validateCodeDir } from '../../common/util';
20
21
22
import { NFSConfig } from '../kubernetesConfig';
import { KubernetesTrialJobDetail } from '../kubernetesData';
import { KubernetesTrainingService } from '../kubernetesTrainingService';
23
import { KubeflowOperatorClientFactory } from './kubeflowApiClient';
24
25
26
import { KubeflowClusterConfig, KubeflowClusterConfigAzure, KubeflowClusterConfigFactory, KubeflowClusterConfigNFS,
    KubeflowTrialConfig, KubeflowTrialConfigFactory, KubeflowTrialConfigPytorch, KubeflowTrialConfigTensorflow
} from './kubeflowConfig';
27
import { KubeflowJobInfoCollector } from './kubeflowJobInfoCollector';
28
import { KubeflowJobRestServer } from './kubeflowJobRestServer';
29
30
31
32
33
34
35
36
37

/**
 * Training Service implementation for Kubeflow
 * Refer https://github.com/kubeflow/kubeflow for more info about Kubeflow
 */
@component.Singleton
class KubeflowTrainingService extends KubernetesTrainingService implements KubernetesTrainingService {
    private kubeflowClusterConfig?: KubeflowClusterConfig;
    private kubeflowTrialConfig?: KubeflowTrialConfig;
38
    private readonly kubeflowJobInfoCollector: KubeflowJobInfoCollector;
39
40

    constructor() {
41
        super();
42
        this.kubeflowJobInfoCollector = new KubeflowJobInfoCollector(this.trialJobsMap);
43
        this.experimentId = getExperimentId();
chicm-ms's avatar
chicm-ms committed
44
        this.log.info('Construct Kubeflow training service.');
45
46
47
    }

    public async run(): Promise<void> {
chicm-ms's avatar
chicm-ms committed
48
        this.log.info('Run Kubeflow training service.');
49
        this.kubernetesJobRestServer = component.get(KubeflowJobRestServer);
50
        if (this.kubernetesJobRestServer === undefined) {
51
52
53
            throw new Error('kubernetesJobRestServer not initialized!');
        }
        await this.kubernetesJobRestServer.start();
54
        this.kubernetesJobRestServer.setEnableVersionCheck = this.versionCheck;
55
56
        this.log.info(`Kubeflow Training service rest server listening on: ${this.kubernetesJobRestServer.endPoint}`);
        while (!this.stopping) {
57
            // collect metrics for Kubeflow jobs by interacting with Kubernetes API server
58
59
            await delay(3000);
            await this.kubeflowJobInfoCollector.retrieveTrialStatus(this.kubernetesCRDClient);
60
            if (this.kubernetesJobRestServer.getErrorMessage !== undefined) {
61
62
63
                throw new Error(this.kubernetesJobRestServer.getErrorMessage);
                this.stopping = true;
            }
64
        }
chicm-ms's avatar
chicm-ms committed
65
        this.log.info('Kubeflow training service exit.');
66
67
    }

68
    public async submitTrialJob(form: TrialJobApplicationForm): Promise<TrialJobDetail> {
69
        if (this.kubernetesCRDClient === undefined) {
70
71
72
            throw new Error('Kubeflow job operator client is undefined');
        }

73
        if (this.kubernetesRestServerPort === undefined) {
74
75
76
77
78
            const restServer: KubeflowJobRestServer = component.get(KubeflowJobRestServer);
            this.kubernetesRestServerPort = restServer.clusterRestServerPort;
        }
        const trialJobId: string = uniqueString(5);
        const trialWorkingFolder: string = path.join(this.CONTAINER_MOUNT_PATH, 'nni', getExperimentId(), trialJobId);
79
        const kubeflowJobName: string = `nni-exp-${this.experimentId}-trial-${trialJobId}`.toLowerCase();
80
81
        const trialLocalTempFolder: string = path.join(getExperimentRootDir(), 'trials-local', trialJobId);
        //prepare the runscript
82
        await this.prepareRunScript(trialLocalTempFolder, trialJobId, trialWorkingFolder, form);
83
84
        //upload files to sotrage
        const trialJobOutputUrl: string = await this.uploadCodeFiles(trialJobId, trialLocalTempFolder);
85
86
87
88
        let initStatus: TrialJobStatus = 'WAITING';
        if (!trialJobOutputUrl) {
            initStatus = 'FAILED';
        }
89
90
        const trialJobDetail: KubernetesTrialJobDetail = new KubernetesTrialJobDetail(
            trialJobId,
91
            initStatus,
92
93
94
95
96
97
            Date.now(),
            trialWorkingFolder,
            form,
            kubeflowJobName,
            trialJobOutputUrl
        );
98
99

        // Generate kubeflow job resource config object
100
101
102
103
        const kubeflowJobConfig: any = await this.prepareKubeflowConfig(trialJobId, trialWorkingFolder, kubeflowJobName);
        // Create kubeflow job based on generated kubeflow job resource config
        await this.kubernetesCRDClient.createKubernetesJob(kubeflowJobConfig);

104
        // Set trial job detail until create Kubeflow job successfully
105
106
107
108
        this.trialJobsMap.set(trialJobId, trialJobDetail);

        return Promise.resolve(trialJobDetail);
    }
109

110
111
112
113
114
115
    public async setClusterMetadata(key: string, value: string): Promise<void> {
        switch (key) {
            case TrialConfigMetadataKey.NNI_MANAGER_IP:
                this.nniManagerIpConfig = <NNIManagerIpConfig>JSON.parse(value);
                break;

chicm-ms's avatar
chicm-ms committed
116
            case TrialConfigMetadataKey.KUBEFLOW_CLUSTER_CONFIG: {
117
118
119
120
121
122
123
124
                const kubeflowClusterJsonObject: object = JSON.parse(value);
                this.kubeflowClusterConfig = KubeflowClusterConfigFactory.generateKubeflowClusterConfig(kubeflowClusterJsonObject);
                if (this.kubeflowClusterConfig.storageType === 'azureStorage') {
                    const azureKubeflowClusterConfig: KubeflowClusterConfigAzure = <KubeflowClusterConfigAzure>this.kubeflowClusterConfig;
                    this.azureStorageAccountName = azureKubeflowClusterConfig.azureStorage.accountName;
                    this.azureStorageShare = azureKubeflowClusterConfig.azureStorage.azureShare;
                    await this.createAzureStorage(
                        azureKubeflowClusterConfig.keyVault.vaultName,
chicm-ms's avatar
chicm-ms committed
125
                        azureKubeflowClusterConfig.keyVault.name
126
127
128
129
130
131
132
133
                    );
                } else if (this.kubeflowClusterConfig.storageType === 'nfs') {
                    const nfsKubeflowClusterConfig: KubeflowClusterConfigNFS = <KubeflowClusterConfigNFS>this.kubeflowClusterConfig;
                    await this.createNFSStorage(
                        nfsKubeflowClusterConfig.nfs.server,
                        nfsKubeflowClusterConfig.nfs.path
                    );
                }
134
135
                this.kubernetesCRDClient = KubeflowOperatorClientFactory.createClient(
                    this.kubeflowClusterConfig.operator, this.kubeflowClusterConfig.apiVersion);
136
                break;
chicm-ms's avatar
chicm-ms committed
137
138
            }
            case TrialConfigMetadataKey.TRIAL_CONFIG: {
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
                if (this.kubeflowClusterConfig === undefined) {
                    this.log.error('kubeflow cluster config is not initialized');

                    return Promise.reject(new Error('kubeflow cluster config is not initialized'));
                }

                assert(this.kubeflowClusterConfig !== undefined);
                const kubeflowTrialJsonObjsect: object = JSON.parse(value);
                this.kubeflowTrialConfig = KubeflowTrialConfigFactory.generateKubeflowTrialConfig(
                    kubeflowTrialJsonObjsect,
                    this.kubeflowClusterConfig.operator
                );

                // Validate to make sure codeDir doesn't have too many files
                try {
                    await validateCodeDir(this.kubeflowTrialConfig.codeDir);
                } catch (error) {
                    this.log.error(error);

                    return Promise.reject(new Error(error));
                }
                break;
chicm-ms's avatar
chicm-ms committed
161
            }
162
163
164
165
166
167
168
169
170
171
172
173
            case TrialConfigMetadataKey.VERSION_CHECK:
                this.versionCheck = (value === 'true' || value === 'True');
                break;
            case TrialConfigMetadataKey.LOG_COLLECTION:
                this.logCollection = value;
                break;
            default:
        }

        return Promise.resolve();
    }

174
175
    /**
     * upload code files to nfs or azureStroage
176
177
     * @param trialJobId
     * @param trialLocalTempFolder
178
179
180
     * return: trialJobOutputUrl
     */
    private async uploadCodeFiles(trialJobId: string, trialLocalTempFolder: string): Promise<string> {
181
        if (this.kubeflowClusterConfig === undefined) {
182
183
184
            throw new Error('Kubeflow Cluster config is not initialized');
        }

185
186
187
188
        if (this.kubeflowTrialConfig === undefined) {
            throw new Error('Kubeflow Trial config is not initialized');
        }

189
190
        let trialJobOutputUrl: string = '';

191
        assert(this.kubeflowClusterConfig.storage === undefined
192
            || this.kubeflowClusterConfig.storage === 'azureStorage'
193
194
            || this.kubeflowClusterConfig.storage === 'nfs');

195
196
197
198
        if (this.kubeflowClusterConfig.storage === 'azureStorage') {
            if (this.azureStorageClient === undefined) {
                throw new Error('azureStorageClient is not initialized');
            }
199
200
            const azureKubeflowClusterConfig: KubeflowClusterConfigAzure = <KubeflowClusterConfigAzure>this.kubeflowClusterConfig;
            trialJobOutputUrl = await this.uploadFilesToAzureStorage(trialJobId, trialLocalTempFolder, this.kubeflowTrialConfig.codeDir, azureKubeflowClusterConfig.uploadRetryCount);
201
202
        } else if (this.kubeflowClusterConfig.storage === 'nfs' || this.kubeflowClusterConfig.storage === undefined) {
            const nfsKubeflowClusterConfig: KubeflowClusterConfigNFS = <KubeflowClusterConfigNFS>this.kubeflowClusterConfig;
203
            // Creat work dir for current trial in NFS directory
204
            await cpp.exec(`mkdir -p ${this.trialLocalNFSTempFolder}/nni/${getExperimentId()}/${trialJobId}`);
205
            // Copy script files from local dir to NFS mounted dir
206
            await cpp.exec(`cp -r ${trialLocalTempFolder}/* ${this.trialLocalNFSTempFolder}/nni/${getExperimentId()}/${trialJobId}/.`);
207
208
            // Copy codeDir to NFS mounted dir
            await cpp.exec(`cp -r ${this.kubeflowTrialConfig.codeDir}/* ${this.trialLocalNFSTempFolder}/nni/${getExperimentId()}/${trialJobId}/.`);
209
            const nfsConfig: NFSConfig = nfsKubeflowClusterConfig.nfs;
210
            trialJobOutputUrl = `nfs://${nfsConfig.server}:${path.join(nfsConfig.path, 'nni', getExperimentId(), trialJobId, 'output')}`;
211
212
213
214
        }

        return Promise.resolve(trialJobOutputUrl);
    }
215

216
217
    private async prepareRunScript(trialLocalTempFolder: string, trialJobId: string, trialWorkingFolder: string,
                                   form: TrialJobApplicationForm): Promise<void> {
218
        if (this.kubeflowClusterConfig === undefined) {
219
220
221
222
            throw new Error('Kubeflow Cluster config is not initialized');
        }

        // initialize kubeflow trial config to specific type
223
224
        let kubeflowTrialConfig: any;
        if (this.kubeflowClusterConfig.operator === 'tf-operator') {
225
            kubeflowTrialConfig = <KubeflowTrialConfigTensorflow>this.kubeflowTrialConfig;
226
        } else if (this.kubeflowClusterConfig.operator === 'pytorch-operator') {
227
            kubeflowTrialConfig = <KubeflowTrialConfigPytorch>this.kubeflowTrialConfig;
228
229
        } else {
            throw Error(`operator ${this.kubeflowClusterConfig.operator} is invalid`);
230
        }
231

232
        //create tmp trial working folder locally.
233
        await cpp.exec(`mkdir -p ${trialLocalTempFolder}`);
chicm-ms's avatar
chicm-ms committed
234
        const runScriptContent: string = CONTAINER_INSTALL_NNI_SHELL_FORMAT;
235
236
237
238
239
        // Write NNI installation file to local tmp files
        await fs.promises.writeFile(path.join(trialLocalTempFolder, 'install_nni.sh'), runScriptContent, { encoding: 'utf8' });

        // Write worker file content run_worker.sh to local tmp folders
        if (kubeflowTrialConfig.worker !== undefined) {
240
           const workerRunScriptContent: string = await this.generateRunScript('kubeflow', trialJobId, trialWorkingFolder,
241
                                                                               kubeflowTrialConfig.worker.command,
242
                                                                               form.sequenceId.toString(), 'worker',
243
                                                                               kubeflowTrialConfig.worker.gpuNum);
244
           await fs.promises.writeFile(path.join(trialLocalTempFolder, 'run_worker.sh'), workerRunScriptContent, { encoding: 'utf8' });
245
246
247
248
249
        }
        // Write parameter server file content run_ps.sh to local tmp folders
        if (this.kubeflowClusterConfig.operator === 'tf-operator') {
           const tensorflowTrialConfig: KubeflowTrialConfigTensorflow = <KubeflowTrialConfigTensorflow>this.kubeflowTrialConfig;
           if (tensorflowTrialConfig.ps !== undefined) {
250
               const psRunScriptContent: string = await this.generateRunScript('kubeflow', trialJobId, trialWorkingFolder,
251
                                                                               tensorflowTrialConfig.ps.command,
252
                                                                               form.sequenceId.toString(),
253
                                                                               'ps', tensorflowTrialConfig.ps.gpuNum);
254
255
               await fs.promises.writeFile(path.join(trialLocalTempFolder, 'run_ps.sh'), psRunScriptContent, { encoding: 'utf8' });
           }
256
257
258
        } else if (this.kubeflowClusterConfig.operator === 'pytorch-operator') {
           const pytorchTrialConfig: KubeflowTrialConfigPytorch = <KubeflowTrialConfigPytorch>this.kubeflowTrialConfig;
           if (pytorchTrialConfig.master !== undefined) {
259
               const masterRunScriptContent: string = await this.generateRunScript('kubeflow', trialJobId, trialWorkingFolder,
260
                                                                                   pytorchTrialConfig.master.command,
261
                                                                                   form.sequenceId.toString(), 'master',
262
                                                                                   pytorchTrialConfig.master.gpuNum);
263
264
               await fs.promises.writeFile(path.join(trialLocalTempFolder, 'run_master.sh'), masterRunScriptContent, { encoding: 'utf8' });
           }
265
266
        }
        // Write file content ( parameter.cfg ) to local tmp folders
267
268
269
        if (form !== undefined) {
           await fs.promises.writeFile(path.join(trialLocalTempFolder, generateParamFileName(form.hyperParameters)),
                                       form.hyperParameters.value, { encoding: 'utf8' });
270
        }
271
    }
272

273
    private async prepareKubeflowConfig(trialJobId: string, trialWorkingFolder: string, kubeflowJobName: string): Promise<any> {
274
        if (this.kubeflowClusterConfig === undefined) {
275
276
277
            throw new Error('Kubeflow Cluster config is not initialized');
        }

278
        if (this.kubeflowTrialConfig === undefined) {
279
280
281
282
            throw new Error('Kubeflow trial config is not initialized');
        }

        // initialize kubeflow trial config to specific type
283
284
        let kubeflowTrialConfig: any;
        if (this.kubeflowClusterConfig.operator === 'tf-operator') {
285
            kubeflowTrialConfig = <KubeflowTrialConfigTensorflow>this.kubeflowTrialConfig;
286
        } else if (this.kubeflowClusterConfig.operator === 'pytorch-operator') {
287
            kubeflowTrialConfig = <KubeflowTrialConfigPytorch>this.kubeflowTrialConfig;
288
289
        } else {
            throw Error(`operator ${this.kubeflowClusterConfig.operator} is invalid`);
290
        }
291

chicm-ms's avatar
chicm-ms committed
292
        const workerPodResources: any = {};
293
        if (kubeflowTrialConfig.worker !== undefined) {
294
            workerPodResources.requests = this.generatePodResource(kubeflowTrialConfig.worker.memoryMB, kubeflowTrialConfig.worker.cpuNum,
295
                                                                   kubeflowTrialConfig.worker.gpuNum);
296
        }
297
        workerPodResources.limits = {...workerPodResources.requests};
298

chicm-ms's avatar
chicm-ms committed
299
        const nonWorkerResources: any = {};
300
301
302
        if (this.kubeflowClusterConfig.operator === 'tf-operator') {
            const tensorflowTrialConfig: KubeflowTrialConfigTensorflow = <KubeflowTrialConfigTensorflow>this.kubeflowTrialConfig;
            if (tensorflowTrialConfig.ps !== undefined) {
303
                nonWorkerResources.requests = this.generatePodResource(tensorflowTrialConfig.ps.memoryMB, tensorflowTrialConfig.ps.cpuNum,
304
305
                                                                       tensorflowTrialConfig.ps.gpuNum);
                nonWorkerResources.limits = {...nonWorkerResources.requests};
306
            }
307
308
        } else if (this.kubeflowClusterConfig.operator === 'pytorch-operator') {
            const pyTorchTrialConfig: KubeflowTrialConfigPytorch = <KubeflowTrialConfigPytorch>this.kubeflowTrialConfig;
309
            nonWorkerResources.requests = this.generatePodResource(pyTorchTrialConfig.master.memoryMB, pyTorchTrialConfig.master.cpuNum,
310
311
                                                                   pyTorchTrialConfig.master.gpuNum);
            nonWorkerResources.limits = {...nonWorkerResources.requests};
312
313
314
        }

        // Generate kubeflow job resource config object
315
        const kubeflowJobConfig: any = await this.generateKubeflowJobConfig(trialJobId, trialWorkingFolder, kubeflowJobName, workerPodResources,
316
                                                                      nonWorkerResources);
317
318

        return Promise.resolve(kubeflowJobConfig);
319
    }
320
321
322
323
324
325
326
327
328

    /**
     * Generate kubeflow resource config file
     * @param trialJobId trial job id
     * @param trialWorkingFolder working folder
     * @param kubeflowJobName job name
     * @param workerPodResources worker pod template
     * @param nonWorkerPodResources non-worker pod template, like ps or master
     */
chicm-ms's avatar
chicm-ms committed
329
330
    private async generateKubeflowJobConfig(trialJobId: string, trialWorkingFolder: string, kubeflowJobName: string, workerPodResources: any,
                                            nonWorkerPodResources?: any): Promise<any> {
331
        if (this.kubeflowClusterConfig === undefined) {
332
333
334
            throw new Error('Kubeflow Cluster config is not initialized');
        }

335
        if (this.kubeflowTrialConfig === undefined) {
336
337
338
            throw new Error('Kubeflow trial config is not initialized');
        }

339
        if (this.kubernetesCRDClient === undefined) {
340
341
342
343
            throw new Error('Kubeflow operator client is not initialized');
        }

        const replicaSpecsObj: any = {};
344
345
346
        const replicaSpecsObjMap: Map<string, object> = new Map<string, object>();
        if (this.kubeflowTrialConfig.operatorType === 'tf-operator') {
            const tensorflowTrialConfig: KubeflowTrialConfigTensorflow = <KubeflowTrialConfigTensorflow>this.kubeflowTrialConfig;
chicm-ms's avatar
chicm-ms committed
347
            const privateRegistrySecretName = await this.createRegistrySecret(tensorflowTrialConfig.worker.privateRegistryAuthPath);
348
            replicaSpecsObj.Worker = this.generateReplicaConfig(trialWorkingFolder, tensorflowTrialConfig.worker.replicas,
349
                                                                tensorflowTrialConfig.worker.image, 'run_worker.sh', workerPodResources, privateRegistrySecretName);
350
            if (tensorflowTrialConfig.ps !== undefined) {
chicm-ms's avatar
chicm-ms committed
351
                const privateRegistrySecretName: string | undefined = await this.createRegistrySecret(tensorflowTrialConfig.ps.privateRegistryAuthPath);
352
                replicaSpecsObj.Ps = this.generateReplicaConfig(trialWorkingFolder, tensorflowTrialConfig.ps.replicas,
353
                                                                tensorflowTrialConfig.ps.image, 'run_ps.sh', nonWorkerPodResources, privateRegistrySecretName);
354
            }
355
356
357
358
            replicaSpecsObjMap.set(this.kubernetesCRDClient.jobKind, {tfReplicaSpecs: replicaSpecsObj});
        } else if (this.kubeflowTrialConfig.operatorType === 'pytorch-operator') {
            const pytorchTrialConfig: KubeflowTrialConfigPytorch = <KubeflowTrialConfigPytorch>this.kubeflowTrialConfig;
            if (pytorchTrialConfig.worker !== undefined) {
chicm-ms's avatar
chicm-ms committed
359
                const privateRegistrySecretName: string | undefined = await this.createRegistrySecret(pytorchTrialConfig.worker.privateRegistryAuthPath);
360
                replicaSpecsObj.Worker = this.generateReplicaConfig(trialWorkingFolder, pytorchTrialConfig.worker.replicas,
361
                                                                    pytorchTrialConfig.worker.image, 'run_worker.sh', workerPodResources, privateRegistrySecretName);
362
            }
chicm-ms's avatar
chicm-ms committed
363
            const privateRegistrySecretName: string | undefined = await this.createRegistrySecret(pytorchTrialConfig.master.privateRegistryAuthPath);
364
            replicaSpecsObj.Master = this.generateReplicaConfig(trialWorkingFolder, pytorchTrialConfig.master.replicas,
365
                                                                pytorchTrialConfig.master.image, 'run_master.sh', nonWorkerPodResources, privateRegistrySecretName);
366

367
            replicaSpecsObjMap.set(this.kubernetesCRDClient.jobKind, {pytorchReplicaSpecs: replicaSpecsObj});
368
369
        }

370
        return Promise.resolve({
371
372
            apiVersion: `kubeflow.org/${this.kubernetesCRDClient.apiVersion}`,
            kind: this.kubernetesCRDClient.jobKind,
373
            metadata: {
374
375
376
377
378
379
380
381
382
                name: kubeflowJobName,
                namespace: 'default',
                labels: {
                    app: this.NNI_KUBERNETES_TRIAL_LABEL,
                    expId: getExperimentId(),
                    trialId: trialJobId
                }
            },
            spec: replicaSpecsObjMap.get(this.kubernetesCRDClient.jobKind)
383
        });
384
385
386
387
388
389
390
391
392
393
    }

    /**
     * Generate tf-operator's tfjobs replica config section
     * @param trialWorkingFolder trial working folder
     * @param replicaNumber replica number
     * @param replicaImage image
     * @param runScriptFile script file name
     * @param podResources pod resource config section
     */
394
    private generateReplicaConfig(trialWorkingFolder: string, replicaNumber: number, replicaImage: string, runScriptFile: string,
395
                                  podResources: any, privateRegistrySecretName: string | undefined): any {
396
        if (this.kubeflowClusterConfig === undefined) {
397
398
399
            throw new Error('Kubeflow Cluster config is not initialized');
        }

400
        if (this.kubeflowTrialConfig === undefined) {
401
402
403
            throw new Error('Kubeflow trial config is not initialized');
        }

404
        if (this.kubernetesCRDClient === undefined) {
405
406
            throw new Error('Kubeflow operator client is not initialized');
        }
407
        // The config spec for volume field
408
409
        const volumeSpecMap: Map<string, object> = new Map<string, object>();
        if (this.kubeflowClusterConfig.storageType === 'azureStorage') {
410
411
412
413
414
415
416
417
            volumeSpecMap.set('nniVolumes', [
            {
                    name: 'nni-vol',
                    azureFile: {
                        secretName: `${this.azureStorageSecretName}`,
                        shareName: `${this.azureStorageShare}`,
                        readonly: false
                    }
418
419
420
            }]);
        } else {
            const nfsKubeflowClusterConfig: KubeflowClusterConfigNFS = <KubeflowClusterConfigNFS> this.kubeflowClusterConfig;
421
422
423
424
425
426
427
            volumeSpecMap.set('nniVolumes', [
            {
                name: 'nni-vol',
                nfs: {
                    server: `${nfsKubeflowClusterConfig.nfs.server}`,
                    path: `${nfsKubeflowClusterConfig.nfs.path}`
                }
428
            }]);
429
        }
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
        // The config spec for container field
        const containersSpecMap: Map<string, object> = new Map<string, object>(); 
        containersSpecMap.set('containers', [
        {
                // Kubeflow tensorflow operator requires that containers' name must be tensorflow
                // TODO: change the name based on operator's type
                name: this.kubernetesCRDClient.containerName,
                image: replicaImage,
                args: ['sh', `${path.join(trialWorkingFolder, runScriptFile)}`],
                volumeMounts: [
                {
                    name: 'nni-vol',
                    mountPath: this.CONTAINER_MOUNT_PATH
                }],
                resources: podResources
            }
        ]);
chicm-ms's avatar
chicm-ms committed
447
        const spec: any = {
448
449
450
451
452
453
454
455
456
457
            containers: containersSpecMap.get('containers'),
            restartPolicy: 'ExitCode',
            volumes: volumeSpecMap.get('nniVolumes')
        }
        if (privateRegistrySecretName) {
            spec.imagePullSecrets = [
                {
                    name: privateRegistrySecretName
                }]
        }
458
459
460
461
462
463
        return {
            replicas: replicaNumber,
            template: {
                metadata: {
                    creationTimestamp: null
                },
464
                spec: spec
465
            }
466
        }
467
468
    }
}
469
export { KubeflowTrainingService };