Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
bf8717eb
Unverified
Commit
bf8717eb
authored
Dec 17, 2024
by
Cody Yu
Committed by
GitHub
Dec 17, 2024
Browse files
[V1] Prefix caching for vision language models (#11187)
Signed-off-by:
Cody Yu
<
hao.yu.cody@gmail.com
>
parent
c77eb8a3
Changes
14
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
342 additions
and
98 deletions
+342
-98
tests/v1/core/test_prefix_caching.py
tests/v1/core/test_prefix_caching.py
+86
-2
tests/v1/engine/test_engine_args.py
tests/v1/engine/test_engine_args.py
+0
-15
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+14
-13
vllm/inputs/data.py
vllm/inputs/data.py
+20
-0
vllm/multimodal/inputs.py
vllm/multimodal/inputs.py
+3
-0
vllm/v1/core/kv_cache_manager.py
vllm/v1/core/kv_cache_manager.py
+50
-24
vllm/v1/core/kv_cache_utils.py
vllm/v1/core/kv_cache_utils.py
+105
-10
vllm/v1/core/scheduler.py
vllm/v1/core/scheduler.py
+2
-0
vllm/v1/engine/async_llm.py
vllm/v1/engine/async_llm.py
+7
-3
vllm/v1/engine/core.py
vllm/v1/engine/core.py
+4
-4
vllm/v1/engine/llm_engine.py
vllm/v1/engine/llm_engine.py
+6
-3
vllm/v1/engine/mm_input_mapper.py
vllm/v1/engine/mm_input_mapper.py
+15
-18
vllm/v1/engine/processor.py
vllm/v1/engine/processor.py
+7
-5
vllm/v1/request.py
vllm/v1/request.py
+23
-1
No files found.
tests/v1/core/test_prefix_caching.py
View file @
bf8717eb
...
@@ -2,16 +2,23 @@
...
@@ -2,16 +2,23 @@
import
pytest
import
pytest
from
vllm.inputs
import
token_inputs
from
vllm.inputs
import
token_inputs
from
vllm.multimodal.inputs
import
PlaceholderRange
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.utils
import
cdiv
from
vllm.utils
import
cdiv
from
vllm.v1.core.kv_cache_manager
import
KVCacheManager
,
Request
from
vllm.v1.core.kv_cache_manager
import
KVCacheManager
,
Request
from
vllm.v1.core.kv_cache_utils
import
KVCacheBlock
,
hash_block_tokens
from
vllm.v1.core.kv_cache_utils
import
KVCacheBlock
,
hash_block_tokens
def
make_request
(
request_id
,
prompt_token_ids
):
def
make_request
(
request_id
,
prompt_token_ids
,
mm_positions
=
None
,
mm_hashes
=
None
):
return
Request
(
return
Request
(
request_id
=
request_id
,
request_id
=
request_id
,
inputs
=
token_inputs
(
prompt_token_ids
=
prompt_token_ids
),
inputs
=
token_inputs
(
prompt_token_ids
=
prompt_token_ids
,
multi_modal_placeholders
=
{
"image"
:
mm_positions
}
if
mm_positions
else
None
,
multi_modal_hashes
=
mm_hashes
),
sampling_params
=
SamplingParams
(
max_tokens
=
17
),
sampling_params
=
SamplingParams
(
max_tokens
=
17
),
eos_token_id
=
100
,
eos_token_id
=
100
,
arrival_time
=
0
,
arrival_time
=
0
,
...
@@ -38,6 +45,7 @@ def test_prefill():
...
@@ -38,6 +45,7 @@ def test_prefill():
all_token_ids
=
common_token_ids
+
unique_token_ids
all_token_ids
=
common_token_ids
+
unique_token_ids
req0
=
make_request
(
"0"
,
all_token_ids
)
req0
=
make_request
(
"0"
,
all_token_ids
)
computed_blocks
=
manager
.
get_computed_blocks
(
req0
)
computed_blocks
=
manager
.
get_computed_blocks
(
req0
)
assert
len
(
req0
.
kv_block_hashes
)
==
3
assert
not
computed_blocks
assert
not
computed_blocks
blocks
=
manager
.
allocate_slots
(
req0
,
55
,
computed_blocks
)
blocks
=
manager
.
allocate_slots
(
req0
,
55
,
computed_blocks
)
assert
[
b
.
block_id
for
b
in
blocks
]
==
[
0
,
1
,
2
,
3
,
4
]
assert
[
b
.
block_id
for
b
in
blocks
]
==
[
0
,
1
,
2
,
3
,
4
]
...
@@ -61,6 +69,7 @@ def test_prefill():
...
@@ -61,6 +69,7 @@ def test_prefill():
unique_token_ids
=
[
3
]
*
5
unique_token_ids
=
[
3
]
*
5
req1
=
make_request
(
"1"
,
common_token_ids
+
unique_token_ids
)
req1
=
make_request
(
"1"
,
common_token_ids
+
unique_token_ids
)
computed_blocks
=
manager
.
get_computed_blocks
(
req1
)
computed_blocks
=
manager
.
get_computed_blocks
(
req1
)
assert
len
(
req1
.
kv_block_hashes
)
==
3
assert
[
b
.
block_id
for
b
in
computed_blocks
]
==
[
0
,
1
,
2
]
assert
[
b
.
block_id
for
b
in
computed_blocks
]
==
[
0
,
1
,
2
]
num_new_tokens
=
53
-
3
*
16
num_new_tokens
=
53
-
3
*
16
blocks
=
manager
.
allocate_slots
(
req1
,
num_new_tokens
,
computed_blocks
)
blocks
=
manager
.
allocate_slots
(
req1
,
num_new_tokens
,
computed_blocks
)
...
@@ -90,6 +99,7 @@ def test_prefill():
...
@@ -90,6 +99,7 @@ def test_prefill():
unique_token_ids
=
[
3
]
*
6
unique_token_ids
=
[
3
]
*
6
req2
=
make_request
(
"2"
,
common_token_ids
+
unique_token_ids
)
req2
=
make_request
(
"2"
,
common_token_ids
+
unique_token_ids
)
computed_block
=
manager
.
get_computed_blocks
(
req2
)
computed_block
=
manager
.
get_computed_blocks
(
req2
)
assert
len
(
req2
.
kv_block_hashes
)
==
3
assert
[
b
.
block_id
for
b
in
computed_block
]
==
[
0
,
1
,
2
]
assert
[
b
.
block_id
for
b
in
computed_block
]
==
[
0
,
1
,
2
]
num_new_tokens
=
53
-
3
*
16
num_new_tokens
=
53
-
3
*
16
blocks
=
manager
.
allocate_slots
(
req2
,
num_new_tokens
,
computed_blocks
)
blocks
=
manager
.
allocate_slots
(
req2
,
num_new_tokens
,
computed_blocks
)
...
@@ -416,3 +426,77 @@ def test_cache_blocks():
...
@@ -416,3 +426,77 @@ def test_cache_blocks():
)
)
assert
len
(
manager
.
cached_block_hash_to_block
)
==
3
assert
len
(
manager
.
cached_block_hash_to_block
)
==
3
assert
blocks
[
0
].
block_hash
is
not
None
assert
blocks
[
0
].
block_hash
is
not
None
def
test_mm_prefix_caching
():
"""
This tests that the multi-modal prefix caching is correct.
"""
manager
=
KVCacheManager
(
block_size
=
16
,
num_gpu_blocks
=
10
,
max_model_len
=
8192
,
sliding_window
=
None
,
enable_caching
=
True
,
num_preallocate_tokens
=
16
,
)
# Common prompt tokens (T is text tokens and P is image placeholder tokens)
# [T,...,T, P0,...,P0], [P0,...,P0,T,...,T,P1,...,P1], [P1,...,P1]
common_token_ids
=
list
(
range
(
10
))
+
[
-
1
]
*
6
common_token_ids
+=
[
-
1
]
*
4
+
list
(
range
(
10
,
20
))
+
[
-
1
]
*
2
common_token_ids
+=
[
-
1
]
*
16
common_mm_positions
=
[
PlaceholderRange
(
offset
=
11
,
length
=
10
),
PlaceholderRange
(
offset
=
30
,
length
=
18
),
]
common_mm_hashes
=
[
"aaa"
,
"bbb"
]
# A unique image plus some text tokens.
unique_token_ids
=
[
-
1
]
*
7
+
[
100
]
*
4
all_token_ids
=
common_token_ids
+
unique_token_ids
mm_positions
=
common_mm_positions
+
[
PlaceholderRange
(
offset
=
48
,
length
=
7
)
]
mm_hashes
=
common_mm_hashes
+
[
"ccc"
]
req0
=
make_request
(
"0"
,
all_token_ids
,
mm_positions
=
mm_positions
,
mm_hashes
=
mm_hashes
)
computed_blocks
=
manager
.
get_computed_blocks
(
req0
)
# Completed block should have hashes with extra keys.
assert
not
computed_blocks
assert
len
(
req0
.
kv_block_hashes
)
==
3
assert
req0
.
kv_block_hashes
[
0
].
extra_keys
==
((
"aaa"
,
0
),
)
assert
req0
.
kv_block_hashes
[
1
].
extra_keys
==
((
"aaa"
,
5
),
(
"bbb"
,
0
))
assert
req0
.
kv_block_hashes
[
2
].
extra_keys
==
((
"bbb"
,
2
),
)
blocks
=
manager
.
allocate_slots
(
req0
,
59
,
computed_blocks
)
assert
[
b
.
block_id
for
b
in
blocks
]
==
[
0
,
1
,
2
,
3
,
4
]
req0
.
num_computed_tokens
=
59
# Append slots without allocating a new block.
for
_
in
range
(
5
):
req0
.
append_output_token_ids
(
8
)
new_blocks
=
manager
.
append_slots
(
req0
,
5
)
assert
new_blocks
is
not
None
and
len
(
new_blocks
)
==
0
# The just completed block should have hashes with extra keys.
assert
len
(
req0
.
kv_block_hashes
)
==
4
assert
req0
.
kv_block_hashes
[
3
].
extra_keys
==
((
"ccc"
,
0
),
)
# Cache hit.
unique_token_ids
=
[
-
1
]
*
7
+
[
200
]
*
5
all_token_ids
=
common_token_ids
+
unique_token_ids
mm_positions
=
common_mm_positions
+
[
PlaceholderRange
(
offset
=
48
,
length
=
7
)
]
mm_hashes
=
common_mm_hashes
+
[
"ccc"
]
req1
=
make_request
(
"1"
,
all_token_ids
,
mm_positions
=
mm_positions
,
mm_hashes
=
mm_hashes
)
computed_blocks
=
manager
.
get_computed_blocks
(
req1
)
assert
len
(
computed_blocks
)
==
3
tests/v1/engine/test_engine_args.py
View file @
bf8717eb
...
@@ -31,14 +31,6 @@ def test_prefix_caching_from_cli():
...
@@ -31,14 +31,6 @@ def test_prefix_caching_from_cli():
assert
engine_args
.
enable_prefix_caching
assert
engine_args
.
enable_prefix_caching
def
test_defaults
():
engine_args
=
EngineArgs
(
model
=
"facebook/opt-125m"
)
# Assert V1 defaults
assert
(
engine_args
.
enable_prefix_caching
),
"V1 turns on prefix caching by default"
def
test_defaults_with_usage_context
():
def
test_defaults_with_usage_context
():
engine_args
=
EngineArgs
(
model
=
"facebook/opt-125m"
)
engine_args
=
EngineArgs
(
model
=
"facebook/opt-125m"
)
vllm_config
:
VllmConfig
=
engine_args
.
create_engine_config
(
vllm_config
:
VllmConfig
=
engine_args
.
create_engine_config
(
...
@@ -52,10 +44,3 @@ def test_defaults_with_usage_context():
...
@@ -52,10 +44,3 @@ def test_defaults_with_usage_context():
UsageContext
.
OPENAI_API_SERVER
)
UsageContext
.
OPENAI_API_SERVER
)
assert
vllm_config
.
scheduler_config
.
max_num_seqs
==
1024
assert
vllm_config
.
scheduler_config
.
max_num_seqs
==
1024
assert
vllm_config
.
scheduler_config
.
max_num_batched_tokens
==
2048
assert
vllm_config
.
scheduler_config
.
max_num_batched_tokens
==
2048
def
test_prefix_cache_disabled_with_multimodel
():
engine_args
=
EngineArgs
(
model
=
"llava-hf/llava-1.5-7b-hf"
)
vllm_config
=
engine_args
.
create_engine_config
(
UsageContext
.
LLM_CLASS
)
assert
not
vllm_config
.
cache_config
.
enable_prefix_caching
vllm/engine/arg_utils.py
View file @
bf8717eb
...
@@ -205,6 +205,7 @@ class EngineArgs:
...
@@ -205,6 +205,7 @@ class EngineArgs:
# by user.
# by user.
if
self
.
enable_prefix_caching
is
None
:
if
self
.
enable_prefix_caching
is
None
:
self
.
enable_prefix_caching
=
bool
(
envs
.
VLLM_USE_V1
)
self
.
enable_prefix_caching
=
bool
(
envs
.
VLLM_USE_V1
)
# Override max_num_seqs if it's not set by user.
# Override max_num_seqs if it's not set by user.
if
self
.
max_num_seqs
is
None
:
if
self
.
max_num_seqs
is
None
:
self
.
max_num_seqs
=
256
if
not
envs
.
VLLM_USE_V1
else
1024
self
.
max_num_seqs
=
256
if
not
envs
.
VLLM_USE_V1
else
1024
...
@@ -1026,11 +1027,11 @@ class EngineArgs:
...
@@ -1026,11 +1027,11 @@ class EngineArgs:
device_config
=
DeviceConfig
(
device
=
self
.
device
)
device_config
=
DeviceConfig
(
device
=
self
.
device
)
model_config
=
self
.
create_model_config
()
model_config
=
self
.
create_model_config
()
if
model_config
.
is_multimodal_model
:
if
(
model_config
.
is_multimodal_model
and
not
envs
.
VLLM_USE_V1
if
self
.
enable_prefix_caching
:
and
self
.
enable_prefix_caching
)
:
logger
.
warning
(
logger
.
warning
(
"--enable-prefix-caching is currently not "
"--enable-prefix-caching is currently not
"
"supported for multimodal models in v0 and
"
"supported for multimodal models and
has been disabled."
)
"
has been disabled."
)
self
.
enable_prefix_caching
=
False
self
.
enable_prefix_caching
=
False
cache_config
=
CacheConfig
(
cache_config
=
CacheConfig
(
...
@@ -1249,11 +1250,14 @@ class EngineArgs:
...
@@ -1249,11 +1250,14 @@ class EngineArgs:
# When no user override, set the default values based on the usage
# When no user override, set the default values based on the usage
# context.
# context.
# TODO(woosuk): Tune the default values for different hardware.
# TODO(woosuk): Tune the default values for different hardware.
if
self
.
max_num_batched_tokens
is
None
:
default_max_num_batched_tokens
=
{
if
usage_context
==
UsageContext
.
LLM_CLASS
:
UsageContext
.
LLM_CLASS
:
8192
,
self
.
max_num_batched_tokens
=
8192
UsageContext
.
OPENAI_API_SERVER
:
2048
,
elif
usage_context
==
UsageContext
.
OPENAI_API_SERVER
:
}
self
.
max_num_batched_tokens
=
2048
if
(
self
.
max_num_batched_tokens
is
None
and
usage_context
in
default_max_num_batched_tokens
):
self
.
max_num_batched_tokens
=
default_max_num_batched_tokens
[
usage_context
]
logger
.
warning
(
logger
.
warning
(
"Setting max_num_batched_tokens to %d for %s usage context."
,
"Setting max_num_batched_tokens to %d for %s usage context."
,
self
.
max_num_batched_tokens
,
usage_context
.
value
)
self
.
max_num_batched_tokens
,
usage_context
.
value
)
...
@@ -1263,9 +1267,6 @@ class EngineArgs:
...
@@ -1263,9 +1267,6 @@ class EngineArgs:
Override the EngineConfig's configs based on the usage context for V1.
Override the EngineConfig's configs based on the usage context for V1.
"""
"""
assert
envs
.
VLLM_USE_V1
,
"V1 is not enabled"
assert
envs
.
VLLM_USE_V1
,
"V1 is not enabled"
if
engine_config
.
model_config
.
is_multimodal_model
:
# TODO (ywang96): Enable APC by default when VLM supports it.
assert
not
engine_config
.
cache_config
.
enable_prefix_caching
@
dataclass
@
dataclass
...
...
vllm/inputs/data.py
View file @
bf8717eb
...
@@ -162,6 +162,11 @@ class TokenInputs(TypedDict):
...
@@ -162,6 +162,11 @@ class TokenInputs(TypedDict):
Placeholder ranges for the multi-modal data.
Placeholder ranges for the multi-modal data.
"""
"""
multi_modal_hashes
:
NotRequired
[
List
[
str
]]
"""
The hashes of the multi-modal data.
"""
mm_processor_kwargs
:
NotRequired
[
Dict
[
str
,
Any
]]
mm_processor_kwargs
:
NotRequired
[
Dict
[
str
,
Any
]]
"""
"""
Optional multi-modal processor kwargs to be forwarded to the
Optional multi-modal processor kwargs to be forwarded to the
...
@@ -177,6 +182,7 @@ def token_inputs(
...
@@ -177,6 +182,7 @@ def token_inputs(
prompt
:
Optional
[
str
]
=
None
,
prompt
:
Optional
[
str
]
=
None
,
multi_modal_data
:
Optional
[
"MultiModalDataDict"
]
=
None
,
multi_modal_data
:
Optional
[
"MultiModalDataDict"
]
=
None
,
multi_modal_inputs
:
Optional
[
"MultiModalKwargs"
]
=
None
,
multi_modal_inputs
:
Optional
[
"MultiModalKwargs"
]
=
None
,
multi_modal_hashes
:
Optional
[
List
[
str
]]
=
None
,
multi_modal_placeholders
:
Optional
[
"MultiModalPlaceholderDict"
]
=
None
,
multi_modal_placeholders
:
Optional
[
"MultiModalPlaceholderDict"
]
=
None
,
mm_processor_kwargs
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
mm_processor_kwargs
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
)
->
TokenInputs
:
)
->
TokenInputs
:
...
@@ -191,6 +197,8 @@ def token_inputs(
...
@@ -191,6 +197,8 @@ def token_inputs(
inputs
[
"multi_modal_data"
]
=
multi_modal_data
inputs
[
"multi_modal_data"
]
=
multi_modal_data
if
multi_modal_inputs
is
not
None
:
if
multi_modal_inputs
is
not
None
:
inputs
[
"multi_modal_inputs"
]
=
multi_modal_inputs
inputs
[
"multi_modal_inputs"
]
=
multi_modal_inputs
if
multi_modal_hashes
is
not
None
:
inputs
[
"multi_modal_hashes"
]
=
multi_modal_hashes
if
multi_modal_placeholders
is
not
None
:
if
multi_modal_placeholders
is
not
None
:
inputs
[
"multi_modal_placeholders"
]
=
multi_modal_placeholders
inputs
[
"multi_modal_placeholders"
]
=
multi_modal_placeholders
if
mm_processor_kwargs
is
not
None
:
if
mm_processor_kwargs
is
not
None
:
...
@@ -295,6 +303,18 @@ class SingletonInputsAdapter:
...
@@ -295,6 +303,18 @@ class SingletonInputsAdapter:
assert_never
(
inputs
)
assert_never
(
inputs
)
@
cached_property
def
multi_modal_hashes
(
self
)
->
List
[
str
]:
inputs
=
self
.
inputs
if
inputs
[
"type"
]
==
"token"
:
return
inputs
.
get
(
"multi_modal_hashes"
,
[])
if
inputs
[
"type"
]
==
"multimodal"
:
return
inputs
.
get
(
"mm_hashes"
,
[])
assert_never
(
inputs
)
@
cached_property
@
cached_property
def
multi_modal_placeholders
(
self
)
->
"MultiModalPlaceholderDict"
:
def
multi_modal_placeholders
(
self
)
->
"MultiModalPlaceholderDict"
:
inputs
=
self
.
inputs
inputs
=
self
.
inputs
...
...
vllm/multimodal/inputs.py
View file @
bf8717eb
...
@@ -215,6 +215,9 @@ class MultiModalInputsV2(TypedDict):
...
@@ -215,6 +215,9 @@ class MultiModalInputsV2(TypedDict):
mm_kwargs
:
MultiModalKwargs
mm_kwargs
:
MultiModalKwargs
"""Keyword arguments to be directly passed to the model after batching."""
"""Keyword arguments to be directly passed to the model after batching."""
mm_hashes
:
NotRequired
[
List
[
str
]]
"""The hashes of the multi-modal data."""
mm_placeholders
:
MultiModalPlaceholderDict
mm_placeholders
:
MultiModalPlaceholderDict
"""
"""
For each modality, information about the placeholder tokens in
For each modality, information about the placeholder tokens in
...
...
vllm/v1/core/kv_cache_manager.py
View file @
bf8717eb
...
@@ -4,7 +4,9 @@ from typing import Dict, Iterable, List, Optional
...
@@ -4,7 +4,9 @@ from typing import Dict, Iterable, List, Optional
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.utils
import
cdiv
from
vllm.utils
import
cdiv
from
vllm.v1.core.kv_cache_utils
import
(
BlockHashType
,
FreeKVCacheBlockQueue
,
from
vllm.v1.core.kv_cache_utils
import
(
BlockHashType
,
FreeKVCacheBlockQueue
,
KVCacheBlock
,
hash_block_tokens
,
KVCacheBlock
,
generate_block_hash_extra_keys
,
hash_block_tokens
,
hash_request_tokens
)
hash_request_tokens
)
from
vllm.v1.request
import
Request
from
vllm.v1.request
import
Request
...
@@ -83,10 +85,12 @@ class KVCacheManager:
...
@@ -83,10 +85,12 @@ class KVCacheManager:
computed_blocks
=
[]
computed_blocks
=
[]
# TODO(rickyx): potentially we could cache this so we don't have to
# The block hashes for the request may already be computed
# recompute it every time.
# if the request was preempted and resumed.
block_hashes
=
hash_request_tokens
(
self
.
block_size
,
if
not
request
.
kv_block_hashes
:
request
.
all_token_ids
)
request
.
set_kv_block_hashes
(
hash_request_tokens
(
self
.
block_size
,
request
))
block_hashes
=
request
.
kv_block_hashes
for
block_hash
in
block_hashes
:
for
block_hash
in
block_hashes
:
# block_hashes is a chain of block hashes. If a block hash is not
# block_hashes is a chain of block hashes. If a block hash is not
...
@@ -242,14 +246,16 @@ class KVCacheManager:
...
@@ -242,14 +246,16 @@ class KVCacheManager:
num_computed_tokens
=
len
(
computed_blocks
)
*
self
.
block_size
num_computed_tokens
=
len
(
computed_blocks
)
*
self
.
block_size
num_full_blocks
=
(
num_computed_tokens
+
num_tokens
)
//
self
.
block_size
num_full_blocks
=
(
num_computed_tokens
+
num_tokens
)
//
self
.
block_size
self
.
_cache_full_blocks
(
new_full_blocks
=
self
.
req_to_blocks
[
request
=
request
,
request
.
request_id
][
len
(
computed_blocks
):
num_full_blocks
]
blk_start_idx
=
len
(
computed_blocks
),
if
new_full_blocks
:
# The new full blocks are the full blocks that are not computed.
self
.
_cache_full_blocks
(
full_blocks
=
self
.
req_to_blocks
[
request
.
request_id
]
request
=
request
,
[
len
(
computed_blocks
):
num_full_blocks
],
blk_start_idx
=
len
(
computed_blocks
),
prev_block
=
computed_blocks
[
-
1
]
if
computed_blocks
else
None
,
# The new full blocks are the full blocks that are not computed.
)
full_blocks
=
new_full_blocks
,
prev_block
=
computed_blocks
[
-
1
]
if
computed_blocks
else
None
,
)
return
new_blocks
return
new_blocks
...
@@ -376,6 +382,8 @@ class KVCacheManager:
...
@@ -376,6 +382,8 @@ class KVCacheManager:
full_blocks: The list of blocks to update hash metadata.
full_blocks: The list of blocks to update hash metadata.
prev_block: The previous block in the chain.
prev_block: The previous block in the chain.
"""
"""
num_cached_block_hashes
=
len
(
request
.
kv_block_hashes
)
# Update the new blocks with the block hashes through the chain.
# Update the new blocks with the block hashes through the chain.
prev_block_hash_value
=
None
prev_block_hash_value
=
None
if
prev_block
is
not
None
:
if
prev_block
is
not
None
:
...
@@ -387,17 +395,35 @@ class KVCacheManager:
...
@@ -387,17 +395,35 @@ class KVCacheManager:
for
i
,
blk
in
enumerate
(
full_blocks
):
for
i
,
blk
in
enumerate
(
full_blocks
):
blk_idx
=
blk_start_idx
+
i
blk_idx
=
blk_start_idx
+
i
block_tokens
=
request
.
all_token_ids
[
blk_idx
*
if
blk_idx
<
num_cached_block_hashes
:
self
.
block_size
:(
blk_idx
+
# The block hash may already be computed in
1
)
*
# "get_computed_blocks" if the tokens are not generated by
self
.
block_size
]
# this request (either the prompt tokens or the previously
assert
len
(
block_tokens
)
==
self
.
block_size
,
(
# generated tokens with preemption). In this case we simply
f
"Expected
{
self
.
block_size
}
tokens, got
{
len
(
block_tokens
)
}
"
# reuse the block hash.
f
"at
{
blk_idx
}
th block for request "
block_hash
=
request
.
kv_block_hashes
[
blk_idx
]
f
"
{
request
.
request_id
}
(
{
request
}
)"
)
else
:
# Otherwise compute the block hash and cache it in the request
# Compute the hash of the current block.
# in case it will be preempted in the future.
block_hash
=
hash_block_tokens
(
prev_block_hash_value
,
block_tokens
)
start_token_idx
=
blk_idx
*
self
.
block_size
end_token_idx
=
(
blk_idx
+
1
)
*
self
.
block_size
block_tokens
=
request
.
all_token_ids
[
start_token_idx
:
end_token_idx
]
assert
len
(
block_tokens
)
==
self
.
block_size
,
(
f
"Expected
{
self
.
block_size
}
tokens, got "
f
"
{
len
(
block_tokens
)
}
at
{
blk_idx
}
th block for request "
f
"
{
request
.
request_id
}
(
{
request
}
)"
)
# Generate extra keys for multi-modal inputs. Note that since
# we reach to this branch only when the block is completed with
# generated tokens, we only need to consider the last mm input.
extra_keys
,
_
=
generate_block_hash_extra_keys
(
request
,
start_token_idx
,
end_token_idx
,
-
1
)
# Compute the hash of the current block.
block_hash
=
hash_block_tokens
(
prev_block_hash_value
,
block_tokens
,
extra_keys
)
request
.
append_kv_block_hashes
(
block_hash
)
# Update and added the full block to the cache.
# Update and added the full block to the cache.
blk
.
block_hash
=
block_hash
blk
.
block_hash
=
block_hash
...
...
vllm/v1/core/kv_cache_utils.py
View file @
bf8717eb
"""KV-Cache Utilities."""
"""KV-Cache Utilities."""
from
collections.abc
import
Sequence
from
collections.abc
import
Sequence
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
List
,
NamedTuple
,
Optional
,
Tuple
from
typing
import
Any
,
List
,
NamedTuple
,
Optional
,
Tuple
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.v1.request
import
Request
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
class
BlockHashType
(
NamedTuple
):
class
BlockHashType
(
NamedTuple
):
"""Hash value of a block
and
the token IDs in the block.
"""Hash value of a block
(int),
the token IDs in the block
, and extra keys
.
The reason we keep a tuple of token IDs is to make sure
no hash
The reason we keep a tuple of token IDs
and extra keys
is to make sure
collision happens when the hash value is the same.
no hash
collision happens when the hash value is the same.
"""
"""
# Hash value of the block in an integer.
hash_value
:
int
hash_value
:
int
# Token IDs in the block.
token_ids
:
Tuple
[
int
,
...]
token_ids
:
Tuple
[
int
,
...]
# Extra keys for the block.
extra_keys
:
Optional
[
Any
]
=
None
@
dataclass
@
dataclass
...
@@ -159,8 +164,80 @@ class FreeKVCacheBlockQueue:
...
@@ -159,8 +164,80 @@ class FreeKVCacheBlockQueue:
return
ret
return
ret
def
hash_block_tokens
(
parent_block_hash
:
Optional
[
int
],
def
generate_block_hash_extra_keys
(
curr_block_token_ids
:
Sequence
[
int
])
->
BlockHashType
:
request
:
Request
,
start_token_idx
:
int
,
end_token_idx
:
int
,
start_mm_idx
:
int
)
->
Tuple
[
Optional
[
Tuple
[
Any
,
...]],
int
]:
"""Generate extra keys for the block hash. The extra keys can come from
the multi-modal inputs and request specific metadata (e.g., LoRA ID).
For multi-modal inputs, the extra keys are (mm_hash, start_offset) that
indicate a mm input contained in the block and its starting offset in
the block tokens.
Args:
request: The request object.
start_token_idx: The start token index of the block.
end_token_idx: The end token index of the block.
start_mm_idx: The start multi-modal index of the block.
Returns:
A tuple of extra keys and the next multi-modal index.
"""
mm_positions
,
mm_hashes
=
request
.
mm_positions
,
request
.
mm_hashes
if
not
mm_positions
:
return
None
,
start_mm_idx
if
mm_positions
and
len
(
mm_positions
)
!=
len
(
mm_hashes
):
raise
ValueError
(
"The number of multi-modal positions and hashes must match. This "
"is likely because you do not enable MM preprocessor hashing. "
"Please set mm_cache_preprocessor=True."
)
# Note that we assume mm_positions is sorted by offset.
# We do not need to check all mm inputs if the start token index is out of
# range. This usually happens in the late prefill phase and decoding phase.
if
mm_positions
[
-
1
][
"offset"
]
+
mm_positions
[
-
1
][
"length"
]
<
start_token_idx
:
return
None
,
start_mm_idx
# Support start_mm_idx == -1 to indicate the last mm input.
if
start_mm_idx
<
0
:
assert
-
start_mm_idx
<=
len
(
mm_positions
)
start_mm_idx
=
len
(
mm_positions
)
+
start_mm_idx
extra_keys
=
[]
curr_mm_idx
=
start_mm_idx
while
mm_positions
and
curr_mm_idx
<
len
(
mm_positions
):
assert
mm_hashes
[
curr_mm_idx
]
is
not
None
offset
=
mm_positions
[
curr_mm_idx
][
"offset"
]
length
=
mm_positions
[
curr_mm_idx
][
"length"
]
if
end_token_idx
>
offset
:
if
start_token_idx
>
offset
+
length
:
# This block has passed the current mm input.
curr_mm_idx
+=
1
continue
# The block contains the current mm input.
mm_start
=
max
(
0
,
start_token_idx
-
offset
)
extra_keys
.
append
((
mm_hashes
[
curr_mm_idx
],
mm_start
))
if
end_token_idx
>=
offset
+
length
:
# If this block contains the end of the current mm input,
# move to the next mm input as this block may also contain
# the next mm input.
curr_mm_idx
+=
1
else
:
# Otherwise this block is done with mm inputs.
break
else
:
# This block has not reached the current mm input.
break
return
tuple
(
extra_keys
),
curr_mm_idx
def
hash_block_tokens
(
parent_block_hash
:
Optional
[
int
],
curr_block_token_ids
:
Sequence
[
int
],
extra_keys
:
Optional
[
Tuple
[
Any
,
...]]
=
None
)
->
BlockHashType
:
"""Computes a hash value corresponding to the contents of a block and
"""Computes a hash value corresponding to the contents of a block and
the contents of the preceding block(s). The hash value is used for
the contents of the preceding block(s). The hash value is used for
prefix caching. We use LRU cache for this function to avoid recomputing
prefix caching. We use LRU cache for this function to avoid recomputing
...
@@ -174,27 +251,39 @@ def hash_block_tokens(parent_block_hash: Optional[int],
...
@@ -174,27 +251,39 @@ def hash_block_tokens(parent_block_hash: Optional[int],
if this is the first block.
if this is the first block.
curr_block_token_ids: A list of token ids in the current
curr_block_token_ids: A list of token ids in the current
block. The current block is assumed to be full.
block. The current block is assumed to be full.
extra_keys: Extra keys for the block.
Returns:
Returns:
The hash value of the block and the token ids in the block.
The hash value of the block and the token ids in the block.
The entire tuple is used as the hash key of the block.
The entire tuple is used as the hash key of the block.
"""
"""
return
BlockHashType
(
hash
((
parent_block_hash
,
*
curr_block_token_ids
)),
return
BlockHashType
(
hash
((
parent_block_hash
,
*
curr_block_token_ids
)),
tuple
(
curr_block_token_ids
))
tuple
(
curr_block_token_ids
)
,
extra_keys
)
def
hash_request_tokens
(
block_size
:
int
,
def
hash_request_tokens
(
block_size
:
int
,
token_ids
:
S
eque
nce
[
int
]
)
->
List
[
BlockHashType
]:
request
:
R
eque
st
)
->
List
[
BlockHashType
]:
"""Computes hash values of a chain of blocks given a sequence of
"""Computes hash values of a chain of blocks given a sequence of
token IDs. The hash value is used for prefix caching.
token IDs. The hash value is used for prefix caching.
Args:
Args:
block_size: The size of each block.
block_size: The size of each block.
token_ids: A sequence of token ids in t
he request.
request: T
he request
object
.
Returns:
Returns:
The list of computed hash values.
The list of computed hash values.
"""
"""
token_ids
=
request
.
all_token_ids
mm_positions
,
mm_hashes
=
request
.
mm_positions
,
request
.
mm_hashes
if
mm_positions
and
len
(
mm_positions
)
!=
len
(
mm_hashes
):
raise
ValueError
(
"The number of multi-modal positions and hashes must match."
)
# TODO: Extend this to support other features such as LoRA.
need_extra_keys
=
bool
(
mm_positions
)
extra_keys
=
None
curr_mm_idx
=
0
ret
=
[]
ret
=
[]
parent_block_hash_value
=
None
parent_block_hash_value
=
None
for
start
in
range
(
0
,
len
(
token_ids
),
block_size
):
for
start
in
range
(
0
,
len
(
token_ids
),
block_size
):
...
@@ -203,8 +292,14 @@ def hash_request_tokens(block_size: int,
...
@@ -203,8 +292,14 @@ def hash_request_tokens(block_size: int,
# Do not hash the block if it is not full.
# Do not hash the block if it is not full.
if
len
(
block_token_ids
)
<
block_size
:
if
len
(
block_token_ids
)
<
block_size
:
break
break
# Add extra keys if the block is a multi-modal block.
if
need_extra_keys
:
extra_keys
,
curr_mm_idx
=
generate_block_hash_extra_keys
(
request
,
start
,
end
,
curr_mm_idx
)
block_hash
=
hash_block_tokens
(
parent_block_hash_value
,
block_hash
=
hash_block_tokens
(
parent_block_hash_value
,
block_token_ids
)
block_token_ids
,
extra_keys
)
ret
.
append
(
block_hash
)
ret
.
append
(
block_hash
)
parent_block_hash_value
=
block_hash
.
hash_value
parent_block_hash_value
=
block_hash
.
hash_value
return
ret
return
ret
vllm/v1/core/scheduler.py
View file @
bf8717eb
...
@@ -516,6 +516,7 @@ class NewRequestData:
...
@@ -516,6 +516,7 @@ class NewRequestData:
prompt_token_ids
:
List
[
int
]
prompt_token_ids
:
List
[
int
]
prompt
:
Optional
[
str
]
prompt
:
Optional
[
str
]
mm_inputs
:
List
[
"MultiModalKwargs"
]
mm_inputs
:
List
[
"MultiModalKwargs"
]
mm_hashes
:
List
[
str
]
mm_positions
:
List
[
"PlaceholderRange"
]
mm_positions
:
List
[
"PlaceholderRange"
]
sampling_params
:
SamplingParams
sampling_params
:
SamplingParams
block_ids
:
List
[
int
]
block_ids
:
List
[
int
]
...
@@ -533,6 +534,7 @@ class NewRequestData:
...
@@ -533,6 +534,7 @@ class NewRequestData:
prompt_token_ids
=
request
.
prompt_token_ids
,
prompt_token_ids
=
request
.
prompt_token_ids
,
prompt
=
request
.
prompt
,
prompt
=
request
.
prompt
,
mm_inputs
=
request
.
mm_inputs
,
mm_inputs
=
request
.
mm_inputs
,
mm_hashes
=
request
.
mm_hashes
,
mm_positions
=
request
.
mm_positions
,
mm_positions
=
request
.
mm_positions
,
sampling_params
=
request
.
sampling_params
,
sampling_params
=
request
.
sampling_params
,
block_ids
=
block_ids
,
block_ids
=
block_ids
,
...
...
vllm/v1/engine/async_llm.py
View file @
bf8717eb
...
@@ -60,9 +60,13 @@ class AsyncLLM(EngineClient):
...
@@ -60,9 +60,13 @@ class AsyncLLM(EngineClient):
self
.
client_aborted_requests
:
List
[
str
]
=
[]
self
.
client_aborted_requests
:
List
[
str
]
=
[]
# Processor (converts Inputs --> EngineCoreRequests).
# Processor (converts Inputs --> EngineCoreRequests).
self
.
processor
=
Processor
(
vllm_config
.
model_config
,
self
.
processor
=
Processor
(
vllm_config
.
lora_config
,
self
.
tokenizer
,
model_config
=
vllm_config
.
model_config
,
input_registry
)
cache_config
=
vllm_config
.
cache_config
,
lora_config
=
vllm_config
.
lora_config
,
tokenizer
=
self
.
tokenizer
,
input_registry
=
input_registry
,
)
# Detokenizer (converts EngineCoreOutputs --> RequestOutput).
# Detokenizer (converts EngineCoreOutputs --> RequestOutput).
self
.
detokenizer
=
Detokenizer
(
self
.
detokenizer
=
Detokenizer
(
...
...
vllm/v1/engine/core.py
View file @
bf8717eb
...
@@ -65,7 +65,8 @@ class EngineCore:
...
@@ -65,7 +65,8 @@ class EngineCore:
self
.
_last_logging_time
=
time
.
time
()
self
.
_last_logging_time
=
time
.
time
()
self
.
mm_input_mapper_server
=
MMInputMapperServer
()
self
.
mm_input_mapper_server
=
MMInputMapperServer
(
vllm_config
.
model_config
)
def
_initialize_kv_caches
(
self
,
def
_initialize_kv_caches
(
self
,
cache_config
:
CacheConfig
)
->
Tuple
[
int
,
int
]:
cache_config
:
CacheConfig
)
->
Tuple
[
int
,
int
]:
...
@@ -98,9 +99,8 @@ class EngineCore:
...
@@ -98,9 +99,8 @@ class EngineCore:
# MM mapper, so anything that has a hash must have a HIT cache
# MM mapper, so anything that has a hash must have a HIT cache
# entry here as well.
# entry here as well.
assert
request
.
mm_inputs
is
not
None
assert
request
.
mm_inputs
is
not
None
request
.
mm_inputs
,
request
.
mm_hashes
=
(
request
.
mm_inputs
=
self
.
mm_input_mapper_server
.
process_inputs
(
self
.
mm_input_mapper_server
.
process_inputs
(
request
.
mm_inputs
,
request
.
mm_hashes
)
request
.
mm_inputs
,
request
.
mm_hashes
))
req
=
Request
.
from_engine_core_request
(
request
)
req
=
Request
.
from_engine_core_request
(
request
)
...
...
vllm/v1/engine/llm_engine.py
View file @
bf8717eb
...
@@ -55,9 +55,12 @@ class LLMEngine:
...
@@ -55,9 +55,12 @@ class LLMEngine:
self
.
tokenizer
.
ping
()
self
.
tokenizer
.
ping
()
# Processor (convert Inputs --> EngineCoreRequests)
# Processor (convert Inputs --> EngineCoreRequests)
self
.
processor
=
Processor
(
vllm_config
.
model_config
,
self
.
processor
=
Processor
(
model_config
=
vllm_config
.
model_config
,
vllm_config
.
lora_config
,
self
.
tokenizer
,
cache_config
=
vllm_config
.
cache_config
,
input_registry
,
mm_registry
)
lora_config
=
vllm_config
.
lora_config
,
tokenizer
=
self
.
tokenizer
,
input_registry
=
input_registry
,
mm_registry
=
mm_registry
)
# Detokenizer (converts EngineCoreOutputs --> RequestOutput)
# Detokenizer (converts EngineCoreOutputs --> RequestOutput)
self
.
detokenizer
=
Detokenizer
(
self
.
detokenizer
=
Detokenizer
(
...
...
vllm/v1/engine/mm_input_mapper.py
View file @
bf8717eb
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
from
typing
import
Any
,
Dict
,
List
,
Optional
import
PIL
import
PIL
from
blake3
import
blake3
from
blake3
import
blake3
...
@@ -42,6 +42,8 @@ class MMInputMapperClient:
...
@@ -42,6 +42,8 @@ class MMInputMapperClient:
model_config
)
model_config
)
self
.
mm_registry
.
init_mm_limits_per_prompt
(
model_config
)
self
.
mm_registry
.
init_mm_limits_per_prompt
(
model_config
)
# Init cache
self
.
use_cache
=
model_config
.
mm_cache_preprocessor
self
.
mm_cache
=
LRUDictCache
[
str
,
MultiModalKwargs
](
MM_CACHE_SIZE
)
self
.
mm_cache
=
LRUDictCache
[
str
,
MultiModalKwargs
](
MM_CACHE_SIZE
)
# DEBUG: Set to None to disable
# DEBUG: Set to None to disable
...
@@ -61,7 +63,7 @@ class MMInputMapperClient:
...
@@ -61,7 +63,7 @@ class MMInputMapperClient:
mm_hashes
:
Optional
[
List
[
str
]],
mm_hashes
:
Optional
[
List
[
str
]],
mm_processor_kwargs
:
Optional
[
Dict
[
str
,
Any
]],
mm_processor_kwargs
:
Optional
[
Dict
[
str
,
Any
]],
precomputed_mm_inputs
:
Optional
[
List
[
MultiModalKwargs
]],
precomputed_mm_inputs
:
Optional
[
List
[
MultiModalKwargs
]],
)
->
Tuple
[
List
[
MultiModalKwargs
]
,
Optional
[
List
[
str
]]]
:
)
->
List
[
MultiModalKwargs
]:
if
precomputed_mm_inputs
is
None
:
if
precomputed_mm_inputs
is
None
:
image_inputs
=
mm_data
[
"image"
]
image_inputs
=
mm_data
[
"image"
]
if
not
isinstance
(
image_inputs
,
list
):
if
not
isinstance
(
image_inputs
,
list
):
...
@@ -70,26 +72,21 @@ class MMInputMapperClient:
...
@@ -70,26 +72,21 @@ class MMInputMapperClient:
else
:
else
:
num_inputs
=
len
(
precomputed_mm_inputs
)
num_inputs
=
len
(
precomputed_mm_inputs
)
# Check if hash is enabled
# Sanity
use_hash
=
mm_hashes
is
not
None
if
self
.
use_cache
:
if
use_hash
:
assert
mm_hashes
is
not
None
assert
mm_hashes
is
not
None
assert
num_inputs
==
len
(
assert
num_inputs
==
len
(
mm_hashes
)
mm_hashes
),
"num_inputs = {} len(mm_hashes) = {}"
.
format
(
num_inputs
,
len
(
mm_hashes
))
# Process each image input separately, so that later we can schedule
# Process each image input separately, so that later we can schedule
# them in a fine-grained manner.
# them in a fine-grained manner.
# Apply caching (if enabled) and reuse precomputed inputs (if provided)
# Apply caching (if enabled) and reuse precomputed inputs (if provided)
ret_hashes
:
Optional
[
List
[
str
]]
=
[]
if
use_hash
else
None
ret_inputs
:
List
[
MultiModalKwargs
]
=
[]
ret_inputs
:
List
[
MultiModalKwargs
]
=
[]
for
input_id
in
range
(
num_inputs
):
for
input_id
in
range
(
num_inputs
):
if
self
.
mm_debug_cache_hit_ratio_steps
is
not
None
:
if
self
.
mm_debug_cache_hit_ratio_steps
is
not
None
:
self
.
cache_hit_ratio
(
self
.
mm_debug_cache_hit_ratio_steps
)
self
.
cache_hit_ratio
(
self
.
mm_debug_cache_hit_ratio_steps
)
mm_hash
=
None
mm_input
=
None
mm_input
=
None
if
use_hash
:
if
self
.
use_cache
:
assert
mm_hashes
is
not
None
assert
mm_hashes
is
not
None
mm_hash
=
mm_hashes
[
input_id
]
mm_hash
=
mm_hashes
[
input_id
]
mm_input
=
self
.
mm_cache
.
get
(
mm_hash
)
mm_input
=
self
.
mm_cache
.
get
(
mm_hash
)
...
@@ -106,7 +103,7 @@ class MMInputMapperClient:
...
@@ -106,7 +103,7 @@ class MMInputMapperClient:
mm_processor_kwargs
=
mm_processor_kwargs
,
mm_processor_kwargs
=
mm_processor_kwargs
,
)
)
if
use_hash
:
if
self
.
use_cache
:
# Add to cache
# Add to cache
assert
mm_hash
is
not
None
assert
mm_hash
is
not
None
self
.
mm_cache
.
put
(
mm_hash
,
mm_input
)
self
.
mm_cache
.
put
(
mm_hash
,
mm_input
)
...
@@ -114,18 +111,15 @@ class MMInputMapperClient:
...
@@ -114,18 +111,15 @@ class MMInputMapperClient:
self
.
mm_cache_hits
+=
1
self
.
mm_cache_hits
+=
1
mm_input
=
None
# Avoids sending mm_input to Server
mm_input
=
None
# Avoids sending mm_input to Server
if
use_hash
:
assert
mm_hash
is
not
None
assert
ret_hashes
is
not
None
ret_hashes
.
append
(
mm_hash
)
ret_inputs
.
append
(
mm_input
)
ret_inputs
.
append
(
mm_input
)
return
ret_inputs
,
ret_hashes
return
ret_inputs
class
MMInputMapperServer
:
class
MMInputMapperServer
:
def
__init__
(
self
,
):
def
__init__
(
self
,
model_config
):
self
.
use_cache
=
model_config
.
mm_cache_preprocessor
self
.
mm_cache
=
LRUDictCache
[
str
,
MultiModalKwargs
](
MM_CACHE_SIZE
)
self
.
mm_cache
=
LRUDictCache
[
str
,
MultiModalKwargs
](
MM_CACHE_SIZE
)
def
process_inputs
(
def
process_inputs
(
...
@@ -135,6 +129,9 @@ class MMInputMapperServer:
...
@@ -135,6 +129,9 @@ class MMInputMapperServer:
)
->
List
[
MultiModalKwargs
]:
)
->
List
[
MultiModalKwargs
]:
assert
len
(
mm_inputs
)
==
len
(
mm_hashes
)
assert
len
(
mm_inputs
)
==
len
(
mm_hashes
)
if
not
self
.
use_cache
:
return
mm_inputs
full_mm_inputs
=
[]
full_mm_inputs
=
[]
for
mm_input
,
mm_hash
in
zip
(
mm_inputs
,
mm_hashes
):
for
mm_input
,
mm_hash
in
zip
(
mm_inputs
,
mm_hashes
):
assert
mm_hash
is
not
None
assert
mm_hash
is
not
None
...
...
vllm/v1/engine/processor.py
View file @
bf8717eb
import
time
import
time
from
typing
import
Any
,
Dict
,
Mapping
,
Optional
,
Tuple
,
Union
from
typing
import
Any
,
Dict
,
Mapping
,
Optional
,
Tuple
,
Union
from
vllm.config
import
LoRAConfig
,
ModelConfig
from
vllm.config
import
CacheConfig
,
LoRAConfig
,
ModelConfig
from
vllm.inputs
import
(
INPUT_REGISTRY
,
InputRegistry
,
ProcessorInputs
,
from
vllm.inputs
import
(
INPUT_REGISTRY
,
InputRegistry
,
ProcessorInputs
,
PromptType
,
SingletonInputsAdapter
)
PromptType
,
SingletonInputsAdapter
)
from
vllm.inputs.parse
import
is_encoder_decoder_inputs
from
vllm.inputs.parse
import
is_encoder_decoder_inputs
...
@@ -23,6 +23,7 @@ class Processor:
...
@@ -23,6 +23,7 @@ class Processor:
def
__init__
(
def
__init__
(
self
,
self
,
model_config
:
ModelConfig
,
model_config
:
ModelConfig
,
cache_config
:
CacheConfig
,
lora_config
:
Optional
[
LoRAConfig
],
lora_config
:
Optional
[
LoRAConfig
],
tokenizer
:
BaseTokenizerGroup
,
tokenizer
:
BaseTokenizerGroup
,
input_registry
:
InputRegistry
=
INPUT_REGISTRY
,
input_registry
:
InputRegistry
=
INPUT_REGISTRY
,
...
@@ -45,8 +46,9 @@ class Processor:
...
@@ -45,8 +46,9 @@ class Processor:
self
.
mm_input_mapper_client
=
MMInputMapperClient
(
model_config
)
self
.
mm_input_mapper_client
=
MMInputMapperClient
(
model_config
)
# Multi-modal hasher (for images)
# Multi-modal hasher (for images)
self
.
mm_hasher
=
MMHasher
(
self
.
use_hash
=
model_config
.
mm_cache_preprocessor
or
\
)
if
model_config
.
mm_cache_preprocessor
else
None
cache_config
.
enable_prefix_caching
self
.
mm_hasher
=
MMHasher
()
# TODO: run in an ThreadpoolExecutor or BackgroundProcess.
# TODO: run in an ThreadpoolExecutor or BackgroundProcess.
# This ideally should releases the GIL, so we should not block the
# This ideally should releases the GIL, so we should not block the
...
@@ -77,7 +79,7 @@ class Processor:
...
@@ -77,7 +79,7 @@ class Processor:
# Compute MM hashes (if enabled)
# Compute MM hashes (if enabled)
mm_hashes
=
None
mm_hashes
=
None
if
self
.
mm
_hash
er
is
not
None
:
if
self
.
use
_hash
:
mm_hashes
=
self
.
mm_hasher
.
hash
(
prompt
)
mm_hashes
=
self
.
mm_hasher
.
hash
(
prompt
)
# Process inputs.
# Process inputs.
...
@@ -118,7 +120,7 @@ class Processor:
...
@@ -118,7 +120,7 @@ class Processor:
# Apply MM mapper
# Apply MM mapper
mm_inputs
=
None
mm_inputs
=
None
if
len
(
decoder_inputs
.
multi_modal_data
)
>
0
:
if
len
(
decoder_inputs
.
multi_modal_data
)
>
0
:
mm_inputs
,
mm_hashes
=
self
.
mm_input_mapper_client
.
process_inputs
(
mm_inputs
=
self
.
mm_input_mapper_client
.
process_inputs
(
decoder_inputs
.
multi_modal_data
,
mm_hashes
,
decoder_inputs
.
multi_modal_data
,
mm_hashes
,
decoder_inputs
.
mm_processor_kwargs
,
precomputed_mm_inputs
)
decoder_inputs
.
mm_processor_kwargs
,
precomputed_mm_inputs
)
...
...
vllm/v1/request.py
View file @
bf8717eb
import
enum
import
enum
from
typing
import
List
,
Optional
,
Union
from
typing
import
TYPE_CHECKING
,
List
,
Optional
,
Union
from
vllm.inputs
import
DecoderOnlyInputs
,
SingletonInputsAdapter
,
token_inputs
from
vllm.inputs
import
DecoderOnlyInputs
,
SingletonInputsAdapter
,
token_inputs
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
...
@@ -9,6 +9,9 @@ from vllm.sequence import RequestMetrics
...
@@ -9,6 +9,9 @@ from vllm.sequence import RequestMetrics
from
vllm.v1.engine
import
EngineCoreRequest
from
vllm.v1.engine
import
EngineCoreRequest
from
vllm.v1.utils
import
ConstantList
from
vllm.v1.utils
import
ConstantList
if
TYPE_CHECKING
:
from
vllm.v1.core.kv_cache_utils
import
BlockHashType
class
Request
:
class
Request
:
...
@@ -45,6 +48,7 @@ class Request:
...
@@ -45,6 +48,7 @@ class Request:
self
.
_all_token_ids
:
List
[
int
]
=
self
.
prompt_token_ids
.
copy
()
self
.
_all_token_ids
:
List
[
int
]
=
self
.
prompt_token_ids
.
copy
()
self
.
num_computed_tokens
=
0
self
.
num_computed_tokens
=
0
# Multi-modal input metadata.
mm_positions
=
self
.
inputs
.
multi_modal_placeholders
mm_positions
=
self
.
inputs
.
multi_modal_placeholders
if
mm_positions
:
if
mm_positions
:
# FIXME(woosuk): Support other modalities.
# FIXME(woosuk): Support other modalities.
...
@@ -56,6 +60,12 @@ class Request:
...
@@ -56,6 +60,12 @@ class Request:
if
self
.
inputs
.
multi_modal_inputs
:
if
self
.
inputs
.
multi_modal_inputs
:
self
.
mm_inputs
=
self
.
inputs
.
multi_modal_inputs
self
.
mm_inputs
=
self
.
inputs
.
multi_modal_inputs
self
.
mm_hashes
:
List
[
str
]
=
self
.
inputs
.
multi_modal_hashes
# Cache the computed kv block hashes of the request to avoid
# recomputing.
self
.
_kv_block_hashes
:
List
[
BlockHashType
]
=
[]
@
classmethod
@
classmethod
def
from_engine_core_request
(
cls
,
request
:
EngineCoreRequest
)
->
"Request"
:
def
from_engine_core_request
(
cls
,
request
:
EngineCoreRequest
)
->
"Request"
:
return
cls
(
return
cls
(
...
@@ -65,6 +75,7 @@ class Request:
...
@@ -65,6 +75,7 @@ class Request:
prompt
=
request
.
prompt
,
prompt
=
request
.
prompt
,
multi_modal_data
=
None
,
multi_modal_data
=
None
,
multi_modal_inputs
=
request
.
mm_inputs
,
multi_modal_inputs
=
request
.
mm_inputs
,
multi_modal_hashes
=
request
.
mm_hashes
,
multi_modal_placeholders
=
request
.
mm_placeholders
,
multi_modal_placeholders
=
request
.
mm_placeholders
,
mm_processor_kwargs
=
None
,
mm_processor_kwargs
=
None
,
),
),
...
@@ -121,6 +132,17 @@ class Request:
...
@@ -121,6 +132,17 @@ class Request:
num_tokens
=
self
.
mm_positions
[
input_id
][
"length"
]
num_tokens
=
self
.
mm_positions
[
input_id
][
"length"
]
return
num_tokens
return
num_tokens
@
property
def
kv_block_hashes
(
self
)
->
ConstantList
[
"BlockHashType"
]:
# Prevent directly appending to the kv_block_hashes.
return
ConstantList
(
self
.
_kv_block_hashes
)
def
set_kv_block_hashes
(
self
,
value
:
List
[
"BlockHashType"
])
->
None
:
self
.
_kv_block_hashes
=
value
def
append_kv_block_hashes
(
self
,
block_hash
:
"BlockHashType"
)
->
None
:
self
.
_kv_block_hashes
.
append
(
block_hash
)
class
RequestStatus
(
enum
.
IntEnum
):
class
RequestStatus
(
enum
.
IntEnum
):
"""Status of a request."""
"""Status of a request."""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment