Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
a74fa40d
Unverified
Commit
a74fa40d
authored
Oct 19, 2022
by
liuzhe-lz
Committed by
GitHub
Oct 19, 2022
Browse files
Let tuner auto reconnect to NNI manager (#5166)
parent
cbd5d8be
Changes
9
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
267 additions
and
40 deletions
+267
-40
nni/__main__.py
nni/__main__.py
+5
-3
nni/runtime/msg_dispatcher_base.py
nni/runtime/msg_dispatcher_base.py
+10
-0
nni/runtime/tuner_command_channel/channel.py
nni/runtime/tuner_command_channel/channel.py
+49
-3
nni/runtime/tuner_command_channel/command_type.py
nni/runtime/tuner_command_channel/command_type.py
+1
-0
nni/runtime/tuner_command_channel/websocket.py
nni/runtime/tuner_command_channel/websocket.py
+12
-3
ts/nni_manager/common/deferred.ts
ts/nni_manager/common/deferred.ts
+96
-0
ts/nni_manager/core/tuner_command_channel/shim.ts
ts/nni_manager/core/tuner_command_channel/shim.ts
+20
-4
ts/nni_manager/core/tuner_command_channel/websocket_channel.ts
...i_manager/core/tuner_command_channel/websocket_channel.ts
+60
-26
ts/nni_manager/test/core/tuner_command_channel.test.ts
ts/nni_manager/test/core/tuner_command_channel.test.ts
+14
-1
No files found.
nni/__main__.py
View file @
a74fa40d
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
os
import
argparse
import
logging
import
json
import
base64
import
json
import
logging
import
os
import
traceback
from
.runtime.msg_dispatcher
import
MsgDispatcher
from
.runtime.msg_dispatcher_base
import
MsgDispatcherBase
...
...
@@ -65,6 +66,7 @@ def main():
tuner
.
_on_error
()
if
assessor
is
not
None
:
assessor
.
_on_error
()
dispatcher
.
report_error
(
traceback
.
format_exc
())
raise
...
...
nni/runtime/msg_dispatcher_base.py
View file @
a74fa40d
...
...
@@ -84,6 +84,16 @@ class MsgDispatcherBase(Recoverable):
_logger
.
info
(
'Dispatcher terminiated'
)
def
report_error
(
self
,
error
:
str
)
->
None
:
'''
Report dispatcher error to NNI manager.
'''
_logger
.
info
(
f
'Report error to NNI manager:
{
error
}
'
)
try
:
self
.
send
(
CommandType
.
Error
,
error
)
except
Exception
:
_logger
.
error
(
'Connection to NNI manager is broken. Failed to report error.'
)
def
send
(
self
,
command
,
data
):
self
.
_channel
.
_send
(
command
,
data
)
...
...
nni/runtime/tuner_command_channel/channel.py
View file @
a74fa40d
...
...
@@ -9,9 +9,14 @@ from __future__ import annotations
__all__
=
[
'TunerCommandChannel'
]
import
logging
import
time
from
.command_type
import
CommandType
from
.websocket
import
WebSocket
_logger
=
logging
.
getLogger
(
__name__
)
class
TunerCommandChannel
:
"""
A channel to communicate with NNI manager.
...
...
@@ -35,7 +40,9 @@ class TunerCommandChannel:
"""
def
__init__
(
self
,
url
:
str
):
self
.
_url
=
url
self
.
_channel
=
WebSocket
(
url
)
self
.
_retry_intervals
=
[
0
,
1
,
10
]
def
connect
(
self
)
->
None
:
self
.
_channel
.
connect
()
...
...
@@ -51,11 +58,50 @@ class TunerCommandChannel:
def
_send
(
self
,
command_type
:
CommandType
,
data
:
str
)
->
None
:
command
=
command_type
.
value
.
decode
()
+
data
try
:
self
.
_channel
.
send
(
command
)
except
WebSocket
.
ConnectionClosed
:
self
.
_retry_send
(
command
)
def
_retry_send
(
self
,
command
:
str
)
->
None
:
_logger
.
warning
(
'Connection lost. Trying to reconnect...'
)
for
i
,
interval
in
enumerate
(
self
.
_retry_intervals
):
_logger
.
info
(
f
'Attempt #
{
i
}
, wait
{
interval
}
seconds...'
)
time
.
sleep
(
interval
)
self
.
_channel
=
WebSocket
(
self
.
_url
)
try
:
self
.
_channel
.
send
(
command
)
_logger
.
info
(
'Reconnected.'
)
return
except
Exception
as
e
:
_logger
.
exception
(
e
)
_logger
.
error
(
'Failed to reconnect.'
)
raise
RuntimeError
(
'Connection lost'
)
def
_receive
(
self
)
->
tuple
[
CommandType
,
str
]
|
tuple
[
None
,
None
]:
try
:
command
=
self
.
_channel
.
receive
()
except
WebSocket
.
ConnectionClosed
:
# this is for robustness and should never happen
_logger
.
warning
(
'ConnectionClosed exception on receiving.'
)
command
=
None
if
command
is
None
:
raise
RuntimeError
(
'NNI manager closed connection'
)
command
=
self
.
_retry_receive
(
)
command_type
=
CommandType
(
command
[:
2
].
encode
())
return
command_type
,
command
[
2
:]
def
_retry_receive
(
self
)
->
str
:
_logger
.
warning
(
'Connection lost. Trying to reconnect...'
)
for
i
,
interval
in
enumerate
(
self
.
_retry_intervals
):
_logger
.
info
(
f
'Attempt #
{
i
}
, wait
{
interval
}
seconds...'
)
time
.
sleep
(
interval
)
self
.
_channel
=
WebSocket
(
self
.
_url
)
try
:
command
=
self
.
_channel
.
receive
()
except
WebSocket
.
ConnectionClosed
:
command
=
None
# for robustness
if
command
is
not
None
:
_logger
.
info
(
'Reconnected'
)
return
command
_logger
.
error
(
'Failed to reconnect.'
)
raise
RuntimeError
(
'Connection lost'
)
nni/runtime/tuner_command_channel/command_type.py
View file @
a74fa40d
...
...
@@ -21,3 +21,4 @@ class CommandType(Enum):
SendTrialJobParameter
=
b
'SP'
NoMoreTrialJobs
=
b
'NO'
KillTrialJob
=
b
'KI'
Error
=
b
'ER'
nni/runtime/tuner_command_channel/websocket.py
View file @
a74fa40d
...
...
@@ -14,7 +14,7 @@ __all__ = ['WebSocket']
import
asyncio
import
logging
from
threading
import
Lock
,
Thread
from
typing
import
Any
from
typing
import
Any
,
Type
import
websockets
...
...
@@ -39,6 +39,9 @@ class WebSocket:
The WebSocket URL.
For tuner command channel it should be something like ``ws://localhost:8080/tuner``.
"""
ConnectionClosed
:
Type
[
Exception
]
=
websockets
.
ConnectionClosed
# type: ignore
def
__init__
(
self
,
url
:
str
):
self
.
_url
:
str
=
url
self
.
_ws
:
Any
=
None
# the library does not provide type hints
...
...
@@ -74,7 +77,13 @@ class WebSocket:
def
send
(
self
,
message
:
str
)
->
None
:
_logger
.
debug
(
f
'Sending
{
message
}
'
)
try
:
_wait
(
self
.
_ws
.
send
(
message
))
except
websockets
.
ConnectionClosed
:
# type: ignore
_logger
.
debug
(
'Connection closed by server.'
)
self
.
_ws
=
None
_decrease_refcnt
()
raise
def
receive
(
self
)
->
str
|
None
:
"""
...
...
@@ -88,7 +97,7 @@ class WebSocket:
_logger
.
debug
(
'Connection closed by server.'
)
self
.
_ws
=
None
_decrease_refcnt
()
r
eturn
Non
e
r
ais
e
# seems the library will inference whether it's text or binary, so we don't have guarantee
if
isinstance
(
msg
,
bytes
):
...
...
ts/nni_manager/common/deferred.ts
0 → 100644
View file @
a74fa40d
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
/**
* TODO: Back ported from 3.0 draft.
*
* An augmented version of ts-deferred.
*
* You can `await deferred.promise` more than once and they will be resolved together.
*
* You can resolve a deferred multiple times with identical value and it will be ignored.
*
* If a deferred is resolved and/or rejected with conflict values,
* it will throw error and log both values or reasons.
**/
import
util
from
'
util
'
;
import
{
Logger
,
getLogger
}
from
'
common/log
'
;
const
logger
=
getLogger
(
'
common.deferred
'
);
export
class
Deferred
<
T
>
{
private
resolveCallbacks
:
any
[]
=
[];
private
rejectCallbacks
:
any
[]
=
[];
private
isResolved
:
boolean
=
false
;
private
isRejected
:
boolean
=
false
;
private
resolvedValue
?:
T
;
private
rejectedReason
?:
Error
;
public
get
promise
():
Promise
<
T
>
{
// use getter to compat ts-deferred
if
(
this
.
isResolved
)
{
return
Promise
.
resolve
(
this
.
resolvedValue
)
as
Promise
<
T
>
;
}
if
(
this
.
isRejected
)
{
return
Promise
.
reject
(
this
.
rejectedReason
)
as
Promise
<
T
>
;
}
return
new
Promise
<
T
>
((
resolutionFunc
,
rejectionFunc
)
=>
{
this
.
resolveCallbacks
.
push
(
resolutionFunc
);
this
.
rejectCallbacks
.
push
(
rejectionFunc
);
});
}
public
get
settled
():
boolean
{
// use getter for consistent api style
return
this
.
isResolved
||
this
.
isRejected
;
}
public
resolve
=
(
value
:
T
):
void
=>
{
if
(
!
this
.
isResolved
&&
!
this
.
isRejected
)
{
this
.
isResolved
=
true
;
this
.
resolvedValue
=
value
;
for
(
const
callback
of
this
.
resolveCallbacks
)
{
callback
(
value
);
}
}
else
if
(
this
.
isResolved
&&
this
.
resolvedValue
==
value
)
{
logger
.
debug
(
'
Double resolve:
'
,
value
);
}
else
{
const
msg
=
this
.
errorMessage
(
'
trying to resolve with value:
'
+
util
.
inspect
(
value
));
logger
.
error
(
msg
);
throw
new
Error
(
'
Conflict Deferred result.
'
+
msg
);
}
}
public
reject
=
(
reason
:
Error
):
void
=>
{
if
(
!
this
.
isResolved
&&
!
this
.
isRejected
)
{
this
.
isRejected
=
true
;
this
.
rejectedReason
=
reason
;
for
(
const
callback
of
this
.
rejectCallbacks
)
{
callback
(
reason
);
}
}
else
if
(
this
.
isRejected
)
{
logger
.
warning
(
'
Double reject:
'
,
this
.
rejectedReason
,
reason
);
}
else
{
const
msg
=
this
.
errorMessage
(
'
trying to reject with reason:
'
+
util
.
inspect
(
reason
));
logger
.
error
(
msg
);
throw
new
Error
(
'
Conflict Deferred result.
'
+
msg
);
}
}
private
errorMessage
(
curStat
:
string
):
string
{
let
prevStat
=
''
;
if
(
this
.
isResolved
)
{
prevStat
=
'
Already resolved with value:
'
+
util
.
inspect
(
this
.
resolvedValue
);
}
if
(
this
.
isRejected
)
{
prevStat
=
'
Already rejected with reason:
'
+
util
.
inspect
(
this
.
rejectedReason
);
}
return
prevStat
+
'
;
'
+
curStat
;
}
}
ts/nni_manager/core/tuner_command_channel/shim.ts
View file @
a74fa40d
...
...
@@ -10,6 +10,24 @@ export async function createDispatcherInterface(): Promise<IpcInterface> {
class
WsIpcInterface
implements
IpcInterface
{
private
channel
:
WebSocketChannel
=
getWebSocketChannel
();
private
commandListener
?:
(
commandType
:
string
,
content
:
string
)
=>
void
;
private
errorListener
?:
(
error
:
Error
)
=>
void
;
constructor
()
{
this
.
channel
.
onCommand
((
command
:
string
)
=>
{
const
commandType
=
command
.
slice
(
0
,
2
);
const
content
=
command
.
slice
(
2
);
if
(
commandType
===
'
ER
'
)
{
if
(
this
.
errorListener
!==
undefined
)
{
this
.
errorListener
(
new
Error
(
content
));
}
}
else
{
if
(
this
.
commandListener
!==
undefined
)
{
this
.
commandListener
(
commandType
,
content
);
}
}
});
}
public
async
init
():
Promise
<
void
>
{
await
this
.
channel
.
init
();
...
...
@@ -25,12 +43,10 @@ class WsIpcInterface implements IpcInterface {
}
public
onCommand
(
listener
:
(
commandType
:
string
,
content
:
string
)
=>
void
):
void
{
this
.
channel
.
onCommand
((
command
:
string
)
=>
{
listener
(
command
.
slice
(
0
,
2
),
command
.
slice
(
2
));
});
this
.
commandListener
=
listener
;
}
public
onError
(
listener
:
(
error
:
Error
)
=>
void
):
void
{
this
.
channel
.
onError
(
listener
)
;
this
.
errorListener
=
listener
;
}
}
ts/nni_manager/core/tuner_command_channel/websocket_channel.ts
View file @
a74fa40d
...
...
@@ -13,9 +13,9 @@
import
assert
from
'
assert/strict
'
;
import
{
EventEmitter
}
from
'
events
'
;
import
{
Deferred
}
from
'
ts-deferred
'
;
import
type
WebSocket
from
'
ws
'
;
import
{
Deferred
}
from
'
common/deferred
'
;
import
{
Logger
,
getLogger
}
from
'
common/log
'
;
const
logger
:
Logger
=
getLogger
(
'
tuner_command_channel.WebSocketChannel
'
);
...
...
@@ -38,46 +38,38 @@ export function getWebSocketChannel(): WebSocketChannel {
/**
* The callback to serve WebSocket connection request. Used by REST server module.
* It should only be invoked once, or an error will be raised.
*
* Type hint of express-ws is somewhat problematic. Don't want to waste time on it so use `any`.
* If it is invoked more than once, the previous connection will be dropped.
**/
export
function
serveWebSocket
(
ws
:
WebSocket
):
void
{
channelSingleton
.
se
t
WebSocket
(
ws
);
channelSingleton
.
se
rve
WebSocket
(
ws
);
}
class
WebSocketChannelImpl
implements
WebSocketChannel
{
private
deferredInit
:
Deferred
<
void
>
|
null
=
new
Deferred
<
void
>
();
private
deferredInit
:
Deferred
<
void
>
=
new
Deferred
<
void
>
();
private
emitter
:
EventEmitter
=
new
EventEmitter
();
private
heartbeatTimer
!
:
NodeJS
.
Timer
;
private
serving
:
boolean
=
false
;
private
waitingPong
:
boolean
=
false
;
private
ws
!
:
WebSocket
;
public
setWebSocket
(
ws
:
WebSocket
):
void
{
if
(
this
.
ws
!==
undefined
)
{
logger
.
error
(
'
A second client is trying to connect.
'
);
ws
.
close
(
4030
,
'
Already serving a tuner
'
);
return
;
}
if
(
this
.
deferredInit
===
null
)
{
logger
.
error
(
'
Connection timed out.
'
);
ws
.
close
(
4080
,
'
Timeout
'
);
return
;
public
serveWebSocket
(
ws
:
WebSocket
):
void
{
if
(
this
.
ws
===
undefined
)
{
logger
.
debug
(
'
Connected.
'
);
}
else
{
logger
.
warning
(
'
Reconnecting. Drop previous connection.
'
);
this
.
dropConnection
(
'
Reconnected
'
);
}
logger
.
debug
(
'
Connected.
'
);
this
.
serving
=
true
;
this
.
ws
=
ws
;
ws
.
on
(
'
close
'
,
()
=>
{
this
.
handle
Error
(
new
Error
(
'
tuner_command_channel: Tuner closed connection
'
));
}
);
ws
.
on
(
'
error
'
,
this
.
handleError
.
bind
(
this
)
);
ws
.
on
(
'
message
'
,
this
.
receive
.
bind
(
this
)
);
ws
.
on
(
'
pong
'
,
()
=>
{
this
.
waitingPong
=
false
;
}
);
this
.
ws
.
on
(
'
close
'
,
this
.
handle
WsClose
);
this
.
ws
.
on
(
'
error
'
,
this
.
handle
Ws
Error
);
this
.
ws
.
on
(
'
message
'
,
this
.
handleWsMessage
);
this
.
ws
.
on
(
'
pong
'
,
this
.
handleWsPong
);
this
.
heartbeatTimer
=
setInterval
(
this
.
heartbeat
.
bind
(
this
),
heartbeatInterval
);
this
.
deferredInit
.
resolve
();
this
.
deferredInit
=
null
;
}
public
init
():
Promise
<
void
>
{
...
...
@@ -85,13 +77,12 @@ class WebSocketChannelImpl implements WebSocketChannel {
logger
.
debug
(
'
Waiting connection...
'
);
// TODO: This is a quick fix. It should check tuner's process status instead.
setTimeout
(()
=>
{
if
(
this
.
deferredInit
!==
null
)
{
if
(
!
this
.
deferredInit
.
settled
)
{
const
msg
=
'
Tuner did not connect in 10 seconds. Please check tuner (dispatcher) log.
'
;
this
.
deferredInit
.
reject
(
new
Error
(
'
tuner_command_channel:
'
+
msg
));
this
.
deferredInit
=
null
;
}
},
10000
);
return
this
.
deferredInit
!
.
promise
;
return
this
.
deferredInit
.
promise
;
}
else
{
logger
.
debug
(
'
Initialized.
'
);
...
...
@@ -127,6 +118,49 @@ class WebSocketChannelImpl implements WebSocketChannel {
this
.
emitter
.
on
(
'
error
'
,
callback
);
}
/* Following callbacks must be auto-binded arrow functions to be turned off */
private
handleWsClose
=
():
void
=>
{
this
.
handleError
(
new
Error
(
'
tuner_command_channel: Tuner closed connection
'
));
}
private
handleWsError
=
(
error
:
Error
):
void
=>
{
this
.
handleError
(
error
);
}
private
handleWsMessage
=
(
data
:
Buffer
,
_isBinary
:
boolean
):
void
=>
{
this
.
receive
(
data
);
}
private
handleWsPong
=
():
void
=>
{
this
.
waitingPong
=
false
;
}
private
dropConnection
(
reason
:
string
):
void
{
if
(
this
.
ws
===
undefined
)
{
return
;
}
this
.
serving
=
false
;
clearInterval
(
this
.
heartbeatTimer
);
this
.
ws
.
off
(
'
close
'
,
this
.
handleWsClose
);
this
.
ws
.
off
(
'
error
'
,
this
.
handleWsError
);
this
.
ws
.
off
(
'
message
'
,
this
.
handleWsMessage
);
this
.
ws
.
off
(
'
pong
'
,
this
.
handleWsPong
);
this
.
ws
.
on
(
'
close
'
,
()
=>
{
logger
.
info
(
'
Connection dropped
'
);
});
this
.
ws
.
on
(
'
message
'
,
(
data
,
_isBinary
)
=>
{
logger
.
error
(
'
Received message after reconnect:
'
,
data
);
});
this
.
ws
.
on
(
'
pong
'
,
()
=>
{
logger
.
error
(
'
Received pong after reconnect.
'
);
});
this
.
ws
.
close
(
1001
,
reason
);
}
private
heartbeat
():
void
{
if
(
this
.
waitingPong
)
{
this
.
ws
.
terminate
();
// this will trigger "close" event
...
...
@@ -137,7 +171,7 @@ class WebSocketChannelImpl implements WebSocketChannel {
this
.
ws
.
ping
();
}
private
receive
(
data
:
Buffer
,
_isBinary
:
boolean
):
void
{
private
receive
(
data
:
Buffer
):
void
{
logger
.
debug
(
'
Received
'
,
data
);
this
.
emitter
.
emit
(
'
command
'
,
data
.
toString
());
}
...
...
ts/nni_manager/test/core/tuner_command_channel.test.ts
View file @
a74fa40d
...
...
@@ -68,12 +68,24 @@ async function testError(): Promise<void> {
client
.
resume
();
}
// WebSocket might get broken in long experiments. Simulate reconnect.
async
function
testReconnect
():
Promise
<
void
>
{
client
.
close
();
startClient
();
testInit
();
testSend
();
}
// Clean up.
async
function
testShutdown
():
Promise
<
void
>
{
const
channel
=
getWebSocketChannel
();
await
channel
.
shutdown
();
try
{
client
.
close
();
}
catch
(
error
)
{
console
.
log
(
'
Error on clean up:
'
,
error
);
}
server
.
close
();
}
...
...
@@ -83,6 +95,7 @@ describe('## tuner_command_channel ##', () => {
it
(
'
send
'
,
testSend
);
it
(
'
receive
'
,
testReceive
);
it
(
'
catch error
'
,
testError
);
it
(
'
reconnect
'
,
testReconnect
);
it
(
'
shutdown
'
,
testShutdown
);
});
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment