test_utils.py 1.07 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
from lmdeploy.pytorch.utils import BasicStreamer, TerminalIO


def test_terminal_io(monkeypatch):
    import io
    tio = TerminalIO()
    inputs = 'hello\n\n'
    # inputs = 'hello\n\n\x1B[A\n\n'
    monkeypatch.setattr('sys.stdin', io.StringIO(inputs))
    string = tio.input()
    tio.output(string)
    assert string == 'hello'
    # string = tio.input()
    # tio.output(string)
    # assert string == 'hello'


def test_basic_streamer():
    output = []

    def decode_func(value):
        return value + 10

    def output_func(value):
        output.append(value)

    streamer = BasicStreamer(decode_func, output_func)
    for i in range(10):
        streamer.put(i)
        if i == 5:
            streamer.end()
    streamer.end()

    assert output == [11, 12, 13, 14, 15, '\n', 17, 18, 19, '\n']

    output.clear()
    streamer = BasicStreamer(decode_func, output_func, skip_prompt=False)
    for i in range(10):
        streamer.put(i)
        if i == 5:
            streamer.end()
    streamer.end()

    assert output == [10, 11, 12, 13, 14, 15, '\n', 16, 17, 18, 19, '\n']