Unverified Commit 868403f6 authored by Shangming Cai's avatar Shangming Cai Committed by GitHub
Browse files

[PD] Add PD support for hybrid model (Qwen3-Next, DeepSeek V3.2 Exp) (#10912)


Signed-off-by: default avatarShangming Cai <csmthu@gmail.com>
Co-authored-by: default avatarhzh0425 <hzh0425@apache.org>
Co-authored-by: default avatarZeldaHuang <hzm414167@alibaba-inc.com>
parent 97d857c0
......@@ -20,6 +20,10 @@ class KVArgs:
aux_data_ptrs: List[int]
aux_data_lens: List[int]
aux_item_lens: List[int]
state_data_ptrs: List[int]
state_data_lens: List[int]
state_item_lens: List[int]
state_type: str # "none", "mamba", "swa"
ib_device: str
ib_traffic_class: str
gpu_id: int
......@@ -76,9 +80,13 @@ class BaseKVSender(ABC):
...
@abstractmethod
def send(self, kv_indices: npt.NDArray[np.int32]):
def send(
self,
kv_indices: npt.NDArray[np.int32],
state_indices: Optional[List[int]] = None,
):
"""
Send the kv cache at the given kv indices to the decoder server
Send the kv cache at the given kv indices and the extra cache/state at the given indices to the decoder server
"""
...
......@@ -108,9 +116,14 @@ class BaseKVReceiver(ABC):
): ...
@abstractmethod
def init(self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None):
def init(
self,
kv_indices: npt.NDArray[np.int32],
aux_index: Optional[int] = None,
state_indices: Optional[List[int]] = None,
):
"""
Notify the prefill server about the kv indices and aux index
Notify the prefill server about the kv indices, aux index, and state_indices.
"""
...
......
......@@ -201,6 +201,7 @@ class CommonKVSender(BaseKVSender):
def send(
self,
kv_indices: npt.NDArray[np.int32],
state_indices: Optional[List[int]] = None,
):
pass
......
......@@ -25,11 +25,12 @@ import time
from collections import deque
from dataclasses import dataclass
from http import HTTPStatus
from typing import TYPE_CHECKING, List, Optional, Tuple, Type, Union
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Type, Union
import torch
from torch.distributed import ProcessGroup
from sglang.srt.configs.mamba_utils import Mamba2CacheParams
from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE
from sglang.srt.disaggregation.base import BaseKVManager, BaseKVReceiver, KVPoll
from sglang.srt.disaggregation.utils import (
......@@ -47,9 +48,19 @@ from sglang.srt.disaggregation.utils import (
)
from sglang.srt.layers.dp_attention import get_attention_tp_size
from sglang.srt.managers.schedule_batch import FINISH_ABORT, RequestStage, ScheduleBatch
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
from sglang.srt.mem_cache.allocator import (
BaseTokenToKVPoolAllocator,
SWATokenToKVPoolAllocator,
)
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.memory_pool import KVCache, ReqToTokenPool
from sglang.srt.mem_cache.memory_pool import (
HybridLinearKVPool,
HybridReqToTokenPool,
KVCache,
NSATokenToKVPool,
ReqToTokenPool,
SWAKVPool,
)
from sglang.srt.model_executor.forward_batch_info import ForwardMode
from sglang.srt.utils import get_int_env_var, require_mlp_sync
from sglang.srt.utils.torch_memory_saver_adapter import TorchMemorySaverAdapter
......@@ -124,6 +135,35 @@ class DecodeReqToTokenPool:
self.free_slots = list(range(self.size + self.pre_alloc_size))
class HybridMambaDecodeReqToTokenPool(HybridReqToTokenPool):
def __init__(
self,
size: int,
max_context_len: int,
device: str,
enable_memory_saver: bool,
cache_params: "Mamba2CacheParams",
speculative_num_draft_tokens: int,
pre_alloc_size: int,
):
DecodeReqToTokenPool.__init__(
self,
size=size,
max_context_len=max_context_len,
device=device,
enable_memory_saver=enable_memory_saver,
pre_alloc_size=pre_alloc_size,
)
self._init_mamba_pool(
size + pre_alloc_size, cache_params, device, speculative_num_draft_tokens
)
def clear(self):
self.free_slots = list(range(self.size + self.pre_alloc_size))
self.mamba_pool.clear()
@dataclass
class DecodeRequest:
req: Req
......@@ -217,6 +257,28 @@ class DecodePreallocQueue:
self.metadata_buffers.get_buf_infos()
)
if hasattr(self.token_to_kv_pool, "get_state_buf_infos"):
state_data_ptrs, state_data_lens, state_item_lens = (
self.token_to_kv_pool.get_state_buf_infos()
)
kv_args.state_data_ptrs = state_data_ptrs
kv_args.state_data_lens = state_data_lens
kv_args.state_item_lens = state_item_lens
if isinstance(self.token_to_kv_pool, SWAKVPool):
kv_args.state_type = "swa"
elif isinstance(self.token_to_kv_pool, HybridLinearKVPool):
kv_args.state_type = "mamba"
elif isinstance(self.token_to_kv_pool, NSATokenToKVPool):
kv_args.state_type = "nsa"
else:
kv_args.state_type = "none"
else:
kv_args.state_data_ptrs = []
kv_args.state_data_lens = []
kv_args.state_item_lens = []
kv_args.state_type = "none"
kv_args.ib_device = self.scheduler.server_args.disaggregation_ib_device
kv_args.gpu_id = self.scheduler.gpu_id
kv_manager_class: Type[BaseKVManager] = get_kv_class(
......@@ -414,16 +476,56 @@ class DecodePreallocQueue:
.cpu()
.numpy()
)
page_size = self.token_to_kv_pool_allocator.page_size
# Prepare extra pool indices for hybrid models
if isinstance(self.token_to_kv_pool, HybridLinearKVPool):
# Mamba hybrid model: single mamba state index
state_indices = [
self.req_to_token_pool.req_index_to_mamba_index_mapping[
decode_req.req.req_pool_idx
]
.cpu()
.numpy()
]
elif isinstance(self.token_to_kv_pool, SWAKVPool):
# SWA hybrid model: send decode-side SWA window indices
seq_len = len(decode_req.req.origin_input_ids)
window_size = self.scheduler.sliding_window_size
window_start = max(0, seq_len - window_size)
window_start = (window_start // page_size) * page_size
window_kv_indices_full = self.req_to_token_pool.req_to_token[
decode_req.req.req_pool_idx, window_start:seq_len
]
# Translate to SWA pool indices
window_kv_indices_swa = (
self.token_to_kv_pool_allocator.translate_loc_from_full_to_swa(
window_kv_indices_full
)
)
state_indices = window_kv_indices_swa.cpu().numpy()
state_indices = kv_to_page_indices(state_indices, page_size)
elif isinstance(self.token_to_kv_pool, NSATokenToKVPool):
seq_len = len(decode_req.req.origin_input_ids)
kv_indices_full = self.req_to_token_pool.req_to_token[
decode_req.req.req_pool_idx, :seq_len
]
state_indices = kv_indices_full.cpu().numpy()
state_indices = kv_to_page_indices(state_indices, page_size)
else:
state_indices = None
decode_req.metadata_buffer_index = (
self.req_to_metadata_buffer_idx_allocator.alloc()
)
assert decode_req.metadata_buffer_index is not None
page_indices = kv_to_page_indices(
kv_indices, self.token_to_kv_pool_allocator.page_size
page_indices = kv_to_page_indices(kv_indices, page_size)
decode_req.kv_receiver.init(
page_indices, decode_req.metadata_buffer_index, state_indices
)
decode_req.kv_receiver.init(page_indices, decode_req.metadata_buffer_index)
decode_req.req.add_latency(RequestStage.DECODE_BOOTSTRAP)
preallocated_reqs.append(decode_req)
indices_to_remove.add(i)
decode_req.req.time_stats.decode_transfer_queue_entry_time = (
......@@ -503,7 +605,10 @@ class DecodePreallocQueue:
def _pre_alloc(self, req: Req) -> torch.Tensor:
"""Pre-allocate the memory for req_to_token and token_kv_pool"""
req_pool_indices = self.req_to_token_pool.alloc(1)
if isinstance(self.req_to_token_pool, HybridMambaDecodeReqToTokenPool):
req_pool_indices = self.req_to_token_pool.alloc(1, [req])
else:
req_pool_indices = self.req_to_token_pool.alloc(1)
assert (
req_pool_indices is not None
......
......@@ -48,9 +48,12 @@ class FakeKVSender(BaseKVSender):
def send(
self,
kv_indices: npt.NDArray[np.int32],
state_indices: Optional[List[int]] = None,
):
self.has_sent = True
logger.debug(f"FakeKVSender send with kv_indices: {kv_indices}")
logger.debug(
f"FakeKVSender send with kv_indices: {kv_indices}, state_indices: {state_indices}"
)
def failure_exception(self):
raise Exception("Fake KVSender Exception")
......@@ -75,10 +78,15 @@ class FakeKVReceiver(BaseKVReceiver):
logger.debug("FakeKVReceiver poll success")
return KVPoll.Success
def init(self, kv_indices: list[int], aux_index: Optional[int] = None):
def init(
self,
kv_indices: list[int],
aux_index: Optional[int] = None,
state_indices: Optional[List[int]] = None,
):
self.has_init = True
logger.debug(
f"FakeKVReceiver init with kv_indices: {kv_indices}, aux_index: {aux_index}"
f"FakeKVReceiver init with kv_indices: {kv_indices}, aux_index: {aux_index}, state_indices: {state_indices}"
)
def failure_exception(self):
......
......@@ -58,6 +58,7 @@ class TransferKVChunk:
index_slice: slice
is_last: bool
prefill_aux_index: Optional[int]
state_indices: Optional[List[int]]
# decode
......@@ -69,6 +70,7 @@ class TransferInfo:
mooncake_session_id: str
dst_kv_indices: npt.NDArray[np.int32]
dst_aux_index: int
dst_state_indices: List[int]
required_dst_info_num: int
is_dummy: bool
......@@ -78,9 +80,14 @@ class TransferInfo:
is_dummy = True
dst_kv_indices = np.array([], dtype=np.int32)
dst_aux_index = None
dst_state_indices = []
else:
dst_kv_indices = np.frombuffer(msg[4], dtype=np.int32)
dst_aux_index = int(msg[5].decode("ascii"))
if msg[6] == b"":
dst_state_indices = []
else:
dst_state_indices = list(np.frombuffer(msg[6], dtype=np.int32))
is_dummy = False
return cls(
room=int(msg[0].decode("ascii")),
......@@ -89,7 +96,8 @@ class TransferInfo:
mooncake_session_id=msg[3].decode("ascii"),
dst_kv_indices=dst_kv_indices,
dst_aux_index=dst_aux_index,
required_dst_info_num=int(msg[6].decode("ascii")),
dst_state_indices=dst_state_indices,
required_dst_info_num=int(msg[7].decode("ascii")),
is_dummy=is_dummy,
)
......@@ -103,6 +111,7 @@ class KVArgsRegisterInfo:
mooncake_session_id: str
dst_kv_ptrs: list[int]
dst_aux_ptrs: list[int]
dst_state_data_ptrs: list[int]
dst_tp_rank: int
dst_attn_tp_size: int
dst_kv_item_len: int
......@@ -116,9 +125,10 @@ class KVArgsRegisterInfo:
mooncake_session_id=msg[3].decode("ascii"),
dst_kv_ptrs=list(struct.unpack(f"{len(msg[4])//8}Q", msg[4])),
dst_aux_ptrs=list(struct.unpack(f"{len(msg[5])//8}Q", msg[5])),
dst_tp_rank=int(msg[6].decode("ascii")),
dst_attn_tp_size=int(msg[7].decode("ascii")),
dst_kv_item_len=int(msg[8].decode("ascii")),
dst_state_data_ptrs=list(struct.unpack(f"{len(msg[6])//8}Q", msg[6])),
dst_tp_rank=int(msg[7].decode("ascii")),
dst_attn_tp_size=int(msg[8].decode("ascii")),
dst_kv_item_len=int(msg[9].decode("ascii")),
)
......@@ -180,6 +190,9 @@ class MooncakeKVManager(CommonKVManager):
)
for _ in range(transfer_queue_size)
]
self.state_executors = concurrent.futures.ThreadPoolExecutor(
transfer_thread_pool_size // transfer_queue_size
)
for queue, executor in zip(self.transfer_queues, self.executors):
threading.Thread(
target=self.transfer_worker, args=(queue, executor), daemon=True
......@@ -239,6 +252,12 @@ class MooncakeKVManager(CommonKVManager):
self.kv_args.aux_data_ptrs, self.kv_args.aux_data_lens
)
# Batch register state/extra pool data buffers
if self.kv_args.state_data_ptrs and self.kv_args.state_data_lens:
self.engine.batch_register(
self.kv_args.state_data_ptrs, self.kv_args.state_data_lens
)
def _transfer_data(self, mooncake_session_id, transfer_blocks):
if not transfer_blocks:
return 0
......@@ -248,17 +267,23 @@ class MooncakeKVManager(CommonKVManager):
mooncake_session_id, list(src_addrs), list(dst_addrs), list(lengths)
)
def send_kvcache(
def _send_kvcache_generic(
self,
mooncake_session_id: str,
prefill_kv_indices: npt.NDArray[np.int32],
dst_kv_ptrs: list[int],
dst_kv_indices: npt.NDArray[np.int32],
src_data_ptrs: list[int],
dst_data_ptrs: list[int],
item_lens: list[int],
prefill_data_indices: npt.NDArray[np.int32],
dst_data_indices: npt.NDArray[np.int32],
executor: concurrent.futures.ThreadPoolExecutor,
):
# Group by indices
) -> int:
"""
Generic KV cache transfer supporting both MHA and MLA architectures.
This method is used by both send_kvcache (full pool) and maybe_send_extra.
"""
# Group by indices for optimization
prefill_kv_blocks, dst_kv_blocks = group_concurrent_contiguous(
prefill_kv_indices, dst_kv_indices
prefill_data_indices, dst_data_indices
)
layers_params = None
......@@ -266,9 +291,9 @@ class MooncakeKVManager(CommonKVManager):
# pp is not supported on the decode side yet
if self.is_mla_backend:
src_kv_ptrs, dst_kv_ptrs, layers_current_pp_stage = (
self.get_mla_kv_ptrs_with_pp(self.kv_args.kv_data_ptrs, dst_kv_ptrs)
self.get_mla_kv_ptrs_with_pp(src_data_ptrs, dst_data_ptrs)
)
kv_item_len = self.kv_args.kv_item_lens[0]
kv_item_len = item_lens[0]
layers_params = [
(
src_kv_ptrs[layer_id],
......@@ -279,9 +304,9 @@ class MooncakeKVManager(CommonKVManager):
]
else:
src_k_ptrs, src_v_ptrs, dst_k_ptrs, dst_v_ptrs, layers_current_pp_stage = (
self.get_mha_kv_ptrs_with_pp(self.kv_args.kv_data_ptrs, dst_kv_ptrs)
self.get_mha_kv_ptrs_with_pp(src_data_ptrs, dst_data_ptrs)
)
kv_item_len = self.kv_args.kv_item_lens[0]
kv_item_len = item_lens[0]
layers_params = [
(
src_k_ptrs[layer_id],
......@@ -345,6 +370,24 @@ class MooncakeKVManager(CommonKVManager):
return 0
def send_kvcache(
self,
mooncake_session_id: str,
prefill_kv_indices: npt.NDArray[np.int32],
dst_kv_ptrs: list[int],
dst_kv_indices: npt.NDArray[np.int32],
executor: concurrent.futures.ThreadPoolExecutor,
):
return self._send_kvcache_generic(
mooncake_session_id=mooncake_session_id,
src_data_ptrs=self.kv_args.kv_data_ptrs,
dst_data_ptrs=dst_kv_ptrs,
item_lens=self.kv_args.kv_item_lens,
prefill_data_indices=prefill_kv_indices,
dst_data_indices=dst_kv_indices,
executor=executor,
)
def send_kvcache_slice(
self,
mooncake_session_id: str,
......@@ -593,6 +636,58 @@ class MooncakeKVManager(CommonKVManager):
f"Received AUX_DATA for bootstrap_room {room} with length:{len(data)}"
)
def maybe_send_extra(
self,
req: TransferInfo,
prefill_state_indices: list[int],
dst_state_data_ptrs: list[int],
):
"""Send state or extra pool data with type-specific handling."""
state_type = getattr(self.kv_args, "state_type", "none")
if state_type == "mamba":
return self._send_mamba_state(
req,
prefill_state_indices,
dst_state_data_ptrs,
)
elif state_type in ["swa", "nsa"]:
# Reuse _send_kvcache_generic interface to send extra pool data
prefill_state_indices = np.array(prefill_state_indices, dtype=np.int32)
dst_state_indices = np.array(req.dst_state_indices, dtype=np.int32)
return self._send_kvcache_generic(
mooncake_session_id=req.mooncake_session_id,
src_data_ptrs=self.kv_args.state_data_ptrs,
dst_data_ptrs=dst_state_data_ptrs,
item_lens=self.kv_args.state_item_lens,
prefill_data_indices=prefill_state_indices,
dst_data_indices=dst_state_indices,
executor=self.state_executors,
)
else:
return 0
def _send_mamba_state(
self,
req: TransferInfo,
prefill_mamba_index: list[int],
dst_state_data_ptrs: list[int],
):
"""Transfer Mamba states."""
assert len(prefill_mamba_index) == 1, "Mamba should have single state index"
transfer_blocks = []
prefill_state_data_ptrs = self.kv_args.state_data_ptrs
prefill_state_item_lens = self.kv_args.state_item_lens
for i, dst_state_ptr in enumerate(dst_state_data_ptrs):
length = prefill_state_item_lens[i]
src_addr = prefill_state_data_ptrs[i] + length * int(prefill_mamba_index[0])
dst_addr = dst_state_ptr + length * int(req.dst_state_indices[0])
transfer_blocks.append((src_addr, dst_addr, length))
return self._transfer_data(req.mooncake_session_id, transfer_blocks)
def sync_status_to_decode_endpoint(
self, remote: str, dst_port: int, room: int, status: int, prefill_rank: int
):
......@@ -702,6 +797,21 @@ class MooncakeKVManager(CommonKVManager):
break
if kv_chunk.is_last:
if kv_chunk.state_indices is not None:
if not self.is_mla_backend and (
self.attn_tp_size
!= target_rank_registration_info.dst_attn_tp_size
):
raise RuntimeError(
f"PD Disaggregation does NOT support PD different TP sizes for non-MLA hybrid models yet."
)
self.maybe_send_extra(
req,
kv_chunk.state_indices,
target_rank_registration_info.dst_state_data_ptrs,
)
if self.pp_group.is_last_rank:
# Only the last chunk we need to send the aux data
ret = self.send_aux(
......@@ -765,7 +875,7 @@ class MooncakeKVManager(CommonKVManager):
)
continue
else:
required_dst_info_num = int(waiting_req_bytes[6].decode("ascii"))
required_dst_info_num = int(waiting_req_bytes[7].decode("ascii"))
room = int(room)
if room not in self.transfer_infos:
self.transfer_infos[room] = {}
......@@ -876,6 +986,7 @@ class MooncakeKVManager(CommonKVManager):
index_slice: slice,
is_last: bool,
aux_index: Optional[int] = None,
state_indices: Optional[List[int]] = None,
):
assert self.disaggregation_mode == DisaggregationMode.PREFILL
assert not is_last or (is_last and aux_index is not None)
......@@ -909,6 +1020,7 @@ class MooncakeKVManager(CommonKVManager):
index_slice=index_slice,
is_last=is_last,
prefill_aux_index=aux_index,
state_indices=state_indices,
)
)
......@@ -989,6 +1101,7 @@ class MooncakeKVSender(CommonKVSender):
def send(
self,
kv_indices: npt.NDArray[np.int32],
state_indices: Optional[List[int]] = None,
):
index_slice = slice(self.curr_idx, self.curr_idx + len(kv_indices))
self.curr_idx += len(kv_indices)
......@@ -1008,6 +1121,7 @@ class MooncakeKVSender(CommonKVSender):
index_slice,
True,
aux_index=self.aux_index,
state_indices=state_indices,
)
def poll(self) -> KVPoll:
......@@ -1110,6 +1224,9 @@ class MooncakeKVReceiver(CommonKVReceiver):
packed_aux_data_ptrs = b"".join(
struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.aux_data_ptrs
)
packed_state_data_ptrs = b"".join(
struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.state_data_ptrs
)
# Note(shangming): No need to add pp rank here since pp is not supported on the decode side yet
tp_rank = self.kv_mgr.kv_args.engine_rank
kv_item_len = self.kv_mgr.kv_args.kv_item_lens[0]
......@@ -1127,13 +1244,19 @@ class MooncakeKVReceiver(CommonKVReceiver):
self.session_id.encode("ascii"),
packed_kv_data_ptrs,
packed_aux_data_ptrs,
packed_state_data_ptrs,
dst_tp_rank,
dst_attn_tp_size,
dst_kv_item_len,
]
)
def init(self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None):
def init(
self,
kv_indices: npt.NDArray[np.int32],
aux_index: Optional[int] = None,
state_indices: Optional[List[int]] = None,
):
for bootstrap_info in self.bootstrap_infos:
sock, lock = self._connect_to_bootstrap_server(bootstrap_info)
is_dummy = bootstrap_info["is_dummy"]
......@@ -1147,6 +1270,14 @@ class MooncakeKVReceiver(CommonKVReceiver):
self.session_id.encode("ascii"),
kv_indices.tobytes() if not is_dummy else b"",
str(aux_index).encode("ascii") if not is_dummy else b"",
(
np.array(
state_indices,
dtype=np.int32,
).tobytes()
if not is_dummy and state_indices is not None
else b""
),
str(self.required_dst_info_num).encode("ascii"),
]
)
......
......@@ -704,6 +704,7 @@ class NixlKVSender(CommonKVSender):
def send(
self,
kv_indices: npt.NDArray[np.int32],
state_indices: Optional[List[int]] = None,
):
index_slice = slice(self.curr_idx, self.curr_idx + len(kv_indices))
self.curr_idx += len(kv_indices)
......@@ -755,7 +756,12 @@ class NixlKVReceiver(CommonKVReceiver):
self.bootstrap_room
)
def init(self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None):
def init(
self,
kv_indices: npt.NDArray[np.int32],
aux_index: Optional[int] = None,
state_indices: Optional[List[int]] = None,
):
for bootstrap_info in self.bootstrap_infos:
logger.debug(
f"Fetched bootstrap info: {bootstrap_info} for engine rank: {self.kv_mgr.kv_args.engine_rank}"
......
......@@ -49,6 +49,11 @@ from sglang.srt.managers.schedule_batch import (
RequestStage,
ScheduleBatch,
)
from sglang.srt.mem_cache.memory_pool import (
HybridLinearKVPool,
NSATokenToKVPool,
SWAKVPool,
)
from sglang.srt.model_executor.forward_batch_info import ForwardMode, PPProxyTensors
from sglang.srt.utils import (
DynamicGradMode,
......@@ -146,6 +151,28 @@ class PrefillBootstrapQueue:
kv_args.ib_device = self.scheduler.server_args.disaggregation_ib_device
kv_args.gpu_id = self.scheduler.gpu_id
if hasattr(self.token_to_kv_pool, "get_state_buf_infos"):
state_data_ptrs, state_data_lens, state_item_lens = (
self.token_to_kv_pool.get_state_buf_infos()
)
kv_args.state_data_ptrs = state_data_ptrs
kv_args.state_data_lens = state_data_lens
kv_args.state_item_lens = state_item_lens
if isinstance(self.token_to_kv_pool, SWAKVPool):
kv_args.state_type = "swa"
elif isinstance(self.token_to_kv_pool, HybridLinearKVPool):
kv_args.state_type = "mamba"
elif isinstance(self.token_to_kv_pool, NSATokenToKVPool):
kv_args.state_type = "nsa"
else:
kv_args.state_type = "none"
else:
kv_args.state_data_ptrs = []
kv_args.state_data_lens = []
kv_args.state_item_lens = []
kv_args.state_type = "none"
kv_manager_class: Type[BaseKVManager] = get_kv_class(
self.transfer_backend, KVClassType.MANAGER
)
......@@ -618,15 +645,58 @@ class SchedulerDisaggregationPrefillMixin:
.numpy()
)
req.start_send_idx = end_idx
state_indices = None
if last_chunk:
self.disagg_metadata_buffers.set_buf(req)
# Prepare extra pool indices for hybrid models
if isinstance(
self.token_to_kv_pool_allocator.get_kvcache(), HybridLinearKVPool
):
# Mamba hybrid model: send single mamba state index
state_indices = [
self.req_to_token_pool.req_index_to_mamba_index_mapping[
req.req_pool_idx
]
.cpu()
.numpy()
]
elif isinstance(self.token_to_kv_pool_allocator.get_kvcache(), SWAKVPool):
# SWA hybrid model: send last window KV indices
seq_len = len(req.fill_ids)
window_size = self.sliding_window_size
window_start = max(0, seq_len - window_size)
window_start = (window_start // page_size) * page_size
window_kv_indices_full = self.req_to_token_pool.req_to_token[
req.req_pool_idx, window_start:seq_len
]
# Translate to SWA pool indices
window_kv_indices_swa = (
self.token_to_kv_pool_allocator.translate_loc_from_full_to_swa(
window_kv_indices_full
)
)
state_indices = window_kv_indices_swa.cpu().numpy()
state_indices = kv_to_page_indices(state_indices, page_size)
elif isinstance(
self.token_to_kv_pool_allocator.get_kvcache(), NSATokenToKVPool
):
seq_len = len(req.fill_ids)
kv_indices_full = self.req_to_token_pool.req_to_token[
req.req_pool_idx, :seq_len
]
state_indices = kv_indices_full.cpu().numpy()
state_indices = kv_to_page_indices(state_indices, page_size)
page_indices = kv_to_page_indices(kv_indices, page_size)
if len(page_indices) == 0:
logger.info(
f"Skip sending kv chunk for request {req.rid=} {req.bootstrap_room=} because page_indices is empty"
)
return
req.disagg_kv_sender.send(page_indices)
req.disagg_kv_sender.send(page_indices, state_indices)
# PP
@DynamicGradMode()
......
......@@ -807,9 +807,6 @@ class Scheduler(
self.tree_cache.cache_controller.layer_done_counter
)
elif self.is_hybrid:
assert (
self.server_args.disaggregation_mode == "null"
), "Hybrid mode does not support disaggregation yet"
self.tree_cache = SWARadixCache(
req_to_token_pool=self.req_to_token_pool,
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
......@@ -819,9 +816,6 @@ class Scheduler(
is_eagle=self.spec_algorithm.is_eagle(),
)
elif self.is_hybrid_gdn:
assert (
self.server_args.disaggregation_mode == "null"
), "Hybrid GDN mode does not support disaggregation yet"
self.tree_cache = MambaRadixCache(
req_to_token_pool=self.req_to_token_pool,
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
......
This diff is collapsed.
......@@ -1669,19 +1669,34 @@ class ModelRunner:
extra_max_context_len += self.server_args.speculative_num_draft_tokens
if self.server_args.disaggregation_mode == "decode":
from sglang.srt.disaggregation.decode import DecodeReqToTokenPool
from sglang.srt.disaggregation.decode import (
DecodeReqToTokenPool,
HybridMambaDecodeReqToTokenPool,
)
# subscribe memory for pre-allocated requests
# if max_num_reqs <= 32, we pre-allocate 2x requests
pre_alloc_size = max_num_reqs * 2 if max_num_reqs <= 32 else 0
self.req_to_token_pool = DecodeReqToTokenPool(
size=max_num_reqs,
max_context_len=self.model_config.context_len
+ extra_max_context_len,
device=self.device,
enable_memory_saver=self.server_args.enable_memory_saver,
pre_alloc_size=pre_alloc_size,
)
if config := self.mambaish_config:
self.req_to_token_pool = HybridMambaDecodeReqToTokenPool(
size=max_num_reqs,
max_context_len=self.model_config.context_len
+ extra_max_context_len,
device=self.device,
enable_memory_saver=self.server_args.enable_memory_saver,
cache_params=config.mamba2_cache_params,
speculative_num_draft_tokens=self.server_args.speculative_num_draft_tokens,
pre_alloc_size=pre_alloc_size,
)
else:
self.req_to_token_pool = DecodeReqToTokenPool(
size=max_num_reqs,
max_context_len=self.model_config.context_len
+ extra_max_context_len,
device=self.device,
enable_memory_saver=self.server_args.enable_memory_saver,
pre_alloc_size=pre_alloc_size,
)
elif config := self.mambaish_config:
self.req_to_token_pool = HybridReqToTokenPool(
size=max_num_reqs,
......@@ -1807,6 +1822,7 @@ class ModelRunner:
),
enable_kvcache_transpose=False,
device=self.device,
mamba_pool=self.req_to_token_pool.mamba_pool,
)
else:
self.token_to_kv_pool = MHATokenToKVPool(
......
......@@ -163,6 +163,7 @@ suites = {
TestFile("test_deepseek_v3_basic.py", 275),
TestFile("test_deepseek_v3_mtp.py", 275),
TestFile("test_disaggregation_different_tp.py", 600),
TestFile("test_disaggregation_hybrid_attention.py", 200),
TestFile("test_disaggregation_pp.py", 140),
],
"per-commit-4-gpu-b200": [
......
import os
import unittest
from types import SimpleNamespace
from sglang.srt.environ import envs
from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k
from sglang.test.test_disaggregation_utils import TestDisaggregationBase
from sglang.test.test_utils import (
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
popen_launch_pd_server,
)
class TestDisaggregationHybridAttentionMamba(TestDisaggregationBase):
@classmethod
def setUpClass(cls):
super().setUpClass()
cls.model = "Qwen/Qwen3-Next-80B-A3B-Instruct"
# Non blocking start servers
cls.start_prefill()
cls.start_decode()
# Block until both
cls.wait_server_ready(cls.prefill_url + "/health")
cls.wait_server_ready(cls.decode_url + "/health")
cls.launch_lb()
@classmethod
def start_prefill(cls):
prefill_args = [
"--trust-remote-code",
"--disaggregation-mode",
"prefill",
"--tp",
"4",
]
prefill_args += cls.transfer_backend + cls.rdma_devices
cls.process_prefill = popen_launch_pd_server(
cls.model,
cls.prefill_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=prefill_args,
)
@classmethod
def start_decode(cls):
decode_args = [
"--trust-remote-code",
"--disaggregation-mode",
"decode",
"--tp",
"4",
"--base-gpu-id",
"4",
]
decode_args += cls.transfer_backend + cls.rdma_devices
cls.process_decode = popen_launch_pd_server(
cls.model,
cls.decode_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=decode_args,
)
def test_gsm8k(self):
args = SimpleNamespace(
num_shots=5,
data_path=None,
num_questions=200,
max_new_tokens=512,
parallel=128,
host=f"http://{self.base_host}",
port=int(self.lb_port),
)
metrics = run_eval_few_shot_gsm8k(args)
print(f"Evaluation metrics: {metrics}")
self.assertGreater(metrics["accuracy"], 0.93)
if __name__ == "__main__":
unittest.main()
......@@ -42,6 +42,7 @@ class TestMamba(unittest.TestCase):
full_attention_layer_ids=full_attention_layer_ids,
enable_kvcache_transpose=False,
device=device,
mamba_pool=None,
)
assert pool._transfer_full_attention_id(global_interval - 1) == 0
assert pool._transfer_full_attention_id(2 * global_interval - 1) == 1
......@@ -173,6 +174,7 @@ class TestMamba(unittest.TestCase):
full_attention_layer_ids=full_attention_layer_ids,
enable_kvcache_transpose=False,
device=device,
mamba_pool=req_to_token_pool.mamba_pool,
)
# setup token to kv pool allocator
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment