test_serial_utils.py 8.56 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
from collections import UserDict
from dataclasses import dataclass

6
import msgspec
7
import numpy as np
8
import pytest
9
10
import torch

11
12
13
14
15
16
17
18
19
from vllm.multimodal.inputs import (
    MultiModalBatchedField,
    MultiModalFieldElem,
    MultiModalFlatField,
    MultiModalKwargsItem,
    MultiModalKwargsItems,
    MultiModalSharedField,
    NestedTensors,
)
20
21
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder

22
23
pytestmark = pytest.mark.cpu_test

24
25
26
27
28
29
30
31
32
33
34
35
36
37

class UnrecognizedType(UserDict):
    def __init__(self, an_int: int):
        super().__init__()
        self.an_int = an_int


@dataclass
class MyType:
    tensor1: torch.Tensor
    a_string: str
    list_of_tensors: list[torch.Tensor]
    numpy_array: np.ndarray
    unrecognized: UnrecognizedType
38
39
40
41
    small_f_contig_tensor: torch.Tensor
    large_f_contig_tensor: torch.Tensor
    small_non_contig_tensor: torch.Tensor
    large_non_contig_tensor: torch.Tensor
42
    empty_tensor: torch.Tensor
43
44


45
def test_encode_decode(monkeypatch: pytest.MonkeyPatch):
46
47
    """Test encode/decode loop with zero-copy tensors."""

48
49
    with monkeypatch.context() as m:
        m.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
50

51
        obj = MyType(
52
            tensor1=torch.randint(low=0, high=100, size=(1024,), dtype=torch.int32),
53
54
55
56
57
58
59
            a_string="hello",
            list_of_tensors=[
                torch.rand((1, 10), dtype=torch.float32),
                torch.rand((3, 5, 4000), dtype=torch.float64),
                torch.tensor(1984),  # test scalar too
                # Make sure to test bf16 which numpy doesn't support.
                torch.rand((3, 5, 1000), dtype=torch.bfloat16),
60
61
62
                torch.tensor(
                    [float("-inf"), float("inf")] * 1024, dtype=torch.bfloat16
                ),
63
64
65
66
67
68
69
70
71
            ],
            numpy_array=np.arange(512),
            unrecognized=UnrecognizedType(33),
            small_f_contig_tensor=torch.rand(5, 4).t(),
            large_f_contig_tensor=torch.rand(1024, 4).t(),
            small_non_contig_tensor=torch.rand(2, 4)[:, 1:3],
            large_non_contig_tensor=torch.rand(1024, 512)[:, 10:20],
            empty_tensor=torch.empty(0),
        )
72

73
74
        encoder = MsgpackEncoder(size_threshold=256)
        decoder = MsgpackDecoder(MyType)
75

76
77
78
79
80
81
        encoded = encoder.encode(obj)

        # There should be the main buffer + 4 large tensor buffers
        # + 1 large numpy array. "large" is <= 512 bytes.
        # The two small tensors are encoded inline.
        assert len(encoded) == 8
82

83
        decoded: MyType = decoder.decode(encoded)
84

85
        assert_equal(decoded, obj)
86

87
        # Test encode_into case
88

89
        preallocated = bytearray()
90

91
        encoded2 = encoder.encode_into(obj, preallocated)
92

93
94
        assert len(encoded2) == 8
        assert encoded2[0] is preallocated
95

96
        decoded2: MyType = decoder.decode(encoded2)
97

98
        assert_equal(decoded2, obj)
99
100


101
class MyRequest(msgspec.Struct):
102
    mm: list[MultiModalKwargsItems] | None
103
104
105


def test_multimodal_kwargs():
106
    e1 = MultiModalFieldElem(
107
108
109
110
        "audio",
        "a0",
        torch.zeros(1000, dtype=torch.bfloat16),
        MultiModalBatchedField(),
111
    )
112
113
114
115
    e2 = MultiModalFieldElem(
        "video",
        "v0",
        [torch.zeros(1000, dtype=torch.int8) for _ in range(4)],
116
117
118
119
        MultiModalFlatField(
            slices=[[slice(1, 2, 3), slice(4, 5, 6)], [slice(None, 2)]],
            dim=0,
        ),
120
121
    )
    e3 = MultiModalFieldElem(
122
123
124
125
        "image",
        "i0",
        torch.zeros(1000, dtype=torch.int32),
        MultiModalSharedField(batch_size=4),
126
    )
127
    e4 = MultiModalFieldElem(
128
129
130
        "image",
        "i1",
        torch.zeros(1000, dtype=torch.int32),
131
        MultiModalFlatField(slices=[slice(1, 2, 3), slice(4, 5, 6)], dim=2),
132
    )
133
134
135
    audio = MultiModalKwargsItem.from_elems([e1])
    video = MultiModalKwargsItem.from_elems([e2])
    image = MultiModalKwargsItem.from_elems([e3, e4])
136
    mm = MultiModalKwargsItems.from_seq([audio, video, image])
137
138
139
140
141
142
143
144
145
146
147
148
149

    # pack mm kwargs into a mock request so that it can be decoded properly
    req = MyRequest([mm])

    encoder = MsgpackEncoder()
    decoder = MsgpackDecoder(MyRequest)

    encoded = encoder.encode(req)

    assert len(encoded) == 8

    total_len = sum(memoryview(x).cast("B").nbytes for x in encoded)

150
151
    # expected total encoding length, should be 14395, +-20 for minor changes
    assert 14375 <= total_len <= 14425
152
153
    decoded = decoder.decode(encoded).mm[0]
    assert isinstance(decoded, MultiModalKwargsItems)
154
155

    # check all modalities were recovered and do some basic sanity checks
156
157
    assert len(decoded) == 3
    images = decoded["image"]
158
159
160
161
162
    assert len(images) == 1
    assert len(images[0].items()) == 2
    assert list(images[0].keys()) == ["i0", "i1"]

    # check the tensor contents and layout in the main dict
163
164
165
    mm_data = mm.get_data()
    decoded_data = decoded.get_data()
    assert all(nested_equal(mm_data[k], decoded_data[k]) for k in mm_data)
166
167
168
169
170


def nested_equal(a: NestedTensors, b: NestedTensors):
    if isinstance(a, torch.Tensor):
        return torch.equal(a, b)
171
    return all(nested_equal(x, y) for x, y in zip(a, b))
172
173


174
175
176
177
def assert_equal(obj1: MyType, obj2: MyType):
    assert torch.equal(obj1.tensor1, obj2.tensor1)
    assert obj1.a_string == obj2.a_string
    assert all(
178
179
        torch.equal(a, b) for a, b in zip(obj1.list_of_tensors, obj2.list_of_tensors)
    )
180
181
    assert np.array_equal(obj1.numpy_array, obj2.numpy_array)
    assert obj1.unrecognized.an_int == obj2.unrecognized.an_int
182
183
    assert torch.equal(obj1.small_f_contig_tensor, obj2.small_f_contig_tensor)
    assert torch.equal(obj1.large_f_contig_tensor, obj2.large_f_contig_tensor)
184
185
    assert torch.equal(obj1.small_non_contig_tensor, obj2.small_non_contig_tensor)
    assert torch.equal(obj1.large_non_contig_tensor, obj2.large_non_contig_tensor)
186
    assert torch.equal(obj1.empty_tensor, obj2.empty_tensor)
187
188


189
def test_dict_serialization():
190
    """Test encoding and decoding of a generic Python object using pickle."""
191
192
    encoder = MsgpackEncoder()
    decoder = MsgpackDecoder()
193
194
195
196
197
198
199
200
201
202
203
204
205
206

    # Create a sample Python object
    obj = {"key": "value", "number": 42}

    # Encode the object
    encoded = encoder.encode(obj)

    # Decode the object
    decoded = decoder.decode(encoded)

    # Verify the decoded object matches the original
    assert obj == decoded, "Decoded object does not match the original object."


207
def test_tensor_serialization():
208
    """Test encoding and decoding of a torch.Tensor."""
209
210
    encoder = MsgpackEncoder()
    decoder = MsgpackDecoder(torch.Tensor)
211
212
213
214
215
216
217
218
219
220
221

    # Create a sample tensor
    tensor = torch.rand(10, 10)

    # Encode the tensor
    encoded = encoder.encode(tensor)

    # Decode the tensor
    decoded = decoder.decode(encoded)

    # Verify the decoded tensor matches the original
222
223
224
    assert torch.allclose(tensor, decoded), (
        "Decoded tensor does not match the original tensor."
    )
225
226


227
def test_numpy_array_serialization():
228
    """Test encoding and decoding of a numpy array."""
229
230
    encoder = MsgpackEncoder()
    decoder = MsgpackDecoder(np.ndarray)
231
232
233
234
235
236
237
238
239
240
241

    # Create a sample numpy array
    array = np.random.rand(10, 10)

    # Encode the numpy array
    encoded = encoder.encode(array)

    # Decode the numpy array
    decoded = decoder.decode(encoded)

    # Verify the decoded array matches the original
242
243
244
    assert np.allclose(array, decoded), (
        "Decoded numpy array does not match the original array."
    )
245
246
247
248
249
250
251
252
253
254


class CustomClass:
    def __init__(self, value):
        self.value = value

    def __eq__(self, other):
        return isinstance(other, CustomClass) and self.value == other.value


255
def test_custom_class_serialization_allowed_with_pickle(
256
257
    monkeypatch: pytest.MonkeyPatch,
):
258
259
    """Test that serializing a custom class succeeds when allow_pickle=True."""

260
261
262
263
    with monkeypatch.context() as m:
        m.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
        encoder = MsgpackEncoder()
        decoder = MsgpackDecoder(CustomClass)
264

265
        obj = CustomClass("test_value")
266

267
268
        # Encode the custom class
        encoded = encoder.encode(obj)
269

270
271
272
273
        # Decode the custom class
        decoded = decoder.decode(encoded)

        # Verify the decoded object matches the original
274
        assert obj == decoded, "Decoded object does not match the original object."
275
276
277
278


def test_custom_class_serialization_disallowed_without_pickle():
    """Test that serializing a custom class fails when allow_pickle=False."""
279
    encoder = MsgpackEncoder()
280
281
282
283
284
285

    obj = CustomClass("test_value")

    with pytest.raises(TypeError):
        # Attempt to encode the custom class
        encoder.encode(obj)