Unverified Commit a16923ef authored by Francis's avatar Francis Committed by GitHub
Browse files

[PD] optimize kv cache transfer directly using batch transfer (#9149)


Co-authored-by: default avatarShangming Cai <csmthu@gmail.com>
parent 6337d905
...@@ -356,33 +356,49 @@ class MooncakeKVManager(BaseKVManager): ...@@ -356,33 +356,49 @@ class MooncakeKVManager(BaseKVManager):
] ]
assert layers_params is not None assert layers_params is not None
# Worker function for processing a single layer def set_transfer_blocks(
def process_layer(src_ptr: int, dst_ptr: int, item_len: int) -> int: src_ptr: int, dst_ptr: int, item_len: int
) -> List[Tuple[int, int, int]]:
transfer_blocks = [] transfer_blocks = []
for prefill_index, decode_index in zip(prefill_kv_blocks, dst_kv_blocks): for prefill_index, decode_index in zip(prefill_kv_blocks, dst_kv_blocks):
src_addr = src_ptr + int(prefill_index[0]) * item_len src_addr = src_ptr + int(prefill_index[0]) * item_len
dst_addr = dst_ptr + int(decode_index[0]) * item_len dst_addr = dst_ptr + int(decode_index[0]) * item_len
length = item_len * len(prefill_index) length = item_len * len(prefill_index)
transfer_blocks.append((src_addr, dst_addr, length)) transfer_blocks.append((src_addr, dst_addr, length))
return transfer_blocks
# Worker function for processing a single layer
def process_layer(src_ptr: int, dst_ptr: int, item_len: int) -> int:
transfer_blocks = set_transfer_blocks(src_ptr, dst_ptr, item_len)
return self._transfer_data(mooncake_session_id, transfer_blocks) return self._transfer_data(mooncake_session_id, transfer_blocks)
futures = [ # Worker function for processing all layers in a batch
executor.submit( def process_layers(layers_params: List[Tuple[int, int, int]]) -> int:
process_layer, transfer_blocks = []
src_ptr, for src_ptr, dst_ptr, item_len in layers_params:
dst_ptr, transfer_blocks.extend(set_transfer_blocks(src_ptr, dst_ptr, item_len))
item_len, return self._transfer_data(mooncake_session_id, transfer_blocks)
)
for (src_ptr, dst_ptr, item_len) in layers_params
]
for future in concurrent.futures.as_completed(futures): if self.enable_custom_mem_pool:
status = future.result() futures = [
if status != 0: executor.submit(
for f in futures: process_layer,
f.cancel() src_ptr,
return status dst_ptr,
item_len,
)
for (src_ptr, dst_ptr, item_len) in layers_params
]
for future in concurrent.futures.as_completed(futures):
status = future.result()
if status != 0:
for f in futures:
f.cancel()
return status
else:
# Combining all layers' params in one batch transfer is more efficient
# compared to using multiple threads
return process_layers(layers_params)
return 0 return 0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment