api_server.py 35.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import asyncio
4
import hashlib
5
6
import importlib
import inspect
7
import json
8
import multiprocessing
9
import multiprocessing.forkserver as forkserver
10
import os
11
import secrets
12
import signal
13
import socket
14
import tempfile
15
import uuid
16
from argparse import Namespace
17
from collections.abc import AsyncIterator, Awaitable
18
from contextlib import asynccontextmanager
19
from http import HTTPStatus
20
from typing import Any
21

22
import model_hosting_container_standards.sagemaker as sagemaker_standards
23
import pydantic
24
import uvloop
25
from fastapi import APIRouter, FastAPI, HTTPException, Request
Zhuohan Li's avatar
Zhuohan Li committed
26
27
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
28
from fastapi.responses import JSONResponse
29
from starlette.concurrency import iterate_in_threadpool
30
31
from starlette.datastructures import URL, Headers, MutableHeaders, State
from starlette.types import ASGIApp, Message, Receive, Scope, Send
Zhuohan Li's avatar
Zhuohan Li committed
32

33
import vllm.envs as envs
Woosuk Kwon's avatar
Woosuk Kwon committed
34
from vllm.engine.arg_utils import AsyncEngineArgs
35
from vllm.engine.protocol import EngineClient
36
from vllm.entrypoints.anthropic.serving import AnthropicServingMessages
37
from vllm.entrypoints.chat_utils import load_chat_template
38
from vllm.entrypoints.launcher import serve_http
39
from vllm.entrypoints.logger import RequestLogger
40
from vllm.entrypoints.mcp.tool_server import DemoToolServer, MCPToolServer, ToolServer
41
from vllm.entrypoints.openai.chat_completion.serving import OpenAIServingChat
42
from vllm.entrypoints.openai.cli_args import make_arg_parser, validate_parsed_serve_args
43
from vllm.entrypoints.openai.completion.serving import OpenAIServingCompletion
44
from vllm.entrypoints.openai.engine.protocol import (
45
46
47
    ErrorInfo,
    ErrorResponse,
)
48
from vllm.entrypoints.openai.engine.serving import OpenAIServing
49
50
from vllm.entrypoints.openai.models.protocol import BaseModelPath
from vllm.entrypoints.openai.models.serving import (
51
52
    OpenAIServingModels,
)
53
from vllm.entrypoints.openai.responses.serving import OpenAIServingResponses
54
from vllm.entrypoints.openai.translations.serving import (
55
56
57
    OpenAIServingTranscription,
    OpenAIServingTranslation,
)
58
59
60
61
62
from vllm.entrypoints.serve.disagg.serving import ServingTokens
from vllm.entrypoints.serve.elastic_ep.middleware import (
    ScalingMiddleware,
)
from vllm.entrypoints.serve.tokenize.serving import OpenAIServingTokenization
63
64
65
from vllm.entrypoints.utils import (
    cli_env_setup,
    log_non_default_args,
66
    process_lora_modules,
67
    sanitize_message,
68
)
69
from vllm.exceptions import VLLMValidationError
70
from vllm.logger import init_logger
71
from vllm.reasoning import ReasoningParserManager
72
from vllm.tool_parsers import ToolParserManager
yhu422's avatar
yhu422 committed
73
from vllm.usage.usage_lib import UsageContext
Cyrus Leung's avatar
Cyrus Leung committed
74
from vllm.utils.argparse_utils import FlexibleArgumentParser
75
from vllm.utils.gc_utils import freeze_gc_heap
76
from vllm.utils.network_utils import is_valid_ipv6_address
77
from vllm.utils.system_utils import decorate_logs, set_ulimit
78
from vllm.version import __version__ as VLLM_VERSION
Zhuohan Li's avatar
Zhuohan Li committed
79

80
prometheus_multiproc_dir: tempfile.TemporaryDirectory
81

82
# Cannot use __name__ (https://github.com/vllm-project/vllm/pull/4765)
83
logger = init_logger("vllm.entrypoints.openai.api_server")
84

85

86
_running_tasks: set[asyncio.Task] = set()
87

88

89
@asynccontextmanager
90
async def lifespan(app: FastAPI):
91
92
    try:
        if app.state.log_stats:
93
            engine_client: EngineClient = app.state.engine_client
94
95
96

            async def _force_log():
                while True:
97
                    await asyncio.sleep(envs.VLLM_LOG_STATS_INTERVAL)
98
                    await engine_client.do_log_stats()
99
100
101
102
103
104

            task = asyncio.create_task(_force_log())
            _running_tasks.add(task)
            task.add_done_callback(_running_tasks.remove)
        else:
            task = None
105
106
107

        # Mark the startup heap as static so that it's ignored by GC.
        # Reduces pause times of oldest generation collections.
108
        freeze_gc_heap()
109
110
111
112
113
114
115
116
        try:
            yield
        finally:
            if task is not None:
                task.cancel()
    finally:
        # Ensure app state including engine ref is gc'd
        del app.state
117
118


119
@asynccontextmanager
120
async def build_async_engine_client(
121
    args: Namespace,
122
123
    *,
    usage_context: UsageContext = UsageContext.OPENAI_API_SERVER,
124
125
    disable_frontend_multiprocessing: bool | None = None,
    client_config: dict[str, Any] | None = None,
126
) -> AsyncIterator[EngineClient]:
127
128
129
130
    if os.getenv("VLLM_WORKER_MULTIPROC_METHOD") == "forkserver":
        # The executor is expected to be mp.
        # Pre-import heavy modules in the forkserver process
        logger.debug("Setup forkserver with pre-imports")
131
        multiprocessing.set_start_method("forkserver")
132
133
134
135
        multiprocessing.set_forkserver_preload(["vllm.v1.engine.async_llm"])
        forkserver.ensure_running()
        logger.debug("Forkserver setup complete!")

136
    # Context manager to handle engine_client lifecycle
137
138
    # Ensures everything is shutdown and cleaned up on error/exit
    engine_args = AsyncEngineArgs.from_cli_args(args)
139
140
141
    if client_config:
        engine_args._api_process_count = client_config.get("client_count", 1)
        engine_args._api_process_rank = client_config.get("client_index", 0)
142

143
    if disable_frontend_multiprocessing is None:
144
        disable_frontend_multiprocessing = bool(args.disable_frontend_multiprocessing)
145

146
    async with build_async_engine_client_from_engine_args(
147
148
149
150
        engine_args,
        usage_context=usage_context,
        disable_frontend_multiprocessing=disable_frontend_multiprocessing,
        client_config=client_config,
151
    ) as engine:
152
153
154
155
156
157
        yield engine


@asynccontextmanager
async def build_async_engine_client_from_engine_args(
    engine_args: AsyncEngineArgs,
158
159
    *,
    usage_context: UsageContext = UsageContext.OPENAI_API_SERVER,
160
    disable_frontend_multiprocessing: bool = False,
161
    client_config: dict[str, Any] | None = None,
162
) -> AsyncIterator[EngineClient]:
163
    """
164
    Create EngineClient, either:
165
166
167
168
169
170
        - in-process using the AsyncLLMEngine Directly
        - multiprocess using AsyncLLMEngine RPC

    Returns the Client or None if the creation failed.
    """

171
172
173
    # Create the EngineConfig (determines if we can use V1).
    vllm_config = engine_args.create_engine_config(usage_context=usage_context)

174
    if disable_frontend_multiprocessing:
175
        logger.warning("V1 is enabled, but got --disable-frontend-multiprocessing.")
176

177
    from vllm.v1.engine.async_llm import AsyncLLM
178

179
    async_llm: AsyncLLM | None = None
180
181
182
183
184
185

    # Don't mutate the input client_config
    client_config = dict(client_config) if client_config else {}
    client_count = client_config.pop("client_count", 1)
    client_index = client_config.pop("client_index", 0)

186
187
188
189
190
    try:
        async_llm = AsyncLLM.from_vllm_config(
            vllm_config=vllm_config,
            usage_context=usage_context,
            enable_log_requests=engine_args.enable_log_requests,
191
            aggregate_engine_logging=engine_args.aggregate_engine_logging,
192
193
194
            disable_log_stats=engine_args.disable_log_stats,
            client_addresses=client_config,
            client_count=client_count,
195
196
            client_index=client_index,
        )
197
198

        # Don't keep the dummy data in memory
199
        assert async_llm is not None
200
201
202
203
204
205
        await async_llm.reset_mm_cache()

        yield async_llm
    finally:
        if async_llm:
            async_llm.shutdown()
206
207


Ethan Xu's avatar
Ethan Xu committed
208
router = APIRouter()
Zhuohan Li's avatar
Zhuohan Li committed
209

210

211
212
213
214
215
216
217
def base(request: Request) -> OpenAIServing:
    # Reuse the existing instance
    return tokenization(request)


def tokenization(request: Request) -> OpenAIServingTokenization:
    return request.app.state.openai_serving_tokenization
218
219


220
def engine_client(request: Request) -> EngineClient:
221
222
223
    return request.app.state.engine_client


224
225
226
227
@router.get("/load")
async def get_server_load_metrics(request: Request):
    # This endpoint returns the current server load metrics.
    # It tracks requests utilizing the GPU from the following routes:
228
229
230
231
    # - /v1/responses
    # - /v1/responses/{response_id}
    # - /v1/responses/{response_id}/cancel
    # - /v1/messages
232
233
234
    # - /v1/chat/completions
    # - /v1/completions
    # - /v1/audio/transcriptions
235
    # - /v1/audio/translations
236
237
    # - /v1/embeddings
    # - /pooling
238
    # - /classify
239
240
241
242
243
    # - /score
    # - /v1/score
    # - /rerank
    # - /v1/rerank
    # - /v2/rerank
244
    return JSONResponse(content={"server_load": request.app.state.server_load_metrics})
245
246


Ethan Xu's avatar
Ethan Xu committed
247
@router.get("/version")
248
async def show_version():
249
    ver = {"version": VLLM_VERSION}
250
251
252
    return JSONResponse(content=ver)


253
def load_log_config(log_config_file: str | None) -> dict | None:
254
255
256
257
258
259
    if not log_config_file:
        return None
    try:
        with open(log_config_file) as f:
            return json.load(f)
    except Exception as e:
260
261
262
        logger.warning(
            "Failed to load log config from file %s: error %s", log_config_file, e
        )
263
264
265
        return None


266
267
268
class AuthenticationMiddleware:
    """
    Pure ASGI middleware that authenticates each request by checking
269
    if the Authorization Bearer token exists and equals anyof "{api_key}".
270
271
272
273
274
275
276
277

    Notes
    -----
    There are two cases in which authentication is skipped:
        1. The HTTP method is OPTIONS.
        2. The request path doesn't start with /v1 (e.g. /health).
    """

278
    def __init__(self, app: ASGIApp, tokens: list[str]) -> None:
279
        self.app = app
280
        self.api_tokens = [hashlib.sha256(t.encode("utf-8")).digest() for t in tokens]
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297

    def verify_token(self, headers: Headers) -> bool:
        authorization_header_value = headers.get("Authorization")
        if not authorization_header_value:
            return False

        scheme, _, param = authorization_header_value.partition(" ")
        if scheme.lower() != "bearer":
            return False

        param_hash = hashlib.sha256(param.encode("utf-8")).digest()

        token_match = False
        for token_hash in self.api_tokens:
            token_match |= secrets.compare_digest(param_hash, token_hash)

        return token_match
298

299
300
    def __call__(self, scope: Scope, receive: Receive, send: Send) -> Awaitable[None]:
        if scope["type"] not in ("http", "websocket") or scope["method"] == "OPTIONS":
301
302
303
304
305
306
307
            # scope["type"] can be "lifespan" or "startup" for example,
            # in which case we don't need to do anything
            return self.app(scope, receive, send)
        root_path = scope.get("root_path", "")
        url_path = URL(scope=scope).path.removeprefix(root_path)
        headers = Headers(scope=scope)
        # Type narrow to satisfy mypy.
308
        if url_path.startswith("/v1") and not self.verify_token(headers):
309
            response = JSONResponse(content={"error": "Unauthorized"}, status_code=401)
310
311
312
313
314
315
316
317
318
319
320
321
322
323
            return response(scope, receive, send)
        return self.app(scope, receive, send)


class XRequestIdMiddleware:
    """
    Middleware the set's the X-Request-Id header for each response
    to a random uuid4 (hex) value if the header isn't already
    present in the request, otherwise use the provided request id.
    """

    def __init__(self, app: ASGIApp) -> None:
        self.app = app

324
    def __call__(self, scope: Scope, receive: Receive, send: Send) -> Awaitable[None]:
325
326
327
328
329
330
331
332
333
334
335
336
337
        if scope["type"] not in ("http", "websocket"):
            return self.app(scope, receive, send)

        # Extract the request headers.
        request_headers = Headers(scope=scope)

        async def send_with_request_id(message: Message) -> None:
            """
            Custom send function to mutate the response headers
            and append X-Request-Id to it.
            """
            if message["type"] == "http.response.start":
                response_headers = MutableHeaders(raw=message["headers"])
338
                request_id = request_headers.get("X-Request-Id", uuid.uuid4().hex)
339
340
341
342
343
344
                response_headers.append("X-Request-Id", request_id)
            await send(message)

        return self.app(scope, receive, send_with_request_id)


345
346
347
def _extract_content_from_chunk(chunk_data: dict) -> str:
    """Extract content from a streaming response chunk."""
    try:
348
        from vllm.entrypoints.openai.chat_completion.protocol import (
349
            ChatCompletionStreamResponse,
350
        )
351
        from vllm.entrypoints.openai.completion.protocol import (
352
353
            CompletionStreamResponse,
        )
354
355

        # Try using Completion types for type-safe parsing
356
357
        if chunk_data.get("object") == "chat.completion.chunk":
            chat_response = ChatCompletionStreamResponse.model_validate(chunk_data)
358
359
            if chat_response.choices and chat_response.choices[0].delta.content:
                return chat_response.choices[0].delta.content
360
361
362
        elif chunk_data.get("object") == "text_completion":
            completion_response = CompletionStreamResponse.model_validate(chunk_data)
            if completion_response.choices and completion_response.choices[0].text:
363
364
365
                return completion_response.choices[0].text
    except pydantic.ValidationError:
        # Fallback to manual parsing
366
367
368
369
370
371
        if "choices" in chunk_data and chunk_data["choices"]:
            choice = chunk_data["choices"][0]
            if "delta" in choice and choice["delta"].get("content"):
                return choice["delta"]["content"]
            elif choice.get("text"):
                return choice["text"]
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
    return ""


class SSEDecoder:
    """Robust Server-Sent Events decoder for streaming responses."""

    def __init__(self):
        self.buffer = ""
        self.content_buffer = []

    def decode_chunk(self, chunk: bytes) -> list[dict]:
        """Decode a chunk of SSE data and return parsed events."""
        import json

        try:
387
            chunk_str = chunk.decode("utf-8")
388
389
390
391
392
393
394
395
        except UnicodeDecodeError:
            # Skip malformed chunks
            return []

        self.buffer += chunk_str
        events = []

        # Process complete lines
396
397
398
        while "\n" in self.buffer:
            line, self.buffer = self.buffer.split("\n", 1)
            line = line.rstrip("\r")  # Handle CRLF
399

400
            if line.startswith("data: "):
401
                data_str = line[6:].strip()
402
403
                if data_str == "[DONE]":
                    events.append({"type": "done"})
404
405
406
                elif data_str:
                    try:
                        event_data = json.loads(data_str)
407
                        events.append({"type": "data", "data": event_data})
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
                    except json.JSONDecodeError:
                        # Skip malformed JSON
                        continue

        return events

    def extract_content(self, event_data: dict) -> str:
        """Extract content from event data."""
        return _extract_content_from_chunk(event_data)

    def add_content(self, content: str) -> None:
        """Add content to the buffer."""
        if content:
            self.content_buffer.append(content)

    def get_complete_content(self) -> str:
        """Get the complete buffered content."""
425
        return "".join(self.content_buffer)
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445


def _log_streaming_response(response, response_body: list) -> None:
    """Log streaming response with robust SSE parsing."""
    from starlette.concurrency import iterate_in_threadpool

    sse_decoder = SSEDecoder()
    chunk_count = 0

    def buffered_iterator():
        nonlocal chunk_count

        for chunk in response_body:
            chunk_count += 1
            yield chunk

            # Parse SSE events from chunk
            events = sse_decoder.decode_chunk(chunk)

            for event in events:
446
447
                if event["type"] == "data":
                    content = sse_decoder.extract_content(event["data"])
448
                    sse_decoder.add_content(content)
449
                elif event["type"] == "done":
450
451
452
453
454
455
456
457
                    # Log complete content when done
                    full_content = sse_decoder.get_complete_content()
                    if full_content:
                        # Truncate if too long
                        if len(full_content) > 2048:
                            full_content = full_content[:2048] + ""
                            "...[truncated]"
                        logger.info(
458
                            "response_body={streaming_complete: content=%r, chunks=%d}",
459
460
461
                            full_content,
                            chunk_count,
                        )
462
463
                    else:
                        logger.info(
464
465
466
                            "response_body={streaming_complete: no_content, chunks=%d}",
                            chunk_count,
                        )
467
468
469
                    return

    response.body_iterator = iterate_in_threadpool(buffered_iterator())
470
    logger.info("response_body={streaming_started: chunks=%d}", len(response_body))
471
472
473
474
475
476
477
478
479
480
481


def _log_non_streaming_response(response_body: list) -> None:
    """Log non-streaming response."""
    try:
        decoded_body = response_body[0].decode()
        logger.info("response_body={%s}", decoded_body)
    except UnicodeDecodeError:
        logger.info("response_body={<binary_data>}")


482
def build_app(args: Namespace) -> FastAPI:
483
    if args.disable_fastapi_docs:
484
485
486
        app = FastAPI(
            openapi_url=None, docs_url=None, redoc_url=None, lifespan=lifespan
        )
487
488
    elif args.enable_offline_docs:
        app = FastAPI(docs_url=None, redoc_url=None, lifespan=lifespan)
489
490
    else:
        app = FastAPI(lifespan=lifespan)
491
492
    app.state.args = args
    from vllm.entrypoints.serve import register_vllm_serve_api_routers
493

494
    register_vllm_serve_api_routers(app)
495
496
497
    from vllm.entrypoints.openai.chat_completion.api_router import (
        attach_router as register_chat_api_router,
    )
498

499
    register_chat_api_router(app)
500
501
502
503
504
505

    from vllm.entrypoints.openai.responses.api_router import (
        attach_router as register_responses_api_router,
    )

    register_responses_api_router(app)
506
507
508
509
510
    from vllm.entrypoints.openai.translations.api_router import (
        attach_router as register_translations_api_router,
    )

    register_translations_api_router(app)
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526

    from vllm.entrypoints.openai.completion.api_router import (
        attach_router as register_completion_api_router,
    )

    register_completion_api_router(app)
    from vllm.entrypoints.anthropic.api_router import (
        attach_router as register_anthropic_api_router,
    )

    register_anthropic_api_router(app)
    from vllm.entrypoints.openai.models.api_router import (
        attach_router as register_models_api_router,
    )

    register_models_api_router(app)
527
528
529
    from vllm.entrypoints.sagemaker.routes import register_sagemaker_routes

    register_sagemaker_routes(router)
Ethan Xu's avatar
Ethan Xu committed
530
    app.include_router(router)
531

Ethan Xu's avatar
Ethan Xu committed
532
    app.root_path = args.root_path
Zhuohan Li's avatar
Zhuohan Li committed
533

534
535
536
537
    from vllm.entrypoints.pooling import register_pooling_api_routers

    register_pooling_api_routers(app)

Zhuohan Li's avatar
Zhuohan Li committed
538
539
540
541
542
543
544
545
    app.add_middleware(
        CORSMiddleware,
        allow_origins=args.allowed_origins,
        allow_credentials=args.allow_credentials,
        allow_methods=args.allowed_methods,
        allow_headers=args.allowed_headers,
    )

546
547
    @app.exception_handler(HTTPException)
    async def http_exception_handler(_: Request, exc: HTTPException):
548
        err = ErrorResponse(
549
            error=ErrorInfo(
550
                message=sanitize_message(exc.detail),
551
552
553
554
                type=HTTPStatus(exc.status_code).phrase,
                code=exc.status_code,
            )
        )
555
556
        return JSONResponse(err.model_dump(), status_code=exc.status_code)

Ethan Xu's avatar
Ethan Xu committed
557
    @app.exception_handler(RequestValidationError)
558
    async def validation_exception_handler(_: Request, exc: RequestValidationError):
559
560
561
562
563
564
565
566
        param = None
        for error in exc.errors():
            if "ctx" in error and "error" in error["ctx"]:
                ctx_error = error["ctx"]["error"]
                if isinstance(ctx_error, VLLMValidationError):
                    param = ctx_error.parameter
                    break

567
568
569
570
571
572
573
574
        exc_str = str(exc)
        errors_str = str(exc.errors())

        if exc.errors() and errors_str and errors_str != exc_str:
            message = f"{exc_str} {errors_str}"
        else:
            message = exc_str

575
576
        err = ErrorResponse(
            error=ErrorInfo(
577
                message=sanitize_message(message),
578
579
                type=HTTPStatus.BAD_REQUEST.phrase,
                code=HTTPStatus.BAD_REQUEST,
580
                param=param,
581
582
583
            )
        )
        return JSONResponse(err.model_dump(), status_code=HTTPStatus.BAD_REQUEST)
Ethan Xu's avatar
Ethan Xu committed
584

585
    # Ensure --api-key option from CLI takes precedence over VLLM_API_KEY
586
587
    if tokens := [key for key in (args.api_key or [envs.VLLM_API_KEY]) if key]:
        app.add_middleware(AuthenticationMiddleware, tokens=tokens)
588

589
    if args.enable_request_id_headers:
590
        app.add_middleware(XRequestIdMiddleware)
591

592
593
594
    # Add scaling middleware to check for scaling state
    app.add_middleware(ScalingMiddleware)

595
    if envs.VLLM_DEBUG_LOG_API_SERVER_RESPONSE:
596
597
598
599
600
        logger.warning(
            "CAUTION: Enabling log response in the API Server. "
            "This can include sensitive information and should be "
            "avoided in production."
        )
601
602
603
604

        @app.middleware("http")
        async def log_response(request: Request, call_next):
            response = await call_next(request)
605
            response_body = [section async for section in response.body_iterator]
606
            response.body_iterator = iterate_in_threadpool(iter(response_body))
607
608
609
610
611
612
613
614
615
616
617
            # Check if this is a streaming response by looking at content-type
            content_type = response.headers.get("content-type", "")
            is_streaming = content_type == "text/event-stream; charset=utf-8"

            # Log response body based on type
            if not response_body:
                logger.info("response_body={<empty>}")
            elif is_streaming:
                _log_streaming_response(response, response_body)
            else:
                _log_non_streaming_response(response_body)
618
            return response
619

620
621
622
623
    for middleware in args.middleware:
        module_path, object_name = middleware.rsplit(".", 1)
        imported = getattr(importlib.import_module(module_path), object_name)
        if inspect.isclass(imported):
624
            app.add_middleware(imported)  # type: ignore[arg-type]
625
626
627
        elif inspect.iscoroutinefunction(imported):
            app.middleware("http")(imported)
        else:
628
629
630
            raise ValueError(
                f"Invalid middleware {middleware}. Must be a function or a class."
            )
631

632
633
    app = sagemaker_standards.bootstrap(app)

Ethan Xu's avatar
Ethan Xu committed
634
635
636
    return app


637
async def init_app_state(
638
    engine_client: EngineClient,
639
    state: State,
640
    args: Namespace,
641
) -> None:
642
643
    vllm_config = engine_client.vllm_config

644
    if args.served_model_name is not None:
645
        served_model_names = args.served_model_name
646
    else:
647
        served_model_names = [args.model]
648

649
    if args.enable_log_requests:
650
        request_logger = RequestLogger(max_log_len=args.max_log_len)
651
652
    else:
        request_logger = None
653

654
    base_model_paths = [
655
        BaseModelPath(name=name, model_path=args.model) for name in served_model_names
656
657
    ]

658
    state.engine_client = engine_client
659
    state.log_stats = not args.disable_log_stats
660
    state.vllm_config = vllm_config
661
    state.args = args
662
    supported_tasks = await engine_client.get_supported_tasks()
663
    logger.info("Supported tasks: %s", supported_tasks)
664

665
    resolved_chat_template = load_chat_template(args.chat_template)
666

667
    if args.tool_server == "demo":
668
        tool_server: ToolServer | None = DemoToolServer()
669
670
        assert isinstance(tool_server, DemoToolServer)
        await tool_server.init_and_validate()
671
672
673
    elif args.tool_server:
        tool_server = MCPToolServer()
        await tool_server.add_tool_server(args.tool_server)
674
675
676
    else:
        tool_server = None

677
    # Merge default_mm_loras into the static lora_modules
678
679
680
681
682
    default_mm_loras = (
        vllm_config.lora_config.default_mm_loras
        if vllm_config.lora_config is not None
        else {}
    )
683

684
685
686
687
688
689
    default_mm_loras = (
        vllm_config.lora_config.default_mm_loras
        if vllm_config.lora_config is not None
        else {}
    )
    lora_modules = process_lora_modules(args.lora_modules, default_mm_loras)
690

691
    state.openai_serving_models = OpenAIServingModels(
692
        engine_client=engine_client,
693
        base_model_paths=base_model_paths,
694
        lora_modules=lora_modules,
695
    )
696
    await state.openai_serving_models.init_static_loras()
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
    state.openai_serving_responses = (
        OpenAIServingResponses(
            engine_client,
            state.openai_serving_models,
            request_logger=request_logger,
            chat_template=resolved_chat_template,
            chat_template_content_format=args.chat_template_content_format,
            return_tokens_as_token_ids=args.return_tokens_as_token_ids,
            enable_auto_tools=args.enable_auto_tool_choice,
            tool_parser=args.tool_call_parser,
            tool_server=tool_server,
            reasoning_parser=args.structured_outputs_config.reasoning_parser,
            enable_prompt_tokens_details=args.enable_prompt_tokens_details,
            enable_force_include_usage=args.enable_force_include_usage,
            enable_log_outputs=args.enable_log_outputs,
            log_error_stack=args.log_error_stack,
        )
        if "generate" in supported_tasks
        else None
    )
    state.openai_serving_chat = (
        OpenAIServingChat(
            engine_client,
            state.openai_serving_models,
            args.response_role,
            request_logger=request_logger,
            chat_template=resolved_chat_template,
            chat_template_content_format=args.chat_template_content_format,
725
            default_chat_template_kwargs=args.default_chat_template_kwargs,
726
727
728
729
730
731
732
733
734
            trust_request_chat_template=args.trust_request_chat_template,
            return_tokens_as_token_ids=args.return_tokens_as_token_ids,
            enable_auto_tools=args.enable_auto_tool_choice,
            exclude_tools_when_tool_choice_none=args.exclude_tools_when_tool_choice_none,
            tool_parser=args.tool_call_parser,
            reasoning_parser=args.structured_outputs_config.reasoning_parser,
            enable_prompt_tokens_details=args.enable_prompt_tokens_details,
            enable_force_include_usage=args.enable_force_include_usage,
            enable_log_outputs=args.enable_log_outputs,
735
            enable_log_deltas=args.enable_log_deltas,
736
737
738
739
740
            log_error_stack=args.log_error_stack,
        )
        if "generate" in supported_tasks
        else None
    )
741
742
743
    # Warm up chat template processing to avoid first-request latency
    if state.openai_serving_chat is not None:
        await state.openai_serving_chat.warmup()
744
745
746
747
748
749
750
751
752
753
754
755
756
    state.openai_serving_completion = (
        OpenAIServingCompletion(
            engine_client,
            state.openai_serving_models,
            request_logger=request_logger,
            return_tokens_as_token_ids=args.return_tokens_as_token_ids,
            enable_prompt_tokens_details=args.enable_prompt_tokens_details,
            enable_force_include_usage=args.enable_force_include_usage,
            log_error_stack=args.log_error_stack,
        )
        if "generate" in supported_tasks
        else None
    )
757
    state.openai_serving_tokenization = OpenAIServingTokenization(
758
        engine_client,
759
        state.openai_serving_models,
760
        request_logger=request_logger,
761
762
        chat_template=resolved_chat_template,
        chat_template_content_format=args.chat_template_content_format,
763
        trust_request_chat_template=args.trust_request_chat_template,
764
        log_error_stack=args.log_error_stack,
765
    )
766
767
768
769
770
771
    state.openai_serving_transcription = (
        OpenAIServingTranscription(
            engine_client,
            state.openai_serving_models,
            request_logger=request_logger,
            log_error_stack=args.log_error_stack,
772
            enable_force_include_usage=args.enable_force_include_usage,
773
774
775
776
777
778
779
780
781
782
        )
        if "transcription" in supported_tasks
        else None
    )
    state.openai_serving_translation = (
        OpenAIServingTranslation(
            engine_client,
            state.openai_serving_models,
            request_logger=request_logger,
            log_error_stack=args.log_error_stack,
783
            enable_force_include_usage=args.enable_force_include_usage,
784
785
786
787
        )
        if "transcription" in supported_tasks
        else None
    )
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
    state.anthropic_serving_messages = (
        AnthropicServingMessages(
            engine_client,
            state.openai_serving_models,
            args.response_role,
            request_logger=request_logger,
            chat_template=resolved_chat_template,
            chat_template_content_format=args.chat_template_content_format,
            return_tokens_as_token_ids=args.return_tokens_as_token_ids,
            enable_auto_tools=args.enable_auto_tool_choice,
            tool_parser=args.tool_call_parser,
            reasoning_parser=args.structured_outputs_config.reasoning_parser,
            enable_prompt_tokens_details=args.enable_prompt_tokens_details,
            enable_force_include_usage=args.enable_force_include_usage,
        )
        if "generate" in supported_tasks
        else None
    )
806
807
808
809
810
811
812
813
814
815
816
817
818
819
    state.serving_tokens = (
        ServingTokens(
            engine_client,
            state.openai_serving_models,
            request_logger=request_logger,
            return_tokens_as_token_ids=args.return_tokens_as_token_ids,
            log_error_stack=args.log_error_stack,
            enable_prompt_tokens_details=args.enable_prompt_tokens_details,
            enable_log_outputs=args.enable_log_outputs,
            force_no_detokenize=args.tokens_only,
        )
        if "generate" in supported_tasks
        else None
    )
820

821
822
823
824
    from vllm.entrypoints.pooling import init_pooling_state

    await init_pooling_state(engine_client, state, args)

825
826
827
    state.enable_server_load_tracking = args.enable_server_load_tracking
    state.server_load_metrics = 0

828

829
def create_server_socket(addr: tuple[str, int]) -> socket.socket:
830
831
832
833
834
835
    family = socket.AF_INET
    if is_valid_ipv6_address(addr[0]):
        family = socket.AF_INET6

    sock = socket.socket(family=family, type=socket.SOCK_STREAM)
    sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
836
    sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
837
838
839
840
841
    sock.bind(addr)

    return sock


842
843
844
845
846
847
def create_server_unix_socket(path: str) -> socket.socket:
    sock = socket.socket(family=socket.AF_UNIX, type=socket.SOCK_STREAM)
    sock.bind(path)
    return sock


848
def validate_api_server_args(args):
849
    valid_tool_parses = ToolParserManager.list_registered()
850
851
852
853
854
    if args.enable_auto_tool_choice and args.tool_call_parser not in valid_tool_parses:
        raise KeyError(
            f"invalid tool call parser: {args.tool_call_parser} "
            f"(chose from {{ {','.join(valid_tool_parses)} }})"
        )
855

856
    valid_reasoning_parsers = ReasoningParserManager.list_registered()
857
858
    if (
        reasoning_parser := args.structured_outputs_config.reasoning_parser
859
    ) and reasoning_parser not in valid_reasoning_parsers:
860
        raise KeyError(
861
            f"invalid reasoning parser: {reasoning_parser} "
862
            f"(chose from {{ {','.join(valid_reasoning_parsers)} }})"
863
        )
864

865
866
867
868
869
870
871
872
873
874
875

def setup_server(args):
    """Validate API server args, set up signal handler, create socket
    ready to serve."""

    logger.info("vLLM API server version %s", VLLM_VERSION)
    log_non_default_args(args)

    if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3:
        ToolParserManager.import_tool_parser(args.tool_parser_plugin)

876
877
878
    if args.reasoning_parser_plugin and len(args.reasoning_parser_plugin) > 3:
        ReasoningParserManager.import_reasoning_parser(args.reasoning_parser_plugin)

879
880
    validate_api_server_args(args)

881
882
883
    # workaround to make sure that we bind the port before the engine is set up.
    # This avoids race conditions with ray.
    # see https://github.com/vllm-project/vllm/issues/8204
884
885
886
887
888
    if args.uds:
        sock = create_server_unix_socket(args.uds)
    else:
        sock_addr = (args.host or "", args.port)
        sock = create_server_socket(sock_addr)
889

890
891
892
893
    # workaround to avoid footguns where uvicorn drops requests with too
    # many concurrent requests active
    set_ulimit()

894
895
896
897
898
899
    def signal_handler(*_) -> None:
        # Interrupt server on sigterm while initializing
        raise KeyboardInterrupt("terminated")

    signal.signal(signal.SIGTERM, signal_handler)

900
901
902
903
904
    if args.uds:
        listen_address = f"unix:{args.uds}"
    else:
        addr, port = sock_addr
        is_ssl = args.ssl_keyfile and args.ssl_certfile
905
        host_part = f"[{addr}]" if is_valid_ipv6_address(addr) else addr or "0.0.0.0"
906
        listen_address = f"http{'s' if is_ssl else ''}://{host_part}:{port}"
907
908
909
910
911
    return listen_address, sock


async def run_server(args, **uvicorn_kwargs) -> None:
    """Run a single-worker API server."""
912
913

    # Add process-specific prefix to stdout and stderr.
914
    decorate_logs("APIServer")
915

916
917
918
919
    listen_address, sock = setup_server(args)
    await run_server_worker(listen_address, sock, args, **uvicorn_kwargs)


920
921
922
async def run_server_worker(
    listen_address, sock, args, client_config=None, **uvicorn_kwargs
) -> None:
923
924
925
926
927
    """Run a single API server worker."""

    if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3:
        ToolParserManager.import_tool_parser(args.tool_parser_plugin)

928
929
930
    if args.reasoning_parser_plugin and len(args.reasoning_parser_plugin) > 3:
        ReasoningParserManager.import_reasoning_parser(args.reasoning_parser_plugin)

931
932
933
    # Load logging config for uvicorn if specified
    log_config = load_log_config(args.log_config_file)
    if log_config is not None:
934
        uvicorn_kwargs["log_config"] = log_config
935

936
    async with build_async_engine_client(
937
938
        args,
        client_config=client_config,
939
    ) as engine_client:
940
941
        app = build_app(args)

942
        await init_app_state(engine_client, app.state, args)
943

944
945
        logger.info(
            "Starting vLLM API server %d on %s",
946
            engine_client.vllm_config.parallel_config._api_process_rank,
947
948
            listen_address,
        )
949
950
        shutdown_task = await serve_http(
            app,
951
            sock=sock,
952
            enable_ssl_refresh=args.enable_ssl_refresh,
953
954
955
            host=args.host,
            port=args.port,
            log_level=args.uvicorn_log_level,
956
957
958
            # NOTE: When the 'disable_uvicorn_access_log' value is True,
            # no access log will be output.
            access_log=not args.disable_uvicorn_access_log,
959
            timeout_keep_alive=envs.VLLM_HTTP_TIMEOUT_KEEP_ALIVE,
960
961
962
963
            ssl_keyfile=args.ssl_keyfile,
            ssl_certfile=args.ssl_certfile,
            ssl_ca_certs=args.ssl_ca_certs,
            ssl_cert_reqs=args.ssl_cert_reqs,
964
            ssl_ciphers=args.ssl_ciphers,
965
966
            h11_max_incomplete_event_size=args.h11_max_incomplete_event_size,
            h11_max_header_count=args.h11_max_header_count,
967
968
969
            **uvicorn_kwargs,
        )

970
    # NB: Await server shutdown only after the backend context is exited
971
972
973
974
    try:
        await shutdown_task
    finally:
        sock.close()
975

Ethan Xu's avatar
Ethan Xu committed
976
977
978

if __name__ == "__main__":
    # NOTE(simon):
979
980
    # This section should be in sync with vllm/entrypoints/cli/main.py for CLI
    # entrypoints.
981
    cli_env_setup()
Ethan Xu's avatar
Ethan Xu committed
982
    parser = FlexibleArgumentParser(
983
984
        description="vLLM OpenAI-Compatible RESTful API server."
    )
Ethan Xu's avatar
Ethan Xu committed
985
986
    parser = make_arg_parser(parser)
    args = parser.parse_args()
987
    validate_parsed_serve_args(args)
988

989
    uvloop.run(run_server(args))