test_cache.py 1.72 KB
Newer Older
1
2
3
4
5
6
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import torch

from vllm.multimodal.cache import MultiModalCache, MultiModalCacheItemMetadata
7
8
from vllm.multimodal.inputs import (MultiModalFieldElem, MultiModalKwargsItem,
                                    MultiModalKwargsItems,
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
                                    MultiModalSharedField)


def _dummy_elem(modality: str, key: str, size: int):
    return MultiModalFieldElem(
        modality=modality,
        key=key,
        data=torch.empty((size, ), dtype=torch.int8),
        field=MultiModalSharedField(1),
    )


def _dummy_item(modality: str, size_by_key: dict[str, int]):
    return MultiModalKwargsItem.from_elems([
        _dummy_elem(modality, key, size) for key, size in size_by_key.items()
    ])


27
28
def _dummy_items(size_by_key_modality: dict[str, dict[str, int]]):
    return MultiModalKwargsItems.from_seq([
29
30
31
32
33
34
35
36
37
38
39
        _dummy_item(modality, size_by_key)
        for modality, size_by_key in size_by_key_modality.items()
    ])


# yapf: disable
@pytest.mark.parametrize(
    ("item", "expected_size"),
    [
        (_dummy_item("a", {"a1": 100}), 100),
        (_dummy_item("a", {"a1": 100, "a2": 110}), 210),
40
41
        (_dummy_items({"a": {"a1": 100, "a2": 110}, "b": {"b1": 120, "b2": 130}}), 460),  # noqa: E501
        (_dummy_items({"a": {"a1": 100, "a2": 110}, "b": {"b1": 120, "b2": 130}}).get_data(), 460),  # noqa: E501
42
43
44
45
46
47
48
49
50
51
52
    ],
)
# yapf: enable
def test_cache_item_size(item, expected_size):
    cache = MultiModalCache.get_lru_cache(2048, type(item))

    cache[""] = item
    assert cache.currsize == expected_size

    cache[""] = MultiModalCacheItemMetadata.wraps(item)
    assert cache.currsize == expected_size