api_client.py 2.78 KB
Newer Older
AllentDan's avatar
AllentDan committed
1
2
3
4
5
6
7
8
# Copyright (c) OpenMMLab. All rights reserved.
import json
from typing import Iterable, List

import fire
import requests


9
10
11
12
13
14
15
16
17
def get_model_list(api_url: str):
    response = requests.get(api_url)
    if hasattr(response, 'text'):
        model_list = json.loads(response.text)
        model_list = model_list.pop('data', [])
        return [item['id'] for item in model_list]
    return None


AllentDan's avatar
AllentDan committed
18
19
20
21
22
23
24
def get_streaming_response(prompt: str,
                           api_url: str,
                           instance_id: int,
                           request_output_len: int = 512,
                           stream: bool = True,
                           sequence_start: bool = True,
                           sequence_end: bool = True,
25
26
                           ignore_eos: bool = False,
                           stop: bool = False) -> Iterable[List[str]]:
AllentDan's avatar
AllentDan committed
27
28
29
30
31
32
33
34
    headers = {'User-Agent': 'Test Client'}
    pload = {
        'prompt': prompt,
        'stream': stream,
        'instance_id': instance_id,
        'request_output_len': request_output_len,
        'sequence_start': sequence_start,
        'sequence_end': sequence_end,
35
36
        'ignore_eos': ignore_eos,
        'stop': stop
AllentDan's avatar
AllentDan committed
37
38
39
40
41
42
43
44
45
46
    }
    response = requests.post(api_url,
                             headers=headers,
                             json=pload,
                             stream=stream)
    for chunk in response.iter_lines(chunk_size=8192,
                                     decode_unicode=False,
                                     delimiter=b'\0'):
        if chunk:
            data = json.loads(chunk.decode('utf-8'))
47
48
49
            output = data.pop('text', '')
            tokens = data.pop('tokens', 0)
            finish_reason = data.pop('finish_reason', None)
AllentDan's avatar
AllentDan committed
50
51
52
53
54
55
56
57
58
59
            yield output, tokens, finish_reason


def input_prompt():
    """Input a prompt in the consolo interface."""
    print('\ndouble enter to end input >>> ', end='')
    sentinel = ''  # ends when this string is seen
    return '\n'.join(iter(input, sentinel))


60
def main(restful_api_url: str, session_id: int = 0):
AllentDan's avatar
AllentDan committed
61
62
63
64
65
66
67
68
    nth_round = 1
    while True:
        prompt = input_prompt()
        if prompt == 'exit':
            exit(0)
        else:
            for output, tokens, finish_reason in get_streaming_response(
                    prompt,
69
                    f'{restful_api_url}/generate',
AllentDan's avatar
AllentDan committed
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
                    instance_id=session_id,
                    request_output_len=512,
                    sequence_start=(nth_round == 1),
                    sequence_end=False):
                if finish_reason == 'length':
                    print('WARNING: exceed session max length.'
                          ' Please end the session.')
                    continue
                print(output, end='')

            nth_round += 1


if __name__ == '__main__':
    fire.Fire(main)