Unverified Commit 0e7ae64b authored by DevPranjal's avatar DevPranjal Committed by GitHub
Browse files

Port tests in test_transforms_video.py to pytest (#4040)

parent fb2598b8
import torch import torch
from torchvision.transforms import Compose from torchvision.transforms import Compose
import unittest import pytest
import random import random
import numpy as np import numpy as np
import warnings import warnings
...@@ -17,7 +17,7 @@ with warnings.catch_warnings(record=True): ...@@ -17,7 +17,7 @@ with warnings.catch_warnings(record=True):
import torchvision.transforms._transforms_video as transforms import torchvision.transforms._transforms_video as transforms
class TestVideoTransforms(unittest.TestCase): class TestVideoTransforms():
def test_random_crop_video(self): def test_random_crop_video(self):
numFrames = random.randint(4, 128) numFrames = random.randint(4, 128)
...@@ -30,8 +30,8 @@ class TestVideoTransforms(unittest.TestCase): ...@@ -30,8 +30,8 @@ class TestVideoTransforms(unittest.TestCase):
transforms.ToTensorVideo(), transforms.ToTensorVideo(),
transforms.RandomCropVideo((oheight, owidth)), transforms.RandomCropVideo((oheight, owidth)),
])(clip) ])(clip)
self.assertEqual(result.size(2), oheight) assert result.size(2) == oheight
self.assertEqual(result.size(3), owidth) assert result.size(3) == owidth
transforms.RandomCropVideo((oheight, owidth)).__repr__() transforms.RandomCropVideo((oheight, owidth)).__repr__()
...@@ -46,8 +46,8 @@ class TestVideoTransforms(unittest.TestCase): ...@@ -46,8 +46,8 @@ class TestVideoTransforms(unittest.TestCase):
transforms.ToTensorVideo(), transforms.ToTensorVideo(),
transforms.RandomResizedCropVideo((oheight, owidth)), transforms.RandomResizedCropVideo((oheight, owidth)),
])(clip) ])(clip)
self.assertEqual(result.size(2), oheight) assert result.size(2) == oheight
self.assertEqual(result.size(3), owidth) assert result.size(3) == owidth
transforms.RandomResizedCropVideo((oheight, owidth)).__repr__() transforms.RandomResizedCropVideo((oheight, owidth)).__repr__()
...@@ -70,7 +70,7 @@ class TestVideoTransforms(unittest.TestCase): ...@@ -70,7 +70,7 @@ class TestVideoTransforms(unittest.TestCase):
msg = "height: " + str(height) + " width: " \ msg = "height: " + str(height) + " width: " \
+ str(width) + " oheight: " + str(oheight) + " owidth: " + str(owidth) + str(width) + " oheight: " + str(oheight) + " owidth: " + str(owidth)
self.assertEqual(result.sum().item(), 0, msg) assert result.sum().item() == 0, msg
oheight += 1 oheight += 1
owidth += 1 owidth += 1
...@@ -82,7 +82,7 @@ class TestVideoTransforms(unittest.TestCase): ...@@ -82,7 +82,7 @@ class TestVideoTransforms(unittest.TestCase):
msg = "height: " + str(height) + " width: " \ msg = "height: " + str(height) + " width: " \
+ str(width) + " oheight: " + str(oheight) + " owidth: " + str(owidth) + str(width) + " oheight: " + str(oheight) + " owidth: " + str(owidth)
self.assertEqual(sum1.item() > 1, True, msg) assert sum1.item() > 1, msg
oheight += 1 oheight += 1
owidth += 1 owidth += 1
...@@ -94,28 +94,29 @@ class TestVideoTransforms(unittest.TestCase): ...@@ -94,28 +94,29 @@ class TestVideoTransforms(unittest.TestCase):
msg = "height: " + str(height) + " width: " \ msg = "height: " + str(height) + " width: " \
+ str(width) + " oheight: " + str(oheight) + " owidth: " + str(owidth) + str(width) + " oheight: " + str(oheight) + " owidth: " + str(owidth)
self.assertTrue(sum2.item() > 1, msg) assert sum2.item() > 1, msg
self.assertTrue(sum2.item() > sum1.item(), msg) assert sum2.item() > sum1.item(), msg
@unittest.skipIf(stats is None, 'scipy.stats is not available') @pytest.mark.skipif(stats is None, reason='scipy.stats is not available')
def test_normalize_video(self): @pytest.mark.parametrize('channels', [1, 3])
def test_normalize_video(self, channels):
def samples_from_standard_normal(tensor): def samples_from_standard_normal(tensor):
p_value = stats.kstest(list(tensor.view(-1)), 'norm', args=(0, 1)).pvalue p_value = stats.kstest(list(tensor.view(-1)), 'norm', args=(0, 1)).pvalue
return p_value > 0.0001 return p_value > 0.0001
random_state = random.getstate() random_state = random.getstate()
random.seed(42) random.seed(42)
for channels in [1, 3]:
numFrames = random.randint(4, 128) numFrames = random.randint(4, 128)
height = random.randint(32, 256) height = random.randint(32, 256)
width = random.randint(32, 256) width = random.randint(32, 256)
mean = random.random() mean = random.random()
std = random.random() std = random.random()
clip = torch.normal(mean, std, size=(channels, numFrames, height, width)) clip = torch.normal(mean, std, size=(channels, numFrames, height, width))
mean = [clip[c].mean().item() for c in range(channels)] mean = [clip[c].mean().item() for c in range(channels)]
std = [clip[c].std().item() for c in range(channels)] std = [clip[c].std().item() for c in range(channels)]
normalized = transforms.NormalizeVideo(mean, std)(clip) normalized = transforms.NormalizeVideo(mean, std)(clip)
self.assertTrue(samples_from_standard_normal(normalized)) assert samples_from_standard_normal(normalized)
random.setstate(random_state) random.setstate(random_state)
# Checking the optional in-place behaviour # Checking the optional in-place behaviour
...@@ -129,11 +130,11 @@ class TestVideoTransforms(unittest.TestCase): ...@@ -129,11 +130,11 @@ class TestVideoTransforms(unittest.TestCase):
numFrames, height, width = 64, 4, 4 numFrames, height, width = 64, 4, 4
trans = transforms.ToTensorVideo() trans = transforms.ToTensorVideo()
with self.assertRaises(TypeError): with pytest.raises(TypeError):
trans(np.random.rand(numFrames, height, width, 1).tolist()) trans(np.random.rand(numFrames, height, width, 1).tolist())
trans(torch.rand((numFrames, height, width, 1), dtype=torch.float)) trans(torch.rand((numFrames, height, width, 1), dtype=torch.float))
with self.assertRaises(ValueError): with pytest.raises(ValueError):
trans(torch.ones((3, numFrames, height, width, 3), dtype=torch.uint8)) trans(torch.ones((3, numFrames, height, width, 3), dtype=torch.uint8))
trans(torch.ones((height, width, 3), dtype=torch.uint8)) trans(torch.ones((height, width, 3), dtype=torch.uint8))
trans(torch.ones((width, 3), dtype=torch.uint8)) trans(torch.ones((width, 3), dtype=torch.uint8))
...@@ -141,7 +142,7 @@ class TestVideoTransforms(unittest.TestCase): ...@@ -141,7 +142,7 @@ class TestVideoTransforms(unittest.TestCase):
trans.__repr__() trans.__repr__()
@unittest.skipIf(stats is None, 'scipy.stats not available') @pytest.mark.skipif(stats is None, reason='scipy.stats not available')
def test_random_horizontal_flip_video(self): def test_random_horizontal_flip_video(self):
random_state = random.getstate() random_state = random.getstate()
random.seed(42) random.seed(42)
...@@ -157,7 +158,7 @@ class TestVideoTransforms(unittest.TestCase): ...@@ -157,7 +158,7 @@ class TestVideoTransforms(unittest.TestCase):
p_value = stats.binom_test(num_horizontal, num_samples, p=0.5) p_value = stats.binom_test(num_horizontal, num_samples, p=0.5)
random.setstate(random_state) random.setstate(random_state)
self.assertGreater(p_value, 0.0001) assert p_value > 0.0001
num_samples = 250 num_samples = 250
num_horizontal = 0 num_horizontal = 0
...@@ -168,10 +169,10 @@ class TestVideoTransforms(unittest.TestCase): ...@@ -168,10 +169,10 @@ class TestVideoTransforms(unittest.TestCase):
p_value = stats.binom_test(num_horizontal, num_samples, p=0.7) p_value = stats.binom_test(num_horizontal, num_samples, p=0.7)
random.setstate(random_state) random.setstate(random_state)
self.assertGreater(p_value, 0.0001) assert p_value > 0.0001
transforms.RandomHorizontalFlipVideo().__repr__() transforms.RandomHorizontalFlipVideo().__repr__()
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() pytest.main([__file__])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment