test_processing.py 32.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import time
5
from contextlib import nullcontext
6

7
import numpy as np
8
import pytest
9

10
from vllm.config import ModelConfig
11
from vllm.multimodal import MULTIMODAL_REGISTRY
12
13
14
15
16
17
from vllm.multimodal.processing import (
    InputProcessingContext,
    PlaceholderFeaturesInfo,
    PromptIndexTargets,
    PromptInsertion,
    PromptReplacement,
18
    _apply_matches,
19
20
21
22
23
24
    apply_text_matches,
    apply_token_matches,
    find_mm_placeholders,
    iter_token_matches,
    replace_token_matches,
)
25

26
27
from .utils import random_image

28
29
pytestmark = pytest.mark.cpu_test

30
31
32
33

@pytest.mark.parametrize(
    ("token_ids", "match_ids", "expected"),
    [
34
        ([], [], []),
35
36
37
38
39
        ([], [32000], []),
        (
            [32000, 32000, 32000],
            [32000],
            [
40
41
42
                {"start_idx": 0, "end_idx": 1},
                {"start_idx": 1, "end_idx": 2},
                {"start_idx": 2, "end_idx": 3},
43
44
45
46
47
            ],
        ),
        (
            [32000, 32000, 32000],
            [32000, 32000],
48
            [{"start_idx": 0, "end_idx": 2}],
49
50
51
52
        ),
        (
            [32000, 32000, 32000],
            [32000, 32000, 32000],
53
            [{"start_idx": 0, "end_idx": 3}],
54
55
56
57
58
        ),
        (
            [9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918],
            [28747, 32000],
            [
59
60
                {"start_idx": 1, "end_idx": 3},
                {"start_idx": 6, "end_idx": 8},
61
62
63
64
65
66
            ],
        ),
        (
            [9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918],
            [28747, 32000, 32000, 32000],
            [
67
                {"start_idx": 1, "end_idx": 5},
68
69
70
71
72
73
74
75
76
            ],
        ),
        (
            [9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918],
            [28747, 0, 32000],
            [],
        ),
    ],
)
77
78
@pytest.mark.parametrize("start_idx", [0, 4, 8])
def test_iter_token_matches(token_ids, match_ids, expected, start_idx):
79
    result = list(iter_token_matches(token_ids, match_ids, start_idx=start_idx))
80
81

    # Manually constructed results
82
83
84
    assert [item._asdict() for item in result] == [
        item for item in expected if item["start_idx"] >= start_idx
    ]
85
86
87
88
89
90
91

    # Invariants
    match_lens = [end - start for start, end in result]
    print("match_lens:", match_lens)  # Only displayed on error
    assert all(match_len == len(match_ids) for match_len in match_lens)


92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
@pytest.mark.parametrize(
    ("token_ids", "match_ids", "new_ids", "expected"),
    [
        ([], [], [-1], []),
        ([], [32000], [-1], []),
        (
            [32000, 32000, 32000],
            [32000],
            [-1],
            [-1, -1, -1],
        ),
        (
            [32000, 32000, 32000],
            [32000, 32000],
            [-1],
            [-1, 32000],
        ),
        (
            [32000, 32000, 32000],
            [32000, 32000, 32000],
            [-1],
            [-1],
        ),
        (
            [9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918],
            [28747, 32000],
            [-1],
            [9833, -1, 32000, 32000, 9833, -1, 32000, 918],
        ),
        (
            [9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918],
            [28747, 32000, 32000, 32000],
            [-1],
            [9833, -1, 9833, 28747, 32000, 32000, 918],
        ),
        (
            [9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918],
            [28747, 0, 32000],
            [-1],
            [9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918],
        ),
    ],
)
def test_replace_token_matches(token_ids, match_ids, new_ids, expected):
    result = replace_token_matches(token_ids, match_ids, new_ids)

    # Manually constructed results
    assert result == expected


142
143
144
145
146
147
148
149
@pytest.mark.parametrize(
    ("prompt", "target_by_key", "expected_by_key"),
    [
        (
            [],
            {
                "pattern_1": [],
                "pattern_2": [32000],
150
151
152
                "pattern_3": PromptIndexTargets.start(),
                "pattern_4": PromptIndexTargets.prefix([32000]),
                "pattern_5": PromptIndexTargets.end(),
153
154
            },
            {
155
                "pattern_1": [],
156
                "pattern_2": [],
157
                "pattern_3": [
158
                    {"start_idx": 0, "end_idx": 0},
159
160
161
                ],
                "pattern_4": [],
                "pattern_5": [
162
                    {"start_idx": 0, "end_idx": 0},
163
                ],
164
            },
165
166
167
168
169
170
171
        ),
        (
            [32000, 32000, 32000, 32000],
            {
                "pattern_1": [32000],
                "pattern_2": [32000, 32000],
                "pattern_3": [32000, 32000, 32000],
172
173
174
                "pattern_4": PromptIndexTargets.start(),
                "pattern_5": PromptIndexTargets.prefix([32000]),
                "pattern_6": PromptIndexTargets.end(),
175
176
177
            },
            {
                "pattern_1": [
178
179
180
181
                    {"start_idx": 0, "end_idx": 1},
                    {"start_idx": 1, "end_idx": 2},
                    {"start_idx": 2, "end_idx": 3},
                    {"start_idx": 3, "end_idx": 4},
182
183
                ],
                "pattern_2": [
184
185
                    {"start_idx": 0, "end_idx": 2},
                    {"start_idx": 2, "end_idx": 4},
186
187
                ],
                "pattern_3": [
188
                    {"start_idx": 0, "end_idx": 3},
189
                ],
190
                "pattern_4": [
191
                    {"start_idx": 0, "end_idx": 0},
192
193
                ],
                "pattern_5": [
194
                    {"start_idx": 1, "end_idx": 1},
195
196
                ],
                "pattern_6": [
197
                    {"start_idx": 4, "end_idx": 4},
198
                ],
199
200
201
202
203
204
205
206
            },
        ),
        (
            [9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918],
            {
                "pattern_1": [28747, 32000],
                "pattern_2": [28747, 32000, 32000, 32000],
                "pattern_3": [28747, 0, 32000],
207
208
209
                "pattern_4": PromptIndexTargets.start(),
                "pattern_5": PromptIndexTargets.prefix([28747, 32000]),
                "pattern_6": PromptIndexTargets.end(),
210
211
212
            },
            {
                "pattern_1": [
213
214
                    {"start_idx": 1, "end_idx": 3},
                    {"start_idx": 6, "end_idx": 8},
215
216
                ],
                "pattern_2": [
217
                    {"start_idx": 1, "end_idx": 5},
218
219
                ],
                "pattern_3": [],
220
                "pattern_4": [
221
                    {"start_idx": 0, "end_idx": 0},
222
223
224
                ],
                "pattern_5": [],
                "pattern_6": [
225
                    {"start_idx": 10, "end_idx": 10},
226
                ],
227
228
229
230
            },
        ),
    ],
)
231
232
233
234
235
236
237
@pytest.mark.parametrize("update_type", [PromptInsertion, PromptReplacement])
def test_find_token_matches(
    prompt,
    target_by_key,
    expected_by_key,
    update_type,
):
238
    prompt_updates = {
239
        key: update_type(key, target, []).resolve(0)
240
        for key, target in target_by_key.items()
241
242
    }
    result = {
243
        key: list(update.iter_token_matches(prompt, tokenizer=None))
244
245
        for key, update in prompt_updates.items()
    }
246
247
248
249
250
251
252
253

    # Only displayed on error
    print("result:", result)

    # Manually constructed results
    assert {
        key: [
            dict(start_idx=item.start_idx, end_idx=item.end_idx)
254
            for item in result.get(key, [])
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
        ]
        for key in expected_by_key
    } == expected_by_key


@pytest.mark.parametrize(
    ("prompt", "target_by_key", "expected_by_key"),
    [
        # Detokenized test cases of `test_find_token_matches`
        # using the vocab of llava-hf/llava-v1.6-mistral-7b-hf
        (
            "",
            {
                "pattern_1": "",
                "pattern_2": "<image>",
270
271
272
                "pattern_3": PromptIndexTargets.start(),
                "pattern_4": PromptIndexTargets.prefix("<image>"),
                "pattern_5": PromptIndexTargets.end(),
273
274
            },
            {
275
                "pattern_1": [{"start_idx": 0, "end_idx": 0}],
276
                "pattern_2": [],
277
                "pattern_3": [
278
                    {"start_idx": 0, "end_idx": 0},
279
280
281
                ],
                "pattern_4": [],
                "pattern_5": [
282
                    {"start_idx": 0, "end_idx": 0},
283
                ],
284
            },
285
286
287
288
289
290
291
        ),
        (
            "<image><image><image><image>",
            {
                "pattern_1": "<image>",
                "pattern_2": "<image><image>",
                "pattern_3": "<image><image><image>",
292
293
294
                "pattern_4": PromptIndexTargets.start(),
                "pattern_5": PromptIndexTargets.prefix("<image>"),
                "pattern_6": PromptIndexTargets.end(),
295
296
297
            },
            {
                "pattern_1": [
298
299
300
301
                    {"start_idx": 0, "end_idx": 7},
                    {"start_idx": 7, "end_idx": 14},
                    {"start_idx": 14, "end_idx": 21},
                    {"start_idx": 21, "end_idx": 28},
302
303
                ],
                "pattern_2": [
304
305
                    {"start_idx": 0, "end_idx": 14},
                    {"start_idx": 14, "end_idx": 28},
306
307
                ],
                "pattern_3": [
308
                    {"start_idx": 0, "end_idx": 21},
309
                ],
310
                "pattern_4": [
311
                    {"start_idx": 0, "end_idx": 0},
312
313
                ],
                "pattern_5": [
314
                    {"start_idx": 7, "end_idx": 7},
315
316
                ],
                "pattern_6": [
317
                    {"start_idx": 28, "end_idx": 28},
318
                ],
319
320
321
322
323
324
325
326
            },
        ),
        (
            "Image:<image><image><image>Image:<image><image>!",
            {
                "pattern_1": "Image:<image>",
                "pattern_2": "Image:<image><image><image>",
                "pattern_3": "Image:<unk><image>",
327
328
329
                "pattern_4": PromptIndexTargets.start(),
                "pattern_5": PromptIndexTargets.prefix("Image:<image>"),
                "pattern_6": PromptIndexTargets.end(),
330
331
332
            },
            {
                "pattern_1": [
333
334
                    {"start_idx": 0, "end_idx": 13},
                    {"start_idx": 27, "end_idx": 40},
335
336
                ],
                "pattern_2": [
337
                    {"start_idx": 0, "end_idx": 27},
338
339
                ],
                "pattern_3": [],
340
                "pattern_4": [
341
                    {"start_idx": 0, "end_idx": 0},
342
343
                ],
                "pattern_5": [
344
                    {"start_idx": 13, "end_idx": 13},
345
346
                ],
                "pattern_6": [
347
                    {"start_idx": 48, "end_idx": 48},
348
                ],
349
350
351
352
353
354
355
356
357
358
359
360
            },
        ),
        # Test regex escape
        (
            "<|image|><image><|image|><image>",
            {
                "pattern_1": "<|image|>",
                "pattern_2": "<|image|><image>",
                "pattern_3": "<|image|><image><|image|>",
            },
            {
                "pattern_1": [
361
362
                    {"start_idx": 0, "end_idx": 9},
                    {"start_idx": 16, "end_idx": 25},
363
364
                ],
                "pattern_2": [
365
366
                    {"start_idx": 0, "end_idx": 16},
                    {"start_idx": 16, "end_idx": 32},
367
368
                ],
                "pattern_3": [
369
                    {"start_idx": 0, "end_idx": 25},
370
371
372
373
374
                ],
            },
        ),
    ],
)
375
376
377
378
379
380
381
@pytest.mark.parametrize("update_type", [PromptInsertion, PromptReplacement])
def test_find_text_matches(
    prompt,
    target_by_key,
    expected_by_key,
    update_type,
):
382
    prompt_updates = {
383
        key: update_type(key, target, []).resolve(0)
384
        for key, target in target_by_key.items()
385
386
    }
    result = {
387
        key: list(update.iter_text_matches(prompt, tokenizer=None))
388
389
        for key, update in prompt_updates.items()
    }
390
391
392
393
394
395
396
397

    # Only displayed on error
    print("result:", result)

    # Manually constructed results
    assert {
        key: [
            dict(start_idx=item.start_idx, end_idx=item.end_idx)
398
            for item in result.get(key, [])
399
400
401
402
403
404
        ]
        for key in expected_by_key
    } == expected_by_key


@pytest.mark.parametrize(
405
    ("prompt", "target_by_key", "repl_by_key", "expected_by_update_type_mm_count"),  # noqa: E501
406
407
408
409
410
411
412
413
414
415
416
    [
        (
            "Image:<image>Image:<image><image>!",
            {
                # We use `<image>` before `Image:` to test matches that
                # occur out of order
                "pattern_1": "<image>",
                "pattern_2": "Image:",
                "pattern_3": "!",
            },
            {
417
418
419
420
421
422
                # Test whether target is confused with replacement
                "pattern_1": "<image><image>",
                # Test empty replacement
                "pattern_2": "",
                # Test dynamic replacement (beyond the form of `unit * count`)
                "pattern_3": "?!?",
423
            },
424
425
426
427
428
429
430
431
432
433
434
435
            {
                PromptInsertion: {
                    0: "Image:<image>Image:<image><image>!",
                    1: "Image:<image><image><image>Image:<image><image>!?!?",
                    2: "Image:<image><image><image><image><image>Image:<image><image>!?!??!?",  # noqa: E501
                },
                PromptReplacement: {
                    0: "Image:<image>Image:<image><image>!",
                    1: "<image><image>Image:<image><image>?!?",
                    2: "<image><image><image><image><image>?!?",
                },
            },
436
        ),
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
        # Test index targets
        (
            "",
            {
                "pattern_1": PromptIndexTargets.start(),
                "pattern_2": PromptIndexTargets.prefix("<image>"),
                "pattern_3": PromptIndexTargets.end(),
            },
            {
                "pattern_1": "1",
                "pattern_2": "2",
                "pattern_3": "3",
            },
            {
                PromptInsertion: {
                    0: "",
                    1: "13",
                    2: "1133",
                },
                PromptReplacement: {
                    0: "",
                    1: "13",
                    2: "1133",
                },
            },
        ),
        (
            "<image>",
            {
                "pattern_1": PromptIndexTargets.start(),
                "pattern_2": PromptIndexTargets.prefix("<image>"),
                "pattern_3": PromptIndexTargets.end(),
            },
            {
                "pattern_1": "1",
                "pattern_2": "2",
                "pattern_3": "3",
            },
            {
                PromptInsertion: {
                    0: "<image>",
                    1: "1<image>23",
                    2: "11<image>2233",
                },
                PromptReplacement: {
                    0: "<image>",
                    1: "1<image>23",
                    2: "11<image>2233",
                },
            },
        ),
        # Test different replacement per item
        (
            "<image><image><image>",
            {
                "pattern_1": "<image>",
            },
            {
                "pattern_1": lambda idx: str(idx + 1),
            },
            {
                PromptInsertion: {
                    0: "<image><image><image>",
                    1: "<image>1<image><image>",
                    2: "<image>12<image><image>",
                },
                PromptReplacement: {
                    0: "<image><image><image>",
                    1: "1<image><image>",
                    2: "12<image>",
                },
            },
        ),
        (
            "<image><image><image>",
            {
                "pattern_1": PromptIndexTargets.prefix("<image>"),
            },
            {
                "pattern_1": lambda idx: str(idx + 1),
            },
            {
                PromptInsertion: {
                    0: "<image><image><image>",
                    1: "<image>1<image><image>",
                    2: "<image>12<image><image>",
                },
                PromptReplacement: {
                    0: "<image><image><image>",
                    1: "<image>1<image><image>",
                    2: "<image>12<image><image>",
                },
            },
        ),
531
    ],
532
)
533
def test_find_update_text(
534
535
536
    prompt,
    target_by_key,
    repl_by_key,
537
    expected_by_update_type_mm_count,
538
):
539
    for (
540
541
        update_type,
        expected_by_mm_count,
542
543
    ) in expected_by_update_type_mm_count.items():
        for mm_count, expected in expected_by_mm_count.items():
544
            mm_prompt_updates = {
545
546
547
548
                key: [
                    [update_type(key, target, repl_by_key[key]).resolve(i)]
                    for i in range(mm_count)
                ]
549
550
551
552
                for key, target in target_by_key.items()
            }

            new_prompt, result = apply_text_matches(
553
                prompt,
554
                mm_prompt_updates,
555
                tokenizer=None,
556
557
558
559
560
            )

            # Only displayed on error
            print("update_type:", update_type)
            print("mm_count:", mm_count)
561
562
            print("mm_prompt_updates:", mm_prompt_updates)
            print("new_prompt:", new_prompt)
563
564
565
            print("result:", result)

            # Manually constructed results
566
            assert new_prompt == expected
567
568
569


@pytest.mark.parametrize(
570
    ("prompt", "target_by_key", "repl_by_key", "expected_by_update_type_mm_count"),  # noqa: E501
571
    [
572
        # Tokenized test cases of `test_find_update_text`
573
574
575
576
577
578
579
580
581
582
583
        # using the vocab of llava-hf/llava-v1.6-mistral-7b-hf
        (
            [1, 9833, 28747, 32000, 9833, 28747, 32000, 32000, 918],
            {
                # We use `<image>` before `Image:` to test matches that
                # occur out of order
                "pattern_1": [32000],
                "pattern_2": [9833, 28747],
                "pattern_3": [918],
            },
            {
584
585
586
587
588
589
                # Test whether target is confused with replacement
                "pattern_1": [32000, 32000],
                # Test empty replacement
                "pattern_2": [],
                # Test dynamic replacement (beyond the form of `unit * count`)
                "pattern_3": [1550, 918, 1550],
590
            },
591
592
593
            {
                PromptInsertion: {
                    0: [1, 9833, 28747, 32000, 9833, 28747, 32000, 32000, 918],
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
                    1: [
                        1,
                        9833,
                        28747,
                        32000,
                        32000,
                        32000,
                        9833,
                        28747,
                        32000,
                        32000,
                        918,
                        1550,
                        918,
                        1550,
                    ],  # noqa: E501
                    2: [
                        1,
                        9833,
                        28747,
                        32000,
                        32000,
                        32000,
                        32000,
                        32000,
                        9833,
                        28747,
                        32000,
                        32000,
                        918,
                        1550,
                        918,
                        1550,
                        1550,
                        918,
                        1550,
                    ],  # noqa: E501
631
632
633
634
635
636
637
                },
                PromptReplacement: {
                    0: [1, 9833, 28747, 32000, 9833, 28747, 32000, 32000, 918],
                    1: [1, 32000, 32000, 9833, 28747, 32000, 32000, 1550, 918, 1550],  # noqa: E501
                    2: [1, 32000, 32000, 32000, 32000, 32000, 1550, 918, 1550],
                },
            },
638
        ),
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
        # Test index targets
        (
            [],
            {
                "pattern_1": PromptIndexTargets.start(),
                "pattern_2": PromptIndexTargets.prefix([32000]),
                "pattern_3": PromptIndexTargets.end(),
            },
            {
                "pattern_1": [-1],
                "pattern_2": [-2],
                "pattern_3": [-3],
            },
            {
                PromptInsertion: {
                    0: [],
                    1: [-1, -3],
                    2: [-1, -1, -3, -3],
                },
                PromptReplacement: {
                    0: [],
                    1: [-1, -3],
                    2: [-1, -1, -3, -3],
                },
            },
        ),
        (
            [32000],
            {
                "pattern_1": PromptIndexTargets.start(),
                "pattern_2": PromptIndexTargets.prefix([32000]),
                "pattern_3": PromptIndexTargets.end(),
            },
            {
                "pattern_1": [-1],
                "pattern_2": [-2],
                "pattern_3": [-3],
            },
            {
                PromptInsertion: {
                    0: [32000],
                    1: [-1, 32000, -2, -3],
                    2: [-1, -1, 32000, -2, -2, -3, -3],
                },
                PromptReplacement: {
                    0: [32000],
                    1: [-1, 32000, -2, -3],
                    2: [-1, -1, 32000, -2, -2, -3, -3],
                },
            },
        ),
        # Test different replacement per item
        (
            [32000, 32000, 32000],
            {
                "pattern_1": [32000],
            },
            {
                "pattern_1": lambda idx: [-(idx + 1)],
            },
            {
                PromptInsertion: {
                    0: [32000, 32000, 32000],
                    1: [32000, -1, 32000, 32000],
                    2: [32000, -1, -2, 32000, 32000],
                },
                PromptReplacement: {
                    0: [32000, 32000, 32000],
                    1: [-1, 32000, 32000],
                    2: [-1, -2, 32000],
                },
            },
        ),
        (
            [32000, 32000, 32000],
            {
                "pattern_1": PromptIndexTargets.prefix([32000]),
            },
            {
                "pattern_1": lambda idx: [-(idx + 1)],
            },
            {
                PromptInsertion: {
                    0: [32000, 32000, 32000],
                    1: [32000, -1, 32000, 32000],
                    2: [32000, -1, -2, 32000, 32000],
                },
                PromptReplacement: {
                    0: [32000, 32000, 32000],
                    1: [32000, -1, 32000, 32000],
                    2: [32000, -1, -2, 32000, 32000],
                },
            },
        ),
733
    ],
734
)
735
def test_find_update_tokens(
736
737
738
    prompt,
    target_by_key,
    repl_by_key,
739
    expected_by_update_type_mm_count,
740
):
741
    for (
742
743
        update_type,
        expected_by_mm_count,
744
745
    ) in expected_by_update_type_mm_count.items():
        for mm_count, expected in expected_by_mm_count.items():
746
            mm_prompt_updates = {
747
748
749
750
                key: [
                    [update_type(key, target, repl_by_key[key]).resolve(i)]
                    for i in range(mm_count)
                ]
751
752
753
754
                for key, target in target_by_key.items()
            }

            new_prompt, result = apply_token_matches(
755
                prompt,
756
                mm_prompt_updates,
757
                tokenizer=None,
758
759
760
761
762
            )

            # Only displayed on error
            print("update_type:", update_type)
            print("mm_count:", mm_count)
763
764
            print("mm_prompt_updates:", mm_prompt_updates)
            print("new_prompt:", new_prompt)
765
766
767
            print("result:", result)

            # Manually constructed results
768
            assert new_prompt == expected
769
770
771
772
773
774


@pytest.mark.parametrize(
    "repl_by_key",
    [
        {
775
776
777
            "pattern_1": [32000, 32000],
            "pattern_2": [],
            "pattern_3": [1550, 918, 1550],
778
779
            # Test different modalities having the same tokens (32000)
            "pattern_4": [32000],
780
781
782
783
784
785
786
787
        },
    ],
)
@pytest.mark.parametrize(
    ("prompt", "expected"),
    [
        (
            [1, 9833, 28747, 32000, 9833, 28747, 32000, 32000, 918],
788
789
            {
                "pattern_1": [
790
                    PlaceholderFeaturesInfo(
791
792
793
                        modality="pattern_1",
                        item_idx=0,
                        start_idx=6,
794
                        tokens=[32000, 32000],
795
                        is_embed=None,
796
797
                    ),
                ],
798
                "pattern_4": [
799
                    PlaceholderFeaturesInfo(
800
801
802
                        modality="pattern_4",
                        item_idx=0,
                        start_idx=3,
803
                        tokens=[32000],
804
                        is_embed=None,
805
806
                    ),
                ],
807
            },
808
809
        ),
        (
810
            [1, 32000, 32000, 9833, 28747, 32000, 32000, 1550, 918, 1550],
811
812
            {
                "pattern_1": [
813
                    PlaceholderFeaturesInfo(
814
815
816
                        modality="pattern_1",
                        item_idx=0,
                        start_idx=1,
817
                        tokens=[32000, 32000],
818
                        is_embed=None,
819
                    ),
820
                    PlaceholderFeaturesInfo(
821
822
823
                        modality="pattern_1",
                        item_idx=1,
                        start_idx=5,
824
                        tokens=[32000, 32000],
825
                        is_embed=None,
826
827
828
                    ),
                ],
                "pattern_3": [
829
                    PlaceholderFeaturesInfo(
830
831
832
                        modality="pattern_3",
                        item_idx=0,
                        start_idx=7,
833
                        tokens=[1550, 918, 1550],
834
                        is_embed=None,
835
836
                    ),
                ],
837
                # No match for pattern_4 as it has lower priority than pattern_1
838
            },
839
840
        ),
        (
841
            [1, 32000, 32000, 32000, 32000, 32000, 1550, 918, 1550],
842
843
            {
                "pattern_1": [
844
                    PlaceholderFeaturesInfo(
845
846
847
                        modality="pattern_1",
                        item_idx=0,
                        start_idx=1,
848
                        tokens=[32000, 32000],
849
                        is_embed=None,
850
                    ),
851
                    PlaceholderFeaturesInfo(
852
853
854
                        modality="pattern_1",
                        item_idx=1,
                        start_idx=3,
855
                        tokens=[32000, 32000],
856
                        is_embed=None,
857
858
                    ),
                ],
859
                "pattern_4": [
860
                    PlaceholderFeaturesInfo(
861
862
863
                        modality="pattern_4",
                        item_idx=0,
                        start_idx=5,
864
                        tokens=[32000],
865
                        is_embed=None,
866
867
                    ),
                ],
868
                "pattern_3": [
869
                    PlaceholderFeaturesInfo(
870
871
872
                        modality="pattern_3",
                        item_idx=0,
                        start_idx=6,
873
                        tokens=[1550, 918, 1550],
874
                        is_embed=None,
875
876
                    ),
                ],
877
            },
878
        ),
879
    ],
880
)
881
@pytest.mark.parametrize("update_type", [PromptInsertion, PromptReplacement])
882
def test_find_mm_placeholders(
883
884
885
    repl_by_key,
    prompt,
    expected,
886
    update_type,
887
):
888
    mm_prompt_updates = {
889
        key: [[update_type(key, [], repl).resolve(i)] for i in range(3)]
890
        for key, repl in repl_by_key.items()
891
    }
892

893
    result = find_mm_placeholders(prompt, mm_prompt_updates, tokenizer=None)
894
895
896

    # Only displayed on error
    print("result:", result)
897
898

    # Manually constructed results
899
    assert result == expected
900
901


902
@pytest.mark.parametrize("model_id", ["llava-hf/llava-v1.6-mistral-7b-hf"])
903
904
@pytest.mark.parametrize(
    ("limit", "num_supported", "is_valid"),
905
906
907
908
909
910
911
912
913
    [
        (0, 0, True),
        (0, 1, True),
        (1, 0, False),
        (1, 1, True),
        (1, 2, True),
        (2, 1, False),
        (2, 2, True),
    ],
914
915
916
917
918
919
920
921
922
)
def test_limit_mm_per_prompt_dummy(model_id, limit, num_supported, is_valid):
    limit_mm_per_prompt = {"image": limit}

    model_config = ModelConfig(
        model=model_id,
        limit_mm_per_prompt=limit_mm_per_prompt,
    )

923
    processor = MULTIMODAL_REGISTRY.create_processor(model_config)
924
    processor._supported_mm_limits = {"image": num_supported}
925

926
    exc_ctx = nullcontext() if is_valid else pytest.raises(ValueError, match="At most")
927
928

    with exc_ctx:
929
930
        processor.dummy_inputs.get_decoder_dummy_data(
            processor,
931
932
933
            model_config.max_model_len,
            mm_counts=limit_mm_per_prompt,
        )
934
935


936
@pytest.mark.parametrize("model_id", ["llava-hf/llava-v1.6-mistral-7b-hf"])
937
938
@pytest.mark.parametrize(
    ("num_images", "limit", "is_valid"),
939
940
941
942
943
944
945
946
947
    [
        (0, 0, True),
        (0, 1, True),
        (1, 0, False),
        (1, 1, True),
        (1, 2, True),
        (2, 1, False),
        (2, 2, True),
    ],
948
949
950
951
952
953
954
955
956
)
def test_limit_mm_per_prompt_apply(model_id, num_images, limit, is_valid):
    limit_mm_per_prompt = {"image": limit}

    model_config = ModelConfig(
        model=model_id,
        limit_mm_per_prompt=limit_mm_per_prompt,
    )

957
    processor = MULTIMODAL_REGISTRY.create_processor(model_config)
958
959

    rng = np.random.RandomState(0)
960
    image = random_image(rng, min_wh=128, max_wh=256)
961
962
963
964
965
966
967
    if num_images == 0:
        mm_data = {}
    elif num_images == 1:
        mm_data = {"image": image}
    else:
        mm_data = {"image": [image] * num_images}

968
    exc_ctx = nullcontext() if is_valid else pytest.raises(ValueError, match="At most")
969
970
971
972
973
974
975

    with exc_ctx:
        processor.apply(
            "<image>" * num_images,
            mm_data=mm_data,
            hf_processor_mm_kwargs={},
        )
976
977


978
979
class DummyProcessor:
    def __init__(self, a: int = 0, b: int = 0) -> None:
980
981
        super().__init__()

982
983
        self.a = a
        self.b = b
984
985
986

    def __call__(
        self,
987
988
        a: int = 0,
        c: int = 0,
989
        return_tensors: str | None = None,
990
991
    ) -> dict[str, int]:
        return dict(a=a, c=c)
992
993


994
@pytest.mark.parametrize("model_id", ["Qwen/Qwen2-VL-2B-Instruct"])  # Dummy
995
@pytest.mark.parametrize(
996
    ("config_kwargs", "inference_kwargs", "expected_kwargs"),
997
    [
998
999
1000
1001
1002
1003
1004
        ({"a": 1}, {}, {"a": 1, "b": 0}),
        ({}, {"a": 1}, {"a": 1, "b": 0}),
        # inference_kwargs should take precedence
        ({"a": 1}, {"a": 2}, {"a": 2, "b": 0}),
        # Should ignore extra kwargs
        ({"a": 1, "c": 1}, {}, {"a": 1, "b": 0}),
        ({"b": 1, "c": 1}, {}, {"a": 0, "b": 1}),
1005
1006
    ],
)
1007
1008
1009
1010
1011
1012
def test_hf_processor_init_kwargs(
    model_id,
    config_kwargs,
    inference_kwargs,
    expected_kwargs,
):
1013
1014
1015
    ctx = InputProcessingContext(
        model_config=ModelConfig(model_id, mm_processor_kwargs=config_kwargs),
        tokenizer=None,
1016
1017
1018
1019
1020
1021
    )

    processor = ctx.get_hf_processor(
        DummyProcessor,  # type: ignore[arg-type]
        **inference_kwargs,
    )
1022
1023
    assert processor.a == expected_kwargs["a"]
    assert processor.b == expected_kwargs["b"]
1024
1025


1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
@pytest.mark.parametrize("model_id", ["Qwen/Qwen2-VL-2B-Instruct"])  # Dummy
@pytest.mark.parametrize(
    ("config_kwargs", "inference_kwargs", "expected_kwargs"),
    [
        ({"a": 1}, {}, {"a": 1, "c": 0}),
        ({}, {"a": 1}, {"a": 1, "c": 0}),
        # inference_kwargs should take precedence
        ({"a": 1}, {"a": 2}, {"a": 2, "c": 0}),
        # Should ignore extra kwargs
        ({"a": 1, "c": 1}, {}, {"a": 1, "c": 1}),
        ({"b": 1, "c": 1}, {}, {"a": 0, "c": 1}),
    ],
)
def test_hf_processor_call_kwargs(
    model_id,
    config_kwargs,
    inference_kwargs,
    expected_kwargs,
):
1045
1046
1047
    ctx = InputProcessingContext(
        model_config=ModelConfig(model_id, mm_processor_kwargs=config_kwargs),
        tokenizer=None,
1048
1049
    )

1050
1051
1052
1053
    processor = ctx.get_hf_processor(DummyProcessor)  # type: ignore[arg-type]

    result = ctx.call_hf_processor(processor, {}, inference_kwargs)
    assert result == expected_kwargs
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077


def test_apply_matches_no_match_exits_quickly():
    """
    Test that _apply_matches exits quickly when no matches are found.

    Previously, _apply_matches had O(n²) behavior when no match was found
    because it would increment start_idx by 1 each iteration while
    re-scanning the entire prompt from prev_end_idx=0.

    With the fix, it should exit immediately when no match is found.
    """
    # Create a long prompt with no placeholder
    long_prompt = "x" * 10000

    # Create update looking for a placeholder that doesn't exist
    mm_prompt_updates = {
        "image": [[PromptReplacement("image", "<image>", "REPLACED").resolve(0)]]
    }

    start = time.perf_counter()
    result, _ = _apply_matches(
        long_prompt,
        mm_prompt_updates,
1078
        tokenizer=None,
1079
1080
1081
1082
1083
1084
    )
    elapsed = time.perf_counter() - start

    # Should complete in < 100ms (was taking seconds before the fix)
    assert elapsed < 0.1, f"_apply_matches took {elapsed:.2f}s, expected < 0.1s"
    assert "".join(result) == long_prompt