"driver/driver.hip.cpp" did not exist on "120ab94aa18d00bf5fdf6b77b512a7f702425e80"
Unverified Commit e6629a76 authored by J-shang's avatar J-shang Committed by GitHub
Browse files

Merge .config & nni.sqlite (#3309)

* refactor util config

* merge NNICTL_HOME_DIR and NNI_HOME_DIR

* add default logDir and fix Config logdir

* fix save/load

* modify updater and tensorboard

* rename

* rename

* fix local clusterMetaData

* fix local config convert

* fix ut

* fix ut

* del unused code

* fix cursor
parent 85c0d841
...@@ -293,10 +293,11 @@ def to_rest_json(config: ExperimentConfig) -> Dict[str, Any]: ...@@ -293,10 +293,11 @@ def to_rest_json(config: ExperimentConfig) -> Dict[str, Any]:
request_data['logCollection'] = experiment_config.get('logCollection') request_data['logCollection'] = experiment_config.get('logCollection')
request_data['clusterMetaData'] = [] request_data['clusterMetaData'] = []
if experiment_config['trainingServicePlatform'] == 'local': if experiment_config['trainingServicePlatform'] == 'local':
if experiment_config.get('localConfig'):
request_data['clusterMetaData'].append(
{'key': 'local_config', 'value': experiment_config['localConfig']})
request_data['clusterMetaData'].append( request_data['clusterMetaData'].append(
{'key':'codeDir', 'value':experiment_config['trial']['codeDir']}) {'key': 'trial_config', 'value': experiment_config['trial']})
request_data['clusterMetaData'].append(
{'key': 'command', 'value': experiment_config['trial']['command']})
elif experiment_config['trainingServicePlatform'] == 'remote': elif experiment_config['trainingServicePlatform'] == 'remote':
request_data['clusterMetaData'].append( request_data['clusterMetaData'].append(
{'key': 'machine_list', 'value': experiment_config['machineList']}) {'key': 'machine_list', 'value': experiment_config['machineList']})
......
...@@ -119,5 +119,5 @@ def _init_experiment(config: ExperimentConfig, port: int, debug: bool) -> None: ...@@ -119,5 +119,5 @@ def _init_experiment(config: ExperimentConfig, port: int, debug: bool) -> None:
def _save_experiment_information(experiment_id: str, port: int, start_time: int, platform: str, name: str, pid: int, logDir: str) -> None: def _save_experiment_information(experiment_id: str, port: int, start_time: int, platform: str, name: str, pid: int, logDir: str) -> None:
experiment_config = Experiments() experiments_config = Experiments()
experiment_config.add_experiment(experiment_id, port, start_time, platform, name, pid=pid, logDir=logDir) experiments_config.add_experiment(experiment_id, port, start_time, platform, name, pid=pid, logDir=logDir)
...@@ -4,56 +4,97 @@ ...@@ -4,56 +4,97 @@
import os import os
import json import json
import shutil import shutil
import sqlite3
import time import time
from .constants import NNICTL_HOME_DIR from .constants import NNI_HOME_DIR
from .command_utils import print_error from .command_utils import print_error
from .common_utils import get_file_lock from .common_utils import get_file_lock
def config_v0_to_v1(config: dict) -> dict:
if 'clusterMetaData' not in config:
return config
elif 'trainingServicePlatform' in config:
import copy
experiment_config = copy.deepcopy(config)
if experiment_config['trainingServicePlatform'] == 'hybrid':
inverse_config = {'hybridConfig': experiment_config['clusterMetaData']['hybrid_config']}
platform_list = inverse_config['hybridConfig']['trainingServicePlatforms']
for platform in platform_list:
inverse_config.update(_inverse_cluster_metadata(platform, experiment_config['clusterMetaData']))
experiment_config.update(inverse_config)
else:
inverse_config = _inverse_cluster_metadata(experiment_config['trainingServicePlatform'], experiment_config['clusterMetaData'])
experiment_config.update(inverse_config)
experiment_config.pop('clusterMetaData')
return experiment_config
else:
raise RuntimeError('experiment config key `trainingServicePlatform` not found')
def _inverse_cluster_metadata(platform: str, metadata_config: list) -> dict:
inverse_config = {}
if platform == 'local':
inverse_config['trial'] = {}
for kv in metadata_config:
if kv['key'] == 'local_config':
inverse_config['localConfig'] = kv['value']
elif kv['key'] == 'trial_config':
inverse_config['trial'] = kv['value']
elif platform == 'remote':
for kv in metadata_config:
if kv['key'] == 'machine_list':
inverse_config['machineList'] = kv['value']
elif kv['key'] == 'trial_config':
inverse_config['trial'] = kv['value']
elif kv['key'] == 'remote_config':
inverse_config['remoteConfig'] = kv['value']
elif platform == 'pai':
for kv in metadata_config:
if kv['key'] == 'pai_config':
inverse_config['paiConfig'] = kv['value']
elif kv['key'] == 'trial_config':
inverse_config['trial'] = kv['value']
elif platform == 'kubeflow':
for kv in metadata_config:
if kv['key'] == 'kubeflow_config':
inverse_config['kubeflowConfig'] = kv['value']
elif kv['key'] == 'trial_config':
inverse_config['trial'] = kv['value']
elif platform == 'frameworkcontroller':
for kv in metadata_config:
if kv['key'] == 'frameworkcontroller_config':
inverse_config['frameworkcontrollerConfig'] = kv['value']
elif kv['key'] == 'trial_config':
inverse_config['trial'] = kv['value']
elif platform == 'aml':
for kv in metadata_config:
if kv['key'] == 'aml_config':
inverse_config['amlConfig'] = kv['value']
elif kv['key'] == 'trial_config':
inverse_config['trial'] = kv['value']
else:
raise RuntimeError('training service platform not found')
return inverse_config
class Config: class Config:
'''a util class to load and save config''' '''a util class to load and save config'''
def __init__(self, file_path, home_dir=NNICTL_HOME_DIR): def __init__(self, experiment_id: str, log_dir: str):
config_path = os.path.join(home_dir, str(file_path)) self.experiment_id = experiment_id
os.makedirs(config_path, exist_ok=True) self.conn = sqlite3.connect(os.path.join(log_dir, experiment_id, 'db', 'nni.sqlite'))
self.config_file = os.path.join(config_path, '.config') self.refresh_config()
self.config = self.read_file()
def get_all_config(self): def refresh_config(self):
'''get all of config values''' '''refresh to get latest config'''
return json.dumps(self.config, indent=4, sort_keys=True, separators=(',', ':')) sql = 'select params from ExperimentProfile where id=? order by revision DESC'
args = (self.experiment_id,)
self.config = config_v0_to_v1(json.loads(self.conn.cursor().execute(sql, args).fetchone()[0]))
def set_config(self, key, value): def get_config(self):
'''set {key:value} paris to self.config'''
self.config = self.read_file()
self.config[key] = value
self.write_file()
def get_config(self, key):
'''get a value according to key''' '''get a value according to key'''
return self.config.get(key) return self.config
def write_file(self):
'''save config to local file'''
if self.config:
try:
with open(self.config_file, 'w') as file:
json.dump(self.config, file, indent=4)
except IOError as error:
print('Error:', error)
return
def read_file(self):
'''load config from local file'''
if os.path.exists(self.config_file):
try:
with open(self.config_file, 'r') as file:
return json.load(file)
except ValueError:
return {}
return {}
class Experiments: class Experiments:
'''Maintain experiment list''' '''Maintain experiment list'''
def __init__(self, home_dir=NNICTL_HOME_DIR): def __init__(self, home_dir=NNI_HOME_DIR):
os.makedirs(home_dir, exist_ok=True) os.makedirs(home_dir, exist_ok=True)
self.experiment_file = os.path.join(home_dir, '.experiment') self.experiment_file = os.path.join(home_dir, '.experiment')
self.lock = get_file_lock(self.experiment_file, stale=2) self.lock = get_file_lock(self.experiment_file, stale=2)
...@@ -61,7 +102,7 @@ class Experiments: ...@@ -61,7 +102,7 @@ class Experiments:
self.experiments = self.read_file() self.experiments = self.read_file()
def add_experiment(self, expId, port, startTime, platform, experiment_name, endTime='N/A', status='INITIALIZED', def add_experiment(self, expId, port, startTime, platform, experiment_name, endTime='N/A', status='INITIALIZED',
tag=[], pid=None, webuiUrl=[], logDir=[]): tag=[], pid=None, webuiUrl=[], logDir=''):
'''set {key:value} pairs to self.experiment''' '''set {key:value} pairs to self.experiment'''
with self.lock: with self.lock:
self.experiments = self.read_file() self.experiments = self.read_file()
...@@ -98,13 +139,6 @@ class Experiments: ...@@ -98,13 +139,6 @@ class Experiments:
self.experiments = self.read_file() self.experiments = self.read_file()
if expId in self.experiments: if expId in self.experiments:
self.experiments.pop(expId) self.experiments.pop(expId)
fileName = expId
if fileName:
logPath = os.path.join(NNICTL_HOME_DIR, fileName)
try:
shutil.rmtree(logPath)
except FileNotFoundError:
print_error('{0} does not exist.'.format(logPath))
self.write_file() self.write_file()
def get_all_experiments(self): def get_all_experiments(self):
......
...@@ -4,8 +4,6 @@ ...@@ -4,8 +4,6 @@
import os import os
from colorama import Fore from colorama import Fore
NNICTL_HOME_DIR = os.path.join(os.path.expanduser('~'), 'nni-experiments')
NNI_HOME_DIR = os.path.join(os.path.expanduser('~'), 'nni-experiments') NNI_HOME_DIR = os.path.join(os.path.expanduser('~'), 'nni-experiments')
ERROR_INFO = 'ERROR: ' ERROR_INFO = 'ERROR: '
......
...@@ -20,15 +20,15 @@ from .config_utils import Config, Experiments ...@@ -20,15 +20,15 @@ from .config_utils import Config, Experiments
from .common_utils import get_yml_content, get_json_content, print_error, print_normal, print_warning, \ from .common_utils import get_yml_content, get_json_content, print_error, print_normal, print_warning, \
detect_port, get_user detect_port, get_user
from .constants import NNICTL_HOME_DIR, ERROR_INFO, REST_TIME_OUT, EXPERIMENT_SUCCESS_INFO, LOG_HEADER from .constants import NNI_HOME_DIR, ERROR_INFO, REST_TIME_OUT, EXPERIMENT_SUCCESS_INFO, LOG_HEADER
from .command_utils import check_output_command, kill_command from .command_utils import check_output_command, kill_command
from .nnictl_utils import update_experiment from .nnictl_utils import update_experiment
def get_log_path(experiment_id): def get_log_path(experiment_id):
'''generate stdout and stderr log path''' '''generate stdout and stderr log path'''
os.makedirs(os.path.join(NNICTL_HOME_DIR, experiment_id, 'log'), exist_ok=True) os.makedirs(os.path.join(NNI_HOME_DIR, experiment_id, 'log'), exist_ok=True)
stdout_full_path = os.path.join(NNICTL_HOME_DIR, experiment_id, 'log', 'nnictl_stdout.log') stdout_full_path = os.path.join(NNI_HOME_DIR, experiment_id, 'log', 'nnictl_stdout.log')
stderr_full_path = os.path.join(NNICTL_HOME_DIR, experiment_id, 'log', 'nnictl_stderr.log') stderr_full_path = os.path.join(NNI_HOME_DIR, experiment_id, 'log', 'nnictl_stderr.log')
return stdout_full_path, stderr_full_path return stdout_full_path, stderr_full_path
def print_log_content(config_file_name): def print_log_content(config_file_name):
...@@ -375,10 +375,11 @@ def set_experiment(experiment_config, mode, port, config_file_name): ...@@ -375,10 +375,11 @@ def set_experiment(experiment_config, mode, port, config_file_name):
request_data['logCollection'] = experiment_config.get('logCollection') request_data['logCollection'] = experiment_config.get('logCollection')
request_data['clusterMetaData'] = [] request_data['clusterMetaData'] = []
if experiment_config['trainingServicePlatform'] == 'local': if experiment_config['trainingServicePlatform'] == 'local':
if experiment_config.get('localConfig'):
request_data['clusterMetaData'].append(
{'key': 'local_config', 'value': experiment_config['localConfig']})
request_data['clusterMetaData'].append( request_data['clusterMetaData'].append(
{'key':'codeDir', 'value':experiment_config['trial']['codeDir']}) {'key': 'trial_config', 'value': experiment_config['trial']})
request_data['clusterMetaData'].append(
{'key': 'command', 'value': experiment_config['trial']['command']})
elif experiment_config['trainingServicePlatform'] == 'remote': elif experiment_config['trainingServicePlatform'] == 'remote':
request_data['clusterMetaData'].append( request_data['clusterMetaData'].append(
{'key': 'machine_list', 'value': experiment_config['machineList']}) {'key': 'machine_list', 'value': experiment_config['machineList']})
...@@ -479,7 +480,6 @@ def set_platform_config(platform, experiment_config, port, config_file_name, res ...@@ -479,7 +480,6 @@ def set_platform_config(platform, experiment_config, port, config_file_name, res
def launch_experiment(args, experiment_config, mode, experiment_id): def launch_experiment(args, experiment_config, mode, experiment_id):
'''follow steps to start rest server and start experiment''' '''follow steps to start rest server and start experiment'''
nni_config = Config(experiment_id)
# check packages for tuner # check packages for tuner
package_name, module_name = None, None package_name, module_name = None, None
if experiment_config.get('tuner') and experiment_config['tuner'].get('builtinTunerName'): if experiment_config.get('tuner') and experiment_config['tuner'].get('builtinTunerName'):
...@@ -499,7 +499,7 @@ def launch_experiment(args, experiment_config, mode, experiment_id): ...@@ -499,7 +499,7 @@ def launch_experiment(args, experiment_config, mode, experiment_id):
if package_name in ['SMAC', 'BOHB', 'PPOTuner']: if package_name in ['SMAC', 'BOHB', 'PPOTuner']:
print_error(f'The dependencies for {package_name} can be installed through pip install nni[{package_name}]') print_error(f'The dependencies for {package_name} can be installed through pip install nni[{package_name}]')
raise raise
log_dir = experiment_config['logDir'] if experiment_config.get('logDir') else None log_dir = experiment_config['logDir'] if experiment_config.get('logDir') else NNI_HOME_DIR
log_level = experiment_config['logLevel'] if experiment_config.get('logLevel') else None log_level = experiment_config['logLevel'] if experiment_config.get('logLevel') else None
#view experiment mode do not need debug function, when view an experiment, there will be no new logs created #view experiment mode do not need debug function, when view an experiment, there will be no new logs created
foreground = False foreground = False
...@@ -510,12 +510,10 @@ def launch_experiment(args, experiment_config, mode, experiment_id): ...@@ -510,12 +510,10 @@ def launch_experiment(args, experiment_config, mode, experiment_id):
# start rest server # start rest server
rest_process, start_time = start_rest_server(args.port, experiment_config['trainingServicePlatform'], \ rest_process, start_time = start_rest_server(args.port, experiment_config['trainingServicePlatform'], \
mode, experiment_id, foreground, log_dir, log_level) mode, experiment_id, foreground, log_dir, log_level)
nni_config.set_config('restServerPid', rest_process.pid)
# save experiment information # save experiment information
nnictl_experiment_config = Experiments() Experiments().add_experiment(experiment_id, args.port, start_time,
nnictl_experiment_config.add_experiment(experiment_id, args.port, start_time, experiment_config['trainingServicePlatform'],
experiment_config['trainingServicePlatform'], experiment_config['experimentName'], pid=rest_process.pid, logDir=log_dir)
experiment_config['experimentName'], pid=rest_process.pid, logDir=log_dir)
# Deal with annotation # Deal with annotation
if experiment_config.get('useAnnotation'): if experiment_config.get('useAnnotation'):
path = os.path.join(tempfile.gettempdir(), get_user(), 'nni', 'annotation') path = os.path.join(tempfile.gettempdir(), get_user(), 'nni', 'annotation')
...@@ -572,7 +570,7 @@ def launch_experiment(args, experiment_config, mode, experiment_id): ...@@ -572,7 +570,7 @@ def launch_experiment(args, experiment_config, mode, experiment_id):
web_ui_url_list = ['http://{0}:{1}'.format(experiment_config['nniManagerIp'], str(args.port))] web_ui_url_list = ['http://{0}:{1}'.format(experiment_config['nniManagerIp'], str(args.port))]
else: else:
web_ui_url_list = get_local_urls(args.port) web_ui_url_list = get_local_urls(args.port)
nni_config.set_config('webuiUrl', web_ui_url_list) Experiments().update_experiment(experiment_id, 'webuiUrl', web_ui_url_list)
print_normal(EXPERIMENT_SUCCESS_INFO % (experiment_id, ' '.join(web_ui_url_list))) print_normal(EXPERIMENT_SUCCESS_INFO % (experiment_id, ' '.join(web_ui_url_list)))
if mode != 'view' and args.foreground: if mode != 'view' and args.foreground:
...@@ -587,8 +585,6 @@ def launch_experiment(args, experiment_config, mode, experiment_id): ...@@ -587,8 +585,6 @@ def launch_experiment(args, experiment_config, mode, experiment_id):
def create_experiment(args): def create_experiment(args):
'''start a new experiment''' '''start a new experiment'''
experiment_id = ''.join(random.sample(string.ascii_letters + string.digits, 8)) experiment_id = ''.join(random.sample(string.ascii_letters + string.digits, 8))
nni_config = Config(experiment_id)
nni_config.set_config('experimentId', experiment_id)
config_path = os.path.abspath(args.config) config_path = os.path.abspath(args.config)
if not os.path.exists(config_path): if not os.path.exists(config_path):
print_error('Please set correct config path!') print_error('Please set correct config path!')
...@@ -610,8 +606,6 @@ def create_experiment(args): ...@@ -610,8 +606,6 @@ def create_experiment(args):
print_error(f'Config in v1 format validation failed. {repr(e)}') print_error(f'Config in v1 format validation failed. {repr(e)}')
exit(1) exit(1)
nni_config.set_config('experimentConfig', experiment_config)
nni_config.set_config('restServerPort', args.port)
try: try:
launch_experiment(args, experiment_config, 'new', experiment_id) launch_experiment(args, experiment_config, 'new', experiment_id)
except Exception as exception: except Exception as exception:
...@@ -624,8 +618,8 @@ def create_experiment(args): ...@@ -624,8 +618,8 @@ def create_experiment(args):
def manage_stopped_experiment(args, mode): def manage_stopped_experiment(args, mode):
'''view a stopped experiment''' '''view a stopped experiment'''
update_experiment() update_experiment()
experiment_config = Experiments() experiments_config = Experiments()
experiment_dict = experiment_config.get_all_experiments() experiments_dict = experiments_config.get_all_experiments()
experiment_id = None experiment_id = None
#find the latest stopped experiment #find the latest stopped experiment
if not args.id: if not args.id:
...@@ -633,17 +627,16 @@ def manage_stopped_experiment(args, mode): ...@@ -633,17 +627,16 @@ def manage_stopped_experiment(args, mode):
'You could use \'nnictl experiment list --all\' to show all experiments!'.format(mode)) 'You could use \'nnictl experiment list --all\' to show all experiments!'.format(mode))
exit(1) exit(1)
else: else:
if experiment_dict.get(args.id) is None: if experiments_dict.get(args.id) is None:
print_error('Id %s not exist!' % args.id) print_error('Id %s not exist!' % args.id)
exit(1) exit(1)
if experiment_dict[args.id]['status'] != 'STOPPED': if experiments_dict[args.id]['status'] != 'STOPPED':
print_error('Only stopped experiments can be {0}ed!'.format(mode)) print_error('Only stopped experiments can be {0}ed!'.format(mode))
exit(1) exit(1)
experiment_id = args.id experiment_id = args.id
print_normal('{0} experiment {1}...'.format(mode, experiment_id)) print_normal('{0} experiment {1}...'.format(mode, experiment_id))
nni_config = Config(experiment_id) experiment_config = Config(experiment_id, experiments_dict[args.id]['logDir']).get_config()
experiment_config = nni_config.get_config('experimentConfig') experiments_config.update_experiment(args.id, 'port', args.port)
nni_config.set_config('restServerPort', args.port)
try: try:
launch_experiment(args, experiment_config, mode, experiment_id) launch_experiment(args, experiment_config, mode, experiment_id)
except Exception as exception: except Exception as exception:
......
This diff is collapsed.
...@@ -31,9 +31,9 @@ def parse_log_path(args, trial_content): ...@@ -31,9 +31,9 @@ def parse_log_path(args, trial_content):
exit(1) exit(1)
return path_list, host_list return path_list, host_list
def copy_data_from_remote(args, nni_config, trial_content, path_list, host_list, temp_nni_path): def copy_data_from_remote(args, experiment_config, trial_content, path_list, host_list, temp_nni_path):
'''use ssh client to copy data from remote machine to local machien''' '''use ssh client to copy data from remote machine to local machien'''
machine_list = nni_config.get_config('experimentConfig').get('machineList') machine_list = experiment_config.get('machineList')
machine_dict = {} machine_dict = {}
local_path_list = [] local_path_list = []
for machine in machine_list: for machine in machine_list:
...@@ -49,15 +49,15 @@ def copy_data_from_remote(args, nni_config, trial_content, path_list, host_list, ...@@ -49,15 +49,15 @@ def copy_data_from_remote(args, nni_config, trial_content, path_list, host_list,
print_normal('Copy done!') print_normal('Copy done!')
return local_path_list return local_path_list
def get_path_list(args, nni_config, trial_content, temp_nni_path): def get_path_list(args, experiment_config, trial_content, temp_nni_path):
'''get path list according to different platform''' '''get path list according to different platform'''
path_list, host_list = parse_log_path(args, trial_content) path_list, host_list = parse_log_path(args, trial_content)
platform = nni_config.get_config('experimentConfig').get('trainingServicePlatform') platform = experiment_config.get('trainingServicePlatform')
if platform == 'local': if platform == 'local':
print_normal('Log path: %s' % ' '.join(path_list)) print_normal('Log path: %s' % ' '.join(path_list))
return path_list return path_list
elif platform == 'remote': elif platform == 'remote':
path_list = copy_data_from_remote(args, nni_config, trial_content, path_list, host_list, temp_nni_path) path_list = copy_data_from_remote(args, experiment_config, trial_content, path_list, host_list, temp_nni_path)
print_normal('Log path: %s' % ' '.join(path_list)) print_normal('Log path: %s' % ' '.join(path_list))
return path_list return path_list
else: else:
...@@ -83,19 +83,19 @@ def start_tensorboard_process(args, experiment_id, path_list, temp_nni_path): ...@@ -83,19 +83,19 @@ def start_tensorboard_process(args, experiment_id, path_list, temp_nni_path):
url_list = get_local_urls(args.port) url_list = get_local_urls(args.port)
print_green('Start tensorboard success!') print_green('Start tensorboard success!')
print_normal('Tensorboard urls: ' + ' '.join(url_list)) print_normal('Tensorboard urls: ' + ' '.join(url_list))
experiment_config = Experiments() experiments_config = Experiments()
tensorboard_process_pid_list = experiment_config.get_all_experiments().get(experiment_id).get('tensorboardPidList') tensorboard_process_pid_list = experiments_config.get_all_experiments().get(experiment_id).get('tensorboardPidList')
if tensorboard_process_pid_list is None: if tensorboard_process_pid_list is None:
tensorboard_process_pid_list = [tensorboard_process.pid] tensorboard_process_pid_list = [tensorboard_process.pid]
else: else:
tensorboard_process_pid_list.append(tensorboard_process.pid) tensorboard_process_pid_list.append(tensorboard_process.pid)
experiment_config.update_experiment(experiment_id, 'tensorboardPidList', tensorboard_process_pid_list) experiments_config.update_experiment(experiment_id, 'tensorboardPidList', tensorboard_process_pid_list)
def stop_tensorboard(args): def stop_tensorboard(args):
'''stop tensorboard''' '''stop tensorboard'''
experiment_id = check_experiment_id(args) experiment_id = check_experiment_id(args)
experiment_config = Experiments() experiments_config = Experiments()
tensorboard_pid_list = experiment_config.get_all_experiments().get(experiment_id).get('tensorboardPidList') tensorboard_pid_list = experiments_config.get_all_experiments().get(experiment_id).get('tensorboardPidList')
if tensorboard_pid_list: if tensorboard_pid_list:
for tensorboard_pid in tensorboard_pid_list: for tensorboard_pid in tensorboard_pid_list:
try: try:
...@@ -103,7 +103,7 @@ def stop_tensorboard(args): ...@@ -103,7 +103,7 @@ def stop_tensorboard(args):
call(cmds) call(cmds)
except Exception as exception: except Exception as exception:
print_error(exception) print_error(exception)
experiment_config.update_experiment(experiment_id, 'tensorboardPidList', []) experiments_config.update_experiment(experiment_id, 'tensorboardPidList', [])
print_normal('Stop tensorboard success!') print_normal('Stop tensorboard success!')
else: else:
print_error('No tensorboard configuration!') print_error('No tensorboard configuration!')
...@@ -128,17 +128,17 @@ def start_tensorboard(args): ...@@ -128,17 +128,17 @@ def start_tensorboard(args):
return return
if args.id is None: if args.id is None:
args.id = experiment_id args.id = experiment_id
experiment_config = Experiments() experiments_config = Experiments()
experiment_dict = experiment_config.get_all_experiments() experiments_dict = experiments_config.get_all_experiments()
if experiment_dict[args.id]["status"] == "STOPPED": if experiments_dict[args.id]["status"] == "STOPPED":
print_error("Experiment {} is stopped...".format(args.id)) print_error("Experiment {} is stopped...".format(args.id))
return return
nni_config = Config(args.id) experiment_config = Config(args.id, experiments_dict[args.id]['logDir']).get_config()
if nni_config.get_config('experimentConfig').get('trainingServicePlatform') == 'adl': if experiment_config.get('trainingServicePlatform') == 'adl':
adl_tensorboard_helper(args) adl_tensorboard_helper(args)
return return
rest_port = nni_config.get_config('restServerPort') rest_port = experiments_dict[args.id]['port']
rest_pid = nni_config.get_config('restServerPid') rest_pid = experiments_dict[args.id]['pid']
if not detect_process(rest_pid): if not detect_process(rest_pid):
print_error('Experiment is not running...') print_error('Experiment is not running...')
return return
...@@ -158,9 +158,9 @@ def start_tensorboard(args): ...@@ -158,9 +158,9 @@ def start_tensorboard(args):
if len(trial_content) > 1 and not args.trial_id: if len(trial_content) > 1 and not args.trial_id:
print_error('There are multiple trials, please set trial id!') print_error('There are multiple trials, please set trial id!')
exit(1) exit(1)
experiment_id = nni_config.get_config('experimentId') experiment_id = args.id
temp_nni_path = os.path.join(tempfile.gettempdir(), 'nni', experiment_id) temp_nni_path = os.path.join(tempfile.gettempdir(), 'nni', experiment_id)
os.makedirs(temp_nni_path, exist_ok=True) os.makedirs(temp_nni_path, exist_ok=True)
path_list = get_path_list(args, nni_config, trial_content, temp_nni_path) path_list = get_path_list(args, experiment_config, trial_content, temp_nni_path)
start_tensorboard_process(args, experiment_id, path_list, temp_nni_path) start_tensorboard_process(args, experiment_id, path_list, temp_nni_path)
...@@ -23,11 +23,12 @@ def validate_file(path): ...@@ -23,11 +23,12 @@ def validate_file(path):
def validate_dispatcher(args): def validate_dispatcher(args):
'''validate if the dispatcher of the experiment supports importing data''' '''validate if the dispatcher of the experiment supports importing data'''
nni_config = Config(get_config_filename(args)).get_config('experimentConfig') experiment_id = get_config_filename(args)
if nni_config.get('tuner') and nni_config['tuner'].get('builtinTunerName'): experiment_config = Config(experiment_id, Experiments().get_all_experiments()[experiment_id]['logDir']).get_config()
dispatcher_name = nni_config['tuner']['builtinTunerName'] if experiment_config.get('tuner') and experiment_config['tuner'].get('builtinTunerName'):
elif nni_config.get('advisor') and nni_config['advisor'].get('builtinAdvisorName'): dispatcher_name = experiment_config['tuner']['builtinTunerName']
dispatcher_name = nni_config['advisor']['builtinAdvisorName'] elif experiment_config.get('advisor') and experiment_config['advisor'].get('builtinAdvisorName'):
dispatcher_name = experiment_config['advisor']['builtinAdvisorName']
else: # otherwise it should be a customized one else: # otherwise it should be a customized one
return return
if dispatcher_name not in TUNERS_SUPPORTING_IMPORT_DATA: if dispatcher_name not in TUNERS_SUPPORTING_IMPORT_DATA:
...@@ -58,9 +59,9 @@ def get_query_type(key): ...@@ -58,9 +59,9 @@ def get_query_type(key):
def update_experiment_profile(args, key, value): def update_experiment_profile(args, key, value):
'''call restful server to update experiment profile''' '''call restful server to update experiment profile'''
experiment_config = Experiments() experiments_config = Experiments()
experiment_dict = experiment_config.get_all_experiments() experiments_dict = experiments_config.get_all_experiments()
rest_port = experiment_dict.get(get_config_filename(args)).get('port') rest_port = experiments_dict.get(get_config_filename(args)).get('port')
running, _ = check_rest_server_quick(rest_port) running, _ = check_rest_server_quick(rest_port)
if running: if running:
response = rest_get(experiment_url(rest_port), REST_TIME_OUT) response = rest_get(experiment_url(rest_port), REST_TIME_OUT)
...@@ -117,9 +118,10 @@ def import_data(args): ...@@ -117,9 +118,10 @@ def import_data(args):
validate_dispatcher(args) validate_dispatcher(args)
content = load_search_space(args.filename) content = load_search_space(args.filename)
nni_config = Config(get_config_filename(args)) experiments_dict = Experiments().get_all_experiments()
rest_port = nni_config.get_config('restServerPort') experiment_id = get_config_filename(args)
rest_pid = nni_config.get_config('restServerPid') rest_port = experiments_dict.get(experiment_id).get('port')
rest_pid = experiments_dict.get(experiment_id).get('pid')
if not detect_process(rest_pid): if not detect_process(rest_pid):
print_error('Experiment is not running...') print_error('Experiment is not running...')
return return
...@@ -137,8 +139,8 @@ def import_data(args): ...@@ -137,8 +139,8 @@ def import_data(args):
def import_data_to_restful_server(args, content): def import_data_to_restful_server(args, content):
'''call restful server to import data to the experiment''' '''call restful server to import data to the experiment'''
nni_config = Config(get_config_filename(args)) experiments_dict = Experiments().get_all_experiments()
rest_port = nni_config.get_config('restServerPort') rest_port = experiments_dict.get(get_config_filename(args)).get('port')
running, _ = check_rest_server_quick(rest_port) running, _ = check_rest_server_quick(rest_port)
if running: if running:
response = rest_post(import_data_url(rest_port), content, REST_TIME_OUT) response = rest_post(import_data_url(rest_port), content, REST_TIME_OUT)
......
...@@ -4,31 +4,27 @@ ...@@ -4,31 +4,27 @@
import argparse import argparse
from pathlib import Path from pathlib import Path
from subprocess import Popen, PIPE, STDOUT from subprocess import Popen, PIPE, STDOUT
from nni.tools.nnictl.config_utils import Config, Experiments from nni.tools.nnictl.config_utils import Experiments
from nni.tools.nnictl.common_utils import print_green from nni.tools.nnictl.common_utils import print_green
from nni.tools.nnictl.command_utils import kill_command from nni.tools.nnictl.command_utils import kill_command
from nni.tools.nnictl.nnictl_utils import get_yml_content from nni.tools.nnictl.nnictl_utils import get_yml_content
def create_mock_experiment(): def create_mock_experiment():
nnictl_experiment_config = Experiments() nnictl_experiment_config = Experiments()
nnictl_experiment_config.add_experiment('xOpEwA5w', '8080', 123456, nnictl_experiment_config.add_experiment('xOpEwA5w', 8080, 123456,
'local', 'example_sklearn-classification') 'local', 'example_sklearn-classification')
nni_config = Config('xOpEwA5w')
# mock process # mock process
cmds = ['sleep', '3600000'] cmds = ['sleep', '3600000']
process = Popen(cmds, stdout=PIPE, stderr=STDOUT) process = Popen(cmds, stdout=PIPE, stderr=STDOUT)
nni_config.set_config('restServerPid', process.pid) nnictl_experiment_config.update_experiment('xOpEwA5w', 'pid', process.pid)
nni_config.set_config('experimentId', 'xOpEwA5w') nnictl_experiment_config.update_experiment('xOpEwA5w', 'port', 8080)
nni_config.set_config('restServerPort', 8080) nnictl_experiment_config.update_experiment('xOpEwA5w', 'webuiUrl', ['http://localhost:8080'])
nni_config.set_config('webuiUrl', ['http://localhost:8080'])
yml_path = Path(__file__).parents[1] / 'config_files/valid/test.yml'
experiment_config = get_yml_content(str(yml_path))
nni_config.set_config('experimentConfig', experiment_config)
print_green("expriment start success, experiment id: xOpEwA5w") print_green("expriment start success, experiment id: xOpEwA5w")
def stop_mock_experiment(): def stop_mock_experiment():
config = Config('config') nnictl_experiment_config = Experiments()
kill_command(config.get_config('restServerPid')) experiments_dict = nnictl_experiment_config.get_all_experiments()
kill_command(experiments_dict['xOpEwA5w'].get('pid'))
nnictl_experiment_config = Experiments() nnictl_experiment_config = Experiments()
nnictl_experiment_config.remove_experiment('xOpEwA5w') nnictl_experiment_config.remove_experiment('xOpEwA5w')
......
{"experimentConfig": {"authorName": "default", "experimentName": "example_sklearn-classification", "trialConcurrency": 5, "maxExecDuration": 3600, "maxTrialNum": 100, "trainingServicePlatform": "local", "searchSpacePath": "../../../config_files/valid/search_space.json", "useAnnotation": false, "tuner": {"builtinTunerName": "TPE", "classArgs": {"optimize_mode": "maximize"}}, "trial": {"command": "python3 main.py", "codeDir": "../../../config_files/valid/.", "gpuNum": 0}}, "restServerPort": 8080, "restServerPid": 7952, "experimentId": "xOpEwA5w", "webuiUrl": ["http://localhost:8080"]}
{"experimentId": "xOpEwA5w"}
\ No newline at end of file
...@@ -9,31 +9,16 @@ HOME_PATH = str(Path(__file__).parent / "mock/nnictl_metadata") ...@@ -9,31 +9,16 @@ HOME_PATH = str(Path(__file__).parent / "mock/nnictl_metadata")
class CommonUtilsTestCase(TestCase): class CommonUtilsTestCase(TestCase):
# FIXME:
# `experiment.get_all_experiments()` returns empty dict. No idea why.
# Don't want to debug this because I will port the logic to `nni.experiment`.
#def test_get_experiment(self):
# experiment = Experiments(HOME_PATH)
# self.assertTrue('xOpEwA5w' in experiment.get_all_experiments())
def test_update_experiment(self): def test_update_experiment(self):
experiment = Experiments(HOME_PATH) experiment = Experiments(HOME_PATH)
experiment.add_experiment('xOpEwA5w', 8081, 'N/A', 'local', 'test', endTime='N/A', status='INITIALIZED') experiment.add_experiment('xOpEwA5w', 8081, 'N/A', 'local', 'test', endTime='N/A', status='INITIALIZED')
self.assertTrue('xOpEwA5w' in experiment.get_all_experiments()) self.assertTrue('xOpEwA5w' in experiment.get_all_experiments())
experiment.remove_experiment('xOpEwA5w') experiment.remove_experiment('xOpEwA5w')
self.assertFalse('xOpEwA5w' in experiment.get_all_experiments()) self.assertFalse('xOpEwA5w' in experiment.get_all_experiments())
def test_get_config(self): def test_get_config(self):
config = Config('config', HOME_PATH) config = Config('xOpEwA5w', HOME_PATH)
self.assertEqual(config.get_config('experimentId'), 'xOpEwA5w') self.assertEqual(config.get_config()['experimentName'], 'test_config')
def test_set_config(self):
config = Config('config', HOME_PATH)
self.assertNotEqual(config.get_config('experimentId'), 'testId')
config.set_config('experimentId', 'testId')
self.assertEqual(config.get_config('experimentId'), 'testId')
config.set_config('experimentId', 'xOpEwA5w')
if __name__ == '__main__': if __name__ == '__main__':
main() main()
...@@ -54,7 +54,7 @@ class CommonUtilsTestCase(TestCase): ...@@ -54,7 +54,7 @@ class CommonUtilsTestCase(TestCase):
@responses.activate @responses.activate
def test_get_experiment_port(self): def test_get_experiment_port(self):
args = generate_args() args = generate_args()
self.assertEqual('8080', get_experiment_port(args)) self.assertEqual(8080, get_experiment_port(args))
@responses.activate @responses.activate
def test_check_rest(self): def test_check_rest(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment