Unverified Commit a74fa40d authored by liuzhe-lz's avatar liuzhe-lz Committed by GitHub
Browse files

Let tuner auto reconnect to NNI manager (#5166)

parent cbd5d8be
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT license. # Licensed under the MIT license.
import os
import argparse import argparse
import logging
import json
import base64 import base64
import json
import logging
import os
import traceback
from .runtime.msg_dispatcher import MsgDispatcher from .runtime.msg_dispatcher import MsgDispatcher
from .runtime.msg_dispatcher_base import MsgDispatcherBase from .runtime.msg_dispatcher_base import MsgDispatcherBase
...@@ -65,6 +66,7 @@ def main(): ...@@ -65,6 +66,7 @@ def main():
tuner._on_error() tuner._on_error()
if assessor is not None: if assessor is not None:
assessor._on_error() assessor._on_error()
dispatcher.report_error(traceback.format_exc())
raise raise
......
...@@ -84,6 +84,16 @@ class MsgDispatcherBase(Recoverable): ...@@ -84,6 +84,16 @@ class MsgDispatcherBase(Recoverable):
_logger.info('Dispatcher terminiated') _logger.info('Dispatcher terminiated')
def report_error(self, error: str) -> None:
'''
Report dispatcher error to NNI manager.
'''
_logger.info(f'Report error to NNI manager: {error}')
try:
self.send(CommandType.Error, error)
except Exception:
_logger.error('Connection to NNI manager is broken. Failed to report error.')
def send(self, command, data): def send(self, command, data):
self._channel._send(command, data) self._channel._send(command, data)
......
...@@ -9,9 +9,14 @@ from __future__ import annotations ...@@ -9,9 +9,14 @@ from __future__ import annotations
__all__ = ['TunerCommandChannel'] __all__ = ['TunerCommandChannel']
import logging
import time
from .command_type import CommandType from .command_type import CommandType
from .websocket import WebSocket from .websocket import WebSocket
_logger = logging.getLogger(__name__)
class TunerCommandChannel: class TunerCommandChannel:
""" """
A channel to communicate with NNI manager. A channel to communicate with NNI manager.
...@@ -35,7 +40,9 @@ class TunerCommandChannel: ...@@ -35,7 +40,9 @@ class TunerCommandChannel:
""" """
def __init__(self, url: str): def __init__(self, url: str):
self._url = url
self._channel = WebSocket(url) self._channel = WebSocket(url)
self._retry_intervals = [0, 1, 10]
def connect(self) -> None: def connect(self) -> None:
self._channel.connect() self._channel.connect()
...@@ -51,11 +58,50 @@ class TunerCommandChannel: ...@@ -51,11 +58,50 @@ class TunerCommandChannel:
def _send(self, command_type: CommandType, data: str) -> None: def _send(self, command_type: CommandType, data: str) -> None:
command = command_type.value.decode() + data command = command_type.value.decode() + data
self._channel.send(command) try:
self._channel.send(command)
except WebSocket.ConnectionClosed:
self._retry_send(command)
def _retry_send(self, command: str) -> None:
_logger.warning('Connection lost. Trying to reconnect...')
for i, interval in enumerate(self._retry_intervals):
_logger.info(f'Attempt #{i}, wait {interval} seconds...')
time.sleep(interval)
self._channel = WebSocket(self._url)
try:
self._channel.send(command)
_logger.info('Reconnected.')
return
except Exception as e:
_logger.exception(e)
_logger.error('Failed to reconnect.')
raise RuntimeError('Connection lost')
def _receive(self) -> tuple[CommandType, str] | tuple[None, None]: def _receive(self) -> tuple[CommandType, str] | tuple[None, None]:
command = self._channel.receive() try:
command = self._channel.receive()
except WebSocket.ConnectionClosed:
# this is for robustness and should never happen
_logger.warning('ConnectionClosed exception on receiving.')
command = None
if command is None: if command is None:
raise RuntimeError('NNI manager closed connection') command = self._retry_receive()
command_type = CommandType(command[:2].encode()) command_type = CommandType(command[:2].encode())
return command_type, command[2:] return command_type, command[2:]
def _retry_receive(self) -> str:
_logger.warning('Connection lost. Trying to reconnect...')
for i, interval in enumerate(self._retry_intervals):
_logger.info(f'Attempt #{i}, wait {interval} seconds...')
time.sleep(interval)
self._channel = WebSocket(self._url)
try:
command = self._channel.receive()
except WebSocket.ConnectionClosed:
command = None # for robustness
if command is not None:
_logger.info('Reconnected')
return command
_logger.error('Failed to reconnect.')
raise RuntimeError('Connection lost')
...@@ -21,3 +21,4 @@ class CommandType(Enum): ...@@ -21,3 +21,4 @@ class CommandType(Enum):
SendTrialJobParameter = b'SP' SendTrialJobParameter = b'SP'
NoMoreTrialJobs = b'NO' NoMoreTrialJobs = b'NO'
KillTrialJob = b'KI' KillTrialJob = b'KI'
Error = b'ER'
...@@ -14,7 +14,7 @@ __all__ = ['WebSocket'] ...@@ -14,7 +14,7 @@ __all__ = ['WebSocket']
import asyncio import asyncio
import logging import logging
from threading import Lock, Thread from threading import Lock, Thread
from typing import Any from typing import Any, Type
import websockets import websockets
...@@ -39,6 +39,9 @@ class WebSocket: ...@@ -39,6 +39,9 @@ class WebSocket:
The WebSocket URL. The WebSocket URL.
For tuner command channel it should be something like ``ws://localhost:8080/tuner``. For tuner command channel it should be something like ``ws://localhost:8080/tuner``.
""" """
ConnectionClosed: Type[Exception] = websockets.ConnectionClosed # type: ignore
def __init__(self, url: str): def __init__(self, url: str):
self._url: str = url self._url: str = url
self._ws: Any = None # the library does not provide type hints self._ws: Any = None # the library does not provide type hints
...@@ -74,7 +77,13 @@ class WebSocket: ...@@ -74,7 +77,13 @@ class WebSocket:
def send(self, message: str) -> None: def send(self, message: str) -> None:
_logger.debug(f'Sending {message}') _logger.debug(f'Sending {message}')
_wait(self._ws.send(message)) try:
_wait(self._ws.send(message))
except websockets.ConnectionClosed: # type: ignore
_logger.debug('Connection closed by server.')
self._ws = None
_decrease_refcnt()
raise
def receive(self) -> str | None: def receive(self) -> str | None:
""" """
...@@ -88,7 +97,7 @@ class WebSocket: ...@@ -88,7 +97,7 @@ class WebSocket:
_logger.debug('Connection closed by server.') _logger.debug('Connection closed by server.')
self._ws = None self._ws = None
_decrease_refcnt() _decrease_refcnt()
return None raise
# seems the library will inference whether it's text or binary, so we don't have guarantee # seems the library will inference whether it's text or binary, so we don't have guarantee
if isinstance(msg, bytes): if isinstance(msg, bytes):
......
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
/**
* TODO: Back ported from 3.0 draft.
*
* An augmented version of ts-deferred.
*
* You can `await deferred.promise` more than once and they will be resolved together.
*
* You can resolve a deferred multiple times with identical value and it will be ignored.
*
* If a deferred is resolved and/or rejected with conflict values,
* it will throw error and log both values or reasons.
**/
import util from 'util';
import { Logger, getLogger } from 'common/log';
const logger = getLogger('common.deferred');
export class Deferred<T> {
private resolveCallbacks: any[] = [];
private rejectCallbacks: any[] = [];
private isResolved: boolean = false;
private isRejected: boolean = false;
private resolvedValue?: T;
private rejectedReason?: Error;
public get promise(): Promise<T> {
// use getter to compat ts-deferred
if (this.isResolved) {
return Promise.resolve(this.resolvedValue) as Promise<T>;
}
if (this.isRejected) {
return Promise.reject(this.rejectedReason) as Promise<T>;
}
return new Promise<T>((resolutionFunc, rejectionFunc) => {
this.resolveCallbacks.push(resolutionFunc);
this.rejectCallbacks.push(rejectionFunc);
});
}
public get settled(): boolean {
// use getter for consistent api style
return this.isResolved || this.isRejected;
}
public resolve = (value: T): void => {
if (!this.isResolved && ! this.isRejected) {
this.isResolved = true;
this.resolvedValue = value;
for (const callback of this.resolveCallbacks) {
callback(value);
}
} else if (this.isResolved && this.resolvedValue == value) {
logger.debug('Double resolve:', value);
} else {
const msg = this.errorMessage('trying to resolve with value: ' + util.inspect(value));
logger.error(msg);
throw new Error('Conflict Deferred result. ' + msg);
}
}
public reject = (reason: Error): void => {
if (!this.isResolved && !this.isRejected) {
this.isRejected = true;
this.rejectedReason = reason;
for (const callback of this.rejectCallbacks) {
callback(reason);
}
} else if (this.isRejected) {
logger.warning('Double reject:', this.rejectedReason, reason);
} else {
const msg = this.errorMessage('trying to reject with reason: ' + util.inspect(reason));
logger.error(msg);
throw new Error('Conflict Deferred result. ' + msg);
}
}
private errorMessage(curStat: string): string {
let prevStat = '';
if (this.isResolved) {
prevStat = 'Already resolved with value: ' + util.inspect(this.resolvedValue);
}
if (this.isRejected) {
prevStat = 'Already rejected with reason: ' + util.inspect(this.rejectedReason);
}
return prevStat + ' ; ' + curStat;
}
}
...@@ -10,6 +10,24 @@ export async function createDispatcherInterface(): Promise<IpcInterface> { ...@@ -10,6 +10,24 @@ export async function createDispatcherInterface(): Promise<IpcInterface> {
class WsIpcInterface implements IpcInterface { class WsIpcInterface implements IpcInterface {
private channel: WebSocketChannel = getWebSocketChannel(); private channel: WebSocketChannel = getWebSocketChannel();
private commandListener?: (commandType: string, content: string) => void;
private errorListener?: (error: Error) => void;
constructor() {
this.channel.onCommand((command: string) => {
const commandType = command.slice(0, 2);
const content = command.slice(2);
if (commandType === 'ER') {
if (this.errorListener !== undefined) {
this.errorListener(new Error(content));
}
} else {
if (this.commandListener !== undefined) {
this.commandListener(commandType, content);
}
}
});
}
public async init(): Promise<void> { public async init(): Promise<void> {
await this.channel.init(); await this.channel.init();
...@@ -25,12 +43,10 @@ class WsIpcInterface implements IpcInterface { ...@@ -25,12 +43,10 @@ class WsIpcInterface implements IpcInterface {
} }
public onCommand(listener: (commandType: string, content: string) => void): void { public onCommand(listener: (commandType: string, content: string) => void): void {
this.channel.onCommand((command: string) => { this.commandListener = listener;
listener(command.slice(0, 2), command.slice(2));
});
} }
public onError(listener: (error: Error) => void): void { public onError(listener: (error: Error) => void): void {
this.channel.onError(listener); this.errorListener = listener;
} }
} }
...@@ -13,9 +13,9 @@ ...@@ -13,9 +13,9 @@
import assert from 'assert/strict'; import assert from 'assert/strict';
import { EventEmitter } from 'events'; import { EventEmitter } from 'events';
import { Deferred } from 'ts-deferred';
import type WebSocket from 'ws'; import type WebSocket from 'ws';
import { Deferred } from 'common/deferred';
import { Logger, getLogger } from 'common/log'; import { Logger, getLogger } from 'common/log';
const logger: Logger = getLogger('tuner_command_channel.WebSocketChannel'); const logger: Logger = getLogger('tuner_command_channel.WebSocketChannel');
...@@ -38,46 +38,38 @@ export function getWebSocketChannel(): WebSocketChannel { ...@@ -38,46 +38,38 @@ export function getWebSocketChannel(): WebSocketChannel {
/** /**
* The callback to serve WebSocket connection request. Used by REST server module. * The callback to serve WebSocket connection request. Used by REST server module.
* It should only be invoked once, or an error will be raised. * If it is invoked more than once, the previous connection will be dropped.
*
* Type hint of express-ws is somewhat problematic. Don't want to waste time on it so use `any`.
**/ **/
export function serveWebSocket(ws: WebSocket): void { export function serveWebSocket(ws: WebSocket): void {
channelSingleton.setWebSocket(ws); channelSingleton.serveWebSocket(ws);
} }
class WebSocketChannelImpl implements WebSocketChannel { class WebSocketChannelImpl implements WebSocketChannel {
private deferredInit: Deferred<void> | null = new Deferred<void>(); private deferredInit: Deferred<void> = new Deferred<void>();
private emitter: EventEmitter = new EventEmitter(); private emitter: EventEmitter = new EventEmitter();
private heartbeatTimer!: NodeJS.Timer; private heartbeatTimer!: NodeJS.Timer;
private serving: boolean = false; private serving: boolean = false;
private waitingPong: boolean = false; private waitingPong: boolean = false;
private ws!: WebSocket; private ws!: WebSocket;
public setWebSocket(ws: WebSocket): void { public serveWebSocket(ws: WebSocket): void {
if (this.ws !== undefined) { if (this.ws === undefined) {
logger.error('A second client is trying to connect.'); logger.debug('Connected.');
ws.close(4030, 'Already serving a tuner'); } else {
return; logger.warning('Reconnecting. Drop previous connection.');
} this.dropConnection('Reconnected');
if (this.deferredInit === null) {
logger.error('Connection timed out.');
ws.close(4080, 'Timeout');
return;
} }
logger.debug('Connected.');
this.serving = true; this.serving = true;
this.ws = ws; this.ws = ws;
ws.on('close', () => { this.handleError(new Error('tuner_command_channel: Tuner closed connection')); }); this.ws.on('close', this.handleWsClose);
ws.on('error', this.handleError.bind(this)); this.ws.on('error', this.handleWsError);
ws.on('message', this.receive.bind(this)); this.ws.on('message', this.handleWsMessage);
ws.on('pong', () => { this.waitingPong = false; }); this.ws.on('pong', this.handleWsPong);
this.heartbeatTimer = setInterval(this.heartbeat.bind(this), heartbeatInterval); this.heartbeatTimer = setInterval(this.heartbeat.bind(this), heartbeatInterval);
this.deferredInit.resolve(); this.deferredInit.resolve();
this.deferredInit = null;
} }
public init(): Promise<void> { public init(): Promise<void> {
...@@ -85,13 +77,12 @@ class WebSocketChannelImpl implements WebSocketChannel { ...@@ -85,13 +77,12 @@ class WebSocketChannelImpl implements WebSocketChannel {
logger.debug('Waiting connection...'); logger.debug('Waiting connection...');
// TODO: This is a quick fix. It should check tuner's process status instead. // TODO: This is a quick fix. It should check tuner's process status instead.
setTimeout(() => { setTimeout(() => {
if (this.deferredInit !== null) { if (!this.deferredInit.settled) {
const msg = 'Tuner did not connect in 10 seconds. Please check tuner (dispatcher) log.'; const msg = 'Tuner did not connect in 10 seconds. Please check tuner (dispatcher) log.';
this.deferredInit.reject(new Error('tuner_command_channel: ' + msg)); this.deferredInit.reject(new Error('tuner_command_channel: ' + msg));
this.deferredInit = null;
} }
}, 10000); }, 10000);
return this.deferredInit!.promise; return this.deferredInit.promise;
} else { } else {
logger.debug('Initialized.'); logger.debug('Initialized.');
...@@ -127,6 +118,49 @@ class WebSocketChannelImpl implements WebSocketChannel { ...@@ -127,6 +118,49 @@ class WebSocketChannelImpl implements WebSocketChannel {
this.emitter.on('error', callback); this.emitter.on('error', callback);
} }
/* Following callbacks must be auto-binded arrow functions to be turned off */
private handleWsClose = (): void => {
this.handleError(new Error('tuner_command_channel: Tuner closed connection'));
}
private handleWsError = (error: Error): void => {
this.handleError(error);
}
private handleWsMessage = (data: Buffer, _isBinary: boolean): void => {
this.receive(data);
}
private handleWsPong = (): void => {
this.waitingPong = false;
}
private dropConnection(reason: string): void {
if (this.ws === undefined) {
return;
}
this.serving = false;
clearInterval(this.heartbeatTimer);
this.ws.off('close', this.handleWsClose);
this.ws.off('error', this.handleWsError);
this.ws.off('message', this.handleWsMessage);
this.ws.off('pong', this.handleWsPong);
this.ws.on('close', () => {
logger.info('Connection dropped');
});
this.ws.on('message', (data, _isBinary) => {
logger.error('Received message after reconnect:', data);
});
this.ws.on('pong', () => {
logger.error('Received pong after reconnect.');
});
this.ws.close(1001, reason);
}
private heartbeat(): void { private heartbeat(): void {
if (this.waitingPong) { if (this.waitingPong) {
this.ws.terminate(); // this will trigger "close" event this.ws.terminate(); // this will trigger "close" event
...@@ -137,7 +171,7 @@ class WebSocketChannelImpl implements WebSocketChannel { ...@@ -137,7 +171,7 @@ class WebSocketChannelImpl implements WebSocketChannel {
this.ws.ping(); this.ws.ping();
} }
private receive(data: Buffer, _isBinary: boolean): void { private receive(data: Buffer): void {
logger.debug('Received', data); logger.debug('Received', data);
this.emitter.emit('command', data.toString()); this.emitter.emit('command', data.toString());
} }
......
...@@ -68,12 +68,24 @@ async function testError(): Promise<void> { ...@@ -68,12 +68,24 @@ async function testError(): Promise<void> {
client.resume(); client.resume();
} }
// WebSocket might get broken in long experiments. Simulate reconnect.
async function testReconnect(): Promise<void> {
client.close();
startClient();
testInit();
testSend();
}
// Clean up. // Clean up.
async function testShutdown(): Promise<void> { async function testShutdown(): Promise<void> {
const channel = getWebSocketChannel(); const channel = getWebSocketChannel();
await channel.shutdown(); await channel.shutdown();
client.close(); try {
client.close();
} catch (error) {
console.log('Error on clean up:', error);
}
server.close(); server.close();
} }
...@@ -83,6 +95,7 @@ describe('## tuner_command_channel ##', () => { ...@@ -83,6 +95,7 @@ describe('## tuner_command_channel ##', () => {
it('send', testSend); it('send', testSend);
it('receive', testReceive); it('receive', testReceive);
it('catch error', testError); it('catch error', testError);
it('reconnect', testReconnect);
it('shutdown', testShutdown); it('shutdown', testShutdown);
}); });
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment