Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
dynamo
Commits
b13a440e
Unverified
Commit
b13a440e
authored
Apr 14, 2026
by
MatejKosec
Committed by
GitHub
Apr 14, 2026
Browse files
fix(trtllm): prevent decode worker segfault on prefill scale-down (#7933)
Signed-off-by:
Matej Kosec
<
mkosec@nvidia.com
>
parent
9a07ca15
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
459 additions
and
7 deletions
+459
-7
components/src/dynamo/common/utils/graceful_shutdown.py
components/src/dynamo/common/utils/graceful_shutdown.py
+39
-1
components/src/dynamo/common/utils/tests/test_graceful_shutdown.py
...s/src/dynamo/common/utils/tests/test_graceful_shutdown.py
+161
-0
components/src/dynamo/trtllm/main.py
components/src/dynamo/trtllm/main.py
+85
-2
components/src/dynamo/trtllm/workers/__init__.py
components/src/dynamo/trtllm/workers/__init__.py
+11
-1
components/src/dynamo/trtllm/workers/llm_worker.py
components/src/dynamo/trtllm/workers/llm_worker.py
+8
-0
lib/bindings/python/src/dynamo/nixl_connect/__init__.py
lib/bindings/python/src/dynamo/nixl_connect/__init__.py
+21
-3
lib/bindings/python/tests/test_nixl_connect_unit.py
lib/bindings/python/tests/test_nixl_connect_unit.py
+134
-0
No files found.
components/src/dynamo/common/utils/graceful_shutdown.py
View file @
b13a440e
...
@@ -5,7 +5,7 @@ import asyncio
...
@@ -5,7 +5,7 @@ import asyncio
import
logging
import
logging
import
os
import
os
import
signal
import
signal
from
typing
import
Any
,
Iterable
,
Optional
from
typing
import
Any
,
Callable
,
Coroutine
,
Iterable
,
Optional
from
dynamo._core
import
DistributedRuntime
from
dynamo._core
import
DistributedRuntime
...
@@ -13,6 +13,7 @@ logger = logging.getLogger(__name__)
...
@@ -13,6 +13,7 @@ logger = logging.getLogger(__name__)
# TODO: make this using cli flag
# TODO: make this using cli flag
_DEFAULT_GRACE_PERIOD_SECS
=
5.0
_DEFAULT_GRACE_PERIOD_SECS
=
5.0
_DEFAULT_DRAIN_TIMEOUT_SECS
=
30.0
_GRACE_PERIOD_ENV
=
"DYN_GRACEFUL_SHUTDOWN_GRACE_PERIOD_SECS"
_GRACE_PERIOD_ENV
=
"DYN_GRACEFUL_SHUTDOWN_GRACE_PERIOD_SECS"
_shutdown_started
=
asyncio
.
Event
()
_shutdown_started
=
asyncio
.
Event
()
...
@@ -68,7 +69,23 @@ async def graceful_shutdown_with_discovery(
...
@@ -68,7 +69,23 @@ async def graceful_shutdown_with_discovery(
endpoints
:
Iterable
,
endpoints
:
Iterable
,
shutdown_event
:
Optional
[
asyncio
.
Event
]
=
None
,
shutdown_event
:
Optional
[
asyncio
.
Event
]
=
None
,
grace_period_s
:
Optional
[
float
]
=
None
,
grace_period_s
:
Optional
[
float
]
=
None
,
drain_callback
:
Optional
[
Callable
[[],
Coroutine
]]
=
None
,
)
->
None
:
)
->
None
:
"""Perform graceful shutdown with endpoint unregistration and optional drain.
Args:
runtime: The distributed runtime to shut down.
endpoints: Endpoints to unregister from discovery before shutdown.
shutdown_event: Optional event to set before calling runtime.shutdown().
grace_period_s: Seconds to wait after unregistering before drain/shutdown.
Defaults to DYN_GRACEFUL_SHUTDOWN_GRACE_PERIOD_SECS env var or 5s.
drain_callback: Optional async callable awaited after the grace period
but *before* runtime.shutdown(). Use this on prefill workers to wait
for in-flight NIXL KV transfers to complete, preventing decode workers
from segfaulting due to use-after-free on freed GPU memory (#7319).
Any exception raised by drain_callback is logged and swallowed so that
shutdown still proceeds even if draining times out or fails.
"""
if
_shutdown_started
.
is_set
():
if
_shutdown_started
.
is_set
():
return
return
_shutdown_started
.
set
()
_shutdown_started
.
set
()
...
@@ -83,6 +100,25 @@ async def graceful_shutdown_with_discovery(
...
@@ -83,6 +100,25 @@ async def graceful_shutdown_with_discovery(
logger
.
info
(
"Grace period %.2fs before stopping endpoints"
,
grace_period_s
)
logger
.
info
(
"Grace period %.2fs before stopping endpoints"
,
grace_period_s
)
await
asyncio
.
sleep
(
grace_period_s
)
await
asyncio
.
sleep
(
grace_period_s
)
if
drain_callback
is
not
None
:
logger
.
info
(
"Draining in-flight transfers before shutdown (issue #7319 safeguard)"
)
try
:
await
asyncio
.
wait_for
(
drain_callback
(),
timeout
=
_DEFAULT_DRAIN_TIMEOUT_SECS
)
logger
.
info
(
"Drain complete"
)
except
asyncio
.
TimeoutError
:
logger
.
warning
(
"Drain callback timed out after %.0fs, proceeding with shutdown"
,
_DEFAULT_DRAIN_TIMEOUT_SECS
,
)
except
Exception
:
logger
.
exception
(
"Drain callback raised an exception; proceeding with shutdown"
)
if
shutdown_event
is
not
None
:
if
shutdown_event
is
not
None
:
shutdown_event
.
set
()
shutdown_event
.
set
()
...
@@ -96,6 +132,7 @@ def install_signal_handlers(
...
@@ -96,6 +132,7 @@ def install_signal_handlers(
endpoints
:
Iterable
,
endpoints
:
Iterable
,
shutdown_event
:
Optional
[
asyncio
.
Event
]
=
None
,
shutdown_event
:
Optional
[
asyncio
.
Event
]
=
None
,
grace_period_s
:
Optional
[
float
]
=
None
,
grace_period_s
:
Optional
[
float
]
=
None
,
drain_callback
:
Optional
[
Callable
[[],
Coroutine
]]
=
None
,
)
->
None
:
)
->
None
:
shutdown_task
:
Optional
[
asyncio
.
Task
[
None
]]
=
None
shutdown_task
:
Optional
[
asyncio
.
Task
[
None
]]
=
None
...
@@ -123,6 +160,7 @@ def install_signal_handlers(
...
@@ -123,6 +160,7 @@ def install_signal_handlers(
endpoints
,
endpoints
,
shutdown_event
=
shutdown_event
,
shutdown_event
=
shutdown_event
,
grace_period_s
=
grace_period_s
,
grace_period_s
=
grace_period_s
,
drain_callback
=
drain_callback
,
)
)
)
)
shutdown_task
.
add_done_callback
(
_on_shutdown_done
)
shutdown_task
.
add_done_callback
(
_on_shutdown_done
)
...
...
components/src/dynamo/common/utils/tests/test_graceful_shutdown.py
0 → 100644
View file @
b13a440e
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Unit tests for graceful_shutdown.py
Tests the drain_callback mechanism added to prevent decode worker segfaults when
a prefill worker scales down before in-flight NIXL KV transfers complete (issue #7319).
These tests import graceful_shutdown directly (bypassing the dynamo package hierarchy)
so they work without GPU, NIXL, or TensorRT-LLM installed.
"""
import
asyncio
import
importlib.util
import
sys
import
types
from
pathlib
import
Path
from
unittest.mock
import
AsyncMock
,
MagicMock
import
pytest
pytestmark
=
[
pytest
.
mark
.
unit
,
pytest
.
mark
.
pre_merge
]
# ---------------------------------------------------------------------------
# Module loading: import graceful_shutdown without triggering the full dynamo
# package (which requires dynamo.llm, CUDA, etc.)
#
# We cannot do `from dynamo.common.utils import graceful_shutdown` because the
# dynamo package __init__ transitively imports dynamo._core, which is a native
# extension (PyO3) requiring CUDA/NIXL libraries that are not available in
# unit test environments. Instead, we stub dynamo._core and load the module
# directly from its file path via importlib.
# ---------------------------------------------------------------------------
_GRACEFUL_SHUTDOWN_PATH
=
Path
(
__file__
).
parent
.
parent
/
"graceful_shutdown.py"
# Provide a minimal dynamo._core stub so the module can be loaded
_dynamo_stub
=
types
.
ModuleType
(
"dynamo"
)
_dynamo_core_stub
=
types
.
ModuleType
(
"dynamo._core"
)
_dynamo_core_stub
.
DistributedRuntime
=
object
sys
.
modules
.
setdefault
(
"dynamo"
,
_dynamo_stub
)
sys
.
modules
.
setdefault
(
"dynamo._core"
,
_dynamo_core_stub
)
def
_load_graceful_shutdown
():
spec
=
importlib
.
util
.
spec_from_file_location
(
"dynamo.common.utils.graceful_shutdown"
,
_GRACEFUL_SHUTDOWN_PATH
,
)
mod
=
importlib
.
util
.
module_from_spec
(
spec
)
spec
.
loader
.
exec_module
(
mod
)
return
mod
_gs
=
_load_graceful_shutdown
()
graceful_shutdown_with_discovery
=
_gs
.
graceful_shutdown_with_discovery
install_signal_handlers
=
_gs
.
install_signal_handlers
# ---------------------------------------------------------------------------
# Helper: reset the module-level _shutdown_started event between tests
# ---------------------------------------------------------------------------
@
pytest
.
fixture
(
autouse
=
True
)
def
reset_shutdown_state
():
_gs
.
_shutdown_started
.
clear
()
yield
_gs
.
_shutdown_started
.
clear
()
# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------
def
test_drain_callback_called_before_shutdown
():
"""Drain callback must be awaited before runtime.shutdown().
This is the key regression test for issue #7319: prefill workers holding
active NIXL RDMA references must drain in-flight transfers before their
process exits, otherwise decode workers segfault accessing freed GPU memory.
"""
call_order
=
[]
mock_runtime
=
MagicMock
()
mock_runtime
.
shutdown
=
MagicMock
(
side_effect
=
lambda
:
call_order
.
append
(
"shutdown"
))
async
def
mock_drain
():
call_order
.
append
(
"drain"
)
async
def
_run
():
mock_endpoint
=
AsyncMock
()
mock_endpoint
.
unregister_endpoint_instance
=
AsyncMock
(
return_value
=
None
)
await
graceful_shutdown_with_discovery
(
runtime
=
mock_runtime
,
endpoints
=
[
mock_endpoint
],
shutdown_event
=
None
,
grace_period_s
=
0
,
drain_callback
=
mock_drain
,
)
asyncio
.
run
(
_run
())
assert
"drain"
in
call_order
,
"drain_callback was not called"
assert
"shutdown"
in
call_order
,
"runtime.shutdown was not called"
drain_idx
=
call_order
.
index
(
"drain"
)
shutdown_idx
=
call_order
.
index
(
"shutdown"
)
assert
drain_idx
<
shutdown_idx
,
(
"drain_callback must be called before runtime.shutdown() to ensure "
"in-flight NIXL transfers complete before GPU memory is freed"
)
def
test_no_drain_callback_still_shuts_down
():
"""Backward compatibility: shutdown still works without drain_callback."""
mock_runtime
=
MagicMock
()
async
def
_run
():
mock_endpoint
=
AsyncMock
()
mock_endpoint
.
unregister_endpoint_instance
=
AsyncMock
(
return_value
=
None
)
await
graceful_shutdown_with_discovery
(
runtime
=
mock_runtime
,
endpoints
=
[
mock_endpoint
],
shutdown_event
=
None
,
grace_period_s
=
0
,
drain_callback
=
None
,
)
asyncio
.
run
(
_run
())
mock_runtime
.
shutdown
.
assert_called_once
()
def
test_drain_callback_exception_does_not_block_shutdown
():
"""Drain callback exceptions must not block shutdown.
Even if draining fails (e.g., timeout), the shutdown must still proceed
so the process exits cleanly.
"""
mock_runtime
=
MagicMock
()
async
def
failing_drain
():
raise
RuntimeError
(
"drain timed out"
)
async
def
_run
():
mock_endpoint
=
AsyncMock
()
mock_endpoint
.
unregister_endpoint_instance
=
AsyncMock
(
return_value
=
None
)
await
graceful_shutdown_with_discovery
(
runtime
=
mock_runtime
,
endpoints
=
[
mock_endpoint
],
shutdown_event
=
None
,
grace_period_s
=
0
,
drain_callback
=
failing_drain
,
)
# Should not raise
asyncio
.
run
(
_run
())
mock_runtime
.
shutdown
.
assert_called_once
()
components/src/dynamo/trtllm/main.py
View file @
b13a440e
...
@@ -3,6 +3,7 @@
...
@@ -3,6 +3,7 @@
import
asyncio
import
asyncio
import
logging
import
logging
from
typing
import
Callable
,
Coroutine
import
uvloop
import
uvloop
...
@@ -10,11 +11,72 @@ from dynamo.common.utils.graceful_shutdown import install_signal_handlers
...
@@ -10,11 +11,72 @@ from dynamo.common.utils.graceful_shutdown import install_signal_handlers
from
dynamo.common.utils.runtime
import
create_runtime
from
dynamo.common.utils.runtime
import
create_runtime
from
dynamo.runtime.logging
import
configure_dynamo_logging
from
dynamo.runtime.logging
import
configure_dynamo_logging
from
dynamo.trtllm.args
import
parse_args
from
dynamo.trtllm.args
import
parse_args
from
dynamo.trtllm.constants
import
DisaggregationMode
from
dynamo.trtllm.workers
import
init_worker
from
dynamo.trtllm.workers
import
init_worker
configure_dynamo_logging
()
configure_dynamo_logging
()
shutdown_endpoints
:
list
=
[]
shutdown_endpoints
:
list
=
[]
# Maximum time (seconds) to wait for in-flight requests to drain during shutdown.
_DRAIN_TIMEOUT_S
=
30.0
_DRAIN_POLL_INTERVAL_S
=
0.5
def
_make_drain_callback
(
engine_holder
:
list
,
)
->
Callable
[[],
Coroutine
]:
"""Create a drain callback that polls the TRT-LLM engine until idle.
The engine_holder is a mutable list populated by init_llm_worker once the
engine is ready. If it is still empty when the signal fires (engine not yet
initialized), draining is skipped.
Returns None when the worker is not a prefill worker (drain is unnecessary).
The caller checks disaggregation_mode *before* calling this helper.
"""
async
def
_drain_in_flight_requests
():
if
not
engine_holder
:
logging
.
info
(
"Engine not yet initialized; skipping drain"
)
return
engine
=
engine_holder
[
0
]
logging
.
info
(
"Draining in-flight requests (timeout=%.1fs) to allow "
"NIXL KV transfers to complete before GPU memory is freed"
,
_DRAIN_TIMEOUT_S
,
)
deadline
=
asyncio
.
get_running_loop
().
time
()
+
_DRAIN_TIMEOUT_S
while
asyncio
.
get_running_loop
().
time
()
<
deadline
:
try
:
stats_iter
=
engine
.
llm
.
get_stats_async
(
timeout
=
2
)
stat
=
await
anext
(
stats_iter
)
active
=
stat
.
get
(
"numActiveRequests"
,
0
)
queued
=
stat
.
get
(
"numQueuedRequests"
,
0
)
total
=
active
+
queued
if
total
==
0
:
logging
.
info
(
"All in-flight requests drained"
)
return
logging
.
info
(
"Waiting for %d in-flight request(s) to complete "
"(active=%d, queued=%d)"
,
total
,
active
,
queued
,
)
except
Exception
as
e
:
# get_stats_async may fail if engine is already partially torn down
logging
.
debug
(
"Stats poll failed during drain: %s"
,
e
)
await
asyncio
.
sleep
(
_DRAIN_POLL_INTERVAL_S
)
logging
.
warning
(
"Drain timeout (%.1fs) reached; proceeding with shutdown. "
"Some NIXL transfers may still be in flight."
,
_DRAIN_TIMEOUT_S
,
)
return
_drain_in_flight_requests
async
def
worker
():
async
def
worker
():
config
=
parse_args
()
config
=
parse_args
()
...
@@ -26,10 +88,31 @@ async def worker():
...
@@ -26,10 +88,31 @@ async def worker():
event_plane
=
config
.
event_plane
,
event_plane
=
config
.
event_plane
,
)
)
install_signal_handlers
(
loop
,
runtime
,
shutdown_endpoints
,
shutdown_event
)
# Only prefill workers need a drain callback. When a prefill worker shuts
# down, decode workers may still be reading its GPU memory via NIXL RDMA.
# The drain callback waits for in-flight requests to finish so that GPU
# memory is not freed while transfers are active (issue #7319).
engine_holder
:
list
=
[]
drain_callback
=
None
if
config
.
disaggregation_mode
==
DisaggregationMode
.
PREFILL
:
drain_callback
=
_make_drain_callback
(
engine_holder
)
install_signal_handlers
(
loop
,
runtime
,
shutdown_endpoints
,
shutdown_event
,
drain_callback
=
drain_callback
,
)
logging
.
info
(
f
"Initializing the worker with config:
{
config
}
"
)
logging
.
info
(
f
"Initializing the worker with config:
{
config
}
"
)
await
init_worker
(
runtime
,
config
,
shutdown_event
,
shutdown_endpoints
)
await
init_worker
(
runtime
,
config
,
shutdown_event
,
shutdown_endpoints
,
engine_holder
=
engine_holder
,
)
def
main
():
def
main
():
...
...
components/src/dynamo/trtllm/workers/__init__.py
View file @
b13a440e
...
@@ -31,6 +31,7 @@ async def init_worker(
...
@@ -31,6 +31,7 @@ async def init_worker(
config
:
Config
,
config
:
Config
,
shutdown_event
:
asyncio
.
Event
,
shutdown_event
:
asyncio
.
Event
,
shutdown_endpoints
:
Optional
[
list
]
=
None
,
shutdown_endpoints
:
Optional
[
list
]
=
None
,
engine_holder
:
Optional
[
list
]
=
None
,
)
->
None
:
)
->
None
:
"""Initialize the appropriate worker based on modality.
"""Initialize the appropriate worker based on modality.
...
@@ -42,6 +43,9 @@ async def init_worker(
...
@@ -42,6 +43,9 @@ async def init_worker(
config: Configuration parsed from command line.
config: Configuration parsed from command line.
shutdown_event: Event to signal shutdown.
shutdown_event: Event to signal shutdown.
shutdown_endpoints: Optional list to populate with endpoints for graceful shutdown.
shutdown_endpoints: Optional list to populate with endpoints for graceful shutdown.
engine_holder: Optional mutable list; when provided, init_llm_worker will
append the TensorRTLLMEngine instance so that the drain callback
(installed earlier by main.py) can access it at signal time.
"""
"""
logging
.
info
(
f
"Initializing worker with modality=
{
config
.
modality
}
"
)
logging
.
info
(
f
"Initializing worker with modality=
{
config
.
modality
}
"
)
...
@@ -61,7 +65,13 @@ async def init_worker(
...
@@ -61,7 +65,13 @@ async def init_worker(
raise
ValueError
(
f
"Unsupported diffusion modality:
{
modality
}
"
)
raise
ValueError
(
f
"Unsupported diffusion modality:
{
modality
}
"
)
# LLM modalities (text, multimodal)
# LLM modalities (text, multimodal)
await
init_llm_worker
(
runtime
,
config
,
shutdown_event
,
shutdown_endpoints
)
await
init_llm_worker
(
runtime
,
config
,
shutdown_event
,
shutdown_endpoints
,
engine_holder
=
engine_holder
,
)
__all__
=
[
"init_worker"
]
__all__
=
[
"init_worker"
]
components/src/dynamo/trtllm/workers/llm_worker.py
View file @
b13a440e
...
@@ -132,6 +132,7 @@ async def init_llm_worker(
...
@@ -132,6 +132,7 @@ async def init_llm_worker(
config
:
Config
,
config
:
Config
,
shutdown_event
:
asyncio
.
Event
,
shutdown_event
:
asyncio
.
Event
,
shutdown_endpoints
:
Optional
[
list
]
=
None
,
shutdown_endpoints
:
Optional
[
list
]
=
None
,
engine_holder
:
Optional
[
list
]
=
None
,
)
->
None
:
)
->
None
:
"""Initialize and run the LLM worker.
"""Initialize and run the LLM worker.
...
@@ -142,6 +143,8 @@ async def init_llm_worker(
...
@@ -142,6 +143,8 @@ async def init_llm_worker(
config: Configuration parsed from command line.
config: Configuration parsed from command line.
shutdown_event: Event to signal shutdown.
shutdown_event: Event to signal shutdown.
shutdown_endpoints: Optional list to populate with endpoints for graceful shutdown.
shutdown_endpoints: Optional list to populate with endpoints for graceful shutdown.
engine_holder: Optional mutable list; when provided, the TensorRTLLMEngine
is appended so that the drain callback can reference it at shutdown time.
"""
"""
encode_client
=
None
encode_client
=
None
...
@@ -384,6 +387,11 @@ async def init_llm_worker(
...
@@ -384,6 +387,11 @@ async def init_llm_worker(
config
.
disaggregation_mode
,
config
.
disaggregation_mode
,
component_gauges
=
component_gauges
,
component_gauges
=
component_gauges
,
)
as
engine
:
)
as
engine
:
# Expose engine to the drain callback installed by main.py (#7319).
# The callback uses this to poll active request count during shutdown.
if
engine_holder
is
not
None
:
engine_holder
.
append
(
engine
)
endpoint
=
runtime
.
endpoint
(
endpoint
=
runtime
.
endpoint
(
f
"
{
config
.
namespace
}
.
{
config
.
component
}
.
{
config
.
endpoint
}
"
f
"
{
config
.
namespace
}
.
{
config
.
component
}
.
{
config
.
endpoint
}
"
)
)
...
...
lib/bindings/python/src/dynamo/nixl_connect/__init__.py
View file @
b13a440e
...
@@ -463,7 +463,17 @@ class ActiveOperation(AbstractOperation):
...
@@ -463,7 +463,17 @@ class ActiveOperation(AbstractOperation):
case
OperationStatus
.
INITIALIZED
|
OperationStatus
.
IN_PROGRESS
:
case
OperationStatus
.
INITIALIZED
|
OperationStatus
.
IN_PROGRESS
:
await
asyncio
.
sleep
(
sleep_time
/
1000
)
await
asyncio
.
sleep
(
sleep_time
/
1000
)
sleep_time
=
min
(
sleep_time
*
backoff_factor
,
max_poll_ms
)
sleep_time
=
min
(
sleep_time
*
backoff_factor
,
max_poll_ms
)
# Any other state indicates completion or error.
# ERRORED indicates the remote agent may have disconnected or
# its memory may be invalid (e.g. prefill worker scaled down).
# Raise so the caller can surface this as a retryable error
# rather than silently returning stale/empty data.
case
OperationStatus
.
ERRORED
:
raise
RuntimeError
(
f
"NIXL transfer operation ERRORED for remote '
{
self
.
_remote
.
name
}
'. "
"The remote agent may have disconnected or its GPU memory may be "
"invalid (e.g. the prefill worker was scaled down mid-transfer)."
)
# Any other state (COMPLETE, CANCELLED) indicates the transfer is done.
case
_
:
case
_
:
return
return
...
@@ -489,7 +499,11 @@ class ActiveOperation(AbstractOperation):
...
@@ -489,7 +499,11 @@ class ActiveOperation(AbstractOperation):
"""
"""
# Early return if the operation is already complete, errored, or cancelled.
# Early return if the operation is already complete, errored, or cancelled.
match
self
.
_status
:
match
self
.
_status
:
case
OperationStatus
.
COMPLETE
|
OperationStatus
.
ERRORED
|
OperationStatus
.
CANCELLED
:
case
(
OperationStatus
.
COMPLETE
|
OperationStatus
.
ERRORED
|
OperationStatus
.
CANCELLED
):
return
self
.
_status
return
self
.
_status
if
self
.
_xfer_hndl
is
None
:
if
self
.
_xfer_hndl
is
None
:
...
@@ -1466,7 +1480,11 @@ class PassiveOperation(AbstractOperation):
...
@@ -1466,7 +1480,11 @@ class PassiveOperation(AbstractOperation):
"""
"""
# Early return if the operation is already complete, errored, or cancelled.
# Early return if the operation is already complete, errored, or cancelled.
match
self
.
_status
:
match
self
.
_status
:
case
OperationStatus
.
COMPLETE
|
OperationStatus
.
ERRORED
|
OperationStatus
.
CANCELLED
:
case
(
OperationStatus
.
COMPLETE
|
OperationStatus
.
ERRORED
|
OperationStatus
.
CANCELLED
):
return
self
.
_status
return
self
.
_status
old_status
=
self
.
_status
old_status
=
self
.
_status
...
...
lib/bindings/python/tests/test_nixl_connect_unit.py
0 → 100644
View file @
b13a440e
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""Unit tests for dynamo.nixl_connect
Tests the ERRORED state handling in ActiveOperation._wait_for_completion_() added
to prevent decode workers from silently consuming bad data when a prefill worker
disappears mid-transfer (issue #7319).
NIXL and CUDA are mocked so these tests run on CPU-only machines.
"""
import
sys
from
unittest.mock
import
MagicMock
,
patch
import
pytest
pytestmark
=
[
pytest
.
mark
.
unit
,
pytest
.
mark
.
pre_merge
]
def
_make_nixl_mocks
():
"""Create minimal mocks for nixl._api and nixl._bindings."""
nixl_api_mock
=
MagicMock
()
nixl_bindings_mock
=
MagicMock
()
# nixl_agent mock (returned by nixl_api.nixl_agent(...))
agent_instance
=
MagicMock
()
agent_instance
.
get_agent_metadata
.
return_value
=
b
"mock-metadata"
agent_instance
.
add_remote_agent
.
return_value
=
b
"mock-remote-agent"
agent_instance
.
get_xfer_descs
.
return_value
=
MagicMock
()
agent_instance
.
initialize_xfer
.
return_value
=
MagicMock
()
agent_instance
.
register_memory
.
return_value
=
MagicMock
()
nixl_api_mock
.
nixl_agent
.
return_value
=
agent_instance
nixl_api_mock
.
nixl_xfer_handle
=
MagicMock
return
nixl_api_mock
,
nixl_bindings_mock
,
agent_instance
@
pytest
.
fixture
def
nixl_mocks
():
nixl_api_mock
,
nixl_bindings_mock
,
agent_instance
=
_make_nixl_mocks
()
# Patch cupy import too since nixl_connect tries to import it
cupy_mock
=
MagicMock
()
cupy_mock
.
cuda
=
MagicMock
()
cupy_mock
.
cuda
.
is_available
=
MagicMock
(
return_value
=
False
)
cupy_mock
.
ndarray
=
type
(
"ndarray"
,
(),
{})
with
(
patch
.
dict
(
sys
.
modules
,
{
"nixl"
:
MagicMock
(),
"nixl._api"
:
nixl_api_mock
,
"nixl._bindings"
:
nixl_bindings_mock
,
"cupy"
:
cupy_mock
,
"cupy_backends"
:
MagicMock
(),
"cupy_backends.cuda"
:
MagicMock
(),
"cupy_backends.cuda.api"
:
MagicMock
(),
"cupy_backends.cuda.api.runtime"
:
MagicMock
(),
},
),
):
yield
nixl_api_mock
,
nixl_bindings_mock
,
agent_instance
@
pytest
.
fixture
def
testable_active_op
(
nixl_mocks
):
"""Factory fixture: returns a function that creates a _TestableActiveOp with a given status sequence.
The subclass short-circuits ActiveOperation.__init__ to avoid NIXL hardware
calls, while preserving the real _wait_for_completion_() logic under test.
"""
from
dynamo.nixl_connect
import
ActiveOperation
,
OperationStatus
class
_TestableActiveOp
(
ActiveOperation
):
def
__init__
(
self
,
status_sequence
):
self
.
_status
=
OperationStatus
.
INITIALIZED
self
.
_status_sequence
=
iter
(
status_sequence
)
self
.
_remote
=
MagicMock
()
self
.
_remote
.
name
=
"mock-prefill-worker"
self
.
_xfer_hndl
=
MagicMock
()
self
.
_connection
=
MagicMock
()
self
.
_local_desc_list
=
MagicMock
()
self
.
_local_desc_tlist
=
[]
self
.
_remote_desc_tlist
=
[]
self
.
_local_device_kind
=
MagicMock
()
self
.
_remote_device_kind
=
MagicMock
()
self
.
_notification_key
=
"test-key"
self
.
_operation_kind
=
MagicMock
()
@
property
def
status
(
self
):
try
:
self
.
_status
=
next
(
self
.
_status_sequence
)
except
StopIteration
:
pass
return
self
.
_status
def
cancel
(
self
):
pass
async
def
wait_for_completion
(
self
):
await
self
.
_wait_for_completion_
()
def
_release
(
self
):
pass
return
_TestableActiveOp
@
pytest
.
mark
.
asyncio
async
def
test_wait_for_completion_raises_on_errored_status
(
testable_active_op
):
"""ActiveOperation._wait_for_completion_ must raise RuntimeError when ERRORED.
Before fix: silently returned, leaving caller unaware the transfer failed.
After fix: raises RuntimeError so the caller can handle the failure (e.g.,
convert it to a retryable RequestError instead of propagating a segfault).
This is the core decode-side fix for issue #7319.
"""
from
dynamo.nixl_connect
import
OperationStatus
# Simulate: INITIALIZED -> IN_PROGRESS -> ERRORED (remote agent disappeared)
op
=
testable_active_op
(
[
OperationStatus
.
INITIALIZED
,
OperationStatus
.
IN_PROGRESS
,
OperationStatus
.
ERRORED
,
]
)
with
pytest
.
raises
(
RuntimeError
,
match
=
r
"ERRORED|errored|error"
):
await
op
.
wait_for_completion
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment