Unverified Commit c037a7c1 authored by SparkSnail's avatar SparkSnail Committed by GitHub
Browse files

Merge pull request #213 from microsoft/master

merge master
parents 49972952 901012eb
...@@ -101,20 +101,20 @@ class CompressorTestCase(TestCase): ...@@ -101,20 +101,20 @@ class CompressorTestCase(TestCase):
def test_tf_pruner(self): def test_tf_pruner(self):
model = TfMnist() model = TfMnist()
configure_list = [{'sparsity': 0.8, 'op_types': ['default']}] configure_list = [{'sparsity': 0.8, 'op_types': ['default']}]
tf_compressor.LevelPruner(configure_list).compress_default_graph() tf_compressor.LevelPruner(tf.get_default_graph(), configure_list).compress()
def test_tf_quantizer(self): def test_tf_quantizer(self):
model = TfMnist() model = TfMnist()
tf_compressor.NaiveQuantizer([{'op_types': ['default']}]).compress_default_graph() tf_compressor.NaiveQuantizer(tf.get_default_graph(), [{'op_types': ['default']}]).compress()
def test_torch_pruner(self): def test_torch_pruner(self):
model = TorchMnist() model = TorchMnist()
configure_list = [{'sparsity': 0.8, 'op_types': ['default']}] configure_list = [{'sparsity': 0.8, 'op_types': ['default']}]
torch_compressor.LevelPruner(configure_list).compress(model) torch_compressor.LevelPruner(model, configure_list).compress()
def test_torch_quantizer(self): def test_torch_quantizer(self):
model = TorchMnist() model = TorchMnist()
torch_compressor.NaiveQuantizer([{'op_types': ['default']}]).compress(model) torch_compressor.NaiveQuantizer(model, [{'op_types': ['default']}]).compress()
if __name__ == '__main__': if __name__ == '__main__':
......
# Copyright (c) Microsoft Corporation. All rights reserved.
#
# MIT License
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and
# associated documentation files (the "Software"), to deal in the Software without restriction,
# including without limitation the rights to use, copy, modify, merge, publish, distribute,
# sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all copies or
# substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT
# NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT
# OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
# ==================================================================================================
import json
from io import BytesIO
from unittest import TestCase, main
import nni.protocol
from nni.msg_dispatcher import MsgDispatcher
from nni.protocol import CommandType, send, receive
from nni.tuner import Tuner
from nni.utils import extract_scalar_reward
class NaiveTuner(Tuner):
def __init__(self):
self.param = 0
self.trial_results = []
self.search_space = None
self._accept_customized_trials()
def generate_parameters(self, parameter_id, **kwargs):
# report Tuner's internal states to generated parameters,
# so we don't need to pause the main loop
self.param += 2
return {
'param': self.param,
'trial_results': self.trial_results,
'search_space': self.search_space
}
def receive_trial_result(self, parameter_id, parameters, value, **kwargs):
reward = extract_scalar_reward(value)
self.trial_results.append((parameter_id, parameters['param'], reward, kwargs.get("customized")))
def update_search_space(self, search_space):
self.search_space = search_space
_in_buf = BytesIO()
_out_buf = BytesIO()
def _reverse_io():
_in_buf.seek(0)
_out_buf.seek(0)
nni.protocol._out_file = _in_buf
nni.protocol._in_file = _out_buf
def _restore_io():
_in_buf.seek(0)
_out_buf.seek(0)
nni.protocol._in_file = _in_buf
nni.protocol._out_file = _out_buf
class MsgDispatcherTestCase(TestCase):
def test_msg_dispatcher(self):
_reverse_io() # now we are sending to Tuner's incoming stream
send(CommandType.RequestTrialJobs, '2')
send(CommandType.ReportMetricData, '{"parameter_id":0,"type":"PERIODICAL","value":10}')
send(CommandType.ReportMetricData, '{"parameter_id":1,"type":"FINAL","value":11}')
send(CommandType.UpdateSearchSpace, '{"name":"SS0"}')
send(CommandType.AddCustomizedTrialJob, '{"param":-1}')
send(CommandType.ReportMetricData, '{"parameter_id":2,"type":"FINAL","value":22}')
send(CommandType.RequestTrialJobs, '1')
send(CommandType.KillTrialJob, 'null')
_restore_io()
tuner = NaiveTuner()
dispatcher = MsgDispatcher(tuner)
nni.msg_dispatcher_base._worker_fast_exit_on_terminate = False
dispatcher.run()
e = dispatcher.worker_exceptions[0]
self.assertIs(type(e), AssertionError)
self.assertEqual(e.args[0], 'Unsupported command: CommandType.KillTrialJob')
_reverse_io() # now we are receiving from Tuner's outgoing stream
self._assert_params(0, 2, [], None)
self._assert_params(1, 4, [], None)
command, data = receive() # this one is customized
data = json.loads(data)
self.assertIs(command, CommandType.NewTrialJob)
self.assertEqual(data['parameter_id'], 2)
self.assertEqual(data['parameter_source'], 'customized')
self.assertEqual(data['parameters'], {'param': -1})
self._assert_params(3, 6, [[1, 4, 11, False], [2, -1, 22, True]], {'name': 'SS0'})
self.assertEqual(len(_out_buf.read()), 0) # no more commands
def _assert_params(self, parameter_id, param, trial_results, search_space):
command, data = receive()
self.assertIs(command, CommandType.NewTrialJob)
data = json.loads(data)
self.assertEqual(data['parameter_id'], parameter_id)
self.assertEqual(data['parameter_source'], 'algorithm')
self.assertEqual(data['parameters']['param'], param)
self.assertEqual(data['parameters']['trial_results'], trial_results)
self.assertEqual(data['parameters']['search_space'], search_space)
if __name__ == '__main__':
main()
...@@ -17,107 +17,184 @@ ...@@ -17,107 +17,184 @@
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT # DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT
# OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. # OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
# ================================================================================================== # ==================================================================================================
import glob
import nni.protocol
from nni.protocol import CommandType, send, receive
from nni.tuner import Tuner
from nni.msg_dispatcher import MsgDispatcher
from nni.utils import extract_scalar_reward
from io import BytesIO
import json import json
import logging
import os
import shutil
import sys
from unittest import TestCase, main from unittest import TestCase, main
from nni.batch_tuner.batch_tuner import BatchTuner
from nni.evolution_tuner.evolution_tuner import EvolutionTuner
from nni.gp_tuner.gp_tuner import GPTuner
from nni.gridsearch_tuner.gridsearch_tuner import GridSearchTuner
from nni.hyperopt_tuner.hyperopt_tuner import HyperoptTuner
from nni.metis_tuner.metis_tuner import MetisTuner
try:
from nni.smac_tuner.smac_tuner import SMACTuner
except ImportError:
assert sys.platform == "win32"
from nni.tuner import Tuner
class NaiveTuner(Tuner): logging.basicConfig(level=logging.INFO)
def __init__(self): logger = logging.getLogger('test_tuner')
self.param = 0
self.trial_results = []
self.search_space = None
self._accept_customized_trials()
def generate_parameters(self, parameter_id, **kwargs):
# report Tuner's internal states to generated parameters,
# so we don't need to pause the main loop
self.param += 2
return {
'param': self.param,
'trial_results': self.trial_results,
'search_space': self.search_space
}
def receive_trial_result(self, parameter_id, parameters, value, customized, **kwargs):
reward = extract_scalar_reward(value)
self.trial_results.append((parameter_id, parameters['param'], reward, customized))
def update_search_space(self, search_space):
self.search_space = search_space
_in_buf = BytesIO()
_out_buf = BytesIO()
def _reverse_io():
_in_buf.seek(0)
_out_buf.seek(0)
nni.protocol._out_file = _in_buf
nni.protocol._in_file = _out_buf
def _restore_io():
_in_buf.seek(0)
_out_buf.seek(0)
nni.protocol._in_file = _in_buf
nni.protocol._out_file = _out_buf
class TunerTestCase(TestCase): class TunerTestCase(TestCase):
def test_tuner(self): """
_reverse_io() # now we are sending to Tuner's incoming stream Targeted at testing functions of built-in tuners, including
send(CommandType.RequestTrialJobs, '2') - [ ] load_checkpoint
send(CommandType.ReportMetricData, '{"parameter_id":0,"type":"PERIODICAL","value":10}') - [ ] save_checkpoint
send(CommandType.ReportMetricData, '{"parameter_id":1,"type":"FINAL","value":11}') - [X] update_search_space
send(CommandType.UpdateSearchSpace, '{"name":"SS0"}') - [X] generate_multiple_parameters
send(CommandType.AddCustomizedTrialJob, '{"param":-1}') - [ ] import_data
send(CommandType.ReportMetricData, '{"parameter_id":2,"type":"FINAL","value":22}') - [ ] trial_end
send(CommandType.RequestTrialJobs, '1') - [ ] receive_trial_result
send(CommandType.KillTrialJob, 'null') """
_restore_io()
def search_space_test_one(self, tuner_factory, search_space):
tuner = NaiveTuner() tuner = tuner_factory()
dispatcher = MsgDispatcher(tuner) self.assertIsInstance(tuner, Tuner)
nni.msg_dispatcher_base._worker_fast_exit_on_terminate = False tuner.update_search_space(search_space)
dispatcher.run() parameters = tuner.generate_multiple_parameters(list(range(0, 50)))
e = dispatcher.worker_exceptions[0] logger.info(parameters)
self.assertIs(type(e), AssertionError) self.check_range(parameters, search_space)
self.assertEqual(e.args[0], 'Unsupported command: CommandType.KillTrialJob') if not parameters: # TODO: not strict
raise ValueError("No parameters generated")
_reverse_io() # now we are receiving from Tuner's outgoing stream return parameters
self._assert_params(0, 2, [], None)
self._assert_params(1, 4, [], None) def check_range(self, generated_params, search_space):
EPS = 1E-6
command, data = receive() # this one is customized for param in generated_params:
data = json.loads(data) if self._testMethodName == "test_batch":
self.assertIs(command, CommandType.NewTrialJob) param = {list(search_space.keys())[0]: param}
self.assertEqual(data['parameter_id'], 2) for k, v in param.items():
self.assertEqual(data['parameter_source'], 'customized') if k.startswith("_mutable_layer"):
self.assertEqual(data['parameters'], {'param': -1}) _, block, layer, choice = k.split("/")
cand = search_space[block]["_value"][layer].get(choice)
self._assert_params(3, 6, [[1, 4, 11, False], [2, -1, 22, True]], {'name': 'SS0'}) # cand could be None, e.g., optional_inputs_chosen_state
if choice == "layer_choice":
self.assertEqual(len(_out_buf.read()), 0) # no more commands self.assertIn(v, cand)
if choice == "optional_input_size":
def _assert_params(self, parameter_id, param, trial_results, search_space): if isinstance(cand, int):
command, data = receive() self.assertEqual(v, cand)
self.assertIs(command, CommandType.NewTrialJob) else:
data = json.loads(data) self.assertGreaterEqual(v, cand[0])
self.assertEqual(data['parameter_id'], parameter_id) self.assertLessEqual(v, cand[1])
self.assertEqual(data['parameter_source'], 'algorithm') if choice == "optional_inputs":
self.assertEqual(data['parameters']['param'], param) pass # ignore for now
self.assertEqual(data['parameters']['trial_results'], trial_results) continue
self.assertEqual(data['parameters']['search_space'], search_space) item = search_space[k]
if item["_type"] == "choice":
self.assertIn(v, item["_value"])
if item["_type"] == "randint":
self.assertIsInstance(v, int)
if item["_type"] == "uniform":
self.assertIsInstance(v, float)
if item["_type"] in ("randint", "uniform", "quniform", "loguniform", "qloguniform"):
self.assertGreaterEqual(v, item["_value"][0])
self.assertLessEqual(v, item["_value"][1])
if item["_type"].startswith("q"):
multiple = v / item["_value"][2]
print(k, v, multiple, item)
if item["_value"][0] + EPS < v < item["_value"][1] - EPS:
self.assertAlmostEqual(int(round(multiple)), multiple)
if item["_type"] in ("qlognormal", "lognormal"):
self.assertGreaterEqual(v, 0)
if item["_type"] == "mutable_layer":
for layer_name in item["_value"].keys():
self.assertIn(v[layer_name]["chosen_layer"], item["layer_choice"])
def search_space_test_all(self, tuner_factory, supported_types=None, ignore_types=None):
# NOTE(yuge): ignore types
# Supported types are listed in the table. They are meant to be supported and should be correct.
# Other than those, all the rest are "unsupported", which are expected to produce ridiculous results
# or throw some exceptions. However, there are certain types I can't check. For example, generate
# "normal" using GP Tuner returns successfully and results are fine if we check the range (-inf to +inf),
# but they make no sense: it's not a normal distribution. So they are ignored in tests for now.
with open(os.path.join(os.path.dirname(__file__), "assets/search_space.json"), "r") as fp:
search_space_all = json.load(fp)
if supported_types is None:
supported_types = ["choice", "randint", "uniform", "quniform", "loguniform", "qloguniform",
"normal", "qnormal", "lognormal", "qlognormal"]
full_supported_search_space = dict()
for single in search_space_all:
single_keyword = single.split("_")
space = search_space_all[single]
expected_fail = not any([t in single_keyword for t in supported_types]) or "fail" in single_keyword
if ignore_types is not None and any([t in ignore_types for t in single_keyword]):
continue
if "fail" in space:
if self._testMethodName.split("_", 1)[1] in space.pop("fail"):
expected_fail = True
single_search_space = {single: space}
if not expected_fail:
# supports this key
self.search_space_test_one(tuner_factory, single_search_space)
full_supported_search_space.update(single_search_space)
else:
# unsupported key
with self.assertRaises(Exception, msg="Testing {}".format(single)) as cm:
self.search_space_test_one(tuner_factory, single_search_space)
logger.info("%s %s %s", tuner_factory, single, cm.exception)
if not any(t in self._testMethodName for t in ["batch", "grid_search"]):
# grid search fails for too many combinations
logger.info("Full supported search space: %s", full_supported_search_space)
self.search_space_test_one(tuner_factory, full_supported_search_space)
def test_grid_search(self):
self.search_space_test_all(lambda: GridSearchTuner(),
supported_types=["choice", "randint", "quniform"])
def test_tpe(self):
self.search_space_test_all(lambda: HyperoptTuner("tpe"))
def test_random_search(self):
self.search_space_test_all(lambda: HyperoptTuner("random_search"))
def test_anneal(self):
self.search_space_test_all(lambda: HyperoptTuner("anneal"))
def test_smac(self):
if sys.platform == "win32":
return # smac doesn't work on windows
self.search_space_test_all(lambda: SMACTuner(),
supported_types=["choice", "randint", "uniform", "quniform", "loguniform"])
def test_batch(self):
self.search_space_test_all(lambda: BatchTuner(),
supported_types=["choice"])
def test_evolution(self):
# Needs enough population size, otherwise it will throw a runtime error
self.search_space_test_all(lambda: EvolutionTuner(population_size=100))
def test_gp(self):
self.search_space_test_all(lambda: GPTuner(),
supported_types=["choice", "randint", "uniform", "quniform", "loguniform",
"qloguniform"],
ignore_types=["normal", "lognormal", "qnormal", "qlognormal"])
def test_metis(self):
self.search_space_test_all(lambda: MetisTuner(),
supported_types=["choice", "randint", "uniform", "quniform"])
def test_networkmorphism(self):
pass
def test_ppo(self):
pass
def tearDown(self):
file_list = glob.glob("smac3*") + ["param_config_space.pcs", "scenario.txt", "model_path"]
for file in file_list:
if os.path.exists(file):
if os.path.isdir(file):
shutil.rmtree(file)
else:
os.remove(file)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -24,9 +24,10 @@ import glob ...@@ -24,9 +24,10 @@ import glob
import subprocess import subprocess
import time import time
import traceback import traceback
import json
from utils import setup_experiment, get_experiment_status, get_yml_content, dump_yml_content, \ from utils import setup_experiment, get_experiment_status, get_yml_content, dump_yml_content, \
parse_max_duration_time, get_succeeded_trial_num, print_stderr, deep_update parse_max_duration_time, get_succeeded_trial_num, deep_update, print_failed_job_log, get_failed_trial_jobs
from utils import GREEN, RED, CLEAR, STATUS_URL, TRIAL_JOBS_URL from utils import GREEN, RED, CLEAR, STATUS_URL, TRIAL_JOBS_URL
def gen_new_config(config_file, training_service='local'): def gen_new_config(config_file, training_service='local'):
...@@ -37,18 +38,18 @@ def gen_new_config(config_file, training_service='local'): ...@@ -37,18 +38,18 @@ def gen_new_config(config_file, training_service='local'):
config = get_yml_content(config_file) config = get_yml_content(config_file)
new_config_file = config_file + '.tmp' new_config_file = config_file + '.tmp'
ts = get_yml_content('training_service.yml')[training_service] it_config = get_yml_content('training_service.yml')
print(ts)
# hack for kubeflow trial config # hack for kubeflow trial config
if training_service == 'kubeflow': if training_service == 'kubeflow':
ts['trial']['worker']['command'] = config['trial']['command'] it_config[training_service]['trial']['worker']['command'] = config['trial']['command']
config['trial'].pop('command') config['trial'].pop('command')
if 'gpuNum' in config['trial']: if 'gpuNum' in config['trial']:
config['trial'].pop('gpuNum') config['trial'].pop('gpuNum')
deep_update(config, ts) deep_update(config, it_config['all'])
print(config) deep_update(config, it_config[training_service])
dump_yml_content(new_config_file, config) dump_yml_content(new_config_file, config)
return new_config_file, config return new_config_file, config
...@@ -57,6 +58,7 @@ def run_test(config_file, training_service, local_gpu=False): ...@@ -57,6 +58,7 @@ def run_test(config_file, training_service, local_gpu=False):
'''run test per configuration file''' '''run test per configuration file'''
new_config_file, config = gen_new_config(config_file, training_service) new_config_file, config = gen_new_config(config_file, training_service)
print(json.dumps(config, sort_keys=True, indent=4))
if training_service == 'local' and not local_gpu and config['trial']['gpuNum'] > 0: if training_service == 'local' and not local_gpu and config['trial']['gpuNum'] > 0:
print('no gpu, skiping: ', config_file) print('no gpu, skiping: ', config_file)
...@@ -72,14 +74,12 @@ def run_test(config_file, training_service, local_gpu=False): ...@@ -72,14 +74,12 @@ def run_test(config_file, training_service, local_gpu=False):
for _ in range(0, max_duration+30, sleep_interval): for _ in range(0, max_duration+30, sleep_interval):
time.sleep(sleep_interval) time.sleep(sleep_interval)
status = get_experiment_status(STATUS_URL) status = get_experiment_status(STATUS_URL)
if status == 'DONE': if status in ['DONE', 'ERROR'] or get_failed_trial_jobs(TRIAL_JOBS_URL):
num_succeeded = get_succeeded_trial_num(TRIAL_JOBS_URL)
if training_service == 'local':
print_stderr(TRIAL_JOBS_URL)
assert num_succeeded == max_trial_num, 'only %d succeeded trial jobs, there should be %d' % (num_succeeded, max_trial_num)
break break
assert status == 'DONE', 'Failed to finish in maxExecDuration' print_failed_job_log(config['trainingServicePlatform'], TRIAL_JOBS_URL)
if status != 'DONE' or get_succeeded_trial_num(TRIAL_JOBS_URL) < max_trial_num:
raise AssertionError('Failed to finish in maxExecDuration')
finally: finally:
if os.path.exists(new_config_file): if os.path.exists(new_config_file):
os.remove(new_config_file) os.remove(new_config_file)
......
...@@ -26,7 +26,7 @@ import traceback ...@@ -26,7 +26,7 @@ import traceback
import json import json
import requests import requests
from utils import get_experiment_status, get_yml_content, parse_max_duration_time, get_succeeded_trial_num, print_stderr from utils import get_experiment_status, get_yml_content, parse_max_duration_time, get_succeeded_trial_num, print_failed_job_log
from utils import GREEN, RED, CLEAR, STATUS_URL, TRIAL_JOBS_URL, METRICS_URL from utils import GREEN, RED, CLEAR, STATUS_URL, TRIAL_JOBS_URL, METRICS_URL
def run_test(): def run_test():
...@@ -49,7 +49,9 @@ def run_test(): ...@@ -49,7 +49,9 @@ def run_test():
#print('experiment status:', status) #print('experiment status:', status)
if status == 'DONE': if status == 'DONE':
num_succeeded = get_succeeded_trial_num(TRIAL_JOBS_URL) num_succeeded = get_succeeded_trial_num(TRIAL_JOBS_URL)
print_stderr(TRIAL_JOBS_URL) print_failed_job_log('local', TRIAL_JOBS_URL)
if sys.platform == "win32":
time.sleep(sleep_interval) # Windows seems to have some issues on updating in time
assert num_succeeded == max_trial_num, 'only %d succeeded trial jobs, there should be %d' % (num_succeeded, max_trial_num) assert num_succeeded == max_trial_num, 'only %d succeeded trial jobs, there should be %d' % (num_succeeded, max_trial_num)
check_metrics() check_metrics()
break break
......
all:
logCollection: http
kubeflow: kubeflow:
maxExecDuration: 15m maxExecDuration: 15m
nniManagerIp: nniManagerIp:
......
...@@ -20,9 +20,7 @@ echo "===========================Testing: nni_sdk===========================" ...@@ -20,9 +20,7 @@ echo "===========================Testing: nni_sdk==========================="
cd ${CWD}/../src/sdk/pynni/ cd ${CWD}/../src/sdk/pynni/
python3 -m unittest discover -v tests python3 -m unittest discover -v tests
# -------------For typescript unittest-------------
# -------------For typescrip unittest-------------
cd ${CWD}/../src/nni_manager cd ${CWD}/../src/nni_manager
echo "" echo ""
echo "===========================Testing: nni_manager===========================" echo "===========================Testing: nni_manager==========================="
......
...@@ -81,13 +81,18 @@ def get_experiment_id(experiment_url): ...@@ -81,13 +81,18 @@ def get_experiment_id(experiment_url):
experiment_id = requests.get(experiment_url).json()['id'] experiment_id = requests.get(experiment_url).json()['id']
return experiment_id return experiment_id
def get_nni_log_path(experiment_url): def get_experiment_dir(experiment_url):
'''get nni's log path from nni's experiment url''' '''get experiment root directory'''
experiment_id = get_experiment_id(experiment_url) experiment_id = get_experiment_id(experiment_url)
experiment_path = os.path.join(os.path.expanduser('~'), 'nni', 'experiments', experiment_id) return os.path.join(os.path.expanduser('~'), 'nni', 'experiments', experiment_id)
nnimanager_log_path = os.path.join(experiment_path, 'log', 'nnimanager.log')
return nnimanager_log_path def get_nni_log_dir(experiment_url):
'''get nni's log directory from nni's experiment url'''
return os.path.join(get_experiment_dir(experiment_url), 'log')
def get_nni_log_path(experiment_url):
'''get nni's log path from nni's experiment url'''
return os.path.join(get_nni_log_dir(experiment_url), 'nnimanager.log')
def is_experiment_done(nnimanager_log_path): def is_experiment_done(nnimanager_log_path):
'''check if the experiment is done successfully''' '''check if the experiment is done successfully'''
...@@ -104,7 +109,6 @@ def get_experiment_status(status_url): ...@@ -104,7 +109,6 @@ def get_experiment_status(status_url):
def get_succeeded_trial_num(trial_jobs_url): def get_succeeded_trial_num(trial_jobs_url):
trial_jobs = requests.get(trial_jobs_url).json() trial_jobs = requests.get(trial_jobs_url).json()
print(trial_jobs)
num_succeed = 0 num_succeed = 0
for trial_job in trial_jobs: for trial_job in trial_jobs:
if trial_job['status'] in ['SUCCEEDED', 'EARLY_STOPPED']: if trial_job['status'] in ['SUCCEEDED', 'EARLY_STOPPED']:
...@@ -112,17 +116,31 @@ def get_succeeded_trial_num(trial_jobs_url): ...@@ -112,17 +116,31 @@ def get_succeeded_trial_num(trial_jobs_url):
print('num_succeed:', num_succeed) print('num_succeed:', num_succeed)
return num_succeed return num_succeed
def print_stderr(trial_jobs_url): def get_failed_trial_jobs(trial_jobs_url):
'''Return failed trial jobs'''
trial_jobs = requests.get(trial_jobs_url).json() trial_jobs = requests.get(trial_jobs_url).json()
failed_jobs = []
for trial_job in trial_jobs:
if trial_job['status'] in ['FAILED']:
failed_jobs.append(trial_job)
return failed_jobs
def print_failed_job_log(training_service, trial_jobs_url):
'''Print job log of FAILED trial jobs'''
trial_jobs = get_failed_trial_jobs(trial_jobs_url)
for trial_job in trial_jobs: for trial_job in trial_jobs:
if trial_job['status'] == 'FAILED': if training_service == 'local':
if sys.platform == "win32": if sys.platform == "win32":
p = trial_job['stderrPath'].split(':') p = trial_job['stderrPath'].split(':')
stderr_path = ':'.join([p[-2], p[-1]]) log_filename = ':'.join([p[-2], p[-1]])
subprocess.run(['type', stderr_path], shell=True)
else: else:
stderr_path = trial_job['stderrPath'].split(':')[-1] log_filename = trial_job['stderrPath'].split(':')[-1]
subprocess.run(['cat', stderr_path]) else:
log_filename = os.path.join(get_experiment_dir(EXPERIMENT_URL), 'trials', trial_job['id'], 'stdout_log_collection.log')
with open(log_filename, 'r') as f:
log_content = f.read()
print(log_filename, flush=True)
print(log_content, flush=True)
def parse_max_duration_time(max_exec_duration): def parse_max_duration_time(max_exec_duration):
unit = max_exec_duration[-1] unit = max_exec_duration[-1]
......
...@@ -265,11 +265,15 @@ pai_trial_schema = { ...@@ -265,11 +265,15 @@ pai_trial_schema = {
} }
pai_config_schema = { pai_config_schema = {
'paiConfig':{ 'paiConfig': Or({
'userName': setType('userName', str), 'userName': setType('userName', str),
'passWord': setType('passWord', str), 'passWord': setType('passWord', str),
'host': setType('host', str) 'host': setType('host', str)
} }, {
'userName': setType('userName', str),
'token': setType('token', str),
'host': setType('host', str)
})
} }
kubeflow_trial_schema = { kubeflow_trial_schema = {
......
...@@ -27,19 +27,20 @@ from xml.dom import minidom ...@@ -27,19 +27,20 @@ from xml.dom import minidom
def check_ready_to_run(): def check_ready_to_run():
if sys.platform == 'win32': if sys.platform == 'win32':
pgrep_output = subprocess.check_output('wmic process where "CommandLine like \'%nni_gpu_tool.gpu_metrics_collector%\' and name like \'%python%\'" get processId') pgrep_output = subprocess.check_output(
'wmic process where "CommandLine like \'%nni_gpu_tool.gpu_metrics_collector%\' and name like \'%python%\'" get processId')
pidList = pgrep_output.decode("utf-8").strip().split() pidList = pgrep_output.decode("utf-8").strip().split()
pidList.pop(0) # remove the key word 'ProcessId' pidList.pop(0) # remove the key word 'ProcessId'
pidList = list(map(int, pidList)) pidList = list(map(int, pidList))
pidList.remove(os.getpid()) pidList.remove(os.getpid())
return len(pidList) == 0 return not pidList
else: else:
pgrep_output = subprocess.check_output('pgrep -fx \'python3 -m nni_gpu_tool.gpu_metrics_collector\'', shell=True) pgrep_output = subprocess.check_output('pgrep -fx \'python3 -m nni_gpu_tool.gpu_metrics_collector\'', shell=True)
pidList = [] pidList = []
for pid in pgrep_output.splitlines(): for pid in pgrep_output.splitlines():
pidList.append(int(pid)) pidList.append(int(pid))
pidList.remove(os.getpid()) pidList.remove(os.getpid())
return len(pidList) == 0 return not pidList
def main(argv): def main(argv):
metrics_output_dir = os.environ['METRIC_OUTPUT_DIR'] metrics_output_dir = os.environ['METRIC_OUTPUT_DIR']
...@@ -69,10 +70,14 @@ def parse_nvidia_smi_result(smi, outputDir): ...@@ -69,10 +70,14 @@ def parse_nvidia_smi_result(smi, outputDir):
outPut["gpuCount"] = len(gpuList) outPut["gpuCount"] = len(gpuList)
outPut["gpuInfos"] = [] outPut["gpuInfos"] = []
for gpuIndex, gpu in enumerate(gpuList): for gpuIndex, gpu in enumerate(gpuList):
gpuInfo ={} gpuInfo = {}
gpuInfo['index'] = gpuIndex gpuInfo['index'] = gpuIndex
gpuInfo['gpuUtil'] = gpu.getElementsByTagName('utilization')[0].getElementsByTagName('gpu_util')[0].childNodes[0].data.replace("%", "").strip() gpuInfo['gpuUtil'] = gpu.getElementsByTagName('utilization')[0]\
gpuInfo['gpuMemUtil'] = gpu.getElementsByTagName('utilization')[0].getElementsByTagName('memory_util')[0].childNodes[0].data.replace("%", "").strip() .getElementsByTagName('gpu_util')[0]\
.childNodes[0].data.replace("%", "").strip()
gpuInfo['gpuMemUtil'] = gpu.getElementsByTagName('utilization')[0]\
.getElementsByTagName('memory_util')[0]\
.childNodes[0].data.replace("%", "").strip()
processes = gpu.getElementsByTagName('processes') processes = gpu.getElementsByTagName('processes')
runningProNumber = len(processes[0].getElementsByTagName('process_info')) runningProNumber = len(processes[0].getElementsByTagName('process_info'))
gpuInfo['activeProcessNum'] = runningProNumber gpuInfo['activeProcessNum'] = runningProNumber
...@@ -81,8 +86,8 @@ def parse_nvidia_smi_result(smi, outputDir): ...@@ -81,8 +86,8 @@ def parse_nvidia_smi_result(smi, outputDir):
print(outPut) print(outPut)
outputFile.write("{}\n".format(json.dumps(outPut, sort_keys=True))) outputFile.write("{}\n".format(json.dumps(outPut, sort_keys=True)))
outputFile.flush(); outputFile.flush();
except : except:
e_info = sys.exc_info() # e_info = sys.exc_info()
print('xmldoc paring error') print('xmldoc paring error')
finally: finally:
os.umask(old_umask) os.umask(old_umask)
......
...@@ -20,7 +20,6 @@ ...@@ -20,7 +20,6 @@
import os import os
import posixpath import posixpath
from pyhdfs import HdfsClient
from .log_utils import LogType, nni_log from .log_utils import LogType, nni_log
def copyHdfsDirectoryToLocal(hdfsDirectory, localDirectory, hdfsClient): def copyHdfsDirectoryToLocal(hdfsDirectory, localDirectory, hdfsClient):
...@@ -79,7 +78,8 @@ def copyDirectoryToHdfs(localDirectory, hdfsDirectory, hdfsClient): ...@@ -79,7 +78,8 @@ def copyDirectoryToHdfs(localDirectory, hdfsDirectory, hdfsClient):
try: try:
result = result and copyDirectoryToHdfs(file_path, hdfs_directory, hdfsClient) result = result and copyDirectoryToHdfs(file_path, hdfs_directory, hdfsClient)
except Exception as exception: except Exception as exception:
nni_log(LogType.Error, 'Copy local directory {0} to hdfs directory {1} error: {2}'.format(file_path, hdfs_directory, str(exception))) nni_log(LogType.Error,
'Copy local directory {0} to hdfs directory {1} error: {2}'.format(file_path, hdfs_directory, str(exception)))
result = False result = False
else: else:
hdfs_file_path = os.path.join(hdfsDirectory, file) hdfs_file_path = os.path.join(hdfsDirectory, file)
......
...@@ -33,8 +33,7 @@ from logging import StreamHandler ...@@ -33,8 +33,7 @@ from logging import StreamHandler
from queue import Queue from queue import Queue
from .rest_utils import rest_get, rest_post, rest_put, rest_delete from .rest_utils import rest_post
from .constants import NNI_EXP_ID, NNI_TRIAL_JOB_ID, STDOUT_API
from .url_utils import gen_send_stdout_url from .url_utils import gen_send_stdout_url
@unique @unique
...@@ -73,7 +72,7 @@ class NNIRestLogHanlder(StreamHandler): ...@@ -73,7 +72,7 @@ class NNIRestLogHanlder(StreamHandler):
log_entry['msg'] = self.format(record) log_entry['msg'] = self.format(record)
try: try:
response = rest_post(gen_send_stdout_url(self.host, self.port), json.dumps(log_entry), 10, True) rest_post(gen_send_stdout_url(self.host, self.port), json.dumps(log_entry), 10, True)
except Exception as e: except Exception as e:
self.orig_stderr.write(str(e) + '\n') self.orig_stderr.write(str(e) + '\n')
self.orig_stderr.flush() self.orig_stderr.flush()
...@@ -112,7 +111,7 @@ class RemoteLogger(object): ...@@ -112,7 +111,7 @@ class RemoteLogger(object):
self.orig_stdout.flush() self.orig_stdout.flush()
try: try:
self.logger.log(self.log_level, line.rstrip()) self.logger.log(self.log_level, line.rstrip())
except Exception as e: except Exception:
pass pass
class PipeLogReader(threading.Thread): class PipeLogReader(threading.Thread):
...@@ -147,15 +146,14 @@ class PipeLogReader(threading.Thread): ...@@ -147,15 +146,14 @@ class PipeLogReader(threading.Thread):
line = self.queue.get(True, 5) line = self.queue.get(True, 5)
try: try:
self.logger.log(self.log_level, line.rstrip()) self.logger.log(self.log_level, line.rstrip())
except Exception as e: except Exception:
pass pass
except Exception as e: except Exception:
if cur_process_exit == True: if cur_process_exit == True:
self._is_read_completed = True self._is_read_completed = True
break break
self.pip_log_reader_thread = threading.Thread(target = _populateQueue, self.pip_log_reader_thread = threading.Thread(target=_populateQueue, args=(self.pipeReader, self.queue))
args = (self.pipeReader, self.queue))
self.pip_log_reader_thread.daemon = True self.pip_log_reader_thread.daemon = True
self.start() self.start()
self.pip_log_reader_thread.start() self.pip_log_reader_thread.start()
...@@ -196,4 +194,4 @@ class PipeLogReader(threading.Thread): ...@@ -196,4 +194,4 @@ class PipeLogReader(threading.Thread):
def set_process_exit(self): def set_process_exit(self):
self.process_exit = True self.process_exit = True
return self.process_exit return self.process_exit
\ No newline at end of file
...@@ -19,7 +19,6 @@ ...@@ -19,7 +19,6 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
import time
import requests import requests
def rest_get(url, timeout): def rest_get(url, timeout):
......
...@@ -18,16 +18,17 @@ ...@@ -18,16 +18,17 @@
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
import os
import shutil
import random
import string
import unittest import unittest
import json import json
import sys import sys
from pyhdfs import HdfsClient from pyhdfs import HdfsClient
from tools.nni_trial_tool.hdfsClientUtility import copyFileToHdfs, copyDirectoryToHdfs
sys.path.append("..") sys.path.append("..")
from trial.hdfsClientUtility import copyFileToHdfs, copyDirectoryToHdfs
import os
import shutil
import random
import string
class HDFSClientUtilityTest(unittest.TestCase): class HDFSClientUtilityTest(unittest.TestCase):
'''Unit test for hdfsClientUtility.py''' '''Unit test for hdfsClientUtility.py'''
...@@ -82,7 +83,8 @@ class HDFSClientUtilityTest(unittest.TestCase): ...@@ -82,7 +83,8 @@ class HDFSClientUtilityTest(unittest.TestCase):
with open('./{0}/{1}'.format(directory_name, file_name), 'w') as file: with open('./{0}/{1}'.format(directory_name, file_name), 'w') as file:
file.write(file_content) file.write(file_content)
result = copyDirectoryToHdfs('./{}'.format(directory_name), '/{0}/{1}'.format(self.hdfs_config['userName'], directory_name), self.hdfs_client) result = copyDirectoryToHdfs('./{}'.format(directory_name),
'/{0}/{1}'.format(self.hdfs_config['userName'], directory_name), self.hdfs_client)
self.assertTrue(result) self.assertTrue(result)
directory_list = self.hdfs_client.listdir('/{0}'.format(self.hdfs_config['userName'])) directory_list = self.hdfs_client.listdir('/{0}'.format(self.hdfs_config['userName']))
......
...@@ -18,32 +18,30 @@ ...@@ -18,32 +18,30 @@
# ============================================================================================================================== # # ============================================================================================================================== #
import argparse import argparse
import sys
import os import os
from subprocess import Popen, PIPE from subprocess import Popen
import time import time
import logging import logging
import shlex import shlex
import re import re
import sys import sys
import select
import json import json
import threading import threading
from pyhdfs import HdfsClient from pyhdfs import HdfsClient
import pkg_resources import pkg_resources
from .rest_utils import rest_post, rest_get from .rest_utils import rest_post, rest_get
from .url_utils import gen_send_stdout_url, gen_send_version_url, gen_parameter_meta_url from .url_utils import gen_send_version_url, gen_parameter_meta_url
from .constants import HOME_DIR, LOG_DIR, NNI_PLATFORM, STDOUT_FULL_PATH, STDERR_FULL_PATH, \ from .constants import LOG_DIR, NNI_PLATFORM, MULTI_PHASE, NNI_TRIAL_JOB_ID, NNI_SYS_DIR, NNI_EXP_ID
MULTI_PHASE, NNI_TRIAL_JOB_ID, NNI_SYS_DIR, NNI_EXP_ID
from .hdfsClientUtility import copyDirectoryToHdfs, copyHdfsDirectoryToLocal, copyHdfsFileToLocal from .hdfsClientUtility import copyDirectoryToHdfs, copyHdfsDirectoryToLocal, copyHdfsFileToLocal
from .log_utils import LogType, nni_log, RemoteLogger, PipeLogReader, StdOutputType from .log_utils import LogType, nni_log, RemoteLogger, StdOutputType
logger = logging.getLogger('trial_keeper') logger = logging.getLogger('trial_keeper')
regular = re.compile('v?(?P<version>[0-9](\.[0-9]){0,1}).*') regular = re.compile('v?(?P<version>[0-9](\.[0-9]){0,1}).*')
_hdfs_client = None _hdfs_client = None
def get_hdfs_client(args): def get_hdfs_client(args):
global _hdfs_client global _hdfs_client
...@@ -62,26 +60,29 @@ def get_hdfs_client(args): ...@@ -62,26 +60,29 @@ def get_hdfs_client(args):
if hdfs_host is not None and args.nni_hdfs_exp_dir is not None: if hdfs_host is not None and args.nni_hdfs_exp_dir is not None:
try: try:
if args.webhdfs_path: if args.webhdfs_path:
_hdfs_client = HdfsClient(hosts='{0}:80'.format(hdfs_host), user_name=args.pai_user_name, webhdfs_path=args.webhdfs_path, timeout=5) _hdfs_client = HdfsClient(hosts='{0}:80'.format(hdfs_host), user_name=args.pai_user_name,
webhdfs_path=args.webhdfs_path, timeout=5)
else: else:
# backward compatibility # backward compatibility
_hdfs_client = HdfsClient(hosts='{0}:{1}'.format(hdfs_host, '50070'), user_name=args.pai_user_name, timeout=5) _hdfs_client = HdfsClient(hosts='{0}:{1}'.format(hdfs_host, '50070'), user_name=args.pai_user_name,
timeout=5)
except Exception as e: except Exception as e:
nni_log(LogType.Error, 'Create HDFS client error: ' + str(e)) nni_log(LogType.Error, 'Create HDFS client error: ' + str(e))
raise e raise e
return _hdfs_client return _hdfs_client
def main_loop(args): def main_loop(args):
'''main loop logic for trial keeper''' '''main loop logic for trial keeper'''
if not os.path.exists(LOG_DIR): if not os.path.exists(LOG_DIR):
os.makedirs(LOG_DIR) os.makedirs(LOG_DIR)
stdout_file = open(STDOUT_FULL_PATH, 'a+') trial_keeper_syslogger = RemoteLogger(args.nnimanager_ip, args.nnimanager_port, 'trial_keeper',
stderr_file = open(STDERR_FULL_PATH, 'a+') StdOutputType.Stdout, args.log_collection)
trial_keeper_syslogger = RemoteLogger(args.nnimanager_ip, args.nnimanager_port, 'trial_keeper', StdOutputType.Stdout, args.log_collection)
# redirect trial keeper's stdout and stderr to syslog # redirect trial keeper's stdout and stderr to syslog
trial_syslogger_stdout = RemoteLogger(args.nnimanager_ip, args.nnimanager_port, 'trial', StdOutputType.Stdout, args.log_collection) trial_syslogger_stdout = RemoteLogger(args.nnimanager_ip, args.nnimanager_port, 'trial', StdOutputType.Stdout,
args.log_collection)
sys.stdout = sys.stderr = trial_keeper_syslogger sys.stdout = sys.stderr = trial_keeper_syslogger
hdfs_output_dir = None hdfs_output_dir = None
...@@ -97,8 +98,10 @@ def main_loop(args): ...@@ -97,8 +98,10 @@ def main_loop(args):
# Notice: We don't appoint env, which means subprocess wil inherit current environment and that is expected behavior # Notice: We don't appoint env, which means subprocess wil inherit current environment and that is expected behavior
log_pipe_stdout = trial_syslogger_stdout.get_pipelog_reader() log_pipe_stdout = trial_syslogger_stdout.get_pipelog_reader()
process = Popen(args.trial_command, shell = True, stdout = log_pipe_stdout, stderr = log_pipe_stdout) process = Popen(args.trial_command, shell=True, stdout=log_pipe_stdout, stderr=log_pipe_stdout)
nni_log(LogType.Info, 'Trial keeper spawns a subprocess (pid {0}) to run command: {1}'.format(process.pid, shlex.split(args.trial_command))) nni_log(LogType.Info, 'Trial keeper spawns a subprocess (pid {0}) to run command: {1}'.format(process.pid,
shlex.split(
args.trial_command)))
while True: while True:
retCode = process.poll() retCode = process.poll()
...@@ -110,9 +113,11 @@ def main_loop(args): ...@@ -110,9 +113,11 @@ def main_loop(args):
nni_local_output_dir = os.environ['NNI_OUTPUT_DIR'] nni_local_output_dir = os.environ['NNI_OUTPUT_DIR']
try: try:
if copyDirectoryToHdfs(nni_local_output_dir, hdfs_output_dir, hdfs_client): if copyDirectoryToHdfs(nni_local_output_dir, hdfs_output_dir, hdfs_client):
nni_log(LogType.Info, 'copy directory from {0} to {1} success!'.format(nni_local_output_dir, hdfs_output_dir)) nni_log(LogType.Info,
'copy directory from {0} to {1} success!'.format(nni_local_output_dir, hdfs_output_dir))
else: else:
nni_log(LogType.Info, 'copy directory from {0} to {1} failed!'.format(nni_local_output_dir, hdfs_output_dir)) nni_log(LogType.Info,
'copy directory from {0} to {1} failed!'.format(nni_local_output_dir, hdfs_output_dir))
except Exception as e: except Exception as e:
nni_log(LogType.Error, 'HDFS copy directory got exception: ' + str(e)) nni_log(LogType.Error, 'HDFS copy directory got exception: ' + str(e))
raise e raise e
...@@ -123,14 +128,16 @@ def main_loop(args): ...@@ -123,14 +128,16 @@ def main_loop(args):
time.sleep(2) time.sleep(2)
def trial_keeper_help_info(*args): def trial_keeper_help_info(*args):
print('please run --help to see guidance') print('please run --help to see guidance')
def check_version(args): def check_version(args):
try: try:
trial_keeper_version = pkg_resources.get_distribution('nni').version trial_keeper_version = pkg_resources.get_distribution('nni').version
except pkg_resources.ResolutionError as err: except pkg_resources.ResolutionError as err:
#package nni does not exist, try nni-tool package # package nni does not exist, try nni-tool package
nni_log(LogType.Error, 'Package nni does not exist!') nni_log(LogType.Error, 'Package nni does not exist!')
os._exit(1) os._exit(1)
if not args.nni_manager_version: if not args.nni_manager_version:
...@@ -145,21 +152,26 @@ def check_version(args): ...@@ -145,21 +152,26 @@ def check_version(args):
log_entry = {} log_entry = {}
if trial_keeper_version != nni_manager_version: if trial_keeper_version != nni_manager_version:
nni_log(LogType.Error, 'Version does not match!') nni_log(LogType.Error, 'Version does not match!')
error_message = 'NNIManager version is {0}, TrialKeeper version is {1}, NNI version does not match!'.format(nni_manager_version, trial_keeper_version) error_message = 'NNIManager version is {0}, TrialKeeper version is {1}, NNI version does not match!'.format(
nni_manager_version, trial_keeper_version)
log_entry['tag'] = 'VCFail' log_entry['tag'] = 'VCFail'
log_entry['msg'] = error_message log_entry['msg'] = error_message
rest_post(gen_send_version_url(args.nnimanager_ip, args.nnimanager_port), json.dumps(log_entry), 10, False) rest_post(gen_send_version_url(args.nnimanager_ip, args.nnimanager_port), json.dumps(log_entry), 10,
False)
os._exit(1) os._exit(1)
else: else:
nni_log(LogType.Info, 'Version match!') nni_log(LogType.Info, 'Version match!')
log_entry['tag'] = 'VCSuccess' log_entry['tag'] = 'VCSuccess'
rest_post(gen_send_version_url(args.nnimanager_ip, args.nnimanager_port), json.dumps(log_entry), 10, False) rest_post(gen_send_version_url(args.nnimanager_ip, args.nnimanager_port), json.dumps(log_entry), 10,
False)
except AttributeError as err: except AttributeError as err:
nni_log(LogType.Error, err) nni_log(LogType.Error, err)
def is_multi_phase(): def is_multi_phase():
return MULTI_PHASE and (MULTI_PHASE in ['True', 'true']) return MULTI_PHASE and (MULTI_PHASE in ['True', 'true'])
def download_parameter(meta_list, args): def download_parameter(meta_list, args):
""" """
Download parameter file to local working directory. Download parameter file to local working directory.
...@@ -171,7 +183,8 @@ def download_parameter(meta_list, args): ...@@ -171,7 +183,8 @@ def download_parameter(meta_list, args):
] ]
""" """
nni_log(LogType.Debug, str(meta_list)) nni_log(LogType.Debug, str(meta_list))
nni_log(LogType.Debug, 'NNI_SYS_DIR: {}, trial Id: {}, experiment ID: {}'.format(NNI_SYS_DIR, NNI_TRIAL_JOB_ID, NNI_EXP_ID)) nni_log(LogType.Debug,
'NNI_SYS_DIR: {}, trial Id: {}, experiment ID: {}'.format(NNI_SYS_DIR, NNI_TRIAL_JOB_ID, NNI_EXP_ID))
nni_log(LogType.Debug, 'NNI_SYS_DIR files: {}'.format(os.listdir(NNI_SYS_DIR))) nni_log(LogType.Debug, 'NNI_SYS_DIR files: {}'.format(os.listdir(NNI_SYS_DIR)))
for meta in meta_list: for meta in meta_list:
if meta['experimentId'] == NNI_EXP_ID and meta['trialId'] == NNI_TRIAL_JOB_ID: if meta['experimentId'] == NNI_EXP_ID and meta['trialId'] == NNI_TRIAL_JOB_ID:
...@@ -180,6 +193,7 @@ def download_parameter(meta_list, args): ...@@ -180,6 +193,7 @@ def download_parameter(meta_list, args):
hdfs_client = get_hdfs_client(args) hdfs_client = get_hdfs_client(args)
copyHdfsFileToLocal(meta['filePath'], param_fp, hdfs_client, override=False) copyHdfsFileToLocal(meta['filePath'], param_fp, hdfs_client, override=False)
def fetch_parameter_file(args): def fetch_parameter_file(args):
class FetchThread(threading.Thread): class FetchThread(threading.Thread):
def __init__(self, args): def __init__(self, args):
...@@ -203,6 +217,7 @@ def fetch_parameter_file(args): ...@@ -203,6 +217,7 @@ def fetch_parameter_file(args):
fetch_file_thread = FetchThread(args) fetch_file_thread = FetchThread(args)
fetch_file_thread.start() fetch_file_thread.start()
if __name__ == '__main__': if __name__ == '__main__':
'''NNI Trial Keeper main function''' '''NNI Trial Keeper main function'''
PARSER = argparse.ArgumentParser() PARSER = argparse.ArgumentParser()
...@@ -210,9 +225,9 @@ if __name__ == '__main__': ...@@ -210,9 +225,9 @@ if __name__ == '__main__':
PARSER.add_argument('--trial_command', type=str, help='Command to launch trial process') PARSER.add_argument('--trial_command', type=str, help='Command to launch trial process')
PARSER.add_argument('--nnimanager_ip', type=str, default='localhost', help='NNI manager rest server IP') PARSER.add_argument('--nnimanager_ip', type=str, default='localhost', help='NNI manager rest server IP')
PARSER.add_argument('--nnimanager_port', type=str, default='8081', help='NNI manager rest server port') PARSER.add_argument('--nnimanager_port', type=str, default='8081', help='NNI manager rest server port')
PARSER.add_argument('--pai_hdfs_output_dir', type=str, help='the output dir of pai_hdfs') # backward compatibility PARSER.add_argument('--pai_hdfs_output_dir', type=str, help='the output dir of pai_hdfs') # backward compatibility
PARSER.add_argument('--hdfs_output_dir', type=str, help='the output dir of hdfs') PARSER.add_argument('--hdfs_output_dir', type=str, help='the output dir of hdfs')
PARSER.add_argument('--pai_hdfs_host', type=str, help='the host of pai_hdfs') # backward compatibility PARSER.add_argument('--pai_hdfs_host', type=str, help='the host of pai_hdfs') # backward compatibility
PARSER.add_argument('--hdfs_host', type=str, help='the host of hdfs') PARSER.add_argument('--hdfs_host', type=str, help='the host of hdfs')
PARSER.add_argument('--pai_user_name', type=str, help='the username of hdfs') PARSER.add_argument('--pai_user_name', type=str, help='the username of hdfs')
PARSER.add_argument('--nni_hdfs_exp_dir', type=str, help='nni experiment directory in hdfs') PARSER.add_argument('--nni_hdfs_exp_dir', type=str, help='nni experiment directory in hdfs')
...@@ -233,4 +248,3 @@ if __name__ == '__main__': ...@@ -233,4 +248,3 @@ if __name__ == '__main__':
except Exception as e: except Exception as e:
nni_log(LogType.Error, 'Exit trial keeper with code 1 because Exception: {} is catched'.format(str(e))) nni_log(LogType.Error, 'Exit trial keeper with code 1 because Exception: {} is catched'.format(str(e)))
os._exit(1) os._exit(1)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment