api_server.py 41.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import asyncio
4
import hashlib
5
6
import importlib
import inspect
7
import json
8
import multiprocessing
9
import multiprocessing.forkserver as forkserver
10
import os
11
import secrets
12
import signal
13
import socket
14
import tempfile
15
import uuid
16
from argparse import Namespace
17
from collections.abc import AsyncIterator, Awaitable
18
from contextlib import asynccontextmanager
19
from http import HTTPStatus
20
from typing import Any
21

22
import model_hosting_container_standards.sagemaker as sagemaker_standards
23
import pydantic
24
import uvloop
25
from fastapi import APIRouter, Depends, FastAPI, HTTPException, Request
Zhuohan Li's avatar
Zhuohan Li committed
26
27
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
28
from fastapi.responses import JSONResponse, StreamingResponse
29
from starlette.concurrency import iterate_in_threadpool
30
31
from starlette.datastructures import URL, Headers, MutableHeaders, State
from starlette.types import ASGIApp, Message, Receive, Scope, Send
Zhuohan Li's avatar
Zhuohan Li committed
32

33
import vllm.envs as envs
Woosuk Kwon's avatar
Woosuk Kwon committed
34
from vllm.engine.arg_utils import AsyncEngineArgs
35
from vllm.engine.protocol import EngineClient
36
37
38
39
40
41
42
from vllm.entrypoints.anthropic.protocol import (
    AnthropicError,
    AnthropicErrorResponse,
    AnthropicMessagesRequest,
    AnthropicMessagesResponse,
)
from vllm.entrypoints.anthropic.serving_messages import AnthropicServingMessages
43
from vllm.entrypoints.launcher import serve_http
44
from vllm.entrypoints.logger import RequestLogger
45
from vllm.entrypoints.openai.chat_completion.serving import OpenAIServingChat
46
from vllm.entrypoints.openai.cli_args import make_arg_parser, validate_parsed_serve_args
47
from vllm.entrypoints.openai.engine.protocol import (
48
49
50
51
52
    CompletionRequest,
    CompletionResponse,
    ErrorInfo,
    ErrorResponse,
)
53
54
from vllm.entrypoints.openai.engine.serving import OpenAIServing
from vllm.entrypoints.openai.orca_metrics import metrics_header
55
from vllm.entrypoints.openai.responses.serving import OpenAIServingResponses
56
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
57
58
59
60
from vllm.entrypoints.openai.serving_models import (
    BaseModelPath,
    OpenAIServingModels,
)
61
from vllm.entrypoints.openai.translations.serving import (
62
63
64
    OpenAIServingTranscription,
    OpenAIServingTranslation,
)
65
66
67
68
69
from vllm.entrypoints.openai.utils import validate_json_request
from vllm.entrypoints.pooling.classify.serving import ServingClassification
from vllm.entrypoints.pooling.embed.serving import OpenAIServingEmbedding
from vllm.entrypoints.pooling.pooling.serving import OpenAIServingPooling
from vllm.entrypoints.pooling.score.serving import ServingScores
70
71
72
73
74
from vllm.entrypoints.serve.disagg.serving import ServingTokens
from vllm.entrypoints.serve.elastic_ep.middleware import (
    ScalingMiddleware,
)
from vllm.entrypoints.serve.tokenize.serving import OpenAIServingTokenization
75
76
77
78
79
from vllm.entrypoints.tool_server import DemoToolServer, MCPToolServer, ToolServer
from vllm.entrypoints.utils import (
    cli_env_setup,
    load_aware_call,
    log_non_default_args,
80
81
    process_chat_template,
    process_lora_modules,
82
    sanitize_message,
83
84
    with_cancellation,
)
85
from vllm.exceptions import VLLMValidationError
86
from vllm.logger import init_logger
87
from vllm.reasoning import ReasoningParserManager
88
from vllm.tasks import POOLING_TASKS
89
from vllm.tool_parsers import ToolParserManager
yhu422's avatar
yhu422 committed
90
from vllm.usage.usage_lib import UsageContext
Cyrus Leung's avatar
Cyrus Leung committed
91
from vllm.utils.argparse_utils import FlexibleArgumentParser
92
from vllm.utils.gc_utils import freeze_gc_heap
93
from vllm.utils.network_utils import is_valid_ipv6_address
94
from vllm.utils.system_utils import decorate_logs, set_ulimit
95
from vllm.version import __version__ as VLLM_VERSION
Zhuohan Li's avatar
Zhuohan Li committed
96

97
prometheus_multiproc_dir: tempfile.TemporaryDirectory
98

99
# Cannot use __name__ (https://github.com/vllm-project/vllm/pull/4765)
100
logger = init_logger("vllm.entrypoints.openai.api_server")
101

102
103
ENDPOINT_LOAD_METRICS_FORMAT_HEADER_LABEL = "endpoint-load-metrics-format"

104
_running_tasks: set[asyncio.Task] = set()
105

106

107
@asynccontextmanager
108
async def lifespan(app: FastAPI):
109
110
    try:
        if app.state.log_stats:
111
            engine_client: EngineClient = app.state.engine_client
112
113
114

            async def _force_log():
                while True:
115
                    await asyncio.sleep(envs.VLLM_LOG_STATS_INTERVAL)
116
                    await engine_client.do_log_stats()
117
118
119
120
121
122

            task = asyncio.create_task(_force_log())
            _running_tasks.add(task)
            task.add_done_callback(_running_tasks.remove)
        else:
            task = None
123
124
125

        # Mark the startup heap as static so that it's ignored by GC.
        # Reduces pause times of oldest generation collections.
126
        freeze_gc_heap()
127
128
129
130
131
132
133
134
        try:
            yield
        finally:
            if task is not None:
                task.cancel()
    finally:
        # Ensure app state including engine ref is gc'd
        del app.state
135
136


137
@asynccontextmanager
138
async def build_async_engine_client(
139
    args: Namespace,
140
141
    *,
    usage_context: UsageContext = UsageContext.OPENAI_API_SERVER,
142
143
    disable_frontend_multiprocessing: bool | None = None,
    client_config: dict[str, Any] | None = None,
144
) -> AsyncIterator[EngineClient]:
145
146
147
148
    if os.getenv("VLLM_WORKER_MULTIPROC_METHOD") == "forkserver":
        # The executor is expected to be mp.
        # Pre-import heavy modules in the forkserver process
        logger.debug("Setup forkserver with pre-imports")
149
        multiprocessing.set_start_method("forkserver")
150
151
152
153
        multiprocessing.set_forkserver_preload(["vllm.v1.engine.async_llm"])
        forkserver.ensure_running()
        logger.debug("Forkserver setup complete!")

154
    # Context manager to handle engine_client lifecycle
155
156
    # Ensures everything is shutdown and cleaned up on error/exit
    engine_args = AsyncEngineArgs.from_cli_args(args)
157
158
159
    if client_config:
        engine_args._api_process_count = client_config.get("client_count", 1)
        engine_args._api_process_rank = client_config.get("client_index", 0)
160

161
    if disable_frontend_multiprocessing is None:
162
        disable_frontend_multiprocessing = bool(args.disable_frontend_multiprocessing)
163

164
    async with build_async_engine_client_from_engine_args(
165
166
167
168
        engine_args,
        usage_context=usage_context,
        disable_frontend_multiprocessing=disable_frontend_multiprocessing,
        client_config=client_config,
169
    ) as engine:
170
171
172
173
174
175
        yield engine


@asynccontextmanager
async def build_async_engine_client_from_engine_args(
    engine_args: AsyncEngineArgs,
176
177
    *,
    usage_context: UsageContext = UsageContext.OPENAI_API_SERVER,
178
    disable_frontend_multiprocessing: bool = False,
179
    client_config: dict[str, Any] | None = None,
180
) -> AsyncIterator[EngineClient]:
181
    """
182
    Create EngineClient, either:
183
184
185
186
187
188
        - in-process using the AsyncLLMEngine Directly
        - multiprocess using AsyncLLMEngine RPC

    Returns the Client or None if the creation failed.
    """

189
190
191
    # Create the EngineConfig (determines if we can use V1).
    vllm_config = engine_args.create_engine_config(usage_context=usage_context)

192
    if disable_frontend_multiprocessing:
193
        logger.warning("V1 is enabled, but got --disable-frontend-multiprocessing.")
194

195
    from vllm.v1.engine.async_llm import AsyncLLM
196

197
    async_llm: AsyncLLM | None = None
198
199
200
201
202
203

    # Don't mutate the input client_config
    client_config = dict(client_config) if client_config else {}
    client_count = client_config.pop("client_count", 1)
    client_index = client_config.pop("client_index", 0)

204
205
206
207
208
    try:
        async_llm = AsyncLLM.from_vllm_config(
            vllm_config=vllm_config,
            usage_context=usage_context,
            enable_log_requests=engine_args.enable_log_requests,
209
            aggregate_engine_logging=engine_args.aggregate_engine_logging,
210
211
212
            disable_log_stats=engine_args.disable_log_stats,
            client_addresses=client_config,
            client_count=client_count,
213
214
            client_index=client_index,
        )
215
216

        # Don't keep the dummy data in memory
217
        assert async_llm is not None
218
219
220
221
222
223
        await async_llm.reset_mm_cache()

        yield async_llm
    finally:
        if async_llm:
            async_llm.shutdown()
224
225


Ethan Xu's avatar
Ethan Xu committed
226
router = APIRouter()
Zhuohan Li's avatar
Zhuohan Li committed
227

228

229
230
231
232
233
def base(request: Request) -> OpenAIServing:
    # Reuse the existing instance
    return tokenization(request)


234
235
236
237
def models(request: Request) -> OpenAIServingModels:
    return request.app.state.openai_serving_models


238
239
240
241
def messages(request: Request) -> AnthropicServingMessages:
    return request.app.state.anthropic_serving_messages


242
def chat(request: Request) -> OpenAIServingChat | None:
243
244
245
    return request.app.state.openai_serving_chat


246
def completion(request: Request) -> OpenAIServingCompletion | None:
247
248
249
    return request.app.state.openai_serving_completion


250
251
def tokenization(request: Request) -> OpenAIServingTokenization:
    return request.app.state.openai_serving_tokenization
252
253


254
def engine_client(request: Request) -> EngineClient:
255
256
257
    return request.app.state.engine_client


258
259
260
261
262
263
264
@router.get("/load")
async def get_server_load_metrics(request: Request):
    # This endpoint returns the current server load metrics.
    # It tracks requests utilizing the GPU from the following routes:
    # - /v1/chat/completions
    # - /v1/completions
    # - /v1/audio/transcriptions
265
    # - /v1/audio/translations
266
267
    # - /v1/embeddings
    # - /pooling
268
    # - /classify
269
270
271
272
273
    # - /score
    # - /v1/score
    # - /rerank
    # - /v1/rerank
    # - /v2/rerank
274
    return JSONResponse(content={"server_load": request.app.state.server_load_metrics})
275
276


Ethan Xu's avatar
Ethan Xu committed
277
@router.get("/v1/models")
278
async def show_available_models(raw_request: Request):
279
    handler = models(raw_request)
280

281
282
    models_ = await handler.show_available_models()
    return JSONResponse(content=models_.model_dump())
Zhuohan Li's avatar
Zhuohan Li committed
283
284


Ethan Xu's avatar
Ethan Xu committed
285
@router.get("/version")
286
async def show_version():
287
    ver = {"version": VLLM_VERSION}
288
289
290
    return JSONResponse(content=ver)


291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
@router.post(
    "/v1/messages",
    dependencies=[Depends(validate_json_request)],
    responses={
        HTTPStatus.OK.value: {"content": {"text/event-stream": {}}},
        HTTPStatus.BAD_REQUEST.value: {"model": AnthropicErrorResponse},
        HTTPStatus.NOT_FOUND.value: {"model": AnthropicErrorResponse},
        HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": AnthropicErrorResponse},
    },
)
@with_cancellation
@load_aware_call
async def create_messages(request: AnthropicMessagesRequest, raw_request: Request):
    def translate_error_response(response: ErrorResponse) -> JSONResponse:
        anthropic_error = AnthropicErrorResponse(
            error=AnthropicError(
                type=response.error.type,
                message=response.error.message,
            )
        )
        return JSONResponse(
            status_code=response.error.code, content=anthropic_error.model_dump()
        )

    handler = messages(raw_request)
    if handler is None:
        error = base(raw_request).create_error_response(
            message="The model does not support Messages API"
        )
        return translate_error_response(error)

    try:
        generator = await handler.create_messages(request, raw_request)
    except Exception as e:
        logger.exception("Error in create_messages: %s", e)
        return JSONResponse(
            status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value,
            content=AnthropicErrorResponse(
                error=AnthropicError(
                    type="internal_error",
                    message=str(e),
                )
            ).model_dump(),
        )

    if isinstance(generator, ErrorResponse):
        return translate_error_response(generator)

    elif isinstance(generator, AnthropicMessagesResponse):
340
341
342
        resp = generator.model_dump(exclude_none=True)
        logger.debug("Anthropic Messages Response: %s", resp)
        return JSONResponse(content=resp)
343
344
345
346

    return StreamingResponse(content=generator, media_type="text/event-stream")


347
348
349
350
351
352
353
354
355
356
@router.post(
    "/v1/completions",
    dependencies=[Depends(validate_json_request)],
    responses={
        HTTPStatus.OK.value: {"content": {"text/event-stream": {}}},
        HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
        HTTPStatus.NOT_FOUND.value: {"model": ErrorResponse},
        HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
    },
)
357
@with_cancellation
358
@load_aware_call
359
async def create_completion(request: CompletionRequest, raw_request: Request):
360
361
362
    metrics_header_format = raw_request.headers.get(
        ENDPOINT_LOAD_METRICS_FORMAT_HEADER_LABEL, ""
    )
363
364
365
    handler = completion(raw_request)
    if handler is None:
        return base(raw_request).create_error_response(
366
367
            message="The model does not support Completions API"
        )
368

369
370
371
    try:
        generator = await handler.create_completion(request, raw_request)
    except Exception as e:
372
        return handler.create_error_response(e)
373

374
    if isinstance(generator, ErrorResponse):
375
376
377
        return JSONResponse(
            content=generator.model_dump(), status_code=generator.error.code
        )
378
    elif isinstance(generator, CompletionResponse):
379
380
381
382
        return JSONResponse(
            content=generator.model_dump(),
            headers=metrics_header(metrics_header_format),
        )
Zhuohan Li's avatar
Zhuohan Li committed
383

384
385
    return StreamingResponse(content=generator, media_type="text/event-stream")

Zhuohan Li's avatar
Zhuohan Li committed
386

387
def load_log_config(log_config_file: str | None) -> dict | None:
388
389
390
391
392
393
    if not log_config_file:
        return None
    try:
        with open(log_config_file) as f:
            return json.load(f)
    except Exception as e:
394
395
396
        logger.warning(
            "Failed to load log config from file %s: error %s", log_config_file, e
        )
397
398
399
        return None


400
401
402
class AuthenticationMiddleware:
    """
    Pure ASGI middleware that authenticates each request by checking
403
    if the Authorization Bearer token exists and equals anyof "{api_key}".
404
405
406
407
408
409
410
411

    Notes
    -----
    There are two cases in which authentication is skipped:
        1. The HTTP method is OPTIONS.
        2. The request path doesn't start with /v1 (e.g. /health).
    """

412
    def __init__(self, app: ASGIApp, tokens: list[str]) -> None:
413
        self.app = app
414
        self.api_tokens = [hashlib.sha256(t.encode("utf-8")).digest() for t in tokens]
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431

    def verify_token(self, headers: Headers) -> bool:
        authorization_header_value = headers.get("Authorization")
        if not authorization_header_value:
            return False

        scheme, _, param = authorization_header_value.partition(" ")
        if scheme.lower() != "bearer":
            return False

        param_hash = hashlib.sha256(param.encode("utf-8")).digest()

        token_match = False
        for token_hash in self.api_tokens:
            token_match |= secrets.compare_digest(param_hash, token_hash)

        return token_match
432

433
434
    def __call__(self, scope: Scope, receive: Receive, send: Send) -> Awaitable[None]:
        if scope["type"] not in ("http", "websocket") or scope["method"] == "OPTIONS":
435
436
437
438
439
440
441
            # scope["type"] can be "lifespan" or "startup" for example,
            # in which case we don't need to do anything
            return self.app(scope, receive, send)
        root_path = scope.get("root_path", "")
        url_path = URL(scope=scope).path.removeprefix(root_path)
        headers = Headers(scope=scope)
        # Type narrow to satisfy mypy.
442
        if url_path.startswith("/v1") and not self.verify_token(headers):
443
            response = JSONResponse(content={"error": "Unauthorized"}, status_code=401)
444
445
446
447
448
449
450
451
452
453
454
455
456
457
            return response(scope, receive, send)
        return self.app(scope, receive, send)


class XRequestIdMiddleware:
    """
    Middleware the set's the X-Request-Id header for each response
    to a random uuid4 (hex) value if the header isn't already
    present in the request, otherwise use the provided request id.
    """

    def __init__(self, app: ASGIApp) -> None:
        self.app = app

458
    def __call__(self, scope: Scope, receive: Receive, send: Send) -> Awaitable[None]:
459
460
461
462
463
464
465
466
467
468
469
470
471
        if scope["type"] not in ("http", "websocket"):
            return self.app(scope, receive, send)

        # Extract the request headers.
        request_headers = Headers(scope=scope)

        async def send_with_request_id(message: Message) -> None:
            """
            Custom send function to mutate the response headers
            and append X-Request-Id to it.
            """
            if message["type"] == "http.response.start":
                response_headers = MutableHeaders(raw=message["headers"])
472
                request_id = request_headers.get("X-Request-Id", uuid.uuid4().hex)
473
474
475
476
477
478
                response_headers.append("X-Request-Id", request_id)
            await send(message)

        return self.app(scope, receive, send_with_request_id)


479
480
481
def _extract_content_from_chunk(chunk_data: dict) -> str:
    """Extract content from a streaming response chunk."""
    try:
482
        from vllm.entrypoints.openai.chat_completion.protocol import (
483
            ChatCompletionStreamResponse,
484
485
        )
        from vllm.entrypoints.openai.engine.protocol import (
486
487
            CompletionStreamResponse,
        )
488
489

        # Try using Completion types for type-safe parsing
490
491
        if chunk_data.get("object") == "chat.completion.chunk":
            chat_response = ChatCompletionStreamResponse.model_validate(chunk_data)
492
493
            if chat_response.choices and chat_response.choices[0].delta.content:
                return chat_response.choices[0].delta.content
494
495
496
        elif chunk_data.get("object") == "text_completion":
            completion_response = CompletionStreamResponse.model_validate(chunk_data)
            if completion_response.choices and completion_response.choices[0].text:
497
498
499
                return completion_response.choices[0].text
    except pydantic.ValidationError:
        # Fallback to manual parsing
500
501
502
503
504
505
        if "choices" in chunk_data and chunk_data["choices"]:
            choice = chunk_data["choices"][0]
            if "delta" in choice and choice["delta"].get("content"):
                return choice["delta"]["content"]
            elif choice.get("text"):
                return choice["text"]
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
    return ""


class SSEDecoder:
    """Robust Server-Sent Events decoder for streaming responses."""

    def __init__(self):
        self.buffer = ""
        self.content_buffer = []

    def decode_chunk(self, chunk: bytes) -> list[dict]:
        """Decode a chunk of SSE data and return parsed events."""
        import json

        try:
521
            chunk_str = chunk.decode("utf-8")
522
523
524
525
526
527
528
529
        except UnicodeDecodeError:
            # Skip malformed chunks
            return []

        self.buffer += chunk_str
        events = []

        # Process complete lines
530
531
532
        while "\n" in self.buffer:
            line, self.buffer = self.buffer.split("\n", 1)
            line = line.rstrip("\r")  # Handle CRLF
533

534
            if line.startswith("data: "):
535
                data_str = line[6:].strip()
536
537
                if data_str == "[DONE]":
                    events.append({"type": "done"})
538
539
540
                elif data_str:
                    try:
                        event_data = json.loads(data_str)
541
                        events.append({"type": "data", "data": event_data})
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
                    except json.JSONDecodeError:
                        # Skip malformed JSON
                        continue

        return events

    def extract_content(self, event_data: dict) -> str:
        """Extract content from event data."""
        return _extract_content_from_chunk(event_data)

    def add_content(self, content: str) -> None:
        """Add content to the buffer."""
        if content:
            self.content_buffer.append(content)

    def get_complete_content(self) -> str:
        """Get the complete buffered content."""
559
        return "".join(self.content_buffer)
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579


def _log_streaming_response(response, response_body: list) -> None:
    """Log streaming response with robust SSE parsing."""
    from starlette.concurrency import iterate_in_threadpool

    sse_decoder = SSEDecoder()
    chunk_count = 0

    def buffered_iterator():
        nonlocal chunk_count

        for chunk in response_body:
            chunk_count += 1
            yield chunk

            # Parse SSE events from chunk
            events = sse_decoder.decode_chunk(chunk)

            for event in events:
580
581
                if event["type"] == "data":
                    content = sse_decoder.extract_content(event["data"])
582
                    sse_decoder.add_content(content)
583
                elif event["type"] == "done":
584
585
586
587
588
589
590
591
                    # Log complete content when done
                    full_content = sse_decoder.get_complete_content()
                    if full_content:
                        # Truncate if too long
                        if len(full_content) > 2048:
                            full_content = full_content[:2048] + ""
                            "...[truncated]"
                        logger.info(
592
                            "response_body={streaming_complete: content=%r, chunks=%d}",
593
594
595
                            full_content,
                            chunk_count,
                        )
596
597
                    else:
                        logger.info(
598
599
600
                            "response_body={streaming_complete: no_content, chunks=%d}",
                            chunk_count,
                        )
601
602
603
                    return

    response.body_iterator = iterate_in_threadpool(buffered_iterator())
604
    logger.info("response_body={streaming_started: chunks=%d}", len(response_body))
605
606
607
608
609
610
611
612
613
614
615


def _log_non_streaming_response(response_body: list) -> None:
    """Log non-streaming response."""
    try:
        decoded_body = response_body[0].decode()
        logger.info("response_body={%s}", decoded_body)
    except UnicodeDecodeError:
        logger.info("response_body={<binary_data>}")


616
def build_app(args: Namespace) -> FastAPI:
617
    if args.disable_fastapi_docs:
618
619
620
        app = FastAPI(
            openapi_url=None, docs_url=None, redoc_url=None, lifespan=lifespan
        )
621
622
    elif args.enable_offline_docs:
        app = FastAPI(docs_url=None, redoc_url=None, lifespan=lifespan)
623
624
    else:
        app = FastAPI(lifespan=lifespan)
625
626
    app.state.args = args
    from vllm.entrypoints.serve import register_vllm_serve_api_routers
627

628
    register_vllm_serve_api_routers(app)
629
630
631
    from vllm.entrypoints.openai.chat_completion.api_router import (
        attach_router as register_chat_api_router,
    )
632

633
    register_chat_api_router(app)
634
635
636
637
638
639

    from vllm.entrypoints.openai.responses.api_router import (
        attach_router as register_responses_api_router,
    )

    register_responses_api_router(app)
640
641
642
643
644
    from vllm.entrypoints.openai.translations.api_router import (
        attach_router as register_translations_api_router,
    )

    register_translations_api_router(app)
645
646
647
    from vllm.entrypoints.sagemaker.routes import register_sagemaker_routes

    register_sagemaker_routes(router)
Ethan Xu's avatar
Ethan Xu committed
648
    app.include_router(router)
649

Ethan Xu's avatar
Ethan Xu committed
650
    app.root_path = args.root_path
Zhuohan Li's avatar
Zhuohan Li committed
651

652
653
654
655
    from vllm.entrypoints.pooling import register_pooling_api_routers

    register_pooling_api_routers(app)

Zhuohan Li's avatar
Zhuohan Li committed
656
657
658
659
660
661
662
663
    app.add_middleware(
        CORSMiddleware,
        allow_origins=args.allowed_origins,
        allow_credentials=args.allow_credentials,
        allow_methods=args.allowed_methods,
        allow_headers=args.allowed_headers,
    )

664
665
    @app.exception_handler(HTTPException)
    async def http_exception_handler(_: Request, exc: HTTPException):
666
        err = ErrorResponse(
667
            error=ErrorInfo(
668
                message=sanitize_message(exc.detail),
669
670
671
672
                type=HTTPStatus(exc.status_code).phrase,
                code=exc.status_code,
            )
        )
673
674
        return JSONResponse(err.model_dump(), status_code=exc.status_code)

Ethan Xu's avatar
Ethan Xu committed
675
    @app.exception_handler(RequestValidationError)
676
    async def validation_exception_handler(_: Request, exc: RequestValidationError):
677
678
679
680
681
682
683
684
        param = None
        for error in exc.errors():
            if "ctx" in error and "error" in error["ctx"]:
                ctx_error = error["ctx"]["error"]
                if isinstance(ctx_error, VLLMValidationError):
                    param = ctx_error.parameter
                    break

685
686
687
688
689
690
691
692
        exc_str = str(exc)
        errors_str = str(exc.errors())

        if exc.errors() and errors_str and errors_str != exc_str:
            message = f"{exc_str} {errors_str}"
        else:
            message = exc_str

693
694
        err = ErrorResponse(
            error=ErrorInfo(
695
                message=sanitize_message(message),
696
697
                type=HTTPStatus.BAD_REQUEST.phrase,
                code=HTTPStatus.BAD_REQUEST,
698
                param=param,
699
700
701
            )
        )
        return JSONResponse(err.model_dump(), status_code=HTTPStatus.BAD_REQUEST)
Ethan Xu's avatar
Ethan Xu committed
702

703
    # Ensure --api-key option from CLI takes precedence over VLLM_API_KEY
704
705
    if tokens := [key for key in (args.api_key or [envs.VLLM_API_KEY]) if key]:
        app.add_middleware(AuthenticationMiddleware, tokens=tokens)
706

707
    if args.enable_request_id_headers:
708
        app.add_middleware(XRequestIdMiddleware)
709

710
711
712
    # Add scaling middleware to check for scaling state
    app.add_middleware(ScalingMiddleware)

713
    if envs.VLLM_DEBUG_LOG_API_SERVER_RESPONSE:
714
715
716
717
718
        logger.warning(
            "CAUTION: Enabling log response in the API Server. "
            "This can include sensitive information and should be "
            "avoided in production."
        )
719
720
721
722

        @app.middleware("http")
        async def log_response(request: Request, call_next):
            response = await call_next(request)
723
            response_body = [section async for section in response.body_iterator]
724
            response.body_iterator = iterate_in_threadpool(iter(response_body))
725
726
727
728
729
730
731
732
733
734
735
            # Check if this is a streaming response by looking at content-type
            content_type = response.headers.get("content-type", "")
            is_streaming = content_type == "text/event-stream; charset=utf-8"

            # Log response body based on type
            if not response_body:
                logger.info("response_body={<empty>}")
            elif is_streaming:
                _log_streaming_response(response, response_body)
            else:
                _log_non_streaming_response(response_body)
736
            return response
737

738
739
740
741
    for middleware in args.middleware:
        module_path, object_name = middleware.rsplit(".", 1)
        imported = getattr(importlib.import_module(module_path), object_name)
        if inspect.isclass(imported):
742
            app.add_middleware(imported)  # type: ignore[arg-type]
743
744
745
        elif inspect.iscoroutinefunction(imported):
            app.middleware("http")(imported)
        else:
746
747
748
            raise ValueError(
                f"Invalid middleware {middleware}. Must be a function or a class."
            )
749

750
751
    app = sagemaker_standards.bootstrap(app)

Ethan Xu's avatar
Ethan Xu committed
752
753
754
    return app


755
async def init_app_state(
756
    engine_client: EngineClient,
757
    state: State,
758
    args: Namespace,
759
) -> None:
760
761
    vllm_config = engine_client.vllm_config

762
    if args.served_model_name is not None:
763
        served_model_names = args.served_model_name
764
    else:
765
        served_model_names = [args.model]
766

767
    if args.enable_log_requests:
768
        request_logger = RequestLogger(max_log_len=args.max_log_len)
769
770
    else:
        request_logger = None
771

772
    base_model_paths = [
773
        BaseModelPath(name=name, model_path=args.model) for name in served_model_names
774
775
    ]

776
    state.engine_client = engine_client
777
    state.log_stats = not args.disable_log_stats
778
    state.vllm_config = vllm_config
779
    state.args = args
780
    supported_tasks = await engine_client.get_supported_tasks()
781
    logger.info("Supported tasks: %s", supported_tasks)
782

783
    resolved_chat_template = await process_chat_template(
784
        args.chat_template, engine_client, vllm_config.model_config
785
    )
786

787
    if args.tool_server == "demo":
788
        tool_server: ToolServer | None = DemoToolServer()
789
790
        assert isinstance(tool_server, DemoToolServer)
        await tool_server.init_and_validate()
791
792
793
    elif args.tool_server:
        tool_server = MCPToolServer()
        await tool_server.add_tool_server(args.tool_server)
794
795
796
    else:
        tool_server = None

797
    # Merge default_mm_loras into the static lora_modules
798
799
800
801
802
    default_mm_loras = (
        vllm_config.lora_config.default_mm_loras
        if vllm_config.lora_config is not None
        else {}
    )
803

804
805
806
807
808
809
    default_mm_loras = (
        vllm_config.lora_config.default_mm_loras
        if vllm_config.lora_config is not None
        else {}
    )
    lora_modules = process_lora_modules(args.lora_modules, default_mm_loras)
810

811
    state.openai_serving_models = OpenAIServingModels(
812
        engine_client=engine_client,
813
        base_model_paths=base_model_paths,
814
        lora_modules=lora_modules,
815
    )
816
    await state.openai_serving_models.init_static_loras()
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
    state.openai_serving_responses = (
        OpenAIServingResponses(
            engine_client,
            state.openai_serving_models,
            request_logger=request_logger,
            chat_template=resolved_chat_template,
            chat_template_content_format=args.chat_template_content_format,
            return_tokens_as_token_ids=args.return_tokens_as_token_ids,
            enable_auto_tools=args.enable_auto_tool_choice,
            tool_parser=args.tool_call_parser,
            tool_server=tool_server,
            reasoning_parser=args.structured_outputs_config.reasoning_parser,
            enable_prompt_tokens_details=args.enable_prompt_tokens_details,
            enable_force_include_usage=args.enable_force_include_usage,
            enable_log_outputs=args.enable_log_outputs,
            log_error_stack=args.log_error_stack,
        )
        if "generate" in supported_tasks
        else None
    )
    state.openai_serving_chat = (
        OpenAIServingChat(
            engine_client,
            state.openai_serving_models,
            args.response_role,
            request_logger=request_logger,
            chat_template=resolved_chat_template,
            chat_template_content_format=args.chat_template_content_format,
845
            default_chat_template_kwargs=args.default_chat_template_kwargs,
846
847
848
849
850
851
852
853
854
            trust_request_chat_template=args.trust_request_chat_template,
            return_tokens_as_token_ids=args.return_tokens_as_token_ids,
            enable_auto_tools=args.enable_auto_tool_choice,
            exclude_tools_when_tool_choice_none=args.exclude_tools_when_tool_choice_none,
            tool_parser=args.tool_call_parser,
            reasoning_parser=args.structured_outputs_config.reasoning_parser,
            enable_prompt_tokens_details=args.enable_prompt_tokens_details,
            enable_force_include_usage=args.enable_force_include_usage,
            enable_log_outputs=args.enable_log_outputs,
855
            enable_log_deltas=args.enable_log_deltas,
856
857
858
859
860
            log_error_stack=args.log_error_stack,
        )
        if "generate" in supported_tasks
        else None
    )
861
862
863
    # Warm up chat template processing to avoid first-request latency
    if state.openai_serving_chat is not None:
        await state.openai_serving_chat.warmup()
864
865
866
867
868
869
870
871
872
873
874
875
876
877
    state.openai_serving_completion = (
        OpenAIServingCompletion(
            engine_client,
            state.openai_serving_models,
            request_logger=request_logger,
            return_tokens_as_token_ids=args.return_tokens_as_token_ids,
            enable_prompt_tokens_details=args.enable_prompt_tokens_details,
            enable_force_include_usage=args.enable_force_include_usage,
            log_error_stack=args.log_error_stack,
        )
        if "generate" in supported_tasks
        else None
    )
    state.openai_serving_pooling = (
878
879
880
881
882
883
884
885
886
887
888
        (
            OpenAIServingPooling(
                engine_client,
                state.openai_serving_models,
                supported_tasks=supported_tasks,
                request_logger=request_logger,
                chat_template=resolved_chat_template,
                chat_template_content_format=args.chat_template_content_format,
                trust_request_chat_template=args.trust_request_chat_template,
                log_error_stack=args.log_error_stack,
            )
889
        )
890
        if any(task in POOLING_TASKS for task in supported_tasks)
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
        else None
    )
    state.openai_serving_embedding = (
        OpenAIServingEmbedding(
            engine_client,
            state.openai_serving_models,
            request_logger=request_logger,
            chat_template=resolved_chat_template,
            chat_template_content_format=args.chat_template_content_format,
            trust_request_chat_template=args.trust_request_chat_template,
            log_error_stack=args.log_error_stack,
        )
        if "embed" in supported_tasks
        else None
    )
    state.openai_serving_classification = (
        ServingClassification(
            engine_client,
            state.openai_serving_models,
            request_logger=request_logger,
911
912
913
            chat_template=resolved_chat_template,
            chat_template_content_format=args.chat_template_content_format,
            trust_request_chat_template=args.trust_request_chat_template,
914
915
916
917
918
919
920
921
922
923
            log_error_stack=args.log_error_stack,
        )
        if "classify" in supported_tasks
        else None
    )
    state.openai_serving_scores = (
        ServingScores(
            engine_client,
            state.openai_serving_models,
            request_logger=request_logger,
924
            score_template=resolved_chat_template,
925
926
927
928
929
            log_error_stack=args.log_error_stack,
        )
        if ("embed" in supported_tasks or "score" in supported_tasks)
        else None
    )
930
    state.openai_serving_tokenization = OpenAIServingTokenization(
931
        engine_client,
932
        state.openai_serving_models,
933
        request_logger=request_logger,
934
935
        chat_template=resolved_chat_template,
        chat_template_content_format=args.chat_template_content_format,
936
        trust_request_chat_template=args.trust_request_chat_template,
937
        log_error_stack=args.log_error_stack,
938
    )
939
940
941
942
943
944
    state.openai_serving_transcription = (
        OpenAIServingTranscription(
            engine_client,
            state.openai_serving_models,
            request_logger=request_logger,
            log_error_stack=args.log_error_stack,
945
            enable_force_include_usage=args.enable_force_include_usage,
946
947
948
949
950
951
952
953
954
955
        )
        if "transcription" in supported_tasks
        else None
    )
    state.openai_serving_translation = (
        OpenAIServingTranslation(
            engine_client,
            state.openai_serving_models,
            request_logger=request_logger,
            log_error_stack=args.log_error_stack,
956
            enable_force_include_usage=args.enable_force_include_usage,
957
958
959
960
        )
        if "transcription" in supported_tasks
        else None
    )
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
    state.anthropic_serving_messages = (
        AnthropicServingMessages(
            engine_client,
            state.openai_serving_models,
            args.response_role,
            request_logger=request_logger,
            chat_template=resolved_chat_template,
            chat_template_content_format=args.chat_template_content_format,
            return_tokens_as_token_ids=args.return_tokens_as_token_ids,
            enable_auto_tools=args.enable_auto_tool_choice,
            tool_parser=args.tool_call_parser,
            reasoning_parser=args.structured_outputs_config.reasoning_parser,
            enable_prompt_tokens_details=args.enable_prompt_tokens_details,
            enable_force_include_usage=args.enable_force_include_usage,
        )
        if "generate" in supported_tasks
        else None
    )
979
980
981
982
983
984
985
986
987
988
989
990
991
992
    state.serving_tokens = (
        ServingTokens(
            engine_client,
            state.openai_serving_models,
            request_logger=request_logger,
            return_tokens_as_token_ids=args.return_tokens_as_token_ids,
            log_error_stack=args.log_error_stack,
            enable_prompt_tokens_details=args.enable_prompt_tokens_details,
            enable_log_outputs=args.enable_log_outputs,
            force_no_detokenize=args.tokens_only,
        )
        if "generate" in supported_tasks
        else None
    )
993

994
995
996
    state.enable_server_load_tracking = args.enable_server_load_tracking
    state.server_load_metrics = 0

997

998
def create_server_socket(addr: tuple[str, int]) -> socket.socket:
999
1000
1001
1002
1003
1004
    family = socket.AF_INET
    if is_valid_ipv6_address(addr[0]):
        family = socket.AF_INET6

    sock = socket.socket(family=family, type=socket.SOCK_STREAM)
    sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
1005
    sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
1006
1007
1008
1009
1010
    sock.bind(addr)

    return sock


1011
1012
1013
1014
1015
1016
def create_server_unix_socket(path: str) -> socket.socket:
    sock = socket.socket(family=socket.AF_UNIX, type=socket.SOCK_STREAM)
    sock.bind(path)
    return sock


1017
def validate_api_server_args(args):
1018
    valid_tool_parses = ToolParserManager.list_registered()
1019
1020
1021
1022
1023
    if args.enable_auto_tool_choice and args.tool_call_parser not in valid_tool_parses:
        raise KeyError(
            f"invalid tool call parser: {args.tool_call_parser} "
            f"(chose from {{ {','.join(valid_tool_parses)} }})"
        )
1024

1025
    valid_reasoning_parsers = ReasoningParserManager.list_registered()
1026
1027
    if (
        reasoning_parser := args.structured_outputs_config.reasoning_parser
1028
    ) and reasoning_parser not in valid_reasoning_parsers:
1029
        raise KeyError(
1030
            f"invalid reasoning parser: {reasoning_parser} "
1031
            f"(chose from {{ {','.join(valid_reasoning_parsers)} }})"
1032
        )
1033

1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044

def setup_server(args):
    """Validate API server args, set up signal handler, create socket
    ready to serve."""

    logger.info("vLLM API server version %s", VLLM_VERSION)
    log_non_default_args(args)

    if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3:
        ToolParserManager.import_tool_parser(args.tool_parser_plugin)

1045
1046
1047
    if args.reasoning_parser_plugin and len(args.reasoning_parser_plugin) > 3:
        ReasoningParserManager.import_reasoning_parser(args.reasoning_parser_plugin)

1048
1049
    validate_api_server_args(args)

1050
1051
1052
    # workaround to make sure that we bind the port before the engine is set up.
    # This avoids race conditions with ray.
    # see https://github.com/vllm-project/vllm/issues/8204
1053
1054
1055
1056
1057
    if args.uds:
        sock = create_server_unix_socket(args.uds)
    else:
        sock_addr = (args.host or "", args.port)
        sock = create_server_socket(sock_addr)
1058

1059
1060
1061
1062
    # workaround to avoid footguns where uvicorn drops requests with too
    # many concurrent requests active
    set_ulimit()

1063
1064
1065
1066
1067
1068
    def signal_handler(*_) -> None:
        # Interrupt server on sigterm while initializing
        raise KeyboardInterrupt("terminated")

    signal.signal(signal.SIGTERM, signal_handler)

1069
1070
1071
1072
1073
    if args.uds:
        listen_address = f"unix:{args.uds}"
    else:
        addr, port = sock_addr
        is_ssl = args.ssl_keyfile and args.ssl_certfile
1074
        host_part = f"[{addr}]" if is_valid_ipv6_address(addr) else addr or "0.0.0.0"
1075
        listen_address = f"http{'s' if is_ssl else ''}://{host_part}:{port}"
1076
1077
1078
1079
1080
    return listen_address, sock


async def run_server(args, **uvicorn_kwargs) -> None:
    """Run a single-worker API server."""
1081
1082

    # Add process-specific prefix to stdout and stderr.
1083
    decorate_logs("APIServer")
1084

1085
1086
1087
1088
    listen_address, sock = setup_server(args)
    await run_server_worker(listen_address, sock, args, **uvicorn_kwargs)


1089
1090
1091
async def run_server_worker(
    listen_address, sock, args, client_config=None, **uvicorn_kwargs
) -> None:
1092
1093
1094
1095
1096
    """Run a single API server worker."""

    if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3:
        ToolParserManager.import_tool_parser(args.tool_parser_plugin)

1097
1098
1099
    if args.reasoning_parser_plugin and len(args.reasoning_parser_plugin) > 3:
        ReasoningParserManager.import_reasoning_parser(args.reasoning_parser_plugin)

1100
1101
1102
    # Load logging config for uvicorn if specified
    log_config = load_log_config(args.log_config_file)
    if log_config is not None:
1103
        uvicorn_kwargs["log_config"] = log_config
1104

1105
    async with build_async_engine_client(
1106
1107
        args,
        client_config=client_config,
1108
    ) as engine_client:
1109
1110
        app = build_app(args)

1111
        await init_app_state(engine_client, app.state, args)
1112

1113
1114
        logger.info(
            "Starting vLLM API server %d on %s",
1115
            engine_client.vllm_config.parallel_config._api_process_rank,
1116
1117
            listen_address,
        )
1118
1119
        shutdown_task = await serve_http(
            app,
1120
            sock=sock,
1121
            enable_ssl_refresh=args.enable_ssl_refresh,
1122
1123
1124
            host=args.host,
            port=args.port,
            log_level=args.uvicorn_log_level,
1125
1126
1127
            # NOTE: When the 'disable_uvicorn_access_log' value is True,
            # no access log will be output.
            access_log=not args.disable_uvicorn_access_log,
1128
            timeout_keep_alive=envs.VLLM_HTTP_TIMEOUT_KEEP_ALIVE,
1129
1130
1131
1132
            ssl_keyfile=args.ssl_keyfile,
            ssl_certfile=args.ssl_certfile,
            ssl_ca_certs=args.ssl_ca_certs,
            ssl_cert_reqs=args.ssl_cert_reqs,
1133
1134
            h11_max_incomplete_event_size=args.h11_max_incomplete_event_size,
            h11_max_header_count=args.h11_max_header_count,
1135
1136
1137
            **uvicorn_kwargs,
        )

1138
    # NB: Await server shutdown only after the backend context is exited
1139
1140
1141
1142
    try:
        await shutdown_task
    finally:
        sock.close()
1143

Ethan Xu's avatar
Ethan Xu committed
1144
1145
1146

if __name__ == "__main__":
    # NOTE(simon):
1147
1148
    # This section should be in sync with vllm/entrypoints/cli/main.py for CLI
    # entrypoints.
1149
    cli_env_setup()
Ethan Xu's avatar
Ethan Xu committed
1150
    parser = FlexibleArgumentParser(
1151
1152
        description="vLLM OpenAI-Compatible RESTful API server."
    )
Ethan Xu's avatar
Ethan Xu committed
1153
1154
    parser = make_arg_parser(parser)
    args = parser.parse_args()
1155
    validate_parsed_serve_args(args)
1156

1157
    uvloop.run(run_server(args))