Unverified Commit 44afde82 authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

Fix PD disaggregation bugs (#5326)

parent 072df753
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
import dataclasses
import logging import logging
import queue
import struct import struct
import threading import threading
from functools import cache from functools import cache
...@@ -52,10 +54,38 @@ def group_concurrent_contiguous( ...@@ -52,10 +54,38 @@ def group_concurrent_contiguous(
return src_groups, dst_groups return src_groups, dst_groups
RequestPoolType = Dict[int, Tuple[npt.NDArray[np.int64], Optional[int]]] @dataclasses.dataclass
WaitingPoolType = Dict[ class TransferKVChunk:
int, Tuple[str, list[int], npt.NDArray[np.int64], list[int], int] room: int
] prefill_kv_indices: npt.NDArray[np.int64]
index_slice: slice
is_last: bool
prefill_aux_index: Optional[int]
@dataclasses.dataclass
class TransferInfo:
room: int
endpoint: str
mooncake_session_id: str
dst_kv_ptrs: list[int]
dst_kv_indices: npt.NDArray[np.int64]
dst_aux_ptrs: list[int]
dst_aux_index: int
@classmethod
def from_zmq(cls, msg: List[bytes]):
return cls(
endpoint=msg[0].decode("ascii"),
mooncake_session_id=msg[1].decode("ascii"),
room=int(msg[2].decode("ascii")),
dst_kv_ptrs=list(struct.unpack(f"{len(msg[3])//8}Q", msg[3])),
dst_kv_indices=np.frombuffer(msg[4], dtype=np.int64),
dst_aux_ptrs=list(struct.unpack(f"{len(msg[5])//8}Q", msg[5])),
dst_aux_index=int(msg[6].decode("ascii")),
)
KVSENDER_POLLING_PORT = 17788 KVSENDER_POLLING_PORT = 17788
KVRECEIVER_POLLING_PORT = 27788 KVRECEIVER_POLLING_PORT = 27788
...@@ -65,13 +95,12 @@ class MooncakeKVManager(BaseKVManager): ...@@ -65,13 +95,12 @@ class MooncakeKVManager(BaseKVManager):
self.engine = MooncakeTransferEngine() self.engine = MooncakeTransferEngine()
self.kv_args = args self.kv_args = args
self.disaggregation_mode = disaggregation_mode self.disaggregation_mode = disaggregation_mode
self.request_pool: RequestPoolType = {}
self.request_status: Dict[int, KVPoll] = {} self.request_status: Dict[int, KVPoll] = {}
self.server_socket = zmq.Context().socket(zmq.PULL) self.server_socket = zmq.Context().socket(zmq.PULL)
self.register_buffer_to_engine() self.register_buffer_to_engine()
if self.disaggregation_mode == DisaggregationMode.PREFILL: if self.disaggregation_mode == DisaggregationMode.PREFILL:
self.waiting_pool: WaitingPoolType = {} self.transfer_queue = queue.Queue()
self.transfer_event = threading.Event() self.transfer_infos: Dict[int, TransferInfo] = {}
self.start_prefill_thread() self.start_prefill_thread()
elif self.disaggregation_mode == DisaggregationMode.DECODE: elif self.disaggregation_mode == DisaggregationMode.DECODE:
self.start_decode_thread() self.start_decode_thread()
...@@ -101,7 +130,7 @@ class MooncakeKVManager(BaseKVManager): ...@@ -101,7 +130,7 @@ class MooncakeKVManager(BaseKVManager):
self, self,
mooncake_session_id: str, mooncake_session_id: str,
prefill_kv_indices: npt.NDArray[np.int64], prefill_kv_indices: npt.NDArray[np.int64],
dst_ptrs: list[int], dst_kv_ptrs: list[int],
dst_kv_indices: npt.NDArray[np.int64], dst_kv_indices: npt.NDArray[np.int64],
): ):
layer_num = int(len(self.kv_args.kv_data_ptrs) / 2) layer_num = int(len(self.kv_args.kv_data_ptrs) / 2)
...@@ -114,8 +143,8 @@ class MooncakeKVManager(BaseKVManager): ...@@ -114,8 +143,8 @@ class MooncakeKVManager(BaseKVManager):
prefill_value_layer_ptr = self.kv_args.kv_data_ptrs[layer_num + layer_id] prefill_value_layer_ptr = self.kv_args.kv_data_ptrs[layer_num + layer_id]
value_item_len = self.kv_args.kv_item_lens[layer_num + layer_id] value_item_len = self.kv_args.kv_item_lens[layer_num + layer_id]
decode_key_layer_ptr = dst_ptrs[layer_id] decode_key_layer_ptr = dst_kv_ptrs[layer_id]
decode_value_layer_ptr = dst_ptrs[layer_num + layer_id] decode_value_layer_ptr = dst_kv_ptrs[layer_num + layer_id]
for prefill_index, decode_index in zip(prefill_kv_blocks, dst_kv_blocks): for prefill_index, decode_index in zip(prefill_kv_blocks, dst_kv_blocks):
prefill_key_addr = ( prefill_key_addr = (
...@@ -192,87 +221,60 @@ class MooncakeKVManager(BaseKVManager): ...@@ -192,87 +221,60 @@ class MooncakeKVManager(BaseKVManager):
sender_rank_port = KVSENDER_POLLING_PORT + self.kv_args.engine_rank sender_rank_port = KVSENDER_POLLING_PORT + self.kv_args.engine_rank
self.server_socket.bind("tcp://*:" + str(sender_rank_port)) self.server_socket.bind("tcp://*:" + str(sender_rank_port))
def prefill_thread(): def bootstrap_thread():
"""This thread recvs pre-alloc notification from the decode engine"""
# KVPoll.Bootstrapping -> KVPoll.WaitingForInput
while True: while True:
( waiting_req_bytes = self.server_socket.recv_multipart()
endpoint, room = waiting_req_bytes[2].decode("ascii")
mooncake_session_id, if room == "None":
bootstrap_room,
dst_ptrs,
dst_kv_indices,
dst_aux_ptrs,
dst_aux_index,
) = self.server_socket.recv_multipart()
if bootstrap_room.decode("ascii") == "None":
continue continue
endpoint = endpoint.decode("ascii") room = int(room)
mooncake_session_id = mooncake_session_id.decode("ascii") self.transfer_infos[room] = TransferInfo.from_zmq(waiting_req_bytes)
bootstrap_room = int(bootstrap_room.decode("ascii"))
dst_ptrs = list(struct.unpack(f"{len(dst_ptrs)//8}Q", dst_ptrs))
dst_kv_indices = np.frombuffer(dst_kv_indices, dtype=np.int64)
dst_aux_ptrs = list(
struct.unpack(f"{len(dst_aux_ptrs)//8}Q", dst_aux_ptrs)
)
dst_aux_index = int(dst_aux_index.decode("ascii"))
self.waiting_pool[bootstrap_room] = (
endpoint,
mooncake_session_id,
dst_ptrs,
dst_kv_indices,
dst_aux_ptrs,
dst_aux_index,
)
self.transfer_event.set()
threading.Thread(target=prefill_thread).start() # NOTE: after bootstrapping we can mark the req as waiting for input
self.request_status[room] = KVPoll.WaitingForInput
def transfer_thread(): def transfer_thread():
# TODO: Shall we use KVPoll.Transferring state?
while True: while True:
self.transfer_event.wait() try:
self.transfer_event.clear() kv_chunk: TransferKVChunk = self.transfer_queue.get(timeout=0.01)
bootstrap_room_ready = self.request_pool.keys() req = self.transfer_infos[kv_chunk.room]
bootstrap_room_request = self.waiting_pool.keys() chunked_dst_kv_indice = req.dst_kv_indices[kv_chunk.index_slice]
for room in list(bootstrap_room_request): assert len(chunked_dst_kv_indice) == len(
if room not in list(bootstrap_room_ready): kv_chunk.prefill_kv_indices
continue )
status = KVPoll.Transferring
self.request_status[room] = status
(
endpoint,
mooncake_session_id,
dst_ptrs,
dst_kv_indices,
dst_aux_ptrs,
dst_aux_index,
) = self.waiting_pool.pop(room)
self.sync_status_to_decode_endpoint(endpoint, room)
(
prefill_kv_indices,
prefill_aux_index,
) = self.request_pool.pop(room)
ret = self.send_kvcache( ret = self.send_kvcache(
mooncake_session_id, req.mooncake_session_id,
prefill_kv_indices, kv_chunk.prefill_kv_indices,
dst_ptrs, req.dst_kv_ptrs,
dst_kv_indices, chunked_dst_kv_indice,
) )
if ret != 0: if ret != 0:
status = KVPoll.Failed self.request_status[kv_chunk.room] = KVPoll.Failed
self.sync_status_to_decode_endpoint(endpoint, room) self.sync_status_to_decode_endpoint(req.endpoint, req.room)
continue continue
ret = self.send_aux(
mooncake_session_id,
prefill_aux_index,
dst_aux_ptrs,
dst_aux_index,
)
if ret != 0:
status = KVPoll.Failed
else:
status = KVPoll.Success
self.request_status[room] = status
self.sync_status_to_decode_endpoint(endpoint, room)
if kv_chunk.is_last:
# Only the last chunk we need to send the aux data
ret = self.send_aux(
req.mooncake_session_id,
kv_chunk.prefill_aux_index,
req.dst_aux_ptrs,
req.dst_aux_index,
)
self.request_status[req.room] = (
KVPoll.Success if ret == 0 else KVPoll.Failed
)
self.sync_status_to_decode_endpoint(req.endpoint, req.room)
self.transfer_infos.pop(req.room)
except queue.Empty:
continue
threading.Thread(target=bootstrap_thread).start()
threading.Thread(target=transfer_thread).start() threading.Thread(target=transfer_thread).start()
def start_decode_thread(self): def start_decode_thread(self):
...@@ -288,29 +290,41 @@ class MooncakeKVManager(BaseKVManager): ...@@ -288,29 +290,41 @@ class MooncakeKVManager(BaseKVManager):
threading.Thread(target=decode_thread).start() threading.Thread(target=decode_thread).start()
def enqueue_request( def add_transfer_request(
self, self,
bootstrap_room: int, bootstrap_room: int,
kv_indices: npt.NDArray[np.int64], kv_indices: npt.NDArray[np.int64],
aux_index: Optional[int], index_slice: slice,
is_last: bool,
aux_index: Optional[int] = None,
): ):
self.request_pool[bootstrap_room] = (kv_indices, aux_index) assert self.disaggregation_mode == DisaggregationMode.PREFILL
assert not is_last or (is_last and aux_index is not None)
self.transfer_queue.put(
TransferKVChunk(
room=bootstrap_room,
prefill_kv_indices=kv_indices,
index_slice=index_slice,
is_last=is_last,
prefill_aux_index=aux_index,
)
)
self.request_status[bootstrap_room] = KVPoll.WaitingForInput self.request_status[bootstrap_room] = KVPoll.WaitingForInput
if self.disaggregation_mode == DisaggregationMode.PREFILL:
self.transfer_event.set()
def check_status(self, bootstrap_room: int): def check_status(self, bootstrap_room: int):
if ( # TOOD: do we really need the poll()?
self.disaggregation_mode == DisaggregationMode.DECODE
and self.request_status[bootstrap_room] == KVPoll.Success
):
if bootstrap_room in self.request_pool:
self.request_pool.pop(bootstrap_room)
return self.request_status[bootstrap_room] return self.request_status[bootstrap_room]
def set_status(self, bootstrap_room: int, status: KVPoll): def update_status(self, bootstrap_room: int, status: KVPoll):
self.request_status[bootstrap_room] = status if bootstrap_room not in self.request_status:
self.request_status[bootstrap_room] = status
else:
# NOTE: The prefill engine could recv bootstrapping first
self.request_status[bootstrap_room] = max(
self.request_status[bootstrap_room], status
)
def get_localhost(self): def get_localhost(self):
return self.engine.get_localhost() return self.engine.get_localhost()
...@@ -326,15 +340,31 @@ class MooncakeKVSender(BaseKVSender): ...@@ -326,15 +340,31 @@ class MooncakeKVSender(BaseKVSender):
): ):
self.kv_mgr = mgr self.kv_mgr = mgr
self.bootstrap_room = bootstrap_room self.bootstrap_room = bootstrap_room
self.kv_mgr.set_status(bootstrap_room, KVPoll.WaitingForInput) self.kv_mgr.update_status(bootstrap_room, KVPoll.Bootstrapping)
self.aux_index = None self.aux_index = None
def init(self, num_kv_indices: int, aux_index: Optional[int] = None): def init(self, num_kv_indices: int, aux_index: Optional[int] = None):
self.aux_index = aux_index
self.num_kv_indices = num_kv_indices self.num_kv_indices = num_kv_indices
self.aux_index = aux_index
def send(self, kv_indices: npt.NDArray[np.int64]): def send(
self.kv_mgr.enqueue_request(self.bootstrap_room, kv_indices, self.aux_index) self,
kv_indices: npt.NDArray[np.int64],
index_slice: slice,
is_last: bool,
):
if not is_last:
self.kv_mgr.add_transfer_request(
self.bootstrap_room, kv_indices, index_slice, False
)
else:
self.kv_mgr.add_transfer_request(
self.bootstrap_room,
kv_indices,
index_slice,
True,
aux_index=self.aux_index,
)
def poll(self) -> KVPoll: def poll(self) -> KVPoll:
return self.kv_mgr.check_status(self.bootstrap_room) return self.kv_mgr.check_status(self.bootstrap_room)
...@@ -361,7 +391,7 @@ class MooncakeKVReceiver(BaseKVReceiver): ...@@ -361,7 +391,7 @@ class MooncakeKVReceiver(BaseKVReceiver):
) )
self.decode_ip = self.kv_mgr.get_localhost() self.decode_ip = self.kv_mgr.get_localhost()
self.session_id = self.kv_mgr.get_session_id() self.session_id = self.kv_mgr.get_session_id()
self.kv_mgr.set_status(bootstrap_room, KVPoll.WaitingForInput) self.kv_mgr.update_status(bootstrap_room, KVPoll.WaitingForInput)
@cache @cache
def _connect(self, endpoint: str): def _connect(self, endpoint: str):
...@@ -370,7 +400,6 @@ class MooncakeKVReceiver(BaseKVReceiver): ...@@ -370,7 +400,6 @@ class MooncakeKVReceiver(BaseKVReceiver):
return socket return socket
def init(self, kv_indices: npt.NDArray[np.int64], aux_index: Optional[int] = None): def init(self, kv_indices: npt.NDArray[np.int64], aux_index: Optional[int] = None):
self.kv_mgr.enqueue_request(self.bootstrap_room, kv_indices, aux_index)
packed_kv_data_ptrs = b"".join( packed_kv_data_ptrs = b"".join(
struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.kv_data_ptrs struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.kv_data_ptrs
) )
......
...@@ -81,7 +81,7 @@ class PrefillBootstrapQueue: ...@@ -81,7 +81,7 @@ class PrefillBootstrapQueue:
self.gloo_group = gloo_group self.gloo_group = gloo_group
self.bootstrap_port = bootstrap_port self.bootstrap_port = bootstrap_port
def allocate_token_id(self, idx: int, token_id: int): def store_prefill_results(self, idx: int, token_id: int):
assert token_id >= 0, f"token_id: {token_id} is negative" assert token_id >= 0, f"token_id: {token_id} is negative"
output_id_buffer = self.metadata_buffers[0] output_id_buffer = self.metadata_buffers[0]
output_id_buffer[idx] = token_id output_id_buffer[idx] = token_id
...@@ -146,7 +146,7 @@ class PrefillBootstrapQueue: ...@@ -146,7 +146,7 @@ class PrefillBootstrapQueue:
elif poll == KVPoll.Failed: elif poll == KVPoll.Failed:
raise Exception("Bootstrap failed") raise Exception("Bootstrap failed")
# KV.WaitingForInput - init here # KV.WaitingForInput
num_kv_indices = len(req.origin_input_ids) num_kv_indices = len(req.origin_input_ids)
if self.req_to_metadata_buffer_idx_allocator.available_size() == 0: if self.req_to_metadata_buffer_idx_allocator.available_size() == 0:
break break
...@@ -222,6 +222,7 @@ class SchedulerDisaggregationPrefillMixin: ...@@ -222,6 +222,7 @@ class SchedulerDisaggregationPrefillMixin:
elif poll == KVPoll.Success: # transfer done elif poll == KVPoll.Success: # transfer done
self.tree_cache.cache_finished_req(req) # unlock the tree self.tree_cache.cache_finished_req(req) # unlock the tree
req.finished_reason = FINISH_LENGTH(length=0) req.finished_reason = FINISH_LENGTH(length=0)
# FIXME: clean up req's data in transfer engine
done_reqs.append(req) done_reqs.append(req)
elif poll == KVPoll.Failed: elif poll == KVPoll.Failed:
raise Exception("Transferring failed") raise Exception("Transferring failed")
...@@ -256,14 +257,18 @@ class SchedulerDisaggregationPrefillMixin: ...@@ -256,14 +257,18 @@ class SchedulerDisaggregationPrefillMixin:
""" """
start_idx = req.start_send_idx start_idx = req.start_send_idx
end_idx = min(len(req.fill_ids), len(req.origin_input_ids)) end_idx = min(len(req.fill_ids), len(req.origin_input_ids))
# Update next start_send_idx
req.start_send_idx = end_idx
kv_indices = ( kv_indices = (
self.req_to_token_pool.req_to_token[req.req_pool_idx][start_idx:end_idx] self.req_to_token_pool.req_to_token[req.req_pool_idx][start_idx:end_idx]
.cpu() .cpu()
.numpy() .numpy()
) )
req.start_send_idx = end_idx
if token_id is not None: if token_id is not None:
self.disagg_prefill_pending_queue.allocate_token_id( self.disagg_prefill_pending_queue.store_prefill_results(
req.metadata_buffer_index, token_id req.metadata_buffer_index, token_id
) )
req.disagg_kv_sender.send(kv_indices) is_last = token_id is not None
req.disagg_kv_sender.send(kv_indices, slice(start_idx, end_idx), is_last)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment