Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
52f248cd
Unverified
Commit
52f248cd
authored
Sep 18, 2025
by
shaharmor98
Committed by
GitHub
Sep 18, 2025
Browse files
Feat/add heartbeat mechanism for nixl conn (#10222)
Signed-off-by:
Shahar Mor
<
smor@nvidia.com
>
parent
93f75778
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
139 additions
and
2 deletions
+139
-2
python/sglang/srt/disaggregation/nixl/conn.py
python/sglang/srt/disaggregation/nixl/conn.py
+139
-2
No files found.
python/sglang/srt/disaggregation/nixl/conn.py
View file @
52f248cd
...
@@ -2,14 +2,17 @@ from __future__ import annotations
...
@@ -2,14 +2,17 @@ from __future__ import annotations
import
dataclasses
import
dataclasses
import
logging
import
logging
import
os
import
struct
import
struct
import
threading
import
threading
import
time
import
uuid
import
uuid
from
collections
import
defaultdict
from
collections
import
defaultdict
from
typing
import
Dict
,
List
,
Optional
,
Set
from
typing
import
Dict
,
List
,
Optional
,
Set
import
numpy
as
np
import
numpy
as
np
import
numpy.typing
as
npt
import
numpy.typing
as
npt
import
requests
from
sglang.srt.disaggregation.base.conn
import
KVArgs
,
KVPoll
from
sglang.srt.disaggregation.base.conn
import
KVArgs
,
KVPoll
from
sglang.srt.disaggregation.common.conn
import
(
from
sglang.srt.disaggregation.common.conn
import
(
...
@@ -21,6 +24,7 @@ from sglang.srt.disaggregation.common.conn import (
...
@@ -21,6 +24,7 @@ from sglang.srt.disaggregation.common.conn import (
from
sglang.srt.disaggregation.common.utils
import
group_concurrent_contiguous
from
sglang.srt.disaggregation.common.utils
import
group_concurrent_contiguous
from
sglang.srt.disaggregation.utils
import
DisaggregationMode
from
sglang.srt.disaggregation.utils
import
DisaggregationMode
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.utils
import
get_int_env_var
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -102,8 +106,14 @@ class TransferStatus:
...
@@ -102,8 +106,14 @@ class TransferStatus:
def
is_done
(
self
):
def
is_done
(
self
):
if
self
.
num_kvs_expected
is
None
:
if
self
.
num_kvs_expected
is
None
:
return
False
return
False
# Check for failure state
if
self
.
num_kvs_expected
==
-
1
:
return
True
# Failed transfers are considered "done"
return
self
.
num_kvs_expected
==
len
(
self
.
received_kvs
)
and
self
.
received_aux
return
self
.
num_kvs_expected
==
len
(
self
.
received_kvs
)
and
self
.
received_aux
def
is_failed
(
self
):
return
self
.
num_kvs_expected
==
-
1
class
NixlKVManager
(
CommonKVManager
):
class
NixlKVManager
(
CommonKVManager
):
def
__init__
(
def
__init__
(
...
@@ -131,11 +141,125 @@ class NixlKVManager(CommonKVManager):
...
@@ -131,11 +141,125 @@ class NixlKVManager(CommonKVManager):
self
.
transfer_statuses
:
Dict
[
int
,
TransferStatus
]
=
defaultdict
(
self
.
transfer_statuses
:
Dict
[
int
,
TransferStatus
]
=
defaultdict
(
TransferStatus
TransferStatus
)
)
self
.
heartbeat_failures
=
{}
self
.
session_pool
=
defaultdict
(
requests
.
Session
)
self
.
session_pool_lock
=
threading
.
Lock
()
self
.
addr_to_rooms_tracker
=
defaultdict
(
set
)
self
.
connection_lock
=
threading
.
Lock
()
# Heartbeat interval should be at least 2 seconds
self
.
heartbeat_interval
=
max
(
float
(
os
.
getenv
(
"SGLANG_DISAGGREGATION_HEARTBEAT_INTERVAL"
,
5.0
)),
2.0
)
# Heartbeat failure should be at least 1
self
.
max_failures
=
max
(
get_int_env_var
(
"SGLANG_DISAGGREGATION_HEARTBEAT_MAX_FAILURE"
,
2
),
1
)
self
.
_start_heartbeat_checker_thread
()
else
:
else
:
raise
ValueError
(
raise
ValueError
(
f
"Unsupported DisaggregationMode:
{
self
.
disaggregation_mode
}
"
f
"Unsupported DisaggregationMode:
{
self
.
disaggregation_mode
}
"
)
)
def
_start_heartbeat_checker_thread
(
self
):
"""
Start the heartbeat checker thread for Decode worker.
TODO (smor): unite nixl heartbeat checker with mooncake's.
"""
def
heartbeat_checker
():
while
True
:
time
.
sleep
(
self
.
heartbeat_interval
)
with
self
.
connection_lock
:
addresses
=
list
(
self
.
prefill_dp_size_table
.
keys
())
for
bootstrap_addr
in
addresses
:
session
=
None
try
:
with
self
.
session_pool_lock
:
session
=
self
.
session_pool
[
bootstrap_addr
]
response
=
session
.
get
(
f
"http://
{
bootstrap_addr
}
/health"
,
timeout
=
(
2
,
3
),
headers
=
{
"Connection"
:
"keep-alive"
},
)
if
response
.
status_code
==
200
:
self
.
heartbeat_failures
[
bootstrap_addr
]
=
0
current_rooms
=
self
.
addr_to_rooms_tracker
[
bootstrap_addr
].
copy
()
for
bootstrap_room
in
current_rooms
:
# Remove successful transfers from the tracker
if
bootstrap_room
not
in
self
.
transfer_statuses
:
self
.
addr_to_rooms_tracker
[
bootstrap_addr
].
discard
(
bootstrap_room
)
else
:
logger
.
info
(
f
"Attempting to reconnect to
{
bootstrap_addr
}
..."
)
self
.
heartbeat_failures
[
bootstrap_addr
]
=
(
self
.
heartbeat_failures
.
get
(
bootstrap_addr
,
0
)
+
1
)
with
self
.
session_pool_lock
:
if
bootstrap_addr
in
self
.
session_pool
:
del
self
.
session_pool
[
bootstrap_addr
]
except
Exception
:
logger
.
info
(
f
"Attempting to reconnect to
{
bootstrap_addr
}
..."
)
self
.
heartbeat_failures
[
bootstrap_addr
]
=
(
self
.
heartbeat_failures
.
get
(
bootstrap_addr
,
0
)
+
1
)
if
(
self
.
heartbeat_failures
.
get
(
bootstrap_addr
,
0
)
>=
self
.
max_failures
):
self
.
_handle_node_failure
(
bootstrap_addr
)
with
self
.
session_pool_lock
:
if
bootstrap_addr
in
self
.
session_pool
:
del
self
.
session_pool
[
bootstrap_addr
]
threading
.
Thread
(
target
=
heartbeat_checker
,
daemon
=
True
).
start
()
def
_handle_node_failure
(
self
,
failed_bootstrap_addr
):
"""Handle failure of a prefill node."""
with
self
.
connection_lock
:
keys_to_remove
=
[
k
for
k
in
self
.
connection_pool
if
k
.
startswith
(
failed_bootstrap_addr
)
]
for
k
in
keys_to_remove
:
del
self
.
connection_pool
[
k
]
if
failed_bootstrap_addr
in
self
.
prefill_tp_size_table
:
del
self
.
prefill_tp_size_table
[
failed_bootstrap_addr
]
if
failed_bootstrap_addr
in
self
.
prefill_dp_size_table
:
del
self
.
prefill_dp_size_table
[
failed_bootstrap_addr
]
if
failed_bootstrap_addr
in
self
.
prefill_pp_size_table
:
del
self
.
prefill_pp_size_table
[
failed_bootstrap_addr
]
possible_affected_rooms
=
self
.
addr_to_rooms_tracker
.
get
(
failed_bootstrap_addr
,
[]
)
if
failed_bootstrap_addr
in
self
.
addr_to_rooms_tracker
:
del
self
.
addr_to_rooms_tracker
[
failed_bootstrap_addr
]
# Mark all pending transfers associated with the failed node as failed
affected_rooms
=
[]
for
room
in
possible_affected_rooms
:
if
(
room
in
self
.
transfer_statuses
and
not
self
.
transfer_statuses
[
room
].
is_done
()
):
# Mark the transfer as failed by setting a special state
self
.
transfer_statuses
[
room
].
num_kvs_expected
=
-
1
# Indicates failure
affected_rooms
.
append
(
room
)
logger
.
error
(
f
"Lost connection with prefill instance (bootstrap_addr:
{
failed_bootstrap_addr
}
), "
f
"
{
len
(
affected_rooms
)
}
transfers affected"
)
def
check_status
(
self
,
bootstrap_room
:
int
):
def
check_status
(
self
,
bootstrap_room
:
int
):
return
self
.
request_status
[
bootstrap_room
]
return
self
.
request_status
[
bootstrap_room
]
...
@@ -593,6 +717,12 @@ class NixlKVReceiver(CommonKVReceiver):
...
@@ -593,6 +717,12 @@ class NixlKVReceiver(CommonKVReceiver):
self
.
conclude_state
=
None
self
.
conclude_state
=
None
super
().
__init__
(
mgr
,
bootstrap_addr
,
bootstrap_room
,
prefill_dp_rank
)
super
().
__init__
(
mgr
,
bootstrap_addr
,
bootstrap_room
,
prefill_dp_rank
)
# Track this room with its bootstrap address for heartbeat monitoring
if
hasattr
(
self
.
kv_mgr
,
"addr_to_rooms_tracker"
):
self
.
kv_mgr
.
addr_to_rooms_tracker
[
self
.
bootstrap_addr
].
add
(
self
.
bootstrap_room
)
def
init
(
self
,
kv_indices
:
npt
.
NDArray
[
np
.
int32
],
aux_index
:
Optional
[
int
]
=
None
):
def
init
(
self
,
kv_indices
:
npt
.
NDArray
[
np
.
int32
],
aux_index
:
Optional
[
int
]
=
None
):
for
bootstrap_info
in
self
.
bootstrap_infos
:
for
bootstrap_info
in
self
.
bootstrap_infos
:
logger
.
debug
(
logger
.
debug
(
...
@@ -627,9 +757,16 @@ class NixlKVReceiver(CommonKVReceiver):
...
@@ -627,9 +757,16 @@ class NixlKVReceiver(CommonKVReceiver):
self
.
kv_mgr
.
update_transfer_status
()
self
.
kv_mgr
.
update_transfer_status
()
if
self
.
kv_mgr
.
check_transfer_done
(
self
.
bootstrap_room
):
# type: ignore
if
self
.
kv_mgr
.
check_transfer_done
(
self
.
bootstrap_room
):
# type: ignore
self
.
conclude_state
=
KVPoll
.
Success
# Check if the transfer failed
if
self
.
kv_mgr
.
transfer_statuses
[
self
.
bootstrap_room
].
is_failed
():
self
.
conclude_state
=
KVPoll
.
Failed
logger
.
error
(
f
"Transfer for room
{
self
.
bootstrap_room
}
failed due to node failure"
)
else
:
self
.
conclude_state
=
KVPoll
.
Success
del
self
.
kv_mgr
.
transfer_statuses
[
self
.
bootstrap_room
]
del
self
.
kv_mgr
.
transfer_statuses
[
self
.
bootstrap_room
]
return
KVPoll
.
Success
# type: ignore
return
self
.
conclude_state
# type: ignore
return
KVPoll
.
WaitingForInput
# type: ignore
return
KVPoll
.
WaitingForInput
# type: ignore
def
_register_kv_args
(
self
):
def
_register_kv_args
(
self
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment