Unverified Commit bfa27438 authored by ykwd's avatar ykwd Committed by GitHub
Browse files

[HiCache] Configurable and Dynamic Prefetch Timeout (#10512)


Co-authored-by: default avatarZhiqiang Xie <xiezhq@stanford.edu>
parent 86cb4db0
...@@ -250,7 +250,7 @@ class HiCacheController: ...@@ -250,7 +250,7 @@ class HiCacheController:
storage_backend: Optional[str] = None, storage_backend: Optional[str] = None,
prefetch_threshold: int = 256, prefetch_threshold: int = 256,
model_name: Optional[str] = None, model_name: Optional[str] = None,
storage_backend_extra_config: Optional[str] = None, storage_backend_extra_config: Optional[dict] = None,
): ):
self.mem_pool_device_allocator = token_to_kv_pool_allocator self.mem_pool_device_allocator = token_to_kv_pool_allocator
self.mem_pool_device = token_to_kv_pool_allocator.get_kvcache() self.mem_pool_device = token_to_kv_pool_allocator.get_kvcache()
...@@ -361,7 +361,7 @@ class HiCacheController: ...@@ -361,7 +361,7 @@ class HiCacheController:
def _generate_storage_config( def _generate_storage_config(
self, self,
model_name: Optional[str] = None, model_name: Optional[str] = None,
storage_backend_extra_config: Optional[str] = None, storage_backend_extra_config: Optional[dict] = None,
): ):
if is_dp_attention_enabled(): if is_dp_attention_enabled():
...@@ -376,23 +376,13 @@ class HiCacheController: ...@@ -376,23 +376,13 @@ class HiCacheController:
# Currently, AscendMLAPagedTokenToKVPool is the subclass of MLATokenToKVPool. # Currently, AscendMLAPagedTokenToKVPool is the subclass of MLATokenToKVPool.
is_mla_backend = isinstance(self.mem_pool_device, MLATokenToKVPool) is_mla_backend = isinstance(self.mem_pool_device, MLATokenToKVPool)
# Parse extra config JSON if provided
extra_config = None
if storage_backend_extra_config:
try:
import json
extra_config = json.loads(storage_backend_extra_config)
except Exception as e:
logger.error(f"Invalid backend extra config JSON: {e}")
return HiCacheStorageConfig( return HiCacheStorageConfig(
tp_rank=self.tp_rank, tp_rank=self.tp_rank,
tp_size=self.tp_size, tp_size=self.tp_size,
is_mla_model=is_mla_backend, is_mla_model=is_mla_backend,
is_page_first_layout=self.mem_pool_host.layout == "page_first", is_page_first_layout=self.mem_pool_host.layout == "page_first",
model_name=model_name, model_name=model_name,
extra_config=extra_config, extra_config=storage_backend_extra_config,
) )
def reset(self): def reset(self):
......
import heapq import heapq
import json
import logging import logging
import threading import threading
import time import time
from queue import Queue
from typing import List, Optional from typing import List, Optional
import torch import torch
...@@ -78,9 +78,19 @@ class HiRadixCache(RadixCache): ...@@ -78,9 +78,19 @@ class HiRadixCache(RadixCache):
self.enable_storage = hicache_storage_backend is not None self.enable_storage = hicache_storage_backend is not None
self.enable_storage_metrics = self.enable_storage and enable_metrics self.enable_storage_metrics = self.enable_storage and enable_metrics
# todo: customizable storage prefetch threshold and timeout (
self.prefetch_threshold = 256 extra_config,
self.prefetch_timeout = 3 # seconds prefetch_threshold,
prefetch_timeout_base,
prefetch_timeout_per_ki_token,
) = self._parse_storage_backend_extra_config(storage_backend_extra_config)
self.prefetch_threshold = prefetch_threshold
self.prefetch_timeout_base = prefetch_timeout_base
self.prefetch_timeout_per_page = (
page_size / 1024 * prefetch_timeout_per_ki_token
)
# TODO: support more timeout check functions
self.is_prefetch_timeout = self._prefetch_timeout_check_linear_func
self.prefetch_stop_policy = hicache_storage_prefetch_policy self.prefetch_stop_policy = hicache_storage_prefetch_policy
self.load_cache_event = threading.Event() self.load_cache_event = threading.Event()
...@@ -95,7 +105,7 @@ class HiRadixCache(RadixCache): ...@@ -95,7 +105,7 @@ class HiRadixCache(RadixCache):
storage_backend=hicache_storage_backend, storage_backend=hicache_storage_backend,
prefetch_threshold=self.prefetch_threshold, prefetch_threshold=self.prefetch_threshold,
model_name=model_name, model_name=model_name,
storage_backend_extra_config=storage_backend_extra_config, storage_backend_extra_config=extra_config,
) )
if self.enable_storage_metrics: if self.enable_storage_metrics:
# TODO: support pp # TODO: support pp
...@@ -127,6 +137,53 @@ class HiRadixCache(RadixCache): ...@@ -127,6 +137,53 @@ class HiRadixCache(RadixCache):
eviction_policy=eviction_policy, eviction_policy=eviction_policy,
) )
def _parse_storage_backend_extra_config(
self, storage_backend_extra_config: Optional[str]
):
"""
Parse storage backend extra config JSON and extract specific parameters.
Args:
storage_backend_extra_config: JSON string containing extra configuration
Returns:
tuple: (extra_config_dict, prefetch_threshold, prefetch_timeout_base, prefetch_timeout_per_ki_token)
"""
# Parse extra config JSON if provided
extra_config = {}
if storage_backend_extra_config:
try:
extra_config = json.loads(storage_backend_extra_config)
except Exception as e:
logger.error(f"Invalid backend extra config JSON: {e}")
raise e
prefetch_threshold = extra_config.pop("prefetch_threshold", 256) # tokens
prefetch_timeout_base = extra_config.pop("prefetch_timeout_base", 1) # seconds
prefetch_timeout_per_ki_token = extra_config.pop(
"prefetch_timeout_per_ki_token", 0.25
) # seconds per 1024 tokens
if not isinstance(prefetch_threshold, int):
raise ValueError(
f"prefetch_threshold must be int, got {type(prefetch_threshold).__name__}"
)
if not isinstance(prefetch_timeout_base, (int, float)):
raise ValueError(
f"prefetch_timeout_base must be number, got {type(prefetch_timeout_base).__name__}"
)
if not isinstance(prefetch_timeout_per_ki_token, (int, float)):
raise ValueError(
f"prefetch_timeout_per_ki_token must be number, got {type(prefetch_timeout_per_ki_token).__name__}"
)
return (
extra_config,
prefetch_threshold,
float(prefetch_timeout_base),
float(prefetch_timeout_per_ki_token),
)
def reset(self): def reset(self):
TreeNode.counter = 0 TreeNode.counter = 0
self.cache_controller.reset() self.cache_controller.reset()
...@@ -490,6 +547,15 @@ class HiRadixCache(RadixCache): ...@@ -490,6 +547,15 @@ class HiRadixCache(RadixCache):
host_indices = torch.cat(host_indices_list, dim=0) host_indices = torch.cat(host_indices_list, dim=0)
cc.mem_pool_host.free(host_indices) cc.mem_pool_host.free(host_indices)
# Timeout is linearly increasing with the number of pages
def _prefetch_timeout_check_linear_func(self, operation: PrefetchOperation):
# If hash_value has not been computed in timeout_base seconds, terminate it.
return (
time.monotonic() - operation.start_time
> self.prefetch_timeout_base
+ len(operation.hash_value) * self.prefetch_timeout_per_page
)
def can_terminate_prefetch(self, operation: PrefetchOperation): def can_terminate_prefetch(self, operation: PrefetchOperation):
can_terminate = True can_terminate = True
...@@ -506,9 +572,7 @@ class HiRadixCache(RadixCache): ...@@ -506,9 +572,7 @@ class HiRadixCache(RadixCache):
if self.prefetch_stop_policy == "wait_complete": if self.prefetch_stop_policy == "wait_complete":
can_terminate = completed can_terminate = completed
elif self.prefetch_stop_policy == "timeout": elif self.prefetch_stop_policy == "timeout":
can_terminate = completed or ( can_terminate = completed or self.is_prefetch_timeout(operation)
time.monotonic() - operation.start_time > self.prefetch_timeout
)
else: else:
# unknown prefetch stop policy, just return True # unknown prefetch stop policy, just return True
return True return True
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment