Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
4bc204bf
Commit
4bc204bf
authored
Apr 24, 2019
by
Shufan Huang
Committed by
SparkSnail
Apr 24, 2019
Browse files
Fix bug bash of import data feature (#1009)
parent
68c26dd4
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
37 additions
and
2 deletions
+37
-2
src/sdk/pynni/nni/bohb_advisor/bohb_advisor.py
src/sdk/pynni/nni/bohb_advisor/bohb_advisor.py
+3
-0
src/sdk/pynni/nni/gridsearch_tuner/gridsearch_tuner.py
src/sdk/pynni/nni/gridsearch_tuner/gridsearch_tuner.py
+5
-0
src/sdk/pynni/nni/hyperopt_tuner/hyperopt_tuner.py
src/sdk/pynni/nni/hyperopt_tuner/hyperopt_tuner.py
+25
-1
src/sdk/pynni/nni/metis_tuner/metis_tuner.py
src/sdk/pynni/nni/metis_tuner/metis_tuner.py
+3
-0
tools/nni_cmd/updater.py
tools/nni_cmd/updater.py
+1
-1
No files found.
src/sdk/pynni/nni/bohb_advisor/bohb_advisor.py
View file @
4bc204bf
...
...
@@ -595,6 +595,9 @@ class BOHB(MsgDispatcherBase):
_params
=
trial_info
[
"parameter"
]
assert
"value"
in
trial_info
_value
=
trial_info
[
'value'
]
if
not
_value
:
logger
.
info
(
"Useless trial data, value is %s, skip this trial data."
%
_value
)
continue
budget_exist_flag
=
False
barely_params
=
dict
()
for
keys
in
_params
:
...
...
src/sdk/pynni/nni/gridsearch_tuner/gridsearch_tuner.py
View file @
4bc204bf
...
...
@@ -164,6 +164,11 @@ class GridSearchTuner(Tuner):
_completed_num
+=
1
assert
"parameter"
in
trial_info
_params
=
trial_info
[
"parameter"
]
assert
"value"
in
trial_info
_value
=
trial_info
[
'value'
]
if
not
_value
:
logger
.
info
(
"Useless trial data, value is %s, skip this trial data."
%
_value
)
continue
_params_tuple
=
convert_dict2tuple
(
_params
)
self
.
supplement_data
[
_params_tuple
]
=
True
logger
.
info
(
"Successfully import data to grid search tuner."
)
src/sdk/pynni/nni/hyperopt_tuner/hyperopt_tuner.py
View file @
4bc204bf
...
...
@@ -139,6 +139,27 @@ def json2vals(in_x, vals, out_y, name=ROOT):
for
i
,
temp
in
enumerate
(
in_x
):
json2vals
(
temp
,
vals
[
i
],
out_y
,
name
+
'[%d]'
%
i
)
def
params2tuner_params
(
in_x
,
parameter
):
"""
change parameters in NNI format to parameters in hyperopt format.
For example, NNI receive parameters like:
{'dropout_rate': 0.8, 'conv_size': 3, 'hidden_size': 512}
Will change to format in hyperopt, like:
{'dropout_rate': 0.8, 'conv_size': {'_index': 1, '_value': 3}, 'hidden_size': {'_index': 1, '_value': 512}}
"""
tuner_params
=
dict
()
for
key
in
parameter
.
keys
():
value
=
parameter
[
key
]
_type
=
in_x
[
key
][
TYPE
]
if
_type
==
'choice'
:
_idx
=
in_x
[
key
][
VALUE
].
index
(
value
)
tuner_params
[
key
]
=
{
INDEX
:
_idx
,
VALUE
:
value
}
else
:
tuner_params
[
key
]
=
value
return
tuner_params
def
_split_index
(
params
):
"""
...
...
@@ -373,8 +394,11 @@ class HyperoptTuner(Tuner):
_params
=
trial_info
[
"parameter"
]
assert
"value"
in
trial_info
_value
=
trial_info
[
'value'
]
if
not
_value
:
logger
.
info
(
"Useless trial data, value is %s, skip this trial data."
%
_value
)
continue
self
.
supplement_data_num
+=
1
_parameter_id
=
'_'
.
join
([
"ImportData"
,
str
(
self
.
supplement_data_num
)])
self
.
total_data
[
_parameter_id
]
=
_
params
self
.
total_data
[
_parameter_id
]
=
params
2tuner_params
(
self
.
json
,
_params
)
self
.
receive_trial_result
(
parameter_id
=
_parameter_id
,
parameters
=
_params
,
value
=
_value
)
logger
.
info
(
"Successfully import data to TPE/Anneal tuner."
)
src/sdk/pynni/nni/metis_tuner/metis_tuner.py
View file @
4bc204bf
...
...
@@ -417,6 +417,9 @@ class MetisTuner(Tuner):
_params
=
trial_info
[
"parameter"
]
assert
"value"
in
trial_info
_value
=
trial_info
[
'value'
]
if
not
_value
:
logger
.
info
(
"Useless trial data, value is %s, skip this trial data."
%
_value
)
continue
self
.
supplement_data_num
+=
1
_parameter_id
=
'_'
.
join
([
"ImportData"
,
str
(
self
.
supplement_data_num
)])
self
.
total_data
.
append
(
_params
)
...
...
tools/nni_cmd/updater.py
View file @
4bc204bf
...
...
@@ -136,7 +136,7 @@ def import_data(args):
args
.
port
=
get_experiment_port
(
args
)
if
args
.
port
is
not
None
:
if
import_data_to_restful_server
(
args
,
content
):
p
rint_normal
(
'Import data success!'
)
p
ass
else
:
print_error
(
'Import data failed!'
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment