Unverified Commit 9715f7bb authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Bugfix] Fix incorrect original shape in hashing (#23672)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
Co-authored-by: default avatarLukas Geiger <lukas.geiger94@gmail.com>
parent 98aa16ff
......@@ -45,10 +45,11 @@ def test_hash_collision_image_transpose():
assert hasher.hash_kwargs(image=image1) != hasher.hash_kwargs(image=image2)
def test_hash_collision_tensor_shape():
@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16])
def test_hash_collision_tensor_shape(dtype):
# The hash should be different though the data is the same when flattened
arr1 = torch.zeros((5, 10, 20, 3))
arr2 = torch.zeros((10, 20, 5, 3))
arr1 = torch.zeros((5, 10, 20, 3), dtype=dtype)
arr2 = torch.zeros((10, 20, 5, 3), dtype=dtype)
hasher = MultiModalHasher
assert hasher.hash_kwargs(data=arr1) != hasher.hash_kwargs(data=arr2)
......
......@@ -45,16 +45,22 @@ class MultiModalHasher:
if isinstance(obj, torch.Tensor):
tensor_obj: torch.Tensor = obj.cpu()
tensor_dtype = tensor_obj.dtype
tensor_shape = tensor_obj.shape
# NumPy does not support bfloat16.
# Workaround: View the tensor as a contiguous 1D array of bytes
if tensor_dtype == torch.bfloat16:
tensor_obj = tensor_obj.contiguous()
tensor_obj = tensor_obj.view(
(tensor_obj.numel(), )).view(torch.uint8)
return cls.item_to_bytes(
"tensor", {
"original_dtype": str(tensor_dtype),
"original_shape": tuple(tensor_obj.shape),
"data": tensor_obj.numpy()
"original_shape": tuple(tensor_shape),
"data": tensor_obj.numpy(),
})
return cls.item_to_bytes("tensor", tensor_obj.numpy())
if isinstance(obj, np.ndarray):
# If the array is non-contiguous, we need to copy it first
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment