cache.py 2.42 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import sys
from collections.abc import Mapping
from dataclasses import dataclass
from typing import TypeVar, Union

import torch

from vllm.jsontree import json_map_leaves, json_reduce_leaves
from vllm.logger import init_logger
from vllm.utils import GiB_bytes, LRUCache

from .inputs import MultiModalKwargs, MultiModalKwargsItem, NestedTensors

logger = init_logger(__name__)


@dataclass
class MultiModalCacheItemMetadata:
    size: int

    @classmethod
    def wraps(cls, value: "MultiModalCacheValue"):
        return cls(size=MultiModalCache.get_item_size(value))


MultiModalCacheValue = Union[
    MultiModalKwargs,
    MultiModalKwargsItem,
    Mapping[str, NestedTensors],
    MultiModalCacheItemMetadata,
]

_V = TypeVar("_V", bound=MultiModalCacheValue)


class MultiModalCache:

    @classmethod
    def get_leaf_size(
        cls,
        leaf: object,
        *,
        debug: bool = False,
    ) -> int:
        # MultiModalKwargs is not a subclass of dict
        if isinstance(leaf, MultiModalKwargs):
            return cls.get_item_size(leaf.data, debug=debug)

        # MultiModalKwargsItem is not a subclass of dict
        if isinstance(leaf, MultiModalKwargsItem):
            leaf_data = {k: v.data for k, v in leaf.items()}
            return cls.get_item_size(leaf_data, debug=debug)

        # sys.getsizeof doesn't work for tensors
        if isinstance(leaf, torch.Tensor):
            return leaf.nbytes

        if isinstance(leaf, MultiModalCacheItemMetadata):
            return leaf.size

        return sys.getsizeof(leaf)

    @classmethod
    def get_item_size(
        cls,
        value: MultiModalCacheValue,
        *,
        debug: bool = False,
    ) -> int:
        size = json_reduce_leaves(
            lambda a, b: a + b,
            json_map_leaves(lambda x: cls.get_leaf_size(x, debug=debug),
                            value),
        )

        if debug:
            logger.debug("Calculated size of %s to be %.2f GiB", type(value),
                         size / GiB_bytes)

        return size

    @classmethod
    def get_lru_cache(
        cls,
        capacity_gb: float,
        value_type: type[_V],
        *,
        debug: bool = False,
    ) -> LRUCache[str, _V]:
        return LRUCache(
            GiB_bytes * capacity_gb,
            getsizeof=lambda x: cls.get_item_size(x, debug=debug),
        )