test_tensor_schema.py 5.33 KB
Newer Older
1
2
3
4
5
6
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import pytest
import torch

7
from vllm.model_executor.models.glm4_1v import Glm4vImageEmbeddingInputs
8
from vllm.model_executor.models.granite_speech import GraniteSpeechAudioInputs
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
from vllm.model_executor.models.phi3v import Phi3VImagePixelInputs


def test_tensor_schema_valid_tensor():
    Phi3VImagePixelInputs(
        data=torch.randn(16, 64, 3, 32, 32),
        image_sizes=torch.randint(0, 256, (16, 2)),
    )


def test_tensor_schema_optional_fields():
    Phi3VImagePixelInputs(
        data=torch.randn(16, 64, 3, 32, 32),
        image_sizes=None,
    )

    Phi3VImagePixelInputs(data=torch.randn(16, 64, 3, 32, 32), )


def test_tensor_schema_constant_dim_failure():
    with pytest.raises(ValueError, match="dim\\[2\\] expected 3, got 4"):
        Phi3VImagePixelInputs(
            data=torch.randn(16, 64, 4, 32, 32),  # dim[2] = 4
            image_sizes=torch.randint(0, 256, (16, 2)),
        )


def test_tensor_schema_symbolic_dim_mismatch():
    with pytest.raises(ValueError, match="expected 'bn'=12, got 16"):
        Phi3VImagePixelInputs(
            data=torch.randn(12, 64, 3, 32, 32),
            image_sizes=torch.randint(0, 256, (16, 2)),
        )


def test_tensor_schema_list_tensor_valid():
    Phi3VImagePixelInputs(
        data=[torch.randn(64, 3, 32, 32) for _ in range(16)],
        image_sizes=torch.randint(0, 256, (16, 2)),
    )


def test_tensor_schema_variable_patch_counts_valid():
    # Each image has a different number of patches (p)
    # Each tensor has shape (p, 3, 32, 32)
    data = [
        torch.randn(16, 3, 32, 32),  # p = 16
        torch.randn(32, 3, 32, 32),  # p = 32
        torch.randn(64, 3, 32, 32),  # p = 64
    ]
    image_sizes = torch.randint(0, 256, (3, 2))  # bn = 3
    Phi3VImagePixelInputs(
        data=data,
        image_sizes=image_sizes,
    )


def test_tensor_schema_tuple_tensor_valid():
    Phi3VImagePixelInputs(
        data=tuple(torch.randn(64, 3, 32, 32) for _ in range(16)),
        image_sizes=torch.randint(0, 256, (16, 2)),
    )


def test_tensor_schema_inconsistent_shapes_in_list():
    with pytest.raises(ValueError, match="contains inconsistent shapes"):
        Phi3VImagePixelInputs(
            data=[torch.randn(64, 3, 32, 32),
                  torch.randn(64, 3, 16, 16)] +
            [torch.randn(64, 3, 32, 32) for _ in range(14)],
            image_sizes=torch.randint(0, 256, (16, 2)),
        )


def test_tensor_schema_empty_list():
    with pytest.raises(ValueError, match="is an empty list"):
        Phi3VImagePixelInputs(
            data=[],
            image_sizes=torch.randint(0, 256, (0, 2)),
        )


def test_tensor_schema_validation_disabled_skips_shape_check():
    # This should NOT raise, because validation is turned off
    # This would normally fail (dim[2] should be 3, not 4)
    Phi3VImagePixelInputs(
        data=torch.randn(16, 64, 4, 32, 32),
        image_sizes=torch.randint(0, 256, (16, 2)),
        validate=False,
    )


def test_tensor_schema_with_valid_resolve_binding_dims():
    data = torch.randn(16, 64, 3, 336, 336)  # h=336, w=336
    image_sizes = torch.randint(0, 256, (16, 2))

    Phi3VImagePixelInputs(
        data=data,
        image_sizes=image_sizes,
        resolve_bindings={
            "h": 336,
            "w": 336
        },
    )


def test_tensor_schema_with_invalid_resolve_binding_dims():
    data = torch.randn(16, 64, 3, 36, 36)  # h=36, w=36
    image_sizes = torch.randint(0, 256, (16, 2))

    # Should raise because 'h' and 'w' don't match resolve bindings
    with pytest.raises(ValueError, match="dim\\[3\\] expected 336, got 36"):
        Phi3VImagePixelInputs(
            data=data,
            image_sizes=image_sizes,
            resolve_bindings={
                "h": 336,
                "w": 336
            },
        )
129
130
131


def test_tensor_schema_with_list_of_symbolic_dim():
132
133
134
135
136
137
138
139
    input_features = torch.randn(3, 10, 160)  # (b=3, fi=10, 160)
    input_features_mask = torch.randn(3, 8)  # (b=3, fo=8)
    audio_embed_sizes = [8, 8, 8]  # len = b = 3

    GraniteSpeechAudioInputs(
        input_features=input_features,
        input_features_mask=input_features_mask,
        audio_embed_sizes=audio_embed_sizes,
140
141
142
143
    )


def test_tensor_schema_with_list_of_symbolic_dim_mismatch_in_length():
144
145
146
147
148
149
150
151
152
    input_features = torch.randn(4, 10, 160)  # (b=4, fi=10, 160)
    input_features_mask = torch.randn(4, 8)  # (b=4, fo=8)
    audio_embed_sizes = [8, 8, 8]  # len = 3 ≠ b

    with pytest.raises(ValueError, match="expected 'b'=4, got 3"):
        GraniteSpeechAudioInputs(
            input_features=input_features,
            input_features_mask=input_features_mask,
            audio_embed_sizes=audio_embed_sizes,
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
        )


def test_valid_tensor_schema_with_static_last_dim():
    image_embeds = torch.randn(256, 1024)
    image_grid_thw = torch.randint(0, 4, (2, 3))

    Glm4vImageEmbeddingInputs(
        image_embeds=image_embeds,
        image_grid_thw=image_grid_thw,
    )


def test_invalid_tensor_schema_with_static_last_dim():
    image_embeds = torch.randn(256, 1024)
    image_grid_thw = torch.randint(0, 4, (2, 4))  # Wrong last dim

    with pytest.raises(ValueError, match="dim\\[1\\] expected 3, got 4"):
        Glm4vImageEmbeddingInputs(
            image_embeds=image_embeds,
            image_grid_thw=image_grid_thw,
        )