api_client.py 2.74 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""Example Python client for `vllm.entrypoints.api_server`
4
5
6
Start the demo server:
    python -m vllm.entrypoints.api_server --model <model_name>

7
8
NOTE: The API server is used only for demonstration and simple performance
benchmarks. It is not intended for production use.
9
For production use, we recommend `vllm serve` and the OpenAI client API.
10
"""
11

Zhuohan Li's avatar
Zhuohan Li committed
12
13
import argparse
import json
14
from argparse import Namespace
15
from collections.abc import Iterable
Zhuohan Li's avatar
Zhuohan Li committed
16

17
18
19
20
import requests


def clear_line(n: int = 1) -> None:
21
22
    LINE_UP = "\033[1A"
    LINE_CLEAR = "\x1b[2K"
23
    for _ in range(n):
Zhuohan Li's avatar
Zhuohan Li committed
24
25
26
        print(LINE_UP, end=LINE_CLEAR, flush=True)


27
28
29
def post_http_request(
    prompt: str, api_url: str, n: int = 1, stream: bool = False
) -> requests.Response:
Zhuohan Li's avatar
Zhuohan Li committed
30
31
32
33
34
35
    headers = {"User-Agent": "Test Client"}
    pload = {
        "prompt": prompt,
        "n": n,
        "temperature": 0.0,
        "max_tokens": 16,
36
        "stream": stream,
Zhuohan Li's avatar
Zhuohan Li committed
37
    }
38
    response = requests.post(api_url, headers=headers, json=pload, stream=stream)
39
40
    return response

Zhuohan Li's avatar
Zhuohan Li committed
41

42
def get_streaming_response(response: requests.Response) -> Iterable[list[str]]:
43
44
45
    for chunk in response.iter_lines(
        chunk_size=8192, decode_unicode=False, delimiter=b"\n"
    ):
Zhuohan Li's avatar
Zhuohan Li committed
46
47
48
49
50
51
        if chunk:
            data = json.loads(chunk.decode("utf-8"))
            output = data["text"]
            yield output


52
def get_response(response: requests.Response) -> list[str]:
53
54
55
56
57
    data = json.loads(response.content)
    output = data["text"]
    return output


58
59
60
61
62
63
64
65
66
67
def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--host", type=str, default="localhost")
    parser.add_argument("--port", type=int, default=8000)
    parser.add_argument("--n", type=int, default=1)
    parser.add_argument("--prompt", type=str, default="San Francisco is a")
    parser.add_argument("--stream", action="store_true")
    return parser.parse_args()


68
def main(args: Namespace):
Zhuohan Li's avatar
Zhuohan Li committed
69
70
71
    prompt = args.prompt
    api_url = f"http://{args.host}:{args.port}/generate"
    n = args.n
72
    stream = args.stream
Zhuohan Li's avatar
Zhuohan Li committed
73

74
    print(f"Prompt: {prompt!r}\n", flush=True)
75
76
77
    response = post_http_request(prompt, api_url, n, stream)

    if stream:
Zhuohan Li's avatar
Zhuohan Li committed
78
        num_printed_lines = 0
79
80
81
82
83
        for h in get_streaming_response(response):
            clear_line(num_printed_lines)
            num_printed_lines = 0
            for i, line in enumerate(h):
                num_printed_lines += 1
84
                print(f"Beam candidate {i}: {line!r}", flush=True)
85
86
87
    else:
        output = get_response(response)
        for i, line in enumerate(output):
88
            print(f"Beam candidate {i}: {line!r}", flush=True)
89
90
91


if __name__ == "__main__":
92
    args = parse_args()
93
    main(args)