Unverified Commit 872554f1 authored by SparkSnail's avatar SparkSnail Committed by GitHub
Browse files

Support heterogeneous environment service (#3097)

parent dec91f7e
......@@ -45,6 +45,10 @@ export class OpenPaiEnvironmentService extends EnvironmentService {
return true;
}
public get getName(): string {
return 'pai';
}
public async config(key: string, value: string): Promise<void> {
switch (key) {
case TrialConfigMetadataKey.PAI_CLUSTER_CONFIG:
......
......@@ -63,6 +63,10 @@ export class RemoteEnvironmentService extends EnvironmentService {
return false;
}
public get getName(): string {
return 'remote';
}
public async config(key: string, value: string): Promise<void> {
switch (key) {
case TrialConfigMetadataKey.MACHINE_LIST:
......@@ -134,7 +138,15 @@ export class RemoteEnvironmentService extends EnvironmentService {
await executor.createFolder(remoteGpuScriptCollectorDir, true);
await executor.allowPermission(true, nniRootDir);
}
public async refreshEnvironmentsStatus(environments: EnvironmentInformation[]): Promise<void> {
const tasks: Promise<void>[] = [];
environments.forEach(async (environment) => {
tasks.push(this.refreshEnvironment(environment));
});
await Promise.all(tasks);
}
private async refreshEnvironment(environment: EnvironmentInformation): Promise<void> {
const executor = await this.getExecutor(environment.id);
const jobpidPath: string = `${environment.runnerWorkingFolder}/pid`;
......@@ -176,14 +188,6 @@ export class RemoteEnvironmentService extends EnvironmentService {
}
}
public async refreshEnvironmentsStatus(environments: EnvironmentInformation[]): Promise<void> {
const tasks: Promise<void>[] = [];
environments.forEach(async (environment) => {
tasks.push(this.refreshEnvironment(environment));
});
await Promise.all(tasks);
}
/**
* If a environment is finished, release the connection resource
* @param environment remote machine environment job detail
......
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
export class HeterogenousConfig {
public readonly trainingServicePlatforms: string[];
constructor(trainingServicePlatforms: string[]) {
this.trainingServicePlatforms = trainingServicePlatforms;
}
}
......@@ -13,14 +13,11 @@ import { TrialConfigMetadataKey } from '../common/trialConfigMetadataKey';
import { PAIClusterConfig } from '../pai/paiConfig';
import { PAIK8STrainingService } from '../pai/paiK8S/paiK8STrainingService';
import { RemoteMachineTrainingService } from '../remote_machine/remoteMachineTrainingService';
import { EnvironmentService } from './environment';
import { OpenPaiEnvironmentService } from './environments/openPaiEnvironmentService';
import { AMLEnvironmentService } from './environments/amlEnvironmentService';
import { RemoteEnvironmentService } from './environments/remoteEnvironmentService';
import { MountedStorageService } from './storages/mountedStorageService';
import { StorageService } from './storageService';
import { TrialDispatcher } from './trialDispatcher';
import { RemoteConfig } from './remote/remoteConfig';
import { HeterogenousConfig } from './heterogenous/heterogenousConfig';
/**
......@@ -31,7 +28,6 @@ import { RemoteConfig } from './remote/remoteConfig';
class RouterTrainingService implements TrainingService {
protected readonly log!: Logger;
private internalTrainingService: TrainingService | undefined;
private metaDataCache: Map<string, string> = new Map<string, string>();
constructor() {
this.log = getLogger();
......@@ -99,75 +95,70 @@ class RouterTrainingService implements TrainingService {
public async setClusterMetadata(key: string, value: string): Promise<void> {
if (this.internalTrainingService === undefined) {
if (key === TrialConfigMetadataKey.PAI_CLUSTER_CONFIG) {
// Need to refactor configuration, remove heterogeneous_config field in the future
if (key === TrialConfigMetadataKey.HETEROGENEOUS_CONFIG){
this.internalTrainingService = component.get(TrialDispatcher);
const heterogenousConfig: HeterogenousConfig = <HeterogenousConfig>JSON.parse(value);
if (this.internalTrainingService === undefined) {
throw new Error("internalTrainingService not initialized!");
}
// Initialize storageService for pai, only support singleton for now, need refactor
if (heterogenousConfig.trainingServicePlatforms.includes('pai')) {
Container.bind(StorageService)
.to(MountedStorageService)
.scope(Scope.Singleton);
}
await this.internalTrainingService.setClusterMetadata('platform_list',
heterogenousConfig.trainingServicePlatforms.join(','));
} else if (key === TrialConfigMetadataKey.LOCAL_CONFIG) {
this.internalTrainingService = component.get(TrialDispatcher);
if (this.internalTrainingService === undefined) {
throw new Error("internalTrainingService not initialized!");
}
await this.internalTrainingService.setClusterMetadata('platform_list', 'local');
} else if (key === TrialConfigMetadataKey.PAI_CLUSTER_CONFIG) {
const config = <PAIClusterConfig>JSON.parse(value);
if (config.reuse === true) {
this.log.info(`reuse flag enabled, use EnvironmentManager.`);
this.internalTrainingService = component.get(TrialDispatcher);
// TODO to support other serivces later.
Container.bind(EnvironmentService)
.to(OpenPaiEnvironmentService)
.scope(Scope.Singleton);
// TODO to support other storages later.
Container.bind(StorageService)
.to(MountedStorageService)
.scope(Scope.Singleton);
if (this.internalTrainingService === undefined) {
throw new Error("internalTrainingService not initialized!");
}
await this.internalTrainingService.setClusterMetadata('platform_list', 'pai');
} else {
this.log.debug(`caching metadata key:{} value:{}, as training service is not determined.`);
this.internalTrainingService = component.get(PAIK8STrainingService);
}
for (const [key, value] of this.metaDataCache) {
if (this.internalTrainingService === undefined) {
throw new Error("TrainingService is not assigned!");
}
await this.internalTrainingService.setClusterMetadata(key, value);
}
if (this.internalTrainingService === undefined) {
throw new Error("TrainingService is not assigned!");
}
await this.internalTrainingService.setClusterMetadata(key, value);
this.metaDataCache.clear();
} else if (key === TrialConfigMetadataKey.AML_CLUSTER_CONFIG) {
this.internalTrainingService = component.get(TrialDispatcher);
Container.bind(EnvironmentService)
.to(AMLEnvironmentService)
.scope(Scope.Singleton);
for (const [key, value] of this.metaDataCache) {
if (this.internalTrainingService === undefined) {
throw new Error("TrainingService is not assigned!");
}
await this.internalTrainingService.setClusterMetadata(key, value);
}
if (this.internalTrainingService === undefined) {
throw new Error("TrainingService is not assigned!");
throw new Error("internalTrainingService not initialized!");
}
await this.internalTrainingService.setClusterMetadata(key, value);
this.metaDataCache.clear();
await this.internalTrainingService.setClusterMetadata('platform_list', 'aml');
} else if (key === TrialConfigMetadataKey.REMOTE_CONFIG) {
const config = <RemoteConfig>JSON.parse(value);
if (config.reuse === true) {
this.log.info(`reuse flag enabled, use EnvironmentManager.`);
this.internalTrainingService = component.get(TrialDispatcher);
Container.bind(EnvironmentService)
.to(RemoteEnvironmentService)
.scope(Scope.Singleton);
if (this.internalTrainingService === undefined) {
throw new Error("internalTrainingService not initialized!");
}
await this.internalTrainingService.setClusterMetadata('platform_list', 'remote');
} else {
this.log.debug(`caching metadata key:{} value:{}, as training service is not determined.`);
this.internalTrainingService = component.get(RemoteMachineTrainingService);
}
} else {
this.log.debug(`caching metadata key:{} value:{}, as training service is not determined.`);
this.metaDataCache.set(key, value);
}
} else {
await this.internalTrainingService.setClusterMetadata(key, value);
}
if (this.internalTrainingService === undefined) {
throw new Error("internalTrainingService not initialized!");
}
await this.internalTrainingService.setClusterMetadata(key, value);
}
public async getClusterMetadata(key: string): Promise<string> {
......
......@@ -3,15 +3,13 @@
import * as chai from 'chai';
import * as path from 'path';
import { Scope } from "typescript-ioc";
import * as component from '../../../common/component';
import { getLogger, Logger } from "../../../common/log";
import { TrialJobApplicationForm, TrialJobStatus } from '../../../common/trainingService';
import { cleanupUnitTest, delay, prepareUnitTest, uniqueString } from '../../../common/utils';
import { INITIALIZED, KILL_TRIAL_JOB, NEW_TRIAL_JOB, SEND_TRIAL_JOB_PARAMETER, TRIAL_END, GPU_INFO } from '../../../core/commands';
import { TrialConfigMetadataKey } from '../../../training_service/common/trialConfigMetadataKey';
import { Command } from '../commandChannel';
import { EnvironmentInformation, EnvironmentService } from "../environment";
import { Command, CommandChannel } from '../commandChannel';
import { Channel, EnvironmentInformation, EnvironmentService } from "../environment";
import { TrialDetail } from '../trial';
import { TrialDispatcher } from "../trialDispatcher";
import { UtCommandChannel } from './utCommandChannel';
......@@ -54,7 +52,7 @@ async function waitResult<TResult>(callback: () => Promise<TResult | undefined>,
return undefined;
}
async function waitResultMust<TResult>(callback: () => Promise<TResult | undefined>, waitMs: number = 1000, interval: number = 1): Promise<TResult> {
async function waitResultMust<TResult>(callback: () => Promise<TResult | undefined>, waitMs: number = 10000, interval: number = 1): Promise<TResult> {
const result = await waitResult(callback, waitMs, interval, true);
// this error should be thrown in waitResult already.
if (result === undefined) {
......@@ -201,16 +199,21 @@ describe('Unit Test for TrialDispatcher', () => {
nniManagerIp: "127.0.0.1",
}
trialDispatcher = new TrialDispatcher();
component.Container.bind(EnvironmentService)
.to(UtEnvironmentService)
.scope(Scope.Singleton);
await trialDispatcher.setClusterMetadata(TrialConfigMetadataKey.TRIAL_CONFIG, JSON.stringify(trialConfig));
await trialDispatcher.setClusterMetadata(TrialConfigMetadataKey.NNI_MANAGER_IP, JSON.stringify(nniManagerIpConfig));
trialRunPromise = trialDispatcher.run();
// set ut environment
let environmentServiceList: EnvironmentService[] = [];
environmentService = new UtEnvironmentService();
environmentServiceList.push(environmentService);
trialDispatcher.environmentServiceList = environmentServiceList;
// set ut command channel
environmentService.initCommandChannel(trialDispatcher.commandEmitter);
commandChannel = environmentService.getCommandChannel as UtCommandChannel;
trialDispatcher.commandChannelSet = new Set<CommandChannel>().add(environmentService.getCommandChannel);
trialDispatcher.environmentMaintenceLoopInterval = 1000;
environmentService = component.get(EnvironmentService) as UtEnvironmentService;
commandChannel = environmentService.testGetCommandChannel();
trialRunPromise = trialDispatcher.run();
});
afterEach(async () => {
......@@ -258,9 +261,6 @@ describe('Unit Test for TrialDispatcher', () => {
await waitEnvironment(2, previousEnvironments, environmentService, commandChannel);
await verifyTrialRunning(commandChannel, trialDetail);
await verifyTrialResult(commandChannel, trialDetail, -1);
await waitResultMust<true>(async () => {
return environment.status === 'USER_CANCELED' ? true : undefined;
});
chai.assert.equal(environmentService.testGetEnvironments().size, 2, "as env not reused, so only 2 envs should be here.");
const trials = await trialDispatcher.listTrialJobs();
......@@ -433,12 +433,10 @@ describe('Unit Test for TrialDispatcher', () => {
let environment = await waitEnvironment(1, previousEnvironments, environmentService, commandChannel);
await verifyTrialRunning(commandChannel, trialDetail);
await verifyTrialResult(commandChannel, trialDetail, 0);
environmentService.testSetEnvironmentStatus(environment, 'SUCCEEDED');
await waitResultMust<boolean>(async () => {
return environment.status === 'SUCCEEDED' ? true : undefined;
});
trialDetail = await newTrial(trialDispatcher);
await waitEnvironment(2, previousEnvironments, environmentService, commandChannel);
await verifyTrialRunning(commandChannel, trialDetail);
......
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
import { EnvironmentInformation, EnvironmentService, EnvironmentStatus } from "../environment";
import { EventEmitter } from "events";
import { CommandChannel } from "../commandChannel";
import { Channel, EnvironmentInformation, EnvironmentService, EnvironmentStatus } from "../environment";
import { EventEmitter } from 'events';
import { UtCommandChannel } from "./utCommandChannel";
export class UtEnvironmentService extends EnvironmentService {
private commandChannel: UtCommandChannel | undefined;
private allEnvironments = new Map<string, EnvironmentInformation>();
private hasMoreEnvironmentsInternal = true;
......@@ -23,6 +21,14 @@ export class UtEnvironmentService extends EnvironmentService {
return 1;
}
public get getName(): string {
return 'ut';
}
public initCommandChannel(eventEmitter: EventEmitter): void {
this.commandChannel = new UtCommandChannel(eventEmitter);
}
public testSetEnvironmentStatus(environment: EnvironmentInformation, newStatus: EnvironmentStatus): void {
environment.status = newStatus;
}
......@@ -35,13 +41,6 @@ export class UtEnvironmentService extends EnvironmentService {
return this.allEnvironments;
}
public testGetCommandChannel(): UtCommandChannel {
if (this.commandChannel === undefined) {
throw new Error(`command channel shouldn't be undefined.`);
}
return this.commandChannel;
}
public testSetNoMoreEnvironment(hasMore: boolean): void {
this.hasMoreEnvironmentsInternal = hasMore;
}
......@@ -50,11 +49,6 @@ export class UtEnvironmentService extends EnvironmentService {
return this.hasMoreEnvironmentsInternal;
}
public createCommandChannel(commandEmitter: EventEmitter): CommandChannel {
this.commandChannel = new UtCommandChannel(commandEmitter)
return this.commandChannel;
}
public async config(_key: string, _value: string): Promise<void> {
// do nothing
}
......
......@@ -10,10 +10,10 @@ import { Writable } from 'stream';
import { String } from 'typescript-string-operations';
import * as component from '../../common/component';
import { NNIError, NNIErrorNames, MethodNotImplementedError } from '../../common/errors';
import { getBasePort, getExperimentId, getPlatform } from '../../common/experimentStartupInfo';
import { getBasePort, getExperimentId } from '../../common/experimentStartupInfo';
import { getLogger, Logger } from '../../common/log';
import { NNIManagerIpConfig, TrainingService, TrialJobApplicationForm, TrialJobMetric, TrialJobStatus, LogType } from '../../common/trainingService';
import { delay, getExperimentRootDir, getIPV4Address, getLogLevel, getVersion, mkDirPSync, uniqueString } from '../../common/utils';
import { delay, getExperimentRootDir, getIPV4Address, getLogLevel, getVersion, mkDirPSync, randomSelect, uniqueString } from '../../common/utils';
import { GPU_INFO, INITIALIZED, KILL_TRIAL_JOB, NEW_TRIAL_JOB, REPORT_METRIC_DATA, SEND_TRIAL_JOB_PARAMETER, STDOUT, TRIAL_END, VERSION_CHECK } from '../../core/commands';
import { ScheduleResultType } from '../../training_service/common/gpuData';
import { CONTAINER_INSTALL_NNI_SHELL_FORMAT } from '../common/containerJobData';
......@@ -22,6 +22,7 @@ import { TrialConfigMetadataKey } from '../common/trialConfigMetadataKey';
import { validateCodeDir } from '../common/util';
import { Command, CommandChannel } from './commandChannel';
import { EnvironmentInformation, EnvironmentService, NodeInformation, RunnerSettings, TrialGpuSummary } from './environment';
import { EnvironmentServiceFactory } from './environments/environmentServiceFactory';
import { GpuScheduler } from './gpuScheduler';
import { MountedStorageService } from './storages/mountedStorageService';
import { StorageService } from './storageService';
......@@ -45,13 +46,16 @@ class TrialDispatcher implements TrainingService {
private enableVersionCheck: boolean = true;
private trialConfig: TrialConfig | undefined;
private runnerSettings: RunnerSettings;
private commandEmitter: EventEmitter | undefined;
private commandChannel: CommandChannel | undefined;
private readonly trials: Map<string, TrialDetail>;
private readonly environments: Map<string, EnvironmentInformation>;
// make public for ut
public environmentServiceList: EnvironmentService[] = [];
public commandChannelSet: Set<CommandChannel>;
public commandEmitter: EventEmitter;
public environmentMaintenceLoopInterval: number = -1;
private nniManagerIp: string | undefined;
// uses to accelerate trial manager loop
// true means there is updates, and trial loop should run a cycle immediately.
......@@ -62,6 +66,7 @@ class TrialDispatcher implements TrainingService {
private enableGpuScheduler: boolean = false;
// uses to save if user like to reuse environment
private reuseEnvironment: boolean = true;
private logCollection: string = '';
private gpuScheduler: GpuScheduler;
......@@ -76,10 +81,7 @@ class TrialDispatcher implements TrainingService {
this.metricsEmitter = new EventEmitter();
this.experimentId = getExperimentId();
this.experimentRootDir = getExperimentRootDir();
this.runnerSettings = new RunnerSettings();
this.runnerSettings.experimentId = this.experimentId;
this.runnerSettings.platform = getPlatform();
this.commandChannelSet = new Set<CommandChannel>();
const logLevel = getLogLevel();
this.log.debug(`current folder ${__dirname}`);
......@@ -89,6 +91,8 @@ class TrialDispatcher implements TrainingService {
this.isDeveloping = true;
}
this.commandEmitter = new EventEmitter();
this.gpuScheduler = new GpuScheduler();
}
......@@ -122,13 +126,7 @@ class TrialDispatcher implements TrainingService {
const trialId: string = uniqueString(5);
const environmentService = component.get<EnvironmentService>(EnvironmentService);
let trialWorkingFolder: string = "";
if (environmentService.hasStorageService) {
const storageService = component.get<StorageService>(StorageService);
trialWorkingFolder = storageService.joinPath('trials', trialId);
}
const trialJobDetail: TrialDetail = new TrialDetail(trialId, "WAITING", Date.now(), trialWorkingFolder, form);
const trialJobDetail: TrialDetail = new TrialDetail(trialId, "WAITING", Date.now(), "", form);
this.trials.set(trialId, trialJobDetail);
......@@ -142,23 +140,20 @@ class TrialDispatcher implements TrainingService {
if (environment === undefined) {
throw new Error(`TrialDispatcher: trial ${trialJobId}'s env shouldn't be undefined in updateTrialJob.`);
}
if (this.commandChannel === undefined) {
throw new Error(`TrialDispatcher: commandChannel shouldn't be undefined in updateTrialJob.`);
if (environment.environmentService === undefined) {
throw new Error(`Environment ${environment.id} does not assigned environment service.`);
}
const message = {
"trialId": trialJobId,
"parameters": form.hyperParameters,
}
await this.commandChannel.sendCommand(environment, SEND_TRIAL_JOB_PARAMETER, message);
await environment.environmentService.getCommandChannel.sendCommand(environment, SEND_TRIAL_JOB_PARAMETER, message);
return trialDetail;
}
public async cancelTrialJob(trialJobId: string, isEarlyStopped?: boolean | undefined): Promise<void> {
if (this.commandChannel === undefined) {
throw new Error(`TrialDispatcher: commandChannel shouldn't be undefined in cancelTrialJob.`);
}
const trial = await this.getTrialJob(trialJobId);
switch (trial.status) {
case "RUNNING":
......@@ -166,8 +161,8 @@ class TrialDispatcher implements TrainingService {
case "UNKNOWN":
{
const environment = trial.environment;
if (environment) {
await this.commandChannel.sendCommand(environment, KILL_TRIAL_JOB, trial.id);
if (environment && environment.environmentService) {
await environment.environmentService.getCommandChannel.sendCommand(environment, KILL_TRIAL_JOB, trial.id);
trial.isEarlyStopped = isEarlyStopped;
trial.status = trial.isEarlyStopped === true ?
'EARLY_STOPPED' : 'USER_CANCELED';
......@@ -179,70 +174,71 @@ class TrialDispatcher implements TrainingService {
}
public async run(): Promise<void> {
const environmentService = component.get<EnvironmentService>(EnvironmentService);
this.commandEmitter = new EventEmitter();
this.commandChannel = environmentService.createCommandChannel(this.commandEmitter);
// TODO it's a hard code of web channel, it needs to be improved.
if (this.runnerSettings.nniManagerIP === "" || this.runnerSettings.nniManagerIP === null) {
this.runnerSettings.nniManagerIP = getIPV4Address();
if (this.trialConfig === undefined) {
throw new Error(`trial config shouldn't be undefined in run()`);
}
for(const environmentService of this.environmentServiceList) {
const runnerSettings: RunnerSettings = new RunnerSettings();
runnerSettings.nniManagerIP = this.nniManagerIp === undefined? getIPV4Address() : this.nniManagerIp;
runnerSettings.nniManagerPort = getBasePort() + 1;
runnerSettings.commandChannel = environmentService.getCommandChannel.channelName;
runnerSettings.enableGpuCollector = this.enableGpuScheduler;
runnerSettings.command = this.trialConfig.command;
runnerSettings.nniManagerVersion = this.enableVersionCheck ? await getVersion() : '';
runnerSettings.logCollection = this.logCollection;
runnerSettings.platform = environmentService.getName;
runnerSettings.experimentId = this.experimentId;
await environmentService.getCommandChannel.start();
this.log.info(`TrialDispatcher: started channel: ${environmentService.getCommandChannel.constructor.name}`);
this.log.info(`TrialDispatcher: copying code and settings.`);
let storageService: StorageService;
if (environmentService.hasStorageService) {
this.log.debug(`TrialDispatcher: use existing storage service.`);
storageService = component.get<StorageService>(StorageService);
} else {
this.log.debug(`TrialDispatcher: create temp storage service to temp folder.`);
storageService = new MountedStorageService();
const environmentLocalTempFolder = path.join(this.experimentRootDir, this.experimentId, "environment-temp");
storageService.initialize(this.trialConfig.codeDir, environmentLocalTempFolder);
}
// Copy the compressed file to remoteDirectory and delete it
const codeDir = path.resolve(this.trialConfig.codeDir);
const envDir = storageService.joinPath("envs");
const codeFileName = await storageService.copyDirectory(codeDir, envDir, true);
storageService.rename(codeFileName, "nni-code.tar.gz");
const installFileName = storageService.joinPath(envDir, 'install_nni.sh');
await storageService.save(CONTAINER_INSTALL_NNI_SHELL_FORMAT, installFileName);
const runnerSettingsConfig = storageService.joinPath(envDir, "settings.json");
await storageService.save(JSON.stringify(runnerSettings), runnerSettingsConfig);
if (this.isDeveloping) {
let trialToolsPath = path.join(__dirname, "../../../../../tools/nni_trial_tool");
if (false === fs.existsSync(trialToolsPath)) {
trialToolsPath = path.join(__dirname, "..\\..\\..\\..\\..\\tools\\nni_trial_tool");
}
await storageService.copyDirectory(trialToolsPath, envDir, true);
}
}
this.runnerSettings.nniManagerPort = getBasePort() + 1;
this.runnerSettings.commandChannel = this.commandChannel.channelName;
// start channel
this.commandEmitter.on("command", (command: Command): void => {
this.handleCommand(command).catch((err: Error) => {
this.log.error(`TrialDispatcher: error on handle env ${command.environment.id} command: ${command.command}, data: ${command.data}, error: ${err}`);
})
});
await this.commandChannel.start();
this.log.info(`TrialDispatcher: started channel: ${this.commandChannel.constructor.name}`);
if (this.trialConfig === undefined) {
throw new Error(`trial config shouldn't be undefined in run()`);
}
this.log.info(`TrialDispatcher: copying code and settings.`);
let storageService: StorageService;
if (environmentService.hasStorageService) {
this.log.debug(`TrialDispatcher: use existing storage service.`);
storageService = component.get<StorageService>(StorageService);
} else {
this.log.debug(`TrialDispatcher: create temp storage service to temp folder.`);
storageService = new MountedStorageService();
const environmentLocalTempFolder = path.join(this.experimentRootDir, this.experimentId, "environment-temp");
storageService.initialize(this.trialConfig.codeDir, environmentLocalTempFolder);
}
// Copy the compressed file to remoteDirectory and delete it
const codeDir = path.resolve(this.trialConfig.codeDir);
const envDir = storageService.joinPath("envs");
const codeFileName = await storageService.copyDirectory(codeDir, envDir, true);
storageService.rename(codeFileName, "nni-code.tar.gz");
const installFileName = storageService.joinPath(envDir, 'install_nni.sh');
await storageService.save(CONTAINER_INSTALL_NNI_SHELL_FORMAT, installFileName);
const runnerSettings = storageService.joinPath(envDir, "settings.json");
await storageService.save(JSON.stringify(this.runnerSettings), runnerSettings);
// FIXME: what the hell is this?
if (this.isDeveloping) {
let trialToolsPath = path.join(__dirname, "../../../../../tools/nni_trial_tool");
if (false === fs.existsSync(trialToolsPath)) {
trialToolsPath = path.join(__dirname, "..\\..\\..\\..\\..\\tools\\nni_trial_tool");
}
await storageService.copyDirectory(trialToolsPath, envDir, true);
}
await this.prefetchEnvironments();
this.log.info(`TrialDispatcher: run loop started.`);
await Promise.all([
this.environmentMaintenanceLoop(),
this.trialManagementLoop(),
this.commandChannel.run(),
]);
const promiseList: Promise<void>[] = [];
for(const commandChannel of this.commandChannelSet) {
promiseList.push(commandChannel.run());
}
promiseList.push(this.environmentMaintenanceLoop());
promiseList.push(this.trialManagementLoop());
await Promise.all(promiseList);
}
public addTrialJobMetricListener(listener: (metric: TrialJobMetric) => void): void {
......@@ -260,14 +256,14 @@ class TrialDispatcher implements TrainingService {
public async setClusterMetadata(key: string, value: string): Promise<void> {
switch (key) {
case TrialConfigMetadataKey.NNI_MANAGER_IP:
this.runnerSettings.nniManagerIP = (<NNIManagerIpConfig>JSON.parse(value)).nniManagerIp;
this.nniManagerIp = (<NNIManagerIpConfig>JSON.parse(value)).nniManagerIp;
break;
case TrialConfigMetadataKey.VERSION_CHECK:
this.enableVersionCheck = (value === 'true' || value === 'True');
this.runnerSettings.nniManagerVersion = this.enableVersionCheck ? await getVersion() : '';
break;
case TrialConfigMetadataKey.LOG_COLLECTION:
this.runnerSettings.logCollection = value;
this.logCollection = value;
break;
case TrialConfigMetadataKey.TRIAL_CONFIG:
this.trialConfig = <TrialConfig>JSON.parse(value);
......@@ -279,15 +275,25 @@ class TrialDispatcher implements TrainingService {
this.log.info(`TrialDispatcher: GPU scheduler is enabled.`)
this.enableGpuScheduler = true;
}
this.runnerSettings.enableGpuCollector = this.enableGpuScheduler;
this.runnerSettings.command = this.trialConfig.command;
// Validate to make sure codeDir doesn't have too many files
await validateCodeDir(this.trialConfig.codeDir);
break;
case TrialConfigMetadataKey.PLATFORM_LIST: {
const platforms: string[] = value.split(",");
for(const platform of platforms) {
const environmentService: EnvironmentService = EnvironmentServiceFactory.createEnvironmentService(platform);
environmentService.initCommandChannel(this.commandEmitter);
this.environmentMaintenceLoopInterval =
Math.max(environmentService.environmentMaintenceLoopInterval, this.environmentMaintenceLoopInterval);
this.commandChannelSet.add(environmentService.getCommandChannel);
this.environmentServiceList.push(environmentService);
}
}
}
for(const environmentService of this.environmentServiceList) {
await environmentService.config(key, value);
}
const environmentService = component.get<EnvironmentService>(EnvironmentService);
await environmentService.config(key, value);
}
public getClusterMetadata(_key: string): Promise<string> {
......@@ -295,48 +301,75 @@ class TrialDispatcher implements TrainingService {
}
public async cleanUp(): Promise<void> {
if (this.commandChannel === undefined) {
throw new Error(`TrialDispatcher: commandChannel shouldn't be undefined in cleanUp.`);
}
if (this.commandEmitter === undefined) {
throw new Error(`TrialDispatcher: commandEmitter shouldn't be undefined in cleanUp.`);
}
this.stopping = true;
this.shouldUpdateTrials = true;
const environmentService = component.get<EnvironmentService>(EnvironmentService);
const environments = [...this.environments.values()];
for (let index = 0; index < environments.length; index++) {
const environment = environments[index];
if (environment.isAlive === true) {
this.log.info(`stopping environment ${environment.id}...`);
await environmentService.stopEnvironment(environment);
await this.commandChannel.close(environment);
if (environment.environmentService === undefined) {
throw new Error(`${environment.id} do not have environmentService!`);
}
await environment.environmentService.stopEnvironment(environment);
this.log.info(`stopped environment ${environment.id}.`);
}
}
this.commandEmitter.off("command", this.handleCommand);
await this.commandChannel.stop();
for (const commandChannel of this.commandChannelSet) {
await commandChannel.stop();
}
}
private async environmentMaintenanceLoop(): Promise<void> {
if (this.commandChannel === undefined) {
throw new Error(`TrialDispatcher: commandChannel shouldn't be undefined in environmentMaintenanceLoop.`);
}
const environmentService = component.get<EnvironmentService>(EnvironmentService);
while (!this.stopping) {
const environments: EnvironmentInformation[] = [];
for (const environment of this.environments.values()) {
if (environment.isAlive === true) {
environments.push(environment);
} else {
await this.commandChannel.close(environment);
if (environment.environmentService === undefined) {
throw new Error(`${environment.id} do not have environment service!`);
}
await environment.environmentService.getCommandChannel.close(environment);
}
}
// Group environments according to environmentService
const environmentServiceDict: Map<EnvironmentService, EnvironmentInformation[]> =
new Map<EnvironmentService, EnvironmentInformation[]>();
for (const environment of environments) {
if (environment.environmentService === undefined) {
throw new Error(`${environment.id} do not have environment service!`);
}
if (!environmentServiceDict.has(environment.environmentService)) {
environmentServiceDict.set(environment.environmentService, [environment]);
} else {
const environmentsList: EnvironmentInformation[] | undefined = environmentServiceDict.get(environment.environmentService);
if (environmentsList === undefined) {
throw new Error(`Environment list not initialized!`);
}
environmentsList.push(environment);
environmentServiceDict.set(environment.environmentService, environmentsList);
}
}
// Refresh all environments
const taskList: Promise<void>[] = [];
for (const environmentService of environmentServiceDict.keys()) {
const environmentsList: EnvironmentInformation[] | undefined = environmentServiceDict.get(environmentService);
if (environmentsList) {
taskList.push(environmentService.refreshEnvironmentsStatus(environmentsList));
}
}
await environmentService.refreshEnvironmentsStatus(environments);
await Promise.all(taskList);
environments.forEach((environment) => {
for (const environment of environments) {
if (environment.environmentService === undefined) {
throw new Error(`${environment.id} do not have environment service!`);
}
const oldIsAlive = environment.isAlive;
switch (environment.status) {
case 'WAITING':
......@@ -351,16 +384,16 @@ class TrialDispatcher implements TrainingService {
if (oldIsAlive !== environment.isAlive) {
this.log.debug(`set environment ${environment.id} isAlive from ${oldIsAlive} to ${environment.isAlive} due to status is ${environment.status}.`);
}
});
}
this.shouldUpdateTrials = true;
await delay(environmentService.environmentMaintenceLoopInterval);
if (this.environmentMaintenceLoopInterval === -1) {
throw new Error("EnvironmentMaintenceLoopInterval not initialized!");
}
await delay(this.environmentMaintenceLoopInterval);
}
}
private async trialManagementLoop(): Promise<void> {
if (this.commandChannel === undefined) {
throw new Error(`TrialDispatcher: commandChannel shouldn't be undefined in trialManagementLoop.`);
}
const interval = 1;
while (!this.stopping) {
......@@ -400,6 +433,11 @@ class TrialDispatcher implements TrainingService {
liveTrialsCount++;
continue;
}
if (environment.environmentService === undefined) {
throw new Error(`${environment.id} does not has environment service!`);
}
trial.url = environment.trackingUrl;
const environmentStatus = environment.status;
......@@ -414,7 +452,7 @@ class TrialDispatcher implements TrainingService {
// for example, in horovod, it's just sleep command, has no impact on trial result.
if (environment.nodeCount > completedCount) {
this.log.info(`stop partial completed trial ${trial.id}`);
await this.commandChannel.sendCommand(environment, KILL_TRIAL_JOB, trial.id);
await environment.environmentService.getCommandChannel.sendCommand(environment, KILL_TRIAL_JOB, trial.id);
}
for (const node of trial.nodes.values()) {
if (node.status === "FAILED") {
......@@ -463,8 +501,10 @@ class TrialDispatcher implements TrainingService {
false === this.reuseEnvironment &&
environment.assignedTrialCount > 0
) {
const environmentService = component.get<EnvironmentService>(EnvironmentService);
await environmentService.stopEnvironment(environment);
if (environment.environmentService === undefined) {
throw new Error(`${environment.id} does not has environment service!`);
}
await environment.environmentService.stopEnvironment(environment);
continue;
}
......@@ -556,11 +596,13 @@ class TrialDispatcher implements TrainingService {
}
if (neededEnvironmentCount > 0) {
const environmentService = component.get<EnvironmentService>(EnvironmentService);
let requestedCount = 0;
let hasMoreEnvironments = false;
for (let index = 0; index < neededEnvironmentCount; index++) {
if (true === environmentService.hasMoreEnvironments) {
await this.requestEnvironment();
const environmentService: EnvironmentService | undefined = this.selectEnvironmentService();
if (environmentService !== undefined) {
hasMoreEnvironments = true;
await this.requestEnvironment(environmentService);
requestedCount++;
this.isLoggedNoMoreEnvironment = false;
} else {
......@@ -570,7 +612,7 @@ class TrialDispatcher implements TrainingService {
}
}
}
if (environmentService.hasMoreEnvironments === true || requestedCount > 0) {
if (hasMoreEnvironments === true || requestedCount > 0) {
this.log.info(`requested new environment, live trials: ${liveTrialsCount}, ` +
`live environments: ${liveEnvironmentsCount}, neededEnvironmentCount: ${neededEnvironmentCount}, ` +
`requestedCount: ${requestedCount}`);
......@@ -580,25 +622,37 @@ class TrialDispatcher implements TrainingService {
}
}
private async prefetchEnvironments (): Promise<void> {
const environmentService = component.get<EnvironmentService>(EnvironmentService);
const number = environmentService.prefetchedEnvironmentCount;
this.log.info(`Initialize environments total number: ${number}`);
for (let index = 0; index < number; index++) {
await this.requestEnvironment();
// Schedule a environment platform for environment
private selectEnvironmentService(): EnvironmentService | undefined {
const validEnvironmentServiceList = [];
for(const environmentService of this.environmentServiceList){
if (environmentService.hasMoreEnvironments) {
validEnvironmentServiceList.push(environmentService);
}
}
if (validEnvironmentServiceList.length === 0) {
return undefined;
}
// Random scheduler
return randomSelect(validEnvironmentServiceList);
}
private async requestEnvironment(): Promise<void> {
if (this.commandChannel === undefined) {
throw new Error(`TrialDispatcher: commandChannel shouldn't be undefined in requestEnvironment.`);
private async prefetchEnvironments (): Promise<void> {
for (const environmentService of this.environmentServiceList) {
const number = environmentService.prefetchedEnvironmentCount;
this.log.info(`Initialize environments total number: ${number}`);
for (let index = 0; index < number; index++) {
await this.requestEnvironment(environmentService);
}
}
}
const environmentService = component.get<EnvironmentService>(EnvironmentService);
private async requestEnvironment(environmentService: EnvironmentService): Promise<void> {
const envId = uniqueString(5);
const envName = `nni_exp_${this.experimentId}_env_${envId}`;
const environment = environmentService.createEnvironmentInformation(envId, envName);
environment.environmentService = environmentService;
this.log.info(`Assign environment service ${environmentService.getName} to environment ${envId}`);
environment.command = `sh ../install_nni.sh && python3 -m nni.tools.trial_tool.trial_runner`;
if (this.isDeveloping) {
......@@ -616,15 +670,11 @@ class TrialDispatcher implements TrainingService {
} else {
environment.isAlive = true;
}
await this.commandChannel.open(environment);
await environment.environmentService.getCommandChannel.open(environment);
this.log.info(`requested environment ${environment.id} and job id is ${environment.envId}.`);
}
private async allocateEnvironment(trial: TrialDetail, environment: EnvironmentInformation): Promise<void> {
if (this.commandChannel === undefined) {
throw new Error(`TrialDispatcher: commandChannel shouldn't be undefined in allocateEnvironment.`);
}
if (this.trialConfig === undefined) {
throw new Error(`TrialDispatcher: trialConfig shouldn't be undefined in allocateEnvironment.`);
}
......@@ -653,6 +703,13 @@ class TrialDispatcher implements TrainingService {
environment.runningTrialCount++;
environment.assignedTrialCount++;
trial.environment = environment;
if (environment.environmentService === undefined) {
throw new Error(`${environment.id} environmentService not initialized!`);
}
if (environment.environmentService.hasStorageService) {
const storageService = component.get<StorageService>(StorageService);
trial.workingDirectory = storageService.joinPath('trials', trial.id);
}
trial.settings = {
trialId: trial.id,
gpuIndices: gpuIndices,
......@@ -661,7 +718,10 @@ class TrialDispatcher implements TrainingService {
}
trial.startTime = Date.now();
trial.status = "RUNNING";
await this.commandChannel.sendCommand(trial.environment, NEW_TRIAL_JOB, trial.settings);
if (environment.environmentService === undefined) {
throw new Error(`${environment.id} does not have environment service!`);
}
await environment.environmentService.getCommandChannel.sendCommand(trial.environment, NEW_TRIAL_JOB, trial.settings);
}
/**
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment