Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
eaad9852
Unverified
Commit
eaad9852
authored
Feb 11, 2022
by
cruiseliu
Committed by
GitHub
Feb 11, 2022
Browse files
Fix a bug that new TPE does not support dict metrics (#4531)
parent
cb408193
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
8 additions
and
3 deletions
+8
-3
nni/algorithms/hpo/tpe_tuner.py
nni/algorithms/hpo/tpe_tuner.py
+6
-3
test/ut/sdk/test_builtin_tuners.py
test/ut/sdk/test_builtin_tuners.py
+2
-0
No files found.
nni/algorithms/hpo/tpe_tuner.py
View file @
eaad9852
...
@@ -22,6 +22,7 @@ from scipy.special import erf # pylint: disable=no-name-in-module
...
@@ -22,6 +22,7 @@ from scipy.special import erf # pylint: disable=no-name-in-module
from
nni.tuner
import
Tuner
from
nni.tuner
import
Tuner
from
nni.common.hpo_utils
import
OptimizeMode
,
format_search_space
,
deformat_parameters
,
format_parameters
from
nni.common.hpo_utils
import
OptimizeMode
,
format_search_space
,
deformat_parameters
,
format_parameters
from
nni.utils
import
extract_scalar_reward
from
.
import
random_tuner
from
.
import
random_tuner
_logger
=
logging
.
getLogger
(
'nni.tuner.tpe'
)
_logger
=
logging
.
getLogger
(
'nni.tuner.tpe'
)
...
@@ -126,9 +127,11 @@ class TpeTuner(Tuner):
...
@@ -126,9 +127,11 @@ class TpeTuner(Tuner):
self
.
_running_params
[
parameter_id
]
=
params
self
.
_running_params
[
parameter_id
]
=
params
return
deformat_parameters
(
params
,
self
.
space
)
return
deformat_parameters
(
params
,
self
.
space
)
def
receive_trial_result
(
self
,
parameter_id
,
_parameters
,
loss
,
**
kwargs
):
def
receive_trial_result
(
self
,
parameter_id
,
_parameters
,
value
,
**
kwargs
):
if
self
.
optimize_mode
is
OptimizeMode
.
Maximize
:
if
self
.
optimize_mode
is
OptimizeMode
.
Minimize
:
loss
=
-
loss
loss
=
extract_scalar_reward
(
value
)
else
:
loss
=
-
extract_scalar_reward
(
value
)
if
self
.
liar
:
if
self
.
liar
:
self
.
liar
.
update
(
loss
)
self
.
liar
.
update
(
loss
)
params
=
self
.
_running_params
.
pop
(
parameter_id
)
params
=
self
.
_running_params
.
pop
(
parameter_id
)
...
...
test/ut/sdk/test_builtin_tuners.py
View file @
eaad9852
...
@@ -58,6 +58,8 @@ class BuiltinTunersTestCase(TestCase):
...
@@ -58,6 +58,8 @@ class BuiltinTunersTestCase(TestCase):
return
receive
return
receive
def
send_trial_result
(
self
,
tuner
,
parameter_id
,
parameters
,
metrics
):
def
send_trial_result
(
self
,
tuner
,
parameter_id
,
parameters
,
metrics
):
if
parameter_id
%
2
==
1
:
metrics
=
{
'default'
:
metrics
,
'extra'
:
'hello'
}
tuner
.
receive_trial_result
(
parameter_id
,
parameters
,
metrics
)
tuner
.
receive_trial_result
(
parameter_id
,
parameters
,
metrics
)
tuner
.
trial_end
(
parameter_id
,
True
)
tuner
.
trial_end
(
parameter_id
,
True
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment