Unverified Commit 93f96d4f authored by SparkSnail's avatar SparkSnail Committed by GitHub
Browse files

Support aml (#2615)

parent f5caa193
...@@ -9,7 +9,7 @@ import * as path from 'path'; ...@@ -9,7 +9,7 @@ import * as path from 'path';
import { Writable } from 'stream'; import { Writable } from 'stream';
import { String } from 'typescript-string-operations'; import { String } from 'typescript-string-operations';
import * as component from '../../common/component'; import * as component from '../../common/component';
import { getExperimentId, getPlatform, getBasePort } from '../../common/experimentStartupInfo'; import { getBasePort, getExperimentId, getPlatform } from '../../common/experimentStartupInfo';
import { getLogger, Logger } from '../../common/log'; import { getLogger, Logger } from '../../common/log';
import { NNIManagerIpConfig, TrainingService, TrialJobApplicationForm, TrialJobMetric, TrialJobStatus } from '../../common/trainingService'; import { NNIManagerIpConfig, TrainingService, TrialJobApplicationForm, TrialJobMetric, TrialJobStatus } from '../../common/trainingService';
import { delay, getExperimentRootDir, getLogLevel, getVersion, mkDirPSync, uniqueString } from '../../common/utils'; import { delay, getExperimentRootDir, getLogLevel, getVersion, mkDirPSync, uniqueString } from '../../common/utils';
...@@ -19,9 +19,9 @@ import { CONTAINER_INSTALL_NNI_SHELL_FORMAT } from '../common/containerJobData'; ...@@ -19,9 +19,9 @@ import { CONTAINER_INSTALL_NNI_SHELL_FORMAT } from '../common/containerJobData';
import { TrialConfig } from '../common/trialConfig'; import { TrialConfig } from '../common/trialConfig';
import { TrialConfigMetadataKey } from '../common/trialConfigMetadataKey'; import { TrialConfigMetadataKey } from '../common/trialConfigMetadataKey';
import { validateCodeDir } from '../common/util'; import { validateCodeDir } from '../common/util';
import { WebCommandChannel } from './channels/webCommandChannel';
import { Command, CommandChannel } from './commandChannel'; import { Command, CommandChannel } from './commandChannel';
import { EnvironmentInformation, EnvironmentService, NodeInfomation, RunnerSettings } from './environment'; import { EnvironmentInformation, EnvironmentService, NodeInfomation, RunnerSettings } from './environment';
import { MountedStorageService } from './storages/mountedStorageService';
import { StorageService } from './storageService'; import { StorageService } from './storageService';
import { TrialDetail } from './trial'; import { TrialDetail } from './trial';
...@@ -40,6 +40,7 @@ class TrialDispatcher implements TrainingService { ...@@ -40,6 +40,7 @@ class TrialDispatcher implements TrainingService {
private readonly metricsEmitter: EventEmitter; private readonly metricsEmitter: EventEmitter;
private readonly experimentId: string; private readonly experimentId: string;
private readonly experimentRootDir: string;
private enableVersionCheck: boolean = true; private enableVersionCheck: boolean = true;
...@@ -58,6 +59,8 @@ class TrialDispatcher implements TrainingService { ...@@ -58,6 +59,8 @@ class TrialDispatcher implements TrainingService {
this.environments = new Map<string, EnvironmentInformation>(); this.environments = new Map<string, EnvironmentInformation>();
this.metricsEmitter = new EventEmitter(); this.metricsEmitter = new EventEmitter();
this.experimentId = getExperimentId(); this.experimentId = getExperimentId();
this.experimentRootDir = getExperimentRootDir();
this.runnerSettings = new RunnerSettings(); this.runnerSettings = new RunnerSettings();
this.runnerSettings.experimentId = this.experimentId; this.runnerSettings.experimentId = this.experimentId;
this.runnerSettings.platform = getPlatform(); this.runnerSettings.platform = getPlatform();
...@@ -158,14 +161,14 @@ class TrialDispatcher implements TrainingService { ...@@ -158,14 +161,14 @@ class TrialDispatcher implements TrainingService {
const environmentService = component.get<EnvironmentService>(EnvironmentService); const environmentService = component.get<EnvironmentService>(EnvironmentService);
this.commandEmitter = new EventEmitter(); this.commandEmitter = new EventEmitter();
this.commandChannel = new WebCommandChannel(this.commandEmitter); this.commandChannel = environmentService.getCommandChannel(this.commandEmitter);
// TODO it's a hard code of web channel, it needs to be improved. // TODO it's a hard code of web channel, it needs to be improved.
this.runnerSettings.nniManagerPort = getBasePort() + 1; this.runnerSettings.nniManagerPort = getBasePort() + 1;
this.runnerSettings.commandChannel = this.commandChannel.channelName; this.runnerSettings.commandChannel = this.commandChannel.channelName;
// for AML channel, other channels can ignore this. // for AML channel, other channels can ignore this.
this.commandChannel.config("MetricEmitter", this.metricsEmitter); await this.commandChannel.config("MetricEmitter", this.metricsEmitter);
// start channel // start channel
this.commandEmitter.on("command", (command: Command): void => { this.commandEmitter.on("command", (command: Command): void => {
...@@ -173,41 +176,50 @@ class TrialDispatcher implements TrainingService { ...@@ -173,41 +176,50 @@ class TrialDispatcher implements TrainingService {
this.log.error(`TrialDispatcher: error on handle env ${command.environment.id} command: ${command.command}, data: ${command.data}, error: ${err}`); this.log.error(`TrialDispatcher: error on handle env ${command.environment.id} command: ${command.command}, data: ${command.data}, error: ${err}`);
}) })
}); });
this.commandChannel.start(); await this.commandChannel.start();
this.log.info(`TrialDispatcher: started channel: ${this.commandChannel.constructor.name}`); this.log.info(`TrialDispatcher: started channel: ${this.commandChannel.constructor.name}`);
if (this.trialConfig === undefined) { if (this.trialConfig === undefined) {
throw new Error(`trial config shouldn't be undefined in run()`); throw new Error(`trial config shouldn't be undefined in run()`);
} }
this.log.info(`TrialDispatcher: copying code and settings.`);
let storageService: StorageService;
if (environmentService.hasStorageService) { if (environmentService.hasStorageService) {
this.log.info(`TrialDispatcher: copying code and settings.`); this.log.debug(`TrialDispatcher: use existing storage service.`);
const storageService = component.get<StorageService>(StorageService); storageService = component.get<StorageService>(StorageService);
// Copy the compressed file to remoteDirectory and delete it } else {
const codeDir = path.resolve(this.trialConfig.codeDir); this.log.debug(`TrialDispatcher: create temp storage service to temp folder.`);
const envDir = storageService.joinPath("envs"); storageService = new MountedStorageService();
const codeFileName = await storageService.copyDirectory(codeDir, envDir, true); const environmentLocalTempFolder = path.join(this.experimentRootDir, this.experimentId, "environment-temp");
storageService.rename(codeFileName, "nni-code.tar.gz"); storageService.initialize(this.trialConfig.codeDir, environmentLocalTempFolder);
}
const installFileName = storageService.joinPath(envDir, 'install_nni.sh');
await storageService.save(CONTAINER_INSTALL_NNI_SHELL_FORMAT, installFileName); // Copy the compressed file to remoteDirectory and delete it
const codeDir = path.resolve(this.trialConfig.codeDir);
const runnerSettings = storageService.joinPath(envDir, "settings.json"); const envDir = storageService.joinPath("envs");
await storageService.save(JSON.stringify(this.runnerSettings), runnerSettings); const codeFileName = await storageService.copyDirectory(codeDir, envDir, true);
storageService.rename(codeFileName, "nni-code.tar.gz");
if (this.isDeveloping) {
let trialToolsPath = path.join(__dirname, "../../../../../tools/nni_trial_tool"); const installFileName = storageService.joinPath(envDir, 'install_nni.sh');
if (false === fs.existsSync(trialToolsPath)) { await storageService.save(CONTAINER_INSTALL_NNI_SHELL_FORMAT, installFileName);
trialToolsPath = path.join(__dirname, "..\\..\\..\\..\\..\\tools\\nni_trial_tool");
} const runnerSettings = storageService.joinPath(envDir, "settings.json");
await storageService.copyDirectory(trialToolsPath, envDir, true); await storageService.save(JSON.stringify(this.runnerSettings), runnerSettings);
if (this.isDeveloping) {
let trialToolsPath = path.join(__dirname, "../../../../../tools/nni_trial_tool");
if (false === fs.existsSync(trialToolsPath)) {
trialToolsPath = path.join(__dirname, "..\\..\\..\\..\\..\\tools\\nni_trial_tool");
} }
await storageService.copyDirectory(trialToolsPath, envDir, true);
} }
this.log.info(`TrialDispatcher: run loop started.`); this.log.info(`TrialDispatcher: run loop started.`);
await Promise.all([ await Promise.all([
this.environmentMaintenanceLoop(), this.environmentMaintenanceLoop(),
this.trialManagementLoop(), this.trialManagementLoop(),
this.commandChannel.run(),
]); ]);
} }
...@@ -274,7 +286,7 @@ class TrialDispatcher implements TrainingService { ...@@ -274,7 +286,7 @@ class TrialDispatcher implements TrainingService {
} }
this.commandEmitter.off("command", this.handleCommand); this.commandEmitter.off("command", this.handleCommand);
this.commandChannel.stop(); await this.commandChannel.stop();
} }
private async environmentMaintenanceLoop(): Promise<void> { private async environmentMaintenanceLoop(): Promise<void> {
...@@ -396,7 +408,6 @@ class TrialDispatcher implements TrainingService { ...@@ -396,7 +408,6 @@ class TrialDispatcher implements TrainingService {
break; break;
} }
} }
let liveEnvironmentsCount = 0; let liveEnvironmentsCount = 0;
const idleEnvironments: EnvironmentInformation[] = []; const idleEnvironments: EnvironmentInformation[] = [];
this.environments.forEach((environment) => { this.environments.forEach((environment) => {
...@@ -407,7 +418,6 @@ class TrialDispatcher implements TrainingService { ...@@ -407,7 +418,6 @@ class TrialDispatcher implements TrainingService {
} }
} }
}); });
while (idleEnvironments.length > 0 && waitingTrials.length > 0) { while (idleEnvironments.length > 0 && waitingTrials.length > 0) {
const trial = waitingTrials.shift(); const trial = waitingTrials.shift();
const idleEnvironment = idleEnvironments.shift(); const idleEnvironment = idleEnvironments.shift();
...@@ -442,14 +452,10 @@ class TrialDispatcher implements TrainingService { ...@@ -442,14 +452,10 @@ class TrialDispatcher implements TrainingService {
environment.command = "[ -d \"nni_trial_tool\" ] && echo \"nni_trial_tool exists already\" || (mkdir ./nni_trial_tool && tar -xof ../nni_trial_tool.tar.gz -C ./nni_trial_tool) && pip3 install websockets && " + environment.command; environment.command = "[ -d \"nni_trial_tool\" ] && echo \"nni_trial_tool exists already\" || (mkdir ./nni_trial_tool && tar -xof ../nni_trial_tool.tar.gz -C ./nni_trial_tool) && pip3 install websockets && " + environment.command;
} }
if (environmentService.hasStorageService) { environment.command = `mkdir -p envs/${envId} && cd envs/${envId} && ${environment.command}`;
const storageService = component.get<StorageService>(StorageService);
environment.workingFolder = storageService.joinPath("envs", envId);
await storageService.createDirectory(environment.workingFolder);
}
this.environments.set(environment.id, environment);
await environmentService.startEnvironment(environment); await environmentService.startEnvironment(environment);
this.environments.set(environment.id, environment);
if (environment.status === "FAILED") { if (environment.status === "FAILED") {
environment.isIdle = false; environment.isIdle = false;
......
...@@ -9,7 +9,7 @@ if trial_env_vars.NNI_PLATFORM is None: ...@@ -9,7 +9,7 @@ if trial_env_vars.NNI_PLATFORM is None:
from .standalone import * from .standalone import *
elif trial_env_vars.NNI_PLATFORM == 'unittest': elif trial_env_vars.NNI_PLATFORM == 'unittest':
from .test import * from .test import *
elif trial_env_vars.NNI_PLATFORM in ('local', 'remote', 'pai', 'kubeflow', 'frameworkcontroller', 'paiYarn', 'dlts'): elif trial_env_vars.NNI_PLATFORM in ('local', 'remote', 'pai', 'kubeflow', 'frameworkcontroller', 'paiYarn', 'dlts', 'aml'):
from .local import * from .local import *
else: else:
raise RuntimeError('Unknown platform %s' % trial_env_vars.NNI_PLATFORM) raise RuntimeError('Unknown platform %s' % trial_env_vars.NNI_PLATFORM)
...@@ -116,7 +116,7 @@ common_schema = { ...@@ -116,7 +116,7 @@ common_schema = {
Optional('maxExecDuration'): And(Regex(r'^[1-9][0-9]*[s|m|h|d]$', error='ERROR: maxExecDuration format is [digit]{s,m,h,d}')), Optional('maxExecDuration'): And(Regex(r'^[1-9][0-9]*[s|m|h|d]$', error='ERROR: maxExecDuration format is [digit]{s,m,h,d}')),
Optional('maxTrialNum'): setNumberRange('maxTrialNum', int, 1, 99999), Optional('maxTrialNum'): setNumberRange('maxTrialNum', int, 1, 99999),
'trainingServicePlatform': setChoice( 'trainingServicePlatform': setChoice(
'trainingServicePlatform', 'remote', 'local', 'pai', 'kubeflow', 'frameworkcontroller', 'paiYarn', 'dlts'), 'trainingServicePlatform', 'remote', 'local', 'pai', 'kubeflow', 'frameworkcontroller', 'paiYarn', 'dlts', 'aml'),
Optional('searchSpacePath'): And(os.path.exists, error=SCHEMA_PATH_ERROR % 'searchSpacePath'), Optional('searchSpacePath'): And(os.path.exists, error=SCHEMA_PATH_ERROR % 'searchSpacePath'),
Optional('multiPhase'): setType('multiPhase', bool), Optional('multiPhase'): setType('multiPhase', bool),
Optional('multiThread'): setType('multiThread', bool), Optional('multiThread'): setType('multiThread', bool),
...@@ -234,6 +234,23 @@ dlts_config_schema = { ...@@ -234,6 +234,23 @@ dlts_config_schema = {
} }
} }
aml_trial_schema = {
'trial':{
'codeDir': setPathCheck('codeDir'),
'command': setType('command', str),
'image': setType('image', str),
'computeTarget': setType('computeTarget', str)
}
}
aml_config_schema = {
'amlConfig': {
'subscriptionId': setType('subscriptionId', str),
'resourceGroup': setType('resourceGroup', str),
'workspaceName': setType('workspaceName', str),
}
}
kubeflow_trial_schema = { kubeflow_trial_schema = {
'trial':{ 'trial':{
'codeDir': setPathCheck('codeDir'), 'codeDir': setPathCheck('codeDir'),
...@@ -374,6 +391,7 @@ training_service_schema_dict = { ...@@ -374,6 +391,7 @@ training_service_schema_dict = {
'paiYarn': Schema({**common_schema, **pai_yarn_trial_schema, **pai_yarn_config_schema}), 'paiYarn': Schema({**common_schema, **pai_yarn_trial_schema, **pai_yarn_config_schema}),
'kubeflow': Schema({**common_schema, **kubeflow_trial_schema, **kubeflow_config_schema}), 'kubeflow': Schema({**common_schema, **kubeflow_trial_schema, **kubeflow_config_schema}),
'frameworkcontroller': Schema({**common_schema, **frameworkcontroller_trial_schema, **frameworkcontroller_config_schema}), 'frameworkcontroller': Schema({**common_schema, **frameworkcontroller_trial_schema, **frameworkcontroller_config_schema}),
'aml': Schema({**common_schema, **aml_trial_schema, **aml_config_schema}),
'dlts': Schema({**common_schema, **dlts_trial_schema, **dlts_config_schema}), 'dlts': Schema({**common_schema, **dlts_trial_schema, **dlts_config_schema}),
} }
......
...@@ -272,6 +272,25 @@ def set_dlts_config(experiment_config, port, config_file_name): ...@@ -272,6 +272,25 @@ def set_dlts_config(experiment_config, port, config_file_name):
#set trial_config #set trial_config
return set_trial_config(experiment_config, port, config_file_name), err_message return set_trial_config(experiment_config, port, config_file_name), err_message
def set_aml_config(experiment_config, port, config_file_name):
'''set aml configuration'''
aml_config_data = dict()
aml_config_data['aml_config'] = experiment_config['amlConfig']
response = rest_put(cluster_metadata_url(port), json.dumps(aml_config_data), REST_TIME_OUT)
err_message = None
if not response or not response.status_code == 200:
if response is not None:
err_message = response.text
_, stderr_full_path = get_log_path(config_file_name)
with open(stderr_full_path, 'a+') as fout:
fout.write(json.dumps(json.loads(err_message), indent=4, sort_keys=True, separators=(',', ':')))
return False, err_message
result, message = setNNIManagerIp(experiment_config, port, config_file_name)
if not result:
return result, message
#set trial_config
return set_trial_config(experiment_config, port, config_file_name), err_message
def set_experiment(experiment_config, mode, port, config_file_name): def set_experiment(experiment_config, mode, port, config_file_name):
'''Call startExperiment (rest POST /experiment) with yaml file content''' '''Call startExperiment (rest POST /experiment) with yaml file content'''
request_data = dict() request_data = dict()
...@@ -374,6 +393,8 @@ def set_platform_config(platform, experiment_config, port, config_file_name, res ...@@ -374,6 +393,8 @@ def set_platform_config(platform, experiment_config, port, config_file_name, res
config_result, err_msg = set_frameworkcontroller_config(experiment_config, port, config_file_name) config_result, err_msg = set_frameworkcontroller_config(experiment_config, port, config_file_name)
elif platform == 'dlts': elif platform == 'dlts':
config_result, err_msg = set_dlts_config(experiment_config, port, config_file_name) config_result, err_msg = set_dlts_config(experiment_config, port, config_file_name)
elif platform == 'aml':
config_result, err_msg = set_aml_config(experiment_config, port, config_file_name)
else: else:
raise Exception(ERROR_INFO % 'Unsupported platform!') raise Exception(ERROR_INFO % 'Unsupported platform!')
exit(1) exit(1)
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from azureml.core.run import Run # pylint: disable=import-error
from .base_channel import BaseChannel
from .log_utils import LogType, nni_log
class AMLChannel(BaseChannel):
def __init__(self, args):
self.args = args
self.run = Run.get_context()
super(AMLChannel, self).__init__(args)
self.current_message_index = -1
def _inner_open(self):
pass
def _inner_close(self):
pass
def _inner_send(self, message):
try:
self.run.log('trial_runner', message.decode('utf8'))
except Exception as exception:
nni_log(LogType.Error, 'meet unhandled exception when send message: %s' % exception)
def _inner_receive(self):
messages = []
message_dict = self.run.get_metrics()
if 'nni_manager' not in message_dict:
return []
message_list = message_dict['nni_manager']
if not message_list:
return messages
if type(message_list) is list:
if self.current_message_index < len(message_list) - 1:
messages = message_list[self.current_message_index + 1 : len(message_list)]
self.current_message_index = len(message_list) - 1
elif self.current_message_index == -1:
messages = [message_list]
self.current_message_index += 1
newMessage = []
for message in messages:
# receive message is string, to get consistent result, encode it here.
newMessage.append(message.encode('utf8'))
return newMessage
...@@ -210,6 +210,9 @@ if __name__ == '__main__': ...@@ -210,6 +210,9 @@ if __name__ == '__main__':
command_channel = None command_channel = None
if args.command_channel == "file": if args.command_channel == "file":
command_channel = FileChannel(args) command_channel = FileChannel(args)
elif args.command_channel == 'aml':
from .aml_channel import AMLChannel
command_channel = AMLChannel(args)
else: else:
command_channel = WebChannel(args) command_channel = WebChannel(args)
command_channel.open() command_channel.open()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment