Unverified Commit 145482f4 authored by Zhiqiang Xie's avatar Zhiqiang Xie Committed by GitHub
Browse files

HiCache Storage TP Refinement (#8307)


Co-authored-by: default avatarpansicheng <sicheng.pan.chn@gmail.com>
parent 39fe1e88
...@@ -219,6 +219,7 @@ class HiCacheController: ...@@ -219,6 +219,7 @@ class HiCacheController:
token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator, token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
mem_pool_host: HostKVCache, mem_pool_host: HostKVCache,
page_size: int, page_size: int,
tp_group: torch.distributed.ProcessGroup,
load_cache_event: threading.Event = None, load_cache_event: threading.Event = None,
write_policy: str = "write_through_selective", write_policy: str = "write_through_selective",
io_backend: str = "", io_backend: str = "",
...@@ -244,11 +245,17 @@ class HiCacheController: ...@@ -244,11 +245,17 @@ class HiCacheController:
self.enable_storage = False self.enable_storage = False
# todo: move backend initialization to storage backend module # todo: move backend initialization to storage backend module
if storage_backend is not None: if storage_backend is not None:
# create a new communication group for synchronizing storage operations across TP workers
self.tp_world_size = torch.distributed.get_world_size(group=tp_group)
if self.tp_world_size > 1:
group_ranks = torch.distributed.get_process_group_ranks(tp_group)
self.tp_group = torch.distributed.new_group(group_ranks, backend="gloo")
if storage_backend == "file": if storage_backend == "file":
self.storage_backend = HiCacheFile() self.storage_backend = HiCacheFile()
self.enable_storage = True self.enable_storage = True
# todo: threshold policy for prefetching # todo: threshold policy for prefetching
self.prefetch_threshold = prefetch_threshold self.prefetch_threshold = max(prefetch_threshold, self.page_size)
else: else:
raise NotImplementedError( raise NotImplementedError(
f"Unsupported storage backend: {storage_backend}" f"Unsupported storage backend: {storage_backend}"
...@@ -568,13 +575,32 @@ class HiCacheController: ...@@ -568,13 +575,32 @@ class HiCacheController:
else: else:
break break
if self.tp_world_size > 1:
storage_hit_count_tensor = torch.tensor(
storage_hit_count, dtype=torch.int
)
torch.distributed.all_reduce(
storage_hit_count_tensor,
op=torch.distributed.ReduceOp.MIN,
group=self.tp_group,
)
storage_hit_count = storage_hit_count_tensor.item()
if storage_hit_count < self.prefetch_threshold: if storage_hit_count < self.prefetch_threshold:
# not to prefetch if not enough benefits # not to prefetch if not enough benefits
self.prefetch_revoke_queue.put(operation.request_id) self.prefetch_revoke_queue.put(operation.request_id)
logger.debug(
f"Revoking prefetch for request {operation.request_id} due to insufficient hits ({storage_hit_count})."
)
else: else:
operation.hash_value = hash_value operation.hash_value = hash_value[
: (storage_hit_count // self.page_size)
]
# free the pre-allocated memory for pages that are not hit
self.mem_pool_host.free(operation.host_indices[storage_hit_count:])
operation.host_indices = operation.host_indices[:storage_hit_count]
logger.debug( logger.debug(
f"Prefetching {len(hash_value)} pages for request {operation.request_id}." f"Prefetching {len(operation.hash_value)} pages for request {operation.request_id}."
) )
self.prefetch_buffer.put(operation) self.prefetch_buffer.put(operation)
...@@ -611,17 +637,37 @@ class HiCacheController: ...@@ -611,17 +637,37 @@ class HiCacheController:
last_hash = get_hash_str( last_hash = get_hash_str(
tokens_to_backup[i : i + self.page_size], last_hash tokens_to_backup[i : i + self.page_size], last_hash
) )
# todo, handle failures in storage backend success = self.storage_backend.set(
self.storage_backend.set(
last_hash, last_hash,
self.mem_pool_host.get_flat_data_page( self.mem_pool_host.get_flat_data_page(
operation.host_indices[i] operation.host_indices[i]
), ),
) )
if not success:
logger.warning(f"Failed to write page {last_hash} to storage.")
break
operation.completed_tokens += self.page_size operation.completed_tokens += self.page_size
operation.hash_value.append(last_hash) operation.hash_value.append(last_hash)
self.ack_backup_queue.put((operation.id, operation.hash_value)) min_completed_tokens = operation.completed_tokens
if self.tp_world_size > 1:
completed_tokens_tensor = torch.tensor(
min_completed_tokens, dtype=torch.int
)
torch.distributed.all_reduce(
completed_tokens_tensor,
op=torch.distributed.ReduceOp.MIN,
group=self.tp_group,
)
min_completed_tokens = completed_tokens_tensor.item()
self.ack_backup_queue.put(
(
operation.id,
operation.hash_value[: min_completed_tokens // self.page_size],
min_completed_tokens,
)
)
except Empty: except Empty:
continue continue
...@@ -9,6 +9,12 @@ import torch ...@@ -9,6 +9,12 @@ import torch
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
from sglang.srt.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
def get_hash_str(token_ids: List[int], prior_hash: Optional[str] = None) -> str: def get_hash_str(token_ids: List[int], prior_hash: Optional[str] = None) -> str:
hasher = hashlib.sha256() hasher = hashlib.sha256()
...@@ -80,13 +86,20 @@ class HiCacheFile(HiCacheStorage): ...@@ -80,13 +86,20 @@ class HiCacheFile(HiCacheStorage):
def __init__(self, file_path: str = "/tmp/hicache"): def __init__(self, file_path: str = "/tmp/hicache"):
self.file_path = file_path self.file_path = file_path
if not os.path.exists(self.file_path): tp_rank = get_tensor_model_parallel_rank()
tp_size = get_tensor_model_parallel_world_size()
self.tp_suffix = f"_{tp_rank}_{tp_size}" if tp_size > 1 else ""
if not os.path.exists(self.file_path) and tp_rank == 0:
os.makedirs(self.file_path) os.makedirs(self.file_path)
logger.info(f"Created HiCacheFile storage directory at {self.file_path}") logger.info(f"Created HiCacheFile storage directory at {self.file_path}")
def _get_suffixed_key(self, key: str) -> str:
return key + self.tp_suffix
def get( def get(
self, key: str, target_location: Optional[torch.Tensor] = None self, key: str, target_location: Optional[torch.Tensor] = None
) -> torch.Tensor | None: ) -> torch.Tensor | None:
key = self._get_suffixed_key(key)
tensor_path = os.path.join(self.file_path, f"{key}.bin") tensor_path = os.path.join(self.file_path, f"{key}.bin")
try: try:
# todo: fixing the target_location logic to enable in-place loading # todo: fixing the target_location logic to enable in-place loading
...@@ -112,6 +125,7 @@ class HiCacheFile(HiCacheStorage): ...@@ -112,6 +125,7 @@ class HiCacheFile(HiCacheStorage):
] ]
def set(self, key: str, value: torch.Tensor) -> bool: def set(self, key: str, value: torch.Tensor) -> bool:
key = self._get_suffixed_key(key)
tensor_path = os.path.join(self.file_path, f"{key}.bin") tensor_path = os.path.join(self.file_path, f"{key}.bin")
if self.exists(key): if self.exists(key):
logger.debug(f"Key {key} already exists. Skipped.") logger.debug(f"Key {key} already exists. Skipped.")
...@@ -130,10 +144,12 @@ class HiCacheFile(HiCacheStorage): ...@@ -130,10 +144,12 @@ class HiCacheFile(HiCacheStorage):
return True return True
def exists(self, key: str) -> bool: def exists(self, key: str) -> bool:
key = self._get_suffixed_key(key)
tensor_path = os.path.join(self.file_path, f"{key}.bin") tensor_path = os.path.join(self.file_path, f"{key}.bin")
return os.path.exists(tensor_path) return os.path.exists(tensor_path)
def delete(self, key: str) -> None: def delete(self, key: str) -> None:
key = self._get_suffixed_key(key)
tensor_path = os.path.join(self.file_path, f"{key}.bin") tensor_path = os.path.join(self.file_path, f"{key}.bin")
try: try:
os.remove(tensor_path) os.remove(tensor_path)
......
...@@ -50,6 +50,7 @@ class HiRadixCache(RadixCache): ...@@ -50,6 +50,7 @@ class HiRadixCache(RadixCache):
raise ValueError(f"HiRadixCache only supports MHA and MLA yet") raise ValueError(f"HiRadixCache only supports MHA and MLA yet")
self.tp_group = tp_cache_group self.tp_group = tp_cache_group
self.tp_world_size = torch.distributed.get_world_size(group=self.tp_group)
self.enable_storage = hicache_storage_backend is not None self.enable_storage = hicache_storage_backend is not None
# todo: customizable storage prefetch threshold # todo: customizable storage prefetch threshold
self.prefetch_threshold = 256 self.prefetch_threshold = 256
...@@ -59,6 +60,7 @@ class HiRadixCache(RadixCache): ...@@ -59,6 +60,7 @@ class HiRadixCache(RadixCache):
token_to_kv_pool_allocator, token_to_kv_pool_allocator,
self.token_to_kv_pool_host, self.token_to_kv_pool_host,
page_size, page_size,
self.tp_group,
load_cache_event=self.load_cache_event, load_cache_event=self.load_cache_event,
write_policy=hicache_write_policy, write_policy=hicache_write_policy,
io_backend=hicache_io_backend, io_backend=hicache_io_backend,
...@@ -153,7 +155,7 @@ class HiRadixCache(RadixCache): ...@@ -153,7 +155,7 @@ class HiRadixCache(RadixCache):
queue_size = torch.tensor( queue_size = torch.tensor(
self.cache_controller.ack_write_queue.qsize(), dtype=torch.int self.cache_controller.ack_write_queue.qsize(), dtype=torch.int
) )
if torch.distributed.get_world_size(group=self.tp_group) > 1: if self.tp_world_size > 1:
# synchrnoize TP workers to make the same update to radix cache # synchrnoize TP workers to make the same update to radix cache
torch.distributed.all_reduce( torch.distributed.all_reduce(
queue_size, queue_size,
...@@ -353,7 +355,7 @@ class HiRadixCache(RadixCache): ...@@ -353,7 +355,7 @@ class HiRadixCache(RadixCache):
queue_size = torch.tensor( queue_size = torch.tensor(
self.cache_controller.prefetch_revoke_queue.qsize(), dtype=torch.int self.cache_controller.prefetch_revoke_queue.qsize(), dtype=torch.int
) )
if torch.distributed.get_world_size(group=self.tp_group) > 1: if self.tp_world_size > 1:
# synchrnoize TP workers to make the same update to hiradix cache # synchrnoize TP workers to make the same update to hiradix cache
torch.distributed.all_reduce( torch.distributed.all_reduce(
queue_size, queue_size,
...@@ -372,7 +374,7 @@ class HiRadixCache(RadixCache): ...@@ -372,7 +374,7 @@ class HiRadixCache(RadixCache):
queue_size = torch.tensor( queue_size = torch.tensor(
self.cache_controller.ack_backup_queue.qsize(), dtype=torch.int self.cache_controller.ack_backup_queue.qsize(), dtype=torch.int
) )
if torch.distributed.get_world_size(group=self.tp_group) > 1: if self.tp_world_size > 1:
# synchrnoize TP workers to make the same update to hiradix cache # synchrnoize TP workers to make the same update to hiradix cache
torch.distributed.all_reduce( torch.distributed.all_reduce(
queue_size, queue_size,
...@@ -380,9 +382,15 @@ class HiRadixCache(RadixCache): ...@@ -380,9 +382,15 @@ class HiRadixCache(RadixCache):
group=self.tp_group, group=self.tp_group,
) )
for _ in range(queue_size.item()): for _ in range(queue_size.item()):
ack_id, hash_value = self.cache_controller.ack_backup_queue.get() ack_id, hash_value, completed_tokens = (
self.ongoing_backup[ack_id].hash_value = hash_value self.cache_controller.ack_backup_queue.get()
self.ongoing_backup[ack_id].release_host() )
host_node = self.ongoing_backup[ack_id]
if completed_tokens < len(host_node.key):
# backup is only partially successful, split the node
new_node = self._split_node(host_node.key, host_node, completed_tokens)
new_node.hash_value = hash_value
host_node.release_host()
del self.ongoing_backup[ack_id] del self.ongoing_backup[ack_id]
def check_prefetch_progress(self, req_id: str): def check_prefetch_progress(self, req_id: str):
...@@ -400,15 +408,18 @@ class HiRadixCache(RadixCache): ...@@ -400,15 +408,18 @@ class HiRadixCache(RadixCache):
) )
logger.debug(f"Prefetch {req_id} completed with {completed_tokens} tokens") logger.debug(f"Prefetch {req_id} completed with {completed_tokens} tokens")
min_completed_tokens = torch.tensor(completed_tokens, dtype=torch.int) min_completed_tokens = completed_tokens
if torch.distributed.get_world_size(group=self.tp_group) > 1: if self.tp_world_size > 1:
# synchrnoize TP workers to make the same update to hiradix cache # synchrnoize TP workers to make the same update to hiradix cache
completed_tokens_tensor = torch.tensor(
min_completed_tokens, dtype=torch.int
)
torch.distributed.all_reduce( torch.distributed.all_reduce(
min_completed_tokens, completed_tokens_tensor,
op=torch.distributed.ReduceOp.MIN, op=torch.distributed.ReduceOp.MIN,
group=self.tp_group, group=self.tp_group,
) )
min_completed_tokens = min_completed_tokens.item() min_completed_tokens = completed_tokens_tensor.item()
fetched_token_ids = token_ids[:min_completed_tokens] fetched_token_ids = token_ids[:min_completed_tokens]
written_indices = host_indices[:min_completed_tokens] written_indices = host_indices[:min_completed_tokens]
matched_length = self._insert_helper_host( matched_length = self._insert_helper_host(
...@@ -465,16 +476,19 @@ class HiRadixCache(RadixCache): ...@@ -465,16 +476,19 @@ class HiRadixCache(RadixCache):
new_input_tokens: List[int], new_input_tokens: List[int],
last_hash: Optional[str] = None, last_hash: Optional[str] = None,
): ):
if not self.enable_storage or len(new_input_tokens) < self.prefetch_threshold: # align the number of fetching tokens to the page size
prefetch_length = len(new_input_tokens) - (
len(new_input_tokens) % self.page_size
)
new_input_tokens = new_input_tokens[:prefetch_length]
if not self.enable_storage or prefetch_length < self.prefetch_threshold:
return return
last_host_node.protect_host() last_host_node.protect_host()
host_indices = self.cache_controller.mem_pool_host.alloc(len(new_input_tokens)) host_indices = self.cache_controller.mem_pool_host.alloc(prefetch_length)
if host_indices is None: if host_indices is None:
self.evict_host(len(new_input_tokens)) self.evict_host(prefetch_length)
host_indices = self.cache_controller.mem_pool_host.alloc( host_indices = self.cache_controller.mem_pool_host.alloc(prefetch_length)
len(new_input_tokens)
)
if host_indices is None: if host_indices is None:
last_host_node.release_host() last_host_node.release_host()
# no sufficient host memory to prefetch # no sufficient host memory to prefetch
......
...@@ -126,6 +126,9 @@ class HostKVCache(abc.ABC): ...@@ -126,6 +126,9 @@ class HostKVCache(abc.ABC):
@synchronized() @synchronized()
def alloc(self, need_size: int) -> torch.Tensor: def alloc(self, need_size: int) -> torch.Tensor:
assert (
need_size % self.page_size == 0
), "The requested size should be a multiple of the page size."
if need_size > self.available_size(): if need_size > self.available_size():
return None return None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment