Unverified Commit 3ce94f71 authored by shangmingc's avatar shangmingc Committed by GitHub
Browse files

[PD] Handle P/D failure and reconnect without affecting other instances (#6263)


Signed-off-by: default avatarShangming Cai <caishangming@linux.alibaba.com>
parent ca95556c
...@@ -361,7 +361,7 @@ class DecodeTransferQueue: ...@@ -361,7 +361,7 @@ class DecodeTransferQueue:
indices_to_remove = set() indices_to_remove = set()
for i, (decode_req, poll) in enumerate(zip(self.queue, polls)): for i, (decode_req, poll) in enumerate(zip(self.queue, polls)):
if poll == KVPoll.Failed: if poll == KVPoll.Failed:
error_message = f"Decode transfer failed for request {decode_req.req.rid=} {decode_req.req.bootstrap_room=}" error_message = f"Decode transfer failed for request rank={self.scheduler.tp_rank} {decode_req.req.rid=} {decode_req.req.bootstrap_room=}"
try: try:
decode_req.kv_receiver.failure_exception() decode_req.kv_receiver.failure_exception()
except Exception as e: except Exception as e:
...@@ -409,7 +409,8 @@ class DecodeTransferQueue: ...@@ -409,7 +409,8 @@ class DecodeTransferQueue:
: decode_req.req.top_logprobs_num : decode_req.req.top_logprobs_num
].tolist() ].tolist()
) )
if hasattr(decode_req.kv_receiver, "clear"):
decode_req.kv_receiver.clear()
transferred_reqs.append(decode_req.req) transferred_reqs.append(decode_req.req)
indices_to_remove.add(i) indices_to_remove.add(i)
elif poll in [ elif poll in [
......
...@@ -30,16 +30,24 @@ class MooncakeTransferEngine: ...@@ -30,16 +30,24 @@ class MooncakeTransferEngine:
self.session_id = f"{self.hostname}:{self.engine.get_rpc_port()}" self.session_id = f"{self.hostname}:{self.engine.get_rpc_port()}"
def register(self, ptr, length): def register(self, ptr, length):
ret_value = self.engine.register_memory(ptr, length) try:
ret_value = self.engine.register_memory(ptr, length)
except Exception:
# Mark register as failed
ret_value = -1
if ret_value != 0: if ret_value != 0:
logger.error("Mooncake memory registration failed.") logger.debug("Mooncake memory registration %s failed.", ptr)
raise RuntimeError("Mooncake memory registration failed.")
def deregister(self, ptr): def deregister(self, ptr):
ret_value = self.engine.unregister_memory(ptr) try:
ret_value = self.engine.unregister_memory(ptr)
except Exception:
# Mark deregister as failed
ret_value = -1
if ret_value != 0: if ret_value != 0:
logger.error("Mooncake memory deregistration failed.") logger.debug("Mooncake memory deregistration %s failed.", ptr)
raise RuntimeError("Mooncake memory deregistration failed.")
def initialize( def initialize(
self, self,
...@@ -61,18 +69,26 @@ class MooncakeTransferEngine: ...@@ -61,18 +69,26 @@ class MooncakeTransferEngine:
self, session_id: str, buffer: int, peer_buffer_address: int, length: int self, session_id: str, buffer: int, peer_buffer_address: int, length: int
) -> int: ) -> int:
"""Synchronously transfer data to the specified address.""" """Synchronously transfer data to the specified address."""
# the first time: based on session_id (which contains remote_ip) to construct a queue pair, and cache the queue pair try:
# later: based on the cached queue pair to send data # the first time: based on session_id (which contains remote_ip) to construct a queue pair, and cache the queue pair
ret = self.engine.transfer_sync_write( # later: based on the cached queue pair to send data
session_id, buffer, peer_buffer_address, length ret = self.engine.transfer_sync_write(
) session_id, buffer, peer_buffer_address, length
)
except Exception:
# Mark transfer request as failed
ret = -1
if ret < 0: if ret < 0:
logger.error("Mooncake Transfer Engine Return Error.") # Do not raise an exception here, since some transfer requests fail should be accepted and the execution thread should not be stopped.
raise RuntimeError("Mooncake Transfer Engine Return Error.") logger.debug(
return ret "Failed to transfer data from %s to %s - %s.",
buffer,
session_id,
peer_buffer_address,
)
def get_localhost(self): return ret
return self.hostname
def get_session_id(self): def get_session_id(self):
return self.session_id return self.session_id
...@@ -417,6 +417,8 @@ class SchedulerDisaggregationPrefillMixin: ...@@ -417,6 +417,8 @@ class SchedulerDisaggregationPrefillMixin:
self.tree_cache.cache_finished_req(req) # unlock the tree self.tree_cache.cache_finished_req(req) # unlock the tree
req.finished_reason = FINISH_LENGTH(length=0) req.finished_reason = FINISH_LENGTH(length=0)
# FIXME: clean up req's data in transfer engine # FIXME: clean up req's data in transfer engine
if hasattr(req.disagg_kv_sender, "clear"):
req.disagg_kv_sender.clear()
done_reqs.append(req) done_reqs.append(req)
elif poll == KVPoll.Failed: elif poll == KVPoll.Failed:
error_message = f"Prefill transfer failed for request rank={self.tp_rank} {req.rid=} {req.bootstrap_room=}" error_message = f"Prefill transfer failed for request rank={self.tp_rank} {req.rid=} {req.bootstrap_room=}"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment