api_server.py 4.46 KB
Newer Older
1
"""
2
3
4
5
6
NOTE: This API server is used only for demonstrating usage of AsyncEngine
and simple performance benchmarks. It is not intended for production use.
For production use, we recommend using our OpenAI compatible server.
We are also not going to accept PRs modifying this file, please
change `vllm/entrypoints/openai/api_server.py` instead.
7
"""
8

Zhuohan Li's avatar
Zhuohan Li committed
9
import json
10
import ssl
11
from typing import AsyncGenerator
Zhuohan Li's avatar
Zhuohan Li committed
12

13
import uvicorn
14
from fastapi import FastAPI, Request
15
from fastapi.responses import JSONResponse, Response, StreamingResponse
Zhuohan Li's avatar
Zhuohan Li committed
16

Woosuk Kwon's avatar
Woosuk Kwon committed
17
18
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
19
from vllm.logger import init_logger
Woosuk Kwon's avatar
Woosuk Kwon committed
20
from vllm.sampling_params import SamplingParams
yhu422's avatar
yhu422 committed
21
from vllm.usage.usage_lib import UsageContext
22
from vllm.utils import FlexibleArgumentParser, random_uuid
Zhuohan Li's avatar
Zhuohan Li committed
23

24
25
logger = init_logger("vllm.entrypoints.api_server")

26
TIMEOUT_KEEP_ALIVE = 5  # seconds.
Zhuohan Li's avatar
Zhuohan Li committed
27
app = FastAPI()
28
engine = None
Zhuohan Li's avatar
Zhuohan Li committed
29
30


31
32
33
34
35
36
@app.get("/health")
async def health() -> Response:
    """Health check."""
    return Response(status_code=200)


Zhuohan Li's avatar
Zhuohan Li committed
37
@app.post("/generate")
38
async def generate(request: Request) -> Response:
39
    """Generate completion for the request.
40
41
42

    The request should be a JSON object with the following fields:
    - prompt: the prompt to use for the generation.
43
    - stream: whether to stream the results or not.
44
45
    - other fields: the sampling parameters (See `SamplingParams` for details).
    """
Zhuohan Li's avatar
Zhuohan Li committed
46
47
    request_dict = await request.json()
    prompt = request_dict.pop("prompt")
48
    stream = request_dict.pop("stream", False)
Zhuohan Li's avatar
Zhuohan Li committed
49
    sampling_params = SamplingParams(**request_dict)
50
    request_id = random_uuid()
51

52
    assert engine is not None
53
    results_generator = engine.generate(prompt, sampling_params, request_id)
Zhuohan Li's avatar
Zhuohan Li committed
54

55
    # Streaming case
Zhuohan Li's avatar
Zhuohan Li committed
56
57
58
59
    async def stream_results() -> AsyncGenerator[bytes, None]:
        async for request_output in results_generator:
            prompt = request_output.prompt
            text_outputs = [
60
                prompt + output.text for output in request_output.outputs
Zhuohan Li's avatar
Zhuohan Li committed
61
            ]
62
            ret = {"text": text_outputs}
Zhuohan Li's avatar
Zhuohan Li committed
63
64
            yield (json.dumps(ret) + "\0").encode("utf-8")

65
    if stream:
66
        return StreamingResponse(stream_results())
67
68
69
70
71
72

    # Non-streaming case
    final_output = None
    async for request_output in results_generator:
        if await request.is_disconnected():
            # Abort the request if the client disconnects.
Zhuohan Li's avatar
Zhuohan Li committed
73
            await engine.abort(request_id)
74
75
76
77
78
            return Response(status_code=499)
        final_output = request_output

    assert final_output is not None
    prompt = final_output.prompt
79
    text_outputs = [prompt + output.text for output in final_output.outputs]
80
    ret = {"text": text_outputs}
81
    return JSONResponse(ret)
Zhuohan Li's avatar
Zhuohan Li committed
82
83
84


if __name__ == "__main__":
85
    parser = FlexibleArgumentParser()
86
    parser.add_argument("--host", type=str, default=None)
87
    parser.add_argument("--port", type=int, default=8000)
88
89
    parser.add_argument("--ssl-keyfile", type=str, default=None)
    parser.add_argument("--ssl-certfile", type=str, default=None)
Dan Clark's avatar
Dan Clark committed
90
91
92
93
    parser.add_argument("--ssl-ca-certs",
                        type=str,
                        default=None,
                        help="The CA certificates file")
94
95
96
97
98
99
    parser.add_argument(
        "--ssl-cert-reqs",
        type=int,
        default=int(ssl.CERT_NONE),
        help="Whether client certificate is required (see stdlib ssl module's)"
    )
100
101
102
103
104
    parser.add_argument(
        "--root-path",
        type=str,
        default=None,
        help="FastAPI root_path when app is behind a path based routing proxy")
105
    parser.add_argument("--log-level", type=str, default="debug")
Zhuohan Li's avatar
Zhuohan Li committed
106
    parser = AsyncEngineArgs.add_cli_args(parser)
Zhuohan Li's avatar
Zhuohan Li committed
107
    args = parser.parse_args()
108
109
110
111
112
    engine_args = AsyncEngineArgs.from_cli_args(args)
    engine = AsyncLLMEngine.from_engine_args(
        engine_args, usage_context=UsageContext.API_SERVER)

    app.root_path = args.root_path
113

114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
    logger.info("Available routes are:")
    for route in app.routes:
        if not hasattr(route, 'methods'):
            continue
        methods = ', '.join(route.methods)
        logger.info("Route: %s, Methods: %s", route.path, methods)

    uvicorn.run(app,
                host=args.host,
                port=args.port,
                log_level=args.log_level,
                timeout_keep_alive=TIMEOUT_KEEP_ALIVE,
                ssl_keyfile=args.ssl_keyfile,
                ssl_certfile=args.ssl_certfile,
                ssl_ca_certs=args.ssl_ca_certs,
                ssl_cert_reqs=args.ssl_cert_reqs)