Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
fb2c1c86
Unverified
Commit
fb2c1c86
authored
Aug 02, 2024
by
Zach Zheng
Committed by
GitHub
Aug 02, 2024
Browse files
[Bugfix] Fix block table for seqs that have prefix cache hits (#7018)
parent
0c25435d
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
65 additions
and
3 deletions
+65
-3
tests/prefix_caching/test_prefix_caching.py
tests/prefix_caching/test_prefix_caching.py
+56
-0
vllm/attention/backends/flash_attn.py
vllm/attention/backends/flash_attn.py
+9
-3
No files found.
tests/prefix_caching/test_prefix_caching.py
View file @
fb2c1c86
...
@@ -6,10 +6,17 @@ from typing import List
...
@@ -6,10 +6,17 @@ from typing import List
import
pytest
import
pytest
from
tests.kernels.utils
import
override_backend_env_variable
from
vllm.block
import
PhysicalTokenBlock
from
vllm.block
import
PhysicalTokenBlock
from
vllm.core.block_manager_v1
import
CachedBlockAllocator
from
vllm.core.block_manager_v1
import
CachedBlockAllocator
from
vllm.utils
import
Device
from
vllm.utils
import
Device
from
..models.utils
import
check_outputs_equal
MODELS
=
[
"facebook/opt-125m"
,
]
@
pytest
.
mark
.
parametrize
(
"block_size"
,
[
16
])
@
pytest
.
mark
.
parametrize
(
"block_size"
,
[
16
])
@
pytest
.
mark
.
parametrize
(
"num_blocks"
,
[
16
])
@
pytest
.
mark
.
parametrize
(
"num_blocks"
,
[
16
])
...
@@ -76,3 +83,52 @@ def test_eviction(num_blocks: int, ):
...
@@ -76,3 +83,52 @@ def test_eviction(num_blocks: int, ):
assert
(
realloc_block
!=
new_block
)
assert
(
realloc_block
!=
new_block
)
assert
(
new_block
.
block_hash
==
new_block_hash
)
assert
(
new_block
.
block_hash
==
new_block_hash
)
assert
(
new_block
.
block_number
==
2
)
assert
(
new_block
.
block_number
==
2
)
@
pytest
.
mark
.
parametrize
(
"model"
,
MODELS
)
@
pytest
.
mark
.
parametrize
(
"backend"
,
[
"FLASH_ATTN"
,
"FLASHINFER"
,
"XFORMERS"
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
"half"
])
@
pytest
.
mark
.
parametrize
(
"max_tokens"
,
[
5
])
@
pytest
.
mark
.
parametrize
(
"cached_position"
,
[
0
,
1
])
@
pytest
.
mark
.
parametrize
(
"use_v2_block_manager"
,
[
False
,
True
])
def
test_mixed_requests
(
hf_runner
,
vllm_runner
,
example_prompts
,
model
:
str
,
backend
:
str
,
dtype
:
str
,
max_tokens
:
int
,
cached_position
:
int
,
use_v2_block_manager
:
bool
,
monkeypatch
,
)
->
None
:
"""
Test the case when some sequences have the prefix cache hit
and the others don't. The cached position determines where
the sequence is at among the batch of prefills.
"""
override_backend_env_variable
(
monkeypatch
,
backend
)
with
hf_runner
(
model
,
dtype
=
dtype
)
as
hf_model
:
hf_outputs
=
hf_model
.
generate_greedy
(
example_prompts
,
max_tokens
)
cached_prompt
=
example_prompts
[
cached_position
]
with
vllm_runner
(
model
,
dtype
=
dtype
,
enable_prefix_caching
=
True
,
use_v2_block_manager
=
use_v2_block_manager
,
)
as
vllm_model
:
# Run the first prompt so the cache is populated
vllm_outputs
=
vllm_model
.
generate_greedy
([
cached_prompt
],
max_tokens
)
# Run all the promopts
vllm_outputs
=
vllm_model
.
generate_greedy
(
example_prompts
,
max_tokens
)
check_outputs_equal
(
outputs_0_lst
=
hf_outputs
,
outputs_1_lst
=
vllm_outputs
,
name_0
=
"hf"
,
name_1
=
"vllm"
,
)
vllm/attention/backends/flash_attn.py
View file @
fb2c1c86
...
@@ -209,6 +209,7 @@ class FlashAttentionMetadataBuilder(
...
@@ -209,6 +209,7 @@ class FlashAttentionMetadataBuilder(
self
.
num_prefills
=
0
self
.
num_prefills
=
0
self
.
num_prefill_tokens
=
0
self
.
num_prefill_tokens
=
0
self
.
num_decode_tokens
=
0
self
.
num_decode_tokens
=
0
self
.
has_prefix_cache_hit
=
False
self
.
input_builder
=
input_builder
self
.
input_builder
=
input_builder
self
.
runner
=
input_builder
.
runner
self
.
runner
=
input_builder
.
runner
...
@@ -219,7 +220,7 @@ class FlashAttentionMetadataBuilder(
...
@@ -219,7 +220,7 @@ class FlashAttentionMetadataBuilder(
def
_add_seq_group
(
def
_add_seq_group
(
self
,
inter_data
:
"ModelInputForGPUBuilder.InterDataForSeqGroup"
,
self
,
inter_data
:
"ModelInputForGPUBuilder.InterDataForSeqGroup"
,
chunked_prefill_enabled
:
bool
):
chunked_prefill_enabled
:
bool
,
prefix_cache_hit
:
bool
):
"""Add a sequence group to the metadata. Specifically update/append
"""Add a sequence group to the metadata. Specifically update/append
1. context length.
1. context length.
2. block table.
2. block table.
...
@@ -252,7 +253,7 @@ class FlashAttentionMetadataBuilder(
...
@@ -252,7 +253,7 @@ class FlashAttentionMetadataBuilder(
# only allowing multiple of block_size chunk size.
# only allowing multiple of block_size chunk size.
# NOTE: This only works for oooooooxxx style attention.
# NOTE: This only works for oooooooxxx style attention.
block_table
=
[]
block_table
=
[]
if
inter_data
.
prefix_cache_hit
:
if
prefix_cache_hit
:
# NOTE(woosuk): For flash-attn, the block table should
# NOTE(woosuk): For flash-attn, the block table should
# include the entries for the incoming prefill tokens.
# include the entries for the incoming prefill tokens.
block_table
=
block_tables
[
seq_id
]
block_table
=
block_tables
[
seq_id
]
...
@@ -281,9 +282,14 @@ class FlashAttentionMetadataBuilder(
...
@@ -281,9 +282,14 @@ class FlashAttentionMetadataBuilder(
-1 if cuda graph is not used.
-1 if cuda graph is not used.
batch_size: The maybe padded batch size.
batch_size: The maybe padded batch size.
"""
"""
prefix_cache_hit
=
any
([
inter_data
.
prefix_cache_hit
for
inter_data
in
self
.
input_builder
.
inter_data_list
])
for
inter_data
in
self
.
input_builder
.
inter_data_list
:
for
inter_data
in
self
.
input_builder
.
inter_data_list
:
self
.
_add_seq_group
(
inter_data
,
self
.
_add_seq_group
(
inter_data
,
self
.
input_builder
.
chunked_prefill_enabled
)
self
.
input_builder
.
chunked_prefill_enabled
,
prefix_cache_hit
)
device
=
self
.
runner
.
device
device
=
self
.
runner
.
device
use_captured_graph
=
cuda_graph_pad_size
!=
-
1
use_captured_graph
=
cuda_graph_pad_size
!=
-
1
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment