Unverified Commit 22711ec5 authored by thatPepe's avatar thatPepe Committed by GitHub
Browse files

Merge pull request #262 from InfiniTensor/issue/244

issue/244 feat(llm): add prefix cache reuse for static KV cache
parents 3b8e1cb7 a89b194a
......@@ -248,9 +248,14 @@ class LLMEngine:
sampled_tokens: List[int],
):
"""Update request status after inference step."""
# Only reset req blocks for paged cache
if is_prefill and self.cache_type == "paged":
self.scheduler.cache_manager.reset_req_blocks()
if is_prefill:
match self.cache_type:
case "paged":
self.scheduler.cache_manager.reset_req_blocks()
case "static":
self.scheduler.update_cache()
case _:
raise ValueError(f"Unsupported cache_type: {self.cache_type}")
for req, token_id in zip(requests, sampled_tokens):
......
......@@ -7,6 +7,7 @@ import queue
import janus
from typing import List, Optional
from infinilm.llm.cache_manager import BlockManager
from infinilm.llm.request import (
RequestStatus,
InferenceRequest,
......@@ -16,6 +17,8 @@ from infinilm.llm.request import (
logger = logging.getLogger(__name__)
_BLOCK_SIZE = 16
class StaticSchedulerOutput:
"""Static scheduler output containing single request and execution phase info."""
......@@ -24,10 +27,12 @@ class StaticSchedulerOutput:
self,
scheduled_requests: List[InferenceRequest],
is_prefill: bool = False,
prefix_hit_len: int = 0,
):
self.scheduled_requests = scheduled_requests
self.num_requests = len(scheduled_requests)
self.is_prefill = is_prefill
self.prefix_hit_len = prefix_hit_len
def build_model_inputs(
self, temperature: float = 1.0, top_p: float = 0.8, top_k: int = 1
......@@ -36,10 +41,10 @@ class StaticSchedulerOutput:
Static cache model inputs:
Prefill phase:
- input_ids: All prompt tokens [1, prompt_length]
- position_ids: [0, 1, 2, ..., prompt_length-1]
- past_kv_lengths: [0] (no cached tokens initially)
Prefill phase (with prefix cache reuse):
- input_ids: Tokens after the cached prefix [1, prompt_length - prefix_hit_len]
- position_ids: [prefix_hit_len, ..., prompt_length-1]
- past_kv_lengths: [prefix_hit_len] (reuse cached prefix)
- total_kv_lengths: [prompt_length]
Decode phase:
......@@ -47,18 +52,19 @@ class StaticSchedulerOutput:
- position_ids: [current_position] (position in full sequence)
- past_kv_lengths: [num_cached_tokens]
- total_kv_lengths: [total_tokens]
-
"""
req = self.scheduled_requests[0]
if self.is_prefill:
# Prefill: send all prompt tokens
# Prefill: only send tokens not already in cache
tokens = req.get_input_tokens()
input_ids = [tokens]
position_ids = [list(range(len(tokens)))]
past_kv_len = 0
prefix_hit_len = self.prefix_hit_len
input_tokens = tokens[prefix_hit_len:]
input_ids = [input_tokens]
position_ids = [list(range(prefix_hit_len, len(tokens)))]
past_kv_len = prefix_hit_len
total_kv_len = len(tokens)
input_offsets = [0, len(tokens)]
input_offsets = [0, len(input_tokens)]
else:
# Decode: send only the last generated token
last_token = req.generated_token_ids[-1]
......@@ -91,12 +97,15 @@ class StaticScheduler:
- Only handles one request at a time
- No cache block management needed
- Simple waiting queue for incoming requests
- Prefix cache reuse via chained block hashing (block size = _BLOCK_SIZE)
"""
def __init__(self, max_cache_len: int = 4096):
self.waiting_queue = janus.Queue()
self.running_request: Optional[InferenceRequest] = None
self.max_cache_len = max_cache_len
self.cached_block_hashes: List[int] = []
self.pending_block_hashes: List[int] = []
def add_request(self, request: InferenceRequest):
if request is not None:
......@@ -138,6 +147,23 @@ class StaticScheduler:
)
continue
total_length = req.get_total_length()
if total_length % _BLOCK_SIZE == 1 and total_length > _BLOCK_SIZE:
block_index = total_length // _BLOCK_SIZE - 1
if len(self.cached_block_hashes) <= block_index:
all_tokens = req.get_all_token_ids()
block_tokens = all_tokens[-(_BLOCK_SIZE + 1) : -1]
prev_h = (
self.cached_block_hashes[-1]
if self.cached_block_hashes
else -1
)
new_h = BlockManager.compute_hash(block_tokens, prev_h)
self.cached_block_hashes.append(new_h)
logger.debug(
f"Decode: appended block hash at index {block_index}"
)
return StaticSchedulerOutput(scheduled_requests=[req], is_prefill=False)
# Case 2: Get new request from waiting queue (prefill phase)
......@@ -175,9 +201,55 @@ class StaticScheduler:
)
continue
tokens = req.prompt_token_ids
num_full_blocks = prompt_len // _BLOCK_SIZE
matched = 0
self.pending_block_hashes.clear()
for i in range(num_full_blocks):
prev_h = self.cached_block_hashes[i - 1] if i > 0 else -1
h = BlockManager.compute_hash(
tokens[i * _BLOCK_SIZE : (i + 1) * _BLOCK_SIZE], prev_h
)
if (
i < len(self.cached_block_hashes)
and h == self.cached_block_hashes[i]
):
matched = i + 1
else:
del self.cached_block_hashes[i:]
cur_h = h
self.pending_block_hashes.append(cur_h)
for j in range(i + 1, num_full_blocks):
cur_h = BlockManager.compute_hash(
tokens[j * _BLOCK_SIZE : (j + 1) * _BLOCK_SIZE],
cur_h,
)
self.pending_block_hashes.append(cur_h)
break
else:
del self.cached_block_hashes[matched:]
prefix_hit_len = matched * _BLOCK_SIZE
logger.info(
f"Prefill cache match: {matched}/{num_full_blocks} blocks "
f"({prefix_hit_len} tokens reused, {len(self.pending_block_hashes)} pending)"
)
req.status = RequestStatus.RUNNING
self.running_request = req
return StaticSchedulerOutput(scheduled_requests=[req], is_prefill=True)
return StaticSchedulerOutput(
scheduled_requests=[req], is_prefill=True, prefix_hit_len=prefix_hit_len
)
def update_cache(self):
"""Commit hashes computed during prefill into the confirmed cache hash list."""
self.cached_block_hashes.extend(self.pending_block_hashes)
self.pending_block_hashes.clear()
logger.debug(
f"update_cache: cached_block_hashes now has {len(self.cached_block_hashes)} blocks"
)
def complete_requests(self, requests: List[InferenceRequest]):
"""Handle completed requests."""
......@@ -190,6 +262,8 @@ class StaticScheduler:
"""Get cache statistics."""
return {
"max_cache_len": self.max_cache_len,
"cached_blocks": len(self.cached_block_hashes),
"cached_tokens": len(self.cached_block_hashes) * _BLOCK_SIZE,
"running_request": (
self.running_request.request_id if self.running_request else None
),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment