client.py 2.28 KB
Newer Older
Rayyyyy's avatar
Rayyyyy committed
1
2
3
import json
import argparse
import requests
chenych's avatar
chenych committed
4
5
import configparser
from typing import Iterable, List
Rayyyyy's avatar
Rayyyyy committed
6
7


chenych's avatar
chenych committed
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
def get_streaming_response(response: requests.Response) -> Iterable[List[str]]:
    for chunk in response.iter_lines(chunk_size=1024, decode_unicode=False,
                                     delimiter=b"\0"):
        if chunk:
            data = json.loads(chunk.decode("utf-8"))
            output = data["text"]
            yield output


def get_response(response: requests.Response) -> List[str]:
    data = json.loads(response.content.decode("utf-8"))
    output = data["text"]
    return output


def clear_line(n: int = 1) -> None:
    LINE_UP = '\033[1A'
    LINE_CLEAR = '\x1b[2K'
    for _ in range(n):
        print(LINE_UP, end=LINE_CLEAR, flush=True)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--query', default='请写一首诗')
    parser.add_argument('--use_hf', action='store_true')
    parser.add_argument(
        '--config_path', default='../config.ini', help='config目录')
    args = parser.parse_args()

    print(args.query)
    headers = {"Content-Type": "application/json"}
    data = {
Rayyyyy's avatar
Rayyyyy committed
41
42
        "query": args.query,
        "history": []
chenych's avatar
chenych committed
43
44
45
46
47
48
49
    }

    json_str = json.dumps(data)

    config = configparser.ConfigParser()
    config.read(args.config_path)
    stream_chat = config.getboolean('llm', 'stream_chat')
chenych's avatar
chenych committed
50

chenych's avatar
chenych committed
51
52
53
    func = 'vllm_inference'
    if args.use_hf:
        func = 'hf_inference'
chenych's avatar
chenych committed
54
55
    if stream_chat:
        func = 'vllm_inference_stream'
chenych's avatar
chenych committed
56
57

    api_url = f"http://localhost:8888/{func}"
Rayyyyy's avatar
Rayyyyy committed
58

chenych's avatar
chenych committed
59
    if stream_chat:
chenych's avatar
chenych committed
60
61
        response = requests.get(api_url, headers=headers, data=json_str.encode(
            "utf-8"), verify=False,  stream=stream_chat)
chenych's avatar
chenych committed
62
63
64
65
66
67
68
69
        num_printed_lines = 0
        for h in get_streaming_response(response):
            clear_line(num_printed_lines)
            num_printed_lines = 0
            for i, line in enumerate(h):
                num_printed_lines += 1
                print(f"Beam candidate {i}: {line!r}", flush=True)
    else:
chenych's avatar
chenych committed
70
71
        response = requests.get(api_url, headers=headers, data=json_str.encode(
            "utf-8"), verify=False,  stream=stream_chat)
chenych's avatar
chenych committed
72
73
74
        output = get_response(response)
        for i, line in enumerate(output):
            print(f"Beam candidate {i}: {line!r}", flush=True)