Commit 2d7c0667 authored by Surgan Jandial's avatar Surgan Jandial Committed by Francisco Massa
Browse files

Test for checking non mutating behaviour of tensor transforms (#1656)

* out_place checks

* lint ups
parent 5c03d593
...@@ -12,20 +12,25 @@ class Tester(unittest.TestCase): ...@@ -12,20 +12,25 @@ class Tester(unittest.TestCase):
def test_vflip(self): def test_vflip(self):
img_tensor = torch.randn(3, 16, 16) img_tensor = torch.randn(3, 16, 16)
img_tensor_clone = img_tensor.clone()
vflipped_img = F_t.vflip(img_tensor) vflipped_img = F_t.vflip(img_tensor)
vflipped_img_again = F_t.vflip(vflipped_img) vflipped_img_again = F_t.vflip(vflipped_img)
self.assertEqual(vflipped_img.shape, img_tensor.shape) self.assertEqual(vflipped_img.shape, img_tensor.shape)
self.assertTrue(torch.equal(img_tensor, vflipped_img_again)) self.assertTrue(torch.equal(img_tensor, vflipped_img_again))
self.assertTrue(torch.equal(img_tensor, img_tensor_clone))
def test_hflip(self): def test_hflip(self):
img_tensor = torch.randn(3, 16, 16) img_tensor = torch.randn(3, 16, 16)
img_tensor_clone = img_tensor.clone()
hflipped_img = F_t.hflip(img_tensor) hflipped_img = F_t.hflip(img_tensor)
hflipped_img_again = F_t.hflip(hflipped_img) hflipped_img_again = F_t.hflip(hflipped_img)
self.assertEqual(hflipped_img.shape, img_tensor.shape) self.assertEqual(hflipped_img.shape, img_tensor.shape)
self.assertTrue(torch.equal(img_tensor, hflipped_img_again)) self.assertTrue(torch.equal(img_tensor, hflipped_img_again))
self.assertTrue(torch.equal(img_tensor, img_tensor_clone))
def test_crop(self): def test_crop(self):
img_tensor = torch.randint(0, 255, (3, 16, 16), dtype=torch.uint8) img_tensor = torch.randint(0, 255, (3, 16, 16), dtype=torch.uint8)
img_tensor_clone = img_tensor.clone()
top = random.randint(0, 15) top = random.randint(0, 15)
left = random.randint(0, 15) left = random.randint(0, 15)
height = random.randint(1, 16 - top) height = random.randint(1, 16 - top)
...@@ -34,7 +39,7 @@ class Tester(unittest.TestCase): ...@@ -34,7 +39,7 @@ class Tester(unittest.TestCase):
img_PIL = transforms.ToPILImage()(img_tensor) img_PIL = transforms.ToPILImage()(img_tensor)
img_PIL_cropped = F.crop(img_PIL, top, left, height, width) img_PIL_cropped = F.crop(img_PIL, top, left, height, width)
img_cropped_GT = transforms.ToTensor()(img_PIL_cropped) img_cropped_GT = transforms.ToTensor()(img_PIL_cropped)
self.assertTrue(torch.equal(img_tensor, img_tensor_clone))
self.assertTrue(torch.equal(img_cropped, (img_cropped_GT * 255).to(torch.uint8)), self.assertTrue(torch.equal(img_cropped, (img_cropped_GT * 255).to(torch.uint8)),
"functional_tensor crop not working") "functional_tensor crop not working")
...@@ -54,6 +59,7 @@ class Tester(unittest.TestCase): ...@@ -54,6 +59,7 @@ class Tester(unittest.TestCase):
img = torch.randint(0, 256, shape, dtype=torch.uint8) img = torch.randint(0, 256, shape, dtype=torch.uint8)
factor = 3 * torch.rand(1) factor = 3 * torch.rand(1)
img_clone = img.clone()
for f, ft in fns: for f, ft in fns:
ft_img = ft(img, factor) ft_img = ft(img, factor)
...@@ -68,23 +74,29 @@ class Tester(unittest.TestCase): ...@@ -68,23 +74,29 @@ class Tester(unittest.TestCase):
# difference in values caused by (at most 5) truncations. # difference in values caused by (at most 5) truncations.
max_diff = (ft_img - f_img).abs().max() max_diff = (ft_img - f_img).abs().max()
self.assertLess(max_diff, 5 / 255 + 1e-5) self.assertLess(max_diff, 5 / 255 + 1e-5)
self.assertTrue(torch.equal(img, img_clone))
def test_rgb_to_grayscale(self): def test_rgb_to_grayscale(self):
img_tensor = torch.randint(0, 255, (3, 16, 16), dtype=torch.uint8) img_tensor = torch.randint(0, 255, (3, 16, 16), dtype=torch.uint8)
img_tensor_clone = img_tensor.clone()
grayscale_tensor = F_t.rgb_to_grayscale(img_tensor).to(int) grayscale_tensor = F_t.rgb_to_grayscale(img_tensor).to(int)
grayscale_pil_img = torch.tensor(np.array(F.to_grayscale(F.to_pil_image(img_tensor)))).to(int) grayscale_pil_img = torch.tensor(np.array(F.to_grayscale(F.to_pil_image(img_tensor)))).to(int)
max_diff = (grayscale_tensor - grayscale_pil_img).abs().max() max_diff = (grayscale_tensor - grayscale_pil_img).abs().max()
self.assertLess(max_diff, 1.0001) self.assertLess(max_diff, 1.0001)
self.assertTrue(torch.equal(img_tensor, img_tensor_clone))
def test_center_crop(self): def test_center_crop(self):
img_tensor = torch.randint(0, 255, (1, 32, 32), dtype=torch.uint8) img_tensor = torch.randint(0, 255, (1, 32, 32), dtype=torch.uint8)
img_tensor_clone = img_tensor.clone()
cropped_tensor = F_t.center_crop(img_tensor, [10, 10]) cropped_tensor = F_t.center_crop(img_tensor, [10, 10])
cropped_pil_image = F.center_crop(transforms.ToPILImage()(img_tensor), [10, 10]) cropped_pil_image = F.center_crop(transforms.ToPILImage()(img_tensor), [10, 10])
cropped_pil_tensor = (transforms.ToTensor()(cropped_pil_image) * 255).to(torch.uint8) cropped_pil_tensor = (transforms.ToTensor()(cropped_pil_image) * 255).to(torch.uint8)
self.assertTrue(torch.equal(cropped_tensor, cropped_pil_tensor)) self.assertTrue(torch.equal(cropped_tensor, cropped_pil_tensor))
self.assertTrue(torch.equal(img_tensor, img_tensor_clone))
def test_five_crop(self): def test_five_crop(self):
img_tensor = torch.randint(0, 255, (1, 32, 32), dtype=torch.uint8) img_tensor = torch.randint(0, 255, (1, 32, 32), dtype=torch.uint8)
img_tensor_clone = img_tensor.clone()
cropped_tensor = F_t.five_crop(img_tensor, [10, 10]) cropped_tensor = F_t.five_crop(img_tensor, [10, 10])
cropped_pil_image = F.five_crop(transforms.ToPILImage()(img_tensor), [10, 10]) cropped_pil_image = F.five_crop(transforms.ToPILImage()(img_tensor), [10, 10])
self.assertTrue(torch.equal(cropped_tensor[0], self.assertTrue(torch.equal(cropped_tensor[0],
...@@ -97,9 +109,11 @@ class Tester(unittest.TestCase): ...@@ -97,9 +109,11 @@ class Tester(unittest.TestCase):
(transforms.ToTensor()(cropped_pil_image[3]) * 255).to(torch.uint8))) (transforms.ToTensor()(cropped_pil_image[3]) * 255).to(torch.uint8)))
self.assertTrue(torch.equal(cropped_tensor[4], self.assertTrue(torch.equal(cropped_tensor[4],
(transforms.ToTensor()(cropped_pil_image[4]) * 255).to(torch.uint8))) (transforms.ToTensor()(cropped_pil_image[4]) * 255).to(torch.uint8)))
self.assertTrue(torch.equal(img_tensor, img_tensor_clone))
def test_ten_crop(self): def test_ten_crop(self):
img_tensor = torch.randint(0, 255, (1, 32, 32), dtype=torch.uint8) img_tensor = torch.randint(0, 255, (1, 32, 32), dtype=torch.uint8)
img_tensor_clone = img_tensor.clone()
cropped_tensor = F_t.ten_crop(img_tensor, [10, 10]) cropped_tensor = F_t.ten_crop(img_tensor, [10, 10])
cropped_pil_image = F.ten_crop(transforms.ToPILImage()(img_tensor), [10, 10]) cropped_pil_image = F.ten_crop(transforms.ToPILImage()(img_tensor), [10, 10])
self.assertTrue(torch.equal(cropped_tensor[0], self.assertTrue(torch.equal(cropped_tensor[0],
...@@ -122,6 +136,7 @@ class Tester(unittest.TestCase): ...@@ -122,6 +136,7 @@ class Tester(unittest.TestCase):
(transforms.ToTensor()(cropped_pil_image[8]) * 255).to(torch.uint8))) (transforms.ToTensor()(cropped_pil_image[8]) * 255).to(torch.uint8)))
self.assertTrue(torch.equal(cropped_tensor[9], self.assertTrue(torch.equal(cropped_tensor[9],
(transforms.ToTensor()(cropped_pil_image[9]) * 255).to(torch.uint8))) (transforms.ToTensor()(cropped_pil_image[9]) * 255).to(torch.uint8)))
self.assertTrue(torch.equal(img_tensor, img_tensor_clone))
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment