test_processing.py 32.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import time
5
from contextlib import nullcontext
6

7
import numpy as np
8
import pytest
9
10
11

from vllm.config import ModelConfig
from vllm.multimodal import MULTIMODAL_REGISTRY
12
13
14
15
16
17
from vllm.multimodal.processing import (
    InputProcessingContext,
    PlaceholderFeaturesInfo,
    PromptIndexTargets,
    PromptInsertion,
    PromptReplacement,
18
    _apply_matches,
19
20
21
22
23
24
    apply_text_matches,
    apply_token_matches,
    find_mm_placeholders,
    iter_token_matches,
    replace_token_matches,
)
25
from vllm.multimodal.profiling import MultiModalProfiler
26

27
28
from .utils import random_image

29
30
pytestmark = pytest.mark.cpu_test

31
32
33
34

@pytest.mark.parametrize(
    ("token_ids", "match_ids", "expected"),
    [
35
        ([], [], []),
36
37
38
39
40
        ([], [32000], []),
        (
            [32000, 32000, 32000],
            [32000],
            [
41
42
43
                {"start_idx": 0, "end_idx": 1},
                {"start_idx": 1, "end_idx": 2},
                {"start_idx": 2, "end_idx": 3},
44
45
46
47
48
            ],
        ),
        (
            [32000, 32000, 32000],
            [32000, 32000],
49
            [{"start_idx": 0, "end_idx": 2}],
50
51
52
53
        ),
        (
            [32000, 32000, 32000],
            [32000, 32000, 32000],
54
            [{"start_idx": 0, "end_idx": 3}],
55
56
57
58
59
        ),
        (
            [9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918],
            [28747, 32000],
            [
60
61
                {"start_idx": 1, "end_idx": 3},
                {"start_idx": 6, "end_idx": 8},
62
63
64
65
66
67
            ],
        ),
        (
            [9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918],
            [28747, 32000, 32000, 32000],
            [
68
                {"start_idx": 1, "end_idx": 5},
69
70
71
72
73
74
75
76
77
            ],
        ),
        (
            [9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918],
            [28747, 0, 32000],
            [],
        ),
    ],
)
78
79
@pytest.mark.parametrize("start_idx", [0, 4, 8])
def test_iter_token_matches(token_ids, match_ids, expected, start_idx):
80
    result = list(iter_token_matches(token_ids, match_ids, start_idx=start_idx))
81
82

    # Manually constructed results
83
84
85
    assert [item._asdict() for item in result] == [
        item for item in expected if item["start_idx"] >= start_idx
    ]
86
87
88
89
90
91
92

    # Invariants
    match_lens = [end - start for start, end in result]
    print("match_lens:", match_lens)  # Only displayed on error
    assert all(match_len == len(match_ids) for match_len in match_lens)


93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
@pytest.mark.parametrize(
    ("token_ids", "match_ids", "new_ids", "expected"),
    [
        ([], [], [-1], []),
        ([], [32000], [-1], []),
        (
            [32000, 32000, 32000],
            [32000],
            [-1],
            [-1, -1, -1],
        ),
        (
            [32000, 32000, 32000],
            [32000, 32000],
            [-1],
            [-1, 32000],
        ),
        (
            [32000, 32000, 32000],
            [32000, 32000, 32000],
            [-1],
            [-1],
        ),
        (
            [9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918],
            [28747, 32000],
            [-1],
            [9833, -1, 32000, 32000, 9833, -1, 32000, 918],
        ),
        (
            [9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918],
            [28747, 32000, 32000, 32000],
            [-1],
            [9833, -1, 9833, 28747, 32000, 32000, 918],
        ),
        (
            [9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918],
            [28747, 0, 32000],
            [-1],
            [9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918],
        ),
    ],
)
def test_replace_token_matches(token_ids, match_ids, new_ids, expected):
    result = replace_token_matches(token_ids, match_ids, new_ids)

    # Manually constructed results
    assert result == expected


143
144
145
146
147
148
149
150
@pytest.mark.parametrize(
    ("prompt", "target_by_key", "expected_by_key"),
    [
        (
            [],
            {
                "pattern_1": [],
                "pattern_2": [32000],
151
152
153
                "pattern_3": PromptIndexTargets.start(),
                "pattern_4": PromptIndexTargets.prefix([32000]),
                "pattern_5": PromptIndexTargets.end(),
154
155
            },
            {
156
                "pattern_1": [],
157
                "pattern_2": [],
158
                "pattern_3": [
159
                    {"start_idx": 0, "end_idx": 0},
160
161
162
                ],
                "pattern_4": [],
                "pattern_5": [
163
                    {"start_idx": 0, "end_idx": 0},
164
                ],
165
            },
166
167
168
169
170
171
172
        ),
        (
            [32000, 32000, 32000, 32000],
            {
                "pattern_1": [32000],
                "pattern_2": [32000, 32000],
                "pattern_3": [32000, 32000, 32000],
173
174
175
                "pattern_4": PromptIndexTargets.start(),
                "pattern_5": PromptIndexTargets.prefix([32000]),
                "pattern_6": PromptIndexTargets.end(),
176
177
178
            },
            {
                "pattern_1": [
179
180
181
182
                    {"start_idx": 0, "end_idx": 1},
                    {"start_idx": 1, "end_idx": 2},
                    {"start_idx": 2, "end_idx": 3},
                    {"start_idx": 3, "end_idx": 4},
183
184
                ],
                "pattern_2": [
185
186
                    {"start_idx": 0, "end_idx": 2},
                    {"start_idx": 2, "end_idx": 4},
187
188
                ],
                "pattern_3": [
189
                    {"start_idx": 0, "end_idx": 3},
190
                ],
191
                "pattern_4": [
192
                    {"start_idx": 0, "end_idx": 0},
193
194
                ],
                "pattern_5": [
195
                    {"start_idx": 1, "end_idx": 1},
196
197
                ],
                "pattern_6": [
198
                    {"start_idx": 4, "end_idx": 4},
199
                ],
200
201
202
203
204
205
206
207
            },
        ),
        (
            [9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918],
            {
                "pattern_1": [28747, 32000],
                "pattern_2": [28747, 32000, 32000, 32000],
                "pattern_3": [28747, 0, 32000],
208
209
210
                "pattern_4": PromptIndexTargets.start(),
                "pattern_5": PromptIndexTargets.prefix([28747, 32000]),
                "pattern_6": PromptIndexTargets.end(),
211
212
213
            },
            {
                "pattern_1": [
214
215
                    {"start_idx": 1, "end_idx": 3},
                    {"start_idx": 6, "end_idx": 8},
216
217
                ],
                "pattern_2": [
218
                    {"start_idx": 1, "end_idx": 5},
219
220
                ],
                "pattern_3": [],
221
                "pattern_4": [
222
                    {"start_idx": 0, "end_idx": 0},
223
224
225
                ],
                "pattern_5": [],
                "pattern_6": [
226
                    {"start_idx": 10, "end_idx": 10},
227
                ],
228
229
230
231
            },
        ),
    ],
)
232
233
234
235
236
237
238
@pytest.mark.parametrize("update_type", [PromptInsertion, PromptReplacement])
def test_find_token_matches(
    prompt,
    target_by_key,
    expected_by_key,
    update_type,
):
239
    prompt_updates = {
240
        key: update_type(key, target, []).resolve(0)
241
        for key, target in target_by_key.items()
242
243
    }
    result = {
244
        key: list(update.iter_token_matches(prompt, tokenizer=None))
245
246
        for key, update in prompt_updates.items()
    }
247
248
249
250
251
252
253
254

    # Only displayed on error
    print("result:", result)

    # Manually constructed results
    assert {
        key: [
            dict(start_idx=item.start_idx, end_idx=item.end_idx)
255
            for item in result.get(key, [])
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
        ]
        for key in expected_by_key
    } == expected_by_key


@pytest.mark.parametrize(
    ("prompt", "target_by_key", "expected_by_key"),
    [
        # Detokenized test cases of `test_find_token_matches`
        # using the vocab of llava-hf/llava-v1.6-mistral-7b-hf
        (
            "",
            {
                "pattern_1": "",
                "pattern_2": "<image>",
271
272
273
                "pattern_3": PromptIndexTargets.start(),
                "pattern_4": PromptIndexTargets.prefix("<image>"),
                "pattern_5": PromptIndexTargets.end(),
274
275
            },
            {
276
                "pattern_1": [{"start_idx": 0, "end_idx": 0}],
277
                "pattern_2": [],
278
                "pattern_3": [
279
                    {"start_idx": 0, "end_idx": 0},
280
281
282
                ],
                "pattern_4": [],
                "pattern_5": [
283
                    {"start_idx": 0, "end_idx": 0},
284
                ],
285
            },
286
287
288
289
290
291
292
        ),
        (
            "<image><image><image><image>",
            {
                "pattern_1": "<image>",
                "pattern_2": "<image><image>",
                "pattern_3": "<image><image><image>",
293
294
295
                "pattern_4": PromptIndexTargets.start(),
                "pattern_5": PromptIndexTargets.prefix("<image>"),
                "pattern_6": PromptIndexTargets.end(),
296
297
298
            },
            {
                "pattern_1": [
299
300
301
302
                    {"start_idx": 0, "end_idx": 7},
                    {"start_idx": 7, "end_idx": 14},
                    {"start_idx": 14, "end_idx": 21},
                    {"start_idx": 21, "end_idx": 28},
303
304
                ],
                "pattern_2": [
305
306
                    {"start_idx": 0, "end_idx": 14},
                    {"start_idx": 14, "end_idx": 28},
307
308
                ],
                "pattern_3": [
309
                    {"start_idx": 0, "end_idx": 21},
310
                ],
311
                "pattern_4": [
312
                    {"start_idx": 0, "end_idx": 0},
313
314
                ],
                "pattern_5": [
315
                    {"start_idx": 7, "end_idx": 7},
316
317
                ],
                "pattern_6": [
318
                    {"start_idx": 28, "end_idx": 28},
319
                ],
320
321
322
323
324
325
326
327
            },
        ),
        (
            "Image:<image><image><image>Image:<image><image>!",
            {
                "pattern_1": "Image:<image>",
                "pattern_2": "Image:<image><image><image>",
                "pattern_3": "Image:<unk><image>",
328
329
330
                "pattern_4": PromptIndexTargets.start(),
                "pattern_5": PromptIndexTargets.prefix("Image:<image>"),
                "pattern_6": PromptIndexTargets.end(),
331
332
333
            },
            {
                "pattern_1": [
334
335
                    {"start_idx": 0, "end_idx": 13},
                    {"start_idx": 27, "end_idx": 40},
336
337
                ],
                "pattern_2": [
338
                    {"start_idx": 0, "end_idx": 27},
339
340
                ],
                "pattern_3": [],
341
                "pattern_4": [
342
                    {"start_idx": 0, "end_idx": 0},
343
344
                ],
                "pattern_5": [
345
                    {"start_idx": 13, "end_idx": 13},
346
347
                ],
                "pattern_6": [
348
                    {"start_idx": 48, "end_idx": 48},
349
                ],
350
351
352
353
354
355
356
357
358
359
360
361
            },
        ),
        # Test regex escape
        (
            "<|image|><image><|image|><image>",
            {
                "pattern_1": "<|image|>",
                "pattern_2": "<|image|><image>",
                "pattern_3": "<|image|><image><|image|>",
            },
            {
                "pattern_1": [
362
363
                    {"start_idx": 0, "end_idx": 9},
                    {"start_idx": 16, "end_idx": 25},
364
365
                ],
                "pattern_2": [
366
367
                    {"start_idx": 0, "end_idx": 16},
                    {"start_idx": 16, "end_idx": 32},
368
369
                ],
                "pattern_3": [
370
                    {"start_idx": 0, "end_idx": 25},
371
372
373
374
375
                ],
            },
        ),
    ],
)
376
377
378
379
380
381
382
@pytest.mark.parametrize("update_type", [PromptInsertion, PromptReplacement])
def test_find_text_matches(
    prompt,
    target_by_key,
    expected_by_key,
    update_type,
):
383
    prompt_updates = {
384
        key: update_type(key, target, []).resolve(0)
385
        for key, target in target_by_key.items()
386
387
    }
    result = {
388
        key: list(update.iter_text_matches(prompt, tokenizer=None))
389
390
        for key, update in prompt_updates.items()
    }
391
392
393
394
395
396
397
398

    # Only displayed on error
    print("result:", result)

    # Manually constructed results
    assert {
        key: [
            dict(start_idx=item.start_idx, end_idx=item.end_idx)
399
            for item in result.get(key, [])
400
401
402
403
404
405
        ]
        for key in expected_by_key
    } == expected_by_key


@pytest.mark.parametrize(
406
    ("prompt", "target_by_key", "repl_by_key", "expected_by_update_type_mm_count"),  # noqa: E501
407
408
409
410
411
412
413
414
415
416
417
    [
        (
            "Image:<image>Image:<image><image>!",
            {
                # We use `<image>` before `Image:` to test matches that
                # occur out of order
                "pattern_1": "<image>",
                "pattern_2": "Image:",
                "pattern_3": "!",
            },
            {
418
419
420
421
422
423
                # Test whether target is confused with replacement
                "pattern_1": "<image><image>",
                # Test empty replacement
                "pattern_2": "",
                # Test dynamic replacement (beyond the form of `unit * count`)
                "pattern_3": "?!?",
424
            },
425
426
427
428
429
430
431
432
433
434
435
436
            {
                PromptInsertion: {
                    0: "Image:<image>Image:<image><image>!",
                    1: "Image:<image><image><image>Image:<image><image>!?!?",
                    2: "Image:<image><image><image><image><image>Image:<image><image>!?!??!?",  # noqa: E501
                },
                PromptReplacement: {
                    0: "Image:<image>Image:<image><image>!",
                    1: "<image><image>Image:<image><image>?!?",
                    2: "<image><image><image><image><image>?!?",
                },
            },
437
        ),
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
        # Test index targets
        (
            "",
            {
                "pattern_1": PromptIndexTargets.start(),
                "pattern_2": PromptIndexTargets.prefix("<image>"),
                "pattern_3": PromptIndexTargets.end(),
            },
            {
                "pattern_1": "1",
                "pattern_2": "2",
                "pattern_3": "3",
            },
            {
                PromptInsertion: {
                    0: "",
                    1: "13",
                    2: "1133",
                },
                PromptReplacement: {
                    0: "",
                    1: "13",
                    2: "1133",
                },
            },
        ),
        (
            "<image>",
            {
                "pattern_1": PromptIndexTargets.start(),
                "pattern_2": PromptIndexTargets.prefix("<image>"),
                "pattern_3": PromptIndexTargets.end(),
            },
            {
                "pattern_1": "1",
                "pattern_2": "2",
                "pattern_3": "3",
            },
            {
                PromptInsertion: {
                    0: "<image>",
                    1: "1<image>23",
                    2: "11<image>2233",
                },
                PromptReplacement: {
                    0: "<image>",
                    1: "1<image>23",
                    2: "11<image>2233",
                },
            },
        ),
        # Test different replacement per item
        (
            "<image><image><image>",
            {
                "pattern_1": "<image>",
            },
            {
                "pattern_1": lambda idx: str(idx + 1),
            },
            {
                PromptInsertion: {
                    0: "<image><image><image>",
                    1: "<image>1<image><image>",
                    2: "<image>12<image><image>",
                },
                PromptReplacement: {
                    0: "<image><image><image>",
                    1: "1<image><image>",
                    2: "12<image>",
                },
            },
        ),
        (
            "<image><image><image>",
            {
                "pattern_1": PromptIndexTargets.prefix("<image>"),
            },
            {
                "pattern_1": lambda idx: str(idx + 1),
            },
            {
                PromptInsertion: {
                    0: "<image><image><image>",
                    1: "<image>1<image><image>",
                    2: "<image>12<image><image>",
                },
                PromptReplacement: {
                    0: "<image><image><image>",
                    1: "<image>1<image><image>",
                    2: "<image>12<image><image>",
                },
            },
        ),
532
    ],
533
)
534
def test_find_update_text(
535
536
537
    prompt,
    target_by_key,
    repl_by_key,
538
    expected_by_update_type_mm_count,
539
):
540
    for (
541
542
        update_type,
        expected_by_mm_count,
543
544
    ) in expected_by_update_type_mm_count.items():
        for mm_count, expected in expected_by_mm_count.items():
545
            mm_prompt_updates = {
546
547
548
549
                key: [
                    [update_type(key, target, repl_by_key[key]).resolve(i)]
                    for i in range(mm_count)
                ]
550
551
552
553
                for key, target in target_by_key.items()
            }

            new_prompt, result = apply_text_matches(
554
                prompt,
555
                mm_prompt_updates,
556
                tokenizer=None,
557
558
559
560
561
            )

            # Only displayed on error
            print("update_type:", update_type)
            print("mm_count:", mm_count)
562
563
            print("mm_prompt_updates:", mm_prompt_updates)
            print("new_prompt:", new_prompt)
564
565
566
            print("result:", result)

            # Manually constructed results
567
            assert new_prompt == expected
568
569
570


@pytest.mark.parametrize(
571
    ("prompt", "target_by_key", "repl_by_key", "expected_by_update_type_mm_count"),  # noqa: E501
572
    [
573
        # Tokenized test cases of `test_find_update_text`
574
575
576
577
578
579
580
581
582
583
584
        # using the vocab of llava-hf/llava-v1.6-mistral-7b-hf
        (
            [1, 9833, 28747, 32000, 9833, 28747, 32000, 32000, 918],
            {
                # We use `<image>` before `Image:` to test matches that
                # occur out of order
                "pattern_1": [32000],
                "pattern_2": [9833, 28747],
                "pattern_3": [918],
            },
            {
585
586
587
588
589
590
                # Test whether target is confused with replacement
                "pattern_1": [32000, 32000],
                # Test empty replacement
                "pattern_2": [],
                # Test dynamic replacement (beyond the form of `unit * count`)
                "pattern_3": [1550, 918, 1550],
591
            },
592
593
594
            {
                PromptInsertion: {
                    0: [1, 9833, 28747, 32000, 9833, 28747, 32000, 32000, 918],
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
                    1: [
                        1,
                        9833,
                        28747,
                        32000,
                        32000,
                        32000,
                        9833,
                        28747,
                        32000,
                        32000,
                        918,
                        1550,
                        918,
                        1550,
                    ],  # noqa: E501
                    2: [
                        1,
                        9833,
                        28747,
                        32000,
                        32000,
                        32000,
                        32000,
                        32000,
                        9833,
                        28747,
                        32000,
                        32000,
                        918,
                        1550,
                        918,
                        1550,
                        1550,
                        918,
                        1550,
                    ],  # noqa: E501
632
633
634
635
636
637
638
                },
                PromptReplacement: {
                    0: [1, 9833, 28747, 32000, 9833, 28747, 32000, 32000, 918],
                    1: [1, 32000, 32000, 9833, 28747, 32000, 32000, 1550, 918, 1550],  # noqa: E501
                    2: [1, 32000, 32000, 32000, 32000, 32000, 1550, 918, 1550],
                },
            },
639
        ),
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
        # Test index targets
        (
            [],
            {
                "pattern_1": PromptIndexTargets.start(),
                "pattern_2": PromptIndexTargets.prefix([32000]),
                "pattern_3": PromptIndexTargets.end(),
            },
            {
                "pattern_1": [-1],
                "pattern_2": [-2],
                "pattern_3": [-3],
            },
            {
                PromptInsertion: {
                    0: [],
                    1: [-1, -3],
                    2: [-1, -1, -3, -3],
                },
                PromptReplacement: {
                    0: [],
                    1: [-1, -3],
                    2: [-1, -1, -3, -3],
                },
            },
        ),
        (
            [32000],
            {
                "pattern_1": PromptIndexTargets.start(),
                "pattern_2": PromptIndexTargets.prefix([32000]),
                "pattern_3": PromptIndexTargets.end(),
            },
            {
                "pattern_1": [-1],
                "pattern_2": [-2],
                "pattern_3": [-3],
            },
            {
                PromptInsertion: {
                    0: [32000],
                    1: [-1, 32000, -2, -3],
                    2: [-1, -1, 32000, -2, -2, -3, -3],
                },
                PromptReplacement: {
                    0: [32000],
                    1: [-1, 32000, -2, -3],
                    2: [-1, -1, 32000, -2, -2, -3, -3],
                },
            },
        ),
        # Test different replacement per item
        (
            [32000, 32000, 32000],
            {
                "pattern_1": [32000],
            },
            {
                "pattern_1": lambda idx: [-(idx + 1)],
            },
            {
                PromptInsertion: {
                    0: [32000, 32000, 32000],
                    1: [32000, -1, 32000, 32000],
                    2: [32000, -1, -2, 32000, 32000],
                },
                PromptReplacement: {
                    0: [32000, 32000, 32000],
                    1: [-1, 32000, 32000],
                    2: [-1, -2, 32000],
                },
            },
        ),
        (
            [32000, 32000, 32000],
            {
                "pattern_1": PromptIndexTargets.prefix([32000]),
            },
            {
                "pattern_1": lambda idx: [-(idx + 1)],
            },
            {
                PromptInsertion: {
                    0: [32000, 32000, 32000],
                    1: [32000, -1, 32000, 32000],
                    2: [32000, -1, -2, 32000, 32000],
                },
                PromptReplacement: {
                    0: [32000, 32000, 32000],
                    1: [32000, -1, 32000, 32000],
                    2: [32000, -1, -2, 32000, 32000],
                },
            },
        ),
734
    ],
735
)
736
def test_find_update_tokens(
737
738
739
    prompt,
    target_by_key,
    repl_by_key,
740
    expected_by_update_type_mm_count,
741
):
742
    for (
743
744
        update_type,
        expected_by_mm_count,
745
746
    ) in expected_by_update_type_mm_count.items():
        for mm_count, expected in expected_by_mm_count.items():
747
            mm_prompt_updates = {
748
749
750
751
                key: [
                    [update_type(key, target, repl_by_key[key]).resolve(i)]
                    for i in range(mm_count)
                ]
752
753
754
755
                for key, target in target_by_key.items()
            }

            new_prompt, result = apply_token_matches(
756
                prompt,
757
                mm_prompt_updates,
758
                tokenizer=None,
759
760
761
762
763
            )

            # Only displayed on error
            print("update_type:", update_type)
            print("mm_count:", mm_count)
764
765
            print("mm_prompt_updates:", mm_prompt_updates)
            print("new_prompt:", new_prompt)
766
767
768
            print("result:", result)

            # Manually constructed results
769
            assert new_prompt == expected
770
771
772
773
774
775


@pytest.mark.parametrize(
    "repl_by_key",
    [
        {
776
777
778
            "pattern_1": [32000, 32000],
            "pattern_2": [],
            "pattern_3": [1550, 918, 1550],
779
780
            # Test different modalities having the same tokens (32000)
            "pattern_4": [32000],
781
782
783
784
785
786
787
788
        },
    ],
)
@pytest.mark.parametrize(
    ("prompt", "expected"),
    [
        (
            [1, 9833, 28747, 32000, 9833, 28747, 32000, 32000, 918],
789
790
            {
                "pattern_1": [
791
                    PlaceholderFeaturesInfo(
792
793
794
                        modality="pattern_1",
                        item_idx=0,
                        start_idx=6,
795
                        tokens=[32000, 32000],
796
                        is_embed=None,
797
798
                    ),
                ],
799
                "pattern_4": [
800
                    PlaceholderFeaturesInfo(
801
802
803
                        modality="pattern_4",
                        item_idx=0,
                        start_idx=3,
804
                        tokens=[32000],
805
                        is_embed=None,
806
807
                    ),
                ],
808
            },
809
810
        ),
        (
811
            [1, 32000, 32000, 9833, 28747, 32000, 32000, 1550, 918, 1550],
812
813
            {
                "pattern_1": [
814
                    PlaceholderFeaturesInfo(
815
816
817
                        modality="pattern_1",
                        item_idx=0,
                        start_idx=1,
818
                        tokens=[32000, 32000],
819
                        is_embed=None,
820
                    ),
821
                    PlaceholderFeaturesInfo(
822
823
824
                        modality="pattern_1",
                        item_idx=1,
                        start_idx=5,
825
                        tokens=[32000, 32000],
826
                        is_embed=None,
827
828
829
                    ),
                ],
                "pattern_3": [
830
                    PlaceholderFeaturesInfo(
831
832
833
                        modality="pattern_3",
                        item_idx=0,
                        start_idx=7,
834
                        tokens=[1550, 918, 1550],
835
                        is_embed=None,
836
837
                    ),
                ],
838
                # No match for pattern_4 as it has lower priority than pattern_1
839
            },
840
841
        ),
        (
842
            [1, 32000, 32000, 32000, 32000, 32000, 1550, 918, 1550],
843
844
            {
                "pattern_1": [
845
                    PlaceholderFeaturesInfo(
846
847
848
                        modality="pattern_1",
                        item_idx=0,
                        start_idx=1,
849
                        tokens=[32000, 32000],
850
                        is_embed=None,
851
                    ),
852
                    PlaceholderFeaturesInfo(
853
854
855
                        modality="pattern_1",
                        item_idx=1,
                        start_idx=3,
856
                        tokens=[32000, 32000],
857
                        is_embed=None,
858
859
                    ),
                ],
860
                "pattern_4": [
861
                    PlaceholderFeaturesInfo(
862
863
864
                        modality="pattern_4",
                        item_idx=0,
                        start_idx=5,
865
                        tokens=[32000],
866
                        is_embed=None,
867
868
                    ),
                ],
869
                "pattern_3": [
870
                    PlaceholderFeaturesInfo(
871
872
873
                        modality="pattern_3",
                        item_idx=0,
                        start_idx=6,
874
                        tokens=[1550, 918, 1550],
875
                        is_embed=None,
876
877
                    ),
                ],
878
            },
879
        ),
880
    ],
881
)
882
@pytest.mark.parametrize("update_type", [PromptInsertion, PromptReplacement])
883
def test_find_mm_placeholders(
884
885
886
    repl_by_key,
    prompt,
    expected,
887
    update_type,
888
):
889
    mm_prompt_updates = {
890
        key: [[update_type(key, [], repl).resolve(i)] for i in range(3)]
891
        for key, repl in repl_by_key.items()
892
    }
893

894
    result = find_mm_placeholders(prompt, mm_prompt_updates, tokenizer=None)
895
896
897

    # Only displayed on error
    print("result:", result)
898
899

    # Manually constructed results
900
    assert result == expected
901
902


903
@pytest.mark.parametrize("model_id", ["llava-hf/llava-v1.6-mistral-7b-hf"])
904
905
@pytest.mark.parametrize(
    ("limit", "num_supported", "is_valid"),
906
907
908
909
910
911
912
913
914
    [
        (0, 0, True),
        (0, 1, True),
        (1, 0, False),
        (1, 1, True),
        (1, 2, True),
        (2, 1, False),
        (2, 2, True),
    ],
915
916
917
918
919
920
921
922
923
)
def test_limit_mm_per_prompt_dummy(model_id, limit, num_supported, is_valid):
    limit_mm_per_prompt = {"image": limit}

    model_config = ModelConfig(
        model=model_id,
        limit_mm_per_prompt=limit_mm_per_prompt,
    )

924
    processor = MULTIMODAL_REGISTRY.create_processor(model_config)
925
    processor._supported_mm_limits = {"image": num_supported}
926

927
    profiler = MultiModalProfiler(processor)
928

929
    exc_ctx = nullcontext() if is_valid else pytest.raises(ValueError, match="At most")
930
931

    with exc_ctx:
932
933
934
935
        profiler.get_decoder_dummy_data(
            model_config.max_model_len,
            mm_counts=limit_mm_per_prompt,
        )
936
937


938
@pytest.mark.parametrize("model_id", ["llava-hf/llava-v1.6-mistral-7b-hf"])
939
940
@pytest.mark.parametrize(
    ("num_images", "limit", "is_valid"),
941
942
943
944
945
946
947
948
949
    [
        (0, 0, True),
        (0, 1, True),
        (1, 0, False),
        (1, 1, True),
        (1, 2, True),
        (2, 1, False),
        (2, 2, True),
    ],
950
951
952
953
954
955
956
957
958
)
def test_limit_mm_per_prompt_apply(model_id, num_images, limit, is_valid):
    limit_mm_per_prompt = {"image": limit}

    model_config = ModelConfig(
        model=model_id,
        limit_mm_per_prompt=limit_mm_per_prompt,
    )

959
    processor = MULTIMODAL_REGISTRY.create_processor(model_config)
960
961

    rng = np.random.RandomState(0)
962
    image = random_image(rng, min_wh=128, max_wh=256)
963
964
965
966
967
968
969
    if num_images == 0:
        mm_data = {}
    elif num_images == 1:
        mm_data = {"image": image}
    else:
        mm_data = {"image": [image] * num_images}

970
    exc_ctx = nullcontext() if is_valid else pytest.raises(ValueError, match="At most")
971
972
973
974
975
976
977

    with exc_ctx:
        processor.apply(
            "<image>" * num_images,
            mm_data=mm_data,
            hf_processor_mm_kwargs={},
        )
978
979


980
981
class DummyProcessor:
    def __init__(self, a: int = 0, b: int = 0) -> None:
982
983
        super().__init__()

984
985
        self.a = a
        self.b = b
986
987
988

    def __call__(
        self,
989
990
        a: int = 0,
        c: int = 0,
991
        return_tensors: str | None = None,
992
993
    ) -> dict[str, int]:
        return dict(a=a, c=c)
994
995


996
@pytest.mark.parametrize("model_id", ["Qwen/Qwen2-VL-2B-Instruct"])  # Dummy
997
@pytest.mark.parametrize(
998
    ("config_kwargs", "inference_kwargs", "expected_kwargs"),
999
    [
1000
1001
1002
1003
1004
1005
1006
        ({"a": 1}, {}, {"a": 1, "b": 0}),
        ({}, {"a": 1}, {"a": 1, "b": 0}),
        # inference_kwargs should take precedence
        ({"a": 1}, {"a": 2}, {"a": 2, "b": 0}),
        # Should ignore extra kwargs
        ({"a": 1, "c": 1}, {}, {"a": 1, "b": 0}),
        ({"b": 1, "c": 1}, {}, {"a": 0, "b": 1}),
1007
1008
    ],
)
1009
1010
1011
1012
1013
1014
1015
1016
def test_hf_processor_init_kwargs(
    model_id,
    config_kwargs,
    inference_kwargs,
    expected_kwargs,
):
    ctx = InputProcessingContext(
        model_config=ModelConfig(model_id, mm_processor_kwargs=config_kwargs),
1017
        tokenizer=None,
1018
1019
1020
1021
1022
1023
1024
1025
1026
    )

    processor = ctx.get_hf_processor(
        DummyProcessor,  # type: ignore[arg-type]
        **inference_kwargs,
    )

    for k, v in expected_kwargs.items():
        assert getattr(processor, k) == v
1027
1028


1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
@pytest.mark.parametrize("model_id", ["Qwen/Qwen2-VL-2B-Instruct"])  # Dummy
@pytest.mark.parametrize(
    ("config_kwargs", "inference_kwargs", "expected_kwargs"),
    [
        ({"a": 1}, {}, {"a": 1, "c": 0}),
        ({}, {"a": 1}, {"a": 1, "c": 0}),
        # inference_kwargs should take precedence
        ({"a": 1}, {"a": 2}, {"a": 2, "c": 0}),
        # Should ignore extra kwargs
        ({"a": 1, "c": 1}, {}, {"a": 1, "c": 1}),
        ({"b": 1, "c": 1}, {}, {"a": 0, "c": 1}),
    ],
)
def test_hf_processor_call_kwargs(
    model_id,
    config_kwargs,
    inference_kwargs,
    expected_kwargs,
):
    ctx = InputProcessingContext(
        model_config=ModelConfig(model_id, mm_processor_kwargs=config_kwargs),
1050
        tokenizer=None,
1051
1052
    )

1053
1054
1055
1056
    processor = ctx.get_hf_processor(DummyProcessor)  # type: ignore[arg-type]

    result = ctx.call_hf_processor(processor, {}, inference_kwargs)
    assert result == expected_kwargs
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080


def test_apply_matches_no_match_exits_quickly():
    """
    Test that _apply_matches exits quickly when no matches are found.

    Previously, _apply_matches had O(n²) behavior when no match was found
    because it would increment start_idx by 1 each iteration while
    re-scanning the entire prompt from prev_end_idx=0.

    With the fix, it should exit immediately when no match is found.
    """
    # Create a long prompt with no placeholder
    long_prompt = "x" * 10000

    # Create update looking for a placeholder that doesn't exist
    mm_prompt_updates = {
        "image": [[PromptReplacement("image", "<image>", "REPLACED").resolve(0)]]
    }

    start = time.perf_counter()
    result, _ = _apply_matches(
        long_prompt,
        mm_prompt_updates,
1081
        tokenizer=None,
1082
1083
1084
1085
1086
1087
    )
    elapsed = time.perf_counter() - start

    # Should complete in < 100ms (was taking seconds before the fix)
    assert elapsed < 0.1, f"_apply_matches took {elapsed:.2f}s, expected < 0.1s"
    assert "".join(result) == long_prompt