Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
5095e966
Unverified
Commit
5095e966
authored
Feb 03, 2025
by
Cody Yu
Committed by
GitHub
Feb 03, 2025
Browse files
[V1] Revert `uncache_blocks` and support recaching full blocks (#12415)
Signed-off-by:
Cody Yu
<
hao.yu.cody@gmail.com
>
parent
cf58b9c4
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
16 additions
and
55 deletions
+16
-55
tests/v1/core/test_prefix_caching.py
tests/v1/core/test_prefix_caching.py
+0
-30
vllm/v1/core/kv_cache_manager.py
vllm/v1/core/kv_cache_manager.py
+16
-25
No files found.
tests/v1/core/test_prefix_caching.py
View file @
5095e966
...
@@ -629,33 +629,3 @@ def test_reset_prefix_cache():
...
@@ -629,33 +629,3 @@ def test_reset_prefix_cache():
assert
manager
.
reset_prefix_cache
()
assert
manager
.
reset_prefix_cache
()
assert
not
manager
.
cached_block_hash_to_block
assert
not
manager
.
cached_block_hash_to_block
assert
all
([
blk
.
block_hash
is
None
for
blk
in
manager
.
block_pool
])
assert
all
([
blk
.
block_hash
is
None
for
blk
in
manager
.
block_pool
])
def
test_uncache_blocks
():
manager
=
KVCacheManager
(
block_size
=
16
,
num_gpu_blocks
=
10
,
max_model_len
=
8192
,
sliding_window
=
None
,
enable_caching
=
True
,
num_preallocate_tokens
=
0
,
)
req0
=
make_request
(
"0"
,
list
(
range
(
30
)))
blocks
=
manager
.
allocate_slots
(
req0
,
30
)
assert
[
b
.
block_id
for
b
in
blocks
]
==
[
0
,
1
]
assert
len
(
manager
.
cached_block_hash_to_block
)
==
1
req0
.
num_computed_tokens
=
30
# Simulate speculative tokens.
for
_
in
range
(
5
):
req0
.
append_output_token_ids
(
8
)
manager
.
allocate_slots
(
req0
,
5
)
assert
len
(
manager
.
cached_block_hash_to_block
)
==
2
# After sampling, assuming only 1 token is accepted.
req0
.
num_computed_tokens
=
31
num_uncached_blocks
=
manager
.
uncache_blocks
(
req0
)
assert
num_uncached_blocks
==
1
assert
len
(
manager
.
cached_block_hash_to_block
)
==
1
vllm/v1/core/kv_cache_manager.py
View file @
5095e966
...
@@ -252,29 +252,6 @@ class KVCacheManager:
...
@@ -252,29 +252,6 @@ class KVCacheManager:
if
block
.
ref_cnt
==
0
:
if
block
.
ref_cnt
==
0
:
self
.
free_block_queue
.
append
(
block
)
self
.
free_block_queue
.
append
(
block
)
def
uncache_blocks
(
self
,
request
:
Request
)
->
int
:
"""Uncache the blocks that are no longer full based on the
num_computed_tokens in the given request. This happens when
the blocks were full and cached due to speculative tokens, but the
speculative tokens are not accepted.
Args:
request: The request.
Returns:
The number of uncached blocks.
"""
blocks
=
self
.
req_to_blocks
[
request
.
request_id
]
num_computed_tokens
=
request
.
num_computed_tokens
num_full_blocks
=
num_computed_tokens
//
self
.
block_size
num_uncached_blocks
=
0
for
block
in
blocks
[
num_full_blocks
:]:
# If the block is not cached, the following blocks are not cached.
if
not
self
.
_maybe_evict_cached_block
(
block
):
break
num_uncached_blocks
+=
1
return
num_uncached_blocks
def
reset_prefix_cache
(
self
)
->
bool
:
def
reset_prefix_cache
(
self
)
->
bool
:
"""Reset prefix cache. This function may be used in RLHF
"""Reset prefix cache. This function may be used in RLHF
flows to invalid prefix caching after the weights are updated,
flows to invalid prefix caching after the weights are updated,
...
@@ -470,8 +447,22 @@ class KVCacheManager:
...
@@ -470,8 +447,22 @@ class KVCacheManager:
assert
prev_block
.
block_hash
is
not
None
assert
prev_block
.
block_hash
is
not
None
prev_block_hash_value
=
prev_block
.
block_hash
.
hash_value
prev_block_hash_value
=
prev_block
.
block_hash
.
hash_value
for
i
,
blk
in
enumerate
(
full_blocks
):
# Find the first uncached block. This case should only happen when
blk_idx
=
blk_start_idx
+
i
# speculative decoding is used.
offset
=
0
for
blk
in
full_blocks
:
if
blk
.
block_hash
is
None
:
break
else
:
prev_block_hash_value
=
blk
.
block_hash
.
hash_value
offset
+=
1
else
:
# All blocks are cached.
return
for
i
,
blk
in
enumerate
(
full_blocks
[
offset
:]):
blk_idx
=
blk_start_idx
+
offset
+
i
assert
blk
.
block_hash
is
None
if
blk_idx
<
num_cached_block_hashes
:
if
blk_idx
<
num_cached_block_hashes
:
# The block hash may already be computed in
# The block hash may already be computed in
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment