Unverified Commit 0b1447f8 authored by Lukas Geiger's avatar Lukas Geiger Committed by GitHub
Browse files

[Bugfix] Ensure tensors are contiguous during serialisation (#18860)


Signed-off-by: default avatarLukas Geiger <lukas.geiger94@gmail.com>
parent 24d0ef89
...@@ -158,8 +158,8 @@ class MsgpackEncoder: ...@@ -158,8 +158,8 @@ class MsgpackEncoder:
self, obj: torch.Tensor self, obj: torch.Tensor
) -> tuple[str, tuple[int, ...], Union[int, memoryview]]: ) -> tuple[str, tuple[int, ...], Union[int, memoryview]]:
assert self.aux_buffers is not None assert self.aux_buffers is not None
# view the tensor as a 1D array of bytes # view the tensor as a contiguous 1D array of bytes
arr = obj.flatten().view(torch.uint8).numpy() arr = obj.flatten().contiguous().view(torch.uint8).numpy()
if obj.nbytes < self.size_threshold: if obj.nbytes < self.size_threshold:
# Smaller tensors are encoded inline, just like ndarrays. # Smaller tensors are encoded inline, just like ndarrays.
data = msgpack.Ext(CUSTOM_TYPE_RAW_VIEW, arr.data) data = msgpack.Ext(CUSTOM_TYPE_RAW_VIEW, arr.data)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment