Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
b35cc934
Unverified
Commit
b35cc934
authored
Mar 08, 2024
by
ElizaWszola
Committed by
GitHub
Mar 07, 2024
Browse files
Fix auto prefix bug (#3239)
parent
8cbba462
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
51 additions
and
12 deletions
+51
-12
tests/engine/test_computed_prefix_blocks.py
tests/engine/test_computed_prefix_blocks.py
+34
-0
vllm/core/block_manager.py
vllm/core/block_manager.py
+16
-12
vllm/worker/model_runner.py
vllm/worker/model_runner.py
+1
-0
No files found.
tests/engine/test_computed_prefix_blocks.py
0 → 100644
View file @
b35cc934
import
pytest
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.engine.llm_engine
import
LLMEngine
from
vllm.sampling_params
import
SamplingParams
@
pytest
.
mark
.
parametrize
(
"model"
,
[
"facebook/opt-125m"
])
@
pytest
.
mark
.
parametrize
(
"block_size"
,
[
16
])
def
test_computed_prefix_blocks
(
model
:
str
,
block_size
:
int
):
# This test checks if we are able to run the engine to completion
# without triggering asserts.
# We are in a scenario where all blocks from the second request's prompt
# are full and already computed when the second request arrives.
prompt
=
(
"You are a helpful assistant. How do I build a car from cardboard and "
"paper clips? Is there an easy to follow video tutorial available "
"online for free?"
)
prompt2
=
(
" Please recommend to me some resources where I can learn not only to "
"handle technical difficulties of building a car, but also "
"decoration."
)
engine_args
=
EngineArgs
(
model
=
model
,
block_size
=
block_size
,
enable_prefix_caching
=
True
)
engine
=
LLMEngine
.
from_engine_args
(
engine_args
)
sampling_params
=
SamplingParams
()
engine
.
add_request
(
"0"
,
prompt
+
prompt2
,
sampling_params
)
engine
.
step
()
engine
.
add_request
(
"1"
,
prompt
,
sampling_params
)
engine
.
step
()
vllm/core/block_manager.py
View file @
b35cc934
"""A block manager that manages token blocks."""
"""A block manager that manages token blocks."""
import
enum
import
enum
from
itertools
import
count
from
itertools
import
count
,
takewhile
from
os.path
import
commonprefix
from
os.path
import
commonprefix
from
typing
import
Dict
,
List
,
Optional
,
Set
,
Tuple
from
typing
import
Dict
,
List
,
Optional
,
Set
,
Tuple
...
@@ -426,23 +426,29 @@ class BlockSpaceManager:
...
@@ -426,23 +426,29 @@ class BlockSpaceManager:
for
block
in
block_table
:
for
block
in
block_table
:
block
.
last_accessed
=
access_time
block
.
last_accessed
=
access_time
def
compute_
last_
full_block_in_seq
(
self
,
seq
:
Sequence
):
def
compute_full_block
s
_in_seq
(
self
,
seq
:
Sequence
):
if
seq
.
seq_id
not
in
self
.
block_tables
:
if
seq
.
seq_id
not
in
self
.
block_tables
:
return
return
max_full_block
=
seq
.
get_len
()
//
self
.
block_size
-
1
max_full_block
=
seq
.
get_len
()
//
self
.
block_size
-
1
block_table
=
self
.
block_tables
[
seq
.
seq_id
]
block_table
=
self
.
block_tables
[
seq
.
seq_id
]
if
max_full_block
==
-
1
:
if
max_full_block
==
-
1
:
return
return
block_table
[
max_full_block
].
computed
=
True
for
i
in
reversed
(
range
(
max_full_block
)):
if
block_table
[
i
].
computed
:
break
block_table
[
i
].
computed
=
True
def
get_all_
block_ids_till_
computed
(
self
,
seq
:
Sequence
)
->
List
[
int
]:
def
get_all_computed
_blocks
(
self
,
seq
:
Sequence
)
->
List
[
int
]:
if
seq
.
seq_id
not
in
self
.
block_tables
:
if
seq
.
seq_id
not
in
self
.
block_tables
:
return
[]
return
[]
block_table
=
self
.
block_tables
[
seq
.
seq_id
]
block_table
=
self
.
block_tables
[
seq
.
seq_id
]
for
block_idx
in
reversed
(
range
(
len
(
block_table
))):
# NOTE We exclude the last block to avoid the case where the entire
if
block_table
[
block_idx
].
computed
:
# prompt is cached. This would cause erroneous behavior in model
return
[
b
.
block_number
for
b
in
block_table
[:
block_idx
+
1
]]
# runner.
return
[]
return
[
b
.
block_number
for
b
in
takewhile
(
lambda
b
:
b
.
computed
,
block_table
[:
-
1
])
]
def
get_common_computed_block_ids
(
self
,
def
get_common_computed_block_ids
(
self
,
seq_group
:
SequenceGroup
)
->
List
[
int
]:
seq_group
:
SequenceGroup
)
->
List
[
int
]:
...
@@ -451,14 +457,12 @@ class BlockSpaceManager:
...
@@ -451,14 +457,12 @@ class BlockSpaceManager:
return
[]
return
[]
ids_list
=
[
ids_list
=
[
self
.
get_all_
block_ids_till_
computed
(
seq
)
self
.
get_all_computed
_blocks
(
seq
)
for
seq
in
iter
(
seq_group
.
seqs_dict
.
values
())
for
seq
in
iter
(
seq_group
.
seqs_dict
.
values
())
]
]
return
commonprefix
([
ids
for
ids
in
ids_list
if
ids
!=
[]])
return
commonprefix
([
ids
for
ids
in
ids_list
if
ids
!=
[]])
def
mark_blocks_as_computed
(
self
,
seq_group
:
SequenceGroup
):
def
mark_blocks_as_computed
(
self
,
seq_group
:
SequenceGroup
):
# NOTE: We only mark the last full block because with prefix caching,
# all blocks until the marked one are guaranteed to be computed.
if
self
.
enable_caching
:
if
self
.
enable_caching
:
for
seq
in
seq_group
.
seqs_dict
.
values
():
for
seq
in
seq_group
.
seqs_dict
.
values
():
self
.
compute_
last_
full_block_in_seq
(
seq
)
self
.
compute_full_block
s
_in_seq
(
seq
)
vllm/worker/model_runner.py
View file @
b35cc934
...
@@ -215,6 +215,7 @@ class ModelRunner:
...
@@ -215,6 +215,7 @@ class ModelRunner:
slot_mapping
[
-
1
].
append
(
slot
)
slot_mapping
[
-
1
].
append
(
slot
)
max_prompt_len
=
max
(
subquery_lens
)
max_prompt_len
=
max
(
subquery_lens
)
assert
max_prompt_len
>
0
input_tokens
=
_make_tensor_with_pad
(
input_tokens
,
input_tokens
=
_make_tensor_with_pad
(
input_tokens
,
max_prompt_len
,
max_prompt_len
,
pad
=
0
,
pad
=
0
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment