Unverified Commit 2fbb754e authored by hzh0425's avatar hzh0425 Committed by GitHub
Browse files

feature(pd-hicache): Prefill instances support reusing the RemoteStorage Cache via HiCache. (#8516)


Co-authored-by: default avatarShangming Cai <csmthu@gmail.com>
Co-authored-by: default avatarZhiqiang Xie <xiezhq@stanford.edu>
parent a85ebf50
...@@ -1185,22 +1185,27 @@ class Scheduler( ...@@ -1185,22 +1185,27 @@ class Scheduler(
def _add_request_to_queue(self, req: Req): def _add_request_to_queue(self, req: Req):
req.queue_time_start = time.perf_counter() req.queue_time_start = time.perf_counter()
if self.disaggregation_mode == DisaggregationMode.PREFILL: if self.disaggregation_mode == DisaggregationMode.PREFILL:
self._prefetch_kvcache(req)
self.disagg_prefill_bootstrap_queue.add( self.disagg_prefill_bootstrap_queue.add(
req, self.model_config.num_key_value_heads req, self.model_config.num_key_value_heads
) )
elif self.disaggregation_mode == DisaggregationMode.DECODE: elif self.disaggregation_mode == DisaggregationMode.DECODE:
self.disagg_decode_prealloc_queue.add(req) self.disagg_decode_prealloc_queue.add(req)
else: else:
self._prefetch_kvcache(req)
self.waiting_queue.append(req)
def _prefetch_kvcache(self, req: Req):
if self.enable_hicache_storage: if self.enable_hicache_storage:
req.init_next_round_input(self.tree_cache) req.init_next_round_input(self.tree_cache)
last_hash = req.last_host_node.get_last_hash_value() last_hash = req.last_host_node.get_last_hash_value()
matched_len = len(req.prefix_indices) + req.host_hit_length matched_len = len(req.prefix_indices) + req.host_hit_length
# todo, free-form fetching, calculating hash keys on the fly
if (matched_len > 0 and last_hash is not None) or matched_len == 0: if (matched_len > 0 and last_hash is not None) or matched_len == 0:
new_input_tokens = req.fill_ids[matched_len:] new_input_tokens = req.fill_ids[matched_len:]
self.tree_cache.prefetch_from_storage( self.tree_cache.prefetch_from_storage(
req.rid, req.last_host_node, new_input_tokens, last_hash req.rid, req.last_host_node, new_input_tokens, last_hash
) )
self.waiting_queue.append(req)
def _extend_requests_to_queue(self, reqs: List[Req], is_retracted: bool = False): def _extend_requests_to_queue(self, reqs: List[Req], is_retracted: bool = False):
if self.disaggregation_mode == DisaggregationMode.PREFILL: if self.disaggregation_mode == DisaggregationMode.PREFILL:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment