Unverified Commit 22711ec5 authored by thatPepe's avatar thatPepe Committed by GitHub
Browse files

Merge pull request #262 from InfiniTensor/issue/244

issue/244 feat(llm): add prefix cache reuse for static KV cache
parents 3b8e1cb7 a89b194a
...@@ -248,9 +248,14 @@ class LLMEngine: ...@@ -248,9 +248,14 @@ class LLMEngine:
sampled_tokens: List[int], sampled_tokens: List[int],
): ):
"""Update request status after inference step.""" """Update request status after inference step."""
# Only reset req blocks for paged cache if is_prefill:
if is_prefill and self.cache_type == "paged": match self.cache_type:
self.scheduler.cache_manager.reset_req_blocks() case "paged":
self.scheduler.cache_manager.reset_req_blocks()
case "static":
self.scheduler.update_cache()
case _:
raise ValueError(f"Unsupported cache_type: {self.cache_type}")
for req, token_id in zip(requests, sampled_tokens): for req, token_id in zip(requests, sampled_tokens):
......
...@@ -7,6 +7,7 @@ import queue ...@@ -7,6 +7,7 @@ import queue
import janus import janus
from typing import List, Optional from typing import List, Optional
from infinilm.llm.cache_manager import BlockManager
from infinilm.llm.request import ( from infinilm.llm.request import (
RequestStatus, RequestStatus,
InferenceRequest, InferenceRequest,
...@@ -16,6 +17,8 @@ from infinilm.llm.request import ( ...@@ -16,6 +17,8 @@ from infinilm.llm.request import (
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
_BLOCK_SIZE = 16
class StaticSchedulerOutput: class StaticSchedulerOutput:
"""Static scheduler output containing single request and execution phase info.""" """Static scheduler output containing single request and execution phase info."""
...@@ -24,10 +27,12 @@ class StaticSchedulerOutput: ...@@ -24,10 +27,12 @@ class StaticSchedulerOutput:
self, self,
scheduled_requests: List[InferenceRequest], scheduled_requests: List[InferenceRequest],
is_prefill: bool = False, is_prefill: bool = False,
prefix_hit_len: int = 0,
): ):
self.scheduled_requests = scheduled_requests self.scheduled_requests = scheduled_requests
self.num_requests = len(scheduled_requests) self.num_requests = len(scheduled_requests)
self.is_prefill = is_prefill self.is_prefill = is_prefill
self.prefix_hit_len = prefix_hit_len
def build_model_inputs( def build_model_inputs(
self, temperature: float = 1.0, top_p: float = 0.8, top_k: int = 1 self, temperature: float = 1.0, top_p: float = 0.8, top_k: int = 1
...@@ -36,10 +41,10 @@ class StaticSchedulerOutput: ...@@ -36,10 +41,10 @@ class StaticSchedulerOutput:
Static cache model inputs: Static cache model inputs:
Prefill phase: Prefill phase (with prefix cache reuse):
- input_ids: All prompt tokens [1, prompt_length] - input_ids: Tokens after the cached prefix [1, prompt_length - prefix_hit_len]
- position_ids: [0, 1, 2, ..., prompt_length-1] - position_ids: [prefix_hit_len, ..., prompt_length-1]
- past_kv_lengths: [0] (no cached tokens initially) - past_kv_lengths: [prefix_hit_len] (reuse cached prefix)
- total_kv_lengths: [prompt_length] - total_kv_lengths: [prompt_length]
Decode phase: Decode phase:
...@@ -47,18 +52,19 @@ class StaticSchedulerOutput: ...@@ -47,18 +52,19 @@ class StaticSchedulerOutput:
- position_ids: [current_position] (position in full sequence) - position_ids: [current_position] (position in full sequence)
- past_kv_lengths: [num_cached_tokens] - past_kv_lengths: [num_cached_tokens]
- total_kv_lengths: [total_tokens] - total_kv_lengths: [total_tokens]
-
""" """
req = self.scheduled_requests[0] req = self.scheduled_requests[0]
if self.is_prefill: if self.is_prefill:
# Prefill: send all prompt tokens # Prefill: only send tokens not already in cache
tokens = req.get_input_tokens() tokens = req.get_input_tokens()
input_ids = [tokens] prefix_hit_len = self.prefix_hit_len
position_ids = [list(range(len(tokens)))] input_tokens = tokens[prefix_hit_len:]
past_kv_len = 0 input_ids = [input_tokens]
position_ids = [list(range(prefix_hit_len, len(tokens)))]
past_kv_len = prefix_hit_len
total_kv_len = len(tokens) total_kv_len = len(tokens)
input_offsets = [0, len(tokens)] input_offsets = [0, len(input_tokens)]
else: else:
# Decode: send only the last generated token # Decode: send only the last generated token
last_token = req.generated_token_ids[-1] last_token = req.generated_token_ids[-1]
...@@ -91,12 +97,15 @@ class StaticScheduler: ...@@ -91,12 +97,15 @@ class StaticScheduler:
- Only handles one request at a time - Only handles one request at a time
- No cache block management needed - No cache block management needed
- Simple waiting queue for incoming requests - Simple waiting queue for incoming requests
- Prefix cache reuse via chained block hashing (block size = _BLOCK_SIZE)
""" """
def __init__(self, max_cache_len: int = 4096): def __init__(self, max_cache_len: int = 4096):
self.waiting_queue = janus.Queue() self.waiting_queue = janus.Queue()
self.running_request: Optional[InferenceRequest] = None self.running_request: Optional[InferenceRequest] = None
self.max_cache_len = max_cache_len self.max_cache_len = max_cache_len
self.cached_block_hashes: List[int] = []
self.pending_block_hashes: List[int] = []
def add_request(self, request: InferenceRequest): def add_request(self, request: InferenceRequest):
if request is not None: if request is not None:
...@@ -138,6 +147,23 @@ class StaticScheduler: ...@@ -138,6 +147,23 @@ class StaticScheduler:
) )
continue continue
total_length = req.get_total_length()
if total_length % _BLOCK_SIZE == 1 and total_length > _BLOCK_SIZE:
block_index = total_length // _BLOCK_SIZE - 1
if len(self.cached_block_hashes) <= block_index:
all_tokens = req.get_all_token_ids()
block_tokens = all_tokens[-(_BLOCK_SIZE + 1) : -1]
prev_h = (
self.cached_block_hashes[-1]
if self.cached_block_hashes
else -1
)
new_h = BlockManager.compute_hash(block_tokens, prev_h)
self.cached_block_hashes.append(new_h)
logger.debug(
f"Decode: appended block hash at index {block_index}"
)
return StaticSchedulerOutput(scheduled_requests=[req], is_prefill=False) return StaticSchedulerOutput(scheduled_requests=[req], is_prefill=False)
# Case 2: Get new request from waiting queue (prefill phase) # Case 2: Get new request from waiting queue (prefill phase)
...@@ -175,9 +201,55 @@ class StaticScheduler: ...@@ -175,9 +201,55 @@ class StaticScheduler:
) )
continue continue
tokens = req.prompt_token_ids
num_full_blocks = prompt_len // _BLOCK_SIZE
matched = 0
self.pending_block_hashes.clear()
for i in range(num_full_blocks):
prev_h = self.cached_block_hashes[i - 1] if i > 0 else -1
h = BlockManager.compute_hash(
tokens[i * _BLOCK_SIZE : (i + 1) * _BLOCK_SIZE], prev_h
)
if (
i < len(self.cached_block_hashes)
and h == self.cached_block_hashes[i]
):
matched = i + 1
else:
del self.cached_block_hashes[i:]
cur_h = h
self.pending_block_hashes.append(cur_h)
for j in range(i + 1, num_full_blocks):
cur_h = BlockManager.compute_hash(
tokens[j * _BLOCK_SIZE : (j + 1) * _BLOCK_SIZE],
cur_h,
)
self.pending_block_hashes.append(cur_h)
break
else:
del self.cached_block_hashes[matched:]
prefix_hit_len = matched * _BLOCK_SIZE
logger.info(
f"Prefill cache match: {matched}/{num_full_blocks} blocks "
f"({prefix_hit_len} tokens reused, {len(self.pending_block_hashes)} pending)"
)
req.status = RequestStatus.RUNNING req.status = RequestStatus.RUNNING
self.running_request = req self.running_request = req
return StaticSchedulerOutput(scheduled_requests=[req], is_prefill=True) return StaticSchedulerOutput(
scheduled_requests=[req], is_prefill=True, prefix_hit_len=prefix_hit_len
)
def update_cache(self):
"""Commit hashes computed during prefill into the confirmed cache hash list."""
self.cached_block_hashes.extend(self.pending_block_hashes)
self.pending_block_hashes.clear()
logger.debug(
f"update_cache: cached_block_hashes now has {len(self.cached_block_hashes)} blocks"
)
def complete_requests(self, requests: List[InferenceRequest]): def complete_requests(self, requests: List[InferenceRequest]):
"""Handle completed requests.""" """Handle completed requests."""
...@@ -190,6 +262,8 @@ class StaticScheduler: ...@@ -190,6 +262,8 @@ class StaticScheduler:
"""Get cache statistics.""" """Get cache statistics."""
return { return {
"max_cache_len": self.max_cache_len, "max_cache_len": self.max_cache_len,
"cached_blocks": len(self.cached_block_hashes),
"cached_tokens": len(self.cached_block_hashes) * _BLOCK_SIZE,
"running_request": ( "running_request": (
self.running_request.request_id if self.running_request else None self.running_request.request_id if self.running_request else None
), ),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment