Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
29e60c22
Unverified
Commit
29e60c22
authored
Mar 06, 2020
by
QuanluZhang
Committed by
GitHub
Mar 06, 2020
Browse files
make assessors support metric data in dict (#2121)
parent
46342a74
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
41 additions
and
16 deletions
+41
-16
src/sdk/pynni/nni/curvefitting_assessor/curvefitting_assessor.py
.../pynni/nni/curvefitting_assessor/curvefitting_assessor.py
+5
-3
src/sdk/pynni/nni/medianstop_assessor/medianstop_assessor.py
src/sdk/pynni/nni/medianstop_assessor/medianstop_assessor.py
+5
-12
src/sdk/pynni/nni/msg_dispatcher.py
src/sdk/pynni/nni/msg_dispatcher.py
+1
-0
src/sdk/pynni/nni/utils.py
src/sdk/pynni/nni/utils.py
+30
-1
No files found.
src/sdk/pynni/nni/curvefitting_assessor/curvefitting_assessor.py
View file @
29e60c22
...
...
@@ -4,6 +4,7 @@
import
logging
import
datetime
from
nni.assessor
import
Assessor
,
AssessResult
from
nni.utils
import
extract_scalar_history
from
.model_factory
import
CurveModel
logger
=
logging
.
getLogger
(
'curvefitting_Assessor'
)
...
...
@@ -91,10 +92,11 @@ class CurvefittingAssessor(Assessor):
Exception
unrecognize exception in curvefitting_assessor
"""
self
.
trial_history
=
trial_history
scalar_trial_history
=
extract_scalar_history
(
trial_history
)
self
.
trial_history
=
scalar_trial_history
if
not
self
.
set_best_performance
:
return
AssessResult
.
Good
curr_step
=
len
(
trial_history
)
curr_step
=
len
(
scalar_
trial_history
)
if
curr_step
<
self
.
start_step
:
return
AssessResult
.
Good
...
...
@@ -106,7 +108,7 @@ class CurvefittingAssessor(Assessor):
start_time
=
datetime
.
datetime
.
now
()
# Predict the final result
curvemodel
=
CurveModel
(
self
.
target_pos
)
predict_y
=
curvemodel
.
predict
(
trial_history
)
predict_y
=
curvemodel
.
predict
(
scalar_
trial_history
)
logger
.
info
(
'Prediction done. Trial job id = %s. Predict value = %s'
,
trial_job_id
,
predict_y
)
if
predict_y
is
None
:
logger
.
info
(
'wait for more information to predict precisely'
)
...
...
src/sdk/pynni/nni/medianstop_assessor/medianstop_assessor.py
View file @
29e60c22
...
...
@@ -3,6 +3,7 @@
import
logging
from
nni.assessor
import
Assessor
,
AssessResult
from
nni.utils
import
extract_scalar_history
logger
=
logging
.
getLogger
(
'medianstop_Assessor'
)
...
...
@@ -91,20 +92,12 @@ class MedianstopAssessor(Assessor):
if
curr_step
<
self
.
_start_step
:
return
AssessResult
.
Good
try
:
num_trial_history
=
[
float
(
ele
)
for
ele
in
trial_history
]
except
(
TypeError
,
ValueError
)
as
error
:
logger
.
warning
(
'incorrect data type or value:'
)
logger
.
exception
(
error
)
except
Exception
as
error
:
logger
.
warning
(
'unrecognized exception in medianstop_assessor:'
)
logger
.
exception
(
error
)
self
.
_update_data
(
trial_job_id
,
num_trial_history
)
scalar_trial_history
=
extract_scalar_history
(
trial_history
)
self
.
_update_data
(
trial_job_id
,
scalar_trial_history
)
if
self
.
_high_better
:
best_history
=
max
(
trial_history
)
best_history
=
max
(
scalar_
trial_history
)
else
:
best_history
=
min
(
trial_history
)
best_history
=
min
(
scalar_
trial_history
)
avg_array
=
[]
for
id_
in
self
.
_completed_avg_history
:
...
...
src/sdk/pynni/nni/msg_dispatcher.py
View file @
29e60c22
...
...
@@ -234,4 +234,5 @@ class MsgDispatcher(MsgDispatcherBase):
if
multi_thread_enabled
():
self
.
_handle_final_metric_data
(
data
)
else
:
data
[
'value'
]
=
to_json
(
data
[
'value'
])
self
.
enqueue_command
(
CommandType
.
ReportMetricData
,
data
)
src/sdk/pynni/nni/utils.py
View file @
29e60c22
...
...
@@ -62,6 +62,13 @@ def extract_scalar_reward(value, scalar_key='default'):
"""
Extract scalar reward from trial result.
Parameters
----------
value : int, float, dict
the reported final metric data
scalar_key : str
the key name that indicates the numeric number
Raises
------
RuntimeError
...
...
@@ -78,6 +85,26 @@ def extract_scalar_reward(value, scalar_key='default'):
return
reward
def
extract_scalar_history
(
trial_history
,
scalar_key
=
'default'
):
"""
Extract scalar value from a list of intermediate results.
Parameters
----------
trial_history : list
accumulated intermediate results of a trial
scalar_key : str
the key name that indicates the numeric number
Raises
------
RuntimeError
Incorrect final result: the final result should be float/int,
or a dict which has a key named "default" whose value is float/int.
"""
return
[
extract_scalar_reward
(
ele
,
scalar_key
)
for
ele
in
trial_history
]
def
convert_dict2tuple
(
value
):
"""
convert dict type to tuple to solve unhashable problem.
...
...
@@ -90,7 +117,9 @@ def convert_dict2tuple(value):
def
init_dispatcher_logger
():
""" Initialize dispatcher logging configuration"""
"""
Initialize dispatcher logging configuration
"""
logger_file_path
=
'dispatcher.log'
if
dispatcher_env_vars
.
NNI_LOG_DIRECTORY
is
not
None
:
logger_file_path
=
os
.
path
.
join
(
dispatcher_env_vars
.
NNI_LOG_DIRECTORY
,
logger_file_path
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment