test_protocol.py 1.27 KB
Newer Older
liuzhe-lz's avatar
liuzhe-lz committed
1
2
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
Deshui Yu's avatar
Deshui Yu committed
3

4
5
from nni.runtime import protocol
from nni.runtime.protocol import CommandType, send, receive
Deshui Yu's avatar
Deshui Yu committed
6
7
8
9
10
11

from io import BytesIO
from unittest import TestCase, main


def _prepare_send():
12
13
14
    out_file = BytesIO()
    protocol._set_out_file(out_file)
    return out_file
Deshui Yu's avatar
Deshui Yu committed
15
16

def _prepare_receive(data):
17
    protocol._set_in_file(BytesIO(data))
Deshui Yu's avatar
Deshui Yu committed
18
19
20
21
22
23


class ProtocolTestCase(TestCase):
    def test_send_en(self):
        out_file = _prepare_send()
        send(CommandType.NewTrialJob, 'CONTENT')
24
        self.assertEqual(out_file.getvalue(), b'TR00000000000007CONTENT')
Deshui Yu's avatar
Deshui Yu committed
25
26
27
28

    def test_send_zh(self):
        out_file = _prepare_send()
        send(CommandType.NewTrialJob, '你好')
29
        self.assertEqual(out_file.getvalue(), 'TR00000000000006你好'.encode('utf8'))
Deshui Yu's avatar
Deshui Yu committed
30
31

    def test_receive_en(self):
32
        _prepare_receive(b'IN00000000000005hello')
Deshui Yu's avatar
Deshui Yu committed
33
34
35
36
37
        command, data = receive()
        self.assertIs(command, CommandType.Initialize)
        self.assertEqual(data, 'hello')

    def test_receive_zh(self):
38
        _prepare_receive('IN00000000000006世界'.encode('utf8'))
Deshui Yu's avatar
Deshui Yu committed
39
40
41
42
43
44
45
        command, data = receive()
        self.assertIs(command, CommandType.Initialize)
        self.assertEqual(data, '世界')


if __name__ == '__main__':
    main()