Commit 5445bf4b authored by Zejun Lin's avatar Zejun Lin Committed by chicm-ms
Browse files

Add test code for `nnictl stop` command (#1349)

parent a651ecf4
......@@ -24,10 +24,10 @@ import sys
import time
import traceback
from utils import is_experiment_done, fetch_nni_log_path, read_last_line, remove_files, setup_experiment
from utils import is_experiment_done, get_experiment_id, get_nni_log_path, read_last_line, remove_files, setup_experiment, detect_port, snooze
from utils import GREEN, RED, CLEAR, EXPERIMENT_URL
def run():
def naive_test():
'''run naive integration test'''
to_remove = ['tuner_search_space.json', 'tuner_result.txt', 'assessor_result.txt']
to_remove = list(map(lambda file: 'naive_test/' + file, to_remove))
......@@ -38,7 +38,7 @@ def run():
print('Spawning trials...')
nnimanager_log_path = fetch_nni_log_path(EXPERIMENT_URL)
nnimanager_log_path = get_nni_log_path(EXPERIMENT_URL)
current_trial = 0
for _ in range(120):
......@@ -79,11 +79,36 @@ def run():
expected = set(open('naive_test/expected_assessor_result.txt'))
assert assessor_result == expected, 'Bad assessor result'
subprocess.run(['nnictl', 'stop'])
snooze()
def stop_experiment_test():
'''Test `nnictl stop` command, including `nnictl stop exp_id` and `nnictl stop all`.
Simple `nnictl stop` is not tested here since it is used in all other test code'''
subprocess.run(['nnictl', 'create', '--config', 'tuner_test/local.yml', '--port', '8080'], check=True)
subprocess.run(['nnictl', 'create', '--config', 'tuner_test/local.yml', '--port', '8888'], check=True)
subprocess.run(['nnictl', 'create', '--config', 'tuner_test/local.yml', '--port', '8989'], check=True)
# test cmd 'nnictl stop id`
experiment_id = get_experiment_id(EXPERIMENT_URL)
proc = subprocess.run(['nnictl', 'stop', experiment_id])
assert proc.returncode == 0, '`nnictl stop %s` failed with code %d' % (experiment_id, proc.returncode)
snooze()
assert not detect_port(8080), '`nnictl stop %s` failed to stop experiments' % experiment_id
# test cmd `nnictl stop all`
proc = subprocess.run(['nnictl', 'stop', 'all'])
assert proc.returncode == 0, '`nnictl stop all` failed with code %d' % proc.returncode
snooze()
assert not detect_port(8888) and not detect_port(8989), '`nnictl stop all` failed to stop experiments'
if __name__ == '__main__':
installed = (sys.argv[-1] != '--preinstall')
setup_experiment(installed)
try:
run()
naive_test()
stop_experiment_test()
# TODO: check the output of rest server
print(GREEN + 'PASS' + CLEAR)
except Exception as error:
......@@ -91,5 +116,3 @@ if __name__ == '__main__':
print('%r' % error)
traceback.print_exc()
sys.exit(1)
finally:
subprocess.run(['nnictl', 'stop'])
......@@ -23,15 +23,11 @@ import sys
import time
import traceback
from utils import get_yml_content, dump_yml_content, setup_experiment, fetch_nni_log_path, is_experiment_done
GREEN = '\33[32m'
RED = '\33[31m'
CLEAR = '\33[0m'
from utils import get_yml_content, dump_yml_content, setup_experiment, get_nni_log_path, is_experiment_done
from utils import GREEN, RED, CLEAR, EXPERIMENT_URL
TUNER_LIST = ['GridSearch', 'BatchTuner', 'TPE', 'Random', 'Anneal', 'Evolution']
ASSESSOR_LIST = ['Medianstop']
EXPERIMENT_URL = 'http://localhost:8080/api/v1/nni/experiment'
def switch(dispatch_type, dispatch_name):
......@@ -63,7 +59,7 @@ def test_builtin_dispatcher(dispatch_type, dispatch_name):
proc = subprocess.run(['nnictl', 'create', '--config', 'tuner_test/local.yml'])
assert proc.returncode == 0, '`nnictl create` failed with code %d' % proc.returncode
nnimanager_log_path = fetch_nni_log_path(EXPERIMENT_URL)
nnimanager_log_path = get_nni_log_path(EXPERIMENT_URL)
for _ in range(20):
time.sleep(3)
......
......@@ -22,9 +22,11 @@ import contextlib
import collections
import json
import os
import socket
import sys
import subprocess
import requests
import time
import ruamel.yaml as yaml
EXPERIMENT_DONE_SIGNAL = '"Experiment done"'
......@@ -76,10 +78,13 @@ def setup_experiment(installed=True):
pypath = ':'.join([sdk_path, cmd_path])
os.environ['PYTHONPATH'] = pypath
def fetch_nni_log_path(experiment_url):
def get_experiment_id(experiment_url):
experiment_id = requests.get(experiment_url).json()['id']
return experiment_id
def get_nni_log_path(experiment_url):
'''get nni's log path from nni's experiment url'''
experiment_profile = requests.get(experiment_url)
experiment_id = json.loads(experiment_profile.text)['id']
experiment_id = get_experiment_id(experiment_url)
experiment_path = os.path.join(os.path.expanduser('~'), 'nni', 'experiments', experiment_id)
nnimanager_log_path = os.path.join(experiment_path, 'log', 'nnimanager.log')
......@@ -98,7 +103,6 @@ def is_experiment_done(nnimanager_log_path):
def get_experiment_status(status_url):
nni_status = requests.get(status_url).json()
#print(nni_status)
return nni_status['status']
def get_succeeded_trial_num(trial_jobs_url):
......@@ -139,3 +143,17 @@ def deep_update(source, overrides):
else:
source[key] = overrides[key]
return source
def detect_port(port):
'''Detect if the port is used'''
socket_test = socket.socket(socket.AF_INET,socket.SOCK_STREAM)
try:
socket_test.connect(('127.0.0.1', int(port)))
socket_test.close()
return True
except:
return False
def snooze():
'''Sleep to make sure previous stopped exp has enough time to exit'''
time.sleep(6)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment