Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
4634a89d
Unverified
Commit
4634a89d
authored
Nov 22, 2024
by
Ricky Xu
Committed by
GitHub
Nov 22, 2024
Browse files
Prefix Cache Aware Scheduling [1/n] (#10128)
Signed-off-by:
rickyx
<
rickyx@anyscale.com
>
parent
7c25fe45
Changes
13
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
967 additions
and
241 deletions
+967
-241
tests/core/block/test_prefix_caching_block.py
tests/core/block/test_prefix_caching_block.py
+175
-6
tests/core/test_scheduler.py
tests/core/test_scheduler.py
+175
-4
tests/core/utils.py
tests/core/utils.py
+46
-5
tests/prefix_caching/test_prefix_caching.py
tests/prefix_caching/test_prefix_caching.py
+100
-6
vllm/core/block/cpu_gpu_block_allocator.py
vllm/core/block/cpu_gpu_block_allocator.py
+7
-8
vllm/core/block/interfaces.py
vllm/core/block/interfaces.py
+21
-15
vllm/core/block/naive_block.py
vllm/core/block/naive_block.py
+4
-7
vllm/core/block/prefix_caching_block.py
vllm/core/block/prefix_caching_block.py
+162
-96
vllm/core/block_manager.py
vllm/core/block_manager.py
+14
-9
vllm/core/interfaces.py
vllm/core/interfaces.py
+4
-0
vllm/core/placeholder_block_space_manager.py
vllm/core/placeholder_block_space_manager.py
+3
-0
vllm/core/scheduler.py
vllm/core/scheduler.py
+253
-85
vllm/sequence.py
vllm/sequence.py
+3
-0
No files found.
tests/core/block/test_prefix_caching_block.py
View file @
4634a89d
...
@@ -5,9 +5,14 @@ from unittest.mock import MagicMock
...
@@ -5,9 +5,14 @@ from unittest.mock import MagicMock
import
pytest
import
pytest
from
tests.core.utils
import
create_dummy_sequence
from
vllm.core.block.cpu_gpu_block_allocator
import
CpuGpuBlockAllocator
from
vllm.core.block.interfaces
import
Block
,
BlockAllocator
from
vllm.core.block.interfaces
import
Block
,
BlockAllocator
from
vllm.core.block.prefix_caching_block
import
(
PrefixCachingBlock
,
from
vllm.core.block.prefix_caching_block
import
(
ComputedBlocksTracker
,
PrefixCachingBlock
,
PrefixCachingBlockAllocator
)
PrefixCachingBlockAllocator
)
from
vllm.sequence
import
Logprob
from
vllm.utils
import
Device
class
TestPrefixCachingBlock
:
class
TestPrefixCachingBlock
:
...
@@ -726,18 +731,71 @@ class TestPrefixCachingBlockAllocator:
...
@@ -726,18 +731,71 @@ class TestPrefixCachingBlockAllocator:
token_ids
=
common_token_ids
,
token_ids
=
common_token_ids
,
allocator
=
allocator
,
allocator
=
allocator
,
)
)
block_
id
s
=
[
block
.
block_id
for
block
in
blocks
]
block_
hashe
s
=
[
block
.
content_hash
for
block
in
blocks
]
# The allocated blocks should be marked as touched
# The allocated blocks should be marked as touched
# but not computed.
# but not computed.
computed_block_ids
=
allocator
.
get_computed_block_ids
(
computed_block_ids
=
allocator
.
find_cached_blocks_prefix
(
[],
block_
ids
,
skip_last_block_id
=
False
)
block_
hashes
)
assert
len
(
computed_block_ids
)
==
0
assert
len
(
computed_block_ids
)
==
0
allocator
.
mark_blocks_as_computed
([])
allocator
.
mark_blocks_as_computed
([])
computed_block_ids
=
allocator
.
get_computed_block_ids
(
computed_block_ids
=
allocator
.
find_cached_blocks_prefix
(
[],
block_
ids
,
skip_last_block_id
=
False
)
block_
hashes
=
block_hashes
)
assert
len
(
computed_block_ids
)
==
common_blocks
assert
len
(
computed_block_ids
)
==
common_blocks
@
staticmethod
def
test_find_cached_blocks_prefix
():
"""
This test verifies the behavior of find_cached_blocks_prefix.
"""
block_size
=
4
num_blocks
=
8
total_test_blocks
=
12
allocator
=
PrefixCachingBlockAllocator
(
num_blocks
=
num_blocks
,
block_size
=
block_size
)
token_ids
=
list
(
range
(
total_test_blocks
*
block_size
))
block_tokens_seq1
=
token_ids
[:
num_blocks
*
block_size
]
blocks_seq1
=
TestPrefixCachingBlockAllocator
.
create_immutable_chain
(
block_size
=
block_size
,
token_ids
=
block_tokens_seq1
,
allocator
=
allocator
,
)
block_hashes_seq1
=
[
block
.
content_hash
for
block
in
blocks_seq1
]
allocator
.
mark_blocks_as_computed
([])
# All blocks should be cached.
cached_blocks_seq1
=
allocator
.
find_cached_blocks_prefix
(
block_hashes
=
block_hashes_seq1
)
assert
len
(
cached_blocks_seq1
)
==
num_blocks
# Free the first sequence.
for
block
in
blocks_seq1
:
allocator
.
free
(
block
)
# All blocks should be still be cached if not required to be allocated.
cached_blocks
=
allocator
.
find_cached_blocks_prefix
(
block_hashes
=
block_hashes_seq1
)
assert
len
(
cached_blocks
)
==
num_blocks
block_tokens_seq2
=
token_ids
[
num_blocks
*
block_size
:]
blocks_seq2
=
TestPrefixCachingBlockAllocator
.
create_immutable_chain
(
block_size
=
block_size
,
token_ids
=
block_tokens_seq2
,
allocator
=
allocator
,
)
block_hashes_seq2
=
[
block
.
content_hash
for
block
in
blocks_seq2
]
allocator
.
mark_blocks_as_computed
([])
cached_blocks
=
allocator
.
find_cached_blocks_prefix
(
block_hashes
=
block_hashes_seq2
)
assert
len
(
cached_blocks
)
==
len
(
blocks_seq2
)
# Half of the blocks from seq1 should still be cached.
num_evicted_blocks
=
len
(
blocks_seq2
)
cached_blocks
=
allocator
.
find_cached_blocks_prefix
(
block_hashes
=
block_hashes_seq1
)
assert
len
(
cached_blocks
)
==
len
(
blocks_seq1
)
-
num_evicted_blocks
@
staticmethod
@
staticmethod
def
create_immutable_chain
(
def
create_immutable_chain
(
block_size
:
int
,
block_size
:
int
,
...
@@ -762,3 +820,114 @@ class TestPrefixCachingBlockAllocator:
...
@@ -762,3 +820,114 @@ class TestPrefixCachingBlockAllocator:
blocks
.
append
(
prev_block
)
blocks
.
append
(
prev_block
)
return
blocks
return
blocks
class
TestComputedBlocksTracker
:
@
staticmethod
def
_get_mock_allocator
():
return
MagicMock
(
spec
=
PrefixCachingBlockAllocator
)
@
staticmethod
def
test_get_num_cached_tokens
():
"""
Test it correctly computes the number of cached tokens for a given
sequence:
- The cache token count is derived from the number of cached blocks.
- The cache token count is updated when the allocator is updated.
- When a sequence is removed, the cache token count should be updated
accordingly.
# TODO(rickyx): This behaviour for prefill sequence is a hack until
we fix the computed blocks tracking.
- The cache token count for prefill sequence doesn't change while
the sequence is in continuous prefill (chunked prefill).
"""
block_size
=
4
mock_allocator
=
TestComputedBlocksTracker
.
_get_mock_allocator
()
tracker
=
ComputedBlocksTracker
(
allocator
=
mock_allocator
,
block_size
=
block_size
,
enable_caching
=
True
,
)
# Not yet allocated.
tokens
=
[
0
,
1
,
2
,
3
,
4
,
5
]
seq1
=
create_dummy_sequence
(
request_id
=
0
,
token_ids
=
tokens
,
block_size
=
block_size
)
mock_allocator
.
find_cached_blocks_prefix
.
return_value
=
[]
assert
tracker
.
get_num_cached_tokens
(
seq1
)
==
0
mock_allocator
.
find_cached_blocks_prefix
.
return_value
=
[
None
]
# 1 block cached.
# Result is cached for prefill sequence.
assert
tracker
.
get_num_cached_tokens
(
seq1
)
==
0
# Mark the sequence as non-prefill.
seq1
.
data
.
update_num_computed_tokens
(
len
(
tokens
))
# 6 tokens computed.
assert
not
seq1
.
is_prefill
()
# Recomputes for decoding sequence.
assert
tracker
.
get_num_cached_tokens
(
seq1
)
==
4
# Append new tokens to the sequence.
num_new_tokens
=
3
for
i
in
range
(
num_new_tokens
):
seq1
.
append_token_id
(
i
,
{
i
:
Logprob
(
logprob
=
0.0
)})
assert
tracker
.
get_num_cached_tokens
(
seq1
)
==
4
# Update the allocator.
mock_allocator
.
find_cached_blocks_prefix
.
return_value
=
[
None
]
*
2
# 2 blocks cached.
assert
tracker
.
get_num_cached_tokens
(
seq1
)
==
8
# Remove the sequence.
tracker
.
remove_seq
(
seq1
.
seq_id
)
# Re-create the sequence with the same request id to simulate recompute.
seq1
=
create_dummy_sequence
(
request_id
=
0
,
token_ids
=
tokens
,
block_size
=
block_size
)
mock_allocator
.
find_cached_blocks_prefix
.
return_value
=
[
]
# no cached block
assert
tracker
.
get_num_cached_tokens
(
seq1
)
==
0
@
staticmethod
def
test_correct_block_hash
():
"""
Test that the block hash is correctly computed for a sequence (should
match the underlying block allocator's block hash). So the number of
cached tokens is correctly retrieved.
"""
block_size
=
4
allocator
=
CpuGpuBlockAllocator
.
create
(
allocator_type
=
"prefix_caching"
,
num_gpu_blocks
=
16
,
num_cpu_blocks
=
16
,
block_size
=
block_size
,
)
gpu_allocator
=
allocator
.
_allocators
[
Device
.
GPU
]
tracker
=
ComputedBlocksTracker
(
allocator
=
allocator
,
block_size
=
block_size
,
enable_caching
=
True
,
)
tokens
=
list
(
range
(
block_size
*
4
))
# 4 blocks.
seq
=
create_dummy_sequence
(
request_id
=
0
,
token_ids
=
tokens
,
block_size
=
block_size
)
_
=
TestPrefixCachingBlockAllocator
.
create_immutable_chain
(
block_size
=
block_size
,
token_ids
=
tokens
,
allocator
=
gpu_allocator
,
)
allocator
.
mark_blocks_as_computed
([])
assert
tracker
.
get_num_cached_tokens
(
seq
)
==
len
(
tokens
)
tests/core/test_scheduler.py
View file @
4634a89d
...
@@ -12,9 +12,9 @@ from vllm.core.scheduler import Scheduler, SchedulingBudget
...
@@ -12,9 +12,9 @@ from vllm.core.scheduler import Scheduler, SchedulingBudget
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.sequence
import
SequenceGroup
from
vllm.sequence
import
SequenceGroup
from
.utils
import
(
append_new_token
,
append_new_token_seq
_group
,
from
.utils
import
(
append_new_token
,
append_new_token_seq
,
create_dummy_prompt
,
get_sequence_groups
,
append_new_token_seq_group
,
create_dummy_prompt
,
schedule_and_update_computed_tokens
)
get_sequence_groups
,
schedule_and_update_computed_tokens
)
def
test_scheduler_add_seq_group
():
def
test_scheduler_add_seq_group
():
...
@@ -305,6 +305,8 @@ def initialize_scheduler(
...
@@ -305,6 +305,8 @@ def initialize_scheduler(
block_size
=
4
,
block_size
=
4
,
num_cpu_blocks
=
8
,
num_cpu_blocks
=
8
,
num_gpu_blocks
=
8
,
num_gpu_blocks
=
8
,
enable_prefix_caching
=
False
,
enable_chunked_prefill
=
False
,
):
):
block_size
=
block_size
block_size
=
block_size
scheduler_config
=
SchedulerConfig
(
scheduler_config
=
SchedulerConfig
(
...
@@ -312,8 +314,15 @@ def initialize_scheduler(
...
@@ -312,8 +314,15 @@ def initialize_scheduler(
max_num_batched_tokens
=
max_token_budget
,
max_num_batched_tokens
=
max_token_budget
,
max_num_seqs
=
max_num_seqs
,
max_num_seqs
=
max_num_seqs
,
max_model_len
=
max_model_len
,
max_model_len
=
max_model_len
,
enable_chunked_prefill
=
enable_chunked_prefill
,
)
cache_config
=
CacheConfig
(
block_size
,
1.0
,
1
,
"auto"
,
enable_prefix_caching
=
enable_prefix_caching
,
)
)
cache_config
=
CacheConfig
(
block_size
,
1.0
,
1
,
"auto"
)
cache_config
.
num_cpu_blocks
=
num_cpu_blocks
cache_config
.
num_cpu_blocks
=
num_cpu_blocks
cache_config
.
num_gpu_blocks
=
num_gpu_blocks
cache_config
.
num_gpu_blocks
=
num_gpu_blocks
scheduler
=
Scheduler
(
scheduler_config
,
cache_config
,
lora_config
)
scheduler
=
Scheduler
(
scheduler_config
,
cache_config
,
lora_config
)
...
@@ -800,3 +809,165 @@ def test_scheduling_budget():
...
@@ -800,3 +809,165 @@ def test_scheduling_budget():
assert
budget
.
num_curr_seqs
==
0
assert
budget
.
num_curr_seqs
==
0
budget
.
subtract_num_seqs
(
seq_group
.
request_id
,
2
)
budget
.
subtract_num_seqs
(
seq_group
.
request_id
,
2
)
assert
budget
.
num_curr_seqs
==
0
assert
budget
.
num_curr_seqs
==
0
@
pytest
.
mark
.
parametrize
(
"enable_prefix_caching"
,
[
True
,
False
])
def
test_prefix_caching_aware_prefills
(
enable_prefix_caching
):
"""
Test the below scenario:
For 3 sequences, seqA, seqB, seqC, share the first block as prefix.
The test verifies the below scenarios:
1. SeqA is first scheduled.
2. SeqB and SeqC can be prefilled together in a single schedule round
even though there are not enough token budgets to prefill both without
considering prefix caching.
"""
block_size
=
4
max_num_batched_tokens
=
12
max_seq_group
=
3
scheduler
=
initialize_scheduler
(
block_size
=
block_size
,
num_cpu_blocks
=
16
,
num_gpu_blocks
=
16
,
max_token_budget
=
max_num_batched_tokens
,
max_num_seqs
=
max_seq_group
,
max_model_len
=
max_num_batched_tokens
,
enable_prefix_caching
=
enable_prefix_caching
,
)
seqA_tokens
=
list
(
range
(
8
))
num_shared_tokens
=
4
seqB_tokens
=
seqA_tokens
[:
num_shared_tokens
]
+
list
(
range
(
12
,
16
))
# Shared prefix first 4.
seqC_tokens
=
seqA_tokens
[:
num_shared_tokens
]
+
list
(
range
(
16
,
20
))
# Shared prefix first 4.
seqA
,
seqA_group
=
create_dummy_prompt
(
"0"
,
prompt_tokens
=
seqA_tokens
,
block_size
=
block_size
)
seqB
,
seqB_group
=
create_dummy_prompt
(
"1"
,
prompt_tokens
=
seqB_tokens
,
block_size
=
block_size
)
seqC
,
seqC_group
=
create_dummy_prompt
(
"2"
,
prompt_tokens
=
seqC_tokens
,
block_size
=
block_size
)
# Schedule seqA prefill.
scheduler
.
add_seq_group
(
seqA_group
)
metas
,
out
,
_
=
scheduler
.
schedule
()
assert
(
len
(
out
.
scheduled_seq_groups
)
==
1
and
out
.
scheduled_seq_groups
[
0
].
seq_group
==
seqA_group
)
assert
out
.
scheduled_seq_groups
[
0
].
token_chunk_size
==
len
(
seqA_tokens
)
# Schedule seqA decode.
append_new_token_seq_group
(
len
(
seqA_tokens
),
seqA_group
,
999
)
metas
,
out
,
_
=
scheduler
.
schedule
()
assert
len
(
out
.
scheduled_seq_groups
)
==
1
assert
out
.
scheduled_seq_groups
[
0
].
seq_group
==
seqA_group
assert
out
.
scheduled_seq_groups
[
0
].
token_chunk_size
==
1
# Schedule seqB and seqC prefills should work with prefix caching.
scheduler
.
add_seq_group
(
seqB_group
)
scheduler
.
add_seq_group
(
seqC_group
)
metas
,
out
,
_
=
scheduler
.
schedule
()
if
enable_prefix_caching
:
assert
len
(
out
.
scheduled_seq_groups
)
==
2
assert
set
([
out
.
scheduled_seq_groups
[
0
].
seq_group
,
out
.
scheduled_seq_groups
[
1
].
seq_group
,
])
==
set
([
seqB_group
,
seqC_group
])
assert
len
(
metas
)
==
2
for
meta
in
metas
:
assert
meta
.
token_chunk_size
==
8
assert
(
len
(
meta
.
computed_block_nums
)
==
num_shared_tokens
//
block_size
)
# 1 Block for the 8 tokens.
else
:
assert
len
(
out
.
scheduled_seq_groups
)
==
1
assert
len
(
metas
)
==
1
assert
metas
[
0
].
token_chunk_size
==
8
assert
len
(
metas
[
0
].
computed_block_nums
)
==
0
# No blocks computed.
def
test_no_multiple_partial_prefills_with_chunked_prefill_and_prefix_caching
(
):
"""
This test verifies that we don't schedule new prefills if there's already
a continuous prefill in progress even though the new prefills with shared
prefix can fit in the token budget:
- SeqA is being chunked prefill.
- SeqB with the same prompt shouldn't be scheduled for prefill even though
there's enough token budget to prefill the cached tokens.
- Neither should seqC be scheduled.
- When seqA is in decoding phase, seqB and seqC can be scheduled.
- Entire seqB should be prefilled since it's a full prefix cache hit.
- SeqC would be partially prefilled with the prefix shared, and the
remaining unique tokens would be prefilled (rounded down to be
block-size aligned).
"""
block_size
=
2
max_num_batched_tokens
=
4
max_seq_group
=
3
scheduler
=
initialize_scheduler
(
block_size
=
block_size
,
num_cpu_blocks
=
16
,
num_gpu_blocks
=
16
,
max_token_budget
=
max_num_batched_tokens
,
max_num_seqs
=
max_seq_group
,
max_model_len
=
100
,
enable_prefix_caching
=
True
,
enable_chunked_prefill
=
True
,
)
seqA_tokens
=
list
(
range
(
8
))
seqB_tokens
=
seqA_tokens
seqC_shared_prefix_len
=
4
seqC_tokens
=
seqA_tokens
[:
seqC_shared_prefix_len
]
+
list
(
range
(
12
,
20
))
seqA
,
seqA_group
=
create_dummy_prompt
(
"0"
,
prompt_tokens
=
seqA_tokens
,
block_size
=
block_size
)
seqB
,
seqB_group
=
create_dummy_prompt
(
"1"
,
prompt_tokens
=
seqB_tokens
,
block_size
=
block_size
)
# Chunked prefill seqA.
scheduler
.
add_seq_group
(
seqA_group
)
metas
,
out
=
schedule_and_update_computed_tokens
(
scheduler
)
assert
len
(
out
.
scheduled_seq_groups
)
==
1
assert
out
.
scheduled_seq_groups
[
0
].
seq_group
==
seqA_group
assert
out
.
scheduled_seq_groups
[
0
].
token_chunk_size
==
4
# seqB should not be scheduled with ongoing prefills.
scheduler
.
add_seq_group
(
seqB_group
)
metas
,
out
=
schedule_and_update_computed_tokens
(
scheduler
)
assert
len
(
out
.
scheduled_seq_groups
)
==
1
assert
out
.
scheduled_seq_groups
[
0
].
seq_group
==
seqA_group
assert
out
.
scheduled_seq_groups
[
0
].
token_chunk_size
==
4
# both seqB and seqC can now be scheduled with seqA is over.
# seqA is in decoding phase.
append_new_token_seq
(
seqA
,
999
)
seqC
,
seqC_group
=
create_dummy_prompt
(
"2"
,
prompt_tokens
=
seqC_tokens
,
block_size
=
block_size
)
scheduler
.
add_seq_group
(
seqC_group
)
metas
,
out
=
schedule_and_update_computed_tokens
(
scheduler
)
assert
len
(
out
.
scheduled_seq_groups
)
==
3
metas
=
{
meta
.
request_id
:
meta
for
meta
in
metas
}
assert
metas
[
seqA_group
.
request_id
].
token_chunk_size
==
1
# Decode
assert
(
metas
[
seqB_group
.
request_id
].
token_chunk_size
==
8
)
# Fully cached prefill
assert
(
metas
[
seqC_group
.
request_id
].
token_chunk_size
==
6
),
"A partial prefix of C (4 tokens) should be prefilled, with the "
"remaining tokens fit into 3 token budget (4-1 from the seqA). It will "
"then be rounded down to 2 tokens on block size, thus 6 tokens in total."
tests/core/utils.py
View file @
4634a89d
import
time
import
time
from
typing
import
List
,
Optional
from
collections
import
defaultdict
from
typing
import
Any
,
Dict
,
List
,
Optional
from
typing
import
Sequence
as
GenericSequence
from
typing
import
Sequence
as
GenericSequence
from
typing
import
Tuple
from
typing
import
Tuple
from
vllm
import
SamplingParams
from
vllm
import
SamplingParams
from
vllm.core.scheduler
import
Scheduler
,
SchedulerOutputs
from
vllm.inputs
import
EncoderDecoderInputs
,
token_inputs
from
vllm.inputs
import
EncoderDecoderInputs
,
token_inputs
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.sequence
import
Logprob
,
Sequence
,
SequenceGroup
from
vllm.sequence
import
(
Logprob
,
Sequence
,
SequenceGroup
,
SequenceGroupMetadata
)
def
create_dummy_prompt
(
def
create_dummy_prompt
(
request_id
:
str
,
request_id
:
str
,
prompt_length
:
int
,
prompt_length
:
int
=
-
1
,
block_size
:
Optional
[
int
]
=
None
,
block_size
:
Optional
[
int
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
best_of
:
int
=
1
,
best_of
:
int
=
1
,
...
@@ -26,6 +29,7 @@ def create_dummy_prompt(
...
@@ -26,6 +29,7 @@ def create_dummy_prompt(
# Create dummy prompt sequence with tokens 0...block_size-1
# Create dummy prompt sequence with tokens 0...block_size-1
# and prompt "0 ... block_size".
# and prompt "0 ... block_size".
prompt_tokens
=
list
(
range
(
prompt_length
))
prompt_tokens
=
list
(
range
(
prompt_length
))
prompt_str
=
" "
.
join
([
str
(
t
)
for
t
in
prompt_tokens
])
prompt_str
=
" "
.
join
([
str
(
t
)
for
t
in
prompt_tokens
])
prompt
=
Sequence
(
int
(
request_id
),
prompt
=
Sequence
(
int
(
request_id
),
inputs
=
token_inputs
(
prompt_tokens
,
prompt
=
prompt_str
),
inputs
=
token_inputs
(
prompt_tokens
,
prompt
=
prompt_str
),
...
@@ -42,6 +46,15 @@ def create_dummy_prompt(
...
@@ -42,6 +46,15 @@ def create_dummy_prompt(
return
prompt
,
seq_group
return
prompt
,
seq_group
def
create_dummy_sequence
(
request_id
:
int
,
token_ids
:
List
[
int
],
block_size
:
int
)
->
Sequence
:
return
Sequence
(
seq_id
=
request_id
,
inputs
=
token_inputs
(
token_ids
),
block_size
=
block_size
,
)
def
create_dummy_prompt_encoder_decoder
(
def
create_dummy_prompt_encoder_decoder
(
request_id
:
str
,
request_id
:
str
,
decoder_prompt_length
:
int
,
decoder_prompt_length
:
int
,
...
@@ -194,12 +207,40 @@ def append_new_token(out, token_id: int):
...
@@ -194,12 +207,40 @@ def append_new_token(out, token_id: int):
def
schedule_and_update_computed_tokens
(
scheduler
):
def
schedule_and_update_computed_tokens
(
scheduler
):
metas
,
out
,
_
=
scheduler
.
schedule
()
metas
,
out
,
_
=
scheduler
.
schedule
()
for
s
,
meta
in
zip
(
out
.
scheduled_seq_groups
,
metas
)
:
for
s
in
out
.
scheduled_seq_groups
:
s
.
seq_group
.
update_num_computed_tokens
(
meta
.
token_chunk_size
)
s
.
seq_group
.
update_num_computed_tokens
(
s
.
token_chunk_size
)
return
metas
,
out
return
metas
,
out
def
append_new_token_seq
(
seq
:
Sequence
,
token_id
:
int
):
seq
.
append_token_id
(
token_id
,
{
token_id
:
Logprob
(
token_id
)})
def
append_new_token_seq_group
(
token_chunk_size
,
seq_group
,
token_id
:
int
):
def
append_new_token_seq_group
(
token_chunk_size
,
seq_group
,
token_id
:
int
):
seq_group
.
update_num_computed_tokens
(
token_chunk_size
)
seq_group
.
update_num_computed_tokens
(
token_chunk_size
)
for
seq
in
seq_group
.
get_seqs
():
for
seq
in
seq_group
.
get_seqs
():
seq
.
append_token_id
(
token_id
,
{
token_id
:
Logprob
(
token_id
)})
seq
.
append_token_id
(
token_id
,
{
token_id
:
Logprob
(
token_id
)})
class
SchedulerProxy
:
"""
A proxy class to forward calls to the scheduler.
"""
def
__init__
(
self
,
scheduler
:
Scheduler
):
self
.
scheduler_
=
scheduler
self
.
call_history
:
Dict
[
str
,
List
[
Any
]]
=
defaultdict
(
list
)
def
__getattr__
(
self
,
name
:
str
)
->
Any
:
def
wrapper
(
*
args
,
**
kwargs
):
result
=
getattr
(
self
.
scheduler_
,
name
)(
*
args
,
**
kwargs
)
self
.
call_history
[
name
].
append
((
args
,
kwargs
,
result
))
return
result
return
wrapper
def
last_schedule_ret
(
self
,
)
->
Tuple
[
List
[
SequenceGroupMetadata
],
SchedulerOutputs
,
Any
]:
_
,
_
,
ret
=
self
.
call_history
[
"schedule"
][
-
1
]
return
ret
tests/prefix_caching/test_prefix_caching.py
View file @
4634a89d
...
@@ -2,10 +2,15 @@
...
@@ -2,10 +2,15 @@
Run `pytest tests/prefix_caching/test_prefix_caching.py`.
Run `pytest tests/prefix_caching/test_prefix_caching.py`.
"""
"""
import
pytest
import
pytest
from
tests.conftest
import
VllmRunner
from
tests.core.utils
import
SchedulerProxy
,
create_dummy_prompt
from
tests.kernels.utils
import
override_backend_env_variable
from
tests.kernels.utils
import
override_backend_env_variable
from
vllm
import
SamplingParams
,
TokensPrompt
from
vllm
import
SamplingParams
,
TokensPrompt
from
vllm.core.scheduler
import
Scheduler
from
vllm.engine.llm_engine
import
LLMEngine
from
..models.utils
import
check_outputs_equal
from
..models.utils
import
check_outputs_equal
...
@@ -27,6 +32,7 @@ UNSTABLE_PROMPT_SEQUENCE = [
...
@@ -27,6 +32,7 @@ UNSTABLE_PROMPT_SEQUENCE = [
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"half"
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"half"
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
5
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
5
])
@
pytest
.
mark
.
parametrize
(
"cached_position"
,
[
0
,
1
])
@
pytest
.
mark
.
parametrize
(
"cached_position"
,
[
0
,
1
])
@
pytest
.
mark
.
parametrize
(
"enable_chunked_prefill"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"block_size"
,
[
16
])
@
pytest
.
mark
.
parametrize
(
"block_size"
,
[
16
])
def
test_mixed_requests
(
def
test_mixed_requests
(
hf_runner
,
hf_runner
,
...
@@ -37,6 +43,7 @@ def test_mixed_requests(
...
@@ -37,6 +43,7 @@ def test_mixed_requests(
dtype
:
str
,
dtype
:
str
,
max_tokens
:
int
,
max_tokens
:
int
,
cached_position
:
int
,
cached_position
:
int
,
enable_chunked_prefill
:
bool
,
block_size
:
int
,
block_size
:
int
,
monkeypatch
,
monkeypatch
,
)
->
None
:
)
->
None
:
...
@@ -55,6 +62,7 @@ def test_mixed_requests(
...
@@ -55,6 +62,7 @@ def test_mixed_requests(
model
,
model
,
dtype
=
dtype
,
dtype
=
dtype
,
enable_prefix_caching
=
True
,
enable_prefix_caching
=
True
,
enable_chunked_prefill
=
enable_chunked_prefill
,
block_size
=
block_size
,
block_size
=
block_size
,
)
as
vllm_model
:
)
as
vllm_model
:
# Run the first prompt so the cache is populated
# Run the first prompt so the cache is populated
...
@@ -72,13 +80,13 @@ def test_mixed_requests(
...
@@ -72,13 +80,13 @@ def test_mixed_requests(
block_size
)
*
block_size
block_size
)
*
block_size
else
:
else
:
expected_num_cached_tokens
=
0
expected_num_cached_tokens
=
0
assert
req_outputs
[
assert
(
i
].
num_cached_tokens
==
expected_num_cached_tokens
req_outputs
[
i
].
num_cached_tokens
==
expected_num_cached_tokens
)
vllm_outputs
=
[
vllm_outputs
=
[
(
(
output
.
prompt_token_ids
+
list
(
output
.
outputs
[
0
].
token_ids
),
output
.
prompt_token_ids
+
list
(
output
.
outputs
[
0
].
token_ids
),
output
.
prompt
+
output
.
outputs
[
0
].
text
)
for
output
in
req_outputs
output
.
prompt
+
output
.
outputs
[
0
].
text
,
]
)
for
output
in
req_outputs
]
check_outputs_equal
(
check_outputs_equal
(
outputs_0_lst
=
hf_outputs
,
outputs_0_lst
=
hf_outputs
,
...
@@ -105,3 +113,89 @@ def test_unstable_prompt_sequence(
...
@@ -105,3 +113,89 @@ def test_unstable_prompt_sequence(
for
prompt
in
UNSTABLE_PROMPT_SEQUENCE
:
for
prompt
in
UNSTABLE_PROMPT_SEQUENCE
:
vllm_model
.
generate
(
TokensPrompt
(
prompt_token_ids
=
prompt
),
vllm_model
.
generate
(
TokensPrompt
(
prompt_token_ids
=
prompt
),
SamplingParams
(
max_tokens
=
1
))
SamplingParams
(
max_tokens
=
1
))
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
def
test_fully_cached_prefill_needs_uncached_token
(
model
):
block_size
=
16
max_num_batched_tokens
=
16
num_output_tokens
=
5
# Make a vllm engine
runner
=
VllmRunner
(
model_name
=
model
,
gpu_memory_utilization
=
0.7
,
enable_chunked_prefill
=
True
,
enforce_eager
=
True
,
enable_prefix_caching
=
True
,
block_size
=
block_size
,
max_num_batched_tokens
=
max_num_batched_tokens
,
max_num_seqs
=
max_num_batched_tokens
,
)
engine
:
LLMEngine
=
runner
.
model
.
llm_engine
scheduler
:
Scheduler
=
SchedulerProxy
(
engine
.
scheduler
[
0
])
# type: ignore
engine
.
scheduler
[
0
]
=
scheduler
# SeqA
seqA_tokens
=
list
(
range
(
2
*
block_size
))
seqA
,
seq_groupA
=
create_dummy_prompt
(
request_id
=
"0"
,
prompt_tokens
=
seqA_tokens
,
max_tokens
=
num_output_tokens
,
block_size
=
block_size
,
)
scheduler
.
add_seq_group
(
seq_groupA
)
assert
seqA
.
data
.
get_num_computed_tokens
()
==
0
# Prefill seqA
while
not
seqA
.
is_finished
():
engine
.
step
()
# seqB
seqB_tokens
=
[
t
+
1
for
t
in
seqA_tokens
]
# shift by 1
seqB
,
seq_groupB
=
create_dummy_prompt
(
request_id
=
"1"
,
prompt_tokens
=
seqB_tokens
,
max_tokens
=
num_output_tokens
,
block_size
=
block_size
,
)
# seqC is the same as seqA
seqC
,
seq_groupC
=
create_dummy_prompt
(
request_id
=
"2"
,
prompt_tokens
=
seqA_tokens
,
max_tokens
=
num_output_tokens
,
block_size
=
block_size
,
)
scheduler
.
add_seq_group
(
seq_groupB
)
scheduler
.
add_seq_group
(
seq_groupC
)
# Even seqC is fully cached, it should not be prefilled since we
# require at least 1 uncached token.
engine
.
step
()
sched_metas
,
sched_out
,
_
=
scheduler
.
last_schedule_ret
()
assert
len
(
sched_out
.
scheduled_seq_groups
)
==
1
assert
(
sched_out
.
scheduled_seq_groups
[
0
].
seq_group
.
request_id
==
seq_groupB
.
request_id
)
assert
(
sched_out
.
scheduled_seq_groups
[
0
].
token_chunk_size
==
max_num_batched_tokens
)
# When seqB is finished, seqC could be prefilled.
while
not
seqB
.
is_finished
():
engine
.
step
()
sched_metas
,
sched_out
,
_
=
scheduler
.
last_schedule_ret
()
assert
len
(
sched_out
.
scheduled_seq_groups
)
==
1
assert
(
sched_out
.
scheduled_seq_groups
[
0
].
seq_group
.
request_id
==
seq_groupB
.
request_id
)
engine
.
step
()
sched_metas
,
sched_out
,
_
=
scheduler
.
last_schedule_ret
()
assert
len
(
sched_out
.
scheduled_seq_groups
)
==
1
assert
(
sched_out
.
scheduled_seq_groups
[
0
].
seq_group
.
request_id
==
seq_groupC
.
request_id
)
assert
sched_out
.
scheduled_seq_groups
[
0
].
token_chunk_size
==
len
(
seqA_tokens
)
vllm/core/block/cpu_gpu_block_allocator.py
View file @
4634a89d
...
@@ -306,14 +306,6 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
...
@@ -306,14 +306,6 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
device
=
Device
.
GPU
device
=
Device
.
GPU
return
self
.
_allocators
[
device
].
mark_blocks_as_computed
(
block_ids
)
return
self
.
_allocators
[
device
].
mark_blocks_as_computed
(
block_ids
)
def
get_computed_block_ids
(
self
,
prev_computed_block_ids
:
List
[
int
],
block_ids
:
List
[
int
],
skip_last_block_id
:
bool
)
->
List
[
int
]:
# Prefix caching only supported on GPU.
device
=
Device
.
GPU
return
self
.
_allocators
[
device
].
get_computed_block_ids
(
prev_computed_block_ids
,
block_ids
,
skip_last_block_id
)
def
get_common_computed_block_ids
(
def
get_common_computed_block_ids
(
self
,
computed_seq_block_ids
:
List
[
List
[
int
]])
->
List
[
int
]:
self
,
computed_seq_block_ids
:
List
[
List
[
int
]])
->
List
[
int
]:
# Prefix caching only supported on GPU.
# Prefix caching only supported on GPU.
...
@@ -342,6 +334,13 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
...
@@ -342,6 +334,13 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
self
.
_swap_mapping
.
clear
()
self
.
_swap_mapping
.
clear
()
return
list
(
mapping
.
items
())
return
list
(
mapping
.
items
())
def
find_cached_blocks_prefix
(
self
,
block_hashes
:
List
[
int
],
device
:
Device
=
Device
.
GPU
,
)
->
List
[
int
]:
return
self
.
_allocators
[
device
].
find_cached_blocks_prefix
(
block_hashes
)
class
NullBlock
(
Block
):
class
NullBlock
(
Block
):
"""
"""
...
...
vllm/core/block/interfaces.py
View file @
4634a89d
...
@@ -159,12 +159,6 @@ class BlockAllocator(ABC):
...
@@ -159,12 +159,6 @@ class BlockAllocator(ABC):
def
mark_blocks_as_computed
(
self
,
block_ids
:
List
[
int
])
->
None
:
def
mark_blocks_as_computed
(
self
,
block_ids
:
List
[
int
])
->
None
:
pass
pass
@
abstractmethod
def
get_computed_block_ids
(
self
,
prev_computed_block_ids
:
List
[
int
],
block_ids
:
List
[
int
],
skip_last_block_id
:
bool
)
->
List
[
int
]:
pass
@
abstractmethod
@
abstractmethod
def
get_common_computed_block_ids
(
def
get_common_computed_block_ids
(
self
,
computed_seq_block_ids
:
List
[
List
[
int
]])
->
List
[
int
]:
self
,
computed_seq_block_ids
:
List
[
List
[
int
]])
->
List
[
int
]:
...
@@ -192,6 +186,13 @@ class BlockAllocator(ABC):
...
@@ -192,6 +186,13 @@ class BlockAllocator(ABC):
class
NoFreeBlocksError
(
ValueError
):
class
NoFreeBlocksError
(
ValueError
):
pass
pass
@
abstractmethod
def
find_cached_blocks_prefix
(
self
,
block_hashes
:
List
[
int
],
)
->
List
[
int
]:
pass
class
DeviceAwareBlockAllocator
(
ABC
):
class
DeviceAwareBlockAllocator
(
ABC
):
...
@@ -207,9 +208,12 @@ class DeviceAwareBlockAllocator(ABC):
...
@@ -207,9 +208,12 @@ class DeviceAwareBlockAllocator(ABC):
pass
pass
@
abstractmethod
@
abstractmethod
def
allocate_immutable_blocks
(
self
,
prev_block
:
Optional
[
Block
],
def
allocate_immutable_blocks
(
block_token_ids
:
List
[
List
[
int
]],
self
,
device
:
Device
)
->
List
[
Block
]:
prev_block
:
Optional
[
Block
],
block_token_ids
:
List
[
List
[
int
]],
device
:
Device
,
)
->
List
[
Block
]:
pass
pass
@
abstractmethod
@
abstractmethod
...
@@ -246,12 +250,6 @@ class DeviceAwareBlockAllocator(ABC):
...
@@ -246,12 +250,6 @@ class DeviceAwareBlockAllocator(ABC):
def
mark_blocks_as_computed
(
self
,
block_ids
:
List
[
int
])
->
None
:
def
mark_blocks_as_computed
(
self
,
block_ids
:
List
[
int
])
->
None
:
pass
pass
@
abstractmethod
def
get_computed_block_ids
(
self
,
prev_computed_block_ids
:
List
[
int
],
block_ids
:
List
[
int
],
skip_last_block_id
:
bool
)
->
List
[
int
]:
pass
@
abstractmethod
@
abstractmethod
def
get_common_computed_block_ids
(
def
get_common_computed_block_ids
(
self
,
computed_seq_block_ids
:
List
[
List
[
int
]])
->
List
[
int
]:
self
,
computed_seq_block_ids
:
List
[
List
[
int
]])
->
List
[
int
]:
...
@@ -284,3 +282,11 @@ class DeviceAwareBlockAllocator(ABC):
...
@@ -284,3 +282,11 @@ class DeviceAwareBlockAllocator(ABC):
def
get_prefix_cache_hit_rate
(
self
,
device
:
Device
)
->
float
:
def
get_prefix_cache_hit_rate
(
self
,
device
:
Device
)
->
float
:
"""Prefix cache hit rate. -1 means not supported or disabled."""
"""Prefix cache hit rate. -1 means not supported or disabled."""
pass
pass
@
abstractmethod
def
find_cached_blocks_prefix
(
self
,
block_hashes
:
List
[
int
],
device
:
Device
=
Device
.
GPU
,
)
->
List
[
int
]:
pass
vllm/core/block/naive_block.py
View file @
4634a89d
...
@@ -262,13 +262,6 @@ class NaiveBlockAllocator(BlockAllocator):
...
@@ -262,13 +262,6 @@ class NaiveBlockAllocator(BlockAllocator):
"""
"""
pass
pass
def
get_computed_block_ids
(
self
,
prev_computed_block_ids
:
List
[
int
],
block_ids
:
List
[
int
],
skip_last_block_id
:
bool
)
->
List
[
int
]:
"""No prefix caching here => return empty list
"""
return
[]
def
get_common_computed_block_ids
(
def
get_common_computed_block_ids
(
self
,
computed_seq_block_ids
:
List
[
List
[
int
]])
->
List
[
int
]:
self
,
computed_seq_block_ids
:
List
[
List
[
int
]])
->
List
[
int
]:
"""Determine blocks that can be skipped in prefill.
"""Determine blocks that can be skipped in prefill.
...
@@ -329,6 +322,10 @@ class NaiveBlockAllocator(BlockAllocator):
...
@@ -329,6 +322,10 @@ class NaiveBlockAllocator(BlockAllocator):
def
get_prefix_cache_hit_rate
(
self
)
->
float
:
def
get_prefix_cache_hit_rate
(
self
)
->
float
:
return
-
1
return
-
1
def
find_cached_blocks_prefix
(
self
,
block_hashes
:
List
[
int
])
->
List
[
int
]:
# Not applicable for naive block allocator.
return
[]
class
NaiveBlock
(
Block
):
class
NaiveBlock
(
Block
):
"""An implementation of the Block class that does not support prefix
"""An implementation of the Block class that does not support prefix
...
...
vllm/core/block/prefix_caching_block.py
View file @
4634a89d
"""Token blocks."""
"""Token blocks."""
import
sys
from
bisect
import
bisect_left
from
os.path
import
commonprefix
from
os.path
import
commonprefix
from
typing
import
Dict
,
FrozenSet
,
Iterable
,
List
,
Optional
,
Set
,
Tuple
from
typing
import
(
Callable
,
Dict
,
FrozenSet
,
Iterable
,
List
,
Optional
,
Set
,
Tuple
)
from
vllm.core.block.common
import
(
CacheMetricData
,
CopyOnWriteTracker
,
from
vllm.core.block.common
import
(
CacheMetricData
,
CopyOnWriteTracker
,
get_all_blocks_recursively
)
get_all_blocks_recursively
)
from
vllm.core.block.interfaces
import
Block
,
BlockAllocator
,
BlockId
,
Device
from
vllm.core.block.interfaces
import
(
Block
,
BlockAllocator
,
BlockId
,
Device
,
DeviceAwareBlockAllocator
)
from
vllm.core.block.naive_block
import
(
BlockPool
,
NaiveBlock
,
from
vllm.core.block.naive_block
import
(
BlockPool
,
NaiveBlock
,
NaiveBlockAllocator
)
NaiveBlockAllocator
)
from
vllm.core.evictor
import
EvictionPolicy
,
Evictor
,
make_evictor
from
vllm.core.evictor
import
EvictionPolicy
,
Evictor
,
make_evictor
from
vllm.sequence
import
Sequence
PrefixHash
=
int
PrefixHash
=
int
...
@@ -534,26 +539,6 @@ class PrefixCachingBlockAllocator(BlockAllocator):
...
@@ -534,26 +539,6 @@ class PrefixCachingBlockAllocator(BlockAllocator):
else
:
else
:
return
block_id
in
self
.
evictor
return
block_id
in
self
.
evictor
def
get_computed_block_ids
(
self
,
prev_computed_block_ids
:
List
[
int
],
block_ids
:
List
[
int
],
skip_last_block_id
:
bool
=
True
)
->
List
[
int
]:
prev_prefix_size
=
len
(
prev_computed_block_ids
)
cur_size
=
len
(
block_ids
)
if
skip_last_block_id
:
cur_size
-=
1
# Sanity checks
assert
cur_size
>=
0
assert
prev_prefix_size
<=
cur_size
ret
=
prev_computed_block_ids
for
i
in
range
(
prev_prefix_size
,
cur_size
):
block_id
=
block_ids
[
i
]
if
self
.
block_is_computed
(
block_id
):
ret
.
append
(
block_id
)
return
ret
def
get_common_computed_block_ids
(
def
get_common_computed_block_ids
(
self
,
computed_seq_block_ids
:
List
[
List
[
int
]])
->
List
[
int
]:
self
,
computed_seq_block_ids
:
List
[
List
[
int
]])
->
List
[
int
]:
"""Return the block ids that are common for a given sequence group.
"""Return the block ids that are common for a given sequence group.
...
@@ -634,6 +619,47 @@ class PrefixCachingBlockAllocator(BlockAllocator):
...
@@ -634,6 +619,47 @@ class PrefixCachingBlockAllocator(BlockAllocator):
block
.
block_id
=
block_id
# Assign block_id
block
.
block_id
=
block_id
# Assign block_id
def
find_cached_blocks_prefix
(
self
,
block_hashes
:
List
[
int
])
->
List
[
int
]:
"""
Given a list of block hashes, return the prefix of the block hashes that
are all cached.
Since a block's block hash includes the hashes of all previous blocks,
and we only allocate/deallocate blocks in the entire sequence, so if a
block is cached, then all previous blocks are also cached. With this
property, we can use binary search to find the prefix of cached blocks.
Args:
block_hashes (List[int]): The list of block hashes.
Returns:
List[int]: The prefix of the `block_hashes` that are cached.
"""
def
_block_is_cached
(
block_hash
:
PrefixHash
)
->
bool
:
if
block_hash
not
in
self
.
_cached_blocks
:
return
False
cached_block_id
=
self
.
_cached_blocks
[
block_hash
]
# We only consider the blocks that are marked as computed.
return
self
.
block_is_computed
(
cached_block_id
)
def
_bisect_left
(
a
,
x
,
key
:
Callable
[[
PrefixHash
],
bool
])
->
int
:
# python <= 3.10 don't have the key argument
if
sys
.
version_info
<
(
3
,
10
):
a
=
[
key
(
e
)
for
e
in
a
]
return
bisect_left
(
a
,
x
)
else
:
return
bisect_left
(
a
,
x
,
key
=
key
)
# Look for the first block that's not cached, and returns the prefix
# i.e. blocks that are cached.
idx
=
_bisect_left
(
block_hashes
,
True
,
key
=
lambda
x
:
not
_block_is_cached
(
x
))
return
block_hashes
[:
idx
]
class
PrefixCachingBlock
(
Block
):
class
PrefixCachingBlock
(
Block
):
"""A block implementation that supports prefix caching.
"""A block implementation that supports prefix caching.
...
@@ -843,86 +869,126 @@ class PrefixCachingBlock(Block):
...
@@ -843,86 +869,126 @@ class PrefixCachingBlock(Block):
class
ComputedBlocksTracker
:
class
ComputedBlocksTracker
:
"""Handles caching of per-sequence computed block ids.
When a sequence appears for the first time, it traverses all of the
blocks and detects the prefix of blocks that is computed. On the
subsequent times, it only traverses the new blocks that were added
and updates the already recorded prefix of blocks with the newly
computed blocks.
To avoid redundant traversals, the algorithm also detects when there
is a "gap" in the computed prefix. For example, if we have blocks =
[1,2,3,4,5], and we have detected [1,2,3] as the computed prefix, then
we won't try to add more computed blocks to [1,2,3] in this sequence
iteration, and will add more computed blocks only after the sequence is
freed and reused again.
Note that currently, for a given sequence, we also skip the last
block id for caching purposes, to avoid caching of a full sequence
"""
"""
Tracks the computed blocks for each sequence.
def
__init__
(
self
,
allocator
):
Internally, it maintains a map from sequence id to the list of block hashes
self
.
_allocator
=
allocator
for the sequence. We cache the hashes of the full blocks for each sequence,
self
.
_cached_computed_seq_blocks
:
Dict
[
int
,
Tuple
[
List
[
int
],
and make sure the hash is calculated in the same way as the allocator.
bool
]]
=
{}
When a sequence is being decoded, we also update the sequence's hash
accordingly and incrementally.
def
add_seq
(
self
,
seq_id
:
int
)
->
None
:
From the sequence hash, with prefix caching enabled, we could also calculate
"""Start tracking seq_id
the number of cached tokens for the sequence by looking up the number of
"""
cached block hashes in the allocator.
assert
seq_id
not
in
self
.
_cached_computed_seq_blocks
"""
self
.
_cached_computed_seq_blocks
[
seq_id
]
=
([],
False
)
def
remove_seq
(
self
,
seq_id
:
int
)
->
None
:
"""Stop tracking seq_id
"""
assert
seq_id
in
self
.
_cached_computed_seq_blocks
del
self
.
_cached_computed_seq_blocks
[
seq_id
]
def
get_cached_computed_blocks_and_update
(
self
,
seq_id
:
int
,
block_ids
:
List
[
int
])
->
List
[
int
]:
""" Look at the class documentation for details
"""
# Ensure seq_id is already tracked
assert
seq_id
in
self
.
_cached_computed_seq_blocks
# Get cached data (may be empty on the first time)
prev_computed_block_ids
,
has_gap
=
self
.
_cached_computed_seq_blocks
[
seq_id
]
if
has_gap
:
# When gap is detected, we do not add more computed blocks at this
# sequence iteration
return
prev_computed_block_ids
# We do not consider the last block id for caching purposes.
num_cur_blocks
=
len
(
block_ids
)
-
1
assert
num_cur_blocks
>=
0
if
len
(
prev_computed_block_ids
)
>=
num_cur_blocks
:
# Cache HIT
assert
len
(
prev_computed_block_ids
)
==
num_cur_blocks
return
prev_computed_block_ids
# If here, then we may possibly add more computed blocks. As a result,
# traverse the additional blocks after prev_computed_block_ids to
# detect more computed blocks and add them.
# Incremental init for seq_id => Look only at the new blocks
computed_block_ids
=
self
.
_allocator
.
get_computed_block_ids
(
# noqa: E501
prev_computed_block_ids
,
block_ids
,
skip_last_block_id
=
True
,
# We skip last block id to avoid caching of full seq
)
# Detect if there is a "gap"
def
__init__
(
has_gap
=
len
(
computed_block_ids
)
<
num_cur_blocks
self
,
allocator
:
DeviceAwareBlockAllocator
,
block_size
:
int
,
enable_caching
:
bool
,
):
self
.
_allocator
=
allocator
self
.
_block_size
=
block_size
self
.
_enable_caching
=
enable_caching
# A map from seq_id to the list of block hashes for the
# sequence. This is so that we don't have to recompute the block hashes
# for the sequence when we need to check if the sequence is cached.
# Note a block that's not full will not have its hash calculated and
# recorded.
self
.
_seq_id_to_blocks_hashes
:
Dict
[
int
,
List
[
int
]]
=
{}
# A map from seq_id to the number of tokens that are cached for the
# sequence.
# We need this so that a sequence in continuous prefill doesn't
# accidentally see its cached token count change. See comments in
# `get_num_cached_tokens` for more details.
self
.
_seq_id_to_num_tokens_computed
:
Dict
[
int
,
int
]
=
{}
def
_update_seq_hashes
(
self
,
seq
:
Sequence
)
->
None
:
"""Incrementally update the sequence's block hashes and record them."""
assert
self
.
_enable_caching
block_hashes_recorded
=
self
.
_seq_id_to_blocks_hashes
.
get
(
seq
.
seq_id
,
[])
cur_num_blocks_recorded
=
len
(
block_hashes_recorded
)
token_ids
=
seq
.
get_token_ids
()
assert
len
(
token_ids
)
>=
cur_num_blocks_recorded
*
self
.
_block_size
,
(
f
"The sequence has
{
len
(
token_ids
)
}
tokens, but"
f
" already recorded
{
cur_num_blocks_recorded
}
blocks. "
"This should not happen since we assume blocks are "
"only appended other than recomputation. When the sequence is "
"recomputed, we should have removed the info of the old blocks."
)
# Update the computed block hashes for the sequence. Since only full
# blocks are considered as "computed", we take floor here.
num_computed_blocks
=
len
(
token_ids
)
//
self
.
_block_size
# We need to know the hash of the previous block to compute the hash of
# the current block so that blocks could be uniquely identified across
# sequences of prefixes.
prev_block_hash
=
(
None
if
cur_num_blocks_recorded
==
0
else
block_hashes_recorded
[
-
1
])
# Only update the computed block hashes for the new blocks
for
i
in
range
(
cur_num_blocks_recorded
,
num_computed_blocks
):
assert
len
(
token_ids
)
>=
(
i
+
1
)
*
self
.
_block_size
block_token_ids
=
token_ids
[
i
*
self
.
_block_size
:(
i
+
1
)
*
self
.
_block_size
]
# This has to be kept in sync with the allocator's hash
# calculation.
block_hash
=
PrefixCachingBlock
.
hash_block_tokens
(
is_first_block
=
prev_block_hash
is
None
,
prev_block_hash
=
prev_block_hash
,
cur_block_token_ids
=
block_token_ids
,
)
block_hashes_recorded
.
append
(
block_hash
)
prev_block_hash
=
block_hash
self
.
_seq_id_to_blocks_hashes
[
seq
.
seq_id
]
=
block_hashes_recorded
def
get_num_cached_tokens
(
self
,
seq
:
Sequence
)
->
int
:
if
not
self
.
_enable_caching
:
return
0
# We always try to update the sequence hashes on the fly.
# This is to ensure that we don't miss any cached tokens for the
# sequence during decode.
# This routine should only update hash for any new blocks too.
self
.
_update_seq_hashes
(
seq
)
num_computed_tokens_prev
=
self
.
_seq_id_to_num_tokens_computed
.
get
(
seq
.
seq_id
,
None
)
# TODO(rickyx): This hack could be removed once we mark blocks as
# computed correctly with chunked prefills.
if
num_computed_tokens_prev
is
not
None
and
seq
.
is_prefill
():
# For a sequence that is still in prefill, we don't
# recompute the number of cached tokens.
# This also handles correctly chunked prefill since currently
# we mark blocks as computed even if the sequence is still partially
# prefilled. So a continuously prefilled sequence should not
# see its cached token count change while running.
return
num_computed_tokens_prev
block_hashes
=
self
.
_seq_id_to_blocks_hashes
[
seq
.
seq_id
]
# This is O(logN), where N is the number of blocks.
num_cached_blocks
=
len
(
self
.
_allocator
.
find_cached_blocks_prefix
(
block_hashes
))
num_cached_tokens
=
num_cached_blocks
*
self
.
_block_size
self
.
_seq_id_to_num_tokens_computed
[
seq
.
seq_id
]
=
num_cached_tokens
return
num_cached_tokens
# Record
def
remove_seq
(
self
,
seq_id
:
int
)
->
None
:
self
.
_cached_computed_seq_blocks
[
seq_id
]
=
(
computed_block_ids
,
"""Stop tracking the sequence."""
has_gap
)
if
not
self
.
_enable_caching
:
return
assert
seq_id
in
self
.
_seq_id_to_blocks_hashes
del
self
.
_seq_id_to_blocks_hashes
[
seq_id
]
return
computed_block_ids
assert
seq_id
in
self
.
_seq_id_to_num_tokens_computed
del
self
.
_seq_id_to_num_tokens_computed
[
seq_id
]
class
LastAccessBlocksTracker
:
class
LastAccessBlocksTracker
:
...
...
vllm/core/block_manager.py
View file @
4634a89d
...
@@ -101,7 +101,7 @@ class SelfAttnBlockSpaceManager(BlockSpaceManager):
...
@@ -101,7 +101,7 @@ class SelfAttnBlockSpaceManager(BlockSpaceManager):
self
.
cross_block_tables
:
Dict
[
EncoderSeqId
,
BlockTable
]
=
{}
self
.
cross_block_tables
:
Dict
[
EncoderSeqId
,
BlockTable
]
=
{}
self
.
_computed_blocks_tracker
=
ComputedBlocksTracker
(
self
.
_computed_blocks_tracker
=
ComputedBlocksTracker
(
self
.
block_allocator
)
self
.
block_allocator
,
self
.
block_size
,
self
.
enable_caching
)
self
.
_last_access_blocks_tracker
=
LastAccessBlocksTracker
(
self
.
_last_access_blocks_tracker
=
LastAccessBlocksTracker
(
self
.
block_allocator
)
self
.
block_allocator
)
...
@@ -170,7 +170,6 @@ class SelfAttnBlockSpaceManager(BlockSpaceManager):
...
@@ -170,7 +170,6 @@ class SelfAttnBlockSpaceManager(BlockSpaceManager):
self
.
block_tables
[
seq
.
seq_id
]
=
block_table
self
.
block_tables
[
seq
.
seq_id
]
=
block_table
# Track seq
# Track seq
self
.
_computed_blocks_tracker
.
add_seq
(
seq
.
seq_id
)
self
.
_last_access_blocks_tracker
.
add_seq
(
seq
.
seq_id
)
self
.
_last_access_blocks_tracker
.
add_seq
(
seq
.
seq_id
)
# Assign the block table for each sequence.
# Assign the block table for each sequence.
...
@@ -178,7 +177,6 @@ class SelfAttnBlockSpaceManager(BlockSpaceManager):
...
@@ -178,7 +177,6 @@ class SelfAttnBlockSpaceManager(BlockSpaceManager):
self
.
block_tables
[
seq
.
seq_id
]
=
block_table
.
fork
()
self
.
block_tables
[
seq
.
seq_id
]
=
block_table
.
fork
()
# Track seq
# Track seq
self
.
_computed_blocks_tracker
.
add_seq
(
seq
.
seq_id
)
self
.
_last_access_blocks_tracker
.
add_seq
(
seq
.
seq_id
)
self
.
_last_access_blocks_tracker
.
add_seq
(
seq
.
seq_id
)
# Allocate cross-attention block table for encoder sequence
# Allocate cross-attention block table for encoder sequence
...
@@ -314,11 +312,13 @@ class SelfAttnBlockSpaceManager(BlockSpaceManager):
...
@@ -314,11 +312,13 @@ class SelfAttnBlockSpaceManager(BlockSpaceManager):
"""
"""
computed_seq_block_ids
=
[]
computed_seq_block_ids
=
[]
for
seq
in
seqs
:
for
seq
in
seqs
:
computed_seq_block_ids
.
append
(
all_blocks
=
self
.
block_tables
[
seq
.
seq_id
].
physical_block_ids
self
.
_computed_blocks_tracker
.
num_cached_tokens
=
(
get_cached_computed_blocks_and_update
(
self
.
_computed_blocks_tracker
.
get_num_cached_tokens
(
seq
))
seq
.
seq_id
,
assert
num_cached_tokens
%
self
.
block_size
==
0
self
.
block_tables
[
seq
.
seq_id
].
physical_block_ids
))
num_cached_blocks
=
num_cached_tokens
//
self
.
block_size
computed_block_ids
=
all_blocks
[:
num_cached_blocks
]
computed_seq_block_ids
.
append
(
computed_block_ids
)
# NOTE(sang): This assumes seq_block_ids doesn't contain any None.
# NOTE(sang): This assumes seq_block_ids doesn't contain any None.
return
self
.
block_allocator
.
get_common_computed_block_ids
(
return
self
.
block_allocator
.
get_common_computed_block_ids
(
...
@@ -332,7 +332,6 @@ class SelfAttnBlockSpaceManager(BlockSpaceManager):
...
@@ -332,7 +332,6 @@ class SelfAttnBlockSpaceManager(BlockSpaceManager):
self
.
block_tables
[
child_seq
.
seq_id
]
=
src_block_table
.
fork
()
self
.
block_tables
[
child_seq
.
seq_id
]
=
src_block_table
.
fork
()
# Track child seq
# Track child seq
self
.
_computed_blocks_tracker
.
add_seq
(
child_seq
.
seq_id
)
self
.
_last_access_blocks_tracker
.
add_seq
(
child_seq
.
seq_id
)
self
.
_last_access_blocks_tracker
.
add_seq
(
child_seq
.
seq_id
)
def
can_swap_in
(
self
,
seq_group
:
SequenceGroup
,
def
can_swap_in
(
self
,
seq_group
:
SequenceGroup
,
...
@@ -503,3 +502,9 @@ class SelfAttnBlockSpaceManager(BlockSpaceManager):
...
@@ -503,3 +502,9 @@ class SelfAttnBlockSpaceManager(BlockSpaceManager):
return
AllocStatus
.
OK
return
AllocStatus
.
OK
else
:
else
:
return
AllocStatus
.
LATER
return
AllocStatus
.
LATER
def
get_num_cached_tokens
(
self
,
seq
:
Sequence
)
->
int
:
"""Get the number of tokens in blocks that are already computed and
cached in the block manager for the sequence.
"""
return
self
.
_computed_blocks_tracker
.
get_num_cached_tokens
(
seq
)
vllm/core/interfaces.py
View file @
4634a89d
...
@@ -121,3 +121,7 @@ class BlockSpaceManager(ABC):
...
@@ -121,3 +121,7 @@ class BlockSpaceManager(ABC):
def
get_prefix_cache_hit_rate
(
self
,
device
:
Device
)
->
float
:
def
get_prefix_cache_hit_rate
(
self
,
device
:
Device
)
->
float
:
"""Prefix cache hit rate. -1 means not supported or disabled."""
"""Prefix cache hit rate. -1 means not supported or disabled."""
pass
pass
@
abstractmethod
def
get_num_cached_tokens
(
self
,
seq
:
Sequence
)
->
int
:
pass
vllm/core/placeholder_block_space_manager.py
View file @
4634a89d
...
@@ -89,3 +89,6 @@ class PlaceholderBlockSpaceManager(BlockSpaceManager):
...
@@ -89,3 +89,6 @@ class PlaceholderBlockSpaceManager(BlockSpaceManager):
def
get_prefix_cache_hit_rate
(
self
,
device
:
Device
)
->
float
:
def
get_prefix_cache_hit_rate
(
self
,
device
:
Device
)
->
float
:
return
-
1
return
-
1
def
get_num_cached_tokens
(
self
,
seq
:
Sequence
)
->
int
:
return
0
vllm/core/scheduler.py
View file @
4634a89d
This diff is collapsed.
Click to expand it.
vllm/sequence.py
View file @
4634a89d
...
@@ -579,6 +579,9 @@ class Sequence:
...
@@ -579,6 +579,9 @@ class Sequence:
return
1
return
1
return
self
.
data
.
get_num_uncomputed_tokens
()
return
self
.
data
.
get_num_uncomputed_tokens
()
def
get_num_computed_tokens
(
self
)
->
int
:
return
self
.
data
.
get_num_computed_tokens
()
def
is_prefill
(
self
)
->
bool
:
def
is_prefill
(
self
)
->
bool
:
return
self
.
data
.
stage
==
SequenceStage
.
PREFILL
return
self
.
data
.
stage
==
SequenceStage
.
PREFILL
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment