test_processing.py 14.2 KB
Newer Older
1
2
3
4
from typing import cast

import pytest

5
6
7
8
from vllm.multimodal.processing import (MultiModalDataItems, PromptReplacement,
                                        _PlaceholderInfo, find_text_matches,
                                        find_token_matches, iter_placeholders,
                                        iter_token_matches,
9
10
                                        replace_text_matches,
                                        replace_token_matches)
11
12
13
14
15
16
17
18
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import full_groupby


# yapf: disable
@pytest.mark.parametrize(
    ("token_ids", "match_ids", "expected"),
    [
19
        ([], [], []),
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
        ([], [32000], []),
        (
            [32000, 32000, 32000],
            [32000],
            [
                { "start_idx": 0, "end_idx": 1 },
                { "start_idx": 1, "end_idx": 2 },
                { "start_idx": 2, "end_idx": 3 },
            ],
        ),
        (
            [32000, 32000, 32000],
            [32000, 32000],
            [{ "start_idx": 0, "end_idx": 2 }],
        ),
        (
            [32000, 32000, 32000],
            [32000, 32000, 32000],
            [{ "start_idx": 0, "end_idx": 3 }],
        ),
        (
            [9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918],
            [28747, 32000],
            [
                { "start_idx": 1, "end_idx": 3 },
                { "start_idx": 6, "end_idx": 8 },
            ],
        ),
        (
            [9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918],
            [28747, 32000, 32000, 32000],
            [
                { "start_idx": 1, "end_idx": 5 },
            ],
        ),
        (
            [9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918],
            [28747, 0, 32000],
            [],
        ),
    ],
)
# yapf: enable
def test_iter_token_matches(token_ids, match_ids, expected):
    result = list(iter_token_matches(token_ids, match_ids))

    # Manually constructed results
    assert [item._asdict() for item in result] == expected

    # Invariants
    match_lens = [end - start for start, end in result]
    print("match_lens:", match_lens)  # Only displayed on error
    assert all(match_len == len(match_ids) for match_len in match_lens)


# yapf: disable
@pytest.mark.parametrize(
    ("prompt", "target_by_key", "expected_by_key"),
    [
        (
            [],
            {
                "pattern_1": [],
                "pattern_2": [32000],
            },
            {
86
                "pattern_1": [],
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
                "pattern_2": [],
            }
        ),
        (
            [32000, 32000, 32000, 32000],
            {
                "pattern_1": [32000],
                "pattern_2": [32000, 32000],
                "pattern_3": [32000, 32000, 32000],
            },
            {
                "pattern_1": [
                    { "start_idx": 0, "end_idx": 1 },
                    { "start_idx": 1, "end_idx": 2 },
                    { "start_idx": 2, "end_idx": 3 },
                    { "start_idx": 3, "end_idx": 4 },
                ],
                "pattern_2": [
                    { "start_idx": 0, "end_idx": 2 },
                    { "start_idx": 2, "end_idx": 4 },
                ],
                "pattern_3": [
                    { "start_idx": 0, "end_idx": 3 },
                ],
            },
        ),
        (
            [9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918],
            {
                "pattern_1": [28747, 32000],
                "pattern_2": [28747, 32000, 32000, 32000],
                "pattern_3": [28747, 0, 32000],
            },
            {
                "pattern_1": [
                    { "start_idx": 1, "end_idx": 3 },
                    { "start_idx": 6, "end_idx": 8 },
                ],
                "pattern_2": [
                    { "start_idx": 1, "end_idx": 5 },
                ],
                "pattern_3": [],
            },
        ),
    ],
)
# yapf: enable
def test_find_token_matches(prompt, target_by_key, expected_by_key):
    # Should not be used since there is nothing to convert to token IDs
    mock_tokenizer = cast(AnyTokenizer, object())

138
    prompt_repls = [
139
        PromptReplacement(key, target, []).bind(mock_tokenizer)
140
141
142
        for key, target in target_by_key.items()
    ]
    result = find_token_matches(prompt, prompt_repls)
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244

    # Only displayed on error
    print("result:", result)

    # Manually constructed results
    result_groups = dict(full_groupby(result, key=lambda x: x.modality))
    assert {
        key: [
            dict(start_idx=item.start_idx, end_idx=item.end_idx)
            for item in result_groups.get(key, [])
        ]
        for key in expected_by_key
    } == expected_by_key


# yapf: disable
@pytest.mark.parametrize(
    ("prompt", "target_by_key", "expected_by_key"),
    [
        # Detokenized test cases of `test_find_token_matches`
        # using the vocab of llava-hf/llava-v1.6-mistral-7b-hf
        (
            "",
            {
                "pattern_1": "",
                "pattern_2": "<image>",
            },
            {
                "pattern_1": [{ "start_idx": 0, "end_idx": 0 }],
                "pattern_2": [],
            }
        ),
        (
            "<image><image><image><image>",
            {
                "pattern_1": "<image>",
                "pattern_2": "<image><image>",
                "pattern_3": "<image><image><image>",
            },
            {
                "pattern_1": [
                    { "start_idx": 0, "end_idx": 7 },
                    { "start_idx": 7, "end_idx": 14 },
                    { "start_idx": 14, "end_idx": 21 },
                    { "start_idx": 21, "end_idx": 28 },
                ],
                "pattern_2": [
                    { "start_idx": 0, "end_idx": 14 },
                    { "start_idx": 14, "end_idx": 28 },
                ],
                "pattern_3": [
                    { "start_idx": 0, "end_idx": 21 },
                ],
            },
        ),
        (
            "Image:<image><image><image>Image:<image><image>!",
            {
                "pattern_1": "Image:<image>",
                "pattern_2": "Image:<image><image><image>",
                "pattern_3": "Image:<unk><image>",
            },
            {
                "pattern_1": [
                    { "start_idx": 0, "end_idx": 13 },
                    { "start_idx": 27, "end_idx": 40 },
                ],
                "pattern_2": [
                    { "start_idx": 0, "end_idx": 27 },
                ],
                "pattern_3": [],
            },
        ),
        # Test regex escape
        (
            "<|image|><image><|image|><image>",
            {
                "pattern_1": "<|image|>",
                "pattern_2": "<|image|><image>",
                "pattern_3": "<|image|><image><|image|>",
            },
            {
                "pattern_1": [
                    { "start_idx": 0, "end_idx": 9 },
                    { "start_idx": 16, "end_idx": 25 },
                ],
                "pattern_2": [
                    { "start_idx": 0, "end_idx": 16 },
                    { "start_idx": 16, "end_idx": 32 },
                ],
                "pattern_3": [
                    { "start_idx": 0, "end_idx": 25 },
                ],
            },
        ),
    ],
)
# yapf: enable
def test_find_text_matches(prompt, target_by_key, expected_by_key):
    # Should not be used since there is nothing to convert to text
    mock_tokenizer = cast(AnyTokenizer, object())

245
    prompt_repls = [
246
        PromptReplacement(key, target, []).bind(mock_tokenizer)
247
248
249
        for key, target in target_by_key.items()
    ]
    result = find_text_matches(prompt, prompt_repls)
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266

    # Only displayed on error
    print("result:", result)

    # Manually constructed results
    result_groups = dict(full_groupby(result, key=lambda x: x.modality))
    assert {
        key: [
            dict(start_idx=item.start_idx, end_idx=item.end_idx)
            for item in result_groups.get(key, [])
        ]
        for key in expected_by_key
    } == expected_by_key


# yapf: disable
@pytest.mark.parametrize(
267
    ("prompt", "target_by_key", "repl_by_key"),
268
269
270
271
272
273
274
275
276
277
278
    [
        (
            "Image:<image>Image:<image><image>!",
            {
                # We use `<image>` before `Image:` to test matches that
                # occur out of order
                "pattern_1": "<image>",
                "pattern_2": "Image:",
                "pattern_3": "!",
            },
            {
279
280
281
282
283
284
                # Test whether target is confused with replacement
                "pattern_1": "<image><image>",
                # Test empty replacement
                "pattern_2": "",
                # Test dynamic replacement (beyond the form of `unit * count`)
                "pattern_3": "?!?",
285
286
287
288
            },
        ),
    ]
)
289
290
291
292
@pytest.mark.parametrize(
    ("mm_count", "expected"),
    [
        (0, "Image:<image>Image:<image><image>!"),
293
294
        (1, "<image><image>Image:<image><image>?!?"),
        (2, "<image><image><image><image><image>?!?"),
295
296
    ]
)
297
298
299
300
301
# yapf: enable
def test_find_replace_text(
    prompt,
    target_by_key,
    repl_by_key,
302
303
    mm_count,
    expected,
304
305
306
307
):
    # Should not be used since there is nothing to convert to text
    mock_tokenizer = cast(AnyTokenizer, object())

308
    prompt_repls = [
309
        PromptReplacement(key, target, repl_by_key[key]).bind(mock_tokenizer)
310
311
312
313
314
        for key, target in target_by_key.items()
    ]
    matches = find_text_matches(prompt, prompt_repls)

    result = replace_text_matches(
315
        prompt,
316
        matches,
317
318
        MultiModalDataItems({key: [None] * mm_count
                             for key in repl_by_key}),
319
320
321
322
    )

    # Only displayed on error
    print("matches:", matches)
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
    print("result:", result)

    # Manually constructed results
    assert result == expected


# yapf: disable
@pytest.mark.parametrize(
    ("prompt", "target_by_key", "repl_by_key"),
    [
        # Tokenized test cases of `test_find_replace_text`
        # using the vocab of llava-hf/llava-v1.6-mistral-7b-hf
        (
            [1, 9833, 28747, 32000, 9833, 28747, 32000, 32000, 918],
            {
                # We use `<image>` before `Image:` to test matches that
                # occur out of order
                "pattern_1": [32000],
                "pattern_2": [9833, 28747],
                "pattern_3": [918],
            },
            {
345
346
347
348
349
350
                # Test whether target is confused with replacement
                "pattern_1": [32000, 32000],
                # Test empty replacement
                "pattern_2": [],
                # Test dynamic replacement (beyond the form of `unit * count`)
                "pattern_3": [1550, 918, 1550],
351
352
353
354
355
356
357
358
            },
        ),
    ]
)
@pytest.mark.parametrize(
    ("mm_count", "expected"),
    [
        (0, [1, 9833, 28747, 32000, 9833, 28747, 32000, 32000, 918]),
359
360
        (1, [1, 32000, 32000, 9833, 28747, 32000, 32000, 1550, 918, 1550]),
        (2, [1, 32000, 32000, 32000, 32000, 32000, 1550, 918, 1550]),
361
362
363
364
365
366
367
368
369
370
371
372
373
374
    ]
)
# yapf: enable
def test_find_replace_tokens(
    prompt,
    target_by_key,
    repl_by_key,
    mm_count,
    expected,
):
    # Should not be used since there is nothing to convert to tokens
    mock_tokenizer = cast(AnyTokenizer, object())

    prompt_repls = [
375
        PromptReplacement(key, target, repl_by_key[key]).bind(mock_tokenizer)
376
377
378
379
380
381
382
        for key, target in target_by_key.items()
    ]
    matches = find_token_matches(prompt, prompt_repls)

    result = replace_token_matches(
        prompt,
        matches,
383
384
        MultiModalDataItems({key: [None] * mm_count
                             for key in repl_by_key}),
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
    )

    # Only displayed on error
    print("matches:", matches)
    print("result:", result)

    # Manually constructed results
    assert result == expected


# yapf: disable
@pytest.mark.parametrize(
    "repl_by_key",
    [
        {
400
401
402
            "pattern_1": [32000, 32000],
            "pattern_2": [],
            "pattern_3": [1550, 918, 1550],
403
404
405
406
407
408
409
410
411
412
413
414
        },
    ],
)
@pytest.mark.parametrize(
    ("prompt", "expected"),
    [
        (
            [1, 9833, 28747, 32000, 9833, 28747, 32000, 32000, 918],
            [
                _PlaceholderInfo(
                    modality="pattern_1",
                    start_idx=6,
415
                    replacement=[32000, 32000],
416
417
418
419
                ),
            ],
        ),
        (
420
            [1, 32000, 32000, 9833, 28747, 32000, 32000, 1550, 918, 1550],
421
422
423
424
            [
                _PlaceholderInfo(
                    modality="pattern_1",
                    start_idx=1,
425
                    replacement=[32000, 32000],
426
427
428
429
                ),
                _PlaceholderInfo(
                    modality="pattern_1",
                    start_idx=5,
430
                    replacement=[32000, 32000],
431
432
433
434
                ),
                _PlaceholderInfo(
                    modality="pattern_3",
                    start_idx=7,
435
                    replacement=[1550, 918, 1550],
436
437
438
439
                ),
            ],
        ),
        (
440
            [1, 32000, 32000, 32000, 32000, 32000, 1550, 918, 1550],
441
442
443
444
            [
                _PlaceholderInfo(
                    modality="pattern_1",
                    start_idx=1,
445
446
447
448
449
450
                    replacement=[32000, 32000],
                ),
                _PlaceholderInfo(
                    modality="pattern_1",
                    start_idx=3,
                    replacement=[32000, 32000],
451
452
453
454
                ),
                _PlaceholderInfo(
                    modality="pattern_3",
                    start_idx=6,
455
                    replacement=[1550, 918, 1550],
456
457
458
459
460
461
462
463
464
465
466
467
468
469
                ),
            ],
        ),
    ]
)
def test_iter_placeholders(
    repl_by_key,
    prompt,
    expected,
):
    # Should not be used since there is nothing to convert to tokens
    mock_tokenizer = cast(AnyTokenizer, object())

    prompt_repls = [
470
        PromptReplacement(key, [], repl).bind(mock_tokenizer)
471
472
473
        for key, repl in repl_by_key.items()
    ]

474
475
476
477
478
479
480
    result = list(
        iter_placeholders(
            prompt_repls,
            prompt,
            # Effectively match all occurrences in the prompt
            MultiModalDataItems({key: [None] * 3 for key in repl_by_key}),
         ))
481
482
483

    # Only displayed on error
    print("result:", result)
484
485

    # Manually constructed results
486
    assert result == expected