api_server.py 4.03 KB
Newer Older
1
"""
2
3
4
5
6
NOTE: This API server is used only for demonstrating usage of AsyncEngine
and simple performance benchmarks. It is not intended for production use.
For production use, we recommend using our OpenAI compatible server.
We are also not going to accept PRs modifying this file, please
change `vllm/entrypoints/openai/api_server.py` instead.
7
8
"""

Zhuohan Li's avatar
Zhuohan Li committed
9
10
import argparse
import json
11
import ssl
Zhuohan Li's avatar
Zhuohan Li committed
12
13
from typing import AsyncGenerator

14
import uvicorn
15
from fastapi import FastAPI, Request
16
from fastapi.responses import JSONResponse, Response, StreamingResponse
Zhuohan Li's avatar
Zhuohan Li committed
17

Woosuk Kwon's avatar
Woosuk Kwon committed
18
19
20
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.sampling_params import SamplingParams
yhu422's avatar
yhu422 committed
21
from vllm.usage.usage_lib import UsageContext
Woosuk Kwon's avatar
Woosuk Kwon committed
22
from vllm.utils import random_uuid
Zhuohan Li's avatar
Zhuohan Li committed
23

24
TIMEOUT_KEEP_ALIVE = 5  # seconds.
Zhuohan Li's avatar
Zhuohan Li committed
25
app = FastAPI()
26
engine = None
Zhuohan Li's avatar
Zhuohan Li committed
27
28


29
30
31
32
33
34
@app.get("/health")
async def health() -> Response:
    """Health check."""
    return Response(status_code=200)


Zhuohan Li's avatar
Zhuohan Li committed
35
@app.post("/generate")
36
async def generate(request: Request) -> Response:
37
    """Generate completion for the request.
38
39
40

    The request should be a JSON object with the following fields:
    - prompt: the prompt to use for the generation.
41
    - stream: whether to stream the results or not.
42
43
    - other fields: the sampling parameters (See `SamplingParams` for details).
    """
Zhuohan Li's avatar
Zhuohan Li committed
44
45
    request_dict = await request.json()
    prompt = request_dict.pop("prompt")
46
    stream = request_dict.pop("stream", False)
Zhuohan Li's avatar
Zhuohan Li committed
47
    sampling_params = SamplingParams(**request_dict)
48
    request_id = random_uuid()
49

50
    results_generator = engine.generate(prompt, sampling_params, request_id)
Zhuohan Li's avatar
Zhuohan Li committed
51

52
    # Streaming case
Zhuohan Li's avatar
Zhuohan Li committed
53
54
55
56
    async def stream_results() -> AsyncGenerator[bytes, None]:
        async for request_output in results_generator:
            prompt = request_output.prompt
            text_outputs = [
57
                prompt + output.text for output in request_output.outputs
Zhuohan Li's avatar
Zhuohan Li committed
58
            ]
59
            ret = {"text": text_outputs}
Zhuohan Li's avatar
Zhuohan Li committed
60
61
            yield (json.dumps(ret) + "\0").encode("utf-8")

62
    if stream:
63
        return StreamingResponse(stream_results())
64
65
66
67
68
69

    # Non-streaming case
    final_output = None
    async for request_output in results_generator:
        if await request.is_disconnected():
            # Abort the request if the client disconnects.
Zhuohan Li's avatar
Zhuohan Li committed
70
            await engine.abort(request_id)
71
72
73
74
75
            return Response(status_code=499)
        final_output = request_output

    assert final_output is not None
    prompt = final_output.prompt
76
    text_outputs = [prompt + output.text for output in final_output.outputs]
77
    ret = {"text": text_outputs}
78
    return JSONResponse(ret)
Zhuohan Li's avatar
Zhuohan Li committed
79
80
81
82


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
83
    parser.add_argument("--host", type=str, default=None)
84
    parser.add_argument("--port", type=int, default=8000)
85
86
    parser.add_argument("--ssl-keyfile", type=str, default=None)
    parser.add_argument("--ssl-certfile", type=str, default=None)
Dan Clark's avatar
Dan Clark committed
87
88
89
90
    parser.add_argument("--ssl-ca-certs",
                        type=str,
                        default=None,
                        help="The CA certificates file")
91
92
93
94
95
96
    parser.add_argument(
        "--ssl-cert-reqs",
        type=int,
        default=int(ssl.CERT_NONE),
        help="Whether client certificate is required (see stdlib ssl module's)"
    )
97
98
99
100
101
    parser.add_argument(
        "--root-path",
        type=str,
        default=None,
        help="FastAPI root_path when app is behind a path based routing proxy")
Zhuohan Li's avatar
Zhuohan Li committed
102
    parser = AsyncEngineArgs.add_cli_args(parser)
Zhuohan Li's avatar
Zhuohan Li committed
103
    args = parser.parse_args()
Zhuohan Li's avatar
Zhuohan Li committed
104
    engine_args = AsyncEngineArgs.from_cli_args(args)
yhu422's avatar
yhu422 committed
105
106
    engine = AsyncLLMEngine.from_engine_args(
        engine_args, usage_context=UsageContext.API_SERVER)
Zhuohan Li's avatar
Zhuohan Li committed
107

108
    app.root_path = args.root_path
109
110
111
112
    uvicorn.run(app,
                host=args.host,
                port=args.port,
                log_level="debug",
113
114
                timeout_keep_alive=TIMEOUT_KEEP_ALIVE,
                ssl_keyfile=args.ssl_keyfile,
Dan Clark's avatar
Dan Clark committed
115
116
117
                ssl_certfile=args.ssl_certfile,
                ssl_ca_certs=args.ssl_ca_certs,
                ssl_cert_reqs=args.ssl_cert_reqs)