Unverified Commit 69244e67 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Core] Use key-only cache for `BaseMultiModalProcessor` (#23018)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 8dbf6ed7
......@@ -86,7 +86,7 @@ llm = LLM(model="meta-llama/Llama-3.1-8B-Instruct",
If you run out of CPU RAM, try the following options:
- (Multi-modal models only) you can set the size of multi-modal processor cache by setting `mm_processor_cache_gb` engine argument (default 4 GiB per API process + 4 GiB per engine core process)
- (Multi-modal models only) you can set the size of multi-modal cache by setting `mm_processor_cache_gb` engine argument (default 4 GiB).
- (CPU backend only) you can set the size of KV cache using `VLLM_CPU_KVCACHE_SPACE` environment variable (default 4 GiB).
## Multi-modal input limits
......
......@@ -204,20 +204,33 @@ vllm serve Qwen/Qwen2.5-VL-3B-Instruct --api-server-count 4 -dp 2
to avoid CPU resource exhaustion.
!!! note
[Multi-modal processor cache](#processor-cache) is disabled when API server scale-out is enabled
because it requires a one-to-one correspondence between API and engine core processes.
API server scale-out disables [multi-modal IPC caching](#ipc-caching)
because it requires a one-to-one correspondance between API and engine core processes.
## Multi-Modal Caching
This does not impact [multi-modal processor caching](#processor-caching).
### Processor Cache
## Multi-Modal Caching
By default, the multi-modal processor cache is enabled to avoid repeatedly processing
the same multi-modal inputs via Hugging Face `AutoProcessor`,
Multi-modal caching avoids repeated transfer or processing of the same multi-modal data,
which commonly occurs in multi-turn conversations.
You can adjust the size of the cache by setting the value of `mm_processor_cache_gb`
(default 4 GiB per API process + 4 GiB per engine core process).
If you do not benefit much from the cache, you can disable it completely via `mm_processor_cache_gb=0`.
### Processor Caching
Multi-modal processor caching is automatically enabled
to avoid repeatedly processing the same multi-modal inputs in `BaseMultiModalProcessor`.
### IPC Caching
Multi-modal IPC caching is automatically enabled when
there is a one-to-one correspondance between API (`P0`) and engine core (`P1`) processes,
to avoid repeatedly transferring the same multi-modal inputs between them.
### Configuration
You can adjust the size of the cache by setting the value of `mm_processor_cache_gb` (default 4 GiB).
If you do not benefit much from the cache, you can disable both IPC
and processor caching completely via `mm_processor_cache_gb=0`.
Examples:
......@@ -230,3 +243,16 @@ llm = LLM(model="Qwen/Qwen2.5-VL-3B-Instruct",
llm = LLM(model="Qwen/Qwen2.5-VL-3B-Instruct",
mm_processor_cache_gb=0)
```
### Cache Placement
Based on the configuration, the content of the multi-modal caches on `P0` and `P1` are as follows:
| Processor Caching | IPC Caching | `P0` Cache | `P1` Cache | Max. Memory |
|-------------------|-------------|------------|------------|-------------|
| ✅ | ✅ | K | K + V | `mm_processor_cache_gb * data_parallel_size` |
| ✅ | ❌ | K + V | N/A | `mm_processor_cache_gb * api_server_count` |
| ❌ | ❌ | N/A | N/A | `0` |
K: Stores the hashes of multi-modal items
V: Stores the processed tensor data of multi-modal items
......@@ -14,8 +14,9 @@ from PIL import Image
from vllm.config import ModelConfig
from vllm.inputs import InputProcessingContext
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict
from vllm.multimodal.cache import MultiModalProcessorOnlyCache
from vllm.multimodal.inputs import MultiModalInputs
from vllm.multimodal.processing import BaseMultiModalProcessor, ProcessingCache
from vllm.multimodal.processing import BaseMultiModalProcessor
from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
cached_tokenizer_from_config,
encode_tokens)
......@@ -63,6 +64,8 @@ def _test_processing_correctness(
revision=model_info.revision,
trust_remote_code=model_info.trust_remote_code,
hf_overrides=model_info.hf_overrides,
# Ensure that the cache can fit all of the data
mm_processor_cache_gb=2048,
)
model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config)
......@@ -71,8 +74,7 @@ def _test_processing_correctness(
model_config,
tokenizer=cached_tokenizer_from_config(model_config),
)
# Ensure that it can fit all of the data
cache = ProcessingCache(capacity_gb=2048)
cache = MultiModalProcessorOnlyCache(model_config)
processing_info = factories.info(ctx)
supported_mm_limits = processing_info.get_supported_mm_limits()
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
import numpy as np
import pytest
import torch
from vllm.multimodal.cache import MultiModalCache, MultiModalCacheItemMetadata
from vllm.config import ModelConfig, ParallelConfig, VllmConfig
from vllm.multimodal.cache import (MultiModalCache,
MultiModalProcessorCacheItem,
MultiModalProcessorCacheItemMetadata,
processor_cache_from_config,
receiver_cache_from_config)
from vllm.multimodal.hasher import MultiModalHasher
from vllm.multimodal.inputs import (MultiModalFieldElem, MultiModalKwargsItem,
MultiModalKwargsItems,
MultiModalSharedField)
from vllm.multimodal.processing import PromptInsertion
from vllm.multimodal.registry import MultiModalRegistry
def _dummy_elem(
modality: str,
key: str,
size: int,
*,
rng: Optional[np.random.RandomState] = None,
):
if rng is None:
data = torch.empty((size, ), dtype=torch.int8)
else:
data = torch.from_numpy(rng.randint(4, size=(size, ), dtype=np.int8))
def _dummy_elem(modality: str, key: str, size: int):
return MultiModalFieldElem(
modality=modality,
key=key,
data=torch.empty((size, ), dtype=torch.int8),
data=data,
field=MultiModalSharedField(1),
)
def _dummy_item(modality: str, size_by_key: dict[str, int]):
def _dummy_item(
modality: str,
size_by_key: dict[str, int],
*,
rng: Optional[np.random.RandomState] = None,
):
return MultiModalKwargsItem.from_elems([
_dummy_elem(modality, key, size) for key, size in size_by_key.items()
_dummy_elem(modality, key, size, rng=rng)
for key, size in size_by_key.items()
])
def _dummy_items(size_by_key_modality: dict[str, dict[str, int]]):
def _dummy_items(
size_by_key_modality: dict[str, dict[str, int]],
*,
rng: Optional[np.random.RandomState] = None,
):
return MultiModalKwargsItems.from_seq([
_dummy_item(modality, size_by_key)
_dummy_item(modality, size_by_key, rng=rng)
for modality, size_by_key in size_by_key_modality.items()
])
......@@ -48,5 +80,139 @@ def test_cache_item_size(item, expected_size):
cache[""] = item
assert cache.currsize == expected_size
cache[""] = MultiModalCacheItemMetadata.wraps(item)
prompt_update = PromptInsertion("dummy", "target", "insertion") \
.resolve(0)
cache[""] = MultiModalProcessorCacheItem(item, [prompt_update])
assert cache.currsize == expected_size
cache[""] = MultiModalProcessorCacheItemMetadata(item, [prompt_update])
assert cache.currsize == expected_size
def _create_vllm_config(
*,
mm_processor_cache_gb: float,
enable_ipc: bool,
):
return VllmConfig(
model_config=ModelConfig(mm_processor_cache_gb=mm_processor_cache_gb),
parallel_config=ParallelConfig(
data_parallel_size=1 if enable_ipc else 2),
)
def _compare_caches(
config_0: VllmConfig,
config_1: VllmConfig,
*,
item_capacity: int = 8,
hit_rate: float = 0.5,
max_items_per_iter: int = 3,
is_cached_calls_per_iter: int,
n_iter: int = 100,
seed: int = 0,
):
mm_registry = MultiModalRegistry()
cache_0_p0 = processor_cache_from_config(config_0, mm_registry)
cache_0_p1 = receiver_cache_from_config(config_0, mm_registry)
cache_1_p0 = processor_cache_from_config(config_1, mm_registry)
cache_1_p1 = receiver_cache_from_config(config_1, mm_registry)
cache_size_gb = max(
config_0.model_config.mm_processor_cache_gb,
config_1.model_config.mm_processor_cache_gb,
)
item_size_gb = int(cache_size_gb / item_capacity)
rng = np.random.RandomState(seed)
all_items = [
_dummy_item("item", {"key": item_size_gb}, rng=rng)
for _ in range(int(item_capacity / hit_rate))
]
all_hashes = [
MultiModalHasher.hash_kwargs(item=item.get_data())
for item in all_items
]
# Should not be used since there is nothing to convert to text
prompt_update = PromptInsertion("dummy", "target", "insertion")
for it in range(n_iter):
num_items_to_select = rng.randint(0, max_items_per_iter)
item_idxs_to_select = rng.choice(len(all_items), num_items_to_select)
selected_items = [all_items[idx] for idx in item_idxs_to_select]
selected_hashes = [all_hashes[idx] for idx in item_idxs_to_select]
if cache_0_p0 is None:
cache_0_p0_out = selected_items
else:
for _ in range(is_cached_calls_per_iter):
cache_0_p0.is_cached(selected_hashes)
cache_0_p0_out = [
item for item, _ in cache_0_p0.get_and_update(
[(item, prompt_update.content) for item in selected_items],
selected_hashes,
)
]
if cache_1_p0 is None:
cache_1_p0_out = selected_items
else:
for _ in range(is_cached_calls_per_iter):
cache_1_p0.is_cached(selected_hashes)
cache_1_p0_out = [
item for item, _ in cache_1_p0.get_and_update(
[(item, prompt_update.content) for item in selected_items],
selected_hashes,
)
]
if cache_0_p1 is None:
cache_0_p1_out = cache_0_p0_out
else:
cache_0_p1_out = cache_0_p1.get_and_update(cache_0_p0_out,
selected_hashes)
if cache_1_p1 is None:
cache_1_p1_out = cache_1_p0_out
else:
cache_1_p1_out = cache_1_p1.get_and_update(cache_1_p0_out,
selected_hashes)
assert cache_0_p1_out == cache_1_p1_out, f"Failed at {it=}"
@pytest.mark.parametrize("is_cached_calls_per_iter", [1, 2, 3])
def test_ipc_enable_disable_consistency(is_cached_calls_per_iter):
cache_size_gb = 1 / (1 << 20)
vllm_config_ipc_enabled = _create_vllm_config(
mm_processor_cache_gb=cache_size_gb,
enable_ipc=True,
)
vllm_config_ipc_disabled = _create_vllm_config(
mm_processor_cache_gb=0,
enable_ipc=False,
)
vllm_config_cache_disabled = _create_vllm_config(
mm_processor_cache_gb=cache_size_gb,
enable_ipc=True,
)
_compare_caches(
vllm_config_ipc_enabled,
vllm_config_ipc_disabled,
is_cached_calls_per_iter=is_cached_calls_per_iter,
)
_compare_caches(
vllm_config_ipc_disabled,
vllm_config_cache_disabled,
is_cached_calls_per_iter=is_cached_calls_per_iter,
)
_compare_caches(
vllm_config_cache_disabled,
vllm_config_ipc_enabled,
is_cached_calls_per_iter=is_cached_calls_per_iter,
)
......@@ -437,7 +437,7 @@ class ModelConfig:
from `AutoProcessor.from_pretrained`. The available overrides depend on the
model that is being run. For example, for Phi-3-Vision: `{"num_crops": 4}`.
"""
mm_processor_cache_gb: int = 4
mm_processor_cache_gb: float = 4
"""The size (in GiB) of the multi-modal processor cache, which is used to
avoid re-processing past multi-modal inputs.
......@@ -884,12 +884,6 @@ class ModelConfig:
return None
def set_mm_processor_cache_gb(self, value: int) -> None:
mm_config = self.get_multimodal_config()
self.mm_processor_cache_gb = value
mm_config.mm_processor_cache_gb = value
def _get_encoder_config(self):
return get_sentence_transformer_tokenizer_config(
self.model, self.revision)
......@@ -1697,22 +1691,6 @@ class ModelConfig:
def is_multimodal_model(self) -> bool:
return self.multimodal_config is not None
@property
def enable_mm_processor_cache(self) -> bool:
"""Whether the multi-modal processor cache should be enabled."""
mm_config = self.multimodal_config
if mm_config is None:
return False
return mm_config.mm_processor_cache_gb > 0
def get_mm_input_cache_gb(self) -> int:
mm_config = self.multimodal_config
if mm_config is None:
return 0
return envs.VLLM_MM_INPUT_CACHE_GIB
@property
def is_cross_encoder(self) -> bool:
return (self._model_info.supports_cross_encoding
......@@ -2561,7 +2539,7 @@ class MultiModalConfig:
`{"num_crops": 4}`.
"""
mm_processor_cache_gb: int = 4
mm_processor_cache_gb: float = 4
"""
The size (in GiB) of the multi-modal processor cache, which is used to
......
......@@ -351,7 +351,7 @@ class EngineArgs:
mm_processor_kwargs: Optional[Dict[str, Any]] = \
MultiModalConfig.mm_processor_kwargs
disable_mm_preprocessor_cache: bool = False # DEPRECATED
mm_processor_cache_gb: int = MultiModalConfig.mm_processor_cache_gb
mm_processor_cache_gb: float = MultiModalConfig.mm_processor_cache_gb
mm_encoder_tp_mode: MMEncoderTPMode = MultiModalConfig.mm_encoder_tp_mode
skip_mm_profiling: bool = MultiModalConfig.skip_mm_profiling
# LoRA fields
......@@ -1293,18 +1293,6 @@ class EngineArgs:
worker_extension_cls=self.worker_extension_cls,
)
if model_config.is_multimodal_model:
dp_supports_mm_processor_cache = (self.data_parallel_size == 1
or data_parallel_external_lb)
if (not dp_supports_mm_processor_cache
and model_config.mm_processor_cache_gb > 0):
logger.warning(
"Multi-modal processor cache is disabled because "
"it is not compatible with data parallelism when "
"there does not exist a one-to-one correspondance "
"between API and engine core processes.")
model_config.set_mm_processor_cache_gb(0)
speculative_config = self.create_speculative_config(
target_model_config=model_config,
target_parallel_config=parallel_config,
......
......@@ -36,6 +36,7 @@ from vllm.logits_process import get_bad_words_logits_processors
from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.multimodal.cache import processor_only_cache_from_config
from vllm.multimodal.processing import EncDecMultiModalProcessor
from vllm.outputs import (PoolingRequestOutput, RequestOutput,
RequestOutputFactory)
......@@ -250,9 +251,13 @@ class LLMEngine:
self.generation_config_fields = (
self.model_config.try_get_generation_config())
self.input_preprocessor = InputPreprocessor(self.model_config,
self.tokenizer,
mm_registry)
self.input_preprocessor = InputPreprocessor(
self.model_config,
self.tokenizer,
mm_registry,
mm_processor_cache=processor_only_cache_from_config(
self.model_config, mm_registry),
)
self.model_executor = executor_class(vllm_config=vllm_config)
......@@ -840,8 +845,8 @@ class LLMEngine:
def reset_mm_cache(self) -> bool:
"""Reset the multi-modal cache."""
return self.input_preprocessor.mm_registry.reset_processor_cache(
self.model_config)
self.input_preprocessor.clear_cache()
return True
def reset_prefix_cache(self, device: Optional[Device] = None) -> bool:
"""Reset prefix cache for all devices."""
......
......@@ -11,6 +11,7 @@ from vllm.config import ModelConfig
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.multimodal.cache import BaseMultiModalProcessorCache
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalEncDecInputs,
MultiModalInputs)
from vllm.transformers_utils.tokenizer import AnyTokenizer
......@@ -32,12 +33,14 @@ class InputPreprocessor:
model_config: ModelConfig,
tokenizer: Optional[TokenizerGroup],
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
mm_processor_cache: Optional[BaseMultiModalProcessorCache] = None,
) -> None:
super().__init__()
self.model_config = model_config
self.tokenizer = tokenizer
self.mm_registry = mm_registry
self.mm_processor_cache = mm_processor_cache
def get_tokenizer_group(self) -> TokenizerGroup:
if self.tokenizer is None:
......@@ -261,8 +264,11 @@ class InputPreprocessor:
"""
tokenizer = self._get_mm_tokenizer(lora_request)
mm_processor = self.mm_registry.create_processor(self.model_config,
tokenizer=tokenizer)
mm_processor = self.mm_registry.create_processor(
self.model_config,
tokenizer=tokenizer,
cache=self.mm_processor_cache,
)
if mm_processor_kwargs is None:
mm_processor_kwargs = {}
......@@ -286,8 +292,12 @@ class InputPreprocessor:
"""
tokenizer = await self._get_mm_tokenizer_async(lora_request)
mm_processor = self.mm_registry.create_processor(self.model_config,
tokenizer=tokenizer)
mm_processor = self.mm_registry.create_processor(
self.model_config,
tokenizer=tokenizer,
cache=self.mm_processor_cache,
)
if mm_processor_kwargs is None:
mm_processor_kwargs = {}
......@@ -860,3 +870,7 @@ class InputPreprocessor:
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
)
def clear_cache(self) -> None:
if self.mm_processor_cache is not None:
self.mm_processor_cache.clear_cache()
......@@ -223,20 +223,26 @@ class InputRegistry:
The model is identified by ``model_config``.
"""
# Avoid circular import
from vllm.multimodal.cache import processor_only_cache_from_config
from vllm.sequence import SequenceData
if not model_config.is_multimodal_model:
seq_data = SequenceData.from_prompt_token_counts((0, seq_len))
return DummyData(seq_data=seq_data)
cache = processor_only_cache_from_config(model_config, mm_registry)
# Encoder dummy data does not contain multi-modal data
if is_encoder_data:
enc_data = mm_registry.get_encoder_dummy_data(
model_config, seq_len)
enc_data = mm_registry.get_encoder_dummy_data(model_config,
seq_len,
cache=cache)
seq_data = SequenceData.from_seqs(enc_data.prompt_token_ids)
return DummyData(seq_data=seq_data)
dec_data = mm_registry.get_decoder_dummy_data(model_config, seq_len)
dec_data = mm_registry.get_decoder_dummy_data(model_config,
seq_len,
cache=cache)
return DummyData(
seq_data=SequenceData.from_seqs(dec_data.prompt_token_ids),
......
......@@ -33,12 +33,13 @@ from vllm.inputs import InputProcessingContext
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.cache import BaseMultiModalProcessorCache
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalKwargsItems)
from vllm.multimodal.parse import ImageSize, MultiModalDataItems
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, ProcessingCache,
PromptReplacement, PromptUpdate)
BaseProcessingInfo, PromptReplacement,
PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
......@@ -367,7 +368,7 @@ def _build_hcxvision_hf_processor(
info: HCXVisionProcessingInfo,
dummy_inputs: BaseDummyInputsBuilder[HCXVisionProcessingInfo],
*,
cache: Optional[ProcessingCache] = None,
cache: Optional[BaseMultiModalProcessorCache] = None,
) -> BaseMultiModalProcessor:
if isinstance(info, HCXVisionProcessingInfo):
return HCXVisionMultiModalProcessor(
......
......@@ -22,14 +22,14 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.cache import BaseMultiModalProcessorCache
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalInputs, MultiModalKwargsItems)
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
ImageSize, MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, ProcessingCache,
PromptReplacement, PromptUpdate,
PromptUpdateDetails)
BaseProcessingInfo, PromptReplacement,
PromptUpdate, PromptUpdateDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.utils.jsontree import json_map_leaves
......@@ -394,7 +394,7 @@ def _build_llava_or_pixtral_hf_processor(
info: _I,
dummy_inputs: BaseDummyInputsBuilder[_I],
*,
cache: Optional[ProcessingCache] = None,
cache: Optional[BaseMultiModalProcessorCache] = None,
) -> BaseMultiModalProcessor:
if isinstance(info, PixtralHFProcessingInfo):
return PixtralHFMultiModalProcessor(
......
......@@ -58,7 +58,8 @@ from vllm.multimodal.parse import (DictEmbeddingItems, ImageItem,
VideoItem, VideoProcessorItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement,
PromptUpdate, PromptUpdateDetails)
PromptUpdate, PromptUpdateDetails,
ResolvedPromptUpdate, _seq2text)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
......@@ -744,6 +745,43 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
for modality, pattern in placeholders
]
def _recompute_cached_prompt_update(
self,
cached_update: ResolvedPromptUpdate,
new_item_idx: int,
) -> ResolvedPromptUpdate:
new_update = super()._recompute_cached_prompt_update(
cached_update,
new_item_idx,
)
if cached_update.modality == "image":
tokenizer = self.info.get_tokenizer()
image_processor = self.info.get_image_processor()
version = self.info.get_model_version()
text = _seq2text(tokenizer, cached_update.content.full)
prev_item_idx = cached_update.item_idx
if version == (2, 0) or version == (2, 5):
im_start = image_processor.im_start_token
im_end = image_processor.im_end_token
else:
im_start = image_processor.im_id_start
im_end = image_processor.im_id_end
new_update = new_update.with_content(
PromptUpdateDetails.select_text(
text.replace(
f"{im_start}{prev_item_idx}{im_end}",
f"{im_start}{new_item_idx}{im_end}",
1,
),
"<unk>",
))
return new_update
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
......
......@@ -22,14 +22,14 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.cache import BaseMultiModalProcessorCache
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalKwargsItems)
from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, ProcessingCache,
PromptReplacement, PromptUpdate,
PromptUpdateDetails)
BaseProcessingInfo, PromptReplacement,
PromptUpdate, PromptUpdateDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.utils.tensor_schema import TensorSchema, TensorShape
......@@ -322,7 +322,7 @@ def _build_mistral3_processor(
info: _I,
dummy_inputs: BaseDummyInputsBuilder[_I],
*,
cache: Optional[ProcessingCache] = None,
cache: Optional[BaseMultiModalProcessorCache] = None,
) -> BaseMultiModalProcessor:
assert isinstance(info, Mistral3ProcessingInfo)
return Mistral3MultiModalProcessor(
......
......@@ -41,7 +41,8 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo,
MultiModalPromptUpdates,
PlaceholderFeaturesInfo,
PromptReplacement, PromptUpdate)
PromptReplacement, PromptUpdate,
ResolvedPromptUpdate)
# yapf: enable
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
......@@ -440,6 +441,23 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor[Phi3VProcessingInfo]):
)
]
def _recompute_cached_prompt_update(
self,
cached_update: ResolvedPromptUpdate,
new_item_idx: int,
) -> ResolvedPromptUpdate:
new_update = super()._recompute_cached_prompt_update(
cached_update,
new_item_idx,
)
if cached_update.modality == "image":
hf_processor = self.info.get_hf_processor()
image_tokens: list[str] = hf_processor.img_tokens # type: ignore
new_update = new_update.with_target(image_tokens[new_item_idx])
return new_update
def _apply_prompt_updates(
self,
token_ids: list[int],
......
......@@ -27,7 +27,7 @@ from vllm.multimodal.parse import (AudioProcessorItems, ImageEmbeddingItems,
MultiModalDataItems, MultiModalDataParser)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement,
PromptUpdate)
PromptUpdate, ResolvedPromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.utils import is_list_of
......@@ -850,6 +850,25 @@ class Phi4MMMultiModalProcessor(BaseMultiModalProcessor[Phi4MMProcessingInfo]):
),
]
def _recompute_cached_prompt_update(
self,
cached_update: ResolvedPromptUpdate,
new_item_idx: int,
) -> ResolvedPromptUpdate:
new_update = super()._recompute_cached_prompt_update(
cached_update,
new_item_idx,
)
if cached_update.modality == "image":
image_tokens: list[str] = self.info.image_tokens # type: ignore
new_update = new_update.with_target(image_tokens[new_item_idx])
elif cached_update.modality == "audio":
audio_tokens: list[str] = self.info.audio_tokens # type: ignore
new_update = new_update.with_target(audio_tokens[new_item_idx])
return new_update
@MULTIMODAL_REGISTRY.register_processor(
Phi4MMMultiModalProcessor,
......
......@@ -25,12 +25,13 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.models.llava import LlavaDummyInputsBuilder
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.cache import BaseMultiModalProcessorCache
from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargsItems
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
ImageSize, MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, ProcessingCache,
PromptReplacement, PromptUpdate)
BaseProcessingInfo, PromptReplacement,
PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.utils.jsontree import json_map_leaves
......@@ -332,7 +333,7 @@ def _build_tarsier_hf_processor(
info: _I_Tarsier,
dummy_inputs: BaseDummyInputsBuilder[_I_Tarsier],
*,
cache: Optional[ProcessingCache] = None,
cache: Optional[BaseMultiModalProcessorCache] = None,
) -> BaseMultiModalProcessor:
if isinstance(info, TarsierProcessingInfo):
return TarsierMultiModalProcessor(
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import sys
from collections.abc import Mapping
from dataclasses import dataclass
from typing import TypeVar, Union
from abc import ABC, abstractmethod
from collections.abc import Mapping, Sequence
from typing import TYPE_CHECKING, Generic, Optional, TypeVar, Union
import torch
from typing_extensions import TypeAlias, override
from vllm.logger import init_logger
from vllm.utils import GiB_bytes, LRUCache
......@@ -15,24 +16,67 @@ from .inputs import (MultiModalFieldElem, MultiModalKwargs,
MultiModalKwargsItem, MultiModalKwargsItems,
NestedTensors)
if TYPE_CHECKING:
from vllm.config import ModelConfig, VllmConfig
from .processing import ResolvedPromptUpdate
from .registry import MultiModalRegistry
logger = init_logger(__name__)
@dataclass
class MultiModalCacheItemMetadata:
size: int
class MultiModalProcessorCacheItem:
"""
The data to store inside `MultiModalProcessorOnlyCache`.
@classmethod
def wraps(cls, value: "MultiModalCacheValue"):
return cls(size=MultiModalCache.get_item_size(value))
Args:
item: The processed tensor data corresponding to a multi-modal item.
prompt_updates: The prompt updates corresponding to `item`.
"""
def __init__(
self,
item: MultiModalKwargsItem,
prompt_updates: Sequence["ResolvedPromptUpdate"],
) -> None:
super().__init__()
self.item = item
self.prompt_updates = prompt_updates
class MultiModalProcessorCacheItemMetadata:
"""
The metadata to store inside `MultiModalProcessorSenderCache`.
Args:
item: The processed tensor data corresponding to a multi-modal item.
Since P1 already stores the tensor data, we only store its size
metadata in P0 to reduce memory usage. The size metadata is still
needed to keep the same cache eviction policy as P0.
prompt_updates: The prompt updates corresponding to `item`.
This needs to stay on P0 because for some models, they are
dependent on the processed tensor data (cached on P1).
"""
def __init__(
self,
item: MultiModalKwargsItem,
prompt_updates: Sequence["ResolvedPromptUpdate"],
) -> None:
super().__init__()
self.item_size = MultiModalCache.get_item_size(item)
self.prompt_updates = prompt_updates
MultiModalCacheValue = Union[
MultiModalProcessorCacheItem,
MultiModalProcessorCacheItemMetadata,
MultiModalKwargsItems,
MultiModalKwargsItem,
MultiModalKwargs,
Mapping[str, NestedTensors],
MultiModalCacheItemMetadata,
]
_V = TypeVar("_V", bound=MultiModalCacheValue)
......@@ -47,8 +91,10 @@ class MultiModalCache:
*,
debug: bool = False,
) -> int:
if isinstance(leaf, MultiModalFieldElem):
return cls.get_item_size(leaf.data) # type: ignore
if isinstance(leaf, MultiModalProcessorCacheItem):
return cls.get_leaf_size(leaf.item)
if isinstance(leaf, MultiModalProcessorCacheItemMetadata):
return leaf.item_size
# These are not subclasses of dict
if isinstance(leaf, MultiModalKwargsItems):
......@@ -58,13 +104,13 @@ class MultiModalCache:
if isinstance(leaf, MultiModalKwargs):
return cls.get_item_size(leaf.data) # type: ignore
if isinstance(leaf, MultiModalFieldElem):
return cls.get_item_size(leaf.data) # type: ignore
# sys.getsizeof doesn't work for tensors
if isinstance(leaf, torch.Tensor):
return leaf.nbytes
if isinstance(leaf, MultiModalCacheItemMetadata):
return leaf.size
return sys.getsizeof(leaf)
@classmethod
......@@ -98,3 +144,332 @@ class MultiModalCache:
GiB_bytes * capacity_gb,
getsizeof=lambda x: cls.get_item_size(x, debug=debug),
)
_I = TypeVar("_I", contravariant=True)
_O = TypeVar("_O", covariant=True)
class BaseMultiModalCache(ABC, Generic[_I, _O]):
"""
Abstract base class to read/write multi-modal items from cache.
The idea of multi-modal caching is based on having a client and server
where the client executes in the frontend process (=P0) and
the server in the core process (=P1). The data flow is as follows:
```
is_cached() x N get_and_update()
P0: From API -----------------> -----------------> To P1
get_and_update()
P1: From P0 -----------------> To model
```
`is_cached()` can be called any number of times in P0. However,
`get_and_update()` must be called in P0 and P1 one after another
so that their cache eviction order remains the same.
This ensures that the keys in P0 and P1 caches are mirrored,
allowing us to determine whether a key is cached in P1 by looking
up the P0 cache, without having to communicate with P1.
"""
@abstractmethod
def get_and_update_item(
self,
mm_item: _I,
mm_hash: str,
) -> _O:
"""
Possibly update a multi-modal item based on whether it is
in the underlying cache.
This update is done out-of-place and updates the cache eviction order.
Args:
mm_item: The multi-modal item to update.
mm_hash: The hash of `mm_item`.
Returns:
The update multi-modal item.
"""
raise NotImplementedError
def get_and_update(
self,
mm_items: Sequence[_I],
mm_hashes: list[str],
) -> list[_O]:
"""
Possibly update a sequence of multi-modal items based on whether they
are in the underlying cache.
This update is done out-of-place and updates the cache eviction order.
Args:
mm_items: The multi-modal items to update.
mm_hashes: The hash of each item in `mm_items`.
Returns:
A new list of updated multi-modal items.
"""
assert len(mm_items) == len(mm_hashes)
return [
self.get_and_update_item(mm_item, mm_hash)
for mm_item, mm_hash in zip(mm_items, mm_hashes)
]
@abstractmethod
def clear_cache(self) -> None:
"""Clear the underlying cache."""
raise NotImplementedError
MultiModalProcessorCacheInItem: TypeAlias = \
Optional[tuple[MultiModalKwargsItem, Sequence["ResolvedPromptUpdate"]]]
MultiModalProcessorCacheOutItem: TypeAlias = \
tuple[Optional[MultiModalKwargsItem], Sequence["ResolvedPromptUpdate"]]
class BaseMultiModalProcessorCache(
BaseMultiModalCache[MultiModalProcessorCacheInItem,
MultiModalProcessorCacheOutItem]):
"""The required interface for caches on P0."""
@abstractmethod
def is_cached_item(self, mm_hash: str) -> bool:
"""
Check whether a multi-modal item is
in the underlying cache.
This **DOES NOT** update the cache eviction order.
Args:
mm_hash: The hash of the item to check.
Returns:
`True` if the item is cached, otherwise `False`.
"""
raise NotImplementedError
def is_cached(self, mm_hashes: list[str]) -> list[bool]:
"""
Check whether a sequence of multi-modal items are
in the underlying cache.
This **DOES NOT** update the cache eviction order.
Args:
mm_hashes: The hash of each item to check.
Returns:
For each item, `True` if the item is cached, otherwise `False`.
"""
return [self.is_cached_item(mm_hash) for mm_hash in mm_hashes]
class MultiModalProcessorOnlyCache(BaseMultiModalProcessorCache):
"""
The cache which is used on P0 when IPC caching is disabled.
How to update each item:
- If the item is in the cache, replace the input with the cached item.
- If the item is not in the cache, store that item (which includes
tensor data and metadata) into the cache, and return the input.
"""
def __init__(self, model_config: "ModelConfig") -> None:
super().__init__()
mm_config = model_config.get_multimodal_config()
self._cache = MultiModalCache.get_lru_cache(
mm_config.mm_processor_cache_gb,
MultiModalProcessorCacheItem,
)
@override
def is_cached_item(self, mm_hash: str) -> bool:
return mm_hash in self._cache
@override
def get_and_update_item(
self,
mm_item: MultiModalProcessorCacheInItem,
mm_hash: str,
) -> MultiModalProcessorCacheOutItem:
if (cached_item := self._cache.get(mm_hash)) is not None:
return cached_item.item, cached_item.prompt_updates
assert mm_item is not None, f"Expected a cached item for {mm_hash=}"
self._cache[mm_hash] = MultiModalProcessorCacheItem(*mm_item)
return mm_item
@override
def clear_cache(self) -> None:
self._cache.clear()
class MultiModalProcessorSenderCache(BaseMultiModalProcessorCache):
"""
The cache which is used on P0 when IPC caching is enabled.
How to update each item:
- If the item is already in the cache, clear the input to avoid
unnecessary IPC.
- If the item is not in the cache, store the metadata of that item so
that the eviction policy remains the same as the cache on P1,
and return the input.
By only storing the metadata, we avoid keeping the data itself in
memory inside P0.
"""
def __init__(self, model_config: "ModelConfig") -> None:
super().__init__()
mm_config = model_config.get_multimodal_config()
self._cache = MultiModalCache.get_lru_cache(
mm_config.mm_processor_cache_gb,
MultiModalProcessorCacheItemMetadata,
)
@override
def is_cached_item(self, mm_hash: str) -> bool:
return mm_hash in self._cache
@override
def get_and_update_item(
self,
mm_item: MultiModalProcessorCacheInItem,
mm_hash: str,
) -> MultiModalProcessorCacheOutItem:
if (cached_item := self._cache.get(mm_hash)) is not None:
return None, cached_item.prompt_updates
assert mm_item is not None, f"Expected a cached item for {mm_hash=}"
self._cache[mm_hash] = MultiModalProcessorCacheItemMetadata(*mm_item)
return mm_item
@override
def clear_cache(self) -> None:
self._cache.clear()
def _enable_processor_cache(
model_config: "ModelConfig",
mm_registry: "MultiModalRegistry",
) -> bool:
if not mm_registry.supports_multimodal_inputs(model_config):
return False
mm_config = model_config.get_multimodal_config()
return mm_config.mm_processor_cache_gb > 0
def _enable_ipc_cache(vllm_config: "VllmConfig") -> bool:
parallel_config = vllm_config.parallel_config
supports_ipc_cache = (parallel_config.data_parallel_size == 1
or parallel_config.data_parallel_external_lb)
return supports_ipc_cache
def processor_cache_from_config(
vllm_config: "VllmConfig",
mm_registry: "MultiModalRegistry",
) -> Optional[BaseMultiModalProcessorCache]:
"""Return a `BaseMultiModalProcessorCache`, if enabled."""
model_config = vllm_config.model_config
if not _enable_processor_cache(model_config, mm_registry):
return None
if not _enable_ipc_cache(vllm_config):
return MultiModalProcessorOnlyCache(model_config)
return MultiModalProcessorSenderCache(model_config)
def processor_only_cache_from_config(
model_config: "ModelConfig",
mm_registry: "MultiModalRegistry",
):
"""Return a `MultiModalProcessorOnlyCache`, if enabled."""
if not _enable_processor_cache(model_config, mm_registry):
return None
return MultiModalProcessorOnlyCache(model_config)
class BaseMultiModalReceiverCache(
BaseMultiModalCache[Optional[MultiModalKwargsItem],
MultiModalKwargsItem]):
"""The required interface for caches on P1."""
class MultiModalReceiverCache(BaseMultiModalReceiverCache):
"""
The cache which is used on P1 when IPC caching is enabled.
How to update each item:
- If the item is in the cache, replace the input with the cached item.
- If the item is not in the cache, store that item (which includes tensor
data) into the cache, and return the input.
"""
def __init__(self, model_config: "ModelConfig") -> None:
super().__init__()
mm_config = model_config.get_multimodal_config()
self._cache = MultiModalCache.get_lru_cache(
mm_config.mm_processor_cache_gb,
MultiModalKwargsItem,
)
@override
def get_and_update_item(
self,
mm_item: Optional[MultiModalKwargsItem],
mm_hash: str,
) -> MultiModalKwargsItem:
if (cached_item := self._cache.get(mm_hash)) is not None:
return cached_item
assert mm_item is not None, f"Expected a cached item for {mm_hash=}"
self._cache[mm_hash] = mm_item
return mm_item
@override
def clear_cache(self) -> None:
self._cache.clear()
def receiver_cache_from_config(
vllm_config: "VllmConfig",
mm_registry: "MultiModalRegistry",
) -> Optional[BaseMultiModalReceiverCache]:
"""Return a `BaseMultiModalReceiverCache`, if enabled."""
model_config = vllm_config.model_config
if not _enable_processor_cache(model_config, mm_registry):
return None
if not _enable_ipc_cache(vllm_config):
return None
return MultiModalReceiverCache(model_config)
......@@ -7,11 +7,11 @@ from collections.abc import Mapping, Sequence
from dataclasses import dataclass
from functools import partial
from itertools import accumulate
from typing import (TYPE_CHECKING, Any, Literal, Optional, TypedDict, TypeVar,
Union, cast, final)
from typing import (TYPE_CHECKING, Any, Literal, Optional, TypedDict, Union,
cast, final)
import numpy as np
from typing_extensions import NotRequired, TypeAlias, deprecated
from typing_extensions import NotRequired, TypeAlias, TypeVar, deprecated
from vllm.utils import LazyLoader, full_groupby, is_list_of
from vllm.utils.jsontree import JSONTree, json_map_leaves
......@@ -668,7 +668,15 @@ class MultiModalKwargsItem(UserDict[str, MultiModalFieldElem]):
return {key: elem.data for key, elem in self.items()}
class MultiModalKwargsItems(UserDict[str, Sequence[MultiModalKwargsItem]]):
_I = TypeVar(
"_I",
MultiModalKwargsItem,
Optional[MultiModalKwargsItem],
default=MultiModalKwargsItem,
)
class MultiModalKwargsItems(UserDict[str, Sequence[_I]]):
"""
A dictionary of
[`MultiModalKwargsItem`][vllm.multimodal.inputs.MultiModalKwargsItem]s
......@@ -714,27 +722,37 @@ class MultiModalKwargsItems(UserDict[str, Sequence[MultiModalKwargsItem]]):
items_by_modality = full_groupby(items, key=lambda x: x.modality)
return MultiModalKwargsItems(items_by_modality)
def __getitem__(self, modality: str):
def __getitem__(self, modality: str) -> Sequence[_I]:
if modality not in self:
raise KeyError(f"Modality {modality!r} not found. "
f"Available modalities: {set(self.keys())}")
return super().__getitem__(modality)
return super().__getitem__(modality) # type: ignore[return-value]
def get_data(self, *, pin_memory: bool = False) -> "MultiModalKwargs":
elems_by_key = defaultdict[str, list[MultiModalFieldElem]](list)
for items in self.values():
for item in items:
for modality, items in self.items():
for i, item in enumerate(items):
if item is None:
raise RuntimeError("Cannot build data from empty "
f"mm_items[{modality}][{i}]")
for key, elem in item.items():
elems_by_key[key].append(elem)
return MultiModalKwargs({
key:
elems[0].field.reduce_data(elems, pin_memory=pin_memory)
for key, elems in elems_by_key.items() if len(elems) > 0
for key, elems in elems_by_key.items()
})
MultiModalKwargsOptionalItems: TypeAlias = Union[
MultiModalKwargsItems[MultiModalKwargsItem],
MultiModalKwargsItems[Optional[MultiModalKwargsItem]],
]
class MultiModalKwargs(UserDict[str, NestedTensors]):
"""
A dictionary that represents the keyword arguments to
......@@ -898,7 +916,7 @@ class MultiModalInputs(TypedDict):
token_type_ids: NotRequired[list[int]]
"""The token type IDs of the prompt."""
mm_kwargs: MultiModalKwargsItems
mm_kwargs: MultiModalKwargsOptionalItems
"""Keyword arguments to be directly passed to the model after batching."""
mm_hashes: "MultiModalHashDict"
......
......@@ -4,7 +4,7 @@ from abc import ABC, abstractmethod
from collections import defaultdict
from collections.abc import (Callable, Generator, ItemsView, Iterable, Mapping,
Sequence)
from dataclasses import dataclass, field
from dataclasses import dataclass, field, replace
from enum import Enum
from functools import lru_cache
from typing import (TYPE_CHECKING, Generic, NamedTuple, Optional, Protocol,
......@@ -20,12 +20,11 @@ from vllm.transformers_utils.tokenizer import (AnyTokenizer, decode_tokens,
encode_tokens)
from vllm.utils import flatten_2d_lists, full_groupby
from .cache import MultiModalCache
from .hasher import MultiModalHasher
from .inputs import (MultiModalDataDict, MultiModalEncDecInputs,
MultiModalFieldConfig, MultiModalInputs,
MultiModalKwargsItem, MultiModalKwargsItems,
PlaceholderRange)
MultiModalKwargsOptionalItems, PlaceholderRange)
from .parse import (DictEmbeddingItems, EmbeddingItems, MultiModalDataItems,
MultiModalDataParser)
......@@ -34,6 +33,7 @@ if TYPE_CHECKING:
from transformers.feature_extraction_utils import BatchFeature
from transformers.processing_utils import ProcessorMixin
from .cache import BaseMultiModalProcessorCache
from .profiling import BaseDummyInputsBuilder
logger = init_logger(__name__)
......@@ -557,6 +557,15 @@ class ResolvedPromptUpdate:
return self.iter_token_matches(prompt, tokenizer, start_idx=start_idx)
def with_target(self, target: UpdateTarget):
return replace(self, target=target)
def with_content(self, content: PromptUpdateInfo):
if not isinstance(content, PromptUpdateDetails):
content = PromptUpdateDetails.from_seq(content)
return replace(self, content=content)
class _TokenMatch(NamedTuple):
start_idx: int
......@@ -865,21 +874,6 @@ def find_mm_placeholders(
return dict(full_groupby_modality(it))
class ProcessingCache(MultiModalCache):
def __init__(self, capacity_gb: float) -> None:
super().__init__()
self._cache = self.get_lru_cache(capacity_gb, MultiModalKwargsItem)
self.get = self._cache.get
self.put = self._cache.put
self.reset = self._cache.clear
_CacheItemOrHash = Union[MultiModalKwargsItem, str]
class BaseProcessingInfo:
"""Base class to provide the information necessary for data processing."""
......@@ -982,7 +976,7 @@ For an item `MultiModalPromptUpdates[k][i]`,
class MultiModalProcessingInfo(NamedTuple):
kwargs: MultiModalKwargsItems
kwargs: MultiModalKwargsOptionalItems
hashes: MultiModalHashes
prompt_updates: MultiModalPromptUpdates
......@@ -994,11 +988,13 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
Not to be confused with `transformers.ProcessorMixin`.
"""
def __init__(self,
info: _I,
dummy_inputs: "BaseDummyInputsBuilder[_I]",
*,
cache: Optional[ProcessingCache] = None) -> None:
def __init__(
self,
info: _I,
dummy_inputs: "BaseDummyInputsBuilder[_I]",
*,
cache: Optional["BaseMultiModalProcessorCache"] = None,
) -> None:
super().__init__()
self.info = info
......@@ -1355,32 +1351,6 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
return prompt_ids, mm_processed_data, False
def _get_cache_missing_items(
self,
cache: ProcessingCache,
mm_data_items: MultiModalDataItems,
mm_hashes: MultiModalHashes,
) -> tuple[dict[str, list[_CacheItemOrHash]], MultiModalDataItems]:
mm_cache_items_or_hashes: dict[str, list[_CacheItemOrHash]] = {
modality: [(h if (v := cache.get(h)) is None else v)
for h in hashes]
for modality, hashes in mm_hashes.items()
}
mm_missing_idxs = {
modality: [
idx for idx, item_or_hash in enumerate(items_or_hashes)
if isinstance(item_or_hash, str)
]
for modality, items_or_hashes in mm_cache_items_or_hashes.items()
}
mm_missing_data = {
modality: [mm_data_items[modality][idx] for idx in idxs]
for modality, idxs in mm_missing_idxs.items()
}
return mm_cache_items_or_hashes, self._to_mm_items(mm_missing_data)
def _hash_mm_items(
self,
mm_items: MultiModalDataItems,
......@@ -1401,28 +1371,92 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
for modality, items in mm_items.items()
}
def _get_cache_missing_items(
self,
cache: "BaseMultiModalProcessorCache",
mm_data_items: MultiModalDataItems,
mm_hashes: MultiModalHashes,
) -> MultiModalDataItems:
mm_is_cached = {
modality: cache.is_cached(hashes)
for modality, hashes in mm_hashes.items()
}
mm_missing_idxs = {
modality: [
idx for idx, item_is_cached in enumerate(items_is_cached)
if not item_is_cached
]
for modality, items_is_cached in mm_is_cached.items()
}
mm_missing_data = {
modality: [mm_data_items[modality][idx] for idx in idxs]
for modality, idxs in mm_missing_idxs.items()
}
return self._to_mm_items(mm_missing_data)
def _recompute_cached_prompt_update(
self,
cached_update: ResolvedPromptUpdate,
new_item_idx: int,
) -> ResolvedPromptUpdate:
"""
Override this if other attributes of `ResolvedPromptUpdate`
also need to be recomputed after retrieving from the cache.
"""
return replace(cached_update, item_idx=new_item_idx)
def _merge_mm_kwargs(
self,
cache: ProcessingCache,
mm_cache_items_or_hashes: dict[str, list[_CacheItemOrHash]],
cache: "BaseMultiModalProcessorCache",
mm_hashes: MultiModalHashes,
mm_missing_kwargs: MultiModalKwargsItems,
) -> MultiModalKwargsItems:
mm_missing_prompt_updates: MultiModalPromptUpdates,
) -> tuple[MultiModalKwargsOptionalItems, MultiModalPromptUpdates]:
# Need to calculate this at the beginning to avoid skipping cache logic
# for subsequently repeated items in the same modality
mm_is_cached = {
modality: cache.is_cached(hashes)
for modality, hashes in mm_hashes.items()
}
mm_missing_next_idx = defaultdict[str, int](lambda: 0)
merged_items = defaultdict[str, list[MultiModalKwargsItem]](list)
for modality, items_or_hashes in mm_cache_items_or_hashes.items():
for item_or_hash in items_or_hashes:
if isinstance(item_or_hash, str):
kw_item = mm_missing_kwargs[modality][
mm_missing_next_idx[modality]]
cache.put(item_or_hash, kw_item)
merged_kwargs = defaultdict[str,
list[Optional[MultiModalKwargsItem]]](list)
merged_prompt_updates = defaultdict[
str, list[Sequence[ResolvedPromptUpdate]]](list)
for modality, hashes in mm_hashes.items():
missing_kwargs = mm_missing_kwargs.get(modality, [])
missing_prompt_updates = mm_missing_prompt_updates.get(
modality, [])
for item_idx, item_hash in enumerate(hashes):
kwargs: Optional[MultiModalKwargsItem]
if not mm_is_cached[modality][item_idx]:
missing_next_idx = mm_missing_next_idx[modality]
kwargs = missing_kwargs[missing_next_idx]
updates = missing_prompt_updates[missing_next_idx]
mm_missing_next_idx[modality] += 1
item = kwargs, updates
else:
kw_item = item_or_hash
item = None
kwargs, updates = cache.get_and_update_item(item, item_hash)
merged_kwargs[modality].append(kwargs)
merged_prompt_updates[modality].append([
self._recompute_cached_prompt_update(update, item_idx)
for update in updates
])
merged_items[modality].append(kw_item)
mm_kwargs = MultiModalKwargsItems(merged_kwargs)
mm_prompt_updates = dict(merged_prompt_updates)
return MultiModalKwargsItems(merged_items)
return mm_kwargs, mm_prompt_updates
def _apply_hf_processor(
self,
......@@ -1490,10 +1524,8 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
mm_hashes = self._hash_mm_items(mm_data_items, hf_processor_mm_kwargs,
tokenization_kwargs)
(
mm_cache_items_or_hashes,
mm_missing_data_items,
) = self._get_cache_missing_items(
mm_missing_data_items = self._get_cache_missing_items(
cache=cache,
mm_data_items=mm_data_items,
mm_hashes=mm_hashes,
......@@ -1520,16 +1552,17 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
hf_processor_mm_kwargs),
)
mm_kwargs = self._merge_mm_kwargs(
cache,
mm_cache_items_or_hashes=mm_cache_items_or_hashes,
mm_missing_kwargs=mm_missing_kwargs,
mm_missing_prompt_updates = self._get_mm_prompt_updates(
mm_missing_data_items,
hf_processor_mm_kwargs,
mm_missing_kwargs,
)
mm_prompt_updates = self._get_mm_prompt_updates(
mm_data_items,
hf_processor_mm_kwargs,
mm_kwargs,
mm_kwargs, mm_prompt_updates = self._merge_mm_kwargs(
cache,
mm_hashes=mm_hashes,
mm_missing_kwargs=mm_missing_kwargs,
mm_missing_prompt_updates=mm_missing_prompt_updates,
)
mm_info = MultiModalProcessingInfo(
......@@ -1614,7 +1647,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
def _validate_mm_kwargs(
self,
mm_kwargs: MultiModalKwargsItems,
mm_kwargs: MultiModalKwargsOptionalItems,
mm_item_counts: Mapping[str, int],
) -> None:
for modality, item_count in mm_item_counts.items():
......@@ -1655,7 +1688,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
self,
mm_items: MultiModalDataItems,
prompt_ids: list[int],
mm_kwargs: MultiModalKwargsItems,
mm_kwargs: MultiModalKwargsOptionalItems,
mm_prompt_updates: MultiModalPromptUpdates,
is_update_applied: bool,
) -> tuple[list[int], str, Mapping[str, list[PlaceholderFeaturesInfo]]]:
......
......@@ -13,7 +13,7 @@ import vllm.envs as envs
from vllm.logger import init_logger
from .inputs import (MultiModalDataDict, MultiModalEncDecInputs,
MultiModalInputs, MultiModalKwargsItems,
MultiModalInputs, MultiModalKwargsOptionalItems,
MultiModalPlaceholderDict)
from .processing import (BaseMultiModalProcessor, BaseProcessingInfo,
EncDecMultiModalProcessor)
......@@ -43,7 +43,7 @@ class DummyDecoderData(NamedTuple):
"""Dummy data used for profiling."""
prompt_token_ids: list[int]
multi_modal_data: MultiModalKwargsItems
multi_modal_data: MultiModalKwargsOptionalItems
multi_modal_placeholders: MultiModalPlaceholderDict
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment