"vscode:/vscode.git/clone" did not exist on "1bb17ecb396f911beaa26ab0d3926d46154c7155"
api_server.py 5.82 KB
Newer Older
1
import asyncio
2
3
import importlib
import inspect
4
5
6
import os
from contextlib import asynccontextmanager
from http import HTTPStatus
7

Zhuohan Li's avatar
Zhuohan Li committed
8
import fastapi
9
import uvicorn
10
from fastapi import Request
Zhuohan Li's avatar
Zhuohan Li committed
11
12
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
13
14
from fastapi.responses import JSONResponse, Response, StreamingResponse
from prometheus_client import make_asgi_app
Zhuohan Li's avatar
Zhuohan Li committed
15

16
import vllm
Woosuk Kwon's avatar
Woosuk Kwon committed
17
18
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
19
from vllm.entrypoints.openai.cli_args import make_arg_parser
20
21
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
                                              CompletionRequest, ErrorResponse)
22
23
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
24
from vllm.logger import init_logger
Zhuohan Li's avatar
Zhuohan Li committed
25

26
TIMEOUT_KEEP_ALIVE = 5  # seconds
Zhuohan Li's avatar
Zhuohan Li committed
27

28
29
openai_serving_chat: OpenAIServingChat = None
openai_serving_completion: OpenAIServingCompletion = None
Zhuohan Li's avatar
Zhuohan Li committed
30
logger = init_logger(__name__)
31
32


33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
@asynccontextmanager
async def lifespan(app: fastapi.FastAPI):

    async def _force_log():
        while True:
            await asyncio.sleep(10)
            await engine.do_log_stats()

    if not engine_args.disable_log_stats:
        asyncio.create_task(_force_log())

    yield


app = fastapi.FastAPI(lifespan=lifespan)


50
def parse_args():
51
    parser = make_arg_parser()
52
    return parser.parse_args()
Zhuohan Li's avatar
Zhuohan Li committed
53
54


55
56
57
# Add prometheus asgi middleware to route /metrics requests
metrics_app = make_asgi_app()
app.mount("/metrics", metrics_app)
58
59


Zhuohan Li's avatar
Zhuohan Li committed
60
@app.exception_handler(RequestValidationError)
61
async def validation_exception_handler(_, exc):
62
    err = openai_serving_chat.create_error_response(message=str(exc))
63
    return JSONResponse(err.model_dump(), status_code=HTTPStatus.BAD_REQUEST)
64
65


66
67
68
@app.get("/health")
async def health() -> Response:
    """Health check."""
69
    await openai_serving_chat.engine.check_health()
70
71
72
    return Response(status_code=200)


Zhuohan Li's avatar
Zhuohan Li committed
73
74
@app.get("/v1/models")
async def show_available_models():
75
    models = await openai_serving_chat.show_available_models()
76
    return JSONResponse(content=models.model_dump())
Zhuohan Li's avatar
Zhuohan Li committed
77
78


79
80
81
82
83
84
@app.get("/version")
async def show_version():
    ver = {"version": vllm.__version__}
    return JSONResponse(content=ver)


85
@app.post("/v1/chat/completions")
86
87
async def create_chat_completion(request: ChatCompletionRequest,
                                 raw_request: Request):
88
89
    generator = await openai_serving_chat.create_chat_completion(
        request, raw_request)
90
91
92
93
    if isinstance(generator, ErrorResponse):
        return JSONResponse(content=generator.model_dump(),
                            status_code=generator.code)
    if request.stream:
94
        return StreamingResponse(content=generator,
95
                                 media_type="text/event-stream")
96
    else:
97
        return JSONResponse(content=generator.model_dump())
98
99


Zhuohan Li's avatar
Zhuohan Li committed
100
@app.post("/v1/completions")
101
async def create_completion(request: CompletionRequest, raw_request: Request):
102
103
    generator = await openai_serving_completion.create_completion(
        request, raw_request)
104
105
106
107
    if isinstance(generator, ErrorResponse):
        return JSONResponse(content=generator.model_dump(),
                            status_code=generator.code)
    if request.stream:
108
        return StreamingResponse(content=generator,
Zhuohan Li's avatar
Zhuohan Li committed
109
                                 media_type="text/event-stream")
110
    else:
111
        return JSONResponse(content=generator.model_dump())
Zhuohan Li's avatar
Zhuohan Li committed
112
113
114


if __name__ == "__main__":
115
    args = parse_args()
Zhuohan Li's avatar
Zhuohan Li committed
116
117
118
119
120
121
122
123
124

    app.add_middleware(
        CORSMiddleware,
        allow_origins=args.allowed_origins,
        allow_credentials=args.allow_credentials,
        allow_methods=args.allowed_methods,
        allow_headers=args.allowed_headers,
    )

125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
    if token := os.environ.get("VLLM_API_KEY") or args.api_key:

        @app.middleware("http")
        async def authentication(request: Request, call_next):
            if not request.url.path.startswith("/v1"):
                return await call_next(request)
            if request.headers.get("Authorization") != "Bearer " + token:
                return JSONResponse(content={"error": "Unauthorized"},
                                    status_code=401)
            return await call_next(request)

    for middleware in args.middleware:
        module_path, object_name = middleware.rsplit(".", 1)
        imported = getattr(importlib.import_module(module_path), object_name)
        if inspect.isclass(imported):
            app.add_middleware(imported)
        elif inspect.iscoroutinefunction(imported):
            app.middleware("http")(imported)
        else:
144
145
            raise ValueError(f"Invalid middleware {middleware}. "
                             f"Must be a function or a class.")
146

147
    logger.info(f"vLLM API server version {vllm.__version__}")
Zhuohan Li's avatar
Zhuohan Li committed
148
149
    logger.info(f"args: {args}")

150
151
152
153
154
    if args.served_model_name is not None:
        served_model = args.served_model_name
    else:
        served_model = args.model

Zhuohan Li's avatar
Zhuohan Li committed
155
    engine_args = AsyncEngineArgs.from_cli_args(args)
156
    engine = AsyncLLMEngine.from_engine_args(engine_args)
157
158
    openai_serving_chat = OpenAIServingChat(engine, served_model,
                                            args.response_role,
159
                                            args.lora_modules,
160
                                            args.chat_template)
161
162
    openai_serving_completion = OpenAIServingCompletion(
        engine, served_model, args.lora_modules)
Zhuohan Li's avatar
Zhuohan Li committed
163

164
    app.root_path = args.root_path
165
166
167
    uvicorn.run(app,
                host=args.host,
                port=args.port,
168
                log_level=args.uvicorn_log_level,
169
170
                timeout_keep_alive=TIMEOUT_KEEP_ALIVE,
                ssl_keyfile=args.ssl_keyfile,
171
172
173
                ssl_certfile=args.ssl_certfile,
                ssl_ca_certs=args.ssl_ca_certs,
                ssl_cert_reqs=args.ssl_cert_reqs)