client.py 2.73 KB
Newer Older
chenych's avatar
chenych committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import json
import requests
import argparse
import re


'''
使用示例:
公共知识库检索:python client.py --action query --query '问题'
私有知识库检索:python client.py --action query --query '问题' --user_id 'user_id'
'''

base_url = 'http://127.0.0.1:8000/%s'


def query(query, user_id=None):

    url = base_url % 'work'
    try:
        header = {'Content-Type': 'application/json'}
        # Add history to data
        data = {
            'query': query,
            'history': []
        }
        if user_id:
            data['user_id'] = user_id
        resp = requests.post(url,
                             headers=header,
                             data=json.dumps(data),
                             timeout=300)
        if resp.status_code != 200:
            raise Exception(str((resp.status_code, resp.reason)))
        return resp.json()['reply'], resp.json()['references']
    except Exception as e:
        print(str(e))
        return ''


def get_streaming_response(response: requests.Response):
    for chunk in response.iter_lines(chunk_size=1024, decode_unicode=False,
                                     delimiter=b"\0"):
        if chunk:
            pattern = re.compile(rb'data: "(\\u[0-9a-fA-F]{4})"')
            matches = pattern.findall(chunk)
            decoded_data = []
            for match in matches:
                hex_value = match[2:].decode('ascii')
                char = chr(int(hex_value, 16))
                decoded_data.append(char)
                print(char, end="", flush=True)


def stream_query(query):

    url = base_url % 'stream'
    try:
        headers={
            "Content-Type": "text/event-stream",
            "Cache-Control": "no-cache",
            "Connection": "keep-alive"
        }
        # Add history to data
        data = {
            'query': query,
            'history': []
        }
        resp = requests.get(url,
                            headers=headers,
                            data=json.dumps(data),
                            timeout=300,
                            verify=False,
                            stream=True)
        get_streaming_response(resp)
    except Exception as e:
        print(str(e))


def parse_args():
    parser = argparse.ArgumentParser(description='.')
    parser.add_argument('--query',
                        default='your query',
                        help='')
    parser.add_argument('--user_id', default='')
    parser.add_argument('--stream', action='store_true')
    return parser.parse_args()


if __name__ == '__main__':
    args = parse_args()

    if args.stream:
        stream_query(args.query, args.user_id)
    else:
        reply, ref = query(args.query, args.user_id)
        print('reply: {} \nref: {} '.format(reply, ref))