api_server.py 9.4 KB
Newer Older
1
import asyncio
2
3
import importlib
import inspect
4
import re
5
6
from contextlib import asynccontextmanager
from http import HTTPStatus
7
from typing import Optional, Set
8

Zhuohan Li's avatar
Zhuohan Li committed
9
import fastapi
10
import uvicorn
11
from fastapi import Request
Zhuohan Li's avatar
Zhuohan Li committed
12
13
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
14
15
from fastapi.responses import JSONResponse, Response, StreamingResponse
from prometheus_client import make_asgi_app
16
from starlette.routing import Mount
Zhuohan Li's avatar
Zhuohan Li committed
17

18
import vllm.envs as envs
Woosuk Kwon's avatar
Woosuk Kwon committed
19
20
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
21
from vllm.entrypoints.openai.cli_args import make_arg_parser
22
23
# yapf conflicts with isort for this block
# yapf: disable
24
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
25
                                              ChatCompletionResponse,
26
                                              CompletionRequest,
27
28
29
30
31
32
                                              DetokenizeRequest,
                                              DetokenizeResponse,
                                              EmbeddingRequest, ErrorResponse,
                                              TokenizeRequest,
                                              TokenizeResponse)
# yapf: enable
33
34
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
35
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
36
from vllm.logger import init_logger
yhu422's avatar
yhu422 committed
37
from vllm.usage.usage_lib import UsageContext
38
from vllm.version import __version__ as VLLM_VERSION
Zhuohan Li's avatar
Zhuohan Li committed
39

40
TIMEOUT_KEEP_ALIVE = 5  # seconds
Zhuohan Li's avatar
Zhuohan Li committed
41

42
43
openai_serving_chat: OpenAIServingChat
openai_serving_completion: OpenAIServingCompletion
44
45
openai_serving_embedding: OpenAIServingEmbedding

46
logger = init_logger('vllm.entrypoints.openai.api_server')
47

48
_running_tasks: Set[asyncio.Task] = set()
49

50

51
52
53
54
55
56
57
58
59
@asynccontextmanager
async def lifespan(app: fastapi.FastAPI):

    async def _force_log():
        while True:
            await asyncio.sleep(10)
            await engine.do_log_stats()

    if not engine_args.disable_log_stats:
60
61
62
        task = asyncio.create_task(_force_log())
        _running_tasks.add(task)
        task.add_done_callback(_running_tasks.remove)
63
64
65
66
67
68
69

    yield


app = fastapi.FastAPI(lifespan=lifespan)


70
def parse_args():
71
    parser = make_arg_parser()
72
    return parser.parse_args()
Zhuohan Li's avatar
Zhuohan Li committed
73
74


75
# Add prometheus asgi middleware to route /metrics requests
76
77
78
79
route = Mount("/metrics", make_asgi_app())
# Workaround for 307 Redirect for /metrics
route.path_regex = re.compile('^/metrics(?P<path>.*)$')
app.routes.append(route)
80
81


Zhuohan Li's avatar
Zhuohan Li committed
82
@app.exception_handler(RequestValidationError)
83
async def validation_exception_handler(_, exc):
84
    err = openai_serving_chat.create_error_response(message=str(exc))
85
    return JSONResponse(err.model_dump(), status_code=HTTPStatus.BAD_REQUEST)
86
87


88
89
90
@app.get("/health")
async def health() -> Response:
    """Health check."""
91
    await openai_serving_chat.engine.check_health()
92
93
94
    return Response(status_code=200)


95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
@app.post("/tokenize")
async def tokenize(request: TokenizeRequest):
    generator = await openai_serving_completion.create_tokenize(request)
    if isinstance(generator, ErrorResponse):
        return JSONResponse(content=generator.model_dump(),
                            status_code=generator.code)
    else:
        assert isinstance(generator, TokenizeResponse)
        return JSONResponse(content=generator.model_dump())


@app.post("/detokenize")
async def detokenize(request: DetokenizeRequest):
    generator = await openai_serving_completion.create_detokenize(request)
    if isinstance(generator, ErrorResponse):
        return JSONResponse(content=generator.model_dump(),
                            status_code=generator.code)
    else:
        assert isinstance(generator, DetokenizeResponse)
        return JSONResponse(content=generator.model_dump())


Zhuohan Li's avatar
Zhuohan Li committed
117
118
@app.get("/v1/models")
async def show_available_models():
119
    models = await openai_serving_completion.show_available_models()
120
    return JSONResponse(content=models.model_dump())
Zhuohan Li's avatar
Zhuohan Li committed
121
122


123
124
@app.get("/version")
async def show_version():
125
    ver = {"version": VLLM_VERSION}
126
127
128
    return JSONResponse(content=ver)


129
@app.post("/v1/chat/completions")
130
131
async def create_chat_completion(request: ChatCompletionRequest,
                                 raw_request: Request):
132
133
    generator = await openai_serving_chat.create_chat_completion(
        request, raw_request)
134
135
136
137
    if isinstance(generator, ErrorResponse):
        return JSONResponse(content=generator.model_dump(),
                            status_code=generator.code)
    if request.stream:
138
        return StreamingResponse(content=generator,
139
                                 media_type="text/event-stream")
140
    else:
141
        assert isinstance(generator, ChatCompletionResponse)
142
        return JSONResponse(content=generator.model_dump())
143
144


Zhuohan Li's avatar
Zhuohan Li committed
145
@app.post("/v1/completions")
146
async def create_completion(request: CompletionRequest, raw_request: Request):
147
148
    generator = await openai_serving_completion.create_completion(
        request, raw_request)
149
150
151
152
    if isinstance(generator, ErrorResponse):
        return JSONResponse(content=generator.model_dump(),
                            status_code=generator.code)
    if request.stream:
153
        return StreamingResponse(content=generator,
Zhuohan Li's avatar
Zhuohan Li committed
154
                                 media_type="text/event-stream")
155
    else:
156
        return JSONResponse(content=generator.model_dump())
Zhuohan Li's avatar
Zhuohan Li committed
157
158


159
160
161
162
163
164
165
166
167
168
169
@app.post("/v1/embeddings")
async def create_embedding(request: EmbeddingRequest, raw_request: Request):
    generator = await openai_serving_embedding.create_embedding(
        request, raw_request)
    if isinstance(generator, ErrorResponse):
        return JSONResponse(content=generator.model_dump(),
                            status_code=generator.code)
    else:
        return JSONResponse(content=generator.model_dump())


Zhuohan Li's avatar
Zhuohan Li committed
170
if __name__ == "__main__":
171
    args = parse_args()
Zhuohan Li's avatar
Zhuohan Li committed
172
173
174
175
176
177
178
179
180

    app.add_middleware(
        CORSMiddleware,
        allow_origins=args.allowed_origins,
        allow_credentials=args.allow_credentials,
        allow_methods=args.allowed_methods,
        allow_headers=args.allowed_headers,
    )

181
    if token := envs.VLLM_API_KEY or args.api_key:
182
183
184

        @app.middleware("http")
        async def authentication(request: Request, call_next):
185
            root_path = "" if args.root_path is None else args.root_path
186
187
            if request.method == "OPTIONS":
                return await call_next(request)
188
            if not request.url.path.startswith(f"{root_path}/v1"):
189
190
191
192
193
194
195
196
197
198
199
200
201
202
                return await call_next(request)
            if request.headers.get("Authorization") != "Bearer " + token:
                return JSONResponse(content={"error": "Unauthorized"},
                                    status_code=401)
            return await call_next(request)

    for middleware in args.middleware:
        module_path, object_name = middleware.rsplit(".", 1)
        imported = getattr(importlib.import_module(module_path), object_name)
        if inspect.isclass(imported):
            app.add_middleware(imported)
        elif inspect.iscoroutinefunction(imported):
            app.middleware("http")(imported)
        else:
203
204
            raise ValueError(f"Invalid middleware {middleware}. "
                             f"Must be a function or a class.")
205

206
    logger.info("vLLM API server version %s", VLLM_VERSION)
207
    logger.info("args: %s", args)
Zhuohan Li's avatar
Zhuohan Li committed
208

209
    if args.served_model_name is not None:
210
        served_model_names = args.served_model_name
211
    else:
212
        served_model_names = [args.model]
213

Zhuohan Li's avatar
Zhuohan Li committed
214
    engine_args = AsyncEngineArgs.from_cli_args(args)
215

yhu422's avatar
yhu422 committed
216
217
    engine = AsyncLLMEngine.from_engine_args(
        engine_args, usage_context=UsageContext.OPENAI_API_SERVER)
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234

    event_loop: Optional[asyncio.AbstractEventLoop]
    try:
        event_loop = asyncio.get_running_loop()
    except RuntimeError:
        event_loop = None

    if event_loop is not None and event_loop.is_running():
        # If the current is instanced by Ray Serve,
        # there is already a running event loop
        model_config = event_loop.run_until_complete(engine.get_model_config())
    else:
        # When using single vLLM without engine_use_ray
        model_config = asyncio.run(engine.get_model_config())

    openai_serving_chat = OpenAIServingChat(engine, model_config,
                                            served_model_names,
235
                                            args.response_role,
236
                                            args.lora_modules,
237
                                            args.chat_template)
238
    openai_serving_completion = OpenAIServingCompletion(
239
240
        engine, model_config, served_model_names, args.lora_modules,
        args.prompt_adapters)
241
242
    openai_serving_embedding = OpenAIServingEmbedding(engine, model_config,
                                                      served_model_names)
243
    app.root_path = args.root_path
244
245
246
247
248
249
250
251

    logger.info("Available routes are:")
    for route in app.routes:
        if not hasattr(route, 'methods'):
            continue
        methods = ', '.join(route.methods)
        logger.info("Route: %s, Methods: %s", route.path, methods)

252
253
254
    uvicorn.run(app,
                host=args.host,
                port=args.port,
255
                log_level=args.uvicorn_log_level,
256
257
                timeout_keep_alive=TIMEOUT_KEEP_ALIVE,
                ssl_keyfile=args.ssl_keyfile,
258
259
260
                ssl_certfile=args.ssl_certfile,
                ssl_ca_certs=args.ssl_ca_certs,
                ssl_cert_reqs=args.ssl_cert_reqs)