Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
5d7c1cd8
Unverified
Commit
5d7c1cd8
authored
Sep 28, 2020
by
SparkSnail
Committed by
GitHub
Sep 28, 2020
Browse files
Add nnictl ut (#2912)
parent
9369b719
Changes
17
Hide whitespace changes
Inline
Side-by-side
Showing
17 changed files
with
382 additions
and
6 deletions
+382
-6
deployment/pypi/setup.py
deployment/pypi/setup.py
+1
-0
setup.py
setup.py
+1
-0
tools/nni_cmd/config_utils.py
tools/nni_cmd/config_utils.py
+5
-5
tools/nni_cmd/nnictl_utils.py
tools/nni_cmd/nnictl_utils.py
+10
-1
tools/nni_cmd/tests/config_files/test_files/test_json.json
tools/nni_cmd/tests/config_files/test_files/test_json.json
+1
-0
tools/nni_cmd/tests/config_files/test_files/test_yaml.yml
tools/nni_cmd/tests/config_files/test_files/test_yaml.yml
+1
-0
tools/nni_cmd/tests/mock/experiment.py
tools/nni_cmd/tests/mock/experiment.py
+45
-0
tools/nni_cmd/tests/mock/nnictl_metadata/.experiment
tools/nni_cmd/tests/mock/nnictl_metadata/.experiment
+1
-0
tools/nni_cmd/tests/mock/nnictl_metadata/aGew0x/.config
tools/nni_cmd/tests/mock/nnictl_metadata/aGew0x/.config
+1
-0
tools/nni_cmd/tests/mock/nnictl_metadata/aGew0x/stderr
tools/nni_cmd/tests/mock/nnictl_metadata/aGew0x/stderr
+1
-0
tools/nni_cmd/tests/mock/nnictl_metadata/aGew0x/stdout
tools/nni_cmd/tests/mock/nnictl_metadata/aGew0x/stdout
+1
-0
tools/nni_cmd/tests/mock/nnictl_metadata/config/.config
tools/nni_cmd/tests/mock/nnictl_metadata/config/.config
+1
-0
tools/nni_cmd/tests/mock/restful_server.py
tools/nni_cmd/tests/mock/restful_server.py
+186
-0
tools/nni_cmd/tests/test_common_utils.py
tools/nni_cmd/tests/test_common_utils.py
+31
-0
tools/nni_cmd/tests/test_config_utils.py
tools/nni_cmd/tests/test_config_utils.py
+33
-0
tools/nni_cmd/tests/test_nnictl_utils.py
tools/nni_cmd/tests/test_nnictl_utils.py
+62
-0
tools/setup.py
tools/setup.py
+1
-0
No files found.
deployment/pypi/setup.py
View file @
5d7c1cd8
...
...
@@ -54,6 +54,7 @@ setuptools.setup(
'ruamel.yaml'
,
'psutil'
,
'requests'
,
'responses'
,
'astor'
,
'PythonWebHDFS'
,
'hyperopt==0.1.2'
,
...
...
setup.py
View file @
5d7c1cd8
...
...
@@ -37,6 +37,7 @@ setup(
'psutil'
,
'ruamel.yaml'
,
'requests'
,
'responses'
,
'scipy'
,
'schema'
,
'PythonWebHDFS'
,
...
...
tools/nni_cmd/config_utils.py
View file @
5d7c1cd8
...
...
@@ -9,8 +9,8 @@ from .command_utils import print_error
class
Config
:
'''a util class to load and save config'''
def
__init__
(
self
,
file_path
):
config_path
=
os
.
path
.
join
(
NNICTL_HOME_DIR
,
str
(
file_path
))
def
__init__
(
self
,
file_path
,
home_dir
=
NNICTL_HOME_DIR
):
config_path
=
os
.
path
.
join
(
home_dir
,
str
(
file_path
))
os
.
makedirs
(
config_path
,
exist_ok
=
True
)
self
.
config_file
=
os
.
path
.
join
(
config_path
,
'.config'
)
self
.
config
=
self
.
read_file
()
...
...
@@ -51,9 +51,9 @@ class Config:
class
Experiments
:
'''Maintain experiment list'''
def
__init__
(
self
):
os
.
makedirs
(
NNICTL_HOME_DIR
,
exist_ok
=
True
)
self
.
experiment_file
=
os
.
path
.
join
(
NNICTL_HOME_DIR
,
'.experiment'
)
def
__init__
(
self
,
home_dir
=
NNICTL_HOME_DIR
):
os
.
makedirs
(
home_dir
,
exist_ok
=
True
)
self
.
experiment_file
=
os
.
path
.
join
(
home_dir
,
'.experiment'
)
self
.
experiments
=
self
.
read_file
()
def
add_experiment
(
self
,
expId
,
port
,
startTime
,
file_name
,
platform
,
experiment_name
,
endTime
=
'N/A'
,
status
=
'INITIALIZED'
):
...
...
tools/nni_cmd/nnictl_utils.py
View file @
5d7c1cd8
...
...
@@ -213,10 +213,11 @@ def check_rest(args):
nni_config
=
Config
(
get_config_filename
(
args
))
rest_port
=
nni_config
.
get_config
(
'restServerPort'
)
running
,
_
=
check_rest_server_quick
(
rest_port
)
if
not
running
:
if
running
:
print_normal
(
'Restful server is running...'
)
else
:
print_normal
(
'Restful server is not running...'
)
return
running
def
stop_experiment
(
args
):
'''Stop the experiment which is running'''
...
...
@@ -284,10 +285,12 @@ def trial_ls(args):
for
index
,
value
in
enumerate
(
content
):
content
[
index
]
=
convert_time_stamp_to_date
(
value
)
print
(
json
.
dumps
(
content
,
indent
=
4
,
sort_keys
=
True
,
separators
=
(
','
,
':'
)))
return
content
else
:
print_error
(
'List trial failed...'
)
else
:
print_error
(
'Restful server is not running...'
)
return
None
def
trial_kill
(
args
):
'''List trial'''
...
...
@@ -302,10 +305,12 @@ def trial_kill(args):
response
=
rest_delete
(
trial_job_id_url
(
rest_port
,
args
.
trial_id
),
REST_TIME_OUT
)
if
response
and
check_response
(
response
):
print
(
response
.
text
)
return
True
else
:
print_error
(
'Kill trial job failed...'
)
else
:
print_error
(
'Restful server is not running...'
)
return
False
def
trial_codegen
(
args
):
'''Generate code for a specific trial'''
...
...
@@ -332,10 +337,12 @@ def list_experiment(args):
if
response
and
check_response
(
response
):
content
=
convert_time_stamp_to_date
(
json
.
loads
(
response
.
text
))
print
(
json
.
dumps
(
content
,
indent
=
4
,
sort_keys
=
True
,
separators
=
(
','
,
':'
)))
return
content
else
:
print_error
(
'List experiment failed...'
)
else
:
print_error
(
'Restful server is not running...'
)
return
None
def
experiment_status
(
args
):
'''Show the status of experiment'''
...
...
@@ -346,6 +353,7 @@ def experiment_status(args):
print_normal
(
'Restful server is not running...'
)
else
:
print
(
json
.
dumps
(
json
.
loads
(
response
.
text
),
indent
=
4
,
sort_keys
=
True
,
separators
=
(
','
,
':'
)))
return
result
def
log_internal
(
args
,
filetype
):
'''internal function to call get_log_content'''
...
...
@@ -618,6 +626,7 @@ def experiment_list(args):
experiment_dict
[
key
][
'startTime'
],
experiment_dict
[
key
][
'endTime'
])
print
(
EXPERIMENT_INFORMATION_FORMAT
%
experiment_information
)
return
experiment_id_list
def
get_time_interval
(
time1
,
time2
):
'''get the interval of two times'''
...
...
tools/nni_cmd/tests/config_files/test_files/test_json.json
0 → 100644
View file @
5d7c1cd8
{
"field"
:
"test"
}
tools/nni_cmd/tests/config_files/test_files/test_yaml.yml
0 → 100644
View file @
5d7c1cd8
field
:
test
tools/nni_cmd/tests/mock/experiment.py
0 → 100644
View file @
5d7c1cd8
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
argparse
from
subprocess
import
Popen
,
PIPE
,
STDOUT
from
nni_cmd.config_utils
import
Config
,
Experiments
from
nni_cmd.common_utils
import
print_green
from
nni_cmd.command_utils
import
kill_command
from
nni_cmd.nnictl_utils
import
get_yml_content
def
create_mock_experiment
():
nnictl_experiment_config
=
Experiments
()
nnictl_experiment_config
.
add_experiment
(
'xOpEwA5w'
,
'8080'
,
'1970/01/1 01:01:01'
,
'aGew0x'
,
'local'
,
'example_sklearn-classification'
)
nni_config
=
Config
(
'aGew0x'
)
# mock process
cmds
=
[
'sleep'
,
'3600000'
]
process
=
Popen
(
cmds
,
stdout
=
PIPE
,
stderr
=
STDOUT
)
nni_config
.
set_config
(
'restServerPid'
,
process
.
pid
)
nni_config
.
set_config
(
'experimentId'
,
'xOpEwA5w'
)
nni_config
.
set_config
(
'restServerPort'
,
8080
)
nni_config
.
set_config
(
'webuiUrl'
,
[
'http://localhost:8080'
])
experiment_config
=
get_yml_content
(
'./tests/config_files/valid/test.yml'
)
nni_config
.
set_config
(
'experimentConfig'
,
experiment_config
)
print_green
(
"expriment start success, experiment id: xOpEwA5w"
)
def
stop_mock_experiment
():
config
=
Config
(
'config'
)
kill_command
(
config
.
get_config
(
'restServerPid'
))
nnictl_experiment_config
=
Experiments
()
nnictl_experiment_config
.
remove_experiment
(
'xOpEwA5w'
)
def
generate_args_parser
():
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'id'
,
nargs
=
'?'
)
parser
.
add_argument
(
'--port'
,
'-p'
,
dest
=
'port'
)
parser
.
add_argument
(
'--all'
,
'-a'
,
action
=
'store_true'
)
parser
.
add_argument
(
'--head'
,
type
=
int
)
parser
.
add_argument
(
'--tail'
,
type
=
int
)
return
parser
def
generate_args
():
parser
=
generate_args_parser
()
args
=
parser
.
parse_args
([
'xOpEwA5w'
])
return
args
tools/nni_cmd/tests/mock/nnictl_metadata/.experiment
0 → 100644
View file @
5d7c1cd8
{"xOpEwA5w": {"port": 8080, "startTime": "1970/01/1 01:01:01", "endTime": "1970-01-2 01:01:01", "status": "RUNNING", "fileName": "aGew0x", "platform": "local", "experimentName": "example_sklearn-classification"}}
tools/nni_cmd/tests/mock/nnictl_metadata/aGew0x/.config
0 → 100644
View file @
5d7c1cd8
{
"experimentConfig"
: {
"authorName"
:
"default"
,
"experimentName"
:
"example_sklearn-classification"
,
"trialConcurrency"
:
5
,
"maxExecDuration"
:
3600
,
"maxTrialNum"
:
100
,
"trainingServicePlatform"
:
"local"
,
"searchSpacePath"
:
"../../../config_files/valid/search_space.json"
,
"useAnnotation"
:
false
,
"tuner"
: {
"builtinTunerName"
:
"TPE"
,
"classArgs"
: {
"optimize_mode"
:
"maximize"
}},
"trial"
: {
"command"
:
"python3 main.py"
,
"codeDir"
:
"../../../config_files/valid/."
,
"gpuNum"
:
0
}},
"restServerPort"
:
8080
,
"restServerPid"
:
7952
,
"experimentId"
:
"xOpEwA5w"
,
"webuiUrl"
: [
"http://localhost:8080"
]}
tools/nni_cmd/tests/mock/nnictl_metadata/aGew0x/stderr
0 → 100644
View file @
5d7c1cd8
stderr
tools/nni_cmd/tests/mock/nnictl_metadata/aGew0x/stdout
0 → 100644
View file @
5d7c1cd8
stdout
tools/nni_cmd/tests/mock/nnictl_metadata/config/.config
0 → 100644
View file @
5d7c1cd8
{
"experimentId"
:
"xOpEwA5w"
}
tools/nni_cmd/tests/mock/restful_server.py
0 → 100644
View file @
5d7c1cd8
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
responses
def
mock_check_status
():
responses
.
add
(
responses
.
GET
,
"http://localhost:8080/api/v1/nni/check-status"
,
json
=
{
"status"
:
"RUNNING"
,
"errors"
:[]},
status
=
200
)
def
mock_version
():
responses
.
add
(
responses
.
GET
,
"http://localhost:8080/api/v1/nni/version"
,
json
=
{
'value'
:
1.8
},
status
=
200
)
def
mock_get_experiment_profile
():
responses
.
add
(
responses
.
GET
,
"http://localhost:8080/api/v1/nni/experiment"
,
json
=
{
"id"
:
"bkfhOdUl"
,
"revision"
:
5
,
"execDuration"
:
10
,
"logDir"
:
"/home/shinyang/nni-experiments/bkfhOdUl"
,
"nextSequenceId"
:
2
,
"params"
:{
"authorName"
:
"default"
,
"experimentName"
:
"example_sklearn-classification"
,
"trialConcurrency"
:
1
,
"maxExecDuration"
:
3600
,
"maxTrialNum"
:
1
,
"searchSpace"
:
"{
\"
C
\"
: {
\"
_type
\"
:
\"
uniform
\"
,
\"
_value
\"
: [0.1, 1]},
\
\"
kernel
\"
: {
\"
_type
\"
:
\"
choice
\"
,
\"
_value
\"
: [
\"
linear
\"
,
\"
rbf
\"
,
\"
poly
\"
,
\"
sigmoid
\"
]},
\
\"
degree
\"
: {
\"
_type
\"
:
\"
choice
\"
,
\"
_value
\"
: [1, 2, 3, 4]},
\"
gamma
\"
: {
\"
_type
\"
:
\"
uniform
\"
,
\
\"
_value
\"
: [0.01, 0.1]}}"
,
\
"trainingServicePlatform"
:
"local"
,
"tuner"
:{
"builtinTunerName"
:
"TPE"
,
"classArgs"
:{
"optimize_mode"
:
"maximize"
},
\
"checkpointDir"
:
"/home/shinyang/nni-experiments/bkfhOdUl/checkpoint"
},
"versionCheck"
:
"true"
,
\
"clusterMetaData"
:[{
"key"
:
"codeDir"
,
"value"
:
"/home/shinyang/folder/examples/trials/sklearn/classification/."
},
\
{
"key"
:
"command"
,
"value"
:
"python3 main.py"
}]},
"startTime"
:
1600326895536
,
"endTime"
:
1600326910605
},
status
=
200
)
def
mock_update_experiment_profile
():
responses
.
add
(
responses
.
PUT
,
'http://localhost:8080/api/v1/nni/experiment'
,
json
=
{
"status"
:
"RUNNING"
,
"errors"
:[]},
status
=
200
,
content_type
=
'application/json'
,
)
def
mock_import_data
():
responses
.
add
(
responses
.
POST
,
'http://localhost:8080/api/v1/nni/experiment/import-data'
,
json
=
{
"result"
:
"data"
},
status
=
201
,
content_type
=
'application/json'
,
)
def
mock_start_experiment
():
responses
.
add
(
responses
.
POST
,
'http://localhost:8080/api/v1/nni/experiment'
,
json
=
{
"status"
:
"RUNNING"
,
"errors"
:[]},
status
=
201
,
content_type
=
'application/json'
,
)
def
mock_get_trial_job_statistics
():
responses
.
add
(
responses
.
GET
,
'http://localhost:8080/api/v1/nni/job-statistics'
,
json
=
[{
"trialJobStatus"
:
"SUCCEEDED"
,
"trialJobNumber"
:
1
}],
status
=
200
,
content_type
=
'application/json'
,
)
def
mock_set_cluster_metadata
():
responses
.
add
(
responses
.
PUT
,
'http://localhost:8080/api/v1/nni/experiment/cluster-metadata'
,
json
=
[{
"trialJobStatus"
:
"SUCCEEDED"
,
"trialJobNumber"
:
1
}],
status
=
201
,
content_type
=
'application/json'
,
)
def
mock_list_trial_jobs
():
responses
.
add
(
responses
.
GET
,
'http://localhost:8080/api/v1/nni/trial-jobs'
,
json
=
[{
"id"
:
"GPInz"
,
"status"
:
"SUCCEEDED"
,
"hyperParameters"
:[
"{
\"
parameter_id
\"
:0,
\
\"
parameter_source
\"
:
\"
algorithm
\"
,
\"
parameters
\"
:{
\"
C
\"
:0.8748364659110364,
\
\"
kernel
\"
:
\"
linear
\"
,
\"
degree
\"
:1,
\"
gamma
\"
:0.040451413392113666},
\
\"
parameter_index
\"
:0}"
],
"logPath"
:
"file://localhost:/home/shinyang/nni-experiments/bkfhOdUl/trials/GPInz"
,
"startTime"
:
1600326905581
,
"sequenceId"
:
0
,
"endTime"
:
1600326906629
,
"finalMetricData"
:[{
"timestamp"
:
1600326906493
,
"trialJobId"
:
"GPInz"
,
"parameterId"
:
"0"
,
"type"
:
"FINAL"
,
"sequence"
:
0
,
"data"
:
"
\"
0.9866666666666667
\"
"
}]}],
status
=
200
,
content_type
=
'application/json'
,
)
def
mock_get_trial_job
():
responses
.
add
(
responses
.
GET
,
'http://localhost:8080/api/v1/nni/trial-jobs/:id'
,
json
=
{
"id"
:
"GPInz"
,
"status"
:
"SUCCEEDED"
,
"hyperParameters"
:[
"{
\"
parameter_id
\"
:0,
\
\"
parameter_source
\"
:
\"
algorithm
\"
,
\"
parameters
\"
:{
\"
C
\"
:0.8748364659110364,
\
\"
kernel
\"
:
\"
linear
\"
,
\"
degree
\"
:1,
\"
gamma
\"
:0.040451413392113666},
\
\"
parameter_index
\"
:0}"
],
"logPath"
:
"file://localhost:/home/shinyang/nni-experiments/bkfhOdUl/trials/GPInz"
,
"startTime"
:
1600326905581
,
"sequenceId"
:
0
,
"endTime"
:
1600326906629
,
"finalMetricData"
:[{
"timestamp"
:
1600326906493
,
"trialJobId"
:
"GPInz"
,
"parameterId"
:
"0"
,
"type"
:
"FINAL"
,
"sequence"
:
0
,
"data"
:
"
\"
0.9866666666666667
\"
"
}]},
status
=
200
,
content_type
=
'application/json'
,
)
def
mock_add_trial_job
():
responses
.
add
(
responses
.
POST
,
'http://localhost:8080/api/v1/nni/trial-jobs'
,
json
=
[{
"trialJobStatus"
:
"SUCCEEDED"
,
"trialJobNumber"
:
1
}],
status
=
201
,
content_type
=
'application/json'
,
)
def
mock_cancel_trial_job
():
responses
.
add
(
responses
.
DELETE
,
'http://localhost:8080/api/v1/nni/trial-jobs/:id'
,
json
=
[{
"trialJobStatus"
:
"SUCCEEDED"
,
"trialJobNumber"
:
1
}],
status
=
200
,
content_type
=
'application/json'
,
)
def
mock_get_metric_data
():
responses
.
add
(
responses
.
DELETE
,
'http://localhost:8080/api/v1/nni/metric-data/:job_id*?'
,
json
=
[{
"timestamp"
:
1600326906486
,
"trialJobId"
:
"GPInz"
,
"parameterId"
:
"0"
,
"type"
:
"PERIODICAL"
,
"sequence"
:
0
,
"data"
:
"
\"
0.9866666666666667
\"
"
},
{
"timestamp"
:
1600326906493
,
"trialJobId"
:
"GPInz"
,
"parameterId"
:
"0"
,
"type"
:
"FINAL"
,
"sequence"
:
0
,
"data"
:
"
\"
0.9866666666666667
\"
"
}],
status
=
200
,
content_type
=
'application/json'
,
)
def
mock_get_metric_data_by_range
():
responses
.
add
(
responses
.
DELETE
,
'http://localhost:8080/api/v1/nni/metric-data-range/:min_seq_id/:max_seq_id'
,
json
=
[{
"timestamp"
:
1600326906486
,
"trialJobId"
:
"GPInz"
,
"parameterId"
:
"0"
,
"type"
:
"PERIODICAL"
,
"sequence"
:
0
,
"data"
:
"
\"
0.9866666666666667
\"
"
},
{
"timestamp"
:
1600326906493
,
"trialJobId"
:
"GPInz"
,
"parameterId"
:
"0"
,
"type"
:
"FINAL"
,
"sequence"
:
0
,
"data"
:
"
\"
0.9866666666666667
\"
"
}],
status
=
200
,
content_type
=
'application/json'
,
)
def
mock_get_latest_metric_data
():
responses
.
add
(
responses
.
DELETE
,
'http://localhost:8080/api/v1/nni/metric-data-latest/'
,
json
=
[{
"timestamp"
:
1600326906493
,
"trialJobId"
:
"GPInz"
,
"parameterId"
:
"0"
,
"type"
:
"FINAL"
,
"sequence"
:
0
,
"data"
:
"
\"
0.9866666666666667
\"
"
},{
"timestamp"
:
1600326906486
,
"trialJobId"
:
"GPInz"
,
"parameterId"
:
"0"
,
"type"
:
"PERIODICAL"
,
"sequence"
:
0
,
"data"
:
"
\"
0.9866666666666667
\"
"
}],
status
=
200
,
content_type
=
'application/json'
,
)
def
mock_get_trial_log
():
responses
.
add
(
responses
.
DELETE
,
'http://localhost:8080/api/v1/nni/trial-log/:id/:type'
,
json
=
{
"status"
:
"RUNNING"
,
"errors"
:[]},
status
=
200
,
content_type
=
'application/json'
,
)
def
mock_export_data
():
responses
.
add
(
responses
.
DELETE
,
'http://localhost:8080/api/v1/nni/export-data'
,
json
=
{
"status"
:
"RUNNING"
,
"errors"
:[]},
status
=
200
,
content_type
=
'application/json'
,
)
def
init_response
():
mock_check_status
()
mock_version
()
mock_get_experiment_profile
()
mock_set_cluster_metadata
()
mock_list_trial_jobs
()
mock_get_trial_job
()
mock_add_trial_job
()
mock_cancel_trial_job
()
mock_get_metric_data
()
mock_get_metric_data_by_range
()
mock_get_latest_metric_data
()
mock_get_trial_log
()
mock_export_data
()
tools/nni_cmd/tests/test_common_utils.py
0 → 100644
View file @
5d7c1cd8
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from
unittest
import
TestCase
,
main
from
nni_cmd.common_utils
import
get_yml_content
,
get_json_content
,
detect_process
from
mock.restful_server
import
init_response
from
subprocess
import
Popen
,
PIPE
,
STDOUT
from
nni_cmd.command_utils
import
kill_command
class
CommonUtilsTestCase
(
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
init_response
()
def
test_get_yml
(
self
):
content
=
get_yml_content
(
'./tests/config_files/test_files/test_yaml.yml'
)
self
.
assertEqual
(
content
,
{
'field'
:
'test'
})
def
test_get_json
(
self
):
content
=
get_json_content
(
'./tests/config_files/test_files/test_json.json'
)
self
.
assertEqual
(
content
,
{
'field'
:
'test'
})
def
test_detect_process
(
self
):
cmds
=
[
'sleep'
,
'360000'
]
process
=
Popen
(
cmds
,
stdout
=
PIPE
,
stderr
=
STDOUT
)
self
.
assertTrue
(
detect_process
(
process
.
pid
))
kill_command
(
process
.
pid
)
if
__name__
==
'__main__'
:
main
()
tools/nni_cmd/tests/test_config_utils.py
0 → 100644
View file @
5d7c1cd8
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from
unittest
import
TestCase
,
main
from
nni_cmd.config_utils
import
Config
,
Experiments
HOME_PATH
=
"./tests/mock/nnictl_metadata"
class
CommonUtilsTestCase
(
TestCase
):
def
test_get_experiment
(
self
):
experiment
=
Experiments
(
HOME_PATH
)
self
.
assertTrue
(
'xOpEwA5w'
in
experiment
.
get_all_experiments
())
def
test_update_experiment
(
self
):
experiment
=
Experiments
(
HOME_PATH
)
experiment
.
add_experiment
(
'xOpEwA5w'
,
8081
,
'N/A'
,
'aGew0x'
,
'local'
,
'test'
,
endTime
=
'N/A'
,
status
=
'INITIALIZED'
)
self
.
assertTrue
(
'xOpEwA5w'
in
experiment
.
get_all_experiments
())
experiment
.
remove_experiment
(
'xOpEwA5w'
)
self
.
assertFalse
(
'xOpEwA5w'
in
experiment
.
get_all_experiments
())
def
test_get_config
(
self
):
config
=
Config
(
'config'
,
HOME_PATH
)
self
.
assertEqual
(
config
.
get_config
(
'experimentId'
),
'xOpEwA5w'
)
def
test_set_config
(
self
):
config
=
Config
(
'config'
,
HOME_PATH
)
self
.
assertNotEqual
(
config
.
get_config
(
'experimentId'
),
'testId'
)
config
.
set_config
(
'experimentId'
,
'testId'
)
self
.
assertEqual
(
config
.
get_config
(
'experimentId'
),
'testId'
)
config
.
set_config
(
'experimentId'
,
'xOpEwA5w'
)
if
__name__
==
'__main__'
:
main
()
tools/nni_cmd/tests/test_nnictl_utils.py
0 → 100644
View file @
5d7c1cd8
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from
mock.restful_server
import
init_response
from
mock.experiment
import
create_mock_experiment
,
stop_mock_experiment
,
generate_args_parser
,
\
generate_args
from
nni_cmd.nnictl_utils
import
get_experiment_time
,
get_experiment_status
,
\
check_experiment_id
,
parse_ids
,
get_config_filename
,
get_experiment_port
,
check_rest
,
\
trial_ls
,
list_experiment
from
unittest
import
TestCase
,
main
import
responses
class
CommonUtilsTestCase
(
TestCase
):
@
classmethod
def
setUp
(
self
):
init_response
()
create_mock_experiment
()
@
classmethod
def
tearDown
(
self
):
stop_mock_experiment
()
@
responses
.
activate
def
test_get_experiment_status
(
self
):
self
.
assertEqual
(
'RUNNING'
,
get_experiment_status
(
8080
))
@
responses
.
activate
def
test_check_experiment_id
(
self
):
parser
=
generate_args_parser
()
args
=
parser
.
parse_args
([
'xOpEwA5w'
])
self
.
assertEqual
(
'xOpEwA5w'
,
check_experiment_id
(
args
))
@
responses
.
activate
def
test_parse_ids
(
self
):
parser
=
generate_args_parser
()
args
=
parser
.
parse_args
([
'xOpEwA5w'
])
self
.
assertEqual
([
'xOpEwA5w'
],
parse_ids
(
args
))
@
responses
.
activate
def
test_get_config_file_name
(
self
):
args
=
generate_args
()
self
.
assertEqual
(
'aGew0x'
,
get_config_filename
(
args
))
@
responses
.
activate
def
test_get_experiment_port
(
self
):
args
=
generate_args
()
self
.
assertEqual
(
'8080'
,
get_experiment_port
(
args
))
@
responses
.
activate
def
test_check_rest
(
self
):
args
=
generate_args
()
self
.
assertEqual
(
True
,
check_rest
(
args
))
@
responses
.
activate
def
test_trial_ls
(
self
):
args
=
generate_args
()
trials
=
trial_ls
(
args
)
self
.
assertEqual
(
trials
[
0
][
'id'
],
'GPInz'
)
if
__name__
==
'__main__'
:
main
()
tools/setup.py
View file @
5d7c1cd8
...
...
@@ -11,6 +11,7 @@ setuptools.setup(
python_requires
=
'>=3.6'
,
install_requires
=
[
'requests'
,
'responses'
,
'ruamel.yaml'
,
'psutil'
,
'astor'
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment