api_server.py 75.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import asyncio
4
import hashlib
5
6
import importlib
import inspect
7
import json
8
import logging
9
import multiprocessing
10
import multiprocessing.forkserver as forkserver
11
import os
12
import secrets
13
import signal
14
import socket
15
import tempfile
16
import uuid
17
from argparse import Namespace
18
from collections.abc import AsyncGenerator, AsyncIterator, Awaitable, Callable
19
from contextlib import asynccontextmanager
20
from http import HTTPStatus
21
from typing import Annotated, Any, Literal
22

23
import model_hosting_container_standards.sagemaker as sagemaker_standards
24
import prometheus_client
25
import pydantic
26
import regex as re
27
import uvloop
28
from fastapi import APIRouter, Depends, FastAPI, Form, HTTPException, Query, Request
Zhuohan Li's avatar
Zhuohan Li committed
29
30
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
31
from fastapi.responses import JSONResponse, Response, StreamingResponse
32
33
from prometheus_client import make_asgi_app
from prometheus_fastapi_instrumentator import Instrumentator
34
from starlette.concurrency import iterate_in_threadpool
35
from starlette.datastructures import URL, Headers, MutableHeaders, State
36
from starlette.routing import Mount
37
from starlette.types import ASGIApp, Message, Receive, Scope, Send
38
from typing_extensions import assert_never
Zhuohan Li's avatar
Zhuohan Li committed
39

40
import vllm.envs as envs
41
from vllm.config import VllmConfig
Woosuk Kwon's avatar
Woosuk Kwon committed
42
from vllm.engine.arg_utils import AsyncEngineArgs
43
from vllm.engine.protocol import EngineClient
44
45
46
47
48
49
50
from vllm.entrypoints.anthropic.protocol import (
    AnthropicError,
    AnthropicErrorResponse,
    AnthropicMessagesRequest,
    AnthropicMessagesResponse,
)
from vllm.entrypoints.anthropic.serving_messages import AnthropicServingMessages
51
from vllm.entrypoints.launcher import serve_http
52
from vllm.entrypoints.logger import RequestLogger
53
from vllm.entrypoints.openai.cli_args import make_arg_parser, validate_parsed_serve_args
54
from vllm.entrypoints.openai.orca_metrics import metrics_header
55
56
57
58
59
60
61
62
63
from vllm.entrypoints.openai.protocol import (
    ChatCompletionRequest,
    ChatCompletionResponse,
    ClassificationRequest,
    ClassificationResponse,
    CompletionRequest,
    CompletionResponse,
    DetokenizeRequest,
    DetokenizeResponse,
64
    EmbeddingBytesResponse,
65
66
67
68
    EmbeddingRequest,
    EmbeddingResponse,
    ErrorInfo,
    ErrorResponse,
69
70
    GenerateRequest,
    GenerateResponse,
71
    IOProcessorResponse,
72
    PoolingBytesResponse,
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
    PoolingRequest,
    PoolingResponse,
    RerankRequest,
    RerankResponse,
    ResponsesRequest,
    ResponsesResponse,
    ScoreRequest,
    ScoreResponse,
    StreamingResponsesResponse,
    TokenizeRequest,
    TokenizeResponse,
    TranscriptionRequest,
    TranscriptionResponse,
    TranslationRequest,
    TranslationResponse,
)
89
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
90
from vllm.entrypoints.openai.serving_classification import ServingClassification
91
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
92
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
93
from vllm.entrypoints.openai.serving_engine import OpenAIServing
94
95
96
97
from vllm.entrypoints.openai.serving_models import (
    BaseModelPath,
    OpenAIServingModels,
)
98
from vllm.entrypoints.openai.serving_pooling import OpenAIServingPooling
99
from vllm.entrypoints.openai.serving_responses import OpenAIServingResponses
100
from vllm.entrypoints.openai.serving_score import ServingScores
101
from vllm.entrypoints.openai.serving_tokenization import OpenAIServingTokenization
102
from vllm.entrypoints.openai.serving_tokens import ServingTokens
103
from vllm.entrypoints.openai.serving_transcription import (
104
105
106
    OpenAIServingTranscription,
    OpenAIServingTranslation,
)
107
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
108
109
110
111
112
from vllm.entrypoints.tool_server import DemoToolServer, MCPToolServer, ToolServer
from vllm.entrypoints.utils import (
    cli_env_setup,
    load_aware_call,
    log_non_default_args,
113
114
    process_chat_template,
    process_lora_modules,
115
116
    with_cancellation,
)
117
from vllm.logger import init_logger
118
from vllm.reasoning import ReasoningParserManager
119
from vllm.tasks import POOLING_TASKS
yhu422's avatar
yhu422 committed
120
from vllm.usage.usage_lib import UsageContext
Cyrus Leung's avatar
Cyrus Leung committed
121
from vllm.utils.argparse_utils import FlexibleArgumentParser
122
from vllm.utils.gc_utils import freeze_gc_heap
123
from vllm.utils.network_utils import is_valid_ipv6_address
124
from vllm.utils.system_utils import decorate_logs, set_ulimit
125
from vllm.v1.engine.exceptions import EngineDeadError
126
from vllm.v1.metrics.prometheus import get_prometheus_registry
127
from vllm.version import __version__ as VLLM_VERSION
Zhuohan Li's avatar
Zhuohan Li committed
128

129
prometheus_multiproc_dir: tempfile.TemporaryDirectory
130

131
# Cannot use __name__ (https://github.com/vllm-project/vllm/pull/4765)
132
logger = init_logger("vllm.entrypoints.openai.api_server")
133

134
135
ENDPOINT_LOAD_METRICS_FORMAT_HEADER_LABEL = "endpoint-load-metrics-format"

136
_running_tasks: set[asyncio.Task] = set()
137

138

139
@asynccontextmanager
140
async def lifespan(app: FastAPI):
141
142
    try:
        if app.state.log_stats:
143
            engine_client: EngineClient = app.state.engine_client
144
145
146

            async def _force_log():
                while True:
147
                    await asyncio.sleep(envs.VLLM_LOG_STATS_INTERVAL)
148
                    await engine_client.do_log_stats()
149
150
151
152
153
154

            task = asyncio.create_task(_force_log())
            _running_tasks.add(task)
            task.add_done_callback(_running_tasks.remove)
        else:
            task = None
155
156
157

        # Mark the startup heap as static so that it's ignored by GC.
        # Reduces pause times of oldest generation collections.
158
        freeze_gc_heap()
159
160
161
162
163
164
165
166
        try:
            yield
        finally:
            if task is not None:
                task.cancel()
    finally:
        # Ensure app state including engine ref is gc'd
        del app.state
167
168


169
@asynccontextmanager
170
async def build_async_engine_client(
171
    args: Namespace,
172
173
    *,
    usage_context: UsageContext = UsageContext.OPENAI_API_SERVER,
174
175
    disable_frontend_multiprocessing: bool | None = None,
    client_config: dict[str, Any] | None = None,
176
) -> AsyncIterator[EngineClient]:
177
178
179
180
    if os.getenv("VLLM_WORKER_MULTIPROC_METHOD") == "forkserver":
        # The executor is expected to be mp.
        # Pre-import heavy modules in the forkserver process
        logger.debug("Setup forkserver with pre-imports")
181
        multiprocessing.set_start_method("forkserver")
182
183
184
185
        multiprocessing.set_forkserver_preload(["vllm.v1.engine.async_llm"])
        forkserver.ensure_running()
        logger.debug("Forkserver setup complete!")

186
    # Context manager to handle engine_client lifecycle
187
188
    # Ensures everything is shutdown and cleaned up on error/exit
    engine_args = AsyncEngineArgs.from_cli_args(args)
189
190
191
    if client_config:
        engine_args._api_process_count = client_config.get("client_count", 1)
        engine_args._api_process_rank = client_config.get("client_index", 0)
192

193
    if disable_frontend_multiprocessing is None:
194
        disable_frontend_multiprocessing = bool(args.disable_frontend_multiprocessing)
195

196
    async with build_async_engine_client_from_engine_args(
197
198
199
200
        engine_args,
        usage_context=usage_context,
        disable_frontend_multiprocessing=disable_frontend_multiprocessing,
        client_config=client_config,
201
    ) as engine:
202
203
204
205
206
207
        yield engine


@asynccontextmanager
async def build_async_engine_client_from_engine_args(
    engine_args: AsyncEngineArgs,
208
209
    *,
    usage_context: UsageContext = UsageContext.OPENAI_API_SERVER,
210
    disable_frontend_multiprocessing: bool = False,
211
    client_config: dict[str, Any] | None = None,
212
) -> AsyncIterator[EngineClient]:
213
    """
214
    Create EngineClient, either:
215
216
217
218
219
220
        - in-process using the AsyncLLMEngine Directly
        - multiprocess using AsyncLLMEngine RPC

    Returns the Client or None if the creation failed.
    """

221
222
223
    # Create the EngineConfig (determines if we can use V1).
    vllm_config = engine_args.create_engine_config(usage_context=usage_context)

224
    if disable_frontend_multiprocessing:
225
        logger.warning("V1 is enabled, but got --disable-frontend-multiprocessing.")
226

227
    from vllm.v1.engine.async_llm import AsyncLLM
228

229
    async_llm: AsyncLLM | None = None
230
231
232
233
234
235

    # Don't mutate the input client_config
    client_config = dict(client_config) if client_config else {}
    client_count = client_config.pop("client_count", 1)
    client_index = client_config.pop("client_index", 0)

236
237
238
239
240
    try:
        async_llm = AsyncLLM.from_vllm_config(
            vllm_config=vllm_config,
            usage_context=usage_context,
            enable_log_requests=engine_args.enable_log_requests,
241
            aggregate_engine_logging=engine_args.aggregate_engine_logging,
242
243
244
            disable_log_stats=engine_args.disable_log_stats,
            client_addresses=client_config,
            client_count=client_count,
245
246
            client_index=client_index,
        )
247
248

        # Don't keep the dummy data in memory
249
        assert async_llm is not None
250
251
252
253
254
255
        await async_llm.reset_mm_cache()

        yield async_llm
    finally:
        if async_llm:
            async_llm.shutdown()
256
257


258
259
async def validate_json_request(raw_request: Request):
    content_type = raw_request.headers.get("content-type", "").lower()
260
261
    media_type = content_type.split(";", maxsplit=1)[0]
    if media_type != "application/json":
262
263
264
        raise RequestValidationError(
            errors=["Unsupported Media Type: Only 'application/json' is allowed"]
        )
265
266


Ethan Xu's avatar
Ethan Xu committed
267
router = APIRouter()
Zhuohan Li's avatar
Zhuohan Li committed
268

269

270
271
272
273
class PrometheusResponse(Response):
    media_type = prometheus_client.CONTENT_TYPE_LATEST


274
def mount_metrics(app: FastAPI):
275
276
277
    """Mount prometheus metrics to a FastAPI app."""

    registry = get_prometheus_registry()
278

279
280
281
282
    # `response_class=PrometheusResponse` is needed to return an HTTP response
    # with header "Content-Type: text/plain; version=0.0.4; charset=utf-8"
    # instead of the default "application/json" which is incorrect.
    # See https://github.com/trallnag/prometheus-fastapi-instrumentator/issues/163#issue-1296092364
283
284
285
286
287
288
289
290
291
292
    Instrumentator(
        excluded_handlers=[
            "/metrics",
            "/health",
            "/load",
            "/ping",
            "/version",
            "/server_info",
        ],
        registry=registry,
293
    ).add().instrument(app).expose(app, response_class=PrometheusResponse)
294
295
296

    # Add prometheus asgi middleware to route /metrics requests
    metrics_route = Mount("/metrics", make_asgi_app(registry=registry))
297

298
    # Workaround for 307 Redirect for /metrics
299
    metrics_route.path_regex = re.compile("^/metrics(?P<path>.*)$")
300
    app.routes.append(metrics_route)
301
302


303
304
305
306
307
def base(request: Request) -> OpenAIServing:
    # Reuse the existing instance
    return tokenization(request)


308
309
310
311
def models(request: Request) -> OpenAIServingModels:
    return request.app.state.openai_serving_models


312
def responses(request: Request) -> OpenAIServingResponses | None:
313
314
315
    return request.app.state.openai_serving_responses


316
317
318
319
def messages(request: Request) -> AnthropicServingMessages:
    return request.app.state.anthropic_serving_messages


320
def chat(request: Request) -> OpenAIServingChat | None:
321
322
323
    return request.app.state.openai_serving_chat


324
def completion(request: Request) -> OpenAIServingCompletion | None:
325
326
327
    return request.app.state.openai_serving_completion


328
def pooling(request: Request) -> OpenAIServingPooling | None:
329
330
331
    return request.app.state.openai_serving_pooling


332
def embedding(request: Request) -> OpenAIServingEmbedding | None:
333
    return request.app.state.openai_serving_embedding
334
335


336
def score(request: Request) -> ServingScores | None:
337
338
339
    return request.app.state.openai_serving_scores


340
def classify(request: Request) -> ServingClassification | None:
341
342
343
    return request.app.state.openai_serving_classification


344
def rerank(request: Request) -> ServingScores | None:
345
    return request.app.state.openai_serving_scores
346
347


348
349
def tokenization(request: Request) -> OpenAIServingTokenization:
    return request.app.state.openai_serving_tokenization
350
351


352
353
354
355
def transcription(request: Request) -> OpenAIServingTranscription:
    return request.app.state.openai_serving_transcription


356
357
358
359
def translation(request: Request) -> OpenAIServingTranslation:
    return request.app.state.openai_serving_translation


360
def engine_client(request: Request) -> EngineClient:
361
362
363
    return request.app.state.engine_client


364
365
366
367
def generate_tokens(request: Request) -> ServingTokens | None:
    return request.app.state.serving_tokens


368
369
@router.get("/health", response_class=Response)
async def health(raw_request: Request) -> Response:
370
    """Health check."""
371
372
373
374
375
    try:
        await engine_client(raw_request).check_health()
        return Response(status_code=200)
    except EngineDeadError:
        return Response(status_code=503)
376
377


378
379
380
381
382
383
384
@router.get("/load")
async def get_server_load_metrics(request: Request):
    # This endpoint returns the current server load metrics.
    # It tracks requests utilizing the GPU from the following routes:
    # - /v1/chat/completions
    # - /v1/completions
    # - /v1/audio/transcriptions
385
    # - /v1/audio/translations
386
387
    # - /v1/embeddings
    # - /pooling
388
    # - /classify
389
390
391
392
393
    # - /score
    # - /v1/score
    # - /rerank
    # - /v1/rerank
    # - /v2/rerank
394
    return JSONResponse(content={"server_load": request.app.state.server_load_metrics})
395
396


397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
@router.post("/pause")
async def pause_generation(
    raw_request: Request,
    wait_for_inflight_requests: bool = Query(False),
    clear_cache: bool = Query(True),
) -> JSONResponse:
    """Pause generation requests to allow weight updates.

    Args:
        wait_for_inflight_requests: When ``True`` waits for in-flight
            requests to finish before pausing. When ``False`` (default),
            aborts any in-flight requests immediately.
        clear_cache: Whether to clear KV/prefix caches after draining.
    """

    engine = engine_client(raw_request)

    try:
        await engine.pause_generation(
            wait_for_inflight_requests=wait_for_inflight_requests,
            clear_cache=clear_cache,
        )
        return JSONResponse(
            content={"status": "paused"},
            status_code=HTTPStatus.OK.value,
        )

    except ValueError as err:
        return JSONResponse(
            content={"error": str(err)},
            status_code=HTTPStatus.BAD_REQUEST.value,
        )
    except Exception as err:  # pragma: no cover - defensive
        logger.exception("Failed to pause generation")
        return JSONResponse(
            content={"error": f"Failed to pause generation: {err}"},
            status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value,
        )


@router.post("/resume")
async def resume_generation(raw_request: Request) -> JSONResponse:
    """Resume generation after a pause."""

    engine = engine_client(raw_request)

    try:
        await engine.resume_generation()
        return JSONResponse(
            content={"status": "resumed"},
            status_code=HTTPStatus.OK.value,
        )
    except Exception as err:  # pragma: no cover - defensive
        logger.exception("Failed to resume generation")
        return JSONResponse(
            content={"error": f"Failed to resume generation: {err}"},
            status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value,
        )


@router.get("/is_paused")
async def is_paused(raw_request: Request) -> JSONResponse:
    """Return the current pause status."""

    engine = engine_client(raw_request)

    try:
        paused = await engine.is_paused()
    except Exception as err:  # pragma: no cover - defensive
        logger.exception("Failed to fetch pause status")
        return JSONResponse(
            content={"error": f"Failed to fetch pause status: {err}"},
            status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value,
        )

    return JSONResponse(content={"is_paused": paused})


475
476
477
478
479
480
481
482
483
484
@router.post(
    "/tokenize",
    dependencies=[Depends(validate_json_request)],
    responses={
        HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
        HTTPStatus.NOT_FOUND.value: {"model": ErrorResponse},
        HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
        HTTPStatus.NOT_IMPLEMENTED.value: {"model": ErrorResponse},
    },
)
485
@with_cancellation
486
async def tokenize(request: TokenizeRequest, raw_request: Request):
487
488
    handler = tokenization(raw_request)

489
490
491
    try:
        generator = await handler.create_tokenize(request, raw_request)
    except NotImplementedError as e:
492
493
494
        raise HTTPException(
            status_code=HTTPStatus.NOT_IMPLEMENTED.value, detail=str(e)
        ) from e
495
    except Exception as e:
496
497
498
        raise HTTPException(
            status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)
        ) from e
499

500
    if isinstance(generator, ErrorResponse):
501
502
503
        return JSONResponse(
            content=generator.model_dump(), status_code=generator.error.code
        )
504
    elif isinstance(generator, TokenizeResponse):
505
506
        return JSONResponse(content=generator.model_dump())

507
508
    assert_never(generator)

509

510
511
512
513
514
515
516
517
518
@router.post(
    "/detokenize",
    dependencies=[Depends(validate_json_request)],
    responses={
        HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
        HTTPStatus.NOT_FOUND.value: {"model": ErrorResponse},
        HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
    },
)
519
@with_cancellation
520
async def detokenize(request: DetokenizeRequest, raw_request: Request):
521
522
    handler = tokenization(raw_request)

523
524
525
526
527
    try:
        generator = await handler.create_detokenize(request, raw_request)
    except OverflowError as e:
        raise RequestValidationError(errors=[str(e)]) from e
    except Exception as e:
528
529
530
        raise HTTPException(
            status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)
        ) from e
531

532
    if isinstance(generator, ErrorResponse):
533
534
535
        return JSONResponse(
            content=generator.model_dump(), status_code=generator.error.code
        )
536
    elif isinstance(generator, DetokenizeResponse):
537
538
        return JSONResponse(content=generator.model_dump())

539
540
    assert_never(generator)

541

542
543
def maybe_register_tokenizer_info_endpoint(args):
    """Conditionally register the tokenizer info endpoint if enabled."""
544
    if getattr(args, "enable_tokenizer_info_endpoint", False):
545
546
547
548
549

        @router.get("/tokenizer_info")
        async def get_tokenizer_info(raw_request: Request):
            """Get comprehensive tokenizer information."""
            result = await tokenization(raw_request).get_tokenizer_info()
550
551
552
553
554
555
            return JSONResponse(
                content=result.model_dump(),
                status_code=result.error.code
                if isinstance(result, ErrorResponse)
                else 200,
            )
556
557


Ethan Xu's avatar
Ethan Xu committed
558
@router.get("/v1/models")
559
async def show_available_models(raw_request: Request):
560
    handler = models(raw_request)
561

562
563
    models_ = await handler.show_available_models()
    return JSONResponse(content=models_.model_dump())
Zhuohan Li's avatar
Zhuohan Li committed
564
565


Ethan Xu's avatar
Ethan Xu committed
566
@router.get("/version")
567
async def show_version():
568
    ver = {"version": VLLM_VERSION}
569
570
571
    return JSONResponse(content=ver)


572
async def _convert_stream_to_sse_events(
573
    generator: AsyncGenerator[StreamingResponsesResponse, None],
574
) -> AsyncGenerator[str, None]:
575
576
    """Convert the generator to a stream of events in SSE format"""
    async for event in generator:
577
        event_type = getattr(event, "type", "unknown")
578
        # https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#event_stream_format
579
580
581
        event_data = (
            f"event: {event_type}\ndata: {event.model_dump_json(indent=None)}\n\n"
        )
582
583
584
        yield event_data


585
586
587
588
589
590
591
592
593
594
@router.post(
    "/v1/responses",
    dependencies=[Depends(validate_json_request)],
    responses={
        HTTPStatus.OK.value: {"content": {"text/event-stream": {}}},
        HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
        HTTPStatus.NOT_FOUND.value: {"model": ErrorResponse},
        HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
    },
)
595
596
597
598
599
@with_cancellation
async def create_responses(request: ResponsesRequest, raw_request: Request):
    handler = responses(raw_request)
    if handler is None:
        return base(raw_request).create_error_response(
600
601
            message="The model does not support Responses API"
        )
602
603
604
    try:
        generator = await handler.create_responses(request, raw_request)
    except Exception as e:
605
606
607
        raise HTTPException(
            status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)
        ) from e
608
609

    if isinstance(generator, ErrorResponse):
610
611
612
        return JSONResponse(
            content=generator.model_dump(), status_code=generator.error.code
        )
613
614
    elif isinstance(generator, ResponsesResponse):
        return JSONResponse(content=generator.model_dump())
615

616
617
618
    return StreamingResponse(
        content=_convert_stream_to_sse_events(generator), media_type="text/event-stream"
    )
619
620
621


@router.get("/v1/responses/{response_id}")
622
623
624
async def retrieve_responses(
    response_id: str,
    raw_request: Request,
625
626
    starting_after: int | None = None,
    stream: bool | None = False,
627
):
628
629
630
    handler = responses(raw_request)
    if handler is None:
        return base(raw_request).create_error_response(
631
632
            message="The model does not support Responses API"
        )
633

634
    try:
635
636
637
638
639
        response = await handler.retrieve_responses(
            response_id,
            starting_after=starting_after,
            stream=stream,
        )
640
    except Exception as e:
641
642
643
        raise HTTPException(
            status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)
        ) from e
644
645

    if isinstance(response, ErrorResponse):
646
647
648
        return JSONResponse(
            content=response.model_dump(), status_code=response.error.code
        )
649
650
    elif isinstance(response, ResponsesResponse):
        return JSONResponse(content=response.model_dump())
651
652
653
    return StreamingResponse(
        content=_convert_stream_to_sse_events(response), media_type="text/event-stream"
    )
654
655
656
657
658
659
660


@router.post("/v1/responses/{response_id}/cancel")
async def cancel_responses(response_id: str, raw_request: Request):
    handler = responses(raw_request)
    if handler is None:
        return base(raw_request).create_error_response(
661
662
            message="The model does not support Responses API"
        )
663

664
665
666
    try:
        response = await handler.cancel_responses(response_id)
    except Exception as e:
667
668
669
        raise HTTPException(
            status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)
        ) from e
670
671

    if isinstance(response, ErrorResponse):
672
673
674
        return JSONResponse(
            content=response.model_dump(), status_code=response.error.code
        )
675
676
677
    return JSONResponse(content=response.model_dump())


678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
@router.post(
    "/v1/messages",
    dependencies=[Depends(validate_json_request)],
    responses={
        HTTPStatus.OK.value: {"content": {"text/event-stream": {}}},
        HTTPStatus.BAD_REQUEST.value: {"model": AnthropicErrorResponse},
        HTTPStatus.NOT_FOUND.value: {"model": AnthropicErrorResponse},
        HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": AnthropicErrorResponse},
    },
)
@with_cancellation
@load_aware_call
async def create_messages(request: AnthropicMessagesRequest, raw_request: Request):
    def translate_error_response(response: ErrorResponse) -> JSONResponse:
        anthropic_error = AnthropicErrorResponse(
            error=AnthropicError(
                type=response.error.type,
                message=response.error.message,
            )
        )
        return JSONResponse(
            status_code=response.error.code, content=anthropic_error.model_dump()
        )

    handler = messages(raw_request)
    if handler is None:
        error = base(raw_request).create_error_response(
            message="The model does not support Messages API"
        )
        return translate_error_response(error)

    try:
        generator = await handler.create_messages(request, raw_request)
    except Exception as e:
        logger.exception("Error in create_messages: %s", e)
        return JSONResponse(
            status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value,
            content=AnthropicErrorResponse(
                error=AnthropicError(
                    type="internal_error",
                    message=str(e),
                )
            ).model_dump(),
        )

    if isinstance(generator, ErrorResponse):
        return translate_error_response(generator)

    elif isinstance(generator, AnthropicMessagesResponse):
727
728
729
        resp = generator.model_dump(exclude_none=True)
        logger.debug("Anthropic Messages Response: %s", resp)
        return JSONResponse(content=resp)
730
731
732
733

    return StreamingResponse(content=generator, media_type="text/event-stream")


734
735
736
737
738
739
740
741
742
743
@router.post(
    "/v1/chat/completions",
    dependencies=[Depends(validate_json_request)],
    responses={
        HTTPStatus.OK.value: {"content": {"text/event-stream": {}}},
        HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
        HTTPStatus.NOT_FOUND.value: {"model": ErrorResponse},
        HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
    },
)
744
@with_cancellation
745
@load_aware_call
746
async def create_chat_completion(request: ChatCompletionRequest, raw_request: Request):
747
748
749
    metrics_header_format = raw_request.headers.get(
        ENDPOINT_LOAD_METRICS_FORMAT_HEADER_LABEL, ""
    )
750
751
752
    handler = chat(raw_request)
    if handler is None:
        return base(raw_request).create_error_response(
753
754
            message="The model does not support Chat Completions API"
        )
755
756
757
    try:
        generator = await handler.create_chat_completion(request, raw_request)
    except Exception as e:
758
759
760
        raise HTTPException(
            status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)
        ) from e
761
    if isinstance(generator, ErrorResponse):
762
763
764
        return JSONResponse(
            content=generator.model_dump(), status_code=generator.error.code
        )
765

766
    elif isinstance(generator, ChatCompletionResponse):
767
768
769
770
        return JSONResponse(
            content=generator.model_dump(),
            headers=metrics_header(metrics_header_format),
        )
771

772
773
    return StreamingResponse(content=generator, media_type="text/event-stream")

774

775
776
777
778
779
780
781
782
783
784
@router.post(
    "/v1/completions",
    dependencies=[Depends(validate_json_request)],
    responses={
        HTTPStatus.OK.value: {"content": {"text/event-stream": {}}},
        HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
        HTTPStatus.NOT_FOUND.value: {"model": ErrorResponse},
        HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
    },
)
785
@with_cancellation
786
@load_aware_call
787
async def create_completion(request: CompletionRequest, raw_request: Request):
788
789
790
    metrics_header_format = raw_request.headers.get(
        ENDPOINT_LOAD_METRICS_FORMAT_HEADER_LABEL, ""
    )
791
792
793
    handler = completion(raw_request)
    if handler is None:
        return base(raw_request).create_error_response(
794
795
            message="The model does not support Completions API"
        )
796

797
798
799
    try:
        generator = await handler.create_completion(request, raw_request)
    except OverflowError as e:
800
801
802
        raise HTTPException(
            status_code=HTTPStatus.BAD_REQUEST.value, detail=str(e)
        ) from e
803
    except Exception as e:
804
805
806
        raise HTTPException(
            status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)
        ) from e
807

808
    if isinstance(generator, ErrorResponse):
809
810
811
        return JSONResponse(
            content=generator.model_dump(), status_code=generator.error.code
        )
812
    elif isinstance(generator, CompletionResponse):
813
814
815
816
        return JSONResponse(
            content=generator.model_dump(),
            headers=metrics_header(metrics_header_format),
        )
Zhuohan Li's avatar
Zhuohan Li committed
817

818
819
    return StreamingResponse(content=generator, media_type="text/event-stream")

Zhuohan Li's avatar
Zhuohan Li committed
820

821
822
823
824
825
826
827
828
@router.post(
    "/v1/embeddings",
    dependencies=[Depends(validate_json_request)],
    responses={
        HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
        HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
    },
)
829
@with_cancellation
830
@load_aware_call
831
832
833
834
async def create_embedding(
    request: EmbeddingRequest,
    raw_request: Request,
):
835
836
    handler = embedding(raw_request)
    if handler is None:
837
        return base(raw_request).create_error_response(
838
839
            message="The model does not support Embeddings API"
        )
840

841
842
843
    try:
        generator = await handler.create_embedding(request, raw_request)
    except Exception as e:
844
845
846
        raise HTTPException(
            status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)
        ) from e
847

848
    if isinstance(generator, ErrorResponse):
849
850
851
        return JSONResponse(
            content=generator.model_dump(), status_code=generator.error.code
        )
852
    elif isinstance(generator, EmbeddingResponse):
853
        return JSONResponse(content=generator.model_dump())
854
855
856
857
858
859
    elif isinstance(generator, EmbeddingBytesResponse):
        return StreamingResponse(
            content=generator.body,
            headers={"metadata": generator.metadata},
            media_type=generator.media_type,
        )
860

861
862
    assert_never(generator)

863

864
865
866
867
868
869
870
871
@router.post(
    "/pooling",
    dependencies=[Depends(validate_json_request)],
    responses={
        HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
        HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
    },
)
872
@with_cancellation
873
@load_aware_call
874
875
876
877
async def create_pooling(request: PoolingRequest, raw_request: Request):
    handler = pooling(raw_request)
    if handler is None:
        return base(raw_request).create_error_response(
878
879
            message="The model does not support Pooling API"
        )
880
881
882
    try:
        generator = await handler.create_pooling(request, raw_request)
    except Exception as e:
883
884
885
        raise HTTPException(
            status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)
        ) from e
886
    if isinstance(generator, ErrorResponse):
887
888
889
        return JSONResponse(
            content=generator.model_dump(), status_code=generator.error.code
        )
890
    elif isinstance(generator, (PoolingResponse, IOProcessorResponse)):
891
        return JSONResponse(content=generator.model_dump())
892
893
894
895
896
897
    elif isinstance(generator, PoolingBytesResponse):
        return StreamingResponse(
            content=generator.body,
            headers={"metadata": generator.metadata},
            media_type=generator.media_type,
        )
898
899
900
901

    assert_never(generator)


902
903
904
@router.post("/classify", dependencies=[Depends(validate_json_request)])
@with_cancellation
@load_aware_call
905
async def create_classify(request: ClassificationRequest, raw_request: Request):
906
907
908
    handler = classify(raw_request)
    if handler is None:
        return base(raw_request).create_error_response(
909
910
            message="The model does not support Classification API"
        )
911

912
913
914
    try:
        generator = await handler.create_classify(request, raw_request)
    except Exception as e:
915
916
917
        raise HTTPException(
            status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)
        ) from e
918
    if isinstance(generator, ErrorResponse):
919
920
921
        return JSONResponse(
            content=generator.model_dump(), status_code=generator.error.code
        )
922
923
924
925
926
927
928

    elif isinstance(generator, ClassificationResponse):
        return JSONResponse(content=generator.model_dump())

    assert_never(generator)


929
930
931
932
933
934
935
936
@router.post(
    "/score",
    dependencies=[Depends(validate_json_request)],
    responses={
        HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
        HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
    },
)
937
@with_cancellation
938
@load_aware_call
939
940
941
942
async def create_score(request: ScoreRequest, raw_request: Request):
    handler = score(raw_request)
    if handler is None:
        return base(raw_request).create_error_response(
943
944
            message="The model does not support Score API"
        )
945

946
947
948
    try:
        generator = await handler.create_score(request, raw_request)
    except Exception as e:
949
950
951
        raise HTTPException(
            status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)
        ) from e
952
    if isinstance(generator, ErrorResponse):
953
954
955
        return JSONResponse(
            content=generator.model_dump(), status_code=generator.error.code
        )
956
957
958
959
960
961
    elif isinstance(generator, ScoreResponse):
        return JSONResponse(content=generator.model_dump())

    assert_never(generator)


962
963
964
965
966
967
968
969
@router.post(
    "/v1/score",
    dependencies=[Depends(validate_json_request)],
    responses={
        HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
        HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
    },
)
970
@with_cancellation
971
@load_aware_call
972
973
974
async def create_score_v1(request: ScoreRequest, raw_request: Request):
    logger.warning(
        "To indicate that Score API is not part of standard OpenAI API, we "
975
976
        "have moved it to `/score`. Please update your client accordingly."
    )
977
978
979
980

    return await create_score(request, raw_request)


981
982
983
984
985
986
987
988
989
@router.post(
    "/v1/audio/transcriptions",
    responses={
        HTTPStatus.OK.value: {"content": {"text/event-stream": {}}},
        HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
        HTTPStatus.UNPROCESSABLE_ENTITY.value: {"model": ErrorResponse},
        HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
    },
)
990
@with_cancellation
991
@load_aware_call
992
993
994
async def create_transcriptions(
    raw_request: Request, request: Annotated[TranscriptionRequest, Form()]
):
995
996
997
    handler = transcription(raw_request)
    if handler is None:
        return base(raw_request).create_error_response(
998
999
            message="The model does not support Transcriptions API"
        )
1000
1001

    audio_data = await request.file.read()
1002
    try:
1003
        generator = await handler.create_transcription(audio_data, request, raw_request)
1004
    except Exception as e:
1005
1006
1007
        raise HTTPException(
            status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)
        ) from e
1008
1009

    if isinstance(generator, ErrorResponse):
1010
1011
1012
        return JSONResponse(
            content=generator.model_dump(), status_code=generator.error.code
        )
1013
1014
1015
1016
1017
1018
1019

    elif isinstance(generator, TranscriptionResponse):
        return JSONResponse(content=generator.model_dump())

    return StreamingResponse(content=generator, media_type="text/event-stream")


1020
1021
1022
1023
1024
1025
1026
1027
1028
@router.post(
    "/v1/audio/translations",
    responses={
        HTTPStatus.OK.value: {"content": {"text/event-stream": {}}},
        HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
        HTTPStatus.UNPROCESSABLE_ENTITY.value: {"model": ErrorResponse},
        HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
    },
)
1029
1030
@with_cancellation
@load_aware_call
1031
1032
1033
async def create_translations(
    request: Annotated[TranslationRequest, Form()], raw_request: Request
):
1034
1035
1036
    handler = translation(raw_request)
    if handler is None:
        return base(raw_request).create_error_response(
1037
1038
            message="The model does not support Translations API"
        )
1039
1040

    audio_data = await request.file.read()
1041
    try:
1042
        generator = await handler.create_translation(audio_data, request, raw_request)
1043
    except Exception as e:
1044
1045
1046
        raise HTTPException(
            status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)
        ) from e
1047
1048

    if isinstance(generator, ErrorResponse):
1049
1050
1051
        return JSONResponse(
            content=generator.model_dump(), status_code=generator.error.code
        )
1052
1053
1054
1055
1056
1057
1058

    elif isinstance(generator, TranslationResponse):
        return JSONResponse(content=generator.model_dump())

    return StreamingResponse(content=generator, media_type="text/event-stream")


1059
1060
1061
1062
1063
1064
1065
1066
@router.post(
    "/rerank",
    dependencies=[Depends(validate_json_request)],
    responses={
        HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
        HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
    },
)
1067
@with_cancellation
1068
@load_aware_call
1069
1070
1071
1072
async def do_rerank(request: RerankRequest, raw_request: Request):
    handler = rerank(raw_request)
    if handler is None:
        return base(raw_request).create_error_response(
1073
1074
            message="The model does not support Rerank (Score) API"
        )
1075
1076
1077
    try:
        generator = await handler.do_rerank(request, raw_request)
    except Exception as e:
1078
1079
1080
        raise HTTPException(
            status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)
        ) from e
1081
    if isinstance(generator, ErrorResponse):
1082
1083
1084
        return JSONResponse(
            content=generator.model_dump(), status_code=generator.error.code
        )
1085
1086
1087
1088
1089
1090
    elif isinstance(generator, RerankResponse):
        return JSONResponse(content=generator.model_dump())

    assert_never(generator)


1091
1092
1093
1094
1095
1096
1097
1098
@router.post(
    "/v1/rerank",
    dependencies=[Depends(validate_json_request)],
    responses={
        HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
        HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
    },
)
1099
1100
@with_cancellation
async def do_rerank_v1(request: RerankRequest, raw_request: Request):
1101
    logger.warning_once(
1102
        "To indicate that the rerank API is not part of the standard OpenAI"
1103
        " API, we have located it at `/rerank`. Please update your client "
1104
1105
        "accordingly. (Note: Conforms to JinaAI rerank API)"
    )
1106
1107
1108
1109

    return await do_rerank(request, raw_request)


1110
1111
1112
1113
1114
1115
1116
1117
@router.post(
    "/v2/rerank",
    dependencies=[Depends(validate_json_request)],
    responses={
        HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
        HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
    },
)
1118
1119
1120
1121
1122
@with_cancellation
async def do_rerank_v2(request: RerankRequest, raw_request: Request):
    return await do_rerank(request, raw_request)


1123
if envs.VLLM_SERVER_DEV_MODE:
1124
1125
1126
1127
    logger.warning(
        "SECURITY WARNING: Development endpoints are enabled! "
        "This should NOT be used in production!"
    )
1128

1129
1130
    PydanticVllmConfig = pydantic.TypeAdapter(VllmConfig)

1131
    @router.get("/server_info")
1132
1133
    async def show_server_info(
        raw_request: Request,
1134
        config_format: Annotated[Literal["text", "json"], Query()] = "text",
1135
1136
1137
    ):
        vllm_config: VllmConfig = raw_request.app.state.vllm_config
        server_info = {
1138
1139
1140
            "vllm_config": str(vllm_config)
            if config_format == "text"
            else PydanticVllmConfig.dump_python(vllm_config, mode="json", fallback=str)
1141
1142
            # fallback=str is needed to handle e.g. torch.dtype
        }
1143
1144
        return JSONResponse(content=server_info)

1145
1146
1147
1148
1149
1150
    @router.post("/reset_prefix_cache")
    async def reset_prefix_cache(raw_request: Request):
        """
        Reset the prefix cache. Note that we currently do not check if the
        prefix cache is successfully reset in the API server.
        """
1151
1152
        logger.info("Resetting prefix cache...")
        await engine_client(raw_request).reset_prefix_cache()
1153
1154
        return Response(status_code=200)

1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
    @router.post("/reset_mm_cache")
    async def reset_mm_cache(raw_request: Request):
        """
        Reset the multi-modal cache. Note that we currently do not check if the
        multi-modal cache is successfully reset in the API server.
        """
        logger.info("Resetting multi-modal cache...")
        await engine_client(raw_request).reset_mm_cache()
        return Response(status_code=200)

1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
    @router.post("/sleep")
    async def sleep(raw_request: Request):
        # get POST params
        level = raw_request.query_params.get("level", "1")
        await engine_client(raw_request).sleep(int(level))
        # FIXME: in v0 with frontend multiprocessing, the sleep command
        # is sent but does not finish yet when we return a response.
        return Response(status_code=200)

    @router.post("/wake_up")
    async def wake_up(raw_request: Request):
1176
1177
1178
1179
1180
1181
        tags = raw_request.query_params.getlist("tags")
        if tags == []:
            # set to None to wake up all tags if no tags are provided
            tags = None
        logger.info("wake up the engine with tags: %s", tags)
        await engine_client(raw_request).wake_up(tags)
1182
1183
1184
1185
        # FIXME: in v0 with frontend multiprocessing, the wake-up command
        # is sent but does not finish yet when we return a response.
        return Response(status_code=200)

1186
1187
1188
1189
1190
1191
    @router.get("/is_sleeping")
    async def is_sleeping(raw_request: Request):
        logger.info("check whether the engine is sleeping")
        is_sleeping = await engine_client(raw_request).is_sleeping()
        return JSONResponse(content={"is_sleeping": is_sleeping})

1192
1193
1194
1195
1196
    @router.post("/collective_rpc")
    async def collective_rpc(raw_request: Request):
        try:
            body = await raw_request.json()
        except json.JSONDecodeError as e:
1197
1198
1199
1200
            raise HTTPException(
                status_code=HTTPStatus.BAD_REQUEST.value,
                detail=f"JSON decode error: {e}",
            ) from e
1201
1202
        method = body.get("method")
        if method is None:
1203
1204
1205
1206
            raise HTTPException(
                status_code=HTTPStatus.BAD_REQUEST.value,
                detail="Missing 'method' in request body",
            )
1207
        # For security reason, only serialized string args/kwargs are passed.
1208
        # User-defined `method` is responsible for deserialization if needed.
1209
1210
        args: list[str] = body.get("args", [])
        kwargs: dict[str, str] = body.get("kwargs", {})
1211
        timeout: float | None = body.get("timeout")
1212
        results = await engine_client(raw_request).collective_rpc(
1213
1214
            method=method, timeout=timeout, args=tuple(args), kwargs=kwargs
        )
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
        if results is None:
            return Response(status_code=200)
        response: list[Any] = []
        for result in results:
            if result is None or isinstance(result, (dict, list)):
                response.append(result)
            else:
                response.append(str(result))
        return JSONResponse(content={"results": response})

1225

1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
@router.post(
    "/scale_elastic_ep",
    dependencies=[Depends(validate_json_request)],
    responses={
        HTTPStatus.OK.value: {"model": dict},
        HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
        HTTPStatus.REQUEST_TIMEOUT.value: {"model": ErrorResponse},
        HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
    },
)
1236
1237
1238
1239
async def scale_elastic_ep(raw_request: Request):
    try:
        body = await raw_request.json()
    except json.JSONDecodeError as e:
1240
        raise HTTPException(status_code=400, detail="Invalid JSON format") from e  # noqa: B904
1241
1242
1243
1244
1245

    new_data_parallel_size = body.get("new_data_parallel_size")
    drain_timeout = body.get("drain_timeout", 120)  # Default 2 minutes

    if new_data_parallel_size is None:
1246
1247
1248
        raise HTTPException(
            status_code=400, detail="new_data_parallel_size is required"
        )
1249

1250
    if not isinstance(new_data_parallel_size, int) or new_data_parallel_size <= 0:
1251
        raise HTTPException(
1252
1253
            status_code=400, detail="new_data_parallel_size must be a positive integer"
        )
1254
1255

    if not isinstance(drain_timeout, int) or drain_timeout <= 0:
1256
1257
1258
        raise HTTPException(
            status_code=400, detail="drain_timeout must be a positive integer"
        )
1259
1260
1261
1262
1263
1264
1265

    # Set scaling flag to prevent new requests
    global _scaling_elastic_ep
    _scaling_elastic_ep = True
    client = engine_client(raw_request)
    try:
        await client.scale_elastic_ep(new_data_parallel_size, drain_timeout)
1266
1267
1268
1269
1270
        return JSONResponse(
            {
                "message": f"Scaled to {new_data_parallel_size} data parallel engines",
            }
        )
1271
    except TimeoutError as e:
1272
1273
1274
1275
1276
        raise HTTPException(
            status_code=408,
            detail="Scale failed due to request drain timeout "
            f"after {drain_timeout} seconds",
        ) from e
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
    except Exception as e:
        logger.error("Scale failed: %s", e)
        raise HTTPException(status_code=500, detail="Scale failed") from e
    finally:
        _scaling_elastic_ep = False


@router.post("/is_scaling_elastic_ep")
async def is_scaling_elastic_ep(raw_request: Request):
    return JSONResponse({"is_scaling_elastic_ep": _scaling_elastic_ep})


1289
1290
1291
# TODO: RequestType = TypeForm[BaseModel] when recognized by type checkers
# (requires typing_extensions >= 4.13)
RequestType = Any
1292
GetHandlerFn = Callable[[Request], OpenAIServing | None]
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
EndpointFn = Callable[[RequestType, Request], Awaitable[Any]]

# NOTE: Items defined earlier take higher priority
INVOCATION_TYPES: list[tuple[RequestType, tuple[GetHandlerFn, EndpointFn]]] = [
    (ChatCompletionRequest, (chat, create_chat_completion)),
    (CompletionRequest, (completion, create_completion)),
    (EmbeddingRequest, (embedding, create_embedding)),
    (ClassificationRequest, (classify, create_classify)),
    (ScoreRequest, (score, create_score)),
    (RerankRequest, (rerank, do_rerank)),
    (PoolingRequest, (pooling, create_pooling)),
]

# NOTE: Construct the TypeAdapters only once
INVOCATION_VALIDATORS = [
    (pydantic.TypeAdapter(request_type), (get_handler, endpoint))
    for request_type, (get_handler, endpoint) in INVOCATION_TYPES
]


1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
@router.post(
    "/inference/v1/generate",
    dependencies=[Depends(validate_json_request)],
    responses={
        HTTPStatus.OK.value: {"content": {"text/event-stream": {}}},
        HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
        HTTPStatus.NOT_FOUND.value: {"model": ErrorResponse},
        HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
    },
)
@with_cancellation
@load_aware_call
async def generate(request: GenerateRequest, raw_request: Request):
    handler = generate_tokens(raw_request)
    if handler is None:
        return base(raw_request).create_error_response(
            message="The model does not support generate tokens API"
        )
    try:
        generator = await handler.serve_tokens(request, raw_request)
    except Exception as e:
        raise HTTPException(
            status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)
        ) from e
    if isinstance(generator, ErrorResponse):
        return JSONResponse(
            content=generator.model_dump(), status_code=generator.error.code
        )

    elif isinstance(generator, GenerateResponse):
        return JSONResponse(content=generator.model_dump())

    return StreamingResponse(content=generator, media_type="text/event-stream")


1348
if envs.VLLM_TORCH_PROFILER_DIR:
1349
    logger.warning_once(
1350
        "Torch Profiler is enabled in the API server. This should ONLY be "
1351
1352
        "used for local development!"
    )
1353
1354
1355
1356
1357
1358
elif envs.VLLM_TORCH_CUDA_PROFILE:
    logger.warning_once(
        "CUDA Profiler is enabled in the API server. This should ONLY be "
        "used for local development!"
    )
if envs.VLLM_TORCH_PROFILER_DIR or envs.VLLM_TORCH_CUDA_PROFILE:
1359
1360

    @router.post("/start_profile")
1361
    async def start_profile(raw_request: Request):
1362
        logger.info("Starting profiler...")
1363
        await engine_client(raw_request).start_profile()
1364
1365
1366
1367
        logger.info("Profiler started.")
        return Response(status_code=200)

    @router.post("/stop_profile")
1368
    async def stop_profile(raw_request: Request):
1369
        logger.info("Stopping profiler...")
1370
        await engine_client(raw_request).stop_profile()
1371
1372
1373
1374
        logger.info("Profiler stopped.")
        return Response(status_code=200)


1375
def load_log_config(log_config_file: str | None) -> dict | None:
1376
1377
1378
1379
1380
1381
    if not log_config_file:
        return None
    try:
        with open(log_config_file) as f:
            return json.load(f)
    except Exception as e:
1382
1383
1384
        logger.warning(
            "Failed to load log config from file %s: error %s", log_config_file, e
        )
1385
1386
1387
        return None


1388
1389
1390
class AuthenticationMiddleware:
    """
    Pure ASGI middleware that authenticates each request by checking
1391
    if the Authorization Bearer token exists and equals anyof "{api_key}".
1392
1393
1394
1395
1396
1397
1398
1399

    Notes
    -----
    There are two cases in which authentication is skipped:
        1. The HTTP method is OPTIONS.
        2. The request path doesn't start with /v1 (e.g. /health).
    """

1400
    def __init__(self, app: ASGIApp, tokens: list[str]) -> None:
1401
        self.app = app
1402
        self.api_tokens = [hashlib.sha256(t.encode("utf-8")).digest() for t in tokens]
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419

    def verify_token(self, headers: Headers) -> bool:
        authorization_header_value = headers.get("Authorization")
        if not authorization_header_value:
            return False

        scheme, _, param = authorization_header_value.partition(" ")
        if scheme.lower() != "bearer":
            return False

        param_hash = hashlib.sha256(param.encode("utf-8")).digest()

        token_match = False
        for token_hash in self.api_tokens:
            token_match |= secrets.compare_digest(param_hash, token_hash)

        return token_match
1420

1421
1422
    def __call__(self, scope: Scope, receive: Receive, send: Send) -> Awaitable[None]:
        if scope["type"] not in ("http", "websocket") or scope["method"] == "OPTIONS":
1423
1424
1425
1426
1427
1428
1429
            # scope["type"] can be "lifespan" or "startup" for example,
            # in which case we don't need to do anything
            return self.app(scope, receive, send)
        root_path = scope.get("root_path", "")
        url_path = URL(scope=scope).path.removeprefix(root_path)
        headers = Headers(scope=scope)
        # Type narrow to satisfy mypy.
1430
        if url_path.startswith("/v1") and not self.verify_token(headers):
1431
            response = JSONResponse(content={"error": "Unauthorized"}, status_code=401)
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
            return response(scope, receive, send)
        return self.app(scope, receive, send)


class XRequestIdMiddleware:
    """
    Middleware the set's the X-Request-Id header for each response
    to a random uuid4 (hex) value if the header isn't already
    present in the request, otherwise use the provided request id.
    """

    def __init__(self, app: ASGIApp) -> None:
        self.app = app

1446
    def __call__(self, scope: Scope, receive: Receive, send: Send) -> Awaitable[None]:
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
        if scope["type"] not in ("http", "websocket"):
            return self.app(scope, receive, send)

        # Extract the request headers.
        request_headers = Headers(scope=scope)

        async def send_with_request_id(message: Message) -> None:
            """
            Custom send function to mutate the response headers
            and append X-Request-Id to it.
            """
            if message["type"] == "http.response.start":
                response_headers = MutableHeaders(raw=message["headers"])
1460
                request_id = request_headers.get("X-Request-Id", uuid.uuid4().hex)
1461
1462
1463
1464
1465
1466
                response_headers.append("X-Request-Id", request_id)
            await send(message)

        return self.app(scope, receive, send_with_request_id)


1467
1468
1469
1470
1471
1472
1473
1474
# Global variable to track scaling state
_scaling_elastic_ep = False


class ScalingMiddleware:
    """
    Middleware that checks if the model is currently scaling and
    returns a 503 Service Unavailable response if it is.
1475

1476
1477
1478
1479
1480
1481
1482
    This middleware applies to all HTTP requests and prevents
    processing when the model is in a scaling state.
    """

    def __init__(self, app: ASGIApp) -> None:
        self.app = app

1483
    def __call__(self, scope: Scope, receive: Receive, send: Send) -> Awaitable[None]:
1484
1485
1486
1487
1488
1489
1490
        if scope["type"] != "http":
            return self.app(scope, receive, send)

        # Check global scaling state
        global _scaling_elastic_ep
        if _scaling_elastic_ep:
            # Return 503 Service Unavailable response
1491
1492
1493
1494
1495
1496
            response = JSONResponse(
                content={
                    "error": "The model is currently scaling. Please try again later."
                },
                status_code=503,
            )
1497
1498
1499
1500
1501
            return response(scope, receive, send)

        return self.app(scope, receive, send)


1502
1503
1504
1505
def _extract_content_from_chunk(chunk_data: dict) -> str:
    """Extract content from a streaming response chunk."""
    try:
        from vllm.entrypoints.openai.protocol import (
1506
1507
1508
            ChatCompletionStreamResponse,
            CompletionStreamResponse,
        )
1509
1510

        # Try using Completion types for type-safe parsing
1511
1512
        if chunk_data.get("object") == "chat.completion.chunk":
            chat_response = ChatCompletionStreamResponse.model_validate(chunk_data)
1513
1514
            if chat_response.choices and chat_response.choices[0].delta.content:
                return chat_response.choices[0].delta.content
1515
1516
1517
        elif chunk_data.get("object") == "text_completion":
            completion_response = CompletionStreamResponse.model_validate(chunk_data)
            if completion_response.choices and completion_response.choices[0].text:
1518
1519
1520
                return completion_response.choices[0].text
    except pydantic.ValidationError:
        # Fallback to manual parsing
1521
1522
1523
1524
1525
1526
        if "choices" in chunk_data and chunk_data["choices"]:
            choice = chunk_data["choices"][0]
            if "delta" in choice and choice["delta"].get("content"):
                return choice["delta"]["content"]
            elif choice.get("text"):
                return choice["text"]
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
    return ""


class SSEDecoder:
    """Robust Server-Sent Events decoder for streaming responses."""

    def __init__(self):
        self.buffer = ""
        self.content_buffer = []

    def decode_chunk(self, chunk: bytes) -> list[dict]:
        """Decode a chunk of SSE data and return parsed events."""
        import json

        try:
1542
            chunk_str = chunk.decode("utf-8")
1543
1544
1545
1546
1547
1548
1549
1550
        except UnicodeDecodeError:
            # Skip malformed chunks
            return []

        self.buffer += chunk_str
        events = []

        # Process complete lines
1551
1552
1553
        while "\n" in self.buffer:
            line, self.buffer = self.buffer.split("\n", 1)
            line = line.rstrip("\r")  # Handle CRLF
1554

1555
            if line.startswith("data: "):
1556
                data_str = line[6:].strip()
1557
1558
                if data_str == "[DONE]":
                    events.append({"type": "done"})
1559
1560
1561
                elif data_str:
                    try:
                        event_data = json.loads(data_str)
1562
                        events.append({"type": "data", "data": event_data})
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
                    except json.JSONDecodeError:
                        # Skip malformed JSON
                        continue

        return events

    def extract_content(self, event_data: dict) -> str:
        """Extract content from event data."""
        return _extract_content_from_chunk(event_data)

    def add_content(self, content: str) -> None:
        """Add content to the buffer."""
        if content:
            self.content_buffer.append(content)

    def get_complete_content(self) -> str:
        """Get the complete buffered content."""
1580
        return "".join(self.content_buffer)
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600


def _log_streaming_response(response, response_body: list) -> None:
    """Log streaming response with robust SSE parsing."""
    from starlette.concurrency import iterate_in_threadpool

    sse_decoder = SSEDecoder()
    chunk_count = 0

    def buffered_iterator():
        nonlocal chunk_count

        for chunk in response_body:
            chunk_count += 1
            yield chunk

            # Parse SSE events from chunk
            events = sse_decoder.decode_chunk(chunk)

            for event in events:
1601
1602
                if event["type"] == "data":
                    content = sse_decoder.extract_content(event["data"])
1603
                    sse_decoder.add_content(content)
1604
                elif event["type"] == "done":
1605
1606
1607
1608
1609
1610
1611
1612
                    # Log complete content when done
                    full_content = sse_decoder.get_complete_content()
                    if full_content:
                        # Truncate if too long
                        if len(full_content) > 2048:
                            full_content = full_content[:2048] + ""
                            "...[truncated]"
                        logger.info(
1613
                            "response_body={streaming_complete: content=%r, chunks=%d}",
1614
1615
1616
                            full_content,
                            chunk_count,
                        )
1617
1618
                    else:
                        logger.info(
1619
1620
1621
                            "response_body={streaming_complete: no_content, chunks=%d}",
                            chunk_count,
                        )
1622
1623
1624
                    return

    response.body_iterator = iterate_in_threadpool(buffered_iterator())
1625
    logger.info("response_body={streaming_started: chunks=%d}", len(response_body))
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636


def _log_non_streaming_response(response_body: list) -> None:
    """Log non-streaming response."""
    try:
        decoded_body = response_body[0].decode()
        logger.info("response_body={%s}", decoded_body)
    except UnicodeDecodeError:
        logger.info("response_body={<binary_data>}")


1637
def build_app(args: Namespace) -> FastAPI:
1638
    if args.disable_fastapi_docs:
1639
1640
1641
        app = FastAPI(
            openapi_url=None, docs_url=None, redoc_url=None, lifespan=lifespan
        )
1642
1643
    else:
        app = FastAPI(lifespan=lifespan)
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657

    if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING:
        logger.warning(
            "LoRA dynamic loading & unloading is enabled in the API server. "
            "This should ONLY be used for local development!"
        )
        from vllm.entrypoints.dynamic_lora import register_dynamic_lora_routes

        register_dynamic_lora_routes(router)

    from vllm.entrypoints.sagemaker.routes import register_sagemaker_routes

    register_sagemaker_routes(router)

Ethan Xu's avatar
Ethan Xu committed
1658
1659
    app.include_router(router)
    app.root_path = args.root_path
Zhuohan Li's avatar
Zhuohan Li committed
1660

1661
1662
    mount_metrics(app)

Zhuohan Li's avatar
Zhuohan Li committed
1663
1664
1665
1666
1667
1668
1669
1670
    app.add_middleware(
        CORSMiddleware,
        allow_origins=args.allowed_origins,
        allow_credentials=args.allow_credentials,
        allow_methods=args.allowed_methods,
        allow_headers=args.allowed_headers,
    )

1671
1672
    @app.exception_handler(HTTPException)
    async def http_exception_handler(_: Request, exc: HTTPException):
1673
        err = ErrorResponse(
1674
1675
1676
1677
1678
1679
            error=ErrorInfo(
                message=exc.detail,
                type=HTTPStatus(exc.status_code).phrase,
                code=exc.status_code,
            )
        )
1680
1681
        return JSONResponse(err.model_dump(), status_code=exc.status_code)

Ethan Xu's avatar
Ethan Xu committed
1682
    @app.exception_handler(RequestValidationError)
1683
    async def validation_exception_handler(_: Request, exc: RequestValidationError):
1684
1685
1686
1687
1688
1689
1690
1691
        exc_str = str(exc)
        errors_str = str(exc.errors())

        if exc.errors() and errors_str and errors_str != exc_str:
            message = f"{exc_str} {errors_str}"
        else:
            message = exc_str

1692
1693
1694
1695
1696
1697
1698
1699
        err = ErrorResponse(
            error=ErrorInfo(
                message=message,
                type=HTTPStatus.BAD_REQUEST.phrase,
                code=HTTPStatus.BAD_REQUEST,
            )
        )
        return JSONResponse(err.model_dump(), status_code=HTTPStatus.BAD_REQUEST)
Ethan Xu's avatar
Ethan Xu committed
1700

1701
    # Ensure --api-key option from CLI takes precedence over VLLM_API_KEY
1702
1703
    if tokens := [key for key in (args.api_key or [envs.VLLM_API_KEY]) if key]:
        app.add_middleware(AuthenticationMiddleware, tokens=tokens)
1704

1705
    if args.enable_request_id_headers:
1706
        app.add_middleware(XRequestIdMiddleware)
1707

1708
1709
1710
    # Add scaling middleware to check for scaling state
    app.add_middleware(ScalingMiddleware)

1711
    if envs.VLLM_DEBUG_LOG_API_SERVER_RESPONSE:
1712
1713
1714
1715
1716
        logger.warning(
            "CAUTION: Enabling log response in the API Server. "
            "This can include sensitive information and should be "
            "avoided in production."
        )
1717
1718
1719
1720

        @app.middleware("http")
        async def log_response(request: Request, call_next):
            response = await call_next(request)
1721
            response_body = [section async for section in response.body_iterator]
1722
            response.body_iterator = iterate_in_threadpool(iter(response_body))
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
            # Check if this is a streaming response by looking at content-type
            content_type = response.headers.get("content-type", "")
            is_streaming = content_type == "text/event-stream; charset=utf-8"

            # Log response body based on type
            if not response_body:
                logger.info("response_body={<empty>}")
            elif is_streaming:
                _log_streaming_response(response, response_body)
            else:
                _log_non_streaming_response(response_body)
1734
            return response
1735

1736
1737
1738
1739
    for middleware in args.middleware:
        module_path, object_name = middleware.rsplit(".", 1)
        imported = getattr(importlib.import_module(module_path), object_name)
        if inspect.isclass(imported):
1740
            app.add_middleware(imported)  # type: ignore[arg-type]
1741
1742
1743
        elif inspect.iscoroutinefunction(imported):
            app.middleware("http")(imported)
        else:
1744
1745
1746
            raise ValueError(
                f"Invalid middleware {middleware}. Must be a function or a class."
            )
1747

1748
    app = sagemaker_standards.bootstrap(app)
1749
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
    # Optional endpoints
    if args.tokens_only:

        @app.post("/abort_requests")
        async def abort_requests(raw_request: Request):
            """
            Abort one or more requests. To be used in a
            Disaggregated Everything setup.
            """
            try:
                body = await raw_request.json()
            except json.JSONDecodeError as e:
                raise HTTPException(
                    status_code=HTTPStatus.BAD_REQUEST.value,
                    detail=f"JSON decode error: {e}",
                ) from e
            request_ids = body.get("request_ids")
            if request_ids is None:
                raise HTTPException(
                    status_code=HTTPStatus.BAD_REQUEST.value,
                    detail="Missing 'request_ids' in request body",
                )
            # Abort requests in background
            asyncio.create_task(engine_client(raw_request).abort(request_ids))
            return Response(status_code=200)
1774

Ethan Xu's avatar
Ethan Xu committed
1775
1776
1777
    return app


1778
async def init_app_state(
1779
    engine_client: EngineClient,
1780
    state: State,
1781
    args: Namespace,
1782
) -> None:
1783
1784
    vllm_config = engine_client.vllm_config

1785
    if args.served_model_name is not None:
1786
        served_model_names = args.served_model_name
1787
    else:
1788
        served_model_names = [args.model]
1789

1790
    if args.enable_log_requests:
1791
        request_logger = RequestLogger(max_log_len=args.max_log_len)
1792
1793
    else:
        request_logger = None
1794

1795
    base_model_paths = [
1796
        BaseModelPath(name=name, model_path=args.model) for name in served_model_names
1797
1798
    ]

1799
    state.engine_client = engine_client
1800
    state.log_stats = not args.disable_log_stats
1801
    state.vllm_config = vllm_config
Ethan Xu's avatar
Ethan Xu committed
1802

1803
    supported_tasks = await engine_client.get_supported_tasks()
1804
    logger.info("Supported tasks: %s", supported_tasks)
1805

1806
1807
1808
    resolved_chat_template = await process_chat_template(
        args.chat_template, engine_client, vllm_config.model_config
    )
1809

1810
    if args.tool_server == "demo":
1811
        tool_server: ToolServer | None = DemoToolServer()
1812
1813
        assert isinstance(tool_server, DemoToolServer)
        await tool_server.init_and_validate()
1814
1815
1816
    elif args.tool_server:
        tool_server = MCPToolServer()
        await tool_server.add_tool_server(args.tool_server)
1817
1818
1819
    else:
        tool_server = None

1820
    # Merge default_mm_loras into the static lora_modules
1821
1822
1823
1824
1825
    default_mm_loras = (
        vllm_config.lora_config.default_mm_loras
        if vllm_config.lora_config is not None
        else {}
    )
1826

1827
1828
1829
1830
1831
1832
    default_mm_loras = (
        vllm_config.lora_config.default_mm_loras
        if vllm_config.lora_config is not None
        else {}
    )
    lora_modules = process_lora_modules(args.lora_modules, default_mm_loras)
1833

1834
    state.openai_serving_models = OpenAIServingModels(
1835
        engine_client=engine_client,
1836
        base_model_paths=base_model_paths,
1837
        lora_modules=lora_modules,
1838
    )
1839
    await state.openai_serving_models.init_static_loras()
1840
1841
1842
1843
1844
1845
1846
1847
1848
1849
1850
1851
1852
1853
1854
1855
1856
1857
1858
1859
1860
1861
1862
1863
1864
1865
1866
1867
1868
1869
1870
1871
1872
1873
1874
1875
1876
1877
1878
1879
1880
1881
1882
1883
1884
1885
1886
1887
1888
1889
1890
1891
1892
1893
1894
1895
    state.openai_serving_responses = (
        OpenAIServingResponses(
            engine_client,
            state.openai_serving_models,
            request_logger=request_logger,
            chat_template=resolved_chat_template,
            chat_template_content_format=args.chat_template_content_format,
            return_tokens_as_token_ids=args.return_tokens_as_token_ids,
            enable_auto_tools=args.enable_auto_tool_choice,
            tool_parser=args.tool_call_parser,
            tool_server=tool_server,
            reasoning_parser=args.structured_outputs_config.reasoning_parser,
            enable_prompt_tokens_details=args.enable_prompt_tokens_details,
            enable_force_include_usage=args.enable_force_include_usage,
            enable_log_outputs=args.enable_log_outputs,
            log_error_stack=args.log_error_stack,
        )
        if "generate" in supported_tasks
        else None
    )
    state.openai_serving_chat = (
        OpenAIServingChat(
            engine_client,
            state.openai_serving_models,
            args.response_role,
            request_logger=request_logger,
            chat_template=resolved_chat_template,
            chat_template_content_format=args.chat_template_content_format,
            trust_request_chat_template=args.trust_request_chat_template,
            return_tokens_as_token_ids=args.return_tokens_as_token_ids,
            enable_auto_tools=args.enable_auto_tool_choice,
            exclude_tools_when_tool_choice_none=args.exclude_tools_when_tool_choice_none,
            tool_parser=args.tool_call_parser,
            reasoning_parser=args.structured_outputs_config.reasoning_parser,
            enable_prompt_tokens_details=args.enable_prompt_tokens_details,
            enable_force_include_usage=args.enable_force_include_usage,
            enable_log_outputs=args.enable_log_outputs,
            log_error_stack=args.log_error_stack,
        )
        if "generate" in supported_tasks
        else None
    )
    state.openai_serving_completion = (
        OpenAIServingCompletion(
            engine_client,
            state.openai_serving_models,
            request_logger=request_logger,
            return_tokens_as_token_ids=args.return_tokens_as_token_ids,
            enable_prompt_tokens_details=args.enable_prompt_tokens_details,
            enable_force_include_usage=args.enable_force_include_usage,
            log_error_stack=args.log_error_stack,
        )
        if "generate" in supported_tasks
        else None
    )
    state.openai_serving_pooling = (
1896
1897
1898
1899
1900
1901
1902
1903
1904
1905
1906
        (
            OpenAIServingPooling(
                engine_client,
                state.openai_serving_models,
                supported_tasks=supported_tasks,
                request_logger=request_logger,
                chat_template=resolved_chat_template,
                chat_template_content_format=args.chat_template_content_format,
                trust_request_chat_template=args.trust_request_chat_template,
                log_error_stack=args.log_error_stack,
            )
1907
        )
1908
        if any(task in POOLING_TASKS for task in supported_tasks)
1909
1910
1911
1912
1913
1914
1915
1916
1917
1918
1919
1920
1921
1922
1923
1924
1925
1926
1927
1928
        else None
    )
    state.openai_serving_embedding = (
        OpenAIServingEmbedding(
            engine_client,
            state.openai_serving_models,
            request_logger=request_logger,
            chat_template=resolved_chat_template,
            chat_template_content_format=args.chat_template_content_format,
            trust_request_chat_template=args.trust_request_chat_template,
            log_error_stack=args.log_error_stack,
        )
        if "embed" in supported_tasks
        else None
    )
    state.openai_serving_classification = (
        ServingClassification(
            engine_client,
            state.openai_serving_models,
            request_logger=request_logger,
1929
1930
1931
            chat_template=resolved_chat_template,
            chat_template_content_format=args.chat_template_content_format,
            trust_request_chat_template=args.trust_request_chat_template,
1932
1933
1934
1935
1936
1937
1938
1939
1940
1941
1942
1943
1944
1945
1946
            log_error_stack=args.log_error_stack,
        )
        if "classify" in supported_tasks
        else None
    )
    state.openai_serving_scores = (
        ServingScores(
            engine_client,
            state.openai_serving_models,
            request_logger=request_logger,
            log_error_stack=args.log_error_stack,
        )
        if ("embed" in supported_tasks or "score" in supported_tasks)
        else None
    )
1947
    state.openai_serving_tokenization = OpenAIServingTokenization(
1948
        engine_client,
1949
        state.openai_serving_models,
1950
        request_logger=request_logger,
1951
1952
        chat_template=resolved_chat_template,
        chat_template_content_format=args.chat_template_content_format,
1953
        trust_request_chat_template=args.trust_request_chat_template,
1954
        log_error_stack=args.log_error_stack,
1955
    )
1956
1957
1958
1959
1960
1961
    state.openai_serving_transcription = (
        OpenAIServingTranscription(
            engine_client,
            state.openai_serving_models,
            request_logger=request_logger,
            log_error_stack=args.log_error_stack,
1962
            enable_force_include_usage=args.enable_force_include_usage,
1963
1964
1965
1966
1967
1968
1969
1970
1971
1972
        )
        if "transcription" in supported_tasks
        else None
    )
    state.openai_serving_translation = (
        OpenAIServingTranslation(
            engine_client,
            state.openai_serving_models,
            request_logger=request_logger,
            log_error_stack=args.log_error_stack,
1973
            enable_force_include_usage=args.enable_force_include_usage,
1974
1975
1976
1977
        )
        if "transcription" in supported_tasks
        else None
    )
1978
1979
1980
1981
1982
1983
1984
1985
1986
1987
1988
1989
1990
1991
1992
1993
1994
1995
    state.anthropic_serving_messages = (
        AnthropicServingMessages(
            engine_client,
            state.openai_serving_models,
            args.response_role,
            request_logger=request_logger,
            chat_template=resolved_chat_template,
            chat_template_content_format=args.chat_template_content_format,
            return_tokens_as_token_ids=args.return_tokens_as_token_ids,
            enable_auto_tools=args.enable_auto_tool_choice,
            tool_parser=args.tool_call_parser,
            reasoning_parser=args.structured_outputs_config.reasoning_parser,
            enable_prompt_tokens_details=args.enable_prompt_tokens_details,
            enable_force_include_usage=args.enable_force_include_usage,
        )
        if "generate" in supported_tasks
        else None
    )
1996
1997
1998
1999
2000
2001
2002
2003
2004
2005
2006
2007
2008
2009
    state.serving_tokens = (
        ServingTokens(
            engine_client,
            state.openai_serving_models,
            request_logger=request_logger,
            return_tokens_as_token_ids=args.return_tokens_as_token_ids,
            log_error_stack=args.log_error_stack,
            enable_prompt_tokens_details=args.enable_prompt_tokens_details,
            enable_log_outputs=args.enable_log_outputs,
            force_no_detokenize=args.tokens_only,
        )
        if "generate" in supported_tasks
        else None
    )
2010

2011
2012
2013
    state.enable_server_load_tracking = args.enable_server_load_tracking
    state.server_load_metrics = 0

2014

2015
def create_server_socket(addr: tuple[str, int]) -> socket.socket:
2016
2017
2018
2019
2020
2021
    family = socket.AF_INET
    if is_valid_ipv6_address(addr[0]):
        family = socket.AF_INET6

    sock = socket.socket(family=family, type=socket.SOCK_STREAM)
    sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
2022
    sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
2023
2024
2025
2026
2027
    sock.bind(addr)

    return sock


2028
2029
2030
2031
2032
2033
def create_server_unix_socket(path: str) -> socket.socket:
    sock = socket.socket(family=socket.AF_UNIX, type=socket.SOCK_STREAM)
    sock.bind(path)
    return sock


2034
def validate_api_server_args(args):
2035
    valid_tool_parses = ToolParserManager.list_registered()
2036
2037
2038
2039
2040
    if args.enable_auto_tool_choice and args.tool_call_parser not in valid_tool_parses:
        raise KeyError(
            f"invalid tool call parser: {args.tool_call_parser} "
            f"(chose from {{ {','.join(valid_tool_parses)} }})"
        )
2041

2042
    valid_reasoning_parsers = ReasoningParserManager.list_registered()
2043
2044
    if (
        reasoning_parser := args.structured_outputs_config.reasoning_parser
2045
    ) and reasoning_parser not in valid_reasoning_parsers:
2046
        raise KeyError(
2047
            f"invalid reasoning parser: {reasoning_parser} "
2048
            f"(chose from {{ {','.join(valid_reasoning_parsers)} }})"
2049
        )
2050

2051
2052
2053
2054
2055
2056
2057
2058
2059
2060
2061

def setup_server(args):
    """Validate API server args, set up signal handler, create socket
    ready to serve."""

    logger.info("vLLM API server version %s", VLLM_VERSION)
    log_non_default_args(args)

    if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3:
        ToolParserManager.import_tool_parser(args.tool_parser_plugin)

2062
2063
2064
    if args.reasoning_parser_plugin and len(args.reasoning_parser_plugin) > 3:
        ReasoningParserManager.import_reasoning_parser(args.reasoning_parser_plugin)

2065
2066
    validate_api_server_args(args)

2067
2068
2069
    # workaround to make sure that we bind the port before the engine is set up.
    # This avoids race conditions with ray.
    # see https://github.com/vllm-project/vllm/issues/8204
2070
2071
2072
2073
2074
    if args.uds:
        sock = create_server_unix_socket(args.uds)
    else:
        sock_addr = (args.host or "", args.port)
        sock = create_server_socket(sock_addr)
2075

2076
2077
2078
2079
    # workaround to avoid footguns where uvicorn drops requests with too
    # many concurrent requests active
    set_ulimit()

2080
2081
2082
2083
2084
2085
    def signal_handler(*_) -> None:
        # Interrupt server on sigterm while initializing
        raise KeyboardInterrupt("terminated")

    signal.signal(signal.SIGTERM, signal_handler)

2086
2087
2088
2089
2090
    if args.uds:
        listen_address = f"unix:{args.uds}"
    else:
        addr, port = sock_addr
        is_ssl = args.ssl_keyfile and args.ssl_certfile
2091
        host_part = f"[{addr}]" if is_valid_ipv6_address(addr) else addr or "0.0.0.0"
2092
        listen_address = f"http{'s' if is_ssl else ''}://{host_part}:{port}"
2093
2094
2095
2096
2097
    return listen_address, sock


async def run_server(args, **uvicorn_kwargs) -> None:
    """Run a single-worker API server."""
2098
2099

    # Add process-specific prefix to stdout and stderr.
2100
    decorate_logs("APIServer")
2101

2102
2103
2104
    # Suppress verbose logs from model_hosting_container_standards
    logging.getLogger("model_hosting_container_standards").setLevel(logging.ERROR)

2105
2106
2107
2108
    listen_address, sock = setup_server(args)
    await run_server_worker(listen_address, sock, args, **uvicorn_kwargs)


2109
2110
2111
async def run_server_worker(
    listen_address, sock, args, client_config=None, **uvicorn_kwargs
) -> None:
2112
2113
2114
2115
2116
    """Run a single API server worker."""

    if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3:
        ToolParserManager.import_tool_parser(args.tool_parser_plugin)

2117
2118
2119
    if args.reasoning_parser_plugin and len(args.reasoning_parser_plugin) > 3:
        ReasoningParserManager.import_reasoning_parser(args.reasoning_parser_plugin)

2120
2121
2122
    # Load logging config for uvicorn if specified
    log_config = load_log_config(args.log_config_file)
    if log_config is not None:
2123
        uvicorn_kwargs["log_config"] = log_config
2124

2125
    async with build_async_engine_client(
2126
2127
        args,
        client_config=client_config,
2128
    ) as engine_client:
2129
        maybe_register_tokenizer_info_endpoint(args)
2130
2131
        app = build_app(args)

2132
        await init_app_state(engine_client, app.state, args)
2133

2134
2135
        logger.info(
            "Starting vLLM API server %d on %s",
2136
            engine_client.vllm_config.parallel_config._api_process_rank,
2137
2138
            listen_address,
        )
2139
2140
        shutdown_task = await serve_http(
            app,
2141
            sock=sock,
2142
            enable_ssl_refresh=args.enable_ssl_refresh,
2143
2144
2145
            host=args.host,
            port=args.port,
            log_level=args.uvicorn_log_level,
2146
2147
2148
            # NOTE: When the 'disable_uvicorn_access_log' value is True,
            # no access log will be output.
            access_log=not args.disable_uvicorn_access_log,
2149
            timeout_keep_alive=envs.VLLM_HTTP_TIMEOUT_KEEP_ALIVE,
2150
2151
2152
2153
            ssl_keyfile=args.ssl_keyfile,
            ssl_certfile=args.ssl_certfile,
            ssl_ca_certs=args.ssl_ca_certs,
            ssl_cert_reqs=args.ssl_cert_reqs,
2154
2155
            h11_max_incomplete_event_size=args.h11_max_incomplete_event_size,
            h11_max_header_count=args.h11_max_header_count,
2156
2157
2158
            **uvicorn_kwargs,
        )

2159
    # NB: Await server shutdown only after the backend context is exited
2160
2161
2162
2163
    try:
        await shutdown_task
    finally:
        sock.close()
2164

Ethan Xu's avatar
Ethan Xu committed
2165
2166
2167

if __name__ == "__main__":
    # NOTE(simon):
2168
2169
    # This section should be in sync with vllm/entrypoints/cli/main.py for CLI
    # entrypoints.
2170
    cli_env_setup()
Ethan Xu's avatar
Ethan Xu committed
2171
    parser = FlexibleArgumentParser(
2172
2173
        description="vLLM OpenAI-Compatible RESTful API server."
    )
Ethan Xu's avatar
Ethan Xu committed
2174
2175
    parser = make_arg_parser(parser)
    args = parser.parse_args()
2176
    validate_parsed_serve_args(args)
2177

2178
    uvloop.run(run_server(args))