"docs/git@developer.sourcefind.cn:OpenDAS/mmcv.git" did not exist on "6d9e4a9b784188ba8db5e610c2f56f8a03f44649"
Commit c49c24c4 authored by QuanluZhang's avatar QuanluZhang Committed by fishyds
Browse files

fix bug in smac search space convert (#940)

* fix bug in smac search space convert

* update docstring
parent 151013aa
...@@ -192,17 +192,20 @@ class SMACTuner(Tuner): ...@@ -192,17 +192,20 @@ class SMACTuner(Tuner):
Returns Returns
------- -------
dict dict
challenger dict dict which stores copy of challengers
""" """
converted_dict = {}
for key, value in challenger_dict.items(): for key, value in challenger_dict.items():
# convert to loguniform # convert to loguniform
if key in self.loguniform_key: if key in self.loguniform_key:
challenger_dict[key] = np.exp(challenger_dict[key]) converted_dict[key] = np.exp(challenger_dict[key])
# convert categorical back to original value # convert categorical back to original value
if key in self.categorical_dict: elif key in self.categorical_dict:
idx = challenger_dict[key] idx = challenger_dict[key]
challenger_dict[key] = self.categorical_dict[key][idx] converted_dict[key] = self.categorical_dict[key][idx]
return challenger_dict else:
converted_dict[key] = value
return converted_dict
def generate_parameters(self, parameter_id): def generate_parameters(self, parameter_id):
"""generate one instance of hyperparameters """generate one instance of hyperparameters
...@@ -220,13 +223,11 @@ class SMACTuner(Tuner): ...@@ -220,13 +223,11 @@ class SMACTuner(Tuner):
if self.first_one: if self.first_one:
init_challenger = self.smbo_solver.nni_smac_start() init_challenger = self.smbo_solver.nni_smac_start()
self.total_data[parameter_id] = init_challenger self.total_data[parameter_id] = init_challenger
json_tricks.dumps(init_challenger.get_dictionary())
return self.convert_loguniform_categorical(init_challenger.get_dictionary()) return self.convert_loguniform_categorical(init_challenger.get_dictionary())
else: else:
challengers = self.smbo_solver.nni_smac_request_challengers() challengers = self.smbo_solver.nni_smac_request_challengers()
for challenger in challengers: for challenger in challengers:
self.total_data[parameter_id] = challenger self.total_data[parameter_id] = challenger
json_tricks.dumps(challenger.get_dictionary())
return self.convert_loguniform_categorical(challenger.get_dictionary()) return self.convert_loguniform_categorical(challenger.get_dictionary())
def generate_multiple_parameters(self, parameter_id_list): def generate_multiple_parameters(self, parameter_id_list):
...@@ -247,7 +248,6 @@ class SMACTuner(Tuner): ...@@ -247,7 +248,6 @@ class SMACTuner(Tuner):
for one_id in parameter_id_list: for one_id in parameter_id_list:
init_challenger = self.smbo_solver.nni_smac_start() init_challenger = self.smbo_solver.nni_smac_start()
self.total_data[one_id] = init_challenger self.total_data[one_id] = init_challenger
json_tricks.dumps(init_challenger.get_dictionary())
params.append(self.convert_loguniform_categorical(init_challenger.get_dictionary())) params.append(self.convert_loguniform_categorical(init_challenger.get_dictionary()))
else: else:
challengers = self.smbo_solver.nni_smac_request_challengers() challengers = self.smbo_solver.nni_smac_request_challengers()
...@@ -257,7 +257,6 @@ class SMACTuner(Tuner): ...@@ -257,7 +257,6 @@ class SMACTuner(Tuner):
if cnt >= len(parameter_id_list): if cnt >= len(parameter_id_list):
break break
self.total_data[parameter_id_list[cnt]] = challenger self.total_data[parameter_id_list[cnt]] = challenger
json_tricks.dumps(challenger.get_dictionary())
params.append(self.convert_loguniform_categorical(challenger.get_dictionary())) params.append(self.convert_loguniform_categorical(challenger.get_dictionary()))
cnt += 1 cnt += 1
return params return params
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment