Unverified Commit bcc640c4 authored by QuanluZhang's avatar QuanluZhang Committed by GitHub
Browse files

[nas] fix issue introduced by the trial recovery feature (#5109)

parent 87677df8
...@@ -648,8 +648,11 @@ class BOHB(MsgDispatcherBase): ...@@ -648,8 +648,11 @@ class BOHB(MsgDispatcherBase):
event: the job's state event: the job's state
hyper_params: the hyperparameters (a string) generated and returned by tuner hyper_params: the hyperparameters (a string) generated and returned by tuner
""" """
logger.debug('Tuner handle trial end, result is %s', data)
hyper_params = nni.load(data['hyper_params']) hyper_params = nni.load(data['hyper_params'])
if self.is_created_in_previous_exp(hyper_params['parameter_id']):
# The end of the recovered trial is ignored
return
logger.debug('Tuner handle trial end, result is %s', data)
self._handle_trial_end(hyper_params['parameter_id']) self._handle_trial_end(hyper_params['parameter_id'])
if data['trial_job_id'] in self.job_id_para_id_map: if data['trial_job_id'] in self.job_id_para_id_map:
del self.job_id_para_id_map[data['trial_job_id']] del self.job_id_para_id_map[data['trial_job_id']]
...@@ -695,6 +698,13 @@ class BOHB(MsgDispatcherBase): ...@@ -695,6 +698,13 @@ class BOHB(MsgDispatcherBase):
ValueError ValueError
Data type not supported Data type not supported
""" """
if self.is_created_in_previous_exp(data['parameter_id']):
if data['type'] == MetricType.FINAL:
# only deal with final metric using import data
param = self.get_previous_param(data['parameter_id'])
trial_data = [{'parameter': param, 'value': nni.load(data['value'])}]
self.handle_import_data(trial_data)
return
logger.debug('handle report metric data = %s', data) logger.debug('handle report metric data = %s', data)
if 'value' in data: if 'value' in data:
data['value'] = nni.load(data['value']) data['value'] = nni.load(data['value'])
...@@ -752,7 +762,10 @@ class BOHB(MsgDispatcherBase): ...@@ -752,7 +762,10 @@ class BOHB(MsgDispatcherBase):
'Data type not supported: {}'.format(data['type'])) 'Data type not supported: {}'.format(data['type']))
def handle_add_customized_trial(self, data): def handle_add_customized_trial(self, data):
pass global _next_parameter_id
# data: parameters
previous_max_param_id = self.recover_parameter_id(data)
_next_parameter_id = previous_max_param_id + 1
def handle_import_data(self, data): def handle_import_data(self, data):
"""Import additional data for tuning """Import additional data for tuning
......
...@@ -522,6 +522,9 @@ class Hyperband(MsgDispatcherBase): ...@@ -522,6 +522,9 @@ class Hyperband(MsgDispatcherBase):
hyper_params: the hyperparameters (a string) generated and returned by tuner hyper_params: the hyperparameters (a string) generated and returned by tuner
""" """
hyper_params = nni.load(data['hyper_params']) hyper_params = nni.load(data['hyper_params'])
if self.is_created_in_previous_exp(hyper_params['parameter_id']):
# The end of the recovered trial is ignored
return
self._handle_trial_end(hyper_params['parameter_id']) self._handle_trial_end(hyper_params['parameter_id'])
if data['trial_job_id'] in self.job_id_para_id_map: if data['trial_job_id'] in self.job_id_para_id_map:
del self.job_id_para_id_map[data['trial_job_id']] del self.job_id_para_id_map[data['trial_job_id']]
...@@ -538,6 +541,9 @@ class Hyperband(MsgDispatcherBase): ...@@ -538,6 +541,9 @@ class Hyperband(MsgDispatcherBase):
ValueError ValueError
Data type not supported Data type not supported
""" """
if self.is_created_in_previous_exp(data['parameter_id']):
# do not support recovering the algorithm state
return
if 'value' in data: if 'value' in data:
data['value'] = nni.load(data['value']) data['value'] = nni.load(data['value'])
# multiphase? need to check # multiphase? need to check
...@@ -576,7 +582,10 @@ class Hyperband(MsgDispatcherBase): ...@@ -576,7 +582,10 @@ class Hyperband(MsgDispatcherBase):
raise ValueError('Data type not supported: {}'.format(data['type'])) raise ValueError('Data type not supported: {}'.format(data['type']))
def handle_add_customized_trial(self, data): def handle_add_customized_trial(self, data):
pass global _next_parameter_id
# data: parameters
previous_max_param_id = self.recover_parameter_id(data)
_next_parameter_id = previous_max_param_id + 1
def handle_import_data(self, data): def handle_import_data(self, data):
pass pass
...@@ -218,19 +218,6 @@ class TpeTuner(Tuner): ...@@ -218,19 +218,6 @@ class TpeTuner(Tuner):
self.dedup.add_history(param) self.dedup.add_history(param)
_logger.info(f'Replayed {len(data)} FINISHED trials') _logger.info(f'Replayed {len(data)} FINISHED trials')
def import_customized_data(self, data): # for dedup customized / resumed
if isinstance(data, str):
data = nni.load(data)
for trial in data:
# {'parameter_id': 0, 'parameter_source': 'resumed', 'parameters': {'batch_size': 128, ...}
if isinstance(trial, str):
trial = nni.load(trial)
param = format_parameters(trial['parameters'], self.space)
self._running_params[trial['parameter_id']] = param
self.dedup.add_history(param)
_logger.info(f'Replayed {len(data)} RUNING/WAITING trials')
def suggest(args, rng, space, history): def suggest(args, rng, space, history):
params = {} params = {}
for key, spec in space.items(): for key, spec in space.items():
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
__all__ = ['RetiariiAdvisor'] __all__ = ['RetiariiAdvisor']
import logging import logging
import time
import os import os
from typing import Any, Callable, Optional, Dict, List, Tuple from typing import Any, Callable, Optional, Dict, List, Tuple
...@@ -60,11 +61,12 @@ class RetiariiAdvisor(MsgDispatcherBase): ...@@ -60,11 +61,12 @@ class RetiariiAdvisor(MsgDispatcherBase):
self.final_metric_callback: Optional[Callable[[int, MetricData], None]] = None self.final_metric_callback: Optional[Callable[[int, MetricData], None]] = None
self.parameters_count = 0 self.parameters_count = 0
# Sometimes messages arrive first before the callbacks get registered. # Sometimes messages arrive first before the callbacks get registered.
# Or in case that we allow engine to be absent during the experiment. # Or in case that we allow engine to be absent during the experiment.
# Here we need to store the messages and invoke them later. # Here we need to store the messages and invoke them later.
self.call_queue: List[Tuple[str, list]] = [] self.call_queue: List[Tuple[str, list]] = []
# this is for waiting the to-be-recovered trials from nnimanager
self._advisor_initialized = False
def register_callbacks(self, callbacks: Dict[str, Callable[..., None]]): def register_callbacks(self, callbacks: Dict[str, Callable[..., None]]):
""" """
...@@ -167,6 +169,10 @@ class RetiariiAdvisor(MsgDispatcherBase): ...@@ -167,6 +169,10 @@ class RetiariiAdvisor(MsgDispatcherBase):
Parameter ID that is assigned to this parameter, Parameter ID that is assigned to this parameter,
which will be used for identification in future. which will be used for identification in future.
""" """
while not self._advisor_initialized:
_logger.info('Wait for RetiariiAdvisor to be initialized...')
time.sleep(0.5)
self.parameters_count += 1 self.parameters_count += 1
if placement_constraint is None: if placement_constraint is None:
placement_constraint = { placement_constraint = {
...@@ -204,6 +210,7 @@ class RetiariiAdvisor(MsgDispatcherBase): ...@@ -204,6 +210,7 @@ class RetiariiAdvisor(MsgDispatcherBase):
self.send(CommandType.NoMoreTrialJobs, '') self.send(CommandType.NoMoreTrialJobs, '')
def handle_request_trial_jobs(self, num_trials): def handle_request_trial_jobs(self, num_trials):
self._advisor_initialized = True
_logger.debug('Request trial jobs: %s', num_trials) _logger.debug('Request trial jobs: %s', num_trials)
self.invoke_callback('request_trial_jobs', num_trials) self.invoke_callback('request_trial_jobs', num_trials)
...@@ -212,10 +219,22 @@ class RetiariiAdvisor(MsgDispatcherBase): ...@@ -212,10 +219,22 @@ class RetiariiAdvisor(MsgDispatcherBase):
self.search_space = data self.search_space = data
def handle_trial_end(self, data): def handle_trial_end(self, data):
# TODO: we should properly handle the trials in self._customized_parameter_ids instead of ignoring
id_ = nni.load(data['hyper_params'])['parameter_id']
if self.is_created_in_previous_exp(id_):
_logger.info('The end of the recovered trial %d is ignored', id_)
return
_logger.debug('Trial end: %s', data) _logger.debug('Trial end: %s', data)
self.invoke_callback('trial_end', nni.load(data['hyper_params'])['parameter_id'], data['event'] == 'SUCCEEDED') self.invoke_callback('trial_end', id_, data['event'] == 'SUCCEEDED')
def handle_report_metric_data(self, data): def handle_report_metric_data(self, data):
# TODO: we should properly handle the trials in self._customized_parameter_ids instead of ignoring
if self.is_created_in_previous_exp(data['parameter_id']):
_logger.info('The metrics of the recovered trial %d are ignored', data['parameter_id'])
return
# NOTE: this part is not aligned with hpo tuners.
# in hpo tuners, trial_job_id is used for intermediate results handling
# parameter_id is for final result handling.
_logger.debug('Metric reported: %s', data) _logger.debug('Metric reported: %s', data)
if data['type'] == MetricType.REQUEST_PARAMETER: if data['type'] == MetricType.REQUEST_PARAMETER:
raise ValueError('Request parameter not supported') raise ValueError('Request parameter not supported')
...@@ -239,4 +258,5 @@ class RetiariiAdvisor(MsgDispatcherBase): ...@@ -239,4 +258,5 @@ class RetiariiAdvisor(MsgDispatcherBase):
pass pass
def handle_add_customized_trial(self, data): def handle_add_customized_trial(self, data):
pass previous_max_param_id = self.recover_parameter_id(data)
self.parameters_count = previous_max_param_id
...@@ -12,6 +12,7 @@ from typing import NewType, Any ...@@ -12,6 +12,7 @@ from typing import NewType, Any
import nni import nni
from nni.common.version import version_check from nni.common.version import version_check
# NOTE: this is only for passing flake8, we cannot import RetiariiAdvisor # NOTE: this is only for passing flake8, we cannot import RetiariiAdvisor
# because it would induce cycled import # because it would induce cycled import
RetiariiAdvisor = NewType('RetiariiAdvisor', Any) RetiariiAdvisor = NewType('RetiariiAdvisor', Any)
......
...@@ -4,8 +4,12 @@ ...@@ -4,8 +4,12 @@
from __future__ import annotations from __future__ import annotations
import os import os
import nni
class Recoverable: class Recoverable:
def __init__(self):
self.recovered_max_param_id = -1
self.recovered_trial_params = {}
def load_checkpoint(self) -> None: def load_checkpoint(self) -> None:
pass pass
...@@ -18,3 +22,29 @@ class Recoverable: ...@@ -18,3 +22,29 @@ class Recoverable:
if ckp_path is not None and os.path.isdir(ckp_path): if ckp_path is not None and os.path.isdir(ckp_path):
return ckp_path return ckp_path
return None return None
def recover_parameter_id(self, data) -> int:
# this is for handling the resuming of the interrupted data: parameters
if not isinstance(data, list):
data = [data]
previous_max_param_id = 0
for trial in data:
# {'parameter_id': 0, 'parameter_source': 'resumed', 'parameters': {'batch_size': 128, ...}
if isinstance(trial, str):
trial = nni.load(trial)
if not isinstance(trial['parameter_id'], int):
# for dealing with user customized trials
# skip for now
continue
self.recovered_trial_params[trial['parameter_id']] = trial['parameters']
if previous_max_param_id < trial['parameter_id']:
previous_max_param_id = trial['parameter_id']
self.recovered_max_param_id = previous_max_param_id
return previous_max_param_id
def is_created_in_previous_exp(self, param_id: int) -> bool:
return param_id <= self.recovered_max_param_id
def get_previous_param(self, param_id: int) -> dict:
return self.recovered_trial_params[param_id]
\ No newline at end of file
...@@ -120,15 +120,10 @@ class MsgDispatcher(MsgDispatcherBase): ...@@ -120,15 +120,10 @@ class MsgDispatcher(MsgDispatcherBase):
self.tuner.import_data(data) self.tuner.import_data(data)
def handle_add_customized_trial(self, data): def handle_add_customized_trial(self, data):
global _next_parameter_id
# data: parameters # data: parameters
if not isinstance(data, list): previous_max_param_id = self.recover_parameter_id(data)
data = [data] _next_parameter_id = previous_max_param_id + 1
for _ in data:
id_ = _create_parameter_id()
_customized_parameter_ids.add(id_)
self.tuner.import_customized_data(data)
def handle_report_metric_data(self, data): def handle_report_metric_data(self, data):
""" """
...@@ -137,6 +132,13 @@ class MsgDispatcher(MsgDispatcherBase): ...@@ -137,6 +132,13 @@ class MsgDispatcher(MsgDispatcherBase):
- 'value': metric value reported by nni.report_final_result() - 'value': metric value reported by nni.report_final_result()
- 'type': report type, support {'FINAL', 'PERIODICAL'} - 'type': report type, support {'FINAL', 'PERIODICAL'}
""" """
if self.is_created_in_previous_exp(data['parameter_id']):
if data['type'] == MetricType.FINAL:
# only deal with final metric using import data
param = self.get_previous_param(data['parameter_id'])
trial_data = [{'parameter': param, 'value': load(data['value'])}]
self.handle_import_data(trial_data)
return
# metrics value is dumped as json string in trial, so we need to decode it here # metrics value is dumped as json string in trial, so we need to decode it here
if 'value' in data: if 'value' in data:
data['value'] = load(data['value']) data['value'] = load(data['value'])
...@@ -166,6 +168,10 @@ class MsgDispatcher(MsgDispatcherBase): ...@@ -166,6 +168,10 @@ class MsgDispatcher(MsgDispatcherBase):
- event: the job's state - event: the job's state
- hyper_params: the hyperparameters generated and returned by tuner - hyper_params: the hyperparameters generated and returned by tuner
""" """
id_ = load(data['hyper_params'])['parameter_id']
if self.is_created_in_previous_exp(id_):
# The end of the recovered trial is ignored
return
trial_job_id = data['trial_job_id'] trial_job_id = data['trial_job_id']
_ended_trials.add(trial_job_id) _ended_trials.add(trial_job_id)
if trial_job_id in _trial_history: if trial_job_id in _trial_history:
...@@ -173,7 +179,7 @@ class MsgDispatcher(MsgDispatcherBase): ...@@ -173,7 +179,7 @@ class MsgDispatcher(MsgDispatcherBase):
if self.assessor is not None: if self.assessor is not None:
self.assessor.trial_end(trial_job_id, data['event'] == 'SUCCEEDED') self.assessor.trial_end(trial_job_id, data['event'] == 'SUCCEEDED')
if self.tuner is not None: if self.tuner is not None:
self.tuner.trial_end(load(data['hyper_params'])['parameter_id'], data['event'] == 'SUCCEEDED') self.tuner.trial_end(id_, data['event'] == 'SUCCEEDED')
def _handle_final_metric_data(self, data): def _handle_final_metric_data(self, data):
"""Call tuner to process final results """Call tuner to process final results
......
...@@ -30,6 +30,7 @@ class MsgDispatcherBase(Recoverable): ...@@ -30,6 +30,7 @@ class MsgDispatcherBase(Recoverable):
""" """
def __init__(self, command_channel_url=None): def __init__(self, command_channel_url=None):
super().__init__()
self.stopping = False self.stopping = False
if command_channel_url is None: if command_channel_url is None:
command_channel_url = dispatcher_env_vars.NNI_TUNER_COMMAND_CHANNEL command_channel_url = dispatcher_env_vars.NNI_TUNER_COMMAND_CHANNEL
......
...@@ -219,14 +219,6 @@ class Tuner(Recoverable): ...@@ -219,14 +219,6 @@ class Tuner(Recoverable):
# data: a list of dictionarys, each of which has at least two keys, 'parameter' and 'value' # data: a list of dictionarys, each of which has at least two keys, 'parameter' and 'value'
pass pass
def import_customized_data(self, data: list[TrialRecord]) -> None:
"""
Internal API under revising, not recommended for end users.
"""
# Import resume data for avoiding duplications
# data: a list of dictionarys, each of which has at least two keys, 'parameter_id' and 'parameters'
pass
def _on_exit(self) -> None: def _on_exit(self) -> None:
pass pass
......
...@@ -319,6 +319,9 @@ class CGOEngineTest(unittest.TestCase): ...@@ -319,6 +319,9 @@ class CGOEngineTest(unittest.TestCase):
advisor._channel = protocol.LegacyCommandChannel() advisor._channel = protocol.LegacyCommandChannel()
advisor.default_worker.start() advisor.default_worker.start()
advisor.assessor_worker.start() advisor.assessor_worker.start()
# this is because RetiariiAdvisor only works after `_advisor_initialized` becomes True.
# normally it becomes true when `handle_request_trial_jobs` is invoked
advisor._advisor_initialized = True
remote = RemoteConfig(machine_list=[]) remote = RemoteConfig(machine_list=[])
remote.machine_list.append(RemoteMachineConfig(host='test', gpu_indices=[0,1,2,3])) remote.machine_list.append(RemoteMachineConfig(host='test', gpu_indices=[0,1,2,3]))
......
...@@ -27,6 +27,7 @@ class EngineTest(unittest.TestCase): ...@@ -27,6 +27,7 @@ class EngineTest(unittest.TestCase):
nni.retiarii.integration_api._advisor = None nni.retiarii.integration_api._advisor = None
nni.retiarii.execution.api._execution_engine = None nni.retiarii.execution.api._execution_engine = None
advisor = RetiariiAdvisor('ws://_unittest_placeholder_') advisor = RetiariiAdvisor('ws://_unittest_placeholder_')
advisor._advisor_initialized = True
advisor._channel = LegacyCommandChannel() advisor._channel = LegacyCommandChannel()
advisor.default_worker.start() advisor.default_worker.start()
advisor.assessor_worker.start() advisor.assessor_worker.start()
...@@ -44,6 +45,7 @@ class EngineTest(unittest.TestCase): ...@@ -44,6 +45,7 @@ class EngineTest(unittest.TestCase):
nni.retiarii.integration_api._advisor = None nni.retiarii.integration_api._advisor = None
nni.retiarii.execution.api._execution_engine = None nni.retiarii.execution.api._execution_engine = None
advisor = RetiariiAdvisor('ws://_unittest_placeholder_') advisor = RetiariiAdvisor('ws://_unittest_placeholder_')
advisor._advisor_initialized = True
advisor._channel = LegacyCommandChannel() advisor._channel = LegacyCommandChannel()
advisor.default_worker.start() advisor.default_worker.start()
advisor.assessor_worker.start() advisor.assessor_worker.start()
......
...@@ -48,11 +48,11 @@ class AssessorTestCase(TestCase): ...@@ -48,11 +48,11 @@ class AssessorTestCase(TestCase):
def test_assessor(self): def test_assessor(self):
pass pass
_reverse_io() _reverse_io()
send(CommandType.ReportMetricData, '{"trial_job_id":"A","type":"PERIODICAL","sequence":0,"value":"2"}') send(CommandType.ReportMetricData, '{"parameter_id": 0,"trial_job_id":"A","type":"PERIODICAL","sequence":0,"value":"2"}')
send(CommandType.ReportMetricData, '{"trial_job_id":"B","type":"PERIODICAL","sequence":0,"value":"2"}') send(CommandType.ReportMetricData, '{"parameter_id": 1,"trial_job_id":"B","type":"PERIODICAL","sequence":0,"value":"2"}')
send(CommandType.ReportMetricData, '{"trial_job_id":"A","type":"PERIODICAL","sequence":1,"value":"3"}') send(CommandType.ReportMetricData, '{"parameter_id": 0,"trial_job_id":"A","type":"PERIODICAL","sequence":1,"value":"3"}')
send(CommandType.TrialEnd, '{"trial_job_id":"A","event":"SYS_CANCELED"}') send(CommandType.TrialEnd, '{"trial_job_id":"A","event":"SYS_CANCELED","hyper_params":"{\\"parameter_id\\": 0}"}')
send(CommandType.TrialEnd, '{"trial_job_id":"B","event":"SUCCEEDED"}') send(CommandType.TrialEnd, '{"trial_job_id":"B","event":"SUCCEEDED","hyper_params":"{\\"parameter_id\\": 1}"}')
send(CommandType.NewTrialJob, 'null') send(CommandType.NewTrialJob, 'null')
_restore_io() _restore_io()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment