Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
eef99f73
Commit
eef99f73
authored
Jan 21, 2026
by
laibao
Browse files
feat: kvpress新增 KV cache 申请/截断支持
parent
8d3d07fc
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
59 additions
and
3 deletions
+59
-3
vllm/v1/core/kv_cache_coordinator.py
vllm/v1/core/kv_cache_coordinator.py
+11
-0
vllm/v1/core/kv_cache_manager.py
vllm/v1/core/kv_cache_manager.py
+21
-3
vllm/v1/core/single_type_kv_cache_manager.py
vllm/v1/core/single_type_kv_cache_manager.py
+27
-0
No files found.
vllm/v1/core/kv_cache_coordinator.py
View file @
eef99f73
...
@@ -154,6 +154,17 @@ class KVCacheCoordinator(ABC):
...
@@ -154,6 +154,17 @@ class KVCacheCoordinator(ABC):
for
manager
in
self
.
single_type_managers
:
for
manager
in
self
.
single_type_managers
:
manager
.
remove_skipped_blocks
(
request_id
,
num_computed_tokens
)
manager
.
remove_skipped_blocks
(
request_id
,
num_computed_tokens
)
def
truncate_to_num_tokens
(
self
,
request_id
:
str
,
num_tokens
:
int
)
->
bool
:
"""Truncate a request's allocated blocks to fit `num_tokens` slots.
Returns True if any blocks were freed.
"""
truncated
=
False
for
manager
in
self
.
single_type_managers
:
truncated
=
manager
.
truncate_to_num_tokens
(
request_id
,
num_tokens
)
or
truncated
return
truncated
def
get_blocks
(
self
,
request_id
:
str
)
->
tuple
[
list
[
KVCacheBlock
],
...]:
def
get_blocks
(
self
,
request_id
:
str
)
->
tuple
[
list
[
KVCacheBlock
],
...]:
"""
"""
Get the blocks for the request.
Get the blocks for the request.
...
...
vllm/v1/core/kv_cache_manager.py
View file @
eef99f73
...
@@ -7,6 +7,8 @@ from typing import Optional
...
@@ -7,6 +7,8 @@ from typing import Optional
from
vllm.distributed.kv_events
import
KVCacheEvent
from
vllm.distributed.kv_events
import
KVCacheEvent
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
import
vllm.envs
as
envs
from
vllm.platforms
import
current_platform
from
vllm.utils
import
sha256
from
vllm.utils
import
sha256
from
vllm.v1.core.kv_cache_coordinator
import
get_kv_cache_coordinator
from
vllm.v1.core.kv_cache_coordinator
import
get_kv_cache_coordinator
from
vllm.v1.core.kv_cache_utils
import
(
BlockHash
,
KVCacheBlock
,
from
vllm.v1.core.kv_cache_utils
import
(
BlockHash
,
KVCacheBlock
,
...
@@ -251,9 +253,17 @@ class KVCacheManager:
...
@@ -251,9 +253,17 @@ class KVCacheManager:
# the new prefix caching hits
# the new prefix caching hits
num_computed_tokens
=
(
request
.
num_computed_tokens
+
num_computed_tokens
=
(
request
.
num_computed_tokens
+
num_new_computed_tokens
)
num_new_computed_tokens
)
num_tokens_need_slot
=
min
(
if
envs
.
VLLM_ENABLE_KV_COMPRESSION
and
not
current_platform
.
is_tpu
():
num_computed_tokens
+
num_new_tokens
+
num_lookahead_tokens
,
# KV compression decouples logical positions from KV cache
self
.
max_model_len
)
# positions. Allocate based on the KV cache length (plus the tokens
# scheduled for this step, which are temporarily written to cache).
num_tokens_need_slot
=
min
(
request
.
num_kv_tokens
+
num_new_tokens
+
num_lookahead_tokens
,
self
.
max_model_len
)
else
:
num_tokens_need_slot
=
min
(
num_computed_tokens
+
num_new_tokens
+
num_lookahead_tokens
,
self
.
max_model_len
)
num_blocks_to_allocate
=
self
.
coordinator
.
get_num_blocks_to_allocate
(
num_blocks_to_allocate
=
self
.
coordinator
.
get_num_blocks_to_allocate
(
request_id
=
request
.
request_id
,
request_id
=
request
.
request_id
,
...
@@ -385,6 +395,14 @@ class KVCacheManager:
...
@@ -385,6 +395,14 @@ class KVCacheManager:
return
KVCacheBlocks
(
return
KVCacheBlocks
(
self
.
coordinator
.
get_blocks
(
request_id
)).
get_block_ids
()
self
.
coordinator
.
get_blocks
(
request_id
)).
get_block_ids
()
def
truncate_to_num_tokens
(
self
,
request_id
:
str
,
num_tokens
:
int
)
->
bool
:
"""Truncate a request's allocated blocks to fit `num_tokens` slots.
This is a best-effort operation that may free blocks back to the pool.
Returns True if any blocks were freed.
"""
return
self
.
coordinator
.
truncate_to_num_tokens
(
request_id
,
num_tokens
)
def
cache_blocks
(
self
,
request
:
Request
,
num_computed_tokens
:
int
)
->
None
:
def
cache_blocks
(
self
,
request
:
Request
,
num_computed_tokens
:
int
)
->
None
:
"""Cache the blocks for the request, if enabled."""
"""Cache the blocks for the request, if enabled."""
if
self
.
enable_caching
:
if
self
.
enable_caching
:
...
...
vllm/v1/core/single_type_kv_cache_manager.py
View file @
eef99f73
...
@@ -174,6 +174,15 @@ class SingleTypeKVCacheManager(ABC):
...
@@ -174,6 +174,15 @@ class SingleTypeKVCacheManager(ABC):
self
.
block_pool
.
free_blocks
(
ordered_blocks
)
self
.
block_pool
.
free_blocks
(
ordered_blocks
)
self
.
num_cached_block
.
pop
(
request_id
,
None
)
self
.
num_cached_block
.
pop
(
request_id
,
None
)
def
truncate_to_num_tokens
(
self
,
request_id
:
str
,
num_tokens
:
int
)
->
bool
:
"""Truncate a request's allocated blocks to fit `num_tokens` slots.
This is a best-effort optimization hook. Subclasses may override this
to free no-longer-needed blocks (e.g., after KV compaction). The default
implementation is a no-op.
"""
return
False
@
abstractmethod
@
abstractmethod
def
get_num_common_prefix_blocks
(
self
,
request_id
:
str
,
def
get_num_common_prefix_blocks
(
self
,
request_id
:
str
,
num_running_requests
:
int
)
->
int
:
num_running_requests
:
int
)
->
int
:
...
@@ -283,6 +292,24 @@ class FullAttentionManager(SingleTypeKVCacheManager):
...
@@ -283,6 +292,24 @@ class FullAttentionManager(SingleTypeKVCacheManager):
# No need to remove blocks for full attention.
# No need to remove blocks for full attention.
pass
pass
def
truncate_to_num_tokens
(
self
,
request_id
:
str
,
num_tokens
:
int
)
->
bool
:
num_tokens
=
max
(
int
(
num_tokens
),
0
)
blocks
=
self
.
req_to_blocks
.
get
(
request_id
)
if
not
blocks
:
return
False
num_required_blocks
=
cdiv
(
num_tokens
,
self
.
block_size
)
if
num_required_blocks
>=
len
(
blocks
):
return
False
removed_blocks
=
blocks
[
num_required_blocks
:]
del
blocks
[
num_required_blocks
:]
self
.
block_pool
.
free_blocks
(
reversed
(
removed_blocks
))
if
request_id
in
self
.
num_cached_block
:
self
.
num_cached_block
[
request_id
]
=
min
(
self
.
num_cached_block
[
request_id
],
len
(
blocks
))
return
True
def
get_num_common_prefix_blocks
(
self
,
request_id
:
str
,
def
get_num_common_prefix_blocks
(
self
,
request_id
:
str
,
num_running_requests
:
int
)
->
int
:
num_running_requests
:
int
)
->
int
:
blocks
=
self
.
req_to_blocks
[
request_id
]
blocks
=
self
.
req_to_blocks
[
request_id
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment