api_server.py 5.58 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""
4
5
6
7
8
NOTE: This API server is used only for demonstrating usage of AsyncEngine
and simple performance benchmarks. It is not intended for production use.
For production use, we recommend using our OpenAI compatible server.
We are also not going to accept PRs modifying this file, please
change `vllm/entrypoints/openai/api_server.py` instead.
9
"""
10

11
import asyncio
Zhuohan Li's avatar
Zhuohan Li committed
12
import json
13
import ssl
14
from argparse import Namespace
15
from collections.abc import AsyncGenerator
16
from typing import Any
Zhuohan Li's avatar
Zhuohan Li committed
17

18
from fastapi import FastAPI, Request
19
from fastapi.responses import JSONResponse, Response, StreamingResponse
Zhuohan Li's avatar
Zhuohan Li committed
20

21
import vllm.envs as envs
Woosuk Kwon's avatar
Woosuk Kwon committed
22
23
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
24
from vllm.entrypoints.launcher import serve_http
25
from vllm.entrypoints.utils import with_cancellation
26
from vllm.logger import init_logger
Woosuk Kwon's avatar
Woosuk Kwon committed
27
from vllm.sampling_params import SamplingParams
yhu422's avatar
yhu422 committed
28
from vllm.usage.usage_lib import UsageContext
29
from vllm.utils import FlexibleArgumentParser, random_uuid, set_ulimit
30
from vllm.version import __version__ as VLLM_VERSION
Zhuohan Li's avatar
Zhuohan Li committed
31

32
33
logger = init_logger("vllm.entrypoints.api_server")

Zhuohan Li's avatar
Zhuohan Li committed
34
app = FastAPI()
35
engine = None
Zhuohan Li's avatar
Zhuohan Li committed
36
37


38
39
40
41
42
43
@app.get("/health")
async def health() -> Response:
    """Health check."""
    return Response(status_code=200)


Zhuohan Li's avatar
Zhuohan Li committed
44
@app.post("/generate")
45
async def generate(request: Request) -> Response:
46
    """Generate completion for the request.
47
48
49

    The request should be a JSON object with the following fields:
    - prompt: the prompt to use for the generation.
50
    - stream: whether to stream the results or not.
51
52
    - other fields: the sampling parameters (See `SamplingParams` for details).
    """
Zhuohan Li's avatar
Zhuohan Li committed
53
    request_dict = await request.json()
54
55
56
57
58
    return await _generate(request_dict, raw_request=request)


@with_cancellation
async def _generate(request_dict: dict, raw_request: Request) -> Response:
Zhuohan Li's avatar
Zhuohan Li committed
59
    prompt = request_dict.pop("prompt")
60
    stream = request_dict.pop("stream", False)
Zhuohan Li's avatar
Zhuohan Li committed
61
    sampling_params = SamplingParams(**request_dict)
62
    request_id = random_uuid()
63

64
    assert engine is not None
65
    results_generator = engine.generate(prompt, sampling_params, request_id)
Zhuohan Li's avatar
Zhuohan Li committed
66

67
    # Streaming case
Zhuohan Li's avatar
Zhuohan Li committed
68
69
70
    async def stream_results() -> AsyncGenerator[bytes, None]:
        async for request_output in results_generator:
            prompt = request_output.prompt
71
            assert prompt is not None
72
            text_outputs = [prompt + output.text for output in request_output.outputs]
73
            ret = {"text": text_outputs}
74
            yield (json.dumps(ret) + "\n").encode("utf-8")
Zhuohan Li's avatar
Zhuohan Li committed
75

76
    if stream:
77
        return StreamingResponse(stream_results())
78
79
80

    # Non-streaming case
    final_output = None
81
82
83
84
85
    try:
        async for request_output in results_generator:
            final_output = request_output
    except asyncio.CancelledError:
        return Response(status_code=499)
86
87
88

    assert final_output is not None
    prompt = final_output.prompt
89
    assert prompt is not None
90
    text_outputs = [prompt + output.text for output in final_output.outputs]
91
    ret = {"text": text_outputs}
92
    return JSONResponse(ret)
Zhuohan Li's avatar
Zhuohan Li committed
93
94


95
96
97
98
99
100
101
102
103
def build_app(args: Namespace) -> FastAPI:
    global app

    app.root_path = args.root_path
    return app


async def init_app(
    args: Namespace,
104
    llm_engine: AsyncLLMEngine | None = None,
105
106
107
108
109
110
) -> FastAPI:
    app = build_app(args)

    global engine

    engine_args = AsyncEngineArgs.from_cli_args(args)
111
112
113
114
115
116
117
    engine = (
        llm_engine
        if llm_engine is not None
        else AsyncLLMEngine.from_engine_args(
            engine_args, usage_context=UsageContext.API_SERVER
        )
    )
118
    app.state.engine_client = engine
119
120
121
    return app


122
async def run_server(
123
    args: Namespace, llm_engine: AsyncLLMEngine | None = None, **uvicorn_kwargs: Any
124
) -> None:
125
126
127
    logger.info("vLLM API server version %s", VLLM_VERSION)
    logger.info("args: %s", args)

128
129
    set_ulimit()

130
    app = await init_app(args, llm_engine)
131
    assert engine is not None
132
133
134

    shutdown_task = await serve_http(
        app,
135
        sock=None,
136
        enable_ssl_refresh=args.enable_ssl_refresh,
137
138
139
        host=args.host,
        port=args.port,
        log_level=args.log_level,
140
        timeout_keep_alive=envs.VLLM_HTTP_TIMEOUT_KEEP_ALIVE,
141
142
143
144
145
146
147
148
149
150
        ssl_keyfile=args.ssl_keyfile,
        ssl_certfile=args.ssl_certfile,
        ssl_ca_certs=args.ssl_ca_certs,
        ssl_cert_reqs=args.ssl_cert_reqs,
        **uvicorn_kwargs,
    )

    await shutdown_task


Zhuohan Li's avatar
Zhuohan Li committed
151
if __name__ == "__main__":
152
    parser = FlexibleArgumentParser()
153
    parser.add_argument("--host", type=str, default=None)
154
    parser.add_argument("--port", type=parser.check_port, default=8000)
155
156
    parser.add_argument("--ssl-keyfile", type=str, default=None)
    parser.add_argument("--ssl-certfile", type=str, default=None)
157
158
159
    parser.add_argument(
        "--ssl-ca-certs", type=str, default=None, help="The CA certificates file"
    )
160
161
162
163
    parser.add_argument(
        "--enable-ssl-refresh",
        action="store_true",
        default=False,
164
165
        help="Refresh SSL Context when SSL certificate files change",
    )
166
167
168
169
    parser.add_argument(
        "--ssl-cert-reqs",
        type=int,
        default=int(ssl.CERT_NONE),
170
        help="Whether client certificate is required (see stdlib ssl module's)",
171
    )
172
173
174
175
    parser.add_argument(
        "--root-path",
        type=str,
        default=None,
176
177
        help="FastAPI root_path when app is behind a path based routing proxy",
    )
178
    parser.add_argument("--log-level", type=str, default="debug")
Zhuohan Li's avatar
Zhuohan Li committed
179
    parser = AsyncEngineArgs.add_cli_args(parser)
Zhuohan Li's avatar
Zhuohan Li committed
180
    args = parser.parse_args()
181

182
    asyncio.run(run_server(args))