Unverified Commit 52f248cd authored by shaharmor98's avatar shaharmor98 Committed by GitHub
Browse files

Feat/add heartbeat mechanism for nixl conn (#10222)


Signed-off-by: default avatarShahar Mor <smor@nvidia.com>
parent 93f75778
...@@ -2,14 +2,17 @@ from __future__ import annotations ...@@ -2,14 +2,17 @@ from __future__ import annotations
import dataclasses import dataclasses
import logging import logging
import os
import struct import struct
import threading import threading
import time
import uuid import uuid
from collections import defaultdict from collections import defaultdict
from typing import Dict, List, Optional, Set from typing import Dict, List, Optional, Set
import numpy as np import numpy as np
import numpy.typing as npt import numpy.typing as npt
import requests
from sglang.srt.disaggregation.base.conn import KVArgs, KVPoll from sglang.srt.disaggregation.base.conn import KVArgs, KVPoll
from sglang.srt.disaggregation.common.conn import ( from sglang.srt.disaggregation.common.conn import (
...@@ -21,6 +24,7 @@ from sglang.srt.disaggregation.common.conn import ( ...@@ -21,6 +24,7 @@ from sglang.srt.disaggregation.common.conn import (
from sglang.srt.disaggregation.common.utils import group_concurrent_contiguous from sglang.srt.disaggregation.common.utils import group_concurrent_contiguous
from sglang.srt.disaggregation.utils import DisaggregationMode from sglang.srt.disaggregation.utils import DisaggregationMode
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import get_int_env_var
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -102,8 +106,14 @@ class TransferStatus: ...@@ -102,8 +106,14 @@ class TransferStatus:
def is_done(self): def is_done(self):
if self.num_kvs_expected is None: if self.num_kvs_expected is None:
return False return False
# Check for failure state
if self.num_kvs_expected == -1:
return True # Failed transfers are considered "done"
return self.num_kvs_expected == len(self.received_kvs) and self.received_aux return self.num_kvs_expected == len(self.received_kvs) and self.received_aux
def is_failed(self):
return self.num_kvs_expected == -1
class NixlKVManager(CommonKVManager): class NixlKVManager(CommonKVManager):
def __init__( def __init__(
...@@ -131,11 +141,125 @@ class NixlKVManager(CommonKVManager): ...@@ -131,11 +141,125 @@ class NixlKVManager(CommonKVManager):
self.transfer_statuses: Dict[int, TransferStatus] = defaultdict( self.transfer_statuses: Dict[int, TransferStatus] = defaultdict(
TransferStatus TransferStatus
) )
self.heartbeat_failures = {}
self.session_pool = defaultdict(requests.Session)
self.session_pool_lock = threading.Lock()
self.addr_to_rooms_tracker = defaultdict(set)
self.connection_lock = threading.Lock()
# Heartbeat interval should be at least 2 seconds
self.heartbeat_interval = max(
float(os.getenv("SGLANG_DISAGGREGATION_HEARTBEAT_INTERVAL", 5.0)), 2.0
)
# Heartbeat failure should be at least 1
self.max_failures = max(
get_int_env_var("SGLANG_DISAGGREGATION_HEARTBEAT_MAX_FAILURE", 2), 1
)
self._start_heartbeat_checker_thread()
else: else:
raise ValueError( raise ValueError(
f"Unsupported DisaggregationMode: {self.disaggregation_mode}" f"Unsupported DisaggregationMode: {self.disaggregation_mode}"
) )
def _start_heartbeat_checker_thread(self):
"""
Start the heartbeat checker thread for Decode worker.
TODO (smor): unite nixl heartbeat checker with mooncake's.
"""
def heartbeat_checker():
while True:
time.sleep(self.heartbeat_interval)
with self.connection_lock:
addresses = list(self.prefill_dp_size_table.keys())
for bootstrap_addr in addresses:
session = None
try:
with self.session_pool_lock:
session = self.session_pool[bootstrap_addr]
response = session.get(
f"http://{bootstrap_addr}/health",
timeout=(2, 3),
headers={"Connection": "keep-alive"},
)
if response.status_code == 200:
self.heartbeat_failures[bootstrap_addr] = 0
current_rooms = self.addr_to_rooms_tracker[
bootstrap_addr
].copy()
for bootstrap_room in current_rooms:
# Remove successful transfers from the tracker
if bootstrap_room not in self.transfer_statuses:
self.addr_to_rooms_tracker[bootstrap_addr].discard(
bootstrap_room
)
else:
logger.info(
f"Attempting to reconnect to {bootstrap_addr}..."
)
self.heartbeat_failures[bootstrap_addr] = (
self.heartbeat_failures.get(bootstrap_addr, 0) + 1
)
with self.session_pool_lock:
if bootstrap_addr in self.session_pool:
del self.session_pool[bootstrap_addr]
except Exception:
logger.info(f"Attempting to reconnect to {bootstrap_addr}...")
self.heartbeat_failures[bootstrap_addr] = (
self.heartbeat_failures.get(bootstrap_addr, 0) + 1
)
if (
self.heartbeat_failures.get(bootstrap_addr, 0)
>= self.max_failures
):
self._handle_node_failure(bootstrap_addr)
with self.session_pool_lock:
if bootstrap_addr in self.session_pool:
del self.session_pool[bootstrap_addr]
threading.Thread(target=heartbeat_checker, daemon=True).start()
def _handle_node_failure(self, failed_bootstrap_addr):
"""Handle failure of a prefill node."""
with self.connection_lock:
keys_to_remove = [
k for k in self.connection_pool if k.startswith(failed_bootstrap_addr)
]
for k in keys_to_remove:
del self.connection_pool[k]
if failed_bootstrap_addr in self.prefill_tp_size_table:
del self.prefill_tp_size_table[failed_bootstrap_addr]
if failed_bootstrap_addr in self.prefill_dp_size_table:
del self.prefill_dp_size_table[failed_bootstrap_addr]
if failed_bootstrap_addr in self.prefill_pp_size_table:
del self.prefill_pp_size_table[failed_bootstrap_addr]
possible_affected_rooms = self.addr_to_rooms_tracker.get(
failed_bootstrap_addr, []
)
if failed_bootstrap_addr in self.addr_to_rooms_tracker:
del self.addr_to_rooms_tracker[failed_bootstrap_addr]
# Mark all pending transfers associated with the failed node as failed
affected_rooms = []
for room in possible_affected_rooms:
if (
room in self.transfer_statuses
and not self.transfer_statuses[room].is_done()
):
# Mark the transfer as failed by setting a special state
self.transfer_statuses[room].num_kvs_expected = -1 # Indicates failure
affected_rooms.append(room)
logger.error(
f"Lost connection with prefill instance (bootstrap_addr: {failed_bootstrap_addr}), "
f"{len(affected_rooms)} transfers affected"
)
def check_status(self, bootstrap_room: int): def check_status(self, bootstrap_room: int):
return self.request_status[bootstrap_room] return self.request_status[bootstrap_room]
...@@ -593,6 +717,12 @@ class NixlKVReceiver(CommonKVReceiver): ...@@ -593,6 +717,12 @@ class NixlKVReceiver(CommonKVReceiver):
self.conclude_state = None self.conclude_state = None
super().__init__(mgr, bootstrap_addr, bootstrap_room, prefill_dp_rank) super().__init__(mgr, bootstrap_addr, bootstrap_room, prefill_dp_rank)
# Track this room with its bootstrap address for heartbeat monitoring
if hasattr(self.kv_mgr, "addr_to_rooms_tracker"):
self.kv_mgr.addr_to_rooms_tracker[self.bootstrap_addr].add(
self.bootstrap_room
)
def init(self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None): def init(self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None):
for bootstrap_info in self.bootstrap_infos: for bootstrap_info in self.bootstrap_infos:
logger.debug( logger.debug(
...@@ -627,9 +757,16 @@ class NixlKVReceiver(CommonKVReceiver): ...@@ -627,9 +757,16 @@ class NixlKVReceiver(CommonKVReceiver):
self.kv_mgr.update_transfer_status() self.kv_mgr.update_transfer_status()
if self.kv_mgr.check_transfer_done(self.bootstrap_room): # type: ignore if self.kv_mgr.check_transfer_done(self.bootstrap_room): # type: ignore
self.conclude_state = KVPoll.Success # Check if the transfer failed
if self.kv_mgr.transfer_statuses[self.bootstrap_room].is_failed():
self.conclude_state = KVPoll.Failed
logger.error(
f"Transfer for room {self.bootstrap_room} failed due to node failure"
)
else:
self.conclude_state = KVPoll.Success
del self.kv_mgr.transfer_statuses[self.bootstrap_room] del self.kv_mgr.transfer_statuses[self.bootstrap_room]
return KVPoll.Success # type: ignore return self.conclude_state # type: ignore
return KVPoll.WaitingForInput # type: ignore return KVPoll.WaitingForInput # type: ignore
def _register_kv_args(self): def _register_kv_args(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment