api_server.py 13.6 KB
Newer Older
1
import asyncio
2
3
import importlib
import inspect
4
import re
5
from argparse import Namespace
6
7
from contextlib import asynccontextmanager
from http import HTTPStatus
8
9
from multiprocessing import Process
from typing import AsyncIterator, Set
10

11
from fastapi import APIRouter, FastAPI, Request
Zhuohan Li's avatar
Zhuohan Li committed
12
13
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
14
15
from fastapi.responses import JSONResponse, Response, StreamingResponse
from prometheus_client import make_asgi_app
16
from starlette.routing import Mount
Zhuohan Li's avatar
Zhuohan Li committed
17

18
import vllm.envs as envs
19
from vllm.config import ModelConfig
Woosuk Kwon's avatar
Woosuk Kwon committed
20
21
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
22
from vllm.engine.protocol import AsyncEngineClient
23
from vllm.entrypoints.launcher import serve_http
24
from vllm.entrypoints.logger import RequestLogger
25
from vllm.entrypoints.openai.cli_args import make_arg_parser
26
27
# yapf conflicts with isort for this block
# yapf: disable
28
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
29
                                              ChatCompletionResponse,
30
                                              CompletionRequest,
31
32
33
34
35
                                              DetokenizeRequest,
                                              DetokenizeResponse,
                                              EmbeddingRequest, ErrorResponse,
                                              TokenizeRequest,
                                              TokenizeResponse)
36
37
from vllm.entrypoints.openai.rpc.client import AsyncEngineRPCClient
from vllm.entrypoints.openai.rpc.server import run_rpc_server
38
# yapf: enable
39
40
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
41
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
42
43
from vllm.entrypoints.openai.serving_tokenization import (
    OpenAIServingTokenization)
44
from vllm.logger import init_logger
yhu422's avatar
yhu422 committed
45
from vllm.usage.usage_lib import UsageContext
46
from vllm.utils import FlexibleArgumentParser, get_open_zmq_ipc_path
47
from vllm.version import __version__ as VLLM_VERSION
Zhuohan Li's avatar
Zhuohan Li committed
48

49
TIMEOUT_KEEP_ALIVE = 5  # seconds
Zhuohan Li's avatar
Zhuohan Li committed
50

51
async_engine_client: AsyncEngineClient
Ethan Xu's avatar
Ethan Xu committed
52
engine_args: AsyncEngineArgs
53
54
openai_serving_chat: OpenAIServingChat
openai_serving_completion: OpenAIServingCompletion
55
openai_serving_embedding: OpenAIServingEmbedding
56
openai_serving_tokenization: OpenAIServingTokenization
57

58
logger = init_logger('vllm.entrypoints.openai.api_server')
59

60
_running_tasks: Set[asyncio.Task] = set()
61

62

63
def model_is_embedding(model_name: str, trust_remote_code: bool) -> bool:
64
65
66
    return ModelConfig(model=model_name,
                       tokenizer=model_name,
                       tokenizer_mode="auto",
67
                       trust_remote_code=trust_remote_code,
68
69
70
71
                       seed=0,
                       dtype="float16").embedding_mode


72
@asynccontextmanager
73
async def lifespan(app: FastAPI):
74
75
76
77

    async def _force_log():
        while True:
            await asyncio.sleep(10)
78
            await async_engine_client.do_log_stats()
79
80

    if not engine_args.disable_log_stats:
81
82
83
        task = asyncio.create_task(_force_log())
        _running_tasks.add(task)
        task.add_done_callback(_running_tasks.remove)
84
85
86
87

    yield


88
89
90
91
92
93
94
95
96
97
98
99
@asynccontextmanager
async def build_async_engine_client(args) -> AsyncIterator[AsyncEngineClient]:
    # Context manager to handle async_engine_client lifecycle
    # Ensures everything is shutdown and cleaned up on error/exit
    global engine_args
    engine_args = AsyncEngineArgs.from_cli_args(args)

    # Backend itself still global for the silly lil' health handler
    global async_engine_client

    # If manually triggered or embedding model, use AsyncLLMEngine in process.
    # TODO: support embedding model via RPC.
100
    if (model_is_embedding(args.model, args.trust_remote_code)
101
102
103
104
105
106
107
108
            or args.disable_frontend_multiprocessing):
        async_engine_client = AsyncLLMEngine.from_engine_args(
            engine_args, usage_context=UsageContext.OPENAI_API_SERVER)
        yield async_engine_client
        return

    # Otherwise, use the multiprocessing AsyncLLMEngine.
    else:
109
110
111
112
113
        # Select random path for IPC.
        rpc_path = get_open_zmq_ipc_path()
        logger.info("Multiprocessing frontend to use %s for RPC Path.",
                    rpc_path)

114
115
116
117
        # Start RPCServer in separate process (holds the AsyncLLMEngine).
        rpc_server_process = Process(target=run_rpc_server,
                                     args=(engine_args,
                                           UsageContext.OPENAI_API_SERVER,
118
                                           rpc_path))
119
120
121
        rpc_server_process.start()

        # Build RPCClient, which conforms to AsyncEngineClient Protocol.
122
        async_engine_client = AsyncEngineRPCClient(rpc_path)
123
124

        try:
125
126
127
128
129
130
131
132
133
134
            while True:
                try:
                    await async_engine_client.setup()
                    break
                except TimeoutError as e:
                    if not rpc_server_process.is_alive():
                        raise RuntimeError(
                            "The server process died before "
                            "responding to the readiness probe") from e

135
136
137
138
139
140
141
142
143
144
145
146
            yield async_engine_client
        finally:
            # Ensure rpc server process was terminated
            rpc_server_process.terminate()

            # Close all open connections to the backend
            async_engine_client.close()

            # Wait for server process to join
            rpc_server_process.join()


Ethan Xu's avatar
Ethan Xu committed
147
router = APIRouter()
Zhuohan Li's avatar
Zhuohan Li committed
148

149

150
def mount_metrics(app: FastAPI):
151
152
153
154
155
    # Add prometheus asgi middleware to route /metrics requests
    metrics_route = Mount("/metrics", make_asgi_app())
    # Workaround for 307 Redirect for /metrics
    metrics_route.path_regex = re.compile('^/metrics(?P<path>.*)$')
    app.routes.append(metrics_route)
156
157


Ethan Xu's avatar
Ethan Xu committed
158
@router.get("/health")
159
160
async def health() -> Response:
    """Health check."""
161
    await async_engine_client.check_health()
162
163
164
    return Response(status_code=200)


Ethan Xu's avatar
Ethan Xu committed
165
@router.post("/tokenize")
166
async def tokenize(request: TokenizeRequest):
167
    generator = await openai_serving_tokenization.create_tokenize(request)
168
169
170
171
172
173
174
175
    if isinstance(generator, ErrorResponse):
        return JSONResponse(content=generator.model_dump(),
                            status_code=generator.code)
    else:
        assert isinstance(generator, TokenizeResponse)
        return JSONResponse(content=generator.model_dump())


Ethan Xu's avatar
Ethan Xu committed
176
@router.post("/detokenize")
177
async def detokenize(request: DetokenizeRequest):
178
    generator = await openai_serving_tokenization.create_detokenize(request)
179
180
181
182
183
184
185
186
    if isinstance(generator, ErrorResponse):
        return JSONResponse(content=generator.model_dump(),
                            status_code=generator.code)
    else:
        assert isinstance(generator, DetokenizeResponse)
        return JSONResponse(content=generator.model_dump())


Ethan Xu's avatar
Ethan Xu committed
187
@router.get("/v1/models")
Zhuohan Li's avatar
Zhuohan Li committed
188
async def show_available_models():
189
    models = await openai_serving_completion.show_available_models()
190
    return JSONResponse(content=models.model_dump())
Zhuohan Li's avatar
Zhuohan Li committed
191
192


Ethan Xu's avatar
Ethan Xu committed
193
@router.get("/version")
194
async def show_version():
195
    ver = {"version": VLLM_VERSION}
196
197
198
    return JSONResponse(content=ver)


Ethan Xu's avatar
Ethan Xu committed
199
@router.post("/v1/chat/completions")
200
201
async def create_chat_completion(request: ChatCompletionRequest,
                                 raw_request: Request):
202
203
    generator = await openai_serving_chat.create_chat_completion(
        request, raw_request)
204
205
206
207
    if isinstance(generator, ErrorResponse):
        return JSONResponse(content=generator.model_dump(),
                            status_code=generator.code)
    if request.stream:
208
        return StreamingResponse(content=generator,
209
                                 media_type="text/event-stream")
210
    else:
211
        assert isinstance(generator, ChatCompletionResponse)
212
        return JSONResponse(content=generator.model_dump())
213
214


Ethan Xu's avatar
Ethan Xu committed
215
@router.post("/v1/completions")
216
async def create_completion(request: CompletionRequest, raw_request: Request):
217
218
    generator = await openai_serving_completion.create_completion(
        request, raw_request)
219
220
221
222
    if isinstance(generator, ErrorResponse):
        return JSONResponse(content=generator.model_dump(),
                            status_code=generator.code)
    if request.stream:
223
        return StreamingResponse(content=generator,
Zhuohan Li's avatar
Zhuohan Li committed
224
                                 media_type="text/event-stream")
225
    else:
226
        return JSONResponse(content=generator.model_dump())
Zhuohan Li's avatar
Zhuohan Li committed
227
228


Ethan Xu's avatar
Ethan Xu committed
229
@router.post("/v1/embeddings")
230
231
232
233
234
235
236
237
238
239
async def create_embedding(request: EmbeddingRequest, raw_request: Request):
    generator = await openai_serving_embedding.create_embedding(
        request, raw_request)
    if isinstance(generator, ErrorResponse):
        return JSONResponse(content=generator.model_dump(),
                            status_code=generator.code)
    else:
        return JSONResponse(content=generator.model_dump())


240
241
def build_app(args: Namespace) -> FastAPI:
    app = FastAPI(lifespan=lifespan)
Ethan Xu's avatar
Ethan Xu committed
242
243
    app.include_router(router)
    app.root_path = args.root_path
Zhuohan Li's avatar
Zhuohan Li committed
244

245
246
    mount_metrics(app)

Zhuohan Li's avatar
Zhuohan Li committed
247
248
249
250
251
252
253
254
    app.add_middleware(
        CORSMiddleware,
        allow_origins=args.allowed_origins,
        allow_credentials=args.allow_credentials,
        allow_methods=args.allowed_methods,
        allow_headers=args.allowed_headers,
    )

Ethan Xu's avatar
Ethan Xu committed
255
256
257
258
259
260
    @app.exception_handler(RequestValidationError)
    async def validation_exception_handler(_, exc):
        err = openai_serving_chat.create_error_response(message=str(exc))
        return JSONResponse(err.model_dump(),
                            status_code=HTTPStatus.BAD_REQUEST)

261
    if token := envs.VLLM_API_KEY or args.api_key:
262
263
264

        @app.middleware("http")
        async def authentication(request: Request, call_next):
265
            root_path = "" if args.root_path is None else args.root_path
266
267
            if request.method == "OPTIONS":
                return await call_next(request)
268
            if not request.url.path.startswith(f"{root_path}/v1"):
269
270
271
272
273
274
275
276
277
278
279
280
281
282
                return await call_next(request)
            if request.headers.get("Authorization") != "Bearer " + token:
                return JSONResponse(content={"error": "Unauthorized"},
                                    status_code=401)
            return await call_next(request)

    for middleware in args.middleware:
        module_path, object_name = middleware.rsplit(".", 1)
        imported = getattr(importlib.import_module(module_path), object_name)
        if inspect.isclass(imported):
            app.add_middleware(imported)
        elif inspect.iscoroutinefunction(imported):
            app.middleware("http")(imported)
        else:
283
284
            raise ValueError(f"Invalid middleware {middleware}. "
                             f"Must be a function or a class.")
285

Ethan Xu's avatar
Ethan Xu committed
286
287
288
    return app


289
async def init_app(
290
    async_engine_client: AsyncEngineClient,
291
292
    args: Namespace,
) -> FastAPI:
Ethan Xu's avatar
Ethan Xu committed
293
294
    app = build_app(args)

295
    if args.served_model_name is not None:
296
        served_model_names = args.served_model_name
297
    else:
298
        served_model_names = [args.model]
299

300
    model_config = await async_engine_client.get_model_config()
301

302
303
304
305
306
    if args.disable_log_requests:
        request_logger = None
    else:
        request_logger = RequestLogger(max_log_len=args.max_log_len)

Ethan Xu's avatar
Ethan Xu committed
307
308
309
    global openai_serving_chat
    global openai_serving_completion
    global openai_serving_embedding
310
    global openai_serving_tokenization
Ethan Xu's avatar
Ethan Xu committed
311

312
    openai_serving_chat = OpenAIServingChat(
313
        async_engine_client,
314
315
316
317
318
319
320
        model_config,
        served_model_names,
        args.response_role,
        lora_modules=args.lora_modules,
        prompt_adapters=args.prompt_adapters,
        request_logger=request_logger,
        chat_template=args.chat_template,
321
        return_tokens_as_token_ids=args.return_tokens_as_token_ids,
322
    )
323
    openai_serving_completion = OpenAIServingCompletion(
324
        async_engine_client,
325
326
327
328
329
        model_config,
        served_model_names,
        lora_modules=args.lora_modules,
        prompt_adapters=args.prompt_adapters,
        request_logger=request_logger,
330
        return_tokens_as_token_ids=args.return_tokens_as_token_ids,
331
332
    )
    openai_serving_embedding = OpenAIServingEmbedding(
333
        async_engine_client,
334
335
336
337
        model_config,
        served_model_names,
        request_logger=request_logger,
    )
338
    openai_serving_tokenization = OpenAIServingTokenization(
339
        async_engine_client,
340
341
342
343
344
345
        model_config,
        served_model_names,
        lora_modules=args.lora_modules,
        request_logger=request_logger,
        chat_template=args.chat_template,
    )
346
    app.root_path = args.root_path
347

348
    return app
349
350


351
async def run_server(args, **uvicorn_kwargs) -> None:
352
353
354
    logger.info("vLLM API server version %s", VLLM_VERSION)
    logger.info("args: %s", args)

355
    async with build_async_engine_client(args) as async_engine_client:
356
357
358
359
        app = await init_app(async_engine_client, args)

        shutdown_task = await serve_http(
            app,
360
            engine=async_engine_client,
361
362
363
364
365
366
367
368
            host=args.host,
            port=args.port,
            log_level=args.uvicorn_log_level,
            timeout_keep_alive=TIMEOUT_KEEP_ALIVE,
            ssl_keyfile=args.ssl_keyfile,
            ssl_certfile=args.ssl_certfile,
            ssl_ca_certs=args.ssl_ca_certs,
            ssl_cert_reqs=args.ssl_cert_reqs,
369
370
371
            **uvicorn_kwargs,
        )

372
373
    # NB: Await server shutdown only after the backend context is exited
    await shutdown_task
374

Ethan Xu's avatar
Ethan Xu committed
375
376
377
378
379
380
381
382

if __name__ == "__main__":
    # NOTE(simon):
    # This section should be in sync with vllm/scripts.py for CLI entrypoints.
    parser = FlexibleArgumentParser(
        description="vLLM OpenAI-Compatible RESTful API server.")
    parser = make_arg_parser(parser)
    args = parser.parse_args()
383

384
    asyncio.run(run_server(args))