Unverified Commit eea50784 authored by chicm-ms's avatar chicm-ms Committed by GitHub
Browse files

Dev pylint (#1697)

Fix pylint errors
parent 1f9b7617
...@@ -20,7 +20,6 @@ ...@@ -20,7 +20,6 @@
import os import os
from .common_utils import print_error from .common_utils import print_error
from subprocess import call
from .command_utils import install_package_command from .command_utils import install_package_command
def check_environment(): def check_environment():
...@@ -29,6 +28,8 @@ def check_environment(): ...@@ -29,6 +28,8 @@ def check_environment():
import paramiko import paramiko
except: except:
install_package_command('paramiko') install_package_command('paramiko')
import paramiko
return paramiko
def copy_remote_directory_to_local(sftp, remote_path, local_path): def copy_remote_directory_to_local(sftp, remote_path, local_path):
'''copy remote directory to local machine''' '''copy remote directory to local machine'''
...@@ -49,8 +50,7 @@ def copy_remote_directory_to_local(sftp, remote_path, local_path): ...@@ -49,8 +50,7 @@ def copy_remote_directory_to_local(sftp, remote_path, local_path):
def create_ssh_sftp_client(host_ip, port, username, password): def create_ssh_sftp_client(host_ip, port, username, password):
'''create ssh client''' '''create ssh client'''
try: try:
check_environment() paramiko = check_environment()
import paramiko
conn = paramiko.Transport(host_ip, port) conn = paramiko.Transport(host_ip, port)
conn.connect(username=username, password=password) conn.connect(username=username, password=password)
sftp = paramiko.SFTPClient.from_transport(conn) sftp = paramiko.SFTPClient.from_transport(conn)
......
...@@ -19,21 +19,17 @@ ...@@ -19,21 +19,17 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
import os import os
import psutil
import json import json
import datetime
import time
from subprocess import call, check_output, Popen, PIPE
from .rest_utils import rest_get, rest_delete, check_rest_server_quick, check_response
from .config_utils import Config, Experiments
from .url_utils import trial_jobs_url, experiment_url, trial_job_id_url, get_local_urls
from .constants import NNICTL_HOME_DIR, EXPERIMENT_INFORMATION_FORMAT, EXPERIMENT_DETAIL_FORMAT, COLOR_GREEN_FORMAT
import time
from .common_utils import print_normal, print_error, print_warning, detect_process, detect_port
from .nnictl_utils import *
import re import re
from .ssh_utils import create_ssh_sftp_client, copy_remote_directory_to_local
import tempfile import tempfile
from subprocess import call, Popen
from .rest_utils import rest_get, check_rest_server_quick, check_response
from .config_utils import Config, Experiments
from .url_utils import trial_jobs_url, get_local_urls
from .constants import COLOR_GREEN_FORMAT, REST_TIME_OUT
from .common_utils import print_normal, print_error, detect_process, detect_port
from .nnictl_utils import check_experiment_id, check_experiment_id
from .ssh_utils import create_ssh_sftp_client, copy_remote_directory_to_local
def parse_log_path(args, trial_content): def parse_log_path(args, trial_content):
'''parse log path''' '''parse log path'''
...@@ -43,7 +39,7 @@ def parse_log_path(args, trial_content): ...@@ -43,7 +39,7 @@ def parse_log_path(args, trial_content):
if args.trial_id and args.trial_id != 'all' and trial.get('id') != args.trial_id: if args.trial_id and args.trial_id != 'all' and trial.get('id') != args.trial_id:
continue continue
pattern = r'(?P<head>.+)://(?P<host>.+):(?P<path>.*)' pattern = r'(?P<head>.+)://(?P<host>.+):(?P<path>.*)'
match = re.search(pattern,trial['logPath']) match = re.search(pattern, trial['logPath'])
if match: if match:
path_list.append(match.group('path')) path_list.append(match.group('path'))
host_list.append(match.group('host')) host_list.append(match.group('host'))
...@@ -94,7 +90,8 @@ def start_tensorboard_process(args, nni_config, path_list, temp_nni_path): ...@@ -94,7 +90,8 @@ def start_tensorboard_process(args, nni_config, path_list, temp_nni_path):
if detect_port(args.port): if detect_port(args.port):
print_error('Port %s is used by another process, please reset port!' % str(args.port)) print_error('Port %s is used by another process, please reset port!' % str(args.port))
exit(1) exit(1)
with open(os.path.join(temp_nni_path, 'tensorboard_stdout'), 'a+') as stdout_file, open(os.path.join(temp_nni_path, 'tensorboard_stderr'), 'a+') as stderr_file: with open(os.path.join(temp_nni_path, 'tensorboard_stdout'), 'a+') as stdout_file, \
open(os.path.join(temp_nni_path, 'tensorboard_stderr'), 'a+') as stderr_file:
cmds = ['tensorboard', '--logdir', format_tensorboard_log_path(path_list), '--port', str(args.port)] cmds = ['tensorboard', '--logdir', format_tensorboard_log_path(path_list), '--port', str(args.port)]
tensorboard_process = Popen(cmds, stdout=stdout_file, stderr=stderr_file) tensorboard_process = Popen(cmds, stdout=stdout_file, stderr=stderr_file)
url_list = get_local_urls(args.port) url_list = get_local_urls(args.port)
......
...@@ -25,7 +25,7 @@ from .rest_utils import rest_put, rest_post, rest_get, check_rest_server_quick, ...@@ -25,7 +25,7 @@ from .rest_utils import rest_put, rest_post, rest_get, check_rest_server_quick,
from .url_utils import experiment_url, import_data_url from .url_utils import experiment_url, import_data_url
from .config_utils import Config from .config_utils import Config
from .common_utils import get_json_content, print_normal, print_error, print_warning from .common_utils import get_json_content, print_normal, print_error, print_warning
from .nnictl_utils import check_experiment_id, get_experiment_port, get_config_filename from .nnictl_utils import get_experiment_port, get_config_filename
from .launcher_utils import parse_time from .launcher_utils import parse_time
from .constants import REST_TIME_OUT, TUNERS_SUPPORTING_IMPORT_DATA, TUNERS_NO_NEED_TO_IMPORT_DATA from .constants import REST_TIME_OUT, TUNERS_SUPPORTING_IMPORT_DATA, TUNERS_NO_NEED_TO_IMPORT_DATA
......
...@@ -18,8 +18,8 @@ ...@@ -18,8 +18,8 @@
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
import socket
import psutil import psutil
from socket import AddressFamily
BASE_URL = 'http://localhost' BASE_URL = 'http://localhost'
...@@ -83,8 +83,8 @@ def tensorboard_url(port): ...@@ -83,8 +83,8 @@ def tensorboard_url(port):
def get_local_urls(port): def get_local_urls(port):
'''get urls of local machine''' '''get urls of local machine'''
url_list = [] url_list = []
for name, info in psutil.net_if_addrs().items(): for _, info in psutil.net_if_addrs().items():
for addr in info: for addr in info:
if AddressFamily.AF_INET == addr.family: if socket.AddressFamily.AF_INET == addr.family:
url_list.append('http://{}:{}'.format(addr.address, port)) url_list.append('http://{}:{}'.format(addr.address, port))
return url_list return url_list
...@@ -27,19 +27,20 @@ from xml.dom import minidom ...@@ -27,19 +27,20 @@ from xml.dom import minidom
def check_ready_to_run(): def check_ready_to_run():
if sys.platform == 'win32': if sys.platform == 'win32':
pgrep_output = subprocess.check_output('wmic process where "CommandLine like \'%nni_gpu_tool.gpu_metrics_collector%\' and name like \'%python%\'" get processId') pgrep_output = subprocess.check_output(
'wmic process where "CommandLine like \'%nni_gpu_tool.gpu_metrics_collector%\' and name like \'%python%\'" get processId')
pidList = pgrep_output.decode("utf-8").strip().split() pidList = pgrep_output.decode("utf-8").strip().split()
pidList.pop(0) # remove the key word 'ProcessId' pidList.pop(0) # remove the key word 'ProcessId'
pidList = list(map(int, pidList)) pidList = list(map(int, pidList))
pidList.remove(os.getpid()) pidList.remove(os.getpid())
return len(pidList) == 0 return not pidList
else: else:
pgrep_output = subprocess.check_output('pgrep -fx \'python3 -m nni_gpu_tool.gpu_metrics_collector\'', shell=True) pgrep_output = subprocess.check_output('pgrep -fx \'python3 -m nni_gpu_tool.gpu_metrics_collector\'', shell=True)
pidList = [] pidList = []
for pid in pgrep_output.splitlines(): for pid in pgrep_output.splitlines():
pidList.append(int(pid)) pidList.append(int(pid))
pidList.remove(os.getpid()) pidList.remove(os.getpid())
return len(pidList) == 0 return not pidList
def main(argv): def main(argv):
metrics_output_dir = os.environ['METRIC_OUTPUT_DIR'] metrics_output_dir = os.environ['METRIC_OUTPUT_DIR']
...@@ -69,10 +70,14 @@ def parse_nvidia_smi_result(smi, outputDir): ...@@ -69,10 +70,14 @@ def parse_nvidia_smi_result(smi, outputDir):
outPut["gpuCount"] = len(gpuList) outPut["gpuCount"] = len(gpuList)
outPut["gpuInfos"] = [] outPut["gpuInfos"] = []
for gpuIndex, gpu in enumerate(gpuList): for gpuIndex, gpu in enumerate(gpuList):
gpuInfo ={} gpuInfo = {}
gpuInfo['index'] = gpuIndex gpuInfo['index'] = gpuIndex
gpuInfo['gpuUtil'] = gpu.getElementsByTagName('utilization')[0].getElementsByTagName('gpu_util')[0].childNodes[0].data.replace("%", "").strip() gpuInfo['gpuUtil'] = gpu.getElementsByTagName('utilization')[0]\
gpuInfo['gpuMemUtil'] = gpu.getElementsByTagName('utilization')[0].getElementsByTagName('memory_util')[0].childNodes[0].data.replace("%", "").strip() .getElementsByTagName('gpu_util')[0]\
.childNodes[0].data.replace("%", "").strip()
gpuInfo['gpuMemUtil'] = gpu.getElementsByTagName('utilization')[0]\
.getElementsByTagName('memory_util')[0]\
.childNodes[0].data.replace("%", "").strip()
processes = gpu.getElementsByTagName('processes') processes = gpu.getElementsByTagName('processes')
runningProNumber = len(processes[0].getElementsByTagName('process_info')) runningProNumber = len(processes[0].getElementsByTagName('process_info'))
gpuInfo['activeProcessNum'] = runningProNumber gpuInfo['activeProcessNum'] = runningProNumber
...@@ -81,8 +86,8 @@ def parse_nvidia_smi_result(smi, outputDir): ...@@ -81,8 +86,8 @@ def parse_nvidia_smi_result(smi, outputDir):
print(outPut) print(outPut)
outputFile.write("{}\n".format(json.dumps(outPut, sort_keys=True))) outputFile.write("{}\n".format(json.dumps(outPut, sort_keys=True)))
outputFile.flush(); outputFile.flush();
except : except:
e_info = sys.exc_info() # e_info = sys.exc_info()
print('xmldoc paring error') print('xmldoc paring error')
finally: finally:
os.umask(old_umask) os.umask(old_umask)
......
...@@ -20,7 +20,6 @@ ...@@ -20,7 +20,6 @@
import os import os
import posixpath import posixpath
from pyhdfs import HdfsClient
from .log_utils import LogType, nni_log from .log_utils import LogType, nni_log
def copyHdfsDirectoryToLocal(hdfsDirectory, localDirectory, hdfsClient): def copyHdfsDirectoryToLocal(hdfsDirectory, localDirectory, hdfsClient):
...@@ -79,7 +78,8 @@ def copyDirectoryToHdfs(localDirectory, hdfsDirectory, hdfsClient): ...@@ -79,7 +78,8 @@ def copyDirectoryToHdfs(localDirectory, hdfsDirectory, hdfsClient):
try: try:
result = result and copyDirectoryToHdfs(file_path, hdfs_directory, hdfsClient) result = result and copyDirectoryToHdfs(file_path, hdfs_directory, hdfsClient)
except Exception as exception: except Exception as exception:
nni_log(LogType.Error, 'Copy local directory {0} to hdfs directory {1} error: {2}'.format(file_path, hdfs_directory, str(exception))) nni_log(LogType.Error,
'Copy local directory {0} to hdfs directory {1} error: {2}'.format(file_path, hdfs_directory, str(exception)))
result = False result = False
else: else:
hdfs_file_path = os.path.join(hdfsDirectory, file) hdfs_file_path = os.path.join(hdfsDirectory, file)
......
...@@ -33,8 +33,7 @@ from logging import StreamHandler ...@@ -33,8 +33,7 @@ from logging import StreamHandler
from queue import Queue from queue import Queue
from .rest_utils import rest_get, rest_post, rest_put, rest_delete from .rest_utils import rest_post
from .constants import NNI_EXP_ID, NNI_TRIAL_JOB_ID, STDOUT_API
from .url_utils import gen_send_stdout_url from .url_utils import gen_send_stdout_url
@unique @unique
...@@ -73,7 +72,7 @@ class NNIRestLogHanlder(StreamHandler): ...@@ -73,7 +72,7 @@ class NNIRestLogHanlder(StreamHandler):
log_entry['msg'] = self.format(record) log_entry['msg'] = self.format(record)
try: try:
response = rest_post(gen_send_stdout_url(self.host, self.port), json.dumps(log_entry), 10, True) rest_post(gen_send_stdout_url(self.host, self.port), json.dumps(log_entry), 10, True)
except Exception as e: except Exception as e:
self.orig_stderr.write(str(e) + '\n') self.orig_stderr.write(str(e) + '\n')
self.orig_stderr.flush() self.orig_stderr.flush()
...@@ -112,7 +111,7 @@ class RemoteLogger(object): ...@@ -112,7 +111,7 @@ class RemoteLogger(object):
self.orig_stdout.flush() self.orig_stdout.flush()
try: try:
self.logger.log(self.log_level, line.rstrip()) self.logger.log(self.log_level, line.rstrip())
except Exception as e: except Exception:
pass pass
class PipeLogReader(threading.Thread): class PipeLogReader(threading.Thread):
...@@ -147,15 +146,14 @@ class PipeLogReader(threading.Thread): ...@@ -147,15 +146,14 @@ class PipeLogReader(threading.Thread):
line = self.queue.get(True, 5) line = self.queue.get(True, 5)
try: try:
self.logger.log(self.log_level, line.rstrip()) self.logger.log(self.log_level, line.rstrip())
except Exception as e: except Exception:
pass pass
except Exception as e: except Exception:
if cur_process_exit == True: if cur_process_exit == True:
self._is_read_completed = True self._is_read_completed = True
break break
self.pip_log_reader_thread = threading.Thread(target = _populateQueue, self.pip_log_reader_thread = threading.Thread(target=_populateQueue, args=(self.pipeReader, self.queue))
args = (self.pipeReader, self.queue))
self.pip_log_reader_thread.daemon = True self.pip_log_reader_thread.daemon = True
self.start() self.start()
self.pip_log_reader_thread.start() self.pip_log_reader_thread.start()
...@@ -196,4 +194,4 @@ class PipeLogReader(threading.Thread): ...@@ -196,4 +194,4 @@ class PipeLogReader(threading.Thread):
def set_process_exit(self): def set_process_exit(self):
self.process_exit = True self.process_exit = True
return self.process_exit return self.process_exit
\ No newline at end of file
...@@ -19,7 +19,6 @@ ...@@ -19,7 +19,6 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
import time
import requests import requests
def rest_get(url, timeout): def rest_get(url, timeout):
......
...@@ -18,16 +18,17 @@ ...@@ -18,16 +18,17 @@
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
import os
import shutil
import random
import string
import unittest import unittest
import json import json
import sys import sys
from pyhdfs import HdfsClient from pyhdfs import HdfsClient
from tools.nni_trial_tool.hdfsClientUtility import copyFileToHdfs, copyDirectoryToHdfs
sys.path.append("..") sys.path.append("..")
from trial.hdfsClientUtility import copyFileToHdfs, copyDirectoryToHdfs
import os
import shutil
import random
import string
class HDFSClientUtilityTest(unittest.TestCase): class HDFSClientUtilityTest(unittest.TestCase):
'''Unit test for hdfsClientUtility.py''' '''Unit test for hdfsClientUtility.py'''
...@@ -82,7 +83,8 @@ class HDFSClientUtilityTest(unittest.TestCase): ...@@ -82,7 +83,8 @@ class HDFSClientUtilityTest(unittest.TestCase):
with open('./{0}/{1}'.format(directory_name, file_name), 'w') as file: with open('./{0}/{1}'.format(directory_name, file_name), 'w') as file:
file.write(file_content) file.write(file_content)
result = copyDirectoryToHdfs('./{}'.format(directory_name), '/{0}/{1}'.format(self.hdfs_config['userName'], directory_name), self.hdfs_client) result = copyDirectoryToHdfs('./{}'.format(directory_name),
'/{0}/{1}'.format(self.hdfs_config['userName'], directory_name), self.hdfs_client)
self.assertTrue(result) self.assertTrue(result)
directory_list = self.hdfs_client.listdir('/{0}'.format(self.hdfs_config['userName'])) directory_list = self.hdfs_client.listdir('/{0}'.format(self.hdfs_config['userName']))
......
...@@ -18,32 +18,30 @@ ...@@ -18,32 +18,30 @@
# ============================================================================================================================== # # ============================================================================================================================== #
import argparse import argparse
import sys
import os import os
from subprocess import Popen, PIPE from subprocess import Popen
import time import time
import logging import logging
import shlex import shlex
import re import re
import sys import sys
import select
import json import json
import threading import threading
from pyhdfs import HdfsClient from pyhdfs import HdfsClient
import pkg_resources import pkg_resources
from .rest_utils import rest_post, rest_get from .rest_utils import rest_post, rest_get
from .url_utils import gen_send_stdout_url, gen_send_version_url, gen_parameter_meta_url from .url_utils import gen_send_version_url, gen_parameter_meta_url
from .constants import HOME_DIR, LOG_DIR, NNI_PLATFORM, STDOUT_FULL_PATH, STDERR_FULL_PATH, \ from .constants import LOG_DIR, NNI_PLATFORM, MULTI_PHASE, NNI_TRIAL_JOB_ID, NNI_SYS_DIR, NNI_EXP_ID
MULTI_PHASE, NNI_TRIAL_JOB_ID, NNI_SYS_DIR, NNI_EXP_ID
from .hdfsClientUtility import copyDirectoryToHdfs, copyHdfsDirectoryToLocal, copyHdfsFileToLocal from .hdfsClientUtility import copyDirectoryToHdfs, copyHdfsDirectoryToLocal, copyHdfsFileToLocal
from .log_utils import LogType, nni_log, RemoteLogger, PipeLogReader, StdOutputType from .log_utils import LogType, nni_log, RemoteLogger, StdOutputType
logger = logging.getLogger('trial_keeper') logger = logging.getLogger('trial_keeper')
regular = re.compile('v?(?P<version>[0-9](\.[0-9]){0,1}).*') regular = re.compile('v?(?P<version>[0-9](\.[0-9]){0,1}).*')
_hdfs_client = None _hdfs_client = None
def get_hdfs_client(args): def get_hdfs_client(args):
global _hdfs_client global _hdfs_client
...@@ -62,26 +60,29 @@ def get_hdfs_client(args): ...@@ -62,26 +60,29 @@ def get_hdfs_client(args):
if hdfs_host is not None and args.nni_hdfs_exp_dir is not None: if hdfs_host is not None and args.nni_hdfs_exp_dir is not None:
try: try:
if args.webhdfs_path: if args.webhdfs_path:
_hdfs_client = HdfsClient(hosts='{0}:80'.format(hdfs_host), user_name=args.pai_user_name, webhdfs_path=args.webhdfs_path, timeout=5) _hdfs_client = HdfsClient(hosts='{0}:80'.format(hdfs_host), user_name=args.pai_user_name,
webhdfs_path=args.webhdfs_path, timeout=5)
else: else:
# backward compatibility # backward compatibility
_hdfs_client = HdfsClient(hosts='{0}:{1}'.format(hdfs_host, '50070'), user_name=args.pai_user_name, timeout=5) _hdfs_client = HdfsClient(hosts='{0}:{1}'.format(hdfs_host, '50070'), user_name=args.pai_user_name,
timeout=5)
except Exception as e: except Exception as e:
nni_log(LogType.Error, 'Create HDFS client error: ' + str(e)) nni_log(LogType.Error, 'Create HDFS client error: ' + str(e))
raise e raise e
return _hdfs_client return _hdfs_client
def main_loop(args): def main_loop(args):
'''main loop logic for trial keeper''' '''main loop logic for trial keeper'''
if not os.path.exists(LOG_DIR): if not os.path.exists(LOG_DIR):
os.makedirs(LOG_DIR) os.makedirs(LOG_DIR)
stdout_file = open(STDOUT_FULL_PATH, 'a+') trial_keeper_syslogger = RemoteLogger(args.nnimanager_ip, args.nnimanager_port, 'trial_keeper',
stderr_file = open(STDERR_FULL_PATH, 'a+') StdOutputType.Stdout, args.log_collection)
trial_keeper_syslogger = RemoteLogger(args.nnimanager_ip, args.nnimanager_port, 'trial_keeper', StdOutputType.Stdout, args.log_collection)
# redirect trial keeper's stdout and stderr to syslog # redirect trial keeper's stdout and stderr to syslog
trial_syslogger_stdout = RemoteLogger(args.nnimanager_ip, args.nnimanager_port, 'trial', StdOutputType.Stdout, args.log_collection) trial_syslogger_stdout = RemoteLogger(args.nnimanager_ip, args.nnimanager_port, 'trial', StdOutputType.Stdout,
args.log_collection)
sys.stdout = sys.stderr = trial_keeper_syslogger sys.stdout = sys.stderr = trial_keeper_syslogger
hdfs_output_dir = None hdfs_output_dir = None
...@@ -97,8 +98,10 @@ def main_loop(args): ...@@ -97,8 +98,10 @@ def main_loop(args):
# Notice: We don't appoint env, which means subprocess wil inherit current environment and that is expected behavior # Notice: We don't appoint env, which means subprocess wil inherit current environment and that is expected behavior
log_pipe_stdout = trial_syslogger_stdout.get_pipelog_reader() log_pipe_stdout = trial_syslogger_stdout.get_pipelog_reader()
process = Popen(args.trial_command, shell = True, stdout = log_pipe_stdout, stderr = log_pipe_stdout) process = Popen(args.trial_command, shell=True, stdout=log_pipe_stdout, stderr=log_pipe_stdout)
nni_log(LogType.Info, 'Trial keeper spawns a subprocess (pid {0}) to run command: {1}'.format(process.pid, shlex.split(args.trial_command))) nni_log(LogType.Info, 'Trial keeper spawns a subprocess (pid {0}) to run command: {1}'.format(process.pid,
shlex.split(
args.trial_command)))
while True: while True:
retCode = process.poll() retCode = process.poll()
...@@ -110,9 +113,11 @@ def main_loop(args): ...@@ -110,9 +113,11 @@ def main_loop(args):
nni_local_output_dir = os.environ['NNI_OUTPUT_DIR'] nni_local_output_dir = os.environ['NNI_OUTPUT_DIR']
try: try:
if copyDirectoryToHdfs(nni_local_output_dir, hdfs_output_dir, hdfs_client): if copyDirectoryToHdfs(nni_local_output_dir, hdfs_output_dir, hdfs_client):
nni_log(LogType.Info, 'copy directory from {0} to {1} success!'.format(nni_local_output_dir, hdfs_output_dir)) nni_log(LogType.Info,
'copy directory from {0} to {1} success!'.format(nni_local_output_dir, hdfs_output_dir))
else: else:
nni_log(LogType.Info, 'copy directory from {0} to {1} failed!'.format(nni_local_output_dir, hdfs_output_dir)) nni_log(LogType.Info,
'copy directory from {0} to {1} failed!'.format(nni_local_output_dir, hdfs_output_dir))
except Exception as e: except Exception as e:
nni_log(LogType.Error, 'HDFS copy directory got exception: ' + str(e)) nni_log(LogType.Error, 'HDFS copy directory got exception: ' + str(e))
raise e raise e
...@@ -123,14 +128,16 @@ def main_loop(args): ...@@ -123,14 +128,16 @@ def main_loop(args):
time.sleep(2) time.sleep(2)
def trial_keeper_help_info(*args): def trial_keeper_help_info(*args):
print('please run --help to see guidance') print('please run --help to see guidance')
def check_version(args): def check_version(args):
try: try:
trial_keeper_version = pkg_resources.get_distribution('nni').version trial_keeper_version = pkg_resources.get_distribution('nni').version
except pkg_resources.ResolutionError as err: except pkg_resources.ResolutionError as err:
#package nni does not exist, try nni-tool package # package nni does not exist, try nni-tool package
nni_log(LogType.Error, 'Package nni does not exist!') nni_log(LogType.Error, 'Package nni does not exist!')
os._exit(1) os._exit(1)
if not args.nni_manager_version: if not args.nni_manager_version:
...@@ -145,21 +152,26 @@ def check_version(args): ...@@ -145,21 +152,26 @@ def check_version(args):
log_entry = {} log_entry = {}
if trial_keeper_version != nni_manager_version: if trial_keeper_version != nni_manager_version:
nni_log(LogType.Error, 'Version does not match!') nni_log(LogType.Error, 'Version does not match!')
error_message = 'NNIManager version is {0}, TrialKeeper version is {1}, NNI version does not match!'.format(nni_manager_version, trial_keeper_version) error_message = 'NNIManager version is {0}, TrialKeeper version is {1}, NNI version does not match!'.format(
nni_manager_version, trial_keeper_version)
log_entry['tag'] = 'VCFail' log_entry['tag'] = 'VCFail'
log_entry['msg'] = error_message log_entry['msg'] = error_message
rest_post(gen_send_version_url(args.nnimanager_ip, args.nnimanager_port), json.dumps(log_entry), 10, False) rest_post(gen_send_version_url(args.nnimanager_ip, args.nnimanager_port), json.dumps(log_entry), 10,
False)
os._exit(1) os._exit(1)
else: else:
nni_log(LogType.Info, 'Version match!') nni_log(LogType.Info, 'Version match!')
log_entry['tag'] = 'VCSuccess' log_entry['tag'] = 'VCSuccess'
rest_post(gen_send_version_url(args.nnimanager_ip, args.nnimanager_port), json.dumps(log_entry), 10, False) rest_post(gen_send_version_url(args.nnimanager_ip, args.nnimanager_port), json.dumps(log_entry), 10,
False)
except AttributeError as err: except AttributeError as err:
nni_log(LogType.Error, err) nni_log(LogType.Error, err)
def is_multi_phase(): def is_multi_phase():
return MULTI_PHASE and (MULTI_PHASE in ['True', 'true']) return MULTI_PHASE and (MULTI_PHASE in ['True', 'true'])
def download_parameter(meta_list, args): def download_parameter(meta_list, args):
""" """
Download parameter file to local working directory. Download parameter file to local working directory.
...@@ -171,7 +183,8 @@ def download_parameter(meta_list, args): ...@@ -171,7 +183,8 @@ def download_parameter(meta_list, args):
] ]
""" """
nni_log(LogType.Debug, str(meta_list)) nni_log(LogType.Debug, str(meta_list))
nni_log(LogType.Debug, 'NNI_SYS_DIR: {}, trial Id: {}, experiment ID: {}'.format(NNI_SYS_DIR, NNI_TRIAL_JOB_ID, NNI_EXP_ID)) nni_log(LogType.Debug,
'NNI_SYS_DIR: {}, trial Id: {}, experiment ID: {}'.format(NNI_SYS_DIR, NNI_TRIAL_JOB_ID, NNI_EXP_ID))
nni_log(LogType.Debug, 'NNI_SYS_DIR files: {}'.format(os.listdir(NNI_SYS_DIR))) nni_log(LogType.Debug, 'NNI_SYS_DIR files: {}'.format(os.listdir(NNI_SYS_DIR)))
for meta in meta_list: for meta in meta_list:
if meta['experimentId'] == NNI_EXP_ID and meta['trialId'] == NNI_TRIAL_JOB_ID: if meta['experimentId'] == NNI_EXP_ID and meta['trialId'] == NNI_TRIAL_JOB_ID:
...@@ -180,6 +193,7 @@ def download_parameter(meta_list, args): ...@@ -180,6 +193,7 @@ def download_parameter(meta_list, args):
hdfs_client = get_hdfs_client(args) hdfs_client = get_hdfs_client(args)
copyHdfsFileToLocal(meta['filePath'], param_fp, hdfs_client, override=False) copyHdfsFileToLocal(meta['filePath'], param_fp, hdfs_client, override=False)
def fetch_parameter_file(args): def fetch_parameter_file(args):
class FetchThread(threading.Thread): class FetchThread(threading.Thread):
def __init__(self, args): def __init__(self, args):
...@@ -203,6 +217,7 @@ def fetch_parameter_file(args): ...@@ -203,6 +217,7 @@ def fetch_parameter_file(args):
fetch_file_thread = FetchThread(args) fetch_file_thread = FetchThread(args)
fetch_file_thread.start() fetch_file_thread.start()
if __name__ == '__main__': if __name__ == '__main__':
'''NNI Trial Keeper main function''' '''NNI Trial Keeper main function'''
PARSER = argparse.ArgumentParser() PARSER = argparse.ArgumentParser()
...@@ -210,9 +225,9 @@ if __name__ == '__main__': ...@@ -210,9 +225,9 @@ if __name__ == '__main__':
PARSER.add_argument('--trial_command', type=str, help='Command to launch trial process') PARSER.add_argument('--trial_command', type=str, help='Command to launch trial process')
PARSER.add_argument('--nnimanager_ip', type=str, default='localhost', help='NNI manager rest server IP') PARSER.add_argument('--nnimanager_ip', type=str, default='localhost', help='NNI manager rest server IP')
PARSER.add_argument('--nnimanager_port', type=str, default='8081', help='NNI manager rest server port') PARSER.add_argument('--nnimanager_port', type=str, default='8081', help='NNI manager rest server port')
PARSER.add_argument('--pai_hdfs_output_dir', type=str, help='the output dir of pai_hdfs') # backward compatibility PARSER.add_argument('--pai_hdfs_output_dir', type=str, help='the output dir of pai_hdfs') # backward compatibility
PARSER.add_argument('--hdfs_output_dir', type=str, help='the output dir of hdfs') PARSER.add_argument('--hdfs_output_dir', type=str, help='the output dir of hdfs')
PARSER.add_argument('--pai_hdfs_host', type=str, help='the host of pai_hdfs') # backward compatibility PARSER.add_argument('--pai_hdfs_host', type=str, help='the host of pai_hdfs') # backward compatibility
PARSER.add_argument('--hdfs_host', type=str, help='the host of hdfs') PARSER.add_argument('--hdfs_host', type=str, help='the host of hdfs')
PARSER.add_argument('--pai_user_name', type=str, help='the username of hdfs') PARSER.add_argument('--pai_user_name', type=str, help='the username of hdfs')
PARSER.add_argument('--nni_hdfs_exp_dir', type=str, help='nni experiment directory in hdfs') PARSER.add_argument('--nni_hdfs_exp_dir', type=str, help='nni experiment directory in hdfs')
...@@ -233,4 +248,3 @@ if __name__ == '__main__': ...@@ -233,4 +248,3 @@ if __name__ == '__main__':
except Exception as e: except Exception as e:
nni_log(LogType.Error, 'Exit trial keeper with code 1 because Exception: {} is catched'.format(str(e))) nni_log(LogType.Error, 'Exit trial keeper with code 1 because Exception: {} is catched'.format(str(e)))
os._exit(1) os._exit(1)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment