test_inputs.py 2.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
import pytest

6
from vllm.config import ModelConfig
7
from vllm.inputs import zip_enc_dec_prompts
8
from vllm.inputs.preprocess import InputPreprocessor
9

10
11
pytestmark = pytest.mark.cpu_test

12

13
14
15
16
17
18
19
20
21
@pytest.mark.parametrize(
    "mm_processor_kwargs,expected_mm_kwargs",
    [
        (None, [{}, {}]),
        ({}, [{}, {}]),
        ({"foo": 100}, [{"foo": 100}, {"foo": 100}]),
        ([{"foo": 100}, {"bar": 200}], [{"foo": 100}, {"bar": 200}]),
    ],
)
22
23
def test_zip_enc_dec_prompts(mm_processor_kwargs, expected_mm_kwargs):
    """Test mm_processor_kwargs init for zipping enc/dec prompts."""
24
25
26
27
28
    encoder_prompts = ["An encoder prompt", "Another encoder prompt"]
    decoder_prompts = ["A decoder prompt", "Another decoder prompt"]
    zipped_prompts = zip_enc_dec_prompts(
        encoder_prompts, decoder_prompts, mm_processor_kwargs
    )
29
    assert len(zipped_prompts) == len(encoder_prompts) == len(decoder_prompts)
30
31
32
    for enc, dec, exp_kwargs, zipped in zip(
        encoder_prompts, decoder_prompts, expected_mm_kwargs, zipped_prompts
    ):
33
34
        assert isinstance(zipped, dict)
        assert len(zipped.keys()) == 3
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
        assert zipped["encoder_prompt"] == enc
        assert zipped["decoder_prompt"] == dec
        assert zipped["mm_processor_kwargs"] == exp_kwargs


@pytest.mark.parametrize(
    "model_id",
    [
        "facebook/chameleon-7b",
    ],
)
@pytest.mark.parametrize(
    "prompt",
    [
        "",
        {"prompt_token_ids": []},
    ],
)
53
54
55
56
57
58
59
@pytest.mark.skip(
    reason=(
        "Applying huggingface processor on text inputs results in "
        "significant performance regression for multimodal models. "
        "See https://github.com/vllm-project/vllm/issues/26320"
    )
)
60
61
def test_preprocessor_always_mm_code_path(model_id, prompt):
    model_config = ModelConfig(model=model_id)
62
    input_preprocessor = InputPreprocessor(model_config)
63
64

    # HF processor adds sep token
65
    tokenizer = input_preprocessor.get_tokenizer()
66
67
68
69
    sep_token_id = tokenizer.vocab[tokenizer.sep_token]

    processed_inputs = input_preprocessor.preprocess(prompt)
    assert sep_token_id in processed_inputs["prompt_token_ids"]