Unverified Commit aea98dd6 authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

[Retiarii] Export topk models (#3464)

parent 0494cae1
......@@ -2,6 +2,7 @@
# Licensed under the MIT license.
import time
from typing import Iterable
from ..graph import Model, ModelStatus
from .interface import AbstractExecutionEngine
......@@ -11,7 +12,7 @@ _execution_engine = None
_default_listener = None
__all__ = ['get_execution_engine', 'get_and_register_default_listener',
'submit_models', 'wait_models', 'query_available_resources',
'list_models', 'submit_models', 'wait_models', 'query_available_resources',
'set_execution_engine', 'is_stopped_exec']
def set_execution_engine(engine) -> None:
......@@ -43,6 +44,12 @@ def submit_models(*models: Model) -> None:
engine.submit_models(*models)
def list_models(*models: Model) -> Iterable[Model]:
engine = get_execution_engine()
get_and_register_default_listener(engine)
return engine.list_models()
def wait_models(*models: Model) -> None:
get_and_register_default_listener(get_execution_engine())
while True:
......
......@@ -5,7 +5,7 @@ import logging
import os
import random
import string
from typing import Dict, List
from typing import Dict, Iterable, List
from .interface import AbstractExecutionEngine, AbstractGraphListener
from .. import codegen, utils
......@@ -53,6 +53,7 @@ class BaseExecutionEngine(AbstractExecutionEngine):
advisor.final_metric_callback = self._final_metric_callback
self._running_models: Dict[int, Model] = dict()
self._history: List[Model] = []
self.resources = 0
......@@ -60,6 +61,10 @@ class BaseExecutionEngine(AbstractExecutionEngine):
for model in models:
data = BaseGraphData(codegen.model_to_pytorch_script(model), model.evaluator)
self._running_models[send_trial(data.dump())] = model
self._history.append(model)
def list_models(self) -> Iterable[Model]:
return self._history
def register_graph_listener(self, listener: AbstractGraphListener) -> None:
self._listeners.append(listener)
......
......@@ -2,7 +2,7 @@
# Licensed under the MIT license.
import logging
from typing import List, Dict, Tuple
from typing import Iterable, List, Dict, Tuple
from .interface import AbstractExecutionEngine, AbstractGraphListener, WorkerInfo
from .. import codegen, utils
......@@ -58,6 +58,9 @@ class CGOExecutionEngine(AbstractExecutionEngine):
# model.config['trainer_module'], model.config['trainer_kwargs'])
# self._running_models[send_trial(data.dump())] = model
def list_models(self) -> Iterable[Model]:
raise NotImplementedError
def _assemble(self, logical_plan: LogicalPlan) -> List[Tuple[Model, PhysicalDevice]]:
# unique_models = set()
# for node in logical_plan.graph.nodes:
......
......@@ -2,7 +2,7 @@
# Licensed under the MIT license.
from abc import ABC, abstractmethod, abstractclassmethod
from typing import Any, NewType, List, Union
from typing import Any, Iterable, NewType, List, Union
from ..graph import Model, MetricData
......@@ -104,6 +104,15 @@ class AbstractExecutionEngine(ABC):
"""
raise NotImplementedError
@abstractmethod
def list_models(self) -> Iterable[Model]:
"""
Get all models in submitted.
Execution engine should store a copy of models that have been submitted and return a list of copies in this method.
"""
raise NotImplementedError
@abstractmethod
def query_available_resource(self) -> Union[List[WorkerInfo], int]:
"""
......
......@@ -26,7 +26,9 @@ from nni.experiment.config.base import ConfigBase, PathLike
from nni.experiment.pipe import Pipe
from nni.tools.nnictl.command_utils import kill_command
from ..codegen import model_to_pytorch_script
from ..converter import convert_to_graph
from ..execution import list_models
from ..graph import Model, Evaluator
from ..integration import RetiariiAdvisor
from ..mutator import Mutator
......@@ -257,16 +259,31 @@ class RetiariiExperiment(Experiment):
self._dispatcher_thread = None
_logger.info('Experiment stopped')
def export_top_models(self, top_n: int = 1):
def export_top_models(self, top_k: int = 1, optimize_mode: str = 'maximize', formatter: str = 'code') -> Any:
"""
export several top performing models
Export several top performing models.
For one-shot algorithms, only top-1 is supported. For others, ``optimize_mode`` asnd ``formater`` is
available for customization.
top_k : int
How many models are intended to be exported.
optimize_mode : str
``maximize`` or ``minimize``. Not supported by one-shot algorithms.
``optimize_mode`` is likely to be removed and defined in strategy in future.
formatter : str
Only model code is supported for now. Not supported by one-shot algorithms.
"""
if top_n != 1:
_logger.warning('Only support top_n is 1 for now.')
if isinstance(self.trainer, BaseOneShotTrainer):
assert top_k == 1, 'Only support top_k is 1 for now.'
return self.trainer.export()
else:
_logger.info('For this experiment, you can find out the best one from WebUI.')
all_models = filter(lambda m: m.metric is not None, list_models())
assert optimize_mode in ['maximize', 'minimize']
all_models = sorted(all_models, key=lambda m: m.metric, reverse=optimize_mode == 'maximize')
assert formatter == 'code', 'Export formatter other than "code" is not supported yet.'
if formatter == 'code':
return [model_to_pytorch_script(model) for model in all_models[:top_k]]
def retrain_model(self, model):
"""
......
......@@ -49,7 +49,10 @@ if __name__ == '__main__':
exp_config = RetiariiExeConfig('local')
exp_config.experiment_name = 'mnist_search'
exp_config.trial_concurrency = 2
exp_config.max_trial_number = 10
exp_config.max_trial_number = 2
exp_config.training_service.use_active_gpu = False
exp.run(exp_config, 8081 + random.randint(0, 100))
print('Final model:')
for model_code in exp.export_top_models():
print(model_code)
......@@ -37,6 +37,9 @@ class MockExecutionEngine(AbstractExecutionEngine):
self._resource_left -= 1
threading.Thread(target=self._model_complete, args=(model, )).start()
def list_models(self) -> List[Model]:
return self.models
def query_available_resource(self) -> Union[List[WorkerInfo], int]:
return self._resource_left
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment