Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
cc826a20
Unverified
Commit
cc826a20
authored
Aug 16, 2025
by
Isotr0py
Committed by
GitHub
Aug 16, 2025
Browse files
[Multimodal] Update Tensor schema test to cover arbitrary shape mm inputs (#22867)
Signed-off-by:
Isotr0py
<
mozf@mail2.sysu.edu.cn
>
parent
6d3da472
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
138 additions
and
27 deletions
+138
-27
tests/models/multimodal/test_tensor_schema.py
tests/models/multimodal/test_tensor_schema.py
+124
-19
vllm/model_executor/models/keye.py
vllm/model_executor/models/keye.py
+14
-8
No files found.
tests/models/multimodal/test_tensor_schema.py
View file @
cc826a20
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
collections.abc
import
Iterable
from
functools
import
partial
from
functools
import
partial
from
typing
import
Any
,
Union
from
unittest.mock
import
patch
from
unittest.mock
import
patch
import
numpy
as
np
import
pytest
import
pytest
from
mistral_common.protocol.instruct.messages
import
(
ImageChunk
,
TextChunk
,
UserMessage
)
from
mistral_common.protocol.instruct.request
import
ChatCompletionRequest
from
PIL
import
Image
from
vllm.config
import
ModelConfig
from
vllm.config
import
ModelConfig
from
vllm.engine.llm_engine
import
LLMEngine
as
V0LLMEngine
from
vllm.engine.llm_engine
import
LLMEngine
as
V0LLMEngine
from
vllm.inputs
import
InputProcessingContext
from
vllm.inputs
import
InputProcessingContext
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
,
MultiModalKwargs
from
vllm.multimodal
import
(
MULTIMODAL_REGISTRY
,
BatchedTensorInputs
,
MultiModalKwargs
)
from
vllm.multimodal.processing
import
BaseMultiModalProcessor
from
vllm.multimodal.processing
import
BaseMultiModalProcessor
from
vllm.multimodal.utils
import
group_mm_kwargs_by_modality
from
vllm.transformers_utils.tokenizer
import
cached_tokenizer_from_config
from
vllm.transformers_utils.tokenizer
import
cached_tokenizer_from_config
from
vllm.utils
import
GiB_bytes
,
set_default_torch_num_threads
from
vllm.utils
import
GiB_bytes
,
is_list_of
,
set_default_torch_num_threads
from
vllm.v1.core.kv_cache_utils
import
get_kv_cache_config
from
vllm.v1.core.kv_cache_utils
import
get_kv_cache_config
from
vllm.v1.engine.core
import
EngineCore
as
V1EngineCore
from
vllm.v1.engine.core
import
EngineCore
as
V1EngineCore
...
@@ -23,12 +32,64 @@ ARCH_TO_SKIP = {
...
@@ -23,12 +32,64 @@ ARCH_TO_SKIP = {
"MolmoForCausalLM"
:
"incompatible requirements"
,
"MolmoForCausalLM"
:
"incompatible requirements"
,
"MiniMaxVL01ForConditionalGeneration"
:
"broken model"
,
"MiniMaxVL01ForConditionalGeneration"
:
"broken model"
,
}
}
ARCH_NEEDS_EXTRAS
=
[
"InternVLChatModel"
,
"Idefics3ForConditionalGeneration"
,
"LlavaForConditionalGeneration"
,
"MiniCPMV"
,
"PaliGemmaForConditionalGeneration"
,
]
REPO_ID_TO_SKIP
=
{
"nm-testing/pixtral-12b-FP8-dynamic"
:
"duplicated test"
}
ImageInput
=
list
[
Image
.
Image
]
VideoInput
=
Union
[
list
[
Image
.
Image
],
list
[
np
.
ndarray
],
list
[
tuple
[
np
.
ndarray
,
dict
[
str
,
Any
]]]]
AudioInput
=
list
[
tuple
[
np
.
ndarray
,
int
]]
def
_resize_data
(
_data
:
Union
[
Image
.
Image
,
np
.
ndarray
],
size_factor
:
float
)
->
Union
[
Image
.
Image
,
np
.
ndarray
]:
assert
size_factor
<=
1
,
"Size factor must be less than 1"
# Image input
if
isinstance
(
_data
,
Image
.
Image
):
W
,
H
=
_data
.
width
,
_data
.
height
W
,
H
=
map
(
lambda
x
:
int
(
x
*
size_factor
),
(
W
,
H
))
return
_data
.
resize
((
W
,
H
))
# Video input with PIL Images
elif
is_list_of
(
_data
,
Image
.
Image
):
W
,
H
=
next
(
iter
(
_data
)).
width
,
next
(
iter
(
_data
)).
height
T
=
len
(
_data
)
T
,
W
,
H
=
map
(
lambda
x
:
max
(
int
(
x
*
size_factor
),
1
),
(
T
,
W
,
H
))
return
[
d
.
resize
((
W
,
H
))
for
d
in
_data
[:
T
]]
# Video input with numpy arrays
elif
isinstance
(
_data
,
np
.
ndarray
)
and
_data
.
ndim
>=
4
:
T
,
H
,
W
,
C
=
_data
.
shape
[
-
4
:]
T
,
H
,
W
=
map
(
lambda
x
:
max
(
int
(
x
*
size_factor
),
1
),
(
T
,
H
,
W
))
return
_data
[...,
:
T
,
:
H
,
:
W
,
:
C
]
# Audio input
elif
isinstance
(
_data
,
np
.
ndarray
)
and
_data
.
ndim
==
1
:
return
_data
[:
int
(
len
(
_data
)
*
size_factor
)]
raise
AssertionError
(
"This line should be unreachable."
)
def
resize_mm_data
(
data
:
Union
[
ImageInput
,
VideoInput
,
AudioInput
],
size_factors
:
tuple
[
float
,
...])
->
Union
[
ImageInput
,
VideoInput
,
AudioInput
]:
size_factors
=
size_factors
[:
len
(
data
)]
if
is_list_of
(
data
,
(
Image
.
Image
,
np
.
ndarray
,
list
)):
return
[
_resize_data
(
d
,
s
)
for
d
,
s
in
zip
(
data
,
size_factors
)]
elif
is_list_of
(
data
,
tuple
):
return
[(
_resize_data
(
d
,
s
),
meta
)
for
(
d
,
meta
),
s
in
zip
(
data
,
size_factors
)]
raise
ValueError
(
"Unsupported multimodal data type."
)
def
create_batched_mm_kwargs
(
def
create_batched_mm_kwargs
(
model_config
:
ModelConfig
,
model_config
:
ModelConfig
,
processor
:
BaseMultiModalProcessor
,
processor
:
BaseMultiModalProcessor
,
)
->
MultiModalKwargs
:
size_factors
:
tuple
[
float
,
...]
=
(
1.0
,
0.5
,
0.25
),
)
->
Iterable
[
tuple
[
str
,
int
,
BatchedTensorInputs
]]:
processing_info
=
processor
.
info
processing_info
=
processor
.
info
dummy_inputs
=
processor
.
dummy_inputs
dummy_inputs
=
processor
.
dummy_inputs
supported_mm_limits
=
processing_info
.
get_supported_mm_limits
()
supported_mm_limits
=
processing_info
.
get_supported_mm_limits
()
...
@@ -40,30 +101,69 @@ def create_batched_mm_kwargs(
...
@@ -40,30 +101,69 @@ def create_batched_mm_kwargs(
seq_len
=
model_config
.
max_model_len
,
seq_len
=
model_config
.
max_model_len
,
mm_counts
=
mm_counts
,
mm_counts
=
mm_counts
,
)
)
mm_data
=
processor_inputs
.
mm_data
resized_mm_data
=
{
modality
:
resize_mm_data
(
data
,
size_factors
)
for
modality
,
data
in
mm_data
.
items
()
}
# Mistral chat outputs tokens directly, rather than text prompts
if
model_config
.
tokenizer_mode
==
"mistral"
:
images
=
resized_mm_data
.
get
(
"image"
,
[])
request
=
ChatCompletionRequest
(
messages
=
[
UserMessage
(
content
=
[
TextChunk
(
text
=
""
),
*
(
ImageChunk
(
image
=
image
)
for
image
in
images
),
]),
])
tokenizer
=
processing_info
.
get_tokenizer
()
res
=
tokenizer
.
mistral
.
encode_chat_completion
(
request
)
prompt
=
res
.
tokens
else
:
prompt
=
processor_inputs
.
prompt
mm_kwargs
=
processor
.
apply
(
mm_kwargs
=
processor
.
apply
(
prompt
=
processor_inputs
.
prompt
,
prompt
=
prompt
,
mm_data
=
processor_inputs
.
mm_data
,
mm_data
=
resized_
mm_data
,
hf_processor_mm_kwargs
=
processor_inputs
.
hf_processor_mm_kwargs
,
hf_processor_mm_kwargs
=
processor_inputs
.
hf_processor_mm_kwargs
,
tokenization_kwargs
=
processor_inputs
.
tokenization_kwargs
,
tokenization_kwargs
=
processor_inputs
.
tokenization_kwargs
,
)[
"mm_kwargs"
]
)[
"mm_kwargs"
]
mm_kwargs
=
MultiModalKwargs
.
batch
([
mm_kwargs
])
items
=
[
return
mm_kwargs
item
for
modality
in
supported_mm_limits
for
item
in
mm_kwargs
.
get_items
(
modality
)
]
return
group_mm_kwargs_by_modality
(
items
)
def
get_model_id_to_test
(
model_arch_list
:
Iterable
[
str
])
->
list
[
tuple
[
str
,
str
]]:
filtered_results
=
[]
for
model_arch
in
model_arch_list
:
model_info
=
HF_EXAMPLE_MODELS
.
get_hf_info
(
model_arch
)
if
model_info
.
extras
and
model_arch
in
ARCH_NEEDS_EXTRAS
:
available_repos
=
list
(
map
(
lambda
model_id
:
(
model_arch
,
model_id
),
[
model_info
.
default
,
*
model_info
.
extras
.
values
()]))
filtered_results
.
extend
(
available_repos
)
else
:
filtered_results
.
append
((
model_arch
,
model_info
.
default
))
return
filtered_results
@
pytest
.
mark
.
core_model
@
pytest
.
mark
.
core_model
@
pytest
.
mark
.
parametrize
(
"model_arch"
,
list
(
_MULTIMODAL_EXAMPLE_MODELS
.
keys
()))
@
pytest
.
mark
.
parametrize
(
def
test_model_tensor_schema
(
model_arch
:
str
,
vllm_runner
:
type
[
VllmRunner
],
"model_arch, model_id"
,
monkeypatch
):
get_model_id_to_test
(
_MULTIMODAL_EXAMPLE_MODELS
.
keys
()))
def
test_model_tensor_schema
(
model_arch
:
str
,
model_id
:
str
,
vllm_runner
:
type
[
VllmRunner
],
monkeypatch
):
if
model_arch
in
ARCH_TO_SKIP
:
if
model_arch
in
ARCH_TO_SKIP
:
pytest
.
skip
(
f
"Skipping
{
model_arch
}
due to
{
ARCH_TO_SKIP
[
model_arch
]
}
"
)
pytest
.
skip
(
f
"Skipping
{
model_arch
}
due to
{
ARCH_TO_SKIP
[
model_arch
]
}
"
)
if
model_id
in
REPO_ID_TO_SKIP
:
pytest
.
skip
(
f
"Skipping
{
model_id
}
due to
{
REPO_ID_TO_SKIP
[
model_id
]
}
"
)
model_info
=
HF_EXAMPLE_MODELS
.
get_hf_info
(
model_arch
)
model_info
=
HF_EXAMPLE_MODELS
.
get_hf_info
(
model_arch
)
model_info
.
check_available_online
(
on_fail
=
"skip"
)
model_info
.
check_available_online
(
on_fail
=
"skip"
)
model_info
.
check_transformers_version
(
on_fail
=
"skip"
,
model_info
.
check_transformers_version
(
on_fail
=
"skip"
,
check_max_version
=
False
)
check_max_version
=
False
)
model_id
=
model_info
.
default
hf_overrides_fn
=
partial
(
dummy_hf_overrides
,
hf_overrides_fn
=
partial
(
dummy_hf_overrides
,
model_arch
=
model_arch
,
model_arch
=
model_arch
,
exist_overrides
=
model_info
.
hf_overrides
)
exist_overrides
=
model_info
.
hf_overrides
)
...
@@ -119,6 +219,7 @@ def test_model_tensor_schema(model_arch: str, vllm_runner: type[VllmRunner],
...
@@ -119,6 +219,7 @@ def test_model_tensor_schema(model_arch: str, vllm_runner: type[VllmRunner],
if
model_info
.
v0_only
:
if
model_info
.
v0_only
:
m
.
setenv
(
"VLLM_USE_V1"
,
"0"
)
m
.
setenv
(
"VLLM_USE_V1"
,
"0"
)
# TODO(Isotr0py): Can we avoid initializing engine?
with
(
with
(
set_default_torch_num_threads
(
1
),
set_default_torch_num_threads
(
1
),
vllm_runner
(
vllm_runner
(
...
@@ -145,12 +246,16 @@ def test_model_tensor_schema(model_arch: str, vllm_runner: type[VllmRunner],
...
@@ -145,12 +246,16 @@ def test_model_tensor_schema(model_arch: str, vllm_runner: type[VllmRunner],
mm_registry
=
llm_engine
.
input_preprocessor
.
mm_registry
mm_registry
=
llm_engine
.
input_preprocessor
.
mm_registry
processor
=
mm_registry
.
create_processor
(
model_config
)
processor
=
mm_registry
.
create_processor
(
model_config
)
mm_kwargs
=
create_batched_mm_kwargs
(
model_config
,
processor
)
def
validate_model_input
(
model
):
def
validate_model_input
(
model
,
modality
:
str
,
for
modality
in
(
"audio"
,
"image"
,
"video"
):
mm_kwargs
:
MultiModalKwargs
):
method_name
=
f
"_parse_and_validate_
{
modality
}
_input"
method_name
=
f
"_parse_and_validate_
{
modality
}
_input"
if
hasattr
(
model
,
method_name
):
if
hasattr
(
model
,
method_name
):
getattr
(
model
,
method_name
)(
**
mm_kwargs
)
getattr
(
model
,
method_name
)(
**
mm_kwargs
)
vllm_model
.
apply_model
(
validate_model_input
)
for
modality
,
_
,
mm_kwargs
in
create_batched_mm_kwargs
(
\ No newline at end of file
model_config
,
processor
):
valid_func
=
partial
(
validate_model_input
,
modality
=
modality
,
mm_kwargs
=
mm_kwargs
)
vllm_model
.
apply_model
(
valid_func
)
vllm/model_executor/models/keye.py
View file @
cc826a20
...
@@ -30,7 +30,7 @@ from vllm.model_executor.layers.quantization.gptq_marlin import (
...
@@ -30,7 +30,7 @@ from vllm.model_executor.layers.quantization.gptq_marlin import (
from
vllm.model_executor.model_loader.weight_utils
import
(
from
vllm.model_executor.model_loader.weight_utils
import
(
default_weight_loader
,
maybe_remap_kv_scale_name
)
default_weight_loader
,
maybe_remap_kv_scale_name
)
from
vllm.model_executor.models.module_mapping
import
MultiModelKeys
from
vllm.model_executor.models.module_mapping
import
MultiModelKeys
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
from
vllm.multimodal
import
MULTIMODAL_REGISTRY
,
NestedTensors
from
vllm.multimodal.inputs
import
(
ImageItem
,
ModalityData
,
from
vllm.multimodal.inputs
import
(
ImageItem
,
ModalityData
,
MultiModalDataDict
,
MultiModalFieldConfig
,
MultiModalDataDict
,
MultiModalFieldConfig
,
MultiModalKwargs
,
VideoItem
)
MultiModalKwargs
,
VideoItem
)
...
@@ -44,6 +44,7 @@ from vllm.multimodal.profiling import BaseDummyInputsBuilder
...
@@ -44,6 +44,7 @@ from vllm.multimodal.profiling import BaseDummyInputsBuilder
from
vllm.platforms
import
_Backend
from
vllm.platforms
import
_Backend
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
vllm.transformers_utils.config
import
uses_mrope
from
vllm.transformers_utils.config
import
uses_mrope
from
vllm.utils
import
is_list_of
from
vllm.utils.tensor_schema
import
TensorSchema
,
TensorShape
from
vllm.utils.tensor_schema
import
TensorSchema
,
TensorShape
from
.interfaces
import
(
MultiModalEmbeddings
,
SupportsLoRA
,
from
.interfaces
import
(
MultiModalEmbeddings
,
SupportsLoRA
,
...
@@ -112,8 +113,9 @@ class KeyeImagePixelInputs(TensorSchema):
...
@@ -112,8 +113,9 @@ class KeyeImagePixelInputs(TensorSchema):
- g: Grid dimensions (3 for t, h, w)
- g: Grid dimensions (3 for t, h, w)
"""
"""
type
:
Literal
[
"pixel_values"
]
type
:
Literal
[
"pixel_values"
]
pixel_values
:
Annotated
[
torch
.
Tensor
,
pixel_values
:
Annotated
[
TensorShape
(
"b"
,
"np"
,
3
,
"ps"
,
"ps"
)]
torch
.
Tensor
,
TensorShape
(
"b"
,
"np"
,
3
,
"ps"
,
"ps"
,
dynamic_dims
=
{
"np"
})]
image_grid_thw
:
Annotated
[
torch
.
Tensor
,
TensorShape
(
"ni"
,
3
)]
image_grid_thw
:
Annotated
[
torch
.
Tensor
,
TensorShape
(
"ni"
,
3
)]
...
@@ -145,8 +147,9 @@ class KeyeVideoPixelInputs(TensorSchema):
...
@@ -145,8 +147,9 @@ class KeyeVideoPixelInputs(TensorSchema):
- g: Grid dimensions (3 for t, h, w)
- g: Grid dimensions (3 for t, h, w)
"""
"""
type
:
Literal
[
"pixel_values_videos"
]
type
:
Literal
[
"pixel_values_videos"
]
pixel_values_videos
:
Annotated
[
torch
.
Tensor
,
pixel_values_videos
:
Annotated
[
TensorShape
(
"b"
,
"np"
,
3
,
"ps"
,
"ps"
)]
torch
.
Tensor
,
TensorShape
(
"b"
,
"np"
,
3
,
"ps"
,
"ps"
,
dynamic_dims
=
{
"np"
})]
video_grid_thw
:
Annotated
[
torch
.
Tensor
,
TensorShape
(
"nv"
,
3
)]
video_grid_thw
:
Annotated
[
torch
.
Tensor
,
TensorShape
(
"nv"
,
3
)]
...
@@ -1295,7 +1298,7 @@ class KeyeForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsLoRA,
...
@@ -1295,7 +1298,7 @@ class KeyeForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsLoRA,
return
None
return
None
return
quant_config
return
quant_config
def
_validate_and_reshape_mm_tensor
(
self
,
mm_input
:
object
,
def
_validate_and_reshape_mm_tensor
(
self
,
mm_input
:
NestedTensors
,
name
:
str
)
->
torch
.
Tensor
:
name
:
str
)
->
torch
.
Tensor
:
if
not
isinstance
(
mm_input
,
(
torch
.
Tensor
,
list
)):
if
not
isinstance
(
mm_input
,
(
torch
.
Tensor
,
list
)):
raise
ValueError
(
f
"Incorrect type of
{
name
}
. "
raise
ValueError
(
f
"Incorrect type of
{
name
}
. "
...
@@ -1310,8 +1313,11 @@ class KeyeForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsLoRA,
...
@@ -1310,8 +1313,11 @@ class KeyeForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsLoRA,
f
"Got ndim:
{
mm_input
.
ndim
}
"
f
"Got ndim:
{
mm_input
.
ndim
}
"
f
"(shape=
{
mm_input
.
shape
}
)"
)
f
"(shape=
{
mm_input
.
shape
}
)"
)
return
torch
.
concat
(
list
(
mm_input
))
return
torch
.
concat
(
list
(
mm_input
))
else
:
elif
is_list_of
(
mm_input
,
torch
.
Tensor
):
return
torch
.
concat
(
mm_input
)
if
all
(
p
.
dim
()
==
4
for
p
in
mm_input
)
or
all
(
p
.
dim
()
==
2
for
p
in
mm_input
):
return
mm_input
return
torch
.
concat
(
list
(
mm_input
))
def
_parse_and_validate_image_input
(
def
_parse_and_validate_image_input
(
self
,
**
kwargs
:
object
)
->
Optional
[
KeyeImageInputs
]:
self
,
**
kwargs
:
object
)
->
Optional
[
KeyeImageInputs
]:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment