Unverified Commit bcc640c4 authored by QuanluZhang's avatar QuanluZhang Committed by GitHub
Browse files

[nas] fix issue introduced by the trial recovery feature (#5109)

parent 87677df8
......@@ -648,8 +648,11 @@ class BOHB(MsgDispatcherBase):
event: the job's state
hyper_params: the hyperparameters (a string) generated and returned by tuner
"""
logger.debug('Tuner handle trial end, result is %s', data)
hyper_params = nni.load(data['hyper_params'])
if self.is_created_in_previous_exp(hyper_params['parameter_id']):
# The end of the recovered trial is ignored
return
logger.debug('Tuner handle trial end, result is %s', data)
self._handle_trial_end(hyper_params['parameter_id'])
if data['trial_job_id'] in self.job_id_para_id_map:
del self.job_id_para_id_map[data['trial_job_id']]
......@@ -695,6 +698,13 @@ class BOHB(MsgDispatcherBase):
ValueError
Data type not supported
"""
if self.is_created_in_previous_exp(data['parameter_id']):
if data['type'] == MetricType.FINAL:
# only deal with final metric using import data
param = self.get_previous_param(data['parameter_id'])
trial_data = [{'parameter': param, 'value': nni.load(data['value'])}]
self.handle_import_data(trial_data)
return
logger.debug('handle report metric data = %s', data)
if 'value' in data:
data['value'] = nni.load(data['value'])
......@@ -752,7 +762,10 @@ class BOHB(MsgDispatcherBase):
'Data type not supported: {}'.format(data['type']))
def handle_add_customized_trial(self, data):
pass
global _next_parameter_id
# data: parameters
previous_max_param_id = self.recover_parameter_id(data)
_next_parameter_id = previous_max_param_id + 1
def handle_import_data(self, data):
"""Import additional data for tuning
......
......@@ -522,6 +522,9 @@ class Hyperband(MsgDispatcherBase):
hyper_params: the hyperparameters (a string) generated and returned by tuner
"""
hyper_params = nni.load(data['hyper_params'])
if self.is_created_in_previous_exp(hyper_params['parameter_id']):
# The end of the recovered trial is ignored
return
self._handle_trial_end(hyper_params['parameter_id'])
if data['trial_job_id'] in self.job_id_para_id_map:
del self.job_id_para_id_map[data['trial_job_id']]
......@@ -538,6 +541,9 @@ class Hyperband(MsgDispatcherBase):
ValueError
Data type not supported
"""
if self.is_created_in_previous_exp(data['parameter_id']):
# do not support recovering the algorithm state
return
if 'value' in data:
data['value'] = nni.load(data['value'])
# multiphase? need to check
......@@ -576,7 +582,10 @@ class Hyperband(MsgDispatcherBase):
raise ValueError('Data type not supported: {}'.format(data['type']))
def handle_add_customized_trial(self, data):
pass
global _next_parameter_id
# data: parameters
previous_max_param_id = self.recover_parameter_id(data)
_next_parameter_id = previous_max_param_id + 1
def handle_import_data(self, data):
pass
......@@ -218,19 +218,6 @@ class TpeTuner(Tuner):
self.dedup.add_history(param)
_logger.info(f'Replayed {len(data)} FINISHED trials')
def import_customized_data(self, data): # for dedup customized / resumed
if isinstance(data, str):
data = nni.load(data)
for trial in data:
# {'parameter_id': 0, 'parameter_source': 'resumed', 'parameters': {'batch_size': 128, ...}
if isinstance(trial, str):
trial = nni.load(trial)
param = format_parameters(trial['parameters'], self.space)
self._running_params[trial['parameter_id']] = param
self.dedup.add_history(param)
_logger.info(f'Replayed {len(data)} RUNING/WAITING trials')
def suggest(args, rng, space, history):
params = {}
for key, spec in space.items():
......
......@@ -4,6 +4,7 @@
__all__ = ['RetiariiAdvisor']
import logging
import time
import os
from typing import Any, Callable, Optional, Dict, List, Tuple
......@@ -60,11 +61,12 @@ class RetiariiAdvisor(MsgDispatcherBase):
self.final_metric_callback: Optional[Callable[[int, MetricData], None]] = None
self.parameters_count = 0
# Sometimes messages arrive first before the callbacks get registered.
# Or in case that we allow engine to be absent during the experiment.
# Here we need to store the messages and invoke them later.
self.call_queue: List[Tuple[str, list]] = []
# this is for waiting the to-be-recovered trials from nnimanager
self._advisor_initialized = False
def register_callbacks(self, callbacks: Dict[str, Callable[..., None]]):
"""
......@@ -167,6 +169,10 @@ class RetiariiAdvisor(MsgDispatcherBase):
Parameter ID that is assigned to this parameter,
which will be used for identification in future.
"""
while not self._advisor_initialized:
_logger.info('Wait for RetiariiAdvisor to be initialized...')
time.sleep(0.5)
self.parameters_count += 1
if placement_constraint is None:
placement_constraint = {
......@@ -204,6 +210,7 @@ class RetiariiAdvisor(MsgDispatcherBase):
self.send(CommandType.NoMoreTrialJobs, '')
def handle_request_trial_jobs(self, num_trials):
self._advisor_initialized = True
_logger.debug('Request trial jobs: %s', num_trials)
self.invoke_callback('request_trial_jobs', num_trials)
......@@ -212,10 +219,22 @@ class RetiariiAdvisor(MsgDispatcherBase):
self.search_space = data
def handle_trial_end(self, data):
# TODO: we should properly handle the trials in self._customized_parameter_ids instead of ignoring
id_ = nni.load(data['hyper_params'])['parameter_id']
if self.is_created_in_previous_exp(id_):
_logger.info('The end of the recovered trial %d is ignored', id_)
return
_logger.debug('Trial end: %s', data)
self.invoke_callback('trial_end', nni.load(data['hyper_params'])['parameter_id'], data['event'] == 'SUCCEEDED')
self.invoke_callback('trial_end', id_, data['event'] == 'SUCCEEDED')
def handle_report_metric_data(self, data):
# TODO: we should properly handle the trials in self._customized_parameter_ids instead of ignoring
if self.is_created_in_previous_exp(data['parameter_id']):
_logger.info('The metrics of the recovered trial %d are ignored', data['parameter_id'])
return
# NOTE: this part is not aligned with hpo tuners.
# in hpo tuners, trial_job_id is used for intermediate results handling
# parameter_id is for final result handling.
_logger.debug('Metric reported: %s', data)
if data['type'] == MetricType.REQUEST_PARAMETER:
raise ValueError('Request parameter not supported')
......@@ -239,4 +258,5 @@ class RetiariiAdvisor(MsgDispatcherBase):
pass
def handle_add_customized_trial(self, data):
pass
previous_max_param_id = self.recover_parameter_id(data)
self.parameters_count = previous_max_param_id
......@@ -12,6 +12,7 @@ from typing import NewType, Any
import nni
from nni.common.version import version_check
# NOTE: this is only for passing flake8, we cannot import RetiariiAdvisor
# because it would induce cycled import
RetiariiAdvisor = NewType('RetiariiAdvisor', Any)
......
......@@ -4,8 +4,12 @@
from __future__ import annotations
import os
import nni
class Recoverable:
def __init__(self):
self.recovered_max_param_id = -1
self.recovered_trial_params = {}
def load_checkpoint(self) -> None:
pass
......@@ -18,3 +22,29 @@ class Recoverable:
if ckp_path is not None and os.path.isdir(ckp_path):
return ckp_path
return None
def recover_parameter_id(self, data) -> int:
# this is for handling the resuming of the interrupted data: parameters
if not isinstance(data, list):
data = [data]
previous_max_param_id = 0
for trial in data:
# {'parameter_id': 0, 'parameter_source': 'resumed', 'parameters': {'batch_size': 128, ...}
if isinstance(trial, str):
trial = nni.load(trial)
if not isinstance(trial['parameter_id'], int):
# for dealing with user customized trials
# skip for now
continue
self.recovered_trial_params[trial['parameter_id']] = trial['parameters']
if previous_max_param_id < trial['parameter_id']:
previous_max_param_id = trial['parameter_id']
self.recovered_max_param_id = previous_max_param_id
return previous_max_param_id
def is_created_in_previous_exp(self, param_id: int) -> bool:
return param_id <= self.recovered_max_param_id
def get_previous_param(self, param_id: int) -> dict:
return self.recovered_trial_params[param_id]
\ No newline at end of file
......@@ -120,15 +120,10 @@ class MsgDispatcher(MsgDispatcherBase):
self.tuner.import_data(data)
def handle_add_customized_trial(self, data):
global _next_parameter_id
# data: parameters
if not isinstance(data, list):
data = [data]
for _ in data:
id_ = _create_parameter_id()
_customized_parameter_ids.add(id_)
self.tuner.import_customized_data(data)
previous_max_param_id = self.recover_parameter_id(data)
_next_parameter_id = previous_max_param_id + 1
def handle_report_metric_data(self, data):
"""
......@@ -137,6 +132,13 @@ class MsgDispatcher(MsgDispatcherBase):
- 'value': metric value reported by nni.report_final_result()
- 'type': report type, support {'FINAL', 'PERIODICAL'}
"""
if self.is_created_in_previous_exp(data['parameter_id']):
if data['type'] == MetricType.FINAL:
# only deal with final metric using import data
param = self.get_previous_param(data['parameter_id'])
trial_data = [{'parameter': param, 'value': load(data['value'])}]
self.handle_import_data(trial_data)
return
# metrics value is dumped as json string in trial, so we need to decode it here
if 'value' in data:
data['value'] = load(data['value'])
......@@ -166,6 +168,10 @@ class MsgDispatcher(MsgDispatcherBase):
- event: the job's state
- hyper_params: the hyperparameters generated and returned by tuner
"""
id_ = load(data['hyper_params'])['parameter_id']
if self.is_created_in_previous_exp(id_):
# The end of the recovered trial is ignored
return
trial_job_id = data['trial_job_id']
_ended_trials.add(trial_job_id)
if trial_job_id in _trial_history:
......@@ -173,7 +179,7 @@ class MsgDispatcher(MsgDispatcherBase):
if self.assessor is not None:
self.assessor.trial_end(trial_job_id, data['event'] == 'SUCCEEDED')
if self.tuner is not None:
self.tuner.trial_end(load(data['hyper_params'])['parameter_id'], data['event'] == 'SUCCEEDED')
self.tuner.trial_end(id_, data['event'] == 'SUCCEEDED')
def _handle_final_metric_data(self, data):
"""Call tuner to process final results
......
......@@ -30,6 +30,7 @@ class MsgDispatcherBase(Recoverable):
"""
def __init__(self, command_channel_url=None):
super().__init__()
self.stopping = False
if command_channel_url is None:
command_channel_url = dispatcher_env_vars.NNI_TUNER_COMMAND_CHANNEL
......
......@@ -219,14 +219,6 @@ class Tuner(Recoverable):
# data: a list of dictionarys, each of which has at least two keys, 'parameter' and 'value'
pass
def import_customized_data(self, data: list[TrialRecord]) -> None:
"""
Internal API under revising, not recommended for end users.
"""
# Import resume data for avoiding duplications
# data: a list of dictionarys, each of which has at least two keys, 'parameter_id' and 'parameters'
pass
def _on_exit(self) -> None:
pass
......
......@@ -319,6 +319,9 @@ class CGOEngineTest(unittest.TestCase):
advisor._channel = protocol.LegacyCommandChannel()
advisor.default_worker.start()
advisor.assessor_worker.start()
# this is because RetiariiAdvisor only works after `_advisor_initialized` becomes True.
# normally it becomes true when `handle_request_trial_jobs` is invoked
advisor._advisor_initialized = True
remote = RemoteConfig(machine_list=[])
remote.machine_list.append(RemoteMachineConfig(host='test', gpu_indices=[0,1,2,3]))
......
......@@ -27,6 +27,7 @@ class EngineTest(unittest.TestCase):
nni.retiarii.integration_api._advisor = None
nni.retiarii.execution.api._execution_engine = None
advisor = RetiariiAdvisor('ws://_unittest_placeholder_')
advisor._advisor_initialized = True
advisor._channel = LegacyCommandChannel()
advisor.default_worker.start()
advisor.assessor_worker.start()
......@@ -44,6 +45,7 @@ class EngineTest(unittest.TestCase):
nni.retiarii.integration_api._advisor = None
nni.retiarii.execution.api._execution_engine = None
advisor = RetiariiAdvisor('ws://_unittest_placeholder_')
advisor._advisor_initialized = True
advisor._channel = LegacyCommandChannel()
advisor.default_worker.start()
advisor.assessor_worker.start()
......
......@@ -48,11 +48,11 @@ class AssessorTestCase(TestCase):
def test_assessor(self):
pass
_reverse_io()
send(CommandType.ReportMetricData, '{"trial_job_id":"A","type":"PERIODICAL","sequence":0,"value":"2"}')
send(CommandType.ReportMetricData, '{"trial_job_id":"B","type":"PERIODICAL","sequence":0,"value":"2"}')
send(CommandType.ReportMetricData, '{"trial_job_id":"A","type":"PERIODICAL","sequence":1,"value":"3"}')
send(CommandType.TrialEnd, '{"trial_job_id":"A","event":"SYS_CANCELED"}')
send(CommandType.TrialEnd, '{"trial_job_id":"B","event":"SUCCEEDED"}')
send(CommandType.ReportMetricData, '{"parameter_id": 0,"trial_job_id":"A","type":"PERIODICAL","sequence":0,"value":"2"}')
send(CommandType.ReportMetricData, '{"parameter_id": 1,"trial_job_id":"B","type":"PERIODICAL","sequence":0,"value":"2"}')
send(CommandType.ReportMetricData, '{"parameter_id": 0,"trial_job_id":"A","type":"PERIODICAL","sequence":1,"value":"3"}')
send(CommandType.TrialEnd, '{"trial_job_id":"A","event":"SYS_CANCELED","hyper_params":"{\\"parameter_id\\": 0}"}')
send(CommandType.TrialEnd, '{"trial_job_id":"B","event":"SUCCEEDED","hyper_params":"{\\"parameter_id\\": 1}"}')
send(CommandType.NewTrialJob, 'null')
_restore_io()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment