test_inputs.py 3.77 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
import pytest

6
from vllm.config import ModelConfig
7
from vllm.inputs import zip_enc_dec_prompts
8
from vllm.inputs.parse import parse_raw_prompts
9
from vllm.inputs.preprocess import InputPreprocessor
10

11
12
pytestmark = pytest.mark.cpu_test

13
STRING_INPUTS = [
14
15
16
17
18
    "",
    "foo",
    "foo bar",
    "foo baz bar",
    "foo bar qux baz",
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
]

TOKEN_INPUTS = [
    [-1],
    [1],
    [1, 2],
    [1, 3, 4],
    [1, 2, 4, 3],
]

INPUTS_SLICES = [
    slice(None, None, -1),
    slice(None, None, 2),
    slice(None, None, -2),
]


36
def test_parse_raw_single_batch_empty():
37
    with pytest.raises(ValueError, match="at least one prompt"):
38
        parse_raw_prompts([])
39
40

    with pytest.raises(ValueError, match="at least one prompt"):
41
        parse_raw_prompts([[]])
42
43


44
@pytest.mark.parametrize("string_input", STRING_INPUTS)
45
def test_parse_raw_single_batch_string_consistent(string_input: str):
46
    assert parse_raw_prompts(string_input) == parse_raw_prompts([string_input])
47
48


49
@pytest.mark.parametrize("token_input", TOKEN_INPUTS)
50
def test_parse_raw_single_batch_token_consistent(token_input: list[int]):
51
    assert parse_raw_prompts(token_input) == parse_raw_prompts([token_input])
52
53


54
@pytest.mark.parametrize("inputs_slice", INPUTS_SLICES)
55
def test_parse_raw_single_batch_string_slice(inputs_slice: slice):
56
57
58
    assert parse_raw_prompts(STRING_INPUTS)[inputs_slice] == parse_raw_prompts(
        STRING_INPUTS[inputs_slice]
    )
59
60


61
62
63
64
65
66
67
68
69
@pytest.mark.parametrize(
    "mm_processor_kwargs,expected_mm_kwargs",
    [
        (None, [{}, {}]),
        ({}, [{}, {}]),
        ({"foo": 100}, [{"foo": 100}, {"foo": 100}]),
        ([{"foo": 100}, {"bar": 200}], [{"foo": 100}, {"bar": 200}]),
    ],
)
70
71
def test_zip_enc_dec_prompts(mm_processor_kwargs, expected_mm_kwargs):
    """Test mm_processor_kwargs init for zipping enc/dec prompts."""
72
73
74
75
76
    encoder_prompts = ["An encoder prompt", "Another encoder prompt"]
    decoder_prompts = ["A decoder prompt", "Another decoder prompt"]
    zipped_prompts = zip_enc_dec_prompts(
        encoder_prompts, decoder_prompts, mm_processor_kwargs
    )
77
    assert len(zipped_prompts) == len(encoder_prompts) == len(decoder_prompts)
78
79
80
    for enc, dec, exp_kwargs, zipped in zip(
        encoder_prompts, decoder_prompts, expected_mm_kwargs, zipped_prompts
    ):
81
82
        assert isinstance(zipped, dict)
        assert len(zipped.keys()) == 3
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
        assert zipped["encoder_prompt"] == enc
        assert zipped["decoder_prompt"] == dec
        assert zipped["mm_processor_kwargs"] == exp_kwargs


@pytest.mark.parametrize(
    "model_id",
    [
        "facebook/opt-125m",
    ],
)
@pytest.mark.parametrize(
    "prompt",
    [
        {
            "prompt": "",
            "multi_modal_data": {"dummy": []},
100
        },
101
102
103
        {
            "prompt_token_ids": [],
            "multi_modal_data": {"dummy": []},
104
        },
105
106
    ],
)
107
108
def test_preprocessor_text_no_mm_inputs(model_id, prompt):
    model_config = ModelConfig(model=model_id)
109
    input_preprocessor = InputPreprocessor(model_config)
110
111
112
113
114

    with pytest.raises(ValueError, match="does not support multimodal inputs"):
        input_preprocessor.preprocess(prompt)


115
116
117
118
119
120
121
122
123
124
125
126
127
@pytest.mark.parametrize(
    "model_id",
    [
        "facebook/chameleon-7b",
    ],
)
@pytest.mark.parametrize(
    "prompt",
    [
        "",
        {"prompt_token_ids": []},
    ],
)
128
129
def test_preprocessor_always_mm_code_path(model_id, prompt):
    model_config = ModelConfig(model=model_id)
130
131
    input_preprocessor = InputPreprocessor(model_config)
    tokenizer = input_preprocessor.tokenizer
132
133
134
135
136
137

    # HF processor adds sep token
    sep_token_id = tokenizer.vocab[tokenizer.sep_token]

    processed_inputs = input_preprocessor.preprocess(prompt)
    assert sep_token_id in processed_inputs["prompt_token_ids"]