from __future__ import annotations import dataclasses import os import random import warnings from collections import deque from enum import Enum from typing import TYPE_CHECKING, List, Optional import numpy as np import requests import torch import torch.distributed as dist from sglang.srt.utils import get_ip if TYPE_CHECKING: from sglang.srt.managers.schedule_batch import Req FakeBootstrapHost = "2.2.2.2" # env var for testing failure, convert to float explicitly FAILURE_PROB = float(os.getenv("DISAGGREGATION_TEST_FAILURE_PROB", 0)) class DisaggregationMode(Enum): NULL = "null" PREFILL = "prefill" DECODE = "decode" def poll_and_all_reduce(pollers, gloo_group): # at a certain prob, the poll is failed to simulate failure if FAILURE_PROB > 0: from sglang.srt.disaggregation.base import KVPoll polls = [ int(KVPoll.Failed) if random.random() < FAILURE_PROB else int(poller.poll()) for poller in pollers ] else: polls = [int(poller.poll()) for poller in pollers] tensor_to_reduce = torch.tensor(polls, dtype=torch.uint8, device="cpu") dist.all_reduce(tensor_to_reduce, op=dist.ReduceOp.MIN, group=gloo_group) return tensor_to_reduce.tolist() class ReqToMetadataIdxAllocator: """A memory pool that maps a request to its first output token location.""" def __init__( self, size: int, ): self.size = size self.free_slots = deque(list(range(size))) def available_size(self): return len(self.free_slots) def alloc(self) -> List[int]: if len(self.free_slots) == 0: return None return self.free_slots.popleft() def free(self, free_index: int): self.free_slots.append(free_index) class TransferBackend(Enum): MOONCAKE = "mooncake" NIXL = "nixl" FAKE = "fake" class KVClassType(Enum): MANAGER = "manager" SENDER = "sender" RECEIVER = "receiver" BOOTSTRAP_SERVER = "bootstrap_server" def get_kv_class(transfer_backend: TransferBackend, class_type: KVClassType): from sglang.srt.disaggregation.fake import FakeKVReceiver, FakeKVSender if transfer_backend == TransferBackend.MOONCAKE: from sglang.srt.disaggregation.mooncake import ( MooncakeKVBootstrapServer, MooncakeKVManager, MooncakeKVReceiver, MooncakeKVSender, ) class_mapping = { KVClassType.MANAGER: MooncakeKVManager, KVClassType.SENDER: MooncakeKVSender, KVClassType.RECEIVER: (MooncakeKVReceiver), KVClassType.BOOTSTRAP_SERVER: MooncakeKVBootstrapServer, } return class_mapping.get(class_type) if transfer_backend == TransferBackend.NIXL: from sglang.srt.disaggregation.nixl import ( NixlKVBootstrapServer, NixlKVManager, NixlKVReceiver, NixlKVSender, ) class_mapping = { KVClassType.MANAGER: NixlKVManager, KVClassType.SENDER: NixlKVSender, KVClassType.RECEIVER: (NixlKVReceiver), KVClassType.BOOTSTRAP_SERVER: NixlKVBootstrapServer, } return class_mapping.get(class_type) if transfer_backend == TransferBackend.FAKE: from sglang.srt.disaggregation.fake import FakeKVReceiver, FakeKVSender class_mapping = { KVClassType.SENDER: FakeKVSender, KVClassType.RECEIVER: (FakeKVReceiver), } return class_mapping.get(class_type) raise ValueError(f"Unsupported transfer backend: {transfer_backend}") def kv_to_page_indices(kv_indices: np.ndarray, page_size: int): # 1. The page is guaranteed to be full except the last page. # 2. page index = kv_index // page_size # The return vector is kv_indices[::page_size] // page_size if page_size == 1: # shortcut return kv_indices return kv_indices[::page_size] // page_size def kv_to_page_num(num_kv_indices: int, page_size: int): # ceil(num_kv_indices / page_size) return (num_kv_indices + page_size - 1) // page_size @dataclasses.dataclass class PDRegistryRequest: """A request to register a machine itself to the LB.""" mode: str registry_url: str bootstrap_port: Optional[int] = None def __post_init__(self): if self.mode == "prefill" and self.bootstrap_port is None: raise ValueError("Bootstrap port must be set in PREFILL mode.") elif self.mode == "decode" and self.bootstrap_port is not None: raise ValueError("Bootstrap port must not be set in DECODE mode.") elif self.mode not in ["prefill", "decode"]: raise ValueError( f"Invalid mode: {self.mode}. Must be 'prefill' or 'decode'." ) def register_disaggregation_server( mode: str, server_port: int, bootstrap_port: int, pdlb_url: str ): boostrap_port = bootstrap_port if mode == "prefill" else None registry_request = PDRegistryRequest( mode=mode, registry_url=f"http://{get_ip()}:{server_port}", bootstrap_port=boostrap_port, ) res = requests.post( f"{pdlb_url}/register", json=dataclasses.asdict(registry_request), ) if res.status_code != 200: warnings.warn( f"Failed to register disaggregation server: {res.status_code} {res.text}" ) def is_mla_backend(target_kv_pool) -> bool: from sglang.srt.mem_cache.memory_pool import MLATokenToKVPool return isinstance(target_kv_pool, MLATokenToKVPool) def prepare_abort(req: Req, error_message: str, status_code=None): from sglang.srt.managers.schedule_batch import FINISH_ABORT # populate finish metadata and stream output req.finished_reason = FINISH_ABORT(error_message, status_code) if req.return_logprob: req.input_token_logprobs_val = [] req.input_token_logprobs_idx = [] req.input_top_logprobs_val = [] req.input_top_logprobs_idx = [] req.input_token_ids_logprobs_val = [] req.input_token_ids_logprobs_idx = [] class MetadataBuffers: def __init__(self, size: int, max_top_logprobs_num: int = 128): # TODO: abort top_logprobs_num > 128 in PD # We transfer the metadata of first output token to decode # The minimal size for RDMA is 64Bytes, so we pad it to > 64Bytes self.output_ids = torch.zeros((size, 16), dtype=torch.int32, device="cpu") self.output_token_logprobs_val = torch.zeros( (size, 16), dtype=torch.float32, device="cpu" ) self.output_token_logprobs_idx = torch.zeros( (size, 16), dtype=torch.int32, device="cpu" ) self.output_top_logprobs_val = torch.zeros( (size, max_top_logprobs_num), dtype=torch.float32, device="cpu" ) self.output_top_logprobs_idx = torch.zeros( (size, max_top_logprobs_num), dtype=torch.int32, device="cpu" ) def get_buf_infos(self): ptrs = [ self.output_ids.data_ptr(), self.output_token_logprobs_val.data_ptr(), self.output_token_logprobs_idx.data_ptr(), self.output_top_logprobs_val.data_ptr(), self.output_top_logprobs_idx.data_ptr(), ] data_lens = [ self.output_ids.nbytes, self.output_token_logprobs_val.nbytes, self.output_token_logprobs_idx.nbytes, self.output_top_logprobs_val.nbytes, self.output_top_logprobs_idx.nbytes, ] item_lens = [ self.output_ids[0].nbytes, self.output_token_logprobs_val[0].nbytes, self.output_token_logprobs_idx[0].nbytes, self.output_top_logprobs_val[0].nbytes, self.output_top_logprobs_idx[0].nbytes, ] return ptrs, data_lens, item_lens def get_buf(self, idx: int): return ( self.output_ids[idx], self.output_token_logprobs_val[idx], self.output_token_logprobs_idx[idx], self.output_top_logprobs_val[idx], self.output_top_logprobs_idx[idx], ) def set_buf(self, req: Req): self.output_ids[req.metadata_buffer_index][0] = req.output_ids[0] if req.return_logprob: if req.output_token_logprobs_val: # not none or empty list self.output_token_logprobs_val[req.metadata_buffer_index][0] = ( req.output_token_logprobs_val[0] ) if req.output_token_logprobs_idx: # not none or empty list self.output_token_logprobs_idx[req.metadata_buffer_index][0] = ( req.output_token_logprobs_idx[0] ) if req.output_top_logprobs_val: # not none or empty list self.output_top_logprobs_val[req.metadata_buffer_index][ : len(req.output_top_logprobs_val[0]) ] = torch.tensor( req.output_top_logprobs_val[0], dtype=torch.float32, device="cpu" ) if req.output_top_logprobs_idx: # not none or empty list self.output_top_logprobs_idx[req.metadata_buffer_index][ : len(req.output_top_logprobs_idx[0]) ] = torch.tensor( req.output_top_logprobs_idx[0], dtype=torch.int32, device="cpu" )