test_utils.py 6.76 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import pytest
4
import torch
5

6
7
8
9
10
11
12
13
from vllm.multimodal.inputs import (
    MultiModalBatchedField,
    MultiModalFieldElem,
    MultiModalKwargsItem,
    MultiModalSharedField,
    PlaceholderRange,
)
from vllm.multimodal.utils import argsort_mm_positions, group_and_batch_mm_items
14
15


16
17
18
@pytest.mark.parametrize(
    "case",
    [
19
20
        # Single modality
        ## Internally sorted
21
        dict(
22
23
24
25
26
27
            mm_positions={
                "image": [
                    PlaceholderRange(offset=0, length=2),
                    PlaceholderRange(offset=3, length=2),
                ]
            },
28
29
30
            expected_modality_idxs=[
                ("image", 0),
                ("image", 1),
31
32
            ],
        ),
33
        ## Internally unsorted
34
        dict(
35
36
            mm_positions={
                "image": [
37
                    PlaceholderRange(offset=3, length=2),
38
39
40
                    PlaceholderRange(offset=0, length=2),
                ]
            },
41
42
43
            expected_modality_idxs=[
                ("image", 1),
                ("image", 0),
44
45
            ],
        ),
46
47
        # Two modalities
        ## Internally sorted
48
        dict(
49
50
51
52
53
54
55
56
            mm_positions={
                "image": [
                    PlaceholderRange(offset=7, length=4),
                    PlaceholderRange(offset=11, length=5),
                ],
                "audio": [
                    PlaceholderRange(offset=0, length=2),
                    PlaceholderRange(offset=2, length=3),
57
                ],
58
            },
59
60
61
62
63
            expected_modality_idxs=[
                ("audio", 0),
                ("audio", 1),
                ("image", 0),
                ("image", 1),
64
            ],
65
66
        ),
        ## Interleaved, internally sorted
67
        dict(
68
69
70
71
72
73
74
75
            mm_positions={
                "image": [
                    PlaceholderRange(offset=0, length=4),
                    PlaceholderRange(offset=8, length=2),
                ],
                "audio": [
                    PlaceholderRange(offset=5, length=2),
                    PlaceholderRange(offset=11, length=4),
76
                ],
77
78
79
80
81
82
            },
            expected_modality_idxs=[
                ("image", 0),
                ("audio", 0),
                ("image", 1),
                ("audio", 1),
83
84
            ],
        ),
85
        ## Interleaved, internally unsorted
86
        dict(
87
88
            mm_positions={
                "image": [
89
90
                    PlaceholderRange(offset=8, length=2),
                    PlaceholderRange(offset=0, length=4),
91
92
                ],
                "audio": [
93
94
                    PlaceholderRange(offset=11, length=4),
                    PlaceholderRange(offset=5, length=2),
95
                ],
96
            },
97
98
99
100
101
            expected_modality_idxs=[
                ("image", 1),
                ("audio", 1),
                ("image", 0),
                ("audio", 0),
102
103
104
            ],
        ),
        # Three modalities
105
        ## Internally sorted
106
        dict(
107
108
109
110
111
112
113
114
115
116
117
118
            mm_positions={
                "image": [
                    PlaceholderRange(offset=15, length=7),
                    PlaceholderRange(offset=22, length=8),
                ],
                "audio": [
                    PlaceholderRange(offset=0, length=2),
                ],
                "video": [
                    PlaceholderRange(offset=3, length=4),
                    PlaceholderRange(offset=7, length=5),
                    PlaceholderRange(offset=12, length=6),
119
                ],
120
            },
121
122
123
124
125
126
127
            expected_modality_idxs=[
                ("audio", 0),
                ("video", 0),
                ("video", 1),
                ("video", 2),
                ("image", 0),
                ("image", 1),
128
            ],
129
        ),
130
        ## Interleaved, internally sorted
131
        dict(
132
133
134
135
136
137
138
139
140
141
142
            mm_positions={
                "image": [
                    PlaceholderRange(offset=0, length=2),
                    PlaceholderRange(offset=2, length=3),
                    PlaceholderRange(offset=20, length=4),
                ],
                "audio": [
                    PlaceholderRange(offset=5, length=2),
                ],
                "video": [
                    PlaceholderRange(offset=8, length=5),
143
                ],
144
            },
145
146
147
148
149
150
            expected_modality_idxs=[
                ("image", 0),
                ("image", 1),
                ("audio", 0),
                ("video", 0),
                ("image", 2),
151
            ],
152
        ),
153
154
        ## Interleaved, internally unsorted
        dict(
155
156
157
            mm_positions={
                "image": [
                    PlaceholderRange(offset=0, length=2),
158
159
                    PlaceholderRange(offset=20, length=4),
                    PlaceholderRange(offset=2, length=3),
160
161
                ],
                "audio": [
162
                    PlaceholderRange(offset=5, length=2),
163
164
                ],
                "video": [
165
                    PlaceholderRange(offset=8, length=5),
166
                ],
167
            },
168
169
170
171
172
173
            expected_modality_idxs=[
                ("image", 0),
                ("image", 2),
                ("audio", 0),
                ("video", 0),
                ("image", 1),
174
175
            ],
        ),
176
177
178
179
180
    ],
)
def test_argsort_mm_positions(case):
    mm_positions = case["mm_positions"]
    expected_modality_idxs = case["expected_modality_idxs"]
181

182
    modality_idxs = argsort_mm_positions(mm_positions)
183

184
    assert modality_idxs == expected_modality_idxs
185
186


187
188
189
190
def test_group_and_batch_mm_items_split_by_fieldset():
    elem = MultiModalFieldElem(
        data=torch.empty(1, dtype=torch.uint8),
        field=MultiModalBatchedField(),
191
    )
192
193
194
195
196
    item1 = MultiModalKwargsItem({"x": elem, "y": elem})
    item2 = MultiModalKwargsItem({"y": elem, "x": elem})
    item3 = MultiModalKwargsItem({"x": elem, "y": elem, "z": elem})
    item4 = MultiModalKwargsItem({"x": elem})
    item5 = MultiModalKwargsItem({"x": elem, "y": elem})
197

198
199
    res = group_and_batch_mm_items([item1, item2, item3, item4, item5])
    assert [num_items for num_items, _ in res] == [2, 1, 1, 1]
200
201


202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
def test_group_and_batch_mm_items_split_by_shared_data():
    elem1 = MultiModalFieldElem(
        data=torch.zeros(1, dtype=torch.uint8),
        field=MultiModalSharedField(batch_size=1),
    )
    elem2 = MultiModalFieldElem(
        data=torch.zeros(2, dtype=torch.uint8),
        field=MultiModalSharedField(batch_size=1),
    )
    item1 = MultiModalKwargsItem({"x": elem1})
    item2 = MultiModalKwargsItem({"x": elem1})
    item3 = MultiModalKwargsItem({"x": elem2})
    item4 = MultiModalKwargsItem({"x": elem1})
    item5 = MultiModalKwargsItem({"x": elem2})

    res = group_and_batch_mm_items([item1, item2, item3, item4, item5])
    assert [num_items for num_items, _ in res] == [2, 1, 1, 1]