api_server.py 17.8 KB
Newer Older
1
import asyncio
2
3
import importlib
import inspect
4
import multiprocessing
5
import os
6
import re
7
import tempfile
8
from argparse import Namespace
9
from contextlib import asynccontextmanager
10
from http import HTTPStatus
11
from typing import AsyncIterator, Optional, Set
12

13
from fastapi import APIRouter, FastAPI, Request
Zhuohan Li's avatar
Zhuohan Li committed
14
15
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
16
from fastapi.responses import JSONResponse, Response, StreamingResponse
17
from starlette.routing import Mount
18
from typing_extensions import assert_never
Zhuohan Li's avatar
Zhuohan Li committed
19

20
import vllm.envs as envs
21
from vllm.config import ModelConfig
Woosuk Kwon's avatar
Woosuk Kwon committed
22
23
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
24
from vllm.engine.protocol import AsyncEngineClient
25
from vllm.entrypoints.launcher import serve_http
26
from vllm.entrypoints.logger import RequestLogger
27
from vllm.entrypoints.openai.cli_args import make_arg_parser
28
29
# yapf conflicts with isort for this block
# yapf: disable
30
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
31
                                              ChatCompletionResponse,
32
                                              CompletionRequest,
33
                                              CompletionResponse,
34
35
                                              DetokenizeRequest,
                                              DetokenizeResponse,
36
37
                                              EmbeddingRequest,
                                              EmbeddingResponse, ErrorResponse,
38
39
                                              TokenizeRequest,
                                              TokenizeResponse)
40
# yapf: enable
41
42
from vllm.entrypoints.openai.rpc.client import AsyncEngineRPCClient
from vllm.entrypoints.openai.rpc.server import run_rpc_server
43
44
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
45
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
46
47
from vllm.entrypoints.openai.serving_tokenization import (
    OpenAIServingTokenization)
48
from vllm.logger import init_logger
yhu422's avatar
yhu422 committed
49
from vllm.usage.usage_lib import UsageContext
50
from vllm.utils import FlexibleArgumentParser, get_open_zmq_ipc_path
51
from vllm.version import __version__ as VLLM_VERSION
Zhuohan Li's avatar
Zhuohan Li committed
52

53
TIMEOUT_KEEP_ALIVE = 5  # seconds
Zhuohan Li's avatar
Zhuohan Li committed
54

55
async_engine_client: AsyncEngineClient
Ethan Xu's avatar
Ethan Xu committed
56
engine_args: AsyncEngineArgs
57
58
openai_serving_chat: OpenAIServingChat
openai_serving_completion: OpenAIServingCompletion
59
openai_serving_embedding: OpenAIServingEmbedding
60
openai_serving_tokenization: OpenAIServingTokenization
61
prometheus_multiproc_dir: tempfile.TemporaryDirectory
62

63
# Cannot use __name__ (https://github.com/vllm-project/vllm/pull/4765)
64
logger = init_logger('vllm.entrypoints.openai.api_server')
65

66
_running_tasks: Set[asyncio.Task] = set()
67

68

69
def model_is_embedding(model_name: str, trust_remote_code: bool,
70
                       quantization: Optional[str]) -> bool:
71
72
73
    return ModelConfig(model=model_name,
                       tokenizer=model_name,
                       tokenizer_mode="auto",
74
                       trust_remote_code=trust_remote_code,
75
                       quantization=quantization,
76
                       seed=0,
77
                       dtype="auto").embedding_mode
78
79


80
@asynccontextmanager
81
async def lifespan(app: FastAPI):
82
83
84
85

    async def _force_log():
        while True:
            await asyncio.sleep(10)
86
            await async_engine_client.do_log_stats()
87
88

    if not engine_args.disable_log_stats:
89
90
91
        task = asyncio.create_task(_force_log())
        _running_tasks.add(task)
        task.add_done_callback(_running_tasks.remove)
92
93
94
95

    yield


96
@asynccontextmanager
97
async def build_async_engine_client(
98
99
        args: Namespace) -> AsyncIterator[Optional[AsyncEngineClient]]:

100
101
102
103
104
105
106
107
    # Context manager to handle async_engine_client lifecycle
    # Ensures everything is shutdown and cleaned up on error/exit
    global engine_args
    engine_args = AsyncEngineArgs.from_cli_args(args)

    # Backend itself still global for the silly lil' health handler
    global async_engine_client

108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
    async with build_async_engine_client_from_engine_args(
            engine_args, args.disable_frontend_multiprocessing) as engine:

        async_engine_client = engine  # type: ignore[assignment]
        yield engine


@asynccontextmanager
async def build_async_engine_client_from_engine_args(
    engine_args: AsyncEngineArgs,
    disable_frontend_multiprocessing: bool = False,
) -> AsyncIterator[Optional[AsyncEngineClient]]:
    """
    Create AsyncEngineClient, either:
        - in-process using the AsyncLLMEngine Directly
        - multiprocess using AsyncLLMEngine RPC

    Returns the Client or None if the creation failed.
    """

128
129
    # If manually triggered or embedding model, use AsyncLLMEngine in process.
    # TODO: support embedding model via RPC.
130
131
132
133
    if (model_is_embedding(engine_args.model, engine_args.trust_remote_code,
                           engine_args.quantization)
            or disable_frontend_multiprocessing):
        engine_client = AsyncLLMEngine.from_engine_args(
134
            engine_args, usage_context=UsageContext.OPENAI_API_SERVER)
135
136
137
138
        try:
            yield engine_client
        finally:
            engine_client.shutdown_background_loop()
139
140
141
142
        return

    # Otherwise, use the multiprocessing AsyncLLMEngine.
    else:
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
        if "PROMETHEUS_MULTIPROC_DIR" not in os.environ:
            # Make TemporaryDirectory for prometheus multiprocessing
            # Note: global TemporaryDirectory will be automatically
            #   cleaned up upon exit.
            global prometheus_multiproc_dir
            prometheus_multiproc_dir = tempfile.TemporaryDirectory()
            os.environ[
                "PROMETHEUS_MULTIPROC_DIR"] = prometheus_multiproc_dir.name
        else:
            logger.warning(
                "Found PROMETHEUS_MULTIPROC_DIR was set by user. "
                "This directory must be wiped between vLLM runs or "
                "you will find inaccurate metrics. Unset the variable "
                "and vLLM will properly handle cleanup.")

158
159
160
161
162
        # Select random path for IPC.
        rpc_path = get_open_zmq_ipc_path()
        logger.info("Multiprocessing frontend to use %s for RPC Path.",
                    rpc_path)

163
164
165
166
167
        # Build RPCClient, which conforms to AsyncEngineClient Protocol.
        # NOTE: Actually, this is not true yet. We still need to support
        # embedding models via RPC (see TODO above)
        rpc_client = AsyncEngineRPCClient(rpc_path)

168
        # Start RPCServer in separate process (holds the AsyncLLMEngine).
169
170
171
172
173
174
        context = multiprocessing.get_context("spawn")
        # the current process might have CUDA context,
        # so we need to spawn a new process
        rpc_server_process = context.Process(
            target=run_rpc_server,
            args=(engine_args, UsageContext.OPENAI_API_SERVER, rpc_path))
175
        rpc_server_process.start()
176
177
        logger.info("Started engine process with PID %d",
                    rpc_server_process.pid)
178
179

        try:
180
181
            while True:
                try:
182
                    await rpc_client.setup()
183
                    break
184
                except TimeoutError:
185
                    if not rpc_server_process.is_alive():
186
187
188
189
190
                        logger.error(
                            "RPCServer process died before responding "
                            "to readiness probe")
                        yield None
                        return
191

192
            yield rpc_client  # type: ignore[misc]
193
194
195
196
197
        finally:
            # Ensure rpc server process was terminated
            rpc_server_process.terminate()

            # Close all open connections to the backend
198
            rpc_client.close()
199
200
201
202

            # Wait for server process to join
            rpc_server_process.join()

203
204
205
206
207
208
209
            # Lazy import for prometheus multiprocessing.
            # We need to set PROMETHEUS_MULTIPROC_DIR environment variable
            # before prometheus_client is imported.
            # See https://prometheus.github.io/client_python/multiprocess/
            from prometheus_client import multiprocess
            multiprocess.mark_process_dead(rpc_server_process.pid)

210

Ethan Xu's avatar
Ethan Xu committed
211
router = APIRouter()
Zhuohan Li's avatar
Zhuohan Li committed
212

213

214
def mount_metrics(app: FastAPI):
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
    # Lazy import for prometheus multiprocessing.
    # We need to set PROMETHEUS_MULTIPROC_DIR environment variable
    # before prometheus_client is imported.
    # See https://prometheus.github.io/client_python/multiprocess/
    from prometheus_client import (CollectorRegistry, make_asgi_app,
                                   multiprocess)

    prometheus_multiproc_dir_path = os.getenv("PROMETHEUS_MULTIPROC_DIR", None)
    if prometheus_multiproc_dir_path is not None:
        logger.info("vLLM to use %s as PROMETHEUS_MULTIPROC_DIR",
                    prometheus_multiproc_dir_path)
        registry = CollectorRegistry()
        multiprocess.MultiProcessCollector(registry)

        # Add prometheus asgi middleware to route /metrics requests
        metrics_route = Mount("/metrics", make_asgi_app(registry=registry))
    else:
        # Add prometheus asgi middleware to route /metrics requests
        metrics_route = Mount("/metrics", make_asgi_app())

235
236
237
    # Workaround for 307 Redirect for /metrics
    metrics_route.path_regex = re.compile('^/metrics(?P<path>.*)$')
    app.routes.append(metrics_route)
238
239


Ethan Xu's avatar
Ethan Xu committed
240
@router.get("/health")
241
242
async def health() -> Response:
    """Health check."""
243
    await async_engine_client.check_health()
244
245
246
    return Response(status_code=200)


Ethan Xu's avatar
Ethan Xu committed
247
@router.post("/tokenize")
248
async def tokenize(request: TokenizeRequest):
249
    generator = await openai_serving_tokenization.create_tokenize(request)
250
251
252
    if isinstance(generator, ErrorResponse):
        return JSONResponse(content=generator.model_dump(),
                            status_code=generator.code)
253
    elif isinstance(generator, TokenizeResponse):
254
255
        return JSONResponse(content=generator.model_dump())

256
257
    assert_never(generator)

258

Ethan Xu's avatar
Ethan Xu committed
259
@router.post("/detokenize")
260
async def detokenize(request: DetokenizeRequest):
261
    generator = await openai_serving_tokenization.create_detokenize(request)
262
263
264
    if isinstance(generator, ErrorResponse):
        return JSONResponse(content=generator.model_dump(),
                            status_code=generator.code)
265
    elif isinstance(generator, DetokenizeResponse):
266
267
        return JSONResponse(content=generator.model_dump())

268
269
    assert_never(generator)

270

Ethan Xu's avatar
Ethan Xu committed
271
@router.get("/v1/models")
Zhuohan Li's avatar
Zhuohan Li committed
272
async def show_available_models():
273
    models = await openai_serving_completion.show_available_models()
274
    return JSONResponse(content=models.model_dump())
Zhuohan Li's avatar
Zhuohan Li committed
275
276


Ethan Xu's avatar
Ethan Xu committed
277
@router.get("/version")
278
async def show_version():
279
    ver = {"version": VLLM_VERSION}
280
281
282
    return JSONResponse(content=ver)


Ethan Xu's avatar
Ethan Xu committed
283
@router.post("/v1/chat/completions")
284
285
async def create_chat_completion(request: ChatCompletionRequest,
                                 raw_request: Request):
286
287
    generator = await openai_serving_chat.create_chat_completion(
        request, raw_request)
288
289
290
    if isinstance(generator, ErrorResponse):
        return JSONResponse(content=generator.model_dump(),
                            status_code=generator.code)
291
    elif isinstance(generator, ChatCompletionResponse):
292
        return JSONResponse(content=generator.model_dump())
293

294
295
    return StreamingResponse(content=generator, media_type="text/event-stream")

296

Ethan Xu's avatar
Ethan Xu committed
297
@router.post("/v1/completions")
298
async def create_completion(request: CompletionRequest, raw_request: Request):
299
300
    generator = await openai_serving_completion.create_completion(
        request, raw_request)
301
302
303
    if isinstance(generator, ErrorResponse):
        return JSONResponse(content=generator.model_dump(),
                            status_code=generator.code)
304
    elif isinstance(generator, CompletionResponse):
305
        return JSONResponse(content=generator.model_dump())
Zhuohan Li's avatar
Zhuohan Li committed
306

307
308
    return StreamingResponse(content=generator, media_type="text/event-stream")

Zhuohan Li's avatar
Zhuohan Li committed
309

Ethan Xu's avatar
Ethan Xu committed
310
@router.post("/v1/embeddings")
311
312
313
314
315
316
async def create_embedding(request: EmbeddingRequest, raw_request: Request):
    generator = await openai_serving_embedding.create_embedding(
        request, raw_request)
    if isinstance(generator, ErrorResponse):
        return JSONResponse(content=generator.model_dump(),
                            status_code=generator.code)
317
    elif isinstance(generator, EmbeddingResponse):
318
319
        return JSONResponse(content=generator.model_dump())

320
321
    assert_never(generator)

322

323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
if envs.VLLM_TORCH_PROFILER_DIR:
    logger.warning(
        "Torch Profiler is enabled in the API server. This should ONLY be "
        "used for local development!")

    @router.post("/start_profile")
    async def start_profile():
        logger.info("Starting profiler...")
        await async_engine_client.start_profile()
        logger.info("Profiler started.")
        return Response(status_code=200)

    @router.post("/stop_profile")
    async def stop_profile():
        logger.info("Stopping profiler...")
        await async_engine_client.stop_profile()
        logger.info("Profiler stopped.")
        return Response(status_code=200)


343
344
def build_app(args: Namespace) -> FastAPI:
    app = FastAPI(lifespan=lifespan)
Ethan Xu's avatar
Ethan Xu committed
345
346
    app.include_router(router)
    app.root_path = args.root_path
Zhuohan Li's avatar
Zhuohan Li committed
347

348
349
    mount_metrics(app)

Zhuohan Li's avatar
Zhuohan Li committed
350
351
352
353
354
355
356
357
    app.add_middleware(
        CORSMiddleware,
        allow_origins=args.allowed_origins,
        allow_credentials=args.allow_credentials,
        allow_methods=args.allowed_methods,
        allow_headers=args.allowed_headers,
    )

Ethan Xu's avatar
Ethan Xu committed
358
359
360
361
362
363
    @app.exception_handler(RequestValidationError)
    async def validation_exception_handler(_, exc):
        err = openai_serving_chat.create_error_response(message=str(exc))
        return JSONResponse(err.model_dump(),
                            status_code=HTTPStatus.BAD_REQUEST)

364
    if token := envs.VLLM_API_KEY or args.api_key:
365
366
367

        @app.middleware("http")
        async def authentication(request: Request, call_next):
368
            root_path = "" if args.root_path is None else args.root_path
369
370
            if request.method == "OPTIONS":
                return await call_next(request)
371
            if not request.url.path.startswith(f"{root_path}/v1"):
372
373
374
375
376
377
378
379
380
381
382
383
384
385
                return await call_next(request)
            if request.headers.get("Authorization") != "Bearer " + token:
                return JSONResponse(content={"error": "Unauthorized"},
                                    status_code=401)
            return await call_next(request)

    for middleware in args.middleware:
        module_path, object_name = middleware.rsplit(".", 1)
        imported = getattr(importlib.import_module(module_path), object_name)
        if inspect.isclass(imported):
            app.add_middleware(imported)
        elif inspect.iscoroutinefunction(imported):
            app.middleware("http")(imported)
        else:
386
387
            raise ValueError(f"Invalid middleware {middleware}. "
                             f"Must be a function or a class.")
388

Ethan Xu's avatar
Ethan Xu committed
389
390
391
    return app


392
async def init_app(
393
    async_engine_client: AsyncEngineClient,
394
395
    args: Namespace,
) -> FastAPI:
Ethan Xu's avatar
Ethan Xu committed
396
397
    app = build_app(args)

398
    if args.served_model_name is not None:
399
        served_model_names = args.served_model_name
400
    else:
401
        served_model_names = [args.model]
402

403
    model_config = await async_engine_client.get_model_config()
404

405
406
407
408
409
    if args.disable_log_requests:
        request_logger = None
    else:
        request_logger = RequestLogger(max_log_len=args.max_log_len)

Ethan Xu's avatar
Ethan Xu committed
410
411
412
    global openai_serving_chat
    global openai_serving_completion
    global openai_serving_embedding
413
    global openai_serving_tokenization
Ethan Xu's avatar
Ethan Xu committed
414

415
    openai_serving_chat = OpenAIServingChat(
416
        async_engine_client,
417
418
419
420
421
422
423
        model_config,
        served_model_names,
        args.response_role,
        lora_modules=args.lora_modules,
        prompt_adapters=args.prompt_adapters,
        request_logger=request_logger,
        chat_template=args.chat_template,
424
        return_tokens_as_token_ids=args.return_tokens_as_token_ids,
425
    )
426
    openai_serving_completion = OpenAIServingCompletion(
427
        async_engine_client,
428
429
430
431
432
        model_config,
        served_model_names,
        lora_modules=args.lora_modules,
        prompt_adapters=args.prompt_adapters,
        request_logger=request_logger,
433
        return_tokens_as_token_ids=args.return_tokens_as_token_ids,
434
435
    )
    openai_serving_embedding = OpenAIServingEmbedding(
436
        async_engine_client,
437
438
439
440
        model_config,
        served_model_names,
        request_logger=request_logger,
    )
441
    openai_serving_tokenization = OpenAIServingTokenization(
442
        async_engine_client,
443
444
445
446
447
448
        model_config,
        served_model_names,
        lora_modules=args.lora_modules,
        request_logger=request_logger,
        chat_template=args.chat_template,
    )
449
    app.root_path = args.root_path
450

451
    return app
452
453


454
async def run_server(args, **uvicorn_kwargs) -> None:
455
456
457
    logger.info("vLLM API server version %s", VLLM_VERSION)
    logger.info("args: %s", args)

458
    async with build_async_engine_client(args) as async_engine_client:
459
460
461
462
        # If None, creation of the client failed and we exit.
        if async_engine_client is None:
            return

463
464
465
466
        app = await init_app(async_engine_client, args)

        shutdown_task = await serve_http(
            app,
467
            engine=async_engine_client,
468
469
470
471
472
473
474
475
            host=args.host,
            port=args.port,
            log_level=args.uvicorn_log_level,
            timeout_keep_alive=TIMEOUT_KEEP_ALIVE,
            ssl_keyfile=args.ssl_keyfile,
            ssl_certfile=args.ssl_certfile,
            ssl_ca_certs=args.ssl_ca_certs,
            ssl_cert_reqs=args.ssl_cert_reqs,
476
477
478
            **uvicorn_kwargs,
        )

479
480
    # NB: Await server shutdown only after the backend context is exited
    await shutdown_task
481

Ethan Xu's avatar
Ethan Xu committed
482
483
484
485
486
487
488
489

if __name__ == "__main__":
    # NOTE(simon):
    # This section should be in sync with vllm/scripts.py for CLI entrypoints.
    parser = FlexibleArgumentParser(
        description="vLLM OpenAI-Compatible RESTful API server.")
    parser = make_arg_parser(parser)
    args = parser.parse_args()
490

491
    asyncio.run(run_server(args))