Unverified Commit c1e926b9 authored by J-shang's avatar J-shang Committed by GitHub
Browse files

fix generate parameter format (#3175)

parent 0d3f13a3
...@@ -89,4 +89,4 @@ Known Limitations: ...@@ -89,4 +89,4 @@ Known Limitations:
* Note that for nested search space: * Note that for nested search space:
* Only Random Search/TPE/Anneal/Evolution tuner supports nested search space * Only Random Search/TPE/Anneal/Evolution/Grid Search tuner supports nested search space
...@@ -255,4 +255,4 @@ Known Limitations: ...@@ -255,4 +255,4 @@ Known Limitations:
Note that for nested search space: Note that for nested search space:
* Only Random Search/TPE/Anneal/Evolution tuner supports nested search space * Only Random Search/TPE/Anneal/Evolution/Grid Search tuner supports nested search space
...@@ -169,7 +169,7 @@ class GridSearchTuner(Tuner): ...@@ -169,7 +169,7 @@ class GridSearchTuner(Tuner):
""" """
self.count += 1 self.count += 1
while self.count <= len(self.expanded_search_space) - 1: while self.count <= len(self.expanded_search_space) - 1:
_params_tuple = convert_dict2tuple(self.expanded_search_space[self.count]) _params_tuple = convert_dict2tuple(copy.deepcopy(self.expanded_search_space[self.count]))
if _params_tuple in self.supplement_data: if _params_tuple in self.supplement_data:
self.count += 1 self.count += 1
else: else:
...@@ -203,6 +203,6 @@ class GridSearchTuner(Tuner): ...@@ -203,6 +203,6 @@ class GridSearchTuner(Tuner):
if not _value: if not _value:
logger.info("Useless trial data, value is %s, skip this trial data.", _value) logger.info("Useless trial data, value is %s, skip this trial data.", _value)
continue continue
_params_tuple = convert_dict2tuple(_params) _params_tuple = convert_dict2tuple(copy.deepcopy(_params))
self.supplement_data[_params_tuple] = True self.supplement_data[_params_tuple] = True
logger.info("Successfully import data to grid search tuner.") logger.info("Successfully import data to grid search tuner.")
...@@ -109,6 +109,7 @@ def extract_scalar_history(trial_history, scalar_key='default'): ...@@ -109,6 +109,7 @@ def extract_scalar_history(trial_history, scalar_key='default'):
def convert_dict2tuple(value): def convert_dict2tuple(value):
""" """
convert dict type to tuple to solve unhashable problem. convert dict type to tuple to solve unhashable problem.
NOTE: this function will change original data.
""" """
if isinstance(value, dict): if isinstance(value, dict):
for _keys in value: for _keys in value:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment