"docs/vscode:/vscode.git/clone" did not exist on "cfe74445fcd5440dd908214657d13d7254d1de5f"
Unverified Commit b13a440e authored by MatejKosec's avatar MatejKosec Committed by GitHub
Browse files

fix(trtllm): prevent decode worker segfault on prefill scale-down (#7933)


Signed-off-by: default avatarMatej Kosec <mkosec@nvidia.com>
parent 9a07ca15
...@@ -5,7 +5,7 @@ import asyncio ...@@ -5,7 +5,7 @@ import asyncio
import logging import logging
import os import os
import signal import signal
from typing import Any, Iterable, Optional from typing import Any, Callable, Coroutine, Iterable, Optional
from dynamo._core import DistributedRuntime from dynamo._core import DistributedRuntime
...@@ -13,6 +13,7 @@ logger = logging.getLogger(__name__) ...@@ -13,6 +13,7 @@ logger = logging.getLogger(__name__)
# TODO: make this using cli flag # TODO: make this using cli flag
_DEFAULT_GRACE_PERIOD_SECS = 5.0 _DEFAULT_GRACE_PERIOD_SECS = 5.0
_DEFAULT_DRAIN_TIMEOUT_SECS = 30.0
_GRACE_PERIOD_ENV = "DYN_GRACEFUL_SHUTDOWN_GRACE_PERIOD_SECS" _GRACE_PERIOD_ENV = "DYN_GRACEFUL_SHUTDOWN_GRACE_PERIOD_SECS"
_shutdown_started = asyncio.Event() _shutdown_started = asyncio.Event()
...@@ -68,7 +69,23 @@ async def graceful_shutdown_with_discovery( ...@@ -68,7 +69,23 @@ async def graceful_shutdown_with_discovery(
endpoints: Iterable, endpoints: Iterable,
shutdown_event: Optional[asyncio.Event] = None, shutdown_event: Optional[asyncio.Event] = None,
grace_period_s: Optional[float] = None, grace_period_s: Optional[float] = None,
drain_callback: Optional[Callable[[], Coroutine]] = None,
) -> None: ) -> None:
"""Perform graceful shutdown with endpoint unregistration and optional drain.
Args:
runtime: The distributed runtime to shut down.
endpoints: Endpoints to unregister from discovery before shutdown.
shutdown_event: Optional event to set before calling runtime.shutdown().
grace_period_s: Seconds to wait after unregistering before drain/shutdown.
Defaults to DYN_GRACEFUL_SHUTDOWN_GRACE_PERIOD_SECS env var or 5s.
drain_callback: Optional async callable awaited after the grace period
but *before* runtime.shutdown(). Use this on prefill workers to wait
for in-flight NIXL KV transfers to complete, preventing decode workers
from segfaulting due to use-after-free on freed GPU memory (#7319).
Any exception raised by drain_callback is logged and swallowed so that
shutdown still proceeds even if draining times out or fails.
"""
if _shutdown_started.is_set(): if _shutdown_started.is_set():
return return
_shutdown_started.set() _shutdown_started.set()
...@@ -83,6 +100,25 @@ async def graceful_shutdown_with_discovery( ...@@ -83,6 +100,25 @@ async def graceful_shutdown_with_discovery(
logger.info("Grace period %.2fs before stopping endpoints", grace_period_s) logger.info("Grace period %.2fs before stopping endpoints", grace_period_s)
await asyncio.sleep(grace_period_s) await asyncio.sleep(grace_period_s)
if drain_callback is not None:
logger.info(
"Draining in-flight transfers before shutdown (issue #7319 safeguard)"
)
try:
await asyncio.wait_for(
drain_callback(), timeout=_DEFAULT_DRAIN_TIMEOUT_SECS
)
logger.info("Drain complete")
except asyncio.TimeoutError:
logger.warning(
"Drain callback timed out after %.0fs, proceeding with shutdown",
_DEFAULT_DRAIN_TIMEOUT_SECS,
)
except Exception:
logger.exception(
"Drain callback raised an exception; proceeding with shutdown"
)
if shutdown_event is not None: if shutdown_event is not None:
shutdown_event.set() shutdown_event.set()
...@@ -96,6 +132,7 @@ def install_signal_handlers( ...@@ -96,6 +132,7 @@ def install_signal_handlers(
endpoints: Iterable, endpoints: Iterable,
shutdown_event: Optional[asyncio.Event] = None, shutdown_event: Optional[asyncio.Event] = None,
grace_period_s: Optional[float] = None, grace_period_s: Optional[float] = None,
drain_callback: Optional[Callable[[], Coroutine]] = None,
) -> None: ) -> None:
shutdown_task: Optional[asyncio.Task[None]] = None shutdown_task: Optional[asyncio.Task[None]] = None
...@@ -123,6 +160,7 @@ def install_signal_handlers( ...@@ -123,6 +160,7 @@ def install_signal_handlers(
endpoints, endpoints,
shutdown_event=shutdown_event, shutdown_event=shutdown_event,
grace_period_s=grace_period_s, grace_period_s=grace_period_s,
drain_callback=drain_callback,
) )
) )
shutdown_task.add_done_callback(_on_shutdown_done) shutdown_task.add_done_callback(_on_shutdown_done)
......
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Unit tests for graceful_shutdown.py
Tests the drain_callback mechanism added to prevent decode worker segfaults when
a prefill worker scales down before in-flight NIXL KV transfers complete (issue #7319).
These tests import graceful_shutdown directly (bypassing the dynamo package hierarchy)
so they work without GPU, NIXL, or TensorRT-LLM installed.
"""
import asyncio
import importlib.util
import sys
import types
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock
import pytest
pytestmark = [pytest.mark.unit, pytest.mark.pre_merge]
# ---------------------------------------------------------------------------
# Module loading: import graceful_shutdown without triggering the full dynamo
# package (which requires dynamo.llm, CUDA, etc.)
#
# We cannot do `from dynamo.common.utils import graceful_shutdown` because the
# dynamo package __init__ transitively imports dynamo._core, which is a native
# extension (PyO3) requiring CUDA/NIXL libraries that are not available in
# unit test environments. Instead, we stub dynamo._core and load the module
# directly from its file path via importlib.
# ---------------------------------------------------------------------------
_GRACEFUL_SHUTDOWN_PATH = Path(__file__).parent.parent / "graceful_shutdown.py"
# Provide a minimal dynamo._core stub so the module can be loaded
_dynamo_stub = types.ModuleType("dynamo")
_dynamo_core_stub = types.ModuleType("dynamo._core")
_dynamo_core_stub.DistributedRuntime = object
sys.modules.setdefault("dynamo", _dynamo_stub)
sys.modules.setdefault("dynamo._core", _dynamo_core_stub)
def _load_graceful_shutdown():
spec = importlib.util.spec_from_file_location(
"dynamo.common.utils.graceful_shutdown",
_GRACEFUL_SHUTDOWN_PATH,
)
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)
return mod
_gs = _load_graceful_shutdown()
graceful_shutdown_with_discovery = _gs.graceful_shutdown_with_discovery
install_signal_handlers = _gs.install_signal_handlers
# ---------------------------------------------------------------------------
# Helper: reset the module-level _shutdown_started event between tests
# ---------------------------------------------------------------------------
@pytest.fixture(autouse=True)
def reset_shutdown_state():
_gs._shutdown_started.clear()
yield
_gs._shutdown_started.clear()
# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------
def test_drain_callback_called_before_shutdown():
"""Drain callback must be awaited before runtime.shutdown().
This is the key regression test for issue #7319: prefill workers holding
active NIXL RDMA references must drain in-flight transfers before their
process exits, otherwise decode workers segfault accessing freed GPU memory.
"""
call_order = []
mock_runtime = MagicMock()
mock_runtime.shutdown = MagicMock(side_effect=lambda: call_order.append("shutdown"))
async def mock_drain():
call_order.append("drain")
async def _run():
mock_endpoint = AsyncMock()
mock_endpoint.unregister_endpoint_instance = AsyncMock(return_value=None)
await graceful_shutdown_with_discovery(
runtime=mock_runtime,
endpoints=[mock_endpoint],
shutdown_event=None,
grace_period_s=0,
drain_callback=mock_drain,
)
asyncio.run(_run())
assert "drain" in call_order, "drain_callback was not called"
assert "shutdown" in call_order, "runtime.shutdown was not called"
drain_idx = call_order.index("drain")
shutdown_idx = call_order.index("shutdown")
assert drain_idx < shutdown_idx, (
"drain_callback must be called before runtime.shutdown() to ensure "
"in-flight NIXL transfers complete before GPU memory is freed"
)
def test_no_drain_callback_still_shuts_down():
"""Backward compatibility: shutdown still works without drain_callback."""
mock_runtime = MagicMock()
async def _run():
mock_endpoint = AsyncMock()
mock_endpoint.unregister_endpoint_instance = AsyncMock(return_value=None)
await graceful_shutdown_with_discovery(
runtime=mock_runtime,
endpoints=[mock_endpoint],
shutdown_event=None,
grace_period_s=0,
drain_callback=None,
)
asyncio.run(_run())
mock_runtime.shutdown.assert_called_once()
def test_drain_callback_exception_does_not_block_shutdown():
"""Drain callback exceptions must not block shutdown.
Even if draining fails (e.g., timeout), the shutdown must still proceed
so the process exits cleanly.
"""
mock_runtime = MagicMock()
async def failing_drain():
raise RuntimeError("drain timed out")
async def _run():
mock_endpoint = AsyncMock()
mock_endpoint.unregister_endpoint_instance = AsyncMock(return_value=None)
await graceful_shutdown_with_discovery(
runtime=mock_runtime,
endpoints=[mock_endpoint],
shutdown_event=None,
grace_period_s=0,
drain_callback=failing_drain,
)
# Should not raise
asyncio.run(_run())
mock_runtime.shutdown.assert_called_once()
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
import asyncio import asyncio
import logging import logging
from typing import Callable, Coroutine
import uvloop import uvloop
...@@ -10,11 +11,72 @@ from dynamo.common.utils.graceful_shutdown import install_signal_handlers ...@@ -10,11 +11,72 @@ from dynamo.common.utils.graceful_shutdown import install_signal_handlers
from dynamo.common.utils.runtime import create_runtime from dynamo.common.utils.runtime import create_runtime
from dynamo.runtime.logging import configure_dynamo_logging from dynamo.runtime.logging import configure_dynamo_logging
from dynamo.trtllm.args import parse_args from dynamo.trtllm.args import parse_args
from dynamo.trtllm.constants import DisaggregationMode
from dynamo.trtllm.workers import init_worker from dynamo.trtllm.workers import init_worker
configure_dynamo_logging() configure_dynamo_logging()
shutdown_endpoints: list = [] shutdown_endpoints: list = []
# Maximum time (seconds) to wait for in-flight requests to drain during shutdown.
_DRAIN_TIMEOUT_S = 30.0
_DRAIN_POLL_INTERVAL_S = 0.5
def _make_drain_callback(
engine_holder: list,
) -> Callable[[], Coroutine]:
"""Create a drain callback that polls the TRT-LLM engine until idle.
The engine_holder is a mutable list populated by init_llm_worker once the
engine is ready. If it is still empty when the signal fires (engine not yet
initialized), draining is skipped.
Returns None when the worker is not a prefill worker (drain is unnecessary).
The caller checks disaggregation_mode *before* calling this helper.
"""
async def _drain_in_flight_requests():
if not engine_holder:
logging.info("Engine not yet initialized; skipping drain")
return
engine = engine_holder[0]
logging.info(
"Draining in-flight requests (timeout=%.1fs) to allow "
"NIXL KV transfers to complete before GPU memory is freed",
_DRAIN_TIMEOUT_S,
)
deadline = asyncio.get_running_loop().time() + _DRAIN_TIMEOUT_S
while asyncio.get_running_loop().time() < deadline:
try:
stats_iter = engine.llm.get_stats_async(timeout=2)
stat = await anext(stats_iter)
active = stat.get("numActiveRequests", 0)
queued = stat.get("numQueuedRequests", 0)
total = active + queued
if total == 0:
logging.info("All in-flight requests drained")
return
logging.info(
"Waiting for %d in-flight request(s) to complete "
"(active=%d, queued=%d)",
total,
active,
queued,
)
except Exception as e:
# get_stats_async may fail if engine is already partially torn down
logging.debug("Stats poll failed during drain: %s", e)
await asyncio.sleep(_DRAIN_POLL_INTERVAL_S)
logging.warning(
"Drain timeout (%.1fs) reached; proceeding with shutdown. "
"Some NIXL transfers may still be in flight.",
_DRAIN_TIMEOUT_S,
)
return _drain_in_flight_requests
async def worker(): async def worker():
config = parse_args() config = parse_args()
...@@ -26,10 +88,31 @@ async def worker(): ...@@ -26,10 +88,31 @@ async def worker():
event_plane=config.event_plane, event_plane=config.event_plane,
) )
install_signal_handlers(loop, runtime, shutdown_endpoints, shutdown_event) # Only prefill workers need a drain callback. When a prefill worker shuts
# down, decode workers may still be reading its GPU memory via NIXL RDMA.
# The drain callback waits for in-flight requests to finish so that GPU
# memory is not freed while transfers are active (issue #7319).
engine_holder: list = []
drain_callback = None
if config.disaggregation_mode == DisaggregationMode.PREFILL:
drain_callback = _make_drain_callback(engine_holder)
install_signal_handlers(
loop,
runtime,
shutdown_endpoints,
shutdown_event,
drain_callback=drain_callback,
)
logging.info(f"Initializing the worker with config: {config}") logging.info(f"Initializing the worker with config: {config}")
await init_worker(runtime, config, shutdown_event, shutdown_endpoints) await init_worker(
runtime,
config,
shutdown_event,
shutdown_endpoints,
engine_holder=engine_holder,
)
def main(): def main():
......
...@@ -31,6 +31,7 @@ async def init_worker( ...@@ -31,6 +31,7 @@ async def init_worker(
config: Config, config: Config,
shutdown_event: asyncio.Event, shutdown_event: asyncio.Event,
shutdown_endpoints: Optional[list] = None, shutdown_endpoints: Optional[list] = None,
engine_holder: Optional[list] = None,
) -> None: ) -> None:
"""Initialize the appropriate worker based on modality. """Initialize the appropriate worker based on modality.
...@@ -42,6 +43,9 @@ async def init_worker( ...@@ -42,6 +43,9 @@ async def init_worker(
config: Configuration parsed from command line. config: Configuration parsed from command line.
shutdown_event: Event to signal shutdown. shutdown_event: Event to signal shutdown.
shutdown_endpoints: Optional list to populate with endpoints for graceful shutdown. shutdown_endpoints: Optional list to populate with endpoints for graceful shutdown.
engine_holder: Optional mutable list; when provided, init_llm_worker will
append the TensorRTLLMEngine instance so that the drain callback
(installed earlier by main.py) can access it at signal time.
""" """
logging.info(f"Initializing worker with modality={config.modality}") logging.info(f"Initializing worker with modality={config.modality}")
...@@ -61,7 +65,13 @@ async def init_worker( ...@@ -61,7 +65,13 @@ async def init_worker(
raise ValueError(f"Unsupported diffusion modality: {modality}") raise ValueError(f"Unsupported diffusion modality: {modality}")
# LLM modalities (text, multimodal) # LLM modalities (text, multimodal)
await init_llm_worker(runtime, config, shutdown_event, shutdown_endpoints) await init_llm_worker(
runtime,
config,
shutdown_event,
shutdown_endpoints,
engine_holder=engine_holder,
)
__all__ = ["init_worker"] __all__ = ["init_worker"]
...@@ -132,6 +132,7 @@ async def init_llm_worker( ...@@ -132,6 +132,7 @@ async def init_llm_worker(
config: Config, config: Config,
shutdown_event: asyncio.Event, shutdown_event: asyncio.Event,
shutdown_endpoints: Optional[list] = None, shutdown_endpoints: Optional[list] = None,
engine_holder: Optional[list] = None,
) -> None: ) -> None:
"""Initialize and run the LLM worker. """Initialize and run the LLM worker.
...@@ -142,6 +143,8 @@ async def init_llm_worker( ...@@ -142,6 +143,8 @@ async def init_llm_worker(
config: Configuration parsed from command line. config: Configuration parsed from command line.
shutdown_event: Event to signal shutdown. shutdown_event: Event to signal shutdown.
shutdown_endpoints: Optional list to populate with endpoints for graceful shutdown. shutdown_endpoints: Optional list to populate with endpoints for graceful shutdown.
engine_holder: Optional mutable list; when provided, the TensorRTLLMEngine
is appended so that the drain callback can reference it at shutdown time.
""" """
encode_client = None encode_client = None
...@@ -384,6 +387,11 @@ async def init_llm_worker( ...@@ -384,6 +387,11 @@ async def init_llm_worker(
config.disaggregation_mode, config.disaggregation_mode,
component_gauges=component_gauges, component_gauges=component_gauges,
) as engine: ) as engine:
# Expose engine to the drain callback installed by main.py (#7319).
# The callback uses this to poll active request count during shutdown.
if engine_holder is not None:
engine_holder.append(engine)
endpoint = runtime.endpoint( endpoint = runtime.endpoint(
f"{config.namespace}.{config.component}.{config.endpoint}" f"{config.namespace}.{config.component}.{config.endpoint}"
) )
......
...@@ -463,7 +463,17 @@ class ActiveOperation(AbstractOperation): ...@@ -463,7 +463,17 @@ class ActiveOperation(AbstractOperation):
case OperationStatus.INITIALIZED | OperationStatus.IN_PROGRESS: case OperationStatus.INITIALIZED | OperationStatus.IN_PROGRESS:
await asyncio.sleep(sleep_time / 1000) await asyncio.sleep(sleep_time / 1000)
sleep_time = min(sleep_time * backoff_factor, max_poll_ms) sleep_time = min(sleep_time * backoff_factor, max_poll_ms)
# Any other state indicates completion or error. # ERRORED indicates the remote agent may have disconnected or
# its memory may be invalid (e.g. prefill worker scaled down).
# Raise so the caller can surface this as a retryable error
# rather than silently returning stale/empty data.
case OperationStatus.ERRORED:
raise RuntimeError(
f"NIXL transfer operation ERRORED for remote '{self._remote.name}'. "
"The remote agent may have disconnected or its GPU memory may be "
"invalid (e.g. the prefill worker was scaled down mid-transfer)."
)
# Any other state (COMPLETE, CANCELLED) indicates the transfer is done.
case _: case _:
return return
...@@ -489,7 +499,11 @@ class ActiveOperation(AbstractOperation): ...@@ -489,7 +499,11 @@ class ActiveOperation(AbstractOperation):
""" """
# Early return if the operation is already complete, errored, or cancelled. # Early return if the operation is already complete, errored, or cancelled.
match self._status: match self._status:
case OperationStatus.COMPLETE | OperationStatus.ERRORED | OperationStatus.CANCELLED: case (
OperationStatus.COMPLETE
| OperationStatus.ERRORED
| OperationStatus.CANCELLED
):
return self._status return self._status
if self._xfer_hndl is None: if self._xfer_hndl is None:
...@@ -1466,7 +1480,11 @@ class PassiveOperation(AbstractOperation): ...@@ -1466,7 +1480,11 @@ class PassiveOperation(AbstractOperation):
""" """
# Early return if the operation is already complete, errored, or cancelled. # Early return if the operation is already complete, errored, or cancelled.
match self._status: match self._status:
case OperationStatus.COMPLETE | OperationStatus.ERRORED | OperationStatus.CANCELLED: case (
OperationStatus.COMPLETE
| OperationStatus.ERRORED
| OperationStatus.CANCELLED
):
return self._status return self._status
old_status = self._status old_status = self._status
......
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Unit tests for dynamo.nixl_connect
Tests the ERRORED state handling in ActiveOperation._wait_for_completion_() added
to prevent decode workers from silently consuming bad data when a prefill worker
disappears mid-transfer (issue #7319).
NIXL and CUDA are mocked so these tests run on CPU-only machines.
"""
import sys
from unittest.mock import MagicMock, patch
import pytest
pytestmark = [pytest.mark.unit, pytest.mark.pre_merge]
def _make_nixl_mocks():
"""Create minimal mocks for nixl._api and nixl._bindings."""
nixl_api_mock = MagicMock()
nixl_bindings_mock = MagicMock()
# nixl_agent mock (returned by nixl_api.nixl_agent(...))
agent_instance = MagicMock()
agent_instance.get_agent_metadata.return_value = b"mock-metadata"
agent_instance.add_remote_agent.return_value = b"mock-remote-agent"
agent_instance.get_xfer_descs.return_value = MagicMock()
agent_instance.initialize_xfer.return_value = MagicMock()
agent_instance.register_memory.return_value = MagicMock()
nixl_api_mock.nixl_agent.return_value = agent_instance
nixl_api_mock.nixl_xfer_handle = MagicMock
return nixl_api_mock, nixl_bindings_mock, agent_instance
@pytest.fixture
def nixl_mocks():
nixl_api_mock, nixl_bindings_mock, agent_instance = _make_nixl_mocks()
# Patch cupy import too since nixl_connect tries to import it
cupy_mock = MagicMock()
cupy_mock.cuda = MagicMock()
cupy_mock.cuda.is_available = MagicMock(return_value=False)
cupy_mock.ndarray = type("ndarray", (), {})
with (
patch.dict(
sys.modules,
{
"nixl": MagicMock(),
"nixl._api": nixl_api_mock,
"nixl._bindings": nixl_bindings_mock,
"cupy": cupy_mock,
"cupy_backends": MagicMock(),
"cupy_backends.cuda": MagicMock(),
"cupy_backends.cuda.api": MagicMock(),
"cupy_backends.cuda.api.runtime": MagicMock(),
},
),
):
yield nixl_api_mock, nixl_bindings_mock, agent_instance
@pytest.fixture
def testable_active_op(nixl_mocks):
"""Factory fixture: returns a function that creates a _TestableActiveOp with a given status sequence.
The subclass short-circuits ActiveOperation.__init__ to avoid NIXL hardware
calls, while preserving the real _wait_for_completion_() logic under test.
"""
from dynamo.nixl_connect import ActiveOperation, OperationStatus
class _TestableActiveOp(ActiveOperation):
def __init__(self, status_sequence):
self._status = OperationStatus.INITIALIZED
self._status_sequence = iter(status_sequence)
self._remote = MagicMock()
self._remote.name = "mock-prefill-worker"
self._xfer_hndl = MagicMock()
self._connection = MagicMock()
self._local_desc_list = MagicMock()
self._local_desc_tlist = []
self._remote_desc_tlist = []
self._local_device_kind = MagicMock()
self._remote_device_kind = MagicMock()
self._notification_key = "test-key"
self._operation_kind = MagicMock()
@property
def status(self):
try:
self._status = next(self._status_sequence)
except StopIteration:
pass
return self._status
def cancel(self):
pass
async def wait_for_completion(self):
await self._wait_for_completion_()
def _release(self):
pass
return _TestableActiveOp
@pytest.mark.asyncio
async def test_wait_for_completion_raises_on_errored_status(testable_active_op):
"""ActiveOperation._wait_for_completion_ must raise RuntimeError when ERRORED.
Before fix: silently returned, leaving caller unaware the transfer failed.
After fix: raises RuntimeError so the caller can handle the failure (e.g.,
convert it to a retryable RequestError instead of propagating a segfault).
This is the core decode-side fix for issue #7319.
"""
from dynamo.nixl_connect import OperationStatus
# Simulate: INITIALIZED -> IN_PROGRESS -> ERRORED (remote agent disappeared)
op = testable_active_op(
[
OperationStatus.INITIALIZED,
OperationStatus.IN_PROGRESS,
OperationStatus.ERRORED,
]
)
with pytest.raises(RuntimeError, match=r"ERRORED|errored|error"):
await op.wait_for_completion()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment