Unverified Commit e6629a76 authored by J-shang's avatar J-shang Committed by GitHub
Browse files

Merge .config & nni.sqlite (#3309)

* refactor util config

* merge NNICTL_HOME_DIR and NNI_HOME_DIR

* add default logDir and fix Config logdir

* fix save/load

* modify updater and tensorboard

* rename

* rename

* fix local clusterMetaData

* fix local config convert

* fix ut

* fix ut

* del unused code

* fix cursor
parent 85c0d841
......@@ -293,10 +293,11 @@ def to_rest_json(config: ExperimentConfig) -> Dict[str, Any]:
request_data['logCollection'] = experiment_config.get('logCollection')
request_data['clusterMetaData'] = []
if experiment_config['trainingServicePlatform'] == 'local':
if experiment_config.get('localConfig'):
request_data['clusterMetaData'].append(
{'key': 'local_config', 'value': experiment_config['localConfig']})
request_data['clusterMetaData'].append(
{'key':'codeDir', 'value':experiment_config['trial']['codeDir']})
request_data['clusterMetaData'].append(
{'key': 'command', 'value': experiment_config['trial']['command']})
{'key': 'trial_config', 'value': experiment_config['trial']})
elif experiment_config['trainingServicePlatform'] == 'remote':
request_data['clusterMetaData'].append(
{'key': 'machine_list', 'value': experiment_config['machineList']})
......
......@@ -119,5 +119,5 @@ def _init_experiment(config: ExperimentConfig, port: int, debug: bool) -> None:
def _save_experiment_information(experiment_id: str, port: int, start_time: int, platform: str, name: str, pid: int, logDir: str) -> None:
experiment_config = Experiments()
experiment_config.add_experiment(experiment_id, port, start_time, platform, name, pid=pid, logDir=logDir)
experiments_config = Experiments()
experiments_config.add_experiment(experiment_id, port, start_time, platform, name, pid=pid, logDir=logDir)
......@@ -4,56 +4,97 @@
import os
import json
import shutil
import sqlite3
import time
from .constants import NNICTL_HOME_DIR
from .constants import NNI_HOME_DIR
from .command_utils import print_error
from .common_utils import get_file_lock
def config_v0_to_v1(config: dict) -> dict:
if 'clusterMetaData' not in config:
return config
elif 'trainingServicePlatform' in config:
import copy
experiment_config = copy.deepcopy(config)
if experiment_config['trainingServicePlatform'] == 'hybrid':
inverse_config = {'hybridConfig': experiment_config['clusterMetaData']['hybrid_config']}
platform_list = inverse_config['hybridConfig']['trainingServicePlatforms']
for platform in platform_list:
inverse_config.update(_inverse_cluster_metadata(platform, experiment_config['clusterMetaData']))
experiment_config.update(inverse_config)
else:
inverse_config = _inverse_cluster_metadata(experiment_config['trainingServicePlatform'], experiment_config['clusterMetaData'])
experiment_config.update(inverse_config)
experiment_config.pop('clusterMetaData')
return experiment_config
else:
raise RuntimeError('experiment config key `trainingServicePlatform` not found')
def _inverse_cluster_metadata(platform: str, metadata_config: list) -> dict:
inverse_config = {}
if platform == 'local':
inverse_config['trial'] = {}
for kv in metadata_config:
if kv['key'] == 'local_config':
inverse_config['localConfig'] = kv['value']
elif kv['key'] == 'trial_config':
inverse_config['trial'] = kv['value']
elif platform == 'remote':
for kv in metadata_config:
if kv['key'] == 'machine_list':
inverse_config['machineList'] = kv['value']
elif kv['key'] == 'trial_config':
inverse_config['trial'] = kv['value']
elif kv['key'] == 'remote_config':
inverse_config['remoteConfig'] = kv['value']
elif platform == 'pai':
for kv in metadata_config:
if kv['key'] == 'pai_config':
inverse_config['paiConfig'] = kv['value']
elif kv['key'] == 'trial_config':
inverse_config['trial'] = kv['value']
elif platform == 'kubeflow':
for kv in metadata_config:
if kv['key'] == 'kubeflow_config':
inverse_config['kubeflowConfig'] = kv['value']
elif kv['key'] == 'trial_config':
inverse_config['trial'] = kv['value']
elif platform == 'frameworkcontroller':
for kv in metadata_config:
if kv['key'] == 'frameworkcontroller_config':
inverse_config['frameworkcontrollerConfig'] = kv['value']
elif kv['key'] == 'trial_config':
inverse_config['trial'] = kv['value']
elif platform == 'aml':
for kv in metadata_config:
if kv['key'] == 'aml_config':
inverse_config['amlConfig'] = kv['value']
elif kv['key'] == 'trial_config':
inverse_config['trial'] = kv['value']
else:
raise RuntimeError('training service platform not found')
return inverse_config
class Config:
'''a util class to load and save config'''
def __init__(self, file_path, home_dir=NNICTL_HOME_DIR):
config_path = os.path.join(home_dir, str(file_path))
os.makedirs(config_path, exist_ok=True)
self.config_file = os.path.join(config_path, '.config')
self.config = self.read_file()
def __init__(self, experiment_id: str, log_dir: str):
self.experiment_id = experiment_id
self.conn = sqlite3.connect(os.path.join(log_dir, experiment_id, 'db', 'nni.sqlite'))
self.refresh_config()
def get_all_config(self):
'''get all of config values'''
return json.dumps(self.config, indent=4, sort_keys=True, separators=(',', ':'))
def refresh_config(self):
'''refresh to get latest config'''
sql = 'select params from ExperimentProfile where id=? order by revision DESC'
args = (self.experiment_id,)
self.config = config_v0_to_v1(json.loads(self.conn.cursor().execute(sql, args).fetchone()[0]))
def set_config(self, key, value):
'''set {key:value} paris to self.config'''
self.config = self.read_file()
self.config[key] = value
self.write_file()
def get_config(self, key):
def get_config(self):
'''get a value according to key'''
return self.config.get(key)
def write_file(self):
'''save config to local file'''
if self.config:
try:
with open(self.config_file, 'w') as file:
json.dump(self.config, file, indent=4)
except IOError as error:
print('Error:', error)
return
def read_file(self):
'''load config from local file'''
if os.path.exists(self.config_file):
try:
with open(self.config_file, 'r') as file:
return json.load(file)
except ValueError:
return {}
return {}
return self.config
class Experiments:
'''Maintain experiment list'''
def __init__(self, home_dir=NNICTL_HOME_DIR):
def __init__(self, home_dir=NNI_HOME_DIR):
os.makedirs(home_dir, exist_ok=True)
self.experiment_file = os.path.join(home_dir, '.experiment')
self.lock = get_file_lock(self.experiment_file, stale=2)
......@@ -61,7 +102,7 @@ class Experiments:
self.experiments = self.read_file()
def add_experiment(self, expId, port, startTime, platform, experiment_name, endTime='N/A', status='INITIALIZED',
tag=[], pid=None, webuiUrl=[], logDir=[]):
tag=[], pid=None, webuiUrl=[], logDir=''):
'''set {key:value} pairs to self.experiment'''
with self.lock:
self.experiments = self.read_file()
......@@ -98,13 +139,6 @@ class Experiments:
self.experiments = self.read_file()
if expId in self.experiments:
self.experiments.pop(expId)
fileName = expId
if fileName:
logPath = os.path.join(NNICTL_HOME_DIR, fileName)
try:
shutil.rmtree(logPath)
except FileNotFoundError:
print_error('{0} does not exist.'.format(logPath))
self.write_file()
def get_all_experiments(self):
......
......@@ -4,8 +4,6 @@
import os
from colorama import Fore
NNICTL_HOME_DIR = os.path.join(os.path.expanduser('~'), 'nni-experiments')
NNI_HOME_DIR = os.path.join(os.path.expanduser('~'), 'nni-experiments')
ERROR_INFO = 'ERROR: '
......
......@@ -20,15 +20,15 @@ from .config_utils import Config, Experiments
from .common_utils import get_yml_content, get_json_content, print_error, print_normal, print_warning, \
detect_port, get_user
from .constants import NNICTL_HOME_DIR, ERROR_INFO, REST_TIME_OUT, EXPERIMENT_SUCCESS_INFO, LOG_HEADER
from .constants import NNI_HOME_DIR, ERROR_INFO, REST_TIME_OUT, EXPERIMENT_SUCCESS_INFO, LOG_HEADER
from .command_utils import check_output_command, kill_command
from .nnictl_utils import update_experiment
def get_log_path(experiment_id):
'''generate stdout and stderr log path'''
os.makedirs(os.path.join(NNICTL_HOME_DIR, experiment_id, 'log'), exist_ok=True)
stdout_full_path = os.path.join(NNICTL_HOME_DIR, experiment_id, 'log', 'nnictl_stdout.log')
stderr_full_path = os.path.join(NNICTL_HOME_DIR, experiment_id, 'log', 'nnictl_stderr.log')
os.makedirs(os.path.join(NNI_HOME_DIR, experiment_id, 'log'), exist_ok=True)
stdout_full_path = os.path.join(NNI_HOME_DIR, experiment_id, 'log', 'nnictl_stdout.log')
stderr_full_path = os.path.join(NNI_HOME_DIR, experiment_id, 'log', 'nnictl_stderr.log')
return stdout_full_path, stderr_full_path
def print_log_content(config_file_name):
......@@ -375,10 +375,11 @@ def set_experiment(experiment_config, mode, port, config_file_name):
request_data['logCollection'] = experiment_config.get('logCollection')
request_data['clusterMetaData'] = []
if experiment_config['trainingServicePlatform'] == 'local':
if experiment_config.get('localConfig'):
request_data['clusterMetaData'].append(
{'key': 'local_config', 'value': experiment_config['localConfig']})
request_data['clusterMetaData'].append(
{'key':'codeDir', 'value':experiment_config['trial']['codeDir']})
request_data['clusterMetaData'].append(
{'key': 'command', 'value': experiment_config['trial']['command']})
{'key': 'trial_config', 'value': experiment_config['trial']})
elif experiment_config['trainingServicePlatform'] == 'remote':
request_data['clusterMetaData'].append(
{'key': 'machine_list', 'value': experiment_config['machineList']})
......@@ -479,7 +480,6 @@ def set_platform_config(platform, experiment_config, port, config_file_name, res
def launch_experiment(args, experiment_config, mode, experiment_id):
'''follow steps to start rest server and start experiment'''
nni_config = Config(experiment_id)
# check packages for tuner
package_name, module_name = None, None
if experiment_config.get('tuner') and experiment_config['tuner'].get('builtinTunerName'):
......@@ -499,7 +499,7 @@ def launch_experiment(args, experiment_config, mode, experiment_id):
if package_name in ['SMAC', 'BOHB', 'PPOTuner']:
print_error(f'The dependencies for {package_name} can be installed through pip install nni[{package_name}]')
raise
log_dir = experiment_config['logDir'] if experiment_config.get('logDir') else None
log_dir = experiment_config['logDir'] if experiment_config.get('logDir') else NNI_HOME_DIR
log_level = experiment_config['logLevel'] if experiment_config.get('logLevel') else None
#view experiment mode do not need debug function, when view an experiment, there will be no new logs created
foreground = False
......@@ -510,12 +510,10 @@ def launch_experiment(args, experiment_config, mode, experiment_id):
# start rest server
rest_process, start_time = start_rest_server(args.port, experiment_config['trainingServicePlatform'], \
mode, experiment_id, foreground, log_dir, log_level)
nni_config.set_config('restServerPid', rest_process.pid)
# save experiment information
nnictl_experiment_config = Experiments()
nnictl_experiment_config.add_experiment(experiment_id, args.port, start_time,
experiment_config['trainingServicePlatform'],
experiment_config['experimentName'], pid=rest_process.pid, logDir=log_dir)
Experiments().add_experiment(experiment_id, args.port, start_time,
experiment_config['trainingServicePlatform'],
experiment_config['experimentName'], pid=rest_process.pid, logDir=log_dir)
# Deal with annotation
if experiment_config.get('useAnnotation'):
path = os.path.join(tempfile.gettempdir(), get_user(), 'nni', 'annotation')
......@@ -572,7 +570,7 @@ def launch_experiment(args, experiment_config, mode, experiment_id):
web_ui_url_list = ['http://{0}:{1}'.format(experiment_config['nniManagerIp'], str(args.port))]
else:
web_ui_url_list = get_local_urls(args.port)
nni_config.set_config('webuiUrl', web_ui_url_list)
Experiments().update_experiment(experiment_id, 'webuiUrl', web_ui_url_list)
print_normal(EXPERIMENT_SUCCESS_INFO % (experiment_id, ' '.join(web_ui_url_list)))
if mode != 'view' and args.foreground:
......@@ -587,8 +585,6 @@ def launch_experiment(args, experiment_config, mode, experiment_id):
def create_experiment(args):
'''start a new experiment'''
experiment_id = ''.join(random.sample(string.ascii_letters + string.digits, 8))
nni_config = Config(experiment_id)
nni_config.set_config('experimentId', experiment_id)
config_path = os.path.abspath(args.config)
if not os.path.exists(config_path):
print_error('Please set correct config path!')
......@@ -610,8 +606,6 @@ def create_experiment(args):
print_error(f'Config in v1 format validation failed. {repr(e)}')
exit(1)
nni_config.set_config('experimentConfig', experiment_config)
nni_config.set_config('restServerPort', args.port)
try:
launch_experiment(args, experiment_config, 'new', experiment_id)
except Exception as exception:
......@@ -624,8 +618,8 @@ def create_experiment(args):
def manage_stopped_experiment(args, mode):
'''view a stopped experiment'''
update_experiment()
experiment_config = Experiments()
experiment_dict = experiment_config.get_all_experiments()
experiments_config = Experiments()
experiments_dict = experiments_config.get_all_experiments()
experiment_id = None
#find the latest stopped experiment
if not args.id:
......@@ -633,17 +627,16 @@ def manage_stopped_experiment(args, mode):
'You could use \'nnictl experiment list --all\' to show all experiments!'.format(mode))
exit(1)
else:
if experiment_dict.get(args.id) is None:
if experiments_dict.get(args.id) is None:
print_error('Id %s not exist!' % args.id)
exit(1)
if experiment_dict[args.id]['status'] != 'STOPPED':
if experiments_dict[args.id]['status'] != 'STOPPED':
print_error('Only stopped experiments can be {0}ed!'.format(mode))
exit(1)
experiment_id = args.id
print_normal('{0} experiment {1}...'.format(mode, experiment_id))
nni_config = Config(experiment_id)
experiment_config = nni_config.get_config('experimentConfig')
nni_config.set_config('restServerPort', args.port)
experiment_config = Config(experiment_id, experiments_dict[args.id]['logDir']).get_config()
experiments_config.update_experiment(args.id, 'port', args.port)
try:
launch_experiment(args, experiment_config, mode, experiment_id)
except Exception as exception:
......
......@@ -19,8 +19,8 @@ import nni_node
from .rest_utils import rest_get, rest_delete, check_rest_server_quick, check_response
from .url_utils import trial_jobs_url, experiment_url, trial_job_id_url, export_data_url, metric_data_url
from .config_utils import Config, Experiments
from .constants import NNICTL_HOME_DIR, NNI_HOME_DIR, EXPERIMENT_INFORMATION_FORMAT, EXPERIMENT_DETAIL_FORMAT, \
EXPERIMENT_MONITOR_INFO, TRIAL_MONITOR_HEAD, TRIAL_MONITOR_CONTENT, TRIAL_MONITOR_TAIL, REST_TIME_OUT
from .constants import NNI_HOME_DIR, EXPERIMENT_INFORMATION_FORMAT, EXPERIMENT_DETAIL_FORMAT, EXPERIMENT_MONITOR_INFO, \
TRIAL_MONITOR_HEAD, TRIAL_MONITOR_CONTENT, TRIAL_MONITOR_TAIL, REST_TIME_OUT
from .common_utils import print_normal, print_error, print_warning, detect_process, get_yml_content, generate_temp_dir
from .common_utils import print_green
from .command_utils import check_output_command, kill_command
......@@ -43,16 +43,16 @@ def get_experiment_status(port):
def update_experiment():
'''Update the experiment status in config file'''
experiment_config = Experiments()
experiment_dict = experiment_config.get_all_experiments()
if not experiment_dict:
experiments_config = Experiments()
experiments_dict = experiments_config.get_all_experiments()
if not experiments_dict:
return None
for key in experiment_dict.keys():
if isinstance(experiment_dict[key], dict):
if experiment_dict[key].get('status') != 'STOPPED':
rest_pid = experiment_dict[key].get('pid')
for key in experiments_dict.keys():
if isinstance(experiments_dict[key], dict):
if experiments_dict[key].get('status') != 'STOPPED':
rest_pid = experiments_dict[key].get('pid')
if not detect_process(rest_pid):
experiment_config.update_experiment(key, 'status', 'STOPPED')
experiments_config.update_experiment(key, 'status', 'STOPPED')
continue
def check_experiment_id(args, update=True):
......@@ -60,31 +60,31 @@ def check_experiment_id(args, update=True):
'''
if update:
update_experiment()
experiment_config = Experiments()
experiment_dict = experiment_config.get_all_experiments()
if not experiment_dict:
experiments_config = Experiments()
experiments_dict = experiments_config.get_all_experiments()
if not experiments_dict:
print_normal('There is no experiment running...')
return None
if not args.id:
running_experiment_list = []
for key in experiment_dict.keys():
if isinstance(experiment_dict[key], dict):
if experiment_dict[key].get('status') != 'STOPPED':
for key in experiments_dict.keys():
if isinstance(experiments_dict[key], dict):
if experiments_dict[key].get('status') != 'STOPPED':
running_experiment_list.append(key)
elif isinstance(experiment_dict[key], list):
elif isinstance(experiments_dict[key], list):
# if the config file is old version, remove the configuration from file
experiment_config.remove_experiment(key)
experiments_config.remove_experiment(key)
if len(running_experiment_list) > 1:
print_error('There are multiple experiments, please set the experiment id...')
experiment_information = ""
for key in running_experiment_list:
experiment_information += EXPERIMENT_DETAIL_FORMAT % (key,
experiment_dict[key].get('experimentName', 'N/A'),
experiment_dict[key]['status'],
experiment_dict[key].get('port', 'N/A'),
experiment_dict[key].get('platform'),
time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(experiment_dict[key]['startTime'] / 1000)) if isinstance(experiment_dict[key]['startTime'], int) else experiment_dict[key]['startTime'],
time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(experiment_dict[key]['endTime'] / 1000)) if isinstance(experiment_dict[key]['endTime'], int) else experiment_dict[key]['endTime'])
experiments_dict[key].get('experimentName', 'N/A'),
experiments_dict[key]['status'],
experiments_dict[key].get('port', 'N/A'),
experiments_dict[key].get('platform'),
time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(experiments_dict[key]['startTime'] / 1000)) if isinstance(experiments_dict[key]['startTime'], int) else experiments_dict[key]['startTime'],
time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(experiments_dict[key]['endTime'] / 1000)) if isinstance(experiments_dict[key]['endTime'], int) else experiments_dict[key]['endTime'])
print(EXPERIMENT_INFORMATION_FORMAT % experiment_information)
exit(1)
elif not running_experiment_list:
......@@ -92,7 +92,7 @@ def check_experiment_id(args, update=True):
return None
else:
return running_experiment_list[0]
if experiment_dict.get(args.id):
if experiments_dict.get(args.id):
return args.id
else:
print_error('Id not correct.')
......@@ -110,25 +110,25 @@ def parse_ids(args):
8.If the id does not exist but match multiple prefix of the experiment ids, nnictl will give id information
'''
update_experiment()
experiment_config = Experiments()
experiment_dict = experiment_config.get_all_experiments()
if not experiment_dict:
experiments_config = Experiments()
experiments_dict = experiments_config.get_all_experiments()
if not experiments_dict:
print_normal('Experiment is not running...')
return None
result_list = []
running_experiment_list = []
for key in experiment_dict.keys():
if isinstance(experiment_dict[key], dict):
if experiment_dict[key].get('status') != 'STOPPED':
for key in experiments_dict.keys():
if isinstance(experiments_dict[key], dict):
if experiments_dict[key].get('status') != 'STOPPED':
running_experiment_list.append(key)
elif isinstance(experiment_dict[key], list):
elif isinstance(experiments_dict[key], list):
# if the config file is old version, remove the configuration from file
experiment_config.remove_experiment(key)
experiments_config.remove_experiment(key)
if args.all:
return running_experiment_list
if args.port is not None:
for key in running_experiment_list:
if experiment_dict[key].get('port') == args.port:
if experiments_dict[key].get('port') == args.port:
result_list.append(key)
if args.id and result_list and args.id != result_list[0]:
print_error('Experiment id and resful server port not match')
......@@ -139,12 +139,12 @@ def parse_ids(args):
experiment_information = ""
for key in running_experiment_list:
experiment_information += EXPERIMENT_DETAIL_FORMAT % (key,
experiment_dict[key].get('experimentName', 'N/A'),
experiment_dict[key]['status'],
experiment_dict[key].get('port', 'N/A'),
experiment_dict[key].get('platform'),
time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(experiment_dict[key]['startTime'] / 1000)) if isinstance(experiment_dict[key]['startTime'], int) else experiment_dict[key]['startTime'],
time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(experiment_dict[key]['endTime'] / 1000)) if isinstance(experiment_dict[key]['endTime'], int) else experiment_dict[key]['endTime'])
experiments_dict[key].get('experimentName', 'N/A'),
experiments_dict[key]['status'],
experiments_dict[key].get('port', 'N/A'),
experiments_dict[key].get('platform'),
time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(experiments_dict[key]['startTime'] / 1000)) if isinstance(experiments_dict[key]['startTime'], int) else experiments_dict[key]['startTime'],
time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(experiments_dict[key]['endTime'] / 1000)) if isinstance(experiments_dict[key]['endTime'], int) else experiments_dict[key]['endTime'])
print(EXPERIMENT_INFORMATION_FORMAT % experiment_information)
exit(1)
else:
......@@ -182,9 +182,9 @@ def get_experiment_port(args):
if experiment_id is None:
print_error('Please set correct experiment id.')
exit(1)
experiment_config = Experiments()
experiment_dict = experiment_config.get_all_experiments()
return experiment_dict[experiment_id].get('port')
experiments_config = Experiments()
experiments_dict = experiments_config.get_all_experiments()
return experiments_dict[experiment_id].get('port')
def convert_time_stamp_to_date(content):
'''Convert time stamp to date time format'''
......@@ -200,9 +200,9 @@ def convert_time_stamp_to_date(content):
def check_rest(args):
'''check if restful server is running'''
experiment_config = Experiments()
experiment_dict = experiment_config.get_all_experiments()
rest_port = experiment_dict.get(get_config_filename(args)).get('port')
experiments_config = Experiments()
experiments_dict = experiments_config.get_all_experiments()
rest_port = experiments_dict.get(get_config_filename(args)).get('port')
running, _ = check_rest_server_quick(rest_port)
if running:
print_normal('Restful server is running...')
......@@ -219,19 +219,19 @@ def stop_experiment(args):
if experiment_id_list:
for experiment_id in experiment_id_list:
print_normal('Stopping experiment %s' % experiment_id)
experiment_config = Experiments()
experiment_dict = experiment_config.get_all_experiments()
rest_pid = experiment_dict.get(experiment_id).get('pid')
experiments_config = Experiments()
experiments_dict = experiments_config.get_all_experiments()
rest_pid = experiments_dict.get(experiment_id).get('pid')
if rest_pid:
kill_command(rest_pid)
tensorboard_pid_list = experiment_dict.get(experiment_id).get('tensorboardPidList')
tensorboard_pid_list = experiments_dict.get(experiment_id).get('tensorboardPidList')
if tensorboard_pid_list:
for tensorboard_pid in tensorboard_pid_list:
try:
kill_command(tensorboard_pid)
except Exception as exception:
print_error(exception)
experiment_config.update_experiment(experiment_id, 'tensorboardPidList', [])
experiments_config.update_experiment(experiment_id, 'tensorboardPidList', [])
print_normal('Stop experiment success.')
def trial_ls(args):
......@@ -250,10 +250,11 @@ def trial_ls(args):
if args.head and args.tail:
print_error('Head and tail cannot be set at the same time.')
return
experiment_config = Experiments()
experiment_dict = experiment_config.get_all_experiments()
rest_port = experiment_dict.get(get_config_filename(args)).get('port')
rest_pid = experiment_dict.get(get_config_filename(args)).get('pid')
experiments_config = Experiments()
experiments_dict = experiments_config.get_all_experiments()
experiment_id = get_config_filename(args)
rest_port = experiments_dict.get(experiment_id).get('port')
rest_pid = experiments_dict.get(experiment_id).get('pid')
if not detect_process(rest_pid):
print_error('Experiment is not running...')
return
......@@ -282,10 +283,11 @@ def trial_ls(args):
def trial_kill(args):
'''List trial'''
experiment_config = Experiments()
experiment_dict = experiment_config.get_all_experiments()
rest_port = experiment_dict.get(get_config_filename(args)).get('port')
rest_pid = experiment_dict.get(get_config_filename(args)).get('pid')
experiments_config = Experiments()
experiments_dict = experiments_config.get_all_experiments()
experiment_id = get_config_filename(args)
rest_port = experiments_dict.get(experiment_id).get('port')
rest_pid = experiments_dict.get(experiment_id).get('pid')
if not detect_process(rest_pid):
print_error('Experiment is not running...')
return
......@@ -304,20 +306,21 @@ def trial_kill(args):
def trial_codegen(args):
'''Generate code for a specific trial'''
print_warning('Currently, this command is only for nni nas programming interface.')
exp_id = check_experiment_id(args)
nni_config = Config(get_config_filename(args))
if not nni_config.get_config('experimentConfig')['useAnnotation']:
exp_id = get_config_filename(args)
experiment_config = Config(exp_id, Experiments().get_all_experiments()[exp_id]['logDir']).get_config()
if not experiment_config.get('useAnnotation'):
print_error('The experiment is not using annotation')
exit(1)
code_dir = nni_config.get_config('experimentConfig')['trial']['codeDir']
code_dir = experiment_config['trial']['codeDir']
expand_annotations(code_dir, './exp_%s_trial_%s_code'%(exp_id, args.trial_id), exp_id, args.trial_id)
def list_experiment(args):
'''Get experiment information'''
experiment_config = Experiments()
experiment_dict = experiment_config.get_all_experiments()
rest_port = experiment_dict.get(get_config_filename(args)).get('port')
rest_pid = experiment_dict.get(get_config_filename(args)).get('pid')
experiments_config = Experiments()
experiments_dict = experiments_config.get_all_experiments()
experiment_id = get_config_filename(args)
rest_port = experiments_dict.get(experiment_id).get('port')
rest_pid = experiments_dict.get(experiment_id).get('pid')
if not detect_process(rest_pid):
print_error('Experiment is not running...')
return
......@@ -336,9 +339,9 @@ def list_experiment(args):
def experiment_status(args):
'''Show the status of experiment'''
experiment_config = Experiments()
experiment_dict = experiment_config.get_all_experiments()
rest_port = experiment_dict.get(get_config_filename(args)).get('port')
experiments_config = Experiments()
experiments_dict = experiments_config.get_all_experiments()
rest_port = experiments_dict.get(get_config_filename(args)).get('port')
result, response = check_rest_server_quick(rest_port)
if not result:
print_normal('Restful server is not running...')
......@@ -350,9 +353,9 @@ def log_internal(args, filetype):
'''internal function to call get_log_content'''
file_name = get_config_filename(args)
if filetype == 'stdout':
file_full_path = os.path.join(NNICTL_HOME_DIR, file_name, 'log', 'nnictl_stdout.log')
file_full_path = os.path.join(NNI_HOME_DIR, file_name, 'log', 'nnictl_stdout.log')
else:
file_full_path = os.path.join(NNICTL_HOME_DIR, file_name, 'log', 'nnictl_stderr.log')
file_full_path = os.path.join(NNI_HOME_DIR, file_name, 'log', 'nnictl_stderr.log')
print(check_output_command(file_full_path, head=args.head, tail=args.tail))
def log_stdout(args):
......@@ -401,9 +404,12 @@ def log_trial(args):
''''get trial log path'''
trial_id_path_dict = {}
trial_id_list = []
nni_config = Config(get_config_filename(args))
rest_port = nni_config.get_config('restServerPort')
rest_pid = nni_config.get_config('restServerPid')
experiments_config = Experiments()
experiments_dict = experiments_config.get_all_experiments()
experiment_id = get_config_filename(args)
rest_port = experiments_dict.get(experiment_id).get('port')
rest_pid = experiments_dict.get(experiment_id).get('pid')
experiment_config = Config(experiment_id, experiments_dict.get(experiment_id).get('logDir')).get_config()
if not detect_process(rest_pid):
print_error('Experiment is not running...')
return
......@@ -419,7 +425,7 @@ def log_trial(args):
else:
print_error('Restful server is not running...')
exit(1)
is_adl = nni_config.get_config('experimentConfig').get('trainingServicePlatform') == 'adl'
is_adl = experiment_config.get('trainingServicePlatform') == 'adl'
if is_adl and not args.trial_id:
print_error('Trial ID is required to retrieve the log for adl. Please specify it with "--trial_id".')
exit(1)
......@@ -428,7 +434,7 @@ def log_trial(args):
print_error('Trial id {0} not correct, please check your command!'.format(args.trial_id))
exit(1)
if is_adl:
log_trial_adl_helper(args, nni_config.get_config('experimentId'))
log_trial_adl_helper(args, experiment_id)
# adl has its own way to log trial, and it thus returns right after the helper returns
return
if trial_id_path_dict.get(args.trial_id):
......@@ -445,13 +451,15 @@ def log_trial(args):
def get_config(args):
'''get config info'''
nni_config = Config(get_config_filename(args))
print(nni_config.get_all_config())
experiment_id = get_config_filename(args)
experiment_config = Config(experiment_id, Experiments().get_all_experiments()[experiment_id]['logDir']).get_config()
print(json.dumps(experiment_config, indent=4))
def webui_url(args):
'''show the url of web ui'''
nni_config = Config(get_config_filename(args))
print_normal('{0} {1}'.format('Web UI url:', ' '.join(nni_config.get_config('webuiUrl'))))
experiment_id = get_config_filename(args)
experiments_dict = Experiments().get_all_experiments()
print_normal('{0} {1}'.format('Web UI url:', ' '.join(experiments_dict[experiment_id].get('webuiUrl'))))
def webui_nas(args):
'''launch nas ui'''
......@@ -520,15 +528,15 @@ def hdfs_clean(host, user_name, output_dir, experiment_id=None):
def experiment_clean(args):
'''clean up the experiment data'''
experiment_id_list = []
experiment_config = Experiments()
experiment_dict = experiment_config.get_all_experiments()
experiments_config = Experiments()
experiments_dict = experiments_config.get_all_experiments()
if args.all:
experiment_id_list = list(experiment_dict.keys())
experiment_id_list = list(experiments_dict.keys())
else:
if args.id is None:
print_error('please set experiment id.')
exit(1)
if args.id not in experiment_dict:
if args.id not in experiments_dict:
print_error('Cannot find experiment {0}.'.format(args.id))
exit(1)
experiment_id_list.append(args.id)
......@@ -542,23 +550,23 @@ def experiment_clean(args):
else:
break
for experiment_id in experiment_id_list:
nni_config = Config(experiment_id)
platform = nni_config.get_config('experimentConfig').get('trainingServicePlatform')
experiment_id = nni_config.get_config('experimentId')
experiment_id = get_config_filename(args)
experiment_config = Config(experiment_id, Experiments().get_all_experiments()[experiment_id]['logDir']).get_config()
platform = experiment_config.get('trainingServicePlatform')
if platform == 'remote':
machine_list = nni_config.get_config('experimentConfig').get('machineList')
machine_list = experiment_config.get('machineList')
remote_clean(machine_list, experiment_id)
elif platform == 'pai':
host = nni_config.get_config('experimentConfig').get('paiConfig').get('host')
user_name = nni_config.get_config('experimentConfig').get('paiConfig').get('userName')
output_dir = nni_config.get_config('experimentConfig').get('trial').get('outputDir')
host = experiment_config.get('paiConfig').get('host')
user_name = experiment_config.get('paiConfig').get('userName')
output_dir = experiment_config.get('trial').get('outputDir')
hdfs_clean(host, user_name, output_dir, experiment_id)
elif platform != 'local':
#TODO: support all platforms
# TODO: support all platforms
print_warning('platform {0} clean up not supported yet.'.format(platform))
exit(0)
#clean local data
local_base_dir = nni_config.get_config('experimentConfig').get('logDir')
# clean local data
local_base_dir = experiments_config[experiment_id]['logDir']
if not local_base_dir:
local_base_dir = NNI_HOME_DIR
local_experiment_dir = os.path.join(local_base_dir, experiment_id)
......@@ -567,9 +575,8 @@ def experiment_clean(args):
local_clean(os.path.join(local_experiment_dir, folder_name))
if not os.listdir(local_experiment_dir):
local_clean(local_experiment_dir)
experiment_config = Experiments()
print_normal('removing metadata of experiment {0}'.format(experiment_id))
experiment_config.remove_experiment(experiment_id)
experiments_config.remove_experiment(experiment_id)
print_normal('Done.')
def get_platform_dir(config_content):
......@@ -635,30 +642,30 @@ def platform_clean(args):
def experiment_list(args):
'''get the information of all experiments'''
update_experiment()
experiment_config = Experiments()
experiment_dict = experiment_config.get_all_experiments()
if not experiment_dict:
experiments_config = Experiments()
experiments_dict = experiments_config.get_all_experiments()
if not experiments_dict:
print_normal('Cannot find experiments.')
exit(1)
experiment_id_list = []
if args.all:
for key in experiment_dict.keys():
for key in experiments_dict.keys():
experiment_id_list.append(key)
else:
for key in experiment_dict.keys():
if experiment_dict[key]['status'] != 'STOPPED':
for key in experiments_dict.keys():
if experiments_dict[key]['status'] != 'STOPPED':
experiment_id_list.append(key)
if not experiment_id_list:
print_warning('There is no experiment running...\nYou can use \'nnictl experiment list --all\' to list all experiments.')
experiment_information = ""
for key in experiment_id_list:
experiment_information += EXPERIMENT_DETAIL_FORMAT % (key,
experiment_dict[key].get('experimentName', 'N/A'),
experiment_dict[key]['status'],
experiment_dict[key].get('port', 'N/A'),
experiment_dict[key].get('platform'),
time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(experiment_dict[key]['startTime'] / 1000)) if isinstance(experiment_dict[key]['startTime'], int) else experiment_dict[key]['startTime'],
time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(experiment_dict[key]['endTime'] / 1000)) if isinstance(experiment_dict[key]['endTime'], int) else experiment_dict[key]['endTime'])
experiments_dict[key].get('experimentName', 'N/A'),
experiments_dict[key]['status'],
experiments_dict[key].get('port', 'N/A'),
experiments_dict[key].get('platform'),
time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(experiments_dict[key]['startTime'] / 1000)) if isinstance(experiments_dict[key]['startTime'], int) else experiments_dict[key]['startTime'],
time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(experiments_dict[key]['endTime'] / 1000)) if isinstance(experiments_dict[key]['endTime'], int) else experiments_dict[key]['endTime'])
print(EXPERIMENT_INFORMATION_FORMAT % experiment_information)
return experiment_id_list
......@@ -680,26 +687,26 @@ def get_time_interval(time1, time2):
def show_experiment_info():
'''show experiment information in monitor'''
update_experiment()
experiment_config = Experiments()
experiment_dict = experiment_config.get_all_experiments()
if not experiment_dict:
experiments_config = Experiments()
experiments_dict = experiments_config.get_all_experiments()
if not experiments_dict:
print('There is no experiment running...')
exit(1)
experiment_id_list = []
for key in experiment_dict.keys():
if experiment_dict[key]['status'] != 'STOPPED':
for key in experiments_dict.keys():
if experiments_dict[key]['status'] != 'STOPPED':
experiment_id_list.append(key)
if not experiment_id_list:
print_warning('There is no experiment running...')
return
for key in experiment_id_list:
print(EXPERIMENT_MONITOR_INFO % (key, experiment_dict[key]['status'], experiment_dict[key]['port'], \
experiment_dict[key].get('platform'), time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(experiment_dict[key]['startTime'] / 1000)) if isinstance(experiment_dict[key]['startTime'], int) else experiment_dict[key]['startTime'], \
get_time_interval(experiment_dict[key]['startTime'], experiment_dict[key]['endTime'])))
print(EXPERIMENT_MONITOR_INFO % (key, experiments_dict[key]['status'], experiments_dict[key]['port'], \
experiments_dict[key].get('platform'), time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(experiments_dict[key]['startTime'] / 1000)) if isinstance(experiments_dict[key]['startTime'], int) else experiments_dict[key]['startTime'], \
get_time_interval(experiments_dict[key]['startTime'], experiments_dict[key]['endTime'])))
print(TRIAL_MONITOR_HEAD)
running, response = check_rest_server_quick(experiment_dict[key]['port'])
running, response = check_rest_server_quick(experiments_dict[key]['port'])
if running:
response = rest_get(trial_jobs_url(experiment_dict[key]['port']), REST_TIME_OUT)
response = rest_get(trial_jobs_url(experiments_dict[key]['port']), REST_TIME_OUT)
if response and check_response(response):
content = json.loads(response.text)
for index, value in enumerate(content):
......@@ -756,10 +763,11 @@ def export_trials_data(args):
groupby.setdefault(content['trialJobId'], []).append(json.loads(content['data']))
return groupby
experiment_config = Experiments()
experiment_dict = experiment_config.get_all_experiments()
rest_port = experiment_dict.get(get_config_filename(args)).get('port')
rest_pid = experiment_dict.get(get_config_filename(args)).get('pid')
experiments_config = Experiments()
experiments_dict = experiments_config.get_all_experiments()
experiment_id = get_config_filename(args)
rest_port = experiments_dict.get(experiment_id).get('port')
rest_pid = experiments_dict.get(experiment_id).get('pid')
if not detect_process(rest_pid):
print_error('Experiment is not running...')
......@@ -825,22 +833,20 @@ def search_space_auto_gen(args):
def save_experiment(args):
'''save experiment data to a zip file'''
experiment_config = Experiments()
experiment_dict = experiment_config.get_all_experiments()
experiments_config = Experiments()
experiments_dict = experiments_config.get_all_experiments()
if args.id is None:
print_error('Please set experiment id.')
exit(1)
if args.id not in experiment_dict:
if args.id not in experiments_dict:
print_error('Cannot find experiment {0}.'.format(args.id))
exit(1)
if experiment_dict[args.id].get('status') != 'STOPPED':
if experiments_dict[args.id].get('status') != 'STOPPED':
print_error('Can only save stopped experiment!')
exit(1)
print_normal('Saving...')
nni_config = Config(args.id)
logDir = os.path.join(NNI_HOME_DIR, args.id)
if nni_config.get_config('logDir'):
logDir = os.path.join(nni_config.get_config('logDir'), args.id)
experiment_config = Config(args.id, experiments_dict[args.id]['logDir']).get_config()
logDir = os.path.join(experiments_dict[args.id]['logDir'], args.id)
temp_root_dir = generate_temp_dir()
# Step1. Copy logDir to temp folder
......@@ -855,22 +861,21 @@ def save_experiment(args):
os.makedirs(temp_nnictl_dir, exist_ok=True)
try:
with open(os.path.join(temp_nnictl_dir, '.experiment'), 'w') as file:
experiment_dict[args.id]['id'] = args.id
json.dump(experiment_dict[args.id], file)
experiments_dict[args.id]['id'] = args.id
json.dump(experiments_dict[args.id], file)
except IOError:
print_error('Write file to %s failed!' % os.path.join(temp_nnictl_dir, '.experiment'))
exit(1)
nnictl_log_dir = os.path.join(NNICTL_HOME_DIR, args.id, 'log')
nnictl_log_dir = os.path.join(NNI_HOME_DIR, args.id, 'log')
shutil.copytree(nnictl_log_dir, os.path.join(temp_nnictl_dir, args.id, 'log'))
shutil.copy(os.path.join(NNICTL_HOME_DIR, args.id, '.config'), os.path.join(temp_nnictl_dir, args.id, '.config'))
# Step3. Copy code dir
if args.saveCodeDir:
temp_code_dir = os.path.join(temp_root_dir, 'code')
shutil.copytree(nni_config.get_config('experimentConfig')['trial']['codeDir'], temp_code_dir)
shutil.copytree(experiment_config['trial']['codeDir'], temp_code_dir)
# Step4. Copy searchSpace file
search_space_path = nni_config.get_config('experimentConfig').get('searchSpacePath')
search_space_path = experiment_config.get('searchSpacePath')
if search_space_path:
if not os.path.exists(search_space_path):
print_warning('search space %s does not exist!' % search_space_path)
......@@ -928,10 +933,10 @@ def load_experiment(args):
print_error('Invalid nnictl metadata file: %s' % err)
shutil.rmtree(temp_root_dir)
exit(1)
experiment_config = Experiments()
experiment_dict = experiment_config.get_all_experiments()
experiments_config = Experiments()
experiments_dict = experiments_config.get_all_experiments()
experiment_id = experiment_metadata.get('id')
if experiment_id in experiment_dict:
if experiment_id in experiments_dict:
print_error('Invalid: experiment id already exist!')
shutil.rmtree(temp_root_dir)
exit(1)
......@@ -942,25 +947,25 @@ def load_experiment(args):
# Step2. Copy nnictl metadata
src_path = os.path.join(nnictl_temp_dir, experiment_id)
dest_path = os.path.join(NNICTL_HOME_DIR, experiment_id)
dest_path = os.path.join(NNI_HOME_DIR, experiment_id)
if os.path.exists(dest_path):
shutil.rmtree(dest_path)
shutil.copytree(src_path, dest_path)
# Step3. Copy experiment data
nni_config = Config(experiment_id)
nnictl_exp_config = nni_config.get_config('experimentConfig')
os.rename(os.path.join(temp_root_dir, 'experiment'), os.path.join(temp_root_dir, experiment_id))
src_path = os.path.join(os.path.join(temp_root_dir, experiment_id))
experiment_config = Config(experiment_id, temp_root_dir).get_config()
if args.logDir:
logDir = args.logDir
nnictl_exp_config['logDir'] = logDir
experiment_config['logDir'] = logDir
else:
if nnictl_exp_config.get('logDir'):
logDir = nnictl_exp_config['logDir']
if experiment_config.get('logDir'):
logDir = experiment_config['logDir']
else:
logDir = NNI_HOME_DIR
os.rename(os.path.join(temp_root_dir, 'experiment'), os.path.join(temp_root_dir, experiment_id))
src_path = os.path.join(os.path.join(temp_root_dir, experiment_id))
dest_path = os.path.join(os.path.join(logDir, experiment_id))
dest_path = os.path.join(logDir, experiment_id)
if os.path.exists(dest_path):
shutil.rmtree(dest_path)
shutil.copytree(src_path, dest_path)
......@@ -970,7 +975,7 @@ def load_experiment(args):
if not os.path.isabs(codeDir):
codeDir = os.path.join(os.getcwd(), codeDir)
print_normal('Expand codeDir to %s' % codeDir)
nnictl_exp_config['trial']['codeDir'] = codeDir
experiment_config['trial']['codeDir'] = codeDir
archive_code_dir = os.path.join(temp_root_dir, 'code')
if os.path.exists(archive_code_dir):
file_list = os.listdir(archive_code_dir)
......@@ -985,44 +990,18 @@ def load_experiment(args):
else:
shutil.copy(src_path, target_path)
# Step5. Copy searchSpace file
archive_search_space_dir = os.path.join(temp_root_dir, 'searchSpace')
if args.searchSpacePath:
target_path = os.path.expanduser(args.searchSpacePath)
else:
# set default path to codeDir
target_path = os.path.join(codeDir, 'search_space.json')
if not os.path.isabs(target_path):
target_path = os.path.join(os.getcwd(), target_path)
print_normal('Expand search space path to %s' % target_path)
nnictl_exp_config['searchSpacePath'] = target_path
# if the path already has a search space file, use the original one, otherwise use archived one
if not os.path.isfile(target_path):
if len(os.listdir(archive_search_space_dir)) == 0:
print_error('Archive file does not contain search space file!')
exit(1)
else:
for file in os.listdir(archive_search_space_dir):
source_path = os.path.join(archive_search_space_dir, file)
os.makedirs(os.path.dirname(target_path), exist_ok=True)
shutil.copyfile(source_path, target_path)
break
elif not args.searchSpacePath:
print_warning('%s exist, will not load search_space file' % target_path)
# Step6. Create experiment metadata
nni_config.set_config('experimentConfig', nnictl_exp_config)
experiment_config.add_experiment(experiment_id,
experiment_metadata.get('port'),
experiment_metadata.get('startTime'),
experiment_metadata.get('platform'),
experiment_metadata.get('experimentName'),
experiment_metadata.get('endTime'),
experiment_metadata.get('status'),
experiment_metadata.get('tag'),
experiment_metadata.get('pid'),
experiment_metadata.get('webUrl'),
experiment_metadata.get('logDir'))
# Step5. Create experiment metadata
experiments_config.add_experiment(experiment_id,
experiment_metadata.get('port'),
experiment_metadata.get('startTime'),
experiment_metadata.get('platform'),
experiment_metadata.get('experimentName'),
experiment_metadata.get('endTime'),
experiment_metadata.get('status'),
experiment_metadata.get('tag'),
experiment_metadata.get('pid'),
experiment_metadata.get('webUrl'),
logDir)
print_normal('Load experiment %s succsss!' % experiment_id)
# Step6. Cleanup temp data
......
......@@ -31,9 +31,9 @@ def parse_log_path(args, trial_content):
exit(1)
return path_list, host_list
def copy_data_from_remote(args, nni_config, trial_content, path_list, host_list, temp_nni_path):
def copy_data_from_remote(args, experiment_config, trial_content, path_list, host_list, temp_nni_path):
'''use ssh client to copy data from remote machine to local machien'''
machine_list = nni_config.get_config('experimentConfig').get('machineList')
machine_list = experiment_config.get('machineList')
machine_dict = {}
local_path_list = []
for machine in machine_list:
......@@ -49,15 +49,15 @@ def copy_data_from_remote(args, nni_config, trial_content, path_list, host_list,
print_normal('Copy done!')
return local_path_list
def get_path_list(args, nni_config, trial_content, temp_nni_path):
def get_path_list(args, experiment_config, trial_content, temp_nni_path):
'''get path list according to different platform'''
path_list, host_list = parse_log_path(args, trial_content)
platform = nni_config.get_config('experimentConfig').get('trainingServicePlatform')
platform = experiment_config.get('trainingServicePlatform')
if platform == 'local':
print_normal('Log path: %s' % ' '.join(path_list))
return path_list
elif platform == 'remote':
path_list = copy_data_from_remote(args, nni_config, trial_content, path_list, host_list, temp_nni_path)
path_list = copy_data_from_remote(args, experiment_config, trial_content, path_list, host_list, temp_nni_path)
print_normal('Log path: %s' % ' '.join(path_list))
return path_list
else:
......@@ -83,19 +83,19 @@ def start_tensorboard_process(args, experiment_id, path_list, temp_nni_path):
url_list = get_local_urls(args.port)
print_green('Start tensorboard success!')
print_normal('Tensorboard urls: ' + ' '.join(url_list))
experiment_config = Experiments()
tensorboard_process_pid_list = experiment_config.get_all_experiments().get(experiment_id).get('tensorboardPidList')
experiments_config = Experiments()
tensorboard_process_pid_list = experiments_config.get_all_experiments().get(experiment_id).get('tensorboardPidList')
if tensorboard_process_pid_list is None:
tensorboard_process_pid_list = [tensorboard_process.pid]
else:
tensorboard_process_pid_list.append(tensorboard_process.pid)
experiment_config.update_experiment(experiment_id, 'tensorboardPidList', tensorboard_process_pid_list)
experiments_config.update_experiment(experiment_id, 'tensorboardPidList', tensorboard_process_pid_list)
def stop_tensorboard(args):
'''stop tensorboard'''
experiment_id = check_experiment_id(args)
experiment_config = Experiments()
tensorboard_pid_list = experiment_config.get_all_experiments().get(experiment_id).get('tensorboardPidList')
experiments_config = Experiments()
tensorboard_pid_list = experiments_config.get_all_experiments().get(experiment_id).get('tensorboardPidList')
if tensorboard_pid_list:
for tensorboard_pid in tensorboard_pid_list:
try:
......@@ -103,7 +103,7 @@ def stop_tensorboard(args):
call(cmds)
except Exception as exception:
print_error(exception)
experiment_config.update_experiment(experiment_id, 'tensorboardPidList', [])
experiments_config.update_experiment(experiment_id, 'tensorboardPidList', [])
print_normal('Stop tensorboard success!')
else:
print_error('No tensorboard configuration!')
......@@ -128,17 +128,17 @@ def start_tensorboard(args):
return
if args.id is None:
args.id = experiment_id
experiment_config = Experiments()
experiment_dict = experiment_config.get_all_experiments()
if experiment_dict[args.id]["status"] == "STOPPED":
experiments_config = Experiments()
experiments_dict = experiments_config.get_all_experiments()
if experiments_dict[args.id]["status"] == "STOPPED":
print_error("Experiment {} is stopped...".format(args.id))
return
nni_config = Config(args.id)
if nni_config.get_config('experimentConfig').get('trainingServicePlatform') == 'adl':
experiment_config = Config(args.id, experiments_dict[args.id]['logDir']).get_config()
if experiment_config.get('trainingServicePlatform') == 'adl':
adl_tensorboard_helper(args)
return
rest_port = nni_config.get_config('restServerPort')
rest_pid = nni_config.get_config('restServerPid')
rest_port = experiments_dict[args.id]['port']
rest_pid = experiments_dict[args.id]['pid']
if not detect_process(rest_pid):
print_error('Experiment is not running...')
return
......@@ -158,9 +158,9 @@ def start_tensorboard(args):
if len(trial_content) > 1 and not args.trial_id:
print_error('There are multiple trials, please set trial id!')
exit(1)
experiment_id = nni_config.get_config('experimentId')
experiment_id = args.id
temp_nni_path = os.path.join(tempfile.gettempdir(), 'nni', experiment_id)
os.makedirs(temp_nni_path, exist_ok=True)
path_list = get_path_list(args, nni_config, trial_content, temp_nni_path)
path_list = get_path_list(args, experiment_config, trial_content, temp_nni_path)
start_tensorboard_process(args, experiment_id, path_list, temp_nni_path)
......@@ -23,11 +23,12 @@ def validate_file(path):
def validate_dispatcher(args):
'''validate if the dispatcher of the experiment supports importing data'''
nni_config = Config(get_config_filename(args)).get_config('experimentConfig')
if nni_config.get('tuner') and nni_config['tuner'].get('builtinTunerName'):
dispatcher_name = nni_config['tuner']['builtinTunerName']
elif nni_config.get('advisor') and nni_config['advisor'].get('builtinAdvisorName'):
dispatcher_name = nni_config['advisor']['builtinAdvisorName']
experiment_id = get_config_filename(args)
experiment_config = Config(experiment_id, Experiments().get_all_experiments()[experiment_id]['logDir']).get_config()
if experiment_config.get('tuner') and experiment_config['tuner'].get('builtinTunerName'):
dispatcher_name = experiment_config['tuner']['builtinTunerName']
elif experiment_config.get('advisor') and experiment_config['advisor'].get('builtinAdvisorName'):
dispatcher_name = experiment_config['advisor']['builtinAdvisorName']
else: # otherwise it should be a customized one
return
if dispatcher_name not in TUNERS_SUPPORTING_IMPORT_DATA:
......@@ -58,9 +59,9 @@ def get_query_type(key):
def update_experiment_profile(args, key, value):
'''call restful server to update experiment profile'''
experiment_config = Experiments()
experiment_dict = experiment_config.get_all_experiments()
rest_port = experiment_dict.get(get_config_filename(args)).get('port')
experiments_config = Experiments()
experiments_dict = experiments_config.get_all_experiments()
rest_port = experiments_dict.get(get_config_filename(args)).get('port')
running, _ = check_rest_server_quick(rest_port)
if running:
response = rest_get(experiment_url(rest_port), REST_TIME_OUT)
......@@ -117,9 +118,10 @@ def import_data(args):
validate_dispatcher(args)
content = load_search_space(args.filename)
nni_config = Config(get_config_filename(args))
rest_port = nni_config.get_config('restServerPort')
rest_pid = nni_config.get_config('restServerPid')
experiments_dict = Experiments().get_all_experiments()
experiment_id = get_config_filename(args)
rest_port = experiments_dict.get(experiment_id).get('port')
rest_pid = experiments_dict.get(experiment_id).get('pid')
if not detect_process(rest_pid):
print_error('Experiment is not running...')
return
......@@ -137,8 +139,8 @@ def import_data(args):
def import_data_to_restful_server(args, content):
'''call restful server to import data to the experiment'''
nni_config = Config(get_config_filename(args))
rest_port = nni_config.get_config('restServerPort')
experiments_dict = Experiments().get_all_experiments()
rest_port = experiments_dict.get(get_config_filename(args)).get('port')
running, _ = check_rest_server_quick(rest_port)
if running:
response = rest_post(import_data_url(rest_port), content, REST_TIME_OUT)
......
......@@ -4,31 +4,27 @@
import argparse
from pathlib import Path
from subprocess import Popen, PIPE, STDOUT
from nni.tools.nnictl.config_utils import Config, Experiments
from nni.tools.nnictl.config_utils import Experiments
from nni.tools.nnictl.common_utils import print_green
from nni.tools.nnictl.command_utils import kill_command
from nni.tools.nnictl.nnictl_utils import get_yml_content
def create_mock_experiment():
nnictl_experiment_config = Experiments()
nnictl_experiment_config.add_experiment('xOpEwA5w', '8080', 123456,
nnictl_experiment_config.add_experiment('xOpEwA5w', 8080, 123456,
'local', 'example_sklearn-classification')
nni_config = Config('xOpEwA5w')
# mock process
cmds = ['sleep', '3600000']
process = Popen(cmds, stdout=PIPE, stderr=STDOUT)
nni_config.set_config('restServerPid', process.pid)
nni_config.set_config('experimentId', 'xOpEwA5w')
nni_config.set_config('restServerPort', 8080)
nni_config.set_config('webuiUrl', ['http://localhost:8080'])
yml_path = Path(__file__).parents[1] / 'config_files/valid/test.yml'
experiment_config = get_yml_content(str(yml_path))
nni_config.set_config('experimentConfig', experiment_config)
nnictl_experiment_config.update_experiment('xOpEwA5w', 'pid', process.pid)
nnictl_experiment_config.update_experiment('xOpEwA5w', 'port', 8080)
nnictl_experiment_config.update_experiment('xOpEwA5w', 'webuiUrl', ['http://localhost:8080'])
print_green("expriment start success, experiment id: xOpEwA5w")
def stop_mock_experiment():
config = Config('config')
kill_command(config.get_config('restServerPid'))
nnictl_experiment_config = Experiments()
experiments_dict = nnictl_experiment_config.get_all_experiments()
kill_command(experiments_dict['xOpEwA5w'].get('pid'))
nnictl_experiment_config = Experiments()
nnictl_experiment_config.remove_experiment('xOpEwA5w')
......
{"experimentConfig": {"authorName": "default", "experimentName": "example_sklearn-classification", "trialConcurrency": 5, "maxExecDuration": 3600, "maxTrialNum": 100, "trainingServicePlatform": "local", "searchSpacePath": "../../../config_files/valid/search_space.json", "useAnnotation": false, "tuner": {"builtinTunerName": "TPE", "classArgs": {"optimize_mode": "maximize"}}, "trial": {"command": "python3 main.py", "codeDir": "../../../config_files/valid/.", "gpuNum": 0}}, "restServerPort": 8080, "restServerPid": 7952, "experimentId": "xOpEwA5w", "webuiUrl": ["http://localhost:8080"]}
{"experimentId": "xOpEwA5w"}
\ No newline at end of file
......@@ -9,31 +9,16 @@ HOME_PATH = str(Path(__file__).parent / "mock/nnictl_metadata")
class CommonUtilsTestCase(TestCase):
# FIXME:
# `experiment.get_all_experiments()` returns empty dict. No idea why.
# Don't want to debug this because I will port the logic to `nni.experiment`.
#def test_get_experiment(self):
# experiment = Experiments(HOME_PATH)
# self.assertTrue('xOpEwA5w' in experiment.get_all_experiments())
def test_update_experiment(self):
experiment = Experiments(HOME_PATH)
experiment.add_experiment('xOpEwA5w', 8081, 'N/A', 'local', 'test', endTime='N/A', status='INITIALIZED')
self.assertTrue('xOpEwA5w' in experiment.get_all_experiments())
experiment.remove_experiment('xOpEwA5w')
self.assertFalse('xOpEwA5w' in experiment.get_all_experiments())
def test_get_config(self):
config = Config('config', HOME_PATH)
self.assertEqual(config.get_config('experimentId'), 'xOpEwA5w')
def test_set_config(self):
config = Config('config', HOME_PATH)
self.assertNotEqual(config.get_config('experimentId'), 'testId')
config.set_config('experimentId', 'testId')
self.assertEqual(config.get_config('experimentId'), 'testId')
config.set_config('experimentId', 'xOpEwA5w')
config = Config('xOpEwA5w', HOME_PATH)
self.assertEqual(config.get_config()['experimentName'], 'test_config')
if __name__ == '__main__':
main()
......@@ -54,7 +54,7 @@ class CommonUtilsTestCase(TestCase):
@responses.activate
def test_get_experiment_port(self):
args = generate_args()
self.assertEqual('8080', get_experiment_port(args))
self.assertEqual(8080, get_experiment_port(args))
@responses.activate
def test_check_rest(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment