Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
4634a89d
Unverified
Commit
4634a89d
authored
Nov 22, 2024
by
Ricky Xu
Committed by
GitHub
Nov 22, 2024
Browse files
Prefix Cache Aware Scheduling [1/n] (#10128)
Signed-off-by:
rickyx
<
rickyx@anyscale.com
>
parent
7c25fe45
Changes
13
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
967 additions
and
241 deletions
+967
-241
tests/core/block/test_prefix_caching_block.py
tests/core/block/test_prefix_caching_block.py
+175
-6
tests/core/test_scheduler.py
tests/core/test_scheduler.py
+175
-4
tests/core/utils.py
tests/core/utils.py
+46
-5
tests/prefix_caching/test_prefix_caching.py
tests/prefix_caching/test_prefix_caching.py
+100
-6
vllm/core/block/cpu_gpu_block_allocator.py
vllm/core/block/cpu_gpu_block_allocator.py
+7
-8
vllm/core/block/interfaces.py
vllm/core/block/interfaces.py
+21
-15
vllm/core/block/naive_block.py
vllm/core/block/naive_block.py
+4
-7
vllm/core/block/prefix_caching_block.py
vllm/core/block/prefix_caching_block.py
+162
-96
vllm/core/block_manager.py
vllm/core/block_manager.py
+14
-9
vllm/core/interfaces.py
vllm/core/interfaces.py
+4
-0
vllm/core/placeholder_block_space_manager.py
vllm/core/placeholder_block_space_manager.py
+3
-0
vllm/core/scheduler.py
vllm/core/scheduler.py
+253
-85
vllm/sequence.py
vllm/sequence.py
+3
-0
No files found.
tests/core/block/test_prefix_caching_block.py
View file @
4634a89d
...
...
@@ -5,9 +5,14 @@ from unittest.mock import MagicMock
import
pytest
from
tests.core.utils
import
create_dummy_sequence
from
vllm.core.block.cpu_gpu_block_allocator
import
CpuGpuBlockAllocator
from
vllm.core.block.interfaces
import
Block
,
BlockAllocator
from
vllm.core.block.prefix_caching_block
import
(
PrefixCachingBlock
,
from
vllm.core.block.prefix_caching_block
import
(
ComputedBlocksTracker
,
PrefixCachingBlock
,
PrefixCachingBlockAllocator
)
from
vllm.sequence
import
Logprob
from
vllm.utils
import
Device
class
TestPrefixCachingBlock
:
...
...
@@ -726,18 +731,71 @@ class TestPrefixCachingBlockAllocator:
token_ids
=
common_token_ids
,
allocator
=
allocator
,
)
block_
id
s
=
[
block
.
block_id
for
block
in
blocks
]
block_
hashe
s
=
[
block
.
content_hash
for
block
in
blocks
]
# The allocated blocks should be marked as touched
# but not computed.
computed_block_ids
=
allocator
.
get_computed_block_ids
(
[],
block_
ids
,
skip_last_block_id
=
False
)
computed_block_ids
=
allocator
.
find_cached_blocks_prefix
(
block_
hashes
)
assert
len
(
computed_block_ids
)
==
0
allocator
.
mark_blocks_as_computed
([])
computed_block_ids
=
allocator
.
get_computed_block_ids
(
[],
block_
ids
,
skip_last_block_id
=
False
)
computed_block_ids
=
allocator
.
find_cached_blocks_prefix
(
block_
hashes
=
block_hashes
)
assert
len
(
computed_block_ids
)
==
common_blocks
@
staticmethod
def
test_find_cached_blocks_prefix
():
"""
This test verifies the behavior of find_cached_blocks_prefix.
"""
block_size
=
4
num_blocks
=
8
total_test_blocks
=
12
allocator
=
PrefixCachingBlockAllocator
(
num_blocks
=
num_blocks
,
block_size
=
block_size
)
token_ids
=
list
(
range
(
total_test_blocks
*
block_size
))
block_tokens_seq1
=
token_ids
[:
num_blocks
*
block_size
]
blocks_seq1
=
TestPrefixCachingBlockAllocator
.
create_immutable_chain
(
block_size
=
block_size
,
token_ids
=
block_tokens_seq1
,
allocator
=
allocator
,
)
block_hashes_seq1
=
[
block
.
content_hash
for
block
in
blocks_seq1
]
allocator
.
mark_blocks_as_computed
([])
# All blocks should be cached.
cached_blocks_seq1
=
allocator
.
find_cached_blocks_prefix
(
block_hashes
=
block_hashes_seq1
)
assert
len
(
cached_blocks_seq1
)
==
num_blocks
# Free the first sequence.
for
block
in
blocks_seq1
:
allocator
.
free
(
block
)
# All blocks should be still be cached if not required to be allocated.
cached_blocks
=
allocator
.
find_cached_blocks_prefix
(
block_hashes
=
block_hashes_seq1
)
assert
len
(
cached_blocks
)
==
num_blocks
block_tokens_seq2
=
token_ids
[
num_blocks
*
block_size
:]
blocks_seq2
=
TestPrefixCachingBlockAllocator
.
create_immutable_chain
(
block_size
=
block_size
,
token_ids
=
block_tokens_seq2
,
allocator
=
allocator
,
)
block_hashes_seq2
=
[
block
.
content_hash
for
block
in
blocks_seq2
]
allocator
.
mark_blocks_as_computed
([])
cached_blocks
=
allocator
.
find_cached_blocks_prefix
(
block_hashes
=
block_hashes_seq2
)
assert
len
(
cached_blocks
)
==
len
(
blocks_seq2
)
# Half of the blocks from seq1 should still be cached.
num_evicted_blocks
=
len
(
blocks_seq2
)
cached_blocks
=
allocator
.
find_cached_blocks_prefix
(
block_hashes
=
block_hashes_seq1
)
assert
len
(
cached_blocks
)
==
len
(
blocks_seq1
)
-
num_evicted_blocks
@
staticmethod
def
create_immutable_chain
(
block_size
:
int
,
...
...
@@ -762,3 +820,114 @@ class TestPrefixCachingBlockAllocator:
blocks
.
append
(
prev_block
)
return
blocks
class
TestComputedBlocksTracker
:
@
staticmethod
def
_get_mock_allocator
():
return
MagicMock
(
spec
=
PrefixCachingBlockAllocator
)
@
staticmethod
def
test_get_num_cached_tokens
():
"""
Test it correctly computes the number of cached tokens for a given
sequence:
- The cache token count is derived from the number of cached blocks.
- The cache token count is updated when the allocator is updated.
- When a sequence is removed, the cache token count should be updated
accordingly.
# TODO(rickyx): This behaviour for prefill sequence is a hack until
we fix the computed blocks tracking.
- The cache token count for prefill sequence doesn't change while
the sequence is in continuous prefill (chunked prefill).
"""
block_size
=
4
mock_allocator
=
TestComputedBlocksTracker
.
_get_mock_allocator
()
tracker
=
ComputedBlocksTracker
(
allocator
=
mock_allocator
,
block_size
=
block_size
,
enable_caching
=
True
,
)
# Not yet allocated.
tokens
=
[
0
,
1
,
2
,
3
,
4
,
5
]
seq1
=
create_dummy_sequence
(
request_id
=
0
,
token_ids
=
tokens
,
block_size
=
block_size
)
mock_allocator
.
find_cached_blocks_prefix
.
return_value
=
[]
assert
tracker
.
get_num_cached_tokens
(
seq1
)
==
0
mock_allocator
.
find_cached_blocks_prefix
.
return_value
=
[
None
]
# 1 block cached.
# Result is cached for prefill sequence.
assert
tracker
.
get_num_cached_tokens
(
seq1
)
==
0
# Mark the sequence as non-prefill.
seq1
.
data
.
update_num_computed_tokens
(
len
(
tokens
))
# 6 tokens computed.
assert
not
seq1
.
is_prefill
()
# Recomputes for decoding sequence.
assert
tracker
.
get_num_cached_tokens
(
seq1
)
==
4
# Append new tokens to the sequence.
num_new_tokens
=
3
for
i
in
range
(
num_new_tokens
):
seq1
.
append_token_id
(
i
,
{
i
:
Logprob
(
logprob
=
0.0
)})
assert
tracker
.
get_num_cached_tokens
(
seq1
)
==
4
# Update the allocator.
mock_allocator
.
find_cached_blocks_prefix
.
return_value
=
[
None
]
*
2
# 2 blocks cached.
assert
tracker
.
get_num_cached_tokens
(
seq1
)
==
8
# Remove the sequence.
tracker
.
remove_seq
(
seq1
.
seq_id
)
# Re-create the sequence with the same request id to simulate recompute.
seq1
=
create_dummy_sequence
(
request_id
=
0
,
token_ids
=
tokens
,
block_size
=
block_size
)
mock_allocator
.
find_cached_blocks_prefix
.
return_value
=
[
]
# no cached block
assert
tracker
.
get_num_cached_tokens
(
seq1
)
==
0
@
staticmethod
def
test_correct_block_hash
():
"""
Test that the block hash is correctly computed for a sequence (should
match the underlying block allocator's block hash). So the number of
cached tokens is correctly retrieved.
"""
block_size
=
4
allocator
=
CpuGpuBlockAllocator
.
create
(
allocator_type
=
"prefix_caching"
,
num_gpu_blocks
=
16
,
num_cpu_blocks
=
16
,
block_size
=
block_size
,
)
gpu_allocator
=
allocator
.
_allocators
[
Device
.
GPU
]
tracker
=
ComputedBlocksTracker
(
allocator
=
allocator
,
block_size
=
block_size
,
enable_caching
=
True
,
)
tokens
=
list
(
range
(
block_size
*
4
))
# 4 blocks.
seq
=
create_dummy_sequence
(
request_id
=
0
,
token_ids
=
tokens
,
block_size
=
block_size
)
_
=
TestPrefixCachingBlockAllocator
.
create_immutable_chain
(
block_size
=
block_size
,
token_ids
=
tokens
,
allocator
=
gpu_allocator
,
)
allocator
.
mark_blocks_as_computed
([])
assert
tracker
.
get_num_cached_tokens
(
seq
)
==
len
(
tokens
)
tests/core/test_scheduler.py
View file @
4634a89d
...
...
@@ -12,9 +12,9 @@ from vllm.core.scheduler import Scheduler, SchedulingBudget
from
vllm.lora.request
import
LoRARequest
from
vllm.sequence
import
SequenceGroup
from
.utils
import
(
append_new_token
,
append_new_token_seq
_group
,
create_dummy_prompt
,
get_sequence_groups
,
schedule_and_update_computed_tokens
)
from
.utils
import
(
append_new_token
,
append_new_token_seq
,
append_new_token_seq_group
,
create_dummy_prompt
,
get_sequence_groups
,
schedule_and_update_computed_tokens
)
def
test_scheduler_add_seq_group
():
...
...
@@ -305,6 +305,8 @@ def initialize_scheduler(
block_size
=
4
,
num_cpu_blocks
=
8
,
num_gpu_blocks
=
8
,
enable_prefix_caching
=
False
,
enable_chunked_prefill
=
False
,
):
block_size
=
block_size
scheduler_config
=
SchedulerConfig
(
...
...
@@ -312,8 +314,15 @@ def initialize_scheduler(
max_num_batched_tokens
=
max_token_budget
,
max_num_seqs
=
max_num_seqs
,
max_model_len
=
max_model_len
,
enable_chunked_prefill
=
enable_chunked_prefill
,
)
cache_config
=
CacheConfig
(
block_size
,
1.0
,
1
,
"auto"
,
enable_prefix_caching
=
enable_prefix_caching
,
)
cache_config
=
CacheConfig
(
block_size
,
1.0
,
1
,
"auto"
)
cache_config
.
num_cpu_blocks
=
num_cpu_blocks
cache_config
.
num_gpu_blocks
=
num_gpu_blocks
scheduler
=
Scheduler
(
scheduler_config
,
cache_config
,
lora_config
)
...
...
@@ -800,3 +809,165 @@ def test_scheduling_budget():
assert
budget
.
num_curr_seqs
==
0
budget
.
subtract_num_seqs
(
seq_group
.
request_id
,
2
)
assert
budget
.
num_curr_seqs
==
0
@
pytest
.
mark
.
parametrize
(
"enable_prefix_caching"
,
[
True
,
False
])
def
test_prefix_caching_aware_prefills
(
enable_prefix_caching
):
"""
Test the below scenario:
For 3 sequences, seqA, seqB, seqC, share the first block as prefix.
The test verifies the below scenarios:
1. SeqA is first scheduled.
2. SeqB and SeqC can be prefilled together in a single schedule round
even though there are not enough token budgets to prefill both without
considering prefix caching.
"""
block_size
=
4
max_num_batched_tokens
=
12
max_seq_group
=
3
scheduler
=
initialize_scheduler
(
block_size
=
block_size
,
num_cpu_blocks
=
16
,
num_gpu_blocks
=
16
,
max_token_budget
=
max_num_batched_tokens
,
max_num_seqs
=
max_seq_group
,
max_model_len
=
max_num_batched_tokens
,
enable_prefix_caching
=
enable_prefix_caching
,
)
seqA_tokens
=
list
(
range
(
8
))
num_shared_tokens
=
4
seqB_tokens
=
seqA_tokens
[:
num_shared_tokens
]
+
list
(
range
(
12
,
16
))
# Shared prefix first 4.
seqC_tokens
=
seqA_tokens
[:
num_shared_tokens
]
+
list
(
range
(
16
,
20
))
# Shared prefix first 4.
seqA
,
seqA_group
=
create_dummy_prompt
(
"0"
,
prompt_tokens
=
seqA_tokens
,
block_size
=
block_size
)
seqB
,
seqB_group
=
create_dummy_prompt
(
"1"
,
prompt_tokens
=
seqB_tokens
,
block_size
=
block_size
)
seqC
,
seqC_group
=
create_dummy_prompt
(
"2"
,
prompt_tokens
=
seqC_tokens
,
block_size
=
block_size
)
# Schedule seqA prefill.
scheduler
.
add_seq_group
(
seqA_group
)
metas
,
out
,
_
=
scheduler
.
schedule
()
assert
(
len
(
out
.
scheduled_seq_groups
)
==
1
and
out
.
scheduled_seq_groups
[
0
].
seq_group
==
seqA_group
)
assert
out
.
scheduled_seq_groups
[
0
].
token_chunk_size
==
len
(
seqA_tokens
)
# Schedule seqA decode.
append_new_token_seq_group
(
len
(
seqA_tokens
),
seqA_group
,
999
)
metas
,
out
,
_
=
scheduler
.
schedule
()
assert
len
(
out
.
scheduled_seq_groups
)
==
1
assert
out
.
scheduled_seq_groups
[
0
].
seq_group
==
seqA_group
assert
out
.
scheduled_seq_groups
[
0
].
token_chunk_size
==
1
# Schedule seqB and seqC prefills should work with prefix caching.
scheduler
.
add_seq_group
(
seqB_group
)
scheduler
.
add_seq_group
(
seqC_group
)
metas
,
out
,
_
=
scheduler
.
schedule
()
if
enable_prefix_caching
:
assert
len
(
out
.
scheduled_seq_groups
)
==
2
assert
set
([
out
.
scheduled_seq_groups
[
0
].
seq_group
,
out
.
scheduled_seq_groups
[
1
].
seq_group
,
])
==
set
([
seqB_group
,
seqC_group
])
assert
len
(
metas
)
==
2
for
meta
in
metas
:
assert
meta
.
token_chunk_size
==
8
assert
(
len
(
meta
.
computed_block_nums
)
==
num_shared_tokens
//
block_size
)
# 1 Block for the 8 tokens.
else
:
assert
len
(
out
.
scheduled_seq_groups
)
==
1
assert
len
(
metas
)
==
1
assert
metas
[
0
].
token_chunk_size
==
8
assert
len
(
metas
[
0
].
computed_block_nums
)
==
0
# No blocks computed.
def
test_no_multiple_partial_prefills_with_chunked_prefill_and_prefix_caching
(
):
"""
This test verifies that we don't schedule new prefills if there's already
a continuous prefill in progress even though the new prefills with shared
prefix can fit in the token budget:
- SeqA is being chunked prefill.
- SeqB with the same prompt shouldn't be scheduled for prefill even though
there's enough token budget to prefill the cached tokens.
- Neither should seqC be scheduled.
- When seqA is in decoding phase, seqB and seqC can be scheduled.
- Entire seqB should be prefilled since it's a full prefix cache hit.
- SeqC would be partially prefilled with the prefix shared, and the
remaining unique tokens would be prefilled (rounded down to be
block-size aligned).
"""
block_size
=
2
max_num_batched_tokens
=
4
max_seq_group
=
3
scheduler
=
initialize_scheduler
(
block_size
=
block_size
,
num_cpu_blocks
=
16
,
num_gpu_blocks
=
16
,
max_token_budget
=
max_num_batched_tokens
,
max_num_seqs
=
max_seq_group
,
max_model_len
=
100
,
enable_prefix_caching
=
True
,
enable_chunked_prefill
=
True
,
)
seqA_tokens
=
list
(
range
(
8
))
seqB_tokens
=
seqA_tokens
seqC_shared_prefix_len
=
4
seqC_tokens
=
seqA_tokens
[:
seqC_shared_prefix_len
]
+
list
(
range
(
12
,
20
))
seqA
,
seqA_group
=
create_dummy_prompt
(
"0"
,
prompt_tokens
=
seqA_tokens
,
block_size
=
block_size
)
seqB
,
seqB_group
=
create_dummy_prompt
(
"1"
,
prompt_tokens
=
seqB_tokens
,
block_size
=
block_size
)
# Chunked prefill seqA.
scheduler
.
add_seq_group
(
seqA_group
)
metas
,
out
=
schedule_and_update_computed_tokens
(
scheduler
)
assert
len
(
out
.
scheduled_seq_groups
)
==
1
assert
out
.
scheduled_seq_groups
[
0
].
seq_group
==
seqA_group
assert
out
.
scheduled_seq_groups
[
0
].
token_chunk_size
==
4
# seqB should not be scheduled with ongoing prefills.
scheduler
.
add_seq_group
(
seqB_group
)
metas
,
out
=
schedule_and_update_computed_tokens
(
scheduler
)
assert
len
(
out
.
scheduled_seq_groups
)
==
1
assert
out
.
scheduled_seq_groups
[
0
].
seq_group
==
seqA_group
assert
out
.
scheduled_seq_groups
[
0
].
token_chunk_size
==
4
# both seqB and seqC can now be scheduled with seqA is over.
# seqA is in decoding phase.
append_new_token_seq
(
seqA
,
999
)
seqC
,
seqC_group
=
create_dummy_prompt
(
"2"
,
prompt_tokens
=
seqC_tokens
,
block_size
=
block_size
)
scheduler
.
add_seq_group
(
seqC_group
)
metas
,
out
=
schedule_and_update_computed_tokens
(
scheduler
)
assert
len
(
out
.
scheduled_seq_groups
)
==
3
metas
=
{
meta
.
request_id
:
meta
for
meta
in
metas
}
assert
metas
[
seqA_group
.
request_id
].
token_chunk_size
==
1
# Decode
assert
(
metas
[
seqB_group
.
request_id
].
token_chunk_size
==
8
)
# Fully cached prefill
assert
(
metas
[
seqC_group
.
request_id
].
token_chunk_size
==
6
),
"A partial prefix of C (4 tokens) should be prefilled, with the "
"remaining tokens fit into 3 token budget (4-1 from the seqA). It will "
"then be rounded down to 2 tokens on block size, thus 6 tokens in total."
tests/core/utils.py
View file @
4634a89d
import
time
from
typing
import
List
,
Optional
from
collections
import
defaultdict
from
typing
import
Any
,
Dict
,
List
,
Optional
from
typing
import
Sequence
as
GenericSequence
from
typing
import
Tuple
from
vllm
import
SamplingParams
from
vllm.core.scheduler
import
Scheduler
,
SchedulerOutputs
from
vllm.inputs
import
EncoderDecoderInputs
,
token_inputs
from
vllm.lora.request
import
LoRARequest
from
vllm.sequence
import
Logprob
,
Sequence
,
SequenceGroup
from
vllm.sequence
import
(
Logprob
,
Sequence
,
SequenceGroup
,
SequenceGroupMetadata
)
def
create_dummy_prompt
(
request_id
:
str
,
prompt_length
:
int
,
prompt_length
:
int
=
-
1
,
block_size
:
Optional
[
int
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
best_of
:
int
=
1
,
...
...
@@ -26,6 +29,7 @@ def create_dummy_prompt(
# Create dummy prompt sequence with tokens 0...block_size-1
# and prompt "0 ... block_size".
prompt_tokens
=
list
(
range
(
prompt_length
))
prompt_str
=
" "
.
join
([
str
(
t
)
for
t
in
prompt_tokens
])
prompt
=
Sequence
(
int
(
request_id
),
inputs
=
token_inputs
(
prompt_tokens
,
prompt
=
prompt_str
),
...
...
@@ -42,6 +46,15 @@ def create_dummy_prompt(
return
prompt
,
seq_group
def
create_dummy_sequence
(
request_id
:
int
,
token_ids
:
List
[
int
],
block_size
:
int
)
->
Sequence
:
return
Sequence
(
seq_id
=
request_id
,
inputs
=
token_inputs
(
token_ids
),
block_size
=
block_size
,
)
def
create_dummy_prompt_encoder_decoder
(
request_id
:
str
,
decoder_prompt_length
:
int
,
...
...
@@ -194,12 +207,40 @@ def append_new_token(out, token_id: int):
def
schedule_and_update_computed_tokens
(
scheduler
):
metas
,
out
,
_
=
scheduler
.
schedule
()
for
s
,
meta
in
zip
(
out
.
scheduled_seq_groups
,
metas
)
:
s
.
seq_group
.
update_num_computed_tokens
(
meta
.
token_chunk_size
)
for
s
in
out
.
scheduled_seq_groups
:
s
.
seq_group
.
update_num_computed_tokens
(
s
.
token_chunk_size
)
return
metas
,
out
def
append_new_token_seq
(
seq
:
Sequence
,
token_id
:
int
):
seq
.
append_token_id
(
token_id
,
{
token_id
:
Logprob
(
token_id
)})
def
append_new_token_seq_group
(
token_chunk_size
,
seq_group
,
token_id
:
int
):
seq_group
.
update_num_computed_tokens
(
token_chunk_size
)
for
seq
in
seq_group
.
get_seqs
():
seq
.
append_token_id
(
token_id
,
{
token_id
:
Logprob
(
token_id
)})
class
SchedulerProxy
:
"""
A proxy class to forward calls to the scheduler.
"""
def
__init__
(
self
,
scheduler
:
Scheduler
):
self
.
scheduler_
=
scheduler
self
.
call_history
:
Dict
[
str
,
List
[
Any
]]
=
defaultdict
(
list
)
def
__getattr__
(
self
,
name
:
str
)
->
Any
:
def
wrapper
(
*
args
,
**
kwargs
):
result
=
getattr
(
self
.
scheduler_
,
name
)(
*
args
,
**
kwargs
)
self
.
call_history
[
name
].
append
((
args
,
kwargs
,
result
))
return
result
return
wrapper
def
last_schedule_ret
(
self
,
)
->
Tuple
[
List
[
SequenceGroupMetadata
],
SchedulerOutputs
,
Any
]:
_
,
_
,
ret
=
self
.
call_history
[
"schedule"
][
-
1
]
return
ret
tests/prefix_caching/test_prefix_caching.py
View file @
4634a89d
...
...
@@ -2,10 +2,15 @@
Run `pytest tests/prefix_caching/test_prefix_caching.py`.
"""
import
pytest
from
tests.conftest
import
VllmRunner
from
tests.core.utils
import
SchedulerProxy
,
create_dummy_prompt
from
tests.kernels.utils
import
override_backend_env_variable
from
vllm
import
SamplingParams
,
TokensPrompt
from
vllm.core.scheduler
import
Scheduler
from
vllm.engine.llm_engine
import
LLMEngine
from
..models.utils
import
check_outputs_equal
...
...
@@ -27,6 +32,7 @@ UNSTABLE_PROMPT_SEQUENCE = [
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"half"
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
5
])
@
pytest
.
mark
.
parametrize
(
"cached_position"
,
[
0
,
1
])
@
pytest
.
mark
.
parametrize
(
"enable_chunked_prefill"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"block_size"
,
[
16
])
def
test_mixed_requests
(
hf_runner
,
...
...
@@ -37,6 +43,7 @@ def test_mixed_requests(
dtype
:
str
,
max_tokens
:
int
,
cached_position
:
int
,
enable_chunked_prefill
:
bool
,
block_size
:
int
,
monkeypatch
,
)
->
None
:
...
...
@@ -55,6 +62,7 @@ def test_mixed_requests(
model
,
dtype
=
dtype
,
enable_prefix_caching
=
True
,
enable_chunked_prefill
=
enable_chunked_prefill
,
block_size
=
block_size
,
)
as
vllm_model
:
# Run the first prompt so the cache is populated
...
...
@@ -72,13 +80,13 @@ def test_mixed_requests(
block_size
)
*
block_size
else
:
expected_num_cached_tokens
=
0
assert
req_outputs
[
i
].
num_cached_tokens
==
expected_num_cached_tokens
assert
(
req_outputs
[
i
].
num_cached_tokens
==
expected_num_cached_tokens
)
vllm_outputs
=
[
(
output
.
prompt_token_ids
+
list
(
output
.
outputs
[
0
].
token_ids
),
output
.
prompt
+
output
.
outputs
[
0
].
text
)
for
output
in
req_outputs
]
vllm_outputs
=
[
(
output
.
prompt_token_ids
+
list
(
output
.
outputs
[
0
].
token_ids
),
output
.
prompt
+
output
.
outputs
[
0
].
text
,
)
for
output
in
req_outputs
]
check_outputs_equal
(
outputs_0_lst
=
hf_outputs
,
...
...
@@ -105,3 +113,89 @@ def test_unstable_prompt_sequence(
for
prompt
in
UNSTABLE_PROMPT_SEQUENCE
:
vllm_model
.
generate
(
TokensPrompt
(
prompt_token_ids
=
prompt
),
SamplingParams
(
max_tokens
=
1
))
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
def
test_fully_cached_prefill_needs_uncached_token
(
model
):
block_size
=
16
max_num_batched_tokens
=
16
num_output_tokens
=
5
# Make a vllm engine
runner
=
VllmRunner
(
model_name
=
model
,
gpu_memory_utilization
=
0.7
,
enable_chunked_prefill
=
True
,
enforce_eager
=
True
,
enable_prefix_caching
=
True
,
block_size
=
block_size
,
max_num_batched_tokens
=
max_num_batched_tokens
,
max_num_seqs
=
max_num_batched_tokens
,
)
engine
:
LLMEngine
=
runner
.
model
.
llm_engine
scheduler
:
Scheduler
=
SchedulerProxy
(
engine
.
scheduler
[
0
])
# type: ignore
engine
.
scheduler
[
0
]
=
scheduler
# SeqA
seqA_tokens
=
list
(
range
(
2
*
block_size
))
seqA
,
seq_groupA
=
create_dummy_prompt
(
request_id
=
"0"
,
prompt_tokens
=
seqA_tokens
,
max_tokens
=
num_output_tokens
,
block_size
=
block_size
,
)
scheduler
.
add_seq_group
(
seq_groupA
)
assert
seqA
.
data
.
get_num_computed_tokens
()
==
0
# Prefill seqA
while
not
seqA
.
is_finished
():
engine
.
step
()
# seqB
seqB_tokens
=
[
t
+
1
for
t
in
seqA_tokens
]
# shift by 1
seqB
,
seq_groupB
=
create_dummy_prompt
(
request_id
=
"1"
,
prompt_tokens
=
seqB_tokens
,
max_tokens
=
num_output_tokens
,
block_size
=
block_size
,
)
# seqC is the same as seqA
seqC
,
seq_groupC
=
create_dummy_prompt
(
request_id
=
"2"
,
prompt_tokens
=
seqA_tokens
,
max_tokens
=
num_output_tokens
,
block_size
=
block_size
,
)
scheduler
.
add_seq_group
(
seq_groupB
)
scheduler
.
add_seq_group
(
seq_groupC
)
# Even seqC is fully cached, it should not be prefilled since we
# require at least 1 uncached token.
engine
.
step
()
sched_metas
,
sched_out
,
_
=
scheduler
.
last_schedule_ret
()
assert
len
(
sched_out
.
scheduled_seq_groups
)
==
1
assert
(
sched_out
.
scheduled_seq_groups
[
0
].
seq_group
.
request_id
==
seq_groupB
.
request_id
)
assert
(
sched_out
.
scheduled_seq_groups
[
0
].
token_chunk_size
==
max_num_batched_tokens
)
# When seqB is finished, seqC could be prefilled.
while
not
seqB
.
is_finished
():
engine
.
step
()
sched_metas
,
sched_out
,
_
=
scheduler
.
last_schedule_ret
()
assert
len
(
sched_out
.
scheduled_seq_groups
)
==
1
assert
(
sched_out
.
scheduled_seq_groups
[
0
].
seq_group
.
request_id
==
seq_groupB
.
request_id
)
engine
.
step
()
sched_metas
,
sched_out
,
_
=
scheduler
.
last_schedule_ret
()
assert
len
(
sched_out
.
scheduled_seq_groups
)
==
1
assert
(
sched_out
.
scheduled_seq_groups
[
0
].
seq_group
.
request_id
==
seq_groupC
.
request_id
)
assert
sched_out
.
scheduled_seq_groups
[
0
].
token_chunk_size
==
len
(
seqA_tokens
)
vllm/core/block/cpu_gpu_block_allocator.py
View file @
4634a89d
...
...
@@ -306,14 +306,6 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
device
=
Device
.
GPU
return
self
.
_allocators
[
device
].
mark_blocks_as_computed
(
block_ids
)
def
get_computed_block_ids
(
self
,
prev_computed_block_ids
:
List
[
int
],
block_ids
:
List
[
int
],
skip_last_block_id
:
bool
)
->
List
[
int
]:
# Prefix caching only supported on GPU.
device
=
Device
.
GPU
return
self
.
_allocators
[
device
].
get_computed_block_ids
(
prev_computed_block_ids
,
block_ids
,
skip_last_block_id
)
def
get_common_computed_block_ids
(
self
,
computed_seq_block_ids
:
List
[
List
[
int
]])
->
List
[
int
]:
# Prefix caching only supported on GPU.
...
...
@@ -342,6 +334,13 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
self
.
_swap_mapping
.
clear
()
return
list
(
mapping
.
items
())
def
find_cached_blocks_prefix
(
self
,
block_hashes
:
List
[
int
],
device
:
Device
=
Device
.
GPU
,
)
->
List
[
int
]:
return
self
.
_allocators
[
device
].
find_cached_blocks_prefix
(
block_hashes
)
class
NullBlock
(
Block
):
"""
...
...
vllm/core/block/interfaces.py
View file @
4634a89d
...
...
@@ -159,12 +159,6 @@ class BlockAllocator(ABC):
def
mark_blocks_as_computed
(
self
,
block_ids
:
List
[
int
])
->
None
:
pass
@
abstractmethod
def
get_computed_block_ids
(
self
,
prev_computed_block_ids
:
List
[
int
],
block_ids
:
List
[
int
],
skip_last_block_id
:
bool
)
->
List
[
int
]:
pass
@
abstractmethod
def
get_common_computed_block_ids
(
self
,
computed_seq_block_ids
:
List
[
List
[
int
]])
->
List
[
int
]:
...
...
@@ -192,6 +186,13 @@ class BlockAllocator(ABC):
class
NoFreeBlocksError
(
ValueError
):
pass
@
abstractmethod
def
find_cached_blocks_prefix
(
self
,
block_hashes
:
List
[
int
],
)
->
List
[
int
]:
pass
class
DeviceAwareBlockAllocator
(
ABC
):
...
...
@@ -207,9 +208,12 @@ class DeviceAwareBlockAllocator(ABC):
pass
@
abstractmethod
def
allocate_immutable_blocks
(
self
,
prev_block
:
Optional
[
Block
],
block_token_ids
:
List
[
List
[
int
]],
device
:
Device
)
->
List
[
Block
]:
def
allocate_immutable_blocks
(
self
,
prev_block
:
Optional
[
Block
],
block_token_ids
:
List
[
List
[
int
]],
device
:
Device
,
)
->
List
[
Block
]:
pass
@
abstractmethod
...
...
@@ -246,12 +250,6 @@ class DeviceAwareBlockAllocator(ABC):
def
mark_blocks_as_computed
(
self
,
block_ids
:
List
[
int
])
->
None
:
pass
@
abstractmethod
def
get_computed_block_ids
(
self
,
prev_computed_block_ids
:
List
[
int
],
block_ids
:
List
[
int
],
skip_last_block_id
:
bool
)
->
List
[
int
]:
pass
@
abstractmethod
def
get_common_computed_block_ids
(
self
,
computed_seq_block_ids
:
List
[
List
[
int
]])
->
List
[
int
]:
...
...
@@ -284,3 +282,11 @@ class DeviceAwareBlockAllocator(ABC):
def
get_prefix_cache_hit_rate
(
self
,
device
:
Device
)
->
float
:
"""Prefix cache hit rate. -1 means not supported or disabled."""
pass
@
abstractmethod
def
find_cached_blocks_prefix
(
self
,
block_hashes
:
List
[
int
],
device
:
Device
=
Device
.
GPU
,
)
->
List
[
int
]:
pass
vllm/core/block/naive_block.py
View file @
4634a89d
...
...
@@ -262,13 +262,6 @@ class NaiveBlockAllocator(BlockAllocator):
"""
pass
def
get_computed_block_ids
(
self
,
prev_computed_block_ids
:
List
[
int
],
block_ids
:
List
[
int
],
skip_last_block_id
:
bool
)
->
List
[
int
]:
"""No prefix caching here => return empty list
"""
return
[]
def
get_common_computed_block_ids
(
self
,
computed_seq_block_ids
:
List
[
List
[
int
]])
->
List
[
int
]:
"""Determine blocks that can be skipped in prefill.
...
...
@@ -329,6 +322,10 @@ class NaiveBlockAllocator(BlockAllocator):
def
get_prefix_cache_hit_rate
(
self
)
->
float
:
return
-
1
def
find_cached_blocks_prefix
(
self
,
block_hashes
:
List
[
int
])
->
List
[
int
]:
# Not applicable for naive block allocator.
return
[]
class
NaiveBlock
(
Block
):
"""An implementation of the Block class that does not support prefix
...
...
vllm/core/block/prefix_caching_block.py
View file @
4634a89d
"""Token blocks."""
import
sys
from
bisect
import
bisect_left
from
os.path
import
commonprefix
from
typing
import
Dict
,
FrozenSet
,
Iterable
,
List
,
Optional
,
Set
,
Tuple
from
typing
import
(
Callable
,
Dict
,
FrozenSet
,
Iterable
,
List
,
Optional
,
Set
,
Tuple
)
from
vllm.core.block.common
import
(
CacheMetricData
,
CopyOnWriteTracker
,
get_all_blocks_recursively
)
from
vllm.core.block.interfaces
import
Block
,
BlockAllocator
,
BlockId
,
Device
from
vllm.core.block.interfaces
import
(
Block
,
BlockAllocator
,
BlockId
,
Device
,
DeviceAwareBlockAllocator
)
from
vllm.core.block.naive_block
import
(
BlockPool
,
NaiveBlock
,
NaiveBlockAllocator
)
from
vllm.core.evictor
import
EvictionPolicy
,
Evictor
,
make_evictor
from
vllm.sequence
import
Sequence
PrefixHash
=
int
...
...
@@ -534,26 +539,6 @@ class PrefixCachingBlockAllocator(BlockAllocator):
else
:
return
block_id
in
self
.
evictor
def
get_computed_block_ids
(
self
,
prev_computed_block_ids
:
List
[
int
],
block_ids
:
List
[
int
],
skip_last_block_id
:
bool
=
True
)
->
List
[
int
]:
prev_prefix_size
=
len
(
prev_computed_block_ids
)
cur_size
=
len
(
block_ids
)
if
skip_last_block_id
:
cur_size
-=
1
# Sanity checks
assert
cur_size
>=
0
assert
prev_prefix_size
<=
cur_size
ret
=
prev_computed_block_ids
for
i
in
range
(
prev_prefix_size
,
cur_size
):
block_id
=
block_ids
[
i
]
if
self
.
block_is_computed
(
block_id
):
ret
.
append
(
block_id
)
return
ret
def
get_common_computed_block_ids
(
self
,
computed_seq_block_ids
:
List
[
List
[
int
]])
->
List
[
int
]:
"""Return the block ids that are common for a given sequence group.
...
...
@@ -634,6 +619,47 @@ class PrefixCachingBlockAllocator(BlockAllocator):
block
.
block_id
=
block_id
# Assign block_id
def
find_cached_blocks_prefix
(
self
,
block_hashes
:
List
[
int
])
->
List
[
int
]:
"""
Given a list of block hashes, return the prefix of the block hashes that
are all cached.
Since a block's block hash includes the hashes of all previous blocks,
and we only allocate/deallocate blocks in the entire sequence, so if a
block is cached, then all previous blocks are also cached. With this
property, we can use binary search to find the prefix of cached blocks.
Args:
block_hashes (List[int]): The list of block hashes.
Returns:
List[int]: The prefix of the `block_hashes` that are cached.
"""
def
_block_is_cached
(
block_hash
:
PrefixHash
)
->
bool
:
if
block_hash
not
in
self
.
_cached_blocks
:
return
False
cached_block_id
=
self
.
_cached_blocks
[
block_hash
]
# We only consider the blocks that are marked as computed.
return
self
.
block_is_computed
(
cached_block_id
)
def
_bisect_left
(
a
,
x
,
key
:
Callable
[[
PrefixHash
],
bool
])
->
int
:
# python <= 3.10 don't have the key argument
if
sys
.
version_info
<
(
3
,
10
):
a
=
[
key
(
e
)
for
e
in
a
]
return
bisect_left
(
a
,
x
)
else
:
return
bisect_left
(
a
,
x
,
key
=
key
)
# Look for the first block that's not cached, and returns the prefix
# i.e. blocks that are cached.
idx
=
_bisect_left
(
block_hashes
,
True
,
key
=
lambda
x
:
not
_block_is_cached
(
x
))
return
block_hashes
[:
idx
]
class
PrefixCachingBlock
(
Block
):
"""A block implementation that supports prefix caching.
...
...
@@ -843,86 +869,126 @@ class PrefixCachingBlock(Block):
class
ComputedBlocksTracker
:
"""Handles caching of per-sequence computed block ids.
When a sequence appears for the first time, it traverses all of the
blocks and detects the prefix of blocks that is computed. On the
subsequent times, it only traverses the new blocks that were added
and updates the already recorded prefix of blocks with the newly
computed blocks.
To avoid redundant traversals, the algorithm also detects when there
is a "gap" in the computed prefix. For example, if we have blocks =
[1,2,3,4,5], and we have detected [1,2,3] as the computed prefix, then
we won't try to add more computed blocks to [1,2,3] in this sequence
iteration, and will add more computed blocks only after the sequence is
freed and reused again.
Note that currently, for a given sequence, we also skip the last
block id for caching purposes, to avoid caching of a full sequence
"""
Tracks the computed blocks for each sequence.
def
__init__
(
self
,
allocator
):
self
.
_allocator
=
allocator
self
.
_cached_computed_seq_blocks
:
Dict
[
int
,
Tuple
[
List
[
int
],
bool
]]
=
{}
Internally, it maintains a map from sequence id to the list of block hashes
for the sequence. We cache the hashes of the full blocks for each sequence,
and make sure the hash is calculated in the same way as the allocator.
When a sequence is being decoded, we also update the sequence's hash
accordingly and incrementally.
def
add_seq
(
self
,
seq_id
:
int
)
->
None
:
"""Start tracking seq_id
"""
assert
seq_id
not
in
self
.
_cached_computed_seq_blocks
self
.
_cached_computed_seq_blocks
[
seq_id
]
=
([],
False
)
def
remove_seq
(
self
,
seq_id
:
int
)
->
None
:
"""Stop tracking seq_id
"""
assert
seq_id
in
self
.
_cached_computed_seq_blocks
del
self
.
_cached_computed_seq_blocks
[
seq_id
]
def
get_cached_computed_blocks_and_update
(
self
,
seq_id
:
int
,
block_ids
:
List
[
int
])
->
List
[
int
]:
""" Look at the class documentation for details
"""
# Ensure seq_id is already tracked
assert
seq_id
in
self
.
_cached_computed_seq_blocks
# Get cached data (may be empty on the first time)
prev_computed_block_ids
,
has_gap
=
self
.
_cached_computed_seq_blocks
[
seq_id
]
if
has_gap
:
# When gap is detected, we do not add more computed blocks at this
# sequence iteration
return
prev_computed_block_ids
# We do not consider the last block id for caching purposes.
num_cur_blocks
=
len
(
block_ids
)
-
1
assert
num_cur_blocks
>=
0
if
len
(
prev_computed_block_ids
)
>=
num_cur_blocks
:
# Cache HIT
assert
len
(
prev_computed_block_ids
)
==
num_cur_blocks
return
prev_computed_block_ids
# If here, then we may possibly add more computed blocks. As a result,
# traverse the additional blocks after prev_computed_block_ids to
# detect more computed blocks and add them.
# Incremental init for seq_id => Look only at the new blocks
computed_block_ids
=
self
.
_allocator
.
get_computed_block_ids
(
# noqa: E501
prev_computed_block_ids
,
block_ids
,
skip_last_block_id
=
True
,
# We skip last block id to avoid caching of full seq
)
From the sequence hash, with prefix caching enabled, we could also calculate
the number of cached tokens for the sequence by looking up the number of
cached block hashes in the allocator.
"""
# Detect if there is a "gap"
has_gap
=
len
(
computed_block_ids
)
<
num_cur_blocks
def
__init__
(
self
,
allocator
:
DeviceAwareBlockAllocator
,
block_size
:
int
,
enable_caching
:
bool
,
):
self
.
_allocator
=
allocator
self
.
_block_size
=
block_size
self
.
_enable_caching
=
enable_caching
# A map from seq_id to the list of block hashes for the
# sequence. This is so that we don't have to recompute the block hashes
# for the sequence when we need to check if the sequence is cached.
# Note a block that's not full will not have its hash calculated and
# recorded.
self
.
_seq_id_to_blocks_hashes
:
Dict
[
int
,
List
[
int
]]
=
{}
# A map from seq_id to the number of tokens that are cached for the
# sequence.
# We need this so that a sequence in continuous prefill doesn't
# accidentally see its cached token count change. See comments in
# `get_num_cached_tokens` for more details.
self
.
_seq_id_to_num_tokens_computed
:
Dict
[
int
,
int
]
=
{}
def
_update_seq_hashes
(
self
,
seq
:
Sequence
)
->
None
:
"""Incrementally update the sequence's block hashes and record them."""
assert
self
.
_enable_caching
block_hashes_recorded
=
self
.
_seq_id_to_blocks_hashes
.
get
(
seq
.
seq_id
,
[])
cur_num_blocks_recorded
=
len
(
block_hashes_recorded
)
token_ids
=
seq
.
get_token_ids
()
assert
len
(
token_ids
)
>=
cur_num_blocks_recorded
*
self
.
_block_size
,
(
f
"The sequence has
{
len
(
token_ids
)
}
tokens, but"
f
" already recorded
{
cur_num_blocks_recorded
}
blocks. "
"This should not happen since we assume blocks are "
"only appended other than recomputation. When the sequence is "
"recomputed, we should have removed the info of the old blocks."
)
# Update the computed block hashes for the sequence. Since only full
# blocks are considered as "computed", we take floor here.
num_computed_blocks
=
len
(
token_ids
)
//
self
.
_block_size
# We need to know the hash of the previous block to compute the hash of
# the current block so that blocks could be uniquely identified across
# sequences of prefixes.
prev_block_hash
=
(
None
if
cur_num_blocks_recorded
==
0
else
block_hashes_recorded
[
-
1
])
# Only update the computed block hashes for the new blocks
for
i
in
range
(
cur_num_blocks_recorded
,
num_computed_blocks
):
assert
len
(
token_ids
)
>=
(
i
+
1
)
*
self
.
_block_size
block_token_ids
=
token_ids
[
i
*
self
.
_block_size
:(
i
+
1
)
*
self
.
_block_size
]
# This has to be kept in sync with the allocator's hash
# calculation.
block_hash
=
PrefixCachingBlock
.
hash_block_tokens
(
is_first_block
=
prev_block_hash
is
None
,
prev_block_hash
=
prev_block_hash
,
cur_block_token_ids
=
block_token_ids
,
)
block_hashes_recorded
.
append
(
block_hash
)
prev_block_hash
=
block_hash
self
.
_seq_id_to_blocks_hashes
[
seq
.
seq_id
]
=
block_hashes_recorded
def
get_num_cached_tokens
(
self
,
seq
:
Sequence
)
->
int
:
if
not
self
.
_enable_caching
:
return
0
# We always try to update the sequence hashes on the fly.
# This is to ensure that we don't miss any cached tokens for the
# sequence during decode.
# This routine should only update hash for any new blocks too.
self
.
_update_seq_hashes
(
seq
)
num_computed_tokens_prev
=
self
.
_seq_id_to_num_tokens_computed
.
get
(
seq
.
seq_id
,
None
)
# TODO(rickyx): This hack could be removed once we mark blocks as
# computed correctly with chunked prefills.
if
num_computed_tokens_prev
is
not
None
and
seq
.
is_prefill
():
# For a sequence that is still in prefill, we don't
# recompute the number of cached tokens.
# This also handles correctly chunked prefill since currently
# we mark blocks as computed even if the sequence is still partially
# prefilled. So a continuously prefilled sequence should not
# see its cached token count change while running.
return
num_computed_tokens_prev
block_hashes
=
self
.
_seq_id_to_blocks_hashes
[
seq
.
seq_id
]
# This is O(logN), where N is the number of blocks.
num_cached_blocks
=
len
(
self
.
_allocator
.
find_cached_blocks_prefix
(
block_hashes
))
num_cached_tokens
=
num_cached_blocks
*
self
.
_block_size
self
.
_seq_id_to_num_tokens_computed
[
seq
.
seq_id
]
=
num_cached_tokens
return
num_cached_tokens
# Record
self
.
_cached_computed_seq_blocks
[
seq_id
]
=
(
computed_block_ids
,
has_gap
)
def
remove_seq
(
self
,
seq_id
:
int
)
->
None
:
"""Stop tracking the sequence."""
if
not
self
.
_enable_caching
:
return
assert
seq_id
in
self
.
_seq_id_to_blocks_hashes
del
self
.
_seq_id_to_blocks_hashes
[
seq_id
]
return
computed_block_ids
assert
seq_id
in
self
.
_seq_id_to_num_tokens_computed
del
self
.
_seq_id_to_num_tokens_computed
[
seq_id
]
class
LastAccessBlocksTracker
:
...
...
vllm/core/block_manager.py
View file @
4634a89d
...
...
@@ -101,7 +101,7 @@ class SelfAttnBlockSpaceManager(BlockSpaceManager):
self
.
cross_block_tables
:
Dict
[
EncoderSeqId
,
BlockTable
]
=
{}
self
.
_computed_blocks_tracker
=
ComputedBlocksTracker
(
self
.
block_allocator
)
self
.
block_allocator
,
self
.
block_size
,
self
.
enable_caching
)
self
.
_last_access_blocks_tracker
=
LastAccessBlocksTracker
(
self
.
block_allocator
)
...
...
@@ -170,7 +170,6 @@ class SelfAttnBlockSpaceManager(BlockSpaceManager):
self
.
block_tables
[
seq
.
seq_id
]
=
block_table
# Track seq
self
.
_computed_blocks_tracker
.
add_seq
(
seq
.
seq_id
)
self
.
_last_access_blocks_tracker
.
add_seq
(
seq
.
seq_id
)
# Assign the block table for each sequence.
...
...
@@ -178,7 +177,6 @@ class SelfAttnBlockSpaceManager(BlockSpaceManager):
self
.
block_tables
[
seq
.
seq_id
]
=
block_table
.
fork
()
# Track seq
self
.
_computed_blocks_tracker
.
add_seq
(
seq
.
seq_id
)
self
.
_last_access_blocks_tracker
.
add_seq
(
seq
.
seq_id
)
# Allocate cross-attention block table for encoder sequence
...
...
@@ -314,11 +312,13 @@ class SelfAttnBlockSpaceManager(BlockSpaceManager):
"""
computed_seq_block_ids
=
[]
for
seq
in
seqs
:
computed_seq_block_ids
.
append
(
self
.
_computed_blocks_tracker
.
get_cached_computed_blocks_and_update
(
seq
.
seq_id
,
self
.
block_tables
[
seq
.
seq_id
].
physical_block_ids
))
all_blocks
=
self
.
block_tables
[
seq
.
seq_id
].
physical_block_ids
num_cached_tokens
=
(
self
.
_computed_blocks_tracker
.
get_num_cached_tokens
(
seq
))
assert
num_cached_tokens
%
self
.
block_size
==
0
num_cached_blocks
=
num_cached_tokens
//
self
.
block_size
computed_block_ids
=
all_blocks
[:
num_cached_blocks
]
computed_seq_block_ids
.
append
(
computed_block_ids
)
# NOTE(sang): This assumes seq_block_ids doesn't contain any None.
return
self
.
block_allocator
.
get_common_computed_block_ids
(
...
...
@@ -332,7 +332,6 @@ class SelfAttnBlockSpaceManager(BlockSpaceManager):
self
.
block_tables
[
child_seq
.
seq_id
]
=
src_block_table
.
fork
()
# Track child seq
self
.
_computed_blocks_tracker
.
add_seq
(
child_seq
.
seq_id
)
self
.
_last_access_blocks_tracker
.
add_seq
(
child_seq
.
seq_id
)
def
can_swap_in
(
self
,
seq_group
:
SequenceGroup
,
...
...
@@ -503,3 +502,9 @@ class SelfAttnBlockSpaceManager(BlockSpaceManager):
return
AllocStatus
.
OK
else
:
return
AllocStatus
.
LATER
def
get_num_cached_tokens
(
self
,
seq
:
Sequence
)
->
int
:
"""Get the number of tokens in blocks that are already computed and
cached in the block manager for the sequence.
"""
return
self
.
_computed_blocks_tracker
.
get_num_cached_tokens
(
seq
)
vllm/core/interfaces.py
View file @
4634a89d
...
...
@@ -121,3 +121,7 @@ class BlockSpaceManager(ABC):
def
get_prefix_cache_hit_rate
(
self
,
device
:
Device
)
->
float
:
"""Prefix cache hit rate. -1 means not supported or disabled."""
pass
@
abstractmethod
def
get_num_cached_tokens
(
self
,
seq
:
Sequence
)
->
int
:
pass
vllm/core/placeholder_block_space_manager.py
View file @
4634a89d
...
...
@@ -89,3 +89,6 @@ class PlaceholderBlockSpaceManager(BlockSpaceManager):
def
get_prefix_cache_hit_rate
(
self
,
device
:
Device
)
->
float
:
return
-
1
def
get_num_cached_tokens
(
self
,
seq
:
Sequence
)
->
int
:
return
0
vllm/core/scheduler.py
View file @
4634a89d
...
...
@@ -56,11 +56,16 @@ class SchedulingBudget:
max_num_seqs
:
int
_request_ids_num_batched_tokens
:
Set
[
str
]
=
field
(
default_factory
=
set
)
_request_ids_num_curr_seqs
:
Set
[
str
]
=
field
(
default_factory
=
set
)
# Number of cached tokens in the batch.
_num_cached_tokens
:
int
=
0
# Number of actual non-cached tokens in the batch.
_num_batched_tokens
:
int
=
0
_num_curr_seqs
:
int
=
0
def
can_schedule
(
self
,
*
,
num_new_tokens
:
int
,
num_new_seqs
:
int
):
assert
num_new_tokens
!=
0
# We allow num_new_tokens to be 0 when the entire sequence has
# been cached.
assert
num_new_tokens
>=
0
assert
num_new_seqs
!=
0
return
(
self
.
num_batched_tokens
+
num_new_tokens
<=
self
.
token_budget
and
self
.
num_curr_seqs
+
num_new_seqs
<=
self
.
max_num_seqs
)
...
...
@@ -68,12 +73,18 @@ class SchedulingBudget:
def
remaining_token_budget
(
self
):
return
self
.
token_budget
-
self
.
num_batched_tokens
def
add_num_batched_tokens
(
self
,
req_id
:
str
,
num_batched_tokens
:
int
):
def
add_num_batched_tokens
(
self
,
req_id
:
str
,
num_batched_tokens
:
int
,
num_cached_tokens
:
int
=
0
):
if
req_id
in
self
.
_request_ids_num_batched_tokens
:
return
assert
num_cached_tokens
>=
0
assert
num_batched_tokens
>=
0
self
.
_request_ids_num_batched_tokens
.
add
(
req_id
)
self
.
_num_batched_tokens
+=
num_batched_tokens
self
.
_num_cached_tokens
+=
num_cached_tokens
def
subtract_num_batched_tokens
(
self
,
req_id
:
str
,
num_batched_tokens
:
int
):
...
...
@@ -101,6 +112,10 @@ class SchedulingBudget:
def
num_curr_seqs
(
self
):
return
self
.
_num_curr_seqs
@
property
def
num_cached_tokens
(
self
):
return
self
.
_num_cached_tokens
@
dataclass
class
ScheduledSequenceGroup
:
...
...
@@ -541,9 +556,19 @@ class Scheduler:
assert
len
(
self
.
_async_stopped
)
==
0
while
running_queue
:
seq_group
=
running_queue
[
0
]
num_running_tokens
=
self
.
_get_num_new_tokens
(
seq_group
,
SequenceStatus
.
RUNNING
,
enable_chunking
,
budget
)
# We discard the cached tokens info here because we don't need it
# for running sequence:
# 1. If a sequence is running with chunked prefill, the cached
# tokens info was already used for the first prefill.
# 2. If a sequence is running with non-chunked prefill, then
# there it's a decoding sequence, and the cached tokens info is
# irrelevant.
num_uncached_new_tokens
,
_
=
(
self
.
_get_num_new_uncached_and_cached_tokens
(
seq_group
,
SequenceStatus
.
RUNNING
,
enable_chunking
,
budget
))
num_running_tokens
=
num_uncached_new_tokens
if
num_running_tokens
==
0
:
# No budget => Stop
break
...
...
@@ -715,13 +740,15 @@ class Scheduler:
# The total number of sequences in the RUNNING state should not
# exceed the maximum number of sequences.
num_new_seqs
=
seq_group
.
get_max_num_running_seqs
()
num_new_tokens
=
self
.
_get_num_new_tokens
(
seq_group
,
SequenceStatus
.
SWAPPED
,
enable_chunking
,
budget
)
if
(
num_new_tokens
==
0
or
not
budget
.
can_schedule
(
num_new_tokens
=
num_new_tokens
,
num_new_seqs
=
num_new_seqs
)):
num_new_tokens_uncached
,
num_new_tokens_cached
=
(
self
.
_get_num_new_uncached_and_cached_tokens
(
seq_group
,
SequenceStatus
.
SWAPPED
,
enable_chunking
,
budget
))
if
num_new_tokens_uncached
==
0
or
not
budget
.
can_schedule
(
num_new_tokens
=
num_new_tokens_uncached
,
num_new_seqs
=
num_new_seqs
,
):
break
if
lora_int_id
>
0
and
curr_loras
is
not
None
:
...
...
@@ -732,12 +759,19 @@ class Scheduler:
is_prefill
=
seq_group
.
is_prefill
()
if
is_prefill
:
prefill_seq_groups
.
append
(
ScheduledSequenceGroup
(
seq_group
,
token_chunk_size
=
num_new_tokens
))
ScheduledSequenceGroup
(
seq_group
,
token_chunk_size
=
num_new_tokens_uncached
+
num_new_tokens_cached
,
))
else
:
decode_seq_groups
.
append
(
ScheduledSequenceGroup
(
seq_group
,
token_chunk_size
=
1
))
budget
.
add_num_batched_tokens
(
seq_group
.
request_id
,
num_new_tokens
)
budget
.
add_num_batched_tokens
(
seq_group
.
request_id
,
num_batched_tokens
=
num_new_tokens_uncached
,
num_cached_tokens
=
num_new_tokens_cached
,
)
budget
.
add_num_seqs
(
seq_group
.
request_id
,
num_new_seqs
)
swapped_queue
.
extendleft
(
leftover_swapped
)
...
...
@@ -803,26 +837,30 @@ class Scheduler:
if
waiting_queue
:
seq_group
=
waiting_queue
.
popleft
()
num_new_seqs
=
seq_group
.
get_max_num_running_seqs
()
num_new_tokens
=
self
.
_get_num_new_tokens
(
seq_group
,
SequenceStatus
.
WAITING
,
False
,
budget
)
num_new_tokens
_uncached
,
_
=
(
self
.
_get_num_new_uncached_and_cached_tokens
(
seq_group
,
SequenceStatus
.
WAITING
,
False
,
budget
)
)
#Only preempt if priority inversion exists
while
running_queue
and
self
.
_get_priority
(
running_queue
[
-
1
])
>
self
.
_get_priority
(
seq_group
):
#Only preempt if waiting sequence cannot be allocated
can_allocate
=
self
.
block_manager
.
can_allocate
(
seq_group
)
if
(
num_new_tokens
and
can_allocate
==
AllocStatus
.
OK
and
budget
.
can_schedule
(
num_new_tokens
=
num_new_tokens
,
num_new_seqs
=
num_new_seqs
)):
if
(
num_new_tokens_uncached
>
0
and
can_allocate
==
AllocStatus
.
OK
and
budget
.
can_schedule
(
num_new_tokens
=
num_new_tokens_uncached
,
num_new_seqs
=
num_new_seqs
,
)):
break
#Adjust budget to remove the victim sequence group
vseq_group
=
running_queue
.
pop
()
num_running_tokens
=
self
.
_get_num_new_tokens
(
vseq_group
,
SequenceStatus
.
RUNNING
,
False
,
budget
)
budget
.
subtract_num_batched_tokens
(
vseq_group
.
request_id
,
num_running_tokens
)
num_running_tokens_uncached
,
_
=
(
self
.
_get_num_new_uncached_and_cached_tokens
(
vseq_group
,
SequenceStatus
.
RUNNING
,
False
,
budget
))
budget
.
subtract_num_batched_tokens
(
vseq_group
.
request_id
,
num_running_tokens_uncached
)
num_running_seqs
=
vseq_group
.
get_max_num_running_seqs
()
budget
.
subtract_num_seqs
(
vseq_group
.
request_id
,
num_running_seqs
)
...
...
@@ -882,9 +920,12 @@ class Scheduler:
assert
len
(
waiting_seqs
)
==
1
,
(
"Waiting sequence group should have only one prompt "
"sequence."
)
num_new_tokens
=
self
.
_get_num_new_tokens
(
seq_group
,
SequenceStatus
.
WAITING
,
enable_chunking
,
budget
)
num_new_tokens_uncached
,
num_new_tokens_cached
=
(
self
.
_get_num_new_uncached_and_cached_tokens
(
seq_group
,
SequenceStatus
.
WAITING
,
enable_chunking
,
budget
))
num_new_tokens
=
num_new_tokens_uncached
+
num_new_tokens_cached
if
not
enable_chunking
:
num_prompt_tokens
=
waiting_seqs
[
0
].
get_len
()
assert
num_new_tokens
==
num_prompt_tokens
...
...
@@ -935,10 +976,18 @@ class Scheduler:
waiting_queue
.
popleft
()
continue
if
(
budget
.
num_batched_tokens
>=
self
.
scheduler_config
.
max_num_batched_tokens
):
# We've reached the budget limit - since there might be
# continuous prefills in the running queue, we should break
# to avoid scheduling any new prefills.
break
num_new_seqs
=
seq_group
.
get_max_num_running_seqs
()
if
(
num_new_tokens
==
0
or
not
budget
.
can_schedule
(
num_new_tokens
=
num_new_tokens
,
num_new_seqs
=
num_new_seqs
)):
if
num_new_tokens_uncached
==
0
or
not
budget
.
can_schedule
(
num_new_tokens
=
num_new_tokens_uncached
,
num_new_seqs
=
num_new_seqs
,
):
break
# Can schedule this request.
...
...
@@ -967,7 +1016,11 @@ class Scheduler:
seq_groups
.
append
(
ScheduledSequenceGroup
(
seq_group
=
seq_group
,
token_chunk_size
=
num_new_tokens
))
budget
.
add_num_batched_tokens
(
seq_group
.
request_id
,
num_new_tokens
)
budget
.
add_num_batched_tokens
(
seq_group
.
request_id
,
num_batched_tokens
=
num_new_tokens_uncached
,
num_cached_tokens
=
num_new_tokens_cached
,
)
budget
.
add_num_seqs
(
seq_group
.
request_id
,
num_new_seqs
)
# Queue requests that couldn't be scheduled.
...
...
@@ -1075,7 +1128,8 @@ class Scheduler:
return
SchedulerOutputs
(
scheduled_seq_groups
=
scheduled_seq_groups
,
num_prefill_groups
=
num_prefill_groups
,
num_batched_tokens
=
budget
.
num_batched_tokens
,
num_batched_tokens
=
budget
.
num_batched_tokens
+
budget
.
num_cached_tokens
,
blocks_to_swap_in
=
swapped_in
.
blocks_to_swap_in
,
blocks_to_swap_out
=
running_scheduled
.
blocks_to_swap_out
,
blocks_to_copy
=
blocks_to_copy
,
...
...
@@ -1119,7 +1173,6 @@ class Scheduler:
running_scheduled
.
swapped_out
)
==
0
:
swapped_in
=
self
.
_schedule_swapped
(
budget
,
curr_loras
)
# Schedule new prefills.
prefills
=
self
.
_schedule_prefills
(
budget
,
curr_loras
,
enable_chunking
=
True
)
...
...
@@ -1157,7 +1210,8 @@ class Scheduler:
num_prefill_groups
=
(
len
(
prefills
.
seq_groups
)
+
len
(
swapped_in
.
prefill_seq_groups
)
+
len
(
running_scheduled
.
prefill_seq_groups
)),
num_batched_tokens
=
budget
.
num_batched_tokens
,
num_batched_tokens
=
budget
.
num_batched_tokens
+
budget
.
num_cached_tokens
,
blocks_to_swap_in
=
swapped_in
.
blocks_to_swap_in
,
blocks_to_swap_out
=
running_scheduled
.
blocks_to_swap_out
,
blocks_to_copy
=
running_scheduled
.
blocks_to_copy
+
...
...
@@ -1584,64 +1638,178 @@ class Scheduler:
return
self
.
scheduler_config
.
num_lookahead_slots
def
_get_num_new_tokens
(
self
,
seq_group
:
SequenceGroup
,
status
:
SequenceStatus
,
enable_chunking
:
bool
,
budget
:
SchedulingBudget
)
->
int
:
"""Get the next new tokens to compute for a given sequence group
that's in a given `status`.
def
_get_num_new_uncached_and_cached_tokens
(
self
,
seq_group
:
SequenceGroup
,
status
:
SequenceStatus
,
enable_chunking
:
bool
,
budget
:
SchedulingBudget
,
)
->
Tuple
[
int
,
int
]:
"""
Returns the number of new uncached and cached tokens to schedule for a
given sequence group that's in a given `status`.
The API could chunk the number of tokens to compute based on `budget`
if `enable_chunking` is True. If a sequence group has multiple
sequences (e.g., running beam search), it means it is in decoding
phase, so chunking doesn't happen.
Returns 0 if the new token cannot be computed due to token budget.
Returns (0, 0) if the new token cannot be computed due to token budget.
The cached tokens's blocks are already computed, and the attention
backend will reuse the cached blocks rather than recomputing them. So
the scheduler could schedule these cached tokens "for free".
Args:
seq_group: The sequence group to get the number of new tokens to
schedule.
status: The status of the sequences to get the number of new tokens
to schedule.
enable_chunking: Whether to chunk the number of tokens to compute.
budget: The budget to chunk the number of tokens to compute.
Returns:
A tuple of two ints. The first int is the number of new uncached
tokens to schedule. The second int is the number of cached tokens.
If no more new tokens can be scheduled, returns (0, 0).
"""
num_new_tokens
=
0
num_cached_new_tokens
=
0
num_uncached_new_tokens
=
0
seqs
=
seq_group
.
get_seqs
(
status
=
status
)
# Compute the number of new uncached and cached tokens for
# each sequence.
for
seq
in
seqs
:
num_new_tokens
+=
seq
.
get_num_new_tokens
()
assert
num_new_tokens
>
0
# Chunk if a running request cannot fit in the given budget.
# If number of seq > 1, it means it is doing beam search
# in a decode phase. Do not chunk.
if
not
seq
.
is_prefill
():
# Decode sequences should always just have 1 uncached token
# TODO(rickyx): Actually is this still correct for multi-step?
num_uncached_new_tokens
+=
1
continue
num_computed_tokens_seq
=
seq
.
get_num_computed_tokens
()
all_num_new_tokens_seq
=
seq
.
get_len
()
-
num_computed_tokens_seq
if
not
self
.
cache_config
.
enable_prefix_caching
:
# If prefix caching is not enabled, all new tokens are uncached.
num_uncached_new_tokens
+=
all_num_new_tokens_seq
continue
# NOTE: the cache token might be currently in a block that's in an
# evictor meaning that it's not yet allocated. However, we don't
# exclude such tokens in the cache count because it will be
# guaranteed to be allocated later if the sequence can be allocated.
num_cached_tokens_seq
=
self
.
block_manager
.
get_num_cached_tokens
(
seq
)
# Sanity check.
if
num_cached_tokens_seq
<
num_computed_tokens_seq
:
# This should only happen with chunked prefill, and
# the seq is still in prefill. The `num_cached_tokens_seq`
# is the value we calculated on scheduling the first prefill.
# For subsequent continuous prefill steps, we cached the
# number of cache tokens for the sequence so the cached token
# count could be less than the number of computed tokens.
# See comments on `ComputedBlocksTracker` for more details.
assert
(
seq
.
is_prefill
()
and
seq
.
status
==
SequenceStatus
.
RUNNING
and
self
.
scheduler_config
.
chunked_prefill_enabled
),
(
"Number of cached tokens should not be less than the "
"number of computed tokens for a sequence that's still "
f
"in prefill. But there are
{
num_cached_tokens_seq
}
cached "
f
"tokens and
{
num_computed_tokens_seq
}
computed tokens "
f
"for sequence
{
seq
.
seq_id
}
."
)
num_cached_new_tokens_seq
=
max
(
0
,
num_cached_tokens_seq
-
num_computed_tokens_seq
)
num_uncached_new_tokens_seq
=
(
all_num_new_tokens_seq
-
num_cached_new_tokens_seq
)
num_uncached_new_tokens
+=
num_uncached_new_tokens_seq
num_cached_new_tokens
+=
num_cached_new_tokens_seq
if
num_uncached_new_tokens
==
0
and
num_cached_new_tokens
>
0
:
# For a fully cached hit sequence, we actually need to recompute the
# last token. So we need at least 1 uncached token to schedule.
# See ModelRunner._compute_for_prefix_cache_hit for more details.
num_uncached_new_tokens
=
1
num_cached_new_tokens
-=
1
if
enable_chunking
and
len
(
seqs
)
==
1
:
remaining_token_budget
=
budget
.
remaining_token_budget
()
if
self
.
scheduler_config
.
is_multi_step
:
# The current multi-step + chunked prefill capability does
# not actually support chunking prompts.
#
# Therefore, `num_new_tokens` is computed in the same fashion
# for both multi-step+chunked-prefill &
# multi-step+chunked-prefill+APC
#
# Prompts with more tokens than the current remaining budget
# are postponed to future scheduler steps
if
num_new_tokens
>
self
.
_get_prompt_limit
(
seq_group
):
# If the seq_group is in prompt-stage, pass the
# num_new_tokens as-is so the caller can ignore
# the sequence.
pass
else
:
num_new_tokens
=
0
\
if
num_new_tokens
>
remaining_token_budget
\
else
num_new_tokens
elif
self
.
cache_config
.
enable_prefix_caching
:
# When prefix caching is enabled, we always allocate
# the number of new tokens that is dividable by the block
# size to avoid partial block matching.
block_size
=
self
.
cache_config
.
block_size
remainder
=
budget
.
token_budget
%
block_size
if
remainder
!=
0
:
raise
ValueError
(
"When enabling chunked prefill and "
"prefix caching, max_num_batched_tokens "
"(chunk size) must be dividable by "
"block size, but got chunk_size "
f
"(
{
budget
.
token_budget
}
) % block_size "
f
"(
{
block_size
}
) =
{
remainder
}
"
)
if
remaining_token_budget
<
num_new_tokens
:
num_new_tokens
=
(
remaining_token_budget
//
block_size
)
*
block_size
else
:
num_new_tokens
=
min
(
num_new_tokens
,
remaining_token_budget
)
# Chunk if a running request cannot fit in the given budget.
# If number of seq > 1, it means it is doing beam search
# in a decode phase. Do not chunk.
num_uncached_new_tokens
=
self
.
_chunk_new_tokens_to_schedule
(
self
.
scheduler_config
,
self
.
cache_config
,
budget
,
self
.
_get_prompt_limit
(
seq_group
),
num_uncached_new_tokens
,
)
return
num_uncached_new_tokens
,
num_cached_new_tokens
@
staticmethod
def
_chunk_new_tokens_to_schedule
(
scheduler_config
:
SchedulerConfig
,
cache_config
:
CacheConfig
,
budget
:
SchedulingBudget
,
prompt_limit
:
int
,
num_new_tokens
:
int
,
)
->
int
:
"""
Chunks the number of new tokens to schedule based on the budget when
chunked prefill is enabled.
Args:
scheduler_config: The scheduler config.
cache_config: The cache config.
budget: The budget to chunk the number of tokens to compute.
prompt_limit: The maximum number of tokens allowed in a prompt.
num_new_tokens: The number of new tokens to schedule.
Returns:
The number of new tokens to schedule after chunking.
"""
remaining_token_budget
=
budget
.
remaining_token_budget
()
if
scheduler_config
.
is_multi_step
:
# The current multi-step + chunked prefill capability does
# not actually support chunking prompts.
#
# Therefore, `num_new_tokens` is computed in the same fashion
# for both multi-step+chunked-prefill &
# multi-step+chunked-prefill+APC
#
# Prompts with more tokens than the current remaining budget
# are postponed to future scheduler steps
if
num_new_tokens
>
prompt_limit
:
# If the seq_group is in prompt-stage, pass the
# num_new_tokens as-is so the caller can ignore
# the sequence.
return
num_new_tokens
return
(
0
if
num_new_tokens
>
remaining_token_budget
else
num_new_tokens
)
if
cache_config
.
enable_prefix_caching
:
# Adjust the remaining token budget to be divisible by the block
# size when prefix caching is enabled.
# When prefix caching is enabled, we always allocate
# the number of new tokens that is dividable by the block
# size to avoid partial block matching.
block_size
=
cache_config
.
block_size
remainder
=
budget
.
token_budget
%
block_size
if
remainder
!=
0
:
raise
ValueError
(
"When enabling chunked prefill and "
"prefix caching, max_num_batched_tokens "
"(chunk size) must be dividable by "
"block size, but got chunk_size "
f
"(
{
budget
.
token_budget
}
) % block_size "
f
"(
{
block_size
}
) =
{
remainder
}
"
)
# Round down to block size.
remaining_token_budget
=
(
remaining_token_budget
//
block_size
*
block_size
)
num_new_tokens
=
min
(
num_new_tokens
,
remaining_token_budget
)
return
num_new_tokens
vllm/sequence.py
View file @
4634a89d
...
...
@@ -579,6 +579,9 @@ class Sequence:
return
1
return
self
.
data
.
get_num_uncomputed_tokens
()
def
get_num_computed_tokens
(
self
)
->
int
:
return
self
.
data
.
get_num_computed_tokens
()
def
is_prefill
(
self
)
->
bool
:
return
self
.
data
.
stage
==
SequenceStage
.
PREFILL
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment