api_server.py 13.9 KB
Newer Older
1
import asyncio
2
3
import importlib
import inspect
4
import re
5
import signal
6
7
from contextlib import asynccontextmanager
from http import HTTPStatus
8
9
from multiprocessing import Process
from typing import AsyncIterator, Set
10

11
12
13
import fastapi
import uvicorn
from fastapi import APIRouter, Request
Zhuohan Li's avatar
Zhuohan Li committed
14
15
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
16
17
from fastapi.responses import JSONResponse, Response, StreamingResponse
from prometheus_client import make_asgi_app
18
from starlette.routing import Mount
Zhuohan Li's avatar
Zhuohan Li committed
19

20
import vllm.envs as envs
21
from vllm.config import ModelConfig
Woosuk Kwon's avatar
Woosuk Kwon committed
22
23
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
24
from vllm.engine.protocol import AsyncEngineClient
25
from vllm.entrypoints.logger import RequestLogger
26
from vllm.entrypoints.openai.cli_args import make_arg_parser
27
28
# yapf conflicts with isort for this block
# yapf: disable
29
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
30
                                              ChatCompletionResponse,
31
                                              CompletionRequest,
32
33
34
35
36
                                              DetokenizeRequest,
                                              DetokenizeResponse,
                                              EmbeddingRequest, ErrorResponse,
                                              TokenizeRequest,
                                              TokenizeResponse)
37
38
from vllm.entrypoints.openai.rpc.client import AsyncEngineRPCClient
from vllm.entrypoints.openai.rpc.server import run_rpc_server
39
# yapf: enable
40
41
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
42
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
43
44
from vllm.entrypoints.openai.serving_tokenization import (
    OpenAIServingTokenization)
45
from vllm.logger import init_logger
yhu422's avatar
yhu422 committed
46
from vllm.usage.usage_lib import UsageContext
47
from vllm.utils import FlexibleArgumentParser, get_open_port
48
from vllm.version import __version__ as VLLM_VERSION
Zhuohan Li's avatar
Zhuohan Li committed
49

50
TIMEOUT_KEEP_ALIVE = 5  # seconds
Zhuohan Li's avatar
Zhuohan Li committed
51

52
async_engine_client: AsyncEngineClient
Ethan Xu's avatar
Ethan Xu committed
53
engine_args: AsyncEngineArgs
54
55
openai_serving_chat: OpenAIServingChat
openai_serving_completion: OpenAIServingCompletion
56
openai_serving_embedding: OpenAIServingEmbedding
57
openai_serving_tokenization: OpenAIServingTokenization
58

59
logger = init_logger('vllm.entrypoints.openai.api_server')
60

61
_running_tasks: Set[asyncio.Task] = set()
62

63

64
65
66
67
68
69
70
71
72
def model_is_embedding(model_name: str) -> bool:
    return ModelConfig(model=model_name,
                       tokenizer=model_name,
                       tokenizer_mode="auto",
                       trust_remote_code=False,
                       seed=0,
                       dtype="float16").embedding_mode


73
@asynccontextmanager
74
async def lifespan(app: fastapi.FastAPI):
75
76
77
78

    async def _force_log():
        while True:
            await asyncio.sleep(10)
79
            await async_engine_client.do_log_stats()
80
81

    if not engine_args.disable_log_stats:
82
83
84
        task = asyncio.create_task(_force_log())
        _running_tasks.add(task)
        task.add_done_callback(_running_tasks.remove)
85
86
87
88

    yield


89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
@asynccontextmanager
async def build_async_engine_client(args) -> AsyncIterator[AsyncEngineClient]:
    # Context manager to handle async_engine_client lifecycle
    # Ensures everything is shutdown and cleaned up on error/exit
    global engine_args
    engine_args = AsyncEngineArgs.from_cli_args(args)

    # Backend itself still global for the silly lil' health handler
    global async_engine_client

    # If manually triggered or embedding model, use AsyncLLMEngine in process.
    # TODO: support embedding model via RPC.
    if (model_is_embedding(args.model)
            or args.disable_frontend_multiprocessing):
        async_engine_client = AsyncLLMEngine.from_engine_args(
            engine_args, usage_context=UsageContext.OPENAI_API_SERVER)
        yield async_engine_client
        return

    # Otherwise, use the multiprocessing AsyncLLMEngine.
    else:
        # Start RPCServer in separate process (holds the AsyncLLMEngine).
        port = get_open_port(envs.VLLM_RPC_PORT)
        rpc_server_process = Process(target=run_rpc_server,
                                     args=(engine_args,
                                           UsageContext.OPENAI_API_SERVER,
                                           port))
        rpc_server_process.start()

        # Build RPCClient, which conforms to AsyncEngineClient Protocol.
        async_engine_client = AsyncEngineRPCClient(port)
        await async_engine_client.setup()

        try:
            yield async_engine_client
        finally:
            # Ensure rpc server process was terminated
            rpc_server_process.terminate()

            # Close all open connections to the backend
            async_engine_client.close()

            # Wait for server process to join
            rpc_server_process.join()


Ethan Xu's avatar
Ethan Xu committed
135
router = APIRouter()
Zhuohan Li's avatar
Zhuohan Li committed
136

137

138
def mount_metrics(app: fastapi.FastAPI):
139
140
141
142
143
    # Add prometheus asgi middleware to route /metrics requests
    metrics_route = Mount("/metrics", make_asgi_app())
    # Workaround for 307 Redirect for /metrics
    metrics_route.path_regex = re.compile('^/metrics(?P<path>.*)$')
    app.routes.append(metrics_route)
144
145


Ethan Xu's avatar
Ethan Xu committed
146
@router.get("/health")
147
148
async def health() -> Response:
    """Health check."""
149
    await async_engine_client.check_health()
150
151
152
    return Response(status_code=200)


Ethan Xu's avatar
Ethan Xu committed
153
@router.post("/tokenize")
154
async def tokenize(request: TokenizeRequest):
155
    generator = await openai_serving_tokenization.create_tokenize(request)
156
157
158
159
160
161
162
163
    if isinstance(generator, ErrorResponse):
        return JSONResponse(content=generator.model_dump(),
                            status_code=generator.code)
    else:
        assert isinstance(generator, TokenizeResponse)
        return JSONResponse(content=generator.model_dump())


Ethan Xu's avatar
Ethan Xu committed
164
@router.post("/detokenize")
165
async def detokenize(request: DetokenizeRequest):
166
    generator = await openai_serving_tokenization.create_detokenize(request)
167
168
169
170
171
172
173
174
    if isinstance(generator, ErrorResponse):
        return JSONResponse(content=generator.model_dump(),
                            status_code=generator.code)
    else:
        assert isinstance(generator, DetokenizeResponse)
        return JSONResponse(content=generator.model_dump())


Ethan Xu's avatar
Ethan Xu committed
175
@router.get("/v1/models")
Zhuohan Li's avatar
Zhuohan Li committed
176
async def show_available_models():
177
    models = await openai_serving_completion.show_available_models()
178
    return JSONResponse(content=models.model_dump())
Zhuohan Li's avatar
Zhuohan Li committed
179
180


Ethan Xu's avatar
Ethan Xu committed
181
@router.get("/version")
182
async def show_version():
183
    ver = {"version": VLLM_VERSION}
184
185
186
    return JSONResponse(content=ver)


Ethan Xu's avatar
Ethan Xu committed
187
@router.post("/v1/chat/completions")
188
189
async def create_chat_completion(request: ChatCompletionRequest,
                                 raw_request: Request):
190
191
    generator = await openai_serving_chat.create_chat_completion(
        request, raw_request)
192
193
194
195
    if isinstance(generator, ErrorResponse):
        return JSONResponse(content=generator.model_dump(),
                            status_code=generator.code)
    if request.stream:
196
        return StreamingResponse(content=generator,
197
                                 media_type="text/event-stream")
198
    else:
199
        assert isinstance(generator, ChatCompletionResponse)
200
        return JSONResponse(content=generator.model_dump())
201
202


Ethan Xu's avatar
Ethan Xu committed
203
@router.post("/v1/completions")
204
async def create_completion(request: CompletionRequest, raw_request: Request):
205
206
    generator = await openai_serving_completion.create_completion(
        request, raw_request)
207
208
209
210
    if isinstance(generator, ErrorResponse):
        return JSONResponse(content=generator.model_dump(),
                            status_code=generator.code)
    if request.stream:
211
        return StreamingResponse(content=generator,
Zhuohan Li's avatar
Zhuohan Li committed
212
                                 media_type="text/event-stream")
213
    else:
214
        return JSONResponse(content=generator.model_dump())
Zhuohan Li's avatar
Zhuohan Li committed
215
216


Ethan Xu's avatar
Ethan Xu committed
217
@router.post("/v1/embeddings")
218
219
220
221
222
223
224
225
226
227
async def create_embedding(request: EmbeddingRequest, raw_request: Request):
    generator = await openai_serving_embedding.create_embedding(
        request, raw_request)
    if isinstance(generator, ErrorResponse):
        return JSONResponse(content=generator.model_dump(),
                            status_code=generator.code)
    else:
        return JSONResponse(content=generator.model_dump())


228
229
def build_app(args):
    app = fastapi.FastAPI(lifespan=lifespan)
Ethan Xu's avatar
Ethan Xu committed
230
231
    app.include_router(router)
    app.root_path = args.root_path
Zhuohan Li's avatar
Zhuohan Li committed
232

233
234
    mount_metrics(app)

Zhuohan Li's avatar
Zhuohan Li committed
235
236
237
238
239
240
241
242
    app.add_middleware(
        CORSMiddleware,
        allow_origins=args.allowed_origins,
        allow_credentials=args.allow_credentials,
        allow_methods=args.allowed_methods,
        allow_headers=args.allowed_headers,
    )

Ethan Xu's avatar
Ethan Xu committed
243
244
245
246
247
248
    @app.exception_handler(RequestValidationError)
    async def validation_exception_handler(_, exc):
        err = openai_serving_chat.create_error_response(message=str(exc))
        return JSONResponse(err.model_dump(),
                            status_code=HTTPStatus.BAD_REQUEST)

249
    if token := envs.VLLM_API_KEY or args.api_key:
250
251
252

        @app.middleware("http")
        async def authentication(request: Request, call_next):
253
            root_path = "" if args.root_path is None else args.root_path
254
255
            if request.method == "OPTIONS":
                return await call_next(request)
256
            if not request.url.path.startswith(f"{root_path}/v1"):
257
258
259
260
261
262
263
264
265
266
267
268
269
270
                return await call_next(request)
            if request.headers.get("Authorization") != "Bearer " + token:
                return JSONResponse(content={"error": "Unauthorized"},
                                    status_code=401)
            return await call_next(request)

    for middleware in args.middleware:
        module_path, object_name = middleware.rsplit(".", 1)
        imported = getattr(importlib.import_module(module_path), object_name)
        if inspect.isclass(imported):
            app.add_middleware(imported)
        elif inspect.iscoroutinefunction(imported):
            app.middleware("http")(imported)
        else:
271
272
            raise ValueError(f"Invalid middleware {middleware}. "
                             f"Must be a function or a class.")
273

Ethan Xu's avatar
Ethan Xu committed
274
275
276
    return app


277
async def build_server(
278
    async_engine_client: AsyncEngineClient,
279
280
281
    args,
    **uvicorn_kwargs,
) -> uvicorn.Server:
Ethan Xu's avatar
Ethan Xu committed
282
283
    app = build_app(args)

284
    if args.served_model_name is not None:
285
        served_model_names = args.served_model_name
286
    else:
287
        served_model_names = [args.model]
288

289
    model_config = await async_engine_client.get_model_config()
290

291
292
293
294
295
    if args.disable_log_requests:
        request_logger = None
    else:
        request_logger = RequestLogger(max_log_len=args.max_log_len)

Ethan Xu's avatar
Ethan Xu committed
296
297
298
    global openai_serving_chat
    global openai_serving_completion
    global openai_serving_embedding
299
    global openai_serving_tokenization
Ethan Xu's avatar
Ethan Xu committed
300

301
    openai_serving_chat = OpenAIServingChat(
302
        async_engine_client,
303
304
305
306
307
308
309
        model_config,
        served_model_names,
        args.response_role,
        lora_modules=args.lora_modules,
        prompt_adapters=args.prompt_adapters,
        request_logger=request_logger,
        chat_template=args.chat_template,
310
        return_tokens_as_token_ids=args.return_tokens_as_token_ids,
311
    )
312
    openai_serving_completion = OpenAIServingCompletion(
313
        async_engine_client,
314
315
316
317
318
        model_config,
        served_model_names,
        lora_modules=args.lora_modules,
        prompt_adapters=args.prompt_adapters,
        request_logger=request_logger,
319
        return_tokens_as_token_ids=args.return_tokens_as_token_ids,
320
321
    )
    openai_serving_embedding = OpenAIServingEmbedding(
322
        async_engine_client,
323
324
325
326
        model_config,
        served_model_names,
        request_logger=request_logger,
    )
327
    openai_serving_tokenization = OpenAIServingTokenization(
328
        async_engine_client,
329
330
331
332
333
334
        model_config,
        served_model_names,
        lora_modules=args.lora_modules,
        request_logger=request_logger,
        chat_template=args.chat_template,
    )
335
    app.root_path = args.root_path
336

337
338
339
340
341
342
    logger.info("Available routes are:")
    for route in app.routes:
        if not hasattr(route, 'methods'):
            continue
        methods = ', '.join(route.methods)
        logger.info("Route: %s, Methods: %s", route.path, methods)
343

344
    config = uvicorn.Config(
345
346
347
348
349
350
351
352
353
354
355
356
        app,
        host=args.host,
        port=args.port,
        log_level=args.uvicorn_log_level,
        timeout_keep_alive=TIMEOUT_KEEP_ALIVE,
        ssl_keyfile=args.ssl_keyfile,
        ssl_certfile=args.ssl_certfile,
        ssl_ca_certs=args.ssl_ca_certs,
        ssl_cert_reqs=args.ssl_cert_reqs,
        **uvicorn_kwargs,
    )

357
358
359
    return uvicorn.Server(config)


360
async def run_server(args, **uvicorn_kwargs) -> None:
361
362
363
    logger.info("vLLM API server version %s", VLLM_VERSION)
    logger.info("args: %s", args)

364
365
366
367
368
369
370
371
372
373
    shutdown_task = None
    async with build_async_engine_client(args) as async_engine_client:

        server = await build_server(
            async_engine_client,
            args,
            **uvicorn_kwargs,
        )

        loop = asyncio.get_running_loop()
374

375
        server_task = loop.create_task(server.serve())
376

377
378
379
        def signal_handler() -> None:
            # prevents the uvicorn signal handler to exit early
            server_task.cancel()
380

381
382
        loop.add_signal_handler(signal.SIGINT, signal_handler)
        loop.add_signal_handler(signal.SIGTERM, signal_handler)
383

384
385
386
387
388
        try:
            await server_task
        except asyncio.CancelledError:
            logger.info("Gracefully stopping http server")
            shutdown_task = server.shutdown()
389

390
391
392
    if shutdown_task:
        # NB: Await server shutdown only after the backend context is exited
        await shutdown_task
393

Ethan Xu's avatar
Ethan Xu committed
394
395
396
397
398
399
400
401

if __name__ == "__main__":
    # NOTE(simon):
    # This section should be in sync with vllm/scripts.py for CLI entrypoints.
    parser = FlexibleArgumentParser(
        description="vLLM OpenAI-Compatible RESTful API server.")
    parser = make_arg_parser(parser)
    args = parser.parse_args()
402
    asyncio.run(run_server(args))