"docs/en/git@developer.sourcefind.cn:OpenDAS/mmcv.git" did not exist on "0a2f60ba0198f8d567b536313bfba329588f9c3f"
Commit 4deeeae1 authored by Zejun Lin's avatar Zejun Lin Committed by chicm-ms
Browse files

Empower tuner to realize trial ending (#949)

* enable tuner to aware trial ending

* fix docstring
parent c3074a8b
......@@ -150,6 +150,8 @@ class MsgDispatcher(MsgDispatcherBase):
_trial_history.pop(trial_job_id)
if self.assessor is not None:
self.assessor.trial_end(trial_job_id, data['event'] == 'SUCCEEDED')
if self.tuner is not None:
self.tuner.trial_end(json_tricks.loads(data['hyper_params'])['parameter_id'], data['event'] == 'SUCCEEDED')
return True
def _handle_final_metric_data(self, data):
......
......@@ -150,6 +150,8 @@ class MultiPhaseMsgDispatcher(MsgDispatcherBase):
_trial_history.pop(trial_job_id)
if self.assessor is not None:
self.assessor.trial_end(trial_job_id, data['event'] == 'SUCCEEDED')
if self.tuner is not None:
self.tuner.trial_end(json_tricks.loads(data['hyper_params'])['parameter_id'], data['event'] == 'SUCCEEDED', trial_job_id)
return True
def _handle_intermediate_metric_data(self, data):
......
......@@ -32,7 +32,7 @@ class MultiPhaseTuner(Recoverable):
def generate_parameters(self, parameter_id, trial_job_id=None):
"""Returns a set of trial (hyper-)parameters, as a serializable object.
User code must override either this function or 'generate_multiple_parameters()'.
parameter_id: int
parameter_id: identifier of the parameter (int)
"""
raise NotImplementedError('Tuner: generate_parameters not implemented')
......@@ -46,20 +46,30 @@ class MultiPhaseTuner(Recoverable):
def receive_trial_result(self, parameter_id, parameters, value, trial_job_id):
"""Invoked when a trial reports its final result. Must override.
parameter_id: int
parameter_id: identifier of the parameter (int)
parameters: object created by 'generate_parameters()'
value: object reported by trial
trial_job_id: identifier of the trial (str)
"""
raise NotImplementedError('Tuner: receive_trial_result not implemented')
def receive_customized_trial_result(self, parameter_id, parameters, value, trial_job_id):
"""Invoked when a trial added by WebUI reports its final result. Do nothing by default.
parameter_id: int
parameter_id: identifier of the parameter (int)
parameters: object created by user
value: object reported by trial
trial_job_id: identifier of the trial (str)
"""
_logger.info('Customized trial job %s ignored by tuner', parameter_id)
def trial_end(self, parameter_id, success, trial_job_id):
"""Invoked when a trial is completed or terminated. Do nothing by default.
parameter_id: identifier of the parameter (int)
success: True if the trial successfully completed; False if failed or terminated
trial_job_id: identifier of the trial (str)
"""
pass
def update_search_space(self, search_space):
"""Update the search space of tuner. Must override.
search_space: JSON object
......
......@@ -71,6 +71,13 @@ class Tuner(Recoverable):
"""
_logger.info('Customized trial job %s ignored by tuner', parameter_id)
def trial_end(self, parameter_id, success):
"""Invoked when a trial is completed or terminated. Do nothing by default.
parameter_id: int
success: True if the trial successfully completed; False if failed or terminated
"""
pass
def update_search_space(self, search_space):
"""Update the search space of tuner. Must override.
search_space: JSON object
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment