api_server.py 3.48 KB
Newer Older
1
2
3
4
5
6
"""
NOTE: This API server is used only for demonstrating usage of AsyncEngine and simple performance benchmarks.
It is not intended for production use. For production use, we recommend using our OpenAI compatible server.
We are also not going to accept PRs modifying this file, please change `vllm/entrypoints/openai/api_server.py` instead.
"""

Zhuohan Li's avatar
Zhuohan Li committed
7
8
9
10
import argparse
import json
from typing import AsyncGenerator

11
from fastapi import FastAPI, Request
12
from fastapi.responses import JSONResponse, Response, StreamingResponse
Zhuohan Li's avatar
Zhuohan Li committed
13
14
import uvicorn

Woosuk Kwon's avatar
Woosuk Kwon committed
15
16
17
18
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.sampling_params import SamplingParams
from vllm.utils import random_uuid
Zhuohan Li's avatar
Zhuohan Li committed
19

20
TIMEOUT_KEEP_ALIVE = 5  # seconds.
Zhuohan Li's avatar
Zhuohan Li committed
21
app = FastAPI()
22
engine = None
Zhuohan Li's avatar
Zhuohan Li committed
23
24


25
26
27
28
29
30
@app.get("/health")
async def health() -> Response:
    """Health check."""
    return Response(status_code=200)


Zhuohan Li's avatar
Zhuohan Li committed
31
@app.post("/generate")
32
async def generate(request: Request) -> Response:
33
    """Generate completion for the request.
34
35
36

    The request should be a JSON object with the following fields:
    - prompt: the prompt to use for the generation.
37
    - stream: whether to stream the results or not.
38
39
    - other fields: the sampling parameters (See `SamplingParams` for details).
    """
Zhuohan Li's avatar
Zhuohan Li committed
40
41
    request_dict = await request.json()
    prompt = request_dict.pop("prompt")
42
    stream = request_dict.pop("stream", False)
Zhuohan Li's avatar
Zhuohan Li committed
43
    sampling_params = SamplingParams(**request_dict)
44
    request_id = random_uuid()
45

46
    results_generator = engine.generate(prompt, sampling_params, request_id)
Zhuohan Li's avatar
Zhuohan Li committed
47

48
    # Streaming case
Zhuohan Li's avatar
Zhuohan Li committed
49
50
51
52
    async def stream_results() -> AsyncGenerator[bytes, None]:
        async for request_output in results_generator:
            prompt = request_output.prompt
            text_outputs = [
53
                prompt + output.text for output in request_output.outputs
Zhuohan Li's avatar
Zhuohan Li committed
54
            ]
55
            ret = {"text": text_outputs}
Zhuohan Li's avatar
Zhuohan Li committed
56
57
            yield (json.dumps(ret) + "\0").encode("utf-8")

58
    if stream:
59
        return StreamingResponse(stream_results())
60
61
62
63
64
65

    # Non-streaming case
    final_output = None
    async for request_output in results_generator:
        if await request.is_disconnected():
            # Abort the request if the client disconnects.
Zhuohan Li's avatar
Zhuohan Li committed
66
            await engine.abort(request_id)
67
68
69
70
71
            return Response(status_code=499)
        final_output = request_output

    assert final_output is not None
    prompt = final_output.prompt
72
    text_outputs = [prompt + output.text for output in final_output.outputs]
73
    ret = {"text": text_outputs}
74
    return JSONResponse(ret)
Zhuohan Li's avatar
Zhuohan Li committed
75
76
77
78


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
79
    parser.add_argument("--host", type=str, default=None)
80
    parser.add_argument("--port", type=int, default=8000)
81
82
    parser.add_argument("--ssl-keyfile", type=str, default=None)
    parser.add_argument("--ssl-certfile", type=str, default=None)
83
84
85
86
87
    parser.add_argument(
        "--root-path",
        type=str,
        default=None,
        help="FastAPI root_path when app is behind a path based routing proxy")
Zhuohan Li's avatar
Zhuohan Li committed
88
    parser = AsyncEngineArgs.add_cli_args(parser)
Zhuohan Li's avatar
Zhuohan Li committed
89
90
    args = parser.parse_args()

Zhuohan Li's avatar
Zhuohan Li committed
91
    engine_args = AsyncEngineArgs.from_cli_args(args)
92
    engine = AsyncLLMEngine.from_engine_args(engine_args)
Zhuohan Li's avatar
Zhuohan Li committed
93

94
    app.root_path = args.root_path
95
96
97
98
    uvicorn.run(app,
                host=args.host,
                port=args.port,
                log_level="debug",
99
100
101
                timeout_keep_alive=TIMEOUT_KEEP_ALIVE,
                ssl_keyfile=args.ssl_keyfile,
                ssl_certfile=args.ssl_certfile)