Unverified Commit dc54f4ad authored by liuzhe-lz's avatar liuzhe-lz Committed by GitHub
Browse files

AML config v2 (#3552)

parent 7fd07766
...@@ -15,6 +15,7 @@ class AmlConfig(TrainingServiceConfig): ...@@ -15,6 +15,7 @@ class AmlConfig(TrainingServiceConfig):
workspace_name: str workspace_name: str
compute_target: str compute_target: str
docker_image: str = 'msranni/nni:latest' docker_image: str = 'msranni/nni:latest'
max_trial_number_per_gpu: int = 1
_validation_rules = { _validation_rules = {
'platform': lambda value: (value == 'aml', 'cannot be modified') 'platform': lambda value: (value == 'aml', 'cannot be modified')
......
...@@ -134,7 +134,7 @@ def to_v2(v1) -> ExperimentConfig: ...@@ -134,7 +134,7 @@ def to_v2(v1) -> ExperimentConfig:
_move_field(aml_config, ts, 'resourceGroup', 'resource_group') _move_field(aml_config, ts, 'resourceGroup', 'resource_group')
_move_field(aml_config, ts, 'workspaceName', 'workspace_name') _move_field(aml_config, ts, 'workspaceName', 'workspace_name')
_move_field(aml_config, ts, 'computeTarget', 'compute_target') _move_field(aml_config, ts, 'computeTarget', 'compute_target')
_deprecate(aml_config, v2, 'maxTrialNumPerGpu') _move_field(aml_config, ts, 'maxTrialNumPerGpu', 'max_trial_number_per_gpu')
_deprecate(aml_config, v2, 'useActiveGpu') _deprecate(aml_config, v2, 'useActiveGpu')
assert not aml_config, aml_config assert not aml_config, aml_config
......
...@@ -65,6 +65,7 @@ export interface AmlConfig extends TrainingServiceConfig { ...@@ -65,6 +65,7 @@ export interface AmlConfig extends TrainingServiceConfig {
workspaceName: string; workspaceName: string;
computeTarget: string; computeTarget: string;
dockerImage: string; dockerImage: string;
maxTrialNumberPerGpu: number;
} }
/* Kubeflow */ /* Kubeflow */
......
...@@ -9,15 +9,16 @@ import * as component from '../../../common/component'; ...@@ -9,15 +9,16 @@ import * as component from '../../../common/component';
import { getExperimentId } from '../../../common/experimentStartupInfo'; import { getExperimentId } from '../../../common/experimentStartupInfo';
import { getLogger, Logger } from '../../../common/log'; import { getLogger, Logger } from '../../../common/log';
import { getExperimentRootDir } from '../../../common/utils'; import { getExperimentRootDir } from '../../../common/utils';
import { TrialConfigMetadataKey } from '../../common/trialConfigMetadataKey'; import { ExperimentConfig, AmlConfig, flattenConfig } from '../../../common/experimentConfig';
import { validateCodeDir } from '../../common/util'; import { validateCodeDir } from '../../common/util';
import { AMLClient } from '../aml/amlClient'; import { AMLClient } from '../aml/amlClient';
import { AMLClusterConfig, AMLEnvironmentInformation, AMLTrialConfig } from '../aml/amlConfig'; import { AMLEnvironmentInformation } from '../aml/amlConfig';
import { EnvironmentInformation, EnvironmentService } from '../environment'; import { EnvironmentInformation, EnvironmentService } from '../environment';
import { EventEmitter } from "events"; import { EventEmitter } from "events";
import { AMLCommandChannel } from '../channels/amlCommandChannel'; import { AMLCommandChannel } from '../channels/amlCommandChannel';
import { SharedStorageService } from '../sharedStorage' import { SharedStorageService } from '../sharedStorage'
interface FlattenAmlConfig extends ExperimentConfig, AmlConfig { }
/** /**
* Collector AML jobs info from AML cluster, and update aml job status locally * Collector AML jobs info from AML cluster, and update aml job status locally
...@@ -26,15 +27,16 @@ import { SharedStorageService } from '../sharedStorage' ...@@ -26,15 +27,16 @@ import { SharedStorageService } from '../sharedStorage'
export class AMLEnvironmentService extends EnvironmentService { export class AMLEnvironmentService extends EnvironmentService {
private readonly log: Logger = getLogger(); private readonly log: Logger = getLogger();
public amlClusterConfig: AMLClusterConfig | undefined;
public amlTrialConfig: AMLTrialConfig | undefined;
private experimentId: string; private experimentId: string;
private experimentRootDir: string; private experimentRootDir: string;
private config: FlattenAmlConfig;
constructor() { constructor(config: ExperimentConfig) {
super(); super();
this.experimentId = getExperimentId(); this.experimentId = getExperimentId();
this.experimentRootDir = getExperimentRootDir(); this.experimentRootDir = getExperimentRootDir();
this.config = flattenConfig(config, 'aml');
validateCodeDir(this.config.trialCodeDirectory);
} }
public get hasStorageService(): boolean { public get hasStorageService(): boolean {
...@@ -53,27 +55,6 @@ export class AMLEnvironmentService extends EnvironmentService { ...@@ -53,27 +55,6 @@ export class AMLEnvironmentService extends EnvironmentService {
return 'aml'; return 'aml';
} }
public async config(key: string, value: string): Promise<void> {
switch (key) {
case TrialConfigMetadataKey.AML_CLUSTER_CONFIG:
this.amlClusterConfig = <AMLClusterConfig>JSON.parse(value);
break;
case TrialConfigMetadataKey.TRIAL_CONFIG: {
if (this.amlClusterConfig === undefined) {
this.log.error('aml cluster config is not initialized');
break;
}
this.amlTrialConfig = <AMLTrialConfig>JSON.parse(value);
// Validate to make sure codeDir doesn't have too many files
await validateCodeDir(this.amlTrialConfig.codeDir);
break;
}
default:
this.log.debug(`AML not proccessed metadata key: '${key}', value: '${value}'`);
}
}
public async refreshEnvironmentsStatus(environments: EnvironmentInformation[]): Promise<void> { public async refreshEnvironmentsStatus(environments: EnvironmentInformation[]): Promise<void> {
environments.forEach(async (environment) => { environments.forEach(async (environment) => {
const amlClient = (environment as AMLEnvironmentInformation).amlClient; const amlClient = (environment as AMLEnvironmentInformation).amlClient;
...@@ -107,12 +88,6 @@ export class AMLEnvironmentService extends EnvironmentService { ...@@ -107,12 +88,6 @@ export class AMLEnvironmentService extends EnvironmentService {
} }
public async startEnvironment(environment: EnvironmentInformation): Promise<void> { public async startEnvironment(environment: EnvironmentInformation): Promise<void> {
if (this.amlClusterConfig === undefined) {
throw new Error('AML Cluster config is not initialized');
}
if (this.amlTrialConfig === undefined) {
throw new Error('AML trial config is not initialized');
}
const amlEnvironment: AMLEnvironmentInformation = environment as AMLEnvironmentInformation; const amlEnvironment: AMLEnvironmentInformation = environment as AMLEnvironmentInformation;
const environmentLocalTempFolder = path.join(this.experimentRootDir, "environment-temp"); const environmentLocalTempFolder = path.join(this.experimentRootDir, "environment-temp");
if (!fs.existsSync(environmentLocalTempFolder)) { if (!fs.existsSync(environmentLocalTempFolder)) {
...@@ -126,22 +101,24 @@ export class AMLEnvironmentService extends EnvironmentService { ...@@ -126,22 +101,24 @@ export class AMLEnvironmentService extends EnvironmentService {
amlEnvironment.command = `mv envs outputs/envs && cd outputs && ${amlEnvironment.command}`; amlEnvironment.command = `mv envs outputs/envs && cd outputs && ${amlEnvironment.command}`;
} }
amlEnvironment.command = `import os\nos.system('${amlEnvironment.command}')`; amlEnvironment.command = `import os\nos.system('${amlEnvironment.command}')`;
amlEnvironment.useActiveGpu = this.amlClusterConfig.useActiveGpu; amlEnvironment.useActiveGpu = !!this.config.deprecated.useActiveGpu;
amlEnvironment.maxTrialNumberPerGpu = this.amlClusterConfig.maxTrialNumPerGpu; amlEnvironment.maxTrialNumberPerGpu = this.config.maxTrialNumberPerGpu;
await fs.promises.writeFile(path.join(environmentLocalTempFolder, 'nni_script.py'), amlEnvironment.command, { encoding: 'utf8' }); await fs.promises.writeFile(path.join(environmentLocalTempFolder, 'nni_script.py'), amlEnvironment.command, { encoding: 'utf8' });
const amlClient = new AMLClient( const amlClient = new AMLClient(
this.amlClusterConfig.subscriptionId, this.config.subscriptionId,
this.amlClusterConfig.resourceGroup, this.config.resourceGroup,
this.amlClusterConfig.workspaceName, this.config.workspaceName,
this.experimentId, this.experimentId,
this.amlClusterConfig.computeTarget, this.config.computeTarget,
this.amlTrialConfig.image, this.config.dockerImage,
'nni_script.py', 'nni_script.py',
environmentLocalTempFolder environmentLocalTempFolder
); );
amlEnvironment.id = await amlClient.submit(); amlEnvironment.id = await amlClient.submit();
this.log.debug('aml: before getTrackingUrl');
amlEnvironment.trackingUrl = await amlClient.getTrackingUrl(); amlEnvironment.trackingUrl = await amlClient.getTrackingUrl();
this.log.debug('aml: after getTrackingUrl');
amlEnvironment.amlClient = amlClient; amlEnvironment.amlClient = amlClient;
} }
......
...@@ -13,7 +13,7 @@ export class EnvironmentServiceFactory { ...@@ -13,7 +13,7 @@ export class EnvironmentServiceFactory {
case 'remote': case 'remote':
return new RemoteEnvironmentService(config); return new RemoteEnvironmentService(config);
case 'aml': case 'aml':
return new AMLEnvironmentService(); return new AMLEnvironmentService(config);
case 'openpai': case 'openpai':
return new OpenPaiEnvironmentService(config); return new OpenPaiEnvironmentService(config);
default: default:
......
...@@ -500,7 +500,7 @@ class TrialDispatcher implements TrainingService { ...@@ -500,7 +500,7 @@ class TrialDispatcher implements TrainingService {
const reuseMode = Array.isArray(this.config.trainingService) || (this.config.trainingService as any).reuseMode; const reuseMode = Array.isArray(this.config.trainingService) || (this.config.trainingService as any).reuseMode;
if ( if (
0 === environment.runningTrialCount && 0 === environment.runningTrialCount &&
!reuseMode && reuseMode === false &&
environment.assignedTrialCount > 0 environment.assignedTrialCount > 0
) { ) {
if (environment.environmentService === undefined) { if (environment.environmentService === undefined) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment