api_server.py 44.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import asyncio
4
import hashlib
5
6
import importlib
import inspect
7
import json
8
import multiprocessing
9
import multiprocessing.forkserver as forkserver
10
import os
11
import secrets
12
import signal
13
import socket
14
import tempfile
15
import uuid
16
from argparse import Namespace
17
from collections.abc import AsyncIterator, Awaitable
18
from contextlib import asynccontextmanager
19
from http import HTTPStatus
20
from typing import Annotated, Any
21

22
import model_hosting_container_standards.sagemaker as sagemaker_standards
23
import pydantic
24
import uvloop
25
from fastapi import APIRouter, Depends, FastAPI, Form, HTTPException, Request
Zhuohan Li's avatar
Zhuohan Li committed
26
27
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
28
from fastapi.responses import JSONResponse, StreamingResponse
29
from starlette.concurrency import iterate_in_threadpool
30
31
from starlette.datastructures import URL, Headers, MutableHeaders, State
from starlette.types import ASGIApp, Message, Receive, Scope, Send
Zhuohan Li's avatar
Zhuohan Li committed
32

33
import vllm.envs as envs
Woosuk Kwon's avatar
Woosuk Kwon committed
34
from vllm.engine.arg_utils import AsyncEngineArgs
35
from vllm.engine.protocol import EngineClient
36
37
38
39
40
41
42
from vllm.entrypoints.anthropic.protocol import (
    AnthropicError,
    AnthropicErrorResponse,
    AnthropicMessagesRequest,
    AnthropicMessagesResponse,
)
from vllm.entrypoints.anthropic.serving_messages import AnthropicServingMessages
43
from vllm.entrypoints.launcher import serve_http
44
from vllm.entrypoints.logger import RequestLogger
45
from vllm.entrypoints.openai.chat_completion.serving import OpenAIServingChat
46
from vllm.entrypoints.openai.cli_args import make_arg_parser, validate_parsed_serve_args
47
from vllm.entrypoints.openai.engine.protocol import (
48
49
50
51
52
    CompletionRequest,
    CompletionResponse,
    ErrorInfo,
    ErrorResponse,
    TranscriptionRequest,
53
    TranscriptionResponseVariant,
54
    TranslationRequest,
55
    TranslationResponseVariant,
56
)
57
58
from vllm.entrypoints.openai.engine.serving import OpenAIServing
from vllm.entrypoints.openai.orca_metrics import metrics_header
59
from vllm.entrypoints.openai.responses.serving import OpenAIServingResponses
60
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
61
62
63
64
from vllm.entrypoints.openai.serving_models import (
    BaseModelPath,
    OpenAIServingModels,
)
65
from vllm.entrypoints.openai.serving_transcription import (
66
67
68
    OpenAIServingTranscription,
    OpenAIServingTranslation,
)
69
70
71
72
73
from vllm.entrypoints.openai.utils import validate_json_request
from vllm.entrypoints.pooling.classify.serving import ServingClassification
from vllm.entrypoints.pooling.embed.serving import OpenAIServingEmbedding
from vllm.entrypoints.pooling.pooling.serving import OpenAIServingPooling
from vllm.entrypoints.pooling.score.serving import ServingScores
74
75
76
77
78
from vllm.entrypoints.serve.disagg.serving import ServingTokens
from vllm.entrypoints.serve.elastic_ep.middleware import (
    ScalingMiddleware,
)
from vllm.entrypoints.serve.tokenize.serving import OpenAIServingTokenization
79
80
81
82
83
from vllm.entrypoints.tool_server import DemoToolServer, MCPToolServer, ToolServer
from vllm.entrypoints.utils import (
    cli_env_setup,
    load_aware_call,
    log_non_default_args,
84
85
    process_chat_template,
    process_lora_modules,
86
    sanitize_message,
87
88
    with_cancellation,
)
89
from vllm.exceptions import VLLMValidationError
90
from vllm.logger import init_logger
91
from vllm.reasoning import ReasoningParserManager
92
from vllm.tasks import POOLING_TASKS
93
from vllm.tool_parsers import ToolParserManager
yhu422's avatar
yhu422 committed
94
from vllm.usage.usage_lib import UsageContext
Cyrus Leung's avatar
Cyrus Leung committed
95
from vllm.utils.argparse_utils import FlexibleArgumentParser
96
from vllm.utils.gc_utils import freeze_gc_heap
97
from vllm.utils.network_utils import is_valid_ipv6_address
98
from vllm.utils.system_utils import decorate_logs, set_ulimit
99
from vllm.version import __version__ as VLLM_VERSION
Zhuohan Li's avatar
Zhuohan Li committed
100

101
prometheus_multiproc_dir: tempfile.TemporaryDirectory
102

103
# Cannot use __name__ (https://github.com/vllm-project/vllm/pull/4765)
104
logger = init_logger("vllm.entrypoints.openai.api_server")
105

106
107
ENDPOINT_LOAD_METRICS_FORMAT_HEADER_LABEL = "endpoint-load-metrics-format"

108
_running_tasks: set[asyncio.Task] = set()
109

110

111
@asynccontextmanager
112
async def lifespan(app: FastAPI):
113
114
    try:
        if app.state.log_stats:
115
            engine_client: EngineClient = app.state.engine_client
116
117
118

            async def _force_log():
                while True:
119
                    await asyncio.sleep(envs.VLLM_LOG_STATS_INTERVAL)
120
                    await engine_client.do_log_stats()
121
122
123
124
125
126

            task = asyncio.create_task(_force_log())
            _running_tasks.add(task)
            task.add_done_callback(_running_tasks.remove)
        else:
            task = None
127
128
129

        # Mark the startup heap as static so that it's ignored by GC.
        # Reduces pause times of oldest generation collections.
130
        freeze_gc_heap()
131
132
133
134
135
136
137
138
        try:
            yield
        finally:
            if task is not None:
                task.cancel()
    finally:
        # Ensure app state including engine ref is gc'd
        del app.state
139
140


141
@asynccontextmanager
142
async def build_async_engine_client(
143
    args: Namespace,
144
145
    *,
    usage_context: UsageContext = UsageContext.OPENAI_API_SERVER,
146
147
    disable_frontend_multiprocessing: bool | None = None,
    client_config: dict[str, Any] | None = None,
148
) -> AsyncIterator[EngineClient]:
149
150
151
152
    if os.getenv("VLLM_WORKER_MULTIPROC_METHOD") == "forkserver":
        # The executor is expected to be mp.
        # Pre-import heavy modules in the forkserver process
        logger.debug("Setup forkserver with pre-imports")
153
        multiprocessing.set_start_method("forkserver")
154
155
156
157
        multiprocessing.set_forkserver_preload(["vllm.v1.engine.async_llm"])
        forkserver.ensure_running()
        logger.debug("Forkserver setup complete!")

158
    # Context manager to handle engine_client lifecycle
159
160
    # Ensures everything is shutdown and cleaned up on error/exit
    engine_args = AsyncEngineArgs.from_cli_args(args)
161
162
163
    if client_config:
        engine_args._api_process_count = client_config.get("client_count", 1)
        engine_args._api_process_rank = client_config.get("client_index", 0)
164

165
    if disable_frontend_multiprocessing is None:
166
        disable_frontend_multiprocessing = bool(args.disable_frontend_multiprocessing)
167

168
    async with build_async_engine_client_from_engine_args(
169
170
171
172
        engine_args,
        usage_context=usage_context,
        disable_frontend_multiprocessing=disable_frontend_multiprocessing,
        client_config=client_config,
173
    ) as engine:
174
175
176
177
178
179
        yield engine


@asynccontextmanager
async def build_async_engine_client_from_engine_args(
    engine_args: AsyncEngineArgs,
180
181
    *,
    usage_context: UsageContext = UsageContext.OPENAI_API_SERVER,
182
    disable_frontend_multiprocessing: bool = False,
183
    client_config: dict[str, Any] | None = None,
184
) -> AsyncIterator[EngineClient]:
185
    """
186
    Create EngineClient, either:
187
188
189
190
191
192
        - in-process using the AsyncLLMEngine Directly
        - multiprocess using AsyncLLMEngine RPC

    Returns the Client or None if the creation failed.
    """

193
194
195
    # Create the EngineConfig (determines if we can use V1).
    vllm_config = engine_args.create_engine_config(usage_context=usage_context)

196
    if disable_frontend_multiprocessing:
197
        logger.warning("V1 is enabled, but got --disable-frontend-multiprocessing.")
198

199
    from vllm.v1.engine.async_llm import AsyncLLM
200

201
    async_llm: AsyncLLM | None = None
202
203
204
205
206
207

    # Don't mutate the input client_config
    client_config = dict(client_config) if client_config else {}
    client_count = client_config.pop("client_count", 1)
    client_index = client_config.pop("client_index", 0)

208
209
210
211
212
    try:
        async_llm = AsyncLLM.from_vllm_config(
            vllm_config=vllm_config,
            usage_context=usage_context,
            enable_log_requests=engine_args.enable_log_requests,
213
            aggregate_engine_logging=engine_args.aggregate_engine_logging,
214
215
216
            disable_log_stats=engine_args.disable_log_stats,
            client_addresses=client_config,
            client_count=client_count,
217
218
            client_index=client_index,
        )
219
220

        # Don't keep the dummy data in memory
221
        assert async_llm is not None
222
223
224
225
226
227
        await async_llm.reset_mm_cache()

        yield async_llm
    finally:
        if async_llm:
            async_llm.shutdown()
228
229


Ethan Xu's avatar
Ethan Xu committed
230
router = APIRouter()
Zhuohan Li's avatar
Zhuohan Li committed
231

232

233
234
235
236
237
def base(request: Request) -> OpenAIServing:
    # Reuse the existing instance
    return tokenization(request)


238
239
240
241
def models(request: Request) -> OpenAIServingModels:
    return request.app.state.openai_serving_models


242
def responses(request: Request) -> OpenAIServingResponses | None:
243
244
245
    return request.app.state.openai_serving_responses


246
247
248
249
def messages(request: Request) -> AnthropicServingMessages:
    return request.app.state.anthropic_serving_messages


250
def chat(request: Request) -> OpenAIServingChat | None:
251
252
253
    return request.app.state.openai_serving_chat


254
def completion(request: Request) -> OpenAIServingCompletion | None:
255
256
257
    return request.app.state.openai_serving_completion


258
259
def tokenization(request: Request) -> OpenAIServingTokenization:
    return request.app.state.openai_serving_tokenization
260
261


262
263
264
265
def transcription(request: Request) -> OpenAIServingTranscription:
    return request.app.state.openai_serving_transcription


266
267
268
269
def translation(request: Request) -> OpenAIServingTranslation:
    return request.app.state.openai_serving_translation


270
def engine_client(request: Request) -> EngineClient:
271
272
273
    return request.app.state.engine_client


274
275
276
277
def generate_tokens(request: Request) -> ServingTokens | None:
    return request.app.state.serving_tokens


278
279
280
281
282
283
284
@router.get("/load")
async def get_server_load_metrics(request: Request):
    # This endpoint returns the current server load metrics.
    # It tracks requests utilizing the GPU from the following routes:
    # - /v1/chat/completions
    # - /v1/completions
    # - /v1/audio/transcriptions
285
    # - /v1/audio/translations
286
287
    # - /v1/embeddings
    # - /pooling
288
    # - /classify
289
290
291
292
293
    # - /score
    # - /v1/score
    # - /rerank
    # - /v1/rerank
    # - /v2/rerank
294
    return JSONResponse(content={"server_load": request.app.state.server_load_metrics})
295
296


Ethan Xu's avatar
Ethan Xu committed
297
@router.get("/v1/models")
298
async def show_available_models(raw_request: Request):
299
    handler = models(raw_request)
300

301
302
    models_ = await handler.show_available_models()
    return JSONResponse(content=models_.model_dump())
Zhuohan Li's avatar
Zhuohan Li committed
303
304


Ethan Xu's avatar
Ethan Xu committed
305
@router.get("/version")
306
async def show_version():
307
    ver = {"version": VLLM_VERSION}
308
309
310
    return JSONResponse(content=ver)


311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
@router.post(
    "/v1/messages",
    dependencies=[Depends(validate_json_request)],
    responses={
        HTTPStatus.OK.value: {"content": {"text/event-stream": {}}},
        HTTPStatus.BAD_REQUEST.value: {"model": AnthropicErrorResponse},
        HTTPStatus.NOT_FOUND.value: {"model": AnthropicErrorResponse},
        HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": AnthropicErrorResponse},
    },
)
@with_cancellation
@load_aware_call
async def create_messages(request: AnthropicMessagesRequest, raw_request: Request):
    def translate_error_response(response: ErrorResponse) -> JSONResponse:
        anthropic_error = AnthropicErrorResponse(
            error=AnthropicError(
                type=response.error.type,
                message=response.error.message,
            )
        )
        return JSONResponse(
            status_code=response.error.code, content=anthropic_error.model_dump()
        )

    handler = messages(raw_request)
    if handler is None:
        error = base(raw_request).create_error_response(
            message="The model does not support Messages API"
        )
        return translate_error_response(error)

    try:
        generator = await handler.create_messages(request, raw_request)
    except Exception as e:
        logger.exception("Error in create_messages: %s", e)
        return JSONResponse(
            status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value,
            content=AnthropicErrorResponse(
                error=AnthropicError(
                    type="internal_error",
                    message=str(e),
                )
            ).model_dump(),
        )

    if isinstance(generator, ErrorResponse):
        return translate_error_response(generator)

    elif isinstance(generator, AnthropicMessagesResponse):
360
361
362
        resp = generator.model_dump(exclude_none=True)
        logger.debug("Anthropic Messages Response: %s", resp)
        return JSONResponse(content=resp)
363
364
365
366

    return StreamingResponse(content=generator, media_type="text/event-stream")


367
368
369
370
371
372
373
374
375
376
@router.post(
    "/v1/completions",
    dependencies=[Depends(validate_json_request)],
    responses={
        HTTPStatus.OK.value: {"content": {"text/event-stream": {}}},
        HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
        HTTPStatus.NOT_FOUND.value: {"model": ErrorResponse},
        HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
    },
)
377
@with_cancellation
378
@load_aware_call
379
async def create_completion(request: CompletionRequest, raw_request: Request):
380
381
382
    metrics_header_format = raw_request.headers.get(
        ENDPOINT_LOAD_METRICS_FORMAT_HEADER_LABEL, ""
    )
383
384
385
    handler = completion(raw_request)
    if handler is None:
        return base(raw_request).create_error_response(
386
387
            message="The model does not support Completions API"
        )
388

389
390
391
    try:
        generator = await handler.create_completion(request, raw_request)
    except OverflowError as e:
392
393
394
        raise HTTPException(
            status_code=HTTPStatus.BAD_REQUEST.value, detail=str(e)
        ) from e
395
    except Exception as e:
396
397
398
        raise HTTPException(
            status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)
        ) from e
399

400
    if isinstance(generator, ErrorResponse):
401
402
403
        return JSONResponse(
            content=generator.model_dump(), status_code=generator.error.code
        )
404
    elif isinstance(generator, CompletionResponse):
405
406
407
408
        return JSONResponse(
            content=generator.model_dump(),
            headers=metrics_header(metrics_header_format),
        )
Zhuohan Li's avatar
Zhuohan Li committed
409

410
411
    return StreamingResponse(content=generator, media_type="text/event-stream")

Zhuohan Li's avatar
Zhuohan Li committed
412

413
414
415
416
417
418
419
420
421
@router.post(
    "/v1/audio/transcriptions",
    responses={
        HTTPStatus.OK.value: {"content": {"text/event-stream": {}}},
        HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
        HTTPStatus.UNPROCESSABLE_ENTITY.value: {"model": ErrorResponse},
        HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
    },
)
422
@with_cancellation
423
@load_aware_call
424
425
426
async def create_transcriptions(
    raw_request: Request, request: Annotated[TranscriptionRequest, Form()]
):
427
428
429
    handler = transcription(raw_request)
    if handler is None:
        return base(raw_request).create_error_response(
430
431
            message="The model does not support Transcriptions API"
        )
432
433

    audio_data = await request.file.read()
434
    try:
435
        generator = await handler.create_transcription(audio_data, request, raw_request)
436
    except Exception as e:
437
438
439
        raise HTTPException(
            status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)
        ) from e
440
441

    if isinstance(generator, ErrorResponse):
442
443
444
        return JSONResponse(
            content=generator.model_dump(), status_code=generator.error.code
        )
445

446
    elif isinstance(generator, TranscriptionResponseVariant):
447
448
449
450
451
        return JSONResponse(content=generator.model_dump())

    return StreamingResponse(content=generator, media_type="text/event-stream")


452
453
454
455
456
457
458
459
460
@router.post(
    "/v1/audio/translations",
    responses={
        HTTPStatus.OK.value: {"content": {"text/event-stream": {}}},
        HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
        HTTPStatus.UNPROCESSABLE_ENTITY.value: {"model": ErrorResponse},
        HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
    },
)
461
462
@with_cancellation
@load_aware_call
463
464
465
async def create_translations(
    request: Annotated[TranslationRequest, Form()], raw_request: Request
):
466
467
468
    handler = translation(raw_request)
    if handler is None:
        return base(raw_request).create_error_response(
469
470
            message="The model does not support Translations API"
        )
471
472

    audio_data = await request.file.read()
473
    try:
474
        generator = await handler.create_translation(audio_data, request, raw_request)
475
    except Exception as e:
476
477
478
        raise HTTPException(
            status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)
        ) from e
479
480

    if isinstance(generator, ErrorResponse):
481
482
483
        return JSONResponse(
            content=generator.model_dump(), status_code=generator.error.code
        )
484

485
    elif isinstance(generator, TranslationResponseVariant):
486
487
488
489
490
        return JSONResponse(content=generator.model_dump())

    return StreamingResponse(content=generator, media_type="text/event-stream")


491
def load_log_config(log_config_file: str | None) -> dict | None:
492
493
494
495
496
497
    if not log_config_file:
        return None
    try:
        with open(log_config_file) as f:
            return json.load(f)
    except Exception as e:
498
499
500
        logger.warning(
            "Failed to load log config from file %s: error %s", log_config_file, e
        )
501
502
503
        return None


504
505
506
class AuthenticationMiddleware:
    """
    Pure ASGI middleware that authenticates each request by checking
507
    if the Authorization Bearer token exists and equals anyof "{api_key}".
508
509
510
511
512
513
514
515

    Notes
    -----
    There are two cases in which authentication is skipped:
        1. The HTTP method is OPTIONS.
        2. The request path doesn't start with /v1 (e.g. /health).
    """

516
    def __init__(self, app: ASGIApp, tokens: list[str]) -> None:
517
        self.app = app
518
        self.api_tokens = [hashlib.sha256(t.encode("utf-8")).digest() for t in tokens]
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535

    def verify_token(self, headers: Headers) -> bool:
        authorization_header_value = headers.get("Authorization")
        if not authorization_header_value:
            return False

        scheme, _, param = authorization_header_value.partition(" ")
        if scheme.lower() != "bearer":
            return False

        param_hash = hashlib.sha256(param.encode("utf-8")).digest()

        token_match = False
        for token_hash in self.api_tokens:
            token_match |= secrets.compare_digest(param_hash, token_hash)

        return token_match
536

537
538
    def __call__(self, scope: Scope, receive: Receive, send: Send) -> Awaitable[None]:
        if scope["type"] not in ("http", "websocket") or scope["method"] == "OPTIONS":
539
540
541
542
543
544
545
            # scope["type"] can be "lifespan" or "startup" for example,
            # in which case we don't need to do anything
            return self.app(scope, receive, send)
        root_path = scope.get("root_path", "")
        url_path = URL(scope=scope).path.removeprefix(root_path)
        headers = Headers(scope=scope)
        # Type narrow to satisfy mypy.
546
        if url_path.startswith("/v1") and not self.verify_token(headers):
547
            response = JSONResponse(content={"error": "Unauthorized"}, status_code=401)
548
549
550
551
552
553
554
555
556
557
558
559
560
561
            return response(scope, receive, send)
        return self.app(scope, receive, send)


class XRequestIdMiddleware:
    """
    Middleware the set's the X-Request-Id header for each response
    to a random uuid4 (hex) value if the header isn't already
    present in the request, otherwise use the provided request id.
    """

    def __init__(self, app: ASGIApp) -> None:
        self.app = app

562
    def __call__(self, scope: Scope, receive: Receive, send: Send) -> Awaitable[None]:
563
564
565
566
567
568
569
570
571
572
573
574
575
        if scope["type"] not in ("http", "websocket"):
            return self.app(scope, receive, send)

        # Extract the request headers.
        request_headers = Headers(scope=scope)

        async def send_with_request_id(message: Message) -> None:
            """
            Custom send function to mutate the response headers
            and append X-Request-Id to it.
            """
            if message["type"] == "http.response.start":
                response_headers = MutableHeaders(raw=message["headers"])
576
                request_id = request_headers.get("X-Request-Id", uuid.uuid4().hex)
577
578
579
580
581
582
                response_headers.append("X-Request-Id", request_id)
            await send(message)

        return self.app(scope, receive, send_with_request_id)


583
584
585
def _extract_content_from_chunk(chunk_data: dict) -> str:
    """Extract content from a streaming response chunk."""
    try:
586
        from vllm.entrypoints.openai.chat_completion.protocol import (
587
            ChatCompletionStreamResponse,
588
589
        )
        from vllm.entrypoints.openai.engine.protocol import (
590
591
            CompletionStreamResponse,
        )
592
593

        # Try using Completion types for type-safe parsing
594
595
        if chunk_data.get("object") == "chat.completion.chunk":
            chat_response = ChatCompletionStreamResponse.model_validate(chunk_data)
596
597
            if chat_response.choices and chat_response.choices[0].delta.content:
                return chat_response.choices[0].delta.content
598
599
600
        elif chunk_data.get("object") == "text_completion":
            completion_response = CompletionStreamResponse.model_validate(chunk_data)
            if completion_response.choices and completion_response.choices[0].text:
601
602
603
                return completion_response.choices[0].text
    except pydantic.ValidationError:
        # Fallback to manual parsing
604
605
606
607
608
609
        if "choices" in chunk_data and chunk_data["choices"]:
            choice = chunk_data["choices"][0]
            if "delta" in choice and choice["delta"].get("content"):
                return choice["delta"]["content"]
            elif choice.get("text"):
                return choice["text"]
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
    return ""


class SSEDecoder:
    """Robust Server-Sent Events decoder for streaming responses."""

    def __init__(self):
        self.buffer = ""
        self.content_buffer = []

    def decode_chunk(self, chunk: bytes) -> list[dict]:
        """Decode a chunk of SSE data and return parsed events."""
        import json

        try:
625
            chunk_str = chunk.decode("utf-8")
626
627
628
629
630
631
632
633
        except UnicodeDecodeError:
            # Skip malformed chunks
            return []

        self.buffer += chunk_str
        events = []

        # Process complete lines
634
635
636
        while "\n" in self.buffer:
            line, self.buffer = self.buffer.split("\n", 1)
            line = line.rstrip("\r")  # Handle CRLF
637

638
            if line.startswith("data: "):
639
                data_str = line[6:].strip()
640
641
                if data_str == "[DONE]":
                    events.append({"type": "done"})
642
643
644
                elif data_str:
                    try:
                        event_data = json.loads(data_str)
645
                        events.append({"type": "data", "data": event_data})
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
                    except json.JSONDecodeError:
                        # Skip malformed JSON
                        continue

        return events

    def extract_content(self, event_data: dict) -> str:
        """Extract content from event data."""
        return _extract_content_from_chunk(event_data)

    def add_content(self, content: str) -> None:
        """Add content to the buffer."""
        if content:
            self.content_buffer.append(content)

    def get_complete_content(self) -> str:
        """Get the complete buffered content."""
663
        return "".join(self.content_buffer)
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683


def _log_streaming_response(response, response_body: list) -> None:
    """Log streaming response with robust SSE parsing."""
    from starlette.concurrency import iterate_in_threadpool

    sse_decoder = SSEDecoder()
    chunk_count = 0

    def buffered_iterator():
        nonlocal chunk_count

        for chunk in response_body:
            chunk_count += 1
            yield chunk

            # Parse SSE events from chunk
            events = sse_decoder.decode_chunk(chunk)

            for event in events:
684
685
                if event["type"] == "data":
                    content = sse_decoder.extract_content(event["data"])
686
                    sse_decoder.add_content(content)
687
                elif event["type"] == "done":
688
689
690
691
692
693
694
695
                    # Log complete content when done
                    full_content = sse_decoder.get_complete_content()
                    if full_content:
                        # Truncate if too long
                        if len(full_content) > 2048:
                            full_content = full_content[:2048] + ""
                            "...[truncated]"
                        logger.info(
696
                            "response_body={streaming_complete: content=%r, chunks=%d}",
697
698
699
                            full_content,
                            chunk_count,
                        )
700
701
                    else:
                        logger.info(
702
703
704
                            "response_body={streaming_complete: no_content, chunks=%d}",
                            chunk_count,
                        )
705
706
707
                    return

    response.body_iterator = iterate_in_threadpool(buffered_iterator())
708
    logger.info("response_body={streaming_started: chunks=%d}", len(response_body))
709
710
711
712
713
714
715
716
717
718
719


def _log_non_streaming_response(response_body: list) -> None:
    """Log non-streaming response."""
    try:
        decoded_body = response_body[0].decode()
        logger.info("response_body={%s}", decoded_body)
    except UnicodeDecodeError:
        logger.info("response_body={<binary_data>}")


720
def build_app(args: Namespace) -> FastAPI:
721
    if args.disable_fastapi_docs:
722
723
724
        app = FastAPI(
            openapi_url=None, docs_url=None, redoc_url=None, lifespan=lifespan
        )
725
726
    elif args.enable_offline_docs:
        app = FastAPI(docs_url=None, redoc_url=None, lifespan=lifespan)
727
728
    else:
        app = FastAPI(lifespan=lifespan)
729
730
    app.state.args = args
    from vllm.entrypoints.serve import register_vllm_serve_api_routers
731

732
    register_vllm_serve_api_routers(app)
733
734
735
    from vllm.entrypoints.openai.chat_completion.api_router import (
        attach_router as register_chat_api_router,
    )
736

737
    register_chat_api_router(app)
738
739
740
741
742
743

    from vllm.entrypoints.openai.responses.api_router import (
        attach_router as register_responses_api_router,
    )

    register_responses_api_router(app)
744
745
746
    from vllm.entrypoints.sagemaker.routes import register_sagemaker_routes

    register_sagemaker_routes(router)
Ethan Xu's avatar
Ethan Xu committed
747
    app.include_router(router)
748

Ethan Xu's avatar
Ethan Xu committed
749
    app.root_path = args.root_path
Zhuohan Li's avatar
Zhuohan Li committed
750

751
752
753
754
    from vllm.entrypoints.pooling import register_pooling_api_routers

    register_pooling_api_routers(app)

Zhuohan Li's avatar
Zhuohan Li committed
755
756
757
758
759
760
761
762
    app.add_middleware(
        CORSMiddleware,
        allow_origins=args.allowed_origins,
        allow_credentials=args.allow_credentials,
        allow_methods=args.allowed_methods,
        allow_headers=args.allowed_headers,
    )

763
764
    @app.exception_handler(HTTPException)
    async def http_exception_handler(_: Request, exc: HTTPException):
765
        err = ErrorResponse(
766
            error=ErrorInfo(
767
                message=sanitize_message(exc.detail),
768
769
770
771
                type=HTTPStatus(exc.status_code).phrase,
                code=exc.status_code,
            )
        )
772
773
        return JSONResponse(err.model_dump(), status_code=exc.status_code)

Ethan Xu's avatar
Ethan Xu committed
774
    @app.exception_handler(RequestValidationError)
775
    async def validation_exception_handler(_: Request, exc: RequestValidationError):
776
777
778
779
780
781
782
783
        param = None
        for error in exc.errors():
            if "ctx" in error and "error" in error["ctx"]:
                ctx_error = error["ctx"]["error"]
                if isinstance(ctx_error, VLLMValidationError):
                    param = ctx_error.parameter
                    break

784
785
786
787
788
789
790
791
        exc_str = str(exc)
        errors_str = str(exc.errors())

        if exc.errors() and errors_str and errors_str != exc_str:
            message = f"{exc_str} {errors_str}"
        else:
            message = exc_str

792
793
        err = ErrorResponse(
            error=ErrorInfo(
794
                message=sanitize_message(message),
795
796
                type=HTTPStatus.BAD_REQUEST.phrase,
                code=HTTPStatus.BAD_REQUEST,
797
                param=param,
798
799
800
            )
        )
        return JSONResponse(err.model_dump(), status_code=HTTPStatus.BAD_REQUEST)
Ethan Xu's avatar
Ethan Xu committed
801

802
    # Ensure --api-key option from CLI takes precedence over VLLM_API_KEY
803
804
    if tokens := [key for key in (args.api_key or [envs.VLLM_API_KEY]) if key]:
        app.add_middleware(AuthenticationMiddleware, tokens=tokens)
805

806
    if args.enable_request_id_headers:
807
        app.add_middleware(XRequestIdMiddleware)
808

809
810
811
    # Add scaling middleware to check for scaling state
    app.add_middleware(ScalingMiddleware)

812
    if envs.VLLM_DEBUG_LOG_API_SERVER_RESPONSE:
813
814
815
816
817
        logger.warning(
            "CAUTION: Enabling log response in the API Server. "
            "This can include sensitive information and should be "
            "avoided in production."
        )
818
819
820
821

        @app.middleware("http")
        async def log_response(request: Request, call_next):
            response = await call_next(request)
822
            response_body = [section async for section in response.body_iterator]
823
            response.body_iterator = iterate_in_threadpool(iter(response_body))
824
825
826
827
828
829
830
831
832
833
834
            # Check if this is a streaming response by looking at content-type
            content_type = response.headers.get("content-type", "")
            is_streaming = content_type == "text/event-stream; charset=utf-8"

            # Log response body based on type
            if not response_body:
                logger.info("response_body={<empty>}")
            elif is_streaming:
                _log_streaming_response(response, response_body)
            else:
                _log_non_streaming_response(response_body)
835
            return response
836

837
838
839
840
    for middleware in args.middleware:
        module_path, object_name = middleware.rsplit(".", 1)
        imported = getattr(importlib.import_module(module_path), object_name)
        if inspect.isclass(imported):
841
            app.add_middleware(imported)  # type: ignore[arg-type]
842
843
844
        elif inspect.iscoroutinefunction(imported):
            app.middleware("http")(imported)
        else:
845
846
847
            raise ValueError(
                f"Invalid middleware {middleware}. Must be a function or a class."
            )
848

849
850
    app = sagemaker_standards.bootstrap(app)

Ethan Xu's avatar
Ethan Xu committed
851
852
853
    return app


854
async def init_app_state(
855
    engine_client: EngineClient,
856
    state: State,
857
    args: Namespace,
858
) -> None:
859
860
    vllm_config = engine_client.vllm_config

861
    if args.served_model_name is not None:
862
        served_model_names = args.served_model_name
863
    else:
864
        served_model_names = [args.model]
865

866
    if args.enable_log_requests:
867
        request_logger = RequestLogger(max_log_len=args.max_log_len)
868
869
    else:
        request_logger = None
870

871
    base_model_paths = [
872
        BaseModelPath(name=name, model_path=args.model) for name in served_model_names
873
874
    ]

875
    state.engine_client = engine_client
876
    state.log_stats = not args.disable_log_stats
877
    state.vllm_config = vllm_config
878
    state.args = args
879
    supported_tasks = await engine_client.get_supported_tasks()
880
    logger.info("Supported tasks: %s", supported_tasks)
881

882
    resolved_chat_template = await process_chat_template(
883
        args.chat_template, engine_client, vllm_config.model_config
884
    )
885

886
    if args.tool_server == "demo":
887
        tool_server: ToolServer | None = DemoToolServer()
888
889
        assert isinstance(tool_server, DemoToolServer)
        await tool_server.init_and_validate()
890
891
892
    elif args.tool_server:
        tool_server = MCPToolServer()
        await tool_server.add_tool_server(args.tool_server)
893
894
895
    else:
        tool_server = None

896
    # Merge default_mm_loras into the static lora_modules
897
898
899
900
901
    default_mm_loras = (
        vllm_config.lora_config.default_mm_loras
        if vllm_config.lora_config is not None
        else {}
    )
902

903
904
905
906
907
908
    default_mm_loras = (
        vllm_config.lora_config.default_mm_loras
        if vllm_config.lora_config is not None
        else {}
    )
    lora_modules = process_lora_modules(args.lora_modules, default_mm_loras)
909

910
    state.openai_serving_models = OpenAIServingModels(
911
        engine_client=engine_client,
912
        base_model_paths=base_model_paths,
913
        lora_modules=lora_modules,
914
    )
915
    await state.openai_serving_models.init_static_loras()
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
    state.openai_serving_responses = (
        OpenAIServingResponses(
            engine_client,
            state.openai_serving_models,
            request_logger=request_logger,
            chat_template=resolved_chat_template,
            chat_template_content_format=args.chat_template_content_format,
            return_tokens_as_token_ids=args.return_tokens_as_token_ids,
            enable_auto_tools=args.enable_auto_tool_choice,
            tool_parser=args.tool_call_parser,
            tool_server=tool_server,
            reasoning_parser=args.structured_outputs_config.reasoning_parser,
            enable_prompt_tokens_details=args.enable_prompt_tokens_details,
            enable_force_include_usage=args.enable_force_include_usage,
            enable_log_outputs=args.enable_log_outputs,
            log_error_stack=args.log_error_stack,
        )
        if "generate" in supported_tasks
        else None
    )
    state.openai_serving_chat = (
        OpenAIServingChat(
            engine_client,
            state.openai_serving_models,
            args.response_role,
            request_logger=request_logger,
            chat_template=resolved_chat_template,
            chat_template_content_format=args.chat_template_content_format,
944
            default_chat_template_kwargs=args.default_chat_template_kwargs,
945
946
947
948
949
950
951
952
953
            trust_request_chat_template=args.trust_request_chat_template,
            return_tokens_as_token_ids=args.return_tokens_as_token_ids,
            enable_auto_tools=args.enable_auto_tool_choice,
            exclude_tools_when_tool_choice_none=args.exclude_tools_when_tool_choice_none,
            tool_parser=args.tool_call_parser,
            reasoning_parser=args.structured_outputs_config.reasoning_parser,
            enable_prompt_tokens_details=args.enable_prompt_tokens_details,
            enable_force_include_usage=args.enable_force_include_usage,
            enable_log_outputs=args.enable_log_outputs,
954
            enable_log_deltas=args.enable_log_deltas,
955
956
957
958
959
            log_error_stack=args.log_error_stack,
        )
        if "generate" in supported_tasks
        else None
    )
960
961
962
    # Warm up chat template processing to avoid first-request latency
    if state.openai_serving_chat is not None:
        await state.openai_serving_chat.warmup()
963
964
965
966
967
968
969
970
971
972
973
974
975
976
    state.openai_serving_completion = (
        OpenAIServingCompletion(
            engine_client,
            state.openai_serving_models,
            request_logger=request_logger,
            return_tokens_as_token_ids=args.return_tokens_as_token_ids,
            enable_prompt_tokens_details=args.enable_prompt_tokens_details,
            enable_force_include_usage=args.enable_force_include_usage,
            log_error_stack=args.log_error_stack,
        )
        if "generate" in supported_tasks
        else None
    )
    state.openai_serving_pooling = (
977
978
979
980
981
982
983
984
985
986
987
        (
            OpenAIServingPooling(
                engine_client,
                state.openai_serving_models,
                supported_tasks=supported_tasks,
                request_logger=request_logger,
                chat_template=resolved_chat_template,
                chat_template_content_format=args.chat_template_content_format,
                trust_request_chat_template=args.trust_request_chat_template,
                log_error_stack=args.log_error_stack,
            )
988
        )
989
        if any(task in POOLING_TASKS for task in supported_tasks)
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
        else None
    )
    state.openai_serving_embedding = (
        OpenAIServingEmbedding(
            engine_client,
            state.openai_serving_models,
            request_logger=request_logger,
            chat_template=resolved_chat_template,
            chat_template_content_format=args.chat_template_content_format,
            trust_request_chat_template=args.trust_request_chat_template,
            log_error_stack=args.log_error_stack,
        )
        if "embed" in supported_tasks
        else None
    )
    state.openai_serving_classification = (
        ServingClassification(
            engine_client,
            state.openai_serving_models,
            request_logger=request_logger,
1010
1011
1012
            chat_template=resolved_chat_template,
            chat_template_content_format=args.chat_template_content_format,
            trust_request_chat_template=args.trust_request_chat_template,
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
            log_error_stack=args.log_error_stack,
        )
        if "classify" in supported_tasks
        else None
    )
    state.openai_serving_scores = (
        ServingScores(
            engine_client,
            state.openai_serving_models,
            request_logger=request_logger,
1023
            score_template=resolved_chat_template,
1024
1025
1026
1027
1028
            log_error_stack=args.log_error_stack,
        )
        if ("embed" in supported_tasks or "score" in supported_tasks)
        else None
    )
1029
    state.openai_serving_tokenization = OpenAIServingTokenization(
1030
        engine_client,
1031
        state.openai_serving_models,
1032
        request_logger=request_logger,
1033
1034
        chat_template=resolved_chat_template,
        chat_template_content_format=args.chat_template_content_format,
1035
        trust_request_chat_template=args.trust_request_chat_template,
1036
        log_error_stack=args.log_error_stack,
1037
    )
1038
1039
1040
1041
1042
1043
    state.openai_serving_transcription = (
        OpenAIServingTranscription(
            engine_client,
            state.openai_serving_models,
            request_logger=request_logger,
            log_error_stack=args.log_error_stack,
1044
            enable_force_include_usage=args.enable_force_include_usage,
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
        )
        if "transcription" in supported_tasks
        else None
    )
    state.openai_serving_translation = (
        OpenAIServingTranslation(
            engine_client,
            state.openai_serving_models,
            request_logger=request_logger,
            log_error_stack=args.log_error_stack,
1055
            enable_force_include_usage=args.enable_force_include_usage,
1056
1057
1058
1059
        )
        if "transcription" in supported_tasks
        else None
    )
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
    state.anthropic_serving_messages = (
        AnthropicServingMessages(
            engine_client,
            state.openai_serving_models,
            args.response_role,
            request_logger=request_logger,
            chat_template=resolved_chat_template,
            chat_template_content_format=args.chat_template_content_format,
            return_tokens_as_token_ids=args.return_tokens_as_token_ids,
            enable_auto_tools=args.enable_auto_tool_choice,
            tool_parser=args.tool_call_parser,
            reasoning_parser=args.structured_outputs_config.reasoning_parser,
            enable_prompt_tokens_details=args.enable_prompt_tokens_details,
            enable_force_include_usage=args.enable_force_include_usage,
        )
        if "generate" in supported_tasks
        else None
    )
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
    state.serving_tokens = (
        ServingTokens(
            engine_client,
            state.openai_serving_models,
            request_logger=request_logger,
            return_tokens_as_token_ids=args.return_tokens_as_token_ids,
            log_error_stack=args.log_error_stack,
            enable_prompt_tokens_details=args.enable_prompt_tokens_details,
            enable_log_outputs=args.enable_log_outputs,
            force_no_detokenize=args.tokens_only,
        )
        if "generate" in supported_tasks
        else None
    )
1092

1093
1094
1095
    state.enable_server_load_tracking = args.enable_server_load_tracking
    state.server_load_metrics = 0

1096

1097
def create_server_socket(addr: tuple[str, int]) -> socket.socket:
1098
1099
1100
1101
1102
1103
    family = socket.AF_INET
    if is_valid_ipv6_address(addr[0]):
        family = socket.AF_INET6

    sock = socket.socket(family=family, type=socket.SOCK_STREAM)
    sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
1104
    sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
1105
1106
1107
1108
1109
    sock.bind(addr)

    return sock


1110
1111
1112
1113
1114
1115
def create_server_unix_socket(path: str) -> socket.socket:
    sock = socket.socket(family=socket.AF_UNIX, type=socket.SOCK_STREAM)
    sock.bind(path)
    return sock


1116
def validate_api_server_args(args):
1117
    valid_tool_parses = ToolParserManager.list_registered()
1118
1119
1120
1121
1122
    if args.enable_auto_tool_choice and args.tool_call_parser not in valid_tool_parses:
        raise KeyError(
            f"invalid tool call parser: {args.tool_call_parser} "
            f"(chose from {{ {','.join(valid_tool_parses)} }})"
        )
1123

1124
    valid_reasoning_parsers = ReasoningParserManager.list_registered()
1125
1126
    if (
        reasoning_parser := args.structured_outputs_config.reasoning_parser
1127
    ) and reasoning_parser not in valid_reasoning_parsers:
1128
        raise KeyError(
1129
            f"invalid reasoning parser: {reasoning_parser} "
1130
            f"(chose from {{ {','.join(valid_reasoning_parsers)} }})"
1131
        )
1132

1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143

def setup_server(args):
    """Validate API server args, set up signal handler, create socket
    ready to serve."""

    logger.info("vLLM API server version %s", VLLM_VERSION)
    log_non_default_args(args)

    if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3:
        ToolParserManager.import_tool_parser(args.tool_parser_plugin)

1144
1145
1146
    if args.reasoning_parser_plugin and len(args.reasoning_parser_plugin) > 3:
        ReasoningParserManager.import_reasoning_parser(args.reasoning_parser_plugin)

1147
1148
    validate_api_server_args(args)

1149
1150
1151
    # workaround to make sure that we bind the port before the engine is set up.
    # This avoids race conditions with ray.
    # see https://github.com/vllm-project/vllm/issues/8204
1152
1153
1154
1155
1156
    if args.uds:
        sock = create_server_unix_socket(args.uds)
    else:
        sock_addr = (args.host or "", args.port)
        sock = create_server_socket(sock_addr)
1157

1158
1159
1160
1161
    # workaround to avoid footguns where uvicorn drops requests with too
    # many concurrent requests active
    set_ulimit()

1162
1163
1164
1165
1166
1167
    def signal_handler(*_) -> None:
        # Interrupt server on sigterm while initializing
        raise KeyboardInterrupt("terminated")

    signal.signal(signal.SIGTERM, signal_handler)

1168
1169
1170
1171
1172
    if args.uds:
        listen_address = f"unix:{args.uds}"
    else:
        addr, port = sock_addr
        is_ssl = args.ssl_keyfile and args.ssl_certfile
1173
        host_part = f"[{addr}]" if is_valid_ipv6_address(addr) else addr or "0.0.0.0"
1174
        listen_address = f"http{'s' if is_ssl else ''}://{host_part}:{port}"
1175
1176
1177
1178
1179
    return listen_address, sock


async def run_server(args, **uvicorn_kwargs) -> None:
    """Run a single-worker API server."""
1180
1181

    # Add process-specific prefix to stdout and stderr.
1182
    decorate_logs("APIServer")
1183

1184
1185
1186
1187
    listen_address, sock = setup_server(args)
    await run_server_worker(listen_address, sock, args, **uvicorn_kwargs)


1188
1189
1190
async def run_server_worker(
    listen_address, sock, args, client_config=None, **uvicorn_kwargs
) -> None:
1191
1192
1193
1194
1195
    """Run a single API server worker."""

    if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3:
        ToolParserManager.import_tool_parser(args.tool_parser_plugin)

1196
1197
1198
    if args.reasoning_parser_plugin and len(args.reasoning_parser_plugin) > 3:
        ReasoningParserManager.import_reasoning_parser(args.reasoning_parser_plugin)

1199
1200
1201
    # Load logging config for uvicorn if specified
    log_config = load_log_config(args.log_config_file)
    if log_config is not None:
1202
        uvicorn_kwargs["log_config"] = log_config
1203

1204
    async with build_async_engine_client(
1205
1206
        args,
        client_config=client_config,
1207
    ) as engine_client:
1208
1209
        app = build_app(args)

1210
        await init_app_state(engine_client, app.state, args)
1211

1212
1213
        logger.info(
            "Starting vLLM API server %d on %s",
1214
            engine_client.vllm_config.parallel_config._api_process_rank,
1215
1216
            listen_address,
        )
1217
1218
        shutdown_task = await serve_http(
            app,
1219
            sock=sock,
1220
            enable_ssl_refresh=args.enable_ssl_refresh,
1221
1222
1223
            host=args.host,
            port=args.port,
            log_level=args.uvicorn_log_level,
1224
1225
1226
            # NOTE: When the 'disable_uvicorn_access_log' value is True,
            # no access log will be output.
            access_log=not args.disable_uvicorn_access_log,
1227
            timeout_keep_alive=envs.VLLM_HTTP_TIMEOUT_KEEP_ALIVE,
1228
1229
1230
1231
            ssl_keyfile=args.ssl_keyfile,
            ssl_certfile=args.ssl_certfile,
            ssl_ca_certs=args.ssl_ca_certs,
            ssl_cert_reqs=args.ssl_cert_reqs,
1232
1233
            h11_max_incomplete_event_size=args.h11_max_incomplete_event_size,
            h11_max_header_count=args.h11_max_header_count,
1234
1235
1236
            **uvicorn_kwargs,
        )

1237
    # NB: Await server shutdown only after the backend context is exited
1238
1239
1240
1241
    try:
        await shutdown_task
    finally:
        sock.close()
1242

Ethan Xu's avatar
Ethan Xu committed
1243
1244
1245

if __name__ == "__main__":
    # NOTE(simon):
1246
1247
    # This section should be in sync with vllm/entrypoints/cli/main.py for CLI
    # entrypoints.
1248
    cli_env_setup()
Ethan Xu's avatar
Ethan Xu committed
1249
    parser = FlexibleArgumentParser(
1250
1251
        description="vLLM OpenAI-Compatible RESTful API server."
    )
Ethan Xu's avatar
Ethan Xu committed
1252
1253
    parser = make_arg_parser(parser)
    args = parser.parse_args()
1254
    validate_parsed_serve_args(args)
1255

1256
    uvloop.run(run_server(args))