Unverified Commit 52f248cd authored by shaharmor98's avatar shaharmor98 Committed by GitHub
Browse files

Feat/add heartbeat mechanism for nixl conn (#10222)


Signed-off-by: default avatarShahar Mor <smor@nvidia.com>
parent 93f75778
......@@ -2,14 +2,17 @@ from __future__ import annotations
import dataclasses
import logging
import os
import struct
import threading
import time
import uuid
from collections import defaultdict
from typing import Dict, List, Optional, Set
import numpy as np
import numpy.typing as npt
import requests
from sglang.srt.disaggregation.base.conn import KVArgs, KVPoll
from sglang.srt.disaggregation.common.conn import (
......@@ -21,6 +24,7 @@ from sglang.srt.disaggregation.common.conn import (
from sglang.srt.disaggregation.common.utils import group_concurrent_contiguous
from sglang.srt.disaggregation.utils import DisaggregationMode
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import get_int_env_var
logger = logging.getLogger(__name__)
......@@ -102,8 +106,14 @@ class TransferStatus:
def is_done(self):
if self.num_kvs_expected is None:
return False
# Check for failure state
if self.num_kvs_expected == -1:
return True # Failed transfers are considered "done"
return self.num_kvs_expected == len(self.received_kvs) and self.received_aux
def is_failed(self):
return self.num_kvs_expected == -1
class NixlKVManager(CommonKVManager):
def __init__(
......@@ -131,11 +141,125 @@ class NixlKVManager(CommonKVManager):
self.transfer_statuses: Dict[int, TransferStatus] = defaultdict(
TransferStatus
)
self.heartbeat_failures = {}
self.session_pool = defaultdict(requests.Session)
self.session_pool_lock = threading.Lock()
self.addr_to_rooms_tracker = defaultdict(set)
self.connection_lock = threading.Lock()
# Heartbeat interval should be at least 2 seconds
self.heartbeat_interval = max(
float(os.getenv("SGLANG_DISAGGREGATION_HEARTBEAT_INTERVAL", 5.0)), 2.0
)
# Heartbeat failure should be at least 1
self.max_failures = max(
get_int_env_var("SGLANG_DISAGGREGATION_HEARTBEAT_MAX_FAILURE", 2), 1
)
self._start_heartbeat_checker_thread()
else:
raise ValueError(
f"Unsupported DisaggregationMode: {self.disaggregation_mode}"
)
def _start_heartbeat_checker_thread(self):
"""
Start the heartbeat checker thread for Decode worker.
TODO (smor): unite nixl heartbeat checker with mooncake's.
"""
def heartbeat_checker():
while True:
time.sleep(self.heartbeat_interval)
with self.connection_lock:
addresses = list(self.prefill_dp_size_table.keys())
for bootstrap_addr in addresses:
session = None
try:
with self.session_pool_lock:
session = self.session_pool[bootstrap_addr]
response = session.get(
f"http://{bootstrap_addr}/health",
timeout=(2, 3),
headers={"Connection": "keep-alive"},
)
if response.status_code == 200:
self.heartbeat_failures[bootstrap_addr] = 0
current_rooms = self.addr_to_rooms_tracker[
bootstrap_addr
].copy()
for bootstrap_room in current_rooms:
# Remove successful transfers from the tracker
if bootstrap_room not in self.transfer_statuses:
self.addr_to_rooms_tracker[bootstrap_addr].discard(
bootstrap_room
)
else:
logger.info(
f"Attempting to reconnect to {bootstrap_addr}..."
)
self.heartbeat_failures[bootstrap_addr] = (
self.heartbeat_failures.get(bootstrap_addr, 0) + 1
)
with self.session_pool_lock:
if bootstrap_addr in self.session_pool:
del self.session_pool[bootstrap_addr]
except Exception:
logger.info(f"Attempting to reconnect to {bootstrap_addr}...")
self.heartbeat_failures[bootstrap_addr] = (
self.heartbeat_failures.get(bootstrap_addr, 0) + 1
)
if (
self.heartbeat_failures.get(bootstrap_addr, 0)
>= self.max_failures
):
self._handle_node_failure(bootstrap_addr)
with self.session_pool_lock:
if bootstrap_addr in self.session_pool:
del self.session_pool[bootstrap_addr]
threading.Thread(target=heartbeat_checker, daemon=True).start()
def _handle_node_failure(self, failed_bootstrap_addr):
"""Handle failure of a prefill node."""
with self.connection_lock:
keys_to_remove = [
k for k in self.connection_pool if k.startswith(failed_bootstrap_addr)
]
for k in keys_to_remove:
del self.connection_pool[k]
if failed_bootstrap_addr in self.prefill_tp_size_table:
del self.prefill_tp_size_table[failed_bootstrap_addr]
if failed_bootstrap_addr in self.prefill_dp_size_table:
del self.prefill_dp_size_table[failed_bootstrap_addr]
if failed_bootstrap_addr in self.prefill_pp_size_table:
del self.prefill_pp_size_table[failed_bootstrap_addr]
possible_affected_rooms = self.addr_to_rooms_tracker.get(
failed_bootstrap_addr, []
)
if failed_bootstrap_addr in self.addr_to_rooms_tracker:
del self.addr_to_rooms_tracker[failed_bootstrap_addr]
# Mark all pending transfers associated with the failed node as failed
affected_rooms = []
for room in possible_affected_rooms:
if (
room in self.transfer_statuses
and not self.transfer_statuses[room].is_done()
):
# Mark the transfer as failed by setting a special state
self.transfer_statuses[room].num_kvs_expected = -1 # Indicates failure
affected_rooms.append(room)
logger.error(
f"Lost connection with prefill instance (bootstrap_addr: {failed_bootstrap_addr}), "
f"{len(affected_rooms)} transfers affected"
)
def check_status(self, bootstrap_room: int):
return self.request_status[bootstrap_room]
......@@ -593,6 +717,12 @@ class NixlKVReceiver(CommonKVReceiver):
self.conclude_state = None
super().__init__(mgr, bootstrap_addr, bootstrap_room, prefill_dp_rank)
# Track this room with its bootstrap address for heartbeat monitoring
if hasattr(self.kv_mgr, "addr_to_rooms_tracker"):
self.kv_mgr.addr_to_rooms_tracker[self.bootstrap_addr].add(
self.bootstrap_room
)
def init(self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None):
for bootstrap_info in self.bootstrap_infos:
logger.debug(
......@@ -627,9 +757,16 @@ class NixlKVReceiver(CommonKVReceiver):
self.kv_mgr.update_transfer_status()
if self.kv_mgr.check_transfer_done(self.bootstrap_room): # type: ignore
self.conclude_state = KVPoll.Success
# Check if the transfer failed
if self.kv_mgr.transfer_statuses[self.bootstrap_room].is_failed():
self.conclude_state = KVPoll.Failed
logger.error(
f"Transfer for room {self.bootstrap_room} failed due to node failure"
)
else:
self.conclude_state = KVPoll.Success
del self.kv_mgr.transfer_statuses[self.bootstrap_room]
return KVPoll.Success # type: ignore
return self.conclude_state # type: ignore
return KVPoll.WaitingForInput # type: ignore
def _register_kv_args(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment