api_server.py 13.9 KB
Newer Older
1
import asyncio
2
3
import importlib
import inspect
4
import multiprocessing
5
import re
6
from argparse import Namespace
7
8
from contextlib import asynccontextmanager
from http import HTTPStatus
9
from typing import AsyncIterator, Set
10

11
from fastapi import APIRouter, FastAPI, Request
Zhuohan Li's avatar
Zhuohan Li committed
12
13
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
14
15
from fastapi.responses import JSONResponse, Response, StreamingResponse
from prometheus_client import make_asgi_app
16
from starlette.routing import Mount
Zhuohan Li's avatar
Zhuohan Li committed
17

18
import vllm.envs as envs
19
from vllm.config import ModelConfig
Woosuk Kwon's avatar
Woosuk Kwon committed
20
21
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
22
from vllm.engine.protocol import AsyncEngineClient
23
from vllm.entrypoints.launcher import serve_http
24
from vllm.entrypoints.logger import RequestLogger
25
from vllm.entrypoints.openai.cli_args import make_arg_parser
26
27
# yapf conflicts with isort for this block
# yapf: disable
28
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
29
                                              ChatCompletionResponse,
30
                                              CompletionRequest,
31
32
33
34
35
                                              DetokenizeRequest,
                                              DetokenizeResponse,
                                              EmbeddingRequest, ErrorResponse,
                                              TokenizeRequest,
                                              TokenizeResponse)
36
37
from vllm.entrypoints.openai.rpc.client import AsyncEngineRPCClient
from vllm.entrypoints.openai.rpc.server import run_rpc_server
38
# yapf: enable
39
40
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
41
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
42
43
from vllm.entrypoints.openai.serving_tokenization import (
    OpenAIServingTokenization)
44
from vllm.logger import init_logger
yhu422's avatar
yhu422 committed
45
from vllm.usage.usage_lib import UsageContext
46
from vllm.utils import FlexibleArgumentParser, get_open_zmq_ipc_path
47
from vllm.version import __version__ as VLLM_VERSION
Zhuohan Li's avatar
Zhuohan Li committed
48

49
TIMEOUT_KEEP_ALIVE = 5  # seconds
Zhuohan Li's avatar
Zhuohan Li committed
50

51
async_engine_client: AsyncEngineClient
Ethan Xu's avatar
Ethan Xu committed
52
engine_args: AsyncEngineArgs
53
54
openai_serving_chat: OpenAIServingChat
openai_serving_completion: OpenAIServingCompletion
55
openai_serving_embedding: OpenAIServingEmbedding
56
openai_serving_tokenization: OpenAIServingTokenization
57

58
logger = init_logger('vllm.entrypoints.openai.api_server')
59

60
_running_tasks: Set[asyncio.Task] = set()
61

62

63
64
def model_is_embedding(model_name: str, trust_remote_code: bool,
                       quantization: str) -> bool:
65
66
67
    return ModelConfig(model=model_name,
                       tokenizer=model_name,
                       tokenizer_mode="auto",
68
                       trust_remote_code=trust_remote_code,
69
                       quantization=quantization,
70
                       seed=0,
71
                       dtype="auto").embedding_mode
72
73


74
@asynccontextmanager
75
async def lifespan(app: FastAPI):
76
77
78
79

    async def _force_log():
        while True:
            await asyncio.sleep(10)
80
            await async_engine_client.do_log_stats()
81
82

    if not engine_args.disable_log_stats:
83
84
85
        task = asyncio.create_task(_force_log())
        _running_tasks.add(task)
        task.add_done_callback(_running_tasks.remove)
86
87
88
89

    yield


90
91
92
93
94
95
96
97
98
99
100
101
@asynccontextmanager
async def build_async_engine_client(args) -> AsyncIterator[AsyncEngineClient]:
    # Context manager to handle async_engine_client lifecycle
    # Ensures everything is shutdown and cleaned up on error/exit
    global engine_args
    engine_args = AsyncEngineArgs.from_cli_args(args)

    # Backend itself still global for the silly lil' health handler
    global async_engine_client

    # If manually triggered or embedding model, use AsyncLLMEngine in process.
    # TODO: support embedding model via RPC.
102
103
    if (model_is_embedding(args.model, args.trust_remote_code,
                           args.quantization)
104
105
106
107
108
109
110
111
            or args.disable_frontend_multiprocessing):
        async_engine_client = AsyncLLMEngine.from_engine_args(
            engine_args, usage_context=UsageContext.OPENAI_API_SERVER)
        yield async_engine_client
        return

    # Otherwise, use the multiprocessing AsyncLLMEngine.
    else:
112
113
114
115
116
        # Select random path for IPC.
        rpc_path = get_open_zmq_ipc_path()
        logger.info("Multiprocessing frontend to use %s for RPC Path.",
                    rpc_path)

117
        # Start RPCServer in separate process (holds the AsyncLLMEngine).
118
119
120
121
122
123
        context = multiprocessing.get_context("spawn")
        # the current process might have CUDA context,
        # so we need to spawn a new process
        rpc_server_process = context.Process(
            target=run_rpc_server,
            args=(engine_args, UsageContext.OPENAI_API_SERVER, rpc_path))
124
        rpc_server_process.start()
125
126
        logger.info("Started engine process with PID %d",
                    rpc_server_process.pid)
127
        # Build RPCClient, which conforms to AsyncEngineClient Protocol.
128
        async_engine_client = AsyncEngineRPCClient(rpc_path)
129
130

        try:
131
132
133
134
135
136
137
138
139
140
            while True:
                try:
                    await async_engine_client.setup()
                    break
                except TimeoutError as e:
                    if not rpc_server_process.is_alive():
                        raise RuntimeError(
                            "The server process died before "
                            "responding to the readiness probe") from e

141
142
143
144
145
146
147
148
149
150
151
152
            yield async_engine_client
        finally:
            # Ensure rpc server process was terminated
            rpc_server_process.terminate()

            # Close all open connections to the backend
            async_engine_client.close()

            # Wait for server process to join
            rpc_server_process.join()


Ethan Xu's avatar
Ethan Xu committed
153
router = APIRouter()
Zhuohan Li's avatar
Zhuohan Li committed
154

155

156
def mount_metrics(app: FastAPI):
157
158
159
160
161
    # Add prometheus asgi middleware to route /metrics requests
    metrics_route = Mount("/metrics", make_asgi_app())
    # Workaround for 307 Redirect for /metrics
    metrics_route.path_regex = re.compile('^/metrics(?P<path>.*)$')
    app.routes.append(metrics_route)
162
163


Ethan Xu's avatar
Ethan Xu committed
164
@router.get("/health")
165
166
async def health() -> Response:
    """Health check."""
167
    await async_engine_client.check_health()
168
169
170
    return Response(status_code=200)


Ethan Xu's avatar
Ethan Xu committed
171
@router.post("/tokenize")
172
async def tokenize(request: TokenizeRequest):
173
    generator = await openai_serving_tokenization.create_tokenize(request)
174
175
176
177
178
179
180
181
    if isinstance(generator, ErrorResponse):
        return JSONResponse(content=generator.model_dump(),
                            status_code=generator.code)
    else:
        assert isinstance(generator, TokenizeResponse)
        return JSONResponse(content=generator.model_dump())


Ethan Xu's avatar
Ethan Xu committed
182
@router.post("/detokenize")
183
async def detokenize(request: DetokenizeRequest):
184
    generator = await openai_serving_tokenization.create_detokenize(request)
185
186
187
188
189
190
191
192
    if isinstance(generator, ErrorResponse):
        return JSONResponse(content=generator.model_dump(),
                            status_code=generator.code)
    else:
        assert isinstance(generator, DetokenizeResponse)
        return JSONResponse(content=generator.model_dump())


Ethan Xu's avatar
Ethan Xu committed
193
@router.get("/v1/models")
Zhuohan Li's avatar
Zhuohan Li committed
194
async def show_available_models():
195
    models = await openai_serving_completion.show_available_models()
196
    return JSONResponse(content=models.model_dump())
Zhuohan Li's avatar
Zhuohan Li committed
197
198


Ethan Xu's avatar
Ethan Xu committed
199
@router.get("/version")
200
async def show_version():
201
    ver = {"version": VLLM_VERSION}
202
203
204
    return JSONResponse(content=ver)


Ethan Xu's avatar
Ethan Xu committed
205
@router.post("/v1/chat/completions")
206
207
async def create_chat_completion(request: ChatCompletionRequest,
                                 raw_request: Request):
208
209
    generator = await openai_serving_chat.create_chat_completion(
        request, raw_request)
210
211
212
213
    if isinstance(generator, ErrorResponse):
        return JSONResponse(content=generator.model_dump(),
                            status_code=generator.code)
    if request.stream:
214
        return StreamingResponse(content=generator,
215
                                 media_type="text/event-stream")
216
    else:
217
        assert isinstance(generator, ChatCompletionResponse)
218
        return JSONResponse(content=generator.model_dump())
219
220


Ethan Xu's avatar
Ethan Xu committed
221
@router.post("/v1/completions")
222
async def create_completion(request: CompletionRequest, raw_request: Request):
223
224
    generator = await openai_serving_completion.create_completion(
        request, raw_request)
225
226
227
228
    if isinstance(generator, ErrorResponse):
        return JSONResponse(content=generator.model_dump(),
                            status_code=generator.code)
    if request.stream:
229
        return StreamingResponse(content=generator,
Zhuohan Li's avatar
Zhuohan Li committed
230
                                 media_type="text/event-stream")
231
    else:
232
        return JSONResponse(content=generator.model_dump())
Zhuohan Li's avatar
Zhuohan Li committed
233
234


Ethan Xu's avatar
Ethan Xu committed
235
@router.post("/v1/embeddings")
236
237
238
239
240
241
242
243
244
245
async def create_embedding(request: EmbeddingRequest, raw_request: Request):
    generator = await openai_serving_embedding.create_embedding(
        request, raw_request)
    if isinstance(generator, ErrorResponse):
        return JSONResponse(content=generator.model_dump(),
                            status_code=generator.code)
    else:
        return JSONResponse(content=generator.model_dump())


246
247
def build_app(args: Namespace) -> FastAPI:
    app = FastAPI(lifespan=lifespan)
Ethan Xu's avatar
Ethan Xu committed
248
249
    app.include_router(router)
    app.root_path = args.root_path
Zhuohan Li's avatar
Zhuohan Li committed
250

251
252
    mount_metrics(app)

Zhuohan Li's avatar
Zhuohan Li committed
253
254
255
256
257
258
259
260
    app.add_middleware(
        CORSMiddleware,
        allow_origins=args.allowed_origins,
        allow_credentials=args.allow_credentials,
        allow_methods=args.allowed_methods,
        allow_headers=args.allowed_headers,
    )

Ethan Xu's avatar
Ethan Xu committed
261
262
263
264
265
266
    @app.exception_handler(RequestValidationError)
    async def validation_exception_handler(_, exc):
        err = openai_serving_chat.create_error_response(message=str(exc))
        return JSONResponse(err.model_dump(),
                            status_code=HTTPStatus.BAD_REQUEST)

267
    if token := envs.VLLM_API_KEY or args.api_key:
268
269
270

        @app.middleware("http")
        async def authentication(request: Request, call_next):
271
            root_path = "" if args.root_path is None else args.root_path
272
273
            if request.method == "OPTIONS":
                return await call_next(request)
274
            if not request.url.path.startswith(f"{root_path}/v1"):
275
276
277
278
279
280
281
282
283
284
285
286
287
288
                return await call_next(request)
            if request.headers.get("Authorization") != "Bearer " + token:
                return JSONResponse(content={"error": "Unauthorized"},
                                    status_code=401)
            return await call_next(request)

    for middleware in args.middleware:
        module_path, object_name = middleware.rsplit(".", 1)
        imported = getattr(importlib.import_module(module_path), object_name)
        if inspect.isclass(imported):
            app.add_middleware(imported)
        elif inspect.iscoroutinefunction(imported):
            app.middleware("http")(imported)
        else:
289
290
            raise ValueError(f"Invalid middleware {middleware}. "
                             f"Must be a function or a class.")
291

Ethan Xu's avatar
Ethan Xu committed
292
293
294
    return app


295
async def init_app(
296
    async_engine_client: AsyncEngineClient,
297
298
    args: Namespace,
) -> FastAPI:
Ethan Xu's avatar
Ethan Xu committed
299
300
    app = build_app(args)

301
    if args.served_model_name is not None:
302
        served_model_names = args.served_model_name
303
    else:
304
        served_model_names = [args.model]
305

306
    model_config = await async_engine_client.get_model_config()
307

308
309
310
311
312
    if args.disable_log_requests:
        request_logger = None
    else:
        request_logger = RequestLogger(max_log_len=args.max_log_len)

Ethan Xu's avatar
Ethan Xu committed
313
314
315
    global openai_serving_chat
    global openai_serving_completion
    global openai_serving_embedding
316
    global openai_serving_tokenization
Ethan Xu's avatar
Ethan Xu committed
317

318
    openai_serving_chat = OpenAIServingChat(
319
        async_engine_client,
320
321
322
323
324
325
326
        model_config,
        served_model_names,
        args.response_role,
        lora_modules=args.lora_modules,
        prompt_adapters=args.prompt_adapters,
        request_logger=request_logger,
        chat_template=args.chat_template,
327
        return_tokens_as_token_ids=args.return_tokens_as_token_ids,
328
    )
329
    openai_serving_completion = OpenAIServingCompletion(
330
        async_engine_client,
331
332
333
334
335
        model_config,
        served_model_names,
        lora_modules=args.lora_modules,
        prompt_adapters=args.prompt_adapters,
        request_logger=request_logger,
336
        return_tokens_as_token_ids=args.return_tokens_as_token_ids,
337
338
    )
    openai_serving_embedding = OpenAIServingEmbedding(
339
        async_engine_client,
340
341
342
343
        model_config,
        served_model_names,
        request_logger=request_logger,
    )
344
    openai_serving_tokenization = OpenAIServingTokenization(
345
        async_engine_client,
346
347
348
349
350
351
        model_config,
        served_model_names,
        lora_modules=args.lora_modules,
        request_logger=request_logger,
        chat_template=args.chat_template,
    )
352
    app.root_path = args.root_path
353

354
    return app
355
356


357
async def run_server(args, **uvicorn_kwargs) -> None:
358
359
360
    logger.info("vLLM API server version %s", VLLM_VERSION)
    logger.info("args: %s", args)

361
    async with build_async_engine_client(args) as async_engine_client:
362
363
364
365
        app = await init_app(async_engine_client, args)

        shutdown_task = await serve_http(
            app,
366
            engine=async_engine_client,
367
368
369
370
371
372
373
374
            host=args.host,
            port=args.port,
            log_level=args.uvicorn_log_level,
            timeout_keep_alive=TIMEOUT_KEEP_ALIVE,
            ssl_keyfile=args.ssl_keyfile,
            ssl_certfile=args.ssl_certfile,
            ssl_ca_certs=args.ssl_ca_certs,
            ssl_cert_reqs=args.ssl_cert_reqs,
375
376
377
            **uvicorn_kwargs,
        )

378
379
    # NB: Await server shutdown only after the backend context is exited
    await shutdown_task
380

Ethan Xu's avatar
Ethan Xu committed
381
382
383
384
385
386
387
388

if __name__ == "__main__":
    # NOTE(simon):
    # This section should be in sync with vllm/scripts.py for CLI entrypoints.
    parser = FlexibleArgumentParser(
        description="vLLM OpenAI-Compatible RESTful API server.")
    parser = make_arg_parser(parser)
    args = parser.parse_args()
389

390
    asyncio.run(run_server(args))