Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
3d446433
Unverified
Commit
3d446433
authored
Mar 19, 2025
by
Cyrus Leung
Committed by
GitHub
Mar 19, 2025
Browse files
[Bugfix] Fix size calculation of processing cache (#15114)
Signed-off-by:
DarkLight1337
<
tlleungac@connect.ust.hk
>
parent
1fe0fd12
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
92 additions
and
16 deletions
+92
-16
tests/multimodal/test_processing.py
tests/multimodal/test_processing.py
+46
-2
vllm/multimodal/processing.py
vllm/multimodal/processing.py
+46
-14
No files found.
tests/multimodal/test_processing.py
View file @
3d446433
...
@@ -7,15 +7,20 @@ from unittest.mock import MagicMock
...
@@ -7,15 +7,20 @@ from unittest.mock import MagicMock
import
numpy
as
np
import
numpy
as
np
import
pytest
import
pytest
import
torch
from
transformers
import
ProcessorMixin
from
transformers
import
ProcessorMixin
from
vllm.config
import
ModelConfig
from
vllm.config
import
ModelConfig
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal.inputs
import
(
MultiModalFieldElem
,
MultiModalKwargs
,
MultiModalKwargsItem
,
MultiModalSharedField
)
# yapf conflicts with isort for this block
# yapf conflicts with isort for this block
# yapf: disable
# yapf: disable
from
vllm.multimodal.processing
import
(
PlaceholderFeaturesInfo
,
from
vllm.multimodal.processing
import
(
PlaceholderFeaturesInfo
,
PromptIndexTargets
,
PromptInsertion
,
ProcessingCache
,
PromptIndexTargets
,
PromptReplacement
,
apply_text_matches
,
PromptInsertion
,
PromptReplacement
,
apply_text_matches
,
apply_token_matches
,
apply_token_matches
,
find_mm_placeholders
,
find_mm_placeholders
,
find_text_matches
,
find_token_matches
,
find_text_matches
,
find_token_matches
,
...
@@ -890,6 +895,45 @@ def test_find_mm_placeholders(
...
@@ -890,6 +895,45 @@ def test_find_mm_placeholders(
assert
result
==
expected
assert
result
==
expected
def
_dummy_elem
(
modality
:
str
,
key
:
str
,
size
:
int
):
return
MultiModalFieldElem
(
modality
=
modality
,
key
=
key
,
data
=
torch
.
empty
((
size
,
),
dtype
=
torch
.
int8
),
field
=
MultiModalSharedField
(
1
),
)
def
_dummy_item
(
modality
:
str
,
size_by_key
:
dict
[
str
,
int
]):
return
MultiModalKwargsItem
.
from_elems
([
_dummy_elem
(
modality
,
key
,
size
)
for
key
,
size
in
size_by_key
.
items
()
])
def
_dummy_kw
(
size_by_key_modality
:
dict
[
str
,
dict
[
str
,
int
]]):
return
MultiModalKwargs
.
from_items
([
_dummy_item
(
modality
,
size_by_key
)
for
modality
,
size_by_key
in
size_by_key_modality
.
items
()
])
# yapf: disable
@
pytest
.
mark
.
parametrize
(
(
"item"
,
"expected_size"
),
[
(
_dummy_item
(
"a"
,
{
"a1"
:
100
}),
100
),
(
_dummy_item
(
"a"
,
{
"a1"
:
100
,
"a2"
:
110
}),
210
),
(
_dummy_kw
({
"a"
:
{
"a1"
:
100
,
"a2"
:
110
},
"b"
:
{
"b1"
:
120
,
"b2"
:
130
}}),
460
),
# noqa: E501
],
)
# yapf: enable
def
test_cache_item_size
(
item
,
expected_size
):
cache
=
ProcessingCache
.
get_lru_cache
(
2048
,
type
(
item
))
cache
[
""
]
=
item
assert
cache
.
currsize
==
expected_size
@
pytest
.
mark
.
parametrize
(
"model_id"
,
[
"llava-hf/llava-v1.6-mistral-7b-hf"
])
@
pytest
.
mark
.
parametrize
(
"model_id"
,
[
"llava-hf/llava-v1.6-mistral-7b-hf"
])
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
(
"limit"
,
"num_supported"
,
"is_valid"
),
(
"limit"
,
"num_supported"
,
"is_valid"
),
...
...
vllm/multimodal/processing.py
View file @
3d446433
...
@@ -26,7 +26,7 @@ from vllm.utils import GiB_bytes, flatten_2d_lists, full_groupby
...
@@ -26,7 +26,7 @@ from vllm.utils import GiB_bytes, flatten_2d_lists, full_groupby
from
.hasher
import
MultiModalHasher
from
.hasher
import
MultiModalHasher
from
.inputs
import
(
MultiModalDataDict
,
MultiModalEncDecInputs
,
from
.inputs
import
(
MultiModalDataDict
,
MultiModalEncDecInputs
,
MultiModalFieldConfig
,
MultiModalInputs
,
MultiModalKwargs
,
MultiModalFieldConfig
,
MultiModalInputs
,
MultiModalKwargs
,
MultiModalKwargsItem
,
PlaceholderRange
)
MultiModalKwargsItem
,
NestedTensors
,
PlaceholderRange
)
from
.parse
import
(
DictEmbeddingItems
,
EmbeddingItems
,
MultiModalDataItems
,
from
.parse
import
(
DictEmbeddingItems
,
EmbeddingItems
,
MultiModalDataItems
,
MultiModalDataParser
)
MultiModalDataParser
)
...
@@ -853,33 +853,62 @@ class ProcessingCache:
...
@@ -853,33 +853,62 @@ class ProcessingCache:
@
staticmethod
@
staticmethod
def
get_lru_cache
(
def
get_lru_cache
(
capacity_gb
:
in
t
,
capacity_gb
:
floa
t
,
value_type
:
type
[
_V
],
value_type
:
type
[
_V
],
*
,
debug
:
bool
=
False
,
)
->
LRUCache
[
str
,
_V
]:
)
->
LRUCache
[
str
,
_V
]:
def
get_size
(
leaf
:
object
)
->
int
:
def
get_leaf_size
(
leaf
:
object
)
->
int
:
# MultiModalKwargs is not a subclass of dict
if
isinstance
(
leaf
,
MultiModalKwargs
):
return
get_item_size
(
leaf
.
data
)
# MultiModalKwargsItem is not a subclass of dict
if
isinstance
(
leaf
,
MultiModalKwargsItem
):
leaf_data
=
{
k
:
v
.
data
for
k
,
v
in
leaf
.
items
()}
return
get_item_size
(
leaf_data
)
# sys.getsizeof doesn't work for tensors
if
isinstance
(
leaf
,
torch
.
Tensor
):
if
isinstance
(
leaf
,
torch
.
Tensor
):
return
leaf
.
nbytes
# sys.getsizeof doesn't work for tensors
return
leaf
.
nbytes
return
sys
.
getsizeof
(
leaf
)
return
sys
.
getsizeof
(
leaf
)
return
LRUCache
[
str
,
_V
](
def
get_item_size
(
GiB_bytes
*
capacity_gb
,
value
:
Union
[
MultiModalKwargs
,
MultiModalKwargsItem
,
getsizeof
=
lambda
x
:
json_reduce_leaves
(
Mapping
[
str
,
NestedTensors
]]
)
->
int
:
size
=
json_reduce_leaves
(
lambda
a
,
b
:
a
+
b
,
lambda
a
,
b
:
a
+
b
,
json_map_leaves
(
get_size
,
x
),
json_map_leaves
(
get_leaf_size
,
value
),
),
)
)
if
debug
:
logger
.
debug
(
"Calculated size of %s to be %.2f GiB"
,
type
(
value
),
size
/
GiB_bytes
)
def
__init__
(
self
,
capacity_gb
:
int
)
->
None
:
return
size
return
LRUCache
(
GiB_bytes
*
capacity_gb
,
getsizeof
=
get_item_size
)
def
__init__
(
self
,
capacity_gb
:
float
,
*
,
debug_cache_hit_ratio_steps
:
Optional
[
int
]
=
None
,
)
->
None
:
super
().
__init__
()
super
().
__init__
()
# DEBUG: Set to None to disable
self
.
debug_cache_hit_ratio_steps
=
debug_cache_hit_ratio_steps
self
.
debug_cache_hit_ratio_steps
:
Optional
[
int
]
=
None
self
.
debug_cache_hits
=
0
self
.
debug_cache_hits
=
0
self
.
debug_cache_total
=
0
self
.
debug_cache_total
=
0
self
.
_cache
=
self
.
get_lru_cache
(
capacity_gb
,
MultiModalKwargsItem
)
self
.
_cache
=
self
.
get_lru_cache
(
capacity_gb
,
MultiModalKwargsItem
,
debug
=
bool
(
debug_cache_hit_ratio_steps
),
)
def
_maybe_log_cache_stats
(
self
)
->
None
:
def
_maybe_log_cache_stats
(
self
)
->
None
:
steps
=
self
.
debug_cache_hit_ratio_steps
steps
=
self
.
debug_cache_hit_ratio_steps
...
@@ -890,6 +919,9 @@ class ProcessingCache:
...
@@ -890,6 +919,9 @@ class ProcessingCache:
if
total
>
0
and
total
%
steps
==
0
:
if
total
>
0
and
total
%
steps
==
0
:
logger
.
debug
(
"ProcessingCache: hit_ratio = %.2f"
,
logger
.
debug
(
"ProcessingCache: hit_ratio = %.2f"
,
self
.
debug_cache_hits
/
total
)
self
.
debug_cache_hits
/
total
)
logger
.
debug
(
"ProcessingCache: size = %.2f / %.2f GiB"
,
self
.
_cache
.
currsize
/
GiB_bytes
,
self
.
_cache
.
maxsize
/
GiB_bytes
)
def
get
(
def
get
(
self
,
self
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment