api_server.py 8.41 KB
Newer Older
1
import asyncio
2
3
import importlib
import inspect
4
import re
5
6
from contextlib import asynccontextmanager
from http import HTTPStatus
7
from typing import Optional, Set
8

Zhuohan Li's avatar
Zhuohan Li committed
9
import fastapi
10
import uvicorn
11
from fastapi import Request
Zhuohan Li's avatar
Zhuohan Li committed
12
13
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
14
15
from fastapi.responses import JSONResponse, Response, StreamingResponse
from prometheus_client import make_asgi_app
16
from starlette.routing import Mount
Zhuohan Li's avatar
Zhuohan Li committed
17

18
import vllm.envs as envs
Woosuk Kwon's avatar
Woosuk Kwon committed
19
20
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
21
from vllm.entrypoints.openai.cli_args import make_arg_parser
22
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
23
                                              ChatCompletionResponse,
24
25
                                              CompletionRequest,
                                              EmbeddingRequest, ErrorResponse)
26
27
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
28
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
29
from vllm.logger import init_logger
yhu422's avatar
yhu422 committed
30
from vllm.usage.usage_lib import UsageContext
31
from vllm.version import __version__ as VLLM_VERSION
Zhuohan Li's avatar
Zhuohan Li committed
32

33
TIMEOUT_KEEP_ALIVE = 5  # seconds
Zhuohan Li's avatar
Zhuohan Li committed
34

35
36
openai_serving_chat: OpenAIServingChat
openai_serving_completion: OpenAIServingCompletion
37
38
openai_serving_embedding: OpenAIServingEmbedding

39
logger = init_logger('vllm.entrypoints.openai.api_server')
40

41
_running_tasks: Set[asyncio.Task] = set()
42

43

44
45
46
47
48
49
50
51
52
@asynccontextmanager
async def lifespan(app: fastapi.FastAPI):

    async def _force_log():
        while True:
            await asyncio.sleep(10)
            await engine.do_log_stats()

    if not engine_args.disable_log_stats:
53
54
55
        task = asyncio.create_task(_force_log())
        _running_tasks.add(task)
        task.add_done_callback(_running_tasks.remove)
56
57
58
59
60
61
62

    yield


app = fastapi.FastAPI(lifespan=lifespan)


63
def parse_args():
64
    parser = make_arg_parser()
65
    return parser.parse_args()
Zhuohan Li's avatar
Zhuohan Li committed
66
67


68
# Add prometheus asgi middleware to route /metrics requests
69
70
71
72
route = Mount("/metrics", make_asgi_app())
# Workaround for 307 Redirect for /metrics
route.path_regex = re.compile('^/metrics(?P<path>.*)$')
app.routes.append(route)
73
74


Zhuohan Li's avatar
Zhuohan Li committed
75
@app.exception_handler(RequestValidationError)
76
async def validation_exception_handler(_, exc):
77
    err = openai_serving_chat.create_error_response(message=str(exc))
78
    return JSONResponse(err.model_dump(), status_code=HTTPStatus.BAD_REQUEST)
79
80


81
82
83
@app.get("/health")
async def health() -> Response:
    """Health check."""
84
    await openai_serving_chat.engine.check_health()
85
86
87
    return Response(status_code=200)


Zhuohan Li's avatar
Zhuohan Li committed
88
89
@app.get("/v1/models")
async def show_available_models():
90
    models = await openai_serving_chat.show_available_models()
91
    return JSONResponse(content=models.model_dump())
Zhuohan Li's avatar
Zhuohan Li committed
92
93


94
95
@app.get("/version")
async def show_version():
96
    ver = {"version": VLLM_VERSION}
97
98
99
    return JSONResponse(content=ver)


100
@app.post("/v1/chat/completions")
101
102
async def create_chat_completion(request: ChatCompletionRequest,
                                 raw_request: Request):
103
104
    generator = await openai_serving_chat.create_chat_completion(
        request, raw_request)
105
106
107
108
    if isinstance(generator, ErrorResponse):
        return JSONResponse(content=generator.model_dump(),
                            status_code=generator.code)
    if request.stream:
109
        return StreamingResponse(content=generator,
110
                                 media_type="text/event-stream")
111
    else:
112
        assert isinstance(generator, ChatCompletionResponse)
113
        return JSONResponse(content=generator.model_dump())
114
115


Zhuohan Li's avatar
Zhuohan Li committed
116
@app.post("/v1/completions")
117
async def create_completion(request: CompletionRequest, raw_request: Request):
118
119
    generator = await openai_serving_completion.create_completion(
        request, raw_request)
120
121
122
123
    if isinstance(generator, ErrorResponse):
        return JSONResponse(content=generator.model_dump(),
                            status_code=generator.code)
    if request.stream:
124
        return StreamingResponse(content=generator,
Zhuohan Li's avatar
Zhuohan Li committed
125
                                 media_type="text/event-stream")
126
    else:
127
        return JSONResponse(content=generator.model_dump())
Zhuohan Li's avatar
Zhuohan Li committed
128
129


130
131
132
133
134
135
136
137
138
139
140
@app.post("/v1/embeddings")
async def create_embedding(request: EmbeddingRequest, raw_request: Request):
    generator = await openai_serving_embedding.create_embedding(
        request, raw_request)
    if isinstance(generator, ErrorResponse):
        return JSONResponse(content=generator.model_dump(),
                            status_code=generator.code)
    else:
        return JSONResponse(content=generator.model_dump())


Zhuohan Li's avatar
Zhuohan Li committed
141
if __name__ == "__main__":
142
    args = parse_args()
Zhuohan Li's avatar
Zhuohan Li committed
143
144
145
146
147
148
149
150
151

    app.add_middleware(
        CORSMiddleware,
        allow_origins=args.allowed_origins,
        allow_credentials=args.allow_credentials,
        allow_methods=args.allowed_methods,
        allow_headers=args.allowed_headers,
    )

152
    if token := envs.VLLM_API_KEY or args.api_key:
153
154
155

        @app.middleware("http")
        async def authentication(request: Request, call_next):
156
            root_path = "" if args.root_path is None else args.root_path
157
158
            if request.method == "OPTIONS":
                return await call_next(request)
159
            if not request.url.path.startswith(f"{root_path}/v1"):
160
161
162
163
164
165
166
167
168
169
170
171
172
173
                return await call_next(request)
            if request.headers.get("Authorization") != "Bearer " + token:
                return JSONResponse(content={"error": "Unauthorized"},
                                    status_code=401)
            return await call_next(request)

    for middleware in args.middleware:
        module_path, object_name = middleware.rsplit(".", 1)
        imported = getattr(importlib.import_module(module_path), object_name)
        if inspect.isclass(imported):
            app.add_middleware(imported)
        elif inspect.iscoroutinefunction(imported):
            app.middleware("http")(imported)
        else:
174
175
            raise ValueError(f"Invalid middleware {middleware}. "
                             f"Must be a function or a class.")
176

177
    logger.info("vLLM API server version %s", VLLM_VERSION)
178
    logger.info("args: %s", args)
Zhuohan Li's avatar
Zhuohan Li committed
179

180
    if args.served_model_name is not None:
181
        served_model_names = args.served_model_name
182
    else:
183
        served_model_names = [args.model]
184

Zhuohan Li's avatar
Zhuohan Li committed
185
    engine_args = AsyncEngineArgs.from_cli_args(args)
186
187
188
189
190
191
192
193
194
195

    # Enforce pixel values as image input type for vision language models
    # when serving with API server
    if engine_args.image_input_type is not None and \
        engine_args.image_input_type.upper() != "PIXEL_VALUES":
        raise ValueError(
            f"Invalid image_input_type: {engine_args.image_input_type}. "
            "Only --image-input-type 'pixel_values' is supported for serving "
            "vision language models with the vLLM API server.")

yhu422's avatar
yhu422 committed
196
197
    engine = AsyncLLMEngine.from_engine_args(
        engine_args, usage_context=UsageContext.OPENAI_API_SERVER)
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214

    event_loop: Optional[asyncio.AbstractEventLoop]
    try:
        event_loop = asyncio.get_running_loop()
    except RuntimeError:
        event_loop = None

    if event_loop is not None and event_loop.is_running():
        # If the current is instanced by Ray Serve,
        # there is already a running event loop
        model_config = event_loop.run_until_complete(engine.get_model_config())
    else:
        # When using single vLLM without engine_use_ray
        model_config = asyncio.run(engine.get_model_config())

    openai_serving_chat = OpenAIServingChat(engine, model_config,
                                            served_model_names,
215
                                            args.response_role,
216
                                            args.lora_modules,
217
                                            args.chat_template)
218
    openai_serving_completion = OpenAIServingCompletion(
219
        engine, model_config, served_model_names, args.lora_modules)
220
221
    openai_serving_embedding = OpenAIServingEmbedding(engine, model_config,
                                                      served_model_names)
222
    app.root_path = args.root_path
223
224
225
    uvicorn.run(app,
                host=args.host,
                port=args.port,
226
                log_level=args.uvicorn_log_level,
227
228
                timeout_keep_alive=TIMEOUT_KEEP_ALIVE,
                ssl_keyfile=args.ssl_keyfile,
229
230
231
                ssl_certfile=args.ssl_certfile,
                ssl_ca_certs=args.ssl_ca_certs,
                ssl_cert_reqs=args.ssl_cert_reqs)