Unverified Commit 29e60c22 authored by QuanluZhang's avatar QuanluZhang Committed by GitHub
Browse files

make assessors support metric data in dict (#2121)

parent 46342a74
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
import logging import logging
import datetime import datetime
from nni.assessor import Assessor, AssessResult from nni.assessor import Assessor, AssessResult
from nni.utils import extract_scalar_history
from .model_factory import CurveModel from .model_factory import CurveModel
logger = logging.getLogger('curvefitting_Assessor') logger = logging.getLogger('curvefitting_Assessor')
...@@ -91,10 +92,11 @@ class CurvefittingAssessor(Assessor): ...@@ -91,10 +92,11 @@ class CurvefittingAssessor(Assessor):
Exception Exception
unrecognize exception in curvefitting_assessor unrecognize exception in curvefitting_assessor
""" """
self.trial_history = trial_history scalar_trial_history = extract_scalar_history(trial_history)
self.trial_history = scalar_trial_history
if not self.set_best_performance: if not self.set_best_performance:
return AssessResult.Good return AssessResult.Good
curr_step = len(trial_history) curr_step = len(scalar_trial_history)
if curr_step < self.start_step: if curr_step < self.start_step:
return AssessResult.Good return AssessResult.Good
...@@ -106,7 +108,7 @@ class CurvefittingAssessor(Assessor): ...@@ -106,7 +108,7 @@ class CurvefittingAssessor(Assessor):
start_time = datetime.datetime.now() start_time = datetime.datetime.now()
# Predict the final result # Predict the final result
curvemodel = CurveModel(self.target_pos) curvemodel = CurveModel(self.target_pos)
predict_y = curvemodel.predict(trial_history) predict_y = curvemodel.predict(scalar_trial_history)
logger.info('Prediction done. Trial job id = %s. Predict value = %s', trial_job_id, predict_y) logger.info('Prediction done. Trial job id = %s. Predict value = %s', trial_job_id, predict_y)
if predict_y is None: if predict_y is None:
logger.info('wait for more information to predict precisely') logger.info('wait for more information to predict precisely')
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
import logging import logging
from nni.assessor import Assessor, AssessResult from nni.assessor import Assessor, AssessResult
from nni.utils import extract_scalar_history
logger = logging.getLogger('medianstop_Assessor') logger = logging.getLogger('medianstop_Assessor')
...@@ -91,20 +92,12 @@ class MedianstopAssessor(Assessor): ...@@ -91,20 +92,12 @@ class MedianstopAssessor(Assessor):
if curr_step < self._start_step: if curr_step < self._start_step:
return AssessResult.Good return AssessResult.Good
try: scalar_trial_history = extract_scalar_history(trial_history)
num_trial_history = [float(ele) for ele in trial_history] self._update_data(trial_job_id, scalar_trial_history)
except (TypeError, ValueError) as error:
logger.warning('incorrect data type or value:')
logger.exception(error)
except Exception as error:
logger.warning('unrecognized exception in medianstop_assessor:')
logger.exception(error)
self._update_data(trial_job_id, num_trial_history)
if self._high_better: if self._high_better:
best_history = max(trial_history) best_history = max(scalar_trial_history)
else: else:
best_history = min(trial_history) best_history = min(scalar_trial_history)
avg_array = [] avg_array = []
for id_ in self._completed_avg_history: for id_ in self._completed_avg_history:
......
...@@ -234,4 +234,5 @@ class MsgDispatcher(MsgDispatcherBase): ...@@ -234,4 +234,5 @@ class MsgDispatcher(MsgDispatcherBase):
if multi_thread_enabled(): if multi_thread_enabled():
self._handle_final_metric_data(data) self._handle_final_metric_data(data)
else: else:
data['value'] = to_json(data['value'])
self.enqueue_command(CommandType.ReportMetricData, data) self.enqueue_command(CommandType.ReportMetricData, data)
...@@ -62,6 +62,13 @@ def extract_scalar_reward(value, scalar_key='default'): ...@@ -62,6 +62,13 @@ def extract_scalar_reward(value, scalar_key='default'):
""" """
Extract scalar reward from trial result. Extract scalar reward from trial result.
Parameters
----------
value : int, float, dict
the reported final metric data
scalar_key : str
the key name that indicates the numeric number
Raises Raises
------ ------
RuntimeError RuntimeError
...@@ -78,6 +85,26 @@ def extract_scalar_reward(value, scalar_key='default'): ...@@ -78,6 +85,26 @@ def extract_scalar_reward(value, scalar_key='default'):
return reward return reward
def extract_scalar_history(trial_history, scalar_key='default'):
"""
Extract scalar value from a list of intermediate results.
Parameters
----------
trial_history : list
accumulated intermediate results of a trial
scalar_key : str
the key name that indicates the numeric number
Raises
------
RuntimeError
Incorrect final result: the final result should be float/int,
or a dict which has a key named "default" whose value is float/int.
"""
return [extract_scalar_reward(ele, scalar_key) for ele in trial_history]
def convert_dict2tuple(value): def convert_dict2tuple(value):
""" """
convert dict type to tuple to solve unhashable problem. convert dict type to tuple to solve unhashable problem.
...@@ -90,7 +117,9 @@ def convert_dict2tuple(value): ...@@ -90,7 +117,9 @@ def convert_dict2tuple(value):
def init_dispatcher_logger(): def init_dispatcher_logger():
""" Initialize dispatcher logging configuration""" """
Initialize dispatcher logging configuration
"""
logger_file_path = 'dispatcher.log' logger_file_path = 'dispatcher.log'
if dispatcher_env_vars.NNI_LOG_DIRECTORY is not None: if dispatcher_env_vars.NNI_LOG_DIRECTORY is not None:
logger_file_path = os.path.join(dispatcher_env_vars.NNI_LOG_DIRECTORY, logger_file_path) logger_file_path = os.path.join(dispatcher_env_vars.NNI_LOG_DIRECTORY, logger_file_path)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment