Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
44ea8513
Unverified
Commit
44ea8513
authored
Oct 04, 2025
by
Cyrus Leung
Committed by
GitHub
Oct 04, 2025
Browse files
[Model] Support nested structures for TensorSchema (#26212)
Signed-off-by:
DarkLight1337
<
tlleungac@connect.ust.hk
>
parent
d3d649ef
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
274 additions
and
292 deletions
+274
-292
tests/utils_/test_tensor_schema.py
tests/utils_/test_tensor_schema.py
+39
-29
vllm/model_executor/models/glm4_1v.py
vllm/model_executor/models/glm4_1v.py
+2
-2
vllm/model_executor/models/hyperclovax_vision.py
vllm/model_executor/models/hyperclovax_vision.py
+177
-210
vllm/model_executor/models/phi3v.py
vllm/model_executor/models/phi3v.py
+3
-3
vllm/utils/tensor_schema.py
vllm/utils/tensor_schema.py
+53
-48
No files found.
tests/utils_/test_tensor_schema.py
View file @
44ea8513
...
@@ -6,37 +6,39 @@ import torch
...
@@ -6,37 +6,39 @@ import torch
from
vllm.model_executor.models.glm4_1v
import
Glm4vImageEmbeddingInputs
from
vllm.model_executor.models.glm4_1v
import
Glm4vImageEmbeddingInputs
from
vllm.model_executor.models.granite_speech
import
GraniteSpeechAudioInputs
from
vllm.model_executor.models.granite_speech
import
GraniteSpeechAudioInputs
from
vllm.model_executor.models.hyperclovax_vision
import
(
HCXVisionVideoPixelInputs
)
from
vllm.model_executor.models.phi3v
import
Phi3VImagePixelInputs
from
vllm.model_executor.models.phi3v
import
Phi3VImagePixelInputs
def
test_tensor_schema_valid_tensor
():
def
test_tensor_schema_valid_tensor
():
Phi3VImagePixelInputs
(
Phi3VImagePixelInputs
(
data
=
torch
.
randn
(
16
,
64
,
3
,
32
,
32
),
pixel_values
=
torch
.
randn
(
16
,
64
,
3
,
32
,
32
),
image_sizes
=
torch
.
randint
(
0
,
256
,
(
16
,
2
)),
image_sizes
=
torch
.
randint
(
0
,
256
,
(
16
,
2
)),
)
)
def
test_tensor_schema_optional_fields
():
def
test_tensor_schema_optional_fields
():
Phi3VImagePixelInputs
(
Phi3VImagePixelInputs
(
data
=
torch
.
randn
(
16
,
64
,
3
,
32
,
32
),
pixel_values
=
torch
.
randn
(
16
,
64
,
3
,
32
,
32
),
image_sizes
=
None
,
image_sizes
=
None
,
)
)
Phi3VImagePixelInputs
(
data
=
torch
.
randn
(
16
,
64
,
3
,
32
,
32
)
,
)
Phi3VImagePixelInputs
(
pixel_values
=
torch
.
randn
(
16
,
64
,
3
,
32
,
32
))
def
test_tensor_schema_constant_dim_failure
():
def
test_tensor_schema_constant_dim_failure
():
with
pytest
.
raises
(
ValueError
,
match
=
"dim
\\
[2
\\
] expected 3, got 4"
):
with
pytest
.
raises
(
ValueError
,
match
=
"dim
\\
[2
\\
] expected 3, got 4"
):
Phi3VImagePixelInputs
(
Phi3VImagePixelInputs
(
data
=
torch
.
randn
(
16
,
64
,
4
,
32
,
32
),
# dim[2] = 4
pixel_values
=
torch
.
randn
(
16
,
64
,
4
,
32
,
32
),
# dim[2] = 4
image_sizes
=
torch
.
randint
(
0
,
256
,
(
16
,
2
)),
image_sizes
=
torch
.
randint
(
0
,
256
,
(
16
,
2
)),
)
)
def
test_tensor_schema_invalid_types_in_list
():
def
test_tensor_schema_invalid_types_in_list
():
with
pytest
.
raises
(
Valu
eError
,
match
=
"is not
a torch.Tensor
"
):
with
pytest
.
raises
(
Typ
eError
,
match
=
"is not
one of the expected types
"
):
Phi3VImagePixelInputs
(
Phi3VImagePixelInputs
(
data
=
[
pixel_values
=
[
torch
.
randn
(
64
,
3
,
32
,
32
),
torch
.
randn
(
64
,
3
,
32
,
32
),
"not_a_tensor"
,
"not_a_tensor"
,
torch
.
randn
(
64
,
3
,
32
,
32
),
torch
.
randn
(
64
,
3
,
32
,
32
),
...
@@ -48,27 +50,28 @@ def test_tensor_schema_invalid_types_in_list():
...
@@ -48,27 +50,28 @@ def test_tensor_schema_invalid_types_in_list():
def
test_tensor_schema_rank_mismatch
():
def
test_tensor_schema_rank_mismatch
():
with
pytest
.
raises
(
ValueError
,
match
=
"has rank 3 but expected 5"
):
with
pytest
.
raises
(
ValueError
,
match
=
"has rank 3 but expected 5"
):
Phi3VImagePixelInputs
(
Phi3VImagePixelInputs
(
data
=
torch
.
randn
(
16
,
64
,
3
),
pixel_values
=
torch
.
randn
(
16
,
64
,
3
),
image_sizes
=
torch
.
randint
(
0
,
256
,
(
16
,
2
)),
image_sizes
=
torch
.
randint
(
0
,
256
,
(
16
,
2
)),
)
)
def
test_tensor_schema_missing_required_field
():
def
test_tensor_schema_missing_required_field
():
with
pytest
.
raises
(
ValueError
,
match
=
"Required field 'data' is missing"
):
with
pytest
.
raises
(
ValueError
,
match
=
"Required field 'pixel_values' is missing"
):
Phi3VImagePixelInputs
(
image_sizes
=
torch
.
randint
(
0
,
256
,
(
16
,
2
)),
)
Phi3VImagePixelInputs
(
image_sizes
=
torch
.
randint
(
0
,
256
,
(
16
,
2
)),
)
def
test_tensor_schema_symbolic_dim_mismatch
():
def
test_tensor_schema_symbolic_dim_mismatch
():
with
pytest
.
raises
(
ValueError
,
match
=
"expected 'bn'=12, got 16"
):
with
pytest
.
raises
(
ValueError
,
match
=
"expected 'bn'=12, got 16"
):
Phi3VImagePixelInputs
(
Phi3VImagePixelInputs
(
data
=
torch
.
randn
(
12
,
64
,
3
,
32
,
32
),
pixel_values
=
torch
.
randn
(
12
,
64
,
3
,
32
,
32
),
image_sizes
=
torch
.
randint
(
0
,
256
,
(
16
,
2
)),
image_sizes
=
torch
.
randint
(
0
,
256
,
(
16
,
2
)),
)
)
def
test_tensor_schema_list_tensor_valid
():
def
test_tensor_schema_list_tensor_valid
():
Phi3VImagePixelInputs
(
Phi3VImagePixelInputs
(
data
=
[
torch
.
randn
(
64
,
3
,
32
,
32
)
for
_
in
range
(
16
)],
pixel_values
=
[
torch
.
randn
(
64
,
3
,
32
,
32
)
for
_
in
range
(
16
)],
image_sizes
=
torch
.
randint
(
0
,
256
,
(
16
,
2
)),
image_sizes
=
torch
.
randint
(
0
,
256
,
(
16
,
2
)),
)
)
...
@@ -76,39 +79,46 @@ def test_tensor_schema_list_tensor_valid():
...
@@ -76,39 +79,46 @@ def test_tensor_schema_list_tensor_valid():
def
test_tensor_schema_variable_patch_counts_valid
():
def
test_tensor_schema_variable_patch_counts_valid
():
# Each image has a different number of patches (p)
# Each image has a different number of patches (p)
# Each tensor has shape (p, 3, 32, 32)
# Each tensor has shape (p, 3, 32, 32)
data
=
[
Phi3VImagePixelInputs
(
pixel_values
=
[
torch
.
randn
(
16
,
3
,
32
,
32
),
# p = 16
torch
.
randn
(
16
,
3
,
32
,
32
),
# p = 16
torch
.
randn
(
32
,
3
,
32
,
32
),
# p = 32
torch
.
randn
(
32
,
3
,
32
,
32
),
# p = 32
torch
.
randn
(
64
,
3
,
32
,
32
),
# p = 64
torch
.
randn
(
64
,
3
,
32
,
32
),
# p = 64
]
],
image_sizes
=
torch
.
randint
(
0
,
256
,
(
3
,
2
))
# bn = 3
image_sizes
=
torch
.
randint
(
0
,
256
,
(
3
,
2
)),
# bn = 3
Phi3VImagePixelInputs
(
data
=
data
,
image_sizes
=
image_sizes
,
)
)
def
test_tensor_schema_tuple_tensor_valid
():
def
test_tensor_schema_tuple_tensor_valid
():
Phi3VImagePixelInputs
(
Phi3VImagePixelInputs
(
data
=
tuple
(
torch
.
randn
(
64
,
3
,
32
,
32
)
for
_
in
range
(
16
)),
pixel_values
=
tuple
(
torch
.
randn
(
64
,
3
,
32
,
32
)
for
_
in
range
(
16
)),
image_sizes
=
torch
.
randint
(
0
,
256
,
(
16
,
2
)),
image_sizes
=
torch
.
randint
(
0
,
256
,
(
16
,
2
)),
)
)
def
test_tensor_schema_double_nested_tensors
():
x
=
torch
.
rand
(
4
,
3
,
32
,
32
)
y
=
torch
.
rand
(
2
,
3
,
32
,
32
)
HCXVisionVideoPixelInputs
(
pixel_values_videos
=
([
x
,
y
,
x
],
[
y
],
[
x
,
y
]))
def
test_tensor_schema_inconsistent_shapes_in_list
():
def
test_tensor_schema_inconsistent_shapes_in_list
():
with
pytest
.
raises
(
ValueError
,
match
=
"contains inconsistent shapes"
):
with
pytest
.
raises
(
ValueError
,
match
=
"contains inconsistent shapes"
):
Phi3VImagePixelInputs
(
Phi3VImagePixelInputs
(
data
=
[
torch
.
randn
(
64
,
3
,
32
,
32
),
pixel_values
=
[
torch
.
randn
(
64
,
3
,
16
,
16
)]
+
torch
.
randn
(
64
,
3
,
32
,
32
),
[
torch
.
randn
(
64
,
3
,
32
,
32
)
for
_
in
range
(
14
)],
torch
.
randn
(
64
,
3
,
16
,
16
),
*
(
torch
.
randn
(
64
,
3
,
32
,
32
)
for
_
in
range
(
14
)),
],
image_sizes
=
torch
.
randint
(
0
,
256
,
(
16
,
2
)),
image_sizes
=
torch
.
randint
(
0
,
256
,
(
16
,
2
)),
)
)
def
test_tensor_schema_empty_list
():
def
test_tensor_schema_empty_list
():
with
pytest
.
raises
(
ValueError
,
match
=
"is an empty
list
"
):
with
pytest
.
raises
(
ValueError
,
match
=
"is an empty
sequence
"
):
Phi3VImagePixelInputs
(
Phi3VImagePixelInputs
(
data
=
[],
pixel_values
=
[],
image_sizes
=
torch
.
randint
(
0
,
256
,
(
0
,
2
)),
image_sizes
=
torch
.
randint
(
0
,
256
,
(
0
,
2
)),
)
)
...
@@ -117,18 +127,18 @@ def test_tensor_schema_validation_disabled_skips_shape_check():
...
@@ -117,18 +127,18 @@ def test_tensor_schema_validation_disabled_skips_shape_check():
# This should NOT raise, because validation is turned off
# This should NOT raise, because validation is turned off
# This would normally fail (dim[2] should be 3, not 4)
# This would normally fail (dim[2] should be 3, not 4)
Phi3VImagePixelInputs
(
Phi3VImagePixelInputs
(
data
=
torch
.
randn
(
16
,
64
,
4
,
32
,
32
),
pixel_values
=
torch
.
randn
(
16
,
64
,
4
,
32
,
32
),
image_sizes
=
torch
.
randint
(
0
,
256
,
(
16
,
2
)),
image_sizes
=
torch
.
randint
(
0
,
256
,
(
16
,
2
)),
validate
=
False
,
validate
=
False
,
)
)
def
test_tensor_schema_with_valid_resolve_binding_dims
():
def
test_tensor_schema_with_valid_resolve_binding_dims
():
data
=
torch
.
randn
(
16
,
64
,
3
,
336
,
336
)
# h=336, w=336
pixel_values
=
torch
.
randn
(
16
,
64
,
3
,
336
,
336
)
# h=336, w=336
image_sizes
=
torch
.
randint
(
0
,
256
,
(
16
,
2
))
image_sizes
=
torch
.
randint
(
0
,
256
,
(
16
,
2
))
Phi3VImagePixelInputs
(
Phi3VImagePixelInputs
(
data
=
data
,
pixel_values
=
pixel_values
,
image_sizes
=
image_sizes
,
image_sizes
=
image_sizes
,
resolve_bindings
=
{
resolve_bindings
=
{
"h"
:
336
,
"h"
:
336
,
...
@@ -138,13 +148,13 @@ def test_tensor_schema_with_valid_resolve_binding_dims():
...
@@ -138,13 +148,13 @@ def test_tensor_schema_with_valid_resolve_binding_dims():
def
test_tensor_schema_with_invalid_resolve_binding_dims
():
def
test_tensor_schema_with_invalid_resolve_binding_dims
():
data
=
torch
.
randn
(
16
,
64
,
3
,
36
,
36
)
# h=36, w=36
pixel_values
=
torch
.
randn
(
16
,
64
,
3
,
36
,
36
)
# h=36, w=36
image_sizes
=
torch
.
randint
(
0
,
256
,
(
16
,
2
))
image_sizes
=
torch
.
randint
(
0
,
256
,
(
16
,
2
))
# Should raise because 'h' and 'w' don't match resolve bindings
# Should raise because 'h' and 'w' don't match resolve bindings
with
pytest
.
raises
(
ValueError
,
match
=
"dim
\\
[3
\\
] expected 336, got 36"
):
with
pytest
.
raises
(
ValueError
,
match
=
"dim
\\
[3
\\
] expected 336, got 36"
):
Phi3VImagePixelInputs
(
Phi3VImagePixelInputs
(
data
=
data
,
pixel_values
=
pixel_values
,
image_sizes
=
image_sizes
,
image_sizes
=
image_sizes
,
resolve_bindings
=
{
resolve_bindings
=
{
"h"
:
336
,
"h"
:
336
,
...
...
vllm/model_executor/models/glm4_1v.py
View file @
44ea8513
...
@@ -29,7 +29,7 @@
...
@@ -29,7 +29,7 @@
import
math
import
math
from
collections.abc
import
Iterable
,
Mapping
,
Sequence
from
collections.abc
import
Iterable
,
Mapping
,
Sequence
from
functools
import
partial
from
functools
import
partial
from
typing
import
Annotated
,
Any
,
Callable
,
Literal
,
Optional
,
Union
,
override
from
typing
import
Annotated
,
Any
,
Callable
,
Literal
,
Optional
,
Union
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
...
@@ -1170,7 +1170,7 @@ class Glm4vDummyInputsBuilder(BaseDummyInputsBuilder[Glm4vProcessingInfo]):
...
@@ -1170,7 +1170,7 @@ class Glm4vDummyInputsBuilder(BaseDummyInputsBuilder[Glm4vProcessingInfo]):
"video.height override (%d) exceeds model's "
"video.height override (%d) exceeds model's "
"maximum height (%d), will be ignored"
,
"maximum height (%d), will be ignored"
,
overrides
.
height
,
height
)
overrides
.
height
,
height
)
height
=
min
(
height
,
override
.
height
)
height
=
min
(
height
,
override
s
.
height
)
video
=
np
.
full
((
num_frames
,
width
,
height
,
3
),
255
,
dtype
=
np
.
uint8
)
video
=
np
.
full
((
num_frames
,
width
,
height
,
3
),
255
,
dtype
=
np
.
uint8
)
video_items
=
[]
video_items
=
[]
...
...
vllm/model_executor/models/hyperclovax_vision.py
View file @
44ea8513
...
@@ -2,27 +2,16 @@
...
@@ -2,27 +2,16 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# copied from : https://github.com/huggingface/transformers
# copied from : https://github.com/huggingface/transformers
import
ast
import
ast
import
sys
from
collections
import
defaultdict
from
collections
import
defaultdict
from
collections.abc
import
Iterable
,
Mapping
,
Sequence
from
collections.abc
import
Iterable
,
Mapping
,
Sequence
from
functools
import
partial
from
functools
import
partial
from
itertools
import
chain
from
itertools
import
accumulate
from
typing
import
Any
,
Literal
,
Optional
,
TypedDict
,
Union
from
typing
import
Annotated
,
Any
,
Literal
,
Optional
,
Union
import
numpy
as
np
import
numpy
as
np
import
PIL
from
einops
import
rearrange
from
PIL
import
Image
if
sys
.
version_info
>=
(
3
,
11
):
import
typing
Unpack
=
typing
.
Unpack
else
:
import
typing_extensions
Unpack
=
typing_extensions
.
Unpack
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
einops
import
rearrange
from
timm.layers
import
LayerNorm
,
LayerNorm2d
from
timm.layers
import
LayerNorm
,
LayerNorm2d
from
timm.models.regnet
import
RegStage
from
timm.models.regnet
import
RegStage
from
transformers
import
BatchFeature
,
CLIPVisionConfig
,
SiglipVisionConfig
from
transformers
import
BatchFeature
,
CLIPVisionConfig
,
SiglipVisionConfig
...
@@ -42,11 +31,13 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
...
@@ -42,11 +31,13 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
PromptReplacement
,
PromptUpdate
)
PromptReplacement
,
PromptUpdate
)
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
from
vllm.multimodal.profiling
import
BaseDummyInputsBuilder
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
vllm.utils.tensor_schema
import
TensorSchema
,
TensorShape
from
.clip
import
CLIPVisionModel
from
.clip
import
CLIPVisionModel
from
.interfaces
import
MultiModalEmbeddings
,
SupportsMultiModal
,
SupportsPP
from
.interfaces
import
MultiModalEmbeddings
,
SupportsMultiModal
,
SupportsPP
from
.siglip
import
SiglipVisionModel
from
.siglip
import
SiglipVisionModel
from
.utils
import
AutoWeightsLoader
,
init_vllm_registered_model
,
maybe_prefix
from
.utils
import
(
AutoWeightsLoader
,
flatten_bn
,
init_vllm_registered_model
,
maybe_prefix
)
from
.vision
import
get_vision_encoder_info
from
.vision
import
get_vision_encoder_info
EOT
=
"<|endofturn|>"
EOT
=
"<|endofturn|>"
...
@@ -69,28 +60,42 @@ def get_num_combined_frames(
...
@@ -69,28 +60,42 @@ def get_num_combined_frames(
return
num_canvases
+
(
leftover_frames
>
0
)
return
num_canvases
+
(
leftover_frames
>
0
)
class
HCXVisionMultimodalPixelInputs
(
TypedDict
):
class
HCXVisionImagePixelInputs
(
TensorSchema
):
type
:
Literal
[
"pixel_values"
]
pixel_values_images
:
list
[
torch
.
Tensor
]
"""
"""
Shape: `[(num_grids, num_channels, height, width), ...]` if anyres
Dimensions:
- n: Number of images
Note that `height` or `width` may be different per batch and image,
- g: Number of grids
in which case the data is passed as a list instead of a batched tensor.
- c: Number of channels (3)
- h: Height
- w: Width
"""
"""
image_sizes_images
:
list
[
tuple
[
Union
[
int
,
float
]]]
type
:
Literal
[
"pixel_values"
]
=
"pixel_values"
"""
pixel_values_images
:
Annotated
[
Shape: `[(height, width), ...]`
list
[
torch
.
Tensor
],
"""
TensorShape
(
"n"
,
"g"
,
3
,
"h"
,
"w"
,
dynamic_dims
=
{
"g"
})]
vision_query_lengths_images
:
list
[
Union
[
int
,
float
]]
image_sizes_images
:
Annotated
[
torch
.
Tensor
,
TensorShape
(
"n"
,
2
)]
pixel_values_videos
:
list
[
tuple
[
Union
[
int
,
float
]]]
HCXVisionImageInputs
=
HCXVisionImagePixelInputs
class
HCXVisionVideoPixelInputs
(
TensorSchema
):
"""
"""
Shape: `[(num_grids, num_channels, height, width), ...]` if anyres
Dimensions:
- n: Number of videos
- f: Number of frames
- g: Number of grids
- c: Number of channels (3)
- h: Height
- w: Width
"""
"""
vision_query_lengths_videos
:
list
[
Union
[
int
,
float
]]
type
:
Literal
[
"pixel_values_videos"
]
=
"pixel_values_videos"
pixel_values_videos
:
Annotated
[
list
[
list
[
torch
.
Tensor
]],
TensorShape
(
"n"
,
"f"
,
"g"
,
3
,
"h"
,
"w"
,
dynamic_dims
=
{
"f"
,
"g"
})]
HCXVision
Multimodal
Inputs
=
Union
[
HCXVision
Multimodal
PixelInputs
]
HCXVision
Video
Inputs
=
HCXVision
Video
PixelInputs
class
HCXVisionProcessingInfo
(
BaseProcessingInfo
):
class
HCXVisionProcessingInfo
(
BaseProcessingInfo
):
...
@@ -191,26 +196,8 @@ class HCXVisionMultiModalProcessor(
...
@@ -191,26 +196,8 @@ class HCXVisionMultiModalProcessor(
mm_kwargs
:
Mapping
[
str
,
object
],
mm_kwargs
:
Mapping
[
str
,
object
],
tok_kwargs
:
Mapping
[
str
,
object
],
tok_kwargs
:
Mapping
[
str
,
object
],
)
->
BatchFeature
:
)
->
BatchFeature
:
def
replace_multimodal_token
(
token_ids
:
torch
.
Tensor
,
target_token
:
int
,
repeats
:
list
[
int
],
):
output
=
list
[
int
]()
_repeats_idx
=
0
for
token_id
in
token_ids
:
if
token_id
==
target_token
:
output
+=
[
token_id
.
item
()]
*
repeats
[
_repeats_idx
]
_repeats_idx
+=
1
else
:
output
+=
[
token_id
.
item
()]
return
torch
.
tensor
(
output
,
device
=
token_ids
.
device
)
for
video_idx
,
video_arr
in
enumerate
(
mm_data
.
get
(
"videos"
,
[])):
for
video_idx
,
video_arr
in
enumerate
(
mm_data
.
get
(
"videos"
,
[])):
if
video_arr
.
dtype
==
np
.
uint8
:
if
video_arr
.
dtype
!=
np
.
uint8
:
continue
mm_data
[
"videos"
][
video_idx
]
=
video_arr
.
astype
(
np
.
uint8
)
mm_data
[
"videos"
][
video_idx
]
=
video_arr
.
astype
(
np
.
uint8
)
processed_outputs
=
self
.
info
.
ctx
.
call_hf_processor
(
processed_outputs
=
self
.
info
.
ctx
.
call_hf_processor
(
...
@@ -223,20 +210,16 @@ class HCXVisionMultiModalProcessor(
...
@@ -223,20 +210,16 @@ class HCXVisionMultiModalProcessor(
)
# text-only
)
# text-only
if
len
(
mm_data
)
>
0
:
if
len
(
mm_data
)
>
0
:
# batchify input as a single item
images
=
mm_data
.
get
(
"images"
)
images
=
mm_data
.
get
(
"images"
,
None
)
videos
=
mm_data
.
get
(
"videos"
)
batched_images
=
None
if
images
is
None
else
[
images
]
# list of video in single conversation
videos
=
mm_data
.
get
(
"videos"
,
None
)
batched_videos
=
None
if
videos
is
None
else
[
videos
]
# batchify input as a single item
_processed_outputs
=
self
.
info
.
ctx
.
call_hf_processor
(
_processed_outputs
=
self
.
info
.
ctx
.
call_hf_processor
(
hf_processor
=
self
.
info
.
get_hf_processor
(
**
mm_kwargs
),
hf_processor
=
self
.
info
.
get_hf_processor
(
**
mm_kwargs
),
data
=
dict
(
data
=
dict
(
text
=
None
,
text
=
None
,
images
=
batched_
images
,
images
=
None
if
images
is
None
else
[
images
]
,
videos
=
batched_
videos
,
videos
=
None
if
videos
is
None
else
[
videos
]
,
),
),
)
# mm-only
)
# mm-only
...
@@ -246,51 +229,43 @@ class HCXVisionMultiModalProcessor(
...
@@ -246,51 +229,43 @@ class HCXVisionMultiModalProcessor(
_processed_outputs
[
k
]
=
v
[
0
]
_processed_outputs
[
k
]
=
v
[
0
]
if
images
:
if
images
:
tokenizer
=
self
.
info
.
get_tokenizer
()
_processed_outputs
[
"image_sizes_images"
]
=
torch
.
tensor
(
image_token_id
=
tokenizer
.
convert_tokens_to_ids
(
IMAGE_TOKEN
)
_processed_outputs
[
"image_sizes_images"
])
processed_outputs
[
"input_ids"
]
=
torch
.
stack
([
_processed_outputs
[
replace_multimodal_token
(
"vision_query_lengths_images"
]
=
torch
.
tensor
(
token_ids
=
_input_ids
,
_processed_outputs
[
"vision_query_lengths_images"
])
target_token
=
image_token_id
,
repeats
=
_processed_outputs
[
"vision_query_lengths_images"
],
)
for
_input_ids
in
processed_outputs
[
"input_ids"
]
],
dim
=
0
)
if
videos
:
if
videos
:
_num_per_videos
=
[
_idx_per_video
=
[
get_num_combined_frames
(
len
(
video
))
for
video
in
videos
0
,
*
accumulate
(
get_num_combined_frames
(
len
(
video
))
for
video
in
videos
)
]
]
_processed_outputs
[
"pixel_values_videos"
]
=
[
_processed_outputs
[
"pixel_values_videos"
]
=
[
_processed_outputs
[
"pixel_values_videos"
]
_processed_outputs
[
"pixel_values_videos"
]
[
sum
(
_num
_per_video
s
[:
_i
]):
sum
(
_num
_per_video
s
[:
_
i
+
1
]
)
]
[
_idx
_per_video
[
i
]:
_idx
_per_video
[
i
+
1
]]
for
_
i
in
range
(
len
(
videos
))
for
i
in
range
(
len
(
videos
))
]
]
_processed_outputs
[
"vision_query_lengths_videos"
]
=
[
_processed_outputs
[
"vision_query_lengths_videos"
]
=
[
torch
.
tensor
(
_processed_outputs
[
"vision_query_lengths_videos"
]
_processed_outputs
[
"vision_query_lengths_videos"
]
[
sum
(
_num
_per_video
s
[:
_i
]):
sum
(
_num
_per_video
s
[:
_
i
+
1
]
)
]
[
_idx
_per_video
[
i
]:
_idx
_per_video
[
i
+
1
]]
)
for
_
i
in
range
(
len
(
videos
))
for
i
in
range
(
len
(
videos
))
]
]
tokenizer
=
self
.
info
.
get_tokenizer
()
video_token_id
=
tokenizer
.
convert_tokens_to_ids
(
VIDEO_TOKEN
)
processed_outputs
[
"input_ids"
]
=
torch
.
stack
([
replace_multimodal_token
(
token_ids
=
_input_ids
,
target_token
=
video_token_id
,
repeats
=
[
sum
(
lens
)
for
lens
in
_processed_outputs
[
"vision_query_lengths_videos"
]
],
)
for
_input_ids
in
processed_outputs
[
"input_ids"
]
],
dim
=
0
)
processed_outputs
.
update
(
_processed_outputs
)
processed_outputs
.
update
(
_processed_outputs
)
return
processed_outputs
return
processed_outputs
def
_hf_processor_applies_updates
(
self
,
prompt_text
:
str
,
mm_items
:
MultiModalDataItems
,
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
tokenization_kwargs
:
Mapping
[
str
,
object
],
)
->
bool
:
return
False
def
_get_prompt_updates
(
def
_get_prompt_updates
(
self
,
self
,
mm_items
:
MultiModalDataItems
,
mm_items
:
MultiModalDataItems
,
...
@@ -311,11 +286,11 @@ class HCXVisionMultiModalProcessor(
...
@@ -311,11 +286,11 @@ class HCXVisionMultiModalProcessor(
out_item
=
out_mm_kwargs
[
modality
][
item_idx
]
out_item
=
out_mm_kwargs
[
modality
][
item_idx
]
if
modality
==
"image"
:
if
modality
==
"image"
:
lens
=
out_item
[
"vision_query_lengths_images"
].
data
lens
=
out_item
[
"vision_query_lengths_images"
].
data
.
tolist
()
num_tokens
=
self
.
info
.
get_num_image_tokens
(
num_tokens
=
self
.
info
.
get_num_image_tokens
(
vision_query_length
=
lens
)
vision_query_length
=
lens
)
elif
modality
==
"video"
:
elif
modality
==
"video"
:
lens
=
out_item
[
"vision_query_lengths_videos"
].
data
lens
=
out_item
[
"vision_query_lengths_videos"
].
data
.
tolist
()
num_tokens
=
self
.
info
.
get_num_video_tokens
(
num_tokens
=
self
.
info
.
get_num_video_tokens
(
vision_query_length
=
lens
)
vision_query_length
=
lens
)
else
:
else
:
...
@@ -343,26 +318,11 @@ class HCXVisionMultiModalProcessor(
...
@@ -343,26 +318,11 @@ class HCXVisionMultiModalProcessor(
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
hf_processor_mm_kwargs
:
Mapping
[
str
,
object
],
)
->
Mapping
[
str
,
MultiModalFieldConfig
]:
)
->
Mapping
[
str
,
MultiModalFieldConfig
]:
return
dict
(
return
dict
(
# image
pixel_values_images
=
MultiModalFieldConfig
.
batched
(
"image"
),
pixel_values_images
=
MultiModalFieldConfig
.
batched
(
"image"
),
image_sizes_images
=
MultiModalFieldConfig
.
batched
(
"image"
),
image_sizes_images
=
MultiModalFieldConfig
.
batched
(
"image"
),
vision_query_lengths_images
=
MultiModalFieldConfig
.
batched
(
"image"
),
vision_query_lengths_images
=
MultiModalFieldConfig
.
batched
(
"image"
),
num_queries_vis_abstractors_images
=
MultiModalFieldConfig
.
batched
(
"image"
),
num_queries_vis_abstractors_slow_images
=
MultiModalFieldConfig
.
batched
(
"image"
),
first_last_frames_slows_images
=
MultiModalFieldConfig
.
batched
(
"image"
),
# video
pixel_values_videos
=
MultiModalFieldConfig
.
batched
(
"video"
),
pixel_values_videos
=
MultiModalFieldConfig
.
batched
(
"video"
),
image_sizes_videos
=
MultiModalFieldConfig
.
batched
(
"video"
),
vision_query_lengths_videos
=
MultiModalFieldConfig
.
batched
(
"video"
),
vision_query_lengths_videos
=
MultiModalFieldConfig
.
batched
(
"video"
),
num_queries_vis_abstractors_videos
=
MultiModalFieldConfig
.
batched
(
"video"
),
num_queries_vis_abstractors_slow_videos
=
MultiModalFieldConfig
.
batched
(
"video"
),
first_last_frames_slows_videos
=
MultiModalFieldConfig
.
batched
(
"video"
),
)
)
...
@@ -617,6 +577,7 @@ class HCXVisionCAbstractor(nn.Module):
...
@@ -617,6 +577,7 @@ class HCXVisionCAbstractor(nn.Module):
info
=
_build_hcxvision_hf_info
,
info
=
_build_hcxvision_hf_info
,
dummy_inputs
=
HCXVisionDummyInputsBuilder
)
dummy_inputs
=
HCXVisionDummyInputsBuilder
)
class
HCXVisionForCausalLM
(
nn
.
Module
,
SupportsMultiModal
,
SupportsPP
):
class
HCXVisionForCausalLM
(
nn
.
Module
,
SupportsMultiModal
,
SupportsPP
):
merge_by_field_config
=
True
packed_modules_mapping
=
{
packed_modules_mapping
=
{
"qkv_proj"
:
[
"q_proj"
,
"k_proj"
,
"v_proj"
],
"qkv_proj"
:
[
"q_proj"
,
"k_proj"
,
"v_proj"
],
...
@@ -692,55 +653,94 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
...
@@ -692,55 +653,94 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
raise
ValueError
(
"Only image or video modality is supported"
)
raise
ValueError
(
"Only image or video modality is supported"
)
def
_parse_and_validate_image_input
(
self
,
**
kwargs
:
object
,
)
->
Optional
[
HCXVisionImageInputs
]:
pixel_values_images
=
kwargs
.
pop
(
"pixel_values_images"
,
None
)
if
pixel_values_images
is
None
:
return
None
image_sizes_images
=
kwargs
.
pop
(
"image_sizes_images"
)
return
HCXVisionImagePixelInputs
(
pixel_values_images
=
pixel_values_images
,
image_sizes_images
=
image_sizes_images
,
)
def
_parse_and_validate_video_input
(
self
,
**
kwargs
:
object
,
)
->
Optional
[
HCXVisionVideoInputs
]:
pixel_values_videos
=
kwargs
.
pop
(
"pixel_values_videos"
,
None
)
if
pixel_values_videos
is
None
:
return
None
return
HCXVisionVideoPixelInputs
(
pixel_values_videos
=
pixel_values_videos
,
)
def
_process_image_input
(
self
,
image_input
:
HCXVisionImageInputs
,
)
->
tuple
[
torch
.
Tensor
,
...]:
return
self
.
forward_images
(
pixel_values_images
=
image_input
[
"pixel_values_images"
],
image_sizes_images
=
image_input
[
"image_sizes_images"
],
)
def
_process_video_input
(
self
,
video_input
:
HCXVisionVideoInputs
,
)
->
tuple
[
torch
.
Tensor
,
...]:
return
self
.
forward_videos
(
pixel_values_videos
=
video_input
[
"pixel_values_videos"
],
)
def
_parse_and_validate_multimodal_inputs
(
self
,
**
kwargs
:
object
)
->
dict
:
modalities
=
{}
# Preserve the order of modalities if there are multiple of them
# from the order of kwargs.
for
input_key
in
kwargs
:
if
(
input_key
==
"pixel_values_images"
and
"images"
not
in
modalities
):
modalities
[
"images"
]
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
if
(
input_key
==
"pixel_values_videos"
and
"videos"
not
in
modalities
):
modalities
[
"videos"
]
=
self
.
_parse_and_validate_video_input
(
**
kwargs
)
return
modalities
def
get_language_model
(
self
)
->
torch
.
nn
.
Module
:
def
get_language_model
(
self
)
->
torch
.
nn
.
Module
:
return
self
.
language_model
return
self
.
language_model
def
get_multimodal_embeddings
(
def
get_multimodal_embeddings
(
self
,
self
,
**
kwargs
:
Unpack
[
HCXVisionMultimodalInputs
]
,
**
kwargs
:
object
,
)
->
MultiModalEmbeddings
:
)
->
MultiModalEmbeddings
:
modalities
=
self
.
_parse_and_validate_multimodal_inputs
(
**
kwargs
)
if
not
modalities
:
return
[]
# The result multimodal_embeddings is tuple of tensors, with each
# tensor correspoending to a multimodal data item (image or video).
multimodal_embeddings
:
tuple
[
torch
.
Tensor
,
...]
=
()
# NOTE: It is important to iterate over the keys in this dictionary
# to preserve the order of the modalities.
for
modality
in
modalities
:
if
modality
==
"images"
:
image_input
=
modalities
[
"images"
]
vision_embeddings
=
self
.
_process_image_input
(
image_input
)
multimodal_embeddings
+=
vision_embeddings
if
modality
==
"videos"
:
video_input
=
modalities
[
"videos"
]
video_embeddings
=
self
.
_process_video_input
(
video_input
)
multimodal_embeddings
+=
video_embeddings
multimodal_embeddings
=
list
()
if
kwargs
.
get
(
"pixel_values_images"
)
is
not
None
:
for
_pixel_values_images
,
_image_sizes_images
in
zip
(
kwargs
[
"pixel_values_images"
],
kwargs
[
"image_sizes_images"
]):
_pixel_values_images
=
_pixel_values_images
.
unsqueeze
(
dim
=
0
)
_image_sizes_images
=
_image_sizes_images
.
unsqueeze
(
dim
=
0
)
_len_pixel_values_images
=
[
len
(
pixel_value
)
for
pixel_value
in
_pixel_values_images
]
if
isinstance
(
_image_sizes_images
,
torch
.
Tensor
):
_image_sizes_images
=
_image_sizes_images
.
detach
().
cpu
(
).
tolist
()
_multimodal_embeddings_images
=
self
.
forward_images
(
pixel_values_images
=
_pixel_values_images
,
image_sizes_images
=
_image_sizes_images
,
len_pixel_values_images
=
_len_pixel_values_images
,
)
_multimodal_embeddings_images
=
torch
.
cat
(
_multimodal_embeddings_images
,
dim
=
0
)
multimodal_embeddings
.
append
(
_multimodal_embeddings_images
)
if
kwargs
.
get
(
"pixel_values_videos"
)
is
not
None
:
for
_pixel_values_videos
,
_vision_query_lengths_videos
in
zip
(
kwargs
[
"pixel_values_videos"
],
kwargs
[
"vision_query_lengths_videos"
]):
_len_pixel_values_videos
=
[
len
(
_vision_query_lengths
)
for
_vision_query_lengths
in
_vision_query_lengths_videos
]
_c
,
_w
,
_h
=
_pixel_values_videos
.
shape
[
-
3
:]
_pixel_values_videos
=
_pixel_values_videos
.
reshape
(
sum
(
_len_pixel_values_videos
),
-
1
,
_c
,
_w
,
_h
).
unsqueeze
(
dim
=
0
)
_multimodal_embeddings_videos
=
self
.
forward_videos
(
pixel_values_videos
=
_pixel_values_videos
,
len_pixel_values_videos
=
_len_pixel_values_videos
,
)
_multimodal_embeddings_videos
=
torch
.
cat
(
_multimodal_embeddings_videos
,
dim
=
0
)
multimodal_embeddings
.
append
(
_multimodal_embeddings_videos
)
return
multimodal_embeddings
return
multimodal_embeddings
def
forward
(
def
forward
(
...
@@ -762,28 +762,20 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
...
@@ -762,28 +762,20 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
def
forward_images
(
def
forward_images
(
self
,
self
,
pixel_values_images
:
list
[
list
[
torch
.
FloatTensor
]],
pixel_values_images
:
list
[
torch
.
Tensor
],
image_sizes_images
:
list
[
list
[
tuple
[
int
,
int
]]],
image_sizes_images
:
torch
.
Tensor
,
len_pixel_values_images
:
list
[
int
],
)
->
tuple
[
torch
.
Tensor
,
...]:
)
->
list
[
list
[
torch
.
Tensor
]]:
pixel_values_image_flat
=
flatten_bn
(
pixel_values_images
,
concat
=
True
)
if
sum
(
len_pixel_values_images
)
==
0
:
return
None
concat_pixel_values_images
=
torch
.
cat
(
list
(
chain
(
*
pixel_values_images
)),
dim
=
0
)
visual_token_idx
=
0
if
"siglip"
in
self
.
vision_config
.
model_type
else
1
visual_token_idx
=
0
if
"siglip"
in
self
.
vision_config
.
model_type
else
1
image_forward_outs
=
self
.
vision_model
(
image_forward_outs
=
self
.
vision_model
(
concat_
pixel_values_image
s
)[:,
visual_token_idx
:]
pixel_values_image
_flat
)[:,
visual_token_idx
:]
image_forward_outs
=
image_forward_outs
.
to
(
image_forward_outs
=
image_forward_outs
.
to
(
dtype
=
self
.
mm_projector
.
dtype
)
dtype
=
self
.
mm_projector
.
dtype
)
image_forward_outs
=
self
.
mm_projector
(
image_forward_outs
)
# b (h w) d
image_forward_outs
=
self
.
mm_projector
(
image_forward_outs
)
# b (h w) d
split_sizes
=
[
split_sizes
=
[
len
(
item
)
for
item
in
pixel_values_images
]
pixel_value
.
shape
[
0
]
for
pixel_value
in
chain
(
*
pixel_values_images
)
]
image_forward_outs
=
torch
.
split
(
image_forward_outs
,
image_forward_outs
=
torch
.
split
(
image_forward_outs
,
split_sizes
,
split_sizes
,
dim
=
0
)
dim
=
0
)
...
@@ -791,10 +783,7 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
...
@@ -791,10 +783,7 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
# newline for anyres postprocessing
# newline for anyres postprocessing
image_features
=
anyres_postprocessing
(
image_features
=
anyres_postprocessing
(
image_forward_outs
=
image_forward_outs
,
image_forward_outs
=
image_forward_outs
,
image_sizes
=
[
image_sizes
=
image_sizes_images
.
tolist
(),
image_size
for
image_sizes
in
image_sizes_images
for
image_size
in
image_sizes
],
num_queries_vis_abstractor
=
self
.
config
.
num_queries_vis_abstractor
=
self
.
config
.
num_queries_vis_abstractor_image
,
num_queries_vis_abstractor_image
,
unpad
=
self
.
config
.
unpad
,
unpad
=
self
.
config
.
unpad
,
...
@@ -803,26 +792,21 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
...
@@ -803,26 +792,21 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
image_newline
=
self
.
image_newline
,
image_newline
=
self
.
image_newline
,
possible_resolutions
=
self
.
config
.
possible_resolutions
,
possible_resolutions
=
self
.
config
.
possible_resolutions
,
)
)
return
image_features
return
tuple
(
image_features
)
def
forward_videos
(
def
forward_videos
(
self
,
self
,
pixel_values_videos
:
list
[
list
[
torch
.
FloatTensor
]],
pixel_values_videos
:
list
[
list
[
torch
.
Tensor
]],
len_pixel_values_videos
:
list
[
int
],
)
->
tuple
[
torch
.
Tensor
,
...]:
)
->
list
[
torch
.
Tensor
]:
pixel_values_videos_flat
=
flatten_bn
(
[
frame
for
frames
in
pixel_values_videos
for
frame
in
frames
],
len_video_grids
=
sum
(
len_pixel_values_videos
)
concat
=
True
,
if
len_video_grids
==
0
:
)
return
None
# Run Vision Model
concat_pixel_values_videos
=
torch
.
cat
(
list
(
chain
(
*
pixel_values_videos
)),
dim
=
0
)
visual_token_idx
=
0
if
"siglip"
in
self
.
vision_config
.
model_type
else
1
visual_token_idx
=
0
if
"siglip"
in
self
.
vision_config
.
model_type
else
1
video_forward_outs
=
self
.
vision_model
(
video_forward_outs
=
self
.
vision_model
(
concat_
pixel_values_videos
)[:,
visual_token_idx
:]
pixel_values_videos
_flat
)[:,
visual_token_idx
:]
video_forward_outs
=
video_forward_outs
.
to
(
video_forward_outs
=
video_forward_outs
.
to
(
dtype
=
self
.
mm_projector
.
dtype
)
dtype
=
self
.
mm_projector
.
dtype
)
...
@@ -905,7 +889,11 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
...
@@ -905,7 +889,11 @@ class HCXVisionForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
)
==
0
,
f
"target_features is not empty!!
{
target_features
}
"
)
==
0
,
f
"target_features is not empty!!
{
target_features
}
"
assert
len
(
video_groups
)
==
len
(
video_features
)
assert
len
(
video_groups
)
==
len
(
video_features
)
return
video_features
feats_per_video
=
[
len
(
video
)
for
video
in
pixel_values_videos
]
idxs_per_video
=
[
0
,
*
accumulate
(
feats_per_video
)]
return
tuple
(
torch
.
cat
(
video_features
[
idxs_per_video
[
i
]:
idxs_per_video
[
i
+
1
]])
for
i
in
range
(
len
(
feats_per_video
)))
def
_prepare_multimodal_kwargs
(
self
,
**
kwargs
:
object
):
def
_prepare_multimodal_kwargs
(
self
,
**
kwargs
:
object
):
output
=
defaultdict
(
list
)
output
=
defaultdict
(
list
)
...
@@ -1111,15 +1099,15 @@ def reshape_and_unpad_image_features(
...
@@ -1111,15 +1099,15 @@ def reshape_and_unpad_image_features(
def
anyres_postprocessing
(
def
anyres_postprocessing
(
image_forward_outs
:
list
[
torch
.
Float
Tensor
],
image_forward_outs
:
list
[
torch
.
Tensor
],
image_sizes
:
list
[
list
[
int
]],
image_sizes
:
list
[
list
[
int
]],
possible_resolutions
:
list
[
tuple
[
int
,
int
]],
possible_resolutions
:
list
[
tuple
[
int
,
int
]],
patch_size
:
int
,
patch_size
:
int
,
grid_size
:
int
,
grid_size
:
int
,
image_newline
:
torch
.
Float
Tensor
,
image_newline
:
torch
.
Tensor
,
num_queries_vis_abstractor
:
int
=
-
1
,
num_queries_vis_abstractor
:
int
=
-
1
,
unpad
:
bool
=
False
,
unpad
:
bool
=
False
,
)
->
list
[
torch
.
Float
Tensor
]:
)
->
list
[
torch
.
Tensor
]:
height
=
width
=
grid_size
//
patch_size
height
=
width
=
grid_size
//
patch_size
if
num_queries_vis_abstractor
>
0
:
if
num_queries_vis_abstractor
>
0
:
...
@@ -1147,26 +1135,5 @@ def anyres_postprocessing(
...
@@ -1147,26 +1135,5 @@ def anyres_postprocessing(
(
image_feature
,
image_newline
[
None
].
to
(
image_feature
.
device
)),
(
image_feature
,
image_newline
[
None
].
to
(
image_feature
.
device
)),
dim
=
0
)
dim
=
0
)
new_image_features
.
append
(
image_feature
)
new_image_features
.
append
(
image_feature
)
image_features
=
new_image_features
return
image_features
return
new_image_features
def
resize_image
(
image
:
Union
[
np
.
ndarray
,
PIL
.
Image
.
Image
],
max_side
:
int
=
378
,
)
->
np
.
ndarray
:
image_arr
=
image
if
isinstance
(
image
,
np
.
ndarray
):
image
=
Image
.
fromarray
(
image
)
width
,
height
=
image
.
size
cur_max_size
=
max
(
width
,
height
)
if
cur_max_size
<=
max_side
:
return
image_arr
scale
=
max_side
/
cur_max_size
width
=
int
(
width
*
scale
)
height
=
int
(
height
*
scale
)
image
=
image
.
resize
((
width
,
height
),
Image
.
LANCZOS
)
image_arr
=
np
.
array
(
image
)
return
image_arr
vllm/model_executor/models/phi3v.py
View file @
44ea8513
...
@@ -109,7 +109,7 @@ class Phi3VImagePixelInputs(TensorSchema):
...
@@ -109,7 +109,7 @@ class Phi3VImagePixelInputs(TensorSchema):
type
:
Literal
[
"pixel_values"
,
"image_embeds"
]
=
"pixel_values"
type
:
Literal
[
"pixel_values"
,
"image_embeds"
]
=
"pixel_values"
# Supports either a stacked tensor or a list of (p, 3, h, w) tensors
# Supports either a stacked tensor or a list of (p, 3, h, w) tensors
data
:
Annotated
[
pixel_values
:
Annotated
[
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]],
Union
[
torch
.
Tensor
,
list
[
torch
.
Tensor
]],
TensorShape
(
"bn"
,
"p"
,
3
,
"h"
,
"w"
,
dynamic_dims
=
{
"p"
}
TensorShape
(
"bn"
,
"p"
,
3
,
"h"
,
"w"
,
dynamic_dims
=
{
"p"
}
),
# 'p' may vary across items
),
# 'p' may vary across items
...
@@ -594,7 +594,7 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP,
...
@@ -594,7 +594,7 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP,
if
pixel_values
is
not
None
:
if
pixel_values
is
not
None
:
return
Phi3VImagePixelInputs
(
return
Phi3VImagePixelInputs
(
type
=
"pixel_values"
,
type
=
"pixel_values"
,
data
=
flatten_bn
(
pixel_values
),
pixel_values
=
flatten_bn
(
pixel_values
),
image_sizes
=
flatten_bn
(
image_sizes
,
concat
=
True
),
image_sizes
=
flatten_bn
(
image_sizes
,
concat
=
True
),
resolve_bindings
=
{
resolve_bindings
=
{
"h"
:
CLIP_VIT_LARGE_PATCH14_336_CONFIG
.
image_size
,
"h"
:
CLIP_VIT_LARGE_PATCH14_336_CONFIG
.
image_size
,
...
@@ -628,7 +628,7 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP,
...
@@ -628,7 +628,7 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP,
)
)
assert
self
.
vision_embed_tokens
is
not
None
assert
self
.
vision_embed_tokens
is
not
None
image_embeds
=
self
.
vision_embed_tokens
(
image_input
[
"
data
"
],
image_embeds
=
self
.
vision_embed_tokens
(
image_input
[
"
pixel_values
"
],
image_input
[
"image_sizes"
])
image_input
[
"image_sizes"
])
return
image_embeds
return
image_embeds
...
...
vllm/utils/tensor_schema.py
View file @
44ea8513
...
@@ -94,34 +94,63 @@ class TensorSchema:
...
@@ -94,34 +94,63 @@ class TensorSchema:
return
False
return
False
return
True
return
True
def
_validate_nested_tensors
(
def
_fmt_indexer
(
self
,
idxs
:
tuple
[
int
,
...])
->
str
:
if
not
idxs
:
return
""
return
str
(
list
(
idxs
))
def
_validate_field
(
self
,
self
,
value
:
Union
[
list
[
torch
.
Tensor
],
tuple
[
torch
.
Tensor
,
...]]
,
value
:
object
,
field_name
:
str
,
field_name
:
str
,
expected_shape
:
tuple
[
Union
[
int
,
str
],
...],
expected_shape
:
tuple
[
Union
[
int
,
str
],
...],
dynamic_dims
:
set
[
str
],
dynamic_dims
:
set
[
str
],
leading_idxs
:
tuple
[
int
,
...]
=
(),
)
->
tuple
[
int
,
...]:
)
->
tuple
[
int
,
...]:
"""Validate a list/tuple of tensors and return the actual shape."""
"""Validate a field and return the actual shape."""
if
isinstance
(
value
,
(
int
,
float
)):
return
()
# Scalar
if
isinstance
(
value
,
torch
.
Tensor
):
return
value
.
shape
if
not
isinstance
(
value
,
(
list
,
tuple
)):
raise
TypeError
(
f
"
{
field_name
}{
self
.
_fmt_indexer
(
leading_idxs
)
}
is not "
f
"one of the expected types: int, float, Tensor, list, tuple. "
f
"Got:
{
type
(
value
)
}
"
)
if
len
(
value
)
==
0
:
raise
ValueError
(
f
"
{
field_name
}{
self
.
_fmt_indexer
(
leading_idxs
)
}
"
f
"is an empty sequence"
)
# Ensure all tensors in the list have the same
# Ensure all tensors in the list have the same
# shape, besides dynamic dimensions
# shape, besides dynamic dimensions
first
=
value
[
0
]
for
i
,
v
in
enumerate
(
value
):
for
i
,
v
in
enumerate
(
value
):
if
not
isinstance
(
v
,
torch
.
Tensor
):
shape
=
self
.
_validate_field
(
raise
ValueError
(
f
"
{
field_name
}
[
{
i
}
] is not a "
v
,
f
"torch.Tensor"
)
field_name
,
if
not
self
.
_match_shape_with_dynamic
(
expected_shape
[
1
:],
v
.
shape
,
dynamic_dims
,
first
.
shape
,
leading_idxs
=
leading_idxs
+
(
i
,
),
)
if
i
==
0
:
first_shape
=
shape
elif
not
self
.
_match_shape_with_dynamic
(
shape
,
first_shape
,
expected_shape
,
expected_shape
,
dynamic_dims
,
dynamic_dims
,
):
):
raise
ValueError
(
f
"
{
field_name
}
contains inconsistent "
raise
ValueError
(
f
"shapes:
{
first
.
shape
}
vs
{
v
.
shape
}
"
f
"
{
field_name
}{
self
.
_fmt_indexer
(
leading_idxs
)
}
"
f
"at index
{
i
}
"
)
f
"contains inconsistent shapes:
{
first_shape
}
"
f
"(index 0) vs
{
shape
}
(index
{
i
}
)"
)
# Treat the list as a stacked tensor:
# Treat the list as a stacked tensor:
# shape = (len(list), *tensor.shape)
# shape = (len(list), *tensor.shape)
return
(
len
(
value
),
)
+
first
.
shape
return
(
len
(
value
),
)
+
first
_
shape
def
_validate_tensor_shape_expected
(
def
_validate_tensor_shape_expected
(
self
,
self
,
...
@@ -187,36 +216,12 @@ class TensorSchema:
...
@@ -187,36 +216,12 @@ class TensorSchema:
for
arg
in
args
:
for
arg
in
args
:
if
isinstance
(
arg
,
TensorShape
):
if
isinstance
(
arg
,
TensorShape
):
expected_shape
=
arg
.
resolve
(
**
self
.
_resolve_bindings
)
expected_shape
=
arg
.
resolve
(
**
self
.
_resolve_bindings
)
if
isinstance
(
value
,
(
list
,
tuple
)):
actual_shape
=
self
.
_validate_field
(
# list/tuple of Tensors → shape = (len(value), ...)
value
,
if
value
and
isinstance
(
value
[
0
],
torch
.
Tensor
):
field_name
,
actual_shape
=
self
.
_validate_nested_tensors
(
expected_shape
,
value
,
field_name
,
expected_shape
,
arg
.
dynamic_dims
,
arg
.
dynamic_dims
)
)
elif
value
:
# list/tuple of scalars → shape = (len(value),)
actual_shape
=
(
len
(
value
),
)
else
:
raise
ValueError
(
f
"
{
field_name
}
is an empty list"
)
# Tensor → shape = tensor.shape
elif
isinstance
(
value
,
torch
.
Tensor
):
actual_shape
=
value
.
shape
# Otherwise, it's an unsupported type
else
:
type_names
=
[]
for
arg
in
args
:
if
hasattr
(
arg
,
"__name__"
):
type_names
.
append
(
str
(
arg
.
__name__
))
else
:
type_names
.
append
(
str
(
arg
))
expected_types
=
", "
.
join
(
type_names
)
raise
ValueError
(
f
"
{
field_name
}
is not one of the expected "
f
"types:
{
expected_types
}
"
)
self
.
_validate_tensor_shape_expected
(
self
.
_validate_tensor_shape_expected
(
actual_shape
,
expected_shape
,
field_name
,
actual_shape
,
expected_shape
,
field_name
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment