test_processing.py 18.7 KB
Newer Older
1
from contextlib import nullcontext
2
from typing import cast
3
from unittest.mock import MagicMock
4

5
import numpy as np
6
import pytest
7
8
9

from vllm.config import ModelConfig
from vllm.multimodal import MULTIMODAL_REGISTRY
10
11
12
13
# yapf conflicts with isort for this block
# yapf: disable
from vllm.multimodal.processing import (PlaceholderFeaturesInfo,
                                        PromptReplacement,
14
                                        find_mm_placeholders,
15
                                        find_text_matches, find_token_matches,
16
                                        iter_token_matches,
17
18
                                        replace_text_matches,
                                        replace_token_matches)
19
# yapf: enable
20
from vllm.multimodal.profiling import MultiModalProfiler
21
from vllm.multimodal.utils import cached_get_tokenizer
22
23
24
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import full_groupby

25
26
from .utils import random_image

27
28
29
30
31

# yapf: disable
@pytest.mark.parametrize(
    ("token_ids", "match_ids", "expected"),
    [
32
        ([], [], []),
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
        ([], [32000], []),
        (
            [32000, 32000, 32000],
            [32000],
            [
                { "start_idx": 0, "end_idx": 1 },
                { "start_idx": 1, "end_idx": 2 },
                { "start_idx": 2, "end_idx": 3 },
            ],
        ),
        (
            [32000, 32000, 32000],
            [32000, 32000],
            [{ "start_idx": 0, "end_idx": 2 }],
        ),
        (
            [32000, 32000, 32000],
            [32000, 32000, 32000],
            [{ "start_idx": 0, "end_idx": 3 }],
        ),
        (
            [9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918],
            [28747, 32000],
            [
                { "start_idx": 1, "end_idx": 3 },
                { "start_idx": 6, "end_idx": 8 },
            ],
        ),
        (
            [9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918],
            [28747, 32000, 32000, 32000],
            [
                { "start_idx": 1, "end_idx": 5 },
            ],
        ),
        (
            [9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918],
            [28747, 0, 32000],
            [],
        ),
    ],
)
# yapf: enable
def test_iter_token_matches(token_ids, match_ids, expected):
    result = list(iter_token_matches(token_ids, match_ids))

    # Manually constructed results
    assert [item._asdict() for item in result] == expected

    # Invariants
    match_lens = [end - start for start, end in result]
    print("match_lens:", match_lens)  # Only displayed on error
    assert all(match_len == len(match_ids) for match_len in match_lens)


# yapf: disable
@pytest.mark.parametrize(
    ("prompt", "target_by_key", "expected_by_key"),
    [
        (
            [],
            {
                "pattern_1": [],
                "pattern_2": [32000],
            },
            {
99
                "pattern_1": [],
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
                "pattern_2": [],
            }
        ),
        (
            [32000, 32000, 32000, 32000],
            {
                "pattern_1": [32000],
                "pattern_2": [32000, 32000],
                "pattern_3": [32000, 32000, 32000],
            },
            {
                "pattern_1": [
                    { "start_idx": 0, "end_idx": 1 },
                    { "start_idx": 1, "end_idx": 2 },
                    { "start_idx": 2, "end_idx": 3 },
                    { "start_idx": 3, "end_idx": 4 },
                ],
                "pattern_2": [
                    { "start_idx": 0, "end_idx": 2 },
                    { "start_idx": 2, "end_idx": 4 },
                ],
                "pattern_3": [
                    { "start_idx": 0, "end_idx": 3 },
                ],
            },
        ),
        (
            [9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918],
            {
                "pattern_1": [28747, 32000],
                "pattern_2": [28747, 32000, 32000, 32000],
                "pattern_3": [28747, 0, 32000],
            },
            {
                "pattern_1": [
                    { "start_idx": 1, "end_idx": 3 },
                    { "start_idx": 6, "end_idx": 8 },
                ],
                "pattern_2": [
                    { "start_idx": 1, "end_idx": 5 },
                ],
                "pattern_3": [],
            },
        ),
    ],
)
# yapf: enable
def test_find_token_matches(prompt, target_by_key, expected_by_key):
    # Should not be used since there is nothing to convert to token IDs
    mock_tokenizer = cast(AnyTokenizer, object())

151
    prompt_repls = [
152
        PromptReplacement(key, target, []).bind(mock_tokenizer)
153
154
155
        for key, target in target_by_key.items()
    ]
    result = find_token_matches(prompt, prompt_repls)
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257

    # Only displayed on error
    print("result:", result)

    # Manually constructed results
    result_groups = dict(full_groupby(result, key=lambda x: x.modality))
    assert {
        key: [
            dict(start_idx=item.start_idx, end_idx=item.end_idx)
            for item in result_groups.get(key, [])
        ]
        for key in expected_by_key
    } == expected_by_key


# yapf: disable
@pytest.mark.parametrize(
    ("prompt", "target_by_key", "expected_by_key"),
    [
        # Detokenized test cases of `test_find_token_matches`
        # using the vocab of llava-hf/llava-v1.6-mistral-7b-hf
        (
            "",
            {
                "pattern_1": "",
                "pattern_2": "<image>",
            },
            {
                "pattern_1": [{ "start_idx": 0, "end_idx": 0 }],
                "pattern_2": [],
            }
        ),
        (
            "<image><image><image><image>",
            {
                "pattern_1": "<image>",
                "pattern_2": "<image><image>",
                "pattern_3": "<image><image><image>",
            },
            {
                "pattern_1": [
                    { "start_idx": 0, "end_idx": 7 },
                    { "start_idx": 7, "end_idx": 14 },
                    { "start_idx": 14, "end_idx": 21 },
                    { "start_idx": 21, "end_idx": 28 },
                ],
                "pattern_2": [
                    { "start_idx": 0, "end_idx": 14 },
                    { "start_idx": 14, "end_idx": 28 },
                ],
                "pattern_3": [
                    { "start_idx": 0, "end_idx": 21 },
                ],
            },
        ),
        (
            "Image:<image><image><image>Image:<image><image>!",
            {
                "pattern_1": "Image:<image>",
                "pattern_2": "Image:<image><image><image>",
                "pattern_3": "Image:<unk><image>",
            },
            {
                "pattern_1": [
                    { "start_idx": 0, "end_idx": 13 },
                    { "start_idx": 27, "end_idx": 40 },
                ],
                "pattern_2": [
                    { "start_idx": 0, "end_idx": 27 },
                ],
                "pattern_3": [],
            },
        ),
        # Test regex escape
        (
            "<|image|><image><|image|><image>",
            {
                "pattern_1": "<|image|>",
                "pattern_2": "<|image|><image>",
                "pattern_3": "<|image|><image><|image|>",
            },
            {
                "pattern_1": [
                    { "start_idx": 0, "end_idx": 9 },
                    { "start_idx": 16, "end_idx": 25 },
                ],
                "pattern_2": [
                    { "start_idx": 0, "end_idx": 16 },
                    { "start_idx": 16, "end_idx": 32 },
                ],
                "pattern_3": [
                    { "start_idx": 0, "end_idx": 25 },
                ],
            },
        ),
    ],
)
# yapf: enable
def test_find_text_matches(prompt, target_by_key, expected_by_key):
    # Should not be used since there is nothing to convert to text
    mock_tokenizer = cast(AnyTokenizer, object())

258
    prompt_repls = [
259
        PromptReplacement(key, target, []).bind(mock_tokenizer)
260
261
262
        for key, target in target_by_key.items()
    ]
    result = find_text_matches(prompt, prompt_repls)
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279

    # Only displayed on error
    print("result:", result)

    # Manually constructed results
    result_groups = dict(full_groupby(result, key=lambda x: x.modality))
    assert {
        key: [
            dict(start_idx=item.start_idx, end_idx=item.end_idx)
            for item in result_groups.get(key, [])
        ]
        for key in expected_by_key
    } == expected_by_key


# yapf: disable
@pytest.mark.parametrize(
280
    ("prompt", "target_by_key", "repl_by_key"),
281
282
283
284
285
286
287
288
289
290
291
    [
        (
            "Image:<image>Image:<image><image>!",
            {
                # We use `<image>` before `Image:` to test matches that
                # occur out of order
                "pattern_1": "<image>",
                "pattern_2": "Image:",
                "pattern_3": "!",
            },
            {
292
293
294
295
296
297
                # Test whether target is confused with replacement
                "pattern_1": "<image><image>",
                # Test empty replacement
                "pattern_2": "",
                # Test dynamic replacement (beyond the form of `unit * count`)
                "pattern_3": "?!?",
298
299
300
301
            },
        ),
    ]
)
302
303
304
305
@pytest.mark.parametrize(
    ("mm_count", "expected"),
    [
        (0, "Image:<image>Image:<image><image>!"),
306
307
        (1, "<image><image>Image:<image><image>?!?"),
        (2, "<image><image><image><image><image>?!?"),
308
309
    ]
)
310
311
312
313
314
# yapf: enable
def test_find_replace_text(
    prompt,
    target_by_key,
    repl_by_key,
315
316
    mm_count,
    expected,
317
318
319
320
):
    # Should not be used since there is nothing to convert to text
    mock_tokenizer = cast(AnyTokenizer, object())

321
322
323
324
325
    mm_prompt_repls = {
        key: [
            PromptReplacement(key, target,
                              repl_by_key[key]).bind(mock_tokenizer)
        ]
326
        for key, target in target_by_key.items()
327
328
329
330
331
    }
    mm_matches = {
        key: find_text_matches(prompt, prompt_repls)
        for key, prompt_repls in mm_prompt_repls.items()
    }
332
333

    result = replace_text_matches(
334
        prompt,
335
        mm_matches,
336
337
        {key: mm_count
         for key in repl_by_key},
338
339
340
    )

    # Only displayed on error
341
    print("mm_matches:", mm_matches)
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
    print("result:", result)

    # Manually constructed results
    assert result == expected


# yapf: disable
@pytest.mark.parametrize(
    ("prompt", "target_by_key", "repl_by_key"),
    [
        # Tokenized test cases of `test_find_replace_text`
        # using the vocab of llava-hf/llava-v1.6-mistral-7b-hf
        (
            [1, 9833, 28747, 32000, 9833, 28747, 32000, 32000, 918],
            {
                # We use `<image>` before `Image:` to test matches that
                # occur out of order
                "pattern_1": [32000],
                "pattern_2": [9833, 28747],
                "pattern_3": [918],
            },
            {
364
365
366
367
368
369
                # Test whether target is confused with replacement
                "pattern_1": [32000, 32000],
                # Test empty replacement
                "pattern_2": [],
                # Test dynamic replacement (beyond the form of `unit * count`)
                "pattern_3": [1550, 918, 1550],
370
371
372
373
374
375
376
377
            },
        ),
    ]
)
@pytest.mark.parametrize(
    ("mm_count", "expected"),
    [
        (0, [1, 9833, 28747, 32000, 9833, 28747, 32000, 32000, 918]),
378
379
        (1, [1, 32000, 32000, 9833, 28747, 32000, 32000, 1550, 918, 1550]),
        (2, [1, 32000, 32000, 32000, 32000, 32000, 1550, 918, 1550]),
380
381
382
383
384
385
386
387
388
389
390
391
392
    ]
)
# yapf: enable
def test_find_replace_tokens(
    prompt,
    target_by_key,
    repl_by_key,
    mm_count,
    expected,
):
    # Should not be used since there is nothing to convert to tokens
    mock_tokenizer = cast(AnyTokenizer, object())

393
394
395
396
397
    mm_prompt_repls = {
        key: [
            PromptReplacement(key, target,
                              repl_by_key[key]).bind(mock_tokenizer)
        ]
398
        for key, target in target_by_key.items()
399
400
401
402
403
    }
    mm_matches = {
        key: find_token_matches(prompt, prompt_repls)
        for key, prompt_repls in mm_prompt_repls.items()
    }
404
405
406

    result = replace_token_matches(
        prompt,
407
        mm_matches,
408
409
        {key: mm_count
         for key in repl_by_key},
410
411
412
    )

    # Only displayed on error
413
    print("mm_matches:", mm_matches)
414
415
416
417
418
419
420
421
422
423
424
    print("result:", result)

    # Manually constructed results
    assert result == expected


# yapf: disable
@pytest.mark.parametrize(
    "repl_by_key",
    [
        {
425
426
427
            "pattern_1": [32000, 32000],
            "pattern_2": [],
            "pattern_3": [1550, 918, 1550],
428
429
            # Test different modalities having the same tokens (32000)
            "pattern_4": [32000],
430
431
432
433
434
435
436
437
        },
    ],
)
@pytest.mark.parametrize(
    ("prompt", "expected"),
    [
        (
            [1, 9833, 28747, 32000, 9833, 28747, 32000, 32000, 918],
438
439
            {
                "pattern_1": [
440
                    PlaceholderFeaturesInfo(
441
442
443
                        modality="pattern_1",
                        item_idx=0,
                        start_idx=6,
444
                        tokens=[32000, 32000],
445
446
                    ),
                ],
447
                "pattern_4": [
448
                    PlaceholderFeaturesInfo(
449
450
451
                        modality="pattern_4",
                        item_idx=0,
                        start_idx=3,
452
                        tokens=[32000],
453
454
                    ),
                ],
455
456
            }

457
458
        ),
        (
459
            [1, 32000, 32000, 9833, 28747, 32000, 32000, 1550, 918, 1550],
460
461
            {
                "pattern_1": [
462
                    PlaceholderFeaturesInfo(
463
464
465
                        modality="pattern_1",
                        item_idx=0,
                        start_idx=1,
466
                        tokens=[32000, 32000],
467
                    ),
468
                    PlaceholderFeaturesInfo(
469
470
471
                        modality="pattern_1",
                        item_idx=1,
                        start_idx=5,
472
                        tokens=[32000, 32000],
473
474
475
                    ),
                ],
                "pattern_3": [
476
                    PlaceholderFeaturesInfo(
477
478
479
                        modality="pattern_3",
                        item_idx=0,
                        start_idx=7,
480
                        tokens=[1550, 918, 1550],
481
482
                    ),
                ],
483
                # No match for pattern_4 as it has lower priority than pattern_1
484
            }
485
486
        ),
        (
487
            [1, 32000, 32000, 32000, 32000, 32000, 1550, 918, 1550],
488
489
            {
                "pattern_1": [
490
                    PlaceholderFeaturesInfo(
491
492
493
                        modality="pattern_1",
                        item_idx=0,
                        start_idx=1,
494
                        tokens=[32000, 32000],
495
                    ),
496
                    PlaceholderFeaturesInfo(
497
498
499
                        modality="pattern_1",
                        item_idx=1,
                        start_idx=3,
500
                        tokens=[32000, 32000],
501
502
                    ),
                ],
503
                "pattern_4": [
504
                    PlaceholderFeaturesInfo(
505
506
507
                        modality="pattern_4",
                        item_idx=0,
                        start_idx=5,
508
                        tokens=[32000],
509
510
                    ),
                ],
511
                "pattern_3": [
512
                    PlaceholderFeaturesInfo(
513
514
515
                        modality="pattern_3",
                        item_idx=0,
                        start_idx=6,
516
                        tokens=[1550, 918, 1550],
517
518
519
                    ),
                ],
            }
520
521
522
        ),
    ]
)
523
# yapf: enable
524
def test_find_mm_placeholders(
525
526
527
528
529
530
531
    repl_by_key,
    prompt,
    expected,
):
    # Should not be used since there is nothing to convert to tokens
    mock_tokenizer = cast(AnyTokenizer, object())

532
533
    mm_prompt_repls = {
        key: [PromptReplacement(key, [], repl).bind(mock_tokenizer)]
534
        for key, repl in repl_by_key.items()
535
    }
536

537
538
539
540
541
542
543
    result = find_mm_placeholders(
        mm_prompt_repls,
        prompt,
        # Effectively match all occurrences in the prompt
        {key: 3
         for key in repl_by_key},
    )
544
545
546

    # Only displayed on error
    print("result:", result)
547
548

    # Manually constructed results
549
    assert result == expected
550
551


552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
@pytest.mark.parametrize("model_id", ["llava-hf/llava-v1.6-mistral-7b-hf"])
@pytest.mark.parametrize(
    ("limit", "num_supported", "is_valid"),
    [(0, 0, True), (0, 1, True), (1, 0, False), (1, 1, True), (1, 2, True),
     (2, 1, False), (2, 2, True)],
)
def test_limit_mm_per_prompt_dummy(model_id, limit, num_supported, is_valid):
    limit_mm_per_prompt = {"image": limit}

    model_config = ModelConfig(
        model=model_id,
        task="auto",
        tokenizer=model_id,
        tokenizer_mode="auto",
        trust_remote_code=False,
        seed=0,
        dtype="half",
        revision=None,
        limit_mm_per_prompt=limit_mm_per_prompt,
    )

573
    processor = MULTIMODAL_REGISTRY.create_processor(
574
575
576
        model_config,
        tokenizer=cached_get_tokenizer(model_config.tokenizer),
    )
577
    profiler = MultiModalProfiler(processor)
578
579

    mock_supported_mm_limits = MagicMock(return_value={"image": num_supported})
580
    processor.info.get_supported_mm_limits = mock_supported_mm_limits
581
582
583
584
585
586
587

    if is_valid:
        exc_ctx = nullcontext()
    else:
        exc_ctx = pytest.raises(ValueError, match="this model only supports")

    with exc_ctx:
588
        profiler.get_dummy_data(model_config.max_model_len)
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611


@pytest.mark.parametrize("model_id", ["llava-hf/llava-v1.6-mistral-7b-hf"])
@pytest.mark.parametrize(
    ("num_images", "limit", "is_valid"),
    [(0, 0, True), (0, 1, True), (1, 0, False), (1, 1, True), (1, 2, True),
     (2, 1, False), (2, 2, True)],
)
def test_limit_mm_per_prompt_apply(model_id, num_images, limit, is_valid):
    limit_mm_per_prompt = {"image": limit}

    model_config = ModelConfig(
        model=model_id,
        task="auto",
        tokenizer=model_id,
        tokenizer_mode="auto",
        trust_remote_code=False,
        seed=0,
        dtype="half",
        revision=None,
        limit_mm_per_prompt=limit_mm_per_prompt,
    )

612
    processor = MULTIMODAL_REGISTRY.create_processor(
613
614
615
616
617
        model_config,
        tokenizer=cached_get_tokenizer(model_config.tokenizer),
    )

    rng = np.random.RandomState(0)
618
    image = random_image(rng, min_wh=128, max_wh=256)
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
    if num_images == 0:
        mm_data = {}
    elif num_images == 1:
        mm_data = {"image": image}
    else:
        mm_data = {"image": [image] * num_images}

    if is_valid:
        exc_ctx = nullcontext()
    else:
        exc_ctx = pytest.raises(ValueError, match=f"passed {num_images} image")

    with exc_ctx:
        processor.apply(
            "<image>" * num_images,
            mm_data=mm_data,
            hf_processor_mm_kwargs={},
        )