Unverified Commit 54fef7fa authored by SparkSnail's avatar SparkSnail Committed by GitHub
Browse files

Add timeout for web_channel in trial_runner (#2710)

parent e5b58531
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# Licensed under the MIT license. # Licensed under the MIT license.
import asyncio import asyncio
import os
import websockets import websockets
from .base_channel import BaseChannel from .base_channel import BaseChannel
...@@ -16,6 +16,7 @@ class WebChannel(BaseChannel): ...@@ -16,6 +16,7 @@ class WebChannel(BaseChannel):
self.args = args self.args = args
self.client = None self.client = None
self.in_cache = b"" self.in_cache = b""
self.timeout = 10
super(WebChannel, self).__init__(args) super(WebChannel, self).__init__(args)
...@@ -23,12 +24,15 @@ class WebChannel(BaseChannel): ...@@ -23,12 +24,15 @@ class WebChannel(BaseChannel):
def _inner_open(self): def _inner_open(self):
url = "ws://{}:{}".format(self.args.nnimanager_ip, self.args.nnimanager_port) url = "ws://{}:{}".format(self.args.nnimanager_ip, self.args.nnimanager_port)
nni_log(LogType.Info, 'WebChannel: connected with info %s' % url) try:
connect = asyncio.wait_for(websockets.connect(url), self.timeout)
connect = websockets.connect(url)
self._event_loop = asyncio.get_event_loop() self._event_loop = asyncio.get_event_loop()
client = self._event_loop.run_until_complete(connect) client = self._event_loop.run_until_complete(connect)
self.client = client self.client = client
nni_log(LogType.Info, 'WebChannel: connected with info %s' % url)
except asyncio.TimeoutError:
nni_log(LogType.Error, 'connect to %s timeout! Please make sure NNIManagerIP configured correctly, and accessable.' % url)
os._exit(1)
def _inner_close(self): def _inner_close(self):
if self.client is not None: if self.client is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment