api_server.py 19.6 KB
Newer Older
1
import asyncio
2
3
import importlib
import inspect
4
import multiprocessing
5
import os
6
import re
7
import tempfile
8
from argparse import Namespace
9
from contextlib import asynccontextmanager
10
from http import HTTPStatus
11
from typing import AsyncIterator, Optional, Set
12

13
from fastapi import APIRouter, FastAPI, Request
Zhuohan Li's avatar
Zhuohan Li committed
14
15
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
16
from fastapi.responses import JSONResponse, Response, StreamingResponse
17
from starlette.routing import Mount
18
from typing_extensions import assert_never
Zhuohan Li's avatar
Zhuohan Li committed
19

20
import vllm.envs as envs
21
from vllm.config import ModelConfig
Woosuk Kwon's avatar
Woosuk Kwon committed
22
23
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
24
from vllm.engine.protocol import AsyncEngineClient
25
from vllm.entrypoints.launcher import serve_http
26
from vllm.entrypoints.logger import RequestLogger
27
from vllm.entrypoints.openai.cli_args import make_arg_parser
28
29
# yapf conflicts with isort for this block
# yapf: disable
30
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
31
                                              ChatCompletionResponse,
32
                                              CompletionRequest,
33
                                              CompletionResponse,
34
35
                                              DetokenizeRequest,
                                              DetokenizeResponse,
36
37
                                              EmbeddingRequest,
                                              EmbeddingResponse, ErrorResponse,
38
                                              LoadLoraAdapterRequest,
39
                                              TokenizeRequest,
40
41
                                              TokenizeResponse,
                                              UnloadLoraAdapterRequest)
42
43
from vllm.entrypoints.openai.rpc.client import AsyncEngineRPCClient
from vllm.entrypoints.openai.rpc.server import run_rpc_server
44
# yapf: enable
45
46
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
47
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
48
49
from vllm.entrypoints.openai.serving_tokenization import (
    OpenAIServingTokenization)
50
from vllm.logger import init_logger
yhu422's avatar
yhu422 committed
51
from vllm.usage.usage_lib import UsageContext
52
from vllm.utils import FlexibleArgumentParser, get_open_zmq_ipc_path
53
from vllm.version import __version__ as VLLM_VERSION
Zhuohan Li's avatar
Zhuohan Li committed
54

55
TIMEOUT_KEEP_ALIVE = 5  # seconds
Zhuohan Li's avatar
Zhuohan Li committed
56

57
async_engine_client: AsyncEngineClient
Ethan Xu's avatar
Ethan Xu committed
58
engine_args: AsyncEngineArgs
59
60
openai_serving_chat: OpenAIServingChat
openai_serving_completion: OpenAIServingCompletion
61
openai_serving_embedding: OpenAIServingEmbedding
62
openai_serving_tokenization: OpenAIServingTokenization
63
prometheus_multiproc_dir: tempfile.TemporaryDirectory
64

65
# Cannot use __name__ (https://github.com/vllm-project/vllm/pull/4765)
66
logger = init_logger('vllm.entrypoints.openai.api_server')
67

68
_running_tasks: Set[asyncio.Task] = set()
69

70

71
def model_is_embedding(model_name: str, trust_remote_code: bool,
72
73
                       quantization: Optional[str],
                       revision: Optional[str]) -> bool:
74
    return ModelConfig(model=model_name,
75
                       revision=revision,
76
77
                       tokenizer=model_name,
                       tokenizer_mode="auto",
78
                       trust_remote_code=trust_remote_code,
79
                       quantization=quantization,
80
                       seed=0,
81
                       dtype="auto").embedding_mode
82
83


84
@asynccontextmanager
85
async def lifespan(app: FastAPI):
86
87
88
89

    async def _force_log():
        while True:
            await asyncio.sleep(10)
90
            await async_engine_client.do_log_stats()
91
92

    if not engine_args.disable_log_stats:
93
94
95
        task = asyncio.create_task(_force_log())
        _running_tasks.add(task)
        task.add_done_callback(_running_tasks.remove)
96
97
98
99

    yield


100
@asynccontextmanager
101
async def build_async_engine_client(
102
103
        args: Namespace) -> AsyncIterator[Optional[AsyncEngineClient]]:

104
105
106
107
108
109
110
111
    # Context manager to handle async_engine_client lifecycle
    # Ensures everything is shutdown and cleaned up on error/exit
    global engine_args
    engine_args = AsyncEngineArgs.from_cli_args(args)

    # Backend itself still global for the silly lil' health handler
    global async_engine_client

112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
    async with build_async_engine_client_from_engine_args(
            engine_args, args.disable_frontend_multiprocessing) as engine:

        async_engine_client = engine  # type: ignore[assignment]
        yield engine


@asynccontextmanager
async def build_async_engine_client_from_engine_args(
    engine_args: AsyncEngineArgs,
    disable_frontend_multiprocessing: bool = False,
) -> AsyncIterator[Optional[AsyncEngineClient]]:
    """
    Create AsyncEngineClient, either:
        - in-process using the AsyncLLMEngine Directly
        - multiprocess using AsyncLLMEngine RPC

    Returns the Client or None if the creation failed.
    """

132
133
    # If manually triggered or embedding model, use AsyncLLMEngine in process.
    # TODO: support embedding model via RPC.
134
    if (model_is_embedding(engine_args.model, engine_args.trust_remote_code,
135
                           engine_args.quantization, engine_args.revision)
136
137
            or disable_frontend_multiprocessing):
        engine_client = AsyncLLMEngine.from_engine_args(
138
            engine_args, usage_context=UsageContext.OPENAI_API_SERVER)
139
140
141
142
        try:
            yield engine_client
        finally:
            engine_client.shutdown_background_loop()
143
144
145
146
        return

    # Otherwise, use the multiprocessing AsyncLLMEngine.
    else:
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
        if "PROMETHEUS_MULTIPROC_DIR" not in os.environ:
            # Make TemporaryDirectory for prometheus multiprocessing
            # Note: global TemporaryDirectory will be automatically
            #   cleaned up upon exit.
            global prometheus_multiproc_dir
            prometheus_multiproc_dir = tempfile.TemporaryDirectory()
            os.environ[
                "PROMETHEUS_MULTIPROC_DIR"] = prometheus_multiproc_dir.name
        else:
            logger.warning(
                "Found PROMETHEUS_MULTIPROC_DIR was set by user. "
                "This directory must be wiped between vLLM runs or "
                "you will find inaccurate metrics. Unset the variable "
                "and vLLM will properly handle cleanup.")

162
163
164
165
166
        # Select random path for IPC.
        rpc_path = get_open_zmq_ipc_path()
        logger.info("Multiprocessing frontend to use %s for RPC Path.",
                    rpc_path)

167
168
169
170
171
        # Build RPCClient, which conforms to AsyncEngineClient Protocol.
        # NOTE: Actually, this is not true yet. We still need to support
        # embedding models via RPC (see TODO above)
        rpc_client = AsyncEngineRPCClient(rpc_path)

172
        # Start RPCServer in separate process (holds the AsyncLLMEngine).
173
174
175
176
177
178
        context = multiprocessing.get_context("spawn")
        # the current process might have CUDA context,
        # so we need to spawn a new process
        rpc_server_process = context.Process(
            target=run_rpc_server,
            args=(engine_args, UsageContext.OPENAI_API_SERVER, rpc_path))
179
        rpc_server_process.start()
180
181
        logger.info("Started engine process with PID %d",
                    rpc_server_process.pid)
182
183

        try:
184
185
            while True:
                try:
186
                    await rpc_client.setup()
187
                    break
188
                except TimeoutError:
189
                    if not rpc_server_process.is_alive():
190
191
192
193
194
                        logger.error(
                            "RPCServer process died before responding "
                            "to readiness probe")
                        yield None
                        return
195

196
            yield rpc_client  # type: ignore[misc]
197
198
199
200
201
        finally:
            # Ensure rpc server process was terminated
            rpc_server_process.terminate()

            # Close all open connections to the backend
202
            rpc_client.close()
203
204
205
206

            # Wait for server process to join
            rpc_server_process.join()

207
208
209
210
211
212
213
            # Lazy import for prometheus multiprocessing.
            # We need to set PROMETHEUS_MULTIPROC_DIR environment variable
            # before prometheus_client is imported.
            # See https://prometheus.github.io/client_python/multiprocess/
            from prometheus_client import multiprocess
            multiprocess.mark_process_dead(rpc_server_process.pid)

214

Ethan Xu's avatar
Ethan Xu committed
215
router = APIRouter()
Zhuohan Li's avatar
Zhuohan Li committed
216

217

218
def mount_metrics(app: FastAPI):
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
    # Lazy import for prometheus multiprocessing.
    # We need to set PROMETHEUS_MULTIPROC_DIR environment variable
    # before prometheus_client is imported.
    # See https://prometheus.github.io/client_python/multiprocess/
    from prometheus_client import (CollectorRegistry, make_asgi_app,
                                   multiprocess)

    prometheus_multiproc_dir_path = os.getenv("PROMETHEUS_MULTIPROC_DIR", None)
    if prometheus_multiproc_dir_path is not None:
        logger.info("vLLM to use %s as PROMETHEUS_MULTIPROC_DIR",
                    prometheus_multiproc_dir_path)
        registry = CollectorRegistry()
        multiprocess.MultiProcessCollector(registry)

        # Add prometheus asgi middleware to route /metrics requests
        metrics_route = Mount("/metrics", make_asgi_app(registry=registry))
    else:
        # Add prometheus asgi middleware to route /metrics requests
        metrics_route = Mount("/metrics", make_asgi_app())

239
    # Workaround for 307 Redirect for /metrics
240
    metrics_route.path_regex = re.compile("^/metrics(?P<path>.*)$")
241
    app.routes.append(metrics_route)
242
243


Ethan Xu's avatar
Ethan Xu committed
244
@router.get("/health")
245
246
async def health() -> Response:
    """Health check."""
247
    await async_engine_client.check_health()
248
249
250
    return Response(status_code=200)


Ethan Xu's avatar
Ethan Xu committed
251
@router.post("/tokenize")
252
async def tokenize(request: TokenizeRequest):
253
    generator = await openai_serving_tokenization.create_tokenize(request)
254
255
256
    if isinstance(generator, ErrorResponse):
        return JSONResponse(content=generator.model_dump(),
                            status_code=generator.code)
257
    elif isinstance(generator, TokenizeResponse):
258
259
        return JSONResponse(content=generator.model_dump())

260
261
    assert_never(generator)

262

Ethan Xu's avatar
Ethan Xu committed
263
@router.post("/detokenize")
264
async def detokenize(request: DetokenizeRequest):
265
    generator = await openai_serving_tokenization.create_detokenize(request)
266
267
268
    if isinstance(generator, ErrorResponse):
        return JSONResponse(content=generator.model_dump(),
                            status_code=generator.code)
269
    elif isinstance(generator, DetokenizeResponse):
270
271
        return JSONResponse(content=generator.model_dump())

272
273
    assert_never(generator)

274

Ethan Xu's avatar
Ethan Xu committed
275
@router.get("/v1/models")
Zhuohan Li's avatar
Zhuohan Li committed
276
async def show_available_models():
277
    models = await openai_serving_completion.show_available_models()
278
    return JSONResponse(content=models.model_dump())
Zhuohan Li's avatar
Zhuohan Li committed
279
280


Ethan Xu's avatar
Ethan Xu committed
281
@router.get("/version")
282
async def show_version():
283
    ver = {"version": VLLM_VERSION}
284
285
286
    return JSONResponse(content=ver)


Ethan Xu's avatar
Ethan Xu committed
287
@router.post("/v1/chat/completions")
288
289
async def create_chat_completion(request: ChatCompletionRequest,
                                 raw_request: Request):
290

291
292
    generator = await openai_serving_chat.create_chat_completion(
        request, raw_request)
293

294
295
296
    if isinstance(generator, ErrorResponse):
        return JSONResponse(content=generator.model_dump(),
                            status_code=generator.code)
297

298
    elif isinstance(generator, ChatCompletionResponse):
299
        return JSONResponse(content=generator.model_dump())
300

301
302
    return StreamingResponse(content=generator, media_type="text/event-stream")

303

Ethan Xu's avatar
Ethan Xu committed
304
@router.post("/v1/completions")
305
async def create_completion(request: CompletionRequest, raw_request: Request):
306
307
    generator = await openai_serving_completion.create_completion(
        request, raw_request)
308
309
310
    if isinstance(generator, ErrorResponse):
        return JSONResponse(content=generator.model_dump(),
                            status_code=generator.code)
311
    elif isinstance(generator, CompletionResponse):
312
        return JSONResponse(content=generator.model_dump())
Zhuohan Li's avatar
Zhuohan Li committed
313

314
315
    return StreamingResponse(content=generator, media_type="text/event-stream")

Zhuohan Li's avatar
Zhuohan Li committed
316

Ethan Xu's avatar
Ethan Xu committed
317
@router.post("/v1/embeddings")
318
319
320
321
322
323
async def create_embedding(request: EmbeddingRequest, raw_request: Request):
    generator = await openai_serving_embedding.create_embedding(
        request, raw_request)
    if isinstance(generator, ErrorResponse):
        return JSONResponse(content=generator.model_dump(),
                            status_code=generator.code)
324
    elif isinstance(generator, EmbeddingResponse):
325
326
        return JSONResponse(content=generator.model_dump())

327
328
    assert_never(generator)

329

330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
if envs.VLLM_TORCH_PROFILER_DIR:
    logger.warning(
        "Torch Profiler is enabled in the API server. This should ONLY be "
        "used for local development!")

    @router.post("/start_profile")
    async def start_profile():
        logger.info("Starting profiler...")
        await async_engine_client.start_profile()
        logger.info("Profiler started.")
        return Response(status_code=200)

    @router.post("/stop_profile")
    async def stop_profile():
        logger.info("Stopping profiler...")
        await async_engine_client.stop_profile()
        logger.info("Profiler stopped.")
        return Response(status_code=200)


350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING:
    logger.warning(
        "Lora dynamic loading & unloading is enabled in the API server. "
        "This should ONLY be used for local development!")

    @router.post("/v1/load_lora_adapter")
    async def load_lora_adapter(request: LoadLoraAdapterRequest):
        response = await openai_serving_chat.load_lora_adapter(request)
        if isinstance(response, ErrorResponse):
            return JSONResponse(content=response.model_dump(),
                                status_code=response.code)

        response = await openai_serving_completion.load_lora_adapter(request)
        if isinstance(response, ErrorResponse):
            return JSONResponse(content=response.model_dump(),
                                status_code=response.code)

        return Response(status_code=200, content=response)

    @router.post("/v1/unload_lora_adapter")
    async def unload_lora_adapter(request: UnloadLoraAdapterRequest):
        response = await openai_serving_chat.unload_lora_adapter(request)
        if isinstance(response, ErrorResponse):
            return JSONResponse(content=response.model_dump(),
                                status_code=response.code)

        response = await openai_serving_completion.unload_lora_adapter(request)
        if isinstance(response, ErrorResponse):
            return JSONResponse(content=response.model_dump(),
                                status_code=response.code)

        return Response(status_code=200, content=response)


384
385
def build_app(args: Namespace) -> FastAPI:
    app = FastAPI(lifespan=lifespan)
Ethan Xu's avatar
Ethan Xu committed
386
387
    app.include_router(router)
    app.root_path = args.root_path
Zhuohan Li's avatar
Zhuohan Li committed
388

389
390
    mount_metrics(app)

Zhuohan Li's avatar
Zhuohan Li committed
391
392
393
394
395
396
397
398
    app.add_middleware(
        CORSMiddleware,
        allow_origins=args.allowed_origins,
        allow_credentials=args.allow_credentials,
        allow_methods=args.allowed_methods,
        allow_headers=args.allowed_headers,
    )

Ethan Xu's avatar
Ethan Xu committed
399
400
401
402
403
404
    @app.exception_handler(RequestValidationError)
    async def validation_exception_handler(_, exc):
        err = openai_serving_chat.create_error_response(message=str(exc))
        return JSONResponse(err.model_dump(),
                            status_code=HTTPStatus.BAD_REQUEST)

405
    if token := envs.VLLM_API_KEY or args.api_key:
406
407
408

        @app.middleware("http")
        async def authentication(request: Request, call_next):
409
            root_path = "" if args.root_path is None else args.root_path
410
411
            if request.method == "OPTIONS":
                return await call_next(request)
412
            if not request.url.path.startswith(f"{root_path}/v1"):
413
414
415
416
417
418
419
420
421
422
423
424
425
426
                return await call_next(request)
            if request.headers.get("Authorization") != "Bearer " + token:
                return JSONResponse(content={"error": "Unauthorized"},
                                    status_code=401)
            return await call_next(request)

    for middleware in args.middleware:
        module_path, object_name = middleware.rsplit(".", 1)
        imported = getattr(importlib.import_module(module_path), object_name)
        if inspect.isclass(imported):
            app.add_middleware(imported)
        elif inspect.iscoroutinefunction(imported):
            app.middleware("http")(imported)
        else:
427
428
            raise ValueError(f"Invalid middleware {middleware}. "
                             f"Must be a function or a class.")
429

Ethan Xu's avatar
Ethan Xu committed
430
431
432
    return app


433
async def init_app(
434
    async_engine_client: AsyncEngineClient,
435
436
    args: Namespace,
) -> FastAPI:
Ethan Xu's avatar
Ethan Xu committed
437
438
    app = build_app(args)

439
    if args.served_model_name is not None:
440
        served_model_names = args.served_model_name
441
    else:
442
        served_model_names = [args.model]
443

444
    model_config = await async_engine_client.get_model_config()
445

446
447
448
449
450
    if args.disable_log_requests:
        request_logger = None
    else:
        request_logger = RequestLogger(max_log_len=args.max_log_len)

Ethan Xu's avatar
Ethan Xu committed
451
452
453
    global openai_serving_chat
    global openai_serving_completion
    global openai_serving_embedding
454
    global openai_serving_tokenization
Ethan Xu's avatar
Ethan Xu committed
455

456
    openai_serving_chat = OpenAIServingChat(
457
        async_engine_client,
458
459
460
461
462
463
464
        model_config,
        served_model_names,
        args.response_role,
        lora_modules=args.lora_modules,
        prompt_adapters=args.prompt_adapters,
        request_logger=request_logger,
        chat_template=args.chat_template,
465
        return_tokens_as_token_ids=args.return_tokens_as_token_ids,
466
467
        enable_auto_tools=args.enable_auto_tool_choice,
        tool_parser=args.tool_call_parser)
468
    openai_serving_completion = OpenAIServingCompletion(
469
        async_engine_client,
470
471
472
473
474
        model_config,
        served_model_names,
        lora_modules=args.lora_modules,
        prompt_adapters=args.prompt_adapters,
        request_logger=request_logger,
475
        return_tokens_as_token_ids=args.return_tokens_as_token_ids,
476
477
    )
    openai_serving_embedding = OpenAIServingEmbedding(
478
        async_engine_client,
479
480
481
482
        model_config,
        served_model_names,
        request_logger=request_logger,
    )
483
    openai_serving_tokenization = OpenAIServingTokenization(
484
        async_engine_client,
485
486
487
488
489
490
        model_config,
        served_model_names,
        lora_modules=args.lora_modules,
        request_logger=request_logger,
        chat_template=args.chat_template,
    )
491
    app.root_path = args.root_path
492

493
    return app
494
495


496
async def run_server(args, **uvicorn_kwargs) -> None:
497
498
499
    logger.info("vLLM API server version %s", VLLM_VERSION)
    logger.info("args: %s", args)

500
    async with build_async_engine_client(args) as async_engine_client:
501
502
503
504
        # If None, creation of the client failed and we exit.
        if async_engine_client is None:
            return

505
506
507
508
        app = await init_app(async_engine_client, args)

        shutdown_task = await serve_http(
            app,
509
            engine=async_engine_client,
510
511
512
513
514
515
516
517
            host=args.host,
            port=args.port,
            log_level=args.uvicorn_log_level,
            timeout_keep_alive=TIMEOUT_KEEP_ALIVE,
            ssl_keyfile=args.ssl_keyfile,
            ssl_certfile=args.ssl_certfile,
            ssl_ca_certs=args.ssl_ca_certs,
            ssl_cert_reqs=args.ssl_cert_reqs,
518
519
520
            **uvicorn_kwargs,
        )

521
522
    # NB: Await server shutdown only after the backend context is exited
    await shutdown_task
523

Ethan Xu's avatar
Ethan Xu committed
524
525
526
527
528
529
530
531

if __name__ == "__main__":
    # NOTE(simon):
    # This section should be in sync with vllm/scripts.py for CLI entrypoints.
    parser = FlexibleArgumentParser(
        description="vLLM OpenAI-Compatible RESTful API server.")
    parser = make_arg_parser(parser)
    args = parser.parse_args()
532

533
    asyncio.run(run_server(args))