Commit 4deeeae1 authored by Zejun Lin's avatar Zejun Lin Committed by chicm-ms
Browse files

Empower tuner to realize trial ending (#949)

* enable tuner to aware trial ending

* fix docstring
parent c3074a8b
...@@ -150,6 +150,8 @@ class MsgDispatcher(MsgDispatcherBase): ...@@ -150,6 +150,8 @@ class MsgDispatcher(MsgDispatcherBase):
_trial_history.pop(trial_job_id) _trial_history.pop(trial_job_id)
if self.assessor is not None: if self.assessor is not None:
self.assessor.trial_end(trial_job_id, data['event'] == 'SUCCEEDED') self.assessor.trial_end(trial_job_id, data['event'] == 'SUCCEEDED')
if self.tuner is not None:
self.tuner.trial_end(json_tricks.loads(data['hyper_params'])['parameter_id'], data['event'] == 'SUCCEEDED')
return True return True
def _handle_final_metric_data(self, data): def _handle_final_metric_data(self, data):
......
...@@ -150,6 +150,8 @@ class MultiPhaseMsgDispatcher(MsgDispatcherBase): ...@@ -150,6 +150,8 @@ class MultiPhaseMsgDispatcher(MsgDispatcherBase):
_trial_history.pop(trial_job_id) _trial_history.pop(trial_job_id)
if self.assessor is not None: if self.assessor is not None:
self.assessor.trial_end(trial_job_id, data['event'] == 'SUCCEEDED') self.assessor.trial_end(trial_job_id, data['event'] == 'SUCCEEDED')
if self.tuner is not None:
self.tuner.trial_end(json_tricks.loads(data['hyper_params'])['parameter_id'], data['event'] == 'SUCCEEDED', trial_job_id)
return True return True
def _handle_intermediate_metric_data(self, data): def _handle_intermediate_metric_data(self, data):
......
...@@ -32,7 +32,7 @@ class MultiPhaseTuner(Recoverable): ...@@ -32,7 +32,7 @@ class MultiPhaseTuner(Recoverable):
def generate_parameters(self, parameter_id, trial_job_id=None): def generate_parameters(self, parameter_id, trial_job_id=None):
"""Returns a set of trial (hyper-)parameters, as a serializable object. """Returns a set of trial (hyper-)parameters, as a serializable object.
User code must override either this function or 'generate_multiple_parameters()'. User code must override either this function or 'generate_multiple_parameters()'.
parameter_id: int parameter_id: identifier of the parameter (int)
""" """
raise NotImplementedError('Tuner: generate_parameters not implemented') raise NotImplementedError('Tuner: generate_parameters not implemented')
...@@ -46,20 +46,30 @@ class MultiPhaseTuner(Recoverable): ...@@ -46,20 +46,30 @@ class MultiPhaseTuner(Recoverable):
def receive_trial_result(self, parameter_id, parameters, value, trial_job_id): def receive_trial_result(self, parameter_id, parameters, value, trial_job_id):
"""Invoked when a trial reports its final result. Must override. """Invoked when a trial reports its final result. Must override.
parameter_id: int parameter_id: identifier of the parameter (int)
parameters: object created by 'generate_parameters()' parameters: object created by 'generate_parameters()'
value: object reported by trial value: object reported by trial
trial_job_id: identifier of the trial (str)
""" """
raise NotImplementedError('Tuner: receive_trial_result not implemented') raise NotImplementedError('Tuner: receive_trial_result not implemented')
def receive_customized_trial_result(self, parameter_id, parameters, value, trial_job_id): def receive_customized_trial_result(self, parameter_id, parameters, value, trial_job_id):
"""Invoked when a trial added by WebUI reports its final result. Do nothing by default. """Invoked when a trial added by WebUI reports its final result. Do nothing by default.
parameter_id: int parameter_id: identifier of the parameter (int)
parameters: object created by user parameters: object created by user
value: object reported by trial value: object reported by trial
trial_job_id: identifier of the trial (str)
""" """
_logger.info('Customized trial job %s ignored by tuner', parameter_id) _logger.info('Customized trial job %s ignored by tuner', parameter_id)
def trial_end(self, parameter_id, success, trial_job_id):
"""Invoked when a trial is completed or terminated. Do nothing by default.
parameter_id: identifier of the parameter (int)
success: True if the trial successfully completed; False if failed or terminated
trial_job_id: identifier of the trial (str)
"""
pass
def update_search_space(self, search_space): def update_search_space(self, search_space):
"""Update the search space of tuner. Must override. """Update the search space of tuner. Must override.
search_space: JSON object search_space: JSON object
......
...@@ -71,6 +71,13 @@ class Tuner(Recoverable): ...@@ -71,6 +71,13 @@ class Tuner(Recoverable):
""" """
_logger.info('Customized trial job %s ignored by tuner', parameter_id) _logger.info('Customized trial job %s ignored by tuner', parameter_id)
def trial_end(self, parameter_id, success):
"""Invoked when a trial is completed or terminated. Do nothing by default.
parameter_id: int
success: True if the trial successfully completed; False if failed or terminated
"""
pass
def update_search_space(self, search_space): def update_search_space(self, search_space):
"""Update the search space of tuner. Must override. """Update the search space of tuner. Must override.
search_space: JSON object search_space: JSON object
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment