api_server.py 5.76 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""
4
5
6
7
8
NOTE: This API server is used only for demonstrating usage of AsyncEngine
and simple performance benchmarks. It is not intended for production use.
For production use, we recommend using our OpenAI compatible server.
We are also not going to accept PRs modifying this file, please
change `vllm/entrypoints/openai/api_server.py` instead.
9
"""
10

11
import asyncio
Zhuohan Li's avatar
Zhuohan Li committed
12
import json
13
import ssl
14
from argparse import Namespace
15
from collections.abc import AsyncGenerator
16
from typing import Any
Zhuohan Li's avatar
Zhuohan Li committed
17

18
from fastapi import FastAPI, Request
19
from fastapi.responses import JSONResponse, Response, StreamingResponse
Zhuohan Li's avatar
Zhuohan Li committed
20

21
import vllm.envs as envs
Woosuk Kwon's avatar
Woosuk Kwon committed
22
23
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
24
from vllm.entrypoints.launcher import serve_http
25
from vllm.entrypoints.utils import with_cancellation
26
from vllm.logger import init_logger
Woosuk Kwon's avatar
Woosuk Kwon committed
27
from vllm.sampling_params import SamplingParams
yhu422's avatar
yhu422 committed
28
from vllm.usage.usage_lib import UsageContext
29
from vllm.utils import random_uuid
30
from vllm.utils.argparse_utils import FlexibleArgumentParser
31
from vllm.utils.system_utils import set_ulimit
32
from vllm.version import __version__ as VLLM_VERSION
Zhuohan Li's avatar
Zhuohan Li committed
33

34
35
logger = init_logger("vllm.entrypoints.api_server")

Zhuohan Li's avatar
Zhuohan Li committed
36
app = FastAPI()
37
engine = None
Zhuohan Li's avatar
Zhuohan Li committed
38
39


40
41
42
43
44
45
@app.get("/health")
async def health() -> Response:
    """Health check."""
    return Response(status_code=200)


Zhuohan Li's avatar
Zhuohan Li committed
46
@app.post("/generate")
47
async def generate(request: Request) -> Response:
48
    """Generate completion for the request.
49
50
51

    The request should be a JSON object with the following fields:
    - prompt: the prompt to use for the generation.
52
    - stream: whether to stream the results or not.
53
54
    - other fields: the sampling parameters (See `SamplingParams` for details).
    """
Zhuohan Li's avatar
Zhuohan Li committed
55
    request_dict = await request.json()
56
57
58
59
60
    return await _generate(request_dict, raw_request=request)


@with_cancellation
async def _generate(request_dict: dict, raw_request: Request) -> Response:
Zhuohan Li's avatar
Zhuohan Li committed
61
    prompt = request_dict.pop("prompt")
62
    stream = request_dict.pop("stream", False)
63
64
    # Since SamplingParams is created fresh per request, safe to skip clone
    sampling_params = SamplingParams(**request_dict, skip_clone=True)
65
    request_id = random_uuid()
66

67
    assert engine is not None
68
    results_generator = engine.generate(prompt, sampling_params, request_id)
Zhuohan Li's avatar
Zhuohan Li committed
69

70
    # Streaming case
Zhuohan Li's avatar
Zhuohan Li committed
71
72
73
    async def stream_results() -> AsyncGenerator[bytes, None]:
        async for request_output in results_generator:
            prompt = request_output.prompt
74
            assert prompt is not None
75
            text_outputs = [prompt + output.text for output in request_output.outputs]
76
            ret = {"text": text_outputs}
77
            yield (json.dumps(ret) + "\n").encode("utf-8")
Zhuohan Li's avatar
Zhuohan Li committed
78

79
    if stream:
80
        return StreamingResponse(stream_results())
81
82
83

    # Non-streaming case
    final_output = None
84
85
86
87
88
    try:
        async for request_output in results_generator:
            final_output = request_output
    except asyncio.CancelledError:
        return Response(status_code=499)
89
90
91

    assert final_output is not None
    prompt = final_output.prompt
92
    assert prompt is not None
93
    text_outputs = [prompt + output.text for output in final_output.outputs]
94
    ret = {"text": text_outputs}
95
    return JSONResponse(ret)
Zhuohan Li's avatar
Zhuohan Li committed
96
97


98
99
100
101
102
103
104
105
106
def build_app(args: Namespace) -> FastAPI:
    global app

    app.root_path = args.root_path
    return app


async def init_app(
    args: Namespace,
107
    llm_engine: AsyncLLMEngine | None = None,
108
109
110
111
112
113
) -> FastAPI:
    app = build_app(args)

    global engine

    engine_args = AsyncEngineArgs.from_cli_args(args)
114
115
116
117
118
119
120
    engine = (
        llm_engine
        if llm_engine is not None
        else AsyncLLMEngine.from_engine_args(
            engine_args, usage_context=UsageContext.API_SERVER
        )
    )
121
    app.state.engine_client = engine
122
    app.state.args = args
123
124
125
    return app


126
async def run_server(
127
    args: Namespace, llm_engine: AsyncLLMEngine | None = None, **uvicorn_kwargs: Any
128
) -> None:
129
130
131
    logger.info("vLLM API server version %s", VLLM_VERSION)
    logger.info("args: %s", args)

132
133
    set_ulimit()

134
    app = await init_app(args, llm_engine)
135
    assert engine is not None
136
137
138

    shutdown_task = await serve_http(
        app,
139
        sock=None,
140
        enable_ssl_refresh=args.enable_ssl_refresh,
141
142
143
        host=args.host,
        port=args.port,
        log_level=args.log_level,
144
        timeout_keep_alive=envs.VLLM_HTTP_TIMEOUT_KEEP_ALIVE,
145
146
147
148
149
150
151
152
153
154
        ssl_keyfile=args.ssl_keyfile,
        ssl_certfile=args.ssl_certfile,
        ssl_ca_certs=args.ssl_ca_certs,
        ssl_cert_reqs=args.ssl_cert_reqs,
        **uvicorn_kwargs,
    )

    await shutdown_task


Zhuohan Li's avatar
Zhuohan Li committed
155
if __name__ == "__main__":
156
    parser = FlexibleArgumentParser()
157
    parser.add_argument("--host", type=str, default=None)
158
    parser.add_argument("--port", type=parser.check_port, default=8000)
159
160
    parser.add_argument("--ssl-keyfile", type=str, default=None)
    parser.add_argument("--ssl-certfile", type=str, default=None)
161
162
163
    parser.add_argument(
        "--ssl-ca-certs", type=str, default=None, help="The CA certificates file"
    )
164
165
166
167
    parser.add_argument(
        "--enable-ssl-refresh",
        action="store_true",
        default=False,
168
169
        help="Refresh SSL Context when SSL certificate files change",
    )
170
171
172
173
    parser.add_argument(
        "--ssl-cert-reqs",
        type=int,
        default=int(ssl.CERT_NONE),
174
        help="Whether client certificate is required (see stdlib ssl module's)",
175
    )
176
177
178
179
    parser.add_argument(
        "--root-path",
        type=str,
        default=None,
180
181
        help="FastAPI root_path when app is behind a path based routing proxy",
    )
182
    parser.add_argument("--log-level", type=str, default="debug")
Zhuohan Li's avatar
Zhuohan Li committed
183
    parser = AsyncEngineArgs.add_cli_args(parser)
Zhuohan Li's avatar
Zhuohan Li committed
184
    args = parser.parse_args()
185

186
    asyncio.run(run_server(args))