Unverified Commit bfa27438 authored by ykwd's avatar ykwd Committed by GitHub
Browse files

[HiCache] Configurable and Dynamic Prefetch Timeout (#10512)


Co-authored-by: default avatarZhiqiang Xie <xiezhq@stanford.edu>
parent 86cb4db0
......@@ -250,7 +250,7 @@ class HiCacheController:
storage_backend: Optional[str] = None,
prefetch_threshold: int = 256,
model_name: Optional[str] = None,
storage_backend_extra_config: Optional[str] = None,
storage_backend_extra_config: Optional[dict] = None,
):
self.mem_pool_device_allocator = token_to_kv_pool_allocator
self.mem_pool_device = token_to_kv_pool_allocator.get_kvcache()
......@@ -361,7 +361,7 @@ class HiCacheController:
def _generate_storage_config(
self,
model_name: Optional[str] = None,
storage_backend_extra_config: Optional[str] = None,
storage_backend_extra_config: Optional[dict] = None,
):
if is_dp_attention_enabled():
......@@ -376,23 +376,13 @@ class HiCacheController:
# Currently, AscendMLAPagedTokenToKVPool is the subclass of MLATokenToKVPool.
is_mla_backend = isinstance(self.mem_pool_device, MLATokenToKVPool)
# Parse extra config JSON if provided
extra_config = None
if storage_backend_extra_config:
try:
import json
extra_config = json.loads(storage_backend_extra_config)
except Exception as e:
logger.error(f"Invalid backend extra config JSON: {e}")
return HiCacheStorageConfig(
tp_rank=self.tp_rank,
tp_size=self.tp_size,
is_mla_model=is_mla_backend,
is_page_first_layout=self.mem_pool_host.layout == "page_first",
model_name=model_name,
extra_config=extra_config,
extra_config=storage_backend_extra_config,
)
def reset(self):
......
import heapq
import json
import logging
import threading
import time
from queue import Queue
from typing import List, Optional
import torch
......@@ -78,9 +78,19 @@ class HiRadixCache(RadixCache):
self.enable_storage = hicache_storage_backend is not None
self.enable_storage_metrics = self.enable_storage and enable_metrics
# todo: customizable storage prefetch threshold and timeout
self.prefetch_threshold = 256
self.prefetch_timeout = 3 # seconds
(
extra_config,
prefetch_threshold,
prefetch_timeout_base,
prefetch_timeout_per_ki_token,
) = self._parse_storage_backend_extra_config(storage_backend_extra_config)
self.prefetch_threshold = prefetch_threshold
self.prefetch_timeout_base = prefetch_timeout_base
self.prefetch_timeout_per_page = (
page_size / 1024 * prefetch_timeout_per_ki_token
)
# TODO: support more timeout check functions
self.is_prefetch_timeout = self._prefetch_timeout_check_linear_func
self.prefetch_stop_policy = hicache_storage_prefetch_policy
self.load_cache_event = threading.Event()
......@@ -95,7 +105,7 @@ class HiRadixCache(RadixCache):
storage_backend=hicache_storage_backend,
prefetch_threshold=self.prefetch_threshold,
model_name=model_name,
storage_backend_extra_config=storage_backend_extra_config,
storage_backend_extra_config=extra_config,
)
if self.enable_storage_metrics:
# TODO: support pp
......@@ -127,6 +137,53 @@ class HiRadixCache(RadixCache):
eviction_policy=eviction_policy,
)
def _parse_storage_backend_extra_config(
self, storage_backend_extra_config: Optional[str]
):
"""
Parse storage backend extra config JSON and extract specific parameters.
Args:
storage_backend_extra_config: JSON string containing extra configuration
Returns:
tuple: (extra_config_dict, prefetch_threshold, prefetch_timeout_base, prefetch_timeout_per_ki_token)
"""
# Parse extra config JSON if provided
extra_config = {}
if storage_backend_extra_config:
try:
extra_config = json.loads(storage_backend_extra_config)
except Exception as e:
logger.error(f"Invalid backend extra config JSON: {e}")
raise e
prefetch_threshold = extra_config.pop("prefetch_threshold", 256) # tokens
prefetch_timeout_base = extra_config.pop("prefetch_timeout_base", 1) # seconds
prefetch_timeout_per_ki_token = extra_config.pop(
"prefetch_timeout_per_ki_token", 0.25
) # seconds per 1024 tokens
if not isinstance(prefetch_threshold, int):
raise ValueError(
f"prefetch_threshold must be int, got {type(prefetch_threshold).__name__}"
)
if not isinstance(prefetch_timeout_base, (int, float)):
raise ValueError(
f"prefetch_timeout_base must be number, got {type(prefetch_timeout_base).__name__}"
)
if not isinstance(prefetch_timeout_per_ki_token, (int, float)):
raise ValueError(
f"prefetch_timeout_per_ki_token must be number, got {type(prefetch_timeout_per_ki_token).__name__}"
)
return (
extra_config,
prefetch_threshold,
float(prefetch_timeout_base),
float(prefetch_timeout_per_ki_token),
)
def reset(self):
TreeNode.counter = 0
self.cache_controller.reset()
......@@ -490,6 +547,15 @@ class HiRadixCache(RadixCache):
host_indices = torch.cat(host_indices_list, dim=0)
cc.mem_pool_host.free(host_indices)
# Timeout is linearly increasing with the number of pages
def _prefetch_timeout_check_linear_func(self, operation: PrefetchOperation):
# If hash_value has not been computed in timeout_base seconds, terminate it.
return (
time.monotonic() - operation.start_time
> self.prefetch_timeout_base
+ len(operation.hash_value) * self.prefetch_timeout_per_page
)
def can_terminate_prefetch(self, operation: PrefetchOperation):
can_terminate = True
......@@ -506,9 +572,7 @@ class HiRadixCache(RadixCache):
if self.prefetch_stop_policy == "wait_complete":
can_terminate = completed
elif self.prefetch_stop_policy == "timeout":
can_terminate = completed or (
time.monotonic() - operation.start_time > self.prefetch_timeout
)
can_terminate = completed or self.is_prefetch_timeout(operation)
else:
# unknown prefetch stop policy, just return True
return True
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment