Unverified Commit be2f3c71 authored by QuanluZhang's avatar QuanluZhang Committed by GitHub
Browse files

fix bug in NaN support (#2077)

parent 914cc1fe
...@@ -557,7 +557,8 @@ class BOHB(MsgDispatcherBase): ...@@ -557,7 +557,8 @@ class BOHB(MsgDispatcherBase):
Data type not supported Data type not supported
""" """
logger.debug('handle report metric data = %s', data) logger.debug('handle report metric data = %s', data)
if 'value' in data:
data['value'] = json_tricks.loads(data['value'])
if data['type'] == MetricType.REQUEST_PARAMETER: if data['type'] == MetricType.REQUEST_PARAMETER:
assert multi_phase_enabled() assert multi_phase_enabled()
assert data['trial_job_id'] is not None assert data['trial_job_id'] is not None
...@@ -627,6 +628,8 @@ class BOHB(MsgDispatcherBase): ...@@ -627,6 +628,8 @@ class BOHB(MsgDispatcherBase):
AssertionError AssertionError
data doesn't have required key 'parameter' and 'value' data doesn't have required key 'parameter' and 'value'
""" """
for entry in data:
entry['value'] = json_tricks.loads(entry['value'])
_completed_num = 0 _completed_num = 0
for trial_info in data: for trial_info in data:
logger.info("Importing data, current processing progress %s / %s", _completed_num, len(data)) logger.info("Importing data, current processing progress %s / %s", _completed_num, len(data))
......
...@@ -380,6 +380,8 @@ class Hyperband(MsgDispatcherBase): ...@@ -380,6 +380,8 @@ class Hyperband(MsgDispatcherBase):
ValueError ValueError
Data type not supported Data type not supported
""" """
if 'value' in data:
data['value'] = json_tricks.loads(data['value'])
if data['type'] == MetricType.REQUEST_PARAMETER: if data['type'] == MetricType.REQUEST_PARAMETER:
assert multi_phase_enabled() assert multi_phase_enabled()
assert data['trial_job_id'] is not None assert data['trial_job_id'] is not None
......
...@@ -113,6 +113,8 @@ class MsgDispatcher(MsgDispatcherBase): ...@@ -113,6 +113,8 @@ class MsgDispatcher(MsgDispatcherBase):
"""Import additional data for tuning """Import additional data for tuning
data: a list of dictionaries, each of which has at least two keys, 'parameter' and 'value' data: a list of dictionaries, each of which has at least two keys, 'parameter' and 'value'
""" """
for entry in data:
entry['value'] = json_tricks.loads(entry['value'])
self.tuner.import_data(data) self.tuner.import_data(data)
def handle_add_customized_trial(self, data): def handle_add_customized_trial(self, data):
...@@ -128,7 +130,8 @@ class MsgDispatcher(MsgDispatcherBase): ...@@ -128,7 +130,8 @@ class MsgDispatcher(MsgDispatcherBase):
- 'type': report type, support {'FINAL', 'PERIODICAL'} - 'type': report type, support {'FINAL', 'PERIODICAL'}
""" """
# metrics value is dumped as json string in trial, so we need to decode it here # metrics value is dumped as json string in trial, so we need to decode it here
data['value'] = json_tricks.loads(data['value']) if 'value' in data:
data['value'] = json_tricks.loads(data['value'])
if data['type'] == MetricType.FINAL: if data['type'] == MetricType.FINAL:
self._handle_final_metric_data(data) self._handle_final_metric_data(data)
elif data['type'] == MetricType.PERIODICAL: elif data['type'] == MetricType.PERIODICAL:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment