api_server.py 16.2 KB
Newer Older
1
import asyncio
2
3
import importlib
import inspect
4
import multiprocessing
5
import os
6
import re
7
import tempfile
8
from argparse import Namespace
9
10
from contextlib import asynccontextmanager
from http import HTTPStatus
11
from typing import AsyncIterator, Set
12

13
from fastapi import APIRouter, FastAPI, Request
Zhuohan Li's avatar
Zhuohan Li committed
14
15
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
16
from fastapi.responses import JSONResponse, Response, StreamingResponse
17
from starlette.routing import Mount
18
from typing_extensions import assert_never
Zhuohan Li's avatar
Zhuohan Li committed
19

20
import vllm.envs as envs
21
from vllm.config import ModelConfig
Woosuk Kwon's avatar
Woosuk Kwon committed
22
23
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
24
from vllm.engine.protocol import AsyncEngineClient
25
from vllm.entrypoints.launcher import serve_http
26
from vllm.entrypoints.logger import RequestLogger
27
from vllm.entrypoints.openai.cli_args import make_arg_parser
28
29
# yapf conflicts with isort for this block
# yapf: disable
30
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
31
                                              ChatCompletionResponse,
32
                                              CompletionRequest,
33
                                              CompletionResponse,
34
35
                                              DetokenizeRequest,
                                              DetokenizeResponse,
36
37
                                              EmbeddingRequest,
                                              EmbeddingResponse, ErrorResponse,
38
39
                                              TokenizeRequest,
                                              TokenizeResponse)
40
# yapf: enable
41
42
from vllm.entrypoints.openai.rpc.client import AsyncEngineRPCClient
from vllm.entrypoints.openai.rpc.server import run_rpc_server
43
44
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
45
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
46
47
from vllm.entrypoints.openai.serving_tokenization import (
    OpenAIServingTokenization)
48
from vllm.logger import init_logger
yhu422's avatar
yhu422 committed
49
from vllm.usage.usage_lib import UsageContext
50
from vllm.utils import FlexibleArgumentParser, get_open_zmq_ipc_path
51
from vllm.version import __version__ as VLLM_VERSION
Zhuohan Li's avatar
Zhuohan Li committed
52

53
TIMEOUT_KEEP_ALIVE = 5  # seconds
Zhuohan Li's avatar
Zhuohan Li committed
54

55
async_engine_client: AsyncEngineClient
Ethan Xu's avatar
Ethan Xu committed
56
engine_args: AsyncEngineArgs
57
58
openai_serving_chat: OpenAIServingChat
openai_serving_completion: OpenAIServingCompletion
59
openai_serving_embedding: OpenAIServingEmbedding
60
openai_serving_tokenization: OpenAIServingTokenization
61
prometheus_multiproc_dir: tempfile.TemporaryDirectory
62

63
logger = init_logger('vllm.entrypoints.openai.api_server')
64

65
_running_tasks: Set[asyncio.Task] = set()
66

67

68
69
def model_is_embedding(model_name: str, trust_remote_code: bool,
                       quantization: str) -> bool:
70
71
72
    return ModelConfig(model=model_name,
                       tokenizer=model_name,
                       tokenizer_mode="auto",
73
                       trust_remote_code=trust_remote_code,
74
                       quantization=quantization,
75
                       seed=0,
76
                       dtype="auto").embedding_mode
77
78


79
@asynccontextmanager
80
async def lifespan(app: FastAPI):
81
82
83
84

    async def _force_log():
        while True:
            await asyncio.sleep(10)
85
            await async_engine_client.do_log_stats()
86
87

    if not engine_args.disable_log_stats:
88
89
90
        task = asyncio.create_task(_force_log())
        _running_tasks.add(task)
        task.add_done_callback(_running_tasks.remove)
91
92
93
94

    yield


95
@asynccontextmanager
96
97
async def build_async_engine_client(
        args: Namespace) -> AsyncIterator[AsyncEngineClient]:
98
99
100
101
102
103
104
105
106
107
    # Context manager to handle async_engine_client lifecycle
    # Ensures everything is shutdown and cleaned up on error/exit
    global engine_args
    engine_args = AsyncEngineArgs.from_cli_args(args)

    # Backend itself still global for the silly lil' health handler
    global async_engine_client

    # If manually triggered or embedding model, use AsyncLLMEngine in process.
    # TODO: support embedding model via RPC.
108
109
    if (model_is_embedding(args.model, args.trust_remote_code,
                           args.quantization)
110
111
112
113
114
115
116
117
            or args.disable_frontend_multiprocessing):
        async_engine_client = AsyncLLMEngine.from_engine_args(
            engine_args, usage_context=UsageContext.OPENAI_API_SERVER)
        yield async_engine_client
        return

    # Otherwise, use the multiprocessing AsyncLLMEngine.
    else:
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
        if "PROMETHEUS_MULTIPROC_DIR" not in os.environ:
            # Make TemporaryDirectory for prometheus multiprocessing
            # Note: global TemporaryDirectory will be automatically
            #   cleaned up upon exit.
            global prometheus_multiproc_dir
            prometheus_multiproc_dir = tempfile.TemporaryDirectory()
            os.environ[
                "PROMETHEUS_MULTIPROC_DIR"] = prometheus_multiproc_dir.name
        else:
            logger.warning(
                "Found PROMETHEUS_MULTIPROC_DIR was set by user. "
                "This directory must be wiped between vLLM runs or "
                "you will find inaccurate metrics. Unset the variable "
                "and vLLM will properly handle cleanup.")

133
134
135
136
137
        # Select random path for IPC.
        rpc_path = get_open_zmq_ipc_path()
        logger.info("Multiprocessing frontend to use %s for RPC Path.",
                    rpc_path)

138
        # Start RPCServer in separate process (holds the AsyncLLMEngine).
139
140
141
142
143
144
        context = multiprocessing.get_context("spawn")
        # the current process might have CUDA context,
        # so we need to spawn a new process
        rpc_server_process = context.Process(
            target=run_rpc_server,
            args=(engine_args, UsageContext.OPENAI_API_SERVER, rpc_path))
145
        rpc_server_process.start()
146
147
        logger.info("Started engine process with PID %d",
                    rpc_server_process.pid)
148
        # Build RPCClient, which conforms to AsyncEngineClient Protocol.
149
150
151
152
        # NOTE: Actually, this is not true yet. We still need to support
        # embedding models via RPC (see TODO above)
        rpc_client = AsyncEngineRPCClient(rpc_path)
        async_engine_client = rpc_client  # type: ignore
153
154

        try:
155
156
            while True:
                try:
157
                    await rpc_client.setup()
158
159
160
161
162
163
164
                    break
                except TimeoutError as e:
                    if not rpc_server_process.is_alive():
                        raise RuntimeError(
                            "The server process died before "
                            "responding to the readiness probe") from e

165
166
167
168
169
170
            yield async_engine_client
        finally:
            # Ensure rpc server process was terminated
            rpc_server_process.terminate()

            # Close all open connections to the backend
171
            rpc_client.close()
172
173
174
175

            # Wait for server process to join
            rpc_server_process.join()

176
177
178
179
180
181
182
            # Lazy import for prometheus multiprocessing.
            # We need to set PROMETHEUS_MULTIPROC_DIR environment variable
            # before prometheus_client is imported.
            # See https://prometheus.github.io/client_python/multiprocess/
            from prometheus_client import multiprocess
            multiprocess.mark_process_dead(rpc_server_process.pid)

183

Ethan Xu's avatar
Ethan Xu committed
184
router = APIRouter()
Zhuohan Li's avatar
Zhuohan Li committed
185

186

187
def mount_metrics(app: FastAPI):
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
    # Lazy import for prometheus multiprocessing.
    # We need to set PROMETHEUS_MULTIPROC_DIR environment variable
    # before prometheus_client is imported.
    # See https://prometheus.github.io/client_python/multiprocess/
    from prometheus_client import (CollectorRegistry, make_asgi_app,
                                   multiprocess)

    prometheus_multiproc_dir_path = os.getenv("PROMETHEUS_MULTIPROC_DIR", None)
    if prometheus_multiproc_dir_path is not None:
        logger.info("vLLM to use %s as PROMETHEUS_MULTIPROC_DIR",
                    prometheus_multiproc_dir_path)
        registry = CollectorRegistry()
        multiprocess.MultiProcessCollector(registry)

        # Add prometheus asgi middleware to route /metrics requests
        metrics_route = Mount("/metrics", make_asgi_app(registry=registry))
    else:
        # Add prometheus asgi middleware to route /metrics requests
        metrics_route = Mount("/metrics", make_asgi_app())

208
209
210
    # Workaround for 307 Redirect for /metrics
    metrics_route.path_regex = re.compile('^/metrics(?P<path>.*)$')
    app.routes.append(metrics_route)
211
212


Ethan Xu's avatar
Ethan Xu committed
213
@router.get("/health")
214
215
async def health() -> Response:
    """Health check."""
216
    await async_engine_client.check_health()
217
218
219
    return Response(status_code=200)


Ethan Xu's avatar
Ethan Xu committed
220
@router.post("/tokenize")
221
async def tokenize(request: TokenizeRequest):
222
    generator = await openai_serving_tokenization.create_tokenize(request)
223
224
225
    if isinstance(generator, ErrorResponse):
        return JSONResponse(content=generator.model_dump(),
                            status_code=generator.code)
226
    elif isinstance(generator, TokenizeResponse):
227
228
        return JSONResponse(content=generator.model_dump())

229
230
    assert_never(generator)

231

Ethan Xu's avatar
Ethan Xu committed
232
@router.post("/detokenize")
233
async def detokenize(request: DetokenizeRequest):
234
    generator = await openai_serving_tokenization.create_detokenize(request)
235
236
237
    if isinstance(generator, ErrorResponse):
        return JSONResponse(content=generator.model_dump(),
                            status_code=generator.code)
238
    elif isinstance(generator, DetokenizeResponse):
239
240
        return JSONResponse(content=generator.model_dump())

241
242
    assert_never(generator)

243

Ethan Xu's avatar
Ethan Xu committed
244
@router.get("/v1/models")
Zhuohan Li's avatar
Zhuohan Li committed
245
async def show_available_models():
246
    models = await openai_serving_completion.show_available_models()
247
    return JSONResponse(content=models.model_dump())
Zhuohan Li's avatar
Zhuohan Li committed
248
249


Ethan Xu's avatar
Ethan Xu committed
250
@router.get("/version")
251
async def show_version():
252
    ver = {"version": VLLM_VERSION}
253
254
255
    return JSONResponse(content=ver)


Ethan Xu's avatar
Ethan Xu committed
256
@router.post("/v1/chat/completions")
257
258
async def create_chat_completion(request: ChatCompletionRequest,
                                 raw_request: Request):
259
260
    generator = await openai_serving_chat.create_chat_completion(
        request, raw_request)
261
262
263
    if isinstance(generator, ErrorResponse):
        return JSONResponse(content=generator.model_dump(),
                            status_code=generator.code)
264
    elif isinstance(generator, ChatCompletionResponse):
265
        return JSONResponse(content=generator.model_dump())
266

267
268
    return StreamingResponse(content=generator, media_type="text/event-stream")

269

Ethan Xu's avatar
Ethan Xu committed
270
@router.post("/v1/completions")
271
async def create_completion(request: CompletionRequest, raw_request: Request):
272
273
    generator = await openai_serving_completion.create_completion(
        request, raw_request)
274
275
276
    if isinstance(generator, ErrorResponse):
        return JSONResponse(content=generator.model_dump(),
                            status_code=generator.code)
277
    elif isinstance(generator, CompletionResponse):
278
        return JSONResponse(content=generator.model_dump())
Zhuohan Li's avatar
Zhuohan Li committed
279

280
281
    return StreamingResponse(content=generator, media_type="text/event-stream")

Zhuohan Li's avatar
Zhuohan Li committed
282

Ethan Xu's avatar
Ethan Xu committed
283
@router.post("/v1/embeddings")
284
285
286
287
288
289
async def create_embedding(request: EmbeddingRequest, raw_request: Request):
    generator = await openai_serving_embedding.create_embedding(
        request, raw_request)
    if isinstance(generator, ErrorResponse):
        return JSONResponse(content=generator.model_dump(),
                            status_code=generator.code)
290
    elif isinstance(generator, EmbeddingResponse):
291
292
        return JSONResponse(content=generator.model_dump())

293
294
    assert_never(generator)

295

296
297
def build_app(args: Namespace) -> FastAPI:
    app = FastAPI(lifespan=lifespan)
Ethan Xu's avatar
Ethan Xu committed
298
299
    app.include_router(router)
    app.root_path = args.root_path
Zhuohan Li's avatar
Zhuohan Li committed
300

301
302
    mount_metrics(app)

Zhuohan Li's avatar
Zhuohan Li committed
303
304
305
306
307
308
309
310
    app.add_middleware(
        CORSMiddleware,
        allow_origins=args.allowed_origins,
        allow_credentials=args.allow_credentials,
        allow_methods=args.allowed_methods,
        allow_headers=args.allowed_headers,
    )

Ethan Xu's avatar
Ethan Xu committed
311
312
313
314
315
316
    @app.exception_handler(RequestValidationError)
    async def validation_exception_handler(_, exc):
        err = openai_serving_chat.create_error_response(message=str(exc))
        return JSONResponse(err.model_dump(),
                            status_code=HTTPStatus.BAD_REQUEST)

317
    if token := envs.VLLM_API_KEY or args.api_key:
318
319
320

        @app.middleware("http")
        async def authentication(request: Request, call_next):
321
            root_path = "" if args.root_path is None else args.root_path
322
323
            if request.method == "OPTIONS":
                return await call_next(request)
324
            if not request.url.path.startswith(f"{root_path}/v1"):
325
326
327
328
329
330
331
332
333
334
335
336
337
338
                return await call_next(request)
            if request.headers.get("Authorization") != "Bearer " + token:
                return JSONResponse(content={"error": "Unauthorized"},
                                    status_code=401)
            return await call_next(request)

    for middleware in args.middleware:
        module_path, object_name = middleware.rsplit(".", 1)
        imported = getattr(importlib.import_module(module_path), object_name)
        if inspect.isclass(imported):
            app.add_middleware(imported)
        elif inspect.iscoroutinefunction(imported):
            app.middleware("http")(imported)
        else:
339
340
            raise ValueError(f"Invalid middleware {middleware}. "
                             f"Must be a function or a class.")
341

Ethan Xu's avatar
Ethan Xu committed
342
343
344
    return app


345
async def init_app(
346
    async_engine_client: AsyncEngineClient,
347
348
    args: Namespace,
) -> FastAPI:
Ethan Xu's avatar
Ethan Xu committed
349
350
    app = build_app(args)

351
    if args.served_model_name is not None:
352
        served_model_names = args.served_model_name
353
    else:
354
        served_model_names = [args.model]
355

356
    model_config = await async_engine_client.get_model_config()
357

358
359
360
361
362
    if args.disable_log_requests:
        request_logger = None
    else:
        request_logger = RequestLogger(max_log_len=args.max_log_len)

Ethan Xu's avatar
Ethan Xu committed
363
364
365
    global openai_serving_chat
    global openai_serving_completion
    global openai_serving_embedding
366
    global openai_serving_tokenization
Ethan Xu's avatar
Ethan Xu committed
367

368
    openai_serving_chat = OpenAIServingChat(
369
        async_engine_client,
370
371
372
373
374
375
376
        model_config,
        served_model_names,
        args.response_role,
        lora_modules=args.lora_modules,
        prompt_adapters=args.prompt_adapters,
        request_logger=request_logger,
        chat_template=args.chat_template,
377
        return_tokens_as_token_ids=args.return_tokens_as_token_ids,
378
    )
379
    openai_serving_completion = OpenAIServingCompletion(
380
        async_engine_client,
381
382
383
384
385
        model_config,
        served_model_names,
        lora_modules=args.lora_modules,
        prompt_adapters=args.prompt_adapters,
        request_logger=request_logger,
386
        return_tokens_as_token_ids=args.return_tokens_as_token_ids,
387
388
    )
    openai_serving_embedding = OpenAIServingEmbedding(
389
        async_engine_client,
390
391
392
393
        model_config,
        served_model_names,
        request_logger=request_logger,
    )
394
    openai_serving_tokenization = OpenAIServingTokenization(
395
        async_engine_client,
396
397
398
399
400
401
        model_config,
        served_model_names,
        lora_modules=args.lora_modules,
        request_logger=request_logger,
        chat_template=args.chat_template,
    )
402
    app.root_path = args.root_path
403

404
    return app
405
406


407
async def run_server(args, **uvicorn_kwargs) -> None:
408
409
410
    logger.info("vLLM API server version %s", VLLM_VERSION)
    logger.info("args: %s", args)

411
    async with build_async_engine_client(args) as async_engine_client:
412
413
414
415
        app = await init_app(async_engine_client, args)

        shutdown_task = await serve_http(
            app,
416
            engine=async_engine_client,
417
418
419
420
421
422
423
424
            host=args.host,
            port=args.port,
            log_level=args.uvicorn_log_level,
            timeout_keep_alive=TIMEOUT_KEEP_ALIVE,
            ssl_keyfile=args.ssl_keyfile,
            ssl_certfile=args.ssl_certfile,
            ssl_ca_certs=args.ssl_ca_certs,
            ssl_cert_reqs=args.ssl_cert_reqs,
425
426
427
            **uvicorn_kwargs,
        )

428
429
    # NB: Await server shutdown only after the backend context is exited
    await shutdown_task
430

Ethan Xu's avatar
Ethan Xu committed
431
432
433
434
435
436
437
438

if __name__ == "__main__":
    # NOTE(simon):
    # This section should be in sync with vllm/scripts.py for CLI entrypoints.
    parser = FlexibleArgumentParser(
        description="vLLM OpenAI-Compatible RESTful API server.")
    parser = make_arg_parser(parser)
    args = parser.parse_args()
439

440
    asyncio.run(run_server(args))