Unverified Commit 69244e67 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Core] Use key-only cache for `BaseMultiModalProcessor` (#23018)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 8dbf6ed7
......@@ -2,7 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Mapping
from dataclasses import dataclass
from functools import lru_cache
from typing import TYPE_CHECKING, Generic, Optional, Protocol, TypeVar
import torch.nn as nn
......@@ -13,8 +12,9 @@ from vllm.transformers_utils.tokenizer import (AnyTokenizer,
cached_tokenizer_from_config)
from vllm.utils import ClassRegistry
from .processing import (BaseMultiModalProcessor, BaseProcessingInfo,
ProcessingCache)
from .cache import (BaseMultiModalProcessorCache,
processor_only_cache_from_config)
from .processing import BaseMultiModalProcessor, BaseProcessingInfo
from .profiling import (BaseDummyInputsBuilder, DummyDecoderData,
DummyEncoderData, MultiModalProfiler)
......@@ -65,7 +65,7 @@ class MultiModalProcessorFactory(Protocol[_I]):
info: _I,
dummy_inputs: BaseDummyInputsBuilder[_I],
*,
cache: Optional[ProcessingCache] = None,
cache: Optional[BaseMultiModalProcessorCache] = None,
) -> BaseMultiModalProcessor[_I]:
...
......@@ -80,20 +80,13 @@ class _ProcessorFactories(Generic[_I]):
self,
ctx: InputProcessingContext,
*,
cache: Optional[ProcessingCache] = None,
cache: Optional[BaseMultiModalProcessorCache] = None,
):
info = self.info(ctx)
dummy_inputs_builder = self.dummy_inputs(info)
return self.processor(info, dummy_inputs_builder, cache=cache)
# Make sure a different cache is used for each model config
# NOTE: ModelConfig is not hashable so it cannot be passed directly
@lru_cache(maxsize=1)
def _get_processor_cache(model_id: str, capacity_gb: int):
return ProcessingCache(capacity_gb) if capacity_gb > 0 else None
class MultiModalRegistry:
"""
A registry that dispatches data processing according to the model.
......@@ -103,31 +96,6 @@ class MultiModalRegistry:
self._processor_factories = ClassRegistry[nn.Module,
_ProcessorFactories]()
def _get_processor_cache(self, model_config: "ModelConfig"):
model_id = model_config.model
capacity_gb = model_config.mm_processor_cache_gb
return _get_processor_cache(model_id, capacity_gb)
def reset_processor_cache(self, model_config: "ModelConfig") -> bool:
"""Reset the multi-modal processing cache."""
if processor_cache := self._get_processor_cache(model_config):
processor_cache.reset()
return True # Success
def enable_mm_input_cache(self, model_config: "ModelConfig") -> bool:
"""Whether the multi-modal input cache should be enabled.
NOTE: This is put under MultiModalRegistry on purpose to respect
text-only mode for multimodal models.
"""
if not self.supports_multimodal_inputs(model_config):
return False
mm_config = model_config.get_multimodal_config()
return mm_config.mm_processor_cache_gb > 0
def supports_multimodal_inputs(self, model_config: "ModelConfig") -> bool:
"""
Checks if the model supports multimodal inputs.
......@@ -157,6 +125,8 @@ class MultiModalRegistry:
def get_max_tokens_per_item_by_modality(
self,
model_config: "ModelConfig",
*,
cache: Optional[BaseMultiModalProcessorCache] = None,
) -> Mapping[str, int]:
"""
Get the maximum number of tokens per data item from each modality based
......@@ -165,11 +135,11 @@ class MultiModalRegistry:
if not model_config.is_multimodal_model:
return {}
processor = self.create_processor(model_config, disable_cache=False)
processor = self.create_processor(model_config, cache=cache)
profiler = MultiModalProfiler(processor)
seq_len = model_config.max_model_len
mm_limits = self.get_mm_limits_per_prompt(model_config)
mm_limits = self.get_mm_limits_per_prompt(model_config, cache=cache)
return profiler.get_mm_max_contiguous_tokens(
seq_len,
......@@ -182,6 +152,8 @@ class MultiModalRegistry:
def get_max_tokens_per_item_by_nonzero_modality(
self,
model_config: "ModelConfig",
*,
cache: Optional[BaseMultiModalProcessorCache] = None,
) -> Mapping[str, int]:
"""
Get the maximum number of tokens per data item from each modality based
......@@ -192,15 +164,19 @@ class MultiModalRegistry:
This is currently directly used only in V1 for profiling the memory
usage of a model.
"""
mm_limits = self.get_mm_limits_per_prompt(model_config)
mm_limits = self.get_mm_limits_per_prompt(model_config, cache=cache)
max_tokens_per_item = self.get_max_tokens_per_item_by_modality(
model_config,
cache=cache,
)
return {
key: max_tokens_per_mm_item
for key, max_tokens_per_mm_item in
self.get_max_tokens_per_item_by_modality(model_config).items()
for key, max_tokens_per_mm_item in max_tokens_per_item.items()
if mm_limits[key] > 0
}
# TODO: Remove once V0 is gone
def get_max_tokens_by_modality(
self,
model_config: "ModelConfig",
......@@ -209,14 +185,19 @@ class MultiModalRegistry:
Get the maximum number of tokens from each modality
for profiling the memory usage of a model.
"""
mm_limits = self.get_mm_limits_per_prompt(model_config)
cache = processor_only_cache_from_config(model_config, self)
mm_limits = self.get_mm_limits_per_prompt(model_config, cache=cache)
max_tokens_per_item = self.get_max_tokens_per_item_by_modality(
model_config,
cache=cache,
)
return {
key: mm_limits[key] * max_tokens_per_mm_item
for key, max_tokens_per_mm_item in
self.get_max_tokens_per_item_by_modality(model_config).items()
for key, max_tokens_per_mm_item in max_tokens_per_item.items()
}
# TODO: Remove once V0 is gone
def get_max_multimodal_tokens(self, model_config: "ModelConfig") -> int:
"""
Get the maximum number of multi-modal tokens
......@@ -227,6 +208,8 @@ class MultiModalRegistry:
def get_mm_limits_per_prompt(
self,
model_config: "ModelConfig",
*,
cache: Optional[BaseMultiModalProcessorCache] = None,
) -> Mapping[str, int]:
"""
Get the maximum number of multi-modal input instances for each modality
......@@ -235,7 +218,7 @@ class MultiModalRegistry:
if not model_config.is_multimodal_model:
return {}
processor = self.create_processor(model_config, disable_cache=False)
processor = self.create_processor(model_config, cache=cache)
profiler = MultiModalProfiler(processor)
return profiler.get_mm_limits()
......@@ -303,7 +286,7 @@ class MultiModalRegistry:
model_config: "ModelConfig",
*,
tokenizer: Optional[AnyTokenizer] = None,
disable_cache: Optional[bool] = None,
cache: Optional[BaseMultiModalProcessorCache] = None,
) -> BaseMultiModalProcessor[BaseProcessingInfo]:
"""
Create a multi-modal processor for a specific model and tokenizer.
......@@ -311,15 +294,10 @@ class MultiModalRegistry:
if not model_config.is_multimodal_model:
raise ValueError(f"{model_config.model} is not a multimodal model")
if disable_cache is None:
disable_cache = not model_config.enable_mm_processor_cache
model_cls = self._get_model_cls(model_config)
factories = self._processor_factories[model_cls]
ctx = self._create_processing_ctx(model_config, tokenizer)
cache = None if disable_cache else self._get_processor_cache(
model_config)
return factories.build_processor(ctx, cache=cache)
......@@ -328,13 +306,15 @@ class MultiModalRegistry:
model_config: "ModelConfig",
seq_len: int,
mm_counts: Optional[Mapping[str, int]] = None,
*,
cache: Optional[BaseMultiModalProcessorCache] = None,
) -> DummyDecoderData:
"""
Create dummy data for profiling the memory usage of a model.
The model is identified by ``model_config``.
"""
processor = self.create_processor(model_config, disable_cache=False)
processor = self.create_processor(model_config, cache=cache)
profiler = MultiModalProfiler(processor)
dummy_data = profiler.get_decoder_dummy_data(seq_len, mm_counts)
......@@ -352,13 +332,15 @@ class MultiModalRegistry:
model_config: "ModelConfig",
seq_len: int,
mm_counts: Optional[Mapping[str, int]] = None,
*,
cache: Optional[BaseMultiModalProcessorCache] = None,
) -> DummyEncoderData:
"""
Create dummy data for profiling the memory usage of a model.
The model is identified by ``model_config``.
"""
processor = self.create_processor(model_config, disable_cache=False)
processor = self.create_processor(model_config, cache=cache)
profiler = MultiModalProfiler(processor)
dummy_data = profiler.get_encoder_dummy_data(seq_len, mm_counts)
......
......@@ -597,8 +597,7 @@ class AsyncLLM(EngineClient):
await asyncio.gather(*coros)
async def reset_mm_cache(self) -> None:
self.processor.mm_registry.reset_processor_cache(self.model_config)
self.processor.mm_input_cache_client.reset()
self.processor.clear_cache()
await self.engine_core.reset_mm_cache_async()
async def reset_prefix_cache(self,
......
......@@ -22,6 +22,7 @@ from vllm.logger import init_logger
from vllm.logging_utils.dump_input import dump_engine_exception
from vllm.lora.request import LoRARequest
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.cache import receiver_cache_from_config
from vllm.tasks import POOLING_TASKS, SupportedTask
from vllm.transformers_utils.config import (
maybe_register_config_serialize_by_value)
......@@ -38,7 +39,6 @@ from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
EngineCoreRequestType,
ReconfigureDistributedRequest, ReconfigureRankType,
UtilityOutput, UtilityResult)
from vllm.v1.engine.mm_input_cache import MultiModalInputCacheServer
from vllm.v1.engine.utils import EngineHandshakeMetadata, EngineZmqAddresses
from vllm.v1.executor.abstract import Executor
from vllm.v1.kv_cache_interface import KVCacheConfig
......@@ -128,8 +128,9 @@ class EngineCore:
)
self.use_spec_decode = vllm_config.speculative_config is not None
self.mm_input_cache_server = MultiModalInputCacheServer(
vllm_config.model_config, MULTIMODAL_REGISTRY)
self.mm_registry = mm_registry = MULTIMODAL_REGISTRY
self.mm_receiver_cache = receiver_cache_from_config(
vllm_config, mm_registry)
# Setup batch queue for pipeline parallelism.
# Batch queue for scheduled batches. This enables us to asynchronously
......@@ -370,7 +371,8 @@ class EngineCore:
logger.warning("Resetting the multi-modal cache when requests are "
"in progress may lead to desynced internal caches.")
self.mm_input_cache_server.reset()
if self.mm_receiver_cache is not None:
self.mm_receiver_cache.clear_cache()
def reset_prefix_cache(self):
self.scheduler.reset_prefix_cache()
......@@ -435,10 +437,11 @@ class EngineCore:
assert request.mm_kwargs is not None
# Note on thread safety: no race condition.
# `mm_input_cache_server` is reset at the end of LLMEngine init,
# `mm_receiver_cache` is reset at the end of LLMEngine init,
# and will only accessed in the input processing thread afterwards.
request.mm_kwargs = self.mm_input_cache_server.get_and_update(
request.mm_kwargs, request.mm_hashes)
if self.mm_receiver_cache is not None:
request.mm_kwargs = self.mm_receiver_cache.get_and_update(
request.mm_kwargs, request.mm_hashes)
req = Request.from_engine_core_request(request,
self.request_block_hasher)
......
......@@ -271,8 +271,7 @@ class LLMEngine:
self.engine_core.profile(False)
def reset_mm_cache(self):
self.processor.mm_registry.reset_processor_cache(self.model_config)
self.processor.mm_input_cache_client.reset()
self.processor.clear_cache()
self.engine_core.reset_mm_cache()
def reset_prefix_cache(self, device: Optional[Device] = None):
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Sequence
from typing import TYPE_CHECKING, Optional
from vllm.multimodal import MultiModalRegistry
from vllm.multimodal.cache import MultiModalCache, MultiModalCacheItemMetadata
from vllm.multimodal.inputs import MultiModalKwargsItem
from vllm.utils import is_list_of
if TYPE_CHECKING:
from vllm.config import ModelConfig
# The idea of multimodal input caching is based on having a client and
# a server, where the client executes in the frontend process (=P0) and the
# server in the core process (=P1).
#
# -- P0:
# - BaseMultiModalProcessor calls MultiModalHasher to get the `mm_hash` of
# each input multi-modal item (e.g. image),
# - BaseMultiModalProcessor processes the input items into `mm_kwargs`,
# which are MultiModalKwargsItem instances that each correspond to an
# input multi-modal item.
# - MultiModalInputCacheClient accepts the `mm_kwargs` and corresponding
# `mm_hash` for each item. It stores the `mm_hash` as keys and the size
# of `mm_kwargs`, but not the `mm_kwargs` themselves, to avoid taking
# up additional memory in P0.
# - The `mm_hash` is always sent to P1.
# - The corresponding `mm_kwargs` are only sent to P1 if they are not cached
# in MultiModalInputCacheServer.
#
# -- P1:
# - If the `mm_hash` is cached (i.e. `mm_kwargs` are not sent from P0),
# MultiModalInputCacheServer retrieves the corresponding `mm_kwargs`.
# - If the `mm_hash` is not cached (i.e. `mm_kwargs` are sent from P0),
# MultiModalInputCacheServer stores `mm_kwargs` under the key `mm_hash`.
# - Either way, the `mm_hash` and corresponding `mm_kwargs` are sent to
# the engine for model execution.
#
# Both Client and Server must perform cache update and eviction based on the
# same item size. This ensures that the keys of MultiModalInputCacheClient
# and MultiModalInputCacheServer are mirrored, allowing us to determine in P0
# whether a key is cached in MultiModalInputCacheServer by querying
# MultiModalInputCacheClient without having to communicate with P1.
class MultiModalInputCacheClient:
"""Used by P0 to check whether multi-modal kwargs are cached in P1."""
def __init__(self, model_config: "ModelConfig",
mm_registry: MultiModalRegistry) -> None:
super().__init__()
self.enabled = mm_registry.enable_mm_input_cache(model_config)
self.mm_cache = MultiModalCache.get_lru_cache(
model_config.get_mm_input_cache_gb(),
MultiModalCacheItemMetadata,
)
def get_and_update(
self,
mm_kwargs: Sequence[MultiModalKwargsItem],
mm_hashes: list[str],
) -> list[Optional[MultiModalKwargsItem]]:
if not self.enabled:
return list(mm_kwargs)
assert len(mm_kwargs) == len(mm_hashes)
out_mm_items = list[Optional[MultiModalKwargsItem]]()
for mm_item, mm_hash in zip(mm_kwargs, mm_hashes):
if self.mm_cache.get(mm_hash) is not None:
out_mm_items.append(None)
else:
self.mm_cache[mm_hash] = \
MultiModalCacheItemMetadata.wraps(mm_item)
out_mm_items.append(mm_item)
return out_mm_items
def reset(self) -> None:
self.mm_cache.clear()
class MultiModalInputCacheServer:
"""Used by P1 to avoid requiring past multi-modal kwargs from P0."""
def __init__(self, model_config: "ModelConfig",
mm_registry: MultiModalRegistry) -> None:
super().__init__()
self.enabled = mm_registry.enable_mm_input_cache(model_config)
self.mm_cache = MultiModalCache.get_lru_cache(
model_config.get_mm_input_cache_gb(),
MultiModalKwargsItem,
)
def get_and_update(
self,
mm_kwargs: Sequence[Optional[MultiModalKwargsItem]],
mm_hashes: list[str],
) -> list[MultiModalKwargsItem]:
if not self.enabled:
mm_kwargs_lst = list(mm_kwargs)
assert is_list_of(mm_kwargs_lst, MultiModalKwargsItem)
return mm_kwargs_lst
assert len(mm_kwargs) == len(mm_hashes)
out_mm_items = list[MultiModalKwargsItem]()
for mm_item, mm_hash in zip(mm_kwargs, mm_hashes):
if mm_item is None:
out_mm_items.append(self.mm_cache[mm_hash])
else:
self.mm_cache[mm_hash] = mm_item
out_mm_items.append(mm_item)
return out_mm_items
def reset(self) -> None:
self.mm_cache.clear()
......@@ -11,6 +11,7 @@ from vllm.inputs.parse import split_enc_dec_inputs
from vllm.inputs.preprocess import InputPreprocessor
from vllm.lora.request import LoRARequest
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.multimodal.cache import processor_cache_from_config
from vllm.multimodal.inputs import MultiModalKwargsItem, PlaceholderRange
from vllm.multimodal.processing import EncDecMultiModalProcessor
from vllm.multimodal.utils import argsort_mm_positions
......@@ -18,7 +19,6 @@ from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.mm_input_cache import MultiModalInputCacheClient
from vllm.v1.structured_output.backend_guidance import (
validate_guidance_grammar)
from vllm.v1.structured_output.backend_lm_format_enforcer import (
......@@ -47,16 +47,17 @@ class Processor:
self.generation_config_fields = (
self.model_config.try_get_generation_config())
self.input_preprocessor = InputPreprocessor(self.model_config,
self.tokenizer,
mm_registry)
self.mm_input_cache_client = MultiModalInputCacheClient(
self.model_config, mm_registry)
self.mm_registry = mm_registry
self.mm_processor_cache = processor_cache_from_config(
vllm_config, mm_registry)
@property
def mm_registry(self):
return self.input_preprocessor.mm_registry
self.input_preprocessor = InputPreprocessor(
self.model_config,
self.tokenizer,
mm_registry,
mm_processor_cache=self.mm_processor_cache,
)
def _validate_logprobs(
self,
......@@ -310,7 +311,7 @@ class Processor:
# in the input sequence.
sorted_mm_idxs = argsort_mm_positions(decoder_mm_positions)
orig_sorted_mm_inputs = [
sorted_mm_inputs = [
decoder_mm_inputs[modality][idx]
for modality, idx in sorted_mm_idxs
]
......@@ -323,11 +324,6 @@ class Processor:
for modality, idx in sorted_mm_idxs
]
sorted_mm_inputs = self.mm_input_cache_client.get_and_update(
orig_sorted_mm_inputs,
sorted_mm_hashes,
)
return decoder_inputs.get("prompt"), EngineCoreRequest(
request_id=request_id,
prompt_token_ids=decoder_inputs["prompt_token_ids"],
......@@ -415,3 +411,6 @@ class Processor:
# TODO: Find out how many placeholder tokens are there so we can
# check that chunked prefill does not truncate them
# max_batch_len = self.scheduler_config.max_num_batched_tokens
def clear_cache(self) -> None:
self.input_preprocessor.clear_cache()
......@@ -2186,10 +2186,13 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
max_items_per_batch: int,
) -> BatchedTensorInputs:
"""Dummy data for profiling and precompiling multimodal models."""
assert self.mm_budget is not None
dummy_decoder_data = self.mm_registry.get_decoder_dummy_data(
model_config=self.model_config,
seq_len=self.max_num_tokens,
mm_counts={modality: 1},
cache=self.mm_budget.cache,
)
dummy_mm_data = dummy_decoder_data.multi_modal_data
......
......@@ -1813,10 +1813,13 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
max_items_per_batch: int,
) -> BatchedTensorInputs:
"""Dummy data for profiling and precompiling multimodal models."""
assert self.mm_budget is not None
dummy_decoder_data = self.mm_registry.get_decoder_dummy_data(
model_config=self.model_config,
seq_len=self.max_num_tokens,
mm_counts={modality: 1},
cache=self.mm_budget.cache,
)
dummy_mm_data = dummy_decoder_data.multi_modal_data
......
......@@ -10,6 +10,7 @@ from vllm.attention.backends.abstract import AttentionBackend
from vllm.config import ModelConfig, SchedulerConfig
from vllm.model_executor.models.interfaces import MultiModalEmbeddings
from vllm.model_executor.models.utils import extract_layer_index
from vllm.multimodal.cache import processor_only_cache_from_config
from vllm.multimodal.registry import MultiModalRegistry
from vllm.v1.attention.backends.utils import AttentionMetadataBuilder
from vllm.v1.core.encoder_cache_manager import compute_mm_encoder_budget
......@@ -33,14 +34,18 @@ class MultiModalBudget:
self.model_config = model_config
self.scheduler_config = scheduler_config
self.mm_registry = mm_registry
self.cache = cache = processor_only_cache_from_config(
model_config, mm_registry)
self.max_model_len = model_config.max_model_len
self.max_num_reqs = scheduler_config.max_num_seqs
self.mm_limits = mm_registry.get_mm_limits_per_prompt(model_config)
self.mm_limits = mm_registry.get_mm_limits_per_prompt(model_config,
cache=cache)
max_tokens_by_modality = mm_registry \
.get_max_tokens_per_item_by_nonzero_modality(model_config)
.get_max_tokens_per_item_by_nonzero_modality(model_config,
cache=cache)
encoder_compute_budget, encoder_cache_size = compute_mm_encoder_budget(
scheduler_config,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment