api_server.py 16.2 KB
Newer Older
1
import asyncio
2
3
import importlib
import inspect
4
import multiprocessing
5
import os
6
import re
7
import tempfile
8
from argparse import Namespace
9
10
from contextlib import asynccontextmanager
from http import HTTPStatus
11
from typing import AsyncIterator, Set
12

13
from fastapi import APIRouter, FastAPI, Request
Zhuohan Li's avatar
Zhuohan Li committed
14
15
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
16
from fastapi.responses import JSONResponse, Response, StreamingResponse
17
from starlette.routing import Mount
18
from typing_extensions import assert_never
Zhuohan Li's avatar
Zhuohan Li committed
19

20
import vllm.envs as envs
21
from vllm.config import ModelConfig
Woosuk Kwon's avatar
Woosuk Kwon committed
22
23
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
24
from vllm.engine.protocol import AsyncEngineClient
25
from vllm.entrypoints.launcher import serve_http
26
from vllm.entrypoints.logger import RequestLogger
27
from vllm.entrypoints.openai.cli_args import make_arg_parser
28
29
# yapf conflicts with isort for this block
# yapf: disable
30
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
31
                                              ChatCompletionResponse,
32
                                              CompletionRequest,
33
                                              CompletionResponse,
34
35
                                              DetokenizeRequest,
                                              DetokenizeResponse,
36
37
                                              EmbeddingRequest,
                                              EmbeddingResponse, ErrorResponse,
38
39
                                              TokenizeRequest,
                                              TokenizeResponse)
40
# yapf: enable
41
42
from vllm.entrypoints.openai.rpc.client import AsyncEngineRPCClient
from vllm.entrypoints.openai.rpc.server import run_rpc_server
43
44
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
45
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
46
47
from vllm.entrypoints.openai.serving_tokenization import (
    OpenAIServingTokenization)
48
from vllm.logger import init_logger
yhu422's avatar
yhu422 committed
49
from vllm.usage.usage_lib import UsageContext
50
from vllm.utils import FlexibleArgumentParser, get_open_zmq_ipc_path
51
from vllm.version import __version__ as VLLM_VERSION
Zhuohan Li's avatar
Zhuohan Li committed
52

53
TIMEOUT_KEEP_ALIVE = 5  # seconds
Zhuohan Li's avatar
Zhuohan Li committed
54

55
async_engine_client: AsyncEngineClient
Ethan Xu's avatar
Ethan Xu committed
56
engine_args: AsyncEngineArgs
57
58
openai_serving_chat: OpenAIServingChat
openai_serving_completion: OpenAIServingCompletion
59
openai_serving_embedding: OpenAIServingEmbedding
60
openai_serving_tokenization: OpenAIServingTokenization
61
prometheus_multiproc_dir: tempfile.TemporaryDirectory
62

63
logger = init_logger('vllm.entrypoints.openai.api_server')
64

65
_running_tasks: Set[asyncio.Task] = set()
66

67

68
69
def model_is_embedding(model_name: str, trust_remote_code: bool,
                       quantization: str) -> bool:
70
71
72
    return ModelConfig(model=model_name,
                       tokenizer=model_name,
                       tokenizer_mode="auto",
73
                       trust_remote_code=trust_remote_code,
74
                       quantization=quantization,
75
                       seed=0,
76
                       dtype="auto").embedding_mode
77
78


79
@asynccontextmanager
80
async def lifespan(app: FastAPI):
81
82
83
84

    async def _force_log():
        while True:
            await asyncio.sleep(10)
85
            await async_engine_client.do_log_stats()
86
87

    if not engine_args.disable_log_stats:
88
89
90
        task = asyncio.create_task(_force_log())
        _running_tasks.add(task)
        task.add_done_callback(_running_tasks.remove)
91
92
93
94

    yield


95
@asynccontextmanager
96
97
async def build_async_engine_client(
        args: Namespace) -> AsyncIterator[AsyncEngineClient]:
98
99
100
101
102
103
104
105
106
107
    # Context manager to handle async_engine_client lifecycle
    # Ensures everything is shutdown and cleaned up on error/exit
    global engine_args
    engine_args = AsyncEngineArgs.from_cli_args(args)

    # Backend itself still global for the silly lil' health handler
    global async_engine_client

    # If manually triggered or embedding model, use AsyncLLMEngine in process.
    # TODO: support embedding model via RPC.
108
109
    if (model_is_embedding(args.model, args.trust_remote_code,
                           args.quantization)
110
111
112
113
114
115
116
117
            or args.disable_frontend_multiprocessing):
        async_engine_client = AsyncLLMEngine.from_engine_args(
            engine_args, usage_context=UsageContext.OPENAI_API_SERVER)
        yield async_engine_client
        return

    # Otherwise, use the multiprocessing AsyncLLMEngine.
    else:
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
        if "PROMETHEUS_MULTIPROC_DIR" not in os.environ:
            # Make TemporaryDirectory for prometheus multiprocessing
            # Note: global TemporaryDirectory will be automatically
            #   cleaned up upon exit.
            global prometheus_multiproc_dir
            prometheus_multiproc_dir = tempfile.TemporaryDirectory()
            os.environ[
                "PROMETHEUS_MULTIPROC_DIR"] = prometheus_multiproc_dir.name
        else:
            logger.warning(
                "Found PROMETHEUS_MULTIPROC_DIR was set by user. "
                "This directory must be wiped between vLLM runs or "
                "you will find inaccurate metrics. Unset the variable "
                "and vLLM will properly handle cleanup.")

133
134
135
136
137
        # Select random path for IPC.
        rpc_path = get_open_zmq_ipc_path()
        logger.info("Multiprocessing frontend to use %s for RPC Path.",
                    rpc_path)

138
139
140
141
142
143
        # Build RPCClient, which conforms to AsyncEngineClient Protocol.
        # NOTE: Actually, this is not true yet. We still need to support
        # embedding models via RPC (see TODO above)
        rpc_client = AsyncEngineRPCClient(rpc_path)
        async_engine_client = rpc_client  # type: ignore

144
        # Start RPCServer in separate process (holds the AsyncLLMEngine).
145
146
147
148
149
150
        context = multiprocessing.get_context("spawn")
        # the current process might have CUDA context,
        # so we need to spawn a new process
        rpc_server_process = context.Process(
            target=run_rpc_server,
            args=(engine_args, UsageContext.OPENAI_API_SERVER, rpc_path))
151
        rpc_server_process.start()
152
153
        logger.info("Started engine process with PID %d",
                    rpc_server_process.pid)
154
155

        try:
156
157
            while True:
                try:
158
                    await rpc_client.setup()
159
160
161
162
163
164
165
                    break
                except TimeoutError as e:
                    if not rpc_server_process.is_alive():
                        raise RuntimeError(
                            "The server process died before "
                            "responding to the readiness probe") from e

166
167
168
169
170
171
            yield async_engine_client
        finally:
            # Ensure rpc server process was terminated
            rpc_server_process.terminate()

            # Close all open connections to the backend
172
            rpc_client.close()
173
174
175
176

            # Wait for server process to join
            rpc_server_process.join()

177
178
179
180
181
182
183
            # Lazy import for prometheus multiprocessing.
            # We need to set PROMETHEUS_MULTIPROC_DIR environment variable
            # before prometheus_client is imported.
            # See https://prometheus.github.io/client_python/multiprocess/
            from prometheus_client import multiprocess
            multiprocess.mark_process_dead(rpc_server_process.pid)

184

Ethan Xu's avatar
Ethan Xu committed
185
router = APIRouter()
Zhuohan Li's avatar
Zhuohan Li committed
186

187

188
def mount_metrics(app: FastAPI):
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
    # Lazy import for prometheus multiprocessing.
    # We need to set PROMETHEUS_MULTIPROC_DIR environment variable
    # before prometheus_client is imported.
    # See https://prometheus.github.io/client_python/multiprocess/
    from prometheus_client import (CollectorRegistry, make_asgi_app,
                                   multiprocess)

    prometheus_multiproc_dir_path = os.getenv("PROMETHEUS_MULTIPROC_DIR", None)
    if prometheus_multiproc_dir_path is not None:
        logger.info("vLLM to use %s as PROMETHEUS_MULTIPROC_DIR",
                    prometheus_multiproc_dir_path)
        registry = CollectorRegistry()
        multiprocess.MultiProcessCollector(registry)

        # Add prometheus asgi middleware to route /metrics requests
        metrics_route = Mount("/metrics", make_asgi_app(registry=registry))
    else:
        # Add prometheus asgi middleware to route /metrics requests
        metrics_route = Mount("/metrics", make_asgi_app())

209
210
211
    # Workaround for 307 Redirect for /metrics
    metrics_route.path_regex = re.compile('^/metrics(?P<path>.*)$')
    app.routes.append(metrics_route)
212
213


Ethan Xu's avatar
Ethan Xu committed
214
@router.get("/health")
215
216
async def health() -> Response:
    """Health check."""
217
    await async_engine_client.check_health()
218
219
220
    return Response(status_code=200)


Ethan Xu's avatar
Ethan Xu committed
221
@router.post("/tokenize")
222
async def tokenize(request: TokenizeRequest):
223
    generator = await openai_serving_tokenization.create_tokenize(request)
224
225
226
    if isinstance(generator, ErrorResponse):
        return JSONResponse(content=generator.model_dump(),
                            status_code=generator.code)
227
    elif isinstance(generator, TokenizeResponse):
228
229
        return JSONResponse(content=generator.model_dump())

230
231
    assert_never(generator)

232

Ethan Xu's avatar
Ethan Xu committed
233
@router.post("/detokenize")
234
async def detokenize(request: DetokenizeRequest):
235
    generator = await openai_serving_tokenization.create_detokenize(request)
236
237
238
    if isinstance(generator, ErrorResponse):
        return JSONResponse(content=generator.model_dump(),
                            status_code=generator.code)
239
    elif isinstance(generator, DetokenizeResponse):
240
241
        return JSONResponse(content=generator.model_dump())

242
243
    assert_never(generator)

244

Ethan Xu's avatar
Ethan Xu committed
245
@router.get("/v1/models")
Zhuohan Li's avatar
Zhuohan Li committed
246
async def show_available_models():
247
    models = await openai_serving_completion.show_available_models()
248
    return JSONResponse(content=models.model_dump())
Zhuohan Li's avatar
Zhuohan Li committed
249
250


Ethan Xu's avatar
Ethan Xu committed
251
@router.get("/version")
252
async def show_version():
253
    ver = {"version": VLLM_VERSION}
254
255
256
    return JSONResponse(content=ver)


Ethan Xu's avatar
Ethan Xu committed
257
@router.post("/v1/chat/completions")
258
259
async def create_chat_completion(request: ChatCompletionRequest,
                                 raw_request: Request):
260
261
    generator = await openai_serving_chat.create_chat_completion(
        request, raw_request)
262
263
264
    if isinstance(generator, ErrorResponse):
        return JSONResponse(content=generator.model_dump(),
                            status_code=generator.code)
265
    elif isinstance(generator, ChatCompletionResponse):
266
        return JSONResponse(content=generator.model_dump())
267

268
269
    return StreamingResponse(content=generator, media_type="text/event-stream")

270

Ethan Xu's avatar
Ethan Xu committed
271
@router.post("/v1/completions")
272
async def create_completion(request: CompletionRequest, raw_request: Request):
273
274
    generator = await openai_serving_completion.create_completion(
        request, raw_request)
275
276
277
    if isinstance(generator, ErrorResponse):
        return JSONResponse(content=generator.model_dump(),
                            status_code=generator.code)
278
    elif isinstance(generator, CompletionResponse):
279
        return JSONResponse(content=generator.model_dump())
Zhuohan Li's avatar
Zhuohan Li committed
280

281
282
    return StreamingResponse(content=generator, media_type="text/event-stream")

Zhuohan Li's avatar
Zhuohan Li committed
283

Ethan Xu's avatar
Ethan Xu committed
284
@router.post("/v1/embeddings")
285
286
287
288
289
290
async def create_embedding(request: EmbeddingRequest, raw_request: Request):
    generator = await openai_serving_embedding.create_embedding(
        request, raw_request)
    if isinstance(generator, ErrorResponse):
        return JSONResponse(content=generator.model_dump(),
                            status_code=generator.code)
291
    elif isinstance(generator, EmbeddingResponse):
292
293
        return JSONResponse(content=generator.model_dump())

294
295
    assert_never(generator)

296

297
298
def build_app(args: Namespace) -> FastAPI:
    app = FastAPI(lifespan=lifespan)
Ethan Xu's avatar
Ethan Xu committed
299
300
    app.include_router(router)
    app.root_path = args.root_path
Zhuohan Li's avatar
Zhuohan Li committed
301

302
303
    mount_metrics(app)

Zhuohan Li's avatar
Zhuohan Li committed
304
305
306
307
308
309
310
311
    app.add_middleware(
        CORSMiddleware,
        allow_origins=args.allowed_origins,
        allow_credentials=args.allow_credentials,
        allow_methods=args.allowed_methods,
        allow_headers=args.allowed_headers,
    )

Ethan Xu's avatar
Ethan Xu committed
312
313
314
315
316
317
    @app.exception_handler(RequestValidationError)
    async def validation_exception_handler(_, exc):
        err = openai_serving_chat.create_error_response(message=str(exc))
        return JSONResponse(err.model_dump(),
                            status_code=HTTPStatus.BAD_REQUEST)

318
    if token := envs.VLLM_API_KEY or args.api_key:
319
320
321

        @app.middleware("http")
        async def authentication(request: Request, call_next):
322
            root_path = "" if args.root_path is None else args.root_path
323
324
            if request.method == "OPTIONS":
                return await call_next(request)
325
            if not request.url.path.startswith(f"{root_path}/v1"):
326
327
328
329
330
331
332
333
334
335
336
337
338
339
                return await call_next(request)
            if request.headers.get("Authorization") != "Bearer " + token:
                return JSONResponse(content={"error": "Unauthorized"},
                                    status_code=401)
            return await call_next(request)

    for middleware in args.middleware:
        module_path, object_name = middleware.rsplit(".", 1)
        imported = getattr(importlib.import_module(module_path), object_name)
        if inspect.isclass(imported):
            app.add_middleware(imported)
        elif inspect.iscoroutinefunction(imported):
            app.middleware("http")(imported)
        else:
340
341
            raise ValueError(f"Invalid middleware {middleware}. "
                             f"Must be a function or a class.")
342

Ethan Xu's avatar
Ethan Xu committed
343
344
345
    return app


346
async def init_app(
347
    async_engine_client: AsyncEngineClient,
348
349
    args: Namespace,
) -> FastAPI:
Ethan Xu's avatar
Ethan Xu committed
350
351
    app = build_app(args)

352
    if args.served_model_name is not None:
353
        served_model_names = args.served_model_name
354
    else:
355
        served_model_names = [args.model]
356

357
    model_config = await async_engine_client.get_model_config()
358

359
360
361
362
363
    if args.disable_log_requests:
        request_logger = None
    else:
        request_logger = RequestLogger(max_log_len=args.max_log_len)

Ethan Xu's avatar
Ethan Xu committed
364
365
366
    global openai_serving_chat
    global openai_serving_completion
    global openai_serving_embedding
367
    global openai_serving_tokenization
Ethan Xu's avatar
Ethan Xu committed
368

369
    openai_serving_chat = OpenAIServingChat(
370
        async_engine_client,
371
372
373
374
375
376
377
        model_config,
        served_model_names,
        args.response_role,
        lora_modules=args.lora_modules,
        prompt_adapters=args.prompt_adapters,
        request_logger=request_logger,
        chat_template=args.chat_template,
378
        return_tokens_as_token_ids=args.return_tokens_as_token_ids,
379
    )
380
    openai_serving_completion = OpenAIServingCompletion(
381
        async_engine_client,
382
383
384
385
386
        model_config,
        served_model_names,
        lora_modules=args.lora_modules,
        prompt_adapters=args.prompt_adapters,
        request_logger=request_logger,
387
        return_tokens_as_token_ids=args.return_tokens_as_token_ids,
388
389
    )
    openai_serving_embedding = OpenAIServingEmbedding(
390
        async_engine_client,
391
392
393
394
        model_config,
        served_model_names,
        request_logger=request_logger,
    )
395
    openai_serving_tokenization = OpenAIServingTokenization(
396
        async_engine_client,
397
398
399
400
401
402
        model_config,
        served_model_names,
        lora_modules=args.lora_modules,
        request_logger=request_logger,
        chat_template=args.chat_template,
    )
403
    app.root_path = args.root_path
404

405
    return app
406
407


408
async def run_server(args, **uvicorn_kwargs) -> None:
409
410
411
    logger.info("vLLM API server version %s", VLLM_VERSION)
    logger.info("args: %s", args)

412
    async with build_async_engine_client(args) as async_engine_client:
413
414
415
416
        app = await init_app(async_engine_client, args)

        shutdown_task = await serve_http(
            app,
417
            engine=async_engine_client,
418
419
420
421
422
423
424
425
            host=args.host,
            port=args.port,
            log_level=args.uvicorn_log_level,
            timeout_keep_alive=TIMEOUT_KEEP_ALIVE,
            ssl_keyfile=args.ssl_keyfile,
            ssl_certfile=args.ssl_certfile,
            ssl_ca_certs=args.ssl_ca_certs,
            ssl_cert_reqs=args.ssl_cert_reqs,
426
427
428
            **uvicorn_kwargs,
        )

429
430
    # NB: Await server shutdown only after the backend context is exited
    await shutdown_task
431

Ethan Xu's avatar
Ethan Xu committed
432
433
434
435
436
437
438
439

if __name__ == "__main__":
    # NOTE(simon):
    # This section should be in sync with vllm/scripts.py for CLI entrypoints.
    parser = FlexibleArgumentParser(
        description="vLLM OpenAI-Compatible RESTful API server.")
    parser = make_arg_parser(parser)
    args = parser.parse_args()
440

441
    asyncio.run(run_server(args))