api_server.py 6.74 KB
Newer Older
Zhuohan Li's avatar
Zhuohan Li committed
1
import argparse
2
import asyncio
Zhuohan Li's avatar
Zhuohan Li committed
3
import json
4
from contextlib import asynccontextmanager
5
6
from aioprometheus import MetricsMiddleware
from aioprometheus.asgi.starlette import metrics
Zhuohan Li's avatar
Zhuohan Li committed
7
import fastapi
8
import uvicorn
9
from http import HTTPStatus
10
from fastapi import Request
Zhuohan Li's avatar
Zhuohan Li committed
11
12
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
13
from fastapi.responses import JSONResponse, StreamingResponse, Response
Zhuohan Li's avatar
Zhuohan Li committed
14

Woosuk Kwon's avatar
Woosuk Kwon committed
15
16
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
17
from vllm.engine.metrics import add_global_metrics_labels
18
from vllm.entrypoints.openai.protocol import CompletionRequest, ChatCompletionRequest, ErrorResponse
Woosuk Kwon's avatar
Woosuk Kwon committed
19
from vllm.logger import init_logger
20
21
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
Zhuohan Li's avatar
Zhuohan Li committed
22

23
TIMEOUT_KEEP_ALIVE = 5  # seconds
Zhuohan Li's avatar
Zhuohan Li committed
24

25
26
openai_serving_chat: OpenAIServingChat = None
openai_serving_completion: OpenAIServingCompletion = None
Zhuohan Li's avatar
Zhuohan Li committed
27
logger = init_logger(__name__)
28
29


30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
@asynccontextmanager
async def lifespan(app: fastapi.FastAPI):

    async def _force_log():
        while True:
            await asyncio.sleep(10)
            await engine.do_log_stats()

    if not engine_args.disable_log_stats:
        asyncio.create_task(_force_log())

    yield


app = fastapi.FastAPI(lifespan=lifespan)


47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
def parse_args():
    parser = argparse.ArgumentParser(
        description="vLLM OpenAI-Compatible RESTful API server.")
    parser.add_argument("--host", type=str, default=None, help="host name")
    parser.add_argument("--port", type=int, default=8000, help="port number")
    parser.add_argument("--allow-credentials",
                        action="store_true",
                        help="allow credentials")
    parser.add_argument("--allowed-origins",
                        type=json.loads,
                        default=["*"],
                        help="allowed origins")
    parser.add_argument("--allowed-methods",
                        type=json.loads,
                        default=["*"],
                        help="allowed methods")
    parser.add_argument("--allowed-headers",
                        type=json.loads,
                        default=["*"],
                        help="allowed headers")
    parser.add_argument("--served-model-name",
                        type=str,
                        default=None,
                        help="The model name used in the API. If not "
                        "specified, the model name will be the same as "
                        "the huggingface name.")
    parser.add_argument("--chat-template",
                        type=str,
                        default=None,
                        help="The file path to the chat template, "
                        "or the template in single-line form "
                        "for the specified model")
    parser.add_argument("--response-role",
                        type=str,
                        default="assistant",
                        help="The role name to return if "
                        "`request.add_generation_prompt=true`.")
84
85
86
87
88
89
90
91
    parser.add_argument("--ssl-keyfile",
                        type=str,
                        default=None,
                        help="The file path to the SSL key file")
    parser.add_argument("--ssl-certfile",
                        type=str,
                        default=None,
                        help="The file path to the SSL cert file")
92
93
94
95
96
    parser.add_argument(
        "--root-path",
        type=str,
        default=None,
        help="FastAPI root_path when app is behind a path based routing proxy")
97
98
99

    parser = AsyncEngineArgs.add_cli_args(parser)
    return parser.parse_args()
Zhuohan Li's avatar
Zhuohan Li committed
100
101


102
103
104
105
app.add_middleware(MetricsMiddleware)  # Trace HTTP server metrics
app.add_route("/metrics", metrics)  # Exposes HTTP metrics


Zhuohan Li's avatar
Zhuohan Li committed
106
@app.exception_handler(RequestValidationError)
107
async def validation_exception_handler(_, exc):
108
    err = openai_serving_chat.create_error_response(message=str(exc))
109
    return JSONResponse(err.model_dump(), status_code=HTTPStatus.BAD_REQUEST)
110
111


112
113
114
115
116
117
@app.get("/health")
async def health() -> Response:
    """Health check."""
    return Response(status_code=200)


Zhuohan Li's avatar
Zhuohan Li committed
118
119
@app.get("/v1/models")
async def show_available_models():
120
    models = await openai_serving_chat.show_available_models()
121
    return JSONResponse(content=models.model_dump())
Zhuohan Li's avatar
Zhuohan Li committed
122
123


124
@app.post("/v1/chat/completions")
125
126
async def create_chat_completion(request: ChatCompletionRequest,
                                 raw_request: Request):
127
128
    generator = await openai_serving_chat.create_chat_completion(
        request, raw_request)
129
130
131
132
    if isinstance(generator, ErrorResponse):
        return JSONResponse(content=generator.model_dump(),
                            status_code=generator.code)
    if request.stream:
133
        return StreamingResponse(content=generator,
134
                                 media_type="text/event-stream")
135
    else:
136
        return JSONResponse(content=generator.model_dump())
137
138


Zhuohan Li's avatar
Zhuohan Li committed
139
@app.post("/v1/completions")
140
async def create_completion(request: CompletionRequest, raw_request: Request):
141
142
    generator = await openai_serving_completion.create_completion(
        request, raw_request)
143
144
145
146
    if isinstance(generator, ErrorResponse):
        return JSONResponse(content=generator.model_dump(),
                            status_code=generator.code)
    if request.stream:
147
        return StreamingResponse(content=generator,
Zhuohan Li's avatar
Zhuohan Li committed
148
                                 media_type="text/event-stream")
149
    else:
150
        return JSONResponse(content=generator.model_dump())
Zhuohan Li's avatar
Zhuohan Li committed
151
152
153


if __name__ == "__main__":
154
    args = parse_args()
Zhuohan Li's avatar
Zhuohan Li committed
155
156
157
158
159
160
161
162
163
164
165

    app.add_middleware(
        CORSMiddleware,
        allow_origins=args.allowed_origins,
        allow_credentials=args.allow_credentials,
        allow_methods=args.allowed_methods,
        allow_headers=args.allowed_headers,
    )

    logger.info(f"args: {args}")

166
167
168
169
170
    if args.served_model_name is not None:
        served_model = args.served_model_name
    else:
        served_model = args.model

Zhuohan Li's avatar
Zhuohan Li committed
171
    engine_args = AsyncEngineArgs.from_cli_args(args)
172
    engine = AsyncLLMEngine.from_engine_args(engine_args)
173
174
175
176
    openai_serving_chat = OpenAIServingChat(engine, served_model,
                                            args.response_role,
                                            args.chat_template)
    openai_serving_completion = OpenAIServingCompletion(engine, served_model)
Zhuohan Li's avatar
Zhuohan Li committed
177

178
179
180
    # Register labels for metrics
    add_global_metrics_labels(model_name=engine_args.model)

181
    app.root_path = args.root_path
182
183
184
185
    uvicorn.run(app,
                host=args.host,
                port=args.port,
                log_level="info",
186
187
188
                timeout_keep_alive=TIMEOUT_KEEP_ALIVE,
                ssl_keyfile=args.ssl_keyfile,
                ssl_certfile=args.ssl_certfile)