Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
aa29841e
Unverified
Commit
aa29841e
authored
Apr 15, 2025
by
Cyrus Leung
Committed by
GitHub
Apr 14, 2025
Browse files
[Bugfix] Multi-modal caches not acting like LRU caches (#16593)
Signed-off-by:
DarkLight1337
<
tlleungac@connect.ust.hk
>
parent
6bf27aff
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
187 additions
and
126 deletions
+187
-126
tests/lora/test_utils.py
tests/lora/test_utils.py
+0
-109
tests/test_utils.py
tests/test_utils.py
+128
-5
vllm/utils.py
vllm/utils.py
+58
-11
vllm/v1/engine/mm_input_cache.py
vllm/v1/engine/mm_input_cache.py
+1
-1
No files found.
tests/lora/test_utils.py
View file @
aa29841e
...
@@ -9,7 +9,6 @@ from torch import nn
...
@@ -9,7 +9,6 @@ from torch import nn
from
vllm.lora.utils
import
(
get_adapter_absolute_path
,
from
vllm.lora.utils
import
(
get_adapter_absolute_path
,
parse_fine_tuned_lora_name
,
replace_submodule
)
parse_fine_tuned_lora_name
,
replace_submodule
)
from
vllm.utils
import
LRUCache
def
test_parse_fine_tuned_lora_name_valid
():
def
test_parse_fine_tuned_lora_name_valid
():
...
@@ -85,114 +84,6 @@ def test_replace_submodule():
...
@@ -85,114 +84,6 @@ def test_replace_submodule():
assert
dict
(
model
.
named_modules
())[
"seq1.dense2"
]
==
dense2
assert
dict
(
model
.
named_modules
())[
"seq1.dense2"
]
==
dense2
class
TestLRUCache
(
LRUCache
):
def
_on_remove
(
self
,
key
,
value
):
if
not
hasattr
(
self
,
"_remove_counter"
):
self
.
_remove_counter
=
0
self
.
_remove_counter
+=
1
def
test_lru_cache
():
cache
=
TestLRUCache
(
3
)
cache
.
put
(
1
,
1
)
assert
len
(
cache
)
==
1
cache
.
put
(
1
,
1
)
assert
len
(
cache
)
==
1
cache
.
put
(
2
,
2
)
assert
len
(
cache
)
==
2
cache
.
put
(
3
,
3
)
assert
len
(
cache
)
==
3
assert
set
(
cache
.
cache
)
==
{
1
,
2
,
3
}
cache
.
put
(
4
,
4
)
assert
len
(
cache
)
==
3
assert
set
(
cache
.
cache
)
==
{
2
,
3
,
4
}
assert
cache
.
_remove_counter
==
1
assert
cache
.
get
(
2
)
==
2
cache
.
put
(
5
,
5
)
assert
set
(
cache
.
cache
)
==
{
2
,
4
,
5
}
assert
cache
.
_remove_counter
==
2
assert
cache
.
pop
(
5
)
==
5
assert
len
(
cache
)
==
2
assert
set
(
cache
.
cache
)
==
{
2
,
4
}
assert
cache
.
_remove_counter
==
3
cache
.
pop
(
10
)
assert
len
(
cache
)
==
2
assert
set
(
cache
.
cache
)
==
{
2
,
4
}
assert
cache
.
_remove_counter
==
3
cache
.
get
(
10
)
assert
len
(
cache
)
==
2
assert
set
(
cache
.
cache
)
==
{
2
,
4
}
assert
cache
.
_remove_counter
==
3
cache
.
put
(
6
,
6
)
assert
len
(
cache
)
==
3
assert
set
(
cache
.
cache
)
==
{
2
,
4
,
6
}
assert
2
in
cache
assert
4
in
cache
assert
6
in
cache
cache
.
remove_oldest
()
assert
len
(
cache
)
==
2
assert
set
(
cache
.
cache
)
==
{
2
,
6
}
assert
cache
.
_remove_counter
==
4
cache
.
clear
()
assert
len
(
cache
)
==
0
assert
cache
.
_remove_counter
==
6
cache
.
_remove_counter
=
0
cache
[
1
]
=
1
assert
len
(
cache
)
==
1
cache
[
1
]
=
1
assert
len
(
cache
)
==
1
cache
[
2
]
=
2
assert
len
(
cache
)
==
2
cache
[
3
]
=
3
assert
len
(
cache
)
==
3
assert
set
(
cache
.
cache
)
==
{
1
,
2
,
3
}
cache
[
4
]
=
4
assert
len
(
cache
)
==
3
assert
set
(
cache
.
cache
)
==
{
2
,
3
,
4
}
assert
cache
.
_remove_counter
==
1
assert
cache
[
2
]
==
2
cache
[
5
]
=
5
assert
set
(
cache
.
cache
)
==
{
2
,
4
,
5
}
assert
cache
.
_remove_counter
==
2
del
cache
[
5
]
assert
len
(
cache
)
==
2
assert
set
(
cache
.
cache
)
==
{
2
,
4
}
assert
cache
.
_remove_counter
==
3
cache
.
pop
(
10
)
assert
len
(
cache
)
==
2
assert
set
(
cache
.
cache
)
==
{
2
,
4
}
assert
cache
.
_remove_counter
==
3
cache
[
6
]
=
6
assert
len
(
cache
)
==
3
assert
set
(
cache
.
cache
)
==
{
2
,
4
,
6
}
assert
2
in
cache
assert
4
in
cache
assert
6
in
cache
# Unit tests for get_adapter_absolute_path
# Unit tests for get_adapter_absolute_path
@
patch
(
'os.path.isabs'
)
@
patch
(
'os.path.isabs'
)
def
test_get_adapter_absolute_path_absolute
(
mock_isabs
):
def
test_get_adapter_absolute_path_absolute
(
mock_isabs
):
...
...
tests/test_utils.py
View file @
aa29841e
...
@@ -13,11 +13,11 @@ import torch
...
@@ -13,11 +13,11 @@ import torch
from
vllm_test_utils.monitor
import
monitor
from
vllm_test_utils.monitor
import
monitor
from
vllm.config
import
ParallelConfig
,
VllmConfig
,
set_current_vllm_config
from
vllm.config
import
ParallelConfig
,
VllmConfig
,
set_current_vllm_config
from
vllm.utils
import
(
FlexibleArgumentParser
,
MemorySnapshot
,
from
vllm.utils
import
(
CacheInfo
,
FlexibleArgumentParser
,
LRUCache
,
PlaceholderModule
,
StoreBoolean
,
bind_kv_cache
,
MemorySnapshot
,
PlaceholderModule
,
StoreBoolean
,
deprecate_kwargs
,
get_open_port
,
memory_profiling
,
bind_kv_cache
,
deprecate_kwargs
,
get_open_port
,
merge_async_iterators
,
sha256
,
supports_kw
,
memory_profiling
,
merge_async_iterators
,
sha256
,
swap_dict_values
)
supports_kw
,
swap_dict_values
)
from
.utils
import
create_new_process_for_each_test
,
error_on_warning
from
.utils
import
create_new_process_for_each_test
,
error_on_warning
...
@@ -417,6 +417,129 @@ def test_bind_kv_cache_pp():
...
@@ -417,6 +417,129 @@ def test_bind_kv_cache_pp():
assert
ctx
[
'layers.0.self_attn'
].
kv_cache
[
1
]
is
kv_cache
[
1
][
0
]
assert
ctx
[
'layers.0.self_attn'
].
kv_cache
[
1
]
is
kv_cache
[
1
][
0
]
class
TestLRUCache
(
LRUCache
):
def
_on_remove
(
self
,
key
,
value
):
if
not
hasattr
(
self
,
"_remove_counter"
):
self
.
_remove_counter
=
0
self
.
_remove_counter
+=
1
def
test_lru_cache
():
cache
=
TestLRUCache
(
3
)
assert
cache
.
stat
()
==
CacheInfo
(
hits
=
0
,
total
=
0
)
assert
cache
.
stat
(
delta
=
True
)
==
CacheInfo
(
hits
=
0
,
total
=
0
)
cache
.
put
(
1
,
1
)
assert
len
(
cache
)
==
1
cache
.
put
(
1
,
1
)
assert
len
(
cache
)
==
1
cache
.
put
(
2
,
2
)
assert
len
(
cache
)
==
2
cache
.
put
(
3
,
3
)
assert
len
(
cache
)
==
3
assert
set
(
cache
.
cache
)
==
{
1
,
2
,
3
}
cache
.
put
(
4
,
4
)
assert
len
(
cache
)
==
3
assert
set
(
cache
.
cache
)
==
{
2
,
3
,
4
}
assert
cache
.
_remove_counter
==
1
assert
cache
.
get
(
2
)
==
2
assert
cache
.
stat
()
==
CacheInfo
(
hits
=
1
,
total
=
1
)
assert
cache
.
stat
(
delta
=
True
)
==
CacheInfo
(
hits
=
1
,
total
=
1
)
assert
cache
[
2
]
==
2
assert
cache
.
stat
()
==
CacheInfo
(
hits
=
2
,
total
=
2
)
assert
cache
.
stat
(
delta
=
True
)
==
CacheInfo
(
hits
=
1
,
total
=
1
)
cache
.
put
(
5
,
5
)
assert
set
(
cache
.
cache
)
==
{
2
,
4
,
5
}
assert
cache
.
_remove_counter
==
2
assert
cache
.
pop
(
5
)
==
5
assert
len
(
cache
)
==
2
assert
set
(
cache
.
cache
)
==
{
2
,
4
}
assert
cache
.
_remove_counter
==
3
assert
cache
.
get
(
-
1
)
is
None
assert
cache
.
stat
()
==
CacheInfo
(
hits
=
2
,
total
=
3
)
assert
cache
.
stat
(
delta
=
True
)
==
CacheInfo
(
hits
=
0
,
total
=
1
)
cache
.
pop
(
10
)
assert
len
(
cache
)
==
2
assert
set
(
cache
.
cache
)
==
{
2
,
4
}
assert
cache
.
_remove_counter
==
3
cache
.
get
(
10
)
assert
len
(
cache
)
==
2
assert
set
(
cache
.
cache
)
==
{
2
,
4
}
assert
cache
.
_remove_counter
==
3
cache
.
put
(
6
,
6
)
assert
len
(
cache
)
==
3
assert
set
(
cache
.
cache
)
==
{
2
,
4
,
6
}
assert
2
in
cache
assert
4
in
cache
assert
6
in
cache
cache
.
remove_oldest
()
assert
len
(
cache
)
==
2
assert
set
(
cache
.
cache
)
==
{
2
,
6
}
assert
cache
.
_remove_counter
==
4
cache
.
clear
()
assert
len
(
cache
)
==
0
assert
cache
.
_remove_counter
==
6
assert
cache
.
stat
()
==
CacheInfo
(
hits
=
0
,
total
=
0
)
assert
cache
.
stat
(
delta
=
True
)
==
CacheInfo
(
hits
=
0
,
total
=
0
)
cache
.
_remove_counter
=
0
cache
[
1
]
=
1
assert
len
(
cache
)
==
1
cache
[
1
]
=
1
assert
len
(
cache
)
==
1
cache
[
2
]
=
2
assert
len
(
cache
)
==
2
cache
[
3
]
=
3
assert
len
(
cache
)
==
3
assert
set
(
cache
.
cache
)
==
{
1
,
2
,
3
}
cache
[
4
]
=
4
assert
len
(
cache
)
==
3
assert
set
(
cache
.
cache
)
==
{
2
,
3
,
4
}
assert
cache
.
_remove_counter
==
1
assert
cache
[
2
]
==
2
cache
[
5
]
=
5
assert
set
(
cache
.
cache
)
==
{
2
,
4
,
5
}
assert
cache
.
_remove_counter
==
2
del
cache
[
5
]
assert
len
(
cache
)
==
2
assert
set
(
cache
.
cache
)
==
{
2
,
4
}
assert
cache
.
_remove_counter
==
3
cache
.
pop
(
10
)
assert
len
(
cache
)
==
2
assert
set
(
cache
.
cache
)
==
{
2
,
4
}
assert
cache
.
_remove_counter
==
3
cache
[
6
]
=
6
assert
len
(
cache
)
==
3
assert
set
(
cache
.
cache
)
==
{
2
,
4
,
6
}
assert
2
in
cache
assert
4
in
cache
assert
6
in
cache
def
test_placeholder_module_error_handling
():
def
test_placeholder_module_error_handling
():
placeholder
=
PlaceholderModule
(
"placeholder_1234"
)
placeholder
=
PlaceholderModule
(
"placeholder_1234"
)
...
...
vllm/utils.py
View file @
aa29841e
...
@@ -236,6 +236,12 @@ class CacheInfo(NamedTuple):
...
@@ -236,6 +236,12 @@ class CacheInfo(NamedTuple):
return
self
.
hits
/
self
.
total
return
self
.
hits
/
self
.
total
def
__sub__
(
self
,
other
:
CacheInfo
):
return
CacheInfo
(
hits
=
self
.
hits
-
other
.
hits
,
total
=
self
.
total
-
other
.
total
,
)
class
LRUCache
(
cachetools
.
LRUCache
[
_K
,
_V
],
Generic
[
_K
,
_V
]):
class
LRUCache
(
cachetools
.
LRUCache
[
_K
,
_V
],
Generic
[
_K
,
_V
]):
...
@@ -243,15 +249,26 @@ class LRUCache(cachetools.LRUCache[_K, _V], Generic[_K, _V]):
...
@@ -243,15 +249,26 @@ class LRUCache(cachetools.LRUCache[_K, _V], Generic[_K, _V]):
capacity
:
float
,
capacity
:
float
,
getsizeof
:
Optional
[
Callable
[[
_V
],
float
]]
=
None
):
getsizeof
:
Optional
[
Callable
[[
_V
],
float
]]
=
None
):
super
().
__init__
(
capacity
,
getsizeof
)
super
().
__init__
(
capacity
,
getsizeof
)
self
.
pinned_items
=
set
[
_K
]()
self
.
pinned_items
=
set
[
_K
]()
self
.
capacity
=
capacity
self
.
_hits
=
0
self
.
_hits
=
0
self
.
_total
=
0
self
.
_total
=
0
self
.
_last_info
=
CacheInfo
(
hits
=
0
,
total
=
0
)
def
__getitem__
(
self
,
key
:
_K
,
*
,
update_info
:
bool
=
True
)
->
_V
:
value
=
super
().
__getitem__
(
key
)
if
update_info
:
self
.
_hits
+=
1
self
.
_total
+=
1
return
value
def
__delitem__
(
self
,
key
:
_K
)
->
None
:
def
__delitem__
(
self
,
key
:
_K
)
->
None
:
run_on_remove
=
key
in
self
run_on_remove
=
key
in
self
value
=
self
.
__getitem__
(
key
)
value
=
self
.
__getitem__
(
key
,
update_info
=
False
)
# type: ignore[call-arg]
super
().
__delitem__
(
key
)
super
().
__delitem__
(
key
)
if
key
in
self
.
pinned_items
:
if
key
in
self
.
pinned_items
:
# Todo: add warning to inform that del pinned item
# Todo: add warning to inform that del pinned item
...
@@ -271,8 +288,32 @@ class LRUCache(cachetools.LRUCache[_K, _V], Generic[_K, _V]):
...
@@ -271,8 +288,32 @@ class LRUCache(cachetools.LRUCache[_K, _V], Generic[_K, _V]):
"""Return the internal order dictionary (read-only)."""
"""Return the internal order dictionary (read-only)."""
return
MappingProxyType
(
self
.
_LRUCache__order
)
# type: ignore
return
MappingProxyType
(
self
.
_LRUCache__order
)
# type: ignore
def
stat
(
self
)
->
CacheInfo
:
@
property
return
CacheInfo
(
hits
=
self
.
_hits
,
total
=
self
.
_total
)
def
capacity
(
self
)
->
float
:
return
self
.
maxsize
@
property
def
usage
(
self
)
->
float
:
if
self
.
maxsize
==
0
:
return
0
return
self
.
currsize
/
self
.
maxsize
def
stat
(
self
,
*
,
delta
:
bool
=
False
)
->
CacheInfo
:
"""
Gets the cumulative number of hits and queries against this cache.
If :code:`delta=True`, instead gets these statistics
since the last call that also passed :code:`delta=True`.
"""
info
=
CacheInfo
(
hits
=
self
.
_hits
,
total
=
self
.
_total
)
if
delta
:
info_delta
=
info
-
self
.
_last_info
self
.
_last_info
=
info
info
=
info_delta
return
info
def
touch
(
self
,
key
:
_K
)
->
None
:
def
touch
(
self
,
key
:
_K
)
->
None
:
self
.
_LRUCache__update
(
key
)
# type: ignore
self
.
_LRUCache__update
(
key
)
# type: ignore
...
@@ -292,7 +333,8 @@ class LRUCache(cachetools.LRUCache[_K, _V], Generic[_K, _V]):
...
@@ -292,7 +333,8 @@ class LRUCache(cachetools.LRUCache[_K, _V], Generic[_K, _V]):
_T
]]
=
None
)
->
Optional
[
Union
[
_V
,
_T
]]:
_T
]]
=
None
)
->
Optional
[
Union
[
_V
,
_T
]]:
value
:
Optional
[
Union
[
_V
,
_T
]]
value
:
Optional
[
Union
[
_V
,
_T
]]
if
key
in
self
:
if
key
in
self
:
value
=
self
.
__getitem__
(
key
)
value
=
self
.
__getitem__
(
key
,
update_info
=
False
)
# type: ignore[call-arg]
self
.
_hits
+=
1
self
.
_hits
+=
1
else
:
else
:
...
@@ -317,8 +359,9 @@ class LRUCache(cachetools.LRUCache[_K, _V], Generic[_K, _V]):
...
@@ -317,8 +359,9 @@ class LRUCache(cachetools.LRUCache[_K, _V], Generic[_K, _V]):
if
key
not
in
self
:
if
key
not
in
self
:
return
default
return
default
value
=
self
[
key
]
value
=
self
.
__getitem__
(
key
,
del
self
[
key
]
update_info
=
False
)
# type: ignore[call-arg]
self
.
__delitem__
(
key
)
return
value
return
value
def
put
(
self
,
key
:
_K
,
value
:
_V
)
->
None
:
def
put
(
self
,
key
:
_K
,
value
:
_V
)
->
None
:
...
@@ -353,10 +396,6 @@ class LRUCache(cachetools.LRUCache[_K, _V], Generic[_K, _V]):
...
@@ -353,10 +396,6 @@ class LRUCache(cachetools.LRUCache[_K, _V], Generic[_K, _V]):
while
self
.
currsize
>
self
.
capacity
:
while
self
.
currsize
>
self
.
capacity
:
self
.
remove_oldest
()
self
.
remove_oldest
()
def
clear
(
self
)
->
None
:
while
len
(
self
)
>
0
:
self
.
remove_oldest
(
remove_pinned
=
True
)
def
popitem
(
self
,
remove_pinned
:
bool
=
False
):
def
popitem
(
self
,
remove_pinned
:
bool
=
False
):
"""Remove and return the `(key, value)` pair least recently used."""
"""Remove and return the `(key, value)` pair least recently used."""
if
not
remove_pinned
:
if
not
remove_pinned
:
...
@@ -372,6 +411,14 @@ class LRUCache(cachetools.LRUCache[_K, _V], Generic[_K, _V]):
...
@@ -372,6 +411,14 @@ class LRUCache(cachetools.LRUCache[_K, _V], Generic[_K, _V]):
value
=
self
.
pop
(
cast
(
_K
,
lru_key
))
value
=
self
.
pop
(
cast
(
_K
,
lru_key
))
return
(
lru_key
,
value
)
return
(
lru_key
,
value
)
def
clear
(
self
)
->
None
:
while
len
(
self
)
>
0
:
self
.
remove_oldest
(
remove_pinned
=
True
)
self
.
_hits
=
0
self
.
_total
=
0
self
.
_last_info
=
CacheInfo
(
hits
=
0
,
total
=
0
)
class
PyObjectCache
:
class
PyObjectCache
:
"""Used to cache python objects to avoid object allocations
"""Used to cache python objects to avoid object allocations
...
...
vllm/v1/engine/mm_input_cache.py
View file @
aa29841e
...
@@ -50,7 +50,7 @@ class MirroredProcessingCache:
...
@@ -50,7 +50,7 @@ class MirroredProcessingCache:
full_mm_inputs
=
list
[
Optional
[
MultiModalKwargs
]]()
full_mm_inputs
=
list
[
Optional
[
MultiModalKwargs
]]()
for
mm_input
,
mm_hash
in
zip
(
mm_inputs
,
mm_hashes
):
for
mm_input
,
mm_hash
in
zip
(
mm_inputs
,
mm_hashes
):
if
mm_hash
in
self
.
mm_cache
:
if
self
.
mm_cache
.
get
(
mm_hash
)
is
not
None
:
mm_input
=
None
mm_input
=
None
else
:
else
:
self
.
mm_cache
[
mm_hash
]
=
mm_input
self
.
mm_cache
[
mm_hash
]
=
mm_input
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment