"git@developer.sourcefind.cn:OpenDAS/tilelang.git" did not exist on "a6b52c529af28f07e5704d21a45b023b9f9230b7"
Commit 5445bf4b authored by Zejun Lin's avatar Zejun Lin Committed by chicm-ms
Browse files

Add test code for `nnictl stop` command (#1349)

parent a651ecf4
...@@ -24,10 +24,10 @@ import sys ...@@ -24,10 +24,10 @@ import sys
import time import time
import traceback import traceback
from utils import is_experiment_done, fetch_nni_log_path, read_last_line, remove_files, setup_experiment from utils import is_experiment_done, get_experiment_id, get_nni_log_path, read_last_line, remove_files, setup_experiment, detect_port, snooze
from utils import GREEN, RED, CLEAR, EXPERIMENT_URL from utils import GREEN, RED, CLEAR, EXPERIMENT_URL
def run(): def naive_test():
'''run naive integration test''' '''run naive integration test'''
to_remove = ['tuner_search_space.json', 'tuner_result.txt', 'assessor_result.txt'] to_remove = ['tuner_search_space.json', 'tuner_result.txt', 'assessor_result.txt']
to_remove = list(map(lambda file: 'naive_test/' + file, to_remove)) to_remove = list(map(lambda file: 'naive_test/' + file, to_remove))
...@@ -38,7 +38,7 @@ def run(): ...@@ -38,7 +38,7 @@ def run():
print('Spawning trials...') print('Spawning trials...')
nnimanager_log_path = fetch_nni_log_path(EXPERIMENT_URL) nnimanager_log_path = get_nni_log_path(EXPERIMENT_URL)
current_trial = 0 current_trial = 0
for _ in range(120): for _ in range(120):
...@@ -79,11 +79,36 @@ def run(): ...@@ -79,11 +79,36 @@ def run():
expected = set(open('naive_test/expected_assessor_result.txt')) expected = set(open('naive_test/expected_assessor_result.txt'))
assert assessor_result == expected, 'Bad assessor result' assert assessor_result == expected, 'Bad assessor result'
subprocess.run(['nnictl', 'stop'])
snooze()
def stop_experiment_test():
'''Test `nnictl stop` command, including `nnictl stop exp_id` and `nnictl stop all`.
Simple `nnictl stop` is not tested here since it is used in all other test code'''
subprocess.run(['nnictl', 'create', '--config', 'tuner_test/local.yml', '--port', '8080'], check=True)
subprocess.run(['nnictl', 'create', '--config', 'tuner_test/local.yml', '--port', '8888'], check=True)
subprocess.run(['nnictl', 'create', '--config', 'tuner_test/local.yml', '--port', '8989'], check=True)
# test cmd 'nnictl stop id`
experiment_id = get_experiment_id(EXPERIMENT_URL)
proc = subprocess.run(['nnictl', 'stop', experiment_id])
assert proc.returncode == 0, '`nnictl stop %s` failed with code %d' % (experiment_id, proc.returncode)
snooze()
assert not detect_port(8080), '`nnictl stop %s` failed to stop experiments' % experiment_id
# test cmd `nnictl stop all`
proc = subprocess.run(['nnictl', 'stop', 'all'])
assert proc.returncode == 0, '`nnictl stop all` failed with code %d' % proc.returncode
snooze()
assert not detect_port(8888) and not detect_port(8989), '`nnictl stop all` failed to stop experiments'
if __name__ == '__main__': if __name__ == '__main__':
installed = (sys.argv[-1] != '--preinstall') installed = (sys.argv[-1] != '--preinstall')
setup_experiment(installed) setup_experiment(installed)
try: try:
run() naive_test()
stop_experiment_test()
# TODO: check the output of rest server # TODO: check the output of rest server
print(GREEN + 'PASS' + CLEAR) print(GREEN + 'PASS' + CLEAR)
except Exception as error: except Exception as error:
...@@ -91,5 +116,3 @@ if __name__ == '__main__': ...@@ -91,5 +116,3 @@ if __name__ == '__main__':
print('%r' % error) print('%r' % error)
traceback.print_exc() traceback.print_exc()
sys.exit(1) sys.exit(1)
finally:
subprocess.run(['nnictl', 'stop'])
...@@ -23,15 +23,11 @@ import sys ...@@ -23,15 +23,11 @@ import sys
import time import time
import traceback import traceback
from utils import get_yml_content, dump_yml_content, setup_experiment, fetch_nni_log_path, is_experiment_done from utils import get_yml_content, dump_yml_content, setup_experiment, get_nni_log_path, is_experiment_done
from utils import GREEN, RED, CLEAR, EXPERIMENT_URL
GREEN = '\33[32m'
RED = '\33[31m'
CLEAR = '\33[0m'
TUNER_LIST = ['GridSearch', 'BatchTuner', 'TPE', 'Random', 'Anneal', 'Evolution'] TUNER_LIST = ['GridSearch', 'BatchTuner', 'TPE', 'Random', 'Anneal', 'Evolution']
ASSESSOR_LIST = ['Medianstop'] ASSESSOR_LIST = ['Medianstop']
EXPERIMENT_URL = 'http://localhost:8080/api/v1/nni/experiment'
def switch(dispatch_type, dispatch_name): def switch(dispatch_type, dispatch_name):
...@@ -63,7 +59,7 @@ def test_builtin_dispatcher(dispatch_type, dispatch_name): ...@@ -63,7 +59,7 @@ def test_builtin_dispatcher(dispatch_type, dispatch_name):
proc = subprocess.run(['nnictl', 'create', '--config', 'tuner_test/local.yml']) proc = subprocess.run(['nnictl', 'create', '--config', 'tuner_test/local.yml'])
assert proc.returncode == 0, '`nnictl create` failed with code %d' % proc.returncode assert proc.returncode == 0, '`nnictl create` failed with code %d' % proc.returncode
nnimanager_log_path = fetch_nni_log_path(EXPERIMENT_URL) nnimanager_log_path = get_nni_log_path(EXPERIMENT_URL)
for _ in range(20): for _ in range(20):
time.sleep(3) time.sleep(3)
......
...@@ -22,9 +22,11 @@ import contextlib ...@@ -22,9 +22,11 @@ import contextlib
import collections import collections
import json import json
import os import os
import socket
import sys import sys
import subprocess import subprocess
import requests import requests
import time
import ruamel.yaml as yaml import ruamel.yaml as yaml
EXPERIMENT_DONE_SIGNAL = '"Experiment done"' EXPERIMENT_DONE_SIGNAL = '"Experiment done"'
...@@ -76,10 +78,13 @@ def setup_experiment(installed=True): ...@@ -76,10 +78,13 @@ def setup_experiment(installed=True):
pypath = ':'.join([sdk_path, cmd_path]) pypath = ':'.join([sdk_path, cmd_path])
os.environ['PYTHONPATH'] = pypath os.environ['PYTHONPATH'] = pypath
def fetch_nni_log_path(experiment_url): def get_experiment_id(experiment_url):
experiment_id = requests.get(experiment_url).json()['id']
return experiment_id
def get_nni_log_path(experiment_url):
'''get nni's log path from nni's experiment url''' '''get nni's log path from nni's experiment url'''
experiment_profile = requests.get(experiment_url) experiment_id = get_experiment_id(experiment_url)
experiment_id = json.loads(experiment_profile.text)['id']
experiment_path = os.path.join(os.path.expanduser('~'), 'nni', 'experiments', experiment_id) experiment_path = os.path.join(os.path.expanduser('~'), 'nni', 'experiments', experiment_id)
nnimanager_log_path = os.path.join(experiment_path, 'log', 'nnimanager.log') nnimanager_log_path = os.path.join(experiment_path, 'log', 'nnimanager.log')
...@@ -98,7 +103,6 @@ def is_experiment_done(nnimanager_log_path): ...@@ -98,7 +103,6 @@ def is_experiment_done(nnimanager_log_path):
def get_experiment_status(status_url): def get_experiment_status(status_url):
nni_status = requests.get(status_url).json() nni_status = requests.get(status_url).json()
#print(nni_status)
return nni_status['status'] return nni_status['status']
def get_succeeded_trial_num(trial_jobs_url): def get_succeeded_trial_num(trial_jobs_url):
...@@ -139,3 +143,17 @@ def deep_update(source, overrides): ...@@ -139,3 +143,17 @@ def deep_update(source, overrides):
else: else:
source[key] = overrides[key] source[key] = overrides[key]
return source return source
def detect_port(port):
'''Detect if the port is used'''
socket_test = socket.socket(socket.AF_INET,socket.SOCK_STREAM)
try:
socket_test.connect(('127.0.0.1', int(port)))
socket_test.close()
return True
except:
return False
def snooze():
'''Sleep to make sure previous stopped exp has enough time to exit'''
time.sleep(6)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment