Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
78ed8f57
Unverified
Commit
78ed8f57
authored
Dec 12, 2024
by
Cody Yu
Committed by
GitHub
Dec 13, 2024
Browse files
[Misc][V1] Fix type in v1 prefix caching (#11151)
parent
db6c264a
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
27 additions
and
15 deletions
+27
-15
tests/v1/core/test_prefix_caching.py
tests/v1/core/test_prefix_caching.py
+8
-4
vllm/v1/core/kv_cache_manager.py
vllm/v1/core/kv_cache_manager.py
+4
-4
vllm/v1/core/kv_cache_utils.py
vllm/v1/core/kv_cache_utils.py
+15
-7
No files found.
tests/v1/core/test_prefix_caching.py
View file @
78ed8f57
...
...
@@ -49,7 +49,7 @@ def test_prefill():
block_hash
=
hash_block_tokens
(
parent_block_hash
,
block_tokens
)
assert
manager
.
block_pool
[
block_id
].
block_hash
==
block_hash
assert
manager
.
block_pool
[
block_id
].
ref_cnt
==
1
parent_block_hash
=
block_hash
parent_block_hash
=
block_hash
.
hash_value
# Check partial/preallocated block metadata
for
block_id
in
(
3
,
4
):
...
...
@@ -360,11 +360,15 @@ def test_preallocate_blocks(num_preallocate_tokens: int, block_size: int):
assert
not
computed_blocks
# Just ask for 1 block.
blocks
=
manager
.
allocate_slots
(
req
,
block_size
,
computed_blocks
)
req
.
num_computed_tokens
=
block_size
assert
len
(
blocks
)
==
1
+
num_preallocated_blocks
# Append slots to the block.
req
.
num_computed_tokens
=
block_size
*
len
(
blocks
)
# Assume all used.
blocks
=
manager
.
append_slots
(
req
,
block_size
)
# Append 1 block.
# Assume all computed.
manager
.
append_slots
(
req
,
block_size
*
(
len
(
blocks
)
-
1
))
req
.
num_computed_tokens
=
block_size
*
len
(
blocks
)
# Append 1 block.
blocks
=
manager
.
append_slots
(
req
,
block_size
)
assert
len
(
blocks
)
==
1
+
num_preallocated_blocks
...
...
vllm/v1/core/kv_cache_manager.py
View file @
78ed8f57
...
...
@@ -375,8 +375,8 @@ class KVCacheManager:
prev_block: The previous block in the chain.
"""
# Update the new blocks with the block hashes through the chain.
prev_block_hash
=
(
prev_block
.
block_hash
if
prev_block
is
not
None
else
None
)
prev_block_hash
_value
=
(
prev_block
.
block_hash
.
hash_value
if
prev_block
is
not
None
else
None
)
for
i
,
blk
in
enumerate
(
full_blocks
):
blk_idx
=
blk_start_idx
+
i
...
...
@@ -390,10 +390,10 @@ class KVCacheManager:
f
"
{
request
.
request_id
}
(
{
request
}
)"
)
# Compute the hash of the current block.
block_hash
=
hash_block_tokens
(
prev_block_hash
,
block_hash
=
hash_block_tokens
(
prev_block_hash
_value
,
tuple
(
block_tokens
))
# Update and added the full block to the cache.
blk
.
block_hash
=
block_hash
self
.
cached_block_hash_to_block
[
block_hash
][
blk
.
block_id
]
=
blk
prev_block_hash
=
block_hash
prev_block_hash
_value
=
block_hash
.
hash_value
vllm/v1/core/kv_cache_utils.py
View file @
78ed8f57
"""KV-Cache Utilities."""
from
dataclasses
import
dataclass
from
typing
import
List
,
Optional
,
Tuple
from
typing
import
List
,
NamedTuple
,
Optional
,
Tuple
from
vllm.logger
import
init_logger
logger
=
init_logger
(
__name__
)
BlockHashType
=
Tuple
[
int
,
Tuple
[
int
]]
class
BlockHashType
(
NamedTuple
):
"""Hash value of a block and the token IDs in the block.
The reason we keep a tuple of token IDs is to make sure no hash
collision happens when the hash value is the same.
"""
hash_value
:
int
token_ids
:
Tuple
[
int
]
@
dataclass
...
...
@@ -171,8 +178,8 @@ def hash_block_tokens(parent_block_hash: Optional[int],
The hash value of the block and the token ids in the block.
The entire tuple is used as the hash key of the block.
"""
return
(
hash
(
(
parent_block_hash
,
*
curr_block_token_ids
)),
curr_block_token_ids
)
return
BlockHashType
(
hash
((
parent_block_hash
,
*
curr_block_token_ids
)),
curr_block_token_ids
)
def
hash_request_tokens
(
block_size
:
int
,
...
...
@@ -188,14 +195,15 @@ def hash_request_tokens(block_size: int,
The list of computed hash values.
"""
ret
=
[]
parent_block_hash
=
None
parent_block_hash
_value
=
None
for
start
in
range
(
0
,
len
(
token_ids
),
block_size
):
end
=
start
+
block_size
block_token_ids
=
tuple
(
token_ids
[
start
:
end
])
# Do not hash the block if it is not full.
if
len
(
block_token_ids
)
<
block_size
:
break
block_hash
=
hash_block_tokens
(
parent_block_hash
,
block_token_ids
)
block_hash
=
hash_block_tokens
(
parent_block_hash_value
,
block_token_ids
)
ret
.
append
(
block_hash
)
parent_block_hash
=
block_hash
parent_block_hash
_value
=
block_hash
.
hash_value
return
ret
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment