Unverified Commit 98c1a77f authored by liuzhe-lz's avatar liuzhe-lz Committed by GitHub
Browse files

Support multiple HPO experiments in one process (#4855)

parent 5dc80762
......@@ -9,7 +9,6 @@ import base64
from .runtime.msg_dispatcher import MsgDispatcher
from .runtime.msg_dispatcher_base import MsgDispatcherBase
from .runtime.protocol import connect_websocket
from .tools.package_utils import create_builtin_class_instance, create_customized_class_instance
logger = logging.getLogger('nni.main')
......@@ -21,10 +20,6 @@ if os.environ.get('COVERAGE_PROCESS_START'):
def main():
# the url should be "ws://localhost:{port}/tuner" or "ws://localhost:{port}/{url_prefix}/tuner"
ws_url = os.environ['NNI_TUNER_COMMAND_CHANNEL']
connect_websocket(ws_url)
parser = argparse.ArgumentParser(description='Dispatcher command line parser')
parser.add_argument('--exp_params', type=str, required=True)
args, _ = parser.parse_known_args()
......@@ -56,7 +51,10 @@ def main():
assessor = _create_algo(exp_params['assessor'], 'assessor')
else:
assessor = None
dispatcher = MsgDispatcher(tuner, assessor)
# the url should be "ws://localhost:{port}/tuner" or "ws://localhost:{port}/{url_prefix}/tuner"
url = os.environ['NNI_TUNER_COMMAND_CHANNEL']
dispatcher = MsgDispatcher(url, tuner, assessor)
try:
dispatcher.run()
......
......@@ -14,7 +14,7 @@ from ConfigSpace.read_and_write import pcs_new
import nni
from nni import ClassArgsValidator
from nni.runtime.protocol import CommandType, send
from nni.runtime.tuner_command_channel import CommandType
from nni.runtime.msg_dispatcher_base import MsgDispatcherBase
from nni.utils import OptimizeMode, MetricType, extract_scalar_reward
from nni.runtime.common import multi_phase_enabled
......@@ -483,7 +483,7 @@ class BOHB(MsgDispatcherBase):
raise ValueError('Error: Search space is None')
# generate first brackets
self.generate_new_bracket()
send(CommandType.Initialized, '')
self.send(CommandType.Initialized, '')
def generate_new_bracket(self):
"""generate a new bracket"""
......@@ -541,7 +541,7 @@ class BOHB(MsgDispatcherBase):
'parameter_source': 'algorithm',
'parameters': ''
}
send(CommandType.NoMoreTrialJobs, nni.dump(ret))
self.send(CommandType.NoMoreTrialJobs, nni.dump(ret))
return None
assert self.generated_hyper_configs
params = self.generated_hyper_configs.pop(0)
......@@ -572,7 +572,7 @@ class BOHB(MsgDispatcherBase):
"""
ret = self._get_one_trial_job()
if ret is not None:
send(CommandType.NewTrialJob, nni.dump(ret))
self.send(CommandType.NewTrialJob, nni.dump(ret))
self.credit -= 1
def handle_update_search_space(self, data):
......@@ -664,7 +664,7 @@ class BOHB(MsgDispatcherBase):
ret['parameter_index'] = one_unsatisfied['parameter_index']
# update parameter_id in self.job_id_para_id_map
self.job_id_para_id_map[ret['trial_job_id']] = ret['parameter_id']
send(CommandType.SendTrialJobParameter, nni.dump(ret))
self.send(CommandType.SendTrialJobParameter, nni.dump(ret))
for _ in range(self.credit):
self._request_one_trial_job()
......@@ -712,7 +712,7 @@ class BOHB(MsgDispatcherBase):
ret['parameter_index'] = data['parameter_index']
# update parameter_id in self.job_id_para_id_map
self.job_id_para_id_map[data['trial_job_id']] = ret['parameter_id']
send(CommandType.SendTrialJobParameter, nni.dump(ret))
self.send(CommandType.SendTrialJobParameter, nni.dump(ret))
else:
assert 'value' in data
value = extract_scalar_reward(data['value'])
......
......@@ -18,7 +18,7 @@ from nni import ClassArgsValidator
from nni.common.hpo_utils import validate_search_space
from nni.runtime.common import multi_phase_enabled
from nni.runtime.msg_dispatcher_base import MsgDispatcherBase
from nni.runtime.protocol import CommandType, send
from nni.runtime.tuner_command_channel import CommandType
from nni.utils import NodeType, OptimizeMode, MetricType, extract_scalar_reward
from nni import parameter_expressions
......@@ -432,7 +432,7 @@ class Hyperband(MsgDispatcherBase):
search space
"""
self.handle_update_search_space(data)
send(CommandType.Initialized, '')
self.send(CommandType.Initialized, '')
def handle_request_trial_jobs(self, data):
"""
......@@ -449,7 +449,7 @@ class Hyperband(MsgDispatcherBase):
def _request_one_trial_job(self):
ret = self._get_one_trial_job()
if ret is not None:
send(CommandType.NewTrialJob, nni.dump(ret))
self.send(CommandType.NewTrialJob, nni.dump(ret))
self.credit -= 1
def _get_one_trial_job(self):
......@@ -478,7 +478,7 @@ class Hyperband(MsgDispatcherBase):
'parameter_source': 'algorithm',
'parameters': ''
}
send(CommandType.NoMoreTrialJobs, nni.dump(ret))
self.send(CommandType.NoMoreTrialJobs, nni.dump(ret))
return None
assert self.generated_hyper_configs
......@@ -553,7 +553,7 @@ class Hyperband(MsgDispatcherBase):
if data['parameter_index'] is not None:
ret['parameter_index'] = data['parameter_index']
self.job_id_para_id_map[data['trial_job_id']] = ret['parameter_id']
send(CommandType.SendTrialJobParameter, nni.dump(ret))
self.send(CommandType.SendTrialJobParameter, nni.dump(ret))
else:
value = extract_scalar_reward(data['value'])
bracket_id, i, _ = data['parameter_id'].split('_')
......
......@@ -24,7 +24,7 @@ from nni.experiment.config import utils
from nni.experiment.config.base import ConfigBase
from nni.experiment.config.training_service import TrainingServiceConfig
from nni.experiment.config.training_services import RemoteConfig
from nni.runtime.protocol import connect_websocket
from nni.runtime.tuner_command_channel import TunerCommandChannel
from nni.tools.nnictl.command_utils import kill_command
from ..codegen import model_to_pytorch_script
......@@ -274,7 +274,8 @@ class RetiariiExperiment(Experiment):
from nni.retiarii.oneshot.pytorch.strategy import OneShotStrategy
if not isinstance(strategy, OneShotStrategy):
self._dispatcher = RetiariiAdvisor()
# FIXME: Dispatcher should not be created this early.
self._dispatcher = RetiariiAdvisor('_placeholder_')
else:
self._dispatcher = cast(RetiariiAdvisor, None)
self._dispatcher_thread: Optional[Thread] = None
......@@ -357,13 +358,14 @@ class RetiariiExperiment(Experiment):
self._proc = launcher.start_experiment('create', self.id, self.config, port, debug, # type: ignore
RunMode.Background, None, ws_url, ['retiarii'])
assert self._proc is not None
connect_websocket(ws_url)
self.port = port # port will be None if start up failed
# dispatcher must be launched after pipe initialized
# the logic to launch dispatcher in background should be refactored into dispatcher api
self._dispatcher = self._create_dispatcher()
if self._dispatcher is not None:
self._dispatcher._channel = TunerCommandChannel(ws_url)
self._dispatcher_thread = Thread(target=self._dispatcher.run)
self._dispatcher_thread.start()
......
......@@ -9,7 +9,7 @@ import nni
from nni.common.serializer import PayloadTooLarge
from nni.common.version import version_dump
from nni.runtime.msg_dispatcher_base import MsgDispatcherBase
from nni.runtime.protocol import CommandType, send
from nni.runtime.tuner_command_channel import CommandType
from nni.utils import MetricType
from .graph import MetricData
......@@ -48,8 +48,8 @@ class RetiariiAdvisor(MsgDispatcherBase):
final_metric_callback
"""
def __init__(self):
super(RetiariiAdvisor, self).__init__()
def __init__(self, url: str):
super().__init__(url)
register_advisor(self) # register the current advisor as the "global only" advisor
self.search_space = None
......@@ -69,7 +69,7 @@ class RetiariiAdvisor(MsgDispatcherBase):
search space
"""
self.handle_update_search_space(data)
send(CommandType.Initialized, '')
self.send(CommandType.Initialized, '')
def _validate_placement_constraint(self, placement_constraint):
if placement_constraint is None:
......@@ -138,14 +138,14 @@ class RetiariiAdvisor(MsgDispatcherBase):
# trial parameters can be super large, disable pickle size limit here
# nevertheless, there could still be blocked by pipe / nni-manager
send(CommandType.NewTrialJob, send_payload)
self.send(CommandType.NewTrialJob, send_payload)
if self.send_trial_callback is not None:
self.send_trial_callback(parameters) # pylint: disable=not-callable
return self.parameters_count
def mark_experiment_as_ending(self):
send(CommandType.NoMoreTrialJobs, '')
self.send(CommandType.NoMoreTrialJobs, '')
def handle_request_trial_jobs(self, num_trials):
_logger.debug('Request trial jobs: %s', num_trials)
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
_multi_thread = False
_multi_phase = False
def enable_multi_thread():
global _multi_thread
_multi_thread = True
def multi_thread_enabled():
return _multi_thread
def enable_multi_phase():
global _multi_phase
_multi_phase = True
......
......@@ -22,7 +22,8 @@ _dispatcher_env_var_names = [
'NNI_CHECKPOINT_DIRECTORY',
'NNI_LOG_DIRECTORY',
'NNI_LOG_LEVEL',
'NNI_INCLUDE_INTERMEDIATE_RESULTS'
'NNI_INCLUDE_INTERMEDIATE_RESULTS',
'NNI_TUNER_COMMAND_CHANNEL',
]
def _load_env_vars(env_var_names):
......
......@@ -7,10 +7,10 @@ from collections import defaultdict
from nni import NoMoreTrialError
from nni.assessor import AssessResult
from .common import multi_thread_enabled, multi_phase_enabled
from .common import multi_phase_enabled
from .env_vars import dispatcher_env_vars
from .msg_dispatcher_base import MsgDispatcherBase
from .protocol import CommandType, send
from .tuner_command_channel import CommandType
from ..common.serializer import dump, load
from ..utils import MetricType
......@@ -67,8 +67,8 @@ def _pack_parameter(parameter_id, params, customized=False, trial_job_id=None, p
class MsgDispatcher(MsgDispatcherBase):
def __init__(self, tuner, assessor=None):
super(MsgDispatcher, self).__init__()
def __init__(self, command_channel_url, tuner, assessor=None):
super().__init__(command_channel_url)
self.tuner = tuner
self.assessor = assessor
if assessor is None:
......@@ -88,12 +88,12 @@ class MsgDispatcher(MsgDispatcherBase):
"""Data is search space
"""
self.tuner.update_search_space(data)
send(CommandType.Initialized, '')
self.send(CommandType.Initialized, '')
def send_trial_callback(self, id_, params):
"""For tuner to issue trial config when the config is generated
"""
send(CommandType.NewTrialJob, _pack_parameter(id_, params))
self.send(CommandType.NewTrialJob, _pack_parameter(id_, params))
def handle_request_trial_jobs(self, data):
# data: number or trial jobs
......@@ -102,10 +102,10 @@ class MsgDispatcher(MsgDispatcherBase):
params_list = self.tuner.generate_multiple_parameters(ids, st_callback=self.send_trial_callback)
for i, _ in enumerate(params_list):
send(CommandType.NewTrialJob, _pack_parameter(ids[i], params_list[i]))
self.send(CommandType.NewTrialJob, _pack_parameter(ids[i], params_list[i]))
# when parameters is None.
if len(params_list) < len(ids):
send(CommandType.NoMoreTrialJobs, _pack_parameter(ids[0], ''))
self.send(CommandType.NoMoreTrialJobs, _pack_parameter(ids[0], ''))
def handle_update_search_space(self, data):
self.tuner.update_search_space(data)
......@@ -148,7 +148,7 @@ class MsgDispatcher(MsgDispatcherBase):
param = self.tuner.generate_parameters(param_id, trial_job_id=data['trial_job_id'])
except NoMoreTrialError:
param = None
send(CommandType.SendTrialJobParameter, _pack_parameter(param_id, param, trial_job_id=data['trial_job_id'],
self.send(CommandType.SendTrialJobParameter, _pack_parameter(param_id, param, trial_job_id=data['trial_job_id'],
parameter_index=data['parameter_index']))
else:
raise ValueError('Data type not supported: {}'.format(data['type']))
......@@ -222,7 +222,7 @@ class MsgDispatcher(MsgDispatcherBase):
if result is AssessResult.Bad:
_logger.debug('BAD, kill %s', trial_job_id)
send(CommandType.KillTrialJob, dump(trial_job_id))
self.send(CommandType.KillTrialJob, dump(trial_job_id))
# notify tuner
_logger.debug('env var: NNI_INCLUDE_INTERMEDIATE_RESULTS: [%s]',
dispatcher_env_vars.NNI_INCLUDE_INTERMEDIATE_RESULTS)
......@@ -237,8 +237,5 @@ class MsgDispatcher(MsgDispatcherBase):
"""
_logger.debug('Early stop notify tuner data: [%s]', data)
data['type'] = MetricType.FINAL
if multi_thread_enabled():
self._handle_final_metric_data(data)
else:
data['value'] = dump(data['value'])
self.enqueue_command(CommandType.ReportMetricData, data)
data['value'] = dump(data['value'])
self.enqueue_command(CommandType.ReportMetricData, data)
......@@ -3,14 +3,12 @@
import threading
import logging
from multiprocessing.dummy import Pool as ThreadPool
from queue import Queue, Empty
from .common import multi_thread_enabled
from .env_vars import dispatcher_env_vars
from ..common import load
from ..recoverable import Recoverable
from .protocol import CommandType, receive
from .tuner_command_channel import CommandType, TunerCommandChannel
_logger = logging.getLogger(__name__)
......@@ -24,59 +22,52 @@ class MsgDispatcherBase(Recoverable):
Inherits this class to make your own advisor.
"""
def __init__(self):
def __init__(self, command_channel_url=None):
self.stopping = False
if multi_thread_enabled():
self.pool = ThreadPool()
self.thread_results = []
else:
self.default_command_queue = Queue()
self.assessor_command_queue = Queue()
self.default_worker = threading.Thread(target=self.command_queue_worker, args=(self.default_command_queue,))
self.assessor_worker = threading.Thread(target=self.command_queue_worker,
args=(self.assessor_command_queue,))
self.default_worker.start()
self.assessor_worker.start()
self.worker_exceptions = []
if command_channel_url is None:
command_channel_url = dispatcher_env_vars.NNI_TUNER_COMMAND_CHANNEL
self._channel = TunerCommandChannel(command_channel_url)
self.default_command_queue = Queue()
self.assessor_command_queue = Queue()
self.default_worker = threading.Thread(target=self.command_queue_worker, args=(self.default_command_queue,))
self.assessor_worker = threading.Thread(target=self.command_queue_worker, args=(self.assessor_command_queue,))
self.worker_exceptions = []
def run(self):
"""Run the tuner.
This function will never return unless raise.
"""
_logger.info('Dispatcher started')
self._channel.connect()
self.default_worker.start()
self.assessor_worker.start()
if dispatcher_env_vars.NNI_MODE == 'resume':
self.load_checkpoint()
while not self.stopping:
command, data = receive()
command, data = self._channel._receive()
if data:
data = load(data)
if command is None or command is CommandType.Terminate:
break
if multi_thread_enabled():
result = self.pool.map_async(self.process_command_thread, [(command, data)])
self.thread_results.append(result)
if any([thread_result.ready() and not thread_result.successful() for thread_result in
self.thread_results]):
_logger.debug('Caught thread exception')
break
else:
self.enqueue_command(command, data)
if self.worker_exceptions:
break
self.enqueue_command(command, data)
if self.worker_exceptions:
break
_logger.info('Dispatcher exiting...')
self.stopping = True
if multi_thread_enabled():
self.pool.close()
self.pool.join()
else:
self.default_worker.join()
self.assessor_worker.join()
self.default_worker.join()
self.assessor_worker.join()
self._channel.disconnect()
_logger.info('Dispatcher terminiated')
def send(self, command, data):
self._channel._send(command, data)
def command_queue_worker(self, command_queue):
"""Process commands in command queues.
"""
......@@ -112,19 +103,6 @@ class MsgDispatcherBase(Recoverable):
if qsize >= QUEUE_LEN_WARNING_MARK:
_logger.warning('assessor queue length: %d', qsize)
def process_command_thread(self, request):
"""Worker thread to process a command.
"""
command, data = request
if multi_thread_enabled():
try:
self.process_command(command, data)
except Exception as e:
_logger.exception(str(e))
raise
else:
pass
def process_command(self, command, data):
_logger.debug('process_command: command: [%s], data: [%s]', command, data)
......@@ -242,4 +220,4 @@ class MsgDispatcherBase(Recoverable):
hyper_params: the string that is sent by message dispatcher during the creation of trials.
"""
raise NotImplementedError('handle_trial_end not implemented')
\ No newline at end of file
raise NotImplementedError('handle_trial_end not implemented')
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# pylint: disable=unused-import
from __future__ import annotations
from .tuner_command_channel.command_type import CommandType
from .tuner_command_channel import legacy
from .tuner_command_channel import shim
_use_ws = False
def connect_websocket(url: str):
global _use_ws
_use_ws = True
shim.connect(url)
def send(command: CommandType, data: str) -> None:
if _use_ws:
shim.send(command, data)
else:
legacy.send(command, data)
def receive() -> tuple[CommandType, str] | tuple[None, None]:
if _use_ws:
return shim.receive()
else:
return legacy.receive()
# for unit test compatibility
def _set_in_file(in_file):
legacy._in_file = in_file
def _set_out_file(out_file):
legacy._out_file = out_file
def _get_out_file():
return legacy._out_file
......@@ -2,7 +2,8 @@
# Licensed under the MIT license.
"""
The IPC channel between tuner/assessor and NNI manager.
Work in progress.
Low level APIs for algorithms to communicate with NNI manager.
"""
from .command_type import CommandType
from .channel import TunerCommandChannel
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""
Low level APIs for algorithms to communicate with NNI manager.
"""
from __future__ import annotations
__all__ = ['TunerCommandChannel']
from .command_type import CommandType
from .websocket import WebSocket
class TunerCommandChannel:
"""
A channel to communicate with NNI manager.
Each NNI experiment has a channel URL for tuner/assessor/strategy algorithm.
The channel can only be connected once, so for each Python side :class:`~nni.experiment.Experiment` object,
there should be exactly one corresponding ``TunerCommandChannel`` instance.
:meth:`connect` must be invoked before sending or receiving data.
The constructor does not have side effect so ``TunerCommandChannel`` can be created anywhere.
But :meth:`connect` requires an initialized NNI manager, or otherwise the behavior is unpredictable.
:meth:`_send` and :meth:`_receive` are underscore-prefixed because their signatures are scheduled to change by v3.0.
Parameters
----------
url
The command channel URL.
For now it must be like ``"ws://localhost:8080/tuner"`` or ``"ws://localhost:8080/url-prefix/tuner"``.
"""
def __init__(self, url: str):
self._channel = WebSocket(url)
def connect(self) -> None:
self._channel.connect()
def disconnect(self) -> None:
self._channel.disconnect()
# TODO: Define semantic command class like `KillTrialJob(trial_id='abc')`.
# def send(self, command: Command) -> None:
# ...
# def receive(self) -> Command | None:
# ...
def _send(self, command_type: CommandType, data: str) -> None:
command = command_type.value.decode() + data
self._channel.send(command)
def _receive(self) -> tuple[CommandType, str] | tuple[None, None]:
command = self._channel.receive()
if command is None:
raise RuntimeError('NNI manager closed connection')
command_type = CommandType(command[:2].encode())
return command_type, command[2:]
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
__all__ = [
'CommandType',
'LegacyCommandChannel',
'send',
'receive',
'_set_in_file',
'_set_out_file',
'_get_out_file',
]
import logging
import os
import threading
......@@ -18,6 +28,29 @@ try:
except OSError:
_logger.debug('IPC pipeline not exists')
def _set_in_file(in_file):
global _in_file
_in_file = in_file
def _set_out_file(out_file):
global _out_file
_out_file = out_file
def _get_out_file():
return _out_file
class LegacyCommandChannel:
def connect(self):
pass
def disconnect(self):
pass
def _send(self, command, data):
send(command, data)
def _receive(self):
return receive()
def send(command, data):
"""Send command to Training Service.
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""
Compatibility layer for old protocol APIs.
We are working on more semantic new APIs.
"""
from __future__ import annotations
from .command_type import CommandType
from .websocket import WebSocket
_ws: WebSocket = None # type: ignore
def connect(url: str) -> None:
global _ws
_ws = WebSocket(url)
_ws.connect()
def send(command_type: CommandType, data: str) -> None:
command = command_type.value.decode() + data
_ws.send(command)
def receive() -> tuple[CommandType, str]:
command = _ws.receive()
if command is None:
raise RuntimeError('NNI manager closed connection')
command_type = CommandType(command[:2].encode())
if command_type is CommandType.Terminate:
_ws.disconnect()
return command_type, command[2:]
......@@ -2,7 +2,6 @@
# Licensed under the MIT license.
import json
import logging
import os
from schema import And, Optional, Or, Regex, Schema, SchemaError
......@@ -77,7 +76,6 @@ class AlgoSchema:
if not builtin_name or not class_args:
return
logging.getLogger('nni.protocol').setLevel(logging.ERROR) # we know IPC is not there, don't complain
validator = create_validator_instance(algo_type+'s', builtin_name)
if validator:
try:
......
......@@ -88,11 +88,14 @@ def create_experiment(args):
exp = Experiment(config)
exp.url_prefix = url_prefix
run_mode = RunMode.Foreground if foreground else RunMode.Detach
exp.start(port, debug, run_mode)
_logger.info(f'To stop experiment run "nnictl stop {exp.id}" or "nnictl stop --all"')
_logger.info('Reference: https://nni.readthedocs.io/en/stable/reference/nnictl.html')
if foreground:
exp.run(port, debug=debug)
else:
exp.start(port, debug, RunMode.Detach)
_logger.info(f'To stop experiment run "nnictl stop {exp.id}" or "nnictl stop --all"')
_logger.info('Reference: https://nni.readthedocs.io/en/stable/reference/nnictl.html')
def resume_experiment(args):
exp_id = args.id
......
......@@ -10,6 +10,7 @@ from pathlib import Path
import nni
import nni.runtime.platform.test
from nni.runtime.tuner_command_channel import legacy as protocol
import json
try:
......@@ -262,7 +263,11 @@ class CGOEngineTest(unittest.TestCase):
opt = DedupInputOptimizer()
opt.convert(lp)
advisor = RetiariiAdvisor()
advisor = RetiariiAdvisor('ws://_placeholder_')
advisor._channel = protocol.LegacyCommandChannel()
advisor.default_worker.start()
advisor.assessor_worker.start()
available_devices = [GPUDevice("test", 0), GPUDevice("test", 1), GPUDevice("test", 2), GPUDevice("test", 3)]
cgo = CGOExecutionEngine(devices=available_devices, batch_waiting_time=0)
......@@ -281,7 +286,11 @@ class CGOEngineTest(unittest.TestCase):
opt = DedupInputOptimizer()
opt.convert(lp)
advisor = RetiariiAdvisor()
advisor = RetiariiAdvisor('ws://_placeholder_')
advisor._channel = protocol.LegacyCommandChannel()
advisor.default_worker.start()
advisor.assessor_worker.start()
available_devices = [GPUDevice("test", 0), GPUDevice("test", 1)]
cgo = CGOExecutionEngine(devices=available_devices, batch_waiting_time=0)
......@@ -296,14 +305,17 @@ class CGOEngineTest(unittest.TestCase):
_reset()
nni.retiarii.debug_configs.framework = 'pytorch'
os.makedirs('generated', exist_ok=True)
from nni.runtime import protocol
import nni.runtime.platform.test as tt
protocol._set_out_file(open('generated/debug_protocol_out_file.py', 'wb'))
protocol._set_in_file(open('generated/debug_protocol_out_file.py', 'rb'))
models = _load_mnist(2)
advisor = RetiariiAdvisor()
advisor = RetiariiAdvisor('ws://_placeholder_')
advisor._channel = protocol.LegacyCommandChannel()
advisor.default_worker.start()
advisor.assessor_worker.start()
cgo_engine = CGOExecutionEngine(devices=[GPUDevice("test", 0), GPUDevice("test", 1),
GPUDevice("test", 2), GPUDevice("test", 3)], batch_waiting_time=0)
set_execution_engine(cgo_engine)
......
......@@ -11,7 +11,7 @@ from nni.retiarii.execution.base import BaseExecutionEngine
from nni.retiarii.execution.python import PurePythonExecutionEngine
from nni.retiarii.graph import DebugEvaluator
from nni.retiarii.integration import RetiariiAdvisor
from nni.runtime.tuner_command_channel.legacy import *
class EngineTest(unittest.TestCase):
def test_codegen(self):
......@@ -25,7 +25,11 @@ class EngineTest(unittest.TestCase):
def test_base_execution_engine(self):
nni.retiarii.integration_api._advisor = None
nni.retiarii.execution.api._execution_engine = None
advisor = RetiariiAdvisor()
advisor = RetiariiAdvisor('ws://_placeholder_')
advisor._channel = LegacyCommandChannel()
advisor.default_worker.start()
advisor.assessor_worker.start()
set_execution_engine(BaseExecutionEngine())
with open(self.enclosing_dir / 'mnist_pytorch.json') as f:
model = Model._load(json.load(f))
......@@ -38,7 +42,11 @@ class EngineTest(unittest.TestCase):
def test_py_execution_engine(self):
nni.retiarii.integration_api._advisor = None
nni.retiarii.execution.api._execution_engine = None
advisor = RetiariiAdvisor()
advisor = RetiariiAdvisor('ws://_placeholder_')
advisor._channel = LegacyCommandChannel()
advisor.default_worker.start()
advisor.assessor_worker.start()
set_execution_engine(PurePythonExecutionEngine())
model = Model._load({
'_model': {
......@@ -63,11 +71,9 @@ class EngineTest(unittest.TestCase):
def setUp(self) -> None:
self.enclosing_dir = Path(__file__).parent
os.makedirs(self.enclosing_dir / 'generated', exist_ok=True)
from nni.runtime import protocol
protocol._set_out_file(open(self.enclosing_dir / 'generated/debug_protocol_out_file.py', 'wb'))
_set_out_file(open(self.enclosing_dir / 'generated/debug_protocol_out_file.py', 'wb'))
def tearDown(self) -> None:
from nni.runtime import protocol
protocol._get_out_file().close()
_get_out_file().close()
nni.retiarii.execution.api._execution_engine = None
nni.retiarii.integration_api._advisor = None
......@@ -8,8 +8,7 @@ from unittest import TestCase, main
from nni.assessor import Assessor, AssessResult
from nni.runtime import msg_dispatcher_base as msg_dispatcher_base
from nni.runtime.msg_dispatcher import MsgDispatcher
from nni.runtime import protocol
from nni.runtime.protocol import CommandType, send, receive
from nni.runtime.tuner_command_channel.legacy import *
_trials = []
_end_trials = []
......@@ -34,15 +33,15 @@ _out_buf = BytesIO()
def _reverse_io():
_in_buf.seek(0)
_out_buf.seek(0)
protocol._set_out_file(_in_buf)
protocol._set_in_file(_out_buf)
_set_out_file(_in_buf)
_set_in_file(_out_buf)
def _restore_io():
_in_buf.seek(0)
_out_buf.seek(0)
protocol._set_in_file(_in_buf)
protocol._set_out_file(_out_buf)
_set_in_file(_in_buf)
_set_out_file(_out_buf)
class AssessorTestCase(TestCase):
......@@ -58,7 +57,8 @@ class AssessorTestCase(TestCase):
_restore_io()
assessor = NaiveAssessor()
dispatcher = MsgDispatcher(None, assessor)
dispatcher = MsgDispatcher('ws://_placeholder_', None, assessor)
dispatcher._channel = LegacyCommandChannel()
msg_dispatcher_base._worker_fast_exit_on_terminate = False
dispatcher.run()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment