Unverified Commit d3235cb5 authored by prashanth058's avatar prashanth058 Committed by GitHub
Browse files

[Fix] Enable mm_processor_cache with vision LoRA (#31927)


Signed-off-by: default avatarprashanth058 <prashanth.dannamaneni@uipath.com>
parent 287b37cd
...@@ -24,10 +24,12 @@ from vllm.multimodal.cache import ( ...@@ -24,10 +24,12 @@ from vllm.multimodal.cache import (
) )
from vllm.multimodal.hasher import MultiModalHasher from vllm.multimodal.hasher import MultiModalHasher
from vllm.multimodal.inputs import ( from vllm.multimodal.inputs import (
MultiModalFeatureSpec,
MultiModalFieldElem, MultiModalFieldElem,
MultiModalKwargsItem, MultiModalKwargsItem,
MultiModalKwargsItems, MultiModalKwargsItems,
MultiModalSharedField, MultiModalSharedField,
PlaceholderRange,
) )
from vllm.multimodal.processing import PromptInsertion from vllm.multimodal.processing import PromptInsertion
from vllm.utils.mem_constants import GiB_bytes, MiB_bytes from vllm.utils.mem_constants import GiB_bytes, MiB_bytes
...@@ -518,3 +520,40 @@ def test_cache_eviction_shm_cache(): ...@@ -518,3 +520,40 @@ def test_cache_eviction_shm_cache():
receiver_cache = ShmObjectStoreReceiverCache(vllm_config, mp.Lock()) receiver_cache = ShmObjectStoreReceiverCache(vllm_config, mp.Lock())
_run_test_cache_eviction_shm(sender_cache, receiver_cache, base_item_size=MiB_bytes) _run_test_cache_eviction_shm(sender_cache, receiver_cache, base_item_size=MiB_bytes)
def test_processor_cache_shared_across_loras():
"""Test that processor cache uses mm_hash to share data across LoRAs."""
model_config = ModelConfig(
model="llava-hf/llava-onevision-qwen2-0.5b-ov-hf",
mm_processor_cache_gb=1,
)
receiver_cache = MultiModalReceiverCache(model_config)
base_mm_hash = "image_hash_abc123"
lora_a_identifier = f"12345:{base_mm_hash}"
lora_b_identifier = f"67890:{base_mm_hash}"
item_data = MultiModalKwargsItem.dummy("test_image", nbytes=1024)
feature_lora_a = MultiModalFeatureSpec(
data=item_data,
modality="image",
identifier=lora_a_identifier,
mm_position=PlaceholderRange(offset=0, length=100),
mm_hash=base_mm_hash,
)
receiver_cache.get_and_update_features([feature_lora_a])
assert base_mm_hash in receiver_cache._cache
feature_lora_b = MultiModalFeatureSpec(
data=None,
modality="image",
identifier=lora_b_identifier,
mm_position=PlaceholderRange(offset=0, length=100),
mm_hash=base_mm_hash,
)
receiver_cache.get_and_update_features([feature_lora_b])
assert feature_lora_b.data == item_data
...@@ -1649,19 +1649,6 @@ class EngineArgs: ...@@ -1649,19 +1649,6 @@ class EngineArgs:
else None else None
) )
if (
lora_config is not None
and lora_config.enable_tower_connector_lora
and self.mm_processor_cache_gb != 0
):
raise ValueError(
"Currently, enable_tower_connector_lora is "
"incompatible with the multi-modal processor cache. "
"When enable_tower_connector_lora is set, "
"mm_processor_cache_gb must be 0, got %s",
self.mm_processor_cache_gb,
)
if ( if (
lora_config is not None lora_config is not None
and speculative_config is not None and speculative_config is not None
......
...@@ -635,12 +635,17 @@ class BaseMultiModalReceiverCache( ...@@ -635,12 +635,17 @@ class BaseMultiModalReceiverCache(
Update multimodal features with cached encoder outputs. Update multimodal features with cached encoder outputs.
Touch all identifier at first before update to avoid Touch all identifier at first before update to avoid
item in updated list evict during update. item in updated list evict during update.
Uses mm_hash for cache key to share across LoRAs (falls back to
identifier for backward compatibility).
""" """
for feature in mm_features: for feature in mm_features:
self.touch_receiver_cache_item(feature.identifier, feature.data) cache_key = feature.mm_hash or feature.identifier
self.touch_receiver_cache_item(cache_key, feature.data)
for feature in mm_features: for feature in mm_features:
feature.data = self.get_and_update_item(feature.data, feature.identifier) cache_key = feature.mm_hash or feature.identifier
feature.data = self.get_and_update_item(feature.data, cache_key)
return mm_features return mm_features
@abstractmethod @abstractmethod
......
...@@ -330,6 +330,9 @@ class MultiModalFeatureSpec: ...@@ -330,6 +330,9 @@ class MultiModalFeatureSpec:
mm_position: PlaceholderRange mm_position: PlaceholderRange
"""e.g., PlaceholderRange(offset=2, length=336)""" """e.g., PlaceholderRange(offset=2, length=336)"""
mm_hash: str | None = None
"""Base mm_hash for processor cache (without LoRA prefix)."""
@staticmethod @staticmethod
def gather_kwargs(features: list["MultiModalFeatureSpec"], keys: set[str]): def gather_kwargs(features: list["MultiModalFeatureSpec"], keys: set[str]):
kwargs = defaultdict[str, list[NestedTensors]](list) kwargs = defaultdict[str, list[NestedTensors]](list)
......
...@@ -562,15 +562,17 @@ class InputProcessor: ...@@ -562,15 +562,17 @@ class InputProcessor:
mm_features = [] mm_features = []
for modality, idx in sorted_mm_idxs: for modality, idx in sorted_mm_idxs:
base_mm_hash = decoder_mm_hashes[modality][idx]
mm_features.append( mm_features.append(
MultiModalFeatureSpec( MultiModalFeatureSpec(
data=decoder_mm_inputs[modality][idx], data=decoder_mm_inputs[modality][idx],
modality=modality, modality=modality,
identifier=self._get_mm_identifier( identifier=self._get_mm_identifier(
decoder_mm_hashes[modality][idx], base_mm_hash,
lora_request, lora_request,
), ),
mm_position=decoder_mm_positions[modality][idx], mm_position=decoder_mm_positions[modality][idx],
mm_hash=base_mm_hash,
) )
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment