Unverified Commit a55cf530 authored by Yi Zhang's avatar Yi Zhang Committed by GitHub
Browse files

[Feature] Support mamba radix cache v0 (#11214)


Co-authored-by: default avatarhanming-lu <hanming@x.ai>
Co-authored-by: default avatarhzh0425 <hzh0425@apache.org>
Co-authored-by: default avatarthalahors <ericalcaide1@gmail.com>
parent 19ba16aa
......@@ -65,7 +65,8 @@ from sglang.srt.mem_cache.common import (
alloc_for_extend,
alloc_token_slots,
)
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
from sglang.srt.mem_cache.mamba_radix_cache import MambaRadixCache
from sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool, ReqToTokenPool
from sglang.srt.mem_cache.radix_cache import RadixKey
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
from sglang.srt.metrics.collector import SchedulerMetricsCollector, TimeStats
......@@ -522,6 +523,7 @@ class Req:
# Memory pool info
self.req_pool_idx: Optional[int] = None
self.mamba_pool_idx: Optional[torch.Tensor] = None # shape (1)
# Check finish
self.tokenizer = None
......@@ -727,7 +729,12 @@ class Req:
self.last_host_node,
self.host_hit_length,
) = tree_cache.match_prefix(
key=RadixKey(token_ids=token_ids, extra_key=self.extra_key)
key=RadixKey(token_ids=token_ids, extra_key=self.extra_key),
**(
{"req": self, "cow_mamba": True}
if isinstance(tree_cache, MambaRadixCache)
else {}
),
)
self.last_matched_prefix_len = len(self.prefix_indices)
self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices)
......@@ -877,6 +884,7 @@ class Req:
self.extend_logprob_start_len = 0
self.is_chunked = 0
self.req_pool_idx = None
self.mamba_pool_idx = None
self.already_computed = 0
def offload_kv_cache(self, req_to_token_pool, token_to_kv_pool_allocator):
......@@ -1071,6 +1079,27 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
def is_empty(self):
return len(self.reqs) == 0
def alloc_req_slots(self, num_reqs: int, reqs: Optional[List[Req]] = None):
if isinstance(self.req_to_token_pool, HybridReqToTokenPool):
mamba_available_size = self.req_to_token_pool.mamba_pool.available_size()
if mamba_available_size < num_reqs:
if self.tree_cache is not None and isinstance(
self.tree_cache, MambaRadixCache
):
mamba_num = max(0, num_reqs - mamba_available_size)
self.tree_cache.evict_mamba(mamba_num)
req_pool_indices = self.req_to_token_pool.alloc(num_reqs, reqs)
else:
req_pool_indices = self.req_to_token_pool.alloc(num_reqs)
if req_pool_indices is None:
raise RuntimeError(
"alloc_req_slots runs out of memory. "
"Please set a smaller number for `--max-running-requests`. "
f"{self.req_to_token_pool.available_size()=}, "
f"{num_reqs=}, "
)
return req_pool_indices
def allocate_for_eagle_v2(self):
from sglang.srt.speculative.eagle_info import EagleDraftInput
from sglang.srt.speculative.spec_utils import assign_req_to_token_pool
......
......@@ -27,6 +27,7 @@ import torch
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
from sglang.srt.mem_cache.mamba_radix_cache import MambaRadixCache
from sglang.srt.mem_cache.radix_cache import RadixCache, RadixKey, TreeNode
from sglang.srt.server_args import ServerArgs
......@@ -357,6 +358,7 @@ class PrefillAdder:
self.is_hybrid = isinstance(
self.token_to_kv_pool_allocator, SWATokenToKVPoolAllocator
)
self.is_hybrid_gdn_cache = isinstance(self.tree_cache, MambaRadixCache)
self.priority_scheduling_preemption_threshold = (
priority_scheduling_preemption_threshold
......@@ -380,6 +382,11 @@ class PrefillAdder:
self.token_to_kv_pool_allocator.swa_available_size()
+ self.tree_cache.swa_evictable_size(),
)
elif self.is_hybrid_gdn_cache:
available_and_evictable = (
self.token_to_kv_pool_allocator.available_size()
+ self.tree_cache.full_evictable_size()
)
else:
available_and_evictable = (
self.token_to_kv_pool_allocator.available_size()
......@@ -397,6 +404,11 @@ class PrefillAdder:
self.token_to_kv_pool_allocator.swa_available_size()
+ self.tree_cache.swa_evictable_size(),
)
elif self.is_hybrid_gdn_cache:
available_and_evictable = (
self.token_to_kv_pool_allocator.available_size()
+ self.tree_cache.full_evictable_size()
)
else:
available_and_evictable = (
self.token_to_kv_pool_allocator.available_size()
......
......@@ -146,6 +146,7 @@ from sglang.srt.managers.session_controller import Session
from sglang.srt.managers.utils import validate_input_length
from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
from sglang.srt.mem_cache.mamba_radix_cache import MambaRadixCache
from sglang.srt.mem_cache.radix_cache import RadixCache
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
......@@ -470,6 +471,10 @@ class Scheduler(
# Hybrid memory pool
self.is_hybrid = self.tp_worker.is_hybrid
self.is_hybrid_gdn = (
self.tp_worker.worker.model_runner.hybrid_gdn_config is not None
)
if self.is_hybrid:
self.sliding_window_size = self.tp_worker.sliding_window_size
self.full_tokens_per_layer, self.swa_tokens_per_layer = (
......@@ -816,6 +821,16 @@ class Scheduler(
disable=server_args.disable_radix_cache,
is_eagle=self.spec_algorithm.is_eagle(),
)
elif self.is_hybrid_gdn:
assert (
self.server_args.disaggregation_mode == "null"
), "Hybrid GDN mode does not support disaggregation yet"
self.tree_cache = MambaRadixCache(
req_to_token_pool=self.req_to_token_pool,
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
page_size=self.page_size,
disable=server_args.disable_radix_cache,
)
elif server_args.enable_lmcache:
from sglang.srt.mem_cache.storage.lmcache.lmc_radix_cache import (
LMCRadixCache,
......@@ -1689,6 +1704,25 @@ class Scheduler(
f"{self.full_tokens_per_layer=}, {full_available_size=}, {full_evictable_size=}, {self.tree_cache.full_protected_size()=}\n"
f"{self.swa_tokens_per_layer=}, {swa_available_size=}, {swa_evictable_size=}, {self.tree_cache.swa_protected_size()=}\n"
)
elif self.is_hybrid_gdn and isinstance(self.tree_cache, MambaRadixCache):
(
full_num_used,
mamba_num_used,
_,
_,
full_available_size,
full_evictable_size,
mamba_available_size,
mamba_evictable_size,
) = self._get_mamba_token_info()
memory_leak = (
full_num_used != self.tree_cache.full_protected_size()
or mamba_num_used != self.tree_cache.mamba_protected_size()
)
token_msg = (
f"{full_available_size=}, {full_evictable_size=}, {self.token_to_kv_pool_allocator.size=}, {self.tree_cache.full_protected_size()=}\n"
f"{mamba_available_size=}, {mamba_evictable_size=}, {self.req_to_token_pool.mamba_pool.size=}, {self.tree_cache.mamba_protected_size()=}\n"
)
else:
_, _, available_size, evictable_size = self._get_token_info()
protected_size = self.tree_cache.protected_size()
......@@ -1739,6 +1773,17 @@ class Scheduler(
) = self._get_swa_token_info()
num_used = max(full_num_used, swa_num_used)
token_usage = max(full_token_usage, swa_token_usage)
elif self.is_hybrid_gdn:
(
num_used,
_,
token_usage,
_,
_,
_,
_,
_,
) = self._get_mamba_token_info()
else:
num_used, token_usage, _, _ = self._get_token_info()
num_running_reqs = len(self.running_batch.reqs)
......@@ -1766,7 +1811,9 @@ class Scheduler(
self._publish_kv_events()
def check_tree_cache(self):
if self.is_hybrid and isinstance(self.tree_cache, SWARadixCache):
if (self.is_hybrid and isinstance(self.tree_cache, SWARadixCache)) or (
self.is_hybrid_gdn and isinstance(self.tree_cache, MambaRadixCache)
):
self.tree_cache.sanity_check()
def _get_token_info(self):
......@@ -1776,6 +1823,35 @@ class Scheduler(
token_usage = num_used / self.max_total_num_tokens
return num_used, token_usage, available_size, evictable_size
def _get_mamba_token_info(self):
is_radix_tree = isinstance(self.tree_cache, MambaRadixCache)
full_available_size = self.token_to_kv_pool_allocator.available_size()
full_evictable_size = (
self.tree_cache.full_evictable_size() if is_radix_tree else 0
)
mamba_available_size = self.req_to_token_pool.mamba_pool.available_size()
mamba_evictable_size = (
self.tree_cache.mamba_evictable_size() if is_radix_tree else 0
)
full_num_used = self.token_to_kv_pool_allocator.size - (
full_available_size + full_evictable_size
)
mamba_num_used = self.req_to_token_pool.mamba_pool.size - (
mamba_available_size + mamba_evictable_size
)
full_token_usage = full_num_used / self.token_to_kv_pool_allocator.size
mamba_usage = mamba_num_used / self.req_to_token_pool.mamba_pool.size
return (
full_num_used,
mamba_num_used,
full_token_usage,
mamba_usage,
full_available_size,
full_evictable_size,
mamba_available_size,
mamba_evictable_size,
)
def _get_swa_token_info(self):
full_available_size = self.token_to_kv_pool_allocator.full_available_size()
full_evictable_size = self.tree_cache.full_evictable_size()
......
......@@ -104,6 +104,23 @@ class SchedulerMetricsMixin:
f"full token usage: {full_token_usage:.2f}, "
f"swa token usage: {swa_token_usage:.2f}, "
)
elif self.is_hybrid_gdn:
(
full_num_used,
_,
full_token_usage,
mamba_usage,
_,
_,
_,
_,
) = self._get_mamba_token_info()
num_used = full_num_used
token_usage = full_token_usage
token_usage_msg = (
f"full token usage: {full_token_usage:.2f}, "
f"mamba usage: {mamba_usage:.2f}, "
)
else:
num_used, token_usage, _, _ = self._get_token_info()
token_usage_msg = f"token usage: {token_usage:.2f}, "
......@@ -203,6 +220,25 @@ class SchedulerMetricsMixin:
f"#swa token: {swa_num_used}, "
f"swa token usage: {swa_token_usage:.2f}, "
)
elif self.is_hybrid_gdn:
(
full_num_used,
mamba_used,
full_token_usage,
mamba_usage,
_,
_,
_,
_,
) = self._get_mamba_token_info()
num_used = full_num_used
token_usage = full_token_usage
token_usage_msg = (
f"#full token: {full_num_used}, "
f"full token usage: {full_token_usage:.2f}, "
f"mamba num: {mamba_used}, "
f"mamba usage: {mamba_usage:.2f}, "
)
else:
num_used, token_usage, _, _ = self._get_token_info()
token_usage_msg = f"#token: {num_used}, token usage: {token_usage:.2f}, "
......
This diff is collapsed.
......@@ -190,6 +190,7 @@ class MambaPool:
)
logger.info(
f"Mamba Cache is allocated. "
f"max_mamba_cache_size: {size}, "
f"conv_state size: {get_tensor_size_bytes(conv_state) / GB:.2f}GB, "
f"ssm_state size: {get_tensor_size_bytes(temporal_state) / GB:.2f}GB "
f"intermediate_ssm_state_cache size: {get_tensor_size_bytes(intermediate_ssm_state_cache) / GB:.2f}GB "
......@@ -199,11 +200,13 @@ class MambaPool:
self.mamba_cache = self.State(conv=conv_state, temporal=temporal_state)
logger.info(
f"Mamba Cache is allocated. "
f"max_mamba_cache_size: {size}, "
f"conv_state size: {get_tensor_size_bytes(conv_state) / GB:.2f}GB, "
f"ssm_state size: {get_tensor_size_bytes(temporal_state) / GB:.2f}GB "
)
self.size = size
self.free_slots = list(range(size))
self.device = device
self.free_slots = torch.arange(self.size, dtype=torch.int64, device=self.device)
self.mem_usage = self.mamba_cache.mem_usage_bytes() / GB
def get_speculative_mamba2_params_all_layers(self) -> SpeculativeState:
......@@ -216,7 +219,7 @@ class MambaPool:
def available_size(self):
return len(self.free_slots)
def alloc(self, need_size: int) -> Optional[List[int]]:
def alloc(self, need_size: int) -> Optional[torch.Tensor]:
if need_size > len(self.free_slots):
return None
......@@ -225,17 +228,30 @@ class MambaPool:
return select_index
def free(self, free_index: Union[int, List[int]]):
if isinstance(free_index, (int,)):
self.free_slots.append(free_index)
else:
self.free_slots.extend(free_index)
def free(self, free_index: torch.Tensor):
if free_index.numel() == 0:
return
self.free_slots = torch.cat((self.free_slots, free_index))
self.mamba_cache.conv[:, free_index] = self.mamba_cache.temporal[
:, free_index
] = 0
def clear(self):
self.free_slots = list(range(self.size))
self.free_slots = torch.arange(self.size, dtype=torch.int64, device=self.device)
def copy_from(self, src_index: torch.Tensor, dst_index: torch.Tensor):
self.mamba_cache.conv[:, dst_index] = self.mamba_cache.conv[:, src_index]
self.mamba_cache.temporal[:, dst_index] = self.mamba_cache.temporal[
:, src_index
]
return
def fork_from(self, src_index: torch.Tensor) -> Optional[torch.Tensor]:
dst_index = self.alloc(1)
if dst_index == None:
return None
self.copy_from(src_index, dst_index)
return dst_index
class HybridReqToTokenPool(ReqToTokenPool):
......@@ -245,6 +261,7 @@ class HybridReqToTokenPool(ReqToTokenPool):
self,
*,
size: int,
mamba_size: int,
max_context_len: int,
device: str,
enable_memory_saver: bool,
......@@ -259,7 +276,7 @@ class HybridReqToTokenPool(ReqToTokenPool):
)
self.mamba_pool = MambaPool(
size=size,
size=mamba_size,
cache_params=cache_params,
device=device,
speculative_num_draft_tokens=speculative_num_draft_tokens,
......@@ -271,9 +288,6 @@ class HybridReqToTokenPool(ReqToTokenPool):
size, dtype=torch.int32, device=self.device
)
self.rid_to_mamba_index_mapping: Dict[str, int] = {}
self.mamba_index_to_rid_mapping: Dict[int, str] = {}
# For chunk prefill req, we do not need to allocate mamba cache,
# We could use allocated mamba cache instead.
def alloc(
......@@ -285,14 +299,14 @@ class HybridReqToTokenPool(ReqToTokenPool):
mamba_index = []
for req in reqs:
rid = req.rid
if rid in self.rid_to_mamba_index_mapping:
mid = self.rid_to_mamba_index_mapping[rid]
elif (mid := self.mamba_pool.alloc(1)) is not None:
mid = mid[0]
self.rid_to_mamba_index_mapping[rid] = mid
self.mamba_index_to_rid_mapping[mid] = rid
mamba_index.append(mid)
mid = None
if req.mamba_pool_idx is not None: # for radix cache
mid = req.mamba_pool_idx
else:
mid = self.mamba_pool.alloc(1)[0]
req.mamba_pool_idx = mid
if mid is not None:
mamba_index.append(mid)
assert len(select_index) == len(
mamba_index
), f"Not enough space for mamba cache, try to increase --max-mamba-cache-size."
......@@ -313,17 +327,12 @@ class HybridReqToTokenPool(ReqToTokenPool):
# For chunk prefill, we can not free mamba cache, we need use it in the future
def free(self, free_index: Union[int, List[int]], free_mamba_cache: bool = True):
if isinstance(free_index, (int,)):
free_index = [free_index]
super().free(free_index)
if free_mamba_cache:
mamba_index = self.req_index_to_mamba_index_mapping[free_index]
mamba_index_list = mamba_index.tolist()
if isinstance(mamba_index_list, int):
mamba_index_list = [mamba_index_list]
self.mamba_pool.free(mamba_index_list)
for mid in mamba_index_list:
rid = self.mamba_index_to_rid_mapping[mid]
self.mamba_index_to_rid_mapping.pop(mid)
self.rid_to_mamba_index_mapping.pop(rid)
self.mamba_pool.free(mamba_index)
def clear(self):
super().clear()
......
......@@ -191,6 +191,9 @@ SGLANG_CI_SMALL_KV_SIZE = os.getenv("SGLANG_CI_SMALL_KV_SIZE", None)
# Detect stragger ranks in model loading
UNBALANCED_MODEL_LOADING_TIMEOUT_S = 300
# the ratio of mamba cache pool size to max_running_requests, it will be safe when it is larger than 2 (yizhang2077)
MAMBA_CACHE_SIZE_MAX_RUNNING_REQUESTS_RATIO = 3
logger = logging.getLogger(__name__)
......@@ -382,26 +385,10 @@ class ModelRunner:
if architectures and not any("Llama4" in arch for arch in architectures):
self.is_hybrid = self.model_config.is_hybrid = True
if config := self.mambaish_config:
if config := self.mamba2_config:
class_name = config.__class__.__name__
logger.warning(f"{class_name} model detected, disable radix cache")
self.server_args.disable_radix_cache = True
if self.server_args.max_mamba_cache_size is None:
if self.server_args.max_running_requests is not None:
self.server_args.max_mamba_cache_size = (
self.server_args.max_running_requests
)
else:
self.server_args.max_mamba_cache_size = 512
if self.hybrid_gdn_config is not None:
self.server_args.max_mamba_cache_size = (
self.server_args.max_mamba_cache_size
// (
self.server_args.dp_size
if self.server_args.enable_dp_attention
else 1
)
)
# For MTP models like DeepSeek-V3 or GLM-4.5, the MTP layer(s) are used separately as draft
# models for speculative decoding. In those cases, `num_nextn_predict_layers` is used to
......@@ -1330,15 +1317,60 @@ class ModelRunner:
rest_memory = available_gpu_memory - total_gpu_memory * (
1 - self.mem_fraction_static
)
if config := self.mambaish_config:
rest_memory -= (
self.server_args.max_mamba_cache_size
* config.mamba2_cache_params.mamba_cache_per_req
/ (1 << 30)
)
if self.mambaish_config is not None:
rest_memory = self.handle_max_mamba_cache(rest_memory)
max_num_token = int(rest_memory * (1 << 30) // cell_size)
return max_num_token
def handle_max_mamba_cache(self, total_rest_memory):
config = self.mambaish_config
server_args = self.server_args
assert config is not None
speculativa_ratio = (
0
if server_args.speculative_num_draft_tokens is None
else server_args.speculative_num_draft_tokens
)
if (
server_args.disable_radix_cache
or config.mamba2_cache_params.mamba_cache_per_req == 0
):
# with disable radix cache, sets the max_mamba_cache_size based on the max_running_requests
if server_args.max_mamba_cache_size is None:
if server_args.max_running_requests is not None:
server_args.max_mamba_cache_size = server_args.max_running_requests
else:
server_args.max_mamba_cache_size = 512
else:
# allocate the memory based on the ratio between mamba state memory vs. full kv cache memory
# solve the equations:
# 1. mamba_state_memory + full_kv_cache_memory == total_rest_memory
# 2. mamba_state_memory / full_kv_cache_memory == server_args.mamba_full_memory_ratio
mamba_state_memory_raw = (
total_rest_memory
* server_args.mamba_full_memory_ratio
/ (1 + server_args.mamba_full_memory_ratio)
)
# calculate the max_mamba_cache_size based on the given total mamba memory
server_args.max_mamba_cache_size = int(
(mamba_state_memory_raw * (1 << 30))
// config.mamba2_cache_params.mamba_cache_per_req
// (1 + speculativa_ratio)
)
if self.hybrid_gdn_config is not None:
server_args.max_mamba_cache_size = server_args.max_mamba_cache_size // (
server_args.dp_size if server_args.enable_dp_attention else 1
)
mamba_state_memory = (
server_args.max_mamba_cache_size
* config.mamba2_cache_params.mamba_cache_per_req
* (1 + speculativa_ratio)
/ (1 << 30)
)
return total_rest_memory - mamba_state_memory
@property
def hybrid_gdn_config(self):
config = self.model_config.hf_config
......@@ -1511,8 +1543,16 @@ class ModelRunner:
),
4096,
)
if self.mambaish_config is not None:
max_num_reqs = min(max_num_reqs, self.server_args.max_mamba_cache_size)
ratio = (
MAMBA_CACHE_SIZE_MAX_RUNNING_REQUESTS_RATIO
if not self.server_args.disable_radix_cache
else 1
)
max_num_reqs = min(
max_num_reqs, self.server_args.max_mamba_cache_size // ratio
)
if self.spec_algorithm.is_eagle() or self.spec_algorithm.is_standalone():
if self.is_draft_worker:
......@@ -1595,6 +1635,7 @@ class ModelRunner:
elif config := self.mambaish_config:
self.req_to_token_pool = HybridReqToTokenPool(
size=max_num_reqs,
mamba_size=self.server_args.max_mamba_cache_size,
max_context_len=self.model_config.context_len
+ extra_max_context_len,
device=self.device,
......
......@@ -362,6 +362,7 @@ class ServerArgs:
# Mamba cache
max_mamba_cache_size: Optional[int] = None
mamba_ssm_dtype: str = "float32"
mamba_full_memory_ratio: float = 0.2
# Hierarchical cache
enable_hierarchical_cache: bool = False
......@@ -2433,6 +2434,12 @@ class ServerArgs:
choices=["float32", "bfloat16"],
help="The data type of the SSM states in mamba cache.",
)
parser.add_argument(
"--mamba-full-memory-ratio",
type=float,
default=ServerArgs.mamba_full_memory_ratio,
help="The ratio of mamba state memory to full kv cache memory.",
)
# Args for multi-item-scoring
parser.add_argument(
"--multi-item-scoring-delimiter",
......
......@@ -84,6 +84,7 @@ suites = {
TestFile("test_io_struct.py", 8),
TestFile("test_jinja_template_utils.py", 1),
TestFile("test_logprobs.py", 55),
TestFile("test_mamba_unittest.py", 4),
TestFile("test_metrics.py", 32),
TestFile("test_metrics_utils.py", 1),
TestFile("test_mla.py", 167),
......
import inspect
import os
import unittest
import torch
from sglang.srt.configs.mamba_utils import Mamba2CacheParams, Mamba2StateShape
from sglang.srt.managers.schedule_batch import Req
from sglang.srt.mem_cache.allocator import TokenToKVPoolAllocator
from sglang.srt.mem_cache.mamba_radix_cache import MambaRadixCache
from sglang.srt.mem_cache.memory_pool import HybridLinearKVPool, HybridReqToTokenPool
from sglang.srt.mem_cache.radix_cache import RadixKey
from sglang.srt.sampling.sampling_params import SamplingParams
class TestMamba(unittest.TestCase):
@classmethod
def setUpClass(cls):
pass
@classmethod
def tearDownClass(cls):
pass
def test_hybrid_linear_kv_pool(self):
size = 16
head_num = 2
head_dim = 256
num_layers = 48
global_interval = 4
dtype = torch.bfloat16
device = "cuda"
full_attention_layer_ids = [
i for i in range(global_interval - 1, num_layers, global_interval)
]
pool = HybridLinearKVPool(
size=size,
dtype=dtype,
page_size=1,
head_num=head_num,
head_dim=head_dim,
full_attention_layer_ids=full_attention_layer_ids,
enable_kvcache_transpose=False,
device=device,
)
assert pool._transfer_full_attention_id(global_interval - 1) == 0
assert pool._transfer_full_attention_id(2 * global_interval - 1) == 1
with self.assertRaises(ValueError) as context:
pool._transfer_full_attention_id(1)
self.assertIn(
"layer_id=1 not in full attention layers:", str(context.exception)
)
def test_mamba_pool(self):
max_num_reqs = 10
mamba_cache_size = 20
max_context_len = 128
device = "cuda"
global_interval = 4
num_layers = 48
full_attention_layer_ids = [
i for i in range(global_interval - 1, num_layers, global_interval)
]
mamba_layers = [
i for i in range(num_layers) if i not in full_attention_layer_ids
]
shape = Mamba2StateShape.create(
tp_world_size=1,
intermediate_size=4096,
n_groups=16,
num_heads=32,
head_dim=128,
state_size=128,
conv_kernel=4,
)
os.environ["SGLANG_MAMBA_SSM_DTYPE"] = "bfloat16"
mamba2_cache_params = Mamba2CacheParams(shape=shape, layers=mamba_layers)
req_to_token_pool = HybridReqToTokenPool(
size=max_num_reqs,
mamba_size=mamba_cache_size,
max_context_len=max_context_len,
device=device,
enable_memory_saver=False,
cache_params=mamba2_cache_params,
speculative_num_draft_tokens=3,
)
assert req_to_token_pool.available_size() == max_num_reqs
assert req_to_token_pool.mamba_pool.available_size() == mamba_cache_size
sampling_params = SamplingParams(
temperature=0,
max_new_tokens=1,
)
req = Req(
rid=0,
origin_input_text="",
origin_input_ids=[],
sampling_params=sampling_params,
)
# alloc req
req_index = req_to_token_pool.alloc(1, [req])
assert req_to_token_pool.available_size() == max_num_reqs - 1
assert req_to_token_pool.mamba_pool.available_size() == mamba_cache_size - 1
# free req
req_to_token_pool.free(req_index)
assert req_to_token_pool.available_size() == max_num_reqs
assert req_to_token_pool.mamba_pool.available_size() == mamba_cache_size
# alloc req without free mamba cache
req.mamba_pool_idx = None
req_index = req_to_token_pool.alloc(1, [req])
req_to_token_pool.free(req_index, free_mamba_cache=False)
assert req_to_token_pool.available_size() == max_num_reqs
assert req_to_token_pool.mamba_pool.available_size() == mamba_cache_size - 1
# alloc again
req_index = req_to_token_pool.alloc(1, [req])
assert req_to_token_pool.available_size() == max_num_reqs - 1
assert req_to_token_pool.mamba_pool.available_size() == mamba_cache_size - 1
def test_mamba_radix_cache_1(self):
# kv cache
size = 128
dtype = torch.bfloat16
head_num = 2
head_dim = 256
num_layers = 48
global_interval = 4
max_num_reqs = 10
mamba_cache_size = 20
max_context_len = 128
device = "cuda"
full_attention_layer_ids = [
i for i in range(global_interval - 1, num_layers, global_interval)
]
# mamba
mamba_layers = [
i for i in range(num_layers) if i not in full_attention_layer_ids
]
os.environ["SGLANG_MAMBA_SSM_DTYPE"] = "bfloat16"
shape = Mamba2StateShape.create(
tp_world_size=1,
intermediate_size=4096,
n_groups=16,
num_heads=32,
head_dim=128,
state_size=128,
conv_kernel=4,
)
mamba2_cache_params = Mamba2CacheParams(shape=shape, layers=mamba_layers)
req_to_token_pool = HybridReqToTokenPool(
size=max_num_reqs,
mamba_size=mamba_cache_size,
max_context_len=max_context_len,
device=device,
enable_memory_saver=False,
cache_params=mamba2_cache_params,
speculative_num_draft_tokens=3,
)
# setup kv pool
pool = HybridLinearKVPool(
size=size,
dtype=dtype,
page_size=1,
head_num=head_num,
head_dim=head_dim,
full_attention_layer_ids=full_attention_layer_ids,
enable_kvcache_transpose=False,
device=device,
)
# setup token to kv pool allocator
allocator = TokenToKVPoolAllocator(
size=size,
dtype=dtype,
device=device,
kvcache=pool,
need_sort=False,
)
# setup radix cache
tree = MambaRadixCache(
req_to_token_pool=req_to_token_pool,
token_to_kv_pool_allocator=allocator,
page_size=1,
disable=False,
)
def make_dummy_req():
sampling_params = SamplingParams(
temperature=0,
max_new_tokens=1,
)
req = Req(
rid=0,
origin_input_text="",
origin_input_ids=[],
sampling_params=sampling_params,
)
req_to_token_pool.alloc(1, reqs=[req])
return req
mamba_pool = req_to_token_pool.mamba_pool
# test
print(
f"[Start] allocator mamba available size: {mamba_pool.available_size()}, full available size: {allocator.available_size()}"
)
req1 = make_dummy_req()
req1_token_ids, req1_kv_indices = [1, 2, 3], allocator.alloc(3)
assert len(req1_token_ids) == len(req1_kv_indices)
print(
f"req1: inserting, req1_token_ids: {req1_token_ids}, req1_kv_indices: {req1_kv_indices}"
)
prefix_len = tree.insert(
RadixKey(req1_token_ids), req1_kv_indices, req1.mamba_pool_idx.unsqueeze(0)
)
print(
f"req1: prefix_len: {prefix_len}, allocator mamba available size: {mamba_pool.available_size()}, full available size: {allocator.available_size()}"
)
req2 = make_dummy_req()
req2_token_ids, req2_kv_indices = [1, 2, 3, 4, 5, 6, 7], allocator.alloc(7)
assert len(req2_token_ids) == len(req2_kv_indices)
print(
f"req2: inserting, req2_token_ids: {req2_token_ids}, req2_kv_indices: {req2_kv_indices}"
)
prefix_len = tree.insert(
RadixKey(req2_token_ids), req2_kv_indices, req2.mamba_pool_idx.unsqueeze(0)
)
print(
f"req2: prefix_len: {prefix_len}, allocator mamba available size: {mamba_pool.available_size()}, full available size: {allocator.available_size()}"
)
req3 = make_dummy_req()
req3_token_ids, req3_kv_indices = [10, 11, 12], allocator.alloc(3)
assert len(req3_token_ids) == len(req3_kv_indices)
print(
f"req3: inserting, req3_token_ids: {req3_token_ids}, req3_kv_indices: {req3_kv_indices}"
)
prefix_len = tree.insert(
RadixKey(req3_token_ids), req3_kv_indices, req3.mamba_pool_idx.unsqueeze(0)
)
print(
f"req3: prefix_len: {prefix_len}, allocator mamba available size: {mamba_pool.available_size()}, full available size: {allocator.available_size()}"
)
req4 = make_dummy_req()
req4_token_ids, req4_kv_indices = [1, 2, 3, 4, 5, 60, 70], allocator.alloc(7)
assert len(req4_token_ids) == len(req4_kv_indices)
print(
f"req4: inserting, req4_token_ids: {req4_token_ids}, req4_kv_indices: {req4_kv_indices}"
)
prefix_len = tree.insert(
RadixKey(req4_token_ids), req4_kv_indices, req4.mamba_pool_idx.unsqueeze(0)
)
print(
f"req4: prefix_len: {prefix_len}, allocator mamba available size: {mamba_pool.available_size()}, full available size: {allocator.available_size()}"
)
tree.pretty_print()
full_num_tokens = 1
print(f"evicting {full_num_tokens} full token")
tree.evict(full_num_tokens=full_num_tokens)
tree.pretty_print()
mamba_num = 1
print(f"evicting {mamba_num} mamba")
tree.evict_mamba(mamba_num=mamba_num)
tree.pretty_print()
req5_token_ids = [1, 2, 3, 4, 5]
result = tree.match_prefix(RadixKey(req5_token_ids))
kv_indices, last_node = result.device_indices, result.last_device_node
print(
f"req5: token_ids: {req5_token_ids}, matched kv_indices: {kv_indices}, last_node.key: {last_node.key}"
)
assert len(kv_indices) == 0
req6_token_ids = [1, 2, 3, 4, 5, 60, 70]
result = tree.match_prefix(RadixKey(req6_token_ids))
kv_indices, last_node = result.device_indices, result.last_device_node
print(
f"req6: token_ids: {req6_token_ids}, matched kv_indices: {kv_indices}, last_node.key: {last_node.key}"
)
assert len(kv_indices) == 7
assert len(last_node.key) == 2
req7_token_ids = [1, 2, 3, 4, 5, 6, 7]
result = tree.match_prefix(RadixKey(req7_token_ids))
kv_indices, last_node = result.device_indices, result.last_device_node
print(
f"req7: token_ids: {req7_token_ids}, matched kv_indices: {kv_indices}, last_node.key: {last_node.key}"
)
assert len(kv_indices) == 7
assert len(last_node.key) == 2
mamba_num = 1
print(f"evicting {mamba_num} mamba")
tree.evict_mamba(mamba_num=mamba_num)
tree.pretty_print()
req8_token_ids = [1, 2, 3, 4, 5, 60, 70]
result = tree.match_prefix(RadixKey(req8_token_ids))
kv_indices, last_node = result.device_indices, result.last_device_node
print(
f"req8: token_ids: {req8_token_ids}, matched kv_indices: {kv_indices}, last_node.key: {last_node.key}"
)
assert len(kv_indices) == 0
assert len(last_node.key) == 0
req9_token_ids = [1, 2, 3, 4, 5, 6, 7]
req9 = make_dummy_req()
result = tree.match_prefix(
RadixKey(req9_token_ids), **({"req": req9, "cow_mamba": True})
)
kv_indices, last_node = result.device_indices, result.last_device_node
assert req9.mamba_pool_idx is not None
assert torch.all(
mamba_pool.mamba_cache.conv[:, req9.mamba_pool_idx]
== mamba_pool.mamba_cache.conv[:, last_node.mamba_value]
)
assert torch.all(
mamba_pool.mamba_cache.temporal[:, req9.mamba_pool_idx]
== mamba_pool.mamba_cache.temporal[:, last_node.mamba_value]
)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment