Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
506475de
Unverified
Commit
506475de
authored
Apr 29, 2025
by
Cyrus Leung
Committed by
GitHub
Apr 29, 2025
Browse files
[Optim] Compute multimodal hash only once per item (#17314)
Signed-off-by:
DarkLight1337
<
tlleungac@connect.ust.hk
>
parent
cfe45320
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
233 additions
and
128 deletions
+233
-128
vllm/model_executor/models/deepseek_vl2.py
vllm/model_executor/models/deepseek_vl2.py
+9
-7
vllm/model_executor/models/h2ovl.py
vllm/model_executor/models/h2ovl.py
+9
-7
vllm/model_executor/models/llava.py
vllm/model_executor/models/llava.py
+0
-3
vllm/model_executor/models/mistral3.py
vllm/model_executor/models/mistral3.py
+0
-2
vllm/model_executor/models/pixtral.py
vllm/model_executor/models/pixtral.py
+10
-5
vllm/multimodal/processing.py
vllm/multimodal/processing.py
+205
-104
No files found.
vllm/model_executor/models/deepseek_vl2.py
View file @
506475de
...
...
@@ -22,8 +22,8 @@ from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
from
vllm.multimodal.parse
import
(
ImageEmbeddingItems
,
ImageProcessorItems
,
ImageSize
,
MultiModalDataItems
)
from
vllm.multimodal.processing
import
(
BaseMultiModalProcessor
,
BaseProcessingInfo
,
PromptReplacement
,
PromptUpdate
)
BaseProcessingInfo
,
MultiModalHashes
,
PromptReplacement
,
PromptUpdate
)
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
from
vllm.sequence
import
IntermediateTensors
from
vllm.transformers_utils.configs.deepseek_vl2
import
(
DeepseekVLV2Config
,
...
...
@@ -279,24 +279,26 @@ class DeepseekVL2MultiModalProcessor(
prompt
:
Union
[
str
,
list
[
int
]],
mm_data_items
:
MultiModalDataItems
,
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
)
->
tuple
[
list
[
int
],
MultiModalKwargs
,
bool
]:
*
,
return_mm_hashes
:
bool
,
)
->
tuple
[
list
[
int
],
MultiModalKwargs
,
Optional
[
MultiModalHashes
],
bool
]:
# The processor logic is different for len(images) <= 2 vs > 2
# Since the processing cache assumes that the processor output is
# invariant of how many images are passed per prompt, we only
# perform caching for the most common case
if
mm_data_items
.
get_count
(
"image"
,
strict
=
False
)
>
2
:
# This code path corresponds to the cache being disabled
return
self
.
_apply_hf_processor_main
(
return
self
.
_apply_hf_processor
(
prompt
=
prompt
,
mm_items
=
mm_data_items
,
mm_
data_
items
=
mm_data_items
,
hf_processor_mm_kwargs
=
hf_processor_mm_kwargs
,
enable_hf_prompt_update
=
True
,
return_mm_hashes
=
return_mm_hashes
,
)
return
super
().
_cached_apply_hf_processor
(
prompt
=
prompt
,
mm_data_items
=
mm_data_items
,
hf_processor_mm_kwargs
=
hf_processor_mm_kwargs
,
return_mm_hashes
=
return_mm_hashes
,
)
...
...
vllm/model_executor/models/h2ovl.py
View file @
506475de
...
...
@@ -19,8 +19,8 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
from
vllm.multimodal.inputs
import
MultiModalKwargs
from
vllm.multimodal.parse
import
(
ImageEmbeddingItems
,
ImageProcessorItems
,
MultiModalDataItems
)
from
vllm.multimodal.processing
import
(
PromptReplacement
,
PromptUpdate
,
PromptUpdateDetails
)
from
vllm.multimodal.processing
import
(
MultiModalHashes
,
PromptReplacement
,
PromptUpdate
,
PromptUpdateDetails
)
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
from
.intern_vit
import
InternVisionModel
...
...
@@ -488,24 +488,26 @@ class H2OVLMultiModalProcessor(InternVLMultiModalProcessor[H2OVLProcessingInfo]
prompt
:
Union
[
str
,
list
[
int
]],
mm_data_items
:
MultiModalDataItems
,
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
)
->
tuple
[
list
[
int
],
MultiModalKwargs
,
bool
]:
*
,
return_mm_hashes
:
bool
,
)
->
tuple
[
list
[
int
],
MultiModalKwargs
,
Optional
[
MultiModalHashes
],
bool
]:
# The processor logic is different for len(images) <= 1 vs > 1
# Since the processing cache assumes that the processor output is
# invariant of how many images are passed per prompt, we only
# perform caching for the most common case
if
mm_data_items
.
get_count
(
"image"
,
strict
=
False
)
>
1
:
# This code path corresponds to the cache being disabled
return
self
.
_apply_hf_processor_main
(
return
self
.
_apply_hf_processor
(
prompt
=
prompt
,
mm_items
=
mm_data_items
,
mm_
data_
items
=
mm_data_items
,
hf_processor_mm_kwargs
=
hf_processor_mm_kwargs
,
enable_hf_prompt_update
=
True
,
return_mm_hashes
=
return_mm_hashes
,
)
return
super
().
_cached_apply_hf_processor
(
prompt
=
prompt
,
mm_data_items
=
mm_data_items
,
hf_processor_mm_kwargs
=
hf_processor_mm_kwargs
,
return_mm_hashes
=
return_mm_hashes
,
)
...
...
vllm/model_executor/models/llava.py
View file @
506475de
...
...
@@ -396,14 +396,12 @@ def _build_llava_or_pixtral_hf_processor(
dummy_inputs
:
BaseDummyInputsBuilder
[
_I
],
*
,
cache
:
Optional
[
ProcessingCache
]
=
None
,
enable_sanity_checks
:
bool
=
True
,
)
->
BaseMultiModalProcessor
:
if
isinstance
(
info
,
PixtralHFProcessingInfo
):
return
PixtralHFMultiModalProcessor
(
info
,
dummy_inputs
,
# type: ignore
cache
=
cache
,
enable_sanity_checks
=
enable_sanity_checks
,
)
if
isinstance
(
info
,
LlavaProcessingInfo
):
...
...
@@ -411,7 +409,6 @@ def _build_llava_or_pixtral_hf_processor(
info
,
dummy_inputs
,
# type: ignore
cache
=
cache
,
enable_sanity_checks
=
enable_sanity_checks
,
)
raise
NotImplementedError
(
type
(
info
))
...
...
vllm/model_executor/models/mistral3.py
View file @
506475de
...
...
@@ -312,14 +312,12 @@ def _build_mistral3_processor(
dummy_inputs
:
BaseDummyInputsBuilder
[
_I
],
*
,
cache
:
Optional
[
ProcessingCache
]
=
None
,
enable_sanity_checks
:
bool
=
True
,
)
->
BaseMultiModalProcessor
:
assert
isinstance
(
info
,
Mistral3ProcessingInfo
)
return
Mistral3MultiModalProcessor
(
info
,
dummy_inputs
,
# type: ignore
cache
=
cache
,
enable_sanity_checks
=
enable_sanity_checks
,
)
...
...
vllm/model_executor/models/pixtral.py
View file @
506475de
...
...
@@ -36,8 +36,9 @@ from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
from
vllm.multimodal.parse
import
(
ImageProcessorItems
,
ImageSize
,
MultiModalDataItems
)
from
vllm.multimodal.processing
import
(
BaseMultiModalProcessor
,
BaseProcessingInfo
,
PromptReplacement
,
PromptUpdate
,
PromptUpdateDetails
)
BaseProcessingInfo
,
MultiModalHashes
,
PromptReplacement
,
PromptUpdate
,
PromptUpdateDetails
)
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
from
vllm.sequence
import
IntermediateTensors
from
vllm.transformers_utils.tokenizer
import
(
MistralTokenizer
,
...
...
@@ -271,15 +272,19 @@ class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo]
prompt
:
Union
[
str
,
list
[
int
]],
mm_data_items
:
MultiModalDataItems
,
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
)
->
tuple
[
list
[
int
],
MultiModalKwargs
,
bool
]:
prompt_ids
,
mm_kwargs
,
_
=
super
().
_cached_apply_hf_processor
(
*
,
return_mm_hashes
:
bool
,
)
->
tuple
[
list
[
int
],
MultiModalKwargs
,
Optional
[
MultiModalHashes
],
bool
]:
prompt_ids
,
mm_kwargs
,
mm_hashes
,
_
=
super
(
).
_cached_apply_hf_processor
(
prompt
=
prompt
,
mm_data_items
=
mm_data_items
,
hf_processor_mm_kwargs
=
hf_processor_mm_kwargs
,
return_mm_hashes
=
return_mm_hashes
,
)
# NOTE: The tokens are already inserted by the chat template
return
prompt_ids
,
mm_kwargs
,
True
return
prompt_ids
,
mm_kwargs
,
mm_hashes
,
True
@
MULTIMODAL_REGISTRY
.
register_processor
(
PixtralMultiModalProcessor
,
...
...
vllm/multimodal/processing.py
View file @
506475de
...
...
@@ -876,6 +876,16 @@ def find_mm_placeholders(
_V
=
TypeVar
(
"_V"
,
bound
=
"Union[MultiModalKwargs, MultiModalKwargsItem]"
)
class
ProcessingCacheOptionalItem
(
NamedTuple
):
key
:
str
value
:
Optional
[
MultiModalKwargsItem
]
class
ProcessingCacheItem
(
NamedTuple
):
key
:
str
value
:
MultiModalKwargsItem
class
ProcessingCache
:
@
staticmethod
...
...
@@ -980,6 +990,22 @@ class ProcessingCache:
return
self
.
_cache
.
get
(
cache_key
)
def
get_item
(
self
,
model_id
:
str
,
modality
:
str
,
input_item
:
object
,
input_kwargs
:
Mapping
[
str
,
object
],
)
->
ProcessingCacheOptionalItem
:
cache_key
=
MultiModalHasher
.
hash_kwargs
(
model_id
=
model_id
,
**
{
modality
:
input_item
},
**
input_kwargs
)
return
ProcessingCacheOptionalItem
(
key
=
cache_key
,
value
=
self
.
_cache
.
get
(
cache_key
),
)
def
put
(
self
,
model_id
:
str
,
...
...
@@ -997,6 +1023,9 @@ class ProcessingCache:
**
input_kwargs
)
self
.
_cache
[
cache_key
]
=
output_kwargs
def
put_item
(
self
,
item
:
ProcessingCacheItem
)
->
None
:
self
.
_cache
[
item
.
key
]
=
item
.
value
class
BaseProcessingInfo
:
"""Base class to provide the information necessary for data processing."""
...
...
@@ -1052,6 +1081,11 @@ class BaseProcessingInfo:
_I
=
TypeVar
(
"_I"
,
bound
=
BaseProcessingInfo
)
MultiModalHashes
=
dict
[
str
,
list
[
str
]]
"""
A collection of hashes with a similar structure as :class:`MultiModalKwargs`.
"""
class
BaseMultiModalProcessor
(
ABC
,
Generic
[
_I
]):
"""
...
...
@@ -1064,14 +1098,12 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
info
:
_I
,
dummy_inputs
:
"BaseDummyInputsBuilder[_I]"
,
*
,
cache
:
Optional
[
ProcessingCache
]
=
None
,
enable_sanity_checks
:
bool
=
True
)
->
None
:
cache
:
Optional
[
ProcessingCache
]
=
None
)
->
None
:
super
().
__init__
()
self
.
info
=
info
self
.
dummy_inputs
=
dummy_inputs
self
.
cache
=
cache
self
.
enable_sanity_checks
=
enable_sanity_checks
self
.
data_parser
=
self
.
_get_data_parser
()
...
...
@@ -1340,46 +1372,144 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
return
prompt_ids
,
mm_kwargs
,
False
def
_get_cache_missing_items
(
self
,
cache
:
ProcessingCache
,
mm_data_items
:
MultiModalDataItems
,
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
)
->
tuple
[
dict
[
str
,
list
[
ProcessingCacheOptionalItem
]],
dict
[
str
,
list
[
object
]]]:
model_id
=
self
.
info
.
model_id
mm_cache_items
=
{
modality
:
[
cache
.
get_item
(
model_id
,
modality
,
item
,
hf_processor_mm_kwargs
)
for
item
in
items
]
for
modality
,
items
in
mm_data_items
.
items
()
}
mm_missing_idxs
=
{
modality
:
[
idx
for
idx
,
item
in
enumerate
(
cache_items
)
if
item
.
value
is
None
]
for
modality
,
cache_items
in
mm_cache_items
.
items
()
}
mm_missing_data
=
{
modality
:
[
mm_data_items
[
modality
][
idx
]
for
idx
in
idxs
]
for
modality
,
idxs
in
mm_missing_idxs
.
items
()
}
return
mm_cache_items
,
mm_missing_data
def
_hash_mm_items
(
self
,
mm_items
:
MultiModalDataItems
,
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
)
->
MultiModalHashes
:
"""Create MM hashes to be returned (only used in V1)."""
model_id
=
self
.
info
.
model_id
return
{
modality
:
[
MultiModalHasher
.
hash_kwargs
(
model_id
=
model_id
,
**
{
modality
:
item
},
**
hf_processor_mm_kwargs
)
for
item
in
items
]
for
modality
,
items
in
mm_items
.
items
()
}
def
_merge_mm_kwargs
(
self
,
cache
:
ProcessingCache
,
mm_cache_items
:
dict
[
str
,
list
[
ProcessingCacheOptionalItem
]],
mm_missing_data
:
dict
[
str
,
list
[
object
]],
mm_missing_kwargs
:
MultiModalKwargs
,
)
->
dict
[
str
,
list
[
ProcessingCacheItem
]]:
mm_missing_next_idx
=
{
modality
:
0
for
modality
in
mm_missing_data
}
merged_items
=
defaultdict
[
str
,
list
[
ProcessingCacheItem
]](
list
)
for
modality
,
cache_items
in
mm_cache_items
.
items
():
for
cache_item
in
cache_items
:
if
cache_item
.
value
is
None
:
kw_item
=
mm_missing_kwargs
.
get_item
(
modality
,
mm_missing_next_idx
[
modality
],
)
cache_item_new
=
ProcessingCacheItem
(
key
=
cache_item
.
key
,
value
=
kw_item
,
)
cache
.
put_item
(
cache_item_new
)
mm_missing_next_idx
[
modality
]
+=
1
else
:
cache_item_new
=
ProcessingCacheItem
(
key
=
cache_item
.
key
,
value
=
cache_item
.
value
,
)
merged_items
[
modality
].
append
(
cache_item_new
)
return
dict
(
merged_items
)
def
_apply_hf_processor
(
self
,
prompt
:
Union
[
str
,
list
[
int
]],
mm_data_items
:
MultiModalDataItems
,
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
*
,
return_mm_hashes
:
bool
,
)
->
tuple
[
list
[
int
],
MultiModalKwargs
,
Optional
[
MultiModalHashes
],
bool
]:
(
prompt_ids
,
mm_kwargs
,
is_update_applied
,
)
=
self
.
_apply_hf_processor_main
(
prompt
=
prompt
,
mm_items
=
mm_data_items
,
hf_processor_mm_kwargs
=
hf_processor_mm_kwargs
,
enable_hf_prompt_update
=
True
,
)
mm_hashes
=
(
self
.
_hash_mm_items
(
mm_data_items
,
hf_processor_mm_kwargs
)
if
return_mm_hashes
else
None
)
return
prompt_ids
,
mm_kwargs
,
mm_hashes
,
is_update_applied
def
_cached_apply_hf_processor
(
self
,
prompt
:
Union
[
str
,
list
[
int
]],
mm_data_items
:
MultiModalDataItems
,
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
)
->
tuple
[
list
[
int
],
MultiModalKwargs
,
bool
]:
*
,
return_mm_hashes
:
bool
,
)
->
tuple
[
list
[
int
],
MultiModalKwargs
,
Optional
[
MultiModalHashes
],
bool
]:
"""
Apply the HF processor on the full prompt text,
caching the results and reusing cached results.
"""
cache
=
self
.
cache
model_id
=
self
.
info
.
model_id
_
,
passthrough_data
=
self
.
_get_hf_mm_data
(
mm_data_items
)
if
cache
is
None
or
passthrough_data
:
return
self
.
_apply_hf_processor
_main
(
return
self
.
_apply_hf_processor
(
prompt
=
prompt
,
mm_items
=
mm_data_items
,
mm_
data_
items
=
mm_data_items
,
hf_processor_mm_kwargs
=
hf_processor_mm_kwargs
,
enable_hf_prompt_update
=
True
,
return_mm_hashes
=
return_mm_hashes
,
)
mm_maybe_cached_kw_items
=
{
modality
:
[
cache
.
get
(
model_id
,
modality
,
item
,
hf_processor_mm_kwargs
)
for
item
in
items
]
for
modality
,
items
in
mm_data_items
.
items
()
}
mm_missing_idxs
=
{
modality
:
[
idx
for
idx
,
item
in
enumerate
(
kw_items
)
if
item
is
None
]
for
modality
,
kw_items
in
mm_maybe_cached_kw_items
.
items
()
}
mm_missing_data
=
{
modality
:
[
mm_data_items
[
modality
][
idx
]
for
idx
in
idxs
]
for
modality
,
idxs
in
mm_missing_idxs
.
items
()
}
mm_missing_data_items
=
self
.
_to_mm_items
(
mm_missing_data
)
(
mm_cache_items
,
mm_missing_data
,
)
=
self
.
_get_cache_missing_items
(
cache
=
cache
,
mm_data_items
=
mm_data_items
,
hf_processor_mm_kwargs
=
hf_processor_mm_kwargs
,
)
# NOTE: `prompt` does not correspond to `mm_missing_data_items`,
# so we can't apply prompt updates until the new multimodal
...
...
@@ -1390,48 +1520,29 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
is_update_applied
,
)
=
self
.
_apply_hf_processor_main
(
prompt
=
prompt
,
mm_items
=
mm_missing_data
_items
,
mm_items
=
self
.
_to_mm_items
(
mm_missing_data
)
,
hf_processor_mm_kwargs
=
hf_processor_mm_kwargs
,
enable_hf_prompt_update
=
False
,
)
mm_missing_next_idx
=
{
modality
:
0
for
modality
in
mm_missing_data_items
}
merged_kw_items
=
list
[
MultiModalKwargsItem
]()
for
modality
,
kw_items
in
mm_maybe_cached_kw_items
.
items
():
for
idx
,
kw_item
in
enumerate
(
kw_items
):
if
kw_item
is
None
:
kw_item
=
mm_missing_kwargs
.
get_item
(
modality
,
mm_missing_next_idx
[
modality
],
)
cache
.
put
(
model_id
,
modality
,
mm_data_items
[
modality
][
idx
],
hf_processor_mm_kwargs
,
kw_item
,
mm_cache_items_merged
=
self
.
_merge_mm_kwargs
(
cache
,
mm_cache_items
=
mm_cache_items
,
mm_missing_data
=
mm_missing_data
,
mm_missing_kwargs
=
mm_missing_kwargs
,
)
mm_missing_next_idx
[
modality
]
+=
1
merged_kw_items
.
append
(
kw_item
)
mm_kwargs
=
MultiModalKwargs
.
from_items
([
item
.
value
for
cache_items
in
mm_cache_items_merged
.
values
()
for
item
in
cache_items
])
if
self
.
enable_sanity_checks
:
mm_missing_counts
=
mm_missing_data_items
.
get_all_counts
()
assert
all
(
item_count
==
mm_missing_counts
[
modality
]
for
modality
,
item_count
in
mm_missing_next_idx
.
items
()),
dict
(
mm_missing_next_idx
=
mm_missing_next_idx
,
mm_missing_counts
=
mm_missing_counts
)
mm_hashes
=
{
modality
:
[
item
.
key
for
item
in
cache_items
]
for
modality
,
cache_items
in
mm_cache_items_merged
.
items
()
}
if
return_mm_hashes
else
None
mm_kwargs
=
MultiModalKwargs
.
from_items
(
merged_kw_items
)
return
prompt_ids
,
mm_kwargs
,
is_update_applied
return
prompt_ids
,
mm_kwargs
,
mm_hashes
,
is_update_applied
def
_bind_and_group_updates
(
self
,
...
...
@@ -1569,27 +1680,6 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
"model (usually arising from an inconsistency between "
"`_call_hf_processor` and `_get_prompt_updates`)."
)
def
_hash_mm_items
(
self
,
mm_items
:
MultiModalDataItems
,
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
)
->
dict
[
str
,
list
[
str
]]:
"""Create MM hashes to be returned (only used in V1)."""
# TODO: Use these hash keys for caching operations in apply_hf_processor
# instead of rehashing.
model_id
=
self
.
info
.
model_id
return
{
modality
:
[
MultiModalHasher
.
hash_kwargs
(
model_id
=
model_id
,
**
{
modality
:
item
},
**
hf_processor_mm_kwargs
)
for
item
in
items
]
for
modality
,
items
in
mm_items
.
items
()
}
def
_maybe_apply_prompt_updates
(
self
,
mm_items
:
MultiModalDataItems
,
...
...
@@ -1655,17 +1745,16 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
"""
mm_items
=
self
.
_to_mm_items
(
mm_data
)
mm_hashes
=
(
self
.
_hash_mm_items
(
mm_items
,
hf_processor_mm_kwargs
)
if
return_mm_hashes
else
None
)
(
prompt_ids
,
mm_kwargs
,
mm_hashes
,
is_update_applied
,
)
=
self
.
_cached_apply_hf_processor
(
prompt
,
mm_items
,
hf_processor_mm_kwargs
,
return_mm_hashes
=
return_mm_hashes
,
)
prompt_ids
,
prompt
,
mm_placeholders
=
self
.
_maybe_apply_prompt_updates
(
...
...
@@ -1717,28 +1806,12 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]):
"""Create input prompt for the decoder."""
return
prompt
def
apply
(
def
_get_enc_dec_inputs
(
self
,
prompt
:
Union
[
str
,
list
[
int
]],
mm_data
:
MultiModalDataDict
,
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
return_mm_hashes
:
bool
=
False
,
)
->
MultiModalEncDecInputs
:
"""
Process multi-modal inputs to be used in vLLM.
The main processing steps are modified to fit encoder-decoder model:
1. Create encoder prompt from input prompt text.
2. Apply the HF processor on encoder prompt.
3. Copy the input prompt text as decoder prompt inputs.
"""
encoder_prompt
=
self
.
create_encoder_prompt
(
prompt
,
mm_data
)
encoder_inputs
=
super
().
apply
(
encoder_prompt
,
mm_data
,
hf_processor_mm_kwargs
,
return_mm_hashes
,
)
encoder_inputs
:
MultiModalInputs
,
):
tokenizer
=
self
.
info
.
get_tokenizer
()
decoder_prompt
=
self
.
create_decoder_prompt
(
prompt
,
mm_data
)
if
isinstance
(
decoder_prompt
,
str
):
...
...
@@ -1758,3 +1831,31 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]):
"prompt_token_ids"
:
decoder_prompt_ids
})
return
mm_inputs
def
apply
(
self
,
prompt
:
Union
[
str
,
list
[
int
]],
mm_data
:
MultiModalDataDict
,
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
return_mm_hashes
:
bool
=
False
,
)
->
MultiModalEncDecInputs
:
"""
Process multi-modal inputs to be used in vLLM.
The main processing steps are modified to fit encoder-decoder model:
1. Create encoder prompt from input prompt text.
2. Apply the HF processor on encoder prompt.
3. Copy the input prompt text as decoder prompt inputs.
"""
encoder_prompt
=
self
.
create_encoder_prompt
(
prompt
,
mm_data
)
encoder_inputs
=
super
().
apply
(
encoder_prompt
,
mm_data
,
hf_processor_mm_kwargs
,
return_mm_hashes
,
)
return
self
.
_get_enc_dec_inputs
(
prompt
=
prompt
,
mm_data
=
mm_data
,
encoder_inputs
=
encoder_inputs
,
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment