api_client.py 2.87 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
"""Example Python client for `vllm.entrypoints.api_server`
3
4
5
Start the demo server:
    python -m vllm.entrypoints.api_server --model <model_name>

6
7
NOTE: The API server is used only for demonstration and simple performance
benchmarks. It is not intended for production use.
8
For production use, we recommend `vllm serve` and the OpenAI client API.
9
"""
10

Zhuohan Li's avatar
Zhuohan Li committed
11
12
import argparse
import json
13
from argparse import Namespace
14
from collections.abc import Iterable
Zhuohan Li's avatar
Zhuohan Li committed
15

16
17
18
19
import requests


def clear_line(n: int = 1) -> None:
Zhuohan Li's avatar
Zhuohan Li committed
20
21
    LINE_UP = '\033[1A'
    LINE_CLEAR = '\x1b[2K'
22
    for _ in range(n):
Zhuohan Li's avatar
Zhuohan Li committed
23
24
25
        print(LINE_UP, end=LINE_CLEAR, flush=True)


26
27
28
def post_http_request(prompt: str,
                      api_url: str,
                      n: int = 1,
29
                      stream: bool = False) -> requests.Response:
Zhuohan Li's avatar
Zhuohan Li committed
30
31
32
33
34
35
    headers = {"User-Agent": "Test Client"}
    pload = {
        "prompt": prompt,
        "n": n,
        "temperature": 0.0,
        "max_tokens": 16,
36
        "stream": stream,
Zhuohan Li's avatar
Zhuohan Li committed
37
    }
38
39
40
41
    response = requests.post(api_url,
                             headers=headers,
                             json=pload,
                             stream=stream)
42
43
    return response

Zhuohan Li's avatar
Zhuohan Li committed
44

45
def get_streaming_response(response: requests.Response) -> Iterable[list[str]]:
46
47
    for chunk in response.iter_lines(chunk_size=8192,
                                     decode_unicode=False,
48
                                     delimiter=b"\n"):
Zhuohan Li's avatar
Zhuohan Li committed
49
50
51
52
53
54
        if chunk:
            data = json.loads(chunk.decode("utf-8"))
            output = data["text"]
            yield output


55
def get_response(response: requests.Response) -> list[str]:
56
57
58
59
60
    data = json.loads(response.content)
    output = data["text"]
    return output


61
62
63
64
65
66
67
68
69
70
def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--host", type=str, default="localhost")
    parser.add_argument("--port", type=int, default=8000)
    parser.add_argument("--n", type=int, default=1)
    parser.add_argument("--prompt", type=str, default="San Francisco is a")
    parser.add_argument("--stream", action="store_true")
    return parser.parse_args()


71
def main(args: Namespace):
Zhuohan Li's avatar
Zhuohan Li committed
72
73
74
    prompt = args.prompt
    api_url = f"http://{args.host}:{args.port}/generate"
    n = args.n
75
    stream = args.stream
Zhuohan Li's avatar
Zhuohan Li committed
76

77
    print(f"Prompt: {prompt!r}\n", flush=True)
78
79
80
    response = post_http_request(prompt, api_url, n, stream)

    if stream:
Zhuohan Li's avatar
Zhuohan Li committed
81
        num_printed_lines = 0
82
83
84
85
86
        for h in get_streaming_response(response):
            clear_line(num_printed_lines)
            num_printed_lines = 0
            for i, line in enumerate(h):
                num_printed_lines += 1
87
                print(f"Beam candidate {i}: {line!r}", flush=True)
88
89
90
    else:
        output = get_response(response)
        for i, line in enumerate(output):
91
            print(f"Beam candidate {i}: {line!r}", flush=True)
92
93
94


if __name__ == "__main__":
95
    args = parse_args()
96
    main(args)