"vscode:/vscode.git/clone" did not exist on "2b22290ce01b033cc692e7dce159d74a43f6f2c5"
api_server.py 5.87 KB
Newer Older
1
import asyncio
2
from contextlib import asynccontextmanager
3
4
5
6
import os
import importlib
import inspect

7
from prometheus_client import make_asgi_app
Zhuohan Li's avatar
Zhuohan Li committed
8
import fastapi
9
import uvicorn
10
from http import HTTPStatus
11
from fastapi import Request
Zhuohan Li's avatar
Zhuohan Li committed
12
13
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
14
from fastapi.responses import JSONResponse, StreamingResponse, Response
Zhuohan Li's avatar
Zhuohan Li committed
15

16
import vllm
Woosuk Kwon's avatar
Woosuk Kwon committed
17
18
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
19
20
21
from vllm.entrypoints.openai.protocol import (CompletionRequest,
                                              ChatCompletionRequest,
                                              ErrorResponse)
Woosuk Kwon's avatar
Woosuk Kwon committed
22
from vllm.logger import init_logger
23
from vllm.entrypoints.openai.cli_args import make_arg_parser
24
25
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
Zhuohan Li's avatar
Zhuohan Li committed
26

27
TIMEOUT_KEEP_ALIVE = 5  # seconds
Zhuohan Li's avatar
Zhuohan Li committed
28

29
30
openai_serving_chat: OpenAIServingChat = None
openai_serving_completion: OpenAIServingCompletion = None
Zhuohan Li's avatar
Zhuohan Li committed
31
logger = init_logger(__name__)
32
33


34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
@asynccontextmanager
async def lifespan(app: fastapi.FastAPI):

    async def _force_log():
        while True:
            await asyncio.sleep(10)
            await engine.do_log_stats()

    if not engine_args.disable_log_stats:
        asyncio.create_task(_force_log())

    yield


app = fastapi.FastAPI(lifespan=lifespan)


51
def parse_args():
52
    parser = make_arg_parser()
53
    return parser.parse_args()
Zhuohan Li's avatar
Zhuohan Li committed
54
55


56
57
58
# Add prometheus asgi middleware to route /metrics requests
metrics_app = make_asgi_app()
app.mount("/metrics", metrics_app)
59
60


Zhuohan Li's avatar
Zhuohan Li committed
61
@app.exception_handler(RequestValidationError)
62
async def validation_exception_handler(_, exc):
63
    err = openai_serving_chat.create_error_response(message=str(exc))
64
    return JSONResponse(err.model_dump(), status_code=HTTPStatus.BAD_REQUEST)
65
66


67
68
69
@app.get("/health")
async def health() -> Response:
    """Health check."""
70
    await openai_serving_chat.engine.check_health()
71
72
73
    return Response(status_code=200)


Zhuohan Li's avatar
Zhuohan Li committed
74
75
@app.get("/v1/models")
async def show_available_models():
76
    models = await openai_serving_chat.show_available_models()
77
    return JSONResponse(content=models.model_dump())
Zhuohan Li's avatar
Zhuohan Li committed
78
79


80
81
82
83
84
85
@app.get("/version")
async def show_version():
    ver = {"version": vllm.__version__}
    return JSONResponse(content=ver)


86
@app.post("/v1/chat/completions")
87
88
async def create_chat_completion(request: ChatCompletionRequest,
                                 raw_request: Request):
89
90
    generator = await openai_serving_chat.create_chat_completion(
        request, raw_request)
91
92
93
94
    if isinstance(generator, ErrorResponse):
        return JSONResponse(content=generator.model_dump(),
                            status_code=generator.code)
    if request.stream:
95
        return StreamingResponse(content=generator,
96
                                 media_type="text/event-stream")
97
    else:
98
        return JSONResponse(content=generator.model_dump())
99
100


Zhuohan Li's avatar
Zhuohan Li committed
101
@app.post("/v1/completions")
102
async def create_completion(request: CompletionRequest, raw_request: Request):
103
104
    generator = await openai_serving_completion.create_completion(
        request, raw_request)
105
106
107
108
    if isinstance(generator, ErrorResponse):
        return JSONResponse(content=generator.model_dump(),
                            status_code=generator.code)
    if request.stream:
109
        return StreamingResponse(content=generator,
Zhuohan Li's avatar
Zhuohan Li committed
110
                                 media_type="text/event-stream")
111
    else:
112
        return JSONResponse(content=generator.model_dump())
Zhuohan Li's avatar
Zhuohan Li committed
113
114
115


if __name__ == "__main__":
116
    args = parse_args()
Zhuohan Li's avatar
Zhuohan Li committed
117
118
119
120
121
122
123
124
125

    app.add_middleware(
        CORSMiddleware,
        allow_origins=args.allowed_origins,
        allow_credentials=args.allow_credentials,
        allow_methods=args.allowed_methods,
        allow_headers=args.allowed_headers,
    )

126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
    if token := os.environ.get("VLLM_API_KEY") or args.api_key:

        @app.middleware("http")
        async def authentication(request: Request, call_next):
            if not request.url.path.startswith("/v1"):
                return await call_next(request)
            if request.headers.get("Authorization") != "Bearer " + token:
                return JSONResponse(content={"error": "Unauthorized"},
                                    status_code=401)
            return await call_next(request)

    for middleware in args.middleware:
        module_path, object_name = middleware.rsplit(".", 1)
        imported = getattr(importlib.import_module(module_path), object_name)
        if inspect.isclass(imported):
            app.add_middleware(imported)
        elif inspect.iscoroutinefunction(imported):
            app.middleware("http")(imported)
        else:
145
146
            raise ValueError(f"Invalid middleware {middleware}. "
                             f"Must be a function or a class.")
147

148
    logger.info(f"vLLM API server version {vllm.__version__}")
Zhuohan Li's avatar
Zhuohan Li committed
149
150
    logger.info(f"args: {args}")

151
152
153
154
155
    if args.served_model_name is not None:
        served_model = args.served_model_name
    else:
        served_model = args.model

Zhuohan Li's avatar
Zhuohan Li committed
156
    engine_args = AsyncEngineArgs.from_cli_args(args)
157
    engine = AsyncLLMEngine.from_engine_args(engine_args)
158
159
    openai_serving_chat = OpenAIServingChat(engine, served_model,
                                            args.response_role,
160
                                            args.lora_modules,
161
                                            args.chat_template)
162
163
    openai_serving_completion = OpenAIServingCompletion(
        engine, served_model, args.lora_modules)
Zhuohan Li's avatar
Zhuohan Li committed
164

165
    app.root_path = args.root_path
166
167
168
    uvicorn.run(app,
                host=args.host,
                port=args.port,
169
                log_level=args.uvicorn_log_level,
170
171
                timeout_keep_alive=TIMEOUT_KEEP_ALIVE,
                ssl_keyfile=args.ssl_keyfile,
172
173
174
                ssl_certfile=args.ssl_certfile,
                ssl_ca_certs=args.ssl_ca_certs,
                ssl_cert_reqs=args.ssl_cert_reqs)