Commit df6145a2 authored by Yuge Zhang's avatar Yuge Zhang
Browse files

Merge branch 'master' of https://github.com/microsoft/nni into dev-retiarii

parents 0f0c6288 f8424a9f
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
export class HeterogenousConfig {
public readonly trainingServicePlatforms: string[];
constructor(trainingServicePlatforms: string[]) {
this.trainingServicePlatforms = trainingServicePlatforms;
}
}
...@@ -13,14 +13,11 @@ import { TrialConfigMetadataKey } from '../common/trialConfigMetadataKey'; ...@@ -13,14 +13,11 @@ import { TrialConfigMetadataKey } from '../common/trialConfigMetadataKey';
import { PAIClusterConfig } from '../pai/paiConfig'; import { PAIClusterConfig } from '../pai/paiConfig';
import { PAIK8STrainingService } from '../pai/paiK8S/paiK8STrainingService'; import { PAIK8STrainingService } from '../pai/paiK8S/paiK8STrainingService';
import { RemoteMachineTrainingService } from '../remote_machine/remoteMachineTrainingService'; import { RemoteMachineTrainingService } from '../remote_machine/remoteMachineTrainingService';
import { EnvironmentService } from './environment';
import { OpenPaiEnvironmentService } from './environments/openPaiEnvironmentService';
import { AMLEnvironmentService } from './environments/amlEnvironmentService';
import { RemoteEnvironmentService } from './environments/remoteEnvironmentService';
import { MountedStorageService } from './storages/mountedStorageService'; import { MountedStorageService } from './storages/mountedStorageService';
import { StorageService } from './storageService'; import { StorageService } from './storageService';
import { TrialDispatcher } from './trialDispatcher'; import { TrialDispatcher } from './trialDispatcher';
import { RemoteConfig } from './remote/remoteConfig'; import { RemoteConfig } from './remote/remoteConfig';
import { HeterogenousConfig } from './heterogenous/heterogenousConfig';
/** /**
...@@ -31,7 +28,6 @@ import { RemoteConfig } from './remote/remoteConfig'; ...@@ -31,7 +28,6 @@ import { RemoteConfig } from './remote/remoteConfig';
class RouterTrainingService implements TrainingService { class RouterTrainingService implements TrainingService {
protected readonly log!: Logger; protected readonly log!: Logger;
private internalTrainingService: TrainingService | undefined; private internalTrainingService: TrainingService | undefined;
private metaDataCache: Map<string, string> = new Map<string, string>();
constructor() { constructor() {
this.log = getLogger(); this.log = getLogger();
...@@ -99,75 +95,70 @@ class RouterTrainingService implements TrainingService { ...@@ -99,75 +95,70 @@ class RouterTrainingService implements TrainingService {
public async setClusterMetadata(key: string, value: string): Promise<void> { public async setClusterMetadata(key: string, value: string): Promise<void> {
if (this.internalTrainingService === undefined) { if (this.internalTrainingService === undefined) {
if (key === TrialConfigMetadataKey.PAI_CLUSTER_CONFIG) { // Need to refactor configuration, remove heterogeneous_config field in the future
if (key === TrialConfigMetadataKey.HETEROGENEOUS_CONFIG){
this.internalTrainingService = component.get(TrialDispatcher);
const heterogenousConfig: HeterogenousConfig = <HeterogenousConfig>JSON.parse(value);
if (this.internalTrainingService === undefined) {
throw new Error("internalTrainingService not initialized!");
}
// Initialize storageService for pai, only support singleton for now, need refactor
if (heterogenousConfig.trainingServicePlatforms.includes('pai')) {
Container.bind(StorageService)
.to(MountedStorageService)
.scope(Scope.Singleton);
}
await this.internalTrainingService.setClusterMetadata('platform_list',
heterogenousConfig.trainingServicePlatforms.join(','));
} else if (key === TrialConfigMetadataKey.LOCAL_CONFIG) {
this.internalTrainingService = component.get(TrialDispatcher);
if (this.internalTrainingService === undefined) {
throw new Error("internalTrainingService not initialized!");
}
await this.internalTrainingService.setClusterMetadata('platform_list', 'local');
} else if (key === TrialConfigMetadataKey.PAI_CLUSTER_CONFIG) {
const config = <PAIClusterConfig>JSON.parse(value); const config = <PAIClusterConfig>JSON.parse(value);
if (config.reuse === true) { if (config.reuse === true) {
this.log.info(`reuse flag enabled, use EnvironmentManager.`); this.log.info(`reuse flag enabled, use EnvironmentManager.`);
this.internalTrainingService = component.get(TrialDispatcher); this.internalTrainingService = component.get(TrialDispatcher);
// TODO to support other serivces later.
Container.bind(EnvironmentService)
.to(OpenPaiEnvironmentService)
.scope(Scope.Singleton);
// TODO to support other storages later. // TODO to support other storages later.
Container.bind(StorageService) Container.bind(StorageService)
.to(MountedStorageService) .to(MountedStorageService)
.scope(Scope.Singleton); .scope(Scope.Singleton);
if (this.internalTrainingService === undefined) {
throw new Error("internalTrainingService not initialized!");
}
await this.internalTrainingService.setClusterMetadata('platform_list', 'pai');
} else { } else {
this.log.debug(`caching metadata key:{} value:{}, as training service is not determined.`); this.log.debug(`caching metadata key:{} value:{}, as training service is not determined.`);
this.internalTrainingService = component.get(PAIK8STrainingService); this.internalTrainingService = component.get(PAIK8STrainingService);
} }
for (const [key, value] of this.metaDataCache) {
if (this.internalTrainingService === undefined) {
throw new Error("TrainingService is not assigned!");
}
await this.internalTrainingService.setClusterMetadata(key, value);
}
if (this.internalTrainingService === undefined) {
throw new Error("TrainingService is not assigned!");
}
await this.internalTrainingService.setClusterMetadata(key, value);
this.metaDataCache.clear();
} else if (key === TrialConfigMetadataKey.AML_CLUSTER_CONFIG) { } else if (key === TrialConfigMetadataKey.AML_CLUSTER_CONFIG) {
this.internalTrainingService = component.get(TrialDispatcher); this.internalTrainingService = component.get(TrialDispatcher);
Container.bind(EnvironmentService)
.to(AMLEnvironmentService)
.scope(Scope.Singleton);
for (const [key, value] of this.metaDataCache) {
if (this.internalTrainingService === undefined) {
throw new Error("TrainingService is not assigned!");
}
await this.internalTrainingService.setClusterMetadata(key, value);
}
if (this.internalTrainingService === undefined) { if (this.internalTrainingService === undefined) {
throw new Error("TrainingService is not assigned!"); throw new Error("internalTrainingService not initialized!");
} }
await this.internalTrainingService.setClusterMetadata(key, value); await this.internalTrainingService.setClusterMetadata('platform_list', 'aml');
this.metaDataCache.clear();
} else if (key === TrialConfigMetadataKey.REMOTE_CONFIG) { } else if (key === TrialConfigMetadataKey.REMOTE_CONFIG) {
const config = <RemoteConfig>JSON.parse(value); const config = <RemoteConfig>JSON.parse(value);
if (config.reuse === true) { if (config.reuse === true) {
this.log.info(`reuse flag enabled, use EnvironmentManager.`); this.log.info(`reuse flag enabled, use EnvironmentManager.`);
this.internalTrainingService = component.get(TrialDispatcher); this.internalTrainingService = component.get(TrialDispatcher);
Container.bind(EnvironmentService) if (this.internalTrainingService === undefined) {
.to(RemoteEnvironmentService) throw new Error("internalTrainingService not initialized!");
.scope(Scope.Singleton); }
await this.internalTrainingService.setClusterMetadata('platform_list', 'remote');
} else { } else {
this.log.debug(`caching metadata key:{} value:{}, as training service is not determined.`); this.log.debug(`caching metadata key:{} value:{}, as training service is not determined.`);
this.internalTrainingService = component.get(RemoteMachineTrainingService); this.internalTrainingService = component.get(RemoteMachineTrainingService);
} }
} else {
this.log.debug(`caching metadata key:{} value:{}, as training service is not determined.`);
this.metaDataCache.set(key, value);
} }
} else {
await this.internalTrainingService.setClusterMetadata(key, value);
} }
if (this.internalTrainingService === undefined) {
throw new Error("internalTrainingService not initialized!");
}
await this.internalTrainingService.setClusterMetadata(key, value);
} }
public async getClusterMetadata(key: string): Promise<string> { public async getClusterMetadata(key: string): Promise<string> {
......
...@@ -3,15 +3,13 @@ ...@@ -3,15 +3,13 @@
import * as chai from 'chai'; import * as chai from 'chai';
import * as path from 'path'; import * as path from 'path';
import { Scope } from "typescript-ioc";
import * as component from '../../../common/component';
import { getLogger, Logger } from "../../../common/log"; import { getLogger, Logger } from "../../../common/log";
import { TrialJobApplicationForm, TrialJobStatus } from '../../../common/trainingService'; import { TrialJobApplicationForm, TrialJobStatus } from '../../../common/trainingService';
import { cleanupUnitTest, delay, prepareUnitTest, uniqueString } from '../../../common/utils'; import { cleanupUnitTest, delay, prepareUnitTest, uniqueString } from '../../../common/utils';
import { INITIALIZED, KILL_TRIAL_JOB, NEW_TRIAL_JOB, SEND_TRIAL_JOB_PARAMETER, TRIAL_END, GPU_INFO } from '../../../core/commands'; import { INITIALIZED, KILL_TRIAL_JOB, NEW_TRIAL_JOB, SEND_TRIAL_JOB_PARAMETER, TRIAL_END, GPU_INFO } from '../../../core/commands';
import { TrialConfigMetadataKey } from '../../../training_service/common/trialConfigMetadataKey'; import { TrialConfigMetadataKey } from '../../../training_service/common/trialConfigMetadataKey';
import { Command } from '../commandChannel'; import { Command, CommandChannel } from '../commandChannel';
import { EnvironmentInformation, EnvironmentService } from "../environment"; import { Channel, EnvironmentInformation, EnvironmentService } from "../environment";
import { TrialDetail } from '../trial'; import { TrialDetail } from '../trial';
import { TrialDispatcher } from "../trialDispatcher"; import { TrialDispatcher } from "../trialDispatcher";
import { UtCommandChannel } from './utCommandChannel'; import { UtCommandChannel } from './utCommandChannel';
...@@ -54,7 +52,7 @@ async function waitResult<TResult>(callback: () => Promise<TResult | undefined>, ...@@ -54,7 +52,7 @@ async function waitResult<TResult>(callback: () => Promise<TResult | undefined>,
return undefined; return undefined;
} }
async function waitResultMust<TResult>(callback: () => Promise<TResult | undefined>, waitMs: number = 1000, interval: number = 1): Promise<TResult> { async function waitResultMust<TResult>(callback: () => Promise<TResult | undefined>, waitMs: number = 10000, interval: number = 1): Promise<TResult> {
const result = await waitResult(callback, waitMs, interval, true); const result = await waitResult(callback, waitMs, interval, true);
// this error should be thrown in waitResult already. // this error should be thrown in waitResult already.
if (result === undefined) { if (result === undefined) {
...@@ -201,16 +199,21 @@ describe('Unit Test for TrialDispatcher', () => { ...@@ -201,16 +199,21 @@ describe('Unit Test for TrialDispatcher', () => {
nniManagerIp: "127.0.0.1", nniManagerIp: "127.0.0.1",
} }
trialDispatcher = new TrialDispatcher(); trialDispatcher = new TrialDispatcher();
component.Container.bind(EnvironmentService)
.to(UtEnvironmentService)
.scope(Scope.Singleton);
await trialDispatcher.setClusterMetadata(TrialConfigMetadataKey.TRIAL_CONFIG, JSON.stringify(trialConfig)); await trialDispatcher.setClusterMetadata(TrialConfigMetadataKey.TRIAL_CONFIG, JSON.stringify(trialConfig));
await trialDispatcher.setClusterMetadata(TrialConfigMetadataKey.NNI_MANAGER_IP, JSON.stringify(nniManagerIpConfig)); await trialDispatcher.setClusterMetadata(TrialConfigMetadataKey.NNI_MANAGER_IP, JSON.stringify(nniManagerIpConfig));
trialRunPromise = trialDispatcher.run(); // set ut environment
let environmentServiceList: EnvironmentService[] = [];
environmentService = new UtEnvironmentService();
environmentServiceList.push(environmentService);
trialDispatcher.environmentServiceList = environmentServiceList;
// set ut command channel
environmentService.initCommandChannel(trialDispatcher.commandEmitter);
commandChannel = environmentService.getCommandChannel as UtCommandChannel;
trialDispatcher.commandChannelSet = new Set<CommandChannel>().add(environmentService.getCommandChannel);
trialDispatcher.environmentMaintenceLoopInterval = 1000;
environmentService = component.get(EnvironmentService) as UtEnvironmentService; trialRunPromise = trialDispatcher.run();
commandChannel = environmentService.testGetCommandChannel();
}); });
afterEach(async () => { afterEach(async () => {
...@@ -258,9 +261,6 @@ describe('Unit Test for TrialDispatcher', () => { ...@@ -258,9 +261,6 @@ describe('Unit Test for TrialDispatcher', () => {
await waitEnvironment(2, previousEnvironments, environmentService, commandChannel); await waitEnvironment(2, previousEnvironments, environmentService, commandChannel);
await verifyTrialRunning(commandChannel, trialDetail); await verifyTrialRunning(commandChannel, trialDetail);
await verifyTrialResult(commandChannel, trialDetail, -1); await verifyTrialResult(commandChannel, trialDetail, -1);
await waitResultMust<true>(async () => {
return environment.status === 'USER_CANCELED' ? true : undefined;
});
chai.assert.equal(environmentService.testGetEnvironments().size, 2, "as env not reused, so only 2 envs should be here."); chai.assert.equal(environmentService.testGetEnvironments().size, 2, "as env not reused, so only 2 envs should be here.");
const trials = await trialDispatcher.listTrialJobs(); const trials = await trialDispatcher.listTrialJobs();
...@@ -433,12 +433,10 @@ describe('Unit Test for TrialDispatcher', () => { ...@@ -433,12 +433,10 @@ describe('Unit Test for TrialDispatcher', () => {
let environment = await waitEnvironment(1, previousEnvironments, environmentService, commandChannel); let environment = await waitEnvironment(1, previousEnvironments, environmentService, commandChannel);
await verifyTrialRunning(commandChannel, trialDetail); await verifyTrialRunning(commandChannel, trialDetail);
await verifyTrialResult(commandChannel, trialDetail, 0); await verifyTrialResult(commandChannel, trialDetail, 0);
environmentService.testSetEnvironmentStatus(environment, 'SUCCEEDED'); environmentService.testSetEnvironmentStatus(environment, 'SUCCEEDED');
await waitResultMust<boolean>(async () => { await waitResultMust<boolean>(async () => {
return environment.status === 'SUCCEEDED' ? true : undefined; return environment.status === 'SUCCEEDED' ? true : undefined;
}); });
trialDetail = await newTrial(trialDispatcher); trialDetail = await newTrial(trialDispatcher);
await waitEnvironment(2, previousEnvironments, environmentService, commandChannel); await waitEnvironment(2, previousEnvironments, environmentService, commandChannel);
await verifyTrialRunning(commandChannel, trialDetail); await verifyTrialRunning(commandChannel, trialDetail);
......
// Copyright (c) Microsoft Corporation. // Copyright (c) Microsoft Corporation.
// Licensed under the MIT license. // Licensed under the MIT license.
import { EnvironmentInformation, EnvironmentService, EnvironmentStatus } from "../environment"; import { Channel, EnvironmentInformation, EnvironmentService, EnvironmentStatus } from "../environment";
import { EventEmitter } from "events"; import { EventEmitter } from 'events';
import { CommandChannel } from "../commandChannel";
import { UtCommandChannel } from "./utCommandChannel"; import { UtCommandChannel } from "./utCommandChannel";
export class UtEnvironmentService extends EnvironmentService { export class UtEnvironmentService extends EnvironmentService {
private commandChannel: UtCommandChannel | undefined;
private allEnvironments = new Map<string, EnvironmentInformation>(); private allEnvironments = new Map<string, EnvironmentInformation>();
private hasMoreEnvironmentsInternal = true; private hasMoreEnvironmentsInternal = true;
...@@ -23,6 +21,14 @@ export class UtEnvironmentService extends EnvironmentService { ...@@ -23,6 +21,14 @@ export class UtEnvironmentService extends EnvironmentService {
return 1; return 1;
} }
public get getName(): string {
return 'ut';
}
public initCommandChannel(eventEmitter: EventEmitter): void {
this.commandChannel = new UtCommandChannel(eventEmitter);
}
public testSetEnvironmentStatus(environment: EnvironmentInformation, newStatus: EnvironmentStatus): void { public testSetEnvironmentStatus(environment: EnvironmentInformation, newStatus: EnvironmentStatus): void {
environment.status = newStatus; environment.status = newStatus;
} }
...@@ -35,13 +41,6 @@ export class UtEnvironmentService extends EnvironmentService { ...@@ -35,13 +41,6 @@ export class UtEnvironmentService extends EnvironmentService {
return this.allEnvironments; return this.allEnvironments;
} }
public testGetCommandChannel(): UtCommandChannel {
if (this.commandChannel === undefined) {
throw new Error(`command channel shouldn't be undefined.`);
}
return this.commandChannel;
}
public testSetNoMoreEnvironment(hasMore: boolean): void { public testSetNoMoreEnvironment(hasMore: boolean): void {
this.hasMoreEnvironmentsInternal = hasMore; this.hasMoreEnvironmentsInternal = hasMore;
} }
...@@ -50,11 +49,6 @@ export class UtEnvironmentService extends EnvironmentService { ...@@ -50,11 +49,6 @@ export class UtEnvironmentService extends EnvironmentService {
return this.hasMoreEnvironmentsInternal; return this.hasMoreEnvironmentsInternal;
} }
public createCommandChannel(commandEmitter: EventEmitter): CommandChannel {
this.commandChannel = new UtCommandChannel(commandEmitter)
return this.commandChannel;
}
public async config(_key: string, _value: string): Promise<void> { public async config(_key: string, _value: string): Promise<void> {
// do nothing // do nothing
} }
......
...@@ -10,10 +10,10 @@ import { Writable } from 'stream'; ...@@ -10,10 +10,10 @@ import { Writable } from 'stream';
import { String } from 'typescript-string-operations'; import { String } from 'typescript-string-operations';
import * as component from '../../common/component'; import * as component from '../../common/component';
import { NNIError, NNIErrorNames, MethodNotImplementedError } from '../../common/errors'; import { NNIError, NNIErrorNames, MethodNotImplementedError } from '../../common/errors';
import { getBasePort, getExperimentId, getPlatform } from '../../common/experimentStartupInfo'; import { getBasePort, getExperimentId } from '../../common/experimentStartupInfo';
import { getLogger, Logger } from '../../common/log'; import { getLogger, Logger } from '../../common/log';
import { NNIManagerIpConfig, TrainingService, TrialJobApplicationForm, TrialJobMetric, TrialJobStatus, LogType } from '../../common/trainingService'; import { NNIManagerIpConfig, TrainingService, TrialJobApplicationForm, TrialJobMetric, TrialJobStatus, LogType } from '../../common/trainingService';
import { delay, getExperimentRootDir, getIPV4Address, getLogLevel, getVersion, mkDirPSync, uniqueString } from '../../common/utils'; import { delay, getExperimentRootDir, getIPV4Address, getLogLevel, getVersion, mkDirPSync, randomSelect, uniqueString } from '../../common/utils';
import { GPU_INFO, INITIALIZED, KILL_TRIAL_JOB, NEW_TRIAL_JOB, REPORT_METRIC_DATA, SEND_TRIAL_JOB_PARAMETER, STDOUT, TRIAL_END, VERSION_CHECK } from '../../core/commands'; import { GPU_INFO, INITIALIZED, KILL_TRIAL_JOB, NEW_TRIAL_JOB, REPORT_METRIC_DATA, SEND_TRIAL_JOB_PARAMETER, STDOUT, TRIAL_END, VERSION_CHECK } from '../../core/commands';
import { ScheduleResultType } from '../../training_service/common/gpuData'; import { ScheduleResultType } from '../../training_service/common/gpuData';
import { CONTAINER_INSTALL_NNI_SHELL_FORMAT } from '../common/containerJobData'; import { CONTAINER_INSTALL_NNI_SHELL_FORMAT } from '../common/containerJobData';
...@@ -22,6 +22,7 @@ import { TrialConfigMetadataKey } from '../common/trialConfigMetadataKey'; ...@@ -22,6 +22,7 @@ import { TrialConfigMetadataKey } from '../common/trialConfigMetadataKey';
import { validateCodeDir } from '../common/util'; import { validateCodeDir } from '../common/util';
import { Command, CommandChannel } from './commandChannel'; import { Command, CommandChannel } from './commandChannel';
import { EnvironmentInformation, EnvironmentService, NodeInformation, RunnerSettings, TrialGpuSummary } from './environment'; import { EnvironmentInformation, EnvironmentService, NodeInformation, RunnerSettings, TrialGpuSummary } from './environment';
import { EnvironmentServiceFactory } from './environments/environmentServiceFactory';
import { GpuScheduler } from './gpuScheduler'; import { GpuScheduler } from './gpuScheduler';
import { MountedStorageService } from './storages/mountedStorageService'; import { MountedStorageService } from './storages/mountedStorageService';
import { StorageService } from './storageService'; import { StorageService } from './storageService';
...@@ -45,13 +46,16 @@ class TrialDispatcher implements TrainingService { ...@@ -45,13 +46,16 @@ class TrialDispatcher implements TrainingService {
private enableVersionCheck: boolean = true; private enableVersionCheck: boolean = true;
private trialConfig: TrialConfig | undefined; private trialConfig: TrialConfig | undefined;
private runnerSettings: RunnerSettings;
private commandEmitter: EventEmitter | undefined;
private commandChannel: CommandChannel | undefined;
private readonly trials: Map<string, TrialDetail>; private readonly trials: Map<string, TrialDetail>;
private readonly environments: Map<string, EnvironmentInformation>; private readonly environments: Map<string, EnvironmentInformation>;
// make public for ut
public environmentServiceList: EnvironmentService[] = [];
public commandChannelSet: Set<CommandChannel>;
public commandEmitter: EventEmitter;
public environmentMaintenceLoopInterval: number = -1;
private nniManagerIp: string | undefined;
// uses to accelerate trial manager loop // uses to accelerate trial manager loop
// true means there is updates, and trial loop should run a cycle immediately. // true means there is updates, and trial loop should run a cycle immediately.
...@@ -62,6 +66,7 @@ class TrialDispatcher implements TrainingService { ...@@ -62,6 +66,7 @@ class TrialDispatcher implements TrainingService {
private enableGpuScheduler: boolean = false; private enableGpuScheduler: boolean = false;
// uses to save if user like to reuse environment // uses to save if user like to reuse environment
private reuseEnvironment: boolean = true; private reuseEnvironment: boolean = true;
private logCollection: string = '';
private gpuScheduler: GpuScheduler; private gpuScheduler: GpuScheduler;
...@@ -76,10 +81,7 @@ class TrialDispatcher implements TrainingService { ...@@ -76,10 +81,7 @@ class TrialDispatcher implements TrainingService {
this.metricsEmitter = new EventEmitter(); this.metricsEmitter = new EventEmitter();
this.experimentId = getExperimentId(); this.experimentId = getExperimentId();
this.experimentRootDir = getExperimentRootDir(); this.experimentRootDir = getExperimentRootDir();
this.commandChannelSet = new Set<CommandChannel>();
this.runnerSettings = new RunnerSettings();
this.runnerSettings.experimentId = this.experimentId;
this.runnerSettings.platform = getPlatform();
const logLevel = getLogLevel(); const logLevel = getLogLevel();
this.log.debug(`current folder ${__dirname}`); this.log.debug(`current folder ${__dirname}`);
...@@ -89,6 +91,8 @@ class TrialDispatcher implements TrainingService { ...@@ -89,6 +91,8 @@ class TrialDispatcher implements TrainingService {
this.isDeveloping = true; this.isDeveloping = true;
} }
this.commandEmitter = new EventEmitter();
this.gpuScheduler = new GpuScheduler(); this.gpuScheduler = new GpuScheduler();
} }
...@@ -122,13 +126,7 @@ class TrialDispatcher implements TrainingService { ...@@ -122,13 +126,7 @@ class TrialDispatcher implements TrainingService {
const trialId: string = uniqueString(5); const trialId: string = uniqueString(5);
const environmentService = component.get<EnvironmentService>(EnvironmentService); const trialJobDetail: TrialDetail = new TrialDetail(trialId, "WAITING", Date.now(), "", form);
let trialWorkingFolder: string = "";
if (environmentService.hasStorageService) {
const storageService = component.get<StorageService>(StorageService);
trialWorkingFolder = storageService.joinPath('trials', trialId);
}
const trialJobDetail: TrialDetail = new TrialDetail(trialId, "WAITING", Date.now(), trialWorkingFolder, form);
this.trials.set(trialId, trialJobDetail); this.trials.set(trialId, trialJobDetail);
...@@ -142,23 +140,20 @@ class TrialDispatcher implements TrainingService { ...@@ -142,23 +140,20 @@ class TrialDispatcher implements TrainingService {
if (environment === undefined) { if (environment === undefined) {
throw new Error(`TrialDispatcher: trial ${trialJobId}'s env shouldn't be undefined in updateTrialJob.`); throw new Error(`TrialDispatcher: trial ${trialJobId}'s env shouldn't be undefined in updateTrialJob.`);
} }
if (this.commandChannel === undefined) { if (environment.environmentService === undefined) {
throw new Error(`TrialDispatcher: commandChannel shouldn't be undefined in updateTrialJob.`); throw new Error(`Environment ${environment.id} does not assigned environment service.`);
} }
const message = { const message = {
"trialId": trialJobId, "trialId": trialJobId,
"parameters": form.hyperParameters, "parameters": form.hyperParameters,
} }
await this.commandChannel.sendCommand(environment, SEND_TRIAL_JOB_PARAMETER, message); await environment.environmentService.getCommandChannel.sendCommand(environment, SEND_TRIAL_JOB_PARAMETER, message);
return trialDetail; return trialDetail;
} }
public async cancelTrialJob(trialJobId: string, isEarlyStopped?: boolean | undefined): Promise<void> { public async cancelTrialJob(trialJobId: string, isEarlyStopped?: boolean | undefined): Promise<void> {
if (this.commandChannel === undefined) {
throw new Error(`TrialDispatcher: commandChannel shouldn't be undefined in cancelTrialJob.`);
}
const trial = await this.getTrialJob(trialJobId); const trial = await this.getTrialJob(trialJobId);
switch (trial.status) { switch (trial.status) {
case "RUNNING": case "RUNNING":
...@@ -166,8 +161,8 @@ class TrialDispatcher implements TrainingService { ...@@ -166,8 +161,8 @@ class TrialDispatcher implements TrainingService {
case "UNKNOWN": case "UNKNOWN":
{ {
const environment = trial.environment; const environment = trial.environment;
if (environment) { if (environment && environment.environmentService) {
await this.commandChannel.sendCommand(environment, KILL_TRIAL_JOB, trial.id); await environment.environmentService.getCommandChannel.sendCommand(environment, KILL_TRIAL_JOB, trial.id);
trial.isEarlyStopped = isEarlyStopped; trial.isEarlyStopped = isEarlyStopped;
trial.status = trial.isEarlyStopped === true ? trial.status = trial.isEarlyStopped === true ?
'EARLY_STOPPED' : 'USER_CANCELED'; 'EARLY_STOPPED' : 'USER_CANCELED';
...@@ -179,70 +174,71 @@ class TrialDispatcher implements TrainingService { ...@@ -179,70 +174,71 @@ class TrialDispatcher implements TrainingService {
} }
public async run(): Promise<void> { public async run(): Promise<void> {
const environmentService = component.get<EnvironmentService>(EnvironmentService); if (this.trialConfig === undefined) {
throw new Error(`trial config shouldn't be undefined in run()`);
this.commandEmitter = new EventEmitter(); }
this.commandChannel = environmentService.createCommandChannel(this.commandEmitter); for(const environmentService of this.environmentServiceList) {
// TODO it's a hard code of web channel, it needs to be improved. const runnerSettings: RunnerSettings = new RunnerSettings();
if (this.runnerSettings.nniManagerIP === "" || this.runnerSettings.nniManagerIP === null) { runnerSettings.nniManagerIP = this.nniManagerIp === undefined? getIPV4Address() : this.nniManagerIp;
this.runnerSettings.nniManagerIP = getIPV4Address(); runnerSettings.nniManagerPort = getBasePort() + 1;
runnerSettings.commandChannel = environmentService.getCommandChannel.channelName;
runnerSettings.enableGpuCollector = this.enableGpuScheduler;
runnerSettings.command = this.trialConfig.command;
runnerSettings.nniManagerVersion = this.enableVersionCheck ? await getVersion() : '';
runnerSettings.logCollection = this.logCollection;
runnerSettings.platform = environmentService.getName;
runnerSettings.experimentId = this.experimentId;
await environmentService.getCommandChannel.start();
this.log.info(`TrialDispatcher: started channel: ${environmentService.getCommandChannel.constructor.name}`);
this.log.info(`TrialDispatcher: copying code and settings.`);
let storageService: StorageService;
if (environmentService.hasStorageService) {
this.log.debug(`TrialDispatcher: use existing storage service.`);
storageService = component.get<StorageService>(StorageService);
} else {
this.log.debug(`TrialDispatcher: create temp storage service to temp folder.`);
storageService = new MountedStorageService();
const environmentLocalTempFolder = path.join(this.experimentRootDir, this.experimentId, "environment-temp");
storageService.initialize(this.trialConfig.codeDir, environmentLocalTempFolder);
}
// Copy the compressed file to remoteDirectory and delete it
const codeDir = path.resolve(this.trialConfig.codeDir);
const envDir = storageService.joinPath("envs");
const codeFileName = await storageService.copyDirectory(codeDir, envDir, true);
storageService.rename(codeFileName, "nni-code.tar.gz");
const installFileName = storageService.joinPath(envDir, 'install_nni.sh');
await storageService.save(CONTAINER_INSTALL_NNI_SHELL_FORMAT, installFileName);
const runnerSettingsConfig = storageService.joinPath(envDir, "settings.json");
await storageService.save(JSON.stringify(runnerSettings), runnerSettingsConfig);
if (this.isDeveloping) {
let trialToolsPath = path.join(__dirname, "../../../../../tools/nni_trial_tool");
if (false === fs.existsSync(trialToolsPath)) {
trialToolsPath = path.join(__dirname, "..\\..\\..\\..\\..\\tools\\nni_trial_tool");
}
await storageService.copyDirectory(trialToolsPath, envDir, true);
}
} }
this.runnerSettings.nniManagerPort = getBasePort() + 1;
this.runnerSettings.commandChannel = this.commandChannel.channelName;
// start channel // start channel
this.commandEmitter.on("command", (command: Command): void => { this.commandEmitter.on("command", (command: Command): void => {
this.handleCommand(command).catch((err: Error) => { this.handleCommand(command).catch((err: Error) => {
this.log.error(`TrialDispatcher: error on handle env ${command.environment.id} command: ${command.command}, data: ${command.data}, error: ${err}`); this.log.error(`TrialDispatcher: error on handle env ${command.environment.id} command: ${command.command}, data: ${command.data}, error: ${err}`);
}) })
}); });
await this.commandChannel.start();
this.log.info(`TrialDispatcher: started channel: ${this.commandChannel.constructor.name}`);
if (this.trialConfig === undefined) {
throw new Error(`trial config shouldn't be undefined in run()`);
}
this.log.info(`TrialDispatcher: copying code and settings.`);
let storageService: StorageService;
if (environmentService.hasStorageService) {
this.log.debug(`TrialDispatcher: use existing storage service.`);
storageService = component.get<StorageService>(StorageService);
} else {
this.log.debug(`TrialDispatcher: create temp storage service to temp folder.`);
storageService = new MountedStorageService();
const environmentLocalTempFolder = path.join(this.experimentRootDir, this.experimentId, "environment-temp");
storageService.initialize(this.trialConfig.codeDir, environmentLocalTempFolder);
}
// Copy the compressed file to remoteDirectory and delete it
const codeDir = path.resolve(this.trialConfig.codeDir);
const envDir = storageService.joinPath("envs");
const codeFileName = await storageService.copyDirectory(codeDir, envDir, true);
storageService.rename(codeFileName, "nni-code.tar.gz");
const installFileName = storageService.joinPath(envDir, 'install_nni.sh');
await storageService.save(CONTAINER_INSTALL_NNI_SHELL_FORMAT, installFileName);
const runnerSettings = storageService.joinPath(envDir, "settings.json");
await storageService.save(JSON.stringify(this.runnerSettings), runnerSettings);
// FIXME: what the hell is this?
if (this.isDeveloping) {
let trialToolsPath = path.join(__dirname, "../../../../../tools/nni_trial_tool");
if (false === fs.existsSync(trialToolsPath)) {
trialToolsPath = path.join(__dirname, "..\\..\\..\\..\\..\\tools\\nni_trial_tool");
}
await storageService.copyDirectory(trialToolsPath, envDir, true);
}
await this.prefetchEnvironments(); await this.prefetchEnvironments();
this.log.info(`TrialDispatcher: run loop started.`); this.log.info(`TrialDispatcher: run loop started.`);
await Promise.all([ const promiseList: Promise<void>[] = [];
this.environmentMaintenanceLoop(), for(const commandChannel of this.commandChannelSet) {
this.trialManagementLoop(), promiseList.push(commandChannel.run());
this.commandChannel.run(), }
]); promiseList.push(this.environmentMaintenanceLoop());
promiseList.push(this.trialManagementLoop());
await Promise.all(promiseList);
} }
public addTrialJobMetricListener(listener: (metric: TrialJobMetric) => void): void { public addTrialJobMetricListener(listener: (metric: TrialJobMetric) => void): void {
...@@ -260,14 +256,14 @@ class TrialDispatcher implements TrainingService { ...@@ -260,14 +256,14 @@ class TrialDispatcher implements TrainingService {
public async setClusterMetadata(key: string, value: string): Promise<void> { public async setClusterMetadata(key: string, value: string): Promise<void> {
switch (key) { switch (key) {
case TrialConfigMetadataKey.NNI_MANAGER_IP: case TrialConfigMetadataKey.NNI_MANAGER_IP:
this.runnerSettings.nniManagerIP = (<NNIManagerIpConfig>JSON.parse(value)).nniManagerIp; this.nniManagerIp = (<NNIManagerIpConfig>JSON.parse(value)).nniManagerIp;
break; break;
case TrialConfigMetadataKey.VERSION_CHECK: case TrialConfigMetadataKey.VERSION_CHECK:
this.enableVersionCheck = (value === 'true' || value === 'True'); this.enableVersionCheck = (value === 'true' || value === 'True');
this.runnerSettings.nniManagerVersion = this.enableVersionCheck ? await getVersion() : '';
break; break;
case TrialConfigMetadataKey.LOG_COLLECTION: case TrialConfigMetadataKey.LOG_COLLECTION:
this.runnerSettings.logCollection = value; this.logCollection = value;
break; break;
case TrialConfigMetadataKey.TRIAL_CONFIG: case TrialConfigMetadataKey.TRIAL_CONFIG:
this.trialConfig = <TrialConfig>JSON.parse(value); this.trialConfig = <TrialConfig>JSON.parse(value);
...@@ -279,15 +275,25 @@ class TrialDispatcher implements TrainingService { ...@@ -279,15 +275,25 @@ class TrialDispatcher implements TrainingService {
this.log.info(`TrialDispatcher: GPU scheduler is enabled.`) this.log.info(`TrialDispatcher: GPU scheduler is enabled.`)
this.enableGpuScheduler = true; this.enableGpuScheduler = true;
} }
this.runnerSettings.enableGpuCollector = this.enableGpuScheduler;
this.runnerSettings.command = this.trialConfig.command;
// Validate to make sure codeDir doesn't have too many files // Validate to make sure codeDir doesn't have too many files
await validateCodeDir(this.trialConfig.codeDir); await validateCodeDir(this.trialConfig.codeDir);
break; break;
case TrialConfigMetadataKey.PLATFORM_LIST: {
const platforms: string[] = value.split(",");
for(const platform of platforms) {
const environmentService: EnvironmentService = EnvironmentServiceFactory.createEnvironmentService(platform);
environmentService.initCommandChannel(this.commandEmitter);
this.environmentMaintenceLoopInterval =
Math.max(environmentService.environmentMaintenceLoopInterval, this.environmentMaintenceLoopInterval);
this.commandChannelSet.add(environmentService.getCommandChannel);
this.environmentServiceList.push(environmentService);
}
}
}
for(const environmentService of this.environmentServiceList) {
await environmentService.config(key, value);
} }
const environmentService = component.get<EnvironmentService>(EnvironmentService);
await environmentService.config(key, value);
} }
public getClusterMetadata(_key: string): Promise<string> { public getClusterMetadata(_key: string): Promise<string> {
...@@ -295,48 +301,75 @@ class TrialDispatcher implements TrainingService { ...@@ -295,48 +301,75 @@ class TrialDispatcher implements TrainingService {
} }
public async cleanUp(): Promise<void> { public async cleanUp(): Promise<void> {
if (this.commandChannel === undefined) {
throw new Error(`TrialDispatcher: commandChannel shouldn't be undefined in cleanUp.`);
}
if (this.commandEmitter === undefined) { if (this.commandEmitter === undefined) {
throw new Error(`TrialDispatcher: commandEmitter shouldn't be undefined in cleanUp.`); throw new Error(`TrialDispatcher: commandEmitter shouldn't be undefined in cleanUp.`);
} }
this.stopping = true; this.stopping = true;
this.shouldUpdateTrials = true; this.shouldUpdateTrials = true;
const environmentService = component.get<EnvironmentService>(EnvironmentService);
const environments = [...this.environments.values()]; const environments = [...this.environments.values()];
for (let index = 0; index < environments.length; index++) { for (let index = 0; index < environments.length; index++) {
const environment = environments[index]; const environment = environments[index];
if (environment.isAlive === true) { if (environment.isAlive === true) {
this.log.info(`stopping environment ${environment.id}...`); this.log.info(`stopping environment ${environment.id}...`);
await environmentService.stopEnvironment(environment); if (environment.environmentService === undefined) {
await this.commandChannel.close(environment); throw new Error(`${environment.id} do not have environmentService!`);
}
await environment.environmentService.stopEnvironment(environment);
this.log.info(`stopped environment ${environment.id}.`); this.log.info(`stopped environment ${environment.id}.`);
} }
} }
this.commandEmitter.off("command", this.handleCommand); this.commandEmitter.off("command", this.handleCommand);
await this.commandChannel.stop(); for (const commandChannel of this.commandChannelSet) {
await commandChannel.stop();
}
} }
private async environmentMaintenanceLoop(): Promise<void> { private async environmentMaintenanceLoop(): Promise<void> {
if (this.commandChannel === undefined) {
throw new Error(`TrialDispatcher: commandChannel shouldn't be undefined in environmentMaintenanceLoop.`);
}
const environmentService = component.get<EnvironmentService>(EnvironmentService);
while (!this.stopping) { while (!this.stopping) {
const environments: EnvironmentInformation[] = []; const environments: EnvironmentInformation[] = [];
for (const environment of this.environments.values()) { for (const environment of this.environments.values()) {
if (environment.isAlive === true) { if (environment.isAlive === true) {
environments.push(environment); environments.push(environment);
} else { } else {
await this.commandChannel.close(environment); if (environment.environmentService === undefined) {
throw new Error(`${environment.id} do not have environment service!`);
}
await environment.environmentService.getCommandChannel.close(environment);
}
}
// Group environments according to environmentService
const environmentServiceDict: Map<EnvironmentService, EnvironmentInformation[]> =
new Map<EnvironmentService, EnvironmentInformation[]>();
for (const environment of environments) {
if (environment.environmentService === undefined) {
throw new Error(`${environment.id} do not have environment service!`);
}
if (!environmentServiceDict.has(environment.environmentService)) {
environmentServiceDict.set(environment.environmentService, [environment]);
} else {
const environmentsList: EnvironmentInformation[] | undefined = environmentServiceDict.get(environment.environmentService);
if (environmentsList === undefined) {
throw new Error(`Environment list not initialized!`);
}
environmentsList.push(environment);
environmentServiceDict.set(environment.environmentService, environmentsList);
}
}
// Refresh all environments
const taskList: Promise<void>[] = [];
for (const environmentService of environmentServiceDict.keys()) {
const environmentsList: EnvironmentInformation[] | undefined = environmentServiceDict.get(environmentService);
if (environmentsList) {
taskList.push(environmentService.refreshEnvironmentsStatus(environmentsList));
} }
} }
await environmentService.refreshEnvironmentsStatus(environments); await Promise.all(taskList);
environments.forEach((environment) => { for (const environment of environments) {
if (environment.environmentService === undefined) {
throw new Error(`${environment.id} do not have environment service!`);
}
const oldIsAlive = environment.isAlive; const oldIsAlive = environment.isAlive;
switch (environment.status) { switch (environment.status) {
case 'WAITING': case 'WAITING':
...@@ -351,16 +384,16 @@ class TrialDispatcher implements TrainingService { ...@@ -351,16 +384,16 @@ class TrialDispatcher implements TrainingService {
if (oldIsAlive !== environment.isAlive) { if (oldIsAlive !== environment.isAlive) {
this.log.debug(`set environment ${environment.id} isAlive from ${oldIsAlive} to ${environment.isAlive} due to status is ${environment.status}.`); this.log.debug(`set environment ${environment.id} isAlive from ${oldIsAlive} to ${environment.isAlive} due to status is ${environment.status}.`);
} }
}); }
this.shouldUpdateTrials = true; this.shouldUpdateTrials = true;
await delay(environmentService.environmentMaintenceLoopInterval); if (this.environmentMaintenceLoopInterval === -1) {
throw new Error("EnvironmentMaintenceLoopInterval not initialized!");
}
await delay(this.environmentMaintenceLoopInterval);
} }
} }
private async trialManagementLoop(): Promise<void> { private async trialManagementLoop(): Promise<void> {
if (this.commandChannel === undefined) {
throw new Error(`TrialDispatcher: commandChannel shouldn't be undefined in trialManagementLoop.`);
}
const interval = 1; const interval = 1;
while (!this.stopping) { while (!this.stopping) {
...@@ -400,6 +433,11 @@ class TrialDispatcher implements TrainingService { ...@@ -400,6 +433,11 @@ class TrialDispatcher implements TrainingService {
liveTrialsCount++; liveTrialsCount++;
continue; continue;
} }
if (environment.environmentService === undefined) {
throw new Error(`${environment.id} does not has environment service!`);
}
trial.url = environment.trackingUrl; trial.url = environment.trackingUrl;
const environmentStatus = environment.status; const environmentStatus = environment.status;
...@@ -414,7 +452,7 @@ class TrialDispatcher implements TrainingService { ...@@ -414,7 +452,7 @@ class TrialDispatcher implements TrainingService {
// for example, in horovod, it's just sleep command, has no impact on trial result. // for example, in horovod, it's just sleep command, has no impact on trial result.
if (environment.nodeCount > completedCount) { if (environment.nodeCount > completedCount) {
this.log.info(`stop partial completed trial ${trial.id}`); this.log.info(`stop partial completed trial ${trial.id}`);
await this.commandChannel.sendCommand(environment, KILL_TRIAL_JOB, trial.id); await environment.environmentService.getCommandChannel.sendCommand(environment, KILL_TRIAL_JOB, trial.id);
} }
for (const node of trial.nodes.values()) { for (const node of trial.nodes.values()) {
if (node.status === "FAILED") { if (node.status === "FAILED") {
...@@ -463,8 +501,10 @@ class TrialDispatcher implements TrainingService { ...@@ -463,8 +501,10 @@ class TrialDispatcher implements TrainingService {
false === this.reuseEnvironment && false === this.reuseEnvironment &&
environment.assignedTrialCount > 0 environment.assignedTrialCount > 0
) { ) {
const environmentService = component.get<EnvironmentService>(EnvironmentService); if (environment.environmentService === undefined) {
await environmentService.stopEnvironment(environment); throw new Error(`${environment.id} does not has environment service!`);
}
await environment.environmentService.stopEnvironment(environment);
continue; continue;
} }
...@@ -556,11 +596,13 @@ class TrialDispatcher implements TrainingService { ...@@ -556,11 +596,13 @@ class TrialDispatcher implements TrainingService {
} }
if (neededEnvironmentCount > 0) { if (neededEnvironmentCount > 0) {
const environmentService = component.get<EnvironmentService>(EnvironmentService);
let requestedCount = 0; let requestedCount = 0;
let hasMoreEnvironments = false;
for (let index = 0; index < neededEnvironmentCount; index++) { for (let index = 0; index < neededEnvironmentCount; index++) {
if (true === environmentService.hasMoreEnvironments) { const environmentService: EnvironmentService | undefined = this.selectEnvironmentService();
await this.requestEnvironment(); if (environmentService !== undefined) {
hasMoreEnvironments = true;
await this.requestEnvironment(environmentService);
requestedCount++; requestedCount++;
this.isLoggedNoMoreEnvironment = false; this.isLoggedNoMoreEnvironment = false;
} else { } else {
...@@ -570,7 +612,7 @@ class TrialDispatcher implements TrainingService { ...@@ -570,7 +612,7 @@ class TrialDispatcher implements TrainingService {
} }
} }
} }
if (environmentService.hasMoreEnvironments === true || requestedCount > 0) { if (hasMoreEnvironments === true || requestedCount > 0) {
this.log.info(`requested new environment, live trials: ${liveTrialsCount}, ` + this.log.info(`requested new environment, live trials: ${liveTrialsCount}, ` +
`live environments: ${liveEnvironmentsCount}, neededEnvironmentCount: ${neededEnvironmentCount}, ` + `live environments: ${liveEnvironmentsCount}, neededEnvironmentCount: ${neededEnvironmentCount}, ` +
`requestedCount: ${requestedCount}`); `requestedCount: ${requestedCount}`);
...@@ -580,25 +622,37 @@ class TrialDispatcher implements TrainingService { ...@@ -580,25 +622,37 @@ class TrialDispatcher implements TrainingService {
} }
} }
private async prefetchEnvironments (): Promise<void> { // Schedule a environment platform for environment
const environmentService = component.get<EnvironmentService>(EnvironmentService); private selectEnvironmentService(): EnvironmentService | undefined {
const number = environmentService.prefetchedEnvironmentCount; const validEnvironmentServiceList = [];
this.log.info(`Initialize environments total number: ${number}`); for(const environmentService of this.environmentServiceList){
for (let index = 0; index < number; index++) { if (environmentService.hasMoreEnvironments) {
await this.requestEnvironment(); validEnvironmentServiceList.push(environmentService);
}
} }
if (validEnvironmentServiceList.length === 0) {
return undefined;
}
// Random scheduler
return randomSelect(validEnvironmentServiceList);
} }
private async requestEnvironment(): Promise<void> { private async prefetchEnvironments (): Promise<void> {
if (this.commandChannel === undefined) { for (const environmentService of this.environmentServiceList) {
throw new Error(`TrialDispatcher: commandChannel shouldn't be undefined in requestEnvironment.`); const number = environmentService.prefetchedEnvironmentCount;
this.log.info(`Initialize environments total number: ${number}`);
for (let index = 0; index < number; index++) {
await this.requestEnvironment(environmentService);
}
} }
}
const environmentService = component.get<EnvironmentService>(EnvironmentService); private async requestEnvironment(environmentService: EnvironmentService): Promise<void> {
const envId = uniqueString(5); const envId = uniqueString(5);
const envName = `nni_exp_${this.experimentId}_env_${envId}`; const envName = `nni_exp_${this.experimentId}_env_${envId}`;
const environment = environmentService.createEnvironmentInformation(envId, envName); const environment = environmentService.createEnvironmentInformation(envId, envName);
environment.environmentService = environmentService;
this.log.info(`Assign environment service ${environmentService.getName} to environment ${envId}`);
environment.command = `sh ../install_nni.sh && python3 -m nni.tools.trial_tool.trial_runner`; environment.command = `sh ../install_nni.sh && python3 -m nni.tools.trial_tool.trial_runner`;
if (this.isDeveloping) { if (this.isDeveloping) {
...@@ -616,15 +670,11 @@ class TrialDispatcher implements TrainingService { ...@@ -616,15 +670,11 @@ class TrialDispatcher implements TrainingService {
} else { } else {
environment.isAlive = true; environment.isAlive = true;
} }
await environment.environmentService.getCommandChannel.open(environment);
await this.commandChannel.open(environment);
this.log.info(`requested environment ${environment.id} and job id is ${environment.envId}.`); this.log.info(`requested environment ${environment.id} and job id is ${environment.envId}.`);
} }
private async allocateEnvironment(trial: TrialDetail, environment: EnvironmentInformation): Promise<void> { private async allocateEnvironment(trial: TrialDetail, environment: EnvironmentInformation): Promise<void> {
if (this.commandChannel === undefined) {
throw new Error(`TrialDispatcher: commandChannel shouldn't be undefined in allocateEnvironment.`);
}
if (this.trialConfig === undefined) { if (this.trialConfig === undefined) {
throw new Error(`TrialDispatcher: trialConfig shouldn't be undefined in allocateEnvironment.`); throw new Error(`TrialDispatcher: trialConfig shouldn't be undefined in allocateEnvironment.`);
} }
...@@ -653,6 +703,13 @@ class TrialDispatcher implements TrainingService { ...@@ -653,6 +703,13 @@ class TrialDispatcher implements TrainingService {
environment.runningTrialCount++; environment.runningTrialCount++;
environment.assignedTrialCount++; environment.assignedTrialCount++;
trial.environment = environment; trial.environment = environment;
if (environment.environmentService === undefined) {
throw new Error(`${environment.id} environmentService not initialized!`);
}
if (environment.environmentService.hasStorageService) {
const storageService = component.get<StorageService>(StorageService);
trial.workingDirectory = storageService.joinPath('trials', trial.id);
}
trial.settings = { trial.settings = {
trialId: trial.id, trialId: trial.id,
gpuIndices: gpuIndices, gpuIndices: gpuIndices,
...@@ -661,7 +718,10 @@ class TrialDispatcher implements TrainingService { ...@@ -661,7 +718,10 @@ class TrialDispatcher implements TrainingService {
} }
trial.startTime = Date.now(); trial.startTime = Date.now();
trial.status = "RUNNING"; trial.status = "RUNNING";
await this.commandChannel.sendCommand(trial.environment, NEW_TRIAL_JOB, trial.settings); if (environment.environmentService === undefined) {
throw new Error(`${environment.id} does not have environment service!`);
}
await environment.environmentService.getCommandChannel.sendCommand(trial.environment, NEW_TRIAL_JOB, trial.settings);
} }
/** /**
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment