Unverified Commit 25ef53f0 authored by Shangming Cai's avatar Shangming Cai Committed by GitHub
Browse files

[PD] Fix nvlink transport accuracy through transferring metadata with tcp (#9261)


Signed-off-by: default avatarShangming Cai <csmthu@gmail.com>
parent c674bf9c
...@@ -2,6 +2,7 @@ from __future__ import annotations ...@@ -2,6 +2,7 @@ from __future__ import annotations
import asyncio import asyncio
import concurrent.futures import concurrent.futures
import ctypes
import dataclasses import dataclasses
import logging import logging
import os import os
...@@ -138,7 +139,29 @@ class KVArgsRegisterInfo: ...@@ -138,7 +139,29 @@ class KVArgsRegisterInfo:
) )
class AuxDataCodec:
"""Handles serialization and deserialization of auxiliary data buffers"""
@staticmethod
def serialize_data_from_buffer(src_addr, data_length):
"""Serialize data from memory buffer to bytes"""
buffer = (ctypes.c_byte * data_length).from_address(src_addr)
return bytes(buffer)
@staticmethod
def deserialize_data_to_buffer(kv_args, buffer_index, aux_index, data):
"""Deserialize bytes into target memory buffer"""
dst_aux_ptr = kv_args.aux_data_ptrs[buffer_index]
item_len = kv_args.aux_item_lens[buffer_index]
dst_addr = dst_aux_ptr + item_len * aux_index
buffer = (ctypes.c_byte * len(data)).from_address(dst_addr)
buffer[:] = data
return
class MooncakeKVManager(BaseKVManager): class MooncakeKVManager(BaseKVManager):
AUX_DATA_HEADER = b"AUX_DATA"
def __init__( def __init__(
self, self,
args: KVArgs, args: KVArgs,
...@@ -283,21 +306,10 @@ class MooncakeKVManager(BaseKVManager): ...@@ -283,21 +306,10 @@ class MooncakeKVManager(BaseKVManager):
if not transfer_blocks: if not transfer_blocks:
return 0 return 0
# TODO(shangming): Fix me when nvlink_transport of Mooncake is bug-free src_addrs, dst_addrs, lengths = zip(*transfer_blocks)
if self.enable_custom_mem_pool: return self.engine.batch_transfer_sync(
# batch_transfer_sync has a higher chance to trigger an accuracy drop for MNNVL, fallback to transfer_sync temporarily mooncake_session_id, list(src_addrs), list(dst_addrs), list(lengths)
for src_addr, dst_addr, length in transfer_blocks: )
status = self.engine.transfer_sync(
mooncake_session_id, src_addr, dst_addr, length
)
if status != 0:
return status
return 0
else:
src_addrs, dst_addrs, lengths = zip(*transfer_blocks)
return self.engine.batch_transfer_sync(
mooncake_session_id, list(src_addrs), list(dst_addrs), list(lengths)
)
def send_kvcache( def send_kvcache(
self, self,
...@@ -570,11 +582,14 @@ class MooncakeKVManager(BaseKVManager): ...@@ -570,11 +582,14 @@ class MooncakeKVManager(BaseKVManager):
def send_aux( def send_aux(
self, self,
mooncake_session_id: str, req: TransferInfo,
prefill_aux_index: int, prefill_aux_index: int,
dst_aux_ptrs: list[int], dst_aux_ptrs: list[int],
dst_aux_index: int,
): ):
# TODO(shangming): Fix me when nvlink_transport of Mooncake is bug-free
if self.enable_custom_mem_pool:
return self.send_aux_tcp(req, prefill_aux_index, dst_aux_ptrs)
transfer_blocks = [] transfer_blocks = []
prefill_aux_ptrs = self.kv_args.aux_data_ptrs prefill_aux_ptrs = self.kv_args.aux_data_ptrs
prefill_aux_item_lens = self.kv_args.aux_item_lens prefill_aux_item_lens = self.kv_args.aux_item_lens
...@@ -582,10 +597,59 @@ class MooncakeKVManager(BaseKVManager): ...@@ -582,10 +597,59 @@ class MooncakeKVManager(BaseKVManager):
for i, dst_aux_ptr in enumerate(dst_aux_ptrs): for i, dst_aux_ptr in enumerate(dst_aux_ptrs):
length = prefill_aux_item_lens[i] length = prefill_aux_item_lens[i]
src_addr = prefill_aux_ptrs[i] + length * prefill_aux_index src_addr = prefill_aux_ptrs[i] + length * prefill_aux_index
dst_addr = dst_aux_ptrs[i] + length * dst_aux_index dst_addr = dst_aux_ptrs[i] + length * req.dst_aux_index
transfer_blocks.append((src_addr, dst_addr, length)) transfer_blocks.append((src_addr, dst_addr, length))
return self._transfer_data(mooncake_session_id, transfer_blocks) return self._transfer_data(req.mooncake_session_id, transfer_blocks)
def send_aux_tcp(
self,
req: TransferInfo,
prefill_aux_index: int,
dst_aux_ptrs: list[int],
):
prefill_aux_ptrs = self.kv_args.aux_data_ptrs
prefill_aux_item_lens = self.kv_args.aux_item_lens
for i in range(len(prefill_aux_ptrs)):
length = prefill_aux_item_lens[i]
src_addr = prefill_aux_ptrs[i] + length * prefill_aux_index
data = AuxDataCodec.serialize_data_from_buffer(src_addr, length)
self.send_aux_data_to_endpoint(
remote=req.endpoint,
dst_port=req.dst_port,
room=req.room,
buffer_index=i,
aux_index=req.dst_aux_index,
data=data,
)
return 0
def send_aux_data_to_endpoint(
self,
remote: str,
dst_port: int,
room: int,
buffer_index: int,
aux_index: int,
data: bytes,
):
socket = self._connect(
format_tcp_address(remote, dst_port), is_ipv6=is_valid_ipv6_address(remote)
)
socket.send_multipart(
[
MooncakeKVManager.AUX_DATA_HEADER,
str(room).encode("ascii"),
str(buffer_index).encode("ascii"),
str(aux_index).encode("ascii"),
struct.pack(">I", len(data)),
data,
]
)
def sync_status_to_decode_endpoint( def sync_status_to_decode_endpoint(
self, remote: str, dst_port: int, room: int, status: int, prefill_rank: int self, remote: str, dst_port: int, room: int, status: int, prefill_rank: int
...@@ -699,10 +763,9 @@ class MooncakeKVManager(BaseKVManager): ...@@ -699,10 +763,9 @@ class MooncakeKVManager(BaseKVManager):
if self.pp_group.is_last_rank: if self.pp_group.is_last_rank:
# Only the last chunk we need to send the aux data # Only the last chunk we need to send the aux data
ret = self.send_aux( ret = self.send_aux(
req.mooncake_session_id, req,
kv_chunk.prefill_aux_index, kv_chunk.prefill_aux_index,
target_rank_registration_info.dst_aux_ptrs, target_rank_registration_info.dst_aux_ptrs,
req.dst_aux_index,
) )
polls.append(True if ret == 0 else False) polls.append(True if ret == 0 else False)
dst_ranks_infos.append( dst_ranks_infos.append(
...@@ -778,15 +841,38 @@ class MooncakeKVManager(BaseKVManager): ...@@ -778,15 +841,38 @@ class MooncakeKVManager(BaseKVManager):
threading.Thread(target=bootstrap_thread).start() threading.Thread(target=bootstrap_thread).start()
def _handle_aux_data(self, msg: List[bytes]):
"""Handle AUX_DATA messages received by the decode thread."""
room = int(msg[1].decode("ascii"))
buffer_index = int(msg[2].decode("ascii"))
aux_index = int(msg[3].decode("ascii"))
data_length = struct.unpack(">I", msg[4])[0]
data = msg[5]
if len(data) != data_length:
logger.error(f"AUX_DATA length mismatch for bootstrap_room {room}")
return
AuxDataCodec.deserialize_data_to_buffer(
self.kv_args, buffer_index, aux_index, data
)
logger.debug(
f"Received AUX_DATA for bootstrap_room {room} with length:{len(data)}"
)
def start_decode_thread(self): def start_decode_thread(self):
self.rank_port = get_free_port() self.rank_port = get_free_port()
self._bind_server_socket() self._bind_server_socket()
def decode_thread(): def decode_thread():
while True: while True:
(bootstrap_room, status, prefill_rank) = ( msg = self.server_socket.recv_multipart()
self.server_socket.recv_multipart() if msg[0] == MooncakeKVManager.AUX_DATA_HEADER:
) self._handle_aux_data(msg)
continue
(bootstrap_room, status, prefill_rank) = msg
status = int(status.decode("ascii")) status = int(status.decode("ascii"))
bootstrap_room = int(bootstrap_room.decode("ascii")) bootstrap_room = int(bootstrap_room.decode("ascii"))
prefill_rank = int(prefill_rank.decode("ascii")) prefill_rank = int(prefill_rank.decode("ascii"))
......
...@@ -99,7 +99,8 @@ class MetadataBuffers: ...@@ -99,7 +99,8 @@ class MetadataBuffers:
# For ascend backend, output tokens are placed in the NPU and will be transferred by D2D channel. # For ascend backend, output tokens are placed in the NPU and will be transferred by D2D channel.
device = "npu" device = "npu"
elif self.custom_mem_pool: elif self.custom_mem_pool:
device = "cuda" # TODO(shangming): Fix me (use 'cuda') when nvlink_transport of Mooncake is bug-free
device = "cpu"
with ( with (
torch.cuda.use_mem_pool(self.custom_mem_pool) torch.cuda.use_mem_pool(self.custom_mem_pool)
if self.custom_mem_pool if self.custom_mem_pool
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment