Unverified Commit f60d3d5e authored by liuzhe-lz's avatar liuzhe-lz Committed by GitHub
Browse files

WebSocket (step 1) - Python client (#4806)

parent b39850f9
aioconsole
coverage coverage
cython cython
flake8 flake8
......
...@@ -177,8 +177,8 @@ def start_experiment_retiarii(exp_id, config, port, debug): ...@@ -177,8 +177,8 @@ def start_experiment_retiarii(exp_id, config, port, debug):
start_time, proc = _start_rest_server_retiarii(config, port, debug, exp_id, pipe.path) start_time, proc = _start_rest_server_retiarii(config, port, debug, exp_id, pipe.path)
_logger.info('Connecting IPC pipe...') _logger.info('Connecting IPC pipe...')
pipe_file = pipe.connect() pipe_file = pipe.connect()
nni.runtime.protocol._in_file = pipe_file nni.runtime.protocol._set_in_file(pipe_file)
nni.runtime.protocol._out_file = pipe_file nni.runtime.protocol._set_out_file(pipe_file)
_logger.info('Starting web server...') _logger.info('Starting web server...')
_check_rest_server(port) _check_rest_server(port)
platform = 'hybrid' if isinstance(config.training_service, list) else config.training_service.platform platform = 'hybrid' if isinstance(config.training_service, list) else config.training_service.platform
......
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT license. # Licensed under the MIT license.
import logging # pylint: disable=unused-import
import os
import threading
from enum import Enum
_logger = logging.getLogger(__name__) from .tuner_command_channel.command_type import CommandType
from .tuner_command_channel.legacy import send, receive
# for unit test compatibility
def _set_in_file(in_file):
from .tuner_command_channel import legacy
legacy._in_file = in_file
class CommandType(Enum): def _set_out_file(out_file):
# in from .tuner_command_channel import legacy
Initialize = b'IN' legacy._out_file = out_file
RequestTrialJobs = b'GE'
ReportMetricData = b'ME'
UpdateSearchSpace = b'SS'
ImportData = b'FD'
AddCustomizedTrialJob = b'AD'
TrialEnd = b'EN'
Terminate = b'TE'
Ping = b'PI'
# out def _get_out_file():
Initialized = b'ID' from .tuner_command_channel import legacy
NewTrialJob = b'TR' return legacy._out_file
SendTrialJobParameter = b'SP'
NoMoreTrialJobs = b'NO'
KillTrialJob = b'KI'
_lock = threading.Lock()
try:
if os.environ.get('NNI_PLATFORM') != 'unittest':
_in_file = open(3, 'rb')
_out_file = open(4, 'wb')
except OSError:
_logger.debug('IPC pipeline not exists')
def send(command, data):
"""Send command to Training Service.
command: CommandType object.
data: string payload.
"""
global _lock
try:
_lock.acquire()
data = data.encode('utf8')
msg = b'%b%014d%b' % (command.value, len(data), data)
_logger.debug('Sending command, data: [%s]', msg)
_out_file.write(msg)
_out_file.flush()
finally:
_lock.release()
def receive():
"""Receive a command from Training Service.
Returns a tuple of command (CommandType) and payload (str)
"""
header = _in_file.read(16)
_logger.debug('Received command, header: [%s]', header)
if header is None or len(header) < 16:
# Pipe EOF encountered
_logger.debug('Pipe EOF encountered')
return None, None
length = int(header[2:])
data = _in_file.read(length)
command = CommandType(header[:2])
data = data.decode('utf8')
_logger.debug('Received command, data: [%s]', data)
return command, data
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""
The IPC channel between tuner/assessor and NNI manager.
Work in progress.
"""
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from enum import Enum
class CommandType(Enum):
# in
Initialize = b'IN'
RequestTrialJobs = b'GE'
ReportMetricData = b'ME'
UpdateSearchSpace = b'SS'
ImportData = b'FD'
AddCustomizedTrialJob = b'AD'
TrialEnd = b'EN'
Terminate = b'TE'
Ping = b'PI'
# out
Initialized = b'ID'
NewTrialJob = b'TR'
SendTrialJobParameter = b'SP'
NoMoreTrialJobs = b'NO'
KillTrialJob = b'KI'
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
import os
import threading
from .command_type import CommandType
_logger = logging.getLogger(__name__)
_lock = threading.Lock()
try:
if os.environ.get('NNI_PLATFORM') != 'unittest':
_in_file = open(3, 'rb')
_out_file = open(4, 'wb')
except OSError:
_logger.debug('IPC pipeline not exists')
def send(command, data):
"""Send command to Training Service.
command: CommandType object.
data: string payload.
"""
global _lock
try:
_lock.acquire()
data = data.encode('utf8')
msg = b'%b%014d%b' % (command.value, len(data), data)
_logger.debug('Sending command, data: [%s]', msg)
_out_file.write(msg)
_out_file.flush()
finally:
_lock.release()
def receive():
"""Receive a command from Training Service.
Returns a tuple of command (CommandType) and payload (str)
"""
header = _in_file.read(16)
_logger.debug('Received command, header: [%s]', header)
if header is None or len(header) < 16:
# Pipe EOF encountered
_logger.debug('Pipe EOF encountered')
return None, None
length = int(header[2:])
data = _in_file.read(length)
command = CommandType(header[:2])
data = data.decode('utf8')
_logger.debug('Received command, data: [%s]', data)
return command, data
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""
Compatibility layer for old protocol APIs.
We are working on more semantic new APIs.
"""
from __future__ import annotations
from .command_type import CommandType
from .websocket import WebSocket
_ws: WebSocket = None # type: ignore
def connect(url: str) -> None:
global _ws
_ws = WebSocket(url)
_ws.connect()
def send(command_type: CommandType, data: str) -> None:
command = command_type.value.decode() + data
_ws.send(command)
def receive() -> tuple[CommandType, str]:
command = _ws.receive()
if command is None:
raise RuntimeError('NNI manager closed connection')
command_type = CommandType(command[:2].encode())
if command_type is CommandType.Terminate:
_ws.disconnect()
return command_type, command[2:]
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""
Synchronized and object-oriented WebSocket class.
WebSocket guarantees that messages will not be divided at API level.
"""
from __future__ import annotations
__all__ = ['WebSocket']
import asyncio
import logging
from threading import Lock, Thread
from typing import Any
import websockets
_logger = logging.getLogger(__name__)
# the singleton event loop
_event_loop: asyncio.AbstractEventLoop = None # type: ignore
_event_loop_lock: Lock = Lock()
_event_loop_refcnt: int = 0 # number of connected websockets
class WebSocket:
"""
A WebSocket connection.
Call :meth:`connect` before :meth:`send` and :meth:`receive`.
All methods are thread safe.
Parameters
----------
url
The WebSocket URL.
For tuner command channel it should be something like ``ws://localhost:8080/tuner``.
"""
def __init__(self, url: str):
self._url: str = url
self._ws: Any = None # the library does not provide type hints
def connect(self) -> None:
global _event_loop, _event_loop_refcnt
with _event_loop_lock:
_event_loop_refcnt += 1
if _event_loop is None:
_logger.debug('Starting event loop.')
# following line must be outside _run_event_loop
# because _wait() might be executed before first line of the child thread
_event_loop = asyncio.new_event_loop()
thread = Thread(target=_run_event_loop, name='NNI-WebSocketEventLoop', daemon=True)
thread.start()
_logger.debug(f'Connecting to {self._url}')
self._ws = _wait(_connect_async(self._url))
_logger.debug(f'Connected.')
def disconnect(self) -> None:
if self._ws is None:
_logger.debug('disconnect: No connection.')
return
try:
_wait(self._ws.close())
_logger.debug('Connection closed by client.')
except Exception as e:
_logger.warning(f'Failed to close connection: {repr(e)}')
self._ws = None
_decrease_refcnt()
def send(self, message: str) -> None:
_logger.debug(f'Sending {message}')
_wait(self._ws.send(message))
def receive(self) -> str | None:
"""
Return received message;
or return ``None`` if the connection has been closed by peer.
"""
try:
msg = _wait(self._ws.recv())
_logger.debug(f'Received {msg}')
except websockets.ConnectionClosed: # type: ignore
_logger.debug('Connection closed by server.')
self._ws = None
_decrease_refcnt()
return None
# seems the library will inference whether it's text or binary, so we don't have guarantee
if isinstance(msg, bytes):
return msg.decode()
else:
return msg
def _wait(coro):
# Synchronized version of "await".
future = asyncio.run_coroutine_threadsafe(coro, _event_loop)
return future.result()
def _run_event_loop() -> None:
# A separate thread to run the event loop.
# The event loop itself is blocking, and send/receive are also blocking,
# so they must run in different threads.
asyncio.set_event_loop(_event_loop)
_event_loop.run_forever()
_logger.debug('Event loop stopped.')
async def _connect_async(url):
# Theoretically this function is meaningless and one can directly use `websockets.connect(url)`,
# but it will not work, raising "TypeError: A coroutine object is required".
# Seems a design flaw in websockets library.
return await websockets.connect(url, max_size=None) # type: ignore
def _decrease_refcnt() -> None:
global _event_loop, _event_loop_refcnt
with _event_loop_lock:
_event_loop_refcnt -= 1
if _event_loop_refcnt == 0:
_event_loop.call_soon_threadsafe(_event_loop.stop)
_event_loop = None # type: ignore
...@@ -298,8 +298,8 @@ class CGOEngineTest(unittest.TestCase): ...@@ -298,8 +298,8 @@ class CGOEngineTest(unittest.TestCase):
os.makedirs('generated', exist_ok=True) os.makedirs('generated', exist_ok=True)
from nni.runtime import protocol from nni.runtime import protocol
import nni.runtime.platform.test as tt import nni.runtime.platform.test as tt
protocol._out_file = open('generated/debug_protocol_out_file.py', 'wb') protocol._set_out_file(open('generated/debug_protocol_out_file.py', 'wb'))
protocol._in_file = open('generated/debug_protocol_out_file.py', 'rb') protocol._set_in_file(open('generated/debug_protocol_out_file.py', 'rb'))
models = _load_mnist(2) models = _load_mnist(2)
......
...@@ -64,10 +64,10 @@ class EngineTest(unittest.TestCase): ...@@ -64,10 +64,10 @@ class EngineTest(unittest.TestCase):
self.enclosing_dir = Path(__file__).parent self.enclosing_dir = Path(__file__).parent
os.makedirs(self.enclosing_dir / 'generated', exist_ok=True) os.makedirs(self.enclosing_dir / 'generated', exist_ok=True)
from nni.runtime import protocol from nni.runtime import protocol
protocol._out_file = open(self.enclosing_dir / 'generated/debug_protocol_out_file.py', 'wb') protocol._set_out_file(open(self.enclosing_dir / 'generated/debug_protocol_out_file.py', 'wb'))
def tearDown(self) -> None: def tearDown(self) -> None:
from nni.runtime import protocol from nni.runtime import protocol
protocol._out_file.close() protocol._get_out_file().close()
nni.retiarii.execution.api._execution_engine = None nni.retiarii.execution.api._execution_engine = None
nni.retiarii.integration_api._advisor = None nni.retiarii.integration_api._advisor = None
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""
A WebSocket server runs on random port, accepting one single client.
It prints each message received from client to stdout,
and send each line read from stdin to the client.
"""
import asyncio
import sys
import aioconsole
import websockets
sys.stdin.reconfigure(encoding='utf_8')
sys.stdout.reconfigure(encoding='utf_8')
sys.stderr.reconfigure(encoding='utf_8')
_ws = None
async def main():
await asyncio.gather(
ws_server(),
read_stdin()
)
async def read_stdin():
async_stdin, _ = await aioconsole.get_standard_streams()
async for line in async_stdin:
line = line.decode().strip()
_debug(f'read from stdin: {line}')
if line == '_close_':
exit()
await _ws.send(line)
async def ws_server():
async with websockets.serve(on_connect, 'localhost', 0) as server:
port = server.sockets[0].getsockname()[1]
print(port, flush=True)
_debug(f'port: {port}')
await asyncio.Future()
async def on_connect(ws):
global _ws
_debug('connected')
_ws = ws
async for msg in ws:
_debug(f'received from websocket: {msg}')
print(msg, flush=True)
def _debug(msg):
#sys.stderr.write(f'[server-debug] {msg}\n')
pass
if __name__ == '__main__':
asyncio.run(main())
...@@ -34,15 +34,15 @@ _out_buf = BytesIO() ...@@ -34,15 +34,15 @@ _out_buf = BytesIO()
def _reverse_io(): def _reverse_io():
_in_buf.seek(0) _in_buf.seek(0)
_out_buf.seek(0) _out_buf.seek(0)
protocol._out_file = _in_buf protocol._set_out_file(_in_buf)
protocol._in_file = _out_buf protocol._set_in_file(_out_buf)
def _restore_io(): def _restore_io():
_in_buf.seek(0) _in_buf.seek(0)
_out_buf.seek(0) _out_buf.seek(0)
protocol._in_file = _in_buf protocol._set_in_file(_in_buf)
protocol._out_file = _out_buf protocol._set_out_file(_out_buf)
class AssessorTestCase(TestCase): class AssessorTestCase(TestCase):
......
...@@ -45,15 +45,15 @@ _out_buf = BytesIO() ...@@ -45,15 +45,15 @@ _out_buf = BytesIO()
def _reverse_io(): def _reverse_io():
_in_buf.seek(0) _in_buf.seek(0)
_out_buf.seek(0) _out_buf.seek(0)
protocol._out_file = _in_buf protocol._set_out_file(_in_buf)
protocol._in_file = _out_buf protocol._set_in_file(_out_buf)
def _restore_io(): def _restore_io():
_in_buf.seek(0) _in_buf.seek(0)
_out_buf.seek(0) _out_buf.seek(0)
protocol._in_file = _in_buf protocol._set_in_file(_in_buf)
protocol._out_file = _out_buf protocol._set_out_file(_out_buf)
class MsgDispatcherTestCase(TestCase): class MsgDispatcherTestCase(TestCase):
......
...@@ -9,11 +9,12 @@ from unittest import TestCase, main ...@@ -9,11 +9,12 @@ from unittest import TestCase, main
def _prepare_send(): def _prepare_send():
protocol._out_file = BytesIO() out_file = BytesIO()
return protocol._out_file protocol._set_out_file(out_file)
return out_file
def _prepare_receive(data): def _prepare_receive(data):
protocol._in_file = BytesIO(data) protocol._set_in_file(BytesIO(data))
class ProtocolTestCase(TestCase): class ProtocolTestCase(TestCase):
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import atexit
from dataclasses import dataclass
import importlib
import json
import os
from pathlib import Path
from subprocess import Popen, PIPE
import sys
import time
from nni.runtime.tuner_command_channel.websocket import WebSocket
# A helper server that connects its stdio to incoming WebSocket.
_server = None
_client = None
_command1 = 'T_hello world'
_command2 = 'T_你好'
## test cases ##
def test_connect():
global _client
port = _init()
_client = WebSocket(f'ws://localhost:{port}')
_client.connect()
def test_send():
# Send commands to server via channel, and get them back via server's stdout.
_client.send(_command1)
_client.send(_command2)
time.sleep(0.01)
sent1 = _server.stdout.readline().strip()
assert sent1 == _command1, sent1
sent2 = _server.stdout.readline().strip()
assert sent2 == _command2, sent2
def test_receive():
# Send commands to server via stdin, and get them back via channel.
_server.stdin.write(_command1 + '\n')
_server.stdin.write(_command2 + '\n')
_server.stdin.flush()
received1 = _client.receive()
assert received1 == _command1, received1
received2 = _client.receive()
assert received2 == _command2, received2
def test_disconnect():
_client.disconnect()
# release the port
global _server
_server.stdin.write('_close_\n')
_server.stdin.flush()
time.sleep(0.1)
_server.terminate()
_server = None
## helper ##
def _init():
global _server
# launch a server that connects websocket to stdio
script = (Path(__file__).parent / 'helper/websocket_server.py').resolve()
_server = Popen([sys.executable, str(script)], stdin=PIPE, stdout=PIPE, encoding='utf_8')
time.sleep(0.1)
# if a test fails, make sure to stop the server
atexit.register(lambda: _server is None or _server.terminate())
return int(_server.stdout.readline().strip())
if __name__ == '__main__':
test_connect()
test_send()
test_receive()
test_disconnect()
print('pass')
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment