Unverified Commit ee3bd8a1 authored by hzh0425's avatar hzh0425 Committed by GitHub
Browse files

feat(hicache): Support passing prefix keys for l3 store. (#9045)


Co-authored-by: default avatarpansicheng <sicheng.pan.chn@gmail.com>
Co-authored-by: default avatarZhiqiang Xie <xiezhq@stanford.edu>
parent d8467db7
......@@ -22,7 +22,10 @@ from typing import TYPE_CHECKING, List, NamedTuple, Optional, Set, Tuple
import torch
from sglang.srt.mem_cache.hicache_storage import HiCacheStorageConfig
from sglang.srt.mem_cache.hicache_storage import (
HiCacheStorageConfig,
HiCacheStorageExtraInfo,
)
if TYPE_CHECKING:
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
......@@ -191,12 +194,14 @@ class StorageOperation:
token_ids: List[int],
last_hash: Optional[str] = None,
hash_value: Optional[List[str]] = None,
prefix_keys: Optional[List[str]] = None,
):
self.host_indices = host_indices
self.token_ids = token_ids
self.last_hash = last_hash
self.completed_tokens = 0
self.hash_value = hash_value if hash_value is not None else []
self.prefix_keys = prefix_keys
self.id = StorageOperation.counter
StorageOperation.counter += 1
......@@ -212,6 +217,7 @@ class PrefetchOperation(StorageOperation):
host_indices: torch.Tensor,
token_ids: List[int],
last_hash: Optional[str] = None,
prefix_keys: Optional[List[str]] = None,
):
self.request_id = request_id
......@@ -219,7 +225,7 @@ class PrefetchOperation(StorageOperation):
self._terminated_flag = False
self.start_time = time.monotonic()
super().__init__(host_indices, token_ids, last_hash)
super().__init__(host_indices, token_ids, last_hash, prefix_keys=prefix_keys)
def increment(self, num_tokens: int):
with self._lock:
......@@ -550,12 +556,13 @@ class HiCacheController:
host_indices: torch.Tensor,
new_input_tokens: List[int],
last_hash: Optional[str] = None,
prefix_keys: Optional[List[str]] = None,
) -> PrefetchOperation:
"""
Prefetch KV caches from storage backend to host memory.
"""
operation = PrefetchOperation(
request_id, host_indices, new_input_tokens, last_hash
request_id, host_indices, new_input_tokens, last_hash, prefix_keys
)
self.prefetch_queue.put(operation)
return operation
......@@ -571,8 +578,12 @@ class HiCacheController:
for page in pages:
self.host_mem_release_queue.put(page)
def _page_get_zero_copy(self, operation, hash_values, host_indices):
results = self.storage_backend.batch_get_v1(hash_values, host_indices)
def _page_get_zero_copy(
self, operation, hash_values, host_indices, extra_info=None
):
results = self.storage_backend.batch_get_v1(
hash_values, host_indices, extra_info
)
inc = 0
for i in range(len(hash_values)):
if not results[i]:
......@@ -584,7 +595,7 @@ class HiCacheController:
operation.increment(inc)
# todo: deprecate
def _generic_page_get(self, operation, hash_values, host_indices):
def _generic_page_get(self, operation, hash_values, host_indices, extra_info=None):
dummy_page_dst = [
self.mem_pool_host.get_dummy_flat_data_page() for _ in hash_values
]
......@@ -608,6 +619,7 @@ class HiCacheController:
def _page_transfer(self, operation):
# Transfer batch by batch
prefix_keys = operation.prefix_keys
for i in range(0, len(operation.hash_value), self.storage_batch_size):
batch_hashes = operation.hash_value[i : i + self.storage_batch_size]
batch_host_indices = operation.host_indices[
......@@ -615,7 +627,8 @@ class HiCacheController:
]
prev_completed_tokens = operation.completed_tokens
# Get one batch token, and update the completed_tokens if succeed
self.page_get_func(operation, batch_hashes, batch_host_indices)
extra_info = HiCacheStorageExtraInfo(prefix_keys=prefix_keys)
self.page_get_func(operation, batch_hashes, batch_host_indices, extra_info)
# Check termination
if (
operation.completed_tokens
......@@ -623,6 +636,10 @@ class HiCacheController:
):
operation.mark_terminate()
break # Some operations fail or operation terminated by controller
if prefix_keys and len(prefix_keys) > 0:
prefix_keys += batch_hashes
# release pre-allocated memory
self.append_host_mem_release(
operation.host_indices[operation.completed_tokens :]
......@@ -656,6 +673,7 @@ class HiCacheController:
def _storage_hit_query(self, operation) -> tuple[list[str], int]:
last_hash = operation.last_hash
tokens_to_fetch = operation.token_ids
prefix_keys = operation.prefix_keys.copy() if operation.prefix_keys else None
storage_query_count = 0
hash_value = []
......@@ -673,11 +691,15 @@ class HiCacheController:
batch_tokens[i : i + self.page_size], last_hash
)
batch_hashes.append(last_hash)
hit_page_num = self.storage_backend.batch_exists(batch_hashes)
extra_info = HiCacheStorageExtraInfo(prefix_keys=prefix_keys)
hit_page_num = self.storage_backend.batch_exists(batch_hashes, extra_info)
hash_value.extend(batch_hashes[:hit_page_num])
storage_query_count += hit_page_num * self.page_size
if hit_page_num < len(batch_hashes):
break
if prefix_keys and len(prefix_keys) > 0:
prefix_keys += batch_hashes
return hash_value, storage_query_count
def prefetch_thread_func(self):
......@@ -734,28 +756,34 @@ class HiCacheController:
host_indices: torch.Tensor,
token_ids: List[int],
hash_value: Optional[List[str]] = None,
prefix_keys: Optional[List[str]] = None,
) -> int:
"""
Write KV caches from host memory to storage backend.
"""
operation = StorageOperation(host_indices, token_ids, hash_value=hash_value)
operation = StorageOperation(
host_indices, token_ids, hash_value=hash_value, prefix_keys=prefix_keys
)
self.backup_queue.put(operation)
return operation.id
# todo: deprecate
def _generic_page_set(self, hash_values, host_indices) -> bool:
def _generic_page_set(self, hash_values, host_indices, extra_info=None) -> bool:
data = [
self.mem_pool_host.get_data_page(host_indices[i * self.page_size])
for i in range(len(hash_values))
]
return self.storage_backend.batch_set(hash_values, data)
def _page_set_zero_copy(self, hash_values, host_indices) -> bool:
return all(self.storage_backend.batch_set_v1(hash_values, host_indices))
def _page_set_zero_copy(self, hash_values, host_indices, extra_info=None) -> bool:
return all(
self.storage_backend.batch_set_v1(hash_values, host_indices, extra_info)
)
# Backup batch by batch
def _page_backup(self, operation):
# Backup batch by batch
prefix_keys = operation.prefix_keys
for i in range(0, len(operation.hash_value), self.storage_batch_size):
batch_hashes = operation.hash_value[i : i + self.storage_batch_size]
batch_host_indices = operation.host_indices[
......@@ -763,12 +791,16 @@ class HiCacheController:
]
# Set one batch token, and record if success.
# todo: allow partial success
success = self.page_set_func(batch_hashes, batch_host_indices)
extra_info = HiCacheStorageExtraInfo(prefix_keys=prefix_keys)
success = self.page_set_func(batch_hashes, batch_host_indices, extra_info)
if not success:
logger.warning(
f"Write page to storage: {len(batch_hashes)} pages failed."
)
break
if prefix_keys and len(prefix_keys) > 0:
prefix_keys += batch_hashes
operation.completed_tokens += self.page_size * len(batch_hashes)
def backup_thread_func(self):
......
......@@ -1491,8 +1491,18 @@ class Scheduler(
last_hash = req.last_host_node.get_last_hash_value()
matched_len = len(req.prefix_indices) + req.host_hit_length
new_input_tokens = req.fill_ids[matched_len:]
prefix_keys = (
req.last_node.get_prefix_hash_values(req.last_node.parent)
if self.tree_cache.hicache_storage_pass_prefix_keys
else None
)
self.tree_cache.prefetch_from_storage(
req.rid, req.last_host_node, new_input_tokens, last_hash
req.rid,
req.last_host_node,
new_input_tokens,
last_hash,
prefix_keys,
)
def _add_request_to_queue(self, req: Req, is_retracted: bool = False):
......
......@@ -36,6 +36,7 @@ class HiCacheStorageConfig:
@dataclass
class HiCacheStorageExtraInfo:
prefix_keys: Optional[List[str]] = (None,)
extra_info: Optional[dict] = None
......@@ -139,7 +140,9 @@ class HiCacheStorage(ABC):
pass
# TODO: Use a finer-grained return type (e.g., List[bool])
def batch_exists(self, keys: List[str]) -> int:
def batch_exists(
self, keys: List[str], extra_info: Optional[HiCacheStorageExtraInfo] = None
) -> int:
"""
Check if the keys exist in the storage.
return the number of consecutive existing keys from the start.
......
......@@ -84,12 +84,14 @@ class HiRadixCache(RadixCache):
prefetch_threshold,
prefetch_timeout_base,
prefetch_timeout_per_ki_token,
hicache_storage_pass_prefix_keys,
) = self._parse_storage_backend_extra_config(storage_backend_extra_config)
self.prefetch_threshold = prefetch_threshold
self.prefetch_timeout_base = prefetch_timeout_base
self.prefetch_timeout_per_page = (
page_size / 1024 * prefetch_timeout_per_ki_token
)
self.hicache_storage_pass_prefix_keys = hicache_storage_pass_prefix_keys
# TODO: support more timeout check functions
self.is_prefetch_timeout = self._prefetch_timeout_check_linear_func
self.prefetch_stop_policy = hicache_storage_prefetch_policy
......@@ -149,7 +151,7 @@ class HiRadixCache(RadixCache):
storage_backend_extra_config: JSON string containing extra configuration
Returns:
tuple: (extra_config_dict, prefetch_threshold, prefetch_timeout_base, prefetch_timeout_per_ki_token)
tuple: (extra_config_dict, prefetch_threshold, prefetch_timeout_base, prefetch_timeout_per_ki_token, hicache_storage_pass_prefix_keys)
"""
# Parse extra config JSON if provided
extra_config = {}
......@@ -165,6 +167,9 @@ class HiRadixCache(RadixCache):
prefetch_timeout_per_ki_token = extra_config.pop(
"prefetch_timeout_per_ki_token", 0.25
) # seconds per 1024 tokens
hicache_storage_pass_prefix_keys = extra_config.pop(
"hicache_storage_pass_prefix_keys", False
)
if not isinstance(prefetch_threshold, int):
raise ValueError(
......@@ -184,6 +189,7 @@ class HiRadixCache(RadixCache):
prefetch_threshold,
float(prefetch_timeout_base),
float(prefetch_timeout_per_ki_token),
hicache_storage_pass_prefix_keys,
)
def reset(self):
......@@ -245,8 +251,14 @@ class HiRadixCache(RadixCache):
return len(host_indices)
def write_backup_storage(self, node: TreeNode):
prefix_keys = (
node.get_prefix_hash_values(node.parent)
if self.hicache_storage_pass_prefix_keys
else None
)
operation_id = self.cache_controller.write_storage(
node.host_value, node.key, node.hash_value
node.host_value, node.key, node.hash_value, prefix_keys
)
self.ongoing_backup[operation_id] = node
node.protect_host()
......@@ -700,6 +712,7 @@ class HiRadixCache(RadixCache):
last_host_node: TreeNode,
new_input_tokens: List[int],
last_hash: Optional[str] = None,
prefix_keys: Optional[List[str]] = None,
):
# align the number of fetching tokens to the page size
prefetch_length = len(new_input_tokens) - (
......@@ -723,7 +736,7 @@ class HiRadixCache(RadixCache):
# no sufficient host memory for prefetch
return
operation = self.cache_controller.prefetch(
req_id, host_indices, new_input_tokens, last_hash
req_id, host_indices, new_input_tokens, last_hash, prefix_keys
)
self.ongoing_prefetch[req_id] = (
last_host_node,
......
......@@ -22,7 +22,7 @@ The radix tree data structure for managing the KV cache.
import heapq
import time
from collections import defaultdict
from functools import partial
from functools import lru_cache, partial
from typing import TYPE_CHECKING, Any, Iterator, List, Optional, Tuple, Union
import torch
......@@ -114,6 +114,13 @@ class TreeNode:
return None
return self.hash_value[-1]
@lru_cache(maxsize=1)
def get_prefix_hash_values(self, node: TreeNode) -> List[str]:
if node is None or node.hash_value is None:
return []
return node.get_prefix_hash_values(node.parent) + node.hash_value
def __lt__(self, other: "TreeNode"):
return self.last_access_time < other.last_access_time
......
......@@ -13,7 +13,11 @@ from aibrix_kvcache import (
)
from aibrix_kvcache.common.absl_logging import log_every_n_seconds
from sglang.srt.mem_cache.hicache_storage import HiCacheStorage, HiCacheStorageConfig
from sglang.srt.mem_cache.hicache_storage import (
HiCacheStorage,
HiCacheStorageConfig,
HiCacheStorageExtraInfo,
)
from sglang.srt.mem_cache.memory_pool_host import HostKVCache
logger = logging.getLogger(__name__)
......@@ -140,7 +144,9 @@ class AibrixKVCacheStorage(HiCacheStorage):
) -> bool:
return self.batch_set([key], [value], [target_location], [target_size])
def batch_exists(self, keys: List[str]) -> int:
def batch_exists(
self, keys: List[str], extra_info: Optional[HiCacheStorageExtraInfo] = None
) -> int:
block_hash = BlockHashes(keys, self.page_size)
status = self.kv_cache_manager.exists(None, block_hash)
if status.is_ok():
......
......@@ -408,7 +408,9 @@ class EICStorage(HiCacheStorage):
exist_num = self.batch_exists([key])
return exist_num == 1
def batch_exists(self, keys) -> int:
def batch_exists(
self, keys, extra_info: Optional[HiCacheStorageExtraInfo] = None
) -> int:
if len(keys) == 0:
return 0
if self.use_zero_copy and not self.is_mla_model:
......
......@@ -454,7 +454,9 @@ class HiCacheHF3FS(HiCacheStorage):
result = self.metadata_client.exists(self.rank, [key])
return result[0] if result else False
def batch_exists(self, keys: List[str]) -> int:
def batch_exists(
self, keys: List[str], extra_info: Optional[HiCacheStorageExtraInfo] = None
) -> int:
factor = 1
if self.is_zero_copy and not self.is_mla_model:
keys = self._get_mha_zero_copy_keys(keys)
......
......@@ -399,7 +399,9 @@ class MooncakeStore(HiCacheStorage):
exist_result = self._batch_exist([key])
return exist_result[0] == 1
def batch_exists(self, keys) -> int:
def batch_exists(
self, keys, extra_info: Optional[HiCacheStorageExtraInfo] = None
) -> int:
if self.is_mla_backend:
query_keys = [f"{key}_k" for key in keys]
key_multiplier = 1
......
......@@ -29,6 +29,7 @@ class HiCacheStorage3FSBackendBaseMixin(HiCacheStorageBaseMixin):
"numjobs": 2,
"entries": 8,
"use_mock_hf3fs_client": True,
"hicache_storage_pass_prefix_keys": True,
}
# Write config to temporary file
......
......@@ -4,6 +4,7 @@ Usage:
python3 -m pytest test/srt/hicache/test_hicache_storage_e2e.py -v
"""
import json
import os
import random
import tempfile
......@@ -70,6 +71,9 @@ class HiCacheStorageBaseMixin:
@classmethod
def _get_base_server_args(cls):
"""Get base server arguments - can be extended in subclasses"""
extra_config = {
"hicache_storage_pass_prefix_keys": True,
}
return {
"--enable-hierarchical-cache": True,
"--mem-fraction-static": 0.6,
......@@ -78,6 +82,7 @@ class HiCacheStorageBaseMixin:
"--enable-cache-report": True,
"--hicache-storage-prefetch-policy": "wait_complete",
"--hicache-storage-backend": "file",
"--hicache-storage-backend-extra-config": json.dumps(extra_config),
}
@classmethod
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment