Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
12c1287d
Unverified
Commit
12c1287d
authored
Sep 25, 2025
by
Cyrus Leung
Committed by
GitHub
Sep 25, 2025
Browse files
[mypy] Further improve MM type annotations (#25654)
Signed-off-by:
DarkLight1337
<
tlleungac@connect.ust.hk
>
parent
17b4c668
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
90 additions
and
48 deletions
+90
-48
vllm/model_executor/models/transformers.py
vllm/model_executor/models/transformers.py
+5
-2
vllm/multimodal/inputs.py
vllm/multimodal/inputs.py
+15
-8
vllm/multimodal/processing.py
vllm/multimodal/processing.py
+9
-10
vllm/multimodal/profiling.py
vllm/multimodal/profiling.py
+3
-3
vllm/multimodal/utils.py
vllm/multimodal/utils.py
+24
-20
vllm/utils/jsontree.py
vllm/utils/jsontree.py
+34
-5
No files found.
vllm/model_executor/models/transformers.py
View file @
12c1287d
...
...
@@ -415,9 +415,12 @@ class MultiModalProcessor(BaseMultiModalProcessor[MultiModalProcessingInfo]):
self
.
_get_mm_fields_config
(
processed_data
,
hf_processor_mm_kwargs
,
num_image_patches
),
)
# Use overrides if provided; fallback to data-dependent hashing.
mm_hashes
=
(
mm_uuids
if
mm_uuids
is
not
None
else
self
.
_hash_mm_items
(
mm_items
,
hf_processor_mm_kwargs
,
tokenization_kwargs
))
mm_hashes
=
self
.
_hash_mm_items
(
mm_items
,
hf_processor_mm_kwargs
,
tokenization_kwargs
,
mm_uuids
=
mm_uuids
)
return
MultiModalInputs
(
type
=
"multimodal"
,
...
...
vllm/multimodal/inputs.py
View file @
12c1287d
...
...
@@ -14,7 +14,7 @@ import numpy as np
from
typing_extensions
import
NotRequired
,
TypeAlias
,
TypeVar
,
deprecated
from
vllm.utils
import
LazyLoader
,
full_groupby
,
is_list_of
from
vllm.utils.jsontree
import
JSONTree
,
json_map_leaves
from
vllm.utils.jsontree
import
json_map_leaves
if
TYPE_CHECKING
:
import
torch
...
...
@@ -203,7 +203,7 @@ def nested_tensors_equal(a: NestedTensors, b: NestedTensors) -> bool:
return
a
==
b
BatchedTensorInputs
:
TypeAlias
=
Mapping
[
str
,
NestedTensors
]
BatchedTensorInputs
:
TypeAlias
=
dict
[
str
,
NestedTensors
]
"""
A dictionary containing nested tensors which have been batched via
[`MultiModalKwargs.batch`][vllm.multimodal.inputs.MultiModalKwargs.batch].
...
...
@@ -377,6 +377,7 @@ class MultiModalBatchedField(BaseMultiModalField):
pin_memory
:
bool
,
)
->
NestedTensors
:
if
len
(
batch
)
>
0
and
is_list_of
(
batch
,
torch
.
Tensor
,
check
=
"all"
):
batch
=
cast
(
list
[
torch
.
Tensor
],
batch
)
if
len
(
batch
)
==
1
:
# An optimization when `batch` contains only one tensor:
# - produce exactly same result as `torch.stack(batch)`
...
...
@@ -422,6 +423,7 @@ class MultiModalFlatField(BaseMultiModalField):
pin_memory
:
bool
,
)
->
NestedTensors
:
if
len
(
batch
)
>
0
and
is_list_of
(
batch
,
torch
.
Tensor
,
check
=
"all"
):
batch
=
cast
(
list
[
torch
.
Tensor
],
batch
)
if
len
(
batch
)
==
1
:
# An optimization when `batch` contains only one tensor:
# - produce exactly same result as `torch.concat(batch)`
...
...
@@ -764,6 +766,15 @@ class MultiModalKwargsItems(UserDict[str, Sequence[_I]]):
return
super
().
__getitem__
(
modality
)
# type: ignore[return-value]
def
require_data
(
self
)
->
"MultiModalKwargsItems[MultiModalKwargsItem]"
:
for
modality
,
items
in
self
.
items
():
for
i
,
item
in
enumerate
(
items
):
if
item
is
None
:
raise
RuntimeError
(
f
"Found empty mm_items[
{
modality
}
][
{
i
}
]"
)
return
self
# type: ignore[return-value]
def
get_data
(
self
,
*
,
pin_memory
:
bool
=
False
)
->
"MultiModalKwargs"
:
elems_by_key
=
defaultdict
[
str
,
list
[
MultiModalFieldElem
]](
list
)
for
modality
,
items
in
self
.
items
():
...
...
@@ -897,15 +908,11 @@ class MultiModalKwargs(UserDict[str, NestedTensors]):
*
,
device
:
torch
.
types
.
Device
,
)
->
BatchedTensorInputs
:
json_inputs
=
cast
(
JSONTree
[
torch
.
Tensor
],
batched_inputs
)
json_mapped
=
json_map_leaves
(
return
json_map_leaves
(
lambda
x
:
x
.
to
(
device
=
device
,
non_blocking
=
True
),
json
_inputs
,
batched
_inputs
,
)
return
cast
(
BatchedTensorInputs
,
json_mapped
)
def
__getitem__
(
self
,
key
:
str
):
if
key
not
in
self
:
raise
KeyError
(
f
"Keyword argument
{
key
!
r
}
not found. "
...
...
vllm/multimodal/processing.py
View file @
12c1287d
...
...
@@ -1585,7 +1585,7 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
*
,
mm_uuids
:
Optional
[
MultiModalUUIDDict
]
=
None
,
)
->
MultiModalHashes
:
"""Create MM hashes to be returned
(only used in V1)
.
"""Create MM hashes to be returned.
Note: When overrides are provided via callers of `apply`,
...
...
@@ -2098,23 +2098,22 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]):
encoder_inputs
:
MultiModalInputs
,
):
tokenizer
=
self
.
info
.
get_tokenizer
()
decoder_prompt
=
self
.
create_decoder_prompt
(
prompt
,
mm_data
)
if
isinstance
(
decoder_prompt
,
str
):
decoder_prompt_raw
=
self
.
create_decoder_prompt
(
prompt
,
mm_data
)
if
isinstance
(
decoder_prompt_raw
,
str
):
decoder_prompt
=
decoder_prompt_raw
decoder_prompt_ids
=
encode_tokens
(
tokenizer
,
decoder_prompt
,
decoder_prompt
_raw
,
add_special_tokens
=
False
)
else
:
decoder_prompt
_ids
=
decoder_prompt
decoder_prompt
=
decode_tokens
(
tokenizer
,
decoder_prompt
)
decoder_prompt
=
decode_tokens
(
tokenizer
,
decoder_prompt
_raw
)
decoder_prompt
_ids
=
decoder_prompt
_raw
mm_inputs
=
MultiModalEncDecInputs
(
encoder_prompt
=
encoder_inputs
[
"prompt"
],
encoder_prompt_token_ids
=
encoder_inputs
[
"prompt_token_ids"
],
**
encoder_inputs
)
mm_inputs
.
update
({
"prompt"
:
decoder_prompt
,
"prompt_token_ids"
:
decoder_prompt_ids
})
mm_inputs
[
"prompt"
]
=
decoder_prompt
mm_inputs
[
"prompt_token_ids"
]
=
decoder_prompt_ids
return
mm_inputs
def
apply
(
...
...
vllm/multimodal/profiling.py
View file @
12c1287d
...
...
@@ -13,7 +13,7 @@ import vllm.envs as envs
from
vllm.logger
import
init_logger
from
.inputs
import
(
MultiModalDataDict
,
MultiModalEncDecInputs
,
MultiModalInputs
,
MultiModalKwargs
Optional
Items
,
MultiModalInputs
,
MultiModalKwargsItems
,
MultiModalPlaceholderDict
)
from
.processing
import
(
BaseMultiModalProcessor
,
BaseProcessingInfo
,
EncDecMultiModalProcessor
)
...
...
@@ -43,7 +43,7 @@ class DummyDecoderData(NamedTuple):
"""Dummy data used for profiling."""
prompt_token_ids
:
list
[
int
]
multi_modal_data
:
MultiModalKwargs
Optional
Items
multi_modal_data
:
MultiModalKwargsItems
multi_modal_placeholders
:
MultiModalPlaceholderDict
...
...
@@ -239,7 +239,7 @@ class MultiModalProfiler(Generic[_I]):
return
DummyDecoderData
(
prompt_token_ids
=
prompt_token_ids
,
multi_modal_data
=
mm_inputs
[
"mm_kwargs"
],
multi_modal_data
=
mm_inputs
[
"mm_kwargs"
]
.
require_data
()
,
multi_modal_placeholders
=
mm_inputs
[
"mm_placeholders"
],
)
...
...
vllm/multimodal/utils.py
View file @
12c1287d
...
...
@@ -19,6 +19,7 @@ from typing_extensions import deprecated
import
vllm.envs
as
envs
from
vllm.connections
import
HTTPConnection
,
global_http_connection
from
vllm.utils.jsontree
import
json_map_leaves
from
.audio
import
AudioMediaIO
from
.base
import
MediaIO
...
...
@@ -383,6 +384,7 @@ def group_mm_kwargs_by_modality(
*
,
device
:
torch
.
types
.
Device
=
None
,
pin_memory
:
bool
=
False
,
merge_by_field_config
:
bool
=
False
,
)
->
Iterable
[
tuple
[
str
,
int
,
BatchedTensorInputs
]]:
"""Group consecutive `MultiModalKwargsItem`s from `mm_kwargs` with the same
modality together into the same `MultiModalKwargs` instance.
...
...
@@ -400,29 +402,31 @@ def group_mm_kwargs_by_modality(
for
modality
,
items
in
groupby
(
mm_kwargs
,
key
=
lambda
item
:
item
.
modality
):
items_lst
=
list
(
items
)
# mm_kwargs_group = MultiModalKwargsItems.from_items(items_lst) \
# .get_data(pin_memory=pin_memory)
# if device is not None:
# mm_kwargs_group = json_map_leaves(
# lambda x: x.to(device=device),
# mm_kwargs_group,
# )
# TODO: Once V0 is removed, we can use the merging logic above
# TODO: Enable `merge_by_field_config` for all models
# to avoid creating an extra batch dimension (except for fields
# that are meant to be stacked anyway).
# We will also need to update each model to remove `flatten_bn`.
mm_kwargs_group
=
MultiModalKwargs
.
as_kwargs
(
MultiModalKwargs
.
batch
(
[
MultiModalKwargsItems
.
from_seq
([
item
]).
get_data
()
for
item
in
items_lst
],
pin_memory
=
pin_memory
,
),
device
=
device
,
)
if
merge_by_field_config
:
mm_kwargs_group
:
BatchedTensorInputs
=
dict
(
MultiModalKwargsItems
.
from_seq
(
items_lst
).
get_data
(
pin_memory
=
pin_memory
))
if
device
is
not
None
:
mm_kwargs_group
=
json_map_leaves
(
lambda
x
:
x
.
to
(
device
=
device
),
mm_kwargs_group
,
)
else
:
mm_kwargs_group
=
MultiModalKwargs
.
as_kwargs
(
MultiModalKwargs
.
batch
(
[
MultiModalKwargsItems
.
from_seq
([
item
]).
get_data
()
for
item
in
items_lst
],
pin_memory
=
pin_memory
,
),
device
=
device
,
)
yield
modality
,
len
(
items_lst
),
mm_kwargs_group
...
...
vllm/utils/jsontree.py
View file @
12c1287d
...
...
@@ -4,7 +4,12 @@
from
collections.abc
import
Iterable
from
functools
import
reduce
from
typing
import
Callable
,
TypeVar
,
Union
,
cast
,
overload
from
typing
import
TYPE_CHECKING
,
Callable
,
TypeVar
,
Union
,
cast
,
overload
if
TYPE_CHECKING
:
import
torch
from
vllm.multimodal.inputs
import
BatchedTensorInputs
_T
=
TypeVar
(
"_T"
)
_U
=
TypeVar
(
"_U"
)
...
...
@@ -17,6 +22,19 @@ JSONTree = Union[
]
"""A nested JSON structure where the leaves need not be JSON-serializable."""
_JSONTree
=
Union
[
dict
[
str
,
"JSONTree[_T]"
],
list
[
"JSONTree[_T]"
],
tuple
[
"JSONTree[_T]"
,
...],
dict
[
str
,
_T
],
list
[
_T
],
tuple
[
_T
,
...],
_T
,
]
"""
Same as `JSONTree` but with additional `Union` members to satisfy overloads.
"""
def
json_iter_leaves
(
value
:
JSONTree
[
_T
])
->
Iterable
[
_T
]:
"""Iterate through each leaf in a nested JSON structure."""
...
...
@@ -30,6 +48,14 @@ def json_iter_leaves(value: JSONTree[_T]) -> Iterable[_T]:
yield
value
@
overload
def
json_map_leaves
(
func
:
Callable
[[
"torch.Tensor"
],
"torch.Tensor"
],
value
:
"BatchedTensorInputs"
,
)
->
"BatchedTensorInputs"
:
...
@
overload
def
json_map_leaves
(
func
:
Callable
[[
_T
],
_U
],
...
...
@@ -64,11 +90,14 @@ def json_map_leaves(
def
json_map_leaves
(
func
:
Callable
[[
_T
],
_U
],
value
:
Union
[
dict
[
str
,
_T
],
list
[
_T
],
tuple
[
_T
,
...],
JSONTree
[
_T
]],
)
->
Union
[
dict
[
str
,
_U
],
list
[
_U
],
tuple
[
_U
,
...],
JSONTree
[
_U
]]:
value
:
Union
[
"BatchedTensorInputs"
,
_
JSONTree
[
_T
]],
)
->
Union
[
"BatchedTensorInputs"
,
_
JSONTree
[
_U
]]:
"""Apply a function to each leaf in a nested JSON structure."""
if
isinstance
(
value
,
dict
):
return
{
k
:
json_map_leaves
(
func
,
v
)
for
k
,
v
in
value
.
items
()}
return
{
k
:
json_map_leaves
(
func
,
v
)
# type: ignore[arg-type]
for
k
,
v
in
value
.
items
()
}
elif
isinstance
(
value
,
list
):
return
[
json_map_leaves
(
func
,
v
)
for
v
in
value
]
elif
isinstance
(
value
,
tuple
):
...
...
@@ -125,7 +154,7 @@ def json_reduce_leaves(
def
json_reduce_leaves
(
func
:
Callable
[...,
Union
[
_T
,
_U
]],
value
:
Union
[
dict
[
str
,
_T
],
list
[
_T
],
tuple
[
_T
,
...],
JSONTree
[
_T
]
]
,
value
:
_
JSONTree
[
_T
],
initial
:
_U
=
cast
(
_U
,
...),
# noqa: B008
/
,
)
->
Union
[
_T
,
_U
]:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment