api_client.py 2.67 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
"""Example Python client for `vllm.entrypoints.api_server`
3
4
5
Start the demo server:
    python -m vllm.entrypoints.api_server --model <model_name>

6
7
NOTE: The API server is used only for demonstration and simple performance
benchmarks. It is not intended for production use.
8
For production use, we recommend `vllm serve` and the OpenAI client API.
9
"""
10

Zhuohan Li's avatar
Zhuohan Li committed
11
12
import argparse
import json
13
from argparse import Namespace
14
from collections.abc import Iterable
Zhuohan Li's avatar
Zhuohan Li committed
15

16
17
18
19
import requests


def clear_line(n: int = 1) -> None:
20
21
    LINE_UP = "\033[1A"
    LINE_CLEAR = "\x1b[2K"
22
    for _ in range(n):
Zhuohan Li's avatar
Zhuohan Li committed
23
24
25
        print(LINE_UP, end=LINE_CLEAR, flush=True)


26
27
28
def post_http_request(
    prompt: str, api_url: str, n: int = 1, stream: bool = False
) -> requests.Response:
Zhuohan Li's avatar
Zhuohan Li committed
29
30
31
32
33
34
    headers = {"User-Agent": "Test Client"}
    pload = {
        "prompt": prompt,
        "n": n,
        "temperature": 0.0,
        "max_tokens": 16,
35
        "stream": stream,
Zhuohan Li's avatar
Zhuohan Li committed
36
    }
37
    response = requests.post(api_url, headers=headers, json=pload, stream=stream)
38
39
    return response

Zhuohan Li's avatar
Zhuohan Li committed
40

41
def get_streaming_response(response: requests.Response) -> Iterable[list[str]]:
42
43
44
    for chunk in response.iter_lines(
        chunk_size=8192, decode_unicode=False, delimiter=b"\n"
    ):
Zhuohan Li's avatar
Zhuohan Li committed
45
46
47
48
49
50
        if chunk:
            data = json.loads(chunk.decode("utf-8"))
            output = data["text"]
            yield output


51
def get_response(response: requests.Response) -> list[str]:
52
53
54
55
56
    data = json.loads(response.content)
    output = data["text"]
    return output


57
58
59
60
61
62
63
64
65
66
def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--host", type=str, default="localhost")
    parser.add_argument("--port", type=int, default=8000)
    parser.add_argument("--n", type=int, default=1)
    parser.add_argument("--prompt", type=str, default="San Francisco is a")
    parser.add_argument("--stream", action="store_true")
    return parser.parse_args()


67
def main(args: Namespace):
Zhuohan Li's avatar
Zhuohan Li committed
68
69
70
    prompt = args.prompt
    api_url = f"http://{args.host}:{args.port}/generate"
    n = args.n
71
    stream = args.stream
Zhuohan Li's avatar
Zhuohan Li committed
72

73
    print(f"Prompt: {prompt!r}\n", flush=True)
74
75
76
    response = post_http_request(prompt, api_url, n, stream)

    if stream:
Zhuohan Li's avatar
Zhuohan Li committed
77
        num_printed_lines = 0
78
79
80
81
82
        for h in get_streaming_response(response):
            clear_line(num_printed_lines)
            num_printed_lines = 0
            for i, line in enumerate(h):
                num_printed_lines += 1
83
                print(f"Beam candidate {i}: {line!r}", flush=True)
84
85
86
    else:
        output = get_response(response)
        for i, line in enumerate(output):
87
            print(f"Beam candidate {i}: {line!r}", flush=True)
88
89
90


if __name__ == "__main__":
91
    args = parse_args()
92
    main(args)