Unverified Commit 0e0d6c16 authored by Schwinn Saereesitthipitak's avatar Schwinn Saereesitthipitak Committed by GitHub
Browse files

feat: SGLang release/resume_memory_occupation endpoints (#5207)


Signed-off-by: default avatarSchwinn Saereesitthipitak <schwinns@nvidia.com>
Co-authored-by: default avatarClaude Opus 4.5 <noreply@anthropic.com>
parent e2a0a4b5
......@@ -153,7 +153,21 @@ async def init(runtime: DistributedRuntime, config: Config):
# Readiness gate: requests wait until model is registered
ready_event = asyncio.Event()
handler = DecodeWorkerHandler(component, engine, config, publisher)
handler = DecodeWorkerHandler(
component, engine, config, publisher, generate_endpoint
)
# Register memory management routes using handler methods
runtime.register_engine_route(
"release_memory_occupation", handler.release_memory_occupation
)
runtime.register_engine_route(
"resume_memory_occupation", handler.resume_memory_occupation
)
logging.info(
"Registered engine routes: /engine/release_memory_occupation, /engine/resume_memory_occupation"
)
print(f"Config: {config}")
health_check_payload = SglangHealthCheckPayload(
engine, use_text_input=dynamo_args.use_sglang_tokenizer
......@@ -254,7 +268,20 @@ async def init_prefill(runtime: DistributedRuntime, config: Config):
if engine.server_args.enable_metrics:
setup_prometheus_registry(engine, generate_endpoint)
handler = PrefillWorkerHandler(component, engine, config, publisher)
handler = PrefillWorkerHandler(
component, engine, config, publisher, generate_endpoint
)
# Register memory management routes using handler methods
runtime.register_engine_route(
"release_memory_occupation", handler.release_memory_occupation
)
runtime.register_engine_route(
"resume_memory_occupation", handler.resume_memory_occupation
)
logging.info(
"Registered engine routes: /engine/release_memory_occupation, /engine/resume_memory_occupation"
)
health_check_payload = SglangPrefillHealthCheckPayload(engine).to_dict()
......
......@@ -17,6 +17,8 @@ from dynamo.common.utils.input_params import InputParamManager
from dynamo.sglang.args import Config
from dynamo.sglang.publisher import DynamoSglangPublisher
logger = logging.getLogger(__name__)
class BaseWorkerHandler(ABC):
"""Abstract base class for SGLang worker handlers."""
......@@ -27,6 +29,7 @@ class BaseWorkerHandler(ABC):
engine: sgl.Engine,
config: Config,
publisher: Optional[DynamoSglangPublisher] = None,
generate_endpoint=None,
) -> None:
"""Initialize base worker handler.
......@@ -35,10 +38,12 @@ class BaseWorkerHandler(ABC):
engine: The SGLang engine instance.
config: SGLang and Dynamo configuration.
publisher: Optional metrics publisher for the worker.
generate_endpoint: The endpoint handle for discovery registration.
"""
self.component = component
self.engine = engine
self.config = config
self.generate_endpoint = generate_endpoint
if publisher is not None:
self.metrics_publisher = publisher.metrics_publisher
self.kv_publisher = publisher.kv_publisher
......@@ -55,6 +60,94 @@ class BaseWorkerHandler(ABC):
else None
)
async def release_memory_occupation(self, body: dict) -> dict:
"""Release GPU memory occupation and unregister from discovery.
Args:
body: Dict with optional 'tags' key for which memory to release.
Default: ["kv_cache", "weights", "cuda_graph"]
Order of operations:
1. Unregister from discovery - stop accepting new requests
2. Pause generation - drain in-flight requests
3. Release memory - safe now that no requests are active
"""
tags = body.get("tags", body.get("tag", None))
if tags is None:
tags = ["kv_cache", "weights", "cuda_graph"]
try:
# Step 1: Unregister endpoint from discovery FIRST
try:
await self.generate_endpoint.unregister_endpoint_instance()
logger.info(
"[ReleaseMemory] Unregistered endpoint from discovery - worker removed from routing pool"
)
except Exception as unreg_err:
logger.warning(
f"[ReleaseMemory] Failed to unregister endpoint from discovery: {unreg_err}"
)
# Step 2: Pause generation to drain in-flight requests
await self.engine.async_pause_generation()
logger.info("[ReleaseMemory] Generation paused")
# Step 3: Release memory now that it's safe
await self.engine.async_release_memory_occupation(tags)
logger.info(f"[ReleaseMemory] Released memory for tags: {tags}")
return {
"status": "ok",
"message": f"Memory released for tags: {tags}",
}
except Exception as e:
logger.error(f"Failed to release memory occupation: {e}")
return {"status": "error", "message": str(e)}
async def resume_memory_occupation(self, body: dict) -> dict:
"""Resume GPU memory occupation and re-register to discovery.
Args:
body: Dict with optional 'tags' key for which memory to resume.
Default: ["kv_cache", "weights", "cuda_graph"]
Order of operations:
1. Resume memory - restore GPU allocations
2. Continue generation - ready to serve requests
3. Re-register to discovery - allow frontend to route here
"""
tags = body.get("tags", body.get("tag", None))
if tags is None:
tags = ["kv_cache", "weights", "cuda_graph"]
try:
# Step 1: Resume memory first - must be ready before accepting requests
await self.engine.async_resume_memory_occupation(tags)
logger.info(f"[ResumeMemory] Resumed memory for tags: {tags}")
# Step 2: Continue generation
await self.engine.async_continue_generation()
logger.info("[ResumeMemory] Generation continued")
# Step 3: Re-register to discovery so frontend can route to us
try:
await self.generate_endpoint.register_endpoint_instance()
logger.info(
"[ResumeMemory] Re-registered endpoint to discovery - worker added back to routing pool"
)
except Exception as reg_err:
logger.warning(
f"[ResumeMemory] Failed to re-register endpoint to discovery: {reg_err}"
)
return {
"status": "ok",
"message": f"Memory resumed for tags: {tags}",
}
except Exception as e:
logger.error(f"Failed to resume memory occupation: {e}")
return {"status": "error", "message": str(e)}
@abstractmethod
async def generate(self, request: Dict[str, Any], context: Context):
"""Generate response from request.
......
......@@ -23,6 +23,7 @@ class DecodeWorkerHandler(BaseWorkerHandler):
engine: sgl.Engine,
config: Config,
publisher: DynamoSglangPublisher,
generate_endpoint=None,
) -> None:
"""Initialize decode worker handler.
......@@ -31,12 +32,14 @@ class DecodeWorkerHandler(BaseWorkerHandler):
engine: The SGLang engine instance.
config: SGLang and Dynamo configuration.
publisher: Metrics publisher for the worker.
generate_endpoint: The endpoint handle for discovery registration.
"""
super().__init__(
component,
engine,
config,
publisher,
generate_endpoint,
)
if self.serving_mode == DisaggregationMode.DECODE:
logging.info(
......
......@@ -22,6 +22,7 @@ class PrefillWorkerHandler(BaseWorkerHandler):
engine: sgl.Engine,
config: Config,
publisher: DynamoSglangPublisher,
generate_endpoint=None,
) -> None:
"""Initialize prefill worker handler.
......@@ -30,10 +31,11 @@ class PrefillWorkerHandler(BaseWorkerHandler):
engine: The SGLang engine instance.
config: SGLang and Dynamo configuration.
publisher: The SGLang publisher instance.
generate_endpoint: The endpoint handle for discovery registration.
"""
self.engine = engine
self.bootstrap_host, self.bootstrap_port = self._get_bootstrap_info(self.engine)
super().__init__(component, engine, config, publisher)
super().__init__(component, engine, config, publisher, generate_endpoint)
self._consume_tasks = set()
logging.info(
f"Prefill worker handler initialized - bootstrap host: {self.bootstrap_host}, bootstrap port: {self.bootstrap_port}"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment