Unverified Commit 3f42b05f authored by Chauncey's avatar Chauncey Committed by GitHub
Browse files

[Refactor] [1/N] to simplify the vLLM serving architecture (#28040)


Signed-off-by: default avatarchaunceyjiang <chaunceyjiang@gmail.com>
parent 69520bc6
......@@ -232,7 +232,7 @@ async def test_server_load(server: RemoteOpenAIServer):
@pytest.mark.asyncio
async def test_health_check_engine_dead_error():
# Import the health function directly to test it in isolation
from vllm.entrypoints.openai.api_server import health
from vllm.entrypoints.serve.instrumentator.health import health
# Create a mock request that simulates what FastAPI would provide
mock_request = Mock(spec=Request)
......
......@@ -118,6 +118,7 @@ async def init_app(
)
)
app.state.engine_client = engine
app.state.args = args
return app
......
This diff is collapsed.
......@@ -74,8 +74,6 @@ from vllm.entrypoints.openai.protocol import (
ErrorResponse,
FunctionCall,
FunctionDefinition,
GenerateRequest,
GenerateResponse,
ResponsesRequest,
TokenizeChatRequest,
TokenizeCompletionRequest,
......@@ -87,6 +85,7 @@ from vllm.entrypoints.openai.protocol import (
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
from vllm.entrypoints.renderer import BaseRenderer, CompletionRenderer, RenderConfig
from vllm.entrypoints.serve.disagg.protocol import GenerateRequest, GenerateResponse
from vllm.entrypoints.utils import _validate_truncation_size
from vllm.inputs.data import PromptType
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
......
......@@ -16,7 +16,6 @@ from vllm.entrypoints.openai.api_server import (
completion,
create_chat_completion,
create_completion,
health,
validate_json_request,
)
from vllm.entrypoints.openai.protocol import (
......@@ -38,6 +37,7 @@ from vllm.entrypoints.pooling.score.api_router import (
score,
)
from vllm.entrypoints.pooling.score.protocol import RerankRequest, ScoreRequest
from vllm.entrypoints.serve.instrumentator.health import health
# TODO: RequestType = TypeForm[BaseModel] when recognized by type checkers
# (requires typing_extensions >= 4.13)
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from fastapi import FastAPI
def register_vllm_serve_api_routers(app: FastAPI):
from vllm.entrypoints.serve.lora.api_router import (
attach_router as attach_lora_router,
)
attach_lora_router(app)
from vllm.entrypoints.serve.elastic_ep.api_router import (
attach_router as attach_elastic_ep_router,
)
attach_elastic_ep_router(app)
from vllm.entrypoints.serve.profile.api_router import (
attach_router as attach_profile_router,
)
attach_profile_router(app)
from vllm.entrypoints.serve.sleep.api_router import (
attach_router as attach_sleep_router,
)
attach_sleep_router(app)
from vllm.entrypoints.serve.tokenize.api_router import (
attach_router as attach_tokenize_router,
)
attach_tokenize_router(app)
from vllm.entrypoints.serve.disagg.api_router import (
attach_router as attach_disagg_router,
)
attach_disagg_router(app)
from vllm.entrypoints.serve.rlhf.api_router import (
attach_router as attach_rlhf_router,
)
attach_rlhf_router(app)
from vllm.entrypoints.serve.instrumentator.metrics import (
attach_router as attach_metrics_router,
)
attach_metrics_router(app)
from vllm.entrypoints.serve.instrumentator.health import (
attach_router as attach_health_router,
)
attach_health_router(app)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import json
from http import HTTPStatus
from fastapi import APIRouter, Depends, FastAPI, HTTPException, Request, Response
from fastapi.responses import JSONResponse, StreamingResponse
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.openai.api_server import validate_json_request
from vllm.entrypoints.openai.protocol import (
ErrorResponse,
)
from vllm.entrypoints.serve.disagg.protocol import (
GenerateRequest,
GenerateResponse,
)
from vllm.entrypoints.serve.disagg.serving import (
ServingTokens,
)
from vllm.entrypoints.serve.tokenize.serving import OpenAIServingTokenization
from vllm.entrypoints.utils import (
load_aware_call,
with_cancellation,
)
from vllm.logger import init_logger
logger = init_logger(__name__)
def tokenization(request: Request) -> OpenAIServingTokenization:
return request.app.state.openai_serving_tokenization
def generate_tokens(request: Request) -> ServingTokens | None:
return request.app.state.serving_tokens
def engine_client(request: Request) -> EngineClient:
return request.app.state.engine_client
router = APIRouter()
@router.post(
"/inference/v1/generate",
dependencies=[Depends(validate_json_request)],
responses={
HTTPStatus.OK.value: {"content": {"text/event-stream": {}}},
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
HTTPStatus.NOT_FOUND.value: {"model": ErrorResponse},
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
},
)
@with_cancellation
@load_aware_call
async def generate(request: GenerateRequest, raw_request: Request):
handler = generate_tokens(raw_request)
if handler is None:
return tokenization(raw_request).create_error_response(
message="The model does not support generate tokens API"
)
try:
generator = await handler.serve_tokens(request, raw_request)
except Exception as e:
raise HTTPException(
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)
) from e
if isinstance(generator, ErrorResponse):
return JSONResponse(
content=generator.model_dump(), status_code=generator.error.code
)
elif isinstance(generator, GenerateResponse):
return JSONResponse(content=generator.model_dump())
return StreamingResponse(content=generator, media_type="text/event-stream")
def attach_router(app: FastAPI):
if getattr(app.state.args, "tokens_only", False):
@router.post("/abort_requests")
async def abort_requests(raw_request: Request):
"""
Abort one or more requests. To be used in a
Disaggregated Everything setup.
"""
try:
body = await raw_request.json()
except json.JSONDecodeError as e:
raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST.value,
detail=f"JSON decode error: {e}",
) from e
request_ids = body.get("request_ids")
if request_ids is None:
raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST.value,
detail="Missing 'request_ids' in request body",
)
# Abort requests in background
asyncio.create_task(engine_client(raw_request).abort(request_ids))
return Response(status_code=200)
app.include_router(router)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any
from pydantic import BaseModel, Field
from vllm.entrypoints.openai.protocol import (
ChatCompletionLogProbs,
Logprob,
SamplingParams,
StreamOptions,
)
from vllm.utils import random_uuid
####### Tokens IN <> Tokens OUT #######
class GenerateRequest(BaseModel):
request_id: str = Field(
default_factory=lambda: f"{random_uuid()}",
description=(
"The request_id related to this request. If the caller does "
"not set it, a random_uuid will be generated. This id is used "
"through out the inference process and return in response."
),
)
token_ids: list[int]
"""The token ids to generate text from."""
# features: MultiModalFeatureSpec
# TODO (NickLucche): implement once Renderer work is completed
features: str | None = None
"""The processed MM inputs for the model."""
sampling_params: SamplingParams
"""The sampling parameters for the model."""
model: str | None = None
stream: bool | None = False
stream_options: StreamOptions | None = None
cache_salt: str | None = Field(
default=None,
description=(
"If specified, the prefix cache will be salted with the provided "
"string to prevent an attacker to guess prompts in multi-user "
"environments. The salt should be random, protected from "
"access by 3rd parties, and long enough to be "
"unpredictable (e.g., 43 characters base64-encoded, corresponding "
"to 256 bit)."
),
)
priority: int = Field(
default=0,
description=(
"The priority of the request (lower means earlier handling; "
"default: 0). Any priority other than 0 will raise an error "
"if the served model does not use priority scheduling."
),
)
kv_transfer_params: dict[str, Any] | None = Field(
default=None,
description="KVTransfer parameters used for disaggregated serving.",
)
class GenerateResponseChoice(BaseModel):
index: int
logprobs: ChatCompletionLogProbs | None = None
# per OpenAI spec this is the default
finish_reason: str | None = "stop"
token_ids: list[int] | None = None
class GenerateResponse(BaseModel):
request_id: str = Field(
default_factory=lambda: f"{random_uuid()}",
description=(
"The request_id related to this request. If the caller does "
"not set it, a random_uuid will be generated. This id is used "
"through out the inference process and return in response."
),
)
choices: list[GenerateResponseChoice]
prompt_logprobs: list[dict[int, Logprob] | None] | None = None
kv_transfer_params: dict[str, Any] | None = Field(
default=None,
description="KVTransfer parameters used for disaggregated serving.",
)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import time
from collections.abc import AsyncGenerator
......@@ -14,15 +16,17 @@ from vllm.entrypoints.openai.protocol import (
ChatCompletionLogProbs,
ChatCompletionLogProbsContent,
ErrorResponse,
GenerateRequest,
GenerateResponse,
GenerateResponseChoice,
PromptTokenUsageInfo,
RequestResponseMetadata,
UsageInfo,
)
from vllm.entrypoints.openai.serving_engine import OpenAIServing, clamp_prompt_logprobs
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.entrypoints.serve.disagg.protocol import (
GenerateRequest,
GenerateResponse,
GenerateResponseChoice,
)
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
from vllm.logger import init_logger
from vllm.logprobs import Logprob
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
from http import HTTPStatus
from fastapi import APIRouter, Depends, FastAPI, HTTPException, Request
from fastapi.responses import JSONResponse
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.openai.api_server import validate_json_request
from vllm.entrypoints.openai.protocol import (
ErrorResponse,
)
from vllm.entrypoints.serve.elastic_ep.middleware import (
get_scaling_elastic_ep,
set_scaling_elastic_ep,
)
from vllm.logger import init_logger
logger = init_logger(__name__)
def engine_client(request: Request) -> EngineClient:
return request.app.state.engine_client
router = APIRouter()
@router.post(
"/scale_elastic_ep",
dependencies=[Depends(validate_json_request)],
responses={
HTTPStatus.OK.value: {"model": dict},
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
HTTPStatus.REQUEST_TIMEOUT.value: {"model": ErrorResponse},
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
},
)
async def scale_elastic_ep(raw_request: Request):
try:
body = await raw_request.json()
except json.JSONDecodeError as e:
raise HTTPException(status_code=400, detail="Invalid JSON format") from e # noqa: B904
new_data_parallel_size = body.get("new_data_parallel_size")
drain_timeout = body.get("drain_timeout", 120) # Default 2 minutes
if new_data_parallel_size is None:
raise HTTPException(
status_code=400, detail="new_data_parallel_size is required"
)
if not isinstance(new_data_parallel_size, int) or new_data_parallel_size <= 0:
raise HTTPException(
status_code=400,
detail="new_data_parallel_size must be a positive integer",
)
if not isinstance(drain_timeout, int) or drain_timeout <= 0:
raise HTTPException(
status_code=400, detail="drain_timeout must be a positive integer"
)
# Set scaling flag to prevent new requests
set_scaling_elastic_ep(True)
client = engine_client(raw_request)
try:
await client.scale_elastic_ep(new_data_parallel_size, drain_timeout)
return JSONResponse(
{
"message": f"Scaled to {new_data_parallel_size} data parallel engines",
}
)
except TimeoutError as e:
raise HTTPException(
status_code=408,
detail="Scale failed due to request drain timeout "
f"after {drain_timeout} seconds",
) from e
except Exception as e:
logger.error("Scale failed: %s", e)
raise HTTPException(status_code=500, detail="Scale failed") from e
finally:
set_scaling_elastic_ep(False)
@router.post("/is_scaling_elastic_ep")
async def is_scaling_elastic_ep(raw_request: Request):
return JSONResponse({"is_scaling_elastic_ep": get_scaling_elastic_ep()})
def attach_router(app: FastAPI):
app.include_router(router)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Awaitable
from fastapi.responses import JSONResponse
from starlette.types import ASGIApp, Receive, Scope, Send
# Global variable to track scaling state
_scaling_elastic_ep = False
def get_scaling_elastic_ep():
return _scaling_elastic_ep
def set_scaling_elastic_ep(value):
global _scaling_elastic_ep
_scaling_elastic_ep = value
class ScalingMiddleware:
"""
Middleware that checks if the model is currently scaling and
returns a 503 Service Unavailable response if it is.
This middleware applies to all HTTP requests and prevents
processing when the model is in a scaling state.
"""
def __init__(self, app: ASGIApp) -> None:
self.app = app
def __call__(self, scope: Scope, receive: Receive, send: Send) -> Awaitable[None]:
if scope["type"] != "http":
return self.app(scope, receive, send)
# Check global scaling state
if get_scaling_elastic_ep():
# Return 503 Service Unavailable response
response = JSONResponse(
content={
"error": "The model is currently scaling. Please try again later."
},
status_code=503,
)
return response(scope, receive, send)
return self.app(scope, receive, send)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from fastapi import APIRouter, Request
from fastapi.responses import Response
from vllm.engine.protocol import EngineClient
from vllm.logger import init_logger
from vllm.v1.engine.exceptions import EngineDeadError
logger = init_logger(__name__)
router = APIRouter()
def engine_client(request: Request) -> EngineClient:
return request.app.state.engine_client
@router.get("/health", response_class=Response)
async def health(raw_request: Request) -> Response:
"""Health check."""
try:
await engine_client(raw_request).check_health()
return Response(status_code=200)
except EngineDeadError:
return Response(status_code=503)
def attach_router(app):
app.include_router(router)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import re
import prometheus_client
from fastapi import FastAPI, Response
from prometheus_client import make_asgi_app
from prometheus_fastapi_instrumentator import Instrumentator
from starlette.routing import Mount
from vllm.v1.metrics.prometheus import get_prometheus_registry
class PrometheusResponse(Response):
media_type = prometheus_client.CONTENT_TYPE_LATEST
def attach_router(app: FastAPI):
"""Mount prometheus metrics to a FastAPI app."""
registry = get_prometheus_registry()
# `response_class=PrometheusResponse` is needed to return an HTTP response
# with header "Content-Type: text/plain; version=0.0.4; charset=utf-8"
# instead of the default "application/json" which is incorrect.
# See https://github.com/trallnag/prometheus-fastapi-instrumentator/issues/163#issue-1296092364
Instrumentator(
excluded_handlers=[
"/metrics",
"/health",
"/load",
"/ping",
"/version",
"/server_info",
],
registry=registry,
).add().instrument(app).expose(app, response_class=PrometheusResponse)
# Add prometheus asgi middleware to route /metrics requests
metrics_route = Mount("/metrics", make_asgi_app(registry=registry))
# Workaround for 307 Redirect for /metrics
metrics_route.path_regex = re.compile("^/metrics(?P<path>.*)$")
app.routes.append(metrics_route)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import model_hosting_container_standards.sagemaker as sagemaker_standards
from fastapi import APIRouter, Depends, Request
from fastapi import APIRouter, Depends, FastAPI, Request
from fastapi.responses import JSONResponse, Response
from vllm import envs
from vllm.entrypoints.openai.api_server import models, validate_json_request
from vllm.entrypoints.openai.protocol import (
ErrorResponse,
......@@ -14,9 +17,18 @@ from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.logger import init_logger
logger = init_logger(__name__)
router = APIRouter()
def register_dynamic_lora_routes(router: APIRouter):
def attach_router(app: FastAPI):
if not envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING:
"""If LoRA dynamic loading & unloading is not enabled, do nothing."""
return
logger.warning(
"LoRA dynamic loading & unloading is enabled in the API server. "
"This should ONLY be used for local development!"
)
@sagemaker_standards.register_load_adapter_handler(
request_shape={
"lora_name": "body.name",
......@@ -54,4 +66,5 @@ def register_dynamic_lora_routes(router: APIRouter):
return Response(status_code=200, content=response)
return router
# register the router
app.include_router(router)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from fastapi import APIRouter, FastAPI, Request
from fastapi.responses import Response
import vllm.envs as envs
from vllm.engine.protocol import EngineClient
from vllm.logger import init_logger
logger = init_logger(__name__)
router = APIRouter()
def engine_client(request: Request) -> EngineClient:
return request.app.state.engine_client
@router.post("/start_profile")
async def start_profile(raw_request: Request):
logger.info("Starting profiler...")
await engine_client(raw_request).start_profile()
logger.info("Profiler started.")
return Response(status_code=200)
@router.post("/stop_profile")
async def stop_profile(raw_request: Request):
logger.info("Stopping profiler...")
await engine_client(raw_request).stop_profile()
logger.info("Profiler stopped.")
return Response(status_code=200)
def attach_router(app: FastAPI):
if envs.VLLM_TORCH_PROFILER_DIR:
logger.warning_once(
"Torch Profiler is enabled in the API server. This should ONLY be "
"used for local development!"
)
elif envs.VLLM_TORCH_CUDA_PROFILE:
logger.warning_once(
"CUDA Profiler is enabled in the API server. This should ONLY be "
"used for local development!"
)
if envs.VLLM_TORCH_PROFILER_DIR or envs.VLLM_TORCH_CUDA_PROFILE:
app.include_router(router)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment