Unverified Commit f05c6873 authored by Teng Ma's avatar Teng Ma Committed by GitHub
Browse files

[HiCache] Clear kvcache in storage backend with fastAPI (#9750)


Co-authored-by: default avatarhzh0425 <hzh0425@apache.org>
parent 9a0d0b75
...@@ -480,6 +480,16 @@ async def flush_cache(): ...@@ -480,6 +480,16 @@ async def flush_cache():
) )
@app.api_route("/clear_hicache_storage_backend", methods=["GET", "POST"])
async def clear_hicache_storage_backend():
"""Clear the hierarchical cache storage backend."""
ret = await _global_state.tokenizer_manager.clear_hicache_storage()
return Response(
content="Hierarchical cache storage backend cleared.\n",
status_code=200 if ret.success else HTTPStatus.BAD_REQUEST,
)
@app.api_route("/start_profile", methods=["GET", "POST"]) @app.api_route("/start_profile", methods=["GET", "POST"])
async def start_profile_async(obj: Optional[ProfileReqInput] = None): async def start_profile_async(obj: Optional[ProfileReqInput] = None):
"""Start profiling.""" """Start profiling."""
......
...@@ -814,6 +814,16 @@ class BatchEmbeddingOut: ...@@ -814,6 +814,16 @@ class BatchEmbeddingOut:
cached_tokens: List[int] cached_tokens: List[int]
@dataclass
class ClearHiCacheReqInput:
pass
@dataclass
class ClearHiCacheReqOutput:
success: bool
@dataclass @dataclass
class FlushCacheReqInput: class FlushCacheReqInput:
pass pass
......
...@@ -69,6 +69,8 @@ from sglang.srt.managers.io_struct import ( ...@@ -69,6 +69,8 @@ from sglang.srt.managers.io_struct import (
AbortReq, AbortReq,
BatchTokenizedEmbeddingReqInput, BatchTokenizedEmbeddingReqInput,
BatchTokenizedGenerateReqInput, BatchTokenizedGenerateReqInput,
ClearHiCacheReqInput,
ClearHiCacheReqOutput,
CloseSessionReqInput, CloseSessionReqInput,
ExpertDistributionReq, ExpertDistributionReq,
ExpertDistributionReqOutput, ExpertDistributionReqOutput,
...@@ -515,6 +517,7 @@ class Scheduler( ...@@ -515,6 +517,7 @@ class Scheduler(
(BatchTokenizedGenerateReqInput, self.handle_batch_generate_request), (BatchTokenizedGenerateReqInput, self.handle_batch_generate_request),
(BatchTokenizedEmbeddingReqInput, self.handle_batch_embedding_request), (BatchTokenizedEmbeddingReqInput, self.handle_batch_embedding_request),
(FlushCacheReqInput, self.flush_cache_wrapped), (FlushCacheReqInput, self.flush_cache_wrapped),
(ClearHiCacheReqInput, self.clear_hicache_storage_wrapped),
(AbortReq, self.abort_request), (AbortReq, self.abort_request),
(OpenSessionReqInput, self.open_session), (OpenSessionReqInput, self.open_session),
(CloseSessionReqInput, self.close_session), (CloseSessionReqInput, self.close_session),
...@@ -2207,6 +2210,16 @@ class Scheduler( ...@@ -2207,6 +2210,16 @@ class Scheduler(
success = self.flush_cache() success = self.flush_cache()
return FlushCacheReqOutput(success=success) return FlushCacheReqOutput(success=success)
def clear_hicache_storage_wrapped(self, recv_req: ClearHiCacheReqInput):
if self.enable_hierarchical_cache:
self.tree_cache.clear_storage_backend()
logger.info("Hierarchical cache cleared successfully!")
if_success = True
else:
logging.warning("Hierarchical cache is not enabled.")
if_success = False
return ClearHiCacheReqOutput(success=if_success)
def flush_cache(self): def flush_cache(self):
"""Flush the memory pool and cache.""" """Flush the memory pool and cache."""
if ( if (
......
...@@ -73,6 +73,8 @@ from sglang.srt.managers.io_struct import ( ...@@ -73,6 +73,8 @@ from sglang.srt.managers.io_struct import (
BatchTokenIDOut, BatchTokenIDOut,
BatchTokenizedEmbeddingReqInput, BatchTokenizedEmbeddingReqInput,
BatchTokenizedGenerateReqInput, BatchTokenizedGenerateReqInput,
ClearHiCacheReqInput,
ClearHiCacheReqOutput,
CloseSessionReqInput, CloseSessionReqInput,
ConfigureLoggingReq, ConfigureLoggingReq,
EmbeddingReqInput, EmbeddingReqInput,
...@@ -386,6 +388,9 @@ class TokenizerManager: ...@@ -386,6 +388,9 @@ class TokenizerManager:
self.flush_cache_communicator = _Communicator( self.flush_cache_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size self.send_to_scheduler, server_args.dp_size
) )
self.clear_hicache_storage_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size
)
self.profile_communicator = _Communicator( self.profile_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size self.send_to_scheduler, server_args.dp_size
) )
...@@ -447,6 +452,10 @@ class TokenizerManager: ...@@ -447,6 +452,10 @@ class TokenizerManager:
SlowDownReqOutput, SlowDownReqOutput,
self.slow_down_communicator.handle_recv, self.slow_down_communicator.handle_recv,
), ),
(
ClearHiCacheReqOutput,
self.clear_hicache_storage_communicator.handle_recv,
),
( (
FlushCacheReqOutput, FlushCacheReqOutput,
self.flush_cache_communicator.handle_recv, self.flush_cache_communicator.handle_recv,
...@@ -988,6 +997,13 @@ class TokenizerManager: ...@@ -988,6 +997,13 @@ class TokenizerManager:
async def flush_cache(self) -> FlushCacheReqOutput: async def flush_cache(self) -> FlushCacheReqOutput:
return (await self.flush_cache_communicator(FlushCacheReqInput()))[0] return (await self.flush_cache_communicator(FlushCacheReqInput()))[0]
async def clear_hicache_storage(self) -> ClearHiCacheReqOutput:
"""Clear the hierarchical cache storage."""
# Delegate to the scheduler to handle HiCacheStorage clearing
return (await self.clear_hicache_storage_communicator(ClearHiCacheReqInput()))[
0
]
def abort_request(self, rid: str = "", abort_all: bool = False): def abort_request(self, rid: str = "", abort_all: bool = False):
if not abort_all and rid not in self.rid_to_state: if not abort_all and rid not in self.rid_to_state:
return return
......
...@@ -102,6 +102,20 @@ class HiCacheStorage(ABC): ...@@ -102,6 +102,20 @@ class HiCacheStorage(ABC):
""" """
pass pass
@abstractmethod
def delete(self, key: str) -> bool:
"""
Delete the entry associated with the given key.
"""
pass
@abstractmethod
def clear(self) -> bool:
"""
Clear all entries in the storage.
"""
pass
def batch_exists(self, keys: List[str]) -> int: def batch_exists(self, keys: List[str]) -> int:
""" """
Check if the keys exist in the storage. Check if the keys exist in the storage.
...@@ -214,12 +228,14 @@ class HiCacheFile(HiCacheStorage): ...@@ -214,12 +228,14 @@ class HiCacheFile(HiCacheStorage):
logger.warning(f"Key {key} does not exist. Cannot delete.") logger.warning(f"Key {key} does not exist. Cannot delete.")
return return
def clear(self) -> None: def clear(self) -> bool:
try: try:
for filename in os.listdir(self.file_path): for filename in os.listdir(self.file_path):
file_path = os.path.join(self.file_path, filename) file_path = os.path.join(self.file_path, filename)
if os.path.isfile(file_path): if os.path.isfile(file_path):
os.remove(file_path) os.remove(file_path)
logger.info("Cleared all entries in HiCacheFile storage.") logger.info("Cleared all entries in HiCacheFile storage.")
return True
except Exception as e: except Exception as e:
logger.error(f"Failed to clear HiCacheFile storage: {e}") logger.error(f"Failed to clear HiCacheFile storage: {e}")
return False
...@@ -125,6 +125,15 @@ class HiRadixCache(RadixCache): ...@@ -125,6 +125,15 @@ class HiRadixCache(RadixCache):
height += 1 height += 1
return height return height
def clear_storage_backend(self):
if self.enable_storage:
self.cache_controller.storage_backend.clear()
logger.info("Hierarchical cache storage backend cleared successfully!")
return True
else:
logger.warning("Hierarchical cache storage backend is not enabled.")
return False
def write_backup(self, node: TreeNode, write_back=False): def write_backup(self, node: TreeNode, write_back=False):
host_indices = self.cache_controller.write( host_indices = self.cache_controller.write(
device_indices=node.value, device_indices=node.value,
......
...@@ -393,8 +393,14 @@ class HiCacheHF3FS(HiCacheStorage): ...@@ -393,8 +393,14 @@ class HiCacheHF3FS(HiCacheStorage):
return len(keys) return len(keys)
def clear(self) -> None: def clear(self) -> bool:
self.metadata_client.clear(self.rank) try:
self.metadata_client.clear(self.rank)
logger.info(f"Cleared HiCacheHF3FS for rank {self.rank}")
return True
except Exception as e:
logger.error(f"Failed to clear HiCacheHF3FS: {e}")
return False
def close(self) -> None: def close(self) -> None:
try: try:
......
...@@ -254,7 +254,7 @@ class MooncakeStore(HiCacheStorage): ...@@ -254,7 +254,7 @@ class MooncakeStore(HiCacheStorage):
pass pass
def clear(self) -> None: def clear(self) -> None:
raise (NotImplementedError) self.store.remove_all()
def _put_batch_zero_copy_impl( def _put_batch_zero_copy_impl(
self, key_strs: List[str], buffer_ptrs: List[int], buffer_sizes: List[int] self, key_strs: List[str], buffer_ptrs: List[int], buffer_sizes: List[int]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment