api_server.py 41.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import asyncio
4
import hashlib
5
6
import importlib
import inspect
7
import json
8
import multiprocessing
9
import multiprocessing.forkserver as forkserver
10
import os
11
import secrets
12
import signal
13
import socket
14
import tempfile
15
import uuid
16
from argparse import Namespace
17
from collections.abc import AsyncIterator, Awaitable
18
from contextlib import asynccontextmanager
19
from http import HTTPStatus
20
from typing import Any
21

22
import model_hosting_container_standards.sagemaker as sagemaker_standards
23
import pydantic
24
import uvloop
25
from fastapi import APIRouter, Depends, FastAPI, HTTPException, Request
Zhuohan Li's avatar
Zhuohan Li committed
26
27
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
28
from fastapi.responses import JSONResponse, StreamingResponse
29
from starlette.concurrency import iterate_in_threadpool
30
31
from starlette.datastructures import URL, Headers, MutableHeaders, State
from starlette.types import ASGIApp, Message, Receive, Scope, Send
Zhuohan Li's avatar
Zhuohan Li committed
32

33
import vllm.envs as envs
Woosuk Kwon's avatar
Woosuk Kwon committed
34
from vllm.engine.arg_utils import AsyncEngineArgs
35
from vllm.engine.protocol import EngineClient
36
37
38
39
40
41
42
from vllm.entrypoints.anthropic.protocol import (
    AnthropicError,
    AnthropicErrorResponse,
    AnthropicMessagesRequest,
    AnthropicMessagesResponse,
)
from vllm.entrypoints.anthropic.serving_messages import AnthropicServingMessages
43
from vllm.entrypoints.launcher import serve_http
44
from vllm.entrypoints.logger import RequestLogger
45
from vllm.entrypoints.openai.chat_completion.serving import OpenAIServingChat
46
from vllm.entrypoints.openai.cli_args import make_arg_parser, validate_parsed_serve_args
47
from vllm.entrypoints.openai.engine.protocol import (
48
49
50
51
52
    CompletionRequest,
    CompletionResponse,
    ErrorInfo,
    ErrorResponse,
)
53
54
from vllm.entrypoints.openai.engine.serving import OpenAIServing
from vllm.entrypoints.openai.orca_metrics import metrics_header
55
from vllm.entrypoints.openai.responses.serving import OpenAIServingResponses
56
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
57
58
59
60
from vllm.entrypoints.openai.serving_models import (
    BaseModelPath,
    OpenAIServingModels,
)
61
from vllm.entrypoints.openai.translations.serving import (
62
63
64
    OpenAIServingTranscription,
    OpenAIServingTranslation,
)
65
66
67
68
69
from vllm.entrypoints.openai.utils import validate_json_request
from vllm.entrypoints.pooling.classify.serving import ServingClassification
from vllm.entrypoints.pooling.embed.serving import OpenAIServingEmbedding
from vllm.entrypoints.pooling.pooling.serving import OpenAIServingPooling
from vllm.entrypoints.pooling.score.serving import ServingScores
70
71
72
73
74
from vllm.entrypoints.serve.disagg.serving import ServingTokens
from vllm.entrypoints.serve.elastic_ep.middleware import (
    ScalingMiddleware,
)
from vllm.entrypoints.serve.tokenize.serving import OpenAIServingTokenization
75
76
77
78
79
from vllm.entrypoints.tool_server import DemoToolServer, MCPToolServer, ToolServer
from vllm.entrypoints.utils import (
    cli_env_setup,
    load_aware_call,
    log_non_default_args,
80
81
    process_chat_template,
    process_lora_modules,
82
    sanitize_message,
83
84
    with_cancellation,
)
85
from vllm.exceptions import VLLMValidationError
86
from vllm.logger import init_logger
87
from vllm.reasoning import ReasoningParserManager
88
from vllm.tasks import POOLING_TASKS
89
from vllm.tool_parsers import ToolParserManager
yhu422's avatar
yhu422 committed
90
from vllm.usage.usage_lib import UsageContext
Cyrus Leung's avatar
Cyrus Leung committed
91
from vllm.utils.argparse_utils import FlexibleArgumentParser
92
from vllm.utils.gc_utils import freeze_gc_heap
93
from vllm.utils.network_utils import is_valid_ipv6_address
94
from vllm.utils.system_utils import decorate_logs, set_ulimit
95
from vllm.version import __version__ as VLLM_VERSION
Zhuohan Li's avatar
Zhuohan Li committed
96

97
prometheus_multiproc_dir: tempfile.TemporaryDirectory
98

99
# Cannot use __name__ (https://github.com/vllm-project/vllm/pull/4765)
100
logger = init_logger("vllm.entrypoints.openai.api_server")
101

102
103
ENDPOINT_LOAD_METRICS_FORMAT_HEADER_LABEL = "endpoint-load-metrics-format"

104
_running_tasks: set[asyncio.Task] = set()
105

106

107
@asynccontextmanager
108
async def lifespan(app: FastAPI):
109
110
    try:
        if app.state.log_stats:
111
            engine_client: EngineClient = app.state.engine_client
112
113
114

            async def _force_log():
                while True:
115
                    await asyncio.sleep(envs.VLLM_LOG_STATS_INTERVAL)
116
                    await engine_client.do_log_stats()
117
118
119
120
121
122

            task = asyncio.create_task(_force_log())
            _running_tasks.add(task)
            task.add_done_callback(_running_tasks.remove)
        else:
            task = None
123
124
125

        # Mark the startup heap as static so that it's ignored by GC.
        # Reduces pause times of oldest generation collections.
126
        freeze_gc_heap()
127
128
129
130
131
132
133
134
        try:
            yield
        finally:
            if task is not None:
                task.cancel()
    finally:
        # Ensure app state including engine ref is gc'd
        del app.state
135
136


137
@asynccontextmanager
138
async def build_async_engine_client(
139
    args: Namespace,
140
141
    *,
    usage_context: UsageContext = UsageContext.OPENAI_API_SERVER,
142
143
    disable_frontend_multiprocessing: bool | None = None,
    client_config: dict[str, Any] | None = None,
144
) -> AsyncIterator[EngineClient]:
145
146
147
148
    if os.getenv("VLLM_WORKER_MULTIPROC_METHOD") == "forkserver":
        # The executor is expected to be mp.
        # Pre-import heavy modules in the forkserver process
        logger.debug("Setup forkserver with pre-imports")
149
        multiprocessing.set_start_method("forkserver")
150
151
152
153
        multiprocessing.set_forkserver_preload(["vllm.v1.engine.async_llm"])
        forkserver.ensure_running()
        logger.debug("Forkserver setup complete!")

154
    # Context manager to handle engine_client lifecycle
155
156
    # Ensures everything is shutdown and cleaned up on error/exit
    engine_args = AsyncEngineArgs.from_cli_args(args)
157
158
159
    if client_config:
        engine_args._api_process_count = client_config.get("client_count", 1)
        engine_args._api_process_rank = client_config.get("client_index", 0)
160

161
    if disable_frontend_multiprocessing is None:
162
        disable_frontend_multiprocessing = bool(args.disable_frontend_multiprocessing)
163

164
    async with build_async_engine_client_from_engine_args(
165
166
167
168
        engine_args,
        usage_context=usage_context,
        disable_frontend_multiprocessing=disable_frontend_multiprocessing,
        client_config=client_config,
169
    ) as engine:
170
171
172
173
174
175
        yield engine


@asynccontextmanager
async def build_async_engine_client_from_engine_args(
    engine_args: AsyncEngineArgs,
176
177
    *,
    usage_context: UsageContext = UsageContext.OPENAI_API_SERVER,
178
    disable_frontend_multiprocessing: bool = False,
179
    client_config: dict[str, Any] | None = None,
180
) -> AsyncIterator[EngineClient]:
181
    """
182
    Create EngineClient, either:
183
184
185
186
187
188
        - in-process using the AsyncLLMEngine Directly
        - multiprocess using AsyncLLMEngine RPC

    Returns the Client or None if the creation failed.
    """

189
190
191
    # Create the EngineConfig (determines if we can use V1).
    vllm_config = engine_args.create_engine_config(usage_context=usage_context)

192
    if disable_frontend_multiprocessing:
193
        logger.warning("V1 is enabled, but got --disable-frontend-multiprocessing.")
194

195
    from vllm.v1.engine.async_llm import AsyncLLM
196

197
    async_llm: AsyncLLM | None = None
198
199
200
201
202
203

    # Don't mutate the input client_config
    client_config = dict(client_config) if client_config else {}
    client_count = client_config.pop("client_count", 1)
    client_index = client_config.pop("client_index", 0)

204
205
206
207
208
    try:
        async_llm = AsyncLLM.from_vllm_config(
            vllm_config=vllm_config,
            usage_context=usage_context,
            enable_log_requests=engine_args.enable_log_requests,
209
            aggregate_engine_logging=engine_args.aggregate_engine_logging,
210
211
212
            disable_log_stats=engine_args.disable_log_stats,
            client_addresses=client_config,
            client_count=client_count,
213
214
            client_index=client_index,
        )
215
216

        # Don't keep the dummy data in memory
217
        assert async_llm is not None
218
219
220
221
222
223
        await async_llm.reset_mm_cache()

        yield async_llm
    finally:
        if async_llm:
            async_llm.shutdown()
224
225


Ethan Xu's avatar
Ethan Xu committed
226
router = APIRouter()
Zhuohan Li's avatar
Zhuohan Li committed
227

228

229
230
231
232
233
def base(request: Request) -> OpenAIServing:
    # Reuse the existing instance
    return tokenization(request)


234
235
236
237
def models(request: Request) -> OpenAIServingModels:
    return request.app.state.openai_serving_models


238
239
240
241
def messages(request: Request) -> AnthropicServingMessages:
    return request.app.state.anthropic_serving_messages


242
def chat(request: Request) -> OpenAIServingChat | None:
243
244
245
    return request.app.state.openai_serving_chat


246
def completion(request: Request) -> OpenAIServingCompletion | None:
247
248
249
    return request.app.state.openai_serving_completion


250
251
def tokenization(request: Request) -> OpenAIServingTokenization:
    return request.app.state.openai_serving_tokenization
252
253


254
def engine_client(request: Request) -> EngineClient:
255
256
257
    return request.app.state.engine_client


258
259
260
261
262
263
264
@router.get("/load")
async def get_server_load_metrics(request: Request):
    # This endpoint returns the current server load metrics.
    # It tracks requests utilizing the GPU from the following routes:
    # - /v1/chat/completions
    # - /v1/completions
    # - /v1/audio/transcriptions
265
    # - /v1/audio/translations
266
267
    # - /v1/embeddings
    # - /pooling
268
    # - /classify
269
270
271
272
273
    # - /score
    # - /v1/score
    # - /rerank
    # - /v1/rerank
    # - /v2/rerank
274
    return JSONResponse(content={"server_load": request.app.state.server_load_metrics})
275
276


Ethan Xu's avatar
Ethan Xu committed
277
@router.get("/v1/models")
278
async def show_available_models(raw_request: Request):
279
    handler = models(raw_request)
280

281
282
    models_ = await handler.show_available_models()
    return JSONResponse(content=models_.model_dump())
Zhuohan Li's avatar
Zhuohan Li committed
283
284


Ethan Xu's avatar
Ethan Xu committed
285
@router.get("/version")
286
async def show_version():
287
    ver = {"version": VLLM_VERSION}
288
289
290
    return JSONResponse(content=ver)


291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
@router.post(
    "/v1/messages",
    dependencies=[Depends(validate_json_request)],
    responses={
        HTTPStatus.OK.value: {"content": {"text/event-stream": {}}},
        HTTPStatus.BAD_REQUEST.value: {"model": AnthropicErrorResponse},
        HTTPStatus.NOT_FOUND.value: {"model": AnthropicErrorResponse},
        HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": AnthropicErrorResponse},
    },
)
@with_cancellation
@load_aware_call
async def create_messages(request: AnthropicMessagesRequest, raw_request: Request):
    def translate_error_response(response: ErrorResponse) -> JSONResponse:
        anthropic_error = AnthropicErrorResponse(
            error=AnthropicError(
                type=response.error.type,
                message=response.error.message,
            )
        )
        return JSONResponse(
            status_code=response.error.code, content=anthropic_error.model_dump()
        )

    handler = messages(raw_request)
    if handler is None:
        error = base(raw_request).create_error_response(
            message="The model does not support Messages API"
        )
        return translate_error_response(error)

    try:
        generator = await handler.create_messages(request, raw_request)
    except Exception as e:
        logger.exception("Error in create_messages: %s", e)
        return JSONResponse(
            status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value,
            content=AnthropicErrorResponse(
                error=AnthropicError(
                    type="internal_error",
                    message=str(e),
                )
            ).model_dump(),
        )

    if isinstance(generator, ErrorResponse):
        return translate_error_response(generator)

    elif isinstance(generator, AnthropicMessagesResponse):
340
341
342
        resp = generator.model_dump(exclude_none=True)
        logger.debug("Anthropic Messages Response: %s", resp)
        return JSONResponse(content=resp)
343
344
345
346

    return StreamingResponse(content=generator, media_type="text/event-stream")


347
348
349
350
351
352
353
354
355
356
@router.post(
    "/v1/completions",
    dependencies=[Depends(validate_json_request)],
    responses={
        HTTPStatus.OK.value: {"content": {"text/event-stream": {}}},
        HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
        HTTPStatus.NOT_FOUND.value: {"model": ErrorResponse},
        HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
    },
)
357
@with_cancellation
358
@load_aware_call
359
async def create_completion(request: CompletionRequest, raw_request: Request):
360
361
362
    metrics_header_format = raw_request.headers.get(
        ENDPOINT_LOAD_METRICS_FORMAT_HEADER_LABEL, ""
    )
363
364
365
    handler = completion(raw_request)
    if handler is None:
        return base(raw_request).create_error_response(
366
367
            message="The model does not support Completions API"
        )
368

369
370
371
    try:
        generator = await handler.create_completion(request, raw_request)
    except OverflowError as e:
372
373
374
        raise HTTPException(
            status_code=HTTPStatus.BAD_REQUEST.value, detail=str(e)
        ) from e
375
    except Exception as e:
376
377
378
        raise HTTPException(
            status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)
        ) from e
379

380
    if isinstance(generator, ErrorResponse):
381
382
383
        return JSONResponse(
            content=generator.model_dump(), status_code=generator.error.code
        )
384
    elif isinstance(generator, CompletionResponse):
385
386
387
388
        return JSONResponse(
            content=generator.model_dump(),
            headers=metrics_header(metrics_header_format),
        )
Zhuohan Li's avatar
Zhuohan Li committed
389

390
391
    return StreamingResponse(content=generator, media_type="text/event-stream")

Zhuohan Li's avatar
Zhuohan Li committed
392

393
def load_log_config(log_config_file: str | None) -> dict | None:
394
395
396
397
398
399
    if not log_config_file:
        return None
    try:
        with open(log_config_file) as f:
            return json.load(f)
    except Exception as e:
400
401
402
        logger.warning(
            "Failed to load log config from file %s: error %s", log_config_file, e
        )
403
404
405
        return None


406
407
408
class AuthenticationMiddleware:
    """
    Pure ASGI middleware that authenticates each request by checking
409
    if the Authorization Bearer token exists and equals anyof "{api_key}".
410
411
412
413
414
415
416
417

    Notes
    -----
    There are two cases in which authentication is skipped:
        1. The HTTP method is OPTIONS.
        2. The request path doesn't start with /v1 (e.g. /health).
    """

418
    def __init__(self, app: ASGIApp, tokens: list[str]) -> None:
419
        self.app = app
420
        self.api_tokens = [hashlib.sha256(t.encode("utf-8")).digest() for t in tokens]
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437

    def verify_token(self, headers: Headers) -> bool:
        authorization_header_value = headers.get("Authorization")
        if not authorization_header_value:
            return False

        scheme, _, param = authorization_header_value.partition(" ")
        if scheme.lower() != "bearer":
            return False

        param_hash = hashlib.sha256(param.encode("utf-8")).digest()

        token_match = False
        for token_hash in self.api_tokens:
            token_match |= secrets.compare_digest(param_hash, token_hash)

        return token_match
438

439
440
    def __call__(self, scope: Scope, receive: Receive, send: Send) -> Awaitable[None]:
        if scope["type"] not in ("http", "websocket") or scope["method"] == "OPTIONS":
441
442
443
444
445
446
447
            # scope["type"] can be "lifespan" or "startup" for example,
            # in which case we don't need to do anything
            return self.app(scope, receive, send)
        root_path = scope.get("root_path", "")
        url_path = URL(scope=scope).path.removeprefix(root_path)
        headers = Headers(scope=scope)
        # Type narrow to satisfy mypy.
448
        if url_path.startswith("/v1") and not self.verify_token(headers):
449
            response = JSONResponse(content={"error": "Unauthorized"}, status_code=401)
450
451
452
453
454
455
456
457
458
459
460
461
462
463
            return response(scope, receive, send)
        return self.app(scope, receive, send)


class XRequestIdMiddleware:
    """
    Middleware the set's the X-Request-Id header for each response
    to a random uuid4 (hex) value if the header isn't already
    present in the request, otherwise use the provided request id.
    """

    def __init__(self, app: ASGIApp) -> None:
        self.app = app

464
    def __call__(self, scope: Scope, receive: Receive, send: Send) -> Awaitable[None]:
465
466
467
468
469
470
471
472
473
474
475
476
477
        if scope["type"] not in ("http", "websocket"):
            return self.app(scope, receive, send)

        # Extract the request headers.
        request_headers = Headers(scope=scope)

        async def send_with_request_id(message: Message) -> None:
            """
            Custom send function to mutate the response headers
            and append X-Request-Id to it.
            """
            if message["type"] == "http.response.start":
                response_headers = MutableHeaders(raw=message["headers"])
478
                request_id = request_headers.get("X-Request-Id", uuid.uuid4().hex)
479
480
481
482
483
484
                response_headers.append("X-Request-Id", request_id)
            await send(message)

        return self.app(scope, receive, send_with_request_id)


485
486
487
def _extract_content_from_chunk(chunk_data: dict) -> str:
    """Extract content from a streaming response chunk."""
    try:
488
        from vllm.entrypoints.openai.chat_completion.protocol import (
489
            ChatCompletionStreamResponse,
490
491
        )
        from vllm.entrypoints.openai.engine.protocol import (
492
493
            CompletionStreamResponse,
        )
494
495

        # Try using Completion types for type-safe parsing
496
497
        if chunk_data.get("object") == "chat.completion.chunk":
            chat_response = ChatCompletionStreamResponse.model_validate(chunk_data)
498
499
            if chat_response.choices and chat_response.choices[0].delta.content:
                return chat_response.choices[0].delta.content
500
501
502
        elif chunk_data.get("object") == "text_completion":
            completion_response = CompletionStreamResponse.model_validate(chunk_data)
            if completion_response.choices and completion_response.choices[0].text:
503
504
505
                return completion_response.choices[0].text
    except pydantic.ValidationError:
        # Fallback to manual parsing
506
507
508
509
510
511
        if "choices" in chunk_data and chunk_data["choices"]:
            choice = chunk_data["choices"][0]
            if "delta" in choice and choice["delta"].get("content"):
                return choice["delta"]["content"]
            elif choice.get("text"):
                return choice["text"]
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
    return ""


class SSEDecoder:
    """Robust Server-Sent Events decoder for streaming responses."""

    def __init__(self):
        self.buffer = ""
        self.content_buffer = []

    def decode_chunk(self, chunk: bytes) -> list[dict]:
        """Decode a chunk of SSE data and return parsed events."""
        import json

        try:
527
            chunk_str = chunk.decode("utf-8")
528
529
530
531
532
533
534
535
        except UnicodeDecodeError:
            # Skip malformed chunks
            return []

        self.buffer += chunk_str
        events = []

        # Process complete lines
536
537
538
        while "\n" in self.buffer:
            line, self.buffer = self.buffer.split("\n", 1)
            line = line.rstrip("\r")  # Handle CRLF
539

540
            if line.startswith("data: "):
541
                data_str = line[6:].strip()
542
543
                if data_str == "[DONE]":
                    events.append({"type": "done"})
544
545
546
                elif data_str:
                    try:
                        event_data = json.loads(data_str)
547
                        events.append({"type": "data", "data": event_data})
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
                    except json.JSONDecodeError:
                        # Skip malformed JSON
                        continue

        return events

    def extract_content(self, event_data: dict) -> str:
        """Extract content from event data."""
        return _extract_content_from_chunk(event_data)

    def add_content(self, content: str) -> None:
        """Add content to the buffer."""
        if content:
            self.content_buffer.append(content)

    def get_complete_content(self) -> str:
        """Get the complete buffered content."""
565
        return "".join(self.content_buffer)
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585


def _log_streaming_response(response, response_body: list) -> None:
    """Log streaming response with robust SSE parsing."""
    from starlette.concurrency import iterate_in_threadpool

    sse_decoder = SSEDecoder()
    chunk_count = 0

    def buffered_iterator():
        nonlocal chunk_count

        for chunk in response_body:
            chunk_count += 1
            yield chunk

            # Parse SSE events from chunk
            events = sse_decoder.decode_chunk(chunk)

            for event in events:
586
587
                if event["type"] == "data":
                    content = sse_decoder.extract_content(event["data"])
588
                    sse_decoder.add_content(content)
589
                elif event["type"] == "done":
590
591
592
593
594
595
596
597
                    # Log complete content when done
                    full_content = sse_decoder.get_complete_content()
                    if full_content:
                        # Truncate if too long
                        if len(full_content) > 2048:
                            full_content = full_content[:2048] + ""
                            "...[truncated]"
                        logger.info(
598
                            "response_body={streaming_complete: content=%r, chunks=%d}",
599
600
601
                            full_content,
                            chunk_count,
                        )
602
603
                    else:
                        logger.info(
604
605
606
                            "response_body={streaming_complete: no_content, chunks=%d}",
                            chunk_count,
                        )
607
608
609
                    return

    response.body_iterator = iterate_in_threadpool(buffered_iterator())
610
    logger.info("response_body={streaming_started: chunks=%d}", len(response_body))
611
612
613
614
615
616
617
618
619
620
621


def _log_non_streaming_response(response_body: list) -> None:
    """Log non-streaming response."""
    try:
        decoded_body = response_body[0].decode()
        logger.info("response_body={%s}", decoded_body)
    except UnicodeDecodeError:
        logger.info("response_body={<binary_data>}")


622
def build_app(args: Namespace) -> FastAPI:
623
    if args.disable_fastapi_docs:
624
625
626
        app = FastAPI(
            openapi_url=None, docs_url=None, redoc_url=None, lifespan=lifespan
        )
627
628
    elif args.enable_offline_docs:
        app = FastAPI(docs_url=None, redoc_url=None, lifespan=lifespan)
629
630
    else:
        app = FastAPI(lifespan=lifespan)
631
632
    app.state.args = args
    from vllm.entrypoints.serve import register_vllm_serve_api_routers
633

634
    register_vllm_serve_api_routers(app)
635
636
637
    from vllm.entrypoints.openai.chat_completion.api_router import (
        attach_router as register_chat_api_router,
    )
638

639
    register_chat_api_router(app)
640
641
642
643
644
645

    from vllm.entrypoints.openai.responses.api_router import (
        attach_router as register_responses_api_router,
    )

    register_responses_api_router(app)
646
647
648
649
650
    from vllm.entrypoints.openai.translations.api_router import (
        attach_router as register_translations_api_router,
    )

    register_translations_api_router(app)
651
652
653
    from vllm.entrypoints.sagemaker.routes import register_sagemaker_routes

    register_sagemaker_routes(router)
Ethan Xu's avatar
Ethan Xu committed
654
    app.include_router(router)
655

Ethan Xu's avatar
Ethan Xu committed
656
    app.root_path = args.root_path
Zhuohan Li's avatar
Zhuohan Li committed
657

658
659
660
661
    from vllm.entrypoints.pooling import register_pooling_api_routers

    register_pooling_api_routers(app)

Zhuohan Li's avatar
Zhuohan Li committed
662
663
664
665
666
667
668
669
    app.add_middleware(
        CORSMiddleware,
        allow_origins=args.allowed_origins,
        allow_credentials=args.allow_credentials,
        allow_methods=args.allowed_methods,
        allow_headers=args.allowed_headers,
    )

670
671
    @app.exception_handler(HTTPException)
    async def http_exception_handler(_: Request, exc: HTTPException):
672
        err = ErrorResponse(
673
            error=ErrorInfo(
674
                message=sanitize_message(exc.detail),
675
676
677
678
                type=HTTPStatus(exc.status_code).phrase,
                code=exc.status_code,
            )
        )
679
680
        return JSONResponse(err.model_dump(), status_code=exc.status_code)

Ethan Xu's avatar
Ethan Xu committed
681
    @app.exception_handler(RequestValidationError)
682
    async def validation_exception_handler(_: Request, exc: RequestValidationError):
683
684
685
686
687
688
689
690
        param = None
        for error in exc.errors():
            if "ctx" in error and "error" in error["ctx"]:
                ctx_error = error["ctx"]["error"]
                if isinstance(ctx_error, VLLMValidationError):
                    param = ctx_error.parameter
                    break

691
692
693
694
695
696
697
698
        exc_str = str(exc)
        errors_str = str(exc.errors())

        if exc.errors() and errors_str and errors_str != exc_str:
            message = f"{exc_str} {errors_str}"
        else:
            message = exc_str

699
700
        err = ErrorResponse(
            error=ErrorInfo(
701
                message=sanitize_message(message),
702
703
                type=HTTPStatus.BAD_REQUEST.phrase,
                code=HTTPStatus.BAD_REQUEST,
704
                param=param,
705
706
707
            )
        )
        return JSONResponse(err.model_dump(), status_code=HTTPStatus.BAD_REQUEST)
Ethan Xu's avatar
Ethan Xu committed
708

709
    # Ensure --api-key option from CLI takes precedence over VLLM_API_KEY
710
711
    if tokens := [key for key in (args.api_key or [envs.VLLM_API_KEY]) if key]:
        app.add_middleware(AuthenticationMiddleware, tokens=tokens)
712

713
    if args.enable_request_id_headers:
714
        app.add_middleware(XRequestIdMiddleware)
715

716
717
718
    # Add scaling middleware to check for scaling state
    app.add_middleware(ScalingMiddleware)

719
    if envs.VLLM_DEBUG_LOG_API_SERVER_RESPONSE:
720
721
722
723
724
        logger.warning(
            "CAUTION: Enabling log response in the API Server. "
            "This can include sensitive information and should be "
            "avoided in production."
        )
725
726
727
728

        @app.middleware("http")
        async def log_response(request: Request, call_next):
            response = await call_next(request)
729
            response_body = [section async for section in response.body_iterator]
730
            response.body_iterator = iterate_in_threadpool(iter(response_body))
731
732
733
734
735
736
737
738
739
740
741
            # Check if this is a streaming response by looking at content-type
            content_type = response.headers.get("content-type", "")
            is_streaming = content_type == "text/event-stream; charset=utf-8"

            # Log response body based on type
            if not response_body:
                logger.info("response_body={<empty>}")
            elif is_streaming:
                _log_streaming_response(response, response_body)
            else:
                _log_non_streaming_response(response_body)
742
            return response
743

744
745
746
747
    for middleware in args.middleware:
        module_path, object_name = middleware.rsplit(".", 1)
        imported = getattr(importlib.import_module(module_path), object_name)
        if inspect.isclass(imported):
748
            app.add_middleware(imported)  # type: ignore[arg-type]
749
750
751
        elif inspect.iscoroutinefunction(imported):
            app.middleware("http")(imported)
        else:
752
753
754
            raise ValueError(
                f"Invalid middleware {middleware}. Must be a function or a class."
            )
755

756
757
    app = sagemaker_standards.bootstrap(app)

Ethan Xu's avatar
Ethan Xu committed
758
759
760
    return app


761
async def init_app_state(
762
    engine_client: EngineClient,
763
    state: State,
764
    args: Namespace,
765
) -> None:
766
767
    vllm_config = engine_client.vllm_config

768
    if args.served_model_name is not None:
769
        served_model_names = args.served_model_name
770
    else:
771
        served_model_names = [args.model]
772

773
    if args.enable_log_requests:
774
        request_logger = RequestLogger(max_log_len=args.max_log_len)
775
776
    else:
        request_logger = None
777

778
    base_model_paths = [
779
        BaseModelPath(name=name, model_path=args.model) for name in served_model_names
780
781
    ]

782
    state.engine_client = engine_client
783
    state.log_stats = not args.disable_log_stats
784
    state.vllm_config = vllm_config
785
    state.args = args
786
    supported_tasks = await engine_client.get_supported_tasks()
787
    logger.info("Supported tasks: %s", supported_tasks)
788

789
    resolved_chat_template = await process_chat_template(
790
        args.chat_template, engine_client, vllm_config.model_config
791
    )
792

793
    if args.tool_server == "demo":
794
        tool_server: ToolServer | None = DemoToolServer()
795
796
        assert isinstance(tool_server, DemoToolServer)
        await tool_server.init_and_validate()
797
798
799
    elif args.tool_server:
        tool_server = MCPToolServer()
        await tool_server.add_tool_server(args.tool_server)
800
801
802
    else:
        tool_server = None

803
    # Merge default_mm_loras into the static lora_modules
804
805
806
807
808
    default_mm_loras = (
        vllm_config.lora_config.default_mm_loras
        if vllm_config.lora_config is not None
        else {}
    )
809

810
811
812
813
814
815
    default_mm_loras = (
        vllm_config.lora_config.default_mm_loras
        if vllm_config.lora_config is not None
        else {}
    )
    lora_modules = process_lora_modules(args.lora_modules, default_mm_loras)
816

817
    state.openai_serving_models = OpenAIServingModels(
818
        engine_client=engine_client,
819
        base_model_paths=base_model_paths,
820
        lora_modules=lora_modules,
821
    )
822
    await state.openai_serving_models.init_static_loras()
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
    state.openai_serving_responses = (
        OpenAIServingResponses(
            engine_client,
            state.openai_serving_models,
            request_logger=request_logger,
            chat_template=resolved_chat_template,
            chat_template_content_format=args.chat_template_content_format,
            return_tokens_as_token_ids=args.return_tokens_as_token_ids,
            enable_auto_tools=args.enable_auto_tool_choice,
            tool_parser=args.tool_call_parser,
            tool_server=tool_server,
            reasoning_parser=args.structured_outputs_config.reasoning_parser,
            enable_prompt_tokens_details=args.enable_prompt_tokens_details,
            enable_force_include_usage=args.enable_force_include_usage,
            enable_log_outputs=args.enable_log_outputs,
            log_error_stack=args.log_error_stack,
        )
        if "generate" in supported_tasks
        else None
    )
    state.openai_serving_chat = (
        OpenAIServingChat(
            engine_client,
            state.openai_serving_models,
            args.response_role,
            request_logger=request_logger,
            chat_template=resolved_chat_template,
            chat_template_content_format=args.chat_template_content_format,
851
            default_chat_template_kwargs=args.default_chat_template_kwargs,
852
853
854
855
856
857
858
859
860
            trust_request_chat_template=args.trust_request_chat_template,
            return_tokens_as_token_ids=args.return_tokens_as_token_ids,
            enable_auto_tools=args.enable_auto_tool_choice,
            exclude_tools_when_tool_choice_none=args.exclude_tools_when_tool_choice_none,
            tool_parser=args.tool_call_parser,
            reasoning_parser=args.structured_outputs_config.reasoning_parser,
            enable_prompt_tokens_details=args.enable_prompt_tokens_details,
            enable_force_include_usage=args.enable_force_include_usage,
            enable_log_outputs=args.enable_log_outputs,
861
            enable_log_deltas=args.enable_log_deltas,
862
863
864
865
866
            log_error_stack=args.log_error_stack,
        )
        if "generate" in supported_tasks
        else None
    )
867
868
869
    # Warm up chat template processing to avoid first-request latency
    if state.openai_serving_chat is not None:
        await state.openai_serving_chat.warmup()
870
871
872
873
874
875
876
877
878
879
880
881
882
883
    state.openai_serving_completion = (
        OpenAIServingCompletion(
            engine_client,
            state.openai_serving_models,
            request_logger=request_logger,
            return_tokens_as_token_ids=args.return_tokens_as_token_ids,
            enable_prompt_tokens_details=args.enable_prompt_tokens_details,
            enable_force_include_usage=args.enable_force_include_usage,
            log_error_stack=args.log_error_stack,
        )
        if "generate" in supported_tasks
        else None
    )
    state.openai_serving_pooling = (
884
885
886
887
888
889
890
891
892
893
894
        (
            OpenAIServingPooling(
                engine_client,
                state.openai_serving_models,
                supported_tasks=supported_tasks,
                request_logger=request_logger,
                chat_template=resolved_chat_template,
                chat_template_content_format=args.chat_template_content_format,
                trust_request_chat_template=args.trust_request_chat_template,
                log_error_stack=args.log_error_stack,
            )
895
        )
896
        if any(task in POOLING_TASKS for task in supported_tasks)
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
        else None
    )
    state.openai_serving_embedding = (
        OpenAIServingEmbedding(
            engine_client,
            state.openai_serving_models,
            request_logger=request_logger,
            chat_template=resolved_chat_template,
            chat_template_content_format=args.chat_template_content_format,
            trust_request_chat_template=args.trust_request_chat_template,
            log_error_stack=args.log_error_stack,
        )
        if "embed" in supported_tasks
        else None
    )
    state.openai_serving_classification = (
        ServingClassification(
            engine_client,
            state.openai_serving_models,
            request_logger=request_logger,
917
918
919
            chat_template=resolved_chat_template,
            chat_template_content_format=args.chat_template_content_format,
            trust_request_chat_template=args.trust_request_chat_template,
920
921
922
923
924
925
926
927
928
929
            log_error_stack=args.log_error_stack,
        )
        if "classify" in supported_tasks
        else None
    )
    state.openai_serving_scores = (
        ServingScores(
            engine_client,
            state.openai_serving_models,
            request_logger=request_logger,
930
            score_template=resolved_chat_template,
931
932
933
934
935
            log_error_stack=args.log_error_stack,
        )
        if ("embed" in supported_tasks or "score" in supported_tasks)
        else None
    )
936
    state.openai_serving_tokenization = OpenAIServingTokenization(
937
        engine_client,
938
        state.openai_serving_models,
939
        request_logger=request_logger,
940
941
        chat_template=resolved_chat_template,
        chat_template_content_format=args.chat_template_content_format,
942
        trust_request_chat_template=args.trust_request_chat_template,
943
        log_error_stack=args.log_error_stack,
944
    )
945
946
947
948
949
950
    state.openai_serving_transcription = (
        OpenAIServingTranscription(
            engine_client,
            state.openai_serving_models,
            request_logger=request_logger,
            log_error_stack=args.log_error_stack,
951
            enable_force_include_usage=args.enable_force_include_usage,
952
953
954
955
956
957
958
959
960
961
        )
        if "transcription" in supported_tasks
        else None
    )
    state.openai_serving_translation = (
        OpenAIServingTranslation(
            engine_client,
            state.openai_serving_models,
            request_logger=request_logger,
            log_error_stack=args.log_error_stack,
962
            enable_force_include_usage=args.enable_force_include_usage,
963
964
965
966
        )
        if "transcription" in supported_tasks
        else None
    )
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
    state.anthropic_serving_messages = (
        AnthropicServingMessages(
            engine_client,
            state.openai_serving_models,
            args.response_role,
            request_logger=request_logger,
            chat_template=resolved_chat_template,
            chat_template_content_format=args.chat_template_content_format,
            return_tokens_as_token_ids=args.return_tokens_as_token_ids,
            enable_auto_tools=args.enable_auto_tool_choice,
            tool_parser=args.tool_call_parser,
            reasoning_parser=args.structured_outputs_config.reasoning_parser,
            enable_prompt_tokens_details=args.enable_prompt_tokens_details,
            enable_force_include_usage=args.enable_force_include_usage,
        )
        if "generate" in supported_tasks
        else None
    )
985
986
987
988
989
990
991
992
993
994
995
996
997
998
    state.serving_tokens = (
        ServingTokens(
            engine_client,
            state.openai_serving_models,
            request_logger=request_logger,
            return_tokens_as_token_ids=args.return_tokens_as_token_ids,
            log_error_stack=args.log_error_stack,
            enable_prompt_tokens_details=args.enable_prompt_tokens_details,
            enable_log_outputs=args.enable_log_outputs,
            force_no_detokenize=args.tokens_only,
        )
        if "generate" in supported_tasks
        else None
    )
999

1000
1001
1002
    state.enable_server_load_tracking = args.enable_server_load_tracking
    state.server_load_metrics = 0

1003

1004
def create_server_socket(addr: tuple[str, int]) -> socket.socket:
1005
1006
1007
1008
1009
1010
    family = socket.AF_INET
    if is_valid_ipv6_address(addr[0]):
        family = socket.AF_INET6

    sock = socket.socket(family=family, type=socket.SOCK_STREAM)
    sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
1011
    sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
1012
1013
1014
1015
1016
    sock.bind(addr)

    return sock


1017
1018
1019
1020
1021
1022
def create_server_unix_socket(path: str) -> socket.socket:
    sock = socket.socket(family=socket.AF_UNIX, type=socket.SOCK_STREAM)
    sock.bind(path)
    return sock


1023
def validate_api_server_args(args):
1024
    valid_tool_parses = ToolParserManager.list_registered()
1025
1026
1027
1028
1029
    if args.enable_auto_tool_choice and args.tool_call_parser not in valid_tool_parses:
        raise KeyError(
            f"invalid tool call parser: {args.tool_call_parser} "
            f"(chose from {{ {','.join(valid_tool_parses)} }})"
        )
1030

1031
    valid_reasoning_parsers = ReasoningParserManager.list_registered()
1032
1033
    if (
        reasoning_parser := args.structured_outputs_config.reasoning_parser
1034
    ) and reasoning_parser not in valid_reasoning_parsers:
1035
        raise KeyError(
1036
            f"invalid reasoning parser: {reasoning_parser} "
1037
            f"(chose from {{ {','.join(valid_reasoning_parsers)} }})"
1038
        )
1039

1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050

def setup_server(args):
    """Validate API server args, set up signal handler, create socket
    ready to serve."""

    logger.info("vLLM API server version %s", VLLM_VERSION)
    log_non_default_args(args)

    if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3:
        ToolParserManager.import_tool_parser(args.tool_parser_plugin)

1051
1052
1053
    if args.reasoning_parser_plugin and len(args.reasoning_parser_plugin) > 3:
        ReasoningParserManager.import_reasoning_parser(args.reasoning_parser_plugin)

1054
1055
    validate_api_server_args(args)

1056
1057
1058
    # workaround to make sure that we bind the port before the engine is set up.
    # This avoids race conditions with ray.
    # see https://github.com/vllm-project/vllm/issues/8204
1059
1060
1061
1062
1063
    if args.uds:
        sock = create_server_unix_socket(args.uds)
    else:
        sock_addr = (args.host or "", args.port)
        sock = create_server_socket(sock_addr)
1064

1065
1066
1067
1068
    # workaround to avoid footguns where uvicorn drops requests with too
    # many concurrent requests active
    set_ulimit()

1069
1070
1071
1072
1073
1074
    def signal_handler(*_) -> None:
        # Interrupt server on sigterm while initializing
        raise KeyboardInterrupt("terminated")

    signal.signal(signal.SIGTERM, signal_handler)

1075
1076
1077
1078
1079
    if args.uds:
        listen_address = f"unix:{args.uds}"
    else:
        addr, port = sock_addr
        is_ssl = args.ssl_keyfile and args.ssl_certfile
1080
        host_part = f"[{addr}]" if is_valid_ipv6_address(addr) else addr or "0.0.0.0"
1081
        listen_address = f"http{'s' if is_ssl else ''}://{host_part}:{port}"
1082
1083
1084
1085
1086
    return listen_address, sock


async def run_server(args, **uvicorn_kwargs) -> None:
    """Run a single-worker API server."""
1087
1088

    # Add process-specific prefix to stdout and stderr.
1089
    decorate_logs("APIServer")
1090

1091
1092
1093
1094
    listen_address, sock = setup_server(args)
    await run_server_worker(listen_address, sock, args, **uvicorn_kwargs)


1095
1096
1097
async def run_server_worker(
    listen_address, sock, args, client_config=None, **uvicorn_kwargs
) -> None:
1098
1099
1100
1101
1102
    """Run a single API server worker."""

    if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3:
        ToolParserManager.import_tool_parser(args.tool_parser_plugin)

1103
1104
1105
    if args.reasoning_parser_plugin and len(args.reasoning_parser_plugin) > 3:
        ReasoningParserManager.import_reasoning_parser(args.reasoning_parser_plugin)

1106
1107
1108
    # Load logging config for uvicorn if specified
    log_config = load_log_config(args.log_config_file)
    if log_config is not None:
1109
        uvicorn_kwargs["log_config"] = log_config
1110

1111
    async with build_async_engine_client(
1112
1113
        args,
        client_config=client_config,
1114
    ) as engine_client:
1115
1116
        app = build_app(args)

1117
        await init_app_state(engine_client, app.state, args)
1118

1119
1120
        logger.info(
            "Starting vLLM API server %d on %s",
1121
            engine_client.vllm_config.parallel_config._api_process_rank,
1122
1123
            listen_address,
        )
1124
1125
        shutdown_task = await serve_http(
            app,
1126
            sock=sock,
1127
            enable_ssl_refresh=args.enable_ssl_refresh,
1128
1129
1130
            host=args.host,
            port=args.port,
            log_level=args.uvicorn_log_level,
1131
1132
1133
            # NOTE: When the 'disable_uvicorn_access_log' value is True,
            # no access log will be output.
            access_log=not args.disable_uvicorn_access_log,
1134
            timeout_keep_alive=envs.VLLM_HTTP_TIMEOUT_KEEP_ALIVE,
1135
1136
1137
1138
            ssl_keyfile=args.ssl_keyfile,
            ssl_certfile=args.ssl_certfile,
            ssl_ca_certs=args.ssl_ca_certs,
            ssl_cert_reqs=args.ssl_cert_reqs,
1139
1140
            h11_max_incomplete_event_size=args.h11_max_incomplete_event_size,
            h11_max_header_count=args.h11_max_header_count,
1141
1142
1143
            **uvicorn_kwargs,
        )

1144
    # NB: Await server shutdown only after the backend context is exited
1145
1146
1147
1148
    try:
        await shutdown_task
    finally:
        sock.close()
1149

Ethan Xu's avatar
Ethan Xu committed
1150
1151
1152

if __name__ == "__main__":
    # NOTE(simon):
1153
1154
    # This section should be in sync with vllm/entrypoints/cli/main.py for CLI
    # entrypoints.
1155
    cli_env_setup()
Ethan Xu's avatar
Ethan Xu committed
1156
    parser = FlexibleArgumentParser(
1157
1158
        description="vLLM OpenAI-Compatible RESTful API server."
    )
Ethan Xu's avatar
Ethan Xu committed
1159
1160
    parser = make_arg_parser(parser)
    args = parser.parse_args()
1161
    validate_parsed_serve_args(args)
1162

1163
    uvloop.run(run_server(args))