Unverified Commit e21aa1df authored by Hongbo Xu's avatar Hongbo Xu Committed by GitHub
Browse files

[PD] Add different TP sizes support for no-MLA models (#6793)


Co-authored-by: default avatarshangmingc <csmthu@gmail.com>
Co-authored-by: default avatarShangming Cai <caishangming@linux.alibaba.com>
parent f3cbd245
...@@ -27,6 +27,8 @@ class KVArgs: ...@@ -27,6 +27,8 @@ class KVArgs:
decode_tp_size: int decode_tp_size: int
# for pp prefill # for pp prefill
prefill_pp_size: int prefill_pp_size: int
kv_head_num: int
page_size: int
class KVPoll: class KVPoll:
......
...@@ -103,6 +103,9 @@ class KVArgsRegisterInfo: ...@@ -103,6 +103,9 @@ class KVArgsRegisterInfo:
mooncake_session_id: str mooncake_session_id: str
dst_kv_ptrs: list[int] dst_kv_ptrs: list[int]
dst_aux_ptrs: list[int] dst_aux_ptrs: list[int]
dst_tp_rank: int
dst_tp_size: int
dst_kv_item_len: int
@classmethod @classmethod
def from_zmq(cls, msg: List[bytes]): def from_zmq(cls, msg: List[bytes]):
...@@ -113,6 +116,9 @@ class KVArgsRegisterInfo: ...@@ -113,6 +116,9 @@ class KVArgsRegisterInfo:
mooncake_session_id=msg[3].decode("ascii"), mooncake_session_id=msg[3].decode("ascii"),
dst_kv_ptrs=list(struct.unpack(f"{len(msg[4])//8}Q", msg[4])), dst_kv_ptrs=list(struct.unpack(f"{len(msg[4])//8}Q", msg[4])),
dst_aux_ptrs=list(struct.unpack(f"{len(msg[5])//8}Q", msg[5])), dst_aux_ptrs=list(struct.unpack(f"{len(msg[5])//8}Q", msg[5])),
dst_tp_rank=int(msg[6].decode("ascii")),
dst_tp_size=int(msg[7].decode("ascii")),
dst_kv_item_len=int(msg[8].decode("ascii")),
) )
...@@ -189,6 +195,8 @@ class MooncakeKVManager(BaseKVManager): ...@@ -189,6 +195,8 @@ class MooncakeKVManager(BaseKVManager):
self.session_pool_lock = threading.Lock() self.session_pool_lock = threading.Lock()
self.addr_to_rooms_tracker = defaultdict(set) self.addr_to_rooms_tracker = defaultdict(set)
self.connection_lock = threading.Lock() self.connection_lock = threading.Lock()
self.required_prefill_info_num_map: Dict[int, int] = {}
self.decode_kv_arrive_state: Dict[int, Set[int]] = defaultdict(set)
# Heartbeat interval should be at least 2 seconds # Heartbeat interval should be at least 2 seconds
self.heartbeat_interval = max( self.heartbeat_interval = max(
float(os.getenv("SGLANG_DISAGGREGATION_HEARTBEAT_INTERVAL", 5.0)), 2.0 float(os.getenv("SGLANG_DISAGGREGATION_HEARTBEAT_INTERVAL", 5.0)), 2.0
...@@ -284,6 +292,163 @@ class MooncakeKVManager(BaseKVManager): ...@@ -284,6 +292,163 @@ class MooncakeKVManager(BaseKVManager):
return 0 return 0
def send_kvcache_slice(
self,
mooncake_session_id: str,
prefill_kv_indices: npt.NDArray[np.int64],
dst_kv_ptrs: list[int],
dst_kv_indices: npt.NDArray[np.int64],
dst_tp_rank: int,
dst_tp_size: int,
dst_kv_item_len: int,
executor: concurrent.futures.ThreadPoolExecutor,
):
"""
Sends KV cache slices from this Prefill rank to a target Decode rank,
supporting generic M-to-N TP size configurations.
NOTE: This implementation calls the transfer engine for each token slot within
each page to ensure correctness for any page_size and head-slicing configuration.
This may introduce performance overhead (increased TTFT) for long sequences.
"""
# rank/kv_head config
local_tp_rank = self.kv_args.engine_rank
local_tp_size = self.tp_size // self.dp_size
num_kv_heads = self.kv_args.kv_head_num
num_layers = len(self.kv_args.kv_data_ptrs)
page_size = self.kv_args.page_size
heads_per_decode_rank = num_kv_heads * local_tp_size // dst_tp_size
heads_per_prefill_rank = num_kv_heads
decode_global_head_start = dst_tp_rank * heads_per_decode_rank
prefill_global_head_start = local_tp_rank * heads_per_prefill_rank
bytes_per_head = dst_kv_item_len // heads_per_decode_rank // page_size
# decode config
decode_rank_item_lens = [dst_kv_item_len for _ in range(num_layers)]
if local_tp_size > dst_tp_size:
src_head_offset = 0
num_heads_to_send = heads_per_prefill_rank
dst_head_offset = prefill_global_head_start - decode_global_head_start
else:
src_head_offset = decode_global_head_start - prefill_global_head_start
num_heads_to_send = heads_per_decode_rank
dst_head_offset = 0
layer_transfer_params = []
for layer_id in range(num_layers):
item_len_of_prefill_rank_page = self.kv_args.kv_item_lens[layer_id]
# Page stride on the target Decode rank for its slice pages
item_len_of_decode_rank_page = decode_rank_item_lens[layer_id]
if item_len_of_prefill_rank_page == 0 or num_kv_heads == 0:
logger.error(
f"Invalid item_len_of_prefill_rank_page or num_kv_heads for layer {layer_id}"
)
return -1
# Calculate precise byte offset and length for the sub-slice within Prefill page data
src_slice_offset = src_head_offset * bytes_per_head
dst_slice_offset = dst_head_offset * bytes_per_head
slice_lens_per_page = num_heads_to_send * bytes_per_head
# Sanity check: The data sub-slice we intend to send should fit into D_n's page.
# This means slice_lens_per_page <= item_len_of_decode_rank_page
if slice_lens_per_page > item_len_of_decode_rank_page:
logger.error(
f"[{mooncake_session_id}] Layer {layer_id}: "
f"slice size ({slice_lens_per_page}) exceeds "
f"target page size ({item_len_of_decode_rank_page})"
)
return -1
layer_transfer_params.append(
(
self.kv_args.kv_data_ptrs[layer_id], # Prefill base ptr (all heads)
dst_kv_ptrs[
layer_id
], # Decode base ptr (for its slice for this layer)
item_len_of_prefill_rank_page, # Prefill page size (all heads)2048
item_len_of_decode_rank_page, # Decode page stride (for its slice page) 1024
src_slice_offset, # Offset to slice data in Prefill page
dst_slice_offset, # Offset to slice data in Decode page
slice_lens_per_page, # Length of slice data per page (actual data to send)
)
)
def process_layer_tp_aware(layer_params):
(
src_ptr,
dst_ptr,
src_item_len,
dst_item_len,
src_offset,
dst_offset,
slice_lens_per_page,
) = layer_params
src_addr_list = []
dst_addr_list = []
length_list = []
# Calculate strides for a single token slot
bytes_per_token_on_prefill = src_item_len // page_size
bytes_per_token_on_decode = dst_item_len // page_size
for i in range(len(prefill_kv_indices)):
prefill_page_idx = int(prefill_kv_indices[i])
decode_page_idx = int(dst_kv_indices[i])
# Get the starting memory address for the current source and destination pages
src_page_start_addr = src_ptr + prefill_page_idx * src_item_len
dst_page_start_addr = dst_ptr + decode_page_idx * dst_item_len
# Iterate through each valid token slot within the current page
for token_slot_in_page in range(page_size):
# Calculate start address of the current token slot
src_token_slot_start_addr = (
src_page_start_addr
+ token_slot_in_page * bytes_per_token_on_prefill
)
dst_token_slot_start_addr = (
dst_page_start_addr
+ token_slot_in_page * bytes_per_token_on_decode
)
# Calculate final source and destination addresses by applying head-slice offsets
src_slice_addr = src_token_slot_start_addr + src_offset
dst_slice_addr = dst_token_slot_start_addr + dst_offset
src_addr_list.append(src_slice_addr)
dst_addr_list.append(dst_slice_addr)
length_list.append(slice_lens_per_page)
logger.debug(
f"SYNC: sid={mooncake_session_id}, "
f"src={src_slice_addr}, dst={dst_slice_addr}, len={slice_lens_per_page}"
)
return self.engine.batch_transfer_sync(
mooncake_session_id, src_addr_list, dst_addr_list, length_list
)
futures = [
executor.submit(
process_layer_tp_aware,
layer_params,
)
for layer_params in layer_transfer_params
]
for future in concurrent.futures.as_completed(futures):
status = future.result()
if status != 0:
for f in futures:
f.cancel()
return status
return 0
def send_aux( def send_aux(
self, self,
mooncake_session_id: str, mooncake_session_id: str,
...@@ -308,7 +473,7 @@ class MooncakeKVManager(BaseKVManager): ...@@ -308,7 +473,7 @@ class MooncakeKVManager(BaseKVManager):
) )
def sync_status_to_decode_endpoint( def sync_status_to_decode_endpoint(
self, remote: str, dst_port: int, room: int, status: int self, remote: str, dst_port: int, room: int, status: int, prefill_rank: int
): ):
if ":" in remote: if ":" in remote:
remote = remote.split(":")[0] remote = remote.split(":")[0]
...@@ -316,6 +481,7 @@ class MooncakeKVManager(BaseKVManager): ...@@ -316,6 +481,7 @@ class MooncakeKVManager(BaseKVManager):
[ [
str(room).encode("ascii"), str(room).encode("ascii"),
str(status).encode("ascii"), str(status).encode("ascii"),
str(prefill_rank).encode("ascii"),
] ]
) )
...@@ -332,6 +498,7 @@ class MooncakeKVManager(BaseKVManager): ...@@ -332,6 +498,7 @@ class MooncakeKVManager(BaseKVManager):
) )
polls = [] polls = []
dst_ranks_infos = [] dst_ranks_infos = []
local_rank = self.kv_args.engine_rank
for req in reqs_to_be_processed: for req in reqs_to_be_processed:
if not req.is_dummy: if not req.is_dummy:
# Early exit if the request has failed # Early exit if the request has failed
...@@ -347,6 +514,7 @@ class MooncakeKVManager(BaseKVManager): ...@@ -347,6 +514,7 @@ class MooncakeKVManager(BaseKVManager):
req.dst_port, req.dst_port,
req.room, req.room,
KVPoll.Failed, KVPoll.Failed,
local_rank,
) )
break break
...@@ -364,15 +532,31 @@ class MooncakeKVManager(BaseKVManager): ...@@ -364,15 +532,31 @@ class MooncakeKVManager(BaseKVManager):
f"len(chunked_dst_kv_indice) = {len(chunked_dst_kv_indice)}, len(kv_chunk.prefill_kv_indices) = {len(kv_chunk.prefill_kv_indices)}" f"len(chunked_dst_kv_indice) = {len(chunked_dst_kv_indice)}, len(kv_chunk.prefill_kv_indices) = {len(kv_chunk.prefill_kv_indices)}"
) )
target_rank_registration_info: KVArgsRegisterInfo = (
self.decode_kv_args_table[req.mooncake_session_id]
)
local_tp_size = self.tp_size // self.dp_size
if self.is_mla_backend or (
local_tp_size == target_rank_registration_info.dst_tp_size
):
ret = self.send_kvcache( ret = self.send_kvcache(
req.mooncake_session_id, req.mooncake_session_id,
kv_chunk.prefill_kv_indices, kv_chunk.prefill_kv_indices,
self.decode_kv_args_table[ target_rank_registration_info.dst_kv_ptrs,
req.mooncake_session_id
].dst_kv_ptrs,
chunked_dst_kv_indice, chunked_dst_kv_indice,
executor, executor,
) )
else:
ret = self.send_kvcache_slice(
req.mooncake_session_id,
kv_chunk.prefill_kv_indices,
target_rank_registration_info.dst_kv_ptrs,
chunked_dst_kv_indice,
target_rank_registration_info.dst_tp_rank,
target_rank_registration_info.dst_tp_size,
target_rank_registration_info.dst_kv_item_len,
executor,
)
if ret != 0: if ret != 0:
with self.session_lock: with self.session_lock:
self.session_failures[req.mooncake_session_id] += 1 self.session_failures[req.mooncake_session_id] += 1
...@@ -388,7 +572,11 @@ class MooncakeKVManager(BaseKVManager): ...@@ -388,7 +572,11 @@ class MooncakeKVManager(BaseKVManager):
) )
self.update_status(kv_chunk.room, KVPoll.Failed) self.update_status(kv_chunk.room, KVPoll.Failed)
self.sync_status_to_decode_endpoint( self.sync_status_to_decode_endpoint(
req.endpoint, req.dst_port, req.room, KVPoll.Failed req.endpoint,
req.dst_port,
req.room,
KVPoll.Failed,
local_rank,
) )
break break
...@@ -413,7 +601,7 @@ class MooncakeKVManager(BaseKVManager): ...@@ -413,7 +601,7 @@ class MooncakeKVManager(BaseKVManager):
self.update_status(req.room, status) self.update_status(req.room, status)
for endpoint, dst_port, room in dst_ranks_infos: for endpoint, dst_port, room in dst_ranks_infos:
self.sync_status_to_decode_endpoint( self.sync_status_to_decode_endpoint(
endpoint, dst_port, room, status endpoint, dst_port, room, status, local_rank
) )
else: else:
# Dummy request means the decode instance is not used, so its status can be marked as success directly # Dummy request means the decode instance is not used, so its status can be marked as success directly
...@@ -479,10 +667,28 @@ class MooncakeKVManager(BaseKVManager): ...@@ -479,10 +667,28 @@ class MooncakeKVManager(BaseKVManager):
def decode_thread(): def decode_thread():
while True: while True:
(bootstrap_room, status) = self.server_socket.recv_multipart() (bootstrap_room, status, prefill_rank) = (
self.server_socket.recv_multipart()
)
status = int(status.decode("ascii")) status = int(status.decode("ascii"))
bootstrap_room = int(bootstrap_room.decode("ascii")) bootstrap_room = int(bootstrap_room.decode("ascii"))
if status == KVPoll.Failed: prefill_rank = int(prefill_rank.decode("ascii"))
if status == KVPoll.Success:
# record arrived prefill_rank
self.decode_kv_arrive_state[bootstrap_room].add(prefill_rank)
expected_prefill_num = self.required_prefill_info_num_map[
bootstrap_room
]
arrived_prefill_num = len(
self.decode_kv_arrive_state[bootstrap_room]
)
if (
self.is_mla_backend
or arrived_prefill_num == expected_prefill_num
):
self.update_status(bootstrap_room, KVPoll.Success)
elif status == KVPoll.Failed:
self.record_failure( self.record_failure(
bootstrap_room, bootstrap_room,
f"Failed to get kvcache from prefill instance, it might be dead", f"Failed to get kvcache from prefill instance, it might be dead",
...@@ -713,7 +919,10 @@ class MooncakeKVSender(BaseKVSender): ...@@ -713,7 +919,10 @@ class MooncakeKVSender(BaseKVSender):
if not is_last: if not is_last:
self.kv_mgr.add_transfer_request( self.kv_mgr.add_transfer_request(
self.bootstrap_room, kv_indices, index_slice, False self.bootstrap_room,
kv_indices,
index_slice,
False,
) )
else: else:
self.kv_mgr.add_transfer_request( self.kv_mgr.add_transfer_request(
...@@ -822,23 +1031,26 @@ class MooncakeKVReceiver(BaseKVReceiver): ...@@ -822,23 +1031,26 @@ class MooncakeKVReceiver(BaseKVReceiver):
self.kv_mgr.kv_args.engine_rank % local_tp_size_per_dp_rank self.kv_mgr.kv_args.engine_rank % local_tp_size_per_dp_rank
) )
self.required_dst_info_num = 1 self.required_dst_info_num = 1
self.required_prefill_info_num = 1
self.target_tp_ranks = [self.target_tp_rank] self.target_tp_ranks = [self.target_tp_rank]
elif local_tp_size_per_dp_rank > prefill_tp_size_per_dp_rank: elif local_tp_size_per_dp_rank > prefill_tp_size_per_dp_rank:
assert ( if not self.kv_mgr.is_mla_backend:
self.kv_mgr.is_mla_backend logger.warning_once(
), "PD with different TP sizes per DP rank is not yet supported for non-MLA models" "Performance is NOT guaranteed when using different TP sizes for non-MLA models. "
)
self.target_tp_rank = ( self.target_tp_rank = (
self.kv_mgr.kv_args.engine_rank % local_tp_size_per_dp_rank self.kv_mgr.kv_args.engine_rank % local_tp_size_per_dp_rank
) // (local_tp_size_per_dp_rank // prefill_tp_size_per_dp_rank) ) // (local_tp_size_per_dp_rank // prefill_tp_size_per_dp_rank)
self.required_dst_info_num = ( self.required_dst_info_num = (
local_tp_size_per_dp_rank // prefill_tp_size_per_dp_rank local_tp_size_per_dp_rank // prefill_tp_size_per_dp_rank
) )
self.required_prefill_info_num = 1
self.target_tp_ranks = [self.target_tp_rank] self.target_tp_ranks = [self.target_tp_rank]
else: else:
assert ( if not self.kv_mgr.is_mla_backend:
self.kv_mgr.is_mla_backend logger.warning_once(
), "PD with different TP sizes per DP rank is not yet supported for non-MLA models" "Performance is NOT guaranteed when using different TP sizes for non-MLA models. "
)
# For non-MLA models, one decode rank needs to retrieve KVCache from multiple prefill ranks for non MLA models; # For non-MLA models, one decode rank needs to retrieve KVCache from multiple prefill ranks for non MLA models;
self.target_tp_ranks = [ self.target_tp_ranks = [
rank rank
...@@ -855,6 +1067,9 @@ class MooncakeKVReceiver(BaseKVReceiver): ...@@ -855,6 +1067,9 @@ class MooncakeKVReceiver(BaseKVReceiver):
# or the KVPoll will never be set correctly # or the KVPoll will never be set correctly
self.target_tp_rank = self.target_tp_ranks[0] self.target_tp_rank = self.target_tp_ranks[0]
self.required_dst_info_num = 1 self.required_dst_info_num = 1
self.required_prefill_info_num = (
prefill_tp_size_per_dp_rank // local_tp_size_per_dp_rank
)
if self.data_parallel_rank is not None: if self.data_parallel_rank is not None:
logger.debug(f"Targeting DP rank: {self.data_parallel_rank}") logger.debug(f"Targeting DP rank: {self.data_parallel_rank}")
...@@ -862,6 +1077,9 @@ class MooncakeKVReceiver(BaseKVReceiver): ...@@ -862,6 +1077,9 @@ class MooncakeKVReceiver(BaseKVReceiver):
else: else:
self.target_dp_group = bootstrap_room % self.prefill_dp_size self.target_dp_group = bootstrap_room % self.prefill_dp_size
self.kv_mgr.required_prefill_info_num_map[self.bootstrap_room] = (
self.required_prefill_info_num
)
# NOTE: key distinguished by bootstrap_addr, target_dp_group, and target_tp_rank # NOTE: key distinguished by bootstrap_addr, target_dp_group, and target_tp_rank
bootstrap_key = ( bootstrap_key = (
f"{self.bootstrap_addr}_{self.target_dp_group}_{self.target_tp_rank}" f"{self.bootstrap_addr}_{self.target_dp_group}_{self.target_tp_rank}"
...@@ -875,11 +1093,15 @@ class MooncakeKVReceiver(BaseKVReceiver): ...@@ -875,11 +1093,15 @@ class MooncakeKVReceiver(BaseKVReceiver):
self.target_dp_group, self.target_dp_group,
) )
if bootstrap_info is not None: if bootstrap_info is not None:
# NOTE: only support MLA for now: select one prefill rank as real rank if self.kv_mgr.is_mla_backend:
# MLA :select one prefill rank as real rank
bootstrap_info["is_dummy"] = not bool( bootstrap_info["is_dummy"] = not bool(
target_tp_rank == self.target_tp_rank target_tp_rank == self.target_tp_rank
or self.target_tp_rank is None or self.target_tp_rank is None
) )
else:
# no-MLA:select all prefill ranks
bootstrap_info["is_dummy"] = False
logger.debug( logger.debug(
f"Fetched bootstrap info: {bootstrap_info} for DP {self.target_dp_group} TP {target_tp_rank}" f"Fetched bootstrap info: {bootstrap_info} for DP {self.target_dp_group} TP {target_tp_rank}"
) )
...@@ -951,6 +1173,12 @@ class MooncakeKVReceiver(BaseKVReceiver): ...@@ -951,6 +1173,12 @@ class MooncakeKVReceiver(BaseKVReceiver):
packed_aux_data_ptrs = b"".join( packed_aux_data_ptrs = b"".join(
struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.aux_data_ptrs struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.aux_data_ptrs
) )
tp_rank = self.kv_mgr.kv_args.engine_rank
tp_size = self.kv_mgr.tp_size // self.kv_mgr.dp_size
kv_item_len = self.kv_mgr.kv_args.kv_item_lens[0]
dst_tp_rank = str(tp_rank).encode("ascii")
dst_tp_size = str(tp_size).encode("ascii")
dst_kv_item_len = str(kv_item_len).encode("ascii")
sock, lock = self._connect("tcp://" + self.prefill_server_url) sock, lock = self._connect("tcp://" + self.prefill_server_url)
with lock: with lock:
...@@ -962,6 +1190,9 @@ class MooncakeKVReceiver(BaseKVReceiver): ...@@ -962,6 +1190,9 @@ class MooncakeKVReceiver(BaseKVReceiver):
self.session_id.encode("ascii"), self.session_id.encode("ascii"),
packed_kv_data_ptrs, packed_kv_data_ptrs,
packed_aux_data_ptrs, packed_aux_data_ptrs,
dst_tp_rank,
dst_tp_size,
dst_kv_item_len,
] ]
) )
...@@ -1009,6 +1240,8 @@ class MooncakeKVReceiver(BaseKVReceiver): ...@@ -1009,6 +1240,8 @@ class MooncakeKVReceiver(BaseKVReceiver):
def clear(self) -> None: def clear(self) -> None:
if self.bootstrap_room in self.kv_mgr.request_status: if self.bootstrap_room in self.kv_mgr.request_status:
self.kv_mgr.request_status.pop(self.bootstrap_room) self.kv_mgr.request_status.pop(self.bootstrap_room)
self.kv_mgr.required_prefill_info_num_map.pop(self.bootstrap_room)
self.kv_mgr.decode_kv_arrive_state.pop(self.bootstrap_room)
def failure_exception(self): def failure_exception(self):
# Explicitly set the status to failure since this request has failed in another rank # Explicitly set the status to failure since this request has failed in another rank
......
...@@ -122,6 +122,9 @@ class PrefillBootstrapQueue: ...@@ -122,6 +122,9 @@ class PrefillBootstrapQueue:
kv_args.kv_data_ptrs = kv_data_ptrs kv_args.kv_data_ptrs = kv_data_ptrs
kv_args.kv_data_lens = kv_data_lens kv_args.kv_data_lens = kv_data_lens
kv_args.kv_item_lens = kv_item_lens kv_args.kv_item_lens = kv_item_lens
if not self.is_mla_backend:
kv_args.kv_head_num = self.token_to_kv_pool.head_num
kv_args.page_size = self.token_to_kv_pool.page_size
kv_args.aux_data_ptrs, kv_args.aux_data_lens, kv_args.aux_item_lens = ( kv_args.aux_data_ptrs, kv_args.aux_data_lens, kv_args.aux_item_lens = (
self.metadata_buffers.get_buf_infos() self.metadata_buffers.get_buf_infos()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment