"tests/vscode:/vscode.git/clone" did not exist on "437cb36c65140cf71b8ea9419351ffa3aee62c14"
Unverified Commit 8b6966d0 authored by Zhiqiang Xie's avatar Zhiqiang Xie Committed by GitHub
Browse files

[HiCache] Storage Refactoring (#9797)


Co-authored-by: default avatarpansicheng <27603155+pansicheng@users.noreply.github.com>
parent a391f73a
...@@ -250,26 +250,21 @@ class HiCacheController: ...@@ -250,26 +250,21 @@ class HiCacheController:
self.write_policy = write_policy self.write_policy = write_policy
self.page_size = page_size self.page_size = page_size
self.io_backend = io_backend self.io_backend = io_backend
self.enable_storage = False self.enable_storage = False
# todo: move backend initialization to storage backend module
if storage_backend is not None: if storage_backend is not None:
self.storage_backend_type = storage_backend self.storage_backend_type = storage_backend
from sglang.srt.mem_cache.hicache_storage import get_hash_str from sglang.srt.mem_cache.hicache_storage import get_hash_str
self.get_hash_str = get_hash_str self.get_hash_str = get_hash_str
self.storage_config = self._generate_storage_config( self.storage_config = self._generate_storage_config(
model_name, storage_backend_extra_config model_name, storage_backend_extra_config
) )
# In MLA backend, only one rank needs to backup the KV cache # for MLA models, only one rank needs to backup the KV cache
self.backup_skip = ( self.backup_skip = (
self.storage_config.is_mla_model self.storage_config.is_mla_model
# todo: for load balancing, decide which rank to backup the KV cache by hash value # todo: load balancing
and self.storage_config.tp_rank != 0 and self.storage_config.tp_rank != 0
# todo: support other storage backends
and self.storage_backend_type in ["file", "mooncake"]
) )
if storage_backend == "file": if storage_backend == "file":
...@@ -309,12 +304,15 @@ class HiCacheController: ...@@ -309,12 +304,15 @@ class HiCacheController:
raise NotImplementedError( raise NotImplementedError(
f"Unsupported storage backend: {storage_backend}" f"Unsupported storage backend: {storage_backend}"
) )
self.enable_storage = True self.enable_storage = True
# todo: threshold policy for prefetching # todo: threshold policy for prefetching
self.prefetch_threshold = max(prefetch_threshold, self.page_size) self.prefetch_threshold = max(prefetch_threshold, self.page_size)
self.prefetch_capacity_limit = int( self.prefetch_capacity_limit = int(
0.8 * (self.mem_pool_host.size - self.mem_pool_device.size) 0.8 * (self.mem_pool_host.size - self.mem_pool_device.size)
) )
# granularity of batch storage IO operations, in number of pages
self.storage_batch_size = 128
# tracking the number of tokens locked in prefetching, updated by the main scheduler thread # tracking the number of tokens locked in prefetching, updated by the main scheduler thread
self.prefetch_tokens_occupied = 0 self.prefetch_tokens_occupied = 0
...@@ -325,12 +323,6 @@ class HiCacheController: ...@@ -325,12 +323,6 @@ class HiCacheController:
self.prefetch_tp_group = torch.distributed.new_group( self.prefetch_tp_group = torch.distributed.new_group(
group_ranks, backend="gloo" group_ranks, backend="gloo"
) )
self.prefetch_io_tp_group = torch.distributed.new_group(
group_ranks, backend="gloo"
)
self.backup_tp_group = torch.distributed.new_group(
group_ranks, backend="gloo"
)
self.load_cache_event = load_cache_event self.load_cache_event = load_cache_event
self.layer_done_counter = LayerDoneCounter(self.mem_pool_device.layer_num) self.layer_done_counter = LayerDoneCounter(self.mem_pool_device.layer_num)
...@@ -380,6 +372,7 @@ class HiCacheController: ...@@ -380,6 +372,7 @@ class HiCacheController:
self.prefetch_revoke_queue = Queue() self.prefetch_revoke_queue = Queue()
self.ack_backup_queue = Queue() self.ack_backup_queue = Queue()
self.host_mem_release_queue = Queue()
self.prefetch_thread.start() self.prefetch_thread.start()
self.backup_thread.start() self.backup_thread.start()
...@@ -618,7 +611,11 @@ class HiCacheController: ...@@ -618,7 +611,11 @@ class HiCacheController:
operation.mark_done() operation.mark_done()
return operation.completed_tokens, operation.hash_value return operation.completed_tokens, operation.hash_value
# zero copy def append_host_mem_release(self, host_indices: torch.Tensor):
chunks = host_indices.split(self.mem_pool_host.page_size)
for chunk in chunks:
self.host_mem_release_queue.put(chunk)
def _3fs_zero_copy_page_get(self, operation, hash_values, host_indices): def _3fs_zero_copy_page_get(self, operation, hash_values, host_indices):
hashes, dsts = self.mem_pool_host.get_buffer_with_hash( hashes, dsts = self.mem_pool_host.get_buffer_with_hash(
hash_values, host_indices hash_values, host_indices
...@@ -631,7 +628,6 @@ class HiCacheController: ...@@ -631,7 +628,6 @@ class HiCacheController:
f"Prefetch operation {operation.request_id} failed to retrieve page {hashes}." f"Prefetch operation {operation.request_id} failed to retrieve page {hashes}."
) )
# zero copy
def _mooncake_page_get(self, operation, hash_values, host_indices): def _mooncake_page_get(self, operation, hash_values, host_indices):
key_strs, buffer_ptrs, buffer_sizes = self.mem_pool_host.get_buffer_meta( key_strs, buffer_ptrs, buffer_sizes = self.mem_pool_host.get_buffer_meta(
hash_values, hash_values,
...@@ -650,9 +646,7 @@ class HiCacheController: ...@@ -650,9 +646,7 @@ class HiCacheController:
if get_result != 0: if get_result != 0:
operation.increment(get_result * self.page_size) operation.increment(get_result * self.page_size)
# non-zero copy
def _generic_page_get(self, operation, hash_values, host_indices): def _generic_page_get(self, operation, hash_values, host_indices):
# todo: zero copy
dummy_page_dst = [self.mem_pool_host.get_dummy_flat_data_page()] * len( dummy_page_dst = [self.mem_pool_host.get_dummy_flat_data_page()] * len(
hash_values hash_values
) )
...@@ -675,22 +669,19 @@ class HiCacheController: ...@@ -675,22 +669,19 @@ class HiCacheController:
def _page_transfer(self, operation): def _page_transfer(self, operation):
# Select the get function and batch size # Select the get function and batch size
if self.is_mooncake_backend(): if self.storage_backend_type == "mooncake":
get_func = self._mooncake_page_get get_func = self._mooncake_page_get
batch_size = 128 elif (
elif self.storage_backend_type == "hf3fs": self.storage_backend_type == "hf3fs"
if self.mem_pool_host.layout == "page_first": and self.mem_pool_host.layout == "page_first"
get_func = self._3fs_zero_copy_page_get ):
elif self.mem_pool_host.layout == "layer_first": get_func = self._3fs_zero_copy_page_get
get_func = self._generic_page_get
batch_size = 128
else: else:
get_func = self._generic_page_get get_func = self._generic_page_get
batch_size = 8
# Transfer batch by batch # Transfer batch by batch
for i in range(0, len(operation.hash_value), batch_size): for i in range(0, len(operation.hash_value), self.storage_batch_size):
batch_hashes = operation.hash_value[i : i + batch_size] batch_hashes = operation.hash_value[i : i + self.storage_batch_size]
batch_host_indices = operation.host_indices[ batch_host_indices = operation.host_indices[
i * self.page_size : (i + len(batch_hashes)) * self.page_size i * self.page_size : (i + len(batch_hashes)) * self.page_size
] ]
...@@ -704,10 +695,9 @@ class HiCacheController: ...@@ -704,10 +695,9 @@ class HiCacheController:
): ):
break # Some operations fail or operation terminated by controller break # Some operations fail or operation terminated by controller
# release pre-allocated memory # release pre-allocated memory
self.mem_pool_host.free(operation.host_indices[operation.completed_tokens :]) self.append_host_mem_release(
operation.host_indices[operation.completed_tokens :]
def is_mooncake_backend(self): )
return self.storage_backend_type == "mooncake"
def prefetch_io_aux_func(self): def prefetch_io_aux_func(self):
""" """
...@@ -717,47 +707,49 @@ class HiCacheController: ...@@ -717,47 +707,49 @@ class HiCacheController:
try: try:
operation = self.prefetch_buffer.get(block=True, timeout=1) operation = self.prefetch_buffer.get(block=True, timeout=1)
self._page_transfer(operation) self._page_transfer(operation)
if self.tp_world_size > 1:
# to ensure all TP workers release the host memory at the same time
torch.distributed.barrier(group=self.prefetch_io_tp_group)
# operation terminated by controller, release pre-allocated memory # operation terminated by controller, release pre-allocated memory
self.mem_pool_host.free( self.append_host_mem_release(
operation.host_indices[operation.completed_tokens :] operation.host_indices[operation.completed_tokens :]
) )
except Empty: except Empty:
continue continue
def prefetch_rate_limit_check(self) -> bool: def prefetch_rate_limited(self) -> bool:
""" """
Rate limit the prefetching operations to avoid overwhelming the storage backend. Rate limit the prefetching operations to avoid overwhelming the storage backend.
""" """
# cancel prefetch if too much memory is occupied # cancel prefetch if too much memory is occupied
if self.prefetch_tokens_occupied >= self.prefetch_capacity_limit: if self.prefetch_tokens_occupied >= self.prefetch_capacity_limit:
return False return True
# todo: more sophisticated rate limiting based on storage backend performance # todo: more sophisticated rate limiting based on storage backend performance
return True return False
def _generic_storage_hit_query(self, operation) -> tuple[list[str], int]: def _storage_hit_query(self, operation) -> tuple[list[str], int]:
last_hash = operation.last_hash last_hash = operation.last_hash
tokens_to_fetch = operation.token_ids tokens_to_fetch = operation.token_ids
storage_query_count = 0 storage_query_count = 0
remaining_tokens = len(tokens_to_fetch)
hash_value = [] hash_value = []
while remaining_tokens >= self.page_size:
last_hash = self.get_hash_str( for start in range(
tokens_to_fetch[ 0, len(tokens_to_fetch), self.page_size * self.storage_batch_size
storage_query_count : storage_query_count + self.page_size ):
], end = min(
last_hash, start + self.page_size * self.storage_batch_size, len(tokens_to_fetch)
) )
hash_value.append(last_hash) batch_tokens = tokens_to_fetch[start:end]
storage_query_count += self.page_size batch_hashes = []
remaining_tokens -= self.page_size for i in range(0, len(batch_tokens), self.page_size):
# deferring to batch exists last_hash = self.get_hash_str(
hit_page_num = self.storage_backend.batch_exists(hash_value) batch_tokens[i : i + self.page_size], last_hash
return hash_value[:hit_page_num], hit_page_num * self.page_size )
batch_hashes.append(last_hash)
hit_page_num = self.storage_backend.batch_exists(batch_hashes)
hash_value.extend(batch_hashes[:hit_page_num])
storage_query_count += hit_page_num * self.page_size
if hit_page_num < len(batch_hashes):
break
return hash_value, storage_query_count
def prefetch_thread_func(self): def prefetch_thread_func(self):
""" """
...@@ -772,13 +764,7 @@ class HiCacheController: ...@@ -772,13 +764,7 @@ class HiCacheController:
if operation is None: if operation is None:
continue continue
if ( hash_value, storage_hit_count = self._storage_hit_query(operation)
operation.host_indices is not None
) and self.prefetch_rate_limit_check():
hash_value, storage_hit_count = self._generic_storage_hit_query(
operation
)
if self.tp_world_size > 1: if self.tp_world_size > 1:
storage_hit_count_tensor = torch.tensor( storage_hit_count_tensor = torch.tensor(
storage_hit_count, dtype=torch.int storage_hit_count, dtype=torch.int
...@@ -793,8 +779,7 @@ class HiCacheController: ...@@ -793,8 +779,7 @@ class HiCacheController:
if storage_hit_count < self.prefetch_threshold: if storage_hit_count < self.prefetch_threshold:
# not to prefetch if not enough benefits # not to prefetch if not enough benefits
self.prefetch_revoke_queue.put(operation.request_id) self.prefetch_revoke_queue.put(operation.request_id)
if operation.host_indices is not None: self.append_host_mem_release(operation.host_indices)
self.mem_pool_host.free(operation.host_indices)
logger.debug( logger.debug(
f"Revoking prefetch for request {operation.request_id} due to insufficient hits ({storage_hit_count})." f"Revoking prefetch for request {operation.request_id} due to insufficient hits ({storage_hit_count})."
) )
...@@ -803,7 +788,9 @@ class HiCacheController: ...@@ -803,7 +788,9 @@ class HiCacheController:
: (storage_hit_count // self.page_size) : (storage_hit_count // self.page_size)
] ]
# free the pre-allocated memory for pages that are not hit # free the pre-allocated memory for pages that are not hit
self.mem_pool_host.free(operation.host_indices[storage_hit_count:]) self.append_host_mem_release(
operation.host_indices[storage_hit_count:]
)
operation.host_indices = operation.host_indices[:storage_hit_count] operation.host_indices = operation.host_indices[:storage_hit_count]
logger.debug( logger.debug(
f"Prefetching {len(operation.hash_value)} pages for request {operation.request_id}." f"Prefetching {len(operation.hash_value)} pages for request {operation.request_id}."
...@@ -858,21 +845,18 @@ class HiCacheController: ...@@ -858,21 +845,18 @@ class HiCacheController:
# Backup batch by batch # Backup batch by batch
def _page_backup(self, operation): def _page_backup(self, operation):
# Select the set function and batch size # Select the set function and batch size
if self.is_mooncake_backend(): if self.storage_backend_type == "mooncake":
backup_set_func = self._mooncake_page_set backup_set_func = self._mooncake_page_set
batch_size = 128 elif (
elif self.storage_backend_type == "hf3fs": self.storage_backend_type == "hf3fs"
if self.mem_pool_host.layout == "page_first": and self.mem_pool_host.layout == "page_first"
backup_set_func = self._3fs_zero_copy_page_set ):
elif self.mem_pool_host.layout == "layer_first": backup_set_func = self._3fs_zero_copy_page_set
backup_set_func = self._generic_page_set
batch_size = 128
else: else:
backup_set_func = self._generic_page_set backup_set_func = self._generic_page_set
batch_size = 8
# Backup batch by batch # Backup batch by batch
for i in range(0, len(operation.hash_value), batch_size): for i in range(0, len(operation.hash_value), self.storage_batch_size):
batch_hashes = operation.hash_value[i : i + batch_size] batch_hashes = operation.hash_value[i : i + self.storage_batch_size]
batch_host_indices = operation.host_indices[ batch_host_indices = operation.host_indices[
i * self.page_size : (i + len(batch_hashes)) * self.page_size i * self.page_size : (i + len(batch_hashes)) * self.page_size
] ]
...@@ -898,27 +882,7 @@ class HiCacheController: ...@@ -898,27 +882,7 @@ class HiCacheController:
if not self.backup_skip: if not self.backup_skip:
self._page_backup(operation) self._page_backup(operation)
min_completed_tokens = operation.completed_tokens self.ack_backup_queue.put(operation.id)
else:
min_completed_tokens = len(operation.token_ids)
if self.tp_world_size > 1:
completed_tokens_tensor = torch.tensor(
min_completed_tokens, dtype=torch.int
)
torch.distributed.all_reduce(
completed_tokens_tensor,
op=torch.distributed.ReduceOp.MIN,
group=self.backup_tp_group,
)
min_completed_tokens = completed_tokens_tensor.item()
self.ack_backup_queue.put(
(
operation.id,
min_completed_tokens,
)
)
except Empty: except Empty:
continue continue
...@@ -104,9 +104,6 @@ class HiRadixCache(RadixCache): ...@@ -104,9 +104,6 @@ class HiRadixCache(RadixCache):
self.write_through_threshold = ( self.write_through_threshold = (
1 if hicache_write_policy == "write_through" else 2 1 if hicache_write_policy == "write_through" else 2
) )
self.write_through_threshold_storage = (
1 if hicache_write_policy == "write_through" else 3
)
self.load_back_threshold = 10 self.load_back_threshold = 10
super().__init__( super().__init__(
req_to_token_pool, token_to_kv_pool_allocator, page_size, disable=False req_to_token_pool, token_to_kv_pool_allocator, page_size, disable=False
...@@ -174,14 +171,6 @@ class HiRadixCache(RadixCache): ...@@ -174,14 +171,6 @@ class HiRadixCache(RadixCache):
if node.hit_count >= self.write_through_threshold: if node.hit_count >= self.write_through_threshold:
# write to host if the node is not backuped # write to host if the node is not backuped
self.write_backup(node) self.write_backup(node)
else:
if (
self.enable_storage
and (not node.backuped_storage)
and node.hit_count >= self.write_through_threshold_storage
):
# if the node is backuped on host memory but not on storage
self.write_backup_storage(node)
def writing_check(self, write_back=False): def writing_check(self, write_back=False):
if write_back: if write_back:
...@@ -202,8 +191,11 @@ class HiRadixCache(RadixCache): ...@@ -202,8 +191,11 @@ class HiRadixCache(RadixCache):
) )
for _ in range(queue_size.item()): for _ in range(queue_size.item()):
ack_id = self.cache_controller.ack_write_queue.get() ack_id = self.cache_controller.ack_write_queue.get()
self.dec_lock_ref(self.ongoing_write_through[ack_id]) backuped_node = self.ongoing_write_through[ack_id]
self.dec_lock_ref(backuped_node)
del self.ongoing_write_through[ack_id] del self.ongoing_write_through[ack_id]
if self.enable_storage:
self.write_backup_storage(backuped_node)
def loading_check(self): def loading_check(self):
while not self.cache_controller.ack_load_queue.empty(): while not self.cache_controller.ack_load_queue.empty():
...@@ -386,57 +378,54 @@ class HiRadixCache(RadixCache): ...@@ -386,57 +378,54 @@ class HiRadixCache(RadixCache):
self.writing_check() self.writing_check()
self.loading_check() self.loading_check()
if self.enable_storage: if self.enable_storage:
self.check_revoked_prefetch() self.drain_storage_control_queues()
self.check_backup_progress()
def drain_storage_control_queues(self):
def check_revoked_prefetch(self): """
queue_size = torch.tensor( Combine prefetch revoke, backup ack, and host mem release checks
self.cache_controller.prefetch_revoke_queue.qsize(), dtype=torch.int to minimize TP synchronization and Python overhead.
"""
cc = self.cache_controller
qsizes = torch.tensor(
[
cc.prefetch_revoke_queue.qsize(),
cc.ack_backup_queue.qsize(),
cc.host_mem_release_queue.qsize(),
],
dtype=torch.int,
) )
if self.tp_world_size > 1: if self.tp_world_size > 1:
# synchrnoize TP workers to make the same update to hiradix cache
torch.distributed.all_reduce( torch.distributed.all_reduce(
queue_size, qsizes, op=torch.distributed.ReduceOp.MIN, group=self.tp_group
op=torch.distributed.ReduceOp.MIN,
group=self.tp_group,
) )
for _ in range(queue_size.item()):
req_id = self.cache_controller.prefetch_revoke_queue.get()
if req_id in self.ongoing_prefetch:
last_host_node, token_ids, _, _ = self.ongoing_prefetch[req_id]
last_host_node.release_host()
del self.ongoing_prefetch[req_id]
self.cache_controller.prefetch_tokens_occupied -= len(token_ids)
else:
# the revoked operation already got terminated
pass
def check_backup_progress(self): n_revoke, n_backup, n_release = map(int, qsizes.tolist())
queue_size = torch.tensor(
self.cache_controller.ack_backup_queue.qsize(), dtype=torch.int # process prefetch revokes
) for _ in range(n_revoke):
if self.tp_world_size > 1: req_id = cc.prefetch_revoke_queue.get()
# synchrnoize TP workers to make the same update to hiradix cache info = self.ongoing_prefetch.pop(req_id, None)
torch.distributed.all_reduce( if info is not None:
queue_size, last_host_node, token_ids, _, _ = info
op=torch.distributed.ReduceOp.MIN, last_host_node.release_host()
group=self.tp_group, cc.prefetch_tokens_occupied -= len(token_ids)
) # else: the revoked operation already got terminated, nothing to do
for _ in range(queue_size.item()):
ack_id, completed_tokens = self.cache_controller.ack_backup_queue.get() # process backup acks
host_node = self.ongoing_backup[ack_id] for _ in range(n_backup):
ack_id = cc.ack_backup_queue.get()
if completed_tokens > 0: entry = self.ongoing_backup.pop(ack_id, None)
if completed_tokens < len(host_node.key): if entry is not None:
# backup is only partially successful, split the node entry.release_host()
new_node = self._split_node(
host_node.key, host_node, completed_tokens # release host memory
) host_indices_list = []
new_node.backuped_storage = True for _ in range(n_release):
else: host_indices_list.append(cc.host_mem_release_queue.get())
host_node.backuped_storage = True if host_indices_list:
host_node.release_host() host_indices = torch.cat(host_indices_list, dim=0)
del self.ongoing_backup[ack_id] cc.mem_pool_host.free(host_indices)
def can_terminate_prefetch(self, operation: PrefetchOperation): def can_terminate_prefetch(self, operation: PrefetchOperation):
can_terminate = True can_terminate = True
...@@ -519,7 +508,7 @@ class HiRadixCache(RadixCache): ...@@ -519,7 +508,7 @@ class HiRadixCache(RadixCache):
self.cache_controller.mem_pool_host.update_prefetch(written_indices) self.cache_controller.mem_pool_host.update_prefetch(written_indices)
self.cache_controller.mem_pool_host.free(host_indices[:matched_length]) self.cache_controller.mem_pool_host.free(host_indices[:matched_length])
self.cache_controller.mem_pool_host.free( self.cache_controller.append_host_mem_release(
host_indices[min_completed_tokens:completed_tokens] host_indices[min_completed_tokens:completed_tokens]
) )
last_host_node.release_host() last_host_node.release_host()
...@@ -575,7 +564,11 @@ class HiRadixCache(RadixCache): ...@@ -575,7 +564,11 @@ class HiRadixCache(RadixCache):
len(new_input_tokens) % self.page_size len(new_input_tokens) % self.page_size
) )
new_input_tokens = new_input_tokens[:prefetch_length] new_input_tokens = new_input_tokens[:prefetch_length]
if not self.enable_storage or prefetch_length < self.prefetch_threshold: if (
not self.enable_storage
or prefetch_length < self.prefetch_threshold
or self.cache_controller.prefetch_rate_limited()
):
return return
last_host_node.protect_host() last_host_node.protect_host()
...@@ -583,6 +576,10 @@ class HiRadixCache(RadixCache): ...@@ -583,6 +576,10 @@ class HiRadixCache(RadixCache):
if host_indices is None: if host_indices is None:
self.evict_host(prefetch_length) self.evict_host(prefetch_length)
host_indices = self.cache_controller.mem_pool_host.alloc(prefetch_length) host_indices = self.cache_controller.mem_pool_host.alloc(prefetch_length)
if host_indices is None:
last_host_node.release_host()
# no sufficient host memory for prefetch
return
operation = self.cache_controller.prefetch( operation = self.cache_controller.prefetch(
req_id, host_indices, new_input_tokens, last_hash req_id, host_indices, new_input_tokens, last_hash
) )
......
...@@ -62,7 +62,6 @@ class TreeNode: ...@@ -62,7 +62,6 @@ class TreeNode:
self.host_value: Optional[torch.Tensor] = None self.host_value: Optional[torch.Tensor] = None
# store hash values of each pages # store hash values of each pages
self.hash_value: Optional[List[str]] = None self.hash_value: Optional[List[str]] = None
self.backuped_storage = False
self.id = TreeNode.counter if id is None else id self.id = TreeNode.counter if id is None else id
TreeNode.counter += 1 TreeNode.counter += 1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment