Unverified Commit 97ac42b6 authored by Yongtong Wu's avatar Yongtong Wu Committed by GitHub
Browse files

[PD] NIXL backend Prefill TP & Decode TP+DP (#5681)

parent 1acca3a2
...@@ -10,7 +10,7 @@ import threading ...@@ -10,7 +10,7 @@ import threading
import uuid import uuid
from collections import defaultdict from collections import defaultdict
from functools import cache from functools import cache
from typing import Dict, List, Optional, Tuple, Union from typing import Dict, List, Optional, Set, Tuple, TypeAlias, Union
import numpy as np import numpy as np
import numpy.typing as npt import numpy.typing as npt
...@@ -32,6 +32,38 @@ from sglang.srt.utils import get_free_port, get_ip, get_local_ip_by_remote ...@@ -32,6 +32,38 @@ from sglang.srt.utils import get_free_port, get_ip, get_local_ip_by_remote
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
NixlEngineInfo: TypeAlias = Dict[str, Union[str, int]]
# From Mooncake backend.
def group_concurrent_contiguous(
src_indices: npt.NDArray[np.int64], dst_indices: npt.NDArray[np.int64]
) -> Tuple[List[npt.NDArray[np.int64]], List[npt.NDArray[np.int64]]]:
src_groups = []
dst_groups = []
current_src = [src_indices[0]]
current_dst = [dst_indices[0]]
for i in range(1, len(src_indices)):
src_contiguous = src_indices[i] == src_indices[i - 1] + 1
dst_contiguous = dst_indices[i] == dst_indices[i - 1] + 1
if src_contiguous and dst_contiguous:
current_src.append(src_indices[i])
current_dst.append(dst_indices[i])
else:
src_groups.append(current_src)
dst_groups.append(current_dst)
current_src = [src_indices[i]]
current_dst = [dst_indices[i]]
src_groups.append(current_src)
dst_groups.append(current_dst)
return src_groups, dst_groups
GUARD = "NixlMsgGuard".encode("ascii")
@dataclasses.dataclass @dataclasses.dataclass
class TransferInfo: class TransferInfo:
...@@ -45,8 +77,25 @@ class TransferInfo: ...@@ -45,8 +77,25 @@ class TransferInfo:
dst_aux_index: int dst_aux_index: int
dst_gpu_id: int dst_gpu_id: int
def is_dummy(self):
return self.endpoint == ""
@classmethod @classmethod
def from_zmq(cls, msg: List[bytes]): def from_zmq(cls, msg: List[bytes]):
if len(msg) == 1:
# dummy msg
return cls(
room=int(msg[0].decode("ascii")),
endpoint="",
dst_port=0,
agent_metadata=b"",
dst_kv_ptrs=[],
dst_kv_indices=np.array([], dtype=np.int64),
dst_aux_ptrs=[],
dst_aux_index=0,
dst_gpu_id=0,
)
else:
return cls( return cls(
room=int(msg[0].decode("ascii")), room=int(msg[0].decode("ascii")),
endpoint=msg[1].decode("ascii"), endpoint=msg[1].decode("ascii"),
...@@ -98,6 +147,19 @@ class NixlKVManager(BaseKVManager): ...@@ -98,6 +147,19 @@ class NixlKVManager(BaseKVManager):
# for p/d multi node infer # for p/d multi node infer
self.bootstrap_port = server_args.disaggregation_bootstrap_port self.bootstrap_port = server_args.disaggregation_bootstrap_port
self.dist_init_addr = server_args.dist_init_addr self.dist_init_addr = server_args.dist_init_addr
self.tp_size = server_args.tp_size
self.tp_rank = args.engine_rank
self.enable_dp_attention = server_args.enable_dp_attention
if self.enable_dp_attention:
assert (
server_args.dp_size > 1
), "If dp_attention is enabled, dp size must be greater than 1 in disaggregation mode."
self.dp_size = server_args.dp_size
self.tp_size_of_dp = server_args.tp_size // server_args.dp_size
self.attn_tp_rank = args.engine_rank % self.tp_size_of_dp
self.dp_rank = args.engine_rank // self.tp_size_of_dp
self.rank_port = None self.rank_port = None
self.server_socket = zmq.Context().socket(zmq.PULL) self.server_socket = zmq.Context().socket(zmq.PULL)
self.register_buffer_to_engine() self.register_buffer_to_engine()
...@@ -110,7 +172,8 @@ class NixlKVManager(BaseKVManager): ...@@ -110,7 +172,8 @@ class NixlKVManager(BaseKVManager):
self._start_bootstrap_thread() self._start_bootstrap_thread()
self._register_to_bootstrap() self._register_to_bootstrap()
elif self.disaggregation_mode == DisaggregationMode.DECODE: elif self.disaggregation_mode == DisaggregationMode.DECODE:
self.connection_pool: Dict[str, Dict[str, Union[str, int]]] = {} # bootstrap key -> (remote_engine_rank -> possible remote source info)
self.prefill_peer_infos: Dict[str, list[Dict[int, NixlEngineInfo]]] = {}
self.transfer_statuses: Dict[int, TransferStatus] = defaultdict( self.transfer_statuses: Dict[int, TransferStatus] = defaultdict(
TransferStatus TransferStatus
) )
...@@ -126,6 +189,7 @@ class NixlKVManager(BaseKVManager): ...@@ -126,6 +189,7 @@ class NixlKVManager(BaseKVManager):
): ):
kv_addrs.append((kv_data_ptr, kv_data_len, self.kv_args.gpu_id, "")) kv_addrs.append((kv_data_ptr, kv_data_len, self.kv_args.gpu_id, ""))
self.kv_descs = self.agent.register_memory(kv_addrs, "VRAM", is_sorted=True) self.kv_descs = self.agent.register_memory(kv_addrs, "VRAM", is_sorted=True)
logger.debug(f"Register kv tensors, len(kv_addr)= {len(kv_addrs)}")
if not self.kv_descs: if not self.kv_descs:
raise Exception("NIXL memory registration failed for kv tensors") raise Exception("NIXL memory registration failed for kv tensors")
aux_addrs = [] aux_addrs = []
...@@ -134,6 +198,7 @@ class NixlKVManager(BaseKVManager): ...@@ -134,6 +198,7 @@ class NixlKVManager(BaseKVManager):
): ):
aux_addrs.append((aux_data_ptr, aux_data_len, 0, "")) aux_addrs.append((aux_data_ptr, aux_data_len, 0, ""))
self.aux_descs = self.agent.register_memory(aux_addrs, "DRAM", is_sorted=True) self.aux_descs = self.agent.register_memory(aux_addrs, "DRAM", is_sorted=True)
logger.debug(f"Register aux tensors, len(aux_addrs)= {len(aux_addrs)}")
if not self.aux_descs: if not self.aux_descs:
raise Exception("NIXL memory registration failed for aux tensors") raise Exception("NIXL memory registration failed for aux tensors")
...@@ -157,6 +222,12 @@ class NixlKVManager(BaseKVManager): ...@@ -157,6 +222,12 @@ class NixlKVManager(BaseKVManager):
dst_gpu_id: int, dst_gpu_id: int,
notif: str, notif: str,
): ):
# group by indices
prefill_kv_blocks, dst_kv_blocks = group_concurrent_contiguous(
prefill_kv_indices, dst_kv_indices
)
logger.debug(f"sending kvcache to {peer_name} with notif {notif}")
# Make descs # Make descs
num_layers = len(self.kv_args.kv_data_ptrs) num_layers = len(self.kv_args.kv_data_ptrs)
src_addrs = [] src_addrs = []
...@@ -166,12 +237,16 @@ class NixlKVManager(BaseKVManager): ...@@ -166,12 +237,16 @@ class NixlKVManager(BaseKVManager):
dst_ptr = dst_kv_ptrs[layer_id] dst_ptr = dst_kv_ptrs[layer_id]
item_len = self.kv_args.kv_item_lens[layer_id] item_len = self.kv_args.kv_item_lens[layer_id]
for prefill_index, decode_index in zip(prefill_kv_indices, dst_kv_indices): for prefill_index, decode_index in zip(prefill_kv_blocks, dst_kv_blocks):
src_addr = src_ptr + int(prefill_index) * item_len src_addr = src_ptr + int(prefill_index[0]) * item_len
dst_addr = dst_ptr + int(decode_index) * item_len dst_addr = dst_ptr + int(decode_index[0]) * item_len
length = item_len length = item_len * len(prefill_index)
src_addrs.append((src_addr, length, self.kv_args.gpu_id)) src_addrs.append((src_addr, length, self.kv_args.gpu_id))
dst_addrs.append((dst_addr, length, dst_gpu_id)) dst_addrs.append((dst_addr, length, dst_gpu_id))
logger.debug(
f"len(src_addrs): before group: {len(prefill_kv_indices)}, after group: {len(src_addrs)}"
)
src_descs = self.agent.get_xfer_descs(src_addrs, "VRAM", is_sorted=True) src_descs = self.agent.get_xfer_descs(src_addrs, "VRAM", is_sorted=True)
dst_descs = self.agent.get_xfer_descs(dst_addrs, "VRAM", is_sorted=True) dst_descs = self.agent.get_xfer_descs(dst_addrs, "VRAM", is_sorted=True)
# Transfer data # Transfer data
...@@ -180,7 +255,7 @@ class NixlKVManager(BaseKVManager): ...@@ -180,7 +255,7 @@ class NixlKVManager(BaseKVManager):
src_descs, src_descs,
dst_descs, dst_descs,
peer_name, peer_name,
notif.encode("ascii"), notif.encode("ascii"), # type: ignore
) )
if not xfer_handle: if not xfer_handle:
raise Exception("KVSender failed to create transfer") raise Exception("KVSender failed to create transfer")
...@@ -213,7 +288,7 @@ class NixlKVManager(BaseKVManager): ...@@ -213,7 +288,7 @@ class NixlKVManager(BaseKVManager):
src_descs, src_descs,
dst_descs, dst_descs,
peer_name, peer_name,
notif.encode("ascii"), notif.encode("ascii"), # type: ignore
) )
if not xfer_handle: if not xfer_handle:
raise Exception("KVSender failed to create transfer") raise Exception("KVSender failed to create transfer")
...@@ -240,6 +315,9 @@ class NixlKVManager(BaseKVManager): ...@@ -240,6 +315,9 @@ class NixlKVManager(BaseKVManager):
req = self.transfer_infos[bootstrap_room] req = self.transfer_infos[bootstrap_room]
assert bootstrap_room == req.room assert bootstrap_room == req.room
if req.is_dummy():
return []
peer_name = self._add_remote(bootstrap_room, req.agent_metadata) peer_name = self._add_remote(bootstrap_room, req.agent_metadata)
chunked_dst_kv_indice = req.dst_kv_indices[index_slice] chunked_dst_kv_indice = req.dst_kv_indices[index_slice]
assert len(chunked_dst_kv_indice) == len(kv_indices) assert len(chunked_dst_kv_indice) == len(kv_indices)
...@@ -256,6 +334,7 @@ class NixlKVManager(BaseKVManager): ...@@ -256,6 +334,7 @@ class NixlKVManager(BaseKVManager):
handles = [kv_xfer_handle] handles = [kv_xfer_handle]
# Only the last chunk we need to send the aux data. # Only the last chunk we need to send the aux data.
if is_last: if is_last:
assert aux_index is not None
aux_xfer_handle = self.send_aux( aux_xfer_handle = self.send_aux(
peer_name, peer_name,
aux_index, aux_index,
...@@ -325,6 +404,13 @@ class NixlKVManager(BaseKVManager): ...@@ -325,6 +404,13 @@ class NixlKVManager(BaseKVManager):
"""This thread recvs transfer info from the decode engine""" """This thread recvs transfer info from the decode engine"""
while True: while True:
waiting_req_bytes = self.server_socket.recv_multipart() waiting_req_bytes = self.server_socket.recv_multipart()
logger.debug(
f"Received multipart with total byte size {sum(len(x) for x in waiting_req_bytes)}"
)
assert (
waiting_req_bytes[0] == GUARD
), f"First message should be {GUARD}. Foreign traffic?"
waiting_req_bytes = waiting_req_bytes[1:]
room = waiting_req_bytes[0].decode("ascii") room = waiting_req_bytes[0].decode("ascii")
if room == "None": if room == "None":
continue continue
...@@ -372,14 +458,13 @@ class NixlKVSender(BaseKVSender): ...@@ -372,14 +458,13 @@ class NixlKVSender(BaseKVSender):
def poll(self) -> KVPoll: def poll(self) -> KVPoll:
if not self.has_sent: if not self.has_sent:
return KVPoll.WaitingForInput return KVPoll.WaitingForInput # type: ignore
states = [self.kv_mgr.agent.check_xfer_state(x) for x in self.xfer_handles] states = [self.kv_mgr.agent.check_xfer_state(x) for x in self.xfer_handles]
if all([x == "DONE" for x in states]): if all([x == "DONE" for x in states]):
return KVPoll.Success return KVPoll.Success # type: ignore
if any([x == "ERR" for x in states]): if any([x == "ERR" for x in states]):
raise Exception("KVSender transfer encountered an error.") raise Exception("KVSender transfer encountered an error.")
return KVPoll.WaitingForInput return KVPoll.WaitingForInput # type: ignore
def failure_exception(self): def failure_exception(self):
raise Exception("Fake KVSender Exception") raise Exception("Fake KVSender Exception")
...@@ -401,7 +486,7 @@ class NixlKVReceiver(BaseKVReceiver): ...@@ -401,7 +486,7 @@ class NixlKVReceiver(BaseKVReceiver):
# NOTE: key distinguished by bootstrap_addr and engine_rank # NOTE: key distinguished by bootstrap_addr and engine_rank
bootstrap_key = f"{self.bootstrap_addr}_{self.kv_mgr.kv_args.engine_rank}" bootstrap_key = f"{self.bootstrap_addr}_{self.kv_mgr.kv_args.engine_rank}"
if bootstrap_key not in self.kv_mgr.connection_pool: if bootstrap_key not in self.kv_mgr.prefill_peer_infos:
self.bootstrap_info = self._get_bootstrap_info_from_server( self.bootstrap_info = self._get_bootstrap_info_from_server(
self.kv_mgr.kv_args.engine_rank self.kv_mgr.kv_args.engine_rank
) )
...@@ -410,20 +495,74 @@ class NixlKVReceiver(BaseKVReceiver): ...@@ -410,20 +495,74 @@ class NixlKVReceiver(BaseKVReceiver):
f"Could not fetch bootstrap info for engine rank: {self.kv_mgr.kv_args.engine_rank}" f"Could not fetch bootstrap info for engine rank: {self.kv_mgr.kv_args.engine_rank}"
) )
else: else:
self.kv_mgr.connection_pool[bootstrap_key] = self.bootstrap_info self.kv_mgr.prefill_peer_infos[bootstrap_key] = self.bootstrap_info
else: else:
self.bootstrap_info = self.kv_mgr.connection_pool[bootstrap_key] self.bootstrap_info = self.kv_mgr.prefill_peer_infos[bootstrap_key]
assert self.bootstrap_info is not None assert self.bootstrap_info is not None
def _get_bootstrap_info_from_server(self, engine_rank): # return a list of remotes in a dict, [(remote_engine_rank -> NixlEngineInfo), ...]
# In each dict, there are multiple possible remotes named "equal sources".
# We only need to select one to split the traffic. i.e. we totally select len(list) remotes.
def _get_bootstrap_info_from_server(
self, engine_rank
) -> Optional[List[Dict[int, NixlEngineInfo]]]:
"""Fetch the bootstrap info from the bootstrap server.""" """Fetch the bootstrap info from the bootstrap server."""
try: try:
if self.kv_mgr.enable_dp_attention:
url = f"http://{self.bootstrap_addr}/route"
response = requests.get(url)
if response.status_code != 200:
logger.error(
f"Failed to get prefill server info: {response.status_code}, {response.text}"
)
return None
bootstrap_info = response.json()
assert isinstance(bootstrap_info, dict)
bootstrap_info = {int(k): v for k, v in bootstrap_info.items()}
# split out who need to send to this rank.
# currently for dpsk mla model, those ranks share the same latent cache.
# pick one as the real source
prefill_tp_size = len(bootstrap_info.keys())
assert (
prefill_tp_size >= self.kv_mgr.tp_size_of_dp
), f"Only support Prefill TP size >= Decode TP size of DP, now we have {prefill_tp_size} vs {self.kv_mgr.tp_size_of_dp}"
num_remote_tp_rank_we_managed = (
prefill_tp_size // self.kv_mgr.tp_size_of_dp
)
# We handle [num * self.attn_tp_rank, num * self.attn_tp_rank + num)
remote_tp_ranks = list(range(0, prefill_tp_size))
# split it into tp_size_of_dp parts and get our part
remote_tp_ranks_grouped = [
remote_tp_ranks[i : i + num_remote_tp_rank_we_managed]
for i in range(0, prefill_tp_size, self.kv_mgr.tp_size_of_dp)
]
managed_ranks = remote_tp_ranks_grouped[self.kv_mgr.attn_tp_rank]
assert len(managed_ranks) == num_remote_tp_rank_we_managed
logger.debug(
f"Rank {self.kv_mgr.kv_args.engine_rank} source can be {managed_ranks}"
)
return [
{
rk: bootstrap_info[rk]
for rk in bootstrap_info.keys()
if rk in managed_ranks
}
]
else:
url = f"http://{self.bootstrap_addr}/route?engine_rank={engine_rank}" url = f"http://{self.bootstrap_addr}/route?engine_rank={engine_rank}"
response = requests.get(url) response = requests.get(url)
if response.status_code == 200: if response.status_code == 200:
bootstrap_info = response.json() bootstrap_info = response.json()
return bootstrap_info return [{engine_rank: bootstrap_info}]
else: else:
logger.error( logger.error(
f"Failed to get prefill server info: {response.status_code}, {response.text}" f"Failed to get prefill server info: {response.status_code}, {response.text}"
...@@ -440,11 +579,17 @@ class NixlKVReceiver(BaseKVReceiver): ...@@ -440,11 +579,17 @@ class NixlKVReceiver(BaseKVReceiver):
return socket return socket
def init(self, kv_indices: npt.NDArray[np.int64], aux_index: Optional[int] = None): def init(self, kv_indices: npt.NDArray[np.int64], aux_index: Optional[int] = None):
self.prefill_server_url = (
f"{self.bootstrap_info['rank_ip']}:{self.bootstrap_info['rank_port']}" assert self.bootstrap_info is not None
) assert self.bootstrap_room is not None
for equal_sources in self.bootstrap_info:
remote_rank = list(equal_sources.keys())[
self.bootstrap_room % len(equal_sources)
]
self.prefill_server_url = f"{equal_sources[remote_rank]['rank_ip']}:{equal_sources[remote_rank]['rank_port']}"
logger.debug( logger.debug(
f"Fetched bootstrap info: {self.bootstrap_info} for engine rank: {self.kv_mgr.kv_args.engine_rank}" f"Fetched bootstrap info for engine rank: {self.kv_mgr.kv_args.engine_rank}, source: {remote_rank}, all: {list(equal_sources.keys())}"
) )
packed_kv_data_ptrs = b"".join( packed_kv_data_ptrs = b"".join(
...@@ -453,8 +598,13 @@ class NixlKVReceiver(BaseKVReceiver): ...@@ -453,8 +598,13 @@ class NixlKVReceiver(BaseKVReceiver):
packed_aux_data_ptrs = b"".join( packed_aux_data_ptrs = b"".join(
struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.aux_data_ptrs struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.aux_data_ptrs
) )
logger.debug(
f"Sending to {self.prefill_server_url} with bootstrap room {self.bootstrap_room}"
)
self._connect("tcp://" + self.prefill_server_url).send_multipart( self._connect("tcp://" + self.prefill_server_url).send_multipart(
[ [
GUARD,
str(self.bootstrap_room).encode("ascii"), str(self.bootstrap_room).encode("ascii"),
get_local_ip_by_remote().encode("ascii"), get_local_ip_by_remote().encode("ascii"),
str(self.kv_mgr.rank_port).encode("ascii"), str(self.kv_mgr.rank_port).encode("ascii"),
...@@ -466,17 +616,30 @@ class NixlKVReceiver(BaseKVReceiver): ...@@ -466,17 +616,30 @@ class NixlKVReceiver(BaseKVReceiver):
str(self.kv_mgr.kv_args.gpu_id).encode("ascii"), str(self.kv_mgr.kv_args.gpu_id).encode("ascii"),
] ]
) )
for dummy_rank in equal_sources.keys():
if dummy_rank == remote_rank:
continue
dummy_info = equal_sources[dummy_rank]
dummy_url = f"{dummy_info['rank_ip']}:{dummy_info['rank_port']}"
self._connect("tcp://" + dummy_url).send_multipart(
[
GUARD,
str(self.bootstrap_room).encode("ascii"),
]
)
self.started_transfer = True self.started_transfer = True
def poll(self) -> KVPoll: def poll(self) -> KVPoll:
if not self.started_transfer: if not self.started_transfer:
return KVPoll.WaitingForInput return KVPoll.WaitingForInput # type: ignore
self.kv_mgr.update_transfer_status() self.kv_mgr.update_transfer_status()
if self.kv_mgr.check_transfer_done(self.bootstrap_room): if self.kv_mgr.check_transfer_done(self.bootstrap_room): # type: ignore
return KVPoll.Success return KVPoll.Success # type: ignore
return KVPoll.WaitingForInput return KVPoll.WaitingForInput # type: ignore
def failure_exception(self): def failure_exception(self):
raise Exception("Fake KVReceiver Exception") raise Exception("Fake KVReceiver Exception")
...@@ -484,6 +647,7 @@ class NixlKVReceiver(BaseKVReceiver): ...@@ -484,6 +647,7 @@ class NixlKVReceiver(BaseKVReceiver):
class NixlKVBootstrapServer(BaseKVBootstrapServer): class NixlKVBootstrapServer(BaseKVBootstrapServer):
def __init__(self, port: int): def __init__(self, port: int):
logger.debug(f"NixlKVBootstrapServer started on port {port}")
self.port = port self.port = port
self.app = web.Application() self.app = web.Application()
self.store = dict() self.store = dict()
...@@ -564,8 +728,8 @@ class NixlKVBootstrapServer(BaseKVBootstrapServer): ...@@ -564,8 +728,8 @@ class NixlKVBootstrapServer(BaseKVBootstrapServer):
engine_rank = int(data["engine_rank"]) engine_rank = int(data["engine_rank"])
agent_name = data["agent_name"] agent_name = data["agent_name"]
# Add lock to make sure thread-safe
if role == "Prefill": if role == "Prefill":
async with self.lock:
self.prefill_port_table[engine_rank] = { self.prefill_port_table[engine_rank] = {
"rank_ip": rank_ip, "rank_ip": rank_ip,
"rank_port": rank_port, "rank_port": rank_port,
...@@ -580,7 +744,13 @@ class NixlKVBootstrapServer(BaseKVBootstrapServer): ...@@ -580,7 +744,13 @@ class NixlKVBootstrapServer(BaseKVBootstrapServer):
async def _handle_route_get(self, request: web.Request): async def _handle_route_get(self, request: web.Request):
engine_rank = request.query.get("engine_rank") engine_rank = request.query.get("engine_rank")
if not engine_rank: if not engine_rank:
return web.Response(text="Missing rank", status=400) logger.debug(
f"No engine_rank specified, return all {len(self.prefill_port_table)} engine infos as a dict"
)
# Return a dict of all engine_rank
async with self.lock:
bootstrap_info = self.prefill_port_table
return web.json_response(bootstrap_info, status=200)
# Find corresponding prefill info # Find corresponding prefill info
async with self.lock: async with self.lock:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment