Unverified Commit 4dce1cc6 authored by Trevor Morris's avatar Trevor Morris Committed by GitHub
Browse files

[PD] Add NIXL transfer backend (#5477)

parent deded17f
from .conn import NixlKVBootstrapServer, NixlKVManager, NixlKVReceiver, NixlKVSender
from __future__ import annotations
import asyncio
import dataclasses
import logging
import queue
import socket
import struct
import threading
import uuid
from collections import defaultdict
from functools import cache
from typing import Dict, List, Optional, Tuple, Union
import numpy as np
import numpy.typing as npt
import requests
import zmq
from aiohttp import web
from sglang.srt.disaggregation.base.conn import (
BaseKVBootstrapServer,
BaseKVManager,
BaseKVReceiver,
BaseKVSender,
KVArgs,
KVPoll,
)
from sglang.srt.disaggregation.utils import DisaggregationMode
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import get_free_port, get_ip, get_local_ip_by_remote
logger = logging.getLogger(__name__)
@dataclasses.dataclass
class TransferInfo:
room: int
endpoint: str
dst_port: int
agent_metadata: bytes
dst_kv_ptrs: list[int]
dst_kv_indices: npt.NDArray[np.int64]
dst_aux_ptrs: list[int]
dst_aux_index: int
dst_gpu_id: int
@classmethod
def from_zmq(cls, msg: List[bytes]):
return cls(
room=int(msg[0].decode("ascii")),
endpoint=msg[1].decode("ascii"),
dst_port=int(msg[2].decode("ascii")),
agent_metadata=msg[3],
dst_kv_ptrs=list(struct.unpack(f"{len(msg[4])//8}Q", msg[4])),
dst_kv_indices=np.frombuffer(msg[5], dtype=np.int64),
dst_aux_ptrs=list(struct.unpack(f"{len(msg[6])//8}Q", msg[6])),
dst_aux_index=int(msg[7].decode("ascii")),
dst_gpu_id=int(msg[8].decode("ascii")),
)
@dataclasses.dataclass
class TransferStatus:
"""Used by KV Receiver to know when a transfer is done."""
# KV chunk IDs that have been received.
received_kvs: Set[int] = dataclasses.field(default_factory=set)
# Number of kv chunks to expect, will know this after last chunk is received.
num_kvs_expected: Optional[int] = None
# Whether aux data has been received.
received_aux: bool = False
def is_done(self):
if self.num_kvs_expected is None:
return False
return self.num_kvs_expected == len(self.received_kvs) and self.received_aux
class NixlKVManager(BaseKVManager):
def __init__(
self,
args: KVArgs,
disaggregation_mode: DisaggregationMode,
server_args: ServerArgs,
):
try:
from nixl._api import nixl_agent
except ImportError as e:
raise ImportError(
"Please install NIXL by following the instructions at "
"https://github.com/ai-dynamo/nixl/blob/main/README.md "
"to run SGLang with NixlTransferEngine."
) from e
self.agent = nixl_agent(str(uuid.uuid4()))
self.kv_args = args
self.disaggregation_mode = disaggregation_mode
# for p/d multi node infer
self.bootstrap_port = server_args.disaggregation_bootstrap_port
self.dist_init_addr = server_args.dist_init_addr
self.rank_port = None
self.server_socket = zmq.Context().socket(zmq.PULL)
self.register_buffer_to_engine()
self.rank_port = get_free_port()
if self.disaggregation_mode == DisaggregationMode.PREFILL:
self.transfer_infos: Dict[int, TransferInfo] = {}
self.condition = threading.Condition()
self.peer_names: Dict[int, str] = {}
self._start_bootstrap_thread()
self._register_to_bootstrap()
elif self.disaggregation_mode == DisaggregationMode.DECODE:
self.connection_pool: Dict[str, Dict[str, Union[str, int]]] = {}
self.transfer_statuses: Dict[int, TransferStatus] = defaultdict(
TransferStatus
)
else:
raise ValueError(
f"Unsupported DisaggregationMode: {self.disaggregation_mode}"
)
def register_buffer_to_engine(self):
kv_addrs = []
for kv_data_ptr, kv_data_len in zip(
self.kv_args.kv_data_ptrs, self.kv_args.kv_data_lens
):
kv_addrs.append((kv_data_ptr, kv_data_len, self.kv_args.gpu_id, ""))
self.kv_descs = self.agent.register_memory(kv_addrs, "VRAM", is_sorted=True)
if not self.kv_descs:
raise Exception("NIXL memory registration failed for kv tensors")
aux_addrs = []
for aux_data_ptr, aux_data_len in zip(
self.kv_args.aux_data_ptrs, self.kv_args.aux_data_lens
):
aux_addrs.append((aux_data_ptr, aux_data_len, 0, ""))
self.aux_descs = self.agent.register_memory(aux_addrs, "DRAM", is_sorted=True)
if not self.aux_descs:
raise Exception("NIXL memory registration failed for aux tensors")
@cache
def _connect(self, endpoint: str):
socket = zmq.Context().socket(zmq.PUSH)
socket.connect(endpoint)
return socket
def _add_remote(self, room: int, agent_metadata: bytes):
if room not in self.peer_names:
self.peer_names[room] = self.agent.add_remote_agent(agent_metadata)
return self.peer_names[room]
def send_kvcache(
self,
peer_name: str,
prefill_kv_indices: npt.NDArray[np.int64],
dst_kv_ptrs: list[int],
dst_kv_indices: npt.NDArray[np.int64],
dst_gpu_id: int,
notif: str,
):
# Make descs
num_layers = len(self.kv_args.kv_data_ptrs)
src_addrs = []
dst_addrs = []
for layer_id in range(num_layers):
src_ptr = self.kv_args.kv_data_ptrs[layer_id]
dst_ptr = dst_kv_ptrs[layer_id]
item_len = self.kv_args.kv_item_lens[layer_id]
for prefill_index, decode_index in zip(prefill_kv_indices, dst_kv_indices):
src_addr = src_ptr + int(prefill_index) * item_len
dst_addr = dst_ptr + int(decode_index) * item_len
length = item_len
src_addrs.append((src_addr, length, self.kv_args.gpu_id))
dst_addrs.append((dst_addr, length, dst_gpu_id))
src_descs = self.agent.get_xfer_descs(src_addrs, "VRAM", is_sorted=True)
dst_descs = self.agent.get_xfer_descs(dst_addrs, "VRAM", is_sorted=True)
# Transfer data
xfer_handle = self.agent.initialize_xfer(
"WRITE",
src_descs,
dst_descs,
peer_name,
notif.encode("ascii"),
)
if not xfer_handle:
raise Exception("KVSender failed to create transfer")
state = self.agent.transfer(xfer_handle)
if state == "ERR":
raise Exception("KVSender failed to post transfer")
return xfer_handle
def send_aux(
self,
peer_name: str,
prefill_aux_index: int,
dst_aux_ptrs: list[int],
dst_aux_index: int,
notif: str,
):
# Make descs
aux_item_len = self.kv_args.aux_item_lens[0]
prefill_aux_addr = (
self.kv_args.aux_data_ptrs[0] + prefill_aux_index * aux_item_len
)
decode_aux_addr = dst_aux_ptrs[0] + dst_aux_index * aux_item_len
src_addrs = [(prefill_aux_addr, aux_item_len, 0)]
dst_addrs = [(decode_aux_addr, aux_item_len, 0)]
src_descs = self.agent.get_xfer_descs(src_addrs, "DRAM", is_sorted=True)
dst_descs = self.agent.get_xfer_descs(dst_addrs, "DRAM", is_sorted=True)
# Transfer data
xfer_handle = self.agent.initialize_xfer(
"WRITE",
src_descs,
dst_descs,
peer_name,
notif.encode("ascii"),
)
if not xfer_handle:
raise Exception("KVSender failed to create transfer")
state = self.agent.transfer(xfer_handle)
if state == "ERR":
raise Exception("KVSender failed to post transfer")
return xfer_handle
def add_transfer_request(
self,
bootstrap_room: int,
kv_indices: npt.NDArray[np.int64],
index_slice: slice,
is_last: bool,
chunk_id: int,
aux_index: Optional[int] = None,
):
assert self.disaggregation_mode == DisaggregationMode.PREFILL
assert not is_last or (is_last and aux_index is not None)
# Wait for transfer info to be populated by bootstrap thread.
with self.condition:
self.condition.wait_for(lambda: bootstrap_room in self.transfer_infos)
req = self.transfer_infos[bootstrap_room]
assert bootstrap_room == req.room
peer_name = self._add_remote(bootstrap_room, req.agent_metadata)
chunked_dst_kv_indice = req.dst_kv_indices[index_slice]
assert len(chunked_dst_kv_indice) == len(kv_indices)
notif = "_".join([str(req.room), "kv", str(chunk_id), str(int(is_last))])
kv_xfer_handle = self.send_kvcache(
peer_name,
kv_indices,
req.dst_kv_ptrs,
chunked_dst_kv_indice,
req.dst_gpu_id,
notif,
)
handles = [kv_xfer_handle]
# Only the last chunk we need to send the aux data.
if is_last:
aux_xfer_handle = self.send_aux(
peer_name,
aux_index,
req.dst_aux_ptrs,
req.dst_aux_index,
str(req.room) + "_aux",
)
handles.append(aux_xfer_handle)
return handles
def update_transfer_status(self):
# Process notifications from received transfers.
notif_map = self.agent.get_new_notifs()
for peer_name, messages in notif_map.items():
# We could also check that self.bootstrap_info['agent_name'] matches
# the message sender. But the bootstrap room alone should be
# sufficient to map the status.
for msg in messages:
components = msg.decode("ascii").split("_")
room = int(components[0])
if components[1] == "kv":
chunk_id = int(components[2])
is_last = bool(components[3])
self.transfer_statuses[room].received_kvs.add(chunk_id)
if is_last:
self.transfer_statuses[room].num_kvs_expected = chunk_id + 1
elif components[1] == "aux":
self.transfer_statuses[room].received_aux = True
def check_transfer_done(self, room: int):
if room not in self.transfer_statuses:
return False
return self.transfer_statuses[room].is_done()
def _register_to_bootstrap(self):
"""Register KVSender to bootstrap server via HTTP POST."""
if self.dist_init_addr:
ip_address = socket.gethostbyname(self.dist_init_addr.split(":")[0])
else:
ip_address = get_ip()
bootstrap_server_url = f"{ip_address}:{self.bootstrap_port}"
url = f"http://{bootstrap_server_url}/route"
payload = {
"role": "Prefill",
"rank_ip": get_local_ip_by_remote(),
"rank_port": self.rank_port,
"engine_rank": self.kv_args.engine_rank,
"agent_name": self.agent.name,
}
try:
response = requests.put(url, json=payload)
if response.status_code == 200:
logger.debug("Prefill successfully registered to bootstrap server.")
else:
logger.error(
f"Prefill Failed to connect to bootstrap server: {response.status_code}, {response.text}"
)
except Exception as e:
logger.error(f"Prefill Failed to register to bootstrap server: {e}")
def _start_bootstrap_thread(self):
self.server_socket.bind(f"tcp://{get_local_ip_by_remote()}:{self.rank_port}")
def bootstrap_thread():
"""This thread recvs transfer info from the decode engine"""
while True:
waiting_req_bytes = self.server_socket.recv_multipart()
room = waiting_req_bytes[0].decode("ascii")
if room == "None":
continue
room = int(room)
with self.condition:
self.transfer_infos[room] = TransferInfo.from_zmq(waiting_req_bytes)
self.condition.notify_all()
threading.Thread(target=bootstrap_thread).start()
class NixlKVSender(BaseKVSender):
def __init__(self, mgr: NixlKVManager, bootstrap_addr: str, bootstrap_room: int):
self.kv_mgr = mgr
self.bootstrap_room = bootstrap_room
self.aux_index = None
self.bootstrap_server_url = bootstrap_addr
self.xfer_handles = []
self.has_sent = False
self.chunk_id = 0
def init(self, num_kv_indices: int, aux_index: Optional[int] = None):
self.num_kv_indices = num_kv_indices
self.aux_index = aux_index
def send(
self,
kv_indices: npt.NDArray[np.int64],
index_slice: slice,
is_last: bool,
):
new_xfer_handles = self.kv_mgr.add_transfer_request(
self.bootstrap_room,
kv_indices,
index_slice,
is_last,
self.chunk_id,
self.aux_index,
)
self.xfer_handles.extend(new_xfer_handles)
self.chunk_id += 1
if is_last:
self.has_sent = True
def poll(self) -> KVPoll:
if not self.has_sent:
return KVPoll.WaitingForInput
states = [self.kv_mgr.agent.check_xfer_state(x) for x in self.xfer_handles]
if all([x == "DONE" for x in states]):
return KVPoll.Success
if any([x == "ERR" for x in states]):
raise Exception("KVSender transfer encountered an error.")
return KVPoll.WaitingForInput
def failure_exception(self):
raise Exception("Fake KVSender Exception")
class NixlKVReceiver(BaseKVReceiver):
def __init__(
self,
mgr: NixlKVManager,
bootstrap_addr: str,
bootstrap_room: Optional[int] = None,
):
self.bootstrap_room = bootstrap_room
self.bootstrap_addr = bootstrap_addr
self.kv_mgr = mgr
self.started_transfer = False
# NOTE: key distinguished by bootstrap_addr and engine_rank
bootstrap_key = f"{self.bootstrap_addr}_{self.kv_mgr.kv_args.engine_rank}"
if bootstrap_key not in self.kv_mgr.connection_pool:
self.bootstrap_info = self._get_bootstrap_info_from_server(
self.kv_mgr.kv_args.engine_rank
)
if self.bootstrap_info is None:
logger.error(
f"Could not fetch bootstrap info for engine rank: {self.kv_mgr.kv_args.engine_rank}"
)
else:
self.kv_mgr.connection_pool[bootstrap_key] = self.bootstrap_info
else:
self.bootstrap_info = self.kv_mgr.connection_pool[bootstrap_key]
assert self.bootstrap_info is not None
def _get_bootstrap_info_from_server(self, engine_rank):
"""Fetch the bootstrap info from the bootstrap server."""
try:
url = f"http://{self.bootstrap_addr}/route?engine_rank={engine_rank}"
response = requests.get(url)
if response.status_code == 200:
bootstrap_info = response.json()
return bootstrap_info
else:
logger.error(
f"Failed to get prefill server info: {response.status_code}, {response.text}"
)
return None
except Exception as e:
logger.error(f"Error fetching prefill info from bootstrap: {e}")
return None
@cache
def _connect(self, endpoint: str):
socket = zmq.Context().socket(zmq.PUSH)
socket.connect(endpoint)
return socket
def init(self, kv_indices: npt.NDArray[np.int64], aux_index: Optional[int] = None):
self.prefill_server_url = (
f"{self.bootstrap_info['rank_ip']}:{self.bootstrap_info['rank_port']}"
)
logger.debug(
f"Fetched bootstrap info: {self.bootstrap_info} for engine rank: {self.kv_mgr.kv_args.engine_rank}"
)
packed_kv_data_ptrs = b"".join(
struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.kv_data_ptrs
)
packed_aux_data_ptrs = b"".join(
struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.aux_data_ptrs
)
self._connect("tcp://" + self.prefill_server_url).send_multipart(
[
str(self.bootstrap_room).encode("ascii"),
get_local_ip_by_remote().encode("ascii"),
str(self.kv_mgr.rank_port).encode("ascii"),
self.kv_mgr.agent.get_agent_metadata(),
packed_kv_data_ptrs,
kv_indices.tobytes(),
packed_aux_data_ptrs,
str(aux_index).encode("ascii"),
str(self.kv_mgr.kv_args.gpu_id).encode("ascii"),
]
)
self.started_transfer = True
def poll(self) -> KVPoll:
if not self.started_transfer:
return KVPoll.WaitingForInput
self.kv_mgr.update_transfer_status()
if self.kv_mgr.check_transfer_done(self.bootstrap_room):
return KVPoll.Success
return KVPoll.WaitingForInput
def failure_exception(self):
raise Exception("Fake KVReceiver Exception")
class NixlKVBootstrapServer(BaseKVBootstrapServer):
def __init__(self, port: int):
self.port = port
self.app = web.Application()
self.store = dict()
self.lock = asyncio.Lock()
self._setup_routes()
self.prefill_port_table: Dict[int, Dict[str, Union[str, int]]] = {}
# Start bootstrap server
self.thread = threading.Thread(target=self._run_server, daemon=True)
self.run()
def run(self):
self.thread.start()
def _setup_routes(self):
self.app.router.add_route("*", "/metadata", self._handle_metadata)
self.app.router.add_route("*", "/route", self._handle_route)
async def _handle_metadata(self, request: web.Request):
key = request.query.get("key", "")
if request.method == "GET":
return await self._handle_metadata_get(key)
elif request.method == "PUT":
return await self._handle_metadata_put(key, request)
elif request.method == "DELETE":
return await self._handle_metadata_delete(key)
return web.Response(
text="Method not allowed", status=405, content_type="application/json"
)
async def _handle_metadata_get(self, key):
async with self.lock:
value = self.store.get(key)
if value is None:
return web.Response(
text="metadata not found", status=404, content_type="application/json"
)
return web.Response(body=value, status=200, content_type="application/json")
async def _handle_metadata_put(self, key, request):
data = await request.read()
async with self.lock:
self.store[key] = data
return web.Response(
text="metadata updated", status=200, content_type="application/json"
)
async def _handle_metadata_delete(self, key):
async with self.lock:
if key not in self.store:
return web.Response(
text="metadata not found",
status=404,
content_type="application/json",
)
del self.store[key]
return web.Response(
text="metadata deleted", status=200, content_type="application/json"
)
async def _handle_route(self, request: web.Request):
method = request.method
if method == "PUT":
return await self._handle_route_put(request)
elif method == "GET":
return await self._handle_route_get(request)
else:
return web.Response(
text="Method not allowed", status=405, content_type="application/json"
)
async def _handle_route_put(self, request: web.Request):
data = await request.json()
role = data["role"]
rank_ip = data["rank_ip"]
rank_port = int(data["rank_port"])
engine_rank = int(data["engine_rank"])
agent_name = data["agent_name"]
# Add lock to make sure thread-safe
if role == "Prefill":
self.prefill_port_table[engine_rank] = {
"rank_ip": rank_ip,
"rank_port": rank_port,
"agent_name": agent_name,
}
logger.info(
f"Registered Prefill boostrap: {engine_rank} with rank_ip: {rank_ip} and rank_port: {rank_port} and name: {agent_name}"
)
return web.Response(text="OK", status=200)
async def _handle_route_get(self, request: web.Request):
engine_rank = request.query.get("engine_rank")
if not engine_rank:
return web.Response(text="Missing rank", status=400)
# Find corresponding prefill info
async with self.lock:
bootstrap_info = self.prefill_port_table.get(int(engine_rank))
if bootstrap_info is not None:
return web.json_response(bootstrap_info, status=200)
else:
return web.Response(text="Not Found", status=404)
def _run_server(self):
try:
# Event Loop
self._loop = asyncio.new_event_loop()
asyncio.set_event_loop(self._loop)
self._runner = web.AppRunner(self.app)
self._loop.run_until_complete(self._runner.setup())
site = web.TCPSite(self._runner, port=self.port)
self._loop.run_until_complete(site.start())
self._loop.run_forever()
except Exception as e:
logger.error(f"Server error: {str(e)}")
finally:
# Cleanup
self._loop.run_until_complete(self._runner.cleanup())
self._loop.close()
def close(self):
"""Shutdown"""
if self._loop is not None and self._loop.is_running():
self._loop.call_soon_threadsafe(self._loop.stop)
logger.info("Stopping server loop...")
if self.thread.is_alive():
self.thread.join(timeout=2)
logger.info("Server thread stopped")
def poll(self) -> KVPoll: ...
......@@ -47,6 +47,7 @@ class ReqToMetadataIdxAllocator:
class TransferBackend(Enum):
MOONCAKE = "mooncake"
NIXL = "nixl"
FAKE = "fake"
......@@ -73,6 +74,21 @@ def get_kv_class(transfer_backend: TransferBackend, class_type: KVClassType):
KVClassType.BOOTSTRAP_SERVER: MooncakeKVBootstrapServer,
}
return class_mapping.get(class_type)
if transfer_backend == TransferBackend.NIXL:
from sglang.srt.disaggregation.nixl import (
NixlKVBootstrapServer,
NixlKVManager,
NixlKVReceiver,
NixlKVSender,
)
class_mapping = {
KVClassType.MANAGER: NixlKVManager,
KVClassType.SENDER: NixlKVSender,
KVClassType.RECEIVER: NixlKVReceiver,
KVClassType.BOOTSTRAP_SERVER: NixlKVBootstrapServer,
}
return class_mapping.get(class_type)
raise ValueError(f"Unsupported transfer backend: {transfer_backend}")
......
......@@ -1211,6 +1211,7 @@ class ServerArgs:
"--disaggregation-transfer-backend",
type=str,
default=ServerArgs.disaggregation_transfer_backend,
choices=["mooncake", "nixl"],
help="The backend for disaggregation transfer. Default is mooncake.",
)
parser.add_argument(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment