api_server.py 13.8 KB
Newer Older
1
import asyncio
2
3
import importlib
import inspect
4
import multiprocessing
5
import re
6
from argparse import Namespace
7
8
from contextlib import asynccontextmanager
from http import HTTPStatus
9
from typing import AsyncIterator, Set
10

11
from fastapi import APIRouter, FastAPI, Request
Zhuohan Li's avatar
Zhuohan Li committed
12
13
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
14
15
from fastapi.responses import JSONResponse, Response, StreamingResponse
from prometheus_client import make_asgi_app
16
from starlette.routing import Mount
Zhuohan Li's avatar
Zhuohan Li committed
17

18
import vllm.envs as envs
19
from vllm.config import ModelConfig
Woosuk Kwon's avatar
Woosuk Kwon committed
20
21
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
22
from vllm.engine.protocol import AsyncEngineClient
23
from vllm.entrypoints.launcher import serve_http
24
from vllm.entrypoints.logger import RequestLogger
25
from vllm.entrypoints.openai.cli_args import make_arg_parser
26
27
# yapf conflicts with isort for this block
# yapf: disable
28
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
29
                                              ChatCompletionResponse,
30
                                              CompletionRequest,
31
32
33
34
35
                                              DetokenizeRequest,
                                              DetokenizeResponse,
                                              EmbeddingRequest, ErrorResponse,
                                              TokenizeRequest,
                                              TokenizeResponse)
36
37
from vllm.entrypoints.openai.rpc.client import AsyncEngineRPCClient
from vllm.entrypoints.openai.rpc.server import run_rpc_server
38
# yapf: enable
39
40
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
41
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
42
43
from vllm.entrypoints.openai.serving_tokenization import (
    OpenAIServingTokenization)
44
from vllm.logger import init_logger
yhu422's avatar
yhu422 committed
45
from vllm.usage.usage_lib import UsageContext
46
from vllm.utils import FlexibleArgumentParser, get_open_zmq_ipc_path
47
from vllm.version import __version__ as VLLM_VERSION
Zhuohan Li's avatar
Zhuohan Li committed
48

49
TIMEOUT_KEEP_ALIVE = 5  # seconds
Zhuohan Li's avatar
Zhuohan Li committed
50

51
async_engine_client: AsyncEngineClient
Ethan Xu's avatar
Ethan Xu committed
52
engine_args: AsyncEngineArgs
53
54
openai_serving_chat: OpenAIServingChat
openai_serving_completion: OpenAIServingCompletion
55
openai_serving_embedding: OpenAIServingEmbedding
56
openai_serving_tokenization: OpenAIServingTokenization
57

58
logger = init_logger('vllm.entrypoints.openai.api_server')
59

60
_running_tasks: Set[asyncio.Task] = set()
61

62

63
def model_is_embedding(model_name: str, trust_remote_code: bool) -> bool:
64
65
66
    return ModelConfig(model=model_name,
                       tokenizer=model_name,
                       tokenizer_mode="auto",
67
                       trust_remote_code=trust_remote_code,
68
                       seed=0,
69
                       dtype="auto").embedding_mode
70
71


72
@asynccontextmanager
73
async def lifespan(app: FastAPI):
74
75
76
77

    async def _force_log():
        while True:
            await asyncio.sleep(10)
78
            await async_engine_client.do_log_stats()
79
80

    if not engine_args.disable_log_stats:
81
82
83
        task = asyncio.create_task(_force_log())
        _running_tasks.add(task)
        task.add_done_callback(_running_tasks.remove)
84
85
86
87

    yield


88
89
90
91
92
93
94
95
96
97
98
99
@asynccontextmanager
async def build_async_engine_client(args) -> AsyncIterator[AsyncEngineClient]:
    # Context manager to handle async_engine_client lifecycle
    # Ensures everything is shutdown and cleaned up on error/exit
    global engine_args
    engine_args = AsyncEngineArgs.from_cli_args(args)

    # Backend itself still global for the silly lil' health handler
    global async_engine_client

    # If manually triggered or embedding model, use AsyncLLMEngine in process.
    # TODO: support embedding model via RPC.
100
    if (model_is_embedding(args.model, args.trust_remote_code)
101
102
103
104
105
106
107
108
            or args.disable_frontend_multiprocessing):
        async_engine_client = AsyncLLMEngine.from_engine_args(
            engine_args, usage_context=UsageContext.OPENAI_API_SERVER)
        yield async_engine_client
        return

    # Otherwise, use the multiprocessing AsyncLLMEngine.
    else:
109
110
111
112
113
        # Select random path for IPC.
        rpc_path = get_open_zmq_ipc_path()
        logger.info("Multiprocessing frontend to use %s for RPC Path.",
                    rpc_path)

114
        # Start RPCServer in separate process (holds the AsyncLLMEngine).
115
116
117
118
119
120
        context = multiprocessing.get_context("spawn")
        # the current process might have CUDA context,
        # so we need to spawn a new process
        rpc_server_process = context.Process(
            target=run_rpc_server,
            args=(engine_args, UsageContext.OPENAI_API_SERVER, rpc_path))
121
        rpc_server_process.start()
122
123
        logger.info("Started engine process with PID %d",
                    rpc_server_process.pid)
124
        # Build RPCClient, which conforms to AsyncEngineClient Protocol.
125
        async_engine_client = AsyncEngineRPCClient(rpc_path)
126
127

        try:
128
129
130
131
132
133
134
135
136
137
            while True:
                try:
                    await async_engine_client.setup()
                    break
                except TimeoutError as e:
                    if not rpc_server_process.is_alive():
                        raise RuntimeError(
                            "The server process died before "
                            "responding to the readiness probe") from e

138
139
140
141
142
143
144
145
146
147
148
149
            yield async_engine_client
        finally:
            # Ensure rpc server process was terminated
            rpc_server_process.terminate()

            # Close all open connections to the backend
            async_engine_client.close()

            # Wait for server process to join
            rpc_server_process.join()


Ethan Xu's avatar
Ethan Xu committed
150
router = APIRouter()
Zhuohan Li's avatar
Zhuohan Li committed
151

152

153
def mount_metrics(app: FastAPI):
154
155
156
157
158
    # Add prometheus asgi middleware to route /metrics requests
    metrics_route = Mount("/metrics", make_asgi_app())
    # Workaround for 307 Redirect for /metrics
    metrics_route.path_regex = re.compile('^/metrics(?P<path>.*)$')
    app.routes.append(metrics_route)
159
160


Ethan Xu's avatar
Ethan Xu committed
161
@router.get("/health")
162
163
async def health() -> Response:
    """Health check."""
164
    await async_engine_client.check_health()
165
166
167
    return Response(status_code=200)


Ethan Xu's avatar
Ethan Xu committed
168
@router.post("/tokenize")
169
async def tokenize(request: TokenizeRequest):
170
    generator = await openai_serving_tokenization.create_tokenize(request)
171
172
173
174
175
176
177
178
    if isinstance(generator, ErrorResponse):
        return JSONResponse(content=generator.model_dump(),
                            status_code=generator.code)
    else:
        assert isinstance(generator, TokenizeResponse)
        return JSONResponse(content=generator.model_dump())


Ethan Xu's avatar
Ethan Xu committed
179
@router.post("/detokenize")
180
async def detokenize(request: DetokenizeRequest):
181
    generator = await openai_serving_tokenization.create_detokenize(request)
182
183
184
185
186
187
188
189
    if isinstance(generator, ErrorResponse):
        return JSONResponse(content=generator.model_dump(),
                            status_code=generator.code)
    else:
        assert isinstance(generator, DetokenizeResponse)
        return JSONResponse(content=generator.model_dump())


Ethan Xu's avatar
Ethan Xu committed
190
@router.get("/v1/models")
Zhuohan Li's avatar
Zhuohan Li committed
191
async def show_available_models():
192
    models = await openai_serving_completion.show_available_models()
193
    return JSONResponse(content=models.model_dump())
Zhuohan Li's avatar
Zhuohan Li committed
194
195


Ethan Xu's avatar
Ethan Xu committed
196
@router.get("/version")
197
async def show_version():
198
    ver = {"version": VLLM_VERSION}
199
200
201
    return JSONResponse(content=ver)


Ethan Xu's avatar
Ethan Xu committed
202
@router.post("/v1/chat/completions")
203
204
async def create_chat_completion(request: ChatCompletionRequest,
                                 raw_request: Request):
205
206
    generator = await openai_serving_chat.create_chat_completion(
        request, raw_request)
207
208
209
210
    if isinstance(generator, ErrorResponse):
        return JSONResponse(content=generator.model_dump(),
                            status_code=generator.code)
    if request.stream:
211
        return StreamingResponse(content=generator,
212
                                 media_type="text/event-stream")
213
    else:
214
        assert isinstance(generator, ChatCompletionResponse)
215
        return JSONResponse(content=generator.model_dump())
216
217


Ethan Xu's avatar
Ethan Xu committed
218
@router.post("/v1/completions")
219
async def create_completion(request: CompletionRequest, raw_request: Request):
220
221
    generator = await openai_serving_completion.create_completion(
        request, raw_request)
222
223
224
225
    if isinstance(generator, ErrorResponse):
        return JSONResponse(content=generator.model_dump(),
                            status_code=generator.code)
    if request.stream:
226
        return StreamingResponse(content=generator,
Zhuohan Li's avatar
Zhuohan Li committed
227
                                 media_type="text/event-stream")
228
    else:
229
        return JSONResponse(content=generator.model_dump())
Zhuohan Li's avatar
Zhuohan Li committed
230
231


Ethan Xu's avatar
Ethan Xu committed
232
@router.post("/v1/embeddings")
233
234
235
236
237
238
239
240
241
242
async def create_embedding(request: EmbeddingRequest, raw_request: Request):
    generator = await openai_serving_embedding.create_embedding(
        request, raw_request)
    if isinstance(generator, ErrorResponse):
        return JSONResponse(content=generator.model_dump(),
                            status_code=generator.code)
    else:
        return JSONResponse(content=generator.model_dump())


243
244
def build_app(args: Namespace) -> FastAPI:
    app = FastAPI(lifespan=lifespan)
Ethan Xu's avatar
Ethan Xu committed
245
246
    app.include_router(router)
    app.root_path = args.root_path
Zhuohan Li's avatar
Zhuohan Li committed
247

248
249
    mount_metrics(app)

Zhuohan Li's avatar
Zhuohan Li committed
250
251
252
253
254
255
256
257
    app.add_middleware(
        CORSMiddleware,
        allow_origins=args.allowed_origins,
        allow_credentials=args.allow_credentials,
        allow_methods=args.allowed_methods,
        allow_headers=args.allowed_headers,
    )

Ethan Xu's avatar
Ethan Xu committed
258
259
260
261
262
263
    @app.exception_handler(RequestValidationError)
    async def validation_exception_handler(_, exc):
        err = openai_serving_chat.create_error_response(message=str(exc))
        return JSONResponse(err.model_dump(),
                            status_code=HTTPStatus.BAD_REQUEST)

264
    if token := envs.VLLM_API_KEY or args.api_key:
265
266
267

        @app.middleware("http")
        async def authentication(request: Request, call_next):
268
            root_path = "" if args.root_path is None else args.root_path
269
270
            if request.method == "OPTIONS":
                return await call_next(request)
271
            if not request.url.path.startswith(f"{root_path}/v1"):
272
273
274
275
276
277
278
279
280
281
282
283
284
285
                return await call_next(request)
            if request.headers.get("Authorization") != "Bearer " + token:
                return JSONResponse(content={"error": "Unauthorized"},
                                    status_code=401)
            return await call_next(request)

    for middleware in args.middleware:
        module_path, object_name = middleware.rsplit(".", 1)
        imported = getattr(importlib.import_module(module_path), object_name)
        if inspect.isclass(imported):
            app.add_middleware(imported)
        elif inspect.iscoroutinefunction(imported):
            app.middleware("http")(imported)
        else:
286
287
            raise ValueError(f"Invalid middleware {middleware}. "
                             f"Must be a function or a class.")
288

Ethan Xu's avatar
Ethan Xu committed
289
290
291
    return app


292
async def init_app(
293
    async_engine_client: AsyncEngineClient,
294
295
    args: Namespace,
) -> FastAPI:
Ethan Xu's avatar
Ethan Xu committed
296
297
    app = build_app(args)

298
    if args.served_model_name is not None:
299
        served_model_names = args.served_model_name
300
    else:
301
        served_model_names = [args.model]
302

303
    model_config = await async_engine_client.get_model_config()
304

305
306
307
308
309
    if args.disable_log_requests:
        request_logger = None
    else:
        request_logger = RequestLogger(max_log_len=args.max_log_len)

Ethan Xu's avatar
Ethan Xu committed
310
311
312
    global openai_serving_chat
    global openai_serving_completion
    global openai_serving_embedding
313
    global openai_serving_tokenization
Ethan Xu's avatar
Ethan Xu committed
314

315
    openai_serving_chat = OpenAIServingChat(
316
        async_engine_client,
317
318
319
320
321
322
323
        model_config,
        served_model_names,
        args.response_role,
        lora_modules=args.lora_modules,
        prompt_adapters=args.prompt_adapters,
        request_logger=request_logger,
        chat_template=args.chat_template,
324
        return_tokens_as_token_ids=args.return_tokens_as_token_ids,
325
    )
326
    openai_serving_completion = OpenAIServingCompletion(
327
        async_engine_client,
328
329
330
331
332
        model_config,
        served_model_names,
        lora_modules=args.lora_modules,
        prompt_adapters=args.prompt_adapters,
        request_logger=request_logger,
333
        return_tokens_as_token_ids=args.return_tokens_as_token_ids,
334
335
    )
    openai_serving_embedding = OpenAIServingEmbedding(
336
        async_engine_client,
337
338
339
340
        model_config,
        served_model_names,
        request_logger=request_logger,
    )
341
    openai_serving_tokenization = OpenAIServingTokenization(
342
        async_engine_client,
343
344
345
346
347
348
        model_config,
        served_model_names,
        lora_modules=args.lora_modules,
        request_logger=request_logger,
        chat_template=args.chat_template,
    )
349
    app.root_path = args.root_path
350

351
    return app
352
353


354
async def run_server(args, **uvicorn_kwargs) -> None:
355
356
357
    logger.info("vLLM API server version %s", VLLM_VERSION)
    logger.info("args: %s", args)

358
    async with build_async_engine_client(args) as async_engine_client:
359
360
361
362
        app = await init_app(async_engine_client, args)

        shutdown_task = await serve_http(
            app,
363
            engine=async_engine_client,
364
365
366
367
368
369
370
371
            host=args.host,
            port=args.port,
            log_level=args.uvicorn_log_level,
            timeout_keep_alive=TIMEOUT_KEEP_ALIVE,
            ssl_keyfile=args.ssl_keyfile,
            ssl_certfile=args.ssl_certfile,
            ssl_ca_certs=args.ssl_ca_certs,
            ssl_cert_reqs=args.ssl_cert_reqs,
372
373
374
            **uvicorn_kwargs,
        )

375
376
    # NB: Await server shutdown only after the backend context is exited
    await shutdown_task
377

Ethan Xu's avatar
Ethan Xu committed
378
379
380
381
382
383
384
385

if __name__ == "__main__":
    # NOTE(simon):
    # This section should be in sync with vllm/scripts.py for CLI entrypoints.
    parser = FlexibleArgumentParser(
        description="vLLM OpenAI-Compatible RESTful API server.")
    parser = make_arg_parser(parser)
    args = parser.parse_args()
386

387
    asyncio.run(run_server(args))