test_tensor_schema.py 5.11 KB
Newer Older
1
2
3
4
5
6
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import pytest
import torch

7
from vllm.model_executor.models.fuyu import FuyuImagePatchInputs
8
from vllm.model_executor.models.glm4_1v import Glm4vImageEmbeddingInputs
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
from vllm.model_executor.models.phi3v import Phi3VImagePixelInputs


def test_tensor_schema_valid_tensor():
    Phi3VImagePixelInputs(
        data=torch.randn(16, 64, 3, 32, 32),
        image_sizes=torch.randint(0, 256, (16, 2)),
    )


def test_tensor_schema_optional_fields():
    Phi3VImagePixelInputs(
        data=torch.randn(16, 64, 3, 32, 32),
        image_sizes=None,
    )

    Phi3VImagePixelInputs(data=torch.randn(16, 64, 3, 32, 32), )


def test_tensor_schema_constant_dim_failure():
    with pytest.raises(ValueError, match="dim\\[2\\] expected 3, got 4"):
        Phi3VImagePixelInputs(
            data=torch.randn(16, 64, 4, 32, 32),  # dim[2] = 4
            image_sizes=torch.randint(0, 256, (16, 2)),
        )


def test_tensor_schema_symbolic_dim_mismatch():
    with pytest.raises(ValueError, match="expected 'bn'=12, got 16"):
        Phi3VImagePixelInputs(
            data=torch.randn(12, 64, 3, 32, 32),
            image_sizes=torch.randint(0, 256, (16, 2)),
        )


def test_tensor_schema_list_tensor_valid():
    Phi3VImagePixelInputs(
        data=[torch.randn(64, 3, 32, 32) for _ in range(16)],
        image_sizes=torch.randint(0, 256, (16, 2)),
    )


def test_tensor_schema_variable_patch_counts_valid():
    # Each image has a different number of patches (p)
    # Each tensor has shape (p, 3, 32, 32)
    data = [
        torch.randn(16, 3, 32, 32),  # p = 16
        torch.randn(32, 3, 32, 32),  # p = 32
        torch.randn(64, 3, 32, 32),  # p = 64
    ]
    image_sizes = torch.randint(0, 256, (3, 2))  # bn = 3
    Phi3VImagePixelInputs(
        data=data,
        image_sizes=image_sizes,
    )


def test_tensor_schema_tuple_tensor_valid():
    Phi3VImagePixelInputs(
        data=tuple(torch.randn(64, 3, 32, 32) for _ in range(16)),
        image_sizes=torch.randint(0, 256, (16, 2)),
    )


def test_tensor_schema_inconsistent_shapes_in_list():
    with pytest.raises(ValueError, match="contains inconsistent shapes"):
        Phi3VImagePixelInputs(
            data=[torch.randn(64, 3, 32, 32),
                  torch.randn(64, 3, 16, 16)] +
            [torch.randn(64, 3, 32, 32) for _ in range(14)],
            image_sizes=torch.randint(0, 256, (16, 2)),
        )


def test_tensor_schema_empty_list():
    with pytest.raises(ValueError, match="is an empty list"):
        Phi3VImagePixelInputs(
            data=[],
            image_sizes=torch.randint(0, 256, (0, 2)),
        )


def test_tensor_schema_validation_disabled_skips_shape_check():
    # This should NOT raise, because validation is turned off
    # This would normally fail (dim[2] should be 3, not 4)
    Phi3VImagePixelInputs(
        data=torch.randn(16, 64, 4, 32, 32),
        image_sizes=torch.randint(0, 256, (16, 2)),
        validate=False,
    )


def test_tensor_schema_with_valid_resolve_binding_dims():
    data = torch.randn(16, 64, 3, 336, 336)  # h=336, w=336
    image_sizes = torch.randint(0, 256, (16, 2))

    Phi3VImagePixelInputs(
        data=data,
        image_sizes=image_sizes,
        resolve_bindings={
            "h": 336,
            "w": 336
        },
    )


def test_tensor_schema_with_invalid_resolve_binding_dims():
    data = torch.randn(16, 64, 3, 36, 36)  # h=36, w=36
    image_sizes = torch.randint(0, 256, (16, 2))

    # Should raise because 'h' and 'w' don't match resolve bindings
    with pytest.raises(ValueError, match="dim\\[3\\] expected 336, got 36"):
        Phi3VImagePixelInputs(
            data=data,
            image_sizes=image_sizes,
            resolve_bindings={
                "h": 336,
                "w": 336
            },
        )
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148


def test_tensor_schema_with_list_of_symbolic_dim():
    flat_data = torch.stack([torch.randn(768) for _ in range(3)])  # (bn=3, fn)
    patches_per_image = [64, 64, 64]  # len = bn = 3

    FuyuImagePatchInputs(
        flat_data=flat_data,
        patches_per_image=patches_per_image,
    )


def test_tensor_schema_with_list_of_symbolic_dim_mismatch_in_length():
    flat_data = torch.stack([torch.randn(768) for _ in range(4)])  # (bn=4, fn)
    patches_per_image = [64, 64, 64]  # len = 3 ≠ bn

    with pytest.raises(ValueError, match="expected 'bn'=4, got 3"):
        FuyuImagePatchInputs(
            flat_data=flat_data,
            patches_per_image=patches_per_image,
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
        )


def test_valid_tensor_schema_with_static_last_dim():
    image_embeds = torch.randn(256, 1024)
    image_grid_thw = torch.randint(0, 4, (2, 3))

    Glm4vImageEmbeddingInputs(
        image_embeds=image_embeds,
        image_grid_thw=image_grid_thw,
    )


def test_invalid_tensor_schema_with_static_last_dim():
    image_embeds = torch.randn(256, 1024)
    image_grid_thw = torch.randint(0, 4, (2, 4))  # Wrong last dim

    with pytest.raises(ValueError, match="dim\\[1\\] expected 3, got 4"):
        Glm4vImageEmbeddingInputs(
            image_embeds=image_embeds,
            image_grid_thw=image_grid_thw,
        )