Commit b5d03f59 authored by Alykhan Tejani's avatar Alykhan Tejani
Browse files

added tests for transforms.Normalize

parent 12a37ba7
......@@ -430,6 +430,22 @@ class Tester(unittest.TestCase):
random.setstate(random_state)
assert p_value > 0.0001
@unittest.skipIf(stats is None, 'scipt.stats is not available')
def test_normalize(self):
def samples_from_standard_normal(tensor):
p_value = stats.kstest(list(tensor.view(-1)), 'norm', args=(0, 1)).pvalue
return p_value > 0.0001
random_state = random.getstate()
random.seed(42)
for channels in [1, 3]:
img = torch.rand(channels, 10, 10)
mean = [img[c].mean() for c in range(channels)]
std = [img[c].std() for c in range(channels)]
normalized = transforms.Normalize(mean, std)(img)
assert samples_from_standard_normal(normalized)
random.setstate(random_state)
def test_adjust_brightness(self):
x_shape = [2, 2, 3]
x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment