api_server.py 69.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import asyncio
5
import gc
6
import hashlib
7
8
import importlib
import inspect
9
import json
10
import multiprocessing
11
import multiprocessing.forkserver as forkserver
12
import os
13
import secrets
14
import signal
15
import socket
16
import tempfile
17
import uuid
18
from argparse import Namespace
19
from collections.abc import AsyncGenerator, AsyncIterator, Awaitable, Callable
20
from contextlib import asynccontextmanager
21
from http import HTTPStatus
22
from typing import Annotated, Any, Literal
23

24
import prometheus_client
25
import pydantic
26
import regex as re
27
import uvloop
28
from fastapi import APIRouter, Depends, FastAPI, Form, HTTPException, Query, Request
Zhuohan Li's avatar
Zhuohan Li committed
29
30
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
31
from fastapi.responses import JSONResponse, Response, StreamingResponse
32
33
from prometheus_client import make_asgi_app
from prometheus_fastapi_instrumentator import Instrumentator
34
from starlette.concurrency import iterate_in_threadpool
35
from starlette.datastructures import URL, Headers, MutableHeaders, State
36
from starlette.routing import Mount
37
from starlette.types import ASGIApp, Message, Receive, Scope, Send
38
from typing_extensions import assert_never
Zhuohan Li's avatar
Zhuohan Li committed
39

40
import vllm.envs as envs
41
from vllm.config import VllmConfig
Woosuk Kwon's avatar
Woosuk Kwon committed
42
from vllm.engine.arg_utils import AsyncEngineArgs
43
from vllm.engine.protocol import EngineClient
44
45
46
47
48
from vllm.entrypoints.chat_utils import (
    load_chat_template,
    resolve_hf_chat_template,
    resolve_mistral_chat_template,
)
49
from vllm.entrypoints.launcher import serve_http
50
from vllm.entrypoints.logger import RequestLogger
51
52
53
54
55
56
57
58
59
60
from vllm.entrypoints.openai.cli_args import make_arg_parser, validate_parsed_serve_args
from vllm.entrypoints.openai.protocol import (
    ChatCompletionRequest,
    ChatCompletionResponse,
    ClassificationRequest,
    ClassificationResponse,
    CompletionRequest,
    CompletionResponse,
    DetokenizeRequest,
    DetokenizeResponse,
61
    EmbeddingBytesResponse,
62
63
64
65
66
67
    EmbeddingRequest,
    EmbeddingResponse,
    ErrorInfo,
    ErrorResponse,
    IOProcessorResponse,
    LoadLoRAAdapterRequest,
68
    PoolingBytesResponse,
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
    PoolingRequest,
    PoolingResponse,
    RerankRequest,
    RerankResponse,
    ResponsesRequest,
    ResponsesResponse,
    ScoreRequest,
    ScoreResponse,
    StreamingResponsesResponse,
    TokenizeRequest,
    TokenizeResponse,
    TranscriptionRequest,
    TranscriptionResponse,
    TranslationRequest,
    TranslationResponse,
    UnloadLoRAAdapterRequest,
)
86
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
87
from vllm.entrypoints.openai.serving_classification import ServingClassification
88
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
89
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
90
from vllm.entrypoints.openai.serving_engine import OpenAIServing
91
92
93
94
95
from vllm.entrypoints.openai.serving_models import (
    BaseModelPath,
    LoRAModulePath,
    OpenAIServingModels,
)
96
from vllm.entrypoints.openai.serving_pooling import OpenAIServingPooling
97
from vllm.entrypoints.openai.serving_responses import OpenAIServingResponses
98
from vllm.entrypoints.openai.serving_score import ServingScores
99
from vllm.entrypoints.openai.serving_tokenization import OpenAIServingTokenization
100
from vllm.entrypoints.openai.serving_transcription import (
101
102
103
    OpenAIServingTranscription,
    OpenAIServingTranslation,
)
104
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
105
106
107
108
109
110
111
from vllm.entrypoints.tool_server import DemoToolServer, MCPToolServer, ToolServer
from vllm.entrypoints.utils import (
    cli_env_setup,
    load_aware_call,
    log_non_default_args,
    with_cancellation,
)
112
from vllm.logger import init_logger
113
from vllm.reasoning import ReasoningParserManager
114
from vllm.transformers_utils.tokenizer import MistralTokenizer
yhu422's avatar
yhu422 committed
115
from vllm.usage.usage_lib import UsageContext
116
117
118
119
120
121
from vllm.utils import (
    Device,
    FlexibleArgumentParser,
    decorate_logs,
    set_ulimit,
)
122
from vllm.utils.network_utils import is_valid_ipv6_address
123
from vllm.v1.engine.exceptions import EngineDeadError
124
from vllm.v1.metrics.prometheus import get_prometheus_registry
125
from vllm.version import __version__ as VLLM_VERSION
Zhuohan Li's avatar
Zhuohan Li committed
126

127
prometheus_multiproc_dir: tempfile.TemporaryDirectory
128

129
# Cannot use __name__ (https://github.com/vllm-project/vllm/pull/4765)
130
logger = init_logger("vllm.entrypoints.openai.api_server")
131

132
_running_tasks: set[asyncio.Task] = set()
133

134

135
@asynccontextmanager
136
async def lifespan(app: FastAPI):
137
138
    try:
        if app.state.log_stats:
139
            engine_client: EngineClient = app.state.engine_client
140
141
142

            async def _force_log():
                while True:
143
                    await asyncio.sleep(envs.VLLM_LOG_STATS_INTERVAL)
144
                    await engine_client.do_log_stats()
145
146
147
148
149
150

            task = asyncio.create_task(_force_log())
            _running_tasks.add(task)
            task.add_done_callback(_running_tasks.remove)
        else:
            task = None
151
152
153
154
155

        # Mark the startup heap as static so that it's ignored by GC.
        # Reduces pause times of oldest generation collections.
        gc.collect()
        gc.freeze()
156
157
158
159
160
161
162
163
        try:
            yield
        finally:
            if task is not None:
                task.cancel()
    finally:
        # Ensure app state including engine ref is gc'd
        del app.state
164
165


166
@asynccontextmanager
167
async def build_async_engine_client(
168
    args: Namespace,
169
170
    *,
    usage_context: UsageContext = UsageContext.OPENAI_API_SERVER,
171
172
    disable_frontend_multiprocessing: bool | None = None,
    client_config: dict[str, Any] | None = None,
173
) -> AsyncIterator[EngineClient]:
174
175
176
177
    if os.getenv("VLLM_WORKER_MULTIPROC_METHOD") == "forkserver":
        # The executor is expected to be mp.
        # Pre-import heavy modules in the forkserver process
        logger.debug("Setup forkserver with pre-imports")
178
        multiprocessing.set_start_method("forkserver")
179
180
181
182
        multiprocessing.set_forkserver_preload(["vllm.v1.engine.async_llm"])
        forkserver.ensure_running()
        logger.debug("Forkserver setup complete!")

183
    # Context manager to handle engine_client lifecycle
184
185
    # Ensures everything is shutdown and cleaned up on error/exit
    engine_args = AsyncEngineArgs.from_cli_args(args)
186
187
188
    if client_config:
        engine_args._api_process_count = client_config.get("client_count", 1)
        engine_args._api_process_rank = client_config.get("client_index", 0)
189

190
    if disable_frontend_multiprocessing is None:
191
        disable_frontend_multiprocessing = bool(args.disable_frontend_multiprocessing)
192

193
    async with build_async_engine_client_from_engine_args(
194
195
196
197
        engine_args,
        usage_context=usage_context,
        disable_frontend_multiprocessing=disable_frontend_multiprocessing,
        client_config=client_config,
198
    ) as engine:
199
200
201
202
203
204
        yield engine


@asynccontextmanager
async def build_async_engine_client_from_engine_args(
    engine_args: AsyncEngineArgs,
205
206
    *,
    usage_context: UsageContext = UsageContext.OPENAI_API_SERVER,
207
    disable_frontend_multiprocessing: bool = False,
208
    client_config: dict[str, Any] | None = None,
209
) -> AsyncIterator[EngineClient]:
210
    """
211
    Create EngineClient, either:
212
213
214
215
216
217
        - in-process using the AsyncLLMEngine Directly
        - multiprocess using AsyncLLMEngine RPC

    Returns the Client or None if the creation failed.
    """

218
219
220
221
    # Create the EngineConfig (determines if we can use V1).
    vllm_config = engine_args.create_engine_config(usage_context=usage_context)

    # V1 AsyncLLM.
222
    assert envs.VLLM_USE_V1
223

224
225
226
    if disable_frontend_multiprocessing:
        logger.warning(
            "V1 is enabled, but got --disable-frontend-multiprocessing. "
227
228
            "To disable frontend multiprocessing, set VLLM_USE_V1=0."
        )
229

230
    from vllm.v1.engine.async_llm import AsyncLLM
231

232
    async_llm: AsyncLLM | None = None
233
234
235
236
237
238

    # Don't mutate the input client_config
    client_config = dict(client_config) if client_config else {}
    client_count = client_config.pop("client_count", 1)
    client_index = client_config.pop("client_index", 0)

239
240
241
242
243
    try:
        async_llm = AsyncLLM.from_vllm_config(
            vllm_config=vllm_config,
            usage_context=usage_context,
            enable_log_requests=engine_args.enable_log_requests,
244
            aggregate_engine_logging=engine_args.aggregate_engine_logging,
245
246
247
            disable_log_stats=engine_args.disable_log_stats,
            client_addresses=client_config,
            client_count=client_count,
248
249
            client_index=client_index,
        )
250
251
252
253
254
255
256
257

        # Don't keep the dummy data in memory
        await async_llm.reset_mm_cache()

        yield async_llm
    finally:
        if async_llm:
            async_llm.shutdown()
258
259


260
261
async def validate_json_request(raw_request: Request):
    content_type = raw_request.headers.get("content-type", "").lower()
262
263
    media_type = content_type.split(";", maxsplit=1)[0]
    if media_type != "application/json":
264
265
266
        raise RequestValidationError(
            errors=["Unsupported Media Type: Only 'application/json' is allowed"]
        )
267
268


Ethan Xu's avatar
Ethan Xu committed
269
router = APIRouter()
Zhuohan Li's avatar
Zhuohan Li committed
270

271

272
273
274
275
class PrometheusResponse(Response):
    media_type = prometheus_client.CONTENT_TYPE_LATEST


276
def mount_metrics(app: FastAPI):
277
278
279
    """Mount prometheus metrics to a FastAPI app."""

    registry = get_prometheus_registry()
280

281
282
283
284
    # `response_class=PrometheusResponse` is needed to return an HTTP response
    # with header "Content-Type: text/plain; version=0.0.4; charset=utf-8"
    # instead of the default "application/json" which is incorrect.
    # See https://github.com/trallnag/prometheus-fastapi-instrumentator/issues/163#issue-1296092364
285
286
287
288
289
290
291
292
293
294
    Instrumentator(
        excluded_handlers=[
            "/metrics",
            "/health",
            "/load",
            "/ping",
            "/version",
            "/server_info",
        ],
        registry=registry,
295
    ).add().instrument(app).expose(app, response_class=PrometheusResponse)
296
297
298

    # Add prometheus asgi middleware to route /metrics requests
    metrics_route = Mount("/metrics", make_asgi_app(registry=registry))
299

300
    # Workaround for 307 Redirect for /metrics
301
    metrics_route.path_regex = re.compile("^/metrics(?P<path>.*)$")
302
    app.routes.append(metrics_route)
303
304


305
306
307
308
309
def base(request: Request) -> OpenAIServing:
    # Reuse the existing instance
    return tokenization(request)


310
311
312
313
def models(request: Request) -> OpenAIServingModels:
    return request.app.state.openai_serving_models


314
def responses(request: Request) -> OpenAIServingResponses | None:
315
316
317
    return request.app.state.openai_serving_responses


318
def chat(request: Request) -> OpenAIServingChat | None:
319
320
321
    return request.app.state.openai_serving_chat


322
def completion(request: Request) -> OpenAIServingCompletion | None:
323
324
325
    return request.app.state.openai_serving_completion


326
def pooling(request: Request) -> OpenAIServingPooling | None:
327
328
329
    return request.app.state.openai_serving_pooling


330
def embedding(request: Request) -> OpenAIServingEmbedding | None:
331
    return request.app.state.openai_serving_embedding
332
333


334
def score(request: Request) -> ServingScores | None:
335
336
337
    return request.app.state.openai_serving_scores


338
def classify(request: Request) -> ServingClassification | None:
339
340
341
    return request.app.state.openai_serving_classification


342
def rerank(request: Request) -> ServingScores | None:
343
    return request.app.state.openai_serving_scores
344
345


346
347
def tokenization(request: Request) -> OpenAIServingTokenization:
    return request.app.state.openai_serving_tokenization
348
349


350
351
352
353
def transcription(request: Request) -> OpenAIServingTranscription:
    return request.app.state.openai_serving_transcription


354
355
356
357
def translation(request: Request) -> OpenAIServingTranslation:
    return request.app.state.openai_serving_translation


358
def engine_client(request: Request) -> EngineClient:
359
360
361
    return request.app.state.engine_client


362
363
@router.get("/health", response_class=Response)
async def health(raw_request: Request) -> Response:
364
    """Health check."""
365
366
367
368
369
    try:
        await engine_client(raw_request).check_health()
        return Response(status_code=200)
    except EngineDeadError:
        return Response(status_code=503)
370
371


372
373
374
375
376
377
378
@router.get("/load")
async def get_server_load_metrics(request: Request):
    # This endpoint returns the current server load metrics.
    # It tracks requests utilizing the GPU from the following routes:
    # - /v1/chat/completions
    # - /v1/completions
    # - /v1/audio/transcriptions
379
    # - /v1/audio/translations
380
381
    # - /v1/embeddings
    # - /pooling
382
    # - /classify
383
384
385
386
387
    # - /score
    # - /v1/score
    # - /rerank
    # - /v1/rerank
    # - /v2/rerank
388
    return JSONResponse(content={"server_load": request.app.state.server_load_metrics})
389
390


391
392
393
@router.get("/ping", response_class=Response)
@router.post("/ping", response_class=Response)
async def ping(raw_request: Request) -> Response:
394
395
396
397
    """Ping check. Endpoint required for SageMaker"""
    return await health(raw_request)


398
399
400
401
402
403
404
405
406
407
@router.post(
    "/tokenize",
    dependencies=[Depends(validate_json_request)],
    responses={
        HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
        HTTPStatus.NOT_FOUND.value: {"model": ErrorResponse},
        HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
        HTTPStatus.NOT_IMPLEMENTED.value: {"model": ErrorResponse},
    },
)
408
@with_cancellation
409
async def tokenize(request: TokenizeRequest, raw_request: Request):
410
411
    handler = tokenization(raw_request)

412
413
414
    try:
        generator = await handler.create_tokenize(request, raw_request)
    except NotImplementedError as e:
415
416
417
        raise HTTPException(
            status_code=HTTPStatus.NOT_IMPLEMENTED.value, detail=str(e)
        ) from e
418
    except Exception as e:
419
420
421
        raise HTTPException(
            status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)
        ) from e
422

423
    if isinstance(generator, ErrorResponse):
424
425
426
        return JSONResponse(
            content=generator.model_dump(), status_code=generator.error.code
        )
427
    elif isinstance(generator, TokenizeResponse):
428
429
        return JSONResponse(content=generator.model_dump())

430
431
    assert_never(generator)

432

433
434
435
436
437
438
439
440
441
@router.post(
    "/detokenize",
    dependencies=[Depends(validate_json_request)],
    responses={
        HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
        HTTPStatus.NOT_FOUND.value: {"model": ErrorResponse},
        HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
    },
)
442
@with_cancellation
443
async def detokenize(request: DetokenizeRequest, raw_request: Request):
444
445
    handler = tokenization(raw_request)

446
447
448
449
450
    try:
        generator = await handler.create_detokenize(request, raw_request)
    except OverflowError as e:
        raise RequestValidationError(errors=[str(e)]) from e
    except Exception as e:
451
452
453
        raise HTTPException(
            status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)
        ) from e
454

455
    if isinstance(generator, ErrorResponse):
456
457
458
        return JSONResponse(
            content=generator.model_dump(), status_code=generator.error.code
        )
459
    elif isinstance(generator, DetokenizeResponse):
460
461
        return JSONResponse(content=generator.model_dump())

462
463
    assert_never(generator)

464

465
466
def maybe_register_tokenizer_info_endpoint(args):
    """Conditionally register the tokenizer info endpoint if enabled."""
467
    if getattr(args, "enable_tokenizer_info_endpoint", False):
468
469
470
471
472

        @router.get("/tokenizer_info")
        async def get_tokenizer_info(raw_request: Request):
            """Get comprehensive tokenizer information."""
            result = await tokenization(raw_request).get_tokenizer_info()
473
474
475
476
477
478
            return JSONResponse(
                content=result.model_dump(),
                status_code=result.error.code
                if isinstance(result, ErrorResponse)
                else 200,
            )
479
480


Ethan Xu's avatar
Ethan Xu committed
481
@router.get("/v1/models")
482
async def show_available_models(raw_request: Request):
483
    handler = models(raw_request)
484

485
486
    models_ = await handler.show_available_models()
    return JSONResponse(content=models_.model_dump())
Zhuohan Li's avatar
Zhuohan Li committed
487
488


Ethan Xu's avatar
Ethan Xu committed
489
@router.get("/version")
490
async def show_version():
491
    ver = {"version": VLLM_VERSION}
492
493
494
    return JSONResponse(content=ver)


495
async def _convert_stream_to_sse_events(
496
    generator: AsyncGenerator[StreamingResponsesResponse, None],
497
) -> AsyncGenerator[str, None]:
498
499
    """Convert the generator to a stream of events in SSE format"""
    async for event in generator:
500
        event_type = getattr(event, "type", "unknown")
501
        # https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#event_stream_format
502
503
504
        event_data = (
            f"event: {event_type}\ndata: {event.model_dump_json(indent=None)}\n\n"
        )
505
506
507
        yield event_data


508
509
510
511
512
513
514
515
516
517
@router.post(
    "/v1/responses",
    dependencies=[Depends(validate_json_request)],
    responses={
        HTTPStatus.OK.value: {"content": {"text/event-stream": {}}},
        HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
        HTTPStatus.NOT_FOUND.value: {"model": ErrorResponse},
        HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
    },
)
518
519
520
521
522
@with_cancellation
async def create_responses(request: ResponsesRequest, raw_request: Request):
    handler = responses(raw_request)
    if handler is None:
        return base(raw_request).create_error_response(
523
524
            message="The model does not support Responses API"
        )
525
526
527
    try:
        generator = await handler.create_responses(request, raw_request)
    except Exception as e:
528
529
530
        raise HTTPException(
            status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)
        ) from e
531
532

    if isinstance(generator, ErrorResponse):
533
534
535
        return JSONResponse(
            content=generator.model_dump(), status_code=generator.error.code
        )
536
537
    elif isinstance(generator, ResponsesResponse):
        return JSONResponse(content=generator.model_dump())
538

539
540
541
    return StreamingResponse(
        content=_convert_stream_to_sse_events(generator), media_type="text/event-stream"
    )
542
543
544


@router.get("/v1/responses/{response_id}")
545
546
547
async def retrieve_responses(
    response_id: str,
    raw_request: Request,
548
549
    starting_after: int | None = None,
    stream: bool | None = False,
550
):
551
552
553
    handler = responses(raw_request)
    if handler is None:
        return base(raw_request).create_error_response(
554
555
            message="The model does not support Responses API"
        )
556

557
    try:
558
559
560
561
562
        response = await handler.retrieve_responses(
            response_id,
            starting_after=starting_after,
            stream=stream,
        )
563
    except Exception as e:
564
565
566
        raise HTTPException(
            status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)
        ) from e
567
568

    if isinstance(response, ErrorResponse):
569
570
571
        return JSONResponse(
            content=response.model_dump(), status_code=response.error.code
        )
572
573
    elif isinstance(response, ResponsesResponse):
        return JSONResponse(content=response.model_dump())
574
575
576
    return StreamingResponse(
        content=_convert_stream_to_sse_events(response), media_type="text/event-stream"
    )
577
578
579
580
581
582
583


@router.post("/v1/responses/{response_id}/cancel")
async def cancel_responses(response_id: str, raw_request: Request):
    handler = responses(raw_request)
    if handler is None:
        return base(raw_request).create_error_response(
584
585
            message="The model does not support Responses API"
        )
586

587
588
589
    try:
        response = await handler.cancel_responses(response_id)
    except Exception as e:
590
591
592
        raise HTTPException(
            status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)
        ) from e
593
594

    if isinstance(response, ErrorResponse):
595
596
597
        return JSONResponse(
            content=response.model_dump(), status_code=response.error.code
        )
598
599
600
    return JSONResponse(content=response.model_dump())


601
602
603
604
605
606
607
608
609
610
@router.post(
    "/v1/chat/completions",
    dependencies=[Depends(validate_json_request)],
    responses={
        HTTPStatus.OK.value: {"content": {"text/event-stream": {}}},
        HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
        HTTPStatus.NOT_FOUND.value: {"model": ErrorResponse},
        HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
    },
)
611
@with_cancellation
612
@load_aware_call
613
async def create_chat_completion(request: ChatCompletionRequest, raw_request: Request):
614
615
616
    handler = chat(raw_request)
    if handler is None:
        return base(raw_request).create_error_response(
617
618
            message="The model does not support Chat Completions API"
        )
619
620
621
    try:
        generator = await handler.create_chat_completion(request, raw_request)
    except Exception as e:
622
623
624
        raise HTTPException(
            status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)
        ) from e
625
    if isinstance(generator, ErrorResponse):
626
627
628
        return JSONResponse(
            content=generator.model_dump(), status_code=generator.error.code
        )
629

630
    elif isinstance(generator, ChatCompletionResponse):
631
        return JSONResponse(content=generator.model_dump())
632

633
634
    return StreamingResponse(content=generator, media_type="text/event-stream")

635

636
637
638
639
640
641
642
643
644
645
@router.post(
    "/v1/completions",
    dependencies=[Depends(validate_json_request)],
    responses={
        HTTPStatus.OK.value: {"content": {"text/event-stream": {}}},
        HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
        HTTPStatus.NOT_FOUND.value: {"model": ErrorResponse},
        HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
    },
)
646
@with_cancellation
647
@load_aware_call
648
async def create_completion(request: CompletionRequest, raw_request: Request):
649
650
651
    handler = completion(raw_request)
    if handler is None:
        return base(raw_request).create_error_response(
652
653
            message="The model does not support Completions API"
        )
654

655
656
657
    try:
        generator = await handler.create_completion(request, raw_request)
    except OverflowError as e:
658
659
660
        raise HTTPException(
            status_code=HTTPStatus.BAD_REQUEST.value, detail=str(e)
        ) from e
661
    except Exception as e:
662
663
664
        raise HTTPException(
            status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)
        ) from e
665

666
    if isinstance(generator, ErrorResponse):
667
668
669
        return JSONResponse(
            content=generator.model_dump(), status_code=generator.error.code
        )
670
    elif isinstance(generator, CompletionResponse):
671
        return JSONResponse(content=generator.model_dump())
Zhuohan Li's avatar
Zhuohan Li committed
672

673
674
    return StreamingResponse(content=generator, media_type="text/event-stream")

Zhuohan Li's avatar
Zhuohan Li committed
675

676
677
678
679
680
681
682
683
@router.post(
    "/v1/embeddings",
    dependencies=[Depends(validate_json_request)],
    responses={
        HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
        HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
    },
)
684
@with_cancellation
685
@load_aware_call
686
687
688
689
async def create_embedding(
    request: EmbeddingRequest,
    raw_request: Request,
):
690
691
    handler = embedding(raw_request)
    if handler is None:
692
        return base(raw_request).create_error_response(
693
694
            message="The model does not support Embeddings API"
        )
695

696
697
698
    try:
        generator = await handler.create_embedding(request, raw_request)
    except Exception as e:
699
700
701
        raise HTTPException(
            status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)
        ) from e
702

703
    if isinstance(generator, ErrorResponse):
704
705
706
        return JSONResponse(
            content=generator.model_dump(), status_code=generator.error.code
        )
707
    elif isinstance(generator, EmbeddingResponse):
708
        return JSONResponse(content=generator.model_dump())
709
710
711
712
713
714
    elif isinstance(generator, EmbeddingBytesResponse):
        return StreamingResponse(
            content=generator.body,
            headers={"metadata": generator.metadata},
            media_type=generator.media_type,
        )
715

716
717
    assert_never(generator)

718

719
720
721
722
723
724
725
726
@router.post(
    "/pooling",
    dependencies=[Depends(validate_json_request)],
    responses={
        HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
        HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
    },
)
727
@with_cancellation
728
@load_aware_call
729
730
731
732
async def create_pooling(request: PoolingRequest, raw_request: Request):
    handler = pooling(raw_request)
    if handler is None:
        return base(raw_request).create_error_response(
733
734
            message="The model does not support Pooling API"
        )
735
736
737
    try:
        generator = await handler.create_pooling(request, raw_request)
    except Exception as e:
738
739
740
        raise HTTPException(
            status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)
        ) from e
741
    if isinstance(generator, ErrorResponse):
742
743
744
        return JSONResponse(
            content=generator.model_dump(), status_code=generator.error.code
        )
745
    elif isinstance(generator, (PoolingResponse, IOProcessorResponse)):
746
        return JSONResponse(content=generator.model_dump())
747
748
749
750
751
752
    elif isinstance(generator, PoolingBytesResponse):
        return StreamingResponse(
            content=generator.body,
            headers={"metadata": generator.metadata},
            media_type=generator.media_type,
        )
753
754
755
756

    assert_never(generator)


757
758
759
@router.post("/classify", dependencies=[Depends(validate_json_request)])
@with_cancellation
@load_aware_call
760
async def create_classify(request: ClassificationRequest, raw_request: Request):
761
762
763
    handler = classify(raw_request)
    if handler is None:
        return base(raw_request).create_error_response(
764
765
            message="The model does not support Classification API"
        )
766

767
768
769
    try:
        generator = await handler.create_classify(request, raw_request)
    except Exception as e:
770
771
772
        raise HTTPException(
            status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)
        ) from e
773
    if isinstance(generator, ErrorResponse):
774
775
776
        return JSONResponse(
            content=generator.model_dump(), status_code=generator.error.code
        )
777
778
779
780
781
782
783

    elif isinstance(generator, ClassificationResponse):
        return JSONResponse(content=generator.model_dump())

    assert_never(generator)


784
785
786
787
788
789
790
791
@router.post(
    "/score",
    dependencies=[Depends(validate_json_request)],
    responses={
        HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
        HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
    },
)
792
@with_cancellation
793
@load_aware_call
794
795
796
797
async def create_score(request: ScoreRequest, raw_request: Request):
    handler = score(raw_request)
    if handler is None:
        return base(raw_request).create_error_response(
798
799
            message="The model does not support Score API"
        )
800

801
802
803
    try:
        generator = await handler.create_score(request, raw_request)
    except Exception as e:
804
805
806
        raise HTTPException(
            status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)
        ) from e
807
    if isinstance(generator, ErrorResponse):
808
809
810
        return JSONResponse(
            content=generator.model_dump(), status_code=generator.error.code
        )
811
812
813
814
815
816
    elif isinstance(generator, ScoreResponse):
        return JSONResponse(content=generator.model_dump())

    assert_never(generator)


817
818
819
820
821
822
823
824
@router.post(
    "/v1/score",
    dependencies=[Depends(validate_json_request)],
    responses={
        HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
        HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
    },
)
825
@with_cancellation
826
@load_aware_call
827
828
829
async def create_score_v1(request: ScoreRequest, raw_request: Request):
    logger.warning(
        "To indicate that Score API is not part of standard OpenAI API, we "
830
831
        "have moved it to `/score`. Please update your client accordingly."
    )
832
833
834
835

    return await create_score(request, raw_request)


836
837
838
839
840
841
842
843
844
@router.post(
    "/v1/audio/transcriptions",
    responses={
        HTTPStatus.OK.value: {"content": {"text/event-stream": {}}},
        HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
        HTTPStatus.UNPROCESSABLE_ENTITY.value: {"model": ErrorResponse},
        HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
    },
)
845
@with_cancellation
846
@load_aware_call
847
848
849
async def create_transcriptions(
    raw_request: Request, request: Annotated[TranscriptionRequest, Form()]
):
850
851
852
    handler = transcription(raw_request)
    if handler is None:
        return base(raw_request).create_error_response(
853
854
            message="The model does not support Transcriptions API"
        )
855
856

    audio_data = await request.file.read()
857
    try:
858
        generator = await handler.create_transcription(audio_data, request, raw_request)
859
    except Exception as e:
860
861
862
        raise HTTPException(
            status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)
        ) from e
863
864

    if isinstance(generator, ErrorResponse):
865
866
867
        return JSONResponse(
            content=generator.model_dump(), status_code=generator.error.code
        )
868
869
870
871
872
873
874

    elif isinstance(generator, TranscriptionResponse):
        return JSONResponse(content=generator.model_dump())

    return StreamingResponse(content=generator, media_type="text/event-stream")


875
876
877
878
879
880
881
882
883
@router.post(
    "/v1/audio/translations",
    responses={
        HTTPStatus.OK.value: {"content": {"text/event-stream": {}}},
        HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
        HTTPStatus.UNPROCESSABLE_ENTITY.value: {"model": ErrorResponse},
        HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
    },
)
884
885
@with_cancellation
@load_aware_call
886
887
888
async def create_translations(
    request: Annotated[TranslationRequest, Form()], raw_request: Request
):
889
890
891
    handler = translation(raw_request)
    if handler is None:
        return base(raw_request).create_error_response(
892
893
            message="The model does not support Translations API"
        )
894
895

    audio_data = await request.file.read()
896
    try:
897
        generator = await handler.create_translation(audio_data, request, raw_request)
898
    except Exception as e:
899
900
901
        raise HTTPException(
            status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)
        ) from e
902
903

    if isinstance(generator, ErrorResponse):
904
905
906
        return JSONResponse(
            content=generator.model_dump(), status_code=generator.error.code
        )
907
908
909
910
911
912
913

    elif isinstance(generator, TranslationResponse):
        return JSONResponse(content=generator.model_dump())

    return StreamingResponse(content=generator, media_type="text/event-stream")


914
915
916
917
918
919
920
921
@router.post(
    "/rerank",
    dependencies=[Depends(validate_json_request)],
    responses={
        HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
        HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
    },
)
922
@with_cancellation
923
@load_aware_call
924
925
926
927
async def do_rerank(request: RerankRequest, raw_request: Request):
    handler = rerank(raw_request)
    if handler is None:
        return base(raw_request).create_error_response(
928
929
            message="The model does not support Rerank (Score) API"
        )
930
931
932
    try:
        generator = await handler.do_rerank(request, raw_request)
    except Exception as e:
933
934
935
        raise HTTPException(
            status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)
        ) from e
936
    if isinstance(generator, ErrorResponse):
937
938
939
        return JSONResponse(
            content=generator.model_dump(), status_code=generator.error.code
        )
940
941
942
943
944
945
    elif isinstance(generator, RerankResponse):
        return JSONResponse(content=generator.model_dump())

    assert_never(generator)


946
947
948
949
950
951
952
953
@router.post(
    "/v1/rerank",
    dependencies=[Depends(validate_json_request)],
    responses={
        HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
        HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
    },
)
954
955
@with_cancellation
async def do_rerank_v1(request: RerankRequest, raw_request: Request):
956
    logger.warning_once(
957
        "To indicate that the rerank API is not part of the standard OpenAI"
958
        " API, we have located it at `/rerank`. Please update your client "
959
960
        "accordingly. (Note: Conforms to JinaAI rerank API)"
    )
961
962
963
964

    return await do_rerank(request, raw_request)


965
966
967
968
969
970
971
972
@router.post(
    "/v2/rerank",
    dependencies=[Depends(validate_json_request)],
    responses={
        HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
        HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
    },
)
973
974
975
976
977
@with_cancellation
async def do_rerank_v2(request: RerankRequest, raw_request: Request):
    return await do_rerank(request, raw_request)


978
if envs.VLLM_SERVER_DEV_MODE:
979
980
981
982
    logger.warning(
        "SECURITY WARNING: Development endpoints are enabled! "
        "This should NOT be used in production!"
    )
983

984
985
    PydanticVllmConfig = pydantic.TypeAdapter(VllmConfig)

986
    @router.get("/server_info")
987
988
    async def show_server_info(
        raw_request: Request,
989
        config_format: Annotated[Literal["text", "json"], Query()] = "text",
990
991
992
    ):
        vllm_config: VllmConfig = raw_request.app.state.vllm_config
        server_info = {
993
994
995
            "vllm_config": str(vllm_config)
            if config_format == "text"
            else PydanticVllmConfig.dump_python(vllm_config, mode="json", fallback=str)
996
997
            # fallback=str is needed to handle e.g. torch.dtype
        }
998
999
        return JSONResponse(content=server_info)

1000
1001
1002
1003
1004
1005
    @router.post("/reset_prefix_cache")
    async def reset_prefix_cache(raw_request: Request):
        """
        Reset the prefix cache. Note that we currently do not check if the
        prefix cache is successfully reset in the API server.
        """
1006
1007
1008
1009
1010
1011
        device = None
        device_str = raw_request.query_params.get("device")
        if device_str is not None:
            device = Device[device_str.upper()]
        logger.info("Resetting prefix cache with specific %s...", str(device))
        await engine_client(raw_request).reset_prefix_cache(device)
1012
1013
        return Response(status_code=200)

1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
    @router.post("/reset_mm_cache")
    async def reset_mm_cache(raw_request: Request):
        """
        Reset the multi-modal cache. Note that we currently do not check if the
        multi-modal cache is successfully reset in the API server.
        """
        logger.info("Resetting multi-modal cache...")
        await engine_client(raw_request).reset_mm_cache()
        return Response(status_code=200)

1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
    @router.post("/sleep")
    async def sleep(raw_request: Request):
        # get POST params
        level = raw_request.query_params.get("level", "1")
        await engine_client(raw_request).sleep(int(level))
        # FIXME: in v0 with frontend multiprocessing, the sleep command
        # is sent but does not finish yet when we return a response.
        return Response(status_code=200)

    @router.post("/wake_up")
    async def wake_up(raw_request: Request):
1035
1036
1037
1038
1039
1040
        tags = raw_request.query_params.getlist("tags")
        if tags == []:
            # set to None to wake up all tags if no tags are provided
            tags = None
        logger.info("wake up the engine with tags: %s", tags)
        await engine_client(raw_request).wake_up(tags)
1041
1042
1043
1044
        # FIXME: in v0 with frontend multiprocessing, the wake-up command
        # is sent but does not finish yet when we return a response.
        return Response(status_code=200)

1045
1046
1047
1048
1049
1050
    @router.get("/is_sleeping")
    async def is_sleeping(raw_request: Request):
        logger.info("check whether the engine is sleeping")
        is_sleeping = await engine_client(raw_request).is_sleeping()
        return JSONResponse(content={"is_sleeping": is_sleeping})

1051
1052
1053
1054
1055
    @router.post("/collective_rpc")
    async def collective_rpc(raw_request: Request):
        try:
            body = await raw_request.json()
        except json.JSONDecodeError as e:
1056
1057
1058
1059
            raise HTTPException(
                status_code=HTTPStatus.BAD_REQUEST.value,
                detail=f"JSON decode error: {e}",
            ) from e
1060
1061
        method = body.get("method")
        if method is None:
1062
1063
1064
1065
            raise HTTPException(
                status_code=HTTPStatus.BAD_REQUEST.value,
                detail="Missing 'method' in request body",
            )
1066
        # For security reason, only serialized string args/kwargs are passed.
1067
        # User-defined `method` is responsible for deserialization if needed.
1068
1069
        args: list[str] = body.get("args", [])
        kwargs: dict[str, str] = body.get("kwargs", {})
1070
        timeout: float | None = body.get("timeout")
1071
        results = await engine_client(raw_request).collective_rpc(
1072
1073
            method=method, timeout=timeout, args=tuple(args), kwargs=kwargs
        )
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
        if results is None:
            return Response(status_code=200)
        response: list[Any] = []
        for result in results:
            if result is None or isinstance(result, (dict, list)):
                response.append(result)
            else:
                response.append(str(result))
        return JSONResponse(content={"results": response})

1084

1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
@router.post(
    "/scale_elastic_ep",
    dependencies=[Depends(validate_json_request)],
    responses={
        HTTPStatus.OK.value: {"model": dict},
        HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
        HTTPStatus.REQUEST_TIMEOUT.value: {"model": ErrorResponse},
        HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
    },
)
1095
1096
1097
1098
async def scale_elastic_ep(raw_request: Request):
    try:
        body = await raw_request.json()
    except json.JSONDecodeError as e:
1099
        raise HTTPException(status_code=400, detail="Invalid JSON format") from e  # noqa: B904
1100
1101
1102
1103
1104

    new_data_parallel_size = body.get("new_data_parallel_size")
    drain_timeout = body.get("drain_timeout", 120)  # Default 2 minutes

    if new_data_parallel_size is None:
1105
1106
1107
        raise HTTPException(
            status_code=400, detail="new_data_parallel_size is required"
        )
1108

1109
    if not isinstance(new_data_parallel_size, int) or new_data_parallel_size <= 0:
1110
        raise HTTPException(
1111
1112
            status_code=400, detail="new_data_parallel_size must be a positive integer"
        )
1113
1114

    if not isinstance(drain_timeout, int) or drain_timeout <= 0:
1115
1116
1117
        raise HTTPException(
            status_code=400, detail="drain_timeout must be a positive integer"
        )
1118
1119
1120
1121
1122
1123
1124

    # Set scaling flag to prevent new requests
    global _scaling_elastic_ep
    _scaling_elastic_ep = True
    client = engine_client(raw_request)
    try:
        await client.scale_elastic_ep(new_data_parallel_size, drain_timeout)
1125
1126
1127
1128
1129
        return JSONResponse(
            {
                "message": f"Scaled to {new_data_parallel_size} data parallel engines",
            }
        )
1130
    except TimeoutError as e:
1131
1132
1133
1134
1135
        raise HTTPException(
            status_code=408,
            detail="Scale failed due to request drain timeout "
            f"after {drain_timeout} seconds",
        ) from e
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
    except Exception as e:
        logger.error("Scale failed: %s", e)
        raise HTTPException(status_code=500, detail="Scale failed") from e
    finally:
        _scaling_elastic_ep = False


@router.post("/is_scaling_elastic_ep")
async def is_scaling_elastic_ep(raw_request: Request):
    return JSONResponse({"is_scaling_elastic_ep": _scaling_elastic_ep})


1148
1149
1150
# TODO: RequestType = TypeForm[BaseModel] when recognized by type checkers
# (requires typing_extensions >= 4.13)
RequestType = Any
1151
GetHandlerFn = Callable[[Request], OpenAIServing | None]
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
EndpointFn = Callable[[RequestType, Request], Awaitable[Any]]

# NOTE: Items defined earlier take higher priority
INVOCATION_TYPES: list[tuple[RequestType, tuple[GetHandlerFn, EndpointFn]]] = [
    (ChatCompletionRequest, (chat, create_chat_completion)),
    (CompletionRequest, (completion, create_completion)),
    (EmbeddingRequest, (embedding, create_embedding)),
    (ClassificationRequest, (classify, create_classify)),
    (ScoreRequest, (score, create_score)),
    (RerankRequest, (rerank, do_rerank)),
    (PoolingRequest, (pooling, create_pooling)),
]

# NOTE: Construct the TypeAdapters only once
INVOCATION_VALIDATORS = [
    (pydantic.TypeAdapter(request_type), (get_handler, endpoint))
    for request_type, (get_handler, endpoint) in INVOCATION_TYPES
]


1172
1173
1174
1175
1176
1177
1178
1179
1180
@router.post(
    "/invocations",
    dependencies=[Depends(validate_json_request)],
    responses={
        HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
        HTTPStatus.UNSUPPORTED_MEDIA_TYPE.value: {"model": ErrorResponse},
        HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
    },
)
1181
async def invocations(raw_request: Request):
1182
    """For SageMaker, routes requests based on the request type."""
1183
1184
    try:
        body = await raw_request.json()
1185
    except json.JSONDecodeError as e:
1186
1187
1188
        raise HTTPException(
            status_code=HTTPStatus.BAD_REQUEST.value, detail=f"JSON decode error: {e}"
        ) from e
1189

1190
1191
1192
1193
1194
    valid_endpoints = [
        (validator, endpoint)
        for validator, (get_handler, endpoint) in INVOCATION_VALIDATORS
        if get_handler(raw_request) is not None
    ]
1195

1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
    for request_validator, endpoint in valid_endpoints:
        try:
            request = request_validator.validate_python(body)
        except pydantic.ValidationError:
            continue

        return await endpoint(request, raw_request)

    type_names = [
        t.__name__ if isinstance(t := validator._type, type) else str(t)
        for validator, _ in valid_endpoints
    ]
1208
    msg = f"Cannot find suitable handler for request. Expected one of: {type_names}"
1209
    res = base(raw_request).create_error_response(message=msg)
1210
    return JSONResponse(content=res.model_dump(), status_code=res.error.code)
1211
1212


1213
1214
1215
if envs.VLLM_TORCH_PROFILER_DIR:
    logger.warning(
        "Torch Profiler is enabled in the API server. This should ONLY be "
1216
1217
        "used for local development!"
    )
1218
1219

    @router.post("/start_profile")
1220
    async def start_profile(raw_request: Request):
1221
        logger.info("Starting profiler...")
1222
        await engine_client(raw_request).start_profile()
1223
1224
1225
1226
        logger.info("Profiler started.")
        return Response(status_code=200)

    @router.post("/stop_profile")
1227
    async def stop_profile(raw_request: Request):
1228
        logger.info("Stopping profiler...")
1229
        await engine_client(raw_request).stop_profile()
1230
1231
1232
1233
        logger.info("Profiler stopped.")
        return Response(status_code=200)


1234
1235
if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING:
    logger.warning(
1236
        "LoRA dynamic loading & unloading is enabled in the API server. "
1237
1238
        "This should ONLY be used for local development!"
    )
1239

1240
1241
    @router.post("/v1/load_lora_adapter", dependencies=[Depends(validate_json_request)])
    async def load_lora_adapter(request: LoadLoRAAdapterRequest, raw_request: Request):
1242
1243
1244
        handler = models(raw_request)
        response = await handler.load_lora_adapter(request)
        if isinstance(response, ErrorResponse):
1245
1246
1247
            return JSONResponse(
                content=response.model_dump(), status_code=response.error.code
            )
1248
1249
1250

        return Response(status_code=200, content=response)

1251
1252
1253
1254
1255
1256
    @router.post(
        "/v1/unload_lora_adapter", dependencies=[Depends(validate_json_request)]
    )
    async def unload_lora_adapter(
        request: UnloadLoRAAdapterRequest, raw_request: Request
    ):
1257
1258
1259
        handler = models(raw_request)
        response = await handler.unload_lora_adapter(request)
        if isinstance(response, ErrorResponse):
1260
1261
1262
            return JSONResponse(
                content=response.model_dump(), status_code=response.error.code
            )
1263
1264
1265
1266

        return Response(status_code=200, content=response)


1267
def load_log_config(log_config_file: str | None) -> dict | None:
1268
1269
1270
1271
1272
1273
    if not log_config_file:
        return None
    try:
        with open(log_config_file) as f:
            return json.load(f)
    except Exception as e:
1274
1275
1276
        logger.warning(
            "Failed to load log config from file %s: error %s", log_config_file, e
        )
1277
1278
1279
        return None


1280
1281
1282
class AuthenticationMiddleware:
    """
    Pure ASGI middleware that authenticates each request by checking
1283
    if the Authorization Bearer token exists and equals anyof "{api_key}".
1284
1285
1286
1287
1288
1289
1290
1291

    Notes
    -----
    There are two cases in which authentication is skipped:
        1. The HTTP method is OPTIONS.
        2. The request path doesn't start with /v1 (e.g. /health).
    """

1292
    def __init__(self, app: ASGIApp, tokens: list[str]) -> None:
1293
        self.app = app
1294
        self.api_tokens = [hashlib.sha256(t.encode("utf-8")).digest() for t in tokens]
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311

    def verify_token(self, headers: Headers) -> bool:
        authorization_header_value = headers.get("Authorization")
        if not authorization_header_value:
            return False

        scheme, _, param = authorization_header_value.partition(" ")
        if scheme.lower() != "bearer":
            return False

        param_hash = hashlib.sha256(param.encode("utf-8")).digest()

        token_match = False
        for token_hash in self.api_tokens:
            token_match |= secrets.compare_digest(param_hash, token_hash)

        return token_match
1312

1313
1314
    def __call__(self, scope: Scope, receive: Receive, send: Send) -> Awaitable[None]:
        if scope["type"] not in ("http", "websocket") or scope["method"] == "OPTIONS":
1315
1316
1317
1318
1319
1320
1321
            # scope["type"] can be "lifespan" or "startup" for example,
            # in which case we don't need to do anything
            return self.app(scope, receive, send)
        root_path = scope.get("root_path", "")
        url_path = URL(scope=scope).path.removeprefix(root_path)
        headers = Headers(scope=scope)
        # Type narrow to satisfy mypy.
1322
        if url_path.startswith("/v1") and not self.verify_token(headers):
1323
            response = JSONResponse(content={"error": "Unauthorized"}, status_code=401)
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
            return response(scope, receive, send)
        return self.app(scope, receive, send)


class XRequestIdMiddleware:
    """
    Middleware the set's the X-Request-Id header for each response
    to a random uuid4 (hex) value if the header isn't already
    present in the request, otherwise use the provided request id.
    """

    def __init__(self, app: ASGIApp) -> None:
        self.app = app

1338
    def __call__(self, scope: Scope, receive: Receive, send: Send) -> Awaitable[None]:
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
        if scope["type"] not in ("http", "websocket"):
            return self.app(scope, receive, send)

        # Extract the request headers.
        request_headers = Headers(scope=scope)

        async def send_with_request_id(message: Message) -> None:
            """
            Custom send function to mutate the response headers
            and append X-Request-Id to it.
            """
            if message["type"] == "http.response.start":
                response_headers = MutableHeaders(raw=message["headers"])
1352
                request_id = request_headers.get("X-Request-Id", uuid.uuid4().hex)
1353
1354
1355
1356
1357
1358
                response_headers.append("X-Request-Id", request_id)
            await send(message)

        return self.app(scope, receive, send_with_request_id)


1359
1360
1361
1362
1363
1364
1365
1366
# Global variable to track scaling state
_scaling_elastic_ep = False


class ScalingMiddleware:
    """
    Middleware that checks if the model is currently scaling and
    returns a 503 Service Unavailable response if it is.
1367

1368
1369
1370
1371
1372
1373
1374
    This middleware applies to all HTTP requests and prevents
    processing when the model is in a scaling state.
    """

    def __init__(self, app: ASGIApp) -> None:
        self.app = app

1375
    def __call__(self, scope: Scope, receive: Receive, send: Send) -> Awaitable[None]:
1376
1377
1378
1379
1380
1381
1382
        if scope["type"] != "http":
            return self.app(scope, receive, send)

        # Check global scaling state
        global _scaling_elastic_ep
        if _scaling_elastic_ep:
            # Return 503 Service Unavailable response
1383
1384
1385
1386
1387
1388
            response = JSONResponse(
                content={
                    "error": "The model is currently scaling. Please try again later."
                },
                status_code=503,
            )
1389
1390
1391
1392
1393
            return response(scope, receive, send)

        return self.app(scope, receive, send)


1394
1395
1396
1397
def _extract_content_from_chunk(chunk_data: dict) -> str:
    """Extract content from a streaming response chunk."""
    try:
        from vllm.entrypoints.openai.protocol import (
1398
1399
1400
            ChatCompletionStreamResponse,
            CompletionStreamResponse,
        )
1401
1402

        # Try using Completion types for type-safe parsing
1403
1404
        if chunk_data.get("object") == "chat.completion.chunk":
            chat_response = ChatCompletionStreamResponse.model_validate(chunk_data)
1405
1406
            if chat_response.choices and chat_response.choices[0].delta.content:
                return chat_response.choices[0].delta.content
1407
1408
1409
        elif chunk_data.get("object") == "text_completion":
            completion_response = CompletionStreamResponse.model_validate(chunk_data)
            if completion_response.choices and completion_response.choices[0].text:
1410
1411
1412
                return completion_response.choices[0].text
    except pydantic.ValidationError:
        # Fallback to manual parsing
1413
1414
1415
1416
1417
1418
        if "choices" in chunk_data and chunk_data["choices"]:
            choice = chunk_data["choices"][0]
            if "delta" in choice and choice["delta"].get("content"):
                return choice["delta"]["content"]
            elif choice.get("text"):
                return choice["text"]
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
    return ""


class SSEDecoder:
    """Robust Server-Sent Events decoder for streaming responses."""

    def __init__(self):
        self.buffer = ""
        self.content_buffer = []

    def decode_chunk(self, chunk: bytes) -> list[dict]:
        """Decode a chunk of SSE data and return parsed events."""
        import json

        try:
1434
            chunk_str = chunk.decode("utf-8")
1435
1436
1437
1438
1439
1440
1441
1442
        except UnicodeDecodeError:
            # Skip malformed chunks
            return []

        self.buffer += chunk_str
        events = []

        # Process complete lines
1443
1444
1445
        while "\n" in self.buffer:
            line, self.buffer = self.buffer.split("\n", 1)
            line = line.rstrip("\r")  # Handle CRLF
1446

1447
            if line.startswith("data: "):
1448
                data_str = line[6:].strip()
1449
1450
                if data_str == "[DONE]":
                    events.append({"type": "done"})
1451
1452
1453
                elif data_str:
                    try:
                        event_data = json.loads(data_str)
1454
                        events.append({"type": "data", "data": event_data})
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
                    except json.JSONDecodeError:
                        # Skip malformed JSON
                        continue

        return events

    def extract_content(self, event_data: dict) -> str:
        """Extract content from event data."""
        return _extract_content_from_chunk(event_data)

    def add_content(self, content: str) -> None:
        """Add content to the buffer."""
        if content:
            self.content_buffer.append(content)

    def get_complete_content(self) -> str:
        """Get the complete buffered content."""
1472
        return "".join(self.content_buffer)
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492


def _log_streaming_response(response, response_body: list) -> None:
    """Log streaming response with robust SSE parsing."""
    from starlette.concurrency import iterate_in_threadpool

    sse_decoder = SSEDecoder()
    chunk_count = 0

    def buffered_iterator():
        nonlocal chunk_count

        for chunk in response_body:
            chunk_count += 1
            yield chunk

            # Parse SSE events from chunk
            events = sse_decoder.decode_chunk(chunk)

            for event in events:
1493
1494
                if event["type"] == "data":
                    content = sse_decoder.extract_content(event["data"])
1495
                    sse_decoder.add_content(content)
1496
                elif event["type"] == "done":
1497
1498
1499
1500
1501
1502
1503
1504
                    # Log complete content when done
                    full_content = sse_decoder.get_complete_content()
                    if full_content:
                        # Truncate if too long
                        if len(full_content) > 2048:
                            full_content = full_content[:2048] + ""
                            "...[truncated]"
                        logger.info(
1505
                            "response_body={streaming_complete: "
1506
                            "content='%s', chunks=%d}",
1507
1508
1509
                            full_content,
                            chunk_count,
                        )
1510
1511
                    else:
                        logger.info(
1512
1513
1514
                            "response_body={streaming_complete: no_content, chunks=%d}",
                            chunk_count,
                        )
1515
1516
1517
                    return

    response.body_iterator = iterate_in_threadpool(buffered_iterator())
1518
    logger.info("response_body={streaming_started: chunks=%d}", len(response_body))
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529


def _log_non_streaming_response(response_body: list) -> None:
    """Log non-streaming response."""
    try:
        decoded_body = response_body[0].decode()
        logger.info("response_body={%s}", decoded_body)
    except UnicodeDecodeError:
        logger.info("response_body={<binary_data>}")


1530
def build_app(args: Namespace) -> FastAPI:
1531
    if args.disable_fastapi_docs:
1532
1533
1534
        app = FastAPI(
            openapi_url=None, docs_url=None, redoc_url=None, lifespan=lifespan
        )
1535
1536
    else:
        app = FastAPI(lifespan=lifespan)
Ethan Xu's avatar
Ethan Xu committed
1537
1538
    app.include_router(router)
    app.root_path = args.root_path
Zhuohan Li's avatar
Zhuohan Li committed
1539

1540
1541
    mount_metrics(app)

Zhuohan Li's avatar
Zhuohan Li committed
1542
1543
1544
1545
1546
1547
1548
1549
    app.add_middleware(
        CORSMiddleware,
        allow_origins=args.allowed_origins,
        allow_credentials=args.allow_credentials,
        allow_methods=args.allowed_methods,
        allow_headers=args.allowed_headers,
    )

1550
1551
    @app.exception_handler(HTTPException)
    async def http_exception_handler(_: Request, exc: HTTPException):
1552
        err = ErrorResponse(
1553
1554
1555
1556
1557
1558
            error=ErrorInfo(
                message=exc.detail,
                type=HTTPStatus(exc.status_code).phrase,
                code=exc.status_code,
            )
        )
1559
1560
        return JSONResponse(err.model_dump(), status_code=exc.status_code)

Ethan Xu's avatar
Ethan Xu committed
1561
    @app.exception_handler(RequestValidationError)
1562
    async def validation_exception_handler(_: Request, exc: RequestValidationError):
1563
1564
1565
1566
1567
1568
1569
1570
        exc_str = str(exc)
        errors_str = str(exc.errors())

        if exc.errors() and errors_str and errors_str != exc_str:
            message = f"{exc_str} {errors_str}"
        else:
            message = exc_str

1571
1572
1573
1574
1575
1576
1577
1578
        err = ErrorResponse(
            error=ErrorInfo(
                message=message,
                type=HTTPStatus.BAD_REQUEST.phrase,
                code=HTTPStatus.BAD_REQUEST,
            )
        )
        return JSONResponse(err.model_dump(), status_code=HTTPStatus.BAD_REQUEST)
Ethan Xu's avatar
Ethan Xu committed
1579

1580
    # Ensure --api-key option from CLI takes precedence over VLLM_API_KEY
1581
1582
    if tokens := [key for key in (args.api_key or [envs.VLLM_API_KEY]) if key]:
        app.add_middleware(AuthenticationMiddleware, tokens=tokens)
1583

1584
    if args.enable_request_id_headers:
1585
        app.add_middleware(XRequestIdMiddleware)
1586

1587
1588
1589
    # Add scaling middleware to check for scaling state
    app.add_middleware(ScalingMiddleware)

1590
    if envs.VLLM_DEBUG_LOG_API_SERVER_RESPONSE:
1591
1592
1593
1594
1595
        logger.warning(
            "CAUTION: Enabling log response in the API Server. "
            "This can include sensitive information and should be "
            "avoided in production."
        )
1596
1597
1598
1599

        @app.middleware("http")
        async def log_response(request: Request, call_next):
            response = await call_next(request)
1600
            response_body = [section async for section in response.body_iterator]
1601
            response.body_iterator = iterate_in_threadpool(iter(response_body))
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
            # Check if this is a streaming response by looking at content-type
            content_type = response.headers.get("content-type", "")
            is_streaming = content_type == "text/event-stream; charset=utf-8"

            # Log response body based on type
            if not response_body:
                logger.info("response_body={<empty>}")
            elif is_streaming:
                _log_streaming_response(response, response_body)
            else:
                _log_non_streaming_response(response_body)
1613
            return response
1614

1615
1616
1617
1618
    for middleware in args.middleware:
        module_path, object_name = middleware.rsplit(".", 1)
        imported = getattr(importlib.import_module(module_path), object_name)
        if inspect.isclass(imported):
1619
            app.add_middleware(imported)  # type: ignore[arg-type]
1620
1621
1622
        elif inspect.iscoroutinefunction(imported):
            app.middleware("http")(imported)
        else:
1623
1624
1625
            raise ValueError(
                f"Invalid middleware {middleware}. Must be a function or a class."
            )
1626

Ethan Xu's avatar
Ethan Xu committed
1627
1628
1629
    return app


1630
async def init_app_state(
1631
    engine_client: EngineClient,
1632
    state: State,
1633
    args: Namespace,
1634
) -> None:
1635
1636
    vllm_config = engine_client.vllm_config

1637
    if args.served_model_name is not None:
1638
        served_model_names = args.served_model_name
1639
    else:
1640
        served_model_names = [args.model]
1641

1642
    if args.enable_log_requests:
1643
        request_logger = RequestLogger(max_log_len=args.max_log_len)
1644
1645
    else:
        request_logger = None
1646

1647
    base_model_paths = [
1648
        BaseModelPath(name=name, model_path=args.model) for name in served_model_names
1649
1650
    ]

1651
    state.engine_client = engine_client
1652
    state.log_stats = not args.disable_log_stats
1653
    state.vllm_config = vllm_config
Ethan Xu's avatar
Ethan Xu committed
1654

1655
    supported_tasks = await engine_client.get_supported_tasks()
1656
    logger.info("Supported tasks: %s", supported_tasks)
1657

1658
    resolved_chat_template = load_chat_template(args.chat_template)
1659
    if resolved_chat_template is not None:
1660
1661
1662
1663
1664
1665
        # Get the tokenizer to check official template
        tokenizer = await engine_client.get_tokenizer()

        if isinstance(tokenizer, MistralTokenizer):
            # The warning is logged in resolve_mistral_chat_template.
            resolved_chat_template = resolve_mistral_chat_template(
1666
1667
                chat_template=resolved_chat_template
            )
1668
1669
        else:
            hf_chat_template = resolve_hf_chat_template(
1670
                tokenizer=tokenizer,
1671
1672
                chat_template=None,
                tools=None,
1673
                model_config=vllm_config.model_config,
1674
            )
1675
1676
1677
1678
1679
1680

            if hf_chat_template != resolved_chat_template:
                logger.warning(
                    "Using supplied chat template: %s\n"
                    "It is different from official chat template '%s'. "
                    "This discrepancy may lead to performance degradation.",
1681
1682
1683
                    resolved_chat_template,
                    args.model,
                )
1684

1685
    if args.tool_server == "demo":
1686
        tool_server: ToolServer | None = DemoToolServer()
1687
1688
        assert isinstance(tool_server, DemoToolServer)
        await tool_server.init_and_validate()
1689
1690
1691
    elif args.tool_server:
        tool_server = MCPToolServer()
        await tool_server.add_tool_server(args.tool_server)
1692
1693
1694
    else:
        tool_server = None

1695
    # Merge default_mm_loras into the static lora_modules
1696
1697
1698
1699
1700
    default_mm_loras = (
        vllm_config.lora_config.default_mm_loras
        if vllm_config.lora_config is not None
        else {}
    )
1701
1702
1703
1704
1705
1706
1707

    lora_modules = args.lora_modules
    if default_mm_loras:
        default_mm_lora_paths = [
            LoRAModulePath(
                name=modality,
                path=lora_path,
1708
1709
            )
            for modality, lora_path in default_mm_loras.items()
1710
1711
1712
1713
1714
1715
        ]
        if args.lora_modules is None:
            lora_modules = default_mm_lora_paths
        else:
            lora_modules += default_mm_lora_paths

1716
    state.openai_serving_models = OpenAIServingModels(
1717
        engine_client=engine_client,
1718
        base_model_paths=base_model_paths,
1719
        lora_modules=lora_modules,
1720
    )
1721
    await state.openai_serving_models.init_static_loras()
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
1747
1748
1749
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
    state.openai_serving_responses = (
        OpenAIServingResponses(
            engine_client,
            state.openai_serving_models,
            request_logger=request_logger,
            chat_template=resolved_chat_template,
            chat_template_content_format=args.chat_template_content_format,
            return_tokens_as_token_ids=args.return_tokens_as_token_ids,
            enable_auto_tools=args.enable_auto_tool_choice,
            tool_parser=args.tool_call_parser,
            tool_server=tool_server,
            reasoning_parser=args.structured_outputs_config.reasoning_parser,
            enable_prompt_tokens_details=args.enable_prompt_tokens_details,
            enable_force_include_usage=args.enable_force_include_usage,
            enable_log_outputs=args.enable_log_outputs,
            log_error_stack=args.log_error_stack,
        )
        if "generate" in supported_tasks
        else None
    )
    state.openai_serving_chat = (
        OpenAIServingChat(
            engine_client,
            state.openai_serving_models,
            args.response_role,
            request_logger=request_logger,
            chat_template=resolved_chat_template,
            chat_template_content_format=args.chat_template_content_format,
            trust_request_chat_template=args.trust_request_chat_template,
            return_tokens_as_token_ids=args.return_tokens_as_token_ids,
            enable_auto_tools=args.enable_auto_tool_choice,
            exclude_tools_when_tool_choice_none=args.exclude_tools_when_tool_choice_none,
            tool_parser=args.tool_call_parser,
            reasoning_parser=args.structured_outputs_config.reasoning_parser,
            enable_prompt_tokens_details=args.enable_prompt_tokens_details,
            enable_force_include_usage=args.enable_force_include_usage,
            enable_log_outputs=args.enable_log_outputs,
            log_error_stack=args.log_error_stack,
        )
        if "generate" in supported_tasks
        else None
    )
    state.openai_serving_completion = (
        OpenAIServingCompletion(
            engine_client,
            state.openai_serving_models,
            request_logger=request_logger,
            return_tokens_as_token_ids=args.return_tokens_as_token_ids,
            enable_prompt_tokens_details=args.enable_prompt_tokens_details,
            enable_force_include_usage=args.enable_force_include_usage,
            log_error_stack=args.log_error_stack,
        )
        if "generate" in supported_tasks
        else None
    )
    state.openai_serving_pooling = (
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
        (
            OpenAIServingPooling(
                engine_client,
                state.openai_serving_models,
                supported_tasks=supported_tasks,
                request_logger=request_logger,
                chat_template=resolved_chat_template,
                chat_template_content_format=args.chat_template_content_format,
                trust_request_chat_template=args.trust_request_chat_template,
                log_error_stack=args.log_error_stack,
            )
1789
        )
1790
        if ("token_embed" in supported_tasks or "token_classify" in supported_tasks)
1791
1792
1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803
1804
1805
1806
1807
1808
1809
1810
1811
1812
1813
1814
1815
1816
1817
1818
1819
1820
1821
1822
1823
1824
1825
        else None
    )
    state.openai_serving_embedding = (
        OpenAIServingEmbedding(
            engine_client,
            state.openai_serving_models,
            request_logger=request_logger,
            chat_template=resolved_chat_template,
            chat_template_content_format=args.chat_template_content_format,
            trust_request_chat_template=args.trust_request_chat_template,
            log_error_stack=args.log_error_stack,
        )
        if "embed" in supported_tasks
        else None
    )
    state.openai_serving_classification = (
        ServingClassification(
            engine_client,
            state.openai_serving_models,
            request_logger=request_logger,
            log_error_stack=args.log_error_stack,
        )
        if "classify" in supported_tasks
        else None
    )
    state.openai_serving_scores = (
        ServingScores(
            engine_client,
            state.openai_serving_models,
            request_logger=request_logger,
            log_error_stack=args.log_error_stack,
        )
        if ("embed" in supported_tasks or "score" in supported_tasks)
        else None
    )
1826
    state.openai_serving_tokenization = OpenAIServingTokenization(
1827
        engine_client,
1828
        state.openai_serving_models,
1829
        request_logger=request_logger,
1830
1831
        chat_template=resolved_chat_template,
        chat_template_content_format=args.chat_template_content_format,
1832
        trust_request_chat_template=args.trust_request_chat_template,
1833
        log_error_stack=args.log_error_stack,
1834
    )
1835
1836
1837
1838
1839
1840
    state.openai_serving_transcription = (
        OpenAIServingTranscription(
            engine_client,
            state.openai_serving_models,
            request_logger=request_logger,
            log_error_stack=args.log_error_stack,
1841
            enable_force_include_usage=args.enable_force_include_usage,
1842
1843
1844
1845
1846
1847
1848
1849
1850
1851
        )
        if "transcription" in supported_tasks
        else None
    )
    state.openai_serving_translation = (
        OpenAIServingTranslation(
            engine_client,
            state.openai_serving_models,
            request_logger=request_logger,
            log_error_stack=args.log_error_stack,
1852
            enable_force_include_usage=args.enable_force_include_usage,
1853
1854
1855
1856
        )
        if "transcription" in supported_tasks
        else None
    )
1857

1858
1859
1860
    state.enable_server_load_tracking = args.enable_server_load_tracking
    state.server_load_metrics = 0

1861

1862
def create_server_socket(addr: tuple[str, int]) -> socket.socket:
1863
1864
1865
1866
1867
1868
    family = socket.AF_INET
    if is_valid_ipv6_address(addr[0]):
        family = socket.AF_INET6

    sock = socket.socket(family=family, type=socket.SOCK_STREAM)
    sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
1869
    sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
1870
1871
1872
1873
1874
    sock.bind(addr)

    return sock


1875
1876
1877
1878
1879
1880
def create_server_unix_socket(path: str) -> socket.socket:
    sock = socket.socket(family=socket.AF_UNIX, type=socket.SOCK_STREAM)
    sock.bind(path)
    return sock


1881
def validate_api_server_args(args):
1882
    valid_tool_parses = ToolParserManager.tool_parsers.keys()
1883
1884
1885
1886
1887
    if args.enable_auto_tool_choice and args.tool_call_parser not in valid_tool_parses:
        raise KeyError(
            f"invalid tool call parser: {args.tool_call_parser} "
            f"(chose from {{ {','.join(valid_tool_parses)} }})"
        )
1888

1889
    valid_reasoning_parses = ReasoningParserManager.reasoning_parsers.keys()
1890
1891
1892
    if (
        reasoning_parser := args.structured_outputs_config.reasoning_parser
    ) and reasoning_parser not in valid_reasoning_parses:
1893
        raise KeyError(
1894
            f"invalid reasoning parser: {reasoning_parser} "
1895
1896
            f"(chose from {{ {','.join(valid_reasoning_parses)} }})"
        )
1897

1898
1899
1900
1901
1902
1903
1904
1905
1906
1907
1908
1909
1910

def setup_server(args):
    """Validate API server args, set up signal handler, create socket
    ready to serve."""

    logger.info("vLLM API server version %s", VLLM_VERSION)
    log_non_default_args(args)

    if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3:
        ToolParserManager.import_tool_parser(args.tool_parser_plugin)

    validate_api_server_args(args)

1911
1912
1913
    # workaround to make sure that we bind the port before the engine is set up.
    # This avoids race conditions with ray.
    # see https://github.com/vllm-project/vllm/issues/8204
1914
1915
1916
1917
1918
    if args.uds:
        sock = create_server_unix_socket(args.uds)
    else:
        sock_addr = (args.host or "", args.port)
        sock = create_server_socket(sock_addr)
1919

1920
1921
1922
1923
    # workaround to avoid footguns where uvicorn drops requests with too
    # many concurrent requests active
    set_ulimit()

1924
1925
1926
1927
1928
1929
    def signal_handler(*_) -> None:
        # Interrupt server on sigterm while initializing
        raise KeyboardInterrupt("terminated")

    signal.signal(signal.SIGTERM, signal_handler)

1930
1931
1932
1933
1934
    if args.uds:
        listen_address = f"unix:{args.uds}"
    else:
        addr, port = sock_addr
        is_ssl = args.ssl_keyfile and args.ssl_certfile
1935
        host_part = f"[{addr}]" if is_valid_ipv6_address(addr) else addr or "0.0.0.0"
1936
        listen_address = f"http{'s' if is_ssl else ''}://{host_part}:{port}"
1937
1938
1939
1940
1941
    return listen_address, sock


async def run_server(args, **uvicorn_kwargs) -> None:
    """Run a single-worker API server."""
1942
1943

    # Add process-specific prefix to stdout and stderr.
1944
    decorate_logs("APIServer")
1945

1946
1947
1948
1949
    listen_address, sock = setup_server(args)
    await run_server_worker(listen_address, sock, args, **uvicorn_kwargs)


1950
1951
1952
async def run_server_worker(
    listen_address, sock, args, client_config=None, **uvicorn_kwargs
) -> None:
1953
1954
1955
1956
1957
    """Run a single API server worker."""

    if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3:
        ToolParserManager.import_tool_parser(args.tool_parser_plugin)

1958
1959
1960
    # Load logging config for uvicorn if specified
    log_config = load_log_config(args.log_config_file)
    if log_config is not None:
1961
        uvicorn_kwargs["log_config"] = log_config
1962

1963
    async with build_async_engine_client(
1964
1965
        args,
        client_config=client_config,
1966
    ) as engine_client:
1967
        maybe_register_tokenizer_info_endpoint(args)
1968
1969
        app = build_app(args)

1970
        await init_app_state(engine_client, app.state, args)
1971

1972
1973
        logger.info(
            "Starting vLLM API server %d on %s",
1974
            engine_client.vllm_config.parallel_config._api_process_rank,
1975
1976
            listen_address,
        )
1977
1978
        shutdown_task = await serve_http(
            app,
1979
            sock=sock,
1980
            enable_ssl_refresh=args.enable_ssl_refresh,
1981
1982
1983
            host=args.host,
            port=args.port,
            log_level=args.uvicorn_log_level,
1984
1985
1986
            # NOTE: When the 'disable_uvicorn_access_log' value is True,
            # no access log will be output.
            access_log=not args.disable_uvicorn_access_log,
1987
            timeout_keep_alive=envs.VLLM_HTTP_TIMEOUT_KEEP_ALIVE,
1988
1989
1990
1991
            ssl_keyfile=args.ssl_keyfile,
            ssl_certfile=args.ssl_certfile,
            ssl_ca_certs=args.ssl_ca_certs,
            ssl_cert_reqs=args.ssl_cert_reqs,
1992
1993
            h11_max_incomplete_event_size=args.h11_max_incomplete_event_size,
            h11_max_header_count=args.h11_max_header_count,
1994
1995
1996
            **uvicorn_kwargs,
        )

1997
    # NB: Await server shutdown only after the backend context is exited
1998
1999
2000
2001
    try:
        await shutdown_task
    finally:
        sock.close()
2002

Ethan Xu's avatar
Ethan Xu committed
2003
2004
2005

if __name__ == "__main__":
    # NOTE(simon):
2006
2007
    # This section should be in sync with vllm/entrypoints/cli/main.py for CLI
    # entrypoints.
2008
    cli_env_setup()
Ethan Xu's avatar
Ethan Xu committed
2009
    parser = FlexibleArgumentParser(
2010
2011
        description="vLLM OpenAI-Compatible RESTful API server."
    )
Ethan Xu's avatar
Ethan Xu committed
2012
2013
    parser = make_arg_parser(parser)
    args = parser.parse_args()
2014
    validate_parsed_serve_args(args)
2015

2016
    uvloop.run(run_server(args))