api_server.py 19.5 KB
Newer Older
1
import asyncio
2
3
import importlib
import inspect
4
import multiprocessing
5
import os
6
import re
7
import tempfile
8
from argparse import Namespace
9
from contextlib import asynccontextmanager
10
from http import HTTPStatus
11
from typing import AsyncIterator, Optional, Set
12

13
from fastapi import APIRouter, FastAPI, Request
Zhuohan Li's avatar
Zhuohan Li committed
14
15
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
16
from fastapi.responses import JSONResponse, Response, StreamingResponse
17
from starlette.routing import Mount
18
from typing_extensions import assert_never
Zhuohan Li's avatar
Zhuohan Li committed
19

20
import vllm.envs as envs
21
from vllm.config import ModelConfig
Woosuk Kwon's avatar
Woosuk Kwon committed
22
23
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
24
from vllm.engine.protocol import AsyncEngineClient
25
from vllm.entrypoints.launcher import serve_http
26
from vllm.entrypoints.logger import RequestLogger
27
from vllm.entrypoints.openai.cli_args import make_arg_parser
28
29
# yapf conflicts with isort for this block
# yapf: disable
30
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
31
                                              ChatCompletionResponse,
32
                                              CompletionRequest,
33
                                              CompletionResponse,
34
35
                                              DetokenizeRequest,
                                              DetokenizeResponse,
36
37
                                              EmbeddingRequest,
                                              EmbeddingResponse, ErrorResponse,
38
                                              LoadLoraAdapterRequest,
39
                                              TokenizeRequest,
40
41
                                              TokenizeResponse,
                                              UnloadLoraAdapterRequest)
42
43
from vllm.entrypoints.openai.rpc.client import AsyncEngineRPCClient
from vllm.entrypoints.openai.rpc.server import run_rpc_server
44
# yapf: enable
45
46
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
47
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
48
49
from vllm.entrypoints.openai.serving_tokenization import (
    OpenAIServingTokenization)
50
from vllm.logger import init_logger
yhu422's avatar
yhu422 committed
51
from vllm.usage.usage_lib import UsageContext
52
from vllm.utils import FlexibleArgumentParser, get_open_zmq_ipc_path
53
from vllm.version import __version__ as VLLM_VERSION
Zhuohan Li's avatar
Zhuohan Li committed
54

55
TIMEOUT_KEEP_ALIVE = 5  # seconds
Zhuohan Li's avatar
Zhuohan Li committed
56

57
async_engine_client: AsyncEngineClient
Ethan Xu's avatar
Ethan Xu committed
58
engine_args: AsyncEngineArgs
59
60
openai_serving_chat: OpenAIServingChat
openai_serving_completion: OpenAIServingCompletion
61
openai_serving_embedding: OpenAIServingEmbedding
62
openai_serving_tokenization: OpenAIServingTokenization
63
prometheus_multiproc_dir: tempfile.TemporaryDirectory
64

65
# Cannot use __name__ (https://github.com/vllm-project/vllm/pull/4765)
66
logger = init_logger('vllm.entrypoints.openai.api_server')
67

68
_running_tasks: Set[asyncio.Task] = set()
69

70

71
def model_is_embedding(model_name: str, trust_remote_code: bool,
72
                       quantization: Optional[str]) -> bool:
73
74
75
    return ModelConfig(model=model_name,
                       tokenizer=model_name,
                       tokenizer_mode="auto",
76
                       trust_remote_code=trust_remote_code,
77
                       quantization=quantization,
78
                       seed=0,
79
                       dtype="auto").embedding_mode
80
81


82
@asynccontextmanager
83
async def lifespan(app: FastAPI):
84
85
86
87

    async def _force_log():
        while True:
            await asyncio.sleep(10)
88
            await async_engine_client.do_log_stats()
89
90

    if not engine_args.disable_log_stats:
91
92
93
        task = asyncio.create_task(_force_log())
        _running_tasks.add(task)
        task.add_done_callback(_running_tasks.remove)
94
95
96
97

    yield


98
@asynccontextmanager
99
async def build_async_engine_client(
100
101
        args: Namespace) -> AsyncIterator[Optional[AsyncEngineClient]]:

102
103
104
105
106
107
108
109
    # Context manager to handle async_engine_client lifecycle
    # Ensures everything is shutdown and cleaned up on error/exit
    global engine_args
    engine_args = AsyncEngineArgs.from_cli_args(args)

    # Backend itself still global for the silly lil' health handler
    global async_engine_client

110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
    async with build_async_engine_client_from_engine_args(
            engine_args, args.disable_frontend_multiprocessing) as engine:

        async_engine_client = engine  # type: ignore[assignment]
        yield engine


@asynccontextmanager
async def build_async_engine_client_from_engine_args(
    engine_args: AsyncEngineArgs,
    disable_frontend_multiprocessing: bool = False,
) -> AsyncIterator[Optional[AsyncEngineClient]]:
    """
    Create AsyncEngineClient, either:
        - in-process using the AsyncLLMEngine Directly
        - multiprocess using AsyncLLMEngine RPC

    Returns the Client or None if the creation failed.
    """

130
131
    # If manually triggered or embedding model, use AsyncLLMEngine in process.
    # TODO: support embedding model via RPC.
132
133
134
135
    if (model_is_embedding(engine_args.model, engine_args.trust_remote_code,
                           engine_args.quantization)
            or disable_frontend_multiprocessing):
        engine_client = AsyncLLMEngine.from_engine_args(
136
            engine_args, usage_context=UsageContext.OPENAI_API_SERVER)
137
138
139
140
        try:
            yield engine_client
        finally:
            engine_client.shutdown_background_loop()
141
142
143
144
        return

    # Otherwise, use the multiprocessing AsyncLLMEngine.
    else:
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
        if "PROMETHEUS_MULTIPROC_DIR" not in os.environ:
            # Make TemporaryDirectory for prometheus multiprocessing
            # Note: global TemporaryDirectory will be automatically
            #   cleaned up upon exit.
            global prometheus_multiproc_dir
            prometheus_multiproc_dir = tempfile.TemporaryDirectory()
            os.environ[
                "PROMETHEUS_MULTIPROC_DIR"] = prometheus_multiproc_dir.name
        else:
            logger.warning(
                "Found PROMETHEUS_MULTIPROC_DIR was set by user. "
                "This directory must be wiped between vLLM runs or "
                "you will find inaccurate metrics. Unset the variable "
                "and vLLM will properly handle cleanup.")

160
161
162
163
164
        # Select random path for IPC.
        rpc_path = get_open_zmq_ipc_path()
        logger.info("Multiprocessing frontend to use %s for RPC Path.",
                    rpc_path)

165
166
167
168
169
        # Build RPCClient, which conforms to AsyncEngineClient Protocol.
        # NOTE: Actually, this is not true yet. We still need to support
        # embedding models via RPC (see TODO above)
        rpc_client = AsyncEngineRPCClient(rpc_path)

170
        # Start RPCServer in separate process (holds the AsyncLLMEngine).
171
172
173
174
175
176
        context = multiprocessing.get_context("spawn")
        # the current process might have CUDA context,
        # so we need to spawn a new process
        rpc_server_process = context.Process(
            target=run_rpc_server,
            args=(engine_args, UsageContext.OPENAI_API_SERVER, rpc_path))
177
        rpc_server_process.start()
178
179
        logger.info("Started engine process with PID %d",
                    rpc_server_process.pid)
180
181

        try:
182
183
            while True:
                try:
184
                    await rpc_client.setup()
185
                    break
186
                except TimeoutError:
187
                    if not rpc_server_process.is_alive():
188
189
190
191
192
                        logger.error(
                            "RPCServer process died before responding "
                            "to readiness probe")
                        yield None
                        return
193

194
            yield rpc_client  # type: ignore[misc]
195
196
197
198
199
        finally:
            # Ensure rpc server process was terminated
            rpc_server_process.terminate()

            # Close all open connections to the backend
200
            rpc_client.close()
201
202
203
204

            # Wait for server process to join
            rpc_server_process.join()

205
206
207
208
209
210
211
            # Lazy import for prometheus multiprocessing.
            # We need to set PROMETHEUS_MULTIPROC_DIR environment variable
            # before prometheus_client is imported.
            # See https://prometheus.github.io/client_python/multiprocess/
            from prometheus_client import multiprocess
            multiprocess.mark_process_dead(rpc_server_process.pid)

212

Ethan Xu's avatar
Ethan Xu committed
213
router = APIRouter()
Zhuohan Li's avatar
Zhuohan Li committed
214

215

216
def mount_metrics(app: FastAPI):
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
    # Lazy import for prometheus multiprocessing.
    # We need to set PROMETHEUS_MULTIPROC_DIR environment variable
    # before prometheus_client is imported.
    # See https://prometheus.github.io/client_python/multiprocess/
    from prometheus_client import (CollectorRegistry, make_asgi_app,
                                   multiprocess)

    prometheus_multiproc_dir_path = os.getenv("PROMETHEUS_MULTIPROC_DIR", None)
    if prometheus_multiproc_dir_path is not None:
        logger.info("vLLM to use %s as PROMETHEUS_MULTIPROC_DIR",
                    prometheus_multiproc_dir_path)
        registry = CollectorRegistry()
        multiprocess.MultiProcessCollector(registry)

        # Add prometheus asgi middleware to route /metrics requests
        metrics_route = Mount("/metrics", make_asgi_app(registry=registry))
    else:
        # Add prometheus asgi middleware to route /metrics requests
        metrics_route = Mount("/metrics", make_asgi_app())

237
    # Workaround for 307 Redirect for /metrics
238
    metrics_route.path_regex = re.compile("^/metrics(?P<path>.*)$")
239
    app.routes.append(metrics_route)
240
241


Ethan Xu's avatar
Ethan Xu committed
242
@router.get("/health")
243
244
async def health() -> Response:
    """Health check."""
245
    await async_engine_client.check_health()
246
247
248
    return Response(status_code=200)


Ethan Xu's avatar
Ethan Xu committed
249
@router.post("/tokenize")
250
async def tokenize(request: TokenizeRequest):
251
    generator = await openai_serving_tokenization.create_tokenize(request)
252
253
254
    if isinstance(generator, ErrorResponse):
        return JSONResponse(content=generator.model_dump(),
                            status_code=generator.code)
255
    elif isinstance(generator, TokenizeResponse):
256
257
        return JSONResponse(content=generator.model_dump())

258
259
    assert_never(generator)

260

Ethan Xu's avatar
Ethan Xu committed
261
@router.post("/detokenize")
262
async def detokenize(request: DetokenizeRequest):
263
    generator = await openai_serving_tokenization.create_detokenize(request)
264
265
266
    if isinstance(generator, ErrorResponse):
        return JSONResponse(content=generator.model_dump(),
                            status_code=generator.code)
267
    elif isinstance(generator, DetokenizeResponse):
268
269
        return JSONResponse(content=generator.model_dump())

270
271
    assert_never(generator)

272

Ethan Xu's avatar
Ethan Xu committed
273
@router.get("/v1/models")
Zhuohan Li's avatar
Zhuohan Li committed
274
async def show_available_models():
275
    models = await openai_serving_completion.show_available_models()
276
    return JSONResponse(content=models.model_dump())
Zhuohan Li's avatar
Zhuohan Li committed
277
278


Ethan Xu's avatar
Ethan Xu committed
279
@router.get("/version")
280
async def show_version():
281
    ver = {"version": VLLM_VERSION}
282
283
284
    return JSONResponse(content=ver)


Ethan Xu's avatar
Ethan Xu committed
285
@router.post("/v1/chat/completions")
286
287
async def create_chat_completion(request: ChatCompletionRequest,
                                 raw_request: Request):
288

289
290
    generator = await openai_serving_chat.create_chat_completion(
        request, raw_request)
291

292
293
294
    if isinstance(generator, ErrorResponse):
        return JSONResponse(content=generator.model_dump(),
                            status_code=generator.code)
295

296
    elif isinstance(generator, ChatCompletionResponse):
297
        return JSONResponse(content=generator.model_dump())
298

299
300
    return StreamingResponse(content=generator, media_type="text/event-stream")

301

Ethan Xu's avatar
Ethan Xu committed
302
@router.post("/v1/completions")
303
async def create_completion(request: CompletionRequest, raw_request: Request):
304
305
    generator = await openai_serving_completion.create_completion(
        request, raw_request)
306
307
308
    if isinstance(generator, ErrorResponse):
        return JSONResponse(content=generator.model_dump(),
                            status_code=generator.code)
309
    elif isinstance(generator, CompletionResponse):
310
        return JSONResponse(content=generator.model_dump())
Zhuohan Li's avatar
Zhuohan Li committed
311

312
313
    return StreamingResponse(content=generator, media_type="text/event-stream")

Zhuohan Li's avatar
Zhuohan Li committed
314

Ethan Xu's avatar
Ethan Xu committed
315
@router.post("/v1/embeddings")
316
317
318
319
320
321
async def create_embedding(request: EmbeddingRequest, raw_request: Request):
    generator = await openai_serving_embedding.create_embedding(
        request, raw_request)
    if isinstance(generator, ErrorResponse):
        return JSONResponse(content=generator.model_dump(),
                            status_code=generator.code)
322
    elif isinstance(generator, EmbeddingResponse):
323
324
        return JSONResponse(content=generator.model_dump())

325
326
    assert_never(generator)

327

328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
if envs.VLLM_TORCH_PROFILER_DIR:
    logger.warning(
        "Torch Profiler is enabled in the API server. This should ONLY be "
        "used for local development!")

    @router.post("/start_profile")
    async def start_profile():
        logger.info("Starting profiler...")
        await async_engine_client.start_profile()
        logger.info("Profiler started.")
        return Response(status_code=200)

    @router.post("/stop_profile")
    async def stop_profile():
        logger.info("Stopping profiler...")
        await async_engine_client.stop_profile()
        logger.info("Profiler stopped.")
        return Response(status_code=200)


348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING:
    logger.warning(
        "Lora dynamic loading & unloading is enabled in the API server. "
        "This should ONLY be used for local development!")

    @router.post("/v1/load_lora_adapter")
    async def load_lora_adapter(request: LoadLoraAdapterRequest):
        response = await openai_serving_chat.load_lora_adapter(request)
        if isinstance(response, ErrorResponse):
            return JSONResponse(content=response.model_dump(),
                                status_code=response.code)

        response = await openai_serving_completion.load_lora_adapter(request)
        if isinstance(response, ErrorResponse):
            return JSONResponse(content=response.model_dump(),
                                status_code=response.code)

        return Response(status_code=200, content=response)

    @router.post("/v1/unload_lora_adapter")
    async def unload_lora_adapter(request: UnloadLoraAdapterRequest):
        response = await openai_serving_chat.unload_lora_adapter(request)
        if isinstance(response, ErrorResponse):
            return JSONResponse(content=response.model_dump(),
                                status_code=response.code)

        response = await openai_serving_completion.unload_lora_adapter(request)
        if isinstance(response, ErrorResponse):
            return JSONResponse(content=response.model_dump(),
                                status_code=response.code)

        return Response(status_code=200, content=response)


382
383
def build_app(args: Namespace) -> FastAPI:
    app = FastAPI(lifespan=lifespan)
Ethan Xu's avatar
Ethan Xu committed
384
385
    app.include_router(router)
    app.root_path = args.root_path
Zhuohan Li's avatar
Zhuohan Li committed
386

387
388
    mount_metrics(app)

Zhuohan Li's avatar
Zhuohan Li committed
389
390
391
392
393
394
395
396
    app.add_middleware(
        CORSMiddleware,
        allow_origins=args.allowed_origins,
        allow_credentials=args.allow_credentials,
        allow_methods=args.allowed_methods,
        allow_headers=args.allowed_headers,
    )

Ethan Xu's avatar
Ethan Xu committed
397
398
399
400
401
402
    @app.exception_handler(RequestValidationError)
    async def validation_exception_handler(_, exc):
        err = openai_serving_chat.create_error_response(message=str(exc))
        return JSONResponse(err.model_dump(),
                            status_code=HTTPStatus.BAD_REQUEST)

403
    if token := envs.VLLM_API_KEY or args.api_key:
404
405
406

        @app.middleware("http")
        async def authentication(request: Request, call_next):
407
            root_path = "" if args.root_path is None else args.root_path
408
409
            if request.method == "OPTIONS":
                return await call_next(request)
410
            if not request.url.path.startswith(f"{root_path}/v1"):
411
412
413
414
415
416
417
418
419
420
421
422
423
424
                return await call_next(request)
            if request.headers.get("Authorization") != "Bearer " + token:
                return JSONResponse(content={"error": "Unauthorized"},
                                    status_code=401)
            return await call_next(request)

    for middleware in args.middleware:
        module_path, object_name = middleware.rsplit(".", 1)
        imported = getattr(importlib.import_module(module_path), object_name)
        if inspect.isclass(imported):
            app.add_middleware(imported)
        elif inspect.iscoroutinefunction(imported):
            app.middleware("http")(imported)
        else:
425
426
            raise ValueError(f"Invalid middleware {middleware}. "
                             f"Must be a function or a class.")
427

Ethan Xu's avatar
Ethan Xu committed
428
429
430
    return app


431
async def init_app(
432
    async_engine_client: AsyncEngineClient,
433
434
    args: Namespace,
) -> FastAPI:
Ethan Xu's avatar
Ethan Xu committed
435
436
    app = build_app(args)

437
    if args.served_model_name is not None:
438
        served_model_names = args.served_model_name
439
    else:
440
        served_model_names = [args.model]
441

442
    model_config = await async_engine_client.get_model_config()
443

444
445
446
447
448
    if args.disable_log_requests:
        request_logger = None
    else:
        request_logger = RequestLogger(max_log_len=args.max_log_len)

Ethan Xu's avatar
Ethan Xu committed
449
450
451
    global openai_serving_chat
    global openai_serving_completion
    global openai_serving_embedding
452
    global openai_serving_tokenization
Ethan Xu's avatar
Ethan Xu committed
453

454
    openai_serving_chat = OpenAIServingChat(
455
        async_engine_client,
456
457
458
459
460
461
462
        model_config,
        served_model_names,
        args.response_role,
        lora_modules=args.lora_modules,
        prompt_adapters=args.prompt_adapters,
        request_logger=request_logger,
        chat_template=args.chat_template,
463
        return_tokens_as_token_ids=args.return_tokens_as_token_ids,
464
465
        enable_auto_tools=args.enable_auto_tool_choice,
        tool_parser=args.tool_call_parser)
466
    openai_serving_completion = OpenAIServingCompletion(
467
        async_engine_client,
468
469
470
471
472
        model_config,
        served_model_names,
        lora_modules=args.lora_modules,
        prompt_adapters=args.prompt_adapters,
        request_logger=request_logger,
473
        return_tokens_as_token_ids=args.return_tokens_as_token_ids,
474
475
    )
    openai_serving_embedding = OpenAIServingEmbedding(
476
        async_engine_client,
477
478
479
480
        model_config,
        served_model_names,
        request_logger=request_logger,
    )
481
    openai_serving_tokenization = OpenAIServingTokenization(
482
        async_engine_client,
483
484
485
486
487
488
        model_config,
        served_model_names,
        lora_modules=args.lora_modules,
        request_logger=request_logger,
        chat_template=args.chat_template,
    )
489
    app.root_path = args.root_path
490

491
    return app
492
493


494
async def run_server(args, **uvicorn_kwargs) -> None:
495
496
497
    logger.info("vLLM API server version %s", VLLM_VERSION)
    logger.info("args: %s", args)

498
    async with build_async_engine_client(args) as async_engine_client:
499
500
501
502
        # If None, creation of the client failed and we exit.
        if async_engine_client is None:
            return

503
504
505
506
        app = await init_app(async_engine_client, args)

        shutdown_task = await serve_http(
            app,
507
            engine=async_engine_client,
508
509
510
511
512
513
514
515
            host=args.host,
            port=args.port,
            log_level=args.uvicorn_log_level,
            timeout_keep_alive=TIMEOUT_KEEP_ALIVE,
            ssl_keyfile=args.ssl_keyfile,
            ssl_certfile=args.ssl_certfile,
            ssl_ca_certs=args.ssl_ca_certs,
            ssl_cert_reqs=args.ssl_cert_reqs,
516
517
518
            **uvicorn_kwargs,
        )

519
520
    # NB: Await server shutdown only after the backend context is exited
    await shutdown_task
521

Ethan Xu's avatar
Ethan Xu committed
522
523
524
525
526
527
528
529

if __name__ == "__main__":
    # NOTE(simon):
    # This section should be in sync with vllm/scripts.py for CLI entrypoints.
    parser = FlexibleArgumentParser(
        description="vLLM OpenAI-Compatible RESTful API server.")
    parser = make_arg_parser(parser)
    args = parser.parse_args()
530

531
    asyncio.run(run_server(args))