Unverified Commit 4a0cc125 authored by QuanluZhang's avatar QuanluZhang Committed by GitHub
Browse files

[Retiarii] fix experiment early exit (#3547)

parent 336d671c
...@@ -13,7 +13,7 @@ _default_listener = None ...@@ -13,7 +13,7 @@ _default_listener = None
__all__ = ['get_execution_engine', 'get_and_register_default_listener', __all__ = ['get_execution_engine', 'get_and_register_default_listener',
'list_models', 'submit_models', 'wait_models', 'query_available_resources', 'list_models', 'submit_models', 'wait_models', 'query_available_resources',
'set_execution_engine', 'is_stopped_exec'] 'set_execution_engine', 'is_stopped_exec', 'budget_exhausted']
def set_execution_engine(engine) -> None: def set_execution_engine(engine) -> None:
global _execution_engine global _execution_engine
...@@ -22,6 +22,7 @@ def set_execution_engine(engine) -> None: ...@@ -22,6 +22,7 @@ def set_execution_engine(engine) -> None:
else: else:
raise RuntimeError('execution engine is already set') raise RuntimeError('execution engine is already set')
def get_execution_engine() -> AbstractExecutionEngine: def get_execution_engine() -> AbstractExecutionEngine:
""" """
Currently we assume the default execution engine is BaseExecutionEngine. Currently we assume the default execution engine is BaseExecutionEngine.
...@@ -67,3 +68,8 @@ def query_available_resources() -> int: ...@@ -67,3 +68,8 @@ def query_available_resources() -> int:
def is_stopped_exec(model: Model) -> bool: def is_stopped_exec(model: Model) -> bool:
return model.status in (ModelStatus.Trained, ModelStatus.Failed) return model.status in (ModelStatus.Trained, ModelStatus.Failed)
def budget_exhausted() -> bool:
engine = get_execution_engine()
return engine.budget_exhausted()
...@@ -104,6 +104,10 @@ class BaseExecutionEngine(AbstractExecutionEngine): ...@@ -104,6 +104,10 @@ class BaseExecutionEngine(AbstractExecutionEngine):
def query_available_resource(self) -> int: def query_available_resource(self) -> int:
return self.resources return self.resources
def budget_exhausted(self) -> bool:
advisor = get_advisor()
return advisor.stopping
@classmethod @classmethod
def trial_execute_graph(cls) -> None: def trial_execute_graph(cls) -> None:
""" """
......
...@@ -130,6 +130,9 @@ class CGOExecutionEngine(AbstractExecutionEngine): ...@@ -130,6 +130,9 @@ class CGOExecutionEngine(AbstractExecutionEngine):
def query_available_resource(self) -> List[WorkerInfo]: def query_available_resource(self) -> List[WorkerInfo]:
raise NotImplementedError # move the method from listener to here? raise NotImplementedError # move the method from listener to here?
def budget_exhausted(self) -> bool:
raise NotImplementedError
@classmethod @classmethod
def trial_execute_graph(cls) -> None: def trial_execute_graph(cls) -> None:
""" """
......
...@@ -123,6 +123,13 @@ class AbstractExecutionEngine(ABC): ...@@ -123,6 +123,13 @@ class AbstractExecutionEngine(ABC):
""" """
raise NotImplementedError raise NotImplementedError
@abstractmethod
def budget_exhausted(self) -> bool:
"""
Check whether user configured max trial number or max execution duration has been reached
"""
raise NotImplementedError
@abstractmethod @abstractmethod
def register_graph_listener(self, listener: AbstractGraphListener) -> None: def register_graph_listener(self, listener: AbstractGraphListener) -> None:
""" """
......
...@@ -165,7 +165,8 @@ class RetiariiExperiment(Experiment): ...@@ -165,7 +165,8 @@ class RetiariiExperiment(Experiment):
_logger.info('Start strategy...') _logger.info('Start strategy...')
self.strategy.run(base_model_ir, self.applied_mutators) self.strategy.run(base_model_ir, self.applied_mutators)
_logger.info('Strategy exit') _logger.info('Strategy exit')
self._dispatcher.mark_experiment_as_ending() # TODO: find out a proper way to show no more trial message on WebUI
#self._dispatcher.mark_experiment_as_ending()
def start(self, port: int = 8080, debug: bool = False) -> None: def start(self, port: int = 8080, debug: bool = False) -> None:
""" """
...@@ -210,11 +211,12 @@ class RetiariiExperiment(Experiment): ...@@ -210,11 +211,12 @@ class RetiariiExperiment(Experiment):
msg = 'Web UI URLs: ' + colorama.Fore.CYAN + ' '.join(ips) + colorama.Style.RESET_ALL msg = 'Web UI URLs: ' + colorama.Fore.CYAN + ' '.join(ips) + colorama.Style.RESET_ALL
_logger.info(msg) _logger.info(msg)
Thread(target=self._check_exp_status).start() exp_status_checker = Thread(target=self._check_exp_status)
exp_status_checker.start()
self._start_strategy() self._start_strategy()
# TODO: the experiment should be completed, when strategy exits and there is no running job # TODO: the experiment should be completed, when strategy exits and there is no running job
# _logger.info('Waiting for submitted trial jobs to finish...')
_logger.info('Waiting for experiment to become DONE (you can ctrl+c if there is no running trial jobs)...') _logger.info('Waiting for experiment to become DONE (you can ctrl+c if there is no running trial jobs)...')
exp_status_checker.join()
def _create_dispatcher(self): def _create_dispatcher(self):
return self._dispatcher return self._dispatcher
...@@ -240,7 +242,12 @@ class RetiariiExperiment(Experiment): ...@@ -240,7 +242,12 @@ class RetiariiExperiment(Experiment):
try: try:
while True: while True:
time.sleep(10) time.sleep(10)
status = self.get_status() # this if is to deal with the situation that
# nnimanager is cleaned up by ctrl+c first
if self._proc.poll() is None:
status = self.get_status()
else:
return False
if status == 'DONE' or status == 'STOPPED': if status == 'DONE' or status == 'STOPPED':
return True return True
if status == 'ERROR': if status == 'ERROR':
...@@ -261,7 +268,10 @@ class RetiariiExperiment(Experiment): ...@@ -261,7 +268,10 @@ class RetiariiExperiment(Experiment):
nni.runtime.log.stop_experiment_log(self.id) nni.runtime.log.stop_experiment_log(self.id)
if self._proc is not None: if self._proc is not None:
try: try:
rest.delete(self.port, '/experiment') # this if is to deal with the situation that
# nnimanager is cleaned up by ctrl+c first
if self._proc.poll() is None:
rest.delete(self.port, '/experiment')
except Exception as e: except Exception as e:
_logger.exception(e) _logger.exception(e)
_logger.warning('Cannot gracefully stop experiment, killing NNI process...') _logger.warning('Cannot gracefully stop experiment, killing NNI process...')
......
...@@ -6,7 +6,7 @@ import time ...@@ -6,7 +6,7 @@ import time
from nni.algorithms.hpo.hyperopt_tuner import HyperoptTuner from nni.algorithms.hpo.hyperopt_tuner import HyperoptTuner
from .. import Sampler, submit_models, query_available_resources, is_stopped_exec from .. import Sampler, submit_models, query_available_resources, is_stopped_exec, budget_exhausted
from .base import BaseStrategy from .base import BaseStrategy
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
...@@ -54,7 +54,7 @@ class TPEStrategy(BaseStrategy): ...@@ -54,7 +54,7 @@ class TPEStrategy(BaseStrategy):
self.tpe_sampler.update_sample_space(sample_space) self.tpe_sampler.update_sample_space(sample_space)
_logger.info('TPE strategy has been started.') _logger.info('TPE strategy has been started.')
while True: while not budget_exhausted():
avail_resource = query_available_resources() avail_resource = query_available_resources()
if avail_resource > 0: if avail_resource > 0:
model = base_model model = base_model
...@@ -70,13 +70,13 @@ class TPEStrategy(BaseStrategy): ...@@ -70,13 +70,13 @@ class TPEStrategy(BaseStrategy):
else: else:
time.sleep(2) time.sleep(2)
_logger.warning('num of running models: %d', len(self.running_models)) _logger.debug('num of running models: %d', len(self.running_models))
to_be_deleted = [] to_be_deleted = []
for _id, _model in self.running_models.items(): for _id, _model in self.running_models.items():
if is_stopped_exec(_model): if is_stopped_exec(_model):
if _model.metric is not None: if _model.metric is not None:
self.tpe_sampler.receive_result(_id, _model.metric) self.tpe_sampler.receive_result(_id, _model.metric)
_logger.warning('tpe receive results: %d, %s', _id, _model.metric) _logger.debug('tpe receive results: %d, %s', _id, _model.metric)
to_be_deleted.append(_id) to_be_deleted.append(_id)
for _id in to_be_deleted: for _id in to_be_deleted:
del self.running_models[_id] del self.running_models[_id]
...@@ -43,6 +43,9 @@ class MockExecutionEngine(AbstractExecutionEngine): ...@@ -43,6 +43,9 @@ class MockExecutionEngine(AbstractExecutionEngine):
def query_available_resource(self) -> Union[List[WorkerInfo], int]: def query_available_resource(self) -> Union[List[WorkerInfo], int]:
return self._resource_left return self._resource_left
def budget_exhausted(self) -> bool:
pass
def register_graph_listener(self, listener: AbstractGraphListener) -> None: def register_graph_listener(self, listener: AbstractGraphListener) -> None:
pass pass
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment