api_server.py 5.67 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""
4
5
6
7
8
NOTE: This API server is used only for demonstrating usage of AsyncEngine
and simple performance benchmarks. It is not intended for production use.
For production use, we recommend using our OpenAI compatible server.
We are also not going to accept PRs modifying this file, please
change `vllm/entrypoints/openai/api_server.py` instead.
9
"""
10

11
import asyncio
Zhuohan Li's avatar
Zhuohan Li committed
12
import json
13
import ssl
14
from argparse import Namespace
15
from collections.abc import AsyncGenerator
16
from typing import Any
Zhuohan Li's avatar
Zhuohan Li committed
17

18
from fastapi import FastAPI, Request
19
from fastapi.responses import JSONResponse, Response, StreamingResponse
Zhuohan Li's avatar
Zhuohan Li committed
20

21
import vllm.envs as envs
Woosuk Kwon's avatar
Woosuk Kwon committed
22
23
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
24
from vllm.entrypoints.launcher import serve_http
25
from vllm.entrypoints.utils import with_cancellation
26
from vllm.logger import init_logger
Woosuk Kwon's avatar
Woosuk Kwon committed
27
from vllm.sampling_params import SamplingParams
yhu422's avatar
yhu422 committed
28
from vllm.usage.usage_lib import UsageContext
29
from vllm.utils import random_uuid
30
from vllm.utils.argparse_utils import FlexibleArgumentParser
31
from vllm.utils.system_utils import set_ulimit
32
from vllm.version import __version__ as VLLM_VERSION
Zhuohan Li's avatar
Zhuohan Li committed
33

34
35
logger = init_logger("vllm.entrypoints.api_server")

Zhuohan Li's avatar
Zhuohan Li committed
36
app = FastAPI()
37
engine = None
Zhuohan Li's avatar
Zhuohan Li committed
38
39


40
41
42
43
44
45
@app.get("/health")
async def health() -> Response:
    """Health check."""
    return Response(status_code=200)


Zhuohan Li's avatar
Zhuohan Li committed
46
@app.post("/generate")
47
async def generate(request: Request) -> Response:
48
    """Generate completion for the request.
49
50
51

    The request should be a JSON object with the following fields:
    - prompt: the prompt to use for the generation.
52
    - stream: whether to stream the results or not.
53
54
    - other fields: the sampling parameters (See `SamplingParams` for details).
    """
Zhuohan Li's avatar
Zhuohan Li committed
55
    request_dict = await request.json()
56
57
58
59
60
    return await _generate(request_dict, raw_request=request)


@with_cancellation
async def _generate(request_dict: dict, raw_request: Request) -> Response:
Zhuohan Li's avatar
Zhuohan Li committed
61
    prompt = request_dict.pop("prompt")
62
    stream = request_dict.pop("stream", False)
Zhuohan Li's avatar
Zhuohan Li committed
63
    sampling_params = SamplingParams(**request_dict)
64
    request_id = random_uuid()
65

66
    assert engine is not None
67
    results_generator = engine.generate(prompt, sampling_params, request_id)
Zhuohan Li's avatar
Zhuohan Li committed
68

69
    # Streaming case
Zhuohan Li's avatar
Zhuohan Li committed
70
71
72
    async def stream_results() -> AsyncGenerator[bytes, None]:
        async for request_output in results_generator:
            prompt = request_output.prompt
73
            assert prompt is not None
74
            text_outputs = [prompt + output.text for output in request_output.outputs]
75
            ret = {"text": text_outputs}
76
            yield (json.dumps(ret) + "\n").encode("utf-8")
Zhuohan Li's avatar
Zhuohan Li committed
77

78
    if stream:
79
        return StreamingResponse(stream_results())
80
81
82

    # Non-streaming case
    final_output = None
83
84
85
86
87
    try:
        async for request_output in results_generator:
            final_output = request_output
    except asyncio.CancelledError:
        return Response(status_code=499)
88
89
90

    assert final_output is not None
    prompt = final_output.prompt
91
    assert prompt is not None
92
    text_outputs = [prompt + output.text for output in final_output.outputs]
93
    ret = {"text": text_outputs}
94
    return JSONResponse(ret)
Zhuohan Li's avatar
Zhuohan Li committed
95
96


97
98
99
100
101
102
103
104
105
def build_app(args: Namespace) -> FastAPI:
    global app

    app.root_path = args.root_path
    return app


async def init_app(
    args: Namespace,
106
    llm_engine: AsyncLLMEngine | None = None,
107
108
109
110
111
112
) -> FastAPI:
    app = build_app(args)

    global engine

    engine_args = AsyncEngineArgs.from_cli_args(args)
113
114
115
116
117
118
119
    engine = (
        llm_engine
        if llm_engine is not None
        else AsyncLLMEngine.from_engine_args(
            engine_args, usage_context=UsageContext.API_SERVER
        )
    )
120
    app.state.engine_client = engine
121
    app.state.args = args
122
123
124
    return app


125
async def run_server(
126
    args: Namespace, llm_engine: AsyncLLMEngine | None = None, **uvicorn_kwargs: Any
127
) -> None:
128
129
130
    logger.info("vLLM API server version %s", VLLM_VERSION)
    logger.info("args: %s", args)

131
132
    set_ulimit()

133
    app = await init_app(args, llm_engine)
134
    assert engine is not None
135
136
137

    shutdown_task = await serve_http(
        app,
138
        sock=None,
139
        enable_ssl_refresh=args.enable_ssl_refresh,
140
141
142
        host=args.host,
        port=args.port,
        log_level=args.log_level,
143
        timeout_keep_alive=envs.VLLM_HTTP_TIMEOUT_KEEP_ALIVE,
144
145
146
147
148
149
150
151
152
153
        ssl_keyfile=args.ssl_keyfile,
        ssl_certfile=args.ssl_certfile,
        ssl_ca_certs=args.ssl_ca_certs,
        ssl_cert_reqs=args.ssl_cert_reqs,
        **uvicorn_kwargs,
    )

    await shutdown_task


Zhuohan Li's avatar
Zhuohan Li committed
154
if __name__ == "__main__":
155
    parser = FlexibleArgumentParser()
156
    parser.add_argument("--host", type=str, default=None)
157
    parser.add_argument("--port", type=parser.check_port, default=8000)
158
159
    parser.add_argument("--ssl-keyfile", type=str, default=None)
    parser.add_argument("--ssl-certfile", type=str, default=None)
160
161
162
    parser.add_argument(
        "--ssl-ca-certs", type=str, default=None, help="The CA certificates file"
    )
163
164
165
166
    parser.add_argument(
        "--enable-ssl-refresh",
        action="store_true",
        default=False,
167
168
        help="Refresh SSL Context when SSL certificate files change",
    )
169
170
171
172
    parser.add_argument(
        "--ssl-cert-reqs",
        type=int,
        default=int(ssl.CERT_NONE),
173
        help="Whether client certificate is required (see stdlib ssl module's)",
174
    )
175
176
177
178
    parser.add_argument(
        "--root-path",
        type=str,
        default=None,
179
180
        help="FastAPI root_path when app is behind a path based routing proxy",
    )
181
    parser.add_argument("--log-level", type=str, default="debug")
Zhuohan Li's avatar
Zhuohan Li committed
182
    parser = AsyncEngineArgs.add_cli_args(parser)
Zhuohan Li's avatar
Zhuohan Li committed
183
    args = parser.parse_args()
184

185
    asyncio.run(run_server(args))