Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinilm
Commits
22711ec5
Unverified
Commit
22711ec5
authored
Mar 10, 2026
by
thatPepe
Committed by
GitHub
Mar 10, 2026
Browse files
Merge pull request #262 from InfiniTensor/issue/244
issue/244 feat(llm): add prefix cache reuse for static KV cache
parents
3b8e1cb7
a89b194a
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
93 additions
and
14 deletions
+93
-14
python/infinilm/llm/llm.py
python/infinilm/llm/llm.py
+8
-3
python/infinilm/llm/static_scheduler.py
python/infinilm/llm/static_scheduler.py
+85
-11
No files found.
python/infinilm/llm/llm.py
View file @
22711ec5
...
@@ -248,9 +248,14 @@ class LLMEngine:
...
@@ -248,9 +248,14 @@ class LLMEngine:
sampled_tokens
:
List
[
int
],
sampled_tokens
:
List
[
int
],
):
):
"""Update request status after inference step."""
"""Update request status after inference step."""
# Only reset req blocks for paged cache
if
is_prefill
:
if
is_prefill
and
self
.
cache_type
==
"paged"
:
match
self
.
cache_type
:
self
.
scheduler
.
cache_manager
.
reset_req_blocks
()
case
"paged"
:
self
.
scheduler
.
cache_manager
.
reset_req_blocks
()
case
"static"
:
self
.
scheduler
.
update_cache
()
case
_
:
raise
ValueError
(
f
"Unsupported cache_type:
{
self
.
cache_type
}
"
)
for
req
,
token_id
in
zip
(
requests
,
sampled_tokens
):
for
req
,
token_id
in
zip
(
requests
,
sampled_tokens
):
...
...
python/infinilm/llm/static_scheduler.py
View file @
22711ec5
...
@@ -7,6 +7,7 @@ import queue
...
@@ -7,6 +7,7 @@ import queue
import
janus
import
janus
from
typing
import
List
,
Optional
from
typing
import
List
,
Optional
from
infinilm.llm.cache_manager
import
BlockManager
from
infinilm.llm.request
import
(
from
infinilm.llm.request
import
(
RequestStatus
,
RequestStatus
,
InferenceRequest
,
InferenceRequest
,
...
@@ -16,6 +17,8 @@ from infinilm.llm.request import (
...
@@ -16,6 +17,8 @@ from infinilm.llm.request import (
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
_BLOCK_SIZE
=
16
class
StaticSchedulerOutput
:
class
StaticSchedulerOutput
:
"""Static scheduler output containing single request and execution phase info."""
"""Static scheduler output containing single request and execution phase info."""
...
@@ -24,10 +27,12 @@ class StaticSchedulerOutput:
...
@@ -24,10 +27,12 @@ class StaticSchedulerOutput:
self
,
self
,
scheduled_requests
:
List
[
InferenceRequest
],
scheduled_requests
:
List
[
InferenceRequest
],
is_prefill
:
bool
=
False
,
is_prefill
:
bool
=
False
,
prefix_hit_len
:
int
=
0
,
):
):
self
.
scheduled_requests
=
scheduled_requests
self
.
scheduled_requests
=
scheduled_requests
self
.
num_requests
=
len
(
scheduled_requests
)
self
.
num_requests
=
len
(
scheduled_requests
)
self
.
is_prefill
=
is_prefill
self
.
is_prefill
=
is_prefill
self
.
prefix_hit_len
=
prefix_hit_len
def
build_model_inputs
(
def
build_model_inputs
(
self
,
temperature
:
float
=
1.0
,
top_p
:
float
=
0.8
,
top_k
:
int
=
1
self
,
temperature
:
float
=
1.0
,
top_p
:
float
=
0.8
,
top_k
:
int
=
1
...
@@ -36,10 +41,10 @@ class StaticSchedulerOutput:
...
@@ -36,10 +41,10 @@ class StaticSchedulerOutput:
Static cache model inputs:
Static cache model inputs:
Prefill phase:
Prefill phase
(with prefix cache reuse)
:
- input_ids:
All prompt tokens
[1, prompt_length]
- input_ids:
Tokens after the cached prefix
[1, prompt_length
- prefix_hit_len
]
- position_ids: [
0, 1, 2
, ..., prompt_length-1]
- position_ids: [
prefix_hit_len
, ..., prompt_length-1]
- past_kv_lengths: [
0] (no cached tokens initially
)
- past_kv_lengths: [
prefix_hit_len] (reuse cached prefix
)
- total_kv_lengths: [prompt_length]
- total_kv_lengths: [prompt_length]
Decode phase:
Decode phase:
...
@@ -47,18 +52,19 @@ class StaticSchedulerOutput:
...
@@ -47,18 +52,19 @@ class StaticSchedulerOutput:
- position_ids: [current_position] (position in full sequence)
- position_ids: [current_position] (position in full sequence)
- past_kv_lengths: [num_cached_tokens]
- past_kv_lengths: [num_cached_tokens]
- total_kv_lengths: [total_tokens]
- total_kv_lengths: [total_tokens]
-
"""
"""
req
=
self
.
scheduled_requests
[
0
]
req
=
self
.
scheduled_requests
[
0
]
if
self
.
is_prefill
:
if
self
.
is_prefill
:
# Prefill:
send all prompt tokens
# Prefill:
only send tokens not already in cache
tokens
=
req
.
get_input_tokens
()
tokens
=
req
.
get_input_tokens
()
input_ids
=
[
tokens
]
prefix_hit_len
=
self
.
prefix_hit_len
position_ids
=
[
list
(
range
(
len
(
tokens
)))]
input_tokens
=
tokens
[
prefix_hit_len
:]
past_kv_len
=
0
input_ids
=
[
input_tokens
]
position_ids
=
[
list
(
range
(
prefix_hit_len
,
len
(
tokens
)))]
past_kv_len
=
prefix_hit_len
total_kv_len
=
len
(
tokens
)
total_kv_len
=
len
(
tokens
)
input_offsets
=
[
0
,
len
(
tokens
)]
input_offsets
=
[
0
,
len
(
input_
tokens
)]
else
:
else
:
# Decode: send only the last generated token
# Decode: send only the last generated token
last_token
=
req
.
generated_token_ids
[
-
1
]
last_token
=
req
.
generated_token_ids
[
-
1
]
...
@@ -91,12 +97,15 @@ class StaticScheduler:
...
@@ -91,12 +97,15 @@ class StaticScheduler:
- Only handles one request at a time
- Only handles one request at a time
- No cache block management needed
- No cache block management needed
- Simple waiting queue for incoming requests
- Simple waiting queue for incoming requests
- Prefix cache reuse via chained block hashing (block size = _BLOCK_SIZE)
"""
"""
def
__init__
(
self
,
max_cache_len
:
int
=
4096
):
def
__init__
(
self
,
max_cache_len
:
int
=
4096
):
self
.
waiting_queue
=
janus
.
Queue
()
self
.
waiting_queue
=
janus
.
Queue
()
self
.
running_request
:
Optional
[
InferenceRequest
]
=
None
self
.
running_request
:
Optional
[
InferenceRequest
]
=
None
self
.
max_cache_len
=
max_cache_len
self
.
max_cache_len
=
max_cache_len
self
.
cached_block_hashes
:
List
[
int
]
=
[]
self
.
pending_block_hashes
:
List
[
int
]
=
[]
def
add_request
(
self
,
request
:
InferenceRequest
):
def
add_request
(
self
,
request
:
InferenceRequest
):
if
request
is
not
None
:
if
request
is
not
None
:
...
@@ -138,6 +147,23 @@ class StaticScheduler:
...
@@ -138,6 +147,23 @@ class StaticScheduler:
)
)
continue
continue
total_length
=
req
.
get_total_length
()
if
total_length
%
_BLOCK_SIZE
==
1
and
total_length
>
_BLOCK_SIZE
:
block_index
=
total_length
//
_BLOCK_SIZE
-
1
if
len
(
self
.
cached_block_hashes
)
<=
block_index
:
all_tokens
=
req
.
get_all_token_ids
()
block_tokens
=
all_tokens
[
-
(
_BLOCK_SIZE
+
1
)
:
-
1
]
prev_h
=
(
self
.
cached_block_hashes
[
-
1
]
if
self
.
cached_block_hashes
else
-
1
)
new_h
=
BlockManager
.
compute_hash
(
block_tokens
,
prev_h
)
self
.
cached_block_hashes
.
append
(
new_h
)
logger
.
debug
(
f
"Decode: appended block hash at index
{
block_index
}
"
)
return
StaticSchedulerOutput
(
scheduled_requests
=
[
req
],
is_prefill
=
False
)
return
StaticSchedulerOutput
(
scheduled_requests
=
[
req
],
is_prefill
=
False
)
# Case 2: Get new request from waiting queue (prefill phase)
# Case 2: Get new request from waiting queue (prefill phase)
...
@@ -175,9 +201,55 @@ class StaticScheduler:
...
@@ -175,9 +201,55 @@ class StaticScheduler:
)
)
continue
continue
tokens
=
req
.
prompt_token_ids
num_full_blocks
=
prompt_len
//
_BLOCK_SIZE
matched
=
0
self
.
pending_block_hashes
.
clear
()
for
i
in
range
(
num_full_blocks
):
prev_h
=
self
.
cached_block_hashes
[
i
-
1
]
if
i
>
0
else
-
1
h
=
BlockManager
.
compute_hash
(
tokens
[
i
*
_BLOCK_SIZE
:
(
i
+
1
)
*
_BLOCK_SIZE
],
prev_h
)
if
(
i
<
len
(
self
.
cached_block_hashes
)
and
h
==
self
.
cached_block_hashes
[
i
]
):
matched
=
i
+
1
else
:
del
self
.
cached_block_hashes
[
i
:]
cur_h
=
h
self
.
pending_block_hashes
.
append
(
cur_h
)
for
j
in
range
(
i
+
1
,
num_full_blocks
):
cur_h
=
BlockManager
.
compute_hash
(
tokens
[
j
*
_BLOCK_SIZE
:
(
j
+
1
)
*
_BLOCK_SIZE
],
cur_h
,
)
self
.
pending_block_hashes
.
append
(
cur_h
)
break
else
:
del
self
.
cached_block_hashes
[
matched
:]
prefix_hit_len
=
matched
*
_BLOCK_SIZE
logger
.
info
(
f
"Prefill cache match:
{
matched
}
/
{
num_full_blocks
}
blocks "
f
"(
{
prefix_hit_len
}
tokens reused,
{
len
(
self
.
pending_block_hashes
)
}
pending)"
)
req
.
status
=
RequestStatus
.
RUNNING
req
.
status
=
RequestStatus
.
RUNNING
self
.
running_request
=
req
self
.
running_request
=
req
return
StaticSchedulerOutput
(
scheduled_requests
=
[
req
],
is_prefill
=
True
)
return
StaticSchedulerOutput
(
scheduled_requests
=
[
req
],
is_prefill
=
True
,
prefix_hit_len
=
prefix_hit_len
)
def
update_cache
(
self
):
"""Commit hashes computed during prefill into the confirmed cache hash list."""
self
.
cached_block_hashes
.
extend
(
self
.
pending_block_hashes
)
self
.
pending_block_hashes
.
clear
()
logger
.
debug
(
f
"update_cache: cached_block_hashes now has
{
len
(
self
.
cached_block_hashes
)
}
blocks"
)
def
complete_requests
(
self
,
requests
:
List
[
InferenceRequest
]):
def
complete_requests
(
self
,
requests
:
List
[
InferenceRequest
]):
"""Handle completed requests."""
"""Handle completed requests."""
...
@@ -190,6 +262,8 @@ class StaticScheduler:
...
@@ -190,6 +262,8 @@ class StaticScheduler:
"""Get cache statistics."""
"""Get cache statistics."""
return
{
return
{
"max_cache_len"
:
self
.
max_cache_len
,
"max_cache_len"
:
self
.
max_cache_len
,
"cached_blocks"
:
len
(
self
.
cached_block_hashes
),
"cached_tokens"
:
len
(
self
.
cached_block_hashes
)
*
_BLOCK_SIZE
,
"running_request"
:
(
"running_request"
:
(
self
.
running_request
.
request_id
if
self
.
running_request
else
None
self
.
running_request
.
request_id
if
self
.
running_request
else
None
),
),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment