Unverified Commit c09a9aad authored by Schwinn Saereesitthipitak's avatar Schwinn Saereesitthipitak Committed by GitHub
Browse files

fix: guard SGLang/vLLM memory occupation control endpoints (#6967)

parent 9b2b44e3
...@@ -19,6 +19,10 @@ from dynamo.runtime import DistributedRuntime ...@@ -19,6 +19,10 @@ from dynamo.runtime import DistributedRuntime
from dynamo.sglang.args import Config from dynamo.sglang.args import Config
from dynamo.sglang.publisher import DynamoSglangPublisher from dynamo.sglang.publisher import DynamoSglangPublisher
# Keep default tags minimal and safe for general use.
# "cuda_graph" can still be requested explicitly, but it requires LD_PRELOAD setup.
DEFAULT_MEMORY_OCCUPATION_TAGS = ["kv_cache", "weights"]
class BaseGenerativeHandler(ABC): class BaseGenerativeHandler(ABC):
"""Minimal base class for all generative handlers (LLM, diffusion, etc.). """Minimal base class for all generative handlers (LLM, diffusion, etc.).
...@@ -144,6 +148,8 @@ class BaseWorkerHandler(BaseGenerativeHandler): ...@@ -144,6 +148,8 @@ class BaseWorkerHandler(BaseGenerativeHandler):
# have an sgl.Engine. # have an sgl.Engine.
self.input_param_manager = InputParamManager(None) self.input_param_manager = InputParamManager(None)
self._engine_supports_priority = False self._engine_supports_priority = False
self._memory_occupation_lock = asyncio.Lock()
self._memory_released = False
def _priority_kwargs(self, priority: Any) -> Dict[str, Any]: def _priority_kwargs(self, priority: Any) -> Dict[str, Any]:
if priority is not None and self._engine_supports_priority: if priority is not None and self._engine_supports_priority:
...@@ -154,8 +160,7 @@ class BaseWorkerHandler(BaseGenerativeHandler): ...@@ -154,8 +160,7 @@ class BaseWorkerHandler(BaseGenerativeHandler):
"""Release GPU memory occupation and unregister from discovery. """Release GPU memory occupation and unregister from discovery.
Args: Args:
body: Dict with optional 'tags' key for which memory to release. body: Unused. Release always targets default tags.
Default: ["kv_cache", "weights", "cuda_graph"]
Order of operations: Order of operations:
1. Unregister from discovery - stop accepting new requests 1. Unregister from discovery - stop accepting new requests
...@@ -167,43 +172,50 @@ class BaseWorkerHandler(BaseGenerativeHandler): ...@@ -167,43 +172,50 @@ class BaseWorkerHandler(BaseGenerativeHandler):
ReleaseMemoryOccupationReqInput, ReleaseMemoryOccupationReqInput,
) )
tags = body.get("tags", body.get("tag", None)) tags = list(DEFAULT_MEMORY_OCCUPATION_TAGS)
if tags is None: tokenizer_manager = (
tags = ["kv_cache", "weights", "cuda_graph"] getattr(self.engine, "tokenizer_manager", None)
if self.engine is not None
else None
)
if tokenizer_manager is None:
return {
"status": "error",
"message": "memory control not supported on this worker",
}
async with self._memory_occupation_lock:
if self._memory_released:
return {
"status": "ok",
"message": "Memory already released",
}
try:
# Step 1: Unregister endpoint from discovery FIRST
try: try:
await self.generate_endpoint.unregister_endpoint_instance() # Stop new requests and drain in-flight work before releasing memory.
except Exception as unreg_err: if self.generate_endpoint is not None:
logging.warning( await self.generate_endpoint.unregister_endpoint_instance()
f"Failed to unregister endpoint from discovery: {unreg_err}"
)
# Step 2: Pause generation to drain in-flight requests pause_req = PauseGenerationReqInput()
pause_req = PauseGenerationReqInput() await tokenizer_manager.pause_generation(pause_req)
await self.engine.tokenizer_manager.pause_generation(pause_req)
# Step 3: Release memory now that it's safe release_req = ReleaseMemoryOccupationReqInput(tags=tags)
release_req = ReleaseMemoryOccupationReqInput(tags=tags) await tokenizer_manager.release_memory_occupation(release_req, None)
await self.engine.tokenizer_manager.release_memory_occupation( self._memory_released = True
release_req, None
)
return { return {
"status": "ok", "status": "ok",
"message": f"Memory released for tags: {tags}", "message": f"Memory released for tags: {tags}",
} }
except Exception as e: except Exception as e:
logging.error(f"Failed to release memory occupation: {e}") logging.error(f"Failed to release memory occupation: {e}")
return {"status": "error", "message": str(e)} return {"status": "error", "message": str(e)}
async def resume_memory_occupation(self, body: dict) -> dict: async def resume_memory_occupation(self, body: dict) -> dict:
"""Resume GPU memory occupation and re-register to discovery. """Resume GPU memory occupation and re-register to discovery.
Args: Args:
body: Dict with optional 'tags' key for which memory to resume. body: Unused. Resume always targets default tags.
Default: ["kv_cache", "weights", "cuda_graph"]
Order of operations: Order of operations:
1. Resume memory - restore GPU allocations 1. Resume memory - restore GPU allocations
...@@ -215,36 +227,43 @@ class BaseWorkerHandler(BaseGenerativeHandler): ...@@ -215,36 +227,43 @@ class BaseWorkerHandler(BaseGenerativeHandler):
ResumeMemoryOccupationReqInput, ResumeMemoryOccupationReqInput,
) )
tags = body.get("tags", body.get("tag", None)) tags = list(DEFAULT_MEMORY_OCCUPATION_TAGS)
if tags is None: tokenizer_manager = (
tags = ["kv_cache", "weights", "cuda_graph"] getattr(self.engine, "tokenizer_manager", None)
if self.engine is not None
try: else None
# Step 1: Resume memory first - must be ready before accepting requests )
resume_req = ResumeMemoryOccupationReqInput(tags=tags) if tokenizer_manager is None:
await self.engine.tokenizer_manager.resume_memory_occupation( return {
resume_req, None "status": "error",
) "message": "memory control not supported on this worker",
}
# Step 2: Continue generation async with self._memory_occupation_lock:
continue_req = ContinueGenerationReqInput() if not self._memory_released:
await self.engine.tokenizer_manager.continue_generation(continue_req) return {
"status": "ok",
"message": "Memory already resumed",
}
# Step 3: Re-register to discovery so frontend can route to us
try: try:
await self.generate_endpoint.register_endpoint_instance() resume_req = ResumeMemoryOccupationReqInput(tags=tags)
except Exception as reg_err: await tokenizer_manager.resume_memory_occupation(resume_req, None)
logging.warning( continue_req = ContinueGenerationReqInput()
f"Failed to re-register endpoint to discovery: {reg_err}" await tokenizer_manager.continue_generation(continue_req)
)
if self.generate_endpoint is not None:
return { await self.generate_endpoint.register_endpoint_instance()
"status": "ok",
"message": f"Memory resumed for tags: {tags}", self._memory_released = False
}
except Exception as e: return {
logging.error(f"Failed to resume memory occupation: {e}") "status": "ok",
return {"status": "error", "message": str(e)} "message": f"Memory resumed for tags: {tags}",
}
except Exception as e:
logging.error(f"Failed to resume memory occupation: {e}")
return {"status": "error", "message": str(e)}
async def start_profile(self, body: dict) -> dict: async def start_profile(self, body: dict) -> dict:
"""Start profiling on the engine. """Start profiling on the engine.
......
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import asyncio
import sys
import types
from types import SimpleNamespace
from unittest.mock import AsyncMock
import pytest
from dynamo.sglang.request_handlers.handler_base import (
DEFAULT_MEMORY_OCCUPATION_TAGS,
BaseWorkerHandler,
)
pytestmark = [
pytest.mark.unit,
pytest.mark.sglang,
pytest.mark.gpu_0,
pytest.mark.pre_merge,
]
@pytest.fixture(autouse=True)
def _stub_sglang_io_struct(monkeypatch):
"""Keep unit tests independent from CUDA-only sglang imports."""
io_struct = types.ModuleType("sglang.srt.managers.io_struct")
class _Req:
def __init__(self, tags=None):
self.tags = tags
io_struct.PauseGenerationReqInput = _Req
io_struct.ReleaseMemoryOccupationReqInput = _Req
io_struct.ResumeMemoryOccupationReqInput = _Req
io_struct.ContinueGenerationReqInput = _Req
monkeypatch.setitem(sys.modules, "sglang.srt.managers.io_struct", io_struct)
class _TestWorkerHandler(BaseWorkerHandler):
async def generate(self, request, context):
yield {}
def _make_handler() -> _TestWorkerHandler:
handler = _TestWorkerHandler.__new__(_TestWorkerHandler)
handler.engine = SimpleNamespace(
tokenizer_manager=SimpleNamespace(
pause_generation=AsyncMock(),
release_memory_occupation=AsyncMock(),
resume_memory_occupation=AsyncMock(),
continue_generation=AsyncMock(),
)
)
handler.generate_endpoint = SimpleNamespace(
unregister_endpoint_instance=AsyncMock(),
register_endpoint_instance=AsyncMock(),
)
handler._memory_occupation_lock = asyncio.Lock()
handler._memory_released = False
return handler
@pytest.mark.asyncio
async def test_resume_before_release_is_noop():
handler = _make_handler()
result = await handler.resume_memory_occupation({})
assert result["status"] == "ok"
assert result["message"] == "Memory already resumed"
handler.engine.tokenizer_manager.resume_memory_occupation.assert_not_awaited()
handler.engine.tokenizer_manager.continue_generation.assert_not_awaited()
handler.generate_endpoint.register_endpoint_instance.assert_not_awaited()
@pytest.mark.asyncio
async def test_release_and_resume_are_idempotent():
handler = _make_handler()
first_release = await handler.release_memory_occupation({})
second_release = await handler.release_memory_occupation({})
first_resume = await handler.resume_memory_occupation({})
second_resume = await handler.resume_memory_occupation({})
assert first_release["status"] == "ok"
assert second_release["status"] == "ok"
assert first_resume["status"] == "ok"
assert second_resume["status"] == "ok"
assert second_release["message"] == "Memory already released"
assert second_resume["message"] == "Memory already resumed"
assert DEFAULT_MEMORY_OCCUPATION_TAGS == ["kv_cache", "weights"]
release_req = (
handler.engine.tokenizer_manager.release_memory_occupation.await_args.args[0]
)
resume_req = (
handler.engine.tokenizer_manager.resume_memory_occupation.await_args.args[0]
)
assert release_req.tags == DEFAULT_MEMORY_OCCUPATION_TAGS
assert resume_req.tags == DEFAULT_MEMORY_OCCUPATION_TAGS
handler.engine.tokenizer_manager.pause_generation.assert_awaited_once()
handler.engine.tokenizer_manager.release_memory_occupation.assert_awaited_once()
handler.generate_endpoint.unregister_endpoint_instance.assert_awaited_once()
handler.engine.tokenizer_manager.resume_memory_occupation.assert_awaited_once()
handler.engine.tokenizer_manager.continue_generation.assert_awaited_once()
handler.generate_endpoint.register_endpoint_instance.assert_awaited_once()
@pytest.mark.asyncio
async def test_resume_uses_default_tags_even_when_request_specifies_subset():
handler = _make_handler()
await handler.release_memory_occupation({"tags": ["weights"]})
resume_result = await handler.resume_memory_occupation({"tags": ["weights"]})
assert resume_result["status"] == "ok"
resume_req = (
handler.engine.tokenizer_manager.resume_memory_occupation.await_args.args[0]
)
assert resume_req.tags == DEFAULT_MEMORY_OCCUPATION_TAGS
handler.engine.tokenizer_manager.continue_generation.assert_awaited_once()
handler.generate_endpoint.register_endpoint_instance.assert_awaited_once()
@pytest.mark.asyncio
async def test_resume_with_no_sleeping_state_is_noop():
handler = _make_handler()
result = await handler.resume_memory_occupation({})
assert result["status"] == "ok"
assert result["message"] == "Memory already resumed"
handler.engine.tokenizer_manager.resume_memory_occupation.assert_not_awaited()
handler.engine.tokenizer_manager.continue_generation.assert_not_awaited()
handler.generate_endpoint.register_endpoint_instance.assert_not_awaited()
@pytest.mark.asyncio
async def test_release_returns_error_when_worker_has_no_tokenizer_manager():
handler = _make_handler()
handler.engine = None
result = await handler.release_memory_occupation({})
assert result == {
"status": "error",
"message": "memory control not supported on this worker",
}
handler.generate_endpoint.unregister_endpoint_instance.assert_not_awaited()
@pytest.mark.asyncio
async def test_resume_returns_error_when_worker_has_no_tokenizer_manager():
handler = _make_handler()
handler.engine = None
result = await handler.resume_memory_occupation({})
assert result == {
"status": "error",
"message": "memory control not supported on this worker",
}
handler.generate_endpoint.register_endpoint_instance.assert_not_awaited()
...@@ -330,6 +330,8 @@ class BaseWorkerHandler(ABC): ...@@ -330,6 +330,8 @@ class BaseWorkerHandler(ABC):
self.use_vllm_tokenizer = use_vllm_tokenizer self.use_vllm_tokenizer = use_vllm_tokenizer
self.dp_range = get_dp_range_for_worker(self.engine_client.vllm_config) self.dp_range = get_dp_range_for_worker(self.engine_client.vllm_config)
self._sleep_wake_lock = asyncio.Lock()
self._engine_is_sleeping = False
# Initialize InputParamManager for text-in-text-out mode # Initialize InputParamManager for text-in-text-out mode
tokenizer = None tokenizer = None
...@@ -351,64 +353,74 @@ class BaseWorkerHandler(ABC): ...@@ -351,64 +353,74 @@ class BaseWorkerHandler(ABC):
2. Abort and drain in-flight requests 2. Abort and drain in-flight requests
3. Sleep engine - safe now that GPU is quiesced 3. Sleep engine - safe now that GPU is quiesced
""" """
body = body or {}
level = body.get("level", 1) level = body.get("level", 1)
try: async with self._sleep_wake_lock:
# Step 1: Unregister endpoint instance FIRST to stop new requests from arriving if self._engine_is_sleeping:
return {
"status": "ok",
"message": "Engine already sleeping",
}
try: try:
await self.generate_endpoint.unregister_endpoint_instance() # Step 1: Unregister endpoint instance before memory transitions.
logger.info( if self.generate_endpoint is not None:
"[Sleep] Unregistered endpoint from discovery - worker removed from routing pool" await self.generate_endpoint.unregister_endpoint_instance()
) logger.info(
except Exception as unreg_err: "[Sleep] Unregistered endpoint from discovery - worker removed from routing pool"
logger.warning( )
f"[Sleep] Failed to unregister endpoint from discovery: {unreg_err}"
)
# Step 2: Abort in-flight requests and wait for them to drain so the # Step 2: Abort in-flight requests and wait for them to drain so the
# GPU is fully quiesced before unmapping memory. # GPU is fully quiesced before unmapping memory.
await self.engine_client.pause_generation() await self.engine_client.pause_generation()
# Step 3: Now safe to sleep - no in-flight GPU work # Step 3: Now safe to sleep - no in-flight GPU work
await self.engine_client.sleep(level) await self.engine_client.sleep(level)
self._engine_is_sleeping = True
return {"status": "ok", "message": f"Engine slept (level={level})"} return {
except Exception as e: "status": "ok",
logger.error(f"Failed to sleep engine: {e}") "message": f"Engine slept (level={level})",
return {"status": "error", "message": str(e)} }
except Exception as e:
logger.error(f"Failed to sleep engine: {e}")
return {"status": "error", "message": str(e)}
async def wake_up(self, body: dict) -> dict: async def wake_up(self, body: dict) -> dict:
"""Wake the engine to restore GPU memory and re-register to discovery. """Wake the engine to restore GPU memory and re-register to discovery.
Args: Args:
body: Dict with optional 'tags' key (e.g., ["weights", "kv_cache"]). None wakes all. body: Unused. Wake always restores all sleep-managed memory.
Order of operations: Order of operations:
1. Wake engine - restore GPU memory 1. Wake engine - restore GPU memory
2. Re-register endpoint instance - allow frontend to route requests here again 2. Re-register endpoint instance - allow frontend to route requests here again
""" """
tags = body.get("tags") async with self._sleep_wake_lock:
try: if not self._engine_is_sleeping:
# Step 1: Wake engine first - must be ready before accepting requests return {"status": "ok", "message": "Engine already awake"}
await self.engine_client.wake_up(tags)
# Step 2: Resume generation so new requests can be processed
await self.engine_client.resume_generation()
# Step 3: Re-register endpoint instance to discovery so frontend can route to us again
try: try:
await self.generate_endpoint.register_endpoint_instance() # Step 1: Wake engine first - must be ready before accepting requests
logger.info( await self.engine_client.wake_up()
"[Wake] Re-registered endpoint to discovery - worker added back to routing pool"
)
except Exception as reg_err:
logger.warning(
f"[Wake] Failed to re-register endpoint to discovery: {reg_err}"
)
return {"status": "ok", "message": f"Engine woke (tags={tags})"} # Step 2: Resume generation and re-register.
except Exception as e: await self.engine_client.resume_generation()
logger.error(f"Failed to wake up engine: {e}") if self.generate_endpoint is not None:
return {"status": "error", "message": str(e)} await self.generate_endpoint.register_endpoint_instance()
logger.info(
"[Wake] Re-registered endpoint to discovery - worker added back to routing pool"
)
self._engine_is_sleeping = False
return {
"status": "ok",
"message": "Engine woke",
}
except Exception as e:
logger.error(f"Failed to wake up engine: {e}")
return {"status": "error", "message": str(e)}
@abstractmethod @abstractmethod
async def generate(self, request, context) -> AsyncGenerator[dict, None]: async def generate(self, request, context) -> AsyncGenerator[dict, None]:
......
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import asyncio
from types import SimpleNamespace
from unittest.mock import AsyncMock
import pytest
from dynamo.vllm.handlers import BaseWorkerHandler
pytestmark = [
pytest.mark.unit,
pytest.mark.vllm,
pytest.mark.gpu_0,
pytest.mark.pre_merge,
]
class _TestWorkerHandler(BaseWorkerHandler):
async def generate(self, request, context):
yield {}
def _make_handler() -> _TestWorkerHandler:
handler = _TestWorkerHandler.__new__(_TestWorkerHandler)
handler.engine_client = SimpleNamespace(
pause_generation=AsyncMock(),
sleep=AsyncMock(),
wake_up=AsyncMock(),
resume_generation=AsyncMock(),
)
handler.generate_endpoint = SimpleNamespace(
unregister_endpoint_instance=AsyncMock(),
register_endpoint_instance=AsyncMock(),
)
handler._sleep_wake_lock = asyncio.Lock()
handler._engine_is_sleeping = False
return handler
@pytest.mark.asyncio
async def test_wake_up_before_sleep_is_noop():
handler = _make_handler()
result = await handler.wake_up({})
assert result["status"] == "ok"
handler.engine_client.wake_up.assert_not_awaited()
handler.engine_client.resume_generation.assert_not_awaited()
handler.generate_endpoint.register_endpoint_instance.assert_not_awaited()
@pytest.mark.asyncio
async def test_sleep_and_wake_are_idempotent():
handler = _make_handler()
first_sleep = await handler.sleep({"level": 2})
second_sleep = await handler.sleep({"level": 2})
first_wake = await handler.wake_up({})
second_wake = await handler.wake_up({})
assert first_sleep["status"] == "ok"
assert second_sleep["status"] == "ok"
assert first_wake["status"] == "ok"
assert second_wake["status"] == "ok"
handler.engine_client.pause_generation.assert_awaited_once()
handler.engine_client.sleep.assert_awaited_once_with(2)
handler.generate_endpoint.unregister_endpoint_instance.assert_awaited_once()
handler.engine_client.wake_up.assert_awaited_once_with()
handler.engine_client.resume_generation.assert_awaited_once()
handler.generate_endpoint.register_endpoint_instance.assert_awaited_once()
@pytest.mark.asyncio
async def test_sleep_returns_error_for_unregister_failure():
handler = _make_handler()
handler.generate_endpoint.unregister_endpoint_instance = AsyncMock(
side_effect=RuntimeError("discovery backend down")
)
result = await handler.sleep({"level": 1})
assert result["status"] == "error"
handler.engine_client.pause_generation.assert_not_awaited()
handler.engine_client.sleep.assert_not_awaited()
@pytest.mark.asyncio
async def test_wake_up_returns_error_for_register_failure():
handler = _make_handler()
handler._engine_is_sleeping = True
handler.generate_endpoint.register_endpoint_instance = AsyncMock(
side_effect=RuntimeError("discovery write timeout")
)
result = await handler.wake_up({})
assert result["status"] == "error"
handler.engine_client.wake_up.assert_awaited_once_with()
handler.engine_client.resume_generation.assert_awaited_once()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment