api_server.py 21.1 KB
Newer Older
1
import asyncio
2
3
import importlib
import inspect
4
import multiprocessing
5
import os
6
import re
7
import signal
8
import tempfile
9
from argparse import Namespace
10
from contextlib import asynccontextmanager
11
from functools import partial
12
from http import HTTPStatus
13
from typing import AsyncIterator, Optional, Set
14

15
import uvloop
16
from fastapi import APIRouter, FastAPI, Request
Zhuohan Li's avatar
Zhuohan Li committed
17
18
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
19
from fastapi.responses import JSONResponse, Response, StreamingResponse
20
from starlette.datastructures import State
21
from starlette.routing import Mount
22
from typing_extensions import assert_never
Zhuohan Li's avatar
Zhuohan Li committed
23

24
import vllm.envs as envs
25
from vllm.config import ModelConfig
Woosuk Kwon's avatar
Woosuk Kwon committed
26
27
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
28
from vllm.engine.protocol import AsyncEngineClient
29
from vllm.entrypoints.launcher import serve_http
30
from vllm.entrypoints.logger import RequestLogger
31
from vllm.entrypoints.openai.cli_args import make_arg_parser
32
33
# yapf conflicts with isort for this block
# yapf: disable
34
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
35
                                              ChatCompletionResponse,
36
                                              CompletionRequest,
37
                                              CompletionResponse,
38
39
                                              DetokenizeRequest,
                                              DetokenizeResponse,
40
41
                                              EmbeddingRequest,
                                              EmbeddingResponse, ErrorResponse,
42
                                              LoadLoraAdapterRequest,
43
                                              TokenizeRequest,
44
45
                                              TokenizeResponse,
                                              UnloadLoraAdapterRequest)
46
47
from vllm.entrypoints.openai.rpc.client import AsyncEngineRPCClient
from vllm.entrypoints.openai.rpc.server import run_rpc_server
48
# yapf: enable
49
50
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
51
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
52
53
from vllm.entrypoints.openai.serving_tokenization import (
    OpenAIServingTokenization)
54
from vllm.logger import init_logger
yhu422's avatar
yhu422 committed
55
from vllm.usage.usage_lib import UsageContext
56
from vllm.utils import FlexibleArgumentParser, get_open_zmq_ipc_path
57
from vllm.version import __version__ as VLLM_VERSION
Zhuohan Li's avatar
Zhuohan Li committed
58

59
TIMEOUT_KEEP_ALIVE = 5  # seconds
Zhuohan Li's avatar
Zhuohan Li committed
60

61
prometheus_multiproc_dir: tempfile.TemporaryDirectory
62

63
# Cannot use __name__ (https://github.com/vllm-project/vllm/pull/4765)
64
logger = init_logger('vllm.entrypoints.openai.api_server')
65

66
_running_tasks: Set[asyncio.Task] = set()
67

68

69
def model_is_embedding(model_name: str, trust_remote_code: bool,
70
71
                       quantization: Optional[str],
                       revision: Optional[str]) -> bool:
72
    return ModelConfig(model=model_name,
73
                       revision=revision,
74
75
                       tokenizer=model_name,
                       tokenizer_mode="auto",
76
                       trust_remote_code=trust_remote_code,
77
                       quantization=quantization,
78
                       seed=0,
79
                       dtype="auto").embedding_mode
80
81


82
@asynccontextmanager
83
async def lifespan(app: FastAPI):
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
    try:
        if app.state.log_stats:
            async_engine_client = app.state.engine_client

            async def _force_log():
                while True:
                    await asyncio.sleep(10)
                    await async_engine_client.do_log_stats()

            task = asyncio.create_task(_force_log())
            _running_tasks.add(task)
            task.add_done_callback(_running_tasks.remove)
        else:
            task = None
        try:
            yield
        finally:
            if task is not None:
                task.cancel()
    finally:
        # Ensure app state including engine ref is gc'd
        del app.state
106
107


108
@asynccontextmanager
109
async def build_async_engine_client(
110
111
        args: Namespace) -> AsyncIterator[Optional[AsyncEngineClient]]:

112
113
114
115
    # Context manager to handle async_engine_client lifecycle
    # Ensures everything is shutdown and cleaned up on error/exit
    engine_args = AsyncEngineArgs.from_cli_args(args)

116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
    async with build_async_engine_client_from_engine_args(
            engine_args, args.disable_frontend_multiprocessing) as engine:
        yield engine


@asynccontextmanager
async def build_async_engine_client_from_engine_args(
    engine_args: AsyncEngineArgs,
    disable_frontend_multiprocessing: bool = False,
) -> AsyncIterator[Optional[AsyncEngineClient]]:
    """
    Create AsyncEngineClient, either:
        - in-process using the AsyncLLMEngine Directly
        - multiprocess using AsyncLLMEngine RPC

    Returns the Client or None if the creation failed.
    """

134
135
    # If manually triggered or embedding model, use AsyncLLMEngine in process.
    # TODO: support embedding model via RPC.
136
    if (model_is_embedding(engine_args.model, engine_args.trust_remote_code,
137
                           engine_args.quantization, engine_args.revision)
138
            or disable_frontend_multiprocessing):
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
        engine_config = engine_args.create_engine_config()
        uses_ray = getattr(AsyncLLMEngine._get_executor_cls(engine_config),
                           "uses_ray", False)

        build_engine = partial(AsyncLLMEngine.from_engine_args,
                               engine_args=engine_args,
                               engine_config=engine_config,
                               usage_context=UsageContext.OPENAI_API_SERVER)
        if uses_ray:
            # Must run in main thread with ray for its signal handlers to work
            engine_client = build_engine()
        else:
            engine_client = await asyncio.get_running_loop().run_in_executor(
                None, build_engine)

        yield engine_client
155
156
157
158
        return

    # Otherwise, use the multiprocessing AsyncLLMEngine.
    else:
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
        if "PROMETHEUS_MULTIPROC_DIR" not in os.environ:
            # Make TemporaryDirectory for prometheus multiprocessing
            # Note: global TemporaryDirectory will be automatically
            #   cleaned up upon exit.
            global prometheus_multiproc_dir
            prometheus_multiproc_dir = tempfile.TemporaryDirectory()
            os.environ[
                "PROMETHEUS_MULTIPROC_DIR"] = prometheus_multiproc_dir.name
        else:
            logger.warning(
                "Found PROMETHEUS_MULTIPROC_DIR was set by user. "
                "This directory must be wiped between vLLM runs or "
                "you will find inaccurate metrics. Unset the variable "
                "and vLLM will properly handle cleanup.")

174
175
176
177
178
        # Select random path for IPC.
        rpc_path = get_open_zmq_ipc_path()
        logger.info("Multiprocessing frontend to use %s for RPC Path.",
                    rpc_path)

179
180
181
182
183
        # Build RPCClient, which conforms to AsyncEngineClient Protocol.
        # NOTE: Actually, this is not true yet. We still need to support
        # embedding models via RPC (see TODO above)
        rpc_client = AsyncEngineRPCClient(rpc_path)

184
        # Start RPCServer in separate process (holds the AsyncLLMEngine).
185
186
187
188
189
190
        context = multiprocessing.get_context("spawn")
        # the current process might have CUDA context,
        # so we need to spawn a new process
        rpc_server_process = context.Process(
            target=run_rpc_server,
            args=(engine_args, UsageContext.OPENAI_API_SERVER, rpc_path))
191
        rpc_server_process.start()
192
193
        logger.info("Started engine process with PID %d",
                    rpc_server_process.pid)
194
195

        try:
196
197
            while True:
                try:
198
                    await rpc_client.setup()
199
                    break
200
                except TimeoutError:
201
                    if not rpc_server_process.is_alive():
202
203
204
205
206
                        logger.error(
                            "RPCServer process died before responding "
                            "to readiness probe")
                        yield None
                        return
207

208
            yield rpc_client  # type: ignore[misc]
209
210
211
212
213
        finally:
            # Ensure rpc server process was terminated
            rpc_server_process.terminate()

            # Close all open connections to the backend
214
            rpc_client.close()
215
216
217
218

            # Wait for server process to join
            rpc_server_process.join()

219
220
221
222
223
224
225
            # Lazy import for prometheus multiprocessing.
            # We need to set PROMETHEUS_MULTIPROC_DIR environment variable
            # before prometheus_client is imported.
            # See https://prometheus.github.io/client_python/multiprocess/
            from prometheus_client import multiprocess
            multiprocess.mark_process_dead(rpc_server_process.pid)

226

Ethan Xu's avatar
Ethan Xu committed
227
router = APIRouter()
Zhuohan Li's avatar
Zhuohan Li committed
228

229

230
def mount_metrics(app: FastAPI):
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
    # Lazy import for prometheus multiprocessing.
    # We need to set PROMETHEUS_MULTIPROC_DIR environment variable
    # before prometheus_client is imported.
    # See https://prometheus.github.io/client_python/multiprocess/
    from prometheus_client import (CollectorRegistry, make_asgi_app,
                                   multiprocess)

    prometheus_multiproc_dir_path = os.getenv("PROMETHEUS_MULTIPROC_DIR", None)
    if prometheus_multiproc_dir_path is not None:
        logger.info("vLLM to use %s as PROMETHEUS_MULTIPROC_DIR",
                    prometheus_multiproc_dir_path)
        registry = CollectorRegistry()
        multiprocess.MultiProcessCollector(registry)

        # Add prometheus asgi middleware to route /metrics requests
        metrics_route = Mount("/metrics", make_asgi_app(registry=registry))
    else:
        # Add prometheus asgi middleware to route /metrics requests
        metrics_route = Mount("/metrics", make_asgi_app())

251
    # Workaround for 307 Redirect for /metrics
252
    metrics_route.path_regex = re.compile("^/metrics(?P<path>.*)$")
253
    app.routes.append(metrics_route)
254
255


256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
def chat(request: Request) -> OpenAIServingChat:
    return request.app.state.openai_serving_chat


def completion(request: Request) -> OpenAIServingCompletion:
    return request.app.state.openai_serving_completion


def tokenization(request: Request) -> OpenAIServingTokenization:
    return request.app.state.openai_serving_tokenization


def embedding(request: Request) -> OpenAIServingEmbedding:
    return request.app.state.openai_serving_embedding


def engine_client(request: Request) -> AsyncEngineClient:
    return request.app.state.engine_client


Ethan Xu's avatar
Ethan Xu committed
276
@router.get("/health")
277
async def health(raw_request: Request) -> Response:
278
    """Health check."""
279
    await engine_client(raw_request).check_health()
280
281
282
    return Response(status_code=200)


Ethan Xu's avatar
Ethan Xu committed
283
@router.post("/tokenize")
284
285
async def tokenize(request: TokenizeRequest, raw_request: Request):
    generator = await tokenization(raw_request).create_tokenize(request)
286
287
288
    if isinstance(generator, ErrorResponse):
        return JSONResponse(content=generator.model_dump(),
                            status_code=generator.code)
289
    elif isinstance(generator, TokenizeResponse):
290
291
        return JSONResponse(content=generator.model_dump())

292
293
    assert_never(generator)

294

Ethan Xu's avatar
Ethan Xu committed
295
@router.post("/detokenize")
296
297
async def detokenize(request: DetokenizeRequest, raw_request: Request):
    generator = await tokenization(raw_request).create_detokenize(request)
298
299
300
    if isinstance(generator, ErrorResponse):
        return JSONResponse(content=generator.model_dump(),
                            status_code=generator.code)
301
    elif isinstance(generator, DetokenizeResponse):
302
303
        return JSONResponse(content=generator.model_dump())

304
305
    assert_never(generator)

306

Ethan Xu's avatar
Ethan Xu committed
307
@router.get("/v1/models")
308
309
async def show_available_models(raw_request: Request):
    models = await completion(raw_request).show_available_models()
310
    return JSONResponse(content=models.model_dump())
Zhuohan Li's avatar
Zhuohan Li committed
311
312


Ethan Xu's avatar
Ethan Xu committed
313
@router.get("/version")
314
async def show_version():
315
    ver = {"version": VLLM_VERSION}
316
317
318
    return JSONResponse(content=ver)


Ethan Xu's avatar
Ethan Xu committed
319
@router.post("/v1/chat/completions")
320
321
async def create_chat_completion(request: ChatCompletionRequest,
                                 raw_request: Request):
322

323
    generator = await chat(raw_request).create_chat_completion(
324
        request, raw_request)
325

326
327
328
    if isinstance(generator, ErrorResponse):
        return JSONResponse(content=generator.model_dump(),
                            status_code=generator.code)
329

330
    elif isinstance(generator, ChatCompletionResponse):
331
        return JSONResponse(content=generator.model_dump())
332

333
334
    return StreamingResponse(content=generator, media_type="text/event-stream")

335

Ethan Xu's avatar
Ethan Xu committed
336
@router.post("/v1/completions")
337
async def create_completion(request: CompletionRequest, raw_request: Request):
338
    generator = await completion(raw_request).create_completion(
339
        request, raw_request)
340
341
342
    if isinstance(generator, ErrorResponse):
        return JSONResponse(content=generator.model_dump(),
                            status_code=generator.code)
343
    elif isinstance(generator, CompletionResponse):
344
        return JSONResponse(content=generator.model_dump())
Zhuohan Li's avatar
Zhuohan Li committed
345

346
347
    return StreamingResponse(content=generator, media_type="text/event-stream")

Zhuohan Li's avatar
Zhuohan Li committed
348

Ethan Xu's avatar
Ethan Xu committed
349
@router.post("/v1/embeddings")
350
async def create_embedding(request: EmbeddingRequest, raw_request: Request):
351
    generator = await embedding(raw_request).create_embedding(
352
353
354
355
        request, raw_request)
    if isinstance(generator, ErrorResponse):
        return JSONResponse(content=generator.model_dump(),
                            status_code=generator.code)
356
    elif isinstance(generator, EmbeddingResponse):
357
358
        return JSONResponse(content=generator.model_dump())

359
360
    assert_never(generator)

361

362
363
364
365
366
367
if envs.VLLM_TORCH_PROFILER_DIR:
    logger.warning(
        "Torch Profiler is enabled in the API server. This should ONLY be "
        "used for local development!")

    @router.post("/start_profile")
368
    async def start_profile(raw_request: Request):
369
        logger.info("Starting profiler...")
370
        await engine_client(raw_request).start_profile()
371
372
373
374
        logger.info("Profiler started.")
        return Response(status_code=200)

    @router.post("/stop_profile")
375
    async def stop_profile(raw_request: Request):
376
        logger.info("Stopping profiler...")
377
        await engine_client(raw_request).stop_profile()
378
379
380
381
        logger.info("Profiler stopped.")
        return Response(status_code=200)


382
383
384
385
386
387
if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING:
    logger.warning(
        "Lora dynamic loading & unloading is enabled in the API server. "
        "This should ONLY be used for local development!")

    @router.post("/v1/load_lora_adapter")
388
389
390
    async def load_lora_adapter(request: LoadLoraAdapterRequest,
                                raw_request: Request):
        response = await chat(raw_request).load_lora_adapter(request)
391
392
393
394
        if isinstance(response, ErrorResponse):
            return JSONResponse(content=response.model_dump(),
                                status_code=response.code)

395
        response = await completion(raw_request).load_lora_adapter(request)
396
397
398
399
400
401
402
        if isinstance(response, ErrorResponse):
            return JSONResponse(content=response.model_dump(),
                                status_code=response.code)

        return Response(status_code=200, content=response)

    @router.post("/v1/unload_lora_adapter")
403
404
405
    async def unload_lora_adapter(request: UnloadLoraAdapterRequest,
                                  raw_request: Request):
        response = await chat(raw_request).unload_lora_adapter(request)
406
407
408
409
        if isinstance(response, ErrorResponse):
            return JSONResponse(content=response.model_dump(),
                                status_code=response.code)

410
        response = await completion(raw_request).unload_lora_adapter(request)
411
412
413
414
415
416
417
        if isinstance(response, ErrorResponse):
            return JSONResponse(content=response.model_dump(),
                                status_code=response.code)

        return Response(status_code=200, content=response)


418
419
def build_app(args: Namespace) -> FastAPI:
    app = FastAPI(lifespan=lifespan)
Ethan Xu's avatar
Ethan Xu committed
420
421
    app.include_router(router)
    app.root_path = args.root_path
Zhuohan Li's avatar
Zhuohan Li committed
422

423
424
    mount_metrics(app)

Zhuohan Li's avatar
Zhuohan Li committed
425
426
427
428
429
430
431
432
    app.add_middleware(
        CORSMiddleware,
        allow_origins=args.allowed_origins,
        allow_credentials=args.allow_credentials,
        allow_methods=args.allowed_methods,
        allow_headers=args.allowed_headers,
    )

Ethan Xu's avatar
Ethan Xu committed
433
434
    @app.exception_handler(RequestValidationError)
    async def validation_exception_handler(_, exc):
435
436
        chat = app.state.openai_serving_chat
        err = chat.create_error_response(message=str(exc))
Ethan Xu's avatar
Ethan Xu committed
437
438
439
        return JSONResponse(err.model_dump(),
                            status_code=HTTPStatus.BAD_REQUEST)

440
    if token := envs.VLLM_API_KEY or args.api_key:
441
442
443

        @app.middleware("http")
        async def authentication(request: Request, call_next):
444
            root_path = "" if args.root_path is None else args.root_path
445
446
            if request.method == "OPTIONS":
                return await call_next(request)
447
            if not request.url.path.startswith(f"{root_path}/v1"):
448
449
450
451
452
453
454
455
456
457
458
459
460
461
                return await call_next(request)
            if request.headers.get("Authorization") != "Bearer " + token:
                return JSONResponse(content={"error": "Unauthorized"},
                                    status_code=401)
            return await call_next(request)

    for middleware in args.middleware:
        module_path, object_name = middleware.rsplit(".", 1)
        imported = getattr(importlib.import_module(module_path), object_name)
        if inspect.isclass(imported):
            app.add_middleware(imported)
        elif inspect.iscoroutinefunction(imported):
            app.middleware("http")(imported)
        else:
462
463
            raise ValueError(f"Invalid middleware {middleware}. "
                             f"Must be a function or a class.")
464

Ethan Xu's avatar
Ethan Xu committed
465
466
467
    return app


468
def init_app_state(
469
    async_engine_client: AsyncEngineClient,
470
471
    model_config: ModelConfig,
    state: State,
472
    args: Namespace,
473
) -> None:
474
    if args.served_model_name is not None:
475
        served_model_names = args.served_model_name
476
    else:
477
        served_model_names = [args.model]
478

479
480
481
482
483
    if args.disable_log_requests:
        request_logger = None
    else:
        request_logger = RequestLogger(max_log_len=args.max_log_len)

484
485
    state.engine_client = async_engine_client
    state.log_stats = not args.disable_log_stats
Ethan Xu's avatar
Ethan Xu committed
486

487
    state.openai_serving_chat = OpenAIServingChat(
488
        async_engine_client,
489
490
491
492
493
494
495
        model_config,
        served_model_names,
        args.response_role,
        lora_modules=args.lora_modules,
        prompt_adapters=args.prompt_adapters,
        request_logger=request_logger,
        chat_template=args.chat_template,
496
        return_tokens_as_token_ids=args.return_tokens_as_token_ids,
497
498
        enable_auto_tools=args.enable_auto_tool_choice,
        tool_parser=args.tool_call_parser)
499
    state.openai_serving_completion = OpenAIServingCompletion(
500
        async_engine_client,
501
502
503
504
505
        model_config,
        served_model_names,
        lora_modules=args.lora_modules,
        prompt_adapters=args.prompt_adapters,
        request_logger=request_logger,
506
        return_tokens_as_token_ids=args.return_tokens_as_token_ids,
507
    )
508
    state.openai_serving_embedding = OpenAIServingEmbedding(
509
        async_engine_client,
510
511
512
513
        model_config,
        served_model_names,
        request_logger=request_logger,
    )
514
    state.openai_serving_tokenization = OpenAIServingTokenization(
515
        async_engine_client,
516
517
518
519
520
521
        model_config,
        served_model_names,
        lora_modules=args.lora_modules,
        request_logger=request_logger,
        chat_template=args.chat_template,
    )
522
523


524
async def run_server(args, **uvicorn_kwargs) -> None:
525
526
527
    logger.info("vLLM API server version %s", VLLM_VERSION)
    logger.info("args: %s", args)

528
529
530
531
532
533
    def signal_handler(*_) -> None:
        # Interrupt server on sigterm while initializing
        raise KeyboardInterrupt("terminated")

    signal.signal(signal.SIGTERM, signal_handler)

534
    async with build_async_engine_client(args) as async_engine_client:
535
536
537
538
        # If None, creation of the client failed and we exit.
        if async_engine_client is None:
            return

539
540
541
542
        app = build_app(args)

        model_config = await async_engine_client.get_model_config()
        init_app_state(async_engine_client, model_config, app.state, args)
543
544
545

        shutdown_task = await serve_http(
            app,
546
            limit_concurrency=async_engine_client.limit_concurrency,
547
548
549
550
551
552
553
554
            host=args.host,
            port=args.port,
            log_level=args.uvicorn_log_level,
            timeout_keep_alive=TIMEOUT_KEEP_ALIVE,
            ssl_keyfile=args.ssl_keyfile,
            ssl_certfile=args.ssl_certfile,
            ssl_ca_certs=args.ssl_ca_certs,
            ssl_cert_reqs=args.ssl_cert_reqs,
555
556
557
            **uvicorn_kwargs,
        )

558
559
    # NB: Await server shutdown only after the backend context is exited
    await shutdown_task
560

Ethan Xu's avatar
Ethan Xu committed
561
562
563
564
565
566
567
568

if __name__ == "__main__":
    # NOTE(simon):
    # This section should be in sync with vllm/scripts.py for CLI entrypoints.
    parser = FlexibleArgumentParser(
        description="vLLM OpenAI-Compatible RESTful API server.")
    parser = make_arg_parser(parser)
    args = parser.parse_args()
569

570
    uvloop.run(run_server(args))