"vscode:/vscode.git/clone" did not exist on "b2f78cbad4e86c484a0fcc518c9e8e3cded0e2cd"
api_server.py 36.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import asyncio
4

5
import hashlib
6
7
import importlib
import inspect
8
import json
9
import multiprocessing
10
import multiprocessing.forkserver as forkserver
11
import os
12
import secrets
13
import signal
14
import socket
15
import tempfile
16
import uuid
17
from argparse import Namespace
18
from collections.abc import AsyncIterator, Awaitable
19
from contextlib import asynccontextmanager
20
from http import HTTPStatus
21
from typing import Any
22

23
import model_hosting_container_standards.sagemaker as sagemaker_standards
24
import pydantic
25
import uvloop
26
from fastapi import APIRouter, FastAPI, HTTPException, Request
Zhuohan Li's avatar
Zhuohan Li committed
27
28
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
29
from fastapi.responses import JSONResponse
30
from starlette.concurrency import iterate_in_threadpool
31
32
from starlette.datastructures import URL, Headers, MutableHeaders, State
from starlette.types import ASGIApp, Message, Receive, Scope, Send
Zhuohan Li's avatar
Zhuohan Li committed
33

34
import vllm.envs as envs
Woosuk Kwon's avatar
Woosuk Kwon committed
35
from vllm.engine.arg_utils import AsyncEngineArgs
36
from vllm.engine.protocol import EngineClient
37
from vllm.entrypoints.anthropic.serving import AnthropicServingMessages
38
from vllm.entrypoints.chat_utils import load_chat_template
39
from vllm.entrypoints.launcher import serve_http
40
from vllm.entrypoints.logger import RequestLogger
41
from vllm.entrypoints.mcp.tool_server import DemoToolServer, MCPToolServer, ToolServer
42
from vllm.entrypoints.openai.chat_completion.serving import OpenAIServingChat
43
from vllm.entrypoints.openai.cli_args import make_arg_parser, validate_parsed_serve_args
44
from vllm.entrypoints.openai.completion.serving import OpenAIServingCompletion
45
from vllm.entrypoints.openai.engine.protocol import (
46
47
48
    ErrorInfo,
    ErrorResponse,
)
49
from vllm.entrypoints.openai.engine.serving import OpenAIServing
50
51
from vllm.entrypoints.openai.models.protocol import BaseModelPath
from vllm.entrypoints.openai.models.serving import (
52
53
    OpenAIServingModels,
)
54
from vllm.entrypoints.openai.responses.serving import OpenAIServingResponses
55
from vllm.entrypoints.openai.translations.serving import (
56
57
58
    OpenAIServingTranscription,
    OpenAIServingTranslation,
)
59
60
61
62
63
from vllm.entrypoints.serve.disagg.serving import ServingTokens
from vllm.entrypoints.serve.elastic_ep.middleware import (
    ScalingMiddleware,
)
from vllm.entrypoints.serve.tokenize.serving import OpenAIServingTokenization
64
65
66
from vllm.entrypoints.utils import (
    cli_env_setup,
    log_non_default_args,
67
    log_version_and_model,
68
    process_lora_modules,
69
    sanitize_message,
70
)
71
from vllm.exceptions import VLLMValidationError
72
from vllm.logger import init_logger
73
from vllm.reasoning import ReasoningParserManager
74
from vllm.tool_parsers import ToolParserManager
yhu422's avatar
yhu422 committed
75
from vllm.usage.usage_lib import UsageContext
Cyrus Leung's avatar
Cyrus Leung committed
76
from vllm.utils.argparse_utils import FlexibleArgumentParser
77
from vllm.utils.gc_utils import freeze_gc_heap
78
from vllm.utils.network_utils import is_valid_ipv6_address
79
from vllm.utils.system_utils import decorate_logs, set_ulimit
80
from vllm.version import __version__ as VLLM_VERSION
Zhuohan Li's avatar
Zhuohan Li committed
81

82
prometheus_multiproc_dir: tempfile.TemporaryDirectory
83

84
# Cannot use __name__ (https://github.com/vllm-project/vllm/pull/4765)
85
logger = init_logger("vllm.entrypoints.openai.api_server")
86
87


88
_running_tasks: set[asyncio.Task] = set()
89

90

91
@asynccontextmanager
92
async def lifespan(app: FastAPI):
93
94
    try:
        if app.state.log_stats:
95
            engine_client: EngineClient = app.state.engine_client
96
97
98

            async def _force_log():
                while True:
99
                    await asyncio.sleep(envs.VLLM_LOG_STATS_INTERVAL)
100
                    await engine_client.do_log_stats()
101
102
103
104
105
106

            task = asyncio.create_task(_force_log())
            _running_tasks.add(task)
            task.add_done_callback(_running_tasks.remove)
        else:
            task = None
107
108
109

        # Mark the startup heap as static so that it's ignored by GC.
        # Reduces pause times of oldest generation collections.
110
        freeze_gc_heap()
111
112
113
114
115
116
117
118
        try:
            yield
        finally:
            if task is not None:
                task.cancel()
    finally:
        # Ensure app state including engine ref is gc'd
        del app.state
119
120


121
@asynccontextmanager
122
async def build_async_engine_client(
123
    args: Namespace,
124
125
    *,
    usage_context: UsageContext = UsageContext.OPENAI_API_SERVER,
126
127
    disable_frontend_multiprocessing: bool | None = None,
    client_config: dict[str, Any] | None = None,
128
) -> AsyncIterator[EngineClient]:
129
130
131
132
    if os.getenv("VLLM_WORKER_MULTIPROC_METHOD") == "forkserver":
        # The executor is expected to be mp.
        # Pre-import heavy modules in the forkserver process
        logger.debug("Setup forkserver with pre-imports")
133
        multiprocessing.set_start_method("forkserver")
134
135
136
137
        multiprocessing.set_forkserver_preload(["vllm.v1.engine.async_llm"])
        forkserver.ensure_running()
        logger.debug("Forkserver setup complete!")

138
    # Context manager to handle engine_client lifecycle
139
140
    # Ensures everything is shutdown and cleaned up on error/exit
    engine_args = AsyncEngineArgs.from_cli_args(args)
141
142
143
    if client_config:
        engine_args._api_process_count = client_config.get("client_count", 1)
        engine_args._api_process_rank = client_config.get("client_index", 0)
144

145
    if disable_frontend_multiprocessing is None:
146
        disable_frontend_multiprocessing = bool(args.disable_frontend_multiprocessing)
147

148
    async with build_async_engine_client_from_engine_args(
149
150
151
152
        engine_args,
        usage_context=usage_context,
        disable_frontend_multiprocessing=disable_frontend_multiprocessing,
        client_config=client_config,
153
    ) as engine:
154
155
156
157
158
159
        yield engine


@asynccontextmanager
async def build_async_engine_client_from_engine_args(
    engine_args: AsyncEngineArgs,
160
161
    *,
    usage_context: UsageContext = UsageContext.OPENAI_API_SERVER,
162
    disable_frontend_multiprocessing: bool = False,
163
    client_config: dict[str, Any] | None = None,
164
) -> AsyncIterator[EngineClient]:
165
    """
166
    Create EngineClient, either:
167
168
169
170
171
172
        - in-process using the AsyncLLMEngine Directly
        - multiprocess using AsyncLLMEngine RPC

    Returns the Client or None if the creation failed.
    """

173
174
175
    # Create the EngineConfig (determines if we can use V1).
    vllm_config = engine_args.create_engine_config(usage_context=usage_context)

176
    if disable_frontend_multiprocessing:
177
        logger.warning("V1 is enabled, but got --disable-frontend-multiprocessing.")
178

179
    from vllm.v1.engine.async_llm import AsyncLLM
180

181
    async_llm: AsyncLLM | None = None
182
183
184
185
186
187

    # Don't mutate the input client_config
    client_config = dict(client_config) if client_config else {}
    client_count = client_config.pop("client_count", 1)
    client_index = client_config.pop("client_index", 0)

188
189
190
191
192
    try:
        async_llm = AsyncLLM.from_vllm_config(
            vllm_config=vllm_config,
            usage_context=usage_context,
            enable_log_requests=engine_args.enable_log_requests,
193
            aggregate_engine_logging=engine_args.aggregate_engine_logging,
194
195
196
            disable_log_stats=engine_args.disable_log_stats,
            client_addresses=client_config,
            client_count=client_count,
197
198
            client_index=client_index,
        )
199
200

        # Don't keep the dummy data in memory
201
        assert async_llm is not None
202
203
204
205
206
207
        await async_llm.reset_mm_cache()

        yield async_llm
    finally:
        if async_llm:
            async_llm.shutdown()
208
209


Ethan Xu's avatar
Ethan Xu committed
210
router = APIRouter()
Zhuohan Li's avatar
Zhuohan Li committed
211

212

213
214
215
216
217
218
219
def base(request: Request) -> OpenAIServing:
    # Reuse the existing instance
    return tokenization(request)


def tokenization(request: Request) -> OpenAIServingTokenization:
    return request.app.state.openai_serving_tokenization
220
221


222
def engine_client(request: Request) -> EngineClient:
223
224
225
    return request.app.state.engine_client


226
227
228
229
@router.get("/load")
async def get_server_load_metrics(request: Request):
    # This endpoint returns the current server load metrics.
    # It tracks requests utilizing the GPU from the following routes:
230
231
232
233
    # - /v1/responses
    # - /v1/responses/{response_id}
    # - /v1/responses/{response_id}/cancel
    # - /v1/messages
234
235
236
    # - /v1/chat/completions
    # - /v1/completions
    # - /v1/audio/transcriptions
237
    # - /v1/audio/translations
238
239
    # - /v1/embeddings
    # - /pooling
240
    # - /classify
241
242
243
244
245
    # - /score
    # - /v1/score
    # - /rerank
    # - /v1/rerank
    # - /v2/rerank
246
    return JSONResponse(content={"server_load": request.app.state.server_load_metrics})
247
248


Ethan Xu's avatar
Ethan Xu committed
249
@router.get("/version")
250
async def show_version():
251
    ver = {"version": VLLM_VERSION}
252
253
254
    return JSONResponse(content=ver)


255
def load_log_config(log_config_file: str | None) -> dict | None:
256
257
258
259
260
261
    if not log_config_file:
        return None
    try:
        with open(log_config_file) as f:
            return json.load(f)
    except Exception as e:
262
263
264
        logger.warning(
            "Failed to load log config from file %s: error %s", log_config_file, e
        )
265
266
267
        return None


268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
def get_uvicorn_log_config(args: Namespace) -> dict | None:
    """
    Get the uvicorn log config based on the provided arguments.

    Priority:
    1. If log_config_file is specified, use it
    2. If disable_access_log_for_endpoints is specified, create a config with
       the access log filter
    3. Otherwise, return None (use uvicorn defaults)
    """
    # First, try to load from file if specified
    log_config = load_log_config(args.log_config_file)
    if log_config is not None:
        return log_config

    # If endpoints to filter are specified, create a config with the filter
    if args.disable_access_log_for_endpoints:
        from vllm.logging_utils import create_uvicorn_log_config

        # Parse comma-separated string into list
        excluded_paths = [
            p.strip()
            for p in args.disable_access_log_for_endpoints.split(",")
            if p.strip()
        ]
        return create_uvicorn_log_config(
            excluded_paths=excluded_paths,
            log_level=args.uvicorn_log_level,
        )

    return None


301
302
303
class AuthenticationMiddleware:
    """
    Pure ASGI middleware that authenticates each request by checking
304
    if the Authorization Bearer token exists and equals anyof "{api_key}".
305
306
307
308
309
310
311
312

    Notes
    -----
    There are two cases in which authentication is skipped:
        1. The HTTP method is OPTIONS.
        2. The request path doesn't start with /v1 (e.g. /health).
    """

313
    def __init__(self, app: ASGIApp, tokens: list[str]) -> None:
314
        self.app = app
315
        self.api_tokens = [hashlib.sha256(t.encode("utf-8")).digest() for t in tokens]
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332

    def verify_token(self, headers: Headers) -> bool:
        authorization_header_value = headers.get("Authorization")
        if not authorization_header_value:
            return False

        scheme, _, param = authorization_header_value.partition(" ")
        if scheme.lower() != "bearer":
            return False

        param_hash = hashlib.sha256(param.encode("utf-8")).digest()

        token_match = False
        for token_hash in self.api_tokens:
            token_match |= secrets.compare_digest(param_hash, token_hash)

        return token_match
333

334
335
    def __call__(self, scope: Scope, receive: Receive, send: Send) -> Awaitable[None]:
        if scope["type"] not in ("http", "websocket") or scope["method"] == "OPTIONS":
336
337
338
339
340
341
342
            # scope["type"] can be "lifespan" or "startup" for example,
            # in which case we don't need to do anything
            return self.app(scope, receive, send)
        root_path = scope.get("root_path", "")
        url_path = URL(scope=scope).path.removeprefix(root_path)
        headers = Headers(scope=scope)
        # Type narrow to satisfy mypy.
343
        if url_path.startswith("/v1") and not self.verify_token(headers):
344
            response = JSONResponse(content={"error": "Unauthorized"}, status_code=401)
345
346
347
348
349
350
351
352
353
354
355
356
357
358
            return response(scope, receive, send)
        return self.app(scope, receive, send)


class XRequestIdMiddleware:
    """
    Middleware the set's the X-Request-Id header for each response
    to a random uuid4 (hex) value if the header isn't already
    present in the request, otherwise use the provided request id.
    """

    def __init__(self, app: ASGIApp) -> None:
        self.app = app

359
    def __call__(self, scope: Scope, receive: Receive, send: Send) -> Awaitable[None]:
360
361
362
363
364
365
366
367
368
369
370
371
372
        if scope["type"] not in ("http", "websocket"):
            return self.app(scope, receive, send)

        # Extract the request headers.
        request_headers = Headers(scope=scope)

        async def send_with_request_id(message: Message) -> None:
            """
            Custom send function to mutate the response headers
            and append X-Request-Id to it.
            """
            if message["type"] == "http.response.start":
                response_headers = MutableHeaders(raw=message["headers"])
373
                request_id = request_headers.get("X-Request-Id", uuid.uuid4().hex)
374
375
376
377
378
379
                response_headers.append("X-Request-Id", request_id)
            await send(message)

        return self.app(scope, receive, send_with_request_id)


380
381
382
def _extract_content_from_chunk(chunk_data: dict) -> str:
    """Extract content from a streaming response chunk."""
    try:
383
        from vllm.entrypoints.openai.chat_completion.protocol import (
384
            ChatCompletionStreamResponse,
385
        )
386
        from vllm.entrypoints.openai.completion.protocol import (
387
388
            CompletionStreamResponse,
        )
389
390

        # Try using Completion types for type-safe parsing
391
392
        if chunk_data.get("object") == "chat.completion.chunk":
            chat_response = ChatCompletionStreamResponse.model_validate(chunk_data)
393
394
            if chat_response.choices and chat_response.choices[0].delta.content:
                return chat_response.choices[0].delta.content
395
396
397
        elif chunk_data.get("object") == "text_completion":
            completion_response = CompletionStreamResponse.model_validate(chunk_data)
            if completion_response.choices and completion_response.choices[0].text:
398
399
400
                return completion_response.choices[0].text
    except pydantic.ValidationError:
        # Fallback to manual parsing
401
402
403
404
405
406
        if "choices" in chunk_data and chunk_data["choices"]:
            choice = chunk_data["choices"][0]
            if "delta" in choice and choice["delta"].get("content"):
                return choice["delta"]["content"]
            elif choice.get("text"):
                return choice["text"]
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
    return ""


class SSEDecoder:
    """Robust Server-Sent Events decoder for streaming responses."""

    def __init__(self):
        self.buffer = ""
        self.content_buffer = []

    def decode_chunk(self, chunk: bytes) -> list[dict]:
        """Decode a chunk of SSE data and return parsed events."""
        import json

        try:
422
            chunk_str = chunk.decode("utf-8")
423
424
425
426
427
428
429
430
        except UnicodeDecodeError:
            # Skip malformed chunks
            return []

        self.buffer += chunk_str
        events = []

        # Process complete lines
431
432
433
        while "\n" in self.buffer:
            line, self.buffer = self.buffer.split("\n", 1)
            line = line.rstrip("\r")  # Handle CRLF
434

435
            if line.startswith("data: "):
436
                data_str = line[6:].strip()
437
438
                if data_str == "[DONE]":
                    events.append({"type": "done"})
439
440
441
                elif data_str:
                    try:
                        event_data = json.loads(data_str)
442
                        events.append({"type": "data", "data": event_data})
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
                    except json.JSONDecodeError:
                        # Skip malformed JSON
                        continue

        return events

    def extract_content(self, event_data: dict) -> str:
        """Extract content from event data."""
        return _extract_content_from_chunk(event_data)

    def add_content(self, content: str) -> None:
        """Add content to the buffer."""
        if content:
            self.content_buffer.append(content)

    def get_complete_content(self) -> str:
        """Get the complete buffered content."""
460
        return "".join(self.content_buffer)
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480


def _log_streaming_response(response, response_body: list) -> None:
    """Log streaming response with robust SSE parsing."""
    from starlette.concurrency import iterate_in_threadpool

    sse_decoder = SSEDecoder()
    chunk_count = 0

    def buffered_iterator():
        nonlocal chunk_count

        for chunk in response_body:
            chunk_count += 1
            yield chunk

            # Parse SSE events from chunk
            events = sse_decoder.decode_chunk(chunk)

            for event in events:
481
482
                if event["type"] == "data":
                    content = sse_decoder.extract_content(event["data"])
483
                    sse_decoder.add_content(content)
484
                elif event["type"] == "done":
485
486
487
488
489
490
491
492
                    # Log complete content when done
                    full_content = sse_decoder.get_complete_content()
                    if full_content:
                        # Truncate if too long
                        if len(full_content) > 2048:
                            full_content = full_content[:2048] + ""
                            "...[truncated]"
                        logger.info(
493
                            "response_body={streaming_complete: content=%r, chunks=%d}",
494
495
496
                            full_content,
                            chunk_count,
                        )
497
498
                    else:
                        logger.info(
499
500
501
                            "response_body={streaming_complete: no_content, chunks=%d}",
                            chunk_count,
                        )
502
503
504
                    return

    response.body_iterator = iterate_in_threadpool(buffered_iterator())
505
    logger.info("response_body={streaming_started: chunks=%d}", len(response_body))
506
507
508
509
510
511
512
513
514
515
516


def _log_non_streaming_response(response_body: list) -> None:
    """Log non-streaming response."""
    try:
        decoded_body = response_body[0].decode()
        logger.info("response_body={%s}", decoded_body)
    except UnicodeDecodeError:
        logger.info("response_body={<binary_data>}")


517
def build_app(args: Namespace) -> FastAPI:
518
    if args.disable_fastapi_docs:
519
520
521
        app = FastAPI(
            openapi_url=None, docs_url=None, redoc_url=None, lifespan=lifespan
        )
522
523
    elif args.enable_offline_docs:
        app = FastAPI(docs_url=None, redoc_url=None, lifespan=lifespan)
524
525
    else:
        app = FastAPI(lifespan=lifespan)
526
527
    app.state.args = args
    from vllm.entrypoints.serve import register_vllm_serve_api_routers
528

529
    register_vllm_serve_api_routers(app)
530
531
532
    from vllm.entrypoints.openai.chat_completion.api_router import (
        attach_router as register_chat_api_router,
    )
533

534
    register_chat_api_router(app)
535
536
537
538
539
540

    from vllm.entrypoints.openai.responses.api_router import (
        attach_router as register_responses_api_router,
    )

    register_responses_api_router(app)
541
542
543
544
545
    from vllm.entrypoints.openai.translations.api_router import (
        attach_router as register_translations_api_router,
    )

    register_translations_api_router(app)
546
547
548
549

    from vllm.entrypoints.openai.completion.api_router import (
        attach_router as register_completion_api_router,
    )
550

551
552
553
554
555
556
557
558
559
560
561
    register_completion_api_router(app)
    from vllm.entrypoints.anthropic.api_router import (
        attach_router as register_anthropic_api_router,
    )

    register_anthropic_api_router(app)
    from vllm.entrypoints.openai.models.api_router import (
        attach_router as register_models_api_router,
    )

    register_models_api_router(app)
562
563
564
    from vllm.entrypoints.sagemaker.routes import register_sagemaker_routes

    register_sagemaker_routes(router)
Ethan Xu's avatar
Ethan Xu committed
565
    app.include_router(router)
566

Ethan Xu's avatar
Ethan Xu committed
567
    app.root_path = args.root_path
Zhuohan Li's avatar
Zhuohan Li committed
568

569
570
571
572
    from vllm.entrypoints.pooling import register_pooling_api_routers

    register_pooling_api_routers(app)

Zhuohan Li's avatar
Zhuohan Li committed
573
574
575
576
577
578
579
580
    app.add_middleware(
        CORSMiddleware,
        allow_origins=args.allowed_origins,
        allow_credentials=args.allow_credentials,
        allow_methods=args.allowed_methods,
        allow_headers=args.allowed_headers,
    )

581
582
    @app.exception_handler(HTTPException)
    async def http_exception_handler(_: Request, exc: HTTPException):
583
        err = ErrorResponse(
584
            error=ErrorInfo(
585
                message=sanitize_message(exc.detail),
586
587
588
589
                type=HTTPStatus(exc.status_code).phrase,
                code=exc.status_code,
            )
        )
590
591
        return JSONResponse(err.model_dump(), status_code=exc.status_code)

Ethan Xu's avatar
Ethan Xu committed
592
    @app.exception_handler(RequestValidationError)
593
    async def validation_exception_handler(_: Request, exc: RequestValidationError):
594
        param = None
595
596
        errors = exc.errors()
        for error in errors:
597
598
599
600
601
602
            if "ctx" in error and "error" in error["ctx"]:
                ctx_error = error["ctx"]["error"]
                if isinstance(ctx_error, VLLMValidationError):
                    param = ctx_error.parameter
                    break

603
        exc_str = str(exc)
604
        errors_str = str(errors)
605

606
        if errors and errors_str and errors_str != exc_str:
607
608
609
610
            message = f"{exc_str} {errors_str}"
        else:
            message = exc_str

611
612
        err = ErrorResponse(
            error=ErrorInfo(
613
                message=sanitize_message(message),
614
615
                type=HTTPStatus.BAD_REQUEST.phrase,
                code=HTTPStatus.BAD_REQUEST,
616
                param=param,
617
618
619
            )
        )
        return JSONResponse(err.model_dump(), status_code=HTTPStatus.BAD_REQUEST)
Ethan Xu's avatar
Ethan Xu committed
620

621
    # Ensure --api-key option from CLI takes precedence over VLLM_API_KEY
622
623
    if tokens := [key for key in (args.api_key or [envs.VLLM_API_KEY]) if key]:
        app.add_middleware(AuthenticationMiddleware, tokens=tokens)
624

625
    if args.enable_request_id_headers:
626
        app.add_middleware(XRequestIdMiddleware)
627

628
629
630
    # Add scaling middleware to check for scaling state
    app.add_middleware(ScalingMiddleware)

631
    if envs.VLLM_DEBUG_LOG_API_SERVER_RESPONSE:
632
633
634
635
636
        logger.warning(
            "CAUTION: Enabling log response in the API Server. "
            "This can include sensitive information and should be "
            "avoided in production."
        )
637
638
639
640

        @app.middleware("http")
        async def log_response(request: Request, call_next):
            response = await call_next(request)
641
            response_body = [section async for section in response.body_iterator]
642
            response.body_iterator = iterate_in_threadpool(iter(response_body))
643
644
645
646
647
648
649
650
651
652
653
            # Check if this is a streaming response by looking at content-type
            content_type = response.headers.get("content-type", "")
            is_streaming = content_type == "text/event-stream; charset=utf-8"

            # Log response body based on type
            if not response_body:
                logger.info("response_body={<empty>}")
            elif is_streaming:
                _log_streaming_response(response, response_body)
            else:
                _log_non_streaming_response(response_body)
654
            return response
655

656
657
658
659
    for middleware in args.middleware:
        module_path, object_name = middleware.rsplit(".", 1)
        imported = getattr(importlib.import_module(module_path), object_name)
        if inspect.isclass(imported):
660
            app.add_middleware(imported)  # type: ignore[arg-type]
661
662
663
        elif inspect.iscoroutinefunction(imported):
            app.middleware("http")(imported)
        else:
664
665
666
            raise ValueError(
                f"Invalid middleware {middleware}. Must be a function or a class."
            )
667

668
    app = sagemaker_standards.bootstrap(app)
669

Ethan Xu's avatar
Ethan Xu committed
670
671
672
    return app


673
async def init_app_state(
674
    engine_client: EngineClient,
675
    state: State,
676
    args: Namespace,
677
) -> None:
678
679
    vllm_config = engine_client.vllm_config

680
    if args.served_model_name is not None:
681
        served_model_names = args.served_model_name
682
    else:
683
        served_model_names = [args.model]
684

685
    if args.enable_log_requests:
686
        request_logger = RequestLogger(max_log_len=args.max_log_len)
687
688
    else:
        request_logger = None
689

690
    base_model_paths = [
691
        BaseModelPath(name=name, model_path=args.model) for name in served_model_names
692
693
    ]

694
    state.engine_client = engine_client
695
    state.log_stats = not args.disable_log_stats
696
    state.vllm_config = vllm_config
697
    state.args = args
698
    supported_tasks = await engine_client.get_supported_tasks()
699
    logger.info("Supported tasks: %s", supported_tasks)
700

701
    resolved_chat_template = load_chat_template(args.chat_template)
702

703
    if args.tool_server == "demo":
704
        tool_server: ToolServer | None = DemoToolServer()
705
706
        assert isinstance(tool_server, DemoToolServer)
        await tool_server.init_and_validate()
707
708
709
    elif args.tool_server:
        tool_server = MCPToolServer()
        await tool_server.add_tool_server(args.tool_server)
710
711
712
    else:
        tool_server = None

713
    # Merge default_mm_loras into the static lora_modules
714
715
716
717
718
    default_mm_loras = (
        vllm_config.lora_config.default_mm_loras
        if vllm_config.lora_config is not None
        else {}
    )
719

720
721
722
723
724
725
    default_mm_loras = (
        vllm_config.lora_config.default_mm_loras
        if vllm_config.lora_config is not None
        else {}
    )
    lora_modules = process_lora_modules(args.lora_modules, default_mm_loras)
726

727
    state.openai_serving_models = OpenAIServingModels(
728
        engine_client=engine_client,
729
        base_model_paths=base_model_paths,
730
        lora_modules=lora_modules,
731
    )
732
    await state.openai_serving_models.init_static_loras()
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
    state.openai_serving_responses = (
        OpenAIServingResponses(
            engine_client,
            state.openai_serving_models,
            request_logger=request_logger,
            chat_template=resolved_chat_template,
            chat_template_content_format=args.chat_template_content_format,
            return_tokens_as_token_ids=args.return_tokens_as_token_ids,
            enable_auto_tools=args.enable_auto_tool_choice,
            tool_parser=args.tool_call_parser,
            tool_server=tool_server,
            reasoning_parser=args.structured_outputs_config.reasoning_parser,
            enable_prompt_tokens_details=args.enable_prompt_tokens_details,
            enable_force_include_usage=args.enable_force_include_usage,
            enable_log_outputs=args.enable_log_outputs,
            log_error_stack=args.log_error_stack,
        )
        if "generate" in supported_tasks
        else None
    )
    state.openai_serving_chat = (
        OpenAIServingChat(
            engine_client,
            state.openai_serving_models,
            args.response_role,
            request_logger=request_logger,
            chat_template=resolved_chat_template,
            chat_template_content_format=args.chat_template_content_format,
761
            default_chat_template_kwargs=args.default_chat_template_kwargs,
762
763
764
765
766
767
768
769
770
            trust_request_chat_template=args.trust_request_chat_template,
            return_tokens_as_token_ids=args.return_tokens_as_token_ids,
            enable_auto_tools=args.enable_auto_tool_choice,
            exclude_tools_when_tool_choice_none=args.exclude_tools_when_tool_choice_none,
            tool_parser=args.tool_call_parser,
            reasoning_parser=args.structured_outputs_config.reasoning_parser,
            enable_prompt_tokens_details=args.enable_prompt_tokens_details,
            enable_force_include_usage=args.enable_force_include_usage,
            enable_log_outputs=args.enable_log_outputs,
771
            enable_log_deltas=args.enable_log_deltas,
772
773
774
775
776
            log_error_stack=args.log_error_stack,
        )
        if "generate" in supported_tasks
        else None
    )
777
778
779
    # Warm up chat template processing to avoid first-request latency
    if state.openai_serving_chat is not None:
        await state.openai_serving_chat.warmup()
780
781
782
783
784
785
786
787
788
789
790
791
792
    state.openai_serving_completion = (
        OpenAIServingCompletion(
            engine_client,
            state.openai_serving_models,
            request_logger=request_logger,
            return_tokens_as_token_ids=args.return_tokens_as_token_ids,
            enable_prompt_tokens_details=args.enable_prompt_tokens_details,
            enable_force_include_usage=args.enable_force_include_usage,
            log_error_stack=args.log_error_stack,
        )
        if "generate" in supported_tasks
        else None
    )
793
    state.openai_serving_tokenization = OpenAIServingTokenization(
794
        engine_client,
795
        state.openai_serving_models,
796
        request_logger=request_logger,
797
798
        chat_template=resolved_chat_template,
        chat_template_content_format=args.chat_template_content_format,
799
        trust_request_chat_template=args.trust_request_chat_template,
800
        log_error_stack=args.log_error_stack,
801
    )
802
803
804
805
806
807
    state.openai_serving_transcription = (
        OpenAIServingTranscription(
            engine_client,
            state.openai_serving_models,
            request_logger=request_logger,
            log_error_stack=args.log_error_stack,
808
            enable_force_include_usage=args.enable_force_include_usage,
809
810
811
812
813
814
815
816
817
818
        )
        if "transcription" in supported_tasks
        else None
    )
    state.openai_serving_translation = (
        OpenAIServingTranslation(
            engine_client,
            state.openai_serving_models,
            request_logger=request_logger,
            log_error_stack=args.log_error_stack,
819
            enable_force_include_usage=args.enable_force_include_usage,
820
821
822
823
        )
        if "transcription" in supported_tasks
        else None
    )
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
    state.anthropic_serving_messages = (
        AnthropicServingMessages(
            engine_client,
            state.openai_serving_models,
            args.response_role,
            request_logger=request_logger,
            chat_template=resolved_chat_template,
            chat_template_content_format=args.chat_template_content_format,
            return_tokens_as_token_ids=args.return_tokens_as_token_ids,
            enable_auto_tools=args.enable_auto_tool_choice,
            tool_parser=args.tool_call_parser,
            reasoning_parser=args.structured_outputs_config.reasoning_parser,
            enable_prompt_tokens_details=args.enable_prompt_tokens_details,
            enable_force_include_usage=args.enable_force_include_usage,
        )
        if "generate" in supported_tasks
        else None
    )
842
843
844
845
846
847
848
849
850
851
852
853
854
855
    state.serving_tokens = (
        ServingTokens(
            engine_client,
            state.openai_serving_models,
            request_logger=request_logger,
            return_tokens_as_token_ids=args.return_tokens_as_token_ids,
            log_error_stack=args.log_error_stack,
            enable_prompt_tokens_details=args.enable_prompt_tokens_details,
            enable_log_outputs=args.enable_log_outputs,
            force_no_detokenize=args.tokens_only,
        )
        if "generate" in supported_tasks
        else None
    )
856

857
858
859
860
    from vllm.entrypoints.pooling import init_pooling_state

    await init_pooling_state(engine_client, state, args)

861
862
863
    state.enable_server_load_tracking = args.enable_server_load_tracking
    state.server_load_metrics = 0

864

865
def create_server_socket(addr: tuple[str, int]) -> socket.socket:
866
867
868
869
870
871
    family = socket.AF_INET
    if is_valid_ipv6_address(addr[0]):
        family = socket.AF_INET6

    sock = socket.socket(family=family, type=socket.SOCK_STREAM)
    sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
872
    sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
873
874
875
876
877
    sock.bind(addr)

    return sock


878
879
880
881
882
883
def create_server_unix_socket(path: str) -> socket.socket:
    sock = socket.socket(family=socket.AF_UNIX, type=socket.SOCK_STREAM)
    sock.bind(path)
    return sock


884
def validate_api_server_args(args):
885
    valid_tool_parses = ToolParserManager.list_registered()
886
887
888
889
890
    if args.enable_auto_tool_choice and args.tool_call_parser not in valid_tool_parses:
        raise KeyError(
            f"invalid tool call parser: {args.tool_call_parser} "
            f"(chose from {{ {','.join(valid_tool_parses)} }})"
        )
891

892
    valid_reasoning_parsers = ReasoningParserManager.list_registered()
893
894
    if (
        reasoning_parser := args.structured_outputs_config.reasoning_parser
895
    ) and reasoning_parser not in valid_reasoning_parsers:
896
        raise KeyError(
897
            f"invalid reasoning parser: {reasoning_parser} "
898
            f"(chose from {{ {','.join(valid_reasoning_parsers)} }})"
899
        )
900

901
902
903
904
905

def setup_server(args):
    """Validate API server args, set up signal handler, create socket
    ready to serve."""

906
    log_version_and_model(logger, VLLM_VERSION, args.model)
907
908
909
910
911
    log_non_default_args(args)

    if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3:
        ToolParserManager.import_tool_parser(args.tool_parser_plugin)

912
913
914
    if args.reasoning_parser_plugin and len(args.reasoning_parser_plugin) > 3:
        ReasoningParserManager.import_reasoning_parser(args.reasoning_parser_plugin)

915
916
    validate_api_server_args(args)

917
918
919
    # workaround to make sure that we bind the port before the engine is set up.
    # This avoids race conditions with ray.
    # see https://github.com/vllm-project/vllm/issues/8204
920
921
922
923
924
    if args.uds:
        sock = create_server_unix_socket(args.uds)
    else:
        sock_addr = (args.host or "", args.port)
        sock = create_server_socket(sock_addr)
925

926
927
928
929
    # workaround to avoid footguns where uvicorn drops requests with too
    # many concurrent requests active
    set_ulimit()

930
931
932
933
934
935
    def signal_handler(*_) -> None:
        # Interrupt server on sigterm while initializing
        raise KeyboardInterrupt("terminated")

    signal.signal(signal.SIGTERM, signal_handler)

936
937
938
939
940
    if args.uds:
        listen_address = f"unix:{args.uds}"
    else:
        addr, port = sock_addr
        is_ssl = args.ssl_keyfile and args.ssl_certfile
941
        host_part = f"[{addr}]" if is_valid_ipv6_address(addr) else addr or "0.0.0.0"
942
        listen_address = f"http{'s' if is_ssl else ''}://{host_part}:{port}"
943
944
945
946
947
    return listen_address, sock


async def run_server(args, **uvicorn_kwargs) -> None:
    """Run a single-worker API server."""
948
949

    # Add process-specific prefix to stdout and stderr.
950
    decorate_logs("APIServer")
951

952
953
954
955
    listen_address, sock = setup_server(args)
    await run_server_worker(listen_address, sock, args, **uvicorn_kwargs)


956
957
958
async def run_server_worker(
    listen_address, sock, args, client_config=None, **uvicorn_kwargs
) -> None:
959
960
961
962
963
    """Run a single API server worker."""

    if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3:
        ToolParserManager.import_tool_parser(args.tool_parser_plugin)

964
965
966
    if args.reasoning_parser_plugin and len(args.reasoning_parser_plugin) > 3:
        ReasoningParserManager.import_reasoning_parser(args.reasoning_parser_plugin)

967
968
    # Get uvicorn log config (from file or with endpoint filter)
    log_config = get_uvicorn_log_config(args)
969
    if log_config is not None:
970
        uvicorn_kwargs["log_config"] = log_config
971

972
    async with build_async_engine_client(
973
974
        args,
        client_config=client_config,
975
    ) as engine_client:
976
977
        app = build_app(args)

978
        await init_app_state(engine_client, app.state, args)
979

980
981
        logger.info(
            "Starting vLLM API server %d on %s",
982
            engine_client.vllm_config.parallel_config._api_process_rank,
983
984
            listen_address,
        )
985
986
        shutdown_task = await serve_http(
            app,
987
            sock=sock,
988
            enable_ssl_refresh=args.enable_ssl_refresh,
989
990
991
            host=args.host,
            port=args.port,
            log_level=args.uvicorn_log_level,
992
993
994
            # NOTE: When the 'disable_uvicorn_access_log' value is True,
            # no access log will be output.
            access_log=not args.disable_uvicorn_access_log,
995
            timeout_keep_alive=envs.VLLM_HTTP_TIMEOUT_KEEP_ALIVE,
996
997
998
999
            ssl_keyfile=args.ssl_keyfile,
            ssl_certfile=args.ssl_certfile,
            ssl_ca_certs=args.ssl_ca_certs,
            ssl_cert_reqs=args.ssl_cert_reqs,
1000
            ssl_ciphers=args.ssl_ciphers,
1001
1002
            h11_max_incomplete_event_size=args.h11_max_incomplete_event_size,
            h11_max_header_count=args.h11_max_header_count,
1003
1004
1005
            **uvicorn_kwargs,
        )

1006
    # NB: Await server shutdown only after the backend context is exited
1007
1008
1009
1010
    try:
        await shutdown_task
    finally:
        sock.close()
1011

Ethan Xu's avatar
Ethan Xu committed
1012
1013
1014

if __name__ == "__main__":
    # NOTE(simon):
1015
1016
    # This section should be in sync with vllm/entrypoints/cli/main.py for CLI
    # entrypoints.
1017
    cli_env_setup()
Ethan Xu's avatar
Ethan Xu committed
1018
    parser = FlexibleArgumentParser(
1019
1020
        description="vLLM OpenAI-Compatible RESTful API server."
    )
Ethan Xu's avatar
Ethan Xu committed
1021
1022
    parser = make_arg_parser(parser)
    args = parser.parse_args()
1023
    validate_parsed_serve_args(args)
1024

1025
    uvloop.run(run_server(args))