Unverified Commit 9d33fcfb authored by Zhiqiang Xie's avatar Zhiqiang Xie Committed by GitHub
Browse files

Hicache Storage Layer Prototype (#7704)

parent 7891bac1
...@@ -25,6 +25,8 @@ if TYPE_CHECKING: ...@@ -25,6 +25,8 @@ if TYPE_CHECKING:
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
from sglang.srt.mem_cache.memory_pool_host import HostKVCache from sglang.srt.mem_cache.memory_pool_host import HostKVCache
from sglang.srt.mem_cache.hicache_storage import HiCacheFile, get_hash_str
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -159,6 +161,57 @@ class TransferBuffer: ...@@ -159,6 +161,57 @@ class TransferBuffer:
self.buffers.queue.clear() self.buffers.queue.clear()
class StorageOperation:
counter = 0
def __init__(
self,
host_indices: torch.Tensor,
token_ids: List[int],
last_hash: Optional[str] = None,
):
self.host_indices = host_indices
self.token_ids = token_ids
self.last_hash = last_hash
self.completed_tokens = 0
self.hash_value = []
self.id = StorageOperation.counter
StorageOperation.counter += 1
def __lt__(self, other: "StorageOperation"):
return self.id < other.id
class PrefetchOperation(StorageOperation):
def __init__(
self,
request_id: str,
host_indices: torch.Tensor,
token_ids: List[int],
last_hash: Optional[str] = None,
):
self.request_id = request_id
self._done_flag = False
self._lock = threading.Lock()
super().__init__(host_indices, token_ids, last_hash)
def increment(self, num_tokens: int):
with self._lock:
if self._done_flag:
return
self.completed_tokens += num_tokens
def mark_done(self):
with self._lock:
self._done_flag = True
def is_done(self) -> bool:
return self._done_flag
class HiCacheController: class HiCacheController:
def __init__( def __init__(
...@@ -169,6 +222,8 @@ class HiCacheController: ...@@ -169,6 +222,8 @@ class HiCacheController:
load_cache_event: threading.Event = None, load_cache_event: threading.Event = None,
write_policy: str = "write_through_selective", write_policy: str = "write_through_selective",
io_backend: str = "", io_backend: str = "",
storage_backend: Optional[str] = None,
prefetch_threshold: int = 256,
): ):
self.mem_pool_device_allocator = token_to_kv_pool_allocator self.mem_pool_device_allocator = token_to_kv_pool_allocator
self.mem_pool_device = token_to_kv_pool_allocator.get_kvcache() self.mem_pool_device = token_to_kv_pool_allocator.get_kvcache()
...@@ -186,6 +241,19 @@ class HiCacheController: ...@@ -186,6 +241,19 @@ class HiCacheController:
else: else:
self.io_backend = io_backend self.io_backend = io_backend
self.enable_storage = False
# todo: move backend initialization to storage backend module
if storage_backend is not None:
if storage_backend == "file":
self.storage_backend = HiCacheFile()
self.enable_storage = True
# todo: threshold policy for prefetching
self.prefetch_threshold = prefetch_threshold
else:
raise NotImplementedError(
f"Unsupported storage backend: {storage_backend}"
)
self.load_cache_event = load_cache_event self.load_cache_event = load_cache_event
self.layer_done_counter = LayerDoneCounter(self.mem_pool_device.layer_num) self.layer_done_counter = LayerDoneCounter(self.mem_pool_device.layer_num)
self.mem_pool_device.register_layer_transfer_counter(self.layer_done_counter) self.mem_pool_device.register_layer_transfer_counter(self.layer_done_counter)
...@@ -218,9 +286,26 @@ class HiCacheController: ...@@ -218,9 +286,26 @@ class HiCacheController:
self.load_thread = threading.Thread( self.load_thread = threading.Thread(
target=self.load_thread_func_layer_by_layer, daemon=True target=self.load_thread_func_layer_by_layer, daemon=True
) )
self.write_thread.start() self.write_thread.start()
self.load_thread.start() self.load_thread.start()
if self.enable_storage:
self.prefetch_thread = threading.Thread(
target=self.prefetch_thread_func, daemon=True
)
self.backup_thread = threading.Thread(
target=self.backup_thread_func, daemon=True
)
self.prefetch_queue = Queue()
self.backup_queue = Queue()
self.prefetch_revoke_queue = Queue()
self.ack_backup_queue = Queue()
self.prefetch_thread.start()
self.backup_thread.start()
def reset(self): def reset(self):
self.stop_event.set() self.stop_event.set()
self.write_thread.join() self.write_thread.join()
...@@ -232,6 +317,13 @@ class HiCacheController: ...@@ -232,6 +317,13 @@ class HiCacheController:
self.load_buffer.clear() self.load_buffer.clear()
self.ack_write_queue.queue.clear() self.ack_write_queue.queue.clear()
self.ack_load_queue.queue.clear() self.ack_load_queue.queue.clear()
if self.enable_storage:
self.prefetch_thread.join()
self.backup_thread.join()
self.prefetch_queue.queue.clear()
self.backup_queue.queue.clear()
self.prefetch_revoke_queue.queue.clear()
self.ack_backup_queue.queue.clear()
self.write_thread = threading.Thread( self.write_thread = threading.Thread(
target=self.write_thread_func_direct, daemon=True target=self.write_thread_func_direct, daemon=True
...@@ -243,6 +335,16 @@ class HiCacheController: ...@@ -243,6 +335,16 @@ class HiCacheController:
self.write_thread.start() self.write_thread.start()
self.load_thread.start() self.load_thread.start()
if self.enable_storage:
self.prefetch_thread = threading.Thread(
target=self.prefetch_thread_func, daemon=True
)
self.backup_thread = threading.Thread(
target=self.backup_thread_func, daemon=True
)
self.prefetch_thread.start()
self.backup_thread.start()
def write( def write(
self, self,
device_indices: torch.Tensor, device_indices: torch.Tensor,
...@@ -383,3 +485,142 @@ class HiCacheController: ...@@ -383,3 +485,142 @@ class HiCacheController:
raise ValueError( raise ValueError(
f"Inconsistent states: {self.mem_pool_host.get_state(host_indices)}" f"Inconsistent states: {self.mem_pool_host.get_state(host_indices)}"
) )
def prefetch(
self,
request_id: str,
host_indices: torch.Tensor,
new_input_tokens: List[int],
last_hash: Optional[str] = None,
) -> int:
"""
Prefetch KV caches from storage backend to host memory.
"""
operation = PrefetchOperation(
request_id, host_indices, new_input_tokens, last_hash
)
self.prefetch_queue.put(operation)
return operation
def terminate_prefetch(self, operation):
operation.mark_done()
return operation.completed_tokens, operation.hash_value
def prefetch_io_aux_func(self):
"""
Auxiliary function conducting IO operations for prefetching.
"""
while not self.stop_event.is_set():
try:
operation = self.prefetch_buffer.get(block=True, timeout=1)
for h in operation.hash_value:
page_data = self.storage_backend.get(h)
if page_data is None:
logger.warning(
f"Prefetch operation {operation.request_id} failed to retrieve page {h}."
)
break
self.mem_pool_host.set_from_flat_data_page(
operation.host_indices[operation.completed_tokens],
page_data,
)
operation.increment(self.page_size)
if operation.is_done():
# operation terminated by controller, release pre-allocated memory
self.mem_pool_host.free(
operation.host_indices[operation.completed_tokens :]
)
break
except Empty:
continue
def prefetch_thread_func(self):
"""
Manage prefetching operations from storage backend to host memory.
"""
self.prefetch_buffer = Queue()
aux_thread = threading.Thread(target=self.prefetch_io_aux_func, daemon=True)
aux_thread.start()
while (not self.stop_event.is_set()) or not self.prefetch_queue.empty():
try:
operation = self.prefetch_queue.get(block=True, timeout=1)
if operation is None:
continue
last_hash = operation.last_hash
tokens_to_fetch = operation.token_ids
storage_hit_count = 0
remaining_tokens = len(tokens_to_fetch)
hash_value = []
while remaining_tokens >= self.page_size:
last_hash = get_hash_str(
tokens_to_fetch[
storage_hit_count : storage_hit_count + self.page_size
],
last_hash,
)
if self.storage_backend.exists(last_hash):
storage_hit_count += self.page_size
hash_value.append(last_hash)
remaining_tokens -= self.page_size
else:
break
if storage_hit_count < self.prefetch_threshold:
# not to prefetch if not enough benefits
self.prefetch_revoke_queue.put(operation.request_id)
else:
operation.hash_value = hash_value
logger.debug(
f"Prefetching {len(hash_value)} pages for request {operation.request_id}."
)
self.prefetch_buffer.put(operation)
except Empty:
continue
def write_storage(
self,
host_indices: torch.Tensor,
token_ids: List[int],
last_hash: Optional[str] = None,
) -> int:
"""
Write KV caches from host memory to storage backend.
"""
operation = StorageOperation(host_indices, token_ids, last_hash)
self.backup_queue.put(operation)
return operation.id
def backup_thread_func(self):
"""
Manage backup operations from host memory to storage backend.
"""
while not self.stop_event.is_set():
try:
operation = self.backup_queue.get(block=True, timeout=1)
if operation is None:
continue
last_hash = operation.last_hash
tokens_to_backup = operation.token_ids
for i in range(0, len(tokens_to_backup), self.page_size):
last_hash = get_hash_str(
tokens_to_backup[i : i + self.page_size], last_hash
)
# todo, handle failures in storage backend
self.storage_backend.set(
last_hash,
self.mem_pool_host.get_flat_data_page(
operation.host_indices[i]
),
)
operation.completed_tokens += self.page_size
operation.hash_value.append(last_hash)
self.ack_backup_queue.put((operation.id, operation.hash_value))
except Empty:
continue
...@@ -262,6 +262,7 @@ class Scheduler( ...@@ -262,6 +262,7 @@ class Scheduler(
) )
self.gpu_id = gpu_id self.gpu_id = gpu_id
self.enable_hierarchical_cache = server_args.enable_hierarchical_cache self.enable_hierarchical_cache = server_args.enable_hierarchical_cache
self.enable_hicache_storage = server_args.hicache_storage_backend is not None
self.page_size = server_args.page_size self.page_size = server_args.page_size
self.dp_size = server_args.dp_size self.dp_size = server_args.dp_size
self.attn_tp_rank, self.attn_tp_size, self.attn_dp_rank = ( self.attn_tp_rank, self.attn_tp_size, self.attn_dp_rank = (
...@@ -614,6 +615,7 @@ class Scheduler( ...@@ -614,6 +615,7 @@ class Scheduler(
== "fa3" # hot fix for incompatibility == "fa3" # hot fix for incompatibility
else server_args.hicache_io_backend else server_args.hicache_io_backend
), ),
hicache_storage_backend=server_args.hicache_storage_backend,
) )
self.tp_worker.register_hicache_layer_transfer_counter( self.tp_worker.register_hicache_layer_transfer_counter(
self.tree_cache.cache_controller.layer_done_counter self.tree_cache.cache_controller.layer_done_counter
...@@ -1258,6 +1260,15 @@ class Scheduler( ...@@ -1258,6 +1260,15 @@ class Scheduler(
elif self.disaggregation_mode == DisaggregationMode.DECODE: elif self.disaggregation_mode == DisaggregationMode.DECODE:
self.disagg_decode_prealloc_queue.add(req) self.disagg_decode_prealloc_queue.add(req)
else: else:
if self.enable_hicache_storage:
req.init_next_round_input(self.tree_cache)
last_hash = req.last_host_node.get_last_hash_value()
matched_len = len(req.prefix_indices) + req.host_hit_length
if (matched_len > 0 and last_hash is not None) or matched_len == 0:
new_input_tokens = req.fill_ids[matched_len:]
self.tree_cache.prefetch_from_storage(
req.rid, req.last_host_node, new_input_tokens, last_hash
)
self.waiting_queue.append(req) self.waiting_queue.append(req)
def _extend_requests_to_queue(self, reqs: List[Req], is_retracted: bool = False): def _extend_requests_to_queue(self, reqs: List[Req], is_retracted: bool = False):
...@@ -1731,6 +1742,9 @@ class Scheduler( ...@@ -1731,6 +1742,9 @@ class Scheduler(
self.running_batch.batch_is_full = True self.running_batch.batch_is_full = True
break break
if self.enable_hicache_storage:
self.tree_cache.check_prefetch_progress(req.rid)
req.init_next_round_input(self.tree_cache) req.init_next_round_input(self.tree_cache)
res = adder.add_one_req(req, has_chunked_req=(self.chunked_req is not None)) res = adder.add_one_req(req, has_chunked_req=(self.chunked_req is not None))
......
import hashlib
import logging
import os
from abc import ABC, abstractmethod
from typing import List, Optional
import torch
logger = logging.getLogger(__name__)
def get_hash_str(token_ids: List[int], prior_hash: Optional[str] = None) -> str:
hasher = hashlib.sha256()
if prior_hash:
hasher.update(bytes.fromhex(prior_hash))
for t in token_ids:
hasher.update(t.to_bytes(4, byteorder="little", signed=False))
return hasher.hexdigest()
class HiCacheStorage(ABC):
"""
HiCacheStorage is a class that provides a generic key-value interface for storing and retrieving KV cache.
It abstracts the underlying storage mechanism, allowing different implementations to be used.
"""
# todo, translate tensor object access for different TP ranks
# potentially pass model and TP configs into storage backend
# todo, the page size of storage backend does not have to be the same as the same as host memory pool
@abstractmethod
def get(
self, key: str, target_location: Optional[torch.Tensor] = None
) -> torch.Tensor | None:
"""
Retrieve the value associated with the given key.
Returns None if the key does not exist.
"""
pass
@abstractmethod
def batch_get(
self, keys: List[str], target_locations: Optional[List[torch.Tensor]] = None
) -> List[torch.Tensor | None]:
"""
Retrieve values for multiple keys.
Returns a list of tensors or None for each key.
"""
pass
@abstractmethod
def set(self, key, value) -> bool:
"""
Store the value associated with the given key.
Returns True if the operation was successful, False otherwise.
"""
pass
@abstractmethod
def batch_set(self, keys: List[str], values: List[torch.Tensor]) -> bool:
"""
Store multiple key-value pairs.
Returns True if all operations were successful, False otherwise.
"""
pass
@abstractmethod
def exists(self, key: str) -> bool:
"""
Check if the key exists in the storage.
Returns True if the key exists, False otherwise.
"""
pass
class HiCacheFile(HiCacheStorage):
def __init__(self, file_path: str = "/tmp/hicache"):
self.file_path = file_path
if not os.path.exists(self.file_path):
os.makedirs(self.file_path)
logger.info(f"Created HiCacheFile storage directory at {self.file_path}")
def get(
self, key: str, target_location: Optional[torch.Tensor] = None
) -> torch.Tensor | None:
tensor_path = os.path.join(self.file_path, f"{key}.bin")
try:
# todo: fixing the target_location logic to enable in-place loading
loaded_tensor = torch.load(tensor_path)
if isinstance(loaded_tensor, torch.Tensor):
return loaded_tensor
else:
logger.error(f"Loaded data for key {key} is not a tensor.")
return None
except FileNotFoundError:
return None
def batch_get(
self,
keys: List[str],
target_locations: Optional[List[torch.Tensor]] = None,
) -> List[torch.Tensor | None]:
return [
self.get(key, target_location)
for key, target_location in zip(
keys, target_locations or [None] * len(keys)
)
]
def set(self, key: str, value: torch.Tensor) -> bool:
tensor_path = os.path.join(self.file_path, f"{key}.bin")
if self.exists(key):
logger.debug(f"Key {key} already exists. Skipped.")
return True
try:
torch.save(value, tensor_path)
return True
except Exception as e:
logger.error(f"Failed to save tensor {key}: {e}")
return False
def batch_set(self, keys: List[str], values: List[torch.Tensor]) -> bool:
for key, value in zip(keys, values):
if not self.set(key, value):
return False
return True
def exists(self, key: str) -> bool:
tensor_path = os.path.join(self.file_path, f"{key}.bin")
return os.path.exists(tensor_path)
def delete(self, key: str) -> None:
tensor_path = os.path.join(self.file_path, f"{key}.bin")
try:
os.remove(tensor_path)
except FileNotFoundError:
logger.warning(f"Key {key} does not exist. Cannot delete.")
return
def clear(self) -> None:
try:
for filename in os.listdir(self.file_path):
file_path = os.path.join(self.file_path, filename)
if os.path.isfile(file_path):
os.remove(file_path)
logger.info("Cleared all entries in HiCacheFile storage.")
except Exception as e:
logger.error(f"Failed to clear HiCacheFile storage: {e}")
...@@ -35,6 +35,7 @@ class HiRadixCache(RadixCache): ...@@ -35,6 +35,7 @@ class HiRadixCache(RadixCache):
hicache_size: int, hicache_size: int,
hicache_write_policy: str, hicache_write_policy: str,
hicache_io_backend: str, hicache_io_backend: str,
hicache_storage_backend: Optional[str] = None,
): ):
self.kv_cache = token_to_kv_pool_allocator.get_kvcache() self.kv_cache = token_to_kv_pool_allocator.get_kvcache()
if isinstance(self.kv_cache, MHATokenToKVPool): if isinstance(self.kv_cache, MHATokenToKVPool):
...@@ -49,6 +50,9 @@ class HiRadixCache(RadixCache): ...@@ -49,6 +50,9 @@ class HiRadixCache(RadixCache):
raise ValueError(f"HiRadixCache only supports MHA and MLA yet") raise ValueError(f"HiRadixCache only supports MHA and MLA yet")
self.tp_group = tp_cache_group self.tp_group = tp_cache_group
self.enable_storage = hicache_storage_backend is not None
# todo: customizable storage prefetch threshold
self.prefetch_threshold = 256
self.load_cache_event = threading.Event() self.load_cache_event = threading.Event()
self.cache_controller = HiCacheController( self.cache_controller = HiCacheController(
...@@ -58,16 +62,22 @@ class HiRadixCache(RadixCache): ...@@ -58,16 +62,22 @@ class HiRadixCache(RadixCache):
load_cache_event=self.load_cache_event, load_cache_event=self.load_cache_event,
write_policy=hicache_write_policy, write_policy=hicache_write_policy,
io_backend=hicache_io_backend, io_backend=hicache_io_backend,
storage_backend=hicache_storage_backend,
prefetch_threshold=self.prefetch_threshold,
) )
# record the nodes with ongoing write through # record the nodes with ongoing write through
self.ongoing_write_through = {} self.ongoing_write_through = {}
# record the node segments with ongoing load back # record the node segments with ongoing load back
self.ongoing_load_back = {} self.ongoing_load_back = {}
# record the ongoing prefetch requests
self.ongoing_prefetch = {}
self.ongoing_backup = {}
# todo: dynamically adjust the threshold # todo: dynamically adjust the threshold
self.write_through_threshold = ( self.write_through_threshold = (
1 if hicache_write_policy == "write_through" else 3 1 if hicache_write_policy == "write_through" else 3
) )
self.write_through_threshold_storage = 3
self.load_back_threshold = 10 self.load_back_threshold = 10
super().__init__( super().__init__(
req_to_token_pool, token_to_kv_pool_allocator, page_size, disable=False req_to_token_pool, token_to_kv_pool_allocator, page_size, disable=False
...@@ -108,13 +118,30 @@ class HiRadixCache(RadixCache): ...@@ -108,13 +118,30 @@ class HiRadixCache(RadixCache):
return len(host_indices) return len(host_indices)
def write_backup_storage(self, node: TreeNode):
operation_id = self.cache_controller.write_storage(
node.host_value, node.key, node.parent.get_last_hash_value()
)
self.ongoing_backup[operation_id] = node
node.protect_host()
def inc_hit_count(self, node: TreeNode): def inc_hit_count(self, node: TreeNode):
if node.backuped or self.cache_controller.write_policy == "write_back": if self.cache_controller.write_policy == "write_back":
return return
node.hit_count += 1 node.hit_count += 1
if node.hit_count >= self.write_through_threshold:
self.write_backup(node) if not node.backuped:
node.hit_count = 0 if node.hit_count >= self.write_through_threshold:
# write to host if the node is not backuped
self.write_backup(node)
else:
if (
self.enable_storage
and (not node.backuped_storage)
and node.hit_count >= self.write_through_threshold_storage
):
# if the node is backuped on host memory but not on storage
self.write_backup_storage(node)
def writing_check(self, write_back=False): def writing_check(self, write_back=False):
if write_back: if write_back:
...@@ -221,6 +248,10 @@ class HiRadixCache(RadixCache): ...@@ -221,6 +248,10 @@ class HiRadixCache(RadixCache):
if not x.evicted: if not x.evicted:
continue continue
# node is protected from eviction as it has ongoing prefetch or backup to storage
if x.host_ref_counter > 0:
continue
num_evicted += self.cache_controller.evict_host(x.host_value) num_evicted += self.cache_controller.evict_host(x.host_value)
for k, v in x.parent.children.items(): for k, v in x.parent.children.items():
...@@ -314,6 +345,85 @@ class HiRadixCache(RadixCache): ...@@ -314,6 +345,85 @@ class HiRadixCache(RadixCache):
def check_hicache_events(self): def check_hicache_events(self):
self.writing_check() self.writing_check()
self.loading_check() self.loading_check()
if self.enable_storage:
self.check_revoked_prefetch()
self.check_backup_progress()
def check_revoked_prefetch(self):
queue_size = torch.tensor(
self.cache_controller.prefetch_revoke_queue.qsize(), dtype=torch.int
)
if torch.distributed.get_world_size(group=self.tp_group) > 1:
# synchrnoize TP workers to make the same update to hiradix cache
torch.distributed.all_reduce(
queue_size,
op=torch.distributed.ReduceOp.MIN,
group=self.tp_group,
)
for _ in range(queue_size.item()):
req_id = self.cache_controller.prefetch_revoke_queue.get()
if req_id in self.ongoing_prefetch:
last_host_node, _, host_indices, _ = self.ongoing_prefetch[req_id]
last_host_node.release_host()
self.cache_controller.mem_pool_host.free(host_indices)
del self.ongoing_prefetch[req_id]
def check_backup_progress(self):
queue_size = torch.tensor(
self.cache_controller.ack_backup_queue.qsize(), dtype=torch.int
)
if torch.distributed.get_world_size(group=self.tp_group) > 1:
# synchrnoize TP workers to make the same update to hiradix cache
torch.distributed.all_reduce(
queue_size,
op=torch.distributed.ReduceOp.MIN,
group=self.tp_group,
)
for _ in range(queue_size.item()):
ack_id, hash_value = self.cache_controller.ack_backup_queue.get()
self.ongoing_backup[ack_id].hash_value = hash_value
self.ongoing_backup[ack_id].release_host()
del self.ongoing_backup[ack_id]
def check_prefetch_progress(self, req_id: str):
if req_id not in self.ongoing_prefetch:
# there is no ongoing prefetch for this request or it has been revoked
return
# todo: more policies for prefetch progress such as timeout
# the current policy is to prefetch with best effort and terminate when queuing is over
last_host_node, token_ids, host_indices, operation = self.ongoing_prefetch[
req_id
]
completed_tokens, hash_value = self.cache_controller.terminate_prefetch(
operation
)
logger.debug(f"Prefetch {req_id} completed with {completed_tokens} tokens")
min_completed_tokens = torch.tensor(completed_tokens, dtype=torch.int)
if torch.distributed.get_world_size(group=self.tp_group) > 1:
# synchrnoize TP workers to make the same update to hiradix cache
torch.distributed.all_reduce(
min_completed_tokens,
op=torch.distributed.ReduceOp.MIN,
group=self.tp_group,
)
min_completed_tokens = min_completed_tokens.item()
fetched_token_ids = token_ids[:min_completed_tokens]
written_indices = host_indices[:min_completed_tokens]
matched_length = self._insert_helper_host(
last_host_node,
fetched_token_ids,
written_indices,
hash_value[:min_completed_tokens],
)
self.cache_controller.mem_pool_host.free(host_indices[:matched_length])
self.cache_controller.mem_pool_host.free(
host_indices[min_completed_tokens:completed_tokens]
)
last_host_node.release_host()
del self.ongoing_prefetch[req_id]
def match_prefix(self, key: List[int], **kwargs): def match_prefix(self, key: List[int], **kwargs):
empty_value = torch.empty((0,), dtype=torch.int64, device=self.device) empty_value = torch.empty((0,), dtype=torch.int64, device=self.device)
...@@ -348,6 +458,71 @@ class HiRadixCache(RadixCache): ...@@ -348,6 +458,71 @@ class HiRadixCache(RadixCache):
host_hit_length=host_hit_length, host_hit_length=host_hit_length,
) )
def prefetch_from_storage(
self,
req_id: str,
last_host_node: TreeNode,
new_input_tokens: List[int],
last_hash: Optional[str] = None,
):
if not self.enable_storage or len(new_input_tokens) < self.prefetch_threshold:
return
last_host_node.protect_host()
host_indices = self.cache_controller.mem_pool_host.alloc(len(new_input_tokens))
if host_indices is None:
self.evict_host(len(new_input_tokens))
host_indices = self.cache_controller.mem_pool_host.alloc(
len(new_input_tokens)
)
if host_indices is None:
last_host_node.release_host()
# no sufficient host memory to prefetch
return
operation = self.cache_controller.prefetch(
req_id, host_indices, new_input_tokens, last_hash
)
self.ongoing_prefetch[req_id] = (
last_host_node,
new_input_tokens,
host_indices,
operation,
)
def _insert_helper_host(self, node: TreeNode, key: List, host_value, hash_value):
node.last_access_time = time.monotonic()
if len(key) == 0:
return 0
child_key = self.get_child_key_fn(key)
matched_length = 0
while len(key) > 0 and child_key in node.children.keys():
node = node.children[child_key]
node.last_access_time = time.monotonic()
prefix_len = self.key_match_fn(node.key, key)
key = key[prefix_len:]
host_value = host_value[prefix_len:]
hash_value = hash_value[prefix_len:]
matched_length += prefix_len
if prefix_len < len(node.key):
new_node = self._split_node(node.key, node, prefix_len)
node = new_node
if len(key):
child_key = self.get_child_key_fn(key)
if len(key):
new_node = TreeNode()
new_node.parent = node
new_node.key = key
new_node.value = None
new_node.host_value = host_value
new_node.hash_value = hash_value
node.children[child_key] = new_node
return matched_length
def _match_prefix_helper(self, node: TreeNode, key: List): def _match_prefix_helper(self, node: TreeNode, key: List):
node.last_access_time = time.monotonic() node.last_access_time = time.monotonic()
child_key = self.get_child_key_fn(key) child_key = self.get_child_key_fn(key)
......
...@@ -99,6 +99,20 @@ class HostKVCache(abc.ABC): ...@@ -99,6 +99,20 @@ class HostKVCache(abc.ABC):
def init_kv_buffer(self): def init_kv_buffer(self):
raise NotImplementedError() raise NotImplementedError()
@abc.abstractmethod
def get_flat_data_page(self, index) -> torch.Tensor:
"""
Get a flat data page from the host memory pool.
"""
raise NotImplementedError()
@abc.abstractmethod
def set_from_flat_data_page(self, index: int, data_page: torch.Tensor) -> None:
"""
Set a flat data page to the host memory pool.
"""
raise NotImplementedError()
@synchronized() @synchronized()
def clear(self): def clear(self):
# Initialize memory states and tracking structures. # Initialize memory states and tracking structures.
...@@ -227,6 +241,19 @@ class MHATokenToKVPoolHost(HostKVCache): ...@@ -227,6 +241,19 @@ class MHATokenToKVPoolHost(HostKVCache):
pin_memory=self.pin_memory, pin_memory=self.pin_memory,
) )
# todo, page first memory layout
def get_flat_data_page(self, index) -> torch.Tensor:
return self.kv_buffer[:, :, index : index + self.page_size, :, :].flatten()
def set_from_flat_data_page(self, index: int, data_page: torch.Tensor) -> None:
self.kv_buffer[:, :, index : index + self.page_size, :, :] = data_page.reshape(
2,
self.layer_num,
self.page_size,
self.head_num,
self.head_dim,
)
@property @property
def k_buffer(self): def k_buffer(self):
return self.kv_buffer[0] return self.kv_buffer[0]
...@@ -276,3 +303,14 @@ class MLATokenToKVPoolHost(HostKVCache): ...@@ -276,3 +303,14 @@ class MLATokenToKVPoolHost(HostKVCache):
device=self.device, device=self.device,
pin_memory=self.pin_memory, pin_memory=self.pin_memory,
) )
def get_flat_data_page(self, index) -> torch.Tensor:
return self.kv_buffer[:, index : index + self.page_size, :, :].flatten()
def set_from_flat_data_page(self, index: int, data_page: torch.Tensor) -> None:
self.kv_buffer[:, index : index + self.page_size, :, :] = data_page.reshape(
self.layer_num,
self.page_size,
1,
self.kv_lora_rank + self.qk_rope_head_dim,
)
...@@ -55,8 +55,13 @@ class TreeNode: ...@@ -55,8 +55,13 @@ class TreeNode:
self.hit_count = 0 self.hit_count = 0
# indicating the node is loading KV cache from host # indicating the node is loading KV cache from host
self.loading = False self.loading = False
# indicating the node is locked to protect from eviction
# incremented when the node is referenced by a storage operation
self.host_ref_counter = 0
# store the host indices of KV cache # store the host indices of KV cache
self.host_value: Optional[torch.Tensor] = None self.host_value: Optional[torch.Tensor] = None
# store hash values of each pages
self.hash_value: Optional[List[str]] = None
self.id = TreeNode.counter if id is None else id self.id = TreeNode.counter if id is None else id
TreeNode.counter += 1 TreeNode.counter += 1
...@@ -69,6 +74,27 @@ class TreeNode: ...@@ -69,6 +74,27 @@ class TreeNode:
def backuped(self): def backuped(self):
return self.host_value is not None return self.host_value is not None
@property
def backuped_storage(self):
return self.hash_value is not None and len(self.hash_value) > 0
def protect_host(self):
"""Protect the host value from eviction."""
self.host_ref_counter += 1
def release_host(self):
"""Release the host value, allowing it to be evicted."""
if self.host_ref_counter > 0:
self.host_ref_counter -= 1
else:
raise RuntimeError("Host reference counter is already zero.")
def get_last_hash_value(self) -> Optional[str]:
"""Returns the hash value of the last page in this node."""
if self.hash_value is None or len(self.hash_value) == 0:
return None
return self.hash_value[-1]
def __lt__(self, other: "TreeNode"): def __lt__(self, other: "TreeNode"):
return self.last_access_time < other.last_access_time return self.last_access_time < other.last_access_time
......
...@@ -222,6 +222,7 @@ class ServerArgs: ...@@ -222,6 +222,7 @@ class ServerArgs:
hicache_size: int = 0 hicache_size: int = 0
hicache_write_policy: str = "write_through_selective" hicache_write_policy: str = "write_through_selective"
hicache_io_backend: str = "" hicache_io_backend: str = ""
hicache_storage_backend: Optional[str] = None
flashinfer_mla_disable_ragged: bool = False flashinfer_mla_disable_ragged: bool = False
disable_shared_experts_fusion: bool = False disable_shared_experts_fusion: bool = False
disable_chunked_prefix_cache: bool = False disable_chunked_prefix_cache: bool = False
...@@ -1604,6 +1605,13 @@ class ServerArgs: ...@@ -1604,6 +1605,13 @@ class ServerArgs:
default=ServerArgs.hicache_io_backend, default=ServerArgs.hicache_io_backend,
help="The IO backend for KV cache transfer between CPU and GPU", help="The IO backend for KV cache transfer between CPU and GPU",
) )
parser.add_argument(
"--hicache-storage-backend",
type=str,
choices=["file"], # todo, mooncacke
default=ServerArgs.hicache_storage_backend,
help="The storage backend for hierarchical KV cache.",
)
parser.add_argument( parser.add_argument(
"--flashinfer-mla-disable-ragged", "--flashinfer-mla-disable-ragged",
action="store_true", action="store_true",
......
...@@ -64,6 +64,7 @@ suites = { ...@@ -64,6 +64,7 @@ suites = {
TestFile("test_fused_moe.py", 30), TestFile("test_fused_moe.py", 30),
TestFile("test_hicache.py", 116), TestFile("test_hicache.py", 116),
TestFile("test_hicache_mla.py", 127), TestFile("test_hicache_mla.py", 127),
TestFile("test_hicache_storage.py", 127),
TestFile("test_hidden_states.py", 55), TestFile("test_hidden_states.py", 55),
TestFile("test_int8_kernel.py", 8), TestFile("test_int8_kernel.py", 8),
TestFile("test_input_embeddings.py", 38), TestFile("test_input_embeddings.py", 38),
......
import unittest
from types import SimpleNamespace
from sglang.srt.utils import kill_process_tree
from sglang.test.run_eval import run_eval
from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
popen_launch_server,
)
class TestHiCache(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--enable-hierarchical-cache",
"--mem-fraction-static",
0.7,
"--hicache-size",
100,
"--page-size",
"64",
"--hicache-storage-backend",
"file",
],
)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def test_mmlu(self):
args = SimpleNamespace(
base_url=self.base_url,
model=self.model,
eval_name="mmlu",
num_examples=64,
num_threads=32,
)
metrics = run_eval(args)
self.assertGreaterEqual(metrics["score"], 0.65)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment