api_server.py 6.29 KB
Newer Older
1
import asyncio
2
3
import importlib
import inspect
4
import re
5
6
from contextlib import asynccontextmanager
from http import HTTPStatus
7

Zhuohan Li's avatar
Zhuohan Li committed
8
import fastapi
9
import uvicorn
10
from fastapi import Request
Zhuohan Li's avatar
Zhuohan Li committed
11
12
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
13
14
from fastapi.responses import JSONResponse, Response, StreamingResponse
from prometheus_client import make_asgi_app
15
from starlette.routing import Mount
Zhuohan Li's avatar
Zhuohan Li committed
16

17
import vllm
18
import vllm.envs as envs
Woosuk Kwon's avatar
Woosuk Kwon committed
19
20
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
21
from vllm.entrypoints.openai.cli_args import make_arg_parser
22
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
23
                                              ChatCompletionResponse,
24
                                              CompletionRequest, ErrorResponse)
25
26
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
27
from vllm.logger import init_logger
yhu422's avatar
yhu422 committed
28
from vllm.usage.usage_lib import UsageContext
Zhuohan Li's avatar
Zhuohan Li committed
29

30
TIMEOUT_KEEP_ALIVE = 5  # seconds
Zhuohan Li's avatar
Zhuohan Li committed
31

32
33
openai_serving_chat: OpenAIServingChat
openai_serving_completion: OpenAIServingCompletion
Zhuohan Li's avatar
Zhuohan Li committed
34
logger = init_logger(__name__)
35
36


37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
@asynccontextmanager
async def lifespan(app: fastapi.FastAPI):

    async def _force_log():
        while True:
            await asyncio.sleep(10)
            await engine.do_log_stats()

    if not engine_args.disable_log_stats:
        asyncio.create_task(_force_log())

    yield


app = fastapi.FastAPI(lifespan=lifespan)


54
def parse_args():
55
    parser = make_arg_parser()
56
    return parser.parse_args()
Zhuohan Li's avatar
Zhuohan Li committed
57
58


59
# Add prometheus asgi middleware to route /metrics requests
60
61
62
63
route = Mount("/metrics", make_asgi_app())
# Workaround for 307 Redirect for /metrics
route.path_regex = re.compile('^/metrics(?P<path>.*)$')
app.routes.append(route)
64
65


Zhuohan Li's avatar
Zhuohan Li committed
66
@app.exception_handler(RequestValidationError)
67
async def validation_exception_handler(_, exc):
68
    err = openai_serving_chat.create_error_response(message=str(exc))
69
    return JSONResponse(err.model_dump(), status_code=HTTPStatus.BAD_REQUEST)
70
71


72
73
74
@app.get("/health")
async def health() -> Response:
    """Health check."""
75
    await openai_serving_chat.engine.check_health()
76
77
78
    return Response(status_code=200)


Zhuohan Li's avatar
Zhuohan Li committed
79
80
@app.get("/v1/models")
async def show_available_models():
81
    models = await openai_serving_chat.show_available_models()
82
    return JSONResponse(content=models.model_dump())
Zhuohan Li's avatar
Zhuohan Li committed
83
84


85
86
87
88
89
90
@app.get("/version")
async def show_version():
    ver = {"version": vllm.__version__}
    return JSONResponse(content=ver)


91
@app.post("/v1/chat/completions")
92
93
async def create_chat_completion(request: ChatCompletionRequest,
                                 raw_request: Request):
94
95
    generator = await openai_serving_chat.create_chat_completion(
        request, raw_request)
96
97
98
99
    if isinstance(generator, ErrorResponse):
        return JSONResponse(content=generator.model_dump(),
                            status_code=generator.code)
    if request.stream:
100
        return StreamingResponse(content=generator,
101
                                 media_type="text/event-stream")
102
    else:
103
        assert isinstance(generator, ChatCompletionResponse)
104
        return JSONResponse(content=generator.model_dump())
105
106


Zhuohan Li's avatar
Zhuohan Li committed
107
@app.post("/v1/completions")
108
async def create_completion(request: CompletionRequest, raw_request: Request):
109
110
    generator = await openai_serving_completion.create_completion(
        request, raw_request)
111
112
113
114
    if isinstance(generator, ErrorResponse):
        return JSONResponse(content=generator.model_dump(),
                            status_code=generator.code)
    if request.stream:
115
        return StreamingResponse(content=generator,
Zhuohan Li's avatar
Zhuohan Li committed
116
                                 media_type="text/event-stream")
117
    else:
118
        return JSONResponse(content=generator.model_dump())
Zhuohan Li's avatar
Zhuohan Li committed
119
120
121


if __name__ == "__main__":
122
    args = parse_args()
Zhuohan Li's avatar
Zhuohan Li committed
123
124
125
126
127
128
129
130
131

    app.add_middleware(
        CORSMiddleware,
        allow_origins=args.allowed_origins,
        allow_credentials=args.allow_credentials,
        allow_methods=args.allowed_methods,
        allow_headers=args.allowed_headers,
    )

132
    if token := envs.VLLM_API_KEY or args.api_key:
133
134
135

        @app.middleware("http")
        async def authentication(request: Request, call_next):
136
137
            root_path = "" if args.root_path is None else args.root_path
            if not request.url.path.startswith(f"{root_path}/v1"):
138
139
140
141
142
143
144
145
146
147
148
149
150
151
                return await call_next(request)
            if request.headers.get("Authorization") != "Bearer " + token:
                return JSONResponse(content={"error": "Unauthorized"},
                                    status_code=401)
            return await call_next(request)

    for middleware in args.middleware:
        module_path, object_name = middleware.rsplit(".", 1)
        imported = getattr(importlib.import_module(module_path), object_name)
        if inspect.isclass(imported):
            app.add_middleware(imported)
        elif inspect.iscoroutinefunction(imported):
            app.middleware("http")(imported)
        else:
152
153
            raise ValueError(f"Invalid middleware {middleware}. "
                             f"Must be a function or a class.")
154

155
156
    logger.info("vLLM API server version %s", vllm.__version__)
    logger.info("args: %s", args)
Zhuohan Li's avatar
Zhuohan Li committed
157

158
    if args.served_model_name is not None:
159
        served_model_names = args.served_model_name
160
    else:
161
        served_model_names = [args.model]
Zhuohan Li's avatar
Zhuohan Li committed
162
    engine_args = AsyncEngineArgs.from_cli_args(args)
yhu422's avatar
yhu422 committed
163
164
    engine = AsyncLLMEngine.from_engine_args(
        engine_args, usage_context=UsageContext.OPENAI_API_SERVER)
165
    openai_serving_chat = OpenAIServingChat(engine, served_model_names,
166
                                            args.response_role,
167
                                            args.lora_modules,
168
                                            args.chat_template)
169
    openai_serving_completion = OpenAIServingCompletion(
170
        engine, served_model_names, args.lora_modules)
Zhuohan Li's avatar
Zhuohan Li committed
171

172
    app.root_path = args.root_path
173
174
175
    uvicorn.run(app,
                host=args.host,
                port=args.port,
176
                log_level=args.uvicorn_log_level,
177
178
                timeout_keep_alive=TIMEOUT_KEEP_ALIVE,
                ssl_keyfile=args.ssl_keyfile,
179
180
181
                ssl_certfile=args.ssl_certfile,
                ssl_ca_certs=args.ssl_ca_certs,
                ssl_cert_reqs=args.ssl_cert_reqs)