api_server.py 5.21 KB
Newer Older
1
"""
2
3
4
5
6
NOTE: This API server is used only for demonstrating usage of AsyncEngine
and simple performance benchmarks. It is not intended for production use.
For production use, we recommend using our OpenAI compatible server.
We are also not going to accept PRs modifying this file, please
change `vllm/entrypoints/openai/api_server.py` instead.
7
"""
8
import asyncio
Zhuohan Li's avatar
Zhuohan Li committed
9
import json
10
import ssl
11
12
from argparse import Namespace
from typing import Any, AsyncGenerator, Optional
Zhuohan Li's avatar
Zhuohan Li committed
13

14
from fastapi import FastAPI, Request
15
from fastapi.responses import JSONResponse, Response, StreamingResponse
Zhuohan Li's avatar
Zhuohan Li committed
16

Woosuk Kwon's avatar
Woosuk Kwon committed
17
18
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
19
from vllm.entrypoints.launcher import serve_http
20
from vllm.logger import init_logger
Woosuk Kwon's avatar
Woosuk Kwon committed
21
from vllm.sampling_params import SamplingParams
yhu422's avatar
yhu422 committed
22
from vllm.usage.usage_lib import UsageContext
23
24
from vllm.utils import (FlexibleArgumentParser, iterate_with_cancellation,
                        random_uuid)
25
from vllm.version import __version__ as VLLM_VERSION
Zhuohan Li's avatar
Zhuohan Li committed
26

27
28
logger = init_logger("vllm.entrypoints.api_server")

29
TIMEOUT_KEEP_ALIVE = 5  # seconds.
Zhuohan Li's avatar
Zhuohan Li committed
30
app = FastAPI()
31
engine = None
Zhuohan Li's avatar
Zhuohan Li committed
32
33


34
35
36
37
38
39
@app.get("/health")
async def health() -> Response:
    """Health check."""
    return Response(status_code=200)


Zhuohan Li's avatar
Zhuohan Li committed
40
@app.post("/generate")
41
async def generate(request: Request) -> Response:
42
    """Generate completion for the request.
43
44
45

    The request should be a JSON object with the following fields:
    - prompt: the prompt to use for the generation.
46
    - stream: whether to stream the results or not.
47
48
    - other fields: the sampling parameters (See `SamplingParams` for details).
    """
Zhuohan Li's avatar
Zhuohan Li committed
49
50
    request_dict = await request.json()
    prompt = request_dict.pop("prompt")
51
    stream = request_dict.pop("stream", False)
Zhuohan Li's avatar
Zhuohan Li committed
52
    sampling_params = SamplingParams(**request_dict)
53
    request_id = random_uuid()
54

55
    assert engine is not None
56
    results_generator = engine.generate(prompt, sampling_params, request_id)
57
58
    results_generator = iterate_with_cancellation(
        results_generator, is_cancelled=request.is_disconnected)
Zhuohan Li's avatar
Zhuohan Li committed
59

60
    # Streaming case
Zhuohan Li's avatar
Zhuohan Li committed
61
62
63
    async def stream_results() -> AsyncGenerator[bytes, None]:
        async for request_output in results_generator:
            prompt = request_output.prompt
64
            assert prompt is not None
Zhuohan Li's avatar
Zhuohan Li committed
65
            text_outputs = [
66
                prompt + output.text for output in request_output.outputs
Zhuohan Li's avatar
Zhuohan Li committed
67
            ]
68
            ret = {"text": text_outputs}
Zhuohan Li's avatar
Zhuohan Li committed
69
70
            yield (json.dumps(ret) + "\0").encode("utf-8")

71
    if stream:
72
        return StreamingResponse(stream_results())
73
74
75

    # Non-streaming case
    final_output = None
76
77
78
79
80
    try:
        async for request_output in results_generator:
            final_output = request_output
    except asyncio.CancelledError:
        return Response(status_code=499)
81
82
83

    assert final_output is not None
    prompt = final_output.prompt
84
    assert prompt is not None
85
    text_outputs = [prompt + output.text for output in final_output.outputs]
86
    ret = {"text": text_outputs}
87
    return JSONResponse(ret)
Zhuohan Li's avatar
Zhuohan Li committed
88
89


90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
def build_app(args: Namespace) -> FastAPI:
    global app

    app.root_path = args.root_path
    return app


async def init_app(
    args: Namespace,
    llm_engine: Optional[AsyncLLMEngine] = None,
) -> FastAPI:
    app = build_app(args)

    global engine

    engine_args = AsyncEngineArgs.from_cli_args(args)
    engine = (llm_engine
              if llm_engine is not None else AsyncLLMEngine.from_engine_args(
                  engine_args, usage_context=UsageContext.API_SERVER))

    return app


async def run_server(args: Namespace,
                     llm_engine: Optional[AsyncLLMEngine] = None,
                     **uvicorn_kwargs: Any) -> None:
    logger.info("vLLM API server version %s", VLLM_VERSION)
    logger.info("args: %s", args)

    app = await init_app(args, llm_engine)
120
    assert engine is not None
121
122
123

    shutdown_task = await serve_http(
        app,
124
        engine=engine,
125
126
127
128
129
130
131
132
133
134
135
136
137
138
        host=args.host,
        port=args.port,
        log_level=args.log_level,
        timeout_keep_alive=TIMEOUT_KEEP_ALIVE,
        ssl_keyfile=args.ssl_keyfile,
        ssl_certfile=args.ssl_certfile,
        ssl_ca_certs=args.ssl_ca_certs,
        ssl_cert_reqs=args.ssl_cert_reqs,
        **uvicorn_kwargs,
    )

    await shutdown_task


Zhuohan Li's avatar
Zhuohan Li committed
139
if __name__ == "__main__":
140
    parser = FlexibleArgumentParser()
141
    parser.add_argument("--host", type=str, default=None)
142
    parser.add_argument("--port", type=int, default=8000)
143
144
    parser.add_argument("--ssl-keyfile", type=str, default=None)
    parser.add_argument("--ssl-certfile", type=str, default=None)
Dan Clark's avatar
Dan Clark committed
145
146
147
148
    parser.add_argument("--ssl-ca-certs",
                        type=str,
                        default=None,
                        help="The CA certificates file")
149
150
151
152
153
154
    parser.add_argument(
        "--ssl-cert-reqs",
        type=int,
        default=int(ssl.CERT_NONE),
        help="Whether client certificate is required (see stdlib ssl module's)"
    )
155
156
157
158
159
    parser.add_argument(
        "--root-path",
        type=str,
        default=None,
        help="FastAPI root_path when app is behind a path based routing proxy")
160
    parser.add_argument("--log-level", type=str, default="debug")
Zhuohan Li's avatar
Zhuohan Li committed
161
    parser = AsyncEngineArgs.add_cli_args(parser)
Zhuohan Li's avatar
Zhuohan Li committed
162
    args = parser.parse_args()
163

164
    asyncio.run(run_server(args))