Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
f60d3d5e
Unverified
Commit
f60d3d5e
authored
Apr 29, 2022
by
liuzhe-lz
Committed by
GitHub
Apr 29, 2022
Browse files
WebSocket (step 1) - Python client (#4806)
parent
b39850f9
Changes
15
Hide whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
418 additions
and
82 deletions
+418
-82
dependencies/develop.txt
dependencies/develop.txt
+1
-0
nni/experiment/launcher.py
nni/experiment/launcher.py
+2
-2
nni/runtime/protocol.py
nni/runtime/protocol.py
+13
-65
nni/runtime/tuner_command_channel/__init__.py
nni/runtime/tuner_command_channel/__init__.py
+8
-0
nni/runtime/tuner_command_channel/command_type.py
nni/runtime/tuner_command_channel/command_type.py
+23
-0
nni/runtime/tuner_command_channel/legacy.py
nni/runtime/tuner_command_channel/legacy.py
+54
-0
nni/runtime/tuner_command_channel/shim.py
nni/runtime/tuner_command_channel/shim.py
+33
-0
nni/runtime/tuner_command_channel/websocket.py
nni/runtime/tuner_command_channel/websocket.py
+124
-0
test/ut/retiarii/test_cgo_engine.py
test/ut/retiarii/test_cgo_engine.py
+2
-2
test/ut/retiarii/test_engine.py
test/ut/retiarii/test_engine.py
+2
-2
test/ut/sdk/helper/websocket_server.py
test/ut/sdk/helper/websocket_server.py
+58
-0
test/ut/sdk/test_assessor.py
test/ut/sdk/test_assessor.py
+4
-4
test/ut/sdk/test_msg_dispatcher.py
test/ut/sdk/test_msg_dispatcher.py
+4
-4
test/ut/sdk/test_protocol.py
test/ut/sdk/test_protocol.py
+4
-3
test/ut/sdk/test_tuner_command_channel.py
test/ut/sdk/test_tuner_command_channel.py
+86
-0
No files found.
dependencies/develop.txt
View file @
f60d3d5e
aioconsole
coverage
cython
flake8
...
...
nni/experiment/launcher.py
View file @
f60d3d5e
...
...
@@ -177,8 +177,8 @@ def start_experiment_retiarii(exp_id, config, port, debug):
start_time
,
proc
=
_start_rest_server_retiarii
(
config
,
port
,
debug
,
exp_id
,
pipe
.
path
)
_logger
.
info
(
'Connecting IPC pipe...'
)
pipe_file
=
pipe
.
connect
()
nni
.
runtime
.
protocol
.
_in_file
=
pipe_file
nni
.
runtime
.
protocol
.
_out_file
=
pipe_file
nni
.
runtime
.
protocol
.
_
set_
in_file
(
pipe_file
)
nni
.
runtime
.
protocol
.
_
set_
out_file
(
pipe_file
)
_logger
.
info
(
'Starting web server...'
)
_check_rest_server
(
port
)
platform
=
'hybrid'
if
isinstance
(
config
.
training_service
,
list
)
else
config
.
training_service
.
platform
...
...
nni/runtime/protocol.py
View file @
f60d3d5e
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
logging
import
os
import
threading
from
enum
import
Enum
# pylint: disable=unused-import
_logger
=
logging
.
getLogger
(
__name__
)
from
.tuner_command_channel.command_type
import
CommandType
from
.tuner_command_channel.legacy
import
send
,
receive
# for unit test compatibility
def
_set_in_file
(
in_file
):
from
.tuner_command_channel
import
legacy
legacy
.
_in_file
=
in_file
class
CommandType
(
Enum
):
# in
Initialize
=
b
'IN'
RequestTrialJobs
=
b
'GE'
ReportMetricData
=
b
'ME'
UpdateSearchSpace
=
b
'SS'
ImportData
=
b
'FD'
AddCustomizedTrialJob
=
b
'AD'
TrialEnd
=
b
'EN'
Terminate
=
b
'TE'
Ping
=
b
'PI'
def
_set_out_file
(
out_file
):
from
.tuner_command_channel
import
legacy
legacy
.
_out_file
=
out_file
# out
Initialized
=
b
'ID'
NewTrialJob
=
b
'TR'
SendTrialJobParameter
=
b
'SP'
NoMoreTrialJobs
=
b
'NO'
KillTrialJob
=
b
'KI'
_lock
=
threading
.
Lock
()
try
:
if
os
.
environ
.
get
(
'NNI_PLATFORM'
)
!=
'unittest'
:
_in_file
=
open
(
3
,
'rb'
)
_out_file
=
open
(
4
,
'wb'
)
except
OSError
:
_logger
.
debug
(
'IPC pipeline not exists'
)
def
send
(
command
,
data
):
"""Send command to Training Service.
command: CommandType object.
data: string payload.
"""
global
_lock
try
:
_lock
.
acquire
()
data
=
data
.
encode
(
'utf8'
)
msg
=
b
'%b%014d%b'
%
(
command
.
value
,
len
(
data
),
data
)
_logger
.
debug
(
'Sending command, data: [%s]'
,
msg
)
_out_file
.
write
(
msg
)
_out_file
.
flush
()
finally
:
_lock
.
release
()
def
receive
():
"""Receive a command from Training Service.
Returns a tuple of command (CommandType) and payload (str)
"""
header
=
_in_file
.
read
(
16
)
_logger
.
debug
(
'Received command, header: [%s]'
,
header
)
if
header
is
None
or
len
(
header
)
<
16
:
# Pipe EOF encountered
_logger
.
debug
(
'Pipe EOF encountered'
)
return
None
,
None
length
=
int
(
header
[
2
:])
data
=
_in_file
.
read
(
length
)
command
=
CommandType
(
header
[:
2
])
data
=
data
.
decode
(
'utf8'
)
_logger
.
debug
(
'Received command, data: [%s]'
,
data
)
return
command
,
data
def
_get_out_file
():
from
.tuner_command_channel
import
legacy
return
legacy
.
_out_file
nni/runtime/tuner_command_channel/__init__.py
0 → 100644
View file @
f60d3d5e
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""
The IPC channel between tuner/assessor and NNI manager.
Work in progress.
"""
nni/runtime/tuner_command_channel/command_type.py
0 → 100644
View file @
f60d3d5e
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from
enum
import
Enum
class
CommandType
(
Enum
):
# in
Initialize
=
b
'IN'
RequestTrialJobs
=
b
'GE'
ReportMetricData
=
b
'ME'
UpdateSearchSpace
=
b
'SS'
ImportData
=
b
'FD'
AddCustomizedTrialJob
=
b
'AD'
TrialEnd
=
b
'EN'
Terminate
=
b
'TE'
Ping
=
b
'PI'
# out
Initialized
=
b
'ID'
NewTrialJob
=
b
'TR'
SendTrialJobParameter
=
b
'SP'
NoMoreTrialJobs
=
b
'NO'
KillTrialJob
=
b
'KI'
nni/runtime/tuner_command_channel/legacy.py
0 → 100644
View file @
f60d3d5e
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
logging
import
os
import
threading
from
.command_type
import
CommandType
_logger
=
logging
.
getLogger
(
__name__
)
_lock
=
threading
.
Lock
()
try
:
if
os
.
environ
.
get
(
'NNI_PLATFORM'
)
!=
'unittest'
:
_in_file
=
open
(
3
,
'rb'
)
_out_file
=
open
(
4
,
'wb'
)
except
OSError
:
_logger
.
debug
(
'IPC pipeline not exists'
)
def
send
(
command
,
data
):
"""Send command to Training Service.
command: CommandType object.
data: string payload.
"""
global
_lock
try
:
_lock
.
acquire
()
data
=
data
.
encode
(
'utf8'
)
msg
=
b
'%b%014d%b'
%
(
command
.
value
,
len
(
data
),
data
)
_logger
.
debug
(
'Sending command, data: [%s]'
,
msg
)
_out_file
.
write
(
msg
)
_out_file
.
flush
()
finally
:
_lock
.
release
()
def
receive
():
"""Receive a command from Training Service.
Returns a tuple of command (CommandType) and payload (str)
"""
header
=
_in_file
.
read
(
16
)
_logger
.
debug
(
'Received command, header: [%s]'
,
header
)
if
header
is
None
or
len
(
header
)
<
16
:
# Pipe EOF encountered
_logger
.
debug
(
'Pipe EOF encountered'
)
return
None
,
None
length
=
int
(
header
[
2
:])
data
=
_in_file
.
read
(
length
)
command
=
CommandType
(
header
[:
2
])
data
=
data
.
decode
(
'utf8'
)
_logger
.
debug
(
'Received command, data: [%s]'
,
data
)
return
command
,
data
nni/runtime/tuner_command_channel/shim.py
0 → 100644
View file @
f60d3d5e
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""
Compatibility layer for old protocol APIs.
We are working on more semantic new APIs.
"""
from
__future__
import
annotations
from
.command_type
import
CommandType
from
.websocket
import
WebSocket
_ws
:
WebSocket
=
None
# type: ignore
def
connect
(
url
:
str
)
->
None
:
global
_ws
_ws
=
WebSocket
(
url
)
_ws
.
connect
()
def
send
(
command_type
:
CommandType
,
data
:
str
)
->
None
:
command
=
command_type
.
value
.
decode
()
+
data
_ws
.
send
(
command
)
def
receive
()
->
tuple
[
CommandType
,
str
]:
command
=
_ws
.
receive
()
if
command
is
None
:
raise
RuntimeError
(
'NNI manager closed connection'
)
command_type
=
CommandType
(
command
[:
2
].
encode
())
if
command_type
is
CommandType
.
Terminate
:
_ws
.
disconnect
()
return
command_type
,
command
[
2
:]
nni/runtime/tuner_command_channel/websocket.py
0 → 100644
View file @
f60d3d5e
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""
Synchronized and object-oriented WebSocket class.
WebSocket guarantees that messages will not be divided at API level.
"""
from
__future__
import
annotations
__all__
=
[
'WebSocket'
]
import
asyncio
import
logging
from
threading
import
Lock
,
Thread
from
typing
import
Any
import
websockets
_logger
=
logging
.
getLogger
(
__name__
)
# the singleton event loop
_event_loop
:
asyncio
.
AbstractEventLoop
=
None
# type: ignore
_event_loop_lock
:
Lock
=
Lock
()
_event_loop_refcnt
:
int
=
0
# number of connected websockets
class
WebSocket
:
"""
A WebSocket connection.
Call :meth:`connect` before :meth:`send` and :meth:`receive`.
All methods are thread safe.
Parameters
----------
url
The WebSocket URL.
For tuner command channel it should be something like ``ws://localhost:8080/tuner``.
"""
def
__init__
(
self
,
url
:
str
):
self
.
_url
:
str
=
url
self
.
_ws
:
Any
=
None
# the library does not provide type hints
def
connect
(
self
)
->
None
:
global
_event_loop
,
_event_loop_refcnt
with
_event_loop_lock
:
_event_loop_refcnt
+=
1
if
_event_loop
is
None
:
_logger
.
debug
(
'Starting event loop.'
)
# following line must be outside _run_event_loop
# because _wait() might be executed before first line of the child thread
_event_loop
=
asyncio
.
new_event_loop
()
thread
=
Thread
(
target
=
_run_event_loop
,
name
=
'NNI-WebSocketEventLoop'
,
daemon
=
True
)
thread
.
start
()
_logger
.
debug
(
f
'Connecting to
{
self
.
_url
}
'
)
self
.
_ws
=
_wait
(
_connect_async
(
self
.
_url
))
_logger
.
debug
(
f
'Connected.'
)
def
disconnect
(
self
)
->
None
:
if
self
.
_ws
is
None
:
_logger
.
debug
(
'disconnect: No connection.'
)
return
try
:
_wait
(
self
.
_ws
.
close
())
_logger
.
debug
(
'Connection closed by client.'
)
except
Exception
as
e
:
_logger
.
warning
(
f
'Failed to close connection:
{
repr
(
e
)
}
'
)
self
.
_ws
=
None
_decrease_refcnt
()
def
send
(
self
,
message
:
str
)
->
None
:
_logger
.
debug
(
f
'Sending
{
message
}
'
)
_wait
(
self
.
_ws
.
send
(
message
))
def
receive
(
self
)
->
str
|
None
:
"""
Return received message;
or return ``None`` if the connection has been closed by peer.
"""
try
:
msg
=
_wait
(
self
.
_ws
.
recv
())
_logger
.
debug
(
f
'Received
{
msg
}
'
)
except
websockets
.
ConnectionClosed
:
# type: ignore
_logger
.
debug
(
'Connection closed by server.'
)
self
.
_ws
=
None
_decrease_refcnt
()
return
None
# seems the library will inference whether it's text or binary, so we don't have guarantee
if
isinstance
(
msg
,
bytes
):
return
msg
.
decode
()
else
:
return
msg
def
_wait
(
coro
):
# Synchronized version of "await".
future
=
asyncio
.
run_coroutine_threadsafe
(
coro
,
_event_loop
)
return
future
.
result
()
def
_run_event_loop
()
->
None
:
# A separate thread to run the event loop.
# The event loop itself is blocking, and send/receive are also blocking,
# so they must run in different threads.
asyncio
.
set_event_loop
(
_event_loop
)
_event_loop
.
run_forever
()
_logger
.
debug
(
'Event loop stopped.'
)
async
def
_connect_async
(
url
):
# Theoretically this function is meaningless and one can directly use `websockets.connect(url)`,
# but it will not work, raising "TypeError: A coroutine object is required".
# Seems a design flaw in websockets library.
return
await
websockets
.
connect
(
url
,
max_size
=
None
)
# type: ignore
def
_decrease_refcnt
()
->
None
:
global
_event_loop
,
_event_loop_refcnt
with
_event_loop_lock
:
_event_loop_refcnt
-=
1
if
_event_loop_refcnt
==
0
:
_event_loop
.
call_soon_threadsafe
(
_event_loop
.
stop
)
_event_loop
=
None
# type: ignore
test/ut/retiarii/test_cgo_engine.py
View file @
f60d3d5e
...
...
@@ -298,8 +298,8 @@ class CGOEngineTest(unittest.TestCase):
os
.
makedirs
(
'generated'
,
exist_ok
=
True
)
from
nni.runtime
import
protocol
import
nni.runtime.platform.test
as
tt
protocol
.
_out_file
=
open
(
'generated/debug_protocol_out_file.py'
,
'wb'
)
protocol
.
_in_file
=
open
(
'generated/debug_protocol_out_file.py'
,
'rb'
)
protocol
.
_
set_
out_file
(
open
(
'generated/debug_protocol_out_file.py'
,
'wb'
)
)
protocol
.
_
set_
in_file
(
open
(
'generated/debug_protocol_out_file.py'
,
'rb'
)
)
models
=
_load_mnist
(
2
)
...
...
test/ut/retiarii/test_engine.py
View file @
f60d3d5e
...
...
@@ -64,10 +64,10 @@ class EngineTest(unittest.TestCase):
self
.
enclosing_dir
=
Path
(
__file__
).
parent
os
.
makedirs
(
self
.
enclosing_dir
/
'generated'
,
exist_ok
=
True
)
from
nni.runtime
import
protocol
protocol
.
_out_file
=
open
(
self
.
enclosing_dir
/
'generated/debug_protocol_out_file.py'
,
'wb'
)
protocol
.
_
set_
out_file
(
open
(
self
.
enclosing_dir
/
'generated/debug_protocol_out_file.py'
,
'wb'
)
)
def
tearDown
(
self
)
->
None
:
from
nni.runtime
import
protocol
protocol
.
_out_file
.
close
()
protocol
.
_
get_
out_file
()
.
close
()
nni
.
retiarii
.
execution
.
api
.
_execution_engine
=
None
nni
.
retiarii
.
integration_api
.
_advisor
=
None
test/ut/sdk/helper/websocket_server.py
0 → 100644
View file @
f60d3d5e
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""
A WebSocket server runs on random port, accepting one single client.
It prints each message received from client to stdout,
and send each line read from stdin to the client.
"""
import
asyncio
import
sys
import
aioconsole
import
websockets
sys
.
stdin
.
reconfigure
(
encoding
=
'utf_8'
)
sys
.
stdout
.
reconfigure
(
encoding
=
'utf_8'
)
sys
.
stderr
.
reconfigure
(
encoding
=
'utf_8'
)
_ws
=
None
async
def
main
():
await
asyncio
.
gather
(
ws_server
(),
read_stdin
()
)
async
def
read_stdin
():
async_stdin
,
_
=
await
aioconsole
.
get_standard_streams
()
async
for
line
in
async_stdin
:
line
=
line
.
decode
().
strip
()
_debug
(
f
'read from stdin:
{
line
}
'
)
if
line
==
'_close_'
:
exit
()
await
_ws
.
send
(
line
)
async
def
ws_server
():
async
with
websockets
.
serve
(
on_connect
,
'localhost'
,
0
)
as
server
:
port
=
server
.
sockets
[
0
].
getsockname
()[
1
]
print
(
port
,
flush
=
True
)
_debug
(
f
'port:
{
port
}
'
)
await
asyncio
.
Future
()
async
def
on_connect
(
ws
):
global
_ws
_debug
(
'connected'
)
_ws
=
ws
async
for
msg
in
ws
:
_debug
(
f
'received from websocket:
{
msg
}
'
)
print
(
msg
,
flush
=
True
)
def
_debug
(
msg
):
#sys.stderr.write(f'[server-debug] {msg}\n')
pass
if
__name__
==
'__main__'
:
asyncio
.
run
(
main
())
test/ut/sdk/test_assessor.py
View file @
f60d3d5e
...
...
@@ -34,15 +34,15 @@ _out_buf = BytesIO()
def
_reverse_io
():
_in_buf
.
seek
(
0
)
_out_buf
.
seek
(
0
)
protocol
.
_out_file
=
_in_buf
protocol
.
_in_file
=
_out_buf
protocol
.
_
set_
out_file
(
_in_buf
)
protocol
.
_
set_
in_file
(
_out_buf
)
def
_restore_io
():
_in_buf
.
seek
(
0
)
_out_buf
.
seek
(
0
)
protocol
.
_in_file
=
_in_buf
protocol
.
_out_file
=
_out_buf
protocol
.
_
set_
in_file
(
_in_buf
)
protocol
.
_
set_
out_file
(
_out_buf
)
class
AssessorTestCase
(
TestCase
):
...
...
test/ut/sdk/test_msg_dispatcher.py
View file @
f60d3d5e
...
...
@@ -45,15 +45,15 @@ _out_buf = BytesIO()
def
_reverse_io
():
_in_buf
.
seek
(
0
)
_out_buf
.
seek
(
0
)
protocol
.
_out_file
=
_in_buf
protocol
.
_in_file
=
_out_buf
protocol
.
_
set_
out_file
(
_in_buf
)
protocol
.
_
set_
in_file
(
_out_buf
)
def
_restore_io
():
_in_buf
.
seek
(
0
)
_out_buf
.
seek
(
0
)
protocol
.
_in_file
=
_in_buf
protocol
.
_out_file
=
_out_buf
protocol
.
_
set_
in_file
(
_in_buf
)
protocol
.
_
set_
out_file
(
_out_buf
)
class
MsgDispatcherTestCase
(
TestCase
):
...
...
test/ut/sdk/test_protocol.py
View file @
f60d3d5e
...
...
@@ -9,11 +9,12 @@ from unittest import TestCase, main
def
_prepare_send
():
protocol
.
_out_file
=
BytesIO
()
return
protocol
.
_out_file
out_file
=
BytesIO
()
protocol
.
_set_out_file
(
out_file
)
return
out_file
def
_prepare_receive
(
data
):
protocol
.
_in_file
=
BytesIO
(
data
)
protocol
.
_
set_
in_file
(
BytesIO
(
data
)
)
class
ProtocolTestCase
(
TestCase
):
...
...
test/ut/sdk/test_tuner_command_channel.py
0 → 100644
View file @
f60d3d5e
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
atexit
from
dataclasses
import
dataclass
import
importlib
import
json
import
os
from
pathlib
import
Path
from
subprocess
import
Popen
,
PIPE
import
sys
import
time
from
nni.runtime.tuner_command_channel.websocket
import
WebSocket
# A helper server that connects its stdio to incoming WebSocket.
_server
=
None
_client
=
None
_command1
=
'T_hello world'
_command2
=
'T_你好'
## test cases ##
def
test_connect
():
global
_client
port
=
_init
()
_client
=
WebSocket
(
f
'ws://localhost:
{
port
}
'
)
_client
.
connect
()
def
test_send
():
# Send commands to server via channel, and get them back via server's stdout.
_client
.
send
(
_command1
)
_client
.
send
(
_command2
)
time
.
sleep
(
0.01
)
sent1
=
_server
.
stdout
.
readline
().
strip
()
assert
sent1
==
_command1
,
sent1
sent2
=
_server
.
stdout
.
readline
().
strip
()
assert
sent2
==
_command2
,
sent2
def
test_receive
():
# Send commands to server via stdin, and get them back via channel.
_server
.
stdin
.
write
(
_command1
+
'
\n
'
)
_server
.
stdin
.
write
(
_command2
+
'
\n
'
)
_server
.
stdin
.
flush
()
received1
=
_client
.
receive
()
assert
received1
==
_command1
,
received1
received2
=
_client
.
receive
()
assert
received2
==
_command2
,
received2
def
test_disconnect
():
_client
.
disconnect
()
# release the port
global
_server
_server
.
stdin
.
write
(
'_close_
\n
'
)
_server
.
stdin
.
flush
()
time
.
sleep
(
0.1
)
_server
.
terminate
()
_server
=
None
## helper ##
def
_init
():
global
_server
# launch a server that connects websocket to stdio
script
=
(
Path
(
__file__
).
parent
/
'helper/websocket_server.py'
).
resolve
()
_server
=
Popen
([
sys
.
executable
,
str
(
script
)],
stdin
=
PIPE
,
stdout
=
PIPE
,
encoding
=
'utf_8'
)
time
.
sleep
(
0.1
)
# if a test fails, make sure to stop the server
atexit
.
register
(
lambda
:
_server
is
None
or
_server
.
terminate
())
return
int
(
_server
.
stdout
.
readline
().
strip
())
if
__name__
==
'__main__'
:
test_connect
()
test_send
()
test_receive
()
test_disconnect
()
print
(
'pass'
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment