Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
4ca13a86
Unverified
Commit
4ca13a86
authored
Oct 22, 2025
by
Mark McLoughlin
Committed by
GitHub
Oct 22, 2025
Browse files
[NIXL] Terminate handshake listener thread in shutdown (#26404)
Signed-off-by:
Mark McLoughlin
<
markmc@redhat.com
>
parent
675aa2ec
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
52 additions
and
11 deletions
+52
-11
tests/v1/kv_connector/unit/test_nixl_connector.py
tests/v1/kv_connector/unit/test_nixl_connector.py
+22
-7
vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
...distributed/kv_transfer/kv_connector/v1/nixl_connector.py
+30
-4
No files found.
tests/v1/kv_connector/unit/test_nixl_connector.py
View file @
4ca13a86
...
@@ -938,6 +938,13 @@ def test_abort_timeout_on_prefiller(monkeypatch, distributed_executor_backend):
...
@@ -938,6 +938,13 @@ def test_abort_timeout_on_prefiller(monkeypatch, distributed_executor_backend):
monkeypatch
.
setenv
(
"VLLM_ENABLE_V1_MULTIPROCESSING"
,
"0"
)
monkeypatch
.
setenv
(
"VLLM_ENABLE_V1_MULTIPROCESSING"
,
"0"
)
monkeypatch
.
setenv
(
"VLLM_NIXL_ABORT_REQUEST_TIMEOUT"
,
str
(
timeout
))
monkeypatch
.
setenv
(
"VLLM_NIXL_ABORT_REQUEST_TIMEOUT"
,
str
(
timeout
))
def
run_test_and_cleanup
():
llm
=
LLM
(
**
llm_kwargs
)
try
:
_run_abort_timeout_test
(
llm
,
timeout
)
finally
:
llm
.
llm_engine
.
engine_core
.
shutdown
()
# Build runtime_env only if we're using Ray
# Build runtime_env only if we're using Ray
if
distributed_executor_backend
==
"ray"
:
if
distributed_executor_backend
==
"ray"
:
with
_make_fake_nixl_pkg
()
as
working_dir
:
with
_make_fake_nixl_pkg
()
as
working_dir
:
...
@@ -950,15 +957,16 @@ def test_abort_timeout_on_prefiller(monkeypatch, distributed_executor_backend):
...
@@ -950,15 +957,16 @@ def test_abort_timeout_on_prefiller(monkeypatch, distributed_executor_backend):
},
},
}
}
ray
.
init
(
runtime_env
=
runtime_env
)
ray
.
init
(
runtime_env
=
runtime_env
)
try
:
_run_abort_timeout_test
(
llm_kwargs
,
timeout
)
run_test_and_cleanup
()
finally
:
ray
.
shutdown
()
else
:
else
:
_
run_
abort_timeout_test
(
llm_kwargs
,
timeout
)
run_
test_and_cleanup
(
)
def
_run_abort_timeout_test
(
llm
_kwargs
:
dict
,
timeout
:
int
):
def
_run_abort_timeout_test
(
llm
:
LLM
,
timeout
:
int
):
"""Helper function to run the abort timeout test logic."""
"""Helper function to run the abort timeout test logic."""
llm
=
LLM
(
**
llm_kwargs
)
remote_prefill_opts
=
{
remote_prefill_opts
=
{
"do_remote_decode"
:
True
,
"do_remote_decode"
:
True
,
"do_remote_prefill"
:
False
,
"do_remote_prefill"
:
False
,
...
@@ -1042,7 +1050,7 @@ def test_register_kv_caches(dist_init):
...
@@ -1042,7 +1050,7 @@ def test_register_kv_caches(dist_init):
),
),
patch
(
patch
(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.threading.Thread"
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.threading.Thread"
),
)
as
mock_thread
,
):
# noqa: E501
):
# noqa: E501
# Create connector
# Create connector
connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
)
connector
=
NixlConnector
(
vllm_config
,
KVConnectorRole
.
WORKER
)
...
@@ -1054,6 +1062,9 @@ def test_register_kv_caches(dist_init):
...
@@ -1054,6 +1062,9 @@ def test_register_kv_caches(dist_init):
mock_wrapper_instance
=
mock_nixl_wrapper
.
return_value
mock_wrapper_instance
=
mock_nixl_wrapper
.
return_value
connector
.
connector_worker
.
nixl_wrapper
=
mock_wrapper_instance
connector
.
connector_worker
.
nixl_wrapper
=
mock_wrapper_instance
# Reassure the shutdown() check that the thread is terminated
mock_thread
.
return_value
.
is_alive
.
return_value
=
False
# Execute register_kv_caches
# Execute register_kv_caches
connector
.
register_kv_caches
(
kv_caches
)
connector
.
register_kv_caches
(
kv_caches
)
...
@@ -1171,6 +1182,7 @@ def test_shutdown_cleans_up_resources(dist_init):
...
@@ -1171,6 +1182,7 @@ def test_shutdown_cleans_up_resources(dist_init):
with
(
with
(
patch
.
object
(
worker
,
"_handshake_initiation_executor"
)
as
mock_exec
,
patch
.
object
(
worker
,
"_handshake_initiation_executor"
)
as
mock_exec
,
patch
.
object
(
worker
,
"_nixl_handshake_listener_t"
)
as
mock_listener
,
patch
.
object
(
worker
,
"_nixl_handshake_listener_t"
)
as
mock_listener
,
patch
.
object
(
worker
,
"_nixl_handshake_listener_stop_event"
)
as
mock_event
,
patch
.
object
(
nixl_wrapper
,
"release_xfer_handle"
)
as
mock_rel_xfer
,
patch
.
object
(
nixl_wrapper
,
"release_xfer_handle"
)
as
mock_rel_xfer
,
patch
.
object
(
nixl_wrapper
,
"release_dlist_handle"
)
as
mock_rel_dlist
,
patch
.
object
(
nixl_wrapper
,
"release_dlist_handle"
)
as
mock_rel_dlist
,
patch
.
object
(
nixl_wrapper
,
"remove_remote_agent"
)
as
mock_rem_agent
,
patch
.
object
(
nixl_wrapper
,
"remove_remote_agent"
)
as
mock_rem_agent
,
...
@@ -1182,6 +1194,8 @@ def test_shutdown_cleans_up_resources(dist_init):
...
@@ -1182,6 +1194,8 @@ def test_shutdown_cleans_up_resources(dist_init):
worker
.
_remote_agents
=
{
"engine1"
:
{
0
:
"agent1"
}}
worker
.
_remote_agents
=
{
"engine1"
:
{
0
:
"agent1"
}}
worker
.
_registered_descs
=
[
"desc1"
,
"desc2"
]
worker
.
_registered_descs
=
[
"desc1"
,
"desc2"
]
mock_listener
.
is_alive
.
return_value
=
False
worker
.
shutdown
()
worker
.
shutdown
()
# Test idempotency
# Test idempotency
...
@@ -1189,7 +1203,8 @@ def test_shutdown_cleans_up_resources(dist_init):
...
@@ -1189,7 +1203,8 @@ def test_shutdown_cleans_up_resources(dist_init):
worker
.
shutdown
()
worker
.
shutdown
()
mock_exec
.
shutdown
.
assert_called_with
(
wait
=
False
)
mock_exec
.
shutdown
.
assert_called_with
(
wait
=
False
)
mock_listener
.
join
.
assert_called_once_with
(
timeout
=
0
)
mock_event
.
set
.
assert_called_once
()
mock_listener
.
join
.
assert_called_once_with
(
timeout
=
1.0
)
mock_rel_xfer
.
assert_called_once_with
(
123
)
mock_rel_xfer
.
assert_called_once_with
(
123
)
assert
mock_rel_dlist
.
call_count
==
2
assert
mock_rel_dlist
.
call_count
==
2
...
...
vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
View file @
4ca13a86
...
@@ -520,6 +520,8 @@ class NixlConnectorScheduler:
...
@@ -520,6 +520,8 @@ class NixlConnectorScheduler:
class
NixlConnectorWorker
:
class
NixlConnectorWorker
:
"""Implementation of Worker side methods"""
"""Implementation of Worker side methods"""
_POLL_TIMEOUT
=
0.1
# Handshake thread polls for stop event every 100ms
@
dataclass
@
dataclass
class
TpKVTopology
:
class
TpKVTopology
:
"""
"""
...
@@ -719,6 +721,7 @@ class NixlConnectorWorker:
...
@@ -719,6 +721,7 @@ class NixlConnectorWorker:
# Background thread for handling new handshake requests.
# Background thread for handling new handshake requests.
self
.
_nixl_handshake_listener_t
:
threading
.
Thread
|
None
=
None
self
.
_nixl_handshake_listener_t
:
threading
.
Thread
|
None
=
None
self
.
_nixl_handshake_listener_stop_event
:
threading
.
Event
|
None
=
None
# Background thread for initializing new NIXL handshakes.
# Background thread for initializing new NIXL handshakes.
self
.
_handshake_initiation_executor
=
ThreadPoolExecutor
(
self
.
_handshake_initiation_executor
=
ThreadPoolExecutor
(
# NIXL is not guaranteed to be thread-safe, limit 1 worker.
# NIXL is not guaranteed to be thread-safe, limit 1 worker.
...
@@ -773,6 +776,7 @@ class NixlConnectorWorker:
...
@@ -773,6 +776,7 @@ class NixlConnectorWorker:
def
_nixl_handshake_listener
(
def
_nixl_handshake_listener
(
metadata
:
NixlAgentMetadata
,
metadata
:
NixlAgentMetadata
,
ready_event
:
threading
.
Event
,
ready_event
:
threading
.
Event
,
stop_event
:
threading
.
Event
,
base_port
:
int
,
base_port
:
int
,
tp_rank
:
int
,
tp_rank
:
int
,
):
):
...
@@ -791,7 +795,14 @@ class NixlConnectorWorker:
...
@@ -791,7 +795,14 @@ class NixlConnectorWorker:
logger
.
debug
(
"Starting listening on path: %s"
,
path
)
logger
.
debug
(
"Starting listening on path: %s"
,
path
)
with
zmq_ctx
(
zmq
.
ROUTER
,
path
)
as
sock
:
with
zmq_ctx
(
zmq
.
ROUTER
,
path
)
as
sock
:
ready_event
.
set
()
ready_event
.
set
()
while
True
:
poller
=
zmq
.
Poller
()
poller
.
register
(
sock
,
zmq
.
POLLIN
)
while
not
stop_event
.
is_set
():
events
=
dict
(
poller
.
poll
(
timeout
=
NixlConnectorWorker
.
_POLL_TIMEOUT
*
1000
)
)
if
sock
not
in
events
:
continue
identity
,
_
,
msg
=
sock
.
recv_multipart
()
identity
,
_
,
msg
=
sock
.
recv_multipart
()
if
msg
!=
GET_META_MSG
:
if
msg
!=
GET_META_MSG
:
logger
.
warning
(
"Connection listener got unexpected message %s"
,
msg
)
logger
.
warning
(
"Connection listener got unexpected message %s"
,
msg
)
...
@@ -1101,14 +1112,21 @@ class NixlConnectorWorker:
...
@@ -1101,14 +1112,21 @@ class NixlConnectorWorker:
attn_backend_name
=
self
.
backend_name
,
attn_backend_name
=
self
.
backend_name
,
kv_cache_layout
=
self
.
kv_cache_layout
,
kv_cache_layout
=
self
.
kv_cache_layout
,
)
)
ready_event
=
threading
.
Event
()
ready_event
,
stop_event
=
threading
.
Event
()
,
threading
.
Event
()
self
.
_nixl_handshake_listener_t
=
threading
.
Thread
(
self
.
_nixl_handshake_listener_t
=
threading
.
Thread
(
target
=
self
.
_nixl_handshake_listener
,
target
=
self
.
_nixl_handshake_listener
,
args
=
(
metadata
,
ready_event
,
self
.
side_channel_port
,
self
.
tp_rank
),
args
=
(
metadata
,
ready_event
,
stop_event
,
self
.
side_channel_port
,
self
.
tp_rank
,
),
daemon
=
True
,
daemon
=
True
,
name
=
"nixl_handshake_listener"
,
name
=
"nixl_handshake_listener"
,
)
)
self
.
_nixl_handshake_listener_t
.
start
()
self
.
_nixl_handshake_listener_t
.
start
()
self
.
_nixl_handshake_listener_stop_event
=
stop_event
ready_event
.
wait
()
# Wait for listener ZMQ socket to be ready.
ready_event
.
wait
()
# Wait for listener ZMQ socket to be ready.
def
add_remote_agent
(
def
add_remote_agent
(
...
@@ -1782,11 +1800,19 @@ class NixlConnectorWorker:
...
@@ -1782,11 +1800,19 @@ class NixlConnectorWorker:
self
.
_invalid_block_ids
=
set
()
self
.
_invalid_block_ids
=
set
()
return
result
return
result
def
__del__
(
self
):
self
.
shutdown
()
def
shutdown
(
self
):
def
shutdown
(
self
):
"""Shutdown the connector worker."""
"""Shutdown the connector worker."""
self
.
_handshake_initiation_executor
.
shutdown
(
wait
=
False
)
self
.
_handshake_initiation_executor
.
shutdown
(
wait
=
False
)
if
self
.
_nixl_handshake_listener_stop_event
is
not
None
:
self
.
_nixl_handshake_listener_stop_event
.
set
()
self
.
_nixl_handshake_listener_stop_event
=
None
if
self
.
_nixl_handshake_listener_t
is
not
None
:
if
self
.
_nixl_handshake_listener_t
is
not
None
:
self
.
_nixl_handshake_listener_t
.
join
(
timeout
=
0
)
# Generous timeout to allow the thread to exit
self
.
_nixl_handshake_listener_t
.
join
(
timeout
=
self
.
_POLL_TIMEOUT
*
10
)
assert
not
self
.
_nixl_handshake_listener_t
.
is_alive
()
self
.
_nixl_handshake_listener_t
=
None
self
.
_nixl_handshake_listener_t
=
None
for
handles
in
self
.
_recving_transfers
.
values
():
for
handles
in
self
.
_recving_transfers
.
values
():
for
handle
,
_
in
handles
:
for
handle
,
_
in
handles
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment