Unverified Commit e2715cf8 authored by Yi Zhang's avatar Yi Zhang Committed by GitHub
Browse files

fix mamba prefix cache leak caused by abort (#12693)

parent 74630ba3
...@@ -2520,6 +2520,9 @@ class Scheduler( ...@@ -2520,6 +2520,9 @@ class Scheduler(
if self.disaggregation_mode == DisaggregationMode.DECODE: if self.disaggregation_mode == DisaggregationMode.DECODE:
self.tree_cache.cache_finished_req(req) self.tree_cache.cache_finished_req(req)
# For mamba radix cache
if req.mamba_pool_idx is not None:
self.tree_cache.cache_finished_req(req, is_insert=False)
logger.debug(f"Abort queued request. {req.rid=}") logger.debug(f"Abort queued request. {req.rid=}")
# Delete the requests in the grammar queue # Delete the requests in the grammar queue
......
...@@ -437,7 +437,8 @@ class MambaRadixCache(BasePrefixCache): ...@@ -437,7 +437,8 @@ class MambaRadixCache(BasePrefixCache):
self.req_to_token_pool.free(req.req_pool_idx) self.req_to_token_pool.free(req.req_pool_idx)
return return
token_ids = (req.origin_input_ids + req.output_ids)[:-1] cache_len = len(req.origin_input_ids) + max(len(req.output_ids) - 1, 0)
token_ids = (req.origin_input_ids + req.output_ids)[:cache_len]
kv_indices = self.req_to_token_pool.req_to_token[ kv_indices = self.req_to_token_pool.req_to_token[
req.req_pool_idx, : len(token_ids) req.req_pool_idx, : len(token_ids)
] ]
...@@ -448,11 +449,7 @@ class MambaRadixCache(BasePrefixCache): ...@@ -448,11 +449,7 @@ class MambaRadixCache(BasePrefixCache):
# Radix Cache takes one ref in memory pool # Radix Cache takes one ref in memory pool
# insert the token_ids and kv_indices into the radix tree # insert the token_ids and kv_indices into the radix tree
# Note: the insert function already frees the overlapped kv_indices # Note: the insert function already frees the overlapped kv_indices
mamba_value = ( mamba_value = req.mamba_pool_idx.unsqueeze(-1).clone()
self.req_to_token_pool.get_mamba_indices(req.req_pool_idx)
.unsqueeze(-1)
.clone()
)
if is_insert: if is_insert:
new_prefix_len, mamba_exist = self.insert( new_prefix_len, mamba_exist = self.insert(
...@@ -469,8 +466,11 @@ class MambaRadixCache(BasePrefixCache): ...@@ -469,8 +466,11 @@ class MambaRadixCache(BasePrefixCache):
) )
mamba_exist = True mamba_exist = True
self.req_to_token_pool.free(req.req_pool_idx, free_mamba_cache=mamba_exist) if req.req_pool_idx is not None:
self.dec_lock_ref(req.last_node) self.req_to_token_pool.free(req.req_pool_idx, free_mamba_cache=mamba_exist)
self.dec_lock_ref(req.last_node)
else: # for abort case
self.req_to_token_pool.mamba_pool.free(mamba_value)
def cache_unfinished_req(self, req: Req, chunked=False) -> None: def cache_unfinished_req(self, req: Req, chunked=False) -> None:
"""Cache request when it is unfinished.""" """Cache request when it is unfinished."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment