Unverified Commit d73a9457 authored by Lukas Geiger's avatar Lukas Geiger Committed by GitHub
Browse files

[Core] Improve Tensor serialisation (#18774)


Signed-off-by: default avatarLukas Geiger <lukas.geiger94@gmail.com>
parent a3896c7f
...@@ -158,10 +158,8 @@ class MsgpackEncoder: ...@@ -158,10 +158,8 @@ class MsgpackEncoder:
self, obj: torch.Tensor self, obj: torch.Tensor
) -> tuple[str, tuple[int, ...], Union[int, memoryview]]: ) -> tuple[str, tuple[int, ...], Union[int, memoryview]]:
assert self.aux_buffers is not None assert self.aux_buffers is not None
# this creates a copy of the tensor if it's not already contiguous
obj = obj.contiguous()
# view the tensor as a 1D array of bytes # view the tensor as a 1D array of bytes
arr = obj.view((obj.numel(), )).view(torch.uint8).numpy() arr = obj.flatten().view(torch.uint8).numpy()
if obj.nbytes < self.size_threshold: if obj.nbytes < self.size_threshold:
# Smaller tensors are encoded inline, just like ndarrays. # Smaller tensors are encoded inline, just like ndarrays.
data = msgpack.Ext(CUSTOM_TYPE_RAW_VIEW, arr.data) data = msgpack.Ext(CUSTOM_TYPE_RAW_VIEW, arr.data)
...@@ -169,7 +167,7 @@ class MsgpackEncoder: ...@@ -169,7 +167,7 @@ class MsgpackEncoder:
# Otherwise encode index of backing buffer to avoid copy. # Otherwise encode index of backing buffer to avoid copy.
data = len(self.aux_buffers) data = len(self.aux_buffers)
self.aux_buffers.append(arr.data) self.aux_buffers.append(arr.data)
dtype = str(obj.dtype)[6:] # remove 'torch.' prefix dtype = str(obj.dtype).removeprefix("torch.")
return dtype, obj.shape, data return dtype, obj.shape, data
def _encode_nested_tensors(self, nt: NestedTensors) -> Any: def _encode_nested_tensors(self, nt: NestedTensors) -> Any:
...@@ -245,7 +243,7 @@ class MsgpackDecoder: ...@@ -245,7 +243,7 @@ class MsgpackDecoder:
# zero-copy decode. We assume the ndarray will not be kept around, # zero-copy decode. We assume the ndarray will not be kept around,
# as it now locks the whole received message buffer in memory. # as it now locks the whole received message buffer in memory.
buffer = self.aux_buffers[data] if isinstance(data, int) else data buffer = self.aux_buffers[data] if isinstance(data, int) else data
return np.ndarray(buffer=buffer, dtype=np.dtype(dtype), shape=shape) return np.frombuffer(buffer, dtype=dtype).reshape(shape)
def _decode_tensor(self, arr: Any) -> torch.Tensor: def _decode_tensor(self, arr: Any) -> torch.Tensor:
dtype, shape, data = arr dtype, shape, data = arr
...@@ -254,12 +252,15 @@ class MsgpackDecoder: ...@@ -254,12 +252,15 @@ class MsgpackDecoder:
# not complain about a readonly memoryview. # not complain about a readonly memoryview.
buffer = self.aux_buffers[data] if isinstance(data, int) \ buffer = self.aux_buffers[data] if isinstance(data, int) \
else bytearray(data) else bytearray(data)
# Create numpy wrapper around the bytes
arr = np.ndarray(buffer=buffer, dtype=np.uint8, shape=(len(buffer), ))
torch_dtype = getattr(torch, dtype) torch_dtype = getattr(torch, dtype)
assert isinstance(torch_dtype, torch.dtype) assert isinstance(torch_dtype, torch.dtype)
if not buffer: # torch.frombuffer doesn't like empty buffers
assert 0 in shape
return torch.empty(shape, dtype=torch_dtype)
# Create uint8 array
arr = torch.frombuffer(buffer, dtype=torch.uint8)
# Convert back to proper shape & type # Convert back to proper shape & type
return torch.from_numpy(arr).view(torch_dtype).view(shape) return arr.view(torch_dtype).view(shape)
def _decode_mm_items(self, obj: list) -> list[MultiModalKwargsItem]: def _decode_mm_items(self, obj: list) -> list[MultiModalKwargsItem]:
decoded_items = [] decoded_items = []
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment