Unverified Commit 2e8c4447 authored by Schwinn Saereesitthipitak's avatar Schwinn Saereesitthipitak Committed by GitHub
Browse files

refactor: clean up SGLang sleep/wake implementation (#5517)

parent 04cecda7
......@@ -124,23 +124,6 @@ async def init(runtime: DistributedRuntime, config: Config):
await _handle_non_leader_node(engine, generate_endpoint)
return
# Register engine routes for profiling
async def start_profile_handler(body: dict) -> dict:
"""Handle /engine/start_profile requests"""
await engine.tokenizer_manager.start_profile(**body)
return {"status": "ok", "message": "Profiling started"}
async def stop_profile_handler(body: dict) -> dict:
"""Handle /engine/stop_profile requests"""
await engine.tokenizer_manager.stop_profile()
return {"status": "ok", "message": "Profiling stopped"}
runtime.register_engine_route("start_profile", start_profile_handler)
runtime.register_engine_route("stop_profile", stop_profile_handler)
logging.info(
"Registered engine routes: /engine/start_profile, /engine/stop_profile"
)
# publisher instantiates the metrics and kv event publishers
publisher, metrics_task, metrics_labels = await setup_sgl_metrics(
engine, config, component, generate_endpoint
......@@ -156,17 +139,7 @@ async def init(runtime: DistributedRuntime, config: Config):
handler = DecodeWorkerHandler(
component, engine, config, publisher, generate_endpoint
)
# Register memory management routes using handler methods
runtime.register_engine_route(
"release_memory_occupation", handler.release_memory_occupation
)
runtime.register_engine_route(
"resume_memory_occupation", handler.resume_memory_occupation
)
logging.info(
"Registered engine routes: /engine/release_memory_occupation, /engine/resume_memory_occupation"
)
handler.register_engine_routes(runtime)
print(f"Config: {config}")
health_check_payload = SglangHealthCheckPayload(
......@@ -238,23 +211,6 @@ async def init_prefill(runtime: DistributedRuntime, config: Config):
await _handle_non_leader_node(engine, generate_endpoint)
return
# Register engine routes for profiling
async def start_profile_handler(body: dict) -> dict:
"""Handle /engine/start_profile requests"""
await engine.tokenizer_manager.start_profile(**body)
return {"status": "ok", "message": "Profiling started"}
async def stop_profile_handler(body: dict) -> dict:
"""Handle /engine/stop_profile requests"""
await engine.tokenizer_manager.stop_profile()
return {"status": "ok", "message": "Profiling stopped"}
runtime.register_engine_route("start_profile", start_profile_handler)
runtime.register_engine_route("stop_profile", stop_profile_handler)
logging.info(
"Registered engine routes: /engine/start_profile, /engine/stop_profile"
)
# Perform dummy warmup for prefill worker to avoid initial TTFT hit
# Only needed on leader node that handles requests
await _warmup_prefill_engine(engine, server_args)
......@@ -271,17 +227,7 @@ async def init_prefill(runtime: DistributedRuntime, config: Config):
handler = PrefillWorkerHandler(
component, engine, config, publisher, generate_endpoint
)
# Register memory management routes using handler methods
runtime.register_engine_route(
"release_memory_occupation", handler.release_memory_occupation
)
runtime.register_engine_route(
"resume_memory_occupation", handler.resume_memory_occupation
)
logging.info(
"Registered engine routes: /engine/release_memory_occupation, /engine/resume_memory_occupation"
)
handler.register_engine_routes(runtime)
health_check_payload = SglangPrefillHealthCheckPayload(engine).to_dict()
......
......@@ -17,8 +17,6 @@ from dynamo.common.utils.input_params import InputParamManager
from dynamo.sglang.args import Config
from dynamo.sglang.publisher import DynamoSglangPublisher
logger = logging.getLogger(__name__)
class BaseWorkerHandler(ABC):
"""Abstract base class for SGLang worker handlers."""
......@@ -80,28 +78,23 @@ class BaseWorkerHandler(ABC):
# Step 1: Unregister endpoint from discovery FIRST
try:
await self.generate_endpoint.unregister_endpoint_instance()
logger.info(
"[ReleaseMemory] Unregistered endpoint from discovery - worker removed from routing pool"
)
except Exception as unreg_err:
logger.warning(
f"[ReleaseMemory] Failed to unregister endpoint from discovery: {unreg_err}"
logging.warning(
f"Failed to unregister endpoint from discovery: {unreg_err}"
)
# Step 2: Pause generation to drain in-flight requests
await self.engine.async_pause_generation()
logger.info("[ReleaseMemory] Generation paused")
# Step 3: Release memory now that it's safe
await self.engine.async_release_memory_occupation(tags)
logger.info(f"[ReleaseMemory] Released memory for tags: {tags}")
return {
"status": "ok",
"message": f"Memory released for tags: {tags}",
}
except Exception as e:
logger.error(f"Failed to release memory occupation: {e}")
logging.error(f"Failed to release memory occupation: {e}")
return {"status": "error", "message": str(e)}
async def resume_memory_occupation(self, body: dict) -> dict:
......@@ -123,21 +116,16 @@ class BaseWorkerHandler(ABC):
try:
# Step 1: Resume memory first - must be ready before accepting requests
await self.engine.async_resume_memory_occupation(tags)
logger.info(f"[ResumeMemory] Resumed memory for tags: {tags}")
# Step 2: Continue generation
await self.engine.async_continue_generation()
logger.info("[ResumeMemory] Generation continued")
# Step 3: Re-register to discovery so frontend can route to us
try:
await self.generate_endpoint.register_endpoint_instance()
logger.info(
"[ResumeMemory] Re-registered endpoint to discovery - worker added back to routing pool"
)
except Exception as reg_err:
logger.warning(
f"[ResumeMemory] Failed to re-register endpoint to discovery: {reg_err}"
logging.warning(
f"Failed to re-register endpoint to discovery: {reg_err}"
)
return {
......@@ -145,9 +133,42 @@ class BaseWorkerHandler(ABC):
"message": f"Memory resumed for tags: {tags}",
}
except Exception as e:
logger.error(f"Failed to resume memory occupation: {e}")
logging.error(f"Failed to resume memory occupation: {e}")
return {"status": "error", "message": str(e)}
async def start_profile(self, body: dict) -> dict:
"""Start profiling on the engine.
Args:
body: Dict with profiling parameters passed to start_profile.
"""
await self.engine.tokenizer_manager.start_profile(**body)
return {"status": "ok", "message": "Profiling started"}
async def stop_profile(self, body: dict) -> dict:
"""Stop profiling on the engine.
Args:
body: Unused, but required for handler signature.
"""
await self.engine.tokenizer_manager.stop_profile()
return {"status": "ok", "message": "Profiling stopped"}
def register_engine_routes(self, runtime) -> None:
"""Register all engine routes for this handler.
Args:
runtime: The DistributedRuntime instance to register routes on.
"""
runtime.register_engine_route("start_profile", self.start_profile)
runtime.register_engine_route("stop_profile", self.stop_profile)
runtime.register_engine_route(
"release_memory_occupation", self.release_memory_occupation
)
runtime.register_engine_route(
"resume_memory_occupation", self.resume_memory_occupation
)
@abstractmethod
async def generate(self, request: Dict[str, Any], context: Context):
"""Generate response from request.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment