Unverified Commit 3ce94f71 authored by shangmingc's avatar shangmingc Committed by GitHub
Browse files

[PD] Handle P/D failure and reconnect without affecting other instances (#6263)


Signed-off-by: default avatarShangming Cai <caishangming@linux.alibaba.com>
parent ca95556c
......@@ -361,7 +361,7 @@ class DecodeTransferQueue:
indices_to_remove = set()
for i, (decode_req, poll) in enumerate(zip(self.queue, polls)):
if poll == KVPoll.Failed:
error_message = f"Decode transfer failed for request {decode_req.req.rid=} {decode_req.req.bootstrap_room=}"
error_message = f"Decode transfer failed for request rank={self.scheduler.tp_rank} {decode_req.req.rid=} {decode_req.req.bootstrap_room=}"
try:
decode_req.kv_receiver.failure_exception()
except Exception as e:
......@@ -409,7 +409,8 @@ class DecodeTransferQueue:
: decode_req.req.top_logprobs_num
].tolist()
)
if hasattr(decode_req.kv_receiver, "clear"):
decode_req.kv_receiver.clear()
transferred_reqs.append(decode_req.req)
indices_to_remove.add(i)
elif poll in [
......
......@@ -30,16 +30,24 @@ class MooncakeTransferEngine:
self.session_id = f"{self.hostname}:{self.engine.get_rpc_port()}"
def register(self, ptr, length):
ret_value = self.engine.register_memory(ptr, length)
try:
ret_value = self.engine.register_memory(ptr, length)
except Exception:
# Mark register as failed
ret_value = -1
if ret_value != 0:
logger.error("Mooncake memory registration failed.")
raise RuntimeError("Mooncake memory registration failed.")
logger.debug("Mooncake memory registration %s failed.", ptr)
def deregister(self, ptr):
ret_value = self.engine.unregister_memory(ptr)
try:
ret_value = self.engine.unregister_memory(ptr)
except Exception:
# Mark deregister as failed
ret_value = -1
if ret_value != 0:
logger.error("Mooncake memory deregistration failed.")
raise RuntimeError("Mooncake memory deregistration failed.")
logger.debug("Mooncake memory deregistration %s failed.", ptr)
def initialize(
self,
......@@ -61,18 +69,26 @@ class MooncakeTransferEngine:
self, session_id: str, buffer: int, peer_buffer_address: int, length: int
) -> int:
"""Synchronously transfer data to the specified address."""
# the first time: based on session_id (which contains remote_ip) to construct a queue pair, and cache the queue pair
# later: based on the cached queue pair to send data
ret = self.engine.transfer_sync_write(
session_id, buffer, peer_buffer_address, length
)
try:
# the first time: based on session_id (which contains remote_ip) to construct a queue pair, and cache the queue pair
# later: based on the cached queue pair to send data
ret = self.engine.transfer_sync_write(
session_id, buffer, peer_buffer_address, length
)
except Exception:
# Mark transfer request as failed
ret = -1
if ret < 0:
logger.error("Mooncake Transfer Engine Return Error.")
raise RuntimeError("Mooncake Transfer Engine Return Error.")
return ret
# Do not raise an exception here, since some transfer requests fail should be accepted and the execution thread should not be stopped.
logger.debug(
"Failed to transfer data from %s to %s - %s.",
buffer,
session_id,
peer_buffer_address,
)
def get_localhost(self):
return self.hostname
return ret
def get_session_id(self):
return self.session_id
......@@ -417,6 +417,8 @@ class SchedulerDisaggregationPrefillMixin:
self.tree_cache.cache_finished_req(req) # unlock the tree
req.finished_reason = FINISH_LENGTH(length=0)
# FIXME: clean up req's data in transfer engine
if hasattr(req.disagg_kv_sender, "clear"):
req.disagg_kv_sender.clear()
done_reqs.append(req)
elif poll == KVPoll.Failed:
error_message = f"Prefill transfer failed for request rank={self.tp_rank} {req.rid=} {req.bootstrap_room=}"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment