Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
4f7cde72
Unverified
Commit
4f7cde72
authored
Aug 29, 2025
by
Adit Chawdhary
Committed by
GitHub
Aug 29, 2025
Browse files
Adds `json_count_leaves` utility function (#23899)
Signed-off-by:
aditchawdhary
<
aditxy@hotmail.com
>
parent
67c14906
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
72 additions
and
10 deletions
+72
-10
tests/utils_/test_utils.py
tests/utils_/test_utils.py
+33
-3
vllm/multimodal/cache.py
vllm/multimodal/cache.py
+27
-5
vllm/utils/jsontree.py
vllm/utils/jsontree.py
+12
-2
No files found.
tests/utils_/test_utils.py
View file @
4f7cde72
...
...
@@ -948,6 +948,36 @@ def test_join_host_port():
assert
join_host_port
(
"::1"
,
5555
)
==
"[::1]:5555"
def
test_json_count_leaves
():
"""Test json_count_leaves function from jsontree utility."""
from
vllm.utils.jsontree
import
json_count_leaves
# Single leaf values
assert
json_count_leaves
(
42
)
==
1
assert
json_count_leaves
(
"hello"
)
==
1
assert
json_count_leaves
(
None
)
==
1
# Empty containers
assert
json_count_leaves
([])
==
0
assert
json_count_leaves
({})
==
0
assert
json_count_leaves
(())
==
0
# Flat structures
assert
json_count_leaves
([
1
,
2
,
3
])
==
3
assert
json_count_leaves
({
"a"
:
1
,
"b"
:
2
})
==
2
assert
json_count_leaves
((
1
,
2
,
3
))
==
3
# Nested structures
nested_dict
=
{
"a"
:
1
,
"b"
:
{
"c"
:
2
,
"d"
:
3
}}
assert
json_count_leaves
(
nested_dict
)
==
3
nested_list
=
[
1
,
[
2
,
3
],
4
]
assert
json_count_leaves
(
nested_list
)
==
4
mixed_nested
=
{
"list"
:
[
1
,
2
],
"dict"
:
{
"x"
:
3
},
"value"
:
4
}
assert
json_count_leaves
(
mixed_nested
)
==
4
def
test_convert_ids_list_to_tokens
():
tokenizer
=
AutoTokenizer
.
from_pretrained
(
"Qwen/Qwen2.5-1.5B-Instruct"
)
token_ids
=
tokenizer
.
encode
(
"Hello, world!"
)
...
...
vllm/multimodal/cache.py
View file @
4f7cde72
...
...
@@ -10,7 +10,8 @@ from typing_extensions import TypeAlias, override
from
vllm.logger
import
init_logger
from
vllm.utils
import
GiB_bytes
,
LRUCache
from
vllm.utils.jsontree
import
json_map_leaves
,
json_reduce_leaves
from
vllm.utils.jsontree
import
(
json_count_leaves
,
json_map_leaves
,
json_reduce_leaves
)
from
.inputs
import
(
MultiModalFeatureSpec
,
MultiModalFieldElem
,
MultiModalKwargs
,
MultiModalKwargsItem
,
...
...
@@ -127,11 +128,32 @@ class MultiModalCache:
)
if
debug
:
logger
.
debug
(
"Calculated size of %s to be %.2f GiB"
,
type
(
value
),
size
/
GiB_bytes
)
leaf_count
=
json_count_leaves
(
value
)
logger
.
debug
(
"Calculated size of %s to be %.2f GiB (%d leaves)"
,
type
(
value
),
size
/
GiB_bytes
,
leaf_count
,
)
return
size
@
classmethod
def
get_item_complexity
(
cls
,
value
:
MultiModalCacheValue
)
->
int
:
"""
Get the number of leaf elements in a multi-modal cache value.
This provides a measure of structural complexity that can be useful
for debugging cache performance and understanding data patterns.
Args:
value: The multi-modal cache value to analyze.
Returns:
The number of leaf elements in the nested structure.
"""
return
json_count_leaves
(
value
)
@
classmethod
def
get_lru_cache
(
cls
,
...
...
vllm/utils/jsontree.py
View file @
4f7cde72
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Helper functions to work with nested JSON structures."""
from
collections.abc
import
Iterable
from
functools
import
reduce
from
typing
import
Callable
,
TypeVar
,
Union
,
overload
...
...
@@ -8,8 +9,12 @@ from typing import Callable, TypeVar, Union, overload
_T
=
TypeVar
(
"_T"
)
_U
=
TypeVar
(
"_U"
)
JSONTree
=
Union
[
dict
[
str
,
"JSONTree[_T]"
],
list
[
"JSONTree[_T]"
],
tuple
[
"JSONTree[_T]"
,
...],
_T
]
JSONTree
=
Union
[
dict
[
str
,
"JSONTree[_T]"
],
list
[
"JSONTree[_T]"
],
tuple
[
"JSONTree[_T]"
,
...],
_T
,
]
"""A nested JSON structure where the leaves need not be JSON-serializable."""
...
...
@@ -78,3 +83,8 @@ def json_reduce_leaves(
json_iter_leaves
(
value
),
initial
,
)
def
json_count_leaves
(
value
:
JSONTree
[
_T
])
->
int
:
"""Count the number of leaves in a nested JSON structure."""
return
sum
(
1
for
_
in
json_iter_leaves
(
value
))
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment