# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import json from http import HTTPStatus from fastapi import APIRouter, FastAPI, HTTPException, Query, Request from fastapi.responses import JSONResponse import vllm.envs as envs from vllm.distributed.weight_transfer.base import ( WeightTransferInitRequest, WeightTransferUpdateRequest, ) from vllm.engine.protocol import EngineClient from vllm.logger import init_logger logger = init_logger(__name__) def engine_client(request: Request) -> EngineClient: return request.app.state.engine_client router = APIRouter() @router.post("/pause") async def pause_generation( raw_request: Request, wait_for_inflight_requests: bool = Query(False), clear_cache: bool = Query(True), ) -> JSONResponse: """Pause generation requests to allow weight updates. Args: wait_for_inflight_requests: When ``True`` waits for in-flight requests to finish before pausing. When ``False`` (default), aborts any in-flight requests immediately. clear_cache: Whether to clear KV/prefix caches after draining. """ engine = engine_client(raw_request) try: await engine.pause_generation( wait_for_inflight_requests=wait_for_inflight_requests, clear_cache=clear_cache, ) return JSONResponse( content={"status": "paused"}, status_code=HTTPStatus.OK.value, ) except ValueError as err: return JSONResponse( content={"error": str(err)}, status_code=HTTPStatus.BAD_REQUEST.value, ) except Exception as err: # pragma: no cover - defensive logger.exception("Failed to pause generation") return JSONResponse( content={"error": f"Failed to pause generation: {err}"}, status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, ) @router.post("/resume") async def resume_generation(raw_request: Request) -> JSONResponse: """Resume generation after a pause.""" engine = engine_client(raw_request) try: await engine.resume_generation() return JSONResponse( content={"status": "resumed"}, status_code=HTTPStatus.OK.value, ) except Exception as err: # pragma: no cover - defensive logger.exception("Failed to resume generation") return JSONResponse( content={"error": f"Failed to resume generation: {err}"}, status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, ) @router.get("/is_paused") async def is_paused(raw_request: Request) -> JSONResponse: """Return the current pause status.""" engine = engine_client(raw_request) try: paused = await engine.is_paused() except Exception as err: # pragma: no cover - defensive logger.exception("Failed to fetch pause status") return JSONResponse( content={"error": f"Failed to fetch pause status: {err}"}, status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, ) return JSONResponse(content={"is_paused": paused}) @router.post("/init_weight_transfer_engine") async def init_weight_transfer_engine(raw_request: Request): try: body = await raw_request.json() except json.JSONDecodeError as e: raise HTTPException(status_code=400, detail="Invalid JSON format") from e # noqa: B904 init_info = body.get("init_info") if init_info is None: raise HTTPException( status_code=HTTPStatus.BAD_REQUEST.value, detail="Missing 'init_info' in request body", ) await engine_client(raw_request).init_weight_transfer_engine( WeightTransferInitRequest(init_info=init_info) ) return JSONResponse(content={"message": "Weight transfer initialized"}) @router.post("/update_weights") async def update_weights(raw_request: Request): try: body = await raw_request.json() except json.JSONDecodeError as e: raise HTTPException(status_code=400, detail="Invalid JSON format") from e # noqa: B904 update_info = body.get("update_info") if update_info is None: raise HTTPException( status_code=HTTPStatus.BAD_REQUEST.value, detail="Missing 'update_info' in request body", ) await engine_client(raw_request).update_weights( request=WeightTransferUpdateRequest(update_info=update_info) ) return JSONResponse(content={"message": "Weights updated"}) @router.get("/get_world_size") async def get_world_size( raw_request: Request, include_dp: bool = Query(True), ): """Get the world size from the parallel config. Args: include_dp: If True (default), returns the world size including data parallelism (TP * PP * DP). If False, returns the world size without data parallelism (TP * PP). """ parallel_config = engine_client(raw_request).vllm_config.parallel_config if include_dp: world_size = parallel_config.world_size_across_dp else: world_size = parallel_config.world_size return JSONResponse(content={"world_size": world_size}) def attach_router(app: FastAPI): if not envs.VLLM_SERVER_DEV_MODE: return app.include_router(router)