test_tuner_command_channel.py 2.07 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import atexit
from dataclasses import dataclass
import importlib
import json
import os
from pathlib import Path
from subprocess import Popen, PIPE
import sys
import time

from nni.runtime.tuner_command_channel.websocket import WebSocket

# A helper server that connects its stdio to incoming WebSocket.
_server = None
_client = None

_command1 = 'T_hello world'
_command2 = 'T_你好'

## test cases ##

def test_connect():
    global _client
    port = _init()
    _client = WebSocket(f'ws://localhost:{port}')
    _client.connect()

def test_send():
    # Send commands to server via channel, and get them back via server's stdout.
    _client.send(_command1)
    _client.send(_command2)
    time.sleep(0.01)

    sent1 = _server.stdout.readline().strip()
    assert sent1 == _command1, sent1

    sent2 = _server.stdout.readline().strip()
    assert sent2 == _command2, sent2

def test_receive():
    # Send commands to server via stdin, and get them back via channel.
    _server.stdin.write(_command1 + '\n')
    _server.stdin.write(_command2 + '\n')
    _server.stdin.flush()

    received1 = _client.receive()
    assert received1 == _command1, received1

    received2 = _client.receive()
    assert received2 == _command2, received2

def test_disconnect():
    _client.disconnect()

    # release the port
    global _server
    _server.stdin.write('_close_\n')
    _server.stdin.flush()
    time.sleep(0.1)
    _server.terminate()
    _server = None

## helper ##

def _init():
    global _server

    # launch a server that connects websocket to stdio
    script = (Path(__file__).parent / 'helper/websocket_server.py').resolve()
    _server = Popen([sys.executable, str(script)], stdin=PIPE, stdout=PIPE, encoding='utf_8')
    time.sleep(0.1)

    # if a test fails, make sure to stop the server
    atexit.register(lambda: _server is None or _server.terminate())

    return int(_server.stdout.readline().strip())

if __name__ == '__main__':
    test_connect()
    test_send()
    test_receive()
    test_disconnect()
    print('pass')