api_server.py 6.49 KB
Newer Older
Zhuohan Li's avatar
Zhuohan Li committed
1
import argparse
2
import asyncio
Zhuohan Li's avatar
Zhuohan Li committed
3
import json
4
from contextlib import asynccontextmanager
5
6
from aioprometheus import MetricsMiddleware
from aioprometheus.asgi.starlette import metrics
Zhuohan Li's avatar
Zhuohan Li committed
7
import fastapi
8
import uvicorn
9
from http import HTTPStatus
10
from fastapi import Request
Zhuohan Li's avatar
Zhuohan Li committed
11
12
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
13
from fastapi.responses import JSONResponse, StreamingResponse, Response
Zhuohan Li's avatar
Zhuohan Li committed
14

Woosuk Kwon's avatar
Woosuk Kwon committed
15
16
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
17
from vllm.engine.metrics import add_global_metrics_labels
18
from vllm.entrypoints.openai.protocol import CompletionRequest, ChatCompletionRequest, ErrorResponse
Woosuk Kwon's avatar
Woosuk Kwon committed
19
from vllm.logger import init_logger
20
21
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
Zhuohan Li's avatar
Zhuohan Li committed
22

23
TIMEOUT_KEEP_ALIVE = 5  # seconds
Zhuohan Li's avatar
Zhuohan Li committed
24

25
26
openai_serving_chat: OpenAIServingChat = None
openai_serving_completion: OpenAIServingCompletion = None
Zhuohan Li's avatar
Zhuohan Li committed
27
logger = init_logger(__name__)
28
29


30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
@asynccontextmanager
async def lifespan(app: fastapi.FastAPI):

    async def _force_log():
        while True:
            await asyncio.sleep(10)
            await engine.do_log_stats()

    if not engine_args.disable_log_stats:
        asyncio.create_task(_force_log())

    yield


app = fastapi.FastAPI(lifespan=lifespan)


47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
def parse_args():
    parser = argparse.ArgumentParser(
        description="vLLM OpenAI-Compatible RESTful API server.")
    parser.add_argument("--host", type=str, default=None, help="host name")
    parser.add_argument("--port", type=int, default=8000, help="port number")
    parser.add_argument("--allow-credentials",
                        action="store_true",
                        help="allow credentials")
    parser.add_argument("--allowed-origins",
                        type=json.loads,
                        default=["*"],
                        help="allowed origins")
    parser.add_argument("--allowed-methods",
                        type=json.loads,
                        default=["*"],
                        help="allowed methods")
    parser.add_argument("--allowed-headers",
                        type=json.loads,
                        default=["*"],
                        help="allowed headers")
    parser.add_argument("--served-model-name",
                        type=str,
                        default=None,
                        help="The model name used in the API. If not "
                        "specified, the model name will be the same as "
                        "the huggingface name.")
    parser.add_argument("--chat-template",
                        type=str,
                        default=None,
                        help="The file path to the chat template, "
                        "or the template in single-line form "
                        "for the specified model")
    parser.add_argument("--response-role",
                        type=str,
                        default="assistant",
                        help="The role name to return if "
                        "`request.add_generation_prompt=true`.")
84
85
86
87
88
89
90
91
    parser.add_argument("--ssl-keyfile",
                        type=str,
                        default=None,
                        help="The file path to the SSL key file")
    parser.add_argument("--ssl-certfile",
                        type=str,
                        default=None,
                        help="The file path to the SSL cert file")
92
93
94
95
96
    parser.add_argument(
        "--root-path",
        type=str,
        default=None,
        help="FastAPI root_path when app is behind a path based routing proxy")
97
98
99

    parser = AsyncEngineArgs.add_cli_args(parser)
    return parser.parse_args()
Zhuohan Li's avatar
Zhuohan Li committed
100
101


102
103
104
105
app.add_middleware(MetricsMiddleware)  # Trace HTTP server metrics
app.add_route("/metrics", metrics)  # Exposes HTTP metrics


Zhuohan Li's avatar
Zhuohan Li committed
106
@app.exception_handler(RequestValidationError)
107
async def validation_exception_handler(_, exc):
108
109
    err = openai_serving_chat.create_error_response(message=str(exc))
    return JSONResponse(err.dict(), status_code=HTTPStatus.BAD_REQUEST)
110
111


112
113
114
115
116
117
@app.get("/health")
async def health() -> Response:
    """Health check."""
    return Response(status_code=200)


Zhuohan Li's avatar
Zhuohan Li committed
118
119
@app.get("/v1/models")
async def show_available_models():
120
121
    models = await openai_serving_chat.show_available_models()
    return JSONResponse(content=models.dict())
Zhuohan Li's avatar
Zhuohan Li committed
122
123


124
@app.post("/v1/chat/completions")
125
126
async def create_chat_completion(request: ChatCompletionRequest,
                                 raw_request: Request):
127
128
129
130
    generator = await openai_serving_chat.create_chat_completion(
        request, raw_request)
    if request.stream and not isinstance(generator, ErrorResponse):
        return StreamingResponse(content=generator,
131
                                 media_type="text/event-stream")
132
    else:
133
        return JSONResponse(content=generator.dict())
134
135


Zhuohan Li's avatar
Zhuohan Li committed
136
@app.post("/v1/completions")
137
async def create_completion(request: CompletionRequest, raw_request: Request):
138
139
140
141
    generator = await openai_serving_completion.create_completion(
        request, raw_request)
    if request.stream and not isinstance(generator, ErrorResponse):
        return StreamingResponse(content=generator,
Zhuohan Li's avatar
Zhuohan Li committed
142
                                 media_type="text/event-stream")
143
144
    else:
        return JSONResponse(content=generator.dict())
Zhuohan Li's avatar
Zhuohan Li committed
145
146
147


if __name__ == "__main__":
148
    args = parse_args()
Zhuohan Li's avatar
Zhuohan Li committed
149
150
151
152
153
154
155
156
157
158
159

    app.add_middleware(
        CORSMiddleware,
        allow_origins=args.allowed_origins,
        allow_credentials=args.allow_credentials,
        allow_methods=args.allowed_methods,
        allow_headers=args.allowed_headers,
    )

    logger.info(f"args: {args}")

160
161
162
163
164
    if args.served_model_name is not None:
        served_model = args.served_model_name
    else:
        served_model = args.model

Zhuohan Li's avatar
Zhuohan Li committed
165
    engine_args = AsyncEngineArgs.from_cli_args(args)
166
    engine = AsyncLLMEngine.from_engine_args(engine_args)
167
168
169
170
    openai_serving_chat = OpenAIServingChat(engine, served_model,
                                            args.response_role,
                                            args.chat_template)
    openai_serving_completion = OpenAIServingCompletion(engine, served_model)
Zhuohan Li's avatar
Zhuohan Li committed
171

172
173
174
    # Register labels for metrics
    add_global_metrics_labels(model_name=engine_args.model)

175
    app.root_path = args.root_path
176
177
178
179
    uvicorn.run(app,
                host=args.host,
                port=args.port,
                log_level="info",
180
181
182
                timeout_keep_alive=TIMEOUT_KEEP_ALIVE,
                ssl_keyfile=args.ssl_keyfile,
                ssl_certfile=args.ssl_certfile)