Unverified Commit 2fc47247 authored by QuanluZhang's avatar QuanluZhang Committed by GitHub
Browse files

[retiarii] refactor of nas experiment (#4841)

parent c80bda29
......@@ -54,6 +54,11 @@ class ConfigBase:
Config objects will remember where they are loaded; therefore relative paths can be resolved smartly.
If a config object is created with constructor, the base path will be current working directory.
If it is loaded with ``ConfigBase.load(path)``, the base path will be ``path``'s parent.
.. attention::
All the classes that inherit ``ConfigBase`` are not allowed to use ``from __future__ import annotations``,
because ``ConfigBase`` uses ``typeguard`` to perform runtime check and it does not support lazy annotations.
"""
def __init__(self, **kwargs):
......
......@@ -164,6 +164,7 @@ class ExperimentConfig(ConfigBase):
# currently I have only seen one issue of this kind
#Path(self.experiment_working_directory).mkdir(parents=True, exist_ok=True)
if type(self).__name__ != 'RetiariiExeConfig':
utils.validate_gpu_indices(self.tuner_gpu_indices)
if self.tuner is None:
......
......@@ -84,20 +84,9 @@ class Experiment:
else:
self.config = config_or_platform
def start(self, port: int = 8080, debug: bool = False, run_mode: RunMode = RunMode.Background) -> None:
"""
Start the experiment in background.
This method will raise exception on failure.
If it returns, the experiment should have been successfully started.
Parameters
----------
port
The port of web UI.
debug
Whether to start in debug mode.
"""
def _start_impl(self, port: int, debug: bool, run_mode: RunMode,
tuner_command_channel: str | None,
tags: list[str] = []) -> ExperimentConfig:
assert self.config is not None
if run_mode is not RunMode.Detach:
atexit.register(self.stop)
......@@ -111,7 +100,8 @@ class Experiment:
log_level = 'debug' if (debug or config.log_level == 'trace') else config.log_level
start_experiment_logging(self.id, log_file, cast(str, log_level))
self._proc = launcher.start_experiment(self._action, self.id, config, port, debug, run_mode, self.url_prefix)
self._proc = launcher.start_experiment(self._action, self.id, config, port, debug, run_mode,
self.url_prefix, tuner_command_channel, tags)
assert self._proc is not None
self.port = port # port will be None if start up failed
......@@ -124,12 +114,27 @@ class Experiment:
ips = [f'http://{ip}:{port}' for ip in ips if ip]
msg = 'Web portal URLs: ${CYAN}' + ' '.join(ips)
_logger.info(msg)
return config
def stop(self) -> None:
def start(self, port: int = 8080, debug: bool = False, run_mode: RunMode = RunMode.Background) -> None:
"""
Stop the experiment.
Start the experiment in background.
This method will raise exception on failure.
If it returns, the experiment should have been successfully started.
Parameters
----------
port
The port of web UI.
debug
Whether to start in debug mode.
run_mode
Running the experiment in foreground or background
"""
_logger.info('Stopping experiment, please wait...')
self._start_impl(port, debug, run_mode, None, [])
def _stop_impl(self) -> None:
atexit.unregister(self.stop)
stop_experiment_logging(self.id)
......@@ -144,8 +149,24 @@ class Experiment:
self.id = None # type: ignore
self.port = None
self._proc = None
def stop(self) -> None:
"""
Stop the experiment.
"""
_logger.info('Stopping experiment, please wait...')
self._stop_impl()
_logger.info('Experiment stopped')
def _wait_completion(self) -> bool:
while True:
status = self.get_status()
if status == 'DONE' or status == 'STOPPED':
return True
if status == 'ERROR':
return False
time.sleep(10)
def run(self, port: int = 8080, wait_completion: bool = True, debug: bool = False) -> bool | None:
"""
Run the experiment.
......@@ -159,13 +180,7 @@ class Experiment:
self.start(port, debug)
if wait_completion:
try:
while True:
time.sleep(10)
status = self.get_status()
if status == 'DONE' or status == 'STOPPED':
return True
if status == 'ERROR':
return False
self._wait_completion()
except KeyboardInterrupt:
_logger.warning('KeyboardInterrupt detected')
self.stop()
......
......@@ -2,6 +2,7 @@
# Licensed under the MIT license.
import time
import warnings
from typing import Iterable
from ..graph import Model, ModelStatus
......@@ -18,12 +19,12 @@ __all__ = ['get_execution_engine', 'get_and_register_default_listener',
def set_execution_engine(engine: AbstractExecutionEngine) -> None:
global _execution_engine
if _execution_engine is None:
_execution_engine = engine
else:
raise RuntimeError('Execution engine is already set. '
if _execution_engine is not None:
warnings.warn('Execution engine is already set. '
'You should avoid instantiating RetiariiExperiment twice in one process. '
'If you are running in a Jupyter notebook, please restart the kernel.')
'If you are running in a Jupyter notebook, please restart the kernel.',
RuntimeWarning)
_execution_engine = engine
def get_execution_engine() -> AbstractExecutionEngine:
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from __future__ import annotations
import logging
import os
import random
import string
from typing import Any, Dict, Iterable, List
from nni.experiment import rest
from .interface import AbstractExecutionEngine, AbstractGraphListener
from .utils import get_mutation_summary
from .. import codegen, utils
......@@ -54,12 +58,22 @@ class BaseExecutionEngine(AbstractExecutionEngine):
Resource management is implemented in this class.
"""
def __init__(self) -> None:
def __init__(self, rest_port: int | None = None, rest_url_prefix: str | None = None) -> None:
"""
Upon initialization, advisor callbacks need to be registered.
Advisor will call the callbacks when the corresponding event has been triggered.
Base execution engine will get those callbacks and broadcast them to graph listener.
Parameters
----------
rest_port
The port of the experiment's rest server
rest_url_prefix
The url prefix of the experiment's rest entry
"""
self.port = rest_port
self.url_prefix = rest_url_prefix
self._listeners: List[AbstractGraphListener] = []
# register advisor callbacks
......@@ -123,8 +137,8 @@ class BaseExecutionEngine(AbstractExecutionEngine):
return self.resources
def budget_exhausted(self) -> bool:
advisor = get_advisor()
return advisor.stopping
resp = rest.get(self.port, '/check-status', self.url_prefix)
return resp['status'] == 'DONE'
@classmethod
def pack_model_data(cls, model: Model) -> Any:
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from __future__ import annotations
import logging
import os
import random
import string
import time
import threading
from typing import Iterable, List, Dict, Tuple
from typing import Iterable, List, Dict, Tuple, cast
from dataclasses import dataclass
from nni.common.device import GPUDevice, Device
from nni.experiment.config.training_services import RemoteConfig
from .interface import AbstractExecutionEngine, AbstractGraphListener, WorkerInfo
from .. import codegen, utils
from ..graph import Model, ModelStatus, MetricData, Node
......@@ -31,7 +34,6 @@ class TrialSubmission:
placement: Dict[Node, Device]
grouped_models: List[Model]
class CGOExecutionEngine(AbstractExecutionEngine):
"""
The execution engine with Cross-Graph Optimization (CGO).
......@@ -41,24 +43,35 @@ class CGOExecutionEngine(AbstractExecutionEngine):
Parameters
----------
devices : List[Device]
Available devices for execution.
max_concurrency : int
training_service
The remote training service config.
max_concurrency
The maximum number of trials to run concurrently.
batch_waiting_time: int
batch_waiting_time
Seconds to wait for each batch of trial submission.
The trials within one batch could apply cross-graph optimization.
rest_port
The port of the experiment's rest server
rest_url_prefix
The url prefix of the experiment's rest entry
"""
def __init__(self, devices: List[Device] = None,
def __init__(self, training_service: RemoteConfig,
max_concurrency: int = None,
batch_waiting_time: int = 60,
rest_port: int | None = None,
rest_url_prefix: str | None = None
) -> None:
self.port = rest_port
self.url_prefix = rest_url_prefix
self._listeners: List[AbstractGraphListener] = []
self._running_models: Dict[int, Model] = dict()
self.logical_plan_counter = 0
self.available_devices: List[Device] = []
self.max_concurrency: int = max_concurrency
devices = self._construct_devices(training_service)
for device in devices:
self.available_devices.append(device)
self.all_devices = self.available_devices.copy()
......@@ -88,6 +101,17 @@ class CGOExecutionEngine(AbstractExecutionEngine):
self._consumer_thread = threading.Thread(target=self._consume_models)
self._consumer_thread.start()
def _construct_devices(self, training_service):
devices = []
if hasattr(training_service, 'machine_list'):
for machine in cast(RemoteConfig, training_service).machine_list:
assert machine.gpu_indices is not None, \
'gpu_indices must be set in RemoteMachineConfig for CGO execution engine'
assert isinstance(machine.gpu_indices, list), 'gpu_indices must be a list'
for gpu_idx in machine.gpu_indices:
devices.append(GPUDevice(machine.host, gpu_idx))
return devices
def join(self):
self._stopped = True
self._consumer_thread.join()
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
\ No newline at end of file
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .experiment_config import *
from .engine_config import *
\ No newline at end of file
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from dataclasses import dataclass
from typing import Optional, List
from nni.experiment.config.base import ConfigBase
__all__ = ['ExecutionEngineConfig', 'BaseEngineConfig', 'OneshotEngineConfig',
'PyEngineConfig', 'CgoEngineConfig', 'BenchmarkEngineConfig']
@dataclass(init=False)
class ExecutionEngineConfig(ConfigBase):
name: str
@dataclass(init=False)
class PyEngineConfig(ExecutionEngineConfig):
name: str = 'py'
@dataclass(init=False)
class OneshotEngineConfig(ExecutionEngineConfig):
name: str = 'oneshot'
@dataclass(init=False)
class BaseEngineConfig(ExecutionEngineConfig):
name: str = 'base'
# input used in GraphConverterWithShape. Currently support shape tuple only.
dummy_input: Optional[List[int]] = None
@dataclass(init=False)
class CgoEngineConfig(ExecutionEngineConfig):
name: str = 'cgo'
max_concurrency_cgo: Optional[int] = None
batch_waiting_time: Optional[int] = None
# input used in GraphConverterWithShape. Currently support shape tuple only.
dummy_input: Optional[List[int]] = None
@dataclass(init=False)
class BenchmarkEngineConfig(ExecutionEngineConfig):
name: str = 'benchmark'
benchmark: Optional[str] = None
\ No newline at end of file
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import os
from dataclasses import dataclass
from typing import Any, Union
from nni.experiment.config import utils, ExperimentConfig
from .engine_config import ExecutionEngineConfig
__all__ = ['RetiariiExeConfig']
def execution_engine_config_factory(engine_name):
# FIXME: may move this function to experiment utils in future
cls = _get_ee_config_class(engine_name)
if cls is None:
raise ValueError(f'Invalid execution engine name: {engine_name}')
return cls()
def _get_ee_config_class(engine_name):
for cls in ExecutionEngineConfig.__subclasses__():
if cls.name == engine_name:
return cls
return None
@dataclass(init=False)
class RetiariiExeConfig(ExperimentConfig):
# FIXME: refactor this class to inherit from a new common base class with HPO config
search_space: Any = ''
trial_code_directory: utils.PathLike = '.'
trial_command: str = '_reserved'
# new config field for NAS
execution_engine: Union[str, ExecutionEngineConfig]
def __init__(self, training_service_platform: Union[str, None] = None,
execution_engine: Union[str, ExecutionEngineConfig] = 'py',
**kwargs):
super().__init__(training_service_platform, **kwargs)
self.execution_engine = execution_engine
def _canonicalize(self, _parents):
msg = '{} is not supposed to be set in Retiarii experiment by users, your config is {}.'
if self.search_space != '':
raise ValueError(msg.format('search_space', self.search_space))
# TODO: maybe we should also allow users to specify trial_code_directory
if str(self.trial_code_directory) != '.' and not os.path.isabs(self.trial_code_directory):
raise ValueError(msg.format('trial_code_directory', self.trial_code_directory))
if self.trial_command != '_reserved' and \
not self.trial_command.startswith('python3 -m nni.retiarii.trial_entry '):
raise ValueError(msg.format('trial_command', self.trial_command))
if isinstance(self.execution_engine, str):
self.execution_engine = execution_engine_config_factory(self.execution_engine)
if self.execution_engine.name in ('py', 'base', 'cgo'):
# TODO: replace python3 with more elegant approach
# maybe use sys.executable rendered in trial side (e.g., trial_runner)
self.trial_command = 'python3 -m nni.retiarii.trial_entry ' + self.execution_engine.name
super()._canonicalize([self])
This diff is collapsed.
......@@ -22,7 +22,10 @@ def get_advisor() -> 'RetiariiAdvisor':
def register_advisor(advisor: 'RetiariiAdvisor'):
global _advisor
assert _advisor is None
if _advisor is not None:
warnings.warn('Advisor is already set.'
'You should avoid instantiating RetiariiExperiment twice in one proces.'
'If you are running in a Jupyter notebook, please restart the kernel.')
_advisor = advisor
......
......@@ -18,8 +18,15 @@ _worker_fast_exit_on_terminate = True
class MsgDispatcherBase(Recoverable):
"""This is where tuners and assessors are not defined yet.
"""
This is where tuners and assessors are not defined yet.
Inherits this class to make your own advisor.
.. note::
The class inheriting MsgDispatcherBase should be instantiated
after nnimanager (rest server) is started, so that the object
is ready to use right after its instantiation.
"""
def __init__(self, command_channel_url=None):
......@@ -27,6 +34,16 @@ class MsgDispatcherBase(Recoverable):
if command_channel_url is None:
command_channel_url = dispatcher_env_vars.NNI_TUNER_COMMAND_CHANNEL
self._channel = TunerCommandChannel(command_channel_url)
# NOTE: `connect()` should be put in __init__. First, this `connect()` affects nnimanager's
# starting process, without `connect()` nnimanager is blocked in `dispatcher.init()`.
# Second, nas experiment uses a thread to execute `run()` of this class, thus, there is
# no way to know when the websocket between nnimanager and dispatcher is built. The following
# logic may crash is websocket is not built. One example is updating search space. If updating
# search space too soon, as the websocket has not been built, the rest api of updating search
# space will timeout.
# FIXME: this is making unittest happy
if not command_channel_url.startswith('ws://_unittest_'):
self._channel.connect()
self.default_command_queue = Queue()
self.assessor_command_queue = Queue()
self.default_worker = threading.Thread(target=self.command_queue_worker, args=(self.default_command_queue,))
......@@ -39,7 +56,6 @@ class MsgDispatcherBase(Recoverable):
"""
_logger.info('Dispatcher started')
self._channel.connect()
self.default_worker.start()
self.assessor_worker.start()
......
......@@ -4,7 +4,6 @@ import warnings
import torch
import torch.nn as torch_nn
from torchvision.models.utils import load_state_dict_from_url
import torch.nn.functional as F
import sys
......
......@@ -8,7 +8,7 @@ import nni.retiarii.evaluator.pytorch.cgo.evaluator as cgo
from nni.retiarii import serialize
from base_mnasnet import MNASNet
from nni.experiment import RemoteMachineConfig
from nni.retiarii.experiment.pytorch import RetiariiExperiment, RetiariiExeConfig
from nni.retiarii.experiment.pytorch import RetiariiExperiment, RetiariiExeConfig, CgoEngineConfig
from nni.retiarii.strategy import TPEStrategy
from torchvision import transforms
from torchvision.datasets import CIFAR10
......@@ -59,8 +59,6 @@ if __name__ == '__main__':
exp_config.max_trial_number = 10
exp_config.trial_gpu_number = 1
exp_config.training_service.reuse_mode = True
exp_config.max_concurrency_cgo = 3
exp_config.batch_waiting_time = 0
rm_conf = RemoteMachineConfig()
rm_conf.host = '127.0.0.1'
......@@ -73,6 +71,6 @@ if __name__ == '__main__':
rm_conf.max_trial_number_per_gpu = 3
exp_config.training_service.machine_list = [rm_conf]
exp_config.execution_engine = 'cgo'
exp_config.execution_engine = CgoEngineConfig(max_concurrency_cgo = 3, batch_waiting_time = 0)
exp.run(exp_config, 8099)
......@@ -9,6 +9,7 @@ from pytorch_lightning.utilities.seed import seed_everything
from pathlib import Path
import nni
from nni.experiment.config import RemoteConfig, RemoteMachineConfig
import nni.runtime.platform.test
from nni.runtime.tuner_command_channel import legacy as protocol
import json
......@@ -263,13 +264,14 @@ class CGOEngineTest(unittest.TestCase):
opt = DedupInputOptimizer()
opt.convert(lp)
advisor = RetiariiAdvisor('ws://_placeholder_')
advisor = RetiariiAdvisor('ws://_unittest_placeholder_')
advisor._channel = protocol.LegacyCommandChannel()
advisor.default_worker.start()
advisor.assessor_worker.start()
available_devices = [GPUDevice("test", 0), GPUDevice("test", 1), GPUDevice("test", 2), GPUDevice("test", 3)]
cgo = CGOExecutionEngine(devices=available_devices, batch_waiting_time=0)
remote = RemoteConfig(machine_list=[])
remote.machine_list.append(RemoteMachineConfig(host='test', gpu_indices=[0,1,2,3]))
cgo = CGOExecutionEngine(training_service=remote, batch_waiting_time=0)
phy_models = cgo._assemble(lp)
self.assertTrue(len(phy_models) == 1)
......@@ -286,13 +288,14 @@ class CGOEngineTest(unittest.TestCase):
opt = DedupInputOptimizer()
opt.convert(lp)
advisor = RetiariiAdvisor('ws://_placeholder_')
advisor = RetiariiAdvisor('ws://_unittest_placeholder_')
advisor._channel = protocol.LegacyCommandChannel()
advisor.default_worker.start()
advisor.assessor_worker.start()
available_devices = [GPUDevice("test", 0), GPUDevice("test", 1)]
cgo = CGOExecutionEngine(devices=available_devices, batch_waiting_time=0)
remote = RemoteConfig(machine_list=[])
remote.machine_list.append(RemoteMachineConfig(host='test', gpu_indices=[0,1]))
cgo = CGOExecutionEngine(training_service=remote, batch_waiting_time=0)
phy_models = cgo._assemble(lp)
self.assertTrue(len(phy_models) == 2)
......@@ -311,13 +314,14 @@ class CGOEngineTest(unittest.TestCase):
models = _load_mnist(2)
advisor = RetiariiAdvisor('ws://_placeholder_')
advisor = RetiariiAdvisor('ws://_unittest_placeholder_')
advisor._channel = protocol.LegacyCommandChannel()
advisor.default_worker.start()
advisor.assessor_worker.start()
cgo_engine = CGOExecutionEngine(devices=[GPUDevice("test", 0), GPUDevice("test", 1),
GPUDevice("test", 2), GPUDevice("test", 3)], batch_waiting_time=0)
remote = RemoteConfig(machine_list=[])
remote.machine_list.append(RemoteMachineConfig(host='test', gpu_indices=[0,1,2,3]))
cgo_engine = CGOExecutionEngine(training_service=remote, batch_waiting_time=0)
set_execution_engine(cgo_engine)
submit_models(*models)
time.sleep(3)
......
......@@ -25,7 +25,7 @@ class EngineTest(unittest.TestCase):
def test_base_execution_engine(self):
nni.retiarii.integration_api._advisor = None
nni.retiarii.execution.api._execution_engine = None
advisor = RetiariiAdvisor('ws://_placeholder_')
advisor = RetiariiAdvisor('ws://_unittest_placeholder_')
advisor._channel = LegacyCommandChannel()
advisor.default_worker.start()
advisor.assessor_worker.start()
......@@ -42,7 +42,7 @@ class EngineTest(unittest.TestCase):
def test_py_execution_engine(self):
nni.retiarii.integration_api._advisor = None
nni.retiarii.execution.api._execution_engine = None
advisor = RetiariiAdvisor('ws://_placeholder_')
advisor = RetiariiAdvisor('ws://_unittest_placeholder_')
advisor._channel = LegacyCommandChannel()
advisor.default_worker.start()
advisor.assessor_worker.start()
......
......@@ -57,7 +57,7 @@ class AssessorTestCase(TestCase):
_restore_io()
assessor = NaiveAssessor()
dispatcher = MsgDispatcher('ws://_placeholder_', None, assessor)
dispatcher = MsgDispatcher('ws://_unittest_placeholder_', None, assessor)
dispatcher._channel = LegacyCommandChannel()
msg_dispatcher_base._worker_fast_exit_on_terminate = False
......
......@@ -66,7 +66,7 @@ class MsgDispatcherTestCase(TestCase):
_restore_io()
tuner = NaiveTuner()
dispatcher = MsgDispatcher('ws://_placeholder_', tuner)
dispatcher = MsgDispatcher('ws://_unittest_placeholder_', tuner)
dispatcher._channel = LegacyCommandChannel()
msg_dispatcher_base._worker_fast_exit_on_terminate = False
......
......@@ -303,8 +303,11 @@ class NNIManager implements Manager {
}
this.trainingService.removeTrialJobMetricListener(this.trialJobMetricListener);
if (this.dispatcherPid > 0) {
// NOTE: this sending TERMINATE should be out of the if clause,
// because when python dispatcher is started before nnimanager
// this.dispatcherPid would not have a valid value (i.e., not >0).
this.dispatcher.sendCommand(TERMINATE);
if (this.dispatcherPid > 0) {
// gracefully terminate tuner and assessor here, wait at most 30 seconds.
for (let i: number = 0; i < 30; i++) {
if (!await isAlive(this.dispatcherPid)) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment