Commit 66b36b84 authored by Zejun Lin's avatar Zejun Lin Committed by fishyds
Browse files

update loguniform for smac (#430)

* modify loguniform and lognormal

* fix bug

* fix bug

* update doc

* update doc

* fix

* update tpe for loguniform

* update tpe for loguniform

* update for loguniform

* update for loguniform

* update loguniform and qloguniform

* update doc

* update

* revert

* revert

* revert

* revert

* update loguniform for smac

* update loguniform for smac

* update loguniform for smac

* update loguniform for smac
parent c78c523b
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
import os import os
import json import json
import numpy as np
def get_json_content(file_path): def get_json_content(file_path):
'''Load json file content''' '''Load json file content'''
...@@ -63,7 +64,9 @@ def generate_pcs(nni_search_space_content): ...@@ -63,7 +64,9 @@ def generate_pcs(nni_search_space_content):
json.dumps(search_space[key]['_value']), json.dumps(search_space[key]['_value']),
json.dumps(search_space[key]['_value'][0]))) json.dumps(search_space[key]['_value'][0])))
elif search_space[key]['_type'] == 'loguniform': elif search_space[key]['_type'] == 'loguniform':
pcs_fd.write('%s real %s [%s] log\n' % ( # use np.round here to ensure that the rounded defaut value is in the range, which will be rounded in configure_space package
search_space[key]['_value'] = list(np.round(np.log(search_space[key]['_value']), 10))
pcs_fd.write('%s real %s [%s]\n' % (
key, key,
json.dumps(search_space[key]['_value']), json.dumps(search_space[key]['_value']),
json.dumps(search_space[key]['_value'][0]))) json.dumps(search_space[key]['_value'][0])))
......
...@@ -57,6 +57,7 @@ class SMACTuner(Tuner): ...@@ -57,6 +57,7 @@ class SMACTuner(Tuner):
self.smbo_solver = None self.smbo_solver = None
self.first_one = True self.first_one = True
self.update_ss_done = False self.update_ss_done = False
self.loguniform_key = set()
def _main_cli(self): def _main_cli(self):
''' '''
...@@ -130,6 +131,7 @@ class SMACTuner(Tuner): ...@@ -130,6 +131,7 @@ class SMACTuner(Tuner):
generate_scenario(search_space) generate_scenario(search_space)
self.optimizer = self._main_cli() self.optimizer = self._main_cli()
self.smbo_solver = self.optimizer.solver self.smbo_solver = self.optimizer.solver
self.loguniform_key = {key for key in search_space.keys() if search_space[key]['_type'] == 'loguniform'}
self.update_ss_done = True self.update_ss_done = True
else: else:
self.logger.warning('update search space is not supported.') self.logger.warning('update search space is not supported.')
...@@ -150,6 +152,15 @@ class SMACTuner(Tuner): ...@@ -150,6 +152,15 @@ class SMACTuner(Tuner):
else: else:
self.smbo_solver.nni_smac_receive_runs(self.total_data[parameter_id], reward) self.smbo_solver.nni_smac_receive_runs(self.total_data[parameter_id], reward)
def convert_loguniform(self, challenger_dict):
'''
convert the values of type `loguniform` back to their initial range
'''
for key, value in challenger_dict.items():
if key in self.loguniform_key:
challenger_dict[key] = np.exp(challenger_dict[key])
return challenger_dict
def generate_parameters(self, parameter_id): def generate_parameters(self, parameter_id):
''' '''
generate one instance of hyperparameters generate one instance of hyperparameters
...@@ -158,25 +169,27 @@ class SMACTuner(Tuner): ...@@ -158,25 +169,27 @@ class SMACTuner(Tuner):
init_challenger = self.smbo_solver.nni_smac_start() init_challenger = self.smbo_solver.nni_smac_start()
self.total_data[parameter_id] = init_challenger self.total_data[parameter_id] = init_challenger
json_tricks.dumps(init_challenger.get_dictionary()) json_tricks.dumps(init_challenger.get_dictionary())
return init_challenger.get_dictionary() return self.convert_loguniform(init_challenger.get_dictionary())
else: else:
challengers = self.smbo_solver.nni_smac_request_challengers() challengers = self.smbo_solver.nni_smac_request_challengers()
for challenger in challengers: for challenger in challengers:
self.total_data[parameter_id] = challenger self.total_data[parameter_id] = challenger
json_tricks.dumps(challenger.get_dictionary()) json_tricks.dumps(challenger.get_dictionary())
return challenger.get_dictionary() return self.convert_loguniform(challenger.get_dictionary())
def generate_multiple_parameters(self, parameter_id_list): def generate_multiple_parameters(self, parameter_id_list):
''' '''
generate mutiple instances of hyperparameters generate mutiple instances of hyperparameters
''' '''
if self.first_one: if self.first_one:
params = [] params = []
for one_id in parameter_id_list: for one_id in parameter_id_list:
init_challenger = self.smbo_solver.nni_smac_start() init_challenger = self.smbo_solver.nni_smac_start()
self.total_data[one_id] = init_challenger self.total_data[one_id] = init_challenger
json_tricks.dumps(init_challenger.get_dictionary()) json_tricks.dumps(init_challenger.get_dictionary())
params.append(init_challenger.get_dictionary()) params.append(self.convert_loguniform(init_challenger.get_dictionary()))
else: else:
challengers = self.smbo_solver.nni_smac_request_challengers() challengers = self.smbo_solver.nni_smac_request_challengers()
cnt = 0 cnt = 0
...@@ -186,6 +199,6 @@ class SMACTuner(Tuner): ...@@ -186,6 +199,6 @@ class SMACTuner(Tuner):
break break
self.total_data[parameter_id_list[cnt]] = challenger self.total_data[parameter_id_list[cnt]] = challenger
json_tricks.dumps(challenger.get_dictionary()) json_tricks.dumps(challenger.get_dictionary())
params.append(challenger.get_dictionary()) params.append(self.convert_loguniform(challenger.get_dictionary()))
cnt += 1 cnt += 1
return params return params
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment