Commit 68bb2534 authored by Ashok M's avatar Ashok M Committed by Francisco Massa
Browse files

Added test for normalize functionality in make_grid function. (#840)

parent 0d975d6d
...@@ -18,6 +18,23 @@ class Tester(unittest.TestCase): ...@@ -18,6 +18,23 @@ class Tester(unittest.TestCase):
utils.make_grid(t, normalize=True, scale_each=True) utils.make_grid(t, normalize=True, scale_each=True)
assert torch.equal(t, t_clone), 'make_grid modified tensor in-place' assert torch.equal(t, t_clone), 'make_grid modified tensor in-place'
def test_normalize_in_make_grid(self):
t = torch.rand(5, 3, 10, 10) * 255
norm_max = torch.tensor(1.0)
norm_min = torch.tensor(0.0)
grid = utils.make_grid(t, normalize=True)
grid_max = torch.max(grid)
grid_min = torch.min(grid)
# Rounding the result to one decimal for comparison
n_digits = 1
rounded_grid_max = torch.round(grid_max * 10 ** n_digits) / (10 ** n_digits)
rounded_grid_min = torch.round(grid_min * 10 ** n_digits) / (10 ** n_digits)
assert torch.equal(norm_max, rounded_grid_max), 'Normalized max is not equal to 1'
assert torch.equal(norm_min, rounded_grid_min), 'Normalized min is not equal to 0'
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment