test_inputs.py 4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
import pytest

6
from vllm.config import ModelConfig
7
from vllm.inputs import zip_enc_dec_prompts
8
from vllm.inputs.parse import parse_raw_prompts
9
10
from vllm.inputs.preprocess import InputPreprocessor
from vllm.transformers_utils.tokenizer import init_tokenizer_from_configs
11

12
13
pytestmark = pytest.mark.cpu_test

14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
STRING_INPUTS = [
    '',
    'foo',
    'foo bar',
    'foo baz bar',
    'foo bar qux baz',
]

TOKEN_INPUTS = [
    [-1],
    [1],
    [1, 2],
    [1, 3, 4],
    [1, 2, 4, 3],
]

INPUTS_SLICES = [
    slice(None, None, -1),
    slice(None, None, 2),
    slice(None, None, -2),
]


37
def test_parse_raw_single_batch_empty():
38
    with pytest.raises(ValueError, match="at least one prompt"):
39
        parse_raw_prompts([])
40
41

    with pytest.raises(ValueError, match="at least one prompt"):
42
        parse_raw_prompts([[]])
43
44
45


@pytest.mark.parametrize('string_input', STRING_INPUTS)
46
47
48
def test_parse_raw_single_batch_string_consistent(string_input: str):
    assert parse_raw_prompts(string_input) \
        == parse_raw_prompts([string_input])
49
50
51


@pytest.mark.parametrize('token_input', TOKEN_INPUTS)
52
53
54
def test_parse_raw_single_batch_token_consistent(token_input: list[int]):
    assert parse_raw_prompts(token_input) \
        == parse_raw_prompts([token_input])
55
56
57


@pytest.mark.parametrize('inputs_slice', INPUTS_SLICES)
58
59
60
def test_parse_raw_single_batch_string_slice(inputs_slice: slice):
    assert parse_raw_prompts(STRING_INPUTS)[inputs_slice] \
        == parse_raw_prompts(STRING_INPUTS[inputs_slice])
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85


# yapf: disable
@pytest.mark.parametrize('mm_processor_kwargs,expected_mm_kwargs', [
    (None, [{}, {}]),
    ({}, [{}, {}]),
    ({"foo": 100}, [{"foo": 100}, {"foo": 100}]),
    ([{"foo": 100}, {"bar": 200}], [{"foo": 100}, {"bar": 200}]),
])
# yapf: enable
def test_zip_enc_dec_prompts(mm_processor_kwargs, expected_mm_kwargs):
    """Test mm_processor_kwargs init for zipping enc/dec prompts."""
    encoder_prompts = ['An encoder prompt', 'Another encoder prompt']
    decoder_prompts = ['A decoder prompt', 'Another decoder prompt']
    zipped_prompts = zip_enc_dec_prompts(encoder_prompts, decoder_prompts,
                                         mm_processor_kwargs)
    assert len(zipped_prompts) == len(encoder_prompts) == len(decoder_prompts)
    for enc, dec, exp_kwargs, zipped in zip(encoder_prompts, decoder_prompts,
                                            expected_mm_kwargs,
                                            zipped_prompts):
        assert isinstance(zipped, dict)
        assert len(zipped.keys()) == 3
        assert zipped['encoder_prompt'] == enc
        assert zipped['decoder_prompt'] == dec
        assert zipped['mm_processor_kwargs'] == exp_kwargs
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132


@pytest.mark.parametrize("model_id", [
    "facebook/opt-125m",
])
@pytest.mark.parametrize("prompt", [
    {
        "prompt": "",
        "multi_modal_data": {
            "dummy": []
        },
    },
    {
        "prompt_token_ids": [],
        "multi_modal_data": {
            "dummy": []
        },
    },
])
def test_preprocessor_text_no_mm_inputs(model_id, prompt):
    model_config = ModelConfig(model=model_id)
    tokenizer = init_tokenizer_from_configs(model_config)
    input_preprocessor = InputPreprocessor(model_config, tokenizer)

    with pytest.raises(ValueError, match="does not support multimodal inputs"):
        input_preprocessor.preprocess(prompt)


@pytest.mark.parametrize("model_id", [
    "facebook/chameleon-7b",
])
@pytest.mark.parametrize("prompt", [
    "",
    {
        "prompt_token_ids": []
    },
])
def test_preprocessor_always_mm_code_path(model_id, prompt):
    model_config = ModelConfig(model=model_id)
    tokenizer = init_tokenizer_from_configs(model_config)
    input_preprocessor = InputPreprocessor(model_config, tokenizer)

    # HF processor adds sep token
    sep_token_id = tokenizer.vocab[tokenizer.sep_token]

    processed_inputs = input_preprocessor.preprocess(prompt)
    assert sep_token_id in processed_inputs["prompt_token_ids"]