Unverified Commit 4ad521d8 authored by Nick Hill's avatar Nick Hill Committed by GitHub
Browse files

[Core] Add generic typing to `LRUCache` (#3511)

parent 9474e89b
......@@ -4,7 +4,7 @@ import logging
import math
import os
import re
from typing import (Any, Callable, Dict, Hashable, List, Optional, Tuple, Type)
from typing import (Callable, Dict, Hashable, List, Optional, Tuple, Type)
import safetensors.torch
import torch
......@@ -535,14 +535,14 @@ class LoRAModelManager:
replacement_loras)
class LoRALRUCache(LRUCache):
class LoRALRUCache(LRUCache[LoRAModel]):
def __init__(self, capacity: int, deactivate_lora_fn: Callable[[Hashable],
None]):
super().__init__(capacity)
self.deactivate_lora_fn = deactivate_lora_fn
def _on_remove(self, key: Hashable, value: Any):
def _on_remove(self, key: Hashable, value: LoRAModel):
logger.debug(f"Removing LoRA. int id: {key}")
self.deactivate_lora_fn(key)
return super()._on_remove(key, value)
......
......@@ -22,27 +22,34 @@ class BaseTokenizerGroup(ABC):
pass
@abstractmethod
def encode(self, prompt: str, request_id: Optional[str],
lora_request: Optional[LoRARequest]) -> List[int]:
def encode(self,
prompt: str,
request_id: Optional[str] = None,
lora_request: Optional[LoRARequest] = None) -> List[int]:
"""Encode a prompt using the tokenizer group."""
pass
@abstractmethod
async def encode_async(self, prompt: str, request_id: Optional[str],
lora_request: Optional[LoRARequest]) -> List[int]:
async def encode_async(
self,
prompt: str,
request_id: Optional[str] = None,
lora_request: Optional[LoRARequest] = None) -> List[int]:
"""Encode a prompt using the tokenizer group."""
pass
@abstractmethod
def get_lora_tokenizer(
self,
lora_request: Optional[LoRARequest]) -> "PreTrainedTokenizer":
lora_request: Optional[LoRARequest] = None
) -> "PreTrainedTokenizer":
"""Get a tokenizer for a LoRA request."""
pass
@abstractmethod
async def get_lora_tokenizer_async(
self,
lora_request: Optional[LoRARequest]) -> "PreTrainedTokenizer":
lora_request: Optional[LoRARequest] = None
) -> "PreTrainedTokenizer":
"""Get a tokenizer for a LoRA request."""
pass
......@@ -21,10 +21,8 @@ class TokenizerGroup(BaseTokenizerGroup):
self.enable_lora = enable_lora
self.max_input_length = max_input_length
self.tokenizer = get_tokenizer(self.tokenizer_id, **tokenizer_config)
if enable_lora:
self.lora_tokenizers = LRUCache(capacity=max_num_seqs)
else:
self.lora_tokenizers = None
self.lora_tokenizers = LRUCache[PreTrainedTokenizer](
capacity=max_num_seqs) if enable_lora else None
def ping(self) -> bool:
"""Check if the tokenizer group is alive."""
......
......@@ -5,7 +5,7 @@ import subprocess
import uuid
import gc
from platform import uname
from typing import List, Tuple, Union
from typing import List, Tuple, Union, Generic
from packaging.version import parse, Version
import psutil
......@@ -53,10 +53,10 @@ class Counter:
self.counter = 0
class LRUCache:
class LRUCache(Generic[T]):
def __init__(self, capacity: int):
self.cache = OrderedDict()
self.cache = OrderedDict[Hashable, T]()
self.capacity = capacity
def __contains__(self, key: Hashable) -> bool:
......@@ -65,10 +65,10 @@ class LRUCache:
def __len__(self) -> int:
return len(self.cache)
def __getitem__(self, key: Hashable) -> Any:
def __getitem__(self, key: Hashable) -> T:
return self.get(key)
def __setitem__(self, key: Hashable, value: Any) -> None:
def __setitem__(self, key: Hashable, value: T) -> None:
self.put(key, value)
def __delitem__(self, key: Hashable) -> None:
......@@ -77,7 +77,9 @@ class LRUCache:
def touch(self, key: Hashable) -> None:
self.cache.move_to_end(key)
def get(self, key: Hashable, default_value: Optional[Any] = None) -> int:
def get(self,
key: Hashable,
default_value: Optional[T] = None) -> Optional[T]:
if key in self.cache:
value = self.cache[key]
self.cache.move_to_end(key)
......@@ -85,12 +87,12 @@ class LRUCache:
value = default_value
return value
def put(self, key: Hashable, value: Any) -> None:
def put(self, key: Hashable, value: T) -> None:
self.cache[key] = value
self.cache.move_to_end(key)
self._remove_old_if_needed()
def _on_remove(self, key: Hashable, value: Any):
def _on_remove(self, key: Hashable, value: T):
pass
def remove_oldest(self):
......@@ -103,7 +105,7 @@ class LRUCache:
while len(self.cache) > self.capacity:
self.remove_oldest()
def pop(self, key: int, default_value: Optional[Any] = None) -> Any:
def pop(self, key: Hashable, default_value: Optional[Any] = None) -> T:
run_on_remove = key in self.cache
value = self.cache.pop(key, default_value)
if run_on_remove:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment