api_server.py 13.2 KB
Newer Older
1
import asyncio
2
3
import importlib
import inspect
4
import re
5
from argparse import Namespace
6
7
from contextlib import asynccontextmanager
from http import HTTPStatus
8
9
from multiprocessing import Process
from typing import AsyncIterator, Set
10

11
from fastapi import APIRouter, FastAPI, Request
Zhuohan Li's avatar
Zhuohan Li committed
12
13
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
14
15
from fastapi.responses import JSONResponse, Response, StreamingResponse
from prometheus_client import make_asgi_app
16
from starlette.routing import Mount
Zhuohan Li's avatar
Zhuohan Li committed
17

18
import vllm.envs as envs
19
from vllm.config import ModelConfig
Woosuk Kwon's avatar
Woosuk Kwon committed
20
21
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
22
from vllm.engine.protocol import AsyncEngineClient
23
from vllm.entrypoints.launcher import serve_http
24
from vllm.entrypoints.logger import RequestLogger
25
from vllm.entrypoints.openai.cli_args import make_arg_parser
26
27
# yapf conflicts with isort for this block
# yapf: disable
28
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
29
                                              ChatCompletionResponse,
30
                                              CompletionRequest,
31
32
33
34
35
                                              DetokenizeRequest,
                                              DetokenizeResponse,
                                              EmbeddingRequest, ErrorResponse,
                                              TokenizeRequest,
                                              TokenizeResponse)
36
37
from vllm.entrypoints.openai.rpc.client import AsyncEngineRPCClient
from vllm.entrypoints.openai.rpc.server import run_rpc_server
38
# yapf: enable
39
40
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
41
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
42
43
from vllm.entrypoints.openai.serving_tokenization import (
    OpenAIServingTokenization)
44
from vllm.logger import init_logger
yhu422's avatar
yhu422 committed
45
from vllm.usage.usage_lib import UsageContext
46
from vllm.utils import FlexibleArgumentParser, get_open_zmq_ipc_path
47
from vllm.version import __version__ as VLLM_VERSION
Zhuohan Li's avatar
Zhuohan Li committed
48

49
TIMEOUT_KEEP_ALIVE = 5  # seconds
Zhuohan Li's avatar
Zhuohan Li committed
50

51
async_engine_client: AsyncEngineClient
Ethan Xu's avatar
Ethan Xu committed
52
engine_args: AsyncEngineArgs
53
54
openai_serving_chat: OpenAIServingChat
openai_serving_completion: OpenAIServingCompletion
55
openai_serving_embedding: OpenAIServingEmbedding
56
openai_serving_tokenization: OpenAIServingTokenization
57

58
logger = init_logger('vllm.entrypoints.openai.api_server')
59

60
_running_tasks: Set[asyncio.Task] = set()
61

62

63
def model_is_embedding(model_name: str, trust_remote_code: bool) -> bool:
64
65
66
    return ModelConfig(model=model_name,
                       tokenizer=model_name,
                       tokenizer_mode="auto",
67
                       trust_remote_code=trust_remote_code,
68
69
70
71
                       seed=0,
                       dtype="float16").embedding_mode


72
@asynccontextmanager
73
async def lifespan(app: FastAPI):
74
75
76
77

    async def _force_log():
        while True:
            await asyncio.sleep(10)
78
            await async_engine_client.do_log_stats()
79
80

    if not engine_args.disable_log_stats:
81
82
83
        task = asyncio.create_task(_force_log())
        _running_tasks.add(task)
        task.add_done_callback(_running_tasks.remove)
84
85
86
87

    yield


88
89
90
91
92
93
94
95
96
97
98
99
@asynccontextmanager
async def build_async_engine_client(args) -> AsyncIterator[AsyncEngineClient]:
    # Context manager to handle async_engine_client lifecycle
    # Ensures everything is shutdown and cleaned up on error/exit
    global engine_args
    engine_args = AsyncEngineArgs.from_cli_args(args)

    # Backend itself still global for the silly lil' health handler
    global async_engine_client

    # If manually triggered or embedding model, use AsyncLLMEngine in process.
    # TODO: support embedding model via RPC.
100
    if (model_is_embedding(args.model, args.trust_remote_code)
101
102
103
104
105
106
107
108
            or args.disable_frontend_multiprocessing):
        async_engine_client = AsyncLLMEngine.from_engine_args(
            engine_args, usage_context=UsageContext.OPENAI_API_SERVER)
        yield async_engine_client
        return

    # Otherwise, use the multiprocessing AsyncLLMEngine.
    else:
109
110
111
112
113
        # Select random path for IPC.
        rpc_path = get_open_zmq_ipc_path()
        logger.info("Multiprocessing frontend to use %s for RPC Path.",
                    rpc_path)

114
115
116
117
        # Start RPCServer in separate process (holds the AsyncLLMEngine).
        rpc_server_process = Process(target=run_rpc_server,
                                     args=(engine_args,
                                           UsageContext.OPENAI_API_SERVER,
118
                                           rpc_path))
119
120
121
        rpc_server_process.start()

        # Build RPCClient, which conforms to AsyncEngineClient Protocol.
122
        async_engine_client = AsyncEngineRPCClient(rpc_path)
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
        await async_engine_client.setup()

        try:
            yield async_engine_client
        finally:
            # Ensure rpc server process was terminated
            rpc_server_process.terminate()

            # Close all open connections to the backend
            async_engine_client.close()

            # Wait for server process to join
            rpc_server_process.join()


Ethan Xu's avatar
Ethan Xu committed
138
router = APIRouter()
Zhuohan Li's avatar
Zhuohan Li committed
139

140

141
def mount_metrics(app: FastAPI):
142
143
144
145
146
    # Add prometheus asgi middleware to route /metrics requests
    metrics_route = Mount("/metrics", make_asgi_app())
    # Workaround for 307 Redirect for /metrics
    metrics_route.path_regex = re.compile('^/metrics(?P<path>.*)$')
    app.routes.append(metrics_route)
147
148


Ethan Xu's avatar
Ethan Xu committed
149
@router.get("/health")
150
151
async def health() -> Response:
    """Health check."""
152
    await async_engine_client.check_health()
153
154
155
    return Response(status_code=200)


Ethan Xu's avatar
Ethan Xu committed
156
@router.post("/tokenize")
157
async def tokenize(request: TokenizeRequest):
158
    generator = await openai_serving_tokenization.create_tokenize(request)
159
160
161
162
163
164
165
166
    if isinstance(generator, ErrorResponse):
        return JSONResponse(content=generator.model_dump(),
                            status_code=generator.code)
    else:
        assert isinstance(generator, TokenizeResponse)
        return JSONResponse(content=generator.model_dump())


Ethan Xu's avatar
Ethan Xu committed
167
@router.post("/detokenize")
168
async def detokenize(request: DetokenizeRequest):
169
    generator = await openai_serving_tokenization.create_detokenize(request)
170
171
172
173
174
175
176
177
    if isinstance(generator, ErrorResponse):
        return JSONResponse(content=generator.model_dump(),
                            status_code=generator.code)
    else:
        assert isinstance(generator, DetokenizeResponse)
        return JSONResponse(content=generator.model_dump())


Ethan Xu's avatar
Ethan Xu committed
178
@router.get("/v1/models")
Zhuohan Li's avatar
Zhuohan Li committed
179
async def show_available_models():
180
    models = await openai_serving_completion.show_available_models()
181
    return JSONResponse(content=models.model_dump())
Zhuohan Li's avatar
Zhuohan Li committed
182
183


Ethan Xu's avatar
Ethan Xu committed
184
@router.get("/version")
185
async def show_version():
186
    ver = {"version": VLLM_VERSION}
187
188
189
    return JSONResponse(content=ver)


Ethan Xu's avatar
Ethan Xu committed
190
@router.post("/v1/chat/completions")
191
192
async def create_chat_completion(request: ChatCompletionRequest,
                                 raw_request: Request):
193
194
    generator = await openai_serving_chat.create_chat_completion(
        request, raw_request)
195
196
197
198
    if isinstance(generator, ErrorResponse):
        return JSONResponse(content=generator.model_dump(),
                            status_code=generator.code)
    if request.stream:
199
        return StreamingResponse(content=generator,
200
                                 media_type="text/event-stream")
201
    else:
202
        assert isinstance(generator, ChatCompletionResponse)
203
        return JSONResponse(content=generator.model_dump())
204
205


Ethan Xu's avatar
Ethan Xu committed
206
@router.post("/v1/completions")
207
async def create_completion(request: CompletionRequest, raw_request: Request):
208
209
    generator = await openai_serving_completion.create_completion(
        request, raw_request)
210
211
212
213
    if isinstance(generator, ErrorResponse):
        return JSONResponse(content=generator.model_dump(),
                            status_code=generator.code)
    if request.stream:
214
        return StreamingResponse(content=generator,
Zhuohan Li's avatar
Zhuohan Li committed
215
                                 media_type="text/event-stream")
216
    else:
217
        return JSONResponse(content=generator.model_dump())
Zhuohan Li's avatar
Zhuohan Li committed
218
219


Ethan Xu's avatar
Ethan Xu committed
220
@router.post("/v1/embeddings")
221
222
223
224
225
226
227
228
229
230
async def create_embedding(request: EmbeddingRequest, raw_request: Request):
    generator = await openai_serving_embedding.create_embedding(
        request, raw_request)
    if isinstance(generator, ErrorResponse):
        return JSONResponse(content=generator.model_dump(),
                            status_code=generator.code)
    else:
        return JSONResponse(content=generator.model_dump())


231
232
def build_app(args: Namespace) -> FastAPI:
    app = FastAPI(lifespan=lifespan)
Ethan Xu's avatar
Ethan Xu committed
233
234
    app.include_router(router)
    app.root_path = args.root_path
Zhuohan Li's avatar
Zhuohan Li committed
235

236
237
    mount_metrics(app)

Zhuohan Li's avatar
Zhuohan Li committed
238
239
240
241
242
243
244
245
    app.add_middleware(
        CORSMiddleware,
        allow_origins=args.allowed_origins,
        allow_credentials=args.allow_credentials,
        allow_methods=args.allowed_methods,
        allow_headers=args.allowed_headers,
    )

Ethan Xu's avatar
Ethan Xu committed
246
247
248
249
250
251
    @app.exception_handler(RequestValidationError)
    async def validation_exception_handler(_, exc):
        err = openai_serving_chat.create_error_response(message=str(exc))
        return JSONResponse(err.model_dump(),
                            status_code=HTTPStatus.BAD_REQUEST)

252
    if token := envs.VLLM_API_KEY or args.api_key:
253
254
255

        @app.middleware("http")
        async def authentication(request: Request, call_next):
256
            root_path = "" if args.root_path is None else args.root_path
257
258
            if request.method == "OPTIONS":
                return await call_next(request)
259
            if not request.url.path.startswith(f"{root_path}/v1"):
260
261
262
263
264
265
266
267
268
269
270
271
272
273
                return await call_next(request)
            if request.headers.get("Authorization") != "Bearer " + token:
                return JSONResponse(content={"error": "Unauthorized"},
                                    status_code=401)
            return await call_next(request)

    for middleware in args.middleware:
        module_path, object_name = middleware.rsplit(".", 1)
        imported = getattr(importlib.import_module(module_path), object_name)
        if inspect.isclass(imported):
            app.add_middleware(imported)
        elif inspect.iscoroutinefunction(imported):
            app.middleware("http")(imported)
        else:
274
275
            raise ValueError(f"Invalid middleware {middleware}. "
                             f"Must be a function or a class.")
276

Ethan Xu's avatar
Ethan Xu committed
277
278
279
    return app


280
async def init_app(
281
    async_engine_client: AsyncEngineClient,
282
283
    args: Namespace,
) -> FastAPI:
Ethan Xu's avatar
Ethan Xu committed
284
285
    app = build_app(args)

286
    if args.served_model_name is not None:
287
        served_model_names = args.served_model_name
288
    else:
289
        served_model_names = [args.model]
290

291
    model_config = await async_engine_client.get_model_config()
292

293
294
295
296
297
    if args.disable_log_requests:
        request_logger = None
    else:
        request_logger = RequestLogger(max_log_len=args.max_log_len)

Ethan Xu's avatar
Ethan Xu committed
298
299
300
    global openai_serving_chat
    global openai_serving_completion
    global openai_serving_embedding
301
    global openai_serving_tokenization
Ethan Xu's avatar
Ethan Xu committed
302

303
    openai_serving_chat = OpenAIServingChat(
304
        async_engine_client,
305
306
307
308
309
310
311
        model_config,
        served_model_names,
        args.response_role,
        lora_modules=args.lora_modules,
        prompt_adapters=args.prompt_adapters,
        request_logger=request_logger,
        chat_template=args.chat_template,
312
        return_tokens_as_token_ids=args.return_tokens_as_token_ids,
313
    )
314
    openai_serving_completion = OpenAIServingCompletion(
315
        async_engine_client,
316
317
318
319
320
        model_config,
        served_model_names,
        lora_modules=args.lora_modules,
        prompt_adapters=args.prompt_adapters,
        request_logger=request_logger,
321
        return_tokens_as_token_ids=args.return_tokens_as_token_ids,
322
323
    )
    openai_serving_embedding = OpenAIServingEmbedding(
324
        async_engine_client,
325
326
327
328
        model_config,
        served_model_names,
        request_logger=request_logger,
    )
329
    openai_serving_tokenization = OpenAIServingTokenization(
330
        async_engine_client,
331
332
333
334
335
336
        model_config,
        served_model_names,
        lora_modules=args.lora_modules,
        request_logger=request_logger,
        chat_template=args.chat_template,
    )
337
    app.root_path = args.root_path
338

339
    return app
340
341


342
async def run_server(args, **uvicorn_kwargs) -> None:
343
344
345
    logger.info("vLLM API server version %s", VLLM_VERSION)
    logger.info("args: %s", args)

346
    async with build_async_engine_client(args) as async_engine_client:
347
348
349
350
351
352
353
354
355
356
357
358
        app = await init_app(async_engine_client, args)

        shutdown_task = await serve_http(
            app,
            host=args.host,
            port=args.port,
            log_level=args.uvicorn_log_level,
            timeout_keep_alive=TIMEOUT_KEEP_ALIVE,
            ssl_keyfile=args.ssl_keyfile,
            ssl_certfile=args.ssl_certfile,
            ssl_ca_certs=args.ssl_ca_certs,
            ssl_cert_reqs=args.ssl_cert_reqs,
359
360
361
            **uvicorn_kwargs,
        )

362
363
    # NB: Await server shutdown only after the backend context is exited
    await shutdown_task
364

Ethan Xu's avatar
Ethan Xu committed
365
366
367
368
369
370
371
372

if __name__ == "__main__":
    # NOTE(simon):
    # This section should be in sync with vllm/scripts.py for CLI entrypoints.
    parser = FlexibleArgumentParser(
        description="vLLM OpenAI-Compatible RESTful API server.")
    parser = make_arg_parser(parser)
    args = parser.parse_args()
373

374
    asyncio.run(run_server(args))