"vscode:/vscode.git/clone" did not exist on "0e6bb7bf406368c59ad08782c25e97ea6c052a96"
Unverified Commit e2a0a4b5 authored by Schwinn Saereesitthipitak's avatar Schwinn Saereesitthipitak Committed by GitHub
Browse files

feat: Sleep/wake endpoints for vLLM runtime (#5339)


Signed-off-by: default avatarSchwinn Saereesitthipitak <schwinns@nvidia.com>
Co-authored-by: default avatarClaude Opus 4.5 <noreply@anthropic.com>
parent fa66f654
...@@ -272,6 +272,68 @@ class BaseWorkerHandler(ABC): ...@@ -272,6 +272,68 @@ class BaseWorkerHandler(ABC):
tokenizer = engine.tokenizer tokenizer = engine.tokenizer
self.input_param_manager = InputParamManager(tokenizer) self.input_param_manager = InputParamManager(tokenizer)
async def sleep(self, body: dict) -> dict:
"""Sleep the engine to release GPU memory and unregister from discovery.
Args:
body: Dict with optional 'level' key (1=weights only, 2=weights+buffers, 3=everything)
Order of operations:
1. Unregister from discovery - stop accepting new requests
2. Sleep engine - safe now that no new requests will arrive
"""
level = body.get("level", 1)
try:
# Step 1: Unregister endpoint instance FIRST to stop new requests from arriving
try:
await self.generate_endpoint.unregister_endpoint_instance()
logger.info(
"[Sleep] Unregistered endpoint from discovery - worker removed from routing pool"
)
except Exception as unreg_err:
logger.warning(
f"[Sleep] Failed to unregister endpoint from discovery: {unreg_err}"
)
# Step 2: Now safe to sleep - no new requests will be routed here
await self.engine_client.sleep(level)
return {"status": "ok", "message": f"Engine slept (level={level})"}
except Exception as e:
logger.error(f"Failed to sleep engine: {e}")
return {"status": "error", "message": str(e)}
async def wake(self, body: dict) -> dict:
"""Wake the engine to restore GPU memory and re-register to discovery.
Args:
body: Dict with optional 'tags' key (e.g., ["weights", "kv_cache"]). None wakes all.
Order of operations:
1. Wake engine - restore GPU memory
2. Re-register endpoint instance - allow frontend to route requests here again
"""
tags = body.get("tags")
try:
# Step 1: Wake engine first - must be ready before accepting requests
await self.engine_client.wake_up(tags)
# Step 2: Re-register endpoint instance to discovery so frontend can route to us again
try:
await self.generate_endpoint.register_endpoint_instance()
logger.info(
"[Wake] Re-registered endpoint to discovery - worker added back to routing pool"
)
except Exception as reg_err:
logger.warning(
f"[Wake] Failed to re-register endpoint to discovery: {reg_err}"
)
return {"status": "ok", "message": f"Engine woke (tags={tags})"}
except Exception as e:
logger.error(f"Failed to wake engine: {e}")
return {"status": "error", "message": str(e)}
@abstractmethod @abstractmethod
async def generate(self, request, context) -> AsyncGenerator[dict, None]: async def generate(self, request, context) -> AsyncGenerator[dict, None]:
raise NotImplementedError raise NotImplementedError
......
...@@ -447,6 +447,11 @@ async def init_prefill(runtime: DistributedRuntime, config: Config): ...@@ -447,6 +447,11 @@ async def init_prefill(runtime: DistributedRuntime, config: Config):
setup_metrics_collection(config, generate_endpoint, logger) setup_metrics_collection(config, generate_endpoint, logger)
# Register sleep/wake engine routes
runtime.register_engine_route("sleep", handler.sleep)
runtime.register_engine_route("wake", handler.wake)
logger.info("Registered engine routes: /engine/sleep, /engine/wake")
# Register prefill model with ModelType.Prefill # Register prefill model with ModelType.Prefill
if not config.engine_args.data_parallel_rank: # if rank is 0 or None then register if not config.engine_args.data_parallel_rank: # if rank is 0 or None then register
model_input = ( model_input = (
...@@ -566,6 +571,11 @@ async def init(runtime: DistributedRuntime, config: Config): ...@@ -566,6 +571,11 @@ async def init(runtime: DistributedRuntime, config: Config):
setup_metrics_collection(config, generate_endpoint, logger) setup_metrics_collection(config, generate_endpoint, logger)
# Register sleep/wake engine routes
runtime.register_engine_route("sleep", handler.sleep)
runtime.register_engine_route("wake", handler.wake)
logger.info("Registered engine routes: /engine/sleep, /engine/wake")
if not config.engine_args.data_parallel_rank: # if rank is 0 or None then register if not config.engine_args.data_parallel_rank: # if rank is 0 or None then register
# Parse endpoint types from --dyn-endpoint-types flag # Parse endpoint types from --dyn-endpoint-types flag
model_type = parse_endpoint_types(config.dyn_endpoint_types) model_type = parse_endpoint_types(config.dyn_endpoint_types)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment