Unverified Commit 97ac42b6 authored by Yongtong Wu's avatar Yongtong Wu Committed by GitHub
Browse files

[PD] NIXL backend Prefill TP & Decode TP+DP (#5681)

parent 1acca3a2
......@@ -10,7 +10,7 @@ import threading
import uuid
from collections import defaultdict
from functools import cache
from typing import Dict, List, Optional, Tuple, Union
from typing import Dict, List, Optional, Set, Tuple, TypeAlias, Union
import numpy as np
import numpy.typing as npt
......@@ -32,6 +32,38 @@ from sglang.srt.utils import get_free_port, get_ip, get_local_ip_by_remote
logger = logging.getLogger(__name__)
NixlEngineInfo: TypeAlias = Dict[str, Union[str, int]]
# From Mooncake backend.
def group_concurrent_contiguous(
src_indices: npt.NDArray[np.int64], dst_indices: npt.NDArray[np.int64]
) -> Tuple[List[npt.NDArray[np.int64]], List[npt.NDArray[np.int64]]]:
src_groups = []
dst_groups = []
current_src = [src_indices[0]]
current_dst = [dst_indices[0]]
for i in range(1, len(src_indices)):
src_contiguous = src_indices[i] == src_indices[i - 1] + 1
dst_contiguous = dst_indices[i] == dst_indices[i - 1] + 1
if src_contiguous and dst_contiguous:
current_src.append(src_indices[i])
current_dst.append(dst_indices[i])
else:
src_groups.append(current_src)
dst_groups.append(current_dst)
current_src = [src_indices[i]]
current_dst = [dst_indices[i]]
src_groups.append(current_src)
dst_groups.append(current_dst)
return src_groups, dst_groups
GUARD = "NixlMsgGuard".encode("ascii")
@dataclasses.dataclass
class TransferInfo:
......@@ -45,19 +77,36 @@ class TransferInfo:
dst_aux_index: int
dst_gpu_id: int
def is_dummy(self):
return self.endpoint == ""
@classmethod
def from_zmq(cls, msg: List[bytes]):
return cls(
room=int(msg[0].decode("ascii")),
endpoint=msg[1].decode("ascii"),
dst_port=int(msg[2].decode("ascii")),
agent_metadata=msg[3],
dst_kv_ptrs=list(struct.unpack(f"{len(msg[4])//8}Q", msg[4])),
dst_kv_indices=np.frombuffer(msg[5], dtype=np.int64),
dst_aux_ptrs=list(struct.unpack(f"{len(msg[6])//8}Q", msg[6])),
dst_aux_index=int(msg[7].decode("ascii")),
dst_gpu_id=int(msg[8].decode("ascii")),
)
if len(msg) == 1:
# dummy msg
return cls(
room=int(msg[0].decode("ascii")),
endpoint="",
dst_port=0,
agent_metadata=b"",
dst_kv_ptrs=[],
dst_kv_indices=np.array([], dtype=np.int64),
dst_aux_ptrs=[],
dst_aux_index=0,
dst_gpu_id=0,
)
else:
return cls(
room=int(msg[0].decode("ascii")),
endpoint=msg[1].decode("ascii"),
dst_port=int(msg[2].decode("ascii")),
agent_metadata=msg[3],
dst_kv_ptrs=list(struct.unpack(f"{len(msg[4])//8}Q", msg[4])),
dst_kv_indices=np.frombuffer(msg[5], dtype=np.int64),
dst_aux_ptrs=list(struct.unpack(f"{len(msg[6])//8}Q", msg[6])),
dst_aux_index=int(msg[7].decode("ascii")),
dst_gpu_id=int(msg[8].decode("ascii")),
)
@dataclasses.dataclass
......@@ -98,6 +147,19 @@ class NixlKVManager(BaseKVManager):
# for p/d multi node infer
self.bootstrap_port = server_args.disaggregation_bootstrap_port
self.dist_init_addr = server_args.dist_init_addr
self.tp_size = server_args.tp_size
self.tp_rank = args.engine_rank
self.enable_dp_attention = server_args.enable_dp_attention
if self.enable_dp_attention:
assert (
server_args.dp_size > 1
), "If dp_attention is enabled, dp size must be greater than 1 in disaggregation mode."
self.dp_size = server_args.dp_size
self.tp_size_of_dp = server_args.tp_size // server_args.dp_size
self.attn_tp_rank = args.engine_rank % self.tp_size_of_dp
self.dp_rank = args.engine_rank // self.tp_size_of_dp
self.rank_port = None
self.server_socket = zmq.Context().socket(zmq.PULL)
self.register_buffer_to_engine()
......@@ -110,7 +172,8 @@ class NixlKVManager(BaseKVManager):
self._start_bootstrap_thread()
self._register_to_bootstrap()
elif self.disaggregation_mode == DisaggregationMode.DECODE:
self.connection_pool: Dict[str, Dict[str, Union[str, int]]] = {}
# bootstrap key -> (remote_engine_rank -> possible remote source info)
self.prefill_peer_infos: Dict[str, list[Dict[int, NixlEngineInfo]]] = {}
self.transfer_statuses: Dict[int, TransferStatus] = defaultdict(
TransferStatus
)
......@@ -126,6 +189,7 @@ class NixlKVManager(BaseKVManager):
):
kv_addrs.append((kv_data_ptr, kv_data_len, self.kv_args.gpu_id, ""))
self.kv_descs = self.agent.register_memory(kv_addrs, "VRAM", is_sorted=True)
logger.debug(f"Register kv tensors, len(kv_addr)= {len(kv_addrs)}")
if not self.kv_descs:
raise Exception("NIXL memory registration failed for kv tensors")
aux_addrs = []
......@@ -134,6 +198,7 @@ class NixlKVManager(BaseKVManager):
):
aux_addrs.append((aux_data_ptr, aux_data_len, 0, ""))
self.aux_descs = self.agent.register_memory(aux_addrs, "DRAM", is_sorted=True)
logger.debug(f"Register aux tensors, len(aux_addrs)= {len(aux_addrs)}")
if not self.aux_descs:
raise Exception("NIXL memory registration failed for aux tensors")
......@@ -157,6 +222,12 @@ class NixlKVManager(BaseKVManager):
dst_gpu_id: int,
notif: str,
):
# group by indices
prefill_kv_blocks, dst_kv_blocks = group_concurrent_contiguous(
prefill_kv_indices, dst_kv_indices
)
logger.debug(f"sending kvcache to {peer_name} with notif {notif}")
# Make descs
num_layers = len(self.kv_args.kv_data_ptrs)
src_addrs = []
......@@ -166,12 +237,16 @@ class NixlKVManager(BaseKVManager):
dst_ptr = dst_kv_ptrs[layer_id]
item_len = self.kv_args.kv_item_lens[layer_id]
for prefill_index, decode_index in zip(prefill_kv_indices, dst_kv_indices):
src_addr = src_ptr + int(prefill_index) * item_len
dst_addr = dst_ptr + int(decode_index) * item_len
length = item_len
for prefill_index, decode_index in zip(prefill_kv_blocks, dst_kv_blocks):
src_addr = src_ptr + int(prefill_index[0]) * item_len
dst_addr = dst_ptr + int(decode_index[0]) * item_len
length = item_len * len(prefill_index)
src_addrs.append((src_addr, length, self.kv_args.gpu_id))
dst_addrs.append((dst_addr, length, dst_gpu_id))
logger.debug(
f"len(src_addrs): before group: {len(prefill_kv_indices)}, after group: {len(src_addrs)}"
)
src_descs = self.agent.get_xfer_descs(src_addrs, "VRAM", is_sorted=True)
dst_descs = self.agent.get_xfer_descs(dst_addrs, "VRAM", is_sorted=True)
# Transfer data
......@@ -180,7 +255,7 @@ class NixlKVManager(BaseKVManager):
src_descs,
dst_descs,
peer_name,
notif.encode("ascii"),
notif.encode("ascii"), # type: ignore
)
if not xfer_handle:
raise Exception("KVSender failed to create transfer")
......@@ -213,7 +288,7 @@ class NixlKVManager(BaseKVManager):
src_descs,
dst_descs,
peer_name,
notif.encode("ascii"),
notif.encode("ascii"), # type: ignore
)
if not xfer_handle:
raise Exception("KVSender failed to create transfer")
......@@ -240,6 +315,9 @@ class NixlKVManager(BaseKVManager):
req = self.transfer_infos[bootstrap_room]
assert bootstrap_room == req.room
if req.is_dummy():
return []
peer_name = self._add_remote(bootstrap_room, req.agent_metadata)
chunked_dst_kv_indice = req.dst_kv_indices[index_slice]
assert len(chunked_dst_kv_indice) == len(kv_indices)
......@@ -256,6 +334,7 @@ class NixlKVManager(BaseKVManager):
handles = [kv_xfer_handle]
# Only the last chunk we need to send the aux data.
if is_last:
assert aux_index is not None
aux_xfer_handle = self.send_aux(
peer_name,
aux_index,
......@@ -325,6 +404,13 @@ class NixlKVManager(BaseKVManager):
"""This thread recvs transfer info from the decode engine"""
while True:
waiting_req_bytes = self.server_socket.recv_multipart()
logger.debug(
f"Received multipart with total byte size {sum(len(x) for x in waiting_req_bytes)}"
)
assert (
waiting_req_bytes[0] == GUARD
), f"First message should be {GUARD}. Foreign traffic?"
waiting_req_bytes = waiting_req_bytes[1:]
room = waiting_req_bytes[0].decode("ascii")
if room == "None":
continue
......@@ -372,14 +458,13 @@ class NixlKVSender(BaseKVSender):
def poll(self) -> KVPoll:
if not self.has_sent:
return KVPoll.WaitingForInput
return KVPoll.WaitingForInput # type: ignore
states = [self.kv_mgr.agent.check_xfer_state(x) for x in self.xfer_handles]
if all([x == "DONE" for x in states]):
return KVPoll.Success
return KVPoll.Success # type: ignore
if any([x == "ERR" for x in states]):
raise Exception("KVSender transfer encountered an error.")
return KVPoll.WaitingForInput
return KVPoll.WaitingForInput # type: ignore
def failure_exception(self):
raise Exception("Fake KVSender Exception")
......@@ -401,7 +486,7 @@ class NixlKVReceiver(BaseKVReceiver):
# NOTE: key distinguished by bootstrap_addr and engine_rank
bootstrap_key = f"{self.bootstrap_addr}_{self.kv_mgr.kv_args.engine_rank}"
if bootstrap_key not in self.kv_mgr.connection_pool:
if bootstrap_key not in self.kv_mgr.prefill_peer_infos:
self.bootstrap_info = self._get_bootstrap_info_from_server(
self.kv_mgr.kv_args.engine_rank
)
......@@ -410,25 +495,79 @@ class NixlKVReceiver(BaseKVReceiver):
f"Could not fetch bootstrap info for engine rank: {self.kv_mgr.kv_args.engine_rank}"
)
else:
self.kv_mgr.connection_pool[bootstrap_key] = self.bootstrap_info
self.kv_mgr.prefill_peer_infos[bootstrap_key] = self.bootstrap_info
else:
self.bootstrap_info = self.kv_mgr.connection_pool[bootstrap_key]
self.bootstrap_info = self.kv_mgr.prefill_peer_infos[bootstrap_key]
assert self.bootstrap_info is not None
def _get_bootstrap_info_from_server(self, engine_rank):
# return a list of remotes in a dict, [(remote_engine_rank -> NixlEngineInfo), ...]
# In each dict, there are multiple possible remotes named "equal sources".
# We only need to select one to split the traffic. i.e. we totally select len(list) remotes.
def _get_bootstrap_info_from_server(
self, engine_rank
) -> Optional[List[Dict[int, NixlEngineInfo]]]:
"""Fetch the bootstrap info from the bootstrap server."""
try:
url = f"http://{self.bootstrap_addr}/route?engine_rank={engine_rank}"
response = requests.get(url)
if response.status_code == 200:
if self.kv_mgr.enable_dp_attention:
url = f"http://{self.bootstrap_addr}/route"
response = requests.get(url)
if response.status_code != 200:
logger.error(
f"Failed to get prefill server info: {response.status_code}, {response.text}"
)
return None
bootstrap_info = response.json()
return bootstrap_info
else:
logger.error(
f"Failed to get prefill server info: {response.status_code}, {response.text}"
assert isinstance(bootstrap_info, dict)
bootstrap_info = {int(k): v for k, v in bootstrap_info.items()}
# split out who need to send to this rank.
# currently for dpsk mla model, those ranks share the same latent cache.
# pick one as the real source
prefill_tp_size = len(bootstrap_info.keys())
assert (
prefill_tp_size >= self.kv_mgr.tp_size_of_dp
), f"Only support Prefill TP size >= Decode TP size of DP, now we have {prefill_tp_size} vs {self.kv_mgr.tp_size_of_dp}"
num_remote_tp_rank_we_managed = (
prefill_tp_size // self.kv_mgr.tp_size_of_dp
)
# We handle [num * self.attn_tp_rank, num * self.attn_tp_rank + num)
remote_tp_ranks = list(range(0, prefill_tp_size))
# split it into tp_size_of_dp parts and get our part
remote_tp_ranks_grouped = [
remote_tp_ranks[i : i + num_remote_tp_rank_we_managed]
for i in range(0, prefill_tp_size, self.kv_mgr.tp_size_of_dp)
]
managed_ranks = remote_tp_ranks_grouped[self.kv_mgr.attn_tp_rank]
assert len(managed_ranks) == num_remote_tp_rank_we_managed
logger.debug(
f"Rank {self.kv_mgr.kv_args.engine_rank} source can be {managed_ranks}"
)
return None
return [
{
rk: bootstrap_info[rk]
for rk in bootstrap_info.keys()
if rk in managed_ranks
}
]
else:
url = f"http://{self.bootstrap_addr}/route?engine_rank={engine_rank}"
response = requests.get(url)
if response.status_code == 200:
bootstrap_info = response.json()
return [{engine_rank: bootstrap_info}]
else:
logger.error(
f"Failed to get prefill server info: {response.status_code}, {response.text}"
)
return None
except Exception as e:
logger.error(f"Error fetching prefill info from bootstrap: {e}")
return None
......@@ -440,43 +579,67 @@ class NixlKVReceiver(BaseKVReceiver):
return socket
def init(self, kv_indices: npt.NDArray[np.int64], aux_index: Optional[int] = None):
self.prefill_server_url = (
f"{self.bootstrap_info['rank_ip']}:{self.bootstrap_info['rank_port']}"
)
logger.debug(
f"Fetched bootstrap info: {self.bootstrap_info} for engine rank: {self.kv_mgr.kv_args.engine_rank}"
)
packed_kv_data_ptrs = b"".join(
struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.kv_data_ptrs
)
packed_aux_data_ptrs = b"".join(
struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.aux_data_ptrs
)
self._connect("tcp://" + self.prefill_server_url).send_multipart(
[
str(self.bootstrap_room).encode("ascii"),
get_local_ip_by_remote().encode("ascii"),
str(self.kv_mgr.rank_port).encode("ascii"),
self.kv_mgr.agent.get_agent_metadata(),
packed_kv_data_ptrs,
kv_indices.tobytes(),
packed_aux_data_ptrs,
str(aux_index).encode("ascii"),
str(self.kv_mgr.kv_args.gpu_id).encode("ascii"),
assert self.bootstrap_info is not None
assert self.bootstrap_room is not None
for equal_sources in self.bootstrap_info:
remote_rank = list(equal_sources.keys())[
self.bootstrap_room % len(equal_sources)
]
)
self.prefill_server_url = f"{equal_sources[remote_rank]['rank_ip']}:{equal_sources[remote_rank]['rank_port']}"
logger.debug(
f"Fetched bootstrap info for engine rank: {self.kv_mgr.kv_args.engine_rank}, source: {remote_rank}, all: {list(equal_sources.keys())}"
)
packed_kv_data_ptrs = b"".join(
struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.kv_data_ptrs
)
packed_aux_data_ptrs = b"".join(
struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.aux_data_ptrs
)
logger.debug(
f"Sending to {self.prefill_server_url} with bootstrap room {self.bootstrap_room}"
)
self._connect("tcp://" + self.prefill_server_url).send_multipart(
[
GUARD,
str(self.bootstrap_room).encode("ascii"),
get_local_ip_by_remote().encode("ascii"),
str(self.kv_mgr.rank_port).encode("ascii"),
self.kv_mgr.agent.get_agent_metadata(),
packed_kv_data_ptrs,
kv_indices.tobytes(),
packed_aux_data_ptrs,
str(aux_index).encode("ascii"),
str(self.kv_mgr.kv_args.gpu_id).encode("ascii"),
]
)
for dummy_rank in equal_sources.keys():
if dummy_rank == remote_rank:
continue
dummy_info = equal_sources[dummy_rank]
dummy_url = f"{dummy_info['rank_ip']}:{dummy_info['rank_port']}"
self._connect("tcp://" + dummy_url).send_multipart(
[
GUARD,
str(self.bootstrap_room).encode("ascii"),
]
)
self.started_transfer = True
def poll(self) -> KVPoll:
if not self.started_transfer:
return KVPoll.WaitingForInput
return KVPoll.WaitingForInput # type: ignore
self.kv_mgr.update_transfer_status()
if self.kv_mgr.check_transfer_done(self.bootstrap_room):
return KVPoll.Success
return KVPoll.WaitingForInput
if self.kv_mgr.check_transfer_done(self.bootstrap_room): # type: ignore
return KVPoll.Success # type: ignore
return KVPoll.WaitingForInput # type: ignore
def failure_exception(self):
raise Exception("Fake KVReceiver Exception")
......@@ -484,6 +647,7 @@ class NixlKVReceiver(BaseKVReceiver):
class NixlKVBootstrapServer(BaseKVBootstrapServer):
def __init__(self, port: int):
logger.debug(f"NixlKVBootstrapServer started on port {port}")
self.port = port
self.app = web.Application()
self.store = dict()
......@@ -564,13 +728,13 @@ class NixlKVBootstrapServer(BaseKVBootstrapServer):
engine_rank = int(data["engine_rank"])
agent_name = data["agent_name"]
# Add lock to make sure thread-safe
if role == "Prefill":
self.prefill_port_table[engine_rank] = {
"rank_ip": rank_ip,
"rank_port": rank_port,
"agent_name": agent_name,
}
async with self.lock:
self.prefill_port_table[engine_rank] = {
"rank_ip": rank_ip,
"rank_port": rank_port,
"agent_name": agent_name,
}
logger.info(
f"Registered Prefill boostrap: {engine_rank} with rank_ip: {rank_ip} and rank_port: {rank_port} and name: {agent_name}"
)
......@@ -580,7 +744,13 @@ class NixlKVBootstrapServer(BaseKVBootstrapServer):
async def _handle_route_get(self, request: web.Request):
engine_rank = request.query.get("engine_rank")
if not engine_rank:
return web.Response(text="Missing rank", status=400)
logger.debug(
f"No engine_rank specified, return all {len(self.prefill_port_table)} engine infos as a dict"
)
# Return a dict of all engine_rank
async with self.lock:
bootstrap_info = self.prefill_port_table
return web.json_response(bootstrap_info, status=200)
# Find corresponding prefill info
async with self.lock:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment