Unverified Commit d3f71f12 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Refactor] Get prompt updates earlier (#23097)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 5a30bd10
......@@ -25,7 +25,8 @@ from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
ImageSize, MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, MultiModalHashes,
BaseProcessingInfo,
MultiModalProcessingInfo,
PromptReplacement, PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
......@@ -291,8 +292,7 @@ class DeepseekVL2MultiModalProcessor(
tokenization_kwargs: Mapping[str, object],
*,
return_mm_hashes: bool,
) -> tuple[list[int], MultiModalKwargsItems, Optional[MultiModalHashes],
bool]:
) -> tuple[list[int], MultiModalProcessingInfo, bool]:
# The processor logic is different for len(images) <= 2 vs > 2
# Since the processing cache assumes that the processor output is
# invariant of how many images are passed per prompt, we only
......
......@@ -20,8 +20,9 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalKwargsItems
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
MultiModalDataItems)
from vllm.multimodal.processing import (MultiModalHashes, PromptReplacement,
PromptUpdate, PromptUpdateDetails)
from vllm.multimodal.processing import (MultiModalProcessingInfo,
PromptReplacement, PromptUpdate,
PromptUpdateDetails)
from vllm.transformers_utils.tokenizer import AnyTokenizer
from .intern_vit import InternVisionModel
......@@ -480,8 +481,7 @@ class H2OVLMultiModalProcessor(
tokenization_kwargs: Mapping[str, object],
*,
return_mm_hashes: bool,
) -> tuple[list[int], MultiModalKwargsItems, Optional[MultiModalHashes],
bool]:
) -> tuple[list[int], MultiModalProcessingInfo, bool]:
# The processor logic is different for len(images) <= 1 vs > 1
# Since the processing cache assumes that the processor output is
# invariant of how many images are passed per prompt, we only
......
......@@ -39,7 +39,8 @@ from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, MultiModalHashes,
BaseProcessingInfo,
MultiModalProcessingInfo,
PromptReplacement, PromptUpdate,
PromptUpdateDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
......@@ -309,14 +310,8 @@ class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo]
tokenization_kwargs: Mapping[str, object],
*,
return_mm_hashes: bool,
) -> tuple[list[int], MultiModalKwargsItems, Optional[MultiModalHashes],
bool]:
(
prompt_ids,
mm_kwargs,
mm_hashes,
_,
) = super()._cached_apply_hf_processor(
) -> tuple[list[int], MultiModalProcessingInfo, bool]:
prompt_ids, mm_info, _ = super()._cached_apply_hf_processor(
prompt=prompt,
mm_data_items=mm_data_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
......@@ -325,7 +320,7 @@ class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo]
)
# NOTE: The tokens are already inserted by the chat template
return prompt_ids, mm_kwargs, mm_hashes, True
return prompt_ids, mm_info, True
@MULTIMODAL_REGISTRY.register_processor(PixtralMultiModalProcessor,
......
......@@ -59,6 +59,7 @@ from vllm.multimodal.parse import (AudioProcessorItems, DictEmbeddingItems,
ModalityDataItems, MultiModalDataItems,
MultiModalDataParser)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
MultiModalPromptUpdates,
PlaceholderFeaturesInfo,
PromptReplacement, PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
......@@ -88,10 +89,7 @@ def _qwen2_5_omni_thinker_field_config(hf_inputs: Mapping[str, torch.Tensor]):
video_grid_thw = hf_inputs.get("video_grid_thw", torch.empty((0, 3)))
video_grid_sizes = video_grid_thw.prod(-1)
# vllm use `second_per_grid_ts` to compute multimodal rotary embedding
video_second_per_grid = hf_inputs.get("video_second_per_grid", None)
if video_second_per_grid is not None:
hf_inputs["second_per_grid_ts"] = video_second_per_grid
num_videos = len(video_grid_sizes)
return dict(
input_audio_features=MultiModalFieldConfig.flat_from_sizes(
......@@ -109,6 +107,7 @@ def _qwen2_5_omni_thinker_field_config(hf_inputs: Mapping[str, torch.Tensor]):
"video", video_grid_sizes),
video_grid_thw=MultiModalFieldConfig.batched("video"),
second_per_grid_ts=MultiModalFieldConfig.batched("video"),
use_audio_in_video=MultiModalFieldConfig.shared("video", num_videos),
)
......@@ -251,6 +250,14 @@ class Qwen2_5OmniThinkerMultiModalProcessor(
if ('audio_feature_lengths' not in hf_inputs
and feature_attention_mask is not None):
hf_inputs['audio_feature_lengths'] = feature_attention_mask.sum(-1)
video_second_per_grid = hf_inputs.get("video_second_per_grid", None)
if video_second_per_grid is not None:
hf_inputs["second_per_grid_ts"] = video_second_per_grid
use_audio_in_video = mm_kwargs.get("use_audio_in_video", False)
hf_inputs["use_audio_in_video"] = torch.tensor(use_audio_in_video)
return hf_inputs
def _get_mm_fields_config(
......@@ -263,27 +270,20 @@ class Qwen2_5OmniThinkerMultiModalProcessor(
def _maybe_apply_prompt_updates(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
prompt_ids: list[int],
mm_kwargs: MultiModalKwargsItems,
mm_prompt_updates: MultiModalPromptUpdates,
is_update_applied: bool,
) -> tuple[list[int], str, Mapping[str, list[PlaceholderFeaturesInfo]]]:
"""
Qwen2.5-Omni reimplements this function to handle `use_audio_in_video`.
"""
unbound_prompt_updates = self._get_prompt_updates(
mm_items,
hf_processor_mm_kwargs,
mm_kwargs,
)
mm_prompt_updates = self._bind_and_group_updates(
unbound_prompt_updates)
mm_item_counts = mm_items.get_all_counts()
self._validate_mm_kwargs(mm_kwargs, mm_item_counts)
use_audio_in_video = hf_processor_mm_kwargs.get(
"use_audio_in_video", False)
use_audio_in_video = (all(
item["use_audio_in_video"].data
for item in mm_kwargs["video"]) if "video" in mm_kwargs else False)
if is_update_applied:
mm_placeholders = self._find_mm_placeholders(
......@@ -316,9 +316,6 @@ class Qwen2_5OmniThinkerMultiModalProcessor(
tokenizer = self.info.get_tokenizer()
prompt = decode_tokens(tokenizer, prompt_ids)
if use_audio_in_video:
mm_kwargs["use_audio_in_video"] = True
return prompt_ids, prompt, mm_placeholders
def _get_prompt_updates(
......
......@@ -35,7 +35,8 @@ from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
from vllm.multimodal.parse import (AudioProcessorItems, MultiModalDataItems,
MultiModalDataParser)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, MultiModalHashes,
BaseProcessingInfo,
MultiModalProcessingInfo,
PromptReplacement, PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
......@@ -289,10 +290,8 @@ class VoxtralMultiModalProcessor(BaseMultiModalProcessor[VoxtralProcessingInfo]
tokenization_kwargs: Mapping[str, object],
*,
return_mm_hashes: bool,
) -> tuple[list[int], MultiModalKwargsItems, Optional[MultiModalHashes],
bool]:
prompt_ids, mm_kwargs, mm_hashes, _ = super(
)._cached_apply_hf_processor(
) -> tuple[list[int], MultiModalProcessingInfo, bool]:
prompt_ids, mm_info, _ = super()._cached_apply_hf_processor(
prompt=prompt,
mm_data_items=mm_data_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
......@@ -301,7 +300,7 @@ class VoxtralMultiModalProcessor(BaseMultiModalProcessor[VoxtralProcessingInfo]
)
# NOTE: The tokens are already inserted by the chat template
return prompt_ids, mm_kwargs, mm_hashes, True
return prompt_ids, mm_info, True
def _get_data_parser(self) -> MultiModalDataParser:
sampling_rate = self.info.get_hf_processor().sampling_rate
......
......@@ -989,6 +989,18 @@ A collection of hashes with a similar structure as
[`MultiModalKwargsItems`][vllm.multimodal.inputs.MultiModalKwargsItems].
"""
MultiModalPromptUpdates = dict[str, Sequence[BoundPromptUpdate]]
"""
A collection of prompt updates with a similar structure as
[`MultiModalKwargsItems`][vllm.multimodal.inputs.MultiModalKwargsItems].
"""
class MultiModalProcessingInfo(NamedTuple):
kwargs: MultiModalKwargsItems
hashes: Optional[MultiModalHashes]
prompt_updates: MultiModalPromptUpdates
class BaseMultiModalProcessor(ABC, Generic[_I]):
"""
......@@ -1363,7 +1375,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
cache: ProcessingCache,
mm_cache_items_or_hashes: dict[str, list[_CacheItemOrHash]],
mm_missing_kwargs: MultiModalKwargsItems,
) -> dict[str, list[MultiModalKwargsItem]]:
) -> MultiModalKwargsItems:
mm_missing_next_idx = defaultdict[str, int](lambda: 0)
merged_items = defaultdict[str, list[MultiModalKwargsItem]](list)
......@@ -1379,7 +1391,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
merged_items[modality].append(kw_item)
return dict(merged_items)
return MultiModalKwargsItems(merged_items)
def _apply_hf_processor(
self,
......@@ -1389,8 +1401,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
tokenization_kwargs: Mapping[str, object],
*,
return_mm_hashes: bool,
) -> tuple[list[int], MultiModalKwargsItems, Optional[MultiModalHashes],
bool]:
) -> tuple[list[int], MultiModalProcessingInfo, bool]:
(
prompt_ids,
mm_processed_data,
......@@ -1413,7 +1424,21 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
tokenization_kwargs)
if return_mm_hashes else None)
return prompt_ids, mm_kwargs, mm_hashes, is_update_applied
unbound_prompt_updates = self._get_prompt_updates(
mm_data_items,
hf_processor_mm_kwargs,
mm_kwargs,
)
mm_prompt_updates = self._bind_and_group_updates(
unbound_prompt_updates)
mm_info = MultiModalProcessingInfo(
kwargs=mm_kwargs,
hashes=mm_hashes,
prompt_updates=mm_prompt_updates,
)
return prompt_ids, mm_info, is_update_applied
def _cached_apply_hf_processor(
self,
......@@ -1423,8 +1448,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
tokenization_kwargs: Mapping[str, object],
*,
return_mm_hashes: bool,
) -> tuple[list[int], MultiModalKwargsItems, Optional[MultiModalHashes],
bool]:
) -> tuple[list[int], MultiModalProcessingInfo, bool]:
"""
Apply the HF processor on the full prompt text,
caching the results and reusing cached results.
......@@ -1475,18 +1499,27 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
hf_processor_mm_kwargs),
)
mm_cache_items_merged = self._merge_mm_kwargs(
mm_kwargs = self._merge_mm_kwargs(
cache,
mm_cache_items_or_hashes=mm_cache_items_or_hashes,
mm_missing_kwargs=mm_missing_kwargs,
)
mm_kwargs = MultiModalKwargsItems.from_seq([
item for cache_items in mm_cache_items_merged.values()
for item in cache_items
])
unbound_prompt_updates = self._get_prompt_updates(
mm_data_items,
hf_processor_mm_kwargs,
mm_kwargs,
)
mm_prompt_updates = self._bind_and_group_updates(
unbound_prompt_updates)
mm_info = MultiModalProcessingInfo(
kwargs=mm_kwargs,
hashes=mm_hashes_to_return,
prompt_updates=mm_prompt_updates,
)
return prompt_ids, mm_kwargs, mm_hashes_to_return, is_update_applied
return prompt_ids, mm_info, is_update_applied
def _bind_and_group_updates(
self,
......@@ -1626,19 +1659,11 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
def _maybe_apply_prompt_updates(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
prompt_ids: list[int],
mm_kwargs: MultiModalKwargsItems,
mm_prompt_updates: MultiModalPromptUpdates,
is_update_applied: bool,
) -> tuple[list[int], str, Mapping[str, list[PlaceholderFeaturesInfo]]]:
unbound_prompt_updates = self._get_prompt_updates(
mm_items,
hf_processor_mm_kwargs,
mm_kwargs,
)
mm_prompt_updates = self._bind_and_group_updates(
unbound_prompt_updates)
mm_item_counts = mm_items.get_all_counts()
self._validate_mm_kwargs(mm_kwargs, mm_item_counts)
......@@ -1694,8 +1719,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
(
prompt_ids,
mm_kwargs,
mm_hashes,
mm_info,
is_update_applied,
) = self._cached_apply_hf_processor(
prompt,
......@@ -1708,9 +1732,9 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
# NOTE: tokenization_kwargs are not required to init processor
prompt_ids, prompt, mm_placeholders = self._maybe_apply_prompt_updates(
mm_items=mm_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
prompt_ids=prompt_ids,
mm_kwargs=mm_kwargs,
mm_kwargs=mm_info.kwargs,
mm_prompt_updates=mm_info.prompt_updates,
is_update_applied=is_update_applied,
)
......@@ -1723,8 +1747,8 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
type="multimodal",
prompt=prompt,
prompt_token_ids=prompt_ids,
mm_kwargs=mm_kwargs,
mm_hashes=mm_hashes,
mm_kwargs=mm_info.kwargs,
mm_hashes=mm_info.hashes,
mm_placeholders=mm_placeholder_ranges,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment