from __future__ import annotations import os import random from collections import deque from contextlib import nullcontext from enum import Enum from typing import TYPE_CHECKING, List, Optional, Type, Union import numpy as np import torch import torch.distributed as dist from sglang.srt.utils import is_npu if TYPE_CHECKING: from sglang.srt.managers.schedule_batch import Req ######################### # Constants & Enums ######################### FAKE_BOOTSTRAP_HOST = "2.2.2.2" class DisaggregationMode(Enum): NULL = "null" PREFILL = "prefill" DECODE = "decode" ######################### # Synchronization ######################### # env var for testing failure, convert to float explicitly FAILURE_PROB = float(os.getenv("DISAGGREGATION_TEST_FAILURE_PROB", 0)) def poll_and_all_reduce(pollers, gloo_group): # at a certain prob, the poll is failed to simulate failure if FAILURE_PROB > 0: from sglang.srt.disaggregation.base import KVPoll polls = [ int(KVPoll.Failed) if random.random() < FAILURE_PROB else int(poller.poll()) for poller in pollers ] else: polls = [int(poller.poll()) for poller in pollers] tensor_to_reduce = torch.tensor(polls, dtype=torch.uint8, device="cpu") dist.all_reduce(tensor_to_reduce, op=dist.ReduceOp.MIN, group=gloo_group) return tensor_to_reduce.tolist() ######################### # Metadata Buffers ######################### class ReqToMetadataIdxAllocator: """A memory pool that maps a request to its first output token location.""" def __init__( self, size: int, ): self.size = size self.free_slots = deque(list(range(size))) def available_size(self): return len(self.free_slots) def alloc(self) -> Optional[int]: if len(self.free_slots) == 0: return None return self.free_slots.popleft() def free(self, free_index: int): self.free_slots.append(free_index) class MetadataBuffers: def __init__( self, size: int, hidden_size: int, hidden_states_dtype: torch.dtype, max_top_logprobs_num: int = 128, custom_mem_pool: torch.cuda.MemPool = None, ): self.custom_mem_pool = custom_mem_pool device = "cpu" if is_npu(): # For ascend backend, output tokens are placed in the NPU and will be transferred by D2D channel. device = "npu" elif self.custom_mem_pool: # TODO(shangming): Fix me (use 'cuda') when nvlink_transport of Mooncake is bug-free device = "cpu" with ( torch.cuda.use_mem_pool(self.custom_mem_pool) if self.custom_mem_pool else nullcontext() ): # TODO: abort top_logprobs_num > 128 in PD # We transfer the metadata of first output token to decode # The minimal size for RDMA is 64Bytes, so we pad it to > 64Bytes self.output_ids = torch.zeros((size, 16), dtype=torch.int32, device=device) self.cached_tokens = torch.zeros( (size, 16), dtype=torch.int32, device=device ) self.output_token_logprobs_val = torch.zeros( (size, 16), dtype=torch.float32, device=device ) self.output_token_logprobs_idx = torch.zeros( (size, 16), dtype=torch.int32, device=device ) self.output_top_logprobs_val = torch.zeros( (size, max_top_logprobs_num), dtype=torch.float32, device=device ) self.output_top_logprobs_idx = torch.zeros( (size, max_top_logprobs_num), dtype=torch.int32, device=device ) # For PD + spec decode self.output_topk_p = torch.zeros( (size, 16), dtype=torch.float32, device=device ) self.output_topk_index = torch.zeros( (size, 16), dtype=torch.int64, device=device ) self.output_hidden_states = torch.zeros( (size, hidden_size), dtype=hidden_states_dtype, device=device ) def get_buf_infos(self): ptrs = [ self.output_ids.data_ptr(), self.cached_tokens.data_ptr(), self.output_token_logprobs_val.data_ptr(), self.output_token_logprobs_idx.data_ptr(), self.output_top_logprobs_val.data_ptr(), self.output_top_logprobs_idx.data_ptr(), self.output_topk_p.data_ptr(), self.output_topk_index.data_ptr(), self.output_hidden_states.data_ptr(), ] data_lens = [ self.output_ids.nbytes, self.cached_tokens.nbytes, self.output_token_logprobs_val.nbytes, self.output_token_logprobs_idx.nbytes, self.output_top_logprobs_val.nbytes, self.output_top_logprobs_idx.nbytes, self.output_topk_p.nbytes, self.output_topk_index.nbytes, self.output_hidden_states.nbytes, ] item_lens = [ self.output_ids[0].nbytes, self.cached_tokens[0].nbytes, self.output_token_logprobs_val[0].nbytes, self.output_token_logprobs_idx[0].nbytes, self.output_top_logprobs_val[0].nbytes, self.output_top_logprobs_idx[0].nbytes, self.output_topk_p[0].nbytes, self.output_topk_index[0].nbytes, self.output_hidden_states[0].nbytes, ] return ptrs, data_lens, item_lens def get_buf(self, idx: int): return ( self.output_ids[idx], self.cached_tokens[idx], self.output_token_logprobs_val[idx], self.output_token_logprobs_idx[idx], self.output_top_logprobs_val[idx], self.output_top_logprobs_idx[idx], self.output_topk_p[idx], self.output_topk_index[idx], self.output_hidden_states[idx], ) def set_buf(self, req: Req): self.output_ids[req.metadata_buffer_index][0] = req.output_ids[0] self.cached_tokens[req.metadata_buffer_index][0] = req.cached_tokens if req.return_logprob: if req.output_token_logprobs_val: # not none or empty list self.output_token_logprobs_val[req.metadata_buffer_index][0] = ( req.output_token_logprobs_val[0] ) if req.output_token_logprobs_idx: # not none or empty list self.output_token_logprobs_idx[req.metadata_buffer_index][0] = ( req.output_token_logprobs_idx[0] ) if req.output_top_logprobs_val: # not none or empty list self.output_top_logprobs_val[req.metadata_buffer_index][ : len(req.output_top_logprobs_val[0]) ] = torch.tensor( req.output_top_logprobs_val[0], dtype=torch.float32, device="cpu" ) if req.output_top_logprobs_idx: # not none or empty list self.output_top_logprobs_idx[req.metadata_buffer_index][ : len(req.output_top_logprobs_idx[0]) ] = torch.tensor( req.output_top_logprobs_idx[0], dtype=torch.int32, device="cpu" ) # For PD + spec decode if req.hidden_states_tensor is not None: # speculative_eagle_topk should not be greater than 16 currently topk = req.output_topk_p.size(0) self.output_topk_p[req.metadata_buffer_index, :topk].copy_( req.output_topk_p ) self.output_topk_index[req.metadata_buffer_index, :topk].copy_( req.output_topk_index ) self.output_hidden_states[req.metadata_buffer_index].copy_( req.hidden_states_tensor ) ######################### # Transfer Backend ######################### class TransferBackend(Enum): MOONCAKE = "mooncake" NIXL = "nixl" ASCEND = "ascend" FAKE = "fake" class KVClassType(Enum): KVARGS = "kvargs" MANAGER = "manager" SENDER = "sender" RECEIVER = "receiver" BOOTSTRAP_SERVER = "bootstrap_server" def get_kv_class( transfer_backend: TransferBackend, class_type: KVClassType ) -> Optional[Type]: from sglang.srt.disaggregation.fake import FakeKVReceiver, FakeKVSender if transfer_backend == TransferBackend.MOONCAKE: from sglang.srt.disaggregation.base import KVArgs from sglang.srt.disaggregation.mooncake import ( MooncakeKVBootstrapServer, MooncakeKVManager, MooncakeKVReceiver, MooncakeKVSender, ) class_mapping = { KVClassType.KVARGS: KVArgs, KVClassType.MANAGER: MooncakeKVManager, KVClassType.SENDER: MooncakeKVSender, KVClassType.RECEIVER: (MooncakeKVReceiver), KVClassType.BOOTSTRAP_SERVER: MooncakeKVBootstrapServer, } return class_mapping.get(class_type) elif transfer_backend == TransferBackend.ASCEND: from sglang.srt.disaggregation.ascend import ( AscendKVBootstrapServer, AscendKVManager, AscendKVReceiver, AscendKVSender, ) from sglang.srt.disaggregation.base import KVArgs class_mapping = { KVClassType.KVARGS: KVArgs, KVClassType.MANAGER: AscendKVManager, KVClassType.SENDER: AscendKVSender, KVClassType.RECEIVER: (AscendKVReceiver), KVClassType.BOOTSTRAP_SERVER: AscendKVBootstrapServer, } return class_mapping.get(class_type) elif transfer_backend == TransferBackend.NIXL: from sglang.srt.disaggregation.base import KVArgs from sglang.srt.disaggregation.nixl import ( NixlKVBootstrapServer, NixlKVManager, NixlKVReceiver, NixlKVSender, ) class_mapping = { KVClassType.KVARGS: KVArgs, KVClassType.MANAGER: NixlKVManager, KVClassType.SENDER: NixlKVSender, KVClassType.RECEIVER: (NixlKVReceiver), KVClassType.BOOTSTRAP_SERVER: NixlKVBootstrapServer, } return class_mapping.get(class_type) elif transfer_backend == TransferBackend.FAKE: from sglang.srt.disaggregation.base import KVArgs from sglang.srt.disaggregation.fake import FakeKVReceiver, FakeKVSender class_mapping = { KVClassType.KVARGS: KVArgs, KVClassType.SENDER: FakeKVSender, KVClassType.RECEIVER: (FakeKVReceiver), } return class_mapping.get(class_type) raise ValueError(f"Unsupported transfer backend: {transfer_backend}") ######################### # KV Pages ######################### def kv_to_page_indices(kv_indices: np.ndarray, page_size: int): # 1. The page is guaranteed to be full except the last page. # 2. page index = kv_index // page_size # The return vector is kv_indices[::page_size] // page_size if page_size == 1: # shortcut return kv_indices return kv_indices[::page_size] // page_size def kv_to_page_num(num_kv_indices: int, page_size: int): # ceil(num_kv_indices / page_size) return (num_kv_indices + page_size - 1) // page_size ######################### # Misc ######################### def is_mla_backend(target_kv_pool) -> bool: from sglang.srt.mem_cache.memory_pool import MLATokenToKVPool return isinstance(target_kv_pool, MLATokenToKVPool) def prepare_abort(req: Req, error_message: str, status_code=None): from sglang.srt.managers.schedule_batch import FINISH_ABORT # populate finish metadata and stream output req.finished_reason = FINISH_ABORT(error_message, status_code) if req.return_logprob: req.input_token_logprobs_val = [] req.input_token_logprobs_idx = [] req.input_top_logprobs_val = [] req.input_top_logprobs_idx = [] req.input_token_ids_logprobs_val = [] req.input_token_ids_logprobs_idx = []