api_server.py 15.9 KB
Newer Older
1
import asyncio
2
3
import importlib
import inspect
4
import multiprocessing
5
import os
6
import re
7
import tempfile
8
from argparse import Namespace
9
10
from contextlib import asynccontextmanager
from http import HTTPStatus
11
from typing import AsyncIterator, Set
12

13
from fastapi import APIRouter, FastAPI, Request
Zhuohan Li's avatar
Zhuohan Li committed
14
15
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
16
from fastapi.responses import JSONResponse, Response, StreamingResponse
17
from starlette.routing import Mount
Zhuohan Li's avatar
Zhuohan Li committed
18

19
import vllm.envs as envs
20
from vllm.config import ModelConfig
Woosuk Kwon's avatar
Woosuk Kwon committed
21
22
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
23
from vllm.engine.protocol import AsyncEngineClient
24
from vllm.entrypoints.launcher import serve_http
25
from vllm.entrypoints.logger import RequestLogger
26
from vllm.entrypoints.openai.cli_args import make_arg_parser
27
28
# yapf conflicts with isort for this block
# yapf: disable
29
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
30
                                              ChatCompletionResponse,
31
                                              CompletionRequest,
32
33
34
35
36
                                              DetokenizeRequest,
                                              DetokenizeResponse,
                                              EmbeddingRequest, ErrorResponse,
                                              TokenizeRequest,
                                              TokenizeResponse)
37
38
from vllm.entrypoints.openai.rpc.client import AsyncEngineRPCClient
from vllm.entrypoints.openai.rpc.server import run_rpc_server
39
# yapf: enable
40
41
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
42
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
43
44
from vllm.entrypoints.openai.serving_tokenization import (
    OpenAIServingTokenization)
45
from vllm.logger import init_logger
yhu422's avatar
yhu422 committed
46
from vllm.usage.usage_lib import UsageContext
47
from vllm.utils import FlexibleArgumentParser, get_open_zmq_ipc_path
48
from vllm.version import __version__ as VLLM_VERSION
Zhuohan Li's avatar
Zhuohan Li committed
49

50
TIMEOUT_KEEP_ALIVE = 5  # seconds
Zhuohan Li's avatar
Zhuohan Li committed
51

52
async_engine_client: AsyncEngineClient
Ethan Xu's avatar
Ethan Xu committed
53
engine_args: AsyncEngineArgs
54
55
openai_serving_chat: OpenAIServingChat
openai_serving_completion: OpenAIServingCompletion
56
openai_serving_embedding: OpenAIServingEmbedding
57
openai_serving_tokenization: OpenAIServingTokenization
58
prometheus_multiproc_dir: tempfile.TemporaryDirectory
59

60
logger = init_logger('vllm.entrypoints.openai.api_server')
61

62
_running_tasks: Set[asyncio.Task] = set()
63

64

65
66
def model_is_embedding(model_name: str, trust_remote_code: bool,
                       quantization: str) -> bool:
67
68
69
    return ModelConfig(model=model_name,
                       tokenizer=model_name,
                       tokenizer_mode="auto",
70
                       trust_remote_code=trust_remote_code,
71
                       quantization=quantization,
72
                       seed=0,
73
                       dtype="auto").embedding_mode
74
75


76
@asynccontextmanager
77
async def lifespan(app: FastAPI):
78
79
80
81

    async def _force_log():
        while True:
            await asyncio.sleep(10)
82
            await async_engine_client.do_log_stats()
83
84

    if not engine_args.disable_log_stats:
85
86
87
        task = asyncio.create_task(_force_log())
        _running_tasks.add(task)
        task.add_done_callback(_running_tasks.remove)
88
89
90
91

    yield


92
93
94
95
96
97
98
99
100
101
102
103
@asynccontextmanager
async def build_async_engine_client(args) -> AsyncIterator[AsyncEngineClient]:
    # Context manager to handle async_engine_client lifecycle
    # Ensures everything is shutdown and cleaned up on error/exit
    global engine_args
    engine_args = AsyncEngineArgs.from_cli_args(args)

    # Backend itself still global for the silly lil' health handler
    global async_engine_client

    # If manually triggered or embedding model, use AsyncLLMEngine in process.
    # TODO: support embedding model via RPC.
104
105
    if (model_is_embedding(args.model, args.trust_remote_code,
                           args.quantization)
106
107
108
109
110
111
112
113
            or args.disable_frontend_multiprocessing):
        async_engine_client = AsyncLLMEngine.from_engine_args(
            engine_args, usage_context=UsageContext.OPENAI_API_SERVER)
        yield async_engine_client
        return

    # Otherwise, use the multiprocessing AsyncLLMEngine.
    else:
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
        if "PROMETHEUS_MULTIPROC_DIR" not in os.environ:
            # Make TemporaryDirectory for prometheus multiprocessing
            # Note: global TemporaryDirectory will be automatically
            #   cleaned up upon exit.
            global prometheus_multiproc_dir
            prometheus_multiproc_dir = tempfile.TemporaryDirectory()
            os.environ[
                "PROMETHEUS_MULTIPROC_DIR"] = prometheus_multiproc_dir.name
        else:
            logger.warning(
                "Found PROMETHEUS_MULTIPROC_DIR was set by user. "
                "This directory must be wiped between vLLM runs or "
                "you will find inaccurate metrics. Unset the variable "
                "and vLLM will properly handle cleanup.")

129
130
131
132
133
        # Select random path for IPC.
        rpc_path = get_open_zmq_ipc_path()
        logger.info("Multiprocessing frontend to use %s for RPC Path.",
                    rpc_path)

134
        # Start RPCServer in separate process (holds the AsyncLLMEngine).
135
136
137
138
139
140
        context = multiprocessing.get_context("spawn")
        # the current process might have CUDA context,
        # so we need to spawn a new process
        rpc_server_process = context.Process(
            target=run_rpc_server,
            args=(engine_args, UsageContext.OPENAI_API_SERVER, rpc_path))
141
        rpc_server_process.start()
142
143
        logger.info("Started engine process with PID %d",
                    rpc_server_process.pid)
144
        # Build RPCClient, which conforms to AsyncEngineClient Protocol.
145
        async_engine_client = AsyncEngineRPCClient(rpc_path)
146
147

        try:
148
149
150
151
152
153
154
155
156
157
            while True:
                try:
                    await async_engine_client.setup()
                    break
                except TimeoutError as e:
                    if not rpc_server_process.is_alive():
                        raise RuntimeError(
                            "The server process died before "
                            "responding to the readiness probe") from e

158
159
160
161
162
163
164
165
166
167
168
            yield async_engine_client
        finally:
            # Ensure rpc server process was terminated
            rpc_server_process.terminate()

            # Close all open connections to the backend
            async_engine_client.close()

            # Wait for server process to join
            rpc_server_process.join()

169
170
171
172
173
174
175
            # Lazy import for prometheus multiprocessing.
            # We need to set PROMETHEUS_MULTIPROC_DIR environment variable
            # before prometheus_client is imported.
            # See https://prometheus.github.io/client_python/multiprocess/
            from prometheus_client import multiprocess
            multiprocess.mark_process_dead(rpc_server_process.pid)

176

Ethan Xu's avatar
Ethan Xu committed
177
router = APIRouter()
Zhuohan Li's avatar
Zhuohan Li committed
178

179

180
def mount_metrics(app: FastAPI):
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
    # Lazy import for prometheus multiprocessing.
    # We need to set PROMETHEUS_MULTIPROC_DIR environment variable
    # before prometheus_client is imported.
    # See https://prometheus.github.io/client_python/multiprocess/
    from prometheus_client import (CollectorRegistry, make_asgi_app,
                                   multiprocess)

    prometheus_multiproc_dir_path = os.getenv("PROMETHEUS_MULTIPROC_DIR", None)
    if prometheus_multiproc_dir_path is not None:
        logger.info("vLLM to use %s as PROMETHEUS_MULTIPROC_DIR",
                    prometheus_multiproc_dir_path)
        registry = CollectorRegistry()
        multiprocess.MultiProcessCollector(registry)

        # Add prometheus asgi middleware to route /metrics requests
        metrics_route = Mount("/metrics", make_asgi_app(registry=registry))
    else:
        # Add prometheus asgi middleware to route /metrics requests
        metrics_route = Mount("/metrics", make_asgi_app())

201
202
203
    # Workaround for 307 Redirect for /metrics
    metrics_route.path_regex = re.compile('^/metrics(?P<path>.*)$')
    app.routes.append(metrics_route)
204
205


Ethan Xu's avatar
Ethan Xu committed
206
@router.get("/health")
207
208
async def health() -> Response:
    """Health check."""
209
    await async_engine_client.check_health()
210
211
212
    return Response(status_code=200)


Ethan Xu's avatar
Ethan Xu committed
213
@router.post("/tokenize")
214
async def tokenize(request: TokenizeRequest):
215
    generator = await openai_serving_tokenization.create_tokenize(request)
216
217
218
219
220
221
222
223
    if isinstance(generator, ErrorResponse):
        return JSONResponse(content=generator.model_dump(),
                            status_code=generator.code)
    else:
        assert isinstance(generator, TokenizeResponse)
        return JSONResponse(content=generator.model_dump())


Ethan Xu's avatar
Ethan Xu committed
224
@router.post("/detokenize")
225
async def detokenize(request: DetokenizeRequest):
226
    generator = await openai_serving_tokenization.create_detokenize(request)
227
228
229
230
231
232
233
234
    if isinstance(generator, ErrorResponse):
        return JSONResponse(content=generator.model_dump(),
                            status_code=generator.code)
    else:
        assert isinstance(generator, DetokenizeResponse)
        return JSONResponse(content=generator.model_dump())


Ethan Xu's avatar
Ethan Xu committed
235
@router.get("/v1/models")
Zhuohan Li's avatar
Zhuohan Li committed
236
async def show_available_models():
237
    models = await openai_serving_completion.show_available_models()
238
    return JSONResponse(content=models.model_dump())
Zhuohan Li's avatar
Zhuohan Li committed
239
240


Ethan Xu's avatar
Ethan Xu committed
241
@router.get("/version")
242
async def show_version():
243
    ver = {"version": VLLM_VERSION}
244
245
246
    return JSONResponse(content=ver)


Ethan Xu's avatar
Ethan Xu committed
247
@router.post("/v1/chat/completions")
248
249
async def create_chat_completion(request: ChatCompletionRequest,
                                 raw_request: Request):
250
251
    generator = await openai_serving_chat.create_chat_completion(
        request, raw_request)
252
253
254
255
    if isinstance(generator, ErrorResponse):
        return JSONResponse(content=generator.model_dump(),
                            status_code=generator.code)
    if request.stream:
256
        return StreamingResponse(content=generator,
257
                                 media_type="text/event-stream")
258
    else:
259
        assert isinstance(generator, ChatCompletionResponse)
260
        return JSONResponse(content=generator.model_dump())
261
262


Ethan Xu's avatar
Ethan Xu committed
263
@router.post("/v1/completions")
264
async def create_completion(request: CompletionRequest, raw_request: Request):
265
266
    generator = await openai_serving_completion.create_completion(
        request, raw_request)
267
268
269
270
    if isinstance(generator, ErrorResponse):
        return JSONResponse(content=generator.model_dump(),
                            status_code=generator.code)
    if request.stream:
271
        return StreamingResponse(content=generator,
Zhuohan Li's avatar
Zhuohan Li committed
272
                                 media_type="text/event-stream")
273
    else:
274
        return JSONResponse(content=generator.model_dump())
Zhuohan Li's avatar
Zhuohan Li committed
275
276


Ethan Xu's avatar
Ethan Xu committed
277
@router.post("/v1/embeddings")
278
279
280
281
282
283
284
285
286
287
async def create_embedding(request: EmbeddingRequest, raw_request: Request):
    generator = await openai_serving_embedding.create_embedding(
        request, raw_request)
    if isinstance(generator, ErrorResponse):
        return JSONResponse(content=generator.model_dump(),
                            status_code=generator.code)
    else:
        return JSONResponse(content=generator.model_dump())


288
289
def build_app(args: Namespace) -> FastAPI:
    app = FastAPI(lifespan=lifespan)
Ethan Xu's avatar
Ethan Xu committed
290
291
    app.include_router(router)
    app.root_path = args.root_path
Zhuohan Li's avatar
Zhuohan Li committed
292

293
294
    mount_metrics(app)

Zhuohan Li's avatar
Zhuohan Li committed
295
296
297
298
299
300
301
302
    app.add_middleware(
        CORSMiddleware,
        allow_origins=args.allowed_origins,
        allow_credentials=args.allow_credentials,
        allow_methods=args.allowed_methods,
        allow_headers=args.allowed_headers,
    )

Ethan Xu's avatar
Ethan Xu committed
303
304
305
306
307
308
    @app.exception_handler(RequestValidationError)
    async def validation_exception_handler(_, exc):
        err = openai_serving_chat.create_error_response(message=str(exc))
        return JSONResponse(err.model_dump(),
                            status_code=HTTPStatus.BAD_REQUEST)

309
    if token := envs.VLLM_API_KEY or args.api_key:
310
311
312

        @app.middleware("http")
        async def authentication(request: Request, call_next):
313
            root_path = "" if args.root_path is None else args.root_path
314
315
            if request.method == "OPTIONS":
                return await call_next(request)
316
            if not request.url.path.startswith(f"{root_path}/v1"):
317
318
319
320
321
322
323
324
325
326
327
328
329
330
                return await call_next(request)
            if request.headers.get("Authorization") != "Bearer " + token:
                return JSONResponse(content={"error": "Unauthorized"},
                                    status_code=401)
            return await call_next(request)

    for middleware in args.middleware:
        module_path, object_name = middleware.rsplit(".", 1)
        imported = getattr(importlib.import_module(module_path), object_name)
        if inspect.isclass(imported):
            app.add_middleware(imported)
        elif inspect.iscoroutinefunction(imported):
            app.middleware("http")(imported)
        else:
331
332
            raise ValueError(f"Invalid middleware {middleware}. "
                             f"Must be a function or a class.")
333

Ethan Xu's avatar
Ethan Xu committed
334
335
336
    return app


337
async def init_app(
338
    async_engine_client: AsyncEngineClient,
339
340
    args: Namespace,
) -> FastAPI:
Ethan Xu's avatar
Ethan Xu committed
341
342
    app = build_app(args)

343
    if args.served_model_name is not None:
344
        served_model_names = args.served_model_name
345
    else:
346
        served_model_names = [args.model]
347

348
    model_config = await async_engine_client.get_model_config()
349

350
351
352
353
354
    if args.disable_log_requests:
        request_logger = None
    else:
        request_logger = RequestLogger(max_log_len=args.max_log_len)

Ethan Xu's avatar
Ethan Xu committed
355
356
357
    global openai_serving_chat
    global openai_serving_completion
    global openai_serving_embedding
358
    global openai_serving_tokenization
Ethan Xu's avatar
Ethan Xu committed
359

360
    openai_serving_chat = OpenAIServingChat(
361
        async_engine_client,
362
363
364
365
366
367
368
        model_config,
        served_model_names,
        args.response_role,
        lora_modules=args.lora_modules,
        prompt_adapters=args.prompt_adapters,
        request_logger=request_logger,
        chat_template=args.chat_template,
369
        return_tokens_as_token_ids=args.return_tokens_as_token_ids,
370
    )
371
    openai_serving_completion = OpenAIServingCompletion(
372
        async_engine_client,
373
374
375
376
377
        model_config,
        served_model_names,
        lora_modules=args.lora_modules,
        prompt_adapters=args.prompt_adapters,
        request_logger=request_logger,
378
        return_tokens_as_token_ids=args.return_tokens_as_token_ids,
379
380
    )
    openai_serving_embedding = OpenAIServingEmbedding(
381
        async_engine_client,
382
383
384
385
        model_config,
        served_model_names,
        request_logger=request_logger,
    )
386
    openai_serving_tokenization = OpenAIServingTokenization(
387
        async_engine_client,
388
389
390
391
392
393
        model_config,
        served_model_names,
        lora_modules=args.lora_modules,
        request_logger=request_logger,
        chat_template=args.chat_template,
    )
394
    app.root_path = args.root_path
395

396
    return app
397
398


399
async def run_server(args, **uvicorn_kwargs) -> None:
400
401
402
    logger.info("vLLM API server version %s", VLLM_VERSION)
    logger.info("args: %s", args)

403
    async with build_async_engine_client(args) as async_engine_client:
404
405
406
407
        app = await init_app(async_engine_client, args)

        shutdown_task = await serve_http(
            app,
408
            engine=async_engine_client,
409
410
411
412
413
414
415
416
            host=args.host,
            port=args.port,
            log_level=args.uvicorn_log_level,
            timeout_keep_alive=TIMEOUT_KEEP_ALIVE,
            ssl_keyfile=args.ssl_keyfile,
            ssl_certfile=args.ssl_certfile,
            ssl_ca_certs=args.ssl_ca_certs,
            ssl_cert_reqs=args.ssl_cert_reqs,
417
418
419
            **uvicorn_kwargs,
        )

420
421
    # NB: Await server shutdown only after the backend context is exited
    await shutdown_task
422

Ethan Xu's avatar
Ethan Xu committed
423
424
425
426
427
428
429
430

if __name__ == "__main__":
    # NOTE(simon):
    # This section should be in sync with vllm/scripts.py for CLI entrypoints.
    parser = FlexibleArgumentParser(
        description="vLLM OpenAI-Compatible RESTful API server.")
    parser = make_arg_parser(parser)
    args = parser.parse_args()
431

432
    asyncio.run(run_server(args))