Unverified Commit 2fc47247 authored by QuanluZhang's avatar QuanluZhang Committed by GitHub
Browse files

[retiarii] refactor of nas experiment (#4841)

parent c80bda29
...@@ -54,6 +54,11 @@ class ConfigBase: ...@@ -54,6 +54,11 @@ class ConfigBase:
Config objects will remember where they are loaded; therefore relative paths can be resolved smartly. Config objects will remember where they are loaded; therefore relative paths can be resolved smartly.
If a config object is created with constructor, the base path will be current working directory. If a config object is created with constructor, the base path will be current working directory.
If it is loaded with ``ConfigBase.load(path)``, the base path will be ``path``'s parent. If it is loaded with ``ConfigBase.load(path)``, the base path will be ``path``'s parent.
.. attention::
All the classes that inherit ``ConfigBase`` are not allowed to use ``from __future__ import annotations``,
because ``ConfigBase`` uses ``typeguard`` to perform runtime check and it does not support lazy annotations.
""" """
def __init__(self, **kwargs): def __init__(self, **kwargs):
......
...@@ -164,10 +164,11 @@ class ExperimentConfig(ConfigBase): ...@@ -164,10 +164,11 @@ class ExperimentConfig(ConfigBase):
# currently I have only seen one issue of this kind # currently I have only seen one issue of this kind
#Path(self.experiment_working_directory).mkdir(parents=True, exist_ok=True) #Path(self.experiment_working_directory).mkdir(parents=True, exist_ok=True)
utils.validate_gpu_indices(self.tuner_gpu_indices) if type(self).__name__ != 'RetiariiExeConfig':
utils.validate_gpu_indices(self.tuner_gpu_indices)
if self.tuner is None: if self.tuner is None:
raise ValueError('ExperimentConfig: tuner must be set') raise ValueError('ExperimentConfig: tuner must be set')
def _load_search_space_file(search_space_path): def _load_search_space_file(search_space_path):
# FIXME # FIXME
......
...@@ -84,20 +84,9 @@ class Experiment: ...@@ -84,20 +84,9 @@ class Experiment:
else: else:
self.config = config_or_platform self.config = config_or_platform
def start(self, port: int = 8080, debug: bool = False, run_mode: RunMode = RunMode.Background) -> None: def _start_impl(self, port: int, debug: bool, run_mode: RunMode,
""" tuner_command_channel: str | None,
Start the experiment in background. tags: list[str] = []) -> ExperimentConfig:
This method will raise exception on failure.
If it returns, the experiment should have been successfully started.
Parameters
----------
port
The port of web UI.
debug
Whether to start in debug mode.
"""
assert self.config is not None assert self.config is not None
if run_mode is not RunMode.Detach: if run_mode is not RunMode.Detach:
atexit.register(self.stop) atexit.register(self.stop)
...@@ -111,7 +100,8 @@ class Experiment: ...@@ -111,7 +100,8 @@ class Experiment:
log_level = 'debug' if (debug or config.log_level == 'trace') else config.log_level log_level = 'debug' if (debug or config.log_level == 'trace') else config.log_level
start_experiment_logging(self.id, log_file, cast(str, log_level)) start_experiment_logging(self.id, log_file, cast(str, log_level))
self._proc = launcher.start_experiment(self._action, self.id, config, port, debug, run_mode, self.url_prefix) self._proc = launcher.start_experiment(self._action, self.id, config, port, debug, run_mode,
self.url_prefix, tuner_command_channel, tags)
assert self._proc is not None assert self._proc is not None
self.port = port # port will be None if start up failed self.port = port # port will be None if start up failed
...@@ -124,12 +114,27 @@ class Experiment: ...@@ -124,12 +114,27 @@ class Experiment:
ips = [f'http://{ip}:{port}' for ip in ips if ip] ips = [f'http://{ip}:{port}' for ip in ips if ip]
msg = 'Web portal URLs: ${CYAN}' + ' '.join(ips) msg = 'Web portal URLs: ${CYAN}' + ' '.join(ips)
_logger.info(msg) _logger.info(msg)
return config
def stop(self) -> None: def start(self, port: int = 8080, debug: bool = False, run_mode: RunMode = RunMode.Background) -> None:
""" """
Stop the experiment. Start the experiment in background.
This method will raise exception on failure.
If it returns, the experiment should have been successfully started.
Parameters
----------
port
The port of web UI.
debug
Whether to start in debug mode.
run_mode
Running the experiment in foreground or background
""" """
_logger.info('Stopping experiment, please wait...') self._start_impl(port, debug, run_mode, None, [])
def _stop_impl(self) -> None:
atexit.unregister(self.stop) atexit.unregister(self.stop)
stop_experiment_logging(self.id) stop_experiment_logging(self.id)
...@@ -144,8 +149,24 @@ class Experiment: ...@@ -144,8 +149,24 @@ class Experiment:
self.id = None # type: ignore self.id = None # type: ignore
self.port = None self.port = None
self._proc = None self._proc = None
def stop(self) -> None:
"""
Stop the experiment.
"""
_logger.info('Stopping experiment, please wait...')
self._stop_impl()
_logger.info('Experiment stopped') _logger.info('Experiment stopped')
def _wait_completion(self) -> bool:
while True:
status = self.get_status()
if status == 'DONE' or status == 'STOPPED':
return True
if status == 'ERROR':
return False
time.sleep(10)
def run(self, port: int = 8080, wait_completion: bool = True, debug: bool = False) -> bool | None: def run(self, port: int = 8080, wait_completion: bool = True, debug: bool = False) -> bool | None:
""" """
Run the experiment. Run the experiment.
...@@ -159,13 +180,7 @@ class Experiment: ...@@ -159,13 +180,7 @@ class Experiment:
self.start(port, debug) self.start(port, debug)
if wait_completion: if wait_completion:
try: try:
while True: self._wait_completion()
time.sleep(10)
status = self.get_status()
if status == 'DONE' or status == 'STOPPED':
return True
if status == 'ERROR':
return False
except KeyboardInterrupt: except KeyboardInterrupt:
_logger.warning('KeyboardInterrupt detected') _logger.warning('KeyboardInterrupt detected')
self.stop() self.stop()
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
# Licensed under the MIT license. # Licensed under the MIT license.
import time import time
import warnings
from typing import Iterable from typing import Iterable
from ..graph import Model, ModelStatus from ..graph import Model, ModelStatus
...@@ -18,12 +19,12 @@ __all__ = ['get_execution_engine', 'get_and_register_default_listener', ...@@ -18,12 +19,12 @@ __all__ = ['get_execution_engine', 'get_and_register_default_listener',
def set_execution_engine(engine: AbstractExecutionEngine) -> None: def set_execution_engine(engine: AbstractExecutionEngine) -> None:
global _execution_engine global _execution_engine
if _execution_engine is None: if _execution_engine is not None:
_execution_engine = engine warnings.warn('Execution engine is already set. '
else: 'You should avoid instantiating RetiariiExperiment twice in one process. '
raise RuntimeError('Execution engine is already set. ' 'If you are running in a Jupyter notebook, please restart the kernel.',
'You should avoid instantiating RetiariiExperiment twice in one process. ' RuntimeWarning)
'If you are running in a Jupyter notebook, please restart the kernel.') _execution_engine = engine
def get_execution_engine() -> AbstractExecutionEngine: def get_execution_engine() -> AbstractExecutionEngine:
......
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT license. # Licensed under the MIT license.
from __future__ import annotations
import logging import logging
import os import os
import random import random
import string import string
from typing import Any, Dict, Iterable, List from typing import Any, Dict, Iterable, List
from nni.experiment import rest
from .interface import AbstractExecutionEngine, AbstractGraphListener from .interface import AbstractExecutionEngine, AbstractGraphListener
from .utils import get_mutation_summary from .utils import get_mutation_summary
from .. import codegen, utils from .. import codegen, utils
...@@ -54,12 +58,22 @@ class BaseExecutionEngine(AbstractExecutionEngine): ...@@ -54,12 +58,22 @@ class BaseExecutionEngine(AbstractExecutionEngine):
Resource management is implemented in this class. Resource management is implemented in this class.
""" """
def __init__(self) -> None: def __init__(self, rest_port: int | None = None, rest_url_prefix: str | None = None) -> None:
""" """
Upon initialization, advisor callbacks need to be registered. Upon initialization, advisor callbacks need to be registered.
Advisor will call the callbacks when the corresponding event has been triggered. Advisor will call the callbacks when the corresponding event has been triggered.
Base execution engine will get those callbacks and broadcast them to graph listener. Base execution engine will get those callbacks and broadcast them to graph listener.
Parameters
----------
rest_port
The port of the experiment's rest server
rest_url_prefix
The url prefix of the experiment's rest entry
""" """
self.port = rest_port
self.url_prefix = rest_url_prefix
self._listeners: List[AbstractGraphListener] = [] self._listeners: List[AbstractGraphListener] = []
# register advisor callbacks # register advisor callbacks
...@@ -123,8 +137,8 @@ class BaseExecutionEngine(AbstractExecutionEngine): ...@@ -123,8 +137,8 @@ class BaseExecutionEngine(AbstractExecutionEngine):
return self.resources return self.resources
def budget_exhausted(self) -> bool: def budget_exhausted(self) -> bool:
advisor = get_advisor() resp = rest.get(self.port, '/check-status', self.url_prefix)
return advisor.stopping return resp['status'] == 'DONE'
@classmethod @classmethod
def pack_model_data(cls, model: Model) -> Any: def pack_model_data(cls, model: Model) -> Any:
......
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT license. # Licensed under the MIT license.
from __future__ import annotations
import logging import logging
import os import os
import random import random
import string import string
import time import time
import threading import threading
from typing import Iterable, List, Dict, Tuple from typing import Iterable, List, Dict, Tuple, cast
from dataclasses import dataclass from dataclasses import dataclass
from nni.common.device import GPUDevice, Device from nni.common.device import GPUDevice, Device
from nni.experiment.config.training_services import RemoteConfig
from .interface import AbstractExecutionEngine, AbstractGraphListener, WorkerInfo from .interface import AbstractExecutionEngine, AbstractGraphListener, WorkerInfo
from .. import codegen, utils from .. import codegen, utils
from ..graph import Model, ModelStatus, MetricData, Node from ..graph import Model, ModelStatus, MetricData, Node
...@@ -31,7 +34,6 @@ class TrialSubmission: ...@@ -31,7 +34,6 @@ class TrialSubmission:
placement: Dict[Node, Device] placement: Dict[Node, Device]
grouped_models: List[Model] grouped_models: List[Model]
class CGOExecutionEngine(AbstractExecutionEngine): class CGOExecutionEngine(AbstractExecutionEngine):
""" """
The execution engine with Cross-Graph Optimization (CGO). The execution engine with Cross-Graph Optimization (CGO).
...@@ -41,24 +43,35 @@ class CGOExecutionEngine(AbstractExecutionEngine): ...@@ -41,24 +43,35 @@ class CGOExecutionEngine(AbstractExecutionEngine):
Parameters Parameters
---------- ----------
devices : List[Device] training_service
Available devices for execution. The remote training service config.
max_concurrency : int max_concurrency
The maximum number of trials to run concurrently. The maximum number of trials to run concurrently.
batch_waiting_time: int batch_waiting_time
Seconds to wait for each batch of trial submission. Seconds to wait for each batch of trial submission.
The trials within one batch could apply cross-graph optimization. The trials within one batch could apply cross-graph optimization.
rest_port
The port of the experiment's rest server
rest_url_prefix
The url prefix of the experiment's rest entry
""" """
def __init__(self, devices: List[Device] = None, def __init__(self, training_service: RemoteConfig,
max_concurrency: int = None, max_concurrency: int = None,
batch_waiting_time: int = 60, batch_waiting_time: int = 60,
rest_port: int | None = None,
rest_url_prefix: str | None = None
) -> None: ) -> None:
self.port = rest_port
self.url_prefix = rest_url_prefix
self._listeners: List[AbstractGraphListener] = [] self._listeners: List[AbstractGraphListener] = []
self._running_models: Dict[int, Model] = dict() self._running_models: Dict[int, Model] = dict()
self.logical_plan_counter = 0 self.logical_plan_counter = 0
self.available_devices: List[Device] = [] self.available_devices: List[Device] = []
self.max_concurrency: int = max_concurrency self.max_concurrency: int = max_concurrency
devices = self._construct_devices(training_service)
for device in devices: for device in devices:
self.available_devices.append(device) self.available_devices.append(device)
self.all_devices = self.available_devices.copy() self.all_devices = self.available_devices.copy()
...@@ -88,6 +101,17 @@ class CGOExecutionEngine(AbstractExecutionEngine): ...@@ -88,6 +101,17 @@ class CGOExecutionEngine(AbstractExecutionEngine):
self._consumer_thread = threading.Thread(target=self._consume_models) self._consumer_thread = threading.Thread(target=self._consume_models)
self._consumer_thread.start() self._consumer_thread.start()
def _construct_devices(self, training_service):
devices = []
if hasattr(training_service, 'machine_list'):
for machine in cast(RemoteConfig, training_service).machine_list:
assert machine.gpu_indices is not None, \
'gpu_indices must be set in RemoteMachineConfig for CGO execution engine'
assert isinstance(machine.gpu_indices, list), 'gpu_indices must be a list'
for gpu_idx in machine.gpu_indices:
devices.append(GPUDevice(machine.host, gpu_idx))
return devices
def join(self): def join(self):
self._stopped = True self._stopped = True
self._consumer_thread.join() self._consumer_thread.join()
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
\ No newline at end of file
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .experiment_config import *
from .engine_config import *
\ No newline at end of file
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from dataclasses import dataclass
from typing import Optional, List
from nni.experiment.config.base import ConfigBase
__all__ = ['ExecutionEngineConfig', 'BaseEngineConfig', 'OneshotEngineConfig',
'PyEngineConfig', 'CgoEngineConfig', 'BenchmarkEngineConfig']
@dataclass(init=False)
class ExecutionEngineConfig(ConfigBase):
name: str
@dataclass(init=False)
class PyEngineConfig(ExecutionEngineConfig):
name: str = 'py'
@dataclass(init=False)
class OneshotEngineConfig(ExecutionEngineConfig):
name: str = 'oneshot'
@dataclass(init=False)
class BaseEngineConfig(ExecutionEngineConfig):
name: str = 'base'
# input used in GraphConverterWithShape. Currently support shape tuple only.
dummy_input: Optional[List[int]] = None
@dataclass(init=False)
class CgoEngineConfig(ExecutionEngineConfig):
name: str = 'cgo'
max_concurrency_cgo: Optional[int] = None
batch_waiting_time: Optional[int] = None
# input used in GraphConverterWithShape. Currently support shape tuple only.
dummy_input: Optional[List[int]] = None
@dataclass(init=False)
class BenchmarkEngineConfig(ExecutionEngineConfig):
name: str = 'benchmark'
benchmark: Optional[str] = None
\ No newline at end of file
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import os
from dataclasses import dataclass
from typing import Any, Union
from nni.experiment.config import utils, ExperimentConfig
from .engine_config import ExecutionEngineConfig
__all__ = ['RetiariiExeConfig']
def execution_engine_config_factory(engine_name):
# FIXME: may move this function to experiment utils in future
cls = _get_ee_config_class(engine_name)
if cls is None:
raise ValueError(f'Invalid execution engine name: {engine_name}')
return cls()
def _get_ee_config_class(engine_name):
for cls in ExecutionEngineConfig.__subclasses__():
if cls.name == engine_name:
return cls
return None
@dataclass(init=False)
class RetiariiExeConfig(ExperimentConfig):
# FIXME: refactor this class to inherit from a new common base class with HPO config
search_space: Any = ''
trial_code_directory: utils.PathLike = '.'
trial_command: str = '_reserved'
# new config field for NAS
execution_engine: Union[str, ExecutionEngineConfig]
def __init__(self, training_service_platform: Union[str, None] = None,
execution_engine: Union[str, ExecutionEngineConfig] = 'py',
**kwargs):
super().__init__(training_service_platform, **kwargs)
self.execution_engine = execution_engine
def _canonicalize(self, _parents):
msg = '{} is not supposed to be set in Retiarii experiment by users, your config is {}.'
if self.search_space != '':
raise ValueError(msg.format('search_space', self.search_space))
# TODO: maybe we should also allow users to specify trial_code_directory
if str(self.trial_code_directory) != '.' and not os.path.isabs(self.trial_code_directory):
raise ValueError(msg.format('trial_code_directory', self.trial_code_directory))
if self.trial_command != '_reserved' and \
not self.trial_command.startswith('python3 -m nni.retiarii.trial_entry '):
raise ValueError(msg.format('trial_command', self.trial_command))
if isinstance(self.execution_engine, str):
self.execution_engine = execution_engine_config_factory(self.execution_engine)
if self.execution_engine.name in ('py', 'base', 'cgo'):
# TODO: replace python3 with more elegant approach
# maybe use sys.executable rendered in trial side (e.g., trial_runner)
self.trial_command = 'python3 -m nni.retiarii.trial_entry ' + self.execution_engine.name
super()._canonicalize([self])
This diff is collapsed.
...@@ -22,7 +22,10 @@ def get_advisor() -> 'RetiariiAdvisor': ...@@ -22,7 +22,10 @@ def get_advisor() -> 'RetiariiAdvisor':
def register_advisor(advisor: 'RetiariiAdvisor'): def register_advisor(advisor: 'RetiariiAdvisor'):
global _advisor global _advisor
assert _advisor is None if _advisor is not None:
warnings.warn('Advisor is already set.'
'You should avoid instantiating RetiariiExperiment twice in one proces.'
'If you are running in a Jupyter notebook, please restart the kernel.')
_advisor = advisor _advisor = advisor
......
...@@ -18,8 +18,15 @@ _worker_fast_exit_on_terminate = True ...@@ -18,8 +18,15 @@ _worker_fast_exit_on_terminate = True
class MsgDispatcherBase(Recoverable): class MsgDispatcherBase(Recoverable):
"""This is where tuners and assessors are not defined yet. """
This is where tuners and assessors are not defined yet.
Inherits this class to make your own advisor. Inherits this class to make your own advisor.
.. note::
The class inheriting MsgDispatcherBase should be instantiated
after nnimanager (rest server) is started, so that the object
is ready to use right after its instantiation.
""" """
def __init__(self, command_channel_url=None): def __init__(self, command_channel_url=None):
...@@ -27,6 +34,16 @@ class MsgDispatcherBase(Recoverable): ...@@ -27,6 +34,16 @@ class MsgDispatcherBase(Recoverable):
if command_channel_url is None: if command_channel_url is None:
command_channel_url = dispatcher_env_vars.NNI_TUNER_COMMAND_CHANNEL command_channel_url = dispatcher_env_vars.NNI_TUNER_COMMAND_CHANNEL
self._channel = TunerCommandChannel(command_channel_url) self._channel = TunerCommandChannel(command_channel_url)
# NOTE: `connect()` should be put in __init__. First, this `connect()` affects nnimanager's
# starting process, without `connect()` nnimanager is blocked in `dispatcher.init()`.
# Second, nas experiment uses a thread to execute `run()` of this class, thus, there is
# no way to know when the websocket between nnimanager and dispatcher is built. The following
# logic may crash is websocket is not built. One example is updating search space. If updating
# search space too soon, as the websocket has not been built, the rest api of updating search
# space will timeout.
# FIXME: this is making unittest happy
if not command_channel_url.startswith('ws://_unittest_'):
self._channel.connect()
self.default_command_queue = Queue() self.default_command_queue = Queue()
self.assessor_command_queue = Queue() self.assessor_command_queue = Queue()
self.default_worker = threading.Thread(target=self.command_queue_worker, args=(self.default_command_queue,)) self.default_worker = threading.Thread(target=self.command_queue_worker, args=(self.default_command_queue,))
...@@ -39,7 +56,6 @@ class MsgDispatcherBase(Recoverable): ...@@ -39,7 +56,6 @@ class MsgDispatcherBase(Recoverable):
""" """
_logger.info('Dispatcher started') _logger.info('Dispatcher started')
self._channel.connect()
self.default_worker.start() self.default_worker.start()
self.assessor_worker.start() self.assessor_worker.start()
......
...@@ -4,7 +4,6 @@ import warnings ...@@ -4,7 +4,6 @@ import warnings
import torch import torch
import torch.nn as torch_nn import torch.nn as torch_nn
from torchvision.models.utils import load_state_dict_from_url
import torch.nn.functional as F import torch.nn.functional as F
import sys import sys
......
...@@ -8,7 +8,7 @@ import nni.retiarii.evaluator.pytorch.cgo.evaluator as cgo ...@@ -8,7 +8,7 @@ import nni.retiarii.evaluator.pytorch.cgo.evaluator as cgo
from nni.retiarii import serialize from nni.retiarii import serialize
from base_mnasnet import MNASNet from base_mnasnet import MNASNet
from nni.experiment import RemoteMachineConfig from nni.experiment import RemoteMachineConfig
from nni.retiarii.experiment.pytorch import RetiariiExperiment, RetiariiExeConfig from nni.retiarii.experiment.pytorch import RetiariiExperiment, RetiariiExeConfig, CgoEngineConfig
from nni.retiarii.strategy import TPEStrategy from nni.retiarii.strategy import TPEStrategy
from torchvision import transforms from torchvision import transforms
from torchvision.datasets import CIFAR10 from torchvision.datasets import CIFAR10
...@@ -59,8 +59,6 @@ if __name__ == '__main__': ...@@ -59,8 +59,6 @@ if __name__ == '__main__':
exp_config.max_trial_number = 10 exp_config.max_trial_number = 10
exp_config.trial_gpu_number = 1 exp_config.trial_gpu_number = 1
exp_config.training_service.reuse_mode = True exp_config.training_service.reuse_mode = True
exp_config.max_concurrency_cgo = 3
exp_config.batch_waiting_time = 0
rm_conf = RemoteMachineConfig() rm_conf = RemoteMachineConfig()
rm_conf.host = '127.0.0.1' rm_conf.host = '127.0.0.1'
...@@ -73,6 +71,6 @@ if __name__ == '__main__': ...@@ -73,6 +71,6 @@ if __name__ == '__main__':
rm_conf.max_trial_number_per_gpu = 3 rm_conf.max_trial_number_per_gpu = 3
exp_config.training_service.machine_list = [rm_conf] exp_config.training_service.machine_list = [rm_conf]
exp_config.execution_engine = 'cgo' exp_config.execution_engine = CgoEngineConfig(max_concurrency_cgo = 3, batch_waiting_time = 0)
exp.run(exp_config, 8099) exp.run(exp_config, 8099)
\ No newline at end of file
...@@ -9,6 +9,7 @@ from pytorch_lightning.utilities.seed import seed_everything ...@@ -9,6 +9,7 @@ from pytorch_lightning.utilities.seed import seed_everything
from pathlib import Path from pathlib import Path
import nni import nni
from nni.experiment.config import RemoteConfig, RemoteMachineConfig
import nni.runtime.platform.test import nni.runtime.platform.test
from nni.runtime.tuner_command_channel import legacy as protocol from nni.runtime.tuner_command_channel import legacy as protocol
import json import json
...@@ -263,13 +264,14 @@ class CGOEngineTest(unittest.TestCase): ...@@ -263,13 +264,14 @@ class CGOEngineTest(unittest.TestCase):
opt = DedupInputOptimizer() opt = DedupInputOptimizer()
opt.convert(lp) opt.convert(lp)
advisor = RetiariiAdvisor('ws://_placeholder_') advisor = RetiariiAdvisor('ws://_unittest_placeholder_')
advisor._channel = protocol.LegacyCommandChannel() advisor._channel = protocol.LegacyCommandChannel()
advisor.default_worker.start() advisor.default_worker.start()
advisor.assessor_worker.start() advisor.assessor_worker.start()
available_devices = [GPUDevice("test", 0), GPUDevice("test", 1), GPUDevice("test", 2), GPUDevice("test", 3)] remote = RemoteConfig(machine_list=[])
cgo = CGOExecutionEngine(devices=available_devices, batch_waiting_time=0) remote.machine_list.append(RemoteMachineConfig(host='test', gpu_indices=[0,1,2,3]))
cgo = CGOExecutionEngine(training_service=remote, batch_waiting_time=0)
phy_models = cgo._assemble(lp) phy_models = cgo._assemble(lp)
self.assertTrue(len(phy_models) == 1) self.assertTrue(len(phy_models) == 1)
...@@ -286,13 +288,14 @@ class CGOEngineTest(unittest.TestCase): ...@@ -286,13 +288,14 @@ class CGOEngineTest(unittest.TestCase):
opt = DedupInputOptimizer() opt = DedupInputOptimizer()
opt.convert(lp) opt.convert(lp)
advisor = RetiariiAdvisor('ws://_placeholder_') advisor = RetiariiAdvisor('ws://_unittest_placeholder_')
advisor._channel = protocol.LegacyCommandChannel() advisor._channel = protocol.LegacyCommandChannel()
advisor.default_worker.start() advisor.default_worker.start()
advisor.assessor_worker.start() advisor.assessor_worker.start()
available_devices = [GPUDevice("test", 0), GPUDevice("test", 1)] remote = RemoteConfig(machine_list=[])
cgo = CGOExecutionEngine(devices=available_devices, batch_waiting_time=0) remote.machine_list.append(RemoteMachineConfig(host='test', gpu_indices=[0,1]))
cgo = CGOExecutionEngine(training_service=remote, batch_waiting_time=0)
phy_models = cgo._assemble(lp) phy_models = cgo._assemble(lp)
self.assertTrue(len(phy_models) == 2) self.assertTrue(len(phy_models) == 2)
...@@ -311,13 +314,14 @@ class CGOEngineTest(unittest.TestCase): ...@@ -311,13 +314,14 @@ class CGOEngineTest(unittest.TestCase):
models = _load_mnist(2) models = _load_mnist(2)
advisor = RetiariiAdvisor('ws://_placeholder_') advisor = RetiariiAdvisor('ws://_unittest_placeholder_')
advisor._channel = protocol.LegacyCommandChannel() advisor._channel = protocol.LegacyCommandChannel()
advisor.default_worker.start() advisor.default_worker.start()
advisor.assessor_worker.start() advisor.assessor_worker.start()
cgo_engine = CGOExecutionEngine(devices=[GPUDevice("test", 0), GPUDevice("test", 1), remote = RemoteConfig(machine_list=[])
GPUDevice("test", 2), GPUDevice("test", 3)], batch_waiting_time=0) remote.machine_list.append(RemoteMachineConfig(host='test', gpu_indices=[0,1,2,3]))
cgo_engine = CGOExecutionEngine(training_service=remote, batch_waiting_time=0)
set_execution_engine(cgo_engine) set_execution_engine(cgo_engine)
submit_models(*models) submit_models(*models)
time.sleep(3) time.sleep(3)
......
...@@ -25,7 +25,7 @@ class EngineTest(unittest.TestCase): ...@@ -25,7 +25,7 @@ class EngineTest(unittest.TestCase):
def test_base_execution_engine(self): def test_base_execution_engine(self):
nni.retiarii.integration_api._advisor = None nni.retiarii.integration_api._advisor = None
nni.retiarii.execution.api._execution_engine = None nni.retiarii.execution.api._execution_engine = None
advisor = RetiariiAdvisor('ws://_placeholder_') advisor = RetiariiAdvisor('ws://_unittest_placeholder_')
advisor._channel = LegacyCommandChannel() advisor._channel = LegacyCommandChannel()
advisor.default_worker.start() advisor.default_worker.start()
advisor.assessor_worker.start() advisor.assessor_worker.start()
...@@ -42,7 +42,7 @@ class EngineTest(unittest.TestCase): ...@@ -42,7 +42,7 @@ class EngineTest(unittest.TestCase):
def test_py_execution_engine(self): def test_py_execution_engine(self):
nni.retiarii.integration_api._advisor = None nni.retiarii.integration_api._advisor = None
nni.retiarii.execution.api._execution_engine = None nni.retiarii.execution.api._execution_engine = None
advisor = RetiariiAdvisor('ws://_placeholder_') advisor = RetiariiAdvisor('ws://_unittest_placeholder_')
advisor._channel = LegacyCommandChannel() advisor._channel = LegacyCommandChannel()
advisor.default_worker.start() advisor.default_worker.start()
advisor.assessor_worker.start() advisor.assessor_worker.start()
......
...@@ -57,7 +57,7 @@ class AssessorTestCase(TestCase): ...@@ -57,7 +57,7 @@ class AssessorTestCase(TestCase):
_restore_io() _restore_io()
assessor = NaiveAssessor() assessor = NaiveAssessor()
dispatcher = MsgDispatcher('ws://_placeholder_', None, assessor) dispatcher = MsgDispatcher('ws://_unittest_placeholder_', None, assessor)
dispatcher._channel = LegacyCommandChannel() dispatcher._channel = LegacyCommandChannel()
msg_dispatcher_base._worker_fast_exit_on_terminate = False msg_dispatcher_base._worker_fast_exit_on_terminate = False
......
...@@ -66,7 +66,7 @@ class MsgDispatcherTestCase(TestCase): ...@@ -66,7 +66,7 @@ class MsgDispatcherTestCase(TestCase):
_restore_io() _restore_io()
tuner = NaiveTuner() tuner = NaiveTuner()
dispatcher = MsgDispatcher('ws://_placeholder_', tuner) dispatcher = MsgDispatcher('ws://_unittest_placeholder_', tuner)
dispatcher._channel = LegacyCommandChannel() dispatcher._channel = LegacyCommandChannel()
msg_dispatcher_base._worker_fast_exit_on_terminate = False msg_dispatcher_base._worker_fast_exit_on_terminate = False
......
...@@ -303,8 +303,11 @@ class NNIManager implements Manager { ...@@ -303,8 +303,11 @@ class NNIManager implements Manager {
} }
this.trainingService.removeTrialJobMetricListener(this.trialJobMetricListener); this.trainingService.removeTrialJobMetricListener(this.trialJobMetricListener);
// NOTE: this sending TERMINATE should be out of the if clause,
// because when python dispatcher is started before nnimanager
// this.dispatcherPid would not have a valid value (i.e., not >0).
this.dispatcher.sendCommand(TERMINATE);
if (this.dispatcherPid > 0) { if (this.dispatcherPid > 0) {
this.dispatcher.sendCommand(TERMINATE);
// gracefully terminate tuner and assessor here, wait at most 30 seconds. // gracefully terminate tuner and assessor here, wait at most 30 seconds.
for (let i: number = 0; i < 30; i++) { for (let i: number = 0; i < 30; i++) {
if (!await isAlive(this.dispatcherPid)) { if (!await isAlive(this.dispatcherPid)) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment