Unverified Commit 868403f6 authored by Shangming Cai's avatar Shangming Cai Committed by GitHub
Browse files

[PD] Add PD support for hybrid model (Qwen3-Next, DeepSeek V3.2 Exp) (#10912)


Signed-off-by: default avatarShangming Cai <csmthu@gmail.com>
Co-authored-by: default avatarhzh0425 <hzh0425@apache.org>
Co-authored-by: default avatarZeldaHuang <hzm414167@alibaba-inc.com>
parent 97d857c0
...@@ -20,6 +20,10 @@ class KVArgs: ...@@ -20,6 +20,10 @@ class KVArgs:
aux_data_ptrs: List[int] aux_data_ptrs: List[int]
aux_data_lens: List[int] aux_data_lens: List[int]
aux_item_lens: List[int] aux_item_lens: List[int]
state_data_ptrs: List[int]
state_data_lens: List[int]
state_item_lens: List[int]
state_type: str # "none", "mamba", "swa"
ib_device: str ib_device: str
ib_traffic_class: str ib_traffic_class: str
gpu_id: int gpu_id: int
...@@ -76,9 +80,13 @@ class BaseKVSender(ABC): ...@@ -76,9 +80,13 @@ class BaseKVSender(ABC):
... ...
@abstractmethod @abstractmethod
def send(self, kv_indices: npt.NDArray[np.int32]): def send(
self,
kv_indices: npt.NDArray[np.int32],
state_indices: Optional[List[int]] = None,
):
""" """
Send the kv cache at the given kv indices to the decoder server Send the kv cache at the given kv indices and the extra cache/state at the given indices to the decoder server
""" """
... ...
...@@ -108,9 +116,14 @@ class BaseKVReceiver(ABC): ...@@ -108,9 +116,14 @@ class BaseKVReceiver(ABC):
): ... ): ...
@abstractmethod @abstractmethod
def init(self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None): def init(
self,
kv_indices: npt.NDArray[np.int32],
aux_index: Optional[int] = None,
state_indices: Optional[List[int]] = None,
):
""" """
Notify the prefill server about the kv indices and aux index Notify the prefill server about the kv indices, aux index, and state_indices.
""" """
... ...
......
...@@ -201,6 +201,7 @@ class CommonKVSender(BaseKVSender): ...@@ -201,6 +201,7 @@ class CommonKVSender(BaseKVSender):
def send( def send(
self, self,
kv_indices: npt.NDArray[np.int32], kv_indices: npt.NDArray[np.int32],
state_indices: Optional[List[int]] = None,
): ):
pass pass
......
...@@ -25,11 +25,12 @@ import time ...@@ -25,11 +25,12 @@ import time
from collections import deque from collections import deque
from dataclasses import dataclass from dataclasses import dataclass
from http import HTTPStatus from http import HTTPStatus
from typing import TYPE_CHECKING, List, Optional, Tuple, Type, Union from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Type, Union
import torch import torch
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
from sglang.srt.configs.mamba_utils import Mamba2CacheParams
from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE
from sglang.srt.disaggregation.base import BaseKVManager, BaseKVReceiver, KVPoll from sglang.srt.disaggregation.base import BaseKVManager, BaseKVReceiver, KVPoll
from sglang.srt.disaggregation.utils import ( from sglang.srt.disaggregation.utils import (
...@@ -47,9 +48,19 @@ from sglang.srt.disaggregation.utils import ( ...@@ -47,9 +48,19 @@ from sglang.srt.disaggregation.utils import (
) )
from sglang.srt.layers.dp_attention import get_attention_tp_size from sglang.srt.layers.dp_attention import get_attention_tp_size
from sglang.srt.managers.schedule_batch import FINISH_ABORT, RequestStage, ScheduleBatch from sglang.srt.managers.schedule_batch import FINISH_ABORT, RequestStage, ScheduleBatch
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator from sglang.srt.mem_cache.allocator import (
BaseTokenToKVPoolAllocator,
SWATokenToKVPoolAllocator,
)
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.memory_pool import KVCache, ReqToTokenPool from sglang.srt.mem_cache.memory_pool import (
HybridLinearKVPool,
HybridReqToTokenPool,
KVCache,
NSATokenToKVPool,
ReqToTokenPool,
SWAKVPool,
)
from sglang.srt.model_executor.forward_batch_info import ForwardMode from sglang.srt.model_executor.forward_batch_info import ForwardMode
from sglang.srt.utils import get_int_env_var, require_mlp_sync from sglang.srt.utils import get_int_env_var, require_mlp_sync
from sglang.srt.utils.torch_memory_saver_adapter import TorchMemorySaverAdapter from sglang.srt.utils.torch_memory_saver_adapter import TorchMemorySaverAdapter
...@@ -124,6 +135,35 @@ class DecodeReqToTokenPool: ...@@ -124,6 +135,35 @@ class DecodeReqToTokenPool:
self.free_slots = list(range(self.size + self.pre_alloc_size)) self.free_slots = list(range(self.size + self.pre_alloc_size))
class HybridMambaDecodeReqToTokenPool(HybridReqToTokenPool):
def __init__(
self,
size: int,
max_context_len: int,
device: str,
enable_memory_saver: bool,
cache_params: "Mamba2CacheParams",
speculative_num_draft_tokens: int,
pre_alloc_size: int,
):
DecodeReqToTokenPool.__init__(
self,
size=size,
max_context_len=max_context_len,
device=device,
enable_memory_saver=enable_memory_saver,
pre_alloc_size=pre_alloc_size,
)
self._init_mamba_pool(
size + pre_alloc_size, cache_params, device, speculative_num_draft_tokens
)
def clear(self):
self.free_slots = list(range(self.size + self.pre_alloc_size))
self.mamba_pool.clear()
@dataclass @dataclass
class DecodeRequest: class DecodeRequest:
req: Req req: Req
...@@ -217,6 +257,28 @@ class DecodePreallocQueue: ...@@ -217,6 +257,28 @@ class DecodePreallocQueue:
self.metadata_buffers.get_buf_infos() self.metadata_buffers.get_buf_infos()
) )
if hasattr(self.token_to_kv_pool, "get_state_buf_infos"):
state_data_ptrs, state_data_lens, state_item_lens = (
self.token_to_kv_pool.get_state_buf_infos()
)
kv_args.state_data_ptrs = state_data_ptrs
kv_args.state_data_lens = state_data_lens
kv_args.state_item_lens = state_item_lens
if isinstance(self.token_to_kv_pool, SWAKVPool):
kv_args.state_type = "swa"
elif isinstance(self.token_to_kv_pool, HybridLinearKVPool):
kv_args.state_type = "mamba"
elif isinstance(self.token_to_kv_pool, NSATokenToKVPool):
kv_args.state_type = "nsa"
else:
kv_args.state_type = "none"
else:
kv_args.state_data_ptrs = []
kv_args.state_data_lens = []
kv_args.state_item_lens = []
kv_args.state_type = "none"
kv_args.ib_device = self.scheduler.server_args.disaggregation_ib_device kv_args.ib_device = self.scheduler.server_args.disaggregation_ib_device
kv_args.gpu_id = self.scheduler.gpu_id kv_args.gpu_id = self.scheduler.gpu_id
kv_manager_class: Type[BaseKVManager] = get_kv_class( kv_manager_class: Type[BaseKVManager] = get_kv_class(
...@@ -414,16 +476,56 @@ class DecodePreallocQueue: ...@@ -414,16 +476,56 @@ class DecodePreallocQueue:
.cpu() .cpu()
.numpy() .numpy()
) )
page_size = self.token_to_kv_pool_allocator.page_size
# Prepare extra pool indices for hybrid models
if isinstance(self.token_to_kv_pool, HybridLinearKVPool):
# Mamba hybrid model: single mamba state index
state_indices = [
self.req_to_token_pool.req_index_to_mamba_index_mapping[
decode_req.req.req_pool_idx
]
.cpu()
.numpy()
]
elif isinstance(self.token_to_kv_pool, SWAKVPool):
# SWA hybrid model: send decode-side SWA window indices
seq_len = len(decode_req.req.origin_input_ids)
window_size = self.scheduler.sliding_window_size
window_start = max(0, seq_len - window_size)
window_start = (window_start // page_size) * page_size
window_kv_indices_full = self.req_to_token_pool.req_to_token[
decode_req.req.req_pool_idx, window_start:seq_len
]
# Translate to SWA pool indices
window_kv_indices_swa = (
self.token_to_kv_pool_allocator.translate_loc_from_full_to_swa(
window_kv_indices_full
)
)
state_indices = window_kv_indices_swa.cpu().numpy()
state_indices = kv_to_page_indices(state_indices, page_size)
elif isinstance(self.token_to_kv_pool, NSATokenToKVPool):
seq_len = len(decode_req.req.origin_input_ids)
kv_indices_full = self.req_to_token_pool.req_to_token[
decode_req.req.req_pool_idx, :seq_len
]
state_indices = kv_indices_full.cpu().numpy()
state_indices = kv_to_page_indices(state_indices, page_size)
else:
state_indices = None
decode_req.metadata_buffer_index = ( decode_req.metadata_buffer_index = (
self.req_to_metadata_buffer_idx_allocator.alloc() self.req_to_metadata_buffer_idx_allocator.alloc()
) )
assert decode_req.metadata_buffer_index is not None assert decode_req.metadata_buffer_index is not None
page_indices = kv_to_page_indices( page_indices = kv_to_page_indices(kv_indices, page_size)
kv_indices, self.token_to_kv_pool_allocator.page_size decode_req.kv_receiver.init(
page_indices, decode_req.metadata_buffer_index, state_indices
) )
decode_req.kv_receiver.init(page_indices, decode_req.metadata_buffer_index) decode_req.req.add_latency(RequestStage.DECODE_BOOTSTRAP)
preallocated_reqs.append(decode_req) preallocated_reqs.append(decode_req)
indices_to_remove.add(i) indices_to_remove.add(i)
decode_req.req.time_stats.decode_transfer_queue_entry_time = ( decode_req.req.time_stats.decode_transfer_queue_entry_time = (
...@@ -503,6 +605,9 @@ class DecodePreallocQueue: ...@@ -503,6 +605,9 @@ class DecodePreallocQueue:
def _pre_alloc(self, req: Req) -> torch.Tensor: def _pre_alloc(self, req: Req) -> torch.Tensor:
"""Pre-allocate the memory for req_to_token and token_kv_pool""" """Pre-allocate the memory for req_to_token and token_kv_pool"""
if isinstance(self.req_to_token_pool, HybridMambaDecodeReqToTokenPool):
req_pool_indices = self.req_to_token_pool.alloc(1, [req])
else:
req_pool_indices = self.req_to_token_pool.alloc(1) req_pool_indices = self.req_to_token_pool.alloc(1)
assert ( assert (
......
...@@ -48,9 +48,12 @@ class FakeKVSender(BaseKVSender): ...@@ -48,9 +48,12 @@ class FakeKVSender(BaseKVSender):
def send( def send(
self, self,
kv_indices: npt.NDArray[np.int32], kv_indices: npt.NDArray[np.int32],
state_indices: Optional[List[int]] = None,
): ):
self.has_sent = True self.has_sent = True
logger.debug(f"FakeKVSender send with kv_indices: {kv_indices}") logger.debug(
f"FakeKVSender send with kv_indices: {kv_indices}, state_indices: {state_indices}"
)
def failure_exception(self): def failure_exception(self):
raise Exception("Fake KVSender Exception") raise Exception("Fake KVSender Exception")
...@@ -75,10 +78,15 @@ class FakeKVReceiver(BaseKVReceiver): ...@@ -75,10 +78,15 @@ class FakeKVReceiver(BaseKVReceiver):
logger.debug("FakeKVReceiver poll success") logger.debug("FakeKVReceiver poll success")
return KVPoll.Success return KVPoll.Success
def init(self, kv_indices: list[int], aux_index: Optional[int] = None): def init(
self,
kv_indices: list[int],
aux_index: Optional[int] = None,
state_indices: Optional[List[int]] = None,
):
self.has_init = True self.has_init = True
logger.debug( logger.debug(
f"FakeKVReceiver init with kv_indices: {kv_indices}, aux_index: {aux_index}" f"FakeKVReceiver init with kv_indices: {kv_indices}, aux_index: {aux_index}, state_indices: {state_indices}"
) )
def failure_exception(self): def failure_exception(self):
......
...@@ -58,6 +58,7 @@ class TransferKVChunk: ...@@ -58,6 +58,7 @@ class TransferKVChunk:
index_slice: slice index_slice: slice
is_last: bool is_last: bool
prefill_aux_index: Optional[int] prefill_aux_index: Optional[int]
state_indices: Optional[List[int]]
# decode # decode
...@@ -69,6 +70,7 @@ class TransferInfo: ...@@ -69,6 +70,7 @@ class TransferInfo:
mooncake_session_id: str mooncake_session_id: str
dst_kv_indices: npt.NDArray[np.int32] dst_kv_indices: npt.NDArray[np.int32]
dst_aux_index: int dst_aux_index: int
dst_state_indices: List[int]
required_dst_info_num: int required_dst_info_num: int
is_dummy: bool is_dummy: bool
...@@ -78,9 +80,14 @@ class TransferInfo: ...@@ -78,9 +80,14 @@ class TransferInfo:
is_dummy = True is_dummy = True
dst_kv_indices = np.array([], dtype=np.int32) dst_kv_indices = np.array([], dtype=np.int32)
dst_aux_index = None dst_aux_index = None
dst_state_indices = []
else: else:
dst_kv_indices = np.frombuffer(msg[4], dtype=np.int32) dst_kv_indices = np.frombuffer(msg[4], dtype=np.int32)
dst_aux_index = int(msg[5].decode("ascii")) dst_aux_index = int(msg[5].decode("ascii"))
if msg[6] == b"":
dst_state_indices = []
else:
dst_state_indices = list(np.frombuffer(msg[6], dtype=np.int32))
is_dummy = False is_dummy = False
return cls( return cls(
room=int(msg[0].decode("ascii")), room=int(msg[0].decode("ascii")),
...@@ -89,7 +96,8 @@ class TransferInfo: ...@@ -89,7 +96,8 @@ class TransferInfo:
mooncake_session_id=msg[3].decode("ascii"), mooncake_session_id=msg[3].decode("ascii"),
dst_kv_indices=dst_kv_indices, dst_kv_indices=dst_kv_indices,
dst_aux_index=dst_aux_index, dst_aux_index=dst_aux_index,
required_dst_info_num=int(msg[6].decode("ascii")), dst_state_indices=dst_state_indices,
required_dst_info_num=int(msg[7].decode("ascii")),
is_dummy=is_dummy, is_dummy=is_dummy,
) )
...@@ -103,6 +111,7 @@ class KVArgsRegisterInfo: ...@@ -103,6 +111,7 @@ class KVArgsRegisterInfo:
mooncake_session_id: str mooncake_session_id: str
dst_kv_ptrs: list[int] dst_kv_ptrs: list[int]
dst_aux_ptrs: list[int] dst_aux_ptrs: list[int]
dst_state_data_ptrs: list[int]
dst_tp_rank: int dst_tp_rank: int
dst_attn_tp_size: int dst_attn_tp_size: int
dst_kv_item_len: int dst_kv_item_len: int
...@@ -116,9 +125,10 @@ class KVArgsRegisterInfo: ...@@ -116,9 +125,10 @@ class KVArgsRegisterInfo:
mooncake_session_id=msg[3].decode("ascii"), mooncake_session_id=msg[3].decode("ascii"),
dst_kv_ptrs=list(struct.unpack(f"{len(msg[4])//8}Q", msg[4])), dst_kv_ptrs=list(struct.unpack(f"{len(msg[4])//8}Q", msg[4])),
dst_aux_ptrs=list(struct.unpack(f"{len(msg[5])//8}Q", msg[5])), dst_aux_ptrs=list(struct.unpack(f"{len(msg[5])//8}Q", msg[5])),
dst_tp_rank=int(msg[6].decode("ascii")), dst_state_data_ptrs=list(struct.unpack(f"{len(msg[6])//8}Q", msg[6])),
dst_attn_tp_size=int(msg[7].decode("ascii")), dst_tp_rank=int(msg[7].decode("ascii")),
dst_kv_item_len=int(msg[8].decode("ascii")), dst_attn_tp_size=int(msg[8].decode("ascii")),
dst_kv_item_len=int(msg[9].decode("ascii")),
) )
...@@ -180,6 +190,9 @@ class MooncakeKVManager(CommonKVManager): ...@@ -180,6 +190,9 @@ class MooncakeKVManager(CommonKVManager):
) )
for _ in range(transfer_queue_size) for _ in range(transfer_queue_size)
] ]
self.state_executors = concurrent.futures.ThreadPoolExecutor(
transfer_thread_pool_size // transfer_queue_size
)
for queue, executor in zip(self.transfer_queues, self.executors): for queue, executor in zip(self.transfer_queues, self.executors):
threading.Thread( threading.Thread(
target=self.transfer_worker, args=(queue, executor), daemon=True target=self.transfer_worker, args=(queue, executor), daemon=True
...@@ -239,6 +252,12 @@ class MooncakeKVManager(CommonKVManager): ...@@ -239,6 +252,12 @@ class MooncakeKVManager(CommonKVManager):
self.kv_args.aux_data_ptrs, self.kv_args.aux_data_lens self.kv_args.aux_data_ptrs, self.kv_args.aux_data_lens
) )
# Batch register state/extra pool data buffers
if self.kv_args.state_data_ptrs and self.kv_args.state_data_lens:
self.engine.batch_register(
self.kv_args.state_data_ptrs, self.kv_args.state_data_lens
)
def _transfer_data(self, mooncake_session_id, transfer_blocks): def _transfer_data(self, mooncake_session_id, transfer_blocks):
if not transfer_blocks: if not transfer_blocks:
return 0 return 0
...@@ -248,17 +267,23 @@ class MooncakeKVManager(CommonKVManager): ...@@ -248,17 +267,23 @@ class MooncakeKVManager(CommonKVManager):
mooncake_session_id, list(src_addrs), list(dst_addrs), list(lengths) mooncake_session_id, list(src_addrs), list(dst_addrs), list(lengths)
) )
def send_kvcache( def _send_kvcache_generic(
self, self,
mooncake_session_id: str, mooncake_session_id: str,
prefill_kv_indices: npt.NDArray[np.int32], src_data_ptrs: list[int],
dst_kv_ptrs: list[int], dst_data_ptrs: list[int],
dst_kv_indices: npt.NDArray[np.int32], item_lens: list[int],
prefill_data_indices: npt.NDArray[np.int32],
dst_data_indices: npt.NDArray[np.int32],
executor: concurrent.futures.ThreadPoolExecutor, executor: concurrent.futures.ThreadPoolExecutor,
): ) -> int:
# Group by indices """
Generic KV cache transfer supporting both MHA and MLA architectures.
This method is used by both send_kvcache (full pool) and maybe_send_extra.
"""
# Group by indices for optimization
prefill_kv_blocks, dst_kv_blocks = group_concurrent_contiguous( prefill_kv_blocks, dst_kv_blocks = group_concurrent_contiguous(
prefill_kv_indices, dst_kv_indices prefill_data_indices, dst_data_indices
) )
layers_params = None layers_params = None
...@@ -266,9 +291,9 @@ class MooncakeKVManager(CommonKVManager): ...@@ -266,9 +291,9 @@ class MooncakeKVManager(CommonKVManager):
# pp is not supported on the decode side yet # pp is not supported on the decode side yet
if self.is_mla_backend: if self.is_mla_backend:
src_kv_ptrs, dst_kv_ptrs, layers_current_pp_stage = ( src_kv_ptrs, dst_kv_ptrs, layers_current_pp_stage = (
self.get_mla_kv_ptrs_with_pp(self.kv_args.kv_data_ptrs, dst_kv_ptrs) self.get_mla_kv_ptrs_with_pp(src_data_ptrs, dst_data_ptrs)
) )
kv_item_len = self.kv_args.kv_item_lens[0] kv_item_len = item_lens[0]
layers_params = [ layers_params = [
( (
src_kv_ptrs[layer_id], src_kv_ptrs[layer_id],
...@@ -279,9 +304,9 @@ class MooncakeKVManager(CommonKVManager): ...@@ -279,9 +304,9 @@ class MooncakeKVManager(CommonKVManager):
] ]
else: else:
src_k_ptrs, src_v_ptrs, dst_k_ptrs, dst_v_ptrs, layers_current_pp_stage = ( src_k_ptrs, src_v_ptrs, dst_k_ptrs, dst_v_ptrs, layers_current_pp_stage = (
self.get_mha_kv_ptrs_with_pp(self.kv_args.kv_data_ptrs, dst_kv_ptrs) self.get_mha_kv_ptrs_with_pp(src_data_ptrs, dst_data_ptrs)
) )
kv_item_len = self.kv_args.kv_item_lens[0] kv_item_len = item_lens[0]
layers_params = [ layers_params = [
( (
src_k_ptrs[layer_id], src_k_ptrs[layer_id],
...@@ -345,6 +370,24 @@ class MooncakeKVManager(CommonKVManager): ...@@ -345,6 +370,24 @@ class MooncakeKVManager(CommonKVManager):
return 0 return 0
def send_kvcache(
self,
mooncake_session_id: str,
prefill_kv_indices: npt.NDArray[np.int32],
dst_kv_ptrs: list[int],
dst_kv_indices: npt.NDArray[np.int32],
executor: concurrent.futures.ThreadPoolExecutor,
):
return self._send_kvcache_generic(
mooncake_session_id=mooncake_session_id,
src_data_ptrs=self.kv_args.kv_data_ptrs,
dst_data_ptrs=dst_kv_ptrs,
item_lens=self.kv_args.kv_item_lens,
prefill_data_indices=prefill_kv_indices,
dst_data_indices=dst_kv_indices,
executor=executor,
)
def send_kvcache_slice( def send_kvcache_slice(
self, self,
mooncake_session_id: str, mooncake_session_id: str,
...@@ -593,6 +636,58 @@ class MooncakeKVManager(CommonKVManager): ...@@ -593,6 +636,58 @@ class MooncakeKVManager(CommonKVManager):
f"Received AUX_DATA for bootstrap_room {room} with length:{len(data)}" f"Received AUX_DATA for bootstrap_room {room} with length:{len(data)}"
) )
def maybe_send_extra(
self,
req: TransferInfo,
prefill_state_indices: list[int],
dst_state_data_ptrs: list[int],
):
"""Send state or extra pool data with type-specific handling."""
state_type = getattr(self.kv_args, "state_type", "none")
if state_type == "mamba":
return self._send_mamba_state(
req,
prefill_state_indices,
dst_state_data_ptrs,
)
elif state_type in ["swa", "nsa"]:
# Reuse _send_kvcache_generic interface to send extra pool data
prefill_state_indices = np.array(prefill_state_indices, dtype=np.int32)
dst_state_indices = np.array(req.dst_state_indices, dtype=np.int32)
return self._send_kvcache_generic(
mooncake_session_id=req.mooncake_session_id,
src_data_ptrs=self.kv_args.state_data_ptrs,
dst_data_ptrs=dst_state_data_ptrs,
item_lens=self.kv_args.state_item_lens,
prefill_data_indices=prefill_state_indices,
dst_data_indices=dst_state_indices,
executor=self.state_executors,
)
else:
return 0
def _send_mamba_state(
self,
req: TransferInfo,
prefill_mamba_index: list[int],
dst_state_data_ptrs: list[int],
):
"""Transfer Mamba states."""
assert len(prefill_mamba_index) == 1, "Mamba should have single state index"
transfer_blocks = []
prefill_state_data_ptrs = self.kv_args.state_data_ptrs
prefill_state_item_lens = self.kv_args.state_item_lens
for i, dst_state_ptr in enumerate(dst_state_data_ptrs):
length = prefill_state_item_lens[i]
src_addr = prefill_state_data_ptrs[i] + length * int(prefill_mamba_index[0])
dst_addr = dst_state_ptr + length * int(req.dst_state_indices[0])
transfer_blocks.append((src_addr, dst_addr, length))
return self._transfer_data(req.mooncake_session_id, transfer_blocks)
def sync_status_to_decode_endpoint( def sync_status_to_decode_endpoint(
self, remote: str, dst_port: int, room: int, status: int, prefill_rank: int self, remote: str, dst_port: int, room: int, status: int, prefill_rank: int
): ):
...@@ -702,6 +797,21 @@ class MooncakeKVManager(CommonKVManager): ...@@ -702,6 +797,21 @@ class MooncakeKVManager(CommonKVManager):
break break
if kv_chunk.is_last: if kv_chunk.is_last:
if kv_chunk.state_indices is not None:
if not self.is_mla_backend and (
self.attn_tp_size
!= target_rank_registration_info.dst_attn_tp_size
):
raise RuntimeError(
f"PD Disaggregation does NOT support PD different TP sizes for non-MLA hybrid models yet."
)
self.maybe_send_extra(
req,
kv_chunk.state_indices,
target_rank_registration_info.dst_state_data_ptrs,
)
if self.pp_group.is_last_rank: if self.pp_group.is_last_rank:
# Only the last chunk we need to send the aux data # Only the last chunk we need to send the aux data
ret = self.send_aux( ret = self.send_aux(
...@@ -765,7 +875,7 @@ class MooncakeKVManager(CommonKVManager): ...@@ -765,7 +875,7 @@ class MooncakeKVManager(CommonKVManager):
) )
continue continue
else: else:
required_dst_info_num = int(waiting_req_bytes[6].decode("ascii")) required_dst_info_num = int(waiting_req_bytes[7].decode("ascii"))
room = int(room) room = int(room)
if room not in self.transfer_infos: if room not in self.transfer_infos:
self.transfer_infos[room] = {} self.transfer_infos[room] = {}
...@@ -876,6 +986,7 @@ class MooncakeKVManager(CommonKVManager): ...@@ -876,6 +986,7 @@ class MooncakeKVManager(CommonKVManager):
index_slice: slice, index_slice: slice,
is_last: bool, is_last: bool,
aux_index: Optional[int] = None, aux_index: Optional[int] = None,
state_indices: Optional[List[int]] = None,
): ):
assert self.disaggregation_mode == DisaggregationMode.PREFILL assert self.disaggregation_mode == DisaggregationMode.PREFILL
assert not is_last or (is_last and aux_index is not None) assert not is_last or (is_last and aux_index is not None)
...@@ -909,6 +1020,7 @@ class MooncakeKVManager(CommonKVManager): ...@@ -909,6 +1020,7 @@ class MooncakeKVManager(CommonKVManager):
index_slice=index_slice, index_slice=index_slice,
is_last=is_last, is_last=is_last,
prefill_aux_index=aux_index, prefill_aux_index=aux_index,
state_indices=state_indices,
) )
) )
...@@ -989,6 +1101,7 @@ class MooncakeKVSender(CommonKVSender): ...@@ -989,6 +1101,7 @@ class MooncakeKVSender(CommonKVSender):
def send( def send(
self, self,
kv_indices: npt.NDArray[np.int32], kv_indices: npt.NDArray[np.int32],
state_indices: Optional[List[int]] = None,
): ):
index_slice = slice(self.curr_idx, self.curr_idx + len(kv_indices)) index_slice = slice(self.curr_idx, self.curr_idx + len(kv_indices))
self.curr_idx += len(kv_indices) self.curr_idx += len(kv_indices)
...@@ -1008,6 +1121,7 @@ class MooncakeKVSender(CommonKVSender): ...@@ -1008,6 +1121,7 @@ class MooncakeKVSender(CommonKVSender):
index_slice, index_slice,
True, True,
aux_index=self.aux_index, aux_index=self.aux_index,
state_indices=state_indices,
) )
def poll(self) -> KVPoll: def poll(self) -> KVPoll:
...@@ -1110,6 +1224,9 @@ class MooncakeKVReceiver(CommonKVReceiver): ...@@ -1110,6 +1224,9 @@ class MooncakeKVReceiver(CommonKVReceiver):
packed_aux_data_ptrs = b"".join( packed_aux_data_ptrs = b"".join(
struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.aux_data_ptrs struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.aux_data_ptrs
) )
packed_state_data_ptrs = b"".join(
struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.state_data_ptrs
)
# Note(shangming): No need to add pp rank here since pp is not supported on the decode side yet # Note(shangming): No need to add pp rank here since pp is not supported on the decode side yet
tp_rank = self.kv_mgr.kv_args.engine_rank tp_rank = self.kv_mgr.kv_args.engine_rank
kv_item_len = self.kv_mgr.kv_args.kv_item_lens[0] kv_item_len = self.kv_mgr.kv_args.kv_item_lens[0]
...@@ -1127,13 +1244,19 @@ class MooncakeKVReceiver(CommonKVReceiver): ...@@ -1127,13 +1244,19 @@ class MooncakeKVReceiver(CommonKVReceiver):
self.session_id.encode("ascii"), self.session_id.encode("ascii"),
packed_kv_data_ptrs, packed_kv_data_ptrs,
packed_aux_data_ptrs, packed_aux_data_ptrs,
packed_state_data_ptrs,
dst_tp_rank, dst_tp_rank,
dst_attn_tp_size, dst_attn_tp_size,
dst_kv_item_len, dst_kv_item_len,
] ]
) )
def init(self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None): def init(
self,
kv_indices: npt.NDArray[np.int32],
aux_index: Optional[int] = None,
state_indices: Optional[List[int]] = None,
):
for bootstrap_info in self.bootstrap_infos: for bootstrap_info in self.bootstrap_infos:
sock, lock = self._connect_to_bootstrap_server(bootstrap_info) sock, lock = self._connect_to_bootstrap_server(bootstrap_info)
is_dummy = bootstrap_info["is_dummy"] is_dummy = bootstrap_info["is_dummy"]
...@@ -1147,6 +1270,14 @@ class MooncakeKVReceiver(CommonKVReceiver): ...@@ -1147,6 +1270,14 @@ class MooncakeKVReceiver(CommonKVReceiver):
self.session_id.encode("ascii"), self.session_id.encode("ascii"),
kv_indices.tobytes() if not is_dummy else b"", kv_indices.tobytes() if not is_dummy else b"",
str(aux_index).encode("ascii") if not is_dummy else b"", str(aux_index).encode("ascii") if not is_dummy else b"",
(
np.array(
state_indices,
dtype=np.int32,
).tobytes()
if not is_dummy and state_indices is not None
else b""
),
str(self.required_dst_info_num).encode("ascii"), str(self.required_dst_info_num).encode("ascii"),
] ]
) )
......
...@@ -704,6 +704,7 @@ class NixlKVSender(CommonKVSender): ...@@ -704,6 +704,7 @@ class NixlKVSender(CommonKVSender):
def send( def send(
self, self,
kv_indices: npt.NDArray[np.int32], kv_indices: npt.NDArray[np.int32],
state_indices: Optional[List[int]] = None,
): ):
index_slice = slice(self.curr_idx, self.curr_idx + len(kv_indices)) index_slice = slice(self.curr_idx, self.curr_idx + len(kv_indices))
self.curr_idx += len(kv_indices) self.curr_idx += len(kv_indices)
...@@ -755,7 +756,12 @@ class NixlKVReceiver(CommonKVReceiver): ...@@ -755,7 +756,12 @@ class NixlKVReceiver(CommonKVReceiver):
self.bootstrap_room self.bootstrap_room
) )
def init(self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None): def init(
self,
kv_indices: npt.NDArray[np.int32],
aux_index: Optional[int] = None,
state_indices: Optional[List[int]] = None,
):
for bootstrap_info in self.bootstrap_infos: for bootstrap_info in self.bootstrap_infos:
logger.debug( logger.debug(
f"Fetched bootstrap info: {bootstrap_info} for engine rank: {self.kv_mgr.kv_args.engine_rank}" f"Fetched bootstrap info: {bootstrap_info} for engine rank: {self.kv_mgr.kv_args.engine_rank}"
......
...@@ -49,6 +49,11 @@ from sglang.srt.managers.schedule_batch import ( ...@@ -49,6 +49,11 @@ from sglang.srt.managers.schedule_batch import (
RequestStage, RequestStage,
ScheduleBatch, ScheduleBatch,
) )
from sglang.srt.mem_cache.memory_pool import (
HybridLinearKVPool,
NSATokenToKVPool,
SWAKVPool,
)
from sglang.srt.model_executor.forward_batch_info import ForwardMode, PPProxyTensors from sglang.srt.model_executor.forward_batch_info import ForwardMode, PPProxyTensors
from sglang.srt.utils import ( from sglang.srt.utils import (
DynamicGradMode, DynamicGradMode,
...@@ -146,6 +151,28 @@ class PrefillBootstrapQueue: ...@@ -146,6 +151,28 @@ class PrefillBootstrapQueue:
kv_args.ib_device = self.scheduler.server_args.disaggregation_ib_device kv_args.ib_device = self.scheduler.server_args.disaggregation_ib_device
kv_args.gpu_id = self.scheduler.gpu_id kv_args.gpu_id = self.scheduler.gpu_id
if hasattr(self.token_to_kv_pool, "get_state_buf_infos"):
state_data_ptrs, state_data_lens, state_item_lens = (
self.token_to_kv_pool.get_state_buf_infos()
)
kv_args.state_data_ptrs = state_data_ptrs
kv_args.state_data_lens = state_data_lens
kv_args.state_item_lens = state_item_lens
if isinstance(self.token_to_kv_pool, SWAKVPool):
kv_args.state_type = "swa"
elif isinstance(self.token_to_kv_pool, HybridLinearKVPool):
kv_args.state_type = "mamba"
elif isinstance(self.token_to_kv_pool, NSATokenToKVPool):
kv_args.state_type = "nsa"
else:
kv_args.state_type = "none"
else:
kv_args.state_data_ptrs = []
kv_args.state_data_lens = []
kv_args.state_item_lens = []
kv_args.state_type = "none"
kv_manager_class: Type[BaseKVManager] = get_kv_class( kv_manager_class: Type[BaseKVManager] = get_kv_class(
self.transfer_backend, KVClassType.MANAGER self.transfer_backend, KVClassType.MANAGER
) )
...@@ -618,15 +645,58 @@ class SchedulerDisaggregationPrefillMixin: ...@@ -618,15 +645,58 @@ class SchedulerDisaggregationPrefillMixin:
.numpy() .numpy()
) )
req.start_send_idx = end_idx req.start_send_idx = end_idx
state_indices = None
if last_chunk: if last_chunk:
self.disagg_metadata_buffers.set_buf(req) self.disagg_metadata_buffers.set_buf(req)
# Prepare extra pool indices for hybrid models
if isinstance(
self.token_to_kv_pool_allocator.get_kvcache(), HybridLinearKVPool
):
# Mamba hybrid model: send single mamba state index
state_indices = [
self.req_to_token_pool.req_index_to_mamba_index_mapping[
req.req_pool_idx
]
.cpu()
.numpy()
]
elif isinstance(self.token_to_kv_pool_allocator.get_kvcache(), SWAKVPool):
# SWA hybrid model: send last window KV indices
seq_len = len(req.fill_ids)
window_size = self.sliding_window_size
window_start = max(0, seq_len - window_size)
window_start = (window_start // page_size) * page_size
window_kv_indices_full = self.req_to_token_pool.req_to_token[
req.req_pool_idx, window_start:seq_len
]
# Translate to SWA pool indices
window_kv_indices_swa = (
self.token_to_kv_pool_allocator.translate_loc_from_full_to_swa(
window_kv_indices_full
)
)
state_indices = window_kv_indices_swa.cpu().numpy()
state_indices = kv_to_page_indices(state_indices, page_size)
elif isinstance(
self.token_to_kv_pool_allocator.get_kvcache(), NSATokenToKVPool
):
seq_len = len(req.fill_ids)
kv_indices_full = self.req_to_token_pool.req_to_token[
req.req_pool_idx, :seq_len
]
state_indices = kv_indices_full.cpu().numpy()
state_indices = kv_to_page_indices(state_indices, page_size)
page_indices = kv_to_page_indices(kv_indices, page_size) page_indices = kv_to_page_indices(kv_indices, page_size)
if len(page_indices) == 0: if len(page_indices) == 0:
logger.info( logger.info(
f"Skip sending kv chunk for request {req.rid=} {req.bootstrap_room=} because page_indices is empty" f"Skip sending kv chunk for request {req.rid=} {req.bootstrap_room=} because page_indices is empty"
) )
return return
req.disagg_kv_sender.send(page_indices) req.disagg_kv_sender.send(page_indices, state_indices)
# PP # PP
@DynamicGradMode() @DynamicGradMode()
......
...@@ -807,9 +807,6 @@ class Scheduler( ...@@ -807,9 +807,6 @@ class Scheduler(
self.tree_cache.cache_controller.layer_done_counter self.tree_cache.cache_controller.layer_done_counter
) )
elif self.is_hybrid: elif self.is_hybrid:
assert (
self.server_args.disaggregation_mode == "null"
), "Hybrid mode does not support disaggregation yet"
self.tree_cache = SWARadixCache( self.tree_cache = SWARadixCache(
req_to_token_pool=self.req_to_token_pool, req_to_token_pool=self.req_to_token_pool,
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator, token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
...@@ -819,9 +816,6 @@ class Scheduler( ...@@ -819,9 +816,6 @@ class Scheduler(
is_eagle=self.spec_algorithm.is_eagle(), is_eagle=self.spec_algorithm.is_eagle(),
) )
elif self.is_hybrid_gdn: elif self.is_hybrid_gdn:
assert (
self.server_args.disaggregation_mode == "null"
), "Hybrid GDN mode does not support disaggregation yet"
self.tree_cache = MambaRadixCache( self.tree_cache = MambaRadixCache(
req_to_token_pool=self.req_to_token_pool, req_to_token_pool=self.req_to_token_pool,
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator, token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
......
...@@ -142,6 +142,24 @@ class MambaPool: ...@@ -142,6 +142,24 @@ class MambaPool:
ssm_dtype = cache_params.dtype.temporal ssm_dtype = cache_params.dtype.temporal
num_mamba_layers = len(cache_params.layers) num_mamba_layers = len(cache_params.layers)
# for disagg with nvlink
self.enable_custom_mem_pool = get_bool_env_var(
"SGLANG_MOONCAKE_CUSTOM_MEM_POOL", "false"
)
if self.enable_custom_mem_pool:
# TODO(shangming): abstract custom allocator class for more backends
from mooncake.allocator import NVLinkAllocator
allocator = NVLinkAllocator.get_allocator(self.device)
self.custom_mem_pool = torch.cuda.MemPool(allocator.allocator())
else:
self.custom_mem_pool = None
with (
torch.cuda.use_mem_pool(self.custom_mem_pool)
if self.enable_custom_mem_pool
else nullcontext()
):
# assume conv_state = (dim, state_len) # assume conv_state = (dim, state_len)
assert conv_state_shape[0] > conv_state_shape[1] assert conv_state_shape[0] > conv_state_shape[1]
conv_state = torch.zeros( conv_state = torch.zeros(
...@@ -206,8 +224,11 @@ class MambaPool: ...@@ -206,8 +224,11 @@ class MambaPool:
) )
self.size = size self.size = size
self.device = device self.device = device
self.free_slots = torch.arange(self.size, dtype=torch.int64, device=self.device) self.free_slots = torch.arange(
self.size, dtype=torch.int64, device=self.device
)
self.mem_usage = self.mamba_cache.mem_usage_bytes() / GB self.mem_usage = self.mamba_cache.mem_usage_bytes() / GB
self.num_mamba_layers = num_mamba_layers
def get_speculative_mamba2_params_all_layers(self) -> SpeculativeState: def get_speculative_mamba2_params_all_layers(self) -> SpeculativeState:
assert isinstance(self.mamba_cache, self.SpeculativeState) assert isinstance(self.mamba_cache, self.SpeculativeState)
...@@ -253,6 +274,22 @@ class MambaPool: ...@@ -253,6 +274,22 @@ class MambaPool:
self.copy_from(src_index, dst_index) self.copy_from(src_index, dst_index)
return dst_index return dst_index
def get_contiguous_buf_infos(self):
state_tensors = [
getattr(self.mamba_cache, field) for field in vars(self.mamba_cache)
]
data_ptrs, data_lens, item_lens = [], [], []
for _, state_tensor in enumerate(state_tensors):
data_ptrs += [
state_tensor[i].data_ptr() for i in range(self.num_mamba_layers)
]
data_lens += [state_tensor[i].nbytes for i in range(self.num_mamba_layers)]
item_lens += [
state_tensor[i][0].nbytes for i in range(self.num_mamba_layers)
]
return data_ptrs, data_lens, item_lens
class HybridReqToTokenPool(ReqToTokenPool): class HybridReqToTokenPool(ReqToTokenPool):
"""A memory pool that maps a request to its token locations.""" """A memory pool that maps a request to its token locations."""
...@@ -274,9 +311,22 @@ class HybridReqToTokenPool(ReqToTokenPool): ...@@ -274,9 +311,22 @@ class HybridReqToTokenPool(ReqToTokenPool):
device=device, device=device,
enable_memory_saver=enable_memory_saver, enable_memory_saver=enable_memory_saver,
) )
self._init_mamba_pool(
size=mamba_size,
cache_params=cache_params,
device=device,
speculative_num_draft_tokens=speculative_num_draft_tokens,
)
def _init_mamba_pool(
self,
size: int,
cache_params: "Mamba2CacheParams",
device: str,
speculative_num_draft_tokens: int = None,
):
self.mamba_pool = MambaPool( self.mamba_pool = MambaPool(
size=mamba_size, size=size,
cache_params=cache_params, cache_params=cache_params,
device=device, device=device,
speculative_num_draft_tokens=speculative_num_draft_tokens, speculative_num_draft_tokens=speculative_num_draft_tokens,
...@@ -375,6 +425,19 @@ class KVCache(abc.ABC): ...@@ -375,6 +425,19 @@ class KVCache(abc.ABC):
# default state for optional layer-wise transfer control # default state for optional layer-wise transfer control
self.layer_transfer_counter = None self.layer_transfer_counter = None
# for disagg with nvlink
self.enable_custom_mem_pool = get_bool_env_var(
"SGLANG_MOONCAKE_CUSTOM_MEM_POOL", "false"
)
if self.enable_custom_mem_pool:
# TODO(shangming): abstract custom allocator class for more backends
from mooncake.allocator import NVLinkAllocator
allocator = NVLinkAllocator.get_allocator(self.device)
self.custom_mem_pool = torch.cuda.MemPool(allocator.allocator())
else:
self.custom_mem_pool = None
def _finalize_allocation_log(self, num_tokens: int): def _finalize_allocation_log(self, num_tokens: int):
"""Common logging and mem_usage computation for KV cache allocation. """Common logging and mem_usage computation for KV cache allocation.
Supports both tuple (K, V) size returns and single KV size returns. Supports both tuple (K, V) size returns and single KV size returns.
...@@ -426,6 +489,9 @@ class KVCache(abc.ABC): ...@@ -426,6 +489,9 @@ class KVCache(abc.ABC):
def load_cpu_copy(self, kv_cache_cpu, indices): def load_cpu_copy(self, kv_cache_cpu, indices):
raise NotImplementedError() raise NotImplementedError()
def maybe_get_custom_mem_pool(self):
return self.custom_mem_pool
class MHATokenToKVPool(KVCache): class MHATokenToKVPool(KVCache):
...@@ -456,19 +522,6 @@ class MHATokenToKVPool(KVCache): ...@@ -456,19 +522,6 @@ class MHATokenToKVPool(KVCache):
self.head_num = head_num self.head_num = head_num
self.head_dim = head_dim self.head_dim = head_dim
# for disagg with nvlink
self.enable_custom_mem_pool = get_bool_env_var(
"SGLANG_MOONCAKE_CUSTOM_MEM_POOL", "false"
)
if self.enable_custom_mem_pool:
# TODO(shangming): abstract custom allocator class for more backends
from mooncake.allocator import NVLinkAllocator
allocator = NVLinkAllocator.get_allocator(self.device)
self.custom_mem_pool = torch.cuda.MemPool(allocator.allocator())
else:
self.custom_mem_pool = None
self._create_buffers() self._create_buffers()
self.device_module = torch.get_device_module(self.device) self.device_module = torch.get_device_module(self.device)
...@@ -611,9 +664,6 @@ class MHATokenToKVPool(KVCache): ...@@ -611,9 +664,6 @@ class MHATokenToKVPool(KVCache):
] ]
return kv_data_ptrs, kv_data_lens, kv_item_lens return kv_data_ptrs, kv_data_lens, kv_item_lens
def maybe_get_custom_mem_pool(self):
return self.custom_mem_pool
def get_cpu_copy(self, indices): def get_cpu_copy(self, indices):
torch.cuda.synchronize() torch.cuda.synchronize()
kv_cache_cpu = [] kv_cache_cpu = []
...@@ -756,12 +806,18 @@ class HybridLinearKVPool(KVCache): ...@@ -756,12 +806,18 @@ class HybridLinearKVPool(KVCache):
full_attention_layer_ids: List[int], full_attention_layer_ids: List[int],
enable_kvcache_transpose: bool, enable_kvcache_transpose: bool,
device: str, device: str,
mamba_pool: MambaPool,
): ):
self.size = size self.size = size
self.dtype = dtype self.dtype = dtype
self.device = device self.device = device
self.full_layer_nums = len(full_attention_layer_ids) self.full_layer_nums = len(full_attention_layer_ids)
self.page_size = page_size self.page_size = page_size
# TODO support pp?
self.start_layer = 0
self.head_num = head_num
self.head_dim = head_dim
self.mamba_pool = mamba_pool
# TODO MHATransposedTokenToKVPool if enable_kvcache_transpose is True # TODO MHATransposedTokenToKVPool if enable_kvcache_transpose is True
assert not enable_kvcache_transpose assert not enable_kvcache_transpose
if _is_npu: if _is_npu:
...@@ -790,6 +846,15 @@ class HybridLinearKVPool(KVCache): ...@@ -790,6 +846,15 @@ class HybridLinearKVPool(KVCache):
def get_contiguous_buf_infos(self): def get_contiguous_buf_infos(self):
return self.full_kv_pool.get_contiguous_buf_infos() return self.full_kv_pool.get_contiguous_buf_infos()
def get_state_buf_infos(self):
mamba_data_ptrs, mamba_data_lens, mamba_item_lens = (
self.mamba_pool.get_contiguous_buf_infos()
)
return mamba_data_ptrs, mamba_data_lens, mamba_item_lens
def maybe_get_custom_mem_pool(self):
return self.full_kv_pool.maybe_get_custom_mem_pool()
def _transfer_full_attention_id(self, layer_id: int): def _transfer_full_attention_id(self, layer_id: int):
if layer_id not in self.full_attention_layer_id_mapping: if layer_id not in self.full_attention_layer_id_mapping:
raise ValueError( raise ValueError(
...@@ -841,22 +906,47 @@ class SWAKVPool(KVCache): ...@@ -841,22 +906,47 @@ class SWAKVPool(KVCache):
size: int, size: int,
size_swa: int, size_swa: int,
dtype: torch.dtype, dtype: torch.dtype,
head_num: int,
head_dim: int,
swa_attention_layer_ids: List[int], swa_attention_layer_ids: List[int],
full_attention_layer_ids: List[int], full_attention_layer_ids: List[int],
enable_kvcache_transpose: bool, enable_kvcache_transpose: bool,
device: str,
token_to_kv_pool_class: KVCache = MHATokenToKVPool, token_to_kv_pool_class: KVCache = MHATokenToKVPool,
**kwargs, **kwargs,
): ):
self.size = size self.size = size
self.size_swa = size_swa self.size_swa = size_swa
self.dtype = dtype self.dtype = dtype
self.head_num = head_num
self.head_dim = head_dim
self.device = device
self.swa_layer_nums = len(swa_attention_layer_ids) self.swa_layer_nums = len(swa_attention_layer_ids)
self.full_layer_nums = len(full_attention_layer_ids) self.full_layer_nums = len(full_attention_layer_ids)
self.start_layer = 0
self.page_size = 1
kwargs["page_size"] = 1 kwargs["page_size"] = 1
kwargs["enable_memory_saver"] = False kwargs["enable_memory_saver"] = False
kwargs["head_num"] = head_num
kwargs["head_dim"] = head_dim
kwargs["device"] = device
# TODO MHATransposedTokenToKVPool if enable_kvcache_transpose is True # TODO MHATransposedTokenToKVPool if enable_kvcache_transpose is True
assert not enable_kvcache_transpose assert not enable_kvcache_transpose
# for disagg with nvlink
self.enable_custom_mem_pool = get_bool_env_var(
"SGLANG_MOONCAKE_CUSTOM_MEM_POOL", "false"
)
if self.enable_custom_mem_pool:
# TODO(shangming): abstract custom allocator class for more backends
from mooncake.allocator import NVLinkAllocator
allocator = NVLinkAllocator.get_allocator(self.device)
self.custom_mem_pool = torch.cuda.MemPool(allocator.allocator())
else:
self.custom_mem_pool = None
self.swa_kv_pool = token_to_kv_pool_class( self.swa_kv_pool = token_to_kv_pool_class(
size=size_swa, size=size_swa,
dtype=dtype, dtype=dtype,
...@@ -878,6 +968,9 @@ class SWAKVPool(KVCache): ...@@ -878,6 +968,9 @@ class SWAKVPool(KVCache):
k_size, v_size = self.get_kv_size_bytes() k_size, v_size = self.get_kv_size_bytes()
self.mem_usage = (k_size + v_size) / GB self.mem_usage = (k_size + v_size) / GB
logger.info(
f"SWAKVPool mem usage: {self.mem_usage} GB, swa size: {self.size_swa}, full size: {self.size}"
)
def get_kv_size_bytes(self): def get_kv_size_bytes(self):
k_size, v_size = self.full_kv_pool.get_kv_size_bytes() k_size, v_size = self.full_kv_pool.get_kv_size_bytes()
...@@ -888,15 +981,19 @@ class SWAKVPool(KVCache): ...@@ -888,15 +981,19 @@ class SWAKVPool(KVCache):
full_kv_data_ptrs, full_kv_data_lens, full_kv_item_lens = ( full_kv_data_ptrs, full_kv_data_lens, full_kv_item_lens = (
self.full_kv_pool.get_contiguous_buf_infos() self.full_kv_pool.get_contiguous_buf_infos()
) )
kv_data_ptrs = full_kv_data_ptrs
kv_data_lens = full_kv_data_lens
kv_item_lens = full_kv_item_lens
return kv_data_ptrs, kv_data_lens, kv_item_lens
def get_state_buf_infos(self):
swa_kv_data_ptrs, swa_kv_data_lens, swa_kv_item_lens = ( swa_kv_data_ptrs, swa_kv_data_lens, swa_kv_item_lens = (
self.swa_kv_pool.get_contiguous_buf_infos() self.swa_kv_pool.get_contiguous_buf_infos()
) )
kv_data_ptrs = full_kv_data_ptrs + swa_kv_data_ptrs return swa_kv_data_ptrs, swa_kv_data_lens, swa_kv_item_lens
kv_data_lens = full_kv_data_lens + swa_kv_data_lens
kv_item_lens = full_kv_item_lens + swa_kv_item_lens
return kv_data_ptrs, kv_data_lens, kv_item_lens
def get_key_buffer(self, layer_id: int): def get_key_buffer(self, layer_id: int):
layer_id_pool, is_swa = self.layers_mapping[layer_id] layer_id_pool, is_swa = self.layers_mapping[layer_id]
...@@ -1152,19 +1249,6 @@ class MLATokenToKVPool(KVCache): ...@@ -1152,19 +1249,6 @@ class MLATokenToKVPool(KVCache):
else (kv_lora_rank + qk_rope_head_dim) else (kv_lora_rank + qk_rope_head_dim)
) )
# for disagg with nvlink
self.enable_custom_mem_pool = get_bool_env_var(
"SGLANG_MOONCAKE_CUSTOM_MEM_POOL", "false"
)
if self.enable_custom_mem_pool:
# TODO(shangming): abstract custom allocator class for more backends
from mooncake.allocator import NVLinkAllocator
allocator = NVLinkAllocator.get_allocator(self.device)
self.custom_mem_pool = torch.cuda.MemPool(allocator.allocator())
else:
self.custom_mem_pool = None
with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE): with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
with ( with (
torch.cuda.use_mem_pool(self.custom_mem_pool) torch.cuda.use_mem_pool(self.custom_mem_pool)
...@@ -1207,9 +1291,6 @@ class MLATokenToKVPool(KVCache): ...@@ -1207,9 +1291,6 @@ class MLATokenToKVPool(KVCache):
] ]
return kv_data_ptrs, kv_data_lens, kv_item_lens return kv_data_ptrs, kv_data_lens, kv_item_lens
def maybe_get_custom_mem_pool(self):
return self.custom_mem_pool
def get_key_buffer(self, layer_id: int): def get_key_buffer(self, layer_id: int):
if self.layer_transfer_counter is not None: if self.layer_transfer_counter is not None:
self.layer_transfer_counter.wait_until(layer_id - self.start_layer) self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
...@@ -1346,6 +1427,11 @@ class NSATokenToKVPool(MLATokenToKVPool): ...@@ -1346,6 +1427,11 @@ class NSATokenToKVPool(MLATokenToKVPool):
assert index_head_dim == 128 assert index_head_dim == 128
assert self.page_size == 64 assert self.page_size == 64
with (
torch.cuda.use_mem_pool(self.custom_mem_pool)
if self.custom_mem_pool
else nullcontext()
):
self.index_k_with_scale_buffer = [ self.index_k_with_scale_buffer = [
torch.zeros( torch.zeros(
# Layout: # Layout:
...@@ -1357,7 +1443,9 @@ class NSATokenToKVPool(MLATokenToKVPool): ...@@ -1357,7 +1443,9 @@ class NSATokenToKVPool(MLATokenToKVPool):
( (
(size + page_size + 1) // self.page_size, (size + page_size + 1) // self.page_size,
self.page_size self.page_size
* (index_head_dim + index_head_dim // self.quant_block_size * 4), * (
index_head_dim + index_head_dim // self.quant_block_size * 4
),
), ),
dtype=self.index_k_with_scale_buffer_dtype, dtype=self.index_k_with_scale_buffer_dtype,
device=device, device=device,
...@@ -1406,6 +1494,18 @@ class NSATokenToKVPool(MLATokenToKVPool): ...@@ -1406,6 +1494,18 @@ class NSATokenToKVPool(MLATokenToKVPool):
pool=self, buf=buf, loc=loc, index_k=index_k, index_k_scale=index_k_scale pool=self, buf=buf, loc=loc, index_k=index_k, index_k_scale=index_k_scale
) )
def get_state_buf_infos(self):
data_ptrs = [
self.index_k_with_scale_buffer[i].data_ptr() for i in range(self.layer_num)
]
data_lens = [
self.index_k_with_scale_buffer[i].nbytes for i in range(self.layer_num)
]
item_lens = [
self.index_k_with_scale_buffer[i][0].nbytes for i in range(self.layer_num)
]
return data_ptrs, data_lens, item_lens
def get_kv_size_bytes(self): def get_kv_size_bytes(self):
kv_size_bytes = super().get_kv_size_bytes() kv_size_bytes = super().get_kv_size_bytes()
for index_k_cache in self.index_k_with_scale_buffer: for index_k_cache in self.index_k_with_scale_buffer:
...@@ -1636,16 +1736,25 @@ class DoubleSparseTokenToKVPool(KVCache): ...@@ -1636,16 +1736,25 @@ class DoubleSparseTokenToKVPool(KVCache):
) )
with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE): with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
with (
torch.cuda.use_mem_pool(self.custom_mem_pool)
if self.enable_custom_mem_pool
else nullcontext()
):
# [size, head_num, head_dim] for each layer # [size, head_num, head_dim] for each layer
self.k_buffer = [ self.k_buffer = [
torch.zeros( torch.zeros(
(size + page_size, head_num, head_dim), dtype=dtype, device=device (size + page_size, head_num, head_dim),
dtype=dtype,
device=device,
) )
for _ in range(layer_num) for _ in range(layer_num)
] ]
self.v_buffer = [ self.v_buffer = [
torch.zeros( torch.zeros(
(size + page_size, head_num, head_dim), dtype=dtype, device=device (size + page_size, head_num, head_dim),
dtype=dtype,
device=device,
) )
for _ in range(layer_num) for _ in range(layer_num)
] ]
...@@ -1653,7 +1762,9 @@ class DoubleSparseTokenToKVPool(KVCache): ...@@ -1653,7 +1762,9 @@ class DoubleSparseTokenToKVPool(KVCache):
# [size, head_num, heavy_channel_num] for each layer # [size, head_num, heavy_channel_num] for each layer
self.label_buffer = [ self.label_buffer = [
torch.zeros( torch.zeros(
(size + 1, head_num, heavy_channel_num), dtype=dtype, device=device (size + 1, head_num, heavy_channel_num),
dtype=dtype,
device=device,
) )
for _ in range(layer_num) for _ in range(layer_num)
] ]
......
...@@ -1669,11 +1669,26 @@ class ModelRunner: ...@@ -1669,11 +1669,26 @@ class ModelRunner:
extra_max_context_len += self.server_args.speculative_num_draft_tokens extra_max_context_len += self.server_args.speculative_num_draft_tokens
if self.server_args.disaggregation_mode == "decode": if self.server_args.disaggregation_mode == "decode":
from sglang.srt.disaggregation.decode import DecodeReqToTokenPool from sglang.srt.disaggregation.decode import (
DecodeReqToTokenPool,
HybridMambaDecodeReqToTokenPool,
)
# subscribe memory for pre-allocated requests # subscribe memory for pre-allocated requests
# if max_num_reqs <= 32, we pre-allocate 2x requests # if max_num_reqs <= 32, we pre-allocate 2x requests
pre_alloc_size = max_num_reqs * 2 if max_num_reqs <= 32 else 0 pre_alloc_size = max_num_reqs * 2 if max_num_reqs <= 32 else 0
if config := self.mambaish_config:
self.req_to_token_pool = HybridMambaDecodeReqToTokenPool(
size=max_num_reqs,
max_context_len=self.model_config.context_len
+ extra_max_context_len,
device=self.device,
enable_memory_saver=self.server_args.enable_memory_saver,
cache_params=config.mamba2_cache_params,
speculative_num_draft_tokens=self.server_args.speculative_num_draft_tokens,
pre_alloc_size=pre_alloc_size,
)
else:
self.req_to_token_pool = DecodeReqToTokenPool( self.req_to_token_pool = DecodeReqToTokenPool(
size=max_num_reqs, size=max_num_reqs,
max_context_len=self.model_config.context_len max_context_len=self.model_config.context_len
...@@ -1807,6 +1822,7 @@ class ModelRunner: ...@@ -1807,6 +1822,7 @@ class ModelRunner:
), ),
enable_kvcache_transpose=False, enable_kvcache_transpose=False,
device=self.device, device=self.device,
mamba_pool=self.req_to_token_pool.mamba_pool,
) )
else: else:
self.token_to_kv_pool = MHATokenToKVPool( self.token_to_kv_pool = MHATokenToKVPool(
......
...@@ -163,6 +163,7 @@ suites = { ...@@ -163,6 +163,7 @@ suites = {
TestFile("test_deepseek_v3_basic.py", 275), TestFile("test_deepseek_v3_basic.py", 275),
TestFile("test_deepseek_v3_mtp.py", 275), TestFile("test_deepseek_v3_mtp.py", 275),
TestFile("test_disaggregation_different_tp.py", 600), TestFile("test_disaggregation_different_tp.py", 600),
TestFile("test_disaggregation_hybrid_attention.py", 200),
TestFile("test_disaggregation_pp.py", 140), TestFile("test_disaggregation_pp.py", 140),
], ],
"per-commit-4-gpu-b200": [ "per-commit-4-gpu-b200": [
......
import os
import unittest
from types import SimpleNamespace
from sglang.srt.environ import envs
from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k
from sglang.test.test_disaggregation_utils import TestDisaggregationBase
from sglang.test.test_utils import (
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
popen_launch_pd_server,
)
class TestDisaggregationHybridAttentionMamba(TestDisaggregationBase):
@classmethod
def setUpClass(cls):
super().setUpClass()
cls.model = "Qwen/Qwen3-Next-80B-A3B-Instruct"
# Non blocking start servers
cls.start_prefill()
cls.start_decode()
# Block until both
cls.wait_server_ready(cls.prefill_url + "/health")
cls.wait_server_ready(cls.decode_url + "/health")
cls.launch_lb()
@classmethod
def start_prefill(cls):
prefill_args = [
"--trust-remote-code",
"--disaggregation-mode",
"prefill",
"--tp",
"4",
]
prefill_args += cls.transfer_backend + cls.rdma_devices
cls.process_prefill = popen_launch_pd_server(
cls.model,
cls.prefill_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=prefill_args,
)
@classmethod
def start_decode(cls):
decode_args = [
"--trust-remote-code",
"--disaggregation-mode",
"decode",
"--tp",
"4",
"--base-gpu-id",
"4",
]
decode_args += cls.transfer_backend + cls.rdma_devices
cls.process_decode = popen_launch_pd_server(
cls.model,
cls.decode_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=decode_args,
)
def test_gsm8k(self):
args = SimpleNamespace(
num_shots=5,
data_path=None,
num_questions=200,
max_new_tokens=512,
parallel=128,
host=f"http://{self.base_host}",
port=int(self.lb_port),
)
metrics = run_eval_few_shot_gsm8k(args)
print(f"Evaluation metrics: {metrics}")
self.assertGreater(metrics["accuracy"], 0.93)
if __name__ == "__main__":
unittest.main()
...@@ -42,6 +42,7 @@ class TestMamba(unittest.TestCase): ...@@ -42,6 +42,7 @@ class TestMamba(unittest.TestCase):
full_attention_layer_ids=full_attention_layer_ids, full_attention_layer_ids=full_attention_layer_ids,
enable_kvcache_transpose=False, enable_kvcache_transpose=False,
device=device, device=device,
mamba_pool=None,
) )
assert pool._transfer_full_attention_id(global_interval - 1) == 0 assert pool._transfer_full_attention_id(global_interval - 1) == 0
assert pool._transfer_full_attention_id(2 * global_interval - 1) == 1 assert pool._transfer_full_attention_id(2 * global_interval - 1) == 1
...@@ -173,6 +174,7 @@ class TestMamba(unittest.TestCase): ...@@ -173,6 +174,7 @@ class TestMamba(unittest.TestCase):
full_attention_layer_ids=full_attention_layer_ids, full_attention_layer_ids=full_attention_layer_ids,
enable_kvcache_transpose=False, enable_kvcache_transpose=False,
device=device, device=device,
mamba_pool=req_to_token_pool.mamba_pool,
) )
# setup token to kv pool allocator # setup token to kv pool allocator
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment