api_server.py 6.03 KB
Newer Older
1
import asyncio
2
3
import importlib
import inspect
4
5
6
import os
from contextlib import asynccontextmanager
from http import HTTPStatus
7

Zhuohan Li's avatar
Zhuohan Li committed
8
import fastapi
9
import uvicorn
10
from fastapi import Request
Zhuohan Li's avatar
Zhuohan Li committed
11
12
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
13
14
from fastapi.responses import JSONResponse, Response, StreamingResponse
from prometheus_client import make_asgi_app
Zhuohan Li's avatar
Zhuohan Li committed
15

16
import vllm
Woosuk Kwon's avatar
Woosuk Kwon committed
17
18
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
19
from vllm.entrypoints.openai.cli_args import make_arg_parser
20
21
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
                                              CompletionRequest, ErrorResponse)
22
23
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
24
from vllm.logger import init_logger
yhu422's avatar
yhu422 committed
25
from vllm.usage.usage_lib import UsageContext
Zhuohan Li's avatar
Zhuohan Li committed
26

27
TIMEOUT_KEEP_ALIVE = 5  # seconds
Zhuohan Li's avatar
Zhuohan Li committed
28

29
30
openai_serving_chat: OpenAIServingChat = None
openai_serving_completion: OpenAIServingCompletion = None
Zhuohan Li's avatar
Zhuohan Li committed
31
logger = init_logger(__name__)
32
33


34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
@asynccontextmanager
async def lifespan(app: fastapi.FastAPI):

    async def _force_log():
        while True:
            await asyncio.sleep(10)
            await engine.do_log_stats()

    if not engine_args.disable_log_stats:
        asyncio.create_task(_force_log())

    yield


app = fastapi.FastAPI(lifespan=lifespan)


51
def parse_args():
52
    parser = make_arg_parser()
53
    return parser.parse_args()
Zhuohan Li's avatar
Zhuohan Li committed
54
55


56
57
58
# Add prometheus asgi middleware to route /metrics requests
metrics_app = make_asgi_app()
app.mount("/metrics", metrics_app)
59
60


Zhuohan Li's avatar
Zhuohan Li committed
61
@app.exception_handler(RequestValidationError)
62
async def validation_exception_handler(_, exc):
63
    err = openai_serving_chat.create_error_response(message=str(exc))
64
    return JSONResponse(err.model_dump(), status_code=HTTPStatus.BAD_REQUEST)
65
66


67
68
69
@app.get("/health")
async def health() -> Response:
    """Health check."""
70
    await openai_serving_chat.engine.check_health()
71
72
73
    return Response(status_code=200)


Zhuohan Li's avatar
Zhuohan Li committed
74
75
@app.get("/v1/models")
async def show_available_models():
76
    models = await openai_serving_chat.show_available_models()
77
    return JSONResponse(content=models.model_dump())
Zhuohan Li's avatar
Zhuohan Li committed
78
79


80
81
82
83
84
85
@app.get("/version")
async def show_version():
    ver = {"version": vllm.__version__}
    return JSONResponse(content=ver)


86
@app.post("/v1/chat/completions")
87
88
async def create_chat_completion(request: ChatCompletionRequest,
                                 raw_request: Request):
89
90
    generator = await openai_serving_chat.create_chat_completion(
        request, raw_request)
91
92
93
94
    if isinstance(generator, ErrorResponse):
        return JSONResponse(content=generator.model_dump(),
                            status_code=generator.code)
    if request.stream:
95
        return StreamingResponse(content=generator,
96
                                 media_type="text/event-stream")
97
    else:
98
        return JSONResponse(content=generator.model_dump())
99
100


Zhuohan Li's avatar
Zhuohan Li committed
101
@app.post("/v1/completions")
102
async def create_completion(request: CompletionRequest, raw_request: Request):
103
104
    generator = await openai_serving_completion.create_completion(
        request, raw_request)
105
106
107
108
    if isinstance(generator, ErrorResponse):
        return JSONResponse(content=generator.model_dump(),
                            status_code=generator.code)
    if request.stream:
109
        return StreamingResponse(content=generator,
Zhuohan Li's avatar
Zhuohan Li committed
110
                                 media_type="text/event-stream")
111
    else:
112
        return JSONResponse(content=generator.model_dump())
Zhuohan Li's avatar
Zhuohan Li committed
113
114
115


if __name__ == "__main__":
116
    args = parse_args()
Zhuohan Li's avatar
Zhuohan Li committed
117
118
119
120
121
122
123
124
125

    app.add_middleware(
        CORSMiddleware,
        allow_origins=args.allowed_origins,
        allow_credentials=args.allow_credentials,
        allow_methods=args.allowed_methods,
        allow_headers=args.allowed_headers,
    )

126
127
128
129
    if token := os.environ.get("VLLM_API_KEY") or args.api_key:

        @app.middleware("http")
        async def authentication(request: Request, call_next):
130
131
            root_path = "" if args.root_path is None else args.root_path
            if not request.url.path.startswith(f"{root_path}/v1"):
132
133
134
135
136
137
138
139
140
141
142
143
144
145
                return await call_next(request)
            if request.headers.get("Authorization") != "Bearer " + token:
                return JSONResponse(content={"error": "Unauthorized"},
                                    status_code=401)
            return await call_next(request)

    for middleware in args.middleware:
        module_path, object_name = middleware.rsplit(".", 1)
        imported = getattr(importlib.import_module(module_path), object_name)
        if inspect.isclass(imported):
            app.add_middleware(imported)
        elif inspect.iscoroutinefunction(imported):
            app.middleware("http")(imported)
        else:
146
147
            raise ValueError(f"Invalid middleware {middleware}. "
                             f"Must be a function or a class.")
148

149
    logger.info(f"vLLM API server version {vllm.__version__}")
Zhuohan Li's avatar
Zhuohan Li committed
150
151
    logger.info(f"args: {args}")

152
    if args.served_model_name is not None:
153
        served_model_names = args.served_model_name
154
    else:
155
        served_model_names = [args.model]
Zhuohan Li's avatar
Zhuohan Li committed
156
    engine_args = AsyncEngineArgs.from_cli_args(args)
yhu422's avatar
yhu422 committed
157
158
    engine = AsyncLLMEngine.from_engine_args(
        engine_args, usage_context=UsageContext.OPENAI_API_SERVER)
159
    openai_serving_chat = OpenAIServingChat(engine, served_model_names,
160
                                            args.response_role,
161
                                            args.lora_modules,
162
                                            args.chat_template)
163
    openai_serving_completion = OpenAIServingCompletion(
164
        engine, served_model_names, args.lora_modules)
Zhuohan Li's avatar
Zhuohan Li committed
165

166
    app.root_path = args.root_path
167
168
169
    uvicorn.run(app,
                host=args.host,
                port=args.port,
170
                log_level=args.uvicorn_log_level,
171
172
                timeout_keep_alive=TIMEOUT_KEEP_ALIVE,
                ssl_keyfile=args.ssl_keyfile,
173
174
175
                ssl_certfile=args.ssl_certfile,
                ssl_ca_certs=args.ssl_ca_certs,
                ssl_cert_reqs=args.ssl_cert_reqs)