Unverified Commit 872554f1 authored by SparkSnail's avatar SparkSnail Committed by GitHub
Browse files

Support heterogeneous environment service (#3097)

parent dec91f7e
...@@ -45,6 +45,10 @@ export class OpenPaiEnvironmentService extends EnvironmentService { ...@@ -45,6 +45,10 @@ export class OpenPaiEnvironmentService extends EnvironmentService {
return true; return true;
} }
public get getName(): string {
return 'pai';
}
public async config(key: string, value: string): Promise<void> { public async config(key: string, value: string): Promise<void> {
switch (key) { switch (key) {
case TrialConfigMetadataKey.PAI_CLUSTER_CONFIG: case TrialConfigMetadataKey.PAI_CLUSTER_CONFIG:
......
...@@ -63,6 +63,10 @@ export class RemoteEnvironmentService extends EnvironmentService { ...@@ -63,6 +63,10 @@ export class RemoteEnvironmentService extends EnvironmentService {
return false; return false;
} }
public get getName(): string {
return 'remote';
}
public async config(key: string, value: string): Promise<void> { public async config(key: string, value: string): Promise<void> {
switch (key) { switch (key) {
case TrialConfigMetadataKey.MACHINE_LIST: case TrialConfigMetadataKey.MACHINE_LIST:
...@@ -135,6 +139,14 @@ export class RemoteEnvironmentService extends EnvironmentService { ...@@ -135,6 +139,14 @@ export class RemoteEnvironmentService extends EnvironmentService {
await executor.allowPermission(true, nniRootDir); await executor.allowPermission(true, nniRootDir);
} }
public async refreshEnvironmentsStatus(environments: EnvironmentInformation[]): Promise<void> {
const tasks: Promise<void>[] = [];
environments.forEach(async (environment) => {
tasks.push(this.refreshEnvironment(environment));
});
await Promise.all(tasks);
}
private async refreshEnvironment(environment: EnvironmentInformation): Promise<void> { private async refreshEnvironment(environment: EnvironmentInformation): Promise<void> {
const executor = await this.getExecutor(environment.id); const executor = await this.getExecutor(environment.id);
const jobpidPath: string = `${environment.runnerWorkingFolder}/pid`; const jobpidPath: string = `${environment.runnerWorkingFolder}/pid`;
...@@ -176,14 +188,6 @@ export class RemoteEnvironmentService extends EnvironmentService { ...@@ -176,14 +188,6 @@ export class RemoteEnvironmentService extends EnvironmentService {
} }
} }
public async refreshEnvironmentsStatus(environments: EnvironmentInformation[]): Promise<void> {
const tasks: Promise<void>[] = [];
environments.forEach(async (environment) => {
tasks.push(this.refreshEnvironment(environment));
});
await Promise.all(tasks);
}
/** /**
* If a environment is finished, release the connection resource * If a environment is finished, release the connection resource
* @param environment remote machine environment job detail * @param environment remote machine environment job detail
......
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
export class HeterogenousConfig {
public readonly trainingServicePlatforms: string[];
constructor(trainingServicePlatforms: string[]) {
this.trainingServicePlatforms = trainingServicePlatforms;
}
}
...@@ -13,14 +13,11 @@ import { TrialConfigMetadataKey } from '../common/trialConfigMetadataKey'; ...@@ -13,14 +13,11 @@ import { TrialConfigMetadataKey } from '../common/trialConfigMetadataKey';
import { PAIClusterConfig } from '../pai/paiConfig'; import { PAIClusterConfig } from '../pai/paiConfig';
import { PAIK8STrainingService } from '../pai/paiK8S/paiK8STrainingService'; import { PAIK8STrainingService } from '../pai/paiK8S/paiK8STrainingService';
import { RemoteMachineTrainingService } from '../remote_machine/remoteMachineTrainingService'; import { RemoteMachineTrainingService } from '../remote_machine/remoteMachineTrainingService';
import { EnvironmentService } from './environment';
import { OpenPaiEnvironmentService } from './environments/openPaiEnvironmentService';
import { AMLEnvironmentService } from './environments/amlEnvironmentService';
import { RemoteEnvironmentService } from './environments/remoteEnvironmentService';
import { MountedStorageService } from './storages/mountedStorageService'; import { MountedStorageService } from './storages/mountedStorageService';
import { StorageService } from './storageService'; import { StorageService } from './storageService';
import { TrialDispatcher } from './trialDispatcher'; import { TrialDispatcher } from './trialDispatcher';
import { RemoteConfig } from './remote/remoteConfig'; import { RemoteConfig } from './remote/remoteConfig';
import { HeterogenousConfig } from './heterogenous/heterogenousConfig';
/** /**
...@@ -31,7 +28,6 @@ import { RemoteConfig } from './remote/remoteConfig'; ...@@ -31,7 +28,6 @@ import { RemoteConfig } from './remote/remoteConfig';
class RouterTrainingService implements TrainingService { class RouterTrainingService implements TrainingService {
protected readonly log!: Logger; protected readonly log!: Logger;
private internalTrainingService: TrainingService | undefined; private internalTrainingService: TrainingService | undefined;
private metaDataCache: Map<string, string> = new Map<string, string>();
constructor() { constructor() {
this.log = getLogger(); this.log = getLogger();
...@@ -99,75 +95,70 @@ class RouterTrainingService implements TrainingService { ...@@ -99,75 +95,70 @@ class RouterTrainingService implements TrainingService {
public async setClusterMetadata(key: string, value: string): Promise<void> { public async setClusterMetadata(key: string, value: string): Promise<void> {
if (this.internalTrainingService === undefined) { if (this.internalTrainingService === undefined) {
if (key === TrialConfigMetadataKey.PAI_CLUSTER_CONFIG) { // Need to refactor configuration, remove heterogeneous_config field in the future
if (key === TrialConfigMetadataKey.HETEROGENEOUS_CONFIG){
this.internalTrainingService = component.get(TrialDispatcher);
const heterogenousConfig: HeterogenousConfig = <HeterogenousConfig>JSON.parse(value);
if (this.internalTrainingService === undefined) {
throw new Error("internalTrainingService not initialized!");
}
// Initialize storageService for pai, only support singleton for now, need refactor
if (heterogenousConfig.trainingServicePlatforms.includes('pai')) {
Container.bind(StorageService)
.to(MountedStorageService)
.scope(Scope.Singleton);
}
await this.internalTrainingService.setClusterMetadata('platform_list',
heterogenousConfig.trainingServicePlatforms.join(','));
} else if (key === TrialConfigMetadataKey.LOCAL_CONFIG) {
this.internalTrainingService = component.get(TrialDispatcher);
if (this.internalTrainingService === undefined) {
throw new Error("internalTrainingService not initialized!");
}
await this.internalTrainingService.setClusterMetadata('platform_list', 'local');
} else if (key === TrialConfigMetadataKey.PAI_CLUSTER_CONFIG) {
const config = <PAIClusterConfig>JSON.parse(value); const config = <PAIClusterConfig>JSON.parse(value);
if (config.reuse === true) { if (config.reuse === true) {
this.log.info(`reuse flag enabled, use EnvironmentManager.`); this.log.info(`reuse flag enabled, use EnvironmentManager.`);
this.internalTrainingService = component.get(TrialDispatcher); this.internalTrainingService = component.get(TrialDispatcher);
// TODO to support other serivces later.
Container.bind(EnvironmentService)
.to(OpenPaiEnvironmentService)
.scope(Scope.Singleton);
// TODO to support other storages later. // TODO to support other storages later.
Container.bind(StorageService) Container.bind(StorageService)
.to(MountedStorageService) .to(MountedStorageService)
.scope(Scope.Singleton); .scope(Scope.Singleton);
if (this.internalTrainingService === undefined) {
throw new Error("internalTrainingService not initialized!");
}
await this.internalTrainingService.setClusterMetadata('platform_list', 'pai');
} else { } else {
this.log.debug(`caching metadata key:{} value:{}, as training service is not determined.`); this.log.debug(`caching metadata key:{} value:{}, as training service is not determined.`);
this.internalTrainingService = component.get(PAIK8STrainingService); this.internalTrainingService = component.get(PAIK8STrainingService);
} }
for (const [key, value] of this.metaDataCache) {
if (this.internalTrainingService === undefined) {
throw new Error("TrainingService is not assigned!");
}
await this.internalTrainingService.setClusterMetadata(key, value);
}
if (this.internalTrainingService === undefined) {
throw new Error("TrainingService is not assigned!");
}
await this.internalTrainingService.setClusterMetadata(key, value);
this.metaDataCache.clear();
} else if (key === TrialConfigMetadataKey.AML_CLUSTER_CONFIG) { } else if (key === TrialConfigMetadataKey.AML_CLUSTER_CONFIG) {
this.internalTrainingService = component.get(TrialDispatcher); this.internalTrainingService = component.get(TrialDispatcher);
Container.bind(EnvironmentService)
.to(AMLEnvironmentService)
.scope(Scope.Singleton);
for (const [key, value] of this.metaDataCache) {
if (this.internalTrainingService === undefined) {
throw new Error("TrainingService is not assigned!");
}
await this.internalTrainingService.setClusterMetadata(key, value);
}
if (this.internalTrainingService === undefined) { if (this.internalTrainingService === undefined) {
throw new Error("TrainingService is not assigned!"); throw new Error("internalTrainingService not initialized!");
} }
await this.internalTrainingService.setClusterMetadata(key, value); await this.internalTrainingService.setClusterMetadata('platform_list', 'aml');
this.metaDataCache.clear();
} else if (key === TrialConfigMetadataKey.REMOTE_CONFIG) { } else if (key === TrialConfigMetadataKey.REMOTE_CONFIG) {
const config = <RemoteConfig>JSON.parse(value); const config = <RemoteConfig>JSON.parse(value);
if (config.reuse === true) { if (config.reuse === true) {
this.log.info(`reuse flag enabled, use EnvironmentManager.`); this.log.info(`reuse flag enabled, use EnvironmentManager.`);
this.internalTrainingService = component.get(TrialDispatcher); this.internalTrainingService = component.get(TrialDispatcher);
Container.bind(EnvironmentService) if (this.internalTrainingService === undefined) {
.to(RemoteEnvironmentService) throw new Error("internalTrainingService not initialized!");
.scope(Scope.Singleton); }
await this.internalTrainingService.setClusterMetadata('platform_list', 'remote');
} else { } else {
this.log.debug(`caching metadata key:{} value:{}, as training service is not determined.`); this.log.debug(`caching metadata key:{} value:{}, as training service is not determined.`);
this.internalTrainingService = component.get(RemoteMachineTrainingService); this.internalTrainingService = component.get(RemoteMachineTrainingService);
} }
} else {
this.log.debug(`caching metadata key:{} value:{}, as training service is not determined.`);
this.metaDataCache.set(key, value);
} }
} else {
await this.internalTrainingService.setClusterMetadata(key, value);
} }
if (this.internalTrainingService === undefined) {
throw new Error("internalTrainingService not initialized!");
}
await this.internalTrainingService.setClusterMetadata(key, value);
} }
public async getClusterMetadata(key: string): Promise<string> { public async getClusterMetadata(key: string): Promise<string> {
......
...@@ -3,15 +3,13 @@ ...@@ -3,15 +3,13 @@
import * as chai from 'chai'; import * as chai from 'chai';
import * as path from 'path'; import * as path from 'path';
import { Scope } from "typescript-ioc";
import * as component from '../../../common/component';
import { getLogger, Logger } from "../../../common/log"; import { getLogger, Logger } from "../../../common/log";
import { TrialJobApplicationForm, TrialJobStatus } from '../../../common/trainingService'; import { TrialJobApplicationForm, TrialJobStatus } from '../../../common/trainingService';
import { cleanupUnitTest, delay, prepareUnitTest, uniqueString } from '../../../common/utils'; import { cleanupUnitTest, delay, prepareUnitTest, uniqueString } from '../../../common/utils';
import { INITIALIZED, KILL_TRIAL_JOB, NEW_TRIAL_JOB, SEND_TRIAL_JOB_PARAMETER, TRIAL_END, GPU_INFO } from '../../../core/commands'; import { INITIALIZED, KILL_TRIAL_JOB, NEW_TRIAL_JOB, SEND_TRIAL_JOB_PARAMETER, TRIAL_END, GPU_INFO } from '../../../core/commands';
import { TrialConfigMetadataKey } from '../../../training_service/common/trialConfigMetadataKey'; import { TrialConfigMetadataKey } from '../../../training_service/common/trialConfigMetadataKey';
import { Command } from '../commandChannel'; import { Command, CommandChannel } from '../commandChannel';
import { EnvironmentInformation, EnvironmentService } from "../environment"; import { Channel, EnvironmentInformation, EnvironmentService } from "../environment";
import { TrialDetail } from '../trial'; import { TrialDetail } from '../trial';
import { TrialDispatcher } from "../trialDispatcher"; import { TrialDispatcher } from "../trialDispatcher";
import { UtCommandChannel } from './utCommandChannel'; import { UtCommandChannel } from './utCommandChannel';
...@@ -54,7 +52,7 @@ async function waitResult<TResult>(callback: () => Promise<TResult | undefined>, ...@@ -54,7 +52,7 @@ async function waitResult<TResult>(callback: () => Promise<TResult | undefined>,
return undefined; return undefined;
} }
async function waitResultMust<TResult>(callback: () => Promise<TResult | undefined>, waitMs: number = 1000, interval: number = 1): Promise<TResult> { async function waitResultMust<TResult>(callback: () => Promise<TResult | undefined>, waitMs: number = 10000, interval: number = 1): Promise<TResult> {
const result = await waitResult(callback, waitMs, interval, true); const result = await waitResult(callback, waitMs, interval, true);
// this error should be thrown in waitResult already. // this error should be thrown in waitResult already.
if (result === undefined) { if (result === undefined) {
...@@ -201,16 +199,21 @@ describe('Unit Test for TrialDispatcher', () => { ...@@ -201,16 +199,21 @@ describe('Unit Test for TrialDispatcher', () => {
nniManagerIp: "127.0.0.1", nniManagerIp: "127.0.0.1",
} }
trialDispatcher = new TrialDispatcher(); trialDispatcher = new TrialDispatcher();
component.Container.bind(EnvironmentService)
.to(UtEnvironmentService)
.scope(Scope.Singleton);
await trialDispatcher.setClusterMetadata(TrialConfigMetadataKey.TRIAL_CONFIG, JSON.stringify(trialConfig)); await trialDispatcher.setClusterMetadata(TrialConfigMetadataKey.TRIAL_CONFIG, JSON.stringify(trialConfig));
await trialDispatcher.setClusterMetadata(TrialConfigMetadataKey.NNI_MANAGER_IP, JSON.stringify(nniManagerIpConfig)); await trialDispatcher.setClusterMetadata(TrialConfigMetadataKey.NNI_MANAGER_IP, JSON.stringify(nniManagerIpConfig));
trialRunPromise = trialDispatcher.run(); // set ut environment
let environmentServiceList: EnvironmentService[] = [];
environmentService = new UtEnvironmentService();
environmentServiceList.push(environmentService);
trialDispatcher.environmentServiceList = environmentServiceList;
// set ut command channel
environmentService.initCommandChannel(trialDispatcher.commandEmitter);
commandChannel = environmentService.getCommandChannel as UtCommandChannel;
trialDispatcher.commandChannelSet = new Set<CommandChannel>().add(environmentService.getCommandChannel);
trialDispatcher.environmentMaintenceLoopInterval = 1000;
environmentService = component.get(EnvironmentService) as UtEnvironmentService; trialRunPromise = trialDispatcher.run();
commandChannel = environmentService.testGetCommandChannel();
}); });
afterEach(async () => { afterEach(async () => {
...@@ -258,9 +261,6 @@ describe('Unit Test for TrialDispatcher', () => { ...@@ -258,9 +261,6 @@ describe('Unit Test for TrialDispatcher', () => {
await waitEnvironment(2, previousEnvironments, environmentService, commandChannel); await waitEnvironment(2, previousEnvironments, environmentService, commandChannel);
await verifyTrialRunning(commandChannel, trialDetail); await verifyTrialRunning(commandChannel, trialDetail);
await verifyTrialResult(commandChannel, trialDetail, -1); await verifyTrialResult(commandChannel, trialDetail, -1);
await waitResultMust<true>(async () => {
return environment.status === 'USER_CANCELED' ? true : undefined;
});
chai.assert.equal(environmentService.testGetEnvironments().size, 2, "as env not reused, so only 2 envs should be here."); chai.assert.equal(environmentService.testGetEnvironments().size, 2, "as env not reused, so only 2 envs should be here.");
const trials = await trialDispatcher.listTrialJobs(); const trials = await trialDispatcher.listTrialJobs();
...@@ -433,12 +433,10 @@ describe('Unit Test for TrialDispatcher', () => { ...@@ -433,12 +433,10 @@ describe('Unit Test for TrialDispatcher', () => {
let environment = await waitEnvironment(1, previousEnvironments, environmentService, commandChannel); let environment = await waitEnvironment(1, previousEnvironments, environmentService, commandChannel);
await verifyTrialRunning(commandChannel, trialDetail); await verifyTrialRunning(commandChannel, trialDetail);
await verifyTrialResult(commandChannel, trialDetail, 0); await verifyTrialResult(commandChannel, trialDetail, 0);
environmentService.testSetEnvironmentStatus(environment, 'SUCCEEDED'); environmentService.testSetEnvironmentStatus(environment, 'SUCCEEDED');
await waitResultMust<boolean>(async () => { await waitResultMust<boolean>(async () => {
return environment.status === 'SUCCEEDED' ? true : undefined; return environment.status === 'SUCCEEDED' ? true : undefined;
}); });
trialDetail = await newTrial(trialDispatcher); trialDetail = await newTrial(trialDispatcher);
await waitEnvironment(2, previousEnvironments, environmentService, commandChannel); await waitEnvironment(2, previousEnvironments, environmentService, commandChannel);
await verifyTrialRunning(commandChannel, trialDetail); await verifyTrialRunning(commandChannel, trialDetail);
......
// Copyright (c) Microsoft Corporation. // Copyright (c) Microsoft Corporation.
// Licensed under the MIT license. // Licensed under the MIT license.
import { EnvironmentInformation, EnvironmentService, EnvironmentStatus } from "../environment"; import { Channel, EnvironmentInformation, EnvironmentService, EnvironmentStatus } from "../environment";
import { EventEmitter } from "events"; import { EventEmitter } from 'events';
import { CommandChannel } from "../commandChannel";
import { UtCommandChannel } from "./utCommandChannel"; import { UtCommandChannel } from "./utCommandChannel";
export class UtEnvironmentService extends EnvironmentService { export class UtEnvironmentService extends EnvironmentService {
private commandChannel: UtCommandChannel | undefined;
private allEnvironments = new Map<string, EnvironmentInformation>(); private allEnvironments = new Map<string, EnvironmentInformation>();
private hasMoreEnvironmentsInternal = true; private hasMoreEnvironmentsInternal = true;
...@@ -23,6 +21,14 @@ export class UtEnvironmentService extends EnvironmentService { ...@@ -23,6 +21,14 @@ export class UtEnvironmentService extends EnvironmentService {
return 1; return 1;
} }
public get getName(): string {
return 'ut';
}
public initCommandChannel(eventEmitter: EventEmitter): void {
this.commandChannel = new UtCommandChannel(eventEmitter);
}
public testSetEnvironmentStatus(environment: EnvironmentInformation, newStatus: EnvironmentStatus): void { public testSetEnvironmentStatus(environment: EnvironmentInformation, newStatus: EnvironmentStatus): void {
environment.status = newStatus; environment.status = newStatus;
} }
...@@ -35,13 +41,6 @@ export class UtEnvironmentService extends EnvironmentService { ...@@ -35,13 +41,6 @@ export class UtEnvironmentService extends EnvironmentService {
return this.allEnvironments; return this.allEnvironments;
} }
public testGetCommandChannel(): UtCommandChannel {
if (this.commandChannel === undefined) {
throw new Error(`command channel shouldn't be undefined.`);
}
return this.commandChannel;
}
public testSetNoMoreEnvironment(hasMore: boolean): void { public testSetNoMoreEnvironment(hasMore: boolean): void {
this.hasMoreEnvironmentsInternal = hasMore; this.hasMoreEnvironmentsInternal = hasMore;
} }
...@@ -50,11 +49,6 @@ export class UtEnvironmentService extends EnvironmentService { ...@@ -50,11 +49,6 @@ export class UtEnvironmentService extends EnvironmentService {
return this.hasMoreEnvironmentsInternal; return this.hasMoreEnvironmentsInternal;
} }
public createCommandChannel(commandEmitter: EventEmitter): CommandChannel {
this.commandChannel = new UtCommandChannel(commandEmitter)
return this.commandChannel;
}
public async config(_key: string, _value: string): Promise<void> { public async config(_key: string, _value: string): Promise<void> {
// do nothing // do nothing
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment