"docs/vscode:/vscode.git/clone" did not exist on "af89df8c56b197a3e3d6fa08767f8d27971e3a3a"
Unverified Commit 872554f1 authored by SparkSnail's avatar SparkSnail Committed by GitHub
Browse files

Support heterogeneous environment service (#3097)

parent dec91f7e
......@@ -45,6 +45,10 @@ export class OpenPaiEnvironmentService extends EnvironmentService {
return true;
}
public get getName(): string {
return 'pai';
}
public async config(key: string, value: string): Promise<void> {
switch (key) {
case TrialConfigMetadataKey.PAI_CLUSTER_CONFIG:
......
......@@ -63,6 +63,10 @@ export class RemoteEnvironmentService extends EnvironmentService {
return false;
}
public get getName(): string {
return 'remote';
}
public async config(key: string, value: string): Promise<void> {
switch (key) {
case TrialConfigMetadataKey.MACHINE_LIST:
......@@ -134,7 +138,15 @@ export class RemoteEnvironmentService extends EnvironmentService {
await executor.createFolder(remoteGpuScriptCollectorDir, true);
await executor.allowPermission(true, nniRootDir);
}
public async refreshEnvironmentsStatus(environments: EnvironmentInformation[]): Promise<void> {
const tasks: Promise<void>[] = [];
environments.forEach(async (environment) => {
tasks.push(this.refreshEnvironment(environment));
});
await Promise.all(tasks);
}
private async refreshEnvironment(environment: EnvironmentInformation): Promise<void> {
const executor = await this.getExecutor(environment.id);
const jobpidPath: string = `${environment.runnerWorkingFolder}/pid`;
......@@ -176,14 +188,6 @@ export class RemoteEnvironmentService extends EnvironmentService {
}
}
public async refreshEnvironmentsStatus(environments: EnvironmentInformation[]): Promise<void> {
const tasks: Promise<void>[] = [];
environments.forEach(async (environment) => {
tasks.push(this.refreshEnvironment(environment));
});
await Promise.all(tasks);
}
/**
* If a environment is finished, release the connection resource
* @param environment remote machine environment job detail
......
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
export class HeterogenousConfig {
public readonly trainingServicePlatforms: string[];
constructor(trainingServicePlatforms: string[]) {
this.trainingServicePlatforms = trainingServicePlatforms;
}
}
......@@ -13,14 +13,11 @@ import { TrialConfigMetadataKey } from '../common/trialConfigMetadataKey';
import { PAIClusterConfig } from '../pai/paiConfig';
import { PAIK8STrainingService } from '../pai/paiK8S/paiK8STrainingService';
import { RemoteMachineTrainingService } from '../remote_machine/remoteMachineTrainingService';
import { EnvironmentService } from './environment';
import { OpenPaiEnvironmentService } from './environments/openPaiEnvironmentService';
import { AMLEnvironmentService } from './environments/amlEnvironmentService';
import { RemoteEnvironmentService } from './environments/remoteEnvironmentService';
import { MountedStorageService } from './storages/mountedStorageService';
import { StorageService } from './storageService';
import { TrialDispatcher } from './trialDispatcher';
import { RemoteConfig } from './remote/remoteConfig';
import { HeterogenousConfig } from './heterogenous/heterogenousConfig';
/**
......@@ -31,7 +28,6 @@ import { RemoteConfig } from './remote/remoteConfig';
class RouterTrainingService implements TrainingService {
protected readonly log!: Logger;
private internalTrainingService: TrainingService | undefined;
private metaDataCache: Map<string, string> = new Map<string, string>();
constructor() {
this.log = getLogger();
......@@ -99,75 +95,70 @@ class RouterTrainingService implements TrainingService {
public async setClusterMetadata(key: string, value: string): Promise<void> {
if (this.internalTrainingService === undefined) {
if (key === TrialConfigMetadataKey.PAI_CLUSTER_CONFIG) {
// Need to refactor configuration, remove heterogeneous_config field in the future
if (key === TrialConfigMetadataKey.HETEROGENEOUS_CONFIG){
this.internalTrainingService = component.get(TrialDispatcher);
const heterogenousConfig: HeterogenousConfig = <HeterogenousConfig>JSON.parse(value);
if (this.internalTrainingService === undefined) {
throw new Error("internalTrainingService not initialized!");
}
// Initialize storageService for pai, only support singleton for now, need refactor
if (heterogenousConfig.trainingServicePlatforms.includes('pai')) {
Container.bind(StorageService)
.to(MountedStorageService)
.scope(Scope.Singleton);
}
await this.internalTrainingService.setClusterMetadata('platform_list',
heterogenousConfig.trainingServicePlatforms.join(','));
} else if (key === TrialConfigMetadataKey.LOCAL_CONFIG) {
this.internalTrainingService = component.get(TrialDispatcher);
if (this.internalTrainingService === undefined) {
throw new Error("internalTrainingService not initialized!");
}
await this.internalTrainingService.setClusterMetadata('platform_list', 'local');
} else if (key === TrialConfigMetadataKey.PAI_CLUSTER_CONFIG) {
const config = <PAIClusterConfig>JSON.parse(value);
if (config.reuse === true) {
this.log.info(`reuse flag enabled, use EnvironmentManager.`);
this.internalTrainingService = component.get(TrialDispatcher);
// TODO to support other serivces later.
Container.bind(EnvironmentService)
.to(OpenPaiEnvironmentService)
.scope(Scope.Singleton);
// TODO to support other storages later.
Container.bind(StorageService)
.to(MountedStorageService)
.scope(Scope.Singleton);
if (this.internalTrainingService === undefined) {
throw new Error("internalTrainingService not initialized!");
}
await this.internalTrainingService.setClusterMetadata('platform_list', 'pai');
} else {
this.log.debug(`caching metadata key:{} value:{}, as training service is not determined.`);
this.internalTrainingService = component.get(PAIK8STrainingService);
}
for (const [key, value] of this.metaDataCache) {
if (this.internalTrainingService === undefined) {
throw new Error("TrainingService is not assigned!");
}
await this.internalTrainingService.setClusterMetadata(key, value);
}
if (this.internalTrainingService === undefined) {
throw new Error("TrainingService is not assigned!");
}
await this.internalTrainingService.setClusterMetadata(key, value);
this.metaDataCache.clear();
} else if (key === TrialConfigMetadataKey.AML_CLUSTER_CONFIG) {
this.internalTrainingService = component.get(TrialDispatcher);
Container.bind(EnvironmentService)
.to(AMLEnvironmentService)
.scope(Scope.Singleton);
for (const [key, value] of this.metaDataCache) {
if (this.internalTrainingService === undefined) {
throw new Error("TrainingService is not assigned!");
}
await this.internalTrainingService.setClusterMetadata(key, value);
}
if (this.internalTrainingService === undefined) {
throw new Error("TrainingService is not assigned!");
throw new Error("internalTrainingService not initialized!");
}
await this.internalTrainingService.setClusterMetadata(key, value);
this.metaDataCache.clear();
await this.internalTrainingService.setClusterMetadata('platform_list', 'aml');
} else if (key === TrialConfigMetadataKey.REMOTE_CONFIG) {
const config = <RemoteConfig>JSON.parse(value);
if (config.reuse === true) {
this.log.info(`reuse flag enabled, use EnvironmentManager.`);
this.internalTrainingService = component.get(TrialDispatcher);
Container.bind(EnvironmentService)
.to(RemoteEnvironmentService)
.scope(Scope.Singleton);
if (this.internalTrainingService === undefined) {
throw new Error("internalTrainingService not initialized!");
}
await this.internalTrainingService.setClusterMetadata('platform_list', 'remote');
} else {
this.log.debug(`caching metadata key:{} value:{}, as training service is not determined.`);
this.internalTrainingService = component.get(RemoteMachineTrainingService);
}
} else {
this.log.debug(`caching metadata key:{} value:{}, as training service is not determined.`);
this.metaDataCache.set(key, value);
}
} else {
await this.internalTrainingService.setClusterMetadata(key, value);
}
if (this.internalTrainingService === undefined) {
throw new Error("internalTrainingService not initialized!");
}
await this.internalTrainingService.setClusterMetadata(key, value);
}
public async getClusterMetadata(key: string): Promise<string> {
......
......@@ -3,15 +3,13 @@
import * as chai from 'chai';
import * as path from 'path';
import { Scope } from "typescript-ioc";
import * as component from '../../../common/component';
import { getLogger, Logger } from "../../../common/log";
import { TrialJobApplicationForm, TrialJobStatus } from '../../../common/trainingService';
import { cleanupUnitTest, delay, prepareUnitTest, uniqueString } from '../../../common/utils';
import { INITIALIZED, KILL_TRIAL_JOB, NEW_TRIAL_JOB, SEND_TRIAL_JOB_PARAMETER, TRIAL_END, GPU_INFO } from '../../../core/commands';
import { TrialConfigMetadataKey } from '../../../training_service/common/trialConfigMetadataKey';
import { Command } from '../commandChannel';
import { EnvironmentInformation, EnvironmentService } from "../environment";
import { Command, CommandChannel } from '../commandChannel';
import { Channel, EnvironmentInformation, EnvironmentService } from "../environment";
import { TrialDetail } from '../trial';
import { TrialDispatcher } from "../trialDispatcher";
import { UtCommandChannel } from './utCommandChannel';
......@@ -54,7 +52,7 @@ async function waitResult<TResult>(callback: () => Promise<TResult | undefined>,
return undefined;
}
async function waitResultMust<TResult>(callback: () => Promise<TResult | undefined>, waitMs: number = 1000, interval: number = 1): Promise<TResult> {
async function waitResultMust<TResult>(callback: () => Promise<TResult | undefined>, waitMs: number = 10000, interval: number = 1): Promise<TResult> {
const result = await waitResult(callback, waitMs, interval, true);
// this error should be thrown in waitResult already.
if (result === undefined) {
......@@ -201,16 +199,21 @@ describe('Unit Test for TrialDispatcher', () => {
nniManagerIp: "127.0.0.1",
}
trialDispatcher = new TrialDispatcher();
component.Container.bind(EnvironmentService)
.to(UtEnvironmentService)
.scope(Scope.Singleton);
await trialDispatcher.setClusterMetadata(TrialConfigMetadataKey.TRIAL_CONFIG, JSON.stringify(trialConfig));
await trialDispatcher.setClusterMetadata(TrialConfigMetadataKey.NNI_MANAGER_IP, JSON.stringify(nniManagerIpConfig));
trialRunPromise = trialDispatcher.run();
// set ut environment
let environmentServiceList: EnvironmentService[] = [];
environmentService = new UtEnvironmentService();
environmentServiceList.push(environmentService);
trialDispatcher.environmentServiceList = environmentServiceList;
// set ut command channel
environmentService.initCommandChannel(trialDispatcher.commandEmitter);
commandChannel = environmentService.getCommandChannel as UtCommandChannel;
trialDispatcher.commandChannelSet = new Set<CommandChannel>().add(environmentService.getCommandChannel);
trialDispatcher.environmentMaintenceLoopInterval = 1000;
environmentService = component.get(EnvironmentService) as UtEnvironmentService;
commandChannel = environmentService.testGetCommandChannel();
trialRunPromise = trialDispatcher.run();
});
afterEach(async () => {
......@@ -258,9 +261,6 @@ describe('Unit Test for TrialDispatcher', () => {
await waitEnvironment(2, previousEnvironments, environmentService, commandChannel);
await verifyTrialRunning(commandChannel, trialDetail);
await verifyTrialResult(commandChannel, trialDetail, -1);
await waitResultMust<true>(async () => {
return environment.status === 'USER_CANCELED' ? true : undefined;
});
chai.assert.equal(environmentService.testGetEnvironments().size, 2, "as env not reused, so only 2 envs should be here.");
const trials = await trialDispatcher.listTrialJobs();
......@@ -433,12 +433,10 @@ describe('Unit Test for TrialDispatcher', () => {
let environment = await waitEnvironment(1, previousEnvironments, environmentService, commandChannel);
await verifyTrialRunning(commandChannel, trialDetail);
await verifyTrialResult(commandChannel, trialDetail, 0);
environmentService.testSetEnvironmentStatus(environment, 'SUCCEEDED');
await waitResultMust<boolean>(async () => {
return environment.status === 'SUCCEEDED' ? true : undefined;
});
trialDetail = await newTrial(trialDispatcher);
await waitEnvironment(2, previousEnvironments, environmentService, commandChannel);
await verifyTrialRunning(commandChannel, trialDetail);
......
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
import { EnvironmentInformation, EnvironmentService, EnvironmentStatus } from "../environment";
import { EventEmitter } from "events";
import { CommandChannel } from "../commandChannel";
import { Channel, EnvironmentInformation, EnvironmentService, EnvironmentStatus } from "../environment";
import { EventEmitter } from 'events';
import { UtCommandChannel } from "./utCommandChannel";
export class UtEnvironmentService extends EnvironmentService {
private commandChannel: UtCommandChannel | undefined;
private allEnvironments = new Map<string, EnvironmentInformation>();
private hasMoreEnvironmentsInternal = true;
......@@ -23,6 +21,14 @@ export class UtEnvironmentService extends EnvironmentService {
return 1;
}
public get getName(): string {
return 'ut';
}
public initCommandChannel(eventEmitter: EventEmitter): void {
this.commandChannel = new UtCommandChannel(eventEmitter);
}
public testSetEnvironmentStatus(environment: EnvironmentInformation, newStatus: EnvironmentStatus): void {
environment.status = newStatus;
}
......@@ -35,13 +41,6 @@ export class UtEnvironmentService extends EnvironmentService {
return this.allEnvironments;
}
public testGetCommandChannel(): UtCommandChannel {
if (this.commandChannel === undefined) {
throw new Error(`command channel shouldn't be undefined.`);
}
return this.commandChannel;
}
public testSetNoMoreEnvironment(hasMore: boolean): void {
this.hasMoreEnvironmentsInternal = hasMore;
}
......@@ -50,11 +49,6 @@ export class UtEnvironmentService extends EnvironmentService {
return this.hasMoreEnvironmentsInternal;
}
public createCommandChannel(commandEmitter: EventEmitter): CommandChannel {
this.commandChannel = new UtCommandChannel(commandEmitter)
return this.commandChannel;
}
public async config(_key: string, _value: string): Promise<void> {
// do nothing
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment