api_server.py 17.4 KB
Newer Older
1
import asyncio
2
3
import importlib
import inspect
4
import multiprocessing
5
import os
6
import re
7
import tempfile
8
from argparse import Namespace
9
from contextlib import asynccontextmanager, suppress
10
from http import HTTPStatus
11
from typing import AsyncIterator, Optional, Set
12

13
from fastapi import APIRouter, FastAPI, Request
Zhuohan Li's avatar
Zhuohan Li committed
14
15
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
16
from fastapi.responses import JSONResponse, Response, StreamingResponse
17
from starlette.routing import Mount
18
from typing_extensions import assert_never
Zhuohan Li's avatar
Zhuohan Li committed
19

20
import vllm.envs as envs
21
from vllm.config import ModelConfig
Woosuk Kwon's avatar
Woosuk Kwon committed
22
23
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
24
from vllm.engine.protocol import AsyncEngineClient
25
from vllm.entrypoints.launcher import serve_http
26
from vllm.entrypoints.logger import RequestLogger
27
from vllm.entrypoints.openai.cli_args import make_arg_parser
28
29
# yapf conflicts with isort for this block
# yapf: disable
30
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
31
                                              ChatCompletionResponse,
32
                                              CompletionRequest,
33
                                              CompletionResponse,
34
35
                                              DetokenizeRequest,
                                              DetokenizeResponse,
36
37
                                              EmbeddingRequest,
                                              EmbeddingResponse, ErrorResponse,
38
39
                                              TokenizeRequest,
                                              TokenizeResponse)
40
# yapf: enable
41
42
from vllm.entrypoints.openai.rpc.client import AsyncEngineRPCClient
from vllm.entrypoints.openai.rpc.server import run_rpc_server
43
44
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
45
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
46
47
from vllm.entrypoints.openai.serving_tokenization import (
    OpenAIServingTokenization)
48
from vllm.logger import init_logger
yhu422's avatar
yhu422 committed
49
from vllm.usage.usage_lib import UsageContext
50
from vllm.utils import FlexibleArgumentParser, get_open_zmq_ipc_path
51
from vllm.version import __version__ as VLLM_VERSION
Zhuohan Li's avatar
Zhuohan Li committed
52

53
TIMEOUT_KEEP_ALIVE = 5  # seconds
Zhuohan Li's avatar
Zhuohan Li committed
54

55
async_engine_client: AsyncEngineClient
Ethan Xu's avatar
Ethan Xu committed
56
engine_args: AsyncEngineArgs
57
58
openai_serving_chat: OpenAIServingChat
openai_serving_completion: OpenAIServingCompletion
59
openai_serving_embedding: OpenAIServingEmbedding
60
openai_serving_tokenization: OpenAIServingTokenization
61
prometheus_multiproc_dir: tempfile.TemporaryDirectory
62

63
# Cannot use __name__ (https://github.com/vllm-project/vllm/pull/4765)
64
logger = init_logger('vllm.entrypoints.openai.api_server')
65

66
_running_tasks: Set[asyncio.Task] = set()
67

68

69
70
def model_is_embedding(model_name: str, trust_remote_code: bool,
                       quantization: str) -> bool:
71
72
73
    return ModelConfig(model=model_name,
                       tokenizer=model_name,
                       tokenizer_mode="auto",
74
                       trust_remote_code=trust_remote_code,
75
                       quantization=quantization,
76
                       seed=0,
77
                       dtype="auto").embedding_mode
78
79


80
@asynccontextmanager
81
async def lifespan(app: FastAPI):
82
83
84
85

    async def _force_log():
        while True:
            await asyncio.sleep(10)
86
87
            with suppress(Exception):
                await async_engine_client.do_log_stats()
88
89

    if not engine_args.disable_log_stats:
90
91
92
        task = asyncio.create_task(_force_log())
        _running_tasks.add(task)
        task.add_done_callback(_running_tasks.remove)
93
94
95
96

    yield


97
@asynccontextmanager
98
async def build_async_engine_client(
99
100
101
102
103
104
105
106
107
        args: Namespace) -> AsyncIterator[Optional[AsyncEngineClient]]:
    """
    Create AsyncEngineClient, either:
        - in-process using the AsyncLLMEngine Directly
        - multiprocess using AsyncLLMEngine RPC

    Returns the Client or None if the creation failed.
    """

108
109
110
111
112
113
114
115
116
117
    # Context manager to handle async_engine_client lifecycle
    # Ensures everything is shutdown and cleaned up on error/exit
    global engine_args
    engine_args = AsyncEngineArgs.from_cli_args(args)

    # Backend itself still global for the silly lil' health handler
    global async_engine_client

    # If manually triggered or embedding model, use AsyncLLMEngine in process.
    # TODO: support embedding model via RPC.
118
119
    if (model_is_embedding(args.model, args.trust_remote_code,
                           args.quantization)
120
121
122
123
124
125
126
127
            or args.disable_frontend_multiprocessing):
        async_engine_client = AsyncLLMEngine.from_engine_args(
            engine_args, usage_context=UsageContext.OPENAI_API_SERVER)
        yield async_engine_client
        return

    # Otherwise, use the multiprocessing AsyncLLMEngine.
    else:
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
        if "PROMETHEUS_MULTIPROC_DIR" not in os.environ:
            # Make TemporaryDirectory for prometheus multiprocessing
            # Note: global TemporaryDirectory will be automatically
            #   cleaned up upon exit.
            global prometheus_multiproc_dir
            prometheus_multiproc_dir = tempfile.TemporaryDirectory()
            os.environ[
                "PROMETHEUS_MULTIPROC_DIR"] = prometheus_multiproc_dir.name
        else:
            logger.warning(
                "Found PROMETHEUS_MULTIPROC_DIR was set by user. "
                "This directory must be wiped between vLLM runs or "
                "you will find inaccurate metrics. Unset the variable "
                "and vLLM will properly handle cleanup.")

143
144
145
146
147
        # Select random path for IPC.
        rpc_path = get_open_zmq_ipc_path()
        logger.info("Multiprocessing frontend to use %s for RPC Path.",
                    rpc_path)

148
149
150
151
152
153
        # Build RPCClient, which conforms to AsyncEngineClient Protocol.
        # NOTE: Actually, this is not true yet. We still need to support
        # embedding models via RPC (see TODO above)
        rpc_client = AsyncEngineRPCClient(rpc_path)
        async_engine_client = rpc_client  # type: ignore

154
        # Start RPCServer in separate process (holds the AsyncLLMEngine).
155
156
157
158
159
160
        context = multiprocessing.get_context("spawn")
        # the current process might have CUDA context,
        # so we need to spawn a new process
        rpc_server_process = context.Process(
            target=run_rpc_server,
            args=(engine_args, UsageContext.OPENAI_API_SERVER, rpc_path))
161
        rpc_server_process.start()
162
163
        logger.info("Started engine process with PID %d",
                    rpc_server_process.pid)
164
165

        try:
166
167
            while True:
                try:
168
                    await rpc_client.setup()
169
                    break
170
                except TimeoutError:
171
                    if not rpc_server_process.is_alive():
172
173
174
175
176
                        logger.error(
                            "RPCServer process died before responding "
                            "to readiness probe")
                        yield None
                        return
177

178
179
180
181
182
183
            yield async_engine_client
        finally:
            # Ensure rpc server process was terminated
            rpc_server_process.terminate()

            # Close all open connections to the backend
184
            rpc_client.close()
185
186
187
188

            # Wait for server process to join
            rpc_server_process.join()

189
190
191
192
193
194
195
            # Lazy import for prometheus multiprocessing.
            # We need to set PROMETHEUS_MULTIPROC_DIR environment variable
            # before prometheus_client is imported.
            # See https://prometheus.github.io/client_python/multiprocess/
            from prometheus_client import multiprocess
            multiprocess.mark_process_dead(rpc_server_process.pid)

196

Ethan Xu's avatar
Ethan Xu committed
197
router = APIRouter()
Zhuohan Li's avatar
Zhuohan Li committed
198

199

200
def mount_metrics(app: FastAPI):
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
    # Lazy import for prometheus multiprocessing.
    # We need to set PROMETHEUS_MULTIPROC_DIR environment variable
    # before prometheus_client is imported.
    # See https://prometheus.github.io/client_python/multiprocess/
    from prometheus_client import (CollectorRegistry, make_asgi_app,
                                   multiprocess)

    prometheus_multiproc_dir_path = os.getenv("PROMETHEUS_MULTIPROC_DIR", None)
    if prometheus_multiproc_dir_path is not None:
        logger.info("vLLM to use %s as PROMETHEUS_MULTIPROC_DIR",
                    prometheus_multiproc_dir_path)
        registry = CollectorRegistry()
        multiprocess.MultiProcessCollector(registry)

        # Add prometheus asgi middleware to route /metrics requests
        metrics_route = Mount("/metrics", make_asgi_app(registry=registry))
    else:
        # Add prometheus asgi middleware to route /metrics requests
        metrics_route = Mount("/metrics", make_asgi_app())

221
222
223
    # Workaround for 307 Redirect for /metrics
    metrics_route.path_regex = re.compile('^/metrics(?P<path>.*)$')
    app.routes.append(metrics_route)
224
225


Ethan Xu's avatar
Ethan Xu committed
226
@router.get("/health")
227
228
async def health() -> Response:
    """Health check."""
229
    await async_engine_client.check_health()
230
231
232
    return Response(status_code=200)


Ethan Xu's avatar
Ethan Xu committed
233
@router.post("/tokenize")
234
async def tokenize(request: TokenizeRequest):
235
    generator = await openai_serving_tokenization.create_tokenize(request)
236
237
238
    if isinstance(generator, ErrorResponse):
        return JSONResponse(content=generator.model_dump(),
                            status_code=generator.code)
239
    elif isinstance(generator, TokenizeResponse):
240
241
        return JSONResponse(content=generator.model_dump())

242
243
    assert_never(generator)

244

Ethan Xu's avatar
Ethan Xu committed
245
@router.post("/detokenize")
246
async def detokenize(request: DetokenizeRequest):
247
    generator = await openai_serving_tokenization.create_detokenize(request)
248
249
250
    if isinstance(generator, ErrorResponse):
        return JSONResponse(content=generator.model_dump(),
                            status_code=generator.code)
251
    elif isinstance(generator, DetokenizeResponse):
252
253
        return JSONResponse(content=generator.model_dump())

254
255
    assert_never(generator)

256

Ethan Xu's avatar
Ethan Xu committed
257
@router.get("/v1/models")
Zhuohan Li's avatar
Zhuohan Li committed
258
async def show_available_models():
259
    models = await openai_serving_completion.show_available_models()
260
    return JSONResponse(content=models.model_dump())
Zhuohan Li's avatar
Zhuohan Li committed
261
262


Ethan Xu's avatar
Ethan Xu committed
263
@router.get("/version")
264
async def show_version():
265
    ver = {"version": VLLM_VERSION}
266
267
268
    return JSONResponse(content=ver)


Ethan Xu's avatar
Ethan Xu committed
269
@router.post("/v1/chat/completions")
270
271
async def create_chat_completion(request: ChatCompletionRequest,
                                 raw_request: Request):
272
273
    generator = await openai_serving_chat.create_chat_completion(
        request, raw_request)
274
275
276
    if isinstance(generator, ErrorResponse):
        return JSONResponse(content=generator.model_dump(),
                            status_code=generator.code)
277
    elif isinstance(generator, ChatCompletionResponse):
278
        return JSONResponse(content=generator.model_dump())
279

280
281
    return StreamingResponse(content=generator, media_type="text/event-stream")

282

Ethan Xu's avatar
Ethan Xu committed
283
@router.post("/v1/completions")
284
async def create_completion(request: CompletionRequest, raw_request: Request):
285
286
    generator = await openai_serving_completion.create_completion(
        request, raw_request)
287
288
289
    if isinstance(generator, ErrorResponse):
        return JSONResponse(content=generator.model_dump(),
                            status_code=generator.code)
290
    elif isinstance(generator, CompletionResponse):
291
        return JSONResponse(content=generator.model_dump())
Zhuohan Li's avatar
Zhuohan Li committed
292

293
294
    return StreamingResponse(content=generator, media_type="text/event-stream")

Zhuohan Li's avatar
Zhuohan Li committed
295

Ethan Xu's avatar
Ethan Xu committed
296
@router.post("/v1/embeddings")
297
298
299
300
301
302
async def create_embedding(request: EmbeddingRequest, raw_request: Request):
    generator = await openai_serving_embedding.create_embedding(
        request, raw_request)
    if isinstance(generator, ErrorResponse):
        return JSONResponse(content=generator.model_dump(),
                            status_code=generator.code)
303
    elif isinstance(generator, EmbeddingResponse):
304
305
        return JSONResponse(content=generator.model_dump())

306
307
    assert_never(generator)

308

309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
if envs.VLLM_TORCH_PROFILER_DIR:
    logger.warning(
        "Torch Profiler is enabled in the API server. This should ONLY be "
        "used for local development!")

    @router.post("/start_profile")
    async def start_profile():
        logger.info("Starting profiler...")
        await async_engine_client.start_profile()
        logger.info("Profiler started.")
        return Response(status_code=200)

    @router.post("/stop_profile")
    async def stop_profile():
        logger.info("Stopping profiler...")
        await async_engine_client.stop_profile()
        logger.info("Profiler stopped.")
        return Response(status_code=200)


329
330
def build_app(args: Namespace) -> FastAPI:
    app = FastAPI(lifespan=lifespan)
Ethan Xu's avatar
Ethan Xu committed
331
332
    app.include_router(router)
    app.root_path = args.root_path
Zhuohan Li's avatar
Zhuohan Li committed
333

334
335
    mount_metrics(app)

Zhuohan Li's avatar
Zhuohan Li committed
336
337
338
339
340
341
342
343
    app.add_middleware(
        CORSMiddleware,
        allow_origins=args.allowed_origins,
        allow_credentials=args.allow_credentials,
        allow_methods=args.allowed_methods,
        allow_headers=args.allowed_headers,
    )

Ethan Xu's avatar
Ethan Xu committed
344
345
346
347
348
349
    @app.exception_handler(RequestValidationError)
    async def validation_exception_handler(_, exc):
        err = openai_serving_chat.create_error_response(message=str(exc))
        return JSONResponse(err.model_dump(),
                            status_code=HTTPStatus.BAD_REQUEST)

350
    if token := envs.VLLM_API_KEY or args.api_key:
351
352
353

        @app.middleware("http")
        async def authentication(request: Request, call_next):
354
            root_path = "" if args.root_path is None else args.root_path
355
356
            if request.method == "OPTIONS":
                return await call_next(request)
357
            if not request.url.path.startswith(f"{root_path}/v1"):
358
359
360
361
362
363
364
365
366
367
368
369
370
371
                return await call_next(request)
            if request.headers.get("Authorization") != "Bearer " + token:
                return JSONResponse(content={"error": "Unauthorized"},
                                    status_code=401)
            return await call_next(request)

    for middleware in args.middleware:
        module_path, object_name = middleware.rsplit(".", 1)
        imported = getattr(importlib.import_module(module_path), object_name)
        if inspect.isclass(imported):
            app.add_middleware(imported)
        elif inspect.iscoroutinefunction(imported):
            app.middleware("http")(imported)
        else:
372
373
            raise ValueError(f"Invalid middleware {middleware}. "
                             f"Must be a function or a class.")
374

Ethan Xu's avatar
Ethan Xu committed
375
376
377
    return app


378
async def init_app(
379
    async_engine_client: AsyncEngineClient,
380
381
    args: Namespace,
) -> FastAPI:
Ethan Xu's avatar
Ethan Xu committed
382
383
    app = build_app(args)

384
    if args.served_model_name is not None:
385
        served_model_names = args.served_model_name
386
    else:
387
        served_model_names = [args.model]
388

389
    model_config = await async_engine_client.get_model_config()
390

391
392
393
394
395
    if args.disable_log_requests:
        request_logger = None
    else:
        request_logger = RequestLogger(max_log_len=args.max_log_len)

Ethan Xu's avatar
Ethan Xu committed
396
397
398
    global openai_serving_chat
    global openai_serving_completion
    global openai_serving_embedding
399
    global openai_serving_tokenization
Ethan Xu's avatar
Ethan Xu committed
400

401
    openai_serving_chat = OpenAIServingChat(
402
        async_engine_client,
403
404
405
406
407
408
409
        model_config,
        served_model_names,
        args.response_role,
        lora_modules=args.lora_modules,
        prompt_adapters=args.prompt_adapters,
        request_logger=request_logger,
        chat_template=args.chat_template,
410
        return_tokens_as_token_ids=args.return_tokens_as_token_ids,
411
    )
412
    openai_serving_completion = OpenAIServingCompletion(
413
        async_engine_client,
414
415
416
417
418
        model_config,
        served_model_names,
        lora_modules=args.lora_modules,
        prompt_adapters=args.prompt_adapters,
        request_logger=request_logger,
419
        return_tokens_as_token_ids=args.return_tokens_as_token_ids,
420
421
    )
    openai_serving_embedding = OpenAIServingEmbedding(
422
        async_engine_client,
423
424
425
426
        model_config,
        served_model_names,
        request_logger=request_logger,
    )
427
    openai_serving_tokenization = OpenAIServingTokenization(
428
        async_engine_client,
429
430
431
432
433
434
        model_config,
        served_model_names,
        lora_modules=args.lora_modules,
        request_logger=request_logger,
        chat_template=args.chat_template,
    )
435
    app.root_path = args.root_path
436

437
    return app
438
439


440
async def run_server(args, **uvicorn_kwargs) -> None:
441
442
443
    logger.info("vLLM API server version %s", VLLM_VERSION)
    logger.info("args: %s", args)

444
    async with build_async_engine_client(args) as async_engine_client:
445
446
447
448
        # If None, creation of the client failed and we exit.
        if async_engine_client is None:
            return

449
450
451
452
        app = await init_app(async_engine_client, args)

        shutdown_task = await serve_http(
            app,
453
            engine=async_engine_client,
454
455
456
457
458
459
460
461
            host=args.host,
            port=args.port,
            log_level=args.uvicorn_log_level,
            timeout_keep_alive=TIMEOUT_KEEP_ALIVE,
            ssl_keyfile=args.ssl_keyfile,
            ssl_certfile=args.ssl_certfile,
            ssl_ca_certs=args.ssl_ca_certs,
            ssl_cert_reqs=args.ssl_cert_reqs,
462
463
464
            **uvicorn_kwargs,
        )

465
466
    # NB: Await server shutdown only after the backend context is exited
    await shutdown_task
467

Ethan Xu's avatar
Ethan Xu committed
468
469
470
471
472
473
474
475

if __name__ == "__main__":
    # NOTE(simon):
    # This section should be in sync with vllm/scripts.py for CLI entrypoints.
    parser = FlexibleArgumentParser(
        description="vLLM OpenAI-Compatible RESTful API server.")
    parser = make_arg_parser(parser)
    args = parser.parse_args()
476

477
    asyncio.run(run_server(args))