Unverified Commit e2715cf8 authored by Yi Zhang's avatar Yi Zhang Committed by GitHub
Browse files

fix mamba prefix cache leak caused by abort (#12693)

parent 74630ba3
......@@ -2520,6 +2520,9 @@ class Scheduler(
if self.disaggregation_mode == DisaggregationMode.DECODE:
self.tree_cache.cache_finished_req(req)
# For mamba radix cache
if req.mamba_pool_idx is not None:
self.tree_cache.cache_finished_req(req, is_insert=False)
logger.debug(f"Abort queued request. {req.rid=}")
# Delete the requests in the grammar queue
......
......@@ -437,7 +437,8 @@ class MambaRadixCache(BasePrefixCache):
self.req_to_token_pool.free(req.req_pool_idx)
return
token_ids = (req.origin_input_ids + req.output_ids)[:-1]
cache_len = len(req.origin_input_ids) + max(len(req.output_ids) - 1, 0)
token_ids = (req.origin_input_ids + req.output_ids)[:cache_len]
kv_indices = self.req_to_token_pool.req_to_token[
req.req_pool_idx, : len(token_ids)
]
......@@ -448,11 +449,7 @@ class MambaRadixCache(BasePrefixCache):
# Radix Cache takes one ref in memory pool
# insert the token_ids and kv_indices into the radix tree
# Note: the insert function already frees the overlapped kv_indices
mamba_value = (
self.req_to_token_pool.get_mamba_indices(req.req_pool_idx)
.unsqueeze(-1)
.clone()
)
mamba_value = req.mamba_pool_idx.unsqueeze(-1).clone()
if is_insert:
new_prefix_len, mamba_exist = self.insert(
......@@ -469,8 +466,11 @@ class MambaRadixCache(BasePrefixCache):
)
mamba_exist = True
if req.req_pool_idx is not None:
self.req_to_token_pool.free(req.req_pool_idx, free_mamba_cache=mamba_exist)
self.dec_lock_ref(req.last_node)
else: # for abort case
self.req_to_token_pool.mamba_pool.free(mamba_value)
def cache_unfinished_req(self, req: Req, chunked=False) -> None:
"""Cache request when it is unfinished."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment