"docs/vscode:/vscode.git/clone" did not exist on "1dfea5f4a95df8d14b46433a479a28d56e60494c"
api_server.py 5.61 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""
4
5
6
7
8
NOTE: This API server is used only for demonstrating usage of AsyncEngine
and simple performance benchmarks. It is not intended for production use.
For production use, we recommend using our OpenAI compatible server.
We are also not going to accept PRs modifying this file, please
change `vllm/entrypoints/openai/api_server.py` instead.
9
"""
10

11
import asyncio
Zhuohan Li's avatar
Zhuohan Li committed
12
import json
13
import ssl
14
from argparse import Namespace
15
from collections.abc import AsyncGenerator
16
from typing import Any
Zhuohan Li's avatar
Zhuohan Li committed
17

18
from fastapi import FastAPI, Request
19
from fastapi.responses import JSONResponse, Response, StreamingResponse
Zhuohan Li's avatar
Zhuohan Li committed
20

21
import vllm.envs as envs
Woosuk Kwon's avatar
Woosuk Kwon committed
22
23
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
24
from vllm.entrypoints.launcher import serve_http
25
from vllm.entrypoints.utils import with_cancellation
26
from vllm.logger import init_logger
Woosuk Kwon's avatar
Woosuk Kwon committed
27
from vllm.sampling_params import SamplingParams
yhu422's avatar
yhu422 committed
28
from vllm.usage.usage_lib import UsageContext
29
30
from vllm.utils import random_uuid, set_ulimit
from vllm.utils.argparse_utils import FlexibleArgumentParser
31
from vllm.version import __version__ as VLLM_VERSION
Zhuohan Li's avatar
Zhuohan Li committed
32

33
34
logger = init_logger("vllm.entrypoints.api_server")

Zhuohan Li's avatar
Zhuohan Li committed
35
app = FastAPI()
36
engine = None
Zhuohan Li's avatar
Zhuohan Li committed
37
38


39
40
41
42
43
44
@app.get("/health")
async def health() -> Response:
    """Health check."""
    return Response(status_code=200)


Zhuohan Li's avatar
Zhuohan Li committed
45
@app.post("/generate")
46
async def generate(request: Request) -> Response:
47
    """Generate completion for the request.
48
49
50

    The request should be a JSON object with the following fields:
    - prompt: the prompt to use for the generation.
51
    - stream: whether to stream the results or not.
52
53
    - other fields: the sampling parameters (See `SamplingParams` for details).
    """
Zhuohan Li's avatar
Zhuohan Li committed
54
    request_dict = await request.json()
55
56
57
58
59
    return await _generate(request_dict, raw_request=request)


@with_cancellation
async def _generate(request_dict: dict, raw_request: Request) -> Response:
Zhuohan Li's avatar
Zhuohan Li committed
60
    prompt = request_dict.pop("prompt")
61
    stream = request_dict.pop("stream", False)
Zhuohan Li's avatar
Zhuohan Li committed
62
    sampling_params = SamplingParams(**request_dict)
63
    request_id = random_uuid()
64

65
    assert engine is not None
66
    results_generator = engine.generate(prompt, sampling_params, request_id)
Zhuohan Li's avatar
Zhuohan Li committed
67

68
    # Streaming case
Zhuohan Li's avatar
Zhuohan Li committed
69
70
71
    async def stream_results() -> AsyncGenerator[bytes, None]:
        async for request_output in results_generator:
            prompt = request_output.prompt
72
            assert prompt is not None
73
            text_outputs = [prompt + output.text for output in request_output.outputs]
74
            ret = {"text": text_outputs}
75
            yield (json.dumps(ret) + "\n").encode("utf-8")
Zhuohan Li's avatar
Zhuohan Li committed
76

77
    if stream:
78
        return StreamingResponse(stream_results())
79
80
81

    # Non-streaming case
    final_output = None
82
83
84
85
86
    try:
        async for request_output in results_generator:
            final_output = request_output
    except asyncio.CancelledError:
        return Response(status_code=499)
87
88
89

    assert final_output is not None
    prompt = final_output.prompt
90
    assert prompt is not None
91
    text_outputs = [prompt + output.text for output in final_output.outputs]
92
    ret = {"text": text_outputs}
93
    return JSONResponse(ret)
Zhuohan Li's avatar
Zhuohan Li committed
94
95


96
97
98
99
100
101
102
103
104
def build_app(args: Namespace) -> FastAPI:
    global app

    app.root_path = args.root_path
    return app


async def init_app(
    args: Namespace,
105
    llm_engine: AsyncLLMEngine | None = None,
106
107
108
109
110
111
) -> FastAPI:
    app = build_app(args)

    global engine

    engine_args = AsyncEngineArgs.from_cli_args(args)
112
113
114
115
116
117
118
    engine = (
        llm_engine
        if llm_engine is not None
        else AsyncLLMEngine.from_engine_args(
            engine_args, usage_context=UsageContext.API_SERVER
        )
    )
119
    app.state.engine_client = engine
120
121
122
    return app


123
async def run_server(
124
    args: Namespace, llm_engine: AsyncLLMEngine | None = None, **uvicorn_kwargs: Any
125
) -> None:
126
127
128
    logger.info("vLLM API server version %s", VLLM_VERSION)
    logger.info("args: %s", args)

129
130
    set_ulimit()

131
    app = await init_app(args, llm_engine)
132
    assert engine is not None
133
134
135

    shutdown_task = await serve_http(
        app,
136
        sock=None,
137
        enable_ssl_refresh=args.enable_ssl_refresh,
138
139
140
        host=args.host,
        port=args.port,
        log_level=args.log_level,
141
        timeout_keep_alive=envs.VLLM_HTTP_TIMEOUT_KEEP_ALIVE,
142
143
144
145
146
147
148
149
150
151
        ssl_keyfile=args.ssl_keyfile,
        ssl_certfile=args.ssl_certfile,
        ssl_ca_certs=args.ssl_ca_certs,
        ssl_cert_reqs=args.ssl_cert_reqs,
        **uvicorn_kwargs,
    )

    await shutdown_task


Zhuohan Li's avatar
Zhuohan Li committed
152
if __name__ == "__main__":
153
    parser = FlexibleArgumentParser()
154
    parser.add_argument("--host", type=str, default=None)
155
    parser.add_argument("--port", type=parser.check_port, default=8000)
156
157
    parser.add_argument("--ssl-keyfile", type=str, default=None)
    parser.add_argument("--ssl-certfile", type=str, default=None)
158
159
160
    parser.add_argument(
        "--ssl-ca-certs", type=str, default=None, help="The CA certificates file"
    )
161
162
163
164
    parser.add_argument(
        "--enable-ssl-refresh",
        action="store_true",
        default=False,
165
166
        help="Refresh SSL Context when SSL certificate files change",
    )
167
168
169
170
    parser.add_argument(
        "--ssl-cert-reqs",
        type=int,
        default=int(ssl.CERT_NONE),
171
        help="Whether client certificate is required (see stdlib ssl module's)",
172
    )
173
174
175
176
    parser.add_argument(
        "--root-path",
        type=str,
        default=None,
177
178
        help="FastAPI root_path when app is behind a path based routing proxy",
    )
179
    parser.add_argument("--log-level", type=str, default="debug")
Zhuohan Li's avatar
Zhuohan Li committed
180
    parser = AsyncEngineArgs.add_cli_args(parser)
Zhuohan Li's avatar
Zhuohan Li committed
181
    args = parser.parse_args()
182

183
    asyncio.run(run_server(args))