test_inputs.py 1.28 KB
Newer Older
1
2
3
4
from typing import List

import pytest

5
from vllm.inputs.parse import parse_and_batch_prompt
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53

STRING_INPUTS = [
    '',
    'foo',
    'foo bar',
    'foo baz bar',
    'foo bar qux baz',
]

TOKEN_INPUTS = [
    [-1],
    [1],
    [1, 2],
    [1, 3, 4],
    [1, 2, 4, 3],
]

INPUTS_SLICES = [
    slice(None, None, -1),
    slice(None, None, 2),
    slice(None, None, -2),
]


def test_parse_single_batch_empty():
    with pytest.raises(ValueError, match="at least one prompt"):
        parse_and_batch_prompt([])

    with pytest.raises(ValueError, match="at least one prompt"):
        parse_and_batch_prompt([[]])


@pytest.mark.parametrize('string_input', STRING_INPUTS)
def test_parse_single_batch_string_consistent(string_input: str):
    assert parse_and_batch_prompt(string_input) \
        == parse_and_batch_prompt([string_input])


@pytest.mark.parametrize('token_input', TOKEN_INPUTS)
def test_parse_single_batch_token_consistent(token_input: List[int]):
    assert parse_and_batch_prompt(token_input) \
        == parse_and_batch_prompt([token_input])


@pytest.mark.parametrize('inputs_slice', INPUTS_SLICES)
def test_parse_single_batch_string_slice(inputs_slice: slice):
    assert parse_and_batch_prompt(STRING_INPUTS)[inputs_slice] \
        == parse_and_batch_prompt(STRING_INPUTS[inputs_slice])