test_processing.py 20.6 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
from contextlib import nullcontext
4
from types import MethodType
5
from typing import cast
6
from unittest.mock import MagicMock
7

8
import numpy as np
9
import pytest
10
from transformers import ProcessorMixin
11
12
13

from vllm.config import ModelConfig
from vllm.multimodal import MULTIMODAL_REGISTRY
14
15
16
17
# yapf conflicts with isort for this block
# yapf: disable
from vllm.multimodal.processing import (PlaceholderFeaturesInfo,
                                        PromptReplacement,
18
                                        find_mm_placeholders,
19
                                        find_text_matches, find_token_matches,
20
                                        iter_token_matches,
21
22
                                        replace_text_matches,
                                        replace_token_matches)
23
# yapf: enable
24
from vllm.multimodal.profiling import MultiModalProfiler
25
26
from vllm.transformers_utils.tokenizer import (AnyTokenizer,
                                               cached_tokenizer_from_config)
27
28
from vllm.utils import full_groupby

29
30
from .utils import random_image

31
32
33
34
35

# yapf: disable
@pytest.mark.parametrize(
    ("token_ids", "match_ids", "expected"),
    [
36
        ([], [], []),
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
        ([], [32000], []),
        (
            [32000, 32000, 32000],
            [32000],
            [
                { "start_idx": 0, "end_idx": 1 },
                { "start_idx": 1, "end_idx": 2 },
                { "start_idx": 2, "end_idx": 3 },
            ],
        ),
        (
            [32000, 32000, 32000],
            [32000, 32000],
            [{ "start_idx": 0, "end_idx": 2 }],
        ),
        (
            [32000, 32000, 32000],
            [32000, 32000, 32000],
            [{ "start_idx": 0, "end_idx": 3 }],
        ),
        (
            [9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918],
            [28747, 32000],
            [
                { "start_idx": 1, "end_idx": 3 },
                { "start_idx": 6, "end_idx": 8 },
            ],
        ),
        (
            [9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918],
            [28747, 32000, 32000, 32000],
            [
                { "start_idx": 1, "end_idx": 5 },
            ],
        ),
        (
            [9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918],
            [28747, 0, 32000],
            [],
        ),
    ],
)
# yapf: enable
def test_iter_token_matches(token_ids, match_ids, expected):
    result = list(iter_token_matches(token_ids, match_ids))

    # Manually constructed results
    assert [item._asdict() for item in result] == expected

    # Invariants
    match_lens = [end - start for start, end in result]
    print("match_lens:", match_lens)  # Only displayed on error
    assert all(match_len == len(match_ids) for match_len in match_lens)


# yapf: disable
@pytest.mark.parametrize(
    ("prompt", "target_by_key", "expected_by_key"),
    [
        (
            [],
            {
                "pattern_1": [],
                "pattern_2": [32000],
            },
            {
103
                "pattern_1": [],
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
                "pattern_2": [],
            }
        ),
        (
            [32000, 32000, 32000, 32000],
            {
                "pattern_1": [32000],
                "pattern_2": [32000, 32000],
                "pattern_3": [32000, 32000, 32000],
            },
            {
                "pattern_1": [
                    { "start_idx": 0, "end_idx": 1 },
                    { "start_idx": 1, "end_idx": 2 },
                    { "start_idx": 2, "end_idx": 3 },
                    { "start_idx": 3, "end_idx": 4 },
                ],
                "pattern_2": [
                    { "start_idx": 0, "end_idx": 2 },
                    { "start_idx": 2, "end_idx": 4 },
                ],
                "pattern_3": [
                    { "start_idx": 0, "end_idx": 3 },
                ],
            },
        ),
        (
            [9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918],
            {
                "pattern_1": [28747, 32000],
                "pattern_2": [28747, 32000, 32000, 32000],
                "pattern_3": [28747, 0, 32000],
            },
            {
                "pattern_1": [
                    { "start_idx": 1, "end_idx": 3 },
                    { "start_idx": 6, "end_idx": 8 },
                ],
                "pattern_2": [
                    { "start_idx": 1, "end_idx": 5 },
                ],
                "pattern_3": [],
            },
        ),
    ],
)
# yapf: enable
def test_find_token_matches(prompt, target_by_key, expected_by_key):
    # Should not be used since there is nothing to convert to token IDs
    mock_tokenizer = cast(AnyTokenizer, object())

155
    prompt_repls = [
156
        PromptReplacement(key, target, []).bind(mock_tokenizer)
157
158
159
        for key, target in target_by_key.items()
    ]
    result = find_token_matches(prompt, prompt_repls)
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261

    # Only displayed on error
    print("result:", result)

    # Manually constructed results
    result_groups = dict(full_groupby(result, key=lambda x: x.modality))
    assert {
        key: [
            dict(start_idx=item.start_idx, end_idx=item.end_idx)
            for item in result_groups.get(key, [])
        ]
        for key in expected_by_key
    } == expected_by_key


# yapf: disable
@pytest.mark.parametrize(
    ("prompt", "target_by_key", "expected_by_key"),
    [
        # Detokenized test cases of `test_find_token_matches`
        # using the vocab of llava-hf/llava-v1.6-mistral-7b-hf
        (
            "",
            {
                "pattern_1": "",
                "pattern_2": "<image>",
            },
            {
                "pattern_1": [{ "start_idx": 0, "end_idx": 0 }],
                "pattern_2": [],
            }
        ),
        (
            "<image><image><image><image>",
            {
                "pattern_1": "<image>",
                "pattern_2": "<image><image>",
                "pattern_3": "<image><image><image>",
            },
            {
                "pattern_1": [
                    { "start_idx": 0, "end_idx": 7 },
                    { "start_idx": 7, "end_idx": 14 },
                    { "start_idx": 14, "end_idx": 21 },
                    { "start_idx": 21, "end_idx": 28 },
                ],
                "pattern_2": [
                    { "start_idx": 0, "end_idx": 14 },
                    { "start_idx": 14, "end_idx": 28 },
                ],
                "pattern_3": [
                    { "start_idx": 0, "end_idx": 21 },
                ],
            },
        ),
        (
            "Image:<image><image><image>Image:<image><image>!",
            {
                "pattern_1": "Image:<image>",
                "pattern_2": "Image:<image><image><image>",
                "pattern_3": "Image:<unk><image>",
            },
            {
                "pattern_1": [
                    { "start_idx": 0, "end_idx": 13 },
                    { "start_idx": 27, "end_idx": 40 },
                ],
                "pattern_2": [
                    { "start_idx": 0, "end_idx": 27 },
                ],
                "pattern_3": [],
            },
        ),
        # Test regex escape
        (
            "<|image|><image><|image|><image>",
            {
                "pattern_1": "<|image|>",
                "pattern_2": "<|image|><image>",
                "pattern_3": "<|image|><image><|image|>",
            },
            {
                "pattern_1": [
                    { "start_idx": 0, "end_idx": 9 },
                    { "start_idx": 16, "end_idx": 25 },
                ],
                "pattern_2": [
                    { "start_idx": 0, "end_idx": 16 },
                    { "start_idx": 16, "end_idx": 32 },
                ],
                "pattern_3": [
                    { "start_idx": 0, "end_idx": 25 },
                ],
            },
        ),
    ],
)
# yapf: enable
def test_find_text_matches(prompt, target_by_key, expected_by_key):
    # Should not be used since there is nothing to convert to text
    mock_tokenizer = cast(AnyTokenizer, object())

262
    prompt_repls = [
263
        PromptReplacement(key, target, []).bind(mock_tokenizer)
264
265
266
        for key, target in target_by_key.items()
    ]
    result = find_text_matches(prompt, prompt_repls)
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283

    # Only displayed on error
    print("result:", result)

    # Manually constructed results
    result_groups = dict(full_groupby(result, key=lambda x: x.modality))
    assert {
        key: [
            dict(start_idx=item.start_idx, end_idx=item.end_idx)
            for item in result_groups.get(key, [])
        ]
        for key in expected_by_key
    } == expected_by_key


# yapf: disable
@pytest.mark.parametrize(
284
    ("prompt", "target_by_key", "repl_by_key"),
285
286
287
288
289
290
291
292
293
294
295
    [
        (
            "Image:<image>Image:<image><image>!",
            {
                # We use `<image>` before `Image:` to test matches that
                # occur out of order
                "pattern_1": "<image>",
                "pattern_2": "Image:",
                "pattern_3": "!",
            },
            {
296
297
298
299
300
301
                # Test whether target is confused with replacement
                "pattern_1": "<image><image>",
                # Test empty replacement
                "pattern_2": "",
                # Test dynamic replacement (beyond the form of `unit * count`)
                "pattern_3": "?!?",
302
303
304
305
            },
        ),
    ]
)
306
307
308
309
@pytest.mark.parametrize(
    ("mm_count", "expected"),
    [
        (0, "Image:<image>Image:<image><image>!"),
310
311
        (1, "<image><image>Image:<image><image>?!?"),
        (2, "<image><image><image><image><image>?!?"),
312
313
    ]
)
314
315
316
317
318
# yapf: enable
def test_find_replace_text(
    prompt,
    target_by_key,
    repl_by_key,
319
320
    mm_count,
    expected,
321
322
323
324
):
    # Should not be used since there is nothing to convert to text
    mock_tokenizer = cast(AnyTokenizer, object())

325
326
327
328
329
    mm_prompt_repls = {
        key: [
            PromptReplacement(key, target,
                              repl_by_key[key]).bind(mock_tokenizer)
        ]
330
        for key, target in target_by_key.items()
331
332
333
334
335
    }
    mm_matches = {
        key: find_text_matches(prompt, prompt_repls)
        for key, prompt_repls in mm_prompt_repls.items()
    }
336
337

    result = replace_text_matches(
338
        prompt,
339
        mm_matches,
340
341
        {key: mm_count
         for key in repl_by_key},
342
343
344
    )

    # Only displayed on error
345
    print("mm_matches:", mm_matches)
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
    print("result:", result)

    # Manually constructed results
    assert result == expected


# yapf: disable
@pytest.mark.parametrize(
    ("prompt", "target_by_key", "repl_by_key"),
    [
        # Tokenized test cases of `test_find_replace_text`
        # using the vocab of llava-hf/llava-v1.6-mistral-7b-hf
        (
            [1, 9833, 28747, 32000, 9833, 28747, 32000, 32000, 918],
            {
                # We use `<image>` before `Image:` to test matches that
                # occur out of order
                "pattern_1": [32000],
                "pattern_2": [9833, 28747],
                "pattern_3": [918],
            },
            {
368
369
370
371
372
373
                # Test whether target is confused with replacement
                "pattern_1": [32000, 32000],
                # Test empty replacement
                "pattern_2": [],
                # Test dynamic replacement (beyond the form of `unit * count`)
                "pattern_3": [1550, 918, 1550],
374
375
376
377
378
379
380
381
            },
        ),
    ]
)
@pytest.mark.parametrize(
    ("mm_count", "expected"),
    [
        (0, [1, 9833, 28747, 32000, 9833, 28747, 32000, 32000, 918]),
382
383
        (1, [1, 32000, 32000, 9833, 28747, 32000, 32000, 1550, 918, 1550]),
        (2, [1, 32000, 32000, 32000, 32000, 32000, 1550, 918, 1550]),
384
385
386
387
388
389
390
391
392
393
394
395
396
    ]
)
# yapf: enable
def test_find_replace_tokens(
    prompt,
    target_by_key,
    repl_by_key,
    mm_count,
    expected,
):
    # Should not be used since there is nothing to convert to tokens
    mock_tokenizer = cast(AnyTokenizer, object())

397
398
399
400
401
    mm_prompt_repls = {
        key: [
            PromptReplacement(key, target,
                              repl_by_key[key]).bind(mock_tokenizer)
        ]
402
        for key, target in target_by_key.items()
403
404
405
406
407
    }
    mm_matches = {
        key: find_token_matches(prompt, prompt_repls)
        for key, prompt_repls in mm_prompt_repls.items()
    }
408
409
410

    result = replace_token_matches(
        prompt,
411
        mm_matches,
412
413
        {key: mm_count
         for key in repl_by_key},
414
415
416
    )

    # Only displayed on error
417
    print("mm_matches:", mm_matches)
418
419
420
421
422
423
424
425
426
427
428
    print("result:", result)

    # Manually constructed results
    assert result == expected


# yapf: disable
@pytest.mark.parametrize(
    "repl_by_key",
    [
        {
429
430
431
            "pattern_1": [32000, 32000],
            "pattern_2": [],
            "pattern_3": [1550, 918, 1550],
432
433
            # Test different modalities having the same tokens (32000)
            "pattern_4": [32000],
434
435
436
437
438
439
440
441
        },
    ],
)
@pytest.mark.parametrize(
    ("prompt", "expected"),
    [
        (
            [1, 9833, 28747, 32000, 9833, 28747, 32000, 32000, 918],
442
443
            {
                "pattern_1": [
444
                    PlaceholderFeaturesInfo(
445
446
447
                        modality="pattern_1",
                        item_idx=0,
                        start_idx=6,
448
                        tokens=[32000, 32000],
449
450
                    ),
                ],
451
                "pattern_4": [
452
                    PlaceholderFeaturesInfo(
453
454
455
                        modality="pattern_4",
                        item_idx=0,
                        start_idx=3,
456
                        tokens=[32000],
457
458
                    ),
                ],
459
460
            }

461
462
        ),
        (
463
            [1, 32000, 32000, 9833, 28747, 32000, 32000, 1550, 918, 1550],
464
465
            {
                "pattern_1": [
466
                    PlaceholderFeaturesInfo(
467
468
469
                        modality="pattern_1",
                        item_idx=0,
                        start_idx=1,
470
                        tokens=[32000, 32000],
471
                    ),
472
                    PlaceholderFeaturesInfo(
473
474
475
                        modality="pattern_1",
                        item_idx=1,
                        start_idx=5,
476
                        tokens=[32000, 32000],
477
478
479
                    ),
                ],
                "pattern_3": [
480
                    PlaceholderFeaturesInfo(
481
482
483
                        modality="pattern_3",
                        item_idx=0,
                        start_idx=7,
484
                        tokens=[1550, 918, 1550],
485
486
                    ),
                ],
487
                # No match for pattern_4 as it has lower priority than pattern_1
488
            }
489
490
        ),
        (
491
            [1, 32000, 32000, 32000, 32000, 32000, 1550, 918, 1550],
492
493
            {
                "pattern_1": [
494
                    PlaceholderFeaturesInfo(
495
496
497
                        modality="pattern_1",
                        item_idx=0,
                        start_idx=1,
498
                        tokens=[32000, 32000],
499
                    ),
500
                    PlaceholderFeaturesInfo(
501
502
503
                        modality="pattern_1",
                        item_idx=1,
                        start_idx=3,
504
                        tokens=[32000, 32000],
505
506
                    ),
                ],
507
                "pattern_4": [
508
                    PlaceholderFeaturesInfo(
509
510
511
                        modality="pattern_4",
                        item_idx=0,
                        start_idx=5,
512
                        tokens=[32000],
513
514
                    ),
                ],
515
                "pattern_3": [
516
                    PlaceholderFeaturesInfo(
517
518
519
                        modality="pattern_3",
                        item_idx=0,
                        start_idx=6,
520
                        tokens=[1550, 918, 1550],
521
522
523
                    ),
                ],
            }
524
525
526
        ),
    ]
)
527
# yapf: enable
528
def test_find_mm_placeholders(
529
530
531
532
533
534
535
    repl_by_key,
    prompt,
    expected,
):
    # Should not be used since there is nothing to convert to tokens
    mock_tokenizer = cast(AnyTokenizer, object())

536
537
    mm_prompt_repls = {
        key: [PromptReplacement(key, [], repl).bind(mock_tokenizer)]
538
        for key, repl in repl_by_key.items()
539
    }
540

541
542
543
544
545
546
547
    result = find_mm_placeholders(
        mm_prompt_repls,
        prompt,
        # Effectively match all occurrences in the prompt
        {key: 3
         for key in repl_by_key},
    )
548
549
550

    # Only displayed on error
    print("result:", result)
551
552

    # Manually constructed results
553
    assert result == expected
554
555


556
@pytest.mark.parametrize("model_id", ["llava-hf/llava-v1.6-mistral-7b-hf"])
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
@pytest.mark.parametrize(
    ("limit", "num_supported", "is_valid"),
    [(0, 0, True), (0, 1, True), (1, 0, False), (1, 1, True), (1, 2, True),
     (2, 1, False), (2, 2, True)],
)
def test_limit_mm_per_prompt_dummy(model_id, limit, num_supported, is_valid):
    limit_mm_per_prompt = {"image": limit}

    model_config = ModelConfig(
        model=model_id,
        task="auto",
        tokenizer=model_id,
        tokenizer_mode="auto",
        trust_remote_code=False,
        seed=0,
        dtype="half",
        revision=None,
        limit_mm_per_prompt=limit_mm_per_prompt,
    )

577
    processor = MULTIMODAL_REGISTRY.create_processor(
578
        model_config,
579
        tokenizer=cached_tokenizer_from_config(model_config),
580
    )
581
    profiler = MultiModalProfiler(processor)
582
583

    mock_supported_mm_limits = MagicMock(return_value={"image": num_supported})
584
    processor.info.get_supported_mm_limits = mock_supported_mm_limits
585
586
587
588
589
590
591

    if is_valid:
        exc_ctx = nullcontext()
    else:
        exc_ctx = pytest.raises(ValueError, match="this model only supports")

    with exc_ctx:
592
        profiler.get_dummy_data(model_config.max_model_len)
593
594


595
@pytest.mark.parametrize("model_id", ["llava-hf/llava-v1.6-mistral-7b-hf"])
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
@pytest.mark.parametrize(
    ("num_images", "limit", "is_valid"),
    [(0, 0, True), (0, 1, True), (1, 0, False), (1, 1, True), (1, 2, True),
     (2, 1, False), (2, 2, True)],
)
def test_limit_mm_per_prompt_apply(model_id, num_images, limit, is_valid):
    limit_mm_per_prompt = {"image": limit}

    model_config = ModelConfig(
        model=model_id,
        task="auto",
        tokenizer=model_id,
        tokenizer_mode="auto",
        trust_remote_code=False,
        seed=0,
        dtype="half",
        revision=None,
        limit_mm_per_prompt=limit_mm_per_prompt,
    )

616
    processor = MULTIMODAL_REGISTRY.create_processor(
617
        model_config,
618
        tokenizer=cached_tokenizer_from_config(model_config),
619
620
621
    )

    rng = np.random.RandomState(0)
622
    image = random_image(rng, min_wh=128, max_wh=256)
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
    if num_images == 0:
        mm_data = {}
    elif num_images == 1:
        mm_data = {"image": image}
    else:
        mm_data = {"image": [image] * num_images}

    if is_valid:
        exc_ctx = nullcontext()
    else:
        exc_ctx = pytest.raises(ValueError, match=f"passed {num_images} image")

    with exc_ctx:
        processor.apply(
            "<image>" * num_images,
            mm_data=mm_data,
            hf_processor_mm_kwargs={},
        )
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663


class _ProcessorProxy:

    def __init__(self, processor: ProcessorMixin) -> None:
        super().__init__()

        self.__processor = processor

    def __getattr__(self, key: str):
        return getattr(self.__processor, key)

    def __call__(
        self,
        text=None,
        images=None,
        videos=None,
        exists=None,
        return_tensors=None,
    ):
        return dict(exists=exists)


664
@pytest.mark.parametrize("model_id", ["Qwen/Qwen2-VL-2B-Instruct"])  # Dummy
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
# yapf: disable
@pytest.mark.parametrize(
    ("call_kwargs", "expected_kwargs"),
    [
        # Should ignore invalid kwargs
        ({"does_not_exist": 100}, {"exists": None}),
        ({"exists": 1}, {"exists": 1}),
        ({"does_not_exist": 100, "exists": 1}, {"exists": 1}),
    ],
)
# yapf: enable
def test_hf_processor_kwargs(model_id, call_kwargs, expected_kwargs):
    model_config = ModelConfig(
        model=model_id,
        task="auto",
        tokenizer=model_id,
        tokenizer_mode="auto",
        trust_remote_code=False,
        seed=0,
        dtype="half",
        revision=None,
    )

    processor = MULTIMODAL_REGISTRY.create_processor(
        model_config,
690
        tokenizer=cached_tokenizer_from_config(model_config),
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
    )
    orig_get_hf_processor = processor.info.get_hf_processor

    def get_hf_processor(self, **kwargs):
        assert kwargs == call_kwargs
        return _ProcessorProxy(orig_get_hf_processor())

    processor.info.get_hf_processor = MethodType(get_hf_processor,
                                                 processor.info)

    out_kwargs = processor._call_hf_processor(
        prompt="",
        mm_data={},
        mm_kwargs=call_kwargs,
    )

    assert out_kwargs == expected_kwargs