Unverified Commit 2694a5d2 authored by Francisco Massa's avatar Francisco Massa Committed by GitHub
Browse files

Fix normalize for different dtype than float32 (#1021)

* Fix normalize for different dtype than float32

* Fix lint
parent f052c53f
......@@ -809,6 +809,15 @@ class Tester(unittest.TestCase):
# Checking if Normalize can be printed as string
transforms.Normalize(mean, std).__repr__()
def test_normalize_different_dtype(self):
for dtype1 in [torch.float32, torch.float64]:
img = torch.rand(3, 10, 10, dtype=dtype1)
for dtype2 in [torch.int64, torch.float32, torch.float64]:
mean = torch.tensor([1, 2, 3], dtype=dtype2)
std = torch.tensor([1, 2, 1], dtype=dtype2)
# checks that it doesn't crash
transforms.functional.normalize(img, mean, std)
def test_adjust_brightness(self):
x_shape = [2, 2, 3]
x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
......
......@@ -203,8 +203,9 @@ def normalize(tensor, mean, std, inplace=False):
if not inplace:
tensor = tensor.clone()
mean = torch.as_tensor(mean, dtype=torch.float32, device=tensor.device)
std = torch.as_tensor(std, dtype=torch.float32, device=tensor.device)
dtype = tensor.dtype
mean = torch.as_tensor(mean, dtype=dtype, device=tensor.device)
std = torch.as_tensor(std, dtype=dtype, device=tensor.device)
tensor.sub_(mean[:, None, None]).div_(std[:, None, None])
return tensor
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment