"example/vscode:/vscode.git/clone" did not exist on "9a17e7fbfdf57480e39a527d13367bbc9e7a0b04"
Commit 4bc204bf authored by Shufan Huang's avatar Shufan Huang Committed by SparkSnail
Browse files

Fix bug bash of import data feature (#1009)

parent 68c26dd4
...@@ -595,6 +595,9 @@ class BOHB(MsgDispatcherBase): ...@@ -595,6 +595,9 @@ class BOHB(MsgDispatcherBase):
_params = trial_info["parameter"] _params = trial_info["parameter"]
assert "value" in trial_info assert "value" in trial_info
_value = trial_info['value'] _value = trial_info['value']
if not _value:
logger.info("Useless trial data, value is %s, skip this trial data." %_value)
continue
budget_exist_flag = False budget_exist_flag = False
barely_params = dict() barely_params = dict()
for keys in _params: for keys in _params:
......
...@@ -164,6 +164,11 @@ class GridSearchTuner(Tuner): ...@@ -164,6 +164,11 @@ class GridSearchTuner(Tuner):
_completed_num += 1 _completed_num += 1
assert "parameter" in trial_info assert "parameter" in trial_info
_params = trial_info["parameter"] _params = trial_info["parameter"]
assert "value" in trial_info
_value = trial_info['value']
if not _value:
logger.info("Useless trial data, value is %s, skip this trial data." %_value)
continue
_params_tuple = convert_dict2tuple(_params) _params_tuple = convert_dict2tuple(_params)
self.supplement_data[_params_tuple] = True self.supplement_data[_params_tuple] = True
logger.info("Successfully import data to grid search tuner.") logger.info("Successfully import data to grid search tuner.")
...@@ -139,6 +139,27 @@ def json2vals(in_x, vals, out_y, name=ROOT): ...@@ -139,6 +139,27 @@ def json2vals(in_x, vals, out_y, name=ROOT):
for i, temp in enumerate(in_x): for i, temp in enumerate(in_x):
json2vals(temp, vals[i], out_y, name + '[%d]' % i) json2vals(temp, vals[i], out_y, name + '[%d]' % i)
def params2tuner_params(in_x, parameter):
"""
change parameters in NNI format to parameters in hyperopt format.
For example, NNI receive parameters like:
{'dropout_rate': 0.8, 'conv_size': 3, 'hidden_size': 512}
Will change to format in hyperopt, like:
{'dropout_rate': 0.8, 'conv_size': {'_index': 1, '_value': 3}, 'hidden_size': {'_index': 1, '_value': 512}}
"""
tuner_params = dict()
for key in parameter.keys():
value = parameter[key]
_type = in_x[key][TYPE]
if _type == 'choice':
_idx = in_x[key][VALUE].index(value)
tuner_params[key] = {
INDEX: _idx,
VALUE: value
}
else:
tuner_params[key] = value
return tuner_params
def _split_index(params): def _split_index(params):
""" """
...@@ -373,8 +394,11 @@ class HyperoptTuner(Tuner): ...@@ -373,8 +394,11 @@ class HyperoptTuner(Tuner):
_params = trial_info["parameter"] _params = trial_info["parameter"]
assert "value" in trial_info assert "value" in trial_info
_value = trial_info['value'] _value = trial_info['value']
if not _value:
logger.info("Useless trial data, value is %s, skip this trial data." %_value)
continue
self.supplement_data_num += 1 self.supplement_data_num += 1
_parameter_id = '_'.join(["ImportData", str(self.supplement_data_num)]) _parameter_id = '_'.join(["ImportData", str(self.supplement_data_num)])
self.total_data[_parameter_id] = _params self.total_data[_parameter_id] = params2tuner_params(self.json, _params)
self.receive_trial_result(parameter_id=_parameter_id, parameters=_params, value=_value) self.receive_trial_result(parameter_id=_parameter_id, parameters=_params, value=_value)
logger.info("Successfully import data to TPE/Anneal tuner.") logger.info("Successfully import data to TPE/Anneal tuner.")
...@@ -417,6 +417,9 @@ class MetisTuner(Tuner): ...@@ -417,6 +417,9 @@ class MetisTuner(Tuner):
_params = trial_info["parameter"] _params = trial_info["parameter"]
assert "value" in trial_info assert "value" in trial_info
_value = trial_info['value'] _value = trial_info['value']
if not _value:
logger.info("Useless trial data, value is %s, skip this trial data." %_value)
continue
self.supplement_data_num += 1 self.supplement_data_num += 1
_parameter_id = '_'.join(["ImportData", str(self.supplement_data_num)]) _parameter_id = '_'.join(["ImportData", str(self.supplement_data_num)])
self.total_data.append(_params) self.total_data.append(_params)
......
...@@ -136,7 +136,7 @@ def import_data(args): ...@@ -136,7 +136,7 @@ def import_data(args):
args.port = get_experiment_port(args) args.port = get_experiment_port(args)
if args.port is not None: if args.port is not None:
if import_data_to_restful_server(args, content): if import_data_to_restful_server(args, content):
print_normal('Import data success!') pass
else: else:
print_error('Import data failed!') print_error('Import data failed!')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment