"tests/vscode:/vscode.git/clone" did not exist on "020732800caf3ba1eae62098c4264dac7ef35611"
Unverified Commit cbbae38f authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[2/N] Move cache factories to MM registry (#32382)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent cdba4c74
...@@ -19,8 +19,6 @@ from vllm.multimodal.cache import ( ...@@ -19,8 +19,6 @@ from vllm.multimodal.cache import (
MultiModalReceiverCache, MultiModalReceiverCache,
ShmObjectStoreReceiverCache, ShmObjectStoreReceiverCache,
ShmObjectStoreSenderCache, ShmObjectStoreSenderCache,
engine_receiver_cache_from_config,
processor_cache_from_config,
) )
from vllm.multimodal.hasher import MultiModalHasher from vllm.multimodal.hasher import MultiModalHasher
from vllm.multimodal.inputs import ( from vllm.multimodal.inputs import (
...@@ -132,10 +130,10 @@ def _compare_caches( ...@@ -132,10 +130,10 @@ def _compare_caches(
n_iter: int = 100, n_iter: int = 100,
seed: int = 0, seed: int = 0,
): ):
cache_0_p0 = processor_cache_from_config(config_0, MULTIMODAL_REGISTRY) cache_0_p0 = MULTIMODAL_REGISTRY.processor_cache_from_config(config_0)
cache_0_p1 = engine_receiver_cache_from_config(config_0, MULTIMODAL_REGISTRY) cache_0_p1 = MULTIMODAL_REGISTRY.engine_receiver_cache_from_config(config_0)
cache_1_p0 = processor_cache_from_config(config_1, MULTIMODAL_REGISTRY) cache_1_p0 = MULTIMODAL_REGISTRY.processor_cache_from_config(config_1)
cache_1_p1 = engine_receiver_cache_from_config(config_1, MULTIMODAL_REGISTRY) cache_1_p1 = MULTIMODAL_REGISTRY.engine_receiver_cache_from_config(config_1)
cache_size_gb = max( cache_size_gb = max(
config_0.model_config.multimodal_config.mm_processor_cache_gb, config_0.model_config.multimodal_config.mm_processor_cache_gb,
......
...@@ -6,9 +6,8 @@ import pytest ...@@ -6,9 +6,8 @@ import pytest
from vllm.assets.image import ImageAsset from vllm.assets.image import ImageAsset
from vllm.assets.video import VideoAsset from vllm.assets.video import VideoAsset
from vllm.config import CacheConfig, DeviceConfig, ModelConfig, VllmConfig from vllm.config import CacheConfig, DeviceConfig, ModelConfig, VllmConfig
from vllm.multimodal import MultiModalUUIDDict from vllm.multimodal import MultiModalRegistry, MultiModalUUIDDict
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.v1.engine import input_processor as input_processor_mod
from vllm.v1.engine.input_processor import InputProcessor from vllm.v1.engine.input_processor import InputProcessor
cherry_pil_image = ImageAsset("cherry_blossom").pil_image cherry_pil_image = ImageAsset("cherry_blossom").pil_image
...@@ -36,9 +35,9 @@ def _mock_input_processor( ...@@ -36,9 +35,9 @@ def _mock_input_processor(
raising=True, raising=True,
) )
monkeypatch.setattr( monkeypatch.setattr(
input_processor_mod, MultiModalRegistry,
"processor_cache_from_config", "processor_cache_from_config",
lambda vllm_config, mm_registry: None, lambda self, vllm_config: None,
raising=True, raising=True,
) )
......
...@@ -135,9 +135,15 @@ class LoRAModelManager: ...@@ -135,9 +135,15 @@ class LoRAModelManager:
llm_punica_wrapper llm_punica_wrapper
) )
def _maybe_init_mm(self, vllm_config: VllmConfig, max_num_batched_tokens) -> None: def _maybe_init_mm(
self.supports_tower_connector_lora = False self,
vllm_config: VllmConfig,
max_num_batched_tokens: int,
) -> None:
model_config: ModelConfig = vllm_config.model_config model_config: ModelConfig = vllm_config.model_config
mm_registry = MULTIMODAL_REGISTRY
self.supports_tower_connector_lora = False
self.mm_mapping: MultiModelKeys = self.model.get_mm_mapping() self.mm_mapping: MultiModelKeys = self.model.get_mm_mapping()
# Only one language model can be included in the model. # Only one language model can be included in the model.
...@@ -154,9 +160,7 @@ class LoRAModelManager: ...@@ -154,9 +160,7 @@ class LoRAModelManager:
self.punica_wrapper_mapping[lm_prefix] = llm_punica_wrapper self.punica_wrapper_mapping[lm_prefix] = llm_punica_wrapper
if self.lora_config.enable_tower_connector_lora: if self.lora_config.enable_tower_connector_lora:
self.mm_processor_info = MULTIMODAL_REGISTRY.create_processor( self.mm_processor_info = mm_registry.create_processor(model_config).info
model_config
).info
self.supports_tower_connector_lora = self.supports_mm and hasattr( self.supports_tower_connector_lora = self.supports_mm and hasattr(
self.model, "get_num_mm_encoder_tokens" self.model, "get_num_mm_encoder_tokens"
) )
...@@ -169,11 +173,7 @@ class LoRAModelManager: ...@@ -169,11 +173,7 @@ class LoRAModelManager:
"GitHub if you encounter them." "GitHub if you encounter them."
) )
mm_budget = MultiModalBudget( mm_budget = MultiModalBudget(vllm_config, mm_registry)
model_config,
vllm_config.scheduler_config,
MULTIMODAL_REGISTRY,
)
limit_per_prompt: int = max( limit_per_prompt: int = max(
self.mm_processor_info.get_allowed_mm_limits().values() self.mm_processor_info.get_allowed_mm_limits().values()
) )
......
...@@ -35,7 +35,6 @@ if TYPE_CHECKING: ...@@ -35,7 +35,6 @@ if TYPE_CHECKING:
from vllm.config import ModelConfig, VllmConfig from vllm.config import ModelConfig, VllmConfig
from .processing.processor import ResolvedPromptUpdate from .processing.processor import ResolvedPromptUpdate
from .registry import MultiModalRegistry
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -561,67 +560,6 @@ class ShmObjectStoreSenderCache(BaseMultiModalProcessorCache): ...@@ -561,67 +560,6 @@ class ShmObjectStoreSenderCache(BaseMultiModalProcessorCache):
return mm_item return mm_item
def _enable_processor_cache(
model_config: "ModelConfig",
mm_registry: "MultiModalRegistry",
) -> bool:
if not mm_registry.supports_multimodal_inputs(model_config):
return False
mm_config = model_config.get_multimodal_config()
return mm_config.mm_processor_cache_gb > 0
def _enable_ipc_cache(vllm_config: "VllmConfig") -> bool:
parallel_config = vllm_config.parallel_config
supports_ipc_cache = (
parallel_config._api_process_count == 1
and parallel_config.data_parallel_size == 1
) or parallel_config.data_parallel_external_lb
return supports_ipc_cache
def _enable_mm_input_shm_cache(vllm_config: "VllmConfig") -> bool:
"""Whether the shared memory based cache should be enabled."""
if not _enable_ipc_cache(vllm_config):
return False
mm_config = vllm_config.model_config.get_multimodal_config()
return mm_config.mm_processor_cache_type == "shm"
def processor_cache_from_config(
vllm_config: "VllmConfig",
mm_registry: "MultiModalRegistry",
) -> BaseMultiModalProcessorCache | None:
"""Return a `BaseMultiModalProcessorCache`, if enabled."""
model_config = vllm_config.model_config
if not _enable_processor_cache(model_config, mm_registry):
return None
if not _enable_ipc_cache(vllm_config):
return MultiModalProcessorOnlyCache(model_config)
if not _enable_mm_input_shm_cache(vllm_config):
return MultiModalProcessorSenderCache(model_config)
return ShmObjectStoreSenderCache(vllm_config)
def processor_only_cache_from_config(
model_config: "ModelConfig",
mm_registry: "MultiModalRegistry",
):
"""Return a `MultiModalProcessorOnlyCache`, if enabled."""
if not _enable_processor_cache(model_config, mm_registry):
return None
return MultiModalProcessorOnlyCache(model_config)
class BaseMultiModalReceiverCache( class BaseMultiModalReceiverCache(
BaseMultiModalCache[MultiModalKwargsItem | None, MultiModalKwargsItem] BaseMultiModalCache[MultiModalKwargsItem | None, MultiModalKwargsItem]
): ):
...@@ -780,50 +718,3 @@ class ShmObjectStoreReceiverCache(BaseMultiModalReceiverCache): ...@@ -780,50 +718,3 @@ class ShmObjectStoreReceiverCache(BaseMultiModalReceiverCache):
@override @override
def clear_cache(self) -> None: def clear_cache(self) -> None:
self._shm_cache.clear() self._shm_cache.clear()
def engine_receiver_cache_from_config(
vllm_config: "VllmConfig",
mm_registry: "MultiModalRegistry",
) -> BaseMultiModalReceiverCache | None:
"""
This is used in the engine process.
Return a `BaseMultiModalReceiverCache` only when IPC caching is enabled and
mm_processor_cache_type=="lru".
"""
model_config = vllm_config.model_config
if not _enable_processor_cache(model_config, mm_registry):
return None
if not _enable_ipc_cache(vllm_config):
return None
if not _enable_mm_input_shm_cache(vllm_config):
return MultiModalReceiverCache(model_config)
return None
def worker_receiver_cache_from_config(
vllm_config: "VllmConfig",
mm_registry: "MultiModalRegistry",
shared_worker_lock: LockType,
) -> BaseMultiModalReceiverCache | None:
"""
This is used in the worker process.
Return a `BaseMultiModalReceiverCache` only when IPC caching is enabled and
mm_processor_cache_type=="shm".
"""
model_config = vllm_config.model_config
if not _enable_processor_cache(model_config, mm_registry):
return None
if not _enable_ipc_cache(vllm_config):
return None
if not _enable_mm_input_shm_cache(vllm_config):
return None
return ShmObjectStoreReceiverCache(vllm_config, shared_worker_lock)
...@@ -2,14 +2,23 @@ ...@@ -2,14 +2,23 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Mapping from collections.abc import Mapping
from dataclasses import dataclass from dataclasses import dataclass
from typing import TYPE_CHECKING, Generic, Protocol, TypeVar, cast from multiprocessing.synchronize import Lock as LockType
from typing import TYPE_CHECKING, Generic, Literal, Protocol, TypeVar, cast
from vllm.config.multimodal import BaseDummyOptions from vllm.config.multimodal import BaseDummyOptions
from vllm.config.observability import ObservabilityConfig from vllm.config.observability import ObservabilityConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike, cached_tokenizer_from_config from vllm.tokenizers import TokenizerLike, cached_tokenizer_from_config
from .cache import BaseMultiModalProcessorCache from .cache import (
BaseMultiModalProcessorCache,
BaseMultiModalReceiverCache,
MultiModalProcessorOnlyCache,
MultiModalProcessorSenderCache,
MultiModalReceiverCache,
ShmObjectStoreReceiverCache,
ShmObjectStoreSenderCache,
)
from .inputs import MultiModalInputs from .inputs import MultiModalInputs
from .processing import ( from .processing import (
BaseDummyInputsBuilder, BaseDummyInputsBuilder,
...@@ -19,7 +28,7 @@ from .processing import ( ...@@ -19,7 +28,7 @@ from .processing import (
) )
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.config import ModelConfig, ObservabilityConfig from vllm.config import ModelConfig, ObservabilityConfig, VllmConfig
from vllm.model_executor.models.interfaces import SupportsMultiModal from vllm.model_executor.models.interfaces import SupportsMultiModal
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -355,3 +364,84 @@ class MultiModalRegistry: ...@@ -355,3 +364,84 @@ class MultiModalRegistry:
first_modality = next(iter(max_tokens)) first_modality = next(iter(max_tokens))
return max_tokens[first_modality] return max_tokens[first_modality]
def _get_cache_type(
self,
vllm_config: "VllmConfig",
) -> Literal[None, "processor_only", "lru", "shm"]:
model_config = vllm_config.model_config
if not self.supports_multimodal_inputs(model_config):
return None
# Check if the cache is disabled.
mm_config = model_config.get_multimodal_config()
if mm_config.mm_processor_cache_gb <= 0:
return None
# Check if IPC caching is supported.
parallel_config = vllm_config.parallel_config
is_ipc_supported = parallel_config._api_process_count == 1 and (
parallel_config.data_parallel_size == 1
or parallel_config.data_parallel_external_lb
)
if not is_ipc_supported:
return "processor_only"
mm_config = model_config.get_multimodal_config()
return mm_config.mm_processor_cache_type
def processor_cache_from_config(
self,
vllm_config: "VllmConfig",
) -> BaseMultiModalProcessorCache | None:
"""Return a `BaseMultiModalProcessorCache`, if enabled."""
cache_type = self._get_cache_type(vllm_config)
if cache_type is None:
return None
elif cache_type == "processor_only":
return MultiModalProcessorOnlyCache(vllm_config.model_config)
elif cache_type == "lru":
return MultiModalProcessorSenderCache(vllm_config.model_config)
elif cache_type == "shm":
return ShmObjectStoreSenderCache(vllm_config)
else:
raise ValueError(f"Unknown cache type: {cache_type!r}")
def processor_only_cache_from_config(
self,
vllm_config: "VllmConfig",
) -> MultiModalProcessorOnlyCache | None:
"""Return a `MultiModalProcessorOnlyCache`, if enabled."""
cache_type = self._get_cache_type(vllm_config)
if cache_type is None:
return None
return MultiModalProcessorOnlyCache(vllm_config.model_config)
def engine_receiver_cache_from_config(
self,
vllm_config: "VllmConfig",
) -> BaseMultiModalReceiverCache | None:
"""Return a `BaseMultiModalReceiverCache` for the engine process."""
cache_type = self._get_cache_type(vllm_config)
if cache_type in (None, "processor_only", "shm"):
return None
elif cache_type == "lru":
return MultiModalReceiverCache(vllm_config.model_config)
else:
raise ValueError(f"Unknown cache type: {cache_type!r}")
def worker_receiver_cache_from_config(
self,
vllm_config: "VllmConfig",
shared_worker_lock: LockType,
) -> BaseMultiModalReceiverCache | None:
"""Return a `BaseMultiModalReceiverCache` for the worker process."""
cache_type = self._get_cache_type(vllm_config)
if cache_type in (None, "processor_only", "lru"):
return None
elif cache_type == "shm":
return ShmObjectStoreReceiverCache(vllm_config, shared_worker_lock)
else:
raise ValueError(f"Unknown cache type: {cache_type!r}")
...@@ -23,7 +23,6 @@ from vllm.logger import init_logger ...@@ -23,7 +23,6 @@ from vllm.logger import init_logger
from vllm.logging_utils.dump_input import dump_engine_exception from vllm.logging_utils.dump_input import dump_engine_exception
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.cache import engine_receiver_cache_from_config
from vllm.tasks import POOLING_TASKS, SupportedTask from vllm.tasks import POOLING_TASKS, SupportedTask
from vllm.transformers_utils.config import maybe_register_config_serialize_by_value from vllm.transformers_utils.config import maybe_register_config_serialize_by_value
from vllm.utils.gc_utils import ( from vllm.utils.gc_utils import (
...@@ -149,8 +148,8 @@ class EngineCore: ...@@ -149,8 +148,8 @@ class EngineCore:
self.model_executor.init_kv_output_aggregator(self.scheduler.connector) # type: ignore self.model_executor.init_kv_output_aggregator(self.scheduler.connector) # type: ignore
self.mm_registry = mm_registry = MULTIMODAL_REGISTRY self.mm_registry = mm_registry = MULTIMODAL_REGISTRY
self.mm_receiver_cache = engine_receiver_cache_from_config( self.mm_receiver_cache = mm_registry.engine_receiver_cache_from_config(
vllm_config, mm_registry vllm_config
) )
# If a KV connector is initialized for scheduler, we want to collect # If a KV connector is initialized for scheduler, we want to collect
......
...@@ -14,7 +14,6 @@ from vllm.inputs.preprocess import InputPreprocessor ...@@ -14,7 +14,6 @@ from vllm.inputs.preprocess import InputPreprocessor
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.multimodal.cache import processor_cache_from_config
from vllm.multimodal.inputs import MultiModalFeatureSpec, MultiModalUUIDDict from vllm.multimodal.inputs import MultiModalFeatureSpec, MultiModalUUIDDict
from vllm.multimodal.parse import MultiModalDataParser from vllm.multimodal.parse import MultiModalDataParser
from vllm.multimodal.processing.context import set_request_id from vllm.multimodal.processing.context import set_request_id
...@@ -58,7 +57,7 @@ class InputProcessor: ...@@ -58,7 +57,7 @@ class InputProcessor:
self.generation_config_fields = self.model_config.try_get_generation_config() self.generation_config_fields = self.model_config.try_get_generation_config()
self.mm_registry = mm_registry self.mm_registry = mm_registry
self.mm_processor_cache = processor_cache_from_config(vllm_config, mm_registry) self.mm_processor_cache = mm_registry.processor_cache_from_config(vllm_config)
self.input_preprocessor = InputPreprocessor( self.input_preprocessor = InputPreprocessor(
self.model_config, self.model_config,
......
...@@ -623,11 +623,7 @@ class GPUModelRunner( ...@@ -623,11 +623,7 @@ class GPUModelRunner(
self.cudagraph_dispatcher = CudagraphDispatcher(self.vllm_config) self.cudagraph_dispatcher = CudagraphDispatcher(self.vllm_config)
self.mm_budget = ( self.mm_budget = (
MultiModalBudget( MultiModalBudget(self.vllm_config, self.mm_registry)
self.model_config,
self.scheduler_config,
self.mm_registry,
)
if self.supports_mm_inputs if self.supports_mm_inputs
else None else None
) )
......
...@@ -8,11 +8,10 @@ import torch ...@@ -8,11 +8,10 @@ import torch
from typing_extensions import deprecated from typing_extensions import deprecated
from vllm.attention.layer import Attention from vllm.attention.layer import Attention
from vllm.config import CacheConfig, ModelConfig, SchedulerConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.models.interfaces import MultiModalEmbeddings from vllm.model_executor.models.interfaces import MultiModalEmbeddings
from vllm.model_executor.models.utils import extract_layer_index from vllm.model_executor.models.utils import extract_layer_index
from vllm.multimodal.cache import processor_only_cache_from_config
from vllm.multimodal.registry import MultiModalRegistry from vllm.multimodal.registry import MultiModalRegistry
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.mem_utils import MemorySnapshot, format_gib from vllm.utils.mem_utils import MemorySnapshot, format_gib
...@@ -28,16 +27,15 @@ class MultiModalBudget: ...@@ -28,16 +27,15 @@ class MultiModalBudget:
def __init__( def __init__(
self, self,
model_config: ModelConfig, vllm_config: VllmConfig,
scheduler_config: SchedulerConfig,
mm_registry: MultiModalRegistry, mm_registry: MultiModalRegistry,
) -> None: ) -> None:
super().__init__() super().__init__()
self.model_config = model_config self.model_config = model_config = vllm_config.model_config
self.scheduler_config = scheduler_config self.scheduler_config = scheduler_config = vllm_config.scheduler_config
self.mm_registry = mm_registry self.mm_registry = mm_registry
self.cache = cache = processor_only_cache_from_config(model_config, mm_registry) self.cache = cache = mm_registry.processor_only_cache_from_config(vllm_config)
self.max_model_len = model_config.max_model_len self.max_model_len = model_config.max_model_len
self.max_num_reqs = scheduler_config.max_num_seqs self.max_num_reqs = scheduler_config.max_num_seqs
......
...@@ -12,7 +12,6 @@ from vllm.config import VllmConfig, set_current_vllm_config ...@@ -12,7 +12,6 @@ from vllm.config import VllmConfig, set_current_vllm_config
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.cache import worker_receiver_cache_from_config
from vllm.utils.import_utils import resolve_obj_by_qualname from vllm.utils.import_utils import resolve_obj_by_qualname
from vllm.utils.system_utils import update_environment_variables from vllm.utils.system_utils import update_environment_variables
from vllm.v1.kv_cache_interface import KVCacheSpec from vllm.v1.kv_cache_interface import KVCacheSpec
...@@ -303,11 +302,12 @@ class WorkerWrapperBase: ...@@ -303,11 +302,12 @@ class WorkerWrapperBase:
self.mm_receiver_cache = None self.mm_receiver_cache = None
else: else:
self.mm_receiver_cache = worker_receiver_cache_from_config( self.mm_receiver_cache = (
MULTIMODAL_REGISTRY.worker_receiver_cache_from_config(
vllm_config, vllm_config,
MULTIMODAL_REGISTRY,
shared_worker_lock, shared_worker_lock,
) )
)
with set_current_vllm_config(self.vllm_config): with set_current_vllm_config(self.vllm_config):
# To make vLLM config available during worker initialization # To make vLLM config available during worker initialization
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment