Commit d6febf29 authored by suiguoxin's avatar suiguoxin
Browse files

Merge branch 'master' of git://github.com/microsoft/nni

parents 77c95479 c2179921
...@@ -76,11 +76,12 @@ def _generate_file_search_space(path, module): ...@@ -76,11 +76,12 @@ def _generate_file_search_space(path, module):
return search_space return search_space
def expand_annotations(src_dir, dst_dir, exp_id='', trial_id=''): def expand_annotations(src_dir, dst_dir, exp_id='', trial_id='', nas_mode=None):
"""Expand annotations in user code. """Expand annotations in user code.
Return dst_dir if annotation detected; return src_dir if not. Return dst_dir if annotation detected; return src_dir if not.
src_dir: directory path of user code (str) src_dir: directory path of user code (str)
dst_dir: directory to place generated files (str) dst_dir: directory to place generated files (str)
nas_mode: the mode of NAS given that NAS interface is used
""" """
if src_dir[-1] == slash: if src_dir[-1] == slash:
src_dir = src_dir[:-1] src_dir = src_dir[:-1]
...@@ -108,7 +109,7 @@ def expand_annotations(src_dir, dst_dir, exp_id='', trial_id=''): ...@@ -108,7 +109,7 @@ def expand_annotations(src_dir, dst_dir, exp_id='', trial_id=''):
dst_path = os.path.join(dst_subdir, file_name) dst_path = os.path.join(dst_subdir, file_name)
if file_name.endswith('.py'): if file_name.endswith('.py'):
if trial_id == '': if trial_id == '':
annotated |= _expand_file_annotations(src_path, dst_path) annotated |= _expand_file_annotations(src_path, dst_path, nas_mode)
else: else:
module = package + file_name[:-3] module = package + file_name[:-3]
annotated |= _generate_specific_file(src_path, dst_path, exp_id, trial_id, module) annotated |= _generate_specific_file(src_path, dst_path, exp_id, trial_id, module)
...@@ -120,10 +121,10 @@ def expand_annotations(src_dir, dst_dir, exp_id='', trial_id=''): ...@@ -120,10 +121,10 @@ def expand_annotations(src_dir, dst_dir, exp_id='', trial_id=''):
return dst_dir if annotated else src_dir return dst_dir if annotated else src_dir
def _expand_file_annotations(src_path, dst_path): def _expand_file_annotations(src_path, dst_path, nas_mode):
with open(src_path) as src, open(dst_path, 'w') as dst: with open(src_path) as src, open(dst_path, 'w') as dst:
try: try:
annotated_code = code_generator.parse(src.read()) annotated_code = code_generator.parse(src.read(), nas_mode)
if annotated_code is None: if annotated_code is None:
shutil.copyfile(src_path, dst_path) shutil.copyfile(src_path, dst_path)
return False return False
......
...@@ -21,14 +21,14 @@ ...@@ -21,14 +21,14 @@
import ast import ast
import astor import astor
from nni_cmd.common_utils import print_warning
# pylint: disable=unidiomatic-typecheck # pylint: disable=unidiomatic-typecheck
def parse_annotation_mutable_layers(code, lineno): def parse_annotation_mutable_layers(code, lineno, nas_mode):
"""Parse the string of mutable layers in annotation. """Parse the string of mutable layers in annotation.
Return a list of AST Expr nodes Return a list of AST Expr nodes
code: annotation string (excluding '@') code: annotation string (excluding '@')
nas_mode: the mode of NAS
""" """
module = ast.parse(code) module = ast.parse(code)
assert type(module) is ast.Module, 'internal error #1' assert type(module) is ast.Module, 'internal error #1'
...@@ -110,6 +110,9 @@ def parse_annotation_mutable_layers(code, lineno): ...@@ -110,6 +110,9 @@ def parse_annotation_mutable_layers(code, lineno):
else: else:
target_call_args.append(ast.Dict(keys=[], values=[])) target_call_args.append(ast.Dict(keys=[], values=[]))
target_call_args.append(ast.Num(n=0)) target_call_args.append(ast.Num(n=0))
target_call_args.append(ast.Str(s=nas_mode))
if nas_mode in ['enas_mode', 'oneshot_mode']:
target_call_args.append(ast.Name(id='tensorflow'))
target_call = ast.Call(func=target_call_attr, args=target_call_args, keywords=[]) target_call = ast.Call(func=target_call_attr, args=target_call_args, keywords=[])
node = ast.Assign(targets=[layer_output], value=target_call) node = ast.Assign(targets=[layer_output], value=target_call)
nodes.append(node) nodes.append(node)
...@@ -277,10 +280,11 @@ class FuncReplacer(ast.NodeTransformer): ...@@ -277,10 +280,11 @@ class FuncReplacer(ast.NodeTransformer):
class Transformer(ast.NodeTransformer): class Transformer(ast.NodeTransformer):
"""Transform original code to annotated code""" """Transform original code to annotated code"""
def __init__(self): def __init__(self, nas_mode=None):
self.stack = [] self.stack = []
self.last_line = 0 self.last_line = 0
self.annotated = False self.annotated = False
self.nas_mode = nas_mode
def visit(self, node): def visit(self, node):
if isinstance(node, (ast.expr, ast.stmt)): if isinstance(node, (ast.expr, ast.stmt)):
...@@ -316,8 +320,11 @@ class Transformer(ast.NodeTransformer): ...@@ -316,8 +320,11 @@ class Transformer(ast.NodeTransformer):
return node # not an annotation, ignore it return node # not an annotation, ignore it
if string.startswith('@nni.get_next_parameter'): if string.startswith('@nni.get_next_parameter'):
deprecated_message = "'@nni.get_next_parameter' is deprecated in annotation due to inconvenience. Please remove this line in the trial code." call_node = parse_annotation(string[1:]).value
print_warning(deprecated_message) if call_node.args:
# it is used in enas mode as it needs to retrieve the next subgraph for training
call_attr = ast.Attribute(value=ast.Name(id='nni', ctx=ast.Load()), attr='reload_tensorflow_variables', ctx=ast.Load())
return ast.Expr(value=ast.Call(func=call_attr, args=call_node.args, keywords=[]))
if string.startswith('@nni.report_intermediate_result') \ if string.startswith('@nni.report_intermediate_result') \
or string.startswith('@nni.report_final_result') \ or string.startswith('@nni.report_final_result') \
...@@ -325,7 +332,8 @@ class Transformer(ast.NodeTransformer): ...@@ -325,7 +332,8 @@ class Transformer(ast.NodeTransformer):
return parse_annotation(string[1:]) # expand annotation string to code return parse_annotation(string[1:]) # expand annotation string to code
if string.startswith('@nni.mutable_layers'): if string.startswith('@nni.mutable_layers'):
return parse_annotation_mutable_layers(string[1:], node.lineno) nodes = parse_annotation_mutable_layers(string[1:], node.lineno, self.nas_mode)
return nodes
if string.startswith('@nni.variable') \ if string.startswith('@nni.variable') \
or string.startswith('@nni.function_choice'): or string.startswith('@nni.function_choice'):
...@@ -343,17 +351,18 @@ class Transformer(ast.NodeTransformer): ...@@ -343,17 +351,18 @@ class Transformer(ast.NodeTransformer):
return node return node
def parse(code): def parse(code, nas_mode=None):
"""Annotate user code. """Annotate user code.
Return annotated code (str) if annotation detected; return None if not. Return annotated code (str) if annotation detected; return None if not.
code: original user code (str) code: original user code (str),
nas_mode: the mode of NAS given that NAS interface is used
""" """
try: try:
ast_tree = ast.parse(code) ast_tree = ast.parse(code)
except Exception: except Exception:
raise RuntimeError('Bad Python code') raise RuntimeError('Bad Python code')
transformer = Transformer() transformer = Transformer(nas_mode)
try: try:
transformer.visit(ast_tree) transformer.visit(ast_tree)
except AssertionError as exc: except AssertionError as exc:
...@@ -369,5 +378,9 @@ def parse(code): ...@@ -369,5 +378,9 @@ def parse(code):
if type(nodes[i]) is ast.ImportFrom and nodes[i].module == '__future__': if type(nodes[i]) is ast.ImportFrom and nodes[i].module == '__future__':
last_future_import = i last_future_import = i
nodes.insert(last_future_import + 1, import_nni) nodes.insert(last_future_import + 1, import_nni)
# enas and oneshot modes for tensorflow need tensorflow module, so we import it here
if nas_mode in ['enas_mode', 'oneshot_mode']:
import_tf = ast.Import(names=[ast.alias(name='tensorflow', asname=None)])
nodes.insert(last_future_import + 1, import_tf)
return astor.to_source(ast_tree) return astor.to_source(ast_tree)
...@@ -196,7 +196,8 @@ common_trial_schema = { ...@@ -196,7 +196,8 @@ common_trial_schema = {
'trial':{ 'trial':{
'command': setType('command', str), 'command': setType('command', str),
'codeDir': setPathCheck('codeDir'), 'codeDir': setPathCheck('codeDir'),
'gpuNum': setNumberRange('gpuNum', int, 0, 99999) 'gpuNum': setNumberRange('gpuNum', int, 0, 99999),
Optional('nasMode'): setChoice('classic_mode', 'enas_mode', 'oneshot_mode')
} }
} }
......
...@@ -377,7 +377,7 @@ def launch_experiment(args, experiment_config, mode, config_file_name, experimen ...@@ -377,7 +377,7 @@ def launch_experiment(args, experiment_config, mode, config_file_name, experimen
if not os.path.isdir(path): if not os.path.isdir(path):
os.makedirs(path) os.makedirs(path)
path = tempfile.mkdtemp(dir=path) path = tempfile.mkdtemp(dir=path)
code_dir = expand_annotations(experiment_config['trial']['codeDir'], path) code_dir = expand_annotations(experiment_config['trial']['codeDir'], path, nas_mode=experiment_config['trial']['nasMode'])
experiment_config['trial']['codeDir'] = code_dir experiment_config['trial']['codeDir'] = code_dir
search_space = generate_search_space(code_dir) search_space = generate_search_space(code_dir)
experiment_config['searchSpace'] = json.dumps(search_space) experiment_config['searchSpace'] = json.dumps(search_space)
......
...@@ -121,6 +121,19 @@ def parse_args(): ...@@ -121,6 +121,19 @@ def parse_args():
parser_experiment_list = parser_experiment_subparsers.add_parser('list', help='list all of running experiment ids') parser_experiment_list = parser_experiment_subparsers.add_parser('list', help='list all of running experiment ids')
parser_experiment_list.add_argument('all', nargs='?', help='list all of experiments') parser_experiment_list.add_argument('all', nargs='?', help='list all of experiments')
parser_experiment_list.set_defaults(func=experiment_list) parser_experiment_list.set_defaults(func=experiment_list)
parser_experiment_clean = parser_experiment_subparsers.add_parser('delete', help='clean up the experiment data')
parser_experiment_clean.add_argument('id', nargs='?', help='the id of experiment')
parser_experiment_clean.add_argument('--all', action='store_true', default=False, help='delete all of experiments')
parser_experiment_clean.set_defaults(func=experiment_clean)
#parse experiment command
parser_platform = subparsers.add_parser('platform', help='get platform information')
#add subparsers for parser_experiment
parser_platform_subparsers = parser_platform.add_subparsers()
parser_platform_clean = parser_platform_subparsers.add_parser('clean', help='clean up the platform data')
parser_platform_clean.add_argument('--config', '-c', required=True, dest='config', help='the path of yaml config file')
parser_platform_clean.set_defaults(func=platform_clean)
#import tuning data #import tuning data
parser_import_data = parser_experiment_subparsers.add_parser('import', help='import additional data') parser_import_data = parser_experiment_subparsers.add_parser('import', help='import additional data')
parser_import_data.add_argument('id', nargs='?', help='the id of experiment') parser_import_data.add_argument('id', nargs='?', help='the id of experiment')
......
...@@ -24,6 +24,10 @@ import psutil ...@@ -24,6 +24,10 @@ import psutil
import json import json
import datetime import datetime
import time import time
import re
from pathlib import Path
from pyhdfs import HdfsClient, HdfsFileNotFoundException
import shutil
from subprocess import call, check_output from subprocess import call, check_output
from nni_annotation import expand_annotations from nni_annotation import expand_annotations
from .rest_utils import rest_get, rest_delete, check_rest_server_quick, check_response from .rest_utils import rest_get, rest_delete, check_rest_server_quick, check_response
...@@ -31,8 +35,9 @@ from .url_utils import trial_jobs_url, experiment_url, trial_job_id_url, export_ ...@@ -31,8 +35,9 @@ from .url_utils import trial_jobs_url, experiment_url, trial_job_id_url, export_
from .config_utils import Config, Experiments from .config_utils import Config, Experiments
from .constants import NNICTL_HOME_DIR, EXPERIMENT_INFORMATION_FORMAT, EXPERIMENT_DETAIL_FORMAT, \ from .constants import NNICTL_HOME_DIR, EXPERIMENT_INFORMATION_FORMAT, EXPERIMENT_DETAIL_FORMAT, \
EXPERIMENT_MONITOR_INFO, TRIAL_MONITOR_HEAD, TRIAL_MONITOR_CONTENT, TRIAL_MONITOR_TAIL, REST_TIME_OUT EXPERIMENT_MONITOR_INFO, TRIAL_MONITOR_HEAD, TRIAL_MONITOR_CONTENT, TRIAL_MONITOR_TAIL, REST_TIME_OUT
from .common_utils import print_normal, print_error, print_warning, detect_process from .common_utils import print_normal, print_error, print_warning, detect_process, get_yml_content
from .command_utils import check_output_command, kill_command from .command_utils import check_output_command, kill_command
from .ssh_utils import create_ssh_sftp_client, remove_remote_directory
def get_experiment_time(port): def get_experiment_time(port):
'''get the startTime and endTime of an experiment''' '''get the startTime and endTime of an experiment'''
...@@ -73,10 +78,11 @@ def update_experiment(): ...@@ -73,10 +78,11 @@ def update_experiment():
if status: if status:
experiment_config.update_experiment(key, 'status', status) experiment_config.update_experiment(key, 'status', status)
def check_experiment_id(args): def check_experiment_id(args, update=True):
'''check if the id is valid '''check if the id is valid
''' '''
update_experiment() if update:
update_experiment()
experiment_config = Experiments() experiment_config = Experiments()
experiment_dict = experiment_config.get_all_experiments() experiment_dict = experiment_config.get_all_experiments()
if not experiment_dict: if not experiment_dict:
...@@ -170,7 +176,7 @@ def get_config_filename(args): ...@@ -170,7 +176,7 @@ def get_config_filename(args):
'''get the file name of config file''' '''get the file name of config file'''
experiment_id = check_experiment_id(args) experiment_id = check_experiment_id(args)
if experiment_id is None: if experiment_id is None:
print_error('Please set the experiment id!') print_error('Please set correct experiment id!')
exit(1) exit(1)
experiment_config = Experiments() experiment_config = Experiments()
experiment_dict = experiment_config.get_all_experiments() experiment_dict = experiment_config.get_all_experiments()
...@@ -180,7 +186,7 @@ def get_experiment_port(args): ...@@ -180,7 +186,7 @@ def get_experiment_port(args):
'''get the port of experiment''' '''get the port of experiment'''
experiment_id = check_experiment_id(args) experiment_id = check_experiment_id(args)
if experiment_id is None: if experiment_id is None:
print_error('Please set the experiment id!') print_error('Please set correct experiment id!')
exit(1) exit(1)
experiment_config = Experiments() experiment_config = Experiments()
experiment_dict = experiment_config.get_all_experiments() experiment_dict = experiment_config.get_all_experiments()
...@@ -373,6 +379,166 @@ def webui_url(args): ...@@ -373,6 +379,166 @@ def webui_url(args):
nni_config = Config(get_config_filename(args)) nni_config = Config(get_config_filename(args))
print_normal('{0} {1}'.format('Web UI url:', ' '.join(nni_config.get_config('webuiUrl')))) print_normal('{0} {1}'.format('Web UI url:', ' '.join(nni_config.get_config('webuiUrl'))))
def local_clean(directory):
'''clean up local data'''
print_normal('removing folder {0}'.format(directory))
try:
shutil.rmtree(directory)
except FileNotFoundError as err:
print_error('{0} does not exist!'.format(directory))
def remote_clean(machine_list, experiment_id=None):
'''clean up remote data'''
for machine in machine_list:
passwd = machine.get('passwd')
userName = machine.get('username')
host = machine.get('ip')
port = machine.get('port')
if experiment_id:
remote_dir = '/' + '/'.join(['tmp', 'nni', 'experiments', experiment_id])
else:
remote_dir = '/' + '/'.join(['tmp', 'nni', 'experiments'])
sftp = create_ssh_sftp_client(host, port, userName, passwd)
print_normal('removing folder {0}'.format(host + ':' + str(port) + remote_dir))
remove_remote_directory(sftp, remote_dir)
def hdfs_clean(host, user_name, output_dir, experiment_id=None):
'''clean up hdfs data'''
hdfs_client = HdfsClient(hosts='{0}:80'.format(host), user_name=user_name, webhdfs_path='/webhdfs/api/v1', timeout=5)
if experiment_id:
full_path = '/' + '/'.join([user_name, 'nni', 'experiments', experiment_id])
else:
full_path = '/' + '/'.join([user_name, 'nni', 'experiments'])
print_normal('removing folder {0} in hdfs'.format(full_path))
hdfs_client.delete(full_path, recursive=True)
if output_dir:
pattern = re.compile('hdfs://(?P<host>([0-9]{1,3}.){3}[0-9]{1,3})(:[0-9]{2,5})?(?P<baseDir>/.*)?')
match_result = pattern.match(output_dir)
if match_result:
output_host = match_result.group('host')
output_dir = match_result.group('baseDir')
#check if the host is valid
if output_host != host:
print_warning('The host in {0} is not consistent with {1}'.format(output_dir, host))
else:
if experiment_id:
output_dir = output_dir + '/' + experiment_id
print_normal('removing folder {0} in hdfs'.format(output_dir))
hdfs_client.delete(output_dir, recursive=True)
def experiment_clean(args):
'''clean up the experiment data'''
experiment_id_list = []
experiment_config = Experiments()
experiment_dict = experiment_config.get_all_experiments()
if args.all:
experiment_id_list = list(experiment_dict.keys())
else:
if args.id is None:
print_error('please set experiment id!')
exit(1)
if args.id not in experiment_dict:
print_error('can not find id {0}!'.format(args.id))
exit(1)
experiment_id_list.append(args.id)
while True:
print('INFO: This action will delete experiment {0}, and it’s not recoverable.'.format(' '.join(experiment_id_list)))
inputs = input('INFO: do you want to continue?[y/N]:')
if not inputs.lower() or inputs.lower() in ['n', 'no']:
exit(0)
elif inputs.lower() not in ['y', 'n', 'yes', 'no']:
print_warning('please input Y or N!')
else:
break
for experiment_id in experiment_id_list:
nni_config = Config(experiment_dict[experiment_id]['fileName'])
platform = nni_config.get_config('experimentConfig').get('trainingServicePlatform')
experiment_id = nni_config.get_config('experimentId')
if platform == 'remote':
machine_list = nni_config.get_config('experimentConfig').get('machineList')
remote_clean(machine_list, experiment_id)
elif platform == 'pai':
host = nni_config.get_config('experimentConfig').get('paiConfig').get('host')
user_name = nni_config.get_config('experimentConfig').get('paiConfig').get('userName')
output_dir = nni_config.get_config('experimentConfig').get('trial').get('outputDir')
hdfs_clean(host, user_name, output_dir, experiment_id)
elif platform != 'local':
#TODO: support all platforms
print_warning('platform {0} clean up not supported yet!'.format(platform))
exit(0)
#clean local data
home = str(Path.home())
local_dir = nni_config.get_config('experimentConfig').get('logDir')
if not local_dir:
local_dir = os.path.join(home, 'nni', 'experiments', experiment_id)
local_clean(local_dir)
experiment_config = Experiments()
print_normal('removing metadata of experiment {0}'.format(experiment_id))
experiment_config.remove_experiment(experiment_id)
print_normal('Finish!')
def get_platform_dir(config_content):
'''get the dir list to be deleted'''
platform = config_content.get('trainingServicePlatform')
dir_list = []
if platform == 'remote':
machine_list = config_content.get('machineList')
for machine in machine_list:
host = machine.get('ip')
port = machine.get('port')
dir_list.append(host + ':' + str(port) + '/tmp/nni/experiments')
elif platform == 'pai':
pai_config = config_content.get('paiConfig')
host = config_content.get('paiConfig').get('host')
user_name = config_content.get('paiConfig').get('userName')
output_dir = config_content.get('trial').get('outputDir')
dir_list.append('hdfs://{0}:9000/{1}/nni/experiments'.format(host, user_name))
if output_dir:
dir_list.append(output_dir)
return dir_list
def platform_clean(args):
'''clean up the experiment data'''
config_path = os.path.abspath(args.config)
if not os.path.exists(config_path):
print_error('Please set correct config path!')
exit(1)
config_content = get_yml_content(config_path)
platform = config_content.get('trainingServicePlatform')
if platform not in ['remote', 'pai']:
print_normal('platform {0} not supported!'.format(platform))
exit(0)
experiment_config = Experiments()
experiment_dict = experiment_config.get_all_experiments()
update_experiment()
id_list = list(experiment_dict.keys())
dir_list = get_platform_dir(config_content)
if not dir_list:
print_normal('No folder of NNI caches is found!')
exit(1)
while True:
print_normal('This command will remove below folders of NNI caches. If other users are using experiments on below hosts, it will be broken.')
for dir in dir_list:
print(' ' + dir)
inputs = input('INFO: do you want to continue?[y/N]:')
if not inputs.lower() or inputs.lower() in ['n', 'no']:
exit(0)
elif inputs.lower() not in ['y', 'n', 'yes', 'no']:
print_warning('please input Y or N!')
else:
break
if platform == 'remote':
machine_list = config_content.get('machineList')
for machine in machine_list:
remote_clean(machine_list, None)
elif platform == 'pai':
pai_config = config_content.get('paiConfig')
host = config_content.get('paiConfig').get('host')
user_name = config_content.get('paiConfig').get('userName')
output_dir = config_content.get('trial').get('outputDir')
hdfs_clean(host, user_name, output_dir, None)
print_normal('Done!')
def experiment_list(args): def experiment_list(args):
'''get the information of all experiments''' '''get the information of all experiments'''
experiment_config = Experiments() experiment_config = Experiments()
...@@ -393,7 +559,6 @@ def experiment_list(args): ...@@ -393,7 +559,6 @@ def experiment_list(args):
print_warning('There is no experiment running...\nYou can use \'nnictl experiment list all\' to list all stopped experiments!') print_warning('There is no experiment running...\nYou can use \'nnictl experiment list all\' to list all stopped experiments!')
experiment_information = "" experiment_information = ""
for key in experiment_id_list: for key in experiment_id_list:
experiment_information += (EXPERIMENT_DETAIL_FORMAT % (key, experiment_dict[key]['status'], experiment_dict[key]['port'],\ experiment_information += (EXPERIMENT_DETAIL_FORMAT % (key, experiment_dict[key]['status'], experiment_dict[key]['port'],\
experiment_dict[key].get('platform'), experiment_dict[key]['startTime'], experiment_dict[key]['endTime'])) experiment_dict[key].get('platform'), experiment_dict[key]['startTime'], experiment_dict[key]['endTime']))
print(EXPERIMENT_INFORMATION_FORMAT % experiment_information) print(EXPERIMENT_INFORMATION_FORMAT % experiment_information)
......
...@@ -57,3 +57,17 @@ def create_ssh_sftp_client(host_ip, port, username, password): ...@@ -57,3 +57,17 @@ def create_ssh_sftp_client(host_ip, port, username, password):
return sftp return sftp
except Exception as exception: except Exception as exception:
print_error('Create ssh client error %s\n' % exception) print_error('Create ssh client error %s\n' % exception)
def remove_remote_directory(sftp, directory):
'''remove a directory in remote machine'''
try:
files = sftp.listdir(directory)
for file in files:
filepath = '/'.join([directory, file])
try:
sftp.remove(filepath)
except IOError:
remove_remote_directory(sftp, filepath)
sftp.rmdir(directory)
except IOError as err:
print_error(err)
\ No newline at end of file
...@@ -36,6 +36,8 @@ STDERR_FULL_PATH = os.path.join(LOG_DIR, 'stderr') ...@@ -36,6 +36,8 @@ STDERR_FULL_PATH = os.path.join(LOG_DIR, 'stderr')
STDOUT_API = '/stdout' STDOUT_API = '/stdout'
VERSION_API = '/version' VERSION_API = '/version'
PARAMETER_META_API = '/parameter-file-meta'
NNI_SYS_DIR = os.environ['NNI_SYS_DIR'] NNI_SYS_DIR = os.environ['NNI_SYS_DIR']
NNI_TRIAL_JOB_ID = os.environ['NNI_TRIAL_JOB_ID'] NNI_TRIAL_JOB_ID = os.environ['NNI_TRIAL_JOB_ID']
NNI_EXP_ID = os.environ['NNI_EXP_ID'] NNI_EXP_ID = os.environ['NNI_EXP_ID']
\ No newline at end of file MULTI_PHASE = os.environ['MULTI_PHASE']
...@@ -28,18 +28,49 @@ import re ...@@ -28,18 +28,49 @@ import re
import sys import sys
import select import select
import json import json
import threading
from pyhdfs import HdfsClient from pyhdfs import HdfsClient
import pkg_resources import pkg_resources
from .rest_utils import rest_post from .rest_utils import rest_post, rest_get
from .url_utils import gen_send_stdout_url, gen_send_version_url from .url_utils import gen_send_stdout_url, gen_send_version_url, gen_parameter_meta_url
from .constants import HOME_DIR, LOG_DIR, NNI_PLATFORM, STDOUT_FULL_PATH, STDERR_FULL_PATH from .constants import HOME_DIR, LOG_DIR, NNI_PLATFORM, STDOUT_FULL_PATH, STDERR_FULL_PATH, \
from .hdfsClientUtility import copyDirectoryToHdfs, copyHdfsDirectoryToLocal MULTI_PHASE, NNI_TRIAL_JOB_ID, NNI_SYS_DIR, NNI_EXP_ID
from .hdfsClientUtility import copyDirectoryToHdfs, copyHdfsDirectoryToLocal, copyHdfsFileToLocal
from .log_utils import LogType, nni_log, RemoteLogger, PipeLogReader, StdOutputType from .log_utils import LogType, nni_log, RemoteLogger, PipeLogReader, StdOutputType
logger = logging.getLogger('trial_keeper') logger = logging.getLogger('trial_keeper')
regular = re.compile('v?(?P<version>[0-9](\.[0-9]){0,1}).*') regular = re.compile('v?(?P<version>[0-9](\.[0-9]){0,1}).*')
_hdfs_client = None
def get_hdfs_client(args):
global _hdfs_client
if _hdfs_client is not None:
return _hdfs_client
# backward compatibility
hdfs_host = None
if args.hdfs_host:
hdfs_host = args.hdfs_host
elif args.pai_hdfs_host:
hdfs_host = args.pai_hdfs_host
else:
return None
if hdfs_host is not None and args.nni_hdfs_exp_dir is not None:
try:
if args.webhdfs_path:
_hdfs_client = HdfsClient(hosts='{0}:80'.format(hdfs_host), user_name=args.pai_user_name, webhdfs_path=args.webhdfs_path, timeout=5)
else:
# backward compatibility
_hdfs_client = HdfsClient(hosts='{0}:{1}'.format(hdfs_host, '50070'), user_name=args.pai_user_name, timeout=5)
except Exception as e:
nni_log(LogType.Error, 'Create HDFS client error: ' + str(e))
raise e
return _hdfs_client
def main_loop(args): def main_loop(args):
'''main loop logic for trial keeper''' '''main loop logic for trial keeper'''
...@@ -52,28 +83,16 @@ def main_loop(args): ...@@ -52,28 +83,16 @@ def main_loop(args):
# redirect trial keeper's stdout and stderr to syslog # redirect trial keeper's stdout and stderr to syslog
trial_syslogger_stdout = RemoteLogger(args.nnimanager_ip, args.nnimanager_port, 'trial', StdOutputType.Stdout, args.log_collection) trial_syslogger_stdout = RemoteLogger(args.nnimanager_ip, args.nnimanager_port, 'trial', StdOutputType.Stdout, args.log_collection)
sys.stdout = sys.stderr = trial_keeper_syslogger sys.stdout = sys.stderr = trial_keeper_syslogger
# backward compatibility
hdfs_host = None
hdfs_output_dir = None hdfs_output_dir = None
if args.hdfs_host:
hdfs_host = args.hdfs_host
elif args.pai_hdfs_host:
hdfs_host = args.pai_hdfs_host
if args.hdfs_output_dir: if args.hdfs_output_dir:
hdfs_output_dir = args.hdfs_output_dir hdfs_output_dir = args.hdfs_output_dir
elif args.pai_hdfs_output_dir: elif args.pai_hdfs_output_dir:
hdfs_output_dir = args.pai_hdfs_output_dir hdfs_output_dir = args.pai_hdfs_output_dir
if hdfs_host is not None and args.nni_hdfs_exp_dir is not None: hdfs_client = get_hdfs_client(args)
try:
if args.webhdfs_path: if hdfs_client is not None:
hdfs_client = HdfsClient(hosts='{0}:80'.format(hdfs_host), user_name=args.pai_user_name, webhdfs_path=args.webhdfs_path, timeout=5)
else:
# backward compatibility
hdfs_client = HdfsClient(hosts='{0}:{1}'.format(hdfs_host, '50070'), user_name=args.pai_user_name, timeout=5)
except Exception as e:
nni_log(LogType.Error, 'Create HDFS client error: ' + str(e))
raise e
copyHdfsDirectoryToLocal(args.nni_hdfs_exp_dir, os.getcwd(), hdfs_client) copyHdfsDirectoryToLocal(args.nni_hdfs_exp_dir, os.getcwd(), hdfs_client)
# Notice: We don't appoint env, which means subprocess wil inherit current environment and that is expected behavior # Notice: We don't appoint env, which means subprocess wil inherit current environment and that is expected behavior
...@@ -138,6 +157,52 @@ def check_version(args): ...@@ -138,6 +157,52 @@ def check_version(args):
except AttributeError as err: except AttributeError as err:
nni_log(LogType.Error, err) nni_log(LogType.Error, err)
def is_multi_phase():
return MULTI_PHASE and (MULTI_PHASE in ['True', 'true'])
def download_parameter(meta_list, args):
"""
Download parameter file to local working directory.
meta_list format is defined in paiJobRestServer.ts
example meta_list:
[
{"experimentId":"yWFJarYa","trialId":"UpPkl","filePath":"/chec/nni/experiments/yWFJarYa/trials/UpPkl/parameter_1.cfg"},
{"experimentId":"yWFJarYa","trialId":"aIUMA","filePath":"/chec/nni/experiments/yWFJarYa/trials/aIUMA/parameter_1.cfg"}
]
"""
nni_log(LogType.Debug, str(meta_list))
nni_log(LogType.Debug, 'NNI_SYS_DIR: {}, trial Id: {}, experiment ID: {}'.format(NNI_SYS_DIR, NNI_TRIAL_JOB_ID, NNI_EXP_ID))
nni_log(LogType.Debug, 'NNI_SYS_DIR files: {}'.format(os.listdir(NNI_SYS_DIR)))
for meta in meta_list:
if meta['experimentId'] == NNI_EXP_ID and meta['trialId'] == NNI_TRIAL_JOB_ID:
param_fp = os.path.join(NNI_SYS_DIR, os.path.basename(meta['filePath']))
if not os.path.exists(param_fp):
hdfs_client = get_hdfs_client(args)
copyHdfsFileToLocal(meta['filePath'], param_fp, hdfs_client, override=False)
def fetch_parameter_file(args):
class FetchThread(threading.Thread):
def __init__(self, args):
super(FetchThread, self).__init__()
self.args = args
def run(self):
uri = gen_parameter_meta_url(self.args.nnimanager_ip, self.args.nnimanager_port)
nni_log(LogType.Info, uri)
while True:
res = rest_get(uri, 10)
nni_log(LogType.Debug, 'status code: {}'.format(res.status_code))
if res.status_code == 200:
meta_list = res.json()
download_parameter(meta_list, self.args)
else:
nni_log(LogType.Warning, 'rest response: {}'.format(str(res)))
time.sleep(5)
fetch_file_thread = FetchThread(args)
fetch_file_thread.start()
if __name__ == '__main__': if __name__ == '__main__':
'''NNI Trial Keeper main function''' '''NNI Trial Keeper main function'''
PARSER = argparse.ArgumentParser() PARSER = argparse.ArgumentParser()
...@@ -159,6 +224,8 @@ if __name__ == '__main__': ...@@ -159,6 +224,8 @@ if __name__ == '__main__':
exit(1) exit(1)
check_version(args) check_version(args)
try: try:
if NNI_PLATFORM == 'pai' and is_multi_phase():
fetch_parameter_file(args)
main_loop(args) main_loop(args)
except SystemExit as se: except SystemExit as se:
nni_log(LogType.Info, 'NNI trial keeper exit with code {}'.format(se.code)) nni_log(LogType.Info, 'NNI trial keeper exit with code {}'.format(se.code))
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
from .constants import API_ROOT_URL, BASE_URL, STDOUT_API, NNI_TRIAL_JOB_ID, NNI_EXP_ID, VERSION_API from .constants import API_ROOT_URL, BASE_URL, STDOUT_API, NNI_TRIAL_JOB_ID, NNI_EXP_ID, VERSION_API, PARAMETER_META_API
def gen_send_stdout_url(ip, port): def gen_send_stdout_url(ip, port):
'''Generate send stdout url''' '''Generate send stdout url'''
...@@ -26,4 +26,8 @@ def gen_send_stdout_url(ip, port): ...@@ -26,4 +26,8 @@ def gen_send_stdout_url(ip, port):
def gen_send_version_url(ip, port): def gen_send_version_url(ip, port):
'''Generate send error url''' '''Generate send error url'''
return '{0}:{1}{2}{3}/{4}/{5}'.format(BASE_URL.format(ip), port, API_ROOT_URL, VERSION_API, NNI_EXP_ID, NNI_TRIAL_JOB_ID) return '{0}:{1}{2}{3}/{4}/{5}'.format(BASE_URL.format(ip), port, API_ROOT_URL, VERSION_API, NNI_EXP_ID, NNI_TRIAL_JOB_ID)
\ No newline at end of file
def gen_parameter_meta_url(ip, port):
'''Generate send error url'''
return '{0}:{1}{2}{3}'.format(BASE_URL.format(ip), port, API_ROOT_URL, PARAMETER_META_API)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment