Unverified Commit 4ad521d8 authored by Nick Hill's avatar Nick Hill Committed by GitHub
Browse files

[Core] Add generic typing to `LRUCache` (#3511)

parent 9474e89b
...@@ -4,7 +4,7 @@ import logging ...@@ -4,7 +4,7 @@ import logging
import math import math
import os import os
import re import re
from typing import (Any, Callable, Dict, Hashable, List, Optional, Tuple, Type) from typing import (Callable, Dict, Hashable, List, Optional, Tuple, Type)
import safetensors.torch import safetensors.torch
import torch import torch
...@@ -535,14 +535,14 @@ class LoRAModelManager: ...@@ -535,14 +535,14 @@ class LoRAModelManager:
replacement_loras) replacement_loras)
class LoRALRUCache(LRUCache): class LoRALRUCache(LRUCache[LoRAModel]):
def __init__(self, capacity: int, deactivate_lora_fn: Callable[[Hashable], def __init__(self, capacity: int, deactivate_lora_fn: Callable[[Hashable],
None]): None]):
super().__init__(capacity) super().__init__(capacity)
self.deactivate_lora_fn = deactivate_lora_fn self.deactivate_lora_fn = deactivate_lora_fn
def _on_remove(self, key: Hashable, value: Any): def _on_remove(self, key: Hashable, value: LoRAModel):
logger.debug(f"Removing LoRA. int id: {key}") logger.debug(f"Removing LoRA. int id: {key}")
self.deactivate_lora_fn(key) self.deactivate_lora_fn(key)
return super()._on_remove(key, value) return super()._on_remove(key, value)
......
...@@ -22,27 +22,34 @@ class BaseTokenizerGroup(ABC): ...@@ -22,27 +22,34 @@ class BaseTokenizerGroup(ABC):
pass pass
@abstractmethod @abstractmethod
def encode(self, prompt: str, request_id: Optional[str], def encode(self,
lora_request: Optional[LoRARequest]) -> List[int]: prompt: str,
request_id: Optional[str] = None,
lora_request: Optional[LoRARequest] = None) -> List[int]:
"""Encode a prompt using the tokenizer group.""" """Encode a prompt using the tokenizer group."""
pass pass
@abstractmethod @abstractmethod
async def encode_async(self, prompt: str, request_id: Optional[str], async def encode_async(
lora_request: Optional[LoRARequest]) -> List[int]: self,
prompt: str,
request_id: Optional[str] = None,
lora_request: Optional[LoRARequest] = None) -> List[int]:
"""Encode a prompt using the tokenizer group.""" """Encode a prompt using the tokenizer group."""
pass pass
@abstractmethod @abstractmethod
def get_lora_tokenizer( def get_lora_tokenizer(
self, self,
lora_request: Optional[LoRARequest]) -> "PreTrainedTokenizer": lora_request: Optional[LoRARequest] = None
) -> "PreTrainedTokenizer":
"""Get a tokenizer for a LoRA request.""" """Get a tokenizer for a LoRA request."""
pass pass
@abstractmethod @abstractmethod
async def get_lora_tokenizer_async( async def get_lora_tokenizer_async(
self, self,
lora_request: Optional[LoRARequest]) -> "PreTrainedTokenizer": lora_request: Optional[LoRARequest] = None
) -> "PreTrainedTokenizer":
"""Get a tokenizer for a LoRA request.""" """Get a tokenizer for a LoRA request."""
pass pass
...@@ -21,10 +21,8 @@ class TokenizerGroup(BaseTokenizerGroup): ...@@ -21,10 +21,8 @@ class TokenizerGroup(BaseTokenizerGroup):
self.enable_lora = enable_lora self.enable_lora = enable_lora
self.max_input_length = max_input_length self.max_input_length = max_input_length
self.tokenizer = get_tokenizer(self.tokenizer_id, **tokenizer_config) self.tokenizer = get_tokenizer(self.tokenizer_id, **tokenizer_config)
if enable_lora: self.lora_tokenizers = LRUCache[PreTrainedTokenizer](
self.lora_tokenizers = LRUCache(capacity=max_num_seqs) capacity=max_num_seqs) if enable_lora else None
else:
self.lora_tokenizers = None
def ping(self) -> bool: def ping(self) -> bool:
"""Check if the tokenizer group is alive.""" """Check if the tokenizer group is alive."""
......
...@@ -5,7 +5,7 @@ import subprocess ...@@ -5,7 +5,7 @@ import subprocess
import uuid import uuid
import gc import gc
from platform import uname from platform import uname
from typing import List, Tuple, Union from typing import List, Tuple, Union, Generic
from packaging.version import parse, Version from packaging.version import parse, Version
import psutil import psutil
...@@ -53,10 +53,10 @@ class Counter: ...@@ -53,10 +53,10 @@ class Counter:
self.counter = 0 self.counter = 0
class LRUCache: class LRUCache(Generic[T]):
def __init__(self, capacity: int): def __init__(self, capacity: int):
self.cache = OrderedDict() self.cache = OrderedDict[Hashable, T]()
self.capacity = capacity self.capacity = capacity
def __contains__(self, key: Hashable) -> bool: def __contains__(self, key: Hashable) -> bool:
...@@ -65,10 +65,10 @@ class LRUCache: ...@@ -65,10 +65,10 @@ class LRUCache:
def __len__(self) -> int: def __len__(self) -> int:
return len(self.cache) return len(self.cache)
def __getitem__(self, key: Hashable) -> Any: def __getitem__(self, key: Hashable) -> T:
return self.get(key) return self.get(key)
def __setitem__(self, key: Hashable, value: Any) -> None: def __setitem__(self, key: Hashable, value: T) -> None:
self.put(key, value) self.put(key, value)
def __delitem__(self, key: Hashable) -> None: def __delitem__(self, key: Hashable) -> None:
...@@ -77,7 +77,9 @@ class LRUCache: ...@@ -77,7 +77,9 @@ class LRUCache:
def touch(self, key: Hashable) -> None: def touch(self, key: Hashable) -> None:
self.cache.move_to_end(key) self.cache.move_to_end(key)
def get(self, key: Hashable, default_value: Optional[Any] = None) -> int: def get(self,
key: Hashable,
default_value: Optional[T] = None) -> Optional[T]:
if key in self.cache: if key in self.cache:
value = self.cache[key] value = self.cache[key]
self.cache.move_to_end(key) self.cache.move_to_end(key)
...@@ -85,12 +87,12 @@ class LRUCache: ...@@ -85,12 +87,12 @@ class LRUCache:
value = default_value value = default_value
return value return value
def put(self, key: Hashable, value: Any) -> None: def put(self, key: Hashable, value: T) -> None:
self.cache[key] = value self.cache[key] = value
self.cache.move_to_end(key) self.cache.move_to_end(key)
self._remove_old_if_needed() self._remove_old_if_needed()
def _on_remove(self, key: Hashable, value: Any): def _on_remove(self, key: Hashable, value: T):
pass pass
def remove_oldest(self): def remove_oldest(self):
...@@ -103,7 +105,7 @@ class LRUCache: ...@@ -103,7 +105,7 @@ class LRUCache:
while len(self.cache) > self.capacity: while len(self.cache) > self.capacity:
self.remove_oldest() self.remove_oldest()
def pop(self, key: int, default_value: Optional[Any] = None) -> Any: def pop(self, key: Hashable, default_value: Optional[Any] = None) -> T:
run_on_remove = key in self.cache run_on_remove = key in self.cache
value = self.cache.pop(key, default_value) value = self.cache.pop(key, default_value)
if run_on_remove: if run_on_remove:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment