api_server.py 17.3 KB
Newer Older
1
import asyncio
2
3
import importlib
import inspect
4
import multiprocessing
5
import os
6
import re
7
import tempfile
8
from argparse import Namespace
9
10
from contextlib import asynccontextmanager
from http import HTTPStatus
11
from typing import AsyncIterator, Optional, Set
12

13
from fastapi import APIRouter, FastAPI, Request
Zhuohan Li's avatar
Zhuohan Li committed
14
15
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
16
from fastapi.responses import JSONResponse, Response, StreamingResponse
17
from starlette.routing import Mount
18
from typing_extensions import assert_never
Zhuohan Li's avatar
Zhuohan Li committed
19

20
import vllm.envs as envs
21
from vllm.config import ModelConfig
Woosuk Kwon's avatar
Woosuk Kwon committed
22
23
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
24
from vllm.engine.protocol import AsyncEngineClient
25
from vllm.entrypoints.launcher import serve_http
26
from vllm.entrypoints.logger import RequestLogger
27
from vllm.entrypoints.openai.cli_args import make_arg_parser
28
29
# yapf conflicts with isort for this block
# yapf: disable
30
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
31
                                              ChatCompletionResponse,
32
                                              CompletionRequest,
33
                                              CompletionResponse,
34
35
                                              DetokenizeRequest,
                                              DetokenizeResponse,
36
37
                                              EmbeddingRequest,
                                              EmbeddingResponse, ErrorResponse,
38
39
                                              TokenizeRequest,
                                              TokenizeResponse)
40
# yapf: enable
41
42
from vllm.entrypoints.openai.rpc.client import AsyncEngineRPCClient
from vllm.entrypoints.openai.rpc.server import run_rpc_server
43
44
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
45
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
46
47
from vllm.entrypoints.openai.serving_tokenization import (
    OpenAIServingTokenization)
48
from vllm.logger import init_logger
yhu422's avatar
yhu422 committed
49
from vllm.usage.usage_lib import UsageContext
50
from vllm.utils import FlexibleArgumentParser, get_open_zmq_ipc_path
51
from vllm.version import __version__ as VLLM_VERSION
Zhuohan Li's avatar
Zhuohan Li committed
52

53
TIMEOUT_KEEP_ALIVE = 5  # seconds
Zhuohan Li's avatar
Zhuohan Li committed
54

55
async_engine_client: AsyncEngineClient
Ethan Xu's avatar
Ethan Xu committed
56
engine_args: AsyncEngineArgs
57
58
openai_serving_chat: OpenAIServingChat
openai_serving_completion: OpenAIServingCompletion
59
openai_serving_embedding: OpenAIServingEmbedding
60
openai_serving_tokenization: OpenAIServingTokenization
61
prometheus_multiproc_dir: tempfile.TemporaryDirectory
62

63
# Cannot use __name__ (https://github.com/vllm-project/vllm/pull/4765)
64
logger = init_logger('vllm.entrypoints.openai.api_server')
65

66
_running_tasks: Set[asyncio.Task] = set()
67

68

69
70
def model_is_embedding(model_name: str, trust_remote_code: bool,
                       quantization: str) -> bool:
71
72
73
    return ModelConfig(model=model_name,
                       tokenizer=model_name,
                       tokenizer_mode="auto",
74
                       trust_remote_code=trust_remote_code,
75
                       quantization=quantization,
76
                       seed=0,
77
                       dtype="auto").embedding_mode
78
79


80
@asynccontextmanager
81
async def lifespan(app: FastAPI):
82
83
84
85

    async def _force_log():
        while True:
            await asyncio.sleep(10)
86
            await async_engine_client.do_log_stats()
87
88

    if not engine_args.disable_log_stats:
89
90
91
        task = asyncio.create_task(_force_log())
        _running_tasks.add(task)
        task.add_done_callback(_running_tasks.remove)
92
93
94
95

    yield


96
@asynccontextmanager
97
async def build_async_engine_client(
98
99
100
101
102
103
104
105
106
        args: Namespace) -> AsyncIterator[Optional[AsyncEngineClient]]:
    """
    Create AsyncEngineClient, either:
        - in-process using the AsyncLLMEngine Directly
        - multiprocess using AsyncLLMEngine RPC

    Returns the Client or None if the creation failed.
    """

107
108
109
110
111
112
113
114
115
116
    # Context manager to handle async_engine_client lifecycle
    # Ensures everything is shutdown and cleaned up on error/exit
    global engine_args
    engine_args = AsyncEngineArgs.from_cli_args(args)

    # Backend itself still global for the silly lil' health handler
    global async_engine_client

    # If manually triggered or embedding model, use AsyncLLMEngine in process.
    # TODO: support embedding model via RPC.
117
118
    if (model_is_embedding(args.model, args.trust_remote_code,
                           args.quantization)
119
120
121
122
123
124
125
126
            or args.disable_frontend_multiprocessing):
        async_engine_client = AsyncLLMEngine.from_engine_args(
            engine_args, usage_context=UsageContext.OPENAI_API_SERVER)
        yield async_engine_client
        return

    # Otherwise, use the multiprocessing AsyncLLMEngine.
    else:
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
        if "PROMETHEUS_MULTIPROC_DIR" not in os.environ:
            # Make TemporaryDirectory for prometheus multiprocessing
            # Note: global TemporaryDirectory will be automatically
            #   cleaned up upon exit.
            global prometheus_multiproc_dir
            prometheus_multiproc_dir = tempfile.TemporaryDirectory()
            os.environ[
                "PROMETHEUS_MULTIPROC_DIR"] = prometheus_multiproc_dir.name
        else:
            logger.warning(
                "Found PROMETHEUS_MULTIPROC_DIR was set by user. "
                "This directory must be wiped between vLLM runs or "
                "you will find inaccurate metrics. Unset the variable "
                "and vLLM will properly handle cleanup.")

142
143
144
145
146
        # Select random path for IPC.
        rpc_path = get_open_zmq_ipc_path()
        logger.info("Multiprocessing frontend to use %s for RPC Path.",
                    rpc_path)

147
148
149
150
151
152
        # Build RPCClient, which conforms to AsyncEngineClient Protocol.
        # NOTE: Actually, this is not true yet. We still need to support
        # embedding models via RPC (see TODO above)
        rpc_client = AsyncEngineRPCClient(rpc_path)
        async_engine_client = rpc_client  # type: ignore

153
        # Start RPCServer in separate process (holds the AsyncLLMEngine).
154
155
156
157
158
159
        context = multiprocessing.get_context("spawn")
        # the current process might have CUDA context,
        # so we need to spawn a new process
        rpc_server_process = context.Process(
            target=run_rpc_server,
            args=(engine_args, UsageContext.OPENAI_API_SERVER, rpc_path))
160
        rpc_server_process.start()
161
162
        logger.info("Started engine process with PID %d",
                    rpc_server_process.pid)
163
164

        try:
165
166
            while True:
                try:
167
                    await rpc_client.setup()
168
                    break
169
                except TimeoutError:
170
                    if not rpc_server_process.is_alive():
171
172
173
174
175
                        logger.error(
                            "RPCServer process died before responding "
                            "to readiness probe")
                        yield None
                        return
176

177
178
179
180
181
182
            yield async_engine_client
        finally:
            # Ensure rpc server process was terminated
            rpc_server_process.terminate()

            # Close all open connections to the backend
183
            rpc_client.close()
184
185
186
187

            # Wait for server process to join
            rpc_server_process.join()

188
189
190
191
192
193
194
            # Lazy import for prometheus multiprocessing.
            # We need to set PROMETHEUS_MULTIPROC_DIR environment variable
            # before prometheus_client is imported.
            # See https://prometheus.github.io/client_python/multiprocess/
            from prometheus_client import multiprocess
            multiprocess.mark_process_dead(rpc_server_process.pid)

195

Ethan Xu's avatar
Ethan Xu committed
196
router = APIRouter()
Zhuohan Li's avatar
Zhuohan Li committed
197

198

199
def mount_metrics(app: FastAPI):
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
    # Lazy import for prometheus multiprocessing.
    # We need to set PROMETHEUS_MULTIPROC_DIR environment variable
    # before prometheus_client is imported.
    # See https://prometheus.github.io/client_python/multiprocess/
    from prometheus_client import (CollectorRegistry, make_asgi_app,
                                   multiprocess)

    prometheus_multiproc_dir_path = os.getenv("PROMETHEUS_MULTIPROC_DIR", None)
    if prometheus_multiproc_dir_path is not None:
        logger.info("vLLM to use %s as PROMETHEUS_MULTIPROC_DIR",
                    prometheus_multiproc_dir_path)
        registry = CollectorRegistry()
        multiprocess.MultiProcessCollector(registry)

        # Add prometheus asgi middleware to route /metrics requests
        metrics_route = Mount("/metrics", make_asgi_app(registry=registry))
    else:
        # Add prometheus asgi middleware to route /metrics requests
        metrics_route = Mount("/metrics", make_asgi_app())

220
221
222
    # Workaround for 307 Redirect for /metrics
    metrics_route.path_regex = re.compile('^/metrics(?P<path>.*)$')
    app.routes.append(metrics_route)
223
224


Ethan Xu's avatar
Ethan Xu committed
225
@router.get("/health")
226
227
async def health() -> Response:
    """Health check."""
228
    await async_engine_client.check_health()
229
230
231
    return Response(status_code=200)


Ethan Xu's avatar
Ethan Xu committed
232
@router.post("/tokenize")
233
async def tokenize(request: TokenizeRequest):
234
    generator = await openai_serving_tokenization.create_tokenize(request)
235
236
237
    if isinstance(generator, ErrorResponse):
        return JSONResponse(content=generator.model_dump(),
                            status_code=generator.code)
238
    elif isinstance(generator, TokenizeResponse):
239
240
        return JSONResponse(content=generator.model_dump())

241
242
    assert_never(generator)

243

Ethan Xu's avatar
Ethan Xu committed
244
@router.post("/detokenize")
245
async def detokenize(request: DetokenizeRequest):
246
    generator = await openai_serving_tokenization.create_detokenize(request)
247
248
249
    if isinstance(generator, ErrorResponse):
        return JSONResponse(content=generator.model_dump(),
                            status_code=generator.code)
250
    elif isinstance(generator, DetokenizeResponse):
251
252
        return JSONResponse(content=generator.model_dump())

253
254
    assert_never(generator)

255

Ethan Xu's avatar
Ethan Xu committed
256
@router.get("/v1/models")
Zhuohan Li's avatar
Zhuohan Li committed
257
async def show_available_models():
258
    models = await openai_serving_completion.show_available_models()
259
    return JSONResponse(content=models.model_dump())
Zhuohan Li's avatar
Zhuohan Li committed
260
261


Ethan Xu's avatar
Ethan Xu committed
262
@router.get("/version")
263
async def show_version():
264
    ver = {"version": VLLM_VERSION}
265
266
267
    return JSONResponse(content=ver)


Ethan Xu's avatar
Ethan Xu committed
268
@router.post("/v1/chat/completions")
269
270
async def create_chat_completion(request: ChatCompletionRequest,
                                 raw_request: Request):
271
272
    generator = await openai_serving_chat.create_chat_completion(
        request, raw_request)
273
274
275
    if isinstance(generator, ErrorResponse):
        return JSONResponse(content=generator.model_dump(),
                            status_code=generator.code)
276
    elif isinstance(generator, ChatCompletionResponse):
277
        return JSONResponse(content=generator.model_dump())
278

279
280
    return StreamingResponse(content=generator, media_type="text/event-stream")

281

Ethan Xu's avatar
Ethan Xu committed
282
@router.post("/v1/completions")
283
async def create_completion(request: CompletionRequest, raw_request: Request):
284
285
    generator = await openai_serving_completion.create_completion(
        request, raw_request)
286
287
288
    if isinstance(generator, ErrorResponse):
        return JSONResponse(content=generator.model_dump(),
                            status_code=generator.code)
289
    elif isinstance(generator, CompletionResponse):
290
        return JSONResponse(content=generator.model_dump())
Zhuohan Li's avatar
Zhuohan Li committed
291

292
293
    return StreamingResponse(content=generator, media_type="text/event-stream")

Zhuohan Li's avatar
Zhuohan Li committed
294

Ethan Xu's avatar
Ethan Xu committed
295
@router.post("/v1/embeddings")
296
297
298
299
300
301
async def create_embedding(request: EmbeddingRequest, raw_request: Request):
    generator = await openai_serving_embedding.create_embedding(
        request, raw_request)
    if isinstance(generator, ErrorResponse):
        return JSONResponse(content=generator.model_dump(),
                            status_code=generator.code)
302
    elif isinstance(generator, EmbeddingResponse):
303
304
        return JSONResponse(content=generator.model_dump())

305
306
    assert_never(generator)

307

308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
if envs.VLLM_TORCH_PROFILER_DIR:
    logger.warning(
        "Torch Profiler is enabled in the API server. This should ONLY be "
        "used for local development!")

    @router.post("/start_profile")
    async def start_profile():
        logger.info("Starting profiler...")
        await async_engine_client.start_profile()
        logger.info("Profiler started.")
        return Response(status_code=200)

    @router.post("/stop_profile")
    async def stop_profile():
        logger.info("Stopping profiler...")
        await async_engine_client.stop_profile()
        logger.info("Profiler stopped.")
        return Response(status_code=200)


328
329
def build_app(args: Namespace) -> FastAPI:
    app = FastAPI(lifespan=lifespan)
Ethan Xu's avatar
Ethan Xu committed
330
331
    app.include_router(router)
    app.root_path = args.root_path
Zhuohan Li's avatar
Zhuohan Li committed
332

333
334
    mount_metrics(app)

Zhuohan Li's avatar
Zhuohan Li committed
335
336
337
338
339
340
341
342
    app.add_middleware(
        CORSMiddleware,
        allow_origins=args.allowed_origins,
        allow_credentials=args.allow_credentials,
        allow_methods=args.allowed_methods,
        allow_headers=args.allowed_headers,
    )

Ethan Xu's avatar
Ethan Xu committed
343
344
345
346
347
348
    @app.exception_handler(RequestValidationError)
    async def validation_exception_handler(_, exc):
        err = openai_serving_chat.create_error_response(message=str(exc))
        return JSONResponse(err.model_dump(),
                            status_code=HTTPStatus.BAD_REQUEST)

349
    if token := envs.VLLM_API_KEY or args.api_key:
350
351
352

        @app.middleware("http")
        async def authentication(request: Request, call_next):
353
            root_path = "" if args.root_path is None else args.root_path
354
355
            if request.method == "OPTIONS":
                return await call_next(request)
356
            if not request.url.path.startswith(f"{root_path}/v1"):
357
358
359
360
361
362
363
364
365
366
367
368
369
370
                return await call_next(request)
            if request.headers.get("Authorization") != "Bearer " + token:
                return JSONResponse(content={"error": "Unauthorized"},
                                    status_code=401)
            return await call_next(request)

    for middleware in args.middleware:
        module_path, object_name = middleware.rsplit(".", 1)
        imported = getattr(importlib.import_module(module_path), object_name)
        if inspect.isclass(imported):
            app.add_middleware(imported)
        elif inspect.iscoroutinefunction(imported):
            app.middleware("http")(imported)
        else:
371
372
            raise ValueError(f"Invalid middleware {middleware}. "
                             f"Must be a function or a class.")
373

Ethan Xu's avatar
Ethan Xu committed
374
375
376
    return app


377
async def init_app(
378
    async_engine_client: AsyncEngineClient,
379
380
    args: Namespace,
) -> FastAPI:
Ethan Xu's avatar
Ethan Xu committed
381
382
    app = build_app(args)

383
    if args.served_model_name is not None:
384
        served_model_names = args.served_model_name
385
    else:
386
        served_model_names = [args.model]
387

388
    model_config = await async_engine_client.get_model_config()
389

390
391
392
393
394
    if args.disable_log_requests:
        request_logger = None
    else:
        request_logger = RequestLogger(max_log_len=args.max_log_len)

Ethan Xu's avatar
Ethan Xu committed
395
396
397
    global openai_serving_chat
    global openai_serving_completion
    global openai_serving_embedding
398
    global openai_serving_tokenization
Ethan Xu's avatar
Ethan Xu committed
399

400
    openai_serving_chat = OpenAIServingChat(
401
        async_engine_client,
402
403
404
405
406
407
408
        model_config,
        served_model_names,
        args.response_role,
        lora_modules=args.lora_modules,
        prompt_adapters=args.prompt_adapters,
        request_logger=request_logger,
        chat_template=args.chat_template,
409
        return_tokens_as_token_ids=args.return_tokens_as_token_ids,
410
    )
411
    openai_serving_completion = OpenAIServingCompletion(
412
        async_engine_client,
413
414
415
416
417
        model_config,
        served_model_names,
        lora_modules=args.lora_modules,
        prompt_adapters=args.prompt_adapters,
        request_logger=request_logger,
418
        return_tokens_as_token_ids=args.return_tokens_as_token_ids,
419
420
    )
    openai_serving_embedding = OpenAIServingEmbedding(
421
        async_engine_client,
422
423
424
425
        model_config,
        served_model_names,
        request_logger=request_logger,
    )
426
    openai_serving_tokenization = OpenAIServingTokenization(
427
        async_engine_client,
428
429
430
431
432
433
        model_config,
        served_model_names,
        lora_modules=args.lora_modules,
        request_logger=request_logger,
        chat_template=args.chat_template,
    )
434
    app.root_path = args.root_path
435

436
    return app
437
438


439
async def run_server(args, **uvicorn_kwargs) -> None:
440
441
442
    logger.info("vLLM API server version %s", VLLM_VERSION)
    logger.info("args: %s", args)

443
    async with build_async_engine_client(args) as async_engine_client:
444
445
446
447
        # If None, creation of the client failed and we exit.
        if async_engine_client is None:
            return

448
449
450
451
        app = await init_app(async_engine_client, args)

        shutdown_task = await serve_http(
            app,
452
            engine=async_engine_client,
453
454
455
456
457
458
459
460
            host=args.host,
            port=args.port,
            log_level=args.uvicorn_log_level,
            timeout_keep_alive=TIMEOUT_KEEP_ALIVE,
            ssl_keyfile=args.ssl_keyfile,
            ssl_certfile=args.ssl_certfile,
            ssl_ca_certs=args.ssl_ca_certs,
            ssl_cert_reqs=args.ssl_cert_reqs,
461
462
463
            **uvicorn_kwargs,
        )

464
465
    # NB: Await server shutdown only after the backend context is exited
    await shutdown_task
466

Ethan Xu's avatar
Ethan Xu committed
467
468
469
470
471
472
473
474

if __name__ == "__main__":
    # NOTE(simon):
    # This section should be in sync with vllm/scripts.py for CLI entrypoints.
    parser = FlexibleArgumentParser(
        description="vLLM OpenAI-Compatible RESTful API server.")
    parser = make_arg_parser(parser)
    args = parser.parse_args()
475

476
    asyncio.run(run_server(args))