Unverified Commit e6629a76 authored by J-shang's avatar J-shang Committed by GitHub
Browse files

Merge .config & nni.sqlite (#3309)

* refactor util config

* merge NNICTL_HOME_DIR and NNI_HOME_DIR

* add default logDir and fix Config logdir

* fix save/load

* modify updater and tensorboard

* rename

* rename

* fix local clusterMetaData

* fix local config convert

* fix ut

* fix ut

* del unused code

* fix cursor
parent 85c0d841
...@@ -293,10 +293,11 @@ def to_rest_json(config: ExperimentConfig) -> Dict[str, Any]: ...@@ -293,10 +293,11 @@ def to_rest_json(config: ExperimentConfig) -> Dict[str, Any]:
request_data['logCollection'] = experiment_config.get('logCollection') request_data['logCollection'] = experiment_config.get('logCollection')
request_data['clusterMetaData'] = [] request_data['clusterMetaData'] = []
if experiment_config['trainingServicePlatform'] == 'local': if experiment_config['trainingServicePlatform'] == 'local':
if experiment_config.get('localConfig'):
request_data['clusterMetaData'].append(
{'key': 'local_config', 'value': experiment_config['localConfig']})
request_data['clusterMetaData'].append( request_data['clusterMetaData'].append(
{'key':'codeDir', 'value':experiment_config['trial']['codeDir']}) {'key': 'trial_config', 'value': experiment_config['trial']})
request_data['clusterMetaData'].append(
{'key': 'command', 'value': experiment_config['trial']['command']})
elif experiment_config['trainingServicePlatform'] == 'remote': elif experiment_config['trainingServicePlatform'] == 'remote':
request_data['clusterMetaData'].append( request_data['clusterMetaData'].append(
{'key': 'machine_list', 'value': experiment_config['machineList']}) {'key': 'machine_list', 'value': experiment_config['machineList']})
......
...@@ -119,5 +119,5 @@ def _init_experiment(config: ExperimentConfig, port: int, debug: bool) -> None: ...@@ -119,5 +119,5 @@ def _init_experiment(config: ExperimentConfig, port: int, debug: bool) -> None:
def _save_experiment_information(experiment_id: str, port: int, start_time: int, platform: str, name: str, pid: int, logDir: str) -> None: def _save_experiment_information(experiment_id: str, port: int, start_time: int, platform: str, name: str, pid: int, logDir: str) -> None:
experiment_config = Experiments() experiments_config = Experiments()
experiment_config.add_experiment(experiment_id, port, start_time, platform, name, pid=pid, logDir=logDir) experiments_config.add_experiment(experiment_id, port, start_time, platform, name, pid=pid, logDir=logDir)
...@@ -4,56 +4,97 @@ ...@@ -4,56 +4,97 @@
import os import os
import json import json
import shutil import shutil
import sqlite3
import time import time
from .constants import NNICTL_HOME_DIR from .constants import NNI_HOME_DIR
from .command_utils import print_error from .command_utils import print_error
from .common_utils import get_file_lock from .common_utils import get_file_lock
def config_v0_to_v1(config: dict) -> dict:
if 'clusterMetaData' not in config:
return config
elif 'trainingServicePlatform' in config:
import copy
experiment_config = copy.deepcopy(config)
if experiment_config['trainingServicePlatform'] == 'hybrid':
inverse_config = {'hybridConfig': experiment_config['clusterMetaData']['hybrid_config']}
platform_list = inverse_config['hybridConfig']['trainingServicePlatforms']
for platform in platform_list:
inverse_config.update(_inverse_cluster_metadata(platform, experiment_config['clusterMetaData']))
experiment_config.update(inverse_config)
else:
inverse_config = _inverse_cluster_metadata(experiment_config['trainingServicePlatform'], experiment_config['clusterMetaData'])
experiment_config.update(inverse_config)
experiment_config.pop('clusterMetaData')
return experiment_config
else:
raise RuntimeError('experiment config key `trainingServicePlatform` not found')
def _inverse_cluster_metadata(platform: str, metadata_config: list) -> dict:
inverse_config = {}
if platform == 'local':
inverse_config['trial'] = {}
for kv in metadata_config:
if kv['key'] == 'local_config':
inverse_config['localConfig'] = kv['value']
elif kv['key'] == 'trial_config':
inverse_config['trial'] = kv['value']
elif platform == 'remote':
for kv in metadata_config:
if kv['key'] == 'machine_list':
inverse_config['machineList'] = kv['value']
elif kv['key'] == 'trial_config':
inverse_config['trial'] = kv['value']
elif kv['key'] == 'remote_config':
inverse_config['remoteConfig'] = kv['value']
elif platform == 'pai':
for kv in metadata_config:
if kv['key'] == 'pai_config':
inverse_config['paiConfig'] = kv['value']
elif kv['key'] == 'trial_config':
inverse_config['trial'] = kv['value']
elif platform == 'kubeflow':
for kv in metadata_config:
if kv['key'] == 'kubeflow_config':
inverse_config['kubeflowConfig'] = kv['value']
elif kv['key'] == 'trial_config':
inverse_config['trial'] = kv['value']
elif platform == 'frameworkcontroller':
for kv in metadata_config:
if kv['key'] == 'frameworkcontroller_config':
inverse_config['frameworkcontrollerConfig'] = kv['value']
elif kv['key'] == 'trial_config':
inverse_config['trial'] = kv['value']
elif platform == 'aml':
for kv in metadata_config:
if kv['key'] == 'aml_config':
inverse_config['amlConfig'] = kv['value']
elif kv['key'] == 'trial_config':
inverse_config['trial'] = kv['value']
else:
raise RuntimeError('training service platform not found')
return inverse_config
class Config: class Config:
'''a util class to load and save config''' '''a util class to load and save config'''
def __init__(self, file_path, home_dir=NNICTL_HOME_DIR): def __init__(self, experiment_id: str, log_dir: str):
config_path = os.path.join(home_dir, str(file_path)) self.experiment_id = experiment_id
os.makedirs(config_path, exist_ok=True) self.conn = sqlite3.connect(os.path.join(log_dir, experiment_id, 'db', 'nni.sqlite'))
self.config_file = os.path.join(config_path, '.config') self.refresh_config()
self.config = self.read_file()
def get_all_config(self): def refresh_config(self):
'''get all of config values''' '''refresh to get latest config'''
return json.dumps(self.config, indent=4, sort_keys=True, separators=(',', ':')) sql = 'select params from ExperimentProfile where id=? order by revision DESC'
args = (self.experiment_id,)
self.config = config_v0_to_v1(json.loads(self.conn.cursor().execute(sql, args).fetchone()[0]))
def set_config(self, key, value): def get_config(self):
'''set {key:value} paris to self.config'''
self.config = self.read_file()
self.config[key] = value
self.write_file()
def get_config(self, key):
'''get a value according to key''' '''get a value according to key'''
return self.config.get(key) return self.config
def write_file(self):
'''save config to local file'''
if self.config:
try:
with open(self.config_file, 'w') as file:
json.dump(self.config, file, indent=4)
except IOError as error:
print('Error:', error)
return
def read_file(self):
'''load config from local file'''
if os.path.exists(self.config_file):
try:
with open(self.config_file, 'r') as file:
return json.load(file)
except ValueError:
return {}
return {}
class Experiments: class Experiments:
'''Maintain experiment list''' '''Maintain experiment list'''
def __init__(self, home_dir=NNICTL_HOME_DIR): def __init__(self, home_dir=NNI_HOME_DIR):
os.makedirs(home_dir, exist_ok=True) os.makedirs(home_dir, exist_ok=True)
self.experiment_file = os.path.join(home_dir, '.experiment') self.experiment_file = os.path.join(home_dir, '.experiment')
self.lock = get_file_lock(self.experiment_file, stale=2) self.lock = get_file_lock(self.experiment_file, stale=2)
...@@ -61,7 +102,7 @@ class Experiments: ...@@ -61,7 +102,7 @@ class Experiments:
self.experiments = self.read_file() self.experiments = self.read_file()
def add_experiment(self, expId, port, startTime, platform, experiment_name, endTime='N/A', status='INITIALIZED', def add_experiment(self, expId, port, startTime, platform, experiment_name, endTime='N/A', status='INITIALIZED',
tag=[], pid=None, webuiUrl=[], logDir=[]): tag=[], pid=None, webuiUrl=[], logDir=''):
'''set {key:value} pairs to self.experiment''' '''set {key:value} pairs to self.experiment'''
with self.lock: with self.lock:
self.experiments = self.read_file() self.experiments = self.read_file()
...@@ -98,13 +139,6 @@ class Experiments: ...@@ -98,13 +139,6 @@ class Experiments:
self.experiments = self.read_file() self.experiments = self.read_file()
if expId in self.experiments: if expId in self.experiments:
self.experiments.pop(expId) self.experiments.pop(expId)
fileName = expId
if fileName:
logPath = os.path.join(NNICTL_HOME_DIR, fileName)
try:
shutil.rmtree(logPath)
except FileNotFoundError:
print_error('{0} does not exist.'.format(logPath))
self.write_file() self.write_file()
def get_all_experiments(self): def get_all_experiments(self):
......
...@@ -4,8 +4,6 @@ ...@@ -4,8 +4,6 @@
import os import os
from colorama import Fore from colorama import Fore
NNICTL_HOME_DIR = os.path.join(os.path.expanduser('~'), 'nni-experiments')
NNI_HOME_DIR = os.path.join(os.path.expanduser('~'), 'nni-experiments') NNI_HOME_DIR = os.path.join(os.path.expanduser('~'), 'nni-experiments')
ERROR_INFO = 'ERROR: ' ERROR_INFO = 'ERROR: '
......
...@@ -20,15 +20,15 @@ from .config_utils import Config, Experiments ...@@ -20,15 +20,15 @@ from .config_utils import Config, Experiments
from .common_utils import get_yml_content, get_json_content, print_error, print_normal, print_warning, \ from .common_utils import get_yml_content, get_json_content, print_error, print_normal, print_warning, \
detect_port, get_user detect_port, get_user
from .constants import NNICTL_HOME_DIR, ERROR_INFO, REST_TIME_OUT, EXPERIMENT_SUCCESS_INFO, LOG_HEADER from .constants import NNI_HOME_DIR, ERROR_INFO, REST_TIME_OUT, EXPERIMENT_SUCCESS_INFO, LOG_HEADER
from .command_utils import check_output_command, kill_command from .command_utils import check_output_command, kill_command
from .nnictl_utils import update_experiment from .nnictl_utils import update_experiment
def get_log_path(experiment_id): def get_log_path(experiment_id):
'''generate stdout and stderr log path''' '''generate stdout and stderr log path'''
os.makedirs(os.path.join(NNICTL_HOME_DIR, experiment_id, 'log'), exist_ok=True) os.makedirs(os.path.join(NNI_HOME_DIR, experiment_id, 'log'), exist_ok=True)
stdout_full_path = os.path.join(NNICTL_HOME_DIR, experiment_id, 'log', 'nnictl_stdout.log') stdout_full_path = os.path.join(NNI_HOME_DIR, experiment_id, 'log', 'nnictl_stdout.log')
stderr_full_path = os.path.join(NNICTL_HOME_DIR, experiment_id, 'log', 'nnictl_stderr.log') stderr_full_path = os.path.join(NNI_HOME_DIR, experiment_id, 'log', 'nnictl_stderr.log')
return stdout_full_path, stderr_full_path return stdout_full_path, stderr_full_path
def print_log_content(config_file_name): def print_log_content(config_file_name):
...@@ -375,10 +375,11 @@ def set_experiment(experiment_config, mode, port, config_file_name): ...@@ -375,10 +375,11 @@ def set_experiment(experiment_config, mode, port, config_file_name):
request_data['logCollection'] = experiment_config.get('logCollection') request_data['logCollection'] = experiment_config.get('logCollection')
request_data['clusterMetaData'] = [] request_data['clusterMetaData'] = []
if experiment_config['trainingServicePlatform'] == 'local': if experiment_config['trainingServicePlatform'] == 'local':
if experiment_config.get('localConfig'):
request_data['clusterMetaData'].append(
{'key': 'local_config', 'value': experiment_config['localConfig']})
request_data['clusterMetaData'].append( request_data['clusterMetaData'].append(
{'key':'codeDir', 'value':experiment_config['trial']['codeDir']}) {'key': 'trial_config', 'value': experiment_config['trial']})
request_data['clusterMetaData'].append(
{'key': 'command', 'value': experiment_config['trial']['command']})
elif experiment_config['trainingServicePlatform'] == 'remote': elif experiment_config['trainingServicePlatform'] == 'remote':
request_data['clusterMetaData'].append( request_data['clusterMetaData'].append(
{'key': 'machine_list', 'value': experiment_config['machineList']}) {'key': 'machine_list', 'value': experiment_config['machineList']})
...@@ -479,7 +480,6 @@ def set_platform_config(platform, experiment_config, port, config_file_name, res ...@@ -479,7 +480,6 @@ def set_platform_config(platform, experiment_config, port, config_file_name, res
def launch_experiment(args, experiment_config, mode, experiment_id): def launch_experiment(args, experiment_config, mode, experiment_id):
'''follow steps to start rest server and start experiment''' '''follow steps to start rest server and start experiment'''
nni_config = Config(experiment_id)
# check packages for tuner # check packages for tuner
package_name, module_name = None, None package_name, module_name = None, None
if experiment_config.get('tuner') and experiment_config['tuner'].get('builtinTunerName'): if experiment_config.get('tuner') and experiment_config['tuner'].get('builtinTunerName'):
...@@ -499,7 +499,7 @@ def launch_experiment(args, experiment_config, mode, experiment_id): ...@@ -499,7 +499,7 @@ def launch_experiment(args, experiment_config, mode, experiment_id):
if package_name in ['SMAC', 'BOHB', 'PPOTuner']: if package_name in ['SMAC', 'BOHB', 'PPOTuner']:
print_error(f'The dependencies for {package_name} can be installed through pip install nni[{package_name}]') print_error(f'The dependencies for {package_name} can be installed through pip install nni[{package_name}]')
raise raise
log_dir = experiment_config['logDir'] if experiment_config.get('logDir') else None log_dir = experiment_config['logDir'] if experiment_config.get('logDir') else NNI_HOME_DIR
log_level = experiment_config['logLevel'] if experiment_config.get('logLevel') else None log_level = experiment_config['logLevel'] if experiment_config.get('logLevel') else None
#view experiment mode do not need debug function, when view an experiment, there will be no new logs created #view experiment mode do not need debug function, when view an experiment, there will be no new logs created
foreground = False foreground = False
...@@ -510,12 +510,10 @@ def launch_experiment(args, experiment_config, mode, experiment_id): ...@@ -510,12 +510,10 @@ def launch_experiment(args, experiment_config, mode, experiment_id):
# start rest server # start rest server
rest_process, start_time = start_rest_server(args.port, experiment_config['trainingServicePlatform'], \ rest_process, start_time = start_rest_server(args.port, experiment_config['trainingServicePlatform'], \
mode, experiment_id, foreground, log_dir, log_level) mode, experiment_id, foreground, log_dir, log_level)
nni_config.set_config('restServerPid', rest_process.pid)
# save experiment information # save experiment information
nnictl_experiment_config = Experiments() Experiments().add_experiment(experiment_id, args.port, start_time,
nnictl_experiment_config.add_experiment(experiment_id, args.port, start_time, experiment_config['trainingServicePlatform'],
experiment_config['trainingServicePlatform'], experiment_config['experimentName'], pid=rest_process.pid, logDir=log_dir)
experiment_config['experimentName'], pid=rest_process.pid, logDir=log_dir)
# Deal with annotation # Deal with annotation
if experiment_config.get('useAnnotation'): if experiment_config.get('useAnnotation'):
path = os.path.join(tempfile.gettempdir(), get_user(), 'nni', 'annotation') path = os.path.join(tempfile.gettempdir(), get_user(), 'nni', 'annotation')
...@@ -572,7 +570,7 @@ def launch_experiment(args, experiment_config, mode, experiment_id): ...@@ -572,7 +570,7 @@ def launch_experiment(args, experiment_config, mode, experiment_id):
web_ui_url_list = ['http://{0}:{1}'.format(experiment_config['nniManagerIp'], str(args.port))] web_ui_url_list = ['http://{0}:{1}'.format(experiment_config['nniManagerIp'], str(args.port))]
else: else:
web_ui_url_list = get_local_urls(args.port) web_ui_url_list = get_local_urls(args.port)
nni_config.set_config('webuiUrl', web_ui_url_list) Experiments().update_experiment(experiment_id, 'webuiUrl', web_ui_url_list)
print_normal(EXPERIMENT_SUCCESS_INFO % (experiment_id, ' '.join(web_ui_url_list))) print_normal(EXPERIMENT_SUCCESS_INFO % (experiment_id, ' '.join(web_ui_url_list)))
if mode != 'view' and args.foreground: if mode != 'view' and args.foreground:
...@@ -587,8 +585,6 @@ def launch_experiment(args, experiment_config, mode, experiment_id): ...@@ -587,8 +585,6 @@ def launch_experiment(args, experiment_config, mode, experiment_id):
def create_experiment(args): def create_experiment(args):
'''start a new experiment''' '''start a new experiment'''
experiment_id = ''.join(random.sample(string.ascii_letters + string.digits, 8)) experiment_id = ''.join(random.sample(string.ascii_letters + string.digits, 8))
nni_config = Config(experiment_id)
nni_config.set_config('experimentId', experiment_id)
config_path = os.path.abspath(args.config) config_path = os.path.abspath(args.config)
if not os.path.exists(config_path): if not os.path.exists(config_path):
print_error('Please set correct config path!') print_error('Please set correct config path!')
...@@ -610,8 +606,6 @@ def create_experiment(args): ...@@ -610,8 +606,6 @@ def create_experiment(args):
print_error(f'Config in v1 format validation failed. {repr(e)}') print_error(f'Config in v1 format validation failed. {repr(e)}')
exit(1) exit(1)
nni_config.set_config('experimentConfig', experiment_config)
nni_config.set_config('restServerPort', args.port)
try: try:
launch_experiment(args, experiment_config, 'new', experiment_id) launch_experiment(args, experiment_config, 'new', experiment_id)
except Exception as exception: except Exception as exception:
...@@ -624,8 +618,8 @@ def create_experiment(args): ...@@ -624,8 +618,8 @@ def create_experiment(args):
def manage_stopped_experiment(args, mode): def manage_stopped_experiment(args, mode):
'''view a stopped experiment''' '''view a stopped experiment'''
update_experiment() update_experiment()
experiment_config = Experiments() experiments_config = Experiments()
experiment_dict = experiment_config.get_all_experiments() experiments_dict = experiments_config.get_all_experiments()
experiment_id = None experiment_id = None
#find the latest stopped experiment #find the latest stopped experiment
if not args.id: if not args.id:
...@@ -633,17 +627,16 @@ def manage_stopped_experiment(args, mode): ...@@ -633,17 +627,16 @@ def manage_stopped_experiment(args, mode):
'You could use \'nnictl experiment list --all\' to show all experiments!'.format(mode)) 'You could use \'nnictl experiment list --all\' to show all experiments!'.format(mode))
exit(1) exit(1)
else: else:
if experiment_dict.get(args.id) is None: if experiments_dict.get(args.id) is None:
print_error('Id %s not exist!' % args.id) print_error('Id %s not exist!' % args.id)
exit(1) exit(1)
if experiment_dict[args.id]['status'] != 'STOPPED': if experiments_dict[args.id]['status'] != 'STOPPED':
print_error('Only stopped experiments can be {0}ed!'.format(mode)) print_error('Only stopped experiments can be {0}ed!'.format(mode))
exit(1) exit(1)
experiment_id = args.id experiment_id = args.id
print_normal('{0} experiment {1}...'.format(mode, experiment_id)) print_normal('{0} experiment {1}...'.format(mode, experiment_id))
nni_config = Config(experiment_id) experiment_config = Config(experiment_id, experiments_dict[args.id]['logDir']).get_config()
experiment_config = nni_config.get_config('experimentConfig') experiments_config.update_experiment(args.id, 'port', args.port)
nni_config.set_config('restServerPort', args.port)
try: try:
launch_experiment(args, experiment_config, mode, experiment_id) launch_experiment(args, experiment_config, mode, experiment_id)
except Exception as exception: except Exception as exception:
......
...@@ -19,8 +19,8 @@ import nni_node ...@@ -19,8 +19,8 @@ import nni_node
from .rest_utils import rest_get, rest_delete, check_rest_server_quick, check_response from .rest_utils import rest_get, rest_delete, check_rest_server_quick, check_response
from .url_utils import trial_jobs_url, experiment_url, trial_job_id_url, export_data_url, metric_data_url from .url_utils import trial_jobs_url, experiment_url, trial_job_id_url, export_data_url, metric_data_url
from .config_utils import Config, Experiments from .config_utils import Config, Experiments
from .constants import NNICTL_HOME_DIR, NNI_HOME_DIR, EXPERIMENT_INFORMATION_FORMAT, EXPERIMENT_DETAIL_FORMAT, \ from .constants import NNI_HOME_DIR, EXPERIMENT_INFORMATION_FORMAT, EXPERIMENT_DETAIL_FORMAT, EXPERIMENT_MONITOR_INFO, \
EXPERIMENT_MONITOR_INFO, TRIAL_MONITOR_HEAD, TRIAL_MONITOR_CONTENT, TRIAL_MONITOR_TAIL, REST_TIME_OUT TRIAL_MONITOR_HEAD, TRIAL_MONITOR_CONTENT, TRIAL_MONITOR_TAIL, REST_TIME_OUT
from .common_utils import print_normal, print_error, print_warning, detect_process, get_yml_content, generate_temp_dir from .common_utils import print_normal, print_error, print_warning, detect_process, get_yml_content, generate_temp_dir
from .common_utils import print_green from .common_utils import print_green
from .command_utils import check_output_command, kill_command from .command_utils import check_output_command, kill_command
...@@ -43,16 +43,16 @@ def get_experiment_status(port): ...@@ -43,16 +43,16 @@ def get_experiment_status(port):
def update_experiment(): def update_experiment():
'''Update the experiment status in config file''' '''Update the experiment status in config file'''
experiment_config = Experiments() experiments_config = Experiments()
experiment_dict = experiment_config.get_all_experiments() experiments_dict = experiments_config.get_all_experiments()
if not experiment_dict: if not experiments_dict:
return None return None
for key in experiment_dict.keys(): for key in experiments_dict.keys():
if isinstance(experiment_dict[key], dict): if isinstance(experiments_dict[key], dict):
if experiment_dict[key].get('status') != 'STOPPED': if experiments_dict[key].get('status') != 'STOPPED':
rest_pid = experiment_dict[key].get('pid') rest_pid = experiments_dict[key].get('pid')
if not detect_process(rest_pid): if not detect_process(rest_pid):
experiment_config.update_experiment(key, 'status', 'STOPPED') experiments_config.update_experiment(key, 'status', 'STOPPED')
continue continue
def check_experiment_id(args, update=True): def check_experiment_id(args, update=True):
...@@ -60,31 +60,31 @@ def check_experiment_id(args, update=True): ...@@ -60,31 +60,31 @@ def check_experiment_id(args, update=True):
''' '''
if update: if update:
update_experiment() update_experiment()
experiment_config = Experiments() experiments_config = Experiments()
experiment_dict = experiment_config.get_all_experiments() experiments_dict = experiments_config.get_all_experiments()
if not experiment_dict: if not experiments_dict:
print_normal('There is no experiment running...') print_normal('There is no experiment running...')
return None return None
if not args.id: if not args.id:
running_experiment_list = [] running_experiment_list = []
for key in experiment_dict.keys(): for key in experiments_dict.keys():
if isinstance(experiment_dict[key], dict): if isinstance(experiments_dict[key], dict):
if experiment_dict[key].get('status') != 'STOPPED': if experiments_dict[key].get('status') != 'STOPPED':
running_experiment_list.append(key) running_experiment_list.append(key)
elif isinstance(experiment_dict[key], list): elif isinstance(experiments_dict[key], list):
# if the config file is old version, remove the configuration from file # if the config file is old version, remove the configuration from file
experiment_config.remove_experiment(key) experiments_config.remove_experiment(key)
if len(running_experiment_list) > 1: if len(running_experiment_list) > 1:
print_error('There are multiple experiments, please set the experiment id...') print_error('There are multiple experiments, please set the experiment id...')
experiment_information = "" experiment_information = ""
for key in running_experiment_list: for key in running_experiment_list:
experiment_information += EXPERIMENT_DETAIL_FORMAT % (key, experiment_information += EXPERIMENT_DETAIL_FORMAT % (key,
experiment_dict[key].get('experimentName', 'N/A'), experiments_dict[key].get('experimentName', 'N/A'),
experiment_dict[key]['status'], experiments_dict[key]['status'],
experiment_dict[key].get('port', 'N/A'), experiments_dict[key].get('port', 'N/A'),
experiment_dict[key].get('platform'), experiments_dict[key].get('platform'),
time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(experiment_dict[key]['startTime'] / 1000)) if isinstance(experiment_dict[key]['startTime'], int) else experiment_dict[key]['startTime'], time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(experiments_dict[key]['startTime'] / 1000)) if isinstance(experiments_dict[key]['startTime'], int) else experiments_dict[key]['startTime'],
time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(experiment_dict[key]['endTime'] / 1000)) if isinstance(experiment_dict[key]['endTime'], int) else experiment_dict[key]['endTime']) time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(experiments_dict[key]['endTime'] / 1000)) if isinstance(experiments_dict[key]['endTime'], int) else experiments_dict[key]['endTime'])
print(EXPERIMENT_INFORMATION_FORMAT % experiment_information) print(EXPERIMENT_INFORMATION_FORMAT % experiment_information)
exit(1) exit(1)
elif not running_experiment_list: elif not running_experiment_list:
...@@ -92,7 +92,7 @@ def check_experiment_id(args, update=True): ...@@ -92,7 +92,7 @@ def check_experiment_id(args, update=True):
return None return None
else: else:
return running_experiment_list[0] return running_experiment_list[0]
if experiment_dict.get(args.id): if experiments_dict.get(args.id):
return args.id return args.id
else: else:
print_error('Id not correct.') print_error('Id not correct.')
...@@ -110,25 +110,25 @@ def parse_ids(args): ...@@ -110,25 +110,25 @@ def parse_ids(args):
8.If the id does not exist but match multiple prefix of the experiment ids, nnictl will give id information 8.If the id does not exist but match multiple prefix of the experiment ids, nnictl will give id information
''' '''
update_experiment() update_experiment()
experiment_config = Experiments() experiments_config = Experiments()
experiment_dict = experiment_config.get_all_experiments() experiments_dict = experiments_config.get_all_experiments()
if not experiment_dict: if not experiments_dict:
print_normal('Experiment is not running...') print_normal('Experiment is not running...')
return None return None
result_list = [] result_list = []
running_experiment_list = [] running_experiment_list = []
for key in experiment_dict.keys(): for key in experiments_dict.keys():
if isinstance(experiment_dict[key], dict): if isinstance(experiments_dict[key], dict):
if experiment_dict[key].get('status') != 'STOPPED': if experiments_dict[key].get('status') != 'STOPPED':
running_experiment_list.append(key) running_experiment_list.append(key)
elif isinstance(experiment_dict[key], list): elif isinstance(experiments_dict[key], list):
# if the config file is old version, remove the configuration from file # if the config file is old version, remove the configuration from file
experiment_config.remove_experiment(key) experiments_config.remove_experiment(key)
if args.all: if args.all:
return running_experiment_list return running_experiment_list
if args.port is not None: if args.port is not None:
for key in running_experiment_list: for key in running_experiment_list:
if experiment_dict[key].get('port') == args.port: if experiments_dict[key].get('port') == args.port:
result_list.append(key) result_list.append(key)
if args.id and result_list and args.id != result_list[0]: if args.id and result_list and args.id != result_list[0]:
print_error('Experiment id and resful server port not match') print_error('Experiment id and resful server port not match')
...@@ -139,12 +139,12 @@ def parse_ids(args): ...@@ -139,12 +139,12 @@ def parse_ids(args):
experiment_information = "" experiment_information = ""
for key in running_experiment_list: for key in running_experiment_list:
experiment_information += EXPERIMENT_DETAIL_FORMAT % (key, experiment_information += EXPERIMENT_DETAIL_FORMAT % (key,
experiment_dict[key].get('experimentName', 'N/A'), experiments_dict[key].get('experimentName', 'N/A'),
experiment_dict[key]['status'], experiments_dict[key]['status'],
experiment_dict[key].get('port', 'N/A'), experiments_dict[key].get('port', 'N/A'),
experiment_dict[key].get('platform'), experiments_dict[key].get('platform'),
time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(experiment_dict[key]['startTime'] / 1000)) if isinstance(experiment_dict[key]['startTime'], int) else experiment_dict[key]['startTime'], time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(experiments_dict[key]['startTime'] / 1000)) if isinstance(experiments_dict[key]['startTime'], int) else experiments_dict[key]['startTime'],
time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(experiment_dict[key]['endTime'] / 1000)) if isinstance(experiment_dict[key]['endTime'], int) else experiment_dict[key]['endTime']) time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(experiments_dict[key]['endTime'] / 1000)) if isinstance(experiments_dict[key]['endTime'], int) else experiments_dict[key]['endTime'])
print(EXPERIMENT_INFORMATION_FORMAT % experiment_information) print(EXPERIMENT_INFORMATION_FORMAT % experiment_information)
exit(1) exit(1)
else: else:
...@@ -182,9 +182,9 @@ def get_experiment_port(args): ...@@ -182,9 +182,9 @@ def get_experiment_port(args):
if experiment_id is None: if experiment_id is None:
print_error('Please set correct experiment id.') print_error('Please set correct experiment id.')
exit(1) exit(1)
experiment_config = Experiments() experiments_config = Experiments()
experiment_dict = experiment_config.get_all_experiments() experiments_dict = experiments_config.get_all_experiments()
return experiment_dict[experiment_id].get('port') return experiments_dict[experiment_id].get('port')
def convert_time_stamp_to_date(content): def convert_time_stamp_to_date(content):
'''Convert time stamp to date time format''' '''Convert time stamp to date time format'''
...@@ -200,9 +200,9 @@ def convert_time_stamp_to_date(content): ...@@ -200,9 +200,9 @@ def convert_time_stamp_to_date(content):
def check_rest(args): def check_rest(args):
'''check if restful server is running''' '''check if restful server is running'''
experiment_config = Experiments() experiments_config = Experiments()
experiment_dict = experiment_config.get_all_experiments() experiments_dict = experiments_config.get_all_experiments()
rest_port = experiment_dict.get(get_config_filename(args)).get('port') rest_port = experiments_dict.get(get_config_filename(args)).get('port')
running, _ = check_rest_server_quick(rest_port) running, _ = check_rest_server_quick(rest_port)
if running: if running:
print_normal('Restful server is running...') print_normal('Restful server is running...')
...@@ -219,19 +219,19 @@ def stop_experiment(args): ...@@ -219,19 +219,19 @@ def stop_experiment(args):
if experiment_id_list: if experiment_id_list:
for experiment_id in experiment_id_list: for experiment_id in experiment_id_list:
print_normal('Stopping experiment %s' % experiment_id) print_normal('Stopping experiment %s' % experiment_id)
experiment_config = Experiments() experiments_config = Experiments()
experiment_dict = experiment_config.get_all_experiments() experiments_dict = experiments_config.get_all_experiments()
rest_pid = experiment_dict.get(experiment_id).get('pid') rest_pid = experiments_dict.get(experiment_id).get('pid')
if rest_pid: if rest_pid:
kill_command(rest_pid) kill_command(rest_pid)
tensorboard_pid_list = experiment_dict.get(experiment_id).get('tensorboardPidList') tensorboard_pid_list = experiments_dict.get(experiment_id).get('tensorboardPidList')
if tensorboard_pid_list: if tensorboard_pid_list:
for tensorboard_pid in tensorboard_pid_list: for tensorboard_pid in tensorboard_pid_list:
try: try:
kill_command(tensorboard_pid) kill_command(tensorboard_pid)
except Exception as exception: except Exception as exception:
print_error(exception) print_error(exception)
experiment_config.update_experiment(experiment_id, 'tensorboardPidList', []) experiments_config.update_experiment(experiment_id, 'tensorboardPidList', [])
print_normal('Stop experiment success.') print_normal('Stop experiment success.')
def trial_ls(args): def trial_ls(args):
...@@ -250,10 +250,11 @@ def trial_ls(args): ...@@ -250,10 +250,11 @@ def trial_ls(args):
if args.head and args.tail: if args.head and args.tail:
print_error('Head and tail cannot be set at the same time.') print_error('Head and tail cannot be set at the same time.')
return return
experiment_config = Experiments() experiments_config = Experiments()
experiment_dict = experiment_config.get_all_experiments() experiments_dict = experiments_config.get_all_experiments()
rest_port = experiment_dict.get(get_config_filename(args)).get('port') experiment_id = get_config_filename(args)
rest_pid = experiment_dict.get(get_config_filename(args)).get('pid') rest_port = experiments_dict.get(experiment_id).get('port')
rest_pid = experiments_dict.get(experiment_id).get('pid')
if not detect_process(rest_pid): if not detect_process(rest_pid):
print_error('Experiment is not running...') print_error('Experiment is not running...')
return return
...@@ -282,10 +283,11 @@ def trial_ls(args): ...@@ -282,10 +283,11 @@ def trial_ls(args):
def trial_kill(args): def trial_kill(args):
'''List trial''' '''List trial'''
experiment_config = Experiments() experiments_config = Experiments()
experiment_dict = experiment_config.get_all_experiments() experiments_dict = experiments_config.get_all_experiments()
rest_port = experiment_dict.get(get_config_filename(args)).get('port') experiment_id = get_config_filename(args)
rest_pid = experiment_dict.get(get_config_filename(args)).get('pid') rest_port = experiments_dict.get(experiment_id).get('port')
rest_pid = experiments_dict.get(experiment_id).get('pid')
if not detect_process(rest_pid): if not detect_process(rest_pid):
print_error('Experiment is not running...') print_error('Experiment is not running...')
return return
...@@ -304,20 +306,21 @@ def trial_kill(args): ...@@ -304,20 +306,21 @@ def trial_kill(args):
def trial_codegen(args): def trial_codegen(args):
'''Generate code for a specific trial''' '''Generate code for a specific trial'''
print_warning('Currently, this command is only for nni nas programming interface.') print_warning('Currently, this command is only for nni nas programming interface.')
exp_id = check_experiment_id(args) exp_id = get_config_filename(args)
nni_config = Config(get_config_filename(args)) experiment_config = Config(exp_id, Experiments().get_all_experiments()[exp_id]['logDir']).get_config()
if not nni_config.get_config('experimentConfig')['useAnnotation']: if not experiment_config.get('useAnnotation'):
print_error('The experiment is not using annotation') print_error('The experiment is not using annotation')
exit(1) exit(1)
code_dir = nni_config.get_config('experimentConfig')['trial']['codeDir'] code_dir = experiment_config['trial']['codeDir']
expand_annotations(code_dir, './exp_%s_trial_%s_code'%(exp_id, args.trial_id), exp_id, args.trial_id) expand_annotations(code_dir, './exp_%s_trial_%s_code'%(exp_id, args.trial_id), exp_id, args.trial_id)
def list_experiment(args): def list_experiment(args):
'''Get experiment information''' '''Get experiment information'''
experiment_config = Experiments() experiments_config = Experiments()
experiment_dict = experiment_config.get_all_experiments() experiments_dict = experiments_config.get_all_experiments()
rest_port = experiment_dict.get(get_config_filename(args)).get('port') experiment_id = get_config_filename(args)
rest_pid = experiment_dict.get(get_config_filename(args)).get('pid') rest_port = experiments_dict.get(experiment_id).get('port')
rest_pid = experiments_dict.get(experiment_id).get('pid')
if not detect_process(rest_pid): if not detect_process(rest_pid):
print_error('Experiment is not running...') print_error('Experiment is not running...')
return return
...@@ -336,9 +339,9 @@ def list_experiment(args): ...@@ -336,9 +339,9 @@ def list_experiment(args):
def experiment_status(args): def experiment_status(args):
'''Show the status of experiment''' '''Show the status of experiment'''
experiment_config = Experiments() experiments_config = Experiments()
experiment_dict = experiment_config.get_all_experiments() experiments_dict = experiments_config.get_all_experiments()
rest_port = experiment_dict.get(get_config_filename(args)).get('port') rest_port = experiments_dict.get(get_config_filename(args)).get('port')
result, response = check_rest_server_quick(rest_port) result, response = check_rest_server_quick(rest_port)
if not result: if not result:
print_normal('Restful server is not running...') print_normal('Restful server is not running...')
...@@ -350,9 +353,9 @@ def log_internal(args, filetype): ...@@ -350,9 +353,9 @@ def log_internal(args, filetype):
'''internal function to call get_log_content''' '''internal function to call get_log_content'''
file_name = get_config_filename(args) file_name = get_config_filename(args)
if filetype == 'stdout': if filetype == 'stdout':
file_full_path = os.path.join(NNICTL_HOME_DIR, file_name, 'log', 'nnictl_stdout.log') file_full_path = os.path.join(NNI_HOME_DIR, file_name, 'log', 'nnictl_stdout.log')
else: else:
file_full_path = os.path.join(NNICTL_HOME_DIR, file_name, 'log', 'nnictl_stderr.log') file_full_path = os.path.join(NNI_HOME_DIR, file_name, 'log', 'nnictl_stderr.log')
print(check_output_command(file_full_path, head=args.head, tail=args.tail)) print(check_output_command(file_full_path, head=args.head, tail=args.tail))
def log_stdout(args): def log_stdout(args):
...@@ -401,9 +404,12 @@ def log_trial(args): ...@@ -401,9 +404,12 @@ def log_trial(args):
''''get trial log path''' ''''get trial log path'''
trial_id_path_dict = {} trial_id_path_dict = {}
trial_id_list = [] trial_id_list = []
nni_config = Config(get_config_filename(args)) experiments_config = Experiments()
rest_port = nni_config.get_config('restServerPort') experiments_dict = experiments_config.get_all_experiments()
rest_pid = nni_config.get_config('restServerPid') experiment_id = get_config_filename(args)
rest_port = experiments_dict.get(experiment_id).get('port')
rest_pid = experiments_dict.get(experiment_id).get('pid')
experiment_config = Config(experiment_id, experiments_dict.get(experiment_id).get('logDir')).get_config()
if not detect_process(rest_pid): if not detect_process(rest_pid):
print_error('Experiment is not running...') print_error('Experiment is not running...')
return return
...@@ -419,7 +425,7 @@ def log_trial(args): ...@@ -419,7 +425,7 @@ def log_trial(args):
else: else:
print_error('Restful server is not running...') print_error('Restful server is not running...')
exit(1) exit(1)
is_adl = nni_config.get_config('experimentConfig').get('trainingServicePlatform') == 'adl' is_adl = experiment_config.get('trainingServicePlatform') == 'adl'
if is_adl and not args.trial_id: if is_adl and not args.trial_id:
print_error('Trial ID is required to retrieve the log for adl. Please specify it with "--trial_id".') print_error('Trial ID is required to retrieve the log for adl. Please specify it with "--trial_id".')
exit(1) exit(1)
...@@ -428,7 +434,7 @@ def log_trial(args): ...@@ -428,7 +434,7 @@ def log_trial(args):
print_error('Trial id {0} not correct, please check your command!'.format(args.trial_id)) print_error('Trial id {0} not correct, please check your command!'.format(args.trial_id))
exit(1) exit(1)
if is_adl: if is_adl:
log_trial_adl_helper(args, nni_config.get_config('experimentId')) log_trial_adl_helper(args, experiment_id)
# adl has its own way to log trial, and it thus returns right after the helper returns # adl has its own way to log trial, and it thus returns right after the helper returns
return return
if trial_id_path_dict.get(args.trial_id): if trial_id_path_dict.get(args.trial_id):
...@@ -445,13 +451,15 @@ def log_trial(args): ...@@ -445,13 +451,15 @@ def log_trial(args):
def get_config(args): def get_config(args):
'''get config info''' '''get config info'''
nni_config = Config(get_config_filename(args)) experiment_id = get_config_filename(args)
print(nni_config.get_all_config()) experiment_config = Config(experiment_id, Experiments().get_all_experiments()[experiment_id]['logDir']).get_config()
print(json.dumps(experiment_config, indent=4))
def webui_url(args): def webui_url(args):
'''show the url of web ui''' '''show the url of web ui'''
nni_config = Config(get_config_filename(args)) experiment_id = get_config_filename(args)
print_normal('{0} {1}'.format('Web UI url:', ' '.join(nni_config.get_config('webuiUrl')))) experiments_dict = Experiments().get_all_experiments()
print_normal('{0} {1}'.format('Web UI url:', ' '.join(experiments_dict[experiment_id].get('webuiUrl'))))
def webui_nas(args): def webui_nas(args):
'''launch nas ui''' '''launch nas ui'''
...@@ -520,15 +528,15 @@ def hdfs_clean(host, user_name, output_dir, experiment_id=None): ...@@ -520,15 +528,15 @@ def hdfs_clean(host, user_name, output_dir, experiment_id=None):
def experiment_clean(args): def experiment_clean(args):
'''clean up the experiment data''' '''clean up the experiment data'''
experiment_id_list = [] experiment_id_list = []
experiment_config = Experiments() experiments_config = Experiments()
experiment_dict = experiment_config.get_all_experiments() experiments_dict = experiments_config.get_all_experiments()
if args.all: if args.all:
experiment_id_list = list(experiment_dict.keys()) experiment_id_list = list(experiments_dict.keys())
else: else:
if args.id is None: if args.id is None:
print_error('please set experiment id.') print_error('please set experiment id.')
exit(1) exit(1)
if args.id not in experiment_dict: if args.id not in experiments_dict:
print_error('Cannot find experiment {0}.'.format(args.id)) print_error('Cannot find experiment {0}.'.format(args.id))
exit(1) exit(1)
experiment_id_list.append(args.id) experiment_id_list.append(args.id)
...@@ -542,23 +550,23 @@ def experiment_clean(args): ...@@ -542,23 +550,23 @@ def experiment_clean(args):
else: else:
break break
for experiment_id in experiment_id_list: for experiment_id in experiment_id_list:
nni_config = Config(experiment_id) experiment_id = get_config_filename(args)
platform = nni_config.get_config('experimentConfig').get('trainingServicePlatform') experiment_config = Config(experiment_id, Experiments().get_all_experiments()[experiment_id]['logDir']).get_config()
experiment_id = nni_config.get_config('experimentId') platform = experiment_config.get('trainingServicePlatform')
if platform == 'remote': if platform == 'remote':
machine_list = nni_config.get_config('experimentConfig').get('machineList') machine_list = experiment_config.get('machineList')
remote_clean(machine_list, experiment_id) remote_clean(machine_list, experiment_id)
elif platform == 'pai': elif platform == 'pai':
host = nni_config.get_config('experimentConfig').get('paiConfig').get('host') host = experiment_config.get('paiConfig').get('host')
user_name = nni_config.get_config('experimentConfig').get('paiConfig').get('userName') user_name = experiment_config.get('paiConfig').get('userName')
output_dir = nni_config.get_config('experimentConfig').get('trial').get('outputDir') output_dir = experiment_config.get('trial').get('outputDir')
hdfs_clean(host, user_name, output_dir, experiment_id) hdfs_clean(host, user_name, output_dir, experiment_id)
elif platform != 'local': elif platform != 'local':
#TODO: support all platforms # TODO: support all platforms
print_warning('platform {0} clean up not supported yet.'.format(platform)) print_warning('platform {0} clean up not supported yet.'.format(platform))
exit(0) exit(0)
#clean local data # clean local data
local_base_dir = nni_config.get_config('experimentConfig').get('logDir') local_base_dir = experiments_config[experiment_id]['logDir']
if not local_base_dir: if not local_base_dir:
local_base_dir = NNI_HOME_DIR local_base_dir = NNI_HOME_DIR
local_experiment_dir = os.path.join(local_base_dir, experiment_id) local_experiment_dir = os.path.join(local_base_dir, experiment_id)
...@@ -567,9 +575,8 @@ def experiment_clean(args): ...@@ -567,9 +575,8 @@ def experiment_clean(args):
local_clean(os.path.join(local_experiment_dir, folder_name)) local_clean(os.path.join(local_experiment_dir, folder_name))
if not os.listdir(local_experiment_dir): if not os.listdir(local_experiment_dir):
local_clean(local_experiment_dir) local_clean(local_experiment_dir)
experiment_config = Experiments()
print_normal('removing metadata of experiment {0}'.format(experiment_id)) print_normal('removing metadata of experiment {0}'.format(experiment_id))
experiment_config.remove_experiment(experiment_id) experiments_config.remove_experiment(experiment_id)
print_normal('Done.') print_normal('Done.')
def get_platform_dir(config_content): def get_platform_dir(config_content):
...@@ -635,30 +642,30 @@ def platform_clean(args): ...@@ -635,30 +642,30 @@ def platform_clean(args):
def experiment_list(args): def experiment_list(args):
'''get the information of all experiments''' '''get the information of all experiments'''
update_experiment() update_experiment()
experiment_config = Experiments() experiments_config = Experiments()
experiment_dict = experiment_config.get_all_experiments() experiments_dict = experiments_config.get_all_experiments()
if not experiment_dict: if not experiments_dict:
print_normal('Cannot find experiments.') print_normal('Cannot find experiments.')
exit(1) exit(1)
experiment_id_list = [] experiment_id_list = []
if args.all: if args.all:
for key in experiment_dict.keys(): for key in experiments_dict.keys():
experiment_id_list.append(key) experiment_id_list.append(key)
else: else:
for key in experiment_dict.keys(): for key in experiments_dict.keys():
if experiment_dict[key]['status'] != 'STOPPED': if experiments_dict[key]['status'] != 'STOPPED':
experiment_id_list.append(key) experiment_id_list.append(key)
if not experiment_id_list: if not experiment_id_list:
print_warning('There is no experiment running...\nYou can use \'nnictl experiment list --all\' to list all experiments.') print_warning('There is no experiment running...\nYou can use \'nnictl experiment list --all\' to list all experiments.')
experiment_information = "" experiment_information = ""
for key in experiment_id_list: for key in experiment_id_list:
experiment_information += EXPERIMENT_DETAIL_FORMAT % (key, experiment_information += EXPERIMENT_DETAIL_FORMAT % (key,
experiment_dict[key].get('experimentName', 'N/A'), experiments_dict[key].get('experimentName', 'N/A'),
experiment_dict[key]['status'], experiments_dict[key]['status'],
experiment_dict[key].get('port', 'N/A'), experiments_dict[key].get('port', 'N/A'),
experiment_dict[key].get('platform'), experiments_dict[key].get('platform'),
time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(experiment_dict[key]['startTime'] / 1000)) if isinstance(experiment_dict[key]['startTime'], int) else experiment_dict[key]['startTime'], time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(experiments_dict[key]['startTime'] / 1000)) if isinstance(experiments_dict[key]['startTime'], int) else experiments_dict[key]['startTime'],
time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(experiment_dict[key]['endTime'] / 1000)) if isinstance(experiment_dict[key]['endTime'], int) else experiment_dict[key]['endTime']) time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(experiments_dict[key]['endTime'] / 1000)) if isinstance(experiments_dict[key]['endTime'], int) else experiments_dict[key]['endTime'])
print(EXPERIMENT_INFORMATION_FORMAT % experiment_information) print(EXPERIMENT_INFORMATION_FORMAT % experiment_information)
return experiment_id_list return experiment_id_list
...@@ -680,26 +687,26 @@ def get_time_interval(time1, time2): ...@@ -680,26 +687,26 @@ def get_time_interval(time1, time2):
def show_experiment_info(): def show_experiment_info():
'''show experiment information in monitor''' '''show experiment information in monitor'''
update_experiment() update_experiment()
experiment_config = Experiments() experiments_config = Experiments()
experiment_dict = experiment_config.get_all_experiments() experiments_dict = experiments_config.get_all_experiments()
if not experiment_dict: if not experiments_dict:
print('There is no experiment running...') print('There is no experiment running...')
exit(1) exit(1)
experiment_id_list = [] experiment_id_list = []
for key in experiment_dict.keys(): for key in experiments_dict.keys():
if experiment_dict[key]['status'] != 'STOPPED': if experiments_dict[key]['status'] != 'STOPPED':
experiment_id_list.append(key) experiment_id_list.append(key)
if not experiment_id_list: if not experiment_id_list:
print_warning('There is no experiment running...') print_warning('There is no experiment running...')
return return
for key in experiment_id_list: for key in experiment_id_list:
print(EXPERIMENT_MONITOR_INFO % (key, experiment_dict[key]['status'], experiment_dict[key]['port'], \ print(EXPERIMENT_MONITOR_INFO % (key, experiments_dict[key]['status'], experiments_dict[key]['port'], \
experiment_dict[key].get('platform'), time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(experiment_dict[key]['startTime'] / 1000)) if isinstance(experiment_dict[key]['startTime'], int) else experiment_dict[key]['startTime'], \ experiments_dict[key].get('platform'), time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(experiments_dict[key]['startTime'] / 1000)) if isinstance(experiments_dict[key]['startTime'], int) else experiments_dict[key]['startTime'], \
get_time_interval(experiment_dict[key]['startTime'], experiment_dict[key]['endTime']))) get_time_interval(experiments_dict[key]['startTime'], experiments_dict[key]['endTime'])))
print(TRIAL_MONITOR_HEAD) print(TRIAL_MONITOR_HEAD)
running, response = check_rest_server_quick(experiment_dict[key]['port']) running, response = check_rest_server_quick(experiments_dict[key]['port'])
if running: if running:
response = rest_get(trial_jobs_url(experiment_dict[key]['port']), REST_TIME_OUT) response = rest_get(trial_jobs_url(experiments_dict[key]['port']), REST_TIME_OUT)
if response and check_response(response): if response and check_response(response):
content = json.loads(response.text) content = json.loads(response.text)
for index, value in enumerate(content): for index, value in enumerate(content):
...@@ -756,10 +763,11 @@ def export_trials_data(args): ...@@ -756,10 +763,11 @@ def export_trials_data(args):
groupby.setdefault(content['trialJobId'], []).append(json.loads(content['data'])) groupby.setdefault(content['trialJobId'], []).append(json.loads(content['data']))
return groupby return groupby
experiment_config = Experiments() experiments_config = Experiments()
experiment_dict = experiment_config.get_all_experiments() experiments_dict = experiments_config.get_all_experiments()
rest_port = experiment_dict.get(get_config_filename(args)).get('port') experiment_id = get_config_filename(args)
rest_pid = experiment_dict.get(get_config_filename(args)).get('pid') rest_port = experiments_dict.get(experiment_id).get('port')
rest_pid = experiments_dict.get(experiment_id).get('pid')
if not detect_process(rest_pid): if not detect_process(rest_pid):
print_error('Experiment is not running...') print_error('Experiment is not running...')
...@@ -825,22 +833,20 @@ def search_space_auto_gen(args): ...@@ -825,22 +833,20 @@ def search_space_auto_gen(args):
def save_experiment(args): def save_experiment(args):
'''save experiment data to a zip file''' '''save experiment data to a zip file'''
experiment_config = Experiments() experiments_config = Experiments()
experiment_dict = experiment_config.get_all_experiments() experiments_dict = experiments_config.get_all_experiments()
if args.id is None: if args.id is None:
print_error('Please set experiment id.') print_error('Please set experiment id.')
exit(1) exit(1)
if args.id not in experiment_dict: if args.id not in experiments_dict:
print_error('Cannot find experiment {0}.'.format(args.id)) print_error('Cannot find experiment {0}.'.format(args.id))
exit(1) exit(1)
if experiment_dict[args.id].get('status') != 'STOPPED': if experiments_dict[args.id].get('status') != 'STOPPED':
print_error('Can only save stopped experiment!') print_error('Can only save stopped experiment!')
exit(1) exit(1)
print_normal('Saving...') print_normal('Saving...')
nni_config = Config(args.id) experiment_config = Config(args.id, experiments_dict[args.id]['logDir']).get_config()
logDir = os.path.join(NNI_HOME_DIR, args.id) logDir = os.path.join(experiments_dict[args.id]['logDir'], args.id)
if nni_config.get_config('logDir'):
logDir = os.path.join(nni_config.get_config('logDir'), args.id)
temp_root_dir = generate_temp_dir() temp_root_dir = generate_temp_dir()
# Step1. Copy logDir to temp folder # Step1. Copy logDir to temp folder
...@@ -855,22 +861,21 @@ def save_experiment(args): ...@@ -855,22 +861,21 @@ def save_experiment(args):
os.makedirs(temp_nnictl_dir, exist_ok=True) os.makedirs(temp_nnictl_dir, exist_ok=True)
try: try:
with open(os.path.join(temp_nnictl_dir, '.experiment'), 'w') as file: with open(os.path.join(temp_nnictl_dir, '.experiment'), 'w') as file:
experiment_dict[args.id]['id'] = args.id experiments_dict[args.id]['id'] = args.id
json.dump(experiment_dict[args.id], file) json.dump(experiments_dict[args.id], file)
except IOError: except IOError:
print_error('Write file to %s failed!' % os.path.join(temp_nnictl_dir, '.experiment')) print_error('Write file to %s failed!' % os.path.join(temp_nnictl_dir, '.experiment'))
exit(1) exit(1)
nnictl_log_dir = os.path.join(NNICTL_HOME_DIR, args.id, 'log') nnictl_log_dir = os.path.join(NNI_HOME_DIR, args.id, 'log')
shutil.copytree(nnictl_log_dir, os.path.join(temp_nnictl_dir, args.id, 'log')) shutil.copytree(nnictl_log_dir, os.path.join(temp_nnictl_dir, args.id, 'log'))
shutil.copy(os.path.join(NNICTL_HOME_DIR, args.id, '.config'), os.path.join(temp_nnictl_dir, args.id, '.config'))
# Step3. Copy code dir # Step3. Copy code dir
if args.saveCodeDir: if args.saveCodeDir:
temp_code_dir = os.path.join(temp_root_dir, 'code') temp_code_dir = os.path.join(temp_root_dir, 'code')
shutil.copytree(nni_config.get_config('experimentConfig')['trial']['codeDir'], temp_code_dir) shutil.copytree(experiment_config['trial']['codeDir'], temp_code_dir)
# Step4. Copy searchSpace file # Step4. Copy searchSpace file
search_space_path = nni_config.get_config('experimentConfig').get('searchSpacePath') search_space_path = experiment_config.get('searchSpacePath')
if search_space_path: if search_space_path:
if not os.path.exists(search_space_path): if not os.path.exists(search_space_path):
print_warning('search space %s does not exist!' % search_space_path) print_warning('search space %s does not exist!' % search_space_path)
...@@ -928,10 +933,10 @@ def load_experiment(args): ...@@ -928,10 +933,10 @@ def load_experiment(args):
print_error('Invalid nnictl metadata file: %s' % err) print_error('Invalid nnictl metadata file: %s' % err)
shutil.rmtree(temp_root_dir) shutil.rmtree(temp_root_dir)
exit(1) exit(1)
experiment_config = Experiments() experiments_config = Experiments()
experiment_dict = experiment_config.get_all_experiments() experiments_dict = experiments_config.get_all_experiments()
experiment_id = experiment_metadata.get('id') experiment_id = experiment_metadata.get('id')
if experiment_id in experiment_dict: if experiment_id in experiments_dict:
print_error('Invalid: experiment id already exist!') print_error('Invalid: experiment id already exist!')
shutil.rmtree(temp_root_dir) shutil.rmtree(temp_root_dir)
exit(1) exit(1)
...@@ -942,25 +947,25 @@ def load_experiment(args): ...@@ -942,25 +947,25 @@ def load_experiment(args):
# Step2. Copy nnictl metadata # Step2. Copy nnictl metadata
src_path = os.path.join(nnictl_temp_dir, experiment_id) src_path = os.path.join(nnictl_temp_dir, experiment_id)
dest_path = os.path.join(NNICTL_HOME_DIR, experiment_id) dest_path = os.path.join(NNI_HOME_DIR, experiment_id)
if os.path.exists(dest_path): if os.path.exists(dest_path):
shutil.rmtree(dest_path) shutil.rmtree(dest_path)
shutil.copytree(src_path, dest_path) shutil.copytree(src_path, dest_path)
# Step3. Copy experiment data # Step3. Copy experiment data
nni_config = Config(experiment_id) os.rename(os.path.join(temp_root_dir, 'experiment'), os.path.join(temp_root_dir, experiment_id))
nnictl_exp_config = nni_config.get_config('experimentConfig') src_path = os.path.join(os.path.join(temp_root_dir, experiment_id))
experiment_config = Config(experiment_id, temp_root_dir).get_config()
if args.logDir: if args.logDir:
logDir = args.logDir logDir = args.logDir
nnictl_exp_config['logDir'] = logDir experiment_config['logDir'] = logDir
else: else:
if nnictl_exp_config.get('logDir'): if experiment_config.get('logDir'):
logDir = nnictl_exp_config['logDir'] logDir = experiment_config['logDir']
else: else:
logDir = NNI_HOME_DIR logDir = NNI_HOME_DIR
os.rename(os.path.join(temp_root_dir, 'experiment'), os.path.join(temp_root_dir, experiment_id))
src_path = os.path.join(os.path.join(temp_root_dir, experiment_id)) dest_path = os.path.join(logDir, experiment_id)
dest_path = os.path.join(os.path.join(logDir, experiment_id))
if os.path.exists(dest_path): if os.path.exists(dest_path):
shutil.rmtree(dest_path) shutil.rmtree(dest_path)
shutil.copytree(src_path, dest_path) shutil.copytree(src_path, dest_path)
...@@ -970,7 +975,7 @@ def load_experiment(args): ...@@ -970,7 +975,7 @@ def load_experiment(args):
if not os.path.isabs(codeDir): if not os.path.isabs(codeDir):
codeDir = os.path.join(os.getcwd(), codeDir) codeDir = os.path.join(os.getcwd(), codeDir)
print_normal('Expand codeDir to %s' % codeDir) print_normal('Expand codeDir to %s' % codeDir)
nnictl_exp_config['trial']['codeDir'] = codeDir experiment_config['trial']['codeDir'] = codeDir
archive_code_dir = os.path.join(temp_root_dir, 'code') archive_code_dir = os.path.join(temp_root_dir, 'code')
if os.path.exists(archive_code_dir): if os.path.exists(archive_code_dir):
file_list = os.listdir(archive_code_dir) file_list = os.listdir(archive_code_dir)
...@@ -985,44 +990,18 @@ def load_experiment(args): ...@@ -985,44 +990,18 @@ def load_experiment(args):
else: else:
shutil.copy(src_path, target_path) shutil.copy(src_path, target_path)
# Step5. Copy searchSpace file # Step5. Create experiment metadata
archive_search_space_dir = os.path.join(temp_root_dir, 'searchSpace') experiments_config.add_experiment(experiment_id,
if args.searchSpacePath: experiment_metadata.get('port'),
target_path = os.path.expanduser(args.searchSpacePath) experiment_metadata.get('startTime'),
else: experiment_metadata.get('platform'),
# set default path to codeDir experiment_metadata.get('experimentName'),
target_path = os.path.join(codeDir, 'search_space.json') experiment_metadata.get('endTime'),
if not os.path.isabs(target_path): experiment_metadata.get('status'),
target_path = os.path.join(os.getcwd(), target_path) experiment_metadata.get('tag'),
print_normal('Expand search space path to %s' % target_path) experiment_metadata.get('pid'),
nnictl_exp_config['searchSpacePath'] = target_path experiment_metadata.get('webUrl'),
# if the path already has a search space file, use the original one, otherwise use archived one logDir)
if not os.path.isfile(target_path):
if len(os.listdir(archive_search_space_dir)) == 0:
print_error('Archive file does not contain search space file!')
exit(1)
else:
for file in os.listdir(archive_search_space_dir):
source_path = os.path.join(archive_search_space_dir, file)
os.makedirs(os.path.dirname(target_path), exist_ok=True)
shutil.copyfile(source_path, target_path)
break
elif not args.searchSpacePath:
print_warning('%s exist, will not load search_space file' % target_path)
# Step6. Create experiment metadata
nni_config.set_config('experimentConfig', nnictl_exp_config)
experiment_config.add_experiment(experiment_id,
experiment_metadata.get('port'),
experiment_metadata.get('startTime'),
experiment_metadata.get('platform'),
experiment_metadata.get('experimentName'),
experiment_metadata.get('endTime'),
experiment_metadata.get('status'),
experiment_metadata.get('tag'),
experiment_metadata.get('pid'),
experiment_metadata.get('webUrl'),
experiment_metadata.get('logDir'))
print_normal('Load experiment %s succsss!' % experiment_id) print_normal('Load experiment %s succsss!' % experiment_id)
# Step6. Cleanup temp data # Step6. Cleanup temp data
......
...@@ -31,9 +31,9 @@ def parse_log_path(args, trial_content): ...@@ -31,9 +31,9 @@ def parse_log_path(args, trial_content):
exit(1) exit(1)
return path_list, host_list return path_list, host_list
def copy_data_from_remote(args, nni_config, trial_content, path_list, host_list, temp_nni_path): def copy_data_from_remote(args, experiment_config, trial_content, path_list, host_list, temp_nni_path):
'''use ssh client to copy data from remote machine to local machien''' '''use ssh client to copy data from remote machine to local machien'''
machine_list = nni_config.get_config('experimentConfig').get('machineList') machine_list = experiment_config.get('machineList')
machine_dict = {} machine_dict = {}
local_path_list = [] local_path_list = []
for machine in machine_list: for machine in machine_list:
...@@ -49,15 +49,15 @@ def copy_data_from_remote(args, nni_config, trial_content, path_list, host_list, ...@@ -49,15 +49,15 @@ def copy_data_from_remote(args, nni_config, trial_content, path_list, host_list,
print_normal('Copy done!') print_normal('Copy done!')
return local_path_list return local_path_list
def get_path_list(args, nni_config, trial_content, temp_nni_path): def get_path_list(args, experiment_config, trial_content, temp_nni_path):
'''get path list according to different platform''' '''get path list according to different platform'''
path_list, host_list = parse_log_path(args, trial_content) path_list, host_list = parse_log_path(args, trial_content)
platform = nni_config.get_config('experimentConfig').get('trainingServicePlatform') platform = experiment_config.get('trainingServicePlatform')
if platform == 'local': if platform == 'local':
print_normal('Log path: %s' % ' '.join(path_list)) print_normal('Log path: %s' % ' '.join(path_list))
return path_list return path_list
elif platform == 'remote': elif platform == 'remote':
path_list = copy_data_from_remote(args, nni_config, trial_content, path_list, host_list, temp_nni_path) path_list = copy_data_from_remote(args, experiment_config, trial_content, path_list, host_list, temp_nni_path)
print_normal('Log path: %s' % ' '.join(path_list)) print_normal('Log path: %s' % ' '.join(path_list))
return path_list return path_list
else: else:
...@@ -83,19 +83,19 @@ def start_tensorboard_process(args, experiment_id, path_list, temp_nni_path): ...@@ -83,19 +83,19 @@ def start_tensorboard_process(args, experiment_id, path_list, temp_nni_path):
url_list = get_local_urls(args.port) url_list = get_local_urls(args.port)
print_green('Start tensorboard success!') print_green('Start tensorboard success!')
print_normal('Tensorboard urls: ' + ' '.join(url_list)) print_normal('Tensorboard urls: ' + ' '.join(url_list))
experiment_config = Experiments() experiments_config = Experiments()
tensorboard_process_pid_list = experiment_config.get_all_experiments().get(experiment_id).get('tensorboardPidList') tensorboard_process_pid_list = experiments_config.get_all_experiments().get(experiment_id).get('tensorboardPidList')
if tensorboard_process_pid_list is None: if tensorboard_process_pid_list is None:
tensorboard_process_pid_list = [tensorboard_process.pid] tensorboard_process_pid_list = [tensorboard_process.pid]
else: else:
tensorboard_process_pid_list.append(tensorboard_process.pid) tensorboard_process_pid_list.append(tensorboard_process.pid)
experiment_config.update_experiment(experiment_id, 'tensorboardPidList', tensorboard_process_pid_list) experiments_config.update_experiment(experiment_id, 'tensorboardPidList', tensorboard_process_pid_list)
def stop_tensorboard(args): def stop_tensorboard(args):
'''stop tensorboard''' '''stop tensorboard'''
experiment_id = check_experiment_id(args) experiment_id = check_experiment_id(args)
experiment_config = Experiments() experiments_config = Experiments()
tensorboard_pid_list = experiment_config.get_all_experiments().get(experiment_id).get('tensorboardPidList') tensorboard_pid_list = experiments_config.get_all_experiments().get(experiment_id).get('tensorboardPidList')
if tensorboard_pid_list: if tensorboard_pid_list:
for tensorboard_pid in tensorboard_pid_list: for tensorboard_pid in tensorboard_pid_list:
try: try:
...@@ -103,7 +103,7 @@ def stop_tensorboard(args): ...@@ -103,7 +103,7 @@ def stop_tensorboard(args):
call(cmds) call(cmds)
except Exception as exception: except Exception as exception:
print_error(exception) print_error(exception)
experiment_config.update_experiment(experiment_id, 'tensorboardPidList', []) experiments_config.update_experiment(experiment_id, 'tensorboardPidList', [])
print_normal('Stop tensorboard success!') print_normal('Stop tensorboard success!')
else: else:
print_error('No tensorboard configuration!') print_error('No tensorboard configuration!')
...@@ -128,17 +128,17 @@ def start_tensorboard(args): ...@@ -128,17 +128,17 @@ def start_tensorboard(args):
return return
if args.id is None: if args.id is None:
args.id = experiment_id args.id = experiment_id
experiment_config = Experiments() experiments_config = Experiments()
experiment_dict = experiment_config.get_all_experiments() experiments_dict = experiments_config.get_all_experiments()
if experiment_dict[args.id]["status"] == "STOPPED": if experiments_dict[args.id]["status"] == "STOPPED":
print_error("Experiment {} is stopped...".format(args.id)) print_error("Experiment {} is stopped...".format(args.id))
return return
nni_config = Config(args.id) experiment_config = Config(args.id, experiments_dict[args.id]['logDir']).get_config()
if nni_config.get_config('experimentConfig').get('trainingServicePlatform') == 'adl': if experiment_config.get('trainingServicePlatform') == 'adl':
adl_tensorboard_helper(args) adl_tensorboard_helper(args)
return return
rest_port = nni_config.get_config('restServerPort') rest_port = experiments_dict[args.id]['port']
rest_pid = nni_config.get_config('restServerPid') rest_pid = experiments_dict[args.id]['pid']
if not detect_process(rest_pid): if not detect_process(rest_pid):
print_error('Experiment is not running...') print_error('Experiment is not running...')
return return
...@@ -158,9 +158,9 @@ def start_tensorboard(args): ...@@ -158,9 +158,9 @@ def start_tensorboard(args):
if len(trial_content) > 1 and not args.trial_id: if len(trial_content) > 1 and not args.trial_id:
print_error('There are multiple trials, please set trial id!') print_error('There are multiple trials, please set trial id!')
exit(1) exit(1)
experiment_id = nni_config.get_config('experimentId') experiment_id = args.id
temp_nni_path = os.path.join(tempfile.gettempdir(), 'nni', experiment_id) temp_nni_path = os.path.join(tempfile.gettempdir(), 'nni', experiment_id)
os.makedirs(temp_nni_path, exist_ok=True) os.makedirs(temp_nni_path, exist_ok=True)
path_list = get_path_list(args, nni_config, trial_content, temp_nni_path) path_list = get_path_list(args, experiment_config, trial_content, temp_nni_path)
start_tensorboard_process(args, experiment_id, path_list, temp_nni_path) start_tensorboard_process(args, experiment_id, path_list, temp_nni_path)
...@@ -23,11 +23,12 @@ def validate_file(path): ...@@ -23,11 +23,12 @@ def validate_file(path):
def validate_dispatcher(args): def validate_dispatcher(args):
'''validate if the dispatcher of the experiment supports importing data''' '''validate if the dispatcher of the experiment supports importing data'''
nni_config = Config(get_config_filename(args)).get_config('experimentConfig') experiment_id = get_config_filename(args)
if nni_config.get('tuner') and nni_config['tuner'].get('builtinTunerName'): experiment_config = Config(experiment_id, Experiments().get_all_experiments()[experiment_id]['logDir']).get_config()
dispatcher_name = nni_config['tuner']['builtinTunerName'] if experiment_config.get('tuner') and experiment_config['tuner'].get('builtinTunerName'):
elif nni_config.get('advisor') and nni_config['advisor'].get('builtinAdvisorName'): dispatcher_name = experiment_config['tuner']['builtinTunerName']
dispatcher_name = nni_config['advisor']['builtinAdvisorName'] elif experiment_config.get('advisor') and experiment_config['advisor'].get('builtinAdvisorName'):
dispatcher_name = experiment_config['advisor']['builtinAdvisorName']
else: # otherwise it should be a customized one else: # otherwise it should be a customized one
return return
if dispatcher_name not in TUNERS_SUPPORTING_IMPORT_DATA: if dispatcher_name not in TUNERS_SUPPORTING_IMPORT_DATA:
...@@ -58,9 +59,9 @@ def get_query_type(key): ...@@ -58,9 +59,9 @@ def get_query_type(key):
def update_experiment_profile(args, key, value): def update_experiment_profile(args, key, value):
'''call restful server to update experiment profile''' '''call restful server to update experiment profile'''
experiment_config = Experiments() experiments_config = Experiments()
experiment_dict = experiment_config.get_all_experiments() experiments_dict = experiments_config.get_all_experiments()
rest_port = experiment_dict.get(get_config_filename(args)).get('port') rest_port = experiments_dict.get(get_config_filename(args)).get('port')
running, _ = check_rest_server_quick(rest_port) running, _ = check_rest_server_quick(rest_port)
if running: if running:
response = rest_get(experiment_url(rest_port), REST_TIME_OUT) response = rest_get(experiment_url(rest_port), REST_TIME_OUT)
...@@ -117,9 +118,10 @@ def import_data(args): ...@@ -117,9 +118,10 @@ def import_data(args):
validate_dispatcher(args) validate_dispatcher(args)
content = load_search_space(args.filename) content = load_search_space(args.filename)
nni_config = Config(get_config_filename(args)) experiments_dict = Experiments().get_all_experiments()
rest_port = nni_config.get_config('restServerPort') experiment_id = get_config_filename(args)
rest_pid = nni_config.get_config('restServerPid') rest_port = experiments_dict.get(experiment_id).get('port')
rest_pid = experiments_dict.get(experiment_id).get('pid')
if not detect_process(rest_pid): if not detect_process(rest_pid):
print_error('Experiment is not running...') print_error('Experiment is not running...')
return return
...@@ -137,8 +139,8 @@ def import_data(args): ...@@ -137,8 +139,8 @@ def import_data(args):
def import_data_to_restful_server(args, content): def import_data_to_restful_server(args, content):
'''call restful server to import data to the experiment''' '''call restful server to import data to the experiment'''
nni_config = Config(get_config_filename(args)) experiments_dict = Experiments().get_all_experiments()
rest_port = nni_config.get_config('restServerPort') rest_port = experiments_dict.get(get_config_filename(args)).get('port')
running, _ = check_rest_server_quick(rest_port) running, _ = check_rest_server_quick(rest_port)
if running: if running:
response = rest_post(import_data_url(rest_port), content, REST_TIME_OUT) response = rest_post(import_data_url(rest_port), content, REST_TIME_OUT)
......
...@@ -4,31 +4,27 @@ ...@@ -4,31 +4,27 @@
import argparse import argparse
from pathlib import Path from pathlib import Path
from subprocess import Popen, PIPE, STDOUT from subprocess import Popen, PIPE, STDOUT
from nni.tools.nnictl.config_utils import Config, Experiments from nni.tools.nnictl.config_utils import Experiments
from nni.tools.nnictl.common_utils import print_green from nni.tools.nnictl.common_utils import print_green
from nni.tools.nnictl.command_utils import kill_command from nni.tools.nnictl.command_utils import kill_command
from nni.tools.nnictl.nnictl_utils import get_yml_content from nni.tools.nnictl.nnictl_utils import get_yml_content
def create_mock_experiment(): def create_mock_experiment():
nnictl_experiment_config = Experiments() nnictl_experiment_config = Experiments()
nnictl_experiment_config.add_experiment('xOpEwA5w', '8080', 123456, nnictl_experiment_config.add_experiment('xOpEwA5w', 8080, 123456,
'local', 'example_sklearn-classification') 'local', 'example_sklearn-classification')
nni_config = Config('xOpEwA5w')
# mock process # mock process
cmds = ['sleep', '3600000'] cmds = ['sleep', '3600000']
process = Popen(cmds, stdout=PIPE, stderr=STDOUT) process = Popen(cmds, stdout=PIPE, stderr=STDOUT)
nni_config.set_config('restServerPid', process.pid) nnictl_experiment_config.update_experiment('xOpEwA5w', 'pid', process.pid)
nni_config.set_config('experimentId', 'xOpEwA5w') nnictl_experiment_config.update_experiment('xOpEwA5w', 'port', 8080)
nni_config.set_config('restServerPort', 8080) nnictl_experiment_config.update_experiment('xOpEwA5w', 'webuiUrl', ['http://localhost:8080'])
nni_config.set_config('webuiUrl', ['http://localhost:8080'])
yml_path = Path(__file__).parents[1] / 'config_files/valid/test.yml'
experiment_config = get_yml_content(str(yml_path))
nni_config.set_config('experimentConfig', experiment_config)
print_green("expriment start success, experiment id: xOpEwA5w") print_green("expriment start success, experiment id: xOpEwA5w")
def stop_mock_experiment(): def stop_mock_experiment():
config = Config('config') nnictl_experiment_config = Experiments()
kill_command(config.get_config('restServerPid')) experiments_dict = nnictl_experiment_config.get_all_experiments()
kill_command(experiments_dict['xOpEwA5w'].get('pid'))
nnictl_experiment_config = Experiments() nnictl_experiment_config = Experiments()
nnictl_experiment_config.remove_experiment('xOpEwA5w') nnictl_experiment_config.remove_experiment('xOpEwA5w')
......
{"experimentConfig": {"authorName": "default", "experimentName": "example_sklearn-classification", "trialConcurrency": 5, "maxExecDuration": 3600, "maxTrialNum": 100, "trainingServicePlatform": "local", "searchSpacePath": "../../../config_files/valid/search_space.json", "useAnnotation": false, "tuner": {"builtinTunerName": "TPE", "classArgs": {"optimize_mode": "maximize"}}, "trial": {"command": "python3 main.py", "codeDir": "../../../config_files/valid/.", "gpuNum": 0}}, "restServerPort": 8080, "restServerPid": 7952, "experimentId": "xOpEwA5w", "webuiUrl": ["http://localhost:8080"]}
{"experimentId": "xOpEwA5w"}
\ No newline at end of file
...@@ -9,31 +9,16 @@ HOME_PATH = str(Path(__file__).parent / "mock/nnictl_metadata") ...@@ -9,31 +9,16 @@ HOME_PATH = str(Path(__file__).parent / "mock/nnictl_metadata")
class CommonUtilsTestCase(TestCase): class CommonUtilsTestCase(TestCase):
# FIXME:
# `experiment.get_all_experiments()` returns empty dict. No idea why.
# Don't want to debug this because I will port the logic to `nni.experiment`.
#def test_get_experiment(self):
# experiment = Experiments(HOME_PATH)
# self.assertTrue('xOpEwA5w' in experiment.get_all_experiments())
def test_update_experiment(self): def test_update_experiment(self):
experiment = Experiments(HOME_PATH) experiment = Experiments(HOME_PATH)
experiment.add_experiment('xOpEwA5w', 8081, 'N/A', 'local', 'test', endTime='N/A', status='INITIALIZED') experiment.add_experiment('xOpEwA5w', 8081, 'N/A', 'local', 'test', endTime='N/A', status='INITIALIZED')
self.assertTrue('xOpEwA5w' in experiment.get_all_experiments()) self.assertTrue('xOpEwA5w' in experiment.get_all_experiments())
experiment.remove_experiment('xOpEwA5w') experiment.remove_experiment('xOpEwA5w')
self.assertFalse('xOpEwA5w' in experiment.get_all_experiments()) self.assertFalse('xOpEwA5w' in experiment.get_all_experiments())
def test_get_config(self): def test_get_config(self):
config = Config('config', HOME_PATH) config = Config('xOpEwA5w', HOME_PATH)
self.assertEqual(config.get_config('experimentId'), 'xOpEwA5w') self.assertEqual(config.get_config()['experimentName'], 'test_config')
def test_set_config(self):
config = Config('config', HOME_PATH)
self.assertNotEqual(config.get_config('experimentId'), 'testId')
config.set_config('experimentId', 'testId')
self.assertEqual(config.get_config('experimentId'), 'testId')
config.set_config('experimentId', 'xOpEwA5w')
if __name__ == '__main__': if __name__ == '__main__':
main() main()
...@@ -54,7 +54,7 @@ class CommonUtilsTestCase(TestCase): ...@@ -54,7 +54,7 @@ class CommonUtilsTestCase(TestCase):
@responses.activate @responses.activate
def test_get_experiment_port(self): def test_get_experiment_port(self):
args = generate_args() args = generate_args()
self.assertEqual('8080', get_experiment_port(args)) self.assertEqual(8080, get_experiment_port(args))
@responses.activate @responses.activate
def test_check_rest(self): def test_check_rest(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment