Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
506475de
Unverified
Commit
506475de
authored
Apr 29, 2025
by
Cyrus Leung
Committed by
GitHub
Apr 29, 2025
Browse files
[Optim] Compute multimodal hash only once per item (#17314)
Signed-off-by:
DarkLight1337
<
tlleungac@connect.ust.hk
>
parent
cfe45320
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
233 additions
and
128 deletions
+233
-128
vllm/model_executor/models/deepseek_vl2.py
vllm/model_executor/models/deepseek_vl2.py
+9
-7
vllm/model_executor/models/h2ovl.py
vllm/model_executor/models/h2ovl.py
+9
-7
vllm/model_executor/models/llava.py
vllm/model_executor/models/llava.py
+0
-3
vllm/model_executor/models/mistral3.py
vllm/model_executor/models/mistral3.py
+0
-2
vllm/model_executor/models/pixtral.py
vllm/model_executor/models/pixtral.py
+10
-5
vllm/multimodal/processing.py
vllm/multimodal/processing.py
+205
-104
No files found.
vllm/model_executor/models/deepseek_vl2.py
View file @
506475de
...
@@ -22,8 +22,8 @@ from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
...
@@ -22,8 +22,8 @@ from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
from
vllm.multimodal.parse
import
(
ImageEmbeddingItems
,
ImageProcessorItems
,
from
vllm.multimodal.parse
import
(
ImageEmbeddingItems
,
ImageProcessorItems
,
ImageSize
,
MultiModalDataItems
)
ImageSize
,
MultiModalDataItems
)
from
vllm.multimodal.processing
import
(
BaseMultiModalProcessor
,
from
vllm.multimodal.processing
import
(
BaseMultiModalProcessor
,
BaseProcessingInfo
,
PromptReplacement
,
BaseProcessingInfo
,
MultiModalHashes
,
PromptUpdate
)
PromptReplacement
,
PromptUpdate
)
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
vllm.transformers_utils.configs.deepseek_vl2
import
(
DeepseekVLV2Config
,
from
vllm.transformers_utils.configs.deepseek_vl2
import
(
DeepseekVLV2Config
,
...
@@ -279,24 +279,26 @@ class DeepseekVL2MultiModalProcessor(
...
@@ -279,24 +279,26 @@ class DeepseekVL2MultiModalProcessor(
prompt
:
Union
[
str
,
list
[
int
]],
prompt
:
Union
[
str
,
list
[
int
]],
mm_data_items
:
MultiModalDataItems
,
mm_data_items
:
MultiModalDataItems
,
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
)
->
tuple
[
list
[
int
],
MultiModalKwargs
,
bool
]:
*
,
return_mm_hashes
:
bool
,
)
->
tuple
[
list
[
int
],
MultiModalKwargs
,
Optional
[
MultiModalHashes
],
bool
]:
# The processor logic is different for len(images) <= 2 vs > 2
# The processor logic is different for len(images) <= 2 vs > 2
# Since the processing cache assumes that the processor output is
# Since the processing cache assumes that the processor output is
# invariant of how many images are passed per prompt, we only
# invariant of how many images are passed per prompt, we only
# perform caching for the most common case
# perform caching for the most common case
if
mm_data_items
.
get_count
(
"image"
,
strict
=
False
)
>
2
:
if
mm_data_items
.
get_count
(
"image"
,
strict
=
False
)
>
2
:
# This code path corresponds to the cache being disabled
return
self
.
_apply_hf_processor
(
return
self
.
_apply_hf_processor_main
(
prompt
=
prompt
,
prompt
=
prompt
,
mm_items
=
mm_data_items
,
mm_
data_
items
=
mm_data_items
,
hf_processor_mm_kwargs
=
hf_processor_mm_kwargs
,
hf_processor_mm_kwargs
=
hf_processor_mm_kwargs
,
enable_hf_prompt_update
=
True
,
return_mm_hashes
=
return_mm_hashes
,
)
)
return
super
().
_cached_apply_hf_processor
(
return
super
().
_cached_apply_hf_processor
(
prompt
=
prompt
,
prompt
=
prompt
,
mm_data_items
=
mm_data_items
,
mm_data_items
=
mm_data_items
,
hf_processor_mm_kwargs
=
hf_processor_mm_kwargs
,
hf_processor_mm_kwargs
=
hf_processor_mm_kwargs
,
return_mm_hashes
=
return_mm_hashes
,
)
)
...
...
vllm/model_executor/models/h2ovl.py
View file @
506475de
...
@@ -19,8 +19,8 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
...
@@ -19,8 +19,8 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
from
vllm.multimodal.inputs
import
MultiModalKwargs
from
vllm.multimodal.inputs
import
MultiModalKwargs
from
vllm.multimodal.parse
import
(
ImageEmbeddingItems
,
ImageProcessorItems
,
from
vllm.multimodal.parse
import
(
ImageEmbeddingItems
,
ImageProcessorItems
,
MultiModalDataItems
)
MultiModalDataItems
)
from
vllm.multimodal.processing
import
(
PromptReplacement
,
PromptUpdate
,
from
vllm.multimodal.processing
import
(
MultiModalHashes
,
PromptReplacement
,
PromptUpdateDetails
)
PromptUpdate
,
PromptUpdateDetails
)
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
from
.intern_vit
import
InternVisionModel
from
.intern_vit
import
InternVisionModel
...
@@ -488,24 +488,26 @@ class H2OVLMultiModalProcessor(InternVLMultiModalProcessor[H2OVLProcessingInfo]
...
@@ -488,24 +488,26 @@ class H2OVLMultiModalProcessor(InternVLMultiModalProcessor[H2OVLProcessingInfo]
prompt
:
Union
[
str
,
list
[
int
]],
prompt
:
Union
[
str
,
list
[
int
]],
mm_data_items
:
MultiModalDataItems
,
mm_data_items
:
MultiModalDataItems
,
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
)
->
tuple
[
list
[
int
],
MultiModalKwargs
,
bool
]:
*
,
return_mm_hashes
:
bool
,
)
->
tuple
[
list
[
int
],
MultiModalKwargs
,
Optional
[
MultiModalHashes
],
bool
]:
# The processor logic is different for len(images) <= 1 vs > 1
# The processor logic is different for len(images) <= 1 vs > 1
# Since the processing cache assumes that the processor output is
# Since the processing cache assumes that the processor output is
# invariant of how many images are passed per prompt, we only
# invariant of how many images are passed per prompt, we only
# perform caching for the most common case
# perform caching for the most common case
if
mm_data_items
.
get_count
(
"image"
,
strict
=
False
)
>
1
:
if
mm_data_items
.
get_count
(
"image"
,
strict
=
False
)
>
1
:
# This code path corresponds to the cache being disabled
return
self
.
_apply_hf_processor
(
return
self
.
_apply_hf_processor_main
(
prompt
=
prompt
,
prompt
=
prompt
,
mm_items
=
mm_data_items
,
mm_
data_
items
=
mm_data_items
,
hf_processor_mm_kwargs
=
hf_processor_mm_kwargs
,
hf_processor_mm_kwargs
=
hf_processor_mm_kwargs
,
enable_hf_prompt_update
=
True
,
return_mm_hashes
=
return_mm_hashes
,
)
)
return
super
().
_cached_apply_hf_processor
(
return
super
().
_cached_apply_hf_processor
(
prompt
=
prompt
,
prompt
=
prompt
,
mm_data_items
=
mm_data_items
,
mm_data_items
=
mm_data_items
,
hf_processor_mm_kwargs
=
hf_processor_mm_kwargs
,
hf_processor_mm_kwargs
=
hf_processor_mm_kwargs
,
return_mm_hashes
=
return_mm_hashes
,
)
)
...
...
vllm/model_executor/models/llava.py
View file @
506475de
...
@@ -396,14 +396,12 @@ def _build_llava_or_pixtral_hf_processor(
...
@@ -396,14 +396,12 @@ def _build_llava_or_pixtral_hf_processor(
dummy_inputs
:
BaseDummyInputsBuilder
[
_I
],
dummy_inputs
:
BaseDummyInputsBuilder
[
_I
],
*
,
*
,
cache
:
Optional
[
ProcessingCache
]
=
None
,
cache
:
Optional
[
ProcessingCache
]
=
None
,
enable_sanity_checks
:
bool
=
True
,
)
->
BaseMultiModalProcessor
:
)
->
BaseMultiModalProcessor
:
if
isinstance
(
info
,
PixtralHFProcessingInfo
):
if
isinstance
(
info
,
PixtralHFProcessingInfo
):
return
PixtralHFMultiModalProcessor
(
return
PixtralHFMultiModalProcessor
(
info
,
info
,
dummy_inputs
,
# type: ignore
dummy_inputs
,
# type: ignore
cache
=
cache
,
cache
=
cache
,
enable_sanity_checks
=
enable_sanity_checks
,
)
)
if
isinstance
(
info
,
LlavaProcessingInfo
):
if
isinstance
(
info
,
LlavaProcessingInfo
):
...
@@ -411,7 +409,6 @@ def _build_llava_or_pixtral_hf_processor(
...
@@ -411,7 +409,6 @@ def _build_llava_or_pixtral_hf_processor(
info
,
info
,
dummy_inputs
,
# type: ignore
dummy_inputs
,
# type: ignore
cache
=
cache
,
cache
=
cache
,
enable_sanity_checks
=
enable_sanity_checks
,
)
)
raise
NotImplementedError
(
type
(
info
))
raise
NotImplementedError
(
type
(
info
))
...
...
vllm/model_executor/models/mistral3.py
View file @
506475de
...
@@ -312,14 +312,12 @@ def _build_mistral3_processor(
...
@@ -312,14 +312,12 @@ def _build_mistral3_processor(
dummy_inputs
:
BaseDummyInputsBuilder
[
_I
],
dummy_inputs
:
BaseDummyInputsBuilder
[
_I
],
*
,
*
,
cache
:
Optional
[
ProcessingCache
]
=
None
,
cache
:
Optional
[
ProcessingCache
]
=
None
,
enable_sanity_checks
:
bool
=
True
,
)
->
BaseMultiModalProcessor
:
)
->
BaseMultiModalProcessor
:
assert
isinstance
(
info
,
Mistral3ProcessingInfo
)
assert
isinstance
(
info
,
Mistral3ProcessingInfo
)
return
Mistral3MultiModalProcessor
(
return
Mistral3MultiModalProcessor
(
info
,
info
,
dummy_inputs
,
# type: ignore
dummy_inputs
,
# type: ignore
cache
=
cache
,
cache
=
cache
,
enable_sanity_checks
=
enable_sanity_checks
,
)
)
...
...
vllm/model_executor/models/pixtral.py
View file @
506475de
...
@@ -36,8 +36,9 @@ from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
...
@@ -36,8 +36,9 @@ from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
from
vllm.multimodal.parse
import
(
ImageProcessorItems
,
ImageSize
,
from
vllm.multimodal.parse
import
(
ImageProcessorItems
,
ImageSize
,
MultiModalDataItems
)
MultiModalDataItems
)
from
vllm.multimodal.processing
import
(
BaseMultiModalProcessor
,
from
vllm.multimodal.processing
import
(
BaseMultiModalProcessor
,
BaseProcessingInfo
,
PromptReplacement
,
BaseProcessingInfo
,
MultiModalHashes
,
PromptUpdate
,
PromptUpdateDetails
)
PromptReplacement
,
PromptUpdate
,
PromptUpdateDetails
)
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
vllm.transformers_utils.tokenizer
import
(
MistralTokenizer
,
from
vllm.transformers_utils.tokenizer
import
(
MistralTokenizer
,
...
@@ -271,15 +272,19 @@ class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo]
...
@@ -271,15 +272,19 @@ class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo]
prompt
:
Union
[
str
,
list
[
int
]],
prompt
:
Union
[
str
,
list
[
int
]],
mm_data_items
:
MultiModalDataItems
,
mm_data_items
:
MultiModalDataItems
,
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
)
->
tuple
[
list
[
int
],
MultiModalKwargs
,
bool
]:
*
,
prompt_ids
,
mm_kwargs
,
_
=
super
().
_cached_apply_hf_processor
(
return_mm_hashes
:
bool
,
)
->
tuple
[
list
[
int
],
MultiModalKwargs
,
Optional
[
MultiModalHashes
],
bool
]:
prompt_ids
,
mm_kwargs
,
mm_hashes
,
_
=
super
(
).
_cached_apply_hf_processor
(
prompt
=
prompt
,
prompt
=
prompt
,
mm_data_items
=
mm_data_items
,
mm_data_items
=
mm_data_items
,
hf_processor_mm_kwargs
=
hf_processor_mm_kwargs
,
hf_processor_mm_kwargs
=
hf_processor_mm_kwargs
,
return_mm_hashes
=
return_mm_hashes
,
)
)
# NOTE: The tokens are already inserted by the chat template
# NOTE: The tokens are already inserted by the chat template
return
prompt_ids
,
mm_kwargs
,
True
return
prompt_ids
,
mm_kwargs
,
mm_hashes
,
True
@
MULTIMODAL_REGISTRY
.
register_processor
(
PixtralMultiModalProcessor
,
@
MULTIMODAL_REGISTRY
.
register_processor
(
PixtralMultiModalProcessor
,
...
...
vllm/multimodal/processing.py
View file @
506475de
...
@@ -876,6 +876,16 @@ def find_mm_placeholders(
...
@@ -876,6 +876,16 @@ def find_mm_placeholders(
_V
=
TypeVar
(
"_V"
,
bound
=
"Union[MultiModalKwargs, MultiModalKwargsItem]"
)
_V
=
TypeVar
(
"_V"
,
bound
=
"Union[MultiModalKwargs, MultiModalKwargsItem]"
)
class
ProcessingCacheOptionalItem
(
NamedTuple
):
key
:
str
value
:
Optional
[
MultiModalKwargsItem
]
class
ProcessingCacheItem
(
NamedTuple
):
key
:
str
value
:
MultiModalKwargsItem
class
ProcessingCache
:
class
ProcessingCache
:
@
staticmethod
@
staticmethod
...
@@ -980,6 +990,22 @@ class ProcessingCache:
...
@@ -980,6 +990,22 @@ class ProcessingCache:
return
self
.
_cache
.
get
(
cache_key
)
return
self
.
_cache
.
get
(
cache_key
)
def
get_item
(
self
,
model_id
:
str
,
modality
:
str
,
input_item
:
object
,
input_kwargs
:
Mapping
[
str
,
object
],
)
->
ProcessingCacheOptionalItem
:
cache_key
=
MultiModalHasher
.
hash_kwargs
(
model_id
=
model_id
,
**
{
modality
:
input_item
},
**
input_kwargs
)
return
ProcessingCacheOptionalItem
(
key
=
cache_key
,
value
=
self
.
_cache
.
get
(
cache_key
),
)
def
put
(
def
put
(
self
,
self
,
model_id
:
str
,
model_id
:
str
,
...
@@ -997,6 +1023,9 @@ class ProcessingCache:
...
@@ -997,6 +1023,9 @@ class ProcessingCache:
**
input_kwargs
)
**
input_kwargs
)
self
.
_cache
[
cache_key
]
=
output_kwargs
self
.
_cache
[
cache_key
]
=
output_kwargs
def
put_item
(
self
,
item
:
ProcessingCacheItem
)
->
None
:
self
.
_cache
[
item
.
key
]
=
item
.
value
class
BaseProcessingInfo
:
class
BaseProcessingInfo
:
"""Base class to provide the information necessary for data processing."""
"""Base class to provide the information necessary for data processing."""
...
@@ -1052,6 +1081,11 @@ class BaseProcessingInfo:
...
@@ -1052,6 +1081,11 @@ class BaseProcessingInfo:
_I
=
TypeVar
(
"_I"
,
bound
=
BaseProcessingInfo
)
_I
=
TypeVar
(
"_I"
,
bound
=
BaseProcessingInfo
)
MultiModalHashes
=
dict
[
str
,
list
[
str
]]
"""
A collection of hashes with a similar structure as :class:`MultiModalKwargs`.
"""
class
BaseMultiModalProcessor
(
ABC
,
Generic
[
_I
]):
class
BaseMultiModalProcessor
(
ABC
,
Generic
[
_I
]):
"""
"""
...
@@ -1064,14 +1098,12 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
...
@@ -1064,14 +1098,12 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
info
:
_I
,
info
:
_I
,
dummy_inputs
:
"BaseDummyInputsBuilder[_I]"
,
dummy_inputs
:
"BaseDummyInputsBuilder[_I]"
,
*
,
*
,
cache
:
Optional
[
ProcessingCache
]
=
None
,
cache
:
Optional
[
ProcessingCache
]
=
None
)
->
None
:
enable_sanity_checks
:
bool
=
True
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
info
=
info
self
.
info
=
info
self
.
dummy_inputs
=
dummy_inputs
self
.
dummy_inputs
=
dummy_inputs
self
.
cache
=
cache
self
.
cache
=
cache
self
.
enable_sanity_checks
=
enable_sanity_checks
self
.
data_parser
=
self
.
_get_data_parser
()
self
.
data_parser
=
self
.
_get_data_parser
()
...
@@ -1340,46 +1372,144 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
...
@@ -1340,46 +1372,144 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
return
prompt_ids
,
mm_kwargs
,
False
return
prompt_ids
,
mm_kwargs
,
False
def
_get_cache_missing_items
(
self
,
cache
:
ProcessingCache
,
mm_data_items
:
MultiModalDataItems
,
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
)
->
tuple
[
dict
[
str
,
list
[
ProcessingCacheOptionalItem
]],
dict
[
str
,
list
[
object
]]]:
model_id
=
self
.
info
.
model_id
mm_cache_items
=
{
modality
:
[
cache
.
get_item
(
model_id
,
modality
,
item
,
hf_processor_mm_kwargs
)
for
item
in
items
]
for
modality
,
items
in
mm_data_items
.
items
()
}
mm_missing_idxs
=
{
modality
:
[
idx
for
idx
,
item
in
enumerate
(
cache_items
)
if
item
.
value
is
None
]
for
modality
,
cache_items
in
mm_cache_items
.
items
()
}
mm_missing_data
=
{
modality
:
[
mm_data_items
[
modality
][
idx
]
for
idx
in
idxs
]
for
modality
,
idxs
in
mm_missing_idxs
.
items
()
}
return
mm_cache_items
,
mm_missing_data
def
_hash_mm_items
(
self
,
mm_items
:
MultiModalDataItems
,
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
)
->
MultiModalHashes
:
"""Create MM hashes to be returned (only used in V1)."""
model_id
=
self
.
info
.
model_id
return
{
modality
:
[
MultiModalHasher
.
hash_kwargs
(
model_id
=
model_id
,
**
{
modality
:
item
},
**
hf_processor_mm_kwargs
)
for
item
in
items
]
for
modality
,
items
in
mm_items
.
items
()
}
def
_merge_mm_kwargs
(
self
,
cache
:
ProcessingCache
,
mm_cache_items
:
dict
[
str
,
list
[
ProcessingCacheOptionalItem
]],
mm_missing_data
:
dict
[
str
,
list
[
object
]],
mm_missing_kwargs
:
MultiModalKwargs
,
)
->
dict
[
str
,
list
[
ProcessingCacheItem
]]:
mm_missing_next_idx
=
{
modality
:
0
for
modality
in
mm_missing_data
}
merged_items
=
defaultdict
[
str
,
list
[
ProcessingCacheItem
]](
list
)
for
modality
,
cache_items
in
mm_cache_items
.
items
():
for
cache_item
in
cache_items
:
if
cache_item
.
value
is
None
:
kw_item
=
mm_missing_kwargs
.
get_item
(
modality
,
mm_missing_next_idx
[
modality
],
)
cache_item_new
=
ProcessingCacheItem
(
key
=
cache_item
.
key
,
value
=
kw_item
,
)
cache
.
put_item
(
cache_item_new
)
mm_missing_next_idx
[
modality
]
+=
1
else
:
cache_item_new
=
ProcessingCacheItem
(
key
=
cache_item
.
key
,
value
=
cache_item
.
value
,
)
merged_items
[
modality
].
append
(
cache_item_new
)
return
dict
(
merged_items
)
def
_apply_hf_processor
(
self
,
prompt
:
Union
[
str
,
list
[
int
]],
mm_data_items
:
MultiModalDataItems
,
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
*
,
return_mm_hashes
:
bool
,
)
->
tuple
[
list
[
int
],
MultiModalKwargs
,
Optional
[
MultiModalHashes
],
bool
]:
(
prompt_ids
,
mm_kwargs
,
is_update_applied
,
)
=
self
.
_apply_hf_processor_main
(
prompt
=
prompt
,
mm_items
=
mm_data_items
,
hf_processor_mm_kwargs
=
hf_processor_mm_kwargs
,
enable_hf_prompt_update
=
True
,
)
mm_hashes
=
(
self
.
_hash_mm_items
(
mm_data_items
,
hf_processor_mm_kwargs
)
if
return_mm_hashes
else
None
)
return
prompt_ids
,
mm_kwargs
,
mm_hashes
,
is_update_applied
def
_cached_apply_hf_processor
(
def
_cached_apply_hf_processor
(
self
,
self
,
prompt
:
Union
[
str
,
list
[
int
]],
prompt
:
Union
[
str
,
list
[
int
]],
mm_data_items
:
MultiModalDataItems
,
mm_data_items
:
MultiModalDataItems
,
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
)
->
tuple
[
list
[
int
],
MultiModalKwargs
,
bool
]:
*
,
return_mm_hashes
:
bool
,
)
->
tuple
[
list
[
int
],
MultiModalKwargs
,
Optional
[
MultiModalHashes
],
bool
]:
"""
"""
Apply the HF processor on the full prompt text,
Apply the HF processor on the full prompt text,
caching the results and reusing cached results.
caching the results and reusing cached results.
"""
"""
cache
=
self
.
cache
cache
=
self
.
cache
model_id
=
self
.
info
.
model_id
_
,
passthrough_data
=
self
.
_get_hf_mm_data
(
mm_data_items
)
_
,
passthrough_data
=
self
.
_get_hf_mm_data
(
mm_data_items
)
if
cache
is
None
or
passthrough_data
:
if
cache
is
None
or
passthrough_data
:
return
self
.
_apply_hf_processor
_main
(
return
self
.
_apply_hf_processor
(
prompt
=
prompt
,
prompt
=
prompt
,
mm_items
=
mm_data_items
,
mm_
data_
items
=
mm_data_items
,
hf_processor_mm_kwargs
=
hf_processor_mm_kwargs
,
hf_processor_mm_kwargs
=
hf_processor_mm_kwargs
,
enable_hf_prompt_update
=
True
,
return_mm_hashes
=
return_mm_hashes
,
)
)
mm_maybe_cached_kw_items
=
{
(
modality
:
[
mm_cache_items
,
cache
.
get
(
model_id
,
modality
,
item
,
hf_processor_mm_kwargs
)
mm_missing_data
,
for
item
in
items
)
=
self
.
_get_cache_missing_items
(
]
cache
=
cache
,
for
modality
,
items
in
mm_data_items
.
items
()
mm_data_items
=
mm_data_items
,
}
hf_processor_mm_kwargs
=
hf_processor_mm_kwargs
,
)
mm_missing_idxs
=
{
modality
:
[
idx
for
idx
,
item
in
enumerate
(
kw_items
)
if
item
is
None
]
for
modality
,
kw_items
in
mm_maybe_cached_kw_items
.
items
()
}
mm_missing_data
=
{
modality
:
[
mm_data_items
[
modality
][
idx
]
for
idx
in
idxs
]
for
modality
,
idxs
in
mm_missing_idxs
.
items
()
}
mm_missing_data_items
=
self
.
_to_mm_items
(
mm_missing_data
)
# NOTE: `prompt` does not correspond to `mm_missing_data_items`,
# NOTE: `prompt` does not correspond to `mm_missing_data_items`,
# so we can't apply prompt updates until the new multimodal
# so we can't apply prompt updates until the new multimodal
...
@@ -1390,48 +1520,29 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
...
@@ -1390,48 +1520,29 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
is_update_applied
,
is_update_applied
,
)
=
self
.
_apply_hf_processor_main
(
)
=
self
.
_apply_hf_processor_main
(
prompt
=
prompt
,
prompt
=
prompt
,
mm_items
=
mm_missing_data
_items
,
mm_items
=
self
.
_to_mm_items
(
mm_missing_data
)
,
hf_processor_mm_kwargs
=
hf_processor_mm_kwargs
,
hf_processor_mm_kwargs
=
hf_processor_mm_kwargs
,
enable_hf_prompt_update
=
False
,
enable_hf_prompt_update
=
False
,
)
)
mm_missing_next_idx
=
{
mm_cache_items_merged
=
self
.
_merge_mm_kwargs
(
modality
:
0
cache
,
for
modality
in
mm_missing_data_items
mm_cache_items
=
mm_cache_items
,
}
mm_missing_data
=
mm_missing_data
,
mm_missing_kwargs
=
mm_missing_kwargs
,
merged_kw_items
=
list
[
MultiModalKwargsItem
]()
)
for
modality
,
kw_items
in
mm_maybe_cached_kw_items
.
items
():
for
idx
,
kw_item
in
enumerate
(
kw_items
):
if
kw_item
is
None
:
kw_item
=
mm_missing_kwargs
.
get_item
(
modality
,
mm_missing_next_idx
[
modality
],
)
cache
.
put
(
model_id
,
modality
,
mm_data_items
[
modality
][
idx
],
hf_processor_mm_kwargs
,
kw_item
,
)
mm_missing_next_idx
[
modality
]
+=
1
merged_kw_items
.
append
(
kw_item
)
if
self
.
enable_sanity_checks
:
mm_kwargs
=
MultiModalKwargs
.
from_items
([
mm_missing_counts
=
mm_missing_data_items
.
get_all_counts
()
item
.
value
for
cache_items
in
mm_cache_items_merged
.
values
()
assert
all
(
for
item
in
cache_items
item_count
==
mm_missing_counts
[
modality
]
])
for
modality
,
item_count
in
mm_missing_next_idx
.
items
()),
dict
(
mm_missing_next_idx
=
mm_missing_next_idx
,
mm_missing_counts
=
mm_missing_counts
)
mm_kwargs
=
MultiModalKwargs
.
from_items
(
merged_kw_items
)
mm_hashes
=
{
modality
:
[
item
.
key
for
item
in
cache_items
]
for
modality
,
cache_items
in
mm_cache_items_merged
.
items
()
}
if
return_mm_hashes
else
None
return
prompt_ids
,
mm_kwargs
,
is_update_applied
return
prompt_ids
,
mm_kwargs
,
mm_hashes
,
is_update_applied
def
_bind_and_group_updates
(
def
_bind_and_group_updates
(
self
,
self
,
...
@@ -1569,27 +1680,6 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
...
@@ -1569,27 +1680,6 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
"model (usually arising from an inconsistency between "
"model (usually arising from an inconsistency between "
"`_call_hf_processor` and `_get_prompt_updates`)."
)
"`_call_hf_processor` and `_get_prompt_updates`)."
)
def
_hash_mm_items
(
self
,
mm_items
:
MultiModalDataItems
,
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
)
->
dict
[
str
,
list
[
str
]]:
"""Create MM hashes to be returned (only used in V1)."""
# TODO: Use these hash keys for caching operations in apply_hf_processor
# instead of rehashing.
model_id
=
self
.
info
.
model_id
return
{
modality
:
[
MultiModalHasher
.
hash_kwargs
(
model_id
=
model_id
,
**
{
modality
:
item
},
**
hf_processor_mm_kwargs
)
for
item
in
items
]
for
modality
,
items
in
mm_items
.
items
()
}
def
_maybe_apply_prompt_updates
(
def
_maybe_apply_prompt_updates
(
self
,
self
,
mm_items
:
MultiModalDataItems
,
mm_items
:
MultiModalDataItems
,
...
@@ -1655,17 +1745,16 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
...
@@ -1655,17 +1745,16 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
"""
"""
mm_items
=
self
.
_to_mm_items
(
mm_data
)
mm_items
=
self
.
_to_mm_items
(
mm_data
)
mm_hashes
=
(
self
.
_hash_mm_items
(
mm_items
,
hf_processor_mm_kwargs
)
if
return_mm_hashes
else
None
)
(
(
prompt_ids
,
prompt_ids
,
mm_kwargs
,
mm_kwargs
,
mm_hashes
,
is_update_applied
,
is_update_applied
,
)
=
self
.
_cached_apply_hf_processor
(
)
=
self
.
_cached_apply_hf_processor
(
prompt
,
prompt
,
mm_items
,
mm_items
,
hf_processor_mm_kwargs
,
hf_processor_mm_kwargs
,
return_mm_hashes
=
return_mm_hashes
,
)
)
prompt_ids
,
prompt
,
mm_placeholders
=
self
.
_maybe_apply_prompt_updates
(
prompt_ids
,
prompt
,
mm_placeholders
=
self
.
_maybe_apply_prompt_updates
(
...
@@ -1717,28 +1806,12 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]):
...
@@ -1717,28 +1806,12 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]):
"""Create input prompt for the decoder."""
"""Create input prompt for the decoder."""
return
prompt
return
prompt
def
apply
(
def
_get_enc_dec_inputs
(
self
,
self
,
prompt
:
Union
[
str
,
list
[
int
]],
prompt
:
Union
[
str
,
list
[
int
]],
mm_data
:
MultiModalDataDict
,
mm_data
:
MultiModalDataDict
,
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
encoder_inputs
:
MultiModalInputs
,
return_mm_hashes
:
bool
=
False
,
):
)
->
MultiModalEncDecInputs
:
"""
Process multi-modal inputs to be used in vLLM.
The main processing steps are modified to fit encoder-decoder model:
1. Create encoder prompt from input prompt text.
2. Apply the HF processor on encoder prompt.
3. Copy the input prompt text as decoder prompt inputs.
"""
encoder_prompt
=
self
.
create_encoder_prompt
(
prompt
,
mm_data
)
encoder_inputs
=
super
().
apply
(
encoder_prompt
,
mm_data
,
hf_processor_mm_kwargs
,
return_mm_hashes
,
)
tokenizer
=
self
.
info
.
get_tokenizer
()
tokenizer
=
self
.
info
.
get_tokenizer
()
decoder_prompt
=
self
.
create_decoder_prompt
(
prompt
,
mm_data
)
decoder_prompt
=
self
.
create_decoder_prompt
(
prompt
,
mm_data
)
if
isinstance
(
decoder_prompt
,
str
):
if
isinstance
(
decoder_prompt
,
str
):
...
@@ -1758,3 +1831,31 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]):
...
@@ -1758,3 +1831,31 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]):
"prompt_token_ids"
:
decoder_prompt_ids
"prompt_token_ids"
:
decoder_prompt_ids
})
})
return
mm_inputs
return
mm_inputs
def
apply
(
self
,
prompt
:
Union
[
str
,
list
[
int
]],
mm_data
:
MultiModalDataDict
,
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
return_mm_hashes
:
bool
=
False
,
)
->
MultiModalEncDecInputs
:
"""
Process multi-modal inputs to be used in vLLM.
The main processing steps are modified to fit encoder-decoder model:
1. Create encoder prompt from input prompt text.
2. Apply the HF processor on encoder prompt.
3. Copy the input prompt text as decoder prompt inputs.
"""
encoder_prompt
=
self
.
create_encoder_prompt
(
prompt
,
mm_data
)
encoder_inputs
=
super
().
apply
(
encoder_prompt
,
mm_data
,
hf_processor_mm_kwargs
,
return_mm_hashes
,
)
return
self
.
_get_enc_dec_inputs
(
prompt
=
prompt
,
mm_data
=
mm_data
,
encoder_inputs
=
encoder_inputs
,
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment