Unverified Commit 95f731e4 authored by J-shang's avatar J-shang Committed by GitHub
Browse files

experiment management backend (#3081)



* step 1 nnictl generate experimentId & merge folder

* step 2.1 modify .experiment structure

* step 2.2 add lock for .experiment rw in nnictl

* step 2.2 add filelock dependence

* step 2.2 remove uniqueString from main.js

* fix test bug

* fix test bug

* setp 3.1 add experiment manager

* step 3.2 add getExperimentsInfo

* fix eslint

* add a simple file lock to support stale

* step 3.3 add test

* divide abs experiment manager from manager

* experiment manager refactor

* support .experiment sync update status

* nnictl no longer uses rest api to update status or endtime

* nnictl no longer uses rest api to update status or endtime

* fix eslint

* support .experiment sync update endtime

* fix test

* fix settimeout bug

* fix test

* adjust experiment endTime

* separate simple file lock class

* modify name

* add 'id' in .experiment

* update rest api format

* fix eslint

* fix issue in comments

* fix rest api format

* add indent in json in experiments manager

* fix unittest

* fix unittest

* refector file lock

* fix eslint

* remove '__enter__' in filelock

* filelock support never expire
Co-authored-by: default avatarNing Shang <nishang@microsoft.com>
parent fc0ff8ce
...@@ -5,11 +5,14 @@ import os ...@@ -5,11 +5,14 @@ import os
import sys import sys
import json import json
import tempfile import tempfile
import time
import socket import socket
import string import string
import random import random
import ruamel.yaml as yaml import ruamel.yaml as yaml
import psutil import psutil
import filelock
import glob
from colorama import Fore from colorama import Fore
from .constants import ERROR_INFO, NORMAL_INFO, WARNING_INFO from .constants import ERROR_INFO, NORMAL_INFO, WARNING_INFO
...@@ -95,3 +98,36 @@ def generate_temp_dir(): ...@@ -95,3 +98,36 @@ def generate_temp_dir():
temp_dir = generate_folder_name() temp_dir = generate_folder_name()
os.makedirs(temp_dir) os.makedirs(temp_dir)
return temp_dir return temp_dir
class SimplePreemptiveLock(filelock.SoftFileLock):
'''this is a lock support check lock expiration, if you do not need check expiration, you can use SoftFileLock'''
def __init__(self, lock_file, stale=-1):
super(__class__, self).__init__(lock_file, timeout=-1)
self._lock_file_name = '{}.{}'.format(self._lock_file, os.getpid())
self._stale = stale
def _acquire(self):
open_mode = os.O_WRONLY | os.O_CREAT | os.O_EXCL | os.O_TRUNC
try:
lock_file_names = glob.glob(self._lock_file + '.*')
for file_name in lock_file_names:
if os.path.exists(file_name) and (self._stale < 0 or time.time() - os.stat(file_name).st_mtime < self._stale):
return None
fd = os.open(self._lock_file_name, open_mode)
except (IOError, OSError):
pass
else:
self._lock_file_fd = fd
return None
def _release(self):
os.close(self._lock_file_fd)
self._lock_file_fd = None
try:
os.remove(self._lock_file_name)
except OSError:
pass
return None
def get_file_lock(path: string, stale=-1):
return SimplePreemptiveLock(path + '.lock', stale=-1)
...@@ -4,8 +4,10 @@ ...@@ -4,8 +4,10 @@
import os import os
import json import json
import shutil import shutil
import time
from .constants import NNICTL_HOME_DIR from .constants import NNICTL_HOME_DIR
from .command_utils import print_error from .command_utils import print_error
from .common_utils import get_file_lock
class Config: class Config:
'''a util class to load and save config''' '''a util class to load and save config'''
...@@ -34,7 +36,7 @@ class Config: ...@@ -34,7 +36,7 @@ class Config:
if self.config: if self.config:
try: try:
with open(self.config_file, 'w') as file: with open(self.config_file, 'w') as file:
json.dump(self.config, file) json.dump(self.config, file, indent=4)
except IOError as error: except IOError as error:
print('Error:', error) print('Error:', error)
return return
...@@ -54,22 +56,33 @@ class Experiments: ...@@ -54,22 +56,33 @@ class Experiments:
def __init__(self, home_dir=NNICTL_HOME_DIR): def __init__(self, home_dir=NNICTL_HOME_DIR):
os.makedirs(home_dir, exist_ok=True) os.makedirs(home_dir, exist_ok=True)
self.experiment_file = os.path.join(home_dir, '.experiment') self.experiment_file = os.path.join(home_dir, '.experiment')
self.lock = get_file_lock(self.experiment_file, stale=2)
with self.lock:
self.experiments = self.read_file() self.experiments = self.read_file()
def add_experiment(self, expId, port, startTime, file_name, platform, experiment_name, endTime='N/A', status='INITIALIZED'): def add_experiment(self, expId, port, startTime, platform, experiment_name, endTime='N/A', status='INITIALIZED',
'''set {key:value} paris to self.experiment''' tag=[], pid=None, webuiUrl=[], logDir=[]):
'''set {key:value} pairs to self.experiment'''
with self.lock:
self.experiments = self.read_file()
self.experiments[expId] = {} self.experiments[expId] = {}
self.experiments[expId]['id'] = expId
self.experiments[expId]['port'] = port self.experiments[expId]['port'] = port
self.experiments[expId]['startTime'] = startTime self.experiments[expId]['startTime'] = startTime
self.experiments[expId]['endTime'] = endTime self.experiments[expId]['endTime'] = endTime
self.experiments[expId]['status'] = status self.experiments[expId]['status'] = status
self.experiments[expId]['fileName'] = file_name
self.experiments[expId]['platform'] = platform self.experiments[expId]['platform'] = platform
self.experiments[expId]['experimentName'] = experiment_name self.experiments[expId]['experimentName'] = experiment_name
self.experiments[expId]['tag'] = tag
self.experiments[expId]['pid'] = pid
self.experiments[expId]['webuiUrl'] = webuiUrl
self.experiments[expId]['logDir'] = logDir
self.write_file() self.write_file()
def update_experiment(self, expId, key, value): def update_experiment(self, expId, key, value):
'''Update experiment''' '''Update experiment'''
with self.lock:
self.experiments = self.read_file()
if expId not in self.experiments: if expId not in self.experiments:
return False return False
self.experiments[expId][key] = value self.experiments[expId][key] = value
...@@ -78,8 +91,11 @@ class Experiments: ...@@ -78,8 +91,11 @@ class Experiments:
def remove_experiment(self, expId): def remove_experiment(self, expId):
'''remove an experiment by id''' '''remove an experiment by id'''
with self.lock:
self.experiments = self.read_file()
if expId in self.experiments: if expId in self.experiments:
fileName = self.experiments.pop(expId).get('fileName') self.experiments.pop(expId)
fileName = expId
if fileName: if fileName:
logPath = os.path.join(NNICTL_HOME_DIR, fileName) logPath = os.path.join(NNICTL_HOME_DIR, fileName)
try: try:
...@@ -96,7 +112,7 @@ class Experiments: ...@@ -96,7 +112,7 @@ class Experiments:
'''save config to local file''' '''save config to local file'''
try: try:
with open(self.experiment_file, 'w') as file: with open(self.experiment_file, 'w') as file:
json.dump(self.experiments, file) json.dump(self.experiments, file, indent=4)
except IOError as error: except IOError as error:
print('Error:', error) print('Error:', error)
return '' return ''
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
import os import os
from colorama import Fore from colorama import Fore
NNICTL_HOME_DIR = os.path.join(os.path.expanduser('~'), '.local', 'nnictl') NNICTL_HOME_DIR = os.path.join(os.path.expanduser('~'), 'nni-experiments')
NNI_HOME_DIR = os.path.join(os.path.expanduser('~'), 'nni-experiments') NNI_HOME_DIR = os.path.join(os.path.expanduser('~'), 'nni-experiments')
......
...@@ -23,10 +23,11 @@ from .constants import NNICTL_HOME_DIR, ERROR_INFO, REST_TIME_OUT, EXPERIMENT_SU ...@@ -23,10 +23,11 @@ from .constants import NNICTL_HOME_DIR, ERROR_INFO, REST_TIME_OUT, EXPERIMENT_SU
from .command_utils import check_output_command, kill_command from .command_utils import check_output_command, kill_command
from .nnictl_utils import update_experiment from .nnictl_utils import update_experiment
def get_log_path(config_file_name): def get_log_path(experiment_id):
'''generate stdout and stderr log path''' '''generate stdout and stderr log path'''
stdout_full_path = os.path.join(NNICTL_HOME_DIR, config_file_name, 'stdout') os.makedirs(os.path.join(NNICTL_HOME_DIR, experiment_id, 'log'), exist_ok=True)
stderr_full_path = os.path.join(NNICTL_HOME_DIR, config_file_name, 'stderr') stdout_full_path = os.path.join(NNICTL_HOME_DIR, experiment_id, 'log', 'nnictl_stdout.log')
stderr_full_path = os.path.join(NNICTL_HOME_DIR, experiment_id, 'log', 'nnictl_stderr.log')
return stdout_full_path, stderr_full_path return stdout_full_path, stderr_full_path
def print_log_content(config_file_name): def print_log_content(config_file_name):
...@@ -38,7 +39,7 @@ def print_log_content(config_file_name): ...@@ -38,7 +39,7 @@ def print_log_content(config_file_name):
print_normal(' Stderr:') print_normal(' Stderr:')
print(check_output_command(stderr_full_path)) print(check_output_command(stderr_full_path))
def start_rest_server(port, platform, mode, config_file_name, foreground=False, experiment_id=None, log_dir=None, log_level=None): def start_rest_server(port, platform, mode, experiment_id, foreground=False, log_dir=None, log_level=None):
'''Run nni manager process''' '''Run nni manager process'''
if detect_port(port): if detect_port(port):
print_error('Port %s is used by another process, please reset the port!\n' \ print_error('Port %s is used by another process, please reset the port!\n' \
...@@ -63,7 +64,8 @@ def start_rest_server(port, platform, mode, config_file_name, foreground=False, ...@@ -63,7 +64,8 @@ def start_rest_server(port, platform, mode, config_file_name, foreground=False,
node_command = os.path.join(entry_dir, 'node.exe') node_command = os.path.join(entry_dir, 'node.exe')
else: else:
node_command = os.path.join(entry_dir, 'node') node_command = os.path.join(entry_dir, 'node')
cmds = [node_command, '--max-old-space-size=4096', entry_file, '--port', str(port), '--mode', platform] cmds = [node_command, '--max-old-space-size=4096', entry_file, '--port', str(port), '--mode', platform, \
'--experiment_id', experiment_id]
if mode == 'view': if mode == 'view':
cmds += ['--start_mode', 'resume'] cmds += ['--start_mode', 'resume']
cmds += ['--readonly', 'true'] cmds += ['--readonly', 'true']
...@@ -73,13 +75,12 @@ def start_rest_server(port, platform, mode, config_file_name, foreground=False, ...@@ -73,13 +75,12 @@ def start_rest_server(port, platform, mode, config_file_name, foreground=False,
cmds += ['--log_dir', log_dir] cmds += ['--log_dir', log_dir]
if log_level is not None: if log_level is not None:
cmds += ['--log_level', log_level] cmds += ['--log_level', log_level]
if mode in ['resume', 'view']:
cmds += ['--experiment_id', experiment_id]
if foreground: if foreground:
cmds += ['--foreground', 'true'] cmds += ['--foreground', 'true']
stdout_full_path, stderr_full_path = get_log_path(config_file_name) stdout_full_path, stderr_full_path = get_log_path(experiment_id)
with open(stdout_full_path, 'a+') as stdout_file, open(stderr_full_path, 'a+') as stderr_file: with open(stdout_full_path, 'a+') as stdout_file, open(stderr_full_path, 'a+') as stderr_file:
time_now = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time())) start_time = time.time()
time_now = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(start_time))
#add time information in the header of log files #add time information in the header of log files
log_header = LOG_HEADER % str(time_now) log_header = LOG_HEADER % str(time_now)
stdout_file.write(log_header) stdout_file.write(log_header)
...@@ -95,7 +96,7 @@ def start_rest_server(port, platform, mode, config_file_name, foreground=False, ...@@ -95,7 +96,7 @@ def start_rest_server(port, platform, mode, config_file_name, foreground=False,
process = Popen(cmds, cwd=entry_dir, stdout=PIPE, stderr=PIPE) process = Popen(cmds, cwd=entry_dir, stdout=PIPE, stderr=PIPE)
else: else:
process = Popen(cmds, cwd=entry_dir, stdout=stdout_file, stderr=stderr_file) process = Popen(cmds, cwd=entry_dir, stdout=stdout_file, stderr=stderr_file)
return process, str(time_now) return process, int(start_time * 1000)
def set_trial_config(experiment_config, port, config_file_name): def set_trial_config(experiment_config, port, config_file_name):
'''set trial configuration''' '''set trial configuration'''
...@@ -432,9 +433,9 @@ def set_platform_config(platform, experiment_config, port, config_file_name, res ...@@ -432,9 +433,9 @@ def set_platform_config(platform, experiment_config, port, config_file_name, res
raise Exception(ERROR_INFO % 'Rest server stopped!') raise Exception(ERROR_INFO % 'Rest server stopped!')
exit(1) exit(1)
def launch_experiment(args, experiment_config, mode, config_file_name, experiment_id=None): def launch_experiment(args, experiment_config, mode, experiment_id):
'''follow steps to start rest server and start experiment''' '''follow steps to start rest server and start experiment'''
nni_config = Config(config_file_name) nni_config = Config(experiment_id)
# check packages for tuner # check packages for tuner
package_name, module_name = None, None package_name, module_name = None, None
if experiment_config.get('tuner') and experiment_config['tuner'].get('builtinTunerName'): if experiment_config.get('tuner') and experiment_config['tuner'].get('builtinTunerName'):
...@@ -445,15 +446,15 @@ def launch_experiment(args, experiment_config, mode, config_file_name, experimen ...@@ -445,15 +446,15 @@ def launch_experiment(args, experiment_config, mode, config_file_name, experimen
module_name, _ = get_builtin_module_class_name('advisors', package_name) module_name, _ = get_builtin_module_class_name('advisors', package_name)
if package_name and module_name: if package_name and module_name:
try: try:
stdout_full_path, stderr_full_path = get_log_path(config_file_name) stdout_full_path, stderr_full_path = get_log_path(experiment_id)
with open(stdout_full_path, 'a+') as stdout_file, open(stderr_full_path, 'a+') as stderr_file: with open(stdout_full_path, 'a+') as stdout_file, open(stderr_full_path, 'a+') as stderr_file:
check_call([sys.executable, '-c', 'import %s'%(module_name)], stdout=stdout_file, stderr=stderr_file) check_call([sys.executable, '-c', 'import %s'%(module_name)], stdout=stdout_file, stderr=stderr_file)
except CalledProcessError: except CalledProcessError:
print_error('some errors happen when import package %s.' %(package_name)) print_error('some errors happen when import package %s.' %(package_name))
print_log_content(config_file_name) print_log_content(experiment_id)
if package_name in INSTALLABLE_PACKAGE_META: if package_name in INSTALLABLE_PACKAGE_META:
print_error('If %s is not installed, it should be installed through '\ print_error('If %s is not installed, it should be installed through '\
'\'nnictl package install --name %s\''%(package_name, package_name)) '\'nnictl package install --name %s\'' % (package_name, package_name))
exit(1) exit(1)
log_dir = experiment_config['logDir'] if experiment_config.get('logDir') else None log_dir = experiment_config['logDir'] if experiment_config.get('logDir') else None
log_level = experiment_config['logLevel'] if experiment_config.get('logLevel') else None log_level = experiment_config['logLevel'] if experiment_config.get('logLevel') else None
...@@ -465,7 +466,7 @@ def launch_experiment(args, experiment_config, mode, config_file_name, experimen ...@@ -465,7 +466,7 @@ def launch_experiment(args, experiment_config, mode, config_file_name, experimen
log_level = 'debug' log_level = 'debug'
# start rest server # start rest server
rest_process, start_time = start_rest_server(args.port, experiment_config['trainingServicePlatform'], \ rest_process, start_time = start_rest_server(args.port, experiment_config['trainingServicePlatform'], \
mode, config_file_name, foreground, experiment_id, log_dir, log_level) mode, experiment_id, foreground, log_dir, log_level)
nni_config.set_config('restServerPid', rest_process.pid) nni_config.set_config('restServerPid', rest_process.pid)
# Deal with annotation # Deal with annotation
if experiment_config.get('useAnnotation'): if experiment_config.get('useAnnotation'):
...@@ -491,7 +492,7 @@ def launch_experiment(args, experiment_config, mode, config_file_name, experimen ...@@ -491,7 +492,7 @@ def launch_experiment(args, experiment_config, mode, config_file_name, experimen
print_normal('Successfully started Restful server!') print_normal('Successfully started Restful server!')
else: else:
print_error('Restful server start failed!') print_error('Restful server start failed!')
print_log_content(config_file_name) print_log_content(experiment_id)
try: try:
kill_command(rest_process.pid) kill_command(rest_process.pid)
except Exception: except Exception:
...@@ -500,21 +501,25 @@ def launch_experiment(args, experiment_config, mode, config_file_name, experimen ...@@ -500,21 +501,25 @@ def launch_experiment(args, experiment_config, mode, config_file_name, experimen
if mode != 'view': if mode != 'view':
# set platform configuration # set platform configuration
set_platform_config(experiment_config['trainingServicePlatform'], experiment_config, args.port,\ set_platform_config(experiment_config['trainingServicePlatform'], experiment_config, args.port,\
config_file_name, rest_process) experiment_id, rest_process)
# start a new experiment # start a new experiment
print_normal('Starting experiment...') print_normal('Starting experiment...')
# save experiment information
nnictl_experiment_config = Experiments()
nnictl_experiment_config.add_experiment(experiment_id, args.port, start_time,
experiment_config['trainingServicePlatform'],
experiment_config['experimentName'], pid=rest_process.pid, logDir=log_dir)
# set debug configuration # set debug configuration
if mode != 'view' and experiment_config.get('debug') is None: if mode != 'view' and experiment_config.get('debug') is None:
experiment_config['debug'] = args.debug experiment_config['debug'] = args.debug
response = set_experiment(experiment_config, mode, args.port, config_file_name) response = set_experiment(experiment_config, mode, args.port, experiment_id)
if response: if response:
if experiment_id is None: if experiment_id is None:
experiment_id = json.loads(response.text).get('experiment_id') experiment_id = json.loads(response.text).get('experiment_id')
nni_config.set_config('experimentId', experiment_id)
else: else:
print_error('Start experiment failed!') print_error('Start experiment failed!')
print_log_content(config_file_name) print_log_content(experiment_id)
try: try:
kill_command(rest_process.pid) kill_command(rest_process.pid)
except Exception: except Exception:
...@@ -526,12 +531,6 @@ def launch_experiment(args, experiment_config, mode, config_file_name, experimen ...@@ -526,12 +531,6 @@ def launch_experiment(args, experiment_config, mode, config_file_name, experimen
web_ui_url_list = get_local_urls(args.port) web_ui_url_list = get_local_urls(args.port)
nni_config.set_config('webuiUrl', web_ui_url_list) nni_config.set_config('webuiUrl', web_ui_url_list)
# save experiment information
nnictl_experiment_config = Experiments()
nnictl_experiment_config.add_experiment(experiment_id, args.port, start_time, config_file_name,
experiment_config['trainingServicePlatform'],
experiment_config['experimentName'])
print_normal(EXPERIMENT_SUCCESS_INFO % (experiment_id, ' '.join(web_ui_url_list))) print_normal(EXPERIMENT_SUCCESS_INFO % (experiment_id, ' '.join(web_ui_url_list)))
if mode != 'view' and args.foreground: if mode != 'view' and args.foreground:
try: try:
...@@ -544,8 +543,9 @@ def launch_experiment(args, experiment_config, mode, config_file_name, experimen ...@@ -544,8 +543,9 @@ def launch_experiment(args, experiment_config, mode, config_file_name, experimen
def create_experiment(args): def create_experiment(args):
'''start a new experiment''' '''start a new experiment'''
config_file_name = ''.join(random.sample(string.ascii_letters + string.digits, 8)) experiment_id = ''.join(random.sample(string.ascii_letters + string.digits, 8))
nni_config = Config(config_file_name) nni_config = Config(experiment_id)
nni_config.set_config('experimentId', experiment_id)
config_path = os.path.abspath(args.config) config_path = os.path.abspath(args.config)
if not os.path.exists(config_path): if not os.path.exists(config_path):
print_error('Please set correct config path!') print_error('Please set correct config path!')
...@@ -560,9 +560,9 @@ def create_experiment(args): ...@@ -560,9 +560,9 @@ def create_experiment(args):
nni_config.set_config('experimentConfig', experiment_config) nni_config.set_config('experimentConfig', experiment_config)
nni_config.set_config('restServerPort', args.port) nni_config.set_config('restServerPort', args.port)
try: try:
launch_experiment(args, experiment_config, 'new', config_file_name) launch_experiment(args, experiment_config, 'new', experiment_id)
except Exception as exception: except Exception as exception:
nni_config = Config(config_file_name) nni_config = Config(experiment_id)
restServerPid = nni_config.get_config('restServerPid') restServerPid = nni_config.get_config('restServerPid')
if restServerPid: if restServerPid:
kill_command(restServerPid) kill_command(restServerPid)
...@@ -589,17 +589,13 @@ def manage_stopped_experiment(args, mode): ...@@ -589,17 +589,13 @@ def manage_stopped_experiment(args, mode):
exit(1) exit(1)
experiment_id = args.id experiment_id = args.id
print_normal('{0} experiment {1}...'.format(mode, experiment_id)) print_normal('{0} experiment {1}...'.format(mode, experiment_id))
nni_config = Config(experiment_dict[experiment_id]['fileName']) nni_config = Config(experiment_id)
experiment_config = nni_config.get_config('experimentConfig') experiment_config = nni_config.get_config('experimentConfig')
experiment_id = nni_config.get_config('experimentId') nni_config.set_config('restServerPort', args.port)
new_config_file_name = ''.join(random.sample(string.ascii_letters + string.digits, 8))
new_nni_config = Config(new_config_file_name)
new_nni_config.set_config('experimentConfig', experiment_config)
new_nni_config.set_config('restServerPort', args.port)
try: try:
launch_experiment(args, experiment_config, mode, new_config_file_name, experiment_id) launch_experiment(args, experiment_config, mode, experiment_id)
except Exception as exception: except Exception as exception:
nni_config = Config(new_config_file_name) nni_config = Config(experiment_id)
restServerPid = nni_config.get_config('restServerPid') restServerPid = nni_config.get_config('restServerPid')
if restServerPid: if restServerPid:
kill_command(restServerPid) kill_command(restServerPid)
......
...@@ -32,6 +32,8 @@ def parse_time(time): ...@@ -32,6 +32,8 @@ def parse_time(time):
def parse_path(experiment_config, config_path): def parse_path(experiment_config, config_path):
'''Parse path in config file''' '''Parse path in config file'''
expand_path(experiment_config, 'searchSpacePath') expand_path(experiment_config, 'searchSpacePath')
if experiment_config.get('logDir'):
expand_path(experiment_config, 'logDir')
if experiment_config.get('trial'): if experiment_config.get('trial'):
expand_path(experiment_config['trial'], 'codeDir') expand_path(experiment_config['trial'], 'codeDir')
if experiment_config['trial'].get('authFile'): if experiment_config['trial'].get('authFile'):
...@@ -65,6 +67,8 @@ def parse_path(experiment_config, config_path): ...@@ -65,6 +67,8 @@ def parse_path(experiment_config, config_path):
root_path = os.path.dirname(config_path) root_path = os.path.dirname(config_path)
if experiment_config.get('searchSpacePath'): if experiment_config.get('searchSpacePath'):
parse_relative_path(root_path, experiment_config, 'searchSpacePath') parse_relative_path(root_path, experiment_config, 'searchSpacePath')
if experiment_config.get('logDir'):
parse_relative_path(root_path, experiment_config, 'logDir')
if experiment_config.get('trial'): if experiment_config.get('trial'):
parse_relative_path(root_path, experiment_config['trial'], 'codeDir') parse_relative_path(root_path, experiment_config['trial'], 'codeDir')
if experiment_config['trial'].get('authFile'): if experiment_config['trial'].get('authFile'):
......
...@@ -30,7 +30,7 @@ def get_experiment_time(port): ...@@ -30,7 +30,7 @@ def get_experiment_time(port):
'''get the startTime and endTime of an experiment''' '''get the startTime and endTime of an experiment'''
response = rest_get(experiment_url(port), REST_TIME_OUT) response = rest_get(experiment_url(port), REST_TIME_OUT)
if response and check_response(response): if response and check_response(response):
content = convert_time_stamp_to_date(json.loads(response.text)) content = json.loads(response.text)
return content.get('startTime'), content.get('endTime') return content.get('startTime'), content.get('endTime')
return None, None return None, None
...@@ -50,20 +50,11 @@ def update_experiment(): ...@@ -50,20 +50,11 @@ def update_experiment():
for key in experiment_dict.keys(): for key in experiment_dict.keys():
if isinstance(experiment_dict[key], dict): if isinstance(experiment_dict[key], dict):
if experiment_dict[key].get('status') != 'STOPPED': if experiment_dict[key].get('status') != 'STOPPED':
nni_config = Config(experiment_dict[key]['fileName']) nni_config = Config(key)
rest_pid = nni_config.get_config('restServerPid') rest_pid = nni_config.get_config('restServerPid')
if not detect_process(rest_pid): if not detect_process(rest_pid):
experiment_config.update_experiment(key, 'status', 'STOPPED') experiment_config.update_experiment(key, 'status', 'STOPPED')
continue continue
rest_port = nni_config.get_config('restServerPort')
startTime, endTime = get_experiment_time(rest_port)
if startTime:
experiment_config.update_experiment(key, 'startTime', startTime)
if endTime:
experiment_config.update_experiment(key, 'endTime', endTime)
status = get_experiment_status(rest_port)
if status:
experiment_config.update_experiment(key, 'status', status)
def check_experiment_id(args, update=True): def check_experiment_id(args, update=True):
'''check if the id is valid '''check if the id is valid
...@@ -184,9 +175,7 @@ def get_config_filename(args): ...@@ -184,9 +175,7 @@ def get_config_filename(args):
if experiment_id is None: if experiment_id is None:
print_error('Please set correct experiment id.') print_error('Please set correct experiment id.')
exit(1) exit(1)
experiment_config = Experiments() return experiment_id
experiment_dict = experiment_config.get_all_experiments()
return experiment_dict[experiment_id]['fileName']
def get_experiment_port(args): def get_experiment_port(args):
'''get the port of experiment''' '''get the port of experiment'''
...@@ -228,11 +217,9 @@ def stop_experiment(args): ...@@ -228,11 +217,9 @@ def stop_experiment(args):
exit(1) exit(1)
experiment_id_list = parse_ids(args) experiment_id_list = parse_ids(args)
if experiment_id_list: if experiment_id_list:
experiment_config = Experiments()
experiment_dict = experiment_config.get_all_experiments()
for experiment_id in experiment_id_list: for experiment_id in experiment_id_list:
print_normal('Stopping experiment %s' % experiment_id) print_normal('Stopping experiment %s' % experiment_id)
nni_config = Config(experiment_dict[experiment_id]['fileName']) nni_config = Config(experiment_id)
rest_pid = nni_config.get_config('restServerPid') rest_pid = nni_config.get_config('restServerPid')
if rest_pid: if rest_pid:
kill_command(rest_pid) kill_command(rest_pid)
...@@ -245,9 +232,6 @@ def stop_experiment(args): ...@@ -245,9 +232,6 @@ def stop_experiment(args):
print_error(exception) print_error(exception)
nni_config.set_config('tensorboardPidList', []) nni_config.set_config('tensorboardPidList', [])
print_normal('Stop experiment success.') print_normal('Stop experiment success.')
experiment_config.update_experiment(experiment_id, 'status', 'STOPPED')
time_now = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time()))
experiment_config.update_experiment(experiment_id, 'endTime', str(time_now))
def trial_ls(args): def trial_ls(args):
'''List trial''' '''List trial'''
...@@ -553,7 +537,7 @@ def experiment_clean(args): ...@@ -553,7 +537,7 @@ def experiment_clean(args):
else: else:
break break
for experiment_id in experiment_id_list: for experiment_id in experiment_id_list:
nni_config = Config(experiment_dict[experiment_id]['fileName']) nni_config = Config(experiment_id)
platform = nni_config.get_config('experimentConfig').get('trainingServicePlatform') platform = nni_config.get_config('experimentConfig').get('trainingServicePlatform')
experiment_id = nni_config.get_config('experimentId') experiment_id = nni_config.get_config('experimentId')
if platform == 'remote': if platform == 'remote':
...@@ -668,18 +652,15 @@ def experiment_list(args): ...@@ -668,18 +652,15 @@ def experiment_list(args):
experiment_dict[key]['status'], experiment_dict[key]['status'],
experiment_dict[key]['port'], experiment_dict[key]['port'],
experiment_dict[key].get('platform'), experiment_dict[key].get('platform'),
experiment_dict[key]['startTime'], time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(experiment_dict[key]['startTime'] / 1000)) if isinstance(experiment_dict[key]['startTime'], int) else experiment_dict[key]['startTime'],
experiment_dict[key]['endTime']) time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(experiment_dict[key]['endTime'] / 1000)) if isinstance(experiment_dict[key]['endTime'], int) else experiment_dict[key]['endTime'])
print(EXPERIMENT_INFORMATION_FORMAT % experiment_information) print(EXPERIMENT_INFORMATION_FORMAT % experiment_information)
return experiment_id_list return experiment_id_list
def get_time_interval(time1, time2): def get_time_interval(time1, time2):
'''get the interval of two times''' '''get the interval of two times'''
try: try:
#convert time to timestamp seconds = int((time2 - time1) / 1000)
time1 = time.mktime(time.strptime(time1, '%Y/%m/%d %H:%M:%S'))
time2 = time.mktime(time.strptime(time2, '%Y/%m/%d %H:%M:%S'))
seconds = (datetime.fromtimestamp(time2) - datetime.fromtimestamp(time1)).seconds
#convert seconds to day:hour:minute:second #convert seconds to day:hour:minute:second
days = seconds / 86400 days = seconds / 86400
seconds %= 86400 seconds %= 86400
...@@ -708,7 +689,7 @@ def show_experiment_info(): ...@@ -708,7 +689,7 @@ def show_experiment_info():
return return
for key in experiment_id_list: for key in experiment_id_list:
print(EXPERIMENT_MONITOR_INFO % (key, experiment_dict[key]['status'], experiment_dict[key]['port'], \ print(EXPERIMENT_MONITOR_INFO % (key, experiment_dict[key]['status'], experiment_dict[key]['port'], \
experiment_dict[key].get('platform'), experiment_dict[key]['startTime'], \ experiment_dict[key].get('platform'), time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(experiment_dict[key]['startTime'] / 1000)) if isinstance(experiment_dict[key]['startTime'], int) else experiment_dict[key]['startTime'], \
get_time_interval(experiment_dict[key]['startTime'], experiment_dict[key]['endTime']))) get_time_interval(experiment_dict[key]['startTime'], experiment_dict[key]['endTime'])))
print(TRIAL_MONITOR_HEAD) print(TRIAL_MONITOR_HEAD)
running, response = check_rest_server_quick(experiment_dict[key]['port']) running, response = check_rest_server_quick(experiment_dict[key]['port'])
...@@ -850,7 +831,7 @@ def save_experiment(args): ...@@ -850,7 +831,7 @@ def save_experiment(args):
print_error('Can only save stopped experiment!') print_error('Can only save stopped experiment!')
exit(1) exit(1)
print_normal('Saving...') print_normal('Saving...')
nni_config = Config(experiment_dict[args.id]['fileName']) nni_config = Config(args.id)
logDir = os.path.join(NNI_HOME_DIR, args.id) logDir = os.path.join(NNI_HOME_DIR, args.id)
if nni_config.get_config('logDir'): if nni_config.get_config('logDir'):
logDir = os.path.join(nni_config.get_config('logDir'), args.id) logDir = os.path.join(nni_config.get_config('logDir'), args.id)
...@@ -873,8 +854,8 @@ def save_experiment(args): ...@@ -873,8 +854,8 @@ def save_experiment(args):
except IOError: except IOError:
print_error('Write file to %s failed!' % os.path.join(temp_nnictl_dir, '.experiment')) print_error('Write file to %s failed!' % os.path.join(temp_nnictl_dir, '.experiment'))
exit(1) exit(1)
nnictl_config_dir = os.path.join(NNICTL_HOME_DIR, experiment_dict[args.id]['fileName']) nnictl_config_dir = os.path.join(NNICTL_HOME_DIR, args.id)
shutil.copytree(nnictl_config_dir, os.path.join(temp_nnictl_dir, experiment_dict[args.id]['fileName'])) shutil.copytree(nnictl_config_dir, os.path.join(temp_nnictl_dir, args.id))
# Step3. Copy code dir # Step3. Copy code dir
if args.saveCodeDir: if args.saveCodeDir:
...@@ -947,20 +928,20 @@ def load_experiment(args): ...@@ -947,20 +928,20 @@ def load_experiment(args):
print_error('Invalid: experiment id already exist!') print_error('Invalid: experiment id already exist!')
shutil.rmtree(temp_root_dir) shutil.rmtree(temp_root_dir)
exit(1) exit(1)
if not os.path.exists(os.path.join(nnictl_temp_dir, experiment_metadata.get('fileName'))): if not os.path.exists(os.path.join(nnictl_temp_dir, experiment_id)):
print_error('Invalid: experiment metadata does not exist!') print_error('Invalid: experiment metadata does not exist!')
shutil.rmtree(temp_root_dir) shutil.rmtree(temp_root_dir)
exit(1) exit(1)
# Step2. Copy nnictl metadata # Step2. Copy nnictl metadata
src_path = os.path.join(nnictl_temp_dir, experiment_metadata.get('fileName')) src_path = os.path.join(nnictl_temp_dir, experiment_id)
dest_path = os.path.join(NNICTL_HOME_DIR, experiment_metadata.get('fileName')) dest_path = os.path.join(NNICTL_HOME_DIR, experiment_id)
if os.path.exists(dest_path): if os.path.exists(dest_path):
shutil.rmtree(dest_path) shutil.rmtree(dest_path)
shutil.copytree(src_path, dest_path) shutil.copytree(src_path, dest_path)
# Step3. Copy experiment data # Step3. Copy experiment data
nni_config = Config(experiment_metadata.get('fileName')) nni_config = Config(experiment_id)
nnictl_exp_config = nni_config.get_config('experimentConfig') nnictl_exp_config = nni_config.get_config('experimentConfig')
if args.logDir: if args.logDir:
logDir = args.logDir logDir = args.logDir
...@@ -1027,13 +1008,15 @@ def load_experiment(args): ...@@ -1027,13 +1008,15 @@ def load_experiment(args):
experiment_config.add_experiment(experiment_id, experiment_config.add_experiment(experiment_id,
experiment_metadata.get('port'), experiment_metadata.get('port'),
experiment_metadata.get('startTime'), experiment_metadata.get('startTime'),
experiment_metadata.get('fileName'),
experiment_metadata.get('platform'), experiment_metadata.get('platform'),
experiment_metadata.get('experimentName'), experiment_metadata.get('experimentName'),
experiment_metadata.get('endTime'), experiment_metadata.get('endTime'),
experiment_metadata.get('status')) experiment_metadata.get('status'),
experiment_metadata.get('tag'),
experiment_metadata.get('pid'),
experiment_metadata.get('webUrl'),
experiment_metadata.get('logDir'))
print_normal('Load experiment %s succsss!' % experiment_id) print_normal('Load experiment %s succsss!' % experiment_id)
# Step6. Cleanup temp data # Step6. Cleanup temp data
shutil.rmtree(temp_root_dir) shutil.rmtree(temp_root_dir)
...@@ -11,7 +11,7 @@ from .config_utils import Config, Experiments ...@@ -11,7 +11,7 @@ from .config_utils import Config, Experiments
from .url_utils import trial_jobs_url, get_local_urls from .url_utils import trial_jobs_url, get_local_urls
from .constants import REST_TIME_OUT from .constants import REST_TIME_OUT
from .common_utils import print_normal, print_warning, print_error, print_green, detect_process, detect_port, check_tensorboard_version from .common_utils import print_normal, print_warning, print_error, print_green, detect_process, detect_port, check_tensorboard_version
from .nnictl_utils import check_experiment_id, check_experiment_id from .nnictl_utils import check_experiment_id
from .ssh_utils import create_ssh_sftp_client, copy_remote_directory_to_local from .ssh_utils import create_ssh_sftp_client, copy_remote_directory_to_local
def parse_log_path(args, trial_content): def parse_log_path(args, trial_content):
...@@ -95,8 +95,7 @@ def stop_tensorboard(args): ...@@ -95,8 +95,7 @@ def stop_tensorboard(args):
experiment_id = check_experiment_id(args) experiment_id = check_experiment_id(args)
experiment_config = Experiments() experiment_config = Experiments()
experiment_dict = experiment_config.get_all_experiments() experiment_dict = experiment_config.get_all_experiments()
config_file_name = experiment_dict[experiment_id]['fileName'] nni_config = Config(experiment_id)
nni_config = Config(config_file_name)
tensorboard_pid_list = nni_config.get_config('tensorboardPidList') tensorboard_pid_list = nni_config.get_config('tensorboardPidList')
if tensorboard_pid_list: if tensorboard_pid_list:
for tensorboard_pid in tensorboard_pid_list: for tensorboard_pid in tensorboard_pid_list:
...@@ -136,7 +135,7 @@ def start_tensorboard(args): ...@@ -136,7 +135,7 @@ def start_tensorboard(args):
print_error("Experiment {} is stopped...".format(args.id)) print_error("Experiment {} is stopped...".format(args.id))
return return
config_file_name = experiment_dict[experiment_id]['fileName'] config_file_name = experiment_dict[experiment_id]['fileName']
nni_config = Config(config_file_name) nni_config = Config(args.id)
if nni_config.get_config('experimentConfig').get('trainingServicePlatform') == 'adl': if nni_config.get_config('experimentConfig').get('trainingServicePlatform') == 'adl':
adl_tensorboard_helper(args) adl_tensorboard_helper(args)
return return
......
...@@ -73,6 +73,7 @@ dependencies = [ ...@@ -73,6 +73,7 @@ dependencies = [
'scikit-learn>=0.23.2', 'scikit-learn>=0.23.2',
'pkginfo', 'pkginfo',
'websockets', 'websockets',
'filelock',
'prettytable' 'prettytable'
] ]
......
...@@ -11,9 +11,9 @@ from nni.tools.nnictl.nnictl_utils import get_yml_content ...@@ -11,9 +11,9 @@ from nni.tools.nnictl.nnictl_utils import get_yml_content
def create_mock_experiment(): def create_mock_experiment():
nnictl_experiment_config = Experiments() nnictl_experiment_config = Experiments()
nnictl_experiment_config.add_experiment('xOpEwA5w', '8080', '1970/01/1 01:01:01', 'aGew0x', nnictl_experiment_config.add_experiment('xOpEwA5w', '8080', 123456,
'local', 'example_sklearn-classification') 'local', 'example_sklearn-classification')
nni_config = Config('aGew0x') nni_config = Config('xOpEwA5w')
# mock process # mock process
cmds = ['sleep', '3600000'] cmds = ['sleep', '3600000']
process = Popen(cmds, stdout=PIPE, stderr=STDOUT) process = Popen(cmds, stdout=PIPE, stderr=STDOUT)
......
...@@ -19,7 +19,7 @@ class CommonUtilsTestCase(TestCase): ...@@ -19,7 +19,7 @@ class CommonUtilsTestCase(TestCase):
def test_update_experiment(self): def test_update_experiment(self):
experiment = Experiments(HOME_PATH) experiment = Experiments(HOME_PATH)
experiment.add_experiment('xOpEwA5w', 8081, 'N/A', 'aGew0x', 'local', 'test', endTime='N/A', status='INITIALIZED') experiment.add_experiment('xOpEwA5w', 8081, 'N/A', 'local', 'test', endTime='N/A', status='INITIALIZED')
self.assertTrue('xOpEwA5w' in experiment.get_all_experiments()) self.assertTrue('xOpEwA5w' in experiment.get_all_experiments())
experiment.remove_experiment('xOpEwA5w') experiment.remove_experiment('xOpEwA5w')
self.assertFalse('xOpEwA5w' in experiment.get_all_experiments()) self.assertFalse('xOpEwA5w' in experiment.get_all_experiments())
......
...@@ -46,7 +46,7 @@ class CommonUtilsTestCase(TestCase): ...@@ -46,7 +46,7 @@ class CommonUtilsTestCase(TestCase):
@responses.activate @responses.activate
def test_get_config_file_name(self): def test_get_config_file_name(self):
args = generate_args() args = generate_args()
self.assertEqual('aGew0x', get_config_filename(args)) self.assertEqual('xOpEwA5w', get_config_filename(args))
@responses.activate @responses.activate
def test_get_experiment_port(self): def test_get_experiment_port(self):
......
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
'use strict';
abstract class ExperimentManager {
public abstract getExperimentsInfo(): Promise<JSON>;
public abstract setExperimentPath(newPath: string): void;
public abstract setExperimentInfo(experimentId: string, key: string, value: any): void;
public abstract stop(): Promise<void>;
}
export {ExperimentManager};
...@@ -11,13 +11,16 @@ import { ChildProcess, spawn, StdioOptions } from 'child_process'; ...@@ -11,13 +11,16 @@ import { ChildProcess, spawn, StdioOptions } from 'child_process';
import * as fs from 'fs'; import * as fs from 'fs';
import * as os from 'os'; import * as os from 'os';
import * as path from 'path'; import * as path from 'path';
import * as lockfile from 'lockfile';
import { Deferred } from 'ts-deferred'; import { Deferred } from 'ts-deferred';
import { Container } from 'typescript-ioc'; import { Container } from 'typescript-ioc';
import * as util from 'util'; import * as util from 'util';
import * as glob from 'glob';
import { Database, DataStore } from './datastore'; import { Database, DataStore } from './datastore';
import { ExperimentStartupInfo, getExperimentStartupInfo, setExperimentStartupInfo } from './experimentStartupInfo'; import { ExperimentStartupInfo, getExperimentStartupInfo, setExperimentStartupInfo } from './experimentStartupInfo';
import { ExperimentParams, Manager } from './manager'; import { ExperimentParams, Manager } from './manager';
import { ExperimentManager } from './experimentManager';
import { HyperParameters, TrainingService, TrialJobStatus } from './trainingService'; import { HyperParameters, TrainingService, TrialJobStatus } from './trainingService';
import { logLevelNameMap } from './log'; import { logLevelNameMap } from './log';
...@@ -43,6 +46,10 @@ function getCheckpointDir(): string { ...@@ -43,6 +46,10 @@ function getCheckpointDir(): string {
return path.join(getExperimentRootDir(), 'checkpoint'); return path.join(getExperimentRootDir(), 'checkpoint');
} }
function getExperimentsInfoPath(): string {
return path.join(os.homedir(), 'nni-experiments', '.experiment');
}
function mkDirP(dirPath: string): Promise<void> { function mkDirP(dirPath: string): Promise<void> {
const deferred: Deferred<void> = new Deferred<void>(); const deferred: Deferred<void> = new Deferred<void>();
fs.exists(dirPath, (exists: boolean) => { fs.exists(dirPath, (exists: boolean) => {
...@@ -184,6 +191,7 @@ function prepareUnitTest(): void { ...@@ -184,6 +191,7 @@ function prepareUnitTest(): void {
Container.snapshot(DataStore); Container.snapshot(DataStore);
Container.snapshot(TrainingService); Container.snapshot(TrainingService);
Container.snapshot(Manager); Container.snapshot(Manager);
Container.snapshot(ExperimentManager);
const logLevel: string = parseArg(['--log_level', '-ll']); const logLevel: string = parseArg(['--log_level', '-ll']);
if (logLevel.length > 0 && !logLevelNameMap.has(logLevel)) { if (logLevel.length > 0 && !logLevelNameMap.has(logLevel)) {
...@@ -211,6 +219,7 @@ function cleanupUnitTest(): void { ...@@ -211,6 +219,7 @@ function cleanupUnitTest(): void {
Container.restore(DataStore); Container.restore(DataStore);
Container.restore(Database); Container.restore(Database);
Container.restore(ExperimentStartupInfo); Container.restore(ExperimentStartupInfo);
Container.restore(ExperimentManager);
} }
let cachedipv4Address: string = ''; let cachedipv4Address: string = '';
...@@ -416,8 +425,29 @@ function unixPathJoin(...paths: any[]): string { ...@@ -416,8 +425,29 @@ function unixPathJoin(...paths: any[]): string {
return dir; return dir;
} }
/**
* lock a file sync
*/
function withLockSync(func: Function, filePath: string, lockOpts: {[key: string]: any}, ...args: any): any {
const lockName = path.join(path.dirname(filePath), path.basename(filePath) + `.lock.${process.pid}`);
if (typeof lockOpts.stale === 'number'){
const lockPath = path.join(path.dirname(filePath), path.basename(filePath) + '.lock.*');
const lockFileNames: string[] = glob.sync(lockPath);
const canLock: boolean = lockFileNames.map((fileName) => {
return fs.existsSync(fileName) && Date.now() - fs.statSync(fileName).mtimeMs > lockOpts.stale;
}).filter(isExpired=>isExpired === false).length === 0;
if (!canLock) {
throw new Error('File has been locked.');
}
}
lockfile.lockSync(lockName, lockOpts);
const result = func(...args);
lockfile.unlockSync(lockName);
return result;
}
export { export {
countFilesRecursively, validateFileNameRecursively, generateParamFileName, getMsgDispatcherCommand, getCheckpointDir, countFilesRecursively, validateFileNameRecursively, generateParamFileName, getMsgDispatcherCommand, getCheckpointDir, getExperimentsInfoPath,
getLogDir, getExperimentRootDir, getJobCancelStatus, getDefaultDatabaseDir, getIPV4Address, unixPathJoin, getLogDir, getExperimentRootDir, getJobCancelStatus, getDefaultDatabaseDir, getIPV4Address, unixPathJoin, withLockSync,
mkDirP, mkDirPSync, delay, prepareUnitTest, parseArg, cleanupUnitTest, uniqueString, randomInt, randomSelect, getLogLevel, getVersion, getCmdPy, getTunerProc, isAlive, killPid, getNewLine mkDirP, mkDirPSync, delay, prepareUnitTest, parseArg, cleanupUnitTest, uniqueString, randomInt, randomSelect, getLogLevel, getVersion, getCmdPy, getTunerProc, isAlive, killPid, getNewLine
}; };
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
'use strict';
import * as fs from 'fs';
import * as os from 'os';
import * as path from 'path';
import * as assert from 'assert';
import { getLogger, Logger } from '../common/log';
import { isAlive, withLockSync, getExperimentsInfoPath, delay } from '../common/utils';
import { ExperimentManager } from '../common/experimentManager';
import { Deferred } from 'ts-deferred';
interface CrashedInfo {
experimentId: string;
isCrashed: boolean;
}
interface FileInfo {
buffer: Buffer;
mtime: number;
}
class NNIExperimentsManager implements ExperimentManager {
private experimentsPath: string;
private log: Logger;
private profileUpdateTimer: {[key: string]: any};
constructor() {
this.experimentsPath = getExperimentsInfoPath();
this.log = getLogger();
this.profileUpdateTimer = {};
}
public async getExperimentsInfo(): Promise<JSON> {
const fileInfo: FileInfo = await this.withLockIterated(this.readExperimentsInfo, 100);
const experimentsInformation = JSON.parse(fileInfo.buffer.toString());
const expIdList: Array<string> = Object.keys(experimentsInformation).filter((expId) => {
return experimentsInformation[expId]['status'] !== 'STOPPED';
});
const updateList: Array<CrashedInfo> = (await Promise.all(expIdList.map((expId) => {
return this.checkCrashed(expId, experimentsInformation[expId]['pid']);
}))).filter(crashedInfo => crashedInfo.isCrashed);
if (updateList.length > 0){
const result = await this.withLockIterated(this.updateAllStatus, 100, updateList.map(crashedInfo => crashedInfo.experimentId), fileInfo.mtime);
if (result !== undefined) {
return JSON.parse(JSON.stringify(Object.keys(result).map(key=>result[key])));
} else {
await delay(500);
return await this.getExperimentsInfo();
}
} else {
return JSON.parse(JSON.stringify(Object.keys(experimentsInformation).map(key=>experimentsInformation[key])));
}
}
public setExperimentPath(newPath: string): void {
if (newPath[0] === '~') {
newPath = path.join(os.homedir(), newPath.slice(1));
}
if (!path.isAbsolute(newPath)) {
newPath = path.resolve(newPath);
}
this.log.info(`Set new experiment information path: ${newPath}`);
this.experimentsPath = newPath;
}
public setExperimentInfo(experimentId: string, key: string, value: any): void {
try {
if (this.profileUpdateTimer[key] !== undefined) {
// if a new call with the same timerId occurs, destroy the unfinished old one
clearTimeout(this.profileUpdateTimer[key]);
this.profileUpdateTimer[key] = undefined;
}
this.withLockSync(() => {
const experimentsInformation = JSON.parse(fs.readFileSync(this.experimentsPath).toString());
assert(experimentId in experimentsInformation, `Experiment Manager: Experiment Id ${experimentId} not found, this should not happen`);
experimentsInformation[experimentId][key] = value;
fs.writeFileSync(this.experimentsPath, JSON.stringify(experimentsInformation, null, 4));
});
} catch (err) {
this.log.error(err);
this.log.debug(`Experiment Manager: Retry set key value: ${experimentId} {${key}: ${value}}`);
if (err.code === 'EEXIST' || err.message === 'File has been locked.') {
this.profileUpdateTimer[key] = setTimeout(this.setExperimentInfo.bind(this), 100, experimentId, key, value);
}
}
}
private async withLockIterated (func: Function, retry: number, ...args: any): Promise<any> {
if (retry < 0) {
throw new Error('Lock file out of retries.');
}
try {
return this.withLockSync(func, ...args);
} catch(err) {
if (err.code === 'EEXIST' || err.message === 'File has been locked.') {
// retry wait is 50ms
await delay(50);
return await this.withLockIterated(func, retry - 1, ...args);
}
throw err;
}
}
private withLockSync (func: Function, ...args: any): any {
return withLockSync(func.bind(this), this.experimentsPath, {stale: 2 * 1000}, ...args);
}
private readExperimentsInfo(): FileInfo {
const buffer: Buffer = fs.readFileSync(this.experimentsPath);
const mtime: number = fs.statSync(this.experimentsPath).mtimeMs;
return {buffer: buffer, mtime: mtime};
}
private async checkCrashed(expId: string, pid: number): Promise<CrashedInfo> {
const alive: boolean = await isAlive(pid);
return {experimentId: expId, isCrashed: !alive}
}
private updateAllStatus(updateList: Array<string>, timestamp: number): {[key: string]: any} | undefined {
if (timestamp !== fs.statSync(this.experimentsPath).mtimeMs) {
return;
} else {
const experimentsInformation = JSON.parse(fs.readFileSync(this.experimentsPath).toString());
updateList.forEach((expId: string) => {
if (experimentsInformation[expId]) {
experimentsInformation[expId]['status'] = 'STOPPED';
} else {
this.log.error(`Experiment Manager: Experiment Id ${expId} not found, this should not happen`);
}
});
fs.writeFileSync(this.experimentsPath, JSON.stringify(experimentsInformation, null, 4));
return experimentsInformation;
}
}
public async stop(): Promise<void> {
this.log.debug('Stopping experiment manager.');
await this.cleanUp().catch(err=>this.log.error(err.message));
this.log.debug('Experiment manager stopped.');
}
private async cleanUp(): Promise<void> {
const deferred = new Deferred<void>();
if (this.isUndone()) {
this.log.debug('Experiment manager: something undone');
setTimeout(((deferred: Deferred<void>): void => {
if (this.isUndone()) {
deferred.reject(new Error('Still has undone after 5s, forced stop.'));
} else {
deferred.resolve();
}
}).bind(this), 5 * 1000, deferred);
} else {
this.log.debug('Experiment manager: all clean up');
deferred.resolve();
}
return deferred.promise;
}
private isUndone(): boolean {
return Object.keys(this.profileUpdateTimer).filter((key: string) => {
return this.profileUpdateTimer[key] !== undefined;
}).length > 0;
}
}
export { NNIExperimentsManager };
...@@ -15,6 +15,7 @@ import { ...@@ -15,6 +15,7 @@ import {
ExperimentParams, ExperimentProfile, Manager, ExperimentStatus, ExperimentParams, ExperimentProfile, Manager, ExperimentStatus,
NNIManagerStatus, ProfileUpdateType, TrialJobStatistics NNIManagerStatus, ProfileUpdateType, TrialJobStatistics
} from '../common/manager'; } from '../common/manager';
import { ExperimentManager } from '../common/experimentManager';
import { import {
TrainingService, TrialJobApplicationForm, TrialJobDetail, TrialJobMetric, TrialJobStatus, LogType TrainingService, TrialJobApplicationForm, TrialJobDetail, TrialJobMetric, TrialJobStatus, LogType
} from '../common/trainingService'; } from '../common/trainingService';
...@@ -31,6 +32,7 @@ import { createDispatcherInterface, IpcInterface } from './ipcInterface'; ...@@ -31,6 +32,7 @@ import { createDispatcherInterface, IpcInterface } from './ipcInterface';
class NNIManager implements Manager { class NNIManager implements Manager {
private trainingService: TrainingService; private trainingService: TrainingService;
private dispatcher: IpcInterface | undefined; private dispatcher: IpcInterface | undefined;
private experimentManager: ExperimentManager;
private currSubmittedTrialNum: number; // need to be recovered private currSubmittedTrialNum: number; // need to be recovered
private trialConcurrencyChange: number; // >0: increase, <0: decrease private trialConcurrencyChange: number; // >0: increase, <0: decrease
private log: Logger; private log: Logger;
...@@ -49,6 +51,7 @@ class NNIManager implements Manager { ...@@ -49,6 +51,7 @@ class NNIManager implements Manager {
this.currSubmittedTrialNum = 0; this.currSubmittedTrialNum = 0;
this.trialConcurrencyChange = 0; this.trialConcurrencyChange = 0;
this.trainingService = component.get(TrainingService); this.trainingService = component.get(TrainingService);
this.experimentManager = component.get(ExperimentManager);
assert(this.trainingService); assert(this.trainingService);
this.dispatcherPid = 0; this.dispatcherPid = 0;
this.waitingTrials = []; this.waitingTrials = [];
...@@ -467,7 +470,9 @@ class NNIManager implements Manager { ...@@ -467,7 +470,9 @@ class NNIManager implements Manager {
} }
} }
await this.trainingService.cleanUp(); await this.trainingService.cleanUp();
this.experimentProfile.endTime = Date.now(); if (this.experimentProfile.endTime === undefined) {
this.setEndtime();
}
await this.storeExperimentProfile(); await this.storeExperimentProfile();
this.setStatus('STOPPED'); this.setStatus('STOPPED');
} }
...@@ -596,7 +601,7 @@ class NNIManager implements Manager { ...@@ -596,7 +601,7 @@ class NNIManager implements Manager {
assert(allFinishedTrialJobNum <= waitSubmittedToFinish); assert(allFinishedTrialJobNum <= waitSubmittedToFinish);
if (allFinishedTrialJobNum >= waitSubmittedToFinish) { if (allFinishedTrialJobNum >= waitSubmittedToFinish) {
this.setStatus('DONE'); this.setStatus('DONE');
this.experimentProfile.endTime = Date.now(); this.setEndtime();
await this.storeExperimentProfile(); await this.storeExperimentProfile();
// write this log for travis CI // write this log for travis CI
this.log.info('Experiment done.'); this.log.info('Experiment done.');
...@@ -796,6 +801,7 @@ class NNIManager implements Manager { ...@@ -796,6 +801,7 @@ class NNIManager implements Manager {
this.log.error(err.stack); this.log.error(err.stack);
} }
this.status.errors.push(err.message); this.status.errors.push(err.message);
this.setEndtime();
this.setStatus('ERROR'); this.setStatus('ERROR');
} }
...@@ -803,9 +809,15 @@ class NNIManager implements Manager { ...@@ -803,9 +809,15 @@ class NNIManager implements Manager {
if (status !== this.status.status) { if (status !== this.status.status) {
this.log.info(`Change NNIManager status from: ${this.status.status} to: ${status}`); this.log.info(`Change NNIManager status from: ${this.status.status} to: ${status}`);
this.status.status = status; this.status.status = status;
this.experimentManager.setExperimentInfo(this.experimentProfile.id, 'status', this.status.status);
} }
} }
private setEndtime(): void {
this.experimentProfile.endTime = Date.now();
this.experimentManager.setExperimentInfo(this.experimentProfile.id, 'endTime', this.experimentProfile.endTime);
}
private createEmptyExperimentProfile(): ExperimentProfile { private createEmptyExperimentProfile(): ExperimentProfile {
return { return {
id: getExperimentId(), id: getExperimentId(),
......
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
'use strict';
import { assert, expect } from 'chai';
import * as fs from 'fs';
import { Container, Scope } from 'typescript-ioc';
import * as component from '../../common/component';
import { cleanupUnitTest, prepareUnitTest } from '../../common/utils';
import { ExperimentManager } from '../../common/experimentManager';
import { NNIExperimentsManager } from '../nniExperimentsManager';
describe('Unit test for experiment manager', function () {
let experimentManager: NNIExperimentsManager;
const mockedInfo = {
"test": {
"port": 8080,
"startTime": 1605246730756,
"endTime": "N/A",
"status": "INITIALIZED",
"platform": "local",
"experimentName": "testExp",
"tag": [], "pid": 11111,
"webuiUrl": [],
"logDir": null
}
}
before(() => {
prepareUnitTest();
fs.writeFileSync('.experiment.test', JSON.stringify(mockedInfo));
Container.bind(ExperimentManager).to(NNIExperimentsManager).scope(Scope.Singleton);
experimentManager = component.get(NNIExperimentsManager);
experimentManager.setExperimentPath('.experiment.test');
});
after(() => {
if (fs.existsSync('.experiment.test')) {
fs.unlinkSync('.experiment.test');
}
cleanupUnitTest();
});
it('test getExperimentsInfo', () => {
return experimentManager.getExperimentsInfo().then(function (experimentsInfo: {[key: string]: any}) {
new Array(experimentsInfo)
for (let idx in experimentsInfo) {
if (experimentsInfo[idx]['id'] === 'test') {
expect(experimentsInfo[idx]['status']).to.be.oneOf(['STOPPED', 'ERROR']);
break;
}
}
}).catch((error) => {
assert.fail(error);
})
});
});
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
'use strict'; 'use strict';
import * as fs from 'fs';
import * as os from 'os'; import * as os from 'os';
import { assert, expect } from 'chai'; import { assert, expect } from 'chai';
import { Container, Scope } from 'typescript-ioc'; import { Container, Scope } from 'typescript-ioc';
...@@ -10,9 +11,10 @@ import { Container, Scope } from 'typescript-ioc'; ...@@ -10,9 +11,10 @@ import { Container, Scope } from 'typescript-ioc';
import * as component from '../../common/component'; import * as component from '../../common/component';
import { Database, DataStore } from '../../common/datastore'; import { Database, DataStore } from '../../common/datastore';
import { Manager, ExperimentProfile} from '../../common/manager'; import { Manager, ExperimentProfile} from '../../common/manager';
import { ExperimentManager } from '../../common/experimentManager';
import { TrainingService } from '../../common/trainingService'; import { TrainingService } from '../../common/trainingService';
import { cleanupUnitTest, prepareUnitTest } from '../../common/utils'; import { cleanupUnitTest, prepareUnitTest } from '../../common/utils';
import { NNIDataStore } from '../nniDataStore'; import { NNIExperimentsManager } from '../nniExperimentsManager';
import { NNIManager } from '../nnimanager'; import { NNIManager } from '../nnimanager';
import { SqlDB } from '../sqlDatabase'; import { SqlDB } from '../sqlDatabase';
import { MockedTrainingService } from './mockedTrainingService'; import { MockedTrainingService } from './mockedTrainingService';
...@@ -25,6 +27,7 @@ async function initContainer(): Promise<void> { ...@@ -25,6 +27,7 @@ async function initContainer(): Promise<void> {
Container.bind(Manager).to(NNIManager).scope(Scope.Singleton); Container.bind(Manager).to(NNIManager).scope(Scope.Singleton);
Container.bind(Database).to(SqlDB).scope(Scope.Singleton); Container.bind(Database).to(SqlDB).scope(Scope.Singleton);
Container.bind(DataStore).to(MockedDataStore).scope(Scope.Singleton); Container.bind(DataStore).to(MockedDataStore).scope(Scope.Singleton);
Container.bind(ExperimentManager).to(NNIExperimentsManager).scope(Scope.Singleton);
await component.get<DataStore>(DataStore).init(); await component.get<DataStore>(DataStore).init();
} }
...@@ -87,9 +90,26 @@ describe('Unit test for nnimanager', function () { ...@@ -87,9 +90,26 @@ describe('Unit test for nnimanager', function () {
revision: 0 revision: 0
} }
let mockedInfo = {
"unittest": {
"port": 8080,
"startTime": 1605246730756,
"endTime": "N/A",
"status": "INITIALIZED",
"platform": "local",
"experimentName": "testExp",
"tag": [], "pid": 11111,
"webuiUrl": [],
"logDir": null
}
}
before(async () => { before(async () => {
await initContainer(); await initContainer();
fs.writeFileSync('.experiment.test', JSON.stringify(mockedInfo));
const experimentsManager: ExperimentManager = component.get(ExperimentManager);
experimentsManager.setExperimentPath('.experiment.test');
nniManager = component.get(Manager); nniManager = component.get(Manager);
const expId: string = await nniManager.startExperiment(experimentParams); const expId: string = await nniManager.startExperiment(experimentParams);
assert.strictEqual(expId, 'unittest'); assert.strictEqual(expId, 'unittest');
......
...@@ -12,11 +12,13 @@ import { Database, DataStore } from './common/datastore'; ...@@ -12,11 +12,13 @@ import { Database, DataStore } from './common/datastore';
import { setExperimentStartupInfo } from './common/experimentStartupInfo'; import { setExperimentStartupInfo } from './common/experimentStartupInfo';
import { getLogger, Logger, logLevelNameMap } from './common/log'; import { getLogger, Logger, logLevelNameMap } from './common/log';
import { Manager, ExperimentStartUpMode } from './common/manager'; import { Manager, ExperimentStartUpMode } from './common/manager';
import { ExperimentManager } from './common/experimentManager';
import { TrainingService } from './common/trainingService'; import { TrainingService } from './common/trainingService';
import { getLogDir, mkDirP, parseArg, uniqueString } from './common/utils'; import { getLogDir, mkDirP, parseArg } from './common/utils';
import { NNIDataStore } from './core/nniDataStore'; import { NNIDataStore } from './core/nniDataStore';
import { NNIManager } from './core/nnimanager'; import { NNIManager } from './core/nnimanager';
import { SqlDB } from './core/sqlDatabase'; import { SqlDB } from './core/sqlDatabase';
import { NNIExperimentsManager } from './core/nniExperimentsManager';
import { NNIRestServer } from './rest_server/nniRestServer'; import { NNIRestServer } from './rest_server/nniRestServer';
import { FrameworkControllerTrainingService } from './training_service/kubernetes/frameworkcontroller/frameworkcontrollerTrainingService'; import { FrameworkControllerTrainingService } from './training_service/kubernetes/frameworkcontroller/frameworkcontrollerTrainingService';
import { AdlTrainingService } from './training_service/kubernetes/adl/adlTrainingService'; import { AdlTrainingService } from './training_service/kubernetes/adl/adlTrainingService';
...@@ -27,11 +29,10 @@ import { PAIYarnTrainingService } from './training_service/pai/paiYarn/paiYarnTr ...@@ -27,11 +29,10 @@ import { PAIYarnTrainingService } from './training_service/pai/paiYarn/paiYarnTr
import { DLTSTrainingService } from './training_service/dlts/dltsTrainingService'; import { DLTSTrainingService } from './training_service/dlts/dltsTrainingService';
function initStartupInfo( function initStartupInfo(
startExpMode: string, resumeExperimentId: string, basePort: number, platform: string, startExpMode: string, experimentId: string, basePort: number, platform: string,
logDirectory: string, experimentLogLevel: string, readonly: boolean): void { logDirectory: string, experimentLogLevel: string, readonly: boolean): void {
const createNew: boolean = (startExpMode === ExperimentStartUpMode.NEW); const createNew: boolean = (startExpMode === ExperimentStartUpMode.NEW);
const expId: string = createNew ? uniqueString(8) : resumeExperimentId; setExperimentStartupInfo(createNew, experimentId, basePort, platform, logDirectory, experimentLogLevel, readonly);
setExperimentStartupInfo(createNew, expId, basePort, platform, logDirectory, experimentLogLevel, readonly);
} }
async function initContainer(foreground: boolean, platformMode: string, logFileName?: string): Promise<void> { async function initContainer(foreground: boolean, platformMode: string, logFileName?: string): Promise<void> {
...@@ -83,6 +84,9 @@ async function initContainer(foreground: boolean, platformMode: string, logFileN ...@@ -83,6 +84,9 @@ async function initContainer(foreground: boolean, platformMode: string, logFileN
Container.bind(DataStore) Container.bind(DataStore)
.to(NNIDataStore) .to(NNIDataStore)
.scope(Scope.Singleton); .scope(Scope.Singleton);
Container.bind(ExperimentManager)
.to(NNIExperimentsManager)
.scope(Scope.Singleton);
const DEFAULT_LOGFILE: string = path.join(getLogDir(), 'nnimanager.log'); const DEFAULT_LOGFILE: string = path.join(getLogDir(), 'nnimanager.log');
if (foreground) { if (foreground) {
logFileName = undefined; logFileName = undefined;
...@@ -133,7 +137,7 @@ if (![ExperimentStartUpMode.NEW, ExperimentStartUpMode.RESUME].includes(startMod ...@@ -133,7 +137,7 @@ if (![ExperimentStartUpMode.NEW, ExperimentStartUpMode.RESUME].includes(startMod
} }
const experimentId: string = parseArg(['--experiment_id', '-id']); const experimentId: string = parseArg(['--experiment_id', '-id']);
if ((startMode === ExperimentStartUpMode.RESUME) && experimentId.trim().length < 1) { if (experimentId.trim().length < 1) {
console.log(`FATAL: cannot resume the experiment, invalid experiment_id: ${experimentId}`); console.log(`FATAL: cannot resume the experiment, invalid experiment_id: ${experimentId}`);
usage(); usage();
process.exit(1); process.exit(1);
...@@ -185,6 +189,8 @@ async function cleanUp(): Promise<void> { ...@@ -185,6 +189,8 @@ async function cleanUp(): Promise<void> {
try { try {
const nniManager: Manager = component.get(Manager); const nniManager: Manager = component.get(Manager);
await nniManager.stopExperiment(); await nniManager.stopExperiment();
const experimentManager: ExperimentManager = component.get(ExperimentManager);
await experimentManager.stop();
const ds: DataStore = component.get(DataStore); const ds: DataStore = component.get(DataStore);
await ds.close(); await ds.close();
const restServer: NNIRestServer = component.get(NNIRestServer); const restServer: NNIRestServer = component.get(NNIRestServer);
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
"ignore": "^5.1.4", "ignore": "^5.1.4",
"js-base64": "^2.4.9", "js-base64": "^2.4.9",
"kubernetes-client": "^6.5.0", "kubernetes-client": "^6.5.0",
"lockfile": "^1.0.4",
"python-shell": "^2.0.1", "python-shell": "^2.0.1",
"rx": "^4.1.0", "rx": "^4.1.0",
"sqlite3": "^5.0.0", "sqlite3": "^5.0.0",
...@@ -39,6 +40,7 @@ ...@@ -39,6 +40,7 @@
"@types/glob": "^7.1.1", "@types/glob": "^7.1.1",
"@types/js-base64": "^2.3.1", "@types/js-base64": "^2.3.1",
"@types/js-yaml": "^3.12.5", "@types/js-yaml": "^3.12.5",
"@types/lockfile": "^1.0.0",
"@types/mocha": "^8.0.3", "@types/mocha": "^8.0.3",
"@types/node": "10.12.18", "@types/node": "10.12.18",
"@types/request": "^2.47.1", "@types/request": "^2.47.1",
......
...@@ -12,20 +12,22 @@ import { NNIError, NNIErrorNames } from '../common/errors'; ...@@ -12,20 +12,22 @@ import { NNIError, NNIErrorNames } from '../common/errors';
import { isNewExperiment, isReadonly } from '../common/experimentStartupInfo'; import { isNewExperiment, isReadonly } from '../common/experimentStartupInfo';
import { getLogger, Logger } from '../common/log'; import { getLogger, Logger } from '../common/log';
import { ExperimentProfile, Manager, TrialJobStatistics } from '../common/manager'; import { ExperimentProfile, Manager, TrialJobStatistics } from '../common/manager';
import { ExperimentManager } from '../common/experimentManager';
import { ValidationSchemas } from './restValidationSchemas'; import { ValidationSchemas } from './restValidationSchemas';
import { NNIRestServer } from './nniRestServer'; import { NNIRestServer } from './nniRestServer';
import { getVersion } from '../common/utils'; import { getVersion } from '../common/utils';
import { NNIManager } from "../core/nnimanager";
const expressJoi = require('express-joi-validator'); const expressJoi = require('express-joi-validator');
class NNIRestHandler { class NNIRestHandler {
private restServer: NNIRestServer; private restServer: NNIRestServer;
private nniManager: NNIManager; private nniManager: Manager;
private experimentsManager: ExperimentManager;
private log: Logger; private log: Logger;
constructor(rs: NNIRestServer) { constructor(rs: NNIRestServer) {
this.nniManager = component.get(Manager); this.nniManager = component.get(Manager);
this.experimentsManager = component.get(ExperimentManager);
this.restServer = rs; this.restServer = rs;
this.log = getLogger(); this.log = getLogger();
} }
...@@ -61,6 +63,7 @@ class NNIRestHandler { ...@@ -61,6 +63,7 @@ class NNIRestHandler {
this.getLatestMetricData(router); this.getLatestMetricData(router);
this.getTrialLog(router); this.getTrialLog(router);
this.exportData(router); this.exportData(router);
this.getExperimentsInfo(router);
// Express-joi-validator configuration // Express-joi-validator configuration
router.use((err: any, _req: Request, res: Response, _next: any) => { router.use((err: any, _req: Request, res: Response, _next: any) => {
...@@ -306,6 +309,16 @@ class NNIRestHandler { ...@@ -306,6 +309,16 @@ class NNIRestHandler {
}); });
} }
private getExperimentsInfo(router: Router): void {
router.get('/experiments-info', (req: Request, res: Response) => {
this.experimentsManager.getExperimentsInfo().then((experimentInfo: JSON) => {
res.send(JSON.stringify(experimentInfo));
}).catch((err: Error) => {
this.handleError(err, res);
});
});
}
private setErrorPathForFailedJob(jobInfo: TrialJobInfo): TrialJobInfo { private setErrorPathForFailedJob(jobInfo: TrialJobInfo): TrialJobInfo {
if (jobInfo === undefined || jobInfo.status !== 'FAILED' || jobInfo.logPath === undefined) { if (jobInfo === undefined || jobInfo.status !== 'FAILED' || jobInfo.logPath === undefined) {
return jobInfo; return jobInfo;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment