from __future__ import annotations import asyncio import dataclasses import logging import queue import socket import struct import threading import uuid from collections import defaultdict from functools import cache from typing import Dict, List, Optional, Set, Tuple, TypeAlias, Union import numpy as np import numpy.typing as npt import requests import zmq from aiohttp import web from sglang.srt.disaggregation.base.conn import BaseKVSender, KVArgs, KVPoll from sglang.srt.disaggregation.common.conn import ( CommonKVBootstrapServer, CommonKVManager, CommonKVReceiver, ) from sglang.srt.disaggregation.utils import ( DisaggregationMode, group_concurrent_contiguous, ) from sglang.srt.server_args import ServerArgs from sglang.srt.utils import get_local_ip_by_remote logger = logging.getLogger(__name__) NixlEngineInfo: TypeAlias = Dict[str, Union[str, int]] GUARD = "NixlMsgGuard".encode("ascii") @dataclasses.dataclass class TransferInfo: room: int endpoint: str dst_port: int agent_metadata: bytes agent_name: str dst_kv_ptrs: list[int] dst_kv_indices: npt.NDArray[np.int64] dst_aux_ptrs: list[int] dst_aux_index: int dst_gpu_id: int required_dst_info_num: int def is_dummy(self): return self.dst_kv_indices.size == 0 @classmethod def from_zmq(cls, msg: List[bytes]): return cls( room=int(msg[0].decode("ascii")), endpoint=msg[1].decode("ascii"), dst_port=int(msg[2].decode("ascii")), agent_metadata=msg[3], agent_name=msg[4].decode("ascii"), dst_kv_ptrs=list(struct.unpack(f"{len(msg[5])//8}Q", msg[5])), dst_kv_indices=np.frombuffer(msg[6], dtype=np.int64), dst_aux_ptrs=list(struct.unpack(f"{len(msg[7])//8}Q", msg[7])), dst_aux_index=int(msg[8].decode("ascii")), dst_gpu_id=int(msg[9].decode("ascii")), required_dst_info_num=int(msg[10].decode("ascii")), ) @dataclasses.dataclass class TransferStatus: """Used by KV Receiver to know when a transfer is done.""" # KV chunk IDs that have been received. received_kvs: Set[int] = dataclasses.field(default_factory=set) # Number of kv chunks to expect, will know this after last chunk is received. num_kvs_expected: Optional[int] = None # Whether aux data has been received. received_aux: bool = False def is_done(self): if self.num_kvs_expected is None: return False return self.num_kvs_expected == len(self.received_kvs) and self.received_aux class NixlKVManager(CommonKVManager): def __init__( self, args: KVArgs, disaggregation_mode: DisaggregationMode, server_args: ServerArgs, is_mla_backend: Optional[bool] = False, ): super().__init__(args, disaggregation_mode, server_args, is_mla_backend) try: from nixl._api import nixl_agent except ImportError as e: raise ImportError( "Please install NIXL by following the instructions at " "https://github.com/ai-dynamo/nixl/blob/main/README.md " "to run SGLang with NixlTransferEngine." ) from e self.agent = nixl_agent(str(uuid.uuid4())) self.server_socket = zmq.Context().socket(zmq.PULL) self.register_buffer_to_engine() if self.disaggregation_mode == DisaggregationMode.PREFILL: self.request_status = {} self.transfer_infos: Dict[int, TransferInfo] = {} self.peer_names: Dict[str, str] = {} self._start_bootstrap_thread() elif self.disaggregation_mode == DisaggregationMode.DECODE: self.transfer_statuses: Dict[int, TransferStatus] = defaultdict( TransferStatus ) else: raise ValueError( f"Unsupported DisaggregationMode: {self.disaggregation_mode}" ) def check_status(self, bootstrap_room: int): return self.request_status[bootstrap_room] def update_status(self, bootstrap_room: int, status: KVPoll): if bootstrap_room not in self.request_status: self.request_status[bootstrap_room] = status else: # NOTE: The prefill engine could recv bootstrapping first self.request_status[bootstrap_room] = max( self.request_status[bootstrap_room], status ) def register_buffer_to_engine(self): kv_addrs = [] for kv_data_ptr, kv_data_len in zip( self.kv_args.kv_data_ptrs, self.kv_args.kv_data_lens ): kv_addrs.append((kv_data_ptr, kv_data_len, self.kv_args.gpu_id, "")) self.kv_descs = self.agent.register_memory(kv_addrs, "VRAM", is_sorted=True) logger.debug(f"Register kv tensors, len(kv_addr)= {len(kv_addrs)}") if not self.kv_descs: raise Exception("NIXL memory registration failed for kv tensors") aux_addrs = [] for aux_data_ptr, aux_data_len in zip( self.kv_args.aux_data_ptrs, self.kv_args.aux_data_lens ): aux_addrs.append((aux_data_ptr, aux_data_len, 0, "")) self.aux_descs = self.agent.register_memory(aux_addrs, "DRAM", is_sorted=True) logger.debug(f"Register aux tensors, len(aux_addrs)= {len(aux_addrs)}") if not self.aux_descs: raise Exception("NIXL memory registration failed for aux tensors") def _add_remote(self, agent_name: str, agent_metadata: bytes): if agent_name not in self.peer_names: self.peer_names[agent_name] = self.agent.add_remote_agent(agent_metadata) return self.peer_names[agent_name] def send_kvcache( self, peer_name: str, prefill_kv_indices: npt.NDArray[np.int64], dst_kv_ptrs: list[int], dst_kv_indices: npt.NDArray[np.int64], dst_gpu_id: int, notif: str, ): # group by indices prefill_kv_blocks, dst_kv_blocks = group_concurrent_contiguous( prefill_kv_indices, dst_kv_indices ) logger.debug(f"sending kvcache to {peer_name} with notif {notif}") # Make descs num_layers = len(self.kv_args.kv_data_ptrs) src_addrs = [] dst_addrs = [] for layer_id in range(num_layers): src_ptr = self.kv_args.kv_data_ptrs[layer_id] dst_ptr = dst_kv_ptrs[layer_id] item_len = self.kv_args.kv_item_lens[layer_id] for prefill_index, decode_index in zip(prefill_kv_blocks, dst_kv_blocks): src_addr = src_ptr + int(prefill_index[0]) * item_len dst_addr = dst_ptr + int(decode_index[0]) * item_len length = item_len * len(prefill_index) src_addrs.append((src_addr, length, self.kv_args.gpu_id)) dst_addrs.append((dst_addr, length, dst_gpu_id)) logger.debug( f"len(src_addrs): before group: {len(prefill_kv_indices)}, after group: {len(src_addrs)}" ) src_descs = self.agent.get_xfer_descs(src_addrs, "VRAM", is_sorted=True) dst_descs = self.agent.get_xfer_descs(dst_addrs, "VRAM", is_sorted=True) # Transfer data xfer_handle = self.agent.initialize_xfer( "WRITE", src_descs, dst_descs, peer_name, notif.encode("ascii"), # type: ignore ) if not xfer_handle: raise Exception("KVSender failed to create transfer") state = self.agent.transfer(xfer_handle) if state == "ERR": raise Exception("KVSender failed to post transfer") return xfer_handle def send_aux( self, peer_name: str, prefill_aux_index: int, dst_aux_ptrs: list[int], dst_aux_index: int, notif: str, ): # Make descs aux_item_len = self.kv_args.aux_item_lens[0] prefill_aux_addr = ( self.kv_args.aux_data_ptrs[0] + prefill_aux_index * aux_item_len ) decode_aux_addr = dst_aux_ptrs[0] + dst_aux_index * aux_item_len src_addrs = [(prefill_aux_addr, aux_item_len, 0)] dst_addrs = [(decode_aux_addr, aux_item_len, 0)] src_descs = self.agent.get_xfer_descs(src_addrs, "DRAM", is_sorted=True) dst_descs = self.agent.get_xfer_descs(dst_addrs, "DRAM", is_sorted=True) # Transfer data xfer_handle = self.agent.initialize_xfer( "WRITE", src_descs, dst_descs, peer_name, notif.encode("ascii"), # type: ignore ) if not xfer_handle: raise Exception("KVSender failed to create transfer") state = self.agent.transfer(xfer_handle) if state == "ERR": raise Exception("KVSender failed to post transfer") return xfer_handle def add_transfer_request( self, bootstrap_room: int, kv_indices: npt.NDArray[np.int64], index_slice: slice, is_last: bool, chunk_id: int, aux_index: Optional[int] = None, ): assert self.disaggregation_mode == DisaggregationMode.PREFILL assert not is_last or (is_last and aux_index is not None) reqs_to_be_processed = self.transfer_infos[bootstrap_room].values() handles = [] for req in reqs_to_be_processed: assert bootstrap_room == req.room if req.is_dummy(): continue peer_name = self._add_remote(req.agent_name, req.agent_metadata) chunked_dst_kv_indice = req.dst_kv_indices[index_slice] assert len(chunked_dst_kv_indice) == len(kv_indices) notif = "_".join([str(req.room), "kv", str(chunk_id), str(int(is_last))]) kv_xfer_handle = self.send_kvcache( peer_name, kv_indices, req.dst_kv_ptrs, chunked_dst_kv_indice, req.dst_gpu_id, notif, ) handles.append(kv_xfer_handle) # Only the last chunk we need to send the aux data. if is_last: assert aux_index is not None aux_xfer_handle = self.send_aux( peer_name, aux_index, req.dst_aux_ptrs, req.dst_aux_index, str(req.room) + "_aux", ) handles.append(aux_xfer_handle) return handles def update_transfer_status(self): # Process notifications from received transfers. notif_map = self.agent.get_new_notifs() for peer_name, messages in notif_map.items(): # We could also check that self.bootstrap_info['agent_name'] matches # the message sender. But the bootstrap room alone should be # sufficient to map the status. for msg in messages: components = msg.decode("ascii").split("_") room = int(components[0]) if components[1] == "kv": chunk_id = int(components[2]) is_last = bool(int(components[3])) self.transfer_statuses[room].received_kvs.add(chunk_id) if is_last: self.transfer_statuses[room].num_kvs_expected = chunk_id + 1 elif components[1] == "aux": self.transfer_statuses[room].received_aux = True def check_transfer_done(self, room: int): if room not in self.transfer_statuses: return False return self.transfer_statuses[room].is_done() def _start_bootstrap_thread(self): self.server_socket.bind(f"tcp://{get_local_ip_by_remote()}:{self.rank_port}") def bootstrap_thread(): """This thread recvs transfer info from the decode engine""" while True: waiting_req_bytes = self.server_socket.recv_multipart() logger.debug( f"Received multipart with total byte size {sum(len(x) for x in waiting_req_bytes)}" ) assert ( waiting_req_bytes[0] == GUARD ), f"First message should be {GUARD}. Foreign traffic?" waiting_req_bytes = waiting_req_bytes[1:] room = waiting_req_bytes[0].decode("ascii") required_dst_info_num = int(waiting_req_bytes[10].decode("ascii")) room = int(room) agent_name = waiting_req_bytes[4].decode("ascii") if room not in self.transfer_infos: self.transfer_infos[room] = {} self.transfer_infos[room][agent_name] = TransferInfo.from_zmq( waiting_req_bytes ) logger.debug(f"got info {room=} {agent_name=} {required_dst_info_num=}") if len(self.transfer_infos[room]) == required_dst_info_num: logger.debug(f"{room=} is bootstrapped") self.update_status(room, KVPoll.WaitingForInput) threading.Thread(target=bootstrap_thread).start() class NixlKVSender(BaseKVSender): def __init__(self, mgr: NixlKVManager, bootstrap_addr: str, bootstrap_room: int): self.kv_mgr = mgr self.bootstrap_room = bootstrap_room self.aux_index = None self.bootstrap_server_url = bootstrap_addr self.xfer_handles = [] self.has_sent = False self.chunk_id = 0 self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Bootstrapping) # inner state self.curr_idx = 0 def init(self, num_kv_indices: int, aux_index: Optional[int] = None): self.num_kv_indices = num_kv_indices self.aux_index = aux_index def send( self, kv_indices: npt.NDArray[np.int64], ): index_slice = slice(self.curr_idx, self.curr_idx + len(kv_indices)) self.curr_idx += len(kv_indices) is_last = self.curr_idx == self.num_kv_indices new_xfer_handles = self.kv_mgr.add_transfer_request( self.bootstrap_room, kv_indices, index_slice, is_last, self.chunk_id, self.aux_index, ) self.xfer_handles.extend(new_xfer_handles) self.chunk_id += 1 if is_last: self.has_sent = True def poll(self) -> KVPoll: if not self.has_sent: return self.kv_mgr.check_status(self.bootstrap_room) states = [self.kv_mgr.agent.check_xfer_state(x) for x in self.xfer_handles] if all([x == "DONE" for x in states]): return KVPoll.Success # type: ignore if any([x == "ERR" for x in states]): raise Exception("KVSender transfer encountered an error.") return KVPoll.WaitingForInput # type: ignore def failure_exception(self): raise Exception("Fake KVSender Exception") class NixlKVReceiver(CommonKVReceiver): def __init__( self, mgr: NixlKVManager, bootstrap_addr: str, bootstrap_room: Optional[int] = None, data_parallel_rank: Optional[int] = None, ): self.started_transfer = False super().__init__(mgr, bootstrap_addr, bootstrap_room, data_parallel_rank) def init(self, kv_indices: npt.NDArray[np.int64], aux_index: Optional[int] = None): for bootstrap_info in self.bootstrap_infos: self.prefill_server_url = ( f"{bootstrap_info['rank_ip']}:{bootstrap_info['rank_port']}" ) logger.debug( f"Fetched bootstrap info: {bootstrap_info} for engine rank: {self.kv_mgr.kv_args.engine_rank}" ) is_dummy = bootstrap_info["is_dummy"] # TODO: send_kv_args earlier packed_kv_data_ptrs = b"".join( struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.kv_data_ptrs ) packed_aux_data_ptrs = b"".join( struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.aux_data_ptrs ) logger.debug( f"Sending to {self.prefill_server_url} with bootstrap room {self.bootstrap_room}" ) sock, lock = self._connect("tcp://" + self.prefill_server_url) with lock: sock.send_multipart( [ GUARD, str(self.bootstrap_room).encode("ascii"), get_local_ip_by_remote().encode("ascii"), str(self.kv_mgr.rank_port).encode("ascii"), self.kv_mgr.agent.get_agent_metadata(), self.kv_mgr.agent.name.encode("ascii"), packed_kv_data_ptrs, kv_indices.tobytes() if not is_dummy else b"", packed_aux_data_ptrs, str(aux_index).encode("ascii"), str(self.kv_mgr.kv_args.gpu_id).encode("ascii"), str(self.required_dst_info_num).encode("ascii"), ] ) self.started_transfer = True def poll(self) -> KVPoll: if not self.started_transfer: return KVPoll.WaitingForInput # type: ignore self.kv_mgr.update_transfer_status() if self.kv_mgr.check_transfer_done(self.bootstrap_room): # type: ignore return KVPoll.Success # type: ignore return KVPoll.WaitingForInput # type: ignore def _register_kv_args(self): pass def failure_exception(self): raise Exception("Fake KVReceiver Exception") class NixlKVBootstrapServer(CommonKVBootstrapServer): pass