api_server.py 41.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import asyncio
4
import hashlib
5
6
import importlib
import inspect
7
import json
8
import multiprocessing
9
import multiprocessing.forkserver as forkserver
10
import os
11
import secrets
12
import signal
13
import socket
14
import tempfile
15
import uuid
16
from argparse import Namespace
17
from collections.abc import AsyncIterator, Awaitable
18
from contextlib import asynccontextmanager
19
from http import HTTPStatus
20
from typing import Any
21

22
import model_hosting_container_standards.sagemaker as sagemaker_standards
23
import pydantic
24
import uvloop
25
from fastapi import APIRouter, Depends, FastAPI, HTTPException, Request
Zhuohan Li's avatar
Zhuohan Li committed
26
27
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
28
from fastapi.responses import JSONResponse, StreamingResponse
29
from starlette.concurrency import iterate_in_threadpool
30
31
from starlette.datastructures import URL, Headers, MutableHeaders, State
from starlette.types import ASGIApp, Message, Receive, Scope, Send
Zhuohan Li's avatar
Zhuohan Li committed
32

33
import vllm.envs as envs
Woosuk Kwon's avatar
Woosuk Kwon committed
34
from vllm.engine.arg_utils import AsyncEngineArgs
35
from vllm.engine.protocol import EngineClient
36
37
38
39
40
41
42
from vllm.entrypoints.anthropic.protocol import (
    AnthropicError,
    AnthropicErrorResponse,
    AnthropicMessagesRequest,
    AnthropicMessagesResponse,
)
from vllm.entrypoints.anthropic.serving_messages import AnthropicServingMessages
43
from vllm.entrypoints.launcher import serve_http
44
from vllm.entrypoints.logger import RequestLogger
45
from vllm.entrypoints.openai.chat_completion.serving import OpenAIServingChat
46
from vllm.entrypoints.openai.cli_args import make_arg_parser, validate_parsed_serve_args
47
from vllm.entrypoints.openai.engine.protocol import (
48
49
50
51
52
    CompletionRequest,
    CompletionResponse,
    ErrorInfo,
    ErrorResponse,
)
53
54
from vllm.entrypoints.openai.engine.serving import OpenAIServing
from vllm.entrypoints.openai.orca_metrics import metrics_header
55
from vllm.entrypoints.openai.responses.serving import OpenAIServingResponses
56
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
57
58
59
60
from vllm.entrypoints.openai.serving_models import (
    BaseModelPath,
    OpenAIServingModels,
)
61
from vllm.entrypoints.openai.translations.serving import (
62
63
64
    OpenAIServingTranscription,
    OpenAIServingTranslation,
)
65
66
67
68
69
from vllm.entrypoints.openai.utils import validate_json_request
from vllm.entrypoints.pooling.classify.serving import ServingClassification
from vllm.entrypoints.pooling.embed.serving import OpenAIServingEmbedding
from vllm.entrypoints.pooling.pooling.serving import OpenAIServingPooling
from vllm.entrypoints.pooling.score.serving import ServingScores
70
71
72
73
74
from vllm.entrypoints.serve.disagg.serving import ServingTokens
from vllm.entrypoints.serve.elastic_ep.middleware import (
    ScalingMiddleware,
)
from vllm.entrypoints.serve.tokenize.serving import OpenAIServingTokenization
75
76
77
78
79
from vllm.entrypoints.tool_server import DemoToolServer, MCPToolServer, ToolServer
from vllm.entrypoints.utils import (
    cli_env_setup,
    load_aware_call,
    log_non_default_args,
80
81
    process_chat_template,
    process_lora_modules,
82
    sanitize_message,
83
84
    with_cancellation,
)
85
from vllm.exceptions import VLLMValidationError
86
from vllm.logger import init_logger
87
from vllm.reasoning import ReasoningParserManager
88
from vllm.tasks import POOLING_TASKS
89
from vllm.tool_parsers import ToolParserManager
yhu422's avatar
yhu422 committed
90
from vllm.usage.usage_lib import UsageContext
Cyrus Leung's avatar
Cyrus Leung committed
91
from vllm.utils.argparse_utils import FlexibleArgumentParser
92
from vllm.utils.gc_utils import freeze_gc_heap
93
from vllm.utils.network_utils import is_valid_ipv6_address
94
from vllm.utils.system_utils import decorate_logs, set_ulimit
95
from vllm.version import __version__ as VLLM_VERSION
Zhuohan Li's avatar
Zhuohan Li committed
96

97
prometheus_multiproc_dir: tempfile.TemporaryDirectory
98

99
# Cannot use __name__ (https://github.com/vllm-project/vllm/pull/4765)
100
logger = init_logger("vllm.entrypoints.openai.api_server")
101

102
103
ENDPOINT_LOAD_METRICS_FORMAT_HEADER_LABEL = "endpoint-load-metrics-format"

104
_running_tasks: set[asyncio.Task] = set()
105

106

107
@asynccontextmanager
108
async def lifespan(app: FastAPI):
109
110
    try:
        if app.state.log_stats:
111
            engine_client: EngineClient = app.state.engine_client
112
113
114

            async def _force_log():
                while True:
115
                    await asyncio.sleep(envs.VLLM_LOG_STATS_INTERVAL)
116
                    await engine_client.do_log_stats()
117
118
119
120
121
122

            task = asyncio.create_task(_force_log())
            _running_tasks.add(task)
            task.add_done_callback(_running_tasks.remove)
        else:
            task = None
123
124
125

        # Mark the startup heap as static so that it's ignored by GC.
        # Reduces pause times of oldest generation collections.
126
        freeze_gc_heap()
127
128
129
130
131
132
133
134
        try:
            yield
        finally:
            if task is not None:
                task.cancel()
    finally:
        # Ensure app state including engine ref is gc'd
        del app.state
135
136


137
@asynccontextmanager
138
async def build_async_engine_client(
139
    args: Namespace,
140
141
    *,
    usage_context: UsageContext = UsageContext.OPENAI_API_SERVER,
142
143
    disable_frontend_multiprocessing: bool | None = None,
    client_config: dict[str, Any] | None = None,
144
) -> AsyncIterator[EngineClient]:
145
146
147
148
    if os.getenv("VLLM_WORKER_MULTIPROC_METHOD") == "forkserver":
        # The executor is expected to be mp.
        # Pre-import heavy modules in the forkserver process
        logger.debug("Setup forkserver with pre-imports")
149
        multiprocessing.set_start_method("forkserver")
150
151
152
153
        multiprocessing.set_forkserver_preload(["vllm.v1.engine.async_llm"])
        forkserver.ensure_running()
        logger.debug("Forkserver setup complete!")

154
    # Context manager to handle engine_client lifecycle
155
156
    # Ensures everything is shutdown and cleaned up on error/exit
    engine_args = AsyncEngineArgs.from_cli_args(args)
157
158
159
    if client_config:
        engine_args._api_process_count = client_config.get("client_count", 1)
        engine_args._api_process_rank = client_config.get("client_index", 0)
160

161
    if disable_frontend_multiprocessing is None:
162
        disable_frontend_multiprocessing = bool(args.disable_frontend_multiprocessing)
163

164
    async with build_async_engine_client_from_engine_args(
165
166
167
168
        engine_args,
        usage_context=usage_context,
        disable_frontend_multiprocessing=disable_frontend_multiprocessing,
        client_config=client_config,
169
    ) as engine:
170
171
172
173
174
175
        yield engine


@asynccontextmanager
async def build_async_engine_client_from_engine_args(
    engine_args: AsyncEngineArgs,
176
177
    *,
    usage_context: UsageContext = UsageContext.OPENAI_API_SERVER,
178
    disable_frontend_multiprocessing: bool = False,
179
    client_config: dict[str, Any] | None = None,
180
) -> AsyncIterator[EngineClient]:
181
    """
182
    Create EngineClient, either:
183
184
185
186
187
188
        - in-process using the AsyncLLMEngine Directly
        - multiprocess using AsyncLLMEngine RPC

    Returns the Client or None if the creation failed.
    """

189
190
191
    # Create the EngineConfig (determines if we can use V1).
    vllm_config = engine_args.create_engine_config(usage_context=usage_context)

192
    if disable_frontend_multiprocessing:
193
        logger.warning("V1 is enabled, but got --disable-frontend-multiprocessing.")
194

195
    from vllm.v1.engine.async_llm import AsyncLLM
196

197
    async_llm: AsyncLLM | None = None
198
199
200
201
202
203

    # Don't mutate the input client_config
    client_config = dict(client_config) if client_config else {}
    client_count = client_config.pop("client_count", 1)
    client_index = client_config.pop("client_index", 0)

204
205
206
207
208
    try:
        async_llm = AsyncLLM.from_vllm_config(
            vllm_config=vllm_config,
            usage_context=usage_context,
            enable_log_requests=engine_args.enable_log_requests,
209
            aggregate_engine_logging=engine_args.aggregate_engine_logging,
210
211
212
            disable_log_stats=engine_args.disable_log_stats,
            client_addresses=client_config,
            client_count=client_count,
213
214
            client_index=client_index,
        )
215
216

        # Don't keep the dummy data in memory
217
        assert async_llm is not None
218
219
220
221
222
223
        await async_llm.reset_mm_cache()

        yield async_llm
    finally:
        if async_llm:
            async_llm.shutdown()
224
225


Ethan Xu's avatar
Ethan Xu committed
226
router = APIRouter()
Zhuohan Li's avatar
Zhuohan Li committed
227

228

229
230
231
232
233
def base(request: Request) -> OpenAIServing:
    # Reuse the existing instance
    return tokenization(request)


234
235
236
237
def models(request: Request) -> OpenAIServingModels:
    return request.app.state.openai_serving_models


238
239
240
241
def messages(request: Request) -> AnthropicServingMessages:
    return request.app.state.anthropic_serving_messages


242
def chat(request: Request) -> OpenAIServingChat | None:
243
244
245
    return request.app.state.openai_serving_chat


246
def completion(request: Request) -> OpenAIServingCompletion | None:
247
248
249
    return request.app.state.openai_serving_completion


250
251
def tokenization(request: Request) -> OpenAIServingTokenization:
    return request.app.state.openai_serving_tokenization
252
253


254
def engine_client(request: Request) -> EngineClient:
255
256
257
    return request.app.state.engine_client


258
259
260
261
@router.get("/load")
async def get_server_load_metrics(request: Request):
    # This endpoint returns the current server load metrics.
    # It tracks requests utilizing the GPU from the following routes:
262
263
264
265
    # - /v1/responses
    # - /v1/responses/{response_id}
    # - /v1/responses/{response_id}/cancel
    # - /v1/messages
266
267
268
    # - /v1/chat/completions
    # - /v1/completions
    # - /v1/audio/transcriptions
269
    # - /v1/audio/translations
270
271
    # - /v1/embeddings
    # - /pooling
272
    # - /classify
273
274
275
276
277
    # - /score
    # - /v1/score
    # - /rerank
    # - /v1/rerank
    # - /v2/rerank
278
    return JSONResponse(content={"server_load": request.app.state.server_load_metrics})
279
280


Ethan Xu's avatar
Ethan Xu committed
281
@router.get("/v1/models")
282
async def show_available_models(raw_request: Request):
283
    handler = models(raw_request)
284

285
286
    models_ = await handler.show_available_models()
    return JSONResponse(content=models_.model_dump())
Zhuohan Li's avatar
Zhuohan Li committed
287
288


Ethan Xu's avatar
Ethan Xu committed
289
@router.get("/version")
290
async def show_version():
291
    ver = {"version": VLLM_VERSION}
292
293
294
    return JSONResponse(content=ver)


295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
@router.post(
    "/v1/messages",
    dependencies=[Depends(validate_json_request)],
    responses={
        HTTPStatus.OK.value: {"content": {"text/event-stream": {}}},
        HTTPStatus.BAD_REQUEST.value: {"model": AnthropicErrorResponse},
        HTTPStatus.NOT_FOUND.value: {"model": AnthropicErrorResponse},
        HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": AnthropicErrorResponse},
    },
)
@with_cancellation
@load_aware_call
async def create_messages(request: AnthropicMessagesRequest, raw_request: Request):
    def translate_error_response(response: ErrorResponse) -> JSONResponse:
        anthropic_error = AnthropicErrorResponse(
            error=AnthropicError(
                type=response.error.type,
                message=response.error.message,
            )
        )
        return JSONResponse(
            status_code=response.error.code, content=anthropic_error.model_dump()
        )

    handler = messages(raw_request)
    if handler is None:
        error = base(raw_request).create_error_response(
            message="The model does not support Messages API"
        )
        return translate_error_response(error)

    try:
        generator = await handler.create_messages(request, raw_request)
    except Exception as e:
        logger.exception("Error in create_messages: %s", e)
        return JSONResponse(
            status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value,
            content=AnthropicErrorResponse(
                error=AnthropicError(
                    type="internal_error",
                    message=str(e),
                )
            ).model_dump(),
        )

    if isinstance(generator, ErrorResponse):
        return translate_error_response(generator)

    elif isinstance(generator, AnthropicMessagesResponse):
344
345
346
        resp = generator.model_dump(exclude_none=True)
        logger.debug("Anthropic Messages Response: %s", resp)
        return JSONResponse(content=resp)
347
348
349
350

    return StreamingResponse(content=generator, media_type="text/event-stream")


351
352
353
354
355
356
357
358
359
360
@router.post(
    "/v1/completions",
    dependencies=[Depends(validate_json_request)],
    responses={
        HTTPStatus.OK.value: {"content": {"text/event-stream": {}}},
        HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
        HTTPStatus.NOT_FOUND.value: {"model": ErrorResponse},
        HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
    },
)
361
@with_cancellation
362
@load_aware_call
363
async def create_completion(request: CompletionRequest, raw_request: Request):
364
365
366
    metrics_header_format = raw_request.headers.get(
        ENDPOINT_LOAD_METRICS_FORMAT_HEADER_LABEL, ""
    )
367
368
369
    handler = completion(raw_request)
    if handler is None:
        return base(raw_request).create_error_response(
370
371
            message="The model does not support Completions API"
        )
372

373
374
375
    try:
        generator = await handler.create_completion(request, raw_request)
    except Exception as e:
376
        return handler.create_error_response(e)
377

378
    if isinstance(generator, ErrorResponse):
379
380
381
        return JSONResponse(
            content=generator.model_dump(), status_code=generator.error.code
        )
382
    elif isinstance(generator, CompletionResponse):
383
384
385
386
        return JSONResponse(
            content=generator.model_dump(),
            headers=metrics_header(metrics_header_format),
        )
Zhuohan Li's avatar
Zhuohan Li committed
387

388
389
    return StreamingResponse(content=generator, media_type="text/event-stream")

Zhuohan Li's avatar
Zhuohan Li committed
390

391
def load_log_config(log_config_file: str | None) -> dict | None:
392
393
394
395
396
397
    if not log_config_file:
        return None
    try:
        with open(log_config_file) as f:
            return json.load(f)
    except Exception as e:
398
399
400
        logger.warning(
            "Failed to load log config from file %s: error %s", log_config_file, e
        )
401
402
403
        return None


404
405
406
class AuthenticationMiddleware:
    """
    Pure ASGI middleware that authenticates each request by checking
407
    if the Authorization Bearer token exists and equals anyof "{api_key}".
408
409
410
411
412
413
414
415

    Notes
    -----
    There are two cases in which authentication is skipped:
        1. The HTTP method is OPTIONS.
        2. The request path doesn't start with /v1 (e.g. /health).
    """

416
    def __init__(self, app: ASGIApp, tokens: list[str]) -> None:
417
        self.app = app
418
        self.api_tokens = [hashlib.sha256(t.encode("utf-8")).digest() for t in tokens]
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435

    def verify_token(self, headers: Headers) -> bool:
        authorization_header_value = headers.get("Authorization")
        if not authorization_header_value:
            return False

        scheme, _, param = authorization_header_value.partition(" ")
        if scheme.lower() != "bearer":
            return False

        param_hash = hashlib.sha256(param.encode("utf-8")).digest()

        token_match = False
        for token_hash in self.api_tokens:
            token_match |= secrets.compare_digest(param_hash, token_hash)

        return token_match
436

437
438
    def __call__(self, scope: Scope, receive: Receive, send: Send) -> Awaitable[None]:
        if scope["type"] not in ("http", "websocket") or scope["method"] == "OPTIONS":
439
440
441
442
443
444
445
            # scope["type"] can be "lifespan" or "startup" for example,
            # in which case we don't need to do anything
            return self.app(scope, receive, send)
        root_path = scope.get("root_path", "")
        url_path = URL(scope=scope).path.removeprefix(root_path)
        headers = Headers(scope=scope)
        # Type narrow to satisfy mypy.
446
        if url_path.startswith("/v1") and not self.verify_token(headers):
447
            response = JSONResponse(content={"error": "Unauthorized"}, status_code=401)
448
449
450
451
452
453
454
455
456
457
458
459
460
461
            return response(scope, receive, send)
        return self.app(scope, receive, send)


class XRequestIdMiddleware:
    """
    Middleware the set's the X-Request-Id header for each response
    to a random uuid4 (hex) value if the header isn't already
    present in the request, otherwise use the provided request id.
    """

    def __init__(self, app: ASGIApp) -> None:
        self.app = app

462
    def __call__(self, scope: Scope, receive: Receive, send: Send) -> Awaitable[None]:
463
464
465
466
467
468
469
470
471
472
473
474
475
        if scope["type"] not in ("http", "websocket"):
            return self.app(scope, receive, send)

        # Extract the request headers.
        request_headers = Headers(scope=scope)

        async def send_with_request_id(message: Message) -> None:
            """
            Custom send function to mutate the response headers
            and append X-Request-Id to it.
            """
            if message["type"] == "http.response.start":
                response_headers = MutableHeaders(raw=message["headers"])
476
                request_id = request_headers.get("X-Request-Id", uuid.uuid4().hex)
477
478
479
480
481
482
                response_headers.append("X-Request-Id", request_id)
            await send(message)

        return self.app(scope, receive, send_with_request_id)


483
484
485
def _extract_content_from_chunk(chunk_data: dict) -> str:
    """Extract content from a streaming response chunk."""
    try:
486
        from vllm.entrypoints.openai.chat_completion.protocol import (
487
            ChatCompletionStreamResponse,
488
489
        )
        from vllm.entrypoints.openai.engine.protocol import (
490
491
            CompletionStreamResponse,
        )
492
493

        # Try using Completion types for type-safe parsing
494
495
        if chunk_data.get("object") == "chat.completion.chunk":
            chat_response = ChatCompletionStreamResponse.model_validate(chunk_data)
496
497
            if chat_response.choices and chat_response.choices[0].delta.content:
                return chat_response.choices[0].delta.content
498
499
500
        elif chunk_data.get("object") == "text_completion":
            completion_response = CompletionStreamResponse.model_validate(chunk_data)
            if completion_response.choices and completion_response.choices[0].text:
501
502
503
                return completion_response.choices[0].text
    except pydantic.ValidationError:
        # Fallback to manual parsing
504
505
506
507
508
509
        if "choices" in chunk_data and chunk_data["choices"]:
            choice = chunk_data["choices"][0]
            if "delta" in choice and choice["delta"].get("content"):
                return choice["delta"]["content"]
            elif choice.get("text"):
                return choice["text"]
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
    return ""


class SSEDecoder:
    """Robust Server-Sent Events decoder for streaming responses."""

    def __init__(self):
        self.buffer = ""
        self.content_buffer = []

    def decode_chunk(self, chunk: bytes) -> list[dict]:
        """Decode a chunk of SSE data and return parsed events."""
        import json

        try:
525
            chunk_str = chunk.decode("utf-8")
526
527
528
529
530
531
532
533
        except UnicodeDecodeError:
            # Skip malformed chunks
            return []

        self.buffer += chunk_str
        events = []

        # Process complete lines
534
535
536
        while "\n" in self.buffer:
            line, self.buffer = self.buffer.split("\n", 1)
            line = line.rstrip("\r")  # Handle CRLF
537

538
            if line.startswith("data: "):
539
                data_str = line[6:].strip()
540
541
                if data_str == "[DONE]":
                    events.append({"type": "done"})
542
543
544
                elif data_str:
                    try:
                        event_data = json.loads(data_str)
545
                        events.append({"type": "data", "data": event_data})
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
                    except json.JSONDecodeError:
                        # Skip malformed JSON
                        continue

        return events

    def extract_content(self, event_data: dict) -> str:
        """Extract content from event data."""
        return _extract_content_from_chunk(event_data)

    def add_content(self, content: str) -> None:
        """Add content to the buffer."""
        if content:
            self.content_buffer.append(content)

    def get_complete_content(self) -> str:
        """Get the complete buffered content."""
563
        return "".join(self.content_buffer)
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583


def _log_streaming_response(response, response_body: list) -> None:
    """Log streaming response with robust SSE parsing."""
    from starlette.concurrency import iterate_in_threadpool

    sse_decoder = SSEDecoder()
    chunk_count = 0

    def buffered_iterator():
        nonlocal chunk_count

        for chunk in response_body:
            chunk_count += 1
            yield chunk

            # Parse SSE events from chunk
            events = sse_decoder.decode_chunk(chunk)

            for event in events:
584
585
                if event["type"] == "data":
                    content = sse_decoder.extract_content(event["data"])
586
                    sse_decoder.add_content(content)
587
                elif event["type"] == "done":
588
589
590
591
592
593
594
595
                    # Log complete content when done
                    full_content = sse_decoder.get_complete_content()
                    if full_content:
                        # Truncate if too long
                        if len(full_content) > 2048:
                            full_content = full_content[:2048] + ""
                            "...[truncated]"
                        logger.info(
596
                            "response_body={streaming_complete: content=%r, chunks=%d}",
597
598
599
                            full_content,
                            chunk_count,
                        )
600
601
                    else:
                        logger.info(
602
603
604
                            "response_body={streaming_complete: no_content, chunks=%d}",
                            chunk_count,
                        )
605
606
607
                    return

    response.body_iterator = iterate_in_threadpool(buffered_iterator())
608
    logger.info("response_body={streaming_started: chunks=%d}", len(response_body))
609
610
611
612
613
614
615
616
617
618
619


def _log_non_streaming_response(response_body: list) -> None:
    """Log non-streaming response."""
    try:
        decoded_body = response_body[0].decode()
        logger.info("response_body={%s}", decoded_body)
    except UnicodeDecodeError:
        logger.info("response_body={<binary_data>}")


620
def build_app(args: Namespace) -> FastAPI:
621
    if args.disable_fastapi_docs:
622
623
624
        app = FastAPI(
            openapi_url=None, docs_url=None, redoc_url=None, lifespan=lifespan
        )
625
626
    elif args.enable_offline_docs:
        app = FastAPI(docs_url=None, redoc_url=None, lifespan=lifespan)
627
628
    else:
        app = FastAPI(lifespan=lifespan)
629
630
    app.state.args = args
    from vllm.entrypoints.serve import register_vllm_serve_api_routers
631

632
    register_vllm_serve_api_routers(app)
633
634
635
    from vllm.entrypoints.openai.chat_completion.api_router import (
        attach_router as register_chat_api_router,
    )
636

637
    register_chat_api_router(app)
638
639
640
641
642
643

    from vllm.entrypoints.openai.responses.api_router import (
        attach_router as register_responses_api_router,
    )

    register_responses_api_router(app)
644
645
646
647
648
    from vllm.entrypoints.openai.translations.api_router import (
        attach_router as register_translations_api_router,
    )

    register_translations_api_router(app)
649
650
651
    from vllm.entrypoints.sagemaker.routes import register_sagemaker_routes

    register_sagemaker_routes(router)
Ethan Xu's avatar
Ethan Xu committed
652
    app.include_router(router)
653

Ethan Xu's avatar
Ethan Xu committed
654
    app.root_path = args.root_path
Zhuohan Li's avatar
Zhuohan Li committed
655

656
657
658
659
    from vllm.entrypoints.pooling import register_pooling_api_routers

    register_pooling_api_routers(app)

Zhuohan Li's avatar
Zhuohan Li committed
660
661
662
663
664
665
666
667
    app.add_middleware(
        CORSMiddleware,
        allow_origins=args.allowed_origins,
        allow_credentials=args.allow_credentials,
        allow_methods=args.allowed_methods,
        allow_headers=args.allowed_headers,
    )

668
669
    @app.exception_handler(HTTPException)
    async def http_exception_handler(_: Request, exc: HTTPException):
670
        err = ErrorResponse(
671
            error=ErrorInfo(
672
                message=sanitize_message(exc.detail),
673
674
675
676
                type=HTTPStatus(exc.status_code).phrase,
                code=exc.status_code,
            )
        )
677
678
        return JSONResponse(err.model_dump(), status_code=exc.status_code)

Ethan Xu's avatar
Ethan Xu committed
679
    @app.exception_handler(RequestValidationError)
680
    async def validation_exception_handler(_: Request, exc: RequestValidationError):
681
682
683
684
685
686
687
688
        param = None
        for error in exc.errors():
            if "ctx" in error and "error" in error["ctx"]:
                ctx_error = error["ctx"]["error"]
                if isinstance(ctx_error, VLLMValidationError):
                    param = ctx_error.parameter
                    break

689
690
691
692
693
694
695
696
        exc_str = str(exc)
        errors_str = str(exc.errors())

        if exc.errors() and errors_str and errors_str != exc_str:
            message = f"{exc_str} {errors_str}"
        else:
            message = exc_str

697
698
        err = ErrorResponse(
            error=ErrorInfo(
699
                message=sanitize_message(message),
700
701
                type=HTTPStatus.BAD_REQUEST.phrase,
                code=HTTPStatus.BAD_REQUEST,
702
                param=param,
703
704
705
            )
        )
        return JSONResponse(err.model_dump(), status_code=HTTPStatus.BAD_REQUEST)
Ethan Xu's avatar
Ethan Xu committed
706

707
    # Ensure --api-key option from CLI takes precedence over VLLM_API_KEY
708
709
    if tokens := [key for key in (args.api_key or [envs.VLLM_API_KEY]) if key]:
        app.add_middleware(AuthenticationMiddleware, tokens=tokens)
710

711
    if args.enable_request_id_headers:
712
        app.add_middleware(XRequestIdMiddleware)
713

714
715
716
    # Add scaling middleware to check for scaling state
    app.add_middleware(ScalingMiddleware)

717
    if envs.VLLM_DEBUG_LOG_API_SERVER_RESPONSE:
718
719
720
721
722
        logger.warning(
            "CAUTION: Enabling log response in the API Server. "
            "This can include sensitive information and should be "
            "avoided in production."
        )
723
724
725
726

        @app.middleware("http")
        async def log_response(request: Request, call_next):
            response = await call_next(request)
727
            response_body = [section async for section in response.body_iterator]
728
            response.body_iterator = iterate_in_threadpool(iter(response_body))
729
730
731
732
733
734
735
736
737
738
739
            # Check if this is a streaming response by looking at content-type
            content_type = response.headers.get("content-type", "")
            is_streaming = content_type == "text/event-stream; charset=utf-8"

            # Log response body based on type
            if not response_body:
                logger.info("response_body={<empty>}")
            elif is_streaming:
                _log_streaming_response(response, response_body)
            else:
                _log_non_streaming_response(response_body)
740
            return response
741

742
743
744
745
    for middleware in args.middleware:
        module_path, object_name = middleware.rsplit(".", 1)
        imported = getattr(importlib.import_module(module_path), object_name)
        if inspect.isclass(imported):
746
            app.add_middleware(imported)  # type: ignore[arg-type]
747
748
749
        elif inspect.iscoroutinefunction(imported):
            app.middleware("http")(imported)
        else:
750
751
752
            raise ValueError(
                f"Invalid middleware {middleware}. Must be a function or a class."
            )
753

754
755
    app = sagemaker_standards.bootstrap(app)

Ethan Xu's avatar
Ethan Xu committed
756
757
758
    return app


759
async def init_app_state(
760
    engine_client: EngineClient,
761
    state: State,
762
    args: Namespace,
763
) -> None:
764
765
    vllm_config = engine_client.vllm_config

766
    if args.served_model_name is not None:
767
        served_model_names = args.served_model_name
768
    else:
769
        served_model_names = [args.model]
770

771
    if args.enable_log_requests:
772
        request_logger = RequestLogger(max_log_len=args.max_log_len)
773
774
    else:
        request_logger = None
775

776
    base_model_paths = [
777
        BaseModelPath(name=name, model_path=args.model) for name in served_model_names
778
779
    ]

780
    state.engine_client = engine_client
781
    state.log_stats = not args.disable_log_stats
782
    state.vllm_config = vllm_config
783
    state.args = args
784
    supported_tasks = await engine_client.get_supported_tasks()
785
    logger.info("Supported tasks: %s", supported_tasks)
786

787
    resolved_chat_template = await process_chat_template(
788
        args.chat_template, engine_client, vllm_config.model_config
789
    )
790

791
    if args.tool_server == "demo":
792
        tool_server: ToolServer | None = DemoToolServer()
793
794
        assert isinstance(tool_server, DemoToolServer)
        await tool_server.init_and_validate()
795
796
797
    elif args.tool_server:
        tool_server = MCPToolServer()
        await tool_server.add_tool_server(args.tool_server)
798
799
800
    else:
        tool_server = None

801
    # Merge default_mm_loras into the static lora_modules
802
803
804
805
806
    default_mm_loras = (
        vllm_config.lora_config.default_mm_loras
        if vllm_config.lora_config is not None
        else {}
    )
807

808
809
810
811
812
813
    default_mm_loras = (
        vllm_config.lora_config.default_mm_loras
        if vllm_config.lora_config is not None
        else {}
    )
    lora_modules = process_lora_modules(args.lora_modules, default_mm_loras)
814

815
    state.openai_serving_models = OpenAIServingModels(
816
        engine_client=engine_client,
817
        base_model_paths=base_model_paths,
818
        lora_modules=lora_modules,
819
    )
820
    await state.openai_serving_models.init_static_loras()
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
    state.openai_serving_responses = (
        OpenAIServingResponses(
            engine_client,
            state.openai_serving_models,
            request_logger=request_logger,
            chat_template=resolved_chat_template,
            chat_template_content_format=args.chat_template_content_format,
            return_tokens_as_token_ids=args.return_tokens_as_token_ids,
            enable_auto_tools=args.enable_auto_tool_choice,
            tool_parser=args.tool_call_parser,
            tool_server=tool_server,
            reasoning_parser=args.structured_outputs_config.reasoning_parser,
            enable_prompt_tokens_details=args.enable_prompt_tokens_details,
            enable_force_include_usage=args.enable_force_include_usage,
            enable_log_outputs=args.enable_log_outputs,
            log_error_stack=args.log_error_stack,
        )
        if "generate" in supported_tasks
        else None
    )
    state.openai_serving_chat = (
        OpenAIServingChat(
            engine_client,
            state.openai_serving_models,
            args.response_role,
            request_logger=request_logger,
            chat_template=resolved_chat_template,
            chat_template_content_format=args.chat_template_content_format,
849
            default_chat_template_kwargs=args.default_chat_template_kwargs,
850
851
852
853
854
855
856
857
858
            trust_request_chat_template=args.trust_request_chat_template,
            return_tokens_as_token_ids=args.return_tokens_as_token_ids,
            enable_auto_tools=args.enable_auto_tool_choice,
            exclude_tools_when_tool_choice_none=args.exclude_tools_when_tool_choice_none,
            tool_parser=args.tool_call_parser,
            reasoning_parser=args.structured_outputs_config.reasoning_parser,
            enable_prompt_tokens_details=args.enable_prompt_tokens_details,
            enable_force_include_usage=args.enable_force_include_usage,
            enable_log_outputs=args.enable_log_outputs,
859
            enable_log_deltas=args.enable_log_deltas,
860
861
862
863
864
            log_error_stack=args.log_error_stack,
        )
        if "generate" in supported_tasks
        else None
    )
865
866
867
    # Warm up chat template processing to avoid first-request latency
    if state.openai_serving_chat is not None:
        await state.openai_serving_chat.warmup()
868
869
870
871
872
873
874
875
876
877
878
879
880
881
    state.openai_serving_completion = (
        OpenAIServingCompletion(
            engine_client,
            state.openai_serving_models,
            request_logger=request_logger,
            return_tokens_as_token_ids=args.return_tokens_as_token_ids,
            enable_prompt_tokens_details=args.enable_prompt_tokens_details,
            enable_force_include_usage=args.enable_force_include_usage,
            log_error_stack=args.log_error_stack,
        )
        if "generate" in supported_tasks
        else None
    )
    state.openai_serving_pooling = (
882
883
884
885
886
887
888
889
890
891
892
        (
            OpenAIServingPooling(
                engine_client,
                state.openai_serving_models,
                supported_tasks=supported_tasks,
                request_logger=request_logger,
                chat_template=resolved_chat_template,
                chat_template_content_format=args.chat_template_content_format,
                trust_request_chat_template=args.trust_request_chat_template,
                log_error_stack=args.log_error_stack,
            )
893
        )
894
        if any(task in POOLING_TASKS for task in supported_tasks)
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
        else None
    )
    state.openai_serving_embedding = (
        OpenAIServingEmbedding(
            engine_client,
            state.openai_serving_models,
            request_logger=request_logger,
            chat_template=resolved_chat_template,
            chat_template_content_format=args.chat_template_content_format,
            trust_request_chat_template=args.trust_request_chat_template,
            log_error_stack=args.log_error_stack,
        )
        if "embed" in supported_tasks
        else None
    )
    state.openai_serving_classification = (
        ServingClassification(
            engine_client,
            state.openai_serving_models,
            request_logger=request_logger,
915
916
917
            chat_template=resolved_chat_template,
            chat_template_content_format=args.chat_template_content_format,
            trust_request_chat_template=args.trust_request_chat_template,
918
919
920
921
922
923
924
925
926
927
            log_error_stack=args.log_error_stack,
        )
        if "classify" in supported_tasks
        else None
    )
    state.openai_serving_scores = (
        ServingScores(
            engine_client,
            state.openai_serving_models,
            request_logger=request_logger,
928
            score_template=resolved_chat_template,
929
930
931
932
933
            log_error_stack=args.log_error_stack,
        )
        if ("embed" in supported_tasks or "score" in supported_tasks)
        else None
    )
934
    state.openai_serving_tokenization = OpenAIServingTokenization(
935
        engine_client,
936
        state.openai_serving_models,
937
        request_logger=request_logger,
938
939
        chat_template=resolved_chat_template,
        chat_template_content_format=args.chat_template_content_format,
940
        trust_request_chat_template=args.trust_request_chat_template,
941
        log_error_stack=args.log_error_stack,
942
    )
943
944
945
946
947
948
    state.openai_serving_transcription = (
        OpenAIServingTranscription(
            engine_client,
            state.openai_serving_models,
            request_logger=request_logger,
            log_error_stack=args.log_error_stack,
949
            enable_force_include_usage=args.enable_force_include_usage,
950
951
952
953
954
955
956
957
958
959
        )
        if "transcription" in supported_tasks
        else None
    )
    state.openai_serving_translation = (
        OpenAIServingTranslation(
            engine_client,
            state.openai_serving_models,
            request_logger=request_logger,
            log_error_stack=args.log_error_stack,
960
            enable_force_include_usage=args.enable_force_include_usage,
961
962
963
964
        )
        if "transcription" in supported_tasks
        else None
    )
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
    state.anthropic_serving_messages = (
        AnthropicServingMessages(
            engine_client,
            state.openai_serving_models,
            args.response_role,
            request_logger=request_logger,
            chat_template=resolved_chat_template,
            chat_template_content_format=args.chat_template_content_format,
            return_tokens_as_token_ids=args.return_tokens_as_token_ids,
            enable_auto_tools=args.enable_auto_tool_choice,
            tool_parser=args.tool_call_parser,
            reasoning_parser=args.structured_outputs_config.reasoning_parser,
            enable_prompt_tokens_details=args.enable_prompt_tokens_details,
            enable_force_include_usage=args.enable_force_include_usage,
        )
        if "generate" in supported_tasks
        else None
    )
983
984
985
986
987
988
989
990
991
992
993
994
995
996
    state.serving_tokens = (
        ServingTokens(
            engine_client,
            state.openai_serving_models,
            request_logger=request_logger,
            return_tokens_as_token_ids=args.return_tokens_as_token_ids,
            log_error_stack=args.log_error_stack,
            enable_prompt_tokens_details=args.enable_prompt_tokens_details,
            enable_log_outputs=args.enable_log_outputs,
            force_no_detokenize=args.tokens_only,
        )
        if "generate" in supported_tasks
        else None
    )
997

998
999
1000
    state.enable_server_load_tracking = args.enable_server_load_tracking
    state.server_load_metrics = 0

1001

1002
def create_server_socket(addr: tuple[str, int]) -> socket.socket:
1003
1004
1005
1006
1007
1008
    family = socket.AF_INET
    if is_valid_ipv6_address(addr[0]):
        family = socket.AF_INET6

    sock = socket.socket(family=family, type=socket.SOCK_STREAM)
    sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
1009
    sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
1010
1011
1012
1013
1014
    sock.bind(addr)

    return sock


1015
1016
1017
1018
1019
1020
def create_server_unix_socket(path: str) -> socket.socket:
    sock = socket.socket(family=socket.AF_UNIX, type=socket.SOCK_STREAM)
    sock.bind(path)
    return sock


1021
def validate_api_server_args(args):
1022
    valid_tool_parses = ToolParserManager.list_registered()
1023
1024
1025
1026
1027
    if args.enable_auto_tool_choice and args.tool_call_parser not in valid_tool_parses:
        raise KeyError(
            f"invalid tool call parser: {args.tool_call_parser} "
            f"(chose from {{ {','.join(valid_tool_parses)} }})"
        )
1028

1029
    valid_reasoning_parsers = ReasoningParserManager.list_registered()
1030
1031
    if (
        reasoning_parser := args.structured_outputs_config.reasoning_parser
1032
    ) and reasoning_parser not in valid_reasoning_parsers:
1033
        raise KeyError(
1034
            f"invalid reasoning parser: {reasoning_parser} "
1035
            f"(chose from {{ {','.join(valid_reasoning_parsers)} }})"
1036
        )
1037

1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048

def setup_server(args):
    """Validate API server args, set up signal handler, create socket
    ready to serve."""

    logger.info("vLLM API server version %s", VLLM_VERSION)
    log_non_default_args(args)

    if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3:
        ToolParserManager.import_tool_parser(args.tool_parser_plugin)

1049
1050
1051
    if args.reasoning_parser_plugin and len(args.reasoning_parser_plugin) > 3:
        ReasoningParserManager.import_reasoning_parser(args.reasoning_parser_plugin)

1052
1053
    validate_api_server_args(args)

1054
1055
1056
    # workaround to make sure that we bind the port before the engine is set up.
    # This avoids race conditions with ray.
    # see https://github.com/vllm-project/vllm/issues/8204
1057
1058
1059
1060
1061
    if args.uds:
        sock = create_server_unix_socket(args.uds)
    else:
        sock_addr = (args.host or "", args.port)
        sock = create_server_socket(sock_addr)
1062

1063
1064
1065
1066
    # workaround to avoid footguns where uvicorn drops requests with too
    # many concurrent requests active
    set_ulimit()

1067
1068
1069
1070
1071
1072
    def signal_handler(*_) -> None:
        # Interrupt server on sigterm while initializing
        raise KeyboardInterrupt("terminated")

    signal.signal(signal.SIGTERM, signal_handler)

1073
1074
1075
1076
1077
    if args.uds:
        listen_address = f"unix:{args.uds}"
    else:
        addr, port = sock_addr
        is_ssl = args.ssl_keyfile and args.ssl_certfile
1078
        host_part = f"[{addr}]" if is_valid_ipv6_address(addr) else addr or "0.0.0.0"
1079
        listen_address = f"http{'s' if is_ssl else ''}://{host_part}:{port}"
1080
1081
1082
1083
1084
    return listen_address, sock


async def run_server(args, **uvicorn_kwargs) -> None:
    """Run a single-worker API server."""
1085
1086

    # Add process-specific prefix to stdout and stderr.
1087
    decorate_logs("APIServer")
1088

1089
1090
1091
1092
    listen_address, sock = setup_server(args)
    await run_server_worker(listen_address, sock, args, **uvicorn_kwargs)


1093
1094
1095
async def run_server_worker(
    listen_address, sock, args, client_config=None, **uvicorn_kwargs
) -> None:
1096
1097
1098
1099
1100
    """Run a single API server worker."""

    if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3:
        ToolParserManager.import_tool_parser(args.tool_parser_plugin)

1101
1102
1103
    if args.reasoning_parser_plugin and len(args.reasoning_parser_plugin) > 3:
        ReasoningParserManager.import_reasoning_parser(args.reasoning_parser_plugin)

1104
1105
1106
    # Load logging config for uvicorn if specified
    log_config = load_log_config(args.log_config_file)
    if log_config is not None:
1107
        uvicorn_kwargs["log_config"] = log_config
1108

1109
    async with build_async_engine_client(
1110
1111
        args,
        client_config=client_config,
1112
    ) as engine_client:
1113
1114
        app = build_app(args)

1115
        await init_app_state(engine_client, app.state, args)
1116

1117
1118
        logger.info(
            "Starting vLLM API server %d on %s",
1119
            engine_client.vllm_config.parallel_config._api_process_rank,
1120
1121
            listen_address,
        )
1122
1123
        shutdown_task = await serve_http(
            app,
1124
            sock=sock,
1125
            enable_ssl_refresh=args.enable_ssl_refresh,
1126
1127
1128
            host=args.host,
            port=args.port,
            log_level=args.uvicorn_log_level,
1129
1130
1131
            # NOTE: When the 'disable_uvicorn_access_log' value is True,
            # no access log will be output.
            access_log=not args.disable_uvicorn_access_log,
1132
            timeout_keep_alive=envs.VLLM_HTTP_TIMEOUT_KEEP_ALIVE,
1133
1134
1135
1136
            ssl_keyfile=args.ssl_keyfile,
            ssl_certfile=args.ssl_certfile,
            ssl_ca_certs=args.ssl_ca_certs,
            ssl_cert_reqs=args.ssl_cert_reqs,
1137
1138
            h11_max_incomplete_event_size=args.h11_max_incomplete_event_size,
            h11_max_header_count=args.h11_max_header_count,
1139
1140
1141
            **uvicorn_kwargs,
        )

1142
    # NB: Await server shutdown only after the backend context is exited
1143
1144
1145
1146
    try:
        await shutdown_task
    finally:
        sock.close()
1147

Ethan Xu's avatar
Ethan Xu committed
1148
1149
1150

if __name__ == "__main__":
    # NOTE(simon):
1151
1152
    # This section should be in sync with vllm/entrypoints/cli/main.py for CLI
    # entrypoints.
1153
    cli_env_setup()
Ethan Xu's avatar
Ethan Xu committed
1154
    parser = FlexibleArgumentParser(
1155
1156
        description="vLLM OpenAI-Compatible RESTful API server."
    )
Ethan Xu's avatar
Ethan Xu committed
1157
1158
    parser = make_arg_parser(parser)
    args = parser.parse_args()
1159
    validate_parsed_serve_args(args)
1160

1161
    uvloop.run(run_server(args))