api_server.py 13 KB
Newer Older
1
import asyncio
2
3
import importlib
import inspect
4
import re
5
from argparse import Namespace
6
7
from contextlib import asynccontextmanager
from http import HTTPStatus
8
9
from multiprocessing import Process
from typing import AsyncIterator, Set
10

11
from fastapi import APIRouter, FastAPI, Request
Zhuohan Li's avatar
Zhuohan Li committed
12
13
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
14
15
from fastapi.responses import JSONResponse, Response, StreamingResponse
from prometheus_client import make_asgi_app
16
from starlette.routing import Mount
Zhuohan Li's avatar
Zhuohan Li committed
17

18
import vllm.envs as envs
19
from vllm.config import ModelConfig
Woosuk Kwon's avatar
Woosuk Kwon committed
20
21
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
22
from vllm.engine.protocol import AsyncEngineClient
23
from vllm.entrypoints.launcher import serve_http
24
from vllm.entrypoints.logger import RequestLogger
25
from vllm.entrypoints.openai.cli_args import make_arg_parser
26
27
# yapf conflicts with isort for this block
# yapf: disable
28
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
29
                                              ChatCompletionResponse,
30
                                              CompletionRequest,
31
32
33
34
35
                                              DetokenizeRequest,
                                              DetokenizeResponse,
                                              EmbeddingRequest, ErrorResponse,
                                              TokenizeRequest,
                                              TokenizeResponse)
36
37
from vllm.entrypoints.openai.rpc.client import AsyncEngineRPCClient
from vllm.entrypoints.openai.rpc.server import run_rpc_server
38
# yapf: enable
39
40
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
41
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
42
43
from vllm.entrypoints.openai.serving_tokenization import (
    OpenAIServingTokenization)
44
from vllm.logger import init_logger
yhu422's avatar
yhu422 committed
45
from vllm.usage.usage_lib import UsageContext
46
from vllm.utils import FlexibleArgumentParser, get_open_port
47
from vllm.version import __version__ as VLLM_VERSION
Zhuohan Li's avatar
Zhuohan Li committed
48

49
TIMEOUT_KEEP_ALIVE = 5  # seconds
Zhuohan Li's avatar
Zhuohan Li committed
50

51
async_engine_client: AsyncEngineClient
Ethan Xu's avatar
Ethan Xu committed
52
engine_args: AsyncEngineArgs
53
54
openai_serving_chat: OpenAIServingChat
openai_serving_completion: OpenAIServingCompletion
55
openai_serving_embedding: OpenAIServingEmbedding
56
openai_serving_tokenization: OpenAIServingTokenization
57

58
logger = init_logger('vllm.entrypoints.openai.api_server')
59

60
_running_tasks: Set[asyncio.Task] = set()
61

62

63
64
65
66
67
68
69
70
71
def model_is_embedding(model_name: str) -> bool:
    return ModelConfig(model=model_name,
                       tokenizer=model_name,
                       tokenizer_mode="auto",
                       trust_remote_code=False,
                       seed=0,
                       dtype="float16").embedding_mode


72
@asynccontextmanager
73
async def lifespan(app: FastAPI):
74
75
76
77

    async def _force_log():
        while True:
            await asyncio.sleep(10)
78
            await async_engine_client.do_log_stats()
79
80

    if not engine_args.disable_log_stats:
81
82
83
        task = asyncio.create_task(_force_log())
        _running_tasks.add(task)
        task.add_done_callback(_running_tasks.remove)
84
85
86
87

    yield


88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
@asynccontextmanager
async def build_async_engine_client(args) -> AsyncIterator[AsyncEngineClient]:
    # Context manager to handle async_engine_client lifecycle
    # Ensures everything is shutdown and cleaned up on error/exit
    global engine_args
    engine_args = AsyncEngineArgs.from_cli_args(args)

    # Backend itself still global for the silly lil' health handler
    global async_engine_client

    # If manually triggered or embedding model, use AsyncLLMEngine in process.
    # TODO: support embedding model via RPC.
    if (model_is_embedding(args.model)
            or args.disable_frontend_multiprocessing):
        async_engine_client = AsyncLLMEngine.from_engine_args(
            engine_args, usage_context=UsageContext.OPENAI_API_SERVER)
        yield async_engine_client
        return

    # Otherwise, use the multiprocessing AsyncLLMEngine.
    else:
        # Start RPCServer in separate process (holds the AsyncLLMEngine).
        port = get_open_port(envs.VLLM_RPC_PORT)
        rpc_server_process = Process(target=run_rpc_server,
                                     args=(engine_args,
                                           UsageContext.OPENAI_API_SERVER,
                                           port))
        rpc_server_process.start()

        # Build RPCClient, which conforms to AsyncEngineClient Protocol.
        async_engine_client = AsyncEngineRPCClient(port)
        await async_engine_client.setup()

        try:
            yield async_engine_client
        finally:
            # Ensure rpc server process was terminated
            rpc_server_process.terminate()

            # Close all open connections to the backend
            async_engine_client.close()

            # Wait for server process to join
            rpc_server_process.join()


Ethan Xu's avatar
Ethan Xu committed
134
router = APIRouter()
Zhuohan Li's avatar
Zhuohan Li committed
135

136

137
def mount_metrics(app: FastAPI):
138
139
140
141
142
    # Add prometheus asgi middleware to route /metrics requests
    metrics_route = Mount("/metrics", make_asgi_app())
    # Workaround for 307 Redirect for /metrics
    metrics_route.path_regex = re.compile('^/metrics(?P<path>.*)$')
    app.routes.append(metrics_route)
143
144


Ethan Xu's avatar
Ethan Xu committed
145
@router.get("/health")
146
147
async def health() -> Response:
    """Health check."""
148
    await async_engine_client.check_health()
149
150
151
    return Response(status_code=200)


Ethan Xu's avatar
Ethan Xu committed
152
@router.post("/tokenize")
153
async def tokenize(request: TokenizeRequest):
154
    generator = await openai_serving_tokenization.create_tokenize(request)
155
156
157
158
159
160
161
162
    if isinstance(generator, ErrorResponse):
        return JSONResponse(content=generator.model_dump(),
                            status_code=generator.code)
    else:
        assert isinstance(generator, TokenizeResponse)
        return JSONResponse(content=generator.model_dump())


Ethan Xu's avatar
Ethan Xu committed
163
@router.post("/detokenize")
164
async def detokenize(request: DetokenizeRequest):
165
    generator = await openai_serving_tokenization.create_detokenize(request)
166
167
168
169
170
171
172
173
    if isinstance(generator, ErrorResponse):
        return JSONResponse(content=generator.model_dump(),
                            status_code=generator.code)
    else:
        assert isinstance(generator, DetokenizeResponse)
        return JSONResponse(content=generator.model_dump())


Ethan Xu's avatar
Ethan Xu committed
174
@router.get("/v1/models")
Zhuohan Li's avatar
Zhuohan Li committed
175
async def show_available_models():
176
    models = await openai_serving_completion.show_available_models()
177
    return JSONResponse(content=models.model_dump())
Zhuohan Li's avatar
Zhuohan Li committed
178
179


Ethan Xu's avatar
Ethan Xu committed
180
@router.get("/version")
181
async def show_version():
182
    ver = {"version": VLLM_VERSION}
183
184
185
    return JSONResponse(content=ver)


Ethan Xu's avatar
Ethan Xu committed
186
@router.post("/v1/chat/completions")
187
188
async def create_chat_completion(request: ChatCompletionRequest,
                                 raw_request: Request):
189
190
    generator = await openai_serving_chat.create_chat_completion(
        request, raw_request)
191
192
193
194
    if isinstance(generator, ErrorResponse):
        return JSONResponse(content=generator.model_dump(),
                            status_code=generator.code)
    if request.stream:
195
        return StreamingResponse(content=generator,
196
                                 media_type="text/event-stream")
197
    else:
198
        assert isinstance(generator, ChatCompletionResponse)
199
        return JSONResponse(content=generator.model_dump())
200
201


Ethan Xu's avatar
Ethan Xu committed
202
@router.post("/v1/completions")
203
async def create_completion(request: CompletionRequest, raw_request: Request):
204
205
    generator = await openai_serving_completion.create_completion(
        request, raw_request)
206
207
208
209
    if isinstance(generator, ErrorResponse):
        return JSONResponse(content=generator.model_dump(),
                            status_code=generator.code)
    if request.stream:
210
        return StreamingResponse(content=generator,
Zhuohan Li's avatar
Zhuohan Li committed
211
                                 media_type="text/event-stream")
212
    else:
213
        return JSONResponse(content=generator.model_dump())
Zhuohan Li's avatar
Zhuohan Li committed
214
215


Ethan Xu's avatar
Ethan Xu committed
216
@router.post("/v1/embeddings")
217
218
219
220
221
222
223
224
225
226
async def create_embedding(request: EmbeddingRequest, raw_request: Request):
    generator = await openai_serving_embedding.create_embedding(
        request, raw_request)
    if isinstance(generator, ErrorResponse):
        return JSONResponse(content=generator.model_dump(),
                            status_code=generator.code)
    else:
        return JSONResponse(content=generator.model_dump())


227
228
def build_app(args: Namespace) -> FastAPI:
    app = FastAPI(lifespan=lifespan)
Ethan Xu's avatar
Ethan Xu committed
229
230
    app.include_router(router)
    app.root_path = args.root_path
Zhuohan Li's avatar
Zhuohan Li committed
231

232
233
    mount_metrics(app)

Zhuohan Li's avatar
Zhuohan Li committed
234
235
236
237
238
239
240
241
    app.add_middleware(
        CORSMiddleware,
        allow_origins=args.allowed_origins,
        allow_credentials=args.allow_credentials,
        allow_methods=args.allowed_methods,
        allow_headers=args.allowed_headers,
    )

Ethan Xu's avatar
Ethan Xu committed
242
243
244
245
246
247
    @app.exception_handler(RequestValidationError)
    async def validation_exception_handler(_, exc):
        err = openai_serving_chat.create_error_response(message=str(exc))
        return JSONResponse(err.model_dump(),
                            status_code=HTTPStatus.BAD_REQUEST)

248
    if token := envs.VLLM_API_KEY or args.api_key:
249
250
251

        @app.middleware("http")
        async def authentication(request: Request, call_next):
252
            root_path = "" if args.root_path is None else args.root_path
253
254
            if request.method == "OPTIONS":
                return await call_next(request)
255
            if not request.url.path.startswith(f"{root_path}/v1"):
256
257
258
259
260
261
262
263
264
265
266
267
268
269
                return await call_next(request)
            if request.headers.get("Authorization") != "Bearer " + token:
                return JSONResponse(content={"error": "Unauthorized"},
                                    status_code=401)
            return await call_next(request)

    for middleware in args.middleware:
        module_path, object_name = middleware.rsplit(".", 1)
        imported = getattr(importlib.import_module(module_path), object_name)
        if inspect.isclass(imported):
            app.add_middleware(imported)
        elif inspect.iscoroutinefunction(imported):
            app.middleware("http")(imported)
        else:
270
271
            raise ValueError(f"Invalid middleware {middleware}. "
                             f"Must be a function or a class.")
272

Ethan Xu's avatar
Ethan Xu committed
273
274
275
    return app


276
async def init_app(
277
    async_engine_client: AsyncEngineClient,
278
279
    args: Namespace,
) -> FastAPI:
Ethan Xu's avatar
Ethan Xu committed
280
281
    app = build_app(args)

282
    if args.served_model_name is not None:
283
        served_model_names = args.served_model_name
284
    else:
285
        served_model_names = [args.model]
286

287
    model_config = await async_engine_client.get_model_config()
288

289
290
291
292
293
    if args.disable_log_requests:
        request_logger = None
    else:
        request_logger = RequestLogger(max_log_len=args.max_log_len)

Ethan Xu's avatar
Ethan Xu committed
294
295
296
    global openai_serving_chat
    global openai_serving_completion
    global openai_serving_embedding
297
    global openai_serving_tokenization
Ethan Xu's avatar
Ethan Xu committed
298

299
    openai_serving_chat = OpenAIServingChat(
300
        async_engine_client,
301
302
303
304
305
306
307
        model_config,
        served_model_names,
        args.response_role,
        lora_modules=args.lora_modules,
        prompt_adapters=args.prompt_adapters,
        request_logger=request_logger,
        chat_template=args.chat_template,
308
        return_tokens_as_token_ids=args.return_tokens_as_token_ids,
309
    )
310
    openai_serving_completion = OpenAIServingCompletion(
311
        async_engine_client,
312
313
314
315
316
        model_config,
        served_model_names,
        lora_modules=args.lora_modules,
        prompt_adapters=args.prompt_adapters,
        request_logger=request_logger,
317
        return_tokens_as_token_ids=args.return_tokens_as_token_ids,
318
319
    )
    openai_serving_embedding = OpenAIServingEmbedding(
320
        async_engine_client,
321
322
323
324
        model_config,
        served_model_names,
        request_logger=request_logger,
    )
325
    openai_serving_tokenization = OpenAIServingTokenization(
326
        async_engine_client,
327
328
329
330
331
332
        model_config,
        served_model_names,
        lora_modules=args.lora_modules,
        request_logger=request_logger,
        chat_template=args.chat_template,
    )
333
    app.root_path = args.root_path
334

335
    return app
336
337


338
async def run_server(args, **uvicorn_kwargs) -> None:
339
340
341
    logger.info("vLLM API server version %s", VLLM_VERSION)
    logger.info("args: %s", args)

342
    async with build_async_engine_client(args) as async_engine_client:
343
344
345
346
347
348
349
350
351
352
353
354
        app = await init_app(async_engine_client, args)

        shutdown_task = await serve_http(
            app,
            host=args.host,
            port=args.port,
            log_level=args.uvicorn_log_level,
            timeout_keep_alive=TIMEOUT_KEEP_ALIVE,
            ssl_keyfile=args.ssl_keyfile,
            ssl_certfile=args.ssl_certfile,
            ssl_ca_certs=args.ssl_ca_certs,
            ssl_cert_reqs=args.ssl_cert_reqs,
355
356
357
            **uvicorn_kwargs,
        )

358
359
    # NB: Await server shutdown only after the backend context is exited
    await shutdown_task
360

Ethan Xu's avatar
Ethan Xu committed
361
362
363
364
365
366
367
368

if __name__ == "__main__":
    # NOTE(simon):
    # This section should be in sync with vllm/scripts.py for CLI entrypoints.
    parser = FlexibleArgumentParser(
        description="vLLM OpenAI-Compatible RESTful API server.")
    parser = make_arg_parser(parser)
    args = parser.parse_args()
369

370
    asyncio.run(run_server(args))