Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
82dfb12e
Unverified
Commit
82dfb12e
authored
Sep 09, 2025
by
Zebing Lin
Committed by
GitHub
Sep 08, 2025
Browse files
[Core] Use sha256 bytes instead of BlockHash to reduce GC overhead (#23673)
Signed-off-by:
linzebing
<
linzebing1995@gmail.com
>
parent
bba1042c
Changes
15
Show whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
298 additions
and
283 deletions
+298
-283
examples/online_serving/kv_events_subscriber.py
examples/online_serving/kv_events_subscriber.py
+5
-3
tests/utils_/test_utils.py
tests/utils_/test_utils.py
+9
-11
tests/v1/core/test_kv_cache_utils.py
tests/v1/core/test_kv_cache_utils.py
+27
-38
tests/v1/core/test_prefix_caching.py
tests/v1/core/test_prefix_caching.py
+122
-103
tests/v1/core/test_single_type_kv_cache_manager.py
tests/v1/core/test_single_type_kv_cache_manager.py
+8
-8
tests/v1/core/utils.py
tests/v1/core/utils.py
+3
-2
tests/v1/engine/test_engine_args.py
tests/v1/engine/test_engine_args.py
+7
-6
tests/v1/kv_connector/unit/utils.py
tests/v1/kv_connector/unit/utils.py
+3
-2
vllm/config/cache.py
vllm/config/cache.py
+6
-11
vllm/distributed/kv_events.py
vllm/distributed/kv_events.py
+4
-3
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+6
-14
vllm/envs.py
vllm/envs.py
+6
-0
vllm/utils/__init__.py
vllm/utils/__init__.py
+10
-17
vllm/v1/core/block_pool.py
vllm/v1/core/block_pool.py
+18
-9
vllm/v1/core/kv_cache_utils.py
vllm/v1/core/kv_cache_utils.py
+64
-56
No files found.
examples/online_serving/kv_events_subscriber.py
View file @
82dfb12e
...
@@ -6,6 +6,8 @@ import msgspec
...
@@ -6,6 +6,8 @@ import msgspec
import
zmq
import
zmq
from
msgspec.msgpack
import
Decoder
from
msgspec.msgpack
import
Decoder
from
vllm.v1.core.kv_cache_utils
import
BlockHash
#
#
# Types copied from vllm.distributed.kv_events
# Types copied from vllm.distributed.kv_events
...
@@ -22,8 +24,8 @@ class KVCacheEvent(
...
@@ -22,8 +24,8 @@ class KVCacheEvent(
class
BlockStored
(
KVCacheEvent
):
class
BlockStored
(
KVCacheEvent
):
block_hashes
:
list
[
int
]
block_hashes
:
list
[
BlockHash
]
parent_block_hash
:
Optional
[
int
]
parent_block_hash
:
Optional
[
BlockHash
]
token_ids
:
list
[
int
]
token_ids
:
list
[
int
]
block_size
:
int
block_size
:
int
lora_id
:
Optional
[
int
]
lora_id
:
Optional
[
int
]
...
@@ -31,7 +33,7 @@ class BlockStored(KVCacheEvent):
...
@@ -31,7 +33,7 @@ class BlockStored(KVCacheEvent):
class
BlockRemoved
(
KVCacheEvent
):
class
BlockRemoved
(
KVCacheEvent
):
block_hashes
:
list
[
int
]
block_hashes
:
list
[
BlockHash
]
medium
:
Optional
[
str
]
medium
:
Optional
[
str
]
...
...
tests/utils_/test_utils.py
View file @
82dfb12e
...
@@ -835,22 +835,20 @@ def test_model_specification(parser_with_config, cli_config_file,
...
@@ -835,22 +835,20 @@ def test_model_specification(parser_with_config, cli_config_file,
@
pytest
.
mark
.
parametrize
(
"input"
,
[(),
(
"abc"
,
),
(
None
,
),
@
pytest
.
mark
.
parametrize
(
"input"
,
[(),
(
"abc"
,
),
(
None
,
),
(
None
,
bool
,
[
1
,
2
,
3
])])
(
None
,
bool
,
[
1
,
2
,
3
])])
@
pytest
.
mark
.
parametrize
(
"output"
,
[
0
,
1
,
2
])
def
test_sha256
(
input
:
tuple
):
def
test_sha256
(
input
:
tuple
,
output
:
int
):
digest
=
sha256
(
input
)
hash
=
sha256
(
input
)
assert
digest
is
not
None
assert
hash
is
not
None
assert
isinstance
(
digest
,
bytes
)
assert
isinstance
(
hash
,
int
)
assert
digest
!=
b
""
assert
hash
!=
0
bytes
=
pickle
.
dumps
(
input
,
protocol
=
pickle
.
HIGHEST_PROTOCOL
)
input_bytes
=
pickle
.
dumps
(
input
,
protocol
=
pickle
.
HIGHEST_PROTOCOL
)
assert
hash
==
int
.
from_bytes
(
hashlib
.
sha256
(
bytes
).
digest
(),
assert
digest
==
hashlib
.
sha256
(
input_bytes
).
digest
()
byteorder
=
"big"
)
# hashing again, returns the same value
# hashing again, returns the same value
assert
hash
==
sha256
(
input
)
assert
digest
==
sha256
(
input
)
# hashing different input, returns different value
# hashing different input, returns different value
assert
hash
!=
sha256
(
input
+
(
1
,
))
assert
digest
!=
sha256
(
input
+
(
1
,
))
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
...
...
tests/v1/core/test_kv_cache_utils.py
View file @
82dfb12e
...
@@ -6,20 +6,22 @@ from typing import Callable, Optional
...
@@ -6,20 +6,22 @@ from typing import Callable, Optional
import
pytest
import
pytest
import
torch
import
torch
import
vllm.v1.core.kv_cache_utils
as
kv_cache_utils
from
vllm.config
import
ModelConfig
,
SchedulerConfig
,
VllmConfig
from
vllm.config
import
ModelConfig
,
SchedulerConfig
,
VllmConfig
from
vllm.multimodal.inputs
import
(
MultiModalFeatureSpec
,
from
vllm.multimodal.inputs
import
(
MultiModalFeatureSpec
,
MultiModalKwargsItem
,
PlaceholderRange
)
MultiModalKwargsItem
,
PlaceholderRange
)
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.utils
import
GiB_bytes
,
sha256
,
sha256_cbor
_64bit
from
vllm.utils
import
GiB_bytes
,
sha256
,
sha256_cbor
from
vllm.v1.core.kv_cache_manager
import
KVCacheManager
from
vllm.v1.core.kv_cache_manager
import
KVCacheManager
# disable yapf here as it formats differently than isort such that both fail
# disable yapf here as it formats differently than isort such that both fail
# yapf: disable
# yapf: disable
from
vllm.v1.core.kv_cache_utils
import
(
from
vllm.v1.core.kv_cache_utils
import
(
FreeKVCacheBlockQueue
,
KVCacheBlock
,
PrefixCachingMetrics
,
BlockHash
,
FreeKVCacheBlockQueue
,
KVCacheBlock
,
PrefixCachingMetrics
,
estimate_max_model_len
,
generate_block_hash_extra_keys
,
estimate_max_model_len
,
generate_block_hash_extra_keys
,
get_kv_cache_config
,
get_max_concurrency_for_kv_cache_config
,
get_kv_cache_config
,
get_max_concurrency_for_kv_cache_config
,
get_request_block_hasher
,
hash_block_tokens
,
init_none_hash
,
get_request_block_hasher
,
hash_block_tokens
,
init_none_hash
,
is_kv_cache_type_uniform
,
unify_kv_cache_configs
)
is_kv_cache_type_uniform
,
make_block_hash_with_group_id
,
unify_kv_cache_configs
)
from
vllm.v1.kv_cache_interface
import
(
FullAttentionSpec
,
KVCacheConfig
,
from
vllm.v1.kv_cache_interface
import
(
FullAttentionSpec
,
KVCacheConfig
,
KVCacheGroupSpec
,
KVCacheTensor
,
KVCacheGroupSpec
,
KVCacheTensor
,
SlidingWindowSpec
)
SlidingWindowSpec
)
...
@@ -88,7 +90,7 @@ def new_sliding_window_spec(block_size=16,
...
@@ -88,7 +90,7 @@ def new_sliding_window_spec(block_size=16,
sliding_window
=
sliding_window
)
sliding_window
=
sliding_window
)
@
pytest
.
mark
.
parametrize
(
"hash_fn"
,
[
sha256
,
sha256_cbor
_64bit
,
hash
])
@
pytest
.
mark
.
parametrize
(
"hash_fn"
,
[
sha256
,
sha256_cbor
])
def
test_none_hash
(
monkeypatch
,
hash_fn
):
def
test_none_hash
(
monkeypatch
,
hash_fn
):
import
vllm.v1.core.kv_cache_utils
import
vllm.v1.core.kv_cache_utils
...
@@ -98,8 +100,8 @@ def test_none_hash(monkeypatch, hash_fn):
...
@@ -98,8 +100,8 @@ def test_none_hash(monkeypatch, hash_fn):
reloaded_kv_cache_utils
=
importlib
.
reload
(
vllm
.
v1
.
core
.
kv_cache_utils
)
reloaded_kv_cache_utils
=
importlib
.
reload
(
vllm
.
v1
.
core
.
kv_cache_utils
)
reloaded_kv_cache_utils
.
init_none_hash
(
hash_fn
)
reloaded_kv_cache_utils
.
init_none_hash
(
hash_fn
)
assert
reloaded_kv_cache_utils
.
NONE_HASH
is
not
None
assert
reloaded_kv_cache_utils
.
NONE_HASH
is
not
None
assert
isinstance
(
reloaded_kv_cache_utils
.
NONE_HASH
,
int
)
assert
isinstance
(
reloaded_kv_cache_utils
.
NONE_HASH
,
bytes
)
assert
reloaded_kv_cache_utils
.
NONE_HASH
!=
0
assert
reloaded_kv_cache_utils
.
NONE_HASH
!=
b
""
# case 2: PYTHONHASHSEED is set, use the seed and hash_fn
# case 2: PYTHONHASHSEED is set, use the seed and hash_fn
with
monkeypatch
.
context
()
as
m
:
with
monkeypatch
.
context
()
as
m
:
...
@@ -107,12 +109,11 @@ def test_none_hash(monkeypatch, hash_fn):
...
@@ -107,12 +109,11 @@ def test_none_hash(monkeypatch, hash_fn):
reloaded_kv_cache_utils
=
importlib
.
reload
(
vllm
.
v1
.
core
.
kv_cache_utils
)
reloaded_kv_cache_utils
=
importlib
.
reload
(
vllm
.
v1
.
core
.
kv_cache_utils
)
reloaded_kv_cache_utils
.
init_none_hash
(
hash_fn
)
reloaded_kv_cache_utils
.
init_none_hash
(
hash_fn
)
assert
reloaded_kv_cache_utils
.
NONE_HASH
is
not
None
assert
reloaded_kv_cache_utils
.
NONE_HASH
is
not
None
assert
isinstance
(
reloaded_kv_cache_utils
.
NONE_HASH
,
int
)
assert
isinstance
(
reloaded_kv_cache_utils
.
NONE_HASH
,
bytes
)
assert
hash_fn
(
'python hash seed'
)
==
reloaded_kv_cache_utils
.
NONE_HASH
assert
hash_fn
(
'python hash seed'
)
==
reloaded_kv_cache_utils
.
NONE_HASH
def
test_kv_cache_block
():
def
test_kv_cache_block
():
import
vllm.v1.core.kv_cache_utils
# Test KVCacheBlock initialization
# Test KVCacheBlock initialization
block
=
KVCacheBlock
(
block_id
=
0
)
block
=
KVCacheBlock
(
block_id
=
0
)
...
@@ -127,8 +128,7 @@ def test_kv_cache_block():
...
@@ -127,8 +128,7 @@ def test_kv_cache_block():
assert
block
.
ref_cnt
==
0
assert
block
.
ref_cnt
==
0
# Test block hash setting and resetting
# Test block hash setting and resetting
block_hash
=
vllm
.
v1
.
core
.
kv_cache_utils
.
BlockHash
(
hash_value
=
123
,
block_hash
=
make_block_hash_with_group_id
(
BlockHash
(
b
"abc"
),
0
)
token_ids
=
(
1
,
2
,
3
))
block
.
block_hash
=
block_hash
block
.
block_hash
=
block_hash
assert
block
.
block_hash
==
block_hash
assert
block
.
block_hash
==
block_hash
...
@@ -407,27 +407,23 @@ def test_generate_block_hash_extra_keys_cache_salt():
...
@@ -407,27 +407,23 @@ def test_generate_block_hash_extra_keys_cache_salt():
assert
next_mm_idx
==
1
assert
next_mm_idx
==
1
@
pytest
.
mark
.
parametrize
(
"hash_fn"
,
[
sha256
,
sha256_cbor
_64bit
,
hash
])
@
pytest
.
mark
.
parametrize
(
"hash_fn"
,
[
sha256
,
sha256_cbor
])
def
test_hash_block_tokens
(
hash_fn
):
def
test_hash_block_tokens
(
hash_fn
):
import
vllm.v1.core.kv_cache_utils
init_none_hash
(
hash_fn
)
init_none_hash
(
hash_fn
)
parent_block_hash
=
123
parent_block_hash
=
BlockHash
(
b
"
123
"
)
curr_block_token_ids
=
(
1
,
2
,
3
)
curr_block_token_ids
=
(
1
,
2
,
3
)
extra_keys
=
(
"key1"
,
"key2"
)
extra_keys
=
(
"key1"
,
"key2"
)
block_hash
=
hash_block_tokens
(
hash_fn
,
parent_block_hash
,
block_hash
=
hash_block_tokens
(
hash_fn
,
parent_block_hash
,
curr_block_token_ids
,
extra_keys
)
curr_block_token_ids
,
extra_keys
)
assert
isinstance
(
block_hash
,
vllm
.
v1
.
core
.
kv_cache_utils
.
BlockHash
)
expected
=
hash_fn
((
parent_block_hash
,
curr_block_token_ids
,
extra_keys
))
assert
block_hash
.
hash_value
==
hash_fn
(
assert
block_hash
==
expected
(
parent_block_hash
,
curr_block_token_ids
,
extra_keys
))
assert
block_hash
.
token_ids
==
curr_block_token_ids
assert
block_hash
.
extra_keys
==
extra_keys
@
pytest
.
mark
.
parametrize
(
"hash_fn"
,
[
sha256
,
sha256_cbor
_64bit
,
hash
])
@
pytest
.
mark
.
parametrize
(
"hash_fn"
,
[
sha256
,
sha256_cbor
])
def
test_request_block_hasher
(
hash_fn
):
def
test_request_block_hasher
(
hash_fn
):
import
vllm.v1.core.kv_cache_utils
kv_cache_utils
.
init_none_hash
(
hash_fn
)
init_none_hash
(
hash_fn
)
request
=
make_request
(
request
=
make_request
(
request_id
=
"0"
,
request_id
=
"0"
,
prompt_token_ids
=
[
_
for
_
in
range
(
6
)],
prompt_token_ids
=
[
_
for
_
in
range
(
6
)],
...
@@ -442,19 +438,13 @@ def test_request_block_hasher(hash_fn):
...
@@ -442,19 +438,13 @@ def test_request_block_hasher(hash_fn):
block_hashes
=
request
.
block_hashes
block_hashes
=
request
.
block_hashes
assert
len
(
block_hashes
)
==
2
assert
len
(
block_hashes
)
==
2
assert
isinstance
(
block_hashes
[
0
],
vllm
.
v1
.
core
.
kv_cache_utils
.
BlockHash
)
assert
block_hashes
[
0
]
==
hash_fn
(
assert
isinstance
(
block_hashes
[
1
],
vllm
.
v1
.
core
.
kv_cache_utils
.
BlockHash
)
(
kv_cache_utils
.
NONE_HASH
,
(
0
,
1
,
2
),
(
"hash1"
,
)))
assert
block_hashes
[
1
]
==
hash_fn
(
# Check the first block
(
block_hashes
[
0
],
(
3
,
4
,
5
),
(
"hash2"
,
)))
assert
block_hashes
[
0
].
token_ids
==
(
0
,
1
,
2
)
assert
block_hashes
[
0
].
extra_keys
==
(
"hash1"
,
)
# Check the second block
assert
block_hashes
[
1
].
token_ids
==
(
3
,
4
,
5
)
assert
block_hashes
[
1
].
extra_keys
==
(
"hash2"
,
)
@
pytest
.
mark
.
parametrize
(
"hash_fn"
,
[
sha256
,
sha256_cbor
])
@
pytest
.
mark
.
parametrize
(
"hash_fn"
,
[
sha256
,
sha256_cbor_64bit
,
hash
])
def
test_hash_tokens_different_mm_input
(
hash_fn
):
def
test_hash_tokens_different_mm_input
(
hash_fn
):
init_none_hash
(
hash_fn
)
init_none_hash
(
hash_fn
)
...
@@ -484,9 +474,9 @@ def test_hash_tokens_different_mm_input(hash_fn):
...
@@ -484,9 +474,9 @@ def test_hash_tokens_different_mm_input(hash_fn):
assert
block_hashes1
[
1
]
!=
block_hashes2
[
1
]
assert
block_hashes1
[
1
]
!=
block_hashes2
[
1
]
@
pytest
.
mark
.
parametrize
(
"hash_fn"
,
[
sha256
,
sha256_cbor
_64bit
,
hash
])
@
pytest
.
mark
.
parametrize
(
"hash_fn"
,
[
sha256
,
sha256_cbor
])
def
test_hash_request_tokens_no_mm_inputs
(
hash_fn
):
def
test_hash_request_tokens_no_mm_inputs
(
hash_fn
):
init_none_hash
(
hash_fn
)
kv_cache_utils
.
init_none_hash
(
hash_fn
)
request
=
make_request
(
request
=
make_request
(
request_id
=
"0"
,
request_id
=
"0"
,
...
@@ -500,10 +490,9 @@ def test_hash_request_tokens_no_mm_inputs(hash_fn):
...
@@ -500,10 +490,9 @@ def test_hash_request_tokens_no_mm_inputs(hash_fn):
block_hashes
=
request
.
block_hashes
block_hashes
=
request
.
block_hashes
assert
len
(
block_hashes
)
==
2
assert
len
(
block_hashes
)
==
2
assert
block_hashes
[
0
].
token_ids
==
(
0
,
1
,
2
)
assert
block_hashes
[
0
]
==
hash_fn
(
assert
block_hashes
[
0
].
extra_keys
is
None
(
kv_cache_utils
.
NONE_HASH
,
(
0
,
1
,
2
),
None
))
assert
block_hashes
[
1
].
token_ids
==
(
3
,
4
,
5
)
assert
block_hashes
[
1
]
==
hash_fn
((
block_hashes
[
0
],
(
3
,
4
,
5
),
None
))
assert
block_hashes
[
1
].
extra_keys
is
None
def
test_metrics
():
def
test_metrics
():
...
...
tests/v1/core/test_prefix_caching.py
View file @
82dfb12e
...
@@ -8,17 +8,19 @@ from typing import Callable, Optional
...
@@ -8,17 +8,19 @@ from typing import Callable, Optional
import
pytest
import
pytest
import
torch
import
torch
import
vllm.v1.core.kv_cache_utils
as
kv_cache_utils
from
vllm.distributed.kv_events
import
AllBlocksCleared
,
BlockRemoved
from
vllm.distributed.kv_events
import
AllBlocksCleared
,
BlockRemoved
from
vllm.multimodal.inputs
import
(
MultiModalFeatureSpec
,
from
vllm.multimodal.inputs
import
(
MultiModalFeatureSpec
,
MultiModalKwargsItem
,
PlaceholderRange
)
MultiModalKwargsItem
,
PlaceholderRange
)
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.utils
import
sha256
,
sha256_cbor
_64bit
from
vllm.utils
import
sha256
,
sha256_cbor
from
vllm.v1.core.block_pool
import
BlockPool
from
vllm.v1.core.block_pool
import
BlockPool
from
vllm.v1.core.kv_cache_manager
import
KVCacheManager
,
Request
from
vllm.v1.core.kv_cache_manager
import
KVCacheManager
,
Request
from
vllm.v1.core.kv_cache_utils
import
(
BlockHash
,
BlockHashWithGroupId
,
from
vllm.v1.core.kv_cache_utils
import
(
BlockHash
,
KVCacheBlock
,
KVCacheBlock
,
get_block_hash
,
get_group_id
,
get_request_block_hasher
,
get_request_block_hasher
,
hash_block_tokens
,
init_none_hash
)
hash_block_tokens
,
init_none_hash
,
make_block_hash_with_group_id
)
from
vllm.v1.kv_cache_interface
import
(
FullAttentionSpec
,
KVCacheConfig
,
from
vllm.v1.kv_cache_interface
import
(
FullAttentionSpec
,
KVCacheConfig
,
KVCacheGroupSpec
,
SlidingWindowSpec
)
KVCacheGroupSpec
,
SlidingWindowSpec
)
...
@@ -101,8 +103,10 @@ def make_kv_cache_config_hybrid_model(block_size: int,
...
@@ -101,8 +103,10 @@ def make_kv_cache_config_hybrid_model(block_size: int,
)
)
@
pytest
.
mark
.
parametrize
(
"hash_algo"
,
[
"sha256"
,
"sha256_cbor_64bit"
,
"hash"
])
@
pytest
.
mark
.
parametrize
(
"hash_fn"
,
[
sha256
,
sha256_cbor
])
def
test_prefill
(
hash_algo
):
def
test_prefill
(
hash_fn
):
init_none_hash
(
hash_fn
)
block_size
=
16
block_size
=
16
manager
=
KVCacheManager
(
manager
=
KVCacheManager
(
make_kv_cache_config
(
block_size
,
11
),
make_kv_cache_config
(
block_size
,
11
),
...
@@ -110,10 +114,6 @@ def test_prefill(hash_algo):
...
@@ -110,10 +114,6 @@ def test_prefill(hash_algo):
enable_caching
=
True
,
enable_caching
=
True
,
)
)
# choose the hash function according to the parameter
hash_fn
=
(
sha256_cbor_64bit
if
hash_algo
==
"sha256_cbor_64bit"
else
sha256
if
hash_algo
==
"sha256"
else
hash
)
# Complete 3 blocks (48 tokens)
# Complete 3 blocks (48 tokens)
common_token_ids
=
[
i
for
i
in
range
(
3
)
for
_
in
range
(
16
)]
common_token_ids
=
[
i
for
i
in
range
(
3
)
for
_
in
range
(
16
)]
...
@@ -137,10 +137,12 @@ def test_prefill(hash_algo):
...
@@ -137,10 +137,12 @@ def test_prefill(hash_algo):
block_tokens
=
tuple
(
all_token_ids
[(
block_id
-
1
)
*
16
:
block_id
*
16
])
block_tokens
=
tuple
(
all_token_ids
[(
block_id
-
1
)
*
16
:
block_id
*
16
])
block_hash
=
hash_block_tokens
(
hash_fn
,
parent_block_hash
,
block_hash
=
hash_block_tokens
(
hash_fn
,
parent_block_hash
,
block_tokens
)
block_tokens
)
assert
manager
.
block_pool
.
blocks
[
blk_hash
=
manager
.
block_pool
.
blocks
[
block_id
].
block_hash
block_id
].
block_hash
.
block_hash
==
block_hash
assert
blk_hash
is
not
None
assert
get_block_hash
(
blk_hash
)
==
block_hash
assert
get_group_id
(
blk_hash
)
==
0
assert
manager
.
block_pool
.
blocks
[
block_id
].
ref_cnt
==
1
assert
manager
.
block_pool
.
blocks
[
block_id
].
ref_cnt
==
1
parent_block_hash
=
block_hash
.
hash_value
parent_block_hash
=
block_hash
# Check partial block metadata
# Check partial block metadata
for
block_id
in
(
4
,
):
for
block_id
in
(
4
,
):
...
@@ -233,7 +235,7 @@ def test_prefill_hybrid_model():
...
@@ -233,7 +235,7 @@ def test_prefill_hybrid_model():
enable_caching
=
True
,
enable_caching
=
True
,
)
)
hash_fn
=
ha
sh
hash_fn
=
sh
a256
# Complete 3 blocks (48 tokens)
# Complete 3 blocks (48 tokens)
common_token_ids
=
[
i
for
i
in
range
(
3
)
for
_
in
range
(
block_size
)]
common_token_ids
=
[
i
for
i
in
range
(
3
)
for
_
in
range
(
block_size
)]
...
@@ -260,11 +262,13 @@ def test_prefill_hybrid_model():
...
@@ -260,11 +262,13 @@ def test_prefill_hybrid_model():
block_tokens
=
tuple
(
all_token_ids
[(
length
-
1
)
*
16
:
length
*
16
])
block_tokens
=
tuple
(
all_token_ids
[(
length
-
1
)
*
16
:
length
*
16
])
block_hash
=
hash_block_tokens
(
hash_fn
,
parent_block_hash
,
block_hash
=
hash_block_tokens
(
hash_fn
,
parent_block_hash
,
block_tokens
)
block_tokens
)
for
block_id
in
block_ids
:
for
group_id
,
block_id
in
enumerate
(
block_ids
):
assert
manager
.
block_pool
.
blocks
[
blk_hash
=
manager
.
block_pool
.
blocks
[
block_id
].
block_hash
block_id
].
block_hash
.
block_hash
==
block_hash
assert
blk_hash
is
not
None
assert
get_block_hash
(
blk_hash
)
==
block_hash
assert
get_group_id
(
blk_hash
)
==
group_id
assert
manager
.
block_pool
.
blocks
[
block_id
].
ref_cnt
==
1
assert
manager
.
block_pool
.
blocks
[
block_id
].
ref_cnt
==
1
parent_block_hash
=
block_hash
.
hash_value
parent_block_hash
=
block_hash
# Check partial block metadata
# Check partial block metadata
for
block_id
in
(
4
,
8
,
12
):
for
block_id
in
(
4
,
8
,
12
):
...
@@ -298,11 +302,10 @@ def test_prefill_hybrid_model():
...
@@ -298,11 +302,10 @@ def test_prefill_hybrid_model():
cached_block_hash_to_block_bak
=
copy
.
copy
(
cached_block_hash_to_block_bak
=
copy
.
copy
(
manager
.
block_pool
.
cached_block_hash_to_block
)
manager
.
block_pool
.
cached_block_hash_to_block
)
def
test_partial_request_hit
(
request_id
:
str
,
def
test_partial_request_hit
(
request_id
:
str
,
hash_to_evict
:
list
[
bytes
],
hash_to_evict
:
list
[
BlockHashWithGroupId
],
expect_hit_length
:
int
):
expect_hit_length
:
int
):
req
=
make_request
(
request_id
,
common_token_ids
+
unique_token_ids
,
req
=
make_request
(
request_id
,
common_token_ids
+
unique_token_ids
,
block_size
,
ha
sh
)
block_size
,
sh
a256
)
for
hash_with_group_id
in
hash_to_evict
:
for
hash_with_group_id
in
hash_to_evict
:
manager
.
block_pool
.
cached_block_hash_to_block
.
pop
(
manager
.
block_pool
.
cached_block_hash_to_block
.
pop
(
hash_with_group_id
)
hash_with_group_id
)
...
@@ -319,33 +322,32 @@ def test_prefill_hybrid_model():
...
@@ -319,33 +322,32 @@ def test_prefill_hybrid_model():
# Evict the blocks outside sliding window, does not affect the hit length.
# Evict the blocks outside sliding window, does not affect the hit length.
test_partial_request_hit
(
"2"
,
[
test_partial_request_hit
(
"2"
,
[
B
lock
H
ash
W
ith
G
roup
I
d
(
block_hashes
[
0
],
1
),
make_b
lock
_h
ash
_w
ith
_g
roup
_i
d
(
block_hashes
[
0
],
1
),
B
lock
H
ash
W
ith
G
roup
I
d
(
block_hashes
[
0
],
2
)
make_b
lock
_h
ash
_w
ith
_g
roup
_i
d
(
block_hashes
[
0
],
2
)
],
3
)
],
3
)
# Evict the first block of full attention, makes total cache miss.
# Evict the first block of full attention, makes total cache miss.
test_partial_request_hit
(
"3"
,
[
test_partial_request_hit
(
BlockHashWithGroupId
(
block_hashes
[
0
],
0
),
"3"
,
[
make_block_hash_with_group_id
(
block_hashes
[
0
],
0
)],
0
)
],
0
)
# Evict the last block of all layers, reduces the hit length to 2.
# Evict the last block of all layers, reduces the hit length to 2.
test_partial_request_hit
(
"4"
,
[
test_partial_request_hit
(
"4"
,
[
B
lock
H
ash
W
ith
G
roup
I
d
(
block_hashes
[
2
],
0
),
make_b
lock
_h
ash
_w
ith
_g
roup
_i
d
(
block_hashes
[
2
],
0
),
B
lock
H
ash
W
ith
G
roup
I
d
(
block_hashes
[
2
],
1
),
make_b
lock
_h
ash
_w
ith
_g
roup
_i
d
(
block_hashes
[
2
],
1
),
B
lock
H
ash
W
ith
G
roup
I
d
(
block_hashes
[
2
],
2
),
make_b
lock
_h
ash
_w
ith
_g
roup
_i
d
(
block_hashes
[
2
],
2
),
],
2
)
],
2
)
# Evict the last block of full attention, reduces the hit length to 2.
# Evict the last block of full attention, reduces the hit length to 2.
test_partial_request_hit
(
"5"
,
[
BlockHashWithGroupId
(
block_hashes
[
2
],
0
)],
test_partial_request_hit
(
2
)
"5"
,
[
make_block_hash_with_group_id
(
block_hashes
[
2
],
0
)],
2
)
# Evict the last block of sliding window, reduces the hit length to 2.
# Evict the last block of sliding window, reduces the hit length to 2.
test_partial_request_hit
(
"6"
,
[
BlockHashWithGroupId
(
block_hashes
[
2
],
1
)],
test_partial_request_hit
(
2
)
"6"
,
[
make_block_hash_with_group_id
(
block_hashes
[
2
],
1
)],
2
)
# Evict the last block of sliding window, reduces the hit length to 2.
# Evict the last block of sliding window, reduces the hit length to 2.
test_partial_request_hit
(
"7"
,
[
BlockHashWithGroupId
(
block_hashes
[
2
],
2
)],
test_partial_request_hit
(
2
)
"7"
,
[
make_block_hash_with_group_id
(
block_hashes
[
2
],
2
)],
2
)
# Evict different set of blocks for full attention and sliding window makes
# Evict different set of blocks for full attention and sliding window makes
# total cache miss.
# total cache miss.
...
@@ -353,9 +355,9 @@ def test_prefill_hybrid_model():
...
@@ -353,9 +355,9 @@ def test_prefill_hybrid_model():
# The cache hit length of sliding window is 2 * block_size.
# The cache hit length of sliding window is 2 * block_size.
# Then it is cache miss as the two type of layers have different hit length.
# Then it is cache miss as the two type of layers have different hit length.
test_partial_request_hit
(
"8"
,
[
test_partial_request_hit
(
"8"
,
[
B
lock
H
ash
W
ith
G
roup
I
d
(
block_hashes
[
2
],
0
),
make_b
lock
_h
ash
_w
ith
_g
roup
_i
d
(
block_hashes
[
2
],
0
),
B
lock
H
ash
W
ith
G
roup
I
d
(
block_hashes
[
0
],
1
),
make_b
lock
_h
ash
_w
ith
_g
roup
_i
d
(
block_hashes
[
0
],
1
),
B
lock
H
ash
W
ith
G
roup
I
d
(
block_hashes
[
0
],
2
),
make_b
lock
_h
ash
_w
ith
_g
roup
_i
d
(
block_hashes
[
0
],
2
),
],
0
)
],
0
)
...
@@ -372,8 +374,8 @@ def test_prefill_plp():
...
@@ -372,8 +374,8 @@ def test_prefill_plp():
max_model_len
=
8192
,
max_model_len
=
8192
,
enable_caching
=
True
,
enable_caching
=
True
,
)
)
# the default hash function is
ha
sh
# the default hash function is sh
a256
hash_fn
=
ha
sh
hash_fn
=
sh
a256
# Complete 3 blocks (48 tokens)
# Complete 3 blocks (48 tokens)
common_token_ids
=
[
i
for
i
in
range
(
3
)
for
_
in
range
(
16
)]
common_token_ids
=
[
i
for
i
in
range
(
3
)
for
_
in
range
(
16
)]
...
@@ -404,10 +406,12 @@ def test_prefill_plp():
...
@@ -404,10 +406,12 @@ def test_prefill_plp():
block_tokens
=
tuple
(
all_token_ids
[(
block_id
-
1
)
*
16
:
block_id
*
16
])
block_tokens
=
tuple
(
all_token_ids
[(
block_id
-
1
)
*
16
:
block_id
*
16
])
block_hash
=
hash_block_tokens
(
hash_fn
,
parent_block_hash
,
block_hash
=
hash_block_tokens
(
hash_fn
,
parent_block_hash
,
block_tokens
)
block_tokens
)
assert
manager
.
block_pool
.
blocks
[
blk_hash
=
(
manager
.
block_pool
.
blocks
[
block_id
].
block_hash
)
block_id
].
block_hash
.
block_hash
==
block_hash
assert
blk_hash
is
not
None
assert
get_block_hash
(
blk_hash
)
==
block_hash
assert
get_group_id
(
blk_hash
)
==
0
assert
manager
.
block_pool
.
blocks
[
block_id
].
ref_cnt
==
1
assert
manager
.
block_pool
.
blocks
[
block_id
].
ref_cnt
==
1
parent_block_hash
=
block_hash
.
hash_value
parent_block_hash
=
block_hash
# Check partial block metadata
# Check partial block metadata
for
block_id
in
(
4
,
):
for
block_id
in
(
4
,
):
...
@@ -493,7 +497,7 @@ def test_decode():
...
@@ -493,7 +497,7 @@ def test_decode():
# Incomplete 1 block (7 tokens)
# Incomplete 1 block (7 tokens)
unique_token_ids
=
[
3
]
*
7
unique_token_ids
=
[
3
]
*
7
req0
=
make_request
(
"0"
,
common_token_ids
+
unique_token_ids
,
block_size
,
req0
=
make_request
(
"0"
,
common_token_ids
+
unique_token_ids
,
block_size
,
ha
sh
)
sh
a256
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req0
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req0
)
assert
not
computed_blocks
.
blocks
[
0
]
assert
not
computed_blocks
.
blocks
[
0
]
assert
num_computed_tokens
==
0
assert
num_computed_tokens
==
0
...
@@ -538,7 +542,7 @@ def test_evict():
...
@@ -538,7 +542,7 @@ def test_evict():
)
)
last_token_id
=
5
*
16
+
7
last_token_id
=
5
*
16
+
7
req0
=
make_request
(
"0"
,
list
(
range
(
last_token_id
)),
block_size
,
ha
sh
)
req0
=
make_request
(
"0"
,
list
(
range
(
last_token_id
)),
block_size
,
sh
a256
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req0
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req0
)
assert
not
computed_blocks
.
blocks
[
0
]
assert
not
computed_blocks
.
blocks
[
0
]
assert
num_computed_tokens
==
0
assert
num_computed_tokens
==
0
...
@@ -550,7 +554,7 @@ def test_evict():
...
@@ -550,7 +554,7 @@ def test_evict():
# 3 blocks.
# 3 blocks.
req1
=
make_request
(
"1"
,
list
(
range
(
last_token_id
,
req1
=
make_request
(
"1"
,
list
(
range
(
last_token_id
,
last_token_id
+
3
*
16
)),
block_size
,
last_token_id
+
3
*
16
)),
block_size
,
ha
sh
)
sh
a256
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req1
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req1
)
assert
not
computed_blocks
.
blocks
[
0
]
assert
not
computed_blocks
.
blocks
[
0
]
assert
num_computed_tokens
==
0
assert
num_computed_tokens
==
0
...
@@ -572,7 +576,7 @@ def test_evict():
...
@@ -572,7 +576,7 @@ def test_evict():
]
==
[
10
,
6
,
5
,
4
,
3
,
2
,
1
,
9
,
8
,
7
]
]
==
[
10
,
6
,
5
,
4
,
3
,
2
,
1
,
9
,
8
,
7
]
# Touch the first 2 blocks.
# Touch the first 2 blocks.
req2
=
make_request
(
"2"
,
list
(
range
(
2
*
16
+
3
)),
block_size
,
ha
sh
)
req2
=
make_request
(
"2"
,
list
(
range
(
2
*
16
+
3
)),
block_size
,
sh
a256
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req2
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req2
)
assert
computed_blocks
.
get_block_ids
()
==
([
1
,
2
],
)
assert
computed_blocks
.
get_block_ids
()
==
([
1
,
2
],
)
assert
num_computed_tokens
==
2
*
16
assert
num_computed_tokens
==
2
*
16
...
@@ -597,7 +601,7 @@ def test_hash_block_correct_reuse():
...
@@ -597,7 +601,7 @@ def test_hash_block_correct_reuse():
# Allocate 1 block and cache it.
# Allocate 1 block and cache it.
num_tokens
=
block_size
*
1
num_tokens
=
block_size
*
1
req
=
make_request
(
"0"
,
list
(
range
(
num_tokens
)),
block_size
,
ha
sh
)
req
=
make_request
(
"0"
,
list
(
range
(
num_tokens
)),
block_size
,
sh
a256
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req
)
assert
not
computed_blocks
.
blocks
[
0
]
assert
not
computed_blocks
.
blocks
[
0
]
assert
num_computed_tokens
==
0
assert
num_computed_tokens
==
0
...
@@ -611,7 +615,7 @@ def test_hash_block_correct_reuse():
...
@@ -611,7 +615,7 @@ def test_hash_block_correct_reuse():
# Allocate a new block that's not full, make sure hash info on the
# Allocate a new block that's not full, make sure hash info on the
# block is cleared.
# block is cleared.
req
=
make_request
(
"1"
,
list
(
range
(
num_tokens
-
1
)),
block_size
,
ha
sh
)
req
=
make_request
(
"1"
,
list
(
range
(
num_tokens
-
1
)),
block_size
,
sh
a256
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req
)
assert
not
computed_blocks
.
blocks
[
0
]
assert
not
computed_blocks
.
blocks
[
0
]
assert
num_computed_tokens
==
0
assert
num_computed_tokens
==
0
...
@@ -638,7 +642,7 @@ def test_computed_blocks_not_evicted():
...
@@ -638,7 +642,7 @@ def test_computed_blocks_not_evicted():
# Allocate a block and cache it.
# Allocate a block and cache it.
num_tokens
=
block_size
*
1
num_tokens
=
block_size
*
1
req0
=
make_request
(
"0"
,
list
(
range
(
num_tokens
)),
block_size
,
ha
sh
)
req0
=
make_request
(
"0"
,
list
(
range
(
num_tokens
)),
block_size
,
sh
a256
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req0
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req0
)
assert
not
computed_blocks
.
blocks
[
0
]
assert
not
computed_blocks
.
blocks
[
0
]
assert
num_computed_tokens
==
0
assert
num_computed_tokens
==
0
...
@@ -650,7 +654,7 @@ def test_computed_blocks_not_evicted():
...
@@ -650,7 +654,7 @@ def test_computed_blocks_not_evicted():
# Allocate another block.
# Allocate another block.
req1
=
make_request
(
"1"
,
list
(
range
(
num_tokens
,
num_tokens
*
2
)),
req1
=
make_request
(
"1"
,
list
(
range
(
num_tokens
,
num_tokens
*
2
)),
block_size
,
ha
sh
)
block_size
,
sh
a256
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req1
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req1
)
assert
not
computed_blocks
.
blocks
[
0
]
assert
not
computed_blocks
.
blocks
[
0
]
assert
num_computed_tokens
==
0
assert
num_computed_tokens
==
0
...
@@ -666,7 +670,7 @@ def test_computed_blocks_not_evicted():
...
@@ -666,7 +670,7 @@ def test_computed_blocks_not_evicted():
# Now if we have a cache hit on the first block, we should evict the second
# Now if we have a cache hit on the first block, we should evict the second
# cached block rather than the first one.
# cached block rather than the first one.
req2
=
make_request
(
"2"
,
list
(
range
(
num_tokens
*
2
)),
block_size
,
ha
sh
)
req2
=
make_request
(
"2"
,
list
(
range
(
num_tokens
*
2
)),
block_size
,
sh
a256
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req2
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req2
)
assert
len
(
computed_blocks
.
blocks
[
0
])
==
1
assert
len
(
computed_blocks
.
blocks
[
0
])
==
1
assert
computed_blocks
.
blocks
[
0
][
0
].
block_id
==
1
assert
computed_blocks
.
blocks
[
0
][
0
].
block_id
==
1
...
@@ -691,7 +695,7 @@ def test_basic_prefix_caching_disabled():
...
@@ -691,7 +695,7 @@ def test_basic_prefix_caching_disabled():
)
)
req1
=
make_request
(
"1"
,
list
(
range
(
10
)),
block_size
,
req1
=
make_request
(
"1"
,
list
(
range
(
10
)),
block_size
,
ha
sh
)
# 2 blocks and some more
sh
a256
)
# 2 blocks and some more
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req1
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req1
)
assert
not
computed_blocks
.
blocks
[
0
]
assert
not
computed_blocks
.
blocks
[
0
]
...
@@ -706,7 +710,7 @@ def test_basic_prefix_caching_disabled():
...
@@ -706,7 +710,7 @@ def test_basic_prefix_caching_disabled():
# No caching.
# No caching.
req2
=
make_request
(
"2"
,
list
(
range
(
16
)),
block_size
,
req2
=
make_request
(
"2"
,
list
(
range
(
16
)),
block_size
,
ha
sh
)
# shared prefix
sh
a256
)
# shared prefix
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req2
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req2
)
assert
not
computed_blocks
.
blocks
[
0
]
assert
not
computed_blocks
.
blocks
[
0
]
assert
num_computed_tokens
==
0
assert
num_computed_tokens
==
0
...
@@ -716,7 +720,7 @@ def test_basic_prefix_caching_disabled():
...
@@ -716,7 +720,7 @@ def test_basic_prefix_caching_disabled():
assert
len
(
blocks
.
blocks
[
0
])
==
4
assert
len
(
blocks
.
blocks
[
0
])
==
4
# New requests should not have any blocks.
# New requests should not have any blocks.
req3
=
make_request
(
"3"
,
list
(
range
(
4
)),
block_size
,
ha
sh
)
req3
=
make_request
(
"3"
,
list
(
range
(
4
)),
block_size
,
sh
a256
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req3
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req3
)
assert
not
computed_blocks
.
blocks
[
0
]
assert
not
computed_blocks
.
blocks
[
0
]
assert
num_computed_tokens
==
0
assert
num_computed_tokens
==
0
...
@@ -726,7 +730,7 @@ def test_basic_prefix_caching_disabled():
...
@@ -726,7 +730,7 @@ def test_basic_prefix_caching_disabled():
assert
not
blocks
assert
not
blocks
@
pytest
.
mark
.
parametrize
(
"hash_fn"
,
[
sha256
,
sha256_cbor
_64bit
,
hash
])
@
pytest
.
mark
.
parametrize
(
"hash_fn"
,
[
sha256
,
sha256_cbor
])
def
test_cache_blocks
(
hash_fn
):
def
test_cache_blocks
(
hash_fn
):
"""
"""
This is a unit test that tests the correctness of the _cache_full_blocks
This is a unit test that tests the correctness of the _cache_full_blocks
...
@@ -787,7 +791,7 @@ def test_cache_blocks_multi_group():
...
@@ -787,7 +791,7 @@ def test_cache_blocks_multi_group():
# Block 1/5: [4, 5, 6, 7]
# Block 1/5: [4, 5, 6, 7]
# Block 2/6: [8, 9, 10, 11]
# Block 2/6: [8, 9, 10, 11]
# Block 3/7: [12, 13]
# Block 3/7: [12, 13]
req
=
make_request
(
"0"
,
list
(
range
(
14
)),
block_size
,
ha
sh
)
req
=
make_request
(
"0"
,
list
(
range
(
14
)),
block_size
,
sh
a256
)
# Cache the blocks for group 0.
# Cache the blocks for group 0.
blocks
=
[
KVCacheBlock
(
block_id
=
i
)
for
i
in
range
(
2
)]
blocks
=
[
KVCacheBlock
(
block_id
=
i
)
for
i
in
range
(
2
)]
...
@@ -845,6 +849,8 @@ def test_mm_prefix_caching():
...
@@ -845,6 +849,8 @@ def test_mm_prefix_caching():
"""
"""
This tests that the multi-modal prefix caching is correct.
This tests that the multi-modal prefix caching is correct.
"""
"""
kv_cache_utils
.
init_none_hash
(
sha256
)
block_size
=
16
block_size
=
16
manager
=
KVCacheManager
(
manager
=
KVCacheManager
(
make_kv_cache_config
(
block_size
,
11
),
make_kv_cache_config
(
block_size
,
11
),
...
@@ -874,23 +880,30 @@ def test_mm_prefix_caching():
...
@@ -874,23 +880,30 @@ def test_mm_prefix_caching():
req0
=
make_request
(
"0"
,
req0
=
make_request
(
"0"
,
all_token_ids
,
all_token_ids
,
block_size
,
block_size
,
ha
sh
,
sh
a256
,
mm_positions
=
mm_positions
,
mm_positions
=
mm_positions
,
mm_hashes
=
mm_hashes
)
mm_hashes
=
mm_hashes
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req0
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req0
)
# Completed block should have hashes
with extra keys.
# Completed block should have hashes
assert
not
computed_blocks
.
blocks
[
0
]
assert
not
computed_blocks
.
blocks
[
0
]
assert
num_computed_tokens
==
0
assert
num_computed_tokens
==
0
block_hashes
=
req0
.
block_hashes
block_hashes
=
req0
.
block_hashes
assert
len
(
block_hashes
)
==
3
assert
len
(
block_hashes
)
==
3
assert
block_hashes
[
0
].
extra_keys
==
(
"aaa"
,
)
assert
block_hashes
[
0
]
==
sha256
(
assert
block_hashes
[
1
].
extra_keys
==
(
"aaa"
,
"bbb"
)
(
kv_cache_utils
.
NONE_HASH
,
tuple
(
all_token_ids
[:
block_size
]),
assert
block_hashes
[
2
].
extra_keys
==
(
"bbb"
,
)
(
"aaa"
,
)))
assert
block_hashes
[
1
]
==
sha256
(
(
block_hashes
[
0
],
tuple
(
all_token_ids
[
block_size
:
block_size
*
2
]),
(
"aaa"
,
"bbb"
)))
assert
block_hashes
[
2
]
==
sha256
(
(
block_hashes
[
1
],
tuple
(
all_token_ids
[
block_size
*
2
:
block_size
*
3
]),
(
"bbb"
,
)))
blocks
=
manager
.
allocate_slots
(
req0
,
59
,
blocks
=
manager
.
allocate_slots
(
req0
,
59
,
len
(
computed_blocks
.
blocks
[
0
])
*
16
,
len
(
computed_blocks
.
blocks
[
0
])
*
16
,
computed_blocks
)
computed_blocks
)
assert
blocks
is
not
None
assert
blocks
.
get_block_ids
()
==
([
1
,
2
,
3
,
4
],
)
assert
blocks
.
get_block_ids
()
==
([
1
,
2
,
3
,
4
],
)
req0
.
num_computed_tokens
=
59
req0
.
num_computed_tokens
=
59
...
@@ -901,10 +914,10 @@ def test_mm_prefix_caching():
...
@@ -901,10 +914,10 @@ def test_mm_prefix_caching():
len
(
computed_blocks
.
blocks
[
0
])
*
16
,
len
(
computed_blocks
.
blocks
[
0
])
*
16
,
computed_blocks
)
computed_blocks
)
assert
new_blocks
is
not
None
and
len
(
new_blocks
.
blocks
[
0
])
==
0
assert
new_blocks
is
not
None
and
len
(
new_blocks
.
blocks
[
0
])
==
0
# The just completed block should have hashes with extra keys.
assert
len
(
block_hashes
)
==
4
assert
len
(
block_hashes
)
==
4
assert
block_hashes
[
3
].
extra_keys
==
(
"ccc"
,
)
assert
block_hashes
[
3
]
==
sha256
(
(
block_hashes
[
2
],
tuple
(
all_token_ids
[
3
*
block_size
:]
+
[
8
]
*
5
),
(
"ccc"
,
)))
# Cache hit.
# Cache hit.
unique_token_ids
=
[
-
1
]
*
7
+
[
200
]
*
5
unique_token_ids
=
[
-
1
]
*
7
+
[
200
]
*
5
...
@@ -916,7 +929,7 @@ def test_mm_prefix_caching():
...
@@ -916,7 +929,7 @@ def test_mm_prefix_caching():
req1
=
make_request
(
"1"
,
req1
=
make_request
(
"1"
,
all_token_ids
,
all_token_ids
,
block_size
,
block_size
,
ha
sh
,
sh
a256
,
mm_positions
=
mm_positions
,
mm_positions
=
mm_positions
,
mm_hashes
=
mm_hashes
)
mm_hashes
=
mm_hashes
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req1
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req1
)
...
@@ -929,6 +942,8 @@ def test_cache_key_salting():
...
@@ -929,6 +942,8 @@ def test_cache_key_salting():
This tests that cache salts are applied during hashing and the cache
This tests that cache salts are applied during hashing and the cache
is separated cache as expected.
is separated cache as expected.
"""
"""
kv_cache_utils
.
init_none_hash
(
sha256
)
block_size
=
16
block_size
=
16
manager
=
KVCacheManager
(
manager
=
KVCacheManager
(
make_kv_cache_config
(
block_size
,
11
),
make_kv_cache_config
(
block_size
,
11
),
...
@@ -939,21 +954,26 @@ def test_cache_key_salting():
...
@@ -939,21 +954,26 @@ def test_cache_key_salting():
# 3 complete blocks and an incomplete block with 11 tokens.
# 3 complete blocks and an incomplete block with 11 tokens.
common_token_ids
=
[
i
for
i
in
range
(
3
)
for
_
in
range
(
block_size
)]
common_token_ids
=
[
i
for
i
in
range
(
3
)
for
_
in
range
(
block_size
)]
token_ids
=
common_token_ids
+
[
3
]
*
11
token_ids
=
common_token_ids
+
[
3
]
*
11
req0
=
make_request
(
"0"
,
token_ids
,
block_size
,
ha
sh
,
cache_salt
=
"salt1"
)
req0
=
make_request
(
"0"
,
token_ids
,
block_size
,
sh
a256
,
cache_salt
=
"salt1"
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req0
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req0
)
# Completed block should have hashes
with extra keys.
# Completed block should have hashes
assert
not
computed_blocks
.
blocks
[
0
]
assert
not
computed_blocks
.
blocks
[
0
]
assert
num_computed_tokens
==
0
assert
num_computed_tokens
==
0
block_hashes
=
req0
.
block_hashes
block_hashes
=
req0
.
block_hashes
assert
len
(
block_hashes
)
==
3
assert
len
(
block_hashes
)
==
3
assert
block_hashes
[
0
].
extra_keys
==
(
"salt1"
,
)
assert
block_hashes
[
0
]
==
sha256
(
assert
block_hashes
[
1
].
extra_keys
is
None
(
kv_cache_utils
.
NONE_HASH
,
tuple
(
token_ids
[:
block_size
]),
(
"salt1"
,
)))
assert
block_hashes
[
2
].
extra_keys
is
None
assert
block_hashes
[
1
]
==
sha256
(
(
block_hashes
[
0
],
tuple
(
token_ids
[
block_size
:
block_size
*
2
]),
None
))
assert
block_hashes
[
2
]
==
sha256
(
(
block_hashes
[
1
],
tuple
(
token_ids
[
block_size
*
2
:
block_size
*
3
]),
None
))
blocks
=
manager
.
allocate_slots
(
req0
,
59
,
blocks
=
manager
.
allocate_slots
(
req0
,
59
,
len
(
computed_blocks
.
blocks
[
0
])
*
16
,
len
(
computed_blocks
.
blocks
[
0
])
*
16
,
computed_blocks
)
computed_blocks
)
assert
blocks
is
not
None
assert
blocks
.
get_block_ids
()
==
([
1
,
2
,
3
,
4
],
)
assert
blocks
.
get_block_ids
()
==
([
1
,
2
,
3
,
4
],
)
req0
.
num_computed_tokens
=
59
req0
.
num_computed_tokens
=
59
...
@@ -964,14 +984,13 @@ def test_cache_key_salting():
...
@@ -964,14 +984,13 @@ def test_cache_key_salting():
len
(
computed_blocks
.
blocks
[
0
])
*
16
,
len
(
computed_blocks
.
blocks
[
0
])
*
16
,
computed_blocks
)
computed_blocks
)
assert
new_blocks
is
not
None
and
len
(
new_blocks
.
blocks
[
0
])
==
0
assert
new_blocks
is
not
None
and
len
(
new_blocks
.
blocks
[
0
])
==
0
# Now one more block that should not have extra keys.
assert
len
(
block_hashes
)
==
4
assert
len
(
block_hashes
)
==
4
assert
block_hashes
[
3
].
extra_keys
is
None
assert
block_hashes
[
3
]
==
sha256
(
(
block_hashes
[
2
],
tuple
(
token_ids
[
3
*
block_size
:]
+
[
8
]
*
5
),
None
))
# Test cache hit with a new request that has the same salt.
# Test cache hit with a new request that has the same salt.
token_ids
=
common_token_ids
+
[
4
]
*
11
token_ids
=
common_token_ids
+
[
4
]
*
11
req1
=
make_request
(
"1"
,
token_ids
,
block_size
,
ha
sh
,
cache_salt
=
"salt1"
)
req1
=
make_request
(
"1"
,
token_ids
,
block_size
,
sh
a256
,
cache_salt
=
"salt1"
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req1
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req1
)
# Should match only a prefix of 3 blocks.
# Should match only a prefix of 3 blocks.
assert
len
(
computed_blocks
.
blocks
[
0
])
==
3
assert
len
(
computed_blocks
.
blocks
[
0
])
==
3
...
@@ -979,13 +998,19 @@ def test_cache_key_salting():
...
@@ -979,13 +998,19 @@ def test_cache_key_salting():
# Test cache miss with same content but different salt.
# Test cache miss with same content but different salt.
token_ids
=
common_token_ids
+
[
4
]
*
11
token_ids
=
common_token_ids
+
[
4
]
*
11
req2
=
make_request
(
"2"
,
token_ids
,
block_size
,
ha
sh
,
cache_salt
=
"salt2"
)
req2
=
make_request
(
"2"
,
token_ids
,
block_size
,
sh
a256
,
cache_salt
=
"salt2"
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req2
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req2
)
assert
len
(
computed_blocks
.
blocks
[
0
])
==
0
assert
len
(
computed_blocks
.
blocks
[
0
])
==
0
assert
num_computed_tokens
==
0
assert
num_computed_tokens
==
0
block_hashes
=
req2
.
block_hashes
block_hashes
=
req2
.
block_hashes
assert
len
(
block_hashes
)
==
3
assert
len
(
block_hashes
)
==
3
assert
block_hashes
[
0
].
extra_keys
==
(
"salt2"
,
)
assert
block_hashes
[
0
]
==
sha256
(
(
kv_cache_utils
.
NONE_HASH
,
tuple
(
token_ids
[:
block_size
]),
(
"salt2"
,
)))
assert
block_hashes
[
1
]
==
sha256
(
(
block_hashes
[
0
],
tuple
(
token_ids
[
block_size
:
block_size
*
2
]),
None
))
assert
block_hashes
[
2
]
==
sha256
(
(
block_hashes
[
1
],
tuple
(
token_ids
[
block_size
*
2
:
block_size
*
3
]),
None
))
def
test_prefill_not_enough_free_blocks_with_computed_blocks
():
def
test_prefill_not_enough_free_blocks_with_computed_blocks
():
...
@@ -1004,7 +1029,7 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
...
@@ -1004,7 +1029,7 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
# Complete 3 blocks (48 tokens)
# Complete 3 blocks (48 tokens)
# | Common-0 | Common-1 | Common-2 | ... |
# | Common-0 | Common-1 | Common-2 | ... |
common_token_ids
=
[
i
for
i
in
range
(
3
)
for
_
in
range
(
16
)]
common_token_ids
=
[
i
for
i
in
range
(
3
)
for
_
in
range
(
16
)]
req0
=
make_request
(
"0"
,
common_token_ids
,
block_size
,
ha
sh
)
req0
=
make_request
(
"0"
,
common_token_ids
,
block_size
,
sh
a256
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req0
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req0
)
assert
not
computed_blocks
.
blocks
[
0
]
assert
not
computed_blocks
.
blocks
[
0
]
assert
num_computed_tokens
==
0
assert
num_computed_tokens
==
0
...
@@ -1015,7 +1040,7 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
...
@@ -1015,7 +1040,7 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
req0
.
request_id
]
req0
.
request_id
]
# | Common-0 | Common-1 | Common-2 | Req1-3 | Req1-4 | Req1-5 | ... |
# | Common-0 | Common-1 | Common-2 | Req1-3 | Req1-4 | Req1-5 | ... |
req1
=
make_request
(
"1"
,
common_token_ids
*
2
,
block_size
,
ha
sh
)
req1
=
make_request
(
"1"
,
common_token_ids
*
2
,
block_size
,
sh
a256
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req1
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req1
)
assert
computed_blocks
.
blocks
[
0
]
==
block_part0
assert
computed_blocks
.
blocks
[
0
]
==
block_part0
assert
num_computed_tokens
==
3
*
16
assert
num_computed_tokens
==
3
*
16
...
@@ -1032,7 +1057,7 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
...
@@ -1032,7 +1057,7 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
# | Common-0 | Common-1 | Common-2 | Req1-3 (F) | Req1-4 (F) |
# | Common-0 | Common-1 | Common-2 | Req1-3 (F) | Req1-4 (F) |
# | Req1-5(F)| Req2-0 | Req2-1 | ... |
# | Req1-5(F)| Req2-0 | Req2-1 | ... |
req2
=
make_request
(
"2"
,
[
7
]
*
block_size
*
2
,
block_size
,
ha
sh
)
req2
=
make_request
(
"2"
,
[
7
]
*
block_size
*
2
,
block_size
,
sh
a256
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req2
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req2
)
assert
not
computed_blocks
.
blocks
[
0
]
assert
not
computed_blocks
.
blocks
[
0
]
assert
num_computed_tokens
==
0
assert
num_computed_tokens
==
0
...
@@ -1044,7 +1069,7 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
...
@@ -1044,7 +1069,7 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
# but it cannot be allocated due to insufficient free blocks (2).
# but it cannot be allocated due to insufficient free blocks (2).
# In this case, the ref_cnt of the computed blocks should not be changed.
# In this case, the ref_cnt of the computed blocks should not be changed.
assert
manager
.
block_pool
.
free_block_queue
.
num_free_blocks
==
5
assert
manager
.
block_pool
.
free_block_queue
.
num_free_blocks
==
5
req3
=
make_request
(
"3"
,
common_token_ids
*
3
,
block_size
,
ha
sh
)
req3
=
make_request
(
"3"
,
common_token_ids
*
3
,
block_size
,
sh
a256
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req3
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req3
)
assert
computed_blocks
.
blocks
[
0
]
==
block_part1
assert
computed_blocks
.
blocks
[
0
]
==
block_part1
assert
num_computed_tokens
==
6
*
16
assert
num_computed_tokens
==
6
*
16
...
@@ -1069,13 +1094,13 @@ def test_reset_prefix_cache():
...
@@ -1069,13 +1094,13 @@ def test_reset_prefix_cache():
full_block_token_ids
=
[
i
for
i
in
range
(
3
)
for
_
in
range
(
16
)]
full_block_token_ids
=
[
i
for
i
in
range
(
3
)
for
_
in
range
(
16
)]
unique_token_ids
=
[
3
]
*
7
unique_token_ids
=
[
3
]
*
7
all_token_ids
=
full_block_token_ids
+
unique_token_ids
all_token_ids
=
full_block_token_ids
+
unique_token_ids
req0
=
make_request
(
"0"
,
all_token_ids
,
block_size
,
ha
sh
)
req0
=
make_request
(
"0"
,
all_token_ids
,
block_size
,
sh
a256
)
blocks
=
manager
.
allocate_slots
(
req0
,
55
)
blocks
=
manager
.
allocate_slots
(
req0
,
55
)
assert
blocks
.
get_block_ids
()
==
([
1
,
2
,
3
,
4
],
)
assert
blocks
.
get_block_ids
()
==
([
1
,
2
,
3
,
4
],
)
unique_token_ids
=
[
4
]
*
7
unique_token_ids
=
[
4
]
*
7
all_token_ids
=
full_block_token_ids
+
unique_token_ids
all_token_ids
=
full_block_token_ids
+
unique_token_ids
req1
=
make_request
(
"1"
,
all_token_ids
,
block_size
,
ha
sh
)
req1
=
make_request
(
"1"
,
all_token_ids
,
block_size
,
sh
a256
)
computed_blocks
,
_
=
manager
.
get_computed_blocks
(
req1
)
computed_blocks
,
_
=
manager
.
get_computed_blocks
(
req1
)
assert
len
(
req1
.
block_hashes
)
==
3
assert
len
(
req1
.
block_hashes
)
==
3
assert
len
(
computed_blocks
.
blocks
[
0
])
==
3
assert
len
(
computed_blocks
.
blocks
[
0
])
==
3
...
@@ -1109,7 +1134,7 @@ def test_prefix_cache_stats_disabled():
...
@@ -1109,7 +1134,7 @@ def test_prefix_cache_stats_disabled():
assert
manager
.
prefix_cache_stats
is
None
assert
manager
.
prefix_cache_stats
is
None
# Call all functions that check whether log_stats is disabled.
# Call all functions that check whether log_stats is disabled.
req
=
make_request
(
"0"
,
list
(
range
(
16
)),
block_size
,
ha
sh
)
req
=
make_request
(
"0"
,
list
(
range
(
16
)),
block_size
,
sh
a256
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req
)
assert
not
computed_blocks
.
blocks
[
0
]
assert
not
computed_blocks
.
blocks
[
0
]
assert
num_computed_tokens
==
0
assert
num_computed_tokens
==
0
...
@@ -1124,15 +1149,9 @@ def test_prefix_cache_stats_disabled():
...
@@ -1124,15 +1149,9 @@ def test_prefix_cache_stats_disabled():
def
test_maybe_evict_cached_block
():
def
test_maybe_evict_cached_block
():
pool
=
BlockPool
(
num_gpu_blocks
=
4
,
enable_caching
=
True
)
pool
=
BlockPool
(
num_gpu_blocks
=
4
,
enable_caching
=
True
)
block_hash0
=
BlockHashWithGroupId
(
block_hash
=
BlockHash
(
hash_value
=
10
,
block_hash0
=
make_block_hash_with_group_id
(
BlockHash
(
b
"10"
),
1000
)
token_ids
=
(
100
,
)),
block_hash1
=
make_block_hash_with_group_id
(
BlockHash
(
b
"20"
),
2000
)
group_id
=
1000
)
block_hash2
=
make_block_hash_with_group_id
(
BlockHash
(
b
"30"
),
3000
)
block_hash1
=
BlockHashWithGroupId
(
block_hash
=
BlockHash
(
hash_value
=
20
,
token_ids
=
(
200
,
)),
group_id
=
2000
)
block_hash2
=
BlockHashWithGroupId
(
block_hash
=
BlockHash
(
hash_value
=
30
,
token_ids
=
(
300
,
)),
group_id
=
3000
)
block_hashes
=
[
block_hashes
=
[
block_hash0
,
block_hash0
,
block_hash1
,
block_hash1
,
...
@@ -1206,7 +1225,7 @@ def test_kv_cache_events(blocks_to_cache: int):
...
@@ -1206,7 +1225,7 @@ def test_kv_cache_events(blocks_to_cache: int):
)
)
num_tokens
=
block_size
*
blocks_to_cache
num_tokens
=
block_size
*
blocks_to_cache
req0
=
make_request
(
"0"
,
list
(
range
(
num_tokens
)),
block_size
,
ha
sh
)
req0
=
make_request
(
"0"
,
list
(
range
(
num_tokens
)),
block_size
,
sh
a256
)
_
=
manager
.
allocate_slots
(
req0
,
num_tokens
)
_
=
manager
.
allocate_slots
(
req0
,
num_tokens
)
events
=
manager
.
take_events
()
events
=
manager
.
take_events
()
...
@@ -1222,7 +1241,7 @@ def test_kv_cache_events(blocks_to_cache: int):
...
@@ -1222,7 +1241,7 @@ def test_kv_cache_events(blocks_to_cache: int):
# Should see block_to_cache number of removed block events and a new block
# Should see block_to_cache number of removed block events and a new block
# stored event
# stored event
manager
.
free
(
req0
)
manager
.
free
(
req0
)
req1
=
make_request
(
"1"
,
list
(
range
(
num_tokens
)),
block_size
,
ha
sh
)
req1
=
make_request
(
"1"
,
list
(
range
(
num_tokens
)),
block_size
,
sh
a256
)
_
=
manager
.
allocate_slots
(
req1
,
num_tokens
)
_
=
manager
.
allocate_slots
(
req1
,
num_tokens
)
events
=
manager
.
take_events
()
events
=
manager
.
take_events
()
...
@@ -1256,7 +1275,7 @@ def test_eagle_enabled_removes_last_block():
...
@@ -1256,7 +1275,7 @@ def test_eagle_enabled_removes_last_block():
# Request with 3 full blocks (48 tokens)
# Request with 3 full blocks (48 tokens)
token_ids
=
[
0
]
*
(
3
*
block_size
)
token_ids
=
[
0
]
*
(
3
*
block_size
)
req
=
make_request
(
"divisible_request"
,
token_ids
,
block_size
,
ha
sh
)
req
=
make_request
(
"divisible_request"
,
token_ids
,
block_size
,
sh
a256
)
# Prime the cache
# Prime the cache
computed_blocks
,
_
=
manager
.
get_computed_blocks
(
req
)
computed_blocks
,
_
=
manager
.
get_computed_blocks
(
req
)
...
@@ -1266,7 +1285,7 @@ def test_eagle_enabled_removes_last_block():
...
@@ -1266,7 +1285,7 @@ def test_eagle_enabled_removes_last_block():
manager
.
free
(
req
)
manager
.
free
(
req
)
# New request with same tokens + Eagle enabled
# New request with same tokens + Eagle enabled
req_eagle
=
make_request
(
"eagle_divisible"
,
token_ids
,
block_size
,
ha
sh
)
req_eagle
=
make_request
(
"eagle_divisible"
,
token_ids
,
block_size
,
sh
a256
)
computed_blocks
,
num_tokens
=
manager
.
get_computed_blocks
(
req_eagle
)
computed_blocks
,
num_tokens
=
manager
.
get_computed_blocks
(
req_eagle
)
# Should retain 1 block:
# Should retain 1 block:
...
@@ -1287,7 +1306,7 @@ def test_eagle_with_partial_blocks():
...
@@ -1287,7 +1306,7 @@ def test_eagle_with_partial_blocks():
)
)
# 2 full blocks + 5 tokens (non-divisible length)
# 2 full blocks + 5 tokens (non-divisible length)
token_ids
=
[
0
]
*
(
2
*
block_size
+
5
)
token_ids
=
[
0
]
*
(
2
*
block_size
+
5
)
req
=
make_request
(
"partial_block_test"
,
token_ids
,
block_size
,
ha
sh
)
req
=
make_request
(
"partial_block_test"
,
token_ids
,
block_size
,
sh
a256
)
# Prime the cache
# Prime the cache
computed_blocks
,
_
=
manager
.
get_computed_blocks
(
req
)
computed_blocks
,
_
=
manager
.
get_computed_blocks
(
req
)
...
@@ -1297,7 +1316,7 @@ def test_eagle_with_partial_blocks():
...
@@ -1297,7 +1316,7 @@ def test_eagle_with_partial_blocks():
manager
.
free
(
req
)
manager
.
free
(
req
)
# New request with Eagle enabled
# New request with Eagle enabled
req_eagle
=
make_request
(
"partial_eagle"
,
token_ids
,
block_size
,
ha
sh
)
req_eagle
=
make_request
(
"partial_eagle"
,
token_ids
,
block_size
,
sh
a256
)
computed_blocks
,
num_tokens
=
manager
.
get_computed_blocks
(
req_eagle
)
computed_blocks
,
num_tokens
=
manager
.
get_computed_blocks
(
req_eagle
)
# Original match: 2 full blocks → Eagle removes 1 → 1 remaining
# Original match: 2 full blocks → Eagle removes 1 → 1 remaining
assert
len
(
computed_blocks
.
blocks
[
0
])
==
1
assert
len
(
computed_blocks
.
blocks
[
0
])
==
1
...
@@ -1328,7 +1347,7 @@ def test_eagle_with_sliding_window():
...
@@ -1328,7 +1347,7 @@ def test_eagle_with_sliding_window():
# 2 full blocks + 5 tokens (non-divisible length)
# 2 full blocks + 5 tokens (non-divisible length)
token_ids
=
[
0
]
*
(
2
*
block_size
+
5
)
token_ids
=
[
0
]
*
(
2
*
block_size
+
5
)
req
=
make_request
(
"partial_block_test"
,
token_ids
,
block_size
,
ha
sh
)
req
=
make_request
(
"partial_block_test"
,
token_ids
,
block_size
,
sh
a256
)
# Prime the cache
# Prime the cache
computed_blocks
,
_
=
manager
.
get_computed_blocks
(
req
)
computed_blocks
,
_
=
manager
.
get_computed_blocks
(
req
)
...
@@ -1341,7 +1360,7 @@ def test_eagle_with_sliding_window():
...
@@ -1341,7 +1360,7 @@ def test_eagle_with_sliding_window():
manager
.
free
(
req
)
manager
.
free
(
req
)
# New request with Eagle enabled
# New request with Eagle enabled
req_eagle
=
make_request
(
"partial_eagle"
,
token_ids
,
block_size
,
ha
sh
)
req_eagle
=
make_request
(
"partial_eagle"
,
token_ids
,
block_size
,
sh
a256
)
computed_blocks
,
num_tokens
=
manager
.
get_computed_blocks
(
req_eagle
)
computed_blocks
,
num_tokens
=
manager
.
get_computed_blocks
(
req_eagle
)
# Original match: 2 full blocks → Eagle removes 1 → 1 remaining
# Original match: 2 full blocks → Eagle removes 1 → 1 remaining
assert
len
(
computed_blocks
.
blocks
[
0
])
==
1
assert
len
(
computed_blocks
.
blocks
[
0
])
==
1
...
@@ -1351,11 +1370,11 @@ def test_eagle_with_sliding_window():
...
@@ -1351,11 +1370,11 @@ def test_eagle_with_sliding_window():
assert
manager
.
block_pool
.
get_cached_block
(
assert
manager
.
block_pool
.
get_cached_block
(
block_hash_first_block
,
kv_cache_group_ids
=
[
0
])
is
not
None
block_hash_first_block
,
kv_cache_group_ids
=
[
0
])
is
not
None
manager
.
block_pool
.
cached_block_hash_to_block
.
pop
(
manager
.
block_pool
.
cached_block_hash_to_block
.
pop
(
B
lock
H
ash
W
ith
G
roup
I
d
(
block_hash_first_block
,
0
))
make_b
lock
_h
ash
_w
ith
_g
roup
_i
d
(
block_hash_first_block
,
0
))
# New request
# New request
req_after_evict
=
make_request
(
"partial_eagle_after_evict"
,
token_ids
,
req_after_evict
=
make_request
(
"partial_eagle_after_evict"
,
token_ids
,
block_size
,
ha
sh
)
block_size
,
sh
a256
)
computed_blocks
,
num_tokens
=
manager
.
get_computed_blocks
(
req_after_evict
)
computed_blocks
,
num_tokens
=
manager
.
get_computed_blocks
(
req_after_evict
)
# Cache miss. The only hit prefix is [NULL_BLOCK, BLOCK_2] if eagle is
# Cache miss. The only hit prefix is [NULL_BLOCK, BLOCK_2] if eagle is
# not considered. But after dropping the last matched block due to eagle,
# not considered. But after dropping the last matched block due to eagle,
...
...
tests/v1/core/test_single_type_kv_cache_manager.py
View file @
82dfb12e
...
@@ -6,8 +6,8 @@ import random
...
@@ -6,8 +6,8 @@ import random
import
torch
import
torch
from
vllm.v1.core.block_pool
import
BlockPool
from
vllm.v1.core.block_pool
import
BlockPool
from
vllm.v1.core.kv_cache_utils
import
(
BlockHash
,
BlockHashWithGroupId
,
from
vllm.v1.core.kv_cache_utils
import
(
BlockHash
,
KVCacheBlock
,
KVCacheBlock
)
make_block_hash_with_group_id
)
from
vllm.v1.core.single_type_kv_cache_manager
import
(
from
vllm.v1.core.single_type_kv_cache_manager
import
(
ChunkedLocalAttentionManager
,
SlidingWindowManager
)
ChunkedLocalAttentionManager
,
SlidingWindowManager
)
from
vllm.v1.kv_cache_interface
import
(
ChunkedLocalAttentionSpec
,
from
vllm.v1.kv_cache_interface
import
(
ChunkedLocalAttentionSpec
,
...
@@ -44,7 +44,7 @@ def test_chunked_local_attention_possible_cached_prefix():
...
@@ -44,7 +44,7 @@ def test_chunked_local_attention_possible_cached_prefix():
def
run_one_case
(
block_is_cached
,
tail_token
,
expect_length
):
def
run_one_case
(
block_is_cached
,
tail_token
,
expect_length
):
block_hash_list
=
[
block_hash_list
=
[
BlockHash
(
i
,
())
for
i
in
range
(
len
(
block_is_cached
))
BlockHash
(
str
(
i
).
encode
())
for
i
in
range
(
len
(
block_is_cached
))
]
]
block_pool
.
cached_block_hash_to_block
.
clear
()
block_pool
.
cached_block_hash_to_block
.
clear
()
...
@@ -53,8 +53,8 @@ def test_chunked_local_attention_possible_cached_prefix():
...
@@ -53,8 +53,8 @@ def test_chunked_local_attention_possible_cached_prefix():
for
i
,
(
block_hash
,
for
i
,
(
block_hash
,
is_cached
)
in
enumerate
(
zip
(
block_hash_list
,
block_is_cached
)):
is_cached
)
in
enumerate
(
zip
(
block_hash_list
,
block_is_cached
)):
if
is_cached
:
if
is_cached
:
block_pool
.
cached_block_hash_to_block
[
BlockHashWithGroupId
(
block_pool
.
cached_block_hash_to_block
[
block_hash
,
0
)]
=
{
make_block_hash_with_group_id
(
block_hash
,
0
)]
=
{
i
:
block_pool
.
blocks
[
i
+
10
],
i
:
block_pool
.
blocks
[
i
+
10
],
}
}
...
@@ -109,7 +109,7 @@ def test_sliding_window_possible_cached_prefix():
...
@@ -109,7 +109,7 @@ def test_sliding_window_possible_cached_prefix():
def
run_one_case
(
block_is_cached
,
expect_length
):
def
run_one_case
(
block_is_cached
,
expect_length
):
block_hash_list
=
[
block_hash_list
=
[
BlockHash
(
i
,
())
for
i
in
range
(
len
(
block_is_cached
))
BlockHash
(
str
(
i
).
encode
())
for
i
in
range
(
len
(
block_is_cached
))
]
]
block_pool
.
cached_block_hash_to_block
.
clear
()
block_pool
.
cached_block_hash_to_block
.
clear
()
...
@@ -118,8 +118,8 @@ def test_sliding_window_possible_cached_prefix():
...
@@ -118,8 +118,8 @@ def test_sliding_window_possible_cached_prefix():
for
i
,
(
block_hash
,
for
i
,
(
block_hash
,
is_cached
)
in
enumerate
(
zip
(
block_hash_list
,
block_is_cached
)):
is_cached
)
in
enumerate
(
zip
(
block_hash_list
,
block_is_cached
)):
if
is_cached
:
if
is_cached
:
block_pool
.
cached_block_hash_to_block
[
BlockHashWithGroupId
(
block_pool
.
cached_block_hash_to_block
[
block_hash
,
0
)]
=
{
make_block_hash_with_group_id
(
block_hash
,
0
)]
=
{
i
:
block_pool
.
blocks
[
i
+
10
],
i
:
block_pool
.
blocks
[
i
+
10
],
}
}
...
...
tests/v1/core/utils.py
View file @
82dfb12e
...
@@ -9,6 +9,7 @@ from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig,
...
@@ -9,6 +9,7 @@ from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig,
from
vllm.multimodal.inputs
import
(
MultiModalFeatureSpec
,
from
vllm.multimodal.inputs
import
(
MultiModalFeatureSpec
,
MultiModalKwargsItem
,
PlaceholderRange
)
MultiModalKwargsItem
,
PlaceholderRange
)
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.utils
import
sha256
from
vllm.v1.core.kv_cache_utils
import
(
get_request_block_hasher
,
from
vllm.v1.core.kv_cache_utils
import
(
get_request_block_hasher
,
init_none_hash
)
init_none_hash
)
from
vllm.v1.core.sched.async_scheduler
import
AsyncScheduler
from
vllm.v1.core.sched.async_scheduler
import
AsyncScheduler
...
@@ -130,10 +131,10 @@ def create_requests(
...
@@ -130,10 +131,10 @@ def create_requests(
)
->
list
[
Request
]:
)
->
list
[
Request
]:
global
_none_hash_initialized
global
_none_hash_initialized
if
not
_none_hash_initialized
:
if
not
_none_hash_initialized
:
init_none_hash
(
ha
sh
)
init_none_hash
(
sh
a256
)
_none_hash_initialized
=
True
_none_hash_initialized
=
True
block_hasher
=
get_request_block_hasher
(
block_size
,
ha
sh
)
block_hasher
=
get_request_block_hasher
(
block_size
,
sh
a256
)
sampling_params
=
SamplingParams
(
ignore_eos
=
False
,
sampling_params
=
SamplingParams
(
ignore_eos
=
False
,
max_tokens
=
max_tokens
,
max_tokens
=
max_tokens
,
stop_token_ids
=
stop_token_ids
,
stop_token_ids
=
stop_token_ids
,
...
...
tests/v1/engine/test_engine_args.py
View file @
82dfb12e
...
@@ -36,18 +36,19 @@ def test_prefix_caching_from_cli():
...
@@ -36,18 +36,19 @@ def test_prefix_caching_from_cli():
assert
vllm_config
.
cache_config
.
enable_prefix_caching
assert
vllm_config
.
cache_config
.
enable_prefix_caching
# default hash algorithm is "builtin"
# default hash algorithm is "builtin"
assert
vllm_config
.
cache_config
.
prefix_caching_hash_algo
==
"builtin"
assert
vllm_config
.
cache_config
.
prefix_caching_hash_algo
==
"sha256"
# set hash algorithm to sha256_cbor
args
=
parser
.
parse_args
([
"--prefix-caching-hash-algo"
,
"sha256_cbor"
])
vllm_config
=
EngineArgs
.
from_cli_args
(
args
=
args
).
create_engine_config
()
assert
vllm_config
.
cache_config
.
prefix_caching_hash_algo
==
\
"sha256_cbor"
# set hash algorithm to sha256
# set hash algorithm to sha256
args
=
parser
.
parse_args
([
"--prefix-caching-hash-algo"
,
"sha256"
])
args
=
parser
.
parse_args
([
"--prefix-caching-hash-algo"
,
"sha256"
])
vllm_config
=
EngineArgs
.
from_cli_args
(
args
=
args
).
create_engine_config
()
vllm_config
=
EngineArgs
.
from_cli_args
(
args
=
args
).
create_engine_config
()
assert
vllm_config
.
cache_config
.
prefix_caching_hash_algo
==
"sha256"
assert
vllm_config
.
cache_config
.
prefix_caching_hash_algo
==
"sha256"
# set hash algorithm to builtin
args
=
parser
.
parse_args
([
"--prefix-caching-hash-algo"
,
"builtin"
])
vllm_config
=
EngineArgs
.
from_cli_args
(
args
=
args
).
create_engine_config
()
assert
vllm_config
.
cache_config
.
prefix_caching_hash_algo
==
"builtin"
# an invalid hash algorithm raises an error
# an invalid hash algorithm raises an error
parser
.
exit_on_error
=
False
parser
.
exit_on_error
=
False
with
pytest
.
raises
(
ArgumentError
):
with
pytest
.
raises
(
ArgumentError
):
...
...
tests/v1/kv_connector/unit/utils.py
View file @
82dfb12e
...
@@ -13,6 +13,7 @@ from vllm.distributed.kv_transfer.kv_connector.factory import (
...
@@ -13,6 +13,7 @@ from vllm.distributed.kv_transfer.kv_connector.factory import (
KVConnectorFactory
)
KVConnectorFactory
)
from
vllm.distributed.kv_transfer.kv_connector.v1.shared_storage_connector
import
(
# noqa
from
vllm.distributed.kv_transfer.kv_connector.v1.shared_storage_connector
import
(
# noqa
SharedStorageConnector
)
SharedStorageConnector
)
from
vllm.utils
import
sha256
from
vllm.v1.core.kv_cache_manager
import
KVCacheBlocks
from
vllm.v1.core.kv_cache_manager
import
KVCacheBlocks
from
vllm.v1.core.kv_cache_utils
import
(
get_request_block_hasher
,
from
vllm.v1.core.kv_cache_utils
import
(
get_request_block_hasher
,
init_none_hash
)
init_none_hash
)
...
@@ -127,11 +128,11 @@ def create_request(request_id: int,
...
@@ -127,11 +128,11 @@ def create_request(request_id: int,
use_all_1s_for_prompt_tokens
:
bool
=
False
,
use_all_1s_for_prompt_tokens
:
bool
=
False
,
num_remote_blocks
:
int
=
3
,
num_remote_blocks
:
int
=
3
,
block_size
:
int
=
16
,
block_size
:
int
=
16
,
hash_fn
:
Callable
=
ha
sh
)
->
Request
:
hash_fn
:
Callable
=
sh
a256
)
->
Request
:
"""Make dummy request for testing."""
"""Make dummy request for testing."""
global
_none_hash_initialized
global
_none_hash_initialized
if
not
_none_hash_initialized
:
if
not
_none_hash_initialized
:
init_none_hash
(
hash
)
init_none_hash
(
hash
_fn
)
_none_hash_initialized
=
True
_none_hash_initialized
=
True
kv_transfer_params
:
Optional
[
dict
[
str
,
Any
]]
=
None
kv_transfer_params
:
Optional
[
dict
[
str
,
Any
]]
=
None
...
...
vllm/config/cache.py
View file @
82dfb12e
...
@@ -24,7 +24,7 @@ logger = init_logger(__name__)
...
@@ -24,7 +24,7 @@ logger = init_logger(__name__)
BlockSize
=
Literal
[
1
,
8
,
16
,
32
,
64
,
128
]
BlockSize
=
Literal
[
1
,
8
,
16
,
32
,
64
,
128
]
CacheDType
=
Literal
[
"auto"
,
"fp8"
,
"fp8_e4m3"
,
"fp8_e5m2"
,
"fp8_inc"
]
CacheDType
=
Literal
[
"auto"
,
"fp8"
,
"fp8_e4m3"
,
"fp8_e5m2"
,
"fp8_inc"
]
MambaDType
=
Literal
[
"auto"
,
"float32"
]
MambaDType
=
Literal
[
"auto"
,
"float32"
]
PrefixCachingHashAlgo
=
Literal
[
"builtin"
,
"sha256"
,
"sha256_cbor
_64bit
"
]
PrefixCachingHashAlgo
=
Literal
[
"sha256"
,
"sha256_cbor"
]
@
config
@
config
...
@@ -63,17 +63,12 @@ class CacheConfig:
...
@@ -63,17 +63,12 @@ class CacheConfig:
"""Sliding window size for the KV cache. This is primarily set in
"""Sliding window size for the KV cache. This is primarily set in
`ModelConfig` and that value should be manually duplicated here."""
`ModelConfig` and that value should be manually duplicated here."""
enable_prefix_caching
:
Optional
[
bool
]
=
None
enable_prefix_caching
:
Optional
[
bool
]
=
None
"""Whether to enable prefix caching. Disabled by default for V0. Enabled by
"""Whether to enable prefix caching. Enabled by default for V1."""
default for V1."""
prefix_caching_hash_algo
:
PrefixCachingHashAlgo
=
"sha256"
prefix_caching_hash_algo
:
PrefixCachingHashAlgo
=
"builtin"
"""Set the hash algorithm for prefix caching:
\n
"""Set the hash algorithm for prefix caching:
\n
- "builtin" is Python's built-in hash.
\n
- "sha256" uses Pickle for object serialization before hashing.
\n
- "sha256" is collision resistant but with certain overheads.
- "sha256_cbor" provides a reproducible, cross-language compatible hash. It
This option uses Pickle for object serialization before hashing.
\n
serializes objects using canonical CBOR and hashes them with SHA-256."""
- "sha256_cbor_64bit" provides a reproducible, cross-language compatible
hash. It serializes objects using canonical CBOR and hashes them with
SHA-256. The resulting hash consists of the lower 64 bits of the SHA-256
digest."""
cpu_offload_gb
:
float
=
0
cpu_offload_gb
:
float
=
0
"""The space in GiB to offload to CPU, per GPU. Default is 0, which means
"""The space in GiB to offload to CPU, per GPU. Default is 0, which means
no offloading. Intuitively, this argument can be seen as a virtual way to
no offloading. Intuitively, this argument can be seen as a virtual way to
...
...
vllm/distributed/kv_events.py
View file @
82dfb12e
...
@@ -16,6 +16,7 @@ import zmq
...
@@ -16,6 +16,7 @@ import zmq
from
vllm.config.kv_events
import
KVEventsConfig
from
vllm.config.kv_events
import
KVEventsConfig
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.v1.core.kv_cache_utils
import
ExternalBlockHash
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -44,8 +45,8 @@ MEDIUM_GPU = "GPU"
...
@@ -44,8 +45,8 @@ MEDIUM_GPU = "GPU"
class
BlockStored
(
KVCacheEvent
):
class
BlockStored
(
KVCacheEvent
):
block_hashes
:
list
[
int
]
block_hashes
:
list
[
ExternalBlockHash
]
parent_block_hash
:
Optional
[
int
]
parent_block_hash
:
Optional
[
ExternalBlockHash
]
token_ids
:
list
[
int
]
token_ids
:
list
[
int
]
block_size
:
int
block_size
:
int
lora_id
:
Optional
[
int
]
lora_id
:
Optional
[
int
]
...
@@ -53,7 +54,7 @@ class BlockStored(KVCacheEvent):
...
@@ -53,7 +54,7 @@ class BlockStored(KVCacheEvent):
class
BlockRemoved
(
KVCacheEvent
):
class
BlockRemoved
(
KVCacheEvent
):
block_hashes
:
list
[
int
]
block_hashes
:
list
[
ExternalBlockHash
]
medium
:
Optional
[
str
]
medium
:
Optional
[
str
]
...
...
vllm/engine/arg_utils.py
View file @
82dfb12e
...
@@ -1592,21 +1592,13 @@ class EngineArgs:
...
@@ -1592,21 +1592,13 @@ class EngineArgs:
"in low performance due to small KV cache size. Consider "
"in low performance due to small KV cache size. Consider "
"setting --max-model-len to a smaller value."
,
max_model_len
)
"setting --max-model-len to a smaller value."
,
max_model_len
)
# if using prefix caching, we must set a hash algo
if
self
.
enable_prefix_caching
:
# Disable prefix caching for multimodal models for VLLM_V0.
# Disable prefix caching for multimodal models for VLLM_V0.
if
model_config
.
is_multimodal_model
:
if
self
.
enable_prefix_caching
and
model_config
.
is_multimodal_model
:
logger
.
warning
(
logger
.
warning
(
"--enable-prefix-caching is not supported for multimodal "
"--enable-prefix-caching is not supported for multimodal "
"models in V0 and has been disabled."
)
"models in V0 and has been disabled."
)
self
.
enable_prefix_caching
=
False
self
.
enable_prefix_caching
=
False
# VLLM_V0 only supports builtin hash algo for prefix caching.
if
self
.
prefix_caching_hash_algo
==
"sha256"
:
raise
ValueError
(
"sha256 is not supported for prefix caching in V0 engine. "
"Please use 'builtin'."
)
# Set max_num_seqs to 256 for VLLM_V0.
# Set max_num_seqs to 256 for VLLM_V0.
if
self
.
max_num_seqs
is
None
:
if
self
.
max_num_seqs
is
None
:
self
.
max_num_seqs
=
256
self
.
max_num_seqs
=
256
...
...
vllm/envs.py
View file @
82dfb12e
...
@@ -171,6 +171,7 @@ if TYPE_CHECKING:
...
@@ -171,6 +171,7 @@ if TYPE_CHECKING:
VLLM_GPT_OSS_USE_CONTAINER_TOOL
:
bool
=
False
VLLM_GPT_OSS_USE_CONTAINER_TOOL
:
bool
=
False
VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS
:
bool
=
False
VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS
:
bool
=
False
VLLM_CUSTOM_SCOPES_FOR_PROFILING
:
bool
=
False
VLLM_CUSTOM_SCOPES_FOR_PROFILING
:
bool
=
False
VLLM_KV_EVENTS_USE_INT_BLOCK_HASHES
:
bool
=
True
def
get_default_cache_root
():
def
get_default_cache_root
():
...
@@ -1215,6 +1216,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
...
@@ -1215,6 +1216,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
# Add optional custom scopes for profiling, disable to avoid overheads
# Add optional custom scopes for profiling, disable to avoid overheads
"VLLM_CUSTOM_SCOPES_FOR_PROFILING"
:
"VLLM_CUSTOM_SCOPES_FOR_PROFILING"
:
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_CUSTOM_SCOPES_FOR_PROFILING"
,
"0"
))),
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_CUSTOM_SCOPES_FOR_PROFILING"
,
"0"
))),
# Represent block hashes in KV cache events as 64-bit integers instead of
# raw bytes. Defaults to True for backward compatibility.
"VLLM_KV_EVENTS_USE_INT_BLOCK_HASHES"
:
lambda
:
bool
(
int
(
os
.
getenv
(
"VLLM_KV_EVENTS_USE_INT_BLOCK_HASHES"
,
"1"
))),
}
}
# --8<-- [end:env-vars-definition]
# --8<-- [end:env-vars-definition]
...
...
vllm/utils/__init__.py
View file @
82dfb12e
...
@@ -3249,7 +3249,7 @@ def check_use_alibi(model_config: ModelConfig) -> bool:
...
@@ -3249,7 +3249,7 @@ def check_use_alibi(model_config: ModelConfig) -> bool:
and
getattr
(
cfg
.
attn_config
,
"alibi"
,
False
)))))
and
getattr
(
cfg
.
attn_config
,
"alibi"
,
False
)))))
def
sha256
(
input
)
->
int
:
def
sha256
(
input
)
->
bytes
:
"""Hash any picklable Python object using SHA-256.
"""Hash any picklable Python object using SHA-256.
The input is serialized using pickle before hashing, which allows
The input is serialized using pickle before hashing, which allows
...
@@ -3260,16 +3260,15 @@ def sha256(input) -> int:
...
@@ -3260,16 +3260,15 @@ def sha256(input) -> int:
input: Any picklable Python object.
input: Any picklable Python object.
Returns:
Returns:
An integer
representing the SHA-256 hash of the serialized input.
Bytes
representing the SHA-256 hash of the serialized input.
"""
"""
input_bytes
=
pickle
.
dumps
(
input
,
protocol
=
pickle
.
HIGHEST_PROTOCOL
)
input_bytes
=
pickle
.
dumps
(
input
,
protocol
=
pickle
.
HIGHEST_PROTOCOL
)
return
int
.
from_bytes
(
hashlib
.
sha256
(
input_bytes
).
digest
(),
return
hashlib
.
sha256
(
input_bytes
).
digest
()
byteorder
=
"big"
)
def
sha256_cbor
_64bit
(
input
)
->
int
:
def
sha256_cbor
(
input
)
->
bytes
:
"""
"""
Hash objects using CBOR serialization and SHA-256
, then truncate to 64bits
.
Hash objects using CBOR serialization and SHA-256.
This option is useful for non-Python-dependent serialization and hashing.
This option is useful for non-Python-dependent serialization and hashing.
...
@@ -3280,17 +3279,13 @@ def sha256_cbor_64bit(input) -> int:
...
@@ -3280,17 +3279,13 @@ def sha256_cbor_64bit(input) -> int:
Custom classes must implement CBOR serialization methods.
Custom classes must implement CBOR serialization methods.
Returns:
Returns:
An integer in the range [0, 2^64-1] representing the lower 64 bits
Bytes representing the SHA-256 hash of the CBOR serialized input.
of the SHA-256 hash of the CBOR serialized input.
"""
"""
input_bytes
=
cbor2
.
dumps
(
input
,
canonical
=
True
)
input_bytes
=
cbor2
.
dumps
(
input
,
canonical
=
True
)
full_hash
=
int
.
from_bytes
(
hashlib
.
sha256
(
input_bytes
).
digest
(),
return
hashlib
.
sha256
(
input_bytes
).
digest
()
byteorder
=
"big"
)
return
full_hash
&
((
1
<<
64
)
-
1
)
def
get_hash_fn_by_name
(
hash_fn_name
:
str
)
->
Callable
[[
Any
],
bytes
]:
def
get_hash_fn_by_name
(
hash_fn_name
:
str
)
->
Callable
[[
Any
],
int
]:
"""Get a hash function by name, or raise an error if
"""Get a hash function by name, or raise an error if
the function is not found.
the function is not found.
Args:
Args:
...
@@ -3300,10 +3295,8 @@ def get_hash_fn_by_name(hash_fn_name: str) -> Callable[[Any], int]:
...
@@ -3300,10 +3295,8 @@ def get_hash_fn_by_name(hash_fn_name: str) -> Callable[[Any], int]:
"""
"""
if
hash_fn_name
==
"sha256"
:
if
hash_fn_name
==
"sha256"
:
return
sha256
return
sha256
if
hash_fn_name
==
"sha256_cbor_64bit"
:
if
hash_fn_name
==
"sha256_cbor"
:
return
sha256_cbor_64bit
return
sha256_cbor
if
hash_fn_name
==
"builtin"
:
return
hash
raise
ValueError
(
f
"Unsupported hash function:
{
hash_fn_name
}
"
)
raise
ValueError
(
f
"Unsupported hash function:
{
hash_fn_name
}
"
)
...
...
vllm/v1/core/block_pool.py
View file @
82dfb12e
...
@@ -9,7 +9,11 @@ from vllm.distributed.kv_events import (MEDIUM_GPU, AllBlocksCleared,
...
@@ -9,7 +9,11 @@ from vllm.distributed.kv_events import (MEDIUM_GPU, AllBlocksCleared,
KVCacheEvent
)
KVCacheEvent
)
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.v1.core.kv_cache_utils
import
(
BlockHash
,
BlockHashWithGroupId
,
from
vllm.v1.core.kv_cache_utils
import
(
BlockHash
,
BlockHashWithGroupId
,
FreeKVCacheBlockQueue
,
KVCacheBlock
)
ExternalBlockHash
,
FreeKVCacheBlockQueue
,
KVCacheBlock
,
get_block_hash
,
make_block_hash_with_group_id
,
maybe_convert_block_hash
)
from
vllm.v1.request
import
Request
from
vllm.v1.request
import
Request
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -84,8 +88,10 @@ class BlockPool:
...
@@ -84,8 +88,10 @@ class BlockPool:
"""
"""
cached_blocks
=
[]
cached_blocks
=
[]
for
group_id
in
kv_cache_group_ids
:
for
group_id
in
kv_cache_group_ids
:
block_hash_with_group_id
=
make_block_hash_with_group_id
(
block_hash
,
group_id
)
cached_blocks_one_group
=
self
.
cached_block_hash_to_block
.
get
(
cached_blocks_one_group
=
self
.
cached_block_hash_to_block
.
get
(
B
lock
H
ash
W
ith
GroupId
(
block_hash
,
group_id
)
)
b
lock
_h
ash
_w
ith
_
group_id
)
if
not
cached_blocks_one_group
:
if
not
cached_blocks_one_group
:
return
None
return
None
first_block
=
next
(
iter
(
cached_blocks_one_group
.
values
()))
first_block
=
next
(
iter
(
cached_blocks_one_group
.
values
()))
...
@@ -124,28 +130,29 @@ class BlockPool:
...
@@ -124,28 +130,29 @@ class BlockPool:
assert
len
(
request
.
block_hashes
)
>=
num_full_blocks
assert
len
(
request
.
block_hashes
)
>=
num_full_blocks
new_block_hashes
=
request
.
block_hashes
[
num_cached_blocks
:]
new_block_hashes
=
request
.
block_hashes
[
num_cached_blocks
:]
new_hashes
:
Optional
[
list
[
int
]]
=
([]
if
self
.
enable_kv_cache_events
new_hashes
:
Optional
[
list
[
ExternalBlockHash
]]
=
(
else
None
)
[]
if
self
.
enable_kv_cache_events
else
None
)
for
i
,
blk
in
enumerate
(
new_full_blocks
):
for
i
,
blk
in
enumerate
(
new_full_blocks
):
assert
blk
.
block_hash
is
None
assert
blk
.
block_hash
is
None
block_hash
=
new_block_hashes
[
i
]
block_hash
=
new_block_hashes
[
i
]
# Update and added the full block to the cache.
# Update and added the full block to the cache.
block_hash_with_group_id
=
B
lock
H
ash
W
ith
G
roup
I
d
(
block_hash_with_group_id
=
make_b
lock
_h
ash
_w
ith
_g
roup
_i
d
(
block_hash
,
kv_cache_group_id
)
block_hash
,
kv_cache_group_id
)
blk
.
block_hash
=
block_hash_with_group_id
blk
.
block_hash
=
block_hash_with_group_id
self
.
cached_block_hash_to_block
[
block_hash_with_group_id
][
self
.
cached_block_hash_to_block
[
block_hash_with_group_id
][
blk
.
block_id
]
=
blk
blk
.
block_id
]
=
blk
if
new_hashes
is
not
None
:
if
new_hashes
is
not
None
:
new_hashes
.
append
(
block_hash
.
hash_value
)
new_hashes
.
append
(
maybe_convert_block_hash
(
block_hash
)
)
if
self
.
enable_kv_cache_events
:
if
self
.
enable_kv_cache_events
:
if
num_cached_blocks
==
0
:
if
num_cached_blocks
==
0
:
parent_block_hash
=
None
parent_block_hash
:
Optional
[
ExternalBlockHash
]
=
None
else
:
else
:
parent_block
=
blocks
[
num_cached_blocks
-
1
]
parent_block
=
blocks
[
num_cached_blocks
-
1
]
assert
parent_block
.
block_hash
is
not
None
assert
parent_block
.
block_hash
is
not
None
parent_block_hash
=
parent_block
.
block_hash
.
get_hash_value
()
parent_block_hash
=
maybe_convert_block_hash
(
get_block_hash
(
parent_block
.
block_hash
))
self
.
kv_event_queue
.
append
(
self
.
kv_event_queue
.
append
(
BlockStored
(
BlockStored
(
...
@@ -220,7 +227,9 @@ class BlockPool:
...
@@ -220,7 +227,9 @@ class BlockPool:
# we disable hybrid kv cache manager when kv cache event is
# we disable hybrid kv cache manager when kv cache event is
# enabled, so there is only one group.
# enabled, so there is only one group.
self
.
kv_event_queue
.
append
(
self
.
kv_event_queue
.
append
(
BlockRemoved
(
block_hashes
=
[
block_hash
.
get_hash_value
()],
BlockRemoved
(
block_hashes
=
[
maybe_convert_block_hash
(
get_block_hash
(
block_hash
))
],
medium
=
MEDIUM_GPU
))
medium
=
MEDIUM_GPU
))
return
True
return
True
...
...
vllm/v1/core/kv_cache_utils.py
View file @
82dfb12e
...
@@ -6,11 +6,12 @@ import os
...
@@ -6,11 +6,12 @@ import os
from
collections
import
defaultdict
,
deque
from
collections
import
defaultdict
,
deque
from
collections.abc
import
Iterable
,
Sequence
from
collections.abc
import
Iterable
,
Sequence
from
dataclasses
import
astuple
,
dataclass
from
dataclasses
import
astuple
,
dataclass
from
typing
import
Any
,
Callable
,
N
amedTupl
e
,
Optional
from
typing
import
Any
,
Callable
,
N
ewTyp
e
,
Optional
,
Union
from
vllm
import
envs
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.utils
import
GiB_bytes
,
cdiv
,
sha256_cbor
_64bit
from
vllm.utils
import
GiB_bytes
,
cdiv
,
sha256_cbor
from
vllm.v1.kv_cache_interface
import
(
ChunkedLocalAttentionSpec
,
from
vllm.v1.kv_cache_interface
import
(
ChunkedLocalAttentionSpec
,
FullAttentionSpec
,
KVCacheConfig
,
FullAttentionSpec
,
KVCacheConfig
,
KVCacheGroupSpec
,
KVCacheSpec
,
KVCacheGroupSpec
,
KVCacheSpec
,
...
@@ -18,59 +19,78 @@ from vllm.v1.kv_cache_interface import (ChunkedLocalAttentionSpec,
...
@@ -18,59 +19,78 @@ from vllm.v1.kv_cache_interface import (ChunkedLocalAttentionSpec,
from
vllm.v1.metrics.stats
import
PrefixCacheStats
from
vllm.v1.metrics.stats
import
PrefixCacheStats
from
vllm.v1.request
import
Request
from
vllm.v1.request
import
Request
logger
=
init_logger
(
__name__
)
# BlockHash represents the hash of a single KV-cache block used for
# prefix caching. Treating it as a distinct type from ``bytes`` helps
# catch accidental misuse when passing around raw byte strings.
BlockHash
=
NewType
(
"BlockHash"
,
bytes
)
# ``BlockHashWithGroupId`` combines a ``BlockHash`` with its KV cache group ID.
# It is represented as raw bytes for compactness and efficiency. The helper
# functions below pack/unpack the ``BlockHash`` and group id into/from the key.
BlockHashWithGroupId
=
NewType
(
"BlockHashWithGroupId"
,
bytes
)
# ExternalBlockHash is used for reproducible prefix-cache block hashing.
# It's a union of ``bytes`` and ``int`` to keep backward compatibility
# after we default block hashing to use sha256 bytes.
ExternalBlockHash
=
Union
[
bytes
,
int
]
def
make_block_hash_with_group_id
(
block_hash
:
BlockHash
,
group_id
:
int
)
->
BlockHashWithGroupId
:
"""Pack a ``BlockHash`` and group id into a ``BlockHashWithGroupId``.
class
BlockHash
(
NamedTuple
):
The group id is encoded using 4 bytes in big-endian order and appended to
"""Hash value of a block (int), the token IDs in the block, and extra keys.
the block hash bytes. This representation avoids creating tuples while
We keep a tuple of token IDs and extra keys to reduce the likelihood of
still allowing us to recover both components when needed.
hash collisions when the hash value is the same. By using SHA256 however,
hash collisions are practically impossible.
"""
"""
# Hash value of the block in an integer.
return
BlockHashWithGroupId
(
block_hash
+
hash_value
:
int
group_id
.
to_bytes
(
4
,
"big"
,
signed
=
False
))
# Token IDs in the block.
token_ids
:
tuple
[
int
,
...]
# Extra keys for the block.
def
get_block_hash
(
key
:
BlockHashWithGroupId
)
->
BlockHash
:
extra_keys
:
Optional
[
Any
]
=
None
"""Extract the ``BlockHash`` from a ``BlockHashWithGroupId``."""
return
BlockHash
(
key
[:
-
4
])
def
get_group_id
(
key
:
BlockHashWithGroupId
)
->
int
:
"""Extract the group id from a ``BlockHashWithGroupId``."""
return
int
.
from_bytes
(
key
[
-
4
:],
"big"
,
signed
=
False
)
class
BlockHashWithGroupId
(
NamedTuple
):
# The hash value for the contents (e.g., token_ids) of a block without group
# ID. The value is the same for blocks representing the same tokens but for
# different groups.
block_hash
:
BlockHash
# The KV cache group ID.
group_id
:
int
def
get_hash_value
(
self
)
->
int
:
def
maybe_convert_block_hash
(
hash_bytes
:
BlockHash
)
->
ExternalBlockHash
:
return
self
.
block_hash
.
hash_value
if
not
envs
.
VLLM_KV_EVENTS_USE_INT_BLOCK_HASHES
:
return
hash_bytes
return
int
.
from_bytes
(
hash_bytes
,
byteorder
=
"big"
)
&
((
1
<<
64
)
-
1
)
logger
=
init_logger
(
__name__
)
# The hash seed for the first block of any prefix block sequence.
# The hash seed for the first block of any prefix block sequence.
#
#
# We use a random value to avoid hash collisions or PYTHONHASHSEED environment
# We use a random value to avoid hash collisions or PYTHONHASHSEED environment
# variable if set such that processes can share the seed if needed.
# variable if set such that processes can share the seed if needed.
This aligns
#
This aligns
with the behavior of Python's hash() function, which also uses
# with the behavior of Python's hash() function, which also uses
a random seed
#
a random seed
if PYTHONHASHSEED is not set.
# if PYTHONHASHSEED is not set.
#
#
# The function `init_none_hash` initializes this variable globally.
# The function `init_none_hash` initializes this variable globally.
NONE_HASH
:
int
NONE_HASH
:
BlockHash
def
init_none_hash
(
hash_fn
:
Callable
):
def
init_none_hash
(
hash_fn
:
Callable
[[
Any
],
bytes
]
):
global
NONE_HASH
global
NONE_HASH
hash_seed
=
os
.
getenv
(
"PYTHONHASHSEED"
)
hash_seed
=
os
.
getenv
(
"PYTHONHASHSEED"
)
if
hash_seed
is
None
and
hash_fn
is
sha256_cbor
_64bit
:
if
hash_seed
is
None
and
hash_fn
is
sha256_cbor
:
logger
.
warning
(
logger
.
warning
(
"PYTHONHASHSEED is not set. This will lead to non-reproducible "
"PYTHONHASHSEED is not set. This will lead to non-reproducible "
"block-hashes when using sha256_cbor
_64bit
as the hash function."
"block-hashes when using sha256_cbor as the hash function."
"Consider setting PYTHONHASHSEED to a fixed value for "
"Consider setting PYTHONHASHSEED to a fixed value for "
"reproducibility."
)
"reproducibility."
)
NONE_HASH
=
(
int
.
from_bytes
(
os
.
urandom
(
32
),
byteorder
=
"big"
)
if
hash_seed
is
None
:
if
hash_seed
is
None
else
hash_fn
(
hash_seed
))
NONE_HASH
=
BlockHash
(
os
.
urandom
(
32
))
else
:
NONE_HASH
=
BlockHash
(
hash_fn
(
hash_seed
))
class
PrefixCachingMetrics
:
class
PrefixCachingMetrics
:
...
@@ -142,8 +162,8 @@ class KVCacheBlock:
...
@@ -142,8 +162,8 @@ class KVCacheBlock:
block_id
:
int
block_id
:
int
# Reference count.
# Reference count.
ref_cnt
:
int
=
0
ref_cnt
:
int
=
0
# The hash
of the block composed of (block hash, tuple of token IDs).
# The hash
key (block hash + group id) of the block, only available
#
It is only available
when the block is full.
# when the block is full
and cached
.
_block_hash
:
Optional
[
BlockHashWithGroupId
]
=
None
_block_hash
:
Optional
[
BlockHashWithGroupId
]
=
None
# Used to construct a doubly linked list for free blocks.
# Used to construct a doubly linked list for free blocks.
...
@@ -177,7 +197,7 @@ class KVCacheBlock:
...
@@ -177,7 +197,7 @@ class KVCacheBlock:
if
self
.
next_free_block
else
None
)
if
self
.
next_free_block
else
None
)
return
(
f
"KVCacheBlock(block_id=
{
self
.
block_id
}
, "
return
(
f
"KVCacheBlock(block_id=
{
self
.
block_id
}
, "
f
"ref_cnt=
{
self
.
ref_cnt
}
, "
f
"ref_cnt=
{
self
.
ref_cnt
}
, "
f
"_block_hash=
{
self
.
_block_hash
}
, "
f
"_block_hash=
{
self
.
_block_hash
!
r
}
, "
f
"prev_free_block=
{
prev_block_id
}
, "
f
"prev_free_block=
{
prev_block_id
}
, "
f
"next_free_block=
{
next_block_id
}
)"
)
f
"next_free_block=
{
next_block_id
}
)"
)
...
@@ -517,15 +537,14 @@ def generate_block_hash_extra_keys(
...
@@ -517,15 +537,14 @@ def generate_block_hash_extra_keys(
def
hash_block_tokens
(
def
hash_block_tokens
(
hash_function
:
Callable
,
hash_function
:
Callable
[[
Any
],
bytes
]
,
parent_block_hash
:
Optional
[
int
],
parent_block_hash
:
Optional
[
BlockHash
],
curr_block_token_ids
:
Sequence
[
int
],
curr_block_token_ids
:
Sequence
[
int
],
extra_keys
:
Optional
[
tuple
[
Any
,
...]]
=
None
)
->
BlockHash
:
extra_keys
:
Optional
[
tuple
[
Any
,
...]]
=
None
)
->
BlockHash
:
"""Computes a hash value corresponding to the contents of a block and
"""Computes a hash value corresponding to the contents of a block and
the contents of the preceding block(s). The hash value is used for
the contents of the preceding block(s). The hash value is used for
prefix caching. We use LRU cache for this function to avoid recomputing
prefix caching. We use LRU cache for this function to avoid recomputing
hash values for the same block contents.
hash values for the same block contents.
Args:
Args:
hash_function: The hash function used to compute block hash.
hash_function: The hash function used to compute block hash.
parent_block_hash: The hash of the parent block. None
parent_block_hash: The hash of the parent block. None
...
@@ -533,7 +552,6 @@ def hash_block_tokens(
...
@@ -533,7 +552,6 @@ def hash_block_tokens(
curr_block_token_ids: A list of token ids in the current
curr_block_token_ids: A list of token ids in the current
block. The current block is assumed to be full.
block. The current block is assumed to be full.
extra_keys: Extra keys for the block.
extra_keys: Extra keys for the block.
Returns:
Returns:
The hash value of the block and the token ids in the block.
The hash value of the block and the token ids in the block.
The entire tuple is used as the hash key of the block.
The entire tuple is used as the hash key of the block.
...
@@ -544,26 +562,16 @@ def hash_block_tokens(
...
@@ -544,26 +562,16 @@ def hash_block_tokens(
curr_block_token_ids_tuple
=
tuple
(
curr_block_token_ids
)
curr_block_token_ids_tuple
=
tuple
(
curr_block_token_ids
)
return
BlockHash
(
return
BlockHash
(
hash_function
(
hash_function
(
(
parent_block_hash
,
curr_block_token_ids_tuple
,
extra_keys
)),
(
parent_block_hash
,
curr_block_token_ids_tuple
,
extra_keys
)))
curr_block_token_ids_tuple
,
extra_keys
)
def
get_request_block_hasher
(
def
get_request_block_hasher
(
block_size
:
int
,
block_size
:
int
,
caching_hash_fn
:
Callable
[[
Any
],
caching_hash_fn
:
Callable
[[
Any
],
bytes
],
int
]
)
->
Callable
[[
Request
],
list
[
BlockHash
]]:
)
->
Callable
[[
Request
],
list
[
BlockHash
]]:
"""
"""
Returns a function which computes the list of un-computed block hashes
Returns a function which computes the list of un-computed block hashes
of a request.
of a request."""
Each request holds a list of its block hashes (request.block_hashes).
When a request is created, it calls the below function to compute
the hashes of all full blocks of the request's initial tokens.
The hashes are then stored in request.block_hashes.
Later, whenever new tokens are appended to the request, it calls
the below function again to compute any new full blocks of tokens.
The returned new hashes are appended to request.block_hashes.
"""
def
request_block_hasher
(
request
:
Request
)
->
list
[
BlockHash
]:
def
request_block_hasher
(
request
:
Request
)
->
list
[
BlockHash
]:
start_token_idx
=
len
(
request
.
block_hashes
)
*
block_size
start_token_idx
=
len
(
request
.
block_hashes
)
*
block_size
...
@@ -577,8 +585,8 @@ def get_request_block_hasher(
...
@@ -577,8 +585,8 @@ def get_request_block_hasher(
# last mm input.
# last mm input.
curr_mm_idx
=
-
1
curr_mm_idx
=
-
1
prev_block_hash_value
=
request
.
block_hashes
[
-
1
]
.
hash_value
\
prev_block_hash_value
=
(
request
.
block_hashes
[
-
1
]
if
request
.
block_hashes
else
None
if
request
.
block_hashes
else
None
)
new_block_hashes
:
list
[
BlockHash
]
=
[]
new_block_hashes
:
list
[
BlockHash
]
=
[]
while
True
:
while
True
:
end_token_idx
=
start_token_idx
+
block_size
end_token_idx
=
start_token_idx
+
block_size
...
@@ -598,7 +606,7 @@ def get_request_block_hasher(
...
@@ -598,7 +606,7 @@ def get_request_block_hasher(
new_block_hashes
.
append
(
block_hash
)
new_block_hashes
.
append
(
block_hash
)
start_token_idx
+=
block_size
start_token_idx
+=
block_size
prev_block_hash_value
=
block_hash
.
hash_value
prev_block_hash_value
=
block_hash
return
new_block_hashes
return
new_block_hashes
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment