Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
bf8717eb
Unverified
Commit
bf8717eb
authored
Dec 17, 2024
by
Cody Yu
Committed by
GitHub
Dec 17, 2024
Browse files
[V1] Prefix caching for vision language models (#11187)
Signed-off-by:
Cody Yu
<
hao.yu.cody@gmail.com
>
parent
c77eb8a3
Changes
14
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
342 additions
and
98 deletions
+342
-98
tests/v1/core/test_prefix_caching.py
tests/v1/core/test_prefix_caching.py
+86
-2
tests/v1/engine/test_engine_args.py
tests/v1/engine/test_engine_args.py
+0
-15
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+14
-13
vllm/inputs/data.py
vllm/inputs/data.py
+20
-0
vllm/multimodal/inputs.py
vllm/multimodal/inputs.py
+3
-0
vllm/v1/core/kv_cache_manager.py
vllm/v1/core/kv_cache_manager.py
+50
-24
vllm/v1/core/kv_cache_utils.py
vllm/v1/core/kv_cache_utils.py
+105
-10
vllm/v1/core/scheduler.py
vllm/v1/core/scheduler.py
+2
-0
vllm/v1/engine/async_llm.py
vllm/v1/engine/async_llm.py
+7
-3
vllm/v1/engine/core.py
vllm/v1/engine/core.py
+4
-4
vllm/v1/engine/llm_engine.py
vllm/v1/engine/llm_engine.py
+6
-3
vllm/v1/engine/mm_input_mapper.py
vllm/v1/engine/mm_input_mapper.py
+15
-18
vllm/v1/engine/processor.py
vllm/v1/engine/processor.py
+7
-5
vllm/v1/request.py
vllm/v1/request.py
+23
-1
No files found.
tests/v1/core/test_prefix_caching.py
View file @
bf8717eb
...
...
@@ -2,16 +2,23 @@
import
pytest
from
vllm.inputs
import
token_inputs
from
vllm.multimodal.inputs
import
PlaceholderRange
from
vllm.sampling_params
import
SamplingParams
from
vllm.utils
import
cdiv
from
vllm.v1.core.kv_cache_manager
import
KVCacheManager
,
Request
from
vllm.v1.core.kv_cache_utils
import
KVCacheBlock
,
hash_block_tokens
def
make_request
(
request_id
,
prompt_token_ids
):
def
make_request
(
request_id
,
prompt_token_ids
,
mm_positions
=
None
,
mm_hashes
=
None
):
return
Request
(
request_id
=
request_id
,
inputs
=
token_inputs
(
prompt_token_ids
=
prompt_token_ids
),
inputs
=
token_inputs
(
prompt_token_ids
=
prompt_token_ids
,
multi_modal_placeholders
=
{
"image"
:
mm_positions
}
if
mm_positions
else
None
,
multi_modal_hashes
=
mm_hashes
),
sampling_params
=
SamplingParams
(
max_tokens
=
17
),
eos_token_id
=
100
,
arrival_time
=
0
,
...
...
@@ -38,6 +45,7 @@ def test_prefill():
all_token_ids
=
common_token_ids
+
unique_token_ids
req0
=
make_request
(
"0"
,
all_token_ids
)
computed_blocks
=
manager
.
get_computed_blocks
(
req0
)
assert
len
(
req0
.
kv_block_hashes
)
==
3
assert
not
computed_blocks
blocks
=
manager
.
allocate_slots
(
req0
,
55
,
computed_blocks
)
assert
[
b
.
block_id
for
b
in
blocks
]
==
[
0
,
1
,
2
,
3
,
4
]
...
...
@@ -61,6 +69,7 @@ def test_prefill():
unique_token_ids
=
[
3
]
*
5
req1
=
make_request
(
"1"
,
common_token_ids
+
unique_token_ids
)
computed_blocks
=
manager
.
get_computed_blocks
(
req1
)
assert
len
(
req1
.
kv_block_hashes
)
==
3
assert
[
b
.
block_id
for
b
in
computed_blocks
]
==
[
0
,
1
,
2
]
num_new_tokens
=
53
-
3
*
16
blocks
=
manager
.
allocate_slots
(
req1
,
num_new_tokens
,
computed_blocks
)
...
...
@@ -90,6 +99,7 @@ def test_prefill():
unique_token_ids
=
[
3
]
*
6
req2
=
make_request
(
"2"
,
common_token_ids
+
unique_token_ids
)
computed_block
=
manager
.
get_computed_blocks
(
req2
)
assert
len
(
req2
.
kv_block_hashes
)
==
3
assert
[
b
.
block_id
for
b
in
computed_block
]
==
[
0
,
1
,
2
]
num_new_tokens
=
53
-
3
*
16
blocks
=
manager
.
allocate_slots
(
req2
,
num_new_tokens
,
computed_blocks
)
...
...
@@ -416,3 +426,77 @@ def test_cache_blocks():
)
assert
len
(
manager
.
cached_block_hash_to_block
)
==
3
assert
blocks
[
0
].
block_hash
is
not
None
def
test_mm_prefix_caching
():
"""
This tests that the multi-modal prefix caching is correct.
"""
manager
=
KVCacheManager
(
block_size
=
16
,
num_gpu_blocks
=
10
,
max_model_len
=
8192
,
sliding_window
=
None
,
enable_caching
=
True
,
num_preallocate_tokens
=
16
,
)
# Common prompt tokens (T is text tokens and P is image placeholder tokens)
# [T,...,T, P0,...,P0], [P0,...,P0,T,...,T,P1,...,P1], [P1,...,P1]
common_token_ids
=
list
(
range
(
10
))
+
[
-
1
]
*
6
common_token_ids
+=
[
-
1
]
*
4
+
list
(
range
(
10
,
20
))
+
[
-
1
]
*
2
common_token_ids
+=
[
-
1
]
*
16
common_mm_positions
=
[
PlaceholderRange
(
offset
=
11
,
length
=
10
),
PlaceholderRange
(
offset
=
30
,
length
=
18
),
]
common_mm_hashes
=
[
"aaa"
,
"bbb"
]
# A unique image plus some text tokens.
unique_token_ids
=
[
-
1
]
*
7
+
[
100
]
*
4
all_token_ids
=
common_token_ids
+
unique_token_ids
mm_positions
=
common_mm_positions
+
[
PlaceholderRange
(
offset
=
48
,
length
=
7
)
]
mm_hashes
=
common_mm_hashes
+
[
"ccc"
]
req0
=
make_request
(
"0"
,
all_token_ids
,
mm_positions
=
mm_positions
,
mm_hashes
=
mm_hashes
)
computed_blocks
=
manager
.
get_computed_blocks
(
req0
)
# Completed block should have hashes with extra keys.
assert
not
computed_blocks
assert
len
(
req0
.
kv_block_hashes
)
==
3
assert
req0
.
kv_block_hashes
[
0
].
extra_keys
==
((
"aaa"
,
0
),
)
assert
req0
.
kv_block_hashes
[
1
].
extra_keys
==
((
"aaa"
,
5
),
(
"bbb"
,
0
))
assert
req0
.
kv_block_hashes
[
2
].
extra_keys
==
((
"bbb"
,
2
),
)
blocks
=
manager
.
allocate_slots
(
req0
,
59
,
computed_blocks
)
assert
[
b
.
block_id
for
b
in
blocks
]
==
[
0
,
1
,
2
,
3
,
4
]
req0
.
num_computed_tokens
=
59
# Append slots without allocating a new block.
for
_
in
range
(
5
):
req0
.
append_output_token_ids
(
8
)
new_blocks
=
manager
.
append_slots
(
req0
,
5
)
assert
new_blocks
is
not
None
and
len
(
new_blocks
)
==
0
# The just completed block should have hashes with extra keys.
assert
len
(
req0
.
kv_block_hashes
)
==
4
assert
req0
.
kv_block_hashes
[
3
].
extra_keys
==
((
"ccc"
,
0
),
)
# Cache hit.
unique_token_ids
=
[
-
1
]
*
7
+
[
200
]
*
5
all_token_ids
=
common_token_ids
+
unique_token_ids
mm_positions
=
common_mm_positions
+
[
PlaceholderRange
(
offset
=
48
,
length
=
7
)
]
mm_hashes
=
common_mm_hashes
+
[
"ccc"
]
req1
=
make_request
(
"1"
,
all_token_ids
,
mm_positions
=
mm_positions
,
mm_hashes
=
mm_hashes
)
computed_blocks
=
manager
.
get_computed_blocks
(
req1
)
assert
len
(
computed_blocks
)
==
3
tests/v1/engine/test_engine_args.py
View file @
bf8717eb
...
...
@@ -31,14 +31,6 @@ def test_prefix_caching_from_cli():
assert
engine_args
.
enable_prefix_caching
def
test_defaults
():
engine_args
=
EngineArgs
(
model
=
"facebook/opt-125m"
)
# Assert V1 defaults
assert
(
engine_args
.
enable_prefix_caching
),
"V1 turns on prefix caching by default"
def
test_defaults_with_usage_context
():
engine_args
=
EngineArgs
(
model
=
"facebook/opt-125m"
)
vllm_config
:
VllmConfig
=
engine_args
.
create_engine_config
(
...
...
@@ -52,10 +44,3 @@ def test_defaults_with_usage_context():
UsageContext
.
OPENAI_API_SERVER
)
assert
vllm_config
.
scheduler_config
.
max_num_seqs
==
1024
assert
vllm_config
.
scheduler_config
.
max_num_batched_tokens
==
2048
def
test_prefix_cache_disabled_with_multimodel
():
engine_args
=
EngineArgs
(
model
=
"llava-hf/llava-1.5-7b-hf"
)
vllm_config
=
engine_args
.
create_engine_config
(
UsageContext
.
LLM_CLASS
)
assert
not
vllm_config
.
cache_config
.
enable_prefix_caching
vllm/engine/arg_utils.py
View file @
bf8717eb
...
...
@@ -205,6 +205,7 @@ class EngineArgs:
# by user.
if
self
.
enable_prefix_caching
is
None
:
self
.
enable_prefix_caching
=
bool
(
envs
.
VLLM_USE_V1
)
# Override max_num_seqs if it's not set by user.
if
self
.
max_num_seqs
is
None
:
self
.
max_num_seqs
=
256
if
not
envs
.
VLLM_USE_V1
else
1024
...
...
@@ -1026,11 +1027,11 @@ class EngineArgs:
device_config
=
DeviceConfig
(
device
=
self
.
device
)
model_config
=
self
.
create_model_config
()
if
model_config
.
is_multimodal_model
:
if
self
.
enable_prefix_caching
:
logger
.
warning
(
"--enable-prefix-caching is currently not
"
"supported for multimodal models and
has been disabled."
)
if
(
model_config
.
is_multimodal_model
and
not
envs
.
VLLM_USE_V1
and
self
.
enable_prefix_caching
)
:
logger
.
warning
(
"--enable-prefix-caching is currently not "
"supported for multimodal models in v0 and
"
"
has been disabled."
)
self
.
enable_prefix_caching
=
False
cache_config
=
CacheConfig
(
...
...
@@ -1249,11 +1250,14 @@ class EngineArgs:
# When no user override, set the default values based on the usage
# context.
# TODO(woosuk): Tune the default values for different hardware.
if
self
.
max_num_batched_tokens
is
None
:
if
usage_context
==
UsageContext
.
LLM_CLASS
:
self
.
max_num_batched_tokens
=
8192
elif
usage_context
==
UsageContext
.
OPENAI_API_SERVER
:
self
.
max_num_batched_tokens
=
2048
default_max_num_batched_tokens
=
{
UsageContext
.
LLM_CLASS
:
8192
,
UsageContext
.
OPENAI_API_SERVER
:
2048
,
}
if
(
self
.
max_num_batched_tokens
is
None
and
usage_context
in
default_max_num_batched_tokens
):
self
.
max_num_batched_tokens
=
default_max_num_batched_tokens
[
usage_context
]
logger
.
warning
(
"Setting max_num_batched_tokens to %d for %s usage context."
,
self
.
max_num_batched_tokens
,
usage_context
.
value
)
...
...
@@ -1263,9 +1267,6 @@ class EngineArgs:
Override the EngineConfig's configs based on the usage context for V1.
"""
assert
envs
.
VLLM_USE_V1
,
"V1 is not enabled"
if
engine_config
.
model_config
.
is_multimodal_model
:
# TODO (ywang96): Enable APC by default when VLM supports it.
assert
not
engine_config
.
cache_config
.
enable_prefix_caching
@
dataclass
...
...
vllm/inputs/data.py
View file @
bf8717eb
...
...
@@ -162,6 +162,11 @@ class TokenInputs(TypedDict):
Placeholder ranges for the multi-modal data.
"""
multi_modal_hashes
:
NotRequired
[
List
[
str
]]
"""
The hashes of the multi-modal data.
"""
mm_processor_kwargs
:
NotRequired
[
Dict
[
str
,
Any
]]
"""
Optional multi-modal processor kwargs to be forwarded to the
...
...
@@ -177,6 +182,7 @@ def token_inputs(
prompt
:
Optional
[
str
]
=
None
,
multi_modal_data
:
Optional
[
"MultiModalDataDict"
]
=
None
,
multi_modal_inputs
:
Optional
[
"MultiModalKwargs"
]
=
None
,
multi_modal_hashes
:
Optional
[
List
[
str
]]
=
None
,
multi_modal_placeholders
:
Optional
[
"MultiModalPlaceholderDict"
]
=
None
,
mm_processor_kwargs
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
)
->
TokenInputs
:
...
...
@@ -191,6 +197,8 @@ def token_inputs(
inputs
[
"multi_modal_data"
]
=
multi_modal_data
if
multi_modal_inputs
is
not
None
:
inputs
[
"multi_modal_inputs"
]
=
multi_modal_inputs
if
multi_modal_hashes
is
not
None
:
inputs
[
"multi_modal_hashes"
]
=
multi_modal_hashes
if
multi_modal_placeholders
is
not
None
:
inputs
[
"multi_modal_placeholders"
]
=
multi_modal_placeholders
if
mm_processor_kwargs
is
not
None
:
...
...
@@ -295,6 +303,18 @@ class SingletonInputsAdapter:
assert_never
(
inputs
)
@
cached_property
def
multi_modal_hashes
(
self
)
->
List
[
str
]:
inputs
=
self
.
inputs
if
inputs
[
"type"
]
==
"token"
:
return
inputs
.
get
(
"multi_modal_hashes"
,
[])
if
inputs
[
"type"
]
==
"multimodal"
:
return
inputs
.
get
(
"mm_hashes"
,
[])
assert_never
(
inputs
)
@
cached_property
def
multi_modal_placeholders
(
self
)
->
"MultiModalPlaceholderDict"
:
inputs
=
self
.
inputs
...
...
vllm/multimodal/inputs.py
View file @
bf8717eb
...
...
@@ -215,6 +215,9 @@ class MultiModalInputsV2(TypedDict):
mm_kwargs
:
MultiModalKwargs
"""Keyword arguments to be directly passed to the model after batching."""
mm_hashes
:
NotRequired
[
List
[
str
]]
"""The hashes of the multi-modal data."""
mm_placeholders
:
MultiModalPlaceholderDict
"""
For each modality, information about the placeholder tokens in
...
...
vllm/v1/core/kv_cache_manager.py
View file @
bf8717eb
...
...
@@ -4,7 +4,9 @@ from typing import Dict, Iterable, List, Optional
from
vllm.logger
import
init_logger
from
vllm.utils
import
cdiv
from
vllm.v1.core.kv_cache_utils
import
(
BlockHashType
,
FreeKVCacheBlockQueue
,
KVCacheBlock
,
hash_block_tokens
,
KVCacheBlock
,
generate_block_hash_extra_keys
,
hash_block_tokens
,
hash_request_tokens
)
from
vllm.v1.request
import
Request
...
...
@@ -83,10 +85,12 @@ class KVCacheManager:
computed_blocks
=
[]
# TODO(rickyx): potentially we could cache this so we don't have to
# recompute it every time.
block_hashes
=
hash_request_tokens
(
self
.
block_size
,
request
.
all_token_ids
)
# The block hashes for the request may already be computed
# if the request was preempted and resumed.
if
not
request
.
kv_block_hashes
:
request
.
set_kv_block_hashes
(
hash_request_tokens
(
self
.
block_size
,
request
))
block_hashes
=
request
.
kv_block_hashes
for
block_hash
in
block_hashes
:
# block_hashes is a chain of block hashes. If a block hash is not
...
...
@@ -242,14 +246,16 @@ class KVCacheManager:
num_computed_tokens
=
len
(
computed_blocks
)
*
self
.
block_size
num_full_blocks
=
(
num_computed_tokens
+
num_tokens
)
//
self
.
block_size
self
.
_cache_full_blocks
(
request
=
request
,
blk_start_idx
=
len
(
computed_blocks
),
# The new full blocks are the full blocks that are not computed.
full_blocks
=
self
.
req_to_blocks
[
request
.
request_id
]
[
len
(
computed_blocks
):
num_full_blocks
],
prev_block
=
computed_blocks
[
-
1
]
if
computed_blocks
else
None
,
)
new_full_blocks
=
self
.
req_to_blocks
[
request
.
request_id
][
len
(
computed_blocks
):
num_full_blocks
]
if
new_full_blocks
:
self
.
_cache_full_blocks
(
request
=
request
,
blk_start_idx
=
len
(
computed_blocks
),
# The new full blocks are the full blocks that are not computed.
full_blocks
=
new_full_blocks
,
prev_block
=
computed_blocks
[
-
1
]
if
computed_blocks
else
None
,
)
return
new_blocks
...
...
@@ -376,6 +382,8 @@ class KVCacheManager:
full_blocks: The list of blocks to update hash metadata.
prev_block: The previous block in the chain.
"""
num_cached_block_hashes
=
len
(
request
.
kv_block_hashes
)
# Update the new blocks with the block hashes through the chain.
prev_block_hash_value
=
None
if
prev_block
is
not
None
:
...
...
@@ -387,17 +395,35 @@ class KVCacheManager:
for
i
,
blk
in
enumerate
(
full_blocks
):
blk_idx
=
blk_start_idx
+
i
block_tokens
=
request
.
all_token_ids
[
blk_idx
*
self
.
block_size
:(
blk_idx
+
1
)
*
self
.
block_size
]
assert
len
(
block_tokens
)
==
self
.
block_size
,
(
f
"Expected
{
self
.
block_size
}
tokens, got
{
len
(
block_tokens
)
}
"
f
"at
{
blk_idx
}
th block for request "
f
"
{
request
.
request_id
}
(
{
request
}
)"
)
# Compute the hash of the current block.
block_hash
=
hash_block_tokens
(
prev_block_hash_value
,
block_tokens
)
if
blk_idx
<
num_cached_block_hashes
:
# The block hash may already be computed in
# "get_computed_blocks" if the tokens are not generated by
# this request (either the prompt tokens or the previously
# generated tokens with preemption). In this case we simply
# reuse the block hash.
block_hash
=
request
.
kv_block_hashes
[
blk_idx
]
else
:
# Otherwise compute the block hash and cache it in the request
# in case it will be preempted in the future.
start_token_idx
=
blk_idx
*
self
.
block_size
end_token_idx
=
(
blk_idx
+
1
)
*
self
.
block_size
block_tokens
=
request
.
all_token_ids
[
start_token_idx
:
end_token_idx
]
assert
len
(
block_tokens
)
==
self
.
block_size
,
(
f
"Expected
{
self
.
block_size
}
tokens, got "
f
"
{
len
(
block_tokens
)
}
at
{
blk_idx
}
th block for request "
f
"
{
request
.
request_id
}
(
{
request
}
)"
)
# Generate extra keys for multi-modal inputs. Note that since
# we reach to this branch only when the block is completed with
# generated tokens, we only need to consider the last mm input.
extra_keys
,
_
=
generate_block_hash_extra_keys
(
request
,
start_token_idx
,
end_token_idx
,
-
1
)
# Compute the hash of the current block.
block_hash
=
hash_block_tokens
(
prev_block_hash_value
,
block_tokens
,
extra_keys
)
request
.
append_kv_block_hashes
(
block_hash
)
# Update and added the full block to the cache.
blk
.
block_hash
=
block_hash
...
...
vllm/v1/core/kv_cache_utils.py
View file @
bf8717eb
"""KV-Cache Utilities."""
from
collections.abc
import
Sequence
from
dataclasses
import
dataclass
from
typing
import
List
,
NamedTuple
,
Optional
,
Tuple
from
typing
import
Any
,
List
,
NamedTuple
,
Optional
,
Tuple
from
vllm.logger
import
init_logger
from
vllm.v1.request
import
Request
logger
=
init_logger
(
__name__
)
class
BlockHashType
(
NamedTuple
):
"""Hash value of a block
and
the token IDs in the block.
The reason we keep a tuple of token IDs is to make sure
no hash
collision happens when the hash value is the same.
"""Hash value of a block
(int),
the token IDs in the block
, and extra keys
.
The reason we keep a tuple of token IDs
and extra keys
is to make sure
no hash
collision happens when the hash value is the same.
"""
# Hash value of the block in an integer.
hash_value
:
int
# Token IDs in the block.
token_ids
:
Tuple
[
int
,
...]
# Extra keys for the block.
extra_keys
:
Optional
[
Any
]
=
None
@
dataclass
...
...
@@ -159,8 +164,80 @@ class FreeKVCacheBlockQueue:
return
ret
def
hash_block_tokens
(
parent_block_hash
:
Optional
[
int
],
curr_block_token_ids
:
Sequence
[
int
])
->
BlockHashType
:
def
generate_block_hash_extra_keys
(
request
:
Request
,
start_token_idx
:
int
,
end_token_idx
:
int
,
start_mm_idx
:
int
)
->
Tuple
[
Optional
[
Tuple
[
Any
,
...]],
int
]:
"""Generate extra keys for the block hash. The extra keys can come from
the multi-modal inputs and request specific metadata (e.g., LoRA ID).
For multi-modal inputs, the extra keys are (mm_hash, start_offset) that
indicate a mm input contained in the block and its starting offset in
the block tokens.
Args:
request: The request object.
start_token_idx: The start token index of the block.
end_token_idx: The end token index of the block.
start_mm_idx: The start multi-modal index of the block.
Returns:
A tuple of extra keys and the next multi-modal index.
"""
mm_positions
,
mm_hashes
=
request
.
mm_positions
,
request
.
mm_hashes
if
not
mm_positions
:
return
None
,
start_mm_idx
if
mm_positions
and
len
(
mm_positions
)
!=
len
(
mm_hashes
):
raise
ValueError
(
"The number of multi-modal positions and hashes must match. This "
"is likely because you do not enable MM preprocessor hashing. "
"Please set mm_cache_preprocessor=True."
)
# Note that we assume mm_positions is sorted by offset.
# We do not need to check all mm inputs if the start token index is out of
# range. This usually happens in the late prefill phase and decoding phase.
if
mm_positions
[
-
1
][
"offset"
]
+
mm_positions
[
-
1
][
"length"
]
<
start_token_idx
:
return
None
,
start_mm_idx
# Support start_mm_idx == -1 to indicate the last mm input.
if
start_mm_idx
<
0
:
assert
-
start_mm_idx
<=
len
(
mm_positions
)
start_mm_idx
=
len
(
mm_positions
)
+
start_mm_idx
extra_keys
=
[]
curr_mm_idx
=
start_mm_idx
while
mm_positions
and
curr_mm_idx
<
len
(
mm_positions
):
assert
mm_hashes
[
curr_mm_idx
]
is
not
None
offset
=
mm_positions
[
curr_mm_idx
][
"offset"
]
length
=
mm_positions
[
curr_mm_idx
][
"length"
]
if
end_token_idx
>
offset
:
if
start_token_idx
>
offset
+
length
:
# This block has passed the current mm input.
curr_mm_idx
+=
1
continue
# The block contains the current mm input.
mm_start
=
max
(
0
,
start_token_idx
-
offset
)
extra_keys
.
append
((
mm_hashes
[
curr_mm_idx
],
mm_start
))
if
end_token_idx
>=
offset
+
length
:
# If this block contains the end of the current mm input,
# move to the next mm input as this block may also contain
# the next mm input.
curr_mm_idx
+=
1
else
:
# Otherwise this block is done with mm inputs.
break
else
:
# This block has not reached the current mm input.
break
return
tuple
(
extra_keys
),
curr_mm_idx
def
hash_block_tokens
(
parent_block_hash
:
Optional
[
int
],
curr_block_token_ids
:
Sequence
[
int
],
extra_keys
:
Optional
[
Tuple
[
Any
,
...]]
=
None
)
->
BlockHashType
:
"""Computes a hash value corresponding to the contents of a block and
the contents of the preceding block(s). The hash value is used for
prefix caching. We use LRU cache for this function to avoid recomputing
...
...
@@ -174,27 +251,39 @@ def hash_block_tokens(parent_block_hash: Optional[int],
if this is the first block.
curr_block_token_ids: A list of token ids in the current
block. The current block is assumed to be full.
extra_keys: Extra keys for the block.
Returns:
The hash value of the block and the token ids in the block.
The entire tuple is used as the hash key of the block.
"""
return
BlockHashType
(
hash
((
parent_block_hash
,
*
curr_block_token_ids
)),
tuple
(
curr_block_token_ids
))
tuple
(
curr_block_token_ids
)
,
extra_keys
)
def
hash_request_tokens
(
block_size
:
int
,
token_ids
:
S
eque
nce
[
int
]
)
->
List
[
BlockHashType
]:
request
:
R
eque
st
)
->
List
[
BlockHashType
]:
"""Computes hash values of a chain of blocks given a sequence of
token IDs. The hash value is used for prefix caching.
Args:
block_size: The size of each block.
token_ids: A sequence of token ids in t
he request.
request: T
he request
object
.
Returns:
The list of computed hash values.
"""
token_ids
=
request
.
all_token_ids
mm_positions
,
mm_hashes
=
request
.
mm_positions
,
request
.
mm_hashes
if
mm_positions
and
len
(
mm_positions
)
!=
len
(
mm_hashes
):
raise
ValueError
(
"The number of multi-modal positions and hashes must match."
)
# TODO: Extend this to support other features such as LoRA.
need_extra_keys
=
bool
(
mm_positions
)
extra_keys
=
None
curr_mm_idx
=
0
ret
=
[]
parent_block_hash_value
=
None
for
start
in
range
(
0
,
len
(
token_ids
),
block_size
):
...
...
@@ -203,8 +292,14 @@ def hash_request_tokens(block_size: int,
# Do not hash the block if it is not full.
if
len
(
block_token_ids
)
<
block_size
:
break
# Add extra keys if the block is a multi-modal block.
if
need_extra_keys
:
extra_keys
,
curr_mm_idx
=
generate_block_hash_extra_keys
(
request
,
start
,
end
,
curr_mm_idx
)
block_hash
=
hash_block_tokens
(
parent_block_hash_value
,
block_token_ids
)
block_token_ids
,
extra_keys
)
ret
.
append
(
block_hash
)
parent_block_hash_value
=
block_hash
.
hash_value
return
ret
vllm/v1/core/scheduler.py
View file @
bf8717eb
...
...
@@ -516,6 +516,7 @@ class NewRequestData:
prompt_token_ids
:
List
[
int
]
prompt
:
Optional
[
str
]
mm_inputs
:
List
[
"MultiModalKwargs"
]
mm_hashes
:
List
[
str
]
mm_positions
:
List
[
"PlaceholderRange"
]
sampling_params
:
SamplingParams
block_ids
:
List
[
int
]
...
...
@@ -533,6 +534,7 @@ class NewRequestData:
prompt_token_ids
=
request
.
prompt_token_ids
,
prompt
=
request
.
prompt
,
mm_inputs
=
request
.
mm_inputs
,
mm_hashes
=
request
.
mm_hashes
,
mm_positions
=
request
.
mm_positions
,
sampling_params
=
request
.
sampling_params
,
block_ids
=
block_ids
,
...
...
vllm/v1/engine/async_llm.py
View file @
bf8717eb
...
...
@@ -60,9 +60,13 @@ class AsyncLLM(EngineClient):
self
.
client_aborted_requests
:
List
[
str
]
=
[]
# Processor (converts Inputs --> EngineCoreRequests).
self
.
processor
=
Processor
(
vllm_config
.
model_config
,
vllm_config
.
lora_config
,
self
.
tokenizer
,
input_registry
)
self
.
processor
=
Processor
(
model_config
=
vllm_config
.
model_config
,
cache_config
=
vllm_config
.
cache_config
,
lora_config
=
vllm_config
.
lora_config
,
tokenizer
=
self
.
tokenizer
,
input_registry
=
input_registry
,
)
# Detokenizer (converts EngineCoreOutputs --> RequestOutput).
self
.
detokenizer
=
Detokenizer
(
...
...
vllm/v1/engine/core.py
View file @
bf8717eb
...
...
@@ -65,7 +65,8 @@ class EngineCore:
self
.
_last_logging_time
=
time
.
time
()
self
.
mm_input_mapper_server
=
MMInputMapperServer
()
self
.
mm_input_mapper_server
=
MMInputMapperServer
(
vllm_config
.
model_config
)
def
_initialize_kv_caches
(
self
,
cache_config
:
CacheConfig
)
->
Tuple
[
int
,
int
]:
...
...
@@ -98,9 +99,8 @@ class EngineCore:
# MM mapper, so anything that has a hash must have a HIT cache
# entry here as well.
assert
request
.
mm_inputs
is
not
None
request
.
mm_inputs
,
request
.
mm_hashes
=
(
self
.
mm_input_mapper_server
.
process_inputs
(
request
.
mm_inputs
,
request
.
mm_hashes
))
request
.
mm_inputs
=
self
.
mm_input_mapper_server
.
process_inputs
(
request
.
mm_inputs
,
request
.
mm_hashes
)
req
=
Request
.
from_engine_core_request
(
request
)
...
...
vllm/v1/engine/llm_engine.py
View file @
bf8717eb
...
...
@@ -55,9 +55,12 @@ class LLMEngine:
self
.
tokenizer
.
ping
()
# Processor (convert Inputs --> EngineCoreRequests)
self
.
processor
=
Processor
(
vllm_config
.
model_config
,
vllm_config
.
lora_config
,
self
.
tokenizer
,
input_registry
,
mm_registry
)
self
.
processor
=
Processor
(
model_config
=
vllm_config
.
model_config
,
cache_config
=
vllm_config
.
cache_config
,
lora_config
=
vllm_config
.
lora_config
,
tokenizer
=
self
.
tokenizer
,
input_registry
=
input_registry
,
mm_registry
=
mm_registry
)
# Detokenizer (converts EngineCoreOutputs --> RequestOutput)
self
.
detokenizer
=
Detokenizer
(
...
...
vllm/v1/engine/mm_input_mapper.py
View file @
bf8717eb
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
from
typing
import
Any
,
Dict
,
List
,
Optional
import
PIL
from
blake3
import
blake3
...
...
@@ -42,6 +42,8 @@ class MMInputMapperClient:
model_config
)
self
.
mm_registry
.
init_mm_limits_per_prompt
(
model_config
)
# Init cache
self
.
use_cache
=
model_config
.
mm_cache_preprocessor
self
.
mm_cache
=
LRUDictCache
[
str
,
MultiModalKwargs
](
MM_CACHE_SIZE
)
# DEBUG: Set to None to disable
...
...
@@ -61,7 +63,7 @@ class MMInputMapperClient:
mm_hashes
:
Optional
[
List
[
str
]],
mm_processor_kwargs
:
Optional
[
Dict
[
str
,
Any
]],
precomputed_mm_inputs
:
Optional
[
List
[
MultiModalKwargs
]],
)
->
Tuple
[
List
[
MultiModalKwargs
]
,
Optional
[
List
[
str
]]]
:
)
->
List
[
MultiModalKwargs
]:
if
precomputed_mm_inputs
is
None
:
image_inputs
=
mm_data
[
"image"
]
if
not
isinstance
(
image_inputs
,
list
):
...
...
@@ -70,26 +72,21 @@ class MMInputMapperClient:
else
:
num_inputs
=
len
(
precomputed_mm_inputs
)
# Check if hash is enabled
use_hash
=
mm_hashes
is
not
None
if
use_hash
:
# Sanity
if
self
.
use_cache
:
assert
mm_hashes
is
not
None
assert
num_inputs
==
len
(
mm_hashes
),
"num_inputs = {} len(mm_hashes) = {}"
.
format
(
num_inputs
,
len
(
mm_hashes
))
assert
num_inputs
==
len
(
mm_hashes
)
# Process each image input separately, so that later we can schedule
# them in a fine-grained manner.
# Apply caching (if enabled) and reuse precomputed inputs (if provided)
ret_hashes
:
Optional
[
List
[
str
]]
=
[]
if
use_hash
else
None
ret_inputs
:
List
[
MultiModalKwargs
]
=
[]
for
input_id
in
range
(
num_inputs
):
if
self
.
mm_debug_cache_hit_ratio_steps
is
not
None
:
self
.
cache_hit_ratio
(
self
.
mm_debug_cache_hit_ratio_steps
)
mm_hash
=
None
mm_input
=
None
if
use_hash
:
if
self
.
use_cache
:
assert
mm_hashes
is
not
None
mm_hash
=
mm_hashes
[
input_id
]
mm_input
=
self
.
mm_cache
.
get
(
mm_hash
)
...
...
@@ -106,7 +103,7 @@ class MMInputMapperClient:
mm_processor_kwargs
=
mm_processor_kwargs
,
)
if
use_hash
:
if
self
.
use_cache
:
# Add to cache
assert
mm_hash
is
not
None
self
.
mm_cache
.
put
(
mm_hash
,
mm_input
)
...
...
@@ -114,18 +111,15 @@ class MMInputMapperClient:
self
.
mm_cache_hits
+=
1
mm_input
=
None
# Avoids sending mm_input to Server
if
use_hash
:
assert
mm_hash
is
not
None
assert
ret_hashes
is
not
None
ret_hashes
.
append
(
mm_hash
)
ret_inputs
.
append
(
mm_input
)
return
ret_inputs
,
ret_hashes
return
ret_inputs
class
MMInputMapperServer
:
def
__init__
(
self
,
):
def
__init__
(
self
,
model_config
):
self
.
use_cache
=
model_config
.
mm_cache_preprocessor
self
.
mm_cache
=
LRUDictCache
[
str
,
MultiModalKwargs
](
MM_CACHE_SIZE
)
def
process_inputs
(
...
...
@@ -135,6 +129,9 @@ class MMInputMapperServer:
)
->
List
[
MultiModalKwargs
]:
assert
len
(
mm_inputs
)
==
len
(
mm_hashes
)
if
not
self
.
use_cache
:
return
mm_inputs
full_mm_inputs
=
[]
for
mm_input
,
mm_hash
in
zip
(
mm_inputs
,
mm_hashes
):
assert
mm_hash
is
not
None
...
...
vllm/v1/engine/processor.py
View file @
bf8717eb
import
time
from
typing
import
Any
,
Dict
,
Mapping
,
Optional
,
Tuple
,
Union
from
vllm.config
import
LoRAConfig
,
ModelConfig
from
vllm.config
import
CacheConfig
,
LoRAConfig
,
ModelConfig
from
vllm.inputs
import
(
INPUT_REGISTRY
,
InputRegistry
,
ProcessorInputs
,
PromptType
,
SingletonInputsAdapter
)
from
vllm.inputs.parse
import
is_encoder_decoder_inputs
...
...
@@ -23,6 +23,7 @@ class Processor:
def
__init__
(
self
,
model_config
:
ModelConfig
,
cache_config
:
CacheConfig
,
lora_config
:
Optional
[
LoRAConfig
],
tokenizer
:
BaseTokenizerGroup
,
input_registry
:
InputRegistry
=
INPUT_REGISTRY
,
...
...
@@ -45,8 +46,9 @@ class Processor:
self
.
mm_input_mapper_client
=
MMInputMapperClient
(
model_config
)
# Multi-modal hasher (for images)
self
.
mm_hasher
=
MMHasher
(
)
if
model_config
.
mm_cache_preprocessor
else
None
self
.
use_hash
=
model_config
.
mm_cache_preprocessor
or
\
cache_config
.
enable_prefix_caching
self
.
mm_hasher
=
MMHasher
()
# TODO: run in an ThreadpoolExecutor or BackgroundProcess.
# This ideally should releases the GIL, so we should not block the
...
...
@@ -77,7 +79,7 @@ class Processor:
# Compute MM hashes (if enabled)
mm_hashes
=
None
if
self
.
mm
_hash
er
is
not
None
:
if
self
.
use
_hash
:
mm_hashes
=
self
.
mm_hasher
.
hash
(
prompt
)
# Process inputs.
...
...
@@ -118,7 +120,7 @@ class Processor:
# Apply MM mapper
mm_inputs
=
None
if
len
(
decoder_inputs
.
multi_modal_data
)
>
0
:
mm_inputs
,
mm_hashes
=
self
.
mm_input_mapper_client
.
process_inputs
(
mm_inputs
=
self
.
mm_input_mapper_client
.
process_inputs
(
decoder_inputs
.
multi_modal_data
,
mm_hashes
,
decoder_inputs
.
mm_processor_kwargs
,
precomputed_mm_inputs
)
...
...
vllm/v1/request.py
View file @
bf8717eb
import
enum
from
typing
import
List
,
Optional
,
Union
from
typing
import
TYPE_CHECKING
,
List
,
Optional
,
Union
from
vllm.inputs
import
DecoderOnlyInputs
,
SingletonInputsAdapter
,
token_inputs
from
vllm.lora.request
import
LoRARequest
...
...
@@ -9,6 +9,9 @@ from vllm.sequence import RequestMetrics
from
vllm.v1.engine
import
EngineCoreRequest
from
vllm.v1.utils
import
ConstantList
if
TYPE_CHECKING
:
from
vllm.v1.core.kv_cache_utils
import
BlockHashType
class
Request
:
...
...
@@ -45,6 +48,7 @@ class Request:
self
.
_all_token_ids
:
List
[
int
]
=
self
.
prompt_token_ids
.
copy
()
self
.
num_computed_tokens
=
0
# Multi-modal input metadata.
mm_positions
=
self
.
inputs
.
multi_modal_placeholders
if
mm_positions
:
# FIXME(woosuk): Support other modalities.
...
...
@@ -56,6 +60,12 @@ class Request:
if
self
.
inputs
.
multi_modal_inputs
:
self
.
mm_inputs
=
self
.
inputs
.
multi_modal_inputs
self
.
mm_hashes
:
List
[
str
]
=
self
.
inputs
.
multi_modal_hashes
# Cache the computed kv block hashes of the request to avoid
# recomputing.
self
.
_kv_block_hashes
:
List
[
BlockHashType
]
=
[]
@
classmethod
def
from_engine_core_request
(
cls
,
request
:
EngineCoreRequest
)
->
"Request"
:
return
cls
(
...
...
@@ -65,6 +75,7 @@ class Request:
prompt
=
request
.
prompt
,
multi_modal_data
=
None
,
multi_modal_inputs
=
request
.
mm_inputs
,
multi_modal_hashes
=
request
.
mm_hashes
,
multi_modal_placeholders
=
request
.
mm_placeholders
,
mm_processor_kwargs
=
None
,
),
...
...
@@ -121,6 +132,17 @@ class Request:
num_tokens
=
self
.
mm_positions
[
input_id
][
"length"
]
return
num_tokens
@
property
def
kv_block_hashes
(
self
)
->
ConstantList
[
"BlockHashType"
]:
# Prevent directly appending to the kv_block_hashes.
return
ConstantList
(
self
.
_kv_block_hashes
)
def
set_kv_block_hashes
(
self
,
value
:
List
[
"BlockHashType"
])
->
None
:
self
.
_kv_block_hashes
=
value
def
append_kv_block_hashes
(
self
,
block_hash
:
"BlockHashType"
)
->
None
:
self
.
_kv_block_hashes
.
append
(
block_hash
)
class
RequestStatus
(
enum
.
IntEnum
):
"""Status of a request."""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment