Unverified Commit aea98dd6 authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

[Retiarii] Export topk models (#3464)

parent 0494cae1
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
# Licensed under the MIT license. # Licensed under the MIT license.
import time import time
from typing import Iterable
from ..graph import Model, ModelStatus from ..graph import Model, ModelStatus
from .interface import AbstractExecutionEngine from .interface import AbstractExecutionEngine
...@@ -11,7 +12,7 @@ _execution_engine = None ...@@ -11,7 +12,7 @@ _execution_engine = None
_default_listener = None _default_listener = None
__all__ = ['get_execution_engine', 'get_and_register_default_listener', __all__ = ['get_execution_engine', 'get_and_register_default_listener',
'submit_models', 'wait_models', 'query_available_resources', 'list_models', 'submit_models', 'wait_models', 'query_available_resources',
'set_execution_engine', 'is_stopped_exec'] 'set_execution_engine', 'is_stopped_exec']
def set_execution_engine(engine) -> None: def set_execution_engine(engine) -> None:
...@@ -43,6 +44,12 @@ def submit_models(*models: Model) -> None: ...@@ -43,6 +44,12 @@ def submit_models(*models: Model) -> None:
engine.submit_models(*models) engine.submit_models(*models)
def list_models(*models: Model) -> Iterable[Model]:
engine = get_execution_engine()
get_and_register_default_listener(engine)
return engine.list_models()
def wait_models(*models: Model) -> None: def wait_models(*models: Model) -> None:
get_and_register_default_listener(get_execution_engine()) get_and_register_default_listener(get_execution_engine())
while True: while True:
......
...@@ -5,7 +5,7 @@ import logging ...@@ -5,7 +5,7 @@ import logging
import os import os
import random import random
import string import string
from typing import Dict, List from typing import Dict, Iterable, List
from .interface import AbstractExecutionEngine, AbstractGraphListener from .interface import AbstractExecutionEngine, AbstractGraphListener
from .. import codegen, utils from .. import codegen, utils
...@@ -53,6 +53,7 @@ class BaseExecutionEngine(AbstractExecutionEngine): ...@@ -53,6 +53,7 @@ class BaseExecutionEngine(AbstractExecutionEngine):
advisor.final_metric_callback = self._final_metric_callback advisor.final_metric_callback = self._final_metric_callback
self._running_models: Dict[int, Model] = dict() self._running_models: Dict[int, Model] = dict()
self._history: List[Model] = []
self.resources = 0 self.resources = 0
...@@ -60,6 +61,10 @@ class BaseExecutionEngine(AbstractExecutionEngine): ...@@ -60,6 +61,10 @@ class BaseExecutionEngine(AbstractExecutionEngine):
for model in models: for model in models:
data = BaseGraphData(codegen.model_to_pytorch_script(model), model.evaluator) data = BaseGraphData(codegen.model_to_pytorch_script(model), model.evaluator)
self._running_models[send_trial(data.dump())] = model self._running_models[send_trial(data.dump())] = model
self._history.append(model)
def list_models(self) -> Iterable[Model]:
return self._history
def register_graph_listener(self, listener: AbstractGraphListener) -> None: def register_graph_listener(self, listener: AbstractGraphListener) -> None:
self._listeners.append(listener) self._listeners.append(listener)
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# Licensed under the MIT license. # Licensed under the MIT license.
import logging import logging
from typing import List, Dict, Tuple from typing import Iterable, List, Dict, Tuple
from .interface import AbstractExecutionEngine, AbstractGraphListener, WorkerInfo from .interface import AbstractExecutionEngine, AbstractGraphListener, WorkerInfo
from .. import codegen, utils from .. import codegen, utils
...@@ -58,6 +58,9 @@ class CGOExecutionEngine(AbstractExecutionEngine): ...@@ -58,6 +58,9 @@ class CGOExecutionEngine(AbstractExecutionEngine):
# model.config['trainer_module'], model.config['trainer_kwargs']) # model.config['trainer_module'], model.config['trainer_kwargs'])
# self._running_models[send_trial(data.dump())] = model # self._running_models[send_trial(data.dump())] = model
def list_models(self) -> Iterable[Model]:
raise NotImplementedError
def _assemble(self, logical_plan: LogicalPlan) -> List[Tuple[Model, PhysicalDevice]]: def _assemble(self, logical_plan: LogicalPlan) -> List[Tuple[Model, PhysicalDevice]]:
# unique_models = set() # unique_models = set()
# for node in logical_plan.graph.nodes: # for node in logical_plan.graph.nodes:
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# Licensed under the MIT license. # Licensed under the MIT license.
from abc import ABC, abstractmethod, abstractclassmethod from abc import ABC, abstractmethod, abstractclassmethod
from typing import Any, NewType, List, Union from typing import Any, Iterable, NewType, List, Union
from ..graph import Model, MetricData from ..graph import Model, MetricData
...@@ -104,6 +104,15 @@ class AbstractExecutionEngine(ABC): ...@@ -104,6 +104,15 @@ class AbstractExecutionEngine(ABC):
""" """
raise NotImplementedError raise NotImplementedError
@abstractmethod
def list_models(self) -> Iterable[Model]:
"""
Get all models in submitted.
Execution engine should store a copy of models that have been submitted and return a list of copies in this method.
"""
raise NotImplementedError
@abstractmethod @abstractmethod
def query_available_resource(self) -> Union[List[WorkerInfo], int]: def query_available_resource(self) -> Union[List[WorkerInfo], int]:
""" """
......
...@@ -26,7 +26,9 @@ from nni.experiment.config.base import ConfigBase, PathLike ...@@ -26,7 +26,9 @@ from nni.experiment.config.base import ConfigBase, PathLike
from nni.experiment.pipe import Pipe from nni.experiment.pipe import Pipe
from nni.tools.nnictl.command_utils import kill_command from nni.tools.nnictl.command_utils import kill_command
from ..codegen import model_to_pytorch_script
from ..converter import convert_to_graph from ..converter import convert_to_graph
from ..execution import list_models
from ..graph import Model, Evaluator from ..graph import Model, Evaluator
from ..integration import RetiariiAdvisor from ..integration import RetiariiAdvisor
from ..mutator import Mutator from ..mutator import Mutator
...@@ -257,16 +259,31 @@ class RetiariiExperiment(Experiment): ...@@ -257,16 +259,31 @@ class RetiariiExperiment(Experiment):
self._dispatcher_thread = None self._dispatcher_thread = None
_logger.info('Experiment stopped') _logger.info('Experiment stopped')
def export_top_models(self, top_n: int = 1): def export_top_models(self, top_k: int = 1, optimize_mode: str = 'maximize', formatter: str = 'code') -> Any:
""" """
export several top performing models Export several top performing models.
For one-shot algorithms, only top-1 is supported. For others, ``optimize_mode`` asnd ``formater`` is
available for customization.
top_k : int
How many models are intended to be exported.
optimize_mode : str
``maximize`` or ``minimize``. Not supported by one-shot algorithms.
``optimize_mode`` is likely to be removed and defined in strategy in future.
formatter : str
Only model code is supported for now. Not supported by one-shot algorithms.
""" """
if top_n != 1:
_logger.warning('Only support top_n is 1 for now.')
if isinstance(self.trainer, BaseOneShotTrainer): if isinstance(self.trainer, BaseOneShotTrainer):
assert top_k == 1, 'Only support top_k is 1 for now.'
return self.trainer.export() return self.trainer.export()
else: else:
_logger.info('For this experiment, you can find out the best one from WebUI.') all_models = filter(lambda m: m.metric is not None, list_models())
assert optimize_mode in ['maximize', 'minimize']
all_models = sorted(all_models, key=lambda m: m.metric, reverse=optimize_mode == 'maximize')
assert formatter == 'code', 'Export formatter other than "code" is not supported yet.'
if formatter == 'code':
return [model_to_pytorch_script(model) for model in all_models[:top_k]]
def retrain_model(self, model): def retrain_model(self, model):
""" """
......
...@@ -49,7 +49,10 @@ if __name__ == '__main__': ...@@ -49,7 +49,10 @@ if __name__ == '__main__':
exp_config = RetiariiExeConfig('local') exp_config = RetiariiExeConfig('local')
exp_config.experiment_name = 'mnist_search' exp_config.experiment_name = 'mnist_search'
exp_config.trial_concurrency = 2 exp_config.trial_concurrency = 2
exp_config.max_trial_number = 10 exp_config.max_trial_number = 2
exp_config.training_service.use_active_gpu = False exp_config.training_service.use_active_gpu = False
exp.run(exp_config, 8081 + random.randint(0, 100)) exp.run(exp_config, 8081 + random.randint(0, 100))
print('Final model:')
for model_code in exp.export_top_models():
print(model_code)
...@@ -37,6 +37,9 @@ class MockExecutionEngine(AbstractExecutionEngine): ...@@ -37,6 +37,9 @@ class MockExecutionEngine(AbstractExecutionEngine):
self._resource_left -= 1 self._resource_left -= 1
threading.Thread(target=self._model_complete, args=(model, )).start() threading.Thread(target=self._model_complete, args=(model, )).start()
def list_models(self) -> List[Model]:
return self.models
def query_available_resource(self) -> Union[List[WorkerInfo], int]: def query_available_resource(self) -> Union[List[WorkerInfo], int]:
return self._resource_left return self._resource_left
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment