Unverified Commit 947f2f53 authored by Russell Bryant's avatar Russell Bryant Committed by GitHub
Browse files

[V1] Allow turning off pickle fallback in vllm.v1.serial_utils (#17427)


Signed-off-by: default avatarRussell Bryant <rbryant@redhat.com>
Co-authored-by: default avatarCyrus Leung <cyrus.tl.leung@gmail.com>
parent 739e03b3
...@@ -5,6 +5,7 @@ from typing import Optional ...@@ -5,6 +5,7 @@ from typing import Optional
import msgspec import msgspec
import numpy as np import numpy as np
import pytest
import torch import torch
from vllm.multimodal.inputs import (MultiModalBatchedField, from vllm.multimodal.inputs import (MultiModalBatchedField,
...@@ -196,3 +197,100 @@ def assert_equal(obj1: MyType, obj2: MyType): ...@@ -196,3 +197,100 @@ def assert_equal(obj1: MyType, obj2: MyType):
assert torch.equal(obj1.large_non_contig_tensor, assert torch.equal(obj1.large_non_contig_tensor,
obj2.large_non_contig_tensor) obj2.large_non_contig_tensor)
assert torch.equal(obj1.empty_tensor, obj2.empty_tensor) assert torch.equal(obj1.empty_tensor, obj2.empty_tensor)
@pytest.mark.parametrize("allow_pickle", [True, False])
def test_dict_serialization(allow_pickle: bool):
"""Test encoding and decoding of a generic Python object using pickle."""
encoder = MsgpackEncoder(allow_pickle=allow_pickle)
decoder = MsgpackDecoder(allow_pickle=allow_pickle)
# Create a sample Python object
obj = {"key": "value", "number": 42}
# Encode the object
encoded = encoder.encode(obj)
# Decode the object
decoded = decoder.decode(encoded)
# Verify the decoded object matches the original
assert obj == decoded, "Decoded object does not match the original object."
@pytest.mark.parametrize("allow_pickle", [True, False])
def test_tensor_serialization(allow_pickle: bool):
"""Test encoding and decoding of a torch.Tensor."""
encoder = MsgpackEncoder(allow_pickle=allow_pickle)
decoder = MsgpackDecoder(torch.Tensor, allow_pickle=allow_pickle)
# Create a sample tensor
tensor = torch.rand(10, 10)
# Encode the tensor
encoded = encoder.encode(tensor)
# Decode the tensor
decoded = decoder.decode(encoded)
# Verify the decoded tensor matches the original
assert torch.allclose(
tensor, decoded), "Decoded tensor does not match the original tensor."
@pytest.mark.parametrize("allow_pickle", [True, False])
def test_numpy_array_serialization(allow_pickle: bool):
"""Test encoding and decoding of a numpy array."""
encoder = MsgpackEncoder(allow_pickle=allow_pickle)
decoder = MsgpackDecoder(np.ndarray, allow_pickle=allow_pickle)
# Create a sample numpy array
array = np.random.rand(10, 10)
# Encode the numpy array
encoded = encoder.encode(array)
# Decode the numpy array
decoded = decoder.decode(encoded)
# Verify the decoded array matches the original
assert np.allclose(
array,
decoded), "Decoded numpy array does not match the original array."
class CustomClass:
def __init__(self, value):
self.value = value
def __eq__(self, other):
return isinstance(other, CustomClass) and self.value == other.value
def test_custom_class_serialization_allowed_with_pickle():
"""Test that serializing a custom class succeeds when allow_pickle=True."""
encoder = MsgpackEncoder(allow_pickle=True)
decoder = MsgpackDecoder(CustomClass, allow_pickle=True)
obj = CustomClass("test_value")
# Encode the custom class
encoded = encoder.encode(obj)
# Decode the custom class
decoded = decoder.decode(encoded)
# Verify the decoded object matches the original
assert obj == decoded, "Decoded object does not match the original object."
def test_custom_class_serialization_disallowed_without_pickle():
"""Test that serializing a custom class fails when allow_pickle=False."""
encoder = MsgpackEncoder(allow_pickle=False)
obj = CustomClass("test_value")
with pytest.raises(TypeError):
# Attempt to encode the custom class
encoder.encode(obj)
...@@ -47,7 +47,9 @@ class MsgpackEncoder: ...@@ -47,7 +47,9 @@ class MsgpackEncoder:
via dedicated messages. Note that this is a per-tensor limit. via dedicated messages. Note that this is a per-tensor limit.
""" """
def __init__(self, size_threshold: Optional[int] = None): def __init__(self,
size_threshold: Optional[int] = None,
allow_pickle: bool = True):
if size_threshold is None: if size_threshold is None:
size_threshold = envs.VLLM_MSGPACK_ZERO_COPY_THRESHOLD size_threshold = envs.VLLM_MSGPACK_ZERO_COPY_THRESHOLD
self.encoder = msgpack.Encoder(enc_hook=self.enc_hook) self.encoder = msgpack.Encoder(enc_hook=self.enc_hook)
...@@ -56,6 +58,7 @@ class MsgpackEncoder: ...@@ -56,6 +58,7 @@ class MsgpackEncoder:
# pass custom data to the hook otherwise. # pass custom data to the hook otherwise.
self.aux_buffers: Optional[list[bytestr]] = None self.aux_buffers: Optional[list[bytestr]] = None
self.size_threshold = size_threshold self.size_threshold = size_threshold
self.allow_pickle = allow_pickle
def encode(self, obj: Any) -> Sequence[bytestr]: def encode(self, obj: Any) -> Sequence[bytestr]:
try: try:
...@@ -105,6 +108,9 @@ class MsgpackEncoder: ...@@ -105,6 +108,9 @@ class MsgpackEncoder:
for itemlist in mm._items_by_modality.values() for itemlist in mm._items_by_modality.values()
for item in itemlist] for item in itemlist]
if not self.allow_pickle:
raise TypeError(f"Object of type {type(obj)} is not serializable")
if isinstance(obj, FunctionType): if isinstance(obj, FunctionType):
# `pickle` is generally faster than cloudpickle, but can have # `pickle` is generally faster than cloudpickle, but can have
# problems serializing methods. # problems serializing methods.
...@@ -179,12 +185,13 @@ class MsgpackDecoder: ...@@ -179,12 +185,13 @@ class MsgpackDecoder:
not thread-safe when encoding tensors / numpy arrays. not thread-safe when encoding tensors / numpy arrays.
""" """
def __init__(self, t: Optional[Any] = None): def __init__(self, t: Optional[Any] = None, allow_pickle: bool = True):
args = () if t is None else (t, ) args = () if t is None else (t, )
self.decoder = msgpack.Decoder(*args, self.decoder = msgpack.Decoder(*args,
ext_hook=self.ext_hook, ext_hook=self.ext_hook,
dec_hook=self.dec_hook) dec_hook=self.dec_hook)
self.aux_buffers: Sequence[bytestr] = () self.aux_buffers: Sequence[bytestr] = ()
self.allow_pickle = allow_pickle
def decode(self, bufs: Union[bytestr, Sequence[bytestr]]) -> Any: def decode(self, bufs: Union[bytestr, Sequence[bytestr]]) -> Any:
if isinstance(bufs, (bytes, bytearray, memoryview, zmq.Frame)): if isinstance(bufs, (bytes, bytearray, memoryview, zmq.Frame)):
...@@ -265,6 +272,8 @@ class MsgpackDecoder: ...@@ -265,6 +272,8 @@ class MsgpackDecoder:
def ext_hook(self, code: int, data: memoryview) -> Any: def ext_hook(self, code: int, data: memoryview) -> Any:
if code == CUSTOM_TYPE_RAW_VIEW: if code == CUSTOM_TYPE_RAW_VIEW:
return data return data
if self.allow_pickle:
if code == CUSTOM_TYPE_PICKLE: if code == CUSTOM_TYPE_PICKLE:
return pickle.loads(data) return pickle.loads(data)
if code == CUSTOM_TYPE_CLOUDPICKLE: if code == CUSTOM_TYPE_CLOUDPICKLE:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment