Unverified Commit 868403f6 authored by Shangming Cai's avatar Shangming Cai Committed by GitHub
Browse files

[PD] Add PD support for hybrid model (Qwen3-Next, DeepSeek V3.2 Exp) (#10912)


Signed-off-by: default avatarShangming Cai <csmthu@gmail.com>
Co-authored-by: default avatarhzh0425 <hzh0425@apache.org>
Co-authored-by: default avatarZeldaHuang <hzm414167@alibaba-inc.com>
parent 97d857c0
......@@ -20,6 +20,10 @@ class KVArgs:
aux_data_ptrs: List[int]
aux_data_lens: List[int]
aux_item_lens: List[int]
state_data_ptrs: List[int]
state_data_lens: List[int]
state_item_lens: List[int]
state_type: str # "none", "mamba", "swa"
ib_device: str
ib_traffic_class: str
gpu_id: int
......@@ -76,9 +80,13 @@ class BaseKVSender(ABC):
...
@abstractmethod
def send(self, kv_indices: npt.NDArray[np.int32]):
def send(
self,
kv_indices: npt.NDArray[np.int32],
state_indices: Optional[List[int]] = None,
):
"""
Send the kv cache at the given kv indices to the decoder server
Send the kv cache at the given kv indices and the extra cache/state at the given indices to the decoder server
"""
...
......@@ -108,9 +116,14 @@ class BaseKVReceiver(ABC):
): ...
@abstractmethod
def init(self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None):
def init(
self,
kv_indices: npt.NDArray[np.int32],
aux_index: Optional[int] = None,
state_indices: Optional[List[int]] = None,
):
"""
Notify the prefill server about the kv indices and aux index
Notify the prefill server about the kv indices, aux index, and state_indices.
"""
...
......
......@@ -201,6 +201,7 @@ class CommonKVSender(BaseKVSender):
def send(
self,
kv_indices: npt.NDArray[np.int32],
state_indices: Optional[List[int]] = None,
):
pass
......
......@@ -25,11 +25,12 @@ import time
from collections import deque
from dataclasses import dataclass
from http import HTTPStatus
from typing import TYPE_CHECKING, List, Optional, Tuple, Type, Union
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Type, Union
import torch
from torch.distributed import ProcessGroup
from sglang.srt.configs.mamba_utils import Mamba2CacheParams
from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE
from sglang.srt.disaggregation.base import BaseKVManager, BaseKVReceiver, KVPoll
from sglang.srt.disaggregation.utils import (
......@@ -47,9 +48,19 @@ from sglang.srt.disaggregation.utils import (
)
from sglang.srt.layers.dp_attention import get_attention_tp_size
from sglang.srt.managers.schedule_batch import FINISH_ABORT, RequestStage, ScheduleBatch
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
from sglang.srt.mem_cache.allocator import (
BaseTokenToKVPoolAllocator,
SWATokenToKVPoolAllocator,
)
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.memory_pool import KVCache, ReqToTokenPool
from sglang.srt.mem_cache.memory_pool import (
HybridLinearKVPool,
HybridReqToTokenPool,
KVCache,
NSATokenToKVPool,
ReqToTokenPool,
SWAKVPool,
)
from sglang.srt.model_executor.forward_batch_info import ForwardMode
from sglang.srt.utils import get_int_env_var, require_mlp_sync
from sglang.srt.utils.torch_memory_saver_adapter import TorchMemorySaverAdapter
......@@ -124,6 +135,35 @@ class DecodeReqToTokenPool:
self.free_slots = list(range(self.size + self.pre_alloc_size))
class HybridMambaDecodeReqToTokenPool(HybridReqToTokenPool):
def __init__(
self,
size: int,
max_context_len: int,
device: str,
enable_memory_saver: bool,
cache_params: "Mamba2CacheParams",
speculative_num_draft_tokens: int,
pre_alloc_size: int,
):
DecodeReqToTokenPool.__init__(
self,
size=size,
max_context_len=max_context_len,
device=device,
enable_memory_saver=enable_memory_saver,
pre_alloc_size=pre_alloc_size,
)
self._init_mamba_pool(
size + pre_alloc_size, cache_params, device, speculative_num_draft_tokens
)
def clear(self):
self.free_slots = list(range(self.size + self.pre_alloc_size))
self.mamba_pool.clear()
@dataclass
class DecodeRequest:
req: Req
......@@ -217,6 +257,28 @@ class DecodePreallocQueue:
self.metadata_buffers.get_buf_infos()
)
if hasattr(self.token_to_kv_pool, "get_state_buf_infos"):
state_data_ptrs, state_data_lens, state_item_lens = (
self.token_to_kv_pool.get_state_buf_infos()
)
kv_args.state_data_ptrs = state_data_ptrs
kv_args.state_data_lens = state_data_lens
kv_args.state_item_lens = state_item_lens
if isinstance(self.token_to_kv_pool, SWAKVPool):
kv_args.state_type = "swa"
elif isinstance(self.token_to_kv_pool, HybridLinearKVPool):
kv_args.state_type = "mamba"
elif isinstance(self.token_to_kv_pool, NSATokenToKVPool):
kv_args.state_type = "nsa"
else:
kv_args.state_type = "none"
else:
kv_args.state_data_ptrs = []
kv_args.state_data_lens = []
kv_args.state_item_lens = []
kv_args.state_type = "none"
kv_args.ib_device = self.scheduler.server_args.disaggregation_ib_device
kv_args.gpu_id = self.scheduler.gpu_id
kv_manager_class: Type[BaseKVManager] = get_kv_class(
......@@ -414,16 +476,56 @@ class DecodePreallocQueue:
.cpu()
.numpy()
)
page_size = self.token_to_kv_pool_allocator.page_size
# Prepare extra pool indices for hybrid models
if isinstance(self.token_to_kv_pool, HybridLinearKVPool):
# Mamba hybrid model: single mamba state index
state_indices = [
self.req_to_token_pool.req_index_to_mamba_index_mapping[
decode_req.req.req_pool_idx
]
.cpu()
.numpy()
]
elif isinstance(self.token_to_kv_pool, SWAKVPool):
# SWA hybrid model: send decode-side SWA window indices
seq_len = len(decode_req.req.origin_input_ids)
window_size = self.scheduler.sliding_window_size
window_start = max(0, seq_len - window_size)
window_start = (window_start // page_size) * page_size
window_kv_indices_full = self.req_to_token_pool.req_to_token[
decode_req.req.req_pool_idx, window_start:seq_len
]
# Translate to SWA pool indices
window_kv_indices_swa = (
self.token_to_kv_pool_allocator.translate_loc_from_full_to_swa(
window_kv_indices_full
)
)
state_indices = window_kv_indices_swa.cpu().numpy()
state_indices = kv_to_page_indices(state_indices, page_size)
elif isinstance(self.token_to_kv_pool, NSATokenToKVPool):
seq_len = len(decode_req.req.origin_input_ids)
kv_indices_full = self.req_to_token_pool.req_to_token[
decode_req.req.req_pool_idx, :seq_len
]
state_indices = kv_indices_full.cpu().numpy()
state_indices = kv_to_page_indices(state_indices, page_size)
else:
state_indices = None
decode_req.metadata_buffer_index = (
self.req_to_metadata_buffer_idx_allocator.alloc()
)
assert decode_req.metadata_buffer_index is not None
page_indices = kv_to_page_indices(
kv_indices, self.token_to_kv_pool_allocator.page_size
page_indices = kv_to_page_indices(kv_indices, page_size)
decode_req.kv_receiver.init(
page_indices, decode_req.metadata_buffer_index, state_indices
)
decode_req.kv_receiver.init(page_indices, decode_req.metadata_buffer_index)
decode_req.req.add_latency(RequestStage.DECODE_BOOTSTRAP)
preallocated_reqs.append(decode_req)
indices_to_remove.add(i)
decode_req.req.time_stats.decode_transfer_queue_entry_time = (
......@@ -503,7 +605,10 @@ class DecodePreallocQueue:
def _pre_alloc(self, req: Req) -> torch.Tensor:
"""Pre-allocate the memory for req_to_token and token_kv_pool"""
req_pool_indices = self.req_to_token_pool.alloc(1)
if isinstance(self.req_to_token_pool, HybridMambaDecodeReqToTokenPool):
req_pool_indices = self.req_to_token_pool.alloc(1, [req])
else:
req_pool_indices = self.req_to_token_pool.alloc(1)
assert (
req_pool_indices is not None
......
......@@ -48,9 +48,12 @@ class FakeKVSender(BaseKVSender):
def send(
self,
kv_indices: npt.NDArray[np.int32],
state_indices: Optional[List[int]] = None,
):
self.has_sent = True
logger.debug(f"FakeKVSender send with kv_indices: {kv_indices}")
logger.debug(
f"FakeKVSender send with kv_indices: {kv_indices}, state_indices: {state_indices}"
)
def failure_exception(self):
raise Exception("Fake KVSender Exception")
......@@ -75,10 +78,15 @@ class FakeKVReceiver(BaseKVReceiver):
logger.debug("FakeKVReceiver poll success")
return KVPoll.Success
def init(self, kv_indices: list[int], aux_index: Optional[int] = None):
def init(
self,
kv_indices: list[int],
aux_index: Optional[int] = None,
state_indices: Optional[List[int]] = None,
):
self.has_init = True
logger.debug(
f"FakeKVReceiver init with kv_indices: {kv_indices}, aux_index: {aux_index}"
f"FakeKVReceiver init with kv_indices: {kv_indices}, aux_index: {aux_index}, state_indices: {state_indices}"
)
def failure_exception(self):
......
......@@ -58,6 +58,7 @@ class TransferKVChunk:
index_slice: slice
is_last: bool
prefill_aux_index: Optional[int]
state_indices: Optional[List[int]]
# decode
......@@ -69,6 +70,7 @@ class TransferInfo:
mooncake_session_id: str
dst_kv_indices: npt.NDArray[np.int32]
dst_aux_index: int
dst_state_indices: List[int]
required_dst_info_num: int
is_dummy: bool
......@@ -78,9 +80,14 @@ class TransferInfo:
is_dummy = True
dst_kv_indices = np.array([], dtype=np.int32)
dst_aux_index = None
dst_state_indices = []
else:
dst_kv_indices = np.frombuffer(msg[4], dtype=np.int32)
dst_aux_index = int(msg[5].decode("ascii"))
if msg[6] == b"":
dst_state_indices = []
else:
dst_state_indices = list(np.frombuffer(msg[6], dtype=np.int32))
is_dummy = False
return cls(
room=int(msg[0].decode("ascii")),
......@@ -89,7 +96,8 @@ class TransferInfo:
mooncake_session_id=msg[3].decode("ascii"),
dst_kv_indices=dst_kv_indices,
dst_aux_index=dst_aux_index,
required_dst_info_num=int(msg[6].decode("ascii")),
dst_state_indices=dst_state_indices,
required_dst_info_num=int(msg[7].decode("ascii")),
is_dummy=is_dummy,
)
......@@ -103,6 +111,7 @@ class KVArgsRegisterInfo:
mooncake_session_id: str
dst_kv_ptrs: list[int]
dst_aux_ptrs: list[int]
dst_state_data_ptrs: list[int]
dst_tp_rank: int
dst_attn_tp_size: int
dst_kv_item_len: int
......@@ -116,9 +125,10 @@ class KVArgsRegisterInfo:
mooncake_session_id=msg[3].decode("ascii"),
dst_kv_ptrs=list(struct.unpack(f"{len(msg[4])//8}Q", msg[4])),
dst_aux_ptrs=list(struct.unpack(f"{len(msg[5])//8}Q", msg[5])),
dst_tp_rank=int(msg[6].decode("ascii")),
dst_attn_tp_size=int(msg[7].decode("ascii")),
dst_kv_item_len=int(msg[8].decode("ascii")),
dst_state_data_ptrs=list(struct.unpack(f"{len(msg[6])//8}Q", msg[6])),
dst_tp_rank=int(msg[7].decode("ascii")),
dst_attn_tp_size=int(msg[8].decode("ascii")),
dst_kv_item_len=int(msg[9].decode("ascii")),
)
......@@ -180,6 +190,9 @@ class MooncakeKVManager(CommonKVManager):
)
for _ in range(transfer_queue_size)
]
self.state_executors = concurrent.futures.ThreadPoolExecutor(
transfer_thread_pool_size // transfer_queue_size
)
for queue, executor in zip(self.transfer_queues, self.executors):
threading.Thread(
target=self.transfer_worker, args=(queue, executor), daemon=True
......@@ -239,6 +252,12 @@ class MooncakeKVManager(CommonKVManager):
self.kv_args.aux_data_ptrs, self.kv_args.aux_data_lens
)
# Batch register state/extra pool data buffers
if self.kv_args.state_data_ptrs and self.kv_args.state_data_lens:
self.engine.batch_register(
self.kv_args.state_data_ptrs, self.kv_args.state_data_lens
)
def _transfer_data(self, mooncake_session_id, transfer_blocks):
if not transfer_blocks:
return 0
......@@ -248,17 +267,23 @@ class MooncakeKVManager(CommonKVManager):
mooncake_session_id, list(src_addrs), list(dst_addrs), list(lengths)
)
def send_kvcache(
def _send_kvcache_generic(
self,
mooncake_session_id: str,
prefill_kv_indices: npt.NDArray[np.int32],
dst_kv_ptrs: list[int],
dst_kv_indices: npt.NDArray[np.int32],
src_data_ptrs: list[int],
dst_data_ptrs: list[int],
item_lens: list[int],
prefill_data_indices: npt.NDArray[np.int32],
dst_data_indices: npt.NDArray[np.int32],
executor: concurrent.futures.ThreadPoolExecutor,
):
# Group by indices
) -> int:
"""
Generic KV cache transfer supporting both MHA and MLA architectures.
This method is used by both send_kvcache (full pool) and maybe_send_extra.
"""
# Group by indices for optimization
prefill_kv_blocks, dst_kv_blocks = group_concurrent_contiguous(
prefill_kv_indices, dst_kv_indices
prefill_data_indices, dst_data_indices
)
layers_params = None
......@@ -266,9 +291,9 @@ class MooncakeKVManager(CommonKVManager):
# pp is not supported on the decode side yet
if self.is_mla_backend:
src_kv_ptrs, dst_kv_ptrs, layers_current_pp_stage = (
self.get_mla_kv_ptrs_with_pp(self.kv_args.kv_data_ptrs, dst_kv_ptrs)
self.get_mla_kv_ptrs_with_pp(src_data_ptrs, dst_data_ptrs)
)
kv_item_len = self.kv_args.kv_item_lens[0]
kv_item_len = item_lens[0]
layers_params = [
(
src_kv_ptrs[layer_id],
......@@ -279,9 +304,9 @@ class MooncakeKVManager(CommonKVManager):
]
else:
src_k_ptrs, src_v_ptrs, dst_k_ptrs, dst_v_ptrs, layers_current_pp_stage = (
self.get_mha_kv_ptrs_with_pp(self.kv_args.kv_data_ptrs, dst_kv_ptrs)
self.get_mha_kv_ptrs_with_pp(src_data_ptrs, dst_data_ptrs)
)
kv_item_len = self.kv_args.kv_item_lens[0]
kv_item_len = item_lens[0]
layers_params = [
(
src_k_ptrs[layer_id],
......@@ -345,6 +370,24 @@ class MooncakeKVManager(CommonKVManager):
return 0
def send_kvcache(
self,
mooncake_session_id: str,
prefill_kv_indices: npt.NDArray[np.int32],
dst_kv_ptrs: list[int],
dst_kv_indices: npt.NDArray[np.int32],
executor: concurrent.futures.ThreadPoolExecutor,
):
return self._send_kvcache_generic(
mooncake_session_id=mooncake_session_id,
src_data_ptrs=self.kv_args.kv_data_ptrs,
dst_data_ptrs=dst_kv_ptrs,
item_lens=self.kv_args.kv_item_lens,
prefill_data_indices=prefill_kv_indices,
dst_data_indices=dst_kv_indices,
executor=executor,
)
def send_kvcache_slice(
self,
mooncake_session_id: str,
......@@ -593,6 +636,58 @@ class MooncakeKVManager(CommonKVManager):
f"Received AUX_DATA for bootstrap_room {room} with length:{len(data)}"
)
def maybe_send_extra(
self,
req: TransferInfo,
prefill_state_indices: list[int],
dst_state_data_ptrs: list[int],
):
"""Send state or extra pool data with type-specific handling."""
state_type = getattr(self.kv_args, "state_type", "none")
if state_type == "mamba":
return self._send_mamba_state(
req,
prefill_state_indices,
dst_state_data_ptrs,
)
elif state_type in ["swa", "nsa"]:
# Reuse _send_kvcache_generic interface to send extra pool data
prefill_state_indices = np.array(prefill_state_indices, dtype=np.int32)
dst_state_indices = np.array(req.dst_state_indices, dtype=np.int32)
return self._send_kvcache_generic(
mooncake_session_id=req.mooncake_session_id,
src_data_ptrs=self.kv_args.state_data_ptrs,
dst_data_ptrs=dst_state_data_ptrs,
item_lens=self.kv_args.state_item_lens,
prefill_data_indices=prefill_state_indices,
dst_data_indices=dst_state_indices,
executor=self.state_executors,
)
else:
return 0
def _send_mamba_state(
self,
req: TransferInfo,
prefill_mamba_index: list[int],
dst_state_data_ptrs: list[int],
):
"""Transfer Mamba states."""
assert len(prefill_mamba_index) == 1, "Mamba should have single state index"
transfer_blocks = []
prefill_state_data_ptrs = self.kv_args.state_data_ptrs
prefill_state_item_lens = self.kv_args.state_item_lens
for i, dst_state_ptr in enumerate(dst_state_data_ptrs):
length = prefill_state_item_lens[i]
src_addr = prefill_state_data_ptrs[i] + length * int(prefill_mamba_index[0])
dst_addr = dst_state_ptr + length * int(req.dst_state_indices[0])
transfer_blocks.append((src_addr, dst_addr, length))
return self._transfer_data(req.mooncake_session_id, transfer_blocks)
def sync_status_to_decode_endpoint(
self, remote: str, dst_port: int, room: int, status: int, prefill_rank: int
):
......@@ -702,6 +797,21 @@ class MooncakeKVManager(CommonKVManager):
break
if kv_chunk.is_last:
if kv_chunk.state_indices is not None:
if not self.is_mla_backend and (
self.attn_tp_size
!= target_rank_registration_info.dst_attn_tp_size
):
raise RuntimeError(
f"PD Disaggregation does NOT support PD different TP sizes for non-MLA hybrid models yet."
)
self.maybe_send_extra(
req,
kv_chunk.state_indices,
target_rank_registration_info.dst_state_data_ptrs,
)
if self.pp_group.is_last_rank:
# Only the last chunk we need to send the aux data
ret = self.send_aux(
......@@ -765,7 +875,7 @@ class MooncakeKVManager(CommonKVManager):
)
continue
else:
required_dst_info_num = int(waiting_req_bytes[6].decode("ascii"))
required_dst_info_num = int(waiting_req_bytes[7].decode("ascii"))
room = int(room)
if room not in self.transfer_infos:
self.transfer_infos[room] = {}
......@@ -876,6 +986,7 @@ class MooncakeKVManager(CommonKVManager):
index_slice: slice,
is_last: bool,
aux_index: Optional[int] = None,
state_indices: Optional[List[int]] = None,
):
assert self.disaggregation_mode == DisaggregationMode.PREFILL
assert not is_last or (is_last and aux_index is not None)
......@@ -909,6 +1020,7 @@ class MooncakeKVManager(CommonKVManager):
index_slice=index_slice,
is_last=is_last,
prefill_aux_index=aux_index,
state_indices=state_indices,
)
)
......@@ -989,6 +1101,7 @@ class MooncakeKVSender(CommonKVSender):
def send(
self,
kv_indices: npt.NDArray[np.int32],
state_indices: Optional[List[int]] = None,
):
index_slice = slice(self.curr_idx, self.curr_idx + len(kv_indices))
self.curr_idx += len(kv_indices)
......@@ -1008,6 +1121,7 @@ class MooncakeKVSender(CommonKVSender):
index_slice,
True,
aux_index=self.aux_index,
state_indices=state_indices,
)
def poll(self) -> KVPoll:
......@@ -1110,6 +1224,9 @@ class MooncakeKVReceiver(CommonKVReceiver):
packed_aux_data_ptrs = b"".join(
struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.aux_data_ptrs
)
packed_state_data_ptrs = b"".join(
struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.state_data_ptrs
)
# Note(shangming): No need to add pp rank here since pp is not supported on the decode side yet
tp_rank = self.kv_mgr.kv_args.engine_rank
kv_item_len = self.kv_mgr.kv_args.kv_item_lens[0]
......@@ -1127,13 +1244,19 @@ class MooncakeKVReceiver(CommonKVReceiver):
self.session_id.encode("ascii"),
packed_kv_data_ptrs,
packed_aux_data_ptrs,
packed_state_data_ptrs,
dst_tp_rank,
dst_attn_tp_size,
dst_kv_item_len,
]
)
def init(self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None):
def init(
self,
kv_indices: npt.NDArray[np.int32],
aux_index: Optional[int] = None,
state_indices: Optional[List[int]] = None,
):
for bootstrap_info in self.bootstrap_infos:
sock, lock = self._connect_to_bootstrap_server(bootstrap_info)
is_dummy = bootstrap_info["is_dummy"]
......@@ -1147,6 +1270,14 @@ class MooncakeKVReceiver(CommonKVReceiver):
self.session_id.encode("ascii"),
kv_indices.tobytes() if not is_dummy else b"",
str(aux_index).encode("ascii") if not is_dummy else b"",
(
np.array(
state_indices,
dtype=np.int32,
).tobytes()
if not is_dummy and state_indices is not None
else b""
),
str(self.required_dst_info_num).encode("ascii"),
]
)
......
......@@ -704,6 +704,7 @@ class NixlKVSender(CommonKVSender):
def send(
self,
kv_indices: npt.NDArray[np.int32],
state_indices: Optional[List[int]] = None,
):
index_slice = slice(self.curr_idx, self.curr_idx + len(kv_indices))
self.curr_idx += len(kv_indices)
......@@ -755,7 +756,12 @@ class NixlKVReceiver(CommonKVReceiver):
self.bootstrap_room
)
def init(self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None):
def init(
self,
kv_indices: npt.NDArray[np.int32],
aux_index: Optional[int] = None,
state_indices: Optional[List[int]] = None,
):
for bootstrap_info in self.bootstrap_infos:
logger.debug(
f"Fetched bootstrap info: {bootstrap_info} for engine rank: {self.kv_mgr.kv_args.engine_rank}"
......
......@@ -49,6 +49,11 @@ from sglang.srt.managers.schedule_batch import (
RequestStage,
ScheduleBatch,
)
from sglang.srt.mem_cache.memory_pool import (
HybridLinearKVPool,
NSATokenToKVPool,
SWAKVPool,
)
from sglang.srt.model_executor.forward_batch_info import ForwardMode, PPProxyTensors
from sglang.srt.utils import (
DynamicGradMode,
......@@ -146,6 +151,28 @@ class PrefillBootstrapQueue:
kv_args.ib_device = self.scheduler.server_args.disaggregation_ib_device
kv_args.gpu_id = self.scheduler.gpu_id
if hasattr(self.token_to_kv_pool, "get_state_buf_infos"):
state_data_ptrs, state_data_lens, state_item_lens = (
self.token_to_kv_pool.get_state_buf_infos()
)
kv_args.state_data_ptrs = state_data_ptrs
kv_args.state_data_lens = state_data_lens
kv_args.state_item_lens = state_item_lens
if isinstance(self.token_to_kv_pool, SWAKVPool):
kv_args.state_type = "swa"
elif isinstance(self.token_to_kv_pool, HybridLinearKVPool):
kv_args.state_type = "mamba"
elif isinstance(self.token_to_kv_pool, NSATokenToKVPool):
kv_args.state_type = "nsa"
else:
kv_args.state_type = "none"
else:
kv_args.state_data_ptrs = []
kv_args.state_data_lens = []
kv_args.state_item_lens = []
kv_args.state_type = "none"
kv_manager_class: Type[BaseKVManager] = get_kv_class(
self.transfer_backend, KVClassType.MANAGER
)
......@@ -618,15 +645,58 @@ class SchedulerDisaggregationPrefillMixin:
.numpy()
)
req.start_send_idx = end_idx
state_indices = None
if last_chunk:
self.disagg_metadata_buffers.set_buf(req)
# Prepare extra pool indices for hybrid models
if isinstance(
self.token_to_kv_pool_allocator.get_kvcache(), HybridLinearKVPool
):
# Mamba hybrid model: send single mamba state index
state_indices = [
self.req_to_token_pool.req_index_to_mamba_index_mapping[
req.req_pool_idx
]
.cpu()
.numpy()
]
elif isinstance(self.token_to_kv_pool_allocator.get_kvcache(), SWAKVPool):
# SWA hybrid model: send last window KV indices
seq_len = len(req.fill_ids)
window_size = self.sliding_window_size
window_start = max(0, seq_len - window_size)
window_start = (window_start // page_size) * page_size
window_kv_indices_full = self.req_to_token_pool.req_to_token[
req.req_pool_idx, window_start:seq_len
]
# Translate to SWA pool indices
window_kv_indices_swa = (
self.token_to_kv_pool_allocator.translate_loc_from_full_to_swa(
window_kv_indices_full
)
)
state_indices = window_kv_indices_swa.cpu().numpy()
state_indices = kv_to_page_indices(state_indices, page_size)
elif isinstance(
self.token_to_kv_pool_allocator.get_kvcache(), NSATokenToKVPool
):
seq_len = len(req.fill_ids)
kv_indices_full = self.req_to_token_pool.req_to_token[
req.req_pool_idx, :seq_len
]
state_indices = kv_indices_full.cpu().numpy()
state_indices = kv_to_page_indices(state_indices, page_size)
page_indices = kv_to_page_indices(kv_indices, page_size)
if len(page_indices) == 0:
logger.info(
f"Skip sending kv chunk for request {req.rid=} {req.bootstrap_room=} because page_indices is empty"
)
return
req.disagg_kv_sender.send(page_indices)
req.disagg_kv_sender.send(page_indices, state_indices)
# PP
@DynamicGradMode()
......
......@@ -807,9 +807,6 @@ class Scheduler(
self.tree_cache.cache_controller.layer_done_counter
)
elif self.is_hybrid:
assert (
self.server_args.disaggregation_mode == "null"
), "Hybrid mode does not support disaggregation yet"
self.tree_cache = SWARadixCache(
req_to_token_pool=self.req_to_token_pool,
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
......@@ -819,9 +816,6 @@ class Scheduler(
is_eagle=self.spec_algorithm.is_eagle(),
)
elif self.is_hybrid_gdn:
assert (
self.server_args.disaggregation_mode == "null"
), "Hybrid GDN mode does not support disaggregation yet"
self.tree_cache = MambaRadixCache(
req_to_token_pool=self.req_to_token_pool,
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
......
......@@ -142,72 +142,93 @@ class MambaPool:
ssm_dtype = cache_params.dtype.temporal
num_mamba_layers = len(cache_params.layers)
# assume conv_state = (dim, state_len)
assert conv_state_shape[0] > conv_state_shape[1]
conv_state = torch.zeros(
size=(num_mamba_layers, size + 1) + conv_state_shape,
dtype=conv_dtype,
device=device,
)
temporal_state = torch.zeros(
size=(num_mamba_layers, size + 1) + temporal_state_shape,
dtype=ssm_dtype,
device=device,
# for disagg with nvlink
self.enable_custom_mem_pool = get_bool_env_var(
"SGLANG_MOONCAKE_CUSTOM_MEM_POOL", "false"
)
if speculative_num_draft_tokens is not None:
# Cache intermediate SSM states per draft token during target verify
# Shape: [num_layers, size + 1, speculative_num_draft_tokens, HV, K, V]
intermediate_ssm_state_cache = torch.zeros(
size=(
num_mamba_layers,
size + 1,
speculative_num_draft_tokens,
temporal_state_shape[0],
temporal_state_shape[1],
temporal_state_shape[2],
),
dtype=ssm_dtype,
device="cuda",
)
# Cache intermediate conv windows (last K-1 inputs) per draft token during target verify
# Shape: [num_layers, size + 1, speculative_num_draft_tokens, dim, K-1]
intermediate_conv_window_cache = torch.zeros(
size=(
num_mamba_layers,
size + 1,
speculative_num_draft_tokens,
conv_state_shape[0],
conv_state_shape[1],
),
if self.enable_custom_mem_pool:
# TODO(shangming): abstract custom allocator class for more backends
from mooncake.allocator import NVLinkAllocator
allocator = NVLinkAllocator.get_allocator(self.device)
self.custom_mem_pool = torch.cuda.MemPool(allocator.allocator())
else:
self.custom_mem_pool = None
with (
torch.cuda.use_mem_pool(self.custom_mem_pool)
if self.enable_custom_mem_pool
else nullcontext()
):
# assume conv_state = (dim, state_len)
assert conv_state_shape[0] > conv_state_shape[1]
conv_state = torch.zeros(
size=(num_mamba_layers, size + 1) + conv_state_shape,
dtype=conv_dtype,
device="cuda",
)
self.mamba_cache = self.SpeculativeState(
conv=conv_state,
temporal=temporal_state,
intermediate_ssm=intermediate_ssm_state_cache,
intermediate_conv_window=intermediate_conv_window_cache,
device=device,
)
logger.info(
f"Mamba Cache is allocated. "
f"max_mamba_cache_size: {size}, "
f"conv_state size: {get_tensor_size_bytes(conv_state) / GB:.2f}GB, "
f"ssm_state size: {get_tensor_size_bytes(temporal_state) / GB:.2f}GB "
f"intermediate_ssm_state_cache size: {get_tensor_size_bytes(intermediate_ssm_state_cache) / GB:.2f}GB "
f"intermediate_conv_window_cache size: {get_tensor_size_bytes(intermediate_conv_window_cache) / GB:.2f}GB "
temporal_state = torch.zeros(
size=(num_mamba_layers, size + 1) + temporal_state_shape,
dtype=ssm_dtype,
device=device,
)
else:
self.mamba_cache = self.State(conv=conv_state, temporal=temporal_state)
logger.info(
f"Mamba Cache is allocated. "
f"max_mamba_cache_size: {size}, "
f"conv_state size: {get_tensor_size_bytes(conv_state) / GB:.2f}GB, "
f"ssm_state size: {get_tensor_size_bytes(temporal_state) / GB:.2f}GB "
if speculative_num_draft_tokens is not None:
# Cache intermediate SSM states per draft token during target verify
# Shape: [num_layers, size + 1, speculative_num_draft_tokens, HV, K, V]
intermediate_ssm_state_cache = torch.zeros(
size=(
num_mamba_layers,
size + 1,
speculative_num_draft_tokens,
temporal_state_shape[0],
temporal_state_shape[1],
temporal_state_shape[2],
),
dtype=ssm_dtype,
device="cuda",
)
# Cache intermediate conv windows (last K-1 inputs) per draft token during target verify
# Shape: [num_layers, size + 1, speculative_num_draft_tokens, dim, K-1]
intermediate_conv_window_cache = torch.zeros(
size=(
num_mamba_layers,
size + 1,
speculative_num_draft_tokens,
conv_state_shape[0],
conv_state_shape[1],
),
dtype=conv_dtype,
device="cuda",
)
self.mamba_cache = self.SpeculativeState(
conv=conv_state,
temporal=temporal_state,
intermediate_ssm=intermediate_ssm_state_cache,
intermediate_conv_window=intermediate_conv_window_cache,
)
logger.info(
f"Mamba Cache is allocated. "
f"max_mamba_cache_size: {size}, "
f"conv_state size: {get_tensor_size_bytes(conv_state) / GB:.2f}GB, "
f"ssm_state size: {get_tensor_size_bytes(temporal_state) / GB:.2f}GB "
f"intermediate_ssm_state_cache size: {get_tensor_size_bytes(intermediate_ssm_state_cache) / GB:.2f}GB "
f"intermediate_conv_window_cache size: {get_tensor_size_bytes(intermediate_conv_window_cache) / GB:.2f}GB "
)
else:
self.mamba_cache = self.State(conv=conv_state, temporal=temporal_state)
logger.info(
f"Mamba Cache is allocated. "
f"max_mamba_cache_size: {size}, "
f"conv_state size: {get_tensor_size_bytes(conv_state) / GB:.2f}GB, "
f"ssm_state size: {get_tensor_size_bytes(temporal_state) / GB:.2f}GB "
)
self.size = size
self.device = device
self.free_slots = torch.arange(
self.size, dtype=torch.int64, device=self.device
)
self.size = size
self.device = device
self.free_slots = torch.arange(self.size, dtype=torch.int64, device=self.device)
self.mem_usage = self.mamba_cache.mem_usage_bytes() / GB
self.mem_usage = self.mamba_cache.mem_usage_bytes() / GB
self.num_mamba_layers = num_mamba_layers
def get_speculative_mamba2_params_all_layers(self) -> SpeculativeState:
assert isinstance(self.mamba_cache, self.SpeculativeState)
......@@ -253,6 +274,22 @@ class MambaPool:
self.copy_from(src_index, dst_index)
return dst_index
def get_contiguous_buf_infos(self):
state_tensors = [
getattr(self.mamba_cache, field) for field in vars(self.mamba_cache)
]
data_ptrs, data_lens, item_lens = [], [], []
for _, state_tensor in enumerate(state_tensors):
data_ptrs += [
state_tensor[i].data_ptr() for i in range(self.num_mamba_layers)
]
data_lens += [state_tensor[i].nbytes for i in range(self.num_mamba_layers)]
item_lens += [
state_tensor[i][0].nbytes for i in range(self.num_mamba_layers)
]
return data_ptrs, data_lens, item_lens
class HybridReqToTokenPool(ReqToTokenPool):
"""A memory pool that maps a request to its token locations."""
......@@ -274,9 +311,22 @@ class HybridReqToTokenPool(ReqToTokenPool):
device=device,
enable_memory_saver=enable_memory_saver,
)
self._init_mamba_pool(
size=mamba_size,
cache_params=cache_params,
device=device,
speculative_num_draft_tokens=speculative_num_draft_tokens,
)
def _init_mamba_pool(
self,
size: int,
cache_params: "Mamba2CacheParams",
device: str,
speculative_num_draft_tokens: int = None,
):
self.mamba_pool = MambaPool(
size=mamba_size,
size=size,
cache_params=cache_params,
device=device,
speculative_num_draft_tokens=speculative_num_draft_tokens,
......@@ -375,6 +425,19 @@ class KVCache(abc.ABC):
# default state for optional layer-wise transfer control
self.layer_transfer_counter = None
# for disagg with nvlink
self.enable_custom_mem_pool = get_bool_env_var(
"SGLANG_MOONCAKE_CUSTOM_MEM_POOL", "false"
)
if self.enable_custom_mem_pool:
# TODO(shangming): abstract custom allocator class for more backends
from mooncake.allocator import NVLinkAllocator
allocator = NVLinkAllocator.get_allocator(self.device)
self.custom_mem_pool = torch.cuda.MemPool(allocator.allocator())
else:
self.custom_mem_pool = None
def _finalize_allocation_log(self, num_tokens: int):
"""Common logging and mem_usage computation for KV cache allocation.
Supports both tuple (K, V) size returns and single KV size returns.
......@@ -426,6 +489,9 @@ class KVCache(abc.ABC):
def load_cpu_copy(self, kv_cache_cpu, indices):
raise NotImplementedError()
def maybe_get_custom_mem_pool(self):
return self.custom_mem_pool
class MHATokenToKVPool(KVCache):
......@@ -456,19 +522,6 @@ class MHATokenToKVPool(KVCache):
self.head_num = head_num
self.head_dim = head_dim
# for disagg with nvlink
self.enable_custom_mem_pool = get_bool_env_var(
"SGLANG_MOONCAKE_CUSTOM_MEM_POOL", "false"
)
if self.enable_custom_mem_pool:
# TODO(shangming): abstract custom allocator class for more backends
from mooncake.allocator import NVLinkAllocator
allocator = NVLinkAllocator.get_allocator(self.device)
self.custom_mem_pool = torch.cuda.MemPool(allocator.allocator())
else:
self.custom_mem_pool = None
self._create_buffers()
self.device_module = torch.get_device_module(self.device)
......@@ -611,9 +664,6 @@ class MHATokenToKVPool(KVCache):
]
return kv_data_ptrs, kv_data_lens, kv_item_lens
def maybe_get_custom_mem_pool(self):
return self.custom_mem_pool
def get_cpu_copy(self, indices):
torch.cuda.synchronize()
kv_cache_cpu = []
......@@ -756,12 +806,18 @@ class HybridLinearKVPool(KVCache):
full_attention_layer_ids: List[int],
enable_kvcache_transpose: bool,
device: str,
mamba_pool: MambaPool,
):
self.size = size
self.dtype = dtype
self.device = device
self.full_layer_nums = len(full_attention_layer_ids)
self.page_size = page_size
# TODO support pp?
self.start_layer = 0
self.head_num = head_num
self.head_dim = head_dim
self.mamba_pool = mamba_pool
# TODO MHATransposedTokenToKVPool if enable_kvcache_transpose is True
assert not enable_kvcache_transpose
if _is_npu:
......@@ -790,6 +846,15 @@ class HybridLinearKVPool(KVCache):
def get_contiguous_buf_infos(self):
return self.full_kv_pool.get_contiguous_buf_infos()
def get_state_buf_infos(self):
mamba_data_ptrs, mamba_data_lens, mamba_item_lens = (
self.mamba_pool.get_contiguous_buf_infos()
)
return mamba_data_ptrs, mamba_data_lens, mamba_item_lens
def maybe_get_custom_mem_pool(self):
return self.full_kv_pool.maybe_get_custom_mem_pool()
def _transfer_full_attention_id(self, layer_id: int):
if layer_id not in self.full_attention_layer_id_mapping:
raise ValueError(
......@@ -841,22 +906,47 @@ class SWAKVPool(KVCache):
size: int,
size_swa: int,
dtype: torch.dtype,
head_num: int,
head_dim: int,
swa_attention_layer_ids: List[int],
full_attention_layer_ids: List[int],
enable_kvcache_transpose: bool,
device: str,
token_to_kv_pool_class: KVCache = MHATokenToKVPool,
**kwargs,
):
self.size = size
self.size_swa = size_swa
self.dtype = dtype
self.head_num = head_num
self.head_dim = head_dim
self.device = device
self.swa_layer_nums = len(swa_attention_layer_ids)
self.full_layer_nums = len(full_attention_layer_ids)
self.start_layer = 0
self.page_size = 1
kwargs["page_size"] = 1
kwargs["enable_memory_saver"] = False
kwargs["head_num"] = head_num
kwargs["head_dim"] = head_dim
kwargs["device"] = device
# TODO MHATransposedTokenToKVPool if enable_kvcache_transpose is True
assert not enable_kvcache_transpose
# for disagg with nvlink
self.enable_custom_mem_pool = get_bool_env_var(
"SGLANG_MOONCAKE_CUSTOM_MEM_POOL", "false"
)
if self.enable_custom_mem_pool:
# TODO(shangming): abstract custom allocator class for more backends
from mooncake.allocator import NVLinkAllocator
allocator = NVLinkAllocator.get_allocator(self.device)
self.custom_mem_pool = torch.cuda.MemPool(allocator.allocator())
else:
self.custom_mem_pool = None
self.swa_kv_pool = token_to_kv_pool_class(
size=size_swa,
dtype=dtype,
......@@ -878,6 +968,9 @@ class SWAKVPool(KVCache):
k_size, v_size = self.get_kv_size_bytes()
self.mem_usage = (k_size + v_size) / GB
logger.info(
f"SWAKVPool mem usage: {self.mem_usage} GB, swa size: {self.size_swa}, full size: {self.size}"
)
def get_kv_size_bytes(self):
k_size, v_size = self.full_kv_pool.get_kv_size_bytes()
......@@ -888,15 +981,19 @@ class SWAKVPool(KVCache):
full_kv_data_ptrs, full_kv_data_lens, full_kv_item_lens = (
self.full_kv_pool.get_contiguous_buf_infos()
)
kv_data_ptrs = full_kv_data_ptrs
kv_data_lens = full_kv_data_lens
kv_item_lens = full_kv_item_lens
return kv_data_ptrs, kv_data_lens, kv_item_lens
def get_state_buf_infos(self):
swa_kv_data_ptrs, swa_kv_data_lens, swa_kv_item_lens = (
self.swa_kv_pool.get_contiguous_buf_infos()
)
kv_data_ptrs = full_kv_data_ptrs + swa_kv_data_ptrs
kv_data_lens = full_kv_data_lens + swa_kv_data_lens
kv_item_lens = full_kv_item_lens + swa_kv_item_lens
return kv_data_ptrs, kv_data_lens, kv_item_lens
return swa_kv_data_ptrs, swa_kv_data_lens, swa_kv_item_lens
def get_key_buffer(self, layer_id: int):
layer_id_pool, is_swa = self.layers_mapping[layer_id]
......@@ -1152,19 +1249,6 @@ class MLATokenToKVPool(KVCache):
else (kv_lora_rank + qk_rope_head_dim)
)
# for disagg with nvlink
self.enable_custom_mem_pool = get_bool_env_var(
"SGLANG_MOONCAKE_CUSTOM_MEM_POOL", "false"
)
if self.enable_custom_mem_pool:
# TODO(shangming): abstract custom allocator class for more backends
from mooncake.allocator import NVLinkAllocator
allocator = NVLinkAllocator.get_allocator(self.device)
self.custom_mem_pool = torch.cuda.MemPool(allocator.allocator())
else:
self.custom_mem_pool = None
with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
with (
torch.cuda.use_mem_pool(self.custom_mem_pool)
......@@ -1207,9 +1291,6 @@ class MLATokenToKVPool(KVCache):
]
return kv_data_ptrs, kv_data_lens, kv_item_lens
def maybe_get_custom_mem_pool(self):
return self.custom_mem_pool
def get_key_buffer(self, layer_id: int):
if self.layer_transfer_counter is not None:
self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
......@@ -1346,24 +1427,31 @@ class NSATokenToKVPool(MLATokenToKVPool):
assert index_head_dim == 128
assert self.page_size == 64
self.index_k_with_scale_buffer = [
torch.zeros(
# Layout:
# ref: test_attention.py :: kv_cache_cast_to_fp8
# shape: (num_pages, page_size 64 * head_dim 128 + page_size 64 * fp32_nbytes 4)
# data: for page i,
# * buf[i, :page_size * head_dim] for fp8 data
# * buf[i, page_size * head_dim:].view(float32) for scale
(
(size + page_size + 1) // self.page_size,
self.page_size
* (index_head_dim + index_head_dim // self.quant_block_size * 4),
),
dtype=self.index_k_with_scale_buffer_dtype,
device=device,
)
for _ in range(layer_num)
]
with (
torch.cuda.use_mem_pool(self.custom_mem_pool)
if self.custom_mem_pool
else nullcontext()
):
self.index_k_with_scale_buffer = [
torch.zeros(
# Layout:
# ref: test_attention.py :: kv_cache_cast_to_fp8
# shape: (num_pages, page_size 64 * head_dim 128 + page_size 64 * fp32_nbytes 4)
# data: for page i,
# * buf[i, :page_size * head_dim] for fp8 data
# * buf[i, page_size * head_dim:].view(float32) for scale
(
(size + page_size + 1) // self.page_size,
self.page_size
* (
index_head_dim + index_head_dim // self.quant_block_size * 4
),
),
dtype=self.index_k_with_scale_buffer_dtype,
device=device,
)
for _ in range(layer_num)
]
self._finalize_allocation_log(size)
def get_index_k_with_scale_buffer(self, layer_id: int) -> torch.Tensor:
......@@ -1406,6 +1494,18 @@ class NSATokenToKVPool(MLATokenToKVPool):
pool=self, buf=buf, loc=loc, index_k=index_k, index_k_scale=index_k_scale
)
def get_state_buf_infos(self):
data_ptrs = [
self.index_k_with_scale_buffer[i].data_ptr() for i in range(self.layer_num)
]
data_lens = [
self.index_k_with_scale_buffer[i].nbytes for i in range(self.layer_num)
]
item_lens = [
self.index_k_with_scale_buffer[i][0].nbytes for i in range(self.layer_num)
]
return data_ptrs, data_lens, item_lens
def get_kv_size_bytes(self):
kv_size_bytes = super().get_kv_size_bytes()
for index_k_cache in self.index_k_with_scale_buffer:
......@@ -1636,27 +1736,38 @@ class DoubleSparseTokenToKVPool(KVCache):
)
with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
# [size, head_num, head_dim] for each layer
self.k_buffer = [
torch.zeros(
(size + page_size, head_num, head_dim), dtype=dtype, device=device
)
for _ in range(layer_num)
]
self.v_buffer = [
torch.zeros(
(size + page_size, head_num, head_dim), dtype=dtype, device=device
)
for _ in range(layer_num)
]
with (
torch.cuda.use_mem_pool(self.custom_mem_pool)
if self.enable_custom_mem_pool
else nullcontext()
):
# [size, head_num, head_dim] for each layer
self.k_buffer = [
torch.zeros(
(size + page_size, head_num, head_dim),
dtype=dtype,
device=device,
)
for _ in range(layer_num)
]
self.v_buffer = [
torch.zeros(
(size + page_size, head_num, head_dim),
dtype=dtype,
device=device,
)
for _ in range(layer_num)
]
# [size, head_num, heavy_channel_num] for each layer
self.label_buffer = [
torch.zeros(
(size + 1, head_num, heavy_channel_num), dtype=dtype, device=device
)
for _ in range(layer_num)
]
# [size, head_num, heavy_channel_num] for each layer
self.label_buffer = [
torch.zeros(
(size + 1, head_num, heavy_channel_num),
dtype=dtype,
device=device,
)
for _ in range(layer_num)
]
def get_key_buffer(self, layer_id: int):
return self.k_buffer[layer_id - self.start_layer]
......
......@@ -1669,19 +1669,34 @@ class ModelRunner:
extra_max_context_len += self.server_args.speculative_num_draft_tokens
if self.server_args.disaggregation_mode == "decode":
from sglang.srt.disaggregation.decode import DecodeReqToTokenPool
from sglang.srt.disaggregation.decode import (
DecodeReqToTokenPool,
HybridMambaDecodeReqToTokenPool,
)
# subscribe memory for pre-allocated requests
# if max_num_reqs <= 32, we pre-allocate 2x requests
pre_alloc_size = max_num_reqs * 2 if max_num_reqs <= 32 else 0
self.req_to_token_pool = DecodeReqToTokenPool(
size=max_num_reqs,
max_context_len=self.model_config.context_len
+ extra_max_context_len,
device=self.device,
enable_memory_saver=self.server_args.enable_memory_saver,
pre_alloc_size=pre_alloc_size,
)
if config := self.mambaish_config:
self.req_to_token_pool = HybridMambaDecodeReqToTokenPool(
size=max_num_reqs,
max_context_len=self.model_config.context_len
+ extra_max_context_len,
device=self.device,
enable_memory_saver=self.server_args.enable_memory_saver,
cache_params=config.mamba2_cache_params,
speculative_num_draft_tokens=self.server_args.speculative_num_draft_tokens,
pre_alloc_size=pre_alloc_size,
)
else:
self.req_to_token_pool = DecodeReqToTokenPool(
size=max_num_reqs,
max_context_len=self.model_config.context_len
+ extra_max_context_len,
device=self.device,
enable_memory_saver=self.server_args.enable_memory_saver,
pre_alloc_size=pre_alloc_size,
)
elif config := self.mambaish_config:
self.req_to_token_pool = HybridReqToTokenPool(
size=max_num_reqs,
......@@ -1807,6 +1822,7 @@ class ModelRunner:
),
enable_kvcache_transpose=False,
device=self.device,
mamba_pool=self.req_to_token_pool.mamba_pool,
)
else:
self.token_to_kv_pool = MHATokenToKVPool(
......
......@@ -163,6 +163,7 @@ suites = {
TestFile("test_deepseek_v3_basic.py", 275),
TestFile("test_deepseek_v3_mtp.py", 275),
TestFile("test_disaggregation_different_tp.py", 600),
TestFile("test_disaggregation_hybrid_attention.py", 200),
TestFile("test_disaggregation_pp.py", 140),
],
"per-commit-4-gpu-b200": [
......
import os
import unittest
from types import SimpleNamespace
from sglang.srt.environ import envs
from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k
from sglang.test.test_disaggregation_utils import TestDisaggregationBase
from sglang.test.test_utils import (
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
popen_launch_pd_server,
)
class TestDisaggregationHybridAttentionMamba(TestDisaggregationBase):
@classmethod
def setUpClass(cls):
super().setUpClass()
cls.model = "Qwen/Qwen3-Next-80B-A3B-Instruct"
# Non blocking start servers
cls.start_prefill()
cls.start_decode()
# Block until both
cls.wait_server_ready(cls.prefill_url + "/health")
cls.wait_server_ready(cls.decode_url + "/health")
cls.launch_lb()
@classmethod
def start_prefill(cls):
prefill_args = [
"--trust-remote-code",
"--disaggregation-mode",
"prefill",
"--tp",
"4",
]
prefill_args += cls.transfer_backend + cls.rdma_devices
cls.process_prefill = popen_launch_pd_server(
cls.model,
cls.prefill_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=prefill_args,
)
@classmethod
def start_decode(cls):
decode_args = [
"--trust-remote-code",
"--disaggregation-mode",
"decode",
"--tp",
"4",
"--base-gpu-id",
"4",
]
decode_args += cls.transfer_backend + cls.rdma_devices
cls.process_decode = popen_launch_pd_server(
cls.model,
cls.decode_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=decode_args,
)
def test_gsm8k(self):
args = SimpleNamespace(
num_shots=5,
data_path=None,
num_questions=200,
max_new_tokens=512,
parallel=128,
host=f"http://{self.base_host}",
port=int(self.lb_port),
)
metrics = run_eval_few_shot_gsm8k(args)
print(f"Evaluation metrics: {metrics}")
self.assertGreater(metrics["accuracy"], 0.93)
if __name__ == "__main__":
unittest.main()
......@@ -42,6 +42,7 @@ class TestMamba(unittest.TestCase):
full_attention_layer_ids=full_attention_layer_ids,
enable_kvcache_transpose=False,
device=device,
mamba_pool=None,
)
assert pool._transfer_full_attention_id(global_interval - 1) == 0
assert pool._transfer_full_attention_id(2 * global_interval - 1) == 1
......@@ -173,6 +174,7 @@ class TestMamba(unittest.TestCase):
full_attention_layer_ids=full_attention_layer_ids,
enable_kvcache_transpose=False,
device=device,
mamba_pool=req_to_token_pool.mamba_pool,
)
# setup token to kv pool allocator
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment