test_processing.py 24.8 KB
Newer Older
1
from contextlib import nullcontext
2
from functools import partial
3
from typing import cast
4
from unittest.mock import MagicMock
5

6
import numpy as np
7
import pytest
8
9
10
11
12
from PIL import Image

from vllm.config import ModelConfig
from vllm.inputs import InputProcessingContext
from vllm.multimodal import MULTIMODAL_REGISTRY
13
14
15
16
17
# yapf conflicts with isort for this block
# yapf: disable
from vllm.multimodal.processing import (PlaceholderInfo, ProcessingCache,
                                        PromptReplacement,
                                        find_mm_placeholders,
18
                                        find_text_matches, find_token_matches,
19
                                        iter_token_matches,
20
21
                                        replace_text_matches,
                                        replace_token_matches)
22
23
# yapf: enable
from vllm.multimodal.profiling import MultiModalProfiler
24
from vllm.multimodal.utils import cached_get_tokenizer
25
26
27
28
29
30
31
32
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import full_groupby


# yapf: disable
@pytest.mark.parametrize(
    ("token_ids", "match_ids", "expected"),
    [
33
        ([], [], []),
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
        ([], [32000], []),
        (
            [32000, 32000, 32000],
            [32000],
            [
                { "start_idx": 0, "end_idx": 1 },
                { "start_idx": 1, "end_idx": 2 },
                { "start_idx": 2, "end_idx": 3 },
            ],
        ),
        (
            [32000, 32000, 32000],
            [32000, 32000],
            [{ "start_idx": 0, "end_idx": 2 }],
        ),
        (
            [32000, 32000, 32000],
            [32000, 32000, 32000],
            [{ "start_idx": 0, "end_idx": 3 }],
        ),
        (
            [9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918],
            [28747, 32000],
            [
                { "start_idx": 1, "end_idx": 3 },
                { "start_idx": 6, "end_idx": 8 },
            ],
        ),
        (
            [9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918],
            [28747, 32000, 32000, 32000],
            [
                { "start_idx": 1, "end_idx": 5 },
            ],
        ),
        (
            [9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918],
            [28747, 0, 32000],
            [],
        ),
    ],
)
# yapf: enable
def test_iter_token_matches(token_ids, match_ids, expected):
    result = list(iter_token_matches(token_ids, match_ids))

    # Manually constructed results
    assert [item._asdict() for item in result] == expected

    # Invariants
    match_lens = [end - start for start, end in result]
    print("match_lens:", match_lens)  # Only displayed on error
    assert all(match_len == len(match_ids) for match_len in match_lens)


# yapf: disable
@pytest.mark.parametrize(
    ("prompt", "target_by_key", "expected_by_key"),
    [
        (
            [],
            {
                "pattern_1": [],
                "pattern_2": [32000],
            },
            {
100
                "pattern_1": [],
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
                "pattern_2": [],
            }
        ),
        (
            [32000, 32000, 32000, 32000],
            {
                "pattern_1": [32000],
                "pattern_2": [32000, 32000],
                "pattern_3": [32000, 32000, 32000],
            },
            {
                "pattern_1": [
                    { "start_idx": 0, "end_idx": 1 },
                    { "start_idx": 1, "end_idx": 2 },
                    { "start_idx": 2, "end_idx": 3 },
                    { "start_idx": 3, "end_idx": 4 },
                ],
                "pattern_2": [
                    { "start_idx": 0, "end_idx": 2 },
                    { "start_idx": 2, "end_idx": 4 },
                ],
                "pattern_3": [
                    { "start_idx": 0, "end_idx": 3 },
                ],
            },
        ),
        (
            [9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918],
            {
                "pattern_1": [28747, 32000],
                "pattern_2": [28747, 32000, 32000, 32000],
                "pattern_3": [28747, 0, 32000],
            },
            {
                "pattern_1": [
                    { "start_idx": 1, "end_idx": 3 },
                    { "start_idx": 6, "end_idx": 8 },
                ],
                "pattern_2": [
                    { "start_idx": 1, "end_idx": 5 },
                ],
                "pattern_3": [],
            },
        ),
    ],
)
# yapf: enable
def test_find_token_matches(prompt, target_by_key, expected_by_key):
    # Should not be used since there is nothing to convert to token IDs
    mock_tokenizer = cast(AnyTokenizer, object())

152
    prompt_repls = [
153
        PromptReplacement(key, target, []).bind(mock_tokenizer)
154
155
156
        for key, target in target_by_key.items()
    ]
    result = find_token_matches(prompt, prompt_repls)
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258

    # Only displayed on error
    print("result:", result)

    # Manually constructed results
    result_groups = dict(full_groupby(result, key=lambda x: x.modality))
    assert {
        key: [
            dict(start_idx=item.start_idx, end_idx=item.end_idx)
            for item in result_groups.get(key, [])
        ]
        for key in expected_by_key
    } == expected_by_key


# yapf: disable
@pytest.mark.parametrize(
    ("prompt", "target_by_key", "expected_by_key"),
    [
        # Detokenized test cases of `test_find_token_matches`
        # using the vocab of llava-hf/llava-v1.6-mistral-7b-hf
        (
            "",
            {
                "pattern_1": "",
                "pattern_2": "<image>",
            },
            {
                "pattern_1": [{ "start_idx": 0, "end_idx": 0 }],
                "pattern_2": [],
            }
        ),
        (
            "<image><image><image><image>",
            {
                "pattern_1": "<image>",
                "pattern_2": "<image><image>",
                "pattern_3": "<image><image><image>",
            },
            {
                "pattern_1": [
                    { "start_idx": 0, "end_idx": 7 },
                    { "start_idx": 7, "end_idx": 14 },
                    { "start_idx": 14, "end_idx": 21 },
                    { "start_idx": 21, "end_idx": 28 },
                ],
                "pattern_2": [
                    { "start_idx": 0, "end_idx": 14 },
                    { "start_idx": 14, "end_idx": 28 },
                ],
                "pattern_3": [
                    { "start_idx": 0, "end_idx": 21 },
                ],
            },
        ),
        (
            "Image:<image><image><image>Image:<image><image>!",
            {
                "pattern_1": "Image:<image>",
                "pattern_2": "Image:<image><image><image>",
                "pattern_3": "Image:<unk><image>",
            },
            {
                "pattern_1": [
                    { "start_idx": 0, "end_idx": 13 },
                    { "start_idx": 27, "end_idx": 40 },
                ],
                "pattern_2": [
                    { "start_idx": 0, "end_idx": 27 },
                ],
                "pattern_3": [],
            },
        ),
        # Test regex escape
        (
            "<|image|><image><|image|><image>",
            {
                "pattern_1": "<|image|>",
                "pattern_2": "<|image|><image>",
                "pattern_3": "<|image|><image><|image|>",
            },
            {
                "pattern_1": [
                    { "start_idx": 0, "end_idx": 9 },
                    { "start_idx": 16, "end_idx": 25 },
                ],
                "pattern_2": [
                    { "start_idx": 0, "end_idx": 16 },
                    { "start_idx": 16, "end_idx": 32 },
                ],
                "pattern_3": [
                    { "start_idx": 0, "end_idx": 25 },
                ],
            },
        ),
    ],
)
# yapf: enable
def test_find_text_matches(prompt, target_by_key, expected_by_key):
    # Should not be used since there is nothing to convert to text
    mock_tokenizer = cast(AnyTokenizer, object())

259
    prompt_repls = [
260
        PromptReplacement(key, target, []).bind(mock_tokenizer)
261
262
263
        for key, target in target_by_key.items()
    ]
    result = find_text_matches(prompt, prompt_repls)
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280

    # Only displayed on error
    print("result:", result)

    # Manually constructed results
    result_groups = dict(full_groupby(result, key=lambda x: x.modality))
    assert {
        key: [
            dict(start_idx=item.start_idx, end_idx=item.end_idx)
            for item in result_groups.get(key, [])
        ]
        for key in expected_by_key
    } == expected_by_key


# yapf: disable
@pytest.mark.parametrize(
281
    ("prompt", "target_by_key", "repl_by_key"),
282
283
284
285
286
287
288
289
290
291
292
    [
        (
            "Image:<image>Image:<image><image>!",
            {
                # We use `<image>` before `Image:` to test matches that
                # occur out of order
                "pattern_1": "<image>",
                "pattern_2": "Image:",
                "pattern_3": "!",
            },
            {
293
294
295
296
297
298
                # Test whether target is confused with replacement
                "pattern_1": "<image><image>",
                # Test empty replacement
                "pattern_2": "",
                # Test dynamic replacement (beyond the form of `unit * count`)
                "pattern_3": "?!?",
299
300
301
302
            },
        ),
    ]
)
303
304
305
306
@pytest.mark.parametrize(
    ("mm_count", "expected"),
    [
        (0, "Image:<image>Image:<image><image>!"),
307
308
        (1, "<image><image>Image:<image><image>?!?"),
        (2, "<image><image><image><image><image>?!?"),
309
310
    ]
)
311
312
313
314
315
# yapf: enable
def test_find_replace_text(
    prompt,
    target_by_key,
    repl_by_key,
316
317
    mm_count,
    expected,
318
319
320
321
):
    # Should not be used since there is nothing to convert to text
    mock_tokenizer = cast(AnyTokenizer, object())

322
323
324
325
326
    mm_prompt_repls = {
        key: [
            PromptReplacement(key, target,
                              repl_by_key[key]).bind(mock_tokenizer)
        ]
327
        for key, target in target_by_key.items()
328
329
330
331
332
    }
    mm_matches = {
        key: find_text_matches(prompt, prompt_repls)
        for key, prompt_repls in mm_prompt_repls.items()
    }
333
334

    result = replace_text_matches(
335
        prompt,
336
        mm_matches,
337
338
        {key: mm_count
         for key in repl_by_key},
339
340
341
    )

    # Only displayed on error
342
    print("mm_matches:", mm_matches)
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
    print("result:", result)

    # Manually constructed results
    assert result == expected


# yapf: disable
@pytest.mark.parametrize(
    ("prompt", "target_by_key", "repl_by_key"),
    [
        # Tokenized test cases of `test_find_replace_text`
        # using the vocab of llava-hf/llava-v1.6-mistral-7b-hf
        (
            [1, 9833, 28747, 32000, 9833, 28747, 32000, 32000, 918],
            {
                # We use `<image>` before `Image:` to test matches that
                # occur out of order
                "pattern_1": [32000],
                "pattern_2": [9833, 28747],
                "pattern_3": [918],
            },
            {
365
366
367
368
369
370
                # Test whether target is confused with replacement
                "pattern_1": [32000, 32000],
                # Test empty replacement
                "pattern_2": [],
                # Test dynamic replacement (beyond the form of `unit * count`)
                "pattern_3": [1550, 918, 1550],
371
372
373
374
375
376
377
378
            },
        ),
    ]
)
@pytest.mark.parametrize(
    ("mm_count", "expected"),
    [
        (0, [1, 9833, 28747, 32000, 9833, 28747, 32000, 32000, 918]),
379
380
        (1, [1, 32000, 32000, 9833, 28747, 32000, 32000, 1550, 918, 1550]),
        (2, [1, 32000, 32000, 32000, 32000, 32000, 1550, 918, 1550]),
381
382
383
384
385
386
387
388
389
390
391
392
393
    ]
)
# yapf: enable
def test_find_replace_tokens(
    prompt,
    target_by_key,
    repl_by_key,
    mm_count,
    expected,
):
    # Should not be used since there is nothing to convert to tokens
    mock_tokenizer = cast(AnyTokenizer, object())

394
395
396
397
398
    mm_prompt_repls = {
        key: [
            PromptReplacement(key, target,
                              repl_by_key[key]).bind(mock_tokenizer)
        ]
399
        for key, target in target_by_key.items()
400
401
402
403
404
    }
    mm_matches = {
        key: find_token_matches(prompt, prompt_repls)
        for key, prompt_repls in mm_prompt_repls.items()
    }
405
406
407

    result = replace_token_matches(
        prompt,
408
        mm_matches,
409
410
        {key: mm_count
         for key in repl_by_key},
411
412
413
    )

    # Only displayed on error
414
    print("mm_matches:", mm_matches)
415
416
417
418
419
420
421
422
423
424
425
    print("result:", result)

    # Manually constructed results
    assert result == expected


# yapf: disable
@pytest.mark.parametrize(
    "repl_by_key",
    [
        {
426
427
428
            "pattern_1": [32000, 32000],
            "pattern_2": [],
            "pattern_3": [1550, 918, 1550],
429
430
431
432
433
434
435
436
        },
    ],
)
@pytest.mark.parametrize(
    ("prompt", "expected"),
    [
        (
            [1, 9833, 28747, 32000, 9833, 28747, 32000, 32000, 918],
437
438
            {
                "pattern_1": [
439
                    PlaceholderInfo(
440
441
442
443
444
445
446
447
                        modality="pattern_1",
                        item_idx=0,
                        start_idx=6,
                        replacement=[32000, 32000],
                    ),
                ],
            }

448
449
        ),
        (
450
            [1, 32000, 32000, 9833, 28747, 32000, 32000, 1550, 918, 1550],
451
452
            {
                "pattern_1": [
453
                    PlaceholderInfo(
454
455
456
457
458
                        modality="pattern_1",
                        item_idx=0,
                        start_idx=1,
                        replacement=[32000, 32000],
                    ),
459
                    PlaceholderInfo(
460
461
462
463
464
465
466
                        modality="pattern_1",
                        item_idx=1,
                        start_idx=5,
                        replacement=[32000, 32000],
                    ),
                ],
                "pattern_3": [
467
                    PlaceholderInfo(
468
469
470
471
472
473
474
                        modality="pattern_3",
                        item_idx=0,
                        start_idx=7,
                        replacement=[1550, 918, 1550],
                    ),
                ],
            }
475
476
        ),
        (
477
            [1, 32000, 32000, 32000, 32000, 32000, 1550, 918, 1550],
478
479
            {
                "pattern_1": [
480
                    PlaceholderInfo(
481
482
483
484
485
                        modality="pattern_1",
                        item_idx=0,
                        start_idx=1,
                        replacement=[32000, 32000],
                    ),
486
                    PlaceholderInfo(
487
488
489
490
491
492
493
                        modality="pattern_1",
                        item_idx=1,
                        start_idx=3,
                        replacement=[32000, 32000],
                    ),
                ],
                "pattern_3": [
494
                    PlaceholderInfo(
495
496
497
498
499
500
501
                        modality="pattern_3",
                        item_idx=0,
                        start_idx=6,
                        replacement=[1550, 918, 1550],
                    ),
                ],
            }
502
503
504
        ),
    ]
)
505
# yapf: enable
506
def test_find_mm_placeholders(
507
508
509
510
511
512
513
    repl_by_key,
    prompt,
    expected,
):
    # Should not be used since there is nothing to convert to tokens
    mock_tokenizer = cast(AnyTokenizer, object())

514
515
    mm_prompt_repls = {
        key: [PromptReplacement(key, [], repl).bind(mock_tokenizer)]
516
        for key, repl in repl_by_key.items()
517
    }
518

519
520
521
522
523
524
525
    result = find_mm_placeholders(
        mm_prompt_repls,
        prompt,
        # Effectively match all occurrences in the prompt
        {key: 3
         for key in repl_by_key},
    )
526
527
528

    # Only displayed on error
    print("result:", result)
529
530

    # Manually constructed results
531
    assert result == expected
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564


def _rand_img(rng: np.random.RandomState, min_wh: int, max_wh: int):
    w, h = rng.randint(min_wh, max_wh, size=(2, ))
    arr = rng.randint(0, 255, size=(w, h, 3), dtype=np.uint8)
    return Image.fromarray(arr)


def _rand_video(
    rng: np.random.RandomState,
    min_frames: int,
    max_frames: int,
    min_wh: int,
    max_wh: int,
):
    # Temporary workaround for https://github.com/huggingface/transformers/issues/35412
    num_frames = rng.randint(min_frames, max_frames)
    num_frames = (num_frames // 2) * 2

    w, h = rng.randint(min_wh, max_wh, size=(2, ))
    return rng.randint(0, 255, size=(num_frames, w, h, 3), dtype=np.uint8)


def _rand_audio(
    rng: np.random.RandomState,
    min_len: int,
    max_len: int,
    sr: int,
):
    audio_len = rng.randint(min_len, max_len)
    return rng.rand(audio_len), sr


565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
@pytest.mark.parametrize("model_id", ["llava-hf/llava-v1.6-mistral-7b-hf"])
@pytest.mark.parametrize(
    ("limit", "num_supported", "is_valid"),
    [(0, 0, True), (0, 1, True), (1, 0, False), (1, 1, True), (1, 2, True),
     (2, 1, False), (2, 2, True)],
)
def test_limit_mm_per_prompt_dummy(model_id, limit, num_supported, is_valid):
    limit_mm_per_prompt = {"image": limit}

    model_config = ModelConfig(
        model=model_id,
        task="auto",
        tokenizer=model_id,
        tokenizer_mode="auto",
        trust_remote_code=False,
        seed=0,
        dtype="half",
        revision=None,
        limit_mm_per_prompt=limit_mm_per_prompt,
    )

586
    processor = MULTIMODAL_REGISTRY.create_processor(
587
588
589
        model_config,
        tokenizer=cached_get_tokenizer(model_config.tokenizer),
    )
590
    profiler = MultiModalProfiler(processor)
591
592

    mock_supported_mm_limits = MagicMock(return_value={"image": num_supported})
593
    processor.info.get_supported_mm_limits = mock_supported_mm_limits
594
595
596
597
598
599
600

    if is_valid:
        exc_ctx = nullcontext()
    else:
        exc_ctx = pytest.raises(ValueError, match="this model only supports")

    with exc_ctx:
601
        profiler.get_dummy_data(model_config.max_model_len)
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624


@pytest.mark.parametrize("model_id", ["llava-hf/llava-v1.6-mistral-7b-hf"])
@pytest.mark.parametrize(
    ("num_images", "limit", "is_valid"),
    [(0, 0, True), (0, 1, True), (1, 0, False), (1, 1, True), (1, 2, True),
     (2, 1, False), (2, 2, True)],
)
def test_limit_mm_per_prompt_apply(model_id, num_images, limit, is_valid):
    limit_mm_per_prompt = {"image": limit}

    model_config = ModelConfig(
        model=model_id,
        task="auto",
        tokenizer=model_id,
        tokenizer_mode="auto",
        trust_remote_code=False,
        seed=0,
        dtype="half",
        revision=None,
        limit_mm_per_prompt=limit_mm_per_prompt,
    )

625
    processor = MULTIMODAL_REGISTRY.create_processor(
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
        model_config,
        tokenizer=cached_get_tokenizer(model_config.tokenizer),
    )

    rng = np.random.RandomState(0)
    image = _rand_img(rng, min_wh=128, max_wh=256)
    if num_images == 0:
        mm_data = {}
    elif num_images == 1:
        mm_data = {"image": image}
    else:
        mm_data = {"image": [image] * num_images}

    if is_valid:
        exc_ctx = nullcontext()
    else:
        exc_ctx = pytest.raises(ValueError, match=f"passed {num_images} image")

    with exc_ctx:
        processor.apply(
            "<image>" * num_images,
            mm_data=mm_data,
            hf_processor_mm_kwargs={},
        )


652
def _test_processing_correctness(
653
    model_id: str,
654
    modalities: dict[str, bool],
655
656
657
658
659
660
661
662
663
    hit_rate: float,
    num_batches: int,
    simplify_rate: float,
):
    if model_id == "TIGER-Lab/Mantis-8B-siglip-llama3":
        hf_overrides = {"architectures": ["MantisForConditionalGeneration"]}
    else:
        hf_overrides = {}

664
665
666
667
668
    limit_mm_per_prompt = {
        modality: 3 if supports_multi else 1
        for modality, supports_multi in modalities.items()
    }

669
670
671
672
673
674
675
676
677
678
    model_config = ModelConfig(
        model_id,
        task="auto",
        tokenizer=model_id,
        tokenizer_mode="auto",
        trust_remote_code=True,
        seed=0,
        dtype="float16",
        revision=None,
        hf_overrides=hf_overrides,
679
        limit_mm_per_prompt=limit_mm_per_prompt,
680
681
    )

682
683
    model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config)
    factories = MULTIMODAL_REGISTRY._processor_factories[model_cls]
684
685
686
687
688
689
690
    ctx = InputProcessingContext(
        model_config,
        tokenizer=cached_get_tokenizer(model_config.tokenizer),
    )
    # Ensure that it can fit all of the data
    cache = ProcessingCache(capacity=1 << 30)

691
692
693
    baseline_processor = factories.build_processor(ctx, cache=None)
    cached_processor = factories.build_processor(ctx, cache=cache)
    dummy_inputs = baseline_processor.dummy_inputs
694
    tokenizer = baseline_processor.info.get_tokenizer()
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713

    rng = np.random.RandomState(0)

    input_to_hit = {
        "image": Image.new("RGB", size=(128, 128)),
        "video": np.zeros((4, 128, 128, 3), dtype=np.uint8),
        "audio": (np.zeros((512, )), 16000),
    }
    input_factory = {
        "image":
        partial(_rand_img, rng, min_wh=128, max_wh=256),
        "video":
        partial(_rand_video,
                rng,
                min_frames=2,
                max_frames=8,
                min_wh=128,
                max_wh=256),
        "audio":
714
        partial(_rand_audio, rng, min_len=512, max_len=1024, sr=16000),
715
716
717
718
719
720
    }

    for batch_idx in range(num_batches):
        mm_data = {
            k:
            [(input_to_hit[k] if rng.rand() < hit_rate else input_factory[k]())
721
             for _ in range(rng.randint(limit_mm_per_prompt[k]))]
722
723
724
725
            for k in modalities
        }

        mm_counts = {k: len(vs) for k, vs in mm_data.items()}
726
        prompt = dummy_inputs.get_dummy_processor_inputs(
727
728
729
            model_config.max_model_len,
            mm_counts,
        ).prompt_text
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750

        # Drop unnecessary keys and test single -> multi conversion
        if rng.rand() < simplify_rate:
            for k in list(mm_data.keys()):
                if not mm_data[k]:
                    del mm_data[k]
                elif len(mm_data[k]) == 1:
                    mm_data[k] = mm_data[k][0]

        baseline_result = baseline_processor.apply(
            prompt,
            mm_data=mm_data,
            hf_processor_mm_kwargs={},
        )
        cached_result = cached_processor.apply(
            prompt,
            mm_data=mm_data,
            hf_processor_mm_kwargs={},
        )

        assert baseline_result == cached_result, (
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
            f"Failed ({batch_idx=}, {prompt=}, {mm_data=})")

        baseline_tokenized_result = baseline_processor.apply(
            tokenizer.encode(prompt),
            mm_data=mm_data,
            hf_processor_mm_kwargs={},
        )

        assert baseline_result == baseline_tokenized_result, (
            f"Failed ({batch_idx=}, {prompt=}, {mm_data=})")

        cached_tokenized_result = cached_processor.apply(
            tokenizer.encode(prompt),
            mm_data=mm_data,
            hf_processor_mm_kwargs={},
        )

        assert cached_result == cached_tokenized_result, (
            f"Failed ({batch_idx=}, {prompt=}, {mm_data=})")
770
771
772


# yapf: disable
773
# True if the model supports multiple data items of the modality per request
774
@pytest.mark.parametrize(("model_id", "modalities"), [
775
776
    ("rhymes-ai/Aria", {"image": True}),
    ("Salesforce/blip2-opt-2.7b", {"image": False}),
777
    ("facebook/chameleon-7b", {"image": False}),
778
779
    ("adept/fuyu-8b", {"image": False}),
    ("llava-hf/llava-1.5-7b-hf", {"image": True}),
780
    ("llava-hf/llava-v1.6-mistral-7b-hf", {"image": True}),
781
782
    ("llava-hf/LLaVA-NeXT-Video-7B-hf", {"video": False}),
    ("llava-hf/llava-onevision-qwen2-0.5b-ov-hf", {"image": True, "video": True}),  # noqa: E501
783
784
785
786
787
    ("TIGER-Lab/Mantis-8B-siglip-llama3", {"image": True}),
    ("mistral-community/pixtral-12b", {"image": True}),
    ("Qwen/Qwen2-VL-2B-Instruct", {"image": True, "video": True}),
    ("Qwen/Qwen2-Audio-7B-Instruct", {"audio": True}),
    ("fixie-ai/ultravox-v0_3", {"audio": True}),
788
789
790
791
792
])
@pytest.mark.parametrize("hit_rate", [0.3, 0.5, 1.0])
@pytest.mark.parametrize("num_batches", [32])
@pytest.mark.parametrize("simplify_rate", [1.0])
# yapf: enable
793
def test_processing_correctness(
794
    model_id: str,
795
    modalities: dict[str, bool],
796
797
798
799
    hit_rate: float,
    num_batches: int,
    simplify_rate: float,
):
800
    _test_processing_correctness(
801
802
803
804
805
806
807
808
809
810
        model_id,
        modalities,
        hit_rate=hit_rate,
        num_batches=num_batches,
        simplify_rate=simplify_rate,
    )


# yapf: disable
@pytest.mark.parametrize(("model_id", "modalities"), [
811
    ("microsoft/Phi-3-vision-128k-instruct", {"image": True}),
812
813
814
815
816
])
@pytest.mark.parametrize("hit_rate", [0.3, 0.5, 1.0])
@pytest.mark.parametrize("num_batches", [32])
@pytest.mark.parametrize("simplify_rate", [1.0])
# yapf: enable
817
def test_processing_correctness_phi3v(
818
    model_id: str,
819
    modalities: dict[str, bool],
820
821
822
823
824
825
826
827
828
829
830
    hit_rate: float,
    num_batches: int,
    simplify_rate: float,
):
    # HACK - this is an attempted workaround for the following bug
    # https://github.com/huggingface/transformers/issues/34307
    from transformers import AutoImageProcessor  # noqa: F401
    from transformers import AutoProcessor  # noqa: F401

    AutoImageProcessor.from_pretrained(model_id, trust_remote_code=True)

831
    _test_processing_correctness(
832
833
834
835
836
837
        model_id,
        modalities,
        hit_rate=hit_rate,
        num_batches=num_batches,
        simplify_rate=simplify_rate,
    )