Unverified Commit 3ce94f71 authored by shangmingc's avatar shangmingc Committed by GitHub
Browse files

[PD] Handle P/D failure and reconnect without affecting other instances (#6263)


Signed-off-by: default avatarShangming Cai <caishangming@linux.alibaba.com>
parent ca95556c
...@@ -361,7 +361,7 @@ class DecodeTransferQueue: ...@@ -361,7 +361,7 @@ class DecodeTransferQueue:
indices_to_remove = set() indices_to_remove = set()
for i, (decode_req, poll) in enumerate(zip(self.queue, polls)): for i, (decode_req, poll) in enumerate(zip(self.queue, polls)):
if poll == KVPoll.Failed: if poll == KVPoll.Failed:
error_message = f"Decode transfer failed for request {decode_req.req.rid=} {decode_req.req.bootstrap_room=}" error_message = f"Decode transfer failed for request rank={self.scheduler.tp_rank} {decode_req.req.rid=} {decode_req.req.bootstrap_room=}"
try: try:
decode_req.kv_receiver.failure_exception() decode_req.kv_receiver.failure_exception()
except Exception as e: except Exception as e:
...@@ -409,7 +409,8 @@ class DecodeTransferQueue: ...@@ -409,7 +409,8 @@ class DecodeTransferQueue:
: decode_req.req.top_logprobs_num : decode_req.req.top_logprobs_num
].tolist() ].tolist()
) )
if hasattr(decode_req.kv_receiver, "clear"):
decode_req.kv_receiver.clear()
transferred_reqs.append(decode_req.req) transferred_reqs.append(decode_req.req)
indices_to_remove.add(i) indices_to_remove.add(i)
elif poll in [ elif poll in [
......
...@@ -9,6 +9,8 @@ import queue ...@@ -9,6 +9,8 @@ import queue
import socket import socket
import struct import struct
import threading import threading
import time
from collections import defaultdict
from functools import cache from functools import cache
from typing import Dict, List, Optional, Tuple, Union from typing import Dict, List, Optional, Tuple, Union
...@@ -51,6 +53,16 @@ def group_concurrent_contiguous( ...@@ -51,6 +53,16 @@ def group_concurrent_contiguous(
return src_groups, dst_groups return src_groups, dst_groups
class KVTransferError(Exception):
def __init__(self, bootstrap_room: int, failure_reason: str):
super().__init__(failure_reason)
self.bootstrap_room = bootstrap_room
self.failure_reason = failure_reason
def __str__(self):
return f"KVTransferError(bootstrap_room={self.bootstrap_room}): {self.failure_reason}"
# prefill # prefill
@dataclasses.dataclass @dataclasses.dataclass
class TransferKVChunk: class TransferKVChunk:
...@@ -153,13 +165,34 @@ class MooncakeKVManager(BaseKVManager): ...@@ -153,13 +165,34 @@ class MooncakeKVManager(BaseKVManager):
self.decode_kv_args_table: Dict[str, KVArgsRegisterInfo] = {} self.decode_kv_args_table: Dict[str, KVArgsRegisterInfo] = {}
self.start_prefill_thread() self.start_prefill_thread()
self._register_to_bootstrap() self._register_to_bootstrap()
self.session_failures = defaultdict(int)
self.failed_sessions = set()
self.session_lock = threading.Lock()
# Determine the number of threads to use for kv sender # Determine the number of threads to use for kv sender
cpu_count = os.cpu_count() cpu_count = os.cpu_count()
self.executor = concurrent.futures.ThreadPoolExecutor( self.executor = concurrent.futures.ThreadPoolExecutor(
min(cpu_count // 4, 16) int(
os.getenv(
"DISAGGREGATION_THREAD_POOL_SIZE",
min(max(1, cpu_count // 8), 8),
)
)
) )
elif self.disaggregation_mode == DisaggregationMode.DECODE: elif self.disaggregation_mode == DisaggregationMode.DECODE:
self.heartbeat_failures = {}
self.session_pool = defaultdict(requests.Session)
self.session_pool_lock = threading.Lock()
self.addr_to_rooms_tracker = defaultdict(list)
self.connection_lock = threading.Lock()
# Heartbeat interval should be at least 2 seconds
self.heartbeat_interval = max(
float(os.getenv("DISAGGREGATION_HEARTBEAT_INTERVAL", 5.0)), 2.0
)
# Heartbeat failure should be at least 1
self.max_failures = max(
int(os.getenv("DISAGGREGATION_HEARTBEAT_MAX_FAILURE", 2)), 1
)
self.start_decode_thread() self.start_decode_thread()
self.connection_pool: Dict[str, Dict[str, Union[str, int]]] = {} self.connection_pool: Dict[str, Dict[str, Union[str, int]]] = {}
self.prefill_tp_size_table: Dict[str, int] = {} self.prefill_tp_size_table: Dict[str, int] = {}
...@@ -169,6 +202,9 @@ class MooncakeKVManager(BaseKVManager): ...@@ -169,6 +202,9 @@ class MooncakeKVManager(BaseKVManager):
f"Unsupported DisaggregationMode: {self.disaggregation_mode}" f"Unsupported DisaggregationMode: {self.disaggregation_mode}"
) )
self.failure_records: Dict[int, str] = {}
self.failure_lock = threading.Lock()
def register_buffer_to_engine(self): def register_buffer_to_engine(self):
for kv_data_ptr, kv_data_len in zip( for kv_data_ptr, kv_data_len in zip(
self.kv_args.kv_data_ptrs, self.kv_args.kv_data_lens self.kv_args.kv_data_ptrs, self.kv_args.kv_data_lens
...@@ -235,8 +271,6 @@ class MooncakeKVManager(BaseKVManager): ...@@ -235,8 +271,6 @@ class MooncakeKVManager(BaseKVManager):
for future in concurrent.futures.as_completed(futures): for future in concurrent.futures.as_completed(futures):
status = future.result() status = future.result()
if status != 0: if status != 0:
# Immediate shutdown on first error (existing tasks will finish)
self.executor.shutdown(wait=False)
for f in futures: for f in futures:
f.cancel() f.cancel()
return status return status
...@@ -255,20 +289,20 @@ class MooncakeKVManager(BaseKVManager): ...@@ -255,20 +289,20 @@ class MooncakeKVManager(BaseKVManager):
self.kv_args.aux_data_ptrs[0] + prefill_aux_index * aux_item_len self.kv_args.aux_data_ptrs[0] + prefill_aux_index * aux_item_len
) )
decode_aux_addr = dst_aux_ptrs[0] + dst_aux_index * aux_item_len decode_aux_addr = dst_aux_ptrs[0] + dst_aux_index * aux_item_len
# TODO: mooncake transfer engine can do async transfer. Do async later
# Not sure about the amount of aux data, maybe transfer it by zmq is more effective
status = self.engine.transfer_sync( status = self.engine.transfer_sync(
mooncake_session_id, prefill_aux_addr, decode_aux_addr, aux_item_len mooncake_session_id, prefill_aux_addr, decode_aux_addr, aux_item_len
) )
return status return status
def sync_status_to_decode_endpoint(self, remote: str, dst_port: int, room: int): def sync_status_to_decode_endpoint(
self, remote: str, dst_port: int, room: int, status: int
):
if ":" in remote: if ":" in remote:
remote = remote.split(":")[0] remote = remote.split(":")[0]
self._connect("tcp://" + remote + ":" + str(dst_port)).send_multipart( self._connect("tcp://" + remote + ":" + str(dst_port)).send_multipart(
[ [
str(room).encode("ascii"), str(room).encode("ascii"),
str(self.check_status(room)).encode("ascii"), str(status).encode("ascii"),
] ]
) )
...@@ -287,6 +321,11 @@ class MooncakeKVManager(BaseKVManager): ...@@ -287,6 +321,11 @@ class MooncakeKVManager(BaseKVManager):
self.decode_kv_args_table[mooncake_session_id] = ( self.decode_kv_args_table[mooncake_session_id] = (
KVArgsRegisterInfo.from_zmq(waiting_req_bytes) KVArgsRegisterInfo.from_zmq(waiting_req_bytes)
) )
with self.session_lock:
if mooncake_session_id in self.failed_sessions:
self.failed_sessions.remove(mooncake_session_id)
if mooncake_session_id in self.session_failures:
del self.session_failures[mooncake_session_id]
logger.debug( logger.debug(
f"Register KVArgs from {mooncake_session_id} successfully" f"Register KVArgs from {mooncake_session_id} successfully"
) )
...@@ -309,17 +348,48 @@ class MooncakeKVManager(BaseKVManager): ...@@ -309,17 +348,48 @@ class MooncakeKVManager(BaseKVManager):
while True: while True:
try: try:
kv_chunk: TransferKVChunk = self.transfer_queue.get(timeout=0.01) kv_chunk: TransferKVChunk = self.transfer_queue.get(timeout=0.01)
reqs_to_be_processed = self.transfer_infos[kv_chunk.room].values() reqs_to_be_processed = (
self.transfer_infos[kv_chunk.room].values()
if kv_chunk.room in self.transfer_infos
else []
)
polls = [] polls = []
dst_ranks_infos = [] dst_ranks_infos = []
for req in reqs_to_be_processed: for req in reqs_to_be_processed:
if not req.is_dummy: if not req.is_dummy:
# Early exit if the request has failed
with self.session_lock:
if req.mooncake_session_id in self.failed_sessions:
self.record_failure(
kv_chunk.room,
f"Decode instance could be dead, {req.mooncake_session_id} failed due to multiple errors",
)
self.update_status(kv_chunk.room, KVPoll.Failed)
self.sync_status_to_decode_endpoint(
req.endpoint,
req.dst_port,
req.room,
KVPoll.Failed,
)
break
chunked_dst_kv_indice = req.dst_kv_indices[ chunked_dst_kv_indice = req.dst_kv_indices[
kv_chunk.index_slice kv_chunk.index_slice
] ]
assert len(chunked_dst_kv_indice) == len(
# NOTE: This is temporarily a workaround to deal with the case where the prefill_kv_indices
# is mismatched with the dst_kv_indices when page size > 1, this should never happen.
if len(chunked_dst_kv_indice) < len(
kv_chunk.prefill_kv_indices kv_chunk.prefill_kv_indices
), f"len(chunked_dst_kv_indice) = {len(chunked_dst_kv_indice)}, len(kv_chunk.prefill_kv_indices) = {len(kv_chunk.prefill_kv_indices)}" ):
kv_chunk.prefill_kv_indices = (
kv_chunk.prefill_kv_indices[
len(chunked_dst_kv_indice)
]
)
logger.warning(
f"len(chunked_dst_kv_indice) = {len(chunked_dst_kv_indice)}, len(kv_chunk.prefill_kv_indices) = {len(kv_chunk.prefill_kv_indices)}"
)
ret = self.send_kvcache( ret = self.send_kvcache(
req.mooncake_session_id, req.mooncake_session_id,
...@@ -330,11 +400,28 @@ class MooncakeKVManager(BaseKVManager): ...@@ -330,11 +400,28 @@ class MooncakeKVManager(BaseKVManager):
chunked_dst_kv_indice, chunked_dst_kv_indice,
) )
if ret != 0: if ret != 0:
with self.session_lock:
self.session_failures[req.mooncake_session_id] += 1
# Failures should never happen if the session is not dead, if the session fails once, mark it as failed
if (
self.session_failures[req.mooncake_session_id]
>= 1
):
self.failed_sessions.add(
req.mooncake_session_id
)
logger.error(
f"Session {req.mooncake_session_id} failed."
)
self.record_failure(
kv_chunk.room,
f"Failed to send kv chunk of {kv_chunk.room} to {req.endpoint}:{req.dst_port}",
)
self.update_status(kv_chunk.room, KVPoll.Failed) self.update_status(kv_chunk.room, KVPoll.Failed)
self.sync_status_to_decode_endpoint( self.sync_status_to_decode_endpoint(
req.endpoint, req.dst_port, req.room req.endpoint, req.dst_port, req.room, KVPoll.Failed
) )
continue break
if kv_chunk.is_last: if kv_chunk.is_last:
# Only the last chunk we need to send the aux data # Only the last chunk we need to send the aux data
...@@ -353,25 +440,33 @@ class MooncakeKVManager(BaseKVManager): ...@@ -353,25 +440,33 @@ class MooncakeKVManager(BaseKVManager):
# Only sync status when all the dst ranks have received the kvcache # Only sync status when all the dst ranks have received the kvcache
if len(polls) == req.required_dst_info_num: if len(polls) == req.required_dst_info_num:
self.update_status( status = (
req.room, KVPoll.Success if all(polls) else KVPoll.Failed
KVPoll.Success if all(polls) else KVPoll.Failed,
) )
self.update_status(req.room, status)
for endpoint, dst_port, room in dst_ranks_infos: for endpoint, dst_port, room in dst_ranks_infos:
self.sync_status_to_decode_endpoint( self.sync_status_to_decode_endpoint(
endpoint, dst_port, room endpoint, dst_port, room, status
) )
else: else:
# Dummy request means the decode instance is not used, so its status can be marked as success directly # Dummy request means the decode instance is not used, so its status can be marked as success directly
# Dummy request does not need to sync status to decode endpoint # Dummy request does not need to sync status to decode endpoint
if kv_chunk.is_last: if kv_chunk.is_last and req.room in self.request_status:
self.update_status(req.room, KVPoll.Success) self.update_status(req.room, KVPoll.Success)
if self.check_status(kv_chunk.room) == KVPoll.Success: if (
kv_chunk.room not in self.request_status
or self.check_status(kv_chunk.room) == KVPoll.Success
):
self.transfer_infos.pop(kv_chunk.room) self.transfer_infos.pop(kv_chunk.room)
except queue.Empty: except queue.Empty:
continue continue
except Exception as e:
# NOTE(shangming): Remove this when we make sure the transfer thread is bug-free
raise RuntimeError(
f"Transfer thread failed because of {e}. Prefill instance with bootstrap_port={self.bootstrap_port} is dead."
)
threading.Thread(target=bootstrap_thread).start() threading.Thread(target=bootstrap_thread).start()
threading.Thread(target=transfer_thread).start() threading.Thread(target=transfer_thread).start()
...@@ -385,9 +480,67 @@ class MooncakeKVManager(BaseKVManager): ...@@ -385,9 +480,67 @@ class MooncakeKVManager(BaseKVManager):
(bootstrap_room, status) = self.server_socket.recv_multipart() (bootstrap_room, status) = self.server_socket.recv_multipart()
status = int(status.decode("ascii")) status = int(status.decode("ascii"))
bootstrap_room = int(bootstrap_room.decode("ascii")) bootstrap_room = int(bootstrap_room.decode("ascii"))
if status == KVPoll.Failed:
self.record_failure(
bootstrap_room,
f"Failed to get kvcache from prefill instance, it might be dead",
)
self.update_status(bootstrap_room, status) self.update_status(bootstrap_room, status)
def heartbeat_checker():
while True:
time.sleep(self.heartbeat_interval)
with self.connection_lock:
addresses = list(self.prefill_dp_size_table.keys())
for bootstrap_addr in addresses:
session = None
try:
with self.session_pool_lock:
session = self.session_pool[bootstrap_addr]
response = session.get(
f"http://{bootstrap_addr}/health",
timeout=(2, 3),
headers={"Connection": "keep-alive"},
)
if response.status_code == 200:
self.heartbeat_failures[bootstrap_addr] = 0
for bootstrap_room in self.addr_to_rooms_tracker[
bootstrap_addr
]:
# Remove KVPoll.Success requests from the map
if bootstrap_room not in self.request_status:
self.addr_to_rooms_tracker[bootstrap_addr].remove(
bootstrap_room
)
else:
logger.info(
f"Attempting to reconnect to {bootstrap_addr}..."
)
self.heartbeat_failures[bootstrap_addr] = (
self.heartbeat_failures.get(bootstrap_addr, 0) + 1
)
with self.session_pool_lock:
if bootstrap_addr in self.session_pool:
del self.session_pool[bootstrap_addr]
except Exception:
logger.info(f"Attempting to reconnect to {bootstrap_addr}...")
self.heartbeat_failures[bootstrap_addr] = (
self.heartbeat_failures.get(bootstrap_addr, 0) + 1
)
if (
self.heartbeat_failures.get(bootstrap_addr, 0)
>= self.max_failures
):
self._handle_node_failure(bootstrap_addr)
with self.session_pool_lock:
if bootstrap_addr in self.session_pool:
del self.session_pool[bootstrap_addr]
threading.Thread(target=decode_thread).start() threading.Thread(target=decode_thread).start()
threading.Thread(target=heartbeat_checker).start()
def add_transfer_request( def add_transfer_request(
self, self,
...@@ -400,6 +553,15 @@ class MooncakeKVManager(BaseKVManager): ...@@ -400,6 +553,15 @@ class MooncakeKVManager(BaseKVManager):
assert self.disaggregation_mode == DisaggregationMode.PREFILL assert self.disaggregation_mode == DisaggregationMode.PREFILL
assert not is_last or (is_last and aux_index is not None) assert not is_last or (is_last and aux_index is not None)
if (
bootstrap_room not in self.request_status
or self.check_status(bootstrap_room) == KVPoll.Failed
):
logger.debug(
"Request with bootstrap_room=%s already failed", bootstrap_room
)
return
self.transfer_queue.put( self.transfer_queue.put(
TransferKVChunk( TransferKVChunk(
room=bootstrap_room, room=bootstrap_room,
...@@ -418,10 +580,17 @@ class MooncakeKVManager(BaseKVManager): ...@@ -418,10 +580,17 @@ class MooncakeKVManager(BaseKVManager):
if bootstrap_room not in self.request_status: if bootstrap_room not in self.request_status:
self.request_status[bootstrap_room] = status self.request_status[bootstrap_room] = status
else: else:
# NOTE: The prefill engine could recv bootstrapping first # NOTE: status is only allowed to be incremented unless it is KVPoll.Failed
self.request_status[bootstrap_room] = max( if status == KVPoll.Failed:
self.request_status[bootstrap_room], status self.request_status[bootstrap_room] = KVPoll.Failed
) else:
self.request_status[bootstrap_room] = max(
self.request_status[bootstrap_room], status
)
def record_failure(self, bootstrap_room: int, failure_reason: str):
with self.failure_lock:
self.failure_records[bootstrap_room] = failure_reason
def get_session_id(self): def get_session_id(self):
return self.engine.get_session_id() return self.engine.get_session_id()
...@@ -445,15 +614,51 @@ class MooncakeKVManager(BaseKVManager): ...@@ -445,15 +614,51 @@ class MooncakeKVManager(BaseKVManager):
} }
try: try:
response = requests.put(url, json=payload) response = requests.put(url, json=payload, timeout=5)
if response.status_code == 200: if response.status_code == 200:
logger.debug("Prefill successfully registered to bootstrap server.") logger.debug("Prefill successfully registered to bootstrap server.")
else: else:
logger.error( logger.error(
f"Prefill Failed to connect to bootstrap server: {response.status_code}, {response.text}" f"Prefill instance failed to connect to bootstrap server: {response.status_code}, {response.text}"
) )
except Exception as e: except Exception as e:
logger.error(f"Prefill Failed to register to bootstrap server: {e}") logger.error(
f"Prefill instance failed to register to bootstrap server: {e}"
)
def _handle_node_failure(self, failed_bootstrap_addr):
with self.connection_lock:
keys_to_remove = [
k for k in self.connection_pool if k.startswith(failed_bootstrap_addr)
]
for k in keys_to_remove:
del self.connection_pool[k]
if failed_bootstrap_addr in self.prefill_tp_size_table:
del self.prefill_tp_size_table[failed_bootstrap_addr]
if failed_bootstrap_addr in self.prefill_dp_size_table:
del self.prefill_dp_size_table[failed_bootstrap_addr]
possible_affected_rooms = self.addr_to_rooms_tracker.get(
failed_bootstrap_addr, []
)
del self.addr_to_rooms_tracker[failed_bootstrap_addr]
# Report the requests associated with the failed bootstrap addr and mark their status as KVPoll.Failed
affected_rooms = []
for room in possible_affected_rooms:
if (
room in self.request_status
and self.check_status(room) != KVPoll.Success
):
self.record_failure(
room,
f"Losing connection with prefill instance (bootstrap_addr: {failed_bootstrap_addr})",
)
self.update_status(room, KVPoll.Failed)
affected_rooms.append(room)
logger.error(
f"Losing connection with prefill instance (bootstrap_addr: {failed_bootstrap_addr}), affected {len(affected_rooms)} requests"
)
class MooncakeKVSender(BaseKVSender): class MooncakeKVSender(BaseKVSender):
...@@ -466,7 +671,7 @@ class MooncakeKVSender(BaseKVSender): ...@@ -466,7 +671,7 @@ class MooncakeKVSender(BaseKVSender):
self.kv_mgr.update_status(bootstrap_room, KVPoll.Bootstrapping) self.kv_mgr.update_status(bootstrap_room, KVPoll.Bootstrapping)
self.aux_index = None self.aux_index = None
self.bootstrap_server_url = bootstrap_addr self.bootstrap_server_url = bootstrap_addr
self.session_id = self.kv_mgr.get_session_id() self.conclude_state = None
# inner state # inner state
self.curr_idx = 0 self.curr_idx = 0
...@@ -496,11 +701,30 @@ class MooncakeKVSender(BaseKVSender): ...@@ -496,11 +701,30 @@ class MooncakeKVSender(BaseKVSender):
) )
def poll(self) -> KVPoll: def poll(self) -> KVPoll:
return self.kv_mgr.check_status(self.bootstrap_room) if self.conclude_state is None:
status = self.kv_mgr.check_status(self.bootstrap_room)
if status in (KVPoll.Success, KVPoll.Failed):
self.conclude_state = status
return status
else:
return self.conclude_state
def clear(self) -> None:
self.kv_mgr.request_status.pop(self.bootstrap_room)
def failure_exception(self): def failure_exception(self):
# TODO: raise a real exception self.clear()
raise Exception("Fake KVSender Exception")
# Explicitly set the status to failure since this request has failed in another rank
if self.conclude_state is None:
self.conclude_state = KVPoll.Failed
with self.kv_mgr.failure_lock:
failure_reason = self.kv_mgr.failure_records.pop(
self.bootstrap_room, "Failed due to an unknown reason from another rank"
)
raise KVTransferError(self.bootstrap_room, failure_reason)
class MooncakeKVReceiver(BaseKVReceiver): class MooncakeKVReceiver(BaseKVReceiver):
...@@ -519,17 +743,24 @@ class MooncakeKVReceiver(BaseKVReceiver): ...@@ -519,17 +743,24 @@ class MooncakeKVReceiver(BaseKVReceiver):
self.bootstrap_addr = bootstrap_addr self.bootstrap_addr = bootstrap_addr
self.kv_mgr = mgr self.kv_mgr = mgr
self.session_id = self.kv_mgr.get_session_id() self.session_id = self.kv_mgr.get_session_id()
self.kv_mgr.update_status(bootstrap_room, KVPoll.Bootstrapping) self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Bootstrapping)
self.conclude_state = None
if self.bootstrap_addr not in self.kv_mgr.prefill_dp_size_table: if self.bootstrap_addr not in self.kv_mgr.prefill_dp_size_table:
self.prefill_tp_size, self.prefill_dp_size = ( self.prefill_tp_size, self.prefill_dp_size = (
self._get_prefill_dp_size_from_server() self._get_prefill_parallel_info_from_server()
) )
if self.prefill_tp_size is None or self.prefill_dp_size is None: if self.prefill_tp_size is None or self.prefill_dp_size is None:
logger.error( self.kv_mgr.record_failure(
f"Could not fetch prefill parallel info for bootstrap_addr: {self.bootstrap_addr}" self.bootstrap_room,
f"Could not fetch prefill parallel info from bootstrap_addr: {self.bootstrap_addr}",
) )
self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Failed)
return
else: else:
logger.debug(
f"Fetch prefill parallel info from [{self.bootstrap_addr}]: DP size:{self.prefill_dp_size}, TP size:{self.prefill_tp_size}"
)
self.kv_mgr.prefill_tp_size_table[self.bootstrap_addr] = ( self.kv_mgr.prefill_tp_size_table[self.bootstrap_addr] = (
self.prefill_tp_size self.prefill_tp_size
) )
...@@ -587,7 +818,7 @@ class MooncakeKVReceiver(BaseKVReceiver): ...@@ -587,7 +818,7 @@ class MooncakeKVReceiver(BaseKVReceiver):
self.target_tp_rank = self.target_tp_ranks[0] self.target_tp_rank = self.target_tp_ranks[0]
self.required_dst_info_num = 1 self.required_dst_info_num = 1
self.target_dp_group = bootstrap_room % self.prefill_dp_size self.target_dp_group = self.bootstrap_room % self.prefill_dp_size
# NOTE: key distinguished by bootstrap_addr, target_dp_group, and target_tp_rank # NOTE: key distinguished by bootstrap_addr, target_dp_group, and target_tp_rank
bootstrap_key = ( bootstrap_key = (
...@@ -607,32 +838,37 @@ class MooncakeKVReceiver(BaseKVReceiver): ...@@ -607,32 +838,37 @@ class MooncakeKVReceiver(BaseKVReceiver):
target_tp_rank == self.target_tp_rank target_tp_rank == self.target_tp_rank
or self.target_tp_rank is None or self.target_tp_rank is None
) )
logger.debug(
f"Fetched bootstrap info: {bootstrap_info} for DP {self.target_dp_group} TP {target_tp_rank}"
)
bootstrap_infos.append(bootstrap_info) bootstrap_infos.append(bootstrap_info)
else: else:
logger.error( self.kv_mgr.record_failure(
f"Could not fetch bootstrap info for engine rank: {self.kv_mgr.kv_args.engine_rank} and target_dp_group: {self.target_dp_group}" self.bootstrap_room,
f"Could not fetch bootstrap info for engine rank: {self.kv_mgr.kv_args.engine_rank} and target_dp_group: {self.target_dp_group}",
) )
self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Failed)
return
self.bootstrap_infos = bootstrap_infos self.bootstrap_infos = bootstrap_infos
self.kv_mgr.connection_pool[bootstrap_key] = self.bootstrap_infos
if len(self.bootstrap_infos) == 0: # Register kv_args only once to prefill KVManager according to the info fetched from the bootstrap server
logger.error( self._register_kv_args()
f"Could not fetch bootstrap info for engine rank: {self.kv_mgr.kv_args.engine_rank}"
)
else:
self.kv_mgr.connection_pool[bootstrap_key] = self.bootstrap_infos
# Register kv_args only once to prefill KVManager according to the info fetched from the bootstrap server
self._register_kv_args()
else: else:
self.bootstrap_infos = self.kv_mgr.connection_pool[bootstrap_key] self.bootstrap_infos = self.kv_mgr.connection_pool[bootstrap_key]
assert len(self.bootstrap_infos) > 0 assert len(self.bootstrap_infos) > 0
self.kv_mgr.update_status(bootstrap_room, KVPoll.WaitingForInput) self.kv_mgr.addr_to_rooms_tracker[self.bootstrap_addr].append(
self.bootstrap_room
)
self.kv_mgr.update_status(self.bootstrap_room, KVPoll.WaitingForInput)
def _get_bootstrap_info_from_server(self, engine_rank, target_dp_group): def _get_bootstrap_info_from_server(self, engine_rank, target_dp_group):
"""Fetch the bootstrap info from the bootstrap server.""" """Fetch the bootstrap info from the bootstrap server."""
try: try:
url = f"http://{self.bootstrap_addr}/route?engine_rank={engine_rank}&target_dp_group={target_dp_group}" url = f"http://{self.bootstrap_addr}/route?engine_rank={engine_rank}&target_dp_group={target_dp_group}"
response = requests.get(url) response = requests.get(url, timeout=5)
if response.status_code == 200: if response.status_code == 200:
bootstrap_info = response.json() bootstrap_info = response.json()
return bootstrap_info return bootstrap_info
...@@ -645,7 +881,7 @@ class MooncakeKVReceiver(BaseKVReceiver): ...@@ -645,7 +881,7 @@ class MooncakeKVReceiver(BaseKVReceiver):
logger.error(f"Error fetching prefill info from bootstrap: {e}") logger.error(f"Error fetching prefill info from bootstrap: {e}")
return None return None
def _get_prefill_dp_size_from_server(self) -> int: def _get_prefill_parallel_info_from_server(self) -> Tuple[int, int]:
"""Fetch the prefill parallel info from the bootstrap server.""" """Fetch the prefill parallel info from the bootstrap server."""
try: try:
url = f"http://{self.bootstrap_addr}/route?engine_rank={-1}&target_dp_group={-1}" url = f"http://{self.bootstrap_addr}/route?engine_rank={-1}&target_dp_group={-1}"
...@@ -659,10 +895,10 @@ class MooncakeKVReceiver(BaseKVReceiver): ...@@ -659,10 +895,10 @@ class MooncakeKVReceiver(BaseKVReceiver):
logger.error( logger.error(
f"Failed to get prefill parallel info: {response.status_code}, {response.text}" f"Failed to get prefill parallel info: {response.status_code}, {response.text}"
) )
return None return None, None
except Exception as e: except Exception as e:
logger.error(f"Error fetching prefill parallel info from bootstrap: {e}") logger.error(f"Error fetching prefill parallel info from bootstrap: {e}")
return None return None, None
def _register_kv_args(self): def _register_kv_args(self):
for bootstrap_info in self.bootstrap_infos: for bootstrap_info in self.bootstrap_infos:
...@@ -704,9 +940,6 @@ class MooncakeKVReceiver(BaseKVReceiver): ...@@ -704,9 +940,6 @@ class MooncakeKVReceiver(BaseKVReceiver):
self.prefill_server_url = ( self.prefill_server_url = (
f"{bootstrap_info['rank_ip']}:{bootstrap_info['rank_port']}" f"{bootstrap_info['rank_ip']}:{bootstrap_info['rank_port']}"
) )
logger.debug(
f"Fetched bootstrap info: {bootstrap_info} for engine rank: {self.kv_mgr.kv_args.engine_rank}"
)
is_dummy = bootstrap_info["is_dummy"] is_dummy = bootstrap_info["is_dummy"]
sock, lock = self._connect("tcp://" + self.prefill_server_url) sock, lock = self._connect("tcp://" + self.prefill_server_url)
...@@ -724,11 +957,30 @@ class MooncakeKVReceiver(BaseKVReceiver): ...@@ -724,11 +957,30 @@ class MooncakeKVReceiver(BaseKVReceiver):
) )
def poll(self) -> KVPoll: def poll(self) -> KVPoll:
return self.kv_mgr.check_status(self.bootstrap_room) if self.conclude_state is None:
status = self.kv_mgr.check_status(self.bootstrap_room)
if status in (KVPoll.Success, KVPoll.Failed):
self.conclude_state = status
return status
else:
return self.conclude_state
def clear(self) -> None:
self.kv_mgr.request_status.pop(self.bootstrap_room)
def failure_exception(self): def failure_exception(self):
# TODO: raise a real exception self.clear()
raise Exception("Fake KVReceiver Exception")
# Explicitly set the status to failure since this request has failed in another rank
if self.conclude_state is None:
self.conclude_state = KVPoll.Failed
with self.kv_mgr.failure_lock:
failure_reason = self.kv_mgr.failure_records.pop(
self.bootstrap_room, "Failed due to an unknown reason from another rank"
)
raise KVTransferError(self.bootstrap_room, failure_reason)
class MooncakeKVBootstrapServer(BaseKVBootstrapServer): class MooncakeKVBootstrapServer(BaseKVBootstrapServer):
...@@ -752,6 +1004,10 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer): ...@@ -752,6 +1004,10 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer):
def _setup_routes(self): def _setup_routes(self):
self.app.router.add_route("*", "/route", self._handle_route) self.app.router.add_route("*", "/route", self._handle_route)
self.app.router.add_get("/health", self._handle_health_check)
async def _handle_health_check(self, request):
return web.Response(text="OK", status=200)
async def _handle_route(self, request: web.Request): async def _handle_route(self, request: web.Request):
method = request.method method = request.method
...@@ -780,14 +1036,14 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer): ...@@ -780,14 +1036,14 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer):
self.dp_size = dp_size self.dp_size = dp_size
tp_size_per_dp_rank = tp_size // dp_size tp_size_per_dp_rank = tp_size // dp_size
if self.tp_size_per_dp_rank == None: if self.tp_size_per_dp_rank is None:
self.tp_size_per_dp_rank = tp_size_per_dp_rank self.tp_size_per_dp_rank = tp_size_per_dp_rank
# Add lock to make sure thread-safe
if role == "Prefill": if role == "Prefill":
dp_group = engine_rank // tp_size_per_dp_rank dp_group = engine_rank // tp_size_per_dp_rank
tp_rank_in_dp_group = engine_rank % tp_size_per_dp_rank tp_rank_in_dp_group = engine_rank % tp_size_per_dp_rank
# Add lock to make sure thread-safe
async with self.lock: async with self.lock:
if dp_group not in self.prefill_port_table: if dp_group not in self.prefill_port_table:
self.prefill_port_table[dp_group] = {} self.prefill_port_table[dp_group] = {}
...@@ -797,7 +1053,7 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer): ...@@ -797,7 +1053,7 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer):
"rank_port": rank_port, "rank_port": rank_port,
} }
logger.debug( logger.debug(
f"Register Prefill bootstrap: {engine_rank} with rank_ip: {rank_ip} and rank_port: {rank_port}" f"Register prefill bootstrap: {engine_rank} with rank_ip: {rank_ip} and rank_port: {rank_port}"
) )
return web.Response(text="OK", status=200) return web.Response(text="OK", status=200)
...@@ -833,7 +1089,11 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer): ...@@ -833,7 +1089,11 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer):
self._loop = asyncio.new_event_loop() self._loop = asyncio.new_event_loop()
asyncio.set_event_loop(self._loop) asyncio.set_event_loop(self._loop)
self._runner = web.AppRunner(self.app) access_log = None
if logging.getLogger(__name__).getEffectiveLevel() <= logging.DEBUG:
access_log = self.app.logger
self._runner = web.AppRunner(self.app, access_log=access_log)
self._loop.run_until_complete(self._runner.setup()) self._loop.run_until_complete(self._runner.setup())
site = web.TCPSite(self._runner, port=self.port) site = web.TCPSite(self._runner, port=self.port)
......
...@@ -30,16 +30,24 @@ class MooncakeTransferEngine: ...@@ -30,16 +30,24 @@ class MooncakeTransferEngine:
self.session_id = f"{self.hostname}:{self.engine.get_rpc_port()}" self.session_id = f"{self.hostname}:{self.engine.get_rpc_port()}"
def register(self, ptr, length): def register(self, ptr, length):
ret_value = self.engine.register_memory(ptr, length) try:
ret_value = self.engine.register_memory(ptr, length)
except Exception:
# Mark register as failed
ret_value = -1
if ret_value != 0: if ret_value != 0:
logger.error("Mooncake memory registration failed.") logger.debug("Mooncake memory registration %s failed.", ptr)
raise RuntimeError("Mooncake memory registration failed.")
def deregister(self, ptr): def deregister(self, ptr):
ret_value = self.engine.unregister_memory(ptr) try:
ret_value = self.engine.unregister_memory(ptr)
except Exception:
# Mark deregister as failed
ret_value = -1
if ret_value != 0: if ret_value != 0:
logger.error("Mooncake memory deregistration failed.") logger.debug("Mooncake memory deregistration %s failed.", ptr)
raise RuntimeError("Mooncake memory deregistration failed.")
def initialize( def initialize(
self, self,
...@@ -61,18 +69,26 @@ class MooncakeTransferEngine: ...@@ -61,18 +69,26 @@ class MooncakeTransferEngine:
self, session_id: str, buffer: int, peer_buffer_address: int, length: int self, session_id: str, buffer: int, peer_buffer_address: int, length: int
) -> int: ) -> int:
"""Synchronously transfer data to the specified address.""" """Synchronously transfer data to the specified address."""
# the first time: based on session_id (which contains remote_ip) to construct a queue pair, and cache the queue pair try:
# later: based on the cached queue pair to send data # the first time: based on session_id (which contains remote_ip) to construct a queue pair, and cache the queue pair
ret = self.engine.transfer_sync_write( # later: based on the cached queue pair to send data
session_id, buffer, peer_buffer_address, length ret = self.engine.transfer_sync_write(
) session_id, buffer, peer_buffer_address, length
)
except Exception:
# Mark transfer request as failed
ret = -1
if ret < 0: if ret < 0:
logger.error("Mooncake Transfer Engine Return Error.") # Do not raise an exception here, since some transfer requests fail should be accepted and the execution thread should not be stopped.
raise RuntimeError("Mooncake Transfer Engine Return Error.") logger.debug(
return ret "Failed to transfer data from %s to %s - %s.",
buffer,
session_id,
peer_buffer_address,
)
def get_localhost(self): return ret
return self.hostname
def get_session_id(self): def get_session_id(self):
return self.session_id return self.session_id
...@@ -417,6 +417,8 @@ class SchedulerDisaggregationPrefillMixin: ...@@ -417,6 +417,8 @@ class SchedulerDisaggregationPrefillMixin:
self.tree_cache.cache_finished_req(req) # unlock the tree self.tree_cache.cache_finished_req(req) # unlock the tree
req.finished_reason = FINISH_LENGTH(length=0) req.finished_reason = FINISH_LENGTH(length=0)
# FIXME: clean up req's data in transfer engine # FIXME: clean up req's data in transfer engine
if hasattr(req.disagg_kv_sender, "clear"):
req.disagg_kv_sender.clear()
done_reqs.append(req) done_reqs.append(req)
elif poll == KVPoll.Failed: elif poll == KVPoll.Failed:
error_message = f"Prefill transfer failed for request rank={self.tp_rank} {req.rid=} {req.bootstrap_room=}" error_message = f"Prefill transfer failed for request rank={self.tp_rank} {req.rid=} {req.bootstrap_room=}"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment