Unverified Commit 3b523e38 authored by Lukas Geiger's avatar Lukas Geiger Committed by GitHub
Browse files

[Core] Do not copy array during hashing (#19484)


Signed-off-by: default avatarLukas Geiger <lukas.geiger94@gmail.com>
parent 16c16301
...@@ -60,3 +60,15 @@ def test_hash_collision_array_shape(): ...@@ -60,3 +60,15 @@ def test_hash_collision_array_shape():
hasher = MultiModalHasher hasher = MultiModalHasher
assert hasher.hash_kwargs(data=arr1) != hasher.hash_kwargs(data=arr2) assert hasher.hash_kwargs(data=arr1) != hasher.hash_kwargs(data=arr2)
def test_hash_non_contiguous_array():
arr = np.arange(24).reshape(4, 6).T
assert not arr.flags.c_contiguous
arr_c = np.ascontiguousarray(arr)
assert arr_c.flags.c_contiguous
hasher = MultiModalHasher
# Both should be hashable and produce the same hashes
assert hasher.hash_kwargs(data=arr) == hasher.hash_kwargs(data=arr_c)
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
import pickle import pickle
from collections.abc import Iterable, Mapping from collections.abc import Iterable, Mapping
from typing import Union
import numpy as np import numpy as np
import torch import torch
...@@ -23,11 +24,11 @@ A dictionary containing hashes for items in each modality. ...@@ -23,11 +24,11 @@ A dictionary containing hashes for items in each modality.
class MultiModalHasher: class MultiModalHasher:
@classmethod @classmethod
def serialize_item(cls, obj: object) -> bytes: def serialize_item(cls, obj: object) -> Union[bytes, memoryview]:
# Simple cases # Simple cases
if isinstance(obj, str): if isinstance(obj, str):
return obj.encode("utf-8") return obj.encode("utf-8")
if isinstance(obj, bytes): if isinstance(obj, (bytes, memoryview)):
return obj return obj
if isinstance(obj, (int, float)): if isinstance(obj, (int, float)):
return np.array(obj).tobytes() return np.array(obj).tobytes()
...@@ -38,12 +39,13 @@ class MultiModalHasher: ...@@ -38,12 +39,13 @@ class MultiModalHasher:
if isinstance(obj, torch.Tensor): if isinstance(obj, torch.Tensor):
return cls.item_to_bytes("tensor", obj.numpy()) return cls.item_to_bytes("tensor", obj.numpy())
if isinstance(obj, np.ndarray): if isinstance(obj, np.ndarray):
return cls.item_to_bytes( # If the array is non-contiguous, we need to copy it first
"ndarray", { arr_data = obj.data if obj.flags.c_contiguous else obj.tobytes()
"dtype": obj.dtype.str, return cls.item_to_bytes("ndarray", {
"shape": obj.shape, "dtype": obj.dtype.str,
"data": obj.tobytes(), "shape": obj.shape,
}) "data": arr_data,
})
logger.warning( logger.warning(
"No serialization method found for %s. " "No serialization method found for %s. "
...@@ -64,7 +66,7 @@ class MultiModalHasher: ...@@ -64,7 +66,7 @@ class MultiModalHasher:
cls, cls,
key: str, key: str,
obj: object, obj: object,
) -> Iterable[tuple[bytes, bytes]]: ) -> Iterable[tuple[bytes, Union[bytes, memoryview]]]:
# Recursive cases # Recursive cases
if isinstance(obj, (list, tuple)): if isinstance(obj, (list, tuple)):
for i, elem in enumerate(obj): for i, elem in enumerate(obj):
...@@ -73,7 +75,7 @@ class MultiModalHasher: ...@@ -73,7 +75,7 @@ class MultiModalHasher:
for k, v in obj.items(): for k, v in obj.items():
yield from cls.iter_item_to_bytes(f"{key}.{k}", v) yield from cls.iter_item_to_bytes(f"{key}.{k}", v)
else: else:
key_bytes = cls.serialize_item(key) key_bytes = key.encode("utf-8")
value_bytes = cls.serialize_item(obj) value_bytes = cls.serialize_item(obj)
yield key_bytes, value_bytes yield key_bytes, value_bytes
......
...@@ -140,7 +140,7 @@ class MsgpackEncoder: ...@@ -140,7 +140,7 @@ class MsgpackEncoder:
) -> tuple[str, tuple[int, ...], Union[int, memoryview]]: ) -> tuple[str, tuple[int, ...], Union[int, memoryview]]:
assert self.aux_buffers is not None assert self.aux_buffers is not None
# If the array is non-contiguous, we need to copy it first # If the array is non-contiguous, we need to copy it first
arr_data = obj.data if obj.data.c_contiguous else obj.tobytes() arr_data = obj.data if obj.flags.c_contiguous else obj.tobytes()
if not obj.shape or obj.nbytes < self.size_threshold: if not obj.shape or obj.nbytes < self.size_threshold:
# Encode small arrays and scalars inline. Using this extension type # Encode small arrays and scalars inline. Using this extension type
# ensures we can avoid copying when decoding. # ensures we can avoid copying when decoding.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment