Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
78ed8f57
Unverified
Commit
78ed8f57
authored
Dec 12, 2024
by
Cody Yu
Committed by
GitHub
Dec 13, 2024
Browse files
[Misc][V1] Fix type in v1 prefix caching (#11151)
parent
db6c264a
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
27 additions
and
15 deletions
+27
-15
tests/v1/core/test_prefix_caching.py
tests/v1/core/test_prefix_caching.py
+8
-4
vllm/v1/core/kv_cache_manager.py
vllm/v1/core/kv_cache_manager.py
+4
-4
vllm/v1/core/kv_cache_utils.py
vllm/v1/core/kv_cache_utils.py
+15
-7
No files found.
tests/v1/core/test_prefix_caching.py
View file @
78ed8f57
...
@@ -49,7 +49,7 @@ def test_prefill():
...
@@ -49,7 +49,7 @@ def test_prefill():
block_hash
=
hash_block_tokens
(
parent_block_hash
,
block_tokens
)
block_hash
=
hash_block_tokens
(
parent_block_hash
,
block_tokens
)
assert
manager
.
block_pool
[
block_id
].
block_hash
==
block_hash
assert
manager
.
block_pool
[
block_id
].
block_hash
==
block_hash
assert
manager
.
block_pool
[
block_id
].
ref_cnt
==
1
assert
manager
.
block_pool
[
block_id
].
ref_cnt
==
1
parent_block_hash
=
block_hash
parent_block_hash
=
block_hash
.
hash_value
# Check partial/preallocated block metadata
# Check partial/preallocated block metadata
for
block_id
in
(
3
,
4
):
for
block_id
in
(
3
,
4
):
...
@@ -360,11 +360,15 @@ def test_preallocate_blocks(num_preallocate_tokens: int, block_size: int):
...
@@ -360,11 +360,15 @@ def test_preallocate_blocks(num_preallocate_tokens: int, block_size: int):
assert
not
computed_blocks
assert
not
computed_blocks
# Just ask for 1 block.
# Just ask for 1 block.
blocks
=
manager
.
allocate_slots
(
req
,
block_size
,
computed_blocks
)
blocks
=
manager
.
allocate_slots
(
req
,
block_size
,
computed_blocks
)
req
.
num_computed_tokens
=
block_size
assert
len
(
blocks
)
==
1
+
num_preallocated_blocks
assert
len
(
blocks
)
==
1
+
num_preallocated_blocks
# Append slots to the block.
# Assume all computed.
req
.
num_computed_tokens
=
block_size
*
len
(
blocks
)
# Assume all used.
manager
.
append_slots
(
req
,
block_size
*
(
len
(
blocks
)
-
1
))
blocks
=
manager
.
append_slots
(
req
,
block_size
)
# Append 1 block.
req
.
num_computed_tokens
=
block_size
*
len
(
blocks
)
# Append 1 block.
blocks
=
manager
.
append_slots
(
req
,
block_size
)
assert
len
(
blocks
)
==
1
+
num_preallocated_blocks
assert
len
(
blocks
)
==
1
+
num_preallocated_blocks
...
...
vllm/v1/core/kv_cache_manager.py
View file @
78ed8f57
...
@@ -375,8 +375,8 @@ class KVCacheManager:
...
@@ -375,8 +375,8 @@ class KVCacheManager:
prev_block: The previous block in the chain.
prev_block: The previous block in the chain.
"""
"""
# Update the new blocks with the block hashes through the chain.
# Update the new blocks with the block hashes through the chain.
prev_block_hash
=
(
prev_block
.
block_hash
prev_block_hash
_value
=
(
prev_block
.
block_hash
.
hash_value
if
prev_block
is
not
None
else
None
)
if
prev_block
is
not
None
else
None
)
for
i
,
blk
in
enumerate
(
full_blocks
):
for
i
,
blk
in
enumerate
(
full_blocks
):
blk_idx
=
blk_start_idx
+
i
blk_idx
=
blk_start_idx
+
i
...
@@ -390,10 +390,10 @@ class KVCacheManager:
...
@@ -390,10 +390,10 @@ class KVCacheManager:
f
"
{
request
.
request_id
}
(
{
request
}
)"
)
f
"
{
request
.
request_id
}
(
{
request
}
)"
)
# Compute the hash of the current block.
# Compute the hash of the current block.
block_hash
=
hash_block_tokens
(
prev_block_hash
,
block_hash
=
hash_block_tokens
(
prev_block_hash
_value
,
tuple
(
block_tokens
))
tuple
(
block_tokens
))
# Update and added the full block to the cache.
# Update and added the full block to the cache.
blk
.
block_hash
=
block_hash
blk
.
block_hash
=
block_hash
self
.
cached_block_hash_to_block
[
block_hash
][
blk
.
block_id
]
=
blk
self
.
cached_block_hash_to_block
[
block_hash
][
blk
.
block_id
]
=
blk
prev_block_hash
=
block_hash
prev_block_hash
_value
=
block_hash
.
hash_value
vllm/v1/core/kv_cache_utils.py
View file @
78ed8f57
"""KV-Cache Utilities."""
"""KV-Cache Utilities."""
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
List
,
Optional
,
Tuple
from
typing
import
List
,
NamedTuple
,
Optional
,
Tuple
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
BlockHashType
=
Tuple
[
int
,
Tuple
[
int
]]
class
BlockHashType
(
NamedTuple
):
"""Hash value of a block and the token IDs in the block.
The reason we keep a tuple of token IDs is to make sure no hash
collision happens when the hash value is the same.
"""
hash_value
:
int
token_ids
:
Tuple
[
int
]
@
dataclass
@
dataclass
...
@@ -171,8 +178,8 @@ def hash_block_tokens(parent_block_hash: Optional[int],
...
@@ -171,8 +178,8 @@ def hash_block_tokens(parent_block_hash: Optional[int],
The hash value of the block and the token ids in the block.
The hash value of the block and the token ids in the block.
The entire tuple is used as the hash key of the block.
The entire tuple is used as the hash key of the block.
"""
"""
return
(
hash
(
return
BlockHashType
(
hash
((
parent_block_hash
,
*
curr_block_token_ids
)),
(
parent_block_hash
,
*
curr_block_token_ids
)),
curr_block_token_ids
)
curr_block_token_ids
)
def
hash_request_tokens
(
block_size
:
int
,
def
hash_request_tokens
(
block_size
:
int
,
...
@@ -188,14 +195,15 @@ def hash_request_tokens(block_size: int,
...
@@ -188,14 +195,15 @@ def hash_request_tokens(block_size: int,
The list of computed hash values.
The list of computed hash values.
"""
"""
ret
=
[]
ret
=
[]
parent_block_hash
=
None
parent_block_hash
_value
=
None
for
start
in
range
(
0
,
len
(
token_ids
),
block_size
):
for
start
in
range
(
0
,
len
(
token_ids
),
block_size
):
end
=
start
+
block_size
end
=
start
+
block_size
block_token_ids
=
tuple
(
token_ids
[
start
:
end
])
block_token_ids
=
tuple
(
token_ids
[
start
:
end
])
# Do not hash the block if it is not full.
# Do not hash the block if it is not full.
if
len
(
block_token_ids
)
<
block_size
:
if
len
(
block_token_ids
)
<
block_size
:
break
break
block_hash
=
hash_block_tokens
(
parent_block_hash
,
block_token_ids
)
block_hash
=
hash_block_tokens
(
parent_block_hash_value
,
block_token_ids
)
ret
.
append
(
block_hash
)
ret
.
append
(
block_hash
)
parent_block_hash
=
block_hash
parent_block_hash
_value
=
block_hash
.
hash_value
return
ret
return
ret
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment