client.py 1.87 KB
Newer Older
lvhan028's avatar
lvhan028 committed
1
2
3
4
5
# Copyright (c) OpenMMLab. All rights reserved.
import os

import fire

6
from lmdeploy.serve.turbomind.chatbot import Chatbot
lvhan028's avatar
lvhan028 committed
7
8
9


def input_prompt():
lvhan028's avatar
lvhan028 committed
10
    """Input a prompt in the console interface."""
lvhan028's avatar
lvhan028 committed
11
12
13
14
15
    print('\ndouble enter to end input >>> ', end='')
    sentinel = ''  # ends when this string is seen
    return '\n'.join(iter(input, sentinel))


16
17
18
def main(tritonserver_addr: str,
         session_id: int = 1,
         stream_output: bool = True):
lvhan028's avatar
lvhan028 committed
19
20
21
22
23
24
25
    """An example to communicate with inference server through the command line
    interface.

    Args:
        tritonserver_addr (str): the address in format "ip:port" of
          triton inference server
        session_id (int): the identical id of a session
26
        stream_output (bool): indicator for streaming output or not
lvhan028's avatar
lvhan028 committed
27
    """
28
    log_level = os.environ.get('SERVICE_LOG_LEVEL', 'WARNING')
29
30
31
    chatbot = Chatbot(tritonserver_addr,
                      log_level=log_level,
                      display=stream_output)
lvhan028's avatar
lvhan028 committed
32
33
34
35
36
37
38
39
40
    nth_round = 1
    while True:
        prompt = input_prompt()
        if prompt == 'exit':
            exit(0)
        elif prompt == 'end':
            chatbot.end(session_id)
        else:
            request_id = f'{session_id}-{nth_round}'
41
42
43
44
45
46
47
48
49
50
51
52
53
            if stream_output:
                for status, res, n_token in chatbot.stream_infer(
                        session_id,
                        prompt,
                        request_id=request_id,
                        request_output_len=512):
                    continue
            else:
                status, res, n_token = chatbot.infer(session_id,
                                                     prompt,
                                                     request_id=request_id,
                                                     request_output_len=512)
                print(res)
lvhan028's avatar
lvhan028 committed
54
55
56
57
58
        nth_round += 1


if __name__ == '__main__':
    fire.Fire(main)