Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
e3580537
Unverified
Commit
e3580537
authored
Aug 28, 2024
by
Cody Yu
Committed by
GitHub
Aug 28, 2024
Browse files
[Performance] Enable chunked prefill and prefix caching together (#7753)
parent
f508e03e
Changes
9
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
225 additions
and
27 deletions
+225
-27
tests/basic_correctness/test_chunked_prefill.py
tests/basic_correctness/test_chunked_prefill.py
+66
-0
tests/core/test_block_manager.py
tests/core/test_block_manager.py
+40
-0
tests/core/test_chunked_prefill_scheduler.py
tests/core/test_chunked_prefill_scheduler.py
+39
-0
vllm/core/block_manager_v1.py
vllm/core/block_manager_v1.py
+13
-6
vllm/core/block_manager_v2.py
vllm/core/block_manager_v2.py
+2
-1
vllm/core/embedding_model_block_manager.py
vllm/core/embedding_model_block_manager.py
+2
-1
vllm/core/interfaces.py
vllm/core/interfaces.py
+2
-1
vllm/core/scheduler.py
vllm/core/scheduler.py
+24
-6
vllm/worker/model_runner.py
vllm/worker/model_runner.py
+37
-12
No files found.
tests/basic_correctness/test_chunked_prefill.py
View file @
e3580537
...
@@ -6,6 +6,7 @@ prefill requests are chunked.
...
@@ -6,6 +6,7 @@ prefill requests are chunked.
Run `pytest tests/models/test_chunked_prefill.py`.
Run `pytest tests/models/test_chunked_prefill.py`.
"""
"""
from
contextlib
import
nullcontext
import
pytest
import
pytest
...
@@ -156,3 +157,68 @@ def test_models_with_fp8_kv_cache(
...
@@ -156,3 +157,68 @@ def test_models_with_fp8_kv_cache(
name_0
=
"no_chunked_prefill"
,
name_0
=
"no_chunked_prefill"
,
name_1
=
"chunked_prefill"
,
name_1
=
"chunked_prefill"
,
)
)
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
16
])
@
pytest
.
mark
.
parametrize
(
"enforce_eager"
,
[
False
])
@
pytest
.
mark
.
parametrize
(
"chunk_size"
,
[
30
,
32
])
@
pytest
.
mark
.
parametrize
(
"use_v2_block_manager"
,
[
False
,
True
])
# NOTE: Increasing this in this suite will fail CI because we currently cannot
# reset distributed env properly. Use a value > 1 just when you test.
@
pytest
.
mark
.
parametrize
(
"tensor_parallel_size"
,
[
1
])
def
test_with_prefix_caching
(
vllm_runner
,
max_tokens
:
int
,
enforce_eager
:
bool
,
chunk_size
:
int
,
use_v2_block_manager
:
bool
,
tensor_parallel_size
:
int
,
)
->
None
:
"""
Checks exact match decode with and without prefix caching
with chunked prefill enabled.
"""
model
=
"meta-llama/Llama-2-7b-chat-hf"
# The common prompt has 142 tokens with Llama-2 tokenizer.
common_prompt
=
"You are a helpful AI assistant "
*
20
unique_prompts
=
[
"Question"
,
# Warmup
"Question"
,
# Fully cached
"Another question"
,
# Partial cached
]
full_prompts
=
[
f
"
{
common_prompt
}
\n
{
p
}
"
for
p
in
unique_prompts
]
max_num_batched_tokens
=
max_num_seqs
=
chunk_size
outputs
=
{}
# type: ignore
check_result
=
True
for
enable
in
(
True
,
False
):
with
vllm_runner
(
model
,
dtype
=
"half"
,
max_num_batched_tokens
=
max_num_batched_tokens
,
enable_chunked_prefill
=
True
,
enable_prefix_caching
=
enable
,
tensor_parallel_size
=
tensor_parallel_size
,
use_v2_block_manager
=
use_v2_block_manager
,
enforce_eager
=
enforce_eager
,
max_num_seqs
=
max_num_seqs
,
)
as
vllm_model
:
# It should fail when prefix caching is enable and chunk
# size is not a multiple of block size (16).
should_fail
=
chunk_size
%
16
!=
0
and
enable
check_result
&=
not
should_fail
outputs
[
enable
]
=
[]
# Send the request one-by-one to ensure the cache is populated.
with
pytest
.
raises
(
ValueError
)
if
should_fail
else
nullcontext
():
for
prompt
in
full_prompts
:
outputs
[
enable
]
+=
vllm_model
.
generate_greedy
([
prompt
],
max_tokens
)
# Check results only if we did not expect a failure.
if
check_result
:
check_outputs_equal
(
outputs_0_lst
=
outputs
[
False
],
outputs_1_lst
=
outputs
[
True
],
name_0
=
"w/o prefix caching"
,
name_1
=
"with prefix caching"
,
)
tests/core/test_block_manager.py
View file @
e3580537
...
@@ -595,3 +595,43 @@ def test_sliding_window_multi_seq():
...
@@ -595,3 +595,43 @@ def test_sliding_window_multi_seq():
# assert all blocks are free now
# assert all blocks are free now
assert
block_manager
.
get_num_free_gpu_blocks
()
==
num_gpu_blocks
assert
block_manager
.
get_num_free_gpu_blocks
()
==
num_gpu_blocks
def
test_mark_blocks_as_computed_with_prefix_cache_and_chunked_prefill
():
"""When prefix cache and chunked prefill are enabled, the block manager
should only mark a chunk of blocks as computed instead of all blocks.
"""
block_size
=
4
num_cpu_blocks
=
0
num_gpu_blocks
=
16
block_manager
=
BlockSpaceManagerV1
(
block_size
,
num_gpu_blocks
,
num_cpu_blocks
,
watermark
=
0
,
enable_caching
=
True
)
# Set prompt size to have num_gpu_blocks - 1 full blocks.
prompt_length
=
block_size
*
num_gpu_blocks
-
1
# Allocate (reserve) all blocks.
_
,
seq_group
=
create_dummy_prompt
(
"0"
,
prompt_length
,
block_size
=
block_size
)
block_manager
.
allocate
(
seq_group
)
assert
seq_group
.
seqs
[
0
].
n_blocks
==
num_gpu_blocks
# 1st chunk: Compute 2 and half blocks. Should mark 2 blocks as computed.
token_chunk_size
=
int
(
block_size
*
2.5
)
block_manager
.
mark_blocks_as_computed
(
seq_group
,
token_chunk_size
)
computed_blocks
=
block_manager
.
get_all_computed_blocks
(
seq_group
.
seqs
[
0
])
assert
len
(
computed_blocks
)
==
2
# Actual computed tokens.
seq_group
.
seqs
[
0
].
data
.
update_num_computed_tokens
(
token_chunk_size
)
# 2nd chunk: Complete 3rd block and additional 4 blocks.
token_chunk_size
=
int
(
block_size
*
4.5
)
block_manager
.
mark_blocks_as_computed
(
seq_group
,
token_chunk_size
)
computed_blocks
=
block_manager
.
get_all_computed_blocks
(
seq_group
.
seqs
[
0
])
assert
len
(
computed_blocks
)
==
7
tests/core/test_chunked_prefill_scheduler.py
View file @
e3580537
...
@@ -562,3 +562,42 @@ def test_chunked_prefill_max_seqs():
...
@@ -562,3 +562,42 @@ def test_chunked_prefill_max_seqs():
assert
len
(
get_sequence_groups
(
out
))
==
max_seqs
assert
len
(
get_sequence_groups
(
out
))
==
max_seqs
assert
not
running
[
0
].
is_prefill
()
assert
not
running
[
0
].
is_prefill
()
assert
not
running
[
1
].
is_prefill
()
assert
not
running
[
1
].
is_prefill
()
def
test_perfix_caching
():
"""Verify allocating full blocks when prefix caching is enabled."""
block_size
=
4
max_seqs
=
10
max_model_len
=
80
max_num_batched_tokens
=
64
scheduler_config
=
SchedulerConfig
(
max_num_batched_tokens
,
max_seqs
,
max_model_len
,
enable_chunked_prefill
=
True
)
cache_config
=
CacheConfig
(
block_size
,
1.0
,
1
,
"auto"
,
enable_prefix_caching
=
True
)
cache_config
.
num_cpu_blocks
=
0
cache_config
.
num_gpu_blocks
=
32
scheduler
=
Scheduler
(
scheduler_config
,
cache_config
,
None
)
running
:
List
[
SequenceGroup
]
=
[]
# Add seq groups to scheduler.
for
i
in
range
(
2
):
_
,
seq_group
=
create_dummy_prompt
(
str
(
i
),
block_size
=
block_size
,
prompt_length
=
50
)
scheduler
.
add_seq_group
(
seq_group
)
running
.
append
(
seq_group
)
seq_group_meta
,
out
=
schedule_and_update_computed_tokens
(
scheduler
)
assert
set
(
get_sequence_groups
(
out
))
==
set
(
running
)
assert
seq_group_meta
[
0
].
token_chunk_size
==
50
# Verify it is chunked. Note that although the budget is 64-50=14,
# we only allocate full blocks for prefix caching, so only 4*(14//4)=12
# tokens are allocated.
assert
seq_group_meta
[
1
].
token_chunk_size
==
12
assert
out
.
num_prefill_groups
==
2
assert
out
.
num_batched_tokens
==
62
vllm/core/block_manager_v1.py
View file @
e3580537
...
@@ -681,14 +681,20 @@ class BlockSpaceManagerV1(BlockSpaceManager):
...
@@ -681,14 +681,20 @@ class BlockSpaceManagerV1(BlockSpaceManager):
for
block
in
block_table
:
for
block
in
block_table
:
block
.
last_accessed
=
access_time
block
.
last_accessed
=
access_time
def
compute_full_blocks_in_seq
(
self
,
seq
:
Sequence
):
def
compute_full_blocks_in_seq
(
self
,
seq
:
Sequence
,
token_chunk_size
:
int
):
if
seq
.
seq_id
not
in
self
.
block_tables
:
if
seq
.
seq_id
not
in
self
.
block_tables
:
return
return
max_full_block
=
seq
.
get_len
()
//
self
.
block_size
-
1
# When chunked prefill is enabled, the computed full blocks
# should be calculated based on the number of computed tokens.
max_computed_tokens
=
(
seq
.
data
.
get_num_computed_tokens
()
+
token_chunk_size
)
computed_full_blocks
=
max_computed_tokens
//
self
.
block_size
block_table
=
self
.
block_tables
[
seq
.
seq_id
]
block_table
=
self
.
block_tables
[
seq
.
seq_id
]
if
max
_full_block
==
-
1
:
if
computed
_full_block
s
==
0
:
return
return
for
i
in
reversed
(
range
(
max
_full_block
)):
for
i
in
reversed
(
range
(
computed
_full_block
s
)):
if
block_table
[
i
].
computed
:
if
block_table
[
i
].
computed
:
break
break
block_table
[
i
].
computed
=
True
block_table
[
i
].
computed
=
True
...
@@ -718,10 +724,11 @@ class BlockSpaceManagerV1(BlockSpaceManager):
...
@@ -718,10 +724,11 @@ class BlockSpaceManagerV1(BlockSpaceManager):
ids_list
=
[
self
.
get_all_computed_blocks
(
seq
)
for
seq
in
seqs
]
ids_list
=
[
self
.
get_all_computed_blocks
(
seq
)
for
seq
in
seqs
]
return
commonprefix
([
ids
for
ids
in
ids_list
if
ids
!=
[]])
return
commonprefix
([
ids
for
ids
in
ids_list
if
ids
!=
[]])
def
mark_blocks_as_computed
(
self
,
seq_group
:
SequenceGroup
):
def
mark_blocks_as_computed
(
self
,
seq_group
:
SequenceGroup
,
token_chunk_size
:
int
):
if
self
.
enable_caching
:
if
self
.
enable_caching
:
for
seq
in
seq_group
.
get_seqs
():
for
seq
in
seq_group
.
get_seqs
():
self
.
compute_full_blocks_in_seq
(
seq
)
self
.
compute_full_blocks_in_seq
(
seq
,
token_chunk_size
)
def
get_prefix_cache_hit_rate
(
self
,
device
:
Device
)
->
float
:
def
get_prefix_cache_hit_rate
(
self
,
device
:
Device
)
->
float
:
if
device
==
Device
.
GPU
:
if
device
==
Device
.
GPU
:
...
...
vllm/core/block_manager_v2.py
View file @
e3580537
...
@@ -290,7 +290,8 @@ class BlockSpaceManagerV2(BlockSpaceManager):
...
@@ -290,7 +290,8 @@ class BlockSpaceManagerV2(BlockSpaceManager):
self
.
_last_access_blocks_tracker
.
update_last_access
(
self
.
_last_access_blocks_tracker
.
update_last_access
(
seq
.
seq_id
,
now
)
seq
.
seq_id
,
now
)
def
mark_blocks_as_computed
(
self
,
seq_group
:
SequenceGroup
):
def
mark_blocks_as_computed
(
self
,
seq_group
:
SequenceGroup
,
token_chunk_size
:
int
):
# If prefix caching is enabled, mark immutable blocks as computed
# If prefix caching is enabled, mark immutable blocks as computed
# right after they have been scheduled (for prefill). This assumes
# right after they have been scheduled (for prefill). This assumes
# the scheduler is synchronous so blocks are actually computed when
# the scheduler is synchronous so blocks are actually computed when
...
...
vllm/core/embedding_model_block_manager.py
View file @
e3580537
...
@@ -80,7 +80,8 @@ class EmbeddingModelBlockSpaceManager(BlockSpaceManager):
...
@@ -80,7 +80,8 @@ class EmbeddingModelBlockSpaceManager(BlockSpaceManager):
seq_group
:
List
[
Sequence
])
->
List
[
int
]:
seq_group
:
List
[
Sequence
])
->
List
[
int
]:
return
[]
return
[]
def
mark_blocks_as_computed
(
self
,
seq_group
:
SequenceGroup
):
def
mark_blocks_as_computed
(
self
,
seq_group
:
SequenceGroup
,
token_chunk_size
:
int
):
pass
pass
def
get_prefix_cache_hit_rate
(
self
,
device
:
Device
)
->
float
:
def
get_prefix_cache_hit_rate
(
self
,
device
:
Device
)
->
float
:
...
...
vllm/core/interfaces.py
View file @
e3580537
...
@@ -115,7 +115,8 @@ class BlockSpaceManager(ABC):
...
@@ -115,7 +115,8 @@ class BlockSpaceManager(ABC):
pass
pass
@
abstractmethod
@
abstractmethod
def
mark_blocks_as_computed
(
self
,
seq_group
:
SequenceGroup
):
def
mark_blocks_as_computed
(
self
,
seq_group
:
SequenceGroup
,
token_chunk_size
:
int
):
pass
pass
@
abstractmethod
@
abstractmethod
...
...
vllm/core/scheduler.py
View file @
e3580537
...
@@ -1226,7 +1226,8 @@ class Scheduler:
...
@@ -1226,7 +1226,8 @@ class Scheduler:
# will crash the vLLM instance / will not retry.
# will crash the vLLM instance / will not retry.
for
scheduled_seq_group
in
scheduler_outputs
.
scheduled_seq_groups
:
for
scheduled_seq_group
in
scheduler_outputs
.
scheduled_seq_groups
:
self
.
block_manager
.
mark_blocks_as_computed
(
self
.
block_manager
.
mark_blocks_as_computed
(
scheduled_seq_group
.
seq_group
)
scheduled_seq_group
.
seq_group
,
scheduled_seq_group
.
token_chunk_size
)
self
.
_seq_group_metadata_cache
[
self
.
next_cache_id
].
reset
()
self
.
_seq_group_metadata_cache
[
self
.
next_cache_id
].
reset
()
...
@@ -1457,10 +1458,27 @@ class Scheduler:
...
@@ -1457,10 +1458,27 @@ class Scheduler:
for
seq
in
seqs
:
for
seq
in
seqs
:
num_new_tokens
+=
seq
.
get_num_new_tokens
()
num_new_tokens
+=
seq
.
get_num_new_tokens
()
assert
num_new_tokens
>
0
assert
num_new_tokens
>
0
# Chunk if a running request cannot fit in.
# Chunk if a running request cannot fit in
the given budget
.
# If number of seq > 1, it means it is doing beam search
in a
# If number of seq > 1, it means it is doing beam search
# decode phase. Do not chunk
in that case
.
#
in a
decode phase. Do not chunk.
if
enable_chunking
and
len
(
seqs
)
==
1
:
if
enable_chunking
and
len
(
seqs
)
==
1
:
num_new_tokens
=
min
(
num_new_tokens
,
remaining_token_budget
=
budget
.
remaining_token_budget
()
budget
.
remaining_token_budget
())
if
self
.
cache_config
.
enable_prefix_caching
:
# When prefix caching is enabled, we always allocate
# the number of new tokens that is dividable by the block size
# to avoid partial block matching.
block_size
=
self
.
cache_config
.
block_size
reminder
=
budget
.
token_budget
%
block_size
if
reminder
!=
0
:
raise
ValueError
(
"When enabling chunked prefill and "
"prefix caching, max_num_batched_tokens "
"(chunk size) must be dividable by "
"block size, but got chunk_size "
f
"(
{
budget
.
token_budget
}
) % block_size "
f
"(
{
block_size
}
) =
{
reminder
}
"
)
if
remaining_token_budget
<
num_new_tokens
:
num_new_tokens
=
(
remaining_token_budget
//
block_size
)
*
block_size
else
:
num_new_tokens
=
min
(
num_new_tokens
,
remaining_token_budget
)
return
num_new_tokens
return
num_new_tokens
vllm/worker/model_runner.py
View file @
e3580537
...
@@ -501,23 +501,48 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
...
@@ -501,23 +501,48 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
and
self
.
sliding_window
is
None
and
self
.
sliding_window
is
None
and
inter_data
.
is_prompt
)
and
inter_data
.
is_prompt
)
inter_data
.
prefix_cache_hit
=
prefix_cache_hit
inter_data
.
prefix_cache_hit
=
prefix_cache_hit
if
self
.
chunked_prefill_enabled
and
prefix_cache_hit
:
raise
RuntimeError
(
"chunked prefill cannot be used with prefix caching now."
)
# If prefix cache is hit, advance context length to bypass
if
not
prefix_cache_hit
:
# hit blocks. Accordingly, input tokens, position and query length
return
# have to be updated.
if
prefix_cache_hit
:
assert
computed_block_nums
is
not
None
assert
computed_block_nums
is
not
None
context_len
=
len
(
computed_block_nums
)
*
self
.
block_size
# The cache hit prompt tokens in this sequence. Note that
# this may be larger than the sequence length if chunked
# prefill is enabled.
prefix_cache_len
=
len
(
computed_block_nums
)
*
self
.
block_size
# The number of so far computed prompt tokens in this sequence.
context_len
=
inter_data
.
context_lens
[
seq_idx
]
# The total number of prompt tokens in this sequence.
# When chunked prefill is enabled, this is the token number of
# computed chunks + current chunk.
seq_len
=
inter_data
.
seq_lens
[
seq_idx
]
if
prefix_cache_len
<=
context_len
:
# We already passed the cache hit region,
# so do normal computation.
pass
elif
context_len
<
prefix_cache_len
<
seq_len
:
# Partial hit. Compute the missing part.
uncomputed_start
=
prefix_cache_len
-
context_len
inter_data
.
input_tokens
[
seq_idx
]
=
inter_data
.
input_tokens
[
inter_data
.
input_tokens
[
seq_idx
]
=
inter_data
.
input_tokens
[
seq_idx
][
context_len
:]
seq_idx
][
uncomputed_start
:]
inter_data
.
input_positions
[
seq_idx
]
=
inter_data
.
input_positions
[
inter_data
.
input_positions
[
seq_idx
]
=
inter_data
.
input_positions
[
seq_idx
][
context_len
:]
seq_idx
][
uncomputed_start
:]
context_len
=
prefix_cache_len
inter_data
.
context_lens
[
seq_idx
]
=
context_len
inter_data
.
context_lens
[
seq_idx
]
=
context_len
inter_data
.
query_lens
[
inter_data
.
query_lens
[
seq_idx
]
=
inter_data
.
seq_lens
[
seq_idx
]
-
context_len
seq_idx
]
=
inter_data
.
seq_lens
[
seq_idx
]
-
context_len
elif
seq_len
<=
prefix_cache_len
:
# Full hit. Only compute the last token to avoid
# erroneous behavior. FIXME: Ideally we should directly
# mark all tokens as computed in the scheduler and do not
# schedule this sequence, so this case should not happen.
inter_data
.
input_tokens
[
seq_idx
]
=
inter_data
.
input_tokens
[
seq_idx
][
-
1
:]
inter_data
.
input_positions
[
seq_idx
]
=
inter_data
.
input_positions
[
seq_idx
][
-
1
:]
inter_data
.
query_lens
[
seq_idx
]
=
1
inter_data
.
context_lens
[
seq_idx
]
=
inter_data
.
seq_lens
[
seq_idx
]
-
1
def
_compute_for_sliding_window
(
self
,
inter_data
:
InterDataForSeqGroup
,
def
_compute_for_sliding_window
(
self
,
inter_data
:
InterDataForSeqGroup
,
seq_idx
:
int
,
seq_idx
:
int
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment