Unverified Commit 1403ea56 authored by Shangming Cai's avatar Shangming Cai Committed by GitHub
Browse files

[PD] Support non-MLA models PD different TP with DP attention (#7931)


Signed-off-by: default avatarShangming Cai <caishangming@linux.alibaba.com>
parent b7e951a6
......@@ -321,67 +321,60 @@ class MooncakeKVManager(BaseKVManager):
This may introduce performance overhead (increased TTFT) for long sequences.
"""
# Extract configuration
local_tp_rank = self.kv_args.engine_rank
local_tp_size = self.tp_size // self.dp_size
local_tp_rank_in_group = self.kv_args.engine_rank % local_tp_size
src_kv_item_len = self.kv_args.kv_item_lens[0]
dst_tp_rank_in_group = dst_tp_rank % dst_tp_size
num_kv_heads = self.kv_args.kv_head_num
num_layers = len(self.kv_args.kv_data_ptrs)
page_size = self.kv_args.page_size
# Calculate head distribution
heads_per_decode_rank = num_kv_heads * local_tp_size // dst_tp_size
heads_per_prefill_rank = num_kv_heads
decode_global_head_start = dst_tp_rank * heads_per_decode_rank
prefill_global_head_start = local_tp_rank * heads_per_prefill_rank
bytes_per_head = dst_kv_item_len // heads_per_decode_rank // page_size
decode_rank_item_lens = [dst_kv_item_len for _ in range(num_layers)]
src_heads_per_rank = num_kv_heads
dst_heads_per_rank = num_kv_heads * local_tp_size // dst_tp_size
bytes_per_head_slice_to_send = (
dst_kv_item_len // page_size // dst_heads_per_rank
)
# Determine slicing parameters based on TP configuration
if local_tp_size > dst_tp_size:
src_head_offset = 0
num_heads_to_send = heads_per_prefill_rank
dst_head_offset = prefill_global_head_start - decode_global_head_start
# Send KVCache from multiple prefill instances to 1 decode instance
src_head_start_offset = 0
num_heads_to_send = src_heads_per_rank
dst_head_start_offset = local_tp_rank_in_group * src_heads_per_rank
else:
src_head_offset = decode_global_head_start - prefill_global_head_start
num_heads_to_send = heads_per_decode_rank
dst_head_offset = 0
# Send KVCache from 1 prefill instance to multiple decode instances
src_head_start_offset = dst_tp_rank_in_group * dst_heads_per_rank
num_heads_to_send = dst_heads_per_rank
dst_head_start_offset = 0
layer_transfer_params = []
layers_params = []
for layer_id in range(num_layers):
item_len_of_prefill_rank_page = self.kv_args.kv_item_lens[layer_id]
# Page stride on the target dst decode rank for its slice pages
item_len_of_decode_rank_page = decode_rank_item_lens[layer_id]
if item_len_of_prefill_rank_page == 0 or num_kv_heads == 0:
logger.error(
f"Invalid item_len_of_prefill_rank_page or num_kv_heads for layer {layer_id}"
)
return -1
# Calculate precise byte offset and length for the sub-slice within the prefill page data
src_slice_offset = src_head_offset * bytes_per_head
dst_slice_offset = dst_head_offset * bytes_per_head
slice_lens_per_page = num_heads_to_send * bytes_per_head
# Calculate precise byte offset and length for the sub-slice within the token
src_head_slice_offset = src_head_start_offset * bytes_per_head_slice_to_send
dst_head_slice_offset = dst_head_start_offset * bytes_per_head_slice_to_send
heads_bytes_per_token_to_send = (
num_heads_to_send * bytes_per_head_slice_to_send
)
# Sanity check: The data sub-slice to be sent should fit into the decode instance's page.
# This means slice_lens_per_page <= item_len_of_decode_rank_page
if slice_lens_per_page > item_len_of_decode_rank_page:
# Sanity check: The data sub-slice to be sent should fit into the dst buffer.
# This means heads_bytes_per_token_to_send <= (dst_kv_item_len // page_size)
if heads_bytes_per_token_to_send > (dst_kv_item_len // page_size):
logger.error(
f"[{mooncake_session_id}] Layer {layer_id}: "
f"slice size ({slice_lens_per_page}) exceeds "
f"target page size ({item_len_of_decode_rank_page})"
f"slice size ({heads_bytes_per_token_to_send}) exceeds "
f"target token slot size ({dst_kv_item_len // page_size})"
)
return -1
layer_transfer_params.append(
layers_params.append(
(
self.kv_args.kv_data_ptrs[layer_id],
dst_kv_ptrs[layer_id],
item_len_of_prefill_rank_page,
item_len_of_decode_rank_page,
src_slice_offset,
dst_slice_offset,
slice_lens_per_page,
src_kv_item_len,
dst_kv_item_len,
src_head_slice_offset,
dst_head_slice_offset,
heads_bytes_per_token_to_send,
)
)
......@@ -391,9 +384,9 @@ class MooncakeKVManager(BaseKVManager):
dst_ptr,
src_item_len,
dst_item_len,
src_offset,
dst_offset,
slice_lens_per_page,
src_head_slice_offset,
dst_head_slice_offset,
heads_bytes_per_token_to_send,
) = layer_params
src_addr_list = []
dst_addr_list = []
......@@ -424,17 +417,12 @@ class MooncakeKVManager(BaseKVManager):
)
# Calculate final src and dst addresses by applying head-slice offsets
src_slice_addr = src_token_slot_start_addr + src_offset
dst_slice_addr = dst_token_slot_start_addr + dst_offset
src_slice_addr = src_token_slot_start_addr + src_head_slice_offset
dst_slice_addr = dst_token_slot_start_addr + dst_head_slice_offset
src_addr_list.append(src_slice_addr)
dst_addr_list.append(dst_slice_addr)
length_list.append(slice_lens_per_page)
logger.debug(
f"SYNC: sid={mooncake_session_id}, "
f"src={src_slice_addr}, dst={dst_slice_addr}, len={slice_lens_per_page}"
)
length_list.append(heads_bytes_per_token_to_send)
return self.engine.batch_transfer_sync(
mooncake_session_id, src_addr_list, dst_addr_list, length_list
......@@ -445,7 +433,7 @@ class MooncakeKVManager(BaseKVManager):
process_layer_tp_aware,
layer_params,
)
for layer_params in layer_transfer_params
for layer_params in layers_params
]
for future in concurrent.futures.as_completed(futures):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment