api_client.py 3.12 KB
Newer Older
AllentDan's avatar
AllentDan committed
1
2
3
4
5
6
7
# Copyright (c) OpenMMLab. All rights reserved.
import json
from typing import Iterable, List

import requests


8
9
10
11
12
13
14
15
16
def get_model_list(api_url: str):
    response = requests.get(api_url)
    if hasattr(response, 'text'):
        model_list = json.loads(response.text)
        model_list = model_list.pop('data', [])
        return [item['id'] for item in model_list]
    return None


AllentDan's avatar
AllentDan committed
17
18
def get_streaming_response(prompt: str,
                           api_url: str,
19
                           session_id: int,
AllentDan's avatar
AllentDan committed
20
21
22
23
                           request_output_len: int = 512,
                           stream: bool = True,
                           sequence_start: bool = True,
                           sequence_end: bool = True,
24
25
                           ignore_eos: bool = False,
                           stop: bool = False) -> Iterable[List[str]]:
AllentDan's avatar
AllentDan committed
26
27
28
29
    headers = {'User-Agent': 'Test Client'}
    pload = {
        'prompt': prompt,
        'stream': stream,
30
        'session_id': session_id,
AllentDan's avatar
AllentDan committed
31
32
33
        'request_output_len': request_output_len,
        'sequence_start': sequence_start,
        'sequence_end': sequence_end,
34
35
        'ignore_eos': ignore_eos,
        'stop': stop
AllentDan's avatar
AllentDan committed
36
37
38
39
40
41
42
    }
    response = requests.post(api_url,
                             headers=headers,
                             json=pload,
                             stream=stream)
    for chunk in response.iter_lines(chunk_size=8192,
                                     decode_unicode=False,
43
                                     delimiter=b'\n'):
AllentDan's avatar
AllentDan committed
44
45
        if chunk:
            data = json.loads(chunk.decode('utf-8'))
46
47
48
            output = data.pop('text', '')
            tokens = data.pop('tokens', 0)
            finish_reason = data.pop('finish_reason', None)
AllentDan's avatar
AllentDan committed
49
50
51
52
53
54
55
56
57
58
            yield output, tokens, finish_reason


def input_prompt():
    """Input a prompt in the consolo interface."""
    print('\ndouble enter to end input >>> ', end='')
    sentinel = ''  # ends when this string is seen
    return '\n'.join(iter(input, sentinel))


59
def main(restful_api_url: str, session_id: int = 0):
AllentDan's avatar
AllentDan committed
60
61
62
63
    nth_round = 1
    while True:
        prompt = input_prompt()
        if prompt == 'exit':
64
65
66
67
68
69
70
71
            for output, tokens, finish_reason in get_streaming_response(
                    '',
                    f'{restful_api_url}/generate',
                    session_id=session_id,
                    request_output_len=0,
                    sequence_start=(nth_round == 1),
                    sequence_end=True):
                pass
AllentDan's avatar
AllentDan committed
72
73
74
75
            exit(0)
        else:
            for output, tokens, finish_reason in get_streaming_response(
                    prompt,
76
                    f'{restful_api_url}/generate',
77
                    session_id=session_id,
AllentDan's avatar
AllentDan committed
78
79
80
81
82
83
84
85
86
87
88
89
90
                    request_output_len=512,
                    sequence_start=(nth_round == 1),
                    sequence_end=False):
                if finish_reason == 'length':
                    print('WARNING: exceed session max length.'
                          ' Please end the session.')
                    continue
                print(output, end='')

            nth_round += 1


if __name__ == '__main__':
91
92
    import fire

AllentDan's avatar
AllentDan committed
93
    fire.Fire(main)