test_inputs.py 3.28 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import pytest
5
6
import torch

7
from vllm.multimodal.inputs import MultiModalKwargs, NestedTensors
8

9
10
pytestmark = pytest.mark.cpu_test

11
12
13

def assert_nested_tensors_equal(expected: NestedTensors,
                                actual: NestedTensors):
14
    assert type(expected) == type(actual)  # noqa: E721
15
16
17
18
19
20
21
    if isinstance(expected, torch.Tensor):
        assert torch.equal(expected, actual)
    else:
        for expected_item, actual_item in zip(expected, actual):
            assert_nested_tensors_equal(expected_item, actual_item)


22
23
def assert_multimodal_inputs_equal(expected: MultiModalKwargs,
                                   actual: MultiModalKwargs):
24
25
26
27
28
29
30
    assert set(expected.keys()) == set(actual.keys())
    for key in expected:
        assert_nested_tensors_equal(expected[key], actual[key])


def test_multimodal_input_batch_single_tensor():
    t = torch.rand([1, 2])
31
    result = MultiModalKwargs.batch([{"image": t}])
32
33
34
35
36
37
38
    assert_multimodal_inputs_equal(result, {"image": t.unsqueeze(0)})


def test_multimodal_input_batch_multiple_tensors():
    a = torch.rand([1, 1, 2])
    b = torch.rand([1, 1, 2])
    c = torch.rand([1, 1, 2])
39
    result = MultiModalKwargs.batch([{"image": a}, {"image": b}, {"image": c}])
40
41
42
43
44
45
46
    assert_multimodal_inputs_equal(result, {"image": torch.stack([a, b, c])})


def test_multimodal_input_batch_multiple_heterogeneous_tensors():
    a = torch.rand([1, 2, 2])
    b = torch.rand([1, 3, 2])
    c = torch.rand([1, 4, 2])
47
    result = MultiModalKwargs.batch([{"image": a}, {"image": b}, {"image": c}])
48
49
50
51
52
53
54
    assert_multimodal_inputs_equal(result, {"image": [a, b, c]})


def test_multimodal_input_batch_nested_tensors():
    a = torch.rand([2, 3])
    b = torch.rand([2, 3])
    c = torch.rand([2, 3])
55
    result = MultiModalKwargs.batch([{
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
        "image": [a]
    }, {
        "image": [b]
    }, {
        "image": [c]
    }])
    assert_multimodal_inputs_equal(result, {
        "image":
        torch.stack([a.unsqueeze(0),
                     b.unsqueeze(0),
                     c.unsqueeze(0)])
    })


def test_multimodal_input_batch_heterogeneous_lists():
    a = torch.rand([1, 2, 3])
    b = torch.rand([1, 2, 3])
    c = torch.rand([1, 2, 3])
74
    result = MultiModalKwargs.batch([{"image": [a, b]}, {"image": [c]}])
75
76
77
78
79
80
81
82
83
84
    assert_multimodal_inputs_equal(
        result,
        {"image": [torch.stack([a, b]), c.unsqueeze(0)]})


def test_multimodal_input_batch_multiple_batchable_lists():
    a = torch.rand([1, 2, 3])
    b = torch.rand([1, 2, 3])
    c = torch.rand([1, 2, 3])
    d = torch.rand([1, 2, 3])
85
    result = MultiModalKwargs.batch([{"image": [a, b]}, {"image": [c, d]}])
86
87
88
89
    assert_multimodal_inputs_equal(
        result,
        {"image": torch.stack([torch.stack([a, b]),
                               torch.stack([c, d])])})
90
91
92
93
94
95
96


def test_multimodal_input_batch_mixed_stacking_depths():
    a = torch.rand([1, 2, 3])
    b = torch.rand([1, 3, 3])
    c = torch.rand([1, 4, 3])

97
    result = MultiModalKwargs.batch([{"image": [a, b]}, {"image": [c]}])
98
99
    assert_multimodal_inputs_equal(result, {"image": [[a, b], c.unsqueeze(0)]})

100
    result = MultiModalKwargs.batch([{"image": [a]}, {"image": [b, c]}])
101
    assert_multimodal_inputs_equal(result, {"image": [a.unsqueeze(0), [b, c]]})