"tools/git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "16cb83775abf5b21f1df1cdbe45a7bf57639f953"
Unverified Commit 063d6b74 authored by SparkSnail's avatar SparkSnail Committed by GitHub
Browse files

Merge pull request #3580 from microsoft/v2.2

[do not Squash!] Merge V2.2 back to master
parents 08986c6b e1295888
...@@ -175,12 +175,14 @@ class NNIManager implements Manager { ...@@ -175,12 +175,14 @@ class NNIManager implements Manager {
nextSequenceId: 0, nextSequenceId: 0,
revision: 0 revision: 0
}; };
this.config = config;
this.log.info(`Starting experiment: ${this.experimentProfile.id}`); this.log.info(`Starting experiment: ${this.experimentProfile.id}`);
await this.storeExperimentProfile(); await this.storeExperimentProfile();
if (this.trainingService === undefined) {
this.log.info('Setup training service...'); this.log.info('Setup training service...');
this.trainingService = await this.initTrainingService(config); this.trainingService = await this.initTrainingService(config);
}
this.log.info('Setup tuner...'); this.log.info('Setup tuner...');
const dispatcherCommand: string = getMsgDispatcherCommand(config); const dispatcherCommand: string = getMsgDispatcherCommand(config);
...@@ -198,18 +200,22 @@ class NNIManager implements Manager { ...@@ -198,18 +200,22 @@ class NNIManager implements Manager {
} }
public async resumeExperiment(readonly: boolean): Promise<void> { public async resumeExperiment(readonly: boolean): Promise<void> {
this.log.info(`Resuming experiment: ${this.experimentProfile.id}`);
//Fetch back the experiment profile //Fetch back the experiment profile
const experimentId: string = getExperimentId(); const experimentId: string = getExperimentId();
this.log.info(`Resuming experiment: ${experimentId}`);
this.experimentProfile = await this.dataStore.getExperimentProfile(experimentId); this.experimentProfile = await this.dataStore.getExperimentProfile(experimentId);
this.readonly = readonly; this.readonly = readonly;
if (readonly) { if (readonly) {
this.setStatus('VIEWED');
return Promise.resolve(); return Promise.resolve();
} }
this.log.info('Setup training service...');
const config: ExperimentConfig = this.experimentProfile.params; const config: ExperimentConfig = this.experimentProfile.params;
this.config = config;
if (this.trainingService === undefined) {
this.log.info('Setup training service...');
this.trainingService = await this.initTrainingService(config); this.trainingService = await this.initTrainingService(config);
}
this.log.info('Setup tuner...'); this.log.info('Setup tuner...');
const dispatcherCommand: string = getMsgDispatcherCommand(config); const dispatcherCommand: string = getMsgDispatcherCommand(config);
...@@ -254,12 +260,35 @@ class NNIManager implements Manager { ...@@ -254,12 +260,35 @@ class NNIManager implements Manager {
return this.dataStore.getTrialJob(trialJobId); return this.dataStore.getTrialJob(trialJobId);
} }
public async setClusterMetadata(_key: string, _value: string): Promise<void> { public async setClusterMetadata(key: string, value: string): Promise<void> {
throw new Error('Calling removed API setClusterMetadata'); // Hack for supporting v2 config, need refactor
if (this.trainingService === undefined) {
this.log.info('Setup training service...');
switch (key) {
case 'kubeflow_config': {
const kubeflowModule = await import('../training_service/kubernetes/kubeflow/kubeflowTrainingService');
this.trainingService = new kubeflowModule.KubeflowTrainingService();
break;
}
case 'frameworkcontroller_config': {
const fcModule = await import('../training_service/kubernetes/frameworkcontroller/frameworkcontrollerTrainingService');
this.trainingService = new fcModule.FrameworkControllerTrainingService();
break;
}
case 'adl_config': {
const adlModule = await import('../training_service/kubernetes/adl/adlTrainingService');
this.trainingService = new adlModule.AdlTrainingService();
break;
}
default:
throw new Error("Setup training service failed.");
}
}
await this.trainingService.setClusterMetadata(key, value);
} }
public getClusterMetadata(_key: string): Promise<string> { public getClusterMetadata(key: string): Promise<string> {
throw new Error('Calling removed API getClusterMetadata'); return this.trainingService.getClusterMetadata(key);
} }
public async getTrialJobStatistics(): Promise<TrialJobStatistics[]> { public async getTrialJobStatistics(): Promise<TrialJobStatistics[]> {
...@@ -404,8 +433,17 @@ class NNIManager implements Manager { ...@@ -404,8 +433,17 @@ class NNIManager implements Manager {
} }
private async initTrainingService(config: ExperimentConfig): Promise<TrainingService> { private async initTrainingService(config: ExperimentConfig): Promise<TrainingService> {
this.config = config; let platform: string;
const platform = Array.isArray(config.trainingService) ? 'hybrid' : config.trainingService.platform; if (Array.isArray(config.trainingService)) {
platform = 'hybrid';
} else if (config.trainingService.platform) {
platform = config.trainingService.platform;
} else {
platform = (config as any).trainingServicePlatform;
}
if (!platform) {
throw new Error('Cannot detect training service platform');
}
if (['remote', 'pai', 'aml', 'hybrid'].includes(platform)) { if (['remote', 'pai', 'aml', 'hybrid'].includes(platform)) {
const module_ = await import('../training_service/reusable/routerTrainingService'); const module_ = await import('../training_service/reusable/routerTrainingService');
......
...@@ -131,6 +131,9 @@ export namespace ValidationSchemas { ...@@ -131,6 +131,9 @@ export namespace ValidationSchemas {
maxTrialNumPerGpu: joi.number(), maxTrialNumPerGpu: joi.number(),
useActiveGpu: joi.boolean(), useActiveGpu: joi.boolean(),
}), }),
adl_config: joi.object({ // eslint-disable-line @typescript-eslint/camelcase
// hack for v2 configuration
}),
kubeflow_config: joi.object({ // eslint-disable-line @typescript-eslint/camelcase kubeflow_config: joi.object({ // eslint-disable-line @typescript-eslint/camelcase
operator: joi.string().min(1).required(), operator: joi.string().min(1).required(),
storage: joi.string().min(1), storage: joi.string().min(1),
...@@ -194,6 +197,8 @@ export namespace ValidationSchemas { ...@@ -194,6 +197,8 @@ export namespace ValidationSchemas {
nni_manager_ip: joi.object({ // eslint-disable-line @typescript-eslint/camelcase nni_manager_ip: joi.object({ // eslint-disable-line @typescript-eslint/camelcase
nniManagerIp: joi.string().min(1) nniManagerIp: joi.string().min(1)
}), }),
version_check: joi.boolean(), // eslint-disable-line @typescript-eslint/camelcase
log_collection: joi.string(), // eslint-disable-line @typescript-eslint/camelcase
remote_config: joi.object({ // eslint-disable-line @typescript-eslint/camelcase remote_config: joi.object({ // eslint-disable-line @typescript-eslint/camelcase
reuse: joi.boolean() reuse: joi.boolean()
}), }),
......
...@@ -19,6 +19,7 @@ import {validateCodeDir} from '../../common/util'; ...@@ -19,6 +19,7 @@ import {validateCodeDir} from '../../common/util';
import {NFSConfig} from '../kubernetesConfig'; import {NFSConfig} from '../kubernetesConfig';
import {KubernetesTrialJobDetail} from '../kubernetesData'; import {KubernetesTrialJobDetail} from '../kubernetesData';
import {KubernetesTrainingService} from '../kubernetesTrainingService'; import {KubernetesTrainingService} from '../kubernetesTrainingService';
import {KubernetesJobRestServer} from '../kubernetesJobRestServer';
import {FrameworkControllerClientFactory} from './frameworkcontrollerApiClient'; import {FrameworkControllerClientFactory} from './frameworkcontrollerApiClient';
import { import {
FrameworkControllerClusterConfig, FrameworkControllerClusterConfig,
...@@ -52,7 +53,7 @@ class FrameworkControllerTrainingService extends KubernetesTrainingService imple ...@@ -52,7 +53,7 @@ class FrameworkControllerTrainingService extends KubernetesTrainingService imple
} }
public async run(): Promise<void> { public async run(): Promise<void> {
this.kubernetesJobRestServer = component.get(FrameworkControllerJobRestServer); this.kubernetesJobRestServer = new KubernetesJobRestServer(this);
if (this.kubernetesJobRestServer === undefined) { if (this.kubernetesJobRestServer === undefined) {
throw new Error('kubernetesJobRestServer not initialized!'); throw new Error('kubernetesJobRestServer not initialized!');
} }
...@@ -140,10 +141,11 @@ class FrameworkControllerTrainingService extends KubernetesTrainingService imple ...@@ -140,10 +141,11 @@ class FrameworkControllerTrainingService extends KubernetesTrainingService imple
const trialLocalTempFolder: string = path.join(getExperimentRootDir(), 'trials-local', trialJobId); const trialLocalTempFolder: string = path.join(getExperimentRootDir(), 'trials-local', trialJobId);
let frameworkcontrollerJobName: string = `nniexp${this.experimentId}trial${trialJobId}`.toLowerCase(); let frameworkcontrollerJobName: string = `nniexp${this.experimentId}trial${trialJobId}`.toLowerCase();
// Create frameworkcontroller job based on generated frameworkcontroller job resource config let frameworkcontrollerJobConfig: any;
let frameworkcontrollerJobConfig = JSON.parse(JSON.stringify(this.fcTemplate));
if (this.fcTemplate !== undefined) { if (this.fcTemplate !== undefined) {
// Create frameworkcontroller job based on generated frameworkcontroller job resource config
frameworkcontrollerJobConfig = JSON.parse(JSON.stringify(this.fcTemplate));
// add a custom name extension to the job name and apply it to the custom template // add a custom name extension to the job name and apply it to the custom template
frameworkcontrollerJobName += "xx" + this.fcTemplate.metadata.name; frameworkcontrollerJobName += "xx" + this.fcTemplate.metadata.name;
// Process custom task roles commands // Process custom task roles commands
......
...@@ -19,6 +19,7 @@ import { TrialConfigMetadataKey } from '../../common/trialConfigMetadataKey'; ...@@ -19,6 +19,7 @@ import { TrialConfigMetadataKey } from '../../common/trialConfigMetadataKey';
import { validateCodeDir } from '../../common/util'; import { validateCodeDir } from '../../common/util';
import { NFSConfig } from '../kubernetesConfig'; import { NFSConfig } from '../kubernetesConfig';
import { KubernetesTrialJobDetail } from '../kubernetesData'; import { KubernetesTrialJobDetail } from '../kubernetesData';
import { KubernetesJobRestServer } from '../kubernetesJobRestServer';
import { KubernetesTrainingService } from '../kubernetesTrainingService'; import { KubernetesTrainingService } from '../kubernetesTrainingService';
import { KubeflowOperatorClientFactory } from './kubeflowApiClient'; import { KubeflowOperatorClientFactory } from './kubeflowApiClient';
import { KubeflowClusterConfig, KubeflowClusterConfigAzure, KubeflowClusterConfigFactory, KubeflowClusterConfigNFS, import { KubeflowClusterConfig, KubeflowClusterConfigAzure, KubeflowClusterConfigFactory, KubeflowClusterConfigNFS,
...@@ -46,7 +47,7 @@ class KubeflowTrainingService extends KubernetesTrainingService implements Kuber ...@@ -46,7 +47,7 @@ class KubeflowTrainingService extends KubernetesTrainingService implements Kuber
public async run(): Promise<void> { public async run(): Promise<void> {
this.log.info('Run Kubeflow training service.'); this.log.info('Run Kubeflow training service.');
this.kubernetesJobRestServer = component.get(KubeflowJobRestServer); this.kubernetesJobRestServer = new KubernetesJobRestServer(this);
if (this.kubernetesJobRestServer === undefined) { if (this.kubernetesJobRestServer === undefined) {
throw new Error('kubernetesJobRestServer not initialized!'); throw new Error('kubernetesJobRestServer not initialized!');
} }
......
...@@ -16,7 +16,6 @@ import { KubernetesTrainingService } from './kubernetesTrainingService'; ...@@ -16,7 +16,6 @@ import { KubernetesTrainingService } from './kubernetesTrainingService';
export class KubernetesJobRestServer extends ClusterJobRestServer { export class KubernetesJobRestServer extends ClusterJobRestServer {
@Inject @Inject
private readonly kubernetesTrainingService? : KubernetesTrainingService; private readonly kubernetesTrainingService? : KubernetesTrainingService;
/** /**
* constructor to provide NNIRestServer's own rest property, e.g. port * constructor to provide NNIRestServer's own rest property, e.g. port
*/ */
......
...@@ -146,6 +146,10 @@ class LinuxCommands extends OsCommands { ...@@ -146,6 +146,10 @@ class LinuxCommands extends OsCommands {
public fileExistCommand(filePath: string): string { public fileExistCommand(filePath: string): string {
return `test -e ${filePath} && echo True || echo False`; return `test -e ${filePath} && echo True || echo False`;
} }
public getCurrentPath(): string {
return `pwd`;
}
} }
export { LinuxCommands }; export { LinuxCommands };
...@@ -134,6 +134,10 @@ class WindowsCommands extends OsCommands { ...@@ -134,6 +134,10 @@ class WindowsCommands extends OsCommands {
public fileExistCommand(filePath: string): string { public fileExistCommand(filePath: string): string {
return `powershell Test-Path ${filePath} -PathType Leaf`; return `powershell Test-Path ${filePath} -PathType Leaf`;
} }
public getCurrentPath(): string {
return `chdir`;
}
} }
export { WindowsCommands }; export { WindowsCommands };
...@@ -30,6 +30,7 @@ abstract class OsCommands { ...@@ -30,6 +30,7 @@ abstract class OsCommands {
public abstract executeScript(script: string, isFile: boolean): string; public abstract executeScript(script: string, isFile: boolean): string;
public abstract setPythonPath(pythonPath: string | undefined, command: string | undefined): string | undefined; public abstract setPythonPath(pythonPath: string | undefined, command: string | undefined): string | undefined;
public abstract fileExistCommand(filePath: string): string | undefined; public abstract fileExistCommand(filePath: string): string | undefined;
public abstract getCurrentPath(): string;
public joinPath(...paths: string[]): string { public joinPath(...paths: string[]): string {
let dir: string = paths.filter((path: any) => path !== '').join(this.pathSpliter); let dir: string = paths.filter((path: any) => path !== '').join(this.pathSpliter);
......
...@@ -169,6 +169,16 @@ class ShellExecutor { ...@@ -169,6 +169,16 @@ class ShellExecutor {
return this.tempPath; return this.tempPath;
} }
public async getCurrentPath(): Promise<string> {
const commandText = this.osCommands && this.osCommands.getCurrentPath();
const commandResult = await this.execute(commandText);
if (commandResult.exitCode == 0) {
return commandResult.stdout;
} else {
throw Error(commandResult.stderr);
}
}
public getRemoteScriptsPath(experimentId: string): string { public getRemoteScriptsPath(experimentId: string): string {
return this.joinPath(this.getRemoteExperimentRootDir(experimentId), 'scripts'); return this.joinPath(this.getRemoteExperimentRootDir(experimentId), 'scripts');
} }
......
...@@ -128,6 +128,10 @@ export class EnvironmentInformation { ...@@ -128,6 +128,10 @@ export class EnvironmentInformation {
export abstract class EnvironmentService { export abstract class EnvironmentService {
public async init(): Promise<void> {
return;
}
public abstract get hasStorageService(): boolean; public abstract get hasStorageService(): boolean;
public abstract refreshEnvironmentsStatus(environments: EnvironmentInformation[]): Promise<void>; public abstract refreshEnvironmentsStatus(environments: EnvironmentInformation[]): Promise<void>;
public abstract stopEnvironment(environment: EnvironmentInformation): Promise<void>; public abstract stopEnvironment(environment: EnvironmentInformation): Promise<void>;
......
...@@ -9,15 +9,16 @@ import * as component from '../../../common/component'; ...@@ -9,15 +9,16 @@ import * as component from '../../../common/component';
import { getExperimentId } from '../../../common/experimentStartupInfo'; import { getExperimentId } from '../../../common/experimentStartupInfo';
import { getLogger, Logger } from '../../../common/log'; import { getLogger, Logger } from '../../../common/log';
import { getExperimentRootDir } from '../../../common/utils'; import { getExperimentRootDir } from '../../../common/utils';
import { TrialConfigMetadataKey } from '../../common/trialConfigMetadataKey'; import { ExperimentConfig, AmlConfig, flattenConfig } from '../../../common/experimentConfig';
import { validateCodeDir } from '../../common/util'; import { validateCodeDir } from '../../common/util';
import { AMLClient } from '../aml/amlClient'; import { AMLClient } from '../aml/amlClient';
import { AMLClusterConfig, AMLEnvironmentInformation, AMLTrialConfig } from '../aml/amlConfig'; import { AMLEnvironmentInformation } from '../aml/amlConfig';
import { EnvironmentInformation, EnvironmentService } from '../environment'; import { EnvironmentInformation, EnvironmentService } from '../environment';
import { EventEmitter } from "events"; import { EventEmitter } from "events";
import { AMLCommandChannel } from '../channels/amlCommandChannel'; import { AMLCommandChannel } from '../channels/amlCommandChannel';
import { SharedStorageService } from '../sharedStorage' import { SharedStorageService } from '../sharedStorage'
interface FlattenAmlConfig extends ExperimentConfig, AmlConfig { }
/** /**
* Collector AML jobs info from AML cluster, and update aml job status locally * Collector AML jobs info from AML cluster, and update aml job status locally
...@@ -26,15 +27,16 @@ import { SharedStorageService } from '../sharedStorage' ...@@ -26,15 +27,16 @@ import { SharedStorageService } from '../sharedStorage'
export class AMLEnvironmentService extends EnvironmentService { export class AMLEnvironmentService extends EnvironmentService {
private readonly log: Logger = getLogger(); private readonly log: Logger = getLogger();
public amlClusterConfig: AMLClusterConfig | undefined;
public amlTrialConfig: AMLTrialConfig | undefined;
private experimentId: string; private experimentId: string;
private experimentRootDir: string; private experimentRootDir: string;
private config: FlattenAmlConfig;
constructor() { constructor(config: ExperimentConfig) {
super(); super();
this.experimentId = getExperimentId(); this.experimentId = getExperimentId();
this.experimentRootDir = getExperimentRootDir(); this.experimentRootDir = getExperimentRootDir();
this.config = flattenConfig(config, 'aml');
validateCodeDir(this.config.trialCodeDirectory);
} }
public get hasStorageService(): boolean { public get hasStorageService(): boolean {
...@@ -53,27 +55,6 @@ export class AMLEnvironmentService extends EnvironmentService { ...@@ -53,27 +55,6 @@ export class AMLEnvironmentService extends EnvironmentService {
return 'aml'; return 'aml';
} }
public async config(key: string, value: string): Promise<void> {
switch (key) {
case TrialConfigMetadataKey.AML_CLUSTER_CONFIG:
this.amlClusterConfig = <AMLClusterConfig>JSON.parse(value);
break;
case TrialConfigMetadataKey.TRIAL_CONFIG: {
if (this.amlClusterConfig === undefined) {
this.log.error('aml cluster config is not initialized');
break;
}
this.amlTrialConfig = <AMLTrialConfig>JSON.parse(value);
// Validate to make sure codeDir doesn't have too many files
await validateCodeDir(this.amlTrialConfig.codeDir);
break;
}
default:
this.log.debug(`AML not proccessed metadata key: '${key}', value: '${value}'`);
}
}
public async refreshEnvironmentsStatus(environments: EnvironmentInformation[]): Promise<void> { public async refreshEnvironmentsStatus(environments: EnvironmentInformation[]): Promise<void> {
environments.forEach(async (environment) => { environments.forEach(async (environment) => {
const amlClient = (environment as AMLEnvironmentInformation).amlClient; const amlClient = (environment as AMLEnvironmentInformation).amlClient;
...@@ -107,12 +88,6 @@ export class AMLEnvironmentService extends EnvironmentService { ...@@ -107,12 +88,6 @@ export class AMLEnvironmentService extends EnvironmentService {
} }
public async startEnvironment(environment: EnvironmentInformation): Promise<void> { public async startEnvironment(environment: EnvironmentInformation): Promise<void> {
if (this.amlClusterConfig === undefined) {
throw new Error('AML Cluster config is not initialized');
}
if (this.amlTrialConfig === undefined) {
throw new Error('AML trial config is not initialized');
}
const amlEnvironment: AMLEnvironmentInformation = environment as AMLEnvironmentInformation; const amlEnvironment: AMLEnvironmentInformation = environment as AMLEnvironmentInformation;
const environmentLocalTempFolder = path.join(this.experimentRootDir, "environment-temp"); const environmentLocalTempFolder = path.join(this.experimentRootDir, "environment-temp");
if (!fs.existsSync(environmentLocalTempFolder)) { if (!fs.existsSync(environmentLocalTempFolder)) {
...@@ -126,22 +101,24 @@ export class AMLEnvironmentService extends EnvironmentService { ...@@ -126,22 +101,24 @@ export class AMLEnvironmentService extends EnvironmentService {
amlEnvironment.command = `mv envs outputs/envs && cd outputs && ${amlEnvironment.command}`; amlEnvironment.command = `mv envs outputs/envs && cd outputs && ${amlEnvironment.command}`;
} }
amlEnvironment.command = `import os\nos.system('${amlEnvironment.command}')`; amlEnvironment.command = `import os\nos.system('${amlEnvironment.command}')`;
amlEnvironment.useActiveGpu = this.amlClusterConfig.useActiveGpu; amlEnvironment.useActiveGpu = !!this.config.deprecated.useActiveGpu;
amlEnvironment.maxTrialNumberPerGpu = this.amlClusterConfig.maxTrialNumPerGpu; amlEnvironment.maxTrialNumberPerGpu = this.config.maxTrialNumberPerGpu;
await fs.promises.writeFile(path.join(environmentLocalTempFolder, 'nni_script.py'), amlEnvironment.command, { encoding: 'utf8' }); await fs.promises.writeFile(path.join(environmentLocalTempFolder, 'nni_script.py'), amlEnvironment.command, { encoding: 'utf8' });
const amlClient = new AMLClient( const amlClient = new AMLClient(
this.amlClusterConfig.subscriptionId, this.config.subscriptionId,
this.amlClusterConfig.resourceGroup, this.config.resourceGroup,
this.amlClusterConfig.workspaceName, this.config.workspaceName,
this.experimentId, this.experimentId,
this.amlClusterConfig.computeTarget, this.config.computeTarget,
this.amlTrialConfig.image, this.config.dockerImage,
'nni_script.py', 'nni_script.py',
environmentLocalTempFolder environmentLocalTempFolder
); );
amlEnvironment.id = await amlClient.submit(); amlEnvironment.id = await amlClient.submit();
this.log.debug('aml: before getTrackingUrl');
amlEnvironment.trackingUrl = await amlClient.getTrackingUrl(); amlEnvironment.trackingUrl = await amlClient.getTrackingUrl();
this.log.debug('aml: after getTrackingUrl');
amlEnvironment.amlClient = amlClient; amlEnvironment.amlClient = amlClient;
} }
......
...@@ -13,7 +13,7 @@ export class EnvironmentServiceFactory { ...@@ -13,7 +13,7 @@ export class EnvironmentServiceFactory {
case 'remote': case 'remote':
return new RemoteEnvironmentService(config); return new RemoteEnvironmentService(config);
case 'aml': case 'aml':
return new AMLEnvironmentService(); return new AMLEnvironmentService(config);
case 'openpai': case 'openpai':
return new OpenPaiEnvironmentService(config); return new OpenPaiEnvironmentService(config);
default: default:
......
...@@ -27,7 +27,7 @@ export class RemoteEnvironmentService extends EnvironmentService { ...@@ -27,7 +27,7 @@ export class RemoteEnvironmentService extends EnvironmentService {
private readonly environmentExecutorManagerMap: Map<string, ExecutorManager>; private readonly environmentExecutorManagerMap: Map<string, ExecutorManager>;
private readonly remoteMachineMetaOccupiedMap: Map<RemoteMachineConfig, boolean>; private readonly remoteMachineMetaOccupiedMap: Map<RemoteMachineConfig, boolean>;
private readonly log: Logger; private readonly log: Logger;
private sshConnectionPromises: any[]; private sshConnectionPromises: Promise<void[]>;
private experimentRootDir: string; private experimentRootDir: string;
private remoteExperimentRootDir: string = ""; private remoteExperimentRootDir: string = "";
private experimentId: string; private experimentId: string;
...@@ -39,7 +39,6 @@ export class RemoteEnvironmentService extends EnvironmentService { ...@@ -39,7 +39,6 @@ export class RemoteEnvironmentService extends EnvironmentService {
this.environmentExecutorManagerMap = new Map<string, ExecutorManager>(); this.environmentExecutorManagerMap = new Map<string, ExecutorManager>();
this.machineExecutorManagerMap = new Map<RemoteMachineConfig, ExecutorManager>(); this.machineExecutorManagerMap = new Map<RemoteMachineConfig, ExecutorManager>();
this.remoteMachineMetaOccupiedMap = new Map<RemoteMachineConfig, boolean>(); this.remoteMachineMetaOccupiedMap = new Map<RemoteMachineConfig, boolean>();
this.sshConnectionPromises = [];
this.experimentRootDir = getExperimentRootDir(); this.experimentRootDir = getExperimentRootDir();
this.experimentId = getExperimentId(); this.experimentId = getExperimentId();
this.log = getLogger(); this.log = getLogger();
...@@ -50,9 +49,18 @@ export class RemoteEnvironmentService extends EnvironmentService { ...@@ -50,9 +49,18 @@ export class RemoteEnvironmentService extends EnvironmentService {
throw new Error(`codeDir ${this.config.trialCodeDirectory} is not a directory`); throw new Error(`codeDir ${this.config.trialCodeDirectory} is not a directory`);
} }
this.sshConnectionPromises = this.config.machineList.map( this.sshConnectionPromises = Promise.all(this.config.machineList.map(
machine => this.initRemoteMachineOnConnected(machine) machine => this.initRemoteMachineOnConnected(machine)
); ));
}
public async init(): Promise<void> {
await this.sshConnectionPromises;
this.log.info('ssh connection initialized!');
Array.from(this.machineExecutorManagerMap.keys()).forEach(rmMeta => {
// initialize remoteMachineMetaOccupiedMap, false means not occupied
this.remoteMachineMetaOccupiedMap.set(rmMeta, false);
});
} }
public get prefetchedEnvironmentCount(): number { public get prefetchedEnvironmentCount(): number {
...@@ -204,16 +212,6 @@ export class RemoteEnvironmentService extends EnvironmentService { ...@@ -204,16 +212,6 @@ export class RemoteEnvironmentService extends EnvironmentService {
} }
public async startEnvironment(environment: EnvironmentInformation): Promise<void> { public async startEnvironment(environment: EnvironmentInformation): Promise<void> {
if (this.sshConnectionPromises.length > 0) {
await Promise.all(this.sshConnectionPromises);
this.log.info('ssh connection initialized!');
// set sshConnectionPromises to [] to avoid log information duplicated
this.sshConnectionPromises = [];
Array.from(this.machineExecutorManagerMap.keys()).forEach(rmMeta => {
// initialize remoteMachineMetaOccupiedMap, false means not occupied
this.remoteMachineMetaOccupiedMap.set(rmMeta, false);
});
}
const remoteEnvironment: RemoteMachineEnvironmentInformation = environment as RemoteMachineEnvironmentInformation; const remoteEnvironment: RemoteMachineEnvironmentInformation = environment as RemoteMachineEnvironmentInformation;
remoteEnvironment.status = 'WAITING'; remoteEnvironment.status = 'WAITING';
// schedule machine for environment, generate command // schedule machine for environment, generate command
...@@ -238,7 +236,10 @@ export class RemoteEnvironmentService extends EnvironmentService { ...@@ -238,7 +236,10 @@ export class RemoteEnvironmentService extends EnvironmentService {
const executor = await this.getExecutor(environment.id); const executor = await this.getExecutor(environment.id);
if (environment.useSharedStorage) { if (environment.useSharedStorage) {
this.remoteExperimentRootDir = component.get<SharedStorageService>(SharedStorageService).remoteWorkingRoot; this.remoteExperimentRootDir = component.get<SharedStorageService>(SharedStorageService).remoteWorkingRoot;
const remoteMountCommand = component.get<SharedStorageService>(SharedStorageService).remoteMountCommand.replace(/echo -e /g, `echo `).replace(/echo /g, `echo -e `); if (!this.remoteExperimentRootDir.startsWith('/')) {
this.remoteExperimentRootDir = executor.joinPath((await executor.getCurrentPath()).trim(), this.remoteExperimentRootDir);
}
const remoteMountCommand = component.get<SharedStorageService>(SharedStorageService).remoteMountCommand.replace(/echo -e /g, `echo `).replace(/echo /g, `echo -e `).replace(/\\\$/g, `\\\\\\$`);
const result = await executor.executeScript(remoteMountCommand, false, false); const result = await executor.executeScript(remoteMountCommand, false, false);
if (result.exitCode !== 0) { if (result.exitCode !== 0) {
throw new Error(`Mount shared storage on remote machine failed.\n ERROR: ${result.stderr}`); throw new Error(`Mount shared storage on remote machine failed.\n ERROR: ${result.stderr}`);
......
...@@ -122,7 +122,6 @@ class TrialDispatcher implements TrainingService { ...@@ -122,7 +122,6 @@ class TrialDispatcher implements TrainingService {
this.environmentServiceList.push(env); this.environmentServiceList.push(env);
} }
// FIXME: max?
this.environmentMaintenceLoopInterval = Math.max( this.environmentMaintenceLoopInterval = Math.max(
...this.environmentServiceList.map((env) => env.environmentMaintenceLoopInterval) ...this.environmentServiceList.map((env) => env.environmentMaintenceLoopInterval)
); );
...@@ -211,6 +210,7 @@ class TrialDispatcher implements TrainingService { ...@@ -211,6 +210,7 @@ class TrialDispatcher implements TrainingService {
} }
public async run(): Promise<void> { public async run(): Promise<void> {
await Promise.all(this.environmentServiceList.map(env => env.init()));
for(const environmentService of this.environmentServiceList) { for(const environmentService of this.environmentServiceList) {
const runnerSettings: RunnerSettings = new RunnerSettings(); const runnerSettings: RunnerSettings = new RunnerSettings();
...@@ -497,9 +497,10 @@ class TrialDispatcher implements TrainingService { ...@@ -497,9 +497,10 @@ class TrialDispatcher implements TrainingService {
liveEnvironmentsCount++; liveEnvironmentsCount++;
if (environment.status === "RUNNING" && environment.isRunnerReady) { if (environment.status === "RUNNING" && environment.isRunnerReady) {
// if environment is not reusable and used, stop and not count as idle; // if environment is not reusable and used, stop and not count as idle;
const reuseMode = Array.isArray(this.config.trainingService) || (this.config.trainingService as any).reuseMode;
if ( if (
0 === environment.runningTrialCount && 0 === environment.runningTrialCount &&
!(this.config as any).reuseMode && reuseMode === false &&
environment.assignedTrialCount > 0 environment.assignedTrialCount > 0
) { ) {
if (environment.environmentService === undefined) { if (environment.environmentService === undefined) {
......
...@@ -237,7 +237,7 @@ class App extends React.Component<{}, AppState> { ...@@ -237,7 +237,7 @@ class App extends React.Component<{}, AppState> {
} }
// experiment status and /trial-jobs api's status could decide website update // experiment status and /trial-jobs api's status could decide website update
if (['DONE', 'ERROR', 'STOPPED'].includes(EXPERIMENT.status) || TRIALS.jobListError()) { if (['DONE', 'ERROR', 'STOPPED', 'VIEWED'].includes(EXPERIMENT.status) || TRIALS.jobListError()) {
// experiment finished, refresh once more to ensure consistency // experiment finished, refresh once more to ensure consistency
this.setState(() => ({ interval: 0, isUpdate: false })); this.setState(() => ({ interval: 0, isUpdate: false }));
return; return;
......
...@@ -54,7 +54,7 @@ class ExperimentSummaryPanel extends React.Component<ExpDrawerProps, ExpDrawerSt ...@@ -54,7 +54,7 @@ class ExperimentSummaryPanel extends React.Component<ExpDrawerProps, ExpDrawerSt
this.setState({ experiment: JSON.stringify(result, null, 4) }); this.setState({ experiment: JSON.stringify(result, null, 4) });
} }
if (['DONE', 'ERROR', 'STOPPED'].includes(EXPERIMENT.status)) { if (['DONE', 'ERROR', 'STOPPED', 'VIEWED'].includes(EXPERIMENT.status)) {
if (this.refreshId !== null || this.refreshId !== undefined) { if (this.refreshId !== null || this.refreshId !== undefined) {
window.clearInterval(this.refreshId); window.clearInterval(this.refreshId);
} }
......
...@@ -30,6 +30,7 @@ export const EditExperimentParam = (): any => { ...@@ -30,6 +30,7 @@ export const EditExperimentParam = (): any => {
const { title, field, editType, maxExecDuration, maxTrialNum, trialConcurrency, updateOverviewPage } = useContext( const { title, field, editType, maxExecDuration, maxTrialNum, trialConcurrency, updateOverviewPage } = useContext(
EditExpeParamContext EditExpeParamContext
); );
const originMaxDurationStr = EXPERIMENT.profile.params.maxExperimentDuration;
const { maxDurationUnit, changeMaxDurationUnit } = useContext(AppContext); const { maxDurationUnit, changeMaxDurationUnit } = useContext(AppContext);
const [unit, setUnit] = useState(maxDurationUnit); const [unit, setUnit] = useState(maxDurationUnit);
let defaultVal = ''; let defaultVal = '';
...@@ -101,13 +102,7 @@ export const EditExperimentParam = (): any => { ...@@ -101,13 +102,7 @@ export const EditExperimentParam = (): any => {
} }
if (isMaxDuration) { if (isMaxDuration) {
const maxDura = JSON.parse(editInputVal); const maxDura = JSON.parse(editInputVal);
if (unit === 'm') { newProfile.params[field] = `${maxDura}${unit}`;
newProfile.params[field] = maxDura * 60;
} else if (unit === 'h') {
newProfile.params[field] = maxDura * 3600;
} else {
newProfile.params[field] = maxDura * 24 * 60 * 60;
}
} else { } else {
newProfile.params[field] = parseInt(editInputVal, 10); newProfile.params[field] = parseInt(editInputVal, 10);
} }
...@@ -118,9 +113,12 @@ export const EditExperimentParam = (): any => { ...@@ -118,9 +113,12 @@ export const EditExperimentParam = (): any => {
params: { update_type: editType } params: { update_type: editType }
}); });
if (res.status === 200) { if (res.status === 200) {
showMessageInfo(`Successfully updated experiment's ${field}`, 'success'); if (isMaxDuration) {
changeMaxDurationUnit(unit); changeMaxDurationUnit(unit);
} }
showMessageInfo(`Successfully updated experiment's ${field}`, 'success');
updateOverviewPage();
}
} catch (error) { } catch (error) {
if (error.response && error.response.data.error) { if (error.response && error.response.data.error) {
showMessageInfo(`Failed to update trial ${field}\n${error.response.data.error}`, 'error'); showMessageInfo(`Failed to update trial ${field}\n${error.response.data.error}`, 'error');
...@@ -132,9 +130,14 @@ export const EditExperimentParam = (): any => { ...@@ -132,9 +130,14 @@ export const EditExperimentParam = (): any => {
showMessageInfo(`Failed to update trial ${field}\nUnknown error`, 'error'); showMessageInfo(`Failed to update trial ${field}\nUnknown error`, 'error');
} }
setEditValInput(defaultVal); setEditValInput(defaultVal);
// confirm trial config panel val
if (isMaxDuration) {
newProfile.params[field] = originMaxDurationStr;
} else {
newProfile.params[field] = beforeParam;
}
} }
showPencil(); showPencil();
updateOverviewPage();
} }
function cancelEdit(): void { function cancelEdit(): void {
...@@ -162,7 +165,7 @@ export const EditExperimentParam = (): any => { ...@@ -162,7 +165,7 @@ export const EditExperimentParam = (): any => {
<EditExpeParamContext.Consumer> <EditExpeParamContext.Consumer>
{(value): React.ReactNode => { {(value): React.ReactNode => {
let editClassName = ''; let editClassName = '';
if (value.field === 'maxExecDuration') { if (value.field === 'maxExperimentDuration') {
editClassName = isShowPencil ? 'noEditDuration' : 'editDuration'; editClassName = isShowPencil ? 'noEditDuration' : 'editDuration';
} }
return ( return (
......
...@@ -50,7 +50,7 @@ export const ExpDuration = (): any => ( ...@@ -50,7 +50,7 @@ export const ExpDuration = (): any => (
<EditExpeParamContext.Provider <EditExpeParamContext.Provider
value={{ value={{
editType: CONTROLTYPE[0], editType: CONTROLTYPE[0],
field: 'maxExecDuration', field: 'maxExperimentDuration',
title: 'Max duration', title: 'Max duration',
maxExecDuration: maxExecDurationStr, maxExecDuration: maxExecDurationStr,
maxTrialNum: EXPERIMENT.maxTrialNumber, maxTrialNum: EXPERIMENT.maxTrialNumber,
......
...@@ -89,7 +89,7 @@ export const TrialCount = (): any => { ...@@ -89,7 +89,7 @@ export const TrialCount = (): any => {
<EditExpeParamContext.Provider <EditExpeParamContext.Provider
value={{ value={{
title: MAX_TRIAL_NUMBERS, title: MAX_TRIAL_NUMBERS,
field: 'maxTrialNum', field: 'maxTrialNumber',
editType: CONTROLTYPE[1], editType: CONTROLTYPE[1],
maxExecDuration: '', maxExecDuration: '',
maxTrialNum: EXPERIMENT.maxTrialNumber, maxTrialNum: EXPERIMENT.maxTrialNumber,
......
...@@ -3,8 +3,7 @@ import { Stack, Panel, PrimaryButton } from '@fluentui/react'; ...@@ -3,8 +3,7 @@ import { Stack, Panel, PrimaryButton } from '@fluentui/react';
import { EXPERIMENT } from '../../static/datamodel'; import { EXPERIMENT } from '../../static/datamodel';
import MonacoEditor from 'react-monaco-editor'; import MonacoEditor from 'react-monaco-editor';
import { MONACO } from '../../static/const'; import { MONACO } from '../../static/const';
import { AppContext } from '../../App'; import { convertDuration, caclMonacoEditorHeight } from '../../static/function';
import { convertDuration, convertTimeAsUnit, caclMonacoEditorHeight } from '../../static/function';
import { prettyStringify } from '../../static/json_util'; import { prettyStringify } from '../../static/json_util';
import lodash from 'lodash'; import lodash from 'lodash';
import '../../static/style/logDrawer.scss'; import '../../static/style/logDrawer.scss';
...@@ -69,14 +68,6 @@ class TrialConfigPanel extends React.Component<LogDrawerProps, LogDrawerState> { ...@@ -69,14 +68,6 @@ class TrialConfigPanel extends React.Component<LogDrawerProps, LogDrawerState> {
const prettyWidth = innerWidth > 1400 ? 100 : 60; const prettyWidth = innerWidth > 1400 ? 100 : 60;
return (
<AppContext.Consumer>
{(value): React.ReactNode => {
const unit = value.maxDurationUnit;
profile.params.maxExecDuration = `${convertTimeAsUnit(
unit,
profile.params.maxExecDuration
)}${unit}`;
const showProfile = JSON.stringify(profile, filter, 2); const showProfile = JSON.stringify(profile, filter, 2);
return ( return (
<Stack> <Stack>
...@@ -117,9 +108,6 @@ class TrialConfigPanel extends React.Component<LogDrawerProps, LogDrawerState> { ...@@ -117,9 +108,6 @@ class TrialConfigPanel extends React.Component<LogDrawerProps, LogDrawerState> {
</Panel> </Panel>
</Stack> </Stack>
); );
}}
</AppContext.Consumer>
);
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment