Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
cdf22afd
Unverified
Commit
cdf22afd
authored
Dec 20, 2024
by
Cyrus Leung
Committed by
GitHub
Dec 20, 2024
Browse files
[Misc] Clean up and consolidate LRUCache (#11339)
Signed-off-by:
DarkLight1337
<
tlleungac@connect.ust.hk
>
parent
e24113a8
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
34 additions
and
67 deletions
+34
-67
vllm/adapter_commons/models.py
vllm/adapter_commons/models.py
+4
-5
vllm/transformers_utils/tokenizer_group/tokenizer_group.py
vllm/transformers_utils/tokenizer_group/tokenizer_group.py
+1
-1
vllm/utils.py
vllm/utils.py
+26
-33
vllm/v1/engine/mm_input_mapper.py
vllm/v1/engine/mm_input_mapper.py
+3
-3
vllm/v1/utils.py
vllm/v1/utils.py
+0
-25
No files found.
vllm/adapter_commons/models.py
View file @
cdf22afd
from
abc
import
ABC
,
abstractmethod
from
typing
import
Any
,
Callable
,
Dict
,
Hashable
,
Optional
,
TypeVar
from
typing
import
Any
,
Callable
,
Dict
,
Optional
,
TypeVar
from
torch
import
nn
...
...
@@ -24,14 +24,13 @@ class AdapterModel(ABC):
T
=
TypeVar
(
'T'
)
class
AdapterLRUCache
(
LRUCache
[
T
]):
class
AdapterLRUCache
(
LRUCache
[
int
,
T
]):
def
__init__
(
self
,
capacity
:
int
,
deactivate_fn
:
Callable
[[
Hashable
],
None
]):
def
__init__
(
self
,
capacity
:
int
,
deactivate_fn
:
Callable
[[
int
],
object
]):
super
().
__init__
(
capacity
)
self
.
deactivate_fn
=
deactivate_fn
def
_on_remove
(
self
,
key
:
Hashable
,
value
:
Optional
[
T
]):
def
_on_remove
(
self
,
key
:
int
,
value
:
Optional
[
T
]):
logger
.
debug
(
"Removing adapter int id: %d"
,
key
)
self
.
deactivate_fn
(
key
)
return
super
().
_on_remove
(
key
,
value
)
...
...
vllm/transformers_utils/tokenizer_group/tokenizer_group.py
View file @
cdf22afd
...
...
@@ -22,7 +22,7 @@ class TokenizerGroup(BaseTokenizerGroup):
self
.
max_input_length
=
max_input_length
self
.
tokenizer
=
get_tokenizer
(
self
.
tokenizer_id
,
**
tokenizer_config
)
max_loras
=
tokenizer_config
.
get
(
"max_loras"
,
0
)
self
.
lora_tokenizers
=
LRUCache
[
AnyTokenizer
](
self
.
lora_tokenizers
=
LRUCache
[
int
,
AnyTokenizer
](
capacity
=
max
(
max_loras
,
max_num_seqs
)
if
enable_lora
else
0
)
@
classmethod
...
...
vllm/utils.py
View file @
cdf22afd
...
...
@@ -21,14 +21,13 @@ import uuid
import
warnings
import
weakref
from
asyncio
import
FIRST_COMPLETED
,
AbstractEventLoop
,
Task
from
collections
import
UserDict
,
defaultdict
from
collections
import
OrderedDict
,
UserDict
,
defaultdict
from
collections.abc
import
Iterable
,
Mapping
from
dataclasses
import
dataclass
,
field
from
functools
import
lru_cache
,
partial
,
wraps
from
typing
import
(
TYPE_CHECKING
,
Any
,
AsyncGenerator
,
Awaitable
,
Callable
,
Dict
,
Generator
,
Generic
,
Hashable
,
List
,
Literal
,
Optional
,
OrderedDict
,
Set
,
Tuple
,
Type
,
TypeVar
,
Union
,
overload
)
Optional
,
Tuple
,
Type
,
TypeVar
,
Union
,
overload
)
from
uuid
import
uuid4
import
numpy
as
np
...
...
@@ -154,10 +153,12 @@ TORCH_DTYPE_TO_NUMPY_DTYPE = {
}
P
=
ParamSpec
(
'P'
)
K
=
TypeVar
(
"K"
)
T
=
TypeVar
(
"T"
)
U
=
TypeVar
(
"U"
)
_K
=
TypeVar
(
"_K"
,
bound
=
Hashable
)
_V
=
TypeVar
(
"_V"
)
class
_Sentinel
:
...
...
...
@@ -190,50 +191,48 @@ class Counter:
self
.
counter
=
0
class
LRUCache
(
Generic
[
T
]):
class
LRUCache
(
Generic
[
_K
,
_V
]):
def
__init__
(
self
,
capacity
:
int
):
self
.
cache
:
OrderedDict
[
Hashable
,
T
]
=
OrderedDict
()
self
.
pinned_items
:
Set
[
Hashable
]
=
set
()
def
__init__
(
self
,
capacity
:
int
)
->
None
:
self
.
cache
=
OrderedDict
[
_K
,
_V
]
()
self
.
pinned_items
=
set
[
_K
]
()
self
.
capacity
=
capacity
def
__contains__
(
self
,
key
:
Hashable
)
->
bool
:
def
__contains__
(
self
,
key
:
_K
)
->
bool
:
return
key
in
self
.
cache
def
__len__
(
self
)
->
int
:
return
len
(
self
.
cache
)
def
__getitem__
(
self
,
key
:
Hashable
)
->
T
:
def
__getitem__
(
self
,
key
:
_K
)
->
_V
:
value
=
self
.
cache
[
key
]
# Raise KeyError if not exists
self
.
cache
.
move_to_end
(
key
)
return
value
def
__setitem__
(
self
,
key
:
Hashable
,
value
:
T
)
->
None
:
def
__setitem__
(
self
,
key
:
_K
,
value
:
_V
)
->
None
:
self
.
put
(
key
,
value
)
def
__delitem__
(
self
,
key
:
Hashable
)
->
None
:
def
__delitem__
(
self
,
key
:
_K
)
->
None
:
self
.
pop
(
key
)
def
touch
(
self
,
key
:
Hashable
)
->
None
:
def
touch
(
self
,
key
:
_K
)
->
None
:
self
.
cache
.
move_to_end
(
key
)
def
get
(
self
,
key
:
Hashable
,
default_value
:
Optional
[
T
]
=
None
)
->
Optional
[
T
]:
value
:
Optional
[
T
]
def
get
(
self
,
key
:
_K
,
default
:
Optional
[
_V
]
=
None
)
->
Optional
[
_V
]:
value
:
Optional
[
_V
]
if
key
in
self
.
cache
:
value
=
self
.
cache
[
key
]
self
.
cache
.
move_to_end
(
key
)
else
:
value
=
default
_value
value
=
default
return
value
def
put
(
self
,
key
:
Hashable
,
value
:
T
)
->
None
:
def
put
(
self
,
key
:
_K
,
value
:
_V
)
->
None
:
self
.
cache
[
key
]
=
value
self
.
cache
.
move_to_end
(
key
)
self
.
_remove_old_if_needed
()
def
pin
(
self
,
key
:
Hashable
)
->
None
:
def
pin
(
self
,
key
:
_K
)
->
None
:
"""
Pins a key in the cache preventing it from being
evicted in the LRU order.
...
...
@@ -242,13 +241,13 @@ class LRUCache(Generic[T]):
raise
ValueError
(
f
"Cannot pin key:
{
key
}
not in cache."
)
self
.
pinned_items
.
add
(
key
)
def
_unpin
(
self
,
key
:
Hashable
)
->
None
:
def
_unpin
(
self
,
key
:
_K
)
->
None
:
self
.
pinned_items
.
remove
(
key
)
def
_on_remove
(
self
,
key
:
Hashable
,
value
:
Optional
[
T
])
:
def
_on_remove
(
self
,
key
:
_K
,
value
:
Optional
[
_V
])
->
None
:
pass
def
remove_oldest
(
self
,
remove_pinned
=
False
)
:
def
remove_oldest
(
self
,
*
,
remove_pinned
:
bool
=
False
)
->
None
:
if
not
self
.
cache
:
return
...
...
@@ -262,17 +261,15 @@ class LRUCache(Generic[T]):
"cannot remove oldest from the cache."
)
else
:
lru_key
=
next
(
iter
(
self
.
cache
))
self
.
pop
(
lru_key
)
self
.
pop
(
lru_key
)
# type: ignore
def
_remove_old_if_needed
(
self
)
->
None
:
while
len
(
self
.
cache
)
>
self
.
capacity
:
self
.
remove_oldest
()
def
pop
(
self
,
key
:
Hashable
,
default_value
:
Optional
[
T
]
=
None
)
->
Optional
[
T
]:
def
pop
(
self
,
key
:
_K
,
default
:
Optional
[
_V
]
=
None
)
->
Optional
[
_V
]:
run_on_remove
=
key
in
self
.
cache
value
:
Optional
[
T
]
=
self
.
cache
.
pop
(
key
,
default
_value
)
value
=
self
.
cache
.
pop
(
key
,
default
)
# remove from pinned items
if
key
in
self
.
pinned_items
:
self
.
_unpin
(
key
)
...
...
@@ -280,7 +277,7 @@ class LRUCache(Generic[T]):
self
.
_on_remove
(
key
,
value
)
return
value
def
clear
(
self
):
def
clear
(
self
)
->
None
:
while
len
(
self
.
cache
)
>
0
:
self
.
remove_oldest
(
remove_pinned
=
True
)
self
.
cache
.
clear
()
...
...
@@ -843,10 +840,6 @@ def flatten_2d_lists(lists: List[List[T]]) -> List[T]:
return
[
item
for
sublist
in
lists
for
item
in
sublist
]
_K
=
TypeVar
(
"_K"
,
bound
=
Hashable
)
_V
=
TypeVar
(
"_V"
)
def
full_groupby
(
values
:
Iterable
[
_V
],
*
,
key
:
Callable
[[
_V
],
_K
]):
"""
Unlike :class:`itertools.groupby`, groups are not broken by
...
...
vllm/v1/engine/mm_input_mapper.py
View file @
cdf22afd
...
...
@@ -8,7 +8,7 @@ from vllm.inputs import PromptType
from
vllm.logger
import
init_logger
from
vllm.multimodal
import
(
MULTIMODAL_REGISTRY
,
MultiModalDataDict
,
MultiModalKwargs
,
MultiModalRegistry
)
from
vllm.
v1.
utils
import
LRU
Dict
Cache
from
vllm.utils
import
LRUCache
logger
=
init_logger
(
__name__
)
...
...
@@ -44,7 +44,7 @@ class MMInputMapperClient:
# Init cache
self
.
use_cache
=
not
model_config
.
disable_mm_preprocessor_cache
self
.
mm_cache
=
LRU
Dict
Cache
[
str
,
MultiModalKwargs
](
MM_CACHE_SIZE
)
self
.
mm_cache
=
LRUCache
[
str
,
MultiModalKwargs
](
MM_CACHE_SIZE
)
# DEBUG: Set to None to disable
self
.
mm_debug_cache_hit_ratio_steps
=
None
...
...
@@ -120,7 +120,7 @@ class MMInputMapperServer:
def
__init__
(
self
,
model_config
):
self
.
use_cache
=
not
model_config
.
disable_mm_preprocessor_cache
self
.
mm_cache
=
LRU
Dict
Cache
[
str
,
MultiModalKwargs
](
MM_CACHE_SIZE
)
self
.
mm_cache
=
LRUCache
[
str
,
MultiModalKwargs
](
MM_CACHE_SIZE
)
def
process_inputs
(
self
,
...
...
vllm/v1/utils.py
View file @
cdf22afd
from
collections
import
OrderedDict
from
collections.abc
import
Sequence
from
contextlib
import
contextmanager
from
typing
import
(
Any
,
Generic
,
Iterator
,
List
,
Optional
,
TypeVar
,
Union
,
...
...
@@ -102,27 +101,3 @@ def make_zmq_socket(
finally
:
ctx
.
destroy
(
linger
=
0
)
K
=
TypeVar
(
'K'
)
V
=
TypeVar
(
'V'
)
class
LRUDictCache
(
Generic
[
K
,
V
]):
def
__init__
(
self
,
size
:
int
):
self
.
cache
:
OrderedDict
[
K
,
V
]
=
OrderedDict
()
self
.
size
=
size
def
get
(
self
,
key
:
K
,
default
=
None
)
->
V
:
if
key
not
in
self
.
cache
:
return
default
self
.
cache
.
move_to_end
(
key
)
return
self
.
cache
[
key
]
def
put
(
self
,
key
:
K
,
value
:
V
):
self
.
cache
[
key
]
=
value
self
.
cache
.
move_to_end
(
key
)
if
len
(
self
.
cache
)
>
self
.
size
:
self
.
cache
.
popitem
(
last
=
False
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment