Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
8c9da6be
Unverified
Commit
8c9da6be
authored
Aug 08, 2025
by
Cyrus Leung
Committed by
GitHub
Aug 07, 2025
Browse files
[Core] Simplify mm processing cache (#22457)
Signed-off-by:
DarkLight1337
<
tlleungac@connect.ust.hk
>
parent
399d2a10
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
97 additions
and
206 deletions
+97
-206
vllm/model_executor/models/qwen2_5_omni_thinker.py
vllm/model_executor/models/qwen2_5_omni_thinker.py
+6
-6
vllm/model_executor/models/transformers.py
vllm/model_executor/models/transformers.py
+3
-2
vllm/multimodal/processing.py
vllm/multimodal/processing.py
+69
-179
vllm/v1/serial_utils.py
vllm/v1/serial_utils.py
+19
-19
No files found.
vllm/model_executor/models/qwen2_5_omni_thinker.py
View file @
8c9da6be
...
@@ -431,7 +431,7 @@ class Qwen2_5OmniThinkerMultiModalProcessor(
...
@@ -431,7 +431,7 @@ class Qwen2_5OmniThinkerMultiModalProcessor(
tokenization_kwargs
:
Mapping
[
str
,
object
],
tokenization_kwargs
:
Mapping
[
str
,
object
],
*
,
*
,
enable_hf_prompt_update
:
bool
,
enable_hf_prompt_update
:
bool
,
)
->
tuple
[
list
[
int
],
MultiModalKwargs
,
bool
]:
)
->
tuple
[
list
[
int
],
BatchFeature
,
bool
]:
"""
"""
Qwen2.5-Omni reimplements this function to handle text only.
Qwen2.5-Omni reimplements this function to handle text only.
"""
"""
...
@@ -448,20 +448,20 @@ class Qwen2_5OmniThinkerMultiModalProcessor(
...
@@ -448,20 +448,20 @@ class Qwen2_5OmniThinkerMultiModalProcessor(
else
:
else
:
prompt_ids
=
self
.
_apply_hf_processor_tokens_only
(
prompt
)
prompt_ids
=
self
.
_apply_hf_processor_tokens_only
(
prompt
)
mm_
kwargs
=
self
.
_apply_hf_processor_mm_only
(
mm_
processed_data
=
self
.
_apply_hf_processor_mm_only
(
mm_items
=
mm_items
,
mm_items
=
mm_items
,
hf_processor_mm_kwargs
=
hf_processor_mm_kwargs
,
hf_processor_mm_kwargs
=
hf_processor_mm_kwargs
,
tokenization_kwargs
=
tokenization_kwargs
,
tokenization_kwargs
=
tokenization_kwargs
,
)
)
return
prompt_ids
,
mm_
kwargs
,
False
return
prompt_ids
,
mm_
processed_data
,
False
def
_apply_hf_processor_mm_only
(
def
_apply_hf_processor_mm_only
(
self
,
self
,
mm_items
:
MultiModalDataItems
,
mm_items
:
MultiModalDataItems
,
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
tokenization_kwargs
:
Mapping
[
str
,
object
],
tokenization_kwargs
:
Mapping
[
str
,
object
],
)
->
MultiModalKwargs
:
)
->
BatchFeature
:
"""
"""
Qwen2.5-Omni reimplements this function to handle `use_audio_in_video`.
Qwen2.5-Omni reimplements this function to handle `use_audio_in_video`.
"""
"""
...
@@ -473,14 +473,14 @@ class Qwen2_5OmniThinkerMultiModalProcessor(
...
@@ -473,14 +473,14 @@ class Qwen2_5OmniThinkerMultiModalProcessor(
assert
"audio"
in
mm_counts
assert
"audio"
in
mm_counts
mm_counts
[
"audio"
]
-=
mm_counts
[
"video"
]
mm_counts
[
"audio"
]
-=
mm_counts
[
"video"
]
_
,
mm_
kwargs
,
_
=
self
.
_apply_hf_processor_text_mm
(
_
,
mm_
processed_data
,
_
=
self
.
_apply_hf_processor_text_mm
(
prompt_text
=
self
.
dummy_inputs
.
get_dummy_text
(
mm_counts
),
prompt_text
=
self
.
dummy_inputs
.
get_dummy_text
(
mm_counts
),
mm_items
=
mm_items
,
mm_items
=
mm_items
,
hf_processor_mm_kwargs
=
hf_processor_mm_kwargs
,
hf_processor_mm_kwargs
=
hf_processor_mm_kwargs
,
tokenization_kwargs
=
tokenization_kwargs
,
tokenization_kwargs
=
tokenization_kwargs
,
)
)
return
mm_
kwargs
return
mm_
processed_data
def
_validate_mm_placeholders
(
def
_validate_mm_placeholders
(
self
,
self
,
...
...
vllm/model_executor/models/transformers.py
View file @
8c9da6be
...
@@ -22,7 +22,8 @@ from typing import Literal, Optional, Union
...
@@ -22,7 +22,8 @@ from typing import Literal, Optional, Union
import
regex
as
re
import
regex
as
re
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
transformers
import
AutoModel
,
PretrainedConfig
,
PreTrainedModel
from
transformers
import
(
AutoModel
,
BatchFeature
,
PretrainedConfig
,
PreTrainedModel
)
from
transformers.modeling_utils
import
ALL_ATTENTION_FUNCTIONS
from
transformers.modeling_utils
import
ALL_ATTENTION_FUNCTIONS
from
vllm.attention
import
Attention
from
vllm.attention
import
Attention
...
@@ -269,7 +270,7 @@ class MultiModalProcessor(BaseMultiModalProcessor[MultiModalProcessingInfo]):
...
@@ -269,7 +270,7 @@ class MultiModalProcessor(BaseMultiModalProcessor[MultiModalProcessingInfo]):
mm_items
:
MultiModalDataItems
,
mm_items
:
MultiModalDataItems
,
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
tokenization_kwargs
:
Mapping
[
str
,
object
],
tokenization_kwargs
:
Mapping
[
str
,
object
],
):
)
->
tuple
[
list
[
int
],
BatchFeature
,
bool
]
:
"""
"""
Apply the HF processor on the prompt text and multi-modal data
Apply the HF processor on the prompt text and multi-modal data
together.
together.
...
...
vllm/multimodal/processing.py
View file @
8c9da6be
...
@@ -18,7 +18,7 @@ from vllm.inputs import InputProcessingContext
...
@@ -18,7 +18,7 @@ from vllm.inputs import InputProcessingContext
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.transformers_utils.tokenizer
import
(
AnyTokenizer
,
decode_tokens
,
from
vllm.transformers_utils.tokenizer
import
(
AnyTokenizer
,
decode_tokens
,
encode_tokens
)
encode_tokens
)
from
vllm.utils
import
GiB_bytes
,
flatten_2d_lists
,
full_groupby
from
vllm.utils
import
flatten_2d_lists
,
full_groupby
from
.cache
import
MultiModalCache
from
.cache
import
MultiModalCache
from
.hasher
import
MultiModalHasher
from
.hasher
import
MultiModalHasher
...
@@ -887,120 +887,19 @@ def find_mm_placeholders(
...
@@ -887,120 +887,19 @@ def find_mm_placeholders(
return
dict
(
full_groupby_modality
(
it
))
return
dict
(
full_groupby_modality
(
it
))
class
ProcessingCacheOptionalItem
(
NamedTuple
):
key
:
str
value
:
Optional
[
MultiModalKwargsItem
]
class
ProcessingCacheItem
(
NamedTuple
):
key
:
str
value
:
MultiModalKwargsItem
class
ProcessingCache
(
MultiModalCache
):
class
ProcessingCache
(
MultiModalCache
):
def
__init__
(
def
__init__
(
self
,
capacity_gb
:
float
)
->
None
:
self
,
capacity_gb
:
float
,
*
,
debug_cache_hit_ratio_steps
:
Optional
[
int
]
=
None
,
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
debug_cache_hit_ratio_steps
=
debug_cache_hit_ratio_steps
self
.
_cache
=
self
.
get_lru_cache
(
capacity_gb
,
MultiModalKwargsItem
)
self
.
debug_cache_hits
=
0
self
.
debug_cache_total
=
0
self
.
_cache
=
self
.
get_lru_cache
(
capacity_gb
,
MultiModalKwargsItem
,
debug
=
bool
(
debug_cache_hit_ratio_steps
),
)
def
_maybe_log_cache_stats
(
self
)
->
None
:
steps
=
self
.
debug_cache_hit_ratio_steps
if
not
steps
:
return
total
=
self
.
debug_cache_total
if
total
>
0
and
total
%
steps
==
0
:
logger
.
debug
(
"ProcessingCache: hit_ratio = %.2f"
,
self
.
debug_cache_hits
/
total
)
logger
.
debug
(
"ProcessingCache: size = %.2f / %.2f GiB"
,
self
.
_cache
.
currsize
/
GiB_bytes
,
self
.
_cache
.
maxsize
/
GiB_bytes
)
def
get
(
self
,
model_id
:
str
,
modality
:
str
,
input_item
:
object
,
input_kwargs
:
Mapping
[
str
,
object
],
)
->
Optional
[
MultiModalKwargsItem
]:
"""
Get a processed multi-modal item from the cache
according to its dependencies, including:
- The model ID
- The modality of the item
- The original data item passed to the HF processor
- The configuration options of the HF processor
"""
self
.
_maybe_log_cache_stats
()
cache_key
=
MultiModalHasher
.
hash_kwargs
(
model_id
=
model_id
,
self
.
get
=
self
.
_cache
.
get
**
{
modality
:
input_item
},
self
.
put
=
self
.
_cache
.
put
**
input_kwargs
)
self
.
reset
=
self
.
_cache
.
clear
if
self
.
debug_cache_hit_ratio_steps
:
if
cache_key
in
self
.
_cache
:
self
.
debug_cache_hits
+=
1
self
.
debug_cache_total
+=
1
_CacheItemOrHash
=
Union
[
MultiModalKwargsItem
,
str
]
return
self
.
_cache
.
get
(
cache_key
)
def
get_item
(
self
,
model_id
:
str
,
modality
:
str
,
input_item
:
object
,
input_kwargs
:
Mapping
[
str
,
object
],
)
->
ProcessingCacheOptionalItem
:
cache_key
=
MultiModalHasher
.
hash_kwargs
(
model_id
=
model_id
,
**
{
modality
:
input_item
},
**
input_kwargs
)
return
ProcessingCacheOptionalItem
(
key
=
cache_key
,
value
=
self
.
_cache
.
get
(
cache_key
),
)
def
put
(
self
,
model_id
:
str
,
modality
:
str
,
input_item
:
object
,
input_kwargs
:
Mapping
[
str
,
object
],
output_kwargs
:
MultiModalKwargsItem
,
)
->
None
:
"""
Put a processed multi-modal item into the cache
according to its dependencies
(see [`get`][vllm.multimodal.processing.ProcessingCache.get]).
"""
cache_key
=
MultiModalHasher
.
hash_kwargs
(
model_id
=
model_id
,
**
{
modality
:
input_item
},
**
input_kwargs
)
self
.
_cache
[
cache_key
]
=
output_kwargs
def
put_item
(
self
,
item
:
ProcessingCacheItem
)
->
None
:
self
.
_cache
[
item
.
key
]
=
item
.
value
def
reset
(
self
)
->
bool
:
self
.
_cache
.
clear
()
return
True
class
BaseProcessingInfo
:
class
BaseProcessingInfo
:
...
@@ -1279,7 +1178,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
...
@@ -1279,7 +1178,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
mm_items
:
MultiModalDataItems
,
mm_items
:
MultiModalDataItems
,
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
tokenization_kwargs
:
Mapping
[
str
,
object
],
tokenization_kwargs
:
Mapping
[
str
,
object
],
)
->
tuple
[
list
[
int
],
MultiModalKwargs
,
bool
]:
)
->
tuple
[
list
[
int
],
"BatchFeature"
,
bool
]:
"""
"""
Apply the HF processor on the prompt text and multi-modal data
Apply the HF processor on the prompt text and multi-modal data
together.
together.
...
@@ -1298,11 +1197,6 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
...
@@ -1298,11 +1197,6 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
prompt_ids
,
=
processed_data
.
pop
(
"input_ids"
).
tolist
()
prompt_ids
,
=
processed_data
.
pop
(
"input_ids"
).
tolist
()
mm_kwargs
=
MultiModalKwargs
.
from_hf_inputs
(
processed_data
,
self
.
_get_mm_fields_config
(
processed_data
,
hf_processor_mm_kwargs
),
)
is_update_applied
=
self
.
_hf_processor_applies_updates
(
is_update_applied
=
self
.
_hf_processor_applies_updates
(
prompt_text
=
prompt_text
,
prompt_text
=
prompt_text
,
mm_items
=
mm_items
,
mm_items
=
mm_items
,
...
@@ -1310,11 +1204,13 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
...
@@ -1310,11 +1204,13 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
tokenization_kwargs
=
tokenization_kwargs
,
tokenization_kwargs
=
tokenization_kwargs
,
)
)
return
prompt_ids
,
mm_kwargs
,
is_update_applied
return
prompt_ids
,
processed_data
,
is_update_applied
def
_apply_hf_processor_text_only
(
def
_apply_hf_processor_text_only
(
self
,
prompt_text
:
str
,
self
,
tokenization_kwargs
:
Mapping
[
str
,
object
])
->
list
[
int
]:
prompt_text
:
str
,
tokenization_kwargs
:
Mapping
[
str
,
object
],
)
->
list
[
int
]:
"""
"""
Apply the HF processor on the prompt text only.
Apply the HF processor on the prompt text only.
...
@@ -1353,7 +1249,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
...
@@ -1353,7 +1249,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
mm_items
:
MultiModalDataItems
,
mm_items
:
MultiModalDataItems
,
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
tokenization_kwargs
:
Mapping
[
str
,
object
],
tokenization_kwargs
:
Mapping
[
str
,
object
],
)
->
MultiModalKwargs
:
)
->
"BatchFeature"
:
"""
"""
Apply the HF processor on the multi-modal data only.
Apply the HF processor on the multi-modal data only.
...
@@ -1364,14 +1260,14 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
...
@@ -1364,14 +1260,14 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
"""
"""
mm_counts
=
mm_items
.
get_all_counts
()
mm_counts
=
mm_items
.
get_all_counts
()
_
,
mm_
kwargs
,
_
=
self
.
_apply_hf_processor_text_mm
(
_
,
mm_
processed_data
,
_
=
self
.
_apply_hf_processor_text_mm
(
prompt_text
=
self
.
dummy_inputs
.
get_dummy_text
(
mm_counts
),
prompt_text
=
self
.
dummy_inputs
.
get_dummy_text
(
mm_counts
),
mm_items
=
mm_items
,
mm_items
=
mm_items
,
hf_processor_mm_kwargs
=
hf_processor_mm_kwargs
,
hf_processor_mm_kwargs
=
hf_processor_mm_kwargs
,
tokenization_kwargs
=
tokenization_kwargs
,
tokenization_kwargs
=
tokenization_kwargs
,
)
)
return
mm_
kwargs
return
mm_
processed_data
def
_apply_hf_processor_main
(
def
_apply_hf_processor_main
(
self
,
self
,
...
@@ -1381,7 +1277,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
...
@@ -1381,7 +1277,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
tokenization_kwargs
:
Mapping
[
str
,
object
],
tokenization_kwargs
:
Mapping
[
str
,
object
],
*
,
*
,
enable_hf_prompt_update
:
bool
,
enable_hf_prompt_update
:
bool
,
)
->
tuple
[
list
[
int
],
MultiModalKwargs
,
bool
]:
)
->
tuple
[
list
[
int
],
"BatchFeature"
,
bool
]:
"""
"""
Apply the HF processor on the prompt text and multi-modal data.
Apply the HF processor on the prompt text and multi-modal data.
...
@@ -1407,52 +1303,46 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
...
@@ -1407,52 +1303,46 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
else
:
else
:
prompt_ids
=
self
.
_apply_hf_processor_tokens_only
(
prompt
)
prompt_ids
=
self
.
_apply_hf_processor_tokens_only
(
prompt
)
mm_
kwargs
=
self
.
_apply_hf_processor_mm_only
(
mm_
processed_data
=
self
.
_apply_hf_processor_mm_only
(
mm_items
=
mm_items
,
mm_items
=
mm_items
,
hf_processor_mm_kwargs
=
hf_processor_mm_kwargs
,
hf_processor_mm_kwargs
=
hf_processor_mm_kwargs
,
tokenization_kwargs
=
tokenization_kwargs
,
tokenization_kwargs
=
tokenization_kwargs
,
)
)
return
prompt_ids
,
mm_
kwargs
,
False
return
prompt_ids
,
mm_
processed_data
,
False
def
_get_cache_missing_items
(
def
_get_cache_missing_items
(
self
,
self
,
cache
:
ProcessingCache
,
cache
:
ProcessingCache
,
mm_data_items
:
MultiModalDataItems
,
mm_data_items
:
MultiModalDataItems
,
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
mm_hashes
:
MultiModalHashes
,
tokenization_kwargs
:
Mapping
[
str
,
object
],
)
->
tuple
[
dict
[
str
,
list
[
_CacheItemOrHash
]],
MultiModalDataItems
]:
)
->
tuple
[
dict
[
str
,
list
[
ProcessingCacheOptionalItem
]],
dict
[
mm_cache_items_or_hashes
:
dict
[
str
,
list
[
_CacheItemOrHash
]]
=
{
str
,
list
[
object
]]]:
modality
:
[(
h
if
(
v
:
=
cache
.
get
(
h
))
is
None
else
v
)
model_id
=
self
.
info
.
model_id
for
h
in
hashes
]
for
modality
,
hashes
in
mm_hashes
.
items
()
mm_cache_items
=
{
modality
:
[
cache
.
get_item
(
model_id
,
modality
,
item
,
dict
(
**
hf_processor_mm_kwargs
,
**
tokenization_kwargs
))
for
item
in
items
]
for
modality
,
items
in
mm_data_items
.
items
()
}
}
mm_missing_idxs
=
{
mm_missing_idxs
=
{
modality
:
[
modality
:
[
idx
for
idx
,
item
in
enumerate
(
cache_item
s
)
idx
for
idx
,
item
_or_hash
in
enumerate
(
items_or_hashe
s
)
if
i
tem
.
value
is
None
if
i
sinstance
(
item_or_hash
,
str
)
]
]
for
modality
,
cache_item
s
in
mm_cache_items
.
items
()
for
modality
,
items_or_hashe
s
in
mm_cache_items
_or_hashes
.
items
()
}
}
mm_missing_data
=
{
mm_missing_data
=
{
modality
:
[
mm_data_items
[
modality
][
idx
]
for
idx
in
idxs
]
modality
:
[
mm_data_items
[
modality
][
idx
]
for
idx
in
idxs
]
for
modality
,
idxs
in
mm_missing_idxs
.
items
()
for
modality
,
idxs
in
mm_missing_idxs
.
items
()
}
}
return
mm_cache_items
,
mm_missing_data
return
mm_cache_items
_or_hashes
,
self
.
_to_mm_items
(
mm_missing_data
)
def
_hash_mm_items
(
def
_hash_mm_items
(
self
,
mm_items
:
MultiModalDataItems
,
self
,
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
mm_items
:
MultiModalDataItems
,
tokenization_kwargs
:
Mapping
[
str
,
object
])
->
MultiModalHashes
:
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
tokenization_kwargs
:
Mapping
[
str
,
object
],
)
->
MultiModalHashes
:
"""Create MM hashes to be returned (only used in V1)."""
"""Create MM hashes to be returned (only used in V1)."""
model_id
=
self
.
info
.
model_id
model_id
=
self
.
info
.
model_id
...
@@ -1470,34 +1360,25 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
...
@@ -1470,34 +1360,25 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
def
_merge_mm_kwargs
(
def
_merge_mm_kwargs
(
self
,
self
,
cache
:
ProcessingCache
,
cache
:
ProcessingCache
,
mm_cache_items
:
dict
[
str
,
list
[
ProcessingCacheOptionalItem
]],
mm_cache_items_or_hashes
:
dict
[
str
,
list
[
_CacheItemOrHash
]],
mm_missing_data
:
dict
[
str
,
list
[
object
]],
mm_missing_kwargs
:
MultiModalKwargs
,
mm_missing_kwargs
:
MultiModalKwargs
,
)
->
dict
[
str
,
list
[
ProcessingCache
Item
]]:
)
->
dict
[
str
,
list
[
MultiModalKwargs
Item
]]:
mm_missing_next_idx
=
{
modality
:
0
for
modality
in
mm_missing_data
}
mm_missing_next_idx
=
defaultdict
[
str
,
int
](
lambda
:
0
)
merged_items
=
defaultdict
[
str
,
list
[
ProcessingCache
Item
]](
list
)
merged_items
=
defaultdict
[
str
,
list
[
MultiModalKwargs
Item
]](
list
)
for
modality
,
cache_item
s
in
mm_cache_items
.
items
():
for
modality
,
items_or_hashe
s
in
mm_cache_items
_or_hashes
.
items
():
for
cache_item
in
cache_item
s
:
for
item_or_hash
in
items_or_hashe
s
:
if
cache_item
.
value
is
None
:
if
isinstance
(
item_or_hash
,
str
)
:
kw_item
=
mm_missing_kwargs
.
get_item
(
kw_item
=
mm_missing_kwargs
.
get_item
(
modality
,
modality
,
mm_missing_next_idx
[
modality
],
mm_missing_next_idx
[
modality
],
)
)
cache_item_new
=
ProcessingCacheItem
(
cache
.
put
(
item_or_hash
,
kw_item
)
key
=
cache_item
.
key
,
value
=
kw_item
,
)
cache
.
put_item
(
cache_item_new
)
mm_missing_next_idx
[
modality
]
+=
1
mm_missing_next_idx
[
modality
]
+=
1
else
:
else
:
cache_item_new
=
ProcessingCacheItem
(
kw_item
=
item_or_hash
key
=
cache_item
.
key
,
value
=
cache_item
.
value
,
)
merged_items
[
modality
].
append
(
cache
_item
_new
)
merged_items
[
modality
].
append
(
kw
_item
)
return
dict
(
merged_items
)
return
dict
(
merged_items
)
...
@@ -1512,7 +1393,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
...
@@ -1512,7 +1393,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
)
->
tuple
[
list
[
int
],
MultiModalKwargs
,
Optional
[
MultiModalHashes
],
bool
]:
)
->
tuple
[
list
[
int
],
MultiModalKwargs
,
Optional
[
MultiModalHashes
],
bool
]:
(
(
prompt_ids
,
prompt_ids
,
mm_
kwargs
,
mm_
processed_data
,
is_update_applied
,
is_update_applied
,
)
=
self
.
_apply_hf_processor_main
(
)
=
self
.
_apply_hf_processor_main
(
prompt
=
prompt
,
prompt
=
prompt
,
...
@@ -1522,6 +1403,12 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
...
@@ -1522,6 +1403,12 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
enable_hf_prompt_update
=
True
,
enable_hf_prompt_update
=
True
,
)
)
mm_kwargs
=
MultiModalKwargs
.
from_hf_inputs
(
mm_processed_data
,
self
.
_get_mm_fields_config
(
mm_processed_data
,
hf_processor_mm_kwargs
),
)
mm_hashes
=
(
self
.
_hash_mm_items
(
mm_data_items
,
hf_processor_mm_kwargs
,
mm_hashes
=
(
self
.
_hash_mm_items
(
mm_data_items
,
hf_processor_mm_kwargs
,
tokenization_kwargs
)
tokenization_kwargs
)
if
return_mm_hashes
else
None
)
if
return_mm_hashes
else
None
)
...
@@ -1553,49 +1440,52 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
...
@@ -1553,49 +1440,52 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
return_mm_hashes
=
return_mm_hashes
,
return_mm_hashes
=
return_mm_hashes
,
)
)
mm_hashes
=
self
.
_hash_mm_items
(
mm_data_items
,
hf_processor_mm_kwargs
,
tokenization_kwargs
)
(
(
mm_cache_items
,
mm_cache_items
_or_hashes
,
mm_missing_data
,
mm_missing_data
_items
,
)
=
self
.
_get_cache_missing_items
(
)
=
self
.
_get_cache_missing_items
(
cache
=
cache
,
cache
=
cache
,
mm_data_items
=
mm_data_items
,
mm_data_items
=
mm_data_items
,
hf_processor_mm_kwargs
=
hf_processor_mm_kwargs
,
mm_hashes
=
mm_hashes
,
tokenization_kwargs
=
tokenization_kwargs
,
)
)
mm_hashes_to_return
=
mm_hashes
if
return_mm_hashes
else
None
# NOTE: `prompt` does not correspond to `mm_missing_data_items`,
# NOTE: `prompt` does not correspond to `mm_missing_data_items`,
# so we can't apply prompt updates until the new multimodal
# so we can't apply prompt updates until the new multimodal
# items are combined with the cached multimodal items
# items are combined with the cached multimodal items
(
(
prompt_ids
,
prompt_ids
,
mm_missing_
kwargs
,
mm_missing_
processed_data
,
is_update_applied
,
is_update_applied
,
)
=
self
.
_apply_hf_processor_main
(
)
=
self
.
_apply_hf_processor_main
(
prompt
=
prompt
,
prompt
=
prompt
,
mm_items
=
self
.
_to_mm_items
(
mm_missing_data
)
,
mm_items
=
mm_missing_data
_items
,
hf_processor_mm_kwargs
=
hf_processor_mm_kwargs
,
hf_processor_mm_kwargs
=
hf_processor_mm_kwargs
,
tokenization_kwargs
=
tokenization_kwargs
,
tokenization_kwargs
=
tokenization_kwargs
,
enable_hf_prompt_update
=
False
,
enable_hf_prompt_update
=
False
,
)
)
mm_missing_kwargs
=
MultiModalKwargs
.
from_hf_inputs
(
mm_missing_processed_data
,
self
.
_get_mm_fields_config
(
mm_missing_processed_data
,
hf_processor_mm_kwargs
),
)
mm_cache_items_merged
=
self
.
_merge_mm_kwargs
(
mm_cache_items_merged
=
self
.
_merge_mm_kwargs
(
cache
,
cache
,
mm_cache_items
=
mm_cache_items
,
mm_cache_items_or_hashes
=
mm_cache_items_or_hashes
,
mm_missing_data
=
mm_missing_data
,
mm_missing_kwargs
=
mm_missing_kwargs
,
mm_missing_kwargs
=
mm_missing_kwargs
,
)
)
mm_kwargs
=
MultiModalKwargs
.
from_items
([
mm_kwargs
=
MultiModalKwargs
.
from_items
([
item
.
value
for
cache_items
in
mm_cache_items_merged
.
values
()
item
for
cache_items
in
mm_cache_items_merged
.
values
()
for
item
in
cache_items
for
item
in
cache_items
])
])
mm_hashes
=
{
return
prompt_ids
,
mm_kwargs
,
mm_hashes_to_return
,
is_update_applied
modality
:
[
item
.
key
for
item
in
cache_items
]
for
modality
,
cache_items
in
mm_cache_items_merged
.
items
()
}
if
return_mm_hashes
else
None
return
prompt_ids
,
mm_kwargs
,
mm_hashes
,
is_update_applied
def
_bind_and_group_updates
(
def
_bind_and_group_updates
(
self
,
self
,
...
...
vllm/v1/serial_utils.py
View file @
8c9da6be
...
@@ -312,25 +312,25 @@ class MsgpackDecoder:
...
@@ -312,25 +312,25 @@ class MsgpackDecoder:
return
arr
.
view
(
torch_dtype
).
view
(
shape
)
return
arr
.
view
(
torch_dtype
).
view
(
shape
)
def
_decode_mm_items
(
self
,
obj
:
list
)
->
list
[
MultiModalKwargsItem
]:
def
_decode_mm_items
(
self
,
obj
:
list
)
->
list
[
MultiModalKwargsItem
]:
decode
d
_item
s
=
[
]
return
[
self
.
_
decode
_mm
_item
(
v
)
for
v
in
obj
]
for
item
in
obj
:
elems
=
[]
def
_decode_mm_item
(
self
,
obj
:
list
)
->
MultiModalKwargsItem
:
for
v
in
item
:
return
MultiModalKwargsItem
.
from_elems
(
v
[
"data"
]
=
self
.
_decode_nested_tensors
(
v
[
"data"
])
[
self
.
_decode_mm_field_elem
(
v
)
for
v
in
obj
])
# Reconstruct the field processor using MultiModalFieldConfig
factory_meth_name
,
*
field_args
=
v
[
"field"
]
def
_decode_mm_field_elem
(
self
,
obj
:
dict
)
->
MultiModalFieldElem
:
factory_meth
=
getattr
(
MultiModalFieldConfig
,
obj
[
"data"
]
=
self
.
_decode_nested_tensors
(
obj
[
"data"
])
factory_meth_name
)
# Reconstruct the field processor using MultiModalFieldConfig
factory_meth_name
,
*
field_args
=
obj
[
"field"
]
# Special case: decode the union "slices" field of
factory_meth
=
getattr
(
MultiModalFieldConfig
,
factory_meth_name
)
# MultiModalFlatField
if
factory_meth_name
==
"flat"
:
# Special case: decode the union "slices" field of
field_args
[
0
]
=
self
.
_decode_nested_slices
(
field_args
[
0
])
# MultiModalFlatField
if
factory_meth_name
==
"flat"
:
v
[
"field"
]
=
factory_meth
(
None
,
*
field_args
).
field
field_args
[
0
]
=
self
.
_decode_nested_slices
(
field_args
[
0
])
elems
.
append
(
MultiModalFieldElem
(
**
v
))
decoded_items
.
append
(
MultiModalKwargsItem
.
from_elems
(
elems
))
obj
[
"field"
]
=
factory_meth
(
None
,
*
field_args
).
field
return
decoded_items
return
MultiModalFieldElem
(
**
obj
)
def
_decode_nested_tensors
(
self
,
obj
:
Any
)
->
NestedTensors
:
def
_decode_nested_tensors
(
self
,
obj
:
Any
)
->
NestedTensors
:
if
isinstance
(
obj
,
(
int
,
float
)):
if
isinstance
(
obj
,
(
int
,
float
)):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment