Unverified Commit 69244e67 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Core] Use key-only cache for `BaseMultiModalProcessor` (#23018)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 8dbf6ed7
...@@ -86,7 +86,7 @@ llm = LLM(model="meta-llama/Llama-3.1-8B-Instruct", ...@@ -86,7 +86,7 @@ llm = LLM(model="meta-llama/Llama-3.1-8B-Instruct",
If you run out of CPU RAM, try the following options: If you run out of CPU RAM, try the following options:
- (Multi-modal models only) you can set the size of multi-modal processor cache by setting `mm_processor_cache_gb` engine argument (default 4 GiB per API process + 4 GiB per engine core process) - (Multi-modal models only) you can set the size of multi-modal cache by setting `mm_processor_cache_gb` engine argument (default 4 GiB).
- (CPU backend only) you can set the size of KV cache using `VLLM_CPU_KVCACHE_SPACE` environment variable (default 4 GiB). - (CPU backend only) you can set the size of KV cache using `VLLM_CPU_KVCACHE_SPACE` environment variable (default 4 GiB).
## Multi-modal input limits ## Multi-modal input limits
......
...@@ -204,20 +204,33 @@ vllm serve Qwen/Qwen2.5-VL-3B-Instruct --api-server-count 4 -dp 2 ...@@ -204,20 +204,33 @@ vllm serve Qwen/Qwen2.5-VL-3B-Instruct --api-server-count 4 -dp 2
to avoid CPU resource exhaustion. to avoid CPU resource exhaustion.
!!! note !!! note
[Multi-modal processor cache](#processor-cache) is disabled when API server scale-out is enabled API server scale-out disables [multi-modal IPC caching](#ipc-caching)
because it requires a one-to-one correspondence between API and engine core processes. because it requires a one-to-one correspondance between API and engine core processes.
## Multi-Modal Caching This does not impact [multi-modal processor caching](#processor-caching).
### Processor Cache ## Multi-Modal Caching
By default, the multi-modal processor cache is enabled to avoid repeatedly processing Multi-modal caching avoids repeated transfer or processing of the same multi-modal data,
the same multi-modal inputs via Hugging Face `AutoProcessor`,
which commonly occurs in multi-turn conversations. which commonly occurs in multi-turn conversations.
You can adjust the size of the cache by setting the value of `mm_processor_cache_gb` ### Processor Caching
(default 4 GiB per API process + 4 GiB per engine core process).
If you do not benefit much from the cache, you can disable it completely via `mm_processor_cache_gb=0`. Multi-modal processor caching is automatically enabled
to avoid repeatedly processing the same multi-modal inputs in `BaseMultiModalProcessor`.
### IPC Caching
Multi-modal IPC caching is automatically enabled when
there is a one-to-one correspondance between API (`P0`) and engine core (`P1`) processes,
to avoid repeatedly transferring the same multi-modal inputs between them.
### Configuration
You can adjust the size of the cache by setting the value of `mm_processor_cache_gb` (default 4 GiB).
If you do not benefit much from the cache, you can disable both IPC
and processor caching completely via `mm_processor_cache_gb=0`.
Examples: Examples:
...@@ -230,3 +243,16 @@ llm = LLM(model="Qwen/Qwen2.5-VL-3B-Instruct", ...@@ -230,3 +243,16 @@ llm = LLM(model="Qwen/Qwen2.5-VL-3B-Instruct",
llm = LLM(model="Qwen/Qwen2.5-VL-3B-Instruct", llm = LLM(model="Qwen/Qwen2.5-VL-3B-Instruct",
mm_processor_cache_gb=0) mm_processor_cache_gb=0)
``` ```
### Cache Placement
Based on the configuration, the content of the multi-modal caches on `P0` and `P1` are as follows:
| Processor Caching | IPC Caching | `P0` Cache | `P1` Cache | Max. Memory |
|-------------------|-------------|------------|------------|-------------|
| ✅ | ✅ | K | K + V | `mm_processor_cache_gb * data_parallel_size` |
| ✅ | ❌ | K + V | N/A | `mm_processor_cache_gb * api_server_count` |
| ❌ | ❌ | N/A | N/A | `0` |
K: Stores the hashes of multi-modal items
V: Stores the processed tensor data of multi-modal items
...@@ -14,8 +14,9 @@ from PIL import Image ...@@ -14,8 +14,9 @@ from PIL import Image
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.inputs import InputProcessingContext from vllm.inputs import InputProcessingContext
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict
from vllm.multimodal.cache import MultiModalProcessorOnlyCache
from vllm.multimodal.inputs import MultiModalInputs from vllm.multimodal.inputs import MultiModalInputs
from vllm.multimodal.processing import BaseMultiModalProcessor, ProcessingCache from vllm.multimodal.processing import BaseMultiModalProcessor
from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer, from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
cached_tokenizer_from_config, cached_tokenizer_from_config,
encode_tokens) encode_tokens)
...@@ -63,6 +64,8 @@ def _test_processing_correctness( ...@@ -63,6 +64,8 @@ def _test_processing_correctness(
revision=model_info.revision, revision=model_info.revision,
trust_remote_code=model_info.trust_remote_code, trust_remote_code=model_info.trust_remote_code,
hf_overrides=model_info.hf_overrides, hf_overrides=model_info.hf_overrides,
# Ensure that the cache can fit all of the data
mm_processor_cache_gb=2048,
) )
model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config) model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config)
...@@ -71,8 +74,7 @@ def _test_processing_correctness( ...@@ -71,8 +74,7 @@ def _test_processing_correctness(
model_config, model_config,
tokenizer=cached_tokenizer_from_config(model_config), tokenizer=cached_tokenizer_from_config(model_config),
) )
# Ensure that it can fit all of the data cache = MultiModalProcessorOnlyCache(model_config)
cache = ProcessingCache(capacity_gb=2048)
processing_info = factories.info(ctx) processing_info = factories.info(ctx)
supported_mm_limits = processing_info.get_supported_mm_limits() supported_mm_limits = processing_info.get_supported_mm_limits()
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
import numpy as np
import pytest import pytest
import torch import torch
from vllm.multimodal.cache import MultiModalCache, MultiModalCacheItemMetadata from vllm.config import ModelConfig, ParallelConfig, VllmConfig
from vllm.multimodal.cache import (MultiModalCache,
MultiModalProcessorCacheItem,
MultiModalProcessorCacheItemMetadata,
processor_cache_from_config,
receiver_cache_from_config)
from vllm.multimodal.hasher import MultiModalHasher
from vllm.multimodal.inputs import (MultiModalFieldElem, MultiModalKwargsItem, from vllm.multimodal.inputs import (MultiModalFieldElem, MultiModalKwargsItem,
MultiModalKwargsItems, MultiModalKwargsItems,
MultiModalSharedField) MultiModalSharedField)
from vllm.multimodal.processing import PromptInsertion
from vllm.multimodal.registry import MultiModalRegistry
def _dummy_elem(
modality: str,
key: str,
size: int,
*,
rng: Optional[np.random.RandomState] = None,
):
if rng is None:
data = torch.empty((size, ), dtype=torch.int8)
else:
data = torch.from_numpy(rng.randint(4, size=(size, ), dtype=np.int8))
def _dummy_elem(modality: str, key: str, size: int):
return MultiModalFieldElem( return MultiModalFieldElem(
modality=modality, modality=modality,
key=key, key=key,
data=torch.empty((size, ), dtype=torch.int8), data=data,
field=MultiModalSharedField(1), field=MultiModalSharedField(1),
) )
def _dummy_item(modality: str, size_by_key: dict[str, int]): def _dummy_item(
modality: str,
size_by_key: dict[str, int],
*,
rng: Optional[np.random.RandomState] = None,
):
return MultiModalKwargsItem.from_elems([ return MultiModalKwargsItem.from_elems([
_dummy_elem(modality, key, size) for key, size in size_by_key.items() _dummy_elem(modality, key, size, rng=rng)
for key, size in size_by_key.items()
]) ])
def _dummy_items(size_by_key_modality: dict[str, dict[str, int]]): def _dummy_items(
size_by_key_modality: dict[str, dict[str, int]],
*,
rng: Optional[np.random.RandomState] = None,
):
return MultiModalKwargsItems.from_seq([ return MultiModalKwargsItems.from_seq([
_dummy_item(modality, size_by_key) _dummy_item(modality, size_by_key, rng=rng)
for modality, size_by_key in size_by_key_modality.items() for modality, size_by_key in size_by_key_modality.items()
]) ])
...@@ -48,5 +80,139 @@ def test_cache_item_size(item, expected_size): ...@@ -48,5 +80,139 @@ def test_cache_item_size(item, expected_size):
cache[""] = item cache[""] = item
assert cache.currsize == expected_size assert cache.currsize == expected_size
cache[""] = MultiModalCacheItemMetadata.wraps(item) prompt_update = PromptInsertion("dummy", "target", "insertion") \
.resolve(0)
cache[""] = MultiModalProcessorCacheItem(item, [prompt_update])
assert cache.currsize == expected_size
cache[""] = MultiModalProcessorCacheItemMetadata(item, [prompt_update])
assert cache.currsize == expected_size assert cache.currsize == expected_size
def _create_vllm_config(
*,
mm_processor_cache_gb: float,
enable_ipc: bool,
):
return VllmConfig(
model_config=ModelConfig(mm_processor_cache_gb=mm_processor_cache_gb),
parallel_config=ParallelConfig(
data_parallel_size=1 if enable_ipc else 2),
)
def _compare_caches(
config_0: VllmConfig,
config_1: VllmConfig,
*,
item_capacity: int = 8,
hit_rate: float = 0.5,
max_items_per_iter: int = 3,
is_cached_calls_per_iter: int,
n_iter: int = 100,
seed: int = 0,
):
mm_registry = MultiModalRegistry()
cache_0_p0 = processor_cache_from_config(config_0, mm_registry)
cache_0_p1 = receiver_cache_from_config(config_0, mm_registry)
cache_1_p0 = processor_cache_from_config(config_1, mm_registry)
cache_1_p1 = receiver_cache_from_config(config_1, mm_registry)
cache_size_gb = max(
config_0.model_config.mm_processor_cache_gb,
config_1.model_config.mm_processor_cache_gb,
)
item_size_gb = int(cache_size_gb / item_capacity)
rng = np.random.RandomState(seed)
all_items = [
_dummy_item("item", {"key": item_size_gb}, rng=rng)
for _ in range(int(item_capacity / hit_rate))
]
all_hashes = [
MultiModalHasher.hash_kwargs(item=item.get_data())
for item in all_items
]
# Should not be used since there is nothing to convert to text
prompt_update = PromptInsertion("dummy", "target", "insertion")
for it in range(n_iter):
num_items_to_select = rng.randint(0, max_items_per_iter)
item_idxs_to_select = rng.choice(len(all_items), num_items_to_select)
selected_items = [all_items[idx] for idx in item_idxs_to_select]
selected_hashes = [all_hashes[idx] for idx in item_idxs_to_select]
if cache_0_p0 is None:
cache_0_p0_out = selected_items
else:
for _ in range(is_cached_calls_per_iter):
cache_0_p0.is_cached(selected_hashes)
cache_0_p0_out = [
item for item, _ in cache_0_p0.get_and_update(
[(item, prompt_update.content) for item in selected_items],
selected_hashes,
)
]
if cache_1_p0 is None:
cache_1_p0_out = selected_items
else:
for _ in range(is_cached_calls_per_iter):
cache_1_p0.is_cached(selected_hashes)
cache_1_p0_out = [
item for item, _ in cache_1_p0.get_and_update(
[(item, prompt_update.content) for item in selected_items],
selected_hashes,
)
]
if cache_0_p1 is None:
cache_0_p1_out = cache_0_p0_out
else:
cache_0_p1_out = cache_0_p1.get_and_update(cache_0_p0_out,
selected_hashes)
if cache_1_p1 is None:
cache_1_p1_out = cache_1_p0_out
else:
cache_1_p1_out = cache_1_p1.get_and_update(cache_1_p0_out,
selected_hashes)
assert cache_0_p1_out == cache_1_p1_out, f"Failed at {it=}"
@pytest.mark.parametrize("is_cached_calls_per_iter", [1, 2, 3])
def test_ipc_enable_disable_consistency(is_cached_calls_per_iter):
cache_size_gb = 1 / (1 << 20)
vllm_config_ipc_enabled = _create_vllm_config(
mm_processor_cache_gb=cache_size_gb,
enable_ipc=True,
)
vllm_config_ipc_disabled = _create_vllm_config(
mm_processor_cache_gb=0,
enable_ipc=False,
)
vllm_config_cache_disabled = _create_vllm_config(
mm_processor_cache_gb=cache_size_gb,
enable_ipc=True,
)
_compare_caches(
vllm_config_ipc_enabled,
vllm_config_ipc_disabled,
is_cached_calls_per_iter=is_cached_calls_per_iter,
)
_compare_caches(
vllm_config_ipc_disabled,
vllm_config_cache_disabled,
is_cached_calls_per_iter=is_cached_calls_per_iter,
)
_compare_caches(
vllm_config_cache_disabled,
vllm_config_ipc_enabled,
is_cached_calls_per_iter=is_cached_calls_per_iter,
)
...@@ -437,7 +437,7 @@ class ModelConfig: ...@@ -437,7 +437,7 @@ class ModelConfig:
from `AutoProcessor.from_pretrained`. The available overrides depend on the from `AutoProcessor.from_pretrained`. The available overrides depend on the
model that is being run. For example, for Phi-3-Vision: `{"num_crops": 4}`. model that is being run. For example, for Phi-3-Vision: `{"num_crops": 4}`.
""" """
mm_processor_cache_gb: int = 4 mm_processor_cache_gb: float = 4
"""The size (in GiB) of the multi-modal processor cache, which is used to """The size (in GiB) of the multi-modal processor cache, which is used to
avoid re-processing past multi-modal inputs. avoid re-processing past multi-modal inputs.
...@@ -884,12 +884,6 @@ class ModelConfig: ...@@ -884,12 +884,6 @@ class ModelConfig:
return None return None
def set_mm_processor_cache_gb(self, value: int) -> None:
mm_config = self.get_multimodal_config()
self.mm_processor_cache_gb = value
mm_config.mm_processor_cache_gb = value
def _get_encoder_config(self): def _get_encoder_config(self):
return get_sentence_transformer_tokenizer_config( return get_sentence_transformer_tokenizer_config(
self.model, self.revision) self.model, self.revision)
...@@ -1697,22 +1691,6 @@ class ModelConfig: ...@@ -1697,22 +1691,6 @@ class ModelConfig:
def is_multimodal_model(self) -> bool: def is_multimodal_model(self) -> bool:
return self.multimodal_config is not None return self.multimodal_config is not None
@property
def enable_mm_processor_cache(self) -> bool:
"""Whether the multi-modal processor cache should be enabled."""
mm_config = self.multimodal_config
if mm_config is None:
return False
return mm_config.mm_processor_cache_gb > 0
def get_mm_input_cache_gb(self) -> int:
mm_config = self.multimodal_config
if mm_config is None:
return 0
return envs.VLLM_MM_INPUT_CACHE_GIB
@property @property
def is_cross_encoder(self) -> bool: def is_cross_encoder(self) -> bool:
return (self._model_info.supports_cross_encoding return (self._model_info.supports_cross_encoding
...@@ -2561,7 +2539,7 @@ class MultiModalConfig: ...@@ -2561,7 +2539,7 @@ class MultiModalConfig:
`{"num_crops": 4}`. `{"num_crops": 4}`.
""" """
mm_processor_cache_gb: int = 4 mm_processor_cache_gb: float = 4
""" """
The size (in GiB) of the multi-modal processor cache, which is used to The size (in GiB) of the multi-modal processor cache, which is used to
......
...@@ -351,7 +351,7 @@ class EngineArgs: ...@@ -351,7 +351,7 @@ class EngineArgs:
mm_processor_kwargs: Optional[Dict[str, Any]] = \ mm_processor_kwargs: Optional[Dict[str, Any]] = \
MultiModalConfig.mm_processor_kwargs MultiModalConfig.mm_processor_kwargs
disable_mm_preprocessor_cache: bool = False # DEPRECATED disable_mm_preprocessor_cache: bool = False # DEPRECATED
mm_processor_cache_gb: int = MultiModalConfig.mm_processor_cache_gb mm_processor_cache_gb: float = MultiModalConfig.mm_processor_cache_gb
mm_encoder_tp_mode: MMEncoderTPMode = MultiModalConfig.mm_encoder_tp_mode mm_encoder_tp_mode: MMEncoderTPMode = MultiModalConfig.mm_encoder_tp_mode
skip_mm_profiling: bool = MultiModalConfig.skip_mm_profiling skip_mm_profiling: bool = MultiModalConfig.skip_mm_profiling
# LoRA fields # LoRA fields
...@@ -1293,18 +1293,6 @@ class EngineArgs: ...@@ -1293,18 +1293,6 @@ class EngineArgs:
worker_extension_cls=self.worker_extension_cls, worker_extension_cls=self.worker_extension_cls,
) )
if model_config.is_multimodal_model:
dp_supports_mm_processor_cache = (self.data_parallel_size == 1
or data_parallel_external_lb)
if (not dp_supports_mm_processor_cache
and model_config.mm_processor_cache_gb > 0):
logger.warning(
"Multi-modal processor cache is disabled because "
"it is not compatible with data parallelism when "
"there does not exist a one-to-one correspondance "
"between API and engine core processes.")
model_config.set_mm_processor_cache_gb(0)
speculative_config = self.create_speculative_config( speculative_config = self.create_speculative_config(
target_model_config=model_config, target_model_config=model_config,
target_parallel_config=parallel_config, target_parallel_config=parallel_config,
......
...@@ -36,6 +36,7 @@ from vllm.logits_process import get_bad_words_logits_processors ...@@ -36,6 +36,7 @@ from vllm.logits_process import get_bad_words_logits_processors
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.multimodal.cache import processor_only_cache_from_config
from vllm.multimodal.processing import EncDecMultiModalProcessor from vllm.multimodal.processing import EncDecMultiModalProcessor
from vllm.outputs import (PoolingRequestOutput, RequestOutput, from vllm.outputs import (PoolingRequestOutput, RequestOutput,
RequestOutputFactory) RequestOutputFactory)
...@@ -250,9 +251,13 @@ class LLMEngine: ...@@ -250,9 +251,13 @@ class LLMEngine:
self.generation_config_fields = ( self.generation_config_fields = (
self.model_config.try_get_generation_config()) self.model_config.try_get_generation_config())
self.input_preprocessor = InputPreprocessor(self.model_config, self.input_preprocessor = InputPreprocessor(
self.tokenizer, self.model_config,
mm_registry) self.tokenizer,
mm_registry,
mm_processor_cache=processor_only_cache_from_config(
self.model_config, mm_registry),
)
self.model_executor = executor_class(vllm_config=vllm_config) self.model_executor = executor_class(vllm_config=vllm_config)
...@@ -840,8 +845,8 @@ class LLMEngine: ...@@ -840,8 +845,8 @@ class LLMEngine:
def reset_mm_cache(self) -> bool: def reset_mm_cache(self) -> bool:
"""Reset the multi-modal cache.""" """Reset the multi-modal cache."""
return self.input_preprocessor.mm_registry.reset_processor_cache( self.input_preprocessor.clear_cache()
self.model_config) return True
def reset_prefix_cache(self, device: Optional[Device] = None) -> bool: def reset_prefix_cache(self, device: Optional[Device] = None) -> bool:
"""Reset prefix cache for all devices.""" """Reset prefix cache for all devices."""
......
...@@ -11,6 +11,7 @@ from vllm.config import ModelConfig ...@@ -11,6 +11,7 @@ from vllm.config import ModelConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.multimodal.cache import BaseMultiModalProcessorCache
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalEncDecInputs, from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalEncDecInputs,
MultiModalInputs) MultiModalInputs)
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer
...@@ -32,12 +33,14 @@ class InputPreprocessor: ...@@ -32,12 +33,14 @@ class InputPreprocessor:
model_config: ModelConfig, model_config: ModelConfig,
tokenizer: Optional[TokenizerGroup], tokenizer: Optional[TokenizerGroup],
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY, mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
mm_processor_cache: Optional[BaseMultiModalProcessorCache] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.model_config = model_config self.model_config = model_config
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.mm_registry = mm_registry self.mm_registry = mm_registry
self.mm_processor_cache = mm_processor_cache
def get_tokenizer_group(self) -> TokenizerGroup: def get_tokenizer_group(self) -> TokenizerGroup:
if self.tokenizer is None: if self.tokenizer is None:
...@@ -261,8 +264,11 @@ class InputPreprocessor: ...@@ -261,8 +264,11 @@ class InputPreprocessor:
""" """
tokenizer = self._get_mm_tokenizer(lora_request) tokenizer = self._get_mm_tokenizer(lora_request)
mm_processor = self.mm_registry.create_processor(self.model_config, mm_processor = self.mm_registry.create_processor(
tokenizer=tokenizer) self.model_config,
tokenizer=tokenizer,
cache=self.mm_processor_cache,
)
if mm_processor_kwargs is None: if mm_processor_kwargs is None:
mm_processor_kwargs = {} mm_processor_kwargs = {}
...@@ -286,8 +292,12 @@ class InputPreprocessor: ...@@ -286,8 +292,12 @@ class InputPreprocessor:
""" """
tokenizer = await self._get_mm_tokenizer_async(lora_request) tokenizer = await self._get_mm_tokenizer_async(lora_request)
mm_processor = self.mm_registry.create_processor(self.model_config, mm_processor = self.mm_registry.create_processor(
tokenizer=tokenizer) self.model_config,
tokenizer=tokenizer,
cache=self.mm_processor_cache,
)
if mm_processor_kwargs is None: if mm_processor_kwargs is None:
mm_processor_kwargs = {} mm_processor_kwargs = {}
...@@ -860,3 +870,7 @@ class InputPreprocessor: ...@@ -860,3 +870,7 @@ class InputPreprocessor:
tokenization_kwargs=tokenization_kwargs, tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request, lora_request=lora_request,
) )
def clear_cache(self) -> None:
if self.mm_processor_cache is not None:
self.mm_processor_cache.clear_cache()
...@@ -223,20 +223,26 @@ class InputRegistry: ...@@ -223,20 +223,26 @@ class InputRegistry:
The model is identified by ``model_config``. The model is identified by ``model_config``.
""" """
# Avoid circular import # Avoid circular import
from vllm.multimodal.cache import processor_only_cache_from_config
from vllm.sequence import SequenceData from vllm.sequence import SequenceData
if not model_config.is_multimodal_model: if not model_config.is_multimodal_model:
seq_data = SequenceData.from_prompt_token_counts((0, seq_len)) seq_data = SequenceData.from_prompt_token_counts((0, seq_len))
return DummyData(seq_data=seq_data) return DummyData(seq_data=seq_data)
cache = processor_only_cache_from_config(model_config, mm_registry)
# Encoder dummy data does not contain multi-modal data # Encoder dummy data does not contain multi-modal data
if is_encoder_data: if is_encoder_data:
enc_data = mm_registry.get_encoder_dummy_data( enc_data = mm_registry.get_encoder_dummy_data(model_config,
model_config, seq_len) seq_len,
cache=cache)
seq_data = SequenceData.from_seqs(enc_data.prompt_token_ids) seq_data = SequenceData.from_seqs(enc_data.prompt_token_ids)
return DummyData(seq_data=seq_data) return DummyData(seq_data=seq_data)
dec_data = mm_registry.get_decoder_dummy_data(model_config, seq_len) dec_data = mm_registry.get_decoder_dummy_data(model_config,
seq_len,
cache=cache)
return DummyData( return DummyData(
seq_data=SequenceData.from_seqs(dec_data.prompt_token_ids), seq_data=SequenceData.from_seqs(dec_data.prompt_token_ids),
......
...@@ -33,12 +33,13 @@ from vllm.inputs import InputProcessingContext ...@@ -33,12 +33,13 @@ from vllm.inputs import InputProcessingContext
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.cache import BaseMultiModalProcessorCache
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalKwargsItems) MultiModalKwargsItems)
from vllm.multimodal.parse import ImageSize, MultiModalDataItems from vllm.multimodal.parse import ImageSize, MultiModalDataItems
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, ProcessingCache, BaseProcessingInfo, PromptReplacement,
PromptReplacement, PromptUpdate) PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
...@@ -367,7 +368,7 @@ def _build_hcxvision_hf_processor( ...@@ -367,7 +368,7 @@ def _build_hcxvision_hf_processor(
info: HCXVisionProcessingInfo, info: HCXVisionProcessingInfo,
dummy_inputs: BaseDummyInputsBuilder[HCXVisionProcessingInfo], dummy_inputs: BaseDummyInputsBuilder[HCXVisionProcessingInfo],
*, *,
cache: Optional[ProcessingCache] = None, cache: Optional[BaseMultiModalProcessorCache] = None,
) -> BaseMultiModalProcessor: ) -> BaseMultiModalProcessor:
if isinstance(info, HCXVisionProcessingInfo): if isinstance(info, HCXVisionProcessingInfo):
return HCXVisionMultiModalProcessor( return HCXVisionMultiModalProcessor(
......
...@@ -22,14 +22,14 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, ...@@ -22,14 +22,14 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.cache import BaseMultiModalProcessorCache
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalInputs, MultiModalKwargsItems) MultiModalInputs, MultiModalKwargsItems)
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
ImageSize, MultiModalDataItems) ImageSize, MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, ProcessingCache, BaseProcessingInfo, PromptReplacement,
PromptReplacement, PromptUpdate, PromptUpdate, PromptUpdateDetails)
PromptUpdateDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils.jsontree import json_map_leaves from vllm.utils.jsontree import json_map_leaves
...@@ -394,7 +394,7 @@ def _build_llava_or_pixtral_hf_processor( ...@@ -394,7 +394,7 @@ def _build_llava_or_pixtral_hf_processor(
info: _I, info: _I,
dummy_inputs: BaseDummyInputsBuilder[_I], dummy_inputs: BaseDummyInputsBuilder[_I],
*, *,
cache: Optional[ProcessingCache] = None, cache: Optional[BaseMultiModalProcessorCache] = None,
) -> BaseMultiModalProcessor: ) -> BaseMultiModalProcessor:
if isinstance(info, PixtralHFProcessingInfo): if isinstance(info, PixtralHFProcessingInfo):
return PixtralHFMultiModalProcessor( return PixtralHFMultiModalProcessor(
......
...@@ -58,7 +58,8 @@ from vllm.multimodal.parse import (DictEmbeddingItems, ImageItem, ...@@ -58,7 +58,8 @@ from vllm.multimodal.parse import (DictEmbeddingItems, ImageItem,
VideoItem, VideoProcessorItems) VideoItem, VideoProcessorItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement, BaseProcessingInfo, PromptReplacement,
PromptUpdate, PromptUpdateDetails) PromptUpdate, PromptUpdateDetails,
ResolvedPromptUpdate, _seq2text)
from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
...@@ -744,6 +745,43 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]): ...@@ -744,6 +745,43 @@ class MiniCPMVMultiModalProcessor(BaseMultiModalProcessor[_I]):
for modality, pattern in placeholders for modality, pattern in placeholders
] ]
def _recompute_cached_prompt_update(
self,
cached_update: ResolvedPromptUpdate,
new_item_idx: int,
) -> ResolvedPromptUpdate:
new_update = super()._recompute_cached_prompt_update(
cached_update,
new_item_idx,
)
if cached_update.modality == "image":
tokenizer = self.info.get_tokenizer()
image_processor = self.info.get_image_processor()
version = self.info.get_model_version()
text = _seq2text(tokenizer, cached_update.content.full)
prev_item_idx = cached_update.item_idx
if version == (2, 0) or version == (2, 5):
im_start = image_processor.im_start_token
im_end = image_processor.im_end_token
else:
im_start = image_processor.im_id_start
im_end = image_processor.im_id_end
new_update = new_update.with_content(
PromptUpdateDetails.select_text(
text.replace(
f"{im_start}{prev_item_idx}{im_end}",
f"{im_start}{new_item_idx}{im_end}",
1,
),
"<unk>",
))
return new_update
def _get_mm_fields_config( def _get_mm_fields_config(
self, self,
hf_inputs: BatchFeature, hf_inputs: BatchFeature,
......
...@@ -22,14 +22,14 @@ from vllm.model_executor.layers.quantization import QuantizationConfig ...@@ -22,14 +22,14 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.cache import BaseMultiModalProcessorCache
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalKwargsItems) MultiModalKwargsItems)
from vllm.multimodal.parse import (ImageProcessorItems, ImageSize, from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
MultiModalDataItems) MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, ProcessingCache, BaseProcessingInfo, PromptReplacement,
PromptReplacement, PromptUpdate, PromptUpdate, PromptUpdateDetails)
PromptUpdateDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils.tensor_schema import TensorSchema, TensorShape from vllm.utils.tensor_schema import TensorSchema, TensorShape
...@@ -322,7 +322,7 @@ def _build_mistral3_processor( ...@@ -322,7 +322,7 @@ def _build_mistral3_processor(
info: _I, info: _I,
dummy_inputs: BaseDummyInputsBuilder[_I], dummy_inputs: BaseDummyInputsBuilder[_I],
*, *,
cache: Optional[ProcessingCache] = None, cache: Optional[BaseMultiModalProcessorCache] = None,
) -> BaseMultiModalProcessor: ) -> BaseMultiModalProcessor:
assert isinstance(info, Mistral3ProcessingInfo) assert isinstance(info, Mistral3ProcessingInfo)
return Mistral3MultiModalProcessor( return Mistral3MultiModalProcessor(
......
...@@ -41,7 +41,8 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor, ...@@ -41,7 +41,8 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, BaseProcessingInfo,
MultiModalPromptUpdates, MultiModalPromptUpdates,
PlaceholderFeaturesInfo, PlaceholderFeaturesInfo,
PromptReplacement, PromptUpdate) PromptReplacement, PromptUpdate,
ResolvedPromptUpdate)
# yapf: enable # yapf: enable
from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
...@@ -440,6 +441,23 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor[Phi3VProcessingInfo]): ...@@ -440,6 +441,23 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor[Phi3VProcessingInfo]):
) )
] ]
def _recompute_cached_prompt_update(
self,
cached_update: ResolvedPromptUpdate,
new_item_idx: int,
) -> ResolvedPromptUpdate:
new_update = super()._recompute_cached_prompt_update(
cached_update,
new_item_idx,
)
if cached_update.modality == "image":
hf_processor = self.info.get_hf_processor()
image_tokens: list[str] = hf_processor.img_tokens # type: ignore
new_update = new_update.with_target(image_tokens[new_item_idx])
return new_update
def _apply_prompt_updates( def _apply_prompt_updates(
self, self,
token_ids: list[int], token_ids: list[int],
......
...@@ -27,7 +27,7 @@ from vllm.multimodal.parse import (AudioProcessorItems, ImageEmbeddingItems, ...@@ -27,7 +27,7 @@ from vllm.multimodal.parse import (AudioProcessorItems, ImageEmbeddingItems,
MultiModalDataItems, MultiModalDataParser) MultiModalDataItems, MultiModalDataParser)
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement, BaseProcessingInfo, PromptReplacement,
PromptUpdate) PromptUpdate, ResolvedPromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils import is_list_of from vllm.utils import is_list_of
...@@ -850,6 +850,25 @@ class Phi4MMMultiModalProcessor(BaseMultiModalProcessor[Phi4MMProcessingInfo]): ...@@ -850,6 +850,25 @@ class Phi4MMMultiModalProcessor(BaseMultiModalProcessor[Phi4MMProcessingInfo]):
), ),
] ]
def _recompute_cached_prompt_update(
self,
cached_update: ResolvedPromptUpdate,
new_item_idx: int,
) -> ResolvedPromptUpdate:
new_update = super()._recompute_cached_prompt_update(
cached_update,
new_item_idx,
)
if cached_update.modality == "image":
image_tokens: list[str] = self.info.image_tokens # type: ignore
new_update = new_update.with_target(image_tokens[new_item_idx])
elif cached_update.modality == "audio":
audio_tokens: list[str] = self.info.audio_tokens # type: ignore
new_update = new_update.with_target(audio_tokens[new_item_idx])
return new_update
@MULTIMODAL_REGISTRY.register_processor( @MULTIMODAL_REGISTRY.register_processor(
Phi4MMMultiModalProcessor, Phi4MMMultiModalProcessor,
......
...@@ -25,12 +25,13 @@ from vllm.model_executor.layers.quantization import QuantizationConfig ...@@ -25,12 +25,13 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.models.llava import LlavaDummyInputsBuilder from vllm.model_executor.models.llava import LlavaDummyInputsBuilder
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.cache import BaseMultiModalProcessorCache
from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargsItems from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargsItems
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems, from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
ImageSize, MultiModalDataItems) ImageSize, MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, ProcessingCache, BaseProcessingInfo, PromptReplacement,
PromptReplacement, PromptUpdate) PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils.jsontree import json_map_leaves from vllm.utils.jsontree import json_map_leaves
...@@ -332,7 +333,7 @@ def _build_tarsier_hf_processor( ...@@ -332,7 +333,7 @@ def _build_tarsier_hf_processor(
info: _I_Tarsier, info: _I_Tarsier,
dummy_inputs: BaseDummyInputsBuilder[_I_Tarsier], dummy_inputs: BaseDummyInputsBuilder[_I_Tarsier],
*, *,
cache: Optional[ProcessingCache] = None, cache: Optional[BaseMultiModalProcessorCache] = None,
) -> BaseMultiModalProcessor: ) -> BaseMultiModalProcessor:
if isinstance(info, TarsierProcessingInfo): if isinstance(info, TarsierProcessingInfo):
return TarsierMultiModalProcessor( return TarsierMultiModalProcessor(
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import sys import sys
from collections.abc import Mapping from abc import ABC, abstractmethod
from dataclasses import dataclass from collections.abc import Mapping, Sequence
from typing import TypeVar, Union from typing import TYPE_CHECKING, Generic, Optional, TypeVar, Union
import torch import torch
from typing_extensions import TypeAlias, override
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import GiB_bytes, LRUCache from vllm.utils import GiB_bytes, LRUCache
...@@ -15,24 +16,67 @@ from .inputs import (MultiModalFieldElem, MultiModalKwargs, ...@@ -15,24 +16,67 @@ from .inputs import (MultiModalFieldElem, MultiModalKwargs,
MultiModalKwargsItem, MultiModalKwargsItems, MultiModalKwargsItem, MultiModalKwargsItems,
NestedTensors) NestedTensors)
if TYPE_CHECKING:
from vllm.config import ModelConfig, VllmConfig
from .processing import ResolvedPromptUpdate
from .registry import MultiModalRegistry
logger = init_logger(__name__) logger = init_logger(__name__)
@dataclass class MultiModalProcessorCacheItem:
class MultiModalCacheItemMetadata: """
size: int The data to store inside `MultiModalProcessorOnlyCache`.
@classmethod Args:
def wraps(cls, value: "MultiModalCacheValue"): item: The processed tensor data corresponding to a multi-modal item.
return cls(size=MultiModalCache.get_item_size(value)) prompt_updates: The prompt updates corresponding to `item`.
"""
def __init__(
self,
item: MultiModalKwargsItem,
prompt_updates: Sequence["ResolvedPromptUpdate"],
) -> None:
super().__init__()
self.item = item
self.prompt_updates = prompt_updates
class MultiModalProcessorCacheItemMetadata:
"""
The metadata to store inside `MultiModalProcessorSenderCache`.
Args:
item: The processed tensor data corresponding to a multi-modal item.
Since P1 already stores the tensor data, we only store its size
metadata in P0 to reduce memory usage. The size metadata is still
needed to keep the same cache eviction policy as P0.
prompt_updates: The prompt updates corresponding to `item`.
This needs to stay on P0 because for some models, they are
dependent on the processed tensor data (cached on P1).
"""
def __init__(
self,
item: MultiModalKwargsItem,
prompt_updates: Sequence["ResolvedPromptUpdate"],
) -> None:
super().__init__()
self.item_size = MultiModalCache.get_item_size(item)
self.prompt_updates = prompt_updates
MultiModalCacheValue = Union[ MultiModalCacheValue = Union[
MultiModalProcessorCacheItem,
MultiModalProcessorCacheItemMetadata,
MultiModalKwargsItems, MultiModalKwargsItems,
MultiModalKwargsItem, MultiModalKwargsItem,
MultiModalKwargs, MultiModalKwargs,
Mapping[str, NestedTensors], Mapping[str, NestedTensors],
MultiModalCacheItemMetadata,
] ]
_V = TypeVar("_V", bound=MultiModalCacheValue) _V = TypeVar("_V", bound=MultiModalCacheValue)
...@@ -47,8 +91,10 @@ class MultiModalCache: ...@@ -47,8 +91,10 @@ class MultiModalCache:
*, *,
debug: bool = False, debug: bool = False,
) -> int: ) -> int:
if isinstance(leaf, MultiModalFieldElem): if isinstance(leaf, MultiModalProcessorCacheItem):
return cls.get_item_size(leaf.data) # type: ignore return cls.get_leaf_size(leaf.item)
if isinstance(leaf, MultiModalProcessorCacheItemMetadata):
return leaf.item_size
# These are not subclasses of dict # These are not subclasses of dict
if isinstance(leaf, MultiModalKwargsItems): if isinstance(leaf, MultiModalKwargsItems):
...@@ -58,13 +104,13 @@ class MultiModalCache: ...@@ -58,13 +104,13 @@ class MultiModalCache:
if isinstance(leaf, MultiModalKwargs): if isinstance(leaf, MultiModalKwargs):
return cls.get_item_size(leaf.data) # type: ignore return cls.get_item_size(leaf.data) # type: ignore
if isinstance(leaf, MultiModalFieldElem):
return cls.get_item_size(leaf.data) # type: ignore
# sys.getsizeof doesn't work for tensors # sys.getsizeof doesn't work for tensors
if isinstance(leaf, torch.Tensor): if isinstance(leaf, torch.Tensor):
return leaf.nbytes return leaf.nbytes
if isinstance(leaf, MultiModalCacheItemMetadata):
return leaf.size
return sys.getsizeof(leaf) return sys.getsizeof(leaf)
@classmethod @classmethod
...@@ -98,3 +144,332 @@ class MultiModalCache: ...@@ -98,3 +144,332 @@ class MultiModalCache:
GiB_bytes * capacity_gb, GiB_bytes * capacity_gb,
getsizeof=lambda x: cls.get_item_size(x, debug=debug), getsizeof=lambda x: cls.get_item_size(x, debug=debug),
) )
_I = TypeVar("_I", contravariant=True)
_O = TypeVar("_O", covariant=True)
class BaseMultiModalCache(ABC, Generic[_I, _O]):
"""
Abstract base class to read/write multi-modal items from cache.
The idea of multi-modal caching is based on having a client and server
where the client executes in the frontend process (=P0) and
the server in the core process (=P1). The data flow is as follows:
```
is_cached() x N get_and_update()
P0: From API -----------------> -----------------> To P1
get_and_update()
P1: From P0 -----------------> To model
```
`is_cached()` can be called any number of times in P0. However,
`get_and_update()` must be called in P0 and P1 one after another
so that their cache eviction order remains the same.
This ensures that the keys in P0 and P1 caches are mirrored,
allowing us to determine whether a key is cached in P1 by looking
up the P0 cache, without having to communicate with P1.
"""
@abstractmethod
def get_and_update_item(
self,
mm_item: _I,
mm_hash: str,
) -> _O:
"""
Possibly update a multi-modal item based on whether it is
in the underlying cache.
This update is done out-of-place and updates the cache eviction order.
Args:
mm_item: The multi-modal item to update.
mm_hash: The hash of `mm_item`.
Returns:
The update multi-modal item.
"""
raise NotImplementedError
def get_and_update(
self,
mm_items: Sequence[_I],
mm_hashes: list[str],
) -> list[_O]:
"""
Possibly update a sequence of multi-modal items based on whether they
are in the underlying cache.
This update is done out-of-place and updates the cache eviction order.
Args:
mm_items: The multi-modal items to update.
mm_hashes: The hash of each item in `mm_items`.
Returns:
A new list of updated multi-modal items.
"""
assert len(mm_items) == len(mm_hashes)
return [
self.get_and_update_item(mm_item, mm_hash)
for mm_item, mm_hash in zip(mm_items, mm_hashes)
]
@abstractmethod
def clear_cache(self) -> None:
"""Clear the underlying cache."""
raise NotImplementedError
MultiModalProcessorCacheInItem: TypeAlias = \
Optional[tuple[MultiModalKwargsItem, Sequence["ResolvedPromptUpdate"]]]
MultiModalProcessorCacheOutItem: TypeAlias = \
tuple[Optional[MultiModalKwargsItem], Sequence["ResolvedPromptUpdate"]]
class BaseMultiModalProcessorCache(
BaseMultiModalCache[MultiModalProcessorCacheInItem,
MultiModalProcessorCacheOutItem]):
"""The required interface for caches on P0."""
@abstractmethod
def is_cached_item(self, mm_hash: str) -> bool:
"""
Check whether a multi-modal item is
in the underlying cache.
This **DOES NOT** update the cache eviction order.
Args:
mm_hash: The hash of the item to check.
Returns:
`True` if the item is cached, otherwise `False`.
"""
raise NotImplementedError
def is_cached(self, mm_hashes: list[str]) -> list[bool]:
"""
Check whether a sequence of multi-modal items are
in the underlying cache.
This **DOES NOT** update the cache eviction order.
Args:
mm_hashes: The hash of each item to check.
Returns:
For each item, `True` if the item is cached, otherwise `False`.
"""
return [self.is_cached_item(mm_hash) for mm_hash in mm_hashes]
class MultiModalProcessorOnlyCache(BaseMultiModalProcessorCache):
"""
The cache which is used on P0 when IPC caching is disabled.
How to update each item:
- If the item is in the cache, replace the input with the cached item.
- If the item is not in the cache, store that item (which includes
tensor data and metadata) into the cache, and return the input.
"""
def __init__(self, model_config: "ModelConfig") -> None:
super().__init__()
mm_config = model_config.get_multimodal_config()
self._cache = MultiModalCache.get_lru_cache(
mm_config.mm_processor_cache_gb,
MultiModalProcessorCacheItem,
)
@override
def is_cached_item(self, mm_hash: str) -> bool:
return mm_hash in self._cache
@override
def get_and_update_item(
self,
mm_item: MultiModalProcessorCacheInItem,
mm_hash: str,
) -> MultiModalProcessorCacheOutItem:
if (cached_item := self._cache.get(mm_hash)) is not None:
return cached_item.item, cached_item.prompt_updates
assert mm_item is not None, f"Expected a cached item for {mm_hash=}"
self._cache[mm_hash] = MultiModalProcessorCacheItem(*mm_item)
return mm_item
@override
def clear_cache(self) -> None:
self._cache.clear()
class MultiModalProcessorSenderCache(BaseMultiModalProcessorCache):
"""
The cache which is used on P0 when IPC caching is enabled.
How to update each item:
- If the item is already in the cache, clear the input to avoid
unnecessary IPC.
- If the item is not in the cache, store the metadata of that item so
that the eviction policy remains the same as the cache on P1,
and return the input.
By only storing the metadata, we avoid keeping the data itself in
memory inside P0.
"""
def __init__(self, model_config: "ModelConfig") -> None:
super().__init__()
mm_config = model_config.get_multimodal_config()
self._cache = MultiModalCache.get_lru_cache(
mm_config.mm_processor_cache_gb,
MultiModalProcessorCacheItemMetadata,
)
@override
def is_cached_item(self, mm_hash: str) -> bool:
return mm_hash in self._cache
@override
def get_and_update_item(
self,
mm_item: MultiModalProcessorCacheInItem,
mm_hash: str,
) -> MultiModalProcessorCacheOutItem:
if (cached_item := self._cache.get(mm_hash)) is not None:
return None, cached_item.prompt_updates
assert mm_item is not None, f"Expected a cached item for {mm_hash=}"
self._cache[mm_hash] = MultiModalProcessorCacheItemMetadata(*mm_item)
return mm_item
@override
def clear_cache(self) -> None:
self._cache.clear()
def _enable_processor_cache(
model_config: "ModelConfig",
mm_registry: "MultiModalRegistry",
) -> bool:
if not mm_registry.supports_multimodal_inputs(model_config):
return False
mm_config = model_config.get_multimodal_config()
return mm_config.mm_processor_cache_gb > 0
def _enable_ipc_cache(vllm_config: "VllmConfig") -> bool:
parallel_config = vllm_config.parallel_config
supports_ipc_cache = (parallel_config.data_parallel_size == 1
or parallel_config.data_parallel_external_lb)
return supports_ipc_cache
def processor_cache_from_config(
vllm_config: "VllmConfig",
mm_registry: "MultiModalRegistry",
) -> Optional[BaseMultiModalProcessorCache]:
"""Return a `BaseMultiModalProcessorCache`, if enabled."""
model_config = vllm_config.model_config
if not _enable_processor_cache(model_config, mm_registry):
return None
if not _enable_ipc_cache(vllm_config):
return MultiModalProcessorOnlyCache(model_config)
return MultiModalProcessorSenderCache(model_config)
def processor_only_cache_from_config(
model_config: "ModelConfig",
mm_registry: "MultiModalRegistry",
):
"""Return a `MultiModalProcessorOnlyCache`, if enabled."""
if not _enable_processor_cache(model_config, mm_registry):
return None
return MultiModalProcessorOnlyCache(model_config)
class BaseMultiModalReceiverCache(
BaseMultiModalCache[Optional[MultiModalKwargsItem],
MultiModalKwargsItem]):
"""The required interface for caches on P1."""
class MultiModalReceiverCache(BaseMultiModalReceiverCache):
"""
The cache which is used on P1 when IPC caching is enabled.
How to update each item:
- If the item is in the cache, replace the input with the cached item.
- If the item is not in the cache, store that item (which includes tensor
data) into the cache, and return the input.
"""
def __init__(self, model_config: "ModelConfig") -> None:
super().__init__()
mm_config = model_config.get_multimodal_config()
self._cache = MultiModalCache.get_lru_cache(
mm_config.mm_processor_cache_gb,
MultiModalKwargsItem,
)
@override
def get_and_update_item(
self,
mm_item: Optional[MultiModalKwargsItem],
mm_hash: str,
) -> MultiModalKwargsItem:
if (cached_item := self._cache.get(mm_hash)) is not None:
return cached_item
assert mm_item is not None, f"Expected a cached item for {mm_hash=}"
self._cache[mm_hash] = mm_item
return mm_item
@override
def clear_cache(self) -> None:
self._cache.clear()
def receiver_cache_from_config(
vllm_config: "VllmConfig",
mm_registry: "MultiModalRegistry",
) -> Optional[BaseMultiModalReceiverCache]:
"""Return a `BaseMultiModalReceiverCache`, if enabled."""
model_config = vllm_config.model_config
if not _enable_processor_cache(model_config, mm_registry):
return None
if not _enable_ipc_cache(vllm_config):
return None
return MultiModalReceiverCache(model_config)
...@@ -7,11 +7,11 @@ from collections.abc import Mapping, Sequence ...@@ -7,11 +7,11 @@ from collections.abc import Mapping, Sequence
from dataclasses import dataclass from dataclasses import dataclass
from functools import partial from functools import partial
from itertools import accumulate from itertools import accumulate
from typing import (TYPE_CHECKING, Any, Literal, Optional, TypedDict, TypeVar, from typing import (TYPE_CHECKING, Any, Literal, Optional, TypedDict, Union,
Union, cast, final) cast, final)
import numpy as np import numpy as np
from typing_extensions import NotRequired, TypeAlias, deprecated from typing_extensions import NotRequired, TypeAlias, TypeVar, deprecated
from vllm.utils import LazyLoader, full_groupby, is_list_of from vllm.utils import LazyLoader, full_groupby, is_list_of
from vllm.utils.jsontree import JSONTree, json_map_leaves from vllm.utils.jsontree import JSONTree, json_map_leaves
...@@ -668,7 +668,15 @@ class MultiModalKwargsItem(UserDict[str, MultiModalFieldElem]): ...@@ -668,7 +668,15 @@ class MultiModalKwargsItem(UserDict[str, MultiModalFieldElem]):
return {key: elem.data for key, elem in self.items()} return {key: elem.data for key, elem in self.items()}
class MultiModalKwargsItems(UserDict[str, Sequence[MultiModalKwargsItem]]): _I = TypeVar(
"_I",
MultiModalKwargsItem,
Optional[MultiModalKwargsItem],
default=MultiModalKwargsItem,
)
class MultiModalKwargsItems(UserDict[str, Sequence[_I]]):
""" """
A dictionary of A dictionary of
[`MultiModalKwargsItem`][vllm.multimodal.inputs.MultiModalKwargsItem]s [`MultiModalKwargsItem`][vllm.multimodal.inputs.MultiModalKwargsItem]s
...@@ -714,27 +722,37 @@ class MultiModalKwargsItems(UserDict[str, Sequence[MultiModalKwargsItem]]): ...@@ -714,27 +722,37 @@ class MultiModalKwargsItems(UserDict[str, Sequence[MultiModalKwargsItem]]):
items_by_modality = full_groupby(items, key=lambda x: x.modality) items_by_modality = full_groupby(items, key=lambda x: x.modality)
return MultiModalKwargsItems(items_by_modality) return MultiModalKwargsItems(items_by_modality)
def __getitem__(self, modality: str): def __getitem__(self, modality: str) -> Sequence[_I]:
if modality not in self: if modality not in self:
raise KeyError(f"Modality {modality!r} not found. " raise KeyError(f"Modality {modality!r} not found. "
f"Available modalities: {set(self.keys())}") f"Available modalities: {set(self.keys())}")
return super().__getitem__(modality) return super().__getitem__(modality) # type: ignore[return-value]
def get_data(self, *, pin_memory: bool = False) -> "MultiModalKwargs": def get_data(self, *, pin_memory: bool = False) -> "MultiModalKwargs":
elems_by_key = defaultdict[str, list[MultiModalFieldElem]](list) elems_by_key = defaultdict[str, list[MultiModalFieldElem]](list)
for items in self.values(): for modality, items in self.items():
for item in items: for i, item in enumerate(items):
if item is None:
raise RuntimeError("Cannot build data from empty "
f"mm_items[{modality}][{i}]")
for key, elem in item.items(): for key, elem in item.items():
elems_by_key[key].append(elem) elems_by_key[key].append(elem)
return MultiModalKwargs({ return MultiModalKwargs({
key: key:
elems[0].field.reduce_data(elems, pin_memory=pin_memory) elems[0].field.reduce_data(elems, pin_memory=pin_memory)
for key, elems in elems_by_key.items() if len(elems) > 0 for key, elems in elems_by_key.items()
}) })
MultiModalKwargsOptionalItems: TypeAlias = Union[
MultiModalKwargsItems[MultiModalKwargsItem],
MultiModalKwargsItems[Optional[MultiModalKwargsItem]],
]
class MultiModalKwargs(UserDict[str, NestedTensors]): class MultiModalKwargs(UserDict[str, NestedTensors]):
""" """
A dictionary that represents the keyword arguments to A dictionary that represents the keyword arguments to
...@@ -898,7 +916,7 @@ class MultiModalInputs(TypedDict): ...@@ -898,7 +916,7 @@ class MultiModalInputs(TypedDict):
token_type_ids: NotRequired[list[int]] token_type_ids: NotRequired[list[int]]
"""The token type IDs of the prompt.""" """The token type IDs of the prompt."""
mm_kwargs: MultiModalKwargsItems mm_kwargs: MultiModalKwargsOptionalItems
"""Keyword arguments to be directly passed to the model after batching.""" """Keyword arguments to be directly passed to the model after batching."""
mm_hashes: "MultiModalHashDict" mm_hashes: "MultiModalHashDict"
......
...@@ -4,7 +4,7 @@ from abc import ABC, abstractmethod ...@@ -4,7 +4,7 @@ from abc import ABC, abstractmethod
from collections import defaultdict from collections import defaultdict
from collections.abc import (Callable, Generator, ItemsView, Iterable, Mapping, from collections.abc import (Callable, Generator, ItemsView, Iterable, Mapping,
Sequence) Sequence)
from dataclasses import dataclass, field from dataclasses import dataclass, field, replace
from enum import Enum from enum import Enum
from functools import lru_cache from functools import lru_cache
from typing import (TYPE_CHECKING, Generic, NamedTuple, Optional, Protocol, from typing import (TYPE_CHECKING, Generic, NamedTuple, Optional, Protocol,
...@@ -20,12 +20,11 @@ from vllm.transformers_utils.tokenizer import (AnyTokenizer, decode_tokens, ...@@ -20,12 +20,11 @@ from vllm.transformers_utils.tokenizer import (AnyTokenizer, decode_tokens,
encode_tokens) encode_tokens)
from vllm.utils import flatten_2d_lists, full_groupby from vllm.utils import flatten_2d_lists, full_groupby
from .cache import MultiModalCache
from .hasher import MultiModalHasher from .hasher import MultiModalHasher
from .inputs import (MultiModalDataDict, MultiModalEncDecInputs, from .inputs import (MultiModalDataDict, MultiModalEncDecInputs,
MultiModalFieldConfig, MultiModalInputs, MultiModalFieldConfig, MultiModalInputs,
MultiModalKwargsItem, MultiModalKwargsItems, MultiModalKwargsItem, MultiModalKwargsItems,
PlaceholderRange) MultiModalKwargsOptionalItems, PlaceholderRange)
from .parse import (DictEmbeddingItems, EmbeddingItems, MultiModalDataItems, from .parse import (DictEmbeddingItems, EmbeddingItems, MultiModalDataItems,
MultiModalDataParser) MultiModalDataParser)
...@@ -34,6 +33,7 @@ if TYPE_CHECKING: ...@@ -34,6 +33,7 @@ if TYPE_CHECKING:
from transformers.feature_extraction_utils import BatchFeature from transformers.feature_extraction_utils import BatchFeature
from transformers.processing_utils import ProcessorMixin from transformers.processing_utils import ProcessorMixin
from .cache import BaseMultiModalProcessorCache
from .profiling import BaseDummyInputsBuilder from .profiling import BaseDummyInputsBuilder
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -557,6 +557,15 @@ class ResolvedPromptUpdate: ...@@ -557,6 +557,15 @@ class ResolvedPromptUpdate:
return self.iter_token_matches(prompt, tokenizer, start_idx=start_idx) return self.iter_token_matches(prompt, tokenizer, start_idx=start_idx)
def with_target(self, target: UpdateTarget):
return replace(self, target=target)
def with_content(self, content: PromptUpdateInfo):
if not isinstance(content, PromptUpdateDetails):
content = PromptUpdateDetails.from_seq(content)
return replace(self, content=content)
class _TokenMatch(NamedTuple): class _TokenMatch(NamedTuple):
start_idx: int start_idx: int
...@@ -865,21 +874,6 @@ def find_mm_placeholders( ...@@ -865,21 +874,6 @@ def find_mm_placeholders(
return dict(full_groupby_modality(it)) return dict(full_groupby_modality(it))
class ProcessingCache(MultiModalCache):
def __init__(self, capacity_gb: float) -> None:
super().__init__()
self._cache = self.get_lru_cache(capacity_gb, MultiModalKwargsItem)
self.get = self._cache.get
self.put = self._cache.put
self.reset = self._cache.clear
_CacheItemOrHash = Union[MultiModalKwargsItem, str]
class BaseProcessingInfo: class BaseProcessingInfo:
"""Base class to provide the information necessary for data processing.""" """Base class to provide the information necessary for data processing."""
...@@ -982,7 +976,7 @@ For an item `MultiModalPromptUpdates[k][i]`, ...@@ -982,7 +976,7 @@ For an item `MultiModalPromptUpdates[k][i]`,
class MultiModalProcessingInfo(NamedTuple): class MultiModalProcessingInfo(NamedTuple):
kwargs: MultiModalKwargsItems kwargs: MultiModalKwargsOptionalItems
hashes: MultiModalHashes hashes: MultiModalHashes
prompt_updates: MultiModalPromptUpdates prompt_updates: MultiModalPromptUpdates
...@@ -994,11 +988,13 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -994,11 +988,13 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
Not to be confused with `transformers.ProcessorMixin`. Not to be confused with `transformers.ProcessorMixin`.
""" """
def __init__(self, def __init__(
info: _I, self,
dummy_inputs: "BaseDummyInputsBuilder[_I]", info: _I,
*, dummy_inputs: "BaseDummyInputsBuilder[_I]",
cache: Optional[ProcessingCache] = None) -> None: *,
cache: Optional["BaseMultiModalProcessorCache"] = None,
) -> None:
super().__init__() super().__init__()
self.info = info self.info = info
...@@ -1355,32 +1351,6 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -1355,32 +1351,6 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
return prompt_ids, mm_processed_data, False return prompt_ids, mm_processed_data, False
def _get_cache_missing_items(
self,
cache: ProcessingCache,
mm_data_items: MultiModalDataItems,
mm_hashes: MultiModalHashes,
) -> tuple[dict[str, list[_CacheItemOrHash]], MultiModalDataItems]:
mm_cache_items_or_hashes: dict[str, list[_CacheItemOrHash]] = {
modality: [(h if (v := cache.get(h)) is None else v)
for h in hashes]
for modality, hashes in mm_hashes.items()
}
mm_missing_idxs = {
modality: [
idx for idx, item_or_hash in enumerate(items_or_hashes)
if isinstance(item_or_hash, str)
]
for modality, items_or_hashes in mm_cache_items_or_hashes.items()
}
mm_missing_data = {
modality: [mm_data_items[modality][idx] for idx in idxs]
for modality, idxs in mm_missing_idxs.items()
}
return mm_cache_items_or_hashes, self._to_mm_items(mm_missing_data)
def _hash_mm_items( def _hash_mm_items(
self, self,
mm_items: MultiModalDataItems, mm_items: MultiModalDataItems,
...@@ -1401,28 +1371,92 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -1401,28 +1371,92 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
for modality, items in mm_items.items() for modality, items in mm_items.items()
} }
def _get_cache_missing_items(
self,
cache: "BaseMultiModalProcessorCache",
mm_data_items: MultiModalDataItems,
mm_hashes: MultiModalHashes,
) -> MultiModalDataItems:
mm_is_cached = {
modality: cache.is_cached(hashes)
for modality, hashes in mm_hashes.items()
}
mm_missing_idxs = {
modality: [
idx for idx, item_is_cached in enumerate(items_is_cached)
if not item_is_cached
]
for modality, items_is_cached in mm_is_cached.items()
}
mm_missing_data = {
modality: [mm_data_items[modality][idx] for idx in idxs]
for modality, idxs in mm_missing_idxs.items()
}
return self._to_mm_items(mm_missing_data)
def _recompute_cached_prompt_update(
self,
cached_update: ResolvedPromptUpdate,
new_item_idx: int,
) -> ResolvedPromptUpdate:
"""
Override this if other attributes of `ResolvedPromptUpdate`
also need to be recomputed after retrieving from the cache.
"""
return replace(cached_update, item_idx=new_item_idx)
def _merge_mm_kwargs( def _merge_mm_kwargs(
self, self,
cache: ProcessingCache, cache: "BaseMultiModalProcessorCache",
mm_cache_items_or_hashes: dict[str, list[_CacheItemOrHash]], mm_hashes: MultiModalHashes,
mm_missing_kwargs: MultiModalKwargsItems, mm_missing_kwargs: MultiModalKwargsItems,
) -> MultiModalKwargsItems: mm_missing_prompt_updates: MultiModalPromptUpdates,
) -> tuple[MultiModalKwargsOptionalItems, MultiModalPromptUpdates]:
# Need to calculate this at the beginning to avoid skipping cache logic
# for subsequently repeated items in the same modality
mm_is_cached = {
modality: cache.is_cached(hashes)
for modality, hashes in mm_hashes.items()
}
mm_missing_next_idx = defaultdict[str, int](lambda: 0) mm_missing_next_idx = defaultdict[str, int](lambda: 0)
merged_items = defaultdict[str, list[MultiModalKwargsItem]](list) merged_kwargs = defaultdict[str,
for modality, items_or_hashes in mm_cache_items_or_hashes.items(): list[Optional[MultiModalKwargsItem]]](list)
for item_or_hash in items_or_hashes: merged_prompt_updates = defaultdict[
if isinstance(item_or_hash, str): str, list[Sequence[ResolvedPromptUpdate]]](list)
kw_item = mm_missing_kwargs[modality][ for modality, hashes in mm_hashes.items():
mm_missing_next_idx[modality]] missing_kwargs = mm_missing_kwargs.get(modality, [])
cache.put(item_or_hash, kw_item) missing_prompt_updates = mm_missing_prompt_updates.get(
modality, [])
for item_idx, item_hash in enumerate(hashes):
kwargs: Optional[MultiModalKwargsItem]
if not mm_is_cached[modality][item_idx]:
missing_next_idx = mm_missing_next_idx[modality]
kwargs = missing_kwargs[missing_next_idx]
updates = missing_prompt_updates[missing_next_idx]
mm_missing_next_idx[modality] += 1 mm_missing_next_idx[modality] += 1
item = kwargs, updates
else: else:
kw_item = item_or_hash item = None
kwargs, updates = cache.get_and_update_item(item, item_hash)
merged_kwargs[modality].append(kwargs)
merged_prompt_updates[modality].append([
self._recompute_cached_prompt_update(update, item_idx)
for update in updates
])
merged_items[modality].append(kw_item) mm_kwargs = MultiModalKwargsItems(merged_kwargs)
mm_prompt_updates = dict(merged_prompt_updates)
return MultiModalKwargsItems(merged_items) return mm_kwargs, mm_prompt_updates
def _apply_hf_processor( def _apply_hf_processor(
self, self,
...@@ -1490,10 +1524,8 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -1490,10 +1524,8 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
mm_hashes = self._hash_mm_items(mm_data_items, hf_processor_mm_kwargs, mm_hashes = self._hash_mm_items(mm_data_items, hf_processor_mm_kwargs,
tokenization_kwargs) tokenization_kwargs)
(
mm_cache_items_or_hashes, mm_missing_data_items = self._get_cache_missing_items(
mm_missing_data_items,
) = self._get_cache_missing_items(
cache=cache, cache=cache,
mm_data_items=mm_data_items, mm_data_items=mm_data_items,
mm_hashes=mm_hashes, mm_hashes=mm_hashes,
...@@ -1520,16 +1552,17 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -1520,16 +1552,17 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
hf_processor_mm_kwargs), hf_processor_mm_kwargs),
) )
mm_kwargs = self._merge_mm_kwargs( mm_missing_prompt_updates = self._get_mm_prompt_updates(
cache, mm_missing_data_items,
mm_cache_items_or_hashes=mm_cache_items_or_hashes, hf_processor_mm_kwargs,
mm_missing_kwargs=mm_missing_kwargs, mm_missing_kwargs,
) )
mm_prompt_updates = self._get_mm_prompt_updates( mm_kwargs, mm_prompt_updates = self._merge_mm_kwargs(
mm_data_items, cache,
hf_processor_mm_kwargs, mm_hashes=mm_hashes,
mm_kwargs, mm_missing_kwargs=mm_missing_kwargs,
mm_missing_prompt_updates=mm_missing_prompt_updates,
) )
mm_info = MultiModalProcessingInfo( mm_info = MultiModalProcessingInfo(
...@@ -1614,7 +1647,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -1614,7 +1647,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
def _validate_mm_kwargs( def _validate_mm_kwargs(
self, self,
mm_kwargs: MultiModalKwargsItems, mm_kwargs: MultiModalKwargsOptionalItems,
mm_item_counts: Mapping[str, int], mm_item_counts: Mapping[str, int],
) -> None: ) -> None:
for modality, item_count in mm_item_counts.items(): for modality, item_count in mm_item_counts.items():
...@@ -1655,7 +1688,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): ...@@ -1655,7 +1688,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
self, self,
mm_items: MultiModalDataItems, mm_items: MultiModalDataItems,
prompt_ids: list[int], prompt_ids: list[int],
mm_kwargs: MultiModalKwargsItems, mm_kwargs: MultiModalKwargsOptionalItems,
mm_prompt_updates: MultiModalPromptUpdates, mm_prompt_updates: MultiModalPromptUpdates,
is_update_applied: bool, is_update_applied: bool,
) -> tuple[list[int], str, Mapping[str, list[PlaceholderFeaturesInfo]]]: ) -> tuple[list[int], str, Mapping[str, list[PlaceholderFeaturesInfo]]]:
......
...@@ -13,7 +13,7 @@ import vllm.envs as envs ...@@ -13,7 +13,7 @@ import vllm.envs as envs
from vllm.logger import init_logger from vllm.logger import init_logger
from .inputs import (MultiModalDataDict, MultiModalEncDecInputs, from .inputs import (MultiModalDataDict, MultiModalEncDecInputs,
MultiModalInputs, MultiModalKwargsItems, MultiModalInputs, MultiModalKwargsOptionalItems,
MultiModalPlaceholderDict) MultiModalPlaceholderDict)
from .processing import (BaseMultiModalProcessor, BaseProcessingInfo, from .processing import (BaseMultiModalProcessor, BaseProcessingInfo,
EncDecMultiModalProcessor) EncDecMultiModalProcessor)
...@@ -43,7 +43,7 @@ class DummyDecoderData(NamedTuple): ...@@ -43,7 +43,7 @@ class DummyDecoderData(NamedTuple):
"""Dummy data used for profiling.""" """Dummy data used for profiling."""
prompt_token_ids: list[int] prompt_token_ids: list[int]
multi_modal_data: MultiModalKwargsItems multi_modal_data: MultiModalKwargsOptionalItems
multi_modal_placeholders: MultiModalPlaceholderDict multi_modal_placeholders: MultiModalPlaceholderDict
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment