"git@developer.sourcefind.cn:OpenDAS/nni.git" did not exist on "e75a9f5ac887b89fcc11b82aba2a38d68863acde"
Commit a0d6cd1c authored by suiguoxin's avatar suiguoxin
Browse files

fix optimize mode bug

parent ce4906f3
......@@ -100,7 +100,7 @@ The total search space is 1,204,224, we set the number of maximum trial to 1000.
| HyperBand |0.415550|0.415977|0.417186|
| GP |0.414353|0.418563|0.420263|
| GP |0.414395|0.418006|0.420431|
| GP |0.416807|0.418095|0.419767|
| GP |0.412943|0.416566|0.418443|
For Metis, there are about 300 trials because it runs slowly due to its high time complexity O(n^3) in Gaussian Process.
......
......@@ -29,7 +29,7 @@ from sklearn.gaussian_process.kernels import Matern
from sklearn.gaussian_process import GaussianProcessRegressor
from nni.tuner import Tuner
from nni.utils import extract_scalar_reward
from nni.utils import OptimizeMode, extract_scalar_reward
from .target_space import TargetSpace
from .util import UtilityFunction, acq_max
......@@ -134,7 +134,7 @@ class GPTuner(Tuner):
if value is dict, it should have "default" key.
"""
value = extract_scalar_reward(value)
if self.optimize_mode == 'minimize':
if self.optimize_mode == OptimizeMode.Minimize.value:
value = -value
logger.info("Received trial result.")
......
......@@ -82,14 +82,14 @@ def acq_max(f_acq, gp, y_max, bounds, space, num_warmup, num_starting_points):
# Warm up with random points
x_tries = [space.random_sample()
for _ in range(int(num_warmup)]
for _ in range(int(num_warmup))]
ys = f_acq(x_tries, gp=gp, y_max=y_max)
x_max = x_tries[ys.argmax()]
max_acq = ys.max()
# Explore the parameter space more throughly
x_seeds = [space.random_sample() for _ in range(int(num_starting_points)]
x_seeds = [space.random_sample() for _ in range(int(num_starting_points))]
bounds_minmax = np.array(
[[bound['_value'][0], bound['_value'][-1]] for bound in bounds])
......
......@@ -216,7 +216,7 @@ class MetisTuner(Tuner):
if value is dict, it should have "default" key.
"""
value = extract_scalar_reward(value)
if self.optimize_mode == OptimizeMode.Maximize:
if self.optimize_mode == OptimizeMode.Maximize.value:
value = -value
logger.info("Received trial result.")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment