"vscode:/vscode.git/clone" did not exist on "ffde65a0942b4d8f4d7ee913bb4d053bae213486"
Unverified Commit 016fd251 authored by Shangming Cai's avatar Shangming Cai Committed by GitHub
Browse files

[PD] Use batch transfer for rdma transport and add notes for mnnvl usage (#8595)


Signed-off-by: default avatarShangming Cai <caishangming@linux.alibaba.com>
parent 02328864
...@@ -37,6 +37,7 @@ from sglang.srt.disaggregation.utils import DisaggregationMode ...@@ -37,6 +37,7 @@ from sglang.srt.disaggregation.utils import DisaggregationMode
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import ( from sglang.srt.utils import (
format_tcp_address, format_tcp_address,
get_bool_env_var,
get_free_port, get_free_port,
get_int_env_var, get_int_env_var,
get_ip, get_ip,
...@@ -198,6 +199,10 @@ class MooncakeKVManager(BaseKVManager): ...@@ -198,6 +199,10 @@ class MooncakeKVManager(BaseKVManager):
self.bootstrap_timeout = get_int_env_var( self.bootstrap_timeout = get_int_env_var(
"SGLANG_DISAGGREGATION_BOOTSTRAP_TIMEOUT", 300 "SGLANG_DISAGGREGATION_BOOTSTRAP_TIMEOUT", 300
) )
self.enable_custom_mem_pool = get_bool_env_var(
"SGLANG_MOONCAKE_CUSTOM_MEM_POOL", "false"
)
elif self.disaggregation_mode == DisaggregationMode.DECODE: elif self.disaggregation_mode == DisaggregationMode.DECODE:
self.heartbeat_failures = {} self.heartbeat_failures = {}
self.session_pool = defaultdict(requests.Session) self.session_pool = defaultdict(requests.Session)
...@@ -258,6 +263,26 @@ class MooncakeKVManager(BaseKVManager): ...@@ -258,6 +263,26 @@ class MooncakeKVManager(BaseKVManager):
socket.connect(endpoint) socket.connect(endpoint)
return socket return socket
def _transfer_data(self, mooncake_session_id, transfer_blocks):
if not transfer_blocks:
return 0
# TODO(shangming): Fix me when nvlink_transport of Mooncake is bug-free
if self.enable_custom_mem_pool:
# batch_transfer_sync has a higher chance to trigger an accuracy drop for MNNVL, fallback to transfer_sync temporarily
for src_addr, dst_addr, length in transfer_blocks:
status = self.engine.transfer_sync(
mooncake_session_id, src_addr, dst_addr, length
)
if status != 0:
return status
return 0
else:
src_addrs, dst_addrs, lengths = zip(*transfer_blocks)
return self.engine.batch_transfer_sync(
mooncake_session_id, list(src_addrs), list(dst_addrs), list(lengths)
)
def send_kvcache( def send_kvcache(
self, self,
mooncake_session_id: str, mooncake_session_id: str,
...@@ -283,17 +308,14 @@ class MooncakeKVManager(BaseKVManager): ...@@ -283,17 +308,14 @@ class MooncakeKVManager(BaseKVManager):
# Worker function for processing a single layer # Worker function for processing a single layer
def process_layer(src_ptr: int, dst_ptr: int, item_len: int) -> int: def process_layer(src_ptr: int, dst_ptr: int, item_len: int) -> int:
transfer_blocks = []
for prefill_index, decode_index in zip(prefill_kv_blocks, dst_kv_blocks): for prefill_index, decode_index in zip(prefill_kv_blocks, dst_kv_blocks):
src_addr = src_ptr + int(prefill_index[0]) * item_len src_addr = src_ptr + int(prefill_index[0]) * item_len
dst_addr = dst_ptr + int(decode_index[0]) * item_len dst_addr = dst_ptr + int(decode_index[0]) * item_len
length = item_len * len(prefill_index) length = item_len * len(prefill_index)
transfer_blocks.append((src_addr, dst_addr, length))
status = self.engine.transfer_sync( return self._transfer_data(mooncake_session_id, transfer_blocks)
mooncake_session_id, src_addr, dst_addr, length
)
if status != 0:
return status
return 0
futures = [ futures = [
executor.submit( executor.submit(
...@@ -465,21 +487,17 @@ class MooncakeKVManager(BaseKVManager): ...@@ -465,21 +487,17 @@ class MooncakeKVManager(BaseKVManager):
dst_aux_ptrs: list[int], dst_aux_ptrs: list[int],
dst_aux_index: int, dst_aux_index: int,
): ):
src_addr_list = [] transfer_blocks = []
dst_addr_list = []
length_list = []
prefill_aux_ptrs = self.kv_args.aux_data_ptrs prefill_aux_ptrs = self.kv_args.aux_data_ptrs
prefill_aux_item_lens = self.kv_args.aux_item_lens prefill_aux_item_lens = self.kv_args.aux_item_lens
for i, dst_aux_ptr in enumerate(dst_aux_ptrs): for i, dst_aux_ptr in enumerate(dst_aux_ptrs):
length = prefill_aux_item_lens[i] length = prefill_aux_item_lens[i]
src_addr = prefill_aux_ptrs[i] + length * prefill_aux_index src_addr = prefill_aux_ptrs[i] + length * prefill_aux_index
dst_addr = dst_aux_ptrs[i] + length * dst_aux_index dst_addr = dst_aux_ptrs[i] + length * dst_aux_index
src_addr_list.append(src_addr) transfer_blocks.append((src_addr, dst_addr, length))
dst_addr_list.append(dst_addr)
length_list.append(length) return self._transfer_data(mooncake_session_id, transfer_blocks)
return self.engine.batch_transfer_sync(
mooncake_session_id, src_addr_list, dst_addr_list, length_list
)
def sync_status_to_decode_endpoint( def sync_status_to_decode_endpoint(
self, remote: str, dst_port: int, room: int, status: int, prefill_rank: int self, remote: str, dst_port: int, room: int, status: int, prefill_rank: int
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment