test_inputs.py 2.61 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
import pytest

6
from vllm.inputs import zip_enc_dec_prompts
7
from vllm.inputs.parse import parse_and_batch_prompt
8

9
10
pytestmark = pytest.mark.cpu_test

11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
STRING_INPUTS = [
    '',
    'foo',
    'foo bar',
    'foo baz bar',
    'foo bar qux baz',
]

TOKEN_INPUTS = [
    [-1],
    [1],
    [1, 2],
    [1, 3, 4],
    [1, 2, 4, 3],
]

INPUTS_SLICES = [
    slice(None, None, -1),
    slice(None, None, 2),
    slice(None, None, -2),
]


def test_parse_single_batch_empty():
    with pytest.raises(ValueError, match="at least one prompt"):
        parse_and_batch_prompt([])

    with pytest.raises(ValueError, match="at least one prompt"):
        parse_and_batch_prompt([[]])


@pytest.mark.parametrize('string_input', STRING_INPUTS)
def test_parse_single_batch_string_consistent(string_input: str):
    assert parse_and_batch_prompt(string_input) \
        == parse_and_batch_prompt([string_input])


@pytest.mark.parametrize('token_input', TOKEN_INPUTS)
49
def test_parse_single_batch_token_consistent(token_input: list[int]):
50
51
52
53
54
55
56
57
    assert parse_and_batch_prompt(token_input) \
        == parse_and_batch_prompt([token_input])


@pytest.mark.parametrize('inputs_slice', INPUTS_SLICES)
def test_parse_single_batch_string_slice(inputs_slice: slice):
    assert parse_and_batch_prompt(STRING_INPUTS)[inputs_slice] \
        == parse_and_batch_prompt(STRING_INPUTS[inputs_slice])
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82


# yapf: disable
@pytest.mark.parametrize('mm_processor_kwargs,expected_mm_kwargs', [
    (None, [{}, {}]),
    ({}, [{}, {}]),
    ({"foo": 100}, [{"foo": 100}, {"foo": 100}]),
    ([{"foo": 100}, {"bar": 200}], [{"foo": 100}, {"bar": 200}]),
])
# yapf: enable
def test_zip_enc_dec_prompts(mm_processor_kwargs, expected_mm_kwargs):
    """Test mm_processor_kwargs init for zipping enc/dec prompts."""
    encoder_prompts = ['An encoder prompt', 'Another encoder prompt']
    decoder_prompts = ['A decoder prompt', 'Another decoder prompt']
    zipped_prompts = zip_enc_dec_prompts(encoder_prompts, decoder_prompts,
                                         mm_processor_kwargs)
    assert len(zipped_prompts) == len(encoder_prompts) == len(decoder_prompts)
    for enc, dec, exp_kwargs, zipped in zip(encoder_prompts, decoder_prompts,
                                            expected_mm_kwargs,
                                            zipped_prompts):
        assert isinstance(zipped, dict)
        assert len(zipped.keys()) == 3
        assert zipped['encoder_prompt'] == enc
        assert zipped['decoder_prompt'] == dec
        assert zipped['mm_processor_kwargs'] == exp_kwargs