api_server.py 5.57 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
"""
3
4
5
6
7
NOTE: This API server is used only for demonstrating usage of AsyncEngine
and simple performance benchmarks. It is not intended for production use.
For production use, we recommend using our OpenAI compatible server.
We are also not going to accept PRs modifying this file, please
change `vllm/entrypoints/openai/api_server.py` instead.
8
"""
9
import asyncio
Zhuohan Li's avatar
Zhuohan Li committed
10
import json
11
import ssl
12
from argparse import Namespace
13
14
from collections.abc import AsyncGenerator
from typing import Any, Optional
Zhuohan Li's avatar
Zhuohan Li committed
15

16
from fastapi import FastAPI, Request
17
from fastapi.responses import JSONResponse, Response, StreamingResponse
Zhuohan Li's avatar
Zhuohan Li committed
18

Woosuk Kwon's avatar
Woosuk Kwon committed
19
20
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
21
from vllm.entrypoints.launcher import serve_http
22
from vllm.entrypoints.utils import with_cancellation
23
from vllm.logger import init_logger
Woosuk Kwon's avatar
Woosuk Kwon committed
24
from vllm.sampling_params import SamplingParams
yhu422's avatar
yhu422 committed
25
from vllm.usage.usage_lib import UsageContext
26
from vllm.utils import FlexibleArgumentParser, random_uuid, set_ulimit
27
from vllm.version import __version__ as VLLM_VERSION
Zhuohan Li's avatar
Zhuohan Li committed
28

29
30
logger = init_logger("vllm.entrypoints.api_server")

31
TIMEOUT_KEEP_ALIVE = 5  # seconds.
Zhuohan Li's avatar
Zhuohan Li committed
32
app = FastAPI()
33
engine = None
Zhuohan Li's avatar
Zhuohan Li committed
34
35


36
37
38
39
40
41
@app.get("/health")
async def health() -> Response:
    """Health check."""
    return Response(status_code=200)


Zhuohan Li's avatar
Zhuohan Li committed
42
@app.post("/generate")
43
async def generate(request: Request) -> Response:
44
    """Generate completion for the request.
45
46
47

    The request should be a JSON object with the following fields:
    - prompt: the prompt to use for the generation.
48
    - stream: whether to stream the results or not.
49
50
    - other fields: the sampling parameters (See `SamplingParams` for details).
    """
Zhuohan Li's avatar
Zhuohan Li committed
51
    request_dict = await request.json()
52
53
54
55
56
    return await _generate(request_dict, raw_request=request)


@with_cancellation
async def _generate(request_dict: dict, raw_request: Request) -> Response:
Zhuohan Li's avatar
Zhuohan Li committed
57
    prompt = request_dict.pop("prompt")
58
    stream = request_dict.pop("stream", False)
Zhuohan Li's avatar
Zhuohan Li committed
59
    sampling_params = SamplingParams(**request_dict)
60
    request_id = random_uuid()
61

62
    assert engine is not None
63
    results_generator = engine.generate(prompt, sampling_params, request_id)
Zhuohan Li's avatar
Zhuohan Li committed
64

65
    # Streaming case
Zhuohan Li's avatar
Zhuohan Li committed
66
67
68
    async def stream_results() -> AsyncGenerator[bytes, None]:
        async for request_output in results_generator:
            prompt = request_output.prompt
69
            assert prompt is not None
Zhuohan Li's avatar
Zhuohan Li committed
70
            text_outputs = [
71
                prompt + output.text for output in request_output.outputs
Zhuohan Li's avatar
Zhuohan Li committed
72
            ]
73
            ret = {"text": text_outputs}
74
            yield (json.dumps(ret) + "\n").encode("utf-8")
Zhuohan Li's avatar
Zhuohan Li committed
75

76
    if stream:
77
        return StreamingResponse(stream_results())
78
79
80

    # Non-streaming case
    final_output = None
81
82
83
84
85
    try:
        async for request_output in results_generator:
            final_output = request_output
    except asyncio.CancelledError:
        return Response(status_code=499)
86
87
88

    assert final_output is not None
    prompt = final_output.prompt
89
    assert prompt is not None
90
    text_outputs = [prompt + output.text for output in final_output.outputs]
91
    ret = {"text": text_outputs}
92
    return JSONResponse(ret)
Zhuohan Li's avatar
Zhuohan Li committed
93
94


95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
def build_app(args: Namespace) -> FastAPI:
    global app

    app.root_path = args.root_path
    return app


async def init_app(
    args: Namespace,
    llm_engine: Optional[AsyncLLMEngine] = None,
) -> FastAPI:
    app = build_app(args)

    global engine

    engine_args = AsyncEngineArgs.from_cli_args(args)
    engine = (llm_engine
              if llm_engine is not None else AsyncLLMEngine.from_engine_args(
                  engine_args, usage_context=UsageContext.API_SERVER))

    return app


async def run_server(args: Namespace,
                     llm_engine: Optional[AsyncLLMEngine] = None,
                     **uvicorn_kwargs: Any) -> None:
    logger.info("vLLM API server version %s", VLLM_VERSION)
    logger.info("args: %s", args)

124
125
    set_ulimit()

126
    app = await init_app(args, llm_engine)
127
    assert engine is not None
128
129
130

    shutdown_task = await serve_http(
        app,
131
        sock=None,
132
        enable_ssl_refresh=args.enable_ssl_refresh,
133
134
135
136
137
138
139
140
141
142
143
144
145
146
        host=args.host,
        port=args.port,
        log_level=args.log_level,
        timeout_keep_alive=TIMEOUT_KEEP_ALIVE,
        ssl_keyfile=args.ssl_keyfile,
        ssl_certfile=args.ssl_certfile,
        ssl_ca_certs=args.ssl_ca_certs,
        ssl_cert_reqs=args.ssl_cert_reqs,
        **uvicorn_kwargs,
    )

    await shutdown_task


Zhuohan Li's avatar
Zhuohan Li committed
147
if __name__ == "__main__":
148
    parser = FlexibleArgumentParser()
149
    parser.add_argument("--host", type=str, default=None)
150
    parser.add_argument("--port", type=parser.check_port, default=8000)
151
152
    parser.add_argument("--ssl-keyfile", type=str, default=None)
    parser.add_argument("--ssl-certfile", type=str, default=None)
Dan Clark's avatar
Dan Clark committed
153
154
155
156
    parser.add_argument("--ssl-ca-certs",
                        type=str,
                        default=None,
                        help="The CA certificates file")
157
158
159
160
161
    parser.add_argument(
        "--enable-ssl-refresh",
        action="store_true",
        default=False,
        help="Refresh SSL Context when SSL certificate files change")
162
163
164
165
166
167
    parser.add_argument(
        "--ssl-cert-reqs",
        type=int,
        default=int(ssl.CERT_NONE),
        help="Whether client certificate is required (see stdlib ssl module's)"
    )
168
169
170
171
172
    parser.add_argument(
        "--root-path",
        type=str,
        default=None,
        help="FastAPI root_path when app is behind a path based routing proxy")
173
    parser.add_argument("--log-level", type=str, default="debug")
Zhuohan Li's avatar
Zhuohan Li committed
174
    parser = AsyncEngineArgs.add_cli_args(parser)
Zhuohan Li's avatar
Zhuohan Li committed
175
    args = parser.parse_args()
176

177
    asyncio.run(run_server(args))