Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
52f248cd
Unverified
Commit
52f248cd
authored
Sep 18, 2025
by
shaharmor98
Committed by
GitHub
Sep 18, 2025
Browse files
Feat/add heartbeat mechanism for nixl conn (#10222)
Signed-off-by:
Shahar Mor
<
smor@nvidia.com
>
parent
93f75778
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
139 additions
and
2 deletions
+139
-2
python/sglang/srt/disaggregation/nixl/conn.py
python/sglang/srt/disaggregation/nixl/conn.py
+139
-2
No files found.
python/sglang/srt/disaggregation/nixl/conn.py
View file @
52f248cd
...
...
@@ -2,14 +2,17 @@ from __future__ import annotations
import
dataclasses
import
logging
import
os
import
struct
import
threading
import
time
import
uuid
from
collections
import
defaultdict
from
typing
import
Dict
,
List
,
Optional
,
Set
import
numpy
as
np
import
numpy.typing
as
npt
import
requests
from
sglang.srt.disaggregation.base.conn
import
KVArgs
,
KVPoll
from
sglang.srt.disaggregation.common.conn
import
(
...
...
@@ -21,6 +24,7 @@ from sglang.srt.disaggregation.common.conn import (
from
sglang.srt.disaggregation.common.utils
import
group_concurrent_contiguous
from
sglang.srt.disaggregation.utils
import
DisaggregationMode
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.utils
import
get_int_env_var
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -102,8 +106,14 @@ class TransferStatus:
def
is_done
(
self
):
if
self
.
num_kvs_expected
is
None
:
return
False
# Check for failure state
if
self
.
num_kvs_expected
==
-
1
:
return
True
# Failed transfers are considered "done"
return
self
.
num_kvs_expected
==
len
(
self
.
received_kvs
)
and
self
.
received_aux
def
is_failed
(
self
):
return
self
.
num_kvs_expected
==
-
1
class
NixlKVManager
(
CommonKVManager
):
def
__init__
(
...
...
@@ -131,11 +141,125 @@ class NixlKVManager(CommonKVManager):
self
.
transfer_statuses
:
Dict
[
int
,
TransferStatus
]
=
defaultdict
(
TransferStatus
)
self
.
heartbeat_failures
=
{}
self
.
session_pool
=
defaultdict
(
requests
.
Session
)
self
.
session_pool_lock
=
threading
.
Lock
()
self
.
addr_to_rooms_tracker
=
defaultdict
(
set
)
self
.
connection_lock
=
threading
.
Lock
()
# Heartbeat interval should be at least 2 seconds
self
.
heartbeat_interval
=
max
(
float
(
os
.
getenv
(
"SGLANG_DISAGGREGATION_HEARTBEAT_INTERVAL"
,
5.0
)),
2.0
)
# Heartbeat failure should be at least 1
self
.
max_failures
=
max
(
get_int_env_var
(
"SGLANG_DISAGGREGATION_HEARTBEAT_MAX_FAILURE"
,
2
),
1
)
self
.
_start_heartbeat_checker_thread
()
else
:
raise
ValueError
(
f
"Unsupported DisaggregationMode:
{
self
.
disaggregation_mode
}
"
)
def
_start_heartbeat_checker_thread
(
self
):
"""
Start the heartbeat checker thread for Decode worker.
TODO (smor): unite nixl heartbeat checker with mooncake's.
"""
def
heartbeat_checker
():
while
True
:
time
.
sleep
(
self
.
heartbeat_interval
)
with
self
.
connection_lock
:
addresses
=
list
(
self
.
prefill_dp_size_table
.
keys
())
for
bootstrap_addr
in
addresses
:
session
=
None
try
:
with
self
.
session_pool_lock
:
session
=
self
.
session_pool
[
bootstrap_addr
]
response
=
session
.
get
(
f
"http://
{
bootstrap_addr
}
/health"
,
timeout
=
(
2
,
3
),
headers
=
{
"Connection"
:
"keep-alive"
},
)
if
response
.
status_code
==
200
:
self
.
heartbeat_failures
[
bootstrap_addr
]
=
0
current_rooms
=
self
.
addr_to_rooms_tracker
[
bootstrap_addr
].
copy
()
for
bootstrap_room
in
current_rooms
:
# Remove successful transfers from the tracker
if
bootstrap_room
not
in
self
.
transfer_statuses
:
self
.
addr_to_rooms_tracker
[
bootstrap_addr
].
discard
(
bootstrap_room
)
else
:
logger
.
info
(
f
"Attempting to reconnect to
{
bootstrap_addr
}
..."
)
self
.
heartbeat_failures
[
bootstrap_addr
]
=
(
self
.
heartbeat_failures
.
get
(
bootstrap_addr
,
0
)
+
1
)
with
self
.
session_pool_lock
:
if
bootstrap_addr
in
self
.
session_pool
:
del
self
.
session_pool
[
bootstrap_addr
]
except
Exception
:
logger
.
info
(
f
"Attempting to reconnect to
{
bootstrap_addr
}
..."
)
self
.
heartbeat_failures
[
bootstrap_addr
]
=
(
self
.
heartbeat_failures
.
get
(
bootstrap_addr
,
0
)
+
1
)
if
(
self
.
heartbeat_failures
.
get
(
bootstrap_addr
,
0
)
>=
self
.
max_failures
):
self
.
_handle_node_failure
(
bootstrap_addr
)
with
self
.
session_pool_lock
:
if
bootstrap_addr
in
self
.
session_pool
:
del
self
.
session_pool
[
bootstrap_addr
]
threading
.
Thread
(
target
=
heartbeat_checker
,
daemon
=
True
).
start
()
def
_handle_node_failure
(
self
,
failed_bootstrap_addr
):
"""Handle failure of a prefill node."""
with
self
.
connection_lock
:
keys_to_remove
=
[
k
for
k
in
self
.
connection_pool
if
k
.
startswith
(
failed_bootstrap_addr
)
]
for
k
in
keys_to_remove
:
del
self
.
connection_pool
[
k
]
if
failed_bootstrap_addr
in
self
.
prefill_tp_size_table
:
del
self
.
prefill_tp_size_table
[
failed_bootstrap_addr
]
if
failed_bootstrap_addr
in
self
.
prefill_dp_size_table
:
del
self
.
prefill_dp_size_table
[
failed_bootstrap_addr
]
if
failed_bootstrap_addr
in
self
.
prefill_pp_size_table
:
del
self
.
prefill_pp_size_table
[
failed_bootstrap_addr
]
possible_affected_rooms
=
self
.
addr_to_rooms_tracker
.
get
(
failed_bootstrap_addr
,
[]
)
if
failed_bootstrap_addr
in
self
.
addr_to_rooms_tracker
:
del
self
.
addr_to_rooms_tracker
[
failed_bootstrap_addr
]
# Mark all pending transfers associated with the failed node as failed
affected_rooms
=
[]
for
room
in
possible_affected_rooms
:
if
(
room
in
self
.
transfer_statuses
and
not
self
.
transfer_statuses
[
room
].
is_done
()
):
# Mark the transfer as failed by setting a special state
self
.
transfer_statuses
[
room
].
num_kvs_expected
=
-
1
# Indicates failure
affected_rooms
.
append
(
room
)
logger
.
error
(
f
"Lost connection with prefill instance (bootstrap_addr:
{
failed_bootstrap_addr
}
), "
f
"
{
len
(
affected_rooms
)
}
transfers affected"
)
def
check_status
(
self
,
bootstrap_room
:
int
):
return
self
.
request_status
[
bootstrap_room
]
...
...
@@ -593,6 +717,12 @@ class NixlKVReceiver(CommonKVReceiver):
self
.
conclude_state
=
None
super
().
__init__
(
mgr
,
bootstrap_addr
,
bootstrap_room
,
prefill_dp_rank
)
# Track this room with its bootstrap address for heartbeat monitoring
if
hasattr
(
self
.
kv_mgr
,
"addr_to_rooms_tracker"
):
self
.
kv_mgr
.
addr_to_rooms_tracker
[
self
.
bootstrap_addr
].
add
(
self
.
bootstrap_room
)
def
init
(
self
,
kv_indices
:
npt
.
NDArray
[
np
.
int32
],
aux_index
:
Optional
[
int
]
=
None
):
for
bootstrap_info
in
self
.
bootstrap_infos
:
logger
.
debug
(
...
...
@@ -627,9 +757,16 @@ class NixlKVReceiver(CommonKVReceiver):
self
.
kv_mgr
.
update_transfer_status
()
if
self
.
kv_mgr
.
check_transfer_done
(
self
.
bootstrap_room
):
# type: ignore
self
.
conclude_state
=
KVPoll
.
Success
# Check if the transfer failed
if
self
.
kv_mgr
.
transfer_statuses
[
self
.
bootstrap_room
].
is_failed
():
self
.
conclude_state
=
KVPoll
.
Failed
logger
.
error
(
f
"Transfer for room
{
self
.
bootstrap_room
}
failed due to node failure"
)
else
:
self
.
conclude_state
=
KVPoll
.
Success
del
self
.
kv_mgr
.
transfer_statuses
[
self
.
bootstrap_room
]
return
KVPoll
.
Success
# type: ignore
return
self
.
conclude_state
# type: ignore
return
KVPoll
.
WaitingForInput
# type: ignore
def
_register_kv_args
(
self
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment