"docs/zh_CN/Tutorial/InstallationLinux.md" did not exist on "df4f05c78a5cc3fd3dc5f9929bf35df414a2fff6"
Unverified Commit 95f731e4 authored by J-shang's avatar J-shang Committed by GitHub
Browse files

experiment management backend (#3081)



* step 1 nnictl generate experimentId & merge folder

* step 2.1 modify .experiment structure

* step 2.2 add lock for .experiment rw in nnictl

* step 2.2 add filelock dependence

* step 2.2 remove uniqueString from main.js

* fix test bug

* fix test bug

* setp 3.1 add experiment manager

* step 3.2 add getExperimentsInfo

* fix eslint

* add a simple file lock to support stale

* step 3.3 add test

* divide abs experiment manager from manager

* experiment manager refactor

* support .experiment sync update status

* nnictl no longer uses rest api to update status or endtime

* nnictl no longer uses rest api to update status or endtime

* fix eslint

* support .experiment sync update endtime

* fix test

* fix settimeout bug

* fix test

* adjust experiment endTime

* separate simple file lock class

* modify name

* add 'id' in .experiment

* update rest api format

* fix eslint

* fix issue in comments

* fix rest api format

* add indent in json in experiments manager

* fix unittest

* fix unittest

* refector file lock

* fix eslint

* remove '__enter__' in filelock

* filelock support never expire
Co-authored-by: default avatarNing Shang <nishang@microsoft.com>
parent fc0ff8ce
......@@ -5,11 +5,14 @@ import os
import sys
import json
import tempfile
import time
import socket
import string
import random
import ruamel.yaml as yaml
import psutil
import filelock
import glob
from colorama import Fore
from .constants import ERROR_INFO, NORMAL_INFO, WARNING_INFO
......@@ -95,3 +98,36 @@ def generate_temp_dir():
temp_dir = generate_folder_name()
os.makedirs(temp_dir)
return temp_dir
class SimplePreemptiveLock(filelock.SoftFileLock):
'''this is a lock support check lock expiration, if you do not need check expiration, you can use SoftFileLock'''
def __init__(self, lock_file, stale=-1):
super(__class__, self).__init__(lock_file, timeout=-1)
self._lock_file_name = '{}.{}'.format(self._lock_file, os.getpid())
self._stale = stale
def _acquire(self):
open_mode = os.O_WRONLY | os.O_CREAT | os.O_EXCL | os.O_TRUNC
try:
lock_file_names = glob.glob(self._lock_file + '.*')
for file_name in lock_file_names:
if os.path.exists(file_name) and (self._stale < 0 or time.time() - os.stat(file_name).st_mtime < self._stale):
return None
fd = os.open(self._lock_file_name, open_mode)
except (IOError, OSError):
pass
else:
self._lock_file_fd = fd
return None
def _release(self):
os.close(self._lock_file_fd)
self._lock_file_fd = None
try:
os.remove(self._lock_file_name)
except OSError:
pass
return None
def get_file_lock(path: string, stale=-1):
return SimplePreemptiveLock(path + '.lock', stale=-1)
......@@ -4,8 +4,10 @@
import os
import json
import shutil
import time
from .constants import NNICTL_HOME_DIR
from .command_utils import print_error
from .common_utils import get_file_lock
class Config:
'''a util class to load and save config'''
......@@ -34,7 +36,7 @@ class Config:
if self.config:
try:
with open(self.config_file, 'w') as file:
json.dump(self.config, file)
json.dump(self.config, file, indent=4)
except IOError as error:
print('Error:', error)
return
......@@ -54,39 +56,53 @@ class Experiments:
def __init__(self, home_dir=NNICTL_HOME_DIR):
os.makedirs(home_dir, exist_ok=True)
self.experiment_file = os.path.join(home_dir, '.experiment')
self.experiments = self.read_file()
self.lock = get_file_lock(self.experiment_file, stale=2)
with self.lock:
self.experiments = self.read_file()
def add_experiment(self, expId, port, startTime, file_name, platform, experiment_name, endTime='N/A', status='INITIALIZED'):
'''set {key:value} paris to self.experiment'''
self.experiments[expId] = {}
self.experiments[expId]['port'] = port
self.experiments[expId]['startTime'] = startTime
self.experiments[expId]['endTime'] = endTime
self.experiments[expId]['status'] = status
self.experiments[expId]['fileName'] = file_name
self.experiments[expId]['platform'] = platform
self.experiments[expId]['experimentName'] = experiment_name
self.write_file()
def add_experiment(self, expId, port, startTime, platform, experiment_name, endTime='N/A', status='INITIALIZED',
tag=[], pid=None, webuiUrl=[], logDir=[]):
'''set {key:value} pairs to self.experiment'''
with self.lock:
self.experiments = self.read_file()
self.experiments[expId] = {}
self.experiments[expId]['id'] = expId
self.experiments[expId]['port'] = port
self.experiments[expId]['startTime'] = startTime
self.experiments[expId]['endTime'] = endTime
self.experiments[expId]['status'] = status
self.experiments[expId]['platform'] = platform
self.experiments[expId]['experimentName'] = experiment_name
self.experiments[expId]['tag'] = tag
self.experiments[expId]['pid'] = pid
self.experiments[expId]['webuiUrl'] = webuiUrl
self.experiments[expId]['logDir'] = logDir
self.write_file()
def update_experiment(self, expId, key, value):
'''Update experiment'''
if expId not in self.experiments:
return False
self.experiments[expId][key] = value
self.write_file()
return True
with self.lock:
self.experiments = self.read_file()
if expId not in self.experiments:
return False
self.experiments[expId][key] = value
self.write_file()
return True
def remove_experiment(self, expId):
'''remove an experiment by id'''
if expId in self.experiments:
fileName = self.experiments.pop(expId).get('fileName')
if fileName:
logPath = os.path.join(NNICTL_HOME_DIR, fileName)
try:
shutil.rmtree(logPath)
except FileNotFoundError:
print_error('{0} does not exist.'.format(logPath))
self.write_file()
with self.lock:
self.experiments = self.read_file()
if expId in self.experiments:
self.experiments.pop(expId)
fileName = expId
if fileName:
logPath = os.path.join(NNICTL_HOME_DIR, fileName)
try:
shutil.rmtree(logPath)
except FileNotFoundError:
print_error('{0} does not exist.'.format(logPath))
self.write_file()
def get_all_experiments(self):
'''return all of experiments'''
......@@ -96,7 +112,7 @@ class Experiments:
'''save config to local file'''
try:
with open(self.experiment_file, 'w') as file:
json.dump(self.experiments, file)
json.dump(self.experiments, file, indent=4)
except IOError as error:
print('Error:', error)
return ''
......
......@@ -4,7 +4,7 @@
import os
from colorama import Fore
NNICTL_HOME_DIR = os.path.join(os.path.expanduser('~'), '.local', 'nnictl')
NNICTL_HOME_DIR = os.path.join(os.path.expanduser('~'), 'nni-experiments')
NNI_HOME_DIR = os.path.join(os.path.expanduser('~'), 'nni-experiments')
......
......@@ -23,10 +23,11 @@ from .constants import NNICTL_HOME_DIR, ERROR_INFO, REST_TIME_OUT, EXPERIMENT_SU
from .command_utils import check_output_command, kill_command
from .nnictl_utils import update_experiment
def get_log_path(config_file_name):
def get_log_path(experiment_id):
'''generate stdout and stderr log path'''
stdout_full_path = os.path.join(NNICTL_HOME_DIR, config_file_name, 'stdout')
stderr_full_path = os.path.join(NNICTL_HOME_DIR, config_file_name, 'stderr')
os.makedirs(os.path.join(NNICTL_HOME_DIR, experiment_id, 'log'), exist_ok=True)
stdout_full_path = os.path.join(NNICTL_HOME_DIR, experiment_id, 'log', 'nnictl_stdout.log')
stderr_full_path = os.path.join(NNICTL_HOME_DIR, experiment_id, 'log', 'nnictl_stderr.log')
return stdout_full_path, stderr_full_path
def print_log_content(config_file_name):
......@@ -38,7 +39,7 @@ def print_log_content(config_file_name):
print_normal(' Stderr:')
print(check_output_command(stderr_full_path))
def start_rest_server(port, platform, mode, config_file_name, foreground=False, experiment_id=None, log_dir=None, log_level=None):
def start_rest_server(port, platform, mode, experiment_id, foreground=False, log_dir=None, log_level=None):
'''Run nni manager process'''
if detect_port(port):
print_error('Port %s is used by another process, please reset the port!\n' \
......@@ -63,7 +64,8 @@ def start_rest_server(port, platform, mode, config_file_name, foreground=False,
node_command = os.path.join(entry_dir, 'node.exe')
else:
node_command = os.path.join(entry_dir, 'node')
cmds = [node_command, '--max-old-space-size=4096', entry_file, '--port', str(port), '--mode', platform]
cmds = [node_command, '--max-old-space-size=4096', entry_file, '--port', str(port), '--mode', platform, \
'--experiment_id', experiment_id]
if mode == 'view':
cmds += ['--start_mode', 'resume']
cmds += ['--readonly', 'true']
......@@ -73,13 +75,12 @@ def start_rest_server(port, platform, mode, config_file_name, foreground=False,
cmds += ['--log_dir', log_dir]
if log_level is not None:
cmds += ['--log_level', log_level]
if mode in ['resume', 'view']:
cmds += ['--experiment_id', experiment_id]
if foreground:
cmds += ['--foreground', 'true']
stdout_full_path, stderr_full_path = get_log_path(config_file_name)
stdout_full_path, stderr_full_path = get_log_path(experiment_id)
with open(stdout_full_path, 'a+') as stdout_file, open(stderr_full_path, 'a+') as stderr_file:
time_now = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time()))
start_time = time.time()
time_now = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(start_time))
#add time information in the header of log files
log_header = LOG_HEADER % str(time_now)
stdout_file.write(log_header)
......@@ -95,7 +96,7 @@ def start_rest_server(port, platform, mode, config_file_name, foreground=False,
process = Popen(cmds, cwd=entry_dir, stdout=PIPE, stderr=PIPE)
else:
process = Popen(cmds, cwd=entry_dir, stdout=stdout_file, stderr=stderr_file)
return process, str(time_now)
return process, int(start_time * 1000)
def set_trial_config(experiment_config, port, config_file_name):
'''set trial configuration'''
......@@ -432,9 +433,9 @@ def set_platform_config(platform, experiment_config, port, config_file_name, res
raise Exception(ERROR_INFO % 'Rest server stopped!')
exit(1)
def launch_experiment(args, experiment_config, mode, config_file_name, experiment_id=None):
def launch_experiment(args, experiment_config, mode, experiment_id):
'''follow steps to start rest server and start experiment'''
nni_config = Config(config_file_name)
nni_config = Config(experiment_id)
# check packages for tuner
package_name, module_name = None, None
if experiment_config.get('tuner') and experiment_config['tuner'].get('builtinTunerName'):
......@@ -445,15 +446,15 @@ def launch_experiment(args, experiment_config, mode, config_file_name, experimen
module_name, _ = get_builtin_module_class_name('advisors', package_name)
if package_name and module_name:
try:
stdout_full_path, stderr_full_path = get_log_path(config_file_name)
stdout_full_path, stderr_full_path = get_log_path(experiment_id)
with open(stdout_full_path, 'a+') as stdout_file, open(stderr_full_path, 'a+') as stderr_file:
check_call([sys.executable, '-c', 'import %s'%(module_name)], stdout=stdout_file, stderr=stderr_file)
except CalledProcessError:
print_error('some errors happen when import package %s.' %(package_name))
print_log_content(config_file_name)
print_log_content(experiment_id)
if package_name in INSTALLABLE_PACKAGE_META:
print_error('If %s is not installed, it should be installed through '\
'\'nnictl package install --name %s\''%(package_name, package_name))
'\'nnictl package install --name %s\'' % (package_name, package_name))
exit(1)
log_dir = experiment_config['logDir'] if experiment_config.get('logDir') else None
log_level = experiment_config['logLevel'] if experiment_config.get('logLevel') else None
......@@ -465,7 +466,7 @@ def launch_experiment(args, experiment_config, mode, config_file_name, experimen
log_level = 'debug'
# start rest server
rest_process, start_time = start_rest_server(args.port, experiment_config['trainingServicePlatform'], \
mode, config_file_name, foreground, experiment_id, log_dir, log_level)
mode, experiment_id, foreground, log_dir, log_level)
nni_config.set_config('restServerPid', rest_process.pid)
# Deal with annotation
if experiment_config.get('useAnnotation'):
......@@ -491,7 +492,7 @@ def launch_experiment(args, experiment_config, mode, config_file_name, experimen
print_normal('Successfully started Restful server!')
else:
print_error('Restful server start failed!')
print_log_content(config_file_name)
print_log_content(experiment_id)
try:
kill_command(rest_process.pid)
except Exception:
......@@ -500,21 +501,25 @@ def launch_experiment(args, experiment_config, mode, config_file_name, experimen
if mode != 'view':
# set platform configuration
set_platform_config(experiment_config['trainingServicePlatform'], experiment_config, args.port,\
config_file_name, rest_process)
experiment_id, rest_process)
# start a new experiment
print_normal('Starting experiment...')
# save experiment information
nnictl_experiment_config = Experiments()
nnictl_experiment_config.add_experiment(experiment_id, args.port, start_time,
experiment_config['trainingServicePlatform'],
experiment_config['experimentName'], pid=rest_process.pid, logDir=log_dir)
# set debug configuration
if mode != 'view' and experiment_config.get('debug') is None:
experiment_config['debug'] = args.debug
response = set_experiment(experiment_config, mode, args.port, config_file_name)
response = set_experiment(experiment_config, mode, args.port, experiment_id)
if response:
if experiment_id is None:
experiment_id = json.loads(response.text).get('experiment_id')
nni_config.set_config('experimentId', experiment_id)
else:
print_error('Start experiment failed!')
print_log_content(config_file_name)
print_log_content(experiment_id)
try:
kill_command(rest_process.pid)
except Exception:
......@@ -526,12 +531,6 @@ def launch_experiment(args, experiment_config, mode, config_file_name, experimen
web_ui_url_list = get_local_urls(args.port)
nni_config.set_config('webuiUrl', web_ui_url_list)
# save experiment information
nnictl_experiment_config = Experiments()
nnictl_experiment_config.add_experiment(experiment_id, args.port, start_time, config_file_name,
experiment_config['trainingServicePlatform'],
experiment_config['experimentName'])
print_normal(EXPERIMENT_SUCCESS_INFO % (experiment_id, ' '.join(web_ui_url_list)))
if mode != 'view' and args.foreground:
try:
......@@ -544,8 +543,9 @@ def launch_experiment(args, experiment_config, mode, config_file_name, experimen
def create_experiment(args):
'''start a new experiment'''
config_file_name = ''.join(random.sample(string.ascii_letters + string.digits, 8))
nni_config = Config(config_file_name)
experiment_id = ''.join(random.sample(string.ascii_letters + string.digits, 8))
nni_config = Config(experiment_id)
nni_config.set_config('experimentId', experiment_id)
config_path = os.path.abspath(args.config)
if not os.path.exists(config_path):
print_error('Please set correct config path!')
......@@ -560,9 +560,9 @@ def create_experiment(args):
nni_config.set_config('experimentConfig', experiment_config)
nni_config.set_config('restServerPort', args.port)
try:
launch_experiment(args, experiment_config, 'new', config_file_name)
launch_experiment(args, experiment_config, 'new', experiment_id)
except Exception as exception:
nni_config = Config(config_file_name)
nni_config = Config(experiment_id)
restServerPid = nni_config.get_config('restServerPid')
if restServerPid:
kill_command(restServerPid)
......@@ -589,17 +589,13 @@ def manage_stopped_experiment(args, mode):
exit(1)
experiment_id = args.id
print_normal('{0} experiment {1}...'.format(mode, experiment_id))
nni_config = Config(experiment_dict[experiment_id]['fileName'])
nni_config = Config(experiment_id)
experiment_config = nni_config.get_config('experimentConfig')
experiment_id = nni_config.get_config('experimentId')
new_config_file_name = ''.join(random.sample(string.ascii_letters + string.digits, 8))
new_nni_config = Config(new_config_file_name)
new_nni_config.set_config('experimentConfig', experiment_config)
new_nni_config.set_config('restServerPort', args.port)
nni_config.set_config('restServerPort', args.port)
try:
launch_experiment(args, experiment_config, mode, new_config_file_name, experiment_id)
launch_experiment(args, experiment_config, mode, experiment_id)
except Exception as exception:
nni_config = Config(new_config_file_name)
nni_config = Config(experiment_id)
restServerPid = nni_config.get_config('restServerPid')
if restServerPid:
kill_command(restServerPid)
......
......@@ -32,6 +32,8 @@ def parse_time(time):
def parse_path(experiment_config, config_path):
'''Parse path in config file'''
expand_path(experiment_config, 'searchSpacePath')
if experiment_config.get('logDir'):
expand_path(experiment_config, 'logDir')
if experiment_config.get('trial'):
expand_path(experiment_config['trial'], 'codeDir')
if experiment_config['trial'].get('authFile'):
......@@ -65,6 +67,8 @@ def parse_path(experiment_config, config_path):
root_path = os.path.dirname(config_path)
if experiment_config.get('searchSpacePath'):
parse_relative_path(root_path, experiment_config, 'searchSpacePath')
if experiment_config.get('logDir'):
parse_relative_path(root_path, experiment_config, 'logDir')
if experiment_config.get('trial'):
parse_relative_path(root_path, experiment_config['trial'], 'codeDir')
if experiment_config['trial'].get('authFile'):
......
......@@ -30,7 +30,7 @@ def get_experiment_time(port):
'''get the startTime and endTime of an experiment'''
response = rest_get(experiment_url(port), REST_TIME_OUT)
if response and check_response(response):
content = convert_time_stamp_to_date(json.loads(response.text))
content = json.loads(response.text)
return content.get('startTime'), content.get('endTime')
return None, None
......@@ -50,20 +50,11 @@ def update_experiment():
for key in experiment_dict.keys():
if isinstance(experiment_dict[key], dict):
if experiment_dict[key].get('status') != 'STOPPED':
nni_config = Config(experiment_dict[key]['fileName'])
nni_config = Config(key)
rest_pid = nni_config.get_config('restServerPid')
if not detect_process(rest_pid):
experiment_config.update_experiment(key, 'status', 'STOPPED')
continue
rest_port = nni_config.get_config('restServerPort')
startTime, endTime = get_experiment_time(rest_port)
if startTime:
experiment_config.update_experiment(key, 'startTime', startTime)
if endTime:
experiment_config.update_experiment(key, 'endTime', endTime)
status = get_experiment_status(rest_port)
if status:
experiment_config.update_experiment(key, 'status', status)
def check_experiment_id(args, update=True):
'''check if the id is valid
......@@ -184,9 +175,7 @@ def get_config_filename(args):
if experiment_id is None:
print_error('Please set correct experiment id.')
exit(1)
experiment_config = Experiments()
experiment_dict = experiment_config.get_all_experiments()
return experiment_dict[experiment_id]['fileName']
return experiment_id
def get_experiment_port(args):
'''get the port of experiment'''
......@@ -228,11 +217,9 @@ def stop_experiment(args):
exit(1)
experiment_id_list = parse_ids(args)
if experiment_id_list:
experiment_config = Experiments()
experiment_dict = experiment_config.get_all_experiments()
for experiment_id in experiment_id_list:
print_normal('Stopping experiment %s' % experiment_id)
nni_config = Config(experiment_dict[experiment_id]['fileName'])
nni_config = Config(experiment_id)
rest_pid = nni_config.get_config('restServerPid')
if rest_pid:
kill_command(rest_pid)
......@@ -245,9 +232,6 @@ def stop_experiment(args):
print_error(exception)
nni_config.set_config('tensorboardPidList', [])
print_normal('Stop experiment success.')
experiment_config.update_experiment(experiment_id, 'status', 'STOPPED')
time_now = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time()))
experiment_config.update_experiment(experiment_id, 'endTime', str(time_now))
def trial_ls(args):
'''List trial'''
......@@ -553,7 +537,7 @@ def experiment_clean(args):
else:
break
for experiment_id in experiment_id_list:
nni_config = Config(experiment_dict[experiment_id]['fileName'])
nni_config = Config(experiment_id)
platform = nni_config.get_config('experimentConfig').get('trainingServicePlatform')
experiment_id = nni_config.get_config('experimentId')
if platform == 'remote':
......@@ -668,18 +652,15 @@ def experiment_list(args):
experiment_dict[key]['status'],
experiment_dict[key]['port'],
experiment_dict[key].get('platform'),
experiment_dict[key]['startTime'],
experiment_dict[key]['endTime'])
time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(experiment_dict[key]['startTime'] / 1000)) if isinstance(experiment_dict[key]['startTime'], int) else experiment_dict[key]['startTime'],
time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(experiment_dict[key]['endTime'] / 1000)) if isinstance(experiment_dict[key]['endTime'], int) else experiment_dict[key]['endTime'])
print(EXPERIMENT_INFORMATION_FORMAT % experiment_information)
return experiment_id_list
def get_time_interval(time1, time2):
'''get the interval of two times'''
try:
#convert time to timestamp
time1 = time.mktime(time.strptime(time1, '%Y/%m/%d %H:%M:%S'))
time2 = time.mktime(time.strptime(time2, '%Y/%m/%d %H:%M:%S'))
seconds = (datetime.fromtimestamp(time2) - datetime.fromtimestamp(time1)).seconds
seconds = int((time2 - time1) / 1000)
#convert seconds to day:hour:minute:second
days = seconds / 86400
seconds %= 86400
......@@ -708,8 +689,8 @@ def show_experiment_info():
return
for key in experiment_id_list:
print(EXPERIMENT_MONITOR_INFO % (key, experiment_dict[key]['status'], experiment_dict[key]['port'], \
experiment_dict[key].get('platform'), experiment_dict[key]['startTime'], \
get_time_interval(experiment_dict[key]['startTime'], experiment_dict[key]['endTime'])))
experiment_dict[key].get('platform'), time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(experiment_dict[key]['startTime'] / 1000)) if isinstance(experiment_dict[key]['startTime'], int) else experiment_dict[key]['startTime'], \
get_time_interval(experiment_dict[key]['startTime'], experiment_dict[key]['endTime'])))
print(TRIAL_MONITOR_HEAD)
running, response = check_rest_server_quick(experiment_dict[key]['port'])
if running:
......@@ -850,7 +831,7 @@ def save_experiment(args):
print_error('Can only save stopped experiment!')
exit(1)
print_normal('Saving...')
nni_config = Config(experiment_dict[args.id]['fileName'])
nni_config = Config(args.id)
logDir = os.path.join(NNI_HOME_DIR, args.id)
if nni_config.get_config('logDir'):
logDir = os.path.join(nni_config.get_config('logDir'), args.id)
......@@ -873,8 +854,8 @@ def save_experiment(args):
except IOError:
print_error('Write file to %s failed!' % os.path.join(temp_nnictl_dir, '.experiment'))
exit(1)
nnictl_config_dir = os.path.join(NNICTL_HOME_DIR, experiment_dict[args.id]['fileName'])
shutil.copytree(nnictl_config_dir, os.path.join(temp_nnictl_dir, experiment_dict[args.id]['fileName']))
nnictl_config_dir = os.path.join(NNICTL_HOME_DIR, args.id)
shutil.copytree(nnictl_config_dir, os.path.join(temp_nnictl_dir, args.id))
# Step3. Copy code dir
if args.saveCodeDir:
......@@ -947,20 +928,20 @@ def load_experiment(args):
print_error('Invalid: experiment id already exist!')
shutil.rmtree(temp_root_dir)
exit(1)
if not os.path.exists(os.path.join(nnictl_temp_dir, experiment_metadata.get('fileName'))):
if not os.path.exists(os.path.join(nnictl_temp_dir, experiment_id)):
print_error('Invalid: experiment metadata does not exist!')
shutil.rmtree(temp_root_dir)
exit(1)
# Step2. Copy nnictl metadata
src_path = os.path.join(nnictl_temp_dir, experiment_metadata.get('fileName'))
dest_path = os.path.join(NNICTL_HOME_DIR, experiment_metadata.get('fileName'))
src_path = os.path.join(nnictl_temp_dir, experiment_id)
dest_path = os.path.join(NNICTL_HOME_DIR, experiment_id)
if os.path.exists(dest_path):
shutil.rmtree(dest_path)
shutil.copytree(src_path, dest_path)
# Step3. Copy experiment data
nni_config = Config(experiment_metadata.get('fileName'))
nni_config = Config(experiment_id)
nnictl_exp_config = nni_config.get_config('experimentConfig')
if args.logDir:
logDir = args.logDir
......@@ -1027,13 +1008,15 @@ def load_experiment(args):
experiment_config.add_experiment(experiment_id,
experiment_metadata.get('port'),
experiment_metadata.get('startTime'),
experiment_metadata.get('fileName'),
experiment_metadata.get('platform'),
experiment_metadata.get('experimentName'),
experiment_metadata.get('endTime'),
experiment_metadata.get('status'))
experiment_metadata.get('status'),
experiment_metadata.get('tag'),
experiment_metadata.get('pid'),
experiment_metadata.get('webUrl'),
experiment_metadata.get('logDir'))
print_normal('Load experiment %s succsss!' % experiment_id)
# Step6. Cleanup temp data
shutil.rmtree(temp_root_dir)
......@@ -11,7 +11,7 @@ from .config_utils import Config, Experiments
from .url_utils import trial_jobs_url, get_local_urls
from .constants import REST_TIME_OUT
from .common_utils import print_normal, print_warning, print_error, print_green, detect_process, detect_port, check_tensorboard_version
from .nnictl_utils import check_experiment_id, check_experiment_id
from .nnictl_utils import check_experiment_id
from .ssh_utils import create_ssh_sftp_client, copy_remote_directory_to_local
def parse_log_path(args, trial_content):
......@@ -95,8 +95,7 @@ def stop_tensorboard(args):
experiment_id = check_experiment_id(args)
experiment_config = Experiments()
experiment_dict = experiment_config.get_all_experiments()
config_file_name = experiment_dict[experiment_id]['fileName']
nni_config = Config(config_file_name)
nni_config = Config(experiment_id)
tensorboard_pid_list = nni_config.get_config('tensorboardPidList')
if tensorboard_pid_list:
for tensorboard_pid in tensorboard_pid_list:
......@@ -136,7 +135,7 @@ def start_tensorboard(args):
print_error("Experiment {} is stopped...".format(args.id))
return
config_file_name = experiment_dict[experiment_id]['fileName']
nni_config = Config(config_file_name)
nni_config = Config(args.id)
if nni_config.get_config('experimentConfig').get('trainingServicePlatform') == 'adl':
adl_tensorboard_helper(args)
return
......
......@@ -73,6 +73,7 @@ dependencies = [
'scikit-learn>=0.23.2',
'pkginfo',
'websockets',
'filelock',
'prettytable'
]
......
......@@ -11,9 +11,9 @@ from nni.tools.nnictl.nnictl_utils import get_yml_content
def create_mock_experiment():
nnictl_experiment_config = Experiments()
nnictl_experiment_config.add_experiment('xOpEwA5w', '8080', '1970/01/1 01:01:01', 'aGew0x',
nnictl_experiment_config.add_experiment('xOpEwA5w', '8080', 123456,
'local', 'example_sklearn-classification')
nni_config = Config('aGew0x')
nni_config = Config('xOpEwA5w')
# mock process
cmds = ['sleep', '3600000']
process = Popen(cmds, stdout=PIPE, stderr=STDOUT)
......
......@@ -19,7 +19,7 @@ class CommonUtilsTestCase(TestCase):
def test_update_experiment(self):
experiment = Experiments(HOME_PATH)
experiment.add_experiment('xOpEwA5w', 8081, 'N/A', 'aGew0x', 'local', 'test', endTime='N/A', status='INITIALIZED')
experiment.add_experiment('xOpEwA5w', 8081, 'N/A', 'local', 'test', endTime='N/A', status='INITIALIZED')
self.assertTrue('xOpEwA5w' in experiment.get_all_experiments())
experiment.remove_experiment('xOpEwA5w')
self.assertFalse('xOpEwA5w' in experiment.get_all_experiments())
......
......@@ -46,7 +46,7 @@ class CommonUtilsTestCase(TestCase):
@responses.activate
def test_get_config_file_name(self):
args = generate_args()
self.assertEqual('aGew0x', get_config_filename(args))
self.assertEqual('xOpEwA5w', get_config_filename(args))
@responses.activate
def test_get_experiment_port(self):
......
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
'use strict';
abstract class ExperimentManager {
public abstract getExperimentsInfo(): Promise<JSON>;
public abstract setExperimentPath(newPath: string): void;
public abstract setExperimentInfo(experimentId: string, key: string, value: any): void;
public abstract stop(): Promise<void>;
}
export {ExperimentManager};
......@@ -11,13 +11,16 @@ import { ChildProcess, spawn, StdioOptions } from 'child_process';
import * as fs from 'fs';
import * as os from 'os';
import * as path from 'path';
import * as lockfile from 'lockfile';
import { Deferred } from 'ts-deferred';
import { Container } from 'typescript-ioc';
import * as util from 'util';
import * as glob from 'glob';
import { Database, DataStore } from './datastore';
import { ExperimentStartupInfo, getExperimentStartupInfo, setExperimentStartupInfo } from './experimentStartupInfo';
import { ExperimentParams, Manager } from './manager';
import { ExperimentManager } from './experimentManager';
import { HyperParameters, TrainingService, TrialJobStatus } from './trainingService';
import { logLevelNameMap } from './log';
......@@ -43,6 +46,10 @@ function getCheckpointDir(): string {
return path.join(getExperimentRootDir(), 'checkpoint');
}
function getExperimentsInfoPath(): string {
return path.join(os.homedir(), 'nni-experiments', '.experiment');
}
function mkDirP(dirPath: string): Promise<void> {
const deferred: Deferred<void> = new Deferred<void>();
fs.exists(dirPath, (exists: boolean) => {
......@@ -184,6 +191,7 @@ function prepareUnitTest(): void {
Container.snapshot(DataStore);
Container.snapshot(TrainingService);
Container.snapshot(Manager);
Container.snapshot(ExperimentManager);
const logLevel: string = parseArg(['--log_level', '-ll']);
if (logLevel.length > 0 && !logLevelNameMap.has(logLevel)) {
......@@ -211,6 +219,7 @@ function cleanupUnitTest(): void {
Container.restore(DataStore);
Container.restore(Database);
Container.restore(ExperimentStartupInfo);
Container.restore(ExperimentManager);
}
let cachedipv4Address: string = '';
......@@ -416,8 +425,29 @@ function unixPathJoin(...paths: any[]): string {
return dir;
}
/**
* lock a file sync
*/
function withLockSync(func: Function, filePath: string, lockOpts: {[key: string]: any}, ...args: any): any {
const lockName = path.join(path.dirname(filePath), path.basename(filePath) + `.lock.${process.pid}`);
if (typeof lockOpts.stale === 'number'){
const lockPath = path.join(path.dirname(filePath), path.basename(filePath) + '.lock.*');
const lockFileNames: string[] = glob.sync(lockPath);
const canLock: boolean = lockFileNames.map((fileName) => {
return fs.existsSync(fileName) && Date.now() - fs.statSync(fileName).mtimeMs > lockOpts.stale;
}).filter(isExpired=>isExpired === false).length === 0;
if (!canLock) {
throw new Error('File has been locked.');
}
}
lockfile.lockSync(lockName, lockOpts);
const result = func(...args);
lockfile.unlockSync(lockName);
return result;
}
export {
countFilesRecursively, validateFileNameRecursively, generateParamFileName, getMsgDispatcherCommand, getCheckpointDir,
getLogDir, getExperimentRootDir, getJobCancelStatus, getDefaultDatabaseDir, getIPV4Address, unixPathJoin,
countFilesRecursively, validateFileNameRecursively, generateParamFileName, getMsgDispatcherCommand, getCheckpointDir, getExperimentsInfoPath,
getLogDir, getExperimentRootDir, getJobCancelStatus, getDefaultDatabaseDir, getIPV4Address, unixPathJoin, withLockSync,
mkDirP, mkDirPSync, delay, prepareUnitTest, parseArg, cleanupUnitTest, uniqueString, randomInt, randomSelect, getLogLevel, getVersion, getCmdPy, getTunerProc, isAlive, killPid, getNewLine
};
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
'use strict';
import * as fs from 'fs';
import * as os from 'os';
import * as path from 'path';
import * as assert from 'assert';
import { getLogger, Logger } from '../common/log';
import { isAlive, withLockSync, getExperimentsInfoPath, delay } from '../common/utils';
import { ExperimentManager } from '../common/experimentManager';
import { Deferred } from 'ts-deferred';
interface CrashedInfo {
experimentId: string;
isCrashed: boolean;
}
interface FileInfo {
buffer: Buffer;
mtime: number;
}
class NNIExperimentsManager implements ExperimentManager {
private experimentsPath: string;
private log: Logger;
private profileUpdateTimer: {[key: string]: any};
constructor() {
this.experimentsPath = getExperimentsInfoPath();
this.log = getLogger();
this.profileUpdateTimer = {};
}
public async getExperimentsInfo(): Promise<JSON> {
const fileInfo: FileInfo = await this.withLockIterated(this.readExperimentsInfo, 100);
const experimentsInformation = JSON.parse(fileInfo.buffer.toString());
const expIdList: Array<string> = Object.keys(experimentsInformation).filter((expId) => {
return experimentsInformation[expId]['status'] !== 'STOPPED';
});
const updateList: Array<CrashedInfo> = (await Promise.all(expIdList.map((expId) => {
return this.checkCrashed(expId, experimentsInformation[expId]['pid']);
}))).filter(crashedInfo => crashedInfo.isCrashed);
if (updateList.length > 0){
const result = await this.withLockIterated(this.updateAllStatus, 100, updateList.map(crashedInfo => crashedInfo.experimentId), fileInfo.mtime);
if (result !== undefined) {
return JSON.parse(JSON.stringify(Object.keys(result).map(key=>result[key])));
} else {
await delay(500);
return await this.getExperimentsInfo();
}
} else {
return JSON.parse(JSON.stringify(Object.keys(experimentsInformation).map(key=>experimentsInformation[key])));
}
}
public setExperimentPath(newPath: string): void {
if (newPath[0] === '~') {
newPath = path.join(os.homedir(), newPath.slice(1));
}
if (!path.isAbsolute(newPath)) {
newPath = path.resolve(newPath);
}
this.log.info(`Set new experiment information path: ${newPath}`);
this.experimentsPath = newPath;
}
public setExperimentInfo(experimentId: string, key: string, value: any): void {
try {
if (this.profileUpdateTimer[key] !== undefined) {
// if a new call with the same timerId occurs, destroy the unfinished old one
clearTimeout(this.profileUpdateTimer[key]);
this.profileUpdateTimer[key] = undefined;
}
this.withLockSync(() => {
const experimentsInformation = JSON.parse(fs.readFileSync(this.experimentsPath).toString());
assert(experimentId in experimentsInformation, `Experiment Manager: Experiment Id ${experimentId} not found, this should not happen`);
experimentsInformation[experimentId][key] = value;
fs.writeFileSync(this.experimentsPath, JSON.stringify(experimentsInformation, null, 4));
});
} catch (err) {
this.log.error(err);
this.log.debug(`Experiment Manager: Retry set key value: ${experimentId} {${key}: ${value}}`);
if (err.code === 'EEXIST' || err.message === 'File has been locked.') {
this.profileUpdateTimer[key] = setTimeout(this.setExperimentInfo.bind(this), 100, experimentId, key, value);
}
}
}
private async withLockIterated (func: Function, retry: number, ...args: any): Promise<any> {
if (retry < 0) {
throw new Error('Lock file out of retries.');
}
try {
return this.withLockSync(func, ...args);
} catch(err) {
if (err.code === 'EEXIST' || err.message === 'File has been locked.') {
// retry wait is 50ms
await delay(50);
return await this.withLockIterated(func, retry - 1, ...args);
}
throw err;
}
}
private withLockSync (func: Function, ...args: any): any {
return withLockSync(func.bind(this), this.experimentsPath, {stale: 2 * 1000}, ...args);
}
private readExperimentsInfo(): FileInfo {
const buffer: Buffer = fs.readFileSync(this.experimentsPath);
const mtime: number = fs.statSync(this.experimentsPath).mtimeMs;
return {buffer: buffer, mtime: mtime};
}
private async checkCrashed(expId: string, pid: number): Promise<CrashedInfo> {
const alive: boolean = await isAlive(pid);
return {experimentId: expId, isCrashed: !alive}
}
private updateAllStatus(updateList: Array<string>, timestamp: number): {[key: string]: any} | undefined {
if (timestamp !== fs.statSync(this.experimentsPath).mtimeMs) {
return;
} else {
const experimentsInformation = JSON.parse(fs.readFileSync(this.experimentsPath).toString());
updateList.forEach((expId: string) => {
if (experimentsInformation[expId]) {
experimentsInformation[expId]['status'] = 'STOPPED';
} else {
this.log.error(`Experiment Manager: Experiment Id ${expId} not found, this should not happen`);
}
});
fs.writeFileSync(this.experimentsPath, JSON.stringify(experimentsInformation, null, 4));
return experimentsInformation;
}
}
public async stop(): Promise<void> {
this.log.debug('Stopping experiment manager.');
await this.cleanUp().catch(err=>this.log.error(err.message));
this.log.debug('Experiment manager stopped.');
}
private async cleanUp(): Promise<void> {
const deferred = new Deferred<void>();
if (this.isUndone()) {
this.log.debug('Experiment manager: something undone');
setTimeout(((deferred: Deferred<void>): void => {
if (this.isUndone()) {
deferred.reject(new Error('Still has undone after 5s, forced stop.'));
} else {
deferred.resolve();
}
}).bind(this), 5 * 1000, deferred);
} else {
this.log.debug('Experiment manager: all clean up');
deferred.resolve();
}
return deferred.promise;
}
private isUndone(): boolean {
return Object.keys(this.profileUpdateTimer).filter((key: string) => {
return this.profileUpdateTimer[key] !== undefined;
}).length > 0;
}
}
export { NNIExperimentsManager };
......@@ -15,6 +15,7 @@ import {
ExperimentParams, ExperimentProfile, Manager, ExperimentStatus,
NNIManagerStatus, ProfileUpdateType, TrialJobStatistics
} from '../common/manager';
import { ExperimentManager } from '../common/experimentManager';
import {
TrainingService, TrialJobApplicationForm, TrialJobDetail, TrialJobMetric, TrialJobStatus, LogType
} from '../common/trainingService';
......@@ -31,6 +32,7 @@ import { createDispatcherInterface, IpcInterface } from './ipcInterface';
class NNIManager implements Manager {
private trainingService: TrainingService;
private dispatcher: IpcInterface | undefined;
private experimentManager: ExperimentManager;
private currSubmittedTrialNum: number; // need to be recovered
private trialConcurrencyChange: number; // >0: increase, <0: decrease
private log: Logger;
......@@ -49,6 +51,7 @@ class NNIManager implements Manager {
this.currSubmittedTrialNum = 0;
this.trialConcurrencyChange = 0;
this.trainingService = component.get(TrainingService);
this.experimentManager = component.get(ExperimentManager);
assert(this.trainingService);
this.dispatcherPid = 0;
this.waitingTrials = [];
......@@ -467,7 +470,9 @@ class NNIManager implements Manager {
}
}
await this.trainingService.cleanUp();
this.experimentProfile.endTime = Date.now();
if (this.experimentProfile.endTime === undefined) {
this.setEndtime();
}
await this.storeExperimentProfile();
this.setStatus('STOPPED');
}
......@@ -596,7 +601,7 @@ class NNIManager implements Manager {
assert(allFinishedTrialJobNum <= waitSubmittedToFinish);
if (allFinishedTrialJobNum >= waitSubmittedToFinish) {
this.setStatus('DONE');
this.experimentProfile.endTime = Date.now();
this.setEndtime();
await this.storeExperimentProfile();
// write this log for travis CI
this.log.info('Experiment done.');
......@@ -796,6 +801,7 @@ class NNIManager implements Manager {
this.log.error(err.stack);
}
this.status.errors.push(err.message);
this.setEndtime();
this.setStatus('ERROR');
}
......@@ -803,9 +809,15 @@ class NNIManager implements Manager {
if (status !== this.status.status) {
this.log.info(`Change NNIManager status from: ${this.status.status} to: ${status}`);
this.status.status = status;
this.experimentManager.setExperimentInfo(this.experimentProfile.id, 'status', this.status.status);
}
}
private setEndtime(): void {
this.experimentProfile.endTime = Date.now();
this.experimentManager.setExperimentInfo(this.experimentProfile.id, 'endTime', this.experimentProfile.endTime);
}
private createEmptyExperimentProfile(): ExperimentProfile {
return {
id: getExperimentId(),
......
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
'use strict';
import { assert, expect } from 'chai';
import * as fs from 'fs';
import { Container, Scope } from 'typescript-ioc';
import * as component from '../../common/component';
import { cleanupUnitTest, prepareUnitTest } from '../../common/utils';
import { ExperimentManager } from '../../common/experimentManager';
import { NNIExperimentsManager } from '../nniExperimentsManager';
describe('Unit test for experiment manager', function () {
let experimentManager: NNIExperimentsManager;
const mockedInfo = {
"test": {
"port": 8080,
"startTime": 1605246730756,
"endTime": "N/A",
"status": "INITIALIZED",
"platform": "local",
"experimentName": "testExp",
"tag": [], "pid": 11111,
"webuiUrl": [],
"logDir": null
}
}
before(() => {
prepareUnitTest();
fs.writeFileSync('.experiment.test', JSON.stringify(mockedInfo));
Container.bind(ExperimentManager).to(NNIExperimentsManager).scope(Scope.Singleton);
experimentManager = component.get(NNIExperimentsManager);
experimentManager.setExperimentPath('.experiment.test');
});
after(() => {
if (fs.existsSync('.experiment.test')) {
fs.unlinkSync('.experiment.test');
}
cleanupUnitTest();
});
it('test getExperimentsInfo', () => {
return experimentManager.getExperimentsInfo().then(function (experimentsInfo: {[key: string]: any}) {
new Array(experimentsInfo)
for (let idx in experimentsInfo) {
if (experimentsInfo[idx]['id'] === 'test') {
expect(experimentsInfo[idx]['status']).to.be.oneOf(['STOPPED', 'ERROR']);
break;
}
}
}).catch((error) => {
assert.fail(error);
})
});
});
......@@ -3,6 +3,7 @@
'use strict';
import * as fs from 'fs';
import * as os from 'os';
import { assert, expect } from 'chai';
import { Container, Scope } from 'typescript-ioc';
......@@ -10,9 +11,10 @@ import { Container, Scope } from 'typescript-ioc';
import * as component from '../../common/component';
import { Database, DataStore } from '../../common/datastore';
import { Manager, ExperimentProfile} from '../../common/manager';
import { ExperimentManager } from '../../common/experimentManager';
import { TrainingService } from '../../common/trainingService';
import { cleanupUnitTest, prepareUnitTest } from '../../common/utils';
import { NNIDataStore } from '../nniDataStore';
import { NNIExperimentsManager } from '../nniExperimentsManager';
import { NNIManager } from '../nnimanager';
import { SqlDB } from '../sqlDatabase';
import { MockedTrainingService } from './mockedTrainingService';
......@@ -25,6 +27,7 @@ async function initContainer(): Promise<void> {
Container.bind(Manager).to(NNIManager).scope(Scope.Singleton);
Container.bind(Database).to(SqlDB).scope(Scope.Singleton);
Container.bind(DataStore).to(MockedDataStore).scope(Scope.Singleton);
Container.bind(ExperimentManager).to(NNIExperimentsManager).scope(Scope.Singleton);
await component.get<DataStore>(DataStore).init();
}
......@@ -87,9 +90,26 @@ describe('Unit test for nnimanager', function () {
revision: 0
}
let mockedInfo = {
"unittest": {
"port": 8080,
"startTime": 1605246730756,
"endTime": "N/A",
"status": "INITIALIZED",
"platform": "local",
"experimentName": "testExp",
"tag": [], "pid": 11111,
"webuiUrl": [],
"logDir": null
}
}
before(async () => {
await initContainer();
fs.writeFileSync('.experiment.test', JSON.stringify(mockedInfo));
const experimentsManager: ExperimentManager = component.get(ExperimentManager);
experimentsManager.setExperimentPath('.experiment.test');
nniManager = component.get(Manager);
const expId: string = await nniManager.startExperiment(experimentParams);
assert.strictEqual(expId, 'unittest');
......
......@@ -12,11 +12,13 @@ import { Database, DataStore } from './common/datastore';
import { setExperimentStartupInfo } from './common/experimentStartupInfo';
import { getLogger, Logger, logLevelNameMap } from './common/log';
import { Manager, ExperimentStartUpMode } from './common/manager';
import { ExperimentManager } from './common/experimentManager';
import { TrainingService } from './common/trainingService';
import { getLogDir, mkDirP, parseArg, uniqueString } from './common/utils';
import { getLogDir, mkDirP, parseArg } from './common/utils';
import { NNIDataStore } from './core/nniDataStore';
import { NNIManager } from './core/nnimanager';
import { SqlDB } from './core/sqlDatabase';
import { NNIExperimentsManager } from './core/nniExperimentsManager';
import { NNIRestServer } from './rest_server/nniRestServer';
import { FrameworkControllerTrainingService } from './training_service/kubernetes/frameworkcontroller/frameworkcontrollerTrainingService';
import { AdlTrainingService } from './training_service/kubernetes/adl/adlTrainingService';
......@@ -27,11 +29,10 @@ import { PAIYarnTrainingService } from './training_service/pai/paiYarn/paiYarnTr
import { DLTSTrainingService } from './training_service/dlts/dltsTrainingService';
function initStartupInfo(
startExpMode: string, resumeExperimentId: string, basePort: number, platform: string,
startExpMode: string, experimentId: string, basePort: number, platform: string,
logDirectory: string, experimentLogLevel: string, readonly: boolean): void {
const createNew: boolean = (startExpMode === ExperimentStartUpMode.NEW);
const expId: string = createNew ? uniqueString(8) : resumeExperimentId;
setExperimentStartupInfo(createNew, expId, basePort, platform, logDirectory, experimentLogLevel, readonly);
setExperimentStartupInfo(createNew, experimentId, basePort, platform, logDirectory, experimentLogLevel, readonly);
}
async function initContainer(foreground: boolean, platformMode: string, logFileName?: string): Promise<void> {
......@@ -83,6 +84,9 @@ async function initContainer(foreground: boolean, platformMode: string, logFileN
Container.bind(DataStore)
.to(NNIDataStore)
.scope(Scope.Singleton);
Container.bind(ExperimentManager)
.to(NNIExperimentsManager)
.scope(Scope.Singleton);
const DEFAULT_LOGFILE: string = path.join(getLogDir(), 'nnimanager.log');
if (foreground) {
logFileName = undefined;
......@@ -133,7 +137,7 @@ if (![ExperimentStartUpMode.NEW, ExperimentStartUpMode.RESUME].includes(startMod
}
const experimentId: string = parseArg(['--experiment_id', '-id']);
if ((startMode === ExperimentStartUpMode.RESUME) && experimentId.trim().length < 1) {
if (experimentId.trim().length < 1) {
console.log(`FATAL: cannot resume the experiment, invalid experiment_id: ${experimentId}`);
usage();
process.exit(1);
......@@ -185,6 +189,8 @@ async function cleanUp(): Promise<void> {
try {
const nniManager: Manager = component.get(Manager);
await nniManager.stopExperiment();
const experimentManager: ExperimentManager = component.get(ExperimentManager);
await experimentManager.stop();
const ds: DataStore = component.get(DataStore);
await ds.close();
const restServer: NNIRestServer = component.get(NNIRestServer);
......
......@@ -18,6 +18,7 @@
"ignore": "^5.1.4",
"js-base64": "^2.4.9",
"kubernetes-client": "^6.5.0",
"lockfile": "^1.0.4",
"python-shell": "^2.0.1",
"rx": "^4.1.0",
"sqlite3": "^5.0.0",
......@@ -39,6 +40,7 @@
"@types/glob": "^7.1.1",
"@types/js-base64": "^2.3.1",
"@types/js-yaml": "^3.12.5",
"@types/lockfile": "^1.0.0",
"@types/mocha": "^8.0.3",
"@types/node": "10.12.18",
"@types/request": "^2.47.1",
......
......@@ -12,20 +12,22 @@ import { NNIError, NNIErrorNames } from '../common/errors';
import { isNewExperiment, isReadonly } from '../common/experimentStartupInfo';
import { getLogger, Logger } from '../common/log';
import { ExperimentProfile, Manager, TrialJobStatistics } from '../common/manager';
import { ExperimentManager } from '../common/experimentManager';
import { ValidationSchemas } from './restValidationSchemas';
import { NNIRestServer } from './nniRestServer';
import { getVersion } from '../common/utils';
import { NNIManager } from "../core/nnimanager";
const expressJoi = require('express-joi-validator');
class NNIRestHandler {
private restServer: NNIRestServer;
private nniManager: NNIManager;
private nniManager: Manager;
private experimentsManager: ExperimentManager;
private log: Logger;
constructor(rs: NNIRestServer) {
this.nniManager = component.get(Manager);
this.experimentsManager = component.get(ExperimentManager);
this.restServer = rs;
this.log = getLogger();
}
......@@ -61,6 +63,7 @@ class NNIRestHandler {
this.getLatestMetricData(router);
this.getTrialLog(router);
this.exportData(router);
this.getExperimentsInfo(router);
// Express-joi-validator configuration
router.use((err: any, _req: Request, res: Response, _next: any) => {
......@@ -306,6 +309,16 @@ class NNIRestHandler {
});
}
private getExperimentsInfo(router: Router): void {
router.get('/experiments-info', (req: Request, res: Response) => {
this.experimentsManager.getExperimentsInfo().then((experimentInfo: JSON) => {
res.send(JSON.stringify(experimentInfo));
}).catch((err: Error) => {
this.handleError(err, res);
});
});
}
private setErrorPathForFailedJob(jobInfo: TrialJobInfo): TrialJobInfo {
if (jobInfo === undefined || jobInfo.status !== 'FAILED' || jobInfo.logPath === undefined) {
return jobInfo;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment