Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
5445bf4b
"git@developer.sourcefind.cn:OpenDAS/tilelang.git" did not exist on "a6b52c529af28f07e5704d21a45b023b9f9230b7"
Commit
5445bf4b
authored
Jul 26, 2019
by
Zejun Lin
Committed by
chicm-ms
Jul 26, 2019
Browse files
Add test code for `nnictl stop` command (#1349)
parent
a651ecf4
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
54 additions
and
17 deletions
+54
-17
test/naive_test.py
test/naive_test.py
+29
-6
test/tuner_test.py
test/tuner_test.py
+3
-7
test/utils.py
test/utils.py
+22
-4
No files found.
test/naive_test.py
View file @
5445bf4b
...
@@ -24,10 +24,10 @@ import sys
...
@@ -24,10 +24,10 @@ import sys
import
time
import
time
import
traceback
import
traceback
from
utils
import
is_experiment_done
,
fetch
_nni_log_path
,
read_last_line
,
remove_files
,
setup_experiment
from
utils
import
is_experiment_done
,
get_experiment_id
,
get
_nni_log_path
,
read_last_line
,
remove_files
,
setup_experiment
,
detect_port
,
snooze
from
utils
import
GREEN
,
RED
,
CLEAR
,
EXPERIMENT_URL
from
utils
import
GREEN
,
RED
,
CLEAR
,
EXPERIMENT_URL
def
run
():
def
naive_test
():
'''run naive integration test'''
'''run naive integration test'''
to_remove
=
[
'tuner_search_space.json'
,
'tuner_result.txt'
,
'assessor_result.txt'
]
to_remove
=
[
'tuner_search_space.json'
,
'tuner_result.txt'
,
'assessor_result.txt'
]
to_remove
=
list
(
map
(
lambda
file
:
'naive_test/'
+
file
,
to_remove
))
to_remove
=
list
(
map
(
lambda
file
:
'naive_test/'
+
file
,
to_remove
))
...
@@ -38,7 +38,7 @@ def run():
...
@@ -38,7 +38,7 @@ def run():
print
(
'Spawning trials...'
)
print
(
'Spawning trials...'
)
nnimanager_log_path
=
f
et
ch
_nni_log_path
(
EXPERIMENT_URL
)
nnimanager_log_path
=
g
et_nni_log_path
(
EXPERIMENT_URL
)
current_trial
=
0
current_trial
=
0
for
_
in
range
(
120
):
for
_
in
range
(
120
):
...
@@ -79,11 +79,36 @@ def run():
...
@@ -79,11 +79,36 @@ def run():
expected
=
set
(
open
(
'naive_test/expected_assessor_result.txt'
))
expected
=
set
(
open
(
'naive_test/expected_assessor_result.txt'
))
assert
assessor_result
==
expected
,
'Bad assessor result'
assert
assessor_result
==
expected
,
'Bad assessor result'
subprocess
.
run
([
'nnictl'
,
'stop'
])
snooze
()
def
stop_experiment_test
():
'''Test `nnictl stop` command, including `nnictl stop exp_id` and `nnictl stop all`.
Simple `nnictl stop` is not tested here since it is used in all other test code'''
subprocess
.
run
([
'nnictl'
,
'create'
,
'--config'
,
'tuner_test/local.yml'
,
'--port'
,
'8080'
],
check
=
True
)
subprocess
.
run
([
'nnictl'
,
'create'
,
'--config'
,
'tuner_test/local.yml'
,
'--port'
,
'8888'
],
check
=
True
)
subprocess
.
run
([
'nnictl'
,
'create'
,
'--config'
,
'tuner_test/local.yml'
,
'--port'
,
'8989'
],
check
=
True
)
# test cmd 'nnictl stop id`
experiment_id
=
get_experiment_id
(
EXPERIMENT_URL
)
proc
=
subprocess
.
run
([
'nnictl'
,
'stop'
,
experiment_id
])
assert
proc
.
returncode
==
0
,
'`nnictl stop %s` failed with code %d'
%
(
experiment_id
,
proc
.
returncode
)
snooze
()
assert
not
detect_port
(
8080
),
'`nnictl stop %s` failed to stop experiments'
%
experiment_id
# test cmd `nnictl stop all`
proc
=
subprocess
.
run
([
'nnictl'
,
'stop'
,
'all'
])
assert
proc
.
returncode
==
0
,
'`nnictl stop all` failed with code %d'
%
proc
.
returncode
snooze
()
assert
not
detect_port
(
8888
)
and
not
detect_port
(
8989
),
'`nnictl stop all` failed to stop experiments'
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
installed
=
(
sys
.
argv
[
-
1
]
!=
'--preinstall'
)
installed
=
(
sys
.
argv
[
-
1
]
!=
'--preinstall'
)
setup_experiment
(
installed
)
setup_experiment
(
installed
)
try
:
try
:
run
()
naive_test
()
stop_experiment_test
()
# TODO: check the output of rest server
# TODO: check the output of rest server
print
(
GREEN
+
'PASS'
+
CLEAR
)
print
(
GREEN
+
'PASS'
+
CLEAR
)
except
Exception
as
error
:
except
Exception
as
error
:
...
@@ -91,5 +116,3 @@ if __name__ == '__main__':
...
@@ -91,5 +116,3 @@ if __name__ == '__main__':
print
(
'%r'
%
error
)
print
(
'%r'
%
error
)
traceback
.
print_exc
()
traceback
.
print_exc
()
sys
.
exit
(
1
)
sys
.
exit
(
1
)
finally
:
subprocess
.
run
([
'nnictl'
,
'stop'
])
test/tuner_test.py
View file @
5445bf4b
...
@@ -23,15 +23,11 @@ import sys
...
@@ -23,15 +23,11 @@ import sys
import
time
import
time
import
traceback
import
traceback
from
utils
import
get_yml_content
,
dump_yml_content
,
setup_experiment
,
fetch_nni_log_path
,
is_experiment_done
from
utils
import
get_yml_content
,
dump_yml_content
,
setup_experiment
,
get_nni_log_path
,
is_experiment_done
from
utils
import
GREEN
,
RED
,
CLEAR
,
EXPERIMENT_URL
GREEN
=
'
\33
[32m'
RED
=
'
\33
[31m'
CLEAR
=
'
\33
[0m'
TUNER_LIST
=
[
'GridSearch'
,
'BatchTuner'
,
'TPE'
,
'Random'
,
'Anneal'
,
'Evolution'
]
TUNER_LIST
=
[
'GridSearch'
,
'BatchTuner'
,
'TPE'
,
'Random'
,
'Anneal'
,
'Evolution'
]
ASSESSOR_LIST
=
[
'Medianstop'
]
ASSESSOR_LIST
=
[
'Medianstop'
]
EXPERIMENT_URL
=
'http://localhost:8080/api/v1/nni/experiment'
def
switch
(
dispatch_type
,
dispatch_name
):
def
switch
(
dispatch_type
,
dispatch_name
):
...
@@ -63,7 +59,7 @@ def test_builtin_dispatcher(dispatch_type, dispatch_name):
...
@@ -63,7 +59,7 @@ def test_builtin_dispatcher(dispatch_type, dispatch_name):
proc
=
subprocess
.
run
([
'nnictl'
,
'create'
,
'--config'
,
'tuner_test/local.yml'
])
proc
=
subprocess
.
run
([
'nnictl'
,
'create'
,
'--config'
,
'tuner_test/local.yml'
])
assert
proc
.
returncode
==
0
,
'`nnictl create` failed with code %d'
%
proc
.
returncode
assert
proc
.
returncode
==
0
,
'`nnictl create` failed with code %d'
%
proc
.
returncode
nnimanager_log_path
=
f
et
ch
_nni_log_path
(
EXPERIMENT_URL
)
nnimanager_log_path
=
g
et_nni_log_path
(
EXPERIMENT_URL
)
for
_
in
range
(
20
):
for
_
in
range
(
20
):
time
.
sleep
(
3
)
time
.
sleep
(
3
)
...
...
test/utils.py
View file @
5445bf4b
...
@@ -22,9 +22,11 @@ import contextlib
...
@@ -22,9 +22,11 @@ import contextlib
import
collections
import
collections
import
json
import
json
import
os
import
os
import
socket
import
sys
import
sys
import
subprocess
import
subprocess
import
requests
import
requests
import
time
import
ruamel.yaml
as
yaml
import
ruamel.yaml
as
yaml
EXPERIMENT_DONE_SIGNAL
=
'"Experiment done"'
EXPERIMENT_DONE_SIGNAL
=
'"Experiment done"'
...
@@ -76,10 +78,13 @@ def setup_experiment(installed=True):
...
@@ -76,10 +78,13 @@ def setup_experiment(installed=True):
pypath
=
':'
.
join
([
sdk_path
,
cmd_path
])
pypath
=
':'
.
join
([
sdk_path
,
cmd_path
])
os
.
environ
[
'PYTHONPATH'
]
=
pypath
os
.
environ
[
'PYTHONPATH'
]
=
pypath
def
fetch_nni_log_path
(
experiment_url
):
def
get_experiment_id
(
experiment_url
):
experiment_id
=
requests
.
get
(
experiment_url
).
json
()[
'id'
]
return
experiment_id
def
get_nni_log_path
(
experiment_url
):
'''get nni's log path from nni's experiment url'''
'''get nni's log path from nni's experiment url'''
experiment_profile
=
requests
.
get
(
experiment_url
)
experiment_id
=
get_experiment_id
(
experiment_url
)
experiment_id
=
json
.
loads
(
experiment_profile
.
text
)[
'id'
]
experiment_path
=
os
.
path
.
join
(
os
.
path
.
expanduser
(
'~'
),
'nni'
,
'experiments'
,
experiment_id
)
experiment_path
=
os
.
path
.
join
(
os
.
path
.
expanduser
(
'~'
),
'nni'
,
'experiments'
,
experiment_id
)
nnimanager_log_path
=
os
.
path
.
join
(
experiment_path
,
'log'
,
'nnimanager.log'
)
nnimanager_log_path
=
os
.
path
.
join
(
experiment_path
,
'log'
,
'nnimanager.log'
)
...
@@ -98,7 +103,6 @@ def is_experiment_done(nnimanager_log_path):
...
@@ -98,7 +103,6 @@ def is_experiment_done(nnimanager_log_path):
def
get_experiment_status
(
status_url
):
def
get_experiment_status
(
status_url
):
nni_status
=
requests
.
get
(
status_url
).
json
()
nni_status
=
requests
.
get
(
status_url
).
json
()
#print(nni_status)
return
nni_status
[
'status'
]
return
nni_status
[
'status'
]
def
get_succeeded_trial_num
(
trial_jobs_url
):
def
get_succeeded_trial_num
(
trial_jobs_url
):
...
@@ -139,3 +143,17 @@ def deep_update(source, overrides):
...
@@ -139,3 +143,17 @@ def deep_update(source, overrides):
else
:
else
:
source
[
key
]
=
overrides
[
key
]
source
[
key
]
=
overrides
[
key
]
return
source
return
source
def
detect_port
(
port
):
'''Detect if the port is used'''
socket_test
=
socket
.
socket
(
socket
.
AF_INET
,
socket
.
SOCK_STREAM
)
try
:
socket_test
.
connect
((
'127.0.0.1'
,
int
(
port
)))
socket_test
.
close
()
return
True
except
:
return
False
def
snooze
():
'''Sleep to make sure previous stopped exp has enough time to exit'''
time
.
sleep
(
6
)
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment