api_server.py 6.15 KB
Newer Older
1
import asyncio
2
3
import importlib
import inspect
4
5
6
import os
from contextlib import asynccontextmanager
from http import HTTPStatus
7

Zhuohan Li's avatar
Zhuohan Li committed
8
import fastapi
9
import uvicorn
10
from fastapi import Request
Zhuohan Li's avatar
Zhuohan Li committed
11
12
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
13
14
from fastapi.responses import JSONResponse, Response, StreamingResponse
from prometheus_client import make_asgi_app
Zhuohan Li's avatar
Zhuohan Li committed
15

16
import vllm
Woosuk Kwon's avatar
Woosuk Kwon committed
17
18
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
19
from vllm.entrypoints.openai.cli_args import make_arg_parser
20
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
21
                                              ChatCompletionResponse,
22
                                              CompletionRequest, ErrorResponse)
23
24
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
25
from vllm.logger import init_logger
yhu422's avatar
yhu422 committed
26
from vllm.usage.usage_lib import UsageContext
Zhuohan Li's avatar
Zhuohan Li committed
27

28
TIMEOUT_KEEP_ALIVE = 5  # seconds
Zhuohan Li's avatar
Zhuohan Li committed
29

30
31
openai_serving_chat: OpenAIServingChat
openai_serving_completion: OpenAIServingCompletion
Zhuohan Li's avatar
Zhuohan Li committed
32
logger = init_logger(__name__)
33
34


35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
@asynccontextmanager
async def lifespan(app: fastapi.FastAPI):

    async def _force_log():
        while True:
            await asyncio.sleep(10)
            await engine.do_log_stats()

    if not engine_args.disable_log_stats:
        asyncio.create_task(_force_log())

    yield


app = fastapi.FastAPI(lifespan=lifespan)


52
def parse_args():
53
    parser = make_arg_parser()
54
    return parser.parse_args()
Zhuohan Li's avatar
Zhuohan Li committed
55
56


57
58
59
# Add prometheus asgi middleware to route /metrics requests
metrics_app = make_asgi_app()
app.mount("/metrics", metrics_app)
60
61


Zhuohan Li's avatar
Zhuohan Li committed
62
@app.exception_handler(RequestValidationError)
63
async def validation_exception_handler(_, exc):
64
    err = openai_serving_chat.create_error_response(message=str(exc))
65
    return JSONResponse(err.model_dump(), status_code=HTTPStatus.BAD_REQUEST)
66
67


68
69
70
@app.get("/health")
async def health() -> Response:
    """Health check."""
71
    await openai_serving_chat.engine.check_health()
72
73
74
    return Response(status_code=200)


Zhuohan Li's avatar
Zhuohan Li committed
75
76
@app.get("/v1/models")
async def show_available_models():
77
    models = await openai_serving_chat.show_available_models()
78
    return JSONResponse(content=models.model_dump())
Zhuohan Li's avatar
Zhuohan Li committed
79
80


81
82
83
84
85
86
@app.get("/version")
async def show_version():
    ver = {"version": vllm.__version__}
    return JSONResponse(content=ver)


87
@app.post("/v1/chat/completions")
88
89
async def create_chat_completion(request: ChatCompletionRequest,
                                 raw_request: Request):
90
91
    generator = await openai_serving_chat.create_chat_completion(
        request, raw_request)
92
93
94
95
    if isinstance(generator, ErrorResponse):
        return JSONResponse(content=generator.model_dump(),
                            status_code=generator.code)
    if request.stream:
96
        return StreamingResponse(content=generator,
97
                                 media_type="text/event-stream")
98
    else:
99
        assert isinstance(generator, ChatCompletionResponse)
100
        return JSONResponse(content=generator.model_dump())
101
102


Zhuohan Li's avatar
Zhuohan Li committed
103
@app.post("/v1/completions")
104
async def create_completion(request: CompletionRequest, raw_request: Request):
105
106
    generator = await openai_serving_completion.create_completion(
        request, raw_request)
107
108
109
110
    if isinstance(generator, ErrorResponse):
        return JSONResponse(content=generator.model_dump(),
                            status_code=generator.code)
    if request.stream:
111
        return StreamingResponse(content=generator,
Zhuohan Li's avatar
Zhuohan Li committed
112
                                 media_type="text/event-stream")
113
    else:
114
        return JSONResponse(content=generator.model_dump())
Zhuohan Li's avatar
Zhuohan Li committed
115
116
117


if __name__ == "__main__":
118
    args = parse_args()
Zhuohan Li's avatar
Zhuohan Li committed
119
120
121
122
123
124
125
126
127

    app.add_middleware(
        CORSMiddleware,
        allow_origins=args.allowed_origins,
        allow_credentials=args.allow_credentials,
        allow_methods=args.allowed_methods,
        allow_headers=args.allowed_headers,
    )

128
129
130
131
    if token := os.environ.get("VLLM_API_KEY") or args.api_key:

        @app.middleware("http")
        async def authentication(request: Request, call_next):
132
133
            root_path = "" if args.root_path is None else args.root_path
            if not request.url.path.startswith(f"{root_path}/v1"):
134
135
136
137
138
139
140
141
142
143
144
145
146
147
                return await call_next(request)
            if request.headers.get("Authorization") != "Bearer " + token:
                return JSONResponse(content={"error": "Unauthorized"},
                                    status_code=401)
            return await call_next(request)

    for middleware in args.middleware:
        module_path, object_name = middleware.rsplit(".", 1)
        imported = getattr(importlib.import_module(module_path), object_name)
        if inspect.isclass(imported):
            app.add_middleware(imported)
        elif inspect.iscoroutinefunction(imported):
            app.middleware("http")(imported)
        else:
148
149
            raise ValueError(f"Invalid middleware {middleware}. "
                             f"Must be a function or a class.")
150

151
152
    logger.info("vLLM API server version %s", vllm.__version__)
    logger.info("args: %s", args)
Zhuohan Li's avatar
Zhuohan Li committed
153

154
    if args.served_model_name is not None:
155
        served_model_names = args.served_model_name
156
    else:
157
        served_model_names = [args.model]
Zhuohan Li's avatar
Zhuohan Li committed
158
    engine_args = AsyncEngineArgs.from_cli_args(args)
yhu422's avatar
yhu422 committed
159
160
    engine = AsyncLLMEngine.from_engine_args(
        engine_args, usage_context=UsageContext.OPENAI_API_SERVER)
161
    openai_serving_chat = OpenAIServingChat(engine, served_model_names,
162
                                            args.response_role,
163
                                            args.lora_modules,
164
                                            args.chat_template)
165
    openai_serving_completion = OpenAIServingCompletion(
166
        engine, served_model_names, args.lora_modules)
Zhuohan Li's avatar
Zhuohan Li committed
167

168
    app.root_path = args.root_path
169
170
171
    uvicorn.run(app,
                host=args.host,
                port=args.port,
172
                log_level=args.uvicorn_log_level,
173
174
                timeout_keep_alive=TIMEOUT_KEEP_ALIVE,
                ssl_keyfile=args.ssl_keyfile,
175
176
177
                ssl_certfile=args.ssl_certfile,
                ssl_ca_certs=args.ssl_ca_certs,
                ssl_cert_reqs=args.ssl_cert_reqs)