Unverified Commit a74fa40d authored by liuzhe-lz's avatar liuzhe-lz Committed by GitHub
Browse files

Let tuner auto reconnect to NNI manager (#5166)

parent cbd5d8be
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import os
import argparse
import logging
import json
import base64
import json
import logging
import os
import traceback
from .runtime.msg_dispatcher import MsgDispatcher
from .runtime.msg_dispatcher_base import MsgDispatcherBase
......@@ -65,6 +66,7 @@ def main():
tuner._on_error()
if assessor is not None:
assessor._on_error()
dispatcher.report_error(traceback.format_exc())
raise
......
......@@ -84,6 +84,16 @@ class MsgDispatcherBase(Recoverable):
_logger.info('Dispatcher terminiated')
def report_error(self, error: str) -> None:
'''
Report dispatcher error to NNI manager.
'''
_logger.info(f'Report error to NNI manager: {error}')
try:
self.send(CommandType.Error, error)
except Exception:
_logger.error('Connection to NNI manager is broken. Failed to report error.')
def send(self, command, data):
self._channel._send(command, data)
......
......@@ -9,9 +9,14 @@ from __future__ import annotations
__all__ = ['TunerCommandChannel']
import logging
import time
from .command_type import CommandType
from .websocket import WebSocket
_logger = logging.getLogger(__name__)
class TunerCommandChannel:
"""
A channel to communicate with NNI manager.
......@@ -35,7 +40,9 @@ class TunerCommandChannel:
"""
def __init__(self, url: str):
self._url = url
self._channel = WebSocket(url)
self._retry_intervals = [0, 1, 10]
def connect(self) -> None:
self._channel.connect()
......@@ -51,11 +58,50 @@ class TunerCommandChannel:
def _send(self, command_type: CommandType, data: str) -> None:
command = command_type.value.decode() + data
self._channel.send(command)
try:
self._channel.send(command)
except WebSocket.ConnectionClosed:
self._retry_send(command)
def _retry_send(self, command: str) -> None:
_logger.warning('Connection lost. Trying to reconnect...')
for i, interval in enumerate(self._retry_intervals):
_logger.info(f'Attempt #{i}, wait {interval} seconds...')
time.sleep(interval)
self._channel = WebSocket(self._url)
try:
self._channel.send(command)
_logger.info('Reconnected.')
return
except Exception as e:
_logger.exception(e)
_logger.error('Failed to reconnect.')
raise RuntimeError('Connection lost')
def _receive(self) -> tuple[CommandType, str] | tuple[None, None]:
command = self._channel.receive()
try:
command = self._channel.receive()
except WebSocket.ConnectionClosed:
# this is for robustness and should never happen
_logger.warning('ConnectionClosed exception on receiving.')
command = None
if command is None:
raise RuntimeError('NNI manager closed connection')
command = self._retry_receive()
command_type = CommandType(command[:2].encode())
return command_type, command[2:]
def _retry_receive(self) -> str:
_logger.warning('Connection lost. Trying to reconnect...')
for i, interval in enumerate(self._retry_intervals):
_logger.info(f'Attempt #{i}, wait {interval} seconds...')
time.sleep(interval)
self._channel = WebSocket(self._url)
try:
command = self._channel.receive()
except WebSocket.ConnectionClosed:
command = None # for robustness
if command is not None:
_logger.info('Reconnected')
return command
_logger.error('Failed to reconnect.')
raise RuntimeError('Connection lost')
......@@ -21,3 +21,4 @@ class CommandType(Enum):
SendTrialJobParameter = b'SP'
NoMoreTrialJobs = b'NO'
KillTrialJob = b'KI'
Error = b'ER'
......@@ -14,7 +14,7 @@ __all__ = ['WebSocket']
import asyncio
import logging
from threading import Lock, Thread
from typing import Any
from typing import Any, Type
import websockets
......@@ -39,6 +39,9 @@ class WebSocket:
The WebSocket URL.
For tuner command channel it should be something like ``ws://localhost:8080/tuner``.
"""
ConnectionClosed: Type[Exception] = websockets.ConnectionClosed # type: ignore
def __init__(self, url: str):
self._url: str = url
self._ws: Any = None # the library does not provide type hints
......@@ -74,7 +77,13 @@ class WebSocket:
def send(self, message: str) -> None:
_logger.debug(f'Sending {message}')
_wait(self._ws.send(message))
try:
_wait(self._ws.send(message))
except websockets.ConnectionClosed: # type: ignore
_logger.debug('Connection closed by server.')
self._ws = None
_decrease_refcnt()
raise
def receive(self) -> str | None:
"""
......@@ -88,7 +97,7 @@ class WebSocket:
_logger.debug('Connection closed by server.')
self._ws = None
_decrease_refcnt()
return None
raise
# seems the library will inference whether it's text or binary, so we don't have guarantee
if isinstance(msg, bytes):
......
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
/**
* TODO: Back ported from 3.0 draft.
*
* An augmented version of ts-deferred.
*
* You can `await deferred.promise` more than once and they will be resolved together.
*
* You can resolve a deferred multiple times with identical value and it will be ignored.
*
* If a deferred is resolved and/or rejected with conflict values,
* it will throw error and log both values or reasons.
**/
import util from 'util';
import { Logger, getLogger } from 'common/log';
const logger = getLogger('common.deferred');
export class Deferred<T> {
private resolveCallbacks: any[] = [];
private rejectCallbacks: any[] = [];
private isResolved: boolean = false;
private isRejected: boolean = false;
private resolvedValue?: T;
private rejectedReason?: Error;
public get promise(): Promise<T> {
// use getter to compat ts-deferred
if (this.isResolved) {
return Promise.resolve(this.resolvedValue) as Promise<T>;
}
if (this.isRejected) {
return Promise.reject(this.rejectedReason) as Promise<T>;
}
return new Promise<T>((resolutionFunc, rejectionFunc) => {
this.resolveCallbacks.push(resolutionFunc);
this.rejectCallbacks.push(rejectionFunc);
});
}
public get settled(): boolean {
// use getter for consistent api style
return this.isResolved || this.isRejected;
}
public resolve = (value: T): void => {
if (!this.isResolved && ! this.isRejected) {
this.isResolved = true;
this.resolvedValue = value;
for (const callback of this.resolveCallbacks) {
callback(value);
}
} else if (this.isResolved && this.resolvedValue == value) {
logger.debug('Double resolve:', value);
} else {
const msg = this.errorMessage('trying to resolve with value: ' + util.inspect(value));
logger.error(msg);
throw new Error('Conflict Deferred result. ' + msg);
}
}
public reject = (reason: Error): void => {
if (!this.isResolved && !this.isRejected) {
this.isRejected = true;
this.rejectedReason = reason;
for (const callback of this.rejectCallbacks) {
callback(reason);
}
} else if (this.isRejected) {
logger.warning('Double reject:', this.rejectedReason, reason);
} else {
const msg = this.errorMessage('trying to reject with reason: ' + util.inspect(reason));
logger.error(msg);
throw new Error('Conflict Deferred result. ' + msg);
}
}
private errorMessage(curStat: string): string {
let prevStat = '';
if (this.isResolved) {
prevStat = 'Already resolved with value: ' + util.inspect(this.resolvedValue);
}
if (this.isRejected) {
prevStat = 'Already rejected with reason: ' + util.inspect(this.rejectedReason);
}
return prevStat + ' ; ' + curStat;
}
}
......@@ -10,6 +10,24 @@ export async function createDispatcherInterface(): Promise<IpcInterface> {
class WsIpcInterface implements IpcInterface {
private channel: WebSocketChannel = getWebSocketChannel();
private commandListener?: (commandType: string, content: string) => void;
private errorListener?: (error: Error) => void;
constructor() {
this.channel.onCommand((command: string) => {
const commandType = command.slice(0, 2);
const content = command.slice(2);
if (commandType === 'ER') {
if (this.errorListener !== undefined) {
this.errorListener(new Error(content));
}
} else {
if (this.commandListener !== undefined) {
this.commandListener(commandType, content);
}
}
});
}
public async init(): Promise<void> {
await this.channel.init();
......@@ -25,12 +43,10 @@ class WsIpcInterface implements IpcInterface {
}
public onCommand(listener: (commandType: string, content: string) => void): void {
this.channel.onCommand((command: string) => {
listener(command.slice(0, 2), command.slice(2));
});
this.commandListener = listener;
}
public onError(listener: (error: Error) => void): void {
this.channel.onError(listener);
this.errorListener = listener;
}
}
......@@ -13,9 +13,9 @@
import assert from 'assert/strict';
import { EventEmitter } from 'events';
import { Deferred } from 'ts-deferred';
import type WebSocket from 'ws';
import { Deferred } from 'common/deferred';
import { Logger, getLogger } from 'common/log';
const logger: Logger = getLogger('tuner_command_channel.WebSocketChannel');
......@@ -38,46 +38,38 @@ export function getWebSocketChannel(): WebSocketChannel {
/**
* The callback to serve WebSocket connection request. Used by REST server module.
* It should only be invoked once, or an error will be raised.
*
* Type hint of express-ws is somewhat problematic. Don't want to waste time on it so use `any`.
* If it is invoked more than once, the previous connection will be dropped.
**/
export function serveWebSocket(ws: WebSocket): void {
channelSingleton.setWebSocket(ws);
channelSingleton.serveWebSocket(ws);
}
class WebSocketChannelImpl implements WebSocketChannel {
private deferredInit: Deferred<void> | null = new Deferred<void>();
private deferredInit: Deferred<void> = new Deferred<void>();
private emitter: EventEmitter = new EventEmitter();
private heartbeatTimer!: NodeJS.Timer;
private serving: boolean = false;
private waitingPong: boolean = false;
private ws!: WebSocket;
public setWebSocket(ws: WebSocket): void {
if (this.ws !== undefined) {
logger.error('A second client is trying to connect.');
ws.close(4030, 'Already serving a tuner');
return;
}
if (this.deferredInit === null) {
logger.error('Connection timed out.');
ws.close(4080, 'Timeout');
return;
public serveWebSocket(ws: WebSocket): void {
if (this.ws === undefined) {
logger.debug('Connected.');
} else {
logger.warning('Reconnecting. Drop previous connection.');
this.dropConnection('Reconnected');
}
logger.debug('Connected.');
this.serving = true;
this.ws = ws;
ws.on('close', () => { this.handleError(new Error('tuner_command_channel: Tuner closed connection')); });
ws.on('error', this.handleError.bind(this));
ws.on('message', this.receive.bind(this));
ws.on('pong', () => { this.waitingPong = false; });
this.ws.on('close', this.handleWsClose);
this.ws.on('error', this.handleWsError);
this.ws.on('message', this.handleWsMessage);
this.ws.on('pong', this.handleWsPong);
this.heartbeatTimer = setInterval(this.heartbeat.bind(this), heartbeatInterval);
this.deferredInit.resolve();
this.deferredInit = null;
}
public init(): Promise<void> {
......@@ -85,13 +77,12 @@ class WebSocketChannelImpl implements WebSocketChannel {
logger.debug('Waiting connection...');
// TODO: This is a quick fix. It should check tuner's process status instead.
setTimeout(() => {
if (this.deferredInit !== null) {
if (!this.deferredInit.settled) {
const msg = 'Tuner did not connect in 10 seconds. Please check tuner (dispatcher) log.';
this.deferredInit.reject(new Error('tuner_command_channel: ' + msg));
this.deferredInit = null;
}
}, 10000);
return this.deferredInit!.promise;
return this.deferredInit.promise;
} else {
logger.debug('Initialized.');
......@@ -127,6 +118,49 @@ class WebSocketChannelImpl implements WebSocketChannel {
this.emitter.on('error', callback);
}
/* Following callbacks must be auto-binded arrow functions to be turned off */
private handleWsClose = (): void => {
this.handleError(new Error('tuner_command_channel: Tuner closed connection'));
}
private handleWsError = (error: Error): void => {
this.handleError(error);
}
private handleWsMessage = (data: Buffer, _isBinary: boolean): void => {
this.receive(data);
}
private handleWsPong = (): void => {
this.waitingPong = false;
}
private dropConnection(reason: string): void {
if (this.ws === undefined) {
return;
}
this.serving = false;
clearInterval(this.heartbeatTimer);
this.ws.off('close', this.handleWsClose);
this.ws.off('error', this.handleWsError);
this.ws.off('message', this.handleWsMessage);
this.ws.off('pong', this.handleWsPong);
this.ws.on('close', () => {
logger.info('Connection dropped');
});
this.ws.on('message', (data, _isBinary) => {
logger.error('Received message after reconnect:', data);
});
this.ws.on('pong', () => {
logger.error('Received pong after reconnect.');
});
this.ws.close(1001, reason);
}
private heartbeat(): void {
if (this.waitingPong) {
this.ws.terminate(); // this will trigger "close" event
......@@ -137,7 +171,7 @@ class WebSocketChannelImpl implements WebSocketChannel {
this.ws.ping();
}
private receive(data: Buffer, _isBinary: boolean): void {
private receive(data: Buffer): void {
logger.debug('Received', data);
this.emitter.emit('command', data.toString());
}
......
......@@ -68,12 +68,24 @@ async function testError(): Promise<void> {
client.resume();
}
// WebSocket might get broken in long experiments. Simulate reconnect.
async function testReconnect(): Promise<void> {
client.close();
startClient();
testInit();
testSend();
}
// Clean up.
async function testShutdown(): Promise<void> {
const channel = getWebSocketChannel();
await channel.shutdown();
client.close();
try {
client.close();
} catch (error) {
console.log('Error on clean up:', error);
}
server.close();
}
......@@ -83,6 +95,7 @@ describe('## tuner_command_channel ##', () => {
it('send', testSend);
it('receive', testReceive);
it('catch error', testError);
it('reconnect', testReconnect);
it('shutdown', testShutdown);
});
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment