Unverified Commit f4c98b4d authored by Bella kira's avatar Bella kira Committed by GitHub
Browse files

[Misc] Consolidate LRUCache implementations (#15481)


Signed-off-by: default avatarBella kira <2374035698@qq.com>
parent e1e0fd75
...@@ -12,7 +12,6 @@ from typing import (TYPE_CHECKING, Generic, NamedTuple, Optional, Protocol, ...@@ -12,7 +12,6 @@ from typing import (TYPE_CHECKING, Generic, NamedTuple, Optional, Protocol,
TypeVar, Union, cast) TypeVar, Union, cast)
import torch import torch
from cachetools import LRUCache
from transformers import BatchFeature, PretrainedConfig, ProcessorMixin from transformers import BatchFeature, PretrainedConfig, ProcessorMixin
from typing_extensions import assert_never from typing_extensions import assert_never
...@@ -21,7 +20,7 @@ from vllm.jsontree import json_map_leaves, json_reduce_leaves ...@@ -21,7 +20,7 @@ from vllm.jsontree import json_map_leaves, json_reduce_leaves
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.transformers_utils.tokenizer import (AnyTokenizer, decode_tokens, from vllm.transformers_utils.tokenizer import (AnyTokenizer, decode_tokens,
encode_tokens) encode_tokens)
from vllm.utils import GiB_bytes, flatten_2d_lists, full_groupby from vllm.utils import GiB_bytes, LRUCache, flatten_2d_lists, full_groupby
from .hasher import MultiModalHasher from .hasher import MultiModalHasher
from .inputs import (MultiModalDataDict, MultiModalEncDecInputs, from .inputs import (MultiModalDataDict, MultiModalEncDecInputs,
......
...@@ -33,15 +33,17 @@ import uuid ...@@ -33,15 +33,17 @@ import uuid
import warnings import warnings
import weakref import weakref
from asyncio import FIRST_COMPLETED, AbstractEventLoop, Task from asyncio import FIRST_COMPLETED, AbstractEventLoop, Task
from collections import OrderedDict, UserDict, defaultdict from collections import UserDict, defaultdict
from collections.abc import (AsyncGenerator, Awaitable, Generator, Hashable, from collections.abc import (AsyncGenerator, Awaitable, Generator, Hashable,
Iterable, Iterator, Mapping) Iterable, Iterator, KeysView, Mapping)
from dataclasses import dataclass, field from dataclasses import dataclass, field
from functools import cache, lru_cache, partial, wraps from functools import cache, lru_cache, partial, wraps
from types import MappingProxyType
from typing import (TYPE_CHECKING, Any, Callable, Generic, Literal, NamedTuple, from typing import (TYPE_CHECKING, Any, Callable, Generic, Literal, NamedTuple,
Optional, Type, TypeVar, Union) Optional, Type, TypeVar, Union, cast, overload)
from uuid import uuid4 from uuid import uuid4
import cachetools
import cloudpickle import cloudpickle
import numpy as np import numpy as np
import numpy.typing as npt import numpy.typing as npt
...@@ -173,6 +175,7 @@ U = TypeVar("U") ...@@ -173,6 +175,7 @@ U = TypeVar("U")
_K = TypeVar("_K", bound=Hashable) _K = TypeVar("_K", bound=Hashable)
_V = TypeVar("_V") _V = TypeVar("_V")
_T = TypeVar("_T")
class _Sentinel: class _Sentinel:
...@@ -206,6 +209,19 @@ class Counter: ...@@ -206,6 +209,19 @@ class Counter:
self.counter = 0 self.counter = 0
class _MappingOrderCacheView(UserDict[_K, _V]):
def __init__(self, data: Mapping[_K, _V], ordered_keys: Mapping[_K, None]):
super().__init__(data)
self.ordered_keys = ordered_keys
def __iter__(self) -> Iterator[_K]:
return iter(self.ordered_keys)
def keys(self) -> KeysView[_K]:
return KeysView(self.ordered_keys)
class CacheInfo(NamedTuple): class CacheInfo(NamedTuple):
hits: int hits: int
total: int total: int
...@@ -218,45 +234,62 @@ class CacheInfo(NamedTuple): ...@@ -218,45 +234,62 @@ class CacheInfo(NamedTuple):
return self.hits / self.total return self.hits / self.total
class LRUCache(Generic[_K, _V]): class LRUCache(cachetools.LRUCache[_K, _V], Generic[_K, _V]):
"""Note: This class is not thread safe!"""
def __init__(self, capacity: int) -> None: def __init__(self,
self.cache = OrderedDict[_K, _V]() capacity: float,
getsizeof: Optional[Callable[[_V], float]] = None):
super().__init__(capacity, getsizeof)
self.pinned_items = set[_K]() self.pinned_items = set[_K]()
self.capacity = capacity self.capacity = capacity
self._hits = 0 self._hits = 0
self._total = 0 self._total = 0
def __contains__(self, key: _K) -> bool: def __delitem__(self, key: _K) -> None:
return key in self.cache run_on_remove = key in self
value = self.__getitem__(key)
def __len__(self) -> int: super().__delitem__(key)
return len(self.cache) if key in self.pinned_items:
# Todo: add warning to inform that del pinned item
def __getitem__(self, key: _K) -> _V: self._unpin(key)
value = self.cache[key] # Raise KeyError if not exists if run_on_remove:
self.cache.move_to_end(key) self._on_remove(key, value)
return value
def __setitem__(self, key: _K, value: _V) -> None: @property
self.put(key, value) def cache(self) -> Mapping[_K, _V]:
"""Return the internal cache dictionary in order (read-only)."""
return _MappingOrderCacheView(
self._Cache__data, # type: ignore
self.order)
def __delitem__(self, key: _K) -> None: @property
self.pop(key) def order(self) -> Mapping[_K, None]:
"""Return the internal order dictionary (read-only)."""
return MappingProxyType(self._LRUCache__order) # type: ignore
def stat(self) -> CacheInfo: def stat(self) -> CacheInfo:
return CacheInfo(hits=self._hits, total=self._total) return CacheInfo(hits=self._hits, total=self._total)
def touch(self, key: _K) -> None: def touch(self, key: _K) -> None:
self.cache.move_to_end(key) self._LRUCache__update(key) # type: ignore
@overload
def get(self, key: _K, /) -> Optional[_V]:
...
@overload
def get(self, key: _K, /, default: Union[_V, _T]) -> Union[_V, _T]:
...
def get(self, key: _K, default: Optional[_V] = None) -> Optional[_V]: def get(self,
value: Optional[_V] key: _K,
if key in self.cache: /,
value = self.cache[key] default: Optional[Union[_V,
self.cache.move_to_end(key) _T]] = None) -> Optional[Union[_V, _T]]:
value: Optional[Union[_V, _T]]
if key in self:
value = self.__getitem__(key)
self._hits += 1 self._hits += 1
else: else:
...@@ -265,60 +298,76 @@ class LRUCache(Generic[_K, _V]): ...@@ -265,60 +298,76 @@ class LRUCache(Generic[_K, _V]):
self._total += 1 self._total += 1
return value return value
@overload
def pop(self, key: _K) -> _V:
...
@overload
def pop(self, key: _K, default: Union[_V, _T]) -> Union[_V, _T]:
...
def pop(self,
key: _K,
default: Optional[Union[_V,
_T]] = None) -> Optional[Union[_V, _T]]:
value: Optional[Union[_V, _T]]
if key not in self:
return default
value = self[key]
del self[key]
return value
def put(self, key: _K, value: _V) -> None: def put(self, key: _K, value: _V) -> None:
self.cache[key] = value self.__setitem__(key, value)
self.cache.move_to_end(key)
self._remove_old_if_needed()
def pin(self, key: _K) -> None: def pin(self, key: _K) -> None:
""" """
Pins a key in the cache preventing it from being Pins a key in the cache preventing it from being
evicted in the LRU order. evicted in the LRU order.
""" """
if key not in self.cache: if key not in self:
raise ValueError(f"Cannot pin key: {key} not in cache.") raise ValueError(f"Cannot pin key: {key} not in cache.")
self.pinned_items.add(key) self.pinned_items.add(key)
def _unpin(self, key: _K) -> None: def _unpin(self, key: _K) -> None:
"""
Unpins a key in the cache allowing it to be
evicted in the LRU order.
"""
self.pinned_items.remove(key) self.pinned_items.remove(key)
def _on_remove(self, key: _K, value: Optional[_V]) -> None: def _on_remove(self, key: _K, value: Optional[_V]) -> None:
pass pass
def remove_oldest(self, *, remove_pinned: bool = False) -> None: def remove_oldest(self, *, remove_pinned: bool = False) -> None:
if not self.cache: if len(self) == 0:
return return
self.popitem(remove_pinned=remove_pinned)
def _remove_old_if_needed(self) -> None:
while self.currsize > self.capacity:
self.remove_oldest()
def clear(self) -> None:
while len(self) > 0:
self.remove_oldest(remove_pinned=True)
def popitem(self, remove_pinned: bool = False):
"""Remove and return the `(key, value)` pair least recently used."""
if not remove_pinned: if not remove_pinned:
# pop the oldest item in the cache that is not pinned # pop the oldest item in the cache that is not pinned
lru_key = next( lru_key = next(
(key for key in self.cache if key not in self.pinned_items), (key for key in self.order if key not in self.pinned_items),
ALL_PINNED_SENTINEL) ALL_PINNED_SENTINEL)
if lru_key is ALL_PINNED_SENTINEL: if lru_key is ALL_PINNED_SENTINEL:
raise RuntimeError("All items are pinned, " raise RuntimeError("All items are pinned, "
"cannot remove oldest from the cache.") "cannot remove oldest from the cache.")
else: else:
lru_key = next(iter(self.cache)) lru_key = next(iter(self.order))
self.pop(lru_key) # type: ignore value = self.pop(cast(_K, lru_key))
return (lru_key, value)
def _remove_old_if_needed(self) -> None:
while len(self.cache) > self.capacity:
self.remove_oldest()
def pop(self, key: _K, default: Optional[_V] = None) -> Optional[_V]:
run_on_remove = key in self.cache
value = self.cache.pop(key, default)
# remove from pinned items
if key in self.pinned_items:
self._unpin(key)
if run_on_remove:
self._on_remove(key, value)
return value
def clear(self) -> None:
while len(self.cache) > 0:
self.remove_oldest(remove_pinned=True)
self.cache.clear()
class PyObjectCache: class PyObjectCache:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment