Unverified Commit eaad9852 authored by cruiseliu's avatar cruiseliu Committed by GitHub
Browse files

Fix a bug that new TPE does not support dict metrics (#4531)

parent cb408193
...@@ -22,6 +22,7 @@ from scipy.special import erf # pylint: disable=no-name-in-module ...@@ -22,6 +22,7 @@ from scipy.special import erf # pylint: disable=no-name-in-module
from nni.tuner import Tuner from nni.tuner import Tuner
from nni.common.hpo_utils import OptimizeMode, format_search_space, deformat_parameters, format_parameters from nni.common.hpo_utils import OptimizeMode, format_search_space, deformat_parameters, format_parameters
from nni.utils import extract_scalar_reward
from . import random_tuner from . import random_tuner
_logger = logging.getLogger('nni.tuner.tpe') _logger = logging.getLogger('nni.tuner.tpe')
...@@ -126,9 +127,11 @@ class TpeTuner(Tuner): ...@@ -126,9 +127,11 @@ class TpeTuner(Tuner):
self._running_params[parameter_id] = params self._running_params[parameter_id] = params
return deformat_parameters(params, self.space) return deformat_parameters(params, self.space)
def receive_trial_result(self, parameter_id, _parameters, loss, **kwargs): def receive_trial_result(self, parameter_id, _parameters, value, **kwargs):
if self.optimize_mode is OptimizeMode.Maximize: if self.optimize_mode is OptimizeMode.Minimize:
loss = -loss loss = extract_scalar_reward(value)
else:
loss = -extract_scalar_reward(value)
if self.liar: if self.liar:
self.liar.update(loss) self.liar.update(loss)
params = self._running_params.pop(parameter_id) params = self._running_params.pop(parameter_id)
......
...@@ -58,6 +58,8 @@ class BuiltinTunersTestCase(TestCase): ...@@ -58,6 +58,8 @@ class BuiltinTunersTestCase(TestCase):
return receive return receive
def send_trial_result(self, tuner, parameter_id, parameters, metrics): def send_trial_result(self, tuner, parameter_id, parameters, metrics):
if parameter_id % 2 == 1:
metrics = {'default': metrics, 'extra': 'hello'}
tuner.receive_trial_result(parameter_id, parameters, metrics) tuner.receive_trial_result(parameter_id, parameters, metrics)
tuner.trial_end(parameter_id, True) tuner.trial_end(parameter_id, True)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment