"...composable_kernel_onnxruntime.git" did not exist on "7e9a9d32c7a9259a1bd57b0b461c36d089d26fe8"
Unverified Commit 872554f1 authored by SparkSnail's avatar SparkSnail Committed by GitHub
Browse files

Support heterogeneous environment service (#3097)

parent dec91f7e
**Run an Experiment on Heterogeneous Mode**
===========================================
Run NNI on heterogeneous mode means that NNI will run trials jobs in multiple kinds of training platforms. For example, NNI could submit trial jobs to remote machine and AML simultaneously。
## Setup environment
NNI has supported [local](./LocalMode.md), [remote](./RemoteMachineMode.md), [pai](./PaiMode.md) and [AML](./AMLMode.md) for heterogeneous training service. Before starting an experiment using these mode, users should setup the corresponding environment for the platforms. More details about the environment setup could be found in the corresponding docs.
## Run an experiment
Use `examples/trials/mnist-tfv1` as an example. The NNI config YAML file's content is like:
.. code-block:: yaml
authorName: default
experimentName: example_mnist
trialConcurrency: 2
maxExecDuration: 1h
maxTrialNum: 10
trainingServicePlatform: heterogeneous
searchSpacePath: search_space.json
#choice: true, false
useAnnotation: false
tuner:
builtinTunerName: TPE
classArgs:
#choice: maximize, minimize
optimize_mode: maximize
trial:
command: python3 mnist.py
codeDir: .
gpuNum: 1
heterogeneousConfig:
trainingServicePlatforms:
- local
- remote
remoteConfig:
reuse: true
machineList:
- ip: 10.1.1.1
username: bob
passwd: bob123
Configurations for heterogeneous mode:
heterogeneousConfig:
* trainingServicePlatforms. required key. This field specify the platforms used in heterogeneous mode, the values using yaml list format. NNI support setting `local`, `remote`, `aml`, `pai` in this field.
Note:
If setting a platform in trainingServicePlatforms mode, users should also set the corresponding configuration for the platform. For example, if set `remote` as one of the platform, should also set `machineList` and `remoteConfig` configuration.
...@@ -12,3 +12,4 @@ Introduction to NNI Training Services ...@@ -12,3 +12,4 @@ Introduction to NNI Training Services
FrameworkController<./TrainingService/FrameworkControllerMode> FrameworkController<./TrainingService/FrameworkControllerMode>
DLTS<./TrainingService/DLTSMode> DLTS<./TrainingService/DLTSMode>
AML<./TrainingService/AMLMode> AML<./TrainingService/AMLMode>
Heterogeneous<./TrainingService/HeterogeneousMode>
authorName: default
experimentName: example_mnist
trialConcurrency: 3
maxExecDuration: 1h
maxTrialNum: 10
trainingServicePlatform: heterogeneous
searchSpacePath: search_space.json
#choice: true, false
useAnnotation: false
tuner:
#choice: TPE, Random, Anneal, Evolution, BatchTuner, MetisTuner, GPTuner
#SMAC (SMAC should be installed through nnictl)
builtinTunerName: TPE
classArgs:
#choice: maximize, minimize
optimize_mode: maximize
trial:
command: python3 mnist.py
codeDir: .
gpuNum: 0
heterogeneousConfig:
trainingServicePlatforms:
- local
- remote
remoteConfig:
reuse: true
machineList:
- ip: 10.1.1.1
username: bob
passwd: bob123
#port can be skip if using default ssh port 22
#port: 22
\ No newline at end of file
...@@ -12,7 +12,8 @@ _trial_env_var_names = [ ...@@ -12,7 +12,8 @@ _trial_env_var_names = [
'NNI_SYS_DIR', 'NNI_SYS_DIR',
'NNI_OUTPUT_DIR', 'NNI_OUTPUT_DIR',
'NNI_TRIAL_SEQ_ID', 'NNI_TRIAL_SEQ_ID',
'MULTI_PHASE' 'MULTI_PHASE',
'REUSE_MODE'
] ]
_dispatcher_env_var_names = [ _dispatcher_env_var_names = [
......
...@@ -31,7 +31,7 @@ def init_logger() -> None: ...@@ -31,7 +31,7 @@ def init_logger() -> None:
if trial_platform == 'unittest': if trial_platform == 'unittest':
return return
if trial_platform: if trial_platform and not trial_env_vars.REUSE_MODE:
_init_logger_trial() _init_logger_trial()
return return
......
...@@ -9,7 +9,7 @@ if trial_env_vars.NNI_PLATFORM is None: ...@@ -9,7 +9,7 @@ if trial_env_vars.NNI_PLATFORM is None:
from .standalone import * from .standalone import *
elif trial_env_vars.NNI_PLATFORM == 'unittest': elif trial_env_vars.NNI_PLATFORM == 'unittest':
from .test import * from .test import *
elif trial_env_vars.NNI_PLATFORM in ('adl', 'local', 'remote', 'pai', 'kubeflow', 'frameworkcontroller', 'paiYarn', 'dlts', 'aml'): elif trial_env_vars.NNI_PLATFORM in ('local', 'remote', 'pai', 'kubeflow', 'frameworkcontroller', 'paiYarn', 'dlts', 'aml', 'adl', 'heterogeneous'):
from .local import * from .local import *
else: else:
raise RuntimeError('Unknown platform %s' % trial_env_vars.NNI_PLATFORM) raise RuntimeError('Unknown platform %s' % trial_env_vars.NNI_PLATFORM)
...@@ -19,6 +19,7 @@ _outputdir = trial_env_vars.NNI_OUTPUT_DIR ...@@ -19,6 +19,7 @@ _outputdir = trial_env_vars.NNI_OUTPUT_DIR
if not os.path.exists(_outputdir): if not os.path.exists(_outputdir):
os.makedirs(_outputdir) os.makedirs(_outputdir)
_reuse_mode = trial_env_vars.REUSE_MODE
_nni_platform = trial_env_vars.NNI_PLATFORM _nni_platform = trial_env_vars.NNI_PLATFORM
_multiphase = trial_env_vars.MULTI_PHASE _multiphase = trial_env_vars.MULTI_PHASE
...@@ -58,7 +59,7 @@ def get_next_parameter(): ...@@ -58,7 +59,7 @@ def get_next_parameter():
return params return params
def send_metric(string): def send_metric(string):
if _nni_platform != 'local': if _nni_platform != 'local' or _reuse_mode in ('true', 'True'):
assert len(string) < 1000000, 'Metric too long' assert len(string) < 1000000, 'Metric too long'
print("NNISDK_MEb'%s'" % (string), flush=True) print("NNISDK_MEb'%s'" % (string), flush=True)
else: else:
......
...@@ -124,7 +124,7 @@ common_schema = { ...@@ -124,7 +124,7 @@ common_schema = {
Optional('maxExecDuration'): And(Regex(r'^[1-9][0-9]*[s|m|h|d]$', error='ERROR: maxExecDuration format is [digit]{s,m,h,d}')), Optional('maxExecDuration'): And(Regex(r'^[1-9][0-9]*[s|m|h|d]$', error='ERROR: maxExecDuration format is [digit]{s,m,h,d}')),
Optional('maxTrialNum'): setNumberRange('maxTrialNum', int, 1, 99999), Optional('maxTrialNum'): setNumberRange('maxTrialNum', int, 1, 99999),
'trainingServicePlatform': setChoice( 'trainingServicePlatform': setChoice(
'trainingServicePlatform', 'adl', 'remote', 'local', 'pai', 'kubeflow', 'frameworkcontroller', 'paiYarn', 'dlts', 'aml'), 'trainingServicePlatform', 'remote', 'local', 'pai', 'kubeflow', 'frameworkcontroller', 'paiYarn', 'dlts', 'aml', 'adl', 'heterogeneous'),
Optional('searchSpacePath'): And(os.path.exists, error=SCHEMA_PATH_ERROR % 'searchSpacePath'), Optional('searchSpacePath'): And(os.path.exists, error=SCHEMA_PATH_ERROR % 'searchSpacePath'),
Optional('multiPhase'): setType('multiPhase', bool), Optional('multiPhase'): setType('multiPhase', bool),
Optional('multiThread'): setType('multiThread', bool), Optional('multiThread'): setType('multiThread', bool),
...@@ -208,7 +208,7 @@ pai_trial_schema = { ...@@ -208,7 +208,7 @@ pai_trial_schema = {
} }
pai_config_schema = { pai_config_schema = {
'paiConfig': { Optional('paiConfig'): {
'userName': setType('userName', str), 'userName': setType('userName', str),
Or('passWord', 'token', only_one=True): str, Or('passWord', 'token', only_one=True): str,
'host': setType('host', str), 'host': setType('host', str),
...@@ -252,7 +252,7 @@ aml_trial_schema = { ...@@ -252,7 +252,7 @@ aml_trial_schema = {
} }
aml_config_schema = { aml_config_schema = {
'amlConfig': { Optional('amlConfig'): {
'subscriptionId': setType('subscriptionId', str), 'subscriptionId': setType('subscriptionId', str),
'resourceGroup': setType('resourceGroup', str), 'resourceGroup': setType('resourceGroup', str),
'workspaceName': setType('workspaceName', str), 'workspaceName': setType('workspaceName', str),
...@@ -262,6 +262,29 @@ aml_config_schema = { ...@@ -262,6 +262,29 @@ aml_config_schema = {
} }
} }
heterogeneous_trial_schema = {
'trial': {
'codeDir': setPathCheck('codeDir'),
Optional('nniManagerNFSMountPath'): setPathCheck('nniManagerNFSMountPath'),
Optional('containerNFSMountPath'): setType('containerNFSMountPath', str),
Optional('nasMode'): setChoice('nasMode', 'classic_mode', 'enas_mode', 'oneshot_mode', 'darts_mode'),
'command': setType('command', str),
Optional('gpuNum'): setNumberRange('gpuNum', int, 0, 99999),
Optional('cpuNum'): setNumberRange('cpuNum', int, 0, 99999),
Optional('memoryMB'): setType('memoryMB', int),
Optional('image'): setType('image', str),
Optional('virtualCluster'): setType('virtualCluster', str),
Optional('paiStorageConfigName'): setType('paiStorageConfigName', str),
Optional('paiConfigPath'): And(os.path.exists, error=SCHEMA_PATH_ERROR % 'paiConfigPath')
}
}
heterogeneous_config_schema = {
'heterogeneousConfig': {
'trainingServicePlatforms': ['local', 'remote', 'pai', 'aml']
}
}
adl_trial_schema = { adl_trial_schema = {
'trial':{ 'trial':{
'codeDir': setType('codeDir', str), 'codeDir': setType('codeDir', str),
...@@ -404,7 +427,7 @@ remote_config_schema = { ...@@ -404,7 +427,7 @@ remote_config_schema = {
} }
machine_list_schema = { machine_list_schema = {
'machineList': [Or( Optional('machineList'): [Or(
{ {
'ip': setType('ip', str), 'ip': setType('ip', str),
Optional('port'): setNumberRange('port', int, 1, 65535), Optional('port'): setNumberRange('port', int, 1, 65535),
...@@ -438,6 +461,8 @@ training_service_schema_dict = { ...@@ -438,6 +461,8 @@ training_service_schema_dict = {
'frameworkcontroller': Schema({**common_schema, **frameworkcontroller_trial_schema, **frameworkcontroller_config_schema}), 'frameworkcontroller': Schema({**common_schema, **frameworkcontroller_trial_schema, **frameworkcontroller_config_schema}),
'aml': Schema({**common_schema, **aml_trial_schema, **aml_config_schema}), 'aml': Schema({**common_schema, **aml_trial_schema, **aml_config_schema}),
'dlts': Schema({**common_schema, **dlts_trial_schema, **dlts_config_schema}), 'dlts': Schema({**common_schema, **dlts_trial_schema, **dlts_config_schema}),
'heterogeneous': Schema({**common_schema, **heterogeneous_trial_schema, **heterogeneous_config_schema, **machine_list_schema,
**pai_config_schema, **aml_config_schema, **remote_config_schema}),
} }
...@@ -454,6 +479,7 @@ class NNIConfigSchema: ...@@ -454,6 +479,7 @@ class NNIConfigSchema:
self.validate_pai_trial_conifg(experiment_config) self.validate_pai_trial_conifg(experiment_config)
self.validate_kubeflow_operators(experiment_config) self.validate_kubeflow_operators(experiment_config)
self.validate_eth0_device(experiment_config) self.validate_eth0_device(experiment_config)
self.validate_heterogeneous_platforms(experiment_config)
def validate_tuner_adivosr_assessor(self, experiment_config): def validate_tuner_adivosr_assessor(self, experiment_config):
if experiment_config.get('advisor'): if experiment_config.get('advisor'):
...@@ -563,3 +589,16 @@ class NNIConfigSchema: ...@@ -563,3 +589,16 @@ class NNIConfigSchema:
and not experiment_config.get('nniManagerIp') \ and not experiment_config.get('nniManagerIp') \
and 'eth0' not in netifaces.interfaces(): and 'eth0' not in netifaces.interfaces():
raise SchemaError('This machine does not contain eth0 network device, please set nniManagerIp in config file!') raise SchemaError('This machine does not contain eth0 network device, please set nniManagerIp in config file!')
def validate_heterogeneous_platforms(self, experiment_config):
required_config_name_map = {
'remote': 'machineList',
'aml': 'amlConfig',
'pai': 'paiConfig'
}
if experiment_config.get('trainingServicePlatform') == 'heterogeneous':
for platform in experiment_config['heterogeneousConfig']['trainingServicePlatforms']:
config_name = required_config_name_map.get(platform)
if config_name and not experiment_config.get(config_name):
raise SchemaError('Need to set {0} for {1} in heterogeneous mode!'.format(config_name, platform))
\ No newline at end of file
...@@ -118,13 +118,6 @@ def set_local_config(experiment_config, port, config_file_name): ...@@ -118,13 +118,6 @@ def set_local_config(experiment_config, port, config_file_name):
request_data = dict() request_data = dict()
if experiment_config.get('localConfig'): if experiment_config.get('localConfig'):
request_data['local_config'] = experiment_config['localConfig'] request_data['local_config'] = experiment_config['localConfig']
if request_data['local_config']:
if request_data['local_config'].get('gpuIndices') and isinstance(request_data['local_config'].get('gpuIndices'), int):
request_data['local_config']['gpuIndices'] = str(request_data['local_config'].get('gpuIndices'))
if request_data['local_config'].get('maxTrialNumOnEachGpu'):
request_data['local_config']['maxTrialNumOnEachGpu'] = request_data['local_config'].get('maxTrialNumOnEachGpu')
if request_data['local_config'].get('useActiveGpu'):
request_data['local_config']['useActiveGpu'] = request_data['local_config'].get('useActiveGpu')
response = rest_put(cluster_metadata_url(port), json.dumps(request_data), REST_TIME_OUT) response = rest_put(cluster_metadata_url(port), json.dumps(request_data), REST_TIME_OUT)
err_message = '' err_message = ''
if not response or not check_response(response): if not response or not check_response(response):
...@@ -306,6 +299,37 @@ def set_aml_config(experiment_config, port, config_file_name): ...@@ -306,6 +299,37 @@ def set_aml_config(experiment_config, port, config_file_name):
#set trial_config #set trial_config
return set_trial_config(experiment_config, port, config_file_name), err_message return set_trial_config(experiment_config, port, config_file_name), err_message
def set_heterogeneous_config(experiment_config, port, config_file_name):
'''set heterogeneous configuration'''
heterogeneous_config_data = dict()
heterogeneous_config_data['heterogeneous_config'] = experiment_config['heterogeneousConfig']
platform_list = experiment_config['heterogeneousConfig']['trainingServicePlatforms']
for platform in platform_list:
if platform == 'aml':
heterogeneous_config_data['aml_config'] = experiment_config['amlConfig']
elif platform == 'remote':
if experiment_config.get('remoteConfig'):
heterogeneous_config_data['remote_config'] = experiment_config['remoteConfig']
heterogeneous_config_data['machine_list'] = experiment_config['machineList']
elif platform == 'local' and experiment_config.get('localConfig'):
heterogeneous_config_data['local_config'] = experiment_config['localConfig']
elif platform == 'pai':
heterogeneous_config_data['pai_config'] = experiment_config['paiConfig']
response = rest_put(cluster_metadata_url(port), json.dumps(heterogeneous_config_data), REST_TIME_OUT)
err_message = None
if not response or not response.status_code == 200:
if response is not None:
err_message = response.text
_, stderr_full_path = get_log_path(config_file_name)
with open(stderr_full_path, 'a+') as fout:
fout.write(json.dumps(json.loads(err_message), indent=4, sort_keys=True, separators=(',', ':')))
return False, err_message
result, message = setNNIManagerIp(experiment_config, port, config_file_name)
if not result:
return result, message
#set trial_config
return set_trial_config(experiment_config, port, config_file_name), err_message
def set_experiment(experiment_config, mode, port, config_file_name): def set_experiment(experiment_config, mode, port, config_file_name):
'''Call startExperiment (rest POST /experiment) with yaml file content''' '''Call startExperiment (rest POST /experiment) with yaml file content'''
request_data = dict() request_data = dict()
...@@ -387,6 +411,21 @@ def set_experiment(experiment_config, mode, port, config_file_name): ...@@ -387,6 +411,21 @@ def set_experiment(experiment_config, mode, port, config_file_name):
{'key': 'aml_config', 'value': experiment_config['amlConfig']}) {'key': 'aml_config', 'value': experiment_config['amlConfig']})
request_data['clusterMetaData'].append( request_data['clusterMetaData'].append(
{'key': 'trial_config', 'value': experiment_config['trial']}) {'key': 'trial_config', 'value': experiment_config['trial']})
elif experiment_config['trainingServicePlatform'] == 'heterogeneous':
request_data['clusterMetaData'].append(
{'key': 'heterogeneous_config', 'value': experiment_config['heterogeneousConfig']})
platform_list = experiment_config['heterogeneousConfig']['trainingServicePlatforms']
request_dict = {
'aml': {'key': 'aml_config', 'value': experiment_config.get('amlConfig')},
'remote': {'key': 'machine_list', 'value': experiment_config.get('machineList')},
'pai': {'key': 'pai_config', 'value': experiment_config.get('paiConfig')},
'local': {'key': 'local_config', 'value': experiment_config.get('localConfig')}
}
for platform in platform_list:
if request_dict.get(platform):
request_data['clusterMetaData'].append(request_dict[platform])
request_data['clusterMetaData'].append(
{'key': 'trial_config', 'value': experiment_config['trial']})
response = rest_post(experiment_url(port), json.dumps(request_data), REST_TIME_OUT, show_error=True) response = rest_post(experiment_url(port), json.dumps(request_data), REST_TIME_OUT, show_error=True)
if check_response(response): if check_response(response):
return response return response
...@@ -420,6 +459,8 @@ def set_platform_config(platform, experiment_config, port, config_file_name, res ...@@ -420,6 +459,8 @@ def set_platform_config(platform, experiment_config, port, config_file_name, res
config_result, err_msg = set_dlts_config(experiment_config, port, config_file_name) config_result, err_msg = set_dlts_config(experiment_config, port, config_file_name)
elif platform == 'aml': elif platform == 'aml':
config_result, err_msg = set_aml_config(experiment_config, port, config_file_name) config_result, err_msg = set_aml_config(experiment_config, port, config_file_name)
elif platform == 'heterogeneous':
config_result, err_msg = set_heterogeneous_config(experiment_config, port, config_file_name)
else: else:
raise Exception(ERROR_INFO % 'Unsupported platform!') raise Exception(ERROR_INFO % 'Unsupported platform!')
exit(1) exit(1)
......
...@@ -25,7 +25,6 @@ def main_loop(args): ...@@ -25,7 +25,6 @@ def main_loop(args):
'''main loop logic for trial runner''' '''main loop logic for trial runner'''
idle_last_time = datetime.now() idle_last_time = datetime.now()
gpu_refresh_last_time = datetime.now() - timedelta(minutes=1) gpu_refresh_last_time = datetime.now() - timedelta(minutes=1)
try: try:
if args.job_pid_file: if args.job_pid_file:
with open(args.job_pid_file, 'w') as job_file: with open(args.job_pid_file, 'w') as job_file:
...@@ -188,6 +187,7 @@ if __name__ == '__main__': ...@@ -188,6 +187,7 @@ if __name__ == '__main__':
os.environ['NNI_EXP_ID'] = args.exp_id os.environ['NNI_EXP_ID'] = args.exp_id
os.environ['MULTI_PHASE'] = "true" os.environ['MULTI_PHASE'] = "true"
os.environ['NNI_TRIAL_JOB_ID'] = "runner" os.environ['NNI_TRIAL_JOB_ID'] = "runner"
os.environ['REUSE_MODE'] = "true"
from .log_utils import LogType, RemoteLogger, StdOutputType, nni_log from .log_utils import LogType, RemoteLogger, StdOutputType, nni_log
from .trial import Trial from .trial import Trial
......
...@@ -28,6 +28,7 @@ import { RouterTrainingService } from './training_service/reusable/routerTrainin ...@@ -28,6 +28,7 @@ import { RouterTrainingService } from './training_service/reusable/routerTrainin
import { PAIYarnTrainingService } from './training_service/pai/paiYarn/paiYarnTrainingService'; import { PAIYarnTrainingService } from './training_service/pai/paiYarn/paiYarnTrainingService';
import { DLTSTrainingService } from './training_service/dlts/dltsTrainingService'; import { DLTSTrainingService } from './training_service/dlts/dltsTrainingService';
function initStartupInfo( function initStartupInfo(
startExpMode: string, experimentId: string, basePort: number, platform: string, startExpMode: string, experimentId: string, basePort: number, platform: string,
logDirectory: string, experimentLogLevel: string, readonly: boolean, dispatcherPipe: string): void { logDirectory: string, experimentLogLevel: string, readonly: boolean, dispatcherPipe: string): void {
...@@ -36,22 +37,15 @@ function initStartupInfo( ...@@ -36,22 +37,15 @@ function initStartupInfo(
} }
async function initContainer(foreground: boolean, platformMode: string, logFileName?: string): Promise<void> { async function initContainer(foreground: boolean, platformMode: string, logFileName?: string): Promise<void> {
if (platformMode === 'adl') { const routerPlatformMode = ['remote', 'pai', 'aml', 'heterogeneous'];
if (routerPlatformMode.includes(platformMode)) {
Container.bind(TrainingService) Container.bind(TrainingService)
.to(AdlTrainingService) .to(RouterTrainingService)
.scope(Scope.Singleton); .scope(Scope.Singleton);
} else if (platformMode === 'local') { } else if (platformMode === 'local') {
Container.bind(TrainingService) Container.bind(TrainingService)
.to(LocalTrainingService) .to(LocalTrainingService)
.scope(Scope.Singleton); .scope(Scope.Singleton);
} else if (platformMode === 'remote') {
Container.bind(TrainingService)
.to(RouterTrainingService)
.scope(Scope.Singleton);
} else if (platformMode === 'pai') {
Container.bind(TrainingService)
.to(RouterTrainingService)
.scope(Scope.Singleton);
} else if (platformMode === 'paiYarn') { } else if (platformMode === 'paiYarn') {
Container.bind(TrainingService) Container.bind(TrainingService)
.to(PAIYarnTrainingService) .to(PAIYarnTrainingService)
...@@ -68,9 +62,9 @@ async function initContainer(foreground: boolean, platformMode: string, logFileN ...@@ -68,9 +62,9 @@ async function initContainer(foreground: boolean, platformMode: string, logFileN
Container.bind(TrainingService) Container.bind(TrainingService)
.to(DLTSTrainingService) .to(DLTSTrainingService)
.scope(Scope.Singleton); .scope(Scope.Singleton);
} else if (platformMode === 'aml') { } else if (platformMode === 'adl') {
Container.bind(TrainingService) Container.bind(TrainingService)
.to(RouterTrainingService) .to(AdlTrainingService)
.scope(Scope.Singleton); .scope(Scope.Singleton);
} else { } else {
throw new Error(`Error: unsupported mode: ${platformMode}`); throw new Error(`Error: unsupported mode: ${platformMode}`);
...@@ -103,7 +97,7 @@ async function initContainer(foreground: boolean, platformMode: string, logFileN ...@@ -103,7 +97,7 @@ async function initContainer(foreground: boolean, platformMode: string, logFileN
function usage(): void { function usage(): void {
console.info('usage: node main.js --port <port> --mode \ console.info('usage: node main.js --port <port> --mode \
<adl/local/remote/pai/kubeflow/frameworkcontroller/paiYarn/aml> --start_mode <new/resume> --experiment_id <id> --foreground <true/false>'); <local/remote/pai/kubeflow/frameworkcontroller/paiYarn/aml/adl/heterogeneous> --start_mode <new/resume> --experiment_id <id> --foreground <true/false>');
} }
const strPort: string = parseArg(['--port', '-p']); const strPort: string = parseArg(['--port', '-p']);
...@@ -123,7 +117,7 @@ const foreground: boolean = foregroundArg.toLowerCase() === 'true' ? true : fals ...@@ -123,7 +117,7 @@ const foreground: boolean = foregroundArg.toLowerCase() === 'true' ? true : fals
const port: number = parseInt(strPort, 10); const port: number = parseInt(strPort, 10);
const mode: string = parseArg(['--mode', '-m']); const mode: string = parseArg(['--mode', '-m']);
if (!['adl', 'local', 'remote', 'pai', 'kubeflow', 'frameworkcontroller', 'paiYarn', 'dlts', 'aml'].includes(mode)) { if (!['local', 'remote', 'pai', 'kubeflow', 'frameworkcontroller', 'paiYarn', 'dlts', 'aml', 'adl', 'heterogeneous'].includes(mode)) {
console.log(`FATAL: unknown mode: ${mode}`); console.log(`FATAL: unknown mode: ${mode}`);
usage(); usage();
process.exit(1); process.exit(1);
......
...@@ -23,7 +23,8 @@ export namespace ValidationSchemas { ...@@ -23,7 +23,8 @@ export namespace ValidationSchemas {
local_config: joi.object({ // eslint-disable-line @typescript-eslint/camelcase local_config: joi.object({ // eslint-disable-line @typescript-eslint/camelcase
gpuIndices: joi.string(), gpuIndices: joi.string(),
maxTrialNumPerGpu: joi.number(), maxTrialNumPerGpu: joi.number(),
useActiveGpu: joi.boolean() useActiveGpu: joi.boolean(),
reuse: joi.boolean()
}), }),
trial_config: joi.object({ // eslint-disable-line @typescript-eslint/camelcase trial_config: joi.object({ // eslint-disable-line @typescript-eslint/camelcase
image: joi.string().min(1), image: joi.string().min(1),
...@@ -182,6 +183,9 @@ export namespace ValidationSchemas { ...@@ -182,6 +183,9 @@ export namespace ValidationSchemas {
maxTrialNumPerGpu: joi.number(), maxTrialNumPerGpu: joi.number(),
useActiveGpu: joi.boolean() useActiveGpu: joi.boolean()
}), }),
heterogeneous_config: joi.object({ // eslint-disable-line @typescript-eslint/camelcase
trainingServicePlatforms: joi.array(),
}),
nni_manager_ip: joi.object({ // eslint-disable-line @typescript-eslint/camelcase nni_manager_ip: joi.object({ // eslint-disable-line @typescript-eslint/camelcase
nniManagerIp: joi.string().min(1) nniManagerIp: joi.string().min(1)
}), }),
......
...@@ -11,6 +11,7 @@ export enum TrialConfigMetadataKey { ...@@ -11,6 +11,7 @@ export enum TrialConfigMetadataKey {
LOCAL_CONFIG = 'local_config', LOCAL_CONFIG = 'local_config',
TRIAL_CONFIG = 'trial_config', TRIAL_CONFIG = 'trial_config',
REMOTE_CONFIG = 'remote_config', REMOTE_CONFIG = 'remote_config',
HETEROGENEOUS_CONFIG = 'heterogeneous_config',
EXPERIMENT_ID = 'experimentId', EXPERIMENT_ID = 'experimentId',
MULTI_PHASE = 'multiPhase', MULTI_PHASE = 'multiPhase',
RANDOM_SCHEDULER = 'random_scheduler', RANDOM_SCHEDULER = 'random_scheduler',
...@@ -22,5 +23,8 @@ export enum TrialConfigMetadataKey { ...@@ -22,5 +23,8 @@ export enum TrialConfigMetadataKey {
DLTS_CLUSTER_CONFIG = 'dlts_config', DLTS_CLUSTER_CONFIG = 'dlts_config',
AML_CLUSTER_CONFIG = 'aml_config', AML_CLUSTER_CONFIG = 'aml_config',
VERSION_CHECK = 'version_check', VERSION_CHECK = 'version_check',
LOG_COLLECTION = 'log_collection' LOG_COLLECTION = 'log_collection',
// Used to set platform for heterogeneous in reuse mode,
// temproarily change and will refactor config schema in the future
PLATFORM_LIST = 'platform_list'
} }
...@@ -78,7 +78,7 @@ class LocalTrialJobDetail implements TrialJobDetail { ...@@ -78,7 +78,7 @@ class LocalTrialJobDetail implements TrialJobDetail {
/** /**
* Local training service config * Local training service config
*/ */
class LocalConfig { export class LocalConfig {
public maxTrialNumPerGpu?: number; public maxTrialNumPerGpu?: number;
public gpuIndices?: string; public gpuIndices?: string;
public useActiveGpu?: boolean; public useActiveGpu?: boolean;
......
...@@ -358,6 +358,10 @@ class RemoteMachineTrainingService implements TrainingService { ...@@ -358,6 +358,10 @@ class RemoteMachineTrainingService implements TrainingService {
case TrialConfigMetadataKey.LOG_COLLECTION: case TrialConfigMetadataKey.LOG_COLLECTION:
this.logCollection = value; this.logCollection = value;
break; break;
case TrialConfigMetadataKey.REMOTE_CONFIG:
// Add remote_config in remoteEnvironmentService to set reuse mode,
// this config need to be catched here, otherwise will throw Unknown key exception here
break;
default: default:
//Reject for unknown keys //Reject for unknown keys
throw new Error(`Uknown key: ${key}`); throw new Error(`Uknown key: ${key}`);
......
...@@ -8,6 +8,7 @@ import { getBasePort, getExperimentId } from "../../../common/experimentStartupI ...@@ -8,6 +8,7 @@ import { getBasePort, getExperimentId } from "../../../common/experimentStartupI
import { INITIALIZED } from '../../../core/commands'; import { INITIALIZED } from '../../../core/commands';
import { CommandChannel, RunnerConnection } from "../commandChannel"; import { CommandChannel, RunnerConnection } from "../commandChannel";
import { Channel, EnvironmentInformation } from "../environment"; import { Channel, EnvironmentInformation } from "../environment";
import { EventEmitter } from "events";
class WebRunnerConnection extends RunnerConnection { class WebRunnerConnection extends RunnerConnection {
public readonly clients: WebSocket[] = []; public readonly clients: WebSocket[] = [];
...@@ -29,7 +30,7 @@ class WebRunnerConnection extends RunnerConnection { ...@@ -29,7 +30,7 @@ class WebRunnerConnection extends RunnerConnection {
export class WebCommandChannel extends CommandChannel { export class WebCommandChannel extends CommandChannel {
private readonly expId: string = getExperimentId(); private readonly expId: string = getExperimentId();
private static commandChannel: WebCommandChannel;
private webSocketServer: SocketServer | undefined; private webSocketServer: SocketServer | undefined;
private clients: Map<WebSocket, WebRunnerConnection | undefined> = new Map<WebSocket, WebRunnerConnection | undefined>(); private clients: Map<WebSocket, WebRunnerConnection | undefined> = new Map<WebSocket, WebRunnerConnection | undefined>();
...@@ -41,6 +42,18 @@ export class WebCommandChannel extends CommandChannel { ...@@ -41,6 +42,18 @@ export class WebCommandChannel extends CommandChannel {
// do nothing // do nothing
} }
// Set WebCommandChannel as singleton mode, one experiment could only start one webCommandChannel instance
private constructor(commandEmitter: EventEmitter) {
super(commandEmitter);
}
public static getInstance(commandEmitter: EventEmitter): CommandChannel {
if (!this.commandChannel) {
this.commandChannel = new WebCommandChannel(commandEmitter);
}
return this.commandChannel;
}
public async start(): Promise<void> { public async start(): Promise<void> {
const port = getBasePort() + 1; const port = getBasePort() + 1;
this.webSocketServer = new SocketServer({ port }); this.webSocketServer = new SocketServer({ port });
......
...@@ -3,12 +3,12 @@ ...@@ -3,12 +3,12 @@
'use strict'; 'use strict';
import { EventEmitter } from "events";
import { getLogger, Logger } from "../../common/log"; import { getLogger, Logger } from "../../common/log";
import { TrialJobStatus } from "../../common/trainingService"; import { TrialJobStatus } from "../../common/trainingService";
import { GPUInfo } from "../../training_service/common/gpuData"; import { GPUInfo } from "../../training_service/common/gpuData";
import { WebCommandChannel } from "./channels/webCommandChannel";
import { CommandChannel } from "./commandChannel"; import { CommandChannel } from "./commandChannel";
import { WebCommandChannel } from './channels/webCommandChannel';
import { EventEmitter } from "events";
export type EnvironmentStatus = 'UNKNOWN' | 'WAITING' | 'RUNNING' | 'SUCCEEDED' | 'FAILED' | 'USER_CANCELED'; export type EnvironmentStatus = 'UNKNOWN' | 'WAITING' | 'RUNNING' | 'SUCCEEDED' | 'FAILED' | 'USER_CANCELED';
...@@ -75,6 +75,8 @@ export class EnvironmentInformation { ...@@ -75,6 +75,8 @@ export class EnvironmentInformation {
public maxTrialNumberPerGpu?: number; public maxTrialNumberPerGpu?: number;
public useActiveGpu?: boolean; public useActiveGpu?: boolean;
public environmentService?: EnvironmentService;
constructor(id: string, name: string, envId?: string) { constructor(id: string, name: string, envId?: string) {
this.log = getLogger(); this.log = getLogger();
this.id = id; this.id = id;
...@@ -127,6 +129,8 @@ export abstract class EnvironmentService { ...@@ -127,6 +129,8 @@ export abstract class EnvironmentService {
public abstract refreshEnvironmentsStatus(environments: EnvironmentInformation[]): Promise<void>; public abstract refreshEnvironmentsStatus(environments: EnvironmentInformation[]): Promise<void>;
public abstract stopEnvironment(environment: EnvironmentInformation): Promise<void>; public abstract stopEnvironment(environment: EnvironmentInformation): Promise<void>;
public abstract startEnvironment(environment: EnvironmentInformation): Promise<void>; public abstract startEnvironment(environment: EnvironmentInformation): Promise<void>;
// Make public for ut
protected commandChannel: CommandChannel | undefined;
// It is used to set prefetched environment count, default value is 0 for OpenPAI and AML mode, // It is used to set prefetched environment count, default value is 0 for OpenPAI and AML mode,
// in remote mode, this value is set to the length of machine list. // in remote mode, this value is set to the length of machine list.
...@@ -134,6 +138,20 @@ export abstract class EnvironmentService { ...@@ -134,6 +138,20 @@ export abstract class EnvironmentService {
return 0; return 0;
} }
public abstract get getName(): string;
// Initialize command channel, use WebCommandChannel as default command channel
public initCommandChannel(eventEmitter: EventEmitter): void {
this.commandChannel = WebCommandChannel.getInstance(eventEmitter);
}
public get getCommandChannel(): CommandChannel {
if (this.commandChannel === undefined) {
throw new Error("Command channel not initialized!");
}
return this.commandChannel;
}
// It depends on environment pressure and settings // It depends on environment pressure and settings
// for example, OpenPAI relies on API calls, and there is an limitation for frequence, so it need to be bigger. // for example, OpenPAI relies on API calls, and there is an limitation for frequence, so it need to be bigger.
public get environmentMaintenceLoopInterval(): number { public get environmentMaintenceLoopInterval(): number {
...@@ -147,10 +165,6 @@ export abstract class EnvironmentService { ...@@ -147,10 +165,6 @@ export abstract class EnvironmentService {
return true; return true;
} }
public createCommandChannel(commandEmitter: EventEmitter): CommandChannel {
return new WebCommandChannel(commandEmitter);
}
public createEnvironmentInformation(envId: string, envName: string): EnvironmentInformation { public createEnvironmentInformation(envId: string, envName: string): EnvironmentInformation {
return new EnvironmentInformation(envId, envName); return new EnvironmentInformation(envId, envName);
} }
......
...@@ -3,7 +3,6 @@ ...@@ -3,7 +3,6 @@
'use strict'; 'use strict';
import { EventEmitter } from "events";
import * as fs from 'fs'; import * as fs from 'fs';
import * as path from 'path'; import * as path from 'path';
import * as component from '../../../common/component'; import * as component from '../../../common/component';
...@@ -14,13 +13,13 @@ import { TrialConfigMetadataKey } from '../../common/trialConfigMetadataKey'; ...@@ -14,13 +13,13 @@ import { TrialConfigMetadataKey } from '../../common/trialConfigMetadataKey';
import { validateCodeDir } from '../../common/util'; import { validateCodeDir } from '../../common/util';
import { AMLClient } from '../aml/amlClient'; import { AMLClient } from '../aml/amlClient';
import { AMLClusterConfig, AMLEnvironmentInformation, AMLTrialConfig } from '../aml/amlConfig'; import { AMLClusterConfig, AMLEnvironmentInformation, AMLTrialConfig } from '../aml/amlConfig';
import { AMLCommandChannel } from '../channels/amlCommandChannel';
import { CommandChannel } from "../commandChannel";
import { EnvironmentInformation, EnvironmentService } from '../environment'; import { EnvironmentInformation, EnvironmentService } from '../environment';
import { EventEmitter } from "events";
import { AMLCommandChannel } from '../channels/amlCommandChannel';
/** /**
* Collector PAI jobs info from PAI cluster, and update pai job status locally * Collector AML jobs info from AML cluster, and update aml job status locally
*/ */
@component.Singleton @component.Singleton
export class AMLEnvironmentService extends EnvironmentService { export class AMLEnvironmentService extends EnvironmentService {
...@@ -41,14 +40,18 @@ export class AMLEnvironmentService extends EnvironmentService { ...@@ -41,14 +40,18 @@ export class AMLEnvironmentService extends EnvironmentService {
return false; return false;
} }
public createCommandChannel(commandEmitter: EventEmitter): CommandChannel { public initCommandChannel(eventEmitter: EventEmitter): void {
return new AMLCommandChannel(commandEmitter); this.commandChannel = new AMLCommandChannel(eventEmitter);
} }
public createEnvironmentInformation(envId: string, envName: string): EnvironmentInformation { public createEnvironmentInformation(envId: string, envName: string): EnvironmentInformation {
return new AMLEnvironmentInformation(envId, envName); return new AMLEnvironmentInformation(envId, envName);
} }
public get getName(): string {
return 'aml';
}
public async config(key: string, value: string): Promise<void> { public async config(key: string, value: string): Promise<void> {
switch (key) { switch (key) {
case TrialConfigMetadataKey.AML_CLUSTER_CONFIG: case TrialConfigMetadataKey.AML_CLUSTER_CONFIG:
......
import { AMLEnvironmentService } from './amlEnvironmentService';
import { OpenPaiEnvironmentService } from './openPaiEnvironmentService';
import { LocalEnvironmentService } from './localEnvironmentService';
import { RemoteEnvironmentService } from './remoteEnvironmentService';
import { EnvironmentService } from '../environment';
export class EnvironmentServiceFactory {
public static createEnvironmentService(name: string): EnvironmentService {
switch(name) {
case 'local':
return new LocalEnvironmentService();
case 'remote':
return new RemoteEnvironmentService();
case 'aml':
return new AMLEnvironmentService();
case 'pai':
return new OpenPaiEnvironmentService();
default:
throw new Error(`${name} not supported!`);
}
}
}
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
'use strict';
import * as fs from 'fs';
import * as path from 'path';
import * as tkill from 'tree-kill';
import * as component from '../../../common/component';
import { getExperimentId } from '../../../common/experimentStartupInfo';
import { getLogger, Logger } from '../../../common/log';
import { TrialConfigMetadataKey } from '../../common/trialConfigMetadataKey';
import { EnvironmentInformation, EnvironmentService } from '../environment';
import { TrialConfig } from '../../common/trialConfig';
import { getExperimentRootDir, isAlive } from '../../../common/utils';
import { execMkdir, runScript, execCopydir } from '../../common/util';
@component.Singleton
export class LocalEnvironmentService extends EnvironmentService {
private readonly log: Logger = getLogger();
private localTrialConfig: TrialConfig | undefined;
private experimentRootDir: string;
private experimentId: string;
constructor() {
super();
this.experimentId = getExperimentId();
this.experimentRootDir = getExperimentRootDir();
}
public get environmentMaintenceLoopInterval(): number {
return 100;
}
public get hasStorageService(): boolean {
return false;
}
public get getName(): string {
return 'local';
}
public async config(key: string, value: string): Promise<void> {
switch (key) {
case TrialConfigMetadataKey.TRIAL_CONFIG:
this.localTrialConfig = <TrialConfig>JSON.parse(value);
break;
default:
this.log.debug(`Local mode does not proccess metadata key: '${key}', value: '${value}'`);
}
}
public async refreshEnvironmentsStatus(environments: EnvironmentInformation[]): Promise<void> {
environments.forEach(async (environment) => {
const jobpidPath: string = `${environment.runnerWorkingFolder}/pid`;
const runnerReturnCodeFilePath: string = `${environment.runnerWorkingFolder}/code`;
/* eslint-disable require-atomic-updates */
try {
// check if pid file exist
const pidExist = await fs.existsSync(jobpidPath);
if (!pidExist) {
return;
}
const pid: string = await fs.promises.readFile(jobpidPath, 'utf8');
const alive: boolean = await isAlive(pid);
environment.status = 'RUNNING';
// if the process of jobpid is not alive any more
if (!alive) {
if (fs.existsSync(runnerReturnCodeFilePath)) {
const runnerReturnCode: string = await fs.promises.readFile(runnerReturnCodeFilePath, 'utf8');
const match: RegExpMatchArray | null = runnerReturnCode.trim()
.match(/^-?(\d+)\s+(\d+)$/);
if (match !== null) {
const { 1: code } = match;
// Update trial job's status based on result code
if (parseInt(code, 10) === 0) {
environment.setStatus('SUCCEEDED');
} else {
environment.setStatus('FAILED');
}
}
}
}
} catch (error) {
this.log.error(`Update job status exception, error is ${error.message}`);
}
});
}
public async startEnvironment(environment: EnvironmentInformation): Promise<void> {
if (this.localTrialConfig === undefined) {
throw new Error('Local trial config is not initialized');
}
// Need refactor, this temp folder path is not appropriate, there are two expId in this path
const localTempFolder: string = path.join(this.experimentRootDir, this.experimentId,
"environment-temp", "envs");
const localEnvCodeFolder: string = path.join(this.experimentRootDir, "envs");
environment.runnerWorkingFolder = path.join(localEnvCodeFolder, environment.id);
await execMkdir(environment.runnerWorkingFolder);
await execCopydir(localTempFolder, localEnvCodeFolder);
environment.command = `cd ${this.experimentRootDir} && \
${environment.command} --job_pid_file ${environment.runnerWorkingFolder}/pid \
1>${environment.runnerWorkingFolder}/trialrunner_stdout 2>${environment.runnerWorkingFolder}/trialrunner_stderr \
&& echo $? \`date +%s%3N\` >${environment.runnerWorkingFolder}/code`;
await fs.promises.writeFile(path.join(localEnvCodeFolder, 'nni_run.sh'),
environment.command, { encoding: 'utf8', mode: 0o777 }),
// Execute command in local machine
runScript(path.join(localEnvCodeFolder, 'nni_run.sh'));
environment.trackingUrl = `${environment.runnerWorkingFolder}`;
}
public async stopEnvironment(environment: EnvironmentInformation): Promise<void> {
const jobpidPath: string = `${environment.runnerWorkingFolder}/pid`;
const pid: string = await fs.promises.readFile(jobpidPath, 'utf8');
tkill(Number(pid), 'SIGKILL');
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment