test_processing.py 32.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
from contextlib import nullcontext
5
from typing import Optional, cast
6

7
import numpy as np
8
import pytest
9
10
11

from vllm.config import ModelConfig
from vllm.multimodal import MULTIMODAL_REGISTRY
12
13
14
15
16
17
18
19
20
21
22
23
from vllm.multimodal.processing import (
    InputProcessingContext,
    PlaceholderFeaturesInfo,
    PromptIndexTargets,
    PromptInsertion,
    PromptReplacement,
    apply_text_matches,
    apply_token_matches,
    find_mm_placeholders,
    iter_token_matches,
    replace_token_matches,
)
24
from vllm.multimodal.profiling import MultiModalProfiler
25
from vllm.transformers_utils.tokenizer import AnyTokenizer
26

27
28
from .utils import random_image

29
30
pytestmark = pytest.mark.cpu_test

31
32
33
34

@pytest.mark.parametrize(
    ("token_ids", "match_ids", "expected"),
    [
35
        ([], [], []),
36
37
38
39
40
        ([], [32000], []),
        (
            [32000, 32000, 32000],
            [32000],
            [
41
42
43
                {"start_idx": 0, "end_idx": 1},
                {"start_idx": 1, "end_idx": 2},
                {"start_idx": 2, "end_idx": 3},
44
45
46
47
48
            ],
        ),
        (
            [32000, 32000, 32000],
            [32000, 32000],
49
            [{"start_idx": 0, "end_idx": 2}],
50
51
52
53
        ),
        (
            [32000, 32000, 32000],
            [32000, 32000, 32000],
54
            [{"start_idx": 0, "end_idx": 3}],
55
56
57
58
59
        ),
        (
            [9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918],
            [28747, 32000],
            [
60
61
                {"start_idx": 1, "end_idx": 3},
                {"start_idx": 6, "end_idx": 8},
62
63
64
65
66
67
            ],
        ),
        (
            [9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918],
            [28747, 32000, 32000, 32000],
            [
68
                {"start_idx": 1, "end_idx": 5},
69
70
71
72
73
74
75
76
77
            ],
        ),
        (
            [9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918],
            [28747, 0, 32000],
            [],
        ),
    ],
)
78
79
@pytest.mark.parametrize("start_idx", [0, 4, 8])
def test_iter_token_matches(token_ids, match_ids, expected, start_idx):
80
    result = list(iter_token_matches(token_ids, match_ids, start_idx=start_idx))
81
82

    # Manually constructed results
83
84
85
    assert [item._asdict() for item in result] == [
        item for item in expected if item["start_idx"] >= start_idx
    ]
86
87
88
89
90
91
92

    # Invariants
    match_lens = [end - start for start, end in result]
    print("match_lens:", match_lens)  # Only displayed on error
    assert all(match_len == len(match_ids) for match_len in match_lens)


93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
@pytest.mark.parametrize(
    ("token_ids", "match_ids", "new_ids", "expected"),
    [
        ([], [], [-1], []),
        ([], [32000], [-1], []),
        (
            [32000, 32000, 32000],
            [32000],
            [-1],
            [-1, -1, -1],
        ),
        (
            [32000, 32000, 32000],
            [32000, 32000],
            [-1],
            [-1, 32000],
        ),
        (
            [32000, 32000, 32000],
            [32000, 32000, 32000],
            [-1],
            [-1],
        ),
        (
            [9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918],
            [28747, 32000],
            [-1],
            [9833, -1, 32000, 32000, 9833, -1, 32000, 918],
        ),
        (
            [9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918],
            [28747, 32000, 32000, 32000],
            [-1],
            [9833, -1, 9833, 28747, 32000, 32000, 918],
        ),
        (
            [9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918],
            [28747, 0, 32000],
            [-1],
            [9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918],
        ),
    ],
)
def test_replace_token_matches(token_ids, match_ids, new_ids, expected):
    result = replace_token_matches(token_ids, match_ids, new_ids)

    # Manually constructed results
    assert result == expected


143
144
145
146
147
148
149
150
@pytest.mark.parametrize(
    ("prompt", "target_by_key", "expected_by_key"),
    [
        (
            [],
            {
                "pattern_1": [],
                "pattern_2": [32000],
151
152
153
                "pattern_3": PromptIndexTargets.start(),
                "pattern_4": PromptIndexTargets.prefix([32000]),
                "pattern_5": PromptIndexTargets.end(),
154
155
            },
            {
156
                "pattern_1": [],
157
                "pattern_2": [],
158
                "pattern_3": [
159
                    {"start_idx": 0, "end_idx": 0},
160
161
162
                ],
                "pattern_4": [],
                "pattern_5": [
163
                    {"start_idx": 0, "end_idx": 0},
164
                ],
165
            },
166
167
168
169
170
171
172
        ),
        (
            [32000, 32000, 32000, 32000],
            {
                "pattern_1": [32000],
                "pattern_2": [32000, 32000],
                "pattern_3": [32000, 32000, 32000],
173
174
175
                "pattern_4": PromptIndexTargets.start(),
                "pattern_5": PromptIndexTargets.prefix([32000]),
                "pattern_6": PromptIndexTargets.end(),
176
177
178
            },
            {
                "pattern_1": [
179
180
181
182
                    {"start_idx": 0, "end_idx": 1},
                    {"start_idx": 1, "end_idx": 2},
                    {"start_idx": 2, "end_idx": 3},
                    {"start_idx": 3, "end_idx": 4},
183
184
                ],
                "pattern_2": [
185
186
                    {"start_idx": 0, "end_idx": 2},
                    {"start_idx": 2, "end_idx": 4},
187
188
                ],
                "pattern_3": [
189
                    {"start_idx": 0, "end_idx": 3},
190
                ],
191
                "pattern_4": [
192
                    {"start_idx": 0, "end_idx": 0},
193
194
                ],
                "pattern_5": [
195
                    {"start_idx": 1, "end_idx": 1},
196
197
                ],
                "pattern_6": [
198
                    {"start_idx": 4, "end_idx": 4},
199
                ],
200
201
202
203
204
205
206
207
            },
        ),
        (
            [9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918],
            {
                "pattern_1": [28747, 32000],
                "pattern_2": [28747, 32000, 32000, 32000],
                "pattern_3": [28747, 0, 32000],
208
209
210
                "pattern_4": PromptIndexTargets.start(),
                "pattern_5": PromptIndexTargets.prefix([28747, 32000]),
                "pattern_6": PromptIndexTargets.end(),
211
212
213
            },
            {
                "pattern_1": [
214
215
                    {"start_idx": 1, "end_idx": 3},
                    {"start_idx": 6, "end_idx": 8},
216
217
                ],
                "pattern_2": [
218
                    {"start_idx": 1, "end_idx": 5},
219
220
                ],
                "pattern_3": [],
221
                "pattern_4": [
222
                    {"start_idx": 0, "end_idx": 0},
223
224
225
                ],
                "pattern_5": [],
                "pattern_6": [
226
                    {"start_idx": 10, "end_idx": 10},
227
                ],
228
229
230
231
            },
        ),
    ],
)
232
233
234
235
236
237
238
@pytest.mark.parametrize("update_type", [PromptInsertion, PromptReplacement])
def test_find_token_matches(
    prompt,
    target_by_key,
    expected_by_key,
    update_type,
):
239
240
241
    # Should not be used since there is nothing to convert to token IDs
    mock_tokenizer = cast(AnyTokenizer, object())

242
    prompt_updates = {
243
        key: update_type(key, target, []).resolve(0)
244
        for key, target in target_by_key.items()
245
246
247
248
249
    }
    result = {
        key: list(update.iter_token_matches(prompt, mock_tokenizer))
        for key, update in prompt_updates.items()
    }
250
251
252
253
254
255
256
257

    # Only displayed on error
    print("result:", result)

    # Manually constructed results
    assert {
        key: [
            dict(start_idx=item.start_idx, end_idx=item.end_idx)
258
            for item in result.get(key, [])
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
        ]
        for key in expected_by_key
    } == expected_by_key


@pytest.mark.parametrize(
    ("prompt", "target_by_key", "expected_by_key"),
    [
        # Detokenized test cases of `test_find_token_matches`
        # using the vocab of llava-hf/llava-v1.6-mistral-7b-hf
        (
            "",
            {
                "pattern_1": "",
                "pattern_2": "<image>",
274
275
276
                "pattern_3": PromptIndexTargets.start(),
                "pattern_4": PromptIndexTargets.prefix("<image>"),
                "pattern_5": PromptIndexTargets.end(),
277
278
            },
            {
279
                "pattern_1": [{"start_idx": 0, "end_idx": 0}],
280
                "pattern_2": [],
281
                "pattern_3": [
282
                    {"start_idx": 0, "end_idx": 0},
283
284
285
                ],
                "pattern_4": [],
                "pattern_5": [
286
                    {"start_idx": 0, "end_idx": 0},
287
                ],
288
            },
289
290
291
292
293
294
295
        ),
        (
            "<image><image><image><image>",
            {
                "pattern_1": "<image>",
                "pattern_2": "<image><image>",
                "pattern_3": "<image><image><image>",
296
297
298
                "pattern_4": PromptIndexTargets.start(),
                "pattern_5": PromptIndexTargets.prefix("<image>"),
                "pattern_6": PromptIndexTargets.end(),
299
300
301
            },
            {
                "pattern_1": [
302
303
304
305
                    {"start_idx": 0, "end_idx": 7},
                    {"start_idx": 7, "end_idx": 14},
                    {"start_idx": 14, "end_idx": 21},
                    {"start_idx": 21, "end_idx": 28},
306
307
                ],
                "pattern_2": [
308
309
                    {"start_idx": 0, "end_idx": 14},
                    {"start_idx": 14, "end_idx": 28},
310
311
                ],
                "pattern_3": [
312
                    {"start_idx": 0, "end_idx": 21},
313
                ],
314
                "pattern_4": [
315
                    {"start_idx": 0, "end_idx": 0},
316
317
                ],
                "pattern_5": [
318
                    {"start_idx": 7, "end_idx": 7},
319
320
                ],
                "pattern_6": [
321
                    {"start_idx": 28, "end_idx": 28},
322
                ],
323
324
325
326
327
328
329
330
            },
        ),
        (
            "Image:<image><image><image>Image:<image><image>!",
            {
                "pattern_1": "Image:<image>",
                "pattern_2": "Image:<image><image><image>",
                "pattern_3": "Image:<unk><image>",
331
332
333
                "pattern_4": PromptIndexTargets.start(),
                "pattern_5": PromptIndexTargets.prefix("Image:<image>"),
                "pattern_6": PromptIndexTargets.end(),
334
335
336
            },
            {
                "pattern_1": [
337
338
                    {"start_idx": 0, "end_idx": 13},
                    {"start_idx": 27, "end_idx": 40},
339
340
                ],
                "pattern_2": [
341
                    {"start_idx": 0, "end_idx": 27},
342
343
                ],
                "pattern_3": [],
344
                "pattern_4": [
345
                    {"start_idx": 0, "end_idx": 0},
346
347
                ],
                "pattern_5": [
348
                    {"start_idx": 13, "end_idx": 13},
349
350
                ],
                "pattern_6": [
351
                    {"start_idx": 48, "end_idx": 48},
352
                ],
353
354
355
356
357
358
359
360
361
362
363
364
            },
        ),
        # Test regex escape
        (
            "<|image|><image><|image|><image>",
            {
                "pattern_1": "<|image|>",
                "pattern_2": "<|image|><image>",
                "pattern_3": "<|image|><image><|image|>",
            },
            {
                "pattern_1": [
365
366
                    {"start_idx": 0, "end_idx": 9},
                    {"start_idx": 16, "end_idx": 25},
367
368
                ],
                "pattern_2": [
369
370
                    {"start_idx": 0, "end_idx": 16},
                    {"start_idx": 16, "end_idx": 32},
371
372
                ],
                "pattern_3": [
373
                    {"start_idx": 0, "end_idx": 25},
374
375
376
377
378
                ],
            },
        ),
    ],
)
379
380
381
382
383
384
385
@pytest.mark.parametrize("update_type", [PromptInsertion, PromptReplacement])
def test_find_text_matches(
    prompt,
    target_by_key,
    expected_by_key,
    update_type,
):
386
387
388
    # Should not be used since there is nothing to convert to text
    mock_tokenizer = cast(AnyTokenizer, object())

389
    prompt_updates = {
390
        key: update_type(key, target, []).resolve(0)
391
        for key, target in target_by_key.items()
392
393
394
395
396
    }
    result = {
        key: list(update.iter_text_matches(prompt, mock_tokenizer))
        for key, update in prompt_updates.items()
    }
397
398
399
400
401
402
403
404

    # Only displayed on error
    print("result:", result)

    # Manually constructed results
    assert {
        key: [
            dict(start_idx=item.start_idx, end_idx=item.end_idx)
405
            for item in result.get(key, [])
406
407
408
409
410
411
        ]
        for key in expected_by_key
    } == expected_by_key


@pytest.mark.parametrize(
412
    ("prompt", "target_by_key", "repl_by_key", "expected_by_update_type_mm_count"),  # noqa: E501
413
414
415
416
417
418
419
420
421
422
423
    [
        (
            "Image:<image>Image:<image><image>!",
            {
                # We use `<image>` before `Image:` to test matches that
                # occur out of order
                "pattern_1": "<image>",
                "pattern_2": "Image:",
                "pattern_3": "!",
            },
            {
424
425
426
427
428
429
                # Test whether target is confused with replacement
                "pattern_1": "<image><image>",
                # Test empty replacement
                "pattern_2": "",
                # Test dynamic replacement (beyond the form of `unit * count`)
                "pattern_3": "?!?",
430
            },
431
432
433
434
435
436
437
438
439
440
441
442
            {
                PromptInsertion: {
                    0: "Image:<image>Image:<image><image>!",
                    1: "Image:<image><image><image>Image:<image><image>!?!?",
                    2: "Image:<image><image><image><image><image>Image:<image><image>!?!??!?",  # noqa: E501
                },
                PromptReplacement: {
                    0: "Image:<image>Image:<image><image>!",
                    1: "<image><image>Image:<image><image>?!?",
                    2: "<image><image><image><image><image>?!?",
                },
            },
443
        ),
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
        # Test index targets
        (
            "",
            {
                "pattern_1": PromptIndexTargets.start(),
                "pattern_2": PromptIndexTargets.prefix("<image>"),
                "pattern_3": PromptIndexTargets.end(),
            },
            {
                "pattern_1": "1",
                "pattern_2": "2",
                "pattern_3": "3",
            },
            {
                PromptInsertion: {
                    0: "",
                    1: "13",
                    2: "1133",
                },
                PromptReplacement: {
                    0: "",
                    1: "13",
                    2: "1133",
                },
            },
        ),
        (
            "<image>",
            {
                "pattern_1": PromptIndexTargets.start(),
                "pattern_2": PromptIndexTargets.prefix("<image>"),
                "pattern_3": PromptIndexTargets.end(),
            },
            {
                "pattern_1": "1",
                "pattern_2": "2",
                "pattern_3": "3",
            },
            {
                PromptInsertion: {
                    0: "<image>",
                    1: "1<image>23",
                    2: "11<image>2233",
                },
                PromptReplacement: {
                    0: "<image>",
                    1: "1<image>23",
                    2: "11<image>2233",
                },
            },
        ),
        # Test different replacement per item
        (
            "<image><image><image>",
            {
                "pattern_1": "<image>",
            },
            {
                "pattern_1": lambda idx: str(idx + 1),
            },
            {
                PromptInsertion: {
                    0: "<image><image><image>",
                    1: "<image>1<image><image>",
                    2: "<image>12<image><image>",
                },
                PromptReplacement: {
                    0: "<image><image><image>",
                    1: "1<image><image>",
                    2: "12<image>",
                },
            },
        ),
        (
            "<image><image><image>",
            {
                "pattern_1": PromptIndexTargets.prefix("<image>"),
            },
            {
                "pattern_1": lambda idx: str(idx + 1),
            },
            {
                PromptInsertion: {
                    0: "<image><image><image>",
                    1: "<image>1<image><image>",
                    2: "<image>12<image><image>",
                },
                PromptReplacement: {
                    0: "<image><image><image>",
                    1: "<image>1<image><image>",
                    2: "<image>12<image><image>",
                },
            },
        ),
538
    ],
539
)
540
def test_find_update_text(
541
542
543
    prompt,
    target_by_key,
    repl_by_key,
544
    expected_by_update_type_mm_count,
545
546
547
548
):
    # Should not be used since there is nothing to convert to text
    mock_tokenizer = cast(AnyTokenizer, object())

549
    for (
550
551
        update_type,
        expected_by_mm_count,
552
553
    ) in expected_by_update_type_mm_count.items():
        for mm_count, expected in expected_by_mm_count.items():
554
            mm_prompt_updates = {
555
556
557
558
                key: [
                    [update_type(key, target, repl_by_key[key]).resolve(i)]
                    for i in range(mm_count)
                ]
559
560
561
562
                for key, target in target_by_key.items()
            }

            new_prompt, result = apply_text_matches(
563
                prompt,
564
565
                mm_prompt_updates,
                mock_tokenizer,
566
567
568
569
570
            )

            # Only displayed on error
            print("update_type:", update_type)
            print("mm_count:", mm_count)
571
572
            print("mm_prompt_updates:", mm_prompt_updates)
            print("new_prompt:", new_prompt)
573
574
575
            print("result:", result)

            # Manually constructed results
576
            assert new_prompt == expected
577
578
579


@pytest.mark.parametrize(
580
    ("prompt", "target_by_key", "repl_by_key", "expected_by_update_type_mm_count"),  # noqa: E501
581
    [
582
        # Tokenized test cases of `test_find_update_text`
583
584
585
586
587
588
589
590
591
592
593
        # using the vocab of llava-hf/llava-v1.6-mistral-7b-hf
        (
            [1, 9833, 28747, 32000, 9833, 28747, 32000, 32000, 918],
            {
                # We use `<image>` before `Image:` to test matches that
                # occur out of order
                "pattern_1": [32000],
                "pattern_2": [9833, 28747],
                "pattern_3": [918],
            },
            {
594
595
596
597
598
599
                # Test whether target is confused with replacement
                "pattern_1": [32000, 32000],
                # Test empty replacement
                "pattern_2": [],
                # Test dynamic replacement (beyond the form of `unit * count`)
                "pattern_3": [1550, 918, 1550],
600
            },
601
602
603
            {
                PromptInsertion: {
                    0: [1, 9833, 28747, 32000, 9833, 28747, 32000, 32000, 918],
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
                    1: [
                        1,
                        9833,
                        28747,
                        32000,
                        32000,
                        32000,
                        9833,
                        28747,
                        32000,
                        32000,
                        918,
                        1550,
                        918,
                        1550,
                    ],  # noqa: E501
                    2: [
                        1,
                        9833,
                        28747,
                        32000,
                        32000,
                        32000,
                        32000,
                        32000,
                        9833,
                        28747,
                        32000,
                        32000,
                        918,
                        1550,
                        918,
                        1550,
                        1550,
                        918,
                        1550,
                    ],  # noqa: E501
641
642
643
644
645
646
647
                },
                PromptReplacement: {
                    0: [1, 9833, 28747, 32000, 9833, 28747, 32000, 32000, 918],
                    1: [1, 32000, 32000, 9833, 28747, 32000, 32000, 1550, 918, 1550],  # noqa: E501
                    2: [1, 32000, 32000, 32000, 32000, 32000, 1550, 918, 1550],
                },
            },
648
        ),
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
        # Test index targets
        (
            [],
            {
                "pattern_1": PromptIndexTargets.start(),
                "pattern_2": PromptIndexTargets.prefix([32000]),
                "pattern_3": PromptIndexTargets.end(),
            },
            {
                "pattern_1": [-1],
                "pattern_2": [-2],
                "pattern_3": [-3],
            },
            {
                PromptInsertion: {
                    0: [],
                    1: [-1, -3],
                    2: [-1, -1, -3, -3],
                },
                PromptReplacement: {
                    0: [],
                    1: [-1, -3],
                    2: [-1, -1, -3, -3],
                },
            },
        ),
        (
            [32000],
            {
                "pattern_1": PromptIndexTargets.start(),
                "pattern_2": PromptIndexTargets.prefix([32000]),
                "pattern_3": PromptIndexTargets.end(),
            },
            {
                "pattern_1": [-1],
                "pattern_2": [-2],
                "pattern_3": [-3],
            },
            {
                PromptInsertion: {
                    0: [32000],
                    1: [-1, 32000, -2, -3],
                    2: [-1, -1, 32000, -2, -2, -3, -3],
                },
                PromptReplacement: {
                    0: [32000],
                    1: [-1, 32000, -2, -3],
                    2: [-1, -1, 32000, -2, -2, -3, -3],
                },
            },
        ),
        # Test different replacement per item
        (
            [32000, 32000, 32000],
            {
                "pattern_1": [32000],
            },
            {
                "pattern_1": lambda idx: [-(idx + 1)],
            },
            {
                PromptInsertion: {
                    0: [32000, 32000, 32000],
                    1: [32000, -1, 32000, 32000],
                    2: [32000, -1, -2, 32000, 32000],
                },
                PromptReplacement: {
                    0: [32000, 32000, 32000],
                    1: [-1, 32000, 32000],
                    2: [-1, -2, 32000],
                },
            },
        ),
        (
            [32000, 32000, 32000],
            {
                "pattern_1": PromptIndexTargets.prefix([32000]),
            },
            {
                "pattern_1": lambda idx: [-(idx + 1)],
            },
            {
                PromptInsertion: {
                    0: [32000, 32000, 32000],
                    1: [32000, -1, 32000, 32000],
                    2: [32000, -1, -2, 32000, 32000],
                },
                PromptReplacement: {
                    0: [32000, 32000, 32000],
                    1: [32000, -1, 32000, 32000],
                    2: [32000, -1, -2, 32000, 32000],
                },
            },
        ),
743
    ],
744
)
745
def test_find_update_tokens(
746
747
748
    prompt,
    target_by_key,
    repl_by_key,
749
    expected_by_update_type_mm_count,
750
751
752
753
):
    # Should not be used since there is nothing to convert to tokens
    mock_tokenizer = cast(AnyTokenizer, object())

754
    for (
755
756
        update_type,
        expected_by_mm_count,
757
758
    ) in expected_by_update_type_mm_count.items():
        for mm_count, expected in expected_by_mm_count.items():
759
            mm_prompt_updates = {
760
761
762
763
                key: [
                    [update_type(key, target, repl_by_key[key]).resolve(i)]
                    for i in range(mm_count)
                ]
764
765
766
767
                for key, target in target_by_key.items()
            }

            new_prompt, result = apply_token_matches(
768
                prompt,
769
770
                mm_prompt_updates,
                mock_tokenizer,
771
772
773
774
775
            )

            # Only displayed on error
            print("update_type:", update_type)
            print("mm_count:", mm_count)
776
777
            print("mm_prompt_updates:", mm_prompt_updates)
            print("new_prompt:", new_prompt)
778
779
780
            print("result:", result)

            # Manually constructed results
781
            assert new_prompt == expected
782
783
784
785
786
787


@pytest.mark.parametrize(
    "repl_by_key",
    [
        {
788
789
790
            "pattern_1": [32000, 32000],
            "pattern_2": [],
            "pattern_3": [1550, 918, 1550],
791
792
            # Test different modalities having the same tokens (32000)
            "pattern_4": [32000],
793
794
795
796
797
798
799
800
        },
    ],
)
@pytest.mark.parametrize(
    ("prompt", "expected"),
    [
        (
            [1, 9833, 28747, 32000, 9833, 28747, 32000, 32000, 918],
801
802
            {
                "pattern_1": [
803
                    PlaceholderFeaturesInfo(
804
805
806
                        modality="pattern_1",
                        item_idx=0,
                        start_idx=6,
807
                        tokens=[32000, 32000],
808
                        is_embed=None,
809
810
                    ),
                ],
811
                "pattern_4": [
812
                    PlaceholderFeaturesInfo(
813
814
815
                        modality="pattern_4",
                        item_idx=0,
                        start_idx=3,
816
                        tokens=[32000],
817
                        is_embed=None,
818
819
                    ),
                ],
820
            },
821
822
        ),
        (
823
            [1, 32000, 32000, 9833, 28747, 32000, 32000, 1550, 918, 1550],
824
825
            {
                "pattern_1": [
826
                    PlaceholderFeaturesInfo(
827
828
829
                        modality="pattern_1",
                        item_idx=0,
                        start_idx=1,
830
                        tokens=[32000, 32000],
831
                        is_embed=None,
832
                    ),
833
                    PlaceholderFeaturesInfo(
834
835
836
                        modality="pattern_1",
                        item_idx=1,
                        start_idx=5,
837
                        tokens=[32000, 32000],
838
                        is_embed=None,
839
840
841
                    ),
                ],
                "pattern_3": [
842
                    PlaceholderFeaturesInfo(
843
844
845
                        modality="pattern_3",
                        item_idx=0,
                        start_idx=7,
846
                        tokens=[1550, 918, 1550],
847
                        is_embed=None,
848
849
                    ),
                ],
850
                # No match for pattern_4 as it has lower priority than pattern_1
851
            },
852
853
        ),
        (
854
            [1, 32000, 32000, 32000, 32000, 32000, 1550, 918, 1550],
855
856
            {
                "pattern_1": [
857
                    PlaceholderFeaturesInfo(
858
859
860
                        modality="pattern_1",
                        item_idx=0,
                        start_idx=1,
861
                        tokens=[32000, 32000],
862
                        is_embed=None,
863
                    ),
864
                    PlaceholderFeaturesInfo(
865
866
867
                        modality="pattern_1",
                        item_idx=1,
                        start_idx=3,
868
                        tokens=[32000, 32000],
869
                        is_embed=None,
870
871
                    ),
                ],
872
                "pattern_4": [
873
                    PlaceholderFeaturesInfo(
874
875
876
                        modality="pattern_4",
                        item_idx=0,
                        start_idx=5,
877
                        tokens=[32000],
878
                        is_embed=None,
879
880
                    ),
                ],
881
                "pattern_3": [
882
                    PlaceholderFeaturesInfo(
883
884
885
                        modality="pattern_3",
                        item_idx=0,
                        start_idx=6,
886
                        tokens=[1550, 918, 1550],
887
                        is_embed=None,
888
889
                    ),
                ],
890
            },
891
        ),
892
    ],
893
)
894
@pytest.mark.parametrize("update_type", [PromptInsertion, PromptReplacement])
895
def test_find_mm_placeholders(
896
897
898
    repl_by_key,
    prompt,
    expected,
899
    update_type,
900
901
902
903
):
    # Should not be used since there is nothing to convert to tokens
    mock_tokenizer = cast(AnyTokenizer, object())

904
    mm_prompt_updates = {
905
        key: [[update_type(key, [], repl).resolve(i)] for i in range(3)]
906
        for key, repl in repl_by_key.items()
907
    }
908

909
    result = find_mm_placeholders(prompt, mm_prompt_updates, mock_tokenizer)
910
911
912

    # Only displayed on error
    print("result:", result)
913
914

    # Manually constructed results
915
    assert result == expected
916
917


918
@pytest.mark.parametrize("model_id", ["llava-hf/llava-v1.6-mistral-7b-hf"])
919
920
@pytest.mark.parametrize(
    ("limit", "num_supported", "is_valid"),
921
922
923
924
925
926
927
928
929
    [
        (0, 0, True),
        (0, 1, True),
        (1, 0, False),
        (1, 1, True),
        (1, 2, True),
        (2, 1, False),
        (2, 2, True),
    ],
930
931
932
933
934
935
936
937
938
)
def test_limit_mm_per_prompt_dummy(model_id, limit, num_supported, is_valid):
    limit_mm_per_prompt = {"image": limit}

    model_config = ModelConfig(
        model=model_id,
        limit_mm_per_prompt=limit_mm_per_prompt,
    )

939
    processor = MULTIMODAL_REGISTRY.create_processor(model_config)
940
    processor._supported_mm_limits = {"image": num_supported}
941

942
    profiler = MultiModalProfiler(processor)
943

944
    exc_ctx = nullcontext() if is_valid else pytest.raises(ValueError, match="At most")
945
946

    with exc_ctx:
947
948
949
950
        profiler.get_decoder_dummy_data(
            model_config.max_model_len,
            mm_counts=limit_mm_per_prompt,
        )
951
952


953
@pytest.mark.parametrize("model_id", ["llava-hf/llava-v1.6-mistral-7b-hf"])
954
955
@pytest.mark.parametrize(
    ("num_images", "limit", "is_valid"),
956
957
958
959
960
961
962
963
964
    [
        (0, 0, True),
        (0, 1, True),
        (1, 0, False),
        (1, 1, True),
        (1, 2, True),
        (2, 1, False),
        (2, 2, True),
    ],
965
966
967
968
969
970
971
972
973
)
def test_limit_mm_per_prompt_apply(model_id, num_images, limit, is_valid):
    limit_mm_per_prompt = {"image": limit}

    model_config = ModelConfig(
        model=model_id,
        limit_mm_per_prompt=limit_mm_per_prompt,
    )

974
    processor = MULTIMODAL_REGISTRY.create_processor(model_config)
975
976

    rng = np.random.RandomState(0)
977
    image = random_image(rng, min_wh=128, max_wh=256)
978
979
980
981
982
983
984
    if num_images == 0:
        mm_data = {}
    elif num_images == 1:
        mm_data = {"image": image}
    else:
        mm_data = {"image": [image] * num_images}

985
    exc_ctx = nullcontext() if is_valid else pytest.raises(ValueError, match="At most")
986
987
988
989
990
991
992

    with exc_ctx:
        processor.apply(
            "<image>" * num_images,
            mm_data=mm_data,
            hf_processor_mm_kwargs={},
        )
993
994


995
996
class DummyProcessor:
    def __init__(self, a: int = 0, b: int = 0) -> None:
997
998
        super().__init__()

999
1000
        self.a = a
        self.b = b
1001
1002
1003

    def __call__(
        self,
1004
1005
1006
1007
1008
        a: int = 0,
        c: int = 0,
        return_tensors: Optional[str] = None,
    ) -> dict[str, int]:
        return dict(a=a, c=c)
1009
1010


1011
@pytest.mark.parametrize("model_id", ["Qwen/Qwen2-VL-2B-Instruct"])  # Dummy
1012
@pytest.mark.parametrize(
1013
    ("config_kwargs", "inference_kwargs", "expected_kwargs"),
1014
    [
1015
1016
1017
1018
1019
1020
1021
        ({"a": 1}, {}, {"a": 1, "b": 0}),
        ({}, {"a": 1}, {"a": 1, "b": 0}),
        # inference_kwargs should take precedence
        ({"a": 1}, {"a": 2}, {"a": 2, "b": 0}),
        # Should ignore extra kwargs
        ({"a": 1, "c": 1}, {}, {"a": 1, "b": 0}),
        ({"b": 1, "c": 1}, {}, {"a": 0, "b": 1}),
1022
1023
    ],
)
1024
1025
1026
1027
1028
1029
1030
1031
def test_hf_processor_init_kwargs(
    model_id,
    config_kwargs,
    inference_kwargs,
    expected_kwargs,
):
    # Should not be used since there is nothing to convert to tokens
    mock_tokenizer = cast(AnyTokenizer, object())
1032

1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
    ctx = InputProcessingContext(
        model_config=ModelConfig(model_id, mm_processor_kwargs=config_kwargs),
        tokenizer=mock_tokenizer,
    )

    processor = ctx.get_hf_processor(
        DummyProcessor,  # type: ignore[arg-type]
        **inference_kwargs,
    )

    for k, v in expected_kwargs.items():
        assert getattr(processor, k) == v
1045
1046


1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
@pytest.mark.parametrize("model_id", ["Qwen/Qwen2-VL-2B-Instruct"])  # Dummy
@pytest.mark.parametrize(
    ("config_kwargs", "inference_kwargs", "expected_kwargs"),
    [
        ({"a": 1}, {}, {"a": 1, "c": 0}),
        ({}, {"a": 1}, {"a": 1, "c": 0}),
        # inference_kwargs should take precedence
        ({"a": 1}, {"a": 2}, {"a": 2, "c": 0}),
        # Should ignore extra kwargs
        ({"a": 1, "c": 1}, {}, {"a": 1, "c": 1}),
        ({"b": 1, "c": 1}, {}, {"a": 0, "c": 1}),
    ],
)
def test_hf_processor_call_kwargs(
    model_id,
    config_kwargs,
    inference_kwargs,
    expected_kwargs,
):
    # Should not be used since there is nothing to convert to tokens
    mock_tokenizer = cast(AnyTokenizer, object())
1068

1069
1070
1071
    ctx = InputProcessingContext(
        model_config=ModelConfig(model_id, mm_processor_kwargs=config_kwargs),
        tokenizer=mock_tokenizer,
1072
1073
    )

1074
1075
1076
1077
    processor = ctx.get_hf_processor(DummyProcessor)  # type: ignore[arg-type]

    result = ctx.call_hf_processor(processor, {}, inference_kwargs)
    assert result == expected_kwargs