Unverified Commit 93f96d4f authored by SparkSnail's avatar SparkSnail Committed by GitHub
Browse files

Support aml (#2615)

parent f5caa193
**Run an Experiment on Azure Machine Learning**
===
NNI supports running an experiment on [AML](https://azure.microsoft.com/en-us/services/machine-learning/) , called aml mode.
## Setup environment
Step 1. Install NNI, follow the install guide [here](../Tutorial/QuickStart.md).
Step 2. Create AML account, follow the document [here](https://docs.microsoft.com/en-us/azure/machine-learning/how-to-manage-workspace-cli).
Step 3. Get your account information.
![](../../img/aml_account.png)
Step4. Install AML package environment.
```
python3 -m pip install azureml --user
python3 -m pip install azureml-sdk --user
```
## Run an experiment
Use `examples/trials/mnist-tfv1` as an example. The NNI config YAML file's content is like:
```yaml
authorName: default
experimentName: example_mnist
trialConcurrency: 1
maxExecDuration: 1h
maxTrialNum: 10
trainingServicePlatform: aml
searchSpacePath: search_space.json
#choice: true, false
useAnnotation: false
tuner:
#choice: TPE, Random, Anneal, Evolution, BatchTuner, MetisTuner, GPTuner
#SMAC (SMAC should be installed through nnictl)
builtinTunerName: TPE
classArgs:
#choice: maximize, minimize
optimize_mode: maximize
trial:
command: python3 mnist.py
codeDir: .
computeTarget: ${replace_to_your_computeTarget}
image: msranni/nni
amlConfig:
subscriptionId: ${replace_to_your_subscriptionId}
resourceGroup: ${replace_to_your_resourceGroup}
workspaceName: ${replace_to_your_workspaceName}
```
Note: You should set `trainingServicePlatform: aml` in NNI config YAML file if you want to start experiment in aml mode.
Compared with [LocalMode](LocalMode.md) trial configuration in aml mode have these additional keys:
* computeTarget
* required key. The computer cluster name you want to use in your AML workspace.
* image
* required key. The docker image name used in job.
amlConfig:
* subscriptionId
* the subscriptionId of your account
* resourceGroup
* the resourceGroup of your account
* workspaceName
* the workspaceName of your account
\ No newline at end of file
...@@ -10,3 +10,4 @@ Introduction to NNI Training Services ...@@ -10,3 +10,4 @@ Introduction to NNI Training Services
Kubeflow<./TrainingService/KubeflowMode> Kubeflow<./TrainingService/KubeflowMode>
FrameworkController<./TrainingService/FrameworkControllerMode> FrameworkController<./TrainingService/FrameworkControllerMode>
DLTS<./TrainingService/DLTSMode> DLTS<./TrainingService/DLTSMode>
AML<./TrainingService/AMLMode>
authorName: default
experimentName: example_mnist_pytorch
trialConcurrency: 1
maxExecDuration: 1h
maxTrialNum: 10
trainingServicePlatform: aml
searchSpacePath: search_space.json
#choice: true, false
useAnnotation: false
tuner:
#choice: TPE, Random, Anneal, Evolution, BatchTuner, MetisTuner, GPTuner
#SMAC (SMAC should be installed through nnictl)
builtinTunerName: TPE
classArgs:
#choice: maximize, minimize
optimize_mode: maximize
trial:
command: python3 mnist.py
codeDir: .
computeTarget: ${replace_to_your_computeTarget}
image: msranni/nni
amlConfig:
subscriptionId: ${replace_to_your_subscriptionId}
resourceGroup: ${replace_to_your_resourceGroup}
workspaceName: ${replace_to_your_workspaceName}
authorName: default
experimentName: example_mnist
trialConcurrency: 1
maxExecDuration: 1h
maxTrialNum: 10
trainingServicePlatform: aml
searchSpacePath: search_space.json
#choice: true, false
useAnnotation: false
tuner:
#choice: TPE, Random, Anneal, Evolution, BatchTuner, MetisTuner, GPTuner
#SMAC (SMAC should be installed through nnictl)
builtinTunerName: TPE
classArgs:
#choice: maximize, minimize
optimize_mode: maximize
trial:
command: python3 mnist.py
codeDir: .
computeTarget: ${replace_to_your_computeTarget}
image: msranni/nni
amlConfig:
subscriptionId: ${replace_to_your_subscriptionId}
resourceGroup: ${replace_to_your_resourceGroup}
workspaceName: ${replace_to_your_workspaceName}
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import os
import sys
import time
import json
from argparse import ArgumentParser
from azureml.core import Experiment, RunConfiguration, ScriptRunConfig
from azureml.core.compute import ComputeTarget
from azureml.core.run import RUNNING_STATES, RunStatus, Run
from azureml.core import Workspace
from azureml.core.conda_dependencies import CondaDependencies
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument('--subscription_id', help='the subscription id of aml')
parser.add_argument('--resource_group', help='the resource group of aml')
parser.add_argument('--workspace_name', help='the workspace name of aml')
parser.add_argument('--compute_target', help='the compute cluster name of aml')
parser.add_argument('--docker_image', help='the docker image of job')
parser.add_argument('--experiment_name', help='the experiment name')
parser.add_argument('--script_dir', help='script directory')
parser.add_argument('--script_name', help='script name')
args = parser.parse_args()
ws = Workspace(args.subscription_id, args.resource_group, args.workspace_name)
compute_target = ComputeTarget(workspace=ws, name=args.compute_target)
experiment = Experiment(ws, args.experiment_name)
run_config = RunConfiguration()
dependencies = CondaDependencies()
dependencies.add_pip_package("azureml-sdk")
dependencies.add_pip_package("azureml")
run_config.environment.python.conda_dependencies = dependencies
run_config.environment.docker.enabled = True
run_config.environment.docker.base_image = args.docker_image
run_config.target = compute_target
run_config.node_count = 1
config = ScriptRunConfig(source_directory=args.script_dir, script=args.script_name, run_config=run_config)
run = experiment.submit(config)
print(run.get_details()["runId"])
while True:
line = sys.stdin.readline().rstrip()
if line == 'update_status':
print('status:' + run.get_status())
elif line == 'tracking_url':
print('tracking_url:' + run.get_portal_url())
elif line == 'stop':
run.cancel()
exit(0)
elif line == 'receive':
print('receive:' + json.dumps(run.get_metrics()))
elif line:
items = line.split(':')
if items[0] == 'command':
run.log('nni_manager', line[8:])
...@@ -65,6 +65,10 @@ async function initContainer(foreground: boolean, platformMode: string, logFileN ...@@ -65,6 +65,10 @@ async function initContainer(foreground: boolean, platformMode: string, logFileN
Container.bind(TrainingService) Container.bind(TrainingService)
.to(DLTSTrainingService) .to(DLTSTrainingService)
.scope(Scope.Singleton); .scope(Scope.Singleton);
} else if (platformMode === 'aml') {
Container.bind(TrainingService)
.to(RouterTrainingService)
.scope(Scope.Singleton);
} else { } else {
throw new Error(`Error: unsupported mode: ${platformMode}`); throw new Error(`Error: unsupported mode: ${platformMode}`);
} }
...@@ -93,7 +97,7 @@ async function initContainer(foreground: boolean, platformMode: string, logFileN ...@@ -93,7 +97,7 @@ async function initContainer(foreground: boolean, platformMode: string, logFileN
function usage(): void { function usage(): void {
console.info('usage: node main.js --port <port> --mode \ console.info('usage: node main.js --port <port> --mode \
<local/remote/pai/kubeflow/frameworkcontroller/paiYarn> --start_mode <new/resume> --experiment_id <id> --foreground <true/false>'); <local/remote/pai/kubeflow/frameworkcontroller/paiYarn/aml> --start_mode <new/resume> --experiment_id <id> --foreground <true/false>');
} }
const strPort: string = parseArg(['--port', '-p']); const strPort: string = parseArg(['--port', '-p']);
...@@ -113,7 +117,7 @@ const foreground: boolean = foregroundArg.toLowerCase() === 'true' ? true : fals ...@@ -113,7 +117,7 @@ const foreground: boolean = foregroundArg.toLowerCase() === 'true' ? true : fals
const port: number = parseInt(strPort, 10); const port: number = parseInt(strPort, 10);
const mode: string = parseArg(['--mode', '-m']); const mode: string = parseArg(['--mode', '-m']);
if (!['local', 'remote', 'pai', 'kubeflow', 'frameworkcontroller', 'paiYarn', 'dlts'].includes(mode)) { if (!['local', 'remote', 'pai', 'kubeflow', 'frameworkcontroller', 'paiYarn', 'dlts', 'aml'].includes(mode)) {
console.log(`FATAL: unknown mode: ${mode}`); console.log(`FATAL: unknown mode: ${mode}`);
usage(); usage();
process.exit(1); process.exit(1);
......
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
"ignore": "^5.1.4", "ignore": "^5.1.4",
"js-base64": "^2.4.9", "js-base64": "^2.4.9",
"kubernetes-client": "^6.5.0", "kubernetes-client": "^6.5.0",
"python-shell": "^2.0.1",
"rx": "^4.1.0", "rx": "^4.1.0",
"sqlite3": "^4.0.2", "sqlite3": "^4.0.2",
"ssh2": "^0.6.1", "ssh2": "^0.6.1",
......
...@@ -39,6 +39,8 @@ export namespace ValidationSchemas { ...@@ -39,6 +39,8 @@ export namespace ValidationSchemas {
nniManagerNFSMountPath: joi.string().min(1), nniManagerNFSMountPath: joi.string().min(1),
containerNFSMountPath: joi.string().min(1), containerNFSMountPath: joi.string().min(1),
paiConfigPath: joi.string(), paiConfigPath: joi.string(),
computeTarget: joi.string(),
nodeCount: joi.number(),
paiStorageConfigName: joi.string().min(1), paiStorageConfigName: joi.string().min(1),
nasMode: joi.string().valid('classic_mode', 'enas_mode', 'oneshot_mode', 'darts_mode'), nasMode: joi.string().valid('classic_mode', 'enas_mode', 'oneshot_mode', 'darts_mode'),
portList: joi.array().items(joi.object({ portList: joi.array().items(joi.object({
...@@ -150,6 +152,11 @@ export namespace ValidationSchemas { ...@@ -150,6 +152,11 @@ export namespace ValidationSchemas {
email: joi.string().min(1), email: joi.string().min(1),
password: joi.string().min(1) password: joi.string().min(1)
}), }),
aml_config: joi.object({ // eslint-disable-line @typescript-eslint/camelcase
subscriptionId: joi.string().min(1),
resourceGroup: joi.string().min(1),
workspaceName: joi.string().min(1)
}),
nni_manager_ip: joi.object({ // eslint-disable-line @typescript-eslint/camelcase nni_manager_ip: joi.object({ // eslint-disable-line @typescript-eslint/camelcase
nniManagerIp: joi.string().min(1) nniManagerIp: joi.string().min(1)
}) })
......
...@@ -19,6 +19,7 @@ export enum TrialConfigMetadataKey { ...@@ -19,6 +19,7 @@ export enum TrialConfigMetadataKey {
NNI_MANAGER_IP = 'nni_manager_ip', NNI_MANAGER_IP = 'nni_manager_ip',
FRAMEWORKCONTROLLER_CLUSTER_CONFIG = 'frameworkcontroller_config', FRAMEWORKCONTROLLER_CLUSTER_CONFIG = 'frameworkcontroller_config',
DLTS_CLUSTER_CONFIG = 'dlts_config', DLTS_CLUSTER_CONFIG = 'dlts_config',
AML_CLUSTER_CONFIG = 'aml_config',
VERSION_CHECK = 'version_check', VERSION_CHECK = 'version_check',
LOG_COLLECTION = 'log_collection' LOG_COLLECTION = 'log_collection'
} }
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
'use strict';
import { Deferred } from 'ts-deferred';
import { PythonShell } from 'python-shell';
export class AMLClient {
public subscriptionId: string;
public resourceGroup: string;
public workspaceName: string;
public experimentId: string;
public image: string;
public scriptName: string;
public pythonShellClient: undefined | PythonShell;
public codeDir: string;
public computeTarget: string;
constructor(
subscriptionId: string,
resourceGroup: string,
workspaceName: string,
experimentId: string,
computeTarget: string,
image: string,
scriptName: string,
codeDir: string,
) {
this.subscriptionId = subscriptionId;
this.resourceGroup = resourceGroup;
this.workspaceName = workspaceName;
this.experimentId = experimentId;
this.image = image;
this.scriptName = scriptName;
this.codeDir = codeDir;
this.computeTarget = computeTarget;
}
public submit(): Promise<string> {
const deferred: Deferred<string> = new Deferred<string>();
this.pythonShellClient = new PythonShell('amlUtil.py', {
scriptPath: './config/aml',
pythonOptions: ['-u'], // get print results in real-time
args: [
'--subscription_id', this.subscriptionId,
'--resource_group', this.resourceGroup,
'--workspace_name', this.workspaceName,
'--compute_target', this.computeTarget,
'--docker_image', this.image,
'--experiment_name', `nni_exp_${this.experimentId}`,
'--script_dir', this.codeDir,
'--script_name', this.scriptName
]
});
this.pythonShellClient.on('message', function (envId: any) {
// received a message sent from the Python script (a simple "print" statement)
deferred.resolve(envId);
});
return deferred.promise;
}
public stop(): void {
if (this.pythonShellClient === undefined) {
throw Error('python shell client not initialized!');
}
this.pythonShellClient.send('stop');
}
public getTrackingUrl(): Promise<string> {
const deferred: Deferred<string> = new Deferred<string>();
if (this.pythonShellClient === undefined) {
throw Error('python shell client not initialized!');
}
this.pythonShellClient.send('tracking_url');
let trackingUrl = '';
this.pythonShellClient.on('message', function (status: any) {
const items = status.split(':');
if (items[0] === 'tracking_url') {
trackingUrl = items.splice(1, items.length).join('')
}
deferred.resolve(trackingUrl);
});
return deferred.promise;
}
public updateStatus(oldStatus: string): Promise<string> {
const deferred: Deferred<string> = new Deferred<string>();
if (this.pythonShellClient === undefined) {
throw Error('python shell client not initialized!');
}
let newStatus = oldStatus;
this.pythonShellClient.send('update_status');
this.pythonShellClient.on('message', function (status: any) {
const items = status.split(':');
if (items[0] === 'status') {
newStatus = items.splice(1, items.length).join('')
}
deferred.resolve(newStatus);
});
return deferred.promise;
}
public sendCommand(message: string): void {
if (this.pythonShellClient === undefined) {
throw Error('python shell client not initialized!');
}
this.pythonShellClient.send(`command:${message}`);
}
public receiveCommand(): Promise<any> {
const deferred: Deferred<any> = new Deferred<any>();
if (this.pythonShellClient === undefined) {
throw Error('python shell client not initialized!');
}
this.pythonShellClient.send('receive');
this.pythonShellClient.on('message', function (command: any) {
const items = command.split(':')
if (items[0] === 'receive') {
deferred.resolve(JSON.parse(command.slice(8)))
}
});
return deferred.promise;
}
}
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
'use strict';
import { TrialConfig } from '../../common/trialConfig';
import { EnvironmentInformation } from '../environment';
import { AMLClient } from '../aml/amlClient';
export class AMLClusterConfig {
public readonly subscriptionId: string;
public readonly resourceGroup: string;
public readonly workspaceName: string;
constructor(subscriptionId: string, resourceGroup: string, workspaceName: string) {
this.subscriptionId = subscriptionId;
this.resourceGroup = resourceGroup;
this.workspaceName = workspaceName;
}
}
export class AMLTrialConfig extends TrialConfig {
public readonly image: string;
public readonly command: string;
public readonly codeDir: string;
public readonly computeTarget: string;
constructor(codeDir: string, command: string, image: string, computeTarget: string) {
super("", codeDir, 0);
this.codeDir = codeDir;
this.command = command;
this.image = image;
this.computeTarget = computeTarget;
}
}
export class AMLEnvironmentInformation extends EnvironmentInformation {
public amlClient?: AMLClient;
}
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
'use strict';
import { EventEmitter } from 'events';
import { delay } from "../../../common/utils";
import { AMLEnvironmentInformation } from '../aml/amlConfig';
import { CommandChannel, RunnerConnection } from "../commandChannel";
import { Channel, EnvironmentInformation } from "../environment";
class AMLRunnerConnection extends RunnerConnection {
}
export class AMLCommandChannel extends CommandChannel {
private stopping: boolean = false;
private currentMessageIndex: number = -1;
private sendQueues: [EnvironmentInformation, string][] = [];
private readonly NNI_METRICS_PATTERN: string = `NNISDK_MEb'(?<metrics>.*?)'`;
public constructor(commandEmitter: EventEmitter) {
super(commandEmitter);
}
public get channelName(): Channel {
return "aml";
}
public async config(_key: string, _value: any): Promise<void> {
// do nothing
}
public async start(): Promise<void> {
// do nothing
}
public async stop(): Promise<void> {
this.stopping = true;
}
public async run(): Promise<void> {
// start command loops
await Promise.all([
this.receiveLoop(),
this.sendLoop()
]);
}
protected async sendCommandInternal(environment: EnvironmentInformation, message: string): Promise<void> {
this.sendQueues.push([environment, message]);
}
protected createRunnerConnection(environment: EnvironmentInformation): RunnerConnection {
return new AMLRunnerConnection(environment);
}
private async sendLoop(): Promise<void> {
const intervalSeconds = 0.5;
while (!this.stopping) {
const start = new Date();
if (this.sendQueues.length > 0) {
while (this.sendQueues.length > 0) {
const item = this.sendQueues.shift();
if (item === undefined) {
break;
}
const environment = item[0];
const message = item[1];
const amlClient = (environment as AMLEnvironmentInformation).amlClient;
if (!amlClient) {
throw new Error('aml client not initialized!');
}
amlClient.sendCommand(message);
}
}
const end = new Date();
const delayMs = intervalSeconds * 1000 - (end.valueOf() - start.valueOf());
if (delayMs > 0) {
await delay(delayMs);
}
}
}
private async receiveLoop(): Promise<void> {
const intervalSeconds = 2;
while (!this.stopping) {
const start = new Date();
const runnerConnections = [...this.runnerConnections.values()] as AMLRunnerConnection[];
for (const runnerConnection of runnerConnections) {
// to loop all commands
const amlClient = (runnerConnection.environment as AMLEnvironmentInformation).amlClient;
if (!amlClient) {
throw new Error('AML client not initialized!');
}
const command = await amlClient.receiveCommand();
if (command && Object.prototype.hasOwnProperty.call(command, "trial_runner")) {
const messages = command['trial_runner'];
if (messages) {
if (messages instanceof Object && this.currentMessageIndex < messages.length - 1) {
for (let index = this.currentMessageIndex + 1; index < messages.length; index ++) {
this.handleCommand(runnerConnection.environment, messages[index]);
}
this.currentMessageIndex = messages.length - 1;
} else if (this.currentMessageIndex === -1){
this.handleCommand(runnerConnection.environment, messages);
this.currentMessageIndex += 1;
}
}
}
}
const end = new Date();
const delayMs = intervalSeconds * 1000 - (end.valueOf() - start.valueOf());
if (delayMs > 0) {
await delay(delayMs);
}
}
}
}
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
import * as component from "../../../common/component"; import * as component from "../../../common/component";
import { delay } from "../../../common/utils"; import { delay } from "../../../common/utils";
import { CommandChannel, RunnerConnection } from "../commandChannel"; import { CommandChannel, RunnerConnection } from "../commandChannel";
import { EnvironmentInformation, Channel } from "../environment"; import { Channel, EnvironmentInformation } from "../environment";
import { StorageService } from "../storageService"; import { StorageService } from "../storageService";
class FileHandler { class FileHandler {
...@@ -38,15 +38,21 @@ export class FileCommandChannel extends CommandChannel { ...@@ -38,15 +38,21 @@ export class FileCommandChannel extends CommandChannel {
} }
public async start(): Promise<void> { public async start(): Promise<void> {
// start command loops // do nothing
this.receiveLoop();
this.sendLoop();
} }
public async stop(): Promise<void> { public async stop(): Promise<void> {
this.stopping = true; this.stopping = true;
} }
public async run(): Promise<void> {
// start command loops
await Promise.all([
this.receiveLoop(),
this.sendLoop()
]);
}
protected async sendCommandInternal(environment: EnvironmentInformation, message: string): Promise<void> { protected async sendCommandInternal(environment: EnvironmentInformation, message: string): Promise<void> {
this.sendQueues.push([environment, message]); this.sendQueues.push([environment, message]);
} }
......
...@@ -66,6 +66,10 @@ export class WebCommandChannel extends CommandChannel { ...@@ -66,6 +66,10 @@ export class WebCommandChannel extends CommandChannel {
} }
} }
public async run(): Promise<void>{
// do nothing
}
protected async sendCommandInternal(environment: EnvironmentInformation, message: string): Promise<void> { protected async sendCommandInternal(environment: EnvironmentInformation, message: string): Promise<void> {
if (this.webSocketServer === undefined) { if (this.webSocketServer === undefined) {
throw new Error(`WebCommandChannel: uninitialized!`) throw new Error(`WebCommandChannel: uninitialized!`)
......
...@@ -59,6 +59,9 @@ export abstract class CommandChannel { ...@@ -59,6 +59,9 @@ export abstract class CommandChannel {
public abstract start(): Promise<void>; public abstract start(): Promise<void>;
public abstract stop(): Promise<void>; public abstract stop(): Promise<void>;
// Pull-based command channels need loop to check messages, the loop should be started with await here.
public abstract run(): Promise<void>;
protected abstract sendCommandInternal(environment: EnvironmentInformation, message: string): Promise<void>; protected abstract sendCommandInternal(environment: EnvironmentInformation, message: string): Promise<void>;
protected abstract createRunnerConnection(environment: EnvironmentInformation): RunnerConnection; protected abstract createRunnerConnection(environment: EnvironmentInformation): RunnerConnection;
......
...@@ -14,7 +14,6 @@ import { CommandChannel } from "./commandChannel"; ...@@ -14,7 +14,6 @@ import { CommandChannel } from "./commandChannel";
export type EnvironmentStatus = 'UNKNOWN' | 'WAITING' | 'RUNNING' | 'SUCCEEDED' | 'FAILED' | 'USER_CANCELED'; export type EnvironmentStatus = 'UNKNOWN' | 'WAITING' | 'RUNNING' | 'SUCCEEDED' | 'FAILED' | 'USER_CANCELED';
export type Channel = "web" | "file" | "aml" | "ut"; export type Channel = "web" | "file" | "aml" | "ut";
export class EnvironmentInformation { export class EnvironmentInformation {
private log: Logger; private log: Logger;
...@@ -65,6 +64,7 @@ export class EnvironmentInformation { ...@@ -65,6 +64,7 @@ export class EnvironmentInformation {
} }
} }
} }
export abstract class EnvironmentService { export abstract class EnvironmentService {
public abstract get hasStorageService(): boolean; public abstract get hasStorageService(): boolean;
......
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
'use strict';
import * as fs from 'fs';
import * as path from 'path';
import * as component from '../../../common/component';
import { getExperimentId } from '../../../common/experimentStartupInfo';
import { getLogger, Logger } from '../../../common/log';
import { TrialConfigMetadataKey } from '../../common/trialConfigMetadataKey';
import { AMLClusterConfig, AMLTrialConfig } from '../aml/amlConfig';
import { EnvironmentInformation, EnvironmentService } from '../environment';
import { AMLEnvironmentInformation } from '../aml/amlConfig';
import { AMLClient } from '../aml/amlClient';
import {
NNIManagerIpConfig,
} from '../../../common/trainingService';
import { validateCodeDir } from '../../common/util';
import { getExperimentRootDir } from '../../../common/utils';
import { AMLCommandChannel } from '../channels/amlCommandChannel';
import { CommandChannel } from "../commandChannel";
import { EventEmitter } from "events";
/**
* Collector PAI jobs info from PAI cluster, and update pai job status locally
*/
@component.Singleton
export class AMLEnvironmentService extends EnvironmentService {
private readonly log: Logger = getLogger();
public amlClusterConfig: AMLClusterConfig | undefined;
public amlTrialConfig: AMLTrialConfig | undefined;
private amlJobConfig: any;
private stopping: boolean = false;
private versionCheck: boolean = true;
private isMultiPhase: boolean = false;
private nniVersion?: string;
private experimentId: string;
private nniManagerIpConfig?: NNIManagerIpConfig;
private experimentRootDir: string;
constructor() {
super();
this.experimentId = getExperimentId();
this.experimentRootDir = getExperimentRootDir();
}
public get hasStorageService(): boolean {
return false;
}
public getCommandChannel(commandEmitter: EventEmitter): CommandChannel {
return new AMLCommandChannel(commandEmitter);
}
public createEnviornmentInfomation(envId: string, envName: string): EnvironmentInformation {
return new AMLEnvironmentInformation(envId, envName);
}
public async config(key: string, value: string): Promise<void> {
switch (key) {
case TrialConfigMetadataKey.AML_CLUSTER_CONFIG:
this.amlClusterConfig = <AMLClusterConfig>JSON.parse(value);
break;
case TrialConfigMetadataKey.TRIAL_CONFIG: {
if (this.amlClusterConfig === undefined) {
this.log.error('aml cluster config is not initialized');
break;
}
this.amlTrialConfig = <AMLTrialConfig>JSON.parse(value);
// Validate to make sure codeDir doesn't have too many files
await validateCodeDir(this.amlTrialConfig.codeDir);
break;
}
default:
this.log.debug(`AML not proccessed metadata key: '${key}', value: '${value}'`);
}
}
public async refreshEnvironmentsStatus(environments: EnvironmentInformation[]): Promise<void> {
environments.forEach(async (environment) => {
const amlClient = (environment as AMLEnvironmentInformation).amlClient;
if (!amlClient) {
throw new Error('AML client not initialized!');
}
const status = await amlClient.updateStatus(environment.status);
switch (status.toUpperCase()) {
case 'WAITING':
case 'RUNNING':
case 'QUEUED':
// RUNNING status is set by runner, and ignore waiting status
break;
case 'COMPLETED':
case 'SUCCEEDED':
environment.setFinalStatus('SUCCEEDED');
break;
case 'FAILED':
environment.setFinalStatus('FAILED');
break;
case 'STOPPED':
case 'STOPPING':
environment.setFinalStatus('USER_CANCELED');
break;
default:
environment.setFinalStatus('UNKNOWN');
}
});
}
public async startEnvironment(environment: EnvironmentInformation): Promise<void> {
if (this.amlClusterConfig === undefined) {
throw new Error('AML Cluster config is not initialized');
}
if (this.amlTrialConfig === undefined) {
throw new Error('AML trial config is not initialized');
}
const amlEnvironment: AMLEnvironmentInformation = environment as AMLEnvironmentInformation;
const environmentLocalTempFolder = path.join(this.experimentRootDir, this.experimentId, "environment-temp");
environment.command = `import os\nos.system('${amlEnvironment.command}')`;
await fs.promises.writeFile(path.join(environmentLocalTempFolder, 'nni_script.py'), amlEnvironment.command ,{ encoding: 'utf8' });
const amlClient = new AMLClient(
this.amlClusterConfig.subscriptionId,
this.amlClusterConfig.resourceGroup,
this.amlClusterConfig.workspaceName,
this.experimentId,
this.amlTrialConfig.computeTarget,
this.amlTrialConfig.image,
'nni_script.py',
environmentLocalTempFolder
);
amlEnvironment.id = await amlClient.submit();
amlEnvironment.trackingUrl = await amlClient.getTrackingUrl();
amlEnvironment.amlClient = amlClient;
}
public async stopEnvironment(environment: EnvironmentInformation): Promise<void> {
const amlEnvironment: AMLEnvironmentInformation = environment as AMLEnvironmentInformation;
const amlClient = amlEnvironment.amlClient;
if (!amlClient) {
throw new Error('AML client not initialized!');
}
amlClient.stop();
}
}
...@@ -167,8 +167,9 @@ export class OpenPaiEnvironmentService extends EnvironmentService { ...@@ -167,8 +167,9 @@ export class OpenPaiEnvironmentService extends EnvironmentService {
} }
// Step 1. Prepare PAI job configuration // Step 1. Prepare PAI job configuration
environment.runnerWorkingFolder = `${this.paiTrialConfig.containerNFSMountPath}/${this.experimentId}/envs/${environment.id}`; const environmentRoot = `${this.paiTrialConfig.containerNFSMountPath}/${this.experimentId}`;
environment.command = `cd ${environment.runnerWorkingFolder} && ${environment.command}` environment.runnerWorkingFolder = `${environmentRoot}/envs/${environment.id}`;
environment.command = `cd ${environmentRoot} && ${environment.command}`
environment.trackingUrl = `${this.protocol}://${this.paiClusterConfig.host}/job-detail.html?username=${this.paiClusterConfig.userName}&jobName=${environment.jobId}` environment.trackingUrl = `${this.protocol}://${this.paiClusterConfig.host}/job-detail.html?username=${this.paiClusterConfig.userName}&jobName=${environment.jobId}`
// Step 2. Generate Job Configuration in yaml format // Step 2. Generate Job Configuration in yaml format
......
...@@ -13,6 +13,7 @@ import { PAIClusterConfig } from '../pai/paiConfig'; ...@@ -13,6 +13,7 @@ import { PAIClusterConfig } from '../pai/paiConfig';
import { PAIK8STrainingService } from '../pai/paiK8S/paiK8STrainingService'; import { PAIK8STrainingService } from '../pai/paiK8S/paiK8STrainingService';
import { EnvironmentService } from './environment'; import { EnvironmentService } from './environment';
import { OpenPaiEnvironmentService } from './environments/openPaiEnvironmentService'; import { OpenPaiEnvironmentService } from './environments/openPaiEnvironmentService';
import { AMLEnvironmentService } from './environments/amlEnvironmentService';
import { MountedStorageService } from './storages/mountedStorageService'; import { MountedStorageService } from './storages/mountedStorageService';
import { StorageService } from './storageService'; import { StorageService } from './storageService';
import { TrialDispatcher } from './trialDispatcher'; import { TrialDispatcher } from './trialDispatcher';
...@@ -120,6 +121,25 @@ class RouterTrainingService implements TrainingService { ...@@ -120,6 +121,25 @@ class RouterTrainingService implements TrainingService {
} }
await this.internalTrainingService.setClusterMetadata(key, value); await this.internalTrainingService.setClusterMetadata(key, value);
this.metaDataCache.clear();
} else if (key === TrialConfigMetadataKey.AML_CLUSTER_CONFIG) {
this.internalTrainingService = component.get(TrialDispatcher);
Container.bind(EnvironmentService)
.to(AMLEnvironmentService)
.scope(Scope.Singleton);
for (const [key, value] of this.metaDataCache) {
if (this.internalTrainingService === undefined) {
throw new Error("TrainingService is not assigned!");
}
await this.internalTrainingService.setClusterMetadata(key, value);
}
if (this.internalTrainingService === undefined) {
throw new Error("TrainingService is not assigned!");
}
await this.internalTrainingService.setClusterMetadata(key, value);
this.metaDataCache.clear(); this.metaDataCache.clear();
} else { } else {
this.log.debug(`caching metadata key:{} value:{}, as training service is not determined.`); this.log.debug(`caching metadata key:{} value:{}, as training service is not determined.`);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment