Unverified Commit 1d7f7835 authored by cctry's avatar cctry Committed by GitHub
Browse files

Refactor kv cache free (#11351)

parent 32595146
...@@ -611,8 +611,8 @@ class DecodeTransferQueue: ...@@ -611,8 +611,8 @@ class DecodeTransferQueue:
self.scheduler.stream_output( self.scheduler.stream_output(
[decode_req.req], decode_req.req.return_logprob [decode_req.req], decode_req.req.return_logprob
) )
# unlock the kv cache or it will have memory leak # release pre-allocated kv cache, but don't insert into the tree since it's failed
self.tree_cache.cache_finished_req(decode_req.req) self.tree_cache.cache_finished_req(decode_req.req, is_insert=False)
indices_to_remove.add(i) indices_to_remove.add(i)
if self.scheduler.enable_metrics: if self.scheduler.enable_metrics:
self.scheduler.metrics_collector.increment_transfer_failed_reqs() self.scheduler.metrics_collector.increment_transfer_failed_reqs()
......
...@@ -64,6 +64,7 @@ from sglang.srt.mem_cache.common import ( ...@@ -64,6 +64,7 @@ from sglang.srt.mem_cache.common import (
alloc_for_decode, alloc_for_decode,
alloc_for_extend, alloc_for_extend,
alloc_token_slots, alloc_token_slots,
evict_from_tree_cache,
) )
from sglang.srt.mem_cache.mamba_radix_cache import MambaRadixCache from sglang.srt.mem_cache.mamba_radix_cache import MambaRadixCache
from sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool, ReqToTokenPool from sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool, ReqToTokenPool
...@@ -1406,7 +1407,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -1406,7 +1407,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
* self.token_to_kv_pool_allocator.page_size * self.token_to_kv_pool_allocator.page_size
) )
self._evict_tree_cache_if_needed(num_tokens) evict_from_tree_cache(self.tree_cache, num_tokens)
return self._is_available_size_sufficient(num_tokens) return self._is_available_size_sufficient(num_tokens)
def retract_decode(self, server_args: ServerArgs): def retract_decode(self, server_args: ServerArgs):
...@@ -1454,6 +1455,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -1454,6 +1455,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
idx = sorted_indices.pop() idx = sorted_indices.pop()
req = self.reqs[idx] req = self.reqs[idx]
retracted_reqs.append(req) retracted_reqs.append(req)
# release memory and don't insert into the tree because we need the space instantly
self.release_req(idx, len(sorted_indices), server_args) self.release_req(idx, len(sorted_indices), server_args)
if len(retracted_reqs) == 0: if len(retracted_reqs) == 0:
...@@ -1478,39 +1480,16 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -1478,39 +1480,16 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
def release_req(self, idx: int, remaing_req_count: int, server_args: ServerArgs): def release_req(self, idx: int, remaing_req_count: int, server_args: ServerArgs):
req = self.reqs[idx] req = self.reqs[idx]
seq_lens_cpu = self.seq_lens_cpu.numpy()
if server_args.disaggregation_mode == "decode": if server_args.disaggregation_mode == "decode":
req.offload_kv_cache( req.offload_kv_cache(
self.req_to_token_pool, self.token_to_kv_pool_allocator self.req_to_token_pool, self.token_to_kv_pool_allocator
) )
if isinstance(self.tree_cache, ChunkCache): # TODO (csy): for preempted requests, we may want to insert into the tree
# ChunkCache does not have eviction self.tree_cache.cache_finished_req(req, is_insert=False)
token_indices = self.req_to_token_pool.req_to_token[ # NOTE(lsyin): we should use the newly evictable memory instantly.
req.req_pool_idx, : seq_lens_cpu[idx] num_tokens = remaing_req_count * envs.SGLANG_RETRACT_DECODE_STEPS.get()
] evict_from_tree_cache(self.tree_cache, num_tokens)
self.token_to_kv_pool_allocator.free(token_indices)
self.req_to_token_pool.free(req.req_pool_idx)
else:
# TODO: apply more fine-grained retraction
last_uncached_pos = (
len(req.prefix_indices) // server_args.page_size
) * server_args.page_size
token_indices = self.req_to_token_pool.req_to_token[
req.req_pool_idx, last_uncached_pos : seq_lens_cpu[idx]
]
self.token_to_kv_pool_allocator.free(token_indices)
self.req_to_token_pool.free(req.req_pool_idx)
# release the last node
if self.is_hybrid:
self.tree_cache.dec_lock_ref(req.last_node, req.swa_uuid_for_lock)
else:
self.tree_cache.dec_lock_ref(req.last_node)
# NOTE(lsyin): we should use the newly evictable memory instantly.
num_tokens = remaing_req_count * envs.SGLANG_RETRACT_DECODE_STEPS.get()
self._evict_tree_cache_if_needed(num_tokens)
req.reset_for_retract() req.reset_for_retract()
...@@ -1808,24 +1787,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): ...@@ -1808,24 +1787,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
enable_overlap=self.enable_overlap, enable_overlap=self.enable_overlap,
) )
def _evict_tree_cache_if_needed(self, num_tokens: int):
if isinstance(self.tree_cache, (SWAChunkCache, ChunkCache)):
return
if self.is_hybrid:
full_available_size = self.token_to_kv_pool_allocator.full_available_size()
swa_available_size = self.token_to_kv_pool_allocator.swa_available_size()
if full_available_size < num_tokens or swa_available_size < num_tokens:
if self.tree_cache is not None:
full_num_tokens = max(0, num_tokens - full_available_size)
swa_num_tokens = max(0, num_tokens - swa_available_size)
self.tree_cache.evict(full_num_tokens, swa_num_tokens)
else:
if self.token_to_kv_pool_allocator.available_size() < num_tokens:
if self.tree_cache is not None:
self.tree_cache.evict(num_tokens)
def _is_available_size_sufficient(self, num_tokens: int) -> bool: def _is_available_size_sufficient(self, num_tokens: int) -> bool:
if self.is_hybrid: if self.is_hybrid:
return ( return (
......
...@@ -40,7 +40,7 @@ class BasePrefixCache(ABC): ...@@ -40,7 +40,7 @@ class BasePrefixCache(ABC):
pass pass
@abstractmethod @abstractmethod
def cache_finished_req(self, req: Req, **kwargs): def cache_finished_req(self, req: Req, is_insert: bool = True, **kwargs):
pass pass
@abstractmethod @abstractmethod
......
...@@ -49,7 +49,7 @@ class ChunkCache(BasePrefixCache): ...@@ -49,7 +49,7 @@ class ChunkCache(BasePrefixCache):
last_host_node=None, last_host_node=None,
) )
def cache_finished_req(self, req: Req, insert: bool = True): def cache_finished_req(self, req: Req, is_insert: bool = True):
kv_indices = self.req_to_token_pool.req_to_token[ kv_indices = self.req_to_token_pool.req_to_token[
req.req_pool_idx, req.req_pool_idx,
# For decode server: if req.output_ids is empty, we want to free all req.origin_input_ids # For decode server: if req.output_ids is empty, we want to free all req.origin_input_ids
......
...@@ -330,18 +330,18 @@ class RadixCache(BasePrefixCache): ...@@ -330,18 +330,18 @@ class RadixCache(BasePrefixCache):
return self._insert_helper(self.root_node, key, value) return self._insert_helper(self.root_node, key, value)
def cache_finished_req(self, req: Req): def cache_finished_req(self, req: Req, is_insert: bool = True):
"""Cache request when it finishes.""" """Cache request when it finishes."""
all_token_len = len(req.origin_input_ids) + max(len(req.output_ids) - 1, 0)
if self.disable: if self.disable:
kv_indices = self.req_to_token_pool.req_to_token[ kv_indices = self.req_to_token_pool.req_to_token[
req.req_pool_idx, : len(req.origin_input_ids) + len(req.output_ids) - 1 req.req_pool_idx, :all_token_len
] ]
self.token_to_kv_pool_allocator.free(kv_indices) self.token_to_kv_pool_allocator.free(kv_indices)
self.req_to_token_pool.free(req.req_pool_idx) self.req_to_token_pool.free(req.req_pool_idx)
return return
token_ids = (req.origin_input_ids + req.output_ids)[:-1] token_ids = (req.origin_input_ids + req.output_ids)[:all_token_len]
all_token_len = len(token_ids)
# For EAGLE radix cache, we will convert the key to bigram key, e.g. [1,2,3,4] -> [(1,2), (2,3), (3,4)], the length will -1. ((len([(1,2), (2,3), (3,4)]) = len([1,2,3,4]) - 1)) # For EAGLE radix cache, we will convert the key to bigram key, e.g. [1,2,3,4] -> [(1,2), (2,3), (3,4)], the length will -1. ((len([(1,2), (2,3), (3,4)]) = len([1,2,3,4]) - 1))
# So for the corresponding kv length should also -1. Then we get the actual_kv_len, and use it to do later calculation and slicing. # So for the corresponding kv length should also -1. Then we get the actual_kv_len, and use it to do later calculation and slicing.
actual_kv_len = all_token_len - 1 if self.is_eagle else all_token_len actual_kv_len = all_token_len - 1 if self.is_eagle else all_token_len
...@@ -354,12 +354,9 @@ class RadixCache(BasePrefixCache): ...@@ -354,12 +354,9 @@ class RadixCache(BasePrefixCache):
page_aligned_kv_indices = kv_indices[:page_aligned_len].to( page_aligned_kv_indices = kv_indices[:page_aligned_len].to(
dtype=torch.int64, copy=True dtype=torch.int64, copy=True
) )
self.token_to_kv_pool_allocator.free(kv_indices[page_aligned_len:])
else: else:
page_aligned_len = actual_kv_len page_aligned_len = actual_kv_len
page_aligned_kv_indices = kv_indices.to(dtype=torch.int64, copy=True) page_aligned_kv_indices = kv_indices.to(dtype=torch.int64, copy=True)
if self.is_eagle:
self.token_to_kv_pool_allocator.free(kv_indices[page_aligned_len:])
page_aligned_token_len = ( page_aligned_token_len = (
page_aligned_len + 1 if self.is_eagle else page_aligned_len page_aligned_len + 1 if self.is_eagle else page_aligned_len
...@@ -372,11 +369,22 @@ class RadixCache(BasePrefixCache): ...@@ -372,11 +369,22 @@ class RadixCache(BasePrefixCache):
old_prefix_len -= 1 old_prefix_len -= 1
# Radix Cache takes one ref in memory pool # Radix Cache takes one ref in memory pool
new_prefix_len = self.insert( if is_insert:
RadixKey(token_ids[:page_aligned_token_len], req.extra_key), new_prefix_len = self.insert(
page_aligned_kv_indices, RadixKey(token_ids[:page_aligned_token_len], req.extra_key),
) page_aligned_kv_indices,
self.token_to_kv_pool_allocator.free(kv_indices[old_prefix_len:new_prefix_len]) )
# Free the duplicates that were already in the tree
self.token_to_kv_pool_allocator.free(
kv_indices[old_prefix_len:new_prefix_len]
)
else:
self.token_to_kv_pool_allocator.free(
kv_indices[old_prefix_len:page_aligned_len]
)
# free the unaligned tail
self.token_to_kv_pool_allocator.free(kv_indices[page_aligned_len:])
# Remove req slot release the cache lock # Remove req slot release the cache lock
self.req_to_token_pool.free(req.req_pool_idx) self.req_to_token_pool.free(req.req_pool_idx)
......
...@@ -151,32 +151,37 @@ class RadixCacheCpp(BasePrefixCache): ...@@ -151,32 +151,37 @@ class RadixCacheCpp(BasePrefixCache):
def total_size(self): def total_size(self):
return self.tree.total_size() return self.tree.total_size()
def cache_finished_req(self, req: Req): def cache_finished_req(self, req: Req, is_insert: bool = True):
"""Cache request when it finishes.""" """Cache request when it finishes."""
assert req.req_pool_idx is not None assert req.req_pool_idx is not None
token_ids = (req.origin_input_ids + req.output_ids)[:-1] all_token_len = len(req.origin_input_ids) + max(len(req.output_ids) - 1, 0)
token_ids = (req.origin_input_ids + req.output_ids)[:all_token_len]
overall_len = len(token_ids) # prefill + decode overall_len = len(token_ids) # prefill + decode
kv_indices = self.req_to_token_pool.req_to_token[req.req_pool_idx, :overall_len] kv_indices = self.req_to_token_pool.req_to_token[req.req_pool_idx, :overall_len]
# NOTE: our C++ implementation don't need `token_ids` and `kv_indices` to be page-aligned # NOTE: our C++ implementation don't need `token_ids` and `kv_indices` to be page-aligned
# it will automatically align them, but length of them should be equal # it will automatically align them, but length of them should be equal
old_prefix_len = len(req.prefix_indices) // self.page_size * self.page_size old_prefix_len = len(req.prefix_indices) // self.page_size * self.page_size
new_prefix_len = self._insert(RadixKey(token_ids, req.extra_key), kv_indices) page_aligned_overall_len = overall_len // self.page_size * self.page_size
# NOTE: kv_indices[:old_prefix_len] == req.prefix_indices if is_insert:
assert old_prefix_len <= new_prefix_len, "Wrong prefix indices" new_prefix_len = self._insert(
RadixKey(token_ids, req.extra_key), kv_indices
# KVCache between old & new is newly generated, but already exists in the pool )
# we need to free this newly generated kv indices # NOTE: kv_indices[:old_prefix_len] == req.prefix_indices
if old_prefix_len < new_prefix_len: assert old_prefix_len <= new_prefix_len, "Wrong prefix indices"
self.token_to_kv_pool.free(kv_indices[old_prefix_len:new_prefix_len]) # Free duplicates that were already in the pool
if old_prefix_len < new_prefix_len:
self.token_to_kv_pool.free(kv_indices[old_prefix_len:new_prefix_len])
else:
self.token_to_kv_pool.free(
kv_indices[old_prefix_len:page_aligned_overall_len]
)
# need to free the unaligned part, since it cannot be inserted into the radix tree # need to free the unaligned part, since it cannot be inserted into the radix tree
if self.page_size != 1 and ( # unaligned tail only exists when page_size > 1 if page_aligned_overall_len < overall_len:
(unaligned_len := overall_len % self.page_size) > 0
):
# NOTE: sglang PagedAllocator support unaligned free (which will automatically align it) # NOTE: sglang PagedAllocator support unaligned free (which will automatically align it)
self.token_to_kv_pool.free(kv_indices[overall_len - unaligned_len :]) self.token_to_kv_pool.free(kv_indices[page_aligned_overall_len:])
# Remove req slot release the cache lock # Remove req slot release the cache lock
self.dec_lock_ref(req.last_node) self.dec_lock_ref(req.last_node)
......
...@@ -217,10 +217,12 @@ class LMCRadixCache(RadixCache): ...@@ -217,10 +217,12 @@ class LMCRadixCache(RadixCache):
return base_res return base_res
def cache_finished_req(self, req: "Req") -> None: # type: ignore[override] def cache_finished_req(self, req: "Req", is_insert: bool = True) -> None: # type: ignore[override]
"""On request completion, insert device KV into radix and store to LMCache.""" """On request completion, insert device KV into radix and store to LMCache."""
super().cache_finished_req(req) super().cache_finished_req(req, is_insert=is_insert)
if not is_insert:
return
token_ids = (req.origin_input_ids + req.output_ids)[:-1] token_ids = (req.origin_input_ids + req.output_ids)[:-1]
kv_indices = self.req_to_token_pool.req_to_token[ kv_indices = self.req_to_token_pool.req_to_token[
......
...@@ -427,19 +427,18 @@ class SWARadixCache(BasePrefixCache): ...@@ -427,19 +427,18 @@ class SWARadixCache(BasePrefixCache):
return self._insert_helper(self.root_node, key, value, prev_prefix_len) return self._insert_helper(self.root_node, key, value, prev_prefix_len)
def cache_finished_req(self, req: Req) -> None: def cache_finished_req(self, req: Req, is_insert: bool = True) -> None:
"""Cache request when it finishes.""" """Cache request when it finishes."""
all_token_len = len(req.origin_input_ids) + max(len(req.output_ids) - 1, 0)
if self.disable: if self.disable:
kv_indices = self.req_to_token_pool.req_to_token[ kv_indices = self.req_to_token_pool.req_to_token[
req.req_pool_idx, req.req_pool_idx, :all_token_len
: len(req.origin_input_ids) + max(len(req.output_ids) - 1, 0),
] ]
self.token_to_kv_pool_allocator.free(kv_indices) self.token_to_kv_pool_allocator.free(kv_indices)
self.req_to_token_pool.free(req.req_pool_idx) self.req_to_token_pool.free(req.req_pool_idx)
return return
token_ids = (req.origin_input_ids + req.output_ids)[:-1] token_ids = (req.origin_input_ids + req.output_ids)[:all_token_len]
all_token_len = len(token_ids)
# For EAGLE radix cache, we will convert the key to bigram key, e.g. [1,2,3,4] -> [(1,2), (2,3), (3,4)], the length will -1. ((len([(1,2), (2,3), (3,4)]) = len([1,2,3,4]) - 1)) # For EAGLE radix cache, we will convert the key to bigram key, e.g. [1,2,3,4] -> [(1,2), (2,3), (3,4)], the length will -1. ((len([(1,2), (2,3), (3,4)]) = len([1,2,3,4]) - 1))
# So for the corresponding kv length should also -1. Then we get the actual_kv_len, and use it to do later calculation and slicing. # So for the corresponding kv length should also -1. Then we get the actual_kv_len, and use it to do later calculation and slicing.
actual_kv_len = all_token_len - 1 if self.is_eagle else all_token_len actual_kv_len = all_token_len - 1 if self.is_eagle else all_token_len
...@@ -452,7 +451,6 @@ class SWARadixCache(BasePrefixCache): ...@@ -452,7 +451,6 @@ class SWARadixCache(BasePrefixCache):
page_aligned_kv_indices = kv_indices[:page_aligned_len].to( page_aligned_kv_indices = kv_indices[:page_aligned_len].to(
dtype=torch.int64, copy=True dtype=torch.int64, copy=True
) )
self.token_to_kv_pool_allocator.free(kv_indices[page_aligned_len:])
else: else:
page_aligned_len = actual_kv_len page_aligned_len = actual_kv_len
page_aligned_kv_indices = kv_indices.to(dtype=torch.int64, copy=True) page_aligned_kv_indices = kv_indices.to(dtype=torch.int64, copy=True)
...@@ -472,11 +470,19 @@ class SWARadixCache(BasePrefixCache): ...@@ -472,11 +470,19 @@ class SWARadixCache(BasePrefixCache):
# Radix Cache takes one ref in memory pool # Radix Cache takes one ref in memory pool
# insert the token_ids and kv_indices into the radix tree # insert the token_ids and kv_indices into the radix tree
# Note: the insert function already frees the overlapped kv_indices # Note: the insert function already frees the overlapped kv_indices
new_prefix_len = self.insert( if is_insert:
RadixKey(token_ids[:page_aligned_token_len], req.extra_key), new_prefix_len = self.insert(
page_aligned_kv_indices, RadixKey(token_ids[:page_aligned_token_len], req.extra_key),
old_prefix_len, page_aligned_kv_indices,
) old_prefix_len,
)
else:
self.token_to_kv_pool_allocator.free(
kv_indices[old_prefix_len:page_aligned_len]
)
# free the unaligned tail
self.token_to_kv_pool_allocator.free(kv_indices[page_aligned_len:])
# Remove req slot release the cache lock # Remove req slot release the cache lock
self.req_to_token_pool.free(req.req_pool_idx) self.req_to_token_pool.free(req.req_pool_idx)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment