Unverified Commit 50eda839 authored by shangmingc's avatar shangmingc Committed by GitHub
Browse files

[PD] Add kvargs table and thread pool for kvcache sender of mooncake (#5738)


Signed-off-by: default avatarShangming Cai <caishangming@linux.alibaba.com>
parent c55550cb
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
import concurrent.futures
import dataclasses import dataclasses
import logging import logging
import os
import queue import queue
import socket import socket
import struct import struct
...@@ -73,9 +75,7 @@ class TransferInfo: ...@@ -73,9 +75,7 @@ class TransferInfo:
endpoint: str endpoint: str
dst_port: int dst_port: int
mooncake_session_id: str mooncake_session_id: str
dst_kv_ptrs: list[int]
dst_kv_indices: npt.NDArray[np.int64] dst_kv_indices: npt.NDArray[np.int64]
dst_aux_ptrs: list[int]
dst_aux_index: int dst_aux_index: int
@classmethod @classmethod
...@@ -85,10 +85,29 @@ class TransferInfo: ...@@ -85,10 +85,29 @@ class TransferInfo:
endpoint=msg[1].decode("ascii"), endpoint=msg[1].decode("ascii"),
dst_port=int(msg[2].decode("ascii")), dst_port=int(msg[2].decode("ascii")),
mooncake_session_id=msg[3].decode("ascii"), mooncake_session_id=msg[3].decode("ascii"),
dst_kv_indices=np.frombuffer(msg[4], dtype=np.int64),
dst_aux_index=int(msg[5].decode("ascii")),
)
@dataclasses.dataclass
class KVArgsRegisterInfo:
room: str
endpoint: str
dst_port: int
mooncake_session_id: str
dst_kv_ptrs: list[int]
dst_aux_ptrs: list[int]
@classmethod
def from_zmq(cls, msg: List[bytes]):
return cls(
room=str(msg[0].decode("ascii")),
endpoint=msg[1].decode("ascii"),
dst_port=int(msg[2].decode("ascii")),
mooncake_session_id=msg[3].decode("ascii"),
dst_kv_ptrs=list(struct.unpack(f"{len(msg[4])//8}Q", msg[4])), dst_kv_ptrs=list(struct.unpack(f"{len(msg[4])//8}Q", msg[4])),
dst_kv_indices=np.frombuffer(msg[5], dtype=np.int64), dst_aux_ptrs=list(struct.unpack(f"{len(msg[5])//8}Q", msg[5])),
dst_aux_ptrs=list(struct.unpack(f"{len(msg[6])//8}Q", msg[6])),
dst_aux_index=int(msg[7].decode("ascii")),
) )
...@@ -123,8 +142,15 @@ class MooncakeKVManager(BaseKVManager): ...@@ -123,8 +142,15 @@ class MooncakeKVManager(BaseKVManager):
if self.disaggregation_mode == DisaggregationMode.PREFILL: if self.disaggregation_mode == DisaggregationMode.PREFILL:
self.transfer_queue = queue.Queue() self.transfer_queue = queue.Queue()
self.transfer_infos: Dict[int, TransferInfo] = {} self.transfer_infos: Dict[int, TransferInfo] = {}
self.decode_kv_args_table: Dict[str, KVArgsRegisterInfo] = {}
self.start_prefill_thread() self.start_prefill_thread()
self._register_to_bootstrap() self._register_to_bootstrap()
# Determine the number of threads to use for kv sender
cpu_count = os.cpu_count()
self.executor = concurrent.futures.ThreadPoolExecutor(
max_workers=cpu_count if cpu_count is not None else 64
)
elif self.disaggregation_mode == DisaggregationMode.DECODE: elif self.disaggregation_mode == DisaggregationMode.DECODE:
self.start_decode_thread() self.start_decode_thread()
self.connection_pool: Dict[str, Dict[str, Union[str, int]]] = {} self.connection_pool: Dict[str, Dict[str, Union[str, int]]] = {}
...@@ -158,28 +184,53 @@ class MooncakeKVManager(BaseKVManager): ...@@ -158,28 +184,53 @@ class MooncakeKVManager(BaseKVManager):
dst_kv_ptrs: list[int], dst_kv_ptrs: list[int],
dst_kv_indices: npt.NDArray[np.int64], dst_kv_indices: npt.NDArray[np.int64],
): ):
# group by indices # Group by indices
prefill_kv_blocks, dst_kv_blocks = group_concurrent_contiguous( prefill_kv_blocks, dst_kv_blocks = group_concurrent_contiguous(
prefill_kv_indices, dst_kv_indices prefill_kv_indices, dst_kv_indices
) )
num_layers = len(self.kv_args.kv_data_ptrs) num_layers = len(self.kv_args.kv_data_ptrs)
for layer_id in range(num_layers): layers_params = [
src_ptr = self.kv_args.kv_data_ptrs[layer_id] (
dst_ptr = dst_kv_ptrs[layer_id] self.kv_args.kv_data_ptrs[layer_id],
item_len = self.kv_args.kv_item_lens[layer_id] dst_kv_ptrs[layer_id],
self.kv_args.kv_item_lens[layer_id],
)
for layer_id in range(num_layers)
]
# Worker function for processing a single layer
def process_layer(src_ptr: int, dst_ptr: int, item_len: int) -> int:
for prefill_index, decode_index in zip(prefill_kv_blocks, dst_kv_blocks): for prefill_index, decode_index in zip(prefill_kv_blocks, dst_kv_blocks):
src_addr = src_ptr + int(prefill_index[0]) * item_len src_addr = src_ptr + int(prefill_index[0]) * item_len
dst_addr = dst_ptr + int(decode_index[0]) * item_len dst_addr = dst_ptr + int(decode_index[0]) * item_len
length = item_len * len(prefill_index) length = item_len * len(prefill_index)
# TODO: make async later
status = self.engine.transfer_sync( status = self.engine.transfer_sync(
mooncake_session_id, src_addr, dst_addr, length mooncake_session_id, src_addr, dst_addr, length
) )
if status != 0: if status != 0:
return status return status
return 0
futures = [
self.executor.submit(
process_layer,
src_ptr,
dst_ptr,
item_len,
)
for (src_ptr, dst_ptr, item_len) in layers_params
]
for future in concurrent.futures.as_completed(futures):
status = future.result()
if status != 0:
# Immediate shutdown on first error (existing tasks will finish)
executor.shutdown(wait=False)
for f in futures:
f.cancel()
return status
return 0 return 0
...@@ -223,6 +274,13 @@ class MooncakeKVManager(BaseKVManager): ...@@ -223,6 +274,13 @@ class MooncakeKVManager(BaseKVManager):
waiting_req_bytes = self.server_socket.recv_multipart() waiting_req_bytes = self.server_socket.recv_multipart()
room = waiting_req_bytes[0].decode("ascii") room = waiting_req_bytes[0].decode("ascii")
if room == "None": if room == "None":
mooncake_session_id = waiting_req_bytes[3].decode("ascii")
self.decode_kv_args_table[mooncake_session_id] = (
KVArgsRegisterInfo.from_zmq(waiting_req_bytes)
)
logger.debug(
f"Register KVArgs from {mooncake_session_id} successfully"
)
continue continue
room = int(room) room = int(room)
self.transfer_infos[room] = TransferInfo.from_zmq(waiting_req_bytes) self.transfer_infos[room] = TransferInfo.from_zmq(waiting_req_bytes)
...@@ -244,7 +302,7 @@ class MooncakeKVManager(BaseKVManager): ...@@ -244,7 +302,7 @@ class MooncakeKVManager(BaseKVManager):
ret = self.send_kvcache( ret = self.send_kvcache(
req.mooncake_session_id, req.mooncake_session_id,
kv_chunk.prefill_kv_indices, kv_chunk.prefill_kv_indices,
req.dst_kv_ptrs, self.decode_kv_args_table[req.mooncake_session_id].dst_kv_ptrs,
chunked_dst_kv_indice, chunked_dst_kv_indice,
) )
if ret != 0: if ret != 0:
...@@ -259,7 +317,9 @@ class MooncakeKVManager(BaseKVManager): ...@@ -259,7 +317,9 @@ class MooncakeKVManager(BaseKVManager):
ret = self.send_aux( ret = self.send_aux(
req.mooncake_session_id, req.mooncake_session_id,
kv_chunk.prefill_aux_index, kv_chunk.prefill_aux_index,
req.dst_aux_ptrs, self.decode_kv_args_table[
req.mooncake_session_id
].dst_aux_ptrs,
req.dst_aux_index, req.dst_aux_index,
) )
self.request_status[req.room] = ( self.request_status[req.room] = (
...@@ -460,6 +520,8 @@ class MooncakeKVReceiver(BaseKVReceiver): ...@@ -460,6 +520,8 @@ class MooncakeKVReceiver(BaseKVReceiver):
) )
else: else:
self.kv_mgr.connection_pool[bootstrap_key] = self.bootstrap_info self.kv_mgr.connection_pool[bootstrap_key] = self.bootstrap_info
# Register kv_args only once to prefill KVManager according to the info fetched from the bootstrap server
self._register_kv_args()
else: else:
self.bootstrap_info = self.kv_mgr.connection_pool[bootstrap_key] self.bootstrap_info = self.kv_mgr.connection_pool[bootstrap_key]
...@@ -502,6 +564,30 @@ class MooncakeKVReceiver(BaseKVReceiver): ...@@ -502,6 +564,30 @@ class MooncakeKVReceiver(BaseKVReceiver):
logger.error(f"Error fetching prefill parallel info from bootstrap: {e}") logger.error(f"Error fetching prefill parallel info from bootstrap: {e}")
return None return None
def _register_kv_args(self):
self.prefill_server_url = (
f"{self.bootstrap_info['rank_ip']}:{self.bootstrap_info['rank_port']}"
)
packed_kv_data_ptrs = b"".join(
struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.kv_data_ptrs
)
packed_aux_data_ptrs = b"".join(
struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.aux_data_ptrs
)
sock, lock = self._connect("tcp://" + self.prefill_server_url)
with lock:
sock.send_multipart(
[
"None".encode("ascii"),
get_local_ip_by_remote().encode("ascii"),
str(self.kv_mgr.rank_port).encode("ascii"),
self.session_id.encode("ascii"),
packed_kv_data_ptrs,
packed_aux_data_ptrs,
]
)
@classmethod @classmethod
def _connect(cls, endpoint: str): def _connect(cls, endpoint: str):
with cls._global_lock: with cls._global_lock:
...@@ -520,12 +606,6 @@ class MooncakeKVReceiver(BaseKVReceiver): ...@@ -520,12 +606,6 @@ class MooncakeKVReceiver(BaseKVReceiver):
f"Fetched bootstrap info: {self.bootstrap_info} for engine rank: {self.kv_mgr.kv_args.engine_rank}" f"Fetched bootstrap info: {self.bootstrap_info} for engine rank: {self.kv_mgr.kv_args.engine_rank}"
) )
packed_kv_data_ptrs = b"".join(
struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.kv_data_ptrs
)
packed_aux_data_ptrs = b"".join(
struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.aux_data_ptrs
)
sock, lock = self._connect("tcp://" + self.prefill_server_url) sock, lock = self._connect("tcp://" + self.prefill_server_url)
with lock: with lock:
sock.send_multipart( sock.send_multipart(
...@@ -534,9 +614,7 @@ class MooncakeKVReceiver(BaseKVReceiver): ...@@ -534,9 +614,7 @@ class MooncakeKVReceiver(BaseKVReceiver):
get_local_ip_by_remote().encode("ascii"), get_local_ip_by_remote().encode("ascii"),
str(self.kv_mgr.rank_port).encode("ascii"), str(self.kv_mgr.rank_port).encode("ascii"),
self.session_id.encode("ascii"), self.session_id.encode("ascii"),
packed_kv_data_ptrs,
kv_indices.tobytes(), kv_indices.tobytes(),
packed_aux_data_ptrs,
str(aux_index).encode("ascii"), str(aux_index).encode("ascii"),
] ]
) )
...@@ -610,7 +688,7 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer): ...@@ -610,7 +688,7 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer):
"rank_port": rank_port, "rank_port": rank_port,
} }
logger.debug( logger.debug(
f"Registered Prefill bootstrap: {engine_rank} with rank_ip: {rank_ip} and rank_port: {rank_port}" f"Register Prefill bootstrap: {engine_rank} with rank_ip: {rank_ip} and rank_port: {rank_port}"
) )
return web.Response(text="OK", status=200) return web.Response(text="OK", status=200)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment