"cuda/moe.cpp" did not exist on "3a458fa7fc504640b5c5ec3394e7ef79d86a786c"
Unverified Commit 1403ea56 authored by Shangming Cai's avatar Shangming Cai Committed by GitHub
Browse files

[PD] Support non-MLA models PD different TP with DP attention (#7931)


Signed-off-by: default avatarShangming Cai <caishangming@linux.alibaba.com>
parent b7e951a6
...@@ -321,67 +321,60 @@ class MooncakeKVManager(BaseKVManager): ...@@ -321,67 +321,60 @@ class MooncakeKVManager(BaseKVManager):
This may introduce performance overhead (increased TTFT) for long sequences. This may introduce performance overhead (increased TTFT) for long sequences.
""" """
# Extract configuration # Extract configuration
local_tp_rank = self.kv_args.engine_rank
local_tp_size = self.tp_size // self.dp_size local_tp_size = self.tp_size // self.dp_size
local_tp_rank_in_group = self.kv_args.engine_rank % local_tp_size
src_kv_item_len = self.kv_args.kv_item_lens[0]
dst_tp_rank_in_group = dst_tp_rank % dst_tp_size
num_kv_heads = self.kv_args.kv_head_num num_kv_heads = self.kv_args.kv_head_num
num_layers = len(self.kv_args.kv_data_ptrs) num_layers = len(self.kv_args.kv_data_ptrs)
page_size = self.kv_args.page_size page_size = self.kv_args.page_size
# Calculate head distribution # Calculate head distribution
heads_per_decode_rank = num_kv_heads * local_tp_size // dst_tp_size src_heads_per_rank = num_kv_heads
heads_per_prefill_rank = num_kv_heads dst_heads_per_rank = num_kv_heads * local_tp_size // dst_tp_size
decode_global_head_start = dst_tp_rank * heads_per_decode_rank bytes_per_head_slice_to_send = (
prefill_global_head_start = local_tp_rank * heads_per_prefill_rank dst_kv_item_len // page_size // dst_heads_per_rank
bytes_per_head = dst_kv_item_len // heads_per_decode_rank // page_size )
decode_rank_item_lens = [dst_kv_item_len for _ in range(num_layers)]
# Determine slicing parameters based on TP configuration # Determine slicing parameters based on TP configuration
if local_tp_size > dst_tp_size: if local_tp_size > dst_tp_size:
src_head_offset = 0 # Send KVCache from multiple prefill instances to 1 decode instance
num_heads_to_send = heads_per_prefill_rank src_head_start_offset = 0
dst_head_offset = prefill_global_head_start - decode_global_head_start num_heads_to_send = src_heads_per_rank
dst_head_start_offset = local_tp_rank_in_group * src_heads_per_rank
else: else:
src_head_offset = decode_global_head_start - prefill_global_head_start # Send KVCache from 1 prefill instance to multiple decode instances
num_heads_to_send = heads_per_decode_rank src_head_start_offset = dst_tp_rank_in_group * dst_heads_per_rank
dst_head_offset = 0 num_heads_to_send = dst_heads_per_rank
dst_head_start_offset = 0
layer_transfer_params = [] layers_params = []
for layer_id in range(num_layers): for layer_id in range(num_layers):
item_len_of_prefill_rank_page = self.kv_args.kv_item_lens[layer_id] # Calculate precise byte offset and length for the sub-slice within the token
src_head_slice_offset = src_head_start_offset * bytes_per_head_slice_to_send
# Page stride on the target dst decode rank for its slice pages dst_head_slice_offset = dst_head_start_offset * bytes_per_head_slice_to_send
item_len_of_decode_rank_page = decode_rank_item_lens[layer_id] heads_bytes_per_token_to_send = (
num_heads_to_send * bytes_per_head_slice_to_send
if item_len_of_prefill_rank_page == 0 or num_kv_heads == 0: )
logger.error(
f"Invalid item_len_of_prefill_rank_page or num_kv_heads for layer {layer_id}"
)
return -1
# Calculate precise byte offset and length for the sub-slice within the prefill page data
src_slice_offset = src_head_offset * bytes_per_head
dst_slice_offset = dst_head_offset * bytes_per_head
slice_lens_per_page = num_heads_to_send * bytes_per_head
# Sanity check: The data sub-slice to be sent should fit into the decode instance's page. # Sanity check: The data sub-slice to be sent should fit into the dst buffer.
# This means slice_lens_per_page <= item_len_of_decode_rank_page # This means heads_bytes_per_token_to_send <= (dst_kv_item_len // page_size)
if slice_lens_per_page > item_len_of_decode_rank_page: if heads_bytes_per_token_to_send > (dst_kv_item_len // page_size):
logger.error( logger.error(
f"[{mooncake_session_id}] Layer {layer_id}: " f"[{mooncake_session_id}] Layer {layer_id}: "
f"slice size ({slice_lens_per_page}) exceeds " f"slice size ({heads_bytes_per_token_to_send}) exceeds "
f"target page size ({item_len_of_decode_rank_page})" f"target token slot size ({dst_kv_item_len // page_size})"
) )
return -1 return -1
layer_transfer_params.append( layers_params.append(
( (
self.kv_args.kv_data_ptrs[layer_id], self.kv_args.kv_data_ptrs[layer_id],
dst_kv_ptrs[layer_id], dst_kv_ptrs[layer_id],
item_len_of_prefill_rank_page, src_kv_item_len,
item_len_of_decode_rank_page, dst_kv_item_len,
src_slice_offset, src_head_slice_offset,
dst_slice_offset, dst_head_slice_offset,
slice_lens_per_page, heads_bytes_per_token_to_send,
) )
) )
...@@ -391,9 +384,9 @@ class MooncakeKVManager(BaseKVManager): ...@@ -391,9 +384,9 @@ class MooncakeKVManager(BaseKVManager):
dst_ptr, dst_ptr,
src_item_len, src_item_len,
dst_item_len, dst_item_len,
src_offset, src_head_slice_offset,
dst_offset, dst_head_slice_offset,
slice_lens_per_page, heads_bytes_per_token_to_send,
) = layer_params ) = layer_params
src_addr_list = [] src_addr_list = []
dst_addr_list = [] dst_addr_list = []
...@@ -424,17 +417,12 @@ class MooncakeKVManager(BaseKVManager): ...@@ -424,17 +417,12 @@ class MooncakeKVManager(BaseKVManager):
) )
# Calculate final src and dst addresses by applying head-slice offsets # Calculate final src and dst addresses by applying head-slice offsets
src_slice_addr = src_token_slot_start_addr + src_offset src_slice_addr = src_token_slot_start_addr + src_head_slice_offset
dst_slice_addr = dst_token_slot_start_addr + dst_offset dst_slice_addr = dst_token_slot_start_addr + dst_head_slice_offset
src_addr_list.append(src_slice_addr) src_addr_list.append(src_slice_addr)
dst_addr_list.append(dst_slice_addr) dst_addr_list.append(dst_slice_addr)
length_list.append(slice_lens_per_page) length_list.append(heads_bytes_per_token_to_send)
logger.debug(
f"SYNC: sid={mooncake_session_id}, "
f"src={src_slice_addr}, dst={dst_slice_addr}, len={slice_lens_per_page}"
)
return self.engine.batch_transfer_sync( return self.engine.batch_transfer_sync(
mooncake_session_id, src_addr_list, dst_addr_list, length_list mooncake_session_id, src_addr_list, dst_addr_list, length_list
...@@ -445,7 +433,7 @@ class MooncakeKVManager(BaseKVManager): ...@@ -445,7 +433,7 @@ class MooncakeKVManager(BaseKVManager):
process_layer_tp_aware, process_layer_tp_aware,
layer_params, layer_params,
) )
for layer_params in layer_transfer_params for layer_params in layers_params
] ]
for future in concurrent.futures.as_completed(futures): for future in concurrent.futures.as_completed(futures):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment