api_server.py 5.65 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""
4
5
6
7
8
NOTE: This API server is used only for demonstrating usage of AsyncEngine
and simple performance benchmarks. It is not intended for production use.
For production use, we recommend using our OpenAI compatible server.
We are also not going to accept PRs modifying this file, please
change `vllm/entrypoints/openai/api_server.py` instead.
9
"""
10

11
import asyncio
Zhuohan Li's avatar
Zhuohan Li committed
12
import json
13
import ssl
14
from argparse import Namespace
15
from collections.abc import AsyncGenerator
16
from typing import Any
Zhuohan Li's avatar
Zhuohan Li committed
17

18
from fastapi import FastAPI, Request
19
from fastapi.responses import JSONResponse, Response, StreamingResponse
Zhuohan Li's avatar
Zhuohan Li committed
20

21
import vllm.envs as envs
Woosuk Kwon's avatar
Woosuk Kwon committed
22
23
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
24
from vllm.entrypoints.launcher import serve_http
25
from vllm.entrypoints.utils import with_cancellation
26
from vllm.logger import init_logger
Woosuk Kwon's avatar
Woosuk Kwon committed
27
from vllm.sampling_params import SamplingParams
yhu422's avatar
yhu422 committed
28
from vllm.usage.usage_lib import UsageContext
29
from vllm.utils import random_uuid
30
from vllm.utils.argparse_utils import FlexibleArgumentParser
31
from vllm.utils.system_utils import set_ulimit
32
from vllm.version import __version__ as VLLM_VERSION
Zhuohan Li's avatar
Zhuohan Li committed
33

34
35
logger = init_logger("vllm.entrypoints.api_server")

Zhuohan Li's avatar
Zhuohan Li committed
36
app = FastAPI()
37
engine = None
Zhuohan Li's avatar
Zhuohan Li committed
38
39


40
41
42
43
44
45
@app.get("/health")
async def health() -> Response:
    """Health check."""
    return Response(status_code=200)


Zhuohan Li's avatar
Zhuohan Li committed
46
@app.post("/generate")
47
async def generate(request: Request) -> Response:
48
    """Generate completion for the request.
49
50
51

    The request should be a JSON object with the following fields:
    - prompt: the prompt to use for the generation.
52
    - stream: whether to stream the results or not.
53
54
    - other fields: the sampling parameters (See `SamplingParams` for details).
    """
Zhuohan Li's avatar
Zhuohan Li committed
55
    request_dict = await request.json()
56
57
58
59
60
    return await _generate(request_dict, raw_request=request)


@with_cancellation
async def _generate(request_dict: dict, raw_request: Request) -> Response:
Zhuohan Li's avatar
Zhuohan Li committed
61
    prompt = request_dict.pop("prompt")
62
    stream = request_dict.pop("stream", False)
Zhuohan Li's avatar
Zhuohan Li committed
63
    sampling_params = SamplingParams(**request_dict)
64
    request_id = random_uuid()
65

66
    assert engine is not None
67
    results_generator = engine.generate(prompt, sampling_params, request_id)
Zhuohan Li's avatar
Zhuohan Li committed
68

69
    # Streaming case
Zhuohan Li's avatar
Zhuohan Li committed
70
71
72
    async def stream_results() -> AsyncGenerator[bytes, None]:
        async for request_output in results_generator:
            prompt = request_output.prompt
73
            assert prompt is not None
74
            text_outputs = [prompt + output.text for output in request_output.outputs]
75
            ret = {"text": text_outputs}
76
            yield (json.dumps(ret) + "\n").encode("utf-8")
Zhuohan Li's avatar
Zhuohan Li committed
77

78
    if stream:
79
        return StreamingResponse(stream_results())
80
81
82

    # Non-streaming case
    final_output = None
83
84
85
86
87
    try:
        async for request_output in results_generator:
            final_output = request_output
    except asyncio.CancelledError:
        return Response(status_code=499)
88
89
90

    assert final_output is not None
    prompt = final_output.prompt
91
    assert prompt is not None
92
    text_outputs = [prompt + output.text for output in final_output.outputs]
93
    ret = {"text": text_outputs}
94
    return JSONResponse(ret)
Zhuohan Li's avatar
Zhuohan Li committed
95
96


97
98
99
100
101
102
103
104
105
def build_app(args: Namespace) -> FastAPI:
    global app

    app.root_path = args.root_path
    return app


async def init_app(
    args: Namespace,
106
    llm_engine: AsyncLLMEngine | None = None,
107
108
109
110
111
112
) -> FastAPI:
    app = build_app(args)

    global engine

    engine_args = AsyncEngineArgs.from_cli_args(args)
113
114
115
116
117
118
119
    engine = (
        llm_engine
        if llm_engine is not None
        else AsyncLLMEngine.from_engine_args(
            engine_args, usage_context=UsageContext.API_SERVER
        )
    )
120
    app.state.engine_client = engine
121
122
123
    return app


124
async def run_server(
125
    args: Namespace, llm_engine: AsyncLLMEngine | None = None, **uvicorn_kwargs: Any
126
) -> None:
127
128
129
    logger.info("vLLM API server version %s", VLLM_VERSION)
    logger.info("args: %s", args)

130
131
    set_ulimit()

132
    app = await init_app(args, llm_engine)
133
    assert engine is not None
134
135
136

    shutdown_task = await serve_http(
        app,
137
        sock=None,
138
        enable_ssl_refresh=args.enable_ssl_refresh,
139
140
141
        host=args.host,
        port=args.port,
        log_level=args.log_level,
142
        timeout_keep_alive=envs.VLLM_HTTP_TIMEOUT_KEEP_ALIVE,
143
144
145
146
147
148
149
150
151
152
        ssl_keyfile=args.ssl_keyfile,
        ssl_certfile=args.ssl_certfile,
        ssl_ca_certs=args.ssl_ca_certs,
        ssl_cert_reqs=args.ssl_cert_reqs,
        **uvicorn_kwargs,
    )

    await shutdown_task


Zhuohan Li's avatar
Zhuohan Li committed
153
if __name__ == "__main__":
154
    parser = FlexibleArgumentParser()
155
    parser.add_argument("--host", type=str, default=None)
156
    parser.add_argument("--port", type=parser.check_port, default=8000)
157
158
    parser.add_argument("--ssl-keyfile", type=str, default=None)
    parser.add_argument("--ssl-certfile", type=str, default=None)
159
160
161
    parser.add_argument(
        "--ssl-ca-certs", type=str, default=None, help="The CA certificates file"
    )
162
163
164
165
    parser.add_argument(
        "--enable-ssl-refresh",
        action="store_true",
        default=False,
166
167
        help="Refresh SSL Context when SSL certificate files change",
    )
168
169
170
171
    parser.add_argument(
        "--ssl-cert-reqs",
        type=int,
        default=int(ssl.CERT_NONE),
172
        help="Whether client certificate is required (see stdlib ssl module's)",
173
    )
174
175
176
177
    parser.add_argument(
        "--root-path",
        type=str,
        default=None,
178
179
        help="FastAPI root_path when app is behind a path based routing proxy",
    )
180
    parser.add_argument("--log-level", type=str, default="debug")
Zhuohan Li's avatar
Zhuohan Li committed
181
    parser = AsyncEngineArgs.add_cli_args(parser)
Zhuohan Li's avatar
Zhuohan Li committed
182
    args = parser.parse_args()
183

184
    asyncio.run(run_server(args))