Unverified Commit 5d7c1cd8 authored by SparkSnail's avatar SparkSnail Committed by GitHub
Browse files

Add nnictl ut (#2912)

parent 9369b719
......@@ -54,6 +54,7 @@ setuptools.setup(
'ruamel.yaml',
'psutil',
'requests',
'responses',
'astor',
'PythonWebHDFS',
'hyperopt==0.1.2',
......
......@@ -37,6 +37,7 @@ setup(
'psutil',
'ruamel.yaml',
'requests',
'responses',
'scipy',
'schema',
'PythonWebHDFS',
......
......@@ -9,8 +9,8 @@ from .command_utils import print_error
class Config:
'''a util class to load and save config'''
def __init__(self, file_path):
config_path = os.path.join(NNICTL_HOME_DIR, str(file_path))
def __init__(self, file_path, home_dir=NNICTL_HOME_DIR):
config_path = os.path.join(home_dir, str(file_path))
os.makedirs(config_path, exist_ok=True)
self.config_file = os.path.join(config_path, '.config')
self.config = self.read_file()
......@@ -51,9 +51,9 @@ class Config:
class Experiments:
'''Maintain experiment list'''
def __init__(self):
os.makedirs(NNICTL_HOME_DIR, exist_ok=True)
self.experiment_file = os.path.join(NNICTL_HOME_DIR, '.experiment')
def __init__(self, home_dir=NNICTL_HOME_DIR):
os.makedirs(home_dir, exist_ok=True)
self.experiment_file = os.path.join(home_dir, '.experiment')
self.experiments = self.read_file()
def add_experiment(self, expId, port, startTime, file_name, platform, experiment_name, endTime='N/A', status='INITIALIZED'):
......
......@@ -213,10 +213,11 @@ def check_rest(args):
nni_config = Config(get_config_filename(args))
rest_port = nni_config.get_config('restServerPort')
running, _ = check_rest_server_quick(rest_port)
if not running:
if running:
print_normal('Restful server is running...')
else:
print_normal('Restful server is not running...')
return running
def stop_experiment(args):
'''Stop the experiment which is running'''
......@@ -284,10 +285,12 @@ def trial_ls(args):
for index, value in enumerate(content):
content[index] = convert_time_stamp_to_date(value)
print(json.dumps(content, indent=4, sort_keys=True, separators=(',', ':')))
return content
else:
print_error('List trial failed...')
else:
print_error('Restful server is not running...')
return None
def trial_kill(args):
'''List trial'''
......@@ -302,10 +305,12 @@ def trial_kill(args):
response = rest_delete(trial_job_id_url(rest_port, args.trial_id), REST_TIME_OUT)
if response and check_response(response):
print(response.text)
return True
else:
print_error('Kill trial job failed...')
else:
print_error('Restful server is not running...')
return False
def trial_codegen(args):
'''Generate code for a specific trial'''
......@@ -332,10 +337,12 @@ def list_experiment(args):
if response and check_response(response):
content = convert_time_stamp_to_date(json.loads(response.text))
print(json.dumps(content, indent=4, sort_keys=True, separators=(',', ':')))
return content
else:
print_error('List experiment failed...')
else:
print_error('Restful server is not running...')
return None
def experiment_status(args):
'''Show the status of experiment'''
......@@ -346,6 +353,7 @@ def experiment_status(args):
print_normal('Restful server is not running...')
else:
print(json.dumps(json.loads(response.text), indent=4, sort_keys=True, separators=(',', ':')))
return result
def log_internal(args, filetype):
'''internal function to call get_log_content'''
......@@ -618,6 +626,7 @@ def experiment_list(args):
experiment_dict[key]['startTime'],
experiment_dict[key]['endTime'])
print(EXPERIMENT_INFORMATION_FORMAT % experiment_information)
return experiment_id_list
def get_time_interval(time1, time2):
'''get the interval of two times'''
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import argparse
from subprocess import Popen, PIPE, STDOUT
from nni_cmd.config_utils import Config, Experiments
from nni_cmd.common_utils import print_green
from nni_cmd.command_utils import kill_command
from nni_cmd.nnictl_utils import get_yml_content
def create_mock_experiment():
nnictl_experiment_config = Experiments()
nnictl_experiment_config.add_experiment('xOpEwA5w', '8080', '1970/01/1 01:01:01', 'aGew0x',
'local', 'example_sklearn-classification')
nni_config = Config('aGew0x')
# mock process
cmds = ['sleep', '3600000']
process = Popen(cmds, stdout=PIPE, stderr=STDOUT)
nni_config.set_config('restServerPid', process.pid)
nni_config.set_config('experimentId', 'xOpEwA5w')
nni_config.set_config('restServerPort', 8080)
nni_config.set_config('webuiUrl', ['http://localhost:8080'])
experiment_config = get_yml_content('./tests/config_files/valid/test.yml')
nni_config.set_config('experimentConfig', experiment_config)
print_green("expriment start success, experiment id: xOpEwA5w")
def stop_mock_experiment():
config = Config('config')
kill_command(config.get_config('restServerPid'))
nnictl_experiment_config = Experiments()
nnictl_experiment_config.remove_experiment('xOpEwA5w')
def generate_args_parser():
parser = argparse.ArgumentParser()
parser.add_argument('id', nargs='?')
parser.add_argument('--port', '-p', dest='port')
parser.add_argument('--all', '-a', action='store_true')
parser.add_argument('--head', type=int)
parser.add_argument('--tail', type=int)
return parser
def generate_args():
parser = generate_args_parser()
args = parser.parse_args(['xOpEwA5w'])
return args
{"xOpEwA5w": {"port": 8080, "startTime": "1970/01/1 01:01:01", "endTime": "1970-01-2 01:01:01", "status": "RUNNING", "fileName": "aGew0x", "platform": "local", "experimentName": "example_sklearn-classification"}}
{"experimentConfig": {"authorName": "default", "experimentName": "example_sklearn-classification", "trialConcurrency": 5, "maxExecDuration": 3600, "maxTrialNum": 100, "trainingServicePlatform": "local", "searchSpacePath": "../../../config_files/valid/search_space.json", "useAnnotation": false, "tuner": {"builtinTunerName": "TPE", "classArgs": {"optimize_mode": "maximize"}}, "trial": {"command": "python3 main.py", "codeDir": "../../../config_files/valid/.", "gpuNum": 0}}, "restServerPort": 8080, "restServerPid": 7952, "experimentId": "xOpEwA5w", "webuiUrl": ["http://localhost:8080"]}
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import responses
def mock_check_status():
responses.add(
responses.GET,
"http://localhost:8080/api/v1/nni/check-status",
json={"status":"RUNNING","errors":[]},
status=200
)
def mock_version():
responses.add(
responses.GET,
"http://localhost:8080/api/v1/nni/version",
json={'value':1.8},
status=200
)
def mock_get_experiment_profile():
responses.add(
responses.GET,
"http://localhost:8080/api/v1/nni/experiment",
json={"id":"bkfhOdUl","revision":5,"execDuration":10,"logDir":"/home/shinyang/nni-experiments/bkfhOdUl",
"nextSequenceId":2,"params":{"authorName":"default","experimentName":"example_sklearn-classification",
"trialConcurrency":1,"maxExecDuration":3600,"maxTrialNum":1,
"searchSpace":"{\"C\": {\"_type\": \"uniform\", \"_value\": [0.1, 1]}, \
\"kernel\": {\"_type\": \"choice\", \"_value\": [\"linear\", \"rbf\", \"poly\", \"sigmoid\"]}, \
\"degree\": {\"_type\": \"choice\", \"_value\": [1, 2, 3, 4]}, \"gamma\": {\"_type\": \"uniform\", \
\"_value\": [0.01, 0.1]}}", \
"trainingServicePlatform":"local","tuner":{"builtinTunerName":"TPE","classArgs":{"optimize_mode":"maximize"}, \
"checkpointDir":"/home/shinyang/nni-experiments/bkfhOdUl/checkpoint"},"versionCheck":"true", \
"clusterMetaData":[{"key":"codeDir","value":"/home/shinyang/folder/examples/trials/sklearn/classification/."}, \
{"key":"command","value":"python3 main.py"}]},"startTime":1600326895536,"endTime":1600326910605},
status=200
)
def mock_update_experiment_profile():
responses.add(
responses.PUT, 'http://localhost:8080/api/v1/nni/experiment',
json={"status":"RUNNING","errors":[]},
status=200,
content_type='application/json',
)
def mock_import_data():
responses.add(
responses.POST, 'http://localhost:8080/api/v1/nni/experiment/import-data',
json={"result":"data"},
status=201,
content_type='application/json',
)
def mock_start_experiment():
responses.add(
responses.POST, 'http://localhost:8080/api/v1/nni/experiment',
json={"status":"RUNNING","errors":[]},
status=201,
content_type='application/json',
)
def mock_get_trial_job_statistics():
responses.add(
responses.GET, 'http://localhost:8080/api/v1/nni/job-statistics',
json=[{"trialJobStatus":"SUCCEEDED","trialJobNumber":1}],
status=200,
content_type='application/json',
)
def mock_set_cluster_metadata():
responses.add(
responses.PUT, 'http://localhost:8080/api/v1/nni/experiment/cluster-metadata',
json=[{"trialJobStatus":"SUCCEEDED","trialJobNumber":1}],
status=201,
content_type='application/json',
)
def mock_list_trial_jobs():
responses.add(
responses.GET, 'http://localhost:8080/api/v1/nni/trial-jobs',
json=[{"id":"GPInz","status":"SUCCEEDED","hyperParameters":["{\"parameter_id\":0, \
\"parameter_source\":\"algorithm\",\"parameters\":{\"C\":0.8748364659110364, \
\"kernel\":\"linear\",\"degree\":1,\"gamma\":0.040451413392113666}, \
\"parameter_index\":0}"],"logPath":"file://localhost:/home/shinyang/nni-experiments/bkfhOdUl/trials/GPInz",
"startTime":1600326905581,"sequenceId":0,"endTime":1600326906629,
"finalMetricData":[{"timestamp":1600326906493,"trialJobId":"GPInz","parameterId":"0",
"type":"FINAL","sequence":0,"data":"\"0.9866666666666667\""}]}],
status=200,
content_type='application/json',
)
def mock_get_trial_job():
responses.add(
responses.GET, 'http://localhost:8080/api/v1/nni/trial-jobs/:id',
json={"id":"GPInz","status":"SUCCEEDED","hyperParameters":["{\"parameter_id\":0, \
\"parameter_source\":\"algorithm\",\"parameters\":{\"C\":0.8748364659110364, \
\"kernel\":\"linear\",\"degree\":1,\"gamma\":0.040451413392113666}, \
\"parameter_index\":0}"],"logPath":"file://localhost:/home/shinyang/nni-experiments/bkfhOdUl/trials/GPInz",
"startTime":1600326905581,"sequenceId":0,"endTime":1600326906629,
"finalMetricData":[{"timestamp":1600326906493,"trialJobId":"GPInz","parameterId":"0","type":"FINAL",
"sequence":0,"data":"\"0.9866666666666667\""}]},
status=200,
content_type='application/json',
)
def mock_add_trial_job():
responses.add(
responses.POST, 'http://localhost:8080/api/v1/nni/trial-jobs',
json=[{"trialJobStatus":"SUCCEEDED","trialJobNumber":1}],
status=201,
content_type='application/json',
)
def mock_cancel_trial_job():
responses.add(
responses.DELETE, 'http://localhost:8080/api/v1/nni/trial-jobs/:id',
json=[{"trialJobStatus":"SUCCEEDED","trialJobNumber":1}],
status=200,
content_type='application/json',
)
def mock_get_metric_data():
responses.add(
responses.DELETE, 'http://localhost:8080/api/v1/nni/metric-data/:job_id*?',
json=[{"timestamp":1600326906486,"trialJobId":"GPInz","parameterId":"0",
"type":"PERIODICAL","sequence":0,"data":"\"0.9866666666666667\""},
{"timestamp":1600326906493,"trialJobId":"GPInz","parameterId":"0",
"type":"FINAL","sequence":0,"data":"\"0.9866666666666667\""}],
status=200,
content_type='application/json',
)
def mock_get_metric_data_by_range():
responses.add(
responses.DELETE, 'http://localhost:8080/api/v1/nni/metric-data-range/:min_seq_id/:max_seq_id',
json=[{"timestamp":1600326906486,"trialJobId":"GPInz","parameterId":"0",
"type":"PERIODICAL","sequence":0,"data":"\"0.9866666666666667\""},
{"timestamp":1600326906493,"trialJobId":"GPInz","parameterId":"0",
"type":"FINAL","sequence":0,"data":"\"0.9866666666666667\""}],
status=200,
content_type='application/json',
)
def mock_get_latest_metric_data():
responses.add(
responses.DELETE, 'http://localhost:8080/api/v1/nni/metric-data-latest/',
json=[{"timestamp":1600326906493,"trialJobId":"GPInz","parameterId":"0",
"type":"FINAL","sequence":0,"data":"\"0.9866666666666667\""},{"timestamp":1600326906486,
"trialJobId":"GPInz","parameterId":"0","type":"PERIODICAL",
"sequence":0,"data":"\"0.9866666666666667\""}],
status=200,
content_type='application/json',
)
def mock_get_trial_log():
responses.add(
responses.DELETE, 'http://localhost:8080/api/v1/nni/trial-log/:id/:type',
json={"status":"RUNNING","errors":[]},
status=200,
content_type='application/json',
)
def mock_export_data():
responses.add(
responses.DELETE, 'http://localhost:8080/api/v1/nni/export-data',
json={"status":"RUNNING","errors":[]},
status=200,
content_type='application/json',
)
def init_response():
mock_check_status()
mock_version()
mock_get_experiment_profile()
mock_set_cluster_metadata()
mock_list_trial_jobs()
mock_get_trial_job()
mock_add_trial_job()
mock_cancel_trial_job()
mock_get_metric_data()
mock_get_metric_data_by_range()
mock_get_latest_metric_data()
mock_get_trial_log()
mock_export_data()
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from unittest import TestCase, main
from nni_cmd.common_utils import get_yml_content, get_json_content, detect_process
from mock.restful_server import init_response
from subprocess import Popen, PIPE, STDOUT
from nni_cmd.command_utils import kill_command
class CommonUtilsTestCase(TestCase):
@classmethod
def setUpClass(cls):
init_response()
def test_get_yml(self):
content = get_yml_content('./tests/config_files/test_files/test_yaml.yml')
self.assertEqual(content, {'field':'test'})
def test_get_json(self):
content = get_json_content('./tests/config_files/test_files/test_json.json')
self.assertEqual(content, {'field':'test'})
def test_detect_process(self):
cmds = ['sleep', '360000']
process = Popen(cmds, stdout=PIPE, stderr=STDOUT)
self.assertTrue(detect_process(process.pid))
kill_command(process.pid)
if __name__ == '__main__':
main()
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from unittest import TestCase, main
from nni_cmd.config_utils import Config, Experiments
HOME_PATH = "./tests/mock/nnictl_metadata"
class CommonUtilsTestCase(TestCase):
def test_get_experiment(self):
experiment = Experiments(HOME_PATH)
self.assertTrue('xOpEwA5w' in experiment.get_all_experiments())
def test_update_experiment(self):
experiment = Experiments(HOME_PATH)
experiment.add_experiment('xOpEwA5w', 8081, 'N/A', 'aGew0x', 'local', 'test', endTime='N/A', status='INITIALIZED')
self.assertTrue('xOpEwA5w' in experiment.get_all_experiments())
experiment.remove_experiment('xOpEwA5w')
self.assertFalse('xOpEwA5w' in experiment.get_all_experiments())
def test_get_config(self):
config = Config('config', HOME_PATH)
self.assertEqual(config.get_config('experimentId'), 'xOpEwA5w')
def test_set_config(self):
config = Config('config', HOME_PATH)
self.assertNotEqual(config.get_config('experimentId'), 'testId')
config.set_config('experimentId', 'testId')
self.assertEqual(config.get_config('experimentId'), 'testId')
config.set_config('experimentId', 'xOpEwA5w')
if __name__ == '__main__':
main()
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from mock.restful_server import init_response
from mock.experiment import create_mock_experiment, stop_mock_experiment, generate_args_parser, \
generate_args
from nni_cmd.nnictl_utils import get_experiment_time, get_experiment_status, \
check_experiment_id, parse_ids, get_config_filename, get_experiment_port, check_rest, \
trial_ls, list_experiment
from unittest import TestCase, main
import responses
class CommonUtilsTestCase(TestCase):
@classmethod
def setUp(self):
init_response()
create_mock_experiment()
@classmethod
def tearDown(self):
stop_mock_experiment()
@responses.activate
def test_get_experiment_status(self):
self.assertEqual('RUNNING', get_experiment_status(8080))
@responses.activate
def test_check_experiment_id(self):
parser = generate_args_parser()
args = parser.parse_args(['xOpEwA5w'])
self.assertEqual('xOpEwA5w', check_experiment_id(args))
@responses.activate
def test_parse_ids(self):
parser = generate_args_parser()
args = parser.parse_args(['xOpEwA5w'])
self.assertEqual(['xOpEwA5w'], parse_ids(args))
@responses.activate
def test_get_config_file_name(self):
args = generate_args()
self.assertEqual('aGew0x', get_config_filename(args))
@responses.activate
def test_get_experiment_port(self):
args = generate_args()
self.assertEqual('8080', get_experiment_port(args))
@responses.activate
def test_check_rest(self):
args = generate_args()
self.assertEqual(True, check_rest(args))
@responses.activate
def test_trial_ls(self):
args = generate_args()
trials = trial_ls(args)
self.assertEqual(trials[0]['id'], 'GPInz')
if __name__ == '__main__':
main()
......@@ -11,6 +11,7 @@ setuptools.setup(
python_requires = '>=3.6',
install_requires = [
'requests',
'responses',
'ruamel.yaml',
'psutil',
'astor',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment