Unverified Commit 6b634493 authored by hzh0425's avatar hzh0425 Committed by GitHub
Browse files

[HICache / PD]: Support offloading incremental KV cache in decode side. (#11966)

parent 756ad9ce
...@@ -60,6 +60,7 @@ class DecodeKVCacheOffloadManager: ...@@ -60,6 +60,7 @@ class DecodeKVCacheOffloadManager:
self.tp_group = tp_group self.tp_group = tp_group
self.tp_world_size = torch.distributed.get_world_size(group=self.tp_group) self.tp_world_size = torch.distributed.get_world_size(group=self.tp_group)
self.cache_controller = HiCacheController( self.cache_controller = HiCacheController(
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator, token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
mem_pool_host=self.decode_host_mem_pool, mem_pool_host=self.decode_host_mem_pool,
...@@ -77,41 +78,59 @@ class DecodeKVCacheOffloadManager: ...@@ -77,41 +78,59 @@ class DecodeKVCacheOffloadManager:
logger.info("Enable offload kv cache for decode side") logger.info("Enable offload kv cache for decode side")
def offload_kv_cache(self, req) -> bool: def offload_kv_cache(self, req) -> bool:
"""Offload a finished request's KV cache to storage.""" """Offload incremental KV cache for decode side."""
if self.cache_controller is None or self.decode_host_mem_pool is None: if self.cache_controller is None or self.decode_host_mem_pool is None:
return False return False
if req.req_pool_idx == -1: if req.req_pool_idx == -1 or len(req.output_ids) == 0:
return False return False
token_indices = self.req_to_token_pool.req_to_token[req.req_pool_idx] token_indices = self.req_to_token_pool.req_to_token[req.req_pool_idx]
if token_indices.dim() == 0 or token_indices.numel() == 0: if token_indices.dim() == 0 or token_indices.numel() == 0:
logger.debug(
f"Request {req.rid} has invalid token_indices: {token_indices}"
)
return False return False
tokens = req.origin_input_ids + req.output_ids # Prefill side offloads page-aligned origin_input_ids, decode side offloads the incremental part
aligned_len = (len(tokens) // self.page_size) * self.page_size all_tokens = req.origin_input_ids + req.output_ids[:-1]
if aligned_len == 0: prefill_offloaded_len = (
len(req.origin_input_ids) // self.page_size * self.page_size
)
incremental_len = len(all_tokens) - prefill_offloaded_len
incremental_aligned_len = incremental_len // self.page_size * self.page_size
if incremental_aligned_len == 0:
return False return False
token_indices = token_indices[:aligned_len] # Extract incremental tokens and indices
tokens = tokens[:aligned_len] start, end = (
prefill_offloaded_len,
prefill_offloaded_len + incremental_aligned_len,
)
incremental_tokens = all_tokens[start:end]
incremental_indices = token_indices[start:end]
# Early free prefill-offloaded GPU memory
if prefill_offloaded_len > 0:
self.token_to_kv_pool_allocator.free(token_indices[:prefill_offloaded_len])
# Asynchronously offload KV cache from device to host by cache controller # Asynchronously offload incremental KV cache from device to host
self.request_counter += 1 self.request_counter += 1
ack_id = self.request_counter ack_id = self.request_counter
host_indices = self.cache_controller.write( host_indices = self.cache_controller.write(
device_indices=token_indices.long(), device_indices=incremental_indices.long(),
node_id=ack_id, node_id=ack_id,
) )
if host_indices is None: if host_indices is None:
logger.error(f"Not enough host memory for request {req.rid}") logger.error(f"Not enough host memory for request {req.rid}")
return False return False
self.ongoing_offload[ack_id] = (req, host_indices, tokens, time.time()) self.ongoing_offload[ack_id] = (
req,
host_indices,
incremental_tokens,
time.time(),
prefill_offloaded_len,
)
return True return True
def check_offload_progress(self): def check_offload_progress(self):
...@@ -140,14 +159,33 @@ class DecodeKVCacheOffloadManager: ...@@ -140,14 +159,33 @@ class DecodeKVCacheOffloadManager:
_, finish_event, ack_list = self.cache_controller.ack_write_queue.pop(0) _, finish_event, ack_list = self.cache_controller.ack_write_queue.pop(0)
finish_event.synchronize() finish_event.synchronize()
for ack_id in ack_list: for ack_id in ack_list:
req, host_indices, tokens, start_time = self.ongoing_offload.pop(ack_id) (
req,
host_indices,
incremental_tokens,
start_time,
prefill_offloaded_len,
) = self.ongoing_offload.pop(ack_id)
self._release_finished_req(req, prefill_offloaded_len)
self._trigger_backup(
req,
host_indices,
incremental_tokens,
start_time,
prefill_offloaded_len,
)
finish_count -= 1
# Release device def _release_finished_req(self, req, prefill_offloaded_len):
self.tree_cache.cache_finished_req(req) kv_indices = self.req_to_token_pool.req_to_token[
req.req_pool_idx,
: len(req.origin_input_ids) + max(len(req.output_ids) - 1, 0),
]
# Trigger async backup from host to storage by cache controller # Free the incremental part of the request
self._trigger_backup(req.rid, host_indices, tokens, start_time) self.token_to_kv_pool_allocator.free(kv_indices[prefill_offloaded_len:])
finish_count -= 1 self.req_to_token_pool.free(req.req_pool_idx)
def _check_backup_progress(self, finish_count): def _check_backup_progress(self, finish_count):
"""Check the progress of backup from host to storage.""" """Check the progress of backup from host to storage."""
...@@ -159,25 +197,30 @@ class DecodeKVCacheOffloadManager: ...@@ -159,25 +197,30 @@ class DecodeKVCacheOffloadManager:
# Release host memory # Release host memory
self.decode_host_mem_pool.free(host_indices) self.decode_host_mem_pool.free(host_indices)
logger.debug( logger.info(
f"Finished backup request {req_id}, free host memory, len:{len(host_indices)}, cost time:{time.time() - start_time:.2f} seconds." f"Finished backup request {req_id}, free host memory, len:{len(host_indices)}, cost time:{time.time() - start_time:.2f} seconds."
) )
def _trigger_backup(self, req_id, host_indices, tokens, start_time): def _trigger_backup(
"""Trigger async backup from host to storage by cache controller.""" self, req, host_indices, incremental_tokens, start_time, prefill_offloaded_len
):
"""Trigger async backup from host to storage."""
prefill_hashes = self._compute_prefix_hash(
req.origin_input_ids[:prefill_offloaded_len]
)
last_prefill_hash = prefill_hashes[-1] if prefill_offloaded_len > 0 else ""
# Generate page hashes and write to storage page_hashes = self._compute_prefix_hash(incremental_tokens, last_prefill_hash)
page_hashes = self._compute_prefix_hash(tokens)
ack_id = self.cache_controller.write_storage( ack_id = self.cache_controller.write_storage(
host_indices, host_indices,
tokens, incremental_tokens,
hash_value=page_hashes, hash_value=page_hashes,
) )
self.ongoing_backup[ack_id] = (req_id, host_indices, start_time) self.ongoing_backup[ack_id] = (req.rid, host_indices, start_time)
def _compute_prefix_hash(self, tokens): def _compute_prefix_hash(self, tokens, prior_hash=""):
last_hash = ""
page_hashes = [] page_hashes = []
last_hash = prior_hash
for offset in range(0, len(tokens), self.page_size): for offset in range(0, len(tokens), self.page_size):
page_tokens = tokens[offset : offset + self.page_size] page_tokens = tokens[offset : offset + self.page_size]
last_hash = self.cache_controller.get_hash_str(page_tokens, last_hash) last_hash = self.cache_controller.get_hash_str(page_tokens, last_hash)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment