"src/diffusers/pipelines/audioldm/pipeline_audioldm.py" did not exist on "b562b6611fb53dae9bcffcaaf44d944539ae22de"
Unverified Commit 3ce94f71 authored by shangmingc's avatar shangmingc Committed by GitHub
Browse files

[PD] Handle P/D failure and reconnect without affecting other instances (#6263)


Signed-off-by: default avatarShangming Cai <caishangming@linux.alibaba.com>
parent ca95556c
......@@ -361,7 +361,7 @@ class DecodeTransferQueue:
indices_to_remove = set()
for i, (decode_req, poll) in enumerate(zip(self.queue, polls)):
if poll == KVPoll.Failed:
error_message = f"Decode transfer failed for request {decode_req.req.rid=} {decode_req.req.bootstrap_room=}"
error_message = f"Decode transfer failed for request rank={self.scheduler.tp_rank} {decode_req.req.rid=} {decode_req.req.bootstrap_room=}"
try:
decode_req.kv_receiver.failure_exception()
except Exception as e:
......@@ -409,7 +409,8 @@ class DecodeTransferQueue:
: decode_req.req.top_logprobs_num
].tolist()
)
if hasattr(decode_req.kv_receiver, "clear"):
decode_req.kv_receiver.clear()
transferred_reqs.append(decode_req.req)
indices_to_remove.add(i)
elif poll in [
......
......@@ -9,6 +9,8 @@ import queue
import socket
import struct
import threading
import time
from collections import defaultdict
from functools import cache
from typing import Dict, List, Optional, Tuple, Union
......@@ -51,6 +53,16 @@ def group_concurrent_contiguous(
return src_groups, dst_groups
class KVTransferError(Exception):
def __init__(self, bootstrap_room: int, failure_reason: str):
super().__init__(failure_reason)
self.bootstrap_room = bootstrap_room
self.failure_reason = failure_reason
def __str__(self):
return f"KVTransferError(bootstrap_room={self.bootstrap_room}): {self.failure_reason}"
# prefill
@dataclasses.dataclass
class TransferKVChunk:
......@@ -153,13 +165,34 @@ class MooncakeKVManager(BaseKVManager):
self.decode_kv_args_table: Dict[str, KVArgsRegisterInfo] = {}
self.start_prefill_thread()
self._register_to_bootstrap()
self.session_failures = defaultdict(int)
self.failed_sessions = set()
self.session_lock = threading.Lock()
# Determine the number of threads to use for kv sender
cpu_count = os.cpu_count()
self.executor = concurrent.futures.ThreadPoolExecutor(
min(cpu_count // 4, 16)
int(
os.getenv(
"DISAGGREGATION_THREAD_POOL_SIZE",
min(max(1, cpu_count // 8), 8),
)
)
)
elif self.disaggregation_mode == DisaggregationMode.DECODE:
self.heartbeat_failures = {}
self.session_pool = defaultdict(requests.Session)
self.session_pool_lock = threading.Lock()
self.addr_to_rooms_tracker = defaultdict(list)
self.connection_lock = threading.Lock()
# Heartbeat interval should be at least 2 seconds
self.heartbeat_interval = max(
float(os.getenv("DISAGGREGATION_HEARTBEAT_INTERVAL", 5.0)), 2.0
)
# Heartbeat failure should be at least 1
self.max_failures = max(
int(os.getenv("DISAGGREGATION_HEARTBEAT_MAX_FAILURE", 2)), 1
)
self.start_decode_thread()
self.connection_pool: Dict[str, Dict[str, Union[str, int]]] = {}
self.prefill_tp_size_table: Dict[str, int] = {}
......@@ -169,6 +202,9 @@ class MooncakeKVManager(BaseKVManager):
f"Unsupported DisaggregationMode: {self.disaggregation_mode}"
)
self.failure_records: Dict[int, str] = {}
self.failure_lock = threading.Lock()
def register_buffer_to_engine(self):
for kv_data_ptr, kv_data_len in zip(
self.kv_args.kv_data_ptrs, self.kv_args.kv_data_lens
......@@ -235,8 +271,6 @@ class MooncakeKVManager(BaseKVManager):
for future in concurrent.futures.as_completed(futures):
status = future.result()
if status != 0:
# Immediate shutdown on first error (existing tasks will finish)
self.executor.shutdown(wait=False)
for f in futures:
f.cancel()
return status
......@@ -255,20 +289,20 @@ class MooncakeKVManager(BaseKVManager):
self.kv_args.aux_data_ptrs[0] + prefill_aux_index * aux_item_len
)
decode_aux_addr = dst_aux_ptrs[0] + dst_aux_index * aux_item_len
# TODO: mooncake transfer engine can do async transfer. Do async later
# Not sure about the amount of aux data, maybe transfer it by zmq is more effective
status = self.engine.transfer_sync(
mooncake_session_id, prefill_aux_addr, decode_aux_addr, aux_item_len
)
return status
def sync_status_to_decode_endpoint(self, remote: str, dst_port: int, room: int):
def sync_status_to_decode_endpoint(
self, remote: str, dst_port: int, room: int, status: int
):
if ":" in remote:
remote = remote.split(":")[0]
self._connect("tcp://" + remote + ":" + str(dst_port)).send_multipart(
[
str(room).encode("ascii"),
str(self.check_status(room)).encode("ascii"),
str(status).encode("ascii"),
]
)
......@@ -287,6 +321,11 @@ class MooncakeKVManager(BaseKVManager):
self.decode_kv_args_table[mooncake_session_id] = (
KVArgsRegisterInfo.from_zmq(waiting_req_bytes)
)
with self.session_lock:
if mooncake_session_id in self.failed_sessions:
self.failed_sessions.remove(mooncake_session_id)
if mooncake_session_id in self.session_failures:
del self.session_failures[mooncake_session_id]
logger.debug(
f"Register KVArgs from {mooncake_session_id} successfully"
)
......@@ -309,17 +348,48 @@ class MooncakeKVManager(BaseKVManager):
while True:
try:
kv_chunk: TransferKVChunk = self.transfer_queue.get(timeout=0.01)
reqs_to_be_processed = self.transfer_infos[kv_chunk.room].values()
reqs_to_be_processed = (
self.transfer_infos[kv_chunk.room].values()
if kv_chunk.room in self.transfer_infos
else []
)
polls = []
dst_ranks_infos = []
for req in reqs_to_be_processed:
if not req.is_dummy:
# Early exit if the request has failed
with self.session_lock:
if req.mooncake_session_id in self.failed_sessions:
self.record_failure(
kv_chunk.room,
f"Decode instance could be dead, {req.mooncake_session_id} failed due to multiple errors",
)
self.update_status(kv_chunk.room, KVPoll.Failed)
self.sync_status_to_decode_endpoint(
req.endpoint,
req.dst_port,
req.room,
KVPoll.Failed,
)
break
chunked_dst_kv_indice = req.dst_kv_indices[
kv_chunk.index_slice
]
assert len(chunked_dst_kv_indice) == len(
# NOTE: This is temporarily a workaround to deal with the case where the prefill_kv_indices
# is mismatched with the dst_kv_indices when page size > 1, this should never happen.
if len(chunked_dst_kv_indice) < len(
kv_chunk.prefill_kv_indices
), f"len(chunked_dst_kv_indice) = {len(chunked_dst_kv_indice)}, len(kv_chunk.prefill_kv_indices) = {len(kv_chunk.prefill_kv_indices)}"
):
kv_chunk.prefill_kv_indices = (
kv_chunk.prefill_kv_indices[
len(chunked_dst_kv_indice)
]
)
logger.warning(
f"len(chunked_dst_kv_indice) = {len(chunked_dst_kv_indice)}, len(kv_chunk.prefill_kv_indices) = {len(kv_chunk.prefill_kv_indices)}"
)
ret = self.send_kvcache(
req.mooncake_session_id,
......@@ -330,11 +400,28 @@ class MooncakeKVManager(BaseKVManager):
chunked_dst_kv_indice,
)
if ret != 0:
with self.session_lock:
self.session_failures[req.mooncake_session_id] += 1
# Failures should never happen if the session is not dead, if the session fails once, mark it as failed
if (
self.session_failures[req.mooncake_session_id]
>= 1
):
self.failed_sessions.add(
req.mooncake_session_id
)
logger.error(
f"Session {req.mooncake_session_id} failed."
)
self.record_failure(
kv_chunk.room,
f"Failed to send kv chunk of {kv_chunk.room} to {req.endpoint}:{req.dst_port}",
)
self.update_status(kv_chunk.room, KVPoll.Failed)
self.sync_status_to_decode_endpoint(
req.endpoint, req.dst_port, req.room
req.endpoint, req.dst_port, req.room, KVPoll.Failed
)
continue
break
if kv_chunk.is_last:
# Only the last chunk we need to send the aux data
......@@ -353,25 +440,33 @@ class MooncakeKVManager(BaseKVManager):
# Only sync status when all the dst ranks have received the kvcache
if len(polls) == req.required_dst_info_num:
self.update_status(
req.room,
KVPoll.Success if all(polls) else KVPoll.Failed,
status = (
KVPoll.Success if all(polls) else KVPoll.Failed
)
self.update_status(req.room, status)
for endpoint, dst_port, room in dst_ranks_infos:
self.sync_status_to_decode_endpoint(
endpoint, dst_port, room
endpoint, dst_port, room, status
)
else:
# Dummy request means the decode instance is not used, so its status can be marked as success directly
# Dummy request does not need to sync status to decode endpoint
if kv_chunk.is_last:
if kv_chunk.is_last and req.room in self.request_status:
self.update_status(req.room, KVPoll.Success)
if self.check_status(kv_chunk.room) == KVPoll.Success:
if (
kv_chunk.room not in self.request_status
or self.check_status(kv_chunk.room) == KVPoll.Success
):
self.transfer_infos.pop(kv_chunk.room)
except queue.Empty:
continue
except Exception as e:
# NOTE(shangming): Remove this when we make sure the transfer thread is bug-free
raise RuntimeError(
f"Transfer thread failed because of {e}. Prefill instance with bootstrap_port={self.bootstrap_port} is dead."
)
threading.Thread(target=bootstrap_thread).start()
threading.Thread(target=transfer_thread).start()
......@@ -385,9 +480,67 @@ class MooncakeKVManager(BaseKVManager):
(bootstrap_room, status) = self.server_socket.recv_multipart()
status = int(status.decode("ascii"))
bootstrap_room = int(bootstrap_room.decode("ascii"))
if status == KVPoll.Failed:
self.record_failure(
bootstrap_room,
f"Failed to get kvcache from prefill instance, it might be dead",
)
self.update_status(bootstrap_room, status)
def heartbeat_checker():
while True:
time.sleep(self.heartbeat_interval)
with self.connection_lock:
addresses = list(self.prefill_dp_size_table.keys())
for bootstrap_addr in addresses:
session = None
try:
with self.session_pool_lock:
session = self.session_pool[bootstrap_addr]
response = session.get(
f"http://{bootstrap_addr}/health",
timeout=(2, 3),
headers={"Connection": "keep-alive"},
)
if response.status_code == 200:
self.heartbeat_failures[bootstrap_addr] = 0
for bootstrap_room in self.addr_to_rooms_tracker[
bootstrap_addr
]:
# Remove KVPoll.Success requests from the map
if bootstrap_room not in self.request_status:
self.addr_to_rooms_tracker[bootstrap_addr].remove(
bootstrap_room
)
else:
logger.info(
f"Attempting to reconnect to {bootstrap_addr}..."
)
self.heartbeat_failures[bootstrap_addr] = (
self.heartbeat_failures.get(bootstrap_addr, 0) + 1
)
with self.session_pool_lock:
if bootstrap_addr in self.session_pool:
del self.session_pool[bootstrap_addr]
except Exception:
logger.info(f"Attempting to reconnect to {bootstrap_addr}...")
self.heartbeat_failures[bootstrap_addr] = (
self.heartbeat_failures.get(bootstrap_addr, 0) + 1
)
if (
self.heartbeat_failures.get(bootstrap_addr, 0)
>= self.max_failures
):
self._handle_node_failure(bootstrap_addr)
with self.session_pool_lock:
if bootstrap_addr in self.session_pool:
del self.session_pool[bootstrap_addr]
threading.Thread(target=decode_thread).start()
threading.Thread(target=heartbeat_checker).start()
def add_transfer_request(
self,
......@@ -400,6 +553,15 @@ class MooncakeKVManager(BaseKVManager):
assert self.disaggregation_mode == DisaggregationMode.PREFILL
assert not is_last or (is_last and aux_index is not None)
if (
bootstrap_room not in self.request_status
or self.check_status(bootstrap_room) == KVPoll.Failed
):
logger.debug(
"Request with bootstrap_room=%s already failed", bootstrap_room
)
return
self.transfer_queue.put(
TransferKVChunk(
room=bootstrap_room,
......@@ -418,10 +580,17 @@ class MooncakeKVManager(BaseKVManager):
if bootstrap_room not in self.request_status:
self.request_status[bootstrap_room] = status
else:
# NOTE: The prefill engine could recv bootstrapping first
self.request_status[bootstrap_room] = max(
self.request_status[bootstrap_room], status
)
# NOTE: status is only allowed to be incremented unless it is KVPoll.Failed
if status == KVPoll.Failed:
self.request_status[bootstrap_room] = KVPoll.Failed
else:
self.request_status[bootstrap_room] = max(
self.request_status[bootstrap_room], status
)
def record_failure(self, bootstrap_room: int, failure_reason: str):
with self.failure_lock:
self.failure_records[bootstrap_room] = failure_reason
def get_session_id(self):
return self.engine.get_session_id()
......@@ -445,15 +614,51 @@ class MooncakeKVManager(BaseKVManager):
}
try:
response = requests.put(url, json=payload)
response = requests.put(url, json=payload, timeout=5)
if response.status_code == 200:
logger.debug("Prefill successfully registered to bootstrap server.")
else:
logger.error(
f"Prefill Failed to connect to bootstrap server: {response.status_code}, {response.text}"
f"Prefill instance failed to connect to bootstrap server: {response.status_code}, {response.text}"
)
except Exception as e:
logger.error(f"Prefill Failed to register to bootstrap server: {e}")
logger.error(
f"Prefill instance failed to register to bootstrap server: {e}"
)
def _handle_node_failure(self, failed_bootstrap_addr):
with self.connection_lock:
keys_to_remove = [
k for k in self.connection_pool if k.startswith(failed_bootstrap_addr)
]
for k in keys_to_remove:
del self.connection_pool[k]
if failed_bootstrap_addr in self.prefill_tp_size_table:
del self.prefill_tp_size_table[failed_bootstrap_addr]
if failed_bootstrap_addr in self.prefill_dp_size_table:
del self.prefill_dp_size_table[failed_bootstrap_addr]
possible_affected_rooms = self.addr_to_rooms_tracker.get(
failed_bootstrap_addr, []
)
del self.addr_to_rooms_tracker[failed_bootstrap_addr]
# Report the requests associated with the failed bootstrap addr and mark their status as KVPoll.Failed
affected_rooms = []
for room in possible_affected_rooms:
if (
room in self.request_status
and self.check_status(room) != KVPoll.Success
):
self.record_failure(
room,
f"Losing connection with prefill instance (bootstrap_addr: {failed_bootstrap_addr})",
)
self.update_status(room, KVPoll.Failed)
affected_rooms.append(room)
logger.error(
f"Losing connection with prefill instance (bootstrap_addr: {failed_bootstrap_addr}), affected {len(affected_rooms)} requests"
)
class MooncakeKVSender(BaseKVSender):
......@@ -466,7 +671,7 @@ class MooncakeKVSender(BaseKVSender):
self.kv_mgr.update_status(bootstrap_room, KVPoll.Bootstrapping)
self.aux_index = None
self.bootstrap_server_url = bootstrap_addr
self.session_id = self.kv_mgr.get_session_id()
self.conclude_state = None
# inner state
self.curr_idx = 0
......@@ -496,11 +701,30 @@ class MooncakeKVSender(BaseKVSender):
)
def poll(self) -> KVPoll:
return self.kv_mgr.check_status(self.bootstrap_room)
if self.conclude_state is None:
status = self.kv_mgr.check_status(self.bootstrap_room)
if status in (KVPoll.Success, KVPoll.Failed):
self.conclude_state = status
return status
else:
return self.conclude_state
def clear(self) -> None:
self.kv_mgr.request_status.pop(self.bootstrap_room)
def failure_exception(self):
# TODO: raise a real exception
raise Exception("Fake KVSender Exception")
self.clear()
# Explicitly set the status to failure since this request has failed in another rank
if self.conclude_state is None:
self.conclude_state = KVPoll.Failed
with self.kv_mgr.failure_lock:
failure_reason = self.kv_mgr.failure_records.pop(
self.bootstrap_room, "Failed due to an unknown reason from another rank"
)
raise KVTransferError(self.bootstrap_room, failure_reason)
class MooncakeKVReceiver(BaseKVReceiver):
......@@ -519,17 +743,24 @@ class MooncakeKVReceiver(BaseKVReceiver):
self.bootstrap_addr = bootstrap_addr
self.kv_mgr = mgr
self.session_id = self.kv_mgr.get_session_id()
self.kv_mgr.update_status(bootstrap_room, KVPoll.Bootstrapping)
self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Bootstrapping)
self.conclude_state = None
if self.bootstrap_addr not in self.kv_mgr.prefill_dp_size_table:
self.prefill_tp_size, self.prefill_dp_size = (
self._get_prefill_dp_size_from_server()
self._get_prefill_parallel_info_from_server()
)
if self.prefill_tp_size is None or self.prefill_dp_size is None:
logger.error(
f"Could not fetch prefill parallel info for bootstrap_addr: {self.bootstrap_addr}"
self.kv_mgr.record_failure(
self.bootstrap_room,
f"Could not fetch prefill parallel info from bootstrap_addr: {self.bootstrap_addr}",
)
self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Failed)
return
else:
logger.debug(
f"Fetch prefill parallel info from [{self.bootstrap_addr}]: DP size:{self.prefill_dp_size}, TP size:{self.prefill_tp_size}"
)
self.kv_mgr.prefill_tp_size_table[self.bootstrap_addr] = (
self.prefill_tp_size
)
......@@ -587,7 +818,7 @@ class MooncakeKVReceiver(BaseKVReceiver):
self.target_tp_rank = self.target_tp_ranks[0]
self.required_dst_info_num = 1
self.target_dp_group = bootstrap_room % self.prefill_dp_size
self.target_dp_group = self.bootstrap_room % self.prefill_dp_size
# NOTE: key distinguished by bootstrap_addr, target_dp_group, and target_tp_rank
bootstrap_key = (
......@@ -607,32 +838,37 @@ class MooncakeKVReceiver(BaseKVReceiver):
target_tp_rank == self.target_tp_rank
or self.target_tp_rank is None
)
logger.debug(
f"Fetched bootstrap info: {bootstrap_info} for DP {self.target_dp_group} TP {target_tp_rank}"
)
bootstrap_infos.append(bootstrap_info)
else:
logger.error(
f"Could not fetch bootstrap info for engine rank: {self.kv_mgr.kv_args.engine_rank} and target_dp_group: {self.target_dp_group}"
self.kv_mgr.record_failure(
self.bootstrap_room,
f"Could not fetch bootstrap info for engine rank: {self.kv_mgr.kv_args.engine_rank} and target_dp_group: {self.target_dp_group}",
)
self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Failed)
return
self.bootstrap_infos = bootstrap_infos
self.kv_mgr.connection_pool[bootstrap_key] = self.bootstrap_infos
if len(self.bootstrap_infos) == 0:
logger.error(
f"Could not fetch bootstrap info for engine rank: {self.kv_mgr.kv_args.engine_rank}"
)
else:
self.kv_mgr.connection_pool[bootstrap_key] = self.bootstrap_infos
# Register kv_args only once to prefill KVManager according to the info fetched from the bootstrap server
self._register_kv_args()
# Register kv_args only once to prefill KVManager according to the info fetched from the bootstrap server
self._register_kv_args()
else:
self.bootstrap_infos = self.kv_mgr.connection_pool[bootstrap_key]
assert len(self.bootstrap_infos) > 0
self.kv_mgr.update_status(bootstrap_room, KVPoll.WaitingForInput)
self.kv_mgr.addr_to_rooms_tracker[self.bootstrap_addr].append(
self.bootstrap_room
)
self.kv_mgr.update_status(self.bootstrap_room, KVPoll.WaitingForInput)
def _get_bootstrap_info_from_server(self, engine_rank, target_dp_group):
"""Fetch the bootstrap info from the bootstrap server."""
try:
url = f"http://{self.bootstrap_addr}/route?engine_rank={engine_rank}&target_dp_group={target_dp_group}"
response = requests.get(url)
response = requests.get(url, timeout=5)
if response.status_code == 200:
bootstrap_info = response.json()
return bootstrap_info
......@@ -645,7 +881,7 @@ class MooncakeKVReceiver(BaseKVReceiver):
logger.error(f"Error fetching prefill info from bootstrap: {e}")
return None
def _get_prefill_dp_size_from_server(self) -> int:
def _get_prefill_parallel_info_from_server(self) -> Tuple[int, int]:
"""Fetch the prefill parallel info from the bootstrap server."""
try:
url = f"http://{self.bootstrap_addr}/route?engine_rank={-1}&target_dp_group={-1}"
......@@ -659,10 +895,10 @@ class MooncakeKVReceiver(BaseKVReceiver):
logger.error(
f"Failed to get prefill parallel info: {response.status_code}, {response.text}"
)
return None
return None, None
except Exception as e:
logger.error(f"Error fetching prefill parallel info from bootstrap: {e}")
return None
return None, None
def _register_kv_args(self):
for bootstrap_info in self.bootstrap_infos:
......@@ -704,9 +940,6 @@ class MooncakeKVReceiver(BaseKVReceiver):
self.prefill_server_url = (
f"{bootstrap_info['rank_ip']}:{bootstrap_info['rank_port']}"
)
logger.debug(
f"Fetched bootstrap info: {bootstrap_info} for engine rank: {self.kv_mgr.kv_args.engine_rank}"
)
is_dummy = bootstrap_info["is_dummy"]
sock, lock = self._connect("tcp://" + self.prefill_server_url)
......@@ -724,11 +957,30 @@ class MooncakeKVReceiver(BaseKVReceiver):
)
def poll(self) -> KVPoll:
return self.kv_mgr.check_status(self.bootstrap_room)
if self.conclude_state is None:
status = self.kv_mgr.check_status(self.bootstrap_room)
if status in (KVPoll.Success, KVPoll.Failed):
self.conclude_state = status
return status
else:
return self.conclude_state
def clear(self) -> None:
self.kv_mgr.request_status.pop(self.bootstrap_room)
def failure_exception(self):
# TODO: raise a real exception
raise Exception("Fake KVReceiver Exception")
self.clear()
# Explicitly set the status to failure since this request has failed in another rank
if self.conclude_state is None:
self.conclude_state = KVPoll.Failed
with self.kv_mgr.failure_lock:
failure_reason = self.kv_mgr.failure_records.pop(
self.bootstrap_room, "Failed due to an unknown reason from another rank"
)
raise KVTransferError(self.bootstrap_room, failure_reason)
class MooncakeKVBootstrapServer(BaseKVBootstrapServer):
......@@ -752,6 +1004,10 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer):
def _setup_routes(self):
self.app.router.add_route("*", "/route", self._handle_route)
self.app.router.add_get("/health", self._handle_health_check)
async def _handle_health_check(self, request):
return web.Response(text="OK", status=200)
async def _handle_route(self, request: web.Request):
method = request.method
......@@ -780,14 +1036,14 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer):
self.dp_size = dp_size
tp_size_per_dp_rank = tp_size // dp_size
if self.tp_size_per_dp_rank == None:
if self.tp_size_per_dp_rank is None:
self.tp_size_per_dp_rank = tp_size_per_dp_rank
# Add lock to make sure thread-safe
if role == "Prefill":
dp_group = engine_rank // tp_size_per_dp_rank
tp_rank_in_dp_group = engine_rank % tp_size_per_dp_rank
# Add lock to make sure thread-safe
async with self.lock:
if dp_group not in self.prefill_port_table:
self.prefill_port_table[dp_group] = {}
......@@ -797,7 +1053,7 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer):
"rank_port": rank_port,
}
logger.debug(
f"Register Prefill bootstrap: {engine_rank} with rank_ip: {rank_ip} and rank_port: {rank_port}"
f"Register prefill bootstrap: {engine_rank} with rank_ip: {rank_ip} and rank_port: {rank_port}"
)
return web.Response(text="OK", status=200)
......@@ -833,7 +1089,11 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer):
self._loop = asyncio.new_event_loop()
asyncio.set_event_loop(self._loop)
self._runner = web.AppRunner(self.app)
access_log = None
if logging.getLogger(__name__).getEffectiveLevel() <= logging.DEBUG:
access_log = self.app.logger
self._runner = web.AppRunner(self.app, access_log=access_log)
self._loop.run_until_complete(self._runner.setup())
site = web.TCPSite(self._runner, port=self.port)
......
......@@ -30,16 +30,24 @@ class MooncakeTransferEngine:
self.session_id = f"{self.hostname}:{self.engine.get_rpc_port()}"
def register(self, ptr, length):
ret_value = self.engine.register_memory(ptr, length)
try:
ret_value = self.engine.register_memory(ptr, length)
except Exception:
# Mark register as failed
ret_value = -1
if ret_value != 0:
logger.error("Mooncake memory registration failed.")
raise RuntimeError("Mooncake memory registration failed.")
logger.debug("Mooncake memory registration %s failed.", ptr)
def deregister(self, ptr):
ret_value = self.engine.unregister_memory(ptr)
try:
ret_value = self.engine.unregister_memory(ptr)
except Exception:
# Mark deregister as failed
ret_value = -1
if ret_value != 0:
logger.error("Mooncake memory deregistration failed.")
raise RuntimeError("Mooncake memory deregistration failed.")
logger.debug("Mooncake memory deregistration %s failed.", ptr)
def initialize(
self,
......@@ -61,18 +69,26 @@ class MooncakeTransferEngine:
self, session_id: str, buffer: int, peer_buffer_address: int, length: int
) -> int:
"""Synchronously transfer data to the specified address."""
# the first time: based on session_id (which contains remote_ip) to construct a queue pair, and cache the queue pair
# later: based on the cached queue pair to send data
ret = self.engine.transfer_sync_write(
session_id, buffer, peer_buffer_address, length
)
try:
# the first time: based on session_id (which contains remote_ip) to construct a queue pair, and cache the queue pair
# later: based on the cached queue pair to send data
ret = self.engine.transfer_sync_write(
session_id, buffer, peer_buffer_address, length
)
except Exception:
# Mark transfer request as failed
ret = -1
if ret < 0:
logger.error("Mooncake Transfer Engine Return Error.")
raise RuntimeError("Mooncake Transfer Engine Return Error.")
return ret
# Do not raise an exception here, since some transfer requests fail should be accepted and the execution thread should not be stopped.
logger.debug(
"Failed to transfer data from %s to %s - %s.",
buffer,
session_id,
peer_buffer_address,
)
def get_localhost(self):
return self.hostname
return ret
def get_session_id(self):
return self.session_id
......@@ -417,6 +417,8 @@ class SchedulerDisaggregationPrefillMixin:
self.tree_cache.cache_finished_req(req) # unlock the tree
req.finished_reason = FINISH_LENGTH(length=0)
# FIXME: clean up req's data in transfer engine
if hasattr(req.disagg_kv_sender, "clear"):
req.disagg_kv_sender.clear()
done_reqs.append(req)
elif poll == KVPoll.Failed:
error_message = f"Prefill transfer failed for request rank={self.tp_rank} {req.rid=} {req.bootstrap_room=}"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment