Unverified Commit 68b254d6 authored by Benji Beck's avatar Benji Beck Committed by GitHub
Browse files

Fix TensorSchema validation test for symbolic dims (#22366)


Signed-off-by: default avatarBenji Beck <benjibeck@meta.com>
parent 8c50d62f
...@@ -4,8 +4,8 @@ ...@@ -4,8 +4,8 @@
import pytest import pytest
import torch import torch
from vllm.model_executor.models.fuyu import FuyuImagePatchInputs
from vllm.model_executor.models.glm4_1v import Glm4vImageEmbeddingInputs from vllm.model_executor.models.glm4_1v import Glm4vImageEmbeddingInputs
from vllm.model_executor.models.granite_speech import GraniteSpeechAudioInputs
from vllm.model_executor.models.phi3v import Phi3VImagePixelInputs from vllm.model_executor.models.phi3v import Phi3VImagePixelInputs
...@@ -129,23 +129,27 @@ def test_tensor_schema_with_invalid_resolve_binding_dims(): ...@@ -129,23 +129,27 @@ def test_tensor_schema_with_invalid_resolve_binding_dims():
def test_tensor_schema_with_list_of_symbolic_dim(): def test_tensor_schema_with_list_of_symbolic_dim():
flat_data = torch.stack([torch.randn(768) for _ in range(3)]) # (bn=3, fn) input_features = torch.randn(3, 10, 160) # (b=3, fi=10, 160)
patches_per_image = [64, 64, 64] # len = bn = 3 input_features_mask = torch.randn(3, 8) # (b=3, fo=8)
audio_embed_sizes = [8, 8, 8] # len = b = 3
FuyuImagePatchInputs(
flat_data=flat_data, GraniteSpeechAudioInputs(
patches_per_image=patches_per_image, input_features=input_features,
input_features_mask=input_features_mask,
audio_embed_sizes=audio_embed_sizes,
) )
def test_tensor_schema_with_list_of_symbolic_dim_mismatch_in_length(): def test_tensor_schema_with_list_of_symbolic_dim_mismatch_in_length():
flat_data = torch.stack([torch.randn(768) for _ in range(4)]) # (bn=4, fn) input_features = torch.randn(4, 10, 160) # (b=4, fi=10, 160)
patches_per_image = [64, 64, 64] # len = 3 ≠ bn input_features_mask = torch.randn(4, 8) # (b=4, fo=8)
audio_embed_sizes = [8, 8, 8] # len = 3 ≠ b
with pytest.raises(ValueError, match="expected 'bn'=4, got 3"):
FuyuImagePatchInputs( with pytest.raises(ValueError, match="expected 'b'=4, got 3"):
flat_data=flat_data, GraniteSpeechAudioInputs(
patches_per_image=patches_per_image, input_features=input_features,
input_features_mask=input_features_mask,
audio_embed_sizes=audio_embed_sizes,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment