Unverified Commit eaad9852 authored by cruiseliu's avatar cruiseliu Committed by GitHub
Browse files

Fix a bug that new TPE does not support dict metrics (#4531)

parent cb408193
......@@ -22,6 +22,7 @@ from scipy.special import erf # pylint: disable=no-name-in-module
from nni.tuner import Tuner
from nni.common.hpo_utils import OptimizeMode, format_search_space, deformat_parameters, format_parameters
from nni.utils import extract_scalar_reward
from . import random_tuner
_logger = logging.getLogger('nni.tuner.tpe')
......@@ -126,9 +127,11 @@ class TpeTuner(Tuner):
self._running_params[parameter_id] = params
return deformat_parameters(params, self.space)
def receive_trial_result(self, parameter_id, _parameters, loss, **kwargs):
if self.optimize_mode is OptimizeMode.Maximize:
loss = -loss
def receive_trial_result(self, parameter_id, _parameters, value, **kwargs):
if self.optimize_mode is OptimizeMode.Minimize:
loss = extract_scalar_reward(value)
else:
loss = -extract_scalar_reward(value)
if self.liar:
self.liar.update(loss)
params = self._running_params.pop(parameter_id)
......
......@@ -58,6 +58,8 @@ class BuiltinTunersTestCase(TestCase):
return receive
def send_trial_result(self, tuner, parameter_id, parameters, metrics):
if parameter_id % 2 == 1:
metrics = {'default': metrics, 'extra': 'hello'}
tuner.receive_trial_result(parameter_id, parameters, metrics)
tuner.trial_end(parameter_id, True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment