test_serial_utils.py 14 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
from collections import UserDict
from dataclasses import dataclass

6
import msgspec
7
import numpy as np
8
import pytest
9
10
import torch

11
12
13
14
15
16
17
18
19
from vllm.multimodal.inputs import (
    MultiModalBatchedField,
    MultiModalFieldElem,
    MultiModalFlatField,
    MultiModalKwargsItem,
    MultiModalKwargsItems,
    MultiModalSharedField,
    NestedTensors,
)
20
21
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder

22
23
pytestmark = pytest.mark.cpu_test

24
25
26
27
28
29
30
31
32
33
34
35
36
37

class UnrecognizedType(UserDict):
    def __init__(self, an_int: int):
        super().__init__()
        self.an_int = an_int


@dataclass
class MyType:
    tensor1: torch.Tensor
    a_string: str
    list_of_tensors: list[torch.Tensor]
    numpy_array: np.ndarray
    unrecognized: UnrecognizedType
38
39
40
41
    small_f_contig_tensor: torch.Tensor
    large_f_contig_tensor: torch.Tensor
    small_non_contig_tensor: torch.Tensor
    large_non_contig_tensor: torch.Tensor
42
    empty_tensor: torch.Tensor
43
44


45
def test_encode_decode(monkeypatch: pytest.MonkeyPatch):
46
47
    """Test encode/decode loop with zero-copy tensors."""

48
49
    with monkeypatch.context() as m:
        m.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
50

51
        obj = MyType(
52
            tensor1=torch.randint(low=0, high=100, size=(1024,), dtype=torch.int32),
53
54
55
56
57
58
59
            a_string="hello",
            list_of_tensors=[
                torch.rand((1, 10), dtype=torch.float32),
                torch.rand((3, 5, 4000), dtype=torch.float64),
                torch.tensor(1984),  # test scalar too
                # Make sure to test bf16 which numpy doesn't support.
                torch.rand((3, 5, 1000), dtype=torch.bfloat16),
60
61
62
                torch.tensor(
                    [float("-inf"), float("inf")] * 1024, dtype=torch.bfloat16
                ),
63
64
65
66
67
68
69
70
71
            ],
            numpy_array=np.arange(512),
            unrecognized=UnrecognizedType(33),
            small_f_contig_tensor=torch.rand(5, 4).t(),
            large_f_contig_tensor=torch.rand(1024, 4).t(),
            small_non_contig_tensor=torch.rand(2, 4)[:, 1:3],
            large_non_contig_tensor=torch.rand(1024, 512)[:, 10:20],
            empty_tensor=torch.empty(0),
        )
72

73
74
        encoder = MsgpackEncoder(size_threshold=256)
        decoder = MsgpackDecoder(MyType)
75

76
77
78
79
80
81
        encoded = encoder.encode(obj)

        # There should be the main buffer + 4 large tensor buffers
        # + 1 large numpy array. "large" is <= 512 bytes.
        # The two small tensors are encoded inline.
        assert len(encoded) == 8
82

83
        decoded: MyType = decoder.decode(encoded)
84

85
        assert_equal(decoded, obj)
86

87
        # Test encode_into case
88

89
        preallocated = bytearray()
90

91
        encoded2 = encoder.encode_into(obj, preallocated)
92

93
94
        assert len(encoded2) == 8
        assert encoded2[0] is preallocated
95

96
        decoded2: MyType = decoder.decode(encoded2)
97

98
        assert_equal(decoded2, obj)
99
100


101
class MyRequest(msgspec.Struct):
102
    mm: list[MultiModalKwargsItems] | None
103
104
105


def test_multimodal_kwargs():
106
    e1 = MultiModalFieldElem(
107
108
        torch.zeros(1000, dtype=torch.bfloat16),
        MultiModalBatchedField(),
109
    )
110
111
    e2 = MultiModalFieldElem(
        [torch.zeros(1000, dtype=torch.int8) for _ in range(4)],
112
113
114
115
        MultiModalFlatField(
            slices=[[slice(1, 2, 3), slice(4, 5, 6)], [slice(None, 2)]],
            dim=0,
        ),
116
117
    )
    e3 = MultiModalFieldElem(
118
119
        torch.zeros(1000, dtype=torch.int32),
        MultiModalSharedField(batch_size=4),
120
    )
121
    e4 = MultiModalFieldElem(
122
        torch.zeros(1000, dtype=torch.int32),
123
        MultiModalFlatField(slices=[slice(1, 2, 3), slice(4, 5, 6)], dim=2),
124
    )
125
126
127
128
129
130
131
    mm = MultiModalKwargsItems(
        {
            "audio": [MultiModalKwargsItem({"a0": e1})],
            "video": [MultiModalKwargsItem({"v0": e2})],
            "image": [MultiModalKwargsItem({"i0": e3, "i1": e4})],
        }
    )
132
133
134
135
136
137
138
139
140
141
142
143
144

    # pack mm kwargs into a mock request so that it can be decoded properly
    req = MyRequest([mm])

    encoder = MsgpackEncoder()
    decoder = MsgpackDecoder(MyRequest)

    encoded = encoder.encode(req)

    assert len(encoded) == 8

    total_len = sum(memoryview(x).cast("B").nbytes for x in encoded)

145
146
    # expected total encoding length, should be 14319, +-20 for minor changes
    assert 14300 <= total_len <= 14340
147
148
    decoded = decoder.decode(encoded).mm[0]
    assert isinstance(decoded, MultiModalKwargsItems)
149
150

    # check all modalities were recovered and do some basic sanity checks
151
152
    assert len(decoded) == 3
    images = decoded["image"]
153
154
155
156
157
    assert len(images) == 1
    assert len(images[0].items()) == 2
    assert list(images[0].keys()) == ["i0", "i1"]

    # check the tensor contents and layout in the main dict
158
159
160
    mm_data = mm.get_data()
    decoded_data = decoded.get_data()
    assert all(nested_equal(mm_data[k], decoded_data[k]) for k in mm_data)
161
162
163
164
165


def nested_equal(a: NestedTensors, b: NestedTensors):
    if isinstance(a, torch.Tensor):
        return torch.equal(a, b)
166
    return all(nested_equal(x, y) for x, y in zip(a, b))
167
168


169
170
171
172
def assert_equal(obj1: MyType, obj2: MyType):
    assert torch.equal(obj1.tensor1, obj2.tensor1)
    assert obj1.a_string == obj2.a_string
    assert all(
173
174
        torch.equal(a, b) for a, b in zip(obj1.list_of_tensors, obj2.list_of_tensors)
    )
175
176
    assert np.array_equal(obj1.numpy_array, obj2.numpy_array)
    assert obj1.unrecognized.an_int == obj2.unrecognized.an_int
177
178
    assert torch.equal(obj1.small_f_contig_tensor, obj2.small_f_contig_tensor)
    assert torch.equal(obj1.large_f_contig_tensor, obj2.large_f_contig_tensor)
179
180
    assert torch.equal(obj1.small_non_contig_tensor, obj2.small_non_contig_tensor)
    assert torch.equal(obj1.large_non_contig_tensor, obj2.large_non_contig_tensor)
181
    assert torch.equal(obj1.empty_tensor, obj2.empty_tensor)
182
183


184
def test_dict_serialization():
185
    """Test encoding and decoding of a generic Python object using pickle."""
186
187
    encoder = MsgpackEncoder()
    decoder = MsgpackDecoder()
188
189
190
191
192
193
194
195
196
197
198
199
200
201

    # Create a sample Python object
    obj = {"key": "value", "number": 42}

    # Encode the object
    encoded = encoder.encode(obj)

    # Decode the object
    decoded = decoder.decode(encoded)

    # Verify the decoded object matches the original
    assert obj == decoded, "Decoded object does not match the original object."


202
def test_tensor_serialization():
203
    """Test encoding and decoding of a torch.Tensor."""
204
205
    encoder = MsgpackEncoder()
    decoder = MsgpackDecoder(torch.Tensor)
206
207
208
209
210
211
212
213
214
215
216

    # Create a sample tensor
    tensor = torch.rand(10, 10)

    # Encode the tensor
    encoded = encoder.encode(tensor)

    # Decode the tensor
    decoded = decoder.decode(encoded)

    # Verify the decoded tensor matches the original
217
218
219
    assert torch.allclose(tensor, decoded), (
        "Decoded tensor does not match the original tensor."
    )
220
221


222
def test_numpy_array_serialization():
223
    """Test encoding and decoding of a numpy array."""
224
225
    encoder = MsgpackEncoder()
    decoder = MsgpackDecoder(np.ndarray)
226
227
228
229
230
231
232
233
234
235
236

    # Create a sample numpy array
    array = np.random.rand(10, 10)

    # Encode the numpy array
    encoded = encoder.encode(array)

    # Decode the numpy array
    decoded = decoder.decode(encoded)

    # Verify the decoded array matches the original
237
238
239
    assert np.allclose(array, decoded), (
        "Decoded numpy array does not match the original array."
    )
240
241
242
243
244
245
246
247
248
249


class CustomClass:
    def __init__(self, value):
        self.value = value

    def __eq__(self, other):
        return isinstance(other, CustomClass) and self.value == other.value


250
def test_custom_class_serialization_allowed_with_pickle(
251
252
    monkeypatch: pytest.MonkeyPatch,
):
253
254
    """Test that serializing a custom class succeeds when allow_pickle=True."""

255
256
257
258
    with monkeypatch.context() as m:
        m.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
        encoder = MsgpackEncoder()
        decoder = MsgpackDecoder(CustomClass)
259

260
        obj = CustomClass("test_value")
261

262
263
        # Encode the custom class
        encoded = encoder.encode(obj)
264

265
266
267
268
        # Decode the custom class
        decoded = decoder.decode(encoded)

        # Verify the decoded object matches the original
269
        assert obj == decoded, "Decoded object does not match the original object."
270
271
272
273


def test_custom_class_serialization_disallowed_without_pickle():
    """Test that serializing a custom class fails when allow_pickle=False."""
274
    encoder = MsgpackEncoder()
275
276
277
278
279
280

    obj = CustomClass("test_value")

    with pytest.raises(TypeError):
        # Attempt to encode the custom class
        encoder.encode(obj)
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425


@dataclass
class RequestWithTensor:
    """Mock request with non-multimodal tensor field like EngineCoreRequest."""

    prompt_embeds: torch.Tensor | None
    data: str


def test_non_multimodal_tensor_with_ipc():
    """Test that non-multimodal tensor fields work correctly with IPC enabled.

    This reproduces the bug where fields like prompt_embeds: torch.Tensor | None
    would fail to decode when IPC is enabled because _decode_tensor expected a
    raw tensor tuple but received a msgpack-decoded TensorIpcHandle list.
    """
    import torch.multiprocessing as torch_mp

    from vllm.v1.engine.tensor_ipc import TensorIpcReceiver, TensorIpcSender

    # Create tensor queues for IPC
    tensor_queues = [torch_mp.Queue()]

    # Create encoder with IPC sender
    sender = TensorIpcSender(tensor_queues[0])
    encoder = MsgpackEncoder(oob_tensor_consumer=sender)

    # Create decoder with IPC receiver
    receiver = TensorIpcReceiver(tensor_queues[0])
    decoder = MsgpackDecoder(RequestWithTensor, oob_tensor_provider=receiver)

    # Create a request with a non-multimodal tensor
    original_tensor = torch.randn(5, 10, dtype=torch.float32)
    request = RequestWithTensor(prompt_embeds=original_tensor, data="test_data")

    # Encode the request - this should send the tensor via IPC
    encoded = encoder.encode(request)

    # Verify encoding succeeded
    assert len(encoded) > 0

    # Decode the request - this should retrieve the tensor from IPC queue
    # Previously this would fail because the decoder tried to unpack the
    # handle list as raw tensor bytes metadata.
    decoded = decoder.decode(encoded)

    # Verify the decoded request matches the original
    assert isinstance(decoded, RequestWithTensor)
    assert decoded.data == "test_data"
    assert decoded.prompt_embeds is not None
    assert torch.allclose(decoded.prompt_embeds, original_tensor), (
        "Decoded tensor does not match the original tensor."
    )


def test_non_multimodal_tensor_with_ipc_none_value():
    """Test that None values for tensor fields work correctly with IPC enabled."""
    import torch.multiprocessing as torch_mp

    from vllm.v1.engine.tensor_ipc import TensorIpcReceiver, TensorIpcSender

    # Create tensor queues for IPC
    tensor_queues = [torch_mp.Queue()]

    # Create encoder with IPC sender
    sender = TensorIpcSender(tensor_queues[0])
    encoder = MsgpackEncoder(oob_tensor_consumer=sender)

    # Create decoder with IPC receiver
    receiver = TensorIpcReceiver(tensor_queues[0])
    decoder = MsgpackDecoder(RequestWithTensor, oob_tensor_provider=receiver)

    # Create a request with None for the tensor field
    request = RequestWithTensor(prompt_embeds=None, data="test_data_with_none")

    # Encode and decode the request
    encoded = encoder.encode(request)
    decoded = decoder.decode(encoded)

    # Verify the decoded request matches the original
    assert isinstance(decoded, RequestWithTensor)
    assert decoded.data == "test_data_with_none"
    assert decoded.prompt_embeds is None


def test_multiple_senders_single_receiver_ipc():
    """Test N senders sharing a queue with a single receiver via msgpack.

    Simulates the real vLLM topology where multiple API server frontends
    each have their own MsgpackEncoder + TensorIpcSender, all putting
    tensors onto the same torch.mp queue, and a single engine core
    decodes them with one MsgpackDecoder + TensorIpcReceiver.
    """
    import torch.multiprocessing as torch_mp

    from vllm.v1.engine.tensor_ipc import TensorIpcReceiver, TensorIpcSender

    num_senders = 3
    num_messages_per_sender = 2
    tensor_queue = torch_mp.Queue()

    # Create N independent senders (each gets its own uuid-based sender_id)
    senders = []
    encoders = []
    for _ in range(num_senders):
        s = TensorIpcSender(tensor_queue)
        senders.append(s)
        encoders.append(MsgpackEncoder(oob_tensor_consumer=s))

    # Single receiver
    receiver = TensorIpcReceiver(tensor_queue)
    decoder = MsgpackDecoder(RequestWithTensor, oob_tensor_provider=receiver)

    # Encode messages from all senders, interleaving the order
    # so that tensors from different senders land on the queue interleaved.
    encoded_payloads: list[tuple[int, int, torch.Tensor, list]] = []
    for msg_idx in range(num_messages_per_sender):
        for sender_idx in range(num_senders):
            tensor = torch.full(
                (sender_idx + 1, msg_idx + 2),
                float(sender_idx * 100 + msg_idx),
                dtype=torch.float32,
            )
            req = RequestWithTensor(
                prompt_embeds=tensor,
                data=f"s{sender_idx}_m{msg_idx}",
            )
            encoded = encoders[sender_idx].encode(req)
            encoded_payloads.append((sender_idx, msg_idx, tensor, encoded))

    # Decode all messages — the receiver must correctly match each
    # tensor handle to the right TensorIpcData from the shared queue.
    for sender_idx, msg_idx, original_tensor, encoded in encoded_payloads:
        decoded = decoder.decode(encoded)
        assert isinstance(decoded, RequestWithTensor)
        assert decoded.data == f"s{sender_idx}_m{msg_idx}"
        assert decoded.prompt_embeds is not None
        assert decoded.prompt_embeds.shape == original_tensor.shape, (
            f"Shape mismatch for sender {sender_idx} msg {msg_idx}: "
            f"{decoded.prompt_embeds.shape} != {original_tensor.shape}"
        )
        assert torch.allclose(decoded.prompt_embeds, original_tensor), (
            f"Value mismatch for sender {sender_idx} msg {msg_idx}"
        )