api_server.py 71.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import asyncio
5
import gc
6
import hashlib
7
8
import importlib
import inspect
9
import json
10
import multiprocessing
11
import multiprocessing.forkserver as forkserver
12
import os
13
import secrets
14
import signal
15
import socket
16
import tempfile
17
import uuid
18
from argparse import Namespace
19
from collections.abc import AsyncGenerator, AsyncIterator, Awaitable, Callable
20
from contextlib import asynccontextmanager
21
from http import HTTPStatus
22
from typing import Annotated, Any, Literal
23

24
import prometheus_client
25
import pydantic
26
import regex as re
27
import uvloop
28
from fastapi import APIRouter, Depends, FastAPI, Form, HTTPException, Query, Request
Zhuohan Li's avatar
Zhuohan Li committed
29
30
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
31
from fastapi.responses import JSONResponse, Response, StreamingResponse
32
33
from prometheus_client import make_asgi_app
from prometheus_fastapi_instrumentator import Instrumentator
34
from starlette.concurrency import iterate_in_threadpool
35
from starlette.datastructures import URL, Headers, MutableHeaders, State
36
from starlette.routing import Mount
37
from starlette.types import ASGIApp, Message, Receive, Scope, Send
38
from typing_extensions import assert_never
Zhuohan Li's avatar
Zhuohan Li committed
39

40
import vllm.envs as envs
41
from vllm.config import VllmConfig
Woosuk Kwon's avatar
Woosuk Kwon committed
42
from vllm.engine.arg_utils import AsyncEngineArgs
43
from vllm.engine.protocol import Device, EngineClient
44
45
46
47
48
49
50
from vllm.entrypoints.anthropic.protocol import (
    AnthropicError,
    AnthropicErrorResponse,
    AnthropicMessagesRequest,
    AnthropicMessagesResponse,
)
from vllm.entrypoints.anthropic.serving_messages import AnthropicServingMessages
51
from vllm.entrypoints.launcher import serve_http
52
from vllm.entrypoints.logger import RequestLogger
53
54
55
56
57
58
59
60
61
62
from vllm.entrypoints.openai.cli_args import make_arg_parser, validate_parsed_serve_args
from vllm.entrypoints.openai.protocol import (
    ChatCompletionRequest,
    ChatCompletionResponse,
    ClassificationRequest,
    ClassificationResponse,
    CompletionRequest,
    CompletionResponse,
    DetokenizeRequest,
    DetokenizeResponse,
63
    EmbeddingBytesResponse,
64
65
66
67
68
69
    EmbeddingRequest,
    EmbeddingResponse,
    ErrorInfo,
    ErrorResponse,
    IOProcessorResponse,
    LoadLoRAAdapterRequest,
70
    PoolingBytesResponse,
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
    PoolingRequest,
    PoolingResponse,
    RerankRequest,
    RerankResponse,
    ResponsesRequest,
    ResponsesResponse,
    ScoreRequest,
    ScoreResponse,
    StreamingResponsesResponse,
    TokenizeRequest,
    TokenizeResponse,
    TranscriptionRequest,
    TranscriptionResponse,
    TranslationRequest,
    TranslationResponse,
    UnloadLoRAAdapterRequest,
)
88
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
89
from vllm.entrypoints.openai.serving_classification import ServingClassification
90
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
91
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
92
from vllm.entrypoints.openai.serving_engine import OpenAIServing
93
94
95
96
from vllm.entrypoints.openai.serving_models import (
    BaseModelPath,
    OpenAIServingModels,
)
97
from vllm.entrypoints.openai.serving_pooling import OpenAIServingPooling
98
from vllm.entrypoints.openai.serving_responses import OpenAIServingResponses
99
from vllm.entrypoints.openai.serving_score import ServingScores
100
from vllm.entrypoints.openai.serving_tokenization import OpenAIServingTokenization
101
from vllm.entrypoints.openai.serving_transcription import (
102
103
104
    OpenAIServingTranscription,
    OpenAIServingTranslation,
)
105
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
106
107
108
109
110
from vllm.entrypoints.tool_server import DemoToolServer, MCPToolServer, ToolServer
from vllm.entrypoints.utils import (
    cli_env_setup,
    load_aware_call,
    log_non_default_args,
111
112
    process_chat_template,
    process_lora_modules,
113
114
    with_cancellation,
)
115
from vllm.logger import init_logger
116
from vllm.reasoning import ReasoningParserManager
117
from vllm.tasks import POOLING_TASKS
yhu422's avatar
yhu422 committed
118
from vllm.usage.usage_lib import UsageContext
Cyrus Leung's avatar
Cyrus Leung committed
119
from vllm.utils.argparse_utils import FlexibleArgumentParser
120
from vllm.utils.network_utils import is_valid_ipv6_address
121
from vllm.utils.system_utils import decorate_logs, set_ulimit
122
from vllm.v1.engine.exceptions import EngineDeadError
123
from vllm.v1.metrics.prometheus import get_prometheus_registry
124
from vllm.version import __version__ as VLLM_VERSION
Zhuohan Li's avatar
Zhuohan Li committed
125

126
prometheus_multiproc_dir: tempfile.TemporaryDirectory
127

128
# Cannot use __name__ (https://github.com/vllm-project/vllm/pull/4765)
129
logger = init_logger("vllm.entrypoints.openai.api_server")
130

131
_running_tasks: set[asyncio.Task] = set()
132

133

134
@asynccontextmanager
135
async def lifespan(app: FastAPI):
136
137
    try:
        if app.state.log_stats:
138
            engine_client: EngineClient = app.state.engine_client
139
140
141

            async def _force_log():
                while True:
142
                    await asyncio.sleep(envs.VLLM_LOG_STATS_INTERVAL)
143
                    await engine_client.do_log_stats()
144
145
146
147
148
149

            task = asyncio.create_task(_force_log())
            _running_tasks.add(task)
            task.add_done_callback(_running_tasks.remove)
        else:
            task = None
150
151
152
153
154

        # Mark the startup heap as static so that it's ignored by GC.
        # Reduces pause times of oldest generation collections.
        gc.collect()
        gc.freeze()
155
156
157
158
159
160
161
162
        try:
            yield
        finally:
            if task is not None:
                task.cancel()
    finally:
        # Ensure app state including engine ref is gc'd
        del app.state
163
164


165
@asynccontextmanager
166
async def build_async_engine_client(
167
    args: Namespace,
168
169
    *,
    usage_context: UsageContext = UsageContext.OPENAI_API_SERVER,
170
171
    disable_frontend_multiprocessing: bool | None = None,
    client_config: dict[str, Any] | None = None,
172
) -> AsyncIterator[EngineClient]:
173
174
175
176
    if os.getenv("VLLM_WORKER_MULTIPROC_METHOD") == "forkserver":
        # The executor is expected to be mp.
        # Pre-import heavy modules in the forkserver process
        logger.debug("Setup forkserver with pre-imports")
177
        multiprocessing.set_start_method("forkserver")
178
179
180
181
        multiprocessing.set_forkserver_preload(["vllm.v1.engine.async_llm"])
        forkserver.ensure_running()
        logger.debug("Forkserver setup complete!")

182
    # Context manager to handle engine_client lifecycle
183
184
    # Ensures everything is shutdown and cleaned up on error/exit
    engine_args = AsyncEngineArgs.from_cli_args(args)
185
186
187
    if client_config:
        engine_args._api_process_count = client_config.get("client_count", 1)
        engine_args._api_process_rank = client_config.get("client_index", 0)
188

189
    if disable_frontend_multiprocessing is None:
190
        disable_frontend_multiprocessing = bool(args.disable_frontend_multiprocessing)
191

192
    async with build_async_engine_client_from_engine_args(
193
194
195
196
        engine_args,
        usage_context=usage_context,
        disable_frontend_multiprocessing=disable_frontend_multiprocessing,
        client_config=client_config,
197
    ) as engine:
198
199
200
201
202
203
        yield engine


@asynccontextmanager
async def build_async_engine_client_from_engine_args(
    engine_args: AsyncEngineArgs,
204
205
    *,
    usage_context: UsageContext = UsageContext.OPENAI_API_SERVER,
206
    disable_frontend_multiprocessing: bool = False,
207
    client_config: dict[str, Any] | None = None,
208
) -> AsyncIterator[EngineClient]:
209
    """
210
    Create EngineClient, either:
211
212
213
214
215
216
        - in-process using the AsyncLLMEngine Directly
        - multiprocess using AsyncLLMEngine RPC

    Returns the Client or None if the creation failed.
    """

217
218
219
220
    # Create the EngineConfig (determines if we can use V1).
    vllm_config = engine_args.create_engine_config(usage_context=usage_context)

    # V1 AsyncLLM.
221
    assert envs.VLLM_USE_V1
222

223
224
225
    if disable_frontend_multiprocessing:
        logger.warning(
            "V1 is enabled, but got --disable-frontend-multiprocessing. "
226
227
            "To disable frontend multiprocessing, set VLLM_USE_V1=0."
        )
228

229
    from vllm.v1.engine.async_llm import AsyncLLM
230

231
    async_llm: AsyncLLM | None = None
232
233
234
235
236
237

    # Don't mutate the input client_config
    client_config = dict(client_config) if client_config else {}
    client_count = client_config.pop("client_count", 1)
    client_index = client_config.pop("client_index", 0)

238
239
240
241
242
    try:
        async_llm = AsyncLLM.from_vllm_config(
            vllm_config=vllm_config,
            usage_context=usage_context,
            enable_log_requests=engine_args.enable_log_requests,
243
            aggregate_engine_logging=engine_args.aggregate_engine_logging,
244
245
246
            disable_log_stats=engine_args.disable_log_stats,
            client_addresses=client_config,
            client_count=client_count,
247
248
            client_index=client_index,
        )
249
250

        # Don't keep the dummy data in memory
251
        assert async_llm is not None
252
253
254
255
256
257
        await async_llm.reset_mm_cache()

        yield async_llm
    finally:
        if async_llm:
            async_llm.shutdown()
258
259


260
261
async def validate_json_request(raw_request: Request):
    content_type = raw_request.headers.get("content-type", "").lower()
262
263
    media_type = content_type.split(";", maxsplit=1)[0]
    if media_type != "application/json":
264
265
266
        raise RequestValidationError(
            errors=["Unsupported Media Type: Only 'application/json' is allowed"]
        )
267
268


Ethan Xu's avatar
Ethan Xu committed
269
router = APIRouter()
Zhuohan Li's avatar
Zhuohan Li committed
270

271

272
273
274
275
class PrometheusResponse(Response):
    media_type = prometheus_client.CONTENT_TYPE_LATEST


276
def mount_metrics(app: FastAPI):
277
278
279
    """Mount prometheus metrics to a FastAPI app."""

    registry = get_prometheus_registry()
280

281
282
283
284
    # `response_class=PrometheusResponse` is needed to return an HTTP response
    # with header "Content-Type: text/plain; version=0.0.4; charset=utf-8"
    # instead of the default "application/json" which is incorrect.
    # See https://github.com/trallnag/prometheus-fastapi-instrumentator/issues/163#issue-1296092364
285
286
287
288
289
290
291
292
293
294
    Instrumentator(
        excluded_handlers=[
            "/metrics",
            "/health",
            "/load",
            "/ping",
            "/version",
            "/server_info",
        ],
        registry=registry,
295
    ).add().instrument(app).expose(app, response_class=PrometheusResponse)
296
297
298

    # Add prometheus asgi middleware to route /metrics requests
    metrics_route = Mount("/metrics", make_asgi_app(registry=registry))
299

300
    # Workaround for 307 Redirect for /metrics
301
    metrics_route.path_regex = re.compile("^/metrics(?P<path>.*)$")
302
    app.routes.append(metrics_route)
303
304


305
306
307
308
309
def base(request: Request) -> OpenAIServing:
    # Reuse the existing instance
    return tokenization(request)


310
311
312
313
def models(request: Request) -> OpenAIServingModels:
    return request.app.state.openai_serving_models


314
def responses(request: Request) -> OpenAIServingResponses | None:
315
316
317
    return request.app.state.openai_serving_responses


318
319
320
321
def messages(request: Request) -> AnthropicServingMessages:
    return request.app.state.anthropic_serving_messages


322
def chat(request: Request) -> OpenAIServingChat | None:
323
324
325
    return request.app.state.openai_serving_chat


326
def completion(request: Request) -> OpenAIServingCompletion | None:
327
328
329
    return request.app.state.openai_serving_completion


330
def pooling(request: Request) -> OpenAIServingPooling | None:
331
332
333
    return request.app.state.openai_serving_pooling


334
def embedding(request: Request) -> OpenAIServingEmbedding | None:
335
    return request.app.state.openai_serving_embedding
336
337


338
def score(request: Request) -> ServingScores | None:
339
340
341
    return request.app.state.openai_serving_scores


342
def classify(request: Request) -> ServingClassification | None:
343
344
345
    return request.app.state.openai_serving_classification


346
def rerank(request: Request) -> ServingScores | None:
347
    return request.app.state.openai_serving_scores
348
349


350
351
def tokenization(request: Request) -> OpenAIServingTokenization:
    return request.app.state.openai_serving_tokenization
352
353


354
355
356
357
def transcription(request: Request) -> OpenAIServingTranscription:
    return request.app.state.openai_serving_transcription


358
359
360
361
def translation(request: Request) -> OpenAIServingTranslation:
    return request.app.state.openai_serving_translation


362
def engine_client(request: Request) -> EngineClient:
363
364
365
    return request.app.state.engine_client


366
367
@router.get("/health", response_class=Response)
async def health(raw_request: Request) -> Response:
368
    """Health check."""
369
370
371
372
373
    try:
        await engine_client(raw_request).check_health()
        return Response(status_code=200)
    except EngineDeadError:
        return Response(status_code=503)
374
375


376
377
378
379
380
381
382
@router.get("/load")
async def get_server_load_metrics(request: Request):
    # This endpoint returns the current server load metrics.
    # It tracks requests utilizing the GPU from the following routes:
    # - /v1/chat/completions
    # - /v1/completions
    # - /v1/audio/transcriptions
383
    # - /v1/audio/translations
384
385
    # - /v1/embeddings
    # - /pooling
386
    # - /classify
387
388
389
390
391
    # - /score
    # - /v1/score
    # - /rerank
    # - /v1/rerank
    # - /v2/rerank
392
    return JSONResponse(content={"server_load": request.app.state.server_load_metrics})
393
394


395
396
397
@router.get("/ping", response_class=Response)
@router.post("/ping", response_class=Response)
async def ping(raw_request: Request) -> Response:
398
399
400
401
    """Ping check. Endpoint required for SageMaker"""
    return await health(raw_request)


402
403
404
405
406
407
408
409
410
411
@router.post(
    "/tokenize",
    dependencies=[Depends(validate_json_request)],
    responses={
        HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
        HTTPStatus.NOT_FOUND.value: {"model": ErrorResponse},
        HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
        HTTPStatus.NOT_IMPLEMENTED.value: {"model": ErrorResponse},
    },
)
412
@with_cancellation
413
async def tokenize(request: TokenizeRequest, raw_request: Request):
414
415
    handler = tokenization(raw_request)

416
417
418
    try:
        generator = await handler.create_tokenize(request, raw_request)
    except NotImplementedError as e:
419
420
421
        raise HTTPException(
            status_code=HTTPStatus.NOT_IMPLEMENTED.value, detail=str(e)
        ) from e
422
    except Exception as e:
423
424
425
        raise HTTPException(
            status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)
        ) from e
426

427
    if isinstance(generator, ErrorResponse):
428
429
430
        return JSONResponse(
            content=generator.model_dump(), status_code=generator.error.code
        )
431
    elif isinstance(generator, TokenizeResponse):
432
433
        return JSONResponse(content=generator.model_dump())

434
435
    assert_never(generator)

436

437
438
439
440
441
442
443
444
445
@router.post(
    "/detokenize",
    dependencies=[Depends(validate_json_request)],
    responses={
        HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
        HTTPStatus.NOT_FOUND.value: {"model": ErrorResponse},
        HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
    },
)
446
@with_cancellation
447
async def detokenize(request: DetokenizeRequest, raw_request: Request):
448
449
    handler = tokenization(raw_request)

450
451
452
453
454
    try:
        generator = await handler.create_detokenize(request, raw_request)
    except OverflowError as e:
        raise RequestValidationError(errors=[str(e)]) from e
    except Exception as e:
455
456
457
        raise HTTPException(
            status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)
        ) from e
458

459
    if isinstance(generator, ErrorResponse):
460
461
462
        return JSONResponse(
            content=generator.model_dump(), status_code=generator.error.code
        )
463
    elif isinstance(generator, DetokenizeResponse):
464
465
        return JSONResponse(content=generator.model_dump())

466
467
    assert_never(generator)

468

469
470
def maybe_register_tokenizer_info_endpoint(args):
    """Conditionally register the tokenizer info endpoint if enabled."""
471
    if getattr(args, "enable_tokenizer_info_endpoint", False):
472
473
474
475
476

        @router.get("/tokenizer_info")
        async def get_tokenizer_info(raw_request: Request):
            """Get comprehensive tokenizer information."""
            result = await tokenization(raw_request).get_tokenizer_info()
477
478
479
480
481
482
            return JSONResponse(
                content=result.model_dump(),
                status_code=result.error.code
                if isinstance(result, ErrorResponse)
                else 200,
            )
483
484


Ethan Xu's avatar
Ethan Xu committed
485
@router.get("/v1/models")
486
async def show_available_models(raw_request: Request):
487
    handler = models(raw_request)
488

489
490
    models_ = await handler.show_available_models()
    return JSONResponse(content=models_.model_dump())
Zhuohan Li's avatar
Zhuohan Li committed
491
492


Ethan Xu's avatar
Ethan Xu committed
493
@router.get("/version")
494
async def show_version():
495
    ver = {"version": VLLM_VERSION}
496
497
498
    return JSONResponse(content=ver)


499
async def _convert_stream_to_sse_events(
500
    generator: AsyncGenerator[StreamingResponsesResponse, None],
501
) -> AsyncGenerator[str, None]:
502
503
    """Convert the generator to a stream of events in SSE format"""
    async for event in generator:
504
        event_type = getattr(event, "type", "unknown")
505
        # https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#event_stream_format
506
507
508
        event_data = (
            f"event: {event_type}\ndata: {event.model_dump_json(indent=None)}\n\n"
        )
509
510
511
        yield event_data


512
513
514
515
516
517
518
519
520
521
@router.post(
    "/v1/responses",
    dependencies=[Depends(validate_json_request)],
    responses={
        HTTPStatus.OK.value: {"content": {"text/event-stream": {}}},
        HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
        HTTPStatus.NOT_FOUND.value: {"model": ErrorResponse},
        HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
    },
)
522
523
524
525
526
@with_cancellation
async def create_responses(request: ResponsesRequest, raw_request: Request):
    handler = responses(raw_request)
    if handler is None:
        return base(raw_request).create_error_response(
527
528
            message="The model does not support Responses API"
        )
529
530
531
    try:
        generator = await handler.create_responses(request, raw_request)
    except Exception as e:
532
533
534
        raise HTTPException(
            status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)
        ) from e
535
536

    if isinstance(generator, ErrorResponse):
537
538
539
        return JSONResponse(
            content=generator.model_dump(), status_code=generator.error.code
        )
540
541
    elif isinstance(generator, ResponsesResponse):
        return JSONResponse(content=generator.model_dump())
542

543
544
545
    return StreamingResponse(
        content=_convert_stream_to_sse_events(generator), media_type="text/event-stream"
    )
546
547
548


@router.get("/v1/responses/{response_id}")
549
550
551
async def retrieve_responses(
    response_id: str,
    raw_request: Request,
552
553
    starting_after: int | None = None,
    stream: bool | None = False,
554
):
555
556
557
    handler = responses(raw_request)
    if handler is None:
        return base(raw_request).create_error_response(
558
559
            message="The model does not support Responses API"
        )
560

561
    try:
562
563
564
565
566
        response = await handler.retrieve_responses(
            response_id,
            starting_after=starting_after,
            stream=stream,
        )
567
    except Exception as e:
568
569
570
        raise HTTPException(
            status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)
        ) from e
571
572

    if isinstance(response, ErrorResponse):
573
574
575
        return JSONResponse(
            content=response.model_dump(), status_code=response.error.code
        )
576
577
    elif isinstance(response, ResponsesResponse):
        return JSONResponse(content=response.model_dump())
578
579
580
    return StreamingResponse(
        content=_convert_stream_to_sse_events(response), media_type="text/event-stream"
    )
581
582
583
584
585
586
587


@router.post("/v1/responses/{response_id}/cancel")
async def cancel_responses(response_id: str, raw_request: Request):
    handler = responses(raw_request)
    if handler is None:
        return base(raw_request).create_error_response(
588
589
            message="The model does not support Responses API"
        )
590

591
592
593
    try:
        response = await handler.cancel_responses(response_id)
    except Exception as e:
594
595
596
        raise HTTPException(
            status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)
        ) from e
597
598

    if isinstance(response, ErrorResponse):
599
600
601
        return JSONResponse(
            content=response.model_dump(), status_code=response.error.code
        )
602
603
604
    return JSONResponse(content=response.model_dump())


605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
@router.post(
    "/v1/messages",
    dependencies=[Depends(validate_json_request)],
    responses={
        HTTPStatus.OK.value: {"content": {"text/event-stream": {}}},
        HTTPStatus.BAD_REQUEST.value: {"model": AnthropicErrorResponse},
        HTTPStatus.NOT_FOUND.value: {"model": AnthropicErrorResponse},
        HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": AnthropicErrorResponse},
    },
)
@with_cancellation
@load_aware_call
async def create_messages(request: AnthropicMessagesRequest, raw_request: Request):
    def translate_error_response(response: ErrorResponse) -> JSONResponse:
        anthropic_error = AnthropicErrorResponse(
            error=AnthropicError(
                type=response.error.type,
                message=response.error.message,
            )
        )
        return JSONResponse(
            status_code=response.error.code, content=anthropic_error.model_dump()
        )

    handler = messages(raw_request)
    if handler is None:
        error = base(raw_request).create_error_response(
            message="The model does not support Messages API"
        )
        return translate_error_response(error)

    try:
        generator = await handler.create_messages(request, raw_request)
    except Exception as e:
        logger.exception("Error in create_messages: %s", e)
        return JSONResponse(
            status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value,
            content=AnthropicErrorResponse(
                error=AnthropicError(
                    type="internal_error",
                    message=str(e),
                )
            ).model_dump(),
        )

    if isinstance(generator, ErrorResponse):
        return translate_error_response(generator)

    elif isinstance(generator, AnthropicMessagesResponse):
        logger.debug(
            "Anthropic Messages Response: %s", generator.model_dump(exclude_none=True)
        )
        return JSONResponse(content=generator.model_dump(exclude_none=True))

    return StreamingResponse(content=generator, media_type="text/event-stream")


662
663
664
665
666
667
668
669
670
671
@router.post(
    "/v1/chat/completions",
    dependencies=[Depends(validate_json_request)],
    responses={
        HTTPStatus.OK.value: {"content": {"text/event-stream": {}}},
        HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
        HTTPStatus.NOT_FOUND.value: {"model": ErrorResponse},
        HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
    },
)
672
@with_cancellation
673
@load_aware_call
674
async def create_chat_completion(request: ChatCompletionRequest, raw_request: Request):
675
676
677
    handler = chat(raw_request)
    if handler is None:
        return base(raw_request).create_error_response(
678
679
            message="The model does not support Chat Completions API"
        )
680
681
682
    try:
        generator = await handler.create_chat_completion(request, raw_request)
    except Exception as e:
683
684
685
        raise HTTPException(
            status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)
        ) from e
686
    if isinstance(generator, ErrorResponse):
687
688
689
        return JSONResponse(
            content=generator.model_dump(), status_code=generator.error.code
        )
690

691
    elif isinstance(generator, ChatCompletionResponse):
692
        return JSONResponse(content=generator.model_dump())
693

694
695
    return StreamingResponse(content=generator, media_type="text/event-stream")

696

697
698
699
700
701
702
703
704
705
706
@router.post(
    "/v1/completions",
    dependencies=[Depends(validate_json_request)],
    responses={
        HTTPStatus.OK.value: {"content": {"text/event-stream": {}}},
        HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
        HTTPStatus.NOT_FOUND.value: {"model": ErrorResponse},
        HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
    },
)
707
@with_cancellation
708
@load_aware_call
709
async def create_completion(request: CompletionRequest, raw_request: Request):
710
711
712
    handler = completion(raw_request)
    if handler is None:
        return base(raw_request).create_error_response(
713
714
            message="The model does not support Completions API"
        )
715

716
717
718
    try:
        generator = await handler.create_completion(request, raw_request)
    except OverflowError as e:
719
720
721
        raise HTTPException(
            status_code=HTTPStatus.BAD_REQUEST.value, detail=str(e)
        ) from e
722
    except Exception as e:
723
724
725
        raise HTTPException(
            status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)
        ) from e
726

727
    if isinstance(generator, ErrorResponse):
728
729
730
        return JSONResponse(
            content=generator.model_dump(), status_code=generator.error.code
        )
731
    elif isinstance(generator, CompletionResponse):
732
        return JSONResponse(content=generator.model_dump())
Zhuohan Li's avatar
Zhuohan Li committed
733

734
735
    return StreamingResponse(content=generator, media_type="text/event-stream")

Zhuohan Li's avatar
Zhuohan Li committed
736

737
738
739
740
741
742
743
744
@router.post(
    "/v1/embeddings",
    dependencies=[Depends(validate_json_request)],
    responses={
        HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
        HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
    },
)
745
@with_cancellation
746
@load_aware_call
747
748
749
750
async def create_embedding(
    request: EmbeddingRequest,
    raw_request: Request,
):
751
752
    handler = embedding(raw_request)
    if handler is None:
753
        return base(raw_request).create_error_response(
754
755
            message="The model does not support Embeddings API"
        )
756

757
758
759
    try:
        generator = await handler.create_embedding(request, raw_request)
    except Exception as e:
760
761
762
        raise HTTPException(
            status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)
        ) from e
763

764
    if isinstance(generator, ErrorResponse):
765
766
767
        return JSONResponse(
            content=generator.model_dump(), status_code=generator.error.code
        )
768
    elif isinstance(generator, EmbeddingResponse):
769
        return JSONResponse(content=generator.model_dump())
770
771
772
773
774
775
    elif isinstance(generator, EmbeddingBytesResponse):
        return StreamingResponse(
            content=generator.body,
            headers={"metadata": generator.metadata},
            media_type=generator.media_type,
        )
776

777
778
    assert_never(generator)

779

780
781
782
783
784
785
786
787
@router.post(
    "/pooling",
    dependencies=[Depends(validate_json_request)],
    responses={
        HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
        HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
    },
)
788
@with_cancellation
789
@load_aware_call
790
791
792
793
async def create_pooling(request: PoolingRequest, raw_request: Request):
    handler = pooling(raw_request)
    if handler is None:
        return base(raw_request).create_error_response(
794
795
            message="The model does not support Pooling API"
        )
796
797
798
    try:
        generator = await handler.create_pooling(request, raw_request)
    except Exception as e:
799
800
801
        raise HTTPException(
            status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)
        ) from e
802
    if isinstance(generator, ErrorResponse):
803
804
805
        return JSONResponse(
            content=generator.model_dump(), status_code=generator.error.code
        )
806
    elif isinstance(generator, (PoolingResponse, IOProcessorResponse)):
807
        return JSONResponse(content=generator.model_dump())
808
809
810
811
812
813
    elif isinstance(generator, PoolingBytesResponse):
        return StreamingResponse(
            content=generator.body,
            headers={"metadata": generator.metadata},
            media_type=generator.media_type,
        )
814
815
816
817

    assert_never(generator)


818
819
820
@router.post("/classify", dependencies=[Depends(validate_json_request)])
@with_cancellation
@load_aware_call
821
async def create_classify(request: ClassificationRequest, raw_request: Request):
822
823
824
    handler = classify(raw_request)
    if handler is None:
        return base(raw_request).create_error_response(
825
826
            message="The model does not support Classification API"
        )
827

828
829
830
    try:
        generator = await handler.create_classify(request, raw_request)
    except Exception as e:
831
832
833
        raise HTTPException(
            status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)
        ) from e
834
    if isinstance(generator, ErrorResponse):
835
836
837
        return JSONResponse(
            content=generator.model_dump(), status_code=generator.error.code
        )
838
839
840
841
842
843
844

    elif isinstance(generator, ClassificationResponse):
        return JSONResponse(content=generator.model_dump())

    assert_never(generator)


845
846
847
848
849
850
851
852
@router.post(
    "/score",
    dependencies=[Depends(validate_json_request)],
    responses={
        HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
        HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
    },
)
853
@with_cancellation
854
@load_aware_call
855
856
857
858
async def create_score(request: ScoreRequest, raw_request: Request):
    handler = score(raw_request)
    if handler is None:
        return base(raw_request).create_error_response(
859
860
            message="The model does not support Score API"
        )
861

862
863
864
    try:
        generator = await handler.create_score(request, raw_request)
    except Exception as e:
865
866
867
        raise HTTPException(
            status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)
        ) from e
868
    if isinstance(generator, ErrorResponse):
869
870
871
        return JSONResponse(
            content=generator.model_dump(), status_code=generator.error.code
        )
872
873
874
875
876
877
    elif isinstance(generator, ScoreResponse):
        return JSONResponse(content=generator.model_dump())

    assert_never(generator)


878
879
880
881
882
883
884
885
@router.post(
    "/v1/score",
    dependencies=[Depends(validate_json_request)],
    responses={
        HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
        HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
    },
)
886
@with_cancellation
887
@load_aware_call
888
889
890
async def create_score_v1(request: ScoreRequest, raw_request: Request):
    logger.warning(
        "To indicate that Score API is not part of standard OpenAI API, we "
891
892
        "have moved it to `/score`. Please update your client accordingly."
    )
893
894
895
896

    return await create_score(request, raw_request)


897
898
899
900
901
902
903
904
905
@router.post(
    "/v1/audio/transcriptions",
    responses={
        HTTPStatus.OK.value: {"content": {"text/event-stream": {}}},
        HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
        HTTPStatus.UNPROCESSABLE_ENTITY.value: {"model": ErrorResponse},
        HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
    },
)
906
@with_cancellation
907
@load_aware_call
908
909
910
async def create_transcriptions(
    raw_request: Request, request: Annotated[TranscriptionRequest, Form()]
):
911
912
913
    handler = transcription(raw_request)
    if handler is None:
        return base(raw_request).create_error_response(
914
915
            message="The model does not support Transcriptions API"
        )
916
917

    audio_data = await request.file.read()
918
    try:
919
        generator = await handler.create_transcription(audio_data, request, raw_request)
920
    except Exception as e:
921
922
923
        raise HTTPException(
            status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)
        ) from e
924
925

    if isinstance(generator, ErrorResponse):
926
927
928
        return JSONResponse(
            content=generator.model_dump(), status_code=generator.error.code
        )
929
930
931
932
933
934
935

    elif isinstance(generator, TranscriptionResponse):
        return JSONResponse(content=generator.model_dump())

    return StreamingResponse(content=generator, media_type="text/event-stream")


936
937
938
939
940
941
942
943
944
@router.post(
    "/v1/audio/translations",
    responses={
        HTTPStatus.OK.value: {"content": {"text/event-stream": {}}},
        HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
        HTTPStatus.UNPROCESSABLE_ENTITY.value: {"model": ErrorResponse},
        HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
    },
)
945
946
@with_cancellation
@load_aware_call
947
948
949
async def create_translations(
    request: Annotated[TranslationRequest, Form()], raw_request: Request
):
950
951
952
    handler = translation(raw_request)
    if handler is None:
        return base(raw_request).create_error_response(
953
954
            message="The model does not support Translations API"
        )
955
956

    audio_data = await request.file.read()
957
    try:
958
        generator = await handler.create_translation(audio_data, request, raw_request)
959
    except Exception as e:
960
961
962
        raise HTTPException(
            status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)
        ) from e
963
964

    if isinstance(generator, ErrorResponse):
965
966
967
        return JSONResponse(
            content=generator.model_dump(), status_code=generator.error.code
        )
968
969
970
971
972
973
974

    elif isinstance(generator, TranslationResponse):
        return JSONResponse(content=generator.model_dump())

    return StreamingResponse(content=generator, media_type="text/event-stream")


975
976
977
978
979
980
981
982
@router.post(
    "/rerank",
    dependencies=[Depends(validate_json_request)],
    responses={
        HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
        HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
    },
)
983
@with_cancellation
984
@load_aware_call
985
986
987
988
async def do_rerank(request: RerankRequest, raw_request: Request):
    handler = rerank(raw_request)
    if handler is None:
        return base(raw_request).create_error_response(
989
990
            message="The model does not support Rerank (Score) API"
        )
991
992
993
    try:
        generator = await handler.do_rerank(request, raw_request)
    except Exception as e:
994
995
996
        raise HTTPException(
            status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)
        ) from e
997
    if isinstance(generator, ErrorResponse):
998
999
1000
        return JSONResponse(
            content=generator.model_dump(), status_code=generator.error.code
        )
1001
1002
1003
1004
1005
1006
    elif isinstance(generator, RerankResponse):
        return JSONResponse(content=generator.model_dump())

    assert_never(generator)


1007
1008
1009
1010
1011
1012
1013
1014
@router.post(
    "/v1/rerank",
    dependencies=[Depends(validate_json_request)],
    responses={
        HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
        HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
    },
)
1015
1016
@with_cancellation
async def do_rerank_v1(request: RerankRequest, raw_request: Request):
1017
    logger.warning_once(
1018
        "To indicate that the rerank API is not part of the standard OpenAI"
1019
        " API, we have located it at `/rerank`. Please update your client "
1020
1021
        "accordingly. (Note: Conforms to JinaAI rerank API)"
    )
1022
1023
1024
1025

    return await do_rerank(request, raw_request)


1026
1027
1028
1029
1030
1031
1032
1033
@router.post(
    "/v2/rerank",
    dependencies=[Depends(validate_json_request)],
    responses={
        HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
        HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
    },
)
1034
1035
1036
1037
1038
@with_cancellation
async def do_rerank_v2(request: RerankRequest, raw_request: Request):
    return await do_rerank(request, raw_request)


1039
if envs.VLLM_SERVER_DEV_MODE:
1040
1041
1042
1043
    logger.warning(
        "SECURITY WARNING: Development endpoints are enabled! "
        "This should NOT be used in production!"
    )
1044

1045
1046
    PydanticVllmConfig = pydantic.TypeAdapter(VllmConfig)

1047
    @router.get("/server_info")
1048
1049
    async def show_server_info(
        raw_request: Request,
1050
        config_format: Annotated[Literal["text", "json"], Query()] = "text",
1051
1052
1053
    ):
        vllm_config: VllmConfig = raw_request.app.state.vllm_config
        server_info = {
1054
1055
1056
            "vllm_config": str(vllm_config)
            if config_format == "text"
            else PydanticVllmConfig.dump_python(vllm_config, mode="json", fallback=str)
1057
1058
            # fallback=str is needed to handle e.g. torch.dtype
        }
1059
1060
        return JSONResponse(content=server_info)

1061
1062
1063
1064
1065
1066
    @router.post("/reset_prefix_cache")
    async def reset_prefix_cache(raw_request: Request):
        """
        Reset the prefix cache. Note that we currently do not check if the
        prefix cache is successfully reset in the API server.
        """
1067
1068
1069
1070
1071
1072
        device = None
        device_str = raw_request.query_params.get("device")
        if device_str is not None:
            device = Device[device_str.upper()]
        logger.info("Resetting prefix cache with specific %s...", str(device))
        await engine_client(raw_request).reset_prefix_cache(device)
1073
1074
        return Response(status_code=200)

1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
    @router.post("/reset_mm_cache")
    async def reset_mm_cache(raw_request: Request):
        """
        Reset the multi-modal cache. Note that we currently do not check if the
        multi-modal cache is successfully reset in the API server.
        """
        logger.info("Resetting multi-modal cache...")
        await engine_client(raw_request).reset_mm_cache()
        return Response(status_code=200)

1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
    @router.post("/sleep")
    async def sleep(raw_request: Request):
        # get POST params
        level = raw_request.query_params.get("level", "1")
        await engine_client(raw_request).sleep(int(level))
        # FIXME: in v0 with frontend multiprocessing, the sleep command
        # is sent but does not finish yet when we return a response.
        return Response(status_code=200)

    @router.post("/wake_up")
    async def wake_up(raw_request: Request):
1096
1097
1098
1099
1100
1101
        tags = raw_request.query_params.getlist("tags")
        if tags == []:
            # set to None to wake up all tags if no tags are provided
            tags = None
        logger.info("wake up the engine with tags: %s", tags)
        await engine_client(raw_request).wake_up(tags)
1102
1103
1104
1105
        # FIXME: in v0 with frontend multiprocessing, the wake-up command
        # is sent but does not finish yet when we return a response.
        return Response(status_code=200)

1106
1107
1108
1109
1110
1111
    @router.get("/is_sleeping")
    async def is_sleeping(raw_request: Request):
        logger.info("check whether the engine is sleeping")
        is_sleeping = await engine_client(raw_request).is_sleeping()
        return JSONResponse(content={"is_sleeping": is_sleeping})

1112
1113
1114
1115
1116
    @router.post("/collective_rpc")
    async def collective_rpc(raw_request: Request):
        try:
            body = await raw_request.json()
        except json.JSONDecodeError as e:
1117
1118
1119
1120
            raise HTTPException(
                status_code=HTTPStatus.BAD_REQUEST.value,
                detail=f"JSON decode error: {e}",
            ) from e
1121
1122
        method = body.get("method")
        if method is None:
1123
1124
1125
1126
            raise HTTPException(
                status_code=HTTPStatus.BAD_REQUEST.value,
                detail="Missing 'method' in request body",
            )
1127
        # For security reason, only serialized string args/kwargs are passed.
1128
        # User-defined `method` is responsible for deserialization if needed.
1129
1130
        args: list[str] = body.get("args", [])
        kwargs: dict[str, str] = body.get("kwargs", {})
1131
        timeout: float | None = body.get("timeout")
1132
        results = await engine_client(raw_request).collective_rpc(
1133
1134
            method=method, timeout=timeout, args=tuple(args), kwargs=kwargs
        )
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
        if results is None:
            return Response(status_code=200)
        response: list[Any] = []
        for result in results:
            if result is None or isinstance(result, (dict, list)):
                response.append(result)
            else:
                response.append(str(result))
        return JSONResponse(content={"results": response})

1145

1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
@router.post(
    "/scale_elastic_ep",
    dependencies=[Depends(validate_json_request)],
    responses={
        HTTPStatus.OK.value: {"model": dict},
        HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
        HTTPStatus.REQUEST_TIMEOUT.value: {"model": ErrorResponse},
        HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
    },
)
1156
1157
1158
1159
async def scale_elastic_ep(raw_request: Request):
    try:
        body = await raw_request.json()
    except json.JSONDecodeError as e:
1160
        raise HTTPException(status_code=400, detail="Invalid JSON format") from e  # noqa: B904
1161
1162
1163
1164
1165

    new_data_parallel_size = body.get("new_data_parallel_size")
    drain_timeout = body.get("drain_timeout", 120)  # Default 2 minutes

    if new_data_parallel_size is None:
1166
1167
1168
        raise HTTPException(
            status_code=400, detail="new_data_parallel_size is required"
        )
1169

1170
    if not isinstance(new_data_parallel_size, int) or new_data_parallel_size <= 0:
1171
        raise HTTPException(
1172
1173
            status_code=400, detail="new_data_parallel_size must be a positive integer"
        )
1174
1175

    if not isinstance(drain_timeout, int) or drain_timeout <= 0:
1176
1177
1178
        raise HTTPException(
            status_code=400, detail="drain_timeout must be a positive integer"
        )
1179
1180
1181
1182
1183
1184
1185

    # Set scaling flag to prevent new requests
    global _scaling_elastic_ep
    _scaling_elastic_ep = True
    client = engine_client(raw_request)
    try:
        await client.scale_elastic_ep(new_data_parallel_size, drain_timeout)
1186
1187
1188
1189
1190
        return JSONResponse(
            {
                "message": f"Scaled to {new_data_parallel_size} data parallel engines",
            }
        )
1191
    except TimeoutError as e:
1192
1193
1194
1195
1196
        raise HTTPException(
            status_code=408,
            detail="Scale failed due to request drain timeout "
            f"after {drain_timeout} seconds",
        ) from e
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
    except Exception as e:
        logger.error("Scale failed: %s", e)
        raise HTTPException(status_code=500, detail="Scale failed") from e
    finally:
        _scaling_elastic_ep = False


@router.post("/is_scaling_elastic_ep")
async def is_scaling_elastic_ep(raw_request: Request):
    return JSONResponse({"is_scaling_elastic_ep": _scaling_elastic_ep})


1209
1210
1211
# TODO: RequestType = TypeForm[BaseModel] when recognized by type checkers
# (requires typing_extensions >= 4.13)
RequestType = Any
1212
GetHandlerFn = Callable[[Request], OpenAIServing | None]
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
EndpointFn = Callable[[RequestType, Request], Awaitable[Any]]

# NOTE: Items defined earlier take higher priority
INVOCATION_TYPES: list[tuple[RequestType, tuple[GetHandlerFn, EndpointFn]]] = [
    (ChatCompletionRequest, (chat, create_chat_completion)),
    (CompletionRequest, (completion, create_completion)),
    (EmbeddingRequest, (embedding, create_embedding)),
    (ClassificationRequest, (classify, create_classify)),
    (ScoreRequest, (score, create_score)),
    (RerankRequest, (rerank, do_rerank)),
    (PoolingRequest, (pooling, create_pooling)),
]

# NOTE: Construct the TypeAdapters only once
INVOCATION_VALIDATORS = [
    (pydantic.TypeAdapter(request_type), (get_handler, endpoint))
    for request_type, (get_handler, endpoint) in INVOCATION_TYPES
]


1233
1234
1235
1236
1237
1238
1239
1240
1241
@router.post(
    "/invocations",
    dependencies=[Depends(validate_json_request)],
    responses={
        HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
        HTTPStatus.UNSUPPORTED_MEDIA_TYPE.value: {"model": ErrorResponse},
        HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
    },
)
1242
async def invocations(raw_request: Request):
1243
    """For SageMaker, routes requests based on the request type."""
1244
1245
    try:
        body = await raw_request.json()
1246
    except json.JSONDecodeError as e:
1247
1248
1249
        raise HTTPException(
            status_code=HTTPStatus.BAD_REQUEST.value, detail=f"JSON decode error: {e}"
        ) from e
1250

1251
1252
1253
1254
1255
    valid_endpoints = [
        (validator, endpoint)
        for validator, (get_handler, endpoint) in INVOCATION_VALIDATORS
        if get_handler(raw_request) is not None
    ]
1256

1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
    for request_validator, endpoint in valid_endpoints:
        try:
            request = request_validator.validate_python(body)
        except pydantic.ValidationError:
            continue

        return await endpoint(request, raw_request)

    type_names = [
        t.__name__ if isinstance(t := validator._type, type) else str(t)
        for validator, _ in valid_endpoints
    ]
1269
    msg = f"Cannot find suitable handler for request. Expected one of: {type_names}"
1270
    res = base(raw_request).create_error_response(message=msg)
1271
    return JSONResponse(content=res.model_dump(), status_code=res.error.code)
1272
1273


1274
1275
1276
if envs.VLLM_TORCH_PROFILER_DIR:
    logger.warning(
        "Torch Profiler is enabled in the API server. This should ONLY be "
1277
1278
        "used for local development!"
    )
1279
1280

    @router.post("/start_profile")
1281
    async def start_profile(raw_request: Request):
1282
        logger.info("Starting profiler...")
1283
        await engine_client(raw_request).start_profile()
1284
1285
1286
1287
        logger.info("Profiler started.")
        return Response(status_code=200)

    @router.post("/stop_profile")
1288
    async def stop_profile(raw_request: Request):
1289
        logger.info("Stopping profiler...")
1290
        await engine_client(raw_request).stop_profile()
1291
1292
1293
1294
        logger.info("Profiler stopped.")
        return Response(status_code=200)


1295
1296
if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING:
    logger.warning(
1297
        "LoRA dynamic loading & unloading is enabled in the API server. "
1298
1299
        "This should ONLY be used for local development!"
    )
1300

1301
1302
    @router.post("/v1/load_lora_adapter", dependencies=[Depends(validate_json_request)])
    async def load_lora_adapter(request: LoadLoRAAdapterRequest, raw_request: Request):
1303
1304
1305
        handler = models(raw_request)
        response = await handler.load_lora_adapter(request)
        if isinstance(response, ErrorResponse):
1306
1307
1308
            return JSONResponse(
                content=response.model_dump(), status_code=response.error.code
            )
1309
1310
1311

        return Response(status_code=200, content=response)

1312
1313
1314
1315
1316
1317
    @router.post(
        "/v1/unload_lora_adapter", dependencies=[Depends(validate_json_request)]
    )
    async def unload_lora_adapter(
        request: UnloadLoRAAdapterRequest, raw_request: Request
    ):
1318
1319
1320
        handler = models(raw_request)
        response = await handler.unload_lora_adapter(request)
        if isinstance(response, ErrorResponse):
1321
1322
1323
            return JSONResponse(
                content=response.model_dump(), status_code=response.error.code
            )
1324
1325
1326
1327

        return Response(status_code=200, content=response)


1328
def load_log_config(log_config_file: str | None) -> dict | None:
1329
1330
1331
1332
1333
1334
    if not log_config_file:
        return None
    try:
        with open(log_config_file) as f:
            return json.load(f)
    except Exception as e:
1335
1336
1337
        logger.warning(
            "Failed to load log config from file %s: error %s", log_config_file, e
        )
1338
1339
1340
        return None


1341
1342
1343
class AuthenticationMiddleware:
    """
    Pure ASGI middleware that authenticates each request by checking
1344
    if the Authorization Bearer token exists and equals anyof "{api_key}".
1345
1346
1347
1348
1349
1350
1351
1352

    Notes
    -----
    There are two cases in which authentication is skipped:
        1. The HTTP method is OPTIONS.
        2. The request path doesn't start with /v1 (e.g. /health).
    """

1353
    def __init__(self, app: ASGIApp, tokens: list[str]) -> None:
1354
        self.app = app
1355
        self.api_tokens = [hashlib.sha256(t.encode("utf-8")).digest() for t in tokens]
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372

    def verify_token(self, headers: Headers) -> bool:
        authorization_header_value = headers.get("Authorization")
        if not authorization_header_value:
            return False

        scheme, _, param = authorization_header_value.partition(" ")
        if scheme.lower() != "bearer":
            return False

        param_hash = hashlib.sha256(param.encode("utf-8")).digest()

        token_match = False
        for token_hash in self.api_tokens:
            token_match |= secrets.compare_digest(param_hash, token_hash)

        return token_match
1373

1374
1375
    def __call__(self, scope: Scope, receive: Receive, send: Send) -> Awaitable[None]:
        if scope["type"] not in ("http", "websocket") or scope["method"] == "OPTIONS":
1376
1377
1378
1379
1380
1381
1382
            # scope["type"] can be "lifespan" or "startup" for example,
            # in which case we don't need to do anything
            return self.app(scope, receive, send)
        root_path = scope.get("root_path", "")
        url_path = URL(scope=scope).path.removeprefix(root_path)
        headers = Headers(scope=scope)
        # Type narrow to satisfy mypy.
1383
        if url_path.startswith("/v1") and not self.verify_token(headers):
1384
            response = JSONResponse(content={"error": "Unauthorized"}, status_code=401)
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
            return response(scope, receive, send)
        return self.app(scope, receive, send)


class XRequestIdMiddleware:
    """
    Middleware the set's the X-Request-Id header for each response
    to a random uuid4 (hex) value if the header isn't already
    present in the request, otherwise use the provided request id.
    """

    def __init__(self, app: ASGIApp) -> None:
        self.app = app

1399
    def __call__(self, scope: Scope, receive: Receive, send: Send) -> Awaitable[None]:
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
        if scope["type"] not in ("http", "websocket"):
            return self.app(scope, receive, send)

        # Extract the request headers.
        request_headers = Headers(scope=scope)

        async def send_with_request_id(message: Message) -> None:
            """
            Custom send function to mutate the response headers
            and append X-Request-Id to it.
            """
            if message["type"] == "http.response.start":
                response_headers = MutableHeaders(raw=message["headers"])
1413
                request_id = request_headers.get("X-Request-Id", uuid.uuid4().hex)
1414
1415
1416
1417
1418
1419
                response_headers.append("X-Request-Id", request_id)
            await send(message)

        return self.app(scope, receive, send_with_request_id)


1420
1421
1422
1423
1424
1425
1426
1427
# Global variable to track scaling state
_scaling_elastic_ep = False


class ScalingMiddleware:
    """
    Middleware that checks if the model is currently scaling and
    returns a 503 Service Unavailable response if it is.
1428

1429
1430
1431
1432
1433
1434
1435
    This middleware applies to all HTTP requests and prevents
    processing when the model is in a scaling state.
    """

    def __init__(self, app: ASGIApp) -> None:
        self.app = app

1436
    def __call__(self, scope: Scope, receive: Receive, send: Send) -> Awaitable[None]:
1437
1438
1439
1440
1441
1442
1443
        if scope["type"] != "http":
            return self.app(scope, receive, send)

        # Check global scaling state
        global _scaling_elastic_ep
        if _scaling_elastic_ep:
            # Return 503 Service Unavailable response
1444
1445
1446
1447
1448
1449
            response = JSONResponse(
                content={
                    "error": "The model is currently scaling. Please try again later."
                },
                status_code=503,
            )
1450
1451
1452
1453
1454
            return response(scope, receive, send)

        return self.app(scope, receive, send)


1455
1456
1457
1458
def _extract_content_from_chunk(chunk_data: dict) -> str:
    """Extract content from a streaming response chunk."""
    try:
        from vllm.entrypoints.openai.protocol import (
1459
1460
1461
            ChatCompletionStreamResponse,
            CompletionStreamResponse,
        )
1462
1463

        # Try using Completion types for type-safe parsing
1464
1465
        if chunk_data.get("object") == "chat.completion.chunk":
            chat_response = ChatCompletionStreamResponse.model_validate(chunk_data)
1466
1467
            if chat_response.choices and chat_response.choices[0].delta.content:
                return chat_response.choices[0].delta.content
1468
1469
1470
        elif chunk_data.get("object") == "text_completion":
            completion_response = CompletionStreamResponse.model_validate(chunk_data)
            if completion_response.choices and completion_response.choices[0].text:
1471
1472
1473
                return completion_response.choices[0].text
    except pydantic.ValidationError:
        # Fallback to manual parsing
1474
1475
1476
1477
1478
1479
        if "choices" in chunk_data and chunk_data["choices"]:
            choice = chunk_data["choices"][0]
            if "delta" in choice and choice["delta"].get("content"):
                return choice["delta"]["content"]
            elif choice.get("text"):
                return choice["text"]
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
    return ""


class SSEDecoder:
    """Robust Server-Sent Events decoder for streaming responses."""

    def __init__(self):
        self.buffer = ""
        self.content_buffer = []

    def decode_chunk(self, chunk: bytes) -> list[dict]:
        """Decode a chunk of SSE data and return parsed events."""
        import json

        try:
1495
            chunk_str = chunk.decode("utf-8")
1496
1497
1498
1499
1500
1501
1502
1503
        except UnicodeDecodeError:
            # Skip malformed chunks
            return []

        self.buffer += chunk_str
        events = []

        # Process complete lines
1504
1505
1506
        while "\n" in self.buffer:
            line, self.buffer = self.buffer.split("\n", 1)
            line = line.rstrip("\r")  # Handle CRLF
1507

1508
            if line.startswith("data: "):
1509
                data_str = line[6:].strip()
1510
1511
                if data_str == "[DONE]":
                    events.append({"type": "done"})
1512
1513
1514
                elif data_str:
                    try:
                        event_data = json.loads(data_str)
1515
                        events.append({"type": "data", "data": event_data})
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
                    except json.JSONDecodeError:
                        # Skip malformed JSON
                        continue

        return events

    def extract_content(self, event_data: dict) -> str:
        """Extract content from event data."""
        return _extract_content_from_chunk(event_data)

    def add_content(self, content: str) -> None:
        """Add content to the buffer."""
        if content:
            self.content_buffer.append(content)

    def get_complete_content(self) -> str:
        """Get the complete buffered content."""
1533
        return "".join(self.content_buffer)
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553


def _log_streaming_response(response, response_body: list) -> None:
    """Log streaming response with robust SSE parsing."""
    from starlette.concurrency import iterate_in_threadpool

    sse_decoder = SSEDecoder()
    chunk_count = 0

    def buffered_iterator():
        nonlocal chunk_count

        for chunk in response_body:
            chunk_count += 1
            yield chunk

            # Parse SSE events from chunk
            events = sse_decoder.decode_chunk(chunk)

            for event in events:
1554
1555
                if event["type"] == "data":
                    content = sse_decoder.extract_content(event["data"])
1556
                    sse_decoder.add_content(content)
1557
                elif event["type"] == "done":
1558
1559
1560
1561
1562
1563
1564
1565
                    # Log complete content when done
                    full_content = sse_decoder.get_complete_content()
                    if full_content:
                        # Truncate if too long
                        if len(full_content) > 2048:
                            full_content = full_content[:2048] + ""
                            "...[truncated]"
                        logger.info(
1566
                            "response_body={streaming_complete: "
1567
                            "content='%s', chunks=%d}",
1568
1569
1570
                            full_content,
                            chunk_count,
                        )
1571
1572
                    else:
                        logger.info(
1573
1574
1575
                            "response_body={streaming_complete: no_content, chunks=%d}",
                            chunk_count,
                        )
1576
1577
1578
                    return

    response.body_iterator = iterate_in_threadpool(buffered_iterator())
1579
    logger.info("response_body={streaming_started: chunks=%d}", len(response_body))
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590


def _log_non_streaming_response(response_body: list) -> None:
    """Log non-streaming response."""
    try:
        decoded_body = response_body[0].decode()
        logger.info("response_body={%s}", decoded_body)
    except UnicodeDecodeError:
        logger.info("response_body={<binary_data>}")


1591
def build_app(args: Namespace) -> FastAPI:
1592
    if args.disable_fastapi_docs:
1593
1594
1595
        app = FastAPI(
            openapi_url=None, docs_url=None, redoc_url=None, lifespan=lifespan
        )
1596
1597
    else:
        app = FastAPI(lifespan=lifespan)
Ethan Xu's avatar
Ethan Xu committed
1598
1599
    app.include_router(router)
    app.root_path = args.root_path
Zhuohan Li's avatar
Zhuohan Li committed
1600

1601
1602
    mount_metrics(app)

Zhuohan Li's avatar
Zhuohan Li committed
1603
1604
1605
1606
1607
1608
1609
1610
    app.add_middleware(
        CORSMiddleware,
        allow_origins=args.allowed_origins,
        allow_credentials=args.allow_credentials,
        allow_methods=args.allowed_methods,
        allow_headers=args.allowed_headers,
    )

1611
1612
    @app.exception_handler(HTTPException)
    async def http_exception_handler(_: Request, exc: HTTPException):
1613
        err = ErrorResponse(
1614
1615
1616
1617
1618
1619
            error=ErrorInfo(
                message=exc.detail,
                type=HTTPStatus(exc.status_code).phrase,
                code=exc.status_code,
            )
        )
1620
1621
        return JSONResponse(err.model_dump(), status_code=exc.status_code)

Ethan Xu's avatar
Ethan Xu committed
1622
    @app.exception_handler(RequestValidationError)
1623
    async def validation_exception_handler(_: Request, exc: RequestValidationError):
1624
1625
1626
1627
1628
1629
1630
1631
        exc_str = str(exc)
        errors_str = str(exc.errors())

        if exc.errors() and errors_str and errors_str != exc_str:
            message = f"{exc_str} {errors_str}"
        else:
            message = exc_str

1632
1633
1634
1635
1636
1637
1638
1639
        err = ErrorResponse(
            error=ErrorInfo(
                message=message,
                type=HTTPStatus.BAD_REQUEST.phrase,
                code=HTTPStatus.BAD_REQUEST,
            )
        )
        return JSONResponse(err.model_dump(), status_code=HTTPStatus.BAD_REQUEST)
Ethan Xu's avatar
Ethan Xu committed
1640

1641
    # Ensure --api-key option from CLI takes precedence over VLLM_API_KEY
1642
1643
    if tokens := [key for key in (args.api_key or [envs.VLLM_API_KEY]) if key]:
        app.add_middleware(AuthenticationMiddleware, tokens=tokens)
1644

1645
    if args.enable_request_id_headers:
1646
        app.add_middleware(XRequestIdMiddleware)
1647

1648
1649
1650
    # Add scaling middleware to check for scaling state
    app.add_middleware(ScalingMiddleware)

1651
    if envs.VLLM_DEBUG_LOG_API_SERVER_RESPONSE:
1652
1653
1654
1655
1656
        logger.warning(
            "CAUTION: Enabling log response in the API Server. "
            "This can include sensitive information and should be "
            "avoided in production."
        )
1657
1658
1659
1660

        @app.middleware("http")
        async def log_response(request: Request, call_next):
            response = await call_next(request)
1661
            response_body = [section async for section in response.body_iterator]
1662
            response.body_iterator = iterate_in_threadpool(iter(response_body))
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
            # Check if this is a streaming response by looking at content-type
            content_type = response.headers.get("content-type", "")
            is_streaming = content_type == "text/event-stream; charset=utf-8"

            # Log response body based on type
            if not response_body:
                logger.info("response_body={<empty>}")
            elif is_streaming:
                _log_streaming_response(response, response_body)
            else:
                _log_non_streaming_response(response_body)
1674
            return response
1675

1676
1677
1678
1679
    for middleware in args.middleware:
        module_path, object_name = middleware.rsplit(".", 1)
        imported = getattr(importlib.import_module(module_path), object_name)
        if inspect.isclass(imported):
1680
            app.add_middleware(imported)  # type: ignore[arg-type]
1681
1682
1683
        elif inspect.iscoroutinefunction(imported):
            app.middleware("http")(imported)
        else:
1684
1685
1686
            raise ValueError(
                f"Invalid middleware {middleware}. Must be a function or a class."
            )
1687

Ethan Xu's avatar
Ethan Xu committed
1688
1689
1690
    return app


1691
async def init_app_state(
1692
    engine_client: EngineClient,
1693
    state: State,
1694
    args: Namespace,
1695
) -> None:
1696
1697
    vllm_config = engine_client.vllm_config

1698
    if args.served_model_name is not None:
1699
        served_model_names = args.served_model_name
1700
    else:
1701
        served_model_names = [args.model]
1702

1703
    if args.enable_log_requests:
1704
        request_logger = RequestLogger(max_log_len=args.max_log_len)
1705
1706
    else:
        request_logger = None
1707

1708
    base_model_paths = [
1709
        BaseModelPath(name=name, model_path=args.model) for name in served_model_names
1710
1711
    ]

1712
    state.engine_client = engine_client
1713
    state.log_stats = not args.disable_log_stats
1714
    state.vllm_config = vllm_config
Ethan Xu's avatar
Ethan Xu committed
1715

1716
    supported_tasks = await engine_client.get_supported_tasks()
1717
    logger.info("Supported tasks: %s", supported_tasks)
1718

1719
1720
1721
    resolved_chat_template = await process_chat_template(
        args.chat_template, engine_client, vllm_config.model_config
    )
1722

1723
    if args.tool_server == "demo":
1724
        tool_server: ToolServer | None = DemoToolServer()
1725
1726
        assert isinstance(tool_server, DemoToolServer)
        await tool_server.init_and_validate()
1727
1728
1729
    elif args.tool_server:
        tool_server = MCPToolServer()
        await tool_server.add_tool_server(args.tool_server)
1730
1731
1732
    else:
        tool_server = None

1733
    # Merge default_mm_loras into the static lora_modules
1734
1735
1736
1737
1738
    default_mm_loras = (
        vllm_config.lora_config.default_mm_loras
        if vllm_config.lora_config is not None
        else {}
    )
1739

1740
1741
1742
1743
1744
1745
    default_mm_loras = (
        vllm_config.lora_config.default_mm_loras
        if vllm_config.lora_config is not None
        else {}
    )
    lora_modules = process_lora_modules(args.lora_modules, default_mm_loras)
1746

1747
    state.openai_serving_models = OpenAIServingModels(
1748
        engine_client=engine_client,
1749
        base_model_paths=base_model_paths,
1750
        lora_modules=lora_modules,
1751
    )
1752
    await state.openai_serving_models.init_static_loras()
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803
1804
1805
1806
1807
1808
    state.openai_serving_responses = (
        OpenAIServingResponses(
            engine_client,
            state.openai_serving_models,
            request_logger=request_logger,
            chat_template=resolved_chat_template,
            chat_template_content_format=args.chat_template_content_format,
            return_tokens_as_token_ids=args.return_tokens_as_token_ids,
            enable_auto_tools=args.enable_auto_tool_choice,
            tool_parser=args.tool_call_parser,
            tool_server=tool_server,
            reasoning_parser=args.structured_outputs_config.reasoning_parser,
            enable_prompt_tokens_details=args.enable_prompt_tokens_details,
            enable_force_include_usage=args.enable_force_include_usage,
            enable_log_outputs=args.enable_log_outputs,
            log_error_stack=args.log_error_stack,
        )
        if "generate" in supported_tasks
        else None
    )
    state.openai_serving_chat = (
        OpenAIServingChat(
            engine_client,
            state.openai_serving_models,
            args.response_role,
            request_logger=request_logger,
            chat_template=resolved_chat_template,
            chat_template_content_format=args.chat_template_content_format,
            trust_request_chat_template=args.trust_request_chat_template,
            return_tokens_as_token_ids=args.return_tokens_as_token_ids,
            enable_auto_tools=args.enable_auto_tool_choice,
            exclude_tools_when_tool_choice_none=args.exclude_tools_when_tool_choice_none,
            tool_parser=args.tool_call_parser,
            reasoning_parser=args.structured_outputs_config.reasoning_parser,
            enable_prompt_tokens_details=args.enable_prompt_tokens_details,
            enable_force_include_usage=args.enable_force_include_usage,
            enable_log_outputs=args.enable_log_outputs,
            log_error_stack=args.log_error_stack,
        )
        if "generate" in supported_tasks
        else None
    )
    state.openai_serving_completion = (
        OpenAIServingCompletion(
            engine_client,
            state.openai_serving_models,
            request_logger=request_logger,
            return_tokens_as_token_ids=args.return_tokens_as_token_ids,
            enable_prompt_tokens_details=args.enable_prompt_tokens_details,
            enable_force_include_usage=args.enable_force_include_usage,
            log_error_stack=args.log_error_stack,
        )
        if "generate" in supported_tasks
        else None
    )
    state.openai_serving_pooling = (
1809
1810
1811
1812
1813
1814
1815
1816
1817
1818
1819
        (
            OpenAIServingPooling(
                engine_client,
                state.openai_serving_models,
                supported_tasks=supported_tasks,
                request_logger=request_logger,
                chat_template=resolved_chat_template,
                chat_template_content_format=args.chat_template_content_format,
                trust_request_chat_template=args.trust_request_chat_template,
                log_error_stack=args.log_error_stack,
            )
1820
        )
1821
        if any(task in POOLING_TASKS for task in supported_tasks)
1822
1823
1824
1825
1826
1827
1828
1829
1830
1831
1832
1833
1834
1835
1836
1837
1838
1839
1840
1841
1842
1843
1844
1845
1846
1847
1848
1849
1850
1851
1852
1853
1854
1855
1856
        else None
    )
    state.openai_serving_embedding = (
        OpenAIServingEmbedding(
            engine_client,
            state.openai_serving_models,
            request_logger=request_logger,
            chat_template=resolved_chat_template,
            chat_template_content_format=args.chat_template_content_format,
            trust_request_chat_template=args.trust_request_chat_template,
            log_error_stack=args.log_error_stack,
        )
        if "embed" in supported_tasks
        else None
    )
    state.openai_serving_classification = (
        ServingClassification(
            engine_client,
            state.openai_serving_models,
            request_logger=request_logger,
            log_error_stack=args.log_error_stack,
        )
        if "classify" in supported_tasks
        else None
    )
    state.openai_serving_scores = (
        ServingScores(
            engine_client,
            state.openai_serving_models,
            request_logger=request_logger,
            log_error_stack=args.log_error_stack,
        )
        if ("embed" in supported_tasks or "score" in supported_tasks)
        else None
    )
1857
    state.openai_serving_tokenization = OpenAIServingTokenization(
1858
        engine_client,
1859
        state.openai_serving_models,
1860
        request_logger=request_logger,
1861
1862
        chat_template=resolved_chat_template,
        chat_template_content_format=args.chat_template_content_format,
1863
        trust_request_chat_template=args.trust_request_chat_template,
1864
        log_error_stack=args.log_error_stack,
1865
    )
1866
1867
1868
1869
1870
1871
    state.openai_serving_transcription = (
        OpenAIServingTranscription(
            engine_client,
            state.openai_serving_models,
            request_logger=request_logger,
            log_error_stack=args.log_error_stack,
1872
            enable_force_include_usage=args.enable_force_include_usage,
1873
1874
1875
1876
1877
1878
1879
1880
1881
1882
        )
        if "transcription" in supported_tasks
        else None
    )
    state.openai_serving_translation = (
        OpenAIServingTranslation(
            engine_client,
            state.openai_serving_models,
            request_logger=request_logger,
            log_error_stack=args.log_error_stack,
1883
            enable_force_include_usage=args.enable_force_include_usage,
1884
1885
1886
1887
        )
        if "transcription" in supported_tasks
        else None
    )
1888
1889
1890
1891
1892
1893
1894
1895
1896
1897
1898
1899
1900
1901
1902
1903
1904
1905
    state.anthropic_serving_messages = (
        AnthropicServingMessages(
            engine_client,
            state.openai_serving_models,
            args.response_role,
            request_logger=request_logger,
            chat_template=resolved_chat_template,
            chat_template_content_format=args.chat_template_content_format,
            return_tokens_as_token_ids=args.return_tokens_as_token_ids,
            enable_auto_tools=args.enable_auto_tool_choice,
            tool_parser=args.tool_call_parser,
            reasoning_parser=args.structured_outputs_config.reasoning_parser,
            enable_prompt_tokens_details=args.enable_prompt_tokens_details,
            enable_force_include_usage=args.enable_force_include_usage,
        )
        if "generate" in supported_tasks
        else None
    )
1906

1907
1908
1909
    state.enable_server_load_tracking = args.enable_server_load_tracking
    state.server_load_metrics = 0

1910

1911
def create_server_socket(addr: tuple[str, int]) -> socket.socket:
1912
1913
1914
1915
1916
1917
    family = socket.AF_INET
    if is_valid_ipv6_address(addr[0]):
        family = socket.AF_INET6

    sock = socket.socket(family=family, type=socket.SOCK_STREAM)
    sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
1918
    sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
1919
1920
1921
1922
1923
    sock.bind(addr)

    return sock


1924
1925
1926
1927
1928
1929
def create_server_unix_socket(path: str) -> socket.socket:
    sock = socket.socket(family=socket.AF_UNIX, type=socket.SOCK_STREAM)
    sock.bind(path)
    return sock


1930
def validate_api_server_args(args):
1931
    valid_tool_parses = ToolParserManager.tool_parsers.keys()
1932
1933
1934
1935
1936
    if args.enable_auto_tool_choice and args.tool_call_parser not in valid_tool_parses:
        raise KeyError(
            f"invalid tool call parser: {args.tool_call_parser} "
            f"(chose from {{ {','.join(valid_tool_parses)} }})"
        )
1937

1938
    valid_reasoning_parses = ReasoningParserManager.reasoning_parsers.keys()
1939
1940
1941
    if (
        reasoning_parser := args.structured_outputs_config.reasoning_parser
    ) and reasoning_parser not in valid_reasoning_parses:
1942
        raise KeyError(
1943
            f"invalid reasoning parser: {reasoning_parser} "
1944
1945
            f"(chose from {{ {','.join(valid_reasoning_parses)} }})"
        )
1946

1947
1948
1949
1950
1951
1952
1953
1954
1955
1956
1957
1958
1959

def setup_server(args):
    """Validate API server args, set up signal handler, create socket
    ready to serve."""

    logger.info("vLLM API server version %s", VLLM_VERSION)
    log_non_default_args(args)

    if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3:
        ToolParserManager.import_tool_parser(args.tool_parser_plugin)

    validate_api_server_args(args)

1960
1961
1962
    # workaround to make sure that we bind the port before the engine is set up.
    # This avoids race conditions with ray.
    # see https://github.com/vllm-project/vllm/issues/8204
1963
1964
1965
1966
1967
    if args.uds:
        sock = create_server_unix_socket(args.uds)
    else:
        sock_addr = (args.host or "", args.port)
        sock = create_server_socket(sock_addr)
1968

1969
1970
1971
1972
    # workaround to avoid footguns where uvicorn drops requests with too
    # many concurrent requests active
    set_ulimit()

1973
1974
1975
1976
1977
1978
    def signal_handler(*_) -> None:
        # Interrupt server on sigterm while initializing
        raise KeyboardInterrupt("terminated")

    signal.signal(signal.SIGTERM, signal_handler)

1979
1980
1981
1982
1983
    if args.uds:
        listen_address = f"unix:{args.uds}"
    else:
        addr, port = sock_addr
        is_ssl = args.ssl_keyfile and args.ssl_certfile
1984
        host_part = f"[{addr}]" if is_valid_ipv6_address(addr) else addr or "0.0.0.0"
1985
        listen_address = f"http{'s' if is_ssl else ''}://{host_part}:{port}"
1986
1987
1988
1989
1990
    return listen_address, sock


async def run_server(args, **uvicorn_kwargs) -> None:
    """Run a single-worker API server."""
1991
1992

    # Add process-specific prefix to stdout and stderr.
1993
    decorate_logs("APIServer")
1994

1995
1996
1997
1998
    listen_address, sock = setup_server(args)
    await run_server_worker(listen_address, sock, args, **uvicorn_kwargs)


1999
2000
2001
async def run_server_worker(
    listen_address, sock, args, client_config=None, **uvicorn_kwargs
) -> None:
2002
2003
2004
2005
2006
    """Run a single API server worker."""

    if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3:
        ToolParserManager.import_tool_parser(args.tool_parser_plugin)

2007
2008
2009
    # Load logging config for uvicorn if specified
    log_config = load_log_config(args.log_config_file)
    if log_config is not None:
2010
        uvicorn_kwargs["log_config"] = log_config
2011

2012
    async with build_async_engine_client(
2013
2014
        args,
        client_config=client_config,
2015
    ) as engine_client:
2016
        maybe_register_tokenizer_info_endpoint(args)
2017
2018
        app = build_app(args)

2019
        await init_app_state(engine_client, app.state, args)
2020

2021
2022
        logger.info(
            "Starting vLLM API server %d on %s",
2023
            engine_client.vllm_config.parallel_config._api_process_rank,
2024
2025
            listen_address,
        )
2026
2027
        shutdown_task = await serve_http(
            app,
2028
            sock=sock,
2029
            enable_ssl_refresh=args.enable_ssl_refresh,
2030
2031
2032
            host=args.host,
            port=args.port,
            log_level=args.uvicorn_log_level,
2033
2034
2035
            # NOTE: When the 'disable_uvicorn_access_log' value is True,
            # no access log will be output.
            access_log=not args.disable_uvicorn_access_log,
2036
            timeout_keep_alive=envs.VLLM_HTTP_TIMEOUT_KEEP_ALIVE,
2037
2038
2039
2040
            ssl_keyfile=args.ssl_keyfile,
            ssl_certfile=args.ssl_certfile,
            ssl_ca_certs=args.ssl_ca_certs,
            ssl_cert_reqs=args.ssl_cert_reqs,
2041
2042
            h11_max_incomplete_event_size=args.h11_max_incomplete_event_size,
            h11_max_header_count=args.h11_max_header_count,
2043
2044
2045
            **uvicorn_kwargs,
        )

2046
    # NB: Await server shutdown only after the backend context is exited
2047
2048
2049
2050
    try:
        await shutdown_task
    finally:
        sock.close()
2051

Ethan Xu's avatar
Ethan Xu committed
2052
2053
2054

if __name__ == "__main__":
    # NOTE(simon):
2055
2056
    # This section should be in sync with vllm/entrypoints/cli/main.py for CLI
    # entrypoints.
2057
    cli_env_setup()
Ethan Xu's avatar
Ethan Xu committed
2058
    parser = FlexibleArgumentParser(
2059
2060
        description="vLLM OpenAI-Compatible RESTful API server."
    )
Ethan Xu's avatar
Ethan Xu committed
2061
2062
    parser = make_arg_parser(parser)
    args = parser.parse_args()
2063
    validate_parsed_serve_args(args)
2064

2065
    uvloop.run(run_server(args))