Commit 3379411e authored by SparkSnail's avatar SparkSnail Committed by Yan Ni
Browse files

Fix nnictl commands (#1216)

parent b8b49a34
...@@ -368,8 +368,14 @@ Debug mode will disable version check function in Trialkeeper. ...@@ -368,8 +368,14 @@ Debug mode will disable version check function in Trialkeeper.
* Usage * Usage
```bash ```bash
nnictl experiment list nnictl experiment list [OPTIONS]
``` ```
* Options
|Name, shorthand|Required|Default|Description|
|------|------|------ |------|
|--all| False| |list all of experiments|
* __nnictl experiment delete__ * __nnictl experiment delete__
...@@ -388,6 +394,7 @@ Debug mode will disable version check function in Trialkeeper. ...@@ -388,6 +394,7 @@ Debug mode will disable version check function in Trialkeeper.
|Name, shorthand|Required|Default|Description| |Name, shorthand|Required|Default|Description|
|------|------|------ |------| |------|------|------ |------|
|id| False| |ID of the experiment| |id| False| |ID of the experiment|
|--all| False| |delete all of experiments|
......
...@@ -119,7 +119,7 @@ def parse_args(): ...@@ -119,7 +119,7 @@ def parse_args():
parser_experiment_status.add_argument('id', nargs='?', help='the id of experiment') parser_experiment_status.add_argument('id', nargs='?', help='the id of experiment')
parser_experiment_status.set_defaults(func=experiment_status) parser_experiment_status.set_defaults(func=experiment_status)
parser_experiment_list = parser_experiment_subparsers.add_parser('list', help='list all of running experiment ids') parser_experiment_list = parser_experiment_subparsers.add_parser('list', help='list all of running experiment ids')
parser_experiment_list.add_argument('all', nargs='?', help='list all of experiments') parser_experiment_list.add_argument('--all', action='store_true', default=False, help='list all of experiments')
parser_experiment_list.set_defaults(func=experiment_list) parser_experiment_list.set_defaults(func=experiment_list)
parser_experiment_clean = parser_experiment_subparsers.add_parser('delete', help='clean up the experiment data') parser_experiment_clean = parser_experiment_subparsers.add_parser('delete', help='clean up the experiment data')
parser_experiment_clean.add_argument('id', nargs='?', help='the id of experiment') parser_experiment_clean.add_argument('id', nargs='?', help='the id of experiment')
......
...@@ -106,14 +106,14 @@ def check_experiment_id(args, update=True): ...@@ -106,14 +106,14 @@ def check_experiment_id(args, update=True):
print(EXPERIMENT_INFORMATION_FORMAT % experiment_information) print(EXPERIMENT_INFORMATION_FORMAT % experiment_information)
exit(1) exit(1)
elif not running_experiment_list: elif not running_experiment_list:
print_error('There is no experiment running!') print_error('There is no experiment running.')
return None return None
else: else:
return running_experiment_list[0] return running_experiment_list[0]
if experiment_dict.get(args.id): if experiment_dict.get(args.id):
return args.id return args.id
else: else:
print_error('Id not correct!') print_error('Id not correct.')
return None return None
def parse_ids(args): def parse_ids(args):
...@@ -151,7 +151,7 @@ def parse_ids(args): ...@@ -151,7 +151,7 @@ def parse_ids(args):
exit(1) exit(1)
else: else:
result_list = running_experiment_list result_list = running_experiment_list
elif args.id == 'all': elif args.all:
result_list = running_experiment_list result_list = running_experiment_list
elif args.id.endswith('*'): elif args.id.endswith('*'):
for id in running_experiment_list: for id in running_experiment_list:
...@@ -168,15 +168,17 @@ def parse_ids(args): ...@@ -168,15 +168,17 @@ def parse_ids(args):
return None return None
if not result_list and args.id: if not result_list and args.id:
print_error('There are no experiments matched, please set correct experiment id...') print_error('There are no experiments matched, please set correct experiment id...')
elif not result_list: elif not result_list and not args.all:
print_error('There is no experiment running...') print_error('There is no experiment running...')
elif not result_list:
print_error('Cannot find experiments.')
return result_list return result_list
def get_config_filename(args): def get_config_filename(args):
'''get the file name of config file''' '''get the file name of config file'''
experiment_id = check_experiment_id(args) experiment_id = check_experiment_id(args)
if experiment_id is None: if experiment_id is None:
print_error('Please set correct experiment id!') print_error('Please set correct experiment id.')
exit(1) exit(1)
experiment_config = Experiments() experiment_config = Experiments()
experiment_dict = experiment_config.get_all_experiments() experiment_dict = experiment_config.get_all_experiments()
...@@ -186,7 +188,7 @@ def get_experiment_port(args): ...@@ -186,7 +188,7 @@ def get_experiment_port(args):
'''get the port of experiment''' '''get the port of experiment'''
experiment_id = check_experiment_id(args) experiment_id = check_experiment_id(args)
if experiment_id is None: if experiment_id is None:
print_error('Please set correct experiment id!') print_error('Please set correct experiment id.')
exit(1) exit(1)
experiment_config = Experiments() experiment_config = Experiments()
experiment_dict = experiment_config.get_all_experiments() experiment_dict = experiment_config.get_all_experiments()
...@@ -235,7 +237,7 @@ def stop_experiment(args): ...@@ -235,7 +237,7 @@ def stop_experiment(args):
except Exception as exception: except Exception as exception:
print_error(exception) print_error(exception)
nni_config.set_config('tensorboardPidList', []) nni_config.set_config('tensorboardPidList', [])
print_normal('Stop experiment success!') print_normal('Stop experiment success.')
experiment_config.update_experiment(experiment_id, 'status', 'STOPPED') experiment_config.update_experiment(experiment_id, 'status', 'STOPPED')
time_now = time.strftime('%Y-%m-%d %H:%M:%S',time.localtime(time.time())) time_now = time.strftime('%Y-%m-%d %H:%M:%S',time.localtime(time.time()))
experiment_config.update_experiment(experiment_id, 'endTime', str(time_now)) experiment_config.update_experiment(experiment_id, 'endTime', str(time_now))
...@@ -360,10 +362,10 @@ def log_trial(args): ...@@ -360,10 +362,10 @@ def log_trial(args):
if trial_id_path_dict.get(args.trial_id): if trial_id_path_dict.get(args.trial_id):
print_normal('id:' + args.trial_id + ' path:' + trial_id_path_dict[args.trial_id]) print_normal('id:' + args.trial_id + ' path:' + trial_id_path_dict[args.trial_id])
else: else:
print_error('trial id is not valid!') print_error('trial id is not valid.')
exit(1) exit(1)
else: else:
print_error('please specific the trial id!') print_error('please specific the trial id.')
exit(1) exit(1)
else: else:
for key in trial_id_path_dict: for key in trial_id_path_dict:
...@@ -385,7 +387,7 @@ def local_clean(directory): ...@@ -385,7 +387,7 @@ def local_clean(directory):
try: try:
shutil.rmtree(directory) shutil.rmtree(directory)
except FileNotFoundError as err: except FileNotFoundError as err:
print_error('{0} does not exist!'.format(directory)) print_error('{0} does not exist.'.format(directory))
def remote_clean(machine_list, experiment_id=None): def remote_clean(machine_list, experiment_id=None):
'''clean up remote data''' '''clean up remote data'''
...@@ -435,19 +437,19 @@ def experiment_clean(args): ...@@ -435,19 +437,19 @@ def experiment_clean(args):
experiment_id_list = list(experiment_dict.keys()) experiment_id_list = list(experiment_dict.keys())
else: else:
if args.id is None: if args.id is None:
print_error('please set experiment id!') print_error('please set experiment id.')
exit(1) exit(1)
if args.id not in experiment_dict: if args.id not in experiment_dict:
print_error('can not find id {0}!'.format(args.id)) print_error('Cannot find experiment {0}.'.format(args.id))
exit(1) exit(1)
experiment_id_list.append(args.id) experiment_id_list.append(args.id)
while True: while True:
print('INFO: This action will delete experiment {0}, and its not recoverable.'.format(' '.join(experiment_id_list))) print('INFO: This action will delete experiment {0}, and it\'s not recoverable.'.format(' '.join(experiment_id_list)))
inputs = input('INFO: do you want to continue?[y/N]:') inputs = input('INFO: do you want to continue?[y/N]:')
if not inputs.lower() or inputs.lower() in ['n', 'no']: if not inputs.lower() or inputs.lower() in ['n', 'no']:
exit(0) exit(0)
elif inputs.lower() not in ['y', 'n', 'yes', 'no']: elif inputs.lower() not in ['y', 'n', 'yes', 'no']:
print_warning('please input Y or N!') print_warning('please input Y or N.')
else: else:
break break
for experiment_id in experiment_id_list: for experiment_id in experiment_id_list:
...@@ -464,7 +466,7 @@ def experiment_clean(args): ...@@ -464,7 +466,7 @@ def experiment_clean(args):
hdfs_clean(host, user_name, output_dir, experiment_id) hdfs_clean(host, user_name, output_dir, experiment_id)
elif platform != 'local': elif platform != 'local':
#TODO: support all platforms #TODO: support all platforms
print_warning('platform {0} clean up not supported yet!'.format(platform)) print_warning('platform {0} clean up not supported yet.'.format(platform))
exit(0) exit(0)
#clean local data #clean local data
home = str(Path.home()) home = str(Path.home())
...@@ -475,7 +477,7 @@ def experiment_clean(args): ...@@ -475,7 +477,7 @@ def experiment_clean(args):
experiment_config = Experiments() experiment_config = Experiments()
print_normal('removing metadata of experiment {0}'.format(experiment_id)) print_normal('removing metadata of experiment {0}'.format(experiment_id))
experiment_config.remove_experiment(experiment_id) experiment_config.remove_experiment(experiment_id)
print_normal('Finish!') print_normal('Done.')
def get_platform_dir(config_content): def get_platform_dir(config_content):
'''get the dir list to be deleted''' '''get the dir list to be deleted'''
...@@ -486,13 +488,13 @@ def get_platform_dir(config_content): ...@@ -486,13 +488,13 @@ def get_platform_dir(config_content):
for machine in machine_list: for machine in machine_list:
host = machine.get('ip') host = machine.get('ip')
port = machine.get('port') port = machine.get('port')
dir_list.append(host + ':' + str(port) + '/tmp/nni/experiments') dir_list.append(host + ':' + str(port) + '/tmp/nni')
elif platform == 'pai': elif platform == 'pai':
pai_config = config_content.get('paiConfig') pai_config = config_content.get('paiConfig')
host = config_content.get('paiConfig').get('host') host = config_content.get('paiConfig').get('host')
user_name = config_content.get('paiConfig').get('userName') user_name = config_content.get('paiConfig').get('userName')
output_dir = config_content.get('trial').get('outputDir') output_dir = config_content.get('trial').get('outputDir')
dir_list.append('hdfs://{0}:9000/{1}/nni/experiments'.format(host, user_name)) dir_list.append('server: {0}, path: {1}/nni'.format(host, user_name))
if output_dir: if output_dir:
dir_list.append(output_dir) dir_list.append(output_dir)
return dir_list return dir_list
...@@ -501,12 +503,15 @@ def platform_clean(args): ...@@ -501,12 +503,15 @@ def platform_clean(args):
'''clean up the experiment data''' '''clean up the experiment data'''
config_path = os.path.abspath(args.config) config_path = os.path.abspath(args.config)
if not os.path.exists(config_path): if not os.path.exists(config_path):
print_error('Please set correct config path!') print_error('Please set correct config path.')
exit(1) exit(1)
config_content = get_yml_content(config_path) config_content = get_yml_content(config_path)
platform = config_content.get('trainingServicePlatform') platform = config_content.get('trainingServicePlatform')
if platform == 'local':
print_normal('it doesn’t need to clean local platform.')
exit(0)
if platform not in ['remote', 'pai']: if platform not in ['remote', 'pai']:
print_normal('platform {0} not supported!'.format(platform)) print_normal('platform {0} not supported.'.format(platform))
exit(0) exit(0)
experiment_config = Experiments() experiment_config = Experiments()
experiment_dict = experiment_config.get_all_experiments() experiment_dict = experiment_config.get_all_experiments()
...@@ -514,7 +519,7 @@ def platform_clean(args): ...@@ -514,7 +519,7 @@ def platform_clean(args):
id_list = list(experiment_dict.keys()) id_list = list(experiment_dict.keys())
dir_list = get_platform_dir(config_content) dir_list = get_platform_dir(config_content)
if not dir_list: if not dir_list:
print_normal('No folder of NNI caches is found!') print_normal('No folder of NNI caches is found.')
exit(1) exit(1)
while True: while True:
print_normal('This command will remove below folders of NNI caches. If other users are using experiments on below hosts, it will be broken.') print_normal('This command will remove below folders of NNI caches. If other users are using experiments on below hosts, it will be broken.')
...@@ -524,7 +529,7 @@ def platform_clean(args): ...@@ -524,7 +529,7 @@ def platform_clean(args):
if not inputs.lower() or inputs.lower() in ['n', 'no']: if not inputs.lower() or inputs.lower() in ['n', 'no']:
exit(0) exit(0)
elif inputs.lower() not in ['y', 'n', 'yes', 'no']: elif inputs.lower() not in ['y', 'n', 'yes', 'no']:
print_warning('please input Y or N!') print_warning('please input Y or N.')
else: else:
break break
if platform == 'remote': if platform == 'remote':
...@@ -537,18 +542,18 @@ def platform_clean(args): ...@@ -537,18 +542,18 @@ def platform_clean(args):
user_name = config_content.get('paiConfig').get('userName') user_name = config_content.get('paiConfig').get('userName')
output_dir = config_content.get('trial').get('outputDir') output_dir = config_content.get('trial').get('outputDir')
hdfs_clean(host, user_name, output_dir, None) hdfs_clean(host, user_name, output_dir, None)
print_normal('Done!') print_normal('Done.')
def experiment_list(args): def experiment_list(args):
'''get the information of all experiments''' '''get the information of all experiments'''
experiment_config = Experiments() experiment_config = Experiments()
experiment_dict = experiment_config.get_all_experiments() experiment_dict = experiment_config.get_all_experiments()
if not experiment_dict: if not experiment_dict:
print('There is no experiment running...') print_normal('Cannot find experiments.')
exit(1) exit(1)
update_experiment() update_experiment()
experiment_id_list = [] experiment_id_list = []
if args.all and args.all == 'all': if args.all:
for key in experiment_dict.keys(): for key in experiment_dict.keys():
experiment_id_list.append(key) experiment_id_list.append(key)
else: else:
...@@ -556,7 +561,7 @@ def experiment_list(args): ...@@ -556,7 +561,7 @@ def experiment_list(args):
if experiment_dict[key]['status'] != 'STOPPED': if experiment_dict[key]['status'] != 'STOPPED':
experiment_id_list.append(key) experiment_id_list.append(key)
if not experiment_id_list: if not experiment_id_list:
print_warning('There is no experiment running...\nYou can use \'nnictl experiment list all\' to list all stopped experiments!') print_warning('There is no experiment running...\nYou can use \'nnictl experiment list --all\' to list all stopped experiments.')
experiment_information = "" experiment_information = ""
for key in experiment_id_list: for key in experiment_id_list:
experiment_information += (EXPERIMENT_DETAIL_FORMAT % (key, experiment_dict[key]['status'], experiment_dict[key]['port'],\ experiment_information += (EXPERIMENT_DETAIL_FORMAT % (key, experiment_dict[key]['status'], experiment_dict[key]['port'],\
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment