Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
d3f71f12
Unverified
Commit
d3f71f12
authored
Aug 18, 2025
by
Cyrus Leung
Committed by
GitHub
Aug 18, 2025
Browse files
[Refactor] Get prompt updates earlier (#23097)
Signed-off-by:
DarkLight1337
<
tlleungac@connect.ust.hk
>
parent
5a30bd10
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
84 additions
and
69 deletions
+84
-69
vllm/model_executor/models/deepseek_vl2.py
vllm/model_executor/models/deepseek_vl2.py
+3
-3
vllm/model_executor/models/h2ovl.py
vllm/model_executor/models/h2ovl.py
+4
-4
vllm/model_executor/models/pixtral.py
vllm/model_executor/models/pixtral.py
+5
-10
vllm/model_executor/models/qwen2_5_omni_thinker.py
vllm/model_executor/models/qwen2_5_omni_thinker.py
+15
-18
vllm/model_executor/models/voxtral.py
vllm/model_executor/models/voxtral.py
+5
-6
vllm/multimodal/processing.py
vllm/multimodal/processing.py
+52
-28
No files found.
vllm/model_executor/models/deepseek_vl2.py
View file @
d3f71f12
...
...
@@ -25,7 +25,8 @@ from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
from
vllm.multimodal.parse
import
(
ImageEmbeddingItems
,
ImageProcessorItems
,
ImageSize
,
MultiModalDataItems
)
from
vllm.multimodal.processing
import
(
BaseMultiModalProcessor
,
BaseProcessingInfo
,
MultiModalHashes
,
BaseProcessingInfo
,
MultiModalProcessingInfo
,
PromptReplacement
,
PromptUpdate
)
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
from
vllm.sequence
import
IntermediateTensors
...
...
@@ -291,8 +292,7 @@ class DeepseekVL2MultiModalProcessor(
tokenization_kwargs
:
Mapping
[
str
,
object
],
*
,
return_mm_hashes
:
bool
,
)
->
tuple
[
list
[
int
],
MultiModalKwargsItems
,
Optional
[
MultiModalHashes
],
bool
]:
)
->
tuple
[
list
[
int
],
MultiModalProcessingInfo
,
bool
]:
# The processor logic is different for len(images) <= 2 vs > 2
# Since the processing cache assumes that the processor output is
# invariant of how many images are passed per prompt, we only
...
...
vllm/model_executor/models/h2ovl.py
View file @
d3f71f12
...
...
@@ -20,8 +20,9 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
from
vllm.multimodal.inputs
import
MultiModalKwargsItems
from
vllm.multimodal.parse
import
(
ImageEmbeddingItems
,
ImageProcessorItems
,
MultiModalDataItems
)
from
vllm.multimodal.processing
import
(
MultiModalHashes
,
PromptReplacement
,
PromptUpdate
,
PromptUpdateDetails
)
from
vllm.multimodal.processing
import
(
MultiModalProcessingInfo
,
PromptReplacement
,
PromptUpdate
,
PromptUpdateDetails
)
from
vllm.transformers_utils.tokenizer
import
AnyTokenizer
from
.intern_vit
import
InternVisionModel
...
...
@@ -480,8 +481,7 @@ class H2OVLMultiModalProcessor(
tokenization_kwargs
:
Mapping
[
str
,
object
],
*
,
return_mm_hashes
:
bool
,
)
->
tuple
[
list
[
int
],
MultiModalKwargsItems
,
Optional
[
MultiModalHashes
],
bool
]:
)
->
tuple
[
list
[
int
],
MultiModalProcessingInfo
,
bool
]:
# The processor logic is different for len(images) <= 1 vs > 1
# Since the processing cache assumes that the processor output is
# invariant of how many images are passed per prompt, we only
...
...
vllm/model_executor/models/pixtral.py
View file @
d3f71f12
...
...
@@ -39,7 +39,8 @@ from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
from
vllm.multimodal.parse
import
(
ImageProcessorItems
,
ImageSize
,
MultiModalDataItems
)
from
vllm.multimodal.processing
import
(
BaseMultiModalProcessor
,
BaseProcessingInfo
,
MultiModalHashes
,
BaseProcessingInfo
,
MultiModalProcessingInfo
,
PromptReplacement
,
PromptUpdate
,
PromptUpdateDetails
)
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
,
ProcessorInputs
...
...
@@ -309,14 +310,8 @@ class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo]
tokenization_kwargs
:
Mapping
[
str
,
object
],
*
,
return_mm_hashes
:
bool
,
)
->
tuple
[
list
[
int
],
MultiModalKwargsItems
,
Optional
[
MultiModalHashes
],
bool
]:
(
prompt_ids
,
mm_kwargs
,
mm_hashes
,
_
,
)
=
super
().
_cached_apply_hf_processor
(
)
->
tuple
[
list
[
int
],
MultiModalProcessingInfo
,
bool
]:
prompt_ids
,
mm_info
,
_
=
super
().
_cached_apply_hf_processor
(
prompt
=
prompt
,
mm_data_items
=
mm_data_items
,
hf_processor_mm_kwargs
=
hf_processor_mm_kwargs
,
...
...
@@ -325,7 +320,7 @@ class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo]
)
# NOTE: The tokens are already inserted by the chat template
return
prompt_ids
,
mm_
kwargs
,
mm_hashes
,
True
return
prompt_ids
,
mm_
info
,
True
@
MULTIMODAL_REGISTRY
.
register_processor
(
PixtralMultiModalProcessor
,
...
...
vllm/model_executor/models/qwen2_5_omni_thinker.py
View file @
d3f71f12
...
...
@@ -59,6 +59,7 @@ from vllm.multimodal.parse import (AudioProcessorItems, DictEmbeddingItems,
ModalityDataItems
,
MultiModalDataItems
,
MultiModalDataParser
)
from
vllm.multimodal.processing
import
(
BaseMultiModalProcessor
,
MultiModalPromptUpdates
,
PlaceholderFeaturesInfo
,
PromptReplacement
,
PromptUpdate
)
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
...
...
@@ -88,10 +89,7 @@ def _qwen2_5_omni_thinker_field_config(hf_inputs: Mapping[str, torch.Tensor]):
video_grid_thw
=
hf_inputs
.
get
(
"video_grid_thw"
,
torch
.
empty
((
0
,
3
)))
video_grid_sizes
=
video_grid_thw
.
prod
(
-
1
)
# vllm use `second_per_grid_ts` to compute multimodal rotary embedding
video_second_per_grid
=
hf_inputs
.
get
(
"video_second_per_grid"
,
None
)
if
video_second_per_grid
is
not
None
:
hf_inputs
[
"second_per_grid_ts"
]
=
video_second_per_grid
num_videos
=
len
(
video_grid_sizes
)
return
dict
(
input_audio_features
=
MultiModalFieldConfig
.
flat_from_sizes
(
...
...
@@ -109,6 +107,7 @@ def _qwen2_5_omni_thinker_field_config(hf_inputs: Mapping[str, torch.Tensor]):
"video"
,
video_grid_sizes
),
video_grid_thw
=
MultiModalFieldConfig
.
batched
(
"video"
),
second_per_grid_ts
=
MultiModalFieldConfig
.
batched
(
"video"
),
use_audio_in_video
=
MultiModalFieldConfig
.
shared
(
"video"
,
num_videos
),
)
...
...
@@ -251,6 +250,14 @@ class Qwen2_5OmniThinkerMultiModalProcessor(
if
(
'audio_feature_lengths'
not
in
hf_inputs
and
feature_attention_mask
is
not
None
):
hf_inputs
[
'audio_feature_lengths'
]
=
feature_attention_mask
.
sum
(
-
1
)
video_second_per_grid
=
hf_inputs
.
get
(
"video_second_per_grid"
,
None
)
if
video_second_per_grid
is
not
None
:
hf_inputs
[
"second_per_grid_ts"
]
=
video_second_per_grid
use_audio_in_video
=
mm_kwargs
.
get
(
"use_audio_in_video"
,
False
)
hf_inputs
[
"use_audio_in_video"
]
=
torch
.
tensor
(
use_audio_in_video
)
return
hf_inputs
def
_get_mm_fields_config
(
...
...
@@ -263,27 +270,20 @@ class Qwen2_5OmniThinkerMultiModalProcessor(
def
_maybe_apply_prompt_updates
(
self
,
mm_items
:
MultiModalDataItems
,
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
prompt_ids
:
list
[
int
],
mm_kwargs
:
MultiModalKwargsItems
,
mm_prompt_updates
:
MultiModalPromptUpdates
,
is_update_applied
:
bool
,
)
->
tuple
[
list
[
int
],
str
,
Mapping
[
str
,
list
[
PlaceholderFeaturesInfo
]]]:
"""
Qwen2.5-Omni reimplements this function to handle `use_audio_in_video`.
"""
unbound_prompt_updates
=
self
.
_get_prompt_updates
(
mm_items
,
hf_processor_mm_kwargs
,
mm_kwargs
,
)
mm_prompt_updates
=
self
.
_bind_and_group_updates
(
unbound_prompt_updates
)
mm_item_counts
=
mm_items
.
get_all_counts
()
self
.
_validate_mm_kwargs
(
mm_kwargs
,
mm_item_counts
)
use_audio_in_video
=
hf_processor_mm_kwargs
.
get
(
"use_audio_in_video"
,
False
)
use_audio_in_video
=
(
all
(
item
[
"use_audio_in_video"
].
data
for
item
in
mm_kwargs
[
"video"
])
if
"video"
in
mm_kwargs
else
False
)
if
is_update_applied
:
mm_placeholders
=
self
.
_find_mm_placeholders
(
...
...
@@ -316,9 +316,6 @@ class Qwen2_5OmniThinkerMultiModalProcessor(
tokenizer
=
self
.
info
.
get_tokenizer
()
prompt
=
decode_tokens
(
tokenizer
,
prompt_ids
)
if
use_audio_in_video
:
mm_kwargs
[
"use_audio_in_video"
]
=
True
return
prompt_ids
,
prompt
,
mm_placeholders
def
_get_prompt_updates
(
...
...
vllm/model_executor/models/voxtral.py
View file @
d3f71f12
...
...
@@ -35,7 +35,8 @@ from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
from
vllm.multimodal.parse
import
(
AudioProcessorItems
,
MultiModalDataItems
,
MultiModalDataParser
)
from
vllm.multimodal.processing
import
(
BaseMultiModalProcessor
,
BaseProcessingInfo
,
MultiModalHashes
,
BaseProcessingInfo
,
MultiModalProcessingInfo
,
PromptReplacement
,
PromptUpdate
)
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
,
ProcessorInputs
from
vllm.sequence
import
IntermediateTensors
...
...
@@ -289,10 +290,8 @@ class VoxtralMultiModalProcessor(BaseMultiModalProcessor[VoxtralProcessingInfo]
tokenization_kwargs
:
Mapping
[
str
,
object
],
*
,
return_mm_hashes
:
bool
,
)
->
tuple
[
list
[
int
],
MultiModalKwargsItems
,
Optional
[
MultiModalHashes
],
bool
]:
prompt_ids
,
mm_kwargs
,
mm_hashes
,
_
=
super
(
).
_cached_apply_hf_processor
(
)
->
tuple
[
list
[
int
],
MultiModalProcessingInfo
,
bool
]:
prompt_ids
,
mm_info
,
_
=
super
().
_cached_apply_hf_processor
(
prompt
=
prompt
,
mm_data_items
=
mm_data_items
,
hf_processor_mm_kwargs
=
hf_processor_mm_kwargs
,
...
...
@@ -301,7 +300,7 @@ class VoxtralMultiModalProcessor(BaseMultiModalProcessor[VoxtralProcessingInfo]
)
# NOTE: The tokens are already inserted by the chat template
return
prompt_ids
,
mm_
kwargs
,
mm_hashes
,
True
return
prompt_ids
,
mm_
info
,
True
def
_get_data_parser
(
self
)
->
MultiModalDataParser
:
sampling_rate
=
self
.
info
.
get_hf_processor
().
sampling_rate
...
...
vllm/multimodal/processing.py
View file @
d3f71f12
...
...
@@ -989,6 +989,18 @@ A collection of hashes with a similar structure as
[`MultiModalKwargsItems`][vllm.multimodal.inputs.MultiModalKwargsItems].
"""
MultiModalPromptUpdates
=
dict
[
str
,
Sequence
[
BoundPromptUpdate
]]
"""
A collection of prompt updates with a similar structure as
[`MultiModalKwargsItems`][vllm.multimodal.inputs.MultiModalKwargsItems].
"""
class
MultiModalProcessingInfo
(
NamedTuple
):
kwargs
:
MultiModalKwargsItems
hashes
:
Optional
[
MultiModalHashes
]
prompt_updates
:
MultiModalPromptUpdates
class
BaseMultiModalProcessor
(
ABC
,
Generic
[
_I
]):
"""
...
...
@@ -1363,7 +1375,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
cache
:
ProcessingCache
,
mm_cache_items_or_hashes
:
dict
[
str
,
list
[
_CacheItemOrHash
]],
mm_missing_kwargs
:
MultiModalKwargsItems
,
)
->
dict
[
str
,
list
[
MultiModalKwargsItem
]]
:
)
->
MultiModalKwargsItem
s
:
mm_missing_next_idx
=
defaultdict
[
str
,
int
](
lambda
:
0
)
merged_items
=
defaultdict
[
str
,
list
[
MultiModalKwargsItem
]](
list
)
...
...
@@ -1379,7 +1391,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
merged_items
[
modality
].
append
(
kw_item
)
return
dict
(
merged_items
)
return
MultiModalKwargsItems
(
merged_items
)
def
_apply_hf_processor
(
self
,
...
...
@@ -1389,8 +1401,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
tokenization_kwargs
:
Mapping
[
str
,
object
],
*
,
return_mm_hashes
:
bool
,
)
->
tuple
[
list
[
int
],
MultiModalKwargsItems
,
Optional
[
MultiModalHashes
],
bool
]:
)
->
tuple
[
list
[
int
],
MultiModalProcessingInfo
,
bool
]:
(
prompt_ids
,
mm_processed_data
,
...
...
@@ -1413,7 +1424,21 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
tokenization_kwargs
)
if
return_mm_hashes
else
None
)
return
prompt_ids
,
mm_kwargs
,
mm_hashes
,
is_update_applied
unbound_prompt_updates
=
self
.
_get_prompt_updates
(
mm_data_items
,
hf_processor_mm_kwargs
,
mm_kwargs
,
)
mm_prompt_updates
=
self
.
_bind_and_group_updates
(
unbound_prompt_updates
)
mm_info
=
MultiModalProcessingInfo
(
kwargs
=
mm_kwargs
,
hashes
=
mm_hashes
,
prompt_updates
=
mm_prompt_updates
,
)
return
prompt_ids
,
mm_info
,
is_update_applied
def
_cached_apply_hf_processor
(
self
,
...
...
@@ -1423,8 +1448,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
tokenization_kwargs
:
Mapping
[
str
,
object
],
*
,
return_mm_hashes
:
bool
,
)
->
tuple
[
list
[
int
],
MultiModalKwargsItems
,
Optional
[
MultiModalHashes
],
bool
]:
)
->
tuple
[
list
[
int
],
MultiModalProcessingInfo
,
bool
]:
"""
Apply the HF processor on the full prompt text,
caching the results and reusing cached results.
...
...
@@ -1475,18 +1499,27 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
hf_processor_mm_kwargs
),
)
mm_
cache_items_merged
=
self
.
_merge_mm_kwargs
(
mm_
kwargs
=
self
.
_merge_mm_kwargs
(
cache
,
mm_cache_items_or_hashes
=
mm_cache_items_or_hashes
,
mm_missing_kwargs
=
mm_missing_kwargs
,
)
mm_kwargs
=
MultiModalKwargsItems
.
from_seq
([
item
for
cache_items
in
mm_cache_items_merged
.
values
()
for
item
in
cache_items
])
unbound_prompt_updates
=
self
.
_get_prompt_updates
(
mm_data_items
,
hf_processor_mm_kwargs
,
mm_kwargs
,
)
mm_prompt_updates
=
self
.
_bind_and_group_updates
(
unbound_prompt_updates
)
mm_info
=
MultiModalProcessingInfo
(
kwargs
=
mm_kwargs
,
hashes
=
mm_hashes_to_return
,
prompt_updates
=
mm_prompt_updates
,
)
return
prompt_ids
,
mm_
kwargs
,
mm_hashes_to_return
,
is_update_applied
return
prompt_ids
,
mm_
info
,
is_update_applied
def
_bind_and_group_updates
(
self
,
...
...
@@ -1626,19 +1659,11 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
def
_maybe_apply_prompt_updates
(
self
,
mm_items
:
MultiModalDataItems
,
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
prompt_ids
:
list
[
int
],
mm_kwargs
:
MultiModalKwargsItems
,
mm_prompt_updates
:
MultiModalPromptUpdates
,
is_update_applied
:
bool
,
)
->
tuple
[
list
[
int
],
str
,
Mapping
[
str
,
list
[
PlaceholderFeaturesInfo
]]]:
unbound_prompt_updates
=
self
.
_get_prompt_updates
(
mm_items
,
hf_processor_mm_kwargs
,
mm_kwargs
,
)
mm_prompt_updates
=
self
.
_bind_and_group_updates
(
unbound_prompt_updates
)
mm_item_counts
=
mm_items
.
get_all_counts
()
self
.
_validate_mm_kwargs
(
mm_kwargs
,
mm_item_counts
)
...
...
@@ -1694,8 +1719,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
(
prompt_ids
,
mm_kwargs
,
mm_hashes
,
mm_info
,
is_update_applied
,
)
=
self
.
_cached_apply_hf_processor
(
prompt
,
...
...
@@ -1708,9 +1732,9 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
# NOTE: tokenization_kwargs are not required to init processor
prompt_ids
,
prompt
,
mm_placeholders
=
self
.
_maybe_apply_prompt_updates
(
mm_items
=
mm_items
,
hf_processor_mm_kwargs
=
hf_processor_mm_kwargs
,
prompt_ids
=
prompt_ids
,
mm_kwargs
=
mm_kwargs
,
mm_kwargs
=
mm_info
.
kwargs
,
mm_prompt_updates
=
mm_info
.
prompt_updates
,
is_update_applied
=
is_update_applied
,
)
...
...
@@ -1723,8 +1747,8 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
type
=
"multimodal"
,
prompt
=
prompt
,
prompt_token_ids
=
prompt_ids
,
mm_kwargs
=
mm_kwargs
,
mm_hashes
=
mm_hashes
,
mm_kwargs
=
mm_
info
.
kwargs
,
mm_hashes
=
mm_
info
.
hashes
,
mm_placeholders
=
mm_placeholder_ranges
,
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment