Unverified Commit e61b68e0 authored by Danylo Ulianych's avatar Danylo Ulianych Committed by GitHub
Browse files

F.normalize unsqueeze mean & std only for 1-d arrays (#2002)

* F.normalize unsqueeze mean & std if necessary

* added tests to F.normalize for 3d mean & std tensors
parent ae228fef
...@@ -5,6 +5,7 @@ import torch ...@@ -5,6 +5,7 @@ import torch
import torchvision.transforms as transforms import torchvision.transforms as transforms
import torchvision.transforms.functional as F import torchvision.transforms.functional as F
from torch._utils_internal import get_file_path_2 from torch._utils_internal import get_file_path_2
from numpy.testing import assert_array_almost_equal
import unittest import unittest
import math import math
import random import random
...@@ -843,6 +844,25 @@ class Tester(unittest.TestCase): ...@@ -843,6 +844,25 @@ class Tester(unittest.TestCase):
# checks that it doesn't crash # checks that it doesn't crash
transforms.functional.normalize(img, mean, std) transforms.functional.normalize(img, mean, std)
def test_normalize_3d_tensor(self):
torch.manual_seed(28)
n_channels = 3
img_size = 10
mean = torch.rand(n_channels)
std = torch.rand(n_channels)
img = torch.rand(n_channels, img_size, img_size)
target = F.normalize(img, mean, std).numpy()
mean_unsqueezed = mean.view(-1, 1, 1)
std_unsqueezed = std.view(-1, 1, 1)
result1 = F.normalize(img, mean_unsqueezed, std_unsqueezed)
result2 = F.normalize(img,
mean_unsqueezed.repeat(1, img_size, img_size),
std_unsqueezed.repeat(1, img_size, img_size))
assert_array_almost_equal(target, result1.numpy())
assert_array_almost_equal(target, result2.numpy())
def test_adjust_brightness(self): def test_adjust_brightness(self):
x_shape = [2, 2, 3] x_shape = [2, 2, 3]
x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1] x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
......
...@@ -211,7 +211,11 @@ def normalize(tensor, mean, std, inplace=False): ...@@ -211,7 +211,11 @@ def normalize(tensor, mean, std, inplace=False):
std = torch.as_tensor(std, dtype=dtype, device=tensor.device) std = torch.as_tensor(std, dtype=dtype, device=tensor.device)
if (std == 0).any(): if (std == 0).any():
raise ValueError('std evaluated to zero after conversion to {}, leading to division by zero.'.format(dtype)) raise ValueError('std evaluated to zero after conversion to {}, leading to division by zero.'.format(dtype))
tensor.sub_(mean[:, None, None]).div_(std[:, None, None]) if mean.ndim == 1:
mean = mean[:, None, None]
if std.ndim == 1:
std = std[:, None, None]
tensor.sub_(mean).div_(std)
return tensor return tensor
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment