"vscode:/vscode.git/clone" did not exist on "36cc3c24e6776dfd8ed53608093af28c8bf7d3ce"
Unverified Commit d5857823 authored by liuzhe-lz's avatar liuzhe-lz Committed by GitHub
Browse files

Config refactor (#4370)

parent cb090e8c
...@@ -10,16 +10,14 @@ export class AMLClusterConfig { ...@@ -10,16 +10,14 @@ export class AMLClusterConfig {
public readonly resourceGroup: string; public readonly resourceGroup: string;
public readonly workspaceName: string; public readonly workspaceName: string;
public readonly computeTarget: string; public readonly computeTarget: string;
public useActiveGpu?: boolean;
public maxTrialNumPerGpu?: number; public maxTrialNumPerGpu?: number;
constructor(subscriptionId: string, resourceGroup: string, workspaceName: string, computeTarget: string, constructor(subscriptionId: string, resourceGroup: string, workspaceName: string, computeTarget: string,
useActiveGpu?: boolean, maxTrialNumPerGpu?: number) { maxTrialNumPerGpu?: number) {
this.subscriptionId = subscriptionId; this.subscriptionId = subscriptionId;
this.resourceGroup = resourceGroup; this.resourceGroup = resourceGroup;
this.workspaceName = workspaceName; this.workspaceName = workspaceName;
this.computeTarget = computeTarget; this.computeTarget = computeTarget;
this.useActiveGpu = useActiveGpu;
this.maxTrialNumPerGpu = maxTrialNumPerGpu; this.maxTrialNumPerGpu = maxTrialNumPerGpu;
} }
} }
......
...@@ -5,7 +5,7 @@ import fs from 'fs'; ...@@ -5,7 +5,7 @@ import fs from 'fs';
import path from 'path'; import path from 'path';
import * as component from 'common/component'; import * as component from 'common/component';
import { getLogger, Logger } from 'common/log'; import { getLogger, Logger } from 'common/log';
import { ExperimentConfig, AmlConfig, flattenConfig } from 'common/experimentConfig'; import { AmlConfig } from 'common/experimentConfig';
import { ExperimentStartupInfo } from 'common/experimentStartupInfo'; import { ExperimentStartupInfo } from 'common/experimentStartupInfo';
import { validateCodeDir } from 'training_service/common/util'; import { validateCodeDir } from 'training_service/common/util';
import { AMLClient } from '../aml/amlClient'; import { AMLClient } from '../aml/amlClient';
...@@ -15,8 +15,6 @@ import { EventEmitter } from "events"; ...@@ -15,8 +15,6 @@ import { EventEmitter } from "events";
import { AMLCommandChannel } from '../channels/amlCommandChannel'; import { AMLCommandChannel } from '../channels/amlCommandChannel';
import { SharedStorageService } from '../sharedStorage' import { SharedStorageService } from '../sharedStorage'
interface FlattenAmlConfig extends ExperimentConfig, AmlConfig { }
/** /**
* Collector AML jobs info from AML cluster, and update aml job status locally * Collector AML jobs info from AML cluster, and update aml job status locally
*/ */
...@@ -26,13 +24,13 @@ export class AMLEnvironmentService extends EnvironmentService { ...@@ -26,13 +24,13 @@ export class AMLEnvironmentService extends EnvironmentService {
private readonly log: Logger = getLogger('AMLEnvironmentService'); private readonly log: Logger = getLogger('AMLEnvironmentService');
private experimentId: string; private experimentId: string;
private experimentRootDir: string; private experimentRootDir: string;
private config: FlattenAmlConfig; private config: AmlConfig;
constructor(config: ExperimentConfig, info: ExperimentStartupInfo) { constructor(config: AmlConfig, info: ExperimentStartupInfo) {
super(); super();
this.experimentId = info.experimentId; this.experimentId = info.experimentId;
this.experimentRootDir = info.logDir; this.experimentRootDir = info.logDir;
this.config = flattenConfig(config, 'aml'); this.config = config;
validateCodeDir(this.config.trialCodeDirectory); validateCodeDir(this.config.trialCodeDirectory);
} }
...@@ -98,9 +96,6 @@ export class AMLEnvironmentService extends EnvironmentService { ...@@ -98,9 +96,6 @@ export class AMLEnvironmentService extends EnvironmentService {
amlEnvironment.command = `mv envs outputs/envs && cd outputs && ${amlEnvironment.command}`; amlEnvironment.command = `mv envs outputs/envs && cd outputs && ${amlEnvironment.command}`;
} }
amlEnvironment.command = `import os\nos.system('${amlEnvironment.command}')`; amlEnvironment.command = `import os\nos.system('${amlEnvironment.command}')`;
if (this.config.deprecated && this.config.deprecated.useActiveGpu !== undefined) {
amlEnvironment.useActiveGpu = this.config.deprecated.useActiveGpu;
}
amlEnvironment.maxTrialNumberPerGpu = this.config.maxTrialNumberPerGpu; amlEnvironment.maxTrialNumberPerGpu = this.config.maxTrialNumberPerGpu;
await fs.promises.writeFile(path.join(environmentLocalTempFolder, 'nni_script.py'), amlEnvironment.command, { encoding: 'utf8' }); await fs.promises.writeFile(path.join(environmentLocalTempFolder, 'nni_script.py'), amlEnvironment.command, { encoding: 'utf8' });
......
...@@ -5,7 +5,7 @@ import fs from 'fs'; ...@@ -5,7 +5,7 @@ import fs from 'fs';
import path from 'path'; import path from 'path';
import * as component from 'common/component'; import * as component from 'common/component';
import { getLogger, Logger } from 'common/log'; import { getLogger, Logger } from 'common/log';
import { ExperimentConfig, DlcConfig, flattenConfig } from 'common/experimentConfig'; import { DlcConfig } from 'common/experimentConfig';
import { ExperimentStartupInfo } from 'common/experimentStartupInfo'; import { ExperimentStartupInfo } from 'common/experimentStartupInfo';
import { DlcClient } from '../dlc/dlcClient'; import { DlcClient } from '../dlc/dlcClient';
import { DlcEnvironmentInformation } from '../dlc/dlcConfig'; import { DlcEnvironmentInformation } from '../dlc/dlcConfig';
...@@ -16,8 +16,6 @@ import { MountedStorageService } from '../storages/mountedStorageService'; ...@@ -16,8 +16,6 @@ import { MountedStorageService } from '../storages/mountedStorageService';
import { Scope } from 'typescript-ioc'; import { Scope } from 'typescript-ioc';
import { StorageService } from '../storageService'; import { StorageService } from '../storageService';
interface FlattenDlcConfig extends ExperimentConfig, DlcConfig { }
/** /**
* Collector DLC jobs info from DLC cluster, and update dlc job status locally * Collector DLC jobs info from DLC cluster, and update dlc job status locally
*/ */
...@@ -26,12 +24,12 @@ export class DlcEnvironmentService extends EnvironmentService { ...@@ -26,12 +24,12 @@ export class DlcEnvironmentService extends EnvironmentService {
private readonly log: Logger = getLogger('dlcEnvironmentService'); private readonly log: Logger = getLogger('dlcEnvironmentService');
private experimentId: string; private experimentId: string;
private config: FlattenDlcConfig; private config: DlcConfig;
constructor(config: ExperimentConfig, info: ExperimentStartupInfo) { constructor(config: DlcConfig, info: ExperimentStartupInfo) {
super(); super();
this.experimentId = info.experimentId; this.experimentId = info.experimentId;
this.config = flattenConfig(config, 'dlc'); this.config = config;
component.Container.bind(StorageService).to(MountedStorageService).scope(Scope.Singleton); component.Container.bind(StorageService).to(MountedStorageService).scope(Scope.Singleton);
const storageService = component.get<StorageService>(StorageService) const storageService = component.get<StorageService>(StorageService)
const remoteRoot = storageService.joinPath(this.config.localStorageMountPoint, 'nni-experiments', this.experimentId); const remoteRoot = storageService.joinPath(this.config.localStorageMountPoint, 'nni-experiments', this.experimentId);
......
...@@ -13,22 +13,23 @@ import { DlcEnvironmentService } from './dlcEnvironmentService'; ...@@ -13,22 +13,23 @@ import { DlcEnvironmentService } from './dlcEnvironmentService';
export async function createEnvironmentService(name: string, config: ExperimentConfig): Promise<EnvironmentService> { export async function createEnvironmentService(name: string, config: ExperimentConfig): Promise<EnvironmentService> {
const info = ExperimentStartupInfo.getInstance(); const info = ExperimentStartupInfo.getInstance();
const tsConfig: any = config.trainingService;
switch(name) { switch (name) {
case 'local': case 'local':
return new LocalEnvironmentService(config, info); return new LocalEnvironmentService(tsConfig, info);
case 'remote': case 'remote':
return new RemoteEnvironmentService(config, info); return new RemoteEnvironmentService(tsConfig, info);
case 'aml': case 'aml':
return new AMLEnvironmentService(config, info); return new AMLEnvironmentService(tsConfig, info);
case 'openpai': case 'openpai':
return new OpenPaiEnvironmentService(config, info); return new OpenPaiEnvironmentService(tsConfig, info);
case 'kubeflow': case 'kubeflow':
return new KubeflowEnvironmentService(config, info); return new KubeflowEnvironmentService(tsConfig, info);
case 'frameworkcontroller': case 'frameworkcontroller':
return new FrameworkControllerEnvironmentService(config, info); return new FrameworkControllerEnvironmentService(tsConfig, info);
case 'dlc': case 'dlc':
return new DlcEnvironmentService(config, info); return new DlcEnvironmentService(tsConfig, info);
} }
const esConfig = await getCustomEnvironmentServiceConfig(name); const esConfig = await getCustomEnvironmentServiceConfig(name);
...@@ -37,5 +38,5 @@ export async function createEnvironmentService(name: string, config: ExperimentC ...@@ -37,5 +38,5 @@ export async function createEnvironmentService(name: string, config: ExperimentC
} }
const esModule = importModule(esConfig.nodeModulePath); const esModule = importModule(esConfig.nodeModulePath);
const esClass = esModule[esConfig.nodeClassName] as any; const esClass = esModule[esConfig.nodeClassName] as any;
return new esClass(config, info); return new esClass(tsConfig, info);
} }
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
import * as fs from 'fs'; import * as fs from 'fs';
import * as path from 'path'; import * as path from 'path';
import * as component from '../../../../common/component'; import * as component from '../../../../common/component';
import { ExperimentConfig, FrameworkControllerConfig, flattenConfig, FrameworkControllerTaskRoleConfig } from '../../../../common/experimentConfig'; import { FrameworkControllerConfig, FrameworkControllerTaskRoleConfig, toMegaBytes } from '../../../../common/experimentConfig';
import { ExperimentStartupInfo } from '../../../../common/experimentStartupInfo'; import { ExperimentStartupInfo } from '../../../../common/experimentStartupInfo';
import { EnvironmentInformation } from '../../environment'; import { EnvironmentInformation } from '../../environment';
import { KubernetesEnvironmentService } from './kubernetesEnvironmentService'; import { KubernetesEnvironmentService } from './kubernetesEnvironmentService';
...@@ -15,23 +15,20 @@ import { FrameworkControllerClusterConfigAzure, FrameworkControllerJobStatus, Fr ...@@ -15,23 +15,20 @@ import { FrameworkControllerClusterConfigAzure, FrameworkControllerJobStatus, Fr
FrameworkControllerJobCompleteStatus } from '../../../kubernetes/frameworkcontroller/frameworkcontrollerConfig'; FrameworkControllerJobCompleteStatus } from '../../../kubernetes/frameworkcontroller/frameworkcontrollerConfig';
import { KeyVaultConfig, AzureStorage } from '../../../kubernetes/kubernetesConfig'; import { KeyVaultConfig, AzureStorage } from '../../../kubernetes/kubernetesConfig';
interface FlattenKubeflowConfig extends ExperimentConfig, FrameworkControllerConfig { }
@component.Singleton @component.Singleton
export class FrameworkControllerEnvironmentService extends KubernetesEnvironmentService { export class FrameworkControllerEnvironmentService extends KubernetesEnvironmentService {
private config: FlattenKubeflowConfig; private config: FrameworkControllerConfig;
private createStoragePromise?: Promise<void>; private createStoragePromise?: Promise<void>;
private readonly fcContainerPortMap: Map<string, number> = new Map<string, number>(); // store frameworkcontroller container port private readonly fcContainerPortMap: Map<string, number> = new Map<string, number>(); // store frameworkcontroller container port
constructor(config: ExperimentConfig, info: ExperimentStartupInfo) { constructor(config: FrameworkControllerConfig, info: ExperimentStartupInfo) {
super(config, info); super(config, info);
this.experimentId = info.experimentId; this.experimentId = info.experimentId;
this.config = flattenConfig(config, 'frameworkcontroller'); this.config = config;
// Create kubernetesCRDClient // Create kubernetesCRDClient
this.kubernetesCRDClient = FrameworkControllerClientFactory.createClient( this.kubernetesCRDClient = FrameworkControllerClientFactory.createClient(this.config.namespace);
this.config.namespace);
// Create storage // Create storage
if (this.config.storage.storageType === 'azureStorage') { if (this.config.storage.storageType === 'azureStorage') {
if (this.config.storage.azureShare === undefined || if (this.config.storage.azureShare === undefined ||
...@@ -40,27 +37,15 @@ export class FrameworkControllerEnvironmentService extends KubernetesEnvironment ...@@ -40,27 +37,15 @@ export class FrameworkControllerEnvironmentService extends KubernetesEnvironment
this.config.storage.keyVaultKey === undefined) { this.config.storage.keyVaultKey === undefined) {
throw new Error("Azure storage configuration error!"); throw new Error("Azure storage configuration error!");
} }
this.azureStorageAccountName = this.config.storage.azureAccount;
const azureStorage: AzureStorage = new AzureStorage(this.config.storage.azureShare, this.config.storage.azureAccount); this.azureStorageShare = this.config.storage.azureShare;
const keyValutConfig: KeyVaultConfig = new KeyVaultConfig(this.config.storage.keyVaultName, this.config.storage.keyVaultKey); this.createStoragePromise = this.createAzureStorage(this.config.storage.keyVaultName, this.config.storage.keyVaultKey);
const azureKubeflowClusterConfig: FrameworkControllerClusterConfigAzure = new FrameworkControllerClusterConfigAzure(
this.config.namespace, this.config.apiVersion, keyValutConfig, azureStorage);
this.azureStorageAccountName = azureKubeflowClusterConfig.azureStorage.accountName;
this.azureStorageShare = azureKubeflowClusterConfig.azureStorage.azureShare;
this.createStoragePromise = this.createAzureStorage(
azureKubeflowClusterConfig.keyVault.vaultName,
azureKubeflowClusterConfig.keyVault.name
);
} else if (this.config.storage.storageType === 'nfs') { } else if (this.config.storage.storageType === 'nfs') {
if (this.config.storage.server === undefined || if (this.config.storage.server === undefined ||
this.config.storage.path === undefined) { this.config.storage.path === undefined) {
throw new Error("NFS storage configuration error!"); throw new Error("NFS storage configuration error!");
} }
this.createStoragePromise = this.createNFSStorage( this.createStoragePromise = this.createNFSStorage(this.config.storage.server, this.config.storage.path);
this.config.storage.server,
this.config.storage.path
);
} }
} }
...@@ -91,9 +76,6 @@ export class FrameworkControllerEnvironmentService extends KubernetesEnvironment ...@@ -91,9 +76,6 @@ export class FrameworkControllerEnvironmentService extends KubernetesEnvironment
const expFolder = `${this.CONTAINER_MOUNT_PATH}/nni/${this.experimentId}`; const expFolder = `${this.CONTAINER_MOUNT_PATH}/nni/${this.experimentId}`;
environment.command = `cd ${expFolder} && ${environment.command} \ environment.command = `cd ${expFolder} && ${environment.command} \
1>${expFolder}/envs/${environment.id}/trialrunner_stdout 2>${expFolder}/envs/${environment.id}/trialrunner_stderr`; 1>${expFolder}/envs/${environment.id}/trialrunner_stdout 2>${expFolder}/envs/${environment.id}/trialrunner_stderr`;
if (this.config.deprecated && this.config.deprecated.useActiveGpu !== undefined) {
environment.useActiveGpu = this.config.deprecated.useActiveGpu;
}
environment.maxTrialNumberPerGpu = this.config.maxTrialNumberPerGpu; environment.maxTrialNumberPerGpu = this.config.maxTrialNumberPerGpu;
const frameworkcontrollerJobName: string = `nniexp${this.experimentId}env${environment.id}`.toLowerCase(); const frameworkcontrollerJobName: string = `nniexp${this.experimentId}env${environment.id}`.toLowerCase();
...@@ -148,7 +130,7 @@ export class FrameworkControllerEnvironmentService extends KubernetesEnvironment ...@@ -148,7 +130,7 @@ export class FrameworkControllerEnvironmentService extends KubernetesEnvironment
const podResources: any = []; const podResources: any = [];
for (const taskRole of this.config.taskRoles) { for (const taskRole of this.config.taskRoles) {
const resource: any = {}; const resource: any = {};
resource.requests = this.generatePodResource(taskRole.memorySize, taskRole.cpuNumber, taskRole.gpuNumber); resource.requests = this.generatePodResource(toMegaBytes(taskRole.memorySize), taskRole.cpuNumber, taskRole.gpuNumber);
resource.limits = {...resource.requests}; resource.limits = {...resource.requests};
podResources.push(resource); podResources.push(resource);
} }
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
import fs from 'fs'; import fs from 'fs';
import path from 'path'; import path from 'path';
import * as component from 'common/component'; import * as component from 'common/component';
import { ExperimentConfig, KubeflowConfig, flattenConfig } from 'common/experimentConfig'; import { KubeflowConfig, toMegaBytes } from 'common/experimentConfig';
import { ExperimentStartupInfo } from 'common/experimentStartupInfo'; import { ExperimentStartupInfo } from 'common/experimentStartupInfo';
import { EnvironmentInformation } from 'training_service/reusable/environment'; import { EnvironmentInformation } from 'training_service/reusable/environment';
import { KubernetesEnvironmentService } from './kubernetesEnvironmentService'; import { KubernetesEnvironmentService } from './kubernetesEnvironmentService';
...@@ -12,19 +12,17 @@ import { KubeflowOperatorClientFactory } from 'training_service/kubernetes/kubef ...@@ -12,19 +12,17 @@ import { KubeflowOperatorClientFactory } from 'training_service/kubernetes/kubef
import { KubeflowClusterConfigAzure } from 'training_service/kubernetes/kubeflow/kubeflowConfig'; import { KubeflowClusterConfigAzure } from 'training_service/kubernetes/kubeflow/kubeflowConfig';
import { KeyVaultConfig, AzureStorage } from 'training_service/kubernetes/kubernetesConfig'; import { KeyVaultConfig, AzureStorage } from 'training_service/kubernetes/kubernetesConfig';
interface FlattenKubeflowConfig extends ExperimentConfig, KubeflowConfig { }
@component.Singleton @component.Singleton
export class KubeflowEnvironmentService extends KubernetesEnvironmentService { export class KubeflowEnvironmentService extends KubernetesEnvironmentService {
private config: FlattenKubeflowConfig; private config: KubeflowConfig;
private createStoragePromise?: Promise<void>; private createStoragePromise?: Promise<void>;
constructor(config: ExperimentConfig, info: ExperimentStartupInfo) { constructor(config: KubeflowConfig, info: ExperimentStartupInfo) {
super(config, info); super(config, info);
this.experimentId = info.experimentId; this.experimentId = info.experimentId;
this.config = flattenConfig(config, 'kubeflow'); this.config = config;
// Create kubernetesCRDClient // Create kubernetesCRDClient
this.kubernetesCRDClient = KubeflowOperatorClientFactory.createClient( this.kubernetesCRDClient = KubeflowOperatorClientFactory.createClient(
this.config.operator, this.config.apiVersion); this.config.operator, this.config.apiVersion);
...@@ -82,9 +80,6 @@ export class KubeflowEnvironmentService extends KubernetesEnvironmentService { ...@@ -82,9 +80,6 @@ export class KubeflowEnvironmentService extends KubernetesEnvironmentService {
const expFolder = `${this.CONTAINER_MOUNT_PATH}/nni/${this.experimentId}`; const expFolder = `${this.CONTAINER_MOUNT_PATH}/nni/${this.experimentId}`;
environment.command = `cd ${expFolder} && ${environment.command} \ environment.command = `cd ${expFolder} && ${environment.command} \
1>${expFolder}/envs/${environment.id}/trialrunner_stdout 2>${expFolder}/envs/${environment.id}/trialrunner_stderr`; 1>${expFolder}/envs/${environment.id}/trialrunner_stdout 2>${expFolder}/envs/${environment.id}/trialrunner_stderr`;
if (this.config.deprecated && this.config.deprecated.useActiveGpu !== undefined) {
environment.useActiveGpu = this.config.deprecated.useActiveGpu;
}
environment.maxTrialNumberPerGpu = this.config.maxTrialNumberPerGpu; environment.maxTrialNumberPerGpu = this.config.maxTrialNumberPerGpu;
const kubeflowJobName: string = `nniexp${this.experimentId}env${environment.id}`.toLowerCase(); const kubeflowJobName: string = `nniexp${this.experimentId}env${environment.id}`.toLowerCase();
...@@ -118,22 +113,22 @@ export class KubeflowEnvironmentService extends KubernetesEnvironmentService { ...@@ -118,22 +113,22 @@ export class KubeflowEnvironmentService extends KubernetesEnvironmentService {
private async prepareKubeflowConfig(envId: string, kubeflowJobName: string): Promise<any> { private async prepareKubeflowConfig(envId: string, kubeflowJobName: string): Promise<any> {
const workerPodResources: any = {}; const workerPodResources: any = {};
if (this.config.worker !== undefined) { if (this.config.worker !== undefined) {
workerPodResources.requests = this.generatePodResource(this.config.worker.memorySize, this.config.worker.cpuNumber, workerPodResources.requests = this.generatePodResource(toMegaBytes(this.config.worker.memorySize),
this.config.worker.gpuNumber); this.config.worker.cpuNumber, this.config.worker.gpuNumber);
} }
workerPodResources.limits = {...workerPodResources.requests}; workerPodResources.limits = {...workerPodResources.requests};
const nonWorkerResources: any = {}; const nonWorkerResources: any = {};
if (this.config.operator === 'tf-operator') { if (this.config.operator === 'tf-operator') {
if (this.config.ps !== undefined) { if (this.config.ps !== undefined) {
nonWorkerResources.requests = this.generatePodResource(this.config.ps.memorySize, this.config.ps.cpuNumber, nonWorkerResources.requests = this.generatePodResource(toMegaBytes(this.config.ps.memorySize),
this.config.ps.gpuNumber); this.config.ps.cpuNumber, this.config.ps.gpuNumber);
nonWorkerResources.limits = {...nonWorkerResources.requests}; nonWorkerResources.limits = {...nonWorkerResources.requests};
} }
} else if (this.config.operator === 'pytorch-operator') { } else if (this.config.operator === 'pytorch-operator') {
if (this.config.master !== undefined) { if (this.config.master !== undefined) {
nonWorkerResources.requests = this.generatePodResource(this.config.master.memorySize, this.config.master.cpuNumber, nonWorkerResources.requests = this.generatePodResource(toMegaBytes(this.config.master.memorySize),
this.config.master.gpuNumber); this.config.master.cpuNumber, this.config.master.gpuNumber);
nonWorkerResources.limits = {...nonWorkerResources.requests}; nonWorkerResources.limits = {...nonWorkerResources.requests};
} }
} }
......
...@@ -33,7 +33,7 @@ export class KubernetesEnvironmentService extends EnvironmentService { ...@@ -33,7 +33,7 @@ export class KubernetesEnvironmentService extends EnvironmentService {
protected log: Logger = getLogger('KubernetesEnvironmentService'); protected log: Logger = getLogger('KubernetesEnvironmentService');
protected environmentWorkingFolder: string; protected environmentWorkingFolder: string;
constructor(_config: ExperimentConfig, info: ExperimentStartupInfo) { constructor(_config: any, info: ExperimentStartupInfo) {
super(); super();
this.CONTAINER_MOUNT_PATH = '/tmp/mount'; this.CONTAINER_MOUNT_PATH = '/tmp/mount';
this.genericK8sClient = new GeneralK8sClient(); this.genericK8sClient = new GeneralK8sClient();
......
...@@ -6,7 +6,7 @@ import request from 'request'; ...@@ -6,7 +6,7 @@ import request from 'request';
import { Container, Scope } from 'typescript-ioc'; import { Container, Scope } from 'typescript-ioc';
import { Deferred } from 'ts-deferred'; import { Deferred } from 'ts-deferred';
import * as component from 'common/component'; import * as component from 'common/component';
import { ExperimentConfig, OpenpaiConfig, flattenConfig, toMegaBytes } from 'common/experimentConfig'; import { OpenpaiConfig, toMegaBytes } from 'common/experimentConfig';
import { ExperimentStartupInfo } from 'common/experimentStartupInfo'; import { ExperimentStartupInfo } from 'common/experimentStartupInfo';
import { getLogger, Logger } from 'common/log'; import { getLogger, Logger } from 'common/log';
import { PAIClusterConfig } from 'training_service/pai/paiConfig'; import { PAIClusterConfig } from 'training_service/pai/paiConfig';
...@@ -16,8 +16,6 @@ import { SharedStorageService } from '../sharedStorage'; ...@@ -16,8 +16,6 @@ import { SharedStorageService } from '../sharedStorage';
import { MountedStorageService } from '../storages/mountedStorageService'; import { MountedStorageService } from '../storages/mountedStorageService';
import { StorageService } from '../storageService'; import { StorageService } from '../storageService';
interface FlattenOpenpaiConfig extends ExperimentConfig, OpenpaiConfig { }
/** /**
* Collector PAI jobs info from PAI cluster, and update pai job status locally * Collector PAI jobs info from PAI cluster, and update pai job status locally
*/ */
...@@ -30,12 +28,12 @@ export class OpenPaiEnvironmentService extends EnvironmentService { ...@@ -30,12 +28,12 @@ export class OpenPaiEnvironmentService extends EnvironmentService {
private paiToken: string; private paiToken: string;
private protocol: string; private protocol: string;
private experimentId: string; private experimentId: string;
private config: FlattenOpenpaiConfig; private config: OpenpaiConfig;
constructor(config: ExperimentConfig, info: ExperimentStartupInfo) { constructor(config: OpenpaiConfig, info: ExperimentStartupInfo) {
super(); super();
this.experimentId = info.experimentId; this.experimentId = info.experimentId;
this.config = flattenConfig(config, 'openpai'); this.config = config;
this.paiToken = this.config.token; this.paiToken = this.config.token;
this.protocol = this.config.host.toLowerCase().startsWith('https://') ? 'https' : 'http'; this.protocol = this.config.host.toLowerCase().startsWith('https://') ? 'https' : 'http';
Container.bind(StorageService) Container.bind(StorageService)
......
...@@ -7,7 +7,7 @@ import * as component from 'common/component'; ...@@ -7,7 +7,7 @@ import * as component from 'common/component';
import { getLogger, Logger } from 'common/log'; import { getLogger, Logger } from 'common/log';
import { EnvironmentInformation, EnvironmentService } from '../environment'; import { EnvironmentInformation, EnvironmentService } from '../environment';
import { getLogLevel } from 'common/utils'; import { getLogLevel } from 'common/utils';
import { ExperimentConfig, RemoteConfig, RemoteMachineConfig, flattenConfig } from 'common/experimentConfig'; import { RemoteConfig, RemoteMachineConfig } from 'common/experimentConfig';
import { ExperimentStartupInfo } from 'common/experimentStartupInfo'; import { ExperimentStartupInfo } from 'common/experimentStartupInfo';
import { execMkdir } from 'training_service/common/util'; import { execMkdir } from 'training_service/common/util';
import { ExecutorManager } from 'training_service/remote_machine/remoteMachineData'; import { ExecutorManager } from 'training_service/remote_machine/remoteMachineData';
...@@ -15,8 +15,6 @@ import { ShellExecutor } from 'training_service/remote_machine/shellExecutor'; ...@@ -15,8 +15,6 @@ import { ShellExecutor } from 'training_service/remote_machine/shellExecutor';
import { RemoteMachineEnvironmentInformation } from '../remote/remoteConfig'; import { RemoteMachineEnvironmentInformation } from '../remote/remoteConfig';
import { SharedStorageService } from '../sharedStorage' import { SharedStorageService } from '../sharedStorage'
interface FlattenRemoteConfig extends ExperimentConfig, RemoteConfig { }
@component.Singleton @component.Singleton
export class RemoteEnvironmentService extends EnvironmentService { export class RemoteEnvironmentService extends EnvironmentService {
...@@ -29,9 +27,9 @@ export class RemoteEnvironmentService extends EnvironmentService { ...@@ -29,9 +27,9 @@ export class RemoteEnvironmentService extends EnvironmentService {
private experimentRootDir: string; private experimentRootDir: string;
private remoteExperimentRootDir: string = ""; private remoteExperimentRootDir: string = "";
private experimentId: string; private experimentId: string;
private config: FlattenRemoteConfig; private config: RemoteConfig;
constructor(config: ExperimentConfig, info: ExperimentStartupInfo) { constructor(config: RemoteConfig, info: ExperimentStartupInfo) {
super(); super();
this.experimentId = info.experimentId; this.experimentId = info.experimentId;
this.environmentExecutorManagerMap = new Map<string, ExecutorManager>(); this.environmentExecutorManagerMap = new Map<string, ExecutorManager>();
...@@ -39,7 +37,7 @@ export class RemoteEnvironmentService extends EnvironmentService { ...@@ -39,7 +37,7 @@ export class RemoteEnvironmentService extends EnvironmentService {
this.remoteMachineMetaOccupiedMap = new Map<RemoteMachineConfig, boolean>(); this.remoteMachineMetaOccupiedMap = new Map<RemoteMachineConfig, boolean>();
this.experimentRootDir = info.logDir; this.experimentRootDir = info.logDir;
this.log = getLogger('RemoteEnvironmentService'); this.log = getLogger('RemoteEnvironmentService');
this.config = flattenConfig(config, 'remote'); this.config = config;
// codeDir is not a valid directory, throw Error // codeDir is not a valid directory, throw Error
if (!fs.lstatSync(this.config.trialCodeDirectory).isDirectory()) { if (!fs.lstatSync(this.config.trialCodeDirectory).isDirectory()) {
......
...@@ -26,9 +26,9 @@ class RouterTrainingService implements TrainingService { ...@@ -26,9 +26,9 @@ class RouterTrainingService implements TrainingService {
instance.log = getLogger('RouterTrainingService'); instance.log = getLogger('RouterTrainingService');
const platform = Array.isArray(config.trainingService) ? 'hybrid' : config.trainingService.platform; const platform = Array.isArray(config.trainingService) ? 'hybrid' : config.trainingService.platform;
if (platform === 'remote' && (<RemoteConfig>config.trainingService).reuseMode === false) { if (platform === 'remote' && (<RemoteConfig>config.trainingService).reuseMode === false) {
instance.internalTrainingService = new RemoteMachineTrainingService(config); instance.internalTrainingService = new RemoteMachineTrainingService(<RemoteConfig>config.trainingService);
} else if (platform === 'openpai' && (<OpenpaiConfig>config.trainingService).reuseMode === false) { } else if (platform === 'openpai' && (<OpenpaiConfig>config.trainingService).reuseMode === false) {
instance.internalTrainingService = new PAITrainingService(config); instance.internalTrainingService = new PAITrainingService(<OpenpaiConfig>config.trainingService);
} else if (platform === 'kubeflow' && (<KubeflowConfig>config.trainingService).reuseMode === false) { } else if (platform === 'kubeflow' && (<KubeflowConfig>config.trainingService).reuseMode === false) {
instance.internalTrainingService = new KubeflowTrainingService(); instance.internalTrainingService = new KubeflowTrainingService();
} else if (platform === 'frameworkcontroller' && (<FrameworkControllerConfig>config.trainingService).reuseMode === false) { } else if (platform === 'frameworkcontroller' && (<FrameworkControllerConfig>config.trainingService).reuseMode === false) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment