Unverified Commit 9f78f391 authored by Zhiqiang Xie's avatar Zhiqiang Xie Committed by GitHub
Browse files

HiCache Storage: generate hash when inserting new nodes (#9053)

parent f508cd3c
......@@ -169,12 +169,13 @@ class StorageOperation:
host_indices: torch.Tensor,
token_ids: List[int],
last_hash: Optional[str] = None,
hash_value: Optional[List[str]] = None,
):
self.host_indices = host_indices
self.token_ids = token_ids
self.last_hash = last_hash
self.completed_tokens = 0
self.hash_value = []
self.hash_value = hash_value if hash_value is not None else []
self.id = StorageOperation.counter
StorageOperation.counter += 1
......@@ -702,12 +703,12 @@ class HiCacheController:
self,
host_indices: torch.Tensor,
token_ids: List[int],
last_hash: Optional[str] = None,
hash_value: Optional[List[str]] = None,
) -> int:
"""
Write KV caches from host memory to storage backend.
"""
operation = StorageOperation(host_indices, token_ids, last_hash)
operation = StorageOperation(host_indices, token_ids, hash_value=hash_value)
self.backup_queue.put(operation)
return operation.id
......@@ -762,24 +763,6 @@ class HiCacheController:
if operation is None:
continue
last_hash = operation.last_hash
tokens_to_backup = operation.token_ids
backup_hit_count = 0
remaining_tokens = len(tokens_to_backup)
hash_value = []
while remaining_tokens >= self.page_size:
last_hash = self.get_hash_str(
tokens_to_backup[
backup_hit_count : backup_hit_count + self.page_size
],
last_hash,
)
backup_hit_count += self.page_size
hash_value.append(last_hash)
remaining_tokens -= self.page_size
operation.hash_value = hash_value
if self.is_mooncake_backend():
self.mooncake_page_backup(operation)
elif self.storage_backend_type == "hf3fs":
......@@ -802,7 +785,6 @@ class HiCacheController:
self.ack_backup_queue.put(
(
operation.id,
operation.hash_value[: min_completed_tokens // self.page_size],
min_completed_tokens,
)
)
......
......@@ -15,7 +15,7 @@ from sglang.srt.distributed import (
)
def get_hash_str(token_ids: List[int], prior_hash: Optional[str] = None) -> str:
def get_hash_str(token_ids: List[int], prior_hash: str = None) -> str:
hasher = hashlib.sha256()
if prior_hash:
......
......@@ -151,7 +151,7 @@ class HiRadixCache(RadixCache):
def write_backup_storage(self, node: TreeNode):
operation_id = self.cache_controller.write_storage(
node.host_value, node.key, node.parent.get_last_hash_value()
node.host_value, node.key, node.hash_value
)
self.ongoing_backup[operation_id] = node
node.protect_host()
......@@ -414,18 +414,18 @@ class HiRadixCache(RadixCache):
group=self.tp_group,
)
for _ in range(queue_size.item()):
ack_id, hash_value, completed_tokens = (
self.cache_controller.ack_backup_queue.get()
)
ack_id, completed_tokens = self.cache_controller.ack_backup_queue.get()
host_node = self.ongoing_backup[ack_id]
if completed_tokens == 0:
host_node.hash_value = None
elif completed_tokens < len(host_node.key):
# backup is only partially successful, split the node
new_node = self._split_node(host_node.key, host_node, completed_tokens)
new_node.hash_value = hash_value
else:
host_node.hash_value = hash_value
if completed_tokens > 0:
if completed_tokens < len(host_node.key):
# backup is only partially successful, split the node
new_node = self._split_node(
host_node.key, host_node, completed_tokens
)
new_node.backuped_storage = True
else:
host_node.backuped_storage = True
host_node.release_host()
del self.ongoing_backup[ack_id]
......@@ -717,6 +717,21 @@ class HiRadixCache(RadixCache):
node.children[child_key] = new_node
self.evictable_size_ += len(value)
if self.enable_storage:
last_hash = node.get_last_hash_value()
assert (node == self.root_node) or (
last_hash is not None
), "Parent node must have a hash value with storage enabled"
new_node.hash_value = []
for idx in range(0, len(key), self.page_size):
new_node.hash_value.append(
self.cache_controller.get_hash_str(
key[idx : idx + self.page_size],
prior_hash=last_hash,
)
)
last_hash = new_node.hash_value[-1]
if self.cache_controller.write_policy != "write_back":
self.inc_hit_count(new_node)
return total_prefix_length
......
......@@ -62,6 +62,7 @@ class TreeNode:
self.host_value: Optional[torch.Tensor] = None
# store hash values of each pages
self.hash_value: Optional[List[str]] = None
self.backuped_storage = False
self.id = TreeNode.counter if id is None else id
TreeNode.counter += 1
......@@ -74,10 +75,6 @@ class TreeNode:
def backuped(self):
return self.host_value is not None
@property
def backuped_storage(self):
return self.hash_value is not None and len(self.hash_value) > 0
def protect_host(self):
"""Protect the host value from eviction."""
self.host_ref_counter += 1
......
......@@ -18,13 +18,12 @@ DEFAULT_LOCAL_BUFFER_SIZE = 128 * 1024 * 1024 # 128 MB
logger = logging.getLogger(__name__)
def get_hash_str_mooncake(current_page_ids: List, prefix_block_key: str):
def get_hash_str_mooncake(token_ids: List[int], prior_hash: str = None):
local_rank = get_tensor_model_parallel_rank()
prefix_str = ""
if prefix_block_key:
if len(prefix_block_key):
prefix_str = hashlib.sha256(prefix_block_key.encode()).hexdigest()
current_token_ids_bytes = np.array(current_page_ids).tobytes()
if prior_hash:
prefix_str = hashlib.sha256(prior_hash.encode()).hexdigest()
current_token_ids_bytes = np.array(token_ids).tobytes()
current_hash_object = hashlib.sha256(current_token_ids_bytes)
current_hash_hex = current_hash_object.hexdigest()
return f"{prefix_str}_{int(current_hash_hex[:16], 16)}_{local_rank}"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment