api_server.py 17.9 KB
Newer Older
1
import asyncio
2
3
import importlib
import inspect
4
import multiprocessing
5
import os
6
import re
7
import tempfile
8
from argparse import Namespace
9
from contextlib import asynccontextmanager
10
from http import HTTPStatus
11
from typing import AsyncIterator, Optional, Set
12

13
from fastapi import APIRouter, FastAPI, Request
Zhuohan Li's avatar
Zhuohan Li committed
14
15
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
16
from fastapi.responses import JSONResponse, Response, StreamingResponse
17
from starlette.routing import Mount
18
from typing_extensions import assert_never
Zhuohan Li's avatar
Zhuohan Li committed
19

20
import vllm.envs as envs
21
from vllm.config import ModelConfig
Woosuk Kwon's avatar
Woosuk Kwon committed
22
23
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
24
from vllm.engine.protocol import AsyncEngineClient
25
from vllm.entrypoints.launcher import serve_http
26
from vllm.entrypoints.logger import RequestLogger
27
from vllm.entrypoints.openai.cli_args import make_arg_parser
28
29
# yapf conflicts with isort for this block
# yapf: disable
30
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
31
                                              ChatCompletionResponse,
32
                                              CompletionRequest,
33
                                              CompletionResponse,
34
35
                                              DetokenizeRequest,
                                              DetokenizeResponse,
36
37
                                              EmbeddingRequest,
                                              EmbeddingResponse, ErrorResponse,
38
39
                                              TokenizeRequest,
                                              TokenizeResponse)
40
# yapf: enable
41
42
from vllm.entrypoints.openai.rpc.client import AsyncEngineRPCClient
from vllm.entrypoints.openai.rpc.server import run_rpc_server
43
44
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
45
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
46
47
from vllm.entrypoints.openai.serving_tokenization import (
    OpenAIServingTokenization)
48
from vllm.logger import init_logger
yhu422's avatar
yhu422 committed
49
from vllm.usage.usage_lib import UsageContext
50
from vllm.utils import FlexibleArgumentParser, get_open_zmq_ipc_path
51
from vllm.version import __version__ as VLLM_VERSION
Zhuohan Li's avatar
Zhuohan Li committed
52

53
TIMEOUT_KEEP_ALIVE = 5  # seconds
Zhuohan Li's avatar
Zhuohan Li committed
54

55
async_engine_client: AsyncEngineClient
Ethan Xu's avatar
Ethan Xu committed
56
engine_args: AsyncEngineArgs
57
58
openai_serving_chat: OpenAIServingChat
openai_serving_completion: OpenAIServingCompletion
59
openai_serving_embedding: OpenAIServingEmbedding
60
openai_serving_tokenization: OpenAIServingTokenization
61
prometheus_multiproc_dir: tempfile.TemporaryDirectory
62

63
# Cannot use __name__ (https://github.com/vllm-project/vllm/pull/4765)
64
logger = init_logger('vllm.entrypoints.openai.api_server')
65

66
_running_tasks: Set[asyncio.Task] = set()
67

68

69
def model_is_embedding(model_name: str, trust_remote_code: bool,
70
                       quantization: Optional[str]) -> bool:
71
72
73
    return ModelConfig(model=model_name,
                       tokenizer=model_name,
                       tokenizer_mode="auto",
74
                       trust_remote_code=trust_remote_code,
75
                       quantization=quantization,
76
                       seed=0,
77
                       dtype="auto").embedding_mode
78
79


80
@asynccontextmanager
81
async def lifespan(app: FastAPI):
82
83
84
85

    async def _force_log():
        while True:
            await asyncio.sleep(10)
86
            await async_engine_client.do_log_stats()
87
88

    if not engine_args.disable_log_stats:
89
90
91
        task = asyncio.create_task(_force_log())
        _running_tasks.add(task)
        task.add_done_callback(_running_tasks.remove)
92
93
94
95

    yield


96
@asynccontextmanager
97
async def build_async_engine_client(
98
99
        args: Namespace) -> AsyncIterator[Optional[AsyncEngineClient]]:

100
101
102
103
104
105
106
107
    # Context manager to handle async_engine_client lifecycle
    # Ensures everything is shutdown and cleaned up on error/exit
    global engine_args
    engine_args = AsyncEngineArgs.from_cli_args(args)

    # Backend itself still global for the silly lil' health handler
    global async_engine_client

108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
    async with build_async_engine_client_from_engine_args(
            engine_args, args.disable_frontend_multiprocessing) as engine:

        async_engine_client = engine  # type: ignore[assignment]
        yield engine


@asynccontextmanager
async def build_async_engine_client_from_engine_args(
    engine_args: AsyncEngineArgs,
    disable_frontend_multiprocessing: bool = False,
) -> AsyncIterator[Optional[AsyncEngineClient]]:
    """
    Create AsyncEngineClient, either:
        - in-process using the AsyncLLMEngine Directly
        - multiprocess using AsyncLLMEngine RPC

    Returns the Client or None if the creation failed.
    """

128
129
    # If manually triggered or embedding model, use AsyncLLMEngine in process.
    # TODO: support embedding model via RPC.
130
131
132
133
    if (model_is_embedding(engine_args.model, engine_args.trust_remote_code,
                           engine_args.quantization)
            or disable_frontend_multiprocessing):
        engine_client = AsyncLLMEngine.from_engine_args(
134
            engine_args, usage_context=UsageContext.OPENAI_API_SERVER)
135
136
137
138
        try:
            yield engine_client
        finally:
            engine_client.shutdown_background_loop()
139
140
141
142
        return

    # Otherwise, use the multiprocessing AsyncLLMEngine.
    else:
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
        if "PROMETHEUS_MULTIPROC_DIR" not in os.environ:
            # Make TemporaryDirectory for prometheus multiprocessing
            # Note: global TemporaryDirectory will be automatically
            #   cleaned up upon exit.
            global prometheus_multiproc_dir
            prometheus_multiproc_dir = tempfile.TemporaryDirectory()
            os.environ[
                "PROMETHEUS_MULTIPROC_DIR"] = prometheus_multiproc_dir.name
        else:
            logger.warning(
                "Found PROMETHEUS_MULTIPROC_DIR was set by user. "
                "This directory must be wiped between vLLM runs or "
                "you will find inaccurate metrics. Unset the variable "
                "and vLLM will properly handle cleanup.")

158
159
160
161
162
        # Select random path for IPC.
        rpc_path = get_open_zmq_ipc_path()
        logger.info("Multiprocessing frontend to use %s for RPC Path.",
                    rpc_path)

163
164
165
166
167
        # Build RPCClient, which conforms to AsyncEngineClient Protocol.
        # NOTE: Actually, this is not true yet. We still need to support
        # embedding models via RPC (see TODO above)
        rpc_client = AsyncEngineRPCClient(rpc_path)

168
        # Start RPCServer in separate process (holds the AsyncLLMEngine).
169
170
171
172
173
174
        context = multiprocessing.get_context("spawn")
        # the current process might have CUDA context,
        # so we need to spawn a new process
        rpc_server_process = context.Process(
            target=run_rpc_server,
            args=(engine_args, UsageContext.OPENAI_API_SERVER, rpc_path))
175
        rpc_server_process.start()
176
177
        logger.info("Started engine process with PID %d",
                    rpc_server_process.pid)
178
179

        try:
180
181
            while True:
                try:
182
                    await rpc_client.setup()
183
                    break
184
                except TimeoutError:
185
                    if not rpc_server_process.is_alive():
186
187
188
189
190
                        logger.error(
                            "RPCServer process died before responding "
                            "to readiness probe")
                        yield None
                        return
191

192
            yield rpc_client  # type: ignore[misc]
193
194
195
196
197
        finally:
            # Ensure rpc server process was terminated
            rpc_server_process.terminate()

            # Close all open connections to the backend
198
            rpc_client.close()
199
200
201
202

            # Wait for server process to join
            rpc_server_process.join()

203
204
205
206
207
208
209
            # Lazy import for prometheus multiprocessing.
            # We need to set PROMETHEUS_MULTIPROC_DIR environment variable
            # before prometheus_client is imported.
            # See https://prometheus.github.io/client_python/multiprocess/
            from prometheus_client import multiprocess
            multiprocess.mark_process_dead(rpc_server_process.pid)

210

Ethan Xu's avatar
Ethan Xu committed
211
router = APIRouter()
Zhuohan Li's avatar
Zhuohan Li committed
212

213

214
def mount_metrics(app: FastAPI):
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
    # Lazy import for prometheus multiprocessing.
    # We need to set PROMETHEUS_MULTIPROC_DIR environment variable
    # before prometheus_client is imported.
    # See https://prometheus.github.io/client_python/multiprocess/
    from prometheus_client import (CollectorRegistry, make_asgi_app,
                                   multiprocess)

    prometheus_multiproc_dir_path = os.getenv("PROMETHEUS_MULTIPROC_DIR", None)
    if prometheus_multiproc_dir_path is not None:
        logger.info("vLLM to use %s as PROMETHEUS_MULTIPROC_DIR",
                    prometheus_multiproc_dir_path)
        registry = CollectorRegistry()
        multiprocess.MultiProcessCollector(registry)

        # Add prometheus asgi middleware to route /metrics requests
        metrics_route = Mount("/metrics", make_asgi_app(registry=registry))
    else:
        # Add prometheus asgi middleware to route /metrics requests
        metrics_route = Mount("/metrics", make_asgi_app())

235
    # Workaround for 307 Redirect for /metrics
236
    metrics_route.path_regex = re.compile("^/metrics(?P<path>.*)$")
237
    app.routes.append(metrics_route)
238
239


Ethan Xu's avatar
Ethan Xu committed
240
@router.get("/health")
241
242
async def health() -> Response:
    """Health check."""
243
    await async_engine_client.check_health()
244
245
246
    return Response(status_code=200)


Ethan Xu's avatar
Ethan Xu committed
247
@router.post("/tokenize")
248
async def tokenize(request: TokenizeRequest):
249
    generator = await openai_serving_tokenization.create_tokenize(request)
250
251
252
    if isinstance(generator, ErrorResponse):
        return JSONResponse(content=generator.model_dump(),
                            status_code=generator.code)
253
    elif isinstance(generator, TokenizeResponse):
254
255
        return JSONResponse(content=generator.model_dump())

256
257
    assert_never(generator)

258

Ethan Xu's avatar
Ethan Xu committed
259
@router.post("/detokenize")
260
async def detokenize(request: DetokenizeRequest):
261
    generator = await openai_serving_tokenization.create_detokenize(request)
262
263
264
    if isinstance(generator, ErrorResponse):
        return JSONResponse(content=generator.model_dump(),
                            status_code=generator.code)
265
    elif isinstance(generator, DetokenizeResponse):
266
267
        return JSONResponse(content=generator.model_dump())

268
269
    assert_never(generator)

270

Ethan Xu's avatar
Ethan Xu committed
271
@router.get("/v1/models")
Zhuohan Li's avatar
Zhuohan Li committed
272
async def show_available_models():
273
    models = await openai_serving_completion.show_available_models()
274
    return JSONResponse(content=models.model_dump())
Zhuohan Li's avatar
Zhuohan Li committed
275
276


Ethan Xu's avatar
Ethan Xu committed
277
@router.get("/version")
278
async def show_version():
279
    ver = {"version": VLLM_VERSION}
280
281
282
    return JSONResponse(content=ver)


Ethan Xu's avatar
Ethan Xu committed
283
@router.post("/v1/chat/completions")
284
285
async def create_chat_completion(request: ChatCompletionRequest,
                                 raw_request: Request):
286

287
288
    generator = await openai_serving_chat.create_chat_completion(
        request, raw_request)
289

290
291
292
    if isinstance(generator, ErrorResponse):
        return JSONResponse(content=generator.model_dump(),
                            status_code=generator.code)
293

294
    elif isinstance(generator, ChatCompletionResponse):
295
        return JSONResponse(content=generator.model_dump())
296

297
298
    return StreamingResponse(content=generator, media_type="text/event-stream")

299

Ethan Xu's avatar
Ethan Xu committed
300
@router.post("/v1/completions")
301
async def create_completion(request: CompletionRequest, raw_request: Request):
302
303
    generator = await openai_serving_completion.create_completion(
        request, raw_request)
304
305
306
    if isinstance(generator, ErrorResponse):
        return JSONResponse(content=generator.model_dump(),
                            status_code=generator.code)
307
    elif isinstance(generator, CompletionResponse):
308
        return JSONResponse(content=generator.model_dump())
Zhuohan Li's avatar
Zhuohan Li committed
309

310
311
    return StreamingResponse(content=generator, media_type="text/event-stream")

Zhuohan Li's avatar
Zhuohan Li committed
312

Ethan Xu's avatar
Ethan Xu committed
313
@router.post("/v1/embeddings")
314
315
316
317
318
319
async def create_embedding(request: EmbeddingRequest, raw_request: Request):
    generator = await openai_serving_embedding.create_embedding(
        request, raw_request)
    if isinstance(generator, ErrorResponse):
        return JSONResponse(content=generator.model_dump(),
                            status_code=generator.code)
320
    elif isinstance(generator, EmbeddingResponse):
321
322
        return JSONResponse(content=generator.model_dump())

323
324
    assert_never(generator)

325

326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
if envs.VLLM_TORCH_PROFILER_DIR:
    logger.warning(
        "Torch Profiler is enabled in the API server. This should ONLY be "
        "used for local development!")

    @router.post("/start_profile")
    async def start_profile():
        logger.info("Starting profiler...")
        await async_engine_client.start_profile()
        logger.info("Profiler started.")
        return Response(status_code=200)

    @router.post("/stop_profile")
    async def stop_profile():
        logger.info("Stopping profiler...")
        await async_engine_client.stop_profile()
        logger.info("Profiler stopped.")
        return Response(status_code=200)


346
347
def build_app(args: Namespace) -> FastAPI:
    app = FastAPI(lifespan=lifespan)
Ethan Xu's avatar
Ethan Xu committed
348
349
    app.include_router(router)
    app.root_path = args.root_path
Zhuohan Li's avatar
Zhuohan Li committed
350

351
352
    mount_metrics(app)

Zhuohan Li's avatar
Zhuohan Li committed
353
354
355
356
357
358
359
360
    app.add_middleware(
        CORSMiddleware,
        allow_origins=args.allowed_origins,
        allow_credentials=args.allow_credentials,
        allow_methods=args.allowed_methods,
        allow_headers=args.allowed_headers,
    )

Ethan Xu's avatar
Ethan Xu committed
361
362
363
364
365
366
    @app.exception_handler(RequestValidationError)
    async def validation_exception_handler(_, exc):
        err = openai_serving_chat.create_error_response(message=str(exc))
        return JSONResponse(err.model_dump(),
                            status_code=HTTPStatus.BAD_REQUEST)

367
    if token := envs.VLLM_API_KEY or args.api_key:
368
369
370

        @app.middleware("http")
        async def authentication(request: Request, call_next):
371
            root_path = "" if args.root_path is None else args.root_path
372
373
            if request.method == "OPTIONS":
                return await call_next(request)
374
            if not request.url.path.startswith(f"{root_path}/v1"):
375
376
377
378
379
380
381
382
383
384
385
386
387
388
                return await call_next(request)
            if request.headers.get("Authorization") != "Bearer " + token:
                return JSONResponse(content={"error": "Unauthorized"},
                                    status_code=401)
            return await call_next(request)

    for middleware in args.middleware:
        module_path, object_name = middleware.rsplit(".", 1)
        imported = getattr(importlib.import_module(module_path), object_name)
        if inspect.isclass(imported):
            app.add_middleware(imported)
        elif inspect.iscoroutinefunction(imported):
            app.middleware("http")(imported)
        else:
389
390
            raise ValueError(f"Invalid middleware {middleware}. "
                             f"Must be a function or a class.")
391

Ethan Xu's avatar
Ethan Xu committed
392
393
394
    return app


395
async def init_app(
396
    async_engine_client: AsyncEngineClient,
397
398
    args: Namespace,
) -> FastAPI:
Ethan Xu's avatar
Ethan Xu committed
399
400
    app = build_app(args)

401
    if args.served_model_name is not None:
402
        served_model_names = args.served_model_name
403
    else:
404
        served_model_names = [args.model]
405

406
    model_config = await async_engine_client.get_model_config()
407

408
409
410
411
412
    if args.disable_log_requests:
        request_logger = None
    else:
        request_logger = RequestLogger(max_log_len=args.max_log_len)

Ethan Xu's avatar
Ethan Xu committed
413
414
415
    global openai_serving_chat
    global openai_serving_completion
    global openai_serving_embedding
416
    global openai_serving_tokenization
Ethan Xu's avatar
Ethan Xu committed
417

418
    openai_serving_chat = OpenAIServingChat(
419
        async_engine_client,
420
421
422
423
424
425
426
        model_config,
        served_model_names,
        args.response_role,
        lora_modules=args.lora_modules,
        prompt_adapters=args.prompt_adapters,
        request_logger=request_logger,
        chat_template=args.chat_template,
427
        return_tokens_as_token_ids=args.return_tokens_as_token_ids,
428
429
        enable_auto_tools=args.enable_auto_tool_choice,
        tool_parser=args.tool_call_parser)
430
    openai_serving_completion = OpenAIServingCompletion(
431
        async_engine_client,
432
433
434
435
436
        model_config,
        served_model_names,
        lora_modules=args.lora_modules,
        prompt_adapters=args.prompt_adapters,
        request_logger=request_logger,
437
        return_tokens_as_token_ids=args.return_tokens_as_token_ids,
438
439
    )
    openai_serving_embedding = OpenAIServingEmbedding(
440
        async_engine_client,
441
442
443
444
        model_config,
        served_model_names,
        request_logger=request_logger,
    )
445
    openai_serving_tokenization = OpenAIServingTokenization(
446
        async_engine_client,
447
448
449
450
451
452
        model_config,
        served_model_names,
        lora_modules=args.lora_modules,
        request_logger=request_logger,
        chat_template=args.chat_template,
    )
453
    app.root_path = args.root_path
454

455
    return app
456
457


458
async def run_server(args, **uvicorn_kwargs) -> None:
459
460
461
    logger.info("vLLM API server version %s", VLLM_VERSION)
    logger.info("args: %s", args)

462
    async with build_async_engine_client(args) as async_engine_client:
463
464
465
466
        # If None, creation of the client failed and we exit.
        if async_engine_client is None:
            return

467
468
469
470
        app = await init_app(async_engine_client, args)

        shutdown_task = await serve_http(
            app,
471
            engine=async_engine_client,
472
473
474
475
476
477
478
479
            host=args.host,
            port=args.port,
            log_level=args.uvicorn_log_level,
            timeout_keep_alive=TIMEOUT_KEEP_ALIVE,
            ssl_keyfile=args.ssl_keyfile,
            ssl_certfile=args.ssl_certfile,
            ssl_ca_certs=args.ssl_ca_certs,
            ssl_cert_reqs=args.ssl_cert_reqs,
480
481
482
            **uvicorn_kwargs,
        )

483
484
    # NB: Await server shutdown only after the backend context is exited
    await shutdown_task
485

Ethan Xu's avatar
Ethan Xu committed
486
487
488
489
490
491
492
493

if __name__ == "__main__":
    # NOTE(simon):
    # This section should be in sync with vllm/scripts.py for CLI entrypoints.
    parser = FlexibleArgumentParser(
        description="vLLM OpenAI-Compatible RESTful API server.")
    parser = make_arg_parser(parser)
    args = parser.parse_args()
494

495
    asyncio.run(run_server(args))