api_server.py 6.46 KB
Newer Older
1
import asyncio
2
3
import importlib
import inspect
4
import re
5
6
from contextlib import asynccontextmanager
from http import HTTPStatus
7
from typing import Any, Set
8

Zhuohan Li's avatar
Zhuohan Li committed
9
import fastapi
10
import uvicorn
11
from fastapi import Request
Zhuohan Li's avatar
Zhuohan Li committed
12
13
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
14
15
from fastapi.responses import JSONResponse, Response, StreamingResponse
from prometheus_client import make_asgi_app
16
from starlette.routing import Mount
Zhuohan Li's avatar
Zhuohan Li committed
17

18
import vllm
19
import vllm.envs as envs
Woosuk Kwon's avatar
Woosuk Kwon committed
20
21
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
22
from vllm.entrypoints.openai.cli_args import make_arg_parser
23
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
24
                                              ChatCompletionResponse,
25
                                              CompletionRequest, ErrorResponse)
26
27
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
28
from vllm.logger import init_logger
yhu422's avatar
yhu422 committed
29
from vllm.usage.usage_lib import UsageContext
Zhuohan Li's avatar
Zhuohan Li committed
30

31
TIMEOUT_KEEP_ALIVE = 5  # seconds
Zhuohan Li's avatar
Zhuohan Li committed
32

33
34
openai_serving_chat: OpenAIServingChat
openai_serving_completion: OpenAIServingCompletion
Zhuohan Li's avatar
Zhuohan Li committed
35
logger = init_logger(__name__)
36

37
38
_running_tasks: Set[asyncio.Task[Any]] = set()

39

40
41
42
43
44
45
46
47
48
@asynccontextmanager
async def lifespan(app: fastapi.FastAPI):

    async def _force_log():
        while True:
            await asyncio.sleep(10)
            await engine.do_log_stats()

    if not engine_args.disable_log_stats:
49
50
51
        task = asyncio.create_task(_force_log())
        _running_tasks.add(task)
        task.add_done_callback(_running_tasks.remove)
52
53
54
55
56
57
58

    yield


app = fastapi.FastAPI(lifespan=lifespan)


59
def parse_args():
60
    parser = make_arg_parser()
61
    return parser.parse_args()
Zhuohan Li's avatar
Zhuohan Li committed
62
63


64
# Add prometheus asgi middleware to route /metrics requests
65
66
67
68
route = Mount("/metrics", make_asgi_app())
# Workaround for 307 Redirect for /metrics
route.path_regex = re.compile('^/metrics(?P<path>.*)$')
app.routes.append(route)
69
70


Zhuohan Li's avatar
Zhuohan Li committed
71
@app.exception_handler(RequestValidationError)
72
async def validation_exception_handler(_, exc):
73
    err = openai_serving_chat.create_error_response(message=str(exc))
74
    return JSONResponse(err.model_dump(), status_code=HTTPStatus.BAD_REQUEST)
75
76


77
78
79
@app.get("/health")
async def health() -> Response:
    """Health check."""
80
    await openai_serving_chat.engine.check_health()
81
82
83
    return Response(status_code=200)


Zhuohan Li's avatar
Zhuohan Li committed
84
85
@app.get("/v1/models")
async def show_available_models():
86
    models = await openai_serving_chat.show_available_models()
87
    return JSONResponse(content=models.model_dump())
Zhuohan Li's avatar
Zhuohan Li committed
88
89


90
91
92
93
94
95
@app.get("/version")
async def show_version():
    ver = {"version": vllm.__version__}
    return JSONResponse(content=ver)


96
@app.post("/v1/chat/completions")
97
98
async def create_chat_completion(request: ChatCompletionRequest,
                                 raw_request: Request):
99
100
    generator = await openai_serving_chat.create_chat_completion(
        request, raw_request)
101
102
103
104
    if isinstance(generator, ErrorResponse):
        return JSONResponse(content=generator.model_dump(),
                            status_code=generator.code)
    if request.stream:
105
        return StreamingResponse(content=generator,
106
                                 media_type="text/event-stream")
107
    else:
108
        assert isinstance(generator, ChatCompletionResponse)
109
        return JSONResponse(content=generator.model_dump())
110
111


Zhuohan Li's avatar
Zhuohan Li committed
112
@app.post("/v1/completions")
113
async def create_completion(request: CompletionRequest, raw_request: Request):
114
115
    generator = await openai_serving_completion.create_completion(
        request, raw_request)
116
117
118
119
    if isinstance(generator, ErrorResponse):
        return JSONResponse(content=generator.model_dump(),
                            status_code=generator.code)
    if request.stream:
120
        return StreamingResponse(content=generator,
Zhuohan Li's avatar
Zhuohan Li committed
121
                                 media_type="text/event-stream")
122
    else:
123
        return JSONResponse(content=generator.model_dump())
Zhuohan Li's avatar
Zhuohan Li committed
124
125
126


if __name__ == "__main__":
127
    args = parse_args()
Zhuohan Li's avatar
Zhuohan Li committed
128
129
130
131
132
133
134
135
136

    app.add_middleware(
        CORSMiddleware,
        allow_origins=args.allowed_origins,
        allow_credentials=args.allow_credentials,
        allow_methods=args.allowed_methods,
        allow_headers=args.allowed_headers,
    )

137
    if token := envs.VLLM_API_KEY or args.api_key:
138
139
140

        @app.middleware("http")
        async def authentication(request: Request, call_next):
141
142
            root_path = "" if args.root_path is None else args.root_path
            if not request.url.path.startswith(f"{root_path}/v1"):
143
144
145
146
147
148
149
150
151
152
153
154
155
156
                return await call_next(request)
            if request.headers.get("Authorization") != "Bearer " + token:
                return JSONResponse(content={"error": "Unauthorized"},
                                    status_code=401)
            return await call_next(request)

    for middleware in args.middleware:
        module_path, object_name = middleware.rsplit(".", 1)
        imported = getattr(importlib.import_module(module_path), object_name)
        if inspect.isclass(imported):
            app.add_middleware(imported)
        elif inspect.iscoroutinefunction(imported):
            app.middleware("http")(imported)
        else:
157
158
            raise ValueError(f"Invalid middleware {middleware}. "
                             f"Must be a function or a class.")
159

160
161
    logger.info("vLLM API server version %s", vllm.__version__)
    logger.info("args: %s", args)
Zhuohan Li's avatar
Zhuohan Li committed
162

163
    if args.served_model_name is not None:
164
        served_model_names = args.served_model_name
165
    else:
166
        served_model_names = [args.model]
Zhuohan Li's avatar
Zhuohan Li committed
167
    engine_args = AsyncEngineArgs.from_cli_args(args)
yhu422's avatar
yhu422 committed
168
169
    engine = AsyncLLMEngine.from_engine_args(
        engine_args, usage_context=UsageContext.OPENAI_API_SERVER)
170
    openai_serving_chat = OpenAIServingChat(engine, served_model_names,
171
                                            args.response_role,
172
                                            args.lora_modules,
173
                                            args.chat_template)
174
    openai_serving_completion = OpenAIServingCompletion(
175
        engine, served_model_names, args.lora_modules)
Zhuohan Li's avatar
Zhuohan Li committed
176

177
    app.root_path = args.root_path
178
179
180
    uvicorn.run(app,
                host=args.host,
                port=args.port,
181
                log_level=args.uvicorn_log_level,
182
183
                timeout_keep_alive=TIMEOUT_KEEP_ALIVE,
                ssl_keyfile=args.ssl_keyfile,
184
185
186
                ssl_certfile=args.ssl_certfile,
                ssl_ca_certs=args.ssl_ca_certs,
                ssl_cert_reqs=args.ssl_cert_reqs)