Commit 3379411e authored by SparkSnail's avatar SparkSnail Committed by Yan Ni
Browse files

Fix nnictl commands (#1216)

parent b8b49a34
......@@ -368,8 +368,14 @@ Debug mode will disable version check function in Trialkeeper.
* Usage
```bash
nnictl experiment list
nnictl experiment list [OPTIONS]
```
* Options
|Name, shorthand|Required|Default|Description|
|------|------|------ |------|
|--all| False| |list all of experiments|
* __nnictl experiment delete__
......@@ -388,6 +394,7 @@ Debug mode will disable version check function in Trialkeeper.
|Name, shorthand|Required|Default|Description|
|------|------|------ |------|
|id| False| |ID of the experiment|
|--all| False| |delete all of experiments|
......
......@@ -119,7 +119,7 @@ def parse_args():
parser_experiment_status.add_argument('id', nargs='?', help='the id of experiment')
parser_experiment_status.set_defaults(func=experiment_status)
parser_experiment_list = parser_experiment_subparsers.add_parser('list', help='list all of running experiment ids')
parser_experiment_list.add_argument('all', nargs='?', help='list all of experiments')
parser_experiment_list.add_argument('--all', action='store_true', default=False, help='list all of experiments')
parser_experiment_list.set_defaults(func=experiment_list)
parser_experiment_clean = parser_experiment_subparsers.add_parser('delete', help='clean up the experiment data')
parser_experiment_clean.add_argument('id', nargs='?', help='the id of experiment')
......
......@@ -106,14 +106,14 @@ def check_experiment_id(args, update=True):
print(EXPERIMENT_INFORMATION_FORMAT % experiment_information)
exit(1)
elif not running_experiment_list:
print_error('There is no experiment running!')
print_error('There is no experiment running.')
return None
else:
return running_experiment_list[0]
if experiment_dict.get(args.id):
return args.id
else:
print_error('Id not correct!')
print_error('Id not correct.')
return None
def parse_ids(args):
......@@ -151,7 +151,7 @@ def parse_ids(args):
exit(1)
else:
result_list = running_experiment_list
elif args.id == 'all':
elif args.all:
result_list = running_experiment_list
elif args.id.endswith('*'):
for id in running_experiment_list:
......@@ -168,15 +168,17 @@ def parse_ids(args):
return None
if not result_list and args.id:
print_error('There are no experiments matched, please set correct experiment id...')
elif not result_list:
elif not result_list and not args.all:
print_error('There is no experiment running...')
elif not result_list:
print_error('Cannot find experiments.')
return result_list
def get_config_filename(args):
'''get the file name of config file'''
experiment_id = check_experiment_id(args)
if experiment_id is None:
print_error('Please set correct experiment id!')
print_error('Please set correct experiment id.')
exit(1)
experiment_config = Experiments()
experiment_dict = experiment_config.get_all_experiments()
......@@ -186,7 +188,7 @@ def get_experiment_port(args):
'''get the port of experiment'''
experiment_id = check_experiment_id(args)
if experiment_id is None:
print_error('Please set correct experiment id!')
print_error('Please set correct experiment id.')
exit(1)
experiment_config = Experiments()
experiment_dict = experiment_config.get_all_experiments()
......@@ -235,7 +237,7 @@ def stop_experiment(args):
except Exception as exception:
print_error(exception)
nni_config.set_config('tensorboardPidList', [])
print_normal('Stop experiment success!')
print_normal('Stop experiment success.')
experiment_config.update_experiment(experiment_id, 'status', 'STOPPED')
time_now = time.strftime('%Y-%m-%d %H:%M:%S',time.localtime(time.time()))
experiment_config.update_experiment(experiment_id, 'endTime', str(time_now))
......@@ -360,10 +362,10 @@ def log_trial(args):
if trial_id_path_dict.get(args.trial_id):
print_normal('id:' + args.trial_id + ' path:' + trial_id_path_dict[args.trial_id])
else:
print_error('trial id is not valid!')
print_error('trial id is not valid.')
exit(1)
else:
print_error('please specific the trial id!')
print_error('please specific the trial id.')
exit(1)
else:
for key in trial_id_path_dict:
......@@ -385,7 +387,7 @@ def local_clean(directory):
try:
shutil.rmtree(directory)
except FileNotFoundError as err:
print_error('{0} does not exist!'.format(directory))
print_error('{0} does not exist.'.format(directory))
def remote_clean(machine_list, experiment_id=None):
'''clean up remote data'''
......@@ -435,19 +437,19 @@ def experiment_clean(args):
experiment_id_list = list(experiment_dict.keys())
else:
if args.id is None:
print_error('please set experiment id!')
print_error('please set experiment id.')
exit(1)
if args.id not in experiment_dict:
print_error('can not find id {0}!'.format(args.id))
print_error('Cannot find experiment {0}.'.format(args.id))
exit(1)
experiment_id_list.append(args.id)
while True:
print('INFO: This action will delete experiment {0}, and its not recoverable.'.format(' '.join(experiment_id_list)))
print('INFO: This action will delete experiment {0}, and it\'s not recoverable.'.format(' '.join(experiment_id_list)))
inputs = input('INFO: do you want to continue?[y/N]:')
if not inputs.lower() or inputs.lower() in ['n', 'no']:
exit(0)
elif inputs.lower() not in ['y', 'n', 'yes', 'no']:
print_warning('please input Y or N!')
print_warning('please input Y or N.')
else:
break
for experiment_id in experiment_id_list:
......@@ -464,7 +466,7 @@ def experiment_clean(args):
hdfs_clean(host, user_name, output_dir, experiment_id)
elif platform != 'local':
#TODO: support all platforms
print_warning('platform {0} clean up not supported yet!'.format(platform))
print_warning('platform {0} clean up not supported yet.'.format(platform))
exit(0)
#clean local data
home = str(Path.home())
......@@ -475,7 +477,7 @@ def experiment_clean(args):
experiment_config = Experiments()
print_normal('removing metadata of experiment {0}'.format(experiment_id))
experiment_config.remove_experiment(experiment_id)
print_normal('Finish!')
print_normal('Done.')
def get_platform_dir(config_content):
'''get the dir list to be deleted'''
......@@ -486,13 +488,13 @@ def get_platform_dir(config_content):
for machine in machine_list:
host = machine.get('ip')
port = machine.get('port')
dir_list.append(host + ':' + str(port) + '/tmp/nni/experiments')
dir_list.append(host + ':' + str(port) + '/tmp/nni')
elif platform == 'pai':
pai_config = config_content.get('paiConfig')
host = config_content.get('paiConfig').get('host')
user_name = config_content.get('paiConfig').get('userName')
output_dir = config_content.get('trial').get('outputDir')
dir_list.append('hdfs://{0}:9000/{1}/nni/experiments'.format(host, user_name))
dir_list.append('server: {0}, path: {1}/nni'.format(host, user_name))
if output_dir:
dir_list.append(output_dir)
return dir_list
......@@ -501,12 +503,15 @@ def platform_clean(args):
'''clean up the experiment data'''
config_path = os.path.abspath(args.config)
if not os.path.exists(config_path):
print_error('Please set correct config path!')
print_error('Please set correct config path.')
exit(1)
config_content = get_yml_content(config_path)
platform = config_content.get('trainingServicePlatform')
if platform == 'local':
print_normal('it doesn’t need to clean local platform.')
exit(0)
if platform not in ['remote', 'pai']:
print_normal('platform {0} not supported!'.format(platform))
print_normal('platform {0} not supported.'.format(platform))
exit(0)
experiment_config = Experiments()
experiment_dict = experiment_config.get_all_experiments()
......@@ -514,7 +519,7 @@ def platform_clean(args):
id_list = list(experiment_dict.keys())
dir_list = get_platform_dir(config_content)
if not dir_list:
print_normal('No folder of NNI caches is found!')
print_normal('No folder of NNI caches is found.')
exit(1)
while True:
print_normal('This command will remove below folders of NNI caches. If other users are using experiments on below hosts, it will be broken.')
......@@ -524,7 +529,7 @@ def platform_clean(args):
if not inputs.lower() or inputs.lower() in ['n', 'no']:
exit(0)
elif inputs.lower() not in ['y', 'n', 'yes', 'no']:
print_warning('please input Y or N!')
print_warning('please input Y or N.')
else:
break
if platform == 'remote':
......@@ -537,18 +542,18 @@ def platform_clean(args):
user_name = config_content.get('paiConfig').get('userName')
output_dir = config_content.get('trial').get('outputDir')
hdfs_clean(host, user_name, output_dir, None)
print_normal('Done!')
print_normal('Done.')
def experiment_list(args):
'''get the information of all experiments'''
experiment_config = Experiments()
experiment_dict = experiment_config.get_all_experiments()
if not experiment_dict:
print('There is no experiment running...')
print_normal('Cannot find experiments.')
exit(1)
update_experiment()
experiment_id_list = []
if args.all and args.all == 'all':
if args.all:
for key in experiment_dict.keys():
experiment_id_list.append(key)
else:
......@@ -556,7 +561,7 @@ def experiment_list(args):
if experiment_dict[key]['status'] != 'STOPPED':
experiment_id_list.append(key)
if not experiment_id_list:
print_warning('There is no experiment running...\nYou can use \'nnictl experiment list all\' to list all stopped experiments!')
print_warning('There is no experiment running...\nYou can use \'nnictl experiment list --all\' to list all stopped experiments.')
experiment_information = ""
for key in experiment_id_list:
experiment_information += (EXPERIMENT_DETAIL_FORMAT % (key, experiment_dict[key]['status'], experiment_dict[key]['port'],\
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment