Unverified Commit 3d446433 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Bugfix] Fix size calculation of processing cache (#15114)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 1fe0fd12
...@@ -7,15 +7,20 @@ from unittest.mock import MagicMock ...@@ -7,15 +7,20 @@ from unittest.mock import MagicMock
import numpy as np import numpy as np
import pytest import pytest
import torch
from transformers import ProcessorMixin from transformers import ProcessorMixin
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalFieldElem, MultiModalKwargs,
MultiModalKwargsItem,
MultiModalSharedField)
# yapf conflicts with isort for this block # yapf conflicts with isort for this block
# yapf: disable # yapf: disable
from vllm.multimodal.processing import (PlaceholderFeaturesInfo, from vllm.multimodal.processing import (PlaceholderFeaturesInfo,
PromptIndexTargets, PromptInsertion, ProcessingCache, PromptIndexTargets,
PromptReplacement, apply_text_matches, PromptInsertion, PromptReplacement,
apply_text_matches,
apply_token_matches, apply_token_matches,
find_mm_placeholders, find_mm_placeholders,
find_text_matches, find_token_matches, find_text_matches, find_token_matches,
...@@ -890,6 +895,45 @@ def test_find_mm_placeholders( ...@@ -890,6 +895,45 @@ def test_find_mm_placeholders(
assert result == expected assert result == expected
def _dummy_elem(modality: str, key: str, size: int):
return MultiModalFieldElem(
modality=modality,
key=key,
data=torch.empty((size, ), dtype=torch.int8),
field=MultiModalSharedField(1),
)
def _dummy_item(modality: str, size_by_key: dict[str, int]):
return MultiModalKwargsItem.from_elems([
_dummy_elem(modality, key, size) for key, size in size_by_key.items()
])
def _dummy_kw(size_by_key_modality: dict[str, dict[str, int]]):
return MultiModalKwargs.from_items([
_dummy_item(modality, size_by_key)
for modality, size_by_key in size_by_key_modality.items()
])
# yapf: disable
@pytest.mark.parametrize(
("item", "expected_size"),
[
(_dummy_item("a", {"a1": 100}), 100),
(_dummy_item("a", {"a1": 100, "a2": 110}), 210),
(_dummy_kw({"a": {"a1": 100, "a2": 110}, "b": {"b1": 120, "b2": 130}}), 460), # noqa: E501
],
)
# yapf: enable
def test_cache_item_size(item, expected_size):
cache = ProcessingCache.get_lru_cache(2048, type(item))
cache[""] = item
assert cache.currsize == expected_size
@pytest.mark.parametrize("model_id", ["llava-hf/llava-v1.6-mistral-7b-hf"]) @pytest.mark.parametrize("model_id", ["llava-hf/llava-v1.6-mistral-7b-hf"])
@pytest.mark.parametrize( @pytest.mark.parametrize(
("limit", "num_supported", "is_valid"), ("limit", "num_supported", "is_valid"),
......
...@@ -26,7 +26,7 @@ from vllm.utils import GiB_bytes, flatten_2d_lists, full_groupby ...@@ -26,7 +26,7 @@ from vllm.utils import GiB_bytes, flatten_2d_lists, full_groupby
from .hasher import MultiModalHasher from .hasher import MultiModalHasher
from .inputs import (MultiModalDataDict, MultiModalEncDecInputs, from .inputs import (MultiModalDataDict, MultiModalEncDecInputs,
MultiModalFieldConfig, MultiModalInputs, MultiModalKwargs, MultiModalFieldConfig, MultiModalInputs, MultiModalKwargs,
MultiModalKwargsItem, PlaceholderRange) MultiModalKwargsItem, NestedTensors, PlaceholderRange)
from .parse import (DictEmbeddingItems, EmbeddingItems, MultiModalDataItems, from .parse import (DictEmbeddingItems, EmbeddingItems, MultiModalDataItems,
MultiModalDataParser) MultiModalDataParser)
...@@ -853,33 +853,62 @@ class ProcessingCache: ...@@ -853,33 +853,62 @@ class ProcessingCache:
@staticmethod @staticmethod
def get_lru_cache( def get_lru_cache(
capacity_gb: int, capacity_gb: float,
value_type: type[_V], value_type: type[_V],
*,
debug: bool = False,
) -> LRUCache[str, _V]: ) -> LRUCache[str, _V]:
def get_size(leaf: object) -> int: def get_leaf_size(leaf: object) -> int:
# MultiModalKwargs is not a subclass of dict
if isinstance(leaf, MultiModalKwargs):
return get_item_size(leaf.data)
# MultiModalKwargsItem is not a subclass of dict
if isinstance(leaf, MultiModalKwargsItem):
leaf_data = {k: v.data for k, v in leaf.items()}
return get_item_size(leaf_data)
# sys.getsizeof doesn't work for tensors
if isinstance(leaf, torch.Tensor): if isinstance(leaf, torch.Tensor):
return leaf.nbytes # sys.getsizeof doesn't work for tensors return leaf.nbytes
return sys.getsizeof(leaf) return sys.getsizeof(leaf)
return LRUCache[str, _V]( def get_item_size(
GiB_bytes * capacity_gb, value: Union[MultiModalKwargs, MultiModalKwargsItem,
getsizeof=lambda x: json_reduce_leaves( Mapping[str, NestedTensors]]
) -> int:
size = json_reduce_leaves(
lambda a, b: a + b, lambda a, b: a + b,
json_map_leaves(get_size, x), json_map_leaves(get_leaf_size, value),
),
) )
def __init__(self, capacity_gb: int) -> None: if debug:
logger.debug("Calculated size of %s to be %.2f GiB",
type(value), size / GiB_bytes)
return size
return LRUCache(GiB_bytes * capacity_gb, getsizeof=get_item_size)
def __init__(
self,
capacity_gb: float,
*,
debug_cache_hit_ratio_steps: Optional[int] = None,
) -> None:
super().__init__() super().__init__()
# DEBUG: Set to None to disable self.debug_cache_hit_ratio_steps = debug_cache_hit_ratio_steps
self.debug_cache_hit_ratio_steps: Optional[int] = None
self.debug_cache_hits = 0 self.debug_cache_hits = 0
self.debug_cache_total = 0 self.debug_cache_total = 0
self._cache = self.get_lru_cache(capacity_gb, MultiModalKwargsItem) self._cache = self.get_lru_cache(
capacity_gb,
MultiModalKwargsItem,
debug=bool(debug_cache_hit_ratio_steps),
)
def _maybe_log_cache_stats(self) -> None: def _maybe_log_cache_stats(self) -> None:
steps = self.debug_cache_hit_ratio_steps steps = self.debug_cache_hit_ratio_steps
...@@ -890,6 +919,9 @@ class ProcessingCache: ...@@ -890,6 +919,9 @@ class ProcessingCache:
if total > 0 and total % steps == 0: if total > 0 and total % steps == 0:
logger.debug("ProcessingCache: hit_ratio = %.2f", logger.debug("ProcessingCache: hit_ratio = %.2f",
self.debug_cache_hits / total) self.debug_cache_hits / total)
logger.debug("ProcessingCache: size = %.2f / %.2f GiB",
self._cache.currsize / GiB_bytes,
self._cache.maxsize / GiB_bytes)
def get( def get(
self, self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment