Unverified Commit e9040c9b authored by chicm-ms's avatar chicm-ms Committed by GitHub
Browse files

Merge pull request #23 from microsoft/master

pull code
parents 256f27af ed63175c
...@@ -76,11 +76,12 @@ def _generate_file_search_space(path, module): ...@@ -76,11 +76,12 @@ def _generate_file_search_space(path, module):
return search_space return search_space
def expand_annotations(src_dir, dst_dir, exp_id='', trial_id=''): def expand_annotations(src_dir, dst_dir, exp_id='', trial_id='', nas_mode=None):
"""Expand annotations in user code. """Expand annotations in user code.
Return dst_dir if annotation detected; return src_dir if not. Return dst_dir if annotation detected; return src_dir if not.
src_dir: directory path of user code (str) src_dir: directory path of user code (str)
dst_dir: directory to place generated files (str) dst_dir: directory to place generated files (str)
nas_mode: the mode of NAS given that NAS interface is used
""" """
if src_dir[-1] == slash: if src_dir[-1] == slash:
src_dir = src_dir[:-1] src_dir = src_dir[:-1]
...@@ -108,7 +109,7 @@ def expand_annotations(src_dir, dst_dir, exp_id='', trial_id=''): ...@@ -108,7 +109,7 @@ def expand_annotations(src_dir, dst_dir, exp_id='', trial_id=''):
dst_path = os.path.join(dst_subdir, file_name) dst_path = os.path.join(dst_subdir, file_name)
if file_name.endswith('.py'): if file_name.endswith('.py'):
if trial_id == '': if trial_id == '':
annotated |= _expand_file_annotations(src_path, dst_path) annotated |= _expand_file_annotations(src_path, dst_path, nas_mode)
else: else:
module = package + file_name[:-3] module = package + file_name[:-3]
annotated |= _generate_specific_file(src_path, dst_path, exp_id, trial_id, module) annotated |= _generate_specific_file(src_path, dst_path, exp_id, trial_id, module)
...@@ -120,10 +121,10 @@ def expand_annotations(src_dir, dst_dir, exp_id='', trial_id=''): ...@@ -120,10 +121,10 @@ def expand_annotations(src_dir, dst_dir, exp_id='', trial_id=''):
return dst_dir if annotated else src_dir return dst_dir if annotated else src_dir
def _expand_file_annotations(src_path, dst_path): def _expand_file_annotations(src_path, dst_path, nas_mode):
with open(src_path) as src, open(dst_path, 'w') as dst: with open(src_path) as src, open(dst_path, 'w') as dst:
try: try:
annotated_code = code_generator.parse(src.read()) annotated_code = code_generator.parse(src.read(), nas_mode)
if annotated_code is None: if annotated_code is None:
shutil.copyfile(src_path, dst_path) shutil.copyfile(src_path, dst_path)
return False return False
......
...@@ -21,14 +21,14 @@ ...@@ -21,14 +21,14 @@
import ast import ast
import astor import astor
from nni_cmd.common_utils import print_warning
# pylint: disable=unidiomatic-typecheck # pylint: disable=unidiomatic-typecheck
def parse_annotation_mutable_layers(code, lineno): def parse_annotation_mutable_layers(code, lineno, nas_mode):
"""Parse the string of mutable layers in annotation. """Parse the string of mutable layers in annotation.
Return a list of AST Expr nodes Return a list of AST Expr nodes
code: annotation string (excluding '@') code: annotation string (excluding '@')
nas_mode: the mode of NAS
""" """
module = ast.parse(code) module = ast.parse(code)
assert type(module) is ast.Module, 'internal error #1' assert type(module) is ast.Module, 'internal error #1'
...@@ -110,6 +110,9 @@ def parse_annotation_mutable_layers(code, lineno): ...@@ -110,6 +110,9 @@ def parse_annotation_mutable_layers(code, lineno):
else: else:
target_call_args.append(ast.Dict(keys=[], values=[])) target_call_args.append(ast.Dict(keys=[], values=[]))
target_call_args.append(ast.Num(n=0)) target_call_args.append(ast.Num(n=0))
target_call_args.append(ast.Str(s=nas_mode))
if nas_mode in ['enas_mode', 'oneshot_mode']:
target_call_args.append(ast.Name(id='tensorflow'))
target_call = ast.Call(func=target_call_attr, args=target_call_args, keywords=[]) target_call = ast.Call(func=target_call_attr, args=target_call_args, keywords=[])
node = ast.Assign(targets=[layer_output], value=target_call) node = ast.Assign(targets=[layer_output], value=target_call)
nodes.append(node) nodes.append(node)
...@@ -277,10 +280,11 @@ class FuncReplacer(ast.NodeTransformer): ...@@ -277,10 +280,11 @@ class FuncReplacer(ast.NodeTransformer):
class Transformer(ast.NodeTransformer): class Transformer(ast.NodeTransformer):
"""Transform original code to annotated code""" """Transform original code to annotated code"""
def __init__(self): def __init__(self, nas_mode=None):
self.stack = [] self.stack = []
self.last_line = 0 self.last_line = 0
self.annotated = False self.annotated = False
self.nas_mode = nas_mode
def visit(self, node): def visit(self, node):
if isinstance(node, (ast.expr, ast.stmt)): if isinstance(node, (ast.expr, ast.stmt)):
...@@ -316,8 +320,11 @@ class Transformer(ast.NodeTransformer): ...@@ -316,8 +320,11 @@ class Transformer(ast.NodeTransformer):
return node # not an annotation, ignore it return node # not an annotation, ignore it
if string.startswith('@nni.get_next_parameter'): if string.startswith('@nni.get_next_parameter'):
deprecated_message = "'@nni.get_next_parameter' is deprecated in annotation due to inconvenience. Please remove this line in the trial code." call_node = parse_annotation(string[1:]).value
print_warning(deprecated_message) if call_node.args:
# it is used in enas mode as it needs to retrieve the next subgraph for training
call_attr = ast.Attribute(value=ast.Name(id='nni', ctx=ast.Load()), attr='reload_tensorflow_variables', ctx=ast.Load())
return ast.Expr(value=ast.Call(func=call_attr, args=call_node.args, keywords=[]))
if string.startswith('@nni.report_intermediate_result') \ if string.startswith('@nni.report_intermediate_result') \
or string.startswith('@nni.report_final_result') \ or string.startswith('@nni.report_final_result') \
...@@ -325,7 +332,8 @@ class Transformer(ast.NodeTransformer): ...@@ -325,7 +332,8 @@ class Transformer(ast.NodeTransformer):
return parse_annotation(string[1:]) # expand annotation string to code return parse_annotation(string[1:]) # expand annotation string to code
if string.startswith('@nni.mutable_layers'): if string.startswith('@nni.mutable_layers'):
return parse_annotation_mutable_layers(string[1:], node.lineno) nodes = parse_annotation_mutable_layers(string[1:], node.lineno, self.nas_mode)
return nodes
if string.startswith('@nni.variable') \ if string.startswith('@nni.variable') \
or string.startswith('@nni.function_choice'): or string.startswith('@nni.function_choice'):
...@@ -343,17 +351,18 @@ class Transformer(ast.NodeTransformer): ...@@ -343,17 +351,18 @@ class Transformer(ast.NodeTransformer):
return node return node
def parse(code): def parse(code, nas_mode=None):
"""Annotate user code. """Annotate user code.
Return annotated code (str) if annotation detected; return None if not. Return annotated code (str) if annotation detected; return None if not.
code: original user code (str) code: original user code (str),
nas_mode: the mode of NAS given that NAS interface is used
""" """
try: try:
ast_tree = ast.parse(code) ast_tree = ast.parse(code)
except Exception: except Exception:
raise RuntimeError('Bad Python code') raise RuntimeError('Bad Python code')
transformer = Transformer() transformer = Transformer(nas_mode)
try: try:
transformer.visit(ast_tree) transformer.visit(ast_tree)
except AssertionError as exc: except AssertionError as exc:
...@@ -369,5 +378,9 @@ def parse(code): ...@@ -369,5 +378,9 @@ def parse(code):
if type(nodes[i]) is ast.ImportFrom and nodes[i].module == '__future__': if type(nodes[i]) is ast.ImportFrom and nodes[i].module == '__future__':
last_future_import = i last_future_import = i
nodes.insert(last_future_import + 1, import_nni) nodes.insert(last_future_import + 1, import_nni)
# enas and oneshot modes for tensorflow need tensorflow module, so we import it here
if nas_mode in ['enas_mode', 'oneshot_mode']:
import_tf = ast.Import(names=[ast.alias(name='tensorflow', asname=None)])
nodes.insert(last_future_import + 1, import_tf)
return astor.to_source(ast_tree) return astor.to_source(ast_tree)
...@@ -104,6 +104,21 @@ tuner_schema_dict = { ...@@ -104,6 +104,21 @@ tuner_schema_dict = {
}, },
Optional('gpuNum'): setNumberRange('gpuNum', int, 0, 99999), Optional('gpuNum'): setNumberRange('gpuNum', int, 0, 99999),
}, },
'GPTuner': {
'builtinTunerName': 'GPTuner',
'classArgs': {
Optional('optimize_mode'): setChoice('optimize_mode', 'maximize', 'minimize'),
Optional('utility'): setChoice('utility', 'ei', 'ucb', 'poi'),
Optional('kappa'): setType('kappa', float),
Optional('xi'): setType('xi', float),
Optional('nu'): setType('nu', float),
Optional('alpha'): setType('alpha', float),
Optional('cold_start_num'): setType('cold_start_num', int),
Optional('selection_num_warm_up'): setType('selection_num_warm_up', int),
Optional('selection_num_starting_points'): setType('selection_num_starting_points', int),
},
Optional('gpuNum'): setNumberRange('gpuNum', int, 0, 99999),
},
'customized': { 'customized': {
'codeDir': setPathCheck('codeDir'), 'codeDir': setPathCheck('codeDir'),
'classFileName': setType('classFileName', str), 'classFileName': setType('classFileName', str),
...@@ -181,7 +196,8 @@ common_trial_schema = { ...@@ -181,7 +196,8 @@ common_trial_schema = {
'trial':{ 'trial':{
'command': setType('command', str), 'command': setType('command', str),
'codeDir': setPathCheck('codeDir'), 'codeDir': setPathCheck('codeDir'),
'gpuNum': setNumberRange('gpuNum', int, 0, 99999) 'gpuNum': setNumberRange('gpuNum', int, 0, 99999),
Optional('nasMode'): setChoice('classic_mode', 'enas_mode', 'oneshot_mode')
} }
} }
...@@ -199,6 +215,7 @@ pai_trial_schema = { ...@@ -199,6 +215,7 @@ pai_trial_schema = {
Optional('outputDir'): And(Regex(r'hdfs://(([0-9]{1,3}.){3}[0-9]{1,3})(:[0-9]{2,5})?(/.*)?'),\ Optional('outputDir'): And(Regex(r'hdfs://(([0-9]{1,3}.){3}[0-9]{1,3})(:[0-9]{2,5})?(/.*)?'),\
error='ERROR: outputDir format error, outputDir format is hdfs://xxx.xxx.xxx.xxx:xxx'), error='ERROR: outputDir format error, outputDir format is hdfs://xxx.xxx.xxx.xxx:xxx'),
Optional('virtualCluster'): setType('virtualCluster', str), Optional('virtualCluster'): setType('virtualCluster', str),
Optional('nasMode'): setChoice('classic_mode', 'enas_mode', 'oneshot_mode')
} }
} }
...@@ -213,6 +230,7 @@ pai_config_schema = { ...@@ -213,6 +230,7 @@ pai_config_schema = {
kubeflow_trial_schema = { kubeflow_trial_schema = {
'trial':{ 'trial':{
'codeDir': setPathCheck('codeDir'), 'codeDir': setPathCheck('codeDir'),
Optional('nasMode'): setChoice('classic_mode', 'enas_mode', 'oneshot_mode'),
Optional('ps'): { Optional('ps'): {
'replicas': setType('replicas', int), 'replicas': setType('replicas', int),
'command': setType('command', str), 'command': setType('command', str),
......
...@@ -377,7 +377,8 @@ def launch_experiment(args, experiment_config, mode, config_file_name, experimen ...@@ -377,7 +377,8 @@ def launch_experiment(args, experiment_config, mode, config_file_name, experimen
if not os.path.isdir(path): if not os.path.isdir(path):
os.makedirs(path) os.makedirs(path)
path = tempfile.mkdtemp(dir=path) path = tempfile.mkdtemp(dir=path)
code_dir = expand_annotations(experiment_config['trial']['codeDir'], path) nas_mode = experiment_config['trial'].get('nasMode', 'classic_mode')
code_dir = expand_annotations(experiment_config['trial']['codeDir'], path, nas_mode=nas_mode)
experiment_config['trial']['codeDir'] = code_dir experiment_config['trial']['codeDir'] = code_dir
search_space = generate_search_space(code_dir) search_space = generate_search_space(code_dir)
experiment_config['searchSpace'] = json.dumps(search_space) experiment_config['searchSpace'] = json.dumps(search_space)
......
...@@ -119,8 +119,21 @@ def parse_args(): ...@@ -119,8 +119,21 @@ def parse_args():
parser_experiment_status.add_argument('id', nargs='?', help='the id of experiment') parser_experiment_status.add_argument('id', nargs='?', help='the id of experiment')
parser_experiment_status.set_defaults(func=experiment_status) parser_experiment_status.set_defaults(func=experiment_status)
parser_experiment_list = parser_experiment_subparsers.add_parser('list', help='list all of running experiment ids') parser_experiment_list = parser_experiment_subparsers.add_parser('list', help='list all of running experiment ids')
parser_experiment_list.add_argument('all', nargs='?', help='list all of experiments') parser_experiment_list.add_argument('--all', action='store_true', default=False, help='list all of experiments')
parser_experiment_list.set_defaults(func=experiment_list) parser_experiment_list.set_defaults(func=experiment_list)
parser_experiment_clean = parser_experiment_subparsers.add_parser('delete', help='clean up the experiment data')
parser_experiment_clean.add_argument('id', nargs='?', help='the id of experiment')
parser_experiment_clean.add_argument('--all', action='store_true', default=False, help='delete all of experiments')
parser_experiment_clean.set_defaults(func=experiment_clean)
#parse experiment command
parser_platform = subparsers.add_parser('platform', help='get platform information')
#add subparsers for parser_experiment
parser_platform_subparsers = parser_platform.add_subparsers()
parser_platform_clean = parser_platform_subparsers.add_parser('clean', help='clean up the platform data')
parser_platform_clean.add_argument('--config', '-c', required=True, dest='config', help='the path of yaml config file')
parser_platform_clean.set_defaults(func=platform_clean)
#import tuning data #import tuning data
parser_import_data = parser_experiment_subparsers.add_parser('import', help='import additional data') parser_import_data = parser_experiment_subparsers.add_parser('import', help='import additional data')
parser_import_data.add_argument('id', nargs='?', help='the id of experiment') parser_import_data.add_argument('id', nargs='?', help='the id of experiment')
......
...@@ -24,6 +24,10 @@ import psutil ...@@ -24,6 +24,10 @@ import psutil
import json import json
import datetime import datetime
import time import time
import re
from pathlib import Path
from pyhdfs import HdfsClient, HdfsFileNotFoundException
import shutil
from subprocess import call, check_output from subprocess import call, check_output
from nni_annotation import expand_annotations from nni_annotation import expand_annotations
from .rest_utils import rest_get, rest_delete, check_rest_server_quick, check_response from .rest_utils import rest_get, rest_delete, check_rest_server_quick, check_response
...@@ -31,8 +35,9 @@ from .url_utils import trial_jobs_url, experiment_url, trial_job_id_url, export_ ...@@ -31,8 +35,9 @@ from .url_utils import trial_jobs_url, experiment_url, trial_job_id_url, export_
from .config_utils import Config, Experiments from .config_utils import Config, Experiments
from .constants import NNICTL_HOME_DIR, EXPERIMENT_INFORMATION_FORMAT, EXPERIMENT_DETAIL_FORMAT, \ from .constants import NNICTL_HOME_DIR, EXPERIMENT_INFORMATION_FORMAT, EXPERIMENT_DETAIL_FORMAT, \
EXPERIMENT_MONITOR_INFO, TRIAL_MONITOR_HEAD, TRIAL_MONITOR_CONTENT, TRIAL_MONITOR_TAIL, REST_TIME_OUT EXPERIMENT_MONITOR_INFO, TRIAL_MONITOR_HEAD, TRIAL_MONITOR_CONTENT, TRIAL_MONITOR_TAIL, REST_TIME_OUT
from .common_utils import print_normal, print_error, print_warning, detect_process from .common_utils import print_normal, print_error, print_warning, detect_process, get_yml_content
from .command_utils import check_output_command, kill_command from .command_utils import check_output_command, kill_command
from .ssh_utils import create_ssh_sftp_client, remove_remote_directory
def get_experiment_time(port): def get_experiment_time(port):
'''get the startTime and endTime of an experiment''' '''get the startTime and endTime of an experiment'''
...@@ -73,10 +78,11 @@ def update_experiment(): ...@@ -73,10 +78,11 @@ def update_experiment():
if status: if status:
experiment_config.update_experiment(key, 'status', status) experiment_config.update_experiment(key, 'status', status)
def check_experiment_id(args): def check_experiment_id(args, update=True):
'''check if the id is valid '''check if the id is valid
''' '''
update_experiment() if update:
update_experiment()
experiment_config = Experiments() experiment_config = Experiments()
experiment_dict = experiment_config.get_all_experiments() experiment_dict = experiment_config.get_all_experiments()
if not experiment_dict: if not experiment_dict:
...@@ -100,14 +106,14 @@ def check_experiment_id(args): ...@@ -100,14 +106,14 @@ def check_experiment_id(args):
print(EXPERIMENT_INFORMATION_FORMAT % experiment_information) print(EXPERIMENT_INFORMATION_FORMAT % experiment_information)
exit(1) exit(1)
elif not running_experiment_list: elif not running_experiment_list:
print_error('There is no experiment running!') print_error('There is no experiment running.')
return None return None
else: else:
return running_experiment_list[0] return running_experiment_list[0]
if experiment_dict.get(args.id): if experiment_dict.get(args.id):
return args.id return args.id
else: else:
print_error('Id not correct!') print_error('Id not correct.')
return None return None
def parse_ids(args): def parse_ids(args):
...@@ -145,7 +151,7 @@ def parse_ids(args): ...@@ -145,7 +151,7 @@ def parse_ids(args):
exit(1) exit(1)
else: else:
result_list = running_experiment_list result_list = running_experiment_list
elif args.id == 'all': elif args.all:
result_list = running_experiment_list result_list = running_experiment_list
elif args.id.endswith('*'): elif args.id.endswith('*'):
for id in running_experiment_list: for id in running_experiment_list:
...@@ -170,7 +176,7 @@ def get_config_filename(args): ...@@ -170,7 +176,7 @@ def get_config_filename(args):
'''get the file name of config file''' '''get the file name of config file'''
experiment_id = check_experiment_id(args) experiment_id = check_experiment_id(args)
if experiment_id is None: if experiment_id is None:
print_error('Please set the experiment id!') print_error('Please set correct experiment id.')
exit(1) exit(1)
experiment_config = Experiments() experiment_config = Experiments()
experiment_dict = experiment_config.get_all_experiments() experiment_dict = experiment_config.get_all_experiments()
...@@ -180,7 +186,7 @@ def get_experiment_port(args): ...@@ -180,7 +186,7 @@ def get_experiment_port(args):
'''get the port of experiment''' '''get the port of experiment'''
experiment_id = check_experiment_id(args) experiment_id = check_experiment_id(args)
if experiment_id is None: if experiment_id is None:
print_error('Please set the experiment id!') print_error('Please set correct experiment id.')
exit(1) exit(1)
experiment_config = Experiments() experiment_config = Experiments()
experiment_dict = experiment_config.get_all_experiments() experiment_dict = experiment_config.get_all_experiments()
...@@ -229,7 +235,7 @@ def stop_experiment(args): ...@@ -229,7 +235,7 @@ def stop_experiment(args):
except Exception as exception: except Exception as exception:
print_error(exception) print_error(exception)
nni_config.set_config('tensorboardPidList', []) nni_config.set_config('tensorboardPidList', [])
print_normal('Stop experiment success!') print_normal('Stop experiment success.')
experiment_config.update_experiment(experiment_id, 'status', 'STOPPED') experiment_config.update_experiment(experiment_id, 'status', 'STOPPED')
time_now = time.strftime('%Y-%m-%d %H:%M:%S',time.localtime(time.time())) time_now = time.strftime('%Y-%m-%d %H:%M:%S',time.localtime(time.time()))
experiment_config.update_experiment(experiment_id, 'endTime', str(time_now)) experiment_config.update_experiment(experiment_id, 'endTime', str(time_now))
...@@ -354,10 +360,10 @@ def log_trial(args): ...@@ -354,10 +360,10 @@ def log_trial(args):
if trial_id_path_dict.get(args.trial_id): if trial_id_path_dict.get(args.trial_id):
print_normal('id:' + args.trial_id + ' path:' + trial_id_path_dict[args.trial_id]) print_normal('id:' + args.trial_id + ' path:' + trial_id_path_dict[args.trial_id])
else: else:
print_error('trial id is not valid!') print_error('trial id is not valid.')
exit(1) exit(1)
else: else:
print_error('please specific the trial id!') print_error('please specific the trial id.')
exit(1) exit(1)
else: else:
for key in trial_id_path_dict: for key in trial_id_path_dict:
...@@ -373,16 +379,179 @@ def webui_url(args): ...@@ -373,16 +379,179 @@ def webui_url(args):
nni_config = Config(get_config_filename(args)) nni_config = Config(get_config_filename(args))
print_normal('{0} {1}'.format('Web UI url:', ' '.join(nni_config.get_config('webuiUrl')))) print_normal('{0} {1}'.format('Web UI url:', ' '.join(nni_config.get_config('webuiUrl'))))
def local_clean(directory):
'''clean up local data'''
print_normal('removing folder {0}'.format(directory))
try:
shutil.rmtree(directory)
except FileNotFoundError as err:
print_error('{0} does not exist.'.format(directory))
def remote_clean(machine_list, experiment_id=None):
'''clean up remote data'''
for machine in machine_list:
passwd = machine.get('passwd')
userName = machine.get('username')
host = machine.get('ip')
port = machine.get('port')
if experiment_id:
remote_dir = '/' + '/'.join(['tmp', 'nni', 'experiments', experiment_id])
else:
remote_dir = '/' + '/'.join(['tmp', 'nni', 'experiments'])
sftp = create_ssh_sftp_client(host, port, userName, passwd)
print_normal('removing folder {0}'.format(host + ':' + str(port) + remote_dir))
remove_remote_directory(sftp, remote_dir)
def hdfs_clean(host, user_name, output_dir, experiment_id=None):
'''clean up hdfs data'''
hdfs_client = HdfsClient(hosts='{0}:80'.format(host), user_name=user_name, webhdfs_path='/webhdfs/api/v1', timeout=5)
if experiment_id:
full_path = '/' + '/'.join([user_name, 'nni', 'experiments', experiment_id])
else:
full_path = '/' + '/'.join([user_name, 'nni', 'experiments'])
print_normal('removing folder {0} in hdfs'.format(full_path))
hdfs_client.delete(full_path, recursive=True)
if output_dir:
pattern = re.compile('hdfs://(?P<host>([0-9]{1,3}.){3}[0-9]{1,3})(:[0-9]{2,5})?(?P<baseDir>/.*)?')
match_result = pattern.match(output_dir)
if match_result:
output_host = match_result.group('host')
output_dir = match_result.group('baseDir')
#check if the host is valid
if output_host != host:
print_warning('The host in {0} is not consistent with {1}'.format(output_dir, host))
else:
if experiment_id:
output_dir = output_dir + '/' + experiment_id
print_normal('removing folder {0} in hdfs'.format(output_dir))
hdfs_client.delete(output_dir, recursive=True)
def experiment_clean(args):
'''clean up the experiment data'''
experiment_id_list = []
experiment_config = Experiments()
experiment_dict = experiment_config.get_all_experiments()
if args.all:
experiment_id_list = list(experiment_dict.keys())
else:
if args.id is None:
print_error('please set experiment id.')
exit(1)
if args.id not in experiment_dict:
print_error('Cannot find experiment {0}.'.format(args.id))
exit(1)
experiment_id_list.append(args.id)
while True:
print('INFO: This action will delete experiment {0}, and it\'s not recoverable.'.format(' '.join(experiment_id_list)))
inputs = input('INFO: do you want to continue?[y/N]:')
if not inputs.lower() or inputs.lower() in ['n', 'no']:
exit(0)
elif inputs.lower() not in ['y', 'n', 'yes', 'no']:
print_warning('please input Y or N.')
else:
break
for experiment_id in experiment_id_list:
nni_config = Config(experiment_dict[experiment_id]['fileName'])
platform = nni_config.get_config('experimentConfig').get('trainingServicePlatform')
experiment_id = nni_config.get_config('experimentId')
if platform == 'remote':
machine_list = nni_config.get_config('experimentConfig').get('machineList')
remote_clean(machine_list, experiment_id)
elif platform == 'pai':
host = nni_config.get_config('experimentConfig').get('paiConfig').get('host')
user_name = nni_config.get_config('experimentConfig').get('paiConfig').get('userName')
output_dir = nni_config.get_config('experimentConfig').get('trial').get('outputDir')
hdfs_clean(host, user_name, output_dir, experiment_id)
elif platform != 'local':
#TODO: support all platforms
print_warning('platform {0} clean up not supported yet.'.format(platform))
exit(0)
#clean local data
home = str(Path.home())
local_dir = nni_config.get_config('experimentConfig').get('logDir')
if not local_dir:
local_dir = os.path.join(home, 'nni', 'experiments', experiment_id)
local_clean(local_dir)
experiment_config = Experiments()
print_normal('removing metadata of experiment {0}'.format(experiment_id))
experiment_config.remove_experiment(experiment_id)
print_normal('Done.')
def get_platform_dir(config_content):
'''get the dir list to be deleted'''
platform = config_content.get('trainingServicePlatform')
dir_list = []
if platform == 'remote':
machine_list = config_content.get('machineList')
for machine in machine_list:
host = machine.get('ip')
port = machine.get('port')
dir_list.append(host + ':' + str(port) + '/tmp/nni')
elif platform == 'pai':
pai_config = config_content.get('paiConfig')
host = config_content.get('paiConfig').get('host')
user_name = config_content.get('paiConfig').get('userName')
output_dir = config_content.get('trial').get('outputDir')
dir_list.append('server: {0}, path: {1}/nni'.format(host, user_name))
if output_dir:
dir_list.append(output_dir)
return dir_list
def platform_clean(args):
'''clean up the experiment data'''
config_path = os.path.abspath(args.config)
if not os.path.exists(config_path):
print_error('Please set correct config path.')
exit(1)
config_content = get_yml_content(config_path)
platform = config_content.get('trainingServicePlatform')
if platform == 'local':
print_normal('it doesn’t need to clean local platform.')
exit(0)
if platform not in ['remote', 'pai']:
print_normal('platform {0} not supported.'.format(platform))
exit(0)
experiment_config = Experiments()
experiment_dict = experiment_config.get_all_experiments()
update_experiment()
id_list = list(experiment_dict.keys())
dir_list = get_platform_dir(config_content)
if not dir_list:
print_normal('No folder of NNI caches is found.')
exit(1)
while True:
print_normal('This command will remove below folders of NNI caches. If other users are using experiments on below hosts, it will be broken.')
for dir in dir_list:
print(' ' + dir)
inputs = input('INFO: do you want to continue?[y/N]:')
if not inputs.lower() or inputs.lower() in ['n', 'no']:
exit(0)
elif inputs.lower() not in ['y', 'n', 'yes', 'no']:
print_warning('please input Y or N.')
else:
break
if platform == 'remote':
machine_list = config_content.get('machineList')
for machine in machine_list:
remote_clean(machine_list, None)
elif platform == 'pai':
pai_config = config_content.get('paiConfig')
host = config_content.get('paiConfig').get('host')
user_name = config_content.get('paiConfig').get('userName')
output_dir = config_content.get('trial').get('outputDir')
hdfs_clean(host, user_name, output_dir, None)
print_normal('Done.')
def experiment_list(args): def experiment_list(args):
'''get the information of all experiments''' '''get the information of all experiments'''
experiment_config = Experiments() experiment_config = Experiments()
experiment_dict = experiment_config.get_all_experiments() experiment_dict = experiment_config.get_all_experiments()
if not experiment_dict: if not experiment_dict:
print('There is no experiment running...') print_normal('Cannot find experiments.')
exit(1) exit(1)
update_experiment() update_experiment()
experiment_id_list = [] experiment_id_list = []
if args.all and args.all == 'all': if args.all:
for key in experiment_dict.keys(): for key in experiment_dict.keys():
experiment_id_list.append(key) experiment_id_list.append(key)
else: else:
...@@ -390,10 +559,9 @@ def experiment_list(args): ...@@ -390,10 +559,9 @@ def experiment_list(args):
if experiment_dict[key]['status'] != 'STOPPED': if experiment_dict[key]['status'] != 'STOPPED':
experiment_id_list.append(key) experiment_id_list.append(key)
if not experiment_id_list: if not experiment_id_list:
print_warning('There is no experiment running...\nYou can use \'nnictl experiment list all\' to list all stopped experiments!') print_warning('There is no experiment running...\nYou can use \'nnictl experiment list --all\' to list all stopped experiments.')
experiment_information = "" experiment_information = ""
for key in experiment_id_list: for key in experiment_id_list:
experiment_information += (EXPERIMENT_DETAIL_FORMAT % (key, experiment_dict[key]['status'], experiment_dict[key]['port'],\ experiment_information += (EXPERIMENT_DETAIL_FORMAT % (key, experiment_dict[key]['status'], experiment_dict[key]['port'],\
experiment_dict[key].get('platform'), experiment_dict[key]['startTime'], experiment_dict[key]['endTime'])) experiment_dict[key].get('platform'), experiment_dict[key]['startTime'], experiment_dict[key]['endTime']))
print(EXPERIMENT_INFORMATION_FORMAT % experiment_information) print(EXPERIMENT_INFORMATION_FORMAT % experiment_information)
......
...@@ -57,3 +57,17 @@ def create_ssh_sftp_client(host_ip, port, username, password): ...@@ -57,3 +57,17 @@ def create_ssh_sftp_client(host_ip, port, username, password):
return sftp return sftp
except Exception as exception: except Exception as exception:
print_error('Create ssh client error %s\n' % exception) print_error('Create ssh client error %s\n' % exception)
def remove_remote_directory(sftp, directory):
'''remove a directory in remote machine'''
try:
files = sftp.listdir(directory)
for file in files:
filepath = '/'.join([directory, file])
try:
sftp.remove(filepath)
except IOError:
remove_remote_directory(sftp, filepath)
sftp.rmdir(directory)
except IOError as err:
print_error(err)
\ No newline at end of file
...@@ -224,7 +224,7 @@ if __name__ == '__main__': ...@@ -224,7 +224,7 @@ if __name__ == '__main__':
exit(1) exit(1)
check_version(args) check_version(args)
try: try:
if is_multi_phase(): if NNI_PLATFORM == 'pai' and is_multi_phase():
fetch_parameter_file(args) fetch_parameter_file(args)
main_loop(args) main_loop(args)
except SystemExit as se: except SystemExit as se:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment