Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
c280066f
Unverified
Commit
c280066f
authored
Aug 16, 2025
by
Or Ozeri
Committed by
GitHub
Aug 15, 2025
Browse files
[v1] Move block_hashes from KVCacheManager to Request.block_hashes (#19728)
Signed-off-by:
Or Ozeri
<
oro@il.ibm.com
>
parent
b9dc9d26
Changes
19
Show whitespace changes
Inline
Side-by-side
Showing
19 changed files
with
382 additions
and
336 deletions
+382
-336
tests/v1/core/test_async_scheduler.py
tests/v1/core/test_async_scheduler.py
+15
-7
tests/v1/core/test_kv_cache_utils.py
tests/v1/core/test_kv_cache_utils.py
+27
-23
tests/v1/core/test_prefix_caching.py
tests/v1/core/test_prefix_caching.py
+122
-103
tests/v1/core/test_scheduler.py
tests/v1/core/test_scheduler.py
+16
-13
tests/v1/core/test_single_type_kv_cache_manager.py
tests/v1/core/test_single_type_kv_cache_manager.py
+0
-2
tests/v1/core/utils.py
tests/v1/core/utils.py
+16
-1
tests/v1/kv_connector/unit/test_nixl_connector.py
tests/v1/kv_connector/unit/test_nixl_connector.py
+2
-0
tests/v1/kv_connector/unit/test_remote_decode_lifecycle.py
tests/v1/kv_connector/unit/test_remote_decode_lifecycle.py
+8
-2
tests/v1/kv_connector/unit/test_remote_prefill_lifecycle.py
tests/v1/kv_connector/unit/test_remote_prefill_lifecycle.py
+15
-2
tests/v1/kv_connector/unit/utils.py
tests/v1/kv_connector/unit/utils.py
+20
-11
vllm/utils/__init__.py
vllm/utils/__init__.py
+18
-0
vllm/v1/core/block_pool.py
vllm/v1/core/block_pool.py
+17
-58
vllm/v1/core/kv_cache_coordinator.py
vllm/v1/core/kv_cache_coordinator.py
+12
-21
vllm/v1/core/kv_cache_manager.py
vllm/v1/core/kv_cache_manager.py
+4
-47
vllm/v1/core/kv_cache_utils.py
vllm/v1/core/kv_cache_utils.py
+50
-30
vllm/v1/core/sched/scheduler.py
vllm/v1/core/sched/scheduler.py
+0
-2
vllm/v1/core/single_type_kv_cache_manager.py
vllm/v1/core/single_type_kv_cache_manager.py
+1
-9
vllm/v1/engine/core.py
vllm/v1/engine/core.py
+19
-3
vllm/v1/request.py
vllm/v1/request.py
+20
-2
No files found.
tests/v1/core/test_async_scheduler.py
View file @
c280066f
...
@@ -7,6 +7,7 @@ import pytest
...
@@ -7,6 +7,7 @@ import pytest
from
vllm.v1.core.sched.output
import
SchedulerOutput
from
vllm.v1.core.sched.output
import
SchedulerOutput
from
vllm.v1.outputs
import
ModelRunnerOutput
from
vllm.v1.outputs
import
ModelRunnerOutput
from
vllm.v1.request
import
RequestStatus
from
vllm.v1.request
import
RequestStatus
from
vllm.v1.utils
import
ConstantList
from
.utils
import
create_requests
,
create_scheduler
from
.utils
import
create_requests
,
create_scheduler
...
@@ -140,7 +141,8 @@ def test_prefix_caching_for_prefill_dedup():
...
@@ -140,7 +141,8 @@ def test_prefix_caching_for_prefill_dedup():
requests
=
create_requests
(
num_requests
=
5
,
requests
=
create_requests
(
num_requests
=
5
,
num_tokens
=
num_prompt_tokens
,
num_tokens
=
num_prompt_tokens
,
max_tokens
=
3
,
max_tokens
=
3
,
same_prompt
=
True
)
same_prompt
=
True
,
block_size
=
BLOCK_SIZE
)
requests_copy
=
requests
.
copy
()
requests_copy
=
requests
.
copy
()
# Two requests with the same prompt.
# Two requests with the same prompt.
...
@@ -188,7 +190,8 @@ def test_prefix_caching_for_multi_turn():
...
@@ -188,7 +190,8 @@ def test_prefix_caching_for_multi_turn():
block_size
=
BLOCK_SIZE
)
block_size
=
BLOCK_SIZE
)
requests
=
create_requests
(
num_requests
=
5
,
requests
=
create_requests
(
num_requests
=
5
,
num_tokens
=
num_prompt_tokens
,
num_tokens
=
num_prompt_tokens
,
max_tokens
=
num_output_tokens
)
max_tokens
=
num_output_tokens
,
block_size
=
BLOCK_SIZE
)
for
req
in
requests
:
for
req
in
requests
:
scheduler
.
add_request
(
req
)
scheduler
.
add_request
(
req
)
...
@@ -208,14 +211,19 @@ def test_prefix_caching_for_multi_turn():
...
@@ -208,14 +211,19 @@ def test_prefix_caching_for_multi_turn():
# Create next-turn requests whose prompts are the full output of the
# Create next-turn requests whose prompts are the full output of the
# previous turn.
# previous turn.
next_turn_requests
=
create_requests
(
next_turn_requests
=
create_requests
(
num_requests
=
5
,
num_requests
=
5
,
num_tokens
=
num_prompt_tokens
+
num_tokens
=
num_prompt_tokens
+
num_output_tokens
,
num_output_tokens
,
max_tokens
=
num_output_tokens
,
max_tokens
=
num_output_tokens
,
)
block_size
=
BLOCK_SIZE
)
for
i
,
req
in
enumerate
(
next_turn_requests
):
for
i
,
req
in
enumerate
(
next_turn_requests
):
req
.
prompt_token_ids
=
(
requests
[
i
].
prompt_token_ids
+
req
.
prompt_token_ids
=
(
requests
[
i
].
prompt_token_ids
+
list
(
requests
[
i
].
output_token_ids
))
list
(
requests
[
i
].
output_token_ids
))
req
.
_all_token_ids
=
req
.
prompt_token_ids
.
copy
()
req
.
all_token_ids
=
ConstantList
(
req
.
_all_token_ids
)
req
.
block_hashes
=
[]
req
.
block_hashes
=
req
.
get_hash_new_full_blocks
()
# Schedule the next-turn requests.
# Schedule the next-turn requests.
for
req
in
next_turn_requests
:
for
req
in
next_turn_requests
:
scheduler
.
add_request
(
req
)
scheduler
.
add_request
(
req
)
...
...
tests/v1/core/test_kv_cache_utils.py
View file @
c280066f
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
importlib
import
importlib
from
typing
import
Optional
from
typing
import
Callable
,
Optional
import
pytest
import
pytest
import
torch
import
torch
...
@@ -19,7 +19,7 @@ from vllm.v1.core.kv_cache_utils import (
...
@@ -19,7 +19,7 @@ from vllm.v1.core.kv_cache_utils import (
FreeKVCacheBlockQueue
,
KVCacheBlock
,
PrefixCachingMetrics
,
FreeKVCacheBlockQueue
,
KVCacheBlock
,
PrefixCachingMetrics
,
estimate_max_model_len
,
generate_block_hash_extra_keys
,
estimate_max_model_len
,
generate_block_hash_extra_keys
,
get_kv_cache_config
,
get_max_concurrency_for_kv_cache_config
,
get_kv_cache_config
,
get_max_concurrency_for_kv_cache_config
,
hash
_block_
tokens
,
hash_
request
_tokens
,
init_none_hash
,
get_request
_block_
hasher
,
hash_
block
_tokens
,
init_none_hash
,
is_kv_cache_type_uniform
,
unify_kv_cache_configs
)
is_kv_cache_type_uniform
,
unify_kv_cache_configs
)
from
vllm.v1.kv_cache_interface
import
(
FullAttentionSpec
,
KVCacheConfig
,
from
vllm.v1.kv_cache_interface
import
(
FullAttentionSpec
,
KVCacheConfig
,
KVCacheGroupSpec
,
KVCacheTensor
,
KVCacheGroupSpec
,
KVCacheTensor
,
...
@@ -33,6 +33,8 @@ from vllm.v1.request import Request
...
@@ -33,6 +33,8 @@ from vllm.v1.request import Request
def
make_request
(
def
make_request
(
request_id
:
str
,
request_id
:
str
,
prompt_token_ids
:
list
[
int
],
prompt_token_ids
:
list
[
int
],
block_size
:
int
=
3
,
hash_fn
:
Callable
=
hash
,
mm_positions
:
Optional
[
list
[
PlaceholderRange
]]
=
None
,
mm_positions
:
Optional
[
list
[
PlaceholderRange
]]
=
None
,
mm_hashes
:
Optional
[
list
[
str
]]
=
None
,
mm_hashes
:
Optional
[
list
[
str
]]
=
None
,
cache_salt
:
Optional
[
str
]
=
None
,
cache_salt
:
Optional
[
str
]
=
None
,
...
@@ -49,8 +51,7 @@ def make_request(
...
@@ -49,8 +51,7 @@ def make_request(
mm_item
=
MultiModalKwargsItem
.
from_elems
([
mm_elem
])
mm_item
=
MultiModalKwargsItem
.
from_elems
([
mm_elem
])
mm_kwargs
=
[
mm_item
]
*
len
(
mm_positions
)
mm_kwargs
=
[
mm_item
]
*
len
(
mm_positions
)
return
Request
(
return
Request
(
request_id
=
request_id
,
request_id
=
request_id
,
prompt_token_ids
=
prompt_token_ids
,
prompt_token_ids
=
prompt_token_ids
,
multi_modal_kwargs
=
mm_kwargs
,
multi_modal_kwargs
=
mm_kwargs
,
multi_modal_hashes
=
mm_hashes
,
multi_modal_hashes
=
mm_hashes
,
...
@@ -60,7 +61,7 @@ def make_request(
...
@@ -60,7 +61,7 @@ def make_request(
eos_token_id
=
100
,
eos_token_id
=
100
,
lora_request
=
None
,
lora_request
=
None
,
cache_salt
=
cache_salt
,
cache_salt
=
cache_salt
,
)
block_hasher
=
get_request_block_hasher
(
block_size
,
hash_fn
)
)
def
new_kv_cache_spec
(
block_size
=
16
,
def
new_kv_cache_spec
(
block_size
=
16
,
...
@@ -428,12 +429,14 @@ def test_hash_block_tokens(hash_fn):
...
@@ -428,12 +429,14 @@ def test_hash_block_tokens(hash_fn):
@
pytest
.
mark
.
parametrize
(
"hash_fn"
,
[
sha256
,
sha256_cbor_64bit
,
hash
])
@
pytest
.
mark
.
parametrize
(
"hash_fn"
,
[
sha256
,
sha256_cbor_64bit
,
hash
])
def
test_
hash_
request_
tokens
(
hash_fn
):
def
test_request_
block_hasher
(
hash_fn
):
import
vllm.v1.core.kv_cache_utils
import
vllm.v1.core.kv_cache_utils
init_none_hash
(
hash_fn
)
init_none_hash
(
hash_fn
)
request
=
make_request
(
request
=
make_request
(
request_id
=
"0"
,
request_id
=
"0"
,
prompt_token_ids
=
[
_
for
_
in
range
(
6
)],
prompt_token_ids
=
[
_
for
_
in
range
(
6
)],
block_size
=
3
,
hash_fn
=
hash_fn
,
mm_positions
=
[
mm_positions
=
[
PlaceholderRange
(
offset
=
0
,
length
=
3
),
PlaceholderRange
(
offset
=
0
,
length
=
3
),
PlaceholderRange
(
offset
=
3
,
length
=
3
),
PlaceholderRange
(
offset
=
3
,
length
=
3
),
...
@@ -441,9 +444,7 @@ def test_hash_request_tokens(hash_fn):
...
@@ -441,9 +444,7 @@ def test_hash_request_tokens(hash_fn):
mm_hashes
=
[
"hash1"
,
"hash2"
],
mm_hashes
=
[
"hash1"
,
"hash2"
],
)
)
block_size
=
3
block_hashes
=
request
.
block_hashes
block_hashes
=
hash_request_tokens
(
hash_fn
,
block_size
,
request
)
assert
len
(
block_hashes
)
==
2
assert
len
(
block_hashes
)
==
2
assert
isinstance
(
block_hashes
[
0
],
vllm
.
v1
.
core
.
kv_cache_utils
.
BlockHash
)
assert
isinstance
(
block_hashes
[
0
],
vllm
.
v1
.
core
.
kv_cache_utils
.
BlockHash
)
assert
isinstance
(
block_hashes
[
1
],
vllm
.
v1
.
core
.
kv_cache_utils
.
BlockHash
)
assert
isinstance
(
block_hashes
[
1
],
vllm
.
v1
.
core
.
kv_cache_utils
.
BlockHash
)
...
@@ -464,6 +465,8 @@ def test_hash_tokens_different_mm_input(hash_fn):
...
@@ -464,6 +465,8 @@ def test_hash_tokens_different_mm_input(hash_fn):
request1
=
make_request
(
request1
=
make_request
(
request_id
=
"0"
,
request_id
=
"0"
,
prompt_token_ids
=
[
_
for
_
in
range
(
6
)],
prompt_token_ids
=
[
_
for
_
in
range
(
6
)],
block_size
=
3
,
hash_fn
=
hash_fn
,
mm_positions
=
[
mm_positions
=
[
PlaceholderRange
(
offset
=
0
,
length
=
3
),
PlaceholderRange
(
offset
=
0
,
length
=
3
),
PlaceholderRange
(
offset
=
3
,
length
=
3
),
PlaceholderRange
(
offset
=
3
,
length
=
3
),
...
@@ -479,9 +482,8 @@ def test_hash_tokens_different_mm_input(hash_fn):
...
@@ -479,9 +482,8 @@ def test_hash_tokens_different_mm_input(hash_fn):
],
],
mm_hashes
=
[
"hash3"
,
"hash2"
],
mm_hashes
=
[
"hash3"
,
"hash2"
],
)
)
block_size
=
3
block_hashes1
=
request1
.
block_hashes
block_hashes1
=
hash_request_tokens
(
hash_fn
,
block_size
,
request1
)
block_hashes2
=
request2
.
block_hashes
block_hashes2
=
hash_request_tokens
(
hash_fn
,
block_size
,
request2
)
assert
block_hashes1
[
0
]
!=
block_hashes2
[
0
]
assert
block_hashes1
[
0
]
!=
block_hashes2
[
0
]
assert
block_hashes1
[
1
]
!=
block_hashes2
[
1
]
assert
block_hashes1
[
1
]
!=
block_hashes2
[
1
]
...
@@ -493,12 +495,13 @@ def test_hash_request_tokens_no_mm_inputs(hash_fn):
...
@@ -493,12 +495,13 @@ def test_hash_request_tokens_no_mm_inputs(hash_fn):
request
=
make_request
(
request
=
make_request
(
request_id
=
"0"
,
request_id
=
"0"
,
prompt_token_ids
=
[
_
for
_
in
range
(
6
)],
prompt_token_ids
=
[
_
for
_
in
range
(
6
)],
block_size
=
3
,
hash_fn
=
hash_fn
,
mm_positions
=
None
,
mm_positions
=
None
,
mm_hashes
=
None
,
mm_hashes
=
None
,
)
)
block_size
=
3
block_hashes
=
request
.
block_hashes
block_hashes
=
hash_request_tokens
(
hash_fn
,
block_size
,
request
)
assert
len
(
block_hashes
)
==
2
assert
len
(
block_hashes
)
==
2
assert
block_hashes
[
0
].
token_ids
==
(
0
,
1
,
2
)
assert
block_hashes
[
0
].
token_ids
==
(
0
,
1
,
2
)
...
@@ -858,6 +861,7 @@ def test_allocate_with_lookahead():
...
@@ -858,6 +861,7 @@ def test_allocate_with_lookahead():
request
=
make_request
(
request
=
make_request
(
request_id
=
"0"
,
request_id
=
"0"
,
prompt_token_ids
=
[],
prompt_token_ids
=
[],
block_size
=
block_size
,
mm_positions
=
None
,
mm_positions
=
None
,
mm_hashes
=
None
,
mm_hashes
=
None
,
)
)
...
...
tests/v1/core/test_prefix_caching.py
View file @
c280066f
...
@@ -3,7 +3,7 @@
...
@@ -3,7 +3,7 @@
"""Compare the with and without prefix caching."""
"""Compare the with and without prefix caching."""
import
copy
import
copy
from
typing
import
Optional
from
typing
import
Callable
,
Optional
import
pytest
import
pytest
import
torch
import
torch
...
@@ -17,8 +17,9 @@ from vllm.utils import sha256, sha256_cbor_64bit
...
@@ -17,8 +17,9 @@ from vllm.utils import sha256, sha256_cbor_64bit
from
vllm.v1.core.block_pool
import
BlockPool
from
vllm.v1.core.block_pool
import
BlockPool
from
vllm.v1.core.kv_cache_manager
import
KVCacheManager
,
Request
from
vllm.v1.core.kv_cache_manager
import
KVCacheManager
,
Request
from
vllm.v1.core.kv_cache_utils
import
(
BlockHash
,
BlockHashWithGroupId
,
from
vllm.v1.core.kv_cache_utils
import
(
BlockHash
,
BlockHashWithGroupId
,
KVCacheBlock
,
hash_block_tokens
,
KVCacheBlock
,
init_none_hash
)
get_request_block_hasher
,
hash_block_tokens
,
init_none_hash
)
from
vllm.v1.kv_cache_interface
import
(
FullAttentionSpec
,
KVCacheConfig
,
from
vllm.v1.kv_cache_interface
import
(
FullAttentionSpec
,
KVCacheConfig
,
KVCacheGroupSpec
,
SlidingWindowSpec
)
KVCacheGroupSpec
,
SlidingWindowSpec
)
...
@@ -26,6 +27,8 @@ from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
...
@@ -26,6 +27,8 @@ from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
def
make_request
(
def
make_request
(
request_id
:
str
,
request_id
:
str
,
prompt_token_ids
:
list
[
int
],
prompt_token_ids
:
list
[
int
],
block_size
:
int
,
hash_fn
:
Callable
,
mm_positions
:
Optional
[
list
[
PlaceholderRange
]]
=
None
,
mm_positions
:
Optional
[
list
[
PlaceholderRange
]]
=
None
,
mm_hashes
:
Optional
[
list
[
str
]]
=
None
,
mm_hashes
:
Optional
[
list
[
str
]]
=
None
,
prompt_logprobs
:
Optional
[
int
]
=
None
,
prompt_logprobs
:
Optional
[
int
]
=
None
,
...
@@ -43,19 +46,18 @@ def make_request(
...
@@ -43,19 +46,18 @@ def make_request(
mm_item
=
MultiModalKwargsItem
.
from_elems
([
mm_elem
])
mm_item
=
MultiModalKwargsItem
.
from_elems
([
mm_elem
])
mm_kwargs
=
[
mm_item
]
*
len
(
mm_positions
)
mm_kwargs
=
[
mm_item
]
*
len
(
mm_positions
)
return
Request
(
return
Request
(
request_id
=
request_id
,
request_id
=
request_id
,
prompt_token_ids
=
prompt_token_ids
,
prompt_token_ids
=
prompt_token_ids
,
multi_modal_kwargs
=
mm_kwargs
,
multi_modal_kwargs
=
mm_kwargs
,
multi_modal_hashes
=
mm_hashes
,
multi_modal_hashes
=
mm_hashes
,
multi_modal_placeholders
=
mm_positions
,
multi_modal_placeholders
=
mm_positions
,
sampling_params
=
SamplingParams
(
max_tokens
=
17
,
sampling_params
=
SamplingParams
(
prompt_logprobs
=
prompt_logprobs
),
max_tokens
=
17
,
prompt_logprobs
=
prompt_logprobs
),
pooling_params
=
None
,
pooling_params
=
None
,
eos_token_id
=
100
,
eos_token_id
=
100
,
lora_request
=
None
,
lora_request
=
None
,
cache_salt
=
cache_salt
,
cache_salt
=
cache_salt
,
)
block_hasher
=
get_request_block_hasher
(
block_size
,
hash_fn
)
)
def
make_kv_cache_config
(
block_size
:
int
,
num_blocks
:
int
)
->
KVCacheConfig
:
def
make_kv_cache_config
(
block_size
:
int
,
num_blocks
:
int
)
->
KVCacheConfig
:
...
@@ -105,11 +107,11 @@ def make_kv_cache_config_hybrid_model(block_size: int,
...
@@ -105,11 +107,11 @@ def make_kv_cache_config_hybrid_model(block_size: int,
@
pytest
.
mark
.
parametrize
(
"hash_algo"
,
[
"sha256"
,
"sha256_cbor_64bit"
,
"hash"
])
@
pytest
.
mark
.
parametrize
(
"hash_algo"
,
[
"sha256"
,
"sha256_cbor_64bit"
,
"hash"
])
def
test_prefill
(
hash_algo
):
def
test_prefill
(
hash_algo
):
block_size
=
16
manager
=
KVCacheManager
(
manager
=
KVCacheManager
(
make_kv_cache_config
(
16
,
11
),
make_kv_cache_config
(
block_size
,
11
),
max_model_len
=
8192
,
max_model_len
=
8192
,
enable_caching
=
True
,
enable_caching
=
True
,
caching_hash_algo
=
hash_algo
,
)
)
# choose the hash function according to the parameter
# choose the hash function according to the parameter
...
@@ -123,9 +125,9 @@ def test_prefill(hash_algo):
...
@@ -123,9 +125,9 @@ def test_prefill(hash_algo):
# Incomplete 1 block (7 tokens)
# Incomplete 1 block (7 tokens)
unique_token_ids
=
[
3
]
*
7
unique_token_ids
=
[
3
]
*
7
all_token_ids
=
common_token_ids
+
unique_token_ids
all_token_ids
=
common_token_ids
+
unique_token_ids
req0
=
make_request
(
"0"
,
all_token_ids
)
req0
=
make_request
(
"0"
,
all_token_ids
,
block_size
,
hash_fn
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req0
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req0
)
assert
len
(
manager
.
req_to_block_hashes
[
req0
.
request_id
]
)
==
3
assert
len
(
req0
.
block_hashes
)
==
3
assert
not
computed_blocks
.
blocks
[
0
]
assert
not
computed_blocks
.
blocks
[
0
]
assert
num_computed_tokens
==
0
assert
num_computed_tokens
==
0
blocks
=
manager
.
allocate_slots
(
req0
,
55
,
blocks
=
manager
.
allocate_slots
(
req0
,
55
,
...
@@ -152,9 +154,10 @@ def test_prefill(hash_algo):
...
@@ -152,9 +154,10 @@ def test_prefill(hash_algo):
# Cache hit in the common prefix when the original block is still in use.
# Cache hit in the common prefix when the original block is still in use.
# Incomplete 1 block (5 tokens)
# Incomplete 1 block (5 tokens)
unique_token_ids
=
[
3
]
*
5
unique_token_ids
=
[
3
]
*
5
req1
=
make_request
(
"1"
,
common_token_ids
+
unique_token_ids
)
req1
=
make_request
(
"1"
,
common_token_ids
+
unique_token_ids
,
block_size
,
hash_fn
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req1
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req1
)
assert
len
(
manager
.
req_to_block_hashes
[
req1
.
request_id
]
)
==
3
assert
len
(
req1
.
block_hashes
)
==
3
assert
computed_blocks
.
get_block_ids
()
==
([
1
,
2
,
3
],
)
assert
computed_blocks
.
get_block_ids
()
==
([
1
,
2
,
3
],
)
assert
num_computed_tokens
==
3
*
16
assert
num_computed_tokens
==
3
*
16
num_new_tokens
=
53
-
3
*
16
num_new_tokens
=
53
-
3
*
16
...
@@ -187,9 +190,10 @@ def test_prefill(hash_algo):
...
@@ -187,9 +190,10 @@ def test_prefill(hash_algo):
# Cache hit in the common prefix when the original block is already free.
# Cache hit in the common prefix when the original block is already free.
# Incomplete 1 block (6 tokens)
# Incomplete 1 block (6 tokens)
unique_token_ids
=
[
3
]
*
6
unique_token_ids
=
[
3
]
*
6
req2
=
make_request
(
"2"
,
common_token_ids
+
unique_token_ids
)
req2
=
make_request
(
"2"
,
common_token_ids
+
unique_token_ids
,
block_size
,
hash_fn
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req2
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req2
)
assert
len
(
manager
.
req_to_block_hashes
[
req2
.
request_id
]
)
==
3
assert
len
(
req2
.
block_hashes
)
==
3
assert
computed_blocks
.
get_block_ids
()
==
([
1
,
2
,
3
],
)
assert
computed_blocks
.
get_block_ids
()
==
([
1
,
2
,
3
],
)
assert
num_computed_tokens
==
3
*
16
assert
num_computed_tokens
==
3
*
16
num_new_tokens
=
53
-
3
*
16
num_new_tokens
=
53
-
3
*
16
...
@@ -208,7 +212,7 @@ def test_prefill(hash_algo):
...
@@ -208,7 +212,7 @@ def test_prefill(hash_algo):
manager
.
free
(
req2
)
manager
.
free
(
req2
)
# Cache miss and eviction.
# Cache miss and eviction.
req3
=
make_request
(
"3"
,
[
99
]
*
(
16
*
10
))
req3
=
make_request
(
"3"
,
[
99
]
*
(
16
*
10
)
,
block_size
,
hash_fn
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req3
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req3
)
assert
not
computed_blocks
.
blocks
[
0
]
assert
not
computed_blocks
.
blocks
[
0
]
assert
num_computed_tokens
==
0
assert
num_computed_tokens
==
0
...
@@ -242,9 +246,9 @@ def test_prefill_hybrid_model():
...
@@ -242,9 +246,9 @@ def test_prefill_hybrid_model():
# Incomplete 1 block (7 tokens)
# Incomplete 1 block (7 tokens)
unique_token_ids
=
[
3
]
*
7
unique_token_ids
=
[
3
]
*
7
all_token_ids
=
common_token_ids
+
unique_token_ids
all_token_ids
=
common_token_ids
+
unique_token_ids
req0
=
make_request
(
"0"
,
all_token_ids
)
req0
=
make_request
(
"0"
,
all_token_ids
,
block_size
,
hash_fn
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req0
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req0
)
assert
len
(
manager
.
req_to_block_hashes
[
req0
.
request_id
]
)
==
3
assert
len
(
req0
.
block_hashes
)
==
3
assert
not
computed_blocks
.
blocks
[
0
]
assert
not
computed_blocks
.
blocks
[
0
]
assert
num_computed_tokens
==
0
assert
num_computed_tokens
==
0
blocks
=
manager
.
allocate_slots
(
req0
,
55
,
blocks
=
manager
.
allocate_slots
(
req0
,
55
,
...
@@ -274,9 +278,10 @@ def test_prefill_hybrid_model():
...
@@ -274,9 +278,10 @@ def test_prefill_hybrid_model():
# Cache hit in the common prefix
# Cache hit in the common prefix
# Incomplete 1 block (5 tokens)
# Incomplete 1 block (5 tokens)
unique_token_ids
=
[
3
]
*
5
unique_token_ids
=
[
3
]
*
5
req1
=
make_request
(
"1"
,
common_token_ids
+
unique_token_ids
)
req1
=
make_request
(
"1"
,
common_token_ids
+
unique_token_ids
,
block_size
,
hash_fn
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req1
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req1
)
assert
len
(
manager
.
req_to_block_hashes
[
req1
.
request_id
]
)
==
3
assert
len
(
req1
.
block_hashes
)
==
3
assert
computed_blocks
.
get_block_ids
()
==
([
1
,
2
,
3
],
[
0
,
6
,
assert
computed_blocks
.
get_block_ids
()
==
([
1
,
2
,
3
],
[
0
,
6
,
7
],
[
0
,
10
,
11
])
7
],
[
0
,
10
,
11
])
assert
num_computed_tokens
==
3
*
16
assert
num_computed_tokens
==
3
*
16
...
@@ -290,7 +295,7 @@ def test_prefill_hybrid_model():
...
@@ -290,7 +295,7 @@ def test_prefill_hybrid_model():
if
block
!=
manager
.
block_pool
.
null_block
:
if
block
!=
manager
.
block_pool
.
null_block
:
assert
block
.
ref_cnt
==
2
assert
block
.
ref_cnt
==
2
block_hashes
=
manager
.
req_to_block_hashes
[
req1
.
request_id
]
block_hashes
=
req1
.
block_hashes
manager
.
free
(
req0
)
manager
.
free
(
req0
)
manager
.
free
(
req1
)
manager
.
free
(
req1
)
...
@@ -300,12 +305,13 @@ def test_prefill_hybrid_model():
...
@@ -300,12 +305,13 @@ def test_prefill_hybrid_model():
def
test_partial_request_hit
(
request_id
:
str
,
def
test_partial_request_hit
(
request_id
:
str
,
hash_to_evict
:
list
[
BlockHashWithGroupId
],
hash_to_evict
:
list
[
BlockHashWithGroupId
],
expect_hit_length
:
int
):
expect_hit_length
:
int
):
req
=
make_request
(
request_id
,
common_token_ids
+
unique_token_ids
)
req
=
make_request
(
request_id
,
common_token_ids
+
unique_token_ids
,
block_size
,
hash
)
for
hash_with_group_id
in
hash_to_evict
:
for
hash_with_group_id
in
hash_to_evict
:
manager
.
block_pool
.
cached_block_hash_to_block
.
pop
(
manager
.
block_pool
.
cached_block_hash_to_block
.
pop
(
hash_with_group_id
)
hash_with_group_id
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req
)
assert
len
(
manager
.
req_to_block_hashes
[
req
.
request_id
]
)
==
3
assert
len
(
req
.
block_hashes
)
==
3
assert
num_computed_tokens
==
expect_hit_length
*
block_size
assert
num_computed_tokens
==
expect_hit_length
*
block_size
for
block_per_group
in
computed_blocks
.
blocks
:
for
block_per_group
in
computed_blocks
.
blocks
:
assert
len
(
block_per_group
)
==
num_computed_tokens
//
block_size
assert
len
(
block_per_group
)
==
num_computed_tokens
//
block_size
...
@@ -364,8 +370,9 @@ def test_prefill_plp():
...
@@ -364,8 +370,9 @@ def test_prefill_plp():
2. Schedule non-plp request and validate blocks
2. Schedule non-plp request and validate blocks
3. Schedule plp request; no hit should occur; validate blocks
3. Schedule plp request; no hit should occur; validate blocks
'''
'''
block_size
=
16
manager
=
KVCacheManager
(
manager
=
KVCacheManager
(
make_kv_cache_config
(
16
,
11
),
make_kv_cache_config
(
block_size
,
11
),
max_model_len
=
8192
,
max_model_len
=
8192
,
enable_caching
=
True
,
enable_caching
=
True
,
)
)
...
@@ -380,9 +387,13 @@ def test_prefill_plp():
...
@@ -380,9 +387,13 @@ def test_prefill_plp():
# Incomplete 1 block (7 tokens)
# Incomplete 1 block (7 tokens)
unique_token_ids
=
[
3
]
*
7
unique_token_ids
=
[
3
]
*
7
all_token_ids
=
common_token_ids
+
unique_token_ids
all_token_ids
=
common_token_ids
+
unique_token_ids
req0
=
make_request
(
"0"
,
all_token_ids
,
prompt_logprobs
=
5
)
req0
=
make_request
(
"0"
,
all_token_ids
,
block_size
,
hash_fn
,
prompt_logprobs
=
5
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req0
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req0
)
assert
len
(
manager
.
req_to_block_hashes
[
req0
.
request_id
]
)
==
0
assert
len
(
req0
.
block_hashes
)
==
3
assert
not
computed_blocks
.
blocks
[
0
]
assert
not
computed_blocks
.
blocks
[
0
]
assert
num_computed_tokens
==
0
assert
num_computed_tokens
==
0
blocks
=
manager
.
allocate_slots
(
req0
,
55
,
blocks
=
manager
.
allocate_slots
(
req0
,
55
,
...
@@ -411,9 +422,10 @@ def test_prefill_plp():
...
@@ -411,9 +422,10 @@ def test_prefill_plp():
# Cache hit in the common prefix when the original block is still in use.
# Cache hit in the common prefix when the original block is still in use.
# Incomplete 1 block (5 tokens)
# Incomplete 1 block (5 tokens)
unique_token_ids
=
[
3
]
*
5
unique_token_ids
=
[
3
]
*
5
req1
=
make_request
(
"1"
,
common_token_ids
+
unique_token_ids
)
req1
=
make_request
(
"1"
,
common_token_ids
+
unique_token_ids
,
block_size
,
hash_fn
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req1
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req1
)
assert
len
(
manager
.
req_to_block_hashes
[
req1
.
request_id
]
)
==
3
assert
len
(
req1
.
block_hashes
)
==
3
assert
computed_blocks
.
get_block_ids
()
==
([
1
,
2
,
3
],
)
assert
computed_blocks
.
get_block_ids
()
==
([
1
,
2
,
3
],
)
assert
num_computed_tokens
==
3
*
16
assert
num_computed_tokens
==
3
*
16
num_new_tokens
=
53
-
3
*
16
num_new_tokens
=
53
-
3
*
16
...
@@ -447,9 +459,11 @@ def test_prefill_plp():
...
@@ -447,9 +459,11 @@ def test_prefill_plp():
unique_token_ids
=
[
3
]
*
6
unique_token_ids
=
[
3
]
*
6
req2
=
make_request
(
"2"
,
req2
=
make_request
(
"2"
,
common_token_ids
+
unique_token_ids
,
common_token_ids
+
unique_token_ids
,
block_size
,
hash_fn
,
prompt_logprobs
=
5
)
prompt_logprobs
=
5
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req2
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req2
)
assert
len
(
manager
.
req_to_block_hashes
[
req2
.
request_id
]
)
==
0
assert
len
(
req2
.
block_hashes
)
==
3
assert
not
computed_blocks
.
blocks
[
0
]
assert
not
computed_blocks
.
blocks
[
0
]
assert
num_computed_tokens
==
0
assert
num_computed_tokens
==
0
blocks
=
manager
.
allocate_slots
(
req2
,
55
,
blocks
=
manager
.
allocate_slots
(
req2
,
55
,
...
@@ -469,8 +483,9 @@ def test_prefill_plp():
...
@@ -469,8 +483,9 @@ def test_prefill_plp():
def
test_decode
():
def
test_decode
():
block_size
=
16
manager
=
KVCacheManager
(
manager
=
KVCacheManager
(
make_kv_cache_config
(
16
,
11
),
make_kv_cache_config
(
block_size
,
11
),
max_model_len
=
8192
,
max_model_len
=
8192
,
enable_caching
=
True
,
enable_caching
=
True
,
)
)
...
@@ -481,7 +496,8 @@ def test_decode():
...
@@ -481,7 +496,8 @@ def test_decode():
# Fully cache miss
# Fully cache miss
# Incomplete 1 block (7 tokens)
# Incomplete 1 block (7 tokens)
unique_token_ids
=
[
3
]
*
7
unique_token_ids
=
[
3
]
*
7
req0
=
make_request
(
"0"
,
common_token_ids
+
unique_token_ids
)
req0
=
make_request
(
"0"
,
common_token_ids
+
unique_token_ids
,
block_size
,
hash
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req0
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req0
)
assert
not
computed_blocks
.
blocks
[
0
]
assert
not
computed_blocks
.
blocks
[
0
]
assert
num_computed_tokens
==
0
assert
num_computed_tokens
==
0
...
@@ -518,14 +534,15 @@ def test_decode():
...
@@ -518,14 +534,15 @@ def test_decode():
def
test_evict
():
def
test_evict
():
block_size
=
16
manager
=
KVCacheManager
(
manager
=
KVCacheManager
(
make_kv_cache_config
(
16
,
11
),
make_kv_cache_config
(
block_size
,
11
),
max_model_len
=
8192
,
max_model_len
=
8192
,
enable_caching
=
True
,
enable_caching
=
True
,
)
)
last_token_id
=
5
*
16
+
7
last_token_id
=
5
*
16
+
7
req0
=
make_request
(
"0"
,
list
(
range
(
last_token_id
)))
req0
=
make_request
(
"0"
,
list
(
range
(
last_token_id
))
,
block_size
,
hash
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req0
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req0
)
assert
not
computed_blocks
.
blocks
[
0
]
assert
not
computed_blocks
.
blocks
[
0
]
assert
num_computed_tokens
==
0
assert
num_computed_tokens
==
0
...
@@ -536,7 +553,8 @@ def test_evict():
...
@@ -536,7 +553,8 @@ def test_evict():
# 3 blocks.
# 3 blocks.
req1
=
make_request
(
"1"
,
list
(
range
(
last_token_id
,
req1
=
make_request
(
"1"
,
list
(
range
(
last_token_id
,
last_token_id
+
3
*
16
)))
last_token_id
+
3
*
16
)),
block_size
,
hash
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req1
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req1
)
assert
not
computed_blocks
.
blocks
[
0
]
assert
not
computed_blocks
.
blocks
[
0
]
assert
num_computed_tokens
==
0
assert
num_computed_tokens
==
0
...
@@ -558,7 +576,7 @@ def test_evict():
...
@@ -558,7 +576,7 @@ def test_evict():
]
==
[
10
,
6
,
5
,
4
,
3
,
2
,
1
,
9
,
8
,
7
]
]
==
[
10
,
6
,
5
,
4
,
3
,
2
,
1
,
9
,
8
,
7
]
# Touch the first 2 blocks.
# Touch the first 2 blocks.
req2
=
make_request
(
"2"
,
list
(
range
(
2
*
16
+
3
)))
req2
=
make_request
(
"2"
,
list
(
range
(
2
*
16
+
3
))
,
block_size
,
hash
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req2
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req2
)
assert
computed_blocks
.
get_block_ids
()
==
([
1
,
2
],
)
assert
computed_blocks
.
get_block_ids
()
==
([
1
,
2
],
)
assert
num_computed_tokens
==
2
*
16
assert
num_computed_tokens
==
2
*
16
...
@@ -583,7 +601,7 @@ def test_hash_block_correct_reuse():
...
@@ -583,7 +601,7 @@ def test_hash_block_correct_reuse():
# Allocate 1 block and cache it.
# Allocate 1 block and cache it.
num_tokens
=
block_size
*
1
num_tokens
=
block_size
*
1
req
=
make_request
(
"0"
,
list
(
range
(
num_tokens
)))
req
=
make_request
(
"0"
,
list
(
range
(
num_tokens
))
,
block_size
,
hash
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req
)
assert
not
computed_blocks
.
blocks
[
0
]
assert
not
computed_blocks
.
blocks
[
0
]
assert
num_computed_tokens
==
0
assert
num_computed_tokens
==
0
...
@@ -597,7 +615,7 @@ def test_hash_block_correct_reuse():
...
@@ -597,7 +615,7 @@ def test_hash_block_correct_reuse():
# Allocate a new block that's not full, make sure hash info on the
# Allocate a new block that's not full, make sure hash info on the
# block is cleared.
# block is cleared.
req
=
make_request
(
"1"
,
list
(
range
(
num_tokens
-
1
)))
req
=
make_request
(
"1"
,
list
(
range
(
num_tokens
-
1
))
,
block_size
,
hash
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req
)
assert
not
computed_blocks
.
blocks
[
0
]
assert
not
computed_blocks
.
blocks
[
0
]
assert
num_computed_tokens
==
0
assert
num_computed_tokens
==
0
...
@@ -624,7 +642,7 @@ def test_computed_blocks_not_evicted():
...
@@ -624,7 +642,7 @@ def test_computed_blocks_not_evicted():
# Allocate a block and cache it.
# Allocate a block and cache it.
num_tokens
=
block_size
*
1
num_tokens
=
block_size
*
1
req0
=
make_request
(
"0"
,
list
(
range
(
num_tokens
)))
req0
=
make_request
(
"0"
,
list
(
range
(
num_tokens
))
,
block_size
,
hash
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req0
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req0
)
assert
not
computed_blocks
.
blocks
[
0
]
assert
not
computed_blocks
.
blocks
[
0
]
assert
num_computed_tokens
==
0
assert
num_computed_tokens
==
0
...
@@ -635,7 +653,8 @@ def test_computed_blocks_not_evicted():
...
@@ -635,7 +653,8 @@ def test_computed_blocks_not_evicted():
assert
blocks
.
blocks
[
0
][
0
].
block_id
==
1
assert
blocks
.
blocks
[
0
][
0
].
block_id
==
1
# Allocate another block.
# Allocate another block.
req1
=
make_request
(
"1"
,
list
(
range
(
num_tokens
,
num_tokens
*
2
)))
req1
=
make_request
(
"1"
,
list
(
range
(
num_tokens
,
num_tokens
*
2
)),
block_size
,
hash
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req1
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req1
)
assert
not
computed_blocks
.
blocks
[
0
]
assert
not
computed_blocks
.
blocks
[
0
]
assert
num_computed_tokens
==
0
assert
num_computed_tokens
==
0
...
@@ -651,7 +670,7 @@ def test_computed_blocks_not_evicted():
...
@@ -651,7 +670,7 @@ def test_computed_blocks_not_evicted():
# Now if we have a cache hit on the first block, we should evict the second
# Now if we have a cache hit on the first block, we should evict the second
# cached block rather than the first one.
# cached block rather than the first one.
req2
=
make_request
(
"2"
,
list
(
range
(
num_tokens
*
2
)))
req2
=
make_request
(
"2"
,
list
(
range
(
num_tokens
*
2
))
,
block_size
,
hash
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req2
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req2
)
assert
len
(
computed_blocks
.
blocks
[
0
])
==
1
assert
len
(
computed_blocks
.
blocks
[
0
])
==
1
assert
computed_blocks
.
blocks
[
0
][
0
].
block_id
==
1
assert
computed_blocks
.
blocks
[
0
][
0
].
block_id
==
1
...
@@ -675,7 +694,8 @@ def test_basic_prefix_caching_disabled():
...
@@ -675,7 +694,8 @@ def test_basic_prefix_caching_disabled():
enable_caching
=
False
,
enable_caching
=
False
,
)
)
req1
=
make_request
(
"1"
,
list
(
range
(
10
)))
# 2 blocks and some more
req1
=
make_request
(
"1"
,
list
(
range
(
10
)),
block_size
,
hash
)
# 2 blocks and some more
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req1
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req1
)
assert
not
computed_blocks
.
blocks
[
0
]
assert
not
computed_blocks
.
blocks
[
0
]
...
@@ -689,7 +709,8 @@ def test_basic_prefix_caching_disabled():
...
@@ -689,7 +709,8 @@ def test_basic_prefix_caching_disabled():
manager
.
free
(
req1
)
manager
.
free
(
req1
)
# No caching.
# No caching.
req2
=
make_request
(
"2"
,
list
(
range
(
16
)))
# shared prefix
req2
=
make_request
(
"2"
,
list
(
range
(
16
)),
block_size
,
hash
)
# shared prefix
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req2
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req2
)
assert
not
computed_blocks
.
blocks
[
0
]
assert
not
computed_blocks
.
blocks
[
0
]
assert
num_computed_tokens
==
0
assert
num_computed_tokens
==
0
...
@@ -699,7 +720,7 @@ def test_basic_prefix_caching_disabled():
...
@@ -699,7 +720,7 @@ def test_basic_prefix_caching_disabled():
assert
len
(
blocks
.
blocks
[
0
])
==
4
assert
len
(
blocks
.
blocks
[
0
])
==
4
# New requests should not have any blocks.
# New requests should not have any blocks.
req3
=
make_request
(
"3"
,
list
(
range
(
4
)))
req3
=
make_request
(
"3"
,
list
(
range
(
4
))
,
block_size
,
hash
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req3
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req3
)
assert
not
computed_blocks
.
blocks
[
0
]
assert
not
computed_blocks
.
blocks
[
0
]
assert
num_computed_tokens
==
0
assert
num_computed_tokens
==
0
...
@@ -727,20 +748,17 @@ def test_cache_blocks(hash_fn):
...
@@ -727,20 +748,17 @@ def test_cache_blocks(hash_fn):
# Block 1: [4, 5, 6, 7]
# Block 1: [4, 5, 6, 7]
# Block 2: [8, 9, 10, 11]
# Block 2: [8, 9, 10, 11]
# Block 3: [12, 13]
# Block 3: [12, 13]
req
=
make_request
(
"0"
,
list
(
range
(
14
)))
req
=
make_request
(
"0"
,
list
(
range
(
14
))
,
block_size
,
hash_fn
)
# Test that blocks are cached correctly for 2 full blocks from the start.
# Test that blocks are cached correctly for 2 full blocks from the start.
blocks
=
[
KVCacheBlock
(
block_id
=
i
)
for
i
in
range
(
2
)]
blocks
=
[
KVCacheBlock
(
block_id
=
i
)
for
i
in
range
(
2
)]
block_hashes
:
list
[
BlockHash
]
=
[]
block_pool
.
cache_full_blocks
(
block_pool
.
cache_full_blocks
(
request
=
req
,
request
=
req
,
blocks
=
blocks
,
blocks
=
blocks
,
block_hashes
=
block_hashes
,
num_cached_blocks
=
0
,
num_cached_blocks
=
0
,
num_full_blocks
=
2
,
num_full_blocks
=
2
,
block_size
=
block_size
,
block_size
=
block_size
,
hash_fn
=
hash_fn
,
kv_cache_group_id
=
0
,
kv_cache_group_id
=
0
,
)
)
...
@@ -752,11 +770,9 @@ def test_cache_blocks(hash_fn):
...
@@ -752,11 +770,9 @@ def test_cache_blocks(hash_fn):
block_pool
.
cache_full_blocks
(
block_pool
.
cache_full_blocks
(
request
=
req
,
request
=
req
,
blocks
=
blocks
,
blocks
=
blocks
,
block_hashes
=
block_hashes
,
num_cached_blocks
=
2
,
num_cached_blocks
=
2
,
num_full_blocks
=
3
,
num_full_blocks
=
3
,
block_size
=
block_size
,
block_size
=
block_size
,
hash_fn
=
hash_fn
,
kv_cache_group_id
=
0
,
kv_cache_group_id
=
0
,
)
)
assert
len
(
block_pool
.
cached_block_hash_to_block
)
==
3
assert
len
(
block_pool
.
cached_block_hash_to_block
)
==
3
...
@@ -775,23 +791,20 @@ def test_cache_blocks_multi_group():
...
@@ -775,23 +791,20 @@ def test_cache_blocks_multi_group():
# Block 1/5: [4, 5, 6, 7]
# Block 1/5: [4, 5, 6, 7]
# Block 2/6: [8, 9, 10, 11]
# Block 2/6: [8, 9, 10, 11]
# Block 3/7: [12, 13]
# Block 3/7: [12, 13]
req
=
make_request
(
"0"
,
list
(
range
(
14
)))
req
=
make_request
(
"0"
,
list
(
range
(
14
))
,
block_size
,
hash
)
# Cache the blocks for group 0.
# Cache the blocks for group 0.
blocks
=
[
KVCacheBlock
(
block_id
=
i
)
for
i
in
range
(
2
)]
blocks
=
[
KVCacheBlock
(
block_id
=
i
)
for
i
in
range
(
2
)]
block_hashes
:
list
[
BlockHash
]
=
[]
block_pool
.
cache_full_blocks
(
block_pool
.
cache_full_blocks
(
request
=
req
,
request
=
req
,
blocks
=
blocks
,
blocks
=
blocks
,
block_hashes
=
block_hashes
,
num_cached_blocks
=
0
,
num_cached_blocks
=
0
,
num_full_blocks
=
2
,
num_full_blocks
=
2
,
block_size
=
block_size
,
block_size
=
block_size
,
hash_fn
=
hash
,
kv_cache_group_id
=
0
,
kv_cache_group_id
=
0
,
)
)
assert
len
(
block_pool
.
cached_block_hash_to_block
)
==
2
assert
len
(
block_pool
.
cached_block_hash_to_block
)
==
2
assert
len
(
block_hashes
)
==
2
assert
len
(
req
.
block_hashes
)
==
3
assert
all
([
block
.
block_hash
is
not
None
for
block
in
blocks
])
assert
all
([
block
.
block_hash
is
not
None
for
block
in
blocks
])
# Cache the blocks for group 1.
# Cache the blocks for group 1.
...
@@ -799,38 +812,36 @@ def test_cache_blocks_multi_group():
...
@@ -799,38 +812,36 @@ def test_cache_blocks_multi_group():
block_pool
.
cache_full_blocks
(
block_pool
.
cache_full_blocks
(
request
=
req
,
request
=
req
,
blocks
=
blocks
,
blocks
=
blocks
,
block_hashes
=
block_hashes
,
num_cached_blocks
=
0
,
num_cached_blocks
=
0
,
num_full_blocks
=
3
,
num_full_blocks
=
3
,
block_size
=
block_size
,
block_size
=
block_size
,
hash_fn
=
hash
,
kv_cache_group_id
=
1
,
kv_cache_group_id
=
1
,
)
)
assert
len
(
block_pool
.
cached_block_hash_to_block
)
==
5
assert
len
(
block_pool
.
cached_block_hash_to_block
)
==
5
assert
len
(
block_hashes
)
==
3
assert
len
(
req
.
block_hashes
)
==
3
assert
all
([
block
.
block_hash
is
not
None
for
block
in
blocks
])
assert
all
([
block
.
block_hash
is
not
None
for
block
in
blocks
])
# Block hash 0: hit for group 0 and 1
# Block hash 0: hit for group 0 and 1
# Block hash 1: hit for group 0 and 1
# Block hash 1: hit for group 0 and 1
# Block hash 2: hit for group 1
# Block hash 2: hit for group 1
assert
block_pool
.
get_cached_block
(
block_hashes
[
0
],
assert
block_pool
.
get_cached_block
(
req
.
block_hashes
[
0
],
kv_cache_group_ids
=
[
0
])
is
not
None
kv_cache_group_ids
=
[
0
])
is
not
None
assert
block_pool
.
get_cached_block
(
block_hashes
[
1
],
assert
block_pool
.
get_cached_block
(
req
.
block_hashes
[
1
],
kv_cache_group_ids
=
[
0
])
is
not
None
kv_cache_group_ids
=
[
0
])
is
not
None
assert
block_pool
.
get_cached_block
(
block_hashes
[
2
],
assert
block_pool
.
get_cached_block
(
req
.
block_hashes
[
2
],
kv_cache_group_ids
=
[
0
])
is
None
kv_cache_group_ids
=
[
0
])
is
None
assert
block_pool
.
get_cached_block
(
block_hashes
[
0
],
assert
block_pool
.
get_cached_block
(
req
.
block_hashes
[
0
],
kv_cache_group_ids
=
[
1
])
is
not
None
kv_cache_group_ids
=
[
1
])
is
not
None
assert
block_pool
.
get_cached_block
(
block_hashes
[
1
],
assert
block_pool
.
get_cached_block
(
req
.
block_hashes
[
1
],
kv_cache_group_ids
=
[
1
])
is
not
None
kv_cache_group_ids
=
[
1
])
is
not
None
assert
block_pool
.
get_cached_block
(
block_hashes
[
2
],
assert
block_pool
.
get_cached_block
(
req
.
block_hashes
[
2
],
kv_cache_group_ids
=
[
1
])
is
not
None
kv_cache_group_ids
=
[
1
])
is
not
None
assert
block_pool
.
get_cached_block
(
block_hashes
[
0
],
assert
block_pool
.
get_cached_block
(
req
.
block_hashes
[
0
],
kv_cache_group_ids
=
[
0
,
1
])
is
not
None
kv_cache_group_ids
=
[
0
,
1
])
is
not
None
assert
block_pool
.
get_cached_block
(
block_hashes
[
1
],
assert
block_pool
.
get_cached_block
(
req
.
block_hashes
[
1
],
kv_cache_group_ids
=
[
0
,
1
])
is
not
None
kv_cache_group_ids
=
[
0
,
1
])
is
not
None
assert
block_pool
.
get_cached_block
(
block_hashes
[
2
],
assert
block_pool
.
get_cached_block
(
req
.
block_hashes
[
2
],
kv_cache_group_ids
=
[
0
,
1
])
is
None
kv_cache_group_ids
=
[
0
,
1
])
is
None
...
@@ -838,8 +849,9 @@ def test_mm_prefix_caching():
...
@@ -838,8 +849,9 @@ def test_mm_prefix_caching():
"""
"""
This tests that the multi-modal prefix caching is correct.
This tests that the multi-modal prefix caching is correct.
"""
"""
block_size
=
16
manager
=
KVCacheManager
(
manager
=
KVCacheManager
(
make_kv_cache_config
(
16
,
11
),
make_kv_cache_config
(
block_size
,
11
),
max_model_len
=
8192
,
max_model_len
=
8192
,
enable_caching
=
True
,
enable_caching
=
True
,
)
)
...
@@ -865,6 +877,8 @@ def test_mm_prefix_caching():
...
@@ -865,6 +877,8 @@ def test_mm_prefix_caching():
mm_hashes
=
common_mm_hashes
+
[
"ccc"
]
mm_hashes
=
common_mm_hashes
+
[
"ccc"
]
req0
=
make_request
(
"0"
,
req0
=
make_request
(
"0"
,
all_token_ids
,
all_token_ids
,
block_size
,
hash
,
mm_positions
=
mm_positions
,
mm_positions
=
mm_positions
,
mm_hashes
=
mm_hashes
)
mm_hashes
=
mm_hashes
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req0
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req0
)
...
@@ -872,7 +886,7 @@ def test_mm_prefix_caching():
...
@@ -872,7 +886,7 @@ def test_mm_prefix_caching():
# Completed block should have hashes with extra keys.
# Completed block should have hashes with extra keys.
assert
not
computed_blocks
.
blocks
[
0
]
assert
not
computed_blocks
.
blocks
[
0
]
assert
num_computed_tokens
==
0
assert
num_computed_tokens
==
0
block_hashes
=
manager
.
req_to_block_hashes
[
req0
.
request_id
]
block_hashes
=
req0
.
block_hashes
assert
len
(
block_hashes
)
==
3
assert
len
(
block_hashes
)
==
3
assert
block_hashes
[
0
].
extra_keys
==
(
"aaa"
,
)
assert
block_hashes
[
0
].
extra_keys
==
(
"aaa"
,
)
assert
block_hashes
[
1
].
extra_keys
==
(
"aaa"
,
"bbb"
)
assert
block_hashes
[
1
].
extra_keys
==
(
"aaa"
,
"bbb"
)
...
@@ -905,6 +919,8 @@ def test_mm_prefix_caching():
...
@@ -905,6 +919,8 @@ def test_mm_prefix_caching():
mm_hashes
=
common_mm_hashes
+
[
"ccc"
]
mm_hashes
=
common_mm_hashes
+
[
"ccc"
]
req1
=
make_request
(
"1"
,
req1
=
make_request
(
"1"
,
all_token_ids
,
all_token_ids
,
block_size
,
hash
,
mm_positions
=
mm_positions
,
mm_positions
=
mm_positions
,
mm_hashes
=
mm_hashes
)
mm_hashes
=
mm_hashes
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req1
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req1
)
...
@@ -927,13 +943,13 @@ def test_cache_key_salting():
...
@@ -927,13 +943,13 @@ def test_cache_key_salting():
# 3 complete blocks and an incomplete block with 11 tokens.
# 3 complete blocks and an incomplete block with 11 tokens.
common_token_ids
=
[
i
for
i
in
range
(
3
)
for
_
in
range
(
block_size
)]
common_token_ids
=
[
i
for
i
in
range
(
3
)
for
_
in
range
(
block_size
)]
token_ids
=
common_token_ids
+
[
3
]
*
11
token_ids
=
common_token_ids
+
[
3
]
*
11
req0
=
make_request
(
"0"
,
token_ids
,
cache_salt
=
"salt1"
)
req0
=
make_request
(
"0"
,
token_ids
,
block_size
,
hash
,
cache_salt
=
"salt1"
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req0
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req0
)
# Completed block should have hashes with extra keys.
# Completed block should have hashes with extra keys.
assert
not
computed_blocks
.
blocks
[
0
]
assert
not
computed_blocks
.
blocks
[
0
]
assert
num_computed_tokens
==
0
assert
num_computed_tokens
==
0
block_hashes
=
manager
.
req_to_block_hashes
[
req0
.
request_id
]
block_hashes
=
req0
.
block_hashes
assert
len
(
block_hashes
)
==
3
assert
len
(
block_hashes
)
==
3
assert
block_hashes
[
0
].
extra_keys
==
(
"salt1"
,
)
assert
block_hashes
[
0
].
extra_keys
==
(
"salt1"
,
)
assert
block_hashes
[
1
].
extra_keys
is
None
assert
block_hashes
[
1
].
extra_keys
is
None
...
@@ -959,7 +975,7 @@ def test_cache_key_salting():
...
@@ -959,7 +975,7 @@ def test_cache_key_salting():
# Test cache hit with a new request that has the same salt.
# Test cache hit with a new request that has the same salt.
token_ids
=
common_token_ids
+
[
4
]
*
11
token_ids
=
common_token_ids
+
[
4
]
*
11
req1
=
make_request
(
"1"
,
token_ids
,
cache_salt
=
"salt1"
)
req1
=
make_request
(
"1"
,
token_ids
,
block_size
,
hash
,
cache_salt
=
"salt1"
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req1
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req1
)
# Should match only a prefix of 3 blocks.
# Should match only a prefix of 3 blocks.
assert
len
(
computed_blocks
.
blocks
[
0
])
==
3
assert
len
(
computed_blocks
.
blocks
[
0
])
==
3
...
@@ -967,11 +983,11 @@ def test_cache_key_salting():
...
@@ -967,11 +983,11 @@ def test_cache_key_salting():
# Test cache miss with same content but different salt.
# Test cache miss with same content but different salt.
token_ids
=
common_token_ids
+
[
4
]
*
11
token_ids
=
common_token_ids
+
[
4
]
*
11
req2
=
make_request
(
"2"
,
token_ids
,
cache_salt
=
"salt2"
)
req2
=
make_request
(
"2"
,
token_ids
,
block_size
,
hash
,
cache_salt
=
"salt2"
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req2
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req2
)
assert
len
(
computed_blocks
.
blocks
[
0
])
==
0
assert
len
(
computed_blocks
.
blocks
[
0
])
==
0
assert
num_computed_tokens
==
0
assert
num_computed_tokens
==
0
block_hashes
=
manager
.
req_to_block_hashes
[
req2
.
request_id
]
block_hashes
=
req2
.
block_hashes
assert
len
(
block_hashes
)
==
3
assert
len
(
block_hashes
)
==
3
assert
block_hashes
[
0
].
extra_keys
==
(
"salt2"
,
)
assert
block_hashes
[
0
].
extra_keys
==
(
"salt2"
,
)
...
@@ -992,7 +1008,7 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
...
@@ -992,7 +1008,7 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
# Complete 3 blocks (48 tokens)
# Complete 3 blocks (48 tokens)
# | Common-0 | Common-1 | Common-2 | ... |
# | Common-0 | Common-1 | Common-2 | ... |
common_token_ids
=
[
i
for
i
in
range
(
3
)
for
_
in
range
(
16
)]
common_token_ids
=
[
i
for
i
in
range
(
3
)
for
_
in
range
(
16
)]
req0
=
make_request
(
"0"
,
common_token_ids
)
req0
=
make_request
(
"0"
,
common_token_ids
,
block_size
,
hash
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req0
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req0
)
assert
not
computed_blocks
.
blocks
[
0
]
assert
not
computed_blocks
.
blocks
[
0
]
assert
num_computed_tokens
==
0
assert
num_computed_tokens
==
0
...
@@ -1003,7 +1019,7 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
...
@@ -1003,7 +1019,7 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
req0
.
request_id
]
req0
.
request_id
]
# | Common-0 | Common-1 | Common-2 | Req1-3 | Req1-4 | Req1-5 | ... |
# | Common-0 | Common-1 | Common-2 | Req1-3 | Req1-4 | Req1-5 | ... |
req1
=
make_request
(
"1"
,
common_token_ids
*
2
)
req1
=
make_request
(
"1"
,
common_token_ids
*
2
,
block_size
,
hash
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req1
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req1
)
assert
computed_blocks
.
blocks
[
0
]
==
block_part0
assert
computed_blocks
.
blocks
[
0
]
==
block_part0
assert
num_computed_tokens
==
3
*
16
assert
num_computed_tokens
==
3
*
16
...
@@ -1020,19 +1036,19 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
...
@@ -1020,19 +1036,19 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
# | Common-0 | Common-1 | Common-2 | Req1-3 (F) | Req1-4 (F) |
# | Common-0 | Common-1 | Common-2 | Req1-3 (F) | Req1-4 (F) |
# | Req1-5(F)| Req2-0 | Req2-1 | ... |
# | Req1-5(F)| Req2-0 | Req2-1 | ... |
req2
=
make_request
(
"2"
,
[
7
]
*
block_size
*
2
)
req2
=
make_request
(
"2"
,
[
7
]
*
block_size
*
2
,
block_size
,
hash
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req2
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req2
)
assert
not
computed_blocks
.
blocks
[
0
]
assert
not
computed_blocks
.
blocks
[
0
]
assert
num_computed_tokens
==
0
assert
num_computed_tokens
==
0
manager
.
allocate_slots
(
req2
,
block_size
*
2
,
manager
.
allocate_slots
(
req2
,
block_size
*
2
,
len
(
computed_blocks
.
blocks
[
0
])
*
16
,
len
(
computed_blocks
.
blocks
[
0
])
*
block_size
,
computed_blocks
)
computed_blocks
)
# Req3 is Req2 + 3 new blocks, so the first 6 blocks are computed,
# Req3 is Req2 + 3 new blocks, so the first 6 blocks are computed,
# but it cannot be allocated due to insufficient free blocks (2).
# but it cannot be allocated due to insufficient free blocks (2).
# In this case, the ref_cnt of the computed blocks should not be changed.
# In this case, the ref_cnt of the computed blocks should not be changed.
assert
manager
.
block_pool
.
free_block_queue
.
num_free_blocks
==
5
assert
manager
.
block_pool
.
free_block_queue
.
num_free_blocks
==
5
req3
=
make_request
(
"3"
,
common_token_ids
*
3
)
req3
=
make_request
(
"3"
,
common_token_ids
*
3
,
block_size
,
hash
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req3
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req3
)
assert
computed_blocks
.
blocks
[
0
]
==
block_part1
assert
computed_blocks
.
blocks
[
0
]
==
block_part1
assert
num_computed_tokens
==
6
*
16
assert
num_computed_tokens
==
6
*
16
...
@@ -1047,8 +1063,9 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
...
@@ -1047,8 +1063,9 @@ def test_prefill_not_enough_free_blocks_with_computed_blocks():
def
test_reset_prefix_cache
():
def
test_reset_prefix_cache
():
block_size
=
16
manager
=
KVCacheManager
(
manager
=
KVCacheManager
(
make_kv_cache_config
(
16
,
11
),
make_kv_cache_config
(
block_size
,
11
),
max_model_len
=
8192
,
max_model_len
=
8192
,
enable_caching
=
True
,
enable_caching
=
True
,
)
)
...
@@ -1056,15 +1073,15 @@ def test_reset_prefix_cache():
...
@@ -1056,15 +1073,15 @@ def test_reset_prefix_cache():
full_block_token_ids
=
[
i
for
i
in
range
(
3
)
for
_
in
range
(
16
)]
full_block_token_ids
=
[
i
for
i
in
range
(
3
)
for
_
in
range
(
16
)]
unique_token_ids
=
[
3
]
*
7
unique_token_ids
=
[
3
]
*
7
all_token_ids
=
full_block_token_ids
+
unique_token_ids
all_token_ids
=
full_block_token_ids
+
unique_token_ids
req0
=
make_request
(
"0"
,
all_token_ids
)
req0
=
make_request
(
"0"
,
all_token_ids
,
block_size
,
hash
)
blocks
=
manager
.
allocate_slots
(
req0
,
55
)
blocks
=
manager
.
allocate_slots
(
req0
,
55
)
assert
blocks
.
get_block_ids
()
==
([
1
,
2
,
3
,
4
],
)
assert
blocks
.
get_block_ids
()
==
([
1
,
2
,
3
,
4
],
)
unique_token_ids
=
[
4
]
*
7
unique_token_ids
=
[
4
]
*
7
all_token_ids
=
full_block_token_ids
+
unique_token_ids
all_token_ids
=
full_block_token_ids
+
unique_token_ids
req1
=
make_request
(
"1"
,
all_token_ids
)
req1
=
make_request
(
"1"
,
all_token_ids
,
block_size
,
hash
)
computed_blocks
,
_
=
manager
.
get_computed_blocks
(
req1
)
computed_blocks
,
_
=
manager
.
get_computed_blocks
(
req1
)
assert
len
(
manager
.
req_to_block_hashes
[
req1
.
request_id
]
)
==
3
assert
len
(
req1
.
block_hashes
)
==
3
assert
len
(
computed_blocks
.
blocks
[
0
])
==
3
assert
len
(
computed_blocks
.
blocks
[
0
])
==
3
blocks
=
manager
.
allocate_slots
(
req1
,
7
,
blocks
=
manager
.
allocate_slots
(
req1
,
7
,
len
(
computed_blocks
.
blocks
[
0
])
*
16
,
len
(
computed_blocks
.
blocks
[
0
])
*
16
,
...
@@ -1086,8 +1103,9 @@ def test_reset_prefix_cache():
...
@@ -1086,8 +1103,9 @@ def test_reset_prefix_cache():
def
test_prefix_cache_stats_disabled
():
def
test_prefix_cache_stats_disabled
():
"""Test that prefix_cache_stats is None when log_stats is False."""
"""Test that prefix_cache_stats is None when log_stats is False."""
block_size
=
16
manager
=
KVCacheManager
(
manager
=
KVCacheManager
(
make_kv_cache_config
(
16
,
11
),
make_kv_cache_config
(
block_size
,
11
),
max_model_len
=
8192
,
max_model_len
=
8192
,
enable_caching
=
True
,
enable_caching
=
True
,
log_stats
=
False
,
# Disable logging stats
log_stats
=
False
,
# Disable logging stats
...
@@ -1095,7 +1113,7 @@ def test_prefix_cache_stats_disabled():
...
@@ -1095,7 +1113,7 @@ def test_prefix_cache_stats_disabled():
assert
manager
.
prefix_cache_stats
is
None
assert
manager
.
prefix_cache_stats
is
None
# Call all functions that check whether log_stats is disabled.
# Call all functions that check whether log_stats is disabled.
req
=
make_request
(
"0"
,
list
(
range
(
16
)))
req
=
make_request
(
"0"
,
list
(
range
(
16
))
,
block_size
,
hash
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req
)
computed_blocks
,
num_computed_tokens
=
manager
.
get_computed_blocks
(
req
)
assert
not
computed_blocks
.
blocks
[
0
]
assert
not
computed_blocks
.
blocks
[
0
]
assert
num_computed_tokens
==
0
assert
num_computed_tokens
==
0
...
@@ -1192,7 +1210,7 @@ def test_kv_cache_events(blocks_to_cache: int):
...
@@ -1192,7 +1210,7 @@ def test_kv_cache_events(blocks_to_cache: int):
)
)
num_tokens
=
block_size
*
blocks_to_cache
num_tokens
=
block_size
*
blocks_to_cache
req0
=
make_request
(
"0"
,
list
(
range
(
num_tokens
)))
req0
=
make_request
(
"0"
,
list
(
range
(
num_tokens
))
,
block_size
,
hash
)
_
=
manager
.
allocate_slots
(
req0
,
num_tokens
)
_
=
manager
.
allocate_slots
(
req0
,
num_tokens
)
events
=
manager
.
take_events
()
events
=
manager
.
take_events
()
...
@@ -1208,7 +1226,7 @@ def test_kv_cache_events(blocks_to_cache: int):
...
@@ -1208,7 +1226,7 @@ def test_kv_cache_events(blocks_to_cache: int):
# Should see block_to_cache number of removed block events and a new block
# Should see block_to_cache number of removed block events and a new block
# stored event
# stored event
manager
.
free
(
req0
)
manager
.
free
(
req0
)
req1
=
make_request
(
"1"
,
list
(
range
(
num_tokens
)))
req1
=
make_request
(
"1"
,
list
(
range
(
num_tokens
))
,
block_size
,
hash
)
_
=
manager
.
allocate_slots
(
req1
,
num_tokens
)
_
=
manager
.
allocate_slots
(
req1
,
num_tokens
)
events
=
manager
.
take_events
()
events
=
manager
.
take_events
()
...
@@ -1242,7 +1260,7 @@ def test_eagle_enabled_removes_last_block():
...
@@ -1242,7 +1260,7 @@ def test_eagle_enabled_removes_last_block():
# Request with 3 full blocks (48 tokens)
# Request with 3 full blocks (48 tokens)
token_ids
=
[
0
]
*
(
3
*
block_size
)
token_ids
=
[
0
]
*
(
3
*
block_size
)
req
=
make_request
(
"divisible_request"
,
token_ids
)
req
=
make_request
(
"divisible_request"
,
token_ids
,
block_size
,
hash
)
# Prime the cache
# Prime the cache
computed_blocks
,
_
=
manager
.
get_computed_blocks
(
req
)
computed_blocks
,
_
=
manager
.
get_computed_blocks
(
req
)
...
@@ -1252,7 +1270,7 @@ def test_eagle_enabled_removes_last_block():
...
@@ -1252,7 +1270,7 @@ def test_eagle_enabled_removes_last_block():
manager
.
free
(
req
)
manager
.
free
(
req
)
# New request with same tokens + Eagle enabled
# New request with same tokens + Eagle enabled
req_eagle
=
make_request
(
"eagle_divisible"
,
token_ids
)
req_eagle
=
make_request
(
"eagle_divisible"
,
token_ids
,
block_size
,
hash
)
computed_blocks
,
num_tokens
=
manager
.
get_computed_blocks
(
req_eagle
)
computed_blocks
,
num_tokens
=
manager
.
get_computed_blocks
(
req_eagle
)
# Should retain 1 block:
# Should retain 1 block:
...
@@ -1273,7 +1291,7 @@ def test_eagle_with_partial_blocks():
...
@@ -1273,7 +1291,7 @@ def test_eagle_with_partial_blocks():
)
)
# 2 full blocks + 5 tokens (non-divisible length)
# 2 full blocks + 5 tokens (non-divisible length)
token_ids
=
[
0
]
*
(
2
*
block_size
+
5
)
token_ids
=
[
0
]
*
(
2
*
block_size
+
5
)
req
=
make_request
(
"partial_block_test"
,
token_ids
)
req
=
make_request
(
"partial_block_test"
,
token_ids
,
block_size
,
hash
)
# Prime the cache
# Prime the cache
computed_blocks
,
_
=
manager
.
get_computed_blocks
(
req
)
computed_blocks
,
_
=
manager
.
get_computed_blocks
(
req
)
...
@@ -1283,7 +1301,7 @@ def test_eagle_with_partial_blocks():
...
@@ -1283,7 +1301,7 @@ def test_eagle_with_partial_blocks():
manager
.
free
(
req
)
manager
.
free
(
req
)
# New request with Eagle enabled
# New request with Eagle enabled
req_eagle
=
make_request
(
"partial_eagle"
,
token_ids
)
req_eagle
=
make_request
(
"partial_eagle"
,
token_ids
,
block_size
,
hash
)
computed_blocks
,
num_tokens
=
manager
.
get_computed_blocks
(
req_eagle
)
computed_blocks
,
num_tokens
=
manager
.
get_computed_blocks
(
req_eagle
)
# Original match: 2 full blocks → Eagle removes 1 → 1 remaining
# Original match: 2 full blocks → Eagle removes 1 → 1 remaining
assert
len
(
computed_blocks
.
blocks
[
0
])
==
1
assert
len
(
computed_blocks
.
blocks
[
0
])
==
1
...
@@ -1314,7 +1332,7 @@ def test_eagle_with_sliding_window():
...
@@ -1314,7 +1332,7 @@ def test_eagle_with_sliding_window():
# 2 full blocks + 5 tokens (non-divisible length)
# 2 full blocks + 5 tokens (non-divisible length)
token_ids
=
[
0
]
*
(
2
*
block_size
+
5
)
token_ids
=
[
0
]
*
(
2
*
block_size
+
5
)
req
=
make_request
(
"partial_block_test"
,
token_ids
)
req
=
make_request
(
"partial_block_test"
,
token_ids
,
block_size
,
hash
)
# Prime the cache
# Prime the cache
computed_blocks
,
_
=
manager
.
get_computed_blocks
(
req
)
computed_blocks
,
_
=
manager
.
get_computed_blocks
(
req
)
...
@@ -1322,12 +1340,12 @@ def test_eagle_with_sliding_window():
...
@@ -1322,12 +1340,12 @@ def test_eagle_with_sliding_window():
len
(
computed_blocks
.
blocks
[
0
])
*
16
,
len
(
computed_blocks
.
blocks
[
0
])
*
16
,
computed_blocks
)
computed_blocks
)
# record the block hash of the first block in the request for later use
# record the block hash of the first block in the request for later use
block_hash_first_block
=
manager
.
req_to_block_hashes
[
req
.
request_id
]
[
0
]
block_hash_first_block
=
req
.
block_hashes
[
0
]
assert
block_hash_first_block
is
not
None
assert
block_hash_first_block
is
not
None
manager
.
free
(
req
)
manager
.
free
(
req
)
# New request with Eagle enabled
# New request with Eagle enabled
req_eagle
=
make_request
(
"partial_eagle"
,
token_ids
)
req_eagle
=
make_request
(
"partial_eagle"
,
token_ids
,
block_size
,
hash
)
computed_blocks
,
num_tokens
=
manager
.
get_computed_blocks
(
req_eagle
)
computed_blocks
,
num_tokens
=
manager
.
get_computed_blocks
(
req_eagle
)
# Original match: 2 full blocks → Eagle removes 1 → 1 remaining
# Original match: 2 full blocks → Eagle removes 1 → 1 remaining
assert
len
(
computed_blocks
.
blocks
[
0
])
==
1
assert
len
(
computed_blocks
.
blocks
[
0
])
==
1
...
@@ -1340,7 +1358,8 @@ def test_eagle_with_sliding_window():
...
@@ -1340,7 +1358,8 @@ def test_eagle_with_sliding_window():
BlockHashWithGroupId
(
block_hash_first_block
,
0
))
BlockHashWithGroupId
(
block_hash_first_block
,
0
))
# New request
# New request
req_after_evict
=
make_request
(
"partial_eagle_after_evict"
,
token_ids
)
req_after_evict
=
make_request
(
"partial_eagle_after_evict"
,
token_ids
,
block_size
,
hash
)
computed_blocks
,
num_tokens
=
manager
.
get_computed_blocks
(
req_after_evict
)
computed_blocks
,
num_tokens
=
manager
.
get_computed_blocks
(
req_after_evict
)
# Cache miss. The only hit prefix is [NULL_BLOCK, BLOCK_2] if eagle is
# Cache miss. The only hit prefix is [NULL_BLOCK, BLOCK_2] if eagle is
# not considered. But after dropping the last matched block due to eagle,
# not considered. But after dropping the last matched block due to eagle,
...
...
tests/v1/core/test_scheduler.py
View file @
c280066f
...
@@ -589,7 +589,7 @@ def test_preempt_during_execution():
...
@@ -589,7 +589,7 @@ def test_preempt_during_execution():
block_size
=
16
,
block_size
=
16
,
num_blocks
=
11
,
num_blocks
=
11
,
enable_prefix_caching
=
False
)
enable_prefix_caching
=
False
)
requests
=
create_requests
(
num_requests
=
2
,
num_tokens
=
80
)
requests
=
create_requests
(
num_requests
=
2
,
num_tokens
=
80
,
block_size
=
16
)
# Schedule the first request.
# Schedule the first request.
scheduler
.
add_request
(
requests
[
0
])
scheduler
.
add_request
(
requests
[
0
])
...
@@ -762,7 +762,7 @@ def _assert_right_scheduler_output(
...
@@ -762,7 +762,7 @@ def _assert_right_scheduler_output(
def
_assert_right_kv_cache_manager
(
def
_assert_right_kv_cache_manager
(
scheduler
:
Scheduler
,
scheduler
:
Scheduler
,
req
_id
s
:
list
[
st
r
],
req
uest
s
:
list
[
Reque
st
],
num_tokens
:
int
,
num_tokens
:
int
,
block_size
:
int
,
block_size
:
int
,
num_requests
:
int
,
num_requests
:
int
,
...
@@ -772,12 +772,12 @@ def _assert_right_kv_cache_manager(
...
@@ -772,12 +772,12 @@ def _assert_right_kv_cache_manager(
# Make sure the request stats are right.
# Make sure the request stats are right.
EXPECTED_TOTAL_BLOCKS
=
num_tokens
//
block_size
EXPECTED_TOTAL_BLOCKS
=
num_tokens
//
block_size
for
req
_id
in
req
_id
s
:
for
req
in
req
uest
s
:
blocks
=
(
scheduler
.
kv_cache_manager
.
coordinator
.
blocks
=
(
scheduler
.
kv_cache_manager
.
coordinator
.
single_type_managers
[
0
].
req_to_blocks
[
req_id
])
single_type_managers
[
0
].
req_to_blocks
[
req
.
request
_id
])
hashes
=
scheduler
.
kv_cache_manager
.
req_to_
block_hashes
[
req_id
]
hashes
=
req
.
block_hashes
assert
(
scheduler
.
kv_cache_manager
.
coordinator
.
single_type_managers
[
0
].
assert
(
scheduler
.
kv_cache_manager
.
coordinator
.
single_type_managers
[
0
].
num_cached_block
[
req_id
]
==
EXPECTED_TOTAL_BLOCKS
)
num_cached_block
[
req
.
request
_id
]
==
EXPECTED_TOTAL_BLOCKS
)
assert
len
(
blocks
)
==
EXPECTED_TOTAL_BLOCKS
assert
len
(
blocks
)
==
EXPECTED_TOTAL_BLOCKS
assert
len
(
hashes
)
==
EXPECTED_TOTAL_BLOCKS
assert
len
(
hashes
)
==
EXPECTED_TOTAL_BLOCKS
...
@@ -840,7 +840,8 @@ def test_kv_connector_basic():
...
@@ -840,7 +840,8 @@ def test_kv_connector_basic():
MAX_TOKENS
=
3
MAX_TOKENS
=
3
requests
=
create_requests
(
num_requests
=
NUM_REQUESTS
,
requests
=
create_requests
(
num_requests
=
NUM_REQUESTS
,
num_tokens
=
NUM_TOKENS
,
num_tokens
=
NUM_TOKENS
,
max_tokens
=
MAX_TOKENS
)
max_tokens
=
MAX_TOKENS
,
block_size
=
BLOCK_SIZE
)
req_ids
=
[]
req_ids
=
[]
req_to_index
=
{}
req_to_index
=
{}
for
i
,
request
in
enumerate
(
requests
):
for
i
,
request
in
enumerate
(
requests
):
...
@@ -868,7 +869,7 @@ def test_kv_connector_basic():
...
@@ -868,7 +869,7 @@ def test_kv_connector_basic():
)
)
# Ensure KVCacheManager is correct.
# Ensure KVCacheManager is correct.
_assert_right_kv_cache_manager
(
scheduler
,
req
_id
s
,
NUM_TOKENS
,
BLOCK_SIZE
,
_assert_right_kv_cache_manager
(
scheduler
,
req
uest
s
,
NUM_TOKENS
,
BLOCK_SIZE
,
NUM_REQUESTS
,
NUM_TOTAL_BLOCKS
)
NUM_REQUESTS
,
NUM_TOTAL_BLOCKS
)
# Continue Generation until done.
# Continue Generation until done.
...
@@ -886,7 +887,8 @@ def test_kv_connector_basic():
...
@@ -886,7 +887,8 @@ def test_kv_connector_basic():
NUM_TOKENS
=
NUM_TOKENS_PREFIX
*
2
NUM_TOKENS
=
NUM_TOKENS_PREFIX
*
2
requests
=
create_requests
(
num_requests
=
NUM_REQUESTS
,
requests
=
create_requests
(
num_requests
=
NUM_REQUESTS
,
num_tokens
=
NUM_TOKENS
,
num_tokens
=
NUM_TOKENS
,
max_tokens
=
MAX_TOKENS
)
max_tokens
=
MAX_TOKENS
,
block_size
=
BLOCK_SIZE
)
req_ids
=
[]
req_ids
=
[]
req_to_index
=
{}
req_to_index
=
{}
for
i
,
request
in
enumerate
(
requests
):
for
i
,
request
in
enumerate
(
requests
):
...
@@ -915,7 +917,7 @@ def test_kv_connector_basic():
...
@@ -915,7 +917,7 @@ def test_kv_connector_basic():
NUM_MATCHED_NEW_TOKENS
))
NUM_MATCHED_NEW_TOKENS
))
# Ensure KVCacheManager is correct.
# Ensure KVCacheManager is correct.
_assert_right_kv_cache_manager
(
scheduler
,
req
_id
s
,
NUM_TOKENS
,
BLOCK_SIZE
,
_assert_right_kv_cache_manager
(
scheduler
,
req
uest
s
,
NUM_TOKENS
,
BLOCK_SIZE
,
NUM_REQUESTS
,
NUM_TOTAL_BLOCKS
)
NUM_REQUESTS
,
NUM_TOTAL_BLOCKS
)
# Continue Generation until done.
# Continue Generation until done.
...
@@ -953,7 +955,8 @@ def test_kv_connector_unable_to_allocate():
...
@@ -953,7 +955,8 @@ def test_kv_connector_unable_to_allocate():
MAX_TOKENS
=
2
MAX_TOKENS
=
2
requests
=
create_requests
(
num_requests
=
NUM_REQUESTS
,
requests
=
create_requests
(
num_requests
=
NUM_REQUESTS
,
num_tokens
=
NUM_TOKENS
,
num_tokens
=
NUM_TOKENS
,
max_tokens
=
MAX_TOKENS
)
max_tokens
=
MAX_TOKENS
,
block_size
=
BLOCK_SIZE
)
req_ids
=
[]
req_ids
=
[]
req_to_index
=
{}
req_to_index
=
{}
for
i
,
request
in
enumerate
(
requests
):
for
i
,
request
in
enumerate
(
requests
):
...
@@ -1034,7 +1037,8 @@ def test_kv_connector_handles_preemption():
...
@@ -1034,7 +1037,8 @@ def test_kv_connector_handles_preemption():
MAX_TOKENS
=
BLOCK_SIZE
*
2
MAX_TOKENS
=
BLOCK_SIZE
*
2
requests
=
create_requests
(
num_requests
=
NUM_REQUESTS
,
requests
=
create_requests
(
num_requests
=
NUM_REQUESTS
,
num_tokens
=
NUM_TOKENS
,
num_tokens
=
NUM_TOKENS
,
max_tokens
=
MAX_TOKENS
)
max_tokens
=
MAX_TOKENS
,
block_size
=
BLOCK_SIZE
)
req_ids
=
[]
req_ids
=
[]
req_to_index
=
{}
req_to_index
=
{}
for
i
,
request
in
enumerate
(
requests
):
for
i
,
request
in
enumerate
(
requests
):
...
@@ -1162,7 +1166,6 @@ def assert_scheduler_empty(scheduler: Scheduler):
...
@@ -1162,7 +1166,6 @@ def assert_scheduler_empty(scheduler: Scheduler):
# KVCache Manager.
# KVCache Manager.
assert
len
(
scheduler
.
kv_cache_manager
.
coordinator
.
single_type_managers
[
0
].
assert
len
(
scheduler
.
kv_cache_manager
.
coordinator
.
single_type_managers
[
0
].
req_to_blocks
)
==
0
req_to_blocks
)
==
0
assert
len
(
scheduler
.
kv_cache_manager
.
req_to_block_hashes
)
==
0
assert
len
(
scheduler
.
kv_cache_manager
.
coordinator
.
single_type_managers
[
0
].
assert
len
(
scheduler
.
kv_cache_manager
.
coordinator
.
single_type_managers
[
0
].
num_cached_block
)
==
0
num_cached_block
)
==
0
num_free_blocks
=
(
num_free_blocks
=
(
...
...
tests/v1/core/test_single_type_kv_cache_manager.py
View file @
c280066f
...
@@ -17,7 +17,6 @@ from vllm.v1.kv_cache_interface import (ChunkedLocalAttentionSpec,
...
@@ -17,7 +17,6 @@ from vllm.v1.kv_cache_interface import (ChunkedLocalAttentionSpec,
def
get_sliding_window_manager
(
sliding_window_spec
,
block_pool
):
def
get_sliding_window_manager
(
sliding_window_spec
,
block_pool
):
return
SlidingWindowManager
(
sliding_window_spec
,
return
SlidingWindowManager
(
sliding_window_spec
,
block_pool
,
block_pool
,
caching_hash_fn
=
lambda
x
:
x
,
kv_cache_group_id
=
0
)
kv_cache_group_id
=
0
)
...
@@ -25,7 +24,6 @@ def get_chunked_local_attention_manager(chunked_local_attention_spec,
...
@@ -25,7 +24,6 @@ def get_chunked_local_attention_manager(chunked_local_attention_spec,
block_pool
):
block_pool
):
return
ChunkedLocalAttentionManager
(
chunked_local_attention_spec
,
return
ChunkedLocalAttentionManager
(
chunked_local_attention_spec
,
block_pool
,
block_pool
,
caching_hash_fn
=
lambda
x
:
x
,
kv_cache_group_id
=
0
)
kv_cache_group_id
=
0
)
...
...
tests/v1/core/utils.py
View file @
c280066f
...
@@ -10,6 +10,8 @@ from vllm.multimodal.inputs import (MultiModalBatchedField,
...
@@ -10,6 +10,8 @@ from vllm.multimodal.inputs import (MultiModalBatchedField,
MultiModalFieldElem
,
MultiModalKwargsItem
,
MultiModalFieldElem
,
MultiModalKwargsItem
,
PlaceholderRange
)
PlaceholderRange
)
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.v1.core.kv_cache_utils
import
(
get_request_block_hasher
,
init_none_hash
)
from
vllm.v1.core.sched.async_scheduler
import
AsyncScheduler
from
vllm.v1.core.sched.async_scheduler
import
AsyncScheduler
from
vllm.v1.core.sched.scheduler
import
Scheduler
from
vllm.v1.core.sched.scheduler
import
Scheduler
from
vllm.v1.kv_cache_interface
import
(
FullAttentionSpec
,
KVCacheConfig
,
from
vllm.v1.kv_cache_interface
import
(
FullAttentionSpec
,
KVCacheConfig
,
...
@@ -114,6 +116,9 @@ def create_scheduler(
...
@@ -114,6 +116,9 @@ def create_scheduler(
)
)
_none_hash_initialized
=
False
def
create_requests
(
def
create_requests
(
num_requests
:
int
,
num_requests
:
int
,
num_tokens
:
int
=
10
,
num_tokens
:
int
=
10
,
...
@@ -122,7 +127,14 @@ def create_requests(
...
@@ -122,7 +127,14 @@ def create_requests(
stop_token_ids
:
Optional
[
list
[
int
]]
=
None
,
stop_token_ids
:
Optional
[
list
[
int
]]
=
None
,
prompt_logprobs
:
Optional
[
int
]
=
None
,
prompt_logprobs
:
Optional
[
int
]
=
None
,
same_prompt
:
bool
=
False
,
same_prompt
:
bool
=
False
,
block_size
:
int
=
16
,
)
->
list
[
Request
]:
)
->
list
[
Request
]:
global
_none_hash_initialized
if
not
_none_hash_initialized
:
init_none_hash
(
hash
)
_none_hash_initialized
=
True
block_hasher
=
get_request_block_hasher
(
block_size
,
hash
)
sampling_params
=
SamplingParams
(
ignore_eos
=
False
,
sampling_params
=
SamplingParams
(
ignore_eos
=
False
,
max_tokens
=
max_tokens
,
max_tokens
=
max_tokens
,
stop_token_ids
=
stop_token_ids
,
stop_token_ids
=
stop_token_ids
,
...
@@ -139,9 +151,11 @@ def create_requests(
...
@@ -139,9 +151,11 @@ def create_requests(
)
)
mm_item
=
MultiModalKwargsItem
.
from_elems
([
mm_elem
])
mm_item
=
MultiModalKwargsItem
.
from_elems
([
mm_elem
])
mm_kwargs
=
[
mm_item
]
*
len
(
mm_position
)
mm_kwargs
=
[
mm_item
]
*
len
(
mm_position
)
mm_hashes
=
[
"hash"
]
*
len
(
mm_position
)
else
:
else
:
mm_position
=
None
mm_position
=
None
mm_kwargs
=
None
mm_kwargs
=
None
mm_hashes
=
None
prompt_token_ids
=
([
0
]
*
num_tokens
if
same_prompt
else
[
i
]
*
prompt_token_ids
=
([
0
]
*
num_tokens
if
same_prompt
else
[
i
]
*
num_tokens
)
num_tokens
)
request
=
Request
(
request
=
Request
(
...
@@ -151,8 +165,9 @@ def create_requests(
...
@@ -151,8 +165,9 @@ def create_requests(
pooling_params
=
None
,
pooling_params
=
None
,
multi_modal_kwargs
=
mm_kwargs
,
multi_modal_kwargs
=
mm_kwargs
,
multi_modal_placeholders
=
mm_position
,
multi_modal_placeholders
=
mm_position
,
multi_modal_hashes
=
None
,
multi_modal_hashes
=
mm_hashes
,
eos_token_id
=
EOS_TOKEN_ID
,
eos_token_id
=
EOS_TOKEN_ID
,
block_hasher
=
block_hasher
,
)
)
requests
.
append
(
request
)
requests
.
append
(
request
)
return
requests
return
requests
tests/v1/kv_connector/unit/test_nixl_connector.py
View file @
c280066f
...
@@ -147,6 +147,7 @@ def test_basic_interface():
...
@@ -147,6 +147,7 @@ def test_basic_interface():
NUM_TOKENS
=
int
(
BLOCK_SIZE
*
(
NUM_EXTERNAL_FULL_BLOCKS
+
0.5
))
NUM_TOKENS
=
int
(
BLOCK_SIZE
*
(
NUM_EXTERNAL_FULL_BLOCKS
+
0.5
))
request
=
create_request
(
request_id
=
1
,
request
=
create_request
(
request_id
=
1
,
block_size
=
BLOCK_SIZE
,
num_tokens
=
NUM_TOKENS
,
num_tokens
=
NUM_TOKENS
,
do_remote_prefill
=
True
)
do_remote_prefill
=
True
)
request_id
=
request
.
request_id
request_id
=
request
.
request_id
...
@@ -186,6 +187,7 @@ def test_prompt_less_than_block_size():
...
@@ -186,6 +187,7 @@ def test_prompt_less_than_block_size():
# Request will have 1 partial remote block.
# Request will have 1 partial remote block.
request
=
create_request
(
request_id
=
1
,
request
=
create_request
(
request_id
=
1
,
block_size
=
BLOCK_SIZE
,
num_tokens
=
NUM_TOKENS
,
num_tokens
=
NUM_TOKENS
,
do_remote_prefill
=
True
,
do_remote_prefill
=
True
,
num_remote_blocks
=
1
)
num_remote_blocks
=
1
)
...
...
tests/v1/kv_connector/unit/test_remote_decode_lifecycle.py
View file @
c280066f
...
@@ -21,6 +21,7 @@ def test_basic_lifecycle():
...
@@ -21,6 +21,7 @@ def test_basic_lifecycle():
NUM_TOKENS
=
int
(
BLOCK_SIZE
*
(
NUM_EXTERNAL_FULL_BLOCKS
+
0.5
))
NUM_TOKENS
=
int
(
BLOCK_SIZE
*
(
NUM_EXTERNAL_FULL_BLOCKS
+
0.5
))
request
=
create_request
(
request_id
=
1
,
request
=
create_request
(
request_id
=
1
,
block_size
=
BLOCK_SIZE
,
max_tokens
=
1
,
max_tokens
=
1
,
num_tokens
=
NUM_TOKENS
,
num_tokens
=
NUM_TOKENS
,
do_remote_decode
=
True
)
do_remote_decode
=
True
)
...
@@ -103,8 +104,10 @@ def test_short_prompt_lifecycle():
...
@@ -103,8 +104,10 @@ def test_short_prompt_lifecycle():
scheduler
=
create_scheduler
(
vllm_config
)
scheduler
=
create_scheduler
(
vllm_config
)
# Not enough tokens for full block.
# Not enough tokens for full block.
NUM_TOKENS
=
vllm_config
.
cache_config
.
block_size
//
2
BLOCK_SIZE
=
vllm_config
.
cache_config
.
block_size
NUM_TOKENS
=
BLOCK_SIZE
//
2
request
=
create_request
(
request_id
=
1
,
request
=
create_request
(
request_id
=
1
,
block_size
=
BLOCK_SIZE
,
max_tokens
=
1
,
max_tokens
=
1
,
num_tokens
=
NUM_TOKENS
,
num_tokens
=
NUM_TOKENS
,
do_remote_decode
=
True
)
do_remote_decode
=
True
)
...
@@ -148,7 +151,9 @@ def test_prefix_cache_lifecycle():
...
@@ -148,7 +151,9 @@ def test_prefix_cache_lifecycle():
NUM_EXTERNAL_FULL_BLOCKS
=
3
NUM_EXTERNAL_FULL_BLOCKS
=
3
NUM_TOKENS
=
int
(
BLOCK_SIZE
*
(
NUM_EXTERNAL_FULL_BLOCKS
+
0.5
))
NUM_TOKENS
=
int
(
BLOCK_SIZE
*
(
NUM_EXTERNAL_FULL_BLOCKS
+
0.5
))
request_normal
=
create_request
(
request_id
=
1
,
num_tokens
=
NUM_TOKENS
)
request_normal
=
create_request
(
request_id
=
1
,
block_size
=
BLOCK_SIZE
,
num_tokens
=
NUM_TOKENS
)
scheduler
.
add_request
(
request_normal
)
scheduler
.
add_request
(
request_normal
)
scheduler_output
=
scheduler
.
schedule
()
scheduler_output
=
scheduler
.
schedule
()
...
@@ -166,6 +171,7 @@ def test_prefix_cache_lifecycle():
...
@@ -166,6 +171,7 @@ def test_prefix_cache_lifecycle():
NUM_TOKENS
=
int
(
BLOCK_SIZE
*
(
NUM_EXTERNAL_FULL_BLOCKS
+
0.5
))
NUM_TOKENS
=
int
(
BLOCK_SIZE
*
(
NUM_EXTERNAL_FULL_BLOCKS
+
0.5
))
request_remote
=
create_request
(
request_id
=
1
,
request_remote
=
create_request
(
request_id
=
1
,
block_size
=
BLOCK_SIZE
,
num_tokens
=
NUM_TOKENS
,
num_tokens
=
NUM_TOKENS
,
do_remote_decode
=
True
)
do_remote_decode
=
True
)
...
...
tests/v1/kv_connector/unit/test_remote_prefill_lifecycle.py
View file @
c280066f
...
@@ -23,6 +23,7 @@ def test_basic_lifecycle():
...
@@ -23,6 +23,7 @@ def test_basic_lifecycle():
scheduler
.
kv_cache_manager
.
block_pool
.
free_block_queue
.
num_free_blocks
)
scheduler
.
kv_cache_manager
.
block_pool
.
free_block_queue
.
num_free_blocks
)
request
=
create_request
(
request_id
=
1
,
request
=
create_request
(
request_id
=
1
,
block_size
=
BLOCK_SIZE
,
num_tokens
=
NUM_TOKENS
,
num_tokens
=
NUM_TOKENS
,
do_remote_prefill
=
True
)
do_remote_prefill
=
True
)
...
@@ -133,14 +134,17 @@ def test_interleaved_lifecycle():
...
@@ -133,14 +134,17 @@ def test_interleaved_lifecycle():
NUM_TOKENS
=
int
(
BLOCK_SIZE
*
(
NUM_EXTERNAL_FULL_BLOCKS
+
0.5
))
NUM_TOKENS
=
int
(
BLOCK_SIZE
*
(
NUM_EXTERNAL_FULL_BLOCKS
+
0.5
))
request_remote
=
create_request
(
request_id
=
1
,
request_remote
=
create_request
(
request_id
=
1
,
block_size
=
BLOCK_SIZE
,
num_tokens
=
NUM_TOKENS
,
num_tokens
=
NUM_TOKENS
,
do_remote_prefill
=
True
)
do_remote_prefill
=
True
)
request_local_a
=
create_request
(
request_local_a
=
create_request
(
request_id
=
2
,
request_id
=
2
,
block_size
=
BLOCK_SIZE
,
num_tokens
=
NUM_TOKENS
,
num_tokens
=
NUM_TOKENS
,
)
)
request_local_b
=
create_request
(
request_local_b
=
create_request
(
request_id
=
3
,
request_id
=
3
,
block_size
=
BLOCK_SIZE
,
num_tokens
=
NUM_TOKENS
,
num_tokens
=
NUM_TOKENS
,
)
)
...
@@ -236,6 +240,7 @@ def test_no_spurious_prefix_caching():
...
@@ -236,6 +240,7 @@ def test_no_spurious_prefix_caching():
# Both of these requests have prompts like [1,1,1,1,1, ...]
# Both of these requests have prompts like [1,1,1,1,1, ...]
request_remote
=
create_request
(
request_remote
=
create_request
(
request_id
=
1
,
request_id
=
1
,
block_size
=
BLOCK_SIZE
,
num_tokens
=
NUM_TOKENS
,
num_tokens
=
NUM_TOKENS
,
do_remote_prefill
=
True
,
do_remote_prefill
=
True
,
use_all_1s_for_prompt_tokens
=
True
,
use_all_1s_for_prompt_tokens
=
True
,
...
@@ -243,6 +248,7 @@ def test_no_spurious_prefix_caching():
...
@@ -243,6 +248,7 @@ def test_no_spurious_prefix_caching():
request_local
=
create_request
(
request_local
=
create_request
(
request_id
=
2
,
request_id
=
2
,
block_size
=
BLOCK_SIZE
,
num_tokens
=
NUM_TOKENS
,
num_tokens
=
NUM_TOKENS
,
do_remote_prefill
=
False
,
do_remote_prefill
=
False
,
use_all_1s_for_prompt_tokens
=
True
,
use_all_1s_for_prompt_tokens
=
True
,
...
@@ -292,6 +298,7 @@ def test_full_block_prompt():
...
@@ -292,6 +298,7 @@ def test_full_block_prompt():
NUM_TOKENS
=
int
(
BLOCK_SIZE
*
NUM_EXTERNAL_FULL_BLOCKS
)
NUM_TOKENS
=
int
(
BLOCK_SIZE
*
NUM_EXTERNAL_FULL_BLOCKS
)
request
=
create_request
(
request_id
=
1
,
request
=
create_request
(
request_id
=
1
,
block_size
=
BLOCK_SIZE
,
num_tokens
=
NUM_TOKENS
,
num_tokens
=
NUM_TOKENS
,
do_remote_prefill
=
True
)
do_remote_prefill
=
True
)
...
@@ -364,8 +371,11 @@ def test_cannot_schedule_after_recv():
...
@@ -364,8 +371,11 @@ def test_cannot_schedule_after_recv():
NUM_TOKENS_LOCAL
=
int
(
BLOCK_SIZE
*
NUM_PROMPT_BLOCKS
)
NUM_TOKENS_LOCAL
=
int
(
BLOCK_SIZE
*
NUM_PROMPT_BLOCKS
)
NUM_TOKENS_REMOTE
=
int
(
BLOCK_SIZE
*
NUM_PROMPT_BLOCKS
)
NUM_TOKENS_REMOTE
=
int
(
BLOCK_SIZE
*
NUM_PROMPT_BLOCKS
)
request_normal
=
create_request
(
request_id
=
1
,
num_tokens
=
NUM_TOKENS_LOCAL
)
request_normal
=
create_request
(
request_id
=
1
,
block_size
=
BLOCK_SIZE
,
num_tokens
=
NUM_TOKENS_LOCAL
)
request_remote
=
create_request
(
request_id
=
2
,
request_remote
=
create_request
(
request_id
=
2
,
block_size
=
BLOCK_SIZE
,
num_tokens
=
NUM_TOKENS_REMOTE
,
num_tokens
=
NUM_TOKENS_REMOTE
,
do_remote_prefill
=
True
)
do_remote_prefill
=
True
)
...
@@ -456,8 +466,11 @@ def test_cannot_recv():
...
@@ -456,8 +466,11 @@ def test_cannot_recv():
NUM_TOKENS_LOCAL
=
int
(
BLOCK_SIZE
*
NUM_PROMPT_BLOCKS
)
NUM_TOKENS_LOCAL
=
int
(
BLOCK_SIZE
*
NUM_PROMPT_BLOCKS
)
NUM_TOKENS_REMOTE
=
int
(
BLOCK_SIZE
*
(
NUM_PROMPT_BLOCKS
+
0.5
))
NUM_TOKENS_REMOTE
=
int
(
BLOCK_SIZE
*
(
NUM_PROMPT_BLOCKS
+
0.5
))
request_normal
=
create_request
(
request_id
=
1
,
num_tokens
=
NUM_TOKENS_LOCAL
)
request_normal
=
create_request
(
request_id
=
1
,
block_size
=
BLOCK_SIZE
,
num_tokens
=
NUM_TOKENS_LOCAL
)
request_remote
=
create_request
(
request_id
=
2
,
request_remote
=
create_request
(
request_id
=
2
,
block_size
=
BLOCK_SIZE
,
num_tokens
=
NUM_TOKENS_REMOTE
,
num_tokens
=
NUM_TOKENS_REMOTE
,
do_remote_prefill
=
True
)
do_remote_prefill
=
True
)
...
...
tests/v1/kv_connector/unit/utils.py
View file @
c280066f
...
@@ -2,7 +2,7 @@
...
@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
tempfile
import
tempfile
from
collections
import
defaultdict
from
collections
import
defaultdict
from
typing
import
Any
,
Optional
from
typing
import
Any
,
Callable
,
Optional
import
torch
import
torch
...
@@ -14,6 +14,8 @@ from vllm.distributed.kv_transfer.kv_connector.factory import (
...
@@ -14,6 +14,8 @@ from vllm.distributed.kv_transfer.kv_connector.factory import (
from
vllm.distributed.kv_transfer.kv_connector.v1.shared_storage_connector
import
(
# noqa
from
vllm.distributed.kv_transfer.kv_connector.v1.shared_storage_connector
import
(
# noqa
SharedStorageConnector
)
SharedStorageConnector
)
from
vllm.v1.core.kv_cache_manager
import
KVCacheBlocks
from
vllm.v1.core.kv_cache_manager
import
KVCacheBlocks
from
vllm.v1.core.kv_cache_utils
import
(
get_request_block_hasher
,
init_none_hash
)
from
vllm.v1.core.sched.scheduler
import
Scheduler
from
vllm.v1.core.sched.scheduler
import
Scheduler
from
vllm.v1.kv_cache_interface
import
(
FullAttentionSpec
,
KVCacheConfig
,
from
vllm.v1.kv_cache_interface
import
(
FullAttentionSpec
,
KVCacheConfig
,
KVCacheGroupSpec
)
KVCacheGroupSpec
)
...
@@ -40,7 +42,6 @@ def assert_scheduler_empty(scheduler: Scheduler):
...
@@ -40,7 +42,6 @@ def assert_scheduler_empty(scheduler: Scheduler):
# KVCache Manager.
# KVCache Manager.
assert
len
(
scheduler
.
kv_cache_manager
.
coordinator
.
single_type_managers
[
0
].
assert
len
(
scheduler
.
kv_cache_manager
.
coordinator
.
single_type_managers
[
0
].
req_to_blocks
)
==
0
req_to_blocks
)
==
0
assert
len
(
scheduler
.
kv_cache_manager
.
req_to_block_hashes
)
==
0
assert
len
(
scheduler
.
kv_cache_manager
.
coordinator
.
single_type_managers
[
0
].
assert
len
(
scheduler
.
kv_cache_manager
.
coordinator
.
single_type_managers
[
0
].
num_cached_block
)
==
0
num_cached_block
)
==
0
num_free_blocks
=
(
num_free_blocks
=
(
...
@@ -115,16 +116,23 @@ def create_scheduler(
...
@@ -115,16 +116,23 @@ def create_scheduler(
)
)
def
create_request
(
_none_hash_initialized
=
False
request_id
:
int
,
def
create_request
(
request_id
:
int
,
num_tokens
:
int
=
10
,
num_tokens
:
int
=
10
,
max_tokens
:
int
=
16
,
max_tokens
:
int
=
16
,
do_remote_decode
:
bool
=
False
,
do_remote_decode
:
bool
=
False
,
do_remote_prefill
:
bool
=
False
,
do_remote_prefill
:
bool
=
False
,
use_all_1s_for_prompt_tokens
:
bool
=
False
,
use_all_1s_for_prompt_tokens
:
bool
=
False
,
num_remote_blocks
:
int
=
3
,
num_remote_blocks
:
int
=
3
,
)
->
Request
:
block_size
:
int
=
16
,
hash_fn
:
Callable
=
hash
)
->
Request
:
"""Make dummy request for testing."""
"""Make dummy request for testing."""
global
_none_hash_initialized
if
not
_none_hash_initialized
:
init_none_hash
(
hash
)
_none_hash_initialized
=
True
kv_transfer_params
:
Optional
[
dict
[
str
,
Any
]]
=
None
kv_transfer_params
:
Optional
[
dict
[
str
,
Any
]]
=
None
...
@@ -158,6 +166,7 @@ def create_request(
...
@@ -158,6 +166,7 @@ def create_request(
multi_modal_placeholders
=
None
,
multi_modal_placeholders
=
None
,
multi_modal_hashes
=
None
,
multi_modal_hashes
=
None
,
eos_token_id
=
EOS_TOKEN_ID
,
eos_token_id
=
EOS_TOKEN_ID
,
block_hasher
=
get_request_block_hasher
(
block_size
,
hash_fn
),
)
)
req
.
kv_transfer_params
=
kv_transfer_params
req
.
kv_transfer_params
=
kv_transfer_params
return
req
return
req
...
...
vllm/utils/__init__.py
View file @
c280066f
...
@@ -3243,6 +3243,24 @@ def sha256_cbor_64bit(input) -> int:
...
@@ -3243,6 +3243,24 @@ def sha256_cbor_64bit(input) -> int:
return
full_hash
&
((
1
<<
64
)
-
1
)
return
full_hash
&
((
1
<<
64
)
-
1
)
def
get_hash_fn_by_name
(
hash_fn_name
:
str
)
->
Callable
:
"""Get a hash function by name, or raise an error if
the function is not found.
Args:
hash_fn_name: Name of the hash function.
Returns:
A hash function.
"""
if
hash_fn_name
==
"sha256"
:
return
sha256
if
hash_fn_name
==
"sha256_cbor_64bit"
:
return
sha256_cbor_64bit
if
hash_fn_name
==
"builtin"
:
return
hash
raise
ValueError
(
f
"Unsupported hash function:
{
hash_fn_name
}
"
)
def
is_torch_equal_or_newer
(
target
:
str
)
->
bool
:
def
is_torch_equal_or_newer
(
target
:
str
)
->
bool
:
"""Check if the installed torch version is >= the target version.
"""Check if the installed torch version is >= the target version.
...
...
vllm/v1/core/block_pool.py
View file @
c280066f
...
@@ -2,15 +2,13 @@
...
@@ -2,15 +2,13 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
collections
import
defaultdict
from
collections
import
defaultdict
from
collections.abc
import
Iterable
from
collections.abc
import
Iterable
from
typing
import
Callable
,
Optional
from
typing
import
Optional
from
vllm.distributed.kv_events
import
(
AllBlocksCleared
,
BlockRemoved
,
from
vllm.distributed.kv_events
import
(
AllBlocksCleared
,
BlockRemoved
,
BlockStored
,
KVCacheEvent
)
BlockStored
,
KVCacheEvent
)
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.v1.core.kv_cache_utils
import
(
BlockHash
,
BlockHashWithGroupId
,
from
vllm.v1.core.kv_cache_utils
import
(
BlockHash
,
BlockHashWithGroupId
,
FreeKVCacheBlockQueue
,
KVCacheBlock
,
FreeKVCacheBlockQueue
,
KVCacheBlock
)
generate_block_hash_extra_keys
,
hash_block_tokens
)
from
vllm.v1.request
import
Request
from
vllm.v1.request
import
Request
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -97,84 +95,39 @@ class BlockPool:
...
@@ -97,84 +95,39 @@ class BlockPool:
self
,
self
,
request
:
Request
,
request
:
Request
,
blocks
:
list
[
KVCacheBlock
],
blocks
:
list
[
KVCacheBlock
],
block_hashes
:
list
[
BlockHash
],
num_cached_blocks
:
int
,
num_cached_blocks
:
int
,
num_full_blocks
:
int
,
num_full_blocks
:
int
,
block_size
:
int
,
block_size
:
int
,
kv_cache_group_id
:
int
,
kv_cache_group_id
:
int
,
hash_fn
:
Callable
,
)
->
None
:
)
->
None
:
"""Cache a list of full blocks for prefix caching.
"""Cache a list of full blocks for prefix caching.
This function takes a list of blocks that will have their block hash
This function takes a list of blocks that will have their block hash
metadata to be updated and cached. Given a request, it computes the
metadata to be updated and cached. Given a request, it updates the
block hashes for the blocks starting from `num_cached_blocks` to
metadata for each block and caching it in the
`num_full_blocks`, updating the metadata for each block
`cached_block_hash_to_block`.
and caching them in the `cached_block_hash_to_block`.
The block hashes values are computed by the Request object immediately
when it is created and when new tokens are appended.
Args:
Args:
request: The request to cache the blocks.
request: The request to cache the blocks.
blocks: All blocks in the request.
blocks: All blocks in the request.
block_hashes: Block hashes of the blocks in the request. Note that
this list may be shorter than the blocks list. In this case the
missed block hash will be computed in this function.
num_cached_blocks: The number of blocks that are already cached.
num_cached_blocks: The number of blocks that are already cached.
num_full_blocks: The number of blocks that are full and should
num_full_blocks: The number of blocks that are full and should
be cached after this function.
be cached after this function.
block_size: Number of tokens in each block.
block_size: Number of tokens in each block.
kv_cache_group_id: The id of the KV cache group.
kv_cache_group_id: The id of the KV cache group.
hash_fn: The hash function to use for block hashes.
"""
"""
if
num_cached_blocks
==
num_full_blocks
:
if
num_cached_blocks
==
num_full_blocks
:
return
return
new_full_blocks
=
blocks
[
num_cached_blocks
:
num_full_blocks
]
new_full_blocks
=
blocks
[
num_cached_blocks
:
num_full_blocks
]
assert
len
(
block_hashes
)
>=
num_
cached
_blocks
assert
len
(
request
.
block_hashes
)
>=
num_
full
_blocks
new_block_hashes
=
block_hashes
[
num_cached_blocks
:]
new_block_hashes
=
request
.
block_hashes
[
num_cached_blocks
:]
# Update the new blocks with the block hashes through the chain.
if
num_cached_blocks
==
0
:
prev_block_hash_value
=
None
else
:
prev_block
=
blocks
[
num_cached_blocks
-
1
]
assert
prev_block
.
block_hash
is
not
None
prev_block_hash_value
=
prev_block
.
block_hash
.
get_hash_value
()
parent_block_hash
=
prev_block_hash_value
new_hashes
:
Optional
[
list
[
int
]]
=
([]
if
self
.
enable_kv_cache_events
new_hashes
:
Optional
[
list
[
int
]]
=
([]
if
self
.
enable_kv_cache_events
else
None
)
else
None
)
for
i
,
blk
in
enumerate
(
new_full_blocks
):
for
i
,
blk
in
enumerate
(
new_full_blocks
):
assert
blk
.
block_hash
is
None
assert
blk
.
block_hash
is
None
if
i
<
len
(
new_block_hashes
):
# The block hash may already be computed in
# "get_computed_blocks" if the tokens are not generated by
# this request (either the prompt tokens or the previously
# generated tokens with preemption), or by other
# single_type_managers with the same block_size.
# In this case we simply reuse the block hash.
block_hash
=
new_block_hashes
[
i
]
block_hash
=
new_block_hashes
[
i
]
else
:
# Otherwise compute the block hash and cache it in the request
# in case it will be preempted in the future.
blk_idx
=
num_cached_blocks
+
i
start_token_idx
=
blk_idx
*
block_size
end_token_idx
=
(
blk_idx
+
1
)
*
block_size
block_tokens
=
request
.
all_token_ids
[
start_token_idx
:
end_token_idx
]
assert
len
(
block_tokens
)
==
block_size
,
(
f
"Expected
{
block_size
}
tokens, got "
f
"
{
len
(
block_tokens
)
}
at
{
blk_idx
}
th block for request "
f
"
{
request
.
request_id
}
(
{
request
}
)"
)
# Generate extra keys for multi-modal inputs. Note that since
# we reach to this branch only when the block is completed with
# generated tokens, we only need to consider the last mm input.
extra_keys
,
_
=
generate_block_hash_extra_keys
(
request
,
start_token_idx
,
end_token_idx
,
-
1
)
# Compute the hash of the current block.
block_hash
=
hash_block_tokens
(
hash_fn
,
prev_block_hash_value
,
block_tokens
,
extra_keys
)
block_hashes
.
append
(
block_hash
)
# Update and added the full block to the cache.
# Update and added the full block to the cache.
block_hash_with_group_id
=
BlockHashWithGroupId
(
block_hash_with_group_id
=
BlockHashWithGroupId
(
...
@@ -184,9 +137,15 @@ class BlockPool:
...
@@ -184,9 +137,15 @@ class BlockPool:
blk
.
block_id
]
=
blk
blk
.
block_id
]
=
blk
if
new_hashes
is
not
None
:
if
new_hashes
is
not
None
:
new_hashes
.
append
(
block_hash
.
hash_value
)
new_hashes
.
append
(
block_hash
.
hash_value
)
prev_block_hash_value
=
block_hash
.
hash_value
if
self
.
enable_kv_cache_events
:
if
self
.
enable_kv_cache_events
:
if
num_cached_blocks
==
0
:
parent_block_hash
=
None
else
:
parent_block
=
blocks
[
num_cached_blocks
-
1
]
assert
parent_block
.
block_hash
is
not
None
parent_block_hash
=
parent_block
.
block_hash
.
get_hash_value
()
self
.
kv_event_queue
.
append
(
self
.
kv_event_queue
.
append
(
BlockStored
(
BlockStored
(
block_hashes
=
new_hashes
,
block_hashes
=
new_hashes
,
...
...
vllm/v1/core/kv_cache_coordinator.py
View file @
c280066f
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
from
typing
import
Callable
,
Optional
from
typing
import
Optional
from
vllm.v1.core.block_pool
import
BlockPool
from
vllm.v1.core.block_pool
import
BlockPool
from
vllm.v1.core.kv_cache_utils
import
BlockHash
,
KVCacheBlock
from
vllm.v1.core.kv_cache_utils
import
BlockHash
,
KVCacheBlock
...
@@ -23,7 +23,6 @@ class KVCacheCoordinator(ABC):
...
@@ -23,7 +23,6 @@ class KVCacheCoordinator(ABC):
max_model_len
:
int
,
max_model_len
:
int
,
use_eagle
:
bool
,
use_eagle
:
bool
,
enable_caching
:
bool
,
enable_caching
:
bool
,
caching_hash_fn
:
Callable
,
enable_kv_cache_events
:
bool
,
enable_kv_cache_events
:
bool
,
):
):
self
.
kv_cache_config
=
kv_cache_config
self
.
kv_cache_config
=
kv_cache_config
...
@@ -40,7 +39,6 @@ class KVCacheCoordinator(ABC):
...
@@ -40,7 +39,6 @@ class KVCacheCoordinator(ABC):
kv_cache_spec
=
kv_cache_group
.
kv_cache_spec
,
kv_cache_spec
=
kv_cache_group
.
kv_cache_spec
,
block_pool
=
self
.
block_pool
,
block_pool
=
self
.
block_pool
,
kv_cache_group_id
=
i
,
kv_cache_group_id
=
i
,
caching_hash_fn
=
caching_hash_fn
,
)
for
i
,
kv_cache_group
in
enumerate
(
)
for
i
,
kv_cache_group
in
enumerate
(
self
.
kv_cache_config
.
kv_cache_groups
))
self
.
kv_cache_config
.
kv_cache_groups
))
...
@@ -99,19 +97,17 @@ class KVCacheCoordinator(ABC):
...
@@ -99,19 +97,17 @@ class KVCacheCoordinator(ABC):
manager
.
allocate_new_blocks
(
request_id
,
num_tokens
)
manager
.
allocate_new_blocks
(
request_id
,
num_tokens
)
for
manager
in
self
.
single_type_managers
)
for
manager
in
self
.
single_type_managers
)
def
cache_blocks
(
self
,
request
:
Request
,
block_hashes
:
list
[
BlockHash
],
def
cache_blocks
(
self
,
request
:
Request
,
num_computed_tokens
:
int
)
->
None
:
num_computed_tokens
:
int
)
->
None
:
"""
"""
Cache the blocks for the request.
Cache the blocks for the request.
Args:
Args:
request: The request.
request: The request.
block_hashes: The block hashes of the request.
num_tokens: The total number of tokens that need to be cached
num_tokens: The total number of tokens that need to be cached
(including tokens that are already cached).
(including tokens that are already cached).
"""
"""
for
manager
in
self
.
single_type_managers
:
for
manager
in
self
.
single_type_managers
:
manager
.
cache_blocks
(
request
,
block_hashes
,
num_computed_tokens
)
manager
.
cache_blocks
(
request
,
num_computed_tokens
)
def
free
(
self
,
request_id
:
str
)
->
None
:
def
free
(
self
,
request_id
:
str
)
->
None
:
"""
"""
...
@@ -184,10 +180,9 @@ class KVCacheCoordinatorNoPrefixCache(KVCacheCoordinator):
...
@@ -184,10 +180,9 @@ class KVCacheCoordinatorNoPrefixCache(KVCacheCoordinator):
"""
"""
def
__init__
(
self
,
kv_cache_config
:
KVCacheConfig
,
max_model_len
:
int
,
def
__init__
(
self
,
kv_cache_config
:
KVCacheConfig
,
max_model_len
:
int
,
use_eagle
:
bool
,
caching_hash_fn
:
Callable
,
use_eagle
:
bool
,
enable_kv_cache_events
:
bool
):
enable_kv_cache_events
:
bool
):
super
().
__init__
(
kv_cache_config
,
max_model_len
,
use_eagle
,
False
,
super
().
__init__
(
kv_cache_config
,
max_model_len
,
use_eagle
,
False
,
caching_hash_fn
,
enable_kv_cache_events
)
enable_kv_cache_events
)
self
.
num_single_type_manager
=
len
(
self
.
single_type_managers
)
self
.
num_single_type_manager
=
len
(
self
.
single_type_managers
)
def
get_num_common_prefix_blocks
(
self
,
request_id
:
str
,
def
get_num_common_prefix_blocks
(
self
,
request_id
:
str
,
...
@@ -213,10 +208,9 @@ class UnitaryKVCacheCoordinator(KVCacheCoordinator):
...
@@ -213,10 +208,9 @@ class UnitaryKVCacheCoordinator(KVCacheCoordinator):
def
__init__
(
self
,
kv_cache_config
:
KVCacheConfig
,
max_model_len
:
int
,
def
__init__
(
self
,
kv_cache_config
:
KVCacheConfig
,
max_model_len
:
int
,
use_eagle
:
bool
,
enable_caching
:
bool
,
use_eagle
:
bool
,
enable_caching
:
bool
,
caching_hash_fn
:
Callable
,
enable_kv_cache_events
:
bool
):
enable_kv_cache_events
:
bool
):
super
().
__init__
(
kv_cache_config
,
max_model_len
,
use_eagle
,
super
().
__init__
(
kv_cache_config
,
max_model_len
,
use_eagle
,
enable_caching
,
caching_hash_fn
,
enable_caching
,
enable_kv_cache_events
)
enable_kv_cache_events
)
self
.
kv_cache_spec
=
self
.
kv_cache_config
.
kv_cache_groups
[
self
.
kv_cache_spec
=
self
.
kv_cache_config
.
kv_cache_groups
[
0
].
kv_cache_spec
0
].
kv_cache_spec
self
.
block_size
=
self
.
kv_cache_spec
.
block_size
self
.
block_size
=
self
.
kv_cache_spec
.
block_size
...
@@ -250,10 +244,9 @@ class HybridKVCacheCoordinator(KVCacheCoordinator):
...
@@ -250,10 +244,9 @@ class HybridKVCacheCoordinator(KVCacheCoordinator):
def
__init__
(
self
,
kv_cache_config
:
KVCacheConfig
,
max_model_len
:
int
,
def
__init__
(
self
,
kv_cache_config
:
KVCacheConfig
,
max_model_len
:
int
,
use_eagle
:
bool
,
enable_caching
:
bool
,
use_eagle
:
bool
,
enable_caching
:
bool
,
caching_hash_fn
:
Callable
,
enable_kv_cache_events
:
bool
):
enable_kv_cache_events
:
bool
):
super
().
__init__
(
kv_cache_config
,
max_model_len
,
use_eagle
,
super
().
__init__
(
kv_cache_config
,
max_model_len
,
use_eagle
,
enable_caching
,
caching_hash_fn
,
enable_caching
,
enable_kv_cache_events
)
enable_kv_cache_events
)
self
.
verify_and_split_kv_cache_groups
()
self
.
verify_and_split_kv_cache_groups
()
def
verify_and_split_kv_cache_groups
(
self
)
->
None
:
def
verify_and_split_kv_cache_groups
(
self
)
->
None
:
...
@@ -386,17 +379,15 @@ class HybridKVCacheCoordinator(KVCacheCoordinator):
...
@@ -386,17 +379,15 @@ class HybridKVCacheCoordinator(KVCacheCoordinator):
def
get_kv_cache_coordinator
(
def
get_kv_cache_coordinator
(
kv_cache_config
:
KVCacheConfig
,
max_model_len
:
int
,
use_eagle
:
bool
,
kv_cache_config
:
KVCacheConfig
,
max_model_len
:
int
,
use_eagle
:
bool
,
enable_caching
:
bool
,
caching_hash_fn
:
Callable
,
enable_caching
:
bool
,
enable_kv_cache_events
:
bool
)
->
KVCacheCoordinator
:
enable_kv_cache_events
:
bool
)
->
KVCacheCoordinator
:
if
not
enable_caching
:
if
not
enable_caching
:
return
KVCacheCoordinatorNoPrefixCache
(
kv_cache_config
,
max_model_len
,
return
KVCacheCoordinatorNoPrefixCache
(
kv_cache_config
,
max_model_len
,
use_eagle
,
caching_hash_fn
,
use_eagle
,
enable_kv_cache_events
)
enable_kv_cache_events
)
if
len
(
kv_cache_config
.
kv_cache_groups
)
==
1
:
if
len
(
kv_cache_config
.
kv_cache_groups
)
==
1
:
return
UnitaryKVCacheCoordinator
(
kv_cache_config
,
max_model_len
,
return
UnitaryKVCacheCoordinator
(
kv_cache_config
,
max_model_len
,
use_eagle
,
enable_caching
,
use_eagle
,
enable_caching
,
caching_hash_fn
,
enable_kv_cache_events
)
enable_kv_cache_events
)
return
HybridKVCacheCoordinator
(
kv_cache_config
,
max_model_len
,
use_eagle
,
return
HybridKVCacheCoordinator
(
kv_cache_config
,
max_model_len
,
use_eagle
,
enable_caching
,
caching_hash_fn
,
enable_caching
,
enable_kv_cache_events
)
enable_kv_cache_events
)
vllm/v1/core/kv_cache_manager.py
View file @
c280066f
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
collections
import
defaultdict
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
Optional
from
typing
import
Optional
from
vllm.distributed.kv_events
import
KVCacheEvent
from
vllm.distributed.kv_events
import
KVCacheEvent
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.utils
import
sha256
,
sha256_cbor_64bit
from
vllm.v1.core.kv_cache_coordinator
import
get_kv_cache_coordinator
from
vllm.v1.core.kv_cache_coordinator
import
get_kv_cache_coordinator
from
vllm.v1.core.kv_cache_utils
import
(
BlockHash
,
KVCacheBlock
,
from
vllm.v1.core.kv_cache_utils
import
KVCacheBlock
hash_request_tokens
,
init_none_hash
)
from
vllm.v1.kv_cache_interface
import
KVCacheConfig
from
vllm.v1.kv_cache_interface
import
KVCacheConfig
from
vllm.v1.metrics.stats
import
PrefixCacheStats
from
vllm.v1.metrics.stats
import
PrefixCacheStats
from
vllm.v1.request
import
Request
,
RequestStatus
from
vllm.v1.request
import
Request
,
RequestStatus
...
@@ -71,23 +68,13 @@ class KVCacheManager:
...
@@ -71,23 +68,13 @@ class KVCacheManager:
kv_cache_config
:
KVCacheConfig
,
kv_cache_config
:
KVCacheConfig
,
max_model_len
:
int
,
max_model_len
:
int
,
enable_caching
:
bool
=
True
,
enable_caching
:
bool
=
True
,
caching_hash_algo
:
str
=
"builtin"
,
use_eagle
:
bool
=
False
,
use_eagle
:
bool
=
False
,
log_stats
:
bool
=
False
,
log_stats
:
bool
=
False
,
enable_kv_cache_events
:
bool
=
False
,
enable_kv_cache_events
:
bool
=
False
,
)
->
None
:
)
->
None
:
self
.
max_model_len
=
max_model_len
self
.
max_model_len
=
max_model_len
if
len
(
kv_cache_config
.
kv_cache_groups
)
==
0
:
# Attention free models don't have kv cache,
# thus don't need prefix caching.
enable_caching
=
False
self
.
enable_caching
=
enable_caching
self
.
enable_caching
=
enable_caching
self
.
caching_hash_fn
=
(
sha256_cbor_64bit
if
caching_hash_algo
==
"sha256_cbor_64bit"
else
sha256
if
caching_hash_algo
==
"sha256"
else
hash
)
init_none_hash
(
self
.
caching_hash_fn
)
self
.
use_eagle
=
use_eagle
self
.
use_eagle
=
use_eagle
self
.
log_stats
=
log_stats
self
.
log_stats
=
log_stats
# FIXME: make prefix cache stats conditional on log_stats
# FIXME: make prefix cache stats conditional on log_stats
...
@@ -107,19 +94,12 @@ class KVCacheManager:
...
@@ -107,19 +94,12 @@ class KVCacheManager:
max_model_len
=
self
.
max_model_len
,
max_model_len
=
self
.
max_model_len
,
use_eagle
=
self
.
use_eagle
,
use_eagle
=
self
.
use_eagle
,
enable_caching
=
self
.
enable_caching
,
enable_caching
=
self
.
enable_caching
,
caching_hash_fn
=
self
.
caching_hash_fn
,
enable_kv_cache_events
=
enable_kv_cache_events
,
enable_kv_cache_events
=
enable_kv_cache_events
,
)
)
self
.
num_kv_cache_groups
=
len
(
kv_cache_config
.
kv_cache_groups
)
self
.
num_kv_cache_groups
=
len
(
kv_cache_config
.
kv_cache_groups
)
self
.
block_pool
=
self
.
coordinator
.
block_pool
self
.
block_pool
=
self
.
coordinator
.
block_pool
self
.
kv_cache_config
=
kv_cache_config
self
.
kv_cache_config
=
kv_cache_config
# Mapping from request ID to kv block hashes.
# This is to avoid recomputing the block hashes for each call of
# `get_computed_blocks` or `allocate_slots`.
self
.
req_to_block_hashes
:
defaultdict
[
str
,
list
[
BlockHash
]]
=
defaultdict
(
list
)
@
property
@
property
def
usage
(
self
)
->
float
:
def
usage
(
self
)
->
float
:
"""Get the KV cache usage.
"""Get the KV cache usage.
...
@@ -161,15 +141,6 @@ class KVCacheManager:
...
@@ -161,15 +141,6 @@ class KVCacheManager:
and
request
.
sampling_params
.
prompt_logprobs
is
not
None
)):
and
request
.
sampling_params
.
prompt_logprobs
is
not
None
)):
return
self
.
create_empty_block_list
(),
0
return
self
.
create_empty_block_list
(),
0
# The block hashes for the request may already be computed
# if the scheduler has tried to schedule the request before.
block_hashes
=
self
.
req_to_block_hashes
[
request
.
request_id
]
if
not
block_hashes
:
assert
self
.
block_size
is
not
None
block_hashes
=
hash_request_tokens
(
self
.
caching_hash_fn
,
self
.
block_size
,
request
)
self
.
req_to_block_hashes
[
request
.
request_id
]
=
block_hashes
# NOTE: When all tokens hit the cache, we must recompute the last token
# NOTE: When all tokens hit the cache, we must recompute the last token
# to obtain logits. Thus, set max_cache_hit_length to prompt_length - 1.
# to obtain logits. Thus, set max_cache_hit_length to prompt_length - 1.
# This can trigger recomputation of an entire block, rather than just
# This can trigger recomputation of an entire block, rather than just
...
@@ -178,7 +149,7 @@ class KVCacheManager:
...
@@ -178,7 +149,7 @@ class KVCacheManager:
# could slightly improve performance in the future.
# could slightly improve performance in the future.
max_cache_hit_length
=
request
.
num_tokens
-
1
max_cache_hit_length
=
request
.
num_tokens
-
1
computed_blocks
,
num_new_computed_tokens
=
(
computed_blocks
,
num_new_computed_tokens
=
(
self
.
coordinator
.
find_longest_cache_hit
(
block_hashes
,
self
.
coordinator
.
find_longest_cache_hit
(
request
.
block_hashes
,
max_cache_hit_length
))
max_cache_hit_length
))
if
self
.
log_stats
:
if
self
.
log_stats
:
...
@@ -296,11 +267,7 @@ class KVCacheManager:
...
@@ -296,11 +267,7 @@ class KVCacheManager:
# at `request.num_tokens`, ensuring only "finalized" tokens are cached.
# at `request.num_tokens`, ensuring only "finalized" tokens are cached.
num_tokens_to_cache
=
min
(
num_computed_tokens
+
num_new_tokens
,
num_tokens_to_cache
=
min
(
num_computed_tokens
+
num_new_tokens
,
request
.
num_tokens
)
request
.
num_tokens
)
self
.
coordinator
.
cache_blocks
(
self
.
coordinator
.
cache_blocks
(
request
,
num_tokens_to_cache
)
request
,
self
.
req_to_block_hashes
[
request
.
request_id
],
num_tokens_to_cache
,
)
return
KVCacheBlocks
(
new_blocks
)
return
KVCacheBlocks
(
new_blocks
)
...
@@ -373,14 +340,6 @@ class KVCacheManager:
...
@@ -373,14 +340,6 @@ class KVCacheManager:
return
self
.
coordinator
.
get_num_common_prefix_blocks
(
return
self
.
coordinator
.
get_num_common_prefix_blocks
(
request
.
request_id
,
num_running_requests
)
request
.
request_id
,
num_running_requests
)
def
free_block_hashes
(
self
,
request
:
Request
)
->
None
:
"""Discard the block hashes for the request.
NOTE: Unlike `free`, this method should be called only when the request
is finished, not when it is preempted.
"""
self
.
req_to_block_hashes
.
pop
(
request
.
request_id
,
None
)
def
take_events
(
self
)
->
list
[
KVCacheEvent
]:
def
take_events
(
self
)
->
list
[
KVCacheEvent
]:
"""Take the KV cache events from the block pool.
"""Take the KV cache events from the block pool.
...
@@ -397,9 +356,7 @@ class KVCacheManager:
...
@@ -397,9 +356,7 @@ class KVCacheManager:
def
cache_blocks
(
self
,
request
:
Request
,
num_computed_tokens
:
int
)
->
None
:
def
cache_blocks
(
self
,
request
:
Request
,
num_computed_tokens
:
int
)
->
None
:
"""Cache the blocks for the request, if enabled."""
"""Cache the blocks for the request, if enabled."""
if
self
.
enable_caching
:
if
self
.
enable_caching
:
block_hashes
=
self
.
req_to_block_hashes
[
request
.
request_id
]
self
.
coordinator
.
cache_blocks
(
request
,
num_computed_tokens
)
self
.
coordinator
.
cache_blocks
(
request
,
block_hashes
,
num_computed_tokens
)
def
create_empty_block_list
(
self
)
->
KVCacheBlocks
:
def
create_empty_block_list
(
self
)
->
KVCacheBlocks
:
"""Creates a new KVCacheBlocks instance with no blocks."""
"""Creates a new KVCacheBlocks instance with no blocks."""
...
...
vllm/v1/core/kv_cache_utils.py
View file @
c280066f
...
@@ -547,41 +547,61 @@ def hash_block_tokens(
...
@@ -547,41 +547,61 @@ def hash_block_tokens(
curr_block_token_ids_tuple
,
extra_keys
)
curr_block_token_ids_tuple
,
extra_keys
)
def
hash_request_tokens
(
hash_function
:
Any
,
block_size
:
int
,
def
get_request_block_hasher
(
request
:
Request
)
->
list
[
BlockHash
]:
block_size
:
int
,
"""Computes hash values of a chain of blocks given a sequence of
caching_hash_fn
:
Callable
[[
Any
],
token IDs. The hash value is used for prefix caching.
int
])
->
Callable
[[
Request
],
list
[
BlockHash
]]:
"""
Args:
Returns a function which computes the list of un-computed block hashes
block_size: The size of each block.
of a request.
request: The request object.
Each request holds a list of its block hashes (request.block_hashes).
Returns:
When a request is created, it calls the below function to compute
The list of computed hash values.
the hashes of all full blocks of the request's initial tokens.
The hashes are then stored in request.block_hashes.
Later, whenever new tokens are appended to the request, it calls
the below function again to compute any new full blocks of tokens.
The returned new hashes are appended to request.block_hashes.
"""
"""
token_ids
=
request
.
all_token_ids
req_need_extra_keys
=
need_extra_keys
(
request
)
def
request_block_hasher
(
request
:
Request
)
->
list
[
BlockHash
]:
req_extra_keys
=
Non
e
start_token_idx
=
len
(
request
.
block_hashes
)
*
block_siz
e
curr_mm_idx
=
0
num_tokens
=
request
.
num_tokens
ret
=
[]
curr_mm_idx
=
0
parent_block_hash_value
=
None
if
start_token_idx
>
0
:
# Only full blocks will be hashed
# Set curr_mm_idx = -1 to indicate the last mm input.
for
start
in
range
(
0
,
len
(
token_ids
)
-
block_size
+
1
,
block_size
):
# Note that since we reach to this branch only when the block is
end
=
start
+
block_size
# completed with generated tokens, we only need to consider the
block_token_ids
=
token_ids
[
start
:
end
]
# last mm input.
curr_mm_idx
=
-
1
prev_block_hash_value
=
request
.
block_hashes
[
-
1
].
hash_value
\
if
request
.
block_hashes
else
None
new_block_hashes
:
list
[
BlockHash
]
=
[]
while
True
:
end_token_idx
=
start_token_idx
+
block_size
if
end_token_idx
>
num_tokens
:
# We only hash full blocks
break
if
req_need_extra_keys
:
# MM and LoRA requests need extra keys for block-hash computation.
# MM and LoRA requests need extra keys for block-hash computation.
req_
extra_keys
,
curr_mm_idx
=
generate_block_hash_extra_keys
(
extra_keys
,
curr_mm_idx
=
generate_block_hash_extra_keys
(
request
,
start
,
end
,
curr_mm_idx
)
request
,
start
_token_idx
,
end_token_idx
,
curr_mm_idx
)
block_hash
=
hash_block_tokens
(
hash_function
,
parent_block_hash_value
,
# Compute the hash of the current block
block_token_ids
,
req_extra_keys
)
block_tokens
=
request
.
all_token_ids
[
start_token_idx
:
end_token_idx
]
ret
.
append
(
block_hash
)
block_hash
=
hash_block_tokens
(
caching_hash_fn
,
parent_block_hash_value
=
block_hash
.
hash_value
prev_block_hash_value
,
block_tokens
,
return
ret
extra_keys
)
new_block_hashes
.
append
(
block_hash
)
start_token_idx
+=
block_size
prev_block_hash_value
=
block_hash
.
hash_value
return
new_block_hashes
return
request_block_hasher
def
max_memory_usage_bytes
(
vllm_config
:
VllmConfig
,
def
max_memory_usage_bytes
(
vllm_config
:
VllmConfig
,
...
...
vllm/v1/core/sched/scheduler.py
View file @
c280066f
...
@@ -155,7 +155,6 @@ class Scheduler(SchedulerInterface):
...
@@ -155,7 +155,6 @@ class Scheduler(SchedulerInterface):
kv_cache_config
=
kv_cache_config
,
kv_cache_config
=
kv_cache_config
,
max_model_len
=
self
.
max_model_len
,
max_model_len
=
self
.
max_model_len
,
enable_caching
=
self
.
cache_config
.
enable_prefix_caching
,
enable_caching
=
self
.
cache_config
.
enable_prefix_caching
,
caching_hash_algo
=
self
.
cache_config
.
prefix_caching_hash_algo
,
use_eagle
=
self
.
use_eagle
,
use_eagle
=
self
.
use_eagle
,
log_stats
=
self
.
log_stats
,
log_stats
=
self
.
log_stats
,
enable_kv_cache_events
=
self
.
enable_kv_cache_events
,
enable_kv_cache_events
=
self
.
enable_kv_cache_events
,
...
@@ -1036,7 +1035,6 @@ class Scheduler(SchedulerInterface):
...
@@ -1036,7 +1035,6 @@ class Scheduler(SchedulerInterface):
def
_free_blocks
(
self
,
request
:
Request
):
def
_free_blocks
(
self
,
request
:
Request
):
assert
request
.
is_finished
()
assert
request
.
is_finished
()
self
.
kv_cache_manager
.
free
(
request
)
self
.
kv_cache_manager
.
free
(
request
)
self
.
kv_cache_manager
.
free_block_hashes
(
request
)
del
self
.
requests
[
request
.
request_id
]
del
self
.
requests
[
request
.
request_id
]
def
get_num_unfinished_requests
(
self
)
->
int
:
def
get_num_unfinished_requests
(
self
)
->
int
:
...
...
vllm/v1/core/single_type_kv_cache_manager.py
View file @
c280066f
...
@@ -3,7 +3,6 @@
...
@@ -3,7 +3,6 @@
import
itertools
import
itertools
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
from
collections
import
defaultdict
from
collections
import
defaultdict
from
typing
import
Callable
from
vllm.utils
import
cdiv
from
vllm.utils
import
cdiv
from
vllm.v1.core.block_pool
import
BlockPool
from
vllm.v1.core.block_pool
import
BlockPool
...
@@ -25,7 +24,6 @@ class SingleTypeKVCacheManager(ABC):
...
@@ -25,7 +24,6 @@ class SingleTypeKVCacheManager(ABC):
kv_cache_spec
:
KVCacheSpec
,
kv_cache_spec
:
KVCacheSpec
,
block_pool
:
BlockPool
,
block_pool
:
BlockPool
,
kv_cache_group_id
:
int
,
kv_cache_group_id
:
int
,
caching_hash_fn
:
Callable
,
)
->
None
:
)
->
None
:
"""
"""
Initializes the SingleTypeKVCacheManager.
Initializes the SingleTypeKVCacheManager.
...
@@ -33,7 +31,6 @@ class SingleTypeKVCacheManager(ABC):
...
@@ -33,7 +31,6 @@ class SingleTypeKVCacheManager(ABC):
kv_cache_spec: The kv_cache_spec for this manager.
kv_cache_spec: The kv_cache_spec for this manager.
block_pool: The block pool.
block_pool: The block pool.
kv_cache_group_id: The id of the kv cache group of this manager.
kv_cache_group_id: The id of the kv cache group of this manager.
caching_hash_fn: The caching hash function.
"""
"""
self
.
block_size
=
kv_cache_spec
.
block_size
self
.
block_size
=
kv_cache_spec
.
block_size
...
@@ -52,7 +49,6 @@ class SingleTypeKVCacheManager(ABC):
...
@@ -52,7 +49,6 @@ class SingleTypeKVCacheManager(ABC):
# data for reempted ones.
# data for reempted ones.
self
.
num_cached_block
:
dict
[
str
,
int
]
=
{}
self
.
num_cached_block
:
dict
[
str
,
int
]
=
{}
self
.
caching_hash_fn
=
caching_hash_fn
self
.
kv_cache_group_id
=
kv_cache_group_id
self
.
kv_cache_group_id
=
kv_cache_group_id
self
.
_null_block
=
block_pool
.
null_block
self
.
_null_block
=
block_pool
.
null_block
...
@@ -130,14 +126,12 @@ class SingleTypeKVCacheManager(ABC):
...
@@ -130,14 +126,12 @@ class SingleTypeKVCacheManager(ABC):
req_blocks
.
extend
(
new_blocks
)
req_blocks
.
extend
(
new_blocks
)
return
new_blocks
return
new_blocks
def
cache_blocks
(
self
,
request
:
Request
,
block_hashes
:
list
[
BlockHash
],
def
cache_blocks
(
self
,
request
:
Request
,
num_tokens
:
int
)
->
None
:
num_tokens
:
int
)
->
None
:
"""
"""
Cache the blocks for the request.
Cache the blocks for the request.
Args:
Args:
request: The request.
request: The request.
block_hashes: The block hashes of the request.
num_tokens: The total number of tokens that need to be cached
num_tokens: The total number of tokens that need to be cached
(including tokens that are already cached).
(including tokens that are already cached).
"""
"""
...
@@ -147,12 +141,10 @@ class SingleTypeKVCacheManager(ABC):
...
@@ -147,12 +141,10 @@ class SingleTypeKVCacheManager(ABC):
self
.
block_pool
.
cache_full_blocks
(
self
.
block_pool
.
cache_full_blocks
(
request
=
request
,
request
=
request
,
blocks
=
self
.
req_to_blocks
[
request
.
request_id
],
blocks
=
self
.
req_to_blocks
[
request
.
request_id
],
block_hashes
=
block_hashes
,
num_cached_blocks
=
num_cached_blocks
,
num_cached_blocks
=
num_cached_blocks
,
num_full_blocks
=
num_full_blocks
,
num_full_blocks
=
num_full_blocks
,
block_size
=
self
.
block_size
,
block_size
=
self
.
block_size
,
kv_cache_group_id
=
self
.
kv_cache_group_id
,
kv_cache_group_id
=
self
.
kv_cache_group_id
,
hash_fn
=
self
.
caching_hash_fn
,
)
)
self
.
num_cached_block
[
request
.
request_id
]
=
num_full_blocks
self
.
num_cached_block
[
request
.
request_id
]
=
num_full_blocks
...
...
vllm/v1/engine/core.py
View file @
c280066f
...
@@ -25,9 +25,11 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
...
@@ -25,9 +25,11 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
from
vllm.tasks
import
POOLING_TASKS
,
SupportedTask
from
vllm.tasks
import
POOLING_TASKS
,
SupportedTask
from
vllm.transformers_utils.config
import
(
from
vllm.transformers_utils.config
import
(
maybe_register_config_serialize_by_value
)
maybe_register_config_serialize_by_value
)
from
vllm.utils
import
(
decorate_logs
,
make_zmq_socket
,
from
vllm.utils
import
(
decorate_logs
,
get_hash_fn_by_name
,
make_zmq_socket
,
resolve_obj_by_qualname
,
set_process_title
)
resolve_obj_by_qualname
,
set_process_title
)
from
vllm.v1.core.kv_cache_utils
import
(
get_kv_cache_config
,
from
vllm.v1.core.kv_cache_utils
import
(
BlockHash
,
get_kv_cache_config
,
get_request_block_hasher
,
init_none_hash
,
unify_kv_cache_configs
)
unify_kv_cache_configs
)
from
vllm.v1.core.sched.interface
import
SchedulerInterface
from
vllm.v1.core.sched.interface
import
SchedulerInterface
from
vllm.v1.core.sched.output
import
SchedulerOutput
from
vllm.v1.core.sched.output
import
SchedulerOutput
...
@@ -140,6 +142,19 @@ class EngineCore:
...
@@ -140,6 +142,19 @@ class EngineCore:
self
.
batch_queue_size
)
self
.
batch_queue_size
)
self
.
batch_queue
=
queue
.
Queue
(
self
.
batch_queue_size
)
self
.
batch_queue
=
queue
.
Queue
(
self
.
batch_queue_size
)
self
.
request_block_hasher
:
Optional
[
Callable
[[
Request
],
list
[
BlockHash
]]]
=
None
if
(
self
.
vllm_config
.
cache_config
.
enable_prefix_caching
or
self
.
scheduler
.
get_kv_connector
()
is
not
None
):
block_size
=
vllm_config
.
cache_config
.
block_size
caching_hash_fn
=
get_hash_fn_by_name
(
vllm_config
.
cache_config
.
prefix_caching_hash_algo
)
init_none_hash
(
caching_hash_fn
)
self
.
request_block_hasher
=
get_request_block_hasher
(
block_size
,
caching_hash_fn
)
def
_initialize_kv_caches
(
def
_initialize_kv_caches
(
self
,
vllm_config
:
VllmConfig
)
->
tuple
[
int
,
int
,
KVCacheConfig
]:
self
,
vllm_config
:
VllmConfig
)
->
tuple
[
int
,
int
,
KVCacheConfig
]:
start
=
time
.
time
()
start
=
time
.
time
()
...
@@ -417,7 +432,8 @@ class EngineCore:
...
@@ -417,7 +432,8 @@ class EngineCore:
request
.
mm_kwargs
=
self
.
mm_input_cache_server
.
get_and_update
(
request
.
mm_kwargs
=
self
.
mm_input_cache_server
.
get_and_update
(
request
.
mm_kwargs
,
request
.
mm_hashes
)
request
.
mm_kwargs
,
request
.
mm_hashes
)
req
=
Request
.
from_engine_core_request
(
request
)
req
=
Request
.
from_engine_core_request
(
request
,
self
.
request_block_hasher
)
if
req
.
use_structured_output
:
if
req
.
use_structured_output
:
# Note on thread safety: no race condition.
# Note on thread safety: no race condition.
# `grammar_init` is only invoked in input processing thread. For
# `grammar_init` is only invoked in input processing thread. For
...
...
vllm/v1/request.py
View file @
c280066f
...
@@ -3,7 +3,8 @@
...
@@ -3,7 +3,8 @@
import
enum
import
enum
import
time
import
time
from
typing
import
TYPE_CHECKING
,
Any
,
Optional
,
Union
from
functools
import
partial
from
typing
import
TYPE_CHECKING
,
Any
,
Callable
,
Optional
,
Union
from
vllm.multimodal.inputs
import
MultiModalKwargsItem
,
PlaceholderRange
from
vllm.multimodal.inputs
import
MultiModalKwargsItem
,
PlaceholderRange
from
vllm.pooling_params
import
PoolingParams
from
vllm.pooling_params
import
PoolingParams
...
@@ -16,6 +17,7 @@ from vllm.v1.utils import ConstantList
...
@@ -16,6 +17,7 @@ from vllm.v1.utils import ConstantList
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.v1.core.kv_cache_utils
import
BlockHash
class
Request
:
class
Request
:
...
@@ -36,6 +38,8 @@ class Request:
...
@@ -36,6 +38,8 @@ class Request:
structured_output_request
:
Optional
[
"StructuredOutputRequest"
]
=
None
,
structured_output_request
:
Optional
[
"StructuredOutputRequest"
]
=
None
,
cache_salt
:
Optional
[
str
]
=
None
,
cache_salt
:
Optional
[
str
]
=
None
,
priority
:
int
=
0
,
priority
:
int
=
0
,
block_hasher
:
Optional
[
Callable
[[
"Request"
],
list
[
"BlockHash"
]]]
=
None
,
)
->
None
:
)
->
None
:
self
.
request_id
=
request_id
self
.
request_id
=
request_id
self
.
client_index
=
client_index
self
.
client_index
=
client_index
...
@@ -108,8 +112,18 @@ class Request:
...
@@ -108,8 +112,18 @@ class Request:
# indicates that the output is corrupted
# indicates that the output is corrupted
self
.
num_nans_in_logits
=
0
self
.
num_nans_in_logits
=
0
self
.
block_hashes
:
list
[
BlockHash
]
=
[]
self
.
get_hash_new_full_blocks
:
Optional
[
Callable
[
[],
list
[
BlockHash
]]]
=
None
if
block_hasher
is
not
None
:
self
.
get_hash_new_full_blocks
=
partial
(
block_hasher
,
self
)
self
.
block_hashes
=
self
.
get_hash_new_full_blocks
()
@
classmethod
@
classmethod
def
from_engine_core_request
(
cls
,
request
:
EngineCoreRequest
)
->
"Request"
:
def
from_engine_core_request
(
cls
,
request
:
EngineCoreRequest
,
block_hasher
:
Optional
[
Callable
[[
"Request"
],
list
[
"BlockHash"
]]]
)
->
"Request"
:
if
request
.
mm_kwargs
is
not
None
:
if
request
.
mm_kwargs
is
not
None
:
assert
is_list_of
(
request
.
mm_kwargs
,
MultiModalKwargsItem
),
(
assert
is_list_of
(
request
.
mm_kwargs
,
MultiModalKwargsItem
),
(
"mm_kwargs was not updated in EngineCore.add_request"
)
"mm_kwargs was not updated in EngineCore.add_request"
)
...
@@ -131,6 +145,7 @@ class Request:
...
@@ -131,6 +145,7 @@ class Request:
if
request
.
sampling_params
else
None
,
if
request
.
sampling_params
else
None
,
cache_salt
=
request
.
cache_salt
,
cache_salt
=
request
.
cache_salt
,
priority
=
request
.
priority
,
priority
=
request
.
priority
,
block_hasher
=
block_hasher
,
)
)
def
append_output_token_ids
(
def
append_output_token_ids
(
...
@@ -144,6 +159,9 @@ class Request:
...
@@ -144,6 +159,9 @@ class Request:
self
.
_output_token_ids
.
extend
(
token_ids
)
self
.
_output_token_ids
.
extend
(
token_ids
)
self
.
_all_token_ids
.
extend
(
token_ids
)
self
.
_all_token_ids
.
extend
(
token_ids
)
if
self
.
get_hash_new_full_blocks
is
not
None
:
self
.
block_hashes
.
extend
(
self
.
get_hash_new_full_blocks
())
@
property
@
property
def
is_output_corrupted
(
self
)
->
bool
:
def
is_output_corrupted
(
self
)
->
bool
:
return
self
.
num_nans_in_logits
>
0
return
self
.
num_nans_in_logits
>
0
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment