Unverified Commit 44ea8513 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Model] Support nested structures for TensorSchema (#26212)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent d3d649ef
...@@ -6,37 +6,39 @@ import torch ...@@ -6,37 +6,39 @@ import torch
from vllm.model_executor.models.glm4_1v import Glm4vImageEmbeddingInputs from vllm.model_executor.models.glm4_1v import Glm4vImageEmbeddingInputs
from vllm.model_executor.models.granite_speech import GraniteSpeechAudioInputs from vllm.model_executor.models.granite_speech import GraniteSpeechAudioInputs
from vllm.model_executor.models.hyperclovax_vision import (
HCXVisionVideoPixelInputs)
from vllm.model_executor.models.phi3v import Phi3VImagePixelInputs from vllm.model_executor.models.phi3v import Phi3VImagePixelInputs
def test_tensor_schema_valid_tensor(): def test_tensor_schema_valid_tensor():
Phi3VImagePixelInputs( Phi3VImagePixelInputs(
data=torch.randn(16, 64, 3, 32, 32), pixel_values=torch.randn(16, 64, 3, 32, 32),
image_sizes=torch.randint(0, 256, (16, 2)), image_sizes=torch.randint(0, 256, (16, 2)),
) )
def test_tensor_schema_optional_fields(): def test_tensor_schema_optional_fields():
Phi3VImagePixelInputs( Phi3VImagePixelInputs(
data=torch.randn(16, 64, 3, 32, 32), pixel_values=torch.randn(16, 64, 3, 32, 32),
image_sizes=None, image_sizes=None,
) )
Phi3VImagePixelInputs(data=torch.randn(16, 64, 3, 32, 32), ) Phi3VImagePixelInputs(pixel_values=torch.randn(16, 64, 3, 32, 32))
def test_tensor_schema_constant_dim_failure(): def test_tensor_schema_constant_dim_failure():
with pytest.raises(ValueError, match="dim\\[2\\] expected 3, got 4"): with pytest.raises(ValueError, match="dim\\[2\\] expected 3, got 4"):
Phi3VImagePixelInputs( Phi3VImagePixelInputs(
data=torch.randn(16, 64, 4, 32, 32), # dim[2] = 4 pixel_values=torch.randn(16, 64, 4, 32, 32), # dim[2] = 4
image_sizes=torch.randint(0, 256, (16, 2)), image_sizes=torch.randint(0, 256, (16, 2)),
) )
def test_tensor_schema_invalid_types_in_list(): def test_tensor_schema_invalid_types_in_list():
with pytest.raises(ValueError, match="is not a torch.Tensor"): with pytest.raises(TypeError, match="is not one of the expected types"):
Phi3VImagePixelInputs( Phi3VImagePixelInputs(
data=[ pixel_values=[
torch.randn(64, 3, 32, 32), torch.randn(64, 3, 32, 32),
"not_a_tensor", "not_a_tensor",
torch.randn(64, 3, 32, 32), torch.randn(64, 3, 32, 32),
...@@ -48,27 +50,28 @@ def test_tensor_schema_invalid_types_in_list(): ...@@ -48,27 +50,28 @@ def test_tensor_schema_invalid_types_in_list():
def test_tensor_schema_rank_mismatch(): def test_tensor_schema_rank_mismatch():
with pytest.raises(ValueError, match="has rank 3 but expected 5"): with pytest.raises(ValueError, match="has rank 3 but expected 5"):
Phi3VImagePixelInputs( Phi3VImagePixelInputs(
data=torch.randn(16, 64, 3), pixel_values=torch.randn(16, 64, 3),
image_sizes=torch.randint(0, 256, (16, 2)), image_sizes=torch.randint(0, 256, (16, 2)),
) )
def test_tensor_schema_missing_required_field(): def test_tensor_schema_missing_required_field():
with pytest.raises(ValueError, match="Required field 'data' is missing"): with pytest.raises(ValueError,
match="Required field 'pixel_values' is missing"):
Phi3VImagePixelInputs(image_sizes=torch.randint(0, 256, (16, 2)), ) Phi3VImagePixelInputs(image_sizes=torch.randint(0, 256, (16, 2)), )
def test_tensor_schema_symbolic_dim_mismatch(): def test_tensor_schema_symbolic_dim_mismatch():
with pytest.raises(ValueError, match="expected 'bn'=12, got 16"): with pytest.raises(ValueError, match="expected 'bn'=12, got 16"):
Phi3VImagePixelInputs( Phi3VImagePixelInputs(
data=torch.randn(12, 64, 3, 32, 32), pixel_values=torch.randn(12, 64, 3, 32, 32),
image_sizes=torch.randint(0, 256, (16, 2)), image_sizes=torch.randint(0, 256, (16, 2)),
) )
def test_tensor_schema_list_tensor_valid(): def test_tensor_schema_list_tensor_valid():
Phi3VImagePixelInputs( Phi3VImagePixelInputs(
data=[torch.randn(64, 3, 32, 32) for _ in range(16)], pixel_values=[torch.randn(64, 3, 32, 32) for _ in range(16)],
image_sizes=torch.randint(0, 256, (16, 2)), image_sizes=torch.randint(0, 256, (16, 2)),
) )
...@@ -76,39 +79,46 @@ def test_tensor_schema_list_tensor_valid(): ...@@ -76,39 +79,46 @@ def test_tensor_schema_list_tensor_valid():
def test_tensor_schema_variable_patch_counts_valid(): def test_tensor_schema_variable_patch_counts_valid():
# Each image has a different number of patches (p) # Each image has a different number of patches (p)
# Each tensor has shape (p, 3, 32, 32) # Each tensor has shape (p, 3, 32, 32)
data = [ Phi3VImagePixelInputs(
pixel_values=[
torch.randn(16, 3, 32, 32), # p = 16 torch.randn(16, 3, 32, 32), # p = 16
torch.randn(32, 3, 32, 32), # p = 32 torch.randn(32, 3, 32, 32), # p = 32
torch.randn(64, 3, 32, 32), # p = 64 torch.randn(64, 3, 32, 32), # p = 64
] ],
image_sizes = torch.randint(0, 256, (3, 2)) # bn = 3 image_sizes=torch.randint(0, 256, (3, 2)), # bn = 3
Phi3VImagePixelInputs(
data=data,
image_sizes=image_sizes,
) )
def test_tensor_schema_tuple_tensor_valid(): def test_tensor_schema_tuple_tensor_valid():
Phi3VImagePixelInputs( Phi3VImagePixelInputs(
data=tuple(torch.randn(64, 3, 32, 32) for _ in range(16)), pixel_values=tuple(torch.randn(64, 3, 32, 32) for _ in range(16)),
image_sizes=torch.randint(0, 256, (16, 2)), image_sizes=torch.randint(0, 256, (16, 2)),
) )
def test_tensor_schema_double_nested_tensors():
x = torch.rand(4, 3, 32, 32)
y = torch.rand(2, 3, 32, 32)
HCXVisionVideoPixelInputs(pixel_values_videos=([x, y, x], [y], [x, y]))
def test_tensor_schema_inconsistent_shapes_in_list(): def test_tensor_schema_inconsistent_shapes_in_list():
with pytest.raises(ValueError, match="contains inconsistent shapes"): with pytest.raises(ValueError, match="contains inconsistent shapes"):
Phi3VImagePixelInputs( Phi3VImagePixelInputs(
data=[torch.randn(64, 3, 32, 32), pixel_values=[
torch.randn(64, 3, 16, 16)] + torch.randn(64, 3, 32, 32),
[torch.randn(64, 3, 32, 32) for _ in range(14)], torch.randn(64, 3, 16, 16),
*(torch.randn(64, 3, 32, 32) for _ in range(14)),
],
image_sizes=torch.randint(0, 256, (16, 2)), image_sizes=torch.randint(0, 256, (16, 2)),
) )
def test_tensor_schema_empty_list(): def test_tensor_schema_empty_list():
with pytest.raises(ValueError, match="is an empty list"): with pytest.raises(ValueError, match="is an empty sequence"):
Phi3VImagePixelInputs( Phi3VImagePixelInputs(
data=[], pixel_values=[],
image_sizes=torch.randint(0, 256, (0, 2)), image_sizes=torch.randint(0, 256, (0, 2)),
) )
...@@ -117,18 +127,18 @@ def test_tensor_schema_validation_disabled_skips_shape_check(): ...@@ -117,18 +127,18 @@ def test_tensor_schema_validation_disabled_skips_shape_check():
# This should NOT raise, because validation is turned off # This should NOT raise, because validation is turned off
# This would normally fail (dim[2] should be 3, not 4) # This would normally fail (dim[2] should be 3, not 4)
Phi3VImagePixelInputs( Phi3VImagePixelInputs(
data=torch.randn(16, 64, 4, 32, 32), pixel_values=torch.randn(16, 64, 4, 32, 32),
image_sizes=torch.randint(0, 256, (16, 2)), image_sizes=torch.randint(0, 256, (16, 2)),
validate=False, validate=False,
) )
def test_tensor_schema_with_valid_resolve_binding_dims(): def test_tensor_schema_with_valid_resolve_binding_dims():
data = torch.randn(16, 64, 3, 336, 336) # h=336, w=336 pixel_values = torch.randn(16, 64, 3, 336, 336) # h=336, w=336
image_sizes = torch.randint(0, 256, (16, 2)) image_sizes = torch.randint(0, 256, (16, 2))
Phi3VImagePixelInputs( Phi3VImagePixelInputs(
data=data, pixel_values=pixel_values,
image_sizes=image_sizes, image_sizes=image_sizes,
resolve_bindings={ resolve_bindings={
"h": 336, "h": 336,
...@@ -138,13 +148,13 @@ def test_tensor_schema_with_valid_resolve_binding_dims(): ...@@ -138,13 +148,13 @@ def test_tensor_schema_with_valid_resolve_binding_dims():
def test_tensor_schema_with_invalid_resolve_binding_dims(): def test_tensor_schema_with_invalid_resolve_binding_dims():
data = torch.randn(16, 64, 3, 36, 36) # h=36, w=36 pixel_values = torch.randn(16, 64, 3, 36, 36) # h=36, w=36
image_sizes = torch.randint(0, 256, (16, 2)) image_sizes = torch.randint(0, 256, (16, 2))
# Should raise because 'h' and 'w' don't match resolve bindings # Should raise because 'h' and 'w' don't match resolve bindings
with pytest.raises(ValueError, match="dim\\[3\\] expected 336, got 36"): with pytest.raises(ValueError, match="dim\\[3\\] expected 336, got 36"):
Phi3VImagePixelInputs( Phi3VImagePixelInputs(
data=data, pixel_values=pixel_values,
image_sizes=image_sizes, image_sizes=image_sizes,
resolve_bindings={ resolve_bindings={
"h": 336, "h": 336,
......
...@@ -29,7 +29,7 @@ ...@@ -29,7 +29,7 @@
import math import math
from collections.abc import Iterable, Mapping, Sequence from collections.abc import Iterable, Mapping, Sequence
from functools import partial from functools import partial
from typing import Annotated, Any, Callable, Literal, Optional, Union, override from typing import Annotated, Any, Callable, Literal, Optional, Union
import numpy as np import numpy as np
import torch import torch
...@@ -1170,7 +1170,7 @@ class Glm4vDummyInputsBuilder(BaseDummyInputsBuilder[Glm4vProcessingInfo]): ...@@ -1170,7 +1170,7 @@ class Glm4vDummyInputsBuilder(BaseDummyInputsBuilder[Glm4vProcessingInfo]):
"video.height override (%d) exceeds model's " "video.height override (%d) exceeds model's "
"maximum height (%d), will be ignored", "maximum height (%d), will be ignored",
overrides.height, height) overrides.height, height)
height = min(height, override.height) height = min(height, overrides.height)
video = np.full((num_frames, width, height, 3), 255, dtype=np.uint8) video = np.full((num_frames, width, height, 3), 255, dtype=np.uint8)
video_items = [] video_items = []
......
...@@ -109,7 +109,7 @@ class Phi3VImagePixelInputs(TensorSchema): ...@@ -109,7 +109,7 @@ class Phi3VImagePixelInputs(TensorSchema):
type: Literal["pixel_values", "image_embeds"] = "pixel_values" type: Literal["pixel_values", "image_embeds"] = "pixel_values"
# Supports either a stacked tensor or a list of (p, 3, h, w) tensors # Supports either a stacked tensor or a list of (p, 3, h, w) tensors
data: Annotated[ pixel_values: Annotated[
Union[torch.Tensor, list[torch.Tensor]], Union[torch.Tensor, list[torch.Tensor]],
TensorShape("bn", "p", 3, "h", "w", dynamic_dims={"p"} TensorShape("bn", "p", 3, "h", "w", dynamic_dims={"p"}
), # 'p' may vary across items ), # 'p' may vary across items
...@@ -594,7 +594,7 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, ...@@ -594,7 +594,7 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP,
if pixel_values is not None: if pixel_values is not None:
return Phi3VImagePixelInputs( return Phi3VImagePixelInputs(
type="pixel_values", type="pixel_values",
data=flatten_bn(pixel_values), pixel_values=flatten_bn(pixel_values),
image_sizes=flatten_bn(image_sizes, concat=True), image_sizes=flatten_bn(image_sizes, concat=True),
resolve_bindings={ resolve_bindings={
"h": CLIP_VIT_LARGE_PATCH14_336_CONFIG.image_size, "h": CLIP_VIT_LARGE_PATCH14_336_CONFIG.image_size,
...@@ -628,7 +628,7 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, ...@@ -628,7 +628,7 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP,
) )
assert self.vision_embed_tokens is not None assert self.vision_embed_tokens is not None
image_embeds = self.vision_embed_tokens(image_input["data"], image_embeds = self.vision_embed_tokens(image_input["pixel_values"],
image_input["image_sizes"]) image_input["image_sizes"])
return image_embeds return image_embeds
......
...@@ -94,34 +94,63 @@ class TensorSchema: ...@@ -94,34 +94,63 @@ class TensorSchema:
return False return False
return True return True
def _validate_nested_tensors( def _fmt_indexer(self, idxs: tuple[int, ...]) -> str:
if not idxs:
return ""
return str(list(idxs))
def _validate_field(
self, self,
value: Union[list[torch.Tensor], tuple[torch.Tensor, ...]], value: object,
field_name: str, field_name: str,
expected_shape: tuple[Union[int, str], ...], expected_shape: tuple[Union[int, str], ...],
dynamic_dims: set[str], dynamic_dims: set[str],
leading_idxs: tuple[int, ...] = (),
) -> tuple[int, ...]: ) -> tuple[int, ...]:
"""Validate a list/tuple of tensors and return the actual shape.""" """Validate a field and return the actual shape."""
if isinstance(value, (int, float)):
return () # Scalar
if isinstance(value, torch.Tensor):
return value.shape
if not isinstance(value, (list, tuple)):
raise TypeError(
f"{field_name}{self._fmt_indexer(leading_idxs)} is not "
f"one of the expected types: int, float, Tensor, list, tuple. "
f"Got: {type(value)}")
if len(value) == 0:
raise ValueError(f"{field_name}{self._fmt_indexer(leading_idxs)} "
f"is an empty sequence")
# Ensure all tensors in the list have the same # Ensure all tensors in the list have the same
# shape, besides dynamic dimensions # shape, besides dynamic dimensions
first = value[0]
for i, v in enumerate(value): for i, v in enumerate(value):
if not isinstance(v, torch.Tensor): shape = self._validate_field(
raise ValueError(f"{field_name}[{i}] is not a " v,
f"torch.Tensor") field_name,
if not self._match_shape_with_dynamic( expected_shape[1:],
v.shape, dynamic_dims,
first.shape, leading_idxs=leading_idxs + (i, ),
)
if i == 0:
first_shape = shape
elif not self._match_shape_with_dynamic(
shape,
first_shape,
expected_shape, expected_shape,
dynamic_dims, dynamic_dims,
): ):
raise ValueError(f"{field_name} contains inconsistent " raise ValueError(
f"shapes: {first.shape} vs {v.shape} " f"{field_name}{self._fmt_indexer(leading_idxs)} "
f"at index {i}") f"contains inconsistent shapes: {first_shape} "
f"(index 0) vs {shape} (index {i})")
# Treat the list as a stacked tensor: # Treat the list as a stacked tensor:
# shape = (len(list), *tensor.shape) # shape = (len(list), *tensor.shape)
return (len(value), ) + first.shape return (len(value), ) + first_shape
def _validate_tensor_shape_expected( def _validate_tensor_shape_expected(
self, self,
...@@ -187,36 +216,12 @@ class TensorSchema: ...@@ -187,36 +216,12 @@ class TensorSchema:
for arg in args: for arg in args:
if isinstance(arg, TensorShape): if isinstance(arg, TensorShape):
expected_shape = arg.resolve(**self._resolve_bindings) expected_shape = arg.resolve(**self._resolve_bindings)
if isinstance(value, (list, tuple)): actual_shape = self._validate_field(
# list/tuple of Tensors → shape = (len(value), ...) value,
if value and isinstance(value[0], torch.Tensor): field_name,
actual_shape = self._validate_nested_tensors( expected_shape,
value, field_name, expected_shape, arg.dynamic_dims,
arg.dynamic_dims) )
elif value:
# list/tuple of scalars → shape = (len(value),)
actual_shape = (len(value), )
else:
raise ValueError(
f"{field_name} is an empty list")
# Tensor → shape = tensor.shape
elif isinstance(value, torch.Tensor):
actual_shape = value.shape
# Otherwise, it's an unsupported type
else:
type_names = []
for arg in args:
if hasattr(arg, "__name__"):
type_names.append(str(arg.__name__))
else:
type_names.append(str(arg))
expected_types = ", ".join(type_names)
raise ValueError(
f"{field_name} is not one of the expected "
f"types: {expected_types}")
self._validate_tensor_shape_expected( self._validate_tensor_shape_expected(
actual_shape, expected_shape, field_name, actual_shape, expected_shape, field_name,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment