Unverified Commit 506475de authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Optim] Compute multimodal hash only once per item (#17314)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent cfe45320
...@@ -22,8 +22,8 @@ from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, ...@@ -22,8 +22,8 @@ from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
ImageSize, MultiModalDataItems) ImageSize, MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement, BaseProcessingInfo, MultiModalHashes,
PromptUpdate) PromptReplacement, PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.deepseek_vl2 import (DeepseekVLV2Config, from vllm.transformers_utils.configs.deepseek_vl2 import (DeepseekVLV2Config,
...@@ -279,24 +279,26 @@ class DeepseekVL2MultiModalProcessor( ...@@ -279,24 +279,26 @@ class DeepseekVL2MultiModalProcessor(
prompt: Union[str, list[int]], prompt: Union[str, list[int]],
mm_data_items: MultiModalDataItems, mm_data_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
) -> tuple[list[int], MultiModalKwargs, bool]: *,
return_mm_hashes: bool,
) -> tuple[list[int], MultiModalKwargs, Optional[MultiModalHashes], bool]:
# The processor logic is different for len(images) <= 2 vs > 2 # The processor logic is different for len(images) <= 2 vs > 2
# Since the processing cache assumes that the processor output is # Since the processing cache assumes that the processor output is
# invariant of how many images are passed per prompt, we only # invariant of how many images are passed per prompt, we only
# perform caching for the most common case # perform caching for the most common case
if mm_data_items.get_count("image", strict=False) > 2: if mm_data_items.get_count("image", strict=False) > 2:
# This code path corresponds to the cache being disabled return self._apply_hf_processor(
return self._apply_hf_processor_main(
prompt=prompt, prompt=prompt,
mm_items=mm_data_items, mm_data_items=mm_data_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs, hf_processor_mm_kwargs=hf_processor_mm_kwargs,
enable_hf_prompt_update=True, return_mm_hashes=return_mm_hashes,
) )
return super()._cached_apply_hf_processor( return super()._cached_apply_hf_processor(
prompt=prompt, prompt=prompt,
mm_data_items=mm_data_items, mm_data_items=mm_data_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs, hf_processor_mm_kwargs=hf_processor_mm_kwargs,
return_mm_hashes=return_mm_hashes,
) )
......
...@@ -19,8 +19,8 @@ from vllm.multimodal import MULTIMODAL_REGISTRY ...@@ -19,8 +19,8 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalKwargs from vllm.multimodal.inputs import MultiModalKwargs
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
MultiModalDataItems) MultiModalDataItems)
from vllm.multimodal.processing import (PromptReplacement, PromptUpdate, from vllm.multimodal.processing import (MultiModalHashes, PromptReplacement,
PromptUpdateDetails) PromptUpdate, PromptUpdateDetails)
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer
from .intern_vit import InternVisionModel from .intern_vit import InternVisionModel
...@@ -488,24 +488,26 @@ class H2OVLMultiModalProcessor(InternVLMultiModalProcessor[H2OVLProcessingInfo] ...@@ -488,24 +488,26 @@ class H2OVLMultiModalProcessor(InternVLMultiModalProcessor[H2OVLProcessingInfo]
prompt: Union[str, list[int]], prompt: Union[str, list[int]],
mm_data_items: MultiModalDataItems, mm_data_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
) -> tuple[list[int], MultiModalKwargs, bool]: *,
return_mm_hashes: bool,
) -> tuple[list[int], MultiModalKwargs, Optional[MultiModalHashes], bool]:
# The processor logic is different for len(images) <= 1 vs > 1 # The processor logic is different for len(images) <= 1 vs > 1
# Since the processing cache assumes that the processor output is # Since the processing cache assumes that the processor output is
# invariant of how many images are passed per prompt, we only # invariant of how many images are passed per prompt, we only
# perform caching for the most common case # perform caching for the most common case
if mm_data_items.get_count("image", strict=False) > 1: if mm_data_items.get_count("image", strict=False) > 1:
# This code path corresponds to the cache being disabled return self._apply_hf_processor(
return self._apply_hf_processor_main(
prompt=prompt, prompt=prompt,
mm_items=mm_data_items, mm_data_items=mm_data_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs, hf_processor_mm_kwargs=hf_processor_mm_kwargs,
enable_hf_prompt_update=True, return_mm_hashes=return_mm_hashes,
) )
return super()._cached_apply_hf_processor( return super()._cached_apply_hf_processor(
prompt=prompt, prompt=prompt,
mm_data_items=mm_data_items, mm_data_items=mm_data_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs, hf_processor_mm_kwargs=hf_processor_mm_kwargs,
return_mm_hashes=return_mm_hashes,
) )
......
...@@ -396,14 +396,12 @@ def _build_llava_or_pixtral_hf_processor( ...@@ -396,14 +396,12 @@ def _build_llava_or_pixtral_hf_processor(
dummy_inputs: BaseDummyInputsBuilder[_I], dummy_inputs: BaseDummyInputsBuilder[_I],
*, *,
cache: Optional[ProcessingCache] = None, cache: Optional[ProcessingCache] = None,
enable_sanity_checks: bool = True,
) -> BaseMultiModalProcessor: ) -> BaseMultiModalProcessor:
if isinstance(info, PixtralHFProcessingInfo): if isinstance(info, PixtralHFProcessingInfo):
return PixtralHFMultiModalProcessor( return PixtralHFMultiModalProcessor(
info, info,
dummy_inputs, # type: ignore dummy_inputs, # type: ignore
cache=cache, cache=cache,
enable_sanity_checks=enable_sanity_checks,
) )
if isinstance(info, LlavaProcessingInfo): if isinstance(info, LlavaProcessingInfo):
...@@ -411,7 +409,6 @@ def _build_llava_or_pixtral_hf_processor( ...@@ -411,7 +409,6 @@ def _build_llava_or_pixtral_hf_processor(
info, info,
dummy_inputs, # type: ignore dummy_inputs, # type: ignore
cache=cache, cache=cache,
enable_sanity_checks=enable_sanity_checks,
) )
raise NotImplementedError(type(info)) raise NotImplementedError(type(info))
......
...@@ -312,14 +312,12 @@ def _build_mistral3_processor( ...@@ -312,14 +312,12 @@ def _build_mistral3_processor(
dummy_inputs: BaseDummyInputsBuilder[_I], dummy_inputs: BaseDummyInputsBuilder[_I],
*, *,
cache: Optional[ProcessingCache] = None, cache: Optional[ProcessingCache] = None,
enable_sanity_checks: bool = True,
) -> BaseMultiModalProcessor: ) -> BaseMultiModalProcessor:
assert isinstance(info, Mistral3ProcessingInfo) assert isinstance(info, Mistral3ProcessingInfo)
return Mistral3MultiModalProcessor( return Mistral3MultiModalProcessor(
info, info,
dummy_inputs, # type: ignore dummy_inputs, # type: ignore
cache=cache, cache=cache,
enable_sanity_checks=enable_sanity_checks,
) )
......
...@@ -36,8 +36,9 @@ from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, ...@@ -36,8 +36,9 @@ from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
from vllm.multimodal.parse import (ImageProcessorItems, ImageSize, from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
MultiModalDataItems) MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement, BaseProcessingInfo, MultiModalHashes,
PromptUpdate, PromptUpdateDetails) PromptReplacement, PromptUpdate,
PromptUpdateDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.tokenizer import (MistralTokenizer, from vllm.transformers_utils.tokenizer import (MistralTokenizer,
...@@ -271,15 +272,19 @@ class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo] ...@@ -271,15 +272,19 @@ class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo]
prompt: Union[str, list[int]], prompt: Union[str, list[int]],
mm_data_items: MultiModalDataItems, mm_data_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
) -> tuple[list[int], MultiModalKwargs, bool]: *,
prompt_ids, mm_kwargs, _ = super()._cached_apply_hf_processor( return_mm_hashes: bool,
) -> tuple[list[int], MultiModalKwargs, Optional[MultiModalHashes], bool]:
prompt_ids, mm_kwargs, mm_hashes, _ = super(
)._cached_apply_hf_processor(
prompt=prompt, prompt=prompt,
mm_data_items=mm_data_items, mm_data_items=mm_data_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs, hf_processor_mm_kwargs=hf_processor_mm_kwargs,
return_mm_hashes=return_mm_hashes,
) )
# NOTE: The tokens are already inserted by the chat template # NOTE: The tokens are already inserted by the chat template
return prompt_ids, mm_kwargs, True return prompt_ids, mm_kwargs, mm_hashes, True
@MULTIMODAL_REGISTRY.register_processor(PixtralMultiModalProcessor, @MULTIMODAL_REGISTRY.register_processor(PixtralMultiModalProcessor,
......
...@@ -876,6 +876,16 @@ def find_mm_placeholders( ...@@ -876,6 +876,16 @@ def find_mm_placeholders(
_V = TypeVar("_V", bound="Union[MultiModalKwargs, MultiModalKwargsItem]") _V = TypeVar("_V", bound="Union[MultiModalKwargs, MultiModalKwargsItem]")
class ProcessingCacheOptionalItem(NamedTuple):
key: str
value: Optional[MultiModalKwargsItem]
class ProcessingCacheItem(NamedTuple):
key: str
value: MultiModalKwargsItem
class ProcessingCache: class ProcessingCache:
@staticmethod @staticmethod
...@@ -980,6 +990,22 @@ class ProcessingCache: ...@@ -980,6 +990,22 @@ class ProcessingCache:
return self._cache.get(cache_key) return self._cache.get(cache_key)
def get_item(
self,
model_id: str,
modality: str,
input_item: object,
input_kwargs: Mapping[str, object],
) -> ProcessingCacheOptionalItem:
cache_key = MultiModalHasher.hash_kwargs(model_id=model_id,
**{modality: input_item},
**input_kwargs)
return ProcessingCacheOptionalItem(
key=cache_key,
value=self._cache.get(cache_key),
)
def put( def put(
self, self,
model_id: str, model_id: str,
...@@ -997,6 +1023,9 @@ class ProcessingCache: ...@@ -997,6 +1023,9 @@ class ProcessingCache:
**input_kwargs) **input_kwargs)
self._cache[cache_key] = output_kwargs self._cache[cache_key] = output_kwargs
def put_item(self, item: ProcessingCacheItem) -> None:
self._cache[item.key] = item.value
class BaseProcessingInfo: class BaseProcessingInfo:
"""Base class to provide the information necessary for data processing.""" """Base class to provide the information necessary for data processing."""
...@@ -1052,6 +1081,11 @@ class BaseProcessingInfo: ...@@ -1052,6 +1081,11 @@ class BaseProcessingInfo:
_I = TypeVar("_I", bound=BaseProcessingInfo) _I = TypeVar("_I", bound=BaseProcessingInfo)
MultiModalHashes = dict[str, list[str]]
"""
A collection of hashes with a similar structure as :class:`MultiModalKwargs`.
"""
class BaseMultiModalProcessor(ABC, Generic[_I]): class BaseMultiModalProcessor(ABC, Generic[_I]):
""" """
...@@ -1064,14 +1098,12 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -1064,14 +1098,12 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
info: _I, info: _I,
dummy_inputs: "BaseDummyInputsBuilder[_I]", dummy_inputs: "BaseDummyInputsBuilder[_I]",
*, *,
cache: Optional[ProcessingCache] = None, cache: Optional[ProcessingCache] = None) -> None:
enable_sanity_checks: bool = True) -> None:
super().__init__() super().__init__()
self.info = info self.info = info
self.dummy_inputs = dummy_inputs self.dummy_inputs = dummy_inputs
self.cache = cache self.cache = cache
self.enable_sanity_checks = enable_sanity_checks
self.data_parser = self._get_data_parser() self.data_parser = self._get_data_parser()
...@@ -1340,46 +1372,144 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -1340,46 +1372,144 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
return prompt_ids, mm_kwargs, False return prompt_ids, mm_kwargs, False
def _get_cache_missing_items(
self,
cache: ProcessingCache,
mm_data_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
) -> tuple[dict[str, list[ProcessingCacheOptionalItem]], dict[
str, list[object]]]:
model_id = self.info.model_id
mm_cache_items = {
modality: [
cache.get_item(model_id, modality, item,
hf_processor_mm_kwargs) for item in items
]
for modality, items in mm_data_items.items()
}
mm_missing_idxs = {
modality: [
idx for idx, item in enumerate(cache_items)
if item.value is None
]
for modality, cache_items in mm_cache_items.items()
}
mm_missing_data = {
modality: [mm_data_items[modality][idx] for idx in idxs]
for modality, idxs in mm_missing_idxs.items()
}
return mm_cache_items, mm_missing_data
def _hash_mm_items(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
) -> MultiModalHashes:
"""Create MM hashes to be returned (only used in V1)."""
model_id = self.info.model_id
return {
modality: [
MultiModalHasher.hash_kwargs(model_id=model_id,
**{modality: item},
**hf_processor_mm_kwargs)
for item in items
]
for modality, items in mm_items.items()
}
def _merge_mm_kwargs(
self,
cache: ProcessingCache,
mm_cache_items: dict[str, list[ProcessingCacheOptionalItem]],
mm_missing_data: dict[str, list[object]],
mm_missing_kwargs: MultiModalKwargs,
) -> dict[str, list[ProcessingCacheItem]]:
mm_missing_next_idx = {modality: 0 for modality in mm_missing_data}
merged_items = defaultdict[str, list[ProcessingCacheItem]](list)
for modality, cache_items in mm_cache_items.items():
for cache_item in cache_items:
if cache_item.value is None:
kw_item = mm_missing_kwargs.get_item(
modality,
mm_missing_next_idx[modality],
)
cache_item_new = ProcessingCacheItem(
key=cache_item.key,
value=kw_item,
)
cache.put_item(cache_item_new)
mm_missing_next_idx[modality] += 1
else:
cache_item_new = ProcessingCacheItem(
key=cache_item.key,
value=cache_item.value,
)
merged_items[modality].append(cache_item_new)
return dict(merged_items)
def _apply_hf_processor(
self,
prompt: Union[str, list[int]],
mm_data_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
*,
return_mm_hashes: bool,
) -> tuple[list[int], MultiModalKwargs, Optional[MultiModalHashes], bool]:
(
prompt_ids,
mm_kwargs,
is_update_applied,
) = self._apply_hf_processor_main(
prompt=prompt,
mm_items=mm_data_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
enable_hf_prompt_update=True,
)
mm_hashes = (self._hash_mm_items(mm_data_items, hf_processor_mm_kwargs)
if return_mm_hashes else None)
return prompt_ids, mm_kwargs, mm_hashes, is_update_applied
def _cached_apply_hf_processor( def _cached_apply_hf_processor(
self, self,
prompt: Union[str, list[int]], prompt: Union[str, list[int]],
mm_data_items: MultiModalDataItems, mm_data_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
) -> tuple[list[int], MultiModalKwargs, bool]: *,
return_mm_hashes: bool,
) -> tuple[list[int], MultiModalKwargs, Optional[MultiModalHashes], bool]:
""" """
Apply the HF processor on the full prompt text, Apply the HF processor on the full prompt text,
caching the results and reusing cached results. caching the results and reusing cached results.
""" """
cache = self.cache cache = self.cache
model_id = self.info.model_id
_, passthrough_data = self._get_hf_mm_data(mm_data_items) _, passthrough_data = self._get_hf_mm_data(mm_data_items)
if cache is None or passthrough_data: if cache is None or passthrough_data:
return self._apply_hf_processor_main( return self._apply_hf_processor(
prompt=prompt, prompt=prompt,
mm_items=mm_data_items, mm_data_items=mm_data_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs, hf_processor_mm_kwargs=hf_processor_mm_kwargs,
enable_hf_prompt_update=True, return_mm_hashes=return_mm_hashes,
) )
mm_maybe_cached_kw_items = { (
modality: [ mm_cache_items,
cache.get(model_id, modality, item, hf_processor_mm_kwargs) mm_missing_data,
for item in items ) = self._get_cache_missing_items(
] cache=cache,
for modality, items in mm_data_items.items() mm_data_items=mm_data_items,
} hf_processor_mm_kwargs=hf_processor_mm_kwargs,
)
mm_missing_idxs = {
modality:
[idx for idx, item in enumerate(kw_items) if item is None]
for modality, kw_items in mm_maybe_cached_kw_items.items()
}
mm_missing_data = {
modality: [mm_data_items[modality][idx] for idx in idxs]
for modality, idxs in mm_missing_idxs.items()
}
mm_missing_data_items = self._to_mm_items(mm_missing_data)
# NOTE: `prompt` does not correspond to `mm_missing_data_items`, # NOTE: `prompt` does not correspond to `mm_missing_data_items`,
# so we can't apply prompt updates until the new multimodal # so we can't apply prompt updates until the new multimodal
...@@ -1390,48 +1520,29 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -1390,48 +1520,29 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
is_update_applied, is_update_applied,
) = self._apply_hf_processor_main( ) = self._apply_hf_processor_main(
prompt=prompt, prompt=prompt,
mm_items=mm_missing_data_items, mm_items=self._to_mm_items(mm_missing_data),
hf_processor_mm_kwargs=hf_processor_mm_kwargs, hf_processor_mm_kwargs=hf_processor_mm_kwargs,
enable_hf_prompt_update=False, enable_hf_prompt_update=False,
) )
mm_missing_next_idx = { mm_cache_items_merged = self._merge_mm_kwargs(
modality: 0 cache,
for modality in mm_missing_data_items mm_cache_items=mm_cache_items,
} mm_missing_data=mm_missing_data,
mm_missing_kwargs=mm_missing_kwargs,
merged_kw_items = list[MultiModalKwargsItem]() )
for modality, kw_items in mm_maybe_cached_kw_items.items():
for idx, kw_item in enumerate(kw_items):
if kw_item is None:
kw_item = mm_missing_kwargs.get_item(
modality,
mm_missing_next_idx[modality],
)
cache.put(
model_id,
modality,
mm_data_items[modality][idx],
hf_processor_mm_kwargs,
kw_item,
)
mm_missing_next_idx[modality] += 1
merged_kw_items.append(kw_item)
if self.enable_sanity_checks: mm_kwargs = MultiModalKwargs.from_items([
mm_missing_counts = mm_missing_data_items.get_all_counts() item.value for cache_items in mm_cache_items_merged.values()
assert all( for item in cache_items
item_count == mm_missing_counts[modality] ])
for modality, item_count in mm_missing_next_idx.items()), dict(
mm_missing_next_idx=mm_missing_next_idx,
mm_missing_counts=mm_missing_counts)
mm_kwargs = MultiModalKwargs.from_items(merged_kw_items) mm_hashes = {
modality: [item.key for item in cache_items]
for modality, cache_items in mm_cache_items_merged.items()
} if return_mm_hashes else None
return prompt_ids, mm_kwargs, is_update_applied return prompt_ids, mm_kwargs, mm_hashes, is_update_applied
def _bind_and_group_updates( def _bind_and_group_updates(
self, self,
...@@ -1569,27 +1680,6 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -1569,27 +1680,6 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
"model (usually arising from an inconsistency between " "model (usually arising from an inconsistency between "
"`_call_hf_processor` and `_get_prompt_updates`).") "`_call_hf_processor` and `_get_prompt_updates`).")
def _hash_mm_items(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
) -> dict[str, list[str]]:
"""Create MM hashes to be returned (only used in V1)."""
# TODO: Use these hash keys for caching operations in apply_hf_processor
# instead of rehashing.
model_id = self.info.model_id
return {
modality: [
MultiModalHasher.hash_kwargs(model_id=model_id,
**{modality: item},
**hf_processor_mm_kwargs)
for item in items
]
for modality, items in mm_items.items()
}
def _maybe_apply_prompt_updates( def _maybe_apply_prompt_updates(
self, self,
mm_items: MultiModalDataItems, mm_items: MultiModalDataItems,
...@@ -1655,17 +1745,16 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -1655,17 +1745,16 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
""" """
mm_items = self._to_mm_items(mm_data) mm_items = self._to_mm_items(mm_data)
mm_hashes = (self._hash_mm_items(mm_items, hf_processor_mm_kwargs)
if return_mm_hashes else None)
( (
prompt_ids, prompt_ids,
mm_kwargs, mm_kwargs,
mm_hashes,
is_update_applied, is_update_applied,
) = self._cached_apply_hf_processor( ) = self._cached_apply_hf_processor(
prompt, prompt,
mm_items, mm_items,
hf_processor_mm_kwargs, hf_processor_mm_kwargs,
return_mm_hashes=return_mm_hashes,
) )
prompt_ids, prompt, mm_placeholders = self._maybe_apply_prompt_updates( prompt_ids, prompt, mm_placeholders = self._maybe_apply_prompt_updates(
...@@ -1717,28 +1806,12 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]): ...@@ -1717,28 +1806,12 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]):
"""Create input prompt for the decoder.""" """Create input prompt for the decoder."""
return prompt return prompt
def apply( def _get_enc_dec_inputs(
self, self,
prompt: Union[str, list[int]], prompt: Union[str, list[int]],
mm_data: MultiModalDataDict, mm_data: MultiModalDataDict,
hf_processor_mm_kwargs: Mapping[str, object], encoder_inputs: MultiModalInputs,
return_mm_hashes: bool = False, ):
) -> MultiModalEncDecInputs:
"""
Process multi-modal inputs to be used in vLLM.
The main processing steps are modified to fit encoder-decoder model:
1. Create encoder prompt from input prompt text.
2. Apply the HF processor on encoder prompt.
3. Copy the input prompt text as decoder prompt inputs.
"""
encoder_prompt = self.create_encoder_prompt(prompt, mm_data)
encoder_inputs = super().apply(
encoder_prompt,
mm_data,
hf_processor_mm_kwargs,
return_mm_hashes,
)
tokenizer = self.info.get_tokenizer() tokenizer = self.info.get_tokenizer()
decoder_prompt = self.create_decoder_prompt(prompt, mm_data) decoder_prompt = self.create_decoder_prompt(prompt, mm_data)
if isinstance(decoder_prompt, str): if isinstance(decoder_prompt, str):
...@@ -1758,3 +1831,31 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]): ...@@ -1758,3 +1831,31 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]):
"prompt_token_ids": decoder_prompt_ids "prompt_token_ids": decoder_prompt_ids
}) })
return mm_inputs return mm_inputs
def apply(
self,
prompt: Union[str, list[int]],
mm_data: MultiModalDataDict,
hf_processor_mm_kwargs: Mapping[str, object],
return_mm_hashes: bool = False,
) -> MultiModalEncDecInputs:
"""
Process multi-modal inputs to be used in vLLM.
The main processing steps are modified to fit encoder-decoder model:
1. Create encoder prompt from input prompt text.
2. Apply the HF processor on encoder prompt.
3. Copy the input prompt text as decoder prompt inputs.
"""
encoder_prompt = self.create_encoder_prompt(prompt, mm_data)
encoder_inputs = super().apply(
encoder_prompt,
mm_data,
hf_processor_mm_kwargs,
return_mm_hashes,
)
return self._get_enc_dec_inputs(
prompt=prompt,
mm_data=mm_data,
encoder_inputs=encoder_inputs,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment