Unverified Commit 868403f6 authored by Shangming Cai's avatar Shangming Cai Committed by GitHub
Browse files

[PD] Add PD support for hybrid model (Qwen3-Next, DeepSeek V3.2 Exp) (#10912)


Signed-off-by: default avatarShangming Cai <csmthu@gmail.com>
Co-authored-by: default avatarhzh0425 <hzh0425@apache.org>
Co-authored-by: default avatarZeldaHuang <hzm414167@alibaba-inc.com>
parent 97d857c0
...@@ -20,6 +20,10 @@ class KVArgs: ...@@ -20,6 +20,10 @@ class KVArgs:
aux_data_ptrs: List[int] aux_data_ptrs: List[int]
aux_data_lens: List[int] aux_data_lens: List[int]
aux_item_lens: List[int] aux_item_lens: List[int]
state_data_ptrs: List[int]
state_data_lens: List[int]
state_item_lens: List[int]
state_type: str # "none", "mamba", "swa"
ib_device: str ib_device: str
ib_traffic_class: str ib_traffic_class: str
gpu_id: int gpu_id: int
...@@ -76,9 +80,13 @@ class BaseKVSender(ABC): ...@@ -76,9 +80,13 @@ class BaseKVSender(ABC):
... ...
@abstractmethod @abstractmethod
def send(self, kv_indices: npt.NDArray[np.int32]): def send(
self,
kv_indices: npt.NDArray[np.int32],
state_indices: Optional[List[int]] = None,
):
""" """
Send the kv cache at the given kv indices to the decoder server Send the kv cache at the given kv indices and the extra cache/state at the given indices to the decoder server
""" """
... ...
...@@ -108,9 +116,14 @@ class BaseKVReceiver(ABC): ...@@ -108,9 +116,14 @@ class BaseKVReceiver(ABC):
): ... ): ...
@abstractmethod @abstractmethod
def init(self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None): def init(
self,
kv_indices: npt.NDArray[np.int32],
aux_index: Optional[int] = None,
state_indices: Optional[List[int]] = None,
):
""" """
Notify the prefill server about the kv indices and aux index Notify the prefill server about the kv indices, aux index, and state_indices.
""" """
... ...
......
...@@ -201,6 +201,7 @@ class CommonKVSender(BaseKVSender): ...@@ -201,6 +201,7 @@ class CommonKVSender(BaseKVSender):
def send( def send(
self, self,
kv_indices: npt.NDArray[np.int32], kv_indices: npt.NDArray[np.int32],
state_indices: Optional[List[int]] = None,
): ):
pass pass
......
...@@ -25,11 +25,12 @@ import time ...@@ -25,11 +25,12 @@ import time
from collections import deque from collections import deque
from dataclasses import dataclass from dataclasses import dataclass
from http import HTTPStatus from http import HTTPStatus
from typing import TYPE_CHECKING, List, Optional, Tuple, Type, Union from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Type, Union
import torch import torch
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
from sglang.srt.configs.mamba_utils import Mamba2CacheParams
from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE
from sglang.srt.disaggregation.base import BaseKVManager, BaseKVReceiver, KVPoll from sglang.srt.disaggregation.base import BaseKVManager, BaseKVReceiver, KVPoll
from sglang.srt.disaggregation.utils import ( from sglang.srt.disaggregation.utils import (
...@@ -47,9 +48,19 @@ from sglang.srt.disaggregation.utils import ( ...@@ -47,9 +48,19 @@ from sglang.srt.disaggregation.utils import (
) )
from sglang.srt.layers.dp_attention import get_attention_tp_size from sglang.srt.layers.dp_attention import get_attention_tp_size
from sglang.srt.managers.schedule_batch import FINISH_ABORT, RequestStage, ScheduleBatch from sglang.srt.managers.schedule_batch import FINISH_ABORT, RequestStage, ScheduleBatch
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator from sglang.srt.mem_cache.allocator import (
BaseTokenToKVPoolAllocator,
SWATokenToKVPoolAllocator,
)
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.memory_pool import KVCache, ReqToTokenPool from sglang.srt.mem_cache.memory_pool import (
HybridLinearKVPool,
HybridReqToTokenPool,
KVCache,
NSATokenToKVPool,
ReqToTokenPool,
SWAKVPool,
)
from sglang.srt.model_executor.forward_batch_info import ForwardMode from sglang.srt.model_executor.forward_batch_info import ForwardMode
from sglang.srt.utils import get_int_env_var, require_mlp_sync from sglang.srt.utils import get_int_env_var, require_mlp_sync
from sglang.srt.utils.torch_memory_saver_adapter import TorchMemorySaverAdapter from sglang.srt.utils.torch_memory_saver_adapter import TorchMemorySaverAdapter
...@@ -124,6 +135,35 @@ class DecodeReqToTokenPool: ...@@ -124,6 +135,35 @@ class DecodeReqToTokenPool:
self.free_slots = list(range(self.size + self.pre_alloc_size)) self.free_slots = list(range(self.size + self.pre_alloc_size))
class HybridMambaDecodeReqToTokenPool(HybridReqToTokenPool):
def __init__(
self,
size: int,
max_context_len: int,
device: str,
enable_memory_saver: bool,
cache_params: "Mamba2CacheParams",
speculative_num_draft_tokens: int,
pre_alloc_size: int,
):
DecodeReqToTokenPool.__init__(
self,
size=size,
max_context_len=max_context_len,
device=device,
enable_memory_saver=enable_memory_saver,
pre_alloc_size=pre_alloc_size,
)
self._init_mamba_pool(
size + pre_alloc_size, cache_params, device, speculative_num_draft_tokens
)
def clear(self):
self.free_slots = list(range(self.size + self.pre_alloc_size))
self.mamba_pool.clear()
@dataclass @dataclass
class DecodeRequest: class DecodeRequest:
req: Req req: Req
...@@ -217,6 +257,28 @@ class DecodePreallocQueue: ...@@ -217,6 +257,28 @@ class DecodePreallocQueue:
self.metadata_buffers.get_buf_infos() self.metadata_buffers.get_buf_infos()
) )
if hasattr(self.token_to_kv_pool, "get_state_buf_infos"):
state_data_ptrs, state_data_lens, state_item_lens = (
self.token_to_kv_pool.get_state_buf_infos()
)
kv_args.state_data_ptrs = state_data_ptrs
kv_args.state_data_lens = state_data_lens
kv_args.state_item_lens = state_item_lens
if isinstance(self.token_to_kv_pool, SWAKVPool):
kv_args.state_type = "swa"
elif isinstance(self.token_to_kv_pool, HybridLinearKVPool):
kv_args.state_type = "mamba"
elif isinstance(self.token_to_kv_pool, NSATokenToKVPool):
kv_args.state_type = "nsa"
else:
kv_args.state_type = "none"
else:
kv_args.state_data_ptrs = []
kv_args.state_data_lens = []
kv_args.state_item_lens = []
kv_args.state_type = "none"
kv_args.ib_device = self.scheduler.server_args.disaggregation_ib_device kv_args.ib_device = self.scheduler.server_args.disaggregation_ib_device
kv_args.gpu_id = self.scheduler.gpu_id kv_args.gpu_id = self.scheduler.gpu_id
kv_manager_class: Type[BaseKVManager] = get_kv_class( kv_manager_class: Type[BaseKVManager] = get_kv_class(
...@@ -414,16 +476,56 @@ class DecodePreallocQueue: ...@@ -414,16 +476,56 @@ class DecodePreallocQueue:
.cpu() .cpu()
.numpy() .numpy()
) )
page_size = self.token_to_kv_pool_allocator.page_size
# Prepare extra pool indices for hybrid models
if isinstance(self.token_to_kv_pool, HybridLinearKVPool):
# Mamba hybrid model: single mamba state index
state_indices = [
self.req_to_token_pool.req_index_to_mamba_index_mapping[
decode_req.req.req_pool_idx
]
.cpu()
.numpy()
]
elif isinstance(self.token_to_kv_pool, SWAKVPool):
# SWA hybrid model: send decode-side SWA window indices
seq_len = len(decode_req.req.origin_input_ids)
window_size = self.scheduler.sliding_window_size
window_start = max(0, seq_len - window_size)
window_start = (window_start // page_size) * page_size
window_kv_indices_full = self.req_to_token_pool.req_to_token[
decode_req.req.req_pool_idx, window_start:seq_len
]
# Translate to SWA pool indices
window_kv_indices_swa = (
self.token_to_kv_pool_allocator.translate_loc_from_full_to_swa(
window_kv_indices_full
)
)
state_indices = window_kv_indices_swa.cpu().numpy()
state_indices = kv_to_page_indices(state_indices, page_size)
elif isinstance(self.token_to_kv_pool, NSATokenToKVPool):
seq_len = len(decode_req.req.origin_input_ids)
kv_indices_full = self.req_to_token_pool.req_to_token[
decode_req.req.req_pool_idx, :seq_len
]
state_indices = kv_indices_full.cpu().numpy()
state_indices = kv_to_page_indices(state_indices, page_size)
else:
state_indices = None
decode_req.metadata_buffer_index = ( decode_req.metadata_buffer_index = (
self.req_to_metadata_buffer_idx_allocator.alloc() self.req_to_metadata_buffer_idx_allocator.alloc()
) )
assert decode_req.metadata_buffer_index is not None assert decode_req.metadata_buffer_index is not None
page_indices = kv_to_page_indices( page_indices = kv_to_page_indices(kv_indices, page_size)
kv_indices, self.token_to_kv_pool_allocator.page_size decode_req.kv_receiver.init(
page_indices, decode_req.metadata_buffer_index, state_indices
) )
decode_req.kv_receiver.init(page_indices, decode_req.metadata_buffer_index) decode_req.req.add_latency(RequestStage.DECODE_BOOTSTRAP)
preallocated_reqs.append(decode_req) preallocated_reqs.append(decode_req)
indices_to_remove.add(i) indices_to_remove.add(i)
decode_req.req.time_stats.decode_transfer_queue_entry_time = ( decode_req.req.time_stats.decode_transfer_queue_entry_time = (
...@@ -503,7 +605,10 @@ class DecodePreallocQueue: ...@@ -503,7 +605,10 @@ class DecodePreallocQueue:
def _pre_alloc(self, req: Req) -> torch.Tensor: def _pre_alloc(self, req: Req) -> torch.Tensor:
"""Pre-allocate the memory for req_to_token and token_kv_pool""" """Pre-allocate the memory for req_to_token and token_kv_pool"""
req_pool_indices = self.req_to_token_pool.alloc(1) if isinstance(self.req_to_token_pool, HybridMambaDecodeReqToTokenPool):
req_pool_indices = self.req_to_token_pool.alloc(1, [req])
else:
req_pool_indices = self.req_to_token_pool.alloc(1)
assert ( assert (
req_pool_indices is not None req_pool_indices is not None
......
...@@ -48,9 +48,12 @@ class FakeKVSender(BaseKVSender): ...@@ -48,9 +48,12 @@ class FakeKVSender(BaseKVSender):
def send( def send(
self, self,
kv_indices: npt.NDArray[np.int32], kv_indices: npt.NDArray[np.int32],
state_indices: Optional[List[int]] = None,
): ):
self.has_sent = True self.has_sent = True
logger.debug(f"FakeKVSender send with kv_indices: {kv_indices}") logger.debug(
f"FakeKVSender send with kv_indices: {kv_indices}, state_indices: {state_indices}"
)
def failure_exception(self): def failure_exception(self):
raise Exception("Fake KVSender Exception") raise Exception("Fake KVSender Exception")
...@@ -75,10 +78,15 @@ class FakeKVReceiver(BaseKVReceiver): ...@@ -75,10 +78,15 @@ class FakeKVReceiver(BaseKVReceiver):
logger.debug("FakeKVReceiver poll success") logger.debug("FakeKVReceiver poll success")
return KVPoll.Success return KVPoll.Success
def init(self, kv_indices: list[int], aux_index: Optional[int] = None): def init(
self,
kv_indices: list[int],
aux_index: Optional[int] = None,
state_indices: Optional[List[int]] = None,
):
self.has_init = True self.has_init = True
logger.debug( logger.debug(
f"FakeKVReceiver init with kv_indices: {kv_indices}, aux_index: {aux_index}" f"FakeKVReceiver init with kv_indices: {kv_indices}, aux_index: {aux_index}, state_indices: {state_indices}"
) )
def failure_exception(self): def failure_exception(self):
......
...@@ -58,6 +58,7 @@ class TransferKVChunk: ...@@ -58,6 +58,7 @@ class TransferKVChunk:
index_slice: slice index_slice: slice
is_last: bool is_last: bool
prefill_aux_index: Optional[int] prefill_aux_index: Optional[int]
state_indices: Optional[List[int]]
# decode # decode
...@@ -69,6 +70,7 @@ class TransferInfo: ...@@ -69,6 +70,7 @@ class TransferInfo:
mooncake_session_id: str mooncake_session_id: str
dst_kv_indices: npt.NDArray[np.int32] dst_kv_indices: npt.NDArray[np.int32]
dst_aux_index: int dst_aux_index: int
dst_state_indices: List[int]
required_dst_info_num: int required_dst_info_num: int
is_dummy: bool is_dummy: bool
...@@ -78,9 +80,14 @@ class TransferInfo: ...@@ -78,9 +80,14 @@ class TransferInfo:
is_dummy = True is_dummy = True
dst_kv_indices = np.array([], dtype=np.int32) dst_kv_indices = np.array([], dtype=np.int32)
dst_aux_index = None dst_aux_index = None
dst_state_indices = []
else: else:
dst_kv_indices = np.frombuffer(msg[4], dtype=np.int32) dst_kv_indices = np.frombuffer(msg[4], dtype=np.int32)
dst_aux_index = int(msg[5].decode("ascii")) dst_aux_index = int(msg[5].decode("ascii"))
if msg[6] == b"":
dst_state_indices = []
else:
dst_state_indices = list(np.frombuffer(msg[6], dtype=np.int32))
is_dummy = False is_dummy = False
return cls( return cls(
room=int(msg[0].decode("ascii")), room=int(msg[0].decode("ascii")),
...@@ -89,7 +96,8 @@ class TransferInfo: ...@@ -89,7 +96,8 @@ class TransferInfo:
mooncake_session_id=msg[3].decode("ascii"), mooncake_session_id=msg[3].decode("ascii"),
dst_kv_indices=dst_kv_indices, dst_kv_indices=dst_kv_indices,
dst_aux_index=dst_aux_index, dst_aux_index=dst_aux_index,
required_dst_info_num=int(msg[6].decode("ascii")), dst_state_indices=dst_state_indices,
required_dst_info_num=int(msg[7].decode("ascii")),
is_dummy=is_dummy, is_dummy=is_dummy,
) )
...@@ -103,6 +111,7 @@ class KVArgsRegisterInfo: ...@@ -103,6 +111,7 @@ class KVArgsRegisterInfo:
mooncake_session_id: str mooncake_session_id: str
dst_kv_ptrs: list[int] dst_kv_ptrs: list[int]
dst_aux_ptrs: list[int] dst_aux_ptrs: list[int]
dst_state_data_ptrs: list[int]
dst_tp_rank: int dst_tp_rank: int
dst_attn_tp_size: int dst_attn_tp_size: int
dst_kv_item_len: int dst_kv_item_len: int
...@@ -116,9 +125,10 @@ class KVArgsRegisterInfo: ...@@ -116,9 +125,10 @@ class KVArgsRegisterInfo:
mooncake_session_id=msg[3].decode("ascii"), mooncake_session_id=msg[3].decode("ascii"),
dst_kv_ptrs=list(struct.unpack(f"{len(msg[4])//8}Q", msg[4])), dst_kv_ptrs=list(struct.unpack(f"{len(msg[4])//8}Q", msg[4])),
dst_aux_ptrs=list(struct.unpack(f"{len(msg[5])//8}Q", msg[5])), dst_aux_ptrs=list(struct.unpack(f"{len(msg[5])//8}Q", msg[5])),
dst_tp_rank=int(msg[6].decode("ascii")), dst_state_data_ptrs=list(struct.unpack(f"{len(msg[6])//8}Q", msg[6])),
dst_attn_tp_size=int(msg[7].decode("ascii")), dst_tp_rank=int(msg[7].decode("ascii")),
dst_kv_item_len=int(msg[8].decode("ascii")), dst_attn_tp_size=int(msg[8].decode("ascii")),
dst_kv_item_len=int(msg[9].decode("ascii")),
) )
...@@ -180,6 +190,9 @@ class MooncakeKVManager(CommonKVManager): ...@@ -180,6 +190,9 @@ class MooncakeKVManager(CommonKVManager):
) )
for _ in range(transfer_queue_size) for _ in range(transfer_queue_size)
] ]
self.state_executors = concurrent.futures.ThreadPoolExecutor(
transfer_thread_pool_size // transfer_queue_size
)
for queue, executor in zip(self.transfer_queues, self.executors): for queue, executor in zip(self.transfer_queues, self.executors):
threading.Thread( threading.Thread(
target=self.transfer_worker, args=(queue, executor), daemon=True target=self.transfer_worker, args=(queue, executor), daemon=True
...@@ -239,6 +252,12 @@ class MooncakeKVManager(CommonKVManager): ...@@ -239,6 +252,12 @@ class MooncakeKVManager(CommonKVManager):
self.kv_args.aux_data_ptrs, self.kv_args.aux_data_lens self.kv_args.aux_data_ptrs, self.kv_args.aux_data_lens
) )
# Batch register state/extra pool data buffers
if self.kv_args.state_data_ptrs and self.kv_args.state_data_lens:
self.engine.batch_register(
self.kv_args.state_data_ptrs, self.kv_args.state_data_lens
)
def _transfer_data(self, mooncake_session_id, transfer_blocks): def _transfer_data(self, mooncake_session_id, transfer_blocks):
if not transfer_blocks: if not transfer_blocks:
return 0 return 0
...@@ -248,17 +267,23 @@ class MooncakeKVManager(CommonKVManager): ...@@ -248,17 +267,23 @@ class MooncakeKVManager(CommonKVManager):
mooncake_session_id, list(src_addrs), list(dst_addrs), list(lengths) mooncake_session_id, list(src_addrs), list(dst_addrs), list(lengths)
) )
def send_kvcache( def _send_kvcache_generic(
self, self,
mooncake_session_id: str, mooncake_session_id: str,
prefill_kv_indices: npt.NDArray[np.int32], src_data_ptrs: list[int],
dst_kv_ptrs: list[int], dst_data_ptrs: list[int],
dst_kv_indices: npt.NDArray[np.int32], item_lens: list[int],
prefill_data_indices: npt.NDArray[np.int32],
dst_data_indices: npt.NDArray[np.int32],
executor: concurrent.futures.ThreadPoolExecutor, executor: concurrent.futures.ThreadPoolExecutor,
): ) -> int:
# Group by indices """
Generic KV cache transfer supporting both MHA and MLA architectures.
This method is used by both send_kvcache (full pool) and maybe_send_extra.
"""
# Group by indices for optimization
prefill_kv_blocks, dst_kv_blocks = group_concurrent_contiguous( prefill_kv_blocks, dst_kv_blocks = group_concurrent_contiguous(
prefill_kv_indices, dst_kv_indices prefill_data_indices, dst_data_indices
) )
layers_params = None layers_params = None
...@@ -266,9 +291,9 @@ class MooncakeKVManager(CommonKVManager): ...@@ -266,9 +291,9 @@ class MooncakeKVManager(CommonKVManager):
# pp is not supported on the decode side yet # pp is not supported on the decode side yet
if self.is_mla_backend: if self.is_mla_backend:
src_kv_ptrs, dst_kv_ptrs, layers_current_pp_stage = ( src_kv_ptrs, dst_kv_ptrs, layers_current_pp_stage = (
self.get_mla_kv_ptrs_with_pp(self.kv_args.kv_data_ptrs, dst_kv_ptrs) self.get_mla_kv_ptrs_with_pp(src_data_ptrs, dst_data_ptrs)
) )
kv_item_len = self.kv_args.kv_item_lens[0] kv_item_len = item_lens[0]
layers_params = [ layers_params = [
( (
src_kv_ptrs[layer_id], src_kv_ptrs[layer_id],
...@@ -279,9 +304,9 @@ class MooncakeKVManager(CommonKVManager): ...@@ -279,9 +304,9 @@ class MooncakeKVManager(CommonKVManager):
] ]
else: else:
src_k_ptrs, src_v_ptrs, dst_k_ptrs, dst_v_ptrs, layers_current_pp_stage = ( src_k_ptrs, src_v_ptrs, dst_k_ptrs, dst_v_ptrs, layers_current_pp_stage = (
self.get_mha_kv_ptrs_with_pp(self.kv_args.kv_data_ptrs, dst_kv_ptrs) self.get_mha_kv_ptrs_with_pp(src_data_ptrs, dst_data_ptrs)
) )
kv_item_len = self.kv_args.kv_item_lens[0] kv_item_len = item_lens[0]
layers_params = [ layers_params = [
( (
src_k_ptrs[layer_id], src_k_ptrs[layer_id],
...@@ -345,6 +370,24 @@ class MooncakeKVManager(CommonKVManager): ...@@ -345,6 +370,24 @@ class MooncakeKVManager(CommonKVManager):
return 0 return 0
def send_kvcache(
self,
mooncake_session_id: str,
prefill_kv_indices: npt.NDArray[np.int32],
dst_kv_ptrs: list[int],
dst_kv_indices: npt.NDArray[np.int32],
executor: concurrent.futures.ThreadPoolExecutor,
):
return self._send_kvcache_generic(
mooncake_session_id=mooncake_session_id,
src_data_ptrs=self.kv_args.kv_data_ptrs,
dst_data_ptrs=dst_kv_ptrs,
item_lens=self.kv_args.kv_item_lens,
prefill_data_indices=prefill_kv_indices,
dst_data_indices=dst_kv_indices,
executor=executor,
)
def send_kvcache_slice( def send_kvcache_slice(
self, self,
mooncake_session_id: str, mooncake_session_id: str,
...@@ -593,6 +636,58 @@ class MooncakeKVManager(CommonKVManager): ...@@ -593,6 +636,58 @@ class MooncakeKVManager(CommonKVManager):
f"Received AUX_DATA for bootstrap_room {room} with length:{len(data)}" f"Received AUX_DATA for bootstrap_room {room} with length:{len(data)}"
) )
def maybe_send_extra(
self,
req: TransferInfo,
prefill_state_indices: list[int],
dst_state_data_ptrs: list[int],
):
"""Send state or extra pool data with type-specific handling."""
state_type = getattr(self.kv_args, "state_type", "none")
if state_type == "mamba":
return self._send_mamba_state(
req,
prefill_state_indices,
dst_state_data_ptrs,
)
elif state_type in ["swa", "nsa"]:
# Reuse _send_kvcache_generic interface to send extra pool data
prefill_state_indices = np.array(prefill_state_indices, dtype=np.int32)
dst_state_indices = np.array(req.dst_state_indices, dtype=np.int32)
return self._send_kvcache_generic(
mooncake_session_id=req.mooncake_session_id,
src_data_ptrs=self.kv_args.state_data_ptrs,
dst_data_ptrs=dst_state_data_ptrs,
item_lens=self.kv_args.state_item_lens,
prefill_data_indices=prefill_state_indices,
dst_data_indices=dst_state_indices,
executor=self.state_executors,
)
else:
return 0
def _send_mamba_state(
self,
req: TransferInfo,
prefill_mamba_index: list[int],
dst_state_data_ptrs: list[int],
):
"""Transfer Mamba states."""
assert len(prefill_mamba_index) == 1, "Mamba should have single state index"
transfer_blocks = []
prefill_state_data_ptrs = self.kv_args.state_data_ptrs
prefill_state_item_lens = self.kv_args.state_item_lens
for i, dst_state_ptr in enumerate(dst_state_data_ptrs):
length = prefill_state_item_lens[i]
src_addr = prefill_state_data_ptrs[i] + length * int(prefill_mamba_index[0])
dst_addr = dst_state_ptr + length * int(req.dst_state_indices[0])
transfer_blocks.append((src_addr, dst_addr, length))
return self._transfer_data(req.mooncake_session_id, transfer_blocks)
def sync_status_to_decode_endpoint( def sync_status_to_decode_endpoint(
self, remote: str, dst_port: int, room: int, status: int, prefill_rank: int self, remote: str, dst_port: int, room: int, status: int, prefill_rank: int
): ):
...@@ -702,6 +797,21 @@ class MooncakeKVManager(CommonKVManager): ...@@ -702,6 +797,21 @@ class MooncakeKVManager(CommonKVManager):
break break
if kv_chunk.is_last: if kv_chunk.is_last:
if kv_chunk.state_indices is not None:
if not self.is_mla_backend and (
self.attn_tp_size
!= target_rank_registration_info.dst_attn_tp_size
):
raise RuntimeError(
f"PD Disaggregation does NOT support PD different TP sizes for non-MLA hybrid models yet."
)
self.maybe_send_extra(
req,
kv_chunk.state_indices,
target_rank_registration_info.dst_state_data_ptrs,
)
if self.pp_group.is_last_rank: if self.pp_group.is_last_rank:
# Only the last chunk we need to send the aux data # Only the last chunk we need to send the aux data
ret = self.send_aux( ret = self.send_aux(
...@@ -765,7 +875,7 @@ class MooncakeKVManager(CommonKVManager): ...@@ -765,7 +875,7 @@ class MooncakeKVManager(CommonKVManager):
) )
continue continue
else: else:
required_dst_info_num = int(waiting_req_bytes[6].decode("ascii")) required_dst_info_num = int(waiting_req_bytes[7].decode("ascii"))
room = int(room) room = int(room)
if room not in self.transfer_infos: if room not in self.transfer_infos:
self.transfer_infos[room] = {} self.transfer_infos[room] = {}
...@@ -876,6 +986,7 @@ class MooncakeKVManager(CommonKVManager): ...@@ -876,6 +986,7 @@ class MooncakeKVManager(CommonKVManager):
index_slice: slice, index_slice: slice,
is_last: bool, is_last: bool,
aux_index: Optional[int] = None, aux_index: Optional[int] = None,
state_indices: Optional[List[int]] = None,
): ):
assert self.disaggregation_mode == DisaggregationMode.PREFILL assert self.disaggregation_mode == DisaggregationMode.PREFILL
assert not is_last or (is_last and aux_index is not None) assert not is_last or (is_last and aux_index is not None)
...@@ -909,6 +1020,7 @@ class MooncakeKVManager(CommonKVManager): ...@@ -909,6 +1020,7 @@ class MooncakeKVManager(CommonKVManager):
index_slice=index_slice, index_slice=index_slice,
is_last=is_last, is_last=is_last,
prefill_aux_index=aux_index, prefill_aux_index=aux_index,
state_indices=state_indices,
) )
) )
...@@ -989,6 +1101,7 @@ class MooncakeKVSender(CommonKVSender): ...@@ -989,6 +1101,7 @@ class MooncakeKVSender(CommonKVSender):
def send( def send(
self, self,
kv_indices: npt.NDArray[np.int32], kv_indices: npt.NDArray[np.int32],
state_indices: Optional[List[int]] = None,
): ):
index_slice = slice(self.curr_idx, self.curr_idx + len(kv_indices)) index_slice = slice(self.curr_idx, self.curr_idx + len(kv_indices))
self.curr_idx += len(kv_indices) self.curr_idx += len(kv_indices)
...@@ -1008,6 +1121,7 @@ class MooncakeKVSender(CommonKVSender): ...@@ -1008,6 +1121,7 @@ class MooncakeKVSender(CommonKVSender):
index_slice, index_slice,
True, True,
aux_index=self.aux_index, aux_index=self.aux_index,
state_indices=state_indices,
) )
def poll(self) -> KVPoll: def poll(self) -> KVPoll:
...@@ -1110,6 +1224,9 @@ class MooncakeKVReceiver(CommonKVReceiver): ...@@ -1110,6 +1224,9 @@ class MooncakeKVReceiver(CommonKVReceiver):
packed_aux_data_ptrs = b"".join( packed_aux_data_ptrs = b"".join(
struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.aux_data_ptrs struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.aux_data_ptrs
) )
packed_state_data_ptrs = b"".join(
struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.state_data_ptrs
)
# Note(shangming): No need to add pp rank here since pp is not supported on the decode side yet # Note(shangming): No need to add pp rank here since pp is not supported on the decode side yet
tp_rank = self.kv_mgr.kv_args.engine_rank tp_rank = self.kv_mgr.kv_args.engine_rank
kv_item_len = self.kv_mgr.kv_args.kv_item_lens[0] kv_item_len = self.kv_mgr.kv_args.kv_item_lens[0]
...@@ -1127,13 +1244,19 @@ class MooncakeKVReceiver(CommonKVReceiver): ...@@ -1127,13 +1244,19 @@ class MooncakeKVReceiver(CommonKVReceiver):
self.session_id.encode("ascii"), self.session_id.encode("ascii"),
packed_kv_data_ptrs, packed_kv_data_ptrs,
packed_aux_data_ptrs, packed_aux_data_ptrs,
packed_state_data_ptrs,
dst_tp_rank, dst_tp_rank,
dst_attn_tp_size, dst_attn_tp_size,
dst_kv_item_len, dst_kv_item_len,
] ]
) )
def init(self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None): def init(
self,
kv_indices: npt.NDArray[np.int32],
aux_index: Optional[int] = None,
state_indices: Optional[List[int]] = None,
):
for bootstrap_info in self.bootstrap_infos: for bootstrap_info in self.bootstrap_infos:
sock, lock = self._connect_to_bootstrap_server(bootstrap_info) sock, lock = self._connect_to_bootstrap_server(bootstrap_info)
is_dummy = bootstrap_info["is_dummy"] is_dummy = bootstrap_info["is_dummy"]
...@@ -1147,6 +1270,14 @@ class MooncakeKVReceiver(CommonKVReceiver): ...@@ -1147,6 +1270,14 @@ class MooncakeKVReceiver(CommonKVReceiver):
self.session_id.encode("ascii"), self.session_id.encode("ascii"),
kv_indices.tobytes() if not is_dummy else b"", kv_indices.tobytes() if not is_dummy else b"",
str(aux_index).encode("ascii") if not is_dummy else b"", str(aux_index).encode("ascii") if not is_dummy else b"",
(
np.array(
state_indices,
dtype=np.int32,
).tobytes()
if not is_dummy and state_indices is not None
else b""
),
str(self.required_dst_info_num).encode("ascii"), str(self.required_dst_info_num).encode("ascii"),
] ]
) )
......
...@@ -704,6 +704,7 @@ class NixlKVSender(CommonKVSender): ...@@ -704,6 +704,7 @@ class NixlKVSender(CommonKVSender):
def send( def send(
self, self,
kv_indices: npt.NDArray[np.int32], kv_indices: npt.NDArray[np.int32],
state_indices: Optional[List[int]] = None,
): ):
index_slice = slice(self.curr_idx, self.curr_idx + len(kv_indices)) index_slice = slice(self.curr_idx, self.curr_idx + len(kv_indices))
self.curr_idx += len(kv_indices) self.curr_idx += len(kv_indices)
...@@ -755,7 +756,12 @@ class NixlKVReceiver(CommonKVReceiver): ...@@ -755,7 +756,12 @@ class NixlKVReceiver(CommonKVReceiver):
self.bootstrap_room self.bootstrap_room
) )
def init(self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None): def init(
self,
kv_indices: npt.NDArray[np.int32],
aux_index: Optional[int] = None,
state_indices: Optional[List[int]] = None,
):
for bootstrap_info in self.bootstrap_infos: for bootstrap_info in self.bootstrap_infos:
logger.debug( logger.debug(
f"Fetched bootstrap info: {bootstrap_info} for engine rank: {self.kv_mgr.kv_args.engine_rank}" f"Fetched bootstrap info: {bootstrap_info} for engine rank: {self.kv_mgr.kv_args.engine_rank}"
......
...@@ -49,6 +49,11 @@ from sglang.srt.managers.schedule_batch import ( ...@@ -49,6 +49,11 @@ from sglang.srt.managers.schedule_batch import (
RequestStage, RequestStage,
ScheduleBatch, ScheduleBatch,
) )
from sglang.srt.mem_cache.memory_pool import (
HybridLinearKVPool,
NSATokenToKVPool,
SWAKVPool,
)
from sglang.srt.model_executor.forward_batch_info import ForwardMode, PPProxyTensors from sglang.srt.model_executor.forward_batch_info import ForwardMode, PPProxyTensors
from sglang.srt.utils import ( from sglang.srt.utils import (
DynamicGradMode, DynamicGradMode,
...@@ -146,6 +151,28 @@ class PrefillBootstrapQueue: ...@@ -146,6 +151,28 @@ class PrefillBootstrapQueue:
kv_args.ib_device = self.scheduler.server_args.disaggregation_ib_device kv_args.ib_device = self.scheduler.server_args.disaggregation_ib_device
kv_args.gpu_id = self.scheduler.gpu_id kv_args.gpu_id = self.scheduler.gpu_id
if hasattr(self.token_to_kv_pool, "get_state_buf_infos"):
state_data_ptrs, state_data_lens, state_item_lens = (
self.token_to_kv_pool.get_state_buf_infos()
)
kv_args.state_data_ptrs = state_data_ptrs
kv_args.state_data_lens = state_data_lens
kv_args.state_item_lens = state_item_lens
if isinstance(self.token_to_kv_pool, SWAKVPool):
kv_args.state_type = "swa"
elif isinstance(self.token_to_kv_pool, HybridLinearKVPool):
kv_args.state_type = "mamba"
elif isinstance(self.token_to_kv_pool, NSATokenToKVPool):
kv_args.state_type = "nsa"
else:
kv_args.state_type = "none"
else:
kv_args.state_data_ptrs = []
kv_args.state_data_lens = []
kv_args.state_item_lens = []
kv_args.state_type = "none"
kv_manager_class: Type[BaseKVManager] = get_kv_class( kv_manager_class: Type[BaseKVManager] = get_kv_class(
self.transfer_backend, KVClassType.MANAGER self.transfer_backend, KVClassType.MANAGER
) )
...@@ -618,15 +645,58 @@ class SchedulerDisaggregationPrefillMixin: ...@@ -618,15 +645,58 @@ class SchedulerDisaggregationPrefillMixin:
.numpy() .numpy()
) )
req.start_send_idx = end_idx req.start_send_idx = end_idx
state_indices = None
if last_chunk: if last_chunk:
self.disagg_metadata_buffers.set_buf(req) self.disagg_metadata_buffers.set_buf(req)
# Prepare extra pool indices for hybrid models
if isinstance(
self.token_to_kv_pool_allocator.get_kvcache(), HybridLinearKVPool
):
# Mamba hybrid model: send single mamba state index
state_indices = [
self.req_to_token_pool.req_index_to_mamba_index_mapping[
req.req_pool_idx
]
.cpu()
.numpy()
]
elif isinstance(self.token_to_kv_pool_allocator.get_kvcache(), SWAKVPool):
# SWA hybrid model: send last window KV indices
seq_len = len(req.fill_ids)
window_size = self.sliding_window_size
window_start = max(0, seq_len - window_size)
window_start = (window_start // page_size) * page_size
window_kv_indices_full = self.req_to_token_pool.req_to_token[
req.req_pool_idx, window_start:seq_len
]
# Translate to SWA pool indices
window_kv_indices_swa = (
self.token_to_kv_pool_allocator.translate_loc_from_full_to_swa(
window_kv_indices_full
)
)
state_indices = window_kv_indices_swa.cpu().numpy()
state_indices = kv_to_page_indices(state_indices, page_size)
elif isinstance(
self.token_to_kv_pool_allocator.get_kvcache(), NSATokenToKVPool
):
seq_len = len(req.fill_ids)
kv_indices_full = self.req_to_token_pool.req_to_token[
req.req_pool_idx, :seq_len
]
state_indices = kv_indices_full.cpu().numpy()
state_indices = kv_to_page_indices(state_indices, page_size)
page_indices = kv_to_page_indices(kv_indices, page_size) page_indices = kv_to_page_indices(kv_indices, page_size)
if len(page_indices) == 0: if len(page_indices) == 0:
logger.info( logger.info(
f"Skip sending kv chunk for request {req.rid=} {req.bootstrap_room=} because page_indices is empty" f"Skip sending kv chunk for request {req.rid=} {req.bootstrap_room=} because page_indices is empty"
) )
return return
req.disagg_kv_sender.send(page_indices) req.disagg_kv_sender.send(page_indices, state_indices)
# PP # PP
@DynamicGradMode() @DynamicGradMode()
......
...@@ -807,9 +807,6 @@ class Scheduler( ...@@ -807,9 +807,6 @@ class Scheduler(
self.tree_cache.cache_controller.layer_done_counter self.tree_cache.cache_controller.layer_done_counter
) )
elif self.is_hybrid: elif self.is_hybrid:
assert (
self.server_args.disaggregation_mode == "null"
), "Hybrid mode does not support disaggregation yet"
self.tree_cache = SWARadixCache( self.tree_cache = SWARadixCache(
req_to_token_pool=self.req_to_token_pool, req_to_token_pool=self.req_to_token_pool,
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator, token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
...@@ -819,9 +816,6 @@ class Scheduler( ...@@ -819,9 +816,6 @@ class Scheduler(
is_eagle=self.spec_algorithm.is_eagle(), is_eagle=self.spec_algorithm.is_eagle(),
) )
elif self.is_hybrid_gdn: elif self.is_hybrid_gdn:
assert (
self.server_args.disaggregation_mode == "null"
), "Hybrid GDN mode does not support disaggregation yet"
self.tree_cache = MambaRadixCache( self.tree_cache = MambaRadixCache(
req_to_token_pool=self.req_to_token_pool, req_to_token_pool=self.req_to_token_pool,
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator, token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
......
This diff is collapsed.
...@@ -1669,19 +1669,34 @@ class ModelRunner: ...@@ -1669,19 +1669,34 @@ class ModelRunner:
extra_max_context_len += self.server_args.speculative_num_draft_tokens extra_max_context_len += self.server_args.speculative_num_draft_tokens
if self.server_args.disaggregation_mode == "decode": if self.server_args.disaggregation_mode == "decode":
from sglang.srt.disaggregation.decode import DecodeReqToTokenPool from sglang.srt.disaggregation.decode import (
DecodeReqToTokenPool,
HybridMambaDecodeReqToTokenPool,
)
# subscribe memory for pre-allocated requests # subscribe memory for pre-allocated requests
# if max_num_reqs <= 32, we pre-allocate 2x requests # if max_num_reqs <= 32, we pre-allocate 2x requests
pre_alloc_size = max_num_reqs * 2 if max_num_reqs <= 32 else 0 pre_alloc_size = max_num_reqs * 2 if max_num_reqs <= 32 else 0
self.req_to_token_pool = DecodeReqToTokenPool( if config := self.mambaish_config:
size=max_num_reqs, self.req_to_token_pool = HybridMambaDecodeReqToTokenPool(
max_context_len=self.model_config.context_len size=max_num_reqs,
+ extra_max_context_len, max_context_len=self.model_config.context_len
device=self.device, + extra_max_context_len,
enable_memory_saver=self.server_args.enable_memory_saver, device=self.device,
pre_alloc_size=pre_alloc_size, enable_memory_saver=self.server_args.enable_memory_saver,
) cache_params=config.mamba2_cache_params,
speculative_num_draft_tokens=self.server_args.speculative_num_draft_tokens,
pre_alloc_size=pre_alloc_size,
)
else:
self.req_to_token_pool = DecodeReqToTokenPool(
size=max_num_reqs,
max_context_len=self.model_config.context_len
+ extra_max_context_len,
device=self.device,
enable_memory_saver=self.server_args.enable_memory_saver,
pre_alloc_size=pre_alloc_size,
)
elif config := self.mambaish_config: elif config := self.mambaish_config:
self.req_to_token_pool = HybridReqToTokenPool( self.req_to_token_pool = HybridReqToTokenPool(
size=max_num_reqs, size=max_num_reqs,
...@@ -1807,6 +1822,7 @@ class ModelRunner: ...@@ -1807,6 +1822,7 @@ class ModelRunner:
), ),
enable_kvcache_transpose=False, enable_kvcache_transpose=False,
device=self.device, device=self.device,
mamba_pool=self.req_to_token_pool.mamba_pool,
) )
else: else:
self.token_to_kv_pool = MHATokenToKVPool( self.token_to_kv_pool = MHATokenToKVPool(
......
...@@ -163,6 +163,7 @@ suites = { ...@@ -163,6 +163,7 @@ suites = {
TestFile("test_deepseek_v3_basic.py", 275), TestFile("test_deepseek_v3_basic.py", 275),
TestFile("test_deepseek_v3_mtp.py", 275), TestFile("test_deepseek_v3_mtp.py", 275),
TestFile("test_disaggregation_different_tp.py", 600), TestFile("test_disaggregation_different_tp.py", 600),
TestFile("test_disaggregation_hybrid_attention.py", 200),
TestFile("test_disaggregation_pp.py", 140), TestFile("test_disaggregation_pp.py", 140),
], ],
"per-commit-4-gpu-b200": [ "per-commit-4-gpu-b200": [
......
import os
import unittest
from types import SimpleNamespace
from sglang.srt.environ import envs
from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k
from sglang.test.test_disaggregation_utils import TestDisaggregationBase
from sglang.test.test_utils import (
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
popen_launch_pd_server,
)
class TestDisaggregationHybridAttentionMamba(TestDisaggregationBase):
@classmethod
def setUpClass(cls):
super().setUpClass()
cls.model = "Qwen/Qwen3-Next-80B-A3B-Instruct"
# Non blocking start servers
cls.start_prefill()
cls.start_decode()
# Block until both
cls.wait_server_ready(cls.prefill_url + "/health")
cls.wait_server_ready(cls.decode_url + "/health")
cls.launch_lb()
@classmethod
def start_prefill(cls):
prefill_args = [
"--trust-remote-code",
"--disaggregation-mode",
"prefill",
"--tp",
"4",
]
prefill_args += cls.transfer_backend + cls.rdma_devices
cls.process_prefill = popen_launch_pd_server(
cls.model,
cls.prefill_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=prefill_args,
)
@classmethod
def start_decode(cls):
decode_args = [
"--trust-remote-code",
"--disaggregation-mode",
"decode",
"--tp",
"4",
"--base-gpu-id",
"4",
]
decode_args += cls.transfer_backend + cls.rdma_devices
cls.process_decode = popen_launch_pd_server(
cls.model,
cls.decode_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=decode_args,
)
def test_gsm8k(self):
args = SimpleNamespace(
num_shots=5,
data_path=None,
num_questions=200,
max_new_tokens=512,
parallel=128,
host=f"http://{self.base_host}",
port=int(self.lb_port),
)
metrics = run_eval_few_shot_gsm8k(args)
print(f"Evaluation metrics: {metrics}")
self.assertGreater(metrics["accuracy"], 0.93)
if __name__ == "__main__":
unittest.main()
...@@ -42,6 +42,7 @@ class TestMamba(unittest.TestCase): ...@@ -42,6 +42,7 @@ class TestMamba(unittest.TestCase):
full_attention_layer_ids=full_attention_layer_ids, full_attention_layer_ids=full_attention_layer_ids,
enable_kvcache_transpose=False, enable_kvcache_transpose=False,
device=device, device=device,
mamba_pool=None,
) )
assert pool._transfer_full_attention_id(global_interval - 1) == 0 assert pool._transfer_full_attention_id(global_interval - 1) == 0
assert pool._transfer_full_attention_id(2 * global_interval - 1) == 1 assert pool._transfer_full_attention_id(2 * global_interval - 1) == 1
...@@ -173,6 +174,7 @@ class TestMamba(unittest.TestCase): ...@@ -173,6 +174,7 @@ class TestMamba(unittest.TestCase):
full_attention_layer_ids=full_attention_layer_ids, full_attention_layer_ids=full_attention_layer_ids,
enable_kvcache_transpose=False, enable_kvcache_transpose=False,
device=device, device=device,
mamba_pool=req_to_token_pool.mamba_pool,
) )
# setup token to kv pool allocator # setup token to kv pool allocator
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment