Unverified Commit 2fc47247 authored by QuanluZhang's avatar QuanluZhang Committed by GitHub
Browse files

[retiarii] refactor of nas experiment (#4841)

parent c80bda29
...@@ -54,6 +54,11 @@ class ConfigBase: ...@@ -54,6 +54,11 @@ class ConfigBase:
Config objects will remember where they are loaded; therefore relative paths can be resolved smartly. Config objects will remember where they are loaded; therefore relative paths can be resolved smartly.
If a config object is created with constructor, the base path will be current working directory. If a config object is created with constructor, the base path will be current working directory.
If it is loaded with ``ConfigBase.load(path)``, the base path will be ``path``'s parent. If it is loaded with ``ConfigBase.load(path)``, the base path will be ``path``'s parent.
.. attention::
All the classes that inherit ``ConfigBase`` are not allowed to use ``from __future__ import annotations``,
because ``ConfigBase`` uses ``typeguard`` to perform runtime check and it does not support lazy annotations.
""" """
def __init__(self, **kwargs): def __init__(self, **kwargs):
......
...@@ -164,10 +164,11 @@ class ExperimentConfig(ConfigBase): ...@@ -164,10 +164,11 @@ class ExperimentConfig(ConfigBase):
# currently I have only seen one issue of this kind # currently I have only seen one issue of this kind
#Path(self.experiment_working_directory).mkdir(parents=True, exist_ok=True) #Path(self.experiment_working_directory).mkdir(parents=True, exist_ok=True)
utils.validate_gpu_indices(self.tuner_gpu_indices) if type(self).__name__ != 'RetiariiExeConfig':
utils.validate_gpu_indices(self.tuner_gpu_indices)
if self.tuner is None: if self.tuner is None:
raise ValueError('ExperimentConfig: tuner must be set') raise ValueError('ExperimentConfig: tuner must be set')
def _load_search_space_file(search_space_path): def _load_search_space_file(search_space_path):
# FIXME # FIXME
......
...@@ -84,20 +84,9 @@ class Experiment: ...@@ -84,20 +84,9 @@ class Experiment:
else: else:
self.config = config_or_platform self.config = config_or_platform
def start(self, port: int = 8080, debug: bool = False, run_mode: RunMode = RunMode.Background) -> None: def _start_impl(self, port: int, debug: bool, run_mode: RunMode,
""" tuner_command_channel: str | None,
Start the experiment in background. tags: list[str] = []) -> ExperimentConfig:
This method will raise exception on failure.
If it returns, the experiment should have been successfully started.
Parameters
----------
port
The port of web UI.
debug
Whether to start in debug mode.
"""
assert self.config is not None assert self.config is not None
if run_mode is not RunMode.Detach: if run_mode is not RunMode.Detach:
atexit.register(self.stop) atexit.register(self.stop)
...@@ -111,7 +100,8 @@ class Experiment: ...@@ -111,7 +100,8 @@ class Experiment:
log_level = 'debug' if (debug or config.log_level == 'trace') else config.log_level log_level = 'debug' if (debug or config.log_level == 'trace') else config.log_level
start_experiment_logging(self.id, log_file, cast(str, log_level)) start_experiment_logging(self.id, log_file, cast(str, log_level))
self._proc = launcher.start_experiment(self._action, self.id, config, port, debug, run_mode, self.url_prefix) self._proc = launcher.start_experiment(self._action, self.id, config, port, debug, run_mode,
self.url_prefix, tuner_command_channel, tags)
assert self._proc is not None assert self._proc is not None
self.port = port # port will be None if start up failed self.port = port # port will be None if start up failed
...@@ -124,12 +114,27 @@ class Experiment: ...@@ -124,12 +114,27 @@ class Experiment:
ips = [f'http://{ip}:{port}' for ip in ips if ip] ips = [f'http://{ip}:{port}' for ip in ips if ip]
msg = 'Web portal URLs: ${CYAN}' + ' '.join(ips) msg = 'Web portal URLs: ${CYAN}' + ' '.join(ips)
_logger.info(msg) _logger.info(msg)
return config
def stop(self) -> None: def start(self, port: int = 8080, debug: bool = False, run_mode: RunMode = RunMode.Background) -> None:
""" """
Stop the experiment. Start the experiment in background.
This method will raise exception on failure.
If it returns, the experiment should have been successfully started.
Parameters
----------
port
The port of web UI.
debug
Whether to start in debug mode.
run_mode
Running the experiment in foreground or background
""" """
_logger.info('Stopping experiment, please wait...') self._start_impl(port, debug, run_mode, None, [])
def _stop_impl(self) -> None:
atexit.unregister(self.stop) atexit.unregister(self.stop)
stop_experiment_logging(self.id) stop_experiment_logging(self.id)
...@@ -144,8 +149,24 @@ class Experiment: ...@@ -144,8 +149,24 @@ class Experiment:
self.id = None # type: ignore self.id = None # type: ignore
self.port = None self.port = None
self._proc = None self._proc = None
def stop(self) -> None:
"""
Stop the experiment.
"""
_logger.info('Stopping experiment, please wait...')
self._stop_impl()
_logger.info('Experiment stopped') _logger.info('Experiment stopped')
def _wait_completion(self) -> bool:
while True:
status = self.get_status()
if status == 'DONE' or status == 'STOPPED':
return True
if status == 'ERROR':
return False
time.sleep(10)
def run(self, port: int = 8080, wait_completion: bool = True, debug: bool = False) -> bool | None: def run(self, port: int = 8080, wait_completion: bool = True, debug: bool = False) -> bool | None:
""" """
Run the experiment. Run the experiment.
...@@ -159,13 +180,7 @@ class Experiment: ...@@ -159,13 +180,7 @@ class Experiment:
self.start(port, debug) self.start(port, debug)
if wait_completion: if wait_completion:
try: try:
while True: self._wait_completion()
time.sleep(10)
status = self.get_status()
if status == 'DONE' or status == 'STOPPED':
return True
if status == 'ERROR':
return False
except KeyboardInterrupt: except KeyboardInterrupt:
_logger.warning('KeyboardInterrupt detected') _logger.warning('KeyboardInterrupt detected')
self.stop() self.stop()
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
# Licensed under the MIT license. # Licensed under the MIT license.
import time import time
import warnings
from typing import Iterable from typing import Iterable
from ..graph import Model, ModelStatus from ..graph import Model, ModelStatus
...@@ -18,12 +19,12 @@ __all__ = ['get_execution_engine', 'get_and_register_default_listener', ...@@ -18,12 +19,12 @@ __all__ = ['get_execution_engine', 'get_and_register_default_listener',
def set_execution_engine(engine: AbstractExecutionEngine) -> None: def set_execution_engine(engine: AbstractExecutionEngine) -> None:
global _execution_engine global _execution_engine
if _execution_engine is None: if _execution_engine is not None:
_execution_engine = engine warnings.warn('Execution engine is already set. '
else: 'You should avoid instantiating RetiariiExperiment twice in one process. '
raise RuntimeError('Execution engine is already set. ' 'If you are running in a Jupyter notebook, please restart the kernel.',
'You should avoid instantiating RetiariiExperiment twice in one process. ' RuntimeWarning)
'If you are running in a Jupyter notebook, please restart the kernel.') _execution_engine = engine
def get_execution_engine() -> AbstractExecutionEngine: def get_execution_engine() -> AbstractExecutionEngine:
......
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT license. # Licensed under the MIT license.
from __future__ import annotations
import logging import logging
import os import os
import random import random
import string import string
from typing import Any, Dict, Iterable, List from typing import Any, Dict, Iterable, List
from nni.experiment import rest
from .interface import AbstractExecutionEngine, AbstractGraphListener from .interface import AbstractExecutionEngine, AbstractGraphListener
from .utils import get_mutation_summary from .utils import get_mutation_summary
from .. import codegen, utils from .. import codegen, utils
...@@ -54,12 +58,22 @@ class BaseExecutionEngine(AbstractExecutionEngine): ...@@ -54,12 +58,22 @@ class BaseExecutionEngine(AbstractExecutionEngine):
Resource management is implemented in this class. Resource management is implemented in this class.
""" """
def __init__(self) -> None: def __init__(self, rest_port: int | None = None, rest_url_prefix: str | None = None) -> None:
""" """
Upon initialization, advisor callbacks need to be registered. Upon initialization, advisor callbacks need to be registered.
Advisor will call the callbacks when the corresponding event has been triggered. Advisor will call the callbacks when the corresponding event has been triggered.
Base execution engine will get those callbacks and broadcast them to graph listener. Base execution engine will get those callbacks and broadcast them to graph listener.
Parameters
----------
rest_port
The port of the experiment's rest server
rest_url_prefix
The url prefix of the experiment's rest entry
""" """
self.port = rest_port
self.url_prefix = rest_url_prefix
self._listeners: List[AbstractGraphListener] = [] self._listeners: List[AbstractGraphListener] = []
# register advisor callbacks # register advisor callbacks
...@@ -123,8 +137,8 @@ class BaseExecutionEngine(AbstractExecutionEngine): ...@@ -123,8 +137,8 @@ class BaseExecutionEngine(AbstractExecutionEngine):
return self.resources return self.resources
def budget_exhausted(self) -> bool: def budget_exhausted(self) -> bool:
advisor = get_advisor() resp = rest.get(self.port, '/check-status', self.url_prefix)
return advisor.stopping return resp['status'] == 'DONE'
@classmethod @classmethod
def pack_model_data(cls, model: Model) -> Any: def pack_model_data(cls, model: Model) -> Any:
......
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT license. # Licensed under the MIT license.
from __future__ import annotations
import logging import logging
import os import os
import random import random
import string import string
import time import time
import threading import threading
from typing import Iterable, List, Dict, Tuple from typing import Iterable, List, Dict, Tuple, cast
from dataclasses import dataclass from dataclasses import dataclass
from nni.common.device import GPUDevice, Device from nni.common.device import GPUDevice, Device
from nni.experiment.config.training_services import RemoteConfig
from .interface import AbstractExecutionEngine, AbstractGraphListener, WorkerInfo from .interface import AbstractExecutionEngine, AbstractGraphListener, WorkerInfo
from .. import codegen, utils from .. import codegen, utils
from ..graph import Model, ModelStatus, MetricData, Node from ..graph import Model, ModelStatus, MetricData, Node
...@@ -31,7 +34,6 @@ class TrialSubmission: ...@@ -31,7 +34,6 @@ class TrialSubmission:
placement: Dict[Node, Device] placement: Dict[Node, Device]
grouped_models: List[Model] grouped_models: List[Model]
class CGOExecutionEngine(AbstractExecutionEngine): class CGOExecutionEngine(AbstractExecutionEngine):
""" """
The execution engine with Cross-Graph Optimization (CGO). The execution engine with Cross-Graph Optimization (CGO).
...@@ -41,24 +43,35 @@ class CGOExecutionEngine(AbstractExecutionEngine): ...@@ -41,24 +43,35 @@ class CGOExecutionEngine(AbstractExecutionEngine):
Parameters Parameters
---------- ----------
devices : List[Device] training_service
Available devices for execution. The remote training service config.
max_concurrency : int max_concurrency
The maximum number of trials to run concurrently. The maximum number of trials to run concurrently.
batch_waiting_time: int batch_waiting_time
Seconds to wait for each batch of trial submission. Seconds to wait for each batch of trial submission.
The trials within one batch could apply cross-graph optimization. The trials within one batch could apply cross-graph optimization.
rest_port
The port of the experiment's rest server
rest_url_prefix
The url prefix of the experiment's rest entry
""" """
def __init__(self, devices: List[Device] = None, def __init__(self, training_service: RemoteConfig,
max_concurrency: int = None, max_concurrency: int = None,
batch_waiting_time: int = 60, batch_waiting_time: int = 60,
rest_port: int | None = None,
rest_url_prefix: str | None = None
) -> None: ) -> None:
self.port = rest_port
self.url_prefix = rest_url_prefix
self._listeners: List[AbstractGraphListener] = [] self._listeners: List[AbstractGraphListener] = []
self._running_models: Dict[int, Model] = dict() self._running_models: Dict[int, Model] = dict()
self.logical_plan_counter = 0 self.logical_plan_counter = 0
self.available_devices: List[Device] = [] self.available_devices: List[Device] = []
self.max_concurrency: int = max_concurrency self.max_concurrency: int = max_concurrency
devices = self._construct_devices(training_service)
for device in devices: for device in devices:
self.available_devices.append(device) self.available_devices.append(device)
self.all_devices = self.available_devices.copy() self.all_devices = self.available_devices.copy()
...@@ -88,6 +101,17 @@ class CGOExecutionEngine(AbstractExecutionEngine): ...@@ -88,6 +101,17 @@ class CGOExecutionEngine(AbstractExecutionEngine):
self._consumer_thread = threading.Thread(target=self._consume_models) self._consumer_thread = threading.Thread(target=self._consume_models)
self._consumer_thread.start() self._consumer_thread.start()
def _construct_devices(self, training_service):
devices = []
if hasattr(training_service, 'machine_list'):
for machine in cast(RemoteConfig, training_service).machine_list:
assert machine.gpu_indices is not None, \
'gpu_indices must be set in RemoteMachineConfig for CGO execution engine'
assert isinstance(machine.gpu_indices, list), 'gpu_indices must be a list'
for gpu_idx in machine.gpu_indices:
devices.append(GPUDevice(machine.host, gpu_idx))
return devices
def join(self): def join(self):
self._stopped = True self._stopped = True
self._consumer_thread.join() self._consumer_thread.join()
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
\ No newline at end of file
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .experiment_config import *
from .engine_config import *
\ No newline at end of file
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from dataclasses import dataclass
from typing import Optional, List
from nni.experiment.config.base import ConfigBase
__all__ = ['ExecutionEngineConfig', 'BaseEngineConfig', 'OneshotEngineConfig',
'PyEngineConfig', 'CgoEngineConfig', 'BenchmarkEngineConfig']
@dataclass(init=False)
class ExecutionEngineConfig(ConfigBase):
name: str
@dataclass(init=False)
class PyEngineConfig(ExecutionEngineConfig):
name: str = 'py'
@dataclass(init=False)
class OneshotEngineConfig(ExecutionEngineConfig):
name: str = 'oneshot'
@dataclass(init=False)
class BaseEngineConfig(ExecutionEngineConfig):
name: str = 'base'
# input used in GraphConverterWithShape. Currently support shape tuple only.
dummy_input: Optional[List[int]] = None
@dataclass(init=False)
class CgoEngineConfig(ExecutionEngineConfig):
name: str = 'cgo'
max_concurrency_cgo: Optional[int] = None
batch_waiting_time: Optional[int] = None
# input used in GraphConverterWithShape. Currently support shape tuple only.
dummy_input: Optional[List[int]] = None
@dataclass(init=False)
class BenchmarkEngineConfig(ExecutionEngineConfig):
name: str = 'benchmark'
benchmark: Optional[str] = None
\ No newline at end of file
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import os
from dataclasses import dataclass
from typing import Any, Union
from nni.experiment.config import utils, ExperimentConfig
from .engine_config import ExecutionEngineConfig
__all__ = ['RetiariiExeConfig']
def execution_engine_config_factory(engine_name):
# FIXME: may move this function to experiment utils in future
cls = _get_ee_config_class(engine_name)
if cls is None:
raise ValueError(f'Invalid execution engine name: {engine_name}')
return cls()
def _get_ee_config_class(engine_name):
for cls in ExecutionEngineConfig.__subclasses__():
if cls.name == engine_name:
return cls
return None
@dataclass(init=False)
class RetiariiExeConfig(ExperimentConfig):
# FIXME: refactor this class to inherit from a new common base class with HPO config
search_space: Any = ''
trial_code_directory: utils.PathLike = '.'
trial_command: str = '_reserved'
# new config field for NAS
execution_engine: Union[str, ExecutionEngineConfig]
def __init__(self, training_service_platform: Union[str, None] = None,
execution_engine: Union[str, ExecutionEngineConfig] = 'py',
**kwargs):
super().__init__(training_service_platform, **kwargs)
self.execution_engine = execution_engine
def _canonicalize(self, _parents):
msg = '{} is not supposed to be set in Retiarii experiment by users, your config is {}.'
if self.search_space != '':
raise ValueError(msg.format('search_space', self.search_space))
# TODO: maybe we should also allow users to specify trial_code_directory
if str(self.trial_code_directory) != '.' and not os.path.isabs(self.trial_code_directory):
raise ValueError(msg.format('trial_code_directory', self.trial_code_directory))
if self.trial_command != '_reserved' and \
not self.trial_command.startswith('python3 -m nni.retiarii.trial_entry '):
raise ValueError(msg.format('trial_command', self.trial_command))
if isinstance(self.execution_engine, str):
self.execution_engine = execution_engine_config_factory(self.execution_engine)
if self.execution_engine.name in ('py', 'base', 'cgo'):
# TODO: replace python3 with more elegant approach
# maybe use sys.executable rendered in trial side (e.g., trial_runner)
self.trial_command = 'python3 -m nni.retiarii.trial_entry ' + self.execution_engine.name
super()._canonicalize([self])
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT license. # Licensed under the MIT license.
import atexit from __future__ import annotations
import logging import logging
import os
import socket
import time
import warnings import warnings
from dataclasses import dataclass
from pathlib import Path
from subprocess import Popen
from threading import Thread from threading import Thread
from typing import Any, List, Optional, Union, cast from typing import Any, List, Union, cast
import colorama import colorama
import psutil
import torch import torch
import torch.nn as nn import torch.nn as nn
import nni.runtime.log from nni.experiment import Experiment, RunMode
from nni.common.device import GPUDevice
from nni.experiment import Experiment, RunMode, launcher, management, rest
from nni.experiment.config import utils
from nni.experiment.config.base import ConfigBase
from nni.experiment.config.training_service import TrainingServiceConfig
from nni.experiment.config.training_services import RemoteConfig from nni.experiment.config.training_services import RemoteConfig
from nni.runtime.tuner_command_channel import TunerCommandChannel
from nni.tools.nnictl.command_utils import kill_command
from .config import (
RetiariiExeConfig, OneshotEngineConfig, BaseEngineConfig,
PyEngineConfig, CgoEngineConfig, BenchmarkEngineConfig
)
from ..codegen import model_to_pytorch_script from ..codegen import model_to_pytorch_script
from ..converter import convert_to_graph from ..converter import convert_to_graph
from ..converter.graph_gen import GraphConverterWithShape from ..converter.graph_gen import GraphConverterWithShape
...@@ -46,79 +39,7 @@ from ..strategy.utils import dry_run_for_formatted_search_space ...@@ -46,79 +39,7 @@ from ..strategy.utils import dry_run_for_formatted_search_space
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
__all__ = ['RetiariiExeConfig', 'RetiariiExperiment'] __all__ = ['RetiariiExperiment']
@dataclass(init=False)
class RetiariiExeConfig(ConfigBase):
experiment_name: Optional[str] = None
search_space: Any = '' # TODO: remove
trial_command: str = '_reserved'
trial_code_directory: utils.PathLike = '.'
trial_concurrency: int
trial_gpu_number: int = 0
devices: Optional[List[Union[str, GPUDevice]]] = None
max_experiment_duration: Optional[str] = None
max_trial_number: Optional[int] = None
max_concurrency_cgo: Optional[int] = None
batch_waiting_time: Optional[int] = None
nni_manager_ip: Optional[str] = None
debug: bool = False
log_level: str = 'info'
experiment_working_directory: utils.PathLike = '~/nni-experiments'
# remove configuration of tuner/assessor/advisor
training_service: TrainingServiceConfig
execution_engine: str = 'py'
# input used in GraphConverterWithShape. Currently support shape tuple only.
dummy_input: Optional[List[int]] = None
# input used for benchmark engine.
benchmark: Optional[str] = None
def __init__(self, training_service_platform: Optional[str] = None, **kwargs):
super().__init__(**kwargs)
if training_service_platform is not None:
assert 'training_service' not in kwargs
self.training_service = utils.training_service_config_factory(platform=training_service_platform)
self.__dict__['trial_command'] = 'python3 -m nni.retiarii.trial_entry py'
def __setattr__(self, key, value):
fixed_attrs = {'search_space': '',
'trial_command': '_reserved'}
if key in fixed_attrs and fixed_attrs[key] != value:
raise AttributeError(f'{key} is not supposed to be set in Retiarii mode by users!')
# 'trial_code_directory' is handled differently because the path will be converted to absolute path by us
if key == 'trial_code_directory' and not (str(value) == '.' or os.path.isabs(value)):
raise AttributeError(f'{key} is not supposed to be set in Retiarii mode by users!')
if key == 'execution_engine':
assert value in ['base', 'py', 'cgo', 'benchmark', 'oneshot'], f'The specified execution engine "{value}" is not supported.'
self.__dict__['trial_command'] = 'python3 -m nni.retiarii.trial_entry ' + value
self.__dict__[key] = value
def validate(self, initialized_tuner: bool = False) -> None:
super().validate()
@property
def _canonical_rules(self):
return _canonical_rules
@property
def _validation_rules(self):
return _validation_rules
_canonical_rules = {
}
_validation_rules = {
'trial_code_directory': lambda value: (Path(value).is_dir(), f'"{value}" does not exist or is not directory'),
'trial_concurrency': lambda value: value > 0,
'trial_gpu_number': lambda value: value >= 0,
'max_trial_number': lambda value: value > 0,
'log_level': lambda value: value in ["trace", "debug", "info", "warning", "error", "fatal"],
'training_service': lambda value: (type(value) is not TrainingServiceConfig, 'cannot be abstract base class')
}
def preprocess_model(base_model, evaluator, applied_mutators, full_ir=True, dummy_input=None, oneshot=False): def preprocess_model(base_model, evaluator, applied_mutators, full_ir=True, dummy_input=None, oneshot=False):
...@@ -252,9 +173,14 @@ class RetiariiExperiment(Experiment): ...@@ -252,9 +173,14 @@ class RetiariiExperiment(Experiment):
... final_model = Net() ... final_model = Net()
""" """
def __init__(self, base_model: nn.Module, evaluator: Union[BaseOneShotTrainer, Evaluator] = cast(Evaluator, None), def __init__(self, base_model: nn.Module,
applied_mutators: List[Mutator] = cast(List[Mutator], None), strategy: BaseStrategy = cast(BaseStrategy, None), evaluator: Union[BaseOneShotTrainer, Evaluator] = cast(Evaluator, None),
applied_mutators: List[Mutator] = cast(List[Mutator], None),
strategy: BaseStrategy = cast(BaseStrategy, None),
trainer: BaseOneShotTrainer = cast(BaseOneShotTrainer, None)): trainer: BaseOneShotTrainer = cast(BaseOneShotTrainer, None)):
super().__init__(None)
self.config: RetiariiExeConfig = cast(RetiariiExeConfig, None)
if trainer is not None: if trainer is not None:
warnings.warn('Usage of `trainer` in RetiariiExperiment is deprecated and will be removed soon. ' warnings.warn('Usage of `trainer` in RetiariiExperiment is deprecated and will be removed soon. '
'Please consider specifying it as a positional argument, or use `evaluator`.', DeprecationWarning) 'Please consider specifying it as a positional argument, or use `evaluator`.', DeprecationWarning)
...@@ -263,25 +189,13 @@ class RetiariiExperiment(Experiment): ...@@ -263,25 +189,13 @@ class RetiariiExperiment(Experiment):
if evaluator is None: if evaluator is None:
raise ValueError('Evaluator should not be none.') raise ValueError('Evaluator should not be none.')
# TODO: The current design of init interface of Retiarii experiment needs to be reviewed.
self.config: RetiariiExeConfig = cast(RetiariiExeConfig, None)
self.port: Optional[int] = None
self.base_model = base_model self.base_model = base_model
self.evaluator: Union[Evaluator, BaseOneShotTrainer] = evaluator self.evaluator: Union[Evaluator, BaseOneShotTrainer] = evaluator
self.applied_mutators = applied_mutators self.applied_mutators = applied_mutators
self.strategy = strategy self.strategy = strategy
from nni.retiarii.oneshot.pytorch.strategy import OneShotStrategy self._dispatcher = None
if not isinstance(strategy, OneShotStrategy): self._dispatcher_thread = None
# FIXME: Dispatcher should not be created this early.
self._dispatcher = RetiariiAdvisor('_placeholder_')
else:
self._dispatcher = cast(RetiariiAdvisor, None)
self._dispatcher_thread: Optional[Thread] = None
self._proc: Optional[Popen] = None
self.url_prefix = None
# check for sanity # check for sanity
if not is_model_wrapped(base_model): if not is_model_wrapped(base_model):
...@@ -290,11 +204,12 @@ class RetiariiExperiment(Experiment): ...@@ -290,11 +204,12 @@ class RetiariiExperiment(Experiment):
'but it may cause inconsistent behavior compared to the time when you add it.' + colorama.Style.RESET_ALL, 'but it may cause inconsistent behavior compared to the time when you add it.' + colorama.Style.RESET_ALL,
RuntimeWarning) RuntimeWarning)
def _start_strategy(self): def _run_strategy(self, config: RetiariiExeConfig):
base_model_ir, self.applied_mutators = preprocess_model( base_model_ir, self.applied_mutators = preprocess_model(
self.base_model, self.evaluator, self.applied_mutators, self.base_model, self.evaluator, self.applied_mutators,
full_ir=self.config.execution_engine not in ['py', 'benchmark'], full_ir=not isinstance(config.execution_engine, (PyEngineConfig, BenchmarkEngineConfig)),
dummy_input=self.config.dummy_input dummy_input=config.execution_engine.dummy_input
if isinstance(config.execution_engine, (BaseEngineConfig, CgoEngineConfig)) else None
) )
_logger.info('Start strategy...') _logger.info('Start strategy...')
...@@ -303,102 +218,49 @@ class RetiariiExperiment(Experiment): ...@@ -303,102 +218,49 @@ class RetiariiExperiment(Experiment):
self.strategy.run(base_model_ir, self.applied_mutators) self.strategy.run(base_model_ir, self.applied_mutators)
_logger.info('Strategy exit') _logger.info('Strategy exit')
# TODO: find out a proper way to show no more trial message on WebUI # TODO: find out a proper way to show no more trial message on WebUI
# self._dispatcher.mark_experiment_as_ending()
def start(self, port: int = 8080, debug: bool = False) -> None: def _create_execution_engine(self, config: RetiariiExeConfig) -> None:
""" #TODO: we will probably need a execution engine factory to make this clean and elegant
Start the experiment in background. if isinstance(config.execution_engine, BaseEngineConfig):
This method will raise exception on failure.
If it returns, the experiment should have been successfully started.
Parameters
----------
port
The port of web UI.
debug
Whether to start in debug mode.
"""
atexit.register(self.stop)
self.config = self.config.canonical_copy()
# we will probably need a execution engine factory to make this clean and elegant
if self.config.execution_engine == 'base':
from ..execution.base import BaseExecutionEngine from ..execution.base import BaseExecutionEngine
engine = BaseExecutionEngine() engine = BaseExecutionEngine(self.port, self.url_prefix)
elif self.config.execution_engine == 'cgo': elif isinstance(config.execution_engine, CgoEngineConfig):
from ..execution.cgo_engine import CGOExecutionEngine from ..execution.cgo_engine import CGOExecutionEngine
assert self.config.training_service.platform == 'remote', \ assert not isinstance(config.training_service, list) \
and config.training_service.platform == 'remote', \
"CGO execution engine currently only supports remote training service" "CGO execution engine currently only supports remote training service"
assert self.config.batch_waiting_time is not None and self.config.max_concurrency_cgo is not None assert config.execution_engine.batch_waiting_time is not None \
devices = self._construct_devices() and config.execution_engine.max_concurrency_cgo is not None
engine = CGOExecutionEngine(devices, engine = CGOExecutionEngine(cast(RemoteConfig, config.training_service),
max_concurrency=self.config.max_concurrency_cgo, max_concurrency=config.execution_engine.max_concurrency_cgo,
batch_waiting_time=self.config.batch_waiting_time) batch_waiting_time=config.execution_engine.batch_waiting_time,
elif self.config.execution_engine == 'py': rest_port=self.port,
rest_url_prefix=self.url_prefix)
elif isinstance(config.execution_engine, PyEngineConfig):
from ..execution.python import PurePythonExecutionEngine from ..execution.python import PurePythonExecutionEngine
engine = PurePythonExecutionEngine() engine = PurePythonExecutionEngine(self.port, self.url_prefix)
elif self.config.execution_engine == 'benchmark': elif isinstance(config.execution_engine, BenchmarkEngineConfig):
from ..execution.benchmark import BenchmarkExecutionEngine from ..execution.benchmark import BenchmarkExecutionEngine
assert self.config.benchmark is not None, '"benchmark" must be set when benchmark execution engine is used.' assert config.execution_engine.benchmark is not None, \
engine = BenchmarkExecutionEngine(self.config.benchmark) '"benchmark" must be set when benchmark execution engine is used.'
engine = BenchmarkExecutionEngine(config.execution_engine.benchmark)
else: else:
raise ValueError(f'Unsupported engine type: {self.config.execution_engine}') raise ValueError(f'Unsupported engine type: {config.execution_engine}')
set_execution_engine(engine) set_execution_engine(engine)
self.id = management.generate_experiment_id() def start(self, *args, **kwargs) -> None:
"""
log_file = Path(self.config.experiment_working_directory, self.id, 'log', 'experiment.log') By design, the only different between `start` and `run` is that `start` is asynchronous,
log_file.parent.mkdir(parents=True, exist_ok=True) while `run` waits the experiment to complete. RetiariiExperiment always waits the experiment
log_level = 'debug' if (debug or self.config.log_level == 'trace') else self.config.log_level to complete as strategy runs in foreground.
nni.runtime.log.start_experiment_logging(self.id, log_file, cast(str, log_level)) """
raise NotImplementedError('RetiariiExperiment is not supposed to provide `start` method')
ws_url = f'ws://localhost:{port}/tuner'
self._proc = launcher.start_experiment('create', self.id, self.config, port, debug, # type: ignore def run(self,
RunMode.Background, None, ws_url, ['retiarii']) config: RetiariiExeConfig | None = None,
assert self._proc is not None port: int = 8080,
debug: bool = False) -> None:
self.port = port # port will be None if start up failed
# dispatcher must be launched after pipe initialized
# the logic to launch dispatcher in background should be refactored into dispatcher api
self._dispatcher = self._create_dispatcher()
if self._dispatcher is not None:
self._dispatcher._channel = TunerCommandChannel(ws_url)
self._dispatcher_thread = Thread(target=self._dispatcher.run)
self._dispatcher_thread.start()
ips = [self.config.nni_manager_ip]
for interfaces in psutil.net_if_addrs().values():
for interface in interfaces:
if interface.family == socket.AF_INET:
ips.append(interface.address)
ips = [f'http://{ip}:{port}' for ip in ips if ip]
msg = 'Web UI URLs: ' + colorama.Fore.CYAN + ' '.join(ips) + colorama.Style.RESET_ALL
_logger.info(msg)
exp_status_checker = Thread(target=self._check_exp_status)
exp_status_checker.start()
self._start_strategy()
# TODO: the experiment should be completed, when strategy exits and there is no running job
_logger.info('Waiting for experiment to become DONE (you can ctrl+c if there is no running trial jobs)...')
exp_status_checker.join()
def _construct_devices(self):
devices = []
if hasattr(self.config.training_service, 'machine_list'):
for machine in cast(RemoteConfig, self.config.training_service).machine_list:
assert machine.gpu_indices is not None, \
'gpu_indices must be set in RemoteMachineConfig for CGO execution engine'
assert isinstance(machine.gpu_indices, list), 'gpu_indices must be a list'
for gpu_idx in machine.gpu_indices:
devices.append(GPUDevice(machine.host, gpu_idx))
return devices
def _create_dispatcher(self):
return self._dispatcher
def run(self, config: Optional[RetiariiExeConfig] = None, port: int = 8080, debug: bool = False) -> None:
""" """
Run the experiment. Run the experiment.
This function will block until experiment finish or error. This function will block until experiment finish or error.
...@@ -410,75 +272,47 @@ class RetiariiExperiment(Experiment): ...@@ -410,75 +272,47 @@ class RetiariiExperiment(Experiment):
# 'In case you want to stick to the old implementation, ' # 'In case you want to stick to the old implementation, '
# 'please consider using ``trainer.fit()`` instead of experiment.', DeprecationWarning) # 'please consider using ``trainer.fit()`` instead of experiment.', DeprecationWarning)
self.evaluator.fit() self.evaluator.fit()
return
if config is None: if config is None:
warnings.warn('config = None is deprecate in future. If you are running a one-shot experiment, ' warnings.warn('config = None is deprecate in future. If you are running a one-shot experiment, '
'please consider creating a config and set execution engine to `oneshot`.', DeprecationWarning) 'please consider creating a config and set execution engine to `oneshot`.', DeprecationWarning)
config = RetiariiExeConfig() self.config = RetiariiExeConfig()
config.execution_engine = 'oneshot' self.config.execution_engine = OneshotEngineConfig()
else:
self.config = config
if config.execution_engine == 'oneshot': if isinstance(self.config.execution_engine, OneshotEngineConfig) \
or (isinstance(self.config.execution_engine, str) and self.config.execution_engine == 'oneshot'):
# this is hacky, will be refactored when oneshot can run on training services
base_model_ir, self.applied_mutators = preprocess_model(self.base_model, self.evaluator, self.applied_mutators, oneshot=True) base_model_ir, self.applied_mutators = preprocess_model(self.base_model, self.evaluator, self.applied_mutators, oneshot=True)
self.strategy.run(base_model_ir, self.applied_mutators) self.strategy.run(base_model_ir, self.applied_mutators)
else: else:
assert config is not None, 'You are using classic search mode, config cannot be None!' ws_url = f'ws://localhost:{port}/tuner'
self.config = config canonicalized_config = self._start_impl(port, debug, RunMode.Background, ws_url, ['retiarii'])
self.start(port, debug) canonicalized_config = cast(RetiariiExeConfig, canonicalized_config)
self._dispatcher = RetiariiAdvisor(ws_url)
def _check_exp_status(self) -> bool: self._dispatcher_thread = Thread(target=self._dispatcher.run, daemon=True)
""" self._dispatcher_thread.start()
Run the experiment. # FIXME: engine cannot be created twice
This function will block until experiment finish or error. self._create_execution_engine(canonicalized_config)
Return `True` when experiment done; or return `False` when experiment failed. try:
""" self._run_strategy(canonicalized_config)
assert self._proc is not None # FIXME: move this logic to strategy with a new API provided by execution engine
try: self._wait_completion()
while True: except KeyboardInterrupt:
time.sleep(10) _logger.warning('KeyboardInterrupt detected')
# this if is to deal with the situation that self.stop()
# nnimanager is cleaned up by ctrl+c first _logger.info('Search process is done, the experiment is still alive, `stop()` can terminate the experiment.')
if self._proc.poll() is None:
status = self.get_status()
else:
return False
if status == 'DONE' or status == 'STOPPED':
return True
if status == 'ERROR':
return False
except KeyboardInterrupt:
_logger.warning('KeyboardInterrupt detected')
finally:
self.stop()
raise RuntimeError('Check experiment status failed.')
def stop(self) -> None: def stop(self) -> None:
""" """
Stop background experiment. Stop background experiment.
""" """
_logger.info('Stopping experiment, please wait...') _logger.info('Stopping experiment, please wait...')
atexit.unregister(self.stop) self._stop_impl()
if self._dispatcher_thread:
# stop strategy first self._dispatcher_thread.join()
if self._dispatcher_thread is not None:
self._dispatcher.stopping = True
self._dispatcher_thread.join(timeout=1)
if self.id is not None:
nni.runtime.log.stop_experiment_logging(self.id)
if self._proc is not None:
try:
# this if is to deal with the situation that
# nnimanager is cleaned up by ctrl+c first
if self._proc.poll() is None:
rest.delete(self.port, '/experiment')
except Exception as e:
_logger.exception(e)
_logger.warning('Cannot gracefully stop experiment, killing NNI process...')
kill_command(self._proc.pid)
self.id = cast(str, None)
self.port = cast(int, None)
self._proc = None
self._dispatcher = cast(RetiariiAdvisor, None) self._dispatcher = cast(RetiariiAdvisor, None)
self._dispatcher_thread = None self._dispatcher_thread = None
_logger.info('Experiment stopped') _logger.info('Experiment stopped')
...@@ -502,8 +336,11 @@ class RetiariiExperiment(Experiment): ...@@ -502,8 +336,11 @@ class RetiariiExperiment(Experiment):
If ``code``, the python code of model will be returned. If ``code``, the python code of model will be returned.
If ``dict``, the mutation history will be returned. If ``dict``, the mutation history will be returned.
""" """
# TODO: the base class may also need this method
if formatter == 'code': if formatter == 'code':
assert self.config.execution_engine != 'py', 'You should use `dict` formatter when using Python execution engine.' config = self.config.canonical_copy()
assert not isinstance(config.execution_engine, PyEngineConfig), \
'You should use `dict` formatter when using Python execution engine.'
if isinstance(self.evaluator, BaseOneShotTrainer): if isinstance(self.evaluator, BaseOneShotTrainer):
assert top_k == 1, 'Only support top_k is 1 for now.' assert top_k == 1, 'Only support top_k is 1 for now.'
return self.evaluator.export() return self.evaluator.export()
...@@ -520,9 +357,3 @@ class RetiariiExperiment(Experiment): ...@@ -520,9 +357,3 @@ class RetiariiExperiment(Experiment):
return [model_to_pytorch_script(model) for model in all_models[:top_k]] return [model_to_pytorch_script(model) for model in all_models[:top_k]]
elif formatter == 'dict': elif formatter == 'dict':
return [get_mutation_dict(model) for model in all_models[:top_k]] return [get_mutation_dict(model) for model in all_models[:top_k]]
def retrain_model(self, model):
"""
this function retrains the exported model, and test it to output test accuracy
"""
raise NotImplementedError
...@@ -22,7 +22,10 @@ def get_advisor() -> 'RetiariiAdvisor': ...@@ -22,7 +22,10 @@ def get_advisor() -> 'RetiariiAdvisor':
def register_advisor(advisor: 'RetiariiAdvisor'): def register_advisor(advisor: 'RetiariiAdvisor'):
global _advisor global _advisor
assert _advisor is None if _advisor is not None:
warnings.warn('Advisor is already set.'
'You should avoid instantiating RetiariiExperiment twice in one proces.'
'If you are running in a Jupyter notebook, please restart the kernel.')
_advisor = advisor _advisor = advisor
......
...@@ -18,8 +18,15 @@ _worker_fast_exit_on_terminate = True ...@@ -18,8 +18,15 @@ _worker_fast_exit_on_terminate = True
class MsgDispatcherBase(Recoverable): class MsgDispatcherBase(Recoverable):
"""This is where tuners and assessors are not defined yet. """
This is where tuners and assessors are not defined yet.
Inherits this class to make your own advisor. Inherits this class to make your own advisor.
.. note::
The class inheriting MsgDispatcherBase should be instantiated
after nnimanager (rest server) is started, so that the object
is ready to use right after its instantiation.
""" """
def __init__(self, command_channel_url=None): def __init__(self, command_channel_url=None):
...@@ -27,6 +34,16 @@ class MsgDispatcherBase(Recoverable): ...@@ -27,6 +34,16 @@ class MsgDispatcherBase(Recoverable):
if command_channel_url is None: if command_channel_url is None:
command_channel_url = dispatcher_env_vars.NNI_TUNER_COMMAND_CHANNEL command_channel_url = dispatcher_env_vars.NNI_TUNER_COMMAND_CHANNEL
self._channel = TunerCommandChannel(command_channel_url) self._channel = TunerCommandChannel(command_channel_url)
# NOTE: `connect()` should be put in __init__. First, this `connect()` affects nnimanager's
# starting process, without `connect()` nnimanager is blocked in `dispatcher.init()`.
# Second, nas experiment uses a thread to execute `run()` of this class, thus, there is
# no way to know when the websocket between nnimanager and dispatcher is built. The following
# logic may crash is websocket is not built. One example is updating search space. If updating
# search space too soon, as the websocket has not been built, the rest api of updating search
# space will timeout.
# FIXME: this is making unittest happy
if not command_channel_url.startswith('ws://_unittest_'):
self._channel.connect()
self.default_command_queue = Queue() self.default_command_queue = Queue()
self.assessor_command_queue = Queue() self.assessor_command_queue = Queue()
self.default_worker = threading.Thread(target=self.command_queue_worker, args=(self.default_command_queue,)) self.default_worker = threading.Thread(target=self.command_queue_worker, args=(self.default_command_queue,))
...@@ -39,7 +56,6 @@ class MsgDispatcherBase(Recoverable): ...@@ -39,7 +56,6 @@ class MsgDispatcherBase(Recoverable):
""" """
_logger.info('Dispatcher started') _logger.info('Dispatcher started')
self._channel.connect()
self.default_worker.start() self.default_worker.start()
self.assessor_worker.start() self.assessor_worker.start()
......
...@@ -4,7 +4,6 @@ import warnings ...@@ -4,7 +4,6 @@ import warnings
import torch import torch
import torch.nn as torch_nn import torch.nn as torch_nn
from torchvision.models.utils import load_state_dict_from_url
import torch.nn.functional as F import torch.nn.functional as F
import sys import sys
......
...@@ -8,7 +8,7 @@ import nni.retiarii.evaluator.pytorch.cgo.evaluator as cgo ...@@ -8,7 +8,7 @@ import nni.retiarii.evaluator.pytorch.cgo.evaluator as cgo
from nni.retiarii import serialize from nni.retiarii import serialize
from base_mnasnet import MNASNet from base_mnasnet import MNASNet
from nni.experiment import RemoteMachineConfig from nni.experiment import RemoteMachineConfig
from nni.retiarii.experiment.pytorch import RetiariiExperiment, RetiariiExeConfig from nni.retiarii.experiment.pytorch import RetiariiExperiment, RetiariiExeConfig, CgoEngineConfig
from nni.retiarii.strategy import TPEStrategy from nni.retiarii.strategy import TPEStrategy
from torchvision import transforms from torchvision import transforms
from torchvision.datasets import CIFAR10 from torchvision.datasets import CIFAR10
...@@ -59,8 +59,6 @@ if __name__ == '__main__': ...@@ -59,8 +59,6 @@ if __name__ == '__main__':
exp_config.max_trial_number = 10 exp_config.max_trial_number = 10
exp_config.trial_gpu_number = 1 exp_config.trial_gpu_number = 1
exp_config.training_service.reuse_mode = True exp_config.training_service.reuse_mode = True
exp_config.max_concurrency_cgo = 3
exp_config.batch_waiting_time = 0
rm_conf = RemoteMachineConfig() rm_conf = RemoteMachineConfig()
rm_conf.host = '127.0.0.1' rm_conf.host = '127.0.0.1'
...@@ -73,6 +71,6 @@ if __name__ == '__main__': ...@@ -73,6 +71,6 @@ if __name__ == '__main__':
rm_conf.max_trial_number_per_gpu = 3 rm_conf.max_trial_number_per_gpu = 3
exp_config.training_service.machine_list = [rm_conf] exp_config.training_service.machine_list = [rm_conf]
exp_config.execution_engine = 'cgo' exp_config.execution_engine = CgoEngineConfig(max_concurrency_cgo = 3, batch_waiting_time = 0)
exp.run(exp_config, 8099) exp.run(exp_config, 8099)
\ No newline at end of file
...@@ -9,6 +9,7 @@ from pytorch_lightning.utilities.seed import seed_everything ...@@ -9,6 +9,7 @@ from pytorch_lightning.utilities.seed import seed_everything
from pathlib import Path from pathlib import Path
import nni import nni
from nni.experiment.config import RemoteConfig, RemoteMachineConfig
import nni.runtime.platform.test import nni.runtime.platform.test
from nni.runtime.tuner_command_channel import legacy as protocol from nni.runtime.tuner_command_channel import legacy as protocol
import json import json
...@@ -263,13 +264,14 @@ class CGOEngineTest(unittest.TestCase): ...@@ -263,13 +264,14 @@ class CGOEngineTest(unittest.TestCase):
opt = DedupInputOptimizer() opt = DedupInputOptimizer()
opt.convert(lp) opt.convert(lp)
advisor = RetiariiAdvisor('ws://_placeholder_') advisor = RetiariiAdvisor('ws://_unittest_placeholder_')
advisor._channel = protocol.LegacyCommandChannel() advisor._channel = protocol.LegacyCommandChannel()
advisor.default_worker.start() advisor.default_worker.start()
advisor.assessor_worker.start() advisor.assessor_worker.start()
available_devices = [GPUDevice("test", 0), GPUDevice("test", 1), GPUDevice("test", 2), GPUDevice("test", 3)] remote = RemoteConfig(machine_list=[])
cgo = CGOExecutionEngine(devices=available_devices, batch_waiting_time=0) remote.machine_list.append(RemoteMachineConfig(host='test', gpu_indices=[0,1,2,3]))
cgo = CGOExecutionEngine(training_service=remote, batch_waiting_time=0)
phy_models = cgo._assemble(lp) phy_models = cgo._assemble(lp)
self.assertTrue(len(phy_models) == 1) self.assertTrue(len(phy_models) == 1)
...@@ -286,13 +288,14 @@ class CGOEngineTest(unittest.TestCase): ...@@ -286,13 +288,14 @@ class CGOEngineTest(unittest.TestCase):
opt = DedupInputOptimizer() opt = DedupInputOptimizer()
opt.convert(lp) opt.convert(lp)
advisor = RetiariiAdvisor('ws://_placeholder_') advisor = RetiariiAdvisor('ws://_unittest_placeholder_')
advisor._channel = protocol.LegacyCommandChannel() advisor._channel = protocol.LegacyCommandChannel()
advisor.default_worker.start() advisor.default_worker.start()
advisor.assessor_worker.start() advisor.assessor_worker.start()
available_devices = [GPUDevice("test", 0), GPUDevice("test", 1)] remote = RemoteConfig(machine_list=[])
cgo = CGOExecutionEngine(devices=available_devices, batch_waiting_time=0) remote.machine_list.append(RemoteMachineConfig(host='test', gpu_indices=[0,1]))
cgo = CGOExecutionEngine(training_service=remote, batch_waiting_time=0)
phy_models = cgo._assemble(lp) phy_models = cgo._assemble(lp)
self.assertTrue(len(phy_models) == 2) self.assertTrue(len(phy_models) == 2)
...@@ -311,13 +314,14 @@ class CGOEngineTest(unittest.TestCase): ...@@ -311,13 +314,14 @@ class CGOEngineTest(unittest.TestCase):
models = _load_mnist(2) models = _load_mnist(2)
advisor = RetiariiAdvisor('ws://_placeholder_') advisor = RetiariiAdvisor('ws://_unittest_placeholder_')
advisor._channel = protocol.LegacyCommandChannel() advisor._channel = protocol.LegacyCommandChannel()
advisor.default_worker.start() advisor.default_worker.start()
advisor.assessor_worker.start() advisor.assessor_worker.start()
cgo_engine = CGOExecutionEngine(devices=[GPUDevice("test", 0), GPUDevice("test", 1), remote = RemoteConfig(machine_list=[])
GPUDevice("test", 2), GPUDevice("test", 3)], batch_waiting_time=0) remote.machine_list.append(RemoteMachineConfig(host='test', gpu_indices=[0,1,2,3]))
cgo_engine = CGOExecutionEngine(training_service=remote, batch_waiting_time=0)
set_execution_engine(cgo_engine) set_execution_engine(cgo_engine)
submit_models(*models) submit_models(*models)
time.sleep(3) time.sleep(3)
......
...@@ -25,7 +25,7 @@ class EngineTest(unittest.TestCase): ...@@ -25,7 +25,7 @@ class EngineTest(unittest.TestCase):
def test_base_execution_engine(self): def test_base_execution_engine(self):
nni.retiarii.integration_api._advisor = None nni.retiarii.integration_api._advisor = None
nni.retiarii.execution.api._execution_engine = None nni.retiarii.execution.api._execution_engine = None
advisor = RetiariiAdvisor('ws://_placeholder_') advisor = RetiariiAdvisor('ws://_unittest_placeholder_')
advisor._channel = LegacyCommandChannel() advisor._channel = LegacyCommandChannel()
advisor.default_worker.start() advisor.default_worker.start()
advisor.assessor_worker.start() advisor.assessor_worker.start()
...@@ -42,7 +42,7 @@ class EngineTest(unittest.TestCase): ...@@ -42,7 +42,7 @@ class EngineTest(unittest.TestCase):
def test_py_execution_engine(self): def test_py_execution_engine(self):
nni.retiarii.integration_api._advisor = None nni.retiarii.integration_api._advisor = None
nni.retiarii.execution.api._execution_engine = None nni.retiarii.execution.api._execution_engine = None
advisor = RetiariiAdvisor('ws://_placeholder_') advisor = RetiariiAdvisor('ws://_unittest_placeholder_')
advisor._channel = LegacyCommandChannel() advisor._channel = LegacyCommandChannel()
advisor.default_worker.start() advisor.default_worker.start()
advisor.assessor_worker.start() advisor.assessor_worker.start()
......
...@@ -57,7 +57,7 @@ class AssessorTestCase(TestCase): ...@@ -57,7 +57,7 @@ class AssessorTestCase(TestCase):
_restore_io() _restore_io()
assessor = NaiveAssessor() assessor = NaiveAssessor()
dispatcher = MsgDispatcher('ws://_placeholder_', None, assessor) dispatcher = MsgDispatcher('ws://_unittest_placeholder_', None, assessor)
dispatcher._channel = LegacyCommandChannel() dispatcher._channel = LegacyCommandChannel()
msg_dispatcher_base._worker_fast_exit_on_terminate = False msg_dispatcher_base._worker_fast_exit_on_terminate = False
......
...@@ -66,7 +66,7 @@ class MsgDispatcherTestCase(TestCase): ...@@ -66,7 +66,7 @@ class MsgDispatcherTestCase(TestCase):
_restore_io() _restore_io()
tuner = NaiveTuner() tuner = NaiveTuner()
dispatcher = MsgDispatcher('ws://_placeholder_', tuner) dispatcher = MsgDispatcher('ws://_unittest_placeholder_', tuner)
dispatcher._channel = LegacyCommandChannel() dispatcher._channel = LegacyCommandChannel()
msg_dispatcher_base._worker_fast_exit_on_terminate = False msg_dispatcher_base._worker_fast_exit_on_terminate = False
......
...@@ -303,8 +303,11 @@ class NNIManager implements Manager { ...@@ -303,8 +303,11 @@ class NNIManager implements Manager {
} }
this.trainingService.removeTrialJobMetricListener(this.trialJobMetricListener); this.trainingService.removeTrialJobMetricListener(this.trialJobMetricListener);
// NOTE: this sending TERMINATE should be out of the if clause,
// because when python dispatcher is started before nnimanager
// this.dispatcherPid would not have a valid value (i.e., not >0).
this.dispatcher.sendCommand(TERMINATE);
if (this.dispatcherPid > 0) { if (this.dispatcherPid > 0) {
this.dispatcher.sendCommand(TERMINATE);
// gracefully terminate tuner and assessor here, wait at most 30 seconds. // gracefully terminate tuner and assessor here, wait at most 30 seconds.
for (let i: number = 0; i < 30; i++) { for (let i: number = 0; i < 30; i++) {
if (!await isAlive(this.dispatcherPid)) { if (!await isAlive(this.dispatcherPid)) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment